flashinfer.norm.layernorm

flashinfer.norm.layernorm(input: Tensor, gemma: Tensor, beta: Tensor, eps: float = 1e-06) Tensor

Layer normalization. :param input: Input tensor, shape (batch_size, hidden_size). Need to be bfloat16. :type input: torch.Tensor :param gemma: Gemma tensor, shape (hidden_size,). Need to be float32. :type gemma: torch.Tensor :param beta: Beta tensor, shape (hidden_size,). Need to be float32. :type beta: torch.Tensor :param eps: Epsilon for numerical stability. :type eps: float

Returns:

output – Layer Normalized tensor, shape (batch_size, hidden_size). Same dtype as input.

Return type:

torch.Tensor