flashinfer.sampling.chain_speculative_sampling

flashinfer.sampling.chain_speculative_sampling(draft_probs, draft_token_ids, target_probs, maybe_output_accepted_token_num: Tensor | None = None, maybe_output_emitted_draft_token_num: Tensor | None = None, deterministic: bool = True, generator: Generator | None = None, seed: int | Tensor | None = None, offset: int | Tensor | None = None) Tensor

Fused-GPU kernel for speculative sampling for sequence generation (proposed in paper Accelerating Large Language Model Decoding with Speculative Sampling), where the draft model generates a sequence(chain) of tokens for each request.

Parameters:
  • draft_probs (torch.Tensor) – The probability over vocabulary generated by draft model. Shape: (batch_size, num_speculate_tokens, vocab_size)

  • draft_token_ids (torch.Tensor) – The draft model’s generated token indices. Shape: (batch_size, num_speculate_tokens)

  • target_probs (torch.Tensor) – The probability over vocabulary generated by target model. Compared to input draft_probs, the target model’s probability has an additional slot at the end because the target model will generate one more token than the draft model. Shape: (batch_size, num_speculate_tokens + 1, vocab_size)

  • maybe_output_accepted_token_num (Optional[torch.Tensor]) – The number of tokens that can be accepted if each token is considered independently for each request. This metric does not consider the fact that rejection sampling will stop at the first token that does not satisfy the probability requirement r < p/q. It only evaluates the alignment of draft model and target model. Shape: (batch_size) If specified, the number of accepted token number will be added to this tensor inplace. Default is None.

  • maybe_output_emitted_draft_token_num (Optional[torch.Tensor]) – The number of draft tokens that are finally emitted for each request. Does not include the bonus token. (Thus the total number of tokens sampled for a given request is output_emitted_draft_token_num + 1). Shape: (batch_size) If specified, the number of emitted token number will be added to this tensor inplace. Default is None.

  • deterministic (bool) – Whether to use deterministic kernel implementation, default is True.

  • generator (Optional[torch.Generator]) – A random number generator for the operation.

  • seed (Optional[Union[int, torch.Tensor]]) –

    Random seed value for the sampling operation. Can be either an integer or a torch.Tensor. When provided as a torch.Tensor, it must be int64 or uint64 dtype, 1D, and length 1 or batch_size. Using torch.Tensor is required for CUDA graph compatibility.

    Warning: If you provide seed and offset explicitly, you are responsible for updating their values between calls to ensure different random samples. Common approaches include: - Incrementing offset by the number of random values consumed - Updating seed based on the number of calls to the operation

  • offset (Optional[Union[int, torch.Tensor]]) –

    Random offset value for the sampling operation. Can be either an integer or a torch.Tensor. When provided as a torch.Tensor, it must be int64 or uint64 dtype, 1D, and length 1 or batch_size. Using torch.Tensor is required for CUDA graph compatibility.

    Warning: If you provide seed and offset explicitly, you are responsible for updating their values between calls to ensure different random samples. The offset should be incremented based on the number of random values consumed by the operation.

Returns:

  • output_token_ids (torch.Tensor) – The output token indices verified by the target model, rejected samples are padded with -1. Compared to input draft_token_ids, the output tensor has an additional token index at the end for the final token, if all previous tokens are accepted, another “bonus” token will be sampled from the target model’s probability. Shape: (batch_size, num_speculate_tokens + 1)

  • output_accepted_token_num (torch.Tensor) – The number of tokens that can be accepted if each token is considered independently for each request. This metric does not consider the fact that rejection sampling will stop at the first token that does not satisfy the probability requirement r < p/q. It only evaluates the alignment of draft model and target model. Shape: (batch_size)

  • output_emitted_draft_token_num (torch.Tensor) – The number of draft tokens that are finally emitted for each request. Does not include the bonus token. (Thus the total number of tokens sampled for a given request is output_emitted_draft_token_num + 1). Shape: (batch_size)

Examples

>>> import torch
>>> import flashinfer
>>> torch.manual_seed(42)
>>> batch_size = 1
>>> num_speculate_tokens = 2
>>> vocab_size = 4
>>> draft_probs = torch.tensor([[[0.1, 0.2, 0.3, 0.4], [0.2, 0.3, 0.4, 0.1]]]).to(0)
>>> # token 2 was sampled from draft model for the first token, and
>>> # token 1 was sampled from draft model for the second token
>>> draft_token_ids = torch.tensor([[2, 1]], dtype=torch.int32).to(0)
>>> target_probs = torch.tensor([[[0.0, 0.1, 0.6, 0.3], [1.0, 0.0, 0.0, 0.0], [0.7, 0.1, 0.1, 0.1]]]).to(0)
>>> output_token_ids, output_accepted_token_num, output_emitted_draft_token_num =\
...     flashinfer.sampling.chain_speculative_sampling(
...         draft_probs, draft_token_ids, target_probs)
>>> # the first token is accepted, the second token is rejected and sampled from the difference
>>> # between the target model and the draft model, the third token is padded with -1
>>> output_token_ids
tensor([[ 2,  0, -1]], device='cuda:0', dtype=torch.int32)
>>> output_accepted_token_num
tensor([1], device='cuda:0')
>>> output_emitted_draft_token_num
tensor([1], device='cuda:0')