Skip to content

Commit 3c9aee6

Browse files
PytatoPyOpenCLArrayContext: use SVM allocator if available, limit arg size for GPUs (#189)
* LazilyPyOpenCLCompilingFunctionCaller: limit arg size for GPUs * move limit * also check for SVM presence * get_target() * memoize get_target * UNDO BEFORE MERGE: use dev branches * Hackety hack: SVM detection in actx constructor * check whether passed allocator supports SVM * undo loopy branch * implement it for the base class * subclass LoopyPyOpenCLTarget * set actual limit * undo pytato branch * remove unused argument * add type annotations * add logging * Refactor arg size passing to put less logic in the target * flake8 * add a test Co-authored-by: Andreas Kloeckner <inform@tiker.net>
1 parent 1593a67 commit 3c9aee6

File tree

3 files changed

+118
-3
lines changed

3 files changed

+118
-3
lines changed

arraycontext/impl/pytato/__init__.py

Lines changed: 94 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -54,15 +54,21 @@
5454
from arraycontext.container.traversal import (rec_map_array_container,
5555
with_array_context)
5656
from arraycontext.metadata import NameHint
57+
from pytools import memoize_method
5758

5859
if TYPE_CHECKING:
5960
import pytato
6061
import pyopencl as cl
62+
import loopy as lp
6163

6264
if getattr(sys, "_BUILDING_SPHINX_DOCS", False):
6365
import pyopencl as cl # noqa: F811
6466

6567

68+
import logging
69+
logger = logging.getLogger(__name__)
70+
71+
6672
# {{{ tag conversion
6773

6874
def _preprocess_array_tags(tags: ToTagSetConvertible) -> FrozenSet[Tag]:
@@ -203,13 +209,30 @@ def supports_nonscalar_broadcasting(self):
203209
def permits_advanced_indexing(self):
204210
return True
205211

212+
def get_target(self):
213+
return None
214+
206215
# }}}
207216

208217
# }}}
209218

210219

211220
# {{{ PytatoPyOpenCLArrayContext
212221

222+
from pytato.target.loopy import LoopyPyOpenCLTarget
223+
224+
225+
class _ArgSizeLimitingPytatoLoopyPyOpenCLTarget(LoopyPyOpenCLTarget):
226+
def __init__(self, limit_arg_size_nbytes: int) -> None:
227+
super().__init__()
228+
self.limit_arg_size_nbytes = limit_arg_size_nbytes
229+
230+
@memoize_method
231+
def get_loopy_target(self) -> Optional["lp.PyOpenCLTarget"]:
232+
from loopy import PyOpenCLTarget
233+
return PyOpenCLTarget(limit_arg_size_nbytes=self.limit_arg_size_nbytes)
234+
235+
213236
class PytatoPyOpenCLArrayContext(_BasePytatoArrayContext):
214237
"""
215238
A :class:`ArrayContext` that uses :mod:`pytato` data types to represent
@@ -232,7 +255,11 @@ class PytatoPyOpenCLArrayContext(_BasePytatoArrayContext):
232255
"""
233256
def __init__(
234257
self, queue: "cl.CommandQueue", allocator=None, *,
235-
compile_trace_callback: Optional[Callable[[Any, str, Any], None]] = None
258+
use_memory_pool: Optional[bool] = None,
259+
compile_trace_callback: Optional[Callable[[Any, str, Any], None]] = None,
260+
261+
# do not use: only for testing
262+
_force_svm_arg_limit: Optional[int] = None,
236263
) -> None:
237264
"""
238265
:arg compile_trace_callback: A function of three arguments
@@ -242,16 +269,57 @@ def __init__(
242269
representation. This interface should be considered
243270
unstable.
244271
"""
272+
if allocator is not None and use_memory_pool is not None:
273+
raise TypeError("may not specify both allocator and use_memory_pool")
274+
275+
self.using_svm = None
276+
277+
if allocator is None:
278+
from pyopencl.characterize import has_coarse_grain_buffer_svm
279+
has_svm = has_coarse_grain_buffer_svm(queue.device)
280+
if has_svm:
281+
self.using_svm = True
282+
283+
from pyopencl.tools import SVMAllocator
284+
allocator = SVMAllocator(queue.context, queue=queue)
285+
286+
if use_memory_pool:
287+
from pyopencl.tools import SVMPool
288+
allocator = SVMPool(allocator)
289+
else:
290+
self.using_svm = False
291+
292+
from pyopencl.tools import ImmediateAllocator
293+
allocator = ImmediateAllocator(queue.context)
294+
295+
if use_memory_pool:
296+
from pyopencl.tools import MemoryPool
297+
allocator = MemoryPool(allocator)
298+
else:
299+
# Check whether the passed allocator allocates SVM
300+
try:
301+
from pyopencl import SVMPointer
302+
mem = allocator(4)
303+
if isinstance(mem, SVMPointer):
304+
self.using_svm = True
305+
else:
306+
self.using_svm = False
307+
except ImportError:
308+
self.using_svm = False
309+
245310
import pytato as pt
246311
import pyopencl.array as cla
247312
super().__init__(compile_trace_callback=compile_trace_callback)
248313
self.queue = queue
314+
249315
self.allocator = allocator
250316
self.array_types = (pt.Array, cla.Array)
251317

252318
# unused, but necessary to keep the context alive
253319
self.context = self.queue.context
254320

321+
self._force_svm_arg_limit = _force_svm_arg_limit
322+
255323
@property
256324
def _frozen_array_types(self) -> Tuple[Type, ...]:
257325
import pyopencl.array as cla
@@ -321,6 +389,29 @@ def _to_numpy(ary):
321389
self._rec_map_container(_to_numpy, self.freeze(array)),
322390
actx=None)
323391

