From 9eeef015f28e910503fe368fdb1b848ab51aa482 Mon Sep 17 00:00:00 2001 From: Kevin Turcios Date: Sun, 10 May 2026 21:59:19 -0500 Subject: [PATCH 1/4] perf(cuda_core): deduplicate Device() call in LaunchConfig.__init__ Avoid calling Device() twice (once for cluster validation, once for cooperative check). Now called at most once, and zero times for the common simple-launch path where neither cluster nor is_cooperative is set. Co-Authored-By: Claude --- cuda_core/cuda/core/_launch_config.pyx | 7 +++++-- 1 file changed, 5 insertions(+), 2 deletions(-) diff --git a/cuda_core/cuda/core/_launch_config.pyx b/cuda_core/cuda/core/_launch_config.pyx index b1a9a96cb2..a3fa7e83e8 100644 --- a/cuda_core/cuda/core/_launch_config.pyx +++ b/cuda_core/cuda/core/_launch_config.pyx @@ -74,8 +74,11 @@ cdef class LaunchConfig: # look up the device from stream. We probably need to defer the checks related to # device compute capability or attributes. # thread block clusters are supported starting H100 + if cluster is not None or is_cooperative: + dev = Device() + if cluster is not None: - cc = Device().compute_capability + cc = dev.compute_capability if cc < (9, 0): raise CUDAError( f"thread block clusters are not supported on devices with compute capability < 9.0 (got {cc})" @@ -92,7 +95,7 @@ cdef class LaunchConfig: self.is_cooperative = is_cooperative - if self.is_cooperative and not Device().properties.cooperative_launch: + if self.is_cooperative and not dev.properties.cooperative_launch: raise CUDAError("cooperative kernels are not supported on this device") def _identity(self): From bc6ba8c1ab33ce54a690a3fdc633e1c8f26e4cc8 Mon Sep 17 00:00:00 2001 From: Kevin Turcios Date: Tue, 12 May 2026 08:21:55 -0500 Subject: [PATCH 2/4] refactor(cuda_core): defer device checks in LaunchConfig to launch time MIME-Version: 1.0 Content-Type: text/plain; charset=UTF-8 Content-Transfer-Encoding: 8bit LaunchConfig.__init__ previously called Device() to validate compute capability (for cluster launches) and cooperative_launch support, but at construction time the stream — and therefore the correct device — is not yet known. Move both checks into _launcher.pyx where the stream is available: - _check_cluster_launch: queries stream.device.compute_capability and raises if CC < 9.0 (thread block clusters require H100+) - _check_cooperative_launch: now also guards cooperative_launch support via stream.device before the grid-size check LaunchConfig.__init__ is now a pure data class with no driver calls. Cluster and cooperative config objects can be constructed without a CUDA context, and errors surface at launch() time with the correct device in scope. Remove the try/except CUDAError skip guards from cluster-related tests; constructing LaunchConfig(cluster=...) no longer raises on sub-CC-9.0 devices, so those tests run on all hardware. --- cuda_core/cuda/core/_launch_config.pyx | 17 ----- cuda_core/cuda/core/_launcher.pyx | 13 ++++ cuda_core/tests/test_launcher.py | 90 ++++++++++++-------------- 3 files changed, 54 insertions(+), 66 deletions(-) diff --git a/cuda_core/cuda/core/_launch_config.pyx b/cuda_core/cuda/core/_launch_config.pyx index a3fa7e83e8..03e76dbbe3 100644 --- a/cuda_core/cuda/core/_launch_config.pyx +++ b/cuda_core/cuda/core/_launch_config.pyx @@ -4,9 +4,7 @@ from libc.string cimport memset -from cuda.core._device import Device from cuda.core._utils.cuda_utils import ( - CUDAError, cast_to_3_tuple, driver, ) @@ -70,19 +68,7 @@ cdef class LaunchConfig: self.grid = cast_to_3_tuple("LaunchConfig.grid", grid) self.block = cast_to_3_tuple("LaunchConfig.block", block) - # FIXME: Calling Device() strictly speaking is not quite right; we should instead - # look up the device from stream. We probably need to defer the checks related to - # device compute capability or attributes. - # thread block clusters are supported starting H100 - if cluster is not None or is_cooperative: - dev = Device() - if cluster is not None: - cc = dev.compute_capability - if cc < (9, 0): - raise CUDAError( - f"thread block clusters are not supported on devices with compute capability < 9.0 (got {cc})" - ) self.cluster = cast_to_3_tuple("LaunchConfig.cluster", cluster) else: self.cluster = None @@ -95,9 +81,6 @@ cdef class LaunchConfig: self.is_cooperative = is_cooperative - if self.is_cooperative and not dev.properties.cooperative_launch: - raise CUDAError("cooperative kernels are not supported on this device") - def _identity(self): return tuple(getattr(self, attr) for attr in _LAUNCH_CONFIG_ATTRS) diff --git a/cuda_core/cuda/core/_launcher.pyx b/cuda_core/cuda/core/_launcher.pyx index e6a07ad28e..3148c0ef89 100644 --- a/cuda_core/cuda/core/_launcher.pyx +++ b/cuda_core/cuda/core/_launcher.pyx @@ -17,6 +17,7 @@ from cuda.core._utils.cuda_utils cimport ( ) from cuda.core._module import Kernel from cuda.core._stream import Stream +from cuda.core._utils.cuda_utils import CUDAError from math import prod @@ -52,14 +53,26 @@ def launch(stream: Stream | GraphBuilder | IsStreamType, config: LaunchConfig, k drv_cfg = conf._to_native_launch_config() drv_cfg.hStream = as_cu(s._h_stream) + if conf.cluster is not None: + _check_cluster_launch(conf, s) if conf.is_cooperative: _check_cooperative_launch(kernel, conf, s) with nogil: HANDLE_RETURN(cydriver.cuLaunchKernelEx(&drv_cfg, func_handle, args_ptr, NULL)) +cdef _check_cluster_launch(config: LaunchConfig, stream: Stream): + cc = stream.device.compute_capability + if cc < (9, 0): + raise CUDAError( + f"thread block clusters are not supported on devices with compute capability < 9.0 (got {cc})" + ) + + cdef _check_cooperative_launch(kernel: Kernel, config: LaunchConfig, stream: Stream): dev = stream.device + if not dev.properties.cooperative_launch: + raise CUDAError("cooperative kernels are not supported on this device") num_sm = dev.properties.multiprocessor_count max_grid_size = ( kernel.occupancy.max_active_blocks_per_multiprocessor(prod(config.block), config.shmem_size) * num_sm diff --git a/cuda_core/tests/test_launcher.py b/cuda_core/tests/test_launcher.py index f4858cdaef..ec57738c6c 100644 --- a/cuda_core/tests/test_launcher.py +++ b/cuda_core/tests/test_launcher.py @@ -63,66 +63,58 @@ def test_launch_config_shmem_size(): assert config.shmem_size == 0 -def test_launch_config_cluster_grid_conversion(init_cuda): +def test_launch_config_cluster_grid_conversion(): """Test that LaunchConfig preserves original grid values and conversion happens in native config.""" - try: - # Test case 1: 1D - Issue #867 example - config = LaunchConfig(grid=4, cluster=2, block=32) - assert config.grid == (4, 1, 1), f"Expected (4, 1, 1), got {config.grid}" - assert config.cluster == (2, 1, 1), f"Expected (2, 1, 1), got {config.cluster}" - assert config.block == (32, 1, 1), f"Expected (32, 1, 1), got {config.block}" + # Test case 1: 1D - Issue #867 example + config = LaunchConfig(grid=4, cluster=2, block=32) + assert config.grid == (4, 1, 1), f"Expected (4, 1, 1), got {config.grid}" + assert config.cluster == (2, 1, 1), f"Expected (2, 1, 1), got {config.cluster}" + assert config.block == (32, 1, 1), f"Expected (32, 1, 1), got {config.block}" - # Test case 2: 2D grid and cluster - config = LaunchConfig(grid=(2, 3), cluster=(2, 2), block=32) - assert config.grid == (2, 3, 1), f"Expected (2, 3, 1), got {config.grid}" - assert config.cluster == (2, 2, 1), f"Expected (2, 2, 1), got {config.cluster}" + # Test case 2: 2D grid and cluster + config = LaunchConfig(grid=(2, 3), cluster=(2, 2), block=32) + assert config.grid == (2, 3, 1), f"Expected (2, 3, 1), got {config.grid}" + assert config.cluster == (2, 2, 1), f"Expected (2, 2, 1), got {config.cluster}" - # Test case 3: 3D full specification - config = LaunchConfig(grid=(2, 2, 2), cluster=(3, 3, 3), block=(8, 8, 8)) - assert config.grid == (2, 2, 2), f"Expected (2, 2, 2), got {config.grid}" - assert config.cluster == (3, 3, 3), f"Expected (3, 3, 3), got {config.cluster}" + # Test case 3: 3D full specification + config = LaunchConfig(grid=(2, 2, 2), cluster=(3, 3, 3), block=(8, 8, 8)) + assert config.grid == (2, 2, 2), f"Expected (2, 2, 2), got {config.grid}" + assert config.cluster == (3, 3, 3), f"Expected (3, 3, 3), got {config.cluster}" - # Test case 4: Identity case - config = LaunchConfig(grid=1, cluster=1, block=32) - assert config.grid == (1, 1, 1), f"Expected (1, 1, 1), got {config.grid}" + # Test case 4: Identity case + config = LaunchConfig(grid=1, cluster=1, block=32) + assert config.grid == (1, 1, 1), f"Expected (1, 1, 1), got {config.grid}" - # Test case 5: No cluster (should not convert grid) - config = LaunchConfig(grid=4, block=32) - assert config.grid == (4, 1, 1), f"Expected (4, 1, 1), got {config.grid}" - assert config.cluster is None - - except CUDAError: - pytest.skip("Driver or GPU not new enough for thread block clusters") + # Test case 5: No cluster (should not convert grid) + config = LaunchConfig(grid=4, block=32) + assert config.grid == (4, 1, 1), f"Expected (4, 1, 1), got {config.grid}" + assert config.cluster is None def test_launch_config_native_conversion(init_cuda): """Test that _to_native_launch_config correctly converts grid from cluster units to block units.""" from cuda.core._launch_config import _to_native_launch_config - try: - # Test case 1: 1D - Issue #867 example - config = LaunchConfig(grid=4, cluster=2, block=32) - native_config = _to_native_launch_config(config) - assert native_config.gridDimX == 8, f"Expected gridDimX=8, got {native_config.gridDimX}" - assert native_config.gridDimY == 1, f"Expected gridDimY=1, got {native_config.gridDimY}" - assert native_config.gridDimZ == 1, f"Expected gridDimZ=1, got {native_config.gridDimZ}" - - # Test case 2: 2D grid and cluster - config = LaunchConfig(grid=(2, 3), cluster=(2, 2), block=32) - native_config = _to_native_launch_config(config) - assert native_config.gridDimX == 4, f"Expected gridDimX=4, got {native_config.gridDimX}" - assert native_config.gridDimY == 6, f"Expected gridDimY=6, got {native_config.gridDimY}" - assert native_config.gridDimZ == 1, f"Expected gridDimZ=1, got {native_config.gridDimZ}" - - # Test case 3: No cluster (should not convert grid) - config = LaunchConfig(grid=4, block=32) - native_config = _to_native_launch_config(config) - assert native_config.gridDimX == 4, f"Expected gridDimX=4, got {native_config.gridDimX}" - assert native_config.gridDimY == 1, f"Expected gridDimY=1, got {native_config.gridDimY}" - assert native_config.gridDimZ == 1, f"Expected gridDimZ=1, got {native_config.gridDimZ}" - - except CUDAError: - pytest.skip("Driver or GPU not new enough for thread block clusters") + # Test case 1: 1D - Issue #867 example + config = LaunchConfig(grid=4, cluster=2, block=32) + native_config = _to_native_launch_config(config) + assert native_config.gridDimX == 8, f"Expected gridDimX=8, got {native_config.gridDimX}" + assert native_config.gridDimY == 1, f"Expected gridDimY=1, got {native_config.gridDimY}" + assert native_config.gridDimZ == 1, f"Expected gridDimZ=1, got {native_config.gridDimZ}" + + # Test case 2: 2D grid and cluster + config = LaunchConfig(grid=(2, 3), cluster=(2, 2), block=32) + native_config = _to_native_launch_config(config) + assert native_config.gridDimX == 4, f"Expected gridDimX=4, got {native_config.gridDimX}" + assert native_config.gridDimY == 6, f"Expected gridDimY=6, got {native_config.gridDimY}" + assert native_config.gridDimZ == 1, f"Expected gridDimZ=1, got {native_config.gridDimZ}" + + # Test case 3: No cluster (should not convert grid) + config = LaunchConfig(grid=4, block=32) + native_config = _to_native_launch_config(config) + assert native_config.gridDimX == 4, f"Expected gridDimX=4, got {native_config.gridDimX}" + assert native_config.gridDimY == 1, f"Expected gridDimY=1, got {native_config.gridDimY}" + assert native_config.gridDimZ == 1, f"Expected gridDimZ=1, got {native_config.gridDimZ}" def test_launch_invalid_values(init_cuda): From 10855e37757c96d27b4edcfcbe158ece2b969721 Mon Sep 17 00:00:00 2001 From: Michael Droettboom Date: Mon, 11 May 2026 16:04:02 -0400 Subject: [PATCH 3/4] Fix tab completion (#2055) * Fix tab completion * Fix tests * Always install the monkeypatch * Update release note * Apply suggestion from @leofang Co-authored-by: Leo Fang * Fix test * Fix tests hanging on Windows --------- Co-authored-by: Leo Fang --- cuda_core/cuda/core/__init__.py | 40 +++++++ cuda_core/docs/source/release/1.0.0-notes.rst | 9 ++ cuda_core/tests/test_rlcompleter_patch.py | 106 ++++++++++++++++++ 3 files changed, 155 insertions(+) create mode 100644 cuda_core/tests/test_rlcompleter_patch.py diff --git a/cuda_core/cuda/core/__init__.py b/cuda_core/cuda/core/__init__.py index 825b29e6ca..f2d7c85b62 100644 --- a/cuda_core/cuda/core/__init__.py +++ b/cuda_core/cuda/core/__init__.py @@ -28,6 +28,46 @@ def _import_versioned_module(): del _import_versioned_module +def _patch_rlcompleter_for_cython_properties(): + # TODO: This can be removed when Python 3.13 is our minimum-supported version: + # https://github.com/python/cpython/pull/149577 + + # Cython @property on cdef class compiles to a C-level getset_descriptor, + # which rlcompleter's narrow isinstance(..., property) check misses; the + # fallback getattr() then invokes the descriptor and any non-AttributeError + # it raises kills tab completion. Extend that isinstance check to also + # match getset_descriptor / member_descriptor. Only installed in + # interactive mode so library users running scripts see no global + # rlcompleter side effect. + import os + + if int(os.environ.get("CUDA_CORE_DONT_FIX_TAB_COMPLETION", "0")): + # Explicit opt-out for users who don't want the global rlcompleter + # side effect, even in an interactive session. + return + + import rlcompleter + from types import GetSetDescriptorType, MemberDescriptorType + + # This works by overriding the `property` built-in with a custom subclass of + # property, but only in the rlcompleter module. This subclass overrides the + # `__instancecheck__` method to also return True for getset_descriptor and + # member_descriptor types, which are what Cython uses for properties on cdef + # classes. + class _PatchedPropMeta(type): + def __instancecheck__(cls, inst): + return isinstance(inst, (property, GetSetDescriptorType, MemberDescriptorType)) + + class _PatchedProperty(metaclass=_PatchedPropMeta): + pass + + rlcompleter.property = _PatchedProperty + + +_patch_rlcompleter_for_cython_properties() +del _patch_rlcompleter_for_cython_properties + + from cuda.core import checkpoint, system, utils from cuda.core._context import Context, ContextOptions from cuda.core._device import Device diff --git a/cuda_core/docs/source/release/1.0.0-notes.rst b/cuda_core/docs/source/release/1.0.0-notes.rst index 714dc48ff6..fab5f484e1 100644 --- a/cuda_core/docs/source/release/1.0.0-notes.rst +++ b/cuda_core/docs/source/release/1.0.0-notes.rst @@ -357,3 +357,12 @@ Fixes and enhancements package size. Debug builds are now supported via ``--config-settings=debug=true``. (`#1890 `__) +- Fixed tab completion silently breaking in the CPython REPL on some + ``cuda.core`` objects (for example, hitting Tab after ``mr.`` for a + ``DeviceMemoryResource`` would produce no suggestions due to a CPython + limitation interacting with Cython properties). When ``cuda.core`` is + imported in an interactive session it now applies a small patch to the + standard library REPL completer so tab completion works as expected. + If you would rather not have ``cuda.core`` modify the REPL completer, set + ``CUDA_CORE_DONT_FIX_TAB_COMPLETION=1`` to opt out. + (`#2053 `__) diff --git a/cuda_core/tests/test_rlcompleter_patch.py b/cuda_core/tests/test_rlcompleter_patch.py new file mode 100644 index 0000000000..ecde4e5b0c --- /dev/null +++ b/cuda_core/tests/test_rlcompleter_patch.py @@ -0,0 +1,106 @@ +# SPDX-FileCopyrightText: Copyright (c) 2026 NVIDIA CORPORATION & AFFILIATES. All rights reserved. +# SPDX-License-Identifier: Apache-2.0 + +""" +Tests for the rlcompleter monkeypatch installed by `cuda.core` in interactive +sessions. + +These tests reproduce the original bug report (NVIDIA/cuda-python#2053): tab +completion on a non-IPC-enabled DeviceMemoryResource crashes because the +Cython @property `allocation_handle` raises RuntimeError, and rlcompleter's +narrow `isinstance(..., property)` check misses C-level getset_descriptor +types and therefore invokes the descriptor. + +The patch only installs in interactive mode, so each scenario is exercised in +a fresh subprocess with a controlled combination of `PYTHONINSPECT` and +`CUDA_CORE_DONT_FIX_TAB_COMPLETION`. +""" + +import os +import subprocess +import sys +import tempfile +import textwrap + +import pytest + +from cuda.core import Device + + +def _gpu_with_mempool_or_skip(): + """Skip when no GPU or no mempool support — test mirrors the bug repro.""" + if len(Device.get_all_devices()) == 0: + pytest.skip("Test requires a CUDA device") + dev = Device(0) + if not dev.properties.memory_pools_supported: + pytest.skip("Device 0 does not support mempool operations") + + +# Probe script: reproduces the bug-report repro literally, then runs +# rlcompleter against `mr` and reports the outcome. +_PROBE_SCRIPT = textwrap.dedent(""" + import rlcompleter + from cuda.core import Device, DeviceMemoryResource + + dev = Device(0) + dev.set_current() + mr = DeviceMemoryResource(dev) + assert not mr.is_ipc_enabled, "test setup: mr should not be IPC-enabled" + + completer = rlcompleter.Completer({"mr": mr}) + try: + matches = completer.attr_matches("mr.") + except Exception as exc: + print(f"crash: {type(exc).__name__}: {exc}") + else: + print(f"ok: {len(matches)} matches") + print(f"allocation_handle: {'mr.allocation_handle' in matches}") +""") + + +def _run_probe(*, pythoninspect: bool, opt_out: bool = False) -> subprocess.CompletedProcess: + env = os.environ.copy() + # Don't let parent-environment values bleed into the subprocess. + env.pop("CUDA_CORE_DONT_FIX_TAB_COMPLETION", None) + # Drop PYTHONPATH so the subprocess can't find a source-tree cuda.core + # via an inherited path entry; we want it to import the installed wheel. + env.pop("PYTHONPATH", None) + if opt_out: + env["CUDA_CORE_DONT_FIX_TAB_COMPLETION"] = "1" + # `python -c` puts the parent's CWD at the head of sys.path. If pytest is + # run from `cuda_core/` (which contains a `cuda/core/` source tree), that + # source tree shadows the installed package. Run the subprocess from a + # neutral temp dir to avoid this. + with tempfile.TemporaryDirectory() as tmpdir: + return subprocess.run( # noqa: S603 + [sys.executable, "-c", _PROBE_SCRIPT], + capture_output=True, + text=True, + env=env, + check=False, + # PYTHONINSPECT keeps the interpreter alive after `-c`; close stdin + # so the implicit REPL exits immediately. + stdin=subprocess.DEVNULL, + cwd=tmpdir, + ) + + +def test_patched_completion_succeeds_on_non_ipc_resource(): + """With the patch installed (PYTHONINSPECT=1), tab completion must not + crash and `mr.allocation_handle` must appear in the matches.""" + _gpu_with_mempool_or_skip() + + result = _run_probe(pythoninspect=True) + assert result.returncode == 0, f"stderr: {result.stderr}\nstdout: {result.stdout}" + assert result.stdout.startswith("ok:"), result.stdout + assert "allocation_handle: True" in result.stdout, result.stdout + + +def test_opt_out_env_var_disables_patch_even_when_interactive(): + """`CUDA_CORE_DONT_FIX_TAB_COMPLETION=1` must short-circuit before the + interactive check, so the bug reproduces again even under PYTHONINSPECT.""" + _gpu_with_mempool_or_skip() + + result = _run_probe(pythoninspect=True, opt_out=True) + assert result.returncode == 0, f"stderr: {result.stderr}\nstdout: {result.stdout}" + assert "crash: RuntimeError" in result.stdout, result.stdout From 38923c8a08f804bb6a57e92f2f89581165087610 Mon Sep 17 00:00:00 2001 From: Kevin Turcios Date: Tue, 12 May 2026 08:31:47 -0500 Subject: [PATCH 4/4] refactor(cuda_core): defer device checks in LaunchConfig to launch time MIME-Version: 1.0 Content-Type: text/plain; charset=UTF-8 Content-Transfer-Encoding: 8bit LaunchConfig.__init__ previously called Device() to validate compute capability (for cluster launches) and cooperative_launch support, but at construction time the stream — and therefore the correct device — is not yet known. Move both checks into _launcher.pyx where the stream is available: - _check_cluster_launch: queries stream.device.compute_capability and raises if CC < 9.0 (thread block clusters require H100+) - _check_cooperative_launch: now also guards cooperative_launch support via stream.device before the grid-size check LaunchConfig.__init__ is now a pure data class with no driver calls. Cluster and cooperative config objects can be constructed without a CUDA context, and errors surface at launch() time with the correct device in scope. Remove the try/except CUDAError skip guards from cluster-related tests; constructing LaunchConfig(cluster=...) no longer raises on sub-CC-9.0 devices, so those tests run on all hardware. --- cuda_core/tests/test_launcher.py | 1 - 1 file changed, 1 deletion(-) diff --git a/cuda_core/tests/test_launcher.py b/cuda_core/tests/test_launcher.py index ec57738c6c..6b5ed64998 100644 --- a/cuda_core/tests/test_launcher.py +++ b/cuda_core/tests/test_launcher.py @@ -25,7 +25,6 @@ launch, ) from cuda.core._memory._legacy import _SynchronousMemoryResource -from cuda.core._utils.cuda_utils import CUDAError from cuda.core.typing import ObjectCodeFormatType, SourceCodeType