flashinfer.green_ctx.split_device_green_ctx_by_sm_count

flashinfer.green_ctx.split_device_green_ctx_by_sm_count(dev: torch.device, sm_counts: List[int]) Tuple[List[torch.Stream], List[cuda.bindings.driver.CUdevResource]]

Split the device into multiple green contexts, each with a fixed number of SMs, return the corresponding streams and CUdevResource for each group and the remaining SMs. Green contexts allow concurrent execution of multiple kernels on different SM partitions.

Parameters:
  • dev – The device to split.

  • sm_counts – List of SM counts for each partition. Each count will be rounded up to meet the minimum and alignment requirements.

Returns:

The list of torch.Streams objects corresponding to the green contexts. resources: The list of CUdevResource objects corresponding to the green contexts.

Return type:

streams

Raises:
  • RuntimeError – If the requested SM allocation exceeds device capacity: - When sum(rounded_sm_counts) > total_device_sms - When CUDA operations fail due to invalid resource types - When the device is not properly initialized

  • ValueError – If sm_counts is empty or contains invalid values (e.g., negative values).

Example

>>> from flashinfer.green_ctx import split_device_green_ctx_by_sm_count
>>> import torch
>>> dev = torch.device("cuda:0")
>>>
>>> # Create three partitions with specific SM counts
>>> streams, resources = split_device_green_ctx_by_sm_count(dev, [8, 16, 24])
>>> print([r.sm.smCount for r in resources])
[8, 16, 24, 84]  # Last value is remaining SMs
>>>
>>> # Execute kernels on different partitions
>>> with torch.cuda.stream(streams[0]):
...     x = torch.randn(4096, 4096, device=dev, dtype=torch.bfloat16)
...     y = torch.randn(4096, 4096, device=dev, dtype=torch.bfloat16)
...     z = x @ y
...     print(f"Partition 0 result: {z.shape}")
...
>>> with torch.cuda.stream(streams[1]):
...     # Different computation on partition 1
...     a = torch.randn(2048, 2048, device=dev, dtype=torch.bfloat16)
...     b = torch.randn(2048, 2048, device=dev, dtype=torch.bfloat16)
...     c = a @ b
...     print(f"Partition 1 result: {c.shape}")

Note

The length of the returned streams and resources is len(sm_counts) + 1, where the last one contains the remaining SMs that were not allocated.

SM count alignment examples for Compute Capability 9.0+: - Requested 7 SMs → Allocated 8 SMs (rounded up to minimum) - Requested 10 SMs → Allocated 16 SMs (rounded up to multiple of 8) - Requested 16 SMs → Allocated 16 SMs (no rounding needed) - Requested 17 SMs → Allocated 24 SMs (rounded up to multiple of 8)

The actual SM count can be obtained from the .sm.smCount field of the returned resources.

See CUDA Green Contexts for more details.