Skip to content

Commit 17e4489

Browse files
committed
fix: review
1 parent 4de0d12 commit 17e4489

File tree

4 files changed

+122
-70
lines changed

4 files changed

+122
-70
lines changed

mkl_random/__init__.py

Lines changed: 2 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -99,8 +99,10 @@
9999
is_patched,
100100
mkl_random,
101101
monkey_patch,
102+
patch_numpy_random,
102103
patched_names,
103104
restore,
105+
restore_numpy_random,
104106
use_in_numpy,
105107
)
106108

Lines changed: 106 additions & 50 deletions
Original file line numberDiff line numberDiff line change
@@ -23,17 +23,7 @@
2323
# OR TORT (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE
2424
# OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE.
2525

26-
# distutils: language = c
27-
# cython: language_level=3
28-
29-
"""
30-
Patch NumPy's `numpy.random` symbols to use mkl_random implementations.
31-
32-
This is attribute-level monkey patching. It can replace legacy APIs like
33-
`numpy.random.RandomState` and global distribution functions, but it does not
34-
replace NumPy's `Generator`/`default_rng()` unless mkl_random provides fully
35-
compatible replacements.
36-
"""
26+
"""Define functions for patching NumPy with MKL-based NumPy interface."""
3727

3828
from contextlib import ContextDecorator
3929
from threading import Lock, local
@@ -43,24 +33,21 @@
4333
from . import mklrand as _mr
4434

4535

