Skip to content

Commit 82ba497

Browse files
authored
TST: remove usages of deprecated numpy.testing.assert_warns (#833)
Deprecated in numpy 2.4.0. Also try `pytest.raises` instead of `assert_raises` in one file; both should work, the pytest builtin functionality will be preferred for new code.
2 parents 3e86c5a + b7e3de4 commit 82ba497

File tree

4 files changed

+73
-46
lines changed

4 files changed

+73
-46
lines changed

pywt/tests/test_cwt_wavelets.py

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -10,7 +10,6 @@
1010
assert_almost_equal,
1111
assert_equal,
1212
assert_raises,
13-
assert_warns,
1413
)
1514

1615
import pywt
@@ -341,7 +340,8 @@ def test_cwt_parameters_in_names():
341340
for func in [pywt.ContinuousWavelet, pywt.DiscreteContinuousWavelet]:
342341
for name in ['fbsp', 'cmor', 'shan']:
343342
# additional parameters should be specified within the name
344-
assert_warns(FutureWarning, func, name)
343+
with pytest.warns(FutureWarning):
344+
func(name)
345345

346346
for name in ['cmor', 'shan']:
347347
# valid names

pywt/tests/test_deprecations.py

Lines changed: 16 additions & 8 deletions
Original file line numberDiff line numberDiff line change
@@ -1,34 +1,40 @@
11
import warnings
22

33
import numpy as np
4-
from numpy.testing import assert_array_equal, assert_warns
4+
import pytest
5+
from numpy.testing import assert_array_equal
56

67
import pywt
78

89

910
def test_intwave_deprecation():
1011
wavelet = pywt.Wavelet('db3')
11-
assert_warns(DeprecationWarning, pywt.intwave, wavelet)
12+
with pytest.warns(DeprecationWarning):
13+
pywt.intwave(wavelet)
1214

1315

1416
def test_centrfrq_deprecation():
1517
wavelet = pywt.Wavelet('db3')
16-
assert_warns(DeprecationWarning, pywt.centrfrq, wavelet)
18+
with pytest.warns(DeprecationWarning):
19+
pywt.centrfrq(wavelet)
1720

1821

1922
def test_scal2frq_deprecation():
2023
wavelet = pywt.Wavelet('db3')
21-
assert_warns(DeprecationWarning, pywt.scal2frq, wavelet, 1)
24+
with pytest.warns(DeprecationWarning):
25+
pywt.scal2frq(wavelet, 1)
2226

2327

2428
def test_orthfilt_deprecation():
25-
assert_warns(DeprecationWarning, pywt.orthfilt, range(6))
29+
with pytest.warns(DeprecationWarning):
30+
pywt.orthfilt(range(6))
2631

2732

2833
def test_integrate_wave_tuple():
2934
sig = [0, 1, 2, 3]
3035
xgrid = [0, 1, 2, 3]
31-
assert_warns(DeprecationWarning, pywt.integrate_wavelet, (sig, xgrid))
36+
with pytest.warns(DeprecationWarning):
37+
pywt.integrate_wavelet((sig, xgrid))
3238

3339

3440
old_modes = ['zpd',
@@ -42,15 +48,17 @@ def test_integrate_wave_tuple():
4248

4349
def test_MODES_from_object_deprecation():
4450
for mode in old_modes:
45-
assert_warns(DeprecationWarning, pywt.Modes.from_object, mode)
51+
with pytest.warns(DeprecationWarning):
52+
pywt.Modes.from_object(mode)
4653

4754

4855
def test_MODES_attributes_deprecation():
4956
def get_mode(Modes, name):
5057
return getattr(Modes, name)
5158

5259
for mode in old_modes:
53-
assert_warns(DeprecationWarning, get_mode, pywt.Modes, mode)
60+
with pytest.warns(DeprecationWarning):
61+
get_mode(pywt.Modes, mode)
5462

5563

5664
def test_mode_equivalence():

pywt/tests/test_multilevel.py

Lines changed: 5 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -14,7 +14,6 @@
1414
assert_equal,
1515
assert_raises,
1616
assert_raises_regex,
17-
assert_warns,
1817
)
1918

2019
import pywt
@@ -899,8 +898,9 @@ def test_fswavedecn_fswaverecn_variable_levels():
899898
assert_raises(ValueError, pywt.fswavedecn, data, 'haar', levels=(1, 1, 1, 1))
900899

901900
# levels too large for array size
902-
assert_warns(UserWarning, pywt.fswavedecn, data, 'haar',
903-
levels=int(np.log2(np.min(data.shape)))+1)
901+
with pytest.warns(UserWarning):
902+
pywt.fswavedecn(data, 'haar',
903+
levels=int(np.log2(np.min(data.shape)))+1)
904904

905905

