Skip to content

Commit b1a3301

Browse files
committed
fix: do not use memcpy impl, match sycl docs better
1 parent ae41152 commit b1a3301

File tree

6 files changed

+257
-5
lines changed

6 files changed

+257
-5
lines changed

dpctl/_backend.pxd

Lines changed: 12 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -511,6 +511,18 @@ cdef extern from "syclinterface/dpctl_sycl_queue_interface.h":
511511
size_t Count,
512512
const DPCTLSyclEventRef *depEvents,
513513
size_t depEventsCount)
514+
cdef DPCTLSyclEventRef DPCTLQueue_CopyData(
515+
const DPCTLSyclQueueRef Q,
516+
void *Dest,
517+
const void *Src,
518+
size_t Count)
519+
cdef DPCTLSyclEventRef DPCTLQueue_CopyDataWithEvents(
520+
const DPCTLSyclQueueRef Q,
521+
void *Dest,
522+
const void *Src,
523+
size_t Count,
524+
const DPCTLSyclEventRef *depEvents,
525+
size_t depEventsCount)
514526
cdef DPCTLSyclEventRef DPCTLQueue_Memset(
515527
const DPCTLSyclQueueRef Q,
516528
void *Dest,

dpctl/_sycl_queue.pyx

Lines changed: 85 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -35,6 +35,8 @@ from ._backend cimport ( # noqa: E211
3535
DPCTLFilterSelector_Create,
3636
DPCTLQueue_AreEq,
3737
DPCTLQueue_Copy,
38+
DPCTLQueue_CopyData,
39+
DPCTLQueue_CopyDataWithEvents,
3840
DPCTLQueue_Create,
3941
DPCTLQueue_Delete,
4042
DPCTLQueue_GetBackend,
@@ -535,6 +537,80 @@ cdef DPCTLSyclEventRef _memcpy_impl(
535537
return ERef
536538

537539

540+
cdef DPCTLSyclEventRef _copy_impl(
541+
SyclQueue q,
542+
object dst,
543+
object src,
544+
size_t byte_count,
545+
DPCTLSyclEventRef *dep_events,
546+
size_t dep_events_count
547+
) except *:
548+
cdef void *c_dst_ptr = NULL
549+
cdef void *c_src_ptr = NULL
550+
cdef DPCTLSyclEventRef ERef = NULL
551+
cdef Py_buffer src_buf_view
552+
cdef Py_buffer dst_buf_view
553+
cdef bint src_is_buf = False
554+
cdef bint dst_is_buf = False
555+
cdef int ret_code = 0
556+
557+
if isinstance(src, _Memory):
558+
c_src_ptr = <void*>(<_Memory>src).get_data_ptr()
559+
elif _is_buffer(src):
560+
ret_code = PyObject_GetBuffer(
561+
src, &src_buf_view, PyBUF_SIMPLE | PyBUF_ANY_CONTIGUOUS
562+
)
563+
if ret_code != 0: # pragma: no cover
564+
raise RuntimeError("Could not access buffer")
565+
c_src_ptr = src_buf_view.buf
566+
src_is_buf = True
567+
else:
568+
raise TypeError(
569+
"Parameter `src` should have either type "
570+
"`dpctl.memory._Memory` or a type that "
571+
"supports Python buffer protocol"
572+
)
573+
574+
if isinstance(dst, _Memory):
575+
c_dst_ptr = <void*>(<_Memory>dst).get_data_ptr()
576+
elif _is_buffer(dst):
577+
ret_code = PyObject_GetBuffer(
578+
dst, &dst_buf_view,
579+
PyBUF_SIMPLE | PyBUF_ANY_CONTIGUOUS | PyBUF_WRITABLE
580+
)
581+
if ret_code != 0: # pragma: no cover
582+
if src_is_buf:
583+
PyBuffer_Release(&src_buf_view)
584+
raise RuntimeError("Could not access buffer")
585+
c_dst_ptr = dst_buf_view.buf
586+
dst_is_buf = True
587+
else:
588+
raise TypeError(
589+
"Parameter `dst` should have either type "
590+
"`dpctl.memory._Memory` or a type that "
591+
"supports Python buffer protocol"
592+
)
593+
594+
if dep_events_count == 0 or dep_events is NULL:
595+
ERef = DPCTLQueue_CopyData(q._queue_ref, c_dst_ptr, c_src_ptr, byte_count)
596+
else:
597+
ERef = DPCTLQueue_CopyDataWithEvents(
598+
q._queue_ref,
599+
c_dst_ptr,
600+
c_src_ptr,
601+
byte_count,
602+
dep_events,
603+
dep_events_count
604+
)
605+
606+
if src_is_buf:
607+
PyBuffer_Release(&src_buf_view)
608+
if dst_is_buf:
609+
PyBuffer_Release(&dst_buf_view)
610+
611+
return ERef
612+
613+
538614
cdef class _SyclQueue:
539615
""" Barebone data owner class used by SyclQueue.
540616
"""
@@ -1381,7 +1457,7 @@ cdef class SyclQueue(_SyclQueue):
13811457
"""Copy memory from `src` to `dst`"""
13821458
cdef DPCTLSyclEventRef ERef = NULL
13831459

1384-
ERef = _memcpy_impl(<SyclQueue>self, dest, src, count, NULL, 0)
1460+
ERef = _copy_impl(<SyclQueue>self, dest, src, count, NULL, 0)
13851461
if (ERef is NULL):
13861462
raise RuntimeError(
13871463
"SyclQueue.memcpy operation encountered an error"
@@ -1399,7 +1475,7 @@ cdef class SyclQueue(_SyclQueue):
13991475
cdef size_t nDE = 0
14001476

14011477
if dEvents is None:
1402-
ERef = _memcpy_impl(<SyclQueue>self, dest, src, count, NULL, 0)
1478+
ERef = _copy_impl(<SyclQueue>self, dest, src, count, NULL, 0)
14031479
else:
14041480
nDE = len(dEvents)
14051481
depEvents = (
@@ -1416,7 +1492,7 @@ cdef class SyclQueue(_SyclQueue):
14161492
raise TypeError(
14171493
"A sequence of dpctl.SyclEvent is expected"
14181494
)
1419-
ERef = _memcpy_impl(self, dest, src, count, depEvents, nDE)
1495+
ERef = _copy_impl(self, dest, src, count, depEvents, nDE)
14201496
free(depEvents)
14211497

14221498
if (ERef is NULL):
@@ -1429,6 +1505,9 @@ cdef class SyclQueue(_SyclQueue):
14291505
cpdef copy(self, dest, src, size_t count):
14301506
"""Copy ``count`` bytes from ``src`` to ``dest`` and wait.
14311507
1508+
Internally, this dispatches ``sycl::queue::copy`` instantiated for
1509+
byte-sized elements.
1510+
14321511
This is a synchronizing variant corresponding to
14331512
:meth:`dpctl.SyclQueue.copy_async`.
14341513
"""
@@ -1448,6 +1527,9 @@ cdef class SyclQueue(_SyclQueue):
14481527
):
14491528
"""Copy ``count`` bytes from ``src`` to ``dest`` asynchronously.
14501529
1530+
Internally, this dispatches ``sycl::queue::copy`` instantiated for
1531+
byte-sized elements.
1532+
14511533
Args:
14521534
dest:
14531535
Destination USM object or Python object supporting

dpctl/tests/test_sycl_queue_memcpy.py

Lines changed: 34 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -155,6 +155,40 @@ def test_memcpy_async():
155155
assert dst_buf2 == src_buf
156156

157157

158+
def test_copy_copy_host_to_host():
159+
try:
160+
q = dpctl.SyclQueue()
161+
except dpctl.SyclQueueCreationError:
162+
pytest.skip("Default constructor for SyclQueue failed")
163+
164+
src_buf = b"abcdefghijklmnopqrstuvwxyz"
165+
dst_buf = bytearray(len(src_buf))
166+
167+
q.copy(dst_buf, src_buf, len(src_buf))
168+
169+
assert dst_buf == src_buf
170+
171+
172+
def test_copy_async():
173+
try:
174+
q = dpctl.SyclQueue()
175+
except dpctl.SyclQueueCreationError:
176+
pytest.skip("Default constructor for SyclQueue failed")
177+
178+
src_buf = b"abcdefghijklmnopqrstuvwxyz"
179+
n = len(src_buf)
180+
dst_buf = bytearray(n)
181+
dst_buf2 = bytearray(n)
182+
183+
e = q.copy_async(dst_buf, src_buf, n)
184+
e2 = q.copy_async(dst_buf2, src_buf, n, [e])
185+
186+
e.wait()
187+
e2.wait()
188+
assert dst_buf == src_buf
189+
assert dst_buf2 == src_buf
190+
191+
158192
def test_memcpy_type_error():
159193
try:
160194
q = dpctl.SyclQueue()

libsyclinterface/include/syclinterface/dpctl_sycl_queue_interface.h

Lines changed: 47 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -292,7 +292,9 @@ void DPCTLQueue_Wait(__dpctl_keep const DPCTLSyclQueueRef QRef);
292292
* @param QRef An opaque pointer to the ``sycl::queue``.
293293
* @param Dest An USM pointer to the destination memory.
294294
* @param Src An USM pointer to the source memory.
295-
* @param Count A number of bytes to copy.
295+
* @param Count A number of bytes to copy. The wrapper binds
296+
* ``sycl::queue::copy`` with ``T=uint8_t`` so the
297+
* SYCL element count matches a byte count.
296298
* @return An opaque pointer to the ``sycl::event`` returned by the
297299
* ``sycl::queue::memcpy`` function.
298300
* @ingroup QueueInterface
@@ -310,7 +312,9 @@ DPCTLQueue_Memcpy(__dpctl_keep const DPCTLSyclQueueRef QRef,
310312
* @param QRef An opaque pointer to the ``sycl::queue``.
311313
* @param Dest An USM pointer to the destination memory.
312314
* @param Src An USM pointer to the source memory.
313-
* @param Count A number of bytes to copy.
315+
* @param Count A number of bytes to copy. The wrapper binds
316+
* ``sycl::queue::copy`` with ``T=uint8_t`` so the
317+
* SYCL element count matches a byte count.
314318
* @param DepEvents A pointer to array of DPCTLSyclEventRef opaque
315319
* pointers to dependent events.
316320
* @param DepEventsCount A number of dependent events.
@@ -327,6 +331,47 @@ DPCTLQueue_MemcpyWithEvents(__dpctl_keep const DPCTLSyclQueueRef QRef,
327331
__dpctl_keep const DPCTLSyclEventRef *DepEvents,
328332
size_t DepEventsCount);
329333

334+
/*!
335+
* @brief C-API wrapper for ``sycl::queue::copy``.
336+
*
337+
* @param QRef An opaque pointer to the ``sycl::queue``.
338+
* @param Dest A destination pointer.
339+
* @param Src A source pointer.
340+
* @param Count A number of bytes to copy.
341+
* @return An opaque pointer to the ``sycl::event`` returned by the
342+
* ``sycl::queue::copy`` function.
343+
* @ingroup QueueInterface
344+
*/
345+
DPCTL_API
346+
__dpctl_give DPCTLSyclEventRef
347+
DPCTLQueue_CopyData(__dpctl_keep const DPCTLSyclQueueRef QRef,
348+
void *Dest,
349+
const void *Src,
350+
size_t Count);
351+
352+
/*!
353+
* @brief C-API wrapper for ``sycl::queue::copy``.
354+
*
355+
* @param QRef An opaque pointer to the ``sycl::queue``.
356+
* @param Dest A destination pointer.
357+
* @param Src A source pointer.
358+
* @param Count A number of bytes to copy.
359+
* @param DepEvents A pointer to array of DPCTLSyclEventRef opaque
360+
* pointers to dependent events.
361+
* @param DepEventsCount A number of dependent events.
362+
* @return An opaque pointer to the ``sycl::event`` returned by the
363+
* ``sycl::queue::copy`` function.
364+
* @ingroup QueueInterface
365+
*/
366+
DPCTL_API
367+
__dpctl_give DPCTLSyclEventRef
368+
DPCTLQueue_CopyDataWithEvents(__dpctl_keep const DPCTLSyclQueueRef QRef,
369+
void *Dest,
370+
const void *Src,
371+
size_t Count,
372+
__dpctl_keep const DPCTLSyclEventRef *DepEvents,
373+
size_t DepEventsCount);
374+
330375
/*!
331376
* @brief C-API wrapper for ``sycl::queue::prefetch``.
332377
*

libsyclinterface/source/dpctl_sycl_queue_interface.cpp

Lines changed: 65 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -41,6 +41,7 @@
4141
#include <stdexcept>
4242
#include <sycl/sycl.hpp> /* SYCL headers */
4343
#include <utility>
44+
#include <vector>
4445

4546
#if defined(SYCL_EXT_ONEAPI_WORK_GROUP_MEMORY) || \
4647
defined(SYCL_EXT_ONEAPI_RAW_KERNEL_ARG)
@@ -693,6 +694,70 @@ DPCTLQueue_MemcpyWithEvents(__dpctl_keep const DPCTLSyclQueueRef QRef,
693694
return wrap<event>(new event(ev));
694695
}
695696

697+
__dpctl_give DPCTLSyclEventRef
698+
DPCTLQueue_CopyData(__dpctl_keep const DPCTLSyclQueueRef QRef,
699+
void *Dest,
700+
const void *Src,
701+
size_t Count)
702+
{
703+
auto Q = unwrap<queue>(QRef);
704+
if (Q) {
705+
sycl::event ev;
706+
try {
707+
// Bind queue::copy with uint8_t so Count is interpreted as bytes.
708+
ev = Q->copy(static_cast<const std::uint8_t *>(Src),
709+
static_cast<std::uint8_t *>(Dest), Count);
710+
} catch (std::exception const &e) {
711+
error_handler(e, __FILE__, __func__, __LINE__);
712+
return nullptr;
713+
}
714+
return wrap<event>(new event(std::move(ev)));
715+
}
716+
else {
717+
error_handler("QRef passed to copy was NULL.", __FILE__, __func__,
718+
__LINE__);
719+
return nullptr;
720+
}
721+
}
722+
723+
__dpctl_give DPCTLSyclEventRef
724+
DPCTLQueue_CopyDataWithEvents(__dpctl_keep const DPCTLSyclQueueRef QRef,
725+
void *Dest,
726+
const void *Src,
727+
size_t Count,
728+
const DPCTLSyclEventRef *DepEvents,
729+
size_t DepEventsCount)
730+
{
731+
auto Q = unwrap<queue>(QRef);
732+
if (Q) {
733+
try {
734+
std::vector<event> dep_events;
735+
if (DepEvents) {
736+
dep_events.reserve(DepEventsCount);
737+
for (size_t i = 0; i < DepEventsCount; ++i) {
738+
event *ei = unwrap<event>(DepEvents[i]);
739+
if (ei)
740+
dep_events.push_back(*ei);
741+
}
742+
}
743+
744+
// Bind queue::copy with uint8_t so Count is interpreted as bytes.
745+
auto ev = Q->copy(static_cast<const std::uint8_t *>(Src),
746+
static_cast<std::uint8_t *>(Dest), Count,
747+
dep_events);
748+
return wrap<event>(new event(std::move(ev)));
749+
} catch (const std::exception &ex) {
750+
error_handler(ex, __FILE__, __func__, __LINE__);
751+
return nullptr;
752+
}
753+
}
754+
else {
755+
error_handler("QRef passed to copy was NULL.", __FILE__, __func__,
756+
__LINE__);
757+
return nullptr;
758+
}
759+
}
760+
696761
__dpctl_give DPCTLSyclEventRef
697762
DPCTLQueue_Prefetch(__dpctl_keep DPCTLSyclQueueRef QRef,
698763
const void *Ptr,

libsyclinterface/tests/test_sycl_queue_interface.cpp

Lines changed: 14 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -371,6 +371,13 @@ TEST(TestDPCTLSyclQueueInterface, CheckMemOpsZeroQRef)
371371
ERef = DPCTLQueue_MemcpyWithEvents(QRef, p1, p2, n_bytes, NULL, 0));
372372
ASSERT_FALSE(bool(ERef));
373373

374+
ASSERT_NO_FATAL_FAILURE(ERef = DPCTLQueue_CopyData(QRef, p1, p2, n_bytes));
375+
ASSERT_FALSE(bool(ERef));
376+
377+
ASSERT_NO_FATAL_FAILURE(ERef = DPCTLQueue_CopyDataWithEvents(
378+
QRef, p1, p2, n_bytes, NULL, 0));
379+
ASSERT_FALSE(bool(ERef));
380+
374381
ASSERT_NO_FATAL_FAILURE(ERef = DPCTLQueue_Prefetch(QRef, p1, n_bytes));
375382
ASSERT_FALSE(bool(ERef));
376383

@@ -429,6 +436,13 @@ TEST_P(TestDPCTLQueueMemberFunctions, CheckMemOpsNullPtr)
429436
ERef = DPCTLQueue_MemcpyWithEvents(QRef, p1, p2, n_bytes, NULL, 0));
430437
ASSERT_FALSE(bool(ERef));
431438

439+
ASSERT_NO_FATAL_FAILURE(ERef = DPCTLQueue_CopyData(QRef, p1, p2, n_bytes));
440+
ASSERT_FALSE(bool(ERef));
441+
442+
ASSERT_NO_FATAL_FAILURE(ERef = DPCTLQueue_CopyDataWithEvents(
443+
QRef, p1, p2, n_bytes, NULL, 0));
444+
ASSERT_FALSE(bool(ERef));
445+
432446
ASSERT_NO_FATAL_FAILURE(ERef = DPCTLQueue_Prefetch(QRef, p1, n_bytes));
433447
if (ERef) {
434448
ASSERT_NO_FATAL_FAILURE(DPCTLEvent_Wait(ERef));

0 commit comments

Comments
 (0)