Skip to content

Commit 5a32de4

Browse files
initial commit, raise value error if xtype is invalid
1 parent 49bb353 commit 5a32de4

2 files changed

Lines changed: 44 additions & 10 deletions

File tree

src/diffpy/utils/scattering_objects/diffraction_objects.py

Lines changed: 2 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -742,7 +742,8 @@ def on_xtype(self, xtype):
742742
return self.on_q
743743
elif xtype.lower() in DQUANTITIES:
744744
return self.on_d
745-
pass
745+
else:
746+
raise ValueError(f"Unknown xtype: {xtype}. Allowed xtypes are {*XQUANTITIES, }.")
746747

747748
def dump(self, filepath, xtype=None):
748749
if xtype is None:

tests/diffpy/utils/scattering_objects/test_diffraction_objects.py

Lines changed: 42 additions & 9 deletions
Original file line numberDiff line numberDiff line change
@@ -1,11 +1,17 @@
1+
import re
12
from pathlib import Path
23

34
import numpy as np
45
import pytest
56
from freezegun import freeze_time
67

7-
from diffpy.utils.scattering_objects.diffraction_objects import DiffractionObject
8-
from diffpy.utils.transforms import wavelength_warning_emsg
8+
from diffpy.utils.scattering_objects.diffraction_objects import (
9+
ANGLEQUANTITIES,
10+
DQUANTITIES,
11+
QQUANTITIES,
12+
XQUANTITIES,
13+
DiffractionObject,
14+
)
915

1016
params = [
1117
( # Default
@@ -232,13 +238,40 @@ def test_diffraction_objects_equality(inputs1, inputs2, expected):
232238
assert (diffraction_object1 == diffraction_object2) == expected
233239

234240

235-
def _test_valid_diffraction_objects(actual_diffraction_object, function, expected_array):
236-
if actual_diffraction_object.wavelength is None:
237-
with pytest.warns(UserWarning) as warn_record:
238-
getattr(actual_diffraction_object, function)()
239-
assert str(warn_record[0].message) == wavelength_warning_emsg
240-
actual_array = getattr(actual_diffraction_object, function)()
241-
return np.allclose(actual_array, expected_array)
241+
params_on_xtype = [
242+
(
243+
[
244+
np.array([1, 2, 3, 4, 5, 6]), # intensity array
245+
np.array([0, 30, 60, 90, 120, 180]), # tth array
246+
np.array([1, 2, 3, 4, 5, 6]), # q array
247+
np.array([10, 20, 30, 40, 50, 60]), # d array
248+
],
249+
[
250+
np.array([[0, 30, 60, 90, 120, 180], [1, 2, 3, 4, 5, 6]]), # expected on_tth
251+
np.array([[1, 2, 3, 4, 5, 6], [1, 2, 3, 4, 5, 6]]), # expected on_q
252+
np.array([[10, 20, 30, 40, 50, 60], [1, 2, 3, 4, 5, 6]]), # expected on_d
253+
],
254+
)
255+
]
256+
257+
258+
@pytest.mark.parametrize("inputs, expected", params_on_xtype)
259+
def test_on_xtype(inputs, expected):
260+
test = DiffractionObject()
261+
test.on_tth = np.array([inputs[1], inputs[0]])
262+
test.on_q = np.array([inputs[2], inputs[0]])
263+
test.on_d = np.array([inputs[3], inputs[0]])
264+
for xtype_list, expected_value in zip([ANGLEQUANTITIES, QQUANTITIES, DQUANTITIES], expected):
265+
for xtype in xtype_list:
266+
assert np.allclose(test.on_xtype(xtype), expected_value)
267+
268+
269+
def test_on_xtype_bad():
270+
test = DiffractionObject()
271+
with pytest.raises(
272+
ValueError, match=re.escape(f"Unknown xtype: invalid. Allowed xtypes are {*XQUANTITIES, }.")
273+
):
274+
test.on_xtype("invalid")
242275

243276

244277
def test_dump(tmp_path, mocker):

0 commit comments

Comments
 (0)