flashinfer.rope.mla_rope_quantize_fp8¶
- flashinfer.rope.mla_rope_quantize_fp8(q_rope: Tensor, k_rope: Tensor, q_nope: Tensor, k_nope: Tensor, cos_sin_cache: Tensor, pos_ids: Tensor, is_neox: bool = True, quantize_dtype: dtype | None = None, quant_scale_q: float = 1.0, quant_scale_kv: float = 1.0, q_rope_out: Tensor | None = None, k_rope_out: Tensor | None = None, q_nope_out: Tensor | None = None, k_nope_out: Tensor | None = None, enable_pdl: bool = False) Tuple[Tensor, Tensor, Tensor, Tensor]¶
Apply RoPE and quantize to FP8 for MLA attention.
Thin wrapper that forwards to
rope_quantize_fp8()with the MLA tensor layout:k_rope/k_nopeare 2-D (no KV-head axis). All other parameters and behavior matchrope_quantize_fp8(); see that function for full parameter documentation.- Parameters:
q_rope (torch.Tensor) – Query rotary portion,
(nnz, num_qo_heads, rope_dim), fp16/bf16.k_rope (torch.Tensor) – Key rotary portion in MLA layout,
(nnz, rope_dim), fp16/bf16.q_nope (torch.Tensor) – Query non-rotary portion,
(nnz, num_qo_heads, no_rope_dim).k_nope (torch.Tensor) – Key non-rotary portion in MLA layout,
(nnz, no_rope_dim).cos_sin_cache (torch.Tensor) –
(max_seq_len, rope_dim)precomputed cos/sin cache (fp32).pos_ids (torch.Tensor) – Per-token position indices,
(nnz,).is_neox (bool) –
Truefor NeoX (split-half) layout,Falsefor interleaved.quantize_dtype (torch.dtype, optional) – Target quantization dtype (
float8_e4m3fnorfloat8_e5m2). Inferred from*_outtensors whenNone.quant_scale_q (float) – Quantization scale applied to queries.
quant_scale_kv (float) – Quantization scale applied to keys.
q_rope_out (torch.Tensor, optional) – Pre-allocated output tensors. Allocated automatically when
None.k_rope_out (torch.Tensor, optional) – Pre-allocated output tensors. Allocated automatically when
None.q_nope_out (torch.Tensor, optional) – Pre-allocated output tensors. Allocated automatically when
None.k_nope_out (torch.Tensor, optional) – Pre-allocated output tensors. Allocated automatically when
None.enable_pdl (bool) – Whether to enable Programmatic Dependent Launch.
- Returns:
(q_rope_out, k_rope_out, q_nope_out, k_nope_out).- Return type:
Tuple[torch.Tensor, torch.Tensor, torch.Tensor, torch.Tensor]