flashinfer.fused_moe.interleave_moe_weights_for_sm90_mixed_gemm¶
- flashinfer.fused_moe.interleave_moe_weights_for_sm90_mixed_gemm(weight: Tensor, quant_type: str = 'fp4') Tensor¶
Interleave 4-bit packed MoE weights for the SM90 mixed-input GEMM.
The SM90 mixed-dtype MoE GEMM (used by
cutlass_fused_moewithuse_w4_group_scaling=True) expects weights in a specific interleaved layout; without preprocessing, the LUT-based FP4→BF16 conversion reads bytes from the wrong positions and the output diverges from a dequantized reference for any K > 128. TensorRT-LLM’s W4A16 MoE runs the equivalent preprocessing at weight-load time (seeinterleave_4bit_weights_for_Hopper_mixed_gemmin TRT-LLM PR #12451).- Parameters:
weight –
[num_experts, n, k // 2]uint8 CUDA tensor (4-bit values packed two-per-byte).quant_type –
"fp4"for MXFP4 (the W4A16 path) or"int4"for INT4 (the W4A8 path).
- Returns:
A new uint8 tensor with the same shape as
weightholding the interleaved layout. Feed this directly asfc1_expert_weights/fc2_expert_weightstocutlass_fused_moe().- Return type:
torch.Tensor