flashinfer.norm

Kernels for normalization layers.

rmsnorm(input, weight[, eps, out, enable_pdl])

Root mean square normalization.

rmsnorm_quant(out, input, weight, scale[, ...])

Root mean square normalization + fp8 quantization.

fused_add_rmsnorm(input, residual, weight[, ...])

Fused add root mean square normalization.

fused_add_rmsnorm_quant(out, input, ...[, ...])

Fused add root mean square normalization + fp8 quantization.

gemma_rmsnorm(input, weight[, eps, out, ...])

Gemma-style root mean square normalization.

gemma_fused_add_rmsnorm(input, residual, weight)

Gemma-style fused add root mean square normalization.

layernorm(input, gemma, beta[, eps])

Layer normalization.

fused_rmsnorm_silu(input, weight[, eps, ...])

Fused RMSNorm + SiLU activation.

fused_qk_rmsnorm_rope(qkv, q_weight, ...[, ...])

Fused QK RMSNorm + 3D RoPE + V copy for video generation DIT self-attention.

fused_dit_residual_layernorm_scale_shift(...)

Fused residual + LayerNorm + scale/shift for DIT self-attention.

fused_dit_gate_residual_layernorm_scale_shift(...)

Fused gate + residual + LayerNorm + scale/shift for DIT self-attention.

fused_dit_gate_residual_layernorm_gamma_beta(...)

Fused gate + residual + LayerNorm(gamma, beta) for DIT self-attention.