flashinfer.comm.trtllm_mnnvl_ar.trtllm_mnnvl_fused_allreduce_add_rmsnorm_quant¶
- flashinfer.comm.trtllm_mnnvl_ar.trtllm_mnnvl_fused_allreduce_add_rmsnorm_quant(input: Tensor, residual_in: Tensor, gamma: Tensor, workspace: MNNVLAllReduceFusionWorkspace, epsilon: float | None = None, output: Tensor | None = None, residual_out: Tensor | None = None, quant_out: Tensor | None = None, scale_out: Tensor | None = None, output_scale: Tensor | float | None = None, layout_code: int = 0, quant_type: int = 0, launch_with_pdl: bool = False, strategy: MNNVLAllreduceFusionStrategy = MNNVLAllreduceFusionStrategy.AUTO, weight_bias: float = 0.0) Tuple[Tensor, Tensor | None, Tensor, Tensor | None]¶
Perform MNNVL AllReduce + Residual + RMSNorm + FP8/NVFP4 quantization.
Quantization is applied after RMSNorm.
outputis optional; pass it only when the normalized non-quantized tensor is also needed. The quantized result is always returned asquant_out.- Parameters:
input – Input tensor with shape
[num_tokens, hidden_dim].residual_in – Residual tensor with the same shape as
input.gamma – RMSNorm gamma tensor with shape
[hidden_dim].workspace – MNNVL workspace for the tensor-parallel group.
epsilon – RMSNorm epsilon. Defaults to
torch.finfo(input.dtype).eps.output – Optional normalized output tensor with shape
[num_tokens, hidden_dim].residual_out – Optional residual output tensor with shape
[num_tokens, hidden_dim].quant_out – Optional quantized output. For FP8, shape must be
[num_tokens, hidden_dim]and dtypetorch.float8_e4m3fn. For NVFP4, shape must be[num_tokens, hidden_dim // 2]and dtypetorch.uint8ortorch.float4_e2m1fn_x2.scale_out – Optional NVFP4 scale output. For
LINEARlayout, shape is[num_tokens, hidden_dim // 16]. ForSWIZZLED_128x4, provide a 1-D tensor large enough for the padded swizzled scale layout. FP8 ignores this argument.output_scale – Scalar float or float32 tensor used as the quantization output scale. Defaults to
1.0.layout_code – NVFP4 scale layout. MNNVL supports
SWIZZLED_128x4andLINEAR;SWIZZLED_8x4is not supported.quant_type –
MNNVLQuantType.FP8orMNNVLQuantType.NVFP4.launch_with_pdl – Whether to launch with PDL.
strategy – MNNVL execution strategy.
AUTOuses internal heuristics.weight_bias – Bias added to gamma before scaling.
0.0for standard RMSNorm;1.0for Gemma / Qwen3.5 RMSNorm.
- Returns:
A tuple
(quant_out, scale_out, residual_out, output).scale_outisNonefor FP8, andoutputisNoneunless requested.