906906
def test_fswavedecn_fswaverecn_variable_wavelets_and_modes():
@@ -967,8 +967,8 @@ def test_fswavedecnresult():
967967
k, np.zeros(tuple([s + 1 for s in d.shape])))
968968

969969
# warns on assigning with a non-matching dtype
970-
assert_warns(UserWarning, result.__setitem__,
971-
k, np.zeros_like(d).astype(np.float32))
970+
with pytest.warns(UserWarning):
971+
result.__setitem__(k, np.zeros_like(d).astype(np.float32))
972972

973973
# all coefficients are stacked into result.coeffs (same ndim)
974974
assert_equal(result.coeffs.ndim, data.ndim)

pywt/tests/test_swt.py

Lines changed: 50 additions & 31 deletions
Original file line numberDiff line numberDiff line change
@@ -12,8 +12,6 @@
1212
assert_allclose,
1313
assert_array_equal,
1414
assert_equal,
15-
assert_raises,
16-
assert_warns,
1715
)
1816

1917
import pywt
@@ -69,7 +67,9 @@ def test_swt_decomposition():
6967

7068
def test_swt_max_level():
7169
# odd sized signal will warn about no levels of decomposition possible
72-
assert_warns(UserWarning, pywt.swt_max_level, 11)
70+
with pytest.warns(UserWarning):
71+
pywt.swt_max_level(11)
72+
7373
with warnings.catch_warnings():
7474
warnings.simplefilter('ignore', UserWarning)
7575
assert_equal(pywt.swt_max_level(11), 0)
@@ -134,7 +134,8 @@ def test_swt_axis():
134134
assert_array_equal(row, cD2)
135135

136136
# axis too large
137-
assert_raises(ValueError, pywt.swt, x, db1, level=2, axis=5)
137+
with pytest.raises(ValueError):
138+
pywt.swt(x, db1, level=2, axis=5)
138139

139140

140141
def test_swt_iswt_integration():
@@ -217,9 +218,8 @@ def test_swt_default_level_by_axis():
217218

218219
def test_swt2_ndim_error():
219220
x = np.ones(8)
220-
with warnings.catch_warnings():
221-
warnings.simplefilter('ignore', FutureWarning)
222-
assert_raises(ValueError, pywt.swt2, x, 'haar', level=1)
221+
with pytest.raises(ValueError):
222+
pywt.swt2(x, 'haar', level=1)
223223

224224

225225
@pytest.mark.slow
@@ -298,10 +298,12 @@ def test_swt2_axes():
298298
assert_allclose(X, r2, atol=atol)
299299

300300
# duplicate axes not allowed
301-
assert_raises(ValueError, pywt.swt2, X, current_wavelet, 1,
302-
axes=(0, 0))
301+
with pytest.raises(ValueError):
302+
pywt.swt2(X, current_wavelet, 1, axes=(0, 0))
303+
303304
# too few axes
304-
assert_raises(ValueError, pywt.swt2, X, current_wavelet, 1, axes=(0, ))
305+
with pytest.raises(ValueError):
306+
pywt.swt2(X, current_wavelet, 1, axes=(0, ))
305307

306308

307309
def test_swtn_axes():
@@ -325,21 +327,24 @@ def test_swtn_axes():
325327
assert_equal(empty, [])
326328

327329
# duplicate axes not allowed
328-
assert_raises(ValueError, pywt.swtn, X, current_wavelet, 1, axes=(0, 0))
330+
with pytest.raises(ValueError):
331+
pywt.swtn(X, current_wavelet, 1, axes=(0, 0))
329332

330333
# data.ndim = 0
331-
assert_raises(ValueError, pywt.swtn, np.asarray([]), current_wavelet, 1)
334+
with pytest.raises(ValueError):
335+
pywt.swtn(np.asarray([]), current_wavelet, 1)
332336

333337
# start_level too large
334-
assert_raises(ValueError, pywt.swtn, X, current_wavelet,
335-
level=1, start_level=2)
338+
with pytest.raises(ValueError):
339+
pywt.swtn(X, current_wavelet, level=1, start_level=2)
336340

337341
# level < 1 in swt_axis call
338-
assert_raises(ValueError, swt_axis, X, current_wavelet, level=0,
339-
start_level=0)
342+
with pytest.raises(ValueError):
343+
swt_axis(X, current_wavelet, level=0, start_level=0)
344+
340345
# odd-sized data not allowed
341-
assert_raises(ValueError, swt_axis, X[:-1, :], current_wavelet, level=0,
342-
start_level=0, axis=0)
346+
with pytest.raises(ValueError):
347+
swt_axis( X[-1, :], current_wavelet, level=0, start_level=0, axis=0)
343348

344349

345350
@pytest.mark.slow
@@ -401,12 +406,17 @@ def test_iswtn_errors():
401406
coeffs = pywt.swtn(x, w, max_level, axes=axes)
402407

