flashinfer.gemm.prepare_bf16_fp4_weights

flashinfer.gemm.prepare_bf16_fp4_weights(b: Tensor, b_descale: Tensor, alpha: Tensor | None = None, *, backend: Literal['cudnn', 'cute-dsl'], block_size: int = 16) Tuple[Tensor, Tensor, Tensor | None]

Prepare FP4 weights for the bf16 x fp4 GEMM, for a specific backend.

The caller is expected to start with weights in the canonical format that flashinfer.nvfp4_quantize() produces with sfLayout=layout_128x4:

  • b is (N, K // 2) uint8 with two FP4 codes packed per byte (low nibble = K=2i, high nibble = K=2i+1).

  • b_descale is the 128x4-swizzled FP8-E4M3 per-block scales, either as a 1-D byte buffer or a 2-D tensor.

Each backend transforms these into whatever layout its compute kernel expects. The returned (b, b_descale, alpha) tuple must be passed back to flashinfer.mm_bf16_fp4() with the same backend – the shapes / dtypes may not match other backends’ expectations.

Parameters:
  • b(N, K // 2) uint8 packed FP4 weight.

  • b_descale – 128x4-swizzled FP8-E4M3 scale factors from nvfp4_quantize. Either 1-D byte buffer or 2-D tensor.

  • alpha – Optional (1,) float32 global scalar. Pass None (default) for implicit alpha=1.0. Returned unchanged; forward the returned tuple to flashinfer.mm_bf16_fp4().

  • backend – Identifier of a supported backend ("cudnn" or "cute-dsl").

  • block_size – SF block size. Always 16 for FP4.

Returns:

(b_prepared, b_descale_prepared, alpha_prepared) – pass all three to flashinfer.mm_bf16_fp4() with the same backend.

Raises:
  • ValueErrorbackend is unknown, or an input has an invalid shape (b not 2-D, K not a multiple of block_size, or alpha not shape (1,)).

  • TypeErrorb is not uint8 or alpha is not float32.