flashinfer.rope.apply_rope_with_cos_sin_cache_inplace

flashinfer.rope.apply_rope_with_cos_sin_cache_inplace(q: torch.Tensor, k: torch.Tensor, cos_cache: torch.Tensor, sin_cache: torch.Tensor, pos_ids: torch.Tensor, interleave: bool = False) None

Apply rotary embedding to keys and queries with precomputed cos/sin values. The result is stored in the input tensors inplace.

Parameters:
  • q (torch.Tensor) – Query tensor, shape: (nnz, num_q_heads, head_dim).

  • k (torch.Tensor) – Key tensor, shape: (nnz, num_k_heads, head_dim).

  • cos_cache (torch.Tensor) – Cosine cache tensor, shape: (max_seq_len, rotary_dim). Expect float32 data type.

  • sin_cache (torch.Tensor) – Sine cache tensor, shape: (max_seq_len, rotary_dim). Expect float32 data type.

  • pos_ids (torch.Tensor) – Position indices, shape: (nnz).

  • interleave (bool) –

    Whether to use interleaved layout in the last dimension, default: False.

    • If True, the last dimension of the query/key tensor is interleaved, i.e., we rotate the even dimensions ([..., ::2]) and odd dimensions ([..., 1::2]).

    • If False, the last dimension of the query/key tensor is not interleaved, i.e., we rorate the first half dimensions ([..., :head_dim//2]) and the second half dimensions ([..., head_dim//2:]).

Note

The rotary dimension is determined by the cosine cache and sine cache.