Skip to content

Commit bfc9186

Browse files
committed
chore: clean the test cases
1 parent daa3749 commit bfc9186

File tree

3 files changed

+52
-76
lines changed

3 files changed

+52
-76
lines changed

src/diffpy/morph/morph_io.py

Lines changed: 3 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -447,17 +447,14 @@ def handle_extrapolation_warnings(squeeze_morph):
447447

448448
def handle_check_increase_warning(squeeze_morph):
449449
if squeeze_morph is not None:
450-
if squeeze_morph.squeeze_info["monotonic"]:
450+
if squeeze_morph.strictly_increasing:
451451
wmsg = None
452452
else:
453-
overlapping_regions = squeeze_morph.squeeze_info[
454-
"overlapping_regions"
455-
]
456453
wmsg = (
457454
"Warning: The squeeze morph has interpolated your morphed "
458455
"function from a non-monotonically increasing grid. "
459-
"This can result in strange behavior in the regions "
460-
f"{overlapping_regions}. To disable this setting, "
456+
"This can result in strange behavior in the non-uniqe "
457+
"grid regions. To disable this setting, "
461458
"please enable --check-increasing."
462459
)
463460
if wmsg:

src/diffpy/morph/morphs/morphsqueeze.py

Lines changed: 5 additions & 36 deletions
Original file line numberDiff line numberDiff line change
@@ -73,8 +73,7 @@ def __init__(self, config=None, check_increase=False):
7373
super().__init__(config)
7474
self.check_increase = check_increase
7575

76-
def _set_squeeze_info(self, x, x_sorted):
77-
self.squeeze_info = {"monotonic": True, "overlapping_regions": None}
76+
def _ensure_strictly_increase(self, x, x_sorted):
7877
if list(x) != list(x_sorted):
7978
if self.check_increase:
8079
raise ValueError(
@@ -100,12 +99,9 @@ def _set_squeeze_info(self, x, x_sorted):
10099
"and stretch parameter for a1."
101100
)
102101
else:
103-
if list(x) != list(x_sorted[::-1]):
104-
overlapping_regions = self.get_overlapping_regions(x)
105-
self.squeeze_info["monotonic"] = False
106-
self.squeeze_info["overlapping_regions"] = (
107-
overlapping_regions
108-
)
102+
self.strictly_increasing = False
103+
else:
104+
self.strictly_increasing = True
109105

110106
def _sort_squeeze(self, x, y):
111107
"""Sort x,y according to the value of x."""
@@ -114,32 +110,6 @@ def _sort_squeeze(self, x, y):
114110
x_sorted, y_sorted = list(zip(*xy_sorted))
115111
return x_sorted, y_sorted
116112

