flashinfer.rope.rope_quantize_fp8_append_paged_kv_cache

flashinfer.rope.rope_quantize_fp8_append_paged_kv_cache(q_rope: Tensor, k_rope: Tensor, q_nope: Tensor | None, k_nope: Tensor | None, v: Tensor | None, cos_sin_cache: Tensor, pos_ids: Tensor, paged_kv_cache: Tuple[Tensor, Tensor], kv_indices: Tensor, kv_indptr: Tensor, batch_indices: Tensor, positions: Tensor, is_neox: bool = True, quantize_dtype: dtype | None = None, quant_scale_q: float = 1.0, quant_scale_kv: float = 1.0, page_size: int = 16, kv_layout: str = 'NHD', q_rope_out: Tensor | None = None, q_nope_out: Tensor | None = None, enable_pdl: bool = False) Tuple[Tensor, Tensor]

Apply RoPE (Rotary Positional Embeddings), quantize to FP8, and append K/V to paged cache.

This fused function applies RoPE to query/key (Q/K) rotary dimension tensors, quantizes all Q/K tensors (and V for GQA/MHA) to FP8 format, and directly appends the quantized K/V to a paged KV cache. It returns quantized Q tensors for use in attention computation. Supports MLA, GQA, and MHA architectures with automatic detection based on input tensor shapes.

Parameters:
  • q_rope (torch.Tensor) – Query tensor (rotary dimensions), shape: (nnz, num_qo_heads, rope_dim). Must be float16 or bfloat16.

  • k_rope (torch.Tensor) – Key tensor (rotary dimensions). For GQA/MHA: (nnz, num_kv_heads, rope_dim). For MLA: (nnz, rope_dim). Must be float16 or bfloat16.

  • q_nope (torch.Tensor) – Query tensor (non-rotary dimensions), shape: (nnz, num_qo_heads, no_rope_dim). Must be float16 or bfloat16.

  • k_nope (torch.Tensor) – Key tensor (non-rotary dimensions). For GQA/MHA: (nnz, num_kv_heads, no_rope_dim). For MLA: (nnz, no_rope_dim). Must be float16 or bfloat16.

  • v (Optional[torch.Tensor]) – Value tensor for GQA/MHA: (nnz, num_kv_heads, head_dim). Must be float16 or bfloat16. For MLA: pass None (MLA does not use separate V; K non-RoPE acts as compressed KV).

  • cos_sin_cache (torch.Tensor) – Precomputed cosine and sine values, shape: (max_seq_len, rope_dim). First half contains cosine values, second half contains sine values. Must be float32.

  • pos_ids (torch.Tensor) – Position indices for each token, shape: (nnz,).

  • paged_kv_cache (Tuple[torch.Tensor, torch.Tensor]) –

    For MLA: (ckv_cache, kpe_cache) where:
    • ckv_cache: (max_pages, page_size, no_rope_dim) in FP8

    • kpe_cache: (max_pages, page_size, rope_dim) in FP8

    For GQA/MHA: (k_cache, v_cache) where:
    • k_cache: (max_pages, page_size, num_kv_heads, head_dim) or (max_pages, num_kv_heads, page_size, head_dim) depending on layout, in FP8

    • v_cache: same shape as k_cache, in FP8

  • kv_indices (torch.Tensor) – Page indices mapping, shape: (total_pages,). Typically torch.arange(total_pages).

  • kv_indptr (torch.Tensor) – Page indptr array for each request, shape: (batch_size + 1,). kv_indptr[i] is the starting page index for request i.

  • batch_indices (torch.Tensor) – Batch index for each token, shape: (nnz,). Maps each token to its request.

  • positions (torch.Tensor) – Position within each request’s sequence for each token, shape: (nnz,).

  • is_neox (bool) – RoPE layout style. If True (default), use non-interleaved layout (first/second half). If False, use interleaved layout (even/odd dimensions).

  • quantize_dtype (Optional[torch.dtype]) – Target quantization dtype. If None, inferred from output tensors or defaults to torch.float8_e4m3fn. Must be torch.float8_e4m3fn or torch.float8_e5m2.

  • quant_scale_q (float) – Quantization scaling factor for query tensors, default: 1.0.

  • quant_scale_kv (float) – Quantization scaling factor for key/value tensors, default: 1.0.

  • page_size (int) – Number of entries per page in the paged cache, default: 16.

  • kv_layout (str) – Cache memory layout for GQA/MHA. Options: "NHD" (page, seq, head, dim) or "HND" (page, head, seq, dim). Default: "NHD". Ignored for MLA.

  • q_rope_out (Optional[torch.Tensor]) – Pre-allocated output tensor for quantized query (rotary). If None, allocated automatically.

  • q_nope_out (Optional[torch.Tensor]) – Pre-allocated output tensor for quantized query (non-rotary). If None, allocated automatically.

  • enable_pdl (bool) – Whether to enable PDL (Programmatic Dependent Launch). Default: False.

Returns:

Quantized query tensors: (q_rope_out, q_nope_out). K/V are written directly to the paged cache and not returned.

Return type:

Tuple[torch.Tensor, torch.Tensor]

Notes

  • Architecture detection: Automatically distinguishes MLA (2D K tensors) from GQA/MHA (3D K tensors).

  • MLA writes K-RoPE to kpe_cache and K-noRoPE to ckv_cache; V is not used.

  • GQA/MHA writes full K (RoPE+noRoPE) to k_cache and V to v_cache.

  • The batch_indices and positions tensors are typically obtained from flashinfer.get_batch_indices_positions().

  • Cache tensors must already be allocated in the target FP8 dtype.