flashinfer.sampling.chain_speculative_sampling#
- flashinfer.sampling.chain_speculative_sampling(draft_probs, draft_token_ids, uniform_samples, target_probs, maybe_output_accepted_token_num: torch.Tensor | None = None, maybe_output_emitted_token_num: torch.Tensor | None = None, deterministic: bool = True) torch.Tensor #
Fused-GPU kernel for speculative sampling for sequence generation (proposed in paper Accelerating Large Language Model Decoding with Speculative Sampling), where the draft model generates a sequence(chain) of tokens for each request.
- Parameters:
draft_probs (torch.Tensor) – The probability over vocabulary generated by draft model. Shape:
(batch_size, num_speculate_tokens, vocab_size)
draft_token_ids (torch.Tensor) – The draft model’s generated token indices. Shape:
(batch_size, num_specutate_tokens)
uniform_samples (torch.Tensor) – The uniform samples used as needle for sampling, shape
(batch_size, num_speculate_tokens + 1)
. Expected to be uniformly distributed in[0, 1)
.target_probs (torch.Tensor) – The probability over vocabulary generated by target model. Compared to input
draft_probs
, the target model’s probability has an additional slot at the end because the target model will generate one more token than the draft model. Shape:(batch_size, num_speculate_tokens + 1, vocab_size)
maybe_output_accepted_token_num (Optional[torch.Tensor]) – The number of tokens that can be accepted if each token is considered independently for each request. This metric does not consider the fact that rejection sampling will stop at the first token that does not satisfy the probablity requirement r < p/q. It only evaluates the alignment of draft model and target model. Shape:
(batch_size)
If specified, the number of accepted token number will be added to this tensor inplace. Default isNone
.maybe_output_emitted_token_num (Optional[torch.Tensor]) – The number of tokens that are finally emitted/generated for each request. Shape:
(batch_size)
If specified, the number of emitted token number will be added to this tensor inplace. Default isNone
.deterministic (bool) – Whether to use deterministic kernel implementation, default is
True
.
- Returns:
output_token_ids (torch.Tensor) – The output token indices verified by the target model, rejected samples are padded with
-1
. Compared to inputdraft_token_ids
, the output tensor has an additional token index at the end for the final token, if all previous tokens are accepted, another “bonus” token will be sampled from the target model’s probability. Shape: (batch_size, num_specutate_tokens + 1)output_accepted_token_num (torch.Tensor) – The number of tokens that can be accepted if each token is considered independently for each request. This metric does not consider the fact that rejection sampling will stop at the first token that does not satisfy the probablity requirement r < p/q. It only evaluates the alignment of draft model and target model. Shape:
(batch_size)
output_emitted_token_num (torch.Tensor) – The number of tokens that are finally emitted/generated for each request. Shape:
(batch_size)
Examples
>>> import torch >>> import flashinfer >>> torch.manual_seed(42) >>> batch_size = 1 >>> num_speculate_tokens = 2 >>> vocab_size = 4 >>> draft_probs = torch.tensor([[[0.1, 0.2, 0.3, 0.4], [0.2, 0.3, 0.4, 0.1]]]).to(0) >>> # token 2 was sampled from draft model for the first token, and >>> # token 1 was sampled from draft model for the second token >>> draft_token_ids = torch.tensor([[2, 1]], dtype=torch.int32).to(0) >>> # uniform samples for rejection sampling >>> uniform_samples = torch.rand(batch_size, num_speculate_tokens + 1).to(0) tensor([[0.8823, 0.9150, 0.3829], device='cuda:0') >>> target_probs = torch.tensor([[[0.0, 0.1, 0.6, 0.3], [1.0, 0.0, 0.0, 0.0], [0.7, 0.1, 0.1, 0.1]]]).to(0) >>> output_token_ids, output_accepted_token_num, output_accepted_token_num =\ ... flashinfer.sampling.chain_speculative_sampling( ... draft_probs, draft_token_ids, uniform_samples, target_probs) >>> # the first token is accepted, the second token is rejected and sampled from the difference >>> # between the target model and the draft model, the third token is padded with -1 >>> output_token_ids tensor([[ 2, 0, -1]], device='cuda:0', dtype=torch.int32) >>> output_accepted_token_num tensor([1], device='cuda:0') >>> output_emitted_token_num tensor([1], device='cuda:0')