flashinfer.gemm.mm_bf16

flashinfer.gemm.mm_bf16(a: Tensor, b: Tensor, bias: Tensor | None = None, pdl: bool = False, out: Tensor | None = None, out_dtype: dtype = torch.bfloat16, backend: Literal['cutlass', 'tgv', 'auto'] = 'tgv') Tensor

MM BF16

Parameters:
  • a (torch.Tensor) – Input tensor, shape (m, k), bf16 in row-major layout.

  • b (torch.Tensor) – Weight tensor, shape (k, n), bf16 in column-major layout.

  • bias (Optional[torch.Tensor]) – Optional bias tensor, shape (n,). If provided, can only be used with the TGV backend. Defaults to None.

  • pdl (bool) – Whether to use persistant data loader mode. Can only be used with the TGV backend. Defaults to False.

  • out (Optional[torch.Tensor]) – Out tensor, shape (m, n), bf16 or fp16. If provided, can only be used with the CUTLASS backend. Defaults to None.

  • out_dtype (torch.dtype) – Output dtype, bf16 or fp16. If provided, can only be used with the CUTLASS backend. Defaults to torch.bfloat16.

  • backend (Literal["cutlass", "tgv", "auto"]) – The backend to use for the operation. Defaults to "tgv". "auto" allows selecting the best tactic from all available backends when autotune is enabled.

Returns:

Out tensor, shape (m, n), bf16 or fp16 in row-major layout.

Return type:

torch.Tensor

Examples

>>> import torch
>>> import flashinfer
>>> # Using the TGV backend
>>> a = torch.randn([48, 64], device="cuda", dtype=torch.bfloat16)
>>> b = torch.randn([80, 64], device="cuda", dtype=torch.bfloat16).transpose(-2, -1)
>>> bias = torch.randn([80], device="cuda", dtype=torch.bfloat16)
>>> out = flashinfer.mm_bf16(a, b, bias=bias, pdl=True, backend="tgv")
>>> out.shape
torch.Size([48, 80])
>>> out.dtype
torch.bfloat16
>>> # Using the CUTLASS backend
>>> fp16_out = torch.empty([48, 80], device="cuda", dtype=torch.float16)
>>> out = flashinfer.mm_bf16(a, b, out=fp16_out, out_dtype=torch.float16, backend="cutlass")
>>> out.shape
torch.Size([48, 80])
>>> out.dtype
torch.float16