Skip to content

Commit 828b2a6

Browse files
committed
now passing tests
1 parent 90be271 commit 828b2a6

File tree

2 files changed

+65
-49
lines changed

2 files changed

+65
-49
lines changed

src/diffpy/utils/diffraction_objects.py

Lines changed: 54 additions & 44 deletions
Original file line numberDiff line numberDiff line change
@@ -1,4 +1,5 @@
11
import datetime
2+
import warnings
23
from copy import deepcopy
34

45
import numpy as np
@@ -18,17 +19,31 @@
1819
)
1920

2021

22+
def _xtype_wmsg(xtype):
23+
return (
24+
f"WARNING: I don't know how to handle the xtype, '{xtype}'. Please rerun specifying and "
25+
f"xtype from {*XQUANTITIES, }"
26+
)
27+
28+
2129
class DiffractionObject:
2230
def __init__(
23-
self, name="", wavelength=None, scat_quantity="", metadata={}, xarray=None, yarray=None, xtype=""
31+
self, name=None, wavelength=None, scat_quantity=None, metadata=None, xarray=None, yarray=None, xtype=""
2432
):
33+
if name is None:
34+
name = ""
35+
self.name = name
36+
if metadata is None:
37+
metadata = {}
38+
self.metadata = metadata
39+
self.scat_quantity = scat_quantity
40+
self.wavelength = wavelength
41+
2542
if xarray is None:
2643
xarray = np.empty(0)
2744
if yarray is None:
2845
yarray = np.empty(0)
29-
self.insert_scattering_quantity(
30-
xarray, yarray, xtype, metadata=metadata, scat_quantity=scat_quantity, name=name, wavelength=wavelength
31-
)
46+
self.insert_scattering_quantity(xarray, yarray, xtype)
3247

3348
def __eq__(self, other):
3449
if not isinstance(other, DiffractionObject):
@@ -241,15 +256,32 @@ def get_angle_index(self, angle):
241256
if count >= len(self.angles):
242257
raise IndexError(f"WARNING: no angle {angle} found in angles list")
243258

259+
def _set_xarrays(self, xarray, xtype):
260+
self.all_arrays = np.empty(shape=(len(xarray), 4))
261+
if xtype.lower() in QQUANTITIES:
262+
self.all_arrays[:, 1] = xarray
263+
self.all_arrays[:, 2] = q_to_tth(xarray, self.wavelength)
264+
self.all_arrays[:, 3] = q_to_d(xarray)
265+
elif xtype.lower() in ANGLEQUANTITIES:
266+
self.all_arrays[:, 2] = xarray
267+
self.all_arrays[:, 1] = tth_to_q(xarray, self.wavelength)
268+
self.all_arrays[:, 3] = tth_to_d(xarray, self.wavelength)
269+
elif xtype.lower() in DQUANTITIES:
270+
self.all_arrays[:, 3] = xarray
271+
self.all_arrays[:, 1] = d_to_q(xarray)
272+
self.all_arrays[:, 2] = d_to_tth(xarray, self.wavelength)
273+
self.qmin = np.nanmin(self.all_arrays[:, 1], initial=np.inf)
274+
self.qmax = np.nanmax(self.all_arrays[:, 1], initial=0.0)
275+
self.tthmin = np.nanmin(self.all_arrays[:, 2], initial=np.inf)
276+
self.tthmax = np.nanmax(self.all_arrays[:, 2], initial=0.0)
277+
self.dmin = np.nanmin(self.all_arrays[:, 3], initial=np.inf)
278+
self.dmax = np.nanmax(self.all_arrays[:, 3], initial=0.0)
279+
244280
def insert_scattering_quantity(
245281
self,
246282
xarray,
247283
yarray,
248284
xtype,
249-
metadata={},
250-
scat_quantity="",
251-
name=None,
252-
wavelength=None,
253285
):
254286
f"""
255287
insert a new scattering quantity into the scattering object
@@ -262,38 +294,14 @@ def insert_scattering_quantity(
262294
the dependent variable array
263295
xtype string
264296
the type of quantity for the independent variable from {*XQUANTITIES, }
265-
metadata: dict
266-
the metadata in the form of a dictionary of user-supplied key:value pairs
267297
268298
Returns
269299
-------
270300
271301
"""
272-
self.input_xtype = xtype
273-
self.metadata = metadata
274-
self.scat_quantity = scat_quantity
275-
self.name = name
276-
self.wavelength = wavelength
277-
self.all_arrays = np.empty(shape=(len(yarray), 4))
302+
self._set_xarrays(xarray, xtype)
278303
self.all_arrays[:, 0] = yarray
279-
if xtype.lower() in QQUANTITIES:
280-
self.all_arrays[:, 1] = xarray
281-
self.all_arrays[:, 2] = q_to_tth(xarray, wavelength)
282-
self.all_arrays[:, 3] = q_to_d(xarray)
283-
elif xtype.lower() in ANGLEQUANTITIES:
284-
self.all_arrays[:, 2] = xarray
285-
self.all_arrays[:, 1] = tth_to_q(xarray, wavelength)
286-
self.all_arrays[:, 3] = tth_to_d(xarray, wavelength)
287-
elif xtype.lower() in DQUANTITIES:
288-
self.all_arrays[:, 3] = xarray
289-
self.all_arrays[:, 1] = d_to_q(xarray)
290-
self.all_arrays[:, 2] = d_to_tth(xarray, wavelength)
291-
self.qmin = np.nanmin(self.all_arrays[:, 1], initial=np.inf)
292-
self.qmax = np.nanmax(self.all_arrays[:, 1], initial=0.0)
293-
self.tthmin = np.nanmin(self.all_arrays[:, 2], initial=np.inf)
294-
self.tthmax = np.nanmax(self.all_arrays[:, 2], initial=0.0)
295-
self.dmin = np.nanmin(self.all_arrays[:, 3], initial=np.inf)
296-
self.dmax = np.nanmax(self.all_arrays[:, 3], initial=0.0)
304+
self.input_xtype = xtype
297305

