flashinfer.rope.apply_rope_with_cos_sin_cache

flashinfer.rope.apply_rope_with_cos_sin_cache(q: torch.Tensor, k: torch.Tensor, cos_cache: torch.Tensor, sin_cache: torch.Tensor, pos_ids: torch.Tensor, interleave: bool = False) Tuple[torch.Tensor, torch.Tensor]

Apply rotary embedding to keys and queries with precomputed cos/sin values.

Parameters:
  • q (torch.Tensor) – Query tensor, shape: (nnz, num_q_heads, head_dim).

  • k (torch.Tensor) – Key tensor, shape: (nnz, num_k_heads, head_dim).

  • cos_cache (torch.Tensor) – Cosine cache tensor, shape: (max_seq_len, rotary_dim).

  • sin_cache (torch.Tensor) – Sine cache tensor, shape: (max_seq_len, rotary_dim).

  • pos_ids (torch.Tensor) – Position indices, shape: (nnz).

  • interleave (bool) –

    Whether to use interleaved layout in the last dimension, default: False.

    • If True, the last dimension of the query/key tensor is interleaved, i.e., we rotate the even dimensions ([..., ::2]) and odd dimensions ([..., 1::2]).

    • If False, the last dimension of the query/key tensor is not interleaved, i.e., we rorate the first half dimensions ([..., :head_dim//2]) and the second half dimensions ([..., head_dim//2:]).

Returns:

  • q_rope (torch.Tensor) – The rotated query tensor, shape: (nnz, num_q_heads, head_dim).

  • k_rope (torch.Tensor) – The rotated key tensor, shape: (nnz, num_k_heads, head_dim).

Note

The rotary dimension is determined by the cosine cache and sine cache.