392+
@memoize_method
393+
def get_target(self):
394+
import pyopencl as cl
395+
import pyopencl.characterize as cl_char
396+
397+
dev = self.queue.device
398+
399+
if (
400+
self._force_svm_arg_limit is not None
401+
or (
402+
self.using_svm and dev.type & cl.device_type.GPU
403+
and cl_char.has_coarse_grain_buffer_svm(dev))):
404+
405+
limit = dev.max_parameter_size
406+
if self._force_svm_arg_limit is not None:
407+
limit = self._force_svm_arg_limit
408+
409+
logger.info(f"limiting argument buffer size for {dev} to {limit} bytes")
410+
411+
return _ArgSizeLimitingPytatoLoopyPyOpenCLTarget(limit)
412+
else:
413+
return super().get_target()
414+
324415
def freeze(self, array):
325416
if np.isscalar(array):
326417
return array
@@ -415,7 +506,8 @@ def _record_leaf_ary_in_dict(
415506
pt_prg = pt.generate_loopy(transformed_dag,
416507
options=_DEFAULT_LOOPY_OPTIONS,
417508
cl_device=self.queue.device,
418-
function_name=function_name)
509+
function_name=function_name,
510+
target=self.get_target())
419511
pt_prg = pt_prg.with_transformed_program(self.transform_loopy_program)
420512
self._freeze_prg_cache[normalized_expr] = pt_prg
421513
else:

arraycontext/impl/pytato/compile.py

Lines changed: 3 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -420,7 +420,9 @@ def _dag_to_transformed_pytato_prg(self, dict_of_named_arrays, *, prg_id=None):
420420
options=lp.Options(
421421
return_dict=True,
422422
no_numpy=True),
423-
function_name=_prg_id_to_kernel_name(prg_id))
423+
function_name=_prg_id_to_kernel_name(prg_id),
424+
target=self.actx.get_target(),
425+
)
424426
assert isinstance(pytato_program, BoundPyOpenCLProgram)
425427

426428
self.actx._compile_trace_callback(

test/test_pytato_arraycontext.py

Lines changed: 21 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -100,6 +100,27 @@ def test_tags_preserved_after_freeze(actx_factory):
100100
assert foo.axes[1].tags_of_type(BazTag)
101101

102102

103+
def test_arg_size_limit(actx_factory):
104+
ran_callback = False
105+
106+
def my_ctc(what, stage, ir):
107+
if stage == "final":
108+
assert ir.target.limit_arg_size_nbytes == 42
109+
nonlocal ran_callback
110+
ran_callback = True
111+
112+
def twice(x):
113+
return 2 * x
114+
115+
actx = _PytatoPyOpenCLArrayContextForTests(
116+
actx_factory().queue, compile_trace_callback=my_ctc, _force_svm_arg_limit=42)
117+
118+
f = actx.compile(twice)
119+
f(99)
120+
121+
assert ran_callback
122+
123+
103124
if __name__ == "__main__":
104125
import sys
105126
if len(sys.argv) > 1:

0 commit comments

Comments
 (0)