flashinfer.fused_moe.trtllm_fp8_block_scale_moe¶
- flashinfer.fused_moe.trtllm_fp8_block_scale_moe(routing_logits: torch.Tensor, routing_bias: torch.Tensor | None, hidden_states: torch.Tensor, hidden_states_scale: torch.Tensor, gemm1_weights: torch.Tensor, gemm1_weights_scale: torch.Tensor, gemm2_weights: torch.Tensor, gemm2_weights_scale: torch.Tensor, num_experts: int, top_k: int, n_group: int, topk_group: int, intermediate_size: int, local_expert_offset: int, local_num_experts: int, routed_scaling_factor: float, tile_tokens_dim: int = 8, routing_method_type: int = 0, use_shuffled_weight: bool = False, weight_layout: int = 0) torch.Tensor ¶
FP8 block scale MoE operation.
- Parameters:
routing_logits – [seq_len, num_experts] tensor of routing logits
routing_bias – [num_experts] tensor of routing bias
hidden_states – [seq_len, hidden_size] tensor of input hidden states
hidden_states_scale – [hidden_size//128, seq_len] tensor of hidden states block scales
gemm1_weights – [num_experts, 2*intermediate_size, hidden_size] tensor of first layer weights
gemm1_weights_scale – [num_experts, 2*intermediate_size//128, hidden_size//128] tensor of first layer block scales
gemm2_weights – [num_experts, hidden_size, intermediate_size] tensor of second layer weights
gemm2_weights_scale – [num_experts, hidden_size//128, intermediate_size//128] tensor of second layer block scales
num_experts – Total number of experts
top_k – Number of experts to route to per token
n_group – Number of expert groups
topk_group – Number of groups to consider for top-k routing
intermediate_size – Size of intermediate layer
local_expert_offset – Offset of local experts in global expert space
local_num_experts – Number of experts handled by this device
routed_scaling_factor – Scaling factor for routing
tile_tokens_dim – Tile dimension for tokens (default: 8)
routing_method_type – Type of routing method to use (default: 0)
- Returns:
Output tensor of shape [seq_len, hidden_size]
- Return type:
torch.Tensor