flashinfer.gemm.bmm_bf16¶
- flashinfer.gemm.bmm_bf16(A: Tensor, B: Tensor, out: Tensor | None = None, out_dtype: dtype = torch.bfloat16, backend: Literal['cutlass'] = 'cutlass') Tensor¶
BMM BF16
- Parameters:
A (torch.Tensor) – Input tensor, shape (b, m, k), bf16 in row-major layout.
B (torch.Tensor) – Weight tensor, shape (b, k, n), bf16 in column-major layout.
out (Optional[torch.Tensor]) – Out tensor, shape (b, m, n), bf16 or fp16, defaults to
None.out_dtype (torch.dtype) – Output dtype, bf16 (default) or fp16.
backend (Literal["cutlass"]) – Backend to use, defaults to “cutlass”.
- Returns:
Out tensor, shape (b, m, n), bf16 or fp16 in row-major layout.
- Return type:
torch.Tensor
Examples
>>> import torch >>> import flashinfer >>> input = torch.randn([16, 48, 64], device="cuda", dtype=torch.bfloat16) >>> weight = torch.randn([16, 80, 64], device="cuda", dtype=torch.bfloat16).transpose(-2, -1) >>> out = flashinfer.bmm_bf16(input, weight) >>> out.shape torch.Size([16, 48, 80]) >>> out.dtype torch.bfloat16