flashinfer.gemm.bmm_fp8#
- flashinfer.gemm.bmm_fp8(A: torch.Tensor, B: torch.Tensor, A_scale: torch.Tensor, B_scale: torch.Tensor, dtype: torch.dtype, out: torch.Tensor | None = None) torch.Tensor #
BMM FP8
- Parameters:
A (torch.Tensor) – Input tensor, shape (b, m, k), fp8 e4m3 or fp8 e5m2.
B (torch.Tensor) – Mat2 tensor, shape (b, k, n), should be column major, fp8 e4m3 or fp8 e5m2.
A_scale (torch.Tensor) – Scale tensor for A, float.
B_scale (torch.Tensor) – Scale tensor for B, float.
dtype (torch.dtype) – out dtype, bf16 or fp16.
out (Optional[torch.Tensor]) – Out tensor, shape (b, m, n), bf16 or fp16, defaults to
None
.
- Returns:
out – Out tensor, shape (b, m, n), bf16 or fp16.
- Return type:
torch.Tensor
Examples
>>> import torch >>> import torch.nn.functional as F >>> import flashinfer >>> def to_float8(x, dtype=torch.float8_e4m3fn): ... finfo = torch.finfo(dtype) ... min_val, max_val = x.aminmax() ... amax = torch.maximum(min_val.abs(), max_val.abs()).clamp(min=1e-12) ... scale = finfo.max / amax ... x_scl_sat = (x * scale).clamp(min=finfo.min, max=finfo.max) ... return x_scl_sat.to(dtype), scale.float().reciprocal() >>> >>> input = torch.randn([16, 48, 64], device="cuda", dtype=torch.bfloat16) >>> input_fp8, input_inv_s = to_float8(input, dtype=torch.float8_e4m3fn) >>> # column major weight >>> weight = torch.randn([16, 80, 64], device="cuda", dtype=torch.bfloat16).transpose(-2, -1) >>> weight_fp8, weight_inv_s = to_float8(weight, dtype=torch.float8_e4m3fn) >>> out = flashinfer.bmm_fp8(input_fp8, weight_fp8, input_inv_s, weight_inv_s, torch.bfloat16) >>> out.shape torch.Size([16, 48, 80]) >>> out.dtype torch.bfloat16