flashinfer.gemm.mm_fp8

flashinfer.gemm.mm_fp8(a: Tensor, b: Tensor, alpha: Tensor | None = None, out_dtype: dtype = torch.bfloat16, out: Tensor | None = None, backend: Literal['trtllm_low_latency'] = 'trtllm_low_latency')

FP8 matrix multiplication.

Parameters:
  • a (torch.Tensor) – Input tensor, shape (m, k), fp8 e4m3.

  • b (torch.Tensor) –

    • When using “trtllm_low_latency” backend, Weight tensor, shape (k // block_size, n, block_size), fp8 e4m3 B needs to be pre-processed using prepare_low_latency_gemm_weights. block_size is 128 for e4m3.

  • alpha (Optional[torch.Tensor]) – Scale tensor for the output, float. If None, defaults to 1.0 for no scaling.

  • out_dtype (torch.dtype) – Output tensor data type. Default is torch.bfloat16.

  • out (Optional[torch.Tensor]) – Output tensor, shape (m, n). If None, a new tensor will be allocated.

  • backend (Literal["trtllm_low_latency"]) – Backend to use for computation. Default is “trtllm_low_latency”. - “trtllm_low_latency”: optimized for small M dimension.

Returns:

Output tensor of shape (m, n) with dtype out_dtype.

Return type:

torch.Tensor

Examples

>>> import torch
>>> from flashinfer import mm_fp8, prepare_low_latency_gemm_weights
>>> m = 16
>>> n = 2560
>>> k = 32768
>>> a = torch.randn([m, k], device="cuda", dtype=torch.bfloat16)
>>> a_fp8, a_inv_s = to_float8(a, dtype=torch.float8_e4m3fn)
>>> b = torch.randn([n, k], device="cuda", dtype=torch.bfloat16)
>>> b_fp8, b_inv_s = to_float8(b, dtype=torch.float8_e4m3fn)
>>> prepared_b = prepare_low_latency_gemm_weights(b_fp8)
>>> alpha = a_inv_s * b_inv_s
>>> out = mm_fp8(a_fp8, prepared_b, alpha)
>>> out.shape
torch.Size([16, 2560])