flashinfer.norm.fused_dit_gate_residual_layernorm_scale_shift¶
- flashinfer.norm.fused_dit_gate_residual_layernorm_scale_shift(input: Tensor, residual: Tensor, gate: Tensor, scale: Tensor, shift: Tensor, *, gate_bias: Tensor | None = None, scale_bias: Tensor | None = None, shift_bias: Tensor | None = None, epsilon: float = 1e-06, use_nvfp4: bool = False, use_mxfp8: bool = False, global_scaling_factor: Tensor | None = None, input_global_scaling_factor: Tensor | None = None, residual_out: Tensor | None = None, norm_out: Tensor | None = None, sf_out: Tensor | None = None) Tuple[Tensor, Tensor]¶
Fused gate + residual + LayerNorm + scale/shift for DIT self-attention.
Computes in a single kernel:
residual_out = residual + input * (gate + gate_bias) norm_out = LayerNorm(residual_out) * (1 + scale + scale_bias) + (shift + shift_bias)
- Parameters:
input (torch.Tensor) – Input tensor
[batch, num_rows, 3072], BF16, contiguous.residual (torch.Tensor) – Residual tensor, same shape as input, BF16, contiguous.
gate (torch.Tensor) – Gating tensor, BF16. Stride
6 * hidden_dimfromtemb.chunk(6, dim=2).scale (torch.Tensor) – Scale tensor, BF16. Same stride convention as gate.
shift (torch.Tensor) – Shift tensor, BF16. Same stride convention as gate.
gate_bias (Optional[torch.Tensor]) – Biases
[3072]or[1, 3072], FP32.scale_bias (Optional[torch.Tensor]) – Biases
[3072]or[1, 3072], FP32.shift_bias (Optional[torch.Tensor]) – Biases
[3072]or[1, 3072], FP32.epsilon (float) – LayerNorm epsilon.
use_nvfp4 (bool) – Quantize norm output to NVFP4 (SM100+ only).
use_mxfp8 (bool) – Quantize norm output to MXFP8 (SM100+ only).
global_scaling_factor (Optional[torch.Tensor]) – Global scale for NVFP4 output
[1], FP32.input_global_scaling_factor (Optional[torch.Tensor]) – Scale for pre-quantized NVFP4 inputs
[1], FP32.residual_out (Optional[torch.Tensor]) – Pre-allocated residual output
[batch, num_rows, 3072], BF16.norm_out (Optional[torch.Tensor]) – Pre-allocated norm output. Shape depends on output format.
sf_out (Optional[torch.Tensor]) – Pre-allocated scale factor output for NVFP4/MXFP8.
- Returns:
(residual_out, norm_out).- Return type:
Tuple[torch.Tensor, torch.Tensor]
Note
This kernel targets WAN 2.2 5B (hidden_dim=3072 only). Primary targets: SM90 (Hopper), SM100/SM103 (Blackwell). BF16 output compatible with SM80+. NVFP4/MXFP8 requires SM100+.