flashinfer.gemm.mm_bf16_fp4¶
- flashinfer.gemm.mm_bf16_fp4(a: Tensor, b: Tensor, b_descale: Tensor, alpha: Tensor | None = None, *, backend: Literal['cudnn', 'cute-dsl'], out_dtype: dtype | None = None, out: Tensor | None = None, block_size: int = 16, enable_pdl: bool = True) Tensor¶
BF16 x FP4 GEMM:
out = (a @ dequant(b).T) * alpha.Intended to support W4A16 workloads (4-bit weights, 16-bit activations) nvfp4 weights must be prepared for
backendbyprepare_bf16_fp4_weights().b,b_descale, andalpha.Example
# 1) Prepare weights for a backend (once, at model load). b_p, sf_p, alpha_p = flashinfer.prepare_bf16_fp4_weights( b, b_descale, alpha, backend="cute-dsl", ) # 2) Run the GEMM with the *same* backend tag. out = flashinfer.mm_bf16_fp4( a, b_p, sf_p, alpha_p, backend="cute-dsl", )
- Parameters:
a –
(M, K)activation matrix intorch.bfloat16. This is the only currently supported activation dtype; fp16 support can be added when needed.b – Prepared weight tensor (backend-specific layout).
b_descale – Prepared scale-factor tensor (backend-specific layout).
alpha – Optional
(1,) float32global scalar. Pass through whateverprepare_bf16_fp4_weightsreturned – it may beNoneif the backend folded it intob_descale.backend – Same identifier passed to
prepare_bf16_fp4_weights.out_dtype – Output dtype. Defaults to
a.dtype(bfloat16).out – Optional preallocated
(M, N)output tensor.block_size – SF block size. Always 16 for FP4.
enable_pdl – Enable Programmatic Dependent Launch
- Returns:
(M, N)tensor ofout_dtype.