flashinfer.sampling.chain_speculative_sampling#

flashinfer.sampling.chain_speculative_sampling(draft_probs, draft_token_ids, uniform_samples, target_probs)#

Fused-GPU kernel for speculative sampling for sequence generation (proposed in paper Accelerating Large Language Model Decoding with Speculative Sampling), where the draft model generates a sequence(chain) of tokens for each request.

Parameters:
  • draft_probs (torch.Tensor) – The probability over vocabulary generated by draft model. Shape: (batch_size, num_speculate_tokens, vocab_size)

  • draft_token_ids (torch.Tensor) – The draft model’s generated token indices. Shape: (batch_size, num_specutate_tokens)

  • uniform_samples (torch.Tensor) – The uniform samples used as needle for sampling, shape (batch_size, num_speculate_tokens + 1). Expected to be uniformly distributed in [0, 1).

  • target_probs (torch.Tensor) – The probability over vocabulary generated by target model. Compared to input draft_probs, the target model’s probability has an additional slot at the end because the target model will generate one more token than the draft model. Shape: (batch_size, num_speculate_tokens + 1, vocab_size)

Returns:

output_token_ids – The output token indices verified by the target model, rejected samples are padded with -1. Compared to input draft_token_ids, the output tensor has an additional token index at the end for the final token, if all previous tokens are accepted, another “bonus” token will be sampled from the target model’s probability. Shape: (batch_size, num_specutate_tokens + 1)

Return type:

torch.Tensor