Skip to content

Commit bedf468

Browse files
committed
fix: let DiffractionObject.get_array_index to use an optional input xtype
1 parent 8e43660 commit bedf468

File tree

2 files changed

+43
-10
lines changed

2 files changed

+43
-10
lines changed

src/diffpy/utils/diffraction_objects.py

Lines changed: 10 additions & 7 deletions
Original file line numberDiff line numberDiff line change
@@ -409,26 +409,29 @@ def uuid(self):
409409
def uuid(self, _):
410410
raise AttributeError(_setter_wmsg("uuid"))
411411

412-
def get_array_index(self, xtype, xvalue):
413-
"""Return the index of the closest value in the array associated with
412+
def get_array_index(self, xvalue, xtype=None):
413+
f"""Return the index of the closest value in the array associated with
414414
the specified xtype and the value provided.
415415
416416
Parameters
417417
----------
418-
xtype : str
419-
The type of the independent variable in `xarray`. Must be one
420-
of {*XQUANTITIES}.
421418
xvalue : float
422419
The value of the xtype to find the closest index for.
420+
xtype : str, optional
421+
The type of the independent variable in `xarray`. Must be one
422+
of {*XQUANTITIES,}. Default is {self._input_xtype}
423423
424424
Returns
425425
-------
426426
index : int
427427
The index of the closest value in the array associated with the
428428
specified xtype and the value provided.
429429
"""
430-
431-
xtype = self._input_xtype
430+
if xtype is None:
431+
xtype = self._input_xtype
432+
else:
433+
if xtype not in XQUANTITIES:
434+
raise ValueError(_xtype_wmsg(xtype))
432435
xarray = self.on_xtype(xtype)[0]
433436
if len(xarray) == 0:
434437
raise ValueError(

tests/test_diffraction_objects.py

Lines changed: 33 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -377,7 +377,7 @@ def test_scale_to_bad(org_do_args, target_do_args, scale_inputs):
377377
0,
378378
),
379379
# C2: Target value lies in the array, expect the closest index
380-
( # 1. same xtype
380+
( # 1. xtype(tth) is equal to self._input_xtype(tth)
381381
{
382382
"wavelength": 4 * np.pi,
383383
"xarray": np.array([30, 60]),
@@ -390,7 +390,20 @@ def test_scale_to_bad(org_do_args, target_do_args, scale_inputs):
390390
},
391391
0,
392392
),
393-
( # 2. different xtype
393+
( # 2. use default xtype(equal to self._input_xtype)
394+
{
395+
"wavelength": 4 * np.pi,
396+
"xarray": np.array([30, 60]),
397+
"yarray": np.array([1, 2]),
398+
"xtype": "tth",
399+
},
400+
{
401+
"xtype": None,
402+
"value": 45,
403+
},
404+
0,
405+
),
406+
( # 3. xtype(q) is different from self._input_xtype(tth)
394407
{
395408
"wavelength": 4 * np.pi,
396409
"xarray": np.array([30, 60]),
@@ -435,12 +448,13 @@ def test_scale_to_bad(org_do_args, target_do_args, scale_inputs):
435448
def test_get_array_index(do_args, get_array_index_inputs, expected_index):
436449
do = DiffractionObject(**do_args)
437450
actual_index = do.get_array_index(
438-
get_array_index_inputs["xtype"], get_array_index_inputs["value"]
451+
get_array_index_inputs["value"], get_array_index_inputs["xtype"]
439452
)
440453
assert actual_index == expected_index
441454

442455

443456
def test_get_array_index_bad():
457+
# empty array in DiffractionObject
444458
do = DiffractionObject(
445459
wavelength=2 * np.pi,
446460
xarray=np.array([]),
@@ -454,6 +468,22 @@ def test_get_array_index_bad():
454468
),
455469
):
456470
do.get_array_index(xtype="tth", xvalue=30)
471+
# non-existing xtype
472+
do = DiffractionObject(
473+
wavelength=4 * np.pi,
474+
xarray=np.array([30, 60]),
475+
yarray=np.array([1, 2]),
476+
xtype="tth",
477+
)
478+
non_existing_xtype = "non_existing_xtype"
479+
with pytest.raises(
480+
ValueError,
481+
match=re.escape(
482+
f"I don't know how to handle the xtype, '{non_existing_xtype}'. "
483+
f"Please rerun specifying an xtype from {*XQUANTITIES, }"
484+
),
485+
):
486+
do.get_array_index(xtype=non_existing_xtype, xvalue=30)
457487

458488

459489
def test_dump(tmp_path, mocker):

0 commit comments

Comments
 (0)