Skip to content

Commit 90fd625

Browse files
authored
Merge pull request #293 from bobleesj/op-mul-sub
feat: Support *, /, - operations between two DiffractionObjects or scalar
2 parents 1ea8e9a + 21711fb commit 90fd625

File tree

3 files changed

+202
-127
lines changed

3 files changed

+202
-127
lines changed

Diff for: news/op-mul-sub-div.rst

+23
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,23 @@
1+
**Added:**
2+
3+
* addition, multiplication, subtraction, and division operators between two DiffractionObject instances or a scalar value with another DiffractionObject for modifying yarray (intensity) values.
4+
5+
**Changed:**
6+
7+
* <news item>
8+
9+
**Deprecated:**
10+
11+
* <news item>
12+
13+
**Removed:**
14+
15+
* <news item>
16+
17+
**Fixed:**
18+
19+
* <news item>
20+
21+
**Security:**
22+
23+
* <news item>

Diff for: src/diffpy/utils/diffraction_objects.py

+33-82
Original file line numberDiff line numberDiff line change
@@ -217,94 +217,45 @@ def __add__(self, other):
217217

218218
__radd__ = __add__
219219

220-
def _check_operation_compatibility(self, other):
221-
if not isinstance(other, (DiffractionObject, int, float)):
222-
raise TypeError(invalid_add_type_emsg)
220+
def __sub__(self, other):
221+
self._check_operation_compatibility(other)
222+
subtracted_do = deepcopy(self)
223+
if isinstance(other, (int, float)):
224+
subtracted_do._all_arrays[:, 0] -= other
223225
if isinstance(other, DiffractionObject):
224-
self_yarray = self.all_arrays[:, 0]
225-
other_yarray = other.all_arrays[:, 0]
226-
if len(self_yarray) != len(other_yarray):
227-
raise ValueError(y_grid_length_mismatch_emsg)
226+
subtracted_do._all_arrays[:, 0] -= other.all_arrays[:, 0]
227+
return subtracted_do
228228

229-
def __sub__(self, other):
230-
subtracted = deepcopy(self)
231-
if isinstance(other, int) or isinstance(other, float) or isinstance(other, np.ndarray):
232-
subtracted.on_tth[1] = self.on_tth[1] - other
233-
subtracted.on_q[1] = self.on_q[1] - other
234-
elif not isinstance(other, DiffractionObject):
235-
raise TypeError("I only know how to subtract two Scattering_object objects")
236-
elif self.on_tth[0].all() != other.on_tth[0].all():
237-
raise RuntimeError(y_grid_length_mismatch_emsg)
238-
else:
239-
subtracted.on_tth[1] = self.on_tth[1] - other.on_tth[1]
240-
subtracted.on_q[1] = self.on_q[1] - other.on_q[1]
241-
return subtracted
242-
243-
def __rsub__(self, other):
244-
subtracted = deepcopy(self)
245-
if isinstance(other, int) or isinstance(other, float) or isinstance(other, np.ndarray):
246-
subtracted.on_tth[1] = other - self.on_tth[1]
247-
subtracted.on_q[1] = other - self.on_q[1]
248-
elif not isinstance(other, DiffractionObject):
249-
raise TypeError("I only know how to subtract two Scattering_object objects")
250-
elif self.on_tth[0].all() != other.on_tth[0].all():
251-
raise RuntimeError(y_grid_length_mismatch_emsg)
252-
else:
253-
subtracted.on_tth[1] = other.on_tth[1] - self.on_tth[1]
254-
subtracted.on_q[1] = other.on_q[1] - self.on_q[1]
255-
return subtracted
229+
__rsub__ = __sub__
256230

