flashinfer.comm.decode_cp_a2a_workspace_size

flashinfer.comm.decode_cp_a2a_workspace_size(cp_size: int) int

Return the workspace size (in bytes) per rank for the given CP group size.

Parameters:

cp_size (int) – Context-parallel group size (number of ranks).

Returns:

Workspace size in bytes per rank.

Return type:

int

Examples

>>> decode_cp_a2a_workspace_size(4)
16778240