flashinfer.comm.trtllm_mnnvl_ar.trtllm_mnnvl_fused_allreduce_add_rmsnorm_quant

flashinfer.comm.trtllm_mnnvl_ar.trtllm_mnnvl_fused_allreduce_add_rmsnorm_quant(input: Tensor, residual_in: Tensor, gamma: Tensor, workspace: MNNVLAllReduceFusionWorkspace, epsilon: float | None = None, output: Tensor | None = None, residual_out: Tensor | None = None, quant_out: Tensor | None = None, scale_out: Tensor | None = None, output_scale: Tensor | float | None = None, layout_code: int = 0, quant_type: int = 0, launch_with_pdl: bool = False, strategy: MNNVLAllreduceFusionStrategy = MNNVLAllreduceFusionStrategy.AUTO, weight_bias: float = 0.0) Tuple[Tensor, Tensor | None, Tensor, Tensor | None]

Perform MNNVL AllReduce + Residual + RMSNorm + FP8/NVFP4 quantization.

Quantization is applied after RMSNorm. output is optional; pass it only when the normalized non-quantized tensor is also needed. The quantized result is always returned as quant_out.

Parameters:
  • input – Input tensor with shape [num_tokens, hidden_dim].

  • residual_in – Residual tensor with the same shape as input.

  • gamma – RMSNorm gamma tensor with shape [hidden_dim].

  • workspace – MNNVL workspace for the tensor-parallel group.

  • epsilon – RMSNorm epsilon. Defaults to torch.finfo(input.dtype).eps.

  • output – Optional normalized output tensor with shape [num_tokens, hidden_dim].

  • residual_out – Optional residual output tensor with shape [num_tokens, hidden_dim].

  • quant_out – Optional quantized output. For FP8, shape must be [num_tokens, hidden_dim] and dtype torch.float8_e4m3fn. For NVFP4, shape must be [num_tokens, hidden_dim // 2] and dtype torch.uint8 or torch.float4_e2m1fn_x2.

  • scale_out – Optional NVFP4 scale output. For LINEAR layout, shape is [num_tokens, hidden_dim // 16]. For SWIZZLED_128x4, provide a 1-D tensor large enough for the padded swizzled scale layout. FP8 ignores this argument.

  • output_scale – Scalar float or float32 tensor used as the quantization output scale. Defaults to 1.0.

  • layout_code – NVFP4 scale layout. MNNVL supports SWIZZLED_128x4 and LINEAR; SWIZZLED_8x4 is not supported.

  • quant_typeMNNVLQuantType.FP8 or MNNVLQuantType.NVFP4.

  • launch_with_pdl – Whether to launch with PDL.

  • strategy – MNNVL execution strategy. AUTO uses internal heuristics.

  • weight_bias – Bias added to gamma before scaling. 0.0 for standard RMSNorm; 1.0 for Gemma / Qwen3.5 RMSNorm.

Returns:

A tuple (quant_out, scale_out, residual_out, output). scale_out is None for FP8, and output is None unless requested.