flashinfer.gemm.fp8_blockscale_gemm_sm90

flashinfer.gemm.fp8_blockscale_gemm_sm90(input: Tensor, weight: Tensor, input_scale: Tensor | None = None, weight_scale: Tensor | None = None, out: Tensor | None = None, out_dtype: dtype | None = None) Tensor

Perform FP8 block-scaled GEMM with automatic swapAB optimization. This function automatically selects between normal and swapAB kernel based on the M dimension. For small M (< 32), it uses the swapAB kernel for better performance.

Supported Dtype Combinations

  • BF16 + BF16 → BF16: Both inputs BF16, internal quantization (no scales needed)

  • BF16 + FP8 → BF16: BF16 input, FP8 weight

  • FP8 + FP8 → BF16 (W8A8): Both inputs FP8 with scales required

param input:

Input activation tensor of shape (M, K). - BF16 (torch.bfloat16) with internal quantization

type input:

torch.Tensor

param weight:

Weight tensor of shape (N, K). Can be: - FP8 (torch.float8_e4m3fn) with weight_scale required - BF16 (torch.bfloat16) for internal quantization

type weight:

torch.Tensor

param input_scale:

type input_scale:

torch.Tensor, optional

param weight_scale:

Scaling factors for weight. Required if weight is FP8.

type weight_scale:

torch.Tensor, optional

param out:

Output tensor of shape (M, N). If None, will be allocated.

type out:

torch.Tensor, optional

param out_dtype:

Output data type. Default is torch.bfloat16.

type out_dtype:

torch.dtype, optional

returns:

Output tensor of shape (M, N) with dtype out_dtype.

rtype:

torch.Tensor

Examples

>>> import torch
>>> from flashinfer.gemm import fp8_blockscale_gemm_sm90
>>>
>>> M, N, K = 16, 4096, 4096
>>> device = "cuda"
>>>
>>> # BF16 inputs
>>> input_bf16 = torch.randn(M, K, device=device, dtype=torch.bfloat16)
>>> weight_bf16 = torch.randn(N, K, device=device, dtype=torch.bfloat16)
>>> output = fp8_blockscale_gemm_sm90(input_bf16, weight_bf16)
>>> print(output.shape)  # torch.Size([16, 4096])
>>>
>>> # Mixed: BF16 input + FP8 weight
>>> from flashinfer.testing.utils import per_token_cast_to_fp8
>>> input_bf16 = torch.randn(M, K, device=device, dtype=torch.bfloat16)
>>> weight_bf16 = torch.randn(N, K, device=device, dtype=torch.bfloat16)
>>> weight_fp8, weight_scale = per_token_cast_to_fp8(weight_bf16)
>>> output = fp8_blockscale_gemm_sm90(input_bf16, weight_fp8, None, weight_scale)
>>> print(output.shape)  # torch.Size([16, 4096])
>>>
>>> # FP8 weight with 128x128 block scales
>>> from flashinfer.testing.utils import per_block_cast_to_fp8
>>> weight_bf16 = torch.randn(N, K, device=device, dtype=torch.bfloat16)
>>> weight_fp8, weight_scale = per_block_cast_to_fp8(weight_bf16)
>>> # weight_scale has shape (N // 128, K // 128)
>>> input_bf16 = torch.randn(M, K, device=device, dtype=torch.bfloat16)
>>> output = fp8_blockscale_gemm_sm90(input_bf16, weight_fp8, None, weight_scale)
>>> print(output.shape)  # torch.Size([16, 4096])

Notes

  • This function requires NVIDIA Hopper (SM90) architecture and CUDA 12.8+

  • SwapAB kernel is automatically used when M < 32 (threshold)

  • For FP8 inputs, scaling factors must be provided

  • For BF16 inputs, quantization and scaling happen internally

  • Weight scales support two granularities: * Per-token (1x128 blocks): (N, K//128) * Per-block (128x128 blocks): (N//128, K//128)

  • Input scales only support per-token format: (M, K//128)

  • The function uses DeepGEMM backend with JIT compilation