flashinfer.norm.layernorm¶
- flashinfer.norm.layernorm(input: Tensor, gemma: Tensor, beta: Tensor, eps: float = 1e-06) Tensor ¶
Layer normalization. :param input: Input tensor, shape (batch_size, hidden_size). Need to be bfloat16. :type input: torch.Tensor :param gemma: Gemma tensor, shape (hidden_size,). Need to be float32. :type gemma: torch.Tensor :param beta: Beta tensor, shape (hidden_size,). Need to be float32. :type beta: torch.Tensor :param eps: Epsilon for numerical stability. :type eps: float
- Returns:
output – Layer Normalized tensor, shape (batch_size, hidden_size). Same dtype as input.
- Return type:
torch.Tensor