flashinfer.gemm.grouped_gemm_nt_masked

flashinfer.gemm.grouped_gemm_nt_masked(lhs: Tuple[Tensor, Tensor], rhs: Tuple[Tensor, Tensor], out: Tensor, masked_m: Tensor, *, ab_dtype: str, sf_dtype: str, c_dtype: str, sf_vec_size: int, topk_weights: Tensor | None = None, idx_src_info: Tensor | None = None, rank_src_info: Tensor | None = None, out_ptrs: Tensor | None = None, num_ranks: int = 0, dst_signals: Tensor | None = None, sm_count: int | None = None, barrier_flag_local: Tensor | None = None, barrier_flag_multicast: Tensor | None = None, is_combine_fusion: bool = False, is_swap_ab: bool = False, **kwargs)

Masked, batched, block-scaled GEMM on Blackwell SM100.

Executes a masked, batched matrix multiplication with scale factors and optional per-batch alpha scaling on the output. alpha is currently applied internally by the kernel; see Notes for the canonical tensor layouts.

Parameters:
  • lhs (Tuple[torch.Tensor, torch.Tensor]) – (A, SFA) — left-hand-side input tensor and its scale-factor tensor. A has logical shape (m, k, l) (physically (l, m, k)); for FP4 with 8-bit storage the physical shape is (m, k/2, l). SFA has logical shape (m32, m4, rm, k4, rk, l) (physically (l, rm, rk, m32, m4, k4)).

  • rhs (Tuple[torch.Tensor, torch.Tensor]) – (B, SFB) — right-hand-side input tensor and its scale-factor tensor. B has logical shape (n, k, l) (physically (l, n, k); FP4 with 8-bit storage is (n, k/2, l)). SFB has logical shape (n32, n4, rn, k4, rk, l) (physically (l, rn, rk, n32, n4, k4)).

  • out (torch.Tensor) – Output tensor of shape (l, m, n). Mutated in place.

  • masked_m (torch.Tensor) – 1-D int32 tensor of shape (l,) giving the valid row count of each batch. Rows above masked_m[batch] are ignored.

  • ab_dtype (str) – Data type for A and B. One of "float4_e2m1fn", "float8_e4m3fn", "float8_e5m2".

  • sf_dtype (str) – Data type for the scale factors. One of "float8_e8m0fnu" or "float8_e4m3fn".

  • c_dtype (str) – Data type for output matrix C. One of "float16", "bfloat16", "float32", "float8_e4m3fn", "float8_e5m2".

  • sf_vec_size (int) – Vector size for scale factors (typically 16 or 32).

  • topk_weights (Optional[torch.Tensor]) – 2-D float32 tensor of shape (l, m) containing top-k routing weights. Defaults to None.

  • idx_src_info (Optional[torch.Tensor]) – 2-D int32 tensor of shape (l, m) carrying source-index metadata for the combine fusion path. Defaults to None.

  • rank_src_info (Optional[torch.Tensor]) – 2-D int32 tensor of shape (l, m) carrying rank-source metadata for the combine fusion path. Defaults to None.

  • out_ptrs (Optional[torch.Tensor]) – 1-D int64 tensor of shape (num_ranks,) containing remote output pointers for multi-rank combine. Defaults to None.

  • num_ranks (int) – Number of ranks participating in the combine path. Defaults to 0.

  • dst_signals (Optional[torch.Tensor]) – Optional 1-D signal tensor used by the combine-fusion path to notify consumers. Defaults to None.

  • sm_count (Optional[int]) – Number of SMs to use. If None, the runtime picks the max available under the CTA configuration.

  • barrier_flag_local (Optional[torch.Tensor]) – 1-D int32 tensor of shape (sm_count,) containing flags for local barrier synchronization (spin-lock wait in multi-rank ops). Defaults to None.

  • barrier_flag_multicast (Optional[torch.Tensor]) – 1-D int32 tensor of shape (sm_count,) containing flags for multicast barrier synchronization (release across ranks). Defaults to None.

  • is_combine_fusion (bool) – If True, enable the fused GEMM + combine operation mode. Defaults to False.

  • is_swap_ab (bool) – If True, swap the lhs/rhs input tensors. Defaults to False.

  • **kwargs

    Additional keyword arguments. Currently recognized:

    • mma_tiler_mn (Tuple[int, int]): shape of the MMA tiler (M, N). Defaults to (128, 128). mma_tiler_mn[0] == 256 enables the 2-CTA MMA path. Must be (128, 128) when is_combine_fusion=True.

    • cluster_shape_mn (Tuple[int, int]): shape of the CTA cluster (ClusterM, ClusterN). Defaults to (1, 1).

    • alpha (Optional[torch.Tensor]): optional 1-D tensor of shape (l,) containing per-batch scaling factors. When provided, each batch output is multiplied by its corresponding alpha value: out = alpha * (A @ B).

    • alpha_dtype (str): elemental dtype string for the alpha tensor (e.g. "float32"). Required when alpha is provided.

    Other entries are reserved for forward-compatible kernel options.

Notes

Tensor-layout conventions:

  • l is the batch size, m/n are the row/column counts, and k is the contraction dimension.

  • m/n32, m/n4, k4 are the constants 32, 4, 4 respectively.

  • m32 * m4 * rm equals M, where M is m padded up to the nearest multiple of 128.

  • n32 * n4 * rn equals N, where N is n padded up to the nearest multiple of 128.

  • k4 * rk equals K, where K is k / sf_vec_size padded up to the nearest multiple of 4.

Masking is applied per batch via masked_m. When alpha is provided (see **kwargs), each batch output is multiplied by its corresponding alpha value: out = alpha * (A @ B).