diff --git a/CHANGELOG.md b/CHANGELOG.md index 2973996..09e4f04 100644 --- a/CHANGELOG.md +++ b/CHANGELOG.md @@ -7,6 +7,8 @@ and this project adheres to [Semantic Versioning](https://semver.org/spec/v2.0.0 ## [dev] (MM/DD/YYYY) ### Added +* Added `mkl_random` patching for NumPy, with `mkl_random` context manager, `is_patched` query, and `patch_numpy_random` and `restore_numpy_random` calls to replace `numpy.random` calls with calls from `mkl_random.interfaces.numpy_random` [gh-90](https://github.com/IntelPython/mkl_random/pull/90) + * Added `mkl_random.interfaces` with `mkl_random.interfaces.numpy_random` interface, which aliases `mkl_random` functionality to more strictly adhere to NumPy's API (i.e., drops arguments and functions which are not part of standard NumPy) [gh-92](https://github.com/IntelPython/mkl_random/pull/92) ### Removed diff --git a/mkl_random/__init__.py b/mkl_random/__init__.py index 3b5f272..774811b 100644 --- a/mkl_random/__init__.py +++ b/mkl_random/__init__.py @@ -95,6 +95,13 @@ from mkl_random import interfaces +from ._patch_numpy import ( + is_patched, + mkl_random, + patch_numpy_random, + restore_numpy_random, +) + __all__ = [ "MKLRandomState", "RandomState", @@ -147,6 +154,10 @@ "shuffle", "permutation", "interfaces", + "mkl_random", + "patch_numpy_random", + "restore_numpy_random", + "is_patched", ] del _init_helper diff --git a/mkl_random/_patch_numpy.py b/mkl_random/_patch_numpy.py new file mode 100644 index 0000000..7d518e7 --- /dev/null +++ b/mkl_random/_patch_numpy.py @@ -0,0 +1,184 @@ +# Copyright (c) 2019, Intel Corporation +# +# Redistribution and use in source and binary forms, with or without +# modification, are permitted provided that the following conditions are met: +# +# * Redistributions of source code must retain the above copyright notice, +# this list of conditions and the following disclaimer. +# * Redistributions in binary form must reproduce the above copyright +# notice, this list of conditions and the following disclaimer in the +# documentation and/or other materials provided with the distribution. +# * Neither the name of Intel Corporation nor the names of its contributors +# may be used to endorse or promote products derived from this software +# without specific prior written permission. +# +# THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS "AS IS" +# AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE +# IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR PURPOSE ARE +# DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT OWNER OR CONTRIBUTORS BE LIABLE +# FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, OR CONSEQUENTIAL +# DAMAGES (INCLUDING, BUT NOT LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR +# SERVICES; LOSS OF USE, DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER +# CAUSED AND ON ANY THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY, +# OR TORT (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE +# OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE. + +"""Define functions for patching NumPy with MKL-based NumPy interface.""" + +from contextlib import ContextDecorator +from threading import Lock, local + +import numpy as np + +import mkl_random.interfaces.numpy_random as _nrand + + +class _GlobalPatch: + def __init__(self): + self._lock = Lock() + self._patch_count = 0 + self._restore_dict = {} + # make _patched_functions a tuple (immutable) + self._patched_functions = tuple(_nrand.__all__) + self._tls = local() + + def _register_func(self, name, func): + if name not in self._patched_functions: + raise ValueError(f"{name} not an mkl_random function.") + if name not in self._restore_dict: + self._restore_dict[name] = getattr(np.random, name) + setattr(np.random, name, func) + + def _restore_func(self, name, verbose=False): + if name not in self._patched_functions: + raise ValueError(f"{name} not an mkl_random function.") + try: + val = self._restore_dict[name] + except KeyError: + if verbose: + print(f"failed to restore {name}") + return + else: + if verbose: + print(f"found and restoring {name}...") + setattr(np.random, name, val) + + def do_patch(self, verbose=False): + with self._lock: + local_count = getattr(self._tls, "local_count", 0) + if self._patch_count == 0: + if verbose: + print( + "Now patching NumPy random submodule with mkl_random " + "NumPy interface." + ) + print( + "Please direct bug reports to " + "https://github.com/IntelPython/mkl_random" + ) + for f in self._patched_functions: + self._register_func(f, getattr(_nrand, f)) + self._patch_count += 1 + self._tls.local_count = local_count + 1 + + def do_restore(self, verbose=False): + with self._lock: + local_count = getattr(self._tls, "local_count", 0) + if local_count <= 0: + if verbose: + print( + "Warning: restore_numpy_random called more times than " + "patch_numpy_random in this thread." + ) + return + self._tls.local_count -= 1 + self._patch_count -= 1 + if self._patch_count == 0: + if verbose: + print("Now restoring original NumPy random submodule.") + for name in tuple(self._restore_dict): + self._restore_func(name, verbose=verbose) + self._restore_dict.clear() + + def is_patched(self): + with self._lock: + return self._patch_count > 0 + + +_patch = _GlobalPatch() + + +def patch_numpy_random(verbose=False): + """ + Patch NumPy's random submodule with mkl_random's numpy_interface. + + Parameters + ---------- + verbose : bool, optional + print message when starting the patching process. + + Notes + ----- + This function uses reference-counted semantics. Each call increments a + global patch counter. Restoration requires a matching number of calls + between `patch_numpy_random` and `restore_numpy_random`. + + In multi-threaded programs, prefer the `mkl_random` context manager. + + """ + _patch.do_patch(verbose=verbose) + + +def restore_numpy_random(verbose=False): + """ + Restore NumPy's random submodule to its original implementations. + + Parameters + ---------- + verbose : bool, optional + print message when starting restoration process. + + Notes + ----- + This function uses reference-counted semantics. Each call decrements a + global patch counter. Restoration requires a matching number of calls + between `patch_numpy_random` and `restore_numpy_random`. + + In multi-threaded programs, prefer the `mkl_random` context manager. + + """ + _patch.do_restore(verbose=verbose) + + +def is_patched(): + """Return True if NumPy's random sm is currently patched by mkl_random.""" + return _patch.is_patched() + + +class mkl_random(ContextDecorator): + """ + Context manager and decorator to temporarily patch NumPy random submodule + with MKL-based implementations. + + Examples + -------- + >>> import mkl_random + >>> mkl_random.is_patched() + # False + + >>> with mkl_random.mkl_random(): # Enable mkl_random in NumPy + >>> print(mkl_random.is_patched()) + # True + + >>> mkl_random.is_patched() + # False + + """ + + def __enter__(self): + patch_numpy_random() + return self + + def __exit__(self, *exc): + restore_numpy_random() + return False diff --git a/mkl_random/tests/test_patch.py b/mkl_random/tests/test_patch.py new file mode 100644 index 0000000..beffc1c --- /dev/null +++ b/mkl_random/tests/test_patch.py @@ -0,0 +1,105 @@ +# Copyright (c) 2017, Intel Corporation +# +# Redistribution and use in source and binary forms, with or without +# modification, are permitted provided that the following conditions are met: +# +# * Redistributions of source code must retain the above copyright notice, +# this list of conditions and the following disclaimer. +# * Redistributions in binary form must reproduce the above copyright +# notice, this list of conditions and the following disclaimer in the +# documentation and/or other materials provided with the distribution. +# * Neither the name of Intel Corporation nor the names of its contributors +# may be used to endorse or promote products derived from this software +# without specific prior written permission. +# +# THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS "AS IS" +# AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE +# IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR PURPOSE ARE +# DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT OWNER OR CONTRIBUTORS BE LIABLE +# FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, OR CONSEQUENTIAL +# DAMAGES (INCLUDING, BUT NOT LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR +# SERVICES; LOSS OF USE, DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER +# CAUSED AND ON ANY THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY, +# OR TORT (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE +# OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE. + +import numpy as np + +import mkl_random +import mkl_random.interfaces.numpy_random as _nrand + + +def test_is_patched(): + """Test that is_patched() returns correct status.""" + assert not mkl_random.is_patched() + try: + mkl_random.patch_numpy_random() + assert mkl_random.is_patched() + mkl_random.restore_numpy_random() + assert not mkl_random.is_patched() + finally: + while mkl_random.is_patched(): + mkl_random.restore_numpy_random() + + +def test_patch(): + old_module = np.random.normal.__module__ + assert not mkl_random.is_patched() + + try: + mkl_random.patch_numpy_random() # Enable mkl_random in NumPy + assert mkl_random.is_patched() + assert np.random.normal.__module__ == _nrand.normal.__module__ + + mkl_random.restore_numpy_random() # Disable mkl_random in NumPy + assert not mkl_random.is_patched() + assert np.random.normal.__module__ == old_module + finally: + while mkl_random.is_patched(): + mkl_random.restore_numpy_random() + + +def test_patch_redundant_patching(): + old_module = np.random.normal.__module__ + assert not mkl_random.is_patched() + + try: + mkl_random.patch_numpy_random() + mkl_random.patch_numpy_random() + + assert mkl_random.is_patched() + assert np.random.normal.__module__ == _nrand.normal.__module__ + + mkl_random.restore_numpy_random() + assert mkl_random.is_patched() + assert np.random.normal.__module__ == _nrand.normal.__module__ + + mkl_random.restore_numpy_random() + assert not mkl_random.is_patched() + assert np.random.normal.__module__ == old_module + finally: + while mkl_random.is_patched(): + mkl_random.restore_numpy_random() + + +def test_patch_reentrant(): + old_module = np.random.normal.__module__ + assert not mkl_random.is_patched() + + try: + with mkl_random.mkl_random(): + assert mkl_random.is_patched() + assert np.random.normal.__module__ == _nrand.normal.__module__ + + with mkl_random.mkl_random(): + assert mkl_random.is_patched() + assert np.random.normal.__module__ == _nrand.normal.__module__ + + assert mkl_random.is_patched() + assert np.random.normal.__module__ == _nrand.normal.__module__ + + assert not mkl_random.is_patched() + assert np.random.normal.__module__ == old_module + finally: + while mkl_random.is_patched(): + mkl_random.restore_numpy_random() diff --git a/pyproject.toml b/pyproject.toml index 3352468..6c136fc 100644 --- a/pyproject.toml +++ b/pyproject.toml @@ -93,6 +93,12 @@ line_length = 80 multi_line_output = 3 use_parentheses = true +[tool.pylint.main] +extension-pkg-allow-list = ["numpy", "mkl_random.mklrand"] + +[tool.pylint.typecheck] +generated-members = ["RandomState", "min", "max"] + [tool.setuptools] include-package-data = true diff --git a/setup.py b/setup.py index bd64aaa..0b90bf9 100644 --- a/setup.py +++ b/setup.py @@ -92,7 +92,7 @@ def extensions(): extra_compile_args=eca, define_macros=defs + [("NDEBUG", None)], language="c++", - ) + ), ] return exts