flashinfer.norm.rmsnorm#

flashinfer.norm.rmsnorm(x: torch.Tensor, w: torch.Tensor, eps: float = 1e-06)#

Root mean square normalization.

Parameters:
  • x (torch.Tensor) – Input tensor, shape (batch_size, hidden_size).

  • w (torch.Tensor) – Weight tensor, shape (hidden_size,).

  • eps (float) – Epsilon for numerical stability.

Returns:

y – Normalized tensor, shape (batch_size, hidden_size).

Return type:

torch.Tensor