flashinfer.fused_moe.bgmv_moe

flashinfer.fused_moe.bgmv_moe(x: Tensor, lora_a_weights: List[Tensor], lora_b_weights: List[Tensor], sorted_token_ids: Tensor, expert_ids: Tensor, lora_indices: Tensor, topk_weights: Tensor, num_experts: int, output_dim: int | None = None) Tensor

High-level multi-LoRA MoE BGMV: shrink + expand in one call.

Computes the LoRA delta for MoE:

delta[token] = Σ_expert (topk_weight * x[token] @ lora_a[expert, lora_id] @ lora_b[expert, lora_id])

Parameters:
  • x – Input activations [num_tokens, hidden_dim].

  • lora_a_weights – List of LoRA-A weight tensors, one per slice. Each has shape [max_loras, num_experts, rank, hidden_dim].

  • lora_b_weights – List of LoRA-B weight tensors, one per slice. Each has shape [max_loras, num_experts, feat_out, rank].

  • sorted_token_ids – Token indices for each pair [num_pairs].

  • expert_ids – Expert indices for each pair [num_pairs].

  • lora_indices – LoRA adapter ID for each token [num_tokens].

  • topk_weights – Routing weights for each pair [num_pairs].

  • num_experts – Number of experts.

  • output_dim – Total output dimension. If None, inferred from lora_b_weights.

Returns:

Output tensor [num_tokens, total_feat_out] with LoRA deltas.