Skip to content
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
54 changes: 54 additions & 0 deletions cuda_core/cuda/core/_cpp/resource_handles.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -56,6 +56,9 @@ decltype(&cuLibraryLoadData) p_cuLibraryLoadData = nullptr;
decltype(&cuLibraryUnload) p_cuLibraryUnload = nullptr;
decltype(&cuLibraryGetKernel) p_cuLibraryGetKernel = nullptr;

// Graph
decltype(&cuGraphDestroy) p_cuGraphDestroy = nullptr;

// Linker
decltype(&cuLinkDestroy) p_cuLinkDestroy = nullptr;

Expand Down Expand Up @@ -901,6 +904,57 @@ LibraryHandle get_kernel_library(const KernelHandle& h) noexcept {
return get_box(h)->h_library;
}

// ============================================================================
// Graph Handles
// ============================================================================

namespace {
struct GraphBox {
CUgraph resource;
GraphHandle h_parent; // Keeps parent alive for child/branch graphs
};
} // namespace

GraphHandle create_graph_handle(CUgraph graph) {
auto box = std::shared_ptr<const GraphBox>(
new GraphBox{graph, {}},
[](const GraphBox* b) {
GILReleaseGuard gil;
p_cuGraphDestroy(b->resource);
delete b;
}
);
return GraphHandle(box, &box->resource);
}

GraphHandle create_graph_handle_ref(CUgraph graph, const GraphHandle& h_parent) {
auto box = std::make_shared<const GraphBox>(GraphBox{graph, h_parent});
return GraphHandle(box, &box->resource);
}

namespace {
struct GraphNodeBox {
CUgraphNode resource;
GraphHandle h_graph;
};
} // namespace

static const GraphNodeBox* get_box(const GraphNodeHandle& h) {
const CUgraphNode* p = h.get();
return reinterpret_cast<const GraphNodeBox*>(
reinterpret_cast<const char*>(p) - offsetof(GraphNodeBox, resource)
);
}

GraphNodeHandle create_graph_node_handle(CUgraphNode node, const GraphHandle& h_graph) {
auto box = std::make_shared<const GraphNodeBox>(GraphNodeBox{node, h_graph});
return GraphNodeHandle(box, &box->resource);
}

GraphHandle graph_node_get_graph(const GraphNodeHandle& h) noexcept {
return h ? get_box(h)->h_graph : GraphHandle{};
}

// ============================================================================
// Graphics Resource Handles
// ============================================================================
Expand Down
59 changes: 59 additions & 0 deletions cuda_core/cuda/core/_cpp/resource_handles.hpp
Original file line number Diff line number Diff line change
Expand Up @@ -92,6 +92,9 @@ extern decltype(&cuLibraryLoadData) p_cuLibraryLoadData;
extern decltype(&cuLibraryUnload) p_cuLibraryUnload;
extern decltype(&cuLibraryGetKernel) p_cuLibraryGetKernel;

// Graph
extern decltype(&cuGraphDestroy) p_cuGraphDestroy;

// Linker
extern decltype(&cuLinkDestroy) p_cuLinkDestroy;

Expand Down Expand Up @@ -143,6 +146,8 @@ using EventHandle = std::shared_ptr<const CUevent>;
using MemoryPoolHandle = std::shared_ptr<const CUmemoryPool>;
using LibraryHandle = std::shared_ptr<const CUlibrary>;
using KernelHandle = std::shared_ptr<const CUkernel>;
using GraphHandle = std::shared_ptr<const CUgraph>;
using GraphNodeHandle = std::shared_ptr<const CUgraphNode>;
using GraphicsResourceHandle = std::shared_ptr<const CUgraphicsResource>;
using NvrtcProgramHandle = std::shared_ptr<const nvrtcProgram>;
using NvvmProgramHandle = std::shared_ptr<const NvvmProgramValue>;
Expand Down Expand Up @@ -371,6 +376,33 @@ KernelHandle create_kernel_handle_ref(CUkernel kernel);
// Returns empty handle if the kernel has no library dependency.
LibraryHandle get_kernel_library(const KernelHandle& h) noexcept;

