flashinfer.norm.fused_add_rmsnorm¶
- flashinfer.norm.fused_add_rmsnorm(input: torch.Tensor, residual: torch.Tensor, weight: torch.Tensor, eps: float = 1e-06) None ¶
Fused add root mean square normalization.
Step 1:
residual[i] += input[i]
Step 2:
input[i] = (residual[i] / RMS(residual)) * weight[i]
- Parameters:
input (torch.Tensor) – Input tensor, shape (batch_size, hidden_size).
residual (torch.Tensor) – Residual tensor, shape (batch_size, hidden_size).
weight (torch.Tensor) – Weight tensor, shape (hidden_size,).
eps (float) – Epsilon for numerical stability.