flashinfer.gemm.mm_bf16¶
- flashinfer.gemm.mm_bf16(a: Tensor, b: Tensor, bias: Tensor | None = None, pdl: bool = False, out: Tensor | None = None, out_dtype: dtype = torch.bfloat16, backend: Literal['cudnn', 'cutlass', 'tgv', 'cublaslt', 'tinygemm', 'auto'] = 'cudnn') Tensor¶
MM BF16
- Parameters:
a (torch.Tensor) – Input tensor, shape (m, k), bf16 in row-major layout.
b (torch.Tensor) – Weight tensor, shape (k, n), bf16 in column-major layout.
bias (Optional[torch.Tensor]) – Optional bias tensor, shape (n,). Enabled for TGV and TinyGEMM backends. Defaults to
None.pdl (bool) – Whether to use Programmatic Dependent Launch. Enabled for TGV and TinyGEMM backends. Defaults to
False.out (Optional[torch.Tensor]) – Out tensor, shape (m, n), bf16, fp16, or fp32. FP16 and FP32 output are enabled for CUTLASS and cuDNN backends; TinyGEMM requires bf16 output.
out_dtype (torch.dtype) – Output dtype, bf16, fp16, or fp32. Enabled for CUTLASS, cuDNN, and cuBLASLt backends. Defaults to
torch.bfloat16.backend (Literal["cudnn", "cutlass", "tgv", "cublaslt", "tinygemm", "auto"]) – The backend to use for the operation. Defaults to
"cudnn"."cudnn"uses the cuDNN backend."cutlass"uses the CUTLASS backend."tgv"uses the TGV backend."cublaslt"uses the cuBLASLt backend with heuristic algorithm search."tinygemm"uses the TinyGEMM backend for small-M BF16 GEMM."auto"allows selecting the best tactic from all available backends when autotune is enabled.
- Returns:
Out tensor, shape (m, n), bf16, fp16, or fp32 in row-major layout.
- Return type:
torch.Tensor
Examples
>>> import torch >>> import flashinfer >>> # Using the TGV backend >>> a = torch.randn([48, 64], device="cuda", dtype=torch.bfloat16) >>> b = torch.randn([80, 64], device="cuda", dtype=torch.bfloat16).transpose(-2, -1) >>> bias = torch.randn([80], device="cuda", dtype=torch.bfloat16) >>> out = flashinfer.mm_bf16(a, b, bias=bias, pdl=True, backend="tgv") >>> out.shape torch.Size([48, 80]) >>> out.dtype torch.bfloat16 >>> # Using the CUTLASS backend >>> fp16_out = torch.empty([48, 80], device="cuda", dtype=torch.float16) >>> out = flashinfer.mm_bf16(a, b, out=fp16_out, out_dtype=torch.float16, backend="cutlass") >>> out.shape torch.Size([48, 80]) >>> out.dtype torch.float16 >>> # Using the cuDNN backend >>> out = flashinfer.mm_bf16(a, b, backend="cudnn") >>> out.shape torch.Size([48, 80]) >>> out.dtype torch.bfloat16 >>> # Using the cuBLASLt backend >>> out = flashinfer.mm_bf16(a, b, backend="cublaslt") >>> out.shape torch.Size([48, 80]) >>> out.dtype torch.bfloat16