|
| 1 | +import re |
1 | 2 | from pathlib import Path |
2 | 3 |
|
3 | 4 | import numpy as np |
4 | 5 | import pytest |
5 | 6 | from freezegun import freeze_time |
6 | 7 |
|
7 | | -from diffpy.utils.scattering_objects.diffraction_objects import DiffractionObject |
8 | | -from diffpy.utils.transforms import wavelength_warning_emsg |
| 8 | +from diffpy.utils.scattering_objects.diffraction_objects import ( |
| 9 | + ANGLEQUANTITIES, |
| 10 | + DQUANTITIES, |
| 11 | + QQUANTITIES, |
| 12 | + XQUANTITIES, |
| 13 | + DiffractionObject, |
| 14 | +) |
9 | 15 |
|
10 | 16 | params = [ |
11 | 17 | ( # Default |
@@ -232,13 +238,40 @@ def test_diffraction_objects_equality(inputs1, inputs2, expected): |
232 | 238 | assert (diffraction_object1 == diffraction_object2) == expected |
233 | 239 |
|
234 | 240 |
|
235 | | -def _test_valid_diffraction_objects(actual_diffraction_object, function, expected_array): |
236 | | - if actual_diffraction_object.wavelength is None: |
237 | | - with pytest.warns(UserWarning) as warn_record: |
238 | | - getattr(actual_diffraction_object, function)() |
239 | | - assert str(warn_record[0].message) == wavelength_warning_emsg |
240 | | - actual_array = getattr(actual_diffraction_object, function)() |
241 | | - return np.allclose(actual_array, expected_array) |
| 241 | +params_on_xtype = [ |
| 242 | + ( |
| 243 | + [ |
| 244 | + np.array([1, 2, 3, 4, 5, 6]), # intensity array |
| 245 | + np.array([0, 30, 60, 90, 120, 180]), # tth array |
| 246 | + np.array([1, 2, 3, 4, 5, 6]), # q array |
| 247 | + np.array([10, 20, 30, 40, 50, 60]), # d array |
| 248 | + ], |
| 249 | + [ |
| 250 | + np.array([[0, 30, 60, 90, 120, 180], [1, 2, 3, 4, 5, 6]]), # expected on_tth |
| 251 | + np.array([[1, 2, 3, 4, 5, 6], [1, 2, 3, 4, 5, 6]]), # expected on_q |
| 252 | + np.array([[10, 20, 30, 40, 50, 60], [1, 2, 3, 4, 5, 6]]), # expected on_d |
| 253 | + ], |
| 254 | + ) |
| 255 | +] |
| 256 | + |
| 257 | + |
| 258 | +@pytest.mark.parametrize("inputs, expected", params_on_xtype) |
| 259 | +def test_on_xtype(inputs, expected): |
| 260 | + test = DiffractionObject() |
| 261 | + test.on_tth = np.array([inputs[1], inputs[0]]) |
| 262 | + test.on_q = np.array([inputs[2], inputs[0]]) |
| 263 | + test.on_d = np.array([inputs[3], inputs[0]]) |
| 264 | + for xtype_list, expected_value in zip([ANGLEQUANTITIES, QQUANTITIES, DQUANTITIES], expected): |
| 265 | + for xtype in xtype_list: |
| 266 | + assert np.allclose(test.on_xtype(xtype), expected_value) |
| 267 | + |
| 268 | + |
| 269 | +def test_on_xtype_bad(): |
| 270 | + test = DiffractionObject() |
| 271 | + with pytest.raises( |
| 272 | + ValueError, match=re.escape(f"Unknown xtype: invalid. Allowed xtypes are {*XQUANTITIES, }.") |
| 273 | + ): |
| 274 | + test.on_xtype("invalid") |
242 | 275 |
|
243 | 276 |
|
244 | 277 | def test_dump(tmp_path, mocker): |
|
0 commit comments