flashinfer.rope.apply_rope#

flashinfer.rope.apply_rope(q: torch.Tensor, k: torch.Tensor, indptr: torch.Tensor, offsets: torch.Tensor, interleave: bool = False, rope_scale: float = 1, rope_theta: float = 10000.0) None#

Apply rotary embedding to a batch of queries/keys (stored as RaggedTensor).

We use indptr to denote the start pointer of each segment in the batch, the i-th segment the query of the i-th segment is q[indptr[i]:indptr[i+1]] and the key of the i-th segment is k[indptr[i]:indptr[i+1]], the first element of indptr is always 0 and the last element of indptr is the total number of queries/keys in the batch. Please see Ragged Tensor tutorial for more details about the ragged tensor.

Parameters:
  • q (torch.Tensor) – Query ragged tensor, shape: (nnz, num_q_heads, head_dim)`, where ``nnz is the last element of indptr.

  • k (torch.Tensor) – Key ragged tensor, shape: (nnz, num_k_heads, head_dim), where nnz is the last element of indptr.

  • indptr (torch.Tensor) – Indptr tensor, shape: (batch_size + 1).

  • offsets (torch.Tensor) – The relative position offsets of each query in the batch, shape: (batch_size).

  • interleave (bool) –

    Whether to use interleaved layout in the last dimension, default: False.

    • If True, the last dimension of the query/key tensor is interleaved, i.e., we rotate the even dimensions ([..., ::2]) and odd dimensions ([..., 1::2]).

    • If False, the last dimension of the query/key tensor is not interleaved, i.e., we rorate the first half dimensions ([..., :head_dim//2]) and the second half dimensions ([..., head_dim//2:]).

  • rope_scale (float) – The scaling factor used in the rope embedding, default: 1.

  • rope_theta (float) – The theta value used in the rope embedding, default: 1e4.

Returns:

  • q_rope (torch.Tensor) – The rotated query tensor, shape: (nnz, num_q_heads, head_dim).

  • k_rope (torch.Tensor) – The rotated key tensor, shape: (nnz, num_k_heads, head_dim).

Examples

>>> import torch
>>> import flashinfer
>>> batch_size = 128
>>> qkv_len = 1024
>>> num_qo_heads = 32
>>> num_kv_heads = 32
>>> head_dim = 128
>>> nnz = batch_size * qkv_len
>>> qkv_packed = torch.randn(
>>>    nnz,
>>>    (num_qo_heads + 2 * num_kv_heads) * head_dim,
>>>    dtype=torch.float16,
>>>    device="cuda:0",
>>> )
>>> q = qkv_packed[:, : num_qo_heads * head_dim].reshape(nnz, num_qo_heads, head_dim)
>>> k = qkv_packed[
...    :, num_qo_heads * head_dim : (num_qo_heads + num_kv_heads) * head_dim
... ].reshape(nnz, num_kv_heads, head_dim)
>>> indptr = torch.tensor(
...    [i * qkv_len for i in range(batch_size + 1)], dtype=torch.int32, device="cuda:0"
>>> )
>>> offsets = torch.full((batch_size,), 10, dtype=torch.int32, device="cuda:0")
>>> q_rope, k_rope = flashinfer.apply_rope(q, k, indptr, offsets)
>>> q_rope.shape
torch.Size([131072, 32, 128])
>>> k_rope.shape
torch.Size([131072, 32, 128])