Skip to content

Commit 630f00e

Browse files
committed
feat: implement __sub__ feature for DiffractionObject
1 parent 1ea8e9a commit 630f00e

File tree

2 files changed

+88
-84
lines changed

2 files changed

+88
-84
lines changed

src/diffpy/utils/diffraction_objects.py

Lines changed: 20 additions & 56 deletions
Original file line numberDiff line numberDiff line change
@@ -217,42 +217,16 @@ def __add__(self, other):
217217

218218
__radd__ = __add__
219219

220-
def _check_operation_compatibility(self, other):
221-
if not isinstance(other, (DiffractionObject, int, float)):
222-
raise TypeError(invalid_add_type_emsg)
223-
if isinstance(other, DiffractionObject):
224-
self_yarray = self.all_arrays[:, 0]
225-
other_yarray = other.all_arrays[:, 0]
226-
if len(self_yarray) != len(other_yarray):
227-
raise ValueError(y_grid_length_mismatch_emsg)
228-
229220
def __sub__(self, other):
230-
subtracted = deepcopy(self)
231-
if isinstance(other, int) or isinstance(other, float) or isinstance(other, np.ndarray):
232-
subtracted.on_tth[1] = self.on_tth[1] - other
233-
subtracted.on_q[1] = self.on_q[1] - other
234-
elif not isinstance(other, DiffractionObject):
235-
raise TypeError("I only know how to subtract two Scattering_object objects")
236-
elif self.on_tth[0].all() != other.on_tth[0].all():
237-
raise RuntimeError(y_grid_length_mismatch_emsg)
238-
else:
239-
subtracted.on_tth[1] = self.on_tth[1] - other.on_tth[1]
240-
subtracted.on_q[1] = self.on_q[1] - other.on_q[1]
241-
return subtracted
221+
self._check_operation_compatibility(other)
222+
subtracted_do = deepcopy(self)
223+
if isinstance(other, (int, float)):
224+
subtracted_do._all_arrays[:, 0] -= other
225+
if isinstance(other, DiffractionObject):
226+
subtracted_do._all_arrays[:, 0] -= other.all_arrays[:, 0]
227+
return subtracted_do
242228

243-
def __rsub__(self, other):
244-
subtracted = deepcopy(self)
245-
if isinstance(other, int) or isinstance(other, float) or isinstance(other, np.ndarray):
246-
subtracted.on_tth[1] = other - self.on_tth[1]
247-
subtracted.on_q[1] = other - self.on_q[1]
248-
elif not isinstance(other, DiffractionObject):
249-
raise TypeError("I only know how to subtract two Scattering_object objects")
250-
elif self.on_tth[0].all() != other.on_tth[0].all():
251-
raise RuntimeError(y_grid_length_mismatch_emsg)
252-
else:
253-
subtracted.on_tth[1] = other.on_tth[1] - self.on_tth[1]
254-
subtracted.on_q[1] = other.on_q[1] - self.on_q[1]
255-
return subtracted
229+
__rsub__ = __sub__
256230

257231
def __mul__(self, other):
258232
multiplied = deepcopy(self)
@@ -268,19 +242,10 @@ def __mul__(self, other):
268242
multiplied.on_q[1] = self.on_q[1] * other.on_q[1]
269243
return multiplied
270244

271-
def __rmul__(self, other):
272-
multiplied = deepcopy(self)
273-
if isinstance(other, int) or isinstance(other, float) or isinstance(other, np.ndarray):
274-
multiplied.on_tth[1] = other * self.on_tth[1]
275-
multiplied.on_q[1] = other * self.on_q[1]
276-
elif self.on_tth[0].all() != other.on_tth[0].all():
277-
raise RuntimeError(y_grid_length_mismatch_emsg)
278-
else:
279-
multiplied.on_tth[1] = self.on_tth[1] * other.on_tth[1]
280-
multiplied.on_q[1] = self.on_q[1] * other.on_q[1]
281-
return multiplied
245+
__rmul__ = __mul__
282246

283247
def __truediv__(self, other):
248+
284249
divided = deepcopy(self)
285250
if isinstance(other, int) or isinstance(other, float) or isinstance(other, np.ndarray):
286251
divided.on_tth[1] = other / self.on_tth[1]
@@ -294,17 +259,16 @@ def __truediv__(self, other):
294259
divided.on_q[1] = self.on_q[1] / other.on_q[1]
295260
return divided
296261

