flashinfer.logits_processor.Temperature¶
- class flashinfer.logits_processor.Temperature(**params: Any)¶
Temperature scaling processor for logits.
Scales logits by dividing by a temperature value.
TensorType.LOGITS
->TensorType.LOGITS
- Parameters:
temperature (float, Runtime) – Temperature value for scaling. Must be positive.
Examples
>>> import torch >>> from flashinfer.logits_processor import LogitsPipe, Temperature, Sample >>> torch.manual_seed(42) >>> pipe = LogitsPipe([Temperature()]) >>> logits = torch.randn(2, 2, device="cuda") >>> logits tensor([[ 0.1940, 2.1614], [ -0.1721, 0.8491]], device='cuda:0') >>> scaled_logits = pipe(logits, temperature=0.8) >>> scaled_logits tensor([[ 0.2425, 2.7017], [-0.2151, 1.0613]], device='cuda:0')
- __init__(**params: Any)¶
Constructor for Temperature processor. No compile-time parameters are needed.
Methods
__init__
(**params)Constructor for Temperature processor.
legalize
(input_type)Legalize the processor into a list of low-level operators.