257231
def __mul__(self, other):
258-
multiplied = deepcopy(self)
259-
if isinstance(other, int) or isinstance(other, float) or isinstance(other, np.ndarray):
260-
multiplied.on_tth[1] = other * self.on_tth[1]
261-
multiplied.on_q[1] = other * self.on_q[1]
262-
elif not isinstance(other, DiffractionObject):
263-
raise TypeError("I only know how to multiply two Scattering_object objects")
264-
elif self.on_tth[0].all() != other.on_tth[0].all():
265-
raise RuntimeError(y_grid_length_mismatch_emsg)
266-
else:
267-
multiplied.on_tth[1] = self.on_tth[1] * other.on_tth[1]
268-
multiplied.on_q[1] = self.on_q[1] * other.on_q[1]
269-
return multiplied
270-
271-
def __rmul__(self, other):
272-
multiplied = deepcopy(self)
273-
if isinstance(other, int) or isinstance(other, float) or isinstance(other, np.ndarray):
274-
multiplied.on_tth[1] = other * self.on_tth[1]
275-
multiplied.on_q[1] = other * self.on_q[1]
276-
elif self.on_tth[0].all() != other.on_tth[0].all():
277-
raise RuntimeError(y_grid_length_mismatch_emsg)
278-
else:
279-
multiplied.on_tth[1] = self.on_tth[1] * other.on_tth[1]
280-
multiplied.on_q[1] = self.on_q[1] * other.on_q[1]
281-
return multiplied
232+
self._check_operation_compatibility(other)
233+
multiplied_do = deepcopy(self)
234+
if isinstance(other, (int, float)):
235+
multiplied_do._all_arrays[:, 0] *= other
236+
if isinstance(other, DiffractionObject):
237+
multiplied_do._all_arrays[:, 0] *= other.all_arrays[:, 0]
238+
return multiplied_do
239+
240+
__rmul__ = __mul__
282241

283242
def __truediv__(self, other):
284-
divided = deepcopy(self)
285-
if isinstance(other, int) or isinstance(other, float) or isinstance(other, np.ndarray):
286-
divided.on_tth[1] = other / self.on_tth[1]
287-
divided.on_q[1] = other / self.on_q[1]
288-
elif not isinstance(other, DiffractionObject):
289-
raise TypeError("I only know how to multiply two Scattering_object objects")
290-
elif self.on_tth[0].all() != other.on_tth[0].all():
291-
raise RuntimeError(y_grid_length_mismatch_emsg)
292-
else:
293-
divided.on_tth[1] = self.on_tth[1] / other.on_tth[1]
294-
divided.on_q[1] = self.on_q[1] / other.on_q[1]
295-
return divided
296-
297-
def __rtruediv__(self, other):
298-
divided = deepcopy(self)
299-
if isinstance(other, int) or isinstance(other, float) or isinstance(other, np.ndarray):
300-
divided.on_tth[1] = other / self.on_tth[1]
301-
divided.on_q[1] = other / self.on_q[1]
302-
elif self.on_tth[0].all() != other.on_tth[0].all():
303-
raise RuntimeError(y_grid_length_mismatch_emsg)
304-
else:
305-
divided.on_tth[1] = other.on_tth[1] / self.on_tth[1]
306-
divided.on_q[1] = other.on_q[1] / self.on_q[1]
307-
return divided
243+
self._check_operation_compatibility(other)
244+
divided_do = deepcopy(self)
245+
if isinstance(other, (int, float)):
246+
divided_do._all_arrays[:, 0] /= other
247+
if isinstance(other, DiffractionObject):
248+
divided_do._all_arrays[:, 0] /= other.all_arrays[:, 0]
249+
return divided_do
250+
251+
__rtruediv__ = __truediv__
252+
253+
def _check_operation_compatibility(self, other):
254+
if not isinstance(other, (DiffractionObject, int, float)):
255+
raise TypeError(invalid_add_type_emsg)
256+
if isinstance(other, DiffractionObject):
257+
if self.all_arrays.shape != other.all_arrays.shape:
258+
raise ValueError(y_grid_length_mismatch_emsg)
308259

309260
@property
310261
def all_arrays(self):

Diff for: tests/test_diffraction_objects.py

+146-45
Original file line numberDiff line numberDiff line change
@@ -713,75 +713,176 @@ def test_copy_object(do_minimal):
713713

714714

