flashinfer.norm.gemma_fused_add_rmsnorm

flashinfer.norm.gemma_fused_add_rmsnorm(input: torch.Tensor, residual: torch.Tensor, weight: torch.Tensor, eps: float = 1e-06) None

Gemma-style fused add root mean square normalization.

Step 1: residual[i] += input[i]

Step 2: input[i] = (residual[i] / RMS(residual)) * (weight + 1)

Parameters:
  • input (torch.Tensor) – Input tensor, shape (batch_size, hidden_size).

  • residual (torch.Tensor) – Residual tensor, shape (batch_size, hidden_size).

  • weight (torch.Tensor) – Weight tensor, shape (hidden_size,).

  • eps (float) – Epsilon for numerical stability.