flashinfer.green_ctx.split_device_green_ctx_by_sm_count¶
- flashinfer.green_ctx.split_device_green_ctx_by_sm_count(dev: torch.device, sm_counts: List[int]) Tuple[List[torch.Stream], List[cuda.bindings.driver.CUdevResource]] ¶
Split the device into multiple green contexts, each with a fixed number of SMs, return the corresponding streams and CUdevResource for each group and the remaining SMs. Green contexts allow concurrent execution of multiple kernels on different SM partitions.
- Parameters:
dev – The device to split.
sm_counts – List of SM counts for each partition. Each count will be rounded up to meet the minimum and alignment requirements.
- Returns:
The list of torch.Streams objects corresponding to the green contexts. resources: The list of CUdevResource objects corresponding to the green contexts.
- Return type:
streams
- Raises:
RuntimeError – If the requested SM allocation exceeds device capacity: - When sum(rounded_sm_counts) > total_device_sms - When CUDA operations fail due to invalid resource types - When the device is not properly initialized
ValueError – If sm_counts is empty or contains invalid values (e.g., negative values).
Example
>>> from flashinfer.green_ctx import split_device_green_ctx_by_sm_count >>> import torch >>> dev = torch.device("cuda:0") >>> >>> # Create three partitions with specific SM counts >>> streams, resources = split_device_green_ctx_by_sm_count(dev, [8, 16, 24]) >>> print([r.sm.smCount for r in resources]) [8, 16, 24, 84] # Last value is remaining SMs >>> >>> # Execute kernels on different partitions >>> with torch.cuda.stream(streams[0]): ... x = torch.randn(4096, 4096, device=dev, dtype=torch.bfloat16) ... y = torch.randn(4096, 4096, device=dev, dtype=torch.bfloat16) ... z = x @ y ... print(f"Partition 0 result: {z.shape}") ... >>> with torch.cuda.stream(streams[1]): ... # Different computation on partition 1 ... a = torch.randn(2048, 2048, device=dev, dtype=torch.bfloat16) ... b = torch.randn(2048, 2048, device=dev, dtype=torch.bfloat16) ... c = a @ b ... print(f"Partition 1 result: {c.shape}")
Note
The length of the returned streams and resources is
len(sm_counts) + 1
, where the last one contains the remaining SMs that were not allocated.SM count alignment examples for Compute Capability 9.0+: - Requested 7 SMs → Allocated 8 SMs (rounded up to minimum) - Requested 10 SMs → Allocated 16 SMs (rounded up to multiple of 8) - Requested 16 SMs → Allocated 16 SMs (no rounding needed) - Requested 17 SMs → Allocated 24 SMs (rounded up to multiple of 8)
The actual SM count can be obtained from the
.sm.smCount
field of the returned resources.See CUDA Green Contexts for more details.