flashinfer.norm.rmsnorm_quant¶
- flashinfer.norm.rmsnorm_quant(out: Tensor, input: Tensor, weight: Tensor, scale: Tensor, eps: float = 1e-06, enable_pdl: bool | None = None) None¶
Root mean square normalization + fp8 quantization.
out[i] = ((input[i] / RMS(input)) * weight[i]).to(fp8)- Parameters:
out (torch.Tensor) – The output tensor, will quantize the output to the dtype of this tensor.
input (torch.Tensor) – Input tensor, 2D shape (batch_size, hidden_size).
weight (torch.Tensor) – Weight tensor, shape (hidden_size,).
scale (torch.Tensor) – Scale factor for quantization, shape (1,).
eps (float) – Epsilon for numerical stability.
enable_pdl (bool) – Whether to enable programmatic dependent launch