flashinfer.gemm.mm_fp4

flashinfer.gemm.mm_fp4(a: torch.Tensor, b: torch.Tensor, a_descale: torch.Tensor, b_descale: torch.Tensor, alpha: torch.Tensor, out_dtype: torch.dtype, out: torch.Tensor | None = None, block_size: int = 16, use_8x4_sf_layout: bool = False, backend: Literal['cudnn', 'trtllm', 'cutlass'] = 'cudnn') torch.Tensor

MM FP4

Parameters:
  • a (torch.Tensor) – Input tensor, shape (m, k), fp4 e2m1fn_x2 or uint8.

  • b (torch.Tensor) – Mat2 tensor, shape (k, n), should be column major, fp4 e2m1fn_x2 or uint8.

  • a_descale (torch.Tensor) – Block scale tensor for A, shape (m, k // block_size), float8_e4m3fn or uint8.

  • b_descale (torch.Tensor) – Block scale tensor for B, shape (k, n // block_size), float8_e4m3fn or uint8.

  • alpha (torch.Tensor) – Global scale tensor, float scalar.

  • out_dtype (torch.dtype) – Output dtype, bf16 or fp16.

  • out (Optional[torch.Tensor]) – Out tensor, shape (m, n), bf16 or fp16, defaults to None.

  • block_size (int) – Block size for FP4 quantization, only 16 is supported.

  • use_8x4_sf_layout (bool) – Whether to use 8x4 scale factor layout or 128x4 scale factor layout, defaults to False.

  • backend (Literal["cudnn", "trtllm", "cutlass"]) – Backend to use, defaults to “cudnn”.

Notes

When cudnn/cutlass backend is used, both a and b should quantized with nvfp4_quantize using the 128x4 scale factor layout and do_shuffle=False. When trtllm backend is used, b must be quantized with 128x4 layout and do_shuffle=True. a can be quantized with either 128x4 or 8x4 layout (controlled by use_8x4_sf_layout) and do_shuffle=False.

Returns:

out – Out tensor, shape (m, n), bf16 or fp16.

Return type:

torch.Tensor

Examples

>>> import torch
>>> from flashinfer import nvfp4_quantize, mm_fp4, SfLayout
>>> a = torch.randn([48, 128], device="cuda", dtype=torch.bfloat16)
>>> b = torch.randn([256, 128], device="cuda", dtype=torch.bfloat16)
>>> a_global_sf = (448 * 6) / a.float().abs().nan_to_num().max()
>>> b_global_sf = (448 * 6) / b.float().abs().nan_to_num().max()
>>> a_fp4, a_sf = nvfp4_quantize(a, a_global_sf, sfLayout=SfLayout.layout_128x4, do_shuffle=False)
>>> b_fp4, b_sf = nvfp4_quantize(b, b_global_sf, sfLayout=SfLayout.layout_128x4, do_shuffle=True)
>>> out = mm_fp4(a_fp4, b_fp4.T, a_sf, b_sf.T, 1.0/(a_global_sf * b_global_sf), torch.bfloat16, None, backend="trtllm")
>>> out.shape
torch.Size([48, 256])