flashinfer.norm.fused_dit_gate_residual_layernorm_gamma_beta¶
- flashinfer.norm.fused_dit_gate_residual_layernorm_gamma_beta(input: Tensor, residual: Tensor, gate: Tensor, gamma: Tensor, beta: Tensor, *, gate_bias: Tensor | None = None, epsilon: float = 1e-06, use_nvfp4: bool = False, use_mxfp8: bool = False, global_scaling_factor: Tensor | None = None, input_global_scaling_factor: Tensor | None = None, residual_out: Tensor | None = None, norm_out: Tensor | None = None, sf_out: Tensor | None = None) Tuple[Tensor, Tensor]¶
Fused gate + residual + LayerNorm(gamma, beta) for DIT self-attention.
Computes in a single kernel:
residual_out = residual + input * (gate + gate_bias) norm_out = LayerNorm(residual_out, gamma, beta)
- Parameters:
input (torch.Tensor) – Input tensor
[batch, num_rows, 3072], BF16, contiguous.residual (torch.Tensor) – Residual tensor, same shape as input, BF16, contiguous.
gate (torch.Tensor) – Gating tensor
[batch, num_rows, 3072], BF16. May be non-contiguous with stride6 * hidden_dimin dim 1 (from WAN’stemb.chunk(6, dim=2)pattern).gamma (torch.Tensor) – LayerNorm weight
[3072], FP32.beta (torch.Tensor) – LayerNorm bias
[3072], FP32.gate_bias (Optional[torch.Tensor]) – Bias for gate
[3072]or[1, 3072], FP32.epsilon (float) – LayerNorm epsilon.
use_nvfp4 (bool) – Quantize norm output to NVFP4 (SM100+ only).
use_mxfp8 (bool) – Quantize norm output to MXFP8 (SM100+ only).
global_scaling_factor (Optional[torch.Tensor]) – Global scale for NVFP4 output
[1], FP32. Required whenuse_nvfp4=True.input_global_scaling_factor (Optional[torch.Tensor]) – Scale applied to input before gating (for pre-quantized NVFP4 inputs)
[1], FP32.residual_out (Optional[torch.Tensor]) – Pre-allocated residual output
[batch, num_rows, 3072], BF16.norm_out (Optional[torch.Tensor]) – Pre-allocated norm output. Shape depends on output format.
sf_out (Optional[torch.Tensor]) – Pre-allocated scale factor output for NVFP4/MXFP8.
- Returns:
(residual_out, norm_out). For BF16: both are[batch, num_rows, 3072]BF16. For NVFP4/MXFP8:residual_outis BF16,norm_outis packed int32.- Return type:
Tuple[torch.Tensor, torch.Tensor]
Note
This kernel targets WAN 2.2 5B (hidden_dim=3072 only). Primary targets: SM90 (Hopper), SM100/SM103 (Blackwell). BF16 output compatible with SM80+. NVFP4/MXFP8 requires SM100+.