flashinfer.gdn_prefill.chunk_gated_delta_rule

flashinfer.gdn_prefill.chunk_gated_delta_rule(q: Tensor, k: Tensor, v: Tensor, g: Tensor | None = None, beta: Tensor | None = None, scale: float | None = None, initial_state: Tensor | None = None, output_final_state: bool = False, cu_seqlens: Tensor | None = None, use_qk_l2norm_in_kernel: bool = False, output: Tensor | None = None, output_state: Tensor | None = None, state_checkpoints: Tensor | None = None, checkpoint_cu_starts: Tensor | None = None, checkpoint_every_n_tokens: int = 0) Tensor | Tuple[Tensor, Tensor]

Chunked Gated Delta Rule (GDN) attention for prefill.

Implements the gated delta rule linear attention mechanism for efficient training and inference. Supports both GQA (grouped query attention) and GVA (grouped value attention) configurations.

Parameters:
  • q (torch.Tensor) – Queries of shape [total_seq_len, num_q_heads, head_size]. Must be contiguous and on CUDA.

  • k (torch.Tensor) – Keys of shape [total_seq_len, num_k_heads, head_size]. Must be contiguous and on CUDA.

  • v (torch.Tensor) – Values of shape [total_seq_len, num_v_heads, head_size]. Must be contiguous and on CUDA.

  • g (torch.Tensor, optional) – Forget gate (alpha) of shape [total_seq_len, num_sab_heads] where num_sab_heads = max(num_q_heads, num_v_heads). Must be float32. Defaults to all ones when None.

  • beta (torch.Tensor, optional) – Update gate (beta) of shape [total_seq_len, num_sab_heads]. Must be float32. Defaults to all ones when None.

  • scale (float, optional) – Scale factor for the attention scores. Defaults to 1 / sqrt(head_size) when None.

  • initial_state (torch.Tensor, optional) – Initial KV state of shape [num_seqs, num_sab_heads, head_size, head_size]. Must be float32. Starts from zero state when None.

  • output_final_state (bool) – Whether to output the final state. Default: False.

  • cu_seqlens (torch.Tensor) – Cumulative sequence lengths of shape [num_seqs + 1], integer dtype on the same CUDA device as q. Required for variable-length sequences (varlen mode); must not be None (asserted at the top of the function body). Internally cast to int32 for the SM100/Blackwell CuTe-DSL kernel and to int64 for the SM90/Hopper C++ kernel, so the caller can pass either dtype.

  • use_qk_l2norm_in_kernel (bool) – Whether to use QK L2 normalization in kernel. Default: False.

  • output (torch.Tensor, optional) – Pre-allocated output tensor of shape [total_seq_len, num_o_heads, head_size] where num_o_heads = max(num_q_heads, num_v_heads). Allocated automatically when None.

  • output_state (torch.Tensor, optional) – Pre-allocated output state tensor of shape [num_seqs, num_sab_heads, head_size, head_size], float32. Required when output_final_state=True.

  • state_checkpoints (torch.Tensor, optional) – Pre-allocated checkpoint tensor of shape [total_checkpoints, num_sab_heads, head_size, head_size], float32. Required when checkpoint_every_n_tokens > 0.

  • checkpoint_cu_starts (torch.Tensor, optional) – Cumulative checkpoint counts of shape [num_seqs + 1], int64. checkpoint_cu_starts[i+1] - checkpoint_cu_starts[i] is the number of checkpoints for sequence i (= seq_len_i // checkpoint_every_n_tokens). Required when checkpoint_every_n_tokens > 0.

  • checkpoint_every_n_tokens (int) – Store intermediate state every N tokens. Must be a multiple of the chunk size (64). 0 disables checkpointing (default).

Returns:

When output_final_state=False, the output tensor of shape [total_seq_len, num_o_heads, head_size]. Otherwise a tuple (output, final_state) where final_state has shape [num_seqs, num_sab_heads, head_size, head_size].

Return type:

torch.Tensor or Tuple[torch.Tensor, torch.Tensor]

Notes

  • Supports GQA (num_q_heads > num_k_heads = num_v_heads) and GVA (num_v_heads > num_q_heads = num_k_heads).

  • The final state layout is [N, H, V, K].

  • Requires SM90 (Hopper) or SM100 (Blackwell) architecture. The SM100 path requires head_size == 128 and nvidia-cutlass-dsl[cu13]>=4.4.2 (pip install flashinfer-python[cu13]).