flashinfer.rope.mla_rope_quantize_fp8

flashinfer.rope.mla_rope_quantize_fp8(q_rope: Tensor, k_rope: Tensor, q_nope: Tensor, k_nope: Tensor, cos_sin_cache: Tensor, pos_ids: Tensor, is_neox: bool = True, quantize_dtype: dtype | None = None, quant_scale_q: float = 1.0, quant_scale_kv: float = 1.0, q_rope_out: Tensor | None = None, k_rope_out: Tensor | None = None, q_nope_out: Tensor | None = None, k_nope_out: Tensor | None = None, enable_pdl: bool = False) Tuple[Tensor, Tensor, Tensor, Tensor]

Apply RoPE and quantize to FP8 for MLA attention.

Thin wrapper that forwards to rope_quantize_fp8() with the MLA tensor layout: k_rope / k_nope are 2-D (no KV-head axis). All other parameters and behavior match rope_quantize_fp8(); see that function for full parameter documentation.

Parameters:
  • q_rope (torch.Tensor) – Query rotary portion, (nnz, num_qo_heads, rope_dim), fp16/bf16.

  • k_rope (torch.Tensor) – Key rotary portion in MLA layout, (nnz, rope_dim), fp16/bf16.

  • q_nope (torch.Tensor) – Query non-rotary portion, (nnz, num_qo_heads, no_rope_dim).

  • k_nope (torch.Tensor) – Key non-rotary portion in MLA layout, (nnz, no_rope_dim).

  • cos_sin_cache (torch.Tensor) – (max_seq_len, rope_dim) precomputed cos/sin cache (fp32).

  • pos_ids (torch.Tensor) – Per-token position indices, (nnz,).

  • is_neox (bool) – True for NeoX (split-half) layout, False for interleaved.

  • quantize_dtype (torch.dtype, optional) – Target quantization dtype (float8_e4m3fn or float8_e5m2). Inferred from *_out tensors when None.

  • quant_scale_q (float) – Quantization scale applied to queries.

  • quant_scale_kv (float) – Quantization scale applied to keys.

  • q_rope_out (torch.Tensor, optional) – Pre-allocated output tensors. Allocated automatically when None.

  • k_rope_out (torch.Tensor, optional) – Pre-allocated output tensors. Allocated automatically when None.

  • q_nope_out (torch.Tensor, optional) – Pre-allocated output tensors. Allocated automatically when None.

  • k_nope_out (torch.Tensor, optional) – Pre-allocated output tensors. Allocated automatically when None.

  • enable_pdl (bool) – Whether to enable Programmatic Dependent Launch.

Returns:

(q_rope_out, k_rope_out, q_nope_out, k_nope_out).

Return type:

Tuple[torch.Tensor, torch.Tensor, torch.Tensor, torch.Tensor]