flashinfer.gdn_prefill.chunk_gated_delta_rule¶
- flashinfer.gdn_prefill.chunk_gated_delta_rule(q: Tensor, k: Tensor, v: Tensor, g: Tensor | None = None, beta: Tensor | None = None, scale: float | None = None, initial_state: Tensor | None = None, output_final_state: bool = False, cu_seqlens: Tensor | None = None, use_qk_l2norm_in_kernel: bool = False, output: Tensor | None = None, output_state: Tensor | None = None, state_checkpoints: Tensor | None = None, checkpoint_cu_starts: Tensor | None = None, checkpoint_every_n_tokens: int = 0) Tensor | Tuple[Tensor, Tensor]¶
Chunked Gated Delta Rule (GDN) attention for prefill.
Implements the gated delta rule linear attention mechanism for efficient training and inference. Supports both GQA (grouped query attention) and GVA (grouped value attention) configurations.
- Parameters:
q (torch.Tensor) – Queries of shape
[total_seq_len, num_q_heads, head_size]. Must be contiguous and on CUDA.k (torch.Tensor) – Keys of shape
[total_seq_len, num_k_heads, head_size]. Must be contiguous and on CUDA.v (torch.Tensor) – Values of shape
[total_seq_len, num_v_heads, head_size]. Must be contiguous and on CUDA.g (torch.Tensor, optional) – Forget gate (alpha) of shape
[total_seq_len, num_sab_heads]wherenum_sab_heads = max(num_q_heads, num_v_heads). Must be float32. Defaults to all ones whenNone.beta (torch.Tensor, optional) – Update gate (beta) of shape
[total_seq_len, num_sab_heads]. Must be float32. Defaults to all ones whenNone.scale (float, optional) – Scale factor for the attention scores. Defaults to
1 / sqrt(head_size)whenNone.initial_state (torch.Tensor, optional) – Initial KV state of shape
[num_seqs, num_sab_heads, head_size, head_size]. Must be float32. Starts from zero state whenNone.output_final_state (bool) – Whether to output the final state. Default:
False.cu_seqlens (torch.Tensor) – Cumulative sequence lengths of shape
[num_seqs + 1], integer dtype on the same CUDA device asq. Required for variable-length sequences (varlen mode); must not beNone(asserted at the top of the function body). Internally cast toint32for the SM100/Blackwell CuTe-DSL kernel and toint64for the SM90/Hopper C++ kernel, so the caller can pass either dtype.use_qk_l2norm_in_kernel (bool) – Whether to use QK L2 normalization in kernel. Default:
False.output (torch.Tensor, optional) – Pre-allocated output tensor of shape
[total_seq_len, num_o_heads, head_size]wherenum_o_heads = max(num_q_heads, num_v_heads). Allocated automatically whenNone.output_state (torch.Tensor, optional) – Pre-allocated output state tensor of shape
[num_seqs, num_sab_heads, head_size, head_size], float32. Required whenoutput_final_state=True.state_checkpoints (torch.Tensor, optional) – Pre-allocated checkpoint tensor of shape
[total_checkpoints, num_sab_heads, head_size, head_size], float32. Required whencheckpoint_every_n_tokens > 0.checkpoint_cu_starts (torch.Tensor, optional) – Cumulative checkpoint counts of shape
[num_seqs + 1], int64.checkpoint_cu_starts[i+1] - checkpoint_cu_starts[i]is the number of checkpoints for sequencei(=seq_len_i // checkpoint_every_n_tokens). Required whencheckpoint_every_n_tokens > 0.checkpoint_every_n_tokens (int) – Store intermediate state every N tokens. Must be a multiple of the chunk size (64).
0disables checkpointing (default).
- Returns:
When
output_final_state=False, the output tensor of shape[total_seq_len, num_o_heads, head_size]. Otherwise a tuple(output, final_state)wherefinal_statehas shape[num_seqs, num_sab_heads, head_size, head_size].- Return type:
torch.Tensor or Tuple[torch.Tensor, torch.Tensor]
Notes
Supports GQA (
num_q_heads > num_k_heads = num_v_heads) and GVA (num_v_heads > num_q_heads = num_k_heads).The final state layout is
[N, H, V, K].Requires SM90 (Hopper) or SM100 (Blackwell) architecture. The SM100 path requires
head_size == 128andnvidia-cutlass-dsl[cu13]>=4.4.2(pip install flashinfer-python[cu13]).