// ============================================================================
// Graph handle functions
// ============================================================================

// Wrap an externally-created CUgraph with RAII cleanup.
// When the last reference is released, cuGraphDestroy is called automatically.
// The caller must have already created the graph via cuGraphCreate.
GraphHandle create_graph_handle(CUgraph graph);

// Create a non-owning graph handle that keeps h_parent alive.
// Use for graphs owned by a child/conditional node in a parent graph.
// The child graph will NOT be destroyed when this handle is released,
// but h_parent will be prevented from destruction while this handle exists.
GraphHandle create_graph_handle_ref(CUgraph graph, const GraphHandle& h_parent);

// ============================================================================
// Graph node handle functions
// ============================================================================

// Create a node handle. Nodes are owned by their parent graph (not
// independently destroyable). The GraphHandle dependency ensures the
// graph outlives any node reference.
GraphNodeHandle create_graph_node_handle(CUgraphNode node, const GraphHandle& h_graph);

// Extract the owning graph handle from a node handle.
GraphHandle graph_node_get_graph(const GraphNodeHandle& h) noexcept;

// ============================================================================
// Graphics resource handle functions
// ============================================================================
Expand Down Expand Up @@ -467,6 +499,14 @@ inline CUkernel as_cu(const KernelHandle& h) noexcept {
return h ? *h : nullptr;
}

inline CUgraph as_cu(const GraphHandle& h) noexcept {
return h ? *h : nullptr;
}

inline CUgraphNode as_cu(const GraphNodeHandle& h) noexcept {
return h ? *h : nullptr;
}

inline CUgraphicsResource as_cu(const GraphicsResourceHandle& h) noexcept {
return h ? *h : nullptr;
}
Expand Down Expand Up @@ -517,6 +557,14 @@ inline std::intptr_t as_intptr(const KernelHandle& h) noexcept {
return reinterpret_cast<std::intptr_t>(as_cu(h));
}

inline std::intptr_t as_intptr(const GraphHandle& h) noexcept {
return reinterpret_cast<std::intptr_t>(as_cu(h));
}

inline std::intptr_t as_intptr(const GraphNodeHandle& h) noexcept {
return reinterpret_cast<std::intptr_t>(as_cu(h));
}

inline std::intptr_t as_intptr(const GraphicsResourceHandle& h) noexcept {
return reinterpret_cast<std::intptr_t>(as_cu(h));
}
Expand Down Expand Up @@ -595,6 +643,17 @@ inline PyObject* as_py(const KernelHandle& h) noexcept {
return detail::make_py("cuda.bindings.driver", "CUkernel", as_intptr(h));
}

inline PyObject* as_py(const GraphHandle& h) noexcept {
return detail::make_py("cuda.bindings.driver", "CUgraph", as_intptr(h));
}

inline PyObject* as_py(const GraphNodeHandle& h) noexcept {
if (!as_intptr(h)) {
Py_RETURN_NONE;
}
return detail::make_py("cuda.bindings.driver", "CUgraphNode", as_intptr(h));
}

inline PyObject* as_py(const NvrtcProgramHandle& h) noexcept {
return detail::make_py("cuda.bindings.nvrtc", "nvrtcProgram", as_intptr(h));
}
Expand Down
159 changes: 77 additions & 82 deletions cuda_core/cuda/core/_graph/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -91,6 +91,43 @@ class GraphDebugPrintOptions:
extra_topo_info: bool = False
conditional_node_params: bool = False