46-
cdef tuple _DEFAULT_NAMES = (
36+
_DEFAULT_NAMES = (
4737
# Legacy seeding / state
4838
"seed",
4939
"get_state",
5040
"set_state",
5141
"RandomState",
52-
5342
# Common global sampling helpers
5443
"random",
5544
"random_sample",
5645
"sample",
5746
"rand",
5847
"randn",
5948
"bytes",
60-
6149
# Integers
6250
"randint",
63-
6451
# Common distributions (only patched if present on both sides)
6552
"standard_normal",
6653
"normal",
@@ -82,7 +69,6 @@
8269
"wald",
8370
"weibull",
8471
"zipf",
85-
8672
# Permutations / choices
8773
"choice",
8874
"permutation",
@@ -94,9 +80,11 @@ class _GlobalPatch:
9480
def __init__(self):
9581
self._lock = Lock()
9682
self._patch_count = 0
83+
self._restore_dict = {}
84+
self._patched_functions = tuple(_DEFAULT_NAMES)
9785
self._numpy_module = None
9886
self._requested_names = None
99-
self._originals = {}
87+
self._active_names = ()
10088
self._patched = ()
10189
self._tls = local()
10290

@@ -111,22 +99,44 @@ def _validate_module(self, numpy_module):
11199
"Expected a numpy-like module with a `.random` attribute."
112100
)
113101

114-
def _apply_patch(self, numpy_module, names, strict):
102+
def _register_func(self, name, func):
103+
if name not in self._patched_functions:
104+
raise ValueError(f"{name} not an mkl_random function.")
105+
np_random = self._numpy_module.random
106+
if name not in self._restore_dict:
107+
self._restore_dict[name] = getattr(np_random, name)
108+
setattr(np_random, name, func)
109+
110+
def _restore_func(self, name, verbose=False):
111+
if name not in self._patched_functions:
112+
raise ValueError(f"{name} not an mkl_random function.")
113+
try:
114+
val = self._restore_dict[name]
115+
except KeyError:
116+
if verbose:
117+
print(f"failed to restore {name}")
118+
return
119+
else:
120+
if verbose:
121+
print(f"found and restoring {name}...")
122+
np_random = self._numpy_module.random
123+
setattr(np_random, name, val)
124+
125+
def _initialize_patch(self, numpy_module, names, strict):
126+
self._validate_module(numpy_module)
115127
np_random = numpy_module.random
116-
originals = {}
117-
patched = []
118128
missing = []
129+
patchable = []
119130
for name in names:
131+
if name not in self._patched_functions:
132+
missing.append(name)
133+
continue
120134
if not hasattr(np_random, name) or not hasattr(_mr, name):
121135
missing.append(name)
122136
continue
123-
originals[name] = getattr(np_random, name)
124-
setattr(np_random, name, getattr(_mr, name))
125-
patched.append(name)
137+
patchable.append(name)
126138

127139
if strict and missing:
128-
for name, value in originals.items():
129-
setattr(np_random, name, value)
130140
raise AttributeError(
131141
"Could not patch these names (missing on numpy.random or "
132142
"mkl_random.mklrand): "
@@ -135,8 +145,8 @@ def _apply_patch(self, numpy_module, names, strict):
135145

136146
self._numpy_module = numpy_module
137147
self._requested_names = names
138-
self._originals = originals
139-
self._patched = tuple(patched)
148+
self._active_names = tuple(patchable)
149+
self._patched = tuple(patchable)
140150

141151
def do_patch(
142152
self,
@@ -148,13 +158,23 @@ def do_patch(
148158
if numpy_module is None:
149159
numpy_module = _np
150160
names = self._normalize_names(names)
151-
self._validate_module(numpy_module)
152161
strict = bool(strict)
153162

154163
with self._lock:
155164
local_count = getattr(self._tls, "local_count", 0)
156165
if self._patch_count == 0:
157-
self._apply_patch(numpy_module, names, strict)
166+
self._initialize_patch(numpy_module, names, strict)
167+
if verbose:
168+
print(
169+
"Now patching NumPy random submodule with mkl_random "
170+
"NumPy interface."
171+
)
172+
print(
173+
"Please direct bug reports to "
174+
"https://github.com/IntelPython/mkl_random"
175+
)
176+
for name in self._active_names:
177+
self._register_func(name, getattr(_mr, name))
158178
else:
159179
if self._numpy_module is not numpy_module:
160180
raise RuntimeError(
@@ -175,20 +195,22 @@ def do_restore(self, verbose=False):
175195
if local_count <= 0:
176196
if verbose:
177197
print(
178-
"Warning: restore called more times than monkey_patch "
179-
"in this thread."
198+
"Warning: restore_numpy_random called more times than "
199+
"patch_numpy_random in this thread."
180200
)
181201
return
182202

183203
self._tls.local_count = local_count - 1
184204
self._patch_count -= 1
185205
if self._patch_count == 0:
186-
np_random = self._numpy_module.random
187-
for name, value in self._originals.items():
188-
setattr(np_random, name, value)
206+
if verbose:
207+
print("Now restoring original NumPy random submodule.")
208+
for name in tuple(self._restore_dict):
209+
self._restore_func(name, verbose=verbose)
210+
self._restore_dict.clear()
189211
self._numpy_module = None
190212
self._requested_names = None
191-
self._originals = {}
213+
self._active_names = ()
192214
self._patched = ()
193215

194216
def is_patched(self):
@@ -203,18 +225,33 @@ def patched_names(self):
203225
_patch = _GlobalPatch()
204226

205227

206-
def monkey_patch(numpy_module=None, names=None, strict=False, verbose=False):
228+
def patch_numpy_random(
229+
numpy_module=None,
230+
names=None,
231+
strict=False,
232+
verbose=False,
233+
):
207234
"""
208-
Enables using mkl_random in the given NumPy module by patching
209-
`numpy.random`.
235+
Patch NumPy's random submodule with mkl_random's NumPy interface.
236+
237+
Parameters
238+
----------
239+
numpy_module : module, optional
240+
NumPy-like module to patch. Defaults to imported NumPy.
241+
names : iterable[str], optional
242+
Attributes under `numpy_module.random` to patch.
243+
strict : bool, optional
244+
Raise if any requested symbol cannot be patched.
245+
verbose : bool, optional
246+
Print messages when starting the patching process.
210247
211248
Examples
212249
--------
213250
>>> import numpy as np
214251
>>> import mkl_random
215252
>>> mkl_random.is_patched()
216253
False
217-
>>> mkl_random.monkey_patch(np)
254+
>>> mkl_random.patch_numpy_random(np)
218255
>>> mkl_random.is_patched()
219256
True
220257
>>> mkl_random.restore()
@@ -229,11 +266,31 @@ def monkey_patch(numpy_module=None, names=None, strict=False, verbose=False):
229266
)
230267

231268

232-
def use_in_numpy(numpy_module=None, names=None, strict=False, verbose=False):
269+
def restore_numpy_random(verbose=False):
233270
"""
234-
Backward-compatible alias for monkey_patch().
271+
Restore NumPy's random submodule to its original implementations.
272+
273+
Parameters
274+
----------
275+
verbose : bool, optional
276+
Print message when starting restoration process.
235277
"""
236-
monkey_patch(
278+
_patch.do_restore(verbose=bool(verbose))
279+
280+
281+
def monkey_patch(numpy_module=None, names=None, strict=False, verbose=False):
282+
"""Backward-compatible alias for patch_numpy_random()."""
283+
patch_numpy_random(
284+
numpy_module=numpy_module,
285+
names=names,
286+
strict=strict,
287+
verbose=verbose,
288+
)
289+
290+
291+
def use_in_numpy(numpy_module=None, names=None, strict=False, verbose=False):
292+
"""Backward-compatible alias for patch_numpy_random()."""
293+
patch_numpy_random(
237294
numpy_module=numpy_module,
238295
names=names,
239296
strict=strict,
@@ -242,11 +299,8 @@ def use_in_numpy(numpy_module=None, names=None, strict=False, verbose=False):
242299

243300

244301
def restore(verbose=False):
245-
"""
246-
Disables using mkl_random in NumPy by restoring the original
247-
`numpy.random` symbols.
248-
"""
249-
_patch.do_restore(verbose=bool(verbose))
302+
"""Backward-compatible alias for restore_numpy_random()."""
303+
restore_numpy_random(verbose=verbose)
250304

251305

252306
def is_patched():
@@ -265,7 +319,8 @@ def patched_names():
265319

266320
class mkl_random(ContextDecorator):
267321
"""
268-
Context manager and decorator to temporarily patch NumPy's `numpy.random`.
322+
Context manager and decorator to temporarily patch NumPy random submodule
323+
with MKL-based implementations.
269324
270325
Examples
271326
--------
@@ -274,19 +329,20 @@ class mkl_random(ContextDecorator):
274329
>>> with mkl_random.mkl_random(np):
275330
... x = np.random.normal(size=10)
276331
"""
332+
277333
def __init__(self, numpy_module=None, names=None, strict=False):
278334
self._numpy_module = numpy_module
279335
self._names = names
280336
self._strict = strict
281337

282338
def __enter__(self):
283-
monkey_patch(
339+
patch_numpy_random(
284340
numpy_module=self._numpy_module,
285341
names=self._names,
286342
strict=self._strict,
287343
)
288344
return self
289345

290346
def __exit__(self, *exc):
291-
restore()
347+
restore_numpy_random()
292348
return False

mkl_random/tests/test_patch.py

Lines changed: 14 additions & 13 deletions
Original file line numberDiff line numberDiff line change
@@ -133,19 +133,20 @@ def test_patch_redundant_patching():
133133
orig_normal = np.random.normal
134134
assert not mkl_random.is_patched()
135135

136-
mkl_random.monkey_patch(np)
137-
mkl_random.monkey_patch(np)
138-
139-
assert mkl_random.is_patched()
140-
assert np.random.normal is mkl_random.normal
141-
142-
mkl_random.restore()
143-
assert mkl_random.is_patched()
144-
assert np.random.normal is mkl_random.normal
145-
146-
mkl_random.restore()
147-
assert not mkl_random.is_patched()
148-
assert np.random.normal is orig_normal
136+
try:
137+
mkl_random.monkey_patch(np)
138+
mkl_random.monkey_patch(np)
139+
assert mkl_random.is_patched()
140+
assert np.random.normal is mkl_random.mklrand.normal
141+
mkl_random.restore()
142+
assert mkl_random.is_patched()
143+
assert np.random.normal is mkl_random.mklrand.normal
144+
mkl_random.restore()
145+
assert not mkl_random.is_patched()
146+
assert np.random.normal is orig_normal
147+
finally:
148+
while mkl_random.is_patched():
149+
mkl_random.restore()
149150

150151

151152
def test_patch_reentrant():

setup.py

Lines changed: 0 additions & 7 deletions
Original file line numberDiff line numberDiff line change
@@ -93,13 +93,6 @@ def extensions():
9393
define_macros=defs + [("NDEBUG", None)],
9494
language="c++",
9595
),
96-
Extension(
97-
"mkl_random._patch",
98-
sources=[join("mkl_random", "src", "_patch.pyx")],
99-
include_dirs=[np.get_include()],
100-
define_macros=defs + [("NDEBUG", None)],
101-
language="c",
102-
),
10396
]
10497

10598
return exts

0 commit comments

Comments
 (0)