flashinfer.fused_moe.interleave_moe_scales_for_sm90_mixed_gemm

flashinfer.fused_moe.interleave_moe_scales_for_sm90_mixed_gemm(scales: Tensor, group_size: int = 32) Tensor

Interleave MXFP4 block scales for the SM90 mixed-input MoE GEMM.

The kernel expects scales in layout (num_experts, K // (group_size * 4), rows * 4) rather than the natural (num_experts, rows, K // group_size) produced by the MXFP4 quantizer. This helper performs the reshape + permute equivalent to TensorRT-LLM’s WFP4A16FusedMoEMethod.load_quant_scales (PR #12451), with the fixed interleave factor of 128 // group_size used for MXFP4.

Parameters:
  • scales[num_experts, rows, K // group_size] uint8 tensor of E8M0 block scales.

  • group_size – MXFP4 quantization group size (default 32).

Returns:

Contiguous uint8 tensor with shape [num_experts, K // (group_size * factor), rows * factor] where factor = 128 // group_size.

Return type:

torch.Tensor