def _to_flags(self) -> int:
"""Convert options to CUDA driver API flags (internal use)."""
flags = 0
if self.verbose:
flags |= driver.CUgraphDebugDot_flags.CU_GRAPH_DEBUG_DOT_FLAGS_VERBOSE
if self.runtime_types:
flags |= driver.CUgraphDebugDot_flags.CU_GRAPH_DEBUG_DOT_FLAGS_RUNTIME_TYPES
if self.kernel_node_params:
flags |= driver.CUgraphDebugDot_flags.CU_GRAPH_DEBUG_DOT_FLAGS_KERNEL_NODE_PARAMS
if self.memcpy_node_params:
flags |= driver.CUgraphDebugDot_flags.CU_GRAPH_DEBUG_DOT_FLAGS_MEMCPY_NODE_PARAMS
if self.memset_node_params:
flags |= driver.CUgraphDebugDot_flags.CU_GRAPH_DEBUG_DOT_FLAGS_MEMSET_NODE_PARAMS
if self.host_node_params:
flags |= driver.CUgraphDebugDot_flags.CU_GRAPH_DEBUG_DOT_FLAGS_HOST_NODE_PARAMS
if self.event_node_params:
flags |= driver.CUgraphDebugDot_flags.CU_GRAPH_DEBUG_DOT_FLAGS_EVENT_NODE_PARAMS
if self.ext_semas_signal_node_params:
flags |= driver.CUgraphDebugDot_flags.CU_GRAPH_DEBUG_DOT_FLAGS_EXT_SEMAS_SIGNAL_NODE_PARAMS
if self.ext_semas_wait_node_params:
flags |= driver.CUgraphDebugDot_flags.CU_GRAPH_DEBUG_DOT_FLAGS_EXT_SEMAS_WAIT_NODE_PARAMS
if self.kernel_node_attributes:
flags |= driver.CUgraphDebugDot_flags.CU_GRAPH_DEBUG_DOT_FLAGS_KERNEL_NODE_ATTRIBUTES
if self.handles:
flags |= driver.CUgraphDebugDot_flags.CU_GRAPH_DEBUG_DOT_FLAGS_HANDLES
if self.mem_alloc_node_params:
flags |= driver.CUgraphDebugDot_flags.CU_GRAPH_DEBUG_DOT_FLAGS_MEM_ALLOC_NODE_PARAMS
if self.mem_free_node_params:
flags |= driver.CUgraphDebugDot_flags.CU_GRAPH_DEBUG_DOT_FLAGS_MEM_FREE_NODE_PARAMS
if self.batch_mem_op_node_params:
flags |= driver.CUgraphDebugDot_flags.CU_GRAPH_DEBUG_DOT_FLAGS_BATCH_MEM_OP_NODE_PARAMS
if self.extra_topo_info:
flags |= driver.CUgraphDebugDot_flags.CU_GRAPH_DEBUG_DOT_FLAGS_EXTRA_TOPO_INFO
if self.conditional_node_params:
flags |= driver.CUgraphDebugDot_flags.CU_GRAPH_DEBUG_DOT_FLAGS_CONDITIONAL_NODE_PARAMS
return flags


@dataclass
class GraphCompleteOptions:
Expand Down Expand Up @@ -118,6 +155,44 @@ class GraphCompleteOptions:
use_node_priority: bool = False


def _instantiate_graph(h_graph, options: GraphCompleteOptions | None = None) -> Graph:
params = driver.CUDA_GRAPH_INSTANTIATE_PARAMS()
if options:
flags = 0
if options.auto_free_on_launch:
flags |= driver.CUgraphInstantiate_flags.CUDA_GRAPH_INSTANTIATE_FLAG_AUTO_FREE_ON_LAUNCH
if options.upload_stream:
flags |= driver.CUgraphInstantiate_flags.CUDA_GRAPH_INSTANTIATE_FLAG_UPLOAD
params.hUploadStream = options.upload_stream.handle
if options.device_launch:
flags |= driver.CUgraphInstantiate_flags.CUDA_GRAPH_INSTANTIATE_FLAG_DEVICE_LAUNCH
if options.use_node_priority:
flags |= driver.CUgraphInstantiate_flags.CUDA_GRAPH_INSTANTIATE_FLAG_USE_NODE_PRIORITY
params.flags = flags

