Skip to content

Commit 1ea8e9a

Browse files
authored
Merge pull request #285 from bobleesj/do-ops
Refactor `__add__` operation in `DiffractionObject` and add tests
2 parents 3f2d0a4 + da70bd6 commit 1ea8e9a

File tree

4 files changed

+185
-34
lines changed

4 files changed

+185
-34
lines changed

Diff for: news/add-operations-tests.rst

+23
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,23 @@
1+
**Added:**
2+
3+
* unit tests for __add__ operation for DiffractionObject
4+
5+
**Changed:**
6+
7+
* <news item>
8+
9+
**Deprecated:**
10+
11+
* <news item>
12+
13+
**Removed:**
14+
15+
* <news item>
16+
17+
**Fixed:**
18+
19+
* <news item>
20+
21+
**Security:**
22+
23+
* <news item>

Diff for: src/diffpy/utils/diffraction_objects.py

+64-34
Original file line numberDiff line numberDiff line change
@@ -14,9 +14,15 @@
1414
XQUANTITIES = ANGLEQUANTITIES + DQUANTITIES + QQUANTITIES
1515
XUNITS = ["degrees", "radians", "rad", "deg", "inv_angs", "inv_nm", "nm-1", "A-1"]
1616

17-
x_grid_emsg = (
18-
"objects are not on the same x-grid. You may add them using the self.add method "
19-
"and specifying how to handle the mismatch."
17+
y_grid_length_mismatch_emsg = (
18+
"The two objects have different y-array lengths. "
19+
"Please ensure the length of the y-value during initialization is identical."
20+
)
21+
22+
invalid_add_type_emsg = (
23+
"You may only add a DiffractionObject with another DiffractionObject or a scalar value. "
24+
"Please rerun by adding another DiffractionObject instance or a scalar value. "
25+
"e.g., my_do_1 + my_do_2 or my_do + 10 or 10 + my_do"
2026
)
2127

2228

@@ -169,32 +175,56 @@ def __eq__(self, other):
169175
return True
170176

171177
def __add__(self, other):
172-
summed = deepcopy(self)
173-
if isinstance(other, int) or isinstance(other, float) or isinstance(other, np.ndarray):
174-
summed.on_tth[1] = self.on_tth[1] + other
175-
summed.on_q[1] = self.on_q[1] + other
176-
elif not isinstance(other, DiffractionObject):
177-
raise TypeError("I only know how to sum two DiffractionObject objects")
178-
elif self.on_tth[0].all() != other.on_tth[0].all():
179-
raise RuntimeError(x_grid_emsg)
180-
else:
181-
summed.on_tth[1] = self.on_tth[1] + other.on_tth[1]
182-
summed.on_q[1] = self.on_q[1] + other.on_q[1]
183-
return summed
178+
"""Add a scalar value or another DiffractionObject to the yarray of the
179+
DiffractionObject.
184180
185-
def __radd__(self, other):
186-
summed = deepcopy(self)
187-
if isinstance(other, int) or isinstance(other, float) or isinstance(other, np.ndarray):
188-
summed.on_tth[1] = self.on_tth[1] + other
189-
summed.on_q[1] = self.on_q[1] + other
190-
elif not isinstance(other, DiffractionObject):
191-
raise TypeError("I only know how to sum two Scattering_object objects")
192-
elif self.on_tth[0].all() != other.on_tth[0].all():
193-
raise RuntimeError(x_grid_emsg)
194-
else:
195-
summed.on_tth[1] = self.on_tth[1] + other.on_tth[1]
196-
summed.on_q[1] = self.on_q[1] + other.on_q[1]
197-
return summed
181+
Parameters
182+
----------
183+
other : DiffractionObject or int or float
184+
The object to add to the current DiffractionObject. If `other` is a scalar value,
185+
it will be added to all yarray. The length of the yarray must match if `other` is
186+
an instance of DiffractionObject.
187+
188+
Returns
189+
-------
190+
DiffractionObject
191+
The new and deep-copied DiffractionObject instance after adding values to the yarray.
192+
193+
Raises
194+
------
195+
ValueError
196+
Raised when the length of the yarray of the two DiffractionObject instances do not match.
197+
TypeError
198+
Raised when the type of `other` is not an instance of DiffractionObject, int, or float.
199+
200+
Examples
201+
--------
202+
Add a scalar value to the yarray of the DiffractionObject instance:
203+
>>> new_do = my_do + 10.1
204+
>>> new_do = 10.1 + my_do
205+
206+
Add the yarray of two DiffractionObject instances:
207+
>>> new_do = my_do_1 + my_do_2
208+
"""
209+
210+
self._check_operation_compatibility(other)
211+
summed_do = deepcopy(self)
212+
if isinstance(other, (int, float)):
213+
summed_do._all_arrays[:, 0] += other
214+
if isinstance(other, DiffractionObject):
215+
summed_do._all_arrays[:, 0] += other.all_arrays[:, 0]
216+
return summed_do
217+
218+
__radd__ = __add__
219+
220+
def _check_operation_compatibility(self, other):
221+
if not isinstance(other, (DiffractionObject, int, float)):
222+
raise TypeError(invalid_add_type_emsg)
223+
if isinstance(other, DiffractionObject):
224+
self_yarray = self.all_arrays[:, 0]
225+
other_yarray = other.all_arrays[:, 0]
226+
if len(self_yarray) != len(other_yarray):
227+
raise ValueError(y_grid_length_mismatch_emsg)
198228

199229
def __sub__(self, other):
200230
subtracted = deepcopy(self)
@@ -204,7 +234,7 @@ def __sub__(self, other):
204234
elif not isinstance(other, DiffractionObject):
205235
raise TypeError("I only know how to subtract two Scattering_object objects")
206236
elif self.on_tth[0].all() != other.on_tth[0].all():
207-
raise RuntimeError(x_grid_emsg)
237+
raise RuntimeError(y_grid_length_mismatch_emsg)
208238
else:
209239
subtracted.on_tth[1] = self.on_tth[1] - other.on_tth[1]
210240
subtracted.on_q[1] = self.on_q[1] - other.on_q[1]
@@ -218,7 +248,7 @@ def __rsub__(self, other):
218248
elif not isinstance(other, DiffractionObject):
219249
raise TypeError("I only know how to subtract two Scattering_object objects")
220250
elif self.on_tth[0].all() != other.on_tth[0].all():
221-
raise RuntimeError(x_grid_emsg)
251+
raise RuntimeError(y_grid_length_mismatch_emsg)
222252
else:
223253
subtracted.on_tth[1] = other.on_tth[1] - self.on_tth[1]
224254
subtracted.on_q[1] = other.on_q[1] - self.on_q[1]
@@ -232,7 +262,7 @@ def __mul__(self, other):
232262
elif not isinstance(other, DiffractionObject):
233263
raise TypeError("I only know how to multiply two Scattering_object objects")
234264
elif self.on_tth[0].all() != other.on_tth[0].all():
235-
raise RuntimeError(x_grid_emsg)
265+
raise RuntimeError(y_grid_length_mismatch_emsg)
236266
else:
237267
multiplied.on_tth[1] = self.on_tth[1] * other.on_tth[1]
238268
multiplied.on_q[1] = self.on_q[1] * other.on_q[1]
@@ -244,7 +274,7 @@ def __rmul__(self, other):
244274
multiplied.on_tth[1] = other * self.on_tth[1]
245275
multiplied.on_q[1] = other * self.on_q[1]
246276
elif self.on_tth[0].all() != other.on_tth[0].all():
247-
raise RuntimeError(x_grid_emsg)
277+
raise RuntimeError(y_grid_length_mismatch_emsg)
248278
else:
249279
multiplied.on_tth[1] = self.on_tth[1] * other.on_tth[1]
250280
multiplied.on_q[1] = self.on_q[1] * other.on_q[1]
@@ -258,7 +288,7 @@ def __truediv__(self, other):
258288
elif not isinstance(other, DiffractionObject):
259289
raise TypeError("I only know how to multiply two Scattering_object objects")
260290
elif self.on_tth[0].all() != other.on_tth[0].all():
261-
raise RuntimeError(x_grid_emsg)
291+
raise RuntimeError(y_grid_length_mismatch_emsg)
262292
else:
263293
divided.on_tth[1] = self.on_tth[1] / other.on_tth[1]
264294
divided.on_q[1] = self.on_q[1] / other.on_q[1]
@@ -270,7 +300,7 @@ def __rtruediv__(self, other):
270300
divided.on_tth[1] = other / self.on_tth[1]
271301
divided.on_q[1] = other / self.on_q[1]
272302
elif self.on_tth[0].all() != other.on_tth[0].all():
273-
raise RuntimeError(x_grid_emsg)
303+
raise RuntimeError(y_grid_length_mismatch_emsg)
274304
else:
275305
divided.on_tth[1] = other.on_tth[1] / self.on_tth[1]
276306
divided.on_q[1] = other.on_q[1] / self.on_q[1]

Diff for: tests/conftest.py

+23
Original file line numberDiff line numberDiff line change
@@ -47,6 +47,12 @@ def do_minimal_tth():
4747
return DiffractionObject(wavelength=2 * np.pi, xarray=np.array([30, 60]), yarray=np.array([1, 2]), xtype="tth")
4848

4949

50+
@pytest.fixture
51+
def do_minimal_d():
52+
# Create an instance of DiffractionObject with non-empty xarray, yarray, and wavelength values
53+
return DiffractionObject(wavelength=1.54, xarray=np.array([1, 2]), yarray=np.array([1, 2]), xtype="d")
54+
55+
5056
@pytest.fixture
5157
def wavelength_warning_msg():
5258
return (
@@ -63,3 +69,20 @@ def invalid_q_or_d_or_wavelength_error_msg():
6369
"The supplied input array and wavelength will result in an impossible two-theta. "
6470
"Please check these values and re-instantiate the DiffractionObject with correct values."
6571
)
72+
73+
74+
@pytest.fixture
75+
def invalid_add_type_error_msg():
76+
return (
77+
"You may only add a DiffractionObject with another DiffractionObject or a scalar value. "
78+
"Please rerun by adding another DiffractionObject instance or a scalar value. "
79+
"e.g., my_do_1 + my_do_2 or my_do + 10 or 10 + my_do"
80+
)
81+
82+
83+
@pytest.fixture
84+
def y_grid_size_mismatch_error_msg():
85+
return (
86+
"The two objects have different y-array lengths. "
87+
"Please ensure the length of the y-value during initialization is identical."
88+
)

Diff for: tests/test_diffraction_objects.py

+75
Original file line numberDiff line numberDiff line change
@@ -710,3 +710,78 @@ def test_copy_object(do_minimal):
710710
do_copy = do.copy()
711711
assert do == do_copy
712712
assert id(do) != id(do_copy)
713+
714+
715+
@pytest.mark.parametrize(
716+
"starting_all_arrays, scalar_to_add, expected_all_arrays",
717+
[
718+
# Test scalar addition to yarray values (intensity) and expect no change to xarrays (q, tth, d)
719+
( # C1: Add integer of 5, expect yarray to increase by by 5
720+
np.array([[1.0, 0.51763809, 30.0, 12.13818192], [2.0, 1.0, 60.0, 6.28318531]]),
721+
5,
722+
np.array([[6.0, 0.51763809, 30.0, 12.13818192], [7.0, 1.0, 60.0, 6.28318531]]),
723+
),
724+
( # C2: Add float of 5.1, expect yarray to be added by 5.1
725+
np.array([[1.0, 0.51763809, 30.0, 12.13818192], [2.0, 1.0, 60.0, 6.28318531]]),
726+
5.1,
727+
np.array([[6.1, 0.51763809, 30.0, 12.13818192], [7.1, 1.0, 60.0, 6.28318531]]),
728+
),
729+
],
730+
)
731+
def test_addition_operator_by_scalar(starting_all_arrays, scalar_to_add, expected_all_arrays, do_minimal_tth):
732+
do = do_minimal_tth
733+
assert np.allclose(do.all_arrays, starting_all_arrays)
734+
do_scalar_right_sum = do + scalar_to_add
735+
assert np.allclose(do_scalar_right_sum.all_arrays, expected_all_arrays)
736+
do_scalar_left_sum = scalar_to_add + do
737+
assert np.allclose(do_scalar_left_sum.all_arrays, expected_all_arrays)
738+
739+
740+
@pytest.mark.parametrize(
741+
"do_1_all_arrays, "
742+
"do_2_all_arrays, "
743+
"expected_do_1_all_arrays_with_y_summed, "
744+
"expected_do_2_all_arrays_with_y_summed",
745+
[
746+
# Test addition of two DO objects, expect combined yarray values and no change to xarrays ((q, tth, d)
747+
( # C1: Add two DO objects, expect sum of yarray values
748+
(np.array([[1.0, 0.51763809, 30.0, 12.13818192], [2.0, 1.0, 60.0, 6.28318531]]),),
749+
(np.array([[1.0, 6.28318531, 100.70777771, 1], [2.0, 3.14159265, 45.28748053, 2.0]]),),
750+
(np.array([[2.0, 0.51763809, 30.0, 12.13818192], [4.0, 1.0, 60.0, 6.28318531]]),),
751+
(np.array([[2.0, 6.28318531, 100.70777771, 1], [4.0, 3.14159265, 45.28748053, 2.0]]),),
752+
),
753+
],
754+
)
755+
def test_addition_operator_by_another_do(
756+
do_1_all_arrays,
757+
do_2_all_arrays,
758+
expected_do_1_all_arrays_with_y_summed,
759+
expected_do_2_all_arrays_with_y_summed,
760+
do_minimal_tth,
761+
do_minimal_d,
762+
):
763+
do_1 = do_minimal_tth
764+
assert np.allclose(do_1.all_arrays, do_1_all_arrays)
765+
do_2 = do_minimal_d
766+
assert np.allclose(do_2.all_arrays, do_2_all_arrays)
767+
assert np.allclose((do_1 + do_2).all_arrays, expected_do_1_all_arrays_with_y_summed)
768+
assert np.allclose((do_2 + do_1).all_arrays, expected_do_2_all_arrays_with_y_summed)
769+
770+
771+
def test_addition_operator_invalid_type(do_minimal_tth, invalid_add_type_error_msg):
772+
# Add a string to a DO object, expect TypeError, only scalar (int, float) allowed for addition
773+
do = do_minimal_tth
774+
with pytest.raises(TypeError, match=re.escape(invalid_add_type_error_msg)):
775+
do + "string_value"
776+
with pytest.raises(TypeError, match=re.escape(invalid_add_type_error_msg)):
777+
"string_value" + do
778+
779+
780+
def test_addition_operator_invalid_yarray_length(do_minimal, do_minimal_tth, y_grid_size_mismatch_error_msg):
781+
# Combine two DO objects, one with empty xarrays (do_minimal) and the other with non-empty xarrays
782+
do_1 = do_minimal
783+
do_2 = do_minimal_tth
784+
assert len(do_1.all_arrays[:, 0]) == 0
785+
assert len(do_2.all_arrays[:, 0]) == 2
786+
with pytest.raises(ValueError, match=re.escape(y_grid_size_mismatch_error_msg)):
787+
do_1 + do_2

0 commit comments

Comments
 (0)