flashinfer.gemm.mm_mxfp8

flashinfer.gemm.mm_mxfp8(a: Tensor, b: Tensor, a_descale: Tensor, b_descale: Tensor, out: Tensor | None = None, out_dtype: dtype = torch.bfloat16, use_8x4_sf_layout: bool = False, backend: Literal['cutlass', 'cute-dsl', 'trtllm', 'auto'] = 'auto') Tensor

MM MXFP8 (block size 32)

Parameters:
  • a (torch.Tensor) – Input A tensor, shape (m, k), mxfp8 e4m3.

  • b (torch.Tensor) – Input B tensor, shape (k, n), should be column major, mxfp8 e4m3.

  • a_descale (torch.Tensor) –

    Block scale tensor for A. Can be: - 2D non-swizzled: shape (m, k // 32) - 1D swizzled: shape (M_padded * K_padded,)

    where M_padded = round_up(m, 8 if 8x4 layout else 128), K_padded = round_up(k // 32, 4)

    dtype: uint8.

  • b_descale (torch.Tensor) – Block scale tensor for B. Can be: - 2D non-swizzled: shape (k // 32, n) - transposed format - 1D swizzled: shape (N_padded * K_padded,) where N_padded = round_up(n, 128), K_padded = round_up(k // 32, 4) dtype: uint8. Note: For 2D format, this is the transposed version (typically passed as scale.t()). For 1D swizzled format, it’s flattened from (N_padded, K_padded) layout.

  • out (Optional[torch.Tensor]) – Out tensor, shape (m, n), bf16 or fp16. If provided, can only be used with the CUTLASS backend. Defaults to None.

  • out_dtype (torch.dtype) – Output dtype, bf16 or fp16. Defaults to torch.bfloat16.

  • use_8x4_sf_layout (bool) – Whether the scale tensors for a are in 8x4 layout (vs 128x4).

  • backend (Literal["cutlass", "cute-dsl", "trtllm", "auto"]) –

    The backend to use for the operation. Defaults to "auto". "auto" selects the CUTLASS backend. - The "cute-dsl" backend currently requires swizzled 1D scales

    (mxfp8_quantize(..., is_sf_swizzled_layout=True)).

    • The "trtllm" requires b to be quantized with 128x4 swizzle layout and shuffled. a can be quantized with either 128x4 or 8x4 layout (controlled by use_8x4_sf_layout).

    • On SM12x GPUs, the "cutlass" backend only supports 1D swizzled scales (SfLayout.layout_128x4). Passing 2D linear scales will raise an error. Use mxfp8_quantize(..., sf_swizzle_layout=SfLayout.layout_128x4).

Returns:

out – Out tensor, shape (m, n), bf16 or fp16.

Return type:

torch.Tensor

Examples

>>> import torch
>>> from flashinfer import mxfp8_quantize, mm_mxfp8
>>> m, n, k = 512, 256, 128
>>> # Create input tensors - note: weight is [n, k] for typical NN layers
>>> a = torch.randn([m, k], device="cuda", dtype=torch.bfloat16)
>>> weight = torch.randn([n, k], device="cuda", dtype=torch.bfloat16)
>>>
>>> # Option 1: Use swizzled layout (recommended for accuracy)
>>> # Quantize input [m, k] - scales are 1D swizzled for (M, K/32) layout
>>> a_mx, a_sf = mxfp8_quantize(input=a, is_sf_swizzled_layout=True)
>>> # Quantize weight [n, k] - scales are 1D swizzled for (N, K/32) layout
>>> w_mx, w_sf = mxfp8_quantize(input=weight, is_sf_swizzled_layout=True)
>>> # Pass weight.T as [k, n] and 1D swizzled scales directly
>>> out = mm_mxfp8(a_mx, w_mx.t(), a_sf, w_sf, out_dtype=torch.bfloat16)
>>> out.shape
torch.Size([512, 256])
>>>
>>> # Option 2: Use non-swizzled layout (for compatibility)
>>> a_mx, a_sf = mxfp8_quantize(input=a, is_sf_swizzled_layout=False)
>>> w_mx, w_sf = mxfp8_quantize(input=weight, is_sf_swizzled_layout=False)
>>> # For non-swizzled: reshape to 2D and transpose weight scale to (k//32, n)
>>> a_sf_2d = a_sf.view(m, k // 32)
>>> w_sf_2d = w_sf.view(n, k // 32).t()  # Transpose to (k // 32, n)
>>> out = mm_mxfp8(a_mx, w_mx.t(), a_sf_2d, w_sf_2d, out_dtype=torch.bfloat16)
>>> out.shape
torch.Size([512, 256])