flashinfer.rope.apply_llama31_rope_pos_ids

flashinfer.rope.apply_llama31_rope_pos_ids(q: torch.Tensor, k: torch.Tensor, pos_ids: torch.Tensor, rotary_dim: int | None = None, interleave: bool = False, rope_scale: float = 8, rope_theta: float = 500000.0, low_freq_factor: float = 1, high_freq_factor: float = 4, old_context_len: int = 8192) Tuple[torch.Tensor, torch.Tensor]

Apply Llama 3.1 style rotary embedding to a batch of queries/keys (stored as RaggedTensor). cos/sin values are computed on the fly inside the kernel.

We use indptr to denote the start pointer of each segment in the batch, the i-th segment the query of the i-th segment is q[indptr[i]:indptr[i+1]] and the key of the i-th segment is k[indptr[i]:indptr[i+1]], the first element of indptr is always 0 and the last element of indptr is the total number of queries/keys in the batch. Please see Ragged Tensor tutorial for more details about the ragged tensor.

Parameters:
  • q (torch.Tensor) – Query ragged tensor, shape: (nnz, num_q_heads, head_dim), where nnz is the last element of indptr.

  • k (torch.Tensor) – Key ragged tensor, shape: (nnz, num_k_heads, head_dim), where nnz is the last element of indptr.

  • pos_ids (torch.Tensor) – Position indices, shape: (nnz).

  • rotary_dim (Optional[int]) – The dimensions to apply RoPE, if None, we apply RoPE to the entire head dimension, otherwise, we apply RoPE to the first rotary_dim dimensions, default: None.

  • interleave (bool) –

    Whether to use interleaved layout in the last dimension, default: False.

    • If True, the last dimension of the query/key tensor is interleaved, i.e., we rotate the even dimensions ([..., ::2]) and odd dimensions ([..., 1::2]).

    • If False, the last dimension of the query/key tensor is not interleaved, i.e., we rorate the first half dimensions ([..., :head_dim//2]) and the second half dimensions ([..., head_dim//2:])

  • rope_scale (float) – The scaling factor used in the rope embedding, default: 8.

  • rope_theta (float) – The theta value used in the rope embedding, default: 5e5.

  • low_freq_factor (float) – The low frequency factor used in Llama 3.1 RoPE, default: 1.

  • high_freq_factor (float) – The high frequency factor used in Llama 3.1 RoPE, default: 4.

  • old_context_len (int) – The old context length used in Llama 3.1 RoPE, default: 8192.

Returns:

  • q_rope (torch.Tensor) – The rotated query tensor, shape: (nnz, num_q_heads, head_dim).

  • k_rope (torch.Tensor) – The rotated key tensor, shape: (nnz, num_k_heads, head_dim).