flashinfer.norm.fused_add_rmsnorm_quant

flashinfer.norm.fused_add_rmsnorm_quant(out: Tensor, input: Tensor, residual: Tensor, weight: Tensor, scale: Tensor, eps: float = 1e-06, enable_pdl: bool | None = None) None

Fused add root mean square normalization + fp8 quantization.

Step 1: residual[i] += input[i]

Step 2: input[i] = ((residual[i] / RMS(residual)) * weight[i]).to(fp8)

Parameters:
  • out (torch.Tensor) – The output tensor, will quantize the output to the dtype of this tensor.

  • input (torch.Tensor) – Input tensor, shape (batch_size, hidden_size).

  • residual (torch.Tensor) – Residual tensor, shape (batch_size, hidden_size).

  • weight (torch.Tensor) – Weight tensor, shape (hidden_size,).

  • scale (torch.Tensor) – Scale factor for quantization, shape (1,).

  • eps (float) – Epsilon for numerical stability.

  • enable_pdl (bool) – Whether to enable programmatic dependent launch