@@ -7,12 +7,9 @@ from __future__ import annotations
77import weakref
88from contextlib import contextmanager
99from dataclasses import dataclass
10- from typing import TYPE_CHECKING, Union
10+ from typing import Union
1111from warnings import warn
1212
13- if TYPE_CHECKING:
14- import cuda.bindings
15-
1613from cuda.core._device import Device
1714from cuda.core._linker import Linker, LinkerHandleT, LinkerOptions
1815from cuda.core._module import ObjectCode
@@ -689,7 +686,7 @@ cdef class Program:
689686 elif code_type == " ptx" :
690687 assert_type(code, str )
691688 self ._linker = Linker(
692- ObjectCode._init(code.encode(), code_type), options = self ._translate_program_options (options)
689+ ObjectCode._init(code.encode(), code_type), options = Program_translate_options (options)
693690 )
694691 self ._backend = self ._linker.backend
695692
@@ -711,36 +708,12 @@ cdef class Program:
711708 assert code_type not in supported_code_types, f" {code_type=}"
712709 raise RuntimeError (f" Unsupported {code_type=} ({supported_code_types=})" )
713710
714- def _translate_program_options (self , options: ProgramOptions ) -> LinkerOptions:
715- return LinkerOptions(
716- name = options.name,
717- arch = options.arch,
718- max_register_count = options.max_register_count,
719- time = options.time,
720- link_time_optimization = options.link_time_optimization,
721- debug = options.debug,
722- lineinfo = options.lineinfo,
723- ftz = options.ftz,
724- prec_div = options.prec_div,
725- prec_sqrt = options.prec_sqrt,
726- fma = options.fma,
727- split_compile = options.split_compile,
728- ptxas_options = options.ptxas_options,
729- no_cache = options.no_cache,
730- )
731-
732711 def close (self ):
733712 """ Destroy this program."""
734713 if self ._linker:
735714 self ._linker.close()
736715 self ._mnff.close()
737716
738- @staticmethod
739- def _can_load_generated_ptx ():
740- driver_ver = handle_return(driver.cuDriverGetVersion())
741- nvrtc_major, nvrtc_minor = handle_return(nvrtc.nvrtcVersion())
742- return nvrtc_major * 1000 + nvrtc_minor * 10 <= driver_ver
743-
744717 def compile (self , target_type , name_expressions = (), logs = None ):
745718 """ Compile the program with a specific compilation type.
746719
@@ -768,7 +741,7 @@ cdef class Program:
768741 raise ValueError (f' Unsupported target_type="{target_type}" ({supported_target_types=})' )
769742
770743 if self ._backend == " NVRTC" :
771- if target_type == " ptx" and not self ._can_load_generated_ptx ():
744+ if target_type == " ptx" and not Program_can_load_generated_ptx ():
772745 warn(
773746 " The CUDA driver version is older than the backend version. "
774747 " The generated ptx will not be loadable by the current driver." ,
@@ -862,3 +835,35 @@ cdef class Program:
862835
863836 def __repr__(self ) -> str:
864837 return f"<Program backend = ' {self._backend}' > "
838+
839+
840+ # =============================================================================
841+ # Helper functions
842+ # =============================================================================
843+
844+
845+ cdef bint Program_can_load_generated_ptx():
846+ """ Check if the driver can load PTX generated by the current NVRTC version."""
847+ driver_ver = handle_return(driver.cuDriverGetVersion())
848+ nvrtc_major, nvrtc_minor = handle_return(nvrtc.nvrtcVersion())
849+ return nvrtc_major * 1000 + nvrtc_minor * 10 <= driver_ver
850+
851+
852+ cdef object Program_translate_options(object options):
853+ """ Translate ProgramOptions to LinkerOptions for PTX compilation."""
854+ return LinkerOptions(
855+ name = options.name,
856+ arch = options.arch,
857+ max_register_count = options.max_register_count,
858+ time = options.time,
859+ link_time_optimization = options.link_time_optimization,
860+ debug = options.debug,
861+ lineinfo = options.lineinfo,
862+ ftz = options.ftz,
863+ prec_div = options.prec_div,
864+ prec_sqrt = options.prec_sqrt,
865+ fma = options.fma,
866+ split_compile = options.split_compile,
867+ ptxas_options = options.ptxas_options,
868+ no_cache = options.no_cache,
869+ )
0 commit comments