flashinfer.sampling.softmax¶
- flashinfer.sampling.softmax(logits: Tensor, temperature: float | Tensor | None = None, enable_pdl: bool | None = None) Tensor¶
Fused GPU kernel for online safe softmax with temperature scaling.
- Parameters:
logits (torch.Tensor) – Input tensor of logits.
temperature (Optional[Union[torch.Tensor, float]]) – Either a scalar or a tensor of shape
(batch_size,), representing the temperature for temperature scaling. If a scalar, the same temperature is used for all requests. If a tensor, each request has its own temperature.enable_pdl (Optional[bool]) – Whether to enable Programmatic Dependent Launch (PDL) for improved performance on supported hardware. If None (default), PDL will be automatically enabled on devices with compute capability >= 9.0.
- Returns:
probs – Tensor of the same shape as input containing the softmax probabilities.
- Return type:
torch.Tensor
Examples
>>> import torch >>> import flashinfer >>> torch.manual_seed(42) >>> batch_size = 4 >>> vocab_size = 5 >>> logits = torch.rand(batch_size, vocab_size).to(0) >>> logits tensor([[0.8823, 0.9150, 0.3829, 0.9593, 0.3904], [0.6009, 0.2566, 0.7936, 0.9408, 0.1332], [0.9346, 0.5936, 0.8694, 0.5677, 0.7411], [0.4294, 0.8854, 0.5739, 0.2666, 0.6274]], device='cuda:0') >>> probs = flashinfer.sampling.softmax(logits, temperature=1.0) >>> probs tensor([[0.2309, 0.2385, 0.1401, 0.2493, 0.1412], [0.2019, 0.1431, 0.2448, 0.2837, 0.1265], [0.2401, 0.1707, 0.2249, 0.1664, 0.1979], [0.1724, 0.2719, 0.1991, 0.1465, 0.2101]], device='cuda:0')