Skip to content

perf(cuda_core): cache native LaunchConfig struct and make fields read-only#2070

Draft
KRRT7 wants to merge 6 commits into
NVIDIA:mainfrom
KRRT7:perf/cache-native-launch-config
Draft

perf(cuda_core): cache native LaunchConfig struct and make fields read-only#2070
KRRT7 wants to merge 6 commits into
NVIDIA:mainfrom
KRRT7:perf/cache-native-launch-config

Conversation

@KRRT7
Copy link
Copy Markdown

@KRRT7 KRRT7 commented May 12, 2026

Motivation

launch() is on the critical path. Every call currently pays the full cost of
_to_native_launch_config() — a memset, vector::resize, and attribute
rebuild — even when the LaunchConfig hasn't changed between calls, which is
the normal pattern in a tight dispatch loop.

LaunchConfig is already designed as an immutable value type (__hash__ and
__eq__ are defined), so the native CUlaunchConfig struct is a pure function
of its fields. We can compute it once and reuse it.

Changes

_launch_config.pxdpublicreadonly on all five fields. Python
callers can still read config.grid etc. but can no longer mutate them after
construction. Cython-internal code is unaffected (direct C field access is
unchanged). Two new private C fields: _cached_drv_cfg and _cache_valid.

_launch_config.pyx_to_native_launch_config (the cdef method called
from launch()) now short-circuits on _cache_valid. On first call it builds
the struct as before, stores it, and sets _cache_valid = True. Subsequent
calls return a struct copy in O(1). The attrs pointer in the cached struct
is stable because self._attrs is never resized after the cache is set.

No changes to _launcher.pyx, _graph_node.pyx, or any test that reads
fields — readonly fields are accessed identically from both Python and
Cython.

Correctness

  • Fields are now read-only from Python, so the cache can never go stale from
    user code.
  • The attrs pointer in the cached CUlaunchConfig points into self._attrs.
    Since _attrs.resize(0) is skipped on the fast path, the vector is never
    reallocated after the cache is populated; the pointer is valid for the
    lifetime of any cuLaunchKernelEx call.
  • Thread safety: the build path runs under the GIL; worst case is two threads
    both compute the same result simultaneously, which is harmless.

Benchmark

Measured on a T4 (CUDA 12.9), 50,000 iterations, noop kernel:

µs/call
launch() reused config (cache warm) 3.98
launch() fresh config each call (cache cold) 6.34
LaunchConfig() construction alone 1.76

1.6x speedup on launch() when reusing a LaunchConfig across calls,
which is the expected pattern for any steady-state dispatch loop.
LaunchConfig construction accounts for ~28% of the cold-path cost; the rest
is the _to_native_launch_config rebuild that the cache eliminates.

Tests

  • test_launch_config_fields_are_readonly — all five fields raise AttributeError on write
  • test_launch_config_native_cache_stable — two calls to _to_native_launch_config on the same config return consistent grid/block/shmem/numAttrs values
  • test_launch_config_native_cache_cooperative — cached cooperative config retains its attribute (numAttrs == 1)

@copy-pr-bot
Copy link
Copy Markdown
Contributor

copy-pr-bot Bot commented May 12, 2026

This pull request requires additional validation before any workflows can run on NVIDIA's runners.

Pull request vetters can view their responsibilities here.

Contributors can view more details about this message here.

@github-actions github-actions Bot added the cuda.core Everything related to the cuda.core module label May 12, 2026
KRRT7 added 2 commits May 12, 2026 10:57
…d-only

_to_native_launch_config() rebuilt CUlaunchConfig on every launch() call
even when the config was unchanged. Since LaunchConfig is already designed
as an immutable value type (__hash__, __eq__), cache the result after the
first build and return a struct copy on subsequent calls.

Fields are changed from `public` to `readonly` so the cache can never go
stale from Python-side mutation. Cython-internal access is unaffected.

Benchmark (T4, 50k iters, noop kernel):
  launch() reused config (cache warm):  3.98 us/call
  launch() fresh config each call:      6.34 us/call
  speedup: 1.6x
@KRRT7
Copy link
Copy Markdown
Author

KRRT7 commented May 12, 2026

meant to open the draft in my fork so that it could run in my CI, apologies, will clean up shortly

@KRRT7 KRRT7 force-pushed the perf/cache-native-launch-config branch from 7bb484b to 6d5e25d Compare May 12, 2026 16:03
KRRT7 and others added 4 commits May 12, 2026 11:03
- Expose _cache_valid as readonly so tests can assert the cdef cache path
- Add test_launch_config_cdef_cache_populated_by_launch: verifies the
  cdef _to_native_launch_config cache is set after a real launch() call
- Add test_launch_config_native_conversion_stable_cluster: cluster config
  consistency via the cpdef wrapper
- Rename cpdef-level cache tests to make clear they test the Python wrapper,
  not the cdef cache
- Add NOTE comment at the cpdef wrapper distinguishing it from the cached
  cdef method
- Add breaking change entry to 1.0.1 release notes for readonly fields
- Use typed values in test_launch_config_fields_are_readonly instead of None
- Narrow except clause to CUDAError in cooperative skip guard
- Assert cache persists through a second launch in cdef cache test
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment

Labels

cuda.core Everything related to the cuda.core module

Projects

None yet

Development

Successfully merging this pull request may close these issues.

1 participant