flashinfer.rope.rope_quantize_fp8

flashinfer.rope.rope_quantize_fp8(q_rope: Tensor, k_rope: Tensor, q_nope: Tensor | None, k_nope: Tensor | None, cos_sin_cache: Tensor, pos_ids: Tensor, is_neox: bool = True, quantize_dtype: dtype | None = None, quant_scale_q: float = 1.0, quant_scale_kv: float = 1.0, q_rope_out: Tensor | None = None, k_rope_out: Tensor | None = None, q_nope_out: Tensor | None = None, k_nope_out: Tensor | None = None, enable_pdl: bool = False) Tuple[Tensor, Tensor, Tensor, Tensor]

Apply RoPE (Rotary Positional Embeddings) and quantize to FP8 format.

This function takes pre-split query/key tensors (rotary and non-rotary dimensions separated), applies RoPE to the rotary dimension tensors, and quantizes both rotary and non-rotary tensors to FP8 format. Supports MLA, GQA, and MHA architectures.

Parameters:
  • q_rope (torch.Tensor) – Query tensor (rotary dimensions), shape: (nnz, num_qo_heads, rope_dim). Must be float16 or bfloat16.

  • k_rope (torch.Tensor) – Key tensor (rotary dimensions). For GQA/MHA: (nnz, num_kv_heads, rope_dim). For MLA: (nnz, rope_dim). Must be float16 or bfloat16.

  • q_nope (Optional[torch.Tensor]) – Query tensor (non-rotary dimensions), shape: (nnz, num_qo_heads, no_rope_dim). If None, treated as zero-dim: a size-0 tensor will be created internally.

  • k_nope (Optional[torch.Tensor]) – Key tensor (non-rotary dimensions). For GQA/MHA: (nnz, num_kv_heads, no_rope_dim). For MLA: (nnz, no_rope_dim). If None, treated as zero-dim and created internally.

  • cos_sin_cache (torch.Tensor) – Precomputed cosine and sine values, shape: (max_seq_len, rope_dim). First half contains cosine values, second half contains sine values. Must be float32.

  • pos_ids (torch.Tensor) – Position indices for each token, shape: (nnz,).

  • is_neox (bool) – RoPE layout style. If True (default), use non-interleaved layout (first/second half). If False, use interleaved layout (even/odd dimensions).

  • quantize_dtype (Optional[torch.dtype]) – Target quantization dtype. If None, inferred from output tensors or defaults to torch.float8_e4m3fn. Must be torch.float8_e4m3fn or torch.float8_e5m2.

  • quant_scale_q (float) – Quantization scaling factor for query tensors, default: 1.0.

  • quant_scale_kv (float) – Quantization scaling factor for key tensors, default: 1.0.

  • q_rope_out (Optional[torch.Tensor]) – Pre-allocated output tensor for quantized query (rotary). If None, allocated automatically.

  • k_rope_out (Optional[torch.Tensor]) – Pre-allocated output tensor for quantized key (rotary). If None, allocated automatically.

  • q_nope_out (Optional[torch.Tensor]) – Pre-allocated output tensor for quantized query (non-rotary). If None, allocated automatically.

  • k_nope_out (Optional[torch.Tensor]) – Pre-allocated output tensor for quantized key (non-rotary). If None, allocated automatically.

  • enable_pdl (bool) – Whether to enable PDL (Programmatic Dependent Launch). Default: False.

Returns:

Quantized tensors: (q_rope_out, k_rope_out, q_nope_out, k_nope_out).

Return type:

Tuple[torch.Tensor, torch.Tensor, torch.Tensor, torch.Tensor]