298306
def _get_original_array(self):
299307
if self.input_xtype in QQUANTITIES:
@@ -319,7 +327,7 @@ def scale_to(self, target_diff_object, xtype=None, xvalue=None):
319327
Parameters
320328
----------
321329
target_diff_object: DiffractionObject
322-
the diffractoin object you want to scale the current one on to
330+
the diffraction object you want to scale the current one on to
323331
xtype: string, optional. Default is Q
324332
the xtype, from {XQUANTITIES}, that you will specify a point from to scale to
325333
xvalue: float. Default is the midpoint of the array
@@ -351,6 +359,7 @@ def scale_to(self, target_diff_object, xtype=None, xvalue=None):
351359
def on_xtype(self, xtype):
352360
"""
353361
return a 2D np array with x in the first column and y in the second for x of type type
362+
354363
Parameters
355364
----------
356365
xtype
@@ -360,24 +369,25 @@ def on_xtype(self, xtype):
360369
361370
"""
362371
if xtype.lower() in ANGLEQUANTITIES:
363-
return self.on_tth
372+
return self.on_tth()
364373
elif xtype.lower() in QQUANTITIES:
365-
return self.on_q
374+
return self.on_q()
366375
elif xtype.lower() in DQUANTITIES:
367-
return self.on_d
368-
pass
376+
return self.on_d()
377+
else:
378+
warnings.warn(_xtype_wmsg(xtype))
369379

370380
def dump(self, filepath, xtype=None):
371381
if xtype is None:
372-
xtype = " q"
373-
if xtype == "q":
382+
xtype = "q"
383+
if xtype in QQUANTITIES:
374384
data_to_save = np.column_stack((self.on_q()[0], self.on_q()[1]))
375-
elif xtype == "tth":
385+
elif xtype in ANGLEQUANTITIES:
376386
data_to_save = np.column_stack((self.on_tth()[0], self.on_tth()[1]))
377-
elif xtype == "d":
387+
elif xtype in DQUANTITIES:
378388
data_to_save = np.column_stack((self.on_d()[0], self.on_d()[1]))
379389
else:
380-
print(f"WARNING: cannot handle the xtype '{xtype}'")
390+
warnings.warn(_xtype_wmsg(xtype))
381391
self.metadata.update(get_package_info("diffpy.utils", metadata=self.metadata))
382392
self.metadata["creation_time"] = datetime.datetime.now()
383393

tests/test_diffraction_objects.py

Lines changed: 11 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -22,18 +22,23 @@ def compare_dicts(dict1, dict2):
2222

2323
def dicts_equal(dict1, dict2):
2424
equal = True
25+
print("")
26+
print(dict1)
27+
print(dict2)
2528
if not dict1.keys() == dict2.keys():
2629
equal = False
2730
for key in dict1:
2831
val1, val2 = dict1[key], dict2[key]
2932
if isinstance(val1, np.ndarray) and isinstance(val2, np.ndarray):
3033
if not np.allclose(val1, val2):
3134
equal = False
35+
elif isinstance(val1, list) and isinstance(val2, list):
36+
if not val1.all() == val2.all():
37+
equal = False
3238
elif isinstance(val1, np.float64) and isinstance(val2, np.float64):
3339
if not np.isclose(val1, val2):
3440
equal = False
3541
else:
36-
print(key, val1, val2)
3742
if not val1 == val2:
3843
equal = False
3944
return equal
@@ -187,12 +192,13 @@ def dicts_equal(dict1, dict2):
187192

188193
@pytest.mark.parametrize("inputs1, inputs2, expected", params)
189194
def test_diffraction_objects_equality(inputs1, inputs2, expected):
190-
diffraction_object1 = DiffractionObject(inputs1)
191-
diffraction_object2 = DiffractionObject(inputs2)
195+
diffraction_object1 = DiffractionObject(**inputs1)
196+
diffraction_object2 = DiffractionObject(**inputs2)
192197
# diffraction_object1_attributes = [key for key in diffraction_object1.__dict__ if not key.startswith("_")]
193198
# for i, attribute in enumerate(diffraction_object1_attributes):
194199
# setattr(diffraction_object1, attribute, inputs1[i])
195200
# setattr(diffraction_object2, attribute, inputs2[i])
201+
print(dicts_equal(diffraction_object1.__dict__, diffraction_object2.__dict__), expected)
196202
assert dicts_equal(diffraction_object1.__dict__, diffraction_object2.__dict__) == expected
197203

198204

@@ -246,7 +252,7 @@ def test_dump(tmp_path, mocker):
246252
"metadata": {},
247253
"input_xtype": "",
248254
"name": "",
249-
"scat_quantity": "",
255+
"scat_quantity": None,
250256
"qmin": np.float64(np.inf),
251257
"qmax": np.float64(0.0),
252258
"tthmin": np.float64(np.inf),
@@ -291,7 +297,7 @@ def test_dump(tmp_path, mocker):
291297
"metadata": {},
292298
"input_xtype": "tth",
293299
"name": "",
294-
"scat_quantity": "",
300+
"scat_quantity": None,
295301
"qmin": np.float64(0.0),
296302
"qmax": np.float64(1.0),
297303
"tthmin": np.float64(0.0),

0 commit comments

Comments
 (0)