graph = Graph._init(handle_return(driver.cuGraphInstantiateWithParams(h_graph, params)))
if params.result_out == driver.CUgraphInstantiateResult.CUDA_GRAPH_INSTANTIATE_ERROR:
raise RuntimeError(
"Instantiation failed for an unexpected reason which is described in the return value of the function."
)
elif params.result_out == driver.CUgraphInstantiateResult.CUDA_GRAPH_INSTANTIATE_INVALID_STRUCTURE:
raise RuntimeError("Instantiation failed due to invalid structure, such as cycles.")
elif params.result_out == driver.CUgraphInstantiateResult.CUDA_GRAPH_INSTANTIATE_NODE_OPERATION_NOT_SUPPORTED:
raise RuntimeError(
"Instantiation for device launch failed because the graph contained an unsupported operation."
)
elif params.result_out == driver.CUgraphInstantiateResult.CUDA_GRAPH_INSTANTIATE_MULTIPLE_CTXS_NOT_SUPPORTED:
raise RuntimeError("Instantiation for device launch failed due to the nodes belonging to different contexts.")
elif (
_py_major_minor >= (12, 8)
and params.result_out == driver.CUgraphInstantiateResult.CUDA_GRAPH_INSTANTIATE_CONDITIONAL_HANDLE_UNUSED
):
raise RuntimeError("One or more conditional handles are not associated with conditional builders.")
elif params.result_out != driver.CUgraphInstantiateResult.CUDA_GRAPH_INSTANTIATE_SUCCESS:
raise RuntimeError(f"Graph instantiation failed with unexpected error code: {params.result_out}")
return graph


