Skip to content

Commit 2e67f59

Browse files
authored
Merge pull request #2219 from IntelPython/resolve_gh_2213
Fix DLPack С-contiguous stride reconstruction
2 parents 019b203 + 8090e95 commit 2e67f59

File tree

2 files changed

+83
-9
lines changed

2 files changed

+83
-9
lines changed

dpctl/tensor/_dlpack.pyx

Lines changed: 44 additions & 9 deletions
Original file line numberDiff line numberDiff line change
@@ -36,7 +36,12 @@ from .._backend cimport (
3636
DPCTLSyclDeviceRef,
3737
DPCTLSyclUSMRef,
3838
)
39-
from ._usmarray cimport USM_ARRAY_WRITABLE, usm_ndarray
39+
from ._usmarray cimport (
40+
USM_ARRAY_C_CONTIGUOUS,
41+
USM_ARRAY_F_CONTIGUOUS,
42+
USM_ARRAY_WRITABLE,
43+
usm_ndarray,
44+
)
4045

4146
import ctypes
4247

@@ -266,6 +271,7 @@ cpdef to_dlpack_capsule(usm_ndarray usm_ary):
266271
cdef int64_t *shape_strides_ptr = NULL
267272
cdef int i = 0
268273
cdef int device_id = -1
274+
cdef int flags = 0
269275
cdef Py_ssize_t element_offset = 0
270276
cdef Py_ssize_t byte_offset = 0
271277
cdef Py_ssize_t si = 1
@@ -291,14 +297,29 @@ cpdef to_dlpack_capsule(usm_ndarray usm_ary):
291297
for i in range(nd):
292298
shape_strides_ptr[i] = shape_ptr[i]
293299
strides_ptr = usm_ary.get_strides()
300+
flags = usm_ary.flags_
294301
if strides_ptr:
295302
for i in range(nd):
296303
shape_strides_ptr[nd + i] = strides_ptr[i]
297304
else:
298-
si = 1
299-
for i in range(0, nd):
300-
shape_strides_ptr[nd + i] = si
301-
si = si * shape_ptr[i]
305+
if flags & USM_ARRAY_C_CONTIGUOUS:
306+
si = 1
307+
for i in range(nd - 1, -1, -1):
308+
shape_strides_ptr[nd + i] = si
309+
si = si * shape_ptr[i]
310+
elif flags & USM_ARRAY_F_CONTIGUOUS:
311+
si = 1
312+
for i in range(0, nd):
313+
shape_strides_ptr[nd + i] = si
314+
si = si * shape_ptr[i]
315+
else:
316+
stdlib.free(shape_strides_ptr)
317+
stdlib.free(dlm_tensor)
318+
raise BufferError(
319+
"to_dlpack_capsule: Invalid array encountered "
320+
"when building strides"
321+
)
322+
302323
strides_ptr = <Py_ssize_t *>&shape_strides_ptr[nd]
303324

304325
ary_dt = usm_ary.dtype
@@ -409,10 +430,24 @@ cpdef to_dlpack_versioned_capsule(usm_ndarray usm_ary, bint copied):
409430
for i in range(nd):
410431
shape_strides_ptr[nd + i] = strides_ptr[i]
411432
else:
412-
si = 1
413-
for i in range(0, nd):
414-
shape_strides_ptr[nd + i] = si
415-
si = si * shape_ptr[i]
433+
if flags & USM_ARRAY_C_CONTIGUOUS:
434+
si = 1
435+
for i in range(nd - 1, -1, -1):
436+
shape_strides_ptr[nd + i] = si
437+
si = si * shape_ptr[i]
438+
elif flags & USM_ARRAY_F_CONTIGUOUS:
439+
si = 1
440+
for i in range(0, nd):
441+
shape_strides_ptr[nd + i] = si
442+
si = si * shape_ptr[i]
443+
else:
444+
stdlib.free(shape_strides_ptr)
445+
stdlib.free(dlmv_tensor)
446+
raise BufferError(
447+
"to_dlpack_versioned_capsule: Invalid array encountered "
448+
"when building strides"
449+
)
450+
416451
strides_ptr = <Py_ssize_t *>&shape_strides_ptr[nd]
417452

418453
# this can all be a function for building the dl_tensor

dpctl/tests/test_usm_ndarray_dlpack.py

Lines changed: 39 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -664,6 +664,45 @@ def test_dlpack_capsule_readonly_array_to_kdlcpu():
664664
assert not y1.flags["W"]
665665

666666

667+
def test_to_dlpack_capsule_c_and_f_contig():
668+
try:
669+
x = dpt.asarray(np.random.rand(2, 3))
670+
except dpctl.SyclDeviceCreationError:
671+
pytest.skip("No default device available")
672+
673+
cap = _dlp.to_dlpack_capsule(x)
674+
y = _dlp.from_dlpack_capsule(cap)
675+
assert np.allclose(dpt.asnumpy(x), dpt.asnumpy(y))
676+
assert x.strides == y.strides
677+
678+
x_f = x.T
679+
cap = _dlp.to_dlpack_capsule(x_f)
680+
yf = _dlp.from_dlpack_capsule(cap)
681+
assert np.allclose(dpt.asnumpy(x_f), dpt.asnumpy(yf))
682+
assert x_f.strides == yf.strides
683+
del cap
684+
685+
686+
def test_to_dlpack_versioned_capsule_c_and_f_contig():
687+
try:
688+
x = dpt.asarray(np.random.rand(2, 3))
689+
max_supported_ver = _dlp.get_build_dlpack_version()
690+
except dpctl.SyclDeviceCreationError:
691+
pytest.skip("No default device available")
692+
693+
cap = x.__dlpack__(max_version=max_supported_ver)
694+
y = _dlp.from_dlpack_capsule(cap)
695+
assert np.allclose(dpt.asnumpy(x), dpt.asnumpy(y))
696+
assert x.strides == y.strides
697+
698+
x_f = x.T
699+
cap = x_f.__dlpack__(max_version=max_supported_ver)
700+
yf = _dlp.from_dlpack_capsule(cap)
701+
assert np.allclose(dpt.asnumpy(x_f), dpt.asnumpy(yf))
702+
assert x_f.strides == yf.strides
703+
del cap
704+
705+
667706
def test_used_dlpack_capsule_from_numpy():
668707
get_queue_or_skip()
669708

0 commit comments

Comments
 (0)