flashinfer.sampling.top_k_mask_logits#
- flashinfer.sampling.top_k_mask_logits(logits: torch.Tensor, top_k: torch.Tensor | int) torch.Tensor #
Fused GPU kernel for masking logits by top-k thresholding.
- Parameters:
logits (torch.Tensor) – Logits before softmax, shape
(batch_size, num_classes)
.top_k (Union[torch.Tensor, int]) – Either a scalar or a tensor of shape
(batch_size,)
, representing the top-k threshold for for for masking logits, should be in(0, num_classes)
. If a scalar, the same threshold is used for all requests. If a tensor, each request has its own threshold. We keep the top-k logits, set the rest to negative infinity.
- Returns:
masked_logits – Masked logits, shape
(batch_size, num_classes)
.- Return type:
torch.Tensor
Examples
>>> import torch >>> import flashinfer >>> torch.manual_seed(42) >>> batch_size = 4 >>> vocab_size = 5 >>> top_k = 3 >>> logits = torch.randn(batch_size, vocab_size).to(0) >>> logits tensor([[ 1.9269, 1.4873, 0.9007, -2.1055, -0.7581], [ 1.0783, 0.8008, 1.6806, 0.3559, -0.6866], [-0.4934, 0.2415, -0.2316, 0.0418, -0.2516], [ 0.8599, -0.3097, -0.3957, 0.8034, -0.6216]], device='cuda:0') >>> masked_logits = flashinfer.sampling.top_k_mask_logits(logits, top_k) >>> masked_logits tensor([[ 1.9269, 1.4873, 0.9007, -inf, -inf], [ 1.0783, 0.8008, 1.6806, -inf, -inf], [ -inf, 0.2415, -0.2316, 0.0418, -inf], [ 0.8599, -0.3097, -inf, 0.8034, -inf]], device='cuda:0')
Note
The combination of
top_k_mask_logits
andsoftmax
should be equivalent totop_k_renorm_probs
.See also