flashinfer.norm.rmsnorm#
- flashinfer.norm.rmsnorm(x: torch.Tensor, w: torch.Tensor, eps: float = 1e-06)#
Root mean square normalization.
- Parameters:
x (torch.Tensor) – Input tensor, shape (batch_size, hidden_size).
w (torch.Tensor) – Weight tensor, shape (hidden_size,).
eps (float) – Epsilon for numerical stability.
- Returns:
y – Normalized tensor, shape (batch_size, hidden_size).
- Return type:
torch.Tensor