flashinfer.norm.fused_qk_rmsnorm_rope

flashinfer.norm.fused_qk_rmsnorm_rope(qkv: Tensor, q_weight: Tensor, k_weight: Tensor, *, ppf: int, pph: int, ppw: int, num_frame_channels: int, num_height_channels: int, num_width_channels: int, num_heads_q: int, num_heads_k: int, num_heads_v: int, head_dim: int, eps: float = 1e-06, base: float = 10000.0, interleave: bool = True, factor: float = 1.0, low: float = 0.0, high: float = 0.0, attention_factor: float = 1.0, is_qk_norm: bool = True, output_fp8: bool = False, output_quant_scale: float = 1.0, v_quant_scale: float = 1.0, q_out: Tensor | None = None, k_out: Tensor | None = None, v_out: Tensor | None = None) Tuple[Tensor, Tensor, Tensor]

Fused QK RMSNorm + 3D RoPE + V copy for video generation DIT self-attention.

Applies across-heads RMSNorm to Q and K, then rotary position embeddings with 3D spatial decomposition (frame/height/width), and copies V to a contiguous output buffer. Optionally quantizes all outputs to FP8 E4M3.

Parameters:
  • qkv (torch.Tensor) –

    Combined QKV input, BF16, contiguous. Accepted shapes: - 3D: [batch, seq_len, (num_heads_q+num_heads_k+num_heads_v)*head_dim] - 2D: [num_tokens, (num_heads_q+num_heads_k+num_heads_v)*head_dim]

    where num_tokens must be divisible by ppf*pph*ppw.

  • q_weight (torch.Tensor) – RMSNorm weight for Q [num_heads_q * head_dim], BF16.

  • k_weight (torch.Tensor) – RMSNorm weight for K [num_heads_k * head_dim], BF16.

  • ppf (int) – Number of patches in frame dimension.

  • pph (int) – Number of patches in height dimension.

  • ppw (int) – Number of patches in width dimension. seq_len = ppf * pph * ppw.

  • num_frame_channels (int) – RoPE frequency channels for the frame dimension (must be even).

  • num_height_channels (int) – RoPE frequency channels for the height dimension (must be even).

  • num_width_channels (int) – RoPE frequency channels for the width dimension (must be even). num_frame_channels + num_height_channels + num_width_channels == head_dim.

  • num_heads_q (int) – Number of query heads.

  • num_heads_k (int) – Number of key heads.

  • num_heads_v (int) – Number of value heads.

  • head_dim (int) – Dimension per head (must be 64, 128, or 256).

  • eps (float) – RMSNorm epsilon.

  • base (float) – RoPE base frequency.

  • interleave (bool) – True for interleaved RoPE (non-NeoX style), False for NeoX-style.

  • factor (float) – YARN RoPE scaling factor. 1.0 disables YARN.

  • low (float) – YARN low frequency threshold.

  • high (float) – YARN high frequency threshold.

  • attention_factor (float) – YARN attention factor applied to cos/sin. Must be 1.0 when factor is 1.0.

  • is_qk_norm (bool) – Whether to apply RMSNorm (False = RoPE only, skip normalization).

  • output_fp8 (bool) – Quantize Q, K, V outputs to FP8 E4M3.

  • output_quant_scale (float) – FP8 quantization scale for Q and K outputs.

  • v_quant_scale (float) – FP8 quantization scale for V output.

  • q_out (Optional[torch.Tensor]) – Pre-allocated Q output tensor (destination-passing style).

  • k_out (Optional[torch.Tensor]) – Pre-allocated K output tensor.

  • v_out (Optional[torch.Tensor]) – Pre-allocated V output tensor.

Returns:

(q_out, k_out, v_out). If input is 3D, each has shape [batch, seq_len, num_heads_x, head_dim]. If input is 2D, each has shape [num_tokens, num_heads_x, head_dim].

Return type:

Tuple[torch.Tensor, torch.Tensor, torch.Tensor]