flashinfer.gdn_decode.gated_delta_rule_decode¶
- flashinfer.gdn_decode.gated_delta_rule_decode(q: Tensor, k: Tensor, v: Tensor, state: Tensor, A_log: Tensor, a: Tensor, dt_bias: Tensor, b: Tensor, scale: float | None = None, output: Tensor | None = None, use_qk_l2norm: bool = True) Tuple[Tensor, Tensor]¶
Gated Delta Rule Decode kernel (K-major layout, no transpose needed).
Implements the decode phase of gated delta rule linear attention, processing one token at a time and updating the recurrent state. This variant uses K-major state layout
[B, HV, K, V](no transposition).- Parameters:
q (torch.Tensor) – Current query of shape
[B, 1, H, K]. Must be float16/bfloat16.k (torch.Tensor) – Current key of shape
[B, 1, H, K]. Must be float16/bfloat16.v (torch.Tensor) – Current value of shape
[B, 1, HV, V]. Must be float16/bfloat16.state (torch.Tensor) – Current state of shape
[B, HV, K, V](k-major layout). Must be float32. Updated in-place.A_log (torch.Tensor) – Log decay parameter of shape
[HV]. Must be float32.a (torch.Tensor) – Input-dependent decay of shape
[B, 1, HV]. Must be float16/bfloat16.dt_bias (torch.Tensor) – Decay bias of shape
[HV]. Must be bfloat16 or float32.b (torch.Tensor) – Update gate (beta) input of shape
[B, 1, HV]. Must be float16/bfloat16.scale (float, optional) – Scale factor for queries. If
None, defaults to1 / sqrt(K).output (torch.Tensor, optional) – Pre-allocated output tensor of shape
[B, 1, HV, V]. Allocated automatically whenNone.use_qk_l2norm (bool) – Whether to apply L2 normalization to q and k. Default:
True.
- Returns:
(output, state)whereoutputhas shape[B, 1, HV, V]andstatehas shape[B, HV, K, V](updated in-place).- Return type:
Tuple[torch.Tensor, torch.Tensor]
Notes
Requires SM90 (Hopper) architecture.
State is updated in-place.
KandVmust each be>= 128.Vmust be a multiple of 32 (TILE_V_NT): the launcher conservatively asserts the large-batch tile size to cover both code paths, even though the small-batch kernel could in principle acceptV % 16 == 0(TILE_V_SMALL_NT).State layout is k-major:
[B, HV, K, V](no transpose needed).