flashinfer.norm.rmsnorm#
- flashinfer.norm.rmsnorm(input: torch.Tensor, weight: torch.Tensor, eps: float = 1e-06, out: torch.Tensor | None = None) torch.Tensor #
Root mean square normalization.
out[i] = (input[i] / RMS(input)) * weight[i]
- Parameters:
input (torch.Tensor) – Input tensor, shape (batch_size, hidden_size).
weight (torch.Tensor) – Weight tensor, shape (hidden_size,).
eps (float) – Epsilon for numerical stability.
out (Optional[torch.Tensor]) – The the output tensor, if specified, the kernel will update this tensor inplace.
- Returns:
output – Normalized tensor, shape (batch_size, hidden_size).
- Return type:
torch.Tensor