flashinfer.rope.apply_llama31_rope_inplace#

flashinfer.rope.apply_llama31_rope_inplace(q: torch.Tensor, k: torch.Tensor, indptr: torch.Tensor, offsets: torch.Tensor, rotary_dim: int | None = None, interleave: bool = False, rope_scale: float = 8, rope_theta: float = 500000.0, low_freq_factor: float = 1, high_freq_factor: float = 4, old_context_len: int = 8192) None#

Apply Llama 3.1 style rotary embedding to a batch of queries/keys (stored as RaggedTensor) inplace. cos/sin values are computed on the fly inside the kernel.

We use indptr to denote the start pointer of each segment in the batch, the i-th segment the query of the i-th segment is q[indptr[i]:indptr[i+1]] and the key of the i-th segment is k[indptr[i]:indptr[i+1]], the first element of indptr is always 0 and the last element of indptr is the total number of queries/keys in the batch. Please see Ragged Tensor tutorial for more details about the ragged tensor.

Parameters:
  • q (torch.Tensor) – Query ragged tensor, shape: (nnz, num_q_heads, head_dim), where nnz is the last element of indptr.

  • k (torch.Tensor) – Key ragged tensor, shape: (nnz, num_k_heads, head_dim), where nnz is the last element of indptr.

  • indptr (torch.Tensor) – Indptr tensor, shape: (batch_size + 1).

  • offsets (torch.Tensor) – The relative position offsets of each query in the batch, shape: (batch_size).

  • rotary_dim (Optional[int]) – The dimensions to apply RoPE, if None, we apply RoPE to the entire head dimension, otherwise, we apply RoPE to the first rotary_dim dimensions, default: None.

  • interleave (bool) –

    Whether to use interleaved layout in the last dimension, default: False.

    • If True, the last dimension of the query/key tensor is interleaved, i.e., we rotate the even dimensions ([..., ::2]) and odd dimensions ([..., 1::2]).

    • If False, the last dimension of the query/key tensor is not interleaved, i.e., we rorate the first half dimensions ([..., :head_dim//2]) and the second half dimensions ([..., head_dim//2:]).

  • rope_scale (float) – The scaling factor used in the rope embedding, default: 8.

  • rope_theta (float) – The theta value used in the rope embedding, default: 5e5.

  • low_freq_factor (float) – The low frequency factor used in Llama 3.1 RoPE, default: 1.

  • high_freq_factor (float) – The high frequency factor used in Llama 3.1 RoPE, default: 4.

  • old_context_len (int) – The old context length used in Llama 3.1 RoPE, default: 8192.

Examples

>>> import torch
>>> import flashinfer
>>> batch_size = 128
>>> qkv_len = 1024
>>> num_qo_heads = 32
>>> num_kv_heads = 32
>>> head_dim = 128
>>> nnz = batch_size * qkv_len
>>> qkv_packed = torch.randn(
>>>    nnz,
>>>    (num_qo_heads + 2 * num_kv_heads) * head_dim,
>>>    dtype=torch.float16,
>>>    device="cuda:0",
>>> )
>>> q = qkv_packed[:, : num_qo_heads * head_dim].reshape(nnz, num_qo_heads, head_dim)
>>> k = qkv_packed[
...    :, num_qo_heads * head_dim : (num_qo_heads + num_kv_heads) * head_dim
... ].reshape(nnz, num_kv_heads, head_dim)
>>> indptr = torch.tensor(
...    [i * qkv_len for i in range(batch_size + 1)], dtype=torch.int32, device="cuda:0"
>>> )
>>> offsets = torch.full((batch_size,), 10, dtype=torch.int32, device="cuda:0")
>>> flashinfer.apply_llama31_rope_inplace(q, k, indptr, offsets)