flashinfer.gemm.mm_bf16

flashinfer.gemm.mm_bf16(a: Tensor, b: Tensor, bias: Tensor | None = None, pdl: bool = False, out: Tensor | None = None, out_dtype: dtype = torch.bfloat16, backend: Literal['cudnn', 'cutlass', 'tgv', 'cublaslt', 'tinygemm', 'auto'] = 'cudnn') Tensor

MM BF16

Parameters:
  • a (torch.Tensor) – Input tensor, shape (m, k), bf16 in row-major layout.

  • b (torch.Tensor) – Weight tensor, shape (k, n), bf16 in column-major layout.

  • bias (Optional[torch.Tensor]) – Optional bias tensor, shape (n,). Enabled for TGV and TinyGEMM backends. Defaults to None.

  • pdl (bool) – Whether to use Programmatic Dependent Launch. Enabled for TGV and TinyGEMM backends. Defaults to False.

  • out (Optional[torch.Tensor]) – Out tensor, shape (m, n), bf16, fp16, or fp32. FP16 and FP32 output are enabled for CUTLASS and cuDNN backends; TinyGEMM requires bf16 output.

  • out_dtype (torch.dtype) – Output dtype, bf16, fp16, or fp32. Enabled for CUTLASS, cuDNN, and cuBLASLt backends. Defaults to torch.bfloat16.

  • backend (Literal["cudnn", "cutlass", "tgv", "cublaslt", "tinygemm", "auto"]) – The backend to use for the operation. Defaults to "cudnn". "cudnn" uses the cuDNN backend. "cutlass" uses the CUTLASS backend. "tgv" uses the TGV backend. "cublaslt" uses the cuBLASLt backend with heuristic algorithm search. "tinygemm" uses the TinyGEMM backend for small-M BF16 GEMM. "auto" allows selecting the best tactic from all available backends when autotune is enabled.

Returns:

Out tensor, shape (m, n), bf16, fp16, or fp32 in row-major layout.

Return type:

torch.Tensor

Examples

>>> import torch
>>> import flashinfer
>>> # Using the TGV backend
>>> a = torch.randn([48, 64], device="cuda", dtype=torch.bfloat16)
>>> b = torch.randn([80, 64], device="cuda", dtype=torch.bfloat16).transpose(-2, -1)
>>> bias = torch.randn([80], device="cuda", dtype=torch.bfloat16)
>>> out = flashinfer.mm_bf16(a, b, bias=bias, pdl=True, backend="tgv")
>>> out.shape
torch.Size([48, 80])
>>> out.dtype
torch.bfloat16
>>> # Using the CUTLASS backend
>>> fp16_out = torch.empty([48, 80], device="cuda", dtype=torch.float16)
>>> out = flashinfer.mm_bf16(a, b, out=fp16_out, out_dtype=torch.float16, backend="cutlass")
>>> out.shape
torch.Size([48, 80])
>>> out.dtype
torch.float16
>>> # Using the cuDNN backend
>>> out = flashinfer.mm_bf16(a, b, backend="cudnn")
>>> out.shape
torch.Size([48, 80])
>>> out.dtype
torch.bfloat16
>>> # Using the cuBLASLt backend
>>> out = flashinfer.mm_bf16(a, b, backend="cublaslt")
>>> out.shape
torch.Size([48, 80])
>>> out.dtype
torch.bfloat16