flashinfer.sampling.top_k_mask_logits

flashinfer.sampling.top_k_mask_logits(logits: torch.Tensor, top_k: torch.Tensor | int) torch.Tensor

Fused GPU kernel for masking logits by top-k thresholding.

Parameters:
  • logits (torch.Tensor) – Logits before softmax, shape (batch_size, num_classes).

  • top_k (Union[torch.Tensor, int]) – Either a scalar or a tensor of shape (batch_size,), representing the top-k threshold for for for masking logits, should be in (0, num_classes). If a scalar, the same threshold is used for all requests. If a tensor, each request has its own threshold. We keep the top-k logits, set the rest to negative infinity.

Returns:

masked_logits – Masked logits, shape (batch_size, num_classes).

Return type:

torch.Tensor

Examples

>>> import torch
>>> import flashinfer
>>> torch.manual_seed(42)
>>> batch_size = 4
>>> vocab_size = 5
>>> top_k = 3
>>> logits = torch.randn(batch_size, vocab_size).to(0)
>>> logits
tensor([[ 1.9269,  1.4873,  0.9007, -2.1055, -0.7581],
        [ 1.0783,  0.8008,  1.6806,  0.3559, -0.6866],
        [-0.4934,  0.2415, -0.2316,  0.0418, -0.2516],
        [ 0.8599, -0.3097, -0.3957,  0.8034, -0.6216]], device='cuda:0')
>>> masked_logits = flashinfer.sampling.top_k_mask_logits(logits, top_k)
>>> masked_logits
tensor([[ 1.9269,  1.4873,  0.9007,    -inf,    -inf],
        [ 1.0783,  0.8008,  1.6806,    -inf,    -inf],
        [   -inf,  0.2415, -0.2316,  0.0418,    -inf],
        [ 0.8599, -0.3097,    -inf,  0.8034,    -inf]], device='cuda:0')

Note

The combination of top_k_mask_logits and softmax should be equivalent to top_k_renorm_probs.