297-
def __rtruediv__(self, other):
298-
divided = deepcopy(self)
299-
if isinstance(other, int) or isinstance(other, float) or isinstance(other, np.ndarray):
300-
divided.on_tth[1] = other / self.on_tth[1]
301-
divided.on_q[1] = other / self.on_q[1]
302-
elif self.on_tth[0].all() != other.on_tth[0].all():
303-
raise RuntimeError(y_grid_length_mismatch_emsg)
304-
else:
305-
divided.on_tth[1] = other.on_tth[1] / self.on_tth[1]
306-
divided.on_q[1] = other.on_q[1] / self.on_q[1]
307-
return divided
262+
__rmul__ = __mul__
263+
264+
def _check_operation_compatibility(self, other):
265+
if not isinstance(other, (DiffractionObject, int, float)):
266+
raise TypeError(invalid_add_type_emsg)
267+
if isinstance(other, DiffractionObject):
268+
self_yarray = self.all_arrays[:, 0]
269+
other_yarray = other.all_arrays[:, 0]
270+
if self_yarray.shape != other_yarray.shape:
271+
raise ValueError(y_grid_length_mismatch_emsg)
308272

309273
@property
310274
def all_arrays(self):

tests/test_diffraction_objects.py

Lines changed: 68 additions & 28 deletions
Original file line numberDiff line numberDiff line change
@@ -713,59 +713,93 @@ def test_copy_object(do_minimal):
713713

714714