403408
# more axes than dimensions transformed
404-
assert_raises(ValueError, pywt.iswtn, coeffs, w, axes=(0, 1, 2))
409+
with pytest.raises(ValueError):
410+
pywt.iswtn(coeffs, w, axes=(0, 1, 2))
411+
405412
# duplicate axes not allowed
406-
assert_raises(ValueError, pywt.iswtn, coeffs, w, axes=(0, 0))
413+
with pytest.raises(ValueError):
414+
pywt.iswtn(coeffs, w, axes=(0, 0))
415+
407416
# mismatched coefficient size
408417
coeffs[0]['da'] = coeffs[0]['da'][:-1, :]
409-
assert_raises(RuntimeError, pywt.iswtn, coeffs, w, axes=axes)
418+
with pytest.raises(RuntimeError):
419+
pywt.iswtn(coeffs, w, axes=axes)
410420

411421

412422
def test_swtn_iswtn_unique_shape_per_axis():
@@ -441,8 +451,11 @@ def test_per_axis_wavelets():
441451
assert_allclose(pywt.iswtn(coefs, wavelets[:1]), data, atol=1e-14)
442452

443453
# length of wavelets doesn't match the length of axes
444-
assert_raises(ValueError, pywt.swtn, data, wavelets[:2], level)
445-
assert_raises(ValueError, pywt.iswtn, coefs, wavelets[:2])
454+
with pytest.raises(ValueError):
455+
pywt.swtn(data, wavelets[:2], level)
456+
457+
with pytest.raises(ValueError):
458+
pywt.iswtn(coefs, wavelets[:2])
446459

447460
with warnings.catch_warnings():
448461
warnings.simplefilter('ignore', FutureWarning)
@@ -458,11 +471,12 @@ def test_error_on_continuous_wavelet():
458471
for dec_func, rec_func in zip([pywt.swt, pywt.swt2, pywt.swtn],
459472
[pywt.iswt, pywt.iswt2, pywt.iswtn]):
460473
for cwave in ['morl', pywt.DiscreteContinuousWavelet('morl')]:
461-
assert_raises(ValueError, dec_func, data, wavelet=cwave,
462-
level=3)
474+
with pytest.raises(ValueError):
475+
dec_func(data, wavelet=cwave, level=3)
463476

464477
c = dec_func(data, 'db1', level=3)
465-
assert_raises(ValueError, rec_func, c, wavelet=cwave)
478+
with pytest.raises(ValueError):
479+
rec_func(c, wavelet=cwave)
466480

467481

468482
def test_iswt_mixed_dtypes():
@@ -552,11 +566,13 @@ def test_iswtn_mixed_dtypes():
552566

553567
def test_swt_zero_size_axes():
554568
# raise on empty input array
555-
assert_raises(ValueError, pywt.swt, [], 'db2')
569+
with pytest.raises(ValueError):
570+
pywt.swt([], 'db2')
556571

557572
# >1D case uses a different code path so check there as well
558573
x = np.ones((1, 4))[0:0, :] # 2D with a size zero axis
559-
assert_raises(ValueError, pywt.swtn, x, 'db2', level=1, axes=(0,))
574+
with pytest.raises(ValueError):
575+
pywt.swtn(x, 'db2', level=1, axes=(0,))
560576

561577

562578
def test_swt_variance_and_energy_preservation():
@@ -575,7 +591,8 @@ def test_swt_variance_and_energy_preservation():
575591
np.linalg.norm(np.concatenate(coeffs)))
576592

577593
# non-orthogonal wavelet with norm=True raises a warning
578-
assert_warns(UserWarning, pywt.swt, x, 'bior2.2', norm=True)
594+
with pytest.warns(UserWarning):
595+
pywt.swt(x, 'bior2.2', norm=True)
579596

580597

581598
def test_swt2_variance_and_energy_preservation():
@@ -598,7 +615,8 @@ def test_swt2_variance_and_energy_preservation():
598615
np.linalg.norm(np.concatenate(coeff_list)))
599616

600617
# non-orthogonal wavelet with norm=True raises a warning
601-
assert_warns(UserWarning, pywt.swt2, x, 'bior2.2', level=4, norm=True)
618+
with pytest.warns(UserWarning):
619+
pywt.swt2(x, 'bior2.2', level=4, norm=True)
602620

603621

604622
def test_swtn_variance_and_energy_preservation():
@@ -621,7 +639,8 @@ def test_swtn_variance_and_energy_preservation():
621639
np.linalg.norm(np.concatenate(coeff_list)))
622640

623641
# non-orthogonal wavelet with norm=True raises a warning
624-
assert_warns(UserWarning, pywt.swtn, x, 'bior2.2', level=4, norm=True)
642+
with pytest.warns(UserWarning):
643+
pywt.swtn(x, 'bior2.2', level=4, norm=True)
625644

626645

627646
def test_swt_ravel_and_unravel():

0 commit comments

Comments
 (0)