flashinfer.sampling.softmax

flashinfer.sampling.softmax(logits: Tensor, temperature: float | Tensor | None = None, enable_pdl: bool | None = None) Tensor

Fused GPU kernel for online safe softmax with temperature scaling.

Parameters:
  • logits (torch.Tensor) – Input tensor of logits.

  • temperature (Optional[Union[torch.Tensor, float]]) – Either a scalar or a tensor of shape (batch_size,), representing the temperature for temperature scaling. If a scalar, the same temperature is used for all requests. If a tensor, each request has its own temperature.

  • enable_pdl (Optional[bool]) – Whether to enable Programmatic Dependent Launch (PDL) for improved performance on supported hardware. If None (default), PDL will be automatically enabled on devices with compute capability >= 9.0.

Returns:

probs – Tensor of the same shape as input containing the softmax probabilities.

Return type:

torch.Tensor

Examples

>>> import torch
>>> import flashinfer
>>> torch.manual_seed(42)
>>> batch_size = 4
>>> vocab_size = 5
>>> logits = torch.rand(batch_size, vocab_size).to(0)
>>> logits
tensor([[0.8823, 0.9150, 0.3829, 0.9593, 0.3904],
        [0.6009, 0.2566, 0.7936, 0.9408, 0.1332],
        [0.9346, 0.5936, 0.8694, 0.5677, 0.7411],
        [0.4294, 0.8854, 0.5739, 0.2666, 0.6274]], device='cuda:0')
>>> probs = flashinfer.sampling.softmax(logits, temperature=1.0)
>>> probs
tensor([[0.2309, 0.2385, 0.1401, 0.2493, 0.1412],
        [0.2019, 0.1431, 0.2448, 0.2837, 0.1265],
        [0.2401, 0.1707, 0.2249, 0.1664, 0.1979],
        [0.1724, 0.2719, 0.1991, 0.1465, 0.2101]], device='cuda:0')