flashinfer.nvfp4_attention_sm120.nvfp4_attention_sm120_fwd¶
- flashinfer.nvfp4_attention_sm120.nvfp4_attention_sm120_fwd(q_fp4: Tensor, k_fp4: Tensor, v_fp4_t: Tensor, q_scale: Tensor, k_scale: Tensor, v_scale_t: Tensor, qk_correction: Tensor, sm_scale: float | None = None, causal: bool = False, per_block_mean: bool = True, out: Tensor | None = None, lse: Tensor | None = None, out_dtype: dtype = torch.bfloat16, softmax_scale: float | None = None) Tuple[Tensor, Tensor]¶
Run SM120 NVFP4 attention on pre-quantized Q/K/V tensors.
The packed tensors should be produced by
nvfp4_attention_sm120_quantize_qkv().q_fp4andk_fp4use layout[batch, num_heads, seq_len, head_dim / 2];v_fp4_tandv_scale_tare stored transposed as[batch, num_heads, head_dim, packed_seq_len].- Parameters:
q_fp4 (torch.Tensor) – Packed NVFP4 Q/K/V tensors.
k_fp4 (torch.Tensor) – Packed NVFP4 Q/K/V tensors.
v_fp4_t (torch.Tensor) – Packed NVFP4 Q/K/V tensors.
q_scale (torch.Tensor) – Per-vector FP8 scale factors for Q/K/V.
k_scale (torch.Tensor) – Per-vector FP8 scale factors for Q/K/V.
v_scale_t (torch.Tensor) – Per-vector FP8 scale factors for Q/K/V.
qk_correction (torch.Tensor) – FP32 correction term returned by
nvfp4_attention_sm120_quantize_qkv.sm_scale (Optional[float], optional) – Scale applied to QK scores before softmax. Defaults to
1 / sqrt(head_dim)when omitted.causal (bool, optional) – Whether to apply a causal mask.
per_block_mean (bool, optional) – Must match the value used by
nvfp4_attention_sm120_quantize_qkv.out (Optional[torch.Tensor], optional) – Optional output and log-sum-exp buffers.
lse (Optional[torch.Tensor], optional) – Optional output and log-sum-exp buffers.
out_dtype (torch.dtype, optional) – Output dtype used when
outis not provided.softmax_scale (Optional[float], optional) – Deprecated alias for
sm_scale.
- Returns:
Attention output and log-sum-exp tensor.
- Return type:
Tuple[torch.Tensor, torch.Tensor]