flashinfer.gemm.mm_fp8¶
- flashinfer.gemm.mm_fp8(a: Tensor, b: Tensor, alpha: Tensor | None = None, out_dtype: dtype = torch.bfloat16, out: Tensor | None = None, backend: Literal['trtllm_low_latency'] = 'trtllm_low_latency')¶
FP8 matrix multiplication.
- Parameters:
a (torch.Tensor) – Input tensor, shape (m, k), fp8 e4m3.
b (torch.Tensor) –
When using “trtllm_low_latency” backend, Weight tensor, shape (k // block_size, n, block_size), fp8 e4m3 B needs to be pre-processed using prepare_low_latency_gemm_weights. block_size is 128 for e4m3.
alpha (Optional[torch.Tensor]) – Scale tensor for the output, float. If None, defaults to 1.0 for no scaling.
out_dtype (torch.dtype) – Output tensor data type. Default is torch.bfloat16.
out (Optional[torch.Tensor]) – Output tensor, shape (m, n). If None, a new tensor will be allocated.
backend (Literal["trtllm_low_latency"]) – Backend to use for computation. Default is “trtllm_low_latency”. - “trtllm_low_latency”: optimized for small M dimension.
- Returns:
Output tensor of shape (m, n) with dtype out_dtype.
- Return type:
torch.Tensor
Examples
>>> import torch >>> from flashinfer import mm_fp8, prepare_low_latency_gemm_weights >>> m = 16 >>> n = 2560 >>> k = 32768 >>> a = torch.randn([m, k], device="cuda", dtype=torch.bfloat16) >>> a_fp8, a_inv_s = to_float8(a, dtype=torch.float8_e4m3fn) >>> b = torch.randn([n, k], device="cuda", dtype=torch.bfloat16) >>> b_fp8, b_inv_s = to_float8(b, dtype=torch.float8_e4m3fn) >>> prepared_b = prepare_low_latency_gemm_weights(b_fp8) >>> alpha = a_inv_s * b_inv_s >>> out = mm_fp8(a_fp8, prepared_b, alpha) >>> out.shape torch.Size([16, 2560])