flashinfer.green_ctx.split_device_green_ctx

flashinfer.green_ctx.split_device_green_ctx(dev: torch.device, num_groups: int, min_count: int) Tuple[List[torch.Stream], List[cuda.bindings.driver.CUdevResource]]

Split the device into multiple green contexts, return the corresponding streams and CUdevResource for each group and the remaining SMs. Green contexts allow concurrent execution of multiple kernels on different SM partitions.

Parameters:
  • dev – The device to split.

  • num_groups – The number of groups to split the device into.

  • min_count – Minimum number of SMs required for each group, it will be adjusted to meet the alignment and granularity requirements.

Returns:

The list of torch.Streams objects corresponding to the green contexts. resources: The list of CUdevResource objects corresponding to the green contexts.

Return type:

streams

Example

>>> from flashinfer.green_ctx import split_device_green_ctx
>>> import torch
>>> dev = torch.device("cuda:0")
>>> streams, resources = split_device_green_ctx(dev, 2, 16)
>>> print([r.sm.smCount for r in resources])
[16, 16, 100]
>>> with torch.cuda.stream(streams[0]):
...     x = torch.randn(8192, 8192, device=dev, dtype=torch.bfloat16)
...     y = torch.randn(8192, 8192, device=dev, dtype=torch.bfloat16)
...     z = x @ y
...     print(z.shape)
...
torch.Size([8192, 8192])

Note

The length of the returned streams and resources is num_groups + 1, where the last one is the remaining SMs.

The following examples show how the SM count is rounded up to meet the alignment and granularity requirements: - Requested 7 SMs → Allocated 8 SMs (rounded up to minimum) - Requested 10 SMs → Allocated 16 SMs (rounded up to multiple of 8) - Requested 16 SMs → Allocated 16 SMs (no rounding needed) - Requested 17 SMs → Allocated 24 SMs (rounded up to multiple of 8)

Raises:
  • RuntimeError – when requested SM allocation exceeds device capacity:

  • num_groups * rounded_min_count > total_device_sms