flashinfer.prefill.trtllm_ragged_attention_deepseek

flashinfer.prefill.trtllm_ragged_attention_deepseek(query: Tensor, key: Tensor, value: Tensor, workspace_buffer: Tensor, seq_lens: Tensor, max_q_len: int, max_kv_len: int, bmm1_scale: float | Tensor, bmm2_scale: float | Tensor, o_sf_scale: float, batch_size: int, window_left: int, cum_seq_lens_q: Tensor, cum_seq_lens_kv: Tensor, enable_pdl: bool | None, is_causal: bool, return_lse: bool, attention_sinks: Tensor | None = None, skip_softmax_threshold_scale_factor: float | None = None, out: Tensor | None = None, lse: Tensor | None = None, sage_attn_sfs: Tuple[Tensor | None, Tensor | None, Tensor | None, Tensor | None] = (None, None, None, None), num_elts_per_sage_attn_blk: Tuple[int, int, int, int] = (0, 0, 0, 0), backend: str = 'trtllm-gen') Tensor | Tuple[Tensor, Tensor]
Parameters:
  • query (torch.Tensor) – query tensor with shape [num_tokens, num_heads, head_dim]

  • key (torch.Tensor) – key tensor with shape [num_tokens, num_heads, head_dim]

  • value (torch.Tensor) – value tensor with shape [num_tokens, num_heads, head_dim]

  • workspace_buffer (torch.Tensor) – workspace buffer

  • seq_lens (torch.Tensor) – sequence lengths

  • max_q_len (int) – max query length

  • max_kv_len (int) – max key/value length

  • bmm1_scale (Union[float, torch.Tensor]) – scale for bmm1, scale_q * scale_k * 1.0 / (head_dim_qk ** 0.5) when using trtllm-gen backend, it can be a torch.Tensor with dtype torch.float32.

  • bmm2_scale (Union[float, torch.Tensor]) – scale for bmm2, scale_v when using trtllm-gen backend, it can be a torch.Tensor with dtype torch.float32.

  • o_sf_scale (float) – scale for output

  • batch_size (int) – batch size

  • window_left (int) – window left

  • cum_seq_lens_q (torch.Tensor) – cumulative sequence lengths for query

  • cum_seq_lens_kv (torch.Tensor) – cumulative sequence lengths for key/value

  • enable_pdl (bool) – enable pdl

  • is_causal (bool) – is causal

  • return_lse (bool) – Whether to allocate and return the log-sum-exp tensor in addition to the attention output. When True the function returns (out, lse); when False only out is returned.

  • attention_sinks (Optional[torch.Tensor]) – attention sinks

  • skip_softmax_threshold_scale_factor (Optional[float]) – threshold scale factor for skipping softmax operations. Providing a value for this parameter enables skip-softmax sparsity as described in: https://arxiv.org/abs/2512.12087 If no value is provided, then standard attention is used. Setting the threshold to a higher value generally increases kernel performance at the cost of accuracy degradation. The actual threshold value equals the provided threshold_scale_factor divided by the context length.

  • out (Optional[torch.Tensor]) – output tensor, if not provided, will be allocated with shape [query.shape[0], query.shape[1], value.shape[2]]

  • lse (Optional[torch.Tensor]) – lse tensor, if not provided, will be allocated with shape [query.shape[0], query.shape[1]]

  • sage_attn_sfs (Tuple[Optional[torch.Tensor], Optional[torch.Tensor], Optional[torch.Tensor], Optional[torch.Tensor]]) – SageAttention scale-factor tensors for the four sub-blocks (q_sf, k_sf, v_block_sum, p_block_sum). Defaults to (None, None, None, None) which disables SageAttention. Currently only consulted by the trtllm-gen backend.

  • num_elts_per_sage_attn_blk (Tuple[int, int, int, int]) – Per-block element counts for the SageAttention scale-factor tensors, matching the order of sage_attn_sfs. Defaults to (0, 0, 0, 0), which disables SageAttention. Only consulted when sage_attn_sfs contains non-None tensors.

  • backend (str) – Attention backend to use. “trtllm-gen” (default) or “cute-dsl”.

Returns:

out – output torch.Tensor or Tuple[torch.Tensor, torch.Tensor]. If return_lse is True, the output will be a tuple of two tensors, the first is the output tensor, the second is the lse tensor. If return_lse is False, the output will be a single tensor.

Return type:

Union[torch.Tensor, Tuple[torch.Tensor, torch.Tensor]]