flashinfer.gemm.mm_bf16¶
- flashinfer.gemm.mm_bf16(a: Tensor, b: Tensor, bias: Tensor | None = None, pdl: bool = False, out: Tensor | None = None, out_dtype: dtype = torch.bfloat16, backend: Literal['cudnn', 'cutlass', 'tgv', 'auto'] = 'cudnn') Tensor¶
MM BF16
- Parameters:
a (torch.Tensor) – Input tensor, shape (m, k), bf16 in row-major layout.
b (torch.Tensor) – Weight tensor, shape (k, n), bf16 in column-major layout.
bias (Optional[torch.Tensor]) – Optional bias tensor, shape (n,). Enabled for TGV backend. Defaults to
None.pdl (bool) – Whether to use persistant data loader mode. Enabled for TGV backend. Defaults to
False.out (Optional[torch.Tensor]) – Out tensor, shape (m, n), bf16 or fp16. Enabled for CUTLASS backend. Defaults to
None.out_dtype (torch.dtype) – Output dtype, bf16 or fp16. Enabled for CUTLASS and cuDNN backends. Defaults to
torch.bfloat16.backend (Literal["cudnn", "cutlass", "tgv", "auto"]) – The backend to use for the operation. Defaults to
"cudnn"."cudnn"uses the cuDNN backend."cutlass"uses the CUTLASS backend."tgv"uses the TGV backend."auto"allows selecting the best tactic from all available backends when autotune is enabled.
- Returns:
Out tensor, shape (m, n), bf16 or fp16 in row-major layout.
- Return type:
torch.Tensor
Examples
>>> import torch >>> import flashinfer >>> # Using the TGV backend >>> a = torch.randn([48, 64], device="cuda", dtype=torch.bfloat16) >>> b = torch.randn([80, 64], device="cuda", dtype=torch.bfloat16).transpose(-2, -1) >>> bias = torch.randn([80], device="cuda", dtype=torch.bfloat16) >>> out = flashinfer.mm_bf16(a, b, bias=bias, pdl=True, backend="tgv") >>> out.shape torch.Size([48, 80]) >>> out.dtype torch.bfloat16 >>> # Using the CUTLASS backend >>> fp16_out = torch.empty([48, 80], device="cuda", dtype=torch.float16) >>> out = flashinfer.mm_bf16(a, b, out=fp16_out, out_dtype=torch.float16, backend="cutlass") >>> out.shape torch.Size([48, 80]) >>> out.dtype torch.float16 >>> # Using the cuDNN backend >>> out = flashinfer.mm_bf16(a, b, backend="cudnn") >>> out.shape torch.Size([48, 80]) >>> out.dtype torch.bfloat16