5454from arraycontext .container .traversal import (rec_map_array_container ,
5555 with_array_context )
5656from arraycontext .metadata import NameHint
57+ from pytools import memoize_method
5758
5859if TYPE_CHECKING :
5960 import pytato
6061 import pyopencl as cl
62+ import loopy as lp
6163
6264if getattr (sys , "_BUILDING_SPHINX_DOCS" , False ):
6365 import pyopencl as cl # noqa: F811
6466
6567
68+ import logging
69+ logger = logging .getLogger (__name__ )
70+
71+
6672# {{{ tag conversion
6773
6874def _preprocess_array_tags (tags : ToTagSetConvertible ) -> FrozenSet [Tag ]:
@@ -203,13 +209,30 @@ def supports_nonscalar_broadcasting(self):
203209 def permits_advanced_indexing (self ):
204210 return True
205211
212+ def get_target (self ):
213+ return None
214+
206215 # }}}
207216
208217# }}}
209218
210219
211220# {{{ PytatoPyOpenCLArrayContext
212221
222+ from pytato .target .loopy import LoopyPyOpenCLTarget
223+
224+
225+ class _ArgSizeLimitingPytatoLoopyPyOpenCLTarget (LoopyPyOpenCLTarget ):
226+ def __init__ (self , limit_arg_size_nbytes : int ) -> None :
227+ super ().__init__ ()
228+ self .limit_arg_size_nbytes = limit_arg_size_nbytes
229+
230+ @memoize_method
231+ def get_loopy_target (self ) -> Optional ["lp.PyOpenCLTarget" ]:
232+ from loopy import PyOpenCLTarget
233+ return PyOpenCLTarget (limit_arg_size_nbytes = self .limit_arg_size_nbytes )
234+
235+
213236class PytatoPyOpenCLArrayContext (_BasePytatoArrayContext ):
214237 """
215238 A :class:`ArrayContext` that uses :mod:`pytato` data types to represent
@@ -232,7 +255,11 @@ class PytatoPyOpenCLArrayContext(_BasePytatoArrayContext):
232255 """
233256 def __init__ (
234257 self , queue : "cl.CommandQueue" , allocator = None , * ,
235- compile_trace_callback : Optional [Callable [[Any , str , Any ], None ]] = None
258+ use_memory_pool : Optional [bool ] = None ,
259+ compile_trace_callback : Optional [Callable [[Any , str , Any ], None ]] = None ,
260+
261+ # do not use: only for testing
262+ _force_svm_arg_limit : Optional [int ] = None ,
236263 ) -> None :
237264 """
238265 :arg compile_trace_callback: A function of three arguments
@@ -242,16 +269,57 @@ def __init__(
242269 representation. This interface should be considered
243270 unstable.
244271 """
272+ if allocator is not None and use_memory_pool is not None :
273+ raise TypeError ("may not specify both allocator and use_memory_pool" )
274+
275+ self .using_svm = None
276+
277+ if allocator is None :
278+ from pyopencl .characterize import has_coarse_grain_buffer_svm
279+ has_svm = has_coarse_grain_buffer_svm (queue .device )
280+ if has_svm :
281+ self .using_svm = True
282+
283+ from pyopencl .tools import SVMAllocator
284+ allocator = SVMAllocator (queue .context , queue = queue )
285+
286+ if use_memory_pool :
287+ from pyopencl .tools import SVMPool
288+ allocator = SVMPool (allocator )
289+ else :
290+ self .using_svm = False
291+
292+ from pyopencl .tools import ImmediateAllocator
293+ allocator = ImmediateAllocator (queue .context )
294+
295+ if use_memory_pool :
296+ from pyopencl .tools import MemoryPool
297+ allocator = MemoryPool (allocator )
298+ else :
299+ # Check whether the passed allocator allocates SVM
300+ try :
301+ from pyopencl import SVMPointer
302+ mem = allocator (4 )
303+ if isinstance (mem , SVMPointer ):
304+ self .using_svm = True
305+ else :
306+ self .using_svm = False
307+ except ImportError :
308+ self .using_svm = False
309+
245310 import pytato as pt
246311 import pyopencl .array as cla
247312 super ().__init__ (compile_trace_callback = compile_trace_callback )
248313 self .queue = queue
314+
249315 self .allocator = allocator
250316 self .array_types = (pt .Array , cla .Array )
251317
252318 # unused, but necessary to keep the context alive
253319 self .context = self .queue .context
254320
321+ self ._force_svm_arg_limit = _force_svm_arg_limit
322+
255323 @property
256324 def _frozen_array_types (self ) -> Tuple [Type , ...]:
257325 import pyopencl .array as cla
@@ -321,6 +389,29 @@ def _to_numpy(ary):
321389 self ._rec_map_container (_to_numpy , self .freeze (array )),
322390 actx = None )
323391
392+ @memoize_method
393+ def get_target (self ):
394+ import pyopencl as cl
395+ import pyopencl .characterize as cl_char
396+
397+ dev = self .queue .device
398+
399+ if (
400+ self ._force_svm_arg_limit is not None
401+ or (
402+ self .using_svm and dev .type & cl .device_type .GPU
403+ and cl_char .has_coarse_grain_buffer_svm (dev ))):
404+
405+ limit = dev .max_parameter_size
406+ if self ._force_svm_arg_limit is not None :
407+ limit = self ._force_svm_arg_limit
408+
409+ logger .info (f"limiting argument buffer size for { dev } to { limit } bytes" )
410+
411+ return _ArgSizeLimitingPytatoLoopyPyOpenCLTarget (limit )
412+ else :
413+ return super ().get_target ()
414+
324415 def freeze (self , array ):
325416 if np .isscalar (array ):
326417 return array
@@ -415,7 +506,8 @@ def _record_leaf_ary_in_dict(
415506 pt_prg = pt .generate_loopy (transformed_dag ,
416507 options = _DEFAULT_LOOPY_OPTIONS ,
417508 cl_device = self .queue .device ,
418- function_name = function_name )
509+ function_name = function_name ,
510+ target = self .get_target ())
419511 pt_prg = pt_prg .with_transformed_program (self .transform_loopy_program )
420512 self ._freeze_prg_cache [normalized_expr ] = pt_prg
421513 else :
0 commit comments