flashinfer.gemm.bmm_fp8

flashinfer.gemm.bmm_fp8(A: torch.Tensor, B: torch.Tensor, A_scale: torch.Tensor, B_scale: torch.Tensor, dtype: torch.dtype, out: torch.Tensor | None = None) torch.Tensor

BMM FP8

Parameters:
  • A (torch.Tensor) – Input tensor, shape (b, m, k), fp8 e4m3 or fp8 e5m2.

  • B (torch.Tensor) – Mat2 tensor, shape (b, k, n), should be column major, fp8 e4m3 or fp8 e5m2.

  • A_scale (torch.Tensor) – Scale tensor for A, float.

  • B_scale (torch.Tensor) – Scale tensor for B, float.

  • dtype (torch.dtype) – out dtype, bf16 or fp16.

  • out (Optional[torch.Tensor]) – Out tensor, shape (b, m, n), bf16 or fp16, defaults to None.

Returns:

out – Out tensor, shape (b, m, n), bf16 or fp16.

Return type:

torch.Tensor

Examples

>>> import torch
>>> import torch.nn.functional as F
>>> import flashinfer
>>> def to_float8(x, dtype=torch.float8_e4m3fn):
...     finfo = torch.finfo(dtype)
...     min_val, max_val = x.aminmax()
...     amax = torch.maximum(min_val.abs(), max_val.abs()).clamp(min=1e-12)
...     scale = finfo.max / amax
...     x_scl_sat = (x * scale).clamp(min=finfo.min, max=finfo.max)
...     return x_scl_sat.to(dtype), scale.float().reciprocal()
>>>
>>> input = torch.randn([16, 48, 64], device="cuda", dtype=torch.bfloat16)
>>> input_fp8, input_inv_s = to_float8(input, dtype=torch.float8_e4m3fn)
>>> # column major weight
>>> weight = torch.randn([16, 80, 64], device="cuda", dtype=torch.bfloat16).transpose(-2, -1)
>>> weight_fp8, weight_inv_s = to_float8(weight, dtype=torch.float8_e4m3fn)
>>> out = flashinfer.bmm_fp8(input_fp8, weight_fp8, input_inv_s, weight_inv_s, torch.bfloat16)
>>> out.shape
torch.Size([16, 48, 80])
>>> out.dtype
torch.bfloat16