flashinfer.gemm.prepare_bf16_fp4_weights¶
- flashinfer.gemm.prepare_bf16_fp4_weights(b: Tensor, b_descale: Tensor, alpha: Tensor | None = None, *, backend: Literal['cudnn', 'cute-dsl'], block_size: int = 16) Tuple[Tensor, Tensor, Tensor | None]¶
Prepare FP4 weights for the bf16 x fp4 GEMM, for a specific backend.
The caller is expected to start with weights in the canonical format that
flashinfer.nvfp4_quantize()produces withsfLayout=layout_128x4:bis(N, K // 2)uint8with two FP4 codes packed per byte (low nibble = K=2i, high nibble = K=2i+1).b_descaleis the 128x4-swizzled FP8-E4M3 per-block scales, either as a 1-D byte buffer or a 2-D tensor.
Each backend transforms these into whatever layout its compute kernel expects. The returned
(b, b_descale, alpha)tuple must be passed back toflashinfer.mm_bf16_fp4()with the samebackend– the shapes / dtypes may not match other backends’ expectations.- Parameters:
b –
(N, K // 2)uint8packed FP4 weight.b_descale – 128x4-swizzled FP8-E4M3 scale factors from
nvfp4_quantize. Either 1-D byte buffer or 2-D tensor.alpha – Optional
(1,) float32global scalar. PassNone(default) for implicitalpha=1.0. Returned unchanged; forward the returned tuple toflashinfer.mm_bf16_fp4().backend – Identifier of a supported backend (
"cudnn"or"cute-dsl").block_size – SF block size. Always 16 for FP4.
- Returns:
(b_prepared, b_descale_prepared, alpha_prepared)– pass all three toflashinfer.mm_bf16_fp4()with the samebackend.- Raises:
ValueError –
backendis unknown, or an input has an invalid shape (bnot 2-D,Knot a multiple ofblock_size, oralphanot shape(1,)).TypeError –
bis notuint8oralphais notfloat32.