flashinfer.gemm.mm_bf16¶
- flashinfer.gemm.mm_bf16(a: Tensor, b: Tensor, bias: Tensor | None = None, pdl: bool = False, out: Tensor | None = None, out_dtype: dtype = torch.bfloat16, backend: Literal['cutlass', 'tgv', 'auto'] = 'tgv') Tensor¶
MM BF16
- Parameters:
a (torch.Tensor) – Input tensor, shape (m, k), bf16 in row-major layout.
b (torch.Tensor) – Weight tensor, shape (k, n), bf16 in column-major layout.
bias (Optional[torch.Tensor]) – Optional bias tensor, shape (n,). If provided, can only be used with the TGV backend. Defaults to
None.pdl (bool) – Whether to use persistant data loader mode. Can only be used with the TGV backend. Defaults to
False.out (Optional[torch.Tensor]) – Out tensor, shape (m, n), bf16 or fp16. If provided, can only be used with the CUTLASS backend. Defaults to
None.out_dtype (torch.dtype) – Output dtype, bf16 or fp16. If provided, can only be used with the CUTLASS backend. Defaults to
torch.bfloat16.backend (Literal["cutlass", "tgv", "auto"]) – The backend to use for the operation. Defaults to
"tgv"."auto"allows selecting the best tactic from all available backends when autotune is enabled.
- Returns:
Out tensor, shape (m, n), bf16 or fp16 in row-major layout.
- Return type:
torch.Tensor
Examples
>>> import torch >>> import flashinfer >>> # Using the TGV backend >>> a = torch.randn([48, 64], device="cuda", dtype=torch.bfloat16) >>> b = torch.randn([80, 64], device="cuda", dtype=torch.bfloat16).transpose(-2, -1) >>> bias = torch.randn([80], device="cuda", dtype=torch.bfloat16) >>> out = flashinfer.mm_bf16(a, b, bias=bias, pdl=True, backend="tgv") >>> out.shape torch.Size([48, 80]) >>> out.dtype torch.bfloat16 >>> # Using the CUTLASS backend >>> fp16_out = torch.empty([48, 80], device="cuda", dtype=torch.float16) >>> out = flashinfer.mm_bf16(a, b, out=fp16_out, out_dtype=torch.float16, backend="cutlass") >>> out.shape torch.Size([48, 80]) >>> out.dtype torch.float16