Skip to content

Commit d65dce1

Browse files
initial commit, tests need discussion
1 parent 08de7c9 commit d65dce1

File tree

2 files changed

+97
-3
lines changed

2 files changed

+97
-3
lines changed

src/diffpy/utils/diffraction_objects.py

Lines changed: 4 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -366,7 +366,7 @@ def on_d(self):
366366

367367
def scale_to(self, target_diff_object, xtype=None, xvalue=None):
368368
f"""
369-
returns a new diffraction object which is the current object but recaled in y to the target
369+
returns a new diffraction object which is the current object but rescaled in y to the target
370370
371371
Parameters
372372
----------
@@ -390,14 +390,15 @@ def scale_to(self, target_diff_object, xtype=None, xvalue=None):
390390

391391
data = self.on_xtype(xtype)
392392
target = target_diff_object.on_xtype(xtype)
393+
if len(data[0]) == 0 or len(target[0]) == 0 or len(data[0]) != len(target[0]):
394+
raise ValueError("I cannot scale two diffraction objects with empty or different lengths.")
393395
if xvalue is None:
394396
xvalue = data[0][0] + (data[0][-1] - data[0][0]) / 2.0
395397

396398
xindex = (np.abs(data[0] - xvalue)).argmin()
397399
ytarget = target[1][xindex]
398400
yself = data[1][xindex]
399-
scaled.on_tth[1] = data[1] * ytarget / yself
400-
scaled.on_q[1] = data[1] * ytarget / yself
401+
scaled._all_arrays[:, 0] = data[1] * ytarget / yself
401402
return scaled
402403

403404
def on_xtype(self, xtype):

tests/test_diffraction_objects.py

Lines changed: 93 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -223,6 +223,99 @@ def test_on_xtype_bad():
223223
test.on_xtype("invalid")
224224

225225

226+
params_scale_to = [
227+
# UC1: xvalue exact match
228+
(
229+
[
230+
np.array([10, 15, 25, 30, 60, 140]),
231+
np.array([10, 20, 25, 30, 60, 100]),
232+
"tth",
233+
2 * np.pi,
234+
np.array([10, 15, 25, 30, 60, 140]),
235+
np.array([2, 3, 4, 5, 6, 7]),
236+
"tth",
237+
2 * np.pi,
238+
"tth",
239+
60,
240+
],
241+
[np.array([1, 2, 2.5, 3, 6, 10])],
242+
),
243+
# UC2: xvalue approximate match
244+
(
245+
[
246+
np.array([0.11, 0.24, 0.31, 0.4]),
247+
np.array([10, 20, 40, 60]),
248+
"q",
249+
2 * np.pi,
250+
np.array([0.11, 0.24, 0.31, 0.4]),
251+
np.array([1, 3, 4, 5]),
252+
"q",
253+
2 * np.pi,
254+
"q",
255+
0.1,
256+
],
257+
[np.array([1, 2, 4, 6])],
258+
),
259+
]
260+
261+
262+
@pytest.mark.parametrize("inputs, expected", params_scale_to)
263+
def test_scale_to(inputs, expected):
264+
orig_diff_object = DiffractionObject(xarray=inputs[0], yarray=inputs[1], xtype=inputs[2], wavelength=inputs[3])
265+
target_diff_object = DiffractionObject(
266+
xarray=inputs[4], yarray=inputs[5], xtype=inputs[6], wavelength=inputs[7]
267+
)
268+
scaled_diff_object = orig_diff_object.scale_to(target_diff_object, xtype=inputs[8], xvalue=inputs[9])
269+
# Check the intensity data is same as expected
270+
assert np.allclose(scaled_diff_object.on_xtype(inputs[8])[1], expected[0])
271+
272+
273+
params_scale_to_bad = [
274+
# UC1: at least one of the y-arrays is empty
275+
(
276+
[
277+
np.array([]),
278+
np.array([]),
279+
"tth",
280+
2 * np.pi,
281+
np.array([11, 14, 16, 20, 25, 30]),
282+
np.array([2, 3, 4, 5, 6, 7]),
283+
"tth",
284+
2 * np.pi,
285+
"tth",
286+
60,
287+
]
288+
),
289+
# UC2: diffraction objects with different array lengths
290+
(
291+
[
292+
np.array([0.11, 0.24, 0.31, 0.4, 0.5]),
293+
np.array([10, 20, 40, 50, 60]),
294+
"q",
295+
2 * np.pi,
296+
np.array([0.1, 0.15, 0.3, 0.4]),
297+
np.array([1, 3, 4, 5]),
298+
"q",
299+
2 * np.pi,
300+
"q",
301+
0.1,
302+
]
303+
),
304+
]
305+
306+
307+
@pytest.mark.parametrize("inputs", params_scale_to_bad)
308+
def test_scale_to_bad(inputs):
309+
orig_diff_object = DiffractionObject(xarray=inputs[0], yarray=inputs[1], xtype=inputs[2], wavelength=inputs[3])
310+
target_diff_object = DiffractionObject(
311+
xarray=inputs[4], yarray=inputs[5], xtype=inputs[6], wavelength=inputs[7]
312+
)
313+
with pytest.raises(
314+
ValueError, match="I cannot scale two diffraction objects with empty or different lengths."
315+
):
316+
orig_diff_object.scale_to(target_diff_object, xtype=inputs[8], xvalue=inputs[9])
317+
318+
226319
params_index = [
227320
# UC1: exact match
228321
([4 * np.pi, np.array([30.005, 60]), np.array([1, 2]), "tth", "tth", 30.005], [0]),

0 commit comments

Comments
 (0)