flashinfer.gemm.gemm_fp8_nt_blockscaled¶
- flashinfer.gemm.gemm_fp8_nt_blockscaled(a: Tensor, b: Tensor, a_scale: Tensor, b_scale: Tensor, scale_major_mode: Literal['MN', 'K'] | None = 'MN', mma_sm: int = 1, out: Tensor | None = None, out_dtype: dtype | None = None) Tensor¶
Performs matrix multiplication with FP8 data types using block-scaled scaling.
Block-scaled scaling is a special case of groupwise scaling where the scale granularity is
(128, 128, 128). Seegemm_fp8_nt_groupwise()for the semantics of each parameter.- Parameters:
a (torch.Tensor) – FP8 input tensor. Shape
(M, K).b (torch.Tensor) – FP8 input tensor (transposed weight). Shape
(N, K).a_scale (torch.Tensor) – FP32 block-scale tensor for
a, layout determined byscale_major_mode.b_scale (torch.Tensor) – FP32 block-scale tensor for
b, layout determined byscale_major_mode.scale_major_mode (Optional[Literal["MN", "K"]]) – Storage order for the scale tensors.
"MN"(default) places the non-contracted dimension in the major direction;"K"places the contracted dimension in the major direction.mma_sm (int) – Number of SMs to fuse per MMA (
1or2). Defaults to1.out (Optional[torch.Tensor]) – Pre-allocated output tensor of shape
(M, N). IfNone, a new tensor is allocated.out_dtype (Optional[torch.dtype]) – Output data type. Defaults to
torch.bfloat16.
- Returns:
Output tensor of shape
(M, N)with dtypeout_dtype.- Return type:
torch.Tensor