flashinfer.gemm.bmm_mxfp8

flashinfer.gemm.bmm_mxfp8(A: Tensor, B: Tensor, A_scale: Tensor, B_scale: Tensor, dtype: dtype, out: Tensor | None = None, backend: Literal['cudnn', 'cutlass', 'auto'] = 'auto') Tensor

BMM MXFP8

Parameters:
  • A (torch.Tensor) – Input tensor, shape (b, m, k), fp8 e4m3 or fp8 e5m2.

  • B (torch.Tensor) – Mat2 tensor, shape (b, k, n), should be column major, fp8 e4m3 or fp8 e5m2.

  • A_scale (torch.Tensor) – Scale tensor for A, uint8 (fp8 e8m0 format).

  • B_scale (torch.Tensor) – Scale tensor for B, uint8 (fp8 e8m0 format).

  • dtype (torch.dtype) – out dtype, bf16 or fp16.

  • out (Optional[torch.Tensor]) – Out tensor, shape (b, m, n), bf16 or fp16, defaults to None.

  • backend (Literal["cudnn", "cutlass", "auto"]) – The backend to use for the operation. Defaults to "auto". On SM120/121 GPUs, "auto" selects the CUTLASS backend; scales must be 1D swizzled (SfLayout.layout_128x4). Pass B in the standard shape [b, k, n] (column-major); the CUTLASS path transposes internally.

Returns:

out – Out tensor, shape (b, m, n), bf16 or fp16.

Return type:

torch.Tensor