flashinfer.rope.rope_quantize_fp8¶
- flashinfer.rope.rope_quantize_fp8(q_rope: Tensor, k_rope: Tensor, q_nope: Tensor | None, k_nope: Tensor | None, cos_sin_cache: Tensor, pos_ids: Tensor, is_neox: bool = True, quantize_dtype: dtype | None = None, quant_scale_q: float = 1.0, quant_scale_kv: float = 1.0, q_rope_out: Tensor | None = None, k_rope_out: Tensor | None = None, q_nope_out: Tensor | None = None, k_nope_out: Tensor | None = None, enable_pdl: bool = False) Tuple[Tensor, Tensor, Tensor, Tensor]¶
Apply RoPE (Rotary Positional Embeddings) and quantize to FP8 format.
This function takes pre-split query/key tensors (rotary and non-rotary dimensions separated), applies RoPE to the rotary dimension tensors, and quantizes both rotary and non-rotary tensors to FP8 format. Supports MLA, GQA, and MHA architectures.
- Parameters:
q_rope (torch.Tensor) – Query tensor (rotary dimensions), shape:
(nnz, num_qo_heads, rope_dim). Must be float16 or bfloat16.k_rope (torch.Tensor) – Key tensor (rotary dimensions). For GQA/MHA:
(nnz, num_kv_heads, rope_dim). For MLA:(nnz, rope_dim). Must be float16 or bfloat16.q_nope (Optional[torch.Tensor]) – Query tensor (non-rotary dimensions), shape:
(nnz, num_qo_heads, no_rope_dim). IfNone, treated as zero-dim: a size-0 tensor will be created internally.k_nope (Optional[torch.Tensor]) – Key tensor (non-rotary dimensions). For GQA/MHA:
(nnz, num_kv_heads, no_rope_dim). For MLA:(nnz, no_rope_dim). IfNone, treated as zero-dim and created internally.cos_sin_cache (torch.Tensor) – Precomputed cosine and sine values, shape:
(max_seq_len, rope_dim). First half contains cosine values, second half contains sine values. Must be float32.pos_ids (torch.Tensor) – Position indices for each token, shape:
(nnz,).is_neox (bool) – RoPE layout style. If
True(default), use non-interleaved layout (first/second half). IfFalse, use interleaved layout (even/odd dimensions).quantize_dtype (Optional[torch.dtype]) – Target quantization dtype. If
None, inferred from output tensors or defaults totorch.float8_e4m3fn. Must betorch.float8_e4m3fnortorch.float8_e5m2.quant_scale_q (float) – Quantization scaling factor for query tensors, default:
1.0.quant_scale_kv (float) – Quantization scaling factor for key tensors, default:
1.0.q_rope_out (Optional[torch.Tensor]) – Pre-allocated output tensor for quantized query (rotary). If
None, allocated automatically.k_rope_out (Optional[torch.Tensor]) – Pre-allocated output tensor for quantized key (rotary). If
None, allocated automatically.q_nope_out (Optional[torch.Tensor]) – Pre-allocated output tensor for quantized query (non-rotary). If
None, allocated automatically.k_nope_out (Optional[torch.Tensor]) – Pre-allocated output tensor for quantized key (non-rotary). If
None, allocated automatically.enable_pdl (bool) – Whether to enable PDL (Programmatic Dependent Launch). Default:
False.
- Returns:
Quantized tensors: (q_rope_out, k_rope_out, q_nope_out, k_nope_out).
- Return type:
Tuple[torch.Tensor, torch.Tensor, torch.Tensor, torch.Tensor]