715715
@pytest.mark.parametrize(
716-
"starting_all_arrays, scalar_to_add, expected_all_arrays",
716+
"operation, starting_yarray, scalar_value, expected_yarray",
717717
[
718-
# Test scalar addition to yarray values (intensity) and expect no change to xarrays (q, tth, d)
719-
( # C1: Add integer of 5, expect yarray to increase by by 5
720-
np.array([[1.0, 0.51763809, 30.0, 12.13818192], [2.0, 1.0, 60.0, 6.28318531]]),
718+
# Test scalar addition, subtraction, multiplication, and division to y-values by adding a scalar value
719+
# C1: Test scalar addition to y-values (intensity), expect no change to x-values (q, tth, d)
720+
( # 1. Add 5
721+
"add",
722+
np.array([1.0, 2.0]),
721723
5,
722-
np.array([[6.0, 0.51763809, 30.0, 12.13818192], [7.0, 1.0, 60.0, 6.28318531]]),
724+
np.array([6.0, 7.0]),
723725
),
724-
( # C2: Add float of 5.1, expect yarray to be added by 5.1
725-
np.array([[1.0, 0.51763809, 30.0, 12.13818192], [2.0, 1.0, 60.0, 6.28318531]]),
726+
( # 2. Add 5.1
727+
"add",
728+
np.array([1.0, 2.0]),
726729
5.1,
727-
np.array([[6.1, 0.51763809, 30.0, 12.13818192], [7.1, 1.0, 60.0, 6.28318531]]),
730+
np.array([6.1, 7.1]),
731+
),
732+
# C2: Test scalar subtraction to y-values (intensity), expect no change to x-values (q, tth, d)
733+
( # 1. Subtract 1
734+
"sub",
735+
np.array([1.0, 2.0]),
736+
1,
737+
np.array([0.0, 1.0]),
738+
),
739+
( # 2. Subtract 0.5
740+
"sub",
741+
np.array([1.0, 2.0]),
742+
0.5,
743+
np.array([0.5, 1.5]),
744+
),
745+
# C3: Test scalar multiplication to y-values (intensity), expect no change to x-values (q, tth, d)
746+
( # 1. Multiply by 2
747+
"mul",
748+
np.array([1.0, 2.0]),
749+
2,
750+
np.array([2.0, 4.0]),
751+
),
752+
( # 2. Multiply by 2.5
753+
"mul",
754+
np.array([1.0, 2.0]),
755+
2.5,
756+
np.array([2.5, 5.0]),
757+
),
758+
# C4: Test scalar division to y-values (intensity), expect no change to x-values (q, tth, d)
759+
( # 1. Divide by 2
760+
"div",
761+
np.array([1.0, 2.0]),
762+
2,
763+
np.array([0.5, 1.0]),
764+
),
765+
( # 2. Divide by 2.5
766+
"div",
767+
np.array([1.0, 2.0]),
768+
2.5,
769+
np.array([0.4, 0.8]),
728770
),
729771
],
730772
)
731-
def test_addition_operator_by_scalar(starting_all_arrays, scalar_to_add, expected_all_arrays, do_minimal_tth):
773+
def test_scalar_operations(operation, starting_yarray, scalar_value, expected_yarray, do_minimal_tth):
732774
do = do_minimal_tth
733-
assert np.allclose(do.all_arrays, starting_all_arrays)
734-
do_scalar_right_sum = do + scalar_to_add
735-
assert np.allclose(do_scalar_right_sum.all_arrays, expected_all_arrays)
736-
do_scalar_left_sum = scalar_to_add + do
737-
assert np.allclose(do_scalar_left_sum.all_arrays, expected_all_arrays)
775+
expected_xarray_constant = np.array([[0.51763809, 30.0, 12.13818192], [1.0, 60.0, 6.28318531]])
776+
assert np.allclose(do.all_arrays[:, [1, 2, 3]], expected_xarray_constant)
777+
assert np.allclose(do.all_arrays[:, 0], starting_yarray)
778+
if operation == "add":
779+
do_right_op = do + scalar_value
780+
do_left_op = scalar_value + do
781+
elif operation == "sub":
782+
do_right_op = do - scalar_value
783+
do_left_op = scalar_value - do
784+
elif operation == "mul":
785+
do_right_op = do * scalar_value
786+
do_left_op = scalar_value * do
787+
elif operation == "div":
788+
do_right_op = do / scalar_value
789+
do_left_op = scalar_value / do
790+
assert np.allclose(do_right_op.all_arrays[:, 0], expected_yarray)
791+
assert np.allclose(do_left_op.all_arrays[:, 0], expected_yarray)
792+
# Ensure x-values are unchanged
793+
assert np.allclose(do_right_op.all_arrays[:, [1, 2, 3]], expected_xarray_constant)
794+
assert np.allclose(do_left_op.all_arrays[:, [1, 2, 3]], expected_xarray_constant)
738795

739796

740797
@pytest.mark.parametrize(
741-
"do_1_all_arrays, "
742-
"do_2_all_arrays, "
743-
"expected_do_1_all_arrays_with_y_summed, "
744-
"expected_do_2_all_arrays_with_y_summed",
798+
"operation, " "expected_do_1_all_arrays_with_y_modified, " "expected_do_2_all_arrays_with_y_modified",
745799
[
746-
# Test addition of two DO objects, expect combined yarray values and no change to xarrays ((q, tth, d)
747-
( # C1: Add two DO objects, expect sum of yarray values
748-
(np.array([[1.0, 0.51763809, 30.0, 12.13818192], [2.0, 1.0, 60.0, 6.28318531]]),),
749-
(np.array([[1.0, 6.28318531, 100.70777771, 1], [2.0, 3.14159265, 45.28748053, 2.0]]),),
750-
(np.array([[2.0, 0.51763809, 30.0, 12.13818192], [4.0, 1.0, 60.0, 6.28318531]]),),
751-
(np.array([[2.0, 6.28318531, 100.70777771, 1], [4.0, 3.14159265, 45.28748053, 2.0]]),),
800+
# Test addition, subtraction, multiplication, and division of two DO objects
801+
( # Test addition of two DO objects, expect combined yarray values
802+
"add",
803+
np.array([[2.0, 0.51763809, 30.0, 12.13818192], [4.0, 1.0, 60.0, 6.28318531]]),
804+
np.array([[2.0, 6.28318531, 100.70777771, 1], [4.0, 3.14159265, 45.28748053, 2.0]]),
805+
),
806+
( # Test subtraction of two DO objects, expect differences in yarray values
807+
"sub",
808+
np.array([[0.0, 0.51763809, 30.0, 12.13818192], [0.0, 1.0, 60.0, 6.28318531]]),
809+
np.array([[0.0, 6.28318531, 100.70777771, 1], [0.0, 3.14159265, 45.28748053, 2.0]]),
810+
),
811+
( # Test multiplication of two DO objects, expect multiplication in yarray values
812+
"mul",
813+
np.array([[1.0, 0.51763809, 30.0, 12.13818192], [4.0, 1.0, 60.0, 6.28318531]]),
814+
np.array([[1.0, 6.28318531, 100.70777771, 1], [4.0, 3.14159265, 45.28748053, 2.0]]),
815+
),
816+
( # Test division of two DO objects, expect division in yarray values
817+
"div",
818+
np.array([[1.0, 0.51763809, 30.0, 12.13818192], [1.0, 1.0, 60.0, 6.28318531]]),
819+
np.array([[1.0, 6.28318531, 100.70777771, 1], [1.0, 3.14159265, 45.28748053, 2.0]]),
752820
),
753821
],
754822
)
755-
def test_addition_operator_by_another_do(
756-
do_1_all_arrays,
757-
do_2_all_arrays,
758-
expected_do_1_all_arrays_with_y_summed,
759-
expected_do_2_all_arrays_with_y_summed,
823+
def test_binary_operator_on_do(
824+
operation,
825+
expected_do_1_all_arrays_with_y_modified,
826+
expected_do_2_all_arrays_with_y_modified,
760827
do_minimal_tth,
761828
do_minimal_d,
762829
):
763830
do_1 = do_minimal_tth
764-
assert np.allclose(do_1.all_arrays, do_1_all_arrays)
765831
do_2 = do_minimal_d
766-
assert np.allclose(do_2.all_arrays, do_2_all_arrays)
767-
assert np.allclose((do_1 + do_2).all_arrays, expected_do_1_all_arrays_with_y_summed)
768-
assert np.allclose((do_2 + do_1).all_arrays, expected_do_2_all_arrays_with_y_summed)
769-
832+
assert np.allclose(
833+
do_1.all_arrays, np.array([[1.0, 0.51763809, 30.0, 12.13818192], [2.0, 1.0, 60.0, 6.28318531]])
834+
)
835+
assert np.allclose(
836+
do_2.all_arrays, np.array([[1.0, 6.28318531, 100.70777771, 1], [2.0, 3.14159265, 45.28748053, 2.0]])
837+
)
770838

771-
def test_addition_operator_invalid_type(do_minimal_tth, invalid_add_type_error_msg):
772-
# Add a string to a DO object, expect TypeError, only scalar (int, float) allowed for addition
839+
if operation == "add":
840+
do_1_y_modified = do_1 + do_2
841+
do_2_y_modified = do_2 + do_1
842+
elif operation == "sub":
843+
do_1_y_modified = do_1 - do_2
844+
do_2_y_modified = do_2 - do_1
845+
elif operation == "mul":
846+
do_1_y_modified = do_1 * do_2
847+
do_2_y_modified = do_2 * do_1
848+
elif operation == "div":
849+
do_1_y_modified = do_1 / do_2
850+
do_2_y_modified = do_2 / do_1
851+
852+
assert np.allclose(do_1_y_modified.all_arrays, expected_do_1_all_arrays_with_y_modified)
853+
assert np.allclose(do_2_y_modified.all_arrays, expected_do_2_all_arrays_with_y_modified)
854+
855+
856+
def test_operator_invalid_type(do_minimal_tth, invalid_add_type_error_msg):
857+
# Add a string to a DiffractionObject, expect TypeError
773858
do = do_minimal_tth
774-
with pytest.raises(TypeError, match=re.escape(invalid_add_type_error_msg)):
775-
do + "string_value"
776-
with pytest.raises(TypeError, match=re.escape(invalid_add_type_error_msg)):
777-
"string_value" + do
778-
779-
780-
def test_addition_operator_invalid_yarray_length(do_minimal, do_minimal_tth, y_grid_size_mismatch_error_msg):
781-
# Combine two DO objects, one with empty xarrays (do_minimal) and the other with non-empty xarrays
859+
invalid_value = "string_value"
860+
operations = [
861+
(lambda x, y: x + y), # Test addition
862+
(lambda x, y: x - y), # Test subtraction
863+
(lambda x, y: x * y), # Test multiplication
864+
(lambda x, y: x / y), # Test division
865+
]
866+
for operation in operations:
867+
with pytest.raises(TypeError, match=re.escape(invalid_add_type_error_msg)):
868+
operation(do, invalid_value)
869+
with pytest.raises(TypeError, match=re.escape(invalid_add_type_error_msg)):
870+
operation(invalid_value, do)
871+
872+
873+
@pytest.mark.parametrize("operation", ["add", "sub", "mul", "div"])
874+
def test_operator_invalid_yarray_length(operation, do_minimal, do_minimal_tth, y_grid_size_mismatch_error_msg):
875+
# Add two DO objects with different yarray lengths, expect ValueError
782876
do_1 = do_minimal
783877
do_2 = do_minimal_tth
784878
assert len(do_1.all_arrays[:, 0]) == 0
785879
assert len(do_2.all_arrays[:, 0]) == 2
786880
with pytest.raises(ValueError, match=re.escape(y_grid_size_mismatch_error_msg)):
787-
do_1 + do_2
881+
if operation == "add":
882+
do_1 + do_2
883+
elif operation == "sub":
884+
do_1 - do_2
885+
elif operation == "mul":
886+
do_1 * do_2
887+
elif operation == "div":
888+
do_1 / do_2

0 commit comments

Comments
 (0)