flashinfer.gemm.mm_fp4

flashinfer.gemm.mm_fp4(a: Tensor, b: Tensor, a_descale: Tensor, b_descale: Tensor, alpha: Tensor | None = None, out_dtype: dtype = torch.bfloat16, out: Tensor | None = None, block_size: int = 16, use_8x4_sf_layout: bool = False, backend: Literal['cudnn', 'trtllm', 'cutlass', 'auto'] = 'auto', use_nvfp4: bool = True) Tensor

MM FP4

Parameters:
  • a (torch.Tensor) – Input tensor, shape (m, k), fp4 e2m1fn_x2 or uint8.

  • b (torch.Tensor) – Mat2 tensor, shape (k, n), should be column major, fp4 e2m1fn_x2 or uint8.

  • a_descale (torch.Tensor) – Block scale tensor for A, shape (m, k // block_size), float8_e4m3fn or uint8.

  • b_descale (torch.Tensor) – Block scale tensor for B, shape (k, n // block_size), float8_e4m3fn or uint8.

  • alpha (Optional[torch.Tensor]) – Global scale tensor, float scalar.

  • out_dtype (torch.dtype) – Output dtype, bf16 or fp16. When backend="trtllm", only bf16 is supported.

  • out (Optional[torch.Tensor]) – Out tensor, shape (m, n), bf16 or fp16, defaults to None.

  • block_size (int) – Block size for FP4 quantization, only 16 and 32 are supported. 16 in case of nvfp4 quantization. 32 in case of mxfp4 quantization.

  • use_8x4_sf_layout (bool) – Whether to use 8x4 scale factor layout or 128x4 scale factor layout, defaults to False.

  • backend (Literal["cudnn", "trtllm", "cutlass", "auto"]) – Backend to use, defaults to "auto", which automatically selects the best backend between "cudnn" and "cutlass" based on the current CUDA and cuDNN versions. The "trtllm" backend is never selected when backend="auto" because it requires different weight preparation.

  • use_nvfp4 (bool) – Whether to use nvfp4 quantization or mxfp4 quantization, defaults to True. See the block_size parameter for related constraints.

Notes

When cudnn/cutlass backend is used, both a and b should quantized with nvfp4_quantize using the 128x4 scale factor layout and do_shuffle=False. When trtllm backend is used, b must be quantized with 128x4 layout and do_shuffle=True. a can be quantized with either 128x4 or 8x4 layout (controlled by use_8x4_sf_layout) and do_shuffle=False.

Returns:

out – Out tensor, shape (m, n), bf16 or fp16.

Return type:

torch.Tensor

Examples

>>> import torch
>>> from flashinfer import nvfp4_quantize, mm_fp4, SfLayout
>>> a = torch.randn([48, 128], device="cuda", dtype=torch.bfloat16)
>>> b = torch.randn([256, 128], device="cuda", dtype=torch.bfloat16)
>>> a_global_sf = (448 * 6) / a.float().abs().nan_to_num().max()
>>> b_global_sf = (448 * 6) / b.float().abs().nan_to_num().max()
>>> a_fp4, a_sf = nvfp4_quantize(a, a_global_sf, sfLayout=SfLayout.layout_128x4, do_shuffle=False)
>>> b_fp4, b_sf = nvfp4_quantize(b, b_global_sf, sfLayout=SfLayout.layout_128x4, do_shuffle=True)
>>> out = mm_fp4(a_fp4, b_fp4.T, a_sf, b_sf.T, 1.0/(a_global_sf * b_global_sf), torch.bfloat16, None, backend="trtllm")
>>> out.shape
torch.Size([48, 256])