flashinfer.rope.apply_rope¶
- flashinfer.rope.apply_rope(q: torch.Tensor, k: torch.Tensor, indptr: torch.Tensor, offsets: torch.Tensor, rotary_dim: int | None = None, interleave: bool = False, rope_scale: float = 1, rope_theta: float = 10000.0) Tuple[torch.Tensor, torch.Tensor] ¶
Apply rotary embedding to a batch of queries/keys (stored as RaggedTensor). cos/sin values are computed on the fly inside the kernel.
We use
indptr
to denote the start pointer of each segment in the batch, the i-th segment the query of the i-th segment isq[indptr[i]:indptr[i+1]]
and the key of the i-th segment isk[indptr[i]:indptr[i+1]]
, the first element ofindptr
is always 0 and the last element ofindptr
is the total number of queries/keys in the batch. Please see Ragged Tensor tutorial for more details about the ragged tensor.- Parameters:
q (torch.Tensor) – Query ragged tensor, shape:
(nnz, num_q_heads, head_dim)`, where ``nnz
is the last element ofindptr
.k (torch.Tensor) – Key ragged tensor, shape:
(nnz, num_k_heads, head_dim)
, wherennz
is the last element ofindptr
.indptr (torch.Tensor) – Indptr tensor, shape:
(batch_size + 1)
.offsets (torch.Tensor) – The relative position offsets of each query in the batch, shape:
(batch_size)
.rotary_dim (Optional[int]) – The dimensions to apply RoPE, if
None
, we apply RoPE to the entire head dimension, otherwise, we apply RoPE to the firstrotary_dim
dimensions, default:None
.interleave (bool) –
Whether to use interleaved layout in the last dimension, default:
False
.If
True
, the last dimension of the query/key tensor is interleaved, i.e., we rotate the even dimensions([..., ::2])
and odd dimensions([..., 1::2])
.If
False
, the last dimension of the query/key tensor is not interleaved, i.e., we rorate the first half dimensions([..., :head_dim//2])
and the second half dimensions([..., head_dim//2:])
.
rope_scale (float) – The scaling factor used in the rope embedding, default:
1
.rope_theta (float) – The theta value used in the rope embedding, default:
1e4
.
- Returns:
q_rope (torch.Tensor) – The rotated query tensor, shape:
(nnz, num_q_heads, head_dim)
.k_rope (torch.Tensor) – The rotated key tensor, shape:
(nnz, num_k_heads, head_dim)
.
Examples
>>> import torch >>> import flashinfer >>> batch_size = 128 >>> qkv_len = 1024 >>> num_qo_heads = 32 >>> num_kv_heads = 32 >>> head_dim = 128 >>> nnz = batch_size * qkv_len >>> qkv_packed = torch.randn( >>> nnz, >>> (num_qo_heads + 2 * num_kv_heads) * head_dim, >>> dtype=torch.float16, >>> device="cuda:0", >>> ) >>> q = qkv_packed[:, : num_qo_heads * head_dim].reshape(nnz, num_qo_heads, head_dim) >>> k = qkv_packed[ ... :, num_qo_heads * head_dim : (num_qo_heads + num_kv_heads) * head_dim ... ].reshape(nnz, num_kv_heads, head_dim) >>> indptr = torch.tensor( ... [i * qkv_len for i in range(batch_size + 1)], dtype=torch.int32, device="cuda:0" >>> ) >>> offsets = torch.full((batch_size,), 10, dtype=torch.int32, device="cuda:0") >>> q_rope, k_rope = flashinfer.apply_rope(q, k, indptr, offsets) >>> q_rope.shape torch.Size([131072, 32, 128]) >>> k_rope.shape torch.Size([131072, 32, 128])
See also