flashinfer.norm.fused_dit_gate_residual_layernorm_scale_shift

flashinfer.norm.fused_dit_gate_residual_layernorm_scale_shift(input: Tensor, residual: Tensor, gate: Tensor, scale: Tensor, shift: Tensor, *, gate_bias: Tensor | None = None, scale_bias: Tensor | None = None, shift_bias: Tensor | None = None, epsilon: float = 1e-06, use_nvfp4: bool = False, use_mxfp8: bool = False, global_scaling_factor: Tensor | None = None, input_global_scaling_factor: Tensor | None = None, residual_out: Tensor | None = None, norm_out: Tensor | None = None, sf_out: Tensor | None = None) Tuple[Tensor, Tensor]

Fused gate + residual + LayerNorm + scale/shift for DIT self-attention.

Computes in a single kernel:

residual_out = residual + input * (gate + gate_bias)
norm_out = LayerNorm(residual_out) * (1 + scale + scale_bias) + (shift + shift_bias)
Parameters:
  • input (torch.Tensor) – Input tensor [batch, num_rows, 3072], BF16, contiguous.

  • residual (torch.Tensor) – Residual tensor, same shape as input, BF16, contiguous.

  • gate (torch.Tensor) – Gating tensor, BF16. Stride 6 * hidden_dim from temb.chunk(6, dim=2).

  • scale (torch.Tensor) – Scale tensor, BF16. Same stride convention as gate.

  • shift (torch.Tensor) – Shift tensor, BF16. Same stride convention as gate.

  • gate_bias (Optional[torch.Tensor]) – Biases [3072] or [1, 3072], FP32.

  • scale_bias (Optional[torch.Tensor]) – Biases [3072] or [1, 3072], FP32.

  • shift_bias (Optional[torch.Tensor]) – Biases [3072] or [1, 3072], FP32.

  • epsilon (float) – LayerNorm epsilon.

  • use_nvfp4 (bool) – Quantize norm output to NVFP4 (SM100+ only).

  • use_mxfp8 (bool) – Quantize norm output to MXFP8 (SM100+ only).

  • global_scaling_factor (Optional[torch.Tensor]) – Global scale for NVFP4 output [1], FP32.

  • input_global_scaling_factor (Optional[torch.Tensor]) – Scale for pre-quantized NVFP4 inputs [1], FP32.

  • residual_out (Optional[torch.Tensor]) – Pre-allocated residual output [batch, num_rows, 3072], BF16.

  • norm_out (Optional[torch.Tensor]) – Pre-allocated norm output. Shape depends on output format.

  • sf_out (Optional[torch.Tensor]) – Pre-allocated scale factor output for NVFP4/MXFP8.

Returns:

(residual_out, norm_out).

Return type:

Tuple[torch.Tensor, torch.Tensor]

Note

This kernel targets WAN 2.2 5B (hidden_dim=3072 only). Primary targets: SM90 (Hopper), SM100/SM103 (Blackwell). BF16 output compatible with SM80+. NVFP4/MXFP8 requires SM100+.