flashinfer.nvfp4_attention_sm120.nvfp4_attention_sm120_fwd

flashinfer.nvfp4_attention_sm120.nvfp4_attention_sm120_fwd(q_fp4: Tensor, k_fp4: Tensor, v_fp4_t: Tensor, q_scale: Tensor, k_scale: Tensor, v_scale_t: Tensor, qk_correction: Tensor, sm_scale: float | None = None, causal: bool = False, per_block_mean: bool = True, out: Tensor | None = None, lse: Tensor | None = None, out_dtype: dtype = torch.bfloat16, softmax_scale: float | None = None) Tuple[Tensor, Tensor]

Run SM120 NVFP4 attention on pre-quantized Q/K/V tensors.

The packed tensors should be produced by nvfp4_attention_sm120_quantize_qkv(). q_fp4 and k_fp4 use layout [batch, num_heads, seq_len, head_dim / 2]; v_fp4_t and v_scale_t are stored transposed as [batch, num_heads, head_dim, packed_seq_len].

Parameters:
  • q_fp4 (torch.Tensor) – Packed NVFP4 Q/K/V tensors.

  • k_fp4 (torch.Tensor) – Packed NVFP4 Q/K/V tensors.

  • v_fp4_t (torch.Tensor) – Packed NVFP4 Q/K/V tensors.

  • q_scale (torch.Tensor) – Per-vector FP8 scale factors for Q/K/V.

  • k_scale (torch.Tensor) – Per-vector FP8 scale factors for Q/K/V.

  • v_scale_t (torch.Tensor) – Per-vector FP8 scale factors for Q/K/V.

  • qk_correction (torch.Tensor) – FP32 correction term returned by nvfp4_attention_sm120_quantize_qkv.

  • sm_scale (Optional[float], optional) – Scale applied to QK scores before softmax. Defaults to 1 / sqrt(head_dim) when omitted.

  • causal (bool, optional) – Whether to apply a causal mask.

  • per_block_mean (bool, optional) – Must match the value used by nvfp4_attention_sm120_quantize_qkv.

  • out (Optional[torch.Tensor], optional) – Optional output and log-sum-exp buffers.

  • lse (Optional[torch.Tensor], optional) – Optional output and log-sum-exp buffers.

  • out_dtype (torch.dtype, optional) – Output dtype used when out is not provided.

  • softmax_scale (Optional[float], optional) – Deprecated alias for sm_scale.

Returns:

Attention output and log-sum-exp tensor.

Return type:

Tuple[torch.Tensor, torch.Tensor]