flashinfer.rope.rope_quantize_fp8_append_paged_kv_cache¶
- flashinfer.rope.rope_quantize_fp8_append_paged_kv_cache(q_rope: Tensor, k_rope: Tensor, q_nope: Tensor | None, k_nope: Tensor | None, v: Tensor | None, cos_sin_cache: Tensor, pos_ids: Tensor, paged_kv_cache: Tuple[Tensor, Tensor], kv_indices: Tensor, kv_indptr: Tensor, batch_indices: Tensor, positions: Tensor, is_neox: bool = True, quantize_dtype: dtype | None = None, quant_scale_q: float = 1.0, quant_scale_kv: float = 1.0, page_size: int = 16, kv_layout: str = 'NHD', q_rope_out: Tensor | None = None, q_nope_out: Tensor | None = None, enable_pdl: bool = False) Tuple[Tensor, Tensor]¶
Apply RoPE (Rotary Positional Embeddings), quantize to FP8, and append K/V to paged cache.
This fused function applies RoPE to query/key (Q/K) rotary dimension tensors, quantizes all Q/K tensors (and V for GQA/MHA) to FP8 format, and directly appends the quantized K/V to a paged KV cache. It returns quantized Q tensors for use in attention computation. Supports MLA, GQA, and MHA architectures with automatic detection based on input tensor shapes.
- Parameters:
q_rope (torch.Tensor) – Query tensor (rotary dimensions), shape:
(nnz, num_qo_heads, rope_dim). Must be float16 or bfloat16.k_rope (torch.Tensor) – Key tensor (rotary dimensions). For GQA/MHA:
(nnz, num_kv_heads, rope_dim). For MLA:(nnz, rope_dim). Must be float16 or bfloat16.q_nope (torch.Tensor) – Query tensor (non-rotary dimensions), shape:
(nnz, num_qo_heads, no_rope_dim). Must be float16 or bfloat16.k_nope (torch.Tensor) – Key tensor (non-rotary dimensions). For GQA/MHA:
(nnz, num_kv_heads, no_rope_dim). For MLA:(nnz, no_rope_dim). Must be float16 or bfloat16.v (Optional[torch.Tensor]) – Value tensor for GQA/MHA:
(nnz, num_kv_heads, head_dim). Must be float16 or bfloat16. For MLA: passNone(MLA does not use separate V; K non-RoPE acts as compressed KV).cos_sin_cache (torch.Tensor) – Precomputed cosine and sine values, shape:
(max_seq_len, rope_dim). First half contains cosine values, second half contains sine values. Must be float32.pos_ids (torch.Tensor) – Position indices for each token, shape:
(nnz,).paged_kv_cache (Tuple[torch.Tensor, torch.Tensor]) –
- For MLA:
(ckv_cache, kpe_cache)where: ckv_cache:
(max_pages, page_size, no_rope_dim)in FP8kpe_cache:
(max_pages, page_size, rope_dim)in FP8
- For GQA/MHA:
(k_cache, v_cache)where: k_cache:
(max_pages, page_size, num_kv_heads, head_dim)or(max_pages, num_kv_heads, page_size, head_dim)depending on layout, in FP8v_cache: same shape as k_cache, in FP8
- For MLA:
kv_indices (torch.Tensor) – Page indices mapping, shape:
(total_pages,). Typicallytorch.arange(total_pages).kv_indptr (torch.Tensor) – Page indptr array for each request, shape:
(batch_size + 1,).kv_indptr[i]is the starting page index for requesti.batch_indices (torch.Tensor) – Batch index for each token, shape:
(nnz,). Maps each token to its request.positions (torch.Tensor) – Position within each request’s sequence for each token, shape:
(nnz,).is_neox (bool) – RoPE layout style. If
True(default), use non-interleaved layout (first/second half). IfFalse, use interleaved layout (even/odd dimensions).quantize_dtype (Optional[torch.dtype]) – Target quantization dtype. If
None, inferred from output tensors or defaults totorch.float8_e4m3fn. Must betorch.float8_e4m3fnortorch.float8_e5m2.quant_scale_q (float) – Quantization scaling factor for query tensors, default:
1.0.quant_scale_kv (float) – Quantization scaling factor for key/value tensors, default:
1.0.page_size (int) – Number of entries per page in the paged cache, default:
16.kv_layout (str) – Cache memory layout for GQA/MHA. Options:
"NHD"(page, seq, head, dim) or"HND"(page, head, seq, dim). Default:"NHD". Ignored for MLA.q_rope_out (Optional[torch.Tensor]) – Pre-allocated output tensor for quantized query (rotary). If
None, allocated automatically.q_nope_out (Optional[torch.Tensor]) – Pre-allocated output tensor for quantized query (non-rotary). If
None, allocated automatically.enable_pdl (bool) – Whether to enable PDL (Programmatic Dependent Launch). Default:
False.
- Returns:
Quantized query tensors: (q_rope_out, q_nope_out). K/V are written directly to the paged cache and not returned.
- Return type:
Tuple[torch.Tensor, torch.Tensor]
Notes
Architecture detection: Automatically distinguishes MLA (2D K tensors) from GQA/MHA (3D K tensors).
MLA writes K-RoPE to
kpe_cacheand K-noRoPE tockv_cache; V is not used.GQA/MHA writes full K (RoPE+noRoPE) to
k_cacheand V tov_cache.The
batch_indicesandpositionstensors are typically obtained fromflashinfer.get_batch_indices_positions().Cache tensors must already be allocated in the target FP8 dtype.