flashinfer.fused_moe.trtllm_fp8_per_tensor_scale_moe

flashinfer.fused_moe.trtllm_fp8_per_tensor_scale_moe(routing_logits: torch.Tensor, routing_bias: torch.Tensor | None, hidden_states: torch.Tensor, gemm1_weights: torch.Tensor, output1_scales_scalar: torch.Tensor, output1_scales_gate_scalar: torch.Tensor, gemm2_weights: torch.Tensor, output2_scales_scalar: torch.Tensor, num_experts: int, top_k: int, n_group: int, topk_group: int, intermediate_size: int, local_expert_offset: int, local_num_experts: int, routed_scaling_factor: float, use_routing_scales_on_input: bool, tile_tokens_dim: int = 8, routing_method_type: int = 0) torch.Tensor

FP8 per tensor scale MoE operation.

Parameters:
  • routing_logits – [seq_len, num_experts] tensor of routing logits

  • routing_bias – [num_experts] tensor of routing bias

  • hidden_states – [seq_len, hidden_size] tensor of input hidden states

  • gemm1_weights – [num_experts, 2*intermediate_size, hidden_size] tensor of first layer weights

  • output1_scales_scalar – [local_num_experts] tensor of first layer output scales

  • output1_scales_gate_scalar – [local_num_experts] tensor of first layer gate scales

  • gemm2_weights – [num_experts, hidden_size, intermediate_size] tensor of second layer weights

  • output2_scales_scalar – [local_num_experts] tensor of second layer output scales

  • num_experts – Total number of experts

  • top_k – Number of experts to route to per token

  • n_group – Number of expert groups

  • topk_group – Number of groups to consider for top-k routing

  • intermediate_size – Size of intermediate layer

  • local_expert_offset – Offset of local experts in global expert space

  • local_num_experts – Number of experts handled by this device

  • routed_scaling_factor – Scaling factor for routing

  • use_routing_scales_on_input – Whether to use routing scales on input

  • tile_tokens_dim – Tile dimension for tokens (default: 8)

  • routing_method_type – Type of routing method to use (default: 0)

Returns:

Output tensor of shape [seq_len, hidden_size]

Return type:

torch.Tensor