flashinfer.gemm.grouped_gemm_nt_masked¶
- flashinfer.gemm.grouped_gemm_nt_masked(lhs: Tuple[Tensor, Tensor], rhs: Tuple[Tensor, Tensor], out: Tensor, masked_m: Tensor, *, ab_dtype: str, sf_dtype: str, c_dtype: str, sf_vec_size: int, topk_weights: Tensor | None = None, idx_src_info: Tensor | None = None, rank_src_info: Tensor | None = None, out_ptrs: Tensor | None = None, num_ranks: int = 0, dst_signals: Tensor | None = None, sm_count: int | None = None, barrier_flag_local: Tensor | None = None, barrier_flag_multicast: Tensor | None = None, is_combine_fusion: bool = False, is_swap_ab: bool = False, **kwargs)¶
Masked, batched, block-scaled GEMM on Blackwell SM100.
Executes a masked, batched matrix multiplication with scale factors and optional per-batch alpha scaling on the output.
alphais currently applied internally by the kernel; see Notes for the canonical tensor layouts.- Parameters:
lhs (Tuple[torch.Tensor, torch.Tensor]) –
(A, SFA)— left-hand-side input tensor and its scale-factor tensor.Ahas logical shape(m, k, l)(physically(l, m, k)); for FP4 with 8-bit storage the physical shape is(m, k/2, l).SFAhas logical shape(m32, m4, rm, k4, rk, l)(physically(l, rm, rk, m32, m4, k4)).rhs (Tuple[torch.Tensor, torch.Tensor]) –
(B, SFB)— right-hand-side input tensor and its scale-factor tensor.Bhas logical shape(n, k, l)(physically(l, n, k); FP4 with 8-bit storage is(n, k/2, l)).SFBhas logical shape(n32, n4, rn, k4, rk, l)(physically(l, rn, rk, n32, n4, k4)).out (torch.Tensor) – Output tensor of shape
(l, m, n). Mutated in place.masked_m (torch.Tensor) – 1-D
int32tensor of shape(l,)giving the valid row count of each batch. Rows abovemasked_m[batch]are ignored.ab_dtype (str) – Data type for
AandB. One of"float4_e2m1fn","float8_e4m3fn","float8_e5m2".sf_dtype (str) – Data type for the scale factors. One of
"float8_e8m0fnu"or"float8_e4m3fn".c_dtype (str) – Data type for output matrix
C. One of"float16","bfloat16","float32","float8_e4m3fn","float8_e5m2".sf_vec_size (int) – Vector size for scale factors (typically 16 or 32).
topk_weights (Optional[torch.Tensor]) – 2-D
float32tensor of shape(l, m)containing top-k routing weights. Defaults toNone.idx_src_info (Optional[torch.Tensor]) – 2-D
int32tensor of shape(l, m)carrying source-index metadata for the combine fusion path. Defaults toNone.rank_src_info (Optional[torch.Tensor]) – 2-D
int32tensor of shape(l, m)carrying rank-source metadata for the combine fusion path. Defaults toNone.out_ptrs (Optional[torch.Tensor]) – 1-D
int64tensor of shape(num_ranks,)containing remote output pointers for multi-rank combine. Defaults toNone.num_ranks (int) – Number of ranks participating in the combine path. Defaults to
0.dst_signals (Optional[torch.Tensor]) – Optional 1-D signal tensor used by the combine-fusion path to notify consumers. Defaults to
None.sm_count (Optional[int]) – Number of SMs to use. If
None, the runtime picks the max available under the CTA configuration.barrier_flag_local (Optional[torch.Tensor]) – 1-D
int32tensor of shape(sm_count,)containing flags for local barrier synchronization (spin-lock wait in multi-rank ops). Defaults toNone.barrier_flag_multicast (Optional[torch.Tensor]) – 1-D
int32tensor of shape(sm_count,)containing flags for multicast barrier synchronization (release across ranks). Defaults toNone.is_combine_fusion (bool) – If
True, enable the fused GEMM + combine operation mode. Defaults toFalse.is_swap_ab (bool) – If
True, swap thelhs/rhsinput tensors. Defaults toFalse.**kwargs –
Additional keyword arguments. Currently recognized:
mma_tiler_mn(Tuple[int, int]): shape of the MMA tiler(M, N). Defaults to(128, 128).mma_tiler_mn[0] == 256enables the 2-CTA MMA path. Must be(128, 128)whenis_combine_fusion=True.cluster_shape_mn(Tuple[int, int]): shape of the CTA cluster(ClusterM, ClusterN). Defaults to(1, 1).alpha(Optional[torch.Tensor]): optional 1-D tensor of shape(l,)containing per-batch scaling factors. When provided, each batch output is multiplied by its corresponding alpha value:out = alpha * (A @ B).alpha_dtype(str): elemental dtype string for thealphatensor (e.g."float32"). Required whenalphais provided.
Other entries are reserved for forward-compatible kernel options.
Notes
Tensor-layout conventions:
lis the batch size,m/nare the row/column counts, andkis the contraction dimension.m/n32,m/n4,k4are the constants32,4,4respectively.m32 * m4 * rmequalsM, whereMismpadded up to the nearest multiple of 128.n32 * n4 * rnequalsN, whereNisnpadded up to the nearest multiple of 128.k4 * rkequalsK, whereKisk / sf_vec_sizepadded up to the nearest multiple of 4.
Masking is applied per batch via
masked_m. Whenalphais provided (see**kwargs), each batch output is multiplied by its corresponding alpha value:out = alpha * (A @ B).