diff --git a/cuda_core/cuda/core/_cpp/resource_handles.cpp b/cuda_core/cuda/core/_cpp/resource_handles.cpp index 581a7af69a..27e76b817a 100644 --- a/cuda_core/cuda/core/_cpp/resource_handles.cpp +++ b/cuda_core/cuda/core/_cpp/resource_handles.cpp @@ -56,6 +56,12 @@ decltype(&cuLibraryLoadData) p_cuLibraryLoadData = nullptr; decltype(&cuLibraryUnload) p_cuLibraryUnload = nullptr; decltype(&cuLibraryGetKernel) p_cuLibraryGetKernel = nullptr; +// NVRTC function pointers +decltype(&nvrtcDestroyProgram) p_nvrtcDestroyProgram = nullptr; + +// NVVM function pointers (may be null if NVVM is not available) +NvvmDestroyProgramFn p_nvvmDestroyProgram = nullptr; + // ============================================================================ // GIL management helpers // ============================================================================ @@ -764,4 +770,64 @@ KernelHandle create_kernel_handle_ref(CUkernel kernel, const LibraryHandle& h_li return KernelHandle(box, &box->resource); } +// ============================================================================ +// NVRTC Program Handles +// ============================================================================ + +namespace { +struct NvrtcProgramBox { + nvrtcProgram resource; +}; +} // namespace + +NvrtcProgramHandle create_nvrtc_program_handle(nvrtcProgram prog) { + auto box = std::shared_ptr( + new NvrtcProgramBox{prog}, + [](NvrtcProgramBox* b) { + // Note: nvrtcDestroyProgram takes nvrtcProgram* and nulls it, + // but we're deleting the box anyway so nulling is harmless. + // Errors are ignored (standard destructor practice). + p_nvrtcDestroyProgram(&b->resource); + delete b; + } + ); + return NvrtcProgramHandle(box, &box->resource); +} + +NvrtcProgramHandle create_nvrtc_program_handle_ref(nvrtcProgram prog) { + auto box = std::make_shared(NvrtcProgramBox{prog}); + return NvrtcProgramHandle(box, &box->resource); +} + +// ============================================================================ +// NVVM Program Handles +// ============================================================================ + +namespace { +struct NvvmProgramBox { + nvvmProgram resource; +}; +} // namespace + +NvvmProgramHandle create_nvvm_program_handle(nvvmProgram prog) { + auto box = std::shared_ptr( + new NvvmProgramBox{prog}, + [](NvvmProgramBox* b) { + // Note: nvvmDestroyProgram takes nvvmProgram* and nulls it, + // but we're deleting the box anyway so nulling is harmless. + // If NVVM is not available, the function pointer is null. + if (p_nvvmDestroyProgram) { + p_nvvmDestroyProgram(&b->resource); + } + delete b; + } + ); + return NvvmProgramHandle(box, &box->resource); +} + +NvvmProgramHandle create_nvvm_program_handle_ref(nvvmProgram prog) { + auto box = std::make_shared(NvvmProgramBox{prog}); + return NvvmProgramHandle(box, &box->resource); +} + } // namespace cuda_core diff --git a/cuda_core/cuda/core/_cpp/resource_handles.hpp b/cuda_core/cuda/core/_cpp/resource_handles.hpp index b6118a07a2..cb66841172 100644 --- a/cuda_core/cuda/core/_cpp/resource_handles.hpp +++ b/cuda_core/cuda/core/_cpp/resource_handles.hpp @@ -6,9 +6,14 @@ #include #include +#include #include #include +// Forward declaration for NVVM - avoids nvvm.h dependency +// Use void* to match cuda.bindings.cynvvm's typedef +using nvvmProgram = void*; + namespace cuda_core { // ============================================================================ @@ -67,6 +72,28 @@ extern decltype(&cuLibraryLoadData) p_cuLibraryLoadData; extern decltype(&cuLibraryUnload) p_cuLibraryUnload; extern decltype(&cuLibraryGetKernel) p_cuLibraryGetKernel; +// ============================================================================ +// NVRTC function pointers +// +// These are populated by _resource_handles.pyx at module import time using +// function pointers extracted from cuda.bindings.cynvrtc.__pyx_capi__. +// ============================================================================ + +extern decltype(&nvrtcDestroyProgram) p_nvrtcDestroyProgram; + +// ============================================================================ +// NVVM function pointers +// +// These are populated by _resource_handles.pyx at module import time using +// function pointers extracted from cuda.bindings.cynvvm.__pyx_capi__. +// Note: May be null if NVVM is not available at runtime. +// ============================================================================ + +// Function pointer type for nvvmDestroyProgram (avoids nvvm.h dependency) +// Signature: nvvmResult nvvmDestroyProgram(nvvmProgram *prog) +using NvvmDestroyProgramFn = int (*)(nvvmProgram*); +extern NvvmDestroyProgramFn p_nvvmDestroyProgram; + // ============================================================================ // Handle type aliases - expose only the raw CUDA resource // ============================================================================ @@ -77,6 +104,8 @@ using EventHandle = std::shared_ptr; using MemoryPoolHandle = std::shared_ptr; using LibraryHandle = std::shared_ptr; using KernelHandle = std::shared_ptr; +using NvrtcProgramHandle = std::shared_ptr; +using NvvmProgramHandle = std::shared_ptr; // ============================================================================ // Context handle functions @@ -260,6 +289,33 @@ KernelHandle create_kernel_handle(const LibraryHandle& h_library, const char* na // Use for borrowed kernels. The library handle keeps the library alive. KernelHandle create_kernel_handle_ref(CUkernel kernel, const LibraryHandle& h_library); +// ============================================================================ +// NVRTC Program handle functions +// ============================================================================ + +// Create an owning NVRTC program handle. +// When the last reference is released, nvrtcDestroyProgram is called. +// Use this to wrap a program created via nvrtcCreateProgram. +NvrtcProgramHandle create_nvrtc_program_handle(nvrtcProgram prog); + +// Create a non-owning NVRTC program handle (references existing program). +// The program will NOT be destroyed when the handle is released. +NvrtcProgramHandle create_nvrtc_program_handle_ref(nvrtcProgram prog); + +// ============================================================================ +// NVVM Program handle functions +// ============================================================================ + +// Create an owning NVVM program handle. +// When the last reference is released, nvvmDestroyProgram is called. +// Use this to wrap a program created via nvvmCreateProgram. +// Note: If NVVM is not available (p_nvvmDestroyProgram is null), the deleter is a no-op. +NvvmProgramHandle create_nvvm_program_handle(nvvmProgram prog); + +// Create a non-owning NVVM program handle (references existing program). +// The program will NOT be destroyed when the handle is released. +NvvmProgramHandle create_nvvm_program_handle_ref(nvvmProgram prog); + // ============================================================================ // Overloaded helper functions to extract raw resources from handles // ============================================================================ @@ -293,6 +349,14 @@ inline CUkernel as_cu(const KernelHandle& h) noexcept { return h ? *h : nullptr; } +inline nvrtcProgram as_cu(const NvrtcProgramHandle& h) noexcept { + return h ? *h : nullptr; +} + +inline nvvmProgram as_cu(const NvvmProgramHandle& h) noexcept { + return h ? *h : nullptr; +} + // as_intptr() - extract handle as intptr_t for Python interop // Using signed intptr_t per C standard convention and issue #1342 inline std::intptr_t as_intptr(const ContextHandle& h) noexcept { @@ -323,11 +387,19 @@ inline std::intptr_t as_intptr(const KernelHandle& h) noexcept { return reinterpret_cast(as_cu(h)); } -// as_py() - convert handle to Python driver wrapper object (returns new reference) +inline std::intptr_t as_intptr(const NvrtcProgramHandle& h) noexcept { + return reinterpret_cast(as_cu(h)); +} + +inline std::intptr_t as_intptr(const NvvmProgramHandle& h) noexcept { + return reinterpret_cast(as_cu(h)); +} + +// as_py() - convert handle to Python wrapper object (returns new reference) namespace detail { // n.b. class lookup is not cached to avoid deadlock hazard, see DESIGN.md -inline PyObject* make_py(const char* class_name, std::intptr_t value) noexcept { - PyObject* mod = PyImport_ImportModule("cuda.bindings.driver"); +inline PyObject* make_py(const char* module_name, const char* class_name, std::intptr_t value) noexcept { + PyObject* mod = PyImport_ImportModule(module_name); if (!mod) return nullptr; PyObject* cls = PyObject_GetAttrString(mod, class_name); Py_DECREF(mod); @@ -339,31 +411,40 @@ inline PyObject* make_py(const char* class_name, std::intptr_t value) noexcept { } // namespace detail inline PyObject* as_py(const ContextHandle& h) noexcept { - return detail::make_py("CUcontext", as_intptr(h)); + return detail::make_py("cuda.bindings.driver", "CUcontext", as_intptr(h)); } inline PyObject* as_py(const StreamHandle& h) noexcept { - return detail::make_py("CUstream", as_intptr(h)); + return detail::make_py("cuda.bindings.driver", "CUstream", as_intptr(h)); } inline PyObject* as_py(const EventHandle& h) noexcept { - return detail::make_py("CUevent", as_intptr(h)); + return detail::make_py("cuda.bindings.driver", "CUevent", as_intptr(h)); } inline PyObject* as_py(const MemoryPoolHandle& h) noexcept { - return detail::make_py("CUmemoryPool", as_intptr(h)); + return detail::make_py("cuda.bindings.driver", "CUmemoryPool", as_intptr(h)); } inline PyObject* as_py(const DevicePtrHandle& h) noexcept { - return detail::make_py("CUdeviceptr", as_intptr(h)); + return detail::make_py("cuda.bindings.driver", "CUdeviceptr", as_intptr(h)); } inline PyObject* as_py(const LibraryHandle& h) noexcept { - return detail::make_py("CUlibrary", as_intptr(h)); + return detail::make_py("cuda.bindings.driver", "CUlibrary", as_intptr(h)); } inline PyObject* as_py(const KernelHandle& h) noexcept { - return detail::make_py("CUkernel", as_intptr(h)); + return detail::make_py("cuda.bindings.driver", "CUkernel", as_intptr(h)); +} + +inline PyObject* as_py(const NvrtcProgramHandle& h) noexcept { + return detail::make_py("cuda.bindings.nvrtc", "nvrtcProgram", as_intptr(h)); +} + +inline PyObject* as_py(const NvvmProgramHandle& h) noexcept { + // NVVM bindings use raw integers, not wrapper classes + return PyLong_FromSsize_t(as_intptr(h)); } } // namespace cuda_core diff --git a/cuda_core/cuda/core/_module.pyx b/cuda_core/cuda/core/_module.pyx index 5508f3f0c6..af3977b15a 100644 --- a/cuda_core/cuda/core/_module.pyx +++ b/cuda_core/cuda/core/_module.pyx @@ -816,7 +816,8 @@ cdef class ObjectCode: try: name = self._sym_map[name] except KeyError: - name = name.encode() + if isinstance(name, str): + name = name.encode() cdef KernelHandle h_kernel = create_kernel_handle(self._h_library, name) if not h_kernel: diff --git a/cuda_core/cuda/core/_program.pxd b/cuda_core/cuda/core/_program.pxd new file mode 100644 index 0000000000..92d30f8c0c --- /dev/null +++ b/cuda_core/cuda/core/_program.pxd @@ -0,0 +1,15 @@ +# SPDX-FileCopyrightText: Copyright (c) 2024-2025 NVIDIA CORPORATION & AFFILIATES. All rights reserved. +# +# SPDX-License-Identifier: Apache-2.0 + +from ._resource_handles cimport NvrtcProgramHandle, NvvmProgramHandle + + +cdef class Program: + cdef: + NvrtcProgramHandle _h_nvrtc + NvvmProgramHandle _h_nvvm + str _backend + object _linker # Linker + object _options # ProgramOptions + object __weakref__ diff --git a/cuda_core/cuda/core/_program.py b/cuda_core/cuda/core/_program.py deleted file mode 100644 index 1ef1aa51f5..0000000000 --- a/cuda_core/cuda/core/_program.py +++ /dev/null @@ -1,860 +0,0 @@ -# SPDX-FileCopyrightText: Copyright (c) 2024-2025 NVIDIA CORPORATION & AFFILIATES. All rights reserved. -# -# SPDX-License-Identifier: Apache-2.0 - -from __future__ import annotations - -import weakref -from contextlib import contextmanager -from dataclasses import dataclass -from typing import TYPE_CHECKING, Union -from warnings import warn - -if TYPE_CHECKING: - import cuda.bindings - -from cuda.core._device import Device -from cuda.core._linker import Linker, LinkerHandleT, LinkerOptions -from cuda.core._module import ObjectCode -from cuda.core._utils.clear_error_support import assert_type -from cuda.core._utils.cuda_utils import ( - CUDAError, - _handle_boolean_option, - check_or_create_options, - driver, - get_binding_version, - handle_return, - is_nested_sequence, - is_sequence, - nvrtc, -) - - -@contextmanager -def _nvvm_exception_manager(self): - """ - Taken from _linker.py - """ - try: - yield - except Exception as e: - error_log = "" - if hasattr(self, "_mnff"): - try: - nvvm = _get_nvvm_module() - logsize = nvvm.get_program_log_size(self._mnff.handle) - if logsize > 1: - log = bytearray(logsize) - nvvm.get_program_log(self._mnff.handle, log) - error_log = log.decode("utf-8", errors="backslashreplace") - except Exception: - error_log = "" - # Starting Python 3.11 we could also use Exception.add_note() for the same purpose, but - # unfortunately we are still supporting Python 3.10... - e.args = (e.args[0] + (f"\nNVVM program log: {error_log}" if error_log else ""), *e.args[1:]) - raise e - - -_nvvm_module = None -_nvvm_import_attempted = False - - -def _get_nvvm_module(): - """ - Handles the import of NVVM module with version and availability checks. - NVVM bindings were added in cuda-bindings 12.9.0, so we need to handle cases where: - 1. cuda.bindings is not new enough (< 12.9.0) - 2. libnvvm is not found in the Python environment - - Returns: - The nvvm module if available and working - - Raises: - RuntimeError: If NVVM is not available due to version or library issues - """ - global _nvvm_module, _nvvm_import_attempted - - if _nvvm_import_attempted: - if _nvvm_module is None: - raise RuntimeError("NVVM module is not available (previous import attempt failed)") - return _nvvm_module - - _nvvm_import_attempted = True - - try: - version = get_binding_version() - if version < (12, 9): - raise RuntimeError( - f"NVVM bindings require cuda-bindings >= 12.9.0, but found {version[0]}.{version[1]}.x. " - "Please update cuda-bindings to use NVVM features." - ) - - from cuda.bindings import nvvm - from cuda.bindings._internal.nvvm import _inspect_function_pointer - - if _inspect_function_pointer("__nvvmCreateProgram") == 0: - raise RuntimeError("NVVM library (libnvvm) is not available in this Python environment. ") - - _nvvm_module = nvvm - return _nvvm_module - - except RuntimeError as e: - _nvvm_module = None - raise e - - -def _process_define_macro_inner(formatted_options, macro): - if isinstance(macro, str): - formatted_options.append(f"--define-macro={macro}") - return True - if isinstance(macro, tuple): - if len(macro) != 2 or any(not isinstance(val, str) for val in macro): - raise RuntimeError(f"Expected define_macro tuple[str, str], got {macro}") - formatted_options.append(f"--define-macro={macro[0]}={macro[1]}") - return True - return False - - -def _process_define_macro(formatted_options, macro): - union_type = "Union[str, tuple[str, str]]" - if _process_define_macro_inner(formatted_options, macro): - return - if is_nested_sequence(macro): - for seq_macro in macro: - if not _process_define_macro_inner(formatted_options, seq_macro): - raise RuntimeError(f"Expected define_macro {union_type}, got {seq_macro}") - return - raise RuntimeError(f"Expected define_macro {union_type}, list[{union_type}], got {macro}") - - -@dataclass -class ProgramOptions: - """Customizable options for configuring `Program`. - - Attributes - ---------- - name : str, optional - Name of the program. If the compilation succeeds, the name is passed down to the generated `ObjectCode`. - arch : str, optional - Pass the SM architecture value, such as ``sm_`` (for generating CUBIN) or - ``compute_`` (for generating PTX). If not provided, the current device's architecture - will be used. - relocatable_device_code : bool, optional - Enable (disable) the generation of relocatable device code. - Default: False - extensible_whole_program : bool, optional - Do extensible whole program compilation of device code. - Default: False - debug : bool, optional - Generate debug information. If --dopt is not specified, then turns off all optimizations. - Default: False - lineinfo: bool, optional - Generate line-number information. - Default: False - device_code_optimize : bool, optional - Enable device code optimization. When specified along with ā€˜-G’, enables limited debug information generation - for optimized device code. - Default: None - ptxas_options : Union[str, list[str]], optional - Specify one or more options directly to ptxas, the PTX optimizing assembler. Options should be strings. - For example ["-v", "-O2"]. - Default: None - max_register_count : int, optional - Specify the maximum amount of registers that GPU functions can use. - Default: None - ftz : bool, optional - When performing single-precision floating-point operations, flush denormal values to zero or preserve denormal - values. - Default: False - prec_sqrt : bool, optional - For single-precision floating-point square root, use IEEE round-to-nearest mode or use a faster approximation. - Default: True - prec_div : bool, optional - For single-precision floating-point division and reciprocals, use IEEE round-to-nearest mode or use a faster - approximation. - Default: True - fma : bool, optional - Enables (disables) the contraction of floating-point multiplies and adds/subtracts into floating-point - multiply-add operations. - Default: True - use_fast_math : bool, optional - Make use of fast math operations. - Default: False - extra_device_vectorization : bool, optional - Enables more aggressive device code vectorization in the NVVM optimizer. - Default: False - link_time_optimization : bool, optional - Generate intermediate code for later link-time optimization. - Default: False - gen_opt_lto : bool, optional - Run the optimizer passes before generating the LTO IR. - Default: False - define_macro : Union[str, tuple[str, str], list[Union[str, tuple[str, str]]]], optional - Predefine a macro. Can be either a string, in which case that macro will be set to 1, a 2 element tuple of - strings, in which case the first element is defined as the second, or a list of strings or tuples. - Default: None - undefine_macro : Union[str, list[str]], optional - Cancel any previous definition of a macro, or list of macros. - Default: None - include_path : Union[str, list[str]], optional - Add the directory or directories to the list of directories to be searched for headers. - Default: None - pre_include : Union[str, list[str]], optional - Preinclude one or more headers during preprocessing. Can be either a string or a list of strings. - Default: None - no_source_include : bool, optional - Disable the default behavior of adding the directory of each input source to the include path. - Default: False - std : str, optional - Set language dialect to C++03, C++11, C++14, C++17 or C++20. - Default: c++17 - builtin_move_forward : bool, optional - Provide builtin definitions of std::move and std::forward. - Default: True - builtin_initializer_list : bool, optional - Provide builtin definitions of std::initializer_list class and member functions. - Default: True - disable_warnings : bool, optional - Inhibit all warning messages. - Default: False - restrict : bool, optional - Programmer assertion that all kernel pointer parameters are restrict pointers. - Default: False - device_as_default_execution_space : bool, optional - Treat entities with no execution space annotation as __device__ entities. - Default: False - device_int128 : bool, optional - Allow the __int128 type in device code. - Default: False - optimization_info : str, optional - Provide optimization reports for the specified kind of optimization. - Default: None - no_display_error_number : bool, optional - Disable the display of a diagnostic number for warning messages. - Default: False - diag_error : Union[int, list[int]], optional - Emit error for a specified diagnostic message number or comma separated list of numbers. - Default: None - diag_suppress : Union[int, list[int]], optional - Suppress a specified diagnostic message number or comma separated list of numbers. - Default: None - diag_warn : Union[int, list[int]], optional - Emit warning for a specified diagnostic message number or comma separated lis of numbers. - Default: None - brief_diagnostics : bool, optional - Disable or enable showing source line and column info in a diagnostic. - Default: False - time : str, optional - Generate a CSV table with the time taken by each compilation phase. - Default: None - split_compile : int, optional - Perform compiler optimizations in parallel. - Default: 1 - fdevice_syntax_only : bool, optional - Ends device compilation after front-end syntax checking. - Default: False - minimal : bool, optional - Omit certain language features to reduce compile time for small programs. - Default: False - no_cache : bool, optional - Disable compiler caching. - Default: False - fdevice_time_trace : str, optional - Generate time trace JSON for profiling compilation (NVRTC only). - Default: None - device_float128 : bool, optional - Allow __float128 type in device code (NVRTC only). - Default: False - frandom_seed : str, optional - Set random seed for randomized optimizations (NVRTC only). - Default: None - ofast_compile : str, optional - Fast compilation mode: "0", "min", "mid", or "max" (NVRTC only). - Default: None - pch : bool, optional - Use default precompiled header (NVRTC only, CUDA 12.8+). - Default: False - create_pch : str, optional - Create precompiled header file (NVRTC only, CUDA 12.8+). - Default: None - use_pch : str, optional - Use specific precompiled header file (NVRTC only, CUDA 12.8+). - Default: None - pch_dir : str, optional - PCH directory location (NVRTC only, CUDA 12.8+). - Default: None - pch_verbose : bool, optional - Verbose PCH output (NVRTC only, CUDA 12.8+). - Default: False - pch_messages : bool, optional - Control PCH diagnostic messages (NVRTC only, CUDA 12.8+). - Default: False - instantiate_templates_in_pch : bool, optional - Control template instantiation in PCH (NVRTC only, CUDA 12.8+). - Default: False - """ - - name: str | None = "default_program" - arch: str | None = None - relocatable_device_code: bool | None = None - extensible_whole_program: bool | None = None - debug: bool | None = None - lineinfo: bool | None = None - device_code_optimize: bool | None = None - ptxas_options: str | list[str] | tuple[str] | None = None - max_register_count: int | None = None - ftz: bool | None = None - prec_sqrt: bool | None = None - prec_div: bool | None = None - fma: bool | None = None - use_fast_math: bool | None = None - extra_device_vectorization: bool | None = None - link_time_optimization: bool | None = None - gen_opt_lto: bool | None = None - define_macro: str | tuple[str, str] | list[str | tuple[str, str]] | tuple[str | tuple[str, str], ...] | None = None - undefine_macro: str | list[str] | tuple[str] | None = None - include_path: str | list[str] | tuple[str] | None = None - pre_include: str | list[str] | tuple[str] | None = None - no_source_include: bool | None = None - std: str | None = None - builtin_move_forward: bool | None = None - builtin_initializer_list: bool | None = None - disable_warnings: bool | None = None - restrict: bool | None = None - device_as_default_execution_space: bool | None = None - device_int128: bool | None = None - optimization_info: str | None = None - no_display_error_number: bool | None = None - diag_error: int | list[int] | tuple[int] | None = None - diag_suppress: int | list[int] | tuple[int] | None = None - diag_warn: int | list[int] | tuple[int] | None = None - brief_diagnostics: bool | None = None - time: str | None = None - split_compile: int | None = None - fdevice_syntax_only: bool | None = None - minimal: bool | None = None - no_cache: bool | None = None - fdevice_time_trace: str | None = None - device_float128: bool | None = None - frandom_seed: str | None = None - ofast_compile: str | None = None - pch: bool | None = None - create_pch: str | None = None - use_pch: str | None = None - pch_dir: str | None = None - pch_verbose: bool | None = None - pch_messages: bool | None = None - instantiate_templates_in_pch: bool | None = None - numba_debug: bool | None = None # Custom option for Numba debugging - - def __post_init__(self): - self._name = self.name.encode() - # Set arch to default if not provided - if self.arch is None: - self.arch = f"sm_{Device().arch}" - - def _prepare_nvrtc_options(self) -> list[bytes]: - # Build NVRTC-specific options - options = [f"-arch={self.arch}"] - if self.relocatable_device_code is not None: - options.append(f"--relocatable-device-code={_handle_boolean_option(self.relocatable_device_code)}") - if self.extensible_whole_program is not None and self.extensible_whole_program: - options.append("--extensible-whole-program") - if self.debug is not None and self.debug: - options.append("--device-debug") - if self.lineinfo is not None and self.lineinfo: - options.append("--generate-line-info") - if self.device_code_optimize is not None and self.device_code_optimize: - options.append("--dopt=on") - if self.ptxas_options is not None: - opt_name = "--ptxas-options" - if isinstance(self.ptxas_options, str): - options.append(f"{opt_name}={self.ptxas_options}") - elif is_sequence(self.ptxas_options): - for opt_value in self.ptxas_options: - options.append(f"{opt_name}={opt_value}") - if self.max_register_count is not None: - options.append(f"--maxrregcount={self.max_register_count}") - if self.ftz is not None: - options.append(f"--ftz={_handle_boolean_option(self.ftz)}") - if self.prec_sqrt is not None: - options.append(f"--prec-sqrt={_handle_boolean_option(self.prec_sqrt)}") - if self.prec_div is not None: - options.append(f"--prec-div={_handle_boolean_option(self.prec_div)}") - if self.fma is not None: - options.append(f"--fmad={_handle_boolean_option(self.fma)}") - if self.use_fast_math is not None and self.use_fast_math: - options.append("--use_fast_math") - if self.extra_device_vectorization is not None and self.extra_device_vectorization: - options.append("--extra-device-vectorization") - if self.link_time_optimization is not None and self.link_time_optimization: - options.append("--dlink-time-opt") - if self.gen_opt_lto is not None and self.gen_opt_lto: - options.append("--gen-opt-lto") - if self.define_macro is not None: - _process_define_macro(options, self.define_macro) - if self.undefine_macro is not None: - if isinstance(self.undefine_macro, str): - options.append(f"--undefine-macro={self.undefine_macro}") - elif is_sequence(self.undefine_macro): - for macro in self.undefine_macro: - options.append(f"--undefine-macro={macro}") - if self.include_path is not None: - if isinstance(self.include_path, str): - options.append(f"--include-path={self.include_path}") - elif is_sequence(self.include_path): - for path in self.include_path: - options.append(f"--include-path={path}") - if self.pre_include is not None: - if isinstance(self.pre_include, str): - options.append(f"--pre-include={self.pre_include}") - elif is_sequence(self.pre_include): - for header in self.pre_include: - options.append(f"--pre-include={header}") - if self.no_source_include is not None and self.no_source_include: - options.append("--no-source-include") - if self.std is not None: - options.append(f"--std={self.std}") - if self.builtin_move_forward is not None: - options.append(f"--builtin-move-forward={_handle_boolean_option(self.builtin_move_forward)}") - if self.builtin_initializer_list is not None: - options.append(f"--builtin-initializer-list={_handle_boolean_option(self.builtin_initializer_list)}") - if self.disable_warnings is not None and self.disable_warnings: - options.append("--disable-warnings") - if self.restrict is not None and self.restrict: - options.append("--restrict") - if self.device_as_default_execution_space is not None and self.device_as_default_execution_space: - options.append("--device-as-default-execution-space") - if self.device_int128 is not None and self.device_int128: - options.append("--device-int128") - if self.device_float128 is not None and self.device_float128: - options.append("--device-float128") - if self.optimization_info is not None: - options.append(f"--optimization-info={self.optimization_info}") - if self.no_display_error_number is not None and self.no_display_error_number: - options.append("--no-display-error-number") - if self.diag_error is not None: - if isinstance(self.diag_error, int): - options.append(f"--diag-error={self.diag_error}") - elif is_sequence(self.diag_error): - for error in self.diag_error: - options.append(f"--diag-error={error}") - if self.diag_suppress is not None: - if isinstance(self.diag_suppress, int): - options.append(f"--diag-suppress={self.diag_suppress}") - elif is_sequence(self.diag_suppress): - for suppress in self.diag_suppress: - options.append(f"--diag-suppress={suppress}") - if self.diag_warn is not None: - if isinstance(self.diag_warn, int): - options.append(f"--diag-warn={self.diag_warn}") - elif is_sequence(self.diag_warn): - for warn in self.diag_warn: - options.append(f"--diag-warn={warn}") - if self.brief_diagnostics is not None: - options.append(f"--brief-diagnostics={_handle_boolean_option(self.brief_diagnostics)}") - if self.time is not None: - options.append(f"--time={self.time}") - if self.split_compile is not None: - options.append(f"--split-compile={self.split_compile}") - if self.fdevice_syntax_only is not None and self.fdevice_syntax_only: - options.append("--fdevice-syntax-only") - if self.minimal is not None and self.minimal: - options.append("--minimal") - if self.no_cache is not None and self.no_cache: - options.append("--no-cache") - if self.fdevice_time_trace is not None: - options.append(f"--fdevice-time-trace={self.fdevice_time_trace}") - if self.frandom_seed is not None: - options.append(f"--frandom-seed={self.frandom_seed}") - if self.ofast_compile is not None: - options.append(f"--Ofast-compile={self.ofast_compile}") - # PCH options (CUDA 12.8+) - if self.pch is not None and self.pch: - options.append("--pch") - if self.create_pch is not None: - options.append(f"--create-pch={self.create_pch}") - if self.use_pch is not None: - options.append(f"--use-pch={self.use_pch}") - if self.pch_dir is not None: - options.append(f"--pch-dir={self.pch_dir}") - if self.pch_verbose is not None: - options.append(f"--pch-verbose={_handle_boolean_option(self.pch_verbose)}") - if self.pch_messages is not None: - options.append(f"--pch-messages={_handle_boolean_option(self.pch_messages)}") - if self.instantiate_templates_in_pch is not None: - options.append( - f"--instantiate-templates-in-pch={_handle_boolean_option(self.instantiate_templates_in_pch)}" - ) - if self.numba_debug: - options.append("--numba-debug") - return [o.encode() for o in options] - - def _prepare_nvvm_options(self, as_bytes: bool = True) -> list[bytes] | list[str]: - options = [] - - # Options supported by NVVM - assert self.arch is not None - arch = self.arch - if arch.startswith("sm_"): - arch = f"compute_{arch[3:]}" - options.append(f"-arch={arch}") - if self.debug is not None and self.debug: - options.append("-g") - if self.device_code_optimize is False: - options.append("-opt=0") - elif self.device_code_optimize is True: - options.append("-opt=3") - # NVVM uses 0/1 instead of true/false for boolean options - if self.ftz is not None: - options.append(f"-ftz={'1' if self.ftz else '0'}") - if self.prec_sqrt is not None: - options.append(f"-prec-sqrt={'1' if self.prec_sqrt else '0'}") - if self.prec_div is not None: - options.append(f"-prec-div={'1' if self.prec_div else '0'}") - if self.fma is not None: - options.append(f"-fma={'1' if self.fma else '0'}") - - # Check for unsupported options and raise error if they are set - unsupported = [] - if self.relocatable_device_code is not None: - unsupported.append("relocatable_device_code") - if self.extensible_whole_program is not None and self.extensible_whole_program: - unsupported.append("extensible_whole_program") - if self.lineinfo is not None and self.lineinfo: - unsupported.append("lineinfo") - if self.ptxas_options is not None: - unsupported.append("ptxas_options") - if self.max_register_count is not None: - unsupported.append("max_register_count") - if self.use_fast_math is not None and self.use_fast_math: - unsupported.append("use_fast_math") - if self.extra_device_vectorization is not None and self.extra_device_vectorization: - unsupported.append("extra_device_vectorization") - if self.gen_opt_lto is not None and self.gen_opt_lto: - unsupported.append("gen_opt_lto") - if self.define_macro is not None: - unsupported.append("define_macro") - if self.undefine_macro is not None: - unsupported.append("undefine_macro") - if self.include_path is not None: - unsupported.append("include_path") - if self.pre_include is not None: - unsupported.append("pre_include") - if self.no_source_include is not None and self.no_source_include: - unsupported.append("no_source_include") - if self.std is not None: - unsupported.append("std") - if self.builtin_move_forward is not None: - unsupported.append("builtin_move_forward") - if self.builtin_initializer_list is not None: - unsupported.append("builtin_initializer_list") - if self.disable_warnings is not None and self.disable_warnings: - unsupported.append("disable_warnings") - if self.restrict is not None and self.restrict: - unsupported.append("restrict") - if self.device_as_default_execution_space is not None and self.device_as_default_execution_space: - unsupported.append("device_as_default_execution_space") - if self.device_int128 is not None and self.device_int128: - unsupported.append("device_int128") - if self.optimization_info is not None: - unsupported.append("optimization_info") - if self.no_display_error_number is not None and self.no_display_error_number: - unsupported.append("no_display_error_number") - if self.diag_error is not None: - unsupported.append("diag_error") - if self.diag_suppress is not None: - unsupported.append("diag_suppress") - if self.diag_warn is not None: - unsupported.append("diag_warn") - if self.brief_diagnostics is not None: - unsupported.append("brief_diagnostics") - if self.time is not None: - unsupported.append("time") - if self.split_compile is not None: - unsupported.append("split_compile") - if self.fdevice_syntax_only is not None and self.fdevice_syntax_only: - unsupported.append("fdevice_syntax_only") - if self.minimal is not None and self.minimal: - unsupported.append("minimal") - if self.numba_debug is not None and self.numba_debug: - unsupported.append("numba_debug") - if unsupported: - raise CUDAError(f"The following options are not supported by NVVM backend: {', '.join(unsupported)}") - - if as_bytes: - return [o.encode() for o in options] - else: - return options - - def as_bytes(self, backend: str) -> list[bytes]: - """Convert program options to bytes format for the specified backend. - - This method transforms the program options into a format suitable for the - specified compiler backend. Different backends may use different option names - and formats even for the same conceptual options. - - Parameters - ---------- - backend : str - The compiler backend to prepare options for. Must be either "nvrtc" or "nvvm". - - Returns - ------- - list[bytes] - List of option strings encoded as bytes. - - Raises - ------ - ValueError - If an unknown backend is specified. - CUDAError - If an option incompatible with the specified backend is set. - - Examples - -------- - >>> options = ProgramOptions(arch="sm_80", debug=True) - >>> nvrtc_options = options.as_bytes("nvrtc") - """ - backend = backend.lower() - if backend == "nvrtc": - return self._prepare_nvrtc_options() - elif backend == "nvvm": - return self._prepare_nvvm_options(as_bytes=True) - else: - raise ValueError(f"Unknown backend '{backend}'. Must be one of: 'nvrtc', 'nvvm'") - - def __repr__(self): - return f"ProgramOptions(name={self.name!r}, arch={self.arch!r})" - - -ProgramHandleT = Union["cuda.bindings.nvrtc.nvrtcProgram", LinkerHandleT] - - -class Program: - """Represent a compilation machinery to process programs into - :obj:`~_module.ObjectCode`. - - This object provides a unified interface to multiple underlying - compiler libraries. Compilation support is enabled for a wide - range of code types and compilation types. - - Parameters - ---------- - code : Any - String of the CUDA Runtime Compilation program. - code_type : Any - String of the code type. Currently ``"ptx"``, ``"c++"``, and ``"nvvm"`` are supported. - options : ProgramOptions, optional - A ProgramOptions object to customize the compilation process. - See :obj:`ProgramOptions` for more information. - """ - - class _MembersNeededForFinalize: - __slots__ = "handle", "backend" - - def __init__(self, program_obj, handle, backend): - self.handle = handle - self.backend = backend - weakref.finalize(program_obj, self.close) - - def close(self): - if self.handle is not None: - if self.backend == "NVRTC": - handle_return(nvrtc.nvrtcDestroyProgram(self.handle)) - elif self.backend == "NVVM": - nvvm = _get_nvvm_module() - nvvm.destroy_program(self.handle) - self.handle = None - - __slots__ = ("__weakref__", "_mnff", "_backend", "_linker", "_options") - - def __init__(self, code, code_type, options: ProgramOptions = None): - self._mnff = Program._MembersNeededForFinalize(self, None, None) - - self._options = options = check_or_create_options(ProgramOptions, options, "Program options") - code_type = code_type.lower() - - if code_type == "c++": - assert_type(code, str) - # TODO: support pre-loaded headers & include names - # TODO: allow tuples once NVIDIA/cuda-python#72 is resolved - - self._mnff.handle = handle_return(nvrtc.nvrtcCreateProgram(code.encode(), options._name, 0, [], [])) - self._mnff.backend = "NVRTC" - self._backend = "NVRTC" - self._linker = None - - elif code_type == "ptx": - assert_type(code, str) - self._linker = Linker( - ObjectCode._init(code.encode(), code_type), options=self._translate_program_options(options) - ) - self._backend = self._linker.backend - - elif code_type == "nvvm": - if isinstance(code, str): - code = code.encode("utf-8") - elif not isinstance(code, (bytes, bytearray)): - raise TypeError("NVVM IR code must be provided as str, bytes, or bytearray") - - nvvm = _get_nvvm_module() - self._mnff.handle = nvvm.create_program() - self._mnff.backend = "NVVM" - nvvm.add_module_to_program(self._mnff.handle, code, len(code), options._name.decode()) - self._backend = "NVVM" - self._linker = None - - else: - supported_code_types = ("c++", "ptx", "nvvm") - assert code_type not in supported_code_types, f"{code_type=}" - raise RuntimeError(f"Unsupported {code_type=} ({supported_code_types=})") - - def _translate_program_options(self, options: ProgramOptions) -> LinkerOptions: - return LinkerOptions( - name=options.name, - arch=options.arch, - max_register_count=options.max_register_count, - time=options.time, - link_time_optimization=options.link_time_optimization, - debug=options.debug, - lineinfo=options.lineinfo, - ftz=options.ftz, - prec_div=options.prec_div, - prec_sqrt=options.prec_sqrt, - fma=options.fma, - split_compile=options.split_compile, - ptxas_options=options.ptxas_options, - no_cache=options.no_cache, - ) - - def close(self): - """Destroy this program.""" - if self._linker: - self._linker.close() - self._mnff.close() - - @staticmethod - def _can_load_generated_ptx(): - driver_ver = handle_return(driver.cuDriverGetVersion()) - nvrtc_major, nvrtc_minor = handle_return(nvrtc.nvrtcVersion()) - return nvrtc_major * 1000 + nvrtc_minor * 10 <= driver_ver - - def compile(self, target_type, name_expressions=(), logs=None): - """Compile the program with a specific compilation type. - - Parameters - ---------- - target_type : Any - String of the targeted compilation type. - Supported options are "ptx", "cubin" and "ltoir". - name_expressions : Union[list, tuple], optional - List of explicit name expressions to become accessible. - (Default to no expressions) - logs : Any, optional - Object with a write method to receive the logs generated - from compilation. - (Default to no logs) - - Returns - ------- - :obj:`~_module.ObjectCode` - Newly created code object. - - """ - supported_target_types = ("ptx", "cubin", "ltoir") - if target_type not in supported_target_types: - raise ValueError(f'Unsupported target_type="{target_type}" ({supported_target_types=})') - - if self._backend == "NVRTC": - if target_type == "ptx" and not self._can_load_generated_ptx(): - warn( - "The CUDA driver version is older than the backend version. " - "The generated ptx will not be loadable by the current driver.", - stacklevel=1, - category=RuntimeWarning, - ) - if name_expressions: - for n in name_expressions: - handle_return( - nvrtc.nvrtcAddNameExpression(self._mnff.handle, n.encode()), - handle=self._mnff.handle, - ) - options = self._options.as_bytes("nvrtc") - handle_return( - nvrtc.nvrtcCompileProgram(self._mnff.handle, len(options), options), - handle=self._mnff.handle, - ) - - size_func = getattr(nvrtc, f"nvrtcGet{target_type.upper()}Size") - comp_func = getattr(nvrtc, f"nvrtcGet{target_type.upper()}") - size = handle_return(size_func(self._mnff.handle), handle=self._mnff.handle) - data = b" " * size - handle_return(comp_func(self._mnff.handle, data), handle=self._mnff.handle) - - symbol_mapping = {} - if name_expressions: - for n in name_expressions: - symbol_mapping[n] = handle_return( - nvrtc.nvrtcGetLoweredName(self._mnff.handle, n.encode()), handle=self._mnff.handle - ) - - if logs is not None: - logsize = handle_return(nvrtc.nvrtcGetProgramLogSize(self._mnff.handle), handle=self._mnff.handle) - if logsize > 1: - log = b" " * logsize - handle_return(nvrtc.nvrtcGetProgramLog(self._mnff.handle, log), handle=self._mnff.handle) - logs.write(log.decode("utf-8", errors="backslashreplace")) - - return ObjectCode._init(data, target_type, symbol_mapping=symbol_mapping, name=self._options.name) - - elif self._backend == "NVVM": - if target_type not in ("ptx", "ltoir"): - raise ValueError(f'NVVM backend only supports target_type="ptx", "ltoir", got "{target_type}"') - - # TODO: flip to True when NVIDIA/cuda-python#1354 is resolved and CUDA 12 is dropped - nvvm_options = self._options._prepare_nvvm_options(as_bytes=False) - if target_type == "ltoir" and "-gen-lto" not in nvvm_options: - nvvm_options.append("-gen-lto") - nvvm = _get_nvvm_module() - with _nvvm_exception_manager(self): - nvvm.verify_program(self._mnff.handle, len(nvvm_options), nvvm_options) - nvvm.compile_program(self._mnff.handle, len(nvvm_options), nvvm_options) - - size = nvvm.get_compiled_result_size(self._mnff.handle) - data = bytearray(size) - nvvm.get_compiled_result(self._mnff.handle, data) - - if logs is not None: - logsize = nvvm.get_program_log_size(self._mnff.handle) - if logsize > 1: - log = bytearray(logsize) - nvvm.get_program_log(self._mnff.handle, log) - logs.write(log.decode("utf-8", errors="backslashreplace")) - - return ObjectCode._init(data, target_type, name=self._options.name) - - supported_backends = ("nvJitLink", "driver") - if self._backend not in supported_backends: - raise ValueError(f'Unsupported backend="{self._backend}" ({supported_backends=})') - return self._linker.link(target_type) - - @property - def backend(self) -> str: - """Return this Program instance's underlying backend.""" - return self._backend - - @property - def handle(self) -> ProgramHandleT: - """Return the underlying handle object. - - .. note:: - - The type of the returned object depends on the backend. - - .. caution:: - - This handle is a Python object. To get the memory address of the underlying C - handle, call ``int(Program.handle)``. - """ - return self._mnff.handle diff --git a/cuda_core/cuda/core/_program.pyx b/cuda_core/cuda/core/_program.pyx new file mode 100644 index 0000000000..10743e0b78 --- /dev/null +++ b/cuda_core/cuda/core/_program.pyx @@ -0,0 +1,977 @@ +# SPDX-FileCopyrightText: Copyright (c) 2024-2025 NVIDIA CORPORATION & AFFILIATES. All rights reserved. +# +# SPDX-License-Identifier: Apache-2.0 +"""Compilation machinery for CUDA programs. + +This module provides :class:`Program` for compiling source code into +:class:`~cuda.core.ObjectCode`, with :class:`ProgramOptions` for configuration. +""" + +from __future__ import annotations + +from dataclasses import dataclass +from warnings import warn + +from cuda.bindings import driver, nvrtc + +from libcpp.vector cimport vector + +from ._resource_handles cimport ( + as_cu, + as_py, + create_nvrtc_program_handle, + create_nvvm_program_handle, +) +from cuda.bindings cimport cynvrtc, cynvvm +from cuda.core._utils.cuda_utils cimport HANDLE_RETURN_NVRTC, HANDLE_RETURN_NVVM +from cuda.core._device import Device +from cuda.core._linker import Linker, LinkerHandleT, LinkerOptions +from cuda.core._module import ObjectCode +from cuda.core._utils.clear_error_support import assert_type +from cuda.core._utils.cuda_utils import ( + CUDAError, + _handle_boolean_option, + check_or_create_options, + get_binding_version, + handle_return, + is_nested_sequence, + is_sequence, +) + +__all__ = ["Program", "ProgramOptions"] + +ProgramHandleT = nvrtc.nvrtcProgram | int | LinkerHandleT +"""Type alias for program handle types across different backends. + +The ``int`` type covers NVVM handles, which don't have a wrapper class. +""" + + +# ============================================================================= +# Principal Class +# ============================================================================= + + +cdef class Program: + """Represent a compilation machinery to process programs into + :class:`~cuda.core.ObjectCode`. + + This object provides a unified interface to multiple underlying + compiler libraries. Compilation support is enabled for a wide + range of code types and compilation types. + + Parameters + ---------- + code : str | bytes | bytearray + The source code to compile. For C++ and PTX, must be a string. + For NVVM IR, can be str, bytes, or bytearray. + code_type : str + The type of source code. Must be one of ``"c++"``, ``"ptx"``, or ``"nvvm"``. + options : :class:`ProgramOptions`, optional + Options to customize the compilation process. + """ + + def __init__(self, code: str | bytes | bytearray, code_type: str, options: ProgramOptions | None = None): + Program_init(self, code, code_type, options) + + def close(self): + """Destroy this program.""" + if self._linker: + self._linker.close() + # Reset handles - the C++ shared_ptr destructor handles cleanup + self._h_nvrtc.reset() + self._h_nvvm.reset() + + def compile( + self, target_type: str, name_expressions: tuple | list = (), logs = None + ) -> ObjectCode: + """Compile the program to the specified target type. + + Parameters + ---------- + target_type : str + The compilation target. Must be one of ``"ptx"``, ``"cubin"``, or ``"ltoir"``. + name_expressions : tuple | list, optional + Sequence of name expressions to make accessible in the compiled code. + Used for template instantiation and similar cases. + logs : object, optional + Object with a ``write`` method to receive compilation logs. + + Returns + ------- + :class:`~cuda.core.ObjectCode` + The compiled object code. + """ + return Program_compile(self, target_type, name_expressions, logs) + + @property + def backend(self) -> str: + """Return this Program instance's underlying backend.""" + return self._backend + + @property + def handle(self) -> ProgramHandleT: + """Return the underlying handle object. + + .. note:: + + The type of the returned object depends on the backend. + + .. caution:: + + This handle is a Python object. To get the memory address of the underlying C + handle, call ``int(Program.handle)``. + """ + if self._backend == "NVRTC": + return as_py(self._h_nvrtc) + elif self._backend == "NVVM": + return as_py(self._h_nvvm) # returns int (NVVM uses raw integers) + else: + return self._linker.handle + + @staticmethod + def driver_can_load_nvrtc_ptx_output() -> bool: + """Check if the CUDA driver can load PTX generated by NVRTC. + + NVRTC generates PTX targeting a specific CUDA version. If the installed + driver is older than the NVRTC version, it may not be able to load the + generated PTX. + + Returns + ------- + bool + True if the driver version is new enough to load PTX generated + by the current NVRTC version, False otherwise. + + Examples + -------- + >>> if Program.driver_can_load_nvrtc_ptx_output(): + ... obj = program.compile("ptx") + ... kernel = obj.get_kernel("my_kernel") + """ + return _can_load_generated_ptx() + + def __repr__(self) -> str: + return f"" + + +# ============================================================================= +# Other Public Classes +# ============================================================================= + + +@dataclass +class ProgramOptions: + """Customizable options for configuring :class:`Program`. + + Attributes + ---------- + name : str, optional + Name of the program. If the compilation succeeds, the name is passed down to the generated `ObjectCode`. + arch : str, optional + Pass the SM architecture value, such as ``sm_`` (for generating CUBIN) or + ``compute_`` (for generating PTX). If not provided, the current device's architecture + will be used. + relocatable_device_code : bool, optional + Enable (disable) the generation of relocatable device code. + Default: False + extensible_whole_program : bool, optional + Do extensible whole program compilation of device code. + Default: False + debug : bool, optional + Generate debug information. If --dopt is not specified, then turns off all optimizations. + Default: False + lineinfo: bool, optional + Generate line-number information. + Default: False + device_code_optimize : bool, optional + Enable device code optimization. When specified along with '-G', enables limited debug information generation + for optimized device code. + Default: None + ptxas_options : Union[str, list[str]], optional + Specify one or more options directly to ptxas, the PTX optimizing assembler. Options should be strings. + For example ["-v", "-O2"]. + Default: None + max_register_count : int, optional + Specify the maximum amount of registers that GPU functions can use. + Default: None + ftz : bool, optional + When performing single-precision floating-point operations, flush denormal values to zero or preserve denormal + values. + Default: False + prec_sqrt : bool, optional + For single-precision floating-point square root, use IEEE round-to-nearest mode or use a faster approximation. + Default: True + prec_div : bool, optional + For single-precision floating-point division and reciprocals, use IEEE round-to-nearest mode or use a faster + approximation. + Default: True + fma : bool, optional + Enables (disables) the contraction of floating-point multiplies and adds/subtracts into floating-point + multiply-add operations. + Default: True + use_fast_math : bool, optional + Make use of fast math operations. + Default: False + extra_device_vectorization : bool, optional + Enables more aggressive device code vectorization in the NVVM optimizer. + Default: False + link_time_optimization : bool, optional + Generate intermediate code for later link-time optimization. + Default: False + gen_opt_lto : bool, optional + Run the optimizer passes before generating the LTO IR. + Default: False + define_macro : Union[str, tuple[str, str], list[Union[str, tuple[str, str]]]], optional + Predefine a macro. Can be either a string, in which case that macro will be set to 1, a 2 element tuple of + strings, in which case the first element is defined as the second, or a list of strings or tuples. + Default: None + undefine_macro : Union[str, list[str]], optional + Cancel any previous definition of a macro, or list of macros. + Default: None + include_path : Union[str, list[str]], optional + Add the directory or directories to the list of directories to be searched for headers. + Default: None + pre_include : Union[str, list[str]], optional + Preinclude one or more headers during preprocessing. Can be either a string or a list of strings. + Default: None + no_source_include : bool, optional + Disable the default behavior of adding the directory of each input source to the include path. + Default: False + std : str, optional + Set language dialect to C++03, C++11, C++14, C++17 or C++20. + Default: c++17 + builtin_move_forward : bool, optional + Provide builtin definitions of std::move and std::forward. + Default: True + builtin_initializer_list : bool, optional + Provide builtin definitions of std::initializer_list class and member functions. + Default: True + disable_warnings : bool, optional + Inhibit all warning messages. + Default: False + restrict : bool, optional + Programmer assertion that all kernel pointer parameters are restrict pointers. + Default: False + device_as_default_execution_space : bool, optional + Treat entities with no execution space annotation as __device__ entities. + Default: False + device_int128 : bool, optional + Allow the __int128 type in device code. + Default: False + optimization_info : str, optional + Provide optimization reports for the specified kind of optimization. + Default: None + no_display_error_number : bool, optional + Disable the display of a diagnostic number for warning messages. + Default: False + diag_error : Union[int, list[int]], optional + Emit error for a specified diagnostic message number or comma separated list of numbers. + Default: None + diag_suppress : Union[int, list[int]], optional + Suppress a specified diagnostic message number or comma separated list of numbers. + Default: None + diag_warn : Union[int, list[int]], optional + Emit warning for a specified diagnostic message number or comma separated lis of numbers. + Default: None + brief_diagnostics : bool, optional + Disable or enable showing source line and column info in a diagnostic. + Default: False + time : str, optional + Generate a CSV table with the time taken by each compilation phase. + Default: None + split_compile : int, optional + Perform compiler optimizations in parallel. + Default: 1 + fdevice_syntax_only : bool, optional + Ends device compilation after front-end syntax checking. + Default: False + minimal : bool, optional + Omit certain language features to reduce compile time for small programs. + Default: False + no_cache : bool, optional + Disable compiler caching. + Default: False + fdevice_time_trace : str, optional + Generate time trace JSON for profiling compilation (NVRTC only). + Default: None + device_float128 : bool, optional + Allow __float128 type in device code (NVRTC only). + Default: False + frandom_seed : str, optional + Set random seed for randomized optimizations (NVRTC only). + Default: None + ofast_compile : str, optional + Fast compilation mode: "0", "min", "mid", or "max" (NVRTC only). + Default: None + pch : bool, optional + Use default precompiled header (NVRTC only, CUDA 12.8+). + Default: False + create_pch : str, optional + Create precompiled header file (NVRTC only, CUDA 12.8+). + Default: None + use_pch : str, optional + Use specific precompiled header file (NVRTC only, CUDA 12.8+). + Default: None + pch_dir : str, optional + PCH directory location (NVRTC only, CUDA 12.8+). + Default: None + pch_verbose : bool, optional + Verbose PCH output (NVRTC only, CUDA 12.8+). + Default: False + pch_messages : bool, optional + Control PCH diagnostic messages (NVRTC only, CUDA 12.8+). + Default: False + instantiate_templates_in_pch : bool, optional + Control template instantiation in PCH (NVRTC only, CUDA 12.8+). + Default: False + """ + + name: str | None = "default_program" + arch: str | None = None + relocatable_device_code: bool | None = None + extensible_whole_program: bool | None = None + debug: bool | None = None + lineinfo: bool | None = None + device_code_optimize: bool | None = None + ptxas_options: str | list[str] | tuple[str] | None = None + max_register_count: int | None = None + ftz: bool | None = None + prec_sqrt: bool | None = None + prec_div: bool | None = None + fma: bool | None = None + use_fast_math: bool | None = None + extra_device_vectorization: bool | None = None + link_time_optimization: bool | None = None + gen_opt_lto: bool | None = None + define_macro: str | tuple[str, str] | list[str | tuple[str, str]] | tuple[str | tuple[str, str], ...] | None = None + undefine_macro: str | list[str] | tuple[str] | None = None + include_path: str | list[str] | tuple[str] | None = None + pre_include: str | list[str] | tuple[str] | None = None + no_source_include: bool | None = None + std: str | None = None + builtin_move_forward: bool | None = None + builtin_initializer_list: bool | None = None + disable_warnings: bool | None = None + restrict: bool | None = None + device_as_default_execution_space: bool | None = None + device_int128: bool | None = None + optimization_info: str | None = None + no_display_error_number: bool | None = None + diag_error: int | list[int] | tuple[int] | None = None + diag_suppress: int | list[int] | tuple[int] | None = None + diag_warn: int | list[int] | tuple[int] | None = None + brief_diagnostics: bool | None = None + time: str | None = None + split_compile: int | None = None + fdevice_syntax_only: bool | None = None + minimal: bool | None = None + no_cache: bool | None = None + fdevice_time_trace: str | None = None + device_float128: bool | None = None + frandom_seed: str | None = None + ofast_compile: str | None = None + pch: bool | None = None + create_pch: str | None = None + use_pch: str | None = None + pch_dir: str | None = None + pch_verbose: bool | None = None + pch_messages: bool | None = None + instantiate_templates_in_pch: bool | None = None + numba_debug: bool | None = None # Custom option for Numba debugging + + def __post_init__(self): + self._name = self.name.encode() + # Set arch to default if not provided + if self.arch is None: + self.arch = f"sm_{Device().arch}" + + def _prepare_nvrtc_options(self) -> list[bytes]: + return _prepare_nvrtc_options_impl(self) + + def _prepare_nvvm_options(self, as_bytes: bool = True) -> list[bytes] | list[str]: + return _prepare_nvvm_options_impl(self, as_bytes) + + def as_bytes(self, backend: str, target_type: str | None = None) -> list[bytes]: + """Convert program options to bytes format for the specified backend. + + This method transforms the program options into a format suitable for the + specified compiler backend. Different backends may use different option names + and formats even for the same conceptual options. + + Parameters + ---------- + backend : str + The compiler backend to prepare options for. Must be either "nvrtc" or "nvvm". + target_type : str, optional + The compilation target type (e.g., "ptx", "cubin", "ltoir"). Some backends + require additional options based on the target type. + + Returns + ------- + list[bytes] + List of option strings encoded as bytes. + + Raises + ------ + ValueError + If an unknown backend is specified. + CUDAError + If an option incompatible with the specified backend is set. + + Examples + -------- + >>> options = ProgramOptions(arch="sm_80", debug=True) + >>> nvrtc_options = options.as_bytes("nvrtc") + """ + backend = backend.lower() + if backend == "nvrtc": + return self._prepare_nvrtc_options() + elif backend == "nvvm": + options = self._prepare_nvvm_options(as_bytes=True) + if target_type == "ltoir" and b"-gen-lto" not in options: + options.append(b"-gen-lto") + return options + else: + raise ValueError(f"Unknown backend '{backend}'. Must be one of: 'nvrtc', 'nvvm'") + + def __repr__(self): + return f"ProgramOptions(name={self.name!r}, arch={self.arch!r})" + + +# ============================================================================= +# Private Classes and Helper Functions +# ============================================================================= + +# Module-level state for NVVM lazy loading +cdef object_nvvm_module = None +cdef bint _nvvm_import_attempted = False + + +def _get_nvvm_module(): + """Get the NVVM module, importing it lazily with availability checks.""" + global _nvvm_module, _nvvm_import_attempted + + if _nvvm_import_attempted: + if _nvvm_module is None: + raise RuntimeError("NVVM module is not available (previous import attempt failed)") + return _nvvm_module + + _nvvm_import_attempted = True + + try: + version = get_binding_version() + if version < (12, 9): + raise RuntimeError( + f"NVVM bindings require cuda-bindings >= 12.9.0, but found {version[0]}.{version[1]}.x. " + "Please update cuda-bindings to use NVVM features." + ) + + from cuda.bindings import nvvm + from cuda.bindings._internal.nvvm import _inspect_function_pointer + + if _inspect_function_pointer("__nvvmCreateProgram") == 0: + raise RuntimeError("NVVM library (libnvvm) is not available in this Python environment. ") + + _nvvm_module = nvvm + return _nvvm_module + + except RuntimeError as e: + _nvvm_module = None + raise e + + +cdef inline bint _process_define_macro_inner(list options, object macro) except? -1: + """Process a single define macro, returning True if successful.""" + if isinstance(macro, str): + options.append(f"--define-macro={macro}") + return True + if isinstance(macro, tuple): + if len(macro) != 2 or any(not isinstance(val, str) for val in macro): + raise RuntimeError(f"Expected define_macro tuple[str, str], got {macro}") + options.append(f"--define-macro={macro[0]}={macro[1]}") + return True + return False + + +cdef inline void _process_define_macro(list options, object macro) except *: + """Process define_macro option which can be str, tuple, or list thereof.""" + union_type = "Union[str, tuple[str, str]]" + if _process_define_macro_inner(options, macro): + return + if is_nested_sequence(macro): + for seq_macro in macro: + if not _process_define_macro_inner(options, seq_macro): + raise RuntimeError(f"Expected define_macro {union_type}, got {seq_macro}") + return + raise RuntimeError(f"Expected define_macro {union_type}, list[{union_type}], got {macro}") + + +cdef inline bint _can_load_generated_ptx() except? -1: + """Check if the driver can load PTX generated by the current NVRTC version.""" + driver_ver = handle_return(driver.cuDriverGetVersion()) + nvrtc_major, nvrtc_minor = handle_return(nvrtc.nvrtcVersion()) + return nvrtc_major * 1000 + nvrtc_minor * 10 <= driver_ver + + +cdef inline object _translate_program_options(object options): + """Translate ProgramOptions to LinkerOptions for PTX compilation.""" + return LinkerOptions( + name=options.name, + arch=options.arch, + max_register_count=options.max_register_count, + time=options.time, + link_time_optimization=options.link_time_optimization, + debug=options.debug, + lineinfo=options.lineinfo, + ftz=options.ftz, + prec_div=options.prec_div, + prec_sqrt=options.prec_sqrt, + fma=options.fma, + split_compile=options.split_compile, + ptxas_options=options.ptxas_options, + no_cache=options.no_cache, + ) + + +cdef inline int Program_init(Program self, object code, str code_type, object options) except -1: + """Initialize a Program instance.""" + cdef cynvrtc.nvrtcProgram nvrtc_prog + cdef cynvvm.nvvmProgram nvvm_prog + cdef bytes code_bytes + cdef const char* code_ptr + cdef const char* name_ptr + cdef size_t code_len + + self._options = options = check_or_create_options(ProgramOptions, options, "Program options") + code_type = code_type.lower() + + if code_type == "c++": + assert_type(code, str) + # TODO: support pre-loaded headers & include names + code_bytes = code.encode() + code_ptr = code_bytes + name_ptr = options._name + + with nogil: + HANDLE_RETURN_NVRTC(NULL, cynvrtc.nvrtcCreateProgram( + &nvrtc_prog, code_ptr, name_ptr, 0, NULL, NULL)) + self._h_nvrtc = create_nvrtc_program_handle(nvrtc_prog) + self._backend = "NVRTC" + self._linker = None + + elif code_type == "ptx": + assert_type(code, str) + self._linker = Linker( + ObjectCode._init(code.encode(), code_type), options=_translate_program_options(options) + ) + self._backend = self._linker.backend + + elif code_type == "nvvm": + _get_nvvm_module() # Validate NVVM availability + if isinstance(code, str): + code = code.encode("utf-8") + elif not isinstance(code, (bytes, bytearray)): + raise TypeError("NVVM IR code must be provided as str, bytes, or bytearray") + + code_ptr = (code) + name_ptr = options._name + code_len = len(code) + + with nogil: + HANDLE_RETURN_NVVM(NULL, cynvvm.nvvmCreateProgram(&nvvm_prog)) + self._h_nvvm = create_nvvm_program_handle(nvvm_prog) # RAII from here + with nogil: + HANDLE_RETURN_NVVM(nvvm_prog, cynvvm.nvvmAddModuleToProgram(nvvm_prog, code_ptr, code_len, name_ptr)) + self._backend = "NVVM" + self._linker = None + + else: + supported_code_types = ("c++", "ptx", "nvvm") + assert code_type not in supported_code_types, f"{code_type=}" + raise RuntimeError(f"Unsupported {code_type=} ({supported_code_types=})") + + return 0 + + +cdef object Program_compile_nvrtc(Program self, str target_type, object name_expressions, object logs): + """Compile using NVRTC backend and return ObjectCode.""" + cdef cynvrtc.nvrtcProgram prog = as_cu(self._h_nvrtc) + cdef size_t output_size = 0 + cdef size_t logsize = 0 + cdef vector[const char*] options_vec + cdef char* data_ptr = NULL + cdef bytes name_bytes + cdef const char* name_ptr = NULL + cdef const char* lowered_name = NULL + cdef dict symbol_mapping = {} + + # Add name expressions before compilation + if name_expressions: + for n in name_expressions: + name_bytes = n.encode() if isinstance(n, str) else n + name_ptr = name_bytes + HANDLE_RETURN_NVRTC(prog, cynvrtc.nvrtcAddNameExpression(prog, name_ptr)) + + # Build options array + options_list = self._options.as_bytes("nvrtc", target_type) + options_vec.resize(len(options_list)) + for i in range(len(options_list)): + options_vec[i] = (options_list[i]) + + # Compile + with nogil: + HANDLE_RETURN_NVRTC(prog, cynvrtc.nvrtcCompileProgram(prog, options_vec.size(), options_vec.data())) + + # Get compiled output based on target type + if target_type == "ptx": + HANDLE_RETURN_NVRTC(prog, cynvrtc.nvrtcGetPTXSize(prog, &output_size)) + data = bytearray(output_size) + data_ptr = (data) + with nogil: + HANDLE_RETURN_NVRTC(prog, cynvrtc.nvrtcGetPTX(prog, data_ptr)) + elif target_type == "cubin": + HANDLE_RETURN_NVRTC(prog, cynvrtc.nvrtcGetCUBINSize(prog, &output_size)) + data = bytearray(output_size) + data_ptr = (data) + with nogil: + HANDLE_RETURN_NVRTC(prog, cynvrtc.nvrtcGetCUBIN(prog, data_ptr)) + else: # ltoir + HANDLE_RETURN_NVRTC(prog, cynvrtc.nvrtcGetLTOIRSize(prog, &output_size)) + data = bytearray(output_size) + data_ptr = (data) + with nogil: + HANDLE_RETURN_NVRTC(prog, cynvrtc.nvrtcGetLTOIR(prog, data_ptr)) + + # Get lowered names after compilation + if name_expressions: + for n in name_expressions: + name_bytes = n.encode() if isinstance(n, str) else n + name_ptr = name_bytes + HANDLE_RETURN_NVRTC(prog, cynvrtc.nvrtcGetLoweredName(prog, name_ptr, &lowered_name)) + symbol_mapping[n] = lowered_name if lowered_name != NULL else None + + # Get compilation log if requested + if logs is not None: + HANDLE_RETURN_NVRTC(prog, cynvrtc.nvrtcGetProgramLogSize(prog, &logsize)) + if logsize > 1: + log = bytearray(logsize) + data_ptr = (log) + with nogil: + HANDLE_RETURN_NVRTC(prog, cynvrtc.nvrtcGetProgramLog(prog, data_ptr)) + logs.write(log.decode("utf-8", errors="backslashreplace")) + + return ObjectCode._init(bytes(data), target_type, symbol_mapping=symbol_mapping, name=self._options.name) + + +cdef object Program_compile_nvvm(Program self, str target_type, object logs): + """Compile using NVVM backend and return ObjectCode.""" + cdef cynvvm.nvvmProgram prog = as_cu(self._h_nvvm) + cdef size_t output_size = 0 + cdef size_t logsize = 0 + cdef vector[const char*] options_vec + cdef char* data_ptr = NULL + + # Build options array + options_list = self._options.as_bytes("nvvm", target_type) + options_vec.resize(len(options_list)) + for i in range(len(options_list)): + options_vec[i] = (options_list[i]) + + # Compile + with nogil: + HANDLE_RETURN_NVVM(prog, cynvvm.nvvmVerifyProgram(prog, options_vec.size(), options_vec.data())) + HANDLE_RETURN_NVVM(prog, cynvvm.nvvmCompileProgram(prog, options_vec.size(), options_vec.data())) + + # Get compiled result + HANDLE_RETURN_NVVM(prog, cynvvm.nvvmGetCompiledResultSize(prog, &output_size)) + data = bytearray(output_size) + data_ptr = (data) + with nogil: + HANDLE_RETURN_NVVM(prog, cynvvm.nvvmGetCompiledResult(prog, data_ptr)) + + # Get compilation log if requested + if logs is not None: + HANDLE_RETURN_NVVM(prog, cynvvm.nvvmGetProgramLogSize(prog, &logsize)) + if logsize > 1: + log = bytearray(logsize) + data_ptr = (log) + with nogil: + HANDLE_RETURN_NVVM(prog, cynvvm.nvvmGetProgramLog(prog, data_ptr)) + logs.write(log.decode("utf-8", errors="backslashreplace")) + + return ObjectCode._init(bytes(data), target_type, name=self._options.name) + +# Supported target types per backend +cdef dict SUPPORTED_TARGETS = { + "NVRTC": ("ptx", "cubin", "ltoir"), + "NVVM": ("ptx", "ltoir"), + "nvJitLink": ("cubin", "ptx"), + "driver": ("cubin", "ptx"), +} + + +cdef object Program_compile(Program self, str target_type, object name_expressions, object logs): + """Compile the program to the specified target type.""" + # Validate target_type for this backend + supported = SUPPORTED_TARGETS.get(self._backend) + if supported is None: + raise ValueError(f'Unknown backend="{self._backend}"') + if target_type not in supported: + raise ValueError( + f'Unsupported target_type="{target_type}" for {self._backend} ' + f'(supported: {", ".join(repr(t) for t in supported)})' + ) + + if self._backend == "NVRTC": + if target_type == "ptx" and not _can_load_generated_ptx(): + warn( + "The CUDA driver version is older than the backend version. " + "The generated ptx will not be loadable by the current driver.", + stacklevel=2, + category=RuntimeWarning, + ) + return Program_compile_nvrtc(self, target_type, name_expressions, logs) + + elif self._backend == "NVVM": + return Program_compile_nvvm(self, target_type, logs) + + else: + return self._linker.link(target_type) + + +cdef inline list _prepare_nvrtc_options_impl(object opts): + """Build NVRTC-specific compiler options.""" + options = [f"-arch={opts.arch}"] + if opts.relocatable_device_code is not None: + options.append(f"--relocatable-device-code={_handle_boolean_option(opts.relocatable_device_code)}") + if opts.extensible_whole_program is not None and opts.extensible_whole_program: + options.append("--extensible-whole-program") + if opts.debug is not None and opts.debug: + options.append("--device-debug") + if opts.lineinfo is not None and opts.lineinfo: + options.append("--generate-line-info") + if opts.device_code_optimize is not None and opts.device_code_optimize: + options.append("--dopt=on") + if opts.ptxas_options is not None: + opt_name = "--ptxas-options" + if isinstance(opts.ptxas_options, str): + options.append(f"{opt_name}={opts.ptxas_options}") + elif is_sequence(opts.ptxas_options): + for opt_value in opts.ptxas_options: + options.append(f"{opt_name}={opt_value}") + if opts.max_register_count is not None: + options.append(f"--maxrregcount={opts.max_register_count}") + if opts.ftz is not None: + options.append(f"--ftz={_handle_boolean_option(opts.ftz)}") + if opts.prec_sqrt is not None: + options.append(f"--prec-sqrt={_handle_boolean_option(opts.prec_sqrt)}") + if opts.prec_div is not None: + options.append(f"--prec-div={_handle_boolean_option(opts.prec_div)}") + if opts.fma is not None: + options.append(f"--fmad={_handle_boolean_option(opts.fma)}") + if opts.use_fast_math is not None and opts.use_fast_math: + options.append("--use_fast_math") + if opts.extra_device_vectorization is not None and opts.extra_device_vectorization: + options.append("--extra-device-vectorization") + if opts.link_time_optimization is not None and opts.link_time_optimization: + options.append("--dlink-time-opt") + if opts.gen_opt_lto is not None and opts.gen_opt_lto: + options.append("--gen-opt-lto") + if opts.define_macro is not None: + _process_define_macro(options, opts.define_macro) + if opts.undefine_macro is not None: + if isinstance(opts.undefine_macro, str): + options.append(f"--undefine-macro={opts.undefine_macro}") + elif is_sequence(opts.undefine_macro): + for macro in opts.undefine_macro: + options.append(f"--undefine-macro={macro}") + if opts.include_path is not None: + if isinstance(opts.include_path, str): + options.append(f"--include-path={opts.include_path}") + elif is_sequence(opts.include_path): + for path in opts.include_path: + options.append(f"--include-path={path}") + if opts.pre_include is not None: + if isinstance(opts.pre_include, str): + options.append(f"--pre-include={opts.pre_include}") + elif is_sequence(opts.pre_include): + for header in opts.pre_include: + options.append(f"--pre-include={header}") + if opts.no_source_include is not None and opts.no_source_include: + options.append("--no-source-include") + if opts.std is not None: + options.append(f"--std={opts.std}") + if opts.builtin_move_forward is not None: + options.append(f"--builtin-move-forward={_handle_boolean_option(opts.builtin_move_forward)}") + if opts.builtin_initializer_list is not None: + options.append(f"--builtin-initializer-list={_handle_boolean_option(opts.builtin_initializer_list)}") + if opts.disable_warnings is not None and opts.disable_warnings: + options.append("--disable-warnings") + if opts.restrict is not None and opts.restrict: + options.append("--restrict") + if opts.device_as_default_execution_space is not None and opts.device_as_default_execution_space: + options.append("--device-as-default-execution-space") + if opts.device_int128 is not None and opts.device_int128: + options.append("--device-int128") + if opts.device_float128 is not None and opts.device_float128: + options.append("--device-float128") + if opts.optimization_info is not None: + options.append(f"--optimization-info={opts.optimization_info}") + if opts.no_display_error_number is not None and opts.no_display_error_number: + options.append("--no-display-error-number") + if opts.diag_error is not None: + if isinstance(opts.diag_error, int): + options.append(f"--diag-error={opts.diag_error}") + elif is_sequence(opts.diag_error): + for error in opts.diag_error: + options.append(f"--diag-error={error}") + if opts.diag_suppress is not None: + if isinstance(opts.diag_suppress, int): + options.append(f"--diag-suppress={opts.diag_suppress}") + elif is_sequence(opts.diag_suppress): + for suppress in opts.diag_suppress: + options.append(f"--diag-suppress={suppress}") + if opts.diag_warn is not None: + if isinstance(opts.diag_warn, int): + options.append(f"--diag-warn={opts.diag_warn}") + elif is_sequence(opts.diag_warn): + for w in opts.diag_warn: + options.append(f"--diag-warn={w}") + if opts.brief_diagnostics is not None: + options.append(f"--brief-diagnostics={_handle_boolean_option(opts.brief_diagnostics)}") + if opts.time is not None: + options.append(f"--time={opts.time}") + if opts.split_compile is not None: + options.append(f"--split-compile={opts.split_compile}") + if opts.fdevice_syntax_only is not None and opts.fdevice_syntax_only: + options.append("--fdevice-syntax-only") + if opts.minimal is not None and opts.minimal: + options.append("--minimal") + if opts.no_cache is not None and opts.no_cache: + options.append("--no-cache") + if opts.fdevice_time_trace is not None: + options.append(f"--fdevice-time-trace={opts.fdevice_time_trace}") + if opts.frandom_seed is not None: + options.append(f"--frandom-seed={opts.frandom_seed}") + if opts.ofast_compile is not None: + options.append(f"--Ofast-compile={opts.ofast_compile}") + # PCH options (CUDA 12.8+) + if opts.pch is not None and opts.pch: + options.append("--pch") + if opts.create_pch is not None: + options.append(f"--create-pch={opts.create_pch}") + if opts.use_pch is not None: + options.append(f"--use-pch={opts.use_pch}") + if opts.pch_dir is not None: + options.append(f"--pch-dir={opts.pch_dir}") + if opts.pch_verbose is not None: + options.append(f"--pch-verbose={_handle_boolean_option(opts.pch_verbose)}") + if opts.pch_messages is not None: + options.append(f"--pch-messages={_handle_boolean_option(opts.pch_messages)}") + if opts.instantiate_templates_in_pch is not None: + options.append( + f"--instantiate-templates-in-pch={_handle_boolean_option(opts.instantiate_templates_in_pch)}" + ) + if opts.numba_debug: + options.append("--numba-debug") + return [o.encode() for o in options] + + +cdef inline object _prepare_nvvm_options_impl(object opts, bint as_bytes): + """Build NVVM-specific compiler options.""" + options = [] + + # Options supported by NVVM + assert opts.arch is not None + arch = opts.arch + if arch.startswith("sm_"): + arch = f"compute_{arch[3:]}" + options.append(f"-arch={arch}") + if opts.debug is not None and opts.debug: + options.append("-g") + if opts.device_code_optimize is False: + options.append("-opt=0") + elif opts.device_code_optimize is True: + options.append("-opt=3") + # NVVM uses 0/1 instead of true/false for boolean options + if opts.ftz is not None: + options.append(f"-ftz={'1' if opts.ftz else '0'}") + if opts.prec_sqrt is not None: + options.append(f"-prec-sqrt={'1' if opts.prec_sqrt else '0'}") + if opts.prec_div is not None: + options.append(f"-prec-div={'1' if opts.prec_div else '0'}") + if opts.fma is not None: + options.append(f"-fma={'1' if opts.fma else '0'}") + + # Check for unsupported options and raise error if they are set + unsupported = [] + if opts.relocatable_device_code is not None: + unsupported.append("relocatable_device_code") + if opts.extensible_whole_program is not None and opts.extensible_whole_program: + unsupported.append("extensible_whole_program") + if opts.lineinfo is not None and opts.lineinfo: + unsupported.append("lineinfo") + if opts.ptxas_options is not None: + unsupported.append("ptxas_options") + if opts.max_register_count is not None: + unsupported.append("max_register_count") + if opts.use_fast_math is not None and opts.use_fast_math: + unsupported.append("use_fast_math") + if opts.extra_device_vectorization is not None and opts.extra_device_vectorization: + unsupported.append("extra_device_vectorization") + if opts.gen_opt_lto is not None and opts.gen_opt_lto: + unsupported.append("gen_opt_lto") + if opts.define_macro is not None: + unsupported.append("define_macro") + if opts.undefine_macro is not None: + unsupported.append("undefine_macro") + if opts.include_path is not None: + unsupported.append("include_path") + if opts.pre_include is not None: + unsupported.append("pre_include") + if opts.no_source_include is not None and opts.no_source_include: + unsupported.append("no_source_include") + if opts.std is not None: + unsupported.append("std") + if opts.builtin_move_forward is not None: + unsupported.append("builtin_move_forward") + if opts.builtin_initializer_list is not None: + unsupported.append("builtin_initializer_list") + if opts.disable_warnings is not None and opts.disable_warnings: + unsupported.append("disable_warnings") + if opts.restrict is not None and opts.restrict: + unsupported.append("restrict") + if opts.device_as_default_execution_space is not None and opts.device_as_default_execution_space: + unsupported.append("device_as_default_execution_space") + if opts.device_int128 is not None and opts.device_int128: + unsupported.append("device_int128") + if opts.optimization_info is not None: + unsupported.append("optimization_info") + if opts.no_display_error_number is not None and opts.no_display_error_number: + unsupported.append("no_display_error_number") + if opts.diag_error is not None: + unsupported.append("diag_error") + if opts.diag_suppress is not None: + unsupported.append("diag_suppress") + if opts.diag_warn is not None: + unsupported.append("diag_warn") + if opts.brief_diagnostics is not None: + unsupported.append("brief_diagnostics") + if opts.time is not None: + unsupported.append("time") + if opts.split_compile is not None: + unsupported.append("split_compile") + if opts.fdevice_syntax_only is not None and opts.fdevice_syntax_only: + unsupported.append("fdevice_syntax_only") + if opts.minimal is not None and opts.minimal: + unsupported.append("minimal") + if opts.numba_debug is not None and opts.numba_debug: + unsupported.append("numba_debug") + if unsupported: + raise CUDAError(f"The following options are not supported by NVVM backend: {', '.join(unsupported)}") + + if as_bytes: + return [o.encode() for o in options] + else: + return options diff --git a/cuda_core/cuda/core/_resource_handles.pxd b/cuda_core/cuda/core/_resource_handles.pxd index 5f08172909..d573862d16 100644 --- a/cuda_core/cuda/core/_resource_handles.pxd +++ b/cuda_core/cuda/core/_resource_handles.pxd @@ -8,6 +8,8 @@ from libc.stdint cimport intptr_t from libcpp.memory cimport shared_ptr from cuda.bindings cimport cydriver +from cuda.bindings cimport cynvrtc +from cuda.bindings cimport cynvvm # ============================================================================= @@ -23,6 +25,8 @@ cdef extern from "_cpp/resource_handles.hpp" namespace "cuda_core": ctypedef shared_ptr[const cydriver.CUdeviceptr] DevicePtrHandle ctypedef shared_ptr[const cydriver.CUlibrary] LibraryHandle ctypedef shared_ptr[const cydriver.CUkernel] KernelHandle + ctypedef shared_ptr[const cynvrtc.nvrtcProgram] NvrtcProgramHandle + ctypedef shared_ptr[const cynvvm.nvvmProgram] NvvmProgramHandle # as_cu() - extract the raw CUDA handle (inline C++) cydriver.CUcontext as_cu(ContextHandle h) noexcept nogil @@ -32,6 +36,8 @@ cdef extern from "_cpp/resource_handles.hpp" namespace "cuda_core": cydriver.CUdeviceptr as_cu(DevicePtrHandle h) noexcept nogil cydriver.CUlibrary as_cu(LibraryHandle h) noexcept nogil cydriver.CUkernel as_cu(KernelHandle h) noexcept nogil + cynvrtc.nvrtcProgram as_cu(NvrtcProgramHandle h) noexcept nogil + cynvvm.nvvmProgram as_cu(NvvmProgramHandle h) noexcept nogil # as_intptr() - extract handle as intptr_t for Python interop (inline C++) intptr_t as_intptr(ContextHandle h) noexcept nogil @@ -41,8 +47,10 @@ cdef extern from "_cpp/resource_handles.hpp" namespace "cuda_core": intptr_t as_intptr(DevicePtrHandle h) noexcept nogil intptr_t as_intptr(LibraryHandle h) noexcept nogil intptr_t as_intptr(KernelHandle h) noexcept nogil + intptr_t as_intptr(NvrtcProgramHandle h) noexcept nogil + intptr_t as_intptr(NvvmProgramHandle h) noexcept nogil - # as_py() - convert handle to Python driver wrapper object (inline C++; requires GIL) + # as_py() - convert handle to Python wrapper object (inline C++; requires GIL) object as_py(ContextHandle h) object as_py(StreamHandle h) object as_py(EventHandle h) @@ -50,6 +58,8 @@ cdef extern from "_cpp/resource_handles.hpp" namespace "cuda_core": object as_py(DevicePtrHandle h) object as_py(LibraryHandle h) object as_py(KernelHandle h) + object as_py(NvrtcProgramHandle h) + object as_py(NvvmProgramHandle h) # ============================================================================= @@ -112,3 +122,11 @@ cdef LibraryHandle create_library_handle_ref(cydriver.CUlibrary library) except+ cdef KernelHandle create_kernel_handle(const LibraryHandle& h_library, const char* name) except+ nogil cdef KernelHandle create_kernel_handle_ref( cydriver.CUkernel kernel, const LibraryHandle& h_library) except+ nogil + +# NVRTC Program handles +cdef NvrtcProgramHandle create_nvrtc_program_handle(cynvrtc.nvrtcProgram prog) except+ nogil +cdef NvrtcProgramHandle create_nvrtc_program_handle_ref(cynvrtc.nvrtcProgram prog) except+ nogil + +# NVVM Program handles +cdef NvvmProgramHandle create_nvvm_program_handle(cynvvm.nvvmProgram prog) except+ nogil +cdef NvvmProgramHandle create_nvvm_program_handle_ref(cynvvm.nvvmProgram prog) except+ nogil diff --git a/cuda_core/cuda/core/_resource_handles.pyx b/cuda_core/cuda/core/_resource_handles.pyx index 0d3e732a4f..2652d4448e 100644 --- a/cuda_core/cuda/core/_resource_handles.pyx +++ b/cuda_core/cuda/core/_resource_handles.pyx @@ -14,6 +14,8 @@ from cpython.pycapsule cimport PyCapsule_GetName, PyCapsule_GetPointer from libc.stddef cimport size_t from cuda.bindings cimport cydriver +from cuda.bindings cimport cynvrtc +from cuda.bindings cimport cynvvm from ._resource_handles cimport ( ContextHandle, @@ -23,9 +25,13 @@ from ._resource_handles cimport ( DevicePtrHandle, LibraryHandle, KernelHandle, + NvrtcProgramHandle, + NvvmProgramHandle, ) import cuda.bindings.cydriver as cydriver +import cuda.bindings.cynvrtc as cynvrtc +import cuda.bindings.cynvvm as cynvvm # ============================================================================= # C++ function declarations (non-inline, implemented in resource_handles.cpp) @@ -107,6 +113,18 @@ cdef extern from "_cpp/resource_handles.hpp" namespace "cuda_core": KernelHandle create_kernel_handle_ref "cuda_core::create_kernel_handle_ref" ( cydriver.CUkernel kernel, const LibraryHandle& h_library) except+ nogil + # NVRTC Program handles + NvrtcProgramHandle create_nvrtc_program_handle "cuda_core::create_nvrtc_program_handle" ( + cynvrtc.nvrtcProgram prog) except+ nogil + NvrtcProgramHandle create_nvrtc_program_handle_ref "cuda_core::create_nvrtc_program_handle_ref" ( + cynvrtc.nvrtcProgram prog) except+ nogil + + # NVVM Program handles + NvvmProgramHandle create_nvvm_program_handle "cuda_core::create_nvvm_program_handle" ( + cynvvm.nvvmProgram prog) except+ nogil + NvvmProgramHandle create_nvvm_program_handle_ref "cuda_core::create_nvvm_program_handle_ref" ( + cynvvm.nvvmProgram prog) except+ nogil + # ============================================================================= # CUDA Driver API capsule @@ -174,6 +192,12 @@ cdef extern from "_cpp/resource_handles.hpp" namespace "cuda_core": void* p_cuLibraryUnload "reinterpret_cast(cuda_core::p_cuLibraryUnload)" void* p_cuLibraryGetKernel "reinterpret_cast(cuda_core::p_cuLibraryGetKernel)" + # NVRTC + void* p_nvrtcDestroyProgram "reinterpret_cast(cuda_core::p_nvrtcDestroyProgram)" + + # NVVM + void* p_nvvmDestroyProgram "reinterpret_cast(cuda_core::p_nvvmDestroyProgram)" + # Initialize driver function pointers from cydriver.__pyx_capi__ at module load cdef void* _get_driver_fn(str name): @@ -223,3 +247,26 @@ p_cuLibraryLoadFromFile = _get_driver_fn("cuLibraryLoadFromFile") p_cuLibraryLoadData = _get_driver_fn("cuLibraryLoadData") p_cuLibraryUnload = _get_driver_fn("cuLibraryUnload") p_cuLibraryGetKernel = _get_driver_fn("cuLibraryGetKernel") + +# ============================================================================= +# NVRTC function pointer initialization +# ============================================================================= + +cdef void* _get_nvrtc_fn(str name): + capsule = cynvrtc.__pyx_capi__[name] + return PyCapsule_GetPointer(capsule, PyCapsule_GetName(capsule)) + +p_nvrtcDestroyProgram = _get_nvrtc_fn("nvrtcDestroyProgram") + +# ============================================================================= +# NVVM function pointer initialization +# +# NVVM may not be available at runtime, so we handle missing function pointers +# gracefully. The C++ deleter checks for null before calling. +# ============================================================================= + +cdef void* _get_nvvm_fn(str name): + capsule = cynvvm.__pyx_capi__[name] + return PyCapsule_GetPointer(capsule, PyCapsule_GetName(capsule)) + +p_nvvmDestroyProgram = _get_nvvm_fn("nvvmDestroyProgram") diff --git a/cuda_core/cuda/core/_utils/cuda_utils.pxd b/cuda_core/cuda/core/_utils/cuda_utils.pxd index 9b5044beda..339b485682 100644 --- a/cuda_core/cuda/core/_utils/cuda_utils.pxd +++ b/cuda_core/cuda/core/_utils/cuda_utils.pxd @@ -6,11 +6,7 @@ cimport cpython from cpython.object cimport PyObject from libc.stdint cimport int64_t, int32_t -from cuda.bindings cimport cydriver - - -ctypedef fused supported_error_type: - cydriver.CUresult +from cuda.bindings cimport cydriver, cynvrtc, cynvvm ctypedef fused integer_t: @@ -22,7 +18,9 @@ ctypedef fused integer_t: cdef const cydriver.CUcontext CU_CONTEXT_INVALID = (-2) -cdef int HANDLE_RETURN(supported_error_type err) except?-1 nogil +cdef int HANDLE_RETURN(cydriver.CUresult err) except?-1 nogil +cdef int HANDLE_RETURN_NVRTC(cynvrtc.nvrtcProgram prog, cynvrtc.nvrtcResult err) except?-1 nogil +cdef int HANDLE_RETURN_NVVM(cynvvm.nvvmProgram prog, cynvvm.nvvmResult err) except?-1 nogil # TODO: stop exposing these within the codebase? diff --git a/cuda_core/cuda/core/_utils/cuda_utils.pyx b/cuda_core/cuda/core/_utils/cuda_utils.pyx index c7f867a0d5..a3c49d8e27 100644 --- a/cuda_core/cuda/core/_utils/cuda_utils.pyx +++ b/cuda_core/cuda/core/_utils/cuda_utils.pyx @@ -20,6 +20,8 @@ except ImportError: from cuda import cudart as runtime from cuda import nvrtc +from cuda.bindings cimport cynvrtc, cynvvm + from cuda.core._utils.driver_cu_result_explanations import DRIVER_CU_RESULT_EXPLANATIONS from cuda.core._utils.runtime_cuda_error_explanations import RUNTIME_CUDA_ERROR_EXPLANATIONS @@ -32,6 +34,10 @@ class NVRTCError(CUDAError): pass +class NVVMError(CUDAError): + pass + + ComputeCapability = namedtuple("ComputeCapability", ("major", "minor")) @@ -57,10 +63,72 @@ def _reduce_3_tuple(t: tuple): return t[0] * t[1] * t[2] -cdef int HANDLE_RETURN(supported_error_type err) except?-1 nogil: - if supported_error_type is cydriver.CUresult: - if err != cydriver.CUresult.CUDA_SUCCESS: - return _check_driver_error(err) +cdef int HANDLE_RETURN(cydriver.CUresult err) except?-1 nogil: + if err != cydriver.CUresult.CUDA_SUCCESS: + return _check_driver_error(err) + return 0 + + +cdef int HANDLE_RETURN_NVRTC(cynvrtc.nvrtcProgram prog, cynvrtc.nvrtcResult err) except?-1 nogil: + """Handle NVRTC result codes, raising NVRTCError with program log on failure.""" + if err == cynvrtc.nvrtcResult.NVRTC_SUCCESS: + return 0 + # Get error string (can be called without GIL) + cdef const char* err_str = cynvrtc.nvrtcGetErrorString(err) + # Get program log size for additional context + cdef size_t logsize = 0 + if prog != NULL: + cynvrtc.nvrtcGetProgramLogSize(prog, &logsize) + # Need GIL for Python string operations and exception raising + with gil: + _raise_nvrtc_error(err, err_str, prog, logsize) + return -1 # Never reached, but satisfies return type + + +cdef int _raise_nvrtc_error(cynvrtc.nvrtcResult err, const char* err_str, + cynvrtc.nvrtcProgram prog, size_t logsize) except -1: + """Helper to raise NVRTCError with program log (requires GIL).""" + cdef bytes log_bytes + cdef str log_str = "" + if logsize > 1 and prog != NULL: + log_bytes = b" " * logsize + if cynvrtc.nvrtcGetProgramLog(prog, log_bytes) == cynvrtc.nvrtcResult.NVRTC_SUCCESS: + log_str = log_bytes.decode("utf-8", errors="backslashreplace") + err_msg = f"{err}: {err_str.decode()}" if err_str != NULL else f"NVRTC error {err}" + if log_str: + err_msg += f", compilation log:\n\n{log_str}" + raise NVRTCError(err_msg) + + +cdef int HANDLE_RETURN_NVVM(cynvvm.nvvmProgram prog, cynvvm.nvvmResult err) except?-1 nogil: + """Handle NVVM result codes, raising NVVMError with program log on failure.""" + if err == cynvvm.nvvmResult.NVVM_SUCCESS: + return 0 + # Get error string (can be called without GIL) + cdef const char* err_str = cynvvm.nvvmGetErrorString(err) + # Get program log size for additional context + cdef size_t logsize = 0 + if prog != NULL: + cynvvm.nvvmGetProgramLogSize(prog, &logsize) + # Need GIL for Python string operations and exception raising + with gil: + _raise_nvvm_error(err, err_str, prog, logsize) + return -1 # Never reached, but satisfies return type + + +cdef int _raise_nvvm_error(cynvvm.nvvmResult err, const char* err_str, + cynvvm.nvvmProgram prog, size_t logsize) except -1: + """Helper to raise NVVMError with program log (requires GIL).""" + cdef bytes log_bytes + cdef str log_str = "" + if logsize > 1 and prog != NULL: + log_bytes = b" " * logsize + if cynvvm.nvvmGetProgramLog(prog, log_bytes) == cynvvm.nvvmResult.NVVM_SUCCESS: + log_str = log_bytes.decode("utf-8", errors="backslashreplace") + err_msg = f"{err}: {err_str.decode()}" if err_str != NULL else f"NVVM error {err}" + if log_str: + err_msg += f", compilation log:\n\n{log_str}" + raise NVVMError(err_msg) cdef object _RUNTIME_SUCCESS = runtime.cudaError_t.cudaSuccess diff --git a/cuda_core/tests/test_module.py b/cuda_core/tests/test_module.py index c0760cee45..c2c2ee1925 100644 --- a/cuda_core/tests/test_module.py +++ b/cuda_core/tests/test_module.py @@ -104,6 +104,8 @@ def test_get_kernel(init_cuda): kernel = object_code.get_kernel("ABC") assert object_code.handle is not None assert kernel.handle is not None + # Also works with bytes + assert object_code.get_kernel(b"ABC").handle is not None @pytest.mark.parametrize( @@ -146,7 +148,7 @@ def test_object_code_load_ptx(get_saxpy_kernel_ptx): sym_map = mod.symbol_mapping mod_obj = ObjectCode.from_ptx(ptx, symbol_mapping=sym_map) assert mod.code == ptx - if not Program._can_load_generated_ptx(): + if not Program.driver_can_load_nvrtc_ptx_output(): pytest.skip("PTX version too new for current driver") mod_obj.get_kernel("saxpy") # force loading @@ -160,7 +162,7 @@ def test_object_code_load_ptx_from_file(get_saxpy_kernel_ptx, tmp_path): mod_obj = ObjectCode.from_ptx(str(ptx_file), symbol_mapping=sym_map) assert mod_obj.code == str(ptx_file) assert mod_obj.code_type == "ptx" - if not Program._can_load_generated_ptx(): + if not Program.driver_can_load_nvrtc_ptx_output(): pytest.skip("PTX version too new for current driver") mod_obj.get_kernel("saxpy") # force loading diff --git a/cuda_core/tests/test_object_protocols.py b/cuda_core/tests/test_object_protocols.py index 69a83f6514..a5237b9c64 100644 --- a/cuda_core/tests/test_object_protocols.py +++ b/cuda_core/tests/test_object_protocols.py @@ -56,16 +56,86 @@ def sample_launch_config(): @pytest.fixture -def sample_object_code(init_cuda): - """A sample ObjectCode object.""" +def sample_kernel(sample_object_code_cubin): + """A sample Kernel object.""" + return sample_object_code_cubin.get_kernel("test_kernel") + + +# ============================================================================= +# Fixtures - ObjectCode variations (by code_type) +# ============================================================================= + + +@pytest.fixture +def sample_object_code_cubin(init_cuda): + """An ObjectCode compiled to cubin.""" prog = Program('extern "C" __global__ void test_kernel() {}', "c++") return prog.compile("cubin") @pytest.fixture -def sample_kernel(sample_object_code): - """A sample Kernel object.""" - return sample_object_code.get_kernel("test_kernel") +def sample_object_code_ptx(init_cuda): + """An ObjectCode compiled to PTX.""" + if not Program.driver_can_load_nvrtc_ptx_output(): + pytest.skip("PTX version too new for current driver") + prog = Program('extern "C" __global__ void test_kernel() {}', "c++") + return prog.compile("ptx") + + +@pytest.fixture +def sample_object_code_ltoir(init_cuda): + """An ObjectCode compiled to LTOIR.""" + prog = Program('extern "C" __global__ void test_kernel() {}', "c++") + return prog.compile("ltoir") + + +# ============================================================================= +# Fixtures - Program variations (by backend) +# ============================================================================= + + +@pytest.fixture +def sample_program_nvrtc(init_cuda): + """A Program using NVRTC backend (C++ code).""" + return Program('extern "C" __global__ void k() {}', "c++") + + +@pytest.fixture +def sample_program_ptx(init_cuda): + """A Program using linker backend (PTX code).""" + # First compile C++ to PTX, then create a Program from PTX + if not Program.driver_can_load_nvrtc_ptx_output(): + pytest.skip("PTX version too new for current driver") + prog = Program('extern "C" __global__ void k() {}', "c++") + obj = prog.compile("ptx") + ptx_code = obj.code.decode() if isinstance(obj.code, bytes) else obj.code + return Program(ptx_code, "ptx") + + +@pytest.fixture +def sample_program_nvvm(init_cuda): + """A Program using NVVM backend (NVVM IR code).""" + # Minimal NVVM IR that declares a kernel + # fmt: off + nvvm_ir = ( + 'target datalayout = "e-p:64:64:64-i1:8:8-i8:8:8-i16:16:16-i32:32:32-' + 'i64:64:64-f32:32:32-f64:64:64-v16:16:16-v32:32:32-v64:64:64-v128:128:128-n16:32:64"\n' + 'target triple = "nvptx64-nvidia-cuda"\n' + "\n" + "define void @test_kernel() {\n" + " ret void\n" + "}\n" + "\n" + "!nvvm.annotations = !{!0}\n" + '!0 = !{void ()* @test_kernel, !"kernel", i32 1}\n' + ) + # fmt: on + try: + return Program(nvvm_ir, "nvvm") + except RuntimeError as e: + if "NVVM" in str(e): + pytest.skip("NVVM not available") + raise # ============================================================================= @@ -131,18 +201,43 @@ def sample_kernel_alt(sample_object_code_alt): # Type groupings # ============================================================================= -# All types that should support weak references -API_TYPES = [ +# Types with __hash__ support +HASH_TYPES = [ "sample_device", "sample_stream", "sample_event", "sample_context", "sample_buffer", "sample_launch_config", - "sample_object_code", + "sample_object_code_cubin", "sample_kernel", ] +# Types with __eq__ support +EQ_TYPES = [ + "sample_device", + "sample_stream", + "sample_event", + "sample_context", + "sample_buffer", + "sample_launch_config", + "sample_object_code_cubin", + "sample_kernel", +] + +# Types with __weakref__ support +WEAKREF_TYPES = [ + "sample_device", + "sample_stream", + "sample_event", + "sample_context", + "sample_buffer", + "sample_launch_config", + "sample_object_code_cubin", + "sample_kernel", + "sample_program_nvrtc", +] + # Pairs of distinct objects of the same type (for inequality testing) # Device and Context pairs require multi-GPU and will skip on single-GPU machines SAME_TYPE_PAIRS = [ @@ -152,7 +247,7 @@ def sample_kernel_alt(sample_object_code_alt): ("sample_context", "sample_context_alt"), ("sample_buffer", "sample_buffer_alt"), ("sample_launch_config", "sample_launch_config_alt"), - ("sample_object_code", "sample_object_code_alt"), + ("sample_object_code_cubin", "sample_object_code_alt"), ("sample_kernel", "sample_kernel_alt"), ] @@ -163,8 +258,13 @@ def sample_kernel_alt(sample_object_code_alt): ("sample_kernel", lambda k: Kernel.from_handle(int(k.handle))), ] +# Derived type groupings for collection tests +DICT_KEY_TYPES = sorted(set(HASH_TYPES) & set(EQ_TYPES)) +WEAK_KEY_TYPES = sorted(set(HASH_TYPES) & set(EQ_TYPES) & set(WEAKREF_TYPES)) + # Pairs of (fixture_name, regex_pattern) for repr format validation REPR_PATTERNS = [ + # Core types ("sample_device", r""), ("sample_stream", r""), ("sample_event", r""), @@ -175,8 +275,15 @@ def sample_kernel_alt(sample_object_code_alt): r"LaunchConfig\(grid=\(\d+, \d+, \d+\), cluster=.+, block=\(\d+, \d+, \d+\), " r"shmem_size=\d+, cooperative_launch=(?:True|False)\)", ), - ("sample_object_code", r""), ("sample_kernel", r""), + # ObjectCode variations (by code_type) + ("sample_object_code_cubin", r""), + ("sample_object_code_ptx", r""), + ("sample_object_code_ltoir", r""), + # Program variations (by backend) + ("sample_program_nvrtc", r""), + ("sample_program_ptx", r""), + ("sample_program_nvvm", r""), ] @@ -185,7 +292,7 @@ def sample_kernel_alt(sample_object_code_alt): # ============================================================================= -@pytest.mark.parametrize("fixture_name", API_TYPES) +@pytest.mark.parametrize("fixture_name", WEAKREF_TYPES) def test_weakref_supported(fixture_name, request): """Object supports weak references.""" obj = request.getfixturevalue(fixture_name) @@ -198,7 +305,7 @@ def test_weakref_supported(fixture_name, request): # ============================================================================= -@pytest.mark.parametrize("fixture_name", API_TYPES) +@pytest.mark.parametrize("fixture_name", HASH_TYPES) def test_hash_consistency(fixture_name, request): """Hash is consistent across multiple calls.""" obj = request.getfixturevalue(fixture_name) @@ -213,7 +320,7 @@ def test_hash_distinct_same_type(a_name, b_name, request): assert hash(obj_a) != hash(obj_b) # extremely unlikely -@pytest.mark.parametrize("a_name,b_name", itertools.combinations(API_TYPES, 2)) +@pytest.mark.parametrize("a_name,b_name", itertools.combinations(HASH_TYPES, 2)) def test_hash_distinct_cross_type(a_name, b_name, request): """Distinct objects of different types have different hashes.""" obj_a = request.getfixturevalue(a_name) @@ -226,7 +333,7 @@ def test_hash_distinct_cross_type(a_name, b_name, request): # ============================================================================= -@pytest.mark.parametrize("fixture_name", API_TYPES) +@pytest.mark.parametrize("fixture_name", EQ_TYPES) def test_equality_basic(fixture_name, request): """Object equality: reflexive, not equal to None or other types.""" obj = request.getfixturevalue(fixture_name) @@ -237,7 +344,7 @@ def test_equality_basic(fixture_name, request): assert obj != obj.handle -@pytest.mark.parametrize("a_name,b_name", itertools.combinations(API_TYPES, 2)) +@pytest.mark.parametrize("a_name,b_name", itertools.combinations(EQ_TYPES, 2)) def test_no_cross_type_equality(a_name, b_name, request): """No two distinct objects of different types should compare equal.""" obj_a = request.getfixturevalue(a_name) @@ -268,7 +375,7 @@ def test_equality_same_handle(fixture_name, copy_fn, request): # ============================================================================= -@pytest.mark.parametrize("fixture_name", API_TYPES) +@pytest.mark.parametrize("fixture_name", DICT_KEY_TYPES) def test_usable_as_dict_key(fixture_name, request): """Object can be used as a dictionary key.""" obj = request.getfixturevalue(fixture_name) @@ -277,7 +384,7 @@ def test_usable_as_dict_key(fixture_name, request): assert obj in d -@pytest.mark.parametrize("fixture_name", API_TYPES) +@pytest.mark.parametrize("fixture_name", DICT_KEY_TYPES) def test_usable_in_set(fixture_name, request): """Object can be added to a set.""" obj = request.getfixturevalue(fixture_name) @@ -285,7 +392,7 @@ def test_usable_in_set(fixture_name, request): assert obj in s -@pytest.mark.parametrize("fixture_name", API_TYPES) +@pytest.mark.parametrize("fixture_name", WEAKREF_TYPES) def test_usable_in_weak_value_dict(fixture_name, request): """Object can be used as a WeakValueDictionary value.""" obj = request.getfixturevalue(fixture_name) @@ -294,7 +401,7 @@ def test_usable_in_weak_value_dict(fixture_name, request): assert wvd["key"] is obj -@pytest.mark.parametrize("fixture_name", API_TYPES) +@pytest.mark.parametrize("fixture_name", WEAK_KEY_TYPES) def test_usable_in_weak_key_dict(fixture_name, request): """Object can be used as a WeakKeyDictionary key.""" obj = request.getfixturevalue(fixture_name) @@ -303,7 +410,7 @@ def test_usable_in_weak_key_dict(fixture_name, request): assert wkd[obj] == "value" -@pytest.mark.parametrize("fixture_name", API_TYPES) +@pytest.mark.parametrize("fixture_name", WEAK_KEY_TYPES) def test_usable_in_weak_set(fixture_name, request): """Object can be added to a WeakSet.""" obj = request.getfixturevalue(fixture_name) diff --git a/cuda_core/tests/test_program.py b/cuda_core/tests/test_program.py index 259e6a9c98..abf29ae1f3 100644 --- a/cuda_core/tests/test_program.py +++ b/cuda_core/tests/test_program.py @@ -284,7 +284,6 @@ def test_cpp_program_with_various_options(init_cuda, options): assert program.backend == "NVRTC" program.compile("ptx") program.close() - assert program.handle is None @pytest.mark.skipif( @@ -299,7 +298,6 @@ def test_cpp_program_with_trace_option(init_cuda, tmp_path): assert program.backend == "NVRTC" program.compile("ptx") program.close() - assert program.handle is None @pytest.mark.skipif((_get_nvrtc_version_for_tests() or 0) < 12800, reason="PCH requires NVRTC >= 12.8") @@ -314,7 +312,6 @@ def test_cpp_program_with_pch_options(init_cuda, tmp_path): assert program.backend == "NVRTC" program.compile("ptx") program.close() - assert program.handle is None options = [ @@ -339,7 +336,6 @@ def test_ptx_program_with_various_options(init_cuda, ptx_code_object, options): assert program.backend == ("driver" if is_culink_backend else "nvJitLink") program.compile("cubin") program.close() - assert program.handle is None def test_program_init_valid_code_type(): @@ -409,7 +405,8 @@ def test_program_close(): code = 'extern "C" __global__ void my_kernel() {}' program = Program(code, "c++") program.close() - assert program.handle is None + # close() is idempotent + program.close() @nvvm_available @@ -441,7 +438,7 @@ def test_nvvm_program_creation_compilation(nvvm_ir): def test_nvvm_compile_invalid_target(nvvm_ir): """Test that NVVM programs reject invalid compilation targets""" program = Program(nvvm_ir, "nvvm") - with pytest.raises(ValueError, match='NVVM backend only supports target_type="ptx"'): + with pytest.raises(ValueError, match='Unsupported target_type="cubin" for NVVM'): program.compile("cubin") program.close()