flashinfer.norm.fused_dit_gate_residual_layernorm_gamma_beta

flashinfer.norm.fused_dit_gate_residual_layernorm_gamma_beta(input: Tensor, residual: Tensor, gate: Tensor, gamma: Tensor, beta: Tensor, *, gate_bias: Tensor | None = None, epsilon: float = 1e-06, use_nvfp4: bool = False, use_mxfp8: bool = False, global_scaling_factor: Tensor | None = None, input_global_scaling_factor: Tensor | None = None, residual_out: Tensor | None = None, norm_out: Tensor | None = None, sf_out: Tensor | None = None) Tuple[Tensor, Tensor]

Fused gate + residual + LayerNorm(gamma, beta) for DIT self-attention.

Computes in a single kernel:

residual_out = residual + input * (gate + gate_bias)
norm_out = LayerNorm(residual_out, gamma, beta)
Parameters:
  • input (torch.Tensor) – Input tensor [batch, num_rows, 3072], BF16, contiguous.

  • residual (torch.Tensor) – Residual tensor, same shape as input, BF16, contiguous.

  • gate (torch.Tensor) – Gating tensor [batch, num_rows, 3072], BF16. May be non-contiguous with stride 6 * hidden_dim in dim 1 (from WAN’s temb.chunk(6, dim=2) pattern).

  • gamma (torch.Tensor) – LayerNorm weight [3072], FP32.

  • beta (torch.Tensor) – LayerNorm bias [3072], FP32.

  • gate_bias (Optional[torch.Tensor]) – Bias for gate [3072] or [1, 3072], FP32.

  • epsilon (float) – LayerNorm epsilon.

  • use_nvfp4 (bool) – Quantize norm output to NVFP4 (SM100+ only).

  • use_mxfp8 (bool) – Quantize norm output to MXFP8 (SM100+ only).

  • global_scaling_factor (Optional[torch.Tensor]) – Global scale for NVFP4 output [1], FP32. Required when use_nvfp4=True.

  • input_global_scaling_factor (Optional[torch.Tensor]) – Scale applied to input before gating (for pre-quantized NVFP4 inputs) [1], FP32.

  • residual_out (Optional[torch.Tensor]) – Pre-allocated residual output [batch, num_rows, 3072], BF16.

  • norm_out (Optional[torch.Tensor]) – Pre-allocated norm output. Shape depends on output format.

  • sf_out (Optional[torch.Tensor]) – Pre-allocated scale factor output for NVFP4/MXFP8.

Returns:

(residual_out, norm_out). For BF16: both are [batch, num_rows, 3072] BF16. For NVFP4/MXFP8: residual_out is BF16, norm_out is packed int32.

Return type:

Tuple[torch.Tensor, torch.Tensor]

Note

This kernel targets WAN 2.2 5B (hidden_dim=3072 only). Primary targets: SM90 (Hopper), SM100/SM103 (Blackwell). BF16 output compatible with SM80+. NVFP4/MXFP8 requires SM100+.