flashinfer.norm.gemma_rmsnorm#

flashinfer.norm.gemma_rmsnorm(input: torch.Tensor, weight: torch.Tensor, eps: float = 1e-06, out: torch.Tensor | None = None) torch.Tensor#

Gemma-style root mean square normalization.

out[i] = (input[i] / RMS(input)) * (weight[i] + 1)

Parameters:
  • input (torch.Tensor) – Input tensor, shape (batch_size, hidden_size).

  • weight (torch.Tensor) – Weight tensor, shape (hidden_size,).

  • eps (float) – Epsilon for numerical stability.

  • out (Optional[torch.Tensor]) – The the output tensor, if specified, the kernel will update this tensor inplace.

Returns:

output – Gemma Normalized tensor, shape (batch_size, hidden_size).

Return type:

torch.Tensor