flashinfer.logits_processor.Temperature

class flashinfer.logits_processor.Temperature(**params: Any)

Temperature scaling processor for logits.

Scales logits by dividing by a temperature value.

TensorType.LOGITS -> TensorType.LOGITS

Parameters:

temperature (float, Runtime) – Temperature value for scaling. Must be positive.

Examples

>>> import torch
>>> from flashinfer.logits_processor import LogitsPipe, Temperature, Sample
>>> torch.manual_seed(42)
>>> pipe = LogitsPipe([Temperature()])
>>> logits = torch.randn(2, 2, device="cuda")
>>> logits
tensor([[ 0.1940,  2.1614], [ -0.1721,  0.8491]], device='cuda:0')
>>> scaled_logits = pipe(logits, temperature=0.8)
>>> scaled_logits
tensor([[ 0.2425,  2.7017], [-0.2151,  1.0613]], device='cuda:0')
__init__(**params: Any)

Constructor for Temperature processor. No compile-time parameters are needed.

Methods

__init__(**params)

Constructor for Temperature processor.

legalize(input_type)

Legalize the processor into a list of low-level operators.