flashinfer.gemm.gemm_fp8_nt_blockscaled

flashinfer.gemm.gemm_fp8_nt_blockscaled(a: Tensor, b: Tensor, a_scale: Tensor, b_scale: Tensor, scale_major_mode: Literal['MN', 'K'] | None = 'MN', mma_sm: int = 1, out: Tensor | None = None, out_dtype: dtype | None = None) Tensor

Performs matrix multiplication with FP8 data types using block-scaled scaling.

Block-scaled scaling is a special case of groupwise scaling where the scale granularity is (128, 128, 128). See gemm_fp8_nt_groupwise() for the semantics of each parameter.

Parameters:
  • a (torch.Tensor) – FP8 input tensor. Shape (M, K).

  • b (torch.Tensor) – FP8 input tensor (transposed weight). Shape (N, K).

  • a_scale (torch.Tensor) – FP32 block-scale tensor for a, layout determined by scale_major_mode.

  • b_scale (torch.Tensor) – FP32 block-scale tensor for b, layout determined by scale_major_mode.

  • scale_major_mode (Optional[Literal["MN", "K"]]) – Storage order for the scale tensors. "MN" (default) places the non-contracted dimension in the major direction; "K" places the contracted dimension in the major direction.

  • mma_sm (int) – Number of SMs to fuse per MMA (1 or 2). Defaults to 1.

  • out (Optional[torch.Tensor]) – Pre-allocated output tensor of shape (M, N). If None, a new tensor is allocated.

  • out_dtype (Optional[torch.dtype]) – Output data type. Defaults to torch.bfloat16.

Returns:

Output tensor of shape (M, N) with dtype out_dtype.

Return type:

torch.Tensor