flashinfer.prefill.trtllm_fmha_v2_prefill

flashinfer.prefill.trtllm_fmha_v2_prefill(qkv: Tensor | Tuple[Tensor, Tensor] | Tuple[Tensor, Tensor, Tensor], input_layout: str, workspace_buffer: Tensor, seq_lens: Tensor, max_q_len: int, max_kv_len: int, bmm1_scale: float, bmm2_scale: float, batch_size: int, cum_seq_lens_q: Tensor, cum_seq_lens_kv: Tensor, block_tables: Tensor | None = None, out: Tensor | None = None, out_dtype: str | dtype | None = None, sinks: List[Tensor] | None = None, pos_encoding_mode: str = None, logits_soft_cap_scale: float | None = None, mask_mode: str = 'causal', window_left: int = -1, chunked_attention_size: int = 0, save_softmax_stats: bool = False, skip_softmax_threshold_scale_factor: float = 0) Tensor | Tuple[Tensor, Tensor]

TRT-LLM FMHAv2 prefill attention.

Parameters:
  • qkv – Query/key/value input; expected format is determined by input_layout.

  • input_layout

    Specifies the layout of the query/key/value tensors: - PACKED_QKV: qkv is a single tensor of shape

    [num_tokens, 3, num_heads, head_dim].

    • CONTIGUOUS_Q_KV: qkv is (Q, KV) where Q has shape [num_tokens, num_heads, head_dim] and KV has shape [num_tokens, 2, num_kv_heads, head_dim] (KV[:, 0, ...] is key, KV[:, 1, ...] is value).

    • Q_PAGED_KV_HND: qkv is (Q, paged_KV) where Q has shape [num_tokens, num_heads, head_dim] and paged_KV has shape [num_pages, 2, num_kv_heads, page_size, head_dim] (paged_KV[:, 0, ...] is key cache, paged_KV[:, 1, ...] is value cache).

    • Q_PAGED_KV_NHD: same as Q_PAGED_KV_HND but paged_KV shape is [num_pages, 2, page_size, num_kv_heads, head_dim].

    • SEPARATE_Q_K_V: qkv is (Q, K, V) where Q has shape [num_tokens, num_heads, head_dim] and K, V have shape [num_tokens, num_kv_heads, head_dim].

  • workspace_buffer – The workspace buffer. Must be initialized to 0 for its first use.

  • seq_lens – The KV sequence length of each request, shape: [batch_size].

  • max_q_len – The maximum sequence length for query.

  • max_kv_len – The maximum sequence length for KV cache.

  • bmm1_scale – The fused scale for BMM1 (QK^T) computation.

  • bmm2_scale – The fused scale for BMM2 (softmax(QK^T) * V) computation.

  • batch_size – The batch size.

  • cum_seq_lens_q – The cumulative sequence lengths for query, shape: [batch_size + 1].

  • cum_seq_lens_kv – The cumulative sequence lengths for KV cache, shape: [batch_size + 1].

  • block_tables – The page table for KV cache, shape: [batch_size, max_num_pages_per_seq]. Required when using paged KV cache format.

  • out – The output tensor. If not provided, will be allocated with out_dtype. If out_dtype is also not provided, will use the dtype of query.

  • out_dtype – The output dtype. If not provided, will use the dtype of out or query.

  • sinks – Additional value per head in the denominator of the softmax.

  • pos_encoding_mode – The position encoding mode, could be alibi. Defaults to None.

  • logits_soft_cap_scale – The logits soft cap scale. Defaults to None, which means no soft cap.

  • mask_mode – The mask mode, could be causal, sliding_window, or chunked. Defaults to causal.

  • window_left – The left (inclusive) window size for the attention window, when set to -1, the window size will be set to the full length of the sequence. Defaults to -1. Only effective when mask_mode is sliding_window.

  • chunked_attention_size – The chunked attention size. Defaults to 0, which means no chunked attention. Only effective when mask_mode is chunked. Must be a power of 2.

  • save_softmax_stats – Whether to save the softmax statistics. Defaults to False.

  • skip_softmax_threshold_scale_factor – The factor of skip-softmax (Sparse Attention), Skip softmax and BMM2 when exp(local_max - global_max) < threshold, where threshold = skip_softmax_threshold_scale_factor / seqlen. Defaults to 0 (disabled).

Returns:

  • If save_softmax_stats is False, the attention output tensor.

  • If save_softmax_stats is True, a tuple of two tensors

  • * The attention output tensor.

  • * The softmax statistics tensor (LSE).