From af1b9087379ba4438f2e8fd2bc1e139bdf560f97 Mon Sep 17 00:00:00 2001 From: Andy Jost Date: Mon, 2 Feb 2026 15:15:40 -0800 Subject: [PATCH 1/7] Begin Cythonization of _program.py - Rename _program.py to _program.pyx - Convert Program to cdef class with _program.pxd declarations - Extract _MembersNeededForFinalize to module-level _ProgramMNFF (nested classes not allowed in cdef class) - Add __repr__ method to Program - Keep ProgramOptions as @dataclass (unchanged) - Keep weakref.finalize pattern for handle cleanup --- cuda_core/cuda/core/_program.pxd | 12 +++++ .../cuda/core/{_program.py => _program.pyx} | 46 ++++++++++--------- 2 files changed, 37 insertions(+), 21 deletions(-) create mode 100644 cuda_core/cuda/core/_program.pxd rename cuda_core/cuda/core/{_program.py => _program.pyx} (97%) diff --git a/cuda_core/cuda/core/_program.pxd b/cuda_core/cuda/core/_program.pxd new file mode 100644 index 0000000000..7dc89cf87d --- /dev/null +++ b/cuda_core/cuda/core/_program.pxd @@ -0,0 +1,12 @@ +# SPDX-FileCopyrightText: Copyright (c) 2024-2025 NVIDIA CORPORATION & AFFILIATES. All rights reserved. +# +# SPDX-License-Identifier: Apache-2.0 + + +cdef class Program: + cdef: + object _mnff + str _backend + object _linker # Linker (not yet Cythonized) + object _options # ProgramOptions + object __weakref__ diff --git a/cuda_core/cuda/core/_program.py b/cuda_core/cuda/core/_program.pyx similarity index 97% rename from cuda_core/cuda/core/_program.py rename to cuda_core/cuda/core/_program.pyx index 1ef1aa51f5..45e5441cac 100644 --- a/cuda_core/cuda/core/_program.py +++ b/cuda_core/cuda/core/_program.pyx @@ -631,7 +631,27 @@ def __repr__(self): ProgramHandleT = Union["cuda.bindings.nvrtc.nvrtcProgram", LinkerHandleT] -class Program: +class _ProgramMNFF: + """Members needed for postrm release of program handles.""" + + __slots__ = "handle", "backend" + + def __init__(self, program_obj, handle, backend): + self.handle = handle + self.backend = backend + weakref.finalize(program_obj, self.close) + + def close(self): + if self.handle is not None: + if self.backend == "NVRTC": + handle_return(nvrtc.nvrtcDestroyProgram(self.handle)) + elif self.backend == "NVVM": + nvvm = _get_nvvm_module() + nvvm.destroy_program(self.handle) + self.handle = None + + +cdef class Program: """Represent a compilation machinery to process programs into :obj:`~_module.ObjectCode`. @@ -650,27 +670,8 @@ class Program: See :obj:`ProgramOptions` for more information. """ - class _MembersNeededForFinalize: - __slots__ = "handle", "backend" - - def __init__(self, program_obj, handle, backend): - self.handle = handle - self.backend = backend - weakref.finalize(program_obj, self.close) - - def close(self): - if self.handle is not None: - if self.backend == "NVRTC": - handle_return(nvrtc.nvrtcDestroyProgram(self.handle)) - elif self.backend == "NVVM": - nvvm = _get_nvvm_module() - nvvm.destroy_program(self.handle) - self.handle = None - - __slots__ = ("__weakref__", "_mnff", "_backend", "_linker", "_options") - def __init__(self, code, code_type, options: ProgramOptions = None): - self._mnff = Program._MembersNeededForFinalize(self, None, None) + self._mnff = _ProgramMNFF(self, None, None) self._options = options = check_or_create_options(ProgramOptions, options, "Program options") code_type = code_type.lower() @@ -858,3 +859,6 @@ def handle(self) -> ProgramHandleT: handle, call ``int(Program.handle)``. """ return self._mnff.handle + + def __repr__(self) -> str: + return f"" From 08a02c046ea18332faf7ea49822443c93ac9e085 Mon Sep 17 00:00:00 2001 From: Andy Jost Date: Mon, 2 Feb 2026 15:23:38 -0800 Subject: [PATCH 2/7] Extract Program helpers to module-level cdef functions - Move _translate_program_options to Program_translate_options (cdef) - Move _can_load_generated_ptx to Program_can_load_generated_ptx (cdef) - Remove unused TYPE_CHECKING import block - Follow _memory/_buffer.pyx helper function patterns --- cuda_core/cuda/core/_program.pyx | 65 +++++++++++++++++--------------- 1 file changed, 35 insertions(+), 30 deletions(-) diff --git a/cuda_core/cuda/core/_program.pyx b/cuda_core/cuda/core/_program.pyx index 45e5441cac..4d2eccd893 100644 --- a/cuda_core/cuda/core/_program.pyx +++ b/cuda_core/cuda/core/_program.pyx @@ -7,12 +7,9 @@ from __future__ import annotations import weakref from contextlib import contextmanager from dataclasses import dataclass -from typing import TYPE_CHECKING, Union +from typing import Union from warnings import warn -if TYPE_CHECKING: - import cuda.bindings - from cuda.core._device import Device from cuda.core._linker import Linker, LinkerHandleT, LinkerOptions from cuda.core._module import ObjectCode @@ -689,7 +686,7 @@ cdef class Program: elif code_type == "ptx": assert_type(code, str) self._linker = Linker( - ObjectCode._init(code.encode(), code_type), options=self._translate_program_options(options) + ObjectCode._init(code.encode(), code_type), options=Program_translate_options(options) ) self._backend = self._linker.backend @@ -711,36 +708,12 @@ cdef class Program: assert code_type not in supported_code_types, f"{code_type=}" raise RuntimeError(f"Unsupported {code_type=} ({supported_code_types=})") - def _translate_program_options(self, options: ProgramOptions) -> LinkerOptions: - return LinkerOptions( - name=options.name, - arch=options.arch, - max_register_count=options.max_register_count, - time=options.time, - link_time_optimization=options.link_time_optimization, - debug=options.debug, - lineinfo=options.lineinfo, - ftz=options.ftz, - prec_div=options.prec_div, - prec_sqrt=options.prec_sqrt, - fma=options.fma, - split_compile=options.split_compile, - ptxas_options=options.ptxas_options, - no_cache=options.no_cache, - ) - def close(self): """Destroy this program.""" if self._linker: self._linker.close() self._mnff.close() - @staticmethod - def _can_load_generated_ptx(): - driver_ver = handle_return(driver.cuDriverGetVersion()) - nvrtc_major, nvrtc_minor = handle_return(nvrtc.nvrtcVersion()) - return nvrtc_major * 1000 + nvrtc_minor * 10 <= driver_ver - def compile(self, target_type, name_expressions=(), logs=None): """Compile the program with a specific compilation type. @@ -768,7 +741,7 @@ cdef class Program: raise ValueError(f'Unsupported target_type="{target_type}" ({supported_target_types=})') if self._backend == "NVRTC": - if target_type == "ptx" and not self._can_load_generated_ptx(): + if target_type == "ptx" and not Program_can_load_generated_ptx(): warn( "The CUDA driver version is older than the backend version. " "The generated ptx will not be loadable by the current driver.", @@ -862,3 +835,35 @@ cdef class Program: def __repr__(self) -> str: return f"" + + +# ============================================================================= +# Helper functions +# ============================================================================= + + +cdef bint Program_can_load_generated_ptx(): + """Check if the driver can load PTX generated by the current NVRTC version.""" + driver_ver = handle_return(driver.cuDriverGetVersion()) + nvrtc_major, nvrtc_minor = handle_return(nvrtc.nvrtcVersion()) + return nvrtc_major * 1000 + nvrtc_minor * 10 <= driver_ver + + +cdef object Program_translate_options(object options): + """Translate ProgramOptions to LinkerOptions for PTX compilation.""" + return LinkerOptions( + name=options.name, + arch=options.arch, + max_register_count=options.max_register_count, + time=options.time, + link_time_optimization=options.link_time_optimization, + debug=options.debug, + lineinfo=options.lineinfo, + ftz=options.ftz, + prec_div=options.prec_div, + prec_sqrt=options.prec_sqrt, + fma=options.fma, + split_compile=options.split_compile, + ptxas_options=options.ptxas_options, + no_cache=options.no_cache, + ) From a136180961861de72bcdd576325a30c719206acd Mon Sep 17 00:00:00 2001 From: Andy Jost Date: Mon, 2 Feb 2026 16:54:21 -0800 Subject: [PATCH 3/7] Complete Cythonization of _program.py - Reorganize file structure per developer guide (principal class first) - Add module docstring, __all__, type alias section - Factor long methods into cdef inline helpers - Add proper exception specs to cdef functions - Fix docstrings (use :class: refs, public paths) - Add type annotations to public methods - Inline _nvvm_exception_manager (single use) - Remove Union import, use | syntax - Add public Program.driver_can_load_nvrtc_ptx_output() API - Update tests to use new public API Closes #1082 --- cuda_core/cuda/core/_program.pxd | 2 +- cuda_core/cuda/core/_program.pyx | 1037 ++++++++++++++++-------------- cuda_core/tests/test_module.py | 4 +- 3 files changed, 544 insertions(+), 499 deletions(-) diff --git a/cuda_core/cuda/core/_program.pxd b/cuda_core/cuda/core/_program.pxd index 7dc89cf87d..444257f1e4 100644 --- a/cuda_core/cuda/core/_program.pxd +++ b/cuda_core/cuda/core/_program.pxd @@ -7,6 +7,6 @@ cdef class Program: cdef: object _mnff str _backend - object _linker # Linker (not yet Cythonized) + object _linker # Linker object _options # ProgramOptions object __weakref__ diff --git a/cuda_core/cuda/core/_program.pyx b/cuda_core/cuda/core/_program.pyx index 4d2eccd893..79a3cd4f7f 100644 --- a/cuda_core/cuda/core/_program.pyx +++ b/cuda_core/cuda/core/_program.pyx @@ -1,15 +1,19 @@ # SPDX-FileCopyrightText: Copyright (c) 2024-2025 NVIDIA CORPORATION & AFFILIATES. All rights reserved. # # SPDX-License-Identifier: Apache-2.0 +"""Compilation machinery for CUDA programs. + +This module provides :class:`Program` for compiling source code into +:class:`~cuda.core.ObjectCode`, with :class:`ProgramOptions` for configuration. +""" from __future__ import annotations import weakref -from contextlib import contextmanager from dataclasses import dataclass -from typing import Union from warnings import warn +from cuda.bindings import driver, nvrtc from cuda.core._device import Device from cuda.core._linker import Linker, LinkerHandleT, LinkerOptions from cuda.core._module import ObjectCode @@ -18,115 +22,127 @@ from cuda.core._utils.cuda_utils import ( CUDAError, _handle_boolean_option, check_or_create_options, - driver, get_binding_version, handle_return, is_nested_sequence, is_sequence, - nvrtc, ) +__all__ = ["Program", "ProgramOptions"] -@contextmanager -def _nvvm_exception_manager(self): - """ - Taken from _linker.py - """ - try: - yield - except Exception as e: - error_log = "" - if hasattr(self, "_mnff"): - try: - nvvm = _get_nvvm_module() - logsize = nvvm.get_program_log_size(self._mnff.handle) - if logsize > 1: - log = bytearray(logsize) - nvvm.get_program_log(self._mnff.handle, log) - error_log = log.decode("utf-8", errors="backslashreplace") - except Exception: - error_log = "" - # Starting Python 3.11 we could also use Exception.add_note() for the same purpose, but - # unfortunately we are still supporting Python 3.10... - e.args = (e.args[0] + (f"\nNVVM program log: {error_log}" if error_log else ""), *e.args[1:]) - raise e +ProgramHandleT = nvrtc.nvrtcProgram | LinkerHandleT +"""Type alias for program handle types across different backends.""" -_nvvm_module = None -_nvvm_import_attempted = False +# ============================================================================= +# Principal Class +# ============================================================================= -def _get_nvvm_module(): - """ - Handles the import of NVVM module with version and availability checks. - NVVM bindings were added in cuda-bindings 12.9.0, so we need to handle cases where: - 1. cuda.bindings is not new enough (< 12.9.0) - 2. libnvvm is not found in the Python environment +cdef class Program: + """Represent a compilation machinery to process programs into + :class:`~cuda.core.ObjectCode`. - Returns: - The nvvm module if available and working + This object provides a unified interface to multiple underlying + compiler libraries. Compilation support is enabled for a wide + range of code types and compilation types. - Raises: - RuntimeError: If NVVM is not available due to version or library issues + Parameters + ---------- + code : str | bytes | bytearray + The source code to compile. For C++ and PTX, must be a string. + For NVVM IR, can be str, bytes, or bytearray. + code_type : str + The type of source code. Must be one of ``"c++"``, ``"ptx"``, or ``"nvvm"``. + options : :class:`ProgramOptions`, optional + Options to customize the compilation process. """ - global _nvvm_module, _nvvm_import_attempted - if _nvvm_import_attempted: - if _nvvm_module is None: - raise RuntimeError("NVVM module is not available (previous import attempt failed)") - return _nvvm_module + def __init__(self, code: str | bytes | bytearray, code_type: str, options: ProgramOptions | None = None): + Program_init(self, code, code_type, options) - _nvvm_import_attempted = True + def close(self): + """Destroy this program.""" + if self._linker: + self._linker.close() + self._mnff.close() - try: - version = get_binding_version() - if version < (12, 9): - raise RuntimeError( - f"NVVM bindings require cuda-bindings >= 12.9.0, but found {version[0]}.{version[1]}.x. " - "Please update cuda-bindings to use NVVM features." - ) + def compile( + self, target_type: str, name_expressions: tuple | list = (), logs = None + ) -> ObjectCode: + """Compile the program to the specified target type. - from cuda.bindings import nvvm - from cuda.bindings._internal.nvvm import _inspect_function_pointer + Parameters + ---------- + target_type : str + The compilation target. Must be one of ``"ptx"``, ``"cubin"``, or ``"ltoir"``. + name_expressions : tuple | list, optional + Sequence of name expressions to make accessible in the compiled code. + Used for template instantiation and similar cases. + logs : object, optional + Object with a ``write`` method to receive compilation logs. - if _inspect_function_pointer("__nvvmCreateProgram") == 0: - raise RuntimeError("NVVM library (libnvvm) is not available in this Python environment. ") + Returns + ------- + :class:`~cuda.core.ObjectCode` + The compiled object code. + """ + return Program_compile(self, target_type, name_expressions, logs) - _nvvm_module = nvvm - return _nvvm_module + @property + def backend(self) -> str: + """Return this Program instance's underlying backend.""" + return self._backend - except RuntimeError as e: - _nvvm_module = None - raise e + @property + def handle(self) -> ProgramHandleT: + """Return the underlying handle object. + .. note:: -def _process_define_macro_inner(formatted_options, macro): - if isinstance(macro, str): - formatted_options.append(f"--define-macro={macro}") - return True - if isinstance(macro, tuple): - if len(macro) != 2 or any(not isinstance(val, str) for val in macro): - raise RuntimeError(f"Expected define_macro tuple[str, str], got {macro}") - formatted_options.append(f"--define-macro={macro[0]}={macro[1]}") - return True - return False + The type of the returned object depends on the backend. + .. caution:: -def _process_define_macro(formatted_options, macro): - union_type = "Union[str, tuple[str, str]]" - if _process_define_macro_inner(formatted_options, macro): - return - if is_nested_sequence(macro): - for seq_macro in macro: - if not _process_define_macro_inner(formatted_options, seq_macro): - raise RuntimeError(f"Expected define_macro {union_type}, got {seq_macro}") - return - raise RuntimeError(f"Expected define_macro {union_type}, list[{union_type}], got {macro}") + This handle is a Python object. To get the memory address of the underlying C + handle, call ``int(Program.handle)``. + """ + return self._mnff.handle + + @staticmethod + def driver_can_load_nvrtc_ptx_output() -> bool: + """Check if the CUDA driver can load PTX generated by NVRTC. + + NVRTC generates PTX targeting a specific CUDA version. If the installed + driver is older than the NVRTC version, it may not be able to load the + generated PTX. + + Returns + ------- + bool + True if the driver version is new enough to load PTX generated + by the current NVRTC version, False otherwise. + + Examples + -------- + >>> if Program.driver_can_load_nvrtc_ptx_output(): + ... obj = program.compile("ptx") + ... kernel = obj.get_kernel("my_kernel") + """ + return _can_load_generated_ptx() + + def __repr__(self) -> str: + return f"" + + +# ============================================================================= +# Other Public Classes +# ============================================================================= @dataclass class ProgramOptions: - """Customizable options for configuring `Program`. + """Customizable options for configuring :class:`Program`. Attributes ---------- @@ -149,7 +165,7 @@ class ProgramOptions: Generate line-number information. Default: False device_code_optimize : bool, optional - Enable device code optimization. When specified along with ā€˜-G’, enables limited debug information generation + Enable device code optimization. When specified along with '-G', enables limited debug information generation for optimized device code. Default: None ptxas_options : Union[str, list[str]], optional @@ -351,238 +367,10 @@ class ProgramOptions: self.arch = f"sm_{Device().arch}" def _prepare_nvrtc_options(self) -> list[bytes]: - # Build NVRTC-specific options - options = [f"-arch={self.arch}"] - if self.relocatable_device_code is not None: - options.append(f"--relocatable-device-code={_handle_boolean_option(self.relocatable_device_code)}") - if self.extensible_whole_program is not None and self.extensible_whole_program: - options.append("--extensible-whole-program") - if self.debug is not None and self.debug: - options.append("--device-debug") - if self.lineinfo is not None and self.lineinfo: - options.append("--generate-line-info") - if self.device_code_optimize is not None and self.device_code_optimize: - options.append("--dopt=on") - if self.ptxas_options is not None: - opt_name = "--ptxas-options" - if isinstance(self.ptxas_options, str): - options.append(f"{opt_name}={self.ptxas_options}") - elif is_sequence(self.ptxas_options): - for opt_value in self.ptxas_options: - options.append(f"{opt_name}={opt_value}") - if self.max_register_count is not None: - options.append(f"--maxrregcount={self.max_register_count}") - if self.ftz is not None: - options.append(f"--ftz={_handle_boolean_option(self.ftz)}") - if self.prec_sqrt is not None: - options.append(f"--prec-sqrt={_handle_boolean_option(self.prec_sqrt)}") - if self.prec_div is not None: - options.append(f"--prec-div={_handle_boolean_option(self.prec_div)}") - if self.fma is not None: - options.append(f"--fmad={_handle_boolean_option(self.fma)}") - if self.use_fast_math is not None and self.use_fast_math: - options.append("--use_fast_math") - if self.extra_device_vectorization is not None and self.extra_device_vectorization: - options.append("--extra-device-vectorization") - if self.link_time_optimization is not None and self.link_time_optimization: - options.append("--dlink-time-opt") - if self.gen_opt_lto is not None and self.gen_opt_lto: - options.append("--gen-opt-lto") - if self.define_macro is not None: - _process_define_macro(options, self.define_macro) - if self.undefine_macro is not None: - if isinstance(self.undefine_macro, str): - options.append(f"--undefine-macro={self.undefine_macro}") - elif is_sequence(self.undefine_macro): - for macro in self.undefine_macro: - options.append(f"--undefine-macro={macro}") - if self.include_path is not None: - if isinstance(self.include_path, str): - options.append(f"--include-path={self.include_path}") - elif is_sequence(self.include_path): - for path in self.include_path: - options.append(f"--include-path={path}") - if self.pre_include is not None: - if isinstance(self.pre_include, str): - options.append(f"--pre-include={self.pre_include}") - elif is_sequence(self.pre_include): - for header in self.pre_include: - options.append(f"--pre-include={header}") - if self.no_source_include is not None and self.no_source_include: - options.append("--no-source-include") - if self.std is not None: - options.append(f"--std={self.std}") - if self.builtin_move_forward is not None: - options.append(f"--builtin-move-forward={_handle_boolean_option(self.builtin_move_forward)}") - if self.builtin_initializer_list is not None: - options.append(f"--builtin-initializer-list={_handle_boolean_option(self.builtin_initializer_list)}") - if self.disable_warnings is not None and self.disable_warnings: - options.append("--disable-warnings") - if self.restrict is not None and self.restrict: - options.append("--restrict") - if self.device_as_default_execution_space is not None and self.device_as_default_execution_space: - options.append("--device-as-default-execution-space") - if self.device_int128 is not None and self.device_int128: - options.append("--device-int128") - if self.device_float128 is not None and self.device_float128: - options.append("--device-float128") - if self.optimization_info is not None: - options.append(f"--optimization-info={self.optimization_info}") - if self.no_display_error_number is not None and self.no_display_error_number: - options.append("--no-display-error-number") - if self.diag_error is not None: - if isinstance(self.diag_error, int): - options.append(f"--diag-error={self.diag_error}") - elif is_sequence(self.diag_error): - for error in self.diag_error: - options.append(f"--diag-error={error}") - if self.diag_suppress is not None: - if isinstance(self.diag_suppress, int): - options.append(f"--diag-suppress={self.diag_suppress}") - elif is_sequence(self.diag_suppress): - for suppress in self.diag_suppress: - options.append(f"--diag-suppress={suppress}") - if self.diag_warn is not None: - if isinstance(self.diag_warn, int): - options.append(f"--diag-warn={self.diag_warn}") - elif is_sequence(self.diag_warn): - for warn in self.diag_warn: - options.append(f"--diag-warn={warn}") - if self.brief_diagnostics is not None: - options.append(f"--brief-diagnostics={_handle_boolean_option(self.brief_diagnostics)}") - if self.time is not None: - options.append(f"--time={self.time}") - if self.split_compile is not None: - options.append(f"--split-compile={self.split_compile}") - if self.fdevice_syntax_only is not None and self.fdevice_syntax_only: - options.append("--fdevice-syntax-only") - if self.minimal is not None and self.minimal: - options.append("--minimal") - if self.no_cache is not None and self.no_cache: - options.append("--no-cache") - if self.fdevice_time_trace is not None: - options.append(f"--fdevice-time-trace={self.fdevice_time_trace}") - if self.frandom_seed is not None: - options.append(f"--frandom-seed={self.frandom_seed}") - if self.ofast_compile is not None: - options.append(f"--Ofast-compile={self.ofast_compile}") - # PCH options (CUDA 12.8+) - if self.pch is not None and self.pch: - options.append("--pch") - if self.create_pch is not None: - options.append(f"--create-pch={self.create_pch}") - if self.use_pch is not None: - options.append(f"--use-pch={self.use_pch}") - if self.pch_dir is not None: - options.append(f"--pch-dir={self.pch_dir}") - if self.pch_verbose is not None: - options.append(f"--pch-verbose={_handle_boolean_option(self.pch_verbose)}") - if self.pch_messages is not None: - options.append(f"--pch-messages={_handle_boolean_option(self.pch_messages)}") - if self.instantiate_templates_in_pch is not None: - options.append( - f"--instantiate-templates-in-pch={_handle_boolean_option(self.instantiate_templates_in_pch)}" - ) - if self.numba_debug: - options.append("--numba-debug") - return [o.encode() for o in options] + return _prepare_nvrtc_options_impl(self) def _prepare_nvvm_options(self, as_bytes: bool = True) -> list[bytes] | list[str]: - options = [] - - # Options supported by NVVM - assert self.arch is not None - arch = self.arch - if arch.startswith("sm_"): - arch = f"compute_{arch[3:]}" - options.append(f"-arch={arch}") - if self.debug is not None and self.debug: - options.append("-g") - if self.device_code_optimize is False: - options.append("-opt=0") - elif self.device_code_optimize is True: - options.append("-opt=3") - # NVVM uses 0/1 instead of true/false for boolean options - if self.ftz is not None: - options.append(f"-ftz={'1' if self.ftz else '0'}") - if self.prec_sqrt is not None: - options.append(f"-prec-sqrt={'1' if self.prec_sqrt else '0'}") - if self.prec_div is not None: - options.append(f"-prec-div={'1' if self.prec_div else '0'}") - if self.fma is not None: - options.append(f"-fma={'1' if self.fma else '0'}") - - # Check for unsupported options and raise error if they are set - unsupported = [] - if self.relocatable_device_code is not None: - unsupported.append("relocatable_device_code") - if self.extensible_whole_program is not None and self.extensible_whole_program: - unsupported.append("extensible_whole_program") - if self.lineinfo is not None and self.lineinfo: - unsupported.append("lineinfo") - if self.ptxas_options is not None: - unsupported.append("ptxas_options") - if self.max_register_count is not None: - unsupported.append("max_register_count") - if self.use_fast_math is not None and self.use_fast_math: - unsupported.append("use_fast_math") - if self.extra_device_vectorization is not None and self.extra_device_vectorization: - unsupported.append("extra_device_vectorization") - if self.gen_opt_lto is not None and self.gen_opt_lto: - unsupported.append("gen_opt_lto") - if self.define_macro is not None: - unsupported.append("define_macro") - if self.undefine_macro is not None: - unsupported.append("undefine_macro") - if self.include_path is not None: - unsupported.append("include_path") - if self.pre_include is not None: - unsupported.append("pre_include") - if self.no_source_include is not None and self.no_source_include: - unsupported.append("no_source_include") - if self.std is not None: - unsupported.append("std") - if self.builtin_move_forward is not None: - unsupported.append("builtin_move_forward") - if self.builtin_initializer_list is not None: - unsupported.append("builtin_initializer_list") - if self.disable_warnings is not None and self.disable_warnings: - unsupported.append("disable_warnings") - if self.restrict is not None and self.restrict: - unsupported.append("restrict") - if self.device_as_default_execution_space is not None and self.device_as_default_execution_space: - unsupported.append("device_as_default_execution_space") - if self.device_int128 is not None and self.device_int128: - unsupported.append("device_int128") - if self.optimization_info is not None: - unsupported.append("optimization_info") - if self.no_display_error_number is not None and self.no_display_error_number: - unsupported.append("no_display_error_number") - if self.diag_error is not None: - unsupported.append("diag_error") - if self.diag_suppress is not None: - unsupported.append("diag_suppress") - if self.diag_warn is not None: - unsupported.append("diag_warn") - if self.brief_diagnostics is not None: - unsupported.append("brief_diagnostics") - if self.time is not None: - unsupported.append("time") - if self.split_compile is not None: - unsupported.append("split_compile") - if self.fdevice_syntax_only is not None and self.fdevice_syntax_only: - unsupported.append("fdevice_syntax_only") - if self.minimal is not None and self.minimal: - unsupported.append("minimal") - if self.numba_debug is not None and self.numba_debug: - unsupported.append("numba_debug") - if unsupported: - raise CUDAError(f"The following options are not supported by NVVM backend: {', '.join(unsupported)}") - - if as_bytes: - return [o.encode() for o in options] - else: - return options + return _prepare_nvvm_options_impl(self, as_bytes) def as_bytes(self, backend: str) -> list[bytes]: """Convert program options to bytes format for the specified backend. @@ -625,7 +413,13 @@ class ProgramOptions: return f"ProgramOptions(name={self.name!r}, arch={self.arch!r})" -ProgramHandleT = Union["cuda.bindings.nvrtc.nvrtcProgram", LinkerHandleT] +# ============================================================================= +# Private Classes and Helper Functions +# ============================================================================= + +# Module-level state for NVVM lazy loading +cdef object_nvvm_module = None +cdef bint _nvvm_import_attempted = False class _ProgramMNFF: @@ -648,208 +442,73 @@ class _ProgramMNFF: self.handle = None -cdef class Program: - """Represent a compilation machinery to process programs into - :obj:`~_module.ObjectCode`. - - This object provides a unified interface to multiple underlying - compiler libraries. Compilation support is enabled for a wide - range of code types and compilation types. - - Parameters - ---------- - code : Any - String of the CUDA Runtime Compilation program. - code_type : Any - String of the code type. Currently ``"ptx"``, ``"c++"``, and ``"nvvm"`` are supported. - options : ProgramOptions, optional - A ProgramOptions object to customize the compilation process. - See :obj:`ProgramOptions` for more information. - """ - - def __init__(self, code, code_type, options: ProgramOptions = None): - self._mnff = _ProgramMNFF(self, None, None) - - self._options = options = check_or_create_options(ProgramOptions, options, "Program options") - code_type = code_type.lower() - - if code_type == "c++": - assert_type(code, str) - # TODO: support pre-loaded headers & include names - # TODO: allow tuples once NVIDIA/cuda-python#72 is resolved - - self._mnff.handle = handle_return(nvrtc.nvrtcCreateProgram(code.encode(), options._name, 0, [], [])) - self._mnff.backend = "NVRTC" - self._backend = "NVRTC" - self._linker = None - - elif code_type == "ptx": - assert_type(code, str) - self._linker = Linker( - ObjectCode._init(code.encode(), code_type), options=Program_translate_options(options) - ) - self._backend = self._linker.backend - - elif code_type == "nvvm": - if isinstance(code, str): - code = code.encode("utf-8") - elif not isinstance(code, (bytes, bytearray)): - raise TypeError("NVVM IR code must be provided as str, bytes, or bytearray") - - nvvm = _get_nvvm_module() - self._mnff.handle = nvvm.create_program() - self._mnff.backend = "NVVM" - nvvm.add_module_to_program(self._mnff.handle, code, len(code), options._name.decode()) - self._backend = "NVVM" - self._linker = None - - else: - supported_code_types = ("c++", "ptx", "nvvm") - assert code_type not in supported_code_types, f"{code_type=}" - raise RuntimeError(f"Unsupported {code_type=} ({supported_code_types=})") - - def close(self): - """Destroy this program.""" - if self._linker: - self._linker.close() - self._mnff.close() - - def compile(self, target_type, name_expressions=(), logs=None): - """Compile the program with a specific compilation type. +def _get_nvvm_module(): + """Get the NVVM module, importing it lazily with availability checks.""" + global _nvvm_module, _nvvm_import_attempted - Parameters - ---------- - target_type : Any - String of the targeted compilation type. - Supported options are "ptx", "cubin" and "ltoir". - name_expressions : Union[list, tuple], optional - List of explicit name expressions to become accessible. - (Default to no expressions) - logs : Any, optional - Object with a write method to receive the logs generated - from compilation. - (Default to no logs) + if _nvvm_import_attempted: + if _nvvm_module is None: + raise RuntimeError("NVVM module is not available (previous import attempt failed)") + return _nvvm_module - Returns - ------- - :obj:`~_module.ObjectCode` - Newly created code object. + _nvvm_import_attempted = True - """ - supported_target_types = ("ptx", "cubin", "ltoir") - if target_type not in supported_target_types: - raise ValueError(f'Unsupported target_type="{target_type}" ({supported_target_types=})') - - if self._backend == "NVRTC": - if target_type == "ptx" and not Program_can_load_generated_ptx(): - warn( - "The CUDA driver version is older than the backend version. " - "The generated ptx will not be loadable by the current driver.", - stacklevel=1, - category=RuntimeWarning, - ) - if name_expressions: - for n in name_expressions: - handle_return( - nvrtc.nvrtcAddNameExpression(self._mnff.handle, n.encode()), - handle=self._mnff.handle, - ) - options = self._options.as_bytes("nvrtc") - handle_return( - nvrtc.nvrtcCompileProgram(self._mnff.handle, len(options), options), - handle=self._mnff.handle, + try: + version = get_binding_version() + if version < (12, 9): + raise RuntimeError( + f"NVVM bindings require cuda-bindings >= 12.9.0, but found {version[0]}.{version[1]}.x. " + "Please update cuda-bindings to use NVVM features." ) - size_func = getattr(nvrtc, f"nvrtcGet{target_type.upper()}Size") - comp_func = getattr(nvrtc, f"nvrtcGet{target_type.upper()}") - size = handle_return(size_func(self._mnff.handle), handle=self._mnff.handle) - data = b" " * size - handle_return(comp_func(self._mnff.handle, data), handle=self._mnff.handle) - - symbol_mapping = {} - if name_expressions: - for n in name_expressions: - symbol_mapping[n] = handle_return( - nvrtc.nvrtcGetLoweredName(self._mnff.handle, n.encode()), handle=self._mnff.handle - ) - - if logs is not None: - logsize = handle_return(nvrtc.nvrtcGetProgramLogSize(self._mnff.handle), handle=self._mnff.handle) - if logsize > 1: - log = b" " * logsize - handle_return(nvrtc.nvrtcGetProgramLog(self._mnff.handle, log), handle=self._mnff.handle) - logs.write(log.decode("utf-8", errors="backslashreplace")) - - return ObjectCode._init(data, target_type, symbol_mapping=symbol_mapping, name=self._options.name) - - elif self._backend == "NVVM": - if target_type not in ("ptx", "ltoir"): - raise ValueError(f'NVVM backend only supports target_type="ptx", "ltoir", got "{target_type}"') - - # TODO: flip to True when NVIDIA/cuda-python#1354 is resolved and CUDA 12 is dropped - nvvm_options = self._options._prepare_nvvm_options(as_bytes=False) - if target_type == "ltoir" and "-gen-lto" not in nvvm_options: - nvvm_options.append("-gen-lto") - nvvm = _get_nvvm_module() - with _nvvm_exception_manager(self): - nvvm.verify_program(self._mnff.handle, len(nvvm_options), nvvm_options) - nvvm.compile_program(self._mnff.handle, len(nvvm_options), nvvm_options) - - size = nvvm.get_compiled_result_size(self._mnff.handle) - data = bytearray(size) - nvvm.get_compiled_result(self._mnff.handle, data) - - if logs is not None: - logsize = nvvm.get_program_log_size(self._mnff.handle) - if logsize > 1: - log = bytearray(logsize) - nvvm.get_program_log(self._mnff.handle, log) - logs.write(log.decode("utf-8", errors="backslashreplace")) - - return ObjectCode._init(data, target_type, name=self._options.name) - - supported_backends = ("nvJitLink", "driver") - if self._backend not in supported_backends: - raise ValueError(f'Unsupported backend="{self._backend}" ({supported_backends=})') - return self._linker.link(target_type) - - @property - def backend(self) -> str: - """Return this Program instance's underlying backend.""" - return self._backend - - @property - def handle(self) -> ProgramHandleT: - """Return the underlying handle object. + from cuda.bindings import nvvm + from cuda.bindings._internal.nvvm import _inspect_function_pointer - .. note:: + if _inspect_function_pointer("__nvvmCreateProgram") == 0: + raise RuntimeError("NVVM library (libnvvm) is not available in this Python environment. ") - The type of the returned object depends on the backend. + _nvvm_module = nvvm + return _nvvm_module - .. caution:: + except RuntimeError as e: + _nvvm_module = None + raise e - This handle is a Python object. To get the memory address of the underlying C - handle, call ``int(Program.handle)``. - """ - return self._mnff.handle - def __repr__(self) -> str: - return f"" +cdef inline bint _process_define_macro_inner(list options, object macro) except? -1: + """Process a single define macro, returning True if successful.""" + if isinstance(macro, str): + options.append(f"--define-macro={macro}") + return True + if isinstance(macro, tuple): + if len(macro) != 2 or any(not isinstance(val, str) for val in macro): + raise RuntimeError(f"Expected define_macro tuple[str, str], got {macro}") + options.append(f"--define-macro={macro[0]}={macro[1]}") + return True + return False -# ============================================================================= -# Helper functions -# ============================================================================= +cdef inline void _process_define_macro(list options, object macro) except *: + """Process define_macro option which can be str, tuple, or list thereof.""" + union_type = "Union[str, tuple[str, str]]" + if _process_define_macro_inner(options, macro): + return + if is_nested_sequence(macro): + for seq_macro in macro: + if not _process_define_macro_inner(options, seq_macro): + raise RuntimeError(f"Expected define_macro {union_type}, got {seq_macro}") + return + raise RuntimeError(f"Expected define_macro {union_type}, list[{union_type}], got {macro}") -cdef bint Program_can_load_generated_ptx(): +cdef inline bint _can_load_generated_ptx() except? -1: """Check if the driver can load PTX generated by the current NVRTC version.""" driver_ver = handle_return(driver.cuDriverGetVersion()) nvrtc_major, nvrtc_minor = handle_return(nvrtc.nvrtcVersion()) return nvrtc_major * 1000 + nvrtc_minor * 10 <= driver_ver -cdef object Program_translate_options(object options): +cdef inline object _translate_program_options(object options): """Translate ProgramOptions to LinkerOptions for PTX compilation.""" return LinkerOptions( name=options.name, @@ -867,3 +526,389 @@ cdef object Program_translate_options(object options): ptxas_options=options.ptxas_options, no_cache=options.no_cache, ) + + +cdef inline int Program_init(Program self, object code, str code_type, object options) except -1: + """Initialize a Program instance.""" + self._mnff = _ProgramMNFF(self, None, None) + self._options = options = check_or_create_options(ProgramOptions, options, "Program options") + code_type = code_type.lower() + + if code_type == "c++": + assert_type(code, str) + # TODO: support pre-loaded headers & include names + # TODO: allow tuples once NVIDIA/cuda-python#72 is resolved + self._mnff.handle = handle_return(nvrtc.nvrtcCreateProgram(code.encode(), options._name, 0, [], [])) + self._mnff.backend = "NVRTC" + self._backend = "NVRTC" + self._linker = None + + elif code_type == "ptx": + assert_type(code, str) + self._linker = Linker( + ObjectCode._init(code.encode(), code_type), options=_translate_program_options(options) + ) + self._backend = self._linker.backend + + elif code_type == "nvvm": + if isinstance(code, str): + code = code.encode("utf-8") + elif not isinstance(code, (bytes, bytearray)): + raise TypeError("NVVM IR code must be provided as str, bytes, or bytearray") + + nvvm = _get_nvvm_module() + self._mnff.handle = nvvm.create_program() + self._mnff.backend = "NVVM" + nvvm.add_module_to_program(self._mnff.handle, code, len(code), options._name.decode()) + self._backend = "NVVM" + self._linker = None + + else: + supported_code_types = ("c++", "ptx", "nvvm") + assert code_type not in supported_code_types, f"{code_type=}" + raise RuntimeError(f"Unsupported {code_type=} ({supported_code_types=})") + + return 0 + + +cdef object Program_compile_nvrtc(Program self, str target_type, object name_expressions, object logs): + """Compile using NVRTC backend.""" + if target_type == "ptx" and not _can_load_generated_ptx(): + warn( + "The CUDA driver version is older than the backend version. " + "The generated ptx will not be loadable by the current driver.", + stacklevel=2, + category=RuntimeWarning, + ) + + if name_expressions: + for n in name_expressions: + handle_return( + nvrtc.nvrtcAddNameExpression(self._mnff.handle, n.encode()), + handle=self._mnff.handle, + ) + + options = self._options.as_bytes("nvrtc") + handle_return( + nvrtc.nvrtcCompileProgram(self._mnff.handle, len(options), options), + handle=self._mnff.handle, + ) + + size_func = getattr(nvrtc, f"nvrtcGet{target_type.upper()}Size") + comp_func = getattr(nvrtc, f"nvrtcGet{target_type.upper()}") + size = handle_return(size_func(self._mnff.handle), handle=self._mnff.handle) + data = b" " * size + handle_return(comp_func(self._mnff.handle, data), handle=self._mnff.handle) + + symbol_mapping = {} + if name_expressions: + for n in name_expressions: + symbol_mapping[n] = handle_return( + nvrtc.nvrtcGetLoweredName(self._mnff.handle, n.encode()), handle=self._mnff.handle + ) + + if logs is not None: + logsize = handle_return(nvrtc.nvrtcGetProgramLogSize(self._mnff.handle), handle=self._mnff.handle) + if logsize > 1: + log = b" " * logsize + handle_return(nvrtc.nvrtcGetProgramLog(self._mnff.handle, log), handle=self._mnff.handle) + logs.write(log.decode("utf-8", errors="backslashreplace")) + + return ObjectCode._init(data, target_type, symbol_mapping=symbol_mapping, name=self._options.name) + + +cdef object Program_compile_nvvm(Program self, str target_type, object logs): + """Compile using NVVM backend.""" + if target_type not in ("ptx", "ltoir"): + raise ValueError(f'NVVM backend only supports target_type="ptx", "ltoir", got "{target_type}"') + + # TODO: flip to True when NVIDIA/cuda-python#1354 is resolved and CUDA 12 is dropped + nvvm_options = self._options._prepare_nvvm_options(as_bytes=False) + if target_type == "ltoir" and "-gen-lto" not in nvvm_options: + nvvm_options.append("-gen-lto") + + nvvm = _get_nvvm_module() + try: + nvvm.verify_program(self._mnff.handle, len(nvvm_options), nvvm_options) + nvvm.compile_program(self._mnff.handle, len(nvvm_options), nvvm_options) + except Exception as e: + # Capture NVVM program log on error + error_log = "" + try: + logsize = nvvm.get_program_log_size(self._mnff.handle) + if logsize > 1: + log = bytearray(logsize) + nvvm.get_program_log(self._mnff.handle, log) + error_log = log.decode("utf-8", errors="backslashreplace") + except Exception: + pass + e.args = (e.args[0] + (f"\nNVVM program log: {error_log}" if error_log else ""), *e.args[1:]) + raise + + size = nvvm.get_compiled_result_size(self._mnff.handle) + data = bytearray(size) + nvvm.get_compiled_result(self._mnff.handle, data) + + if logs is not None: + logsize = nvvm.get_program_log_size(self._mnff.handle) + if logsize > 1: + log = bytearray(logsize) + nvvm.get_program_log(self._mnff.handle, log) + logs.write(log.decode("utf-8", errors="backslashreplace")) + + return ObjectCode._init(data, target_type, name=self._options.name) + + +cdef object Program_compile(Program self, str target_type, object name_expressions, object logs): + """Compile the program to the specified target type.""" + supported_target_types = ("ptx", "cubin", "ltoir") + if target_type not in supported_target_types: + raise ValueError(f'Unsupported target_type="{target_type}" ({supported_target_types=})') + + if self._backend == "NVRTC": + return Program_compile_nvrtc(self, target_type, name_expressions, logs) + elif self._backend == "NVVM": + return Program_compile_nvvm(self, target_type, logs) + + # Linker backend (PTX code type) + supported_backends = ("nvJitLink", "driver") + if self._backend not in supported_backends: + raise ValueError(f'Unsupported backend="{self._backend}" ({supported_backends=})') + return self._linker.link(target_type) + + +cdef inline list _prepare_nvrtc_options_impl(object opts): + """Build NVRTC-specific compiler options.""" + options = [f"-arch={opts.arch}"] + if opts.relocatable_device_code is not None: + options.append(f"--relocatable-device-code={_handle_boolean_option(opts.relocatable_device_code)}") + if opts.extensible_whole_program is not None and opts.extensible_whole_program: + options.append("--extensible-whole-program") + if opts.debug is not None and opts.debug: + options.append("--device-debug") + if opts.lineinfo is not None and opts.lineinfo: + options.append("--generate-line-info") + if opts.device_code_optimize is not None and opts.device_code_optimize: + options.append("--dopt=on") + if opts.ptxas_options is not None: + opt_name = "--ptxas-options" + if isinstance(opts.ptxas_options, str): + options.append(f"{opt_name}={opts.ptxas_options}") + elif is_sequence(opts.ptxas_options): + for opt_value in opts.ptxas_options: + options.append(f"{opt_name}={opt_value}") + if opts.max_register_count is not None: + options.append(f"--maxrregcount={opts.max_register_count}") + if opts.ftz is not None: + options.append(f"--ftz={_handle_boolean_option(opts.ftz)}") + if opts.prec_sqrt is not None: + options.append(f"--prec-sqrt={_handle_boolean_option(opts.prec_sqrt)}") + if opts.prec_div is not None: + options.append(f"--prec-div={_handle_boolean_option(opts.prec_div)}") + if opts.fma is not None: + options.append(f"--fmad={_handle_boolean_option(opts.fma)}") + if opts.use_fast_math is not None and opts.use_fast_math: + options.append("--use_fast_math") + if opts.extra_device_vectorization is not None and opts.extra_device_vectorization: + options.append("--extra-device-vectorization") + if opts.link_time_optimization is not None and opts.link_time_optimization: + options.append("--dlink-time-opt") + if opts.gen_opt_lto is not None and opts.gen_opt_lto: + options.append("--gen-opt-lto") + if opts.define_macro is not None: + _process_define_macro(options, opts.define_macro) + if opts.undefine_macro is not None: + if isinstance(opts.undefine_macro, str): + options.append(f"--undefine-macro={opts.undefine_macro}") + elif is_sequence(opts.undefine_macro): + for macro in opts.undefine_macro: + options.append(f"--undefine-macro={macro}") + if opts.include_path is not None: + if isinstance(opts.include_path, str): + options.append(f"--include-path={opts.include_path}") + elif is_sequence(opts.include_path): + for path in opts.include_path: + options.append(f"--include-path={path}") + if opts.pre_include is not None: + if isinstance(opts.pre_include, str): + options.append(f"--pre-include={opts.pre_include}") + elif is_sequence(opts.pre_include): + for header in opts.pre_include: + options.append(f"--pre-include={header}") + if opts.no_source_include is not None and opts.no_source_include: + options.append("--no-source-include") + if opts.std is not None: + options.append(f"--std={opts.std}") + if opts.builtin_move_forward is not None: + options.append(f"--builtin-move-forward={_handle_boolean_option(opts.builtin_move_forward)}") + if opts.builtin_initializer_list is not None: + options.append(f"--builtin-initializer-list={_handle_boolean_option(opts.builtin_initializer_list)}") + if opts.disable_warnings is not None and opts.disable_warnings: + options.append("--disable-warnings") + if opts.restrict is not None and opts.restrict: + options.append("--restrict") + if opts.device_as_default_execution_space is not None and opts.device_as_default_execution_space: + options.append("--device-as-default-execution-space") + if opts.device_int128 is not None and opts.device_int128: + options.append("--device-int128") + if opts.device_float128 is not None and opts.device_float128: + options.append("--device-float128") + if opts.optimization_info is not None: + options.append(f"--optimization-info={opts.optimization_info}") + if opts.no_display_error_number is not None and opts.no_display_error_number: + options.append("--no-display-error-number") + if opts.diag_error is not None: + if isinstance(opts.diag_error, int): + options.append(f"--diag-error={opts.diag_error}") + elif is_sequence(opts.diag_error): + for error in opts.diag_error: + options.append(f"--diag-error={error}") + if opts.diag_suppress is not None: + if isinstance(opts.diag_suppress, int): + options.append(f"--diag-suppress={opts.diag_suppress}") + elif is_sequence(opts.diag_suppress): + for suppress in opts.diag_suppress: + options.append(f"--diag-suppress={suppress}") + if opts.diag_warn is not None: + if isinstance(opts.diag_warn, int): + options.append(f"--diag-warn={opts.diag_warn}") + elif is_sequence(opts.diag_warn): + for w in opts.diag_warn: + options.append(f"--diag-warn={w}") + if opts.brief_diagnostics is not None: + options.append(f"--brief-diagnostics={_handle_boolean_option(opts.brief_diagnostics)}") + if opts.time is not None: + options.append(f"--time={opts.time}") + if opts.split_compile is not None: + options.append(f"--split-compile={opts.split_compile}") + if opts.fdevice_syntax_only is not None and opts.fdevice_syntax_only: + options.append("--fdevice-syntax-only") + if opts.minimal is not None and opts.minimal: + options.append("--minimal") + if opts.no_cache is not None and opts.no_cache: + options.append("--no-cache") + if opts.fdevice_time_trace is not None: + options.append(f"--fdevice-time-trace={opts.fdevice_time_trace}") + if opts.frandom_seed is not None: + options.append(f"--frandom-seed={opts.frandom_seed}") + if opts.ofast_compile is not None: + options.append(f"--Ofast-compile={opts.ofast_compile}") + # PCH options (CUDA 12.8+) + if opts.pch is not None and opts.pch: + options.append("--pch") + if opts.create_pch is not None: + options.append(f"--create-pch={opts.create_pch}") + if opts.use_pch is not None: + options.append(f"--use-pch={opts.use_pch}") + if opts.pch_dir is not None: + options.append(f"--pch-dir={opts.pch_dir}") + if opts.pch_verbose is not None: + options.append(f"--pch-verbose={_handle_boolean_option(opts.pch_verbose)}") + if opts.pch_messages is not None: + options.append(f"--pch-messages={_handle_boolean_option(opts.pch_messages)}") + if opts.instantiate_templates_in_pch is not None: + options.append( + f"--instantiate-templates-in-pch={_handle_boolean_option(opts.instantiate_templates_in_pch)}" + ) + if opts.numba_debug: + options.append("--numba-debug") + return [o.encode() for o in options] + + +cdef inline object _prepare_nvvm_options_impl(object opts, bint as_bytes): + """Build NVVM-specific compiler options.""" + options = [] + + # Options supported by NVVM + assert opts.arch is not None + arch = opts.arch + if arch.startswith("sm_"): + arch = f"compute_{arch[3:]}" + options.append(f"-arch={arch}") + if opts.debug is not None and opts.debug: + options.append("-g") + if opts.device_code_optimize is False: + options.append("-opt=0") + elif opts.device_code_optimize is True: + options.append("-opt=3") + # NVVM uses 0/1 instead of true/false for boolean options + if opts.ftz is not None: + options.append(f"-ftz={'1' if opts.ftz else '0'}") + if opts.prec_sqrt is not None: + options.append(f"-prec-sqrt={'1' if opts.prec_sqrt else '0'}") + if opts.prec_div is not None: + options.append(f"-prec-div={'1' if opts.prec_div else '0'}") + if opts.fma is not None: + options.append(f"-fma={'1' if opts.fma else '0'}") + + # Check for unsupported options and raise error if they are set + unsupported = [] + if opts.relocatable_device_code is not None: + unsupported.append("relocatable_device_code") + if opts.extensible_whole_program is not None and opts.extensible_whole_program: + unsupported.append("extensible_whole_program") + if opts.lineinfo is not None and opts.lineinfo: + unsupported.append("lineinfo") + if opts.ptxas_options is not None: + unsupported.append("ptxas_options") + if opts.max_register_count is not None: + unsupported.append("max_register_count") + if opts.use_fast_math is not None and opts.use_fast_math: + unsupported.append("use_fast_math") + if opts.extra_device_vectorization is not None and opts.extra_device_vectorization: + unsupported.append("extra_device_vectorization") + if opts.gen_opt_lto is not None and opts.gen_opt_lto: + unsupported.append("gen_opt_lto") + if opts.define_macro is not None: + unsupported.append("define_macro") + if opts.undefine_macro is not None: + unsupported.append("undefine_macro") + if opts.include_path is not None: + unsupported.append("include_path") + if opts.pre_include is not None: + unsupported.append("pre_include") + if opts.no_source_include is not None and opts.no_source_include: + unsupported.append("no_source_include") + if opts.std is not None: + unsupported.append("std") + if opts.builtin_move_forward is not None: + unsupported.append("builtin_move_forward") + if opts.builtin_initializer_list is not None: + unsupported.append("builtin_initializer_list") + if opts.disable_warnings is not None and opts.disable_warnings: + unsupported.append("disable_warnings") + if opts.restrict is not None and opts.restrict: + unsupported.append("restrict") + if opts.device_as_default_execution_space is not None and opts.device_as_default_execution_space: + unsupported.append("device_as_default_execution_space") + if opts.device_int128 is not None and opts.device_int128: + unsupported.append("device_int128") + if opts.optimization_info is not None: + unsupported.append("optimization_info") + if opts.no_display_error_number is not None and opts.no_display_error_number: + unsupported.append("no_display_error_number") + if opts.diag_error is not None: + unsupported.append("diag_error") + if opts.diag_suppress is not None: + unsupported.append("diag_suppress") + if opts.diag_warn is not None: + unsupported.append("diag_warn") + if opts.brief_diagnostics is not None: + unsupported.append("brief_diagnostics") + if opts.time is not None: + unsupported.append("time") + if opts.split_compile is not None: + unsupported.append("split_compile") + if opts.fdevice_syntax_only is not None and opts.fdevice_syntax_only: + unsupported.append("fdevice_syntax_only") + if opts.minimal is not None and opts.minimal: + unsupported.append("minimal") + if opts.numba_debug is not None and opts.numba_debug: + unsupported.append("numba_debug") + if unsupported: + raise CUDAError(f"The following options are not supported by NVVM backend: {', '.join(unsupported)}") + + if as_bytes: + return [o.encode() for o in options] + else: + return options diff --git a/cuda_core/tests/test_module.py b/cuda_core/tests/test_module.py index c0760cee45..e237cf5474 100644 --- a/cuda_core/tests/test_module.py +++ b/cuda_core/tests/test_module.py @@ -146,7 +146,7 @@ def test_object_code_load_ptx(get_saxpy_kernel_ptx): sym_map = mod.symbol_mapping mod_obj = ObjectCode.from_ptx(ptx, symbol_mapping=sym_map) assert mod.code == ptx - if not Program._can_load_generated_ptx(): + if not Program.driver_can_load_nvrtc_ptx_output(): pytest.skip("PTX version too new for current driver") mod_obj.get_kernel("saxpy") # force loading @@ -160,7 +160,7 @@ def test_object_code_load_ptx_from_file(get_saxpy_kernel_ptx, tmp_path): mod_obj = ObjectCode.from_ptx(str(ptx_file), symbol_mapping=sym_map) assert mod_obj.code == str(ptx_file) assert mod_obj.code_type == "ptx" - if not Program._can_load_generated_ptx(): + if not Program.driver_can_load_nvrtc_ptx_output(): pytest.skip("PTX version too new for current driver") mod_obj.get_kernel("saxpy") # force loading From 182feabf8fb88380d8676d518822aeba3bc5da0f Mon Sep 17 00:00:00 2001 From: Andy Jost Date: Tue, 3 Feb 2026 10:10:31 -0800 Subject: [PATCH 4/7] Extend test_object_protocols.py with Program and ObjectCode variations Add fixtures for different Program backends (NVRTC, PTX, NVVM) and ObjectCode code types (cubin, PTX, LTOIR). Split API_TYPES into more precise HASH_TYPES, EQ_TYPES, and WEAKREF_TYPES lists. Derive DICT_KEY_TYPES and WEAK_KEY_TYPES for collection tests. --- cuda_core/tests/test_object_protocols.py | 147 ++++++++++++++++++++--- 1 file changed, 127 insertions(+), 20 deletions(-) diff --git a/cuda_core/tests/test_object_protocols.py b/cuda_core/tests/test_object_protocols.py index 69a83f6514..a5237b9c64 100644 --- a/cuda_core/tests/test_object_protocols.py +++ b/cuda_core/tests/test_object_protocols.py @@ -56,16 +56,86 @@ def sample_launch_config(): @pytest.fixture -def sample_object_code(init_cuda): - """A sample ObjectCode object.""" +def sample_kernel(sample_object_code_cubin): + """A sample Kernel object.""" + return sample_object_code_cubin.get_kernel("test_kernel") + + +# ============================================================================= +# Fixtures - ObjectCode variations (by code_type) +# ============================================================================= + + +@pytest.fixture +def sample_object_code_cubin(init_cuda): + """An ObjectCode compiled to cubin.""" prog = Program('extern "C" __global__ void test_kernel() {}', "c++") return prog.compile("cubin") @pytest.fixture -def sample_kernel(sample_object_code): - """A sample Kernel object.""" - return sample_object_code.get_kernel("test_kernel") +def sample_object_code_ptx(init_cuda): + """An ObjectCode compiled to PTX.""" + if not Program.driver_can_load_nvrtc_ptx_output(): + pytest.skip("PTX version too new for current driver") + prog = Program('extern "C" __global__ void test_kernel() {}', "c++") + return prog.compile("ptx") + + +@pytest.fixture +def sample_object_code_ltoir(init_cuda): + """An ObjectCode compiled to LTOIR.""" + prog = Program('extern "C" __global__ void test_kernel() {}', "c++") + return prog.compile("ltoir") + + +# ============================================================================= +# Fixtures - Program variations (by backend) +# ============================================================================= + + +@pytest.fixture +def sample_program_nvrtc(init_cuda): + """A Program using NVRTC backend (C++ code).""" + return Program('extern "C" __global__ void k() {}', "c++") + + +@pytest.fixture +def sample_program_ptx(init_cuda): + """A Program using linker backend (PTX code).""" + # First compile C++ to PTX, then create a Program from PTX + if not Program.driver_can_load_nvrtc_ptx_output(): + pytest.skip("PTX version too new for current driver") + prog = Program('extern "C" __global__ void k() {}', "c++") + obj = prog.compile("ptx") + ptx_code = obj.code.decode() if isinstance(obj.code, bytes) else obj.code + return Program(ptx_code, "ptx") + + +@pytest.fixture +def sample_program_nvvm(init_cuda): + """A Program using NVVM backend (NVVM IR code).""" + # Minimal NVVM IR that declares a kernel + # fmt: off + nvvm_ir = ( + 'target datalayout = "e-p:64:64:64-i1:8:8-i8:8:8-i16:16:16-i32:32:32-' + 'i64:64:64-f32:32:32-f64:64:64-v16:16:16-v32:32:32-v64:64:64-v128:128:128-n16:32:64"\n' + 'target triple = "nvptx64-nvidia-cuda"\n' + "\n" + "define void @test_kernel() {\n" + " ret void\n" + "}\n" + "\n" + "!nvvm.annotations = !{!0}\n" + '!0 = !{void ()* @test_kernel, !"kernel", i32 1}\n' + ) + # fmt: on + try: + return Program(nvvm_ir, "nvvm") + except RuntimeError as e: + if "NVVM" in str(e): + pytest.skip("NVVM not available") + raise # ============================================================================= @@ -131,18 +201,43 @@ def sample_kernel_alt(sample_object_code_alt): # Type groupings # ============================================================================= -# All types that should support weak references -API_TYPES = [ +# Types with __hash__ support +HASH_TYPES = [ "sample_device", "sample_stream", "sample_event", "sample_context", "sample_buffer", "sample_launch_config", - "sample_object_code", + "sample_object_code_cubin", "sample_kernel", ] +# Types with __eq__ support +EQ_TYPES = [ + "sample_device", + "sample_stream", + "sample_event", + "sample_context", + "sample_buffer", + "sample_launch_config", + "sample_object_code_cubin", + "sample_kernel", +] + +# Types with __weakref__ support +WEAKREF_TYPES = [ + "sample_device", + "sample_stream", + "sample_event", + "sample_context", + "sample_buffer", + "sample_launch_config", + "sample_object_code_cubin", + "sample_kernel", + "sample_program_nvrtc", +] + # Pairs of distinct objects of the same type (for inequality testing) # Device and Context pairs require multi-GPU and will skip on single-GPU machines SAME_TYPE_PAIRS = [ @@ -152,7 +247,7 @@ def sample_kernel_alt(sample_object_code_alt): ("sample_context", "sample_context_alt"), ("sample_buffer", "sample_buffer_alt"), ("sample_launch_config", "sample_launch_config_alt"), - ("sample_object_code", "sample_object_code_alt"), + ("sample_object_code_cubin", "sample_object_code_alt"), ("sample_kernel", "sample_kernel_alt"), ] @@ -163,8 +258,13 @@ def sample_kernel_alt(sample_object_code_alt): ("sample_kernel", lambda k: Kernel.from_handle(int(k.handle))), ] +# Derived type groupings for collection tests +DICT_KEY_TYPES = sorted(set(HASH_TYPES) & set(EQ_TYPES)) +WEAK_KEY_TYPES = sorted(set(HASH_TYPES) & set(EQ_TYPES) & set(WEAKREF_TYPES)) + # Pairs of (fixture_name, regex_pattern) for repr format validation REPR_PATTERNS = [ + # Core types ("sample_device", r""), ("sample_stream", r""), ("sample_event", r""), @@ -175,8 +275,15 @@ def sample_kernel_alt(sample_object_code_alt): r"LaunchConfig\(grid=\(\d+, \d+, \d+\), cluster=.+, block=\(\d+, \d+, \d+\), " r"shmem_size=\d+, cooperative_launch=(?:True|False)\)", ), - ("sample_object_code", r""), ("sample_kernel", r""), + # ObjectCode variations (by code_type) + ("sample_object_code_cubin", r""), + ("sample_object_code_ptx", r""), + ("sample_object_code_ltoir", r""), + # Program variations (by backend) + ("sample_program_nvrtc", r""), + ("sample_program_ptx", r""), + ("sample_program_nvvm", r""), ] @@ -185,7 +292,7 @@ def sample_kernel_alt(sample_object_code_alt): # ============================================================================= -@pytest.mark.parametrize("fixture_name", API_TYPES) +@pytest.mark.parametrize("fixture_name", WEAKREF_TYPES) def test_weakref_supported(fixture_name, request): """Object supports weak references.""" obj = request.getfixturevalue(fixture_name) @@ -198,7 +305,7 @@ def test_weakref_supported(fixture_name, request): # ============================================================================= -@pytest.mark.parametrize("fixture_name", API_TYPES) +@pytest.mark.parametrize("fixture_name", HASH_TYPES) def test_hash_consistency(fixture_name, request): """Hash is consistent across multiple calls.""" obj = request.getfixturevalue(fixture_name) @@ -213,7 +320,7 @@ def test_hash_distinct_same_type(a_name, b_name, request): assert hash(obj_a) != hash(obj_b) # extremely unlikely -@pytest.mark.parametrize("a_name,b_name", itertools.combinations(API_TYPES, 2)) +@pytest.mark.parametrize("a_name,b_name", itertools.combinations(HASH_TYPES, 2)) def test_hash_distinct_cross_type(a_name, b_name, request): """Distinct objects of different types have different hashes.""" obj_a = request.getfixturevalue(a_name) @@ -226,7 +333,7 @@ def test_hash_distinct_cross_type(a_name, b_name, request): # ============================================================================= -@pytest.mark.parametrize("fixture_name", API_TYPES) +@pytest.mark.parametrize("fixture_name", EQ_TYPES) def test_equality_basic(fixture_name, request): """Object equality: reflexive, not equal to None or other types.""" obj = request.getfixturevalue(fixture_name) @@ -237,7 +344,7 @@ def test_equality_basic(fixture_name, request): assert obj != obj.handle -@pytest.mark.parametrize("a_name,b_name", itertools.combinations(API_TYPES, 2)) +@pytest.mark.parametrize("a_name,b_name", itertools.combinations(EQ_TYPES, 2)) def test_no_cross_type_equality(a_name, b_name, request): """No two distinct objects of different types should compare equal.""" obj_a = request.getfixturevalue(a_name) @@ -268,7 +375,7 @@ def test_equality_same_handle(fixture_name, copy_fn, request): # ============================================================================= -@pytest.mark.parametrize("fixture_name", API_TYPES) +@pytest.mark.parametrize("fixture_name", DICT_KEY_TYPES) def test_usable_as_dict_key(fixture_name, request): """Object can be used as a dictionary key.""" obj = request.getfixturevalue(fixture_name) @@ -277,7 +384,7 @@ def test_usable_as_dict_key(fixture_name, request): assert obj in d -@pytest.mark.parametrize("fixture_name", API_TYPES) +@pytest.mark.parametrize("fixture_name", DICT_KEY_TYPES) def test_usable_in_set(fixture_name, request): """Object can be added to a set.""" obj = request.getfixturevalue(fixture_name) @@ -285,7 +392,7 @@ def test_usable_in_set(fixture_name, request): assert obj in s -@pytest.mark.parametrize("fixture_name", API_TYPES) +@pytest.mark.parametrize("fixture_name", WEAKREF_TYPES) def test_usable_in_weak_value_dict(fixture_name, request): """Object can be used as a WeakValueDictionary value.""" obj = request.getfixturevalue(fixture_name) @@ -294,7 +401,7 @@ def test_usable_in_weak_value_dict(fixture_name, request): assert wvd["key"] is obj -@pytest.mark.parametrize("fixture_name", API_TYPES) +@pytest.mark.parametrize("fixture_name", WEAK_KEY_TYPES) def test_usable_in_weak_key_dict(fixture_name, request): """Object can be used as a WeakKeyDictionary key.""" obj = request.getfixturevalue(fixture_name) @@ -303,7 +410,7 @@ def test_usable_in_weak_key_dict(fixture_name, request): assert wkd[obj] == "value" -@pytest.mark.parametrize("fixture_name", API_TYPES) +@pytest.mark.parametrize("fixture_name", WEAK_KEY_TYPES) def test_usable_in_weak_set(fixture_name, request): """Object can be added to a WeakSet.""" obj = request.getfixturevalue(fixture_name) From b9b90d6968bc9da4fa523329c0a0349a35e0bd59 Mon Sep 17 00:00:00 2001 From: Andy Jost Date: Tue, 3 Feb 2026 11:28:15 -0800 Subject: [PATCH 5/7] Add NVRTC/NVVM resource handles and remove Program MNFF - Add NvrtcProgramHandle and NvvmProgramHandle to resource handles module - Add function pointer initialization for nvrtcDestroyProgram and nvvmDestroyProgram - Forward-declare nvvmProgram to avoid nvvm.h dependency - Refactor detail::make_py to accept module name parameter - Remove _ProgramMNFF class from _program.pyx - Program now uses typed handles directly with RAII cleanup - Update handle property to return None when handle is null --- cuda_core/cuda/core/_cpp/resource_handles.cpp | 66 ++++++++++++ cuda_core/cuda/core/_cpp/resource_handles.hpp | 102 ++++++++++++++++-- cuda_core/cuda/core/_program.pxd | 5 +- cuda_core/cuda/core/_program.pyx | 97 +++++++++-------- cuda_core/cuda/core/_resource_handles.pxd | 20 +++- cuda_core/cuda/core/_resource_handles.pyx | 47 ++++++++ 6 files changed, 279 insertions(+), 58 deletions(-) diff --git a/cuda_core/cuda/core/_cpp/resource_handles.cpp b/cuda_core/cuda/core/_cpp/resource_handles.cpp index 581a7af69a..27e76b817a 100644 --- a/cuda_core/cuda/core/_cpp/resource_handles.cpp +++ b/cuda_core/cuda/core/_cpp/resource_handles.cpp @@ -56,6 +56,12 @@ decltype(&cuLibraryLoadData) p_cuLibraryLoadData = nullptr; decltype(&cuLibraryUnload) p_cuLibraryUnload = nullptr; decltype(&cuLibraryGetKernel) p_cuLibraryGetKernel = nullptr; +// NVRTC function pointers +decltype(&nvrtcDestroyProgram) p_nvrtcDestroyProgram = nullptr; + +// NVVM function pointers (may be null if NVVM is not available) +NvvmDestroyProgramFn p_nvvmDestroyProgram = nullptr; + // ============================================================================ // GIL management helpers // ============================================================================ @@ -764,4 +770,64 @@ KernelHandle create_kernel_handle_ref(CUkernel kernel, const LibraryHandle& h_li return KernelHandle(box, &box->resource); } +// ============================================================================ +// NVRTC Program Handles +// ============================================================================ + +namespace { +struct NvrtcProgramBox { + nvrtcProgram resource; +}; +} // namespace + +NvrtcProgramHandle create_nvrtc_program_handle(nvrtcProgram prog) { + auto box = std::shared_ptr( + new NvrtcProgramBox{prog}, + [](NvrtcProgramBox* b) { + // Note: nvrtcDestroyProgram takes nvrtcProgram* and nulls it, + // but we're deleting the box anyway so nulling is harmless. + // Errors are ignored (standard destructor practice). + p_nvrtcDestroyProgram(&b->resource); + delete b; + } + ); + return NvrtcProgramHandle(box, &box->resource); +} + +NvrtcProgramHandle create_nvrtc_program_handle_ref(nvrtcProgram prog) { + auto box = std::make_shared(NvrtcProgramBox{prog}); + return NvrtcProgramHandle(box, &box->resource); +} + +// ============================================================================ +// NVVM Program Handles +// ============================================================================ + +namespace { +struct NvvmProgramBox { + nvvmProgram resource; +}; +} // namespace + +NvvmProgramHandle create_nvvm_program_handle(nvvmProgram prog) { + auto box = std::shared_ptr( + new NvvmProgramBox{prog}, + [](NvvmProgramBox* b) { + // Note: nvvmDestroyProgram takes nvvmProgram* and nulls it, + // but we're deleting the box anyway so nulling is harmless. + // If NVVM is not available, the function pointer is null. + if (p_nvvmDestroyProgram) { + p_nvvmDestroyProgram(&b->resource); + } + delete b; + } + ); + return NvvmProgramHandle(box, &box->resource); +} + +NvvmProgramHandle create_nvvm_program_handle_ref(nvvmProgram prog) { + auto box = std::make_shared(NvvmProgramBox{prog}); + return NvvmProgramHandle(box, &box->resource); +} + } // namespace cuda_core diff --git a/cuda_core/cuda/core/_cpp/resource_handles.hpp b/cuda_core/cuda/core/_cpp/resource_handles.hpp index b6118a07a2..fe64fd343d 100644 --- a/cuda_core/cuda/core/_cpp/resource_handles.hpp +++ b/cuda_core/cuda/core/_cpp/resource_handles.hpp @@ -6,9 +6,16 @@ #include #include +#include #include #include +// Forward declaration for NVVM - avoids nvvm.h dependency +// Use void* to match cuda.bindings.cynvvm's typedef for compatibility +#ifndef CYTHON_EXTERN_C +typedef void *nvvmProgram; +#endif + namespace cuda_core { // ============================================================================ @@ -67,6 +74,28 @@ extern decltype(&cuLibraryLoadData) p_cuLibraryLoadData; extern decltype(&cuLibraryUnload) p_cuLibraryUnload; extern decltype(&cuLibraryGetKernel) p_cuLibraryGetKernel; +// ============================================================================ +// NVRTC function pointers +// +// These are populated by _resource_handles.pyx at module import time using +// function pointers extracted from cuda.bindings.cynvrtc.__pyx_capi__. +// ============================================================================ + +extern decltype(&nvrtcDestroyProgram) p_nvrtcDestroyProgram; + +// ============================================================================ +// NVVM function pointers +// +// These are populated by _resource_handles.pyx at module import time using +// function pointers extracted from cuda.bindings.cynvvm.__pyx_capi__. +// Note: May be null if NVVM is not available at runtime. +// ============================================================================ + +// Function pointer type for nvvmDestroyProgram (avoids nvvm.h dependency) +// Signature: nvvmResult nvvmDestroyProgram(nvvmProgram *prog) +using NvvmDestroyProgramFn = int (*)(nvvmProgram*); +extern NvvmDestroyProgramFn p_nvvmDestroyProgram; + // ============================================================================ // Handle type aliases - expose only the raw CUDA resource // ============================================================================ @@ -77,6 +106,8 @@ using EventHandle = std::shared_ptr; using MemoryPoolHandle = std::shared_ptr; using LibraryHandle = std::shared_ptr; using KernelHandle = std::shared_ptr; +using NvrtcProgramHandle = std::shared_ptr; +using NvvmProgramHandle = std::shared_ptr; // ============================================================================ // Context handle functions @@ -260,6 +291,33 @@ KernelHandle create_kernel_handle(const LibraryHandle& h_library, const char* na // Use for borrowed kernels. The library handle keeps the library alive. KernelHandle create_kernel_handle_ref(CUkernel kernel, const LibraryHandle& h_library); +// ============================================================================ +// NVRTC Program handle functions +// ============================================================================ + +// Create an owning NVRTC program handle. +// When the last reference is released, nvrtcDestroyProgram is called. +// Use this to wrap a program created via nvrtcCreateProgram. +NvrtcProgramHandle create_nvrtc_program_handle(nvrtcProgram prog); + +// Create a non-owning NVRTC program handle (references existing program). +// The program will NOT be destroyed when the handle is released. +NvrtcProgramHandle create_nvrtc_program_handle_ref(nvrtcProgram prog); + +// ============================================================================ +// NVVM Program handle functions +// ============================================================================ + +// Create an owning NVVM program handle. +// When the last reference is released, nvvmDestroyProgram is called. +// Use this to wrap a program created via nvvmCreateProgram. +// Note: If NVVM is not available (p_nvvmDestroyProgram is null), the deleter is a no-op. +NvvmProgramHandle create_nvvm_program_handle(nvvmProgram prog); + +// Create a non-owning NVVM program handle (references existing program). +// The program will NOT be destroyed when the handle is released. +NvvmProgramHandle create_nvvm_program_handle_ref(nvvmProgram prog); + // ============================================================================ // Overloaded helper functions to extract raw resources from handles // ============================================================================ @@ -293,6 +351,14 @@ inline CUkernel as_cu(const KernelHandle& h) noexcept { return h ? *h : nullptr; } +inline nvrtcProgram as_cu(const NvrtcProgramHandle& h) noexcept { + return h ? *h : nullptr; +} + +inline nvvmProgram as_cu(const NvvmProgramHandle& h) noexcept { + return h ? *h : nullptr; +} + // as_intptr() - extract handle as intptr_t for Python interop // Using signed intptr_t per C standard convention and issue #1342 inline std::intptr_t as_intptr(const ContextHandle& h) noexcept { @@ -323,11 +389,19 @@ inline std::intptr_t as_intptr(const KernelHandle& h) noexcept { return reinterpret_cast(as_cu(h)); } -// as_py() - convert handle to Python driver wrapper object (returns new reference) +inline std::intptr_t as_intptr(const NvrtcProgramHandle& h) noexcept { + return reinterpret_cast(as_cu(h)); +} + +inline std::intptr_t as_intptr(const NvvmProgramHandle& h) noexcept { + return reinterpret_cast(as_cu(h)); +} + +// as_py() - convert handle to Python wrapper object (returns new reference) namespace detail { // n.b. class lookup is not cached to avoid deadlock hazard, see DESIGN.md -inline PyObject* make_py(const char* class_name, std::intptr_t value) noexcept { - PyObject* mod = PyImport_ImportModule("cuda.bindings.driver"); +inline PyObject* make_py(const char* module_name, const char* class_name, std::intptr_t value) noexcept { + PyObject* mod = PyImport_ImportModule(module_name); if (!mod) return nullptr; PyObject* cls = PyObject_GetAttrString(mod, class_name); Py_DECREF(mod); @@ -339,31 +413,39 @@ inline PyObject* make_py(const char* class_name, std::intptr_t value) noexcept { } // namespace detail inline PyObject* as_py(const ContextHandle& h) noexcept { - return detail::make_py("CUcontext", as_intptr(h)); + return detail::make_py("cuda.bindings.driver", "CUcontext", as_intptr(h)); } inline PyObject* as_py(const StreamHandle& h) noexcept { - return detail::make_py("CUstream", as_intptr(h)); + return detail::make_py("cuda.bindings.driver", "CUstream", as_intptr(h)); } inline PyObject* as_py(const EventHandle& h) noexcept { - return detail::make_py("CUevent", as_intptr(h)); + return detail::make_py("cuda.bindings.driver", "CUevent", as_intptr(h)); } inline PyObject* as_py(const MemoryPoolHandle& h) noexcept { - return detail::make_py("CUmemoryPool", as_intptr(h)); + return detail::make_py("cuda.bindings.driver", "CUmemoryPool", as_intptr(h)); } inline PyObject* as_py(const DevicePtrHandle& h) noexcept { - return detail::make_py("CUdeviceptr", as_intptr(h)); + return detail::make_py("cuda.bindings.driver", "CUdeviceptr", as_intptr(h)); } inline PyObject* as_py(const LibraryHandle& h) noexcept { - return detail::make_py("CUlibrary", as_intptr(h)); + return detail::make_py("cuda.bindings.driver", "CUlibrary", as_intptr(h)); } inline PyObject* as_py(const KernelHandle& h) noexcept { - return detail::make_py("CUkernel", as_intptr(h)); + return detail::make_py("cuda.bindings.driver", "CUkernel", as_intptr(h)); +} + +inline PyObject* as_py(const NvrtcProgramHandle& h) noexcept { + return detail::make_py("cuda.bindings.nvrtc", "nvrtcProgram", as_intptr(h)); +} + +inline PyObject* as_py(const NvvmProgramHandle& h) noexcept { + return detail::make_py("cuda.bindings.nvvm", "nvvmProgram", as_intptr(h)); } } // namespace cuda_core diff --git a/cuda_core/cuda/core/_program.pxd b/cuda_core/cuda/core/_program.pxd index 444257f1e4..92d30f8c0c 100644 --- a/cuda_core/cuda/core/_program.pxd +++ b/cuda_core/cuda/core/_program.pxd @@ -2,10 +2,13 @@ # # SPDX-License-Identifier: Apache-2.0 +from ._resource_handles cimport NvrtcProgramHandle, NvvmProgramHandle + cdef class Program: cdef: - object _mnff + NvrtcProgramHandle _h_nvrtc + NvvmProgramHandle _h_nvvm str _backend object _linker # Linker object _options # ProgramOptions diff --git a/cuda_core/cuda/core/_program.pyx b/cuda_core/cuda/core/_program.pyx index 79a3cd4f7f..e30d7f65b4 100644 --- a/cuda_core/cuda/core/_program.pyx +++ b/cuda_core/cuda/core/_program.pyx @@ -9,11 +9,21 @@ This module provides :class:`Program` for compiling source code into from __future__ import annotations -import weakref from dataclasses import dataclass from warnings import warn from cuda.bindings import driver, nvrtc + +from libc.stdint cimport intptr_t + +from ._resource_handles cimport ( + NvrtcProgramHandle, + NvvmProgramHandle, + as_intptr, + create_nvrtc_program_handle, + create_nvvm_program_handle, +) +from cuda.bindings cimport cynvrtc, cynvvm from cuda.core._device import Device from cuda.core._linker import Linker, LinkerHandleT, LinkerOptions from cuda.core._module import ObjectCode @@ -65,7 +75,9 @@ cdef class Program: """Destroy this program.""" if self._linker: self._linker.close() - self._mnff.close() + # Reset handles - the C++ shared_ptr destructor handles cleanup + self._h_nvrtc = NvrtcProgramHandle() + self._h_nvvm = NvvmProgramHandle() def compile( self, target_type: str, name_expressions: tuple | list = (), logs = None @@ -107,7 +119,15 @@ cdef class Program: This handle is a Python object. To get the memory address of the underlying C handle, call ``int(Program.handle)``. """ - return self._mnff.handle + if self._backend == "NVRTC": + ptr = as_intptr(self._h_nvrtc) + return nvrtc.nvrtcProgram(ptr) if ptr else None + elif self._backend == "NVVM": + # NVVM uses raw integers for handles, not wrapper classes + ptr = as_intptr(self._h_nvvm) + return ptr if ptr else None + else: + return self._linker.handle if self._linker else None @staticmethod def driver_can_load_nvrtc_ptx_output() -> bool: @@ -422,26 +442,6 @@ cdef object_nvvm_module = None cdef bint _nvvm_import_attempted = False -class _ProgramMNFF: - """Members needed for postrm release of program handles.""" - - __slots__ = "handle", "backend" - - def __init__(self, program_obj, handle, backend): - self.handle = handle - self.backend = backend - weakref.finalize(program_obj, self.close) - - def close(self): - if self.handle is not None: - if self.backend == "NVRTC": - handle_return(nvrtc.nvrtcDestroyProgram(self.handle)) - elif self.backend == "NVVM": - nvvm = _get_nvvm_module() - nvvm.destroy_program(self.handle) - self.handle = None - - def _get_nvvm_module(): """Get the NVVM module, importing it lazily with availability checks.""" global _nvvm_module, _nvvm_import_attempted @@ -530,7 +530,6 @@ cdef inline object _translate_program_options(object options): cdef inline int Program_init(Program self, object code, str code_type, object options) except -1: """Initialize a Program instance.""" - self._mnff = _ProgramMNFF(self, None, None) self._options = options = check_or_create_options(ProgramOptions, options, "Program options") code_type = code_type.lower() @@ -538,8 +537,8 @@ cdef inline int Program_init(Program self, object code, str code_type, object op assert_type(code, str) # TODO: support pre-loaded headers & include names # TODO: allow tuples once NVIDIA/cuda-python#72 is resolved - self._mnff.handle = handle_return(nvrtc.nvrtcCreateProgram(code.encode(), options._name, 0, [], [])) - self._mnff.backend = "NVRTC" + py_prog = handle_return(nvrtc.nvrtcCreateProgram(code.encode(), options._name, 0, [], [])) + self._h_nvrtc = create_nvrtc_program_handle(int(py_prog)) self._backend = "NVRTC" self._linker = None @@ -557,9 +556,9 @@ cdef inline int Program_init(Program self, object code, str code_type, object op raise TypeError("NVVM IR code must be provided as str, bytes, or bytearray") nvvm = _get_nvvm_module() - self._mnff.handle = nvvm.create_program() - self._mnff.backend = "NVVM" - nvvm.add_module_to_program(self._mnff.handle, code, len(code), options._name.decode()) + py_prog = nvvm.create_program() + nvvm.add_module_to_program(py_prog, code, len(code), options._name.decode()) + self._h_nvvm = create_nvvm_program_handle(int(py_prog)) self._backend = "NVVM" self._linker = None @@ -581,37 +580,40 @@ cdef object Program_compile_nvrtc(Program self, str target_type, object name_exp category=RuntimeWarning, ) + # Create Python wrapper for handle_return calls that need it + py_handle = nvrtc.nvrtcProgram(as_intptr(self._h_nvrtc)) + if name_expressions: for n in name_expressions: handle_return( - nvrtc.nvrtcAddNameExpression(self._mnff.handle, n.encode()), - handle=self._mnff.handle, + nvrtc.nvrtcAddNameExpression(py_handle, n.encode()), + handle=py_handle, ) options = self._options.as_bytes("nvrtc") handle_return( - nvrtc.nvrtcCompileProgram(self._mnff.handle, len(options), options), - handle=self._mnff.handle, + nvrtc.nvrtcCompileProgram(py_handle, len(options), options), + handle=py_handle, ) size_func = getattr(nvrtc, f"nvrtcGet{target_type.upper()}Size") comp_func = getattr(nvrtc, f"nvrtcGet{target_type.upper()}") - size = handle_return(size_func(self._mnff.handle), handle=self._mnff.handle) + size = handle_return(size_func(py_handle), handle=py_handle) data = b" " * size - handle_return(comp_func(self._mnff.handle, data), handle=self._mnff.handle) + handle_return(comp_func(py_handle, data), handle=py_handle) symbol_mapping = {} if name_expressions: for n in name_expressions: symbol_mapping[n] = handle_return( - nvrtc.nvrtcGetLoweredName(self._mnff.handle, n.encode()), handle=self._mnff.handle + nvrtc.nvrtcGetLoweredName(py_handle, n.encode()), handle=py_handle ) if logs is not None: - logsize = handle_return(nvrtc.nvrtcGetProgramLogSize(self._mnff.handle), handle=self._mnff.handle) + logsize = handle_return(nvrtc.nvrtcGetProgramLogSize(py_handle), handle=py_handle) if logsize > 1: log = b" " * logsize - handle_return(nvrtc.nvrtcGetProgramLog(self._mnff.handle, log), handle=self._mnff.handle) + handle_return(nvrtc.nvrtcGetProgramLog(py_handle, log), handle=py_handle) logs.write(log.decode("utf-8", errors="backslashreplace")) return ObjectCode._init(data, target_type, symbol_mapping=symbol_mapping, name=self._options.name) @@ -628,32 +630,35 @@ cdef object Program_compile_nvvm(Program self, str target_type, object logs): nvvm_options.append("-gen-lto") nvvm = _get_nvvm_module() + # NVVM uses raw integers for handles + py_handle = as_intptr(self._h_nvvm) + try: - nvvm.verify_program(self._mnff.handle, len(nvvm_options), nvvm_options) - nvvm.compile_program(self._mnff.handle, len(nvvm_options), nvvm_options) + nvvm.verify_program(py_handle, len(nvvm_options), nvvm_options) + nvvm.compile_program(py_handle, len(nvvm_options), nvvm_options) except Exception as e: # Capture NVVM program log on error error_log = "" try: - logsize = nvvm.get_program_log_size(self._mnff.handle) + logsize = nvvm.get_program_log_size(py_handle) if logsize > 1: log = bytearray(logsize) - nvvm.get_program_log(self._mnff.handle, log) + nvvm.get_program_log(py_handle, log) error_log = log.decode("utf-8", errors="backslashreplace") except Exception: pass e.args = (e.args[0] + (f"\nNVVM program log: {error_log}" if error_log else ""), *e.args[1:]) raise - size = nvvm.get_compiled_result_size(self._mnff.handle) + size = nvvm.get_compiled_result_size(py_handle) data = bytearray(size) - nvvm.get_compiled_result(self._mnff.handle, data) + nvvm.get_compiled_result(py_handle, data) if logs is not None: - logsize = nvvm.get_program_log_size(self._mnff.handle) + logsize = nvvm.get_program_log_size(py_handle) if logsize > 1: log = bytearray(logsize) - nvvm.get_program_log(self._mnff.handle, log) + nvvm.get_program_log(py_handle, log) logs.write(log.decode("utf-8", errors="backslashreplace")) return ObjectCode._init(data, target_type, name=self._options.name) diff --git a/cuda_core/cuda/core/_resource_handles.pxd b/cuda_core/cuda/core/_resource_handles.pxd index 5f08172909..d573862d16 100644 --- a/cuda_core/cuda/core/_resource_handles.pxd +++ b/cuda_core/cuda/core/_resource_handles.pxd @@ -8,6 +8,8 @@ from libc.stdint cimport intptr_t from libcpp.memory cimport shared_ptr from cuda.bindings cimport cydriver +from cuda.bindings cimport cynvrtc +from cuda.bindings cimport cynvvm # ============================================================================= @@ -23,6 +25,8 @@ cdef extern from "_cpp/resource_handles.hpp" namespace "cuda_core": ctypedef shared_ptr[const cydriver.CUdeviceptr] DevicePtrHandle ctypedef shared_ptr[const cydriver.CUlibrary] LibraryHandle ctypedef shared_ptr[const cydriver.CUkernel] KernelHandle + ctypedef shared_ptr[const cynvrtc.nvrtcProgram] NvrtcProgramHandle + ctypedef shared_ptr[const cynvvm.nvvmProgram] NvvmProgramHandle # as_cu() - extract the raw CUDA handle (inline C++) cydriver.CUcontext as_cu(ContextHandle h) noexcept nogil @@ -32,6 +36,8 @@ cdef extern from "_cpp/resource_handles.hpp" namespace "cuda_core": cydriver.CUdeviceptr as_cu(DevicePtrHandle h) noexcept nogil cydriver.CUlibrary as_cu(LibraryHandle h) noexcept nogil cydriver.CUkernel as_cu(KernelHandle h) noexcept nogil + cynvrtc.nvrtcProgram as_cu(NvrtcProgramHandle h) noexcept nogil + cynvvm.nvvmProgram as_cu(NvvmProgramHandle h) noexcept nogil # as_intptr() - extract handle as intptr_t for Python interop (inline C++) intptr_t as_intptr(ContextHandle h) noexcept nogil @@ -41,8 +47,10 @@ cdef extern from "_cpp/resource_handles.hpp" namespace "cuda_core": intptr_t as_intptr(DevicePtrHandle h) noexcept nogil intptr_t as_intptr(LibraryHandle h) noexcept nogil intptr_t as_intptr(KernelHandle h) noexcept nogil + intptr_t as_intptr(NvrtcProgramHandle h) noexcept nogil + intptr_t as_intptr(NvvmProgramHandle h) noexcept nogil - # as_py() - convert handle to Python driver wrapper object (inline C++; requires GIL) + # as_py() - convert handle to Python wrapper object (inline C++; requires GIL) object as_py(ContextHandle h) object as_py(StreamHandle h) object as_py(EventHandle h) @@ -50,6 +58,8 @@ cdef extern from "_cpp/resource_handles.hpp" namespace "cuda_core": object as_py(DevicePtrHandle h) object as_py(LibraryHandle h) object as_py(KernelHandle h) + object as_py(NvrtcProgramHandle h) + object as_py(NvvmProgramHandle h) # ============================================================================= @@ -112,3 +122,11 @@ cdef LibraryHandle create_library_handle_ref(cydriver.CUlibrary library) except+ cdef KernelHandle create_kernel_handle(const LibraryHandle& h_library, const char* name) except+ nogil cdef KernelHandle create_kernel_handle_ref( cydriver.CUkernel kernel, const LibraryHandle& h_library) except+ nogil + +# NVRTC Program handles +cdef NvrtcProgramHandle create_nvrtc_program_handle(cynvrtc.nvrtcProgram prog) except+ nogil +cdef NvrtcProgramHandle create_nvrtc_program_handle_ref(cynvrtc.nvrtcProgram prog) except+ nogil + +# NVVM Program handles +cdef NvvmProgramHandle create_nvvm_program_handle(cynvvm.nvvmProgram prog) except+ nogil +cdef NvvmProgramHandle create_nvvm_program_handle_ref(cynvvm.nvvmProgram prog) except+ nogil diff --git a/cuda_core/cuda/core/_resource_handles.pyx b/cuda_core/cuda/core/_resource_handles.pyx index 0d3e732a4f..2652d4448e 100644 --- a/cuda_core/cuda/core/_resource_handles.pyx +++ b/cuda_core/cuda/core/_resource_handles.pyx @@ -14,6 +14,8 @@ from cpython.pycapsule cimport PyCapsule_GetName, PyCapsule_GetPointer from libc.stddef cimport size_t from cuda.bindings cimport cydriver +from cuda.bindings cimport cynvrtc +from cuda.bindings cimport cynvvm from ._resource_handles cimport ( ContextHandle, @@ -23,9 +25,13 @@ from ._resource_handles cimport ( DevicePtrHandle, LibraryHandle, KernelHandle, + NvrtcProgramHandle, + NvvmProgramHandle, ) import cuda.bindings.cydriver as cydriver +import cuda.bindings.cynvrtc as cynvrtc +import cuda.bindings.cynvvm as cynvvm # ============================================================================= # C++ function declarations (non-inline, implemented in resource_handles.cpp) @@ -107,6 +113,18 @@ cdef extern from "_cpp/resource_handles.hpp" namespace "cuda_core": KernelHandle create_kernel_handle_ref "cuda_core::create_kernel_handle_ref" ( cydriver.CUkernel kernel, const LibraryHandle& h_library) except+ nogil + # NVRTC Program handles + NvrtcProgramHandle create_nvrtc_program_handle "cuda_core::create_nvrtc_program_handle" ( + cynvrtc.nvrtcProgram prog) except+ nogil + NvrtcProgramHandle create_nvrtc_program_handle_ref "cuda_core::create_nvrtc_program_handle_ref" ( + cynvrtc.nvrtcProgram prog) except+ nogil + + # NVVM Program handles + NvvmProgramHandle create_nvvm_program_handle "cuda_core::create_nvvm_program_handle" ( + cynvvm.nvvmProgram prog) except+ nogil + NvvmProgramHandle create_nvvm_program_handle_ref "cuda_core::create_nvvm_program_handle_ref" ( + cynvvm.nvvmProgram prog) except+ nogil + # ============================================================================= # CUDA Driver API capsule @@ -174,6 +192,12 @@ cdef extern from "_cpp/resource_handles.hpp" namespace "cuda_core": void* p_cuLibraryUnload "reinterpret_cast(cuda_core::p_cuLibraryUnload)" void* p_cuLibraryGetKernel "reinterpret_cast(cuda_core::p_cuLibraryGetKernel)" + # NVRTC + void* p_nvrtcDestroyProgram "reinterpret_cast(cuda_core::p_nvrtcDestroyProgram)" + + # NVVM + void* p_nvvmDestroyProgram "reinterpret_cast(cuda_core::p_nvvmDestroyProgram)" + # Initialize driver function pointers from cydriver.__pyx_capi__ at module load cdef void* _get_driver_fn(str name): @@ -223,3 +247,26 @@ p_cuLibraryLoadFromFile = _get_driver_fn("cuLibraryLoadFromFile") p_cuLibraryLoadData = _get_driver_fn("cuLibraryLoadData") p_cuLibraryUnload = _get_driver_fn("cuLibraryUnload") p_cuLibraryGetKernel = _get_driver_fn("cuLibraryGetKernel") + +# ============================================================================= +# NVRTC function pointer initialization +# ============================================================================= + +cdef void* _get_nvrtc_fn(str name): + capsule = cynvrtc.__pyx_capi__[name] + return PyCapsule_GetPointer(capsule, PyCapsule_GetName(capsule)) + +p_nvrtcDestroyProgram = _get_nvrtc_fn("nvrtcDestroyProgram") + +# ============================================================================= +# NVVM function pointer initialization +# +# NVVM may not be available at runtime, so we handle missing function pointers +# gracefully. The C++ deleter checks for null before calling. +# ============================================================================= + +cdef void* _get_nvvm_fn(str name): + capsule = cynvvm.__pyx_capi__[name] + return PyCapsule_GetPointer(capsule, PyCapsule_GetName(capsule)) + +p_nvvmDestroyProgram = _get_nvvm_fn("nvvmDestroyProgram") From c15b12eb64eff5c86219397c6befeea21e7e64f0 Mon Sep 17 00:00:00 2001 From: Andy Jost Date: Tue, 3 Feb 2026 11:59:58 -0800 Subject: [PATCH 6/7] Add HANDLE_RETURN_NVRTC and HANDLE_RETURN_NVVM, simplify HANDLE_RETURN - Add NVVMError exception class - Add HANDLE_RETURN_NVRTC for nogil NVRTC error handling with program log - Add HANDLE_RETURN_NVVM for nogil NVVM error handling with program log - Remove vestigial supported_error_type fused type - Simplify HANDLE_RETURN to directly take cydriver.CUresult --- cuda_core/cuda/core/_utils/cuda_utils.pxd | 10 ++- cuda_core/cuda/core/_utils/cuda_utils.pyx | 76 +++++++++++++++++++++-- 2 files changed, 76 insertions(+), 10 deletions(-) diff --git a/cuda_core/cuda/core/_utils/cuda_utils.pxd b/cuda_core/cuda/core/_utils/cuda_utils.pxd index 9b5044beda..6cb2f76e73 100644 --- a/cuda_core/cuda/core/_utils/cuda_utils.pxd +++ b/cuda_core/cuda/core/_utils/cuda_utils.pxd @@ -6,11 +6,7 @@ cimport cpython from cpython.object cimport PyObject from libc.stdint cimport int64_t, int32_t -from cuda.bindings cimport cydriver - - -ctypedef fused supported_error_type: - cydriver.CUresult +from cuda.bindings cimport cydriver, cynvrtc, cynvvm ctypedef fused integer_t: @@ -22,7 +18,9 @@ ctypedef fused integer_t: cdef const cydriver.CUcontext CU_CONTEXT_INVALID = (-2) -cdef int HANDLE_RETURN(supported_error_type err) except?-1 nogil +cdef int HANDLE_RETURN(cydriver.CUresult err) except?-1 nogil +cdef int HANDLE_RETURN_NVRTC(cynvrtc.nvrtcResult err, cynvrtc.nvrtcProgram prog) except?-1 nogil +cdef int HANDLE_RETURN_NVVM(cynvvm.nvvmResult err, cynvvm.nvvmProgram prog) except?-1 nogil # TODO: stop exposing these within the codebase? diff --git a/cuda_core/cuda/core/_utils/cuda_utils.pyx b/cuda_core/cuda/core/_utils/cuda_utils.pyx index c7f867a0d5..0d1d7d3344 100644 --- a/cuda_core/cuda/core/_utils/cuda_utils.pyx +++ b/cuda_core/cuda/core/_utils/cuda_utils.pyx @@ -20,6 +20,8 @@ except ImportError: from cuda import cudart as runtime from cuda import nvrtc +from cuda.bindings cimport cynvrtc, cynvvm + from cuda.core._utils.driver_cu_result_explanations import DRIVER_CU_RESULT_EXPLANATIONS from cuda.core._utils.runtime_cuda_error_explanations import RUNTIME_CUDA_ERROR_EXPLANATIONS @@ -32,6 +34,10 @@ class NVRTCError(CUDAError): pass +class NVVMError(CUDAError): + pass + + ComputeCapability = namedtuple("ComputeCapability", ("major", "minor")) @@ -57,10 +63,72 @@ def _reduce_3_tuple(t: tuple): return t[0] * t[1] * t[2] -cdef int HANDLE_RETURN(supported_error_type err) except?-1 nogil: - if supported_error_type is cydriver.CUresult: - if err != cydriver.CUresult.CUDA_SUCCESS: - return _check_driver_error(err) +cdef int HANDLE_RETURN(cydriver.CUresult err) except?-1 nogil: + if err != cydriver.CUresult.CUDA_SUCCESS: + return _check_driver_error(err) + return 0 + + +cdef int HANDLE_RETURN_NVRTC(cynvrtc.nvrtcResult err, cynvrtc.nvrtcProgram prog) except?-1 nogil: + """Handle NVRTC result codes, raising NVRTCError with program log on failure.""" + if err == cynvrtc.nvrtcResult.NVRTC_SUCCESS: + return 0 + # Get error string (can be called without GIL) + cdef const char* err_str = cynvrtc.nvrtcGetErrorString(err) + # Get program log size for additional context + cdef size_t logsize = 0 + if prog != NULL: + cynvrtc.nvrtcGetProgramLogSize(prog, &logsize) + # Need GIL for Python string operations and exception raising + with gil: + _raise_nvrtc_error(err, err_str, prog, logsize) + return -1 # Never reached, but satisfies return type + + +cdef int _raise_nvrtc_error(cynvrtc.nvrtcResult err, const char* err_str, + cynvrtc.nvrtcProgram prog, size_t logsize) except -1: + """Helper to raise NVRTCError with program log (requires GIL).""" + cdef bytes log_bytes + cdef str log_str = "" + if logsize > 1 and prog != NULL: + log_bytes = b" " * logsize + if cynvrtc.nvrtcGetProgramLog(prog, log_bytes) == cynvrtc.nvrtcResult.NVRTC_SUCCESS: + log_str = log_bytes.decode("utf-8", errors="backslashreplace") + err_msg = f"{err}: {err_str.decode()}" if err_str != NULL else f"NVRTC error {err}" + if log_str: + err_msg += f", compilation log:\n\n{log_str}" + raise NVRTCError(err_msg) + + +cdef int HANDLE_RETURN_NVVM(cynvvm.nvvmResult err, cynvvm.nvvmProgram prog) except?-1 nogil: + """Handle NVVM result codes, raising NVVMError with program log on failure.""" + if err == cynvvm.nvvmResult.NVVM_SUCCESS: + return 0 + # Get error string (can be called without GIL) + cdef const char* err_str = cynvvm.nvvmGetErrorString(err) + # Get program log size for additional context + cdef size_t logsize = 0 + if prog != NULL: + cynvvm.nvvmGetProgramLogSize(prog, &logsize) + # Need GIL for Python string operations and exception raising + with gil: + _raise_nvvm_error(err, err_str, prog, logsize) + return -1 # Never reached, but satisfies return type + + +cdef int _raise_nvvm_error(cynvvm.nvvmResult err, const char* err_str, + cynvvm.nvvmProgram prog, size_t logsize) except -1: + """Helper to raise NVVMError with program log (requires GIL).""" + cdef bytes log_bytes + cdef str log_str = "" + if logsize > 1 and prog != NULL: + log_bytes = b" " * logsize + if cynvvm.nvvmGetProgramLog(prog, log_bytes) == cynvvm.nvvmResult.NVVM_SUCCESS: + log_str = log_bytes.decode("utf-8", errors="backslashreplace") + err_msg = f"{err}: {err_str.decode()}" if err_str != NULL else f"NVVM error {err}" + if log_str: + err_msg += f", compilation log:\n\n{log_str}" + raise NVVMError(err_msg) cdef object _RUNTIME_SUCCESS = runtime.cudaError_t.cudaSuccess From 95f149ae6f097126f7a3c8f622681c88c364474d Mon Sep 17 00:00:00 2001 From: Andy Jost Date: Tue, 3 Feb 2026 16:24:44 -0800 Subject: [PATCH 7/7] Fix build errors, update tests, remove unused imports - Change cdef function return types from ObjectCode to object (Cython limitation) - Remove unused imports: intptr_t, NvrtcProgramHandle, NvvmProgramHandle, as_intptr - Update as_py(NvvmProgramHandle) to return Python int via PyLong_FromSsize_t - Update test assertions: remove handle checks after close(), test idempotency instead - Update NVVM error message regex to match new unified format --- cuda_core/cuda/core/_cpp/resource_handles.hpp | 9 +- cuda_core/cuda/core/_module.pyx | 3 +- cuda_core/cuda/core/_program.pyx | 268 +++++++++++------- cuda_core/cuda/core/_utils/cuda_utils.pxd | 4 +- cuda_core/cuda/core/_utils/cuda_utils.pyx | 4 +- cuda_core/tests/test_module.py | 2 + cuda_core/tests/test_program.py | 9 +- 7 files changed, 178 insertions(+), 121 deletions(-) diff --git a/cuda_core/cuda/core/_cpp/resource_handles.hpp b/cuda_core/cuda/core/_cpp/resource_handles.hpp index fe64fd343d..cb66841172 100644 --- a/cuda_core/cuda/core/_cpp/resource_handles.hpp +++ b/cuda_core/cuda/core/_cpp/resource_handles.hpp @@ -11,10 +11,8 @@ #include // Forward declaration for NVVM - avoids nvvm.h dependency -// Use void* to match cuda.bindings.cynvvm's typedef for compatibility -#ifndef CYTHON_EXTERN_C -typedef void *nvvmProgram; -#endif +// Use void* to match cuda.bindings.cynvvm's typedef +using nvvmProgram = void*; namespace cuda_core { @@ -445,7 +443,8 @@ inline PyObject* as_py(const NvrtcProgramHandle& h) noexcept { } inline PyObject* as_py(const NvvmProgramHandle& h) noexcept { - return detail::make_py("cuda.bindings.nvvm", "nvvmProgram", as_intptr(h)); + // NVVM bindings use raw integers, not wrapper classes + return PyLong_FromSsize_t(as_intptr(h)); } } // namespace cuda_core diff --git a/cuda_core/cuda/core/_module.pyx b/cuda_core/cuda/core/_module.pyx index 5508f3f0c6..af3977b15a 100644 --- a/cuda_core/cuda/core/_module.pyx +++ b/cuda_core/cuda/core/_module.pyx @@ -816,7 +816,8 @@ cdef class ObjectCode: try: name = self._sym_map[name] except KeyError: - name = name.encode() + if isinstance(name, str): + name = name.encode() cdef KernelHandle h_kernel = create_kernel_handle(self._h_library, name) if not h_kernel: diff --git a/cuda_core/cuda/core/_program.pyx b/cuda_core/cuda/core/_program.pyx index e30d7f65b4..10743e0b78 100644 --- a/cuda_core/cuda/core/_program.pyx +++ b/cuda_core/cuda/core/_program.pyx @@ -14,16 +14,16 @@ from warnings import warn from cuda.bindings import driver, nvrtc -from libc.stdint cimport intptr_t +from libcpp.vector cimport vector from ._resource_handles cimport ( - NvrtcProgramHandle, - NvvmProgramHandle, - as_intptr, + as_cu, + as_py, create_nvrtc_program_handle, create_nvvm_program_handle, ) from cuda.bindings cimport cynvrtc, cynvvm +from cuda.core._utils.cuda_utils cimport HANDLE_RETURN_NVRTC, HANDLE_RETURN_NVVM from cuda.core._device import Device from cuda.core._linker import Linker, LinkerHandleT, LinkerOptions from cuda.core._module import ObjectCode @@ -40,8 +40,11 @@ from cuda.core._utils.cuda_utils import ( __all__ = ["Program", "ProgramOptions"] -ProgramHandleT = nvrtc.nvrtcProgram | LinkerHandleT -"""Type alias for program handle types across different backends.""" +ProgramHandleT = nvrtc.nvrtcProgram | int | LinkerHandleT +"""Type alias for program handle types across different backends. + +The ``int`` type covers NVVM handles, which don't have a wrapper class. +""" # ============================================================================= @@ -76,8 +79,8 @@ cdef class Program: if self._linker: self._linker.close() # Reset handles - the C++ shared_ptr destructor handles cleanup - self._h_nvrtc = NvrtcProgramHandle() - self._h_nvvm = NvvmProgramHandle() + self._h_nvrtc.reset() + self._h_nvvm.reset() def compile( self, target_type: str, name_expressions: tuple | list = (), logs = None @@ -120,14 +123,11 @@ cdef class Program: handle, call ``int(Program.handle)``. """ if self._backend == "NVRTC": - ptr = as_intptr(self._h_nvrtc) - return nvrtc.nvrtcProgram(ptr) if ptr else None + return as_py(self._h_nvrtc) elif self._backend == "NVVM": - # NVVM uses raw integers for handles, not wrapper classes - ptr = as_intptr(self._h_nvvm) - return ptr if ptr else None + return as_py(self._h_nvvm) # returns int (NVVM uses raw integers) else: - return self._linker.handle if self._linker else None + return self._linker.handle @staticmethod def driver_can_load_nvrtc_ptx_output() -> bool: @@ -392,7 +392,7 @@ class ProgramOptions: def _prepare_nvvm_options(self, as_bytes: bool = True) -> list[bytes] | list[str]: return _prepare_nvvm_options_impl(self, as_bytes) - def as_bytes(self, backend: str) -> list[bytes]: + def as_bytes(self, backend: str, target_type: str | None = None) -> list[bytes]: """Convert program options to bytes format for the specified backend. This method transforms the program options into a format suitable for the @@ -403,6 +403,9 @@ class ProgramOptions: ---------- backend : str The compiler backend to prepare options for. Must be either "nvrtc" or "nvvm". + target_type : str, optional + The compilation target type (e.g., "ptx", "cubin", "ltoir"). Some backends + require additional options based on the target type. Returns ------- @@ -425,7 +428,10 @@ class ProgramOptions: if backend == "nvrtc": return self._prepare_nvrtc_options() elif backend == "nvvm": - return self._prepare_nvvm_options(as_bytes=True) + options = self._prepare_nvvm_options(as_bytes=True) + if target_type == "ltoir" and b"-gen-lto" not in options: + options.append(b"-gen-lto") + return options else: raise ValueError(f"Unknown backend '{backend}'. Must be one of: 'nvrtc', 'nvvm'") @@ -530,15 +536,27 @@ cdef inline object _translate_program_options(object options): cdef inline int Program_init(Program self, object code, str code_type, object options) except -1: """Initialize a Program instance.""" + cdef cynvrtc.nvrtcProgram nvrtc_prog + cdef cynvvm.nvvmProgram nvvm_prog + cdef bytes code_bytes + cdef const char* code_ptr + cdef const char* name_ptr + cdef size_t code_len + self._options = options = check_or_create_options(ProgramOptions, options, "Program options") code_type = code_type.lower() if code_type == "c++": assert_type(code, str) # TODO: support pre-loaded headers & include names - # TODO: allow tuples once NVIDIA/cuda-python#72 is resolved - py_prog = handle_return(nvrtc.nvrtcCreateProgram(code.encode(), options._name, 0, [], [])) - self._h_nvrtc = create_nvrtc_program_handle(int(py_prog)) + code_bytes = code.encode() + code_ptr = code_bytes + name_ptr = options._name + + with nogil: + HANDLE_RETURN_NVRTC(NULL, cynvrtc.nvrtcCreateProgram( + &nvrtc_prog, code_ptr, name_ptr, 0, NULL, NULL)) + self._h_nvrtc = create_nvrtc_program_handle(nvrtc_prog) self._backend = "NVRTC" self._linker = None @@ -550,15 +568,21 @@ cdef inline int Program_init(Program self, object code, str code_type, object op self._backend = self._linker.backend elif code_type == "nvvm": + _get_nvvm_module() # Validate NVVM availability if isinstance(code, str): code = code.encode("utf-8") elif not isinstance(code, (bytes, bytearray)): raise TypeError("NVVM IR code must be provided as str, bytes, or bytearray") - nvvm = _get_nvvm_module() - py_prog = nvvm.create_program() - nvvm.add_module_to_program(py_prog, code, len(code), options._name.decode()) - self._h_nvvm = create_nvvm_program_handle(int(py_prog)) + code_ptr = (code) + name_ptr = options._name + code_len = len(code) + + with nogil: + HANDLE_RETURN_NVVM(NULL, cynvvm.nvvmCreateProgram(&nvvm_prog)) + self._h_nvvm = create_nvvm_program_handle(nvvm_prog) # RAII from here + with nogil: + HANDLE_RETURN_NVVM(nvvm_prog, cynvvm.nvvmAddModuleToProgram(nvvm_prog, code_ptr, code_len, name_ptr)) self._backend = "NVVM" self._linker = None @@ -571,115 +595,149 @@ cdef inline int Program_init(Program self, object code, str code_type, object op cdef object Program_compile_nvrtc(Program self, str target_type, object name_expressions, object logs): - """Compile using NVRTC backend.""" - if target_type == "ptx" and not _can_load_generated_ptx(): - warn( - "The CUDA driver version is older than the backend version. " - "The generated ptx will not be loadable by the current driver.", - stacklevel=2, - category=RuntimeWarning, - ) - - # Create Python wrapper for handle_return calls that need it - py_handle = nvrtc.nvrtcProgram(as_intptr(self._h_nvrtc)) - + """Compile using NVRTC backend and return ObjectCode.""" + cdef cynvrtc.nvrtcProgram prog = as_cu(self._h_nvrtc) + cdef size_t output_size = 0 + cdef size_t logsize = 0 + cdef vector[const char*] options_vec + cdef char* data_ptr = NULL + cdef bytes name_bytes + cdef const char* name_ptr = NULL + cdef const char* lowered_name = NULL + cdef dict symbol_mapping = {} + + # Add name expressions before compilation if name_expressions: for n in name_expressions: - handle_return( - nvrtc.nvrtcAddNameExpression(py_handle, n.encode()), - handle=py_handle, - ) - - options = self._options.as_bytes("nvrtc") - handle_return( - nvrtc.nvrtcCompileProgram(py_handle, len(options), options), - handle=py_handle, - ) - - size_func = getattr(nvrtc, f"nvrtcGet{target_type.upper()}Size") - comp_func = getattr(nvrtc, f"nvrtcGet{target_type.upper()}") - size = handle_return(size_func(py_handle), handle=py_handle) - data = b" " * size - handle_return(comp_func(py_handle, data), handle=py_handle) - - symbol_mapping = {} + name_bytes = n.encode() if isinstance(n, str) else n + name_ptr = name_bytes + HANDLE_RETURN_NVRTC(prog, cynvrtc.nvrtcAddNameExpression(prog, name_ptr)) + + # Build options array + options_list = self._options.as_bytes("nvrtc", target_type) + options_vec.resize(len(options_list)) + for i in range(len(options_list)): + options_vec[i] = (options_list[i]) + + # Compile + with nogil: + HANDLE_RETURN_NVRTC(prog, cynvrtc.nvrtcCompileProgram(prog, options_vec.size(), options_vec.data())) + + # Get compiled output based on target type + if target_type == "ptx": + HANDLE_RETURN_NVRTC(prog, cynvrtc.nvrtcGetPTXSize(prog, &output_size)) + data = bytearray(output_size) + data_ptr = (data) + with nogil: + HANDLE_RETURN_NVRTC(prog, cynvrtc.nvrtcGetPTX(prog, data_ptr)) + elif target_type == "cubin": + HANDLE_RETURN_NVRTC(prog, cynvrtc.nvrtcGetCUBINSize(prog, &output_size)) + data = bytearray(output_size) + data_ptr = (data) + with nogil: + HANDLE_RETURN_NVRTC(prog, cynvrtc.nvrtcGetCUBIN(prog, data_ptr)) + else: # ltoir + HANDLE_RETURN_NVRTC(prog, cynvrtc.nvrtcGetLTOIRSize(prog, &output_size)) + data = bytearray(output_size) + data_ptr = (data) + with nogil: + HANDLE_RETURN_NVRTC(prog, cynvrtc.nvrtcGetLTOIR(prog, data_ptr)) + + # Get lowered names after compilation if name_expressions: for n in name_expressions: - symbol_mapping[n] = handle_return( - nvrtc.nvrtcGetLoweredName(py_handle, n.encode()), handle=py_handle - ) + name_bytes = n.encode() if isinstance(n, str) else n + name_ptr = name_bytes + HANDLE_RETURN_NVRTC(prog, cynvrtc.nvrtcGetLoweredName(prog, name_ptr, &lowered_name)) + symbol_mapping[n] = lowered_name if lowered_name != NULL else None + # Get compilation log if requested if logs is not None: - logsize = handle_return(nvrtc.nvrtcGetProgramLogSize(py_handle), handle=py_handle) + HANDLE_RETURN_NVRTC(prog, cynvrtc.nvrtcGetProgramLogSize(prog, &logsize)) if logsize > 1: - log = b" " * logsize - handle_return(nvrtc.nvrtcGetProgramLog(py_handle, log), handle=py_handle) + log = bytearray(logsize) + data_ptr = (log) + with nogil: + HANDLE_RETURN_NVRTC(prog, cynvrtc.nvrtcGetProgramLog(prog, data_ptr)) logs.write(log.decode("utf-8", errors="backslashreplace")) - return ObjectCode._init(data, target_type, symbol_mapping=symbol_mapping, name=self._options.name) + return ObjectCode._init(bytes(data), target_type, symbol_mapping=symbol_mapping, name=self._options.name) cdef object Program_compile_nvvm(Program self, str target_type, object logs): - """Compile using NVVM backend.""" - if target_type not in ("ptx", "ltoir"): - raise ValueError(f'NVVM backend only supports target_type="ptx", "ltoir", got "{target_type}"') - - # TODO: flip to True when NVIDIA/cuda-python#1354 is resolved and CUDA 12 is dropped - nvvm_options = self._options._prepare_nvvm_options(as_bytes=False) - if target_type == "ltoir" and "-gen-lto" not in nvvm_options: - nvvm_options.append("-gen-lto") - - nvvm = _get_nvvm_module() - # NVVM uses raw integers for handles - py_handle = as_intptr(self._h_nvvm) - - try: - nvvm.verify_program(py_handle, len(nvvm_options), nvvm_options) - nvvm.compile_program(py_handle, len(nvvm_options), nvvm_options) - except Exception as e: - # Capture NVVM program log on error - error_log = "" - try: - logsize = nvvm.get_program_log_size(py_handle) - if logsize > 1: - log = bytearray(logsize) - nvvm.get_program_log(py_handle, log) - error_log = log.decode("utf-8", errors="backslashreplace") - except Exception: - pass - e.args = (e.args[0] + (f"\nNVVM program log: {error_log}" if error_log else ""), *e.args[1:]) - raise - - size = nvvm.get_compiled_result_size(py_handle) - data = bytearray(size) - nvvm.get_compiled_result(py_handle, data) - + """Compile using NVVM backend and return ObjectCode.""" + cdef cynvvm.nvvmProgram prog = as_cu(self._h_nvvm) + cdef size_t output_size = 0 + cdef size_t logsize = 0 + cdef vector[const char*] options_vec + cdef char* data_ptr = NULL + + # Build options array + options_list = self._options.as_bytes("nvvm", target_type) + options_vec.resize(len(options_list)) + for i in range(len(options_list)): + options_vec[i] = (options_list[i]) + + # Compile + with nogil: + HANDLE_RETURN_NVVM(prog, cynvvm.nvvmVerifyProgram(prog, options_vec.size(), options_vec.data())) + HANDLE_RETURN_NVVM(prog, cynvvm.nvvmCompileProgram(prog, options_vec.size(), options_vec.data())) + + # Get compiled result + HANDLE_RETURN_NVVM(prog, cynvvm.nvvmGetCompiledResultSize(prog, &output_size)) + data = bytearray(output_size) + data_ptr = (data) + with nogil: + HANDLE_RETURN_NVVM(prog, cynvvm.nvvmGetCompiledResult(prog, data_ptr)) + + # Get compilation log if requested if logs is not None: - logsize = nvvm.get_program_log_size(py_handle) + HANDLE_RETURN_NVVM(prog, cynvvm.nvvmGetProgramLogSize(prog, &logsize)) if logsize > 1: log = bytearray(logsize) - nvvm.get_program_log(py_handle, log) + data_ptr = (log) + with nogil: + HANDLE_RETURN_NVVM(prog, cynvvm.nvvmGetProgramLog(prog, data_ptr)) logs.write(log.decode("utf-8", errors="backslashreplace")) - return ObjectCode._init(data, target_type, name=self._options.name) + return ObjectCode._init(bytes(data), target_type, name=self._options.name) + +# Supported target types per backend +cdef dict SUPPORTED_TARGETS = { + "NVRTC": ("ptx", "cubin", "ltoir"), + "NVVM": ("ptx", "ltoir"), + "nvJitLink": ("cubin", "ptx"), + "driver": ("cubin", "ptx"), +} cdef object Program_compile(Program self, str target_type, object name_expressions, object logs): """Compile the program to the specified target type.""" - supported_target_types = ("ptx", "cubin", "ltoir") - if target_type not in supported_target_types: - raise ValueError(f'Unsupported target_type="{target_type}" ({supported_target_types=})') + # Validate target_type for this backend + supported = SUPPORTED_TARGETS.get(self._backend) + if supported is None: + raise ValueError(f'Unknown backend="{self._backend}"') + if target_type not in supported: + raise ValueError( + f'Unsupported target_type="{target_type}" for {self._backend} ' + f'(supported: {", ".join(repr(t) for t in supported)})' + ) if self._backend == "NVRTC": + if target_type == "ptx" and not _can_load_generated_ptx(): + warn( + "The CUDA driver version is older than the backend version. " + "The generated ptx will not be loadable by the current driver.", + stacklevel=2, + category=RuntimeWarning, + ) return Program_compile_nvrtc(self, target_type, name_expressions, logs) + elif self._backend == "NVVM": return Program_compile_nvvm(self, target_type, logs) - # Linker backend (PTX code type) - supported_backends = ("nvJitLink", "driver") - if self._backend not in supported_backends: - raise ValueError(f'Unsupported backend="{self._backend}" ({supported_backends=})') - return self._linker.link(target_type) + else: + return self._linker.link(target_type) cdef inline list _prepare_nvrtc_options_impl(object opts): diff --git a/cuda_core/cuda/core/_utils/cuda_utils.pxd b/cuda_core/cuda/core/_utils/cuda_utils.pxd index 6cb2f76e73..339b485682 100644 --- a/cuda_core/cuda/core/_utils/cuda_utils.pxd +++ b/cuda_core/cuda/core/_utils/cuda_utils.pxd @@ -19,8 +19,8 @@ cdef const cydriver.CUcontext CU_CONTEXT_INVALID = (-2) cdef int HANDLE_RETURN(cydriver.CUresult err) except?-1 nogil -cdef int HANDLE_RETURN_NVRTC(cynvrtc.nvrtcResult err, cynvrtc.nvrtcProgram prog) except?-1 nogil -cdef int HANDLE_RETURN_NVVM(cynvvm.nvvmResult err, cynvvm.nvvmProgram prog) except?-1 nogil +cdef int HANDLE_RETURN_NVRTC(cynvrtc.nvrtcProgram prog, cynvrtc.nvrtcResult err) except?-1 nogil +cdef int HANDLE_RETURN_NVVM(cynvvm.nvvmProgram prog, cynvvm.nvvmResult err) except?-1 nogil # TODO: stop exposing these within the codebase? diff --git a/cuda_core/cuda/core/_utils/cuda_utils.pyx b/cuda_core/cuda/core/_utils/cuda_utils.pyx index 0d1d7d3344..a3c49d8e27 100644 --- a/cuda_core/cuda/core/_utils/cuda_utils.pyx +++ b/cuda_core/cuda/core/_utils/cuda_utils.pyx @@ -69,7 +69,7 @@ cdef int HANDLE_RETURN(cydriver.CUresult err) except?-1 nogil: return 0 -cdef int HANDLE_RETURN_NVRTC(cynvrtc.nvrtcResult err, cynvrtc.nvrtcProgram prog) except?-1 nogil: +cdef int HANDLE_RETURN_NVRTC(cynvrtc.nvrtcProgram prog, cynvrtc.nvrtcResult err) except?-1 nogil: """Handle NVRTC result codes, raising NVRTCError with program log on failure.""" if err == cynvrtc.nvrtcResult.NVRTC_SUCCESS: return 0 @@ -100,7 +100,7 @@ cdef int _raise_nvrtc_error(cynvrtc.nvrtcResult err, const char* err_str, raise NVRTCError(err_msg) -cdef int HANDLE_RETURN_NVVM(cynvvm.nvvmResult err, cynvvm.nvvmProgram prog) except?-1 nogil: +cdef int HANDLE_RETURN_NVVM(cynvvm.nvvmProgram prog, cynvvm.nvvmResult err) except?-1 nogil: """Handle NVVM result codes, raising NVVMError with program log on failure.""" if err == cynvvm.nvvmResult.NVVM_SUCCESS: return 0 diff --git a/cuda_core/tests/test_module.py b/cuda_core/tests/test_module.py index e237cf5474..c2c2ee1925 100644 --- a/cuda_core/tests/test_module.py +++ b/cuda_core/tests/test_module.py @@ -104,6 +104,8 @@ def test_get_kernel(init_cuda): kernel = object_code.get_kernel("ABC") assert object_code.handle is not None assert kernel.handle is not None + # Also works with bytes + assert object_code.get_kernel(b"ABC").handle is not None @pytest.mark.parametrize( diff --git a/cuda_core/tests/test_program.py b/cuda_core/tests/test_program.py index 259e6a9c98..abf29ae1f3 100644 --- a/cuda_core/tests/test_program.py +++ b/cuda_core/tests/test_program.py @@ -284,7 +284,6 @@ def test_cpp_program_with_various_options(init_cuda, options): assert program.backend == "NVRTC" program.compile("ptx") program.close() - assert program.handle is None @pytest.mark.skipif( @@ -299,7 +298,6 @@ def test_cpp_program_with_trace_option(init_cuda, tmp_path): assert program.backend == "NVRTC" program.compile("ptx") program.close() - assert program.handle is None @pytest.mark.skipif((_get_nvrtc_version_for_tests() or 0) < 12800, reason="PCH requires NVRTC >= 12.8") @@ -314,7 +312,6 @@ def test_cpp_program_with_pch_options(init_cuda, tmp_path): assert program.backend == "NVRTC" program.compile("ptx") program.close() - assert program.handle is None options = [ @@ -339,7 +336,6 @@ def test_ptx_program_with_various_options(init_cuda, ptx_code_object, options): assert program.backend == ("driver" if is_culink_backend else "nvJitLink") program.compile("cubin") program.close() - assert program.handle is None def test_program_init_valid_code_type(): @@ -409,7 +405,8 @@ def test_program_close(): code = 'extern "C" __global__ void my_kernel() {}' program = Program(code, "c++") program.close() - assert program.handle is None + # close() is idempotent + program.close() @nvvm_available @@ -441,7 +438,7 @@ def test_nvvm_program_creation_compilation(nvvm_ir): def test_nvvm_compile_invalid_target(nvvm_ir): """Test that NVVM programs reject invalid compilation targets""" program = Program(nvvm_ir, "nvvm") - with pytest.raises(ValueError, match='NVVM backend only supports target_type="ptx"'): + with pytest.raises(ValueError, match='Unsupported target_type="cubin" for NVVM'): program.compile("cubin") program.close()