flashinfer.fused_moe.bgmv_moe¶
- flashinfer.fused_moe.bgmv_moe(x: Tensor, lora_a_weights: List[Tensor], lora_b_weights: List[Tensor], sorted_token_ids: Tensor, expert_ids: Tensor, lora_indices: Tensor, topk_weights: Tensor, num_experts: int, output_dim: int | None = None) Tensor¶
High-level multi-LoRA MoE BGMV: shrink + expand in one call.
- Computes the LoRA delta for MoE:
delta[token] = Σ_expert (topk_weight * x[token] @ lora_a[expert, lora_id] @ lora_b[expert, lora_id])
- Parameters:
x – Input activations [num_tokens, hidden_dim].
lora_a_weights – List of LoRA-A weight tensors, one per slice. Each has shape [max_loras, num_experts, rank, hidden_dim].
lora_b_weights – List of LoRA-B weight tensors, one per slice. Each has shape [max_loras, num_experts, feat_out, rank].
sorted_token_ids – Token indices for each pair [num_pairs].
expert_ids – Expert indices for each pair [num_pairs].
lora_indices – LoRA adapter ID for each token [num_tokens].
topk_weights – Routing weights for each pair [num_pairs].
num_experts – Number of experts.
output_dim – Total output dimension. If None, inferred from lora_b_weights.
- Returns:
Output tensor [num_tokens, total_feat_out] with LoRA deltas.