flashinfer.rope.apply_rope_with_cos_sin_cache#
- flashinfer.rope.apply_rope_with_cos_sin_cache(q: torch.Tensor, k: torch.Tensor, cos_cache: torch.Tensor, sin_cache: torch.Tensor, pos_ids: torch.Tensor, interleave: bool = False) Tuple[torch.Tensor, torch.Tensor] #
Apply rotary embedding to keys and queries with precomputed cos/sin values.
- Parameters:
q (torch.Tensor) – Query tensor, shape:
(nnz, num_q_heads, head_dim)
.k (torch.Tensor) – Key tensor, shape:
(nnz, num_k_heads, head_dim)
.cos_cache (torch.Tensor) – Cosine cache tensor, shape:
(max_seq_len, rotary_dim)
.sin_cache (torch.Tensor) – Sine cache tensor, shape:
(max_seq_len, rotary_dim)
.pos_ids (torch.Tensor) – Position indices, shape:
(nnz)
.interleave (bool) –
Whether to use interleaved layout in the last dimension, default:
False
.If
True
, the last dimension of the query/key tensor is interleaved, i.e., we rotate the even dimensions([..., ::2])
and odd dimensions([..., 1::2])
.If
False
, the last dimension of the query/key tensor is not interleaved, i.e., we rorate the first half dimensions([..., :head_dim//2])
and the second half dimensions([..., head_dim//2:])
.
- Returns:
q_rope (torch.Tensor) – The rotated query tensor, shape:
(nnz, num_q_heads, head_dim)
.k_rope (torch.Tensor) – The rotated key tensor, shape:
(nnz, num_k_heads, head_dim)
.
Note
The rotary dimension is determined by the cosine cache and sine cache.