flashinfer.norm#

Kernels for normalization layers.

rmsnorm(input, weight[, eps, out])

Root mean square normalization.

fused_add_rmsnorm(input, residual, weight[, eps])

Fused add root mean square normalization.

gemma_rmsnorm(input, weight[, eps, out])

Gemma-style root mean square normalization.

gemma_fused_add_rmsnorm(input, residual, weight)

Gemma-style fused add root mean square normalization.