flashinfer.gemm.bmm_mxfp8¶
- flashinfer.gemm.bmm_mxfp8(A: Tensor, B: Tensor, A_scale: Tensor, B_scale: Tensor, dtype: dtype, out: Tensor | None = None, backend: Literal['cudnn', 'cutlass', 'auto'] = 'auto') Tensor¶
BMM MXFP8
- Parameters:
A (torch.Tensor) – Input tensor, shape (b, m, k), fp8 e4m3 or fp8 e5m2.
B (torch.Tensor) – Mat2 tensor, shape (b, k, n), should be column major, fp8 e4m3 or fp8 e5m2.
A_scale (torch.Tensor) – Scale tensor for A, uint8 (fp8 e8m0 format).
B_scale (torch.Tensor) – Scale tensor for B, uint8 (fp8 e8m0 format).
dtype (torch.dtype) – out dtype, bf16 or fp16.
out (Optional[torch.Tensor]) – Out tensor, shape (b, m, n), bf16 or fp16, defaults to
None.backend (Literal["cudnn", "cutlass", "auto"]) – The backend to use for the operation. Defaults to
"auto". On SM120/121 GPUs,"auto"selects the CUTLASS backend; scales must be 1D swizzled (SfLayout.layout_128x4). PassBin the standard shape[b, k, n](column-major); the CUTLASS path transposes internally.
- Returns:
out – Out tensor, shape (b, m, n), bf16 or fp16.
- Return type:
torch.Tensor