flashinfer.prefill.trtllm_fmha_v2_prefill¶
- flashinfer.prefill.trtllm_fmha_v2_prefill(qkv: Tensor | Tuple[Tensor, Tensor] | Tuple[Tensor, Tensor, Tensor], input_layout: str, workspace_buffer: Tensor, seq_lens: Tensor, max_q_len: int, max_kv_len: int, bmm1_scale: float, bmm2_scale: float, batch_size: int, cum_seq_lens_q: Tensor, cum_seq_lens_kv: Tensor, block_tables: Tensor | None = None, out: Tensor | None = None, out_dtype: str | dtype | None = None, sinks: List[Tensor] | None = None, pos_encoding_mode: str = None, logits_soft_cap_scale: float | None = None, mask_mode: str = 'causal', window_left: int = -1, chunked_attention_size: int = 0, save_softmax_stats: bool = False, skip_softmax_threshold_scale_factor: float = 0) Tensor | Tuple[Tensor, Tensor]¶
TRT-LLM FMHAv2 prefill attention.
- Parameters:
qkv – Query/key/value input; expected format is determined by
input_layout.input_layout –
Specifies the layout of the query/key/value tensors: -
PACKED_QKV:qkvis a single tensor of shape[num_tokens, 3, num_heads, head_dim].CONTIGUOUS_Q_KV:qkvis(Q, KV)where Q has shape[num_tokens, num_heads, head_dim]and KV has shape[num_tokens, 2, num_kv_heads, head_dim](KV[:, 0, ...]is key,KV[:, 1, ...]is value).Q_PAGED_KV_HND:qkvis(Q, paged_KV)where Q has shape[num_tokens, num_heads, head_dim]and paged_KV has shape[num_pages, 2, num_kv_heads, page_size, head_dim](paged_KV[:, 0, ...]is key cache,paged_KV[:, 1, ...]is value cache).Q_PAGED_KV_NHD: same asQ_PAGED_KV_HNDbut paged_KV shape is[num_pages, 2, page_size, num_kv_heads, head_dim].SEPARATE_Q_K_V:qkvis(Q, K, V)where Q has shape[num_tokens, num_heads, head_dim]and K, V have shape[num_tokens, num_kv_heads, head_dim].
workspace_buffer – The workspace buffer. Must be initialized to 0 for its first use.
seq_lens – The KV sequence length of each request, shape:
[batch_size].max_q_len – The maximum sequence length for query.
max_kv_len – The maximum sequence length for KV cache.
bmm1_scale – The fused scale for BMM1 (QK^T) computation.
bmm2_scale – The fused scale for BMM2 (softmax(QK^T) * V) computation.
batch_size – The batch size.
cum_seq_lens_q – The cumulative sequence lengths for query, shape:
[batch_size + 1].cum_seq_lens_kv – The cumulative sequence lengths for KV cache, shape:
[batch_size + 1].block_tables – The page table for KV cache, shape:
[batch_size, max_num_pages_per_seq]. Required when using paged KV cache format.out – The output tensor. If not provided, will be allocated with
out_dtype. Ifout_dtypeis also not provided, will use the dtype of query.out_dtype – The output dtype. If not provided, will use the dtype of
outor query.sinks – Additional value per head in the denominator of the softmax.
pos_encoding_mode – The position encoding mode, could be
alibi. Defaults toNone.logits_soft_cap_scale – The logits soft cap scale. Defaults to
None, which means no soft cap.mask_mode – The mask mode, could be
causal,sliding_window, orchunked. Defaults tocausal.window_left – The left (inclusive) window size for the attention window, when set to
-1, the window size will be set to the full length of the sequence. Defaults to-1. Only effective whenmask_modeissliding_window.chunked_attention_size – The chunked attention size. Defaults to
0, which means no chunked attention. Only effective whenmask_modeischunked. Must be a power of 2.save_softmax_stats – Whether to save the softmax statistics. Defaults to
False.skip_softmax_threshold_scale_factor – The factor of skip-softmax (Sparse Attention), Skip softmax and BMM2 when exp(local_max - global_max) < threshold, where threshold = skip_softmax_threshold_scale_factor / seqlen. Defaults to
0(disabled).
- Returns:
If
save_softmax_statsisFalse, the attention output tensor.If
save_softmax_statsisTrue, a tuple of two tensors* The attention output tensor.
* The softmax statistics tensor (LSE).