117-
def get_overlapping_regions(self, x):
118-
diffx = numpy.diff(x)
119-
diffx_sign = numpy.sign(diffx)
120-
local_min_or_max_index = (
121-
numpy.where(numpy.diff(diffx_sign) != 0)[0] + 1
122-
)
123-
monotonic_regions_x = numpy.concatenate(
124-
(
125-
[x[0]],
126-
numpy.repeat(
127-
numpy.array(x)[local_min_or_max_index], 2
128-
).tolist()[:-1],
129-
)
130-
).reshape(-1, 2)
131-
monotinic_regions_sign = diffx_sign[local_min_or_max_index - 1]
132-
133-
overlapping_regions_sign = -1 if x[0] < x[-1] else 1
134-
overlapping_regions_index = numpy.where(
135-
monotinic_regions_sign == overlapping_regions_sign
136-
)[0]
137-
overlapping_regions = monotonic_regions_x[overlapping_regions_index]
138-
overlapping_regions = [
139-
sorted(region) for region in overlapping_regions
140-
]
141-
return overlapping_regions
142-
143113
def _handle_duplicates(self, x, y):
144114
"""Remove duplicated x and use the mean value of y corresponded
145115
to the duplicated x."""
@@ -159,14 +129,13 @@ def morph(self, x_morph, y_morph, x_target, y_target):
159129
data.
160130
"""
161131
Morph.morph(self, x_morph, y_morph, x_target, y_target)
162-
163132
coeffs = [self.squeeze[f"a{i}"] for i in range(len(self.squeeze))]
164133
squeeze_polynomial = Polynomial(coeffs)
165134
x_squeezed = self.x_morph_in + squeeze_polynomial(self.x_morph_in)
166135
x_squeezed_sorted, y_morph_sorted = self._sort_squeeze(
167136
x_squeezed, self.y_morph_in
168137
)
169-
self._set_squeeze_info(x_squeezed, x_squeezed_sorted)
138+
self._ensure_strictly_increase(x_squeezed, x_squeezed_sorted)
170139
x_squeezed_sorted, y_morph_sorted = self._handle_duplicates(
171140
x_squeezed_sorted, y_morph_sorted
172141
)

tests/test_morphsqueeze.py

Lines changed: 44 additions & 34 deletions
Original file line numberDiff line numberDiff line change
@@ -176,26 +176,38 @@ def test_morphsqueeze_extrapolate(
176176
@pytest.mark.parametrize(
177177
"squeeze_coeffs, x_morph",
178178
[
179-
({"a0": 0.01, "a1": -0.99, "a2": 0.01}, np.linspace(-1, 1, 101)),
179+
({"a0": 0.01, "a1": 0.01, "a2": -0.1}, np.linspace(0, 10, 101)),
180180
],
181181
)
182-
def test_sort_squeeze(user_filesystem, squeeze_coeffs, x_morph):
182+
def test_non_strictly_increasing_squeeze(squeeze_coeffs, x_morph):
183183
x_target = x_morph
184184
y_target = np.sin(x_target)
185185
coeffs = [squeeze_coeffs[f"a{i}"] for i in range(len(squeeze_coeffs))]
186186
squeeze_polynomial = Polynomial(coeffs)
187187
x_squeezed = x_morph + squeeze_polynomial(x_morph)
188-
# non-strictly-monotonic
189-
assert not np.all(np.diff(np.sign(np.diff(x_squeezed))) == 0)
190-
# outcome converges when --check-increase is not used
188+
# non-strictly-increasing
189+
assert not np.all(np.sign(np.diff(x_squeezed)) > 0)
191190
y_morph = np.sin(x_squeezed)
192-
morph = MorphSqueeze()
193-
morph.squeeze = squeeze_coeffs
191+
# all zero initial guess
192+
morph_results = morphpy.morph_arrays(
193+
np.array([x_morph, y_morph]).T,
194+
np.array([x_target, y_target]).T,
195+
squeeze=[0, 0, 0],
196+
apply=True,
197+
)
198+
_, y_morph_actual = morph_results[1].T # noqa: F841
199+
y_morph_expected = np.sin(x_morph) # noqa: F841
200+
# squeeze morph extrapolates.
201+
# Need to extract extrap_index from morph_results to examine
202+
# the convergence.
203+
# assert np.allclose(y_morph_actual, y_morph_expected, atol=1e-3)
204+
# Raise warning when called without --check-increase
194205
with pytest.warns() as w:
195-
moreph_results = morphpy.morph_arrays(
206+
morph_results = morphpy.morph_arrays(
196207
np.array([x_morph, y_morph]).T,
197208
np.array([x_target, y_target]).T,
198-
squeeze=[0.01, -0.99, 0.01],
209+
squeeze=[0.01, 0.01, -0.1],
210+
apply=True,
199211
)
200212
assert w[0].category is UserWarning
201213
actual_wmsg = " ".join([str(w[i].message) for i in range(len(w))])
@@ -204,18 +216,18 @@ def test_sort_squeeze(user_filesystem, squeeze_coeffs, x_morph):
204216
"function from a non-monotonically increasing grid. "
205217
)
206218
assert expected_wmsg in actual_wmsg
207-
expected_coeffs = coeffs
208-
actual_coeffs = [
209-
moreph_results[0]["squeeze"][f"a{i}"]
210-
for i in range(len(moreph_results[0]["squeeze"]))
211-
]
212-
# program exits when --check-increase is used
213-
assert np.allclose(actual_coeffs, expected_coeffs, rtol=1e-2)
219+
_, y_morph_actual = morph_results[1].T # noqa: F841
220+
y_morph_expected = np.sin(x_morph) # noqa: F841
221+
# squeeze morph extrapolates.
222+
# Need to extract extrap_index from morph_results to examine
223+
# the convergence.
224+
# assert np.allclose(y_morph_actual, y_morph_expected, atol=1e-3)
225+
# System exits when called with --check-increase
214226
with pytest.raises(SystemExit) as excinfo:
215227
morphpy.morph_arrays(
216228
np.array([x_morph, y_morph]).T,
217229
np.array([x_target, y_target]).T,
218-
squeeze=[0.01, -1, 0.01],
230+
squeeze=[0.01, 0.009, -0.1],
219231
check_increase=True,
220232
)
221233
actual_emsg = str(excinfo.value)
@@ -315,24 +327,22 @@ def test_sort_squeeze_bad(user_filesystem, squeeze_coeffs, x_morph):
315327
assert expected_emsg in actual_emsg
316328

317329

318-
@pytest.mark.parametrize(
319-
"turning_points, expected_overlapping_regions",
320-
[
321-
# x[-1] > x[0], monotonically decreasing regions are overlapping
322-
([0, 10, 7, 12], [[7, 10]]),
323-
# x[-1] < x[0], monotonically increasing regions are overlapping
324-
([0, 5, 2, 4, -10], [[0, 5], [2, 4]]),
325-
],
326-
)
327-
def test_get_overlapping_regions(turning_points, expected_overlapping_regions):
330+
def test_handle_duplicates():
331+
unq_x = np.linspace(0, 11, 10)
332+
iter = 10
328333
morph = MorphSqueeze()
329-
regions = (
330-
np.linspace(turning_points[i], turning_points[i + 1], 20)
331-
for i in range(len(turning_points) - 1)
332-
)
333-
x_value = np.concatenate(list(regions))
334-
actual_overlaping_regions = morph.get_overlapping_regions(x_value)
335-
assert expected_overlapping_regions == actual_overlaping_regions
334+
for i in range(iter):
335+
actual_x = np.random.choice(unq_x, size=20)
336+
actual_y = np.sin(actual_x)
337+
actual_handled_x, actual_handled_y = morph._handle_duplicates(
338+
actual_x, actual_y
339+
)
340+
expected_handled_x = np.unique(actual_x)
341+
expected_handled_y = np.array(
342+
[actual_y[actual_x == x].mean() for x in expected_handled_x]
343+
)
344+
assert np.allclose(actual_handled_x, expected_handled_x)
345+
assert np.allclose(actual_handled_y, expected_handled_y)
336346

337347

338348
def create_morph_data_file(

0 commit comments

Comments
 (0)