flashinfer.gemm.mm_fp4¶
- flashinfer.gemm.mm_fp4(a: torch.Tensor, b: torch.Tensor, a_descale: torch.Tensor, b_descale: torch.Tensor, alpha: torch.Tensor, out_dtype: torch.dtype, out: torch.Tensor | None = None, block_size: int = 16, use_8x4_sf_layout: bool = False, backend: Literal['cudnn', 'trtllm', 'cutlass'] = 'cudnn') torch.Tensor ¶
MM FP4
- Parameters:
a (torch.Tensor) – Input tensor, shape (m, k), fp4 e2m1fn_x2 or uint8.
b (torch.Tensor) – Mat2 tensor, shape (k, n), should be column major, fp4 e2m1fn_x2 or uint8.
a_descale (torch.Tensor) – Block scale tensor for A, shape (m, k // block_size), float8_e4m3fn or uint8.
b_descale (torch.Tensor) – Block scale tensor for B, shape (k, n // block_size), float8_e4m3fn or uint8.
alpha (torch.Tensor) – Global scale tensor, float scalar.
out_dtype (torch.dtype) – Output dtype, bf16 or fp16.
out (Optional[torch.Tensor]) – Out tensor, shape (m, n), bf16 or fp16, defaults to
None
.block_size (int) – Block size for FP4 quantization, only 16 is supported.
use_8x4_sf_layout (bool) – Whether to use 8x4 scale factor layout or 128x4 scale factor layout, defaults to False.
backend (Literal["cudnn", "trtllm", "cutlass"]) – Backend to use, defaults to “cudnn”.
Notes
When cudnn/cutlass backend is used, both a and b should quantized with nvfp4_quantize using the 128x4 scale factor layout and do_shuffle=False. When trtllm backend is used, b must be quantized with 128x4 layout and do_shuffle=True. a can be quantized with either 128x4 or 8x4 layout (controlled by use_8x4_sf_layout) and do_shuffle=False.
- Returns:
out – Out tensor, shape (m, n), bf16 or fp16.
- Return type:
torch.Tensor
Examples
>>> import torch >>> from flashinfer import nvfp4_quantize, mm_fp4, SfLayout >>> a = torch.randn([48, 128], device="cuda", dtype=torch.bfloat16) >>> b = torch.randn([256, 128], device="cuda", dtype=torch.bfloat16) >>> a_global_sf = (448 * 6) / a.float().abs().nan_to_num().max() >>> b_global_sf = (448 * 6) / b.float().abs().nan_to_num().max() >>> a_fp4, a_sf = nvfp4_quantize(a, a_global_sf, sfLayout=SfLayout.layout_128x4, do_shuffle=False) >>> b_fp4, b_sf = nvfp4_quantize(b, b_global_sf, sfLayout=SfLayout.layout_128x4, do_shuffle=True) >>> out = mm_fp4(a_fp4, b_fp4.T, a_sf, b_sf.T, 1.0/(a_global_sf * b_global_sf), torch.bfloat16, None, backend="trtllm") >>> out.shape torch.Size([48, 256])