Skip to content

Commit 11c4166

Browse files
committed
refactor: add yarrays together instead of xarrays for __add__
1 parent 12848e8 commit 11c4166

File tree

1 file changed

+31
-30
lines changed

1 file changed

+31
-30
lines changed

Diff for: src/diffpy/utils/diffraction_objects.py

+31-30
Original file line numberDiff line numberDiff line change
@@ -14,15 +14,15 @@
1414
XQUANTITIES = ANGLEQUANTITIES + DQUANTITIES + QQUANTITIES
1515
XUNITS = ["degrees", "radians", "rad", "deg", "inv_angs", "inv_nm", "nm-1", "A-1"]
1616

17-
x_grid_length_mismatch_emsg = (
18-
"The two objects have different x-array lengths. "
19-
"Please ensure the length of the x-value during initialization is identical."
17+
y_grid_length_mismatch_emsg = (
18+
"The two objects have different y-array lengths. "
19+
"Please ensure the length of the y-value during initialization is identical."
2020
)
2121

2222
invalid_add_type_emsg = (
2323
"You may only add a DiffractionObject with another DiffractionObject or a scalar value. "
2424
"Please rerun by adding another DiffractionObject instance or a scalar value. "
25-
"e.g., my_do_1 + my_do_2 or my_do + 10"
25+
"e.g., my_do_1 + my_do_2 or my_do + 10 or 10 + my_do"
2626
)
2727

2828

@@ -175,56 +175,57 @@ def __eq__(self, other):
175175
return True
176176

177177
def __add__(self, other):
178-
"""Add a scalar value or another DiffractionObject to the xarrays of
179-
the DiffractionObject.
178+
"""Add a scalar value or another DiffractionObject to the yarray of the
179+
DiffractionObject.
180180
181181
Parameters
182182
----------
183183
other : DiffractionObject or int or float
184184
The object to add to the current DiffractionObject. If `other` is a scalar value,
185-
it will be added to all xarrays. The length of the xarrays must match if `other` is
185+
it will be added to all yarray. The length of the yarray must match if `other` is
186186
an instance of DiffractionObject.
187187
188188
Returns
189189
-------
190190
DiffractionObject
191-
The new and deep-copied DiffractionObject instance after adding values to the xarrays.
191+
The new and deep-copied DiffractionObject instance after adding values to the yarray.
192192
193193
Raises
194194
------
195195
ValueError
196-
Raised when the length of the xarrays of the two DiffractionObject instances do not match.
196+
Raised when the length of the yarray of the two DiffractionObject instances do not match.
197197
TypeError
198198
Raised when the type of `other` is not an instance of DiffractionObject, int, or float.
199199
200200
Examples
201201
--------
202-
Add a scalar value to the xarrays of the DiffractionObject instance:
202+
Add a scalar value to the yarray of the DiffractionObject instance:
203203
>>> new_do = my_do + 10.1
204+
>>> new_do = 10.1 + my_do
204205
205-
Add the xarrays of two DiffractionObject instances:
206+
Add the yarray of two DiffractionObject instances:
206207
>>> new_do = my_do_1 + my_do_2
207208
"""
208209

210+
self._check_operation_compatibility(other)
209211
summed_do = deepcopy(self)
210-
# Add scalar value to all xarrays by broadcasting
211212
if isinstance(other, (int, float)):
212-
summed_do._all_arrays[:, 1] += other
213-
summed_do._all_arrays[:, 2] += other
214-
summed_do._all_arrays[:, 3] += other
215-
# Add xarrays of two DiffractionObject instances
216-
elif isinstance(other, DiffractionObject):
217-
if len(self.on_tth()[0]) != len(other.on_tth()[0]):
218-
raise ValueError(x_grid_length_mismatch_emsg)
219-
summed_do._all_arrays[:, 1] += other.on_q()[0]
220-
summed_do._all_arrays[:, 2] += other.on_tth()[0]
221-
summed_do._all_arrays[:, 3] += other.on_d()[0]
222-
else:
223-
raise TypeError(invalid_add_type_emsg)
213+
summed_do._all_arrays[:, 0] += other
214+
if isinstance(other, DiffractionObject):
215+
summed_do._all_arrays[:, 0] += other.all_arrays[:, 0]
224216
return summed_do
225217

226218
__radd__ = __add__
227219

220+
def _check_operation_compatibility(self, other):
221+
if not isinstance(other, (DiffractionObject, int, float)):
222+
raise TypeError(invalid_add_type_emsg)
223+
if isinstance(other, DiffractionObject):
224+
self_yarray = self.all_arrays[:, 0]
225+
other_yarray = other.all_arrays[:, 0]
226+
if len(self_yarray) != len(other_yarray):
227+
raise ValueError(y_grid_length_mismatch_emsg)
228+
228229
def __sub__(self, other):
229230
subtracted = deepcopy(self)
230231
if isinstance(other, int) or isinstance(other, float) or isinstance(other, np.ndarray):
@@ -233,7 +234,7 @@ def __sub__(self, other):
233234
elif not isinstance(other, DiffractionObject):
234235
raise TypeError("I only know how to subtract two Scattering_object objects")
235236
elif self.on_tth[0].all() != other.on_tth[0].all():
236-
raise RuntimeError(x_grid_length_mismatch_emsg)
237+
raise RuntimeError(y_grid_length_mismatch_emsg)
237238
else:
238239
subtracted.on_tth[1] = self.on_tth[1] - other.on_tth[1]
239240
subtracted.on_q[1] = self.on_q[1] - other.on_q[1]
@@ -247,7 +248,7 @@ def __rsub__(self, other):
247248
elif not isinstance(other, DiffractionObject):
248249
raise TypeError("I only know how to subtract two Scattering_object objects")
249250
elif self.on_tth[0].all() != other.on_tth[0].all():
250-
raise RuntimeError(x_grid_length_mismatch_emsg)
251+
raise RuntimeError(y_grid_length_mismatch_emsg)
251252
else:
252253
subtracted.on_tth[1] = other.on_tth[1] - self.on_tth[1]
253254
subtracted.on_q[1] = other.on_q[1] - self.on_q[1]
@@ -261,7 +262,7 @@ def __mul__(self, other):
261262
elif not isinstance(other, DiffractionObject):
262263
raise TypeError("I only know how to multiply two Scattering_object objects")
263264
elif self.on_tth[0].all() != other.on_tth[0].all():
264-
raise RuntimeError(x_grid_length_mismatch_emsg)
265+
raise RuntimeError(y_grid_length_mismatch_emsg)
265266
else:
266267
multiplied.on_tth[1] = self.on_tth[1] * other.on_tth[1]
267268
multiplied.on_q[1] = self.on_q[1] * other.on_q[1]
@@ -273,7 +274,7 @@ def __rmul__(self, other):
273274
multiplied.on_tth[1] = other * self.on_tth[1]
274275
multiplied.on_q[1] = other * self.on_q[1]
275276
elif self.on_tth[0].all() != other.on_tth[0].all():
276-
raise RuntimeError(x_grid_length_mismatch_emsg)
277+
raise RuntimeError(y_grid_length_mismatch_emsg)
277278
else:
278279
multiplied.on_tth[1] = self.on_tth[1] * other.on_tth[1]
279280
multiplied.on_q[1] = self.on_q[1] * other.on_q[1]
@@ -287,7 +288,7 @@ def __truediv__(self, other):
287288
elif not isinstance(other, DiffractionObject):
288289
raise TypeError("I only know how to multiply two Scattering_object objects")
289290
elif self.on_tth[0].all() != other.on_tth[0].all():
290-
raise RuntimeError(x_grid_length_mismatch_emsg)
291+
raise RuntimeError(y_grid_length_mismatch_emsg)
291292
else:
292293
divided.on_tth[1] = self.on_tth[1] / other.on_tth[1]
293294
divided.on_q[1] = self.on_q[1] / other.on_q[1]
@@ -299,7 +300,7 @@ def __rtruediv__(self, other):
299300
divided.on_tth[1] = other / self.on_tth[1]
300301
divided.on_q[1] = other / self.on_q[1]
301302
elif self.on_tth[0].all() != other.on_tth[0].all():
302-
raise RuntimeError(x_grid_length_mismatch_emsg)
303+
raise RuntimeError(y_grid_length_mismatch_emsg)
303304
else:
304305
divided.on_tth[1] = other.on_tth[1] / self.on_tth[1]
305306
divided.on_q[1] = other.on_q[1] / self.on_q[1]

0 commit comments

Comments
 (0)