715715
@pytest.mark.parametrize(
716-
"starting_all_arrays, scalar_to_add, expected_all_arrays",
716+
"operation, starting_all_arrays, scalar_value, expected_all_arrays",
717717
[
718-
# Test scalar addition to yarray values (intensity) and expect no change to xarrays (q, tth, d)
719-
( # C1: Add integer of 5, expect yarray to increase by by 5
718+
# C1: Test scalar addition to yarray values (intensity), expect no change to xarrays (q, tth, d)
719+
( # 1. Add integer 5
720+
"add",
720721
np.array([[1.0, 0.51763809, 30.0, 12.13818192], [2.0, 1.0, 60.0, 6.28318531]]),
721722
5,
722723
np.array([[6.0, 0.51763809, 30.0, 12.13818192], [7.0, 1.0, 60.0, 6.28318531]]),
723724
),
724-
( # C2: Add float of 5.1, expect yarray to be added by 5.1
725+
( # 2. Add float 5.1
726+
"add",
725727
np.array([[1.0, 0.51763809, 30.0, 12.13818192], [2.0, 1.0, 60.0, 6.28318531]]),
726728
5.1,
727729
np.array([[6.1, 0.51763809, 30.0, 12.13818192], [7.1, 1.0, 60.0, 6.28318531]]),
728730
),
731+
# C2. Test scalar subtraction to yarray values (intensity), expect no change to xarrays (q, tth, d)
732+
( # 1. Subtract integer 1
733+
"sub",
734+
np.array([[1.0, 0.51763809, 30.0, 12.13818192], [2.0, 1.0, 60.0, 6.28318531]]),
735+
1,
736+
np.array([[0.0, 0.51763809, 30.0, 12.13818192], [1.0, 1.0, 60.0, 6.28318531]]),
737+
),
738+
( # 2. Subtract float 0.5
739+
"sub",
740+
np.array([[1.0, 0.51763809, 30.0, 12.13818192], [2.0, 1.0, 60.0, 6.28318531]]),
741+
0.5,
742+
np.array([[0.5, 0.51763809, 30.0, 12.13818192], [1.5, 1.0, 60.0, 6.28318531]]),
743+
),
729744
],
730745
)
731-
def test_addition_operator_by_scalar(starting_all_arrays, scalar_to_add, expected_all_arrays, do_minimal_tth):
746+
def test_scalar_operations(operation, starting_all_arrays, scalar_value, expected_all_arrays, do_minimal_tth):
732747
do = do_minimal_tth
733748
assert np.allclose(do.all_arrays, starting_all_arrays)
734-
do_scalar_right_sum = do + scalar_to_add
735-
assert np.allclose(do_scalar_right_sum.all_arrays, expected_all_arrays)
736-
do_scalar_left_sum = scalar_to_add + do
737-
assert np.allclose(do_scalar_left_sum.all_arrays, expected_all_arrays)
749+
750+
if operation == "add":
751+
result_right = do + scalar_value
752+
result_left = scalar_value + do
753+
elif operation == "sub":
754+
result_right = do - scalar_value
755+
result_left = scalar_value - do
756+
757+
assert np.allclose(result_right.all_arrays, expected_all_arrays)
758+
assert np.allclose(result_left.all_arrays, expected_all_arrays)
738759

739760

740761
@pytest.mark.parametrize(
741-
"do_1_all_arrays, "
742-
"do_2_all_arrays, "
743-
"expected_do_1_all_arrays_with_y_summed, "
744-
"expected_do_2_all_arrays_with_y_summed",
762+
"operation, " "expected_do_1_all_arrays_with_y_modified, " "expected_do_2_all_arrays_with_y_modified",
745763
[
746-
# Test addition of two DO objects, expect combined yarray values and no change to xarrays ((q, tth, d)
747-
( # C1: Add two DO objects, expect sum of yarray values
748-
(np.array([[1.0, 0.51763809, 30.0, 12.13818192], [2.0, 1.0, 60.0, 6.28318531]]),),
749-
(np.array([[1.0, 6.28318531, 100.70777771, 1], [2.0, 3.14159265, 45.28748053, 2.0]]),),
750-
(np.array([[2.0, 0.51763809, 30.0, 12.13818192], [4.0, 1.0, 60.0, 6.28318531]]),),
751-
(np.array([[2.0, 6.28318531, 100.70777771, 1], [4.0, 3.14159265, 45.28748053, 2.0]]),),
764+
# Test addition of two DO objects, expect combined yarray values
765+
(
766+
"add",
767+
np.array([[2.0, 0.51763809, 30.0, 12.13818192], [4.0, 1.0, 60.0, 6.28318531]]),
768+
np.array([[2.0, 6.28318531, 100.70777771, 1], [4.0, 3.14159265, 45.28748053, 2.0]]),
769+
),
770+
# Test subtraction of two DO objects, expect differences in yarray values
771+
(
772+
"sub",
773+
np.array([[0.0, 0.51763809, 30.0, 12.13818192], [0.0, 1.0, 60.0, 6.28318531]]),
774+
np.array([[0.0, 6.28318531, 100.70777771, 1], [0.0, 3.14159265, 45.28748053, 2.0]]),
752775
),
753776
],
754777
)
755-
def test_addition_operator_by_another_do(
756-
do_1_all_arrays,
757-
do_2_all_arrays,
758-
expected_do_1_all_arrays_with_y_summed,
759-
expected_do_2_all_arrays_with_y_summed,
778+
def test_binary_operator_on_do(
779+
operation,
780+
expected_do_1_all_arrays_with_y_modified,
781+
expected_do_2_all_arrays_with_y_modified,
760782
do_minimal_tth,
761783
do_minimal_d,
762784
):
763785
do_1 = do_minimal_tth
764-
assert np.allclose(do_1.all_arrays, do_1_all_arrays)
765786
do_2 = do_minimal_d
766-
assert np.allclose(do_2.all_arrays, do_2_all_arrays)
767-
assert np.allclose((do_1 + do_2).all_arrays, expected_do_1_all_arrays_with_y_summed)
768-
assert np.allclose((do_2 + do_1).all_arrays, expected_do_2_all_arrays_with_y_summed)
787+
assert np.allclose(
788+
do_1.all_arrays, np.array([[1.0, 0.51763809, 30.0, 12.13818192], [2.0, 1.0, 60.0, 6.28318531]])
789+
)
790+
assert np.allclose(
791+
do_2.all_arrays, np.array([[1.0, 6.28318531, 100.70777771, 1], [2.0, 3.14159265, 45.28748053, 2.0]])
792+
)
793+
794+
if operation == "add":
795+
do_1_y_modified = do_1 + do_2
796+
do_2_y_modified = do_2 + do_1
797+
elif operation == "sub":
798+
do_1_y_modified = do_1 - do_2
799+
do_2_y_modified = do_2 - do_1
800+
801+
assert np.allclose(do_1_y_modified.all_arrays, expected_do_1_all_arrays_with_y_modified)
802+
assert np.allclose(do_2_y_modified.all_arrays, expected_do_2_all_arrays_with_y_modified)
769803

770804

771805
def test_addition_operator_invalid_type(do_minimal_tth, invalid_add_type_error_msg):
@@ -775,6 +809,10 @@ def test_addition_operator_invalid_type(do_minimal_tth, invalid_add_type_error_m
775809
do + "string_value"
776810
with pytest.raises(TypeError, match=re.escape(invalid_add_type_error_msg)):
777811
"string_value" + do
812+
with pytest.raises(TypeError, match=re.escape(invalid_add_type_error_msg)):
813+
do - "string_value"
814+
with pytest.raises(TypeError, match=re.escape(invalid_add_type_error_msg)):
815+
"string_value" - do
778816

779817

780818
def test_addition_operator_invalid_yarray_length(do_minimal, do_minimal_tth, y_grid_size_mismatch_error_msg):
@@ -785,3 +823,5 @@ def test_addition_operator_invalid_yarray_length(do_minimal, do_minimal_tth, y_g
785823
assert len(do_2.all_arrays[:, 0]) == 2
786824
with pytest.raises(ValueError, match=re.escape(y_grid_size_mismatch_error_msg)):
787825
do_1 + do_2
826+
with pytest.raises(ValueError, match=re.escape(y_grid_size_mismatch_error_msg)):
827+
do_1 - do_2

0 commit comments

Comments
 (0)