class GraphBuilder:
"""Represents a graph under construction.

Expand Down Expand Up @@ -280,53 +355,7 @@ def complete(self, options: GraphCompleteOptions | None = None) -> Graph:
if not self._building_ended:
raise RuntimeError("Graph has not finished building.")

if (_driver_ver < 12000) or (_py_major_minor < (12, 0)):
flags = 0
if options:
if options.auto_free_on_launch:
flags |= driver.CUgraphInstantiate_flags.CUDA_GRAPH_INSTANTIATE_FLAG_AUTO_FREE_ON_LAUNCH
if options.use_node_priority:
flags |= driver.CUgraphInstantiate_flags.CUDA_GRAPH_INSTANTIATE_FLAG_USE_NODE_PRIORITY
return Graph._init(handle_return(driver.cuGraphInstantiateWithFlags(self._mnff.graph, flags)))

params = driver.CUDA_GRAPH_INSTANTIATE_PARAMS()
if options:
flags = 0
if options.auto_free_on_launch:
flags |= driver.CUgraphInstantiate_flags.CUDA_GRAPH_INSTANTIATE_FLAG_AUTO_FREE_ON_LAUNCH
if options.upload_stream:
flags |= driver.CUgraphInstantiate_flags.CUDA_GRAPH_INSTANTIATE_FLAG_UPLOAD
params.hUploadStream = options.upload_stream.handle
if options.device_launch:
flags |= driver.CUgraphInstantiate_flags.CUDA_GRAPH_INSTANTIATE_FLAG_DEVICE_LAUNCH
if options.use_node_priority:
flags |= driver.CUgraphInstantiate_flags.CUDA_GRAPH_INSTANTIATE_FLAG_USE_NODE_PRIORITY
params.flags = flags

graph = Graph._init(handle_return(driver.cuGraphInstantiateWithParams(self._mnff.graph, params)))
if params.result_out == driver.CUgraphInstantiateResult.CUDA_GRAPH_INSTANTIATE_ERROR:
# NOTE: Should never get here since the handle_return should have caught this case
raise RuntimeError(
"Instantiation failed for an unexpected reason which is described in the return value of the function."
)
elif params.result_out == driver.CUgraphInstantiateResult.CUDA_GRAPH_INSTANTIATE_INVALID_STRUCTURE:
raise RuntimeError("Instantiation failed due to invalid structure, such as cycles.")
elif params.result_out == driver.CUgraphInstantiateResult.CUDA_GRAPH_INSTANTIATE_NODE_OPERATION_NOT_SUPPORTED:
raise RuntimeError(
"Instantiation for device launch failed because the graph contained an unsupported operation."
)
elif params.result_out == driver.CUgraphInstantiateResult.CUDA_GRAPH_INSTANTIATE_MULTIPLE_CTXS_NOT_SUPPORTED:
raise RuntimeError(
"Instantiation for device launch failed due to the nodes belonging to different contexts."
)
elif (
_py_major_minor >= (12, 8)
and params.result_out == driver.CUgraphInstantiateResult.CUDA_GRAPH_INSTANTIATE_CONDITIONAL_HANDLE_UNUSED
):
raise RuntimeError("One or more conditional handles are not associated with conditional builders.")
elif params.result_out != driver.CUgraphInstantiateResult.CUDA_GRAPH_INSTANTIATE_SUCCESS:
raise RuntimeError(f"Graph instantiation failed with unexpected error code: {params.result_out}")
return graph
return _instantiate_graph(self._mnff.graph, options)

def debug_dot_print(self, path, options: GraphDebugPrintOptions | None = None):
"""Generates a DOT debug file for the graph builder.
Expand All @@ -341,41 +370,7 @@ def debug_dot_print(self, path, options: GraphDebugPrintOptions | None = None):
"""
if not self._building_ended:
raise RuntimeError("Graph has not finished building.")
flags = 0
if options:
if options.verbose:
flags |= driver.CUgraphDebugDot_flags.CU_GRAPH_DEBUG_DOT_FLAGS_VERBOSE
if options.runtime_types:
flags |= driver.CUgraphDebugDot_flags.CU_GRAPH_DEBUG_DOT_FLAGS_RUNTIME_TYPES
if options.kernel_node_params:
flags |= driver.CUgraphDebugDot_flags.CU_GRAPH_DEBUG_DOT_FLAGS_KERNEL_NODE_PARAMS
if options.memcpy_node_params:
flags |= driver.CUgraphDebugDot_flags.CU_GRAPH_DEBUG_DOT_FLAGS_MEMCPY_NODE_PARAMS
if options.memset_node_params:
flags |= driver.CUgraphDebugDot_flags.CU_GRAPH_DEBUG_DOT_FLAGS_MEMSET_NODE_PARAMS
if options.host_node_params:
flags |= driver.CUgraphDebugDot_flags.CU_GRAPH_DEBUG_DOT_FLAGS_HOST_NODE_PARAMS
if options.event_node_params:
flags |= driver.CUgraphDebugDot_flags.CU_GRAPH_DEBUG_DOT_FLAGS_EVENT_NODE_PARAMS
if options.ext_semas_signal_node_params:
flags |= driver.CUgraphDebugDot_flags.CU_GRAPH_DEBUG_DOT_FLAGS_EXT_SEMAS_SIGNAL_NODE_PARAMS
if options.ext_semas_wait_node_params:
flags |= driver.CUgraphDebugDot_flags.CU_GRAPH_DEBUG_DOT_FLAGS_EXT_SEMAS_WAIT_NODE_PARAMS
if options.kernel_node_attributes:
flags |= driver.CUgraphDebugDot_flags.CU_GRAPH_DEBUG_DOT_FLAGS_KERNEL_NODE_ATTRIBUTES
if options.handles:
flags |= driver.CUgraphDebugDot_flags.CU_GRAPH_DEBUG_DOT_FLAGS_HANDLES
if options.mem_alloc_node_params:
flags |= driver.CUgraphDebugDot_flags.CU_GRAPH_DEBUG_DOT_FLAGS_MEM_ALLOC_NODE_PARAMS
if options.mem_free_node_params:
flags |= driver.CUgraphDebugDot_flags.CU_GRAPH_DEBUG_DOT_FLAGS_MEM_FREE_NODE_PARAMS
if options.batch_mem_op_node_params:
flags |= driver.CUgraphDebugDot_flags.CU_GRAPH_DEBUG_DOT_FLAGS_BATCH_MEM_OP_NODE_PARAMS
if options.extra_topo_info:
flags |= driver.CUgraphDebugDot_flags.CU_GRAPH_DEBUG_DOT_FLAGS_EXTRA_TOPO_INFO
if options.conditional_node_params:
flags |= driver.CUgraphDebugDot_flags.CU_GRAPH_DEBUG_DOT_FLAGS_CONDITIONAL_NODE_PARAMS

flags = options._to_flags() if options else 0
handle_return(driver.cuGraphDebugDotPrint(self._mnff.graph, path, flags))

def split(self, count: int) -> tuple[GraphBuilder, ...]:
Expand Down
Loading
Loading