flashinfer.fused_moe.interleave_moe_scales_for_sm90_mixed_gemm¶
- flashinfer.fused_moe.interleave_moe_scales_for_sm90_mixed_gemm(scales: Tensor, group_size: int = 32) Tensor¶
Interleave MXFP4 block scales for the SM90 mixed-input MoE GEMM.
The kernel expects scales in layout
(num_experts, K // (group_size * 4), rows * 4)rather than the natural(num_experts, rows, K // group_size)produced by the MXFP4 quantizer. This helper performs the reshape + permute equivalent to TensorRT-LLM’sWFP4A16FusedMoEMethod.load_quant_scales(PR #12451), with the fixed interleave factor of128 // group_sizeused for MXFP4.- Parameters:
scales –
[num_experts, rows, K // group_size]uint8 tensor of E8M0 block scales.group_size – MXFP4 quantization group size (default 32).
- Returns:
Contiguous uint8 tensor with shape
[num_experts, K // (group_size * factor), rows * factor]wherefactor = 128 // group_size.- Return type:
torch.Tensor