flashinfer.gemm.fp8_blockscale_gemm_sm90¶
- flashinfer.gemm.fp8_blockscale_gemm_sm90(input: Tensor, weight: Tensor, input_scale: Tensor | None = None, weight_scale: Tensor | None = None, out: Tensor | None = None, out_dtype: dtype | None = None) Tensor¶
Perform FP8 block-scaled GEMM with automatic swapAB optimization. This function automatically selects between normal and swapAB kernel based on the M dimension. For small M (< 32), it uses the swapAB kernel for better performance.
Supported Dtype Combinations¶
BF16 + BF16 → BF16: Both inputs BF16, internal quantization (no scales needed)
BF16 + FP8 → BF16: BF16 input, FP8 weight
FP8 + FP8 → BF16 (W8A8): Both inputs FP8 with scales required
- param input:
Input activation tensor of shape (M, K). - BF16 (torch.bfloat16) with internal quantization
- type input:
torch.Tensor
- param weight:
Weight tensor of shape (N, K). Can be: - FP8 (torch.float8_e4m3fn) with weight_scale required - BF16 (torch.bfloat16) for internal quantization
- type weight:
torch.Tensor
- param input_scale:
- type input_scale:
torch.Tensor, optional
- param weight_scale:
Scaling factors for weight. Required if weight is FP8.
- type weight_scale:
torch.Tensor, optional
- param out:
Output tensor of shape (M, N). If None, will be allocated.
- type out:
torch.Tensor, optional
- param out_dtype:
Output data type. Default is torch.bfloat16.
- type out_dtype:
torch.dtype, optional
- returns:
Output tensor of shape (M, N) with dtype out_dtype.
- rtype:
torch.Tensor
Examples
>>> import torch >>> from flashinfer.gemm import fp8_blockscale_gemm_sm90 >>> >>> M, N, K = 16, 4096, 4096 >>> device = "cuda" >>> >>> # BF16 inputs >>> input_bf16 = torch.randn(M, K, device=device, dtype=torch.bfloat16) >>> weight_bf16 = torch.randn(N, K, device=device, dtype=torch.bfloat16) >>> output = fp8_blockscale_gemm_sm90(input_bf16, weight_bf16) >>> print(output.shape) # torch.Size([16, 4096]) >>> >>> # Mixed: BF16 input + FP8 weight >>> from flashinfer.testing.utils import per_token_cast_to_fp8 >>> input_bf16 = torch.randn(M, K, device=device, dtype=torch.bfloat16) >>> weight_bf16 = torch.randn(N, K, device=device, dtype=torch.bfloat16) >>> weight_fp8, weight_scale = per_token_cast_to_fp8(weight_bf16) >>> output = fp8_blockscale_gemm_sm90(input_bf16, weight_fp8, None, weight_scale) >>> print(output.shape) # torch.Size([16, 4096]) >>> >>> # FP8 weight with 128x128 block scales >>> from flashinfer.testing.utils import per_block_cast_to_fp8 >>> weight_bf16 = torch.randn(N, K, device=device, dtype=torch.bfloat16) >>> weight_fp8, weight_scale = per_block_cast_to_fp8(weight_bf16) >>> # weight_scale has shape (N // 128, K // 128) >>> input_bf16 = torch.randn(M, K, device=device, dtype=torch.bfloat16) >>> output = fp8_blockscale_gemm_sm90(input_bf16, weight_fp8, None, weight_scale) >>> print(output.shape) # torch.Size([16, 4096])
Notes
This function requires NVIDIA Hopper (SM90) architecture and CUDA 12.8+
SwapAB kernel is automatically used when M < 32 (threshold)
For FP8 inputs, scaling factors must be provided
For BF16 inputs, quantization and scaling happen internally
Weight scales support two granularities: * Per-token (1x128 blocks): (N, K//128) * Per-block (128x128 blocks): (N//128, K//128)
Input scales only support per-token format: (M, K//128)
The function uses DeepGEMM backend with JIT compilation