flashinfer.fused_moe.trtllm_fp8_block_scale_moe¶
- flashinfer.fused_moe.trtllm_fp8_block_scale_moe(routing_logits: Tensor, routing_bias: Tensor | None, hidden_states: Tensor, hidden_states_scale: Tensor, gemm1_weights: Tensor, gemm1_weights_scale: Tensor, gemm2_weights: Tensor, gemm2_weights_scale: Tensor, num_experts: int, top_k: int, n_group: int | None, topk_group: int | None, intermediate_size: int, local_expert_offset: int, local_num_experts: int, routed_scaling_factor: float | None, routing_method_type: int = 0, use_shuffled_weight: bool = False, weight_layout: int = 0, enable_pdl: bool | None = None, tune_max_num_tokens: int = 8192, fp8_quantization_type: Fp8QuantizationType = Fp8QuantizationType.DeepSeekFp8) Tensor¶
FP8 block scale MoE operation.
- Parameters:
routing_logits – [seq_len, num_experts] tensor of routing logits
routing_bias – [num_experts] tensor of routing bias
hidden_states – [seq_len, hidden_size] tensor of input hidden states
hidden_states_scale – [hidden_size//128, seq_len] tensor of hidden states block scales
gemm1_weights – tensor of first layer weights - [num_experts, 2*intermediate_size, hidden_size] if weight_layout == WeightLayout.MajorK - [num_experts, 2*intermediate_size // 128, hidden_size, 128] if weight_layout == WeightLayout.BlockMajorK
gemm1_weights_scale – [num_experts, 2*intermediate_size//(32 if mxfp8 else 128), hidden_size//(32 if mxfp8 else 128)] tensor of first layer block scales
gemm2_weights – tensor of second layer weights - [num_experts, hidden_size, intermediate_size] if weight_layout == WeightLayout.MajorK - [num_experts, hidden_size//128, intermediate_size, 128] if weight_layout == WeightLayout.BlockMajorK
gemm2_weights_scale – [num_experts, hidden_size//(32 if mxfp8 else 128), intermediate_size//(32 if mxfp8 else 128)] tensor of second layer block scales
num_experts – Total number of experts
top_k – Number of experts to route to per token
n_group – Number of expert groups
topk_group – Number of groups to consider for top-k routing
intermediate_size – Size of intermediate layer
local_expert_offset – Offset of local experts in global expert space
local_num_experts – Number of experts handled by this device
routed_scaling_factor – Scaling factor for routing
routing_method_type – Type of routing method to use (default: 0)
weight_layout – Weight layout format (default: WeightLayout.MajorK). Supported layouts: - 0: MajorK - K-major layout [Mn, K] - 2: BlockMajorK - Blocked along K dimension [K/blockK, Mn, blockK]
enable_pdl – Whether to enable Programmatic Dependent Launch (PDL). Auto-enabled for >= sm90.
tune_max_num_tokens (int) – Maximum number of tokens for tuning. (default: 8192)
fp8_quantization_type – Type of FP8 quantization to use (default: DeepSeekFp8)
- Returns:
Output tensor of shape [seq_len, hidden_size]
- Return type:
torch.Tensor