flashinfer.gdn_decode.gated_delta_rule_decode

flashinfer.gdn_decode.gated_delta_rule_decode(q: Tensor, k: Tensor, v: Tensor, state: Tensor, A_log: Tensor, a: Tensor, dt_bias: Tensor, b: Tensor, scale: float | None = None, output: Tensor | None = None, use_qk_l2norm: bool = True) Tuple[Tensor, Tensor]

Gated Delta Rule Decode kernel (K-major layout, no transpose needed).

Implements the decode phase of gated delta rule linear attention, processing one token at a time and updating the recurrent state. This variant uses K-major state layout [B, HV, K, V] (no transposition).

Parameters:
  • q (torch.Tensor) – Current query of shape [B, 1, H, K]. Must be float16/bfloat16.

  • k (torch.Tensor) – Current key of shape [B, 1, H, K]. Must be float16/bfloat16.

  • v (torch.Tensor) – Current value of shape [B, 1, HV, V]. Must be float16/bfloat16.

  • state (torch.Tensor) – Current state of shape [B, HV, K, V] (k-major layout). Must be float32. Updated in-place.

  • A_log (torch.Tensor) – Log decay parameter of shape [HV]. Must be float32.

  • a (torch.Tensor) – Input-dependent decay of shape [B, 1, HV]. Must be float16/bfloat16.

  • dt_bias (torch.Tensor) – Decay bias of shape [HV]. Must be bfloat16 or float32.

  • b (torch.Tensor) – Update gate (beta) input of shape [B, 1, HV]. Must be float16/bfloat16.

  • scale (float, optional) – Scale factor for queries. If None, defaults to 1 / sqrt(K).

  • output (torch.Tensor, optional) – Pre-allocated output tensor of shape [B, 1, HV, V]. Allocated automatically when None.

  • use_qk_l2norm (bool) – Whether to apply L2 normalization to q and k. Default: True.

Returns:

(output, state) where output has shape [B, 1, HV, V] and state has shape [B, HV, K, V] (updated in-place).

Return type:

Tuple[torch.Tensor, torch.Tensor]

Notes

  • Requires SM90 (Hopper) architecture.

  • State is updated in-place.

  • K and V must each be >= 128. V must be a multiple of 32 (TILE_V_NT): the launcher conservatively asserts the large-batch tile size to cover both code paths, even though the small-batch kernel could in principle accept V % 16 == 0 (TILE_V_SMALL_NT).

  • State layout is k-major: [B, HV, K, V] (no transpose needed).