flashinfer.gemm.bmm_bf16

flashinfer.gemm.bmm_bf16(A: Tensor, B: Tensor, out: Tensor | None = None, out_dtype: dtype = torch.bfloat16, backend: Literal['cutlass'] = 'cutlass') Tensor

BMM BF16

Parameters:
  • A (torch.Tensor) – Input tensor, shape (b, m, k), bf16 in row-major layout.

  • B (torch.Tensor) – Weight tensor, shape (b, k, n), bf16 in column-major layout.

  • out (Optional[torch.Tensor]) – Out tensor, shape (b, m, n), bf16 or fp16, defaults to None.

  • out_dtype (torch.dtype) – Output dtype, bf16 (default) or fp16.

  • backend (Literal["cutlass"]) – Backend to use, defaults to “cutlass”.

Returns:

Out tensor, shape (b, m, n), bf16 or fp16 in row-major layout.

Return type:

torch.Tensor

Examples

>>> import torch
>>> import flashinfer
>>> input = torch.randn([16, 48, 64], device="cuda", dtype=torch.bfloat16)
>>> weight = torch.randn([16, 80, 64], device="cuda", dtype=torch.bfloat16).transpose(-2, -1)
>>> out = flashinfer.bmm_bf16(input, weight)
>>> out.shape
torch.Size([16, 48, 80])
>>> out.dtype
torch.bfloat16