flashinfer.rope.apply_rope_pos_ids_inplace¶
- flashinfer.rope.apply_rope_pos_ids_inplace(q: Tensor, k: Tensor, pos_ids: Tensor, rotary_dim: int | None = None, interleave: bool = False, rope_scale: float = 1, rope_theta: float = 10000.0) None¶
Apply rotary embedding to a batch of queries/keys (stored as RaggedTensor) inplace. cos/sin values are computed on the fly inside the kernel.
We use
indptrto denote the start pointer of each segment in the batch, the i-th segment the query of the i-th segment isq[indptr[i]:indptr[i+1]]and the key of the i-th segment isk[indptr[i]:indptr[i+1]], the first element ofindptris always 0 and the last element ofindptris the total number of queries/keys in the batch. Please see Ragged Tensor tutorial for more details about the ragged tensor.- Parameters:
q (torch.Tensor) – Query ragged tensor, shape:
(nnz, num_q_heads, head_dim)`, where ``nnzis the last element ofindptr.k (torch.Tensor) – Key ragged tensor, shape:
(nnz, num_k_heads, head_dim), wherennzis the last element ofindptr.pos_ids (torch.Tensor) – Position indices, shape:
(nnz).rotary_dim (Optional[int]) – The dimensions to apply RoPE, if
None, we apply RoPE to the entire head dimension, otherwise, we apply RoPE to the firstrotary_dimdimensions, default:None.interleave (bool) –
Whether to use interleaved layout in the last dimension, default:
False.If
True, the last dimension of the query/key tensor is interleaved, i.e., we rotate the even dimensions([..., ::2])and odd dimensions([..., 1::2]).If
False, the last dimension of the query/key tensor is not interleaved, i.e., we rotate the first half dimensions([..., :head_dim//2])and the second half dimensions([..., head_dim//2:]).
rope_scale (float) – The scaling factor used in the rope embedding, default:
1.rope_theta (float) – The theta value used in the rope embedding, default:
1e4.
See also