Skip to content

Commit 4d0462d

Browse files
authored
Improve linspace implementation with non-scalar inputs (#2712)
The PR closes #2084. This PR improves implementation of `dpnp.linspace` function aligning it with the latest changes added to NumPy. The tests coverage is also extended to verify more use cases.
1 parent 290ab65 commit 4d0462d

File tree

8 files changed

+167
-105
lines changed

8 files changed

+167
-105
lines changed

CHANGELOG.md

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -59,6 +59,7 @@ Also, that release drops support for Python 3.9, making Python 3.10 the minimum
5959

6060
* Suppressed a potential deprecation warning triggered during import of the `dpctl.tensor` module [#2709](https://github.com/IntelPython/dpnp/pull/2709)
6161
* Corrected a phonetic spelling issue due to incorrect using of `a nd` in docstrings [#2719](https://github.com/IntelPython/dpnp/pull/2719)
62+
* Resolved an issue causing `dpnp.linspace` to return an incorrect output shape when inputs were passed as arrays [#2712](https://github.com/IntelPython/dpnp/pull/2712)
6263

6364
### Security
6465

dpnp/dpnp_algo/dpnp_arraycreation.py

Lines changed: 27 additions & 15 deletions
Original file line numberDiff line numberDiff line change
@@ -172,14 +172,9 @@ def dpnp_linspace(
172172

173173
num = operator.index(num)
174174
if num < 0:
175-
raise ValueError("Number of points must be non-negative")
175+
raise ValueError(f"Number of samples={num} must be non-negative.")
176176
step_num = (num - 1) if endpoint else num
177177

178-
step_nan = False
179-
if step_num == 0:
180-
step_nan = True
181-
step = dpnp.nan
182-
183178
if dpnp.isscalar(start) and dpnp.isscalar(stop):
184179
# Call linspace() function for scalars.
185180
usm_res = dpt.linspace(
@@ -191,8 +186,13 @@ def dpnp_linspace(
191186
sycl_queue=sycl_queue_normalized,
192187
endpoint=endpoint,
193188
)
194-
if retstep is True and step_nan is False:
195-
step = (stop - start) / step_num
189+
190+
# calculate the used step to return
191+
if retstep is True:
192+
if step_num > 0:
193+
step = (stop - start) / step_num
194+
else:
195+
step = dpnp.nan
196196
else:
197197
usm_start = dpt.asarray(
198198
start,
@@ -204,6 +204,8 @@ def dpnp_linspace(
204204
stop, dtype=dt, usm_type=_usm_type, sycl_queue=sycl_queue_normalized
205205
)
206206

207+
delta = usm_stop - usm_start
208+
207209
usm_res = dpt.arange(
208210
0,
209211
stop=num,
@@ -212,20 +214,30 @@ def dpnp_linspace(
212214
usm_type=_usm_type,
213215
sycl_queue=sycl_queue_normalized,
214216
)
217+
usm_res = dpt.reshape(usm_res, (-1,) + (1,) * delta.ndim, copy=False)
218+
219+
if step_num > 0:
220+
step = delta / step_num
221+
222+
# Needed a special handling for denormal numbers (when step == 0),
223+
# see numpy#5437 for more details.
224+
# Note, dpt.where() is used to avoid a synchronization branch.
225+
usm_res = dpt.where(
226+
step == 0, (usm_res / step_num) * delta, usm_res * step
227+
)
228+
else:
229+
step = dpnp.nan
230+
usm_res = usm_res * delta
215231

216-
if step_nan is False:
217-
step = (usm_stop - usm_start) / step_num
218-
usm_res = dpt.reshape(usm_res, (-1,) + (1,) * step.ndim, copy=False)
219-
usm_res = usm_res * step
220-
usm_res += usm_start
232+
usm_res += usm_start
221233

222234
if endpoint and num > 1:
223-
usm_res[-1] = dpt.full(step.shape, usm_stop)
235+
usm_res[-1, ...] = usm_stop
224236

225237
if axis != 0:
226238
usm_res = dpt.moveaxis(usm_res, 0, axis)
227239

228-
if numpy.issubdtype(dtype, dpnp.integer):
240+
if dpnp.issubdtype(dtype, dpnp.integer):
229241
dpt.floor(usm_res, out=usm_res)
230242

231243
res = dpt.astype(usm_res, dtype, copy=False)

dpnp/dpnp_iface_arraycreation.py

Lines changed: 2 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -2704,6 +2704,8 @@ def linspace(
27042704
of tuples, tuples of lists, and ndarrays. If `endpoint` is set to
27052705
``False`` the sequence consists of all but the last of ``num + 1``
27062706
evenly spaced samples, so that `stop` is excluded.
2707+
num : int
2708+
Number of samples. Must have a nonnegative value.
27072709
dtype : {None, str, dtype object}, optional
27082710
The desired dtype for the array. If not given, a default dtype will be
27092711
used that can represent the values (by considering Promotion Type Rule

dpnp/tests/helper.py

Lines changed: 3 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -69,6 +69,9 @@ def assert_dtype_allclose(
6969
x.dtype, dpnp.inexact
7070
)
7171

72+
if not hasattr(numpy_arr, "dtype"):
73+
numpy_arr = numpy.array(numpy_arr)
74+
7275
if is_inexact(dpnp_arr) or is_inexact(numpy_arr):
7376
tol_dpnp = (
7477
dpnp.finfo(dpnp_arr).resolution

dpnp/tests/test_arraycreation.py

Lines changed: 127 additions & 84 deletions
Original file line numberDiff line numberDiff line change
@@ -19,6 +19,8 @@
1919
assert_dtype_allclose,
2020
get_all_dtypes,
2121
get_array,
22+
get_float_dtypes,
23+
has_support_aspect64,
2224
is_lts_driver,
2325
is_tgllp_iris_xe,
2426
is_win_platform,
@@ -83,6 +85,131 @@ def test_validate_positional_args(self, xp):
8385
)
8486

8587

88+
class TestLinspace:
89+
@pytest.mark.parametrize("start", [0, -5, 10, -2.5, 9.7])
90+
@pytest.mark.parametrize("stop", [0, 10, -2, 20.5, 120])
91+
@pytest.mark.parametrize("num", [0, 1, 5, numpy.array(10)])
92+
@pytest.mark.parametrize(
93+
"dt", get_all_dtypes(no_bool=True, no_float16=False)
94+
)
95+
@pytest.mark.parametrize("retstep", [True, False])
96+
def test_basic(self, start, stop, num, dt, retstep):
97+
if (
98+
not has_support_aspect64()
99+
and numpy.issubdtype(dt, numpy.integer)
100+
and start == -5
101+
and stop == 10
102+
and num == 10
103+
):
104+
pytest.skip("due to dpctl-1056")
105+
106+
if numpy.issubdtype(dt, numpy.unsignedinteger):
107+
start = abs(start)
108+
stop = abs(stop)
109+
110+
res = dpnp.linspace(start, stop, num, dtype=dt, retstep=retstep)
111+
exp = numpy.linspace(start, stop, num, dtype=dt, retstep=retstep)
112+
if retstep:
113+
res, res_step = res
114+
exp, exp_step = exp
115+
assert_dtype_allclose(res_step, exp_step)
116+
117+
if numpy.issubdtype(dt, numpy.integer):
118+
assert_allclose(res, exp, rtol=1)
119+
else:
120+
assert_dtype_allclose(res, exp)
121+
122+
@pytest.mark.parametrize(
123+
"start, stop",
124+
[
125+
(dpnp.array(1), dpnp.array([-4])),
126+
(dpnp.array([2.6]), dpnp.array([[2.6], [-4]])),
127+
(numpy.array([[-6.7, 3]]), numpy.array(2)),
128+
([1, -4], [[-4.6]]),
129+
((3, 5), (3,)),
130+
],
131+
)
132+
@pytest.mark.parametrize("num", [0, 1, 5])
133+
@pytest.mark.parametrize(
134+
"dt", get_all_dtypes(no_bool=True, no_float16=False)
135+
)
136+
@pytest.mark.parametrize("retstep", [True, False])
137+
def test_start_stop_arrays(self, start, stop, num, dt, retstep):
138+
res = dpnp.linspace(start, stop, num, dtype=dt, retstep=retstep)
139+
exp = numpy.linspace(
140+
get_array(numpy, start),
141+
get_array(numpy, stop),
142+
num,
143+
dtype=dt,
144+
retstep=retstep,
145+
)
146+
if retstep:
147+
res, res_step = res
148+
exp, exp_step = exp
149+
assert_dtype_allclose(res_step, exp_step)
150+
assert_dtype_allclose(res, exp)
151+
152+
@pytest.mark.parametrize(
153+
"start, stop",
154+
[(1 + 2j, 3 + 4j), (1j, 10), ([0, 1], 3 + 2j)],
155+
)
156+
def test_start_stop_complex(self, start, stop):
157+
result = dpnp.linspace(start, stop, num=5)
158+
expected = numpy.linspace(start, stop, num=5)
159+
assert_dtype_allclose(result, expected)
160+
161+
@pytest.mark.parametrize("dt", get_float_dtypes())
162+
def test_denormal_numbers(self, dt):
163+
stop = numpy.nextafter(dt(0), dt(1)) * 5 # denormal number
164+
165+
result = dpnp.linspace(0, stop, num=10, endpoint=False, dtype=dt)
166+
expected = numpy.linspace(0, stop, num=10, endpoint=False, dtype=dt)
167+
assert_dtype_allclose(result, expected)
168+
169+
@pytest.mark.skipif(not has_support_aspect64(), reason="due to dpctl-1056")
170+
def test_equivalent_to_arange(self):
171+
result = dpnp.linspace(0, 35, num=36, dtype=int)
172+
expected = numpy.linspace(0, 35, num=36, dtype=int)
173+
assert_equal(result, expected)
174+
175+
def test_round_negative(self):
176+
result = dpnp.linspace(-1, 3, num=8, dtype=int)
177+
expected = numpy.linspace(-1, 3, num=8, dtype=int)
178+
assert_array_equal(result, expected)
179+
180+
def test_step_zero(self):
181+
start = numpy.array([0.0, 1.0])
182+
stop = numpy.array([2.0, 1.0])
183+
184+
result = dpnp.linspace(start, stop, num=3)
185+
expected = numpy.linspace(start, stop, num=3)
186+
assert_array_equal(result, expected)
187+
188+
@pytest.mark.parametrize("endpoint", [True, False])
189+
def test_num_zero(self, endpoint):
190+
start, stop = 0, [0, 1, 2, 3, 4]
191+
result = dpnp.linspace(start, stop, num=0, endpoint=endpoint)
192+
expected = numpy.linspace(start, stop, num=0, endpoint=endpoint)
193+
assert_dtype_allclose(result, expected)
194+
195+
@pytest.mark.parametrize("axis", [0, 1])
196+
def test_axis(self, axis):
197+
func = lambda xp: xp.linspace([2, 3], [20, 15], num=10, axis=axis)
198+
assert_allclose(func(dpnp), func(numpy))
199+
200+
@pytest.mark.parametrize("xp", [dpnp, numpy])
201+
def test_negative_num(self, xp):
202+
with pytest.raises(ValueError, match="must be non-negative"):
203+
_ = xp.linspace(0, 10, num=-1)
204+
205+
@pytest.mark.parametrize("xp", [dpnp, numpy])
206+
def test_float_num(self, xp):
207+
with pytest.raises(
208+
TypeError, match="cannot be interpreted as an integer"
209+
):
210+
_ = xp.linspace(0, 1, num=2.5)
211+
212+
86213
class TestTrace:
87214
@pytest.mark.parametrize("a_sh", [(3, 4), (2, 2, 2)])
88215
@pytest.mark.parametrize(
@@ -734,37 +861,6 @@ def test_dpctl_tensor_input(func, args):
734861
assert_array_equal(X, Y)
735862

736863

737-
@pytest.mark.parametrize("start", [0, -5, 10, -2.5, 9.7])
738-
@pytest.mark.parametrize("stop", [0, 10, -2, 20.5, 120])
739-
@pytest.mark.parametrize(
740-
"num",
741-
[1, 5, numpy.array(10), dpnp.array(17), dpt.asarray(100)],
742-
ids=["1", "5", "numpy.array(10)", "dpnp.array(17)", "dpt.asarray(100)"],
743-
)
744-
@pytest.mark.parametrize(
745-
"dtype",
746-
get_all_dtypes(no_bool=True, no_float16=False),
747-
)
748-
@pytest.mark.parametrize("retstep", [True, False])
749-
def test_linspace(start, stop, num, dtype, retstep):
750-
if numpy.issubdtype(dtype, numpy.unsignedinteger):
751-
start = abs(start)
752-
stop = abs(stop)
753-
754-
res_np = numpy.linspace(start, stop, num, dtype=dtype, retstep=retstep)
755-
res_dp = dpnp.linspace(start, stop, num, dtype=dtype, retstep=retstep)
756-
757-
if retstep:
758-
[res_np, step_np] = res_np
759-
[res_dp, step_dp] = res_dp
760-
assert_allclose(step_np, step_dp)
761-
762-
if numpy.issubdtype(dtype, dpnp.integer):
763-
assert_allclose(res_np, res_dp, rtol=1)
764-
else:
765-
assert_dtype_allclose(res_dp, res_np)
766-
767-
768864
@pytest.mark.parametrize("func", ["geomspace", "linspace", "logspace"])
769865
@pytest.mark.parametrize(
770866
"start_dtype", [numpy.float64, numpy.float32, numpy.int64, numpy.int32]
@@ -778,57 +874,6 @@ def test_space_numpy_dtype(func, start_dtype, stop_dtype):
778874
getattr(dpnp, func)(start, stop, 10)
779875

780876

781-
@pytest.mark.parametrize(
782-
"start",
783-
[
784-
dpnp.array(1),
785-
dpnp.array([2.6]),
786-
numpy.array([[-6.7, 3]]),
787-
[1, -4],
788-
(3, 5),
789-
],
790-
)
791-
@pytest.mark.parametrize(
792-
"stop",
793-
[
794-
dpnp.array([-4]),
795-
dpnp.array([[2.6], [-4]]),
796-
numpy.array(2),
797-
[[-4.6]],
798-
(3,),
799-
],
800-
)
801-
def test_linspace_arrays(start, stop):
802-
func = lambda xp: xp.linspace(get_array(xp, start), get_array(xp, stop), 10)
803-
assert func(numpy).shape == func(dpnp).shape
804-
805-
806-
def test_linspace_complex():
807-
func = lambda xp: xp.linspace(0, 3 + 2j, num=1000)
808-
assert_allclose(func(dpnp), func(numpy))
809-
810-
811-
@pytest.mark.parametrize("axis", [0, 1])
812-
def test_linspace_axis(axis):
813-
func = lambda xp: xp.linspace([2, 3], [20, 15], num=10, axis=axis)
814-
assert_allclose(func(dpnp), func(numpy))
815-
816-
817-
def test_linspace_step_nan():
818-
func = lambda xp: xp.linspace(1, 2, num=0, endpoint=False)
819-
assert_allclose(func(dpnp), func(numpy))
820-
821-
822-
@pytest.mark.parametrize("start", [1, [1, 1]])
823-
@pytest.mark.parametrize("stop", [10, [10 + 10]])
824-
def test_linspace_retstep(start, stop):
825-
func = lambda xp: xp.linspace(start, stop, num=10, retstep=True)
826-
np_res = func(numpy)
827-
dpnp_res = func(dpnp)
828-
assert_allclose(dpnp_res[0], np_res[0])
829-
assert_allclose(dpnp_res[1], np_res[1])
830-
831-
832877
@pytest.mark.parametrize(
833878
"arrays",
834879
[[], [[1]], [[1, 2, 3], [4, 5, 6]], [[1, 2], [3, 4], [5, 6]]],
@@ -862,10 +907,8 @@ def test_geomspace_zero_error():
862907

863908
def test_space_num_error():
864909
with pytest.raises(ValueError):
865-
dpnp.linspace(2, 5, -3)
866910
dpnp.geomspace(2, 5, -3)
867911
dpnp.logspace(2, 5, -3)
868-
dpnp.linspace([2, 3], 5, -3)
869912
dpnp.geomspace([2, 3], 5, -3)
870913
dpnp.logspace([2, 3], 5, -3)
871914

dpnp/tests/test_arraypad.py

Lines changed: 1 addition & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -73,11 +73,8 @@ def test_non_contiguous_array(self, mode):
7373
else:
7474
assert_array_equal(result, expected)
7575

76-
# TODO: include "linear_ramp" when dpnp issue gh-2084 is resolved
7776
@pytest.mark.parametrize("pad_width", [0, (0, 0), ((0, 0), (0, 0))])
78-
@pytest.mark.parametrize(
79-
"mode", [m for m in _modes if m not in {"linear_ramp"}]
80-
)
77+
@pytest.mark.parametrize("mode", _modes)
8178
def test_zero_pad_width(self, pad_width, mode):
8279
arr = dpnp.arange(30).reshape(6, 5)
8380
assert_array_equal(arr, dpnp.pad(arr, pad_width, mode=mode))

dpnp/tests/third_party/cupy/creation_tests/test_ranges.py

Lines changed: 4 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -1,3 +1,5 @@
1+
from __future__ import annotations
2+
13
import functools
24
import math
35
import unittest
@@ -226,8 +228,8 @@ def test_linspace_mixed_start_stop2(self, xp, dtype_range, dtype_out):
226228
# TODO (ev-br): np 2.0: had to bump the default rtol on Windows
227229
# and numpy 1.26+weak promotion from 0 to 5e-6
228230
if xp.dtype(dtype_range).kind == "u":
229-
# to avoid overflow, limit `val` to be smaller
230-
# than xp.iinfo(dtype).max
231+
# to avoid overflow, limit `val` to be smaller than
232+
# xp.iinfo(dtype).max (TODO: check if dpctl-2230 resolves that)
231233
if dtype_range in [xp.uint8, xp.uint16] or dtype_out in [
232234
xp.int8,
233235
xp.uint8,

dpnp/tests/third_party/cupy/functional_tests/test_piecewise.py

Lines changed: 2 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -1,3 +1,5 @@
1+
from __future__ import annotations
2+
13
import unittest
24

35
import numpy

0 commit comments

Comments
 (0)