flashinfer.fused_moe.bgmv_moe_shrink¶
- flashinfer.fused_moe.bgmv_moe_shrink(y: Tensor, x: Tensor, w_ptr: Tensor, sorted_token_ids: Tensor, expert_ids: Tensor, lora_indices: Tensor, lora_stride: int) None¶
MoE LoRA shrink operation: project input through LoRA-A matrices.
- For each (token, expert) pair, computes:
y[slice, pair, rank] += x[token] @ lora_a[expert, lora_id, :, :]
- Parameters:
y – Output tensor [num_slices, num_pairs, rank]. Accumulated in-place.
x – Input activations [num_tokens, hidden_dim].
w_ptr – Pointer table [num_slices, num_experts] of int64. Each entry points to the start of lora_a weights for (slice, expert). The kernel uses lora_stride to index different LoRA adapters.
sorted_token_ids – Token indices for each pair [num_pairs].
expert_ids – Expert indices for each pair [num_pairs].
lora_indices – LoRA adapter ID for each token [num_tokens]. -1 means no LoRA (pair is skipped).
lora_stride – Stride (in elements) between consecutive LoRA adapters in the weight tensor. For layout [max_loras, num_experts, rank, feat], this is num_experts * rank * feat.