flashinfer.rope.apply_rope_pos_ids#
- flashinfer.rope.apply_rope_pos_ids(q: torch.Tensor, k: torch.Tensor, pos_ids: torch.Tensor, rotary_dim: int | None = None, interleave: bool = False, rope_scale: float = 1, rope_theta: float = 10000.0) Tuple[torch.Tensor, torch.Tensor] #
Apply rotary embedding to a batch of queries/keys (stored as RaggedTensor). cos/sin values are computed on the fly inside the kernel.
We use
indptr
to denote the start pointer of each segment in the batch, the i-th segment the query of the i-th segment isq[indptr[i]:indptr[i+1]]
and the key of the i-th segment isk[indptr[i]:indptr[i+1]]
, the first element ofindptr
is always 0 and the last element ofindptr
is the total number of queries/keys in the batch. Please see Ragged Tensor tutorial for more details about the ragged tensor.- Parameters:
q (torch.Tensor) – Query ragged tensor, shape:
(nnz, num_q_heads, head_dim)`, where ``nnz
is the last element ofindptr
.k (torch.Tensor) – Key ragged tensor, shape:
(nnz, num_k_heads, head_dim)
, wherennz
is the last element ofindptr
.pos_ids (torch.Tensor) – Position indices, shape:
(batch_size + 1)
.rotary_dim (Optional[int]) – The dimensions to apply RoPE, if
None
, we apply RoPE to the entire head dimension, otherwise, we apply RoPE to the firstrotary_dim
dimensions, default:None
.interleave (bool) –
Whether to use interleaved layout in the last dimension, default:
False
.If
True
, the last dimension of the query/key tensor is interleaved, i.e., we rotate the even dimensions([..., ::2])
and odd dimensions([..., 1::2])
.If
False
, the last dimension of the query/key tensor is not interleaved, i.e., we rorate the first half dimensions([..., :head_dim//2])
and the second half dimensions([..., head_dim//2:])
.
rope_scale (float) – The scaling factor used in the rope embedding, default:
1
.rope_theta (float) – The theta value used in the rope embedding, default:
1e4
.
- Returns:
q_rope (torch.Tensor) – The rotated query tensor, shape:
(nnz, num_q_heads, head_dim)
.k_rope (torch.Tensor) – The rotated key tensor, shape:
(nnz, num_k_heads, head_dim)
.
See also