flashinfer.logits_processor.Sample

class flashinfer.logits_processor.Sample(deterministic: bool = True, **params: Any)

Sampling processor to generate token indices.

Samples tokens from logits or probability distributions.

TensorType.LOGITS -> TensorType.INDICES | TensorType.PROBS -> TensorType.INDICES

Parameters:
  • deterministic (bool, optional, Compile-time) – Whether to use deterministic kernel implementation. Default is True.

  • indices (torch.Tensor, optional, Runtime) – Indices for batched sampling when probability tensors are shared.

  • generator (torch.Generator, optional, Runtime) – Random number generator for reproducible sampling.

Examples

>>> import torch
>>> from flashinfer.logits_processor import LogitsPipe, Sample, TensorType
>>> torch.manual_seed(42)
>>>
>>> # Sampling from logits
>>> pipe = LogitsPipe([Sample(deterministic=True)], input_type=TensorType.LOGITS)
>>> logits = torch.randn(2, 5, device="cuda")
>>> logits
tensor([[ 0.1940,  2.1614, -0.1721,  0.8491, -1.9244],
        [ 0.6530, -0.6494, -0.8175,  0.5280, -1.2753]], device='cuda:0')
>>> tokens = pipe(logits, top_k=1)
>>> tokens
tensor([0, 1], device='cuda:0')
>>>
>>> # Sampling from probabilities
>>> pipe = LogitsPipe([Sample(deterministic=True)], input_type=TensorType.PROBS)
>>> probs = torch.randn(2, 5, device="cuda")
>>> probs_normed = probs / probs.sum(dim=-1, keepdim=True)
>>> probs_normed
tensor([[ 2.8827,  0.0870,  0.2340, -3.2731,  1.0694],
        [ 0.3526,  0.0928,  0.1601, -0.1737,  0.5683]], device='cuda:0')
>>> tokens = pipe(probs_normed, top_k=1)
>>> tokens
tensor([0, 0], device='cuda:0')

Notes

Outputs TensorType.INDICES - no operators can follow

See also

sampling_from_logits(), sampling_from_probs()

__init__(deterministic: bool = True, **params: Any)

Constructor for Sample processor.

Parameters:

deterministic (bool, optional) – Whether to use deterministic kernel implementation. Default is True.

Methods

__init__([deterministic])

Constructor for Sample processor.

legalize(input_type)

Legalize the processor into a list of low-level operators.