flashinfer.fused_moe.trtllm_fp8_block_scale_moe

flashinfer.fused_moe.trtllm_fp8_block_scale_moe(routing_logits: torch.Tensor, routing_bias: torch.Tensor | None, hidden_states: torch.Tensor, hidden_states_scale: torch.Tensor, gemm1_weights: torch.Tensor, gemm1_weights_scale: torch.Tensor, gemm2_weights: torch.Tensor, gemm2_weights_scale: torch.Tensor, num_experts: int, top_k: int, n_group: int, topk_group: int, intermediate_size: int, local_expert_offset: int, local_num_experts: int, routed_scaling_factor: float, tile_tokens_dim: int = 8, routing_method_type: int = 0, use_shuffled_weight: bool = False, weight_layout: int = 0) torch.Tensor

FP8 block scale MoE operation.

Parameters:
  • routing_logits – [seq_len, num_experts] tensor of routing logits

  • routing_bias – [num_experts] tensor of routing bias

  • hidden_states – [seq_len, hidden_size] tensor of input hidden states

  • hidden_states_scale – [hidden_size//128, seq_len] tensor of hidden states block scales

  • gemm1_weights – [num_experts, 2*intermediate_size, hidden_size] tensor of first layer weights

  • gemm1_weights_scale – [num_experts, 2*intermediate_size//128, hidden_size//128] tensor of first layer block scales

  • gemm2_weights – [num_experts, hidden_size, intermediate_size] tensor of second layer weights

  • gemm2_weights_scale – [num_experts, hidden_size//128, intermediate_size//128] tensor of second layer block scales

  • num_experts – Total number of experts

  • top_k – Number of experts to route to per token

  • n_group – Number of expert groups

  • topk_group – Number of groups to consider for top-k routing

  • intermediate_size – Size of intermediate layer

  • local_expert_offset – Offset of local experts in global expert space

  • local_num_experts – Number of experts handled by this device

  • routed_scaling_factor – Scaling factor for routing

  • tile_tokens_dim – Tile dimension for tokens (default: 8)

  • routing_method_type – Type of routing method to use (default: 0)

Returns:

Output tensor of shape [seq_len, hidden_size]

Return type:

torch.Tensor