flashinfer.fused_moe.bgmv_moe_expand¶
- flashinfer.fused_moe.bgmv_moe_expand(y: Tensor, x: Tensor, w_ptr: Tensor, sorted_token_ids: Tensor, expert_ids: Tensor, topk_weights: Tensor, lora_indices: Tensor, slice_start_loc: Tensor, output_slices: List[int], lora_stride: int) None¶
MoE LoRA expand operation: project through LoRA-B matrices with routing weights.
- For each (token, expert) pair, computes:
y[token, col_offset:col_offset+feat] += topk_weight * (x[slice, pair, :] @ lora_b[expert, lora_id, :, :])
- Parameters:
y – Output tensor [num_tokens, total_feat_out]. Float32 accumulation buffer.
x – Shrink output [num_slices, num_pairs, rank].
w_ptr – Pointer table [num_slices, num_experts] of int64.
sorted_token_ids – Token indices for each pair [num_pairs].
expert_ids – Expert indices for each pair [num_pairs].
topk_weights – Routing weights for each pair [num_pairs]. Float32.
lora_indices – LoRA adapter ID for each token [num_tokens].
slice_start_loc – Column offset for each slice [num_slices]. Int64.
output_slices – Output feature dimension for each slice.
lora_stride – Stride between LoRA adapters in weight tensor.