flashinfer.gemm.mm_bf16

flashinfer.gemm.mm_bf16(a: Tensor, b: Tensor, bias: Tensor | None = None, pdl: bool = False, out: Tensor | None = None, out_dtype: dtype = torch.bfloat16, backend: Literal['cudnn', 'cutlass', 'tgv', 'auto'] = 'cudnn') Tensor

MM BF16

Parameters:
  • a (torch.Tensor) – Input tensor, shape (m, k), bf16 in row-major layout.

  • b (torch.Tensor) – Weight tensor, shape (k, n), bf16 in column-major layout.

  • bias (Optional[torch.Tensor]) – Optional bias tensor, shape (n,). Enabled for TGV backend. Defaults to None.

  • pdl (bool) – Whether to use persistant data loader mode. Enabled for TGV backend. Defaults to False.

  • out (Optional[torch.Tensor]) – Out tensor, shape (m, n), bf16 or fp16. Enabled for CUTLASS backend. Defaults to None.

  • out_dtype (torch.dtype) – Output dtype, bf16 or fp16. Enabled for CUTLASS and cuDNN backends. Defaults to torch.bfloat16.

  • backend (Literal["cudnn", "cutlass", "tgv", "auto"]) – The backend to use for the operation. Defaults to "cudnn". "cudnn" uses the cuDNN backend. "cutlass" uses the CUTLASS backend. "tgv" uses the TGV backend. "auto" allows selecting the best tactic from all available backends when autotune is enabled.

Returns:

Out tensor, shape (m, n), bf16 or fp16 in row-major layout.

Return type:

torch.Tensor

Examples

>>> import torch
>>> import flashinfer
>>> # Using the TGV backend
>>> a = torch.randn([48, 64], device="cuda", dtype=torch.bfloat16)
>>> b = torch.randn([80, 64], device="cuda", dtype=torch.bfloat16).transpose(-2, -1)
>>> bias = torch.randn([80], device="cuda", dtype=torch.bfloat16)
>>> out = flashinfer.mm_bf16(a, b, bias=bias, pdl=True, backend="tgv")
>>> out.shape
torch.Size([48, 80])
>>> out.dtype
torch.bfloat16
>>> # Using the CUTLASS backend
>>> fp16_out = torch.empty([48, 80], device="cuda", dtype=torch.float16)
>>> out = flashinfer.mm_bf16(a, b, out=fp16_out, out_dtype=torch.float16, backend="cutlass")
>>> out.shape
torch.Size([48, 80])
>>> out.dtype
torch.float16
>>> # Using the cuDNN backend
>>> out = flashinfer.mm_bf16(a, b, backend="cudnn")
>>> out.shape
torch.Size([48, 80])
>>> out.dtype
torch.bfloat16