flashinfer.logits_processor.Sample¶
- class flashinfer.logits_processor.Sample(deterministic: bool = True, **params: Any)¶
Sampling processor to generate token indices.
Samples tokens from logits or probability distributions.
TensorType.LOGITS
->TensorType.INDICES
|TensorType.PROBS
->TensorType.INDICES
- Parameters:
deterministic (bool, optional, Compile-time) – Whether to use deterministic kernel implementation. Default is True.
indices (torch.Tensor, optional, Runtime) – Indices for batched sampling when probability tensors are shared.
generator (torch.Generator, optional, Runtime) – Random number generator for reproducible sampling.
Examples
>>> import torch >>> from flashinfer.logits_processor import LogitsPipe, Sample, TensorType >>> torch.manual_seed(42) >>> >>> # Sampling from logits >>> pipe = LogitsPipe([Sample(deterministic=True)], input_type=TensorType.LOGITS) >>> logits = torch.randn(2, 5, device="cuda") >>> logits tensor([[ 0.1940, 2.1614, -0.1721, 0.8491, -1.9244], [ 0.6530, -0.6494, -0.8175, 0.5280, -1.2753]], device='cuda:0') >>> tokens = pipe(logits, top_k=1) >>> tokens tensor([0, 1], device='cuda:0') >>> >>> # Sampling from probabilities >>> pipe = LogitsPipe([Sample(deterministic=True)], input_type=TensorType.PROBS) >>> probs = torch.randn(2, 5, device="cuda") >>> probs_normed = probs / probs.sum(dim=-1, keepdim=True) >>> probs_normed tensor([[ 2.8827, 0.0870, 0.2340, -3.2731, 1.0694], [ 0.3526, 0.0928, 0.1601, -0.1737, 0.5683]], device='cuda:0') >>> tokens = pipe(probs_normed, top_k=1) >>> tokens tensor([0, 0], device='cuda:0')
Notes
Outputs
TensorType.INDICES
- no operators can followSee also
sampling_from_logits()
,sampling_from_probs()
- __init__(deterministic: bool = True, **params: Any)¶
Constructor for Sample processor.
- Parameters:
deterministic (bool, optional) – Whether to use deterministic kernel implementation. Default is True.
Methods
__init__
([deterministic])Constructor for Sample processor.
legalize
(input_type)Legalize the processor into a list of low-level operators.