Skip to content

Commit 2f47e9e

Browse files
committed
Extract Program helpers to module-level cdef functions
- Move _translate_program_options to Program_translate_options (cdef) - Move _can_load_generated_ptx to Program_can_load_generated_ptx (cdef) - Remove unused TYPE_CHECKING import block - Follow _memory/_buffer.pyx helper function patterns
1 parent e17027a commit 2f47e9e

File tree

1 file changed

+35
-30
lines changed

1 file changed

+35
-30
lines changed

cuda_core/cuda/core/_program.pyx

Lines changed: 35 additions & 30 deletions
Original file line numberDiff line numberDiff line change
@@ -7,12 +7,9 @@ from __future__ import annotations
77
import weakref
88
from contextlib import contextmanager
99
from dataclasses import dataclass
10-
from typing import TYPE_CHECKING, Union
10+
from typing import Union
1111
from warnings import warn
1212

13-
if TYPE_CHECKING:
14-
import cuda.bindings
15-
1613
from cuda.core._device import Device
1714
from cuda.core._linker import Linker, LinkerHandleT, LinkerOptions
1815
from cuda.core._module import ObjectCode
@@ -689,7 +686,7 @@ cdef class Program:
689686
elif code_type == "ptx":
690687
assert_type(code, str)
691688
self._linker = Linker(
692-
ObjectCode._init(code.encode(), code_type), options=self._translate_program_options(options)
689+
ObjectCode._init(code.encode(), code_type), options=Program_translate_options(options)
693690
)
694691
self._backend = self._linker.backend
695692

@@ -711,36 +708,12 @@ cdef class Program:
711708
assert code_type not in supported_code_types, f"{code_type=}"
712709
raise RuntimeError(f"Unsupported {code_type=} ({supported_code_types=})")
713710

714-
def _translate_program_options(self, options: ProgramOptions) -> LinkerOptions:
715-
return LinkerOptions(
716-
name=options.name,
717-
arch=options.arch,
718-
max_register_count=options.max_register_count,
719-
time=options.time,
720-
link_time_optimization=options.link_time_optimization,
721-
debug=options.debug,
722-
lineinfo=options.lineinfo,
723-
ftz=options.ftz,
724-
prec_div=options.prec_div,
725-
prec_sqrt=options.prec_sqrt,
726-
fma=options.fma,
727-
split_compile=options.split_compile,
728-
ptxas_options=options.ptxas_options,
729-
no_cache=options.no_cache,
730-
)
731-
732711
def close(self):
733712
"""Destroy this program."""
734713
if self._linker:
735714
self._linker.close()
736715
self._mnff.close()
737716

738-
@staticmethod
739-
def _can_load_generated_ptx():
740-
driver_ver = handle_return(driver.cuDriverGetVersion())
741-
nvrtc_major, nvrtc_minor = handle_return(nvrtc.nvrtcVersion())
742-
return nvrtc_major * 1000 + nvrtc_minor * 10 <= driver_ver
743-
744717
def compile(self, target_type, name_expressions=(), logs=None):
745718
"""Compile the program with a specific compilation type.
746719
@@ -768,7 +741,7 @@ cdef class Program:
768741
raise ValueError(f'Unsupported target_type="{target_type}" ({supported_target_types=})')
769742

770743
if self._backend == "NVRTC":
771-
if target_type == "ptx" and not self._can_load_generated_ptx():
744+
if target_type == "ptx" and not Program_can_load_generated_ptx():
772745
warn(
773746
"The CUDA driver version is older than the backend version. "
774747
"The generated ptx will not be loadable by the current driver.",
@@ -862,3 +835,35 @@ cdef class Program:
862835

863836
def __repr__(self) -> str:
864837
return f"<Program backend='{self._backend}'>"
838+
839+
840+
# =============================================================================
841+
# Helper functions
842+
# =============================================================================
843+
844+
845+
cdef bint Program_can_load_generated_ptx():
846+
"""Check if the driver can load PTX generated by the current NVRTC version."""
847+
driver_ver = handle_return(driver.cuDriverGetVersion())
848+
nvrtc_major, nvrtc_minor = handle_return(nvrtc.nvrtcVersion())
849+
return nvrtc_major * 1000 + nvrtc_minor * 10 <= driver_ver
850+
851+
852+
cdef object Program_translate_options(object options):
853+
"""Translate ProgramOptions to LinkerOptions for PTX compilation."""
854+
return LinkerOptions(
855+
name=options.name,
856+
arch=options.arch,
857+
max_register_count=options.max_register_count,
858+
time=options.time,
859+
link_time_optimization=options.link_time_optimization,
860+
debug=options.debug,
861+
lineinfo=options.lineinfo,
862+
ftz=options.ftz,
863+
prec_div=options.prec_div,
864+
prec_sqrt=options.prec_sqrt,
865+
fma=options.fma,
866+
split_compile=options.split_compile,
867+
ptxas_options=options.ptxas_options,
868+
no_cache=options.no_cache,
869+
)

0 commit comments

Comments
 (0)