Skip to content

Commit eb2941b

Browse files
committed
feat: add __div__ function to DiffractionObject
1 parent ddeeb26 commit eb2941b

File tree

2 files changed

+97
-64
lines changed

2 files changed

+97
-64
lines changed

src/diffpy/utils/diffraction_objects.py

Lines changed: 8 additions & 14 deletions
Original file line numberDiff line numberDiff line change
@@ -240,21 +240,15 @@ def __mul__(self, other):
240240
__rmul__ = __mul__
241241

242242
def __truediv__(self, other):
243+
self._check_operation_compatibility(other)
244+
divided_do = deepcopy(self)
245+
if isinstance(other, (int, float)):
246+
divided_do._all_arrays[:, 0] /= other
247+
if isinstance(other, DiffractionObject):
248+
divided_do._all_arrays[:, 0] /= other.all_arrays[:, 0]
249+
return divided_do
243250

244-
divided = deepcopy(self)
245-
if isinstance(other, int) or isinstance(other, float) or isinstance(other, np.ndarray):
246-
divided.on_tth[1] = other / self.on_tth[1]
247-
divided.on_q[1] = other / self.on_q[1]
248-
elif not isinstance(other, DiffractionObject):
249-
raise TypeError("I only know how to multiply two Scattering_object objects")
250-
elif self.on_tth[0].all() != other.on_tth[0].all():
251-
raise RuntimeError(y_grid_length_mismatch_emsg)
252-
else:
253-
divided.on_tth[1] = self.on_tth[1] / other.on_tth[1]
254-
divided.on_q[1] = self.on_q[1] / other.on_q[1]
255-
return divided
256-
257-
__rmul__ = __mul__
251+
__rtruediv__ = __truediv__
258252

259253
def _check_operation_compatibility(self, other):
260254
if not isinstance(other, (DiffractionObject, int, float)):

tests/test_diffraction_objects.py

Lines changed: 89 additions & 50 deletions
Original file line numberDiff line numberDiff line change
@@ -713,64 +713,84 @@ def test_copy_object(do_minimal):
713713

714714

