flashinfer.fused_moe.interleave_moe_weights_for_sm90_mixed_gemm

flashinfer.fused_moe.interleave_moe_weights_for_sm90_mixed_gemm(weight: Tensor, quant_type: str = 'fp4') Tensor

Interleave 4-bit packed MoE weights for the SM90 mixed-input GEMM.

The SM90 mixed-dtype MoE GEMM (used by cutlass_fused_moe with use_w4_group_scaling=True) expects weights in a specific interleaved layout; without preprocessing, the LUT-based FP4→BF16 conversion reads bytes from the wrong positions and the output diverges from a dequantized reference for any K > 128. TensorRT-LLM’s W4A16 MoE runs the equivalent preprocessing at weight-load time (see interleave_4bit_weights_for_Hopper_mixed_gemm in TRT-LLM PR #12451).

Parameters:
  • weight[num_experts, n, k // 2] uint8 CUDA tensor (4-bit values packed two-per-byte).

  • quant_type"fp4" for MXFP4 (the W4A16 path) or "int4" for INT4 (the W4A8 path).

Returns:

A new uint8 tensor with the same shape as weight holding the interleaved layout. Feed this directly as fc1_expert_weights / fc2_expert_weights to cutlass_fused_moe().

Return type:

torch.Tensor