Skip to content

Commit 5c63733

Browse files
committed
Reorganize graph tests and add device-side launch support
- Reorganize graph tests into tests/graph/ subdirectory: - test_basic.py: basic graph construction tests - test_conditional.py: conditional node tests (if, if-else, switch, while) - test_advanced.py: child graphs, update, stream lifetime - test_options.py: debug print, complete options, build mode - test_capture_alloc.py: graph memory resource tests (moved from test_graph_mem.py) - test_device_launch.py: new device-side graph launch tests - Add Graph.handle property to expose CUgraphExec handle, consistent with Stream.handle and Event.handle patterns - Add tests/helpers/graph_kernels.py with shared kernel compilation helpers - Device-side graph launch tests require Hopper (sm_90+) architecture
1 parent 3257477 commit 5c63733

File tree

9 files changed

+1024
-764
lines changed

9 files changed

+1024
-764
lines changed

cuda_core/cuda/core/_graph.py

Lines changed: 12 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -746,6 +746,18 @@ def close(self):
746746
"""Destroy the graph."""
747747
self._mnff.close()
748748

749+
@property
750+
def handle(self) -> driver.CUgraphExec:
751+
"""Return the underlying ``CUgraphExec`` object.
752+
753+
.. caution::
754+
755+
This handle is a Python object. To get the memory address of the underlying C
756+
handle, call ``int()`` on the returned object.
757+
758+
"""
759+
return self._mnff.graph
760+
749761
def update(self, builder: GraphBuilder):
750762
"""Update the graph using new build configuration from the builder.
751763
Lines changed: 186 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,186 @@
1+
# SPDX-FileCopyrightText: Copyright (c) 2025 NVIDIA CORPORATION & AFFILIATES. All rights reserved.
2+
# SPDX-License-Identifier: LicenseRef-NVIDIA-SOFTWARE-LICENSE
3+
4+
"""Advanced graph feature tests (child graphs, update, stream lifetime)."""
5+
6+
import numpy as np
7+
import pytest
8+
from cuda.core import Device, LaunchConfig, LegacyPinnedMemoryResource, launch
9+
from helpers.graph_kernels import compile_common_kernels, compile_conditional_kernels
10+
11+
12+
@pytest.mark.skipif(tuple(int(i) for i in np.__version__.split(".")[:2]) < (2, 1), reason="need numpy 2.1.0+")
13+
def test_graph_child_graph(init_cuda):
14+
mod = compile_common_kernels()
15+
add_one = mod.get_kernel("add_one")
16+
17+
# Allocate memory
18+
launch_stream = Device().create_stream()
19+
mr = LegacyPinnedMemoryResource()
20+
b = mr.allocate(8)
21+
arr = np.from_dlpack(b).view(np.int32)
22+
arr[0] = 0
23+
arr[1] = 0
24+
25+
# Capture the child graph
26+
gb_child = Device().create_graph_builder().begin_building()
27+
launch(gb_child, LaunchConfig(grid=1, block=1), add_one, arr[1:].ctypes.data)
28+
launch(gb_child, LaunchConfig(grid=1, block=1), add_one, arr[1:].ctypes.data)
29+
launch(gb_child, LaunchConfig(grid=1, block=1), add_one, arr[1:].ctypes.data)
30+
gb_child.end_building()
31+
32+
# Capture the parent graph
33+
gb_parent = Device().create_graph_builder().begin_building()
34+
launch(gb_parent, LaunchConfig(grid=1, block=1), add_one, arr.ctypes.data)
35+
36+
## Add child
37+
try:
38+
gb_parent.add_child(gb_child)
39+
except NotImplementedError as e:
40+
with pytest.raises(
41+
NotImplementedError,
42+
match="^Launching child graphs is not implemented for versions older than CUDA 12",
43+
):
44+
raise e
45+
gb_parent.end_building()
46+
b.close()
47+
pytest.skip("Launching child graphs is not implemented for versions older than CUDA 12")
48+
49+
launch(gb_parent, LaunchConfig(grid=1, block=1), add_one, arr.ctypes.data)
50+
graph = gb_parent.end_building().complete()
51+
52+
# Parent updates first value, child updates second value
53+
assert arr[0] == 0
54+
assert arr[1] == 0
55+
graph.launch(launch_stream)
56+
launch_stream.sync()
57+
assert arr[0] == 2
58+
assert arr[1] == 3
59+
60+
# Close the memory resource now because the garbage collected might
61+
# de-allocate it during the next graph builder process
62+
b.close()
63+
64+
65+
@pytest.mark.skipif(tuple(int(i) for i in np.__version__.split(".")[:2]) < (2, 1), reason="need numpy 2.1.0+")
66+
def test_graph_update(init_cuda):
67+
mod = compile_conditional_kernels(int)
68+
add_one = mod.get_kernel("add_one")
69+
70+
# Allocate memory
71+
launch_stream = Device().create_stream()
72+
mr = LegacyPinnedMemoryResource()
73+
b = mr.allocate(12)
74+
arr = np.from_dlpack(b).view(np.int32)
75+
arr[0] = 0
76+
arr[1] = 0
77+
arr[2] = 0
78+
79+
def build_graph(condition_value):
80+
# Begin capture
81+
gb = Device().create_graph_builder().begin_building()
82+
83+
# Add Node A (sets condition)
84+
handle = gb.create_conditional_handle(default_value=condition_value)
85+
86+
# Add Node B (while condition)
87+
try:
88+
gb_case = list(gb.switch(handle, 3))
89+
except Exception as e:
90+
with pytest.raises(RuntimeError, match="^(Driver|Binding) version"):
91+
raise e
92+
gb.end_building()
93+
raise e
94+
95+
## Case 0
96+
gb_case[0].begin_building()
97+
launch(gb_case[0], LaunchConfig(grid=1, block=1), add_one, arr.ctypes.data)
98+
launch(gb_case[0], LaunchConfig(grid=1, block=1), add_one, arr.ctypes.data)
99+
launch(gb_case[0], LaunchConfig(grid=1, block=1), add_one, arr.ctypes.data)
100+
gb_case[0].end_building()
101+
102+
## Case 1
103+
gb_case[1].begin_building()
104+
launch(gb_case[1], LaunchConfig(grid=1, block=1), add_one, arr[1:].ctypes.data)
105+
launch(gb_case[1], LaunchConfig(grid=1, block=1), add_one, arr[1:].ctypes.data)
106+
launch(gb_case[1], LaunchConfig(grid=1, block=1), add_one, arr[1:].ctypes.data)
107+
gb_case[1].end_building()
108+
109+
## Case 2
110+
gb_case[2].begin_building()
111+
launch(gb_case[2], LaunchConfig(grid=1, block=1), add_one, arr[2:].ctypes.data)
112+
launch(gb_case[2], LaunchConfig(grid=1, block=1), add_one, arr[2:].ctypes.data)
113+
launch(gb_case[2], LaunchConfig(grid=1, block=1), add_one, arr[2:].ctypes.data)
114+
gb_case[2].end_building()
115+
116+
return gb.end_building()
117+
118+
try:
119+
graph_variants = [build_graph(0), build_graph(1), build_graph(2)]
120+
except Exception as e:
121+
with pytest.raises(RuntimeError, match="^(Driver|Binding) version"):
122+
raise e
123+
b.close()
124+
pytest.skip("Driver does not support conditional switch")
125+
126+
# Launch the first graph
127+
assert arr[0] == 0
128+
assert arr[1] == 0
129+
assert arr[2] == 0
130+
graph = graph_variants[0].complete()
131+
graph.launch(launch_stream)
132+
launch_stream.sync()
133+
assert arr[0] == 3
134+
assert arr[1] == 0
135+
assert arr[2] == 0
136+
137+
# Update with second variant and launch again
138+
graph.update(graph_variants[1])
139+
graph.launch(launch_stream)
140+
launch_stream.sync()
141+
assert arr[0] == 3
142+
assert arr[1] == 3
143+
assert arr[2] == 0
144+
145+
# Update with third variant and launch again
146+
graph.update(graph_variants[2])
147+
graph.launch(launch_stream)
148+
launch_stream.sync()
149+
assert arr[0] == 3
150+
assert arr[1] == 3
151+
assert arr[2] == 3
152+
153+
# Close the memory resource now because the garbage collected might
154+
# de-allocate it during the next graph builder process
155+
b.close()
156+
157+
158+
def test_graph_stream_lifetime(init_cuda):
159+
mod = compile_common_kernels()
160+
empty_kernel = mod.get_kernel("empty_kernel")
161+
162+
# Create simple graph from device
163+
gb = Device().create_graph_builder().begin_building()
164+
launch(gb, LaunchConfig(grid=1, block=1), empty_kernel)
165+
graph = gb.end_building().complete()
166+
167+
# Destroy simple graph and builder
168+
gb.close()
169+
graph.close()
170+
171+
# Create simple graph from stream
172+
stream = Device().create_stream()
173+
gb = stream.create_graph_builder().begin_building()
174+
launch(gb, LaunchConfig(grid=1, block=1), empty_kernel)
175+
graph = gb.end_building().complete()
176+
177+
# Destroy simple graph and builder
178+
gb.close()
179+
graph.close()
180+
181+
# Verify the stream can still launch work
182+
launch(stream, LaunchConfig(grid=1, block=1), empty_kernel)
183+
stream.sync()
184+
185+
# Destroy the stream
186+
stream.close()
Lines changed: 164 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,164 @@
1+
# SPDX-FileCopyrightText: Copyright (c) 2025 NVIDIA CORPORATION & AFFILIATES. All rights reserved.
2+
# SPDX-License-Identifier: LicenseRef-NVIDIA-SOFTWARE-LICENSE
3+
4+
"""Basic graph construction and topology tests."""
5+
6+
import numpy as np
7+
import pytest
8+
from cuda.core import Device, GraphBuilder, LaunchConfig, LegacyPinnedMemoryResource, launch
9+
from helpers.graph_kernels import compile_common_kernels
10+
11+
12+
def test_graph_is_building(init_cuda):
13+
gb = Device().create_graph_builder()
14+
assert gb.is_building is False
15+
gb.begin_building()
16+
assert gb.is_building is True
17+
gb.end_building()
18+
assert gb.is_building is False
19+
20+
21+
def test_graph_straight(init_cuda):
22+
mod = compile_common_kernels()
23+
empty_kernel = mod.get_kernel("empty_kernel")
24+
launch_stream = Device().create_stream()
25+
26+
# Simple linear topology
27+
gb = Device().create_graph_builder().begin_building()
28+
launch(gb, LaunchConfig(grid=1, block=1), empty_kernel)
29+
launch(gb, LaunchConfig(grid=1, block=1), empty_kernel)
30+
launch(gb, LaunchConfig(grid=1, block=1), empty_kernel)
31+
graph = gb.end_building().complete()
32+
33+
# Sanity upload and launch
34+
graph.upload(launch_stream)
35+
graph.launch(launch_stream)
36+
launch_stream.sync()
37+
38+
39+
def test_graph_fork_join(init_cuda):
40+
mod = compile_common_kernels()
41+
empty_kernel = mod.get_kernel("empty_kernel")
42+
launch_stream = Device().create_stream()
43+
44+
# Simple diamond topology
45+
gb = Device().create_graph_builder().begin_building()
46+
launch(gb, LaunchConfig(grid=1, block=1), empty_kernel)
47+
48+
with pytest.raises(ValueError, match="^Invalid split count: expecting >= 2, got 1"):
49+
gb.split(1)
50+
51+
left, right = gb.split(2)
52+
launch(left, LaunchConfig(grid=1, block=1), empty_kernel)
53+
launch(left, LaunchConfig(grid=1, block=1), empty_kernel)
54+
launch(right, LaunchConfig(grid=1, block=1), empty_kernel)
55+
launch(right, LaunchConfig(grid=1, block=1), empty_kernel)
56+
57+
with pytest.raises(ValueError, match="^Must join with at least two graph builders"):
58+
GraphBuilder.join(left)
59+
60+
gb = GraphBuilder.join(left, right)
61+
62+
launch(gb, LaunchConfig(grid=1, block=1), empty_kernel)
63+
graph = gb.end_building().complete()
64+
65+
# Sanity upload and launch
66+
graph.upload(launch_stream)
67+
graph.launch(launch_stream)
68+
launch_stream.sync()
69+
70+
71+
def test_graph_is_join_required(init_cuda):
72+
mod = compile_common_kernels()
73+
empty_kernel = mod.get_kernel("empty_kernel")
74+
75+
# Starting builder is always primary
76+
gb = Device().create_graph_builder()
77+
assert gb.is_join_required is False
78+
gb.begin_building()
79+
80+
# Create root node
81+
launch(gb, LaunchConfig(grid=1, block=1), empty_kernel)
82+
83+
# First returned builder is always the original
84+
first_split_builders = gb.split(3)
85+
assert first_split_builders[0] is gb
86+
87+
# Only the original builder need not join
88+
assert first_split_builders[0].is_join_required is False
89+
for builder in first_split_builders[1:]:
90+
assert builder.is_join_required is True
91+
92+
# Launch kernel on each split
93+
for builder in first_split_builders:
94+
launch(builder, LaunchConfig(grid=1, block=1), empty_kernel)
95+
96+
# Splitting on new builder will all require joining
97+
second_split_builders = first_split_builders[-1]
98+
first_split_builders = first_split_builders[0:-1]
99+
second_split_builders = second_split_builders.split(3)
100+
for builder in second_split_builders:
101+
assert builder.is_join_required is True
102+
103+
# Launch kernel on each second split
104+
for builder in second_split_builders:
105+
launch(builder, LaunchConfig(grid=1, block=1), empty_kernel)
106+
107+
# Joined builder requires joining if all builder need to join
108+
gb = GraphBuilder.join(*second_split_builders)
109+
assert gb.is_join_required is True
110+
gb = GraphBuilder.join(gb, *first_split_builders)
111+
assert gb.is_join_required is False
112+
113+
# Create final node
114+
launch(gb, LaunchConfig(grid=1, block=1), empty_kernel)
115+
gb.end_building().complete()
116+
117+
118+
@pytest.mark.skipif(tuple(int(i) for i in np.__version__.split(".")[:2]) < (2, 1), reason="need numpy 2.1.0+")
119+
def test_graph_repeat_capture(init_cuda):
120+
mod = compile_common_kernels()
121+
add_one = mod.get_kernel("add_one")
122+
123+
# Allocate memory
124+
launch_stream = Device().create_stream()
125+
mr = LegacyPinnedMemoryResource()
126+
b = mr.allocate(4)
127+
arr = np.from_dlpack(b).view(np.int32)
128+
arr[0] = 0
129+
130+
# Launch the graph once
131+
gb = launch_stream.create_graph_builder().begin_building()
132+
launch(gb, LaunchConfig(grid=1, block=1), add_one, arr.ctypes.data)
133+
graph = gb.end_building().complete()
134+
135+
# Run the graph once
136+
graph.launch(launch_stream)
137+
launch_stream.sync()
138+
assert arr[0] == 1
139+
140+
# Continue capturing to extend the graph
141+
with pytest.raises(RuntimeError, match="^Cannot resume building after building has ended."):
142+
gb.begin_building()
143+
144+
# Graph can be re-launched
145+
graph.launch(launch_stream)
146+
graph.launch(launch_stream)
147+
graph.launch(launch_stream)
148+
launch_stream.sync()
149+
assert arr[0] == 4
150+
151+
# Close the memory resource now because the garbage collected might
152+
# de-allocate it during the next graph builder process
153+
b.close()
154+
155+
156+
def test_graph_capture_errors(init_cuda):
157+
gb = Device().create_graph_builder()
158+
with pytest.raises(RuntimeError, match="^Graph has not finished building."):
159+
gb.complete()
160+
161+
gb.begin_building()
162+
with pytest.raises(RuntimeError, match="^Graph has not finished building."):
163+
gb.complete()
164+
gb.end_building().complete()

cuda_core/tests/test_graph_mem.py renamed to cuda_core/tests/graph/test_capture_alloc.py

Lines changed: 2 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -2,6 +2,8 @@
22
#
33
# SPDX-License-Identifier: LicenseRef-NVIDIA-SOFTWARE-LICENSE
44

5+
"""Graph memory resource tests."""
6+
57
import pytest
68
from cuda.core import (
79
Device,

0 commit comments

Comments
 (0)