715715
@pytest.mark.parametrize(
716-
"operation, starting_all_arrays, scalar_value, expected_all_arrays",
716+
"operation, starting_yarray, scalar_value, expected_yarray",
717717
[
718-
# C1: Test scalar addition to yarray values (intensity), expect no change to xarrays (q, tth, d)
719-
( # 1. Add integer 5
718+
# C1: Test scalar addition to y-values (intensity), expect no change to x-values (q, tth, d)
719+
( # 1. Add 5
720720
"add",
721-
np.array([[1.0, 0.51763809, 30.0, 12.13818192], [2.0, 1.0, 60.0, 6.28318531]]),
721+
np.array([1.0, 2.0]),
722722
5,
723-
np.array([[6.0, 0.51763809, 30.0, 12.13818192], [7.0, 1.0, 60.0, 6.28318531]]),
723+
np.array([6.0, 7.0]),
724724
),
725-
( # 2. Add float 5.1
725+
( # 2. Add 5.1
726726
"add",
727-
np.array([[1.0, 0.51763809, 30.0, 12.13818192], [2.0, 1.0, 60.0, 6.28318531]]),
727+
np.array([1.0, 2.0]),
728728
5.1,
729-
np.array([[6.1, 0.51763809, 30.0, 12.13818192], [7.1, 1.0, 60.0, 6.28318531]]),
729+
np.array([6.1, 7.1]),
730730
),
731-
# C2. Test scalar subtraction to yarray values (intensity), expect no change to xarrays (q, tth, d)
732-
( # 1. Subtract integer 1
731+
# C2: Test scalar subtraction to y-values (intensity), expect no change to x-values (q, tth, d)
732+
( # 1. Subtract 1
733733
"sub",
734-
np.array([[1.0, 0.51763809, 30.0, 12.13818192], [2.0, 1.0, 60.0, 6.28318531]]),
734+
np.array([1.0, 2.0]),
735735
1,
736-
np.array([[0.0, 0.51763809, 30.0, 12.13818192], [1.0, 1.0, 60.0, 6.28318531]]),
736+
np.array([0.0, 1.0]),
737737
),
738-
( # 2. Subtract float 0.5
738+
( # 2. Subtract 0.5
739739
"sub",
740-
np.array([[1.0, 0.51763809, 30.0, 12.13818192], [2.0, 1.0, 60.0, 6.28318531]]),
740+
np.array([1.0, 2.0]),
741741
0.5,
742-
np.array([[0.5, 0.51763809, 30.0, 12.13818192], [1.5, 1.0, 60.0, 6.28318531]]),
742+
np.array([0.5, 1.5]),
743743
),
744-
# C2. Test scalar multiplication to yarray values (intensity), expect no change to xarrays (q, tth, d)
745-
( # 1. Multipliy by integer 2
744+
# C3: Test scalar multiplication to y-values (intensity), expect no change to x-values (q, tth, d)
745+
( # 1. Multiply by 2
746746
"mul",
747-
np.array([[1.0, 0.51763809, 30.0, 12.13818192], [2.0, 1.0, 60.0, 6.28318531]]),
747+
np.array([1.0, 2.0]),
748748
2,
749-
np.array([[2.0, 0.51763809, 30.0, 12.13818192], [4.0, 1.0, 60.0, 6.28318531]]),
749+
np.array([2.0, 4.0]),
750750
),
751-
( # 2. Multipliy by float 0.5
751+
( # 2. Multiply by 2.5
752752
"mul",
753-
np.array([[1.0, 0.51763809, 30.0, 12.13818192], [2.0, 1.0, 60.0, 6.28318531]]),
753+
np.array([1.0, 2.0]),
754754
2.5,
755-
np.array([[2.5, 0.51763809, 30.0, 12.13818192], [5.0, 1.0, 60.0, 6.28318531]]),
755+
np.array([2.5, 5.0]),
756+
),
757+
# C4: Test scalar division to y-values (intensity), expect no change to x-values (q, tth, d)
758+
( # 1. Divide by 2
759+
"div",
760+
np.array([1.0, 2.0]),
761+
2,
762+
np.array([0.5, 1.0]),
763+
),
764+
( # 2. Divide by 2.5
765+
"div",
766+
np.array([1.0, 2.0]),
767+
2.5,
768+
np.array([0.4, 0.8]),
756769
),
757770
],
758771
)
759-
def test_scalar_operations(operation, starting_all_arrays, scalar_value, expected_all_arrays, do_minimal_tth):
772+
def test_scalar_operations(operation, starting_yarray, scalar_value, expected_yarray, do_minimal_tth):
760773
do = do_minimal_tth
761-
assert np.allclose(do.all_arrays, starting_all_arrays)
774+
expected_xarray_constant = np.array([[0.51763809, 30.0, 12.13818192], [1.0, 60.0, 6.28318531]])
775+
assert np.allclose(do.all_arrays[:, [1, 2, 3]], expected_xarray_constant)
776+
assert np.allclose(do.all_arrays[:, 0], starting_yarray)
762777
if operation == "add":
763-
result_right = do + scalar_value
764-
result_left = scalar_value + do
778+
do_right_op = do + scalar_value
779+
do_left_op = scalar_value + do
765780
elif operation == "sub":
766-
result_right = do - scalar_value
767-
result_left = scalar_value - do
781+
do_right_op = do - scalar_value
782+
do_left_op = scalar_value - do
768783
elif operation == "mul":
769-
result_right = do * scalar_value
770-
result_left = scalar_value * do
771-
772-
assert np.allclose(result_right.all_arrays, expected_all_arrays)
773-
assert np.allclose(result_left.all_arrays, expected_all_arrays)
784+
do_right_op = do * scalar_value
785+
do_left_op = scalar_value * do
786+
elif operation == "div":
787+
do_right_op = do / scalar_value
788+
do_left_op = scalar_value / do
789+
assert np.allclose(do_right_op.all_arrays[:, 0], expected_yarray)
790+
assert np.allclose(do_left_op.all_arrays[:, 0], expected_yarray)
791+
# Ensure x-values are unchanged
792+
assert np.allclose(do_right_op.all_arrays[:, [1, 2, 3]], expected_xarray_constant)
793+
assert np.allclose(do_left_op.all_arrays[:, [1, 2, 3]], expected_xarray_constant)
774794

775795

776796
@pytest.mark.parametrize(
@@ -793,6 +813,11 @@ def test_scalar_operations(operation, starting_all_arrays, scalar_value, expecte
793813
np.array([[1.0, 0.51763809, 30.0, 12.13818192], [4.0, 1.0, 60.0, 6.28318531]]),
794814
np.array([[1.0, 6.28318531, 100.70777771, 1], [4.0, 3.14159265, 45.28748053, 2.0]]),
795815
),
816+
(
817+
"div",
818+
np.array([[1.0, 0.51763809, 30.0, 12.13818192], [1.0, 1.0, 60.0, 6.28318531]]),
819+
np.array([[1.0, 6.28318531, 100.70777771, 1], [1.0, 3.14159265, 45.28748053, 2.0]]),
820+
),
796821
],
797822
)
798823
def test_binary_operator_on_do(
@@ -820,31 +845,45 @@ def test_binary_operator_on_do(
820845
elif operation == "mul":
821846
do_1_y_modified = do_1 * do_2
822847
do_2_y_modified = do_2 * do_1
848+
elif operation == "div":
849+
do_1_y_modified = do_1 / do_2
850+
do_2_y_modified = do_2 / do_1
823851

824852
assert np.allclose(do_1_y_modified.all_arrays, expected_do_1_all_arrays_with_y_modified)
825853
assert np.allclose(do_2_y_modified.all_arrays, expected_do_2_all_arrays_with_y_modified)
826854

827855

828-
def test_addition_operator_invalid_type(do_minimal_tth, invalid_add_type_error_msg):
829-
# Add a string to a DO object, expect TypeError, only scalar (int, float) allowed for addition
856+
def test_operator_invalid_type(do_minimal_tth, invalid_add_type_error_msg):
830857
do = do_minimal_tth
831-
with pytest.raises(TypeError, match=re.escape(invalid_add_type_error_msg)):
832-
do + "string_value"
833-
with pytest.raises(TypeError, match=re.escape(invalid_add_type_error_msg)):
834-
"string_value" + do
835-
with pytest.raises(TypeError, match=re.escape(invalid_add_type_error_msg)):
836-
do - "string_value"
837-
with pytest.raises(TypeError, match=re.escape(invalid_add_type_error_msg)):
838-
"string_value" - do
839-
840-
841-
def test_addition_operator_invalid_yarray_length(do_minimal, do_minimal_tth, y_grid_size_mismatch_error_msg):
842-
# Combine two DO objects, one with empty xarrays (do_minimal) and the other with non-empty xarrays
858+
invalid_value = "string_value"
859+
860+
operations = [
861+
(lambda x, y: x + y), # Test addition
862+
(lambda x, y: x - y), # Test subtraction
863+
(lambda x, y: x * y), # Test multiplication
864+
(lambda x, y: x / y), # Test division
865+
]
866+
867+
# Test each operation with both orderings of operands
868+
for operation in operations:
869+
with pytest.raises(TypeError, match=re.escape(invalid_add_type_error_msg)):
870+
operation(do, invalid_value)
871+
with pytest.raises(TypeError, match=re.escape(invalid_add_type_error_msg)):
872+
operation(invalid_value, do)
873+
874+
875+
@pytest.mark.parametrize("operation", ["add", "sub", "mul", "div"])
876+
def test_operator_invalid_yarray_length(operation, do_minimal, do_minimal_tth, y_grid_size_mismatch_error_msg):
843877
do_1 = do_minimal
844878
do_2 = do_minimal_tth
845879
assert len(do_1.all_arrays[:, 0]) == 0
846880
assert len(do_2.all_arrays[:, 0]) == 2
847881
with pytest.raises(ValueError, match=re.escape(y_grid_size_mismatch_error_msg)):
848-
do_1 + do_2
849-
with pytest.raises(ValueError, match=re.escape(y_grid_size_mismatch_error_msg)):
850-
do_1 - do_2
882+
if operation == "add":
883+
do_1 + do_2
884+
elif operation == "sub":
885+
do_1 - do_2
886+
elif operation == "mul":
887+
do_1 * do_2
888+
elif operation == "div":
889+
do_1 / do_2

0 commit comments

Comments
 (0)