flashinfer.prefill.trtllm_ragged_attention_deepseek¶
- flashinfer.prefill.trtllm_ragged_attention_deepseek(query: Tensor, key: Tensor, value: Tensor, workspace_buffer: Tensor, seq_lens: Tensor, max_q_len: int, max_kv_len: int, bmm1_scale: float | Tensor, bmm2_scale: float | Tensor, o_sf_scale: float, batch_size: int, window_left: int, cum_seq_lens_q: Tensor, cum_seq_lens_kv: Tensor, enable_pdl: bool | None, is_causal: bool, return_lse: bool, attention_sinks: Tensor | None = None, skip_softmax_threshold_scale_factor: float | None = None, out: Tensor | None = None, lse: Tensor | None = None, sage_attn_sfs: Tuple[Tensor | None, Tensor | None, Tensor | None, Tensor | None] = (None, None, None, None), num_elts_per_sage_attn_blk: Tuple[int, int, int, int] = (0, 0, 0, 0), backend: str = 'trtllm-gen') Tensor | Tuple[Tensor, Tensor]¶
- Parameters:
query (torch.Tensor) – query tensor with shape [num_tokens, num_heads, head_dim]
key (torch.Tensor) – key tensor with shape [num_tokens, num_heads, head_dim]
value (torch.Tensor) – value tensor with shape [num_tokens, num_heads, head_dim]
workspace_buffer (torch.Tensor) – workspace buffer
seq_lens (torch.Tensor) – sequence lengths
max_q_len (int) – max query length
max_kv_len (int) – max key/value length
bmm1_scale (Union[float, torch.Tensor]) – scale for bmm1, scale_q * scale_k * 1.0 / (head_dim_qk ** 0.5) when using trtllm-gen backend, it can be a torch.Tensor with dtype torch.float32.
bmm2_scale (Union[float, torch.Tensor]) – scale for bmm2, scale_v when using trtllm-gen backend, it can be a torch.Tensor with dtype torch.float32.
o_sf_scale (float) – scale for output
batch_size (int) – batch size
window_left (int) – window left
cum_seq_lens_q (torch.Tensor) – cumulative sequence lengths for query
cum_seq_lens_kv (torch.Tensor) – cumulative sequence lengths for key/value
enable_pdl (bool) – enable pdl
is_causal (bool) – is causal
return_lse (bool) – Whether to allocate and return the log-sum-exp tensor in addition to the attention output. When
Truethe function returns(out, lse); whenFalseonlyoutis returned.attention_sinks (Optional[torch.Tensor]) – attention sinks
skip_softmax_threshold_scale_factor (Optional[float]) – threshold scale factor for skipping softmax operations. Providing a value for this parameter enables skip-softmax sparsity as described in: https://arxiv.org/abs/2512.12087 If no value is provided, then standard attention is used. Setting the threshold to a higher value generally increases kernel performance at the cost of accuracy degradation. The actual threshold value equals the provided threshold_scale_factor divided by the context length.
out (Optional[torch.Tensor]) – output tensor, if not provided, will be allocated with shape [query.shape[0], query.shape[1], value.shape[2]]
lse (Optional[torch.Tensor]) – lse tensor, if not provided, will be allocated with shape [query.shape[0], query.shape[1]]
sage_attn_sfs (Tuple[Optional[torch.Tensor], Optional[torch.Tensor], Optional[torch.Tensor], Optional[torch.Tensor]]) – SageAttention scale-factor tensors for the four sub-blocks
(q_sf, k_sf, v_block_sum, p_block_sum). Defaults to(None, None, None, None)which disables SageAttention. Currently only consulted by thetrtllm-genbackend.num_elts_per_sage_attn_blk (Tuple[int, int, int, int]) – Per-block element counts for the SageAttention scale-factor tensors, matching the order of
sage_attn_sfs. Defaults to(0, 0, 0, 0), which disables SageAttention. Only consulted whensage_attn_sfscontains non-Nonetensors.backend (str) – Attention backend to use. “trtllm-gen” (default) or “cute-dsl”.
- Returns:
out – output torch.Tensor or Tuple[torch.Tensor, torch.Tensor]. If return_lse is True, the output will be a tuple of two tensors, the first is the output tensor, the second is the lse tensor. If return_lse is False, the output will be a single tensor.
- Return type:
Union[torch.Tensor, Tuple[torch.Tensor, torch.Tensor]]