2323# OR TORT (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE
2424# OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE.
2525
26- # distutils: language = c
27- # cython: language_level=3
28-
29- """
30- Patch NumPy's `numpy.random` symbols to use mkl_random implementations.
31-
32- This is attribute-level monkey patching. It can replace legacy APIs like
33- `numpy.random.RandomState` and global distribution functions, but it does not
34- replace NumPy's `Generator`/`default_rng()` unless mkl_random provides fully
35- compatible replacements.
36- """
26+ """Define functions for patching NumPy with MKL-based NumPy interface."""
3727
3828from contextlib import ContextDecorator
3929from threading import Lock , local
4333from . import mklrand as _mr
4434
4535
46- cdef tuple _DEFAULT_NAMES = (
36+ _DEFAULT_NAMES = (
4737 # Legacy seeding / state
4838 "seed" ,
4939 "get_state" ,
5040 "set_state" ,
5141 "RandomState" ,
52-
5342 # Common global sampling helpers
5443 "random" ,
5544 "random_sample" ,
5645 "sample" ,
5746 "rand" ,
5847 "randn" ,
5948 "bytes" ,
60-
6149 # Integers
6250 "randint" ,
63-
6451 # Common distributions (only patched if present on both sides)
6552 "standard_normal" ,
6653 "normal" ,
8269 "wald" ,
8370 "weibull" ,
8471 "zipf" ,
85-
8672 # Permutations / choices
8773 "choice" ,
8874 "permutation" ,
@@ -94,9 +80,11 @@ class _GlobalPatch:
9480 def __init__ (self ):
9581 self ._lock = Lock ()
9682 self ._patch_count = 0
83+ self ._restore_dict = {}
84+ self ._patched_functions = tuple (_DEFAULT_NAMES )
9785 self ._numpy_module = None
9886 self ._requested_names = None
99- self ._originals = {}
87+ self ._active_names = ()
10088 self ._patched = ()
10189 self ._tls = local ()
10290
@@ -111,22 +99,44 @@ def _validate_module(self, numpy_module):
11199 "Expected a numpy-like module with a `.random` attribute."
112100 )
113101
114- def _apply_patch (self , numpy_module , names , strict ):
102+ def _register_func (self , name , func ):
103+ if name not in self ._patched_functions :
104+ raise ValueError (f"{ name } not an mkl_random function." )
105+ np_random = self ._numpy_module .random
106+ if name not in self ._restore_dict :
107+ self ._restore_dict [name ] = getattr (np_random , name )
108+ setattr (np_random , name , func )
109+
110+ def _restore_func (self , name , verbose = False ):
111+ if name not in self ._patched_functions :
112+ raise ValueError (f"{ name } not an mkl_random function." )
113+ try :
114+ val = self ._restore_dict [name ]
115+ except KeyError :
116+ if verbose :
117+ print (f"failed to restore { name } " )
118+ return
119+ else :
120+ if verbose :
121+ print (f"found and restoring { name } ..." )
122+ np_random = self ._numpy_module .random
123+ setattr (np_random , name , val )
124+
125+ def _initialize_patch (self , numpy_module , names , strict ):
126+ self ._validate_module (numpy_module )
115127 np_random = numpy_module .random
116- originals = {}
117- patched = []
118128 missing = []
129+ patchable = []
119130 for name in names :
131+ if name not in self ._patched_functions :
132+ missing .append (name )
133+ continue
120134 if not hasattr (np_random , name ) or not hasattr (_mr , name ):
121135 missing .append (name )
122136 continue
123- originals[name] = getattr (np_random, name)
124- setattr (np_random, name, getattr (_mr, name))
125- patched.append(name)
137+ patchable .append (name )
126138
127139 if strict and missing :
128- for name, value in originals.items():
129- setattr (np_random, name, value)
130140 raise AttributeError (
131141 "Could not patch these names (missing on numpy.random or "
132142 "mkl_random.mklrand): "
@@ -135,8 +145,8 @@ def _apply_patch(self, numpy_module, names, strict):
135145
136146 self ._numpy_module = numpy_module
137147 self ._requested_names = names
138- self ._originals = originals
139- self ._patched = tuple (patched )
148+ self ._active_names = tuple ( patchable )
149+ self ._patched = tuple (patchable )
140150
141151 def do_patch (
142152 self ,
@@ -148,13 +158,23 @@ def do_patch(
148158 if numpy_module is None :
149159 numpy_module = _np
150160 names = self ._normalize_names (names )
151- self ._validate_module(numpy_module)
152161 strict = bool (strict )
153162
154163 with self ._lock :
155164 local_count = getattr (self ._tls , "local_count" , 0 )
156165 if self ._patch_count == 0 :
157- self ._apply_patch(numpy_module, names, strict)
166+ self ._initialize_patch (numpy_module , names , strict )
167+ if verbose :
168+ print (
169+ "Now patching NumPy random submodule with mkl_random "
170+ "NumPy interface."
171+ )
172+ print (
173+ "Please direct bug reports to "
174+ "https://github.com/IntelPython/mkl_random"
175+ )
176+ for name in self ._active_names :
177+ self ._register_func (name , getattr (_mr , name ))
158178 else :
159179 if self ._numpy_module is not numpy_module :
160180 raise RuntimeError (
@@ -175,20 +195,22 @@ def do_restore(self, verbose=False):
175195 if local_count <= 0 :
176196 if verbose :
177197 print (
178- " Warning: restore called more times than monkey_patch "
179- " in this thread."
198+ "Warning: restore_numpy_random called more times than "
199+ "patch_numpy_random in this thread."
180200 )
181201 return
182202
183203 self ._tls .local_count = local_count - 1
184204 self ._patch_count -= 1
185205 if self ._patch_count == 0 :
186- np_random = self ._numpy_module.random
187- for name, value in self ._originals.items():
188- setattr (np_random, name, value)
206+ if verbose :
207+ print ("Now restoring original NumPy random submodule." )
208+ for name in tuple (self ._restore_dict ):
209+ self ._restore_func (name , verbose = verbose )
210+ self ._restore_dict .clear ()
189211 self ._numpy_module = None
190212 self ._requested_names = None
191- self ._originals = {}
213+ self ._active_names = ()
192214 self ._patched = ()
193215
194216 def is_patched (self ):
@@ -203,18 +225,33 @@ def patched_names(self):
203225_patch = _GlobalPatch ()
204226
205227
206- def monkey_patch (numpy_module = None , names = None , strict = False , verbose = False ):
228+ def patch_numpy_random (
229+ numpy_module = None ,
230+ names = None ,
231+ strict = False ,
232+ verbose = False ,
233+ ):
207234 """
208- Enables using mkl_random in the given NumPy module by patching
209- `numpy.random`.
235+ Patch NumPy's random submodule with mkl_random's NumPy interface.
236+
237+ Parameters
238+ ----------
239+ numpy_module : module, optional
240+ NumPy-like module to patch. Defaults to imported NumPy.
241+ names : iterable[str], optional
242+ Attributes under `numpy_module.random` to patch.
243+ strict : bool, optional
244+ Raise if any requested symbol cannot be patched.
245+ verbose : bool, optional
246+ Print messages when starting the patching process.
210247
211248 Examples
212249 --------
213250 >>> import numpy as np
214251 >>> import mkl_random
215252 >>> mkl_random.is_patched()
216253 False
217- >>> mkl_random.monkey_patch (np)
254+ >>> mkl_random.patch_numpy_random (np)
218255 >>> mkl_random.is_patched()
219256 True
220257 >>> mkl_random.restore()
@@ -229,11 +266,31 @@ def monkey_patch(numpy_module=None, names=None, strict=False, verbose=False):
229266 )
230267
231268
232- def use_in_numpy ( numpy_module = None , names = None , strict = False , verbose = False ):
269+ def restore_numpy_random ( verbose = False ):
233270 """
234- Backward-compatible alias for monkey_patch().
271+ Restore NumPy's random submodule to its original implementations.
272+
273+ Parameters
274+ ----------
275+ verbose : bool, optional
276+ Print message when starting restoration process.
235277 """
236- monkey_patch(
278+ _patch .do_restore (verbose = bool (verbose ))
279+
280+
281+ def monkey_patch (numpy_module = None , names = None , strict = False , verbose = False ):
282+ """Backward-compatible alias for patch_numpy_random()."""
283+ patch_numpy_random (
284+ numpy_module = numpy_module ,
285+ names = names ,
286+ strict = strict ,
287+ verbose = verbose ,
288+ )
289+
290+
291+ def use_in_numpy (numpy_module = None , names = None , strict = False , verbose = False ):
292+ """Backward-compatible alias for patch_numpy_random()."""
293+ patch_numpy_random (
237294 numpy_module = numpy_module ,
238295 names = names ,
239296 strict = strict ,
@@ -242,11 +299,8 @@ def use_in_numpy(numpy_module=None, names=None, strict=False, verbose=False):
242299
243300
244301def restore (verbose = False ):
245- """
246- Disables using mkl_random in NumPy by restoring the original
247- `numpy.random` symbols.
248- """
249- _patch.do_restore(verbose = bool (verbose))
302+ """Backward-compatible alias for restore_numpy_random()."""
303+ restore_numpy_random (verbose = verbose )
250304
251305
252306def is_patched ():
@@ -265,7 +319,8 @@ def patched_names():
265319
266320class mkl_random (ContextDecorator ):
267321 """
268- Context manager and decorator to temporarily patch NumPy's `numpy.random`.
322+ Context manager and decorator to temporarily patch NumPy random submodule
323+ with MKL-based implementations.
269324
270325 Examples
271326 --------
@@ -274,19 +329,20 @@ class mkl_random(ContextDecorator):
274329 >>> with mkl_random.mkl_random(np):
275330 ... x = np.random.normal(size=10)
276331 """
332+
277333 def __init__ (self , numpy_module = None , names = None , strict = False ):
278334 self ._numpy_module = numpy_module
279335 self ._names = names
280336 self ._strict = strict
281337
282338 def __enter__ (self ):
283- monkey_patch (
339+ patch_numpy_random (
284340 numpy_module = self ._numpy_module ,
285341 names = self ._names ,
286342 strict = self ._strict ,
287343 )
288344 return self
289345
290346 def __exit__ (self , * exc ):
291- restore ()
347+ restore_numpy_random ()
292348 return False
0 commit comments