flashinfer.norm.fused_dit_residual_layernorm_scale_shift¶
- flashinfer.norm.fused_dit_residual_layernorm_scale_shift(input: Tensor, scale: Tensor, shift: Tensor, *, residual: Tensor | None = None, scale_bias: Tensor | None = None, shift_bias: Tensor | None = None, epsilon: float = 1e-06, use_nvfp4: bool = False, use_mxfp8: bool = False, global_scaling_factor: Tensor | None = None, input_global_scaling_factor: Tensor | None = None, residual_out: Tensor | None = None, norm_out: Tensor | None = None, sf_out: Tensor | None = None) Tuple[Tensor, Tensor]¶
Fused residual + LayerNorm + scale/shift for DIT self-attention.
Computes in a single kernel:
residual_out = residual + input # (or just input if residual is None) norm_out = LayerNorm(residual_out) * (1 + scale + scale_bias) + (shift + shift_bias)
- Parameters:
input (torch.Tensor) – Input tensor
[batch, num_rows, 3072], BF16, contiguous.scale (torch.Tensor) – Scale tensor, BF16. Stride
6 * hidden_dimfromtemb.chunk(6, dim=2).shift (torch.Tensor) – Shift tensor, BF16. Same stride convention as scale.
residual (Optional[torch.Tensor]) – Residual tensor, BF16, contiguous. If None, no residual addition.
scale_bias (Optional[torch.Tensor]) – Biases
[3072]or[1, 3072], FP32.shift_bias (Optional[torch.Tensor]) – Biases
[3072]or[1, 3072], FP32.epsilon (float) – LayerNorm epsilon.
use_nvfp4 (bool) – Quantize norm output to NVFP4 (SM100+ only).
use_mxfp8 (bool) – Quantize norm output to MXFP8 (SM100+ only).
global_scaling_factor (Optional[torch.Tensor]) – Global scale for NVFP4 output
[1], FP32.input_global_scaling_factor (Optional[torch.Tensor]) – Scale for pre-quantized NVFP4 inputs
[1], FP32.residual_out (Optional[torch.Tensor]) – Pre-allocated residual output
[batch, num_rows, 3072], BF16.norm_out (Optional[torch.Tensor]) – Pre-allocated norm output. Shape depends on output format.
sf_out (Optional[torch.Tensor]) – Pre-allocated scale factor output for NVFP4/MXFP8.
- Returns:
(residual_out, norm_out).- Return type:
Tuple[torch.Tensor, torch.Tensor]
Note
This kernel targets WAN 2.2 5B (hidden_dim=3072 only). Primary targets: SM90 (Hopper), SM100/SM103 (Blackwell). BF16 output compatible with SM80+. NVFP4/MXFP8 requires SM100+.