flashinfer.gemm.mm_bf16_fp4

flashinfer.gemm.mm_bf16_fp4(a: Tensor, b: Tensor, b_descale: Tensor, alpha: Tensor | None = None, *, backend: Literal['cudnn', 'cute-dsl'], out_dtype: dtype | None = None, out: Tensor | None = None, block_size: int = 16, enable_pdl: bool = True) Tensor

BF16 x FP4 GEMM: out = (a @ dequant(b).T) * alpha.

Intended to support W4A16 workloads (4-bit weights, 16-bit activations) nvfp4 weights must be prepared for backend by prepare_bf16_fp4_weights(). b, b_descale, and alpha.

Example

# 1) Prepare weights for a backend (once, at model load).
b_p, sf_p, alpha_p = flashinfer.prepare_bf16_fp4_weights(
    b, b_descale, alpha, backend="cute-dsl",
)
# 2) Run the GEMM with the *same* backend tag.
out = flashinfer.mm_bf16_fp4(
    a, b_p, sf_p, alpha_p, backend="cute-dsl",
)
Parameters:
  • a(M, K) activation matrix in torch.bfloat16. This is the only currently supported activation dtype; fp16 support can be added when needed.

  • b – Prepared weight tensor (backend-specific layout).

  • b_descale – Prepared scale-factor tensor (backend-specific layout).

  • alpha – Optional (1,) float32 global scalar. Pass through whatever prepare_bf16_fp4_weights returned – it may be None if the backend folded it into b_descale.

  • backend – Same identifier passed to prepare_bf16_fp4_weights.

  • out_dtype – Output dtype. Defaults to a.dtype (bfloat16).

  • out – Optional preallocated (M, N) output tensor.

  • block_size – SF block size. Always 16 for FP4.

  • enable_pdl – Enable Programmatic Dependent Launch

Returns:

(M, N) tensor of out_dtype.