Skip to content

Commit 3c52108

Browse files
feat: add warning for extrapolation in morphsqueeze (#250)
* feat: add warning for extrapolation in `morphsqueeze` * [pre-commit.ci] auto fixes from pre-commit hooks * chore: add news * test: add test for warning extrapolation in `morphsqueeze` * chore: update warning message for extrapolation in `morphsqueeze.py` * test: add CLI test for extrapolation warning --------- Co-authored-by: pre-commit-ci[bot] <66853113+pre-commit-ci[bot]@users.noreply.github.com>
1 parent 3908d63 commit 3c52108

File tree

3 files changed

+137
-0
lines changed

3 files changed

+137
-0
lines changed

news/extrapolate-warning.rst

Lines changed: 23 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,23 @@
1+
**Added:**
2+
3+
* No news added: Add warning for extrapolation in morphsqueeze.py.
4+
5+
**Changed:**
6+
7+
* <news item>
8+
9+
**Deprecated:**
10+
11+
* <news item>
12+
13+
**Removed:**
14+
15+
* <news item>
16+
17+
**Fixed:**
18+
19+
* <news item>
20+
21+
**Security:**
22+
23+
* <news item>

src/diffpy/morph/morphs/morphsqueeze.py

Lines changed: 32 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -1,13 +1,22 @@
11
"""Class MorphSqueeze -- Apply a polynomial to squeeze the morph
22
function."""
33

4+
import warnings
5+
46
import numpy as np
57
from numpy.polynomial import Polynomial
68
from scipy.interpolate import CubicSpline
79

810
from diffpy.morph.morphs.morph import LABEL_GR, LABEL_RA, Morph
911

1012

13+
def custom_formatwarning(msg, *args, **kwargs):
14+
return f"{msg}\n"
15+
16+
17+
warnings.formatwarning = custom_formatwarning
18+
19+
1120
class MorphSqueeze(Morph):
1221
"""Squeeze the morph function.
1322
@@ -85,4 +94,27 @@ def morph(self, x_morph, y_morph, x_target, y_target):
8594
high_extrap = np.where(self.x_morph_in > x_squeezed[-1])[0]
8695
self.extrap_index_low = low_extrap[-1] if low_extrap.size else None
8796
self.extrap_index_high = high_extrap[0] if high_extrap.size else None
97+
below_extrap = min(x_morph) < min(x_squeezed)
98+
above_extrap = max(x_morph) > max(x_squeezed)
99+
if below_extrap or above_extrap:
100+
if not above_extrap:
101+
wmsg = (
102+
"Warning: points with grid value below "
103+
f"{min(x_squeezed)} will be extrapolated."
104+
)
105+
elif not below_extrap:
106+
wmsg = (
107+
"Warning: points with grid value above "
108+
f"{max(x_squeezed)} will be extrapolated."
109+
)
110+
else:
111+
wmsg = (
112+
"Warning: points with grid value below "
113+
f"{min(x_squeezed)} and above {max(x_squeezed)} will be "
114+
"extrapolated."
115+
)
116+
warnings.warn(
117+
wmsg,
118+
UserWarning,
119+
)
88120
return self.xyallout

tests/test_morphsqueeze.py

Lines changed: 82 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -1,3 +1,5 @@
1+
import subprocess
2+
13
import numpy as np
24
import pytest
35
from numpy.polynomial import Polynomial
@@ -85,3 +87,83 @@ def test_morphsqueeze(x_morph, x_target, squeeze_coeffs):
8587
assert np.allclose(x_morph_actual, x_morph_expected)
8688
assert np.allclose(x_target_actual, x_target)
8789
assert np.allclose(y_target_actual, y_target)
90+
91+
92+
@pytest.mark.parametrize(
93+
"squeeze_coeffs, wmsg_gen",
94+
[
95+
# extrapolate below
96+
(
97+
{"a0": 0.01},
98+
lambda x: (
99+
"Warning: points with grid value below "
100+
f"{x[0]} will be extrapolated."
101+
),
102+
),
103+
# extrapolate above
104+
(
105+
{"a0": -0.01},
106+
lambda x: (
107+
"Warning: points with grid value above "
108+
f"{x[1]} will be extrapolated."
109+
),
110+
),
111+
# extrapolate below and above
112+
(
113+
{"a0": 0.01, "a1": -0.002},
114+
lambda x: (
115+
"Warning: points with grid value below "
116+
f"{x[0]} and above {x[1]} will be "
117+
"extrapolated."
118+
),
119+
),
120+
],
121+
)
122+
def test_morphsqueeze_extrapolate(user_filesystem, squeeze_coeffs, wmsg_gen):
123+
x_morph = np.linspace(0, 10, 101)
124+
y_morph = np.sin(x_morph)
125+
x_target = x_morph
126+
y_target = y_morph
127+
morph = MorphSqueeze()
128+
morph.squeeze = squeeze_coeffs
129+
coeffs = [squeeze_coeffs[f"a{i}"] for i in range(len(squeeze_coeffs))]
130+
squeeze_polynomial = Polynomial(coeffs)
131+
x_squeezed = x_morph + squeeze_polynomial(x_morph)
132+
with pytest.warns() as w:
133+
x_morph_actual, y_morph_actual, x_target_actual, y_target_actual = (
134+
morph(x_morph, y_morph, x_target, y_target)
135+
)
136+
assert len(w) == 1
137+
assert w[0].category is UserWarning
138+
actual_wmsg = str(w[0].message)
139+
expected_wmsg = wmsg_gen([min(x_squeezed), max(x_squeezed)])
140+
assert actual_wmsg == expected_wmsg
141+
142+
# CLI test
143+
morph_file, target_file = create_morph_data_file(
144+
user_filesystem / "cwd_dir", x_morph, y_morph, x_target, y_target
145+
)
146+
run_cmd = ["diffpy.morph"]
147+
run_cmd.extend(["--squeeze=" + ",".join(map(str, coeffs))])
148+
run_cmd.extend([str(morph_file), str(target_file)])
149+
run_cmd.append("-n")
150+
result = subprocess.run(run_cmd, capture_output=True, text=True)
151+
assert expected_wmsg in result.stderr
152+
153+
154+
def create_morph_data_file(
155+
data_dir_path, x_morph, y_morph, x_target, y_target
156+
):
157+
morph_file = data_dir_path / "morph_data"
158+
morph_data_text = [
159+
str(x_morph[i]) + " " + str(y_morph[i]) for i in range(len(x_morph))
160+
]
161+
morph_data_text = "\n".join(morph_data_text)
162+
morph_file.write_text(morph_data_text)
163+
target_file = data_dir_path / "target_data"
164+
target_data_text = [
165+
str(x_target[i]) + " " + str(y_target[i]) for i in range(len(x_target))
166+
]
167+
target_data_text = "\n".join(target_data_text)
168+
target_file.write_text(target_data_text)
169+
return morph_file, target_file

0 commit comments

Comments
 (0)