flashinfer.gemm.mm_mxfp8¶
- flashinfer.gemm.mm_mxfp8(a: Tensor, b: Tensor, a_descale: Tensor, b_descale: Tensor, out: Tensor | None = None, out_dtype: dtype = torch.bfloat16, use_8x4_sf_layout: bool = False, backend: Literal['cutlass', 'cute-dsl', 'trtllm', 'auto'] = 'auto') Tensor¶
MM MXFP8 (block size 32)
- Parameters:
a (torch.Tensor) – Input A tensor, shape (m, k), mxfp8 e4m3.
b (torch.Tensor) – Input B tensor, shape (k, n), should be column major, mxfp8 e4m3.
a_descale (torch.Tensor) –
Block scale tensor for A. Can be: - 2D non-swizzled: shape (m, k // 32) - 1D swizzled: shape (M_padded * K_padded,)
where M_padded = round_up(m, 8 if 8x4 layout else 128), K_padded = round_up(k // 32, 4)
dtype: uint8.
b_descale (torch.Tensor) – Block scale tensor for B. Can be: - 2D non-swizzled: shape (k // 32, n) - transposed format - 1D swizzled: shape (N_padded * K_padded,) where N_padded = round_up(n, 128), K_padded = round_up(k // 32, 4) dtype: uint8. Note: For 2D format, this is the transposed version (typically passed as scale.t()). For 1D swizzled format, it’s flattened from (N_padded, K_padded) layout.
out (Optional[torch.Tensor]) – Out tensor, shape (m, n), bf16 or fp16. If provided, can only be used with the CUTLASS backend. Defaults to
None.out_dtype (torch.dtype) – Output dtype, bf16 or fp16. Defaults to
torch.bfloat16.use_8x4_sf_layout (bool) – Whether the scale tensors for a are in 8x4 layout (vs 128x4).
backend (Literal["cutlass", "cute-dsl", "trtllm", "auto"]) –
The backend to use for the operation. Defaults to
"auto"."auto"selects the CUTLASS backend. - The"cute-dsl"backend currently requires swizzled 1D scales(
mxfp8_quantize(..., is_sf_swizzled_layout=True)).The
"trtllm"requires b to be quantized with 128x4 swizzle layout and shuffled. a can be quantized with either 128x4 or 8x4 layout (controlled by use_8x4_sf_layout).On SM12x GPUs, the
"cutlass"backend only supports 1D swizzled scales (SfLayout.layout_128x4). Passing 2D linear scales will raise an error. Usemxfp8_quantize(..., sf_swizzle_layout=SfLayout.layout_128x4).
- Returns:
out – Out tensor, shape (m, n), bf16 or fp16.
- Return type:
torch.Tensor
Examples
>>> import torch >>> from flashinfer import mxfp8_quantize, mm_mxfp8 >>> m, n, k = 512, 256, 128 >>> # Create input tensors - note: weight is [n, k] for typical NN layers >>> a = torch.randn([m, k], device="cuda", dtype=torch.bfloat16) >>> weight = torch.randn([n, k], device="cuda", dtype=torch.bfloat16) >>> >>> # Option 1: Use swizzled layout (recommended for accuracy) >>> # Quantize input [m, k] - scales are 1D swizzled for (M, K/32) layout >>> a_mx, a_sf = mxfp8_quantize(input=a, is_sf_swizzled_layout=True) >>> # Quantize weight [n, k] - scales are 1D swizzled for (N, K/32) layout >>> w_mx, w_sf = mxfp8_quantize(input=weight, is_sf_swizzled_layout=True) >>> # Pass weight.T as [k, n] and 1D swizzled scales directly >>> out = mm_mxfp8(a_mx, w_mx.t(), a_sf, w_sf, out_dtype=torch.bfloat16) >>> out.shape torch.Size([512, 256]) >>> >>> # Option 2: Use non-swizzled layout (for compatibility) >>> a_mx, a_sf = mxfp8_quantize(input=a, is_sf_swizzled_layout=False) >>> w_mx, w_sf = mxfp8_quantize(input=weight, is_sf_swizzled_layout=False) >>> # For non-swizzled: reshape to 2D and transpose weight scale to (k//32, n) >>> a_sf_2d = a_sf.view(m, k // 32) >>> w_sf_2d = w_sf.view(n, k // 32).t() # Transpose to (k // 32, n) >>> out = mm_mxfp8(a_mx, w_mx.t(), a_sf_2d, w_sf_2d, out_dtype=torch.bfloat16) >>> out.shape torch.Size([512, 256])