flashinfer.norm.fused_qk_rmsnorm_rope¶
- flashinfer.norm.fused_qk_rmsnorm_rope(qkv: Tensor, q_weight: Tensor, k_weight: Tensor, *, ppf: int, pph: int, ppw: int, num_frame_channels: int, num_height_channels: int, num_width_channels: int, num_heads_q: int, num_heads_k: int, num_heads_v: int, head_dim: int, eps: float = 1e-06, base: float = 10000.0, interleave: bool = True, factor: float = 1.0, low: float = 0.0, high: float = 0.0, attention_factor: float = 1.0, is_qk_norm: bool = True, output_fp8: bool = False, output_quant_scale: float = 1.0, v_quant_scale: float = 1.0, q_out: Tensor | None = None, k_out: Tensor | None = None, v_out: Tensor | None = None) Tuple[Tensor, Tensor, Tensor]¶
Fused QK RMSNorm + 3D RoPE + V copy for video generation DIT self-attention.
Applies across-heads RMSNorm to Q and K, then rotary position embeddings with 3D spatial decomposition (frame/height/width), and copies V to a contiguous output buffer. Optionally quantizes all outputs to FP8 E4M3.
- Parameters:
qkv (torch.Tensor) –
Combined QKV input, BF16, contiguous. Accepted shapes: - 3D:
[batch, seq_len, (num_heads_q+num_heads_k+num_heads_v)*head_dim]- 2D:[num_tokens, (num_heads_q+num_heads_k+num_heads_v)*head_dim]where
num_tokensmust be divisible byppf*pph*ppw.q_weight (torch.Tensor) – RMSNorm weight for Q
[num_heads_q * head_dim], BF16.k_weight (torch.Tensor) – RMSNorm weight for K
[num_heads_k * head_dim], BF16.ppf (int) – Number of patches in frame dimension.
pph (int) – Number of patches in height dimension.
ppw (int) – Number of patches in width dimension.
seq_len = ppf * pph * ppw.num_frame_channels (int) – RoPE frequency channels for the frame dimension (must be even).
num_height_channels (int) – RoPE frequency channels for the height dimension (must be even).
num_width_channels (int) – RoPE frequency channels for the width dimension (must be even).
num_frame_channels + num_height_channels + num_width_channels == head_dim.num_heads_q (int) – Number of query heads.
num_heads_k (int) – Number of key heads.
num_heads_v (int) – Number of value heads.
head_dim (int) – Dimension per head (must be 64, 128, or 256).
eps (float) – RMSNorm epsilon.
base (float) – RoPE base frequency.
interleave (bool) – True for interleaved RoPE (non-NeoX style), False for NeoX-style.
factor (float) – YARN RoPE scaling factor. 1.0 disables YARN.
low (float) – YARN low frequency threshold.
high (float) – YARN high frequency threshold.
attention_factor (float) – YARN attention factor applied to cos/sin. Must be 1.0 when factor is 1.0.
is_qk_norm (bool) – Whether to apply RMSNorm (False = RoPE only, skip normalization).
output_fp8 (bool) – Quantize Q, K, V outputs to FP8 E4M3.
output_quant_scale (float) – FP8 quantization scale for Q and K outputs.
v_quant_scale (float) – FP8 quantization scale for V output.
q_out (Optional[torch.Tensor]) – Pre-allocated Q output tensor (destination-passing style).
k_out (Optional[torch.Tensor]) – Pre-allocated K output tensor.
v_out (Optional[torch.Tensor]) – Pre-allocated V output tensor.
- Returns:
(q_out, k_out, v_out). If input is 3D, each has shape[batch, seq_len, num_heads_x, head_dim]. If input is 2D, each has shape[num_tokens, num_heads_x, head_dim].- Return type:
Tuple[torch.Tensor, torch.Tensor, torch.Tensor]