Skip to content

Commit 56e7278

Browse files
add more tests
1 parent 5257ee2 commit 56e7278

File tree

2 files changed

+68
-27
lines changed

2 files changed

+68
-27
lines changed

src/diffpy/utils/diffraction_objects.py

Lines changed: 26 additions & 8 deletions
Original file line numberDiff line numberDiff line change
@@ -248,9 +248,9 @@ def _set_array_from_range(self, begin, end, step_size=None, n_steps=None):
248248
array = np.linspace(begin, end, n_steps)
249249
return array
250250

251-
def get_array_index(self, xtype, value):
251+
def get_array_index(self, value, xtype=None):
252252
"""
253-
returns the index of a given value in the array associated with the specified xtype
253+
returns the index of the closest value in the array associated with the specified xtype
254254
255255
Parameters
256256
----------
@@ -263,12 +263,30 @@ def get_array_index(self, xtype, value):
263263
-------
264264
the index of the value in the array
265265
"""
266-
if self.on_xtype(xtype) is None:
267-
raise ValueError(_xtype_wmsg(xtype))
268-
for i, target in enumerate(self.on_xtype(xtype)[0]):
269-
if value == target:
270-
return i
271-
raise IndexError(f"WARNING: no matching value {value} found in the {xtype} array.")
266+
267+
if xtype is None:
268+
xtype = self.input_xtype
269+
if self.on_xtype(xtype) is None or len(self.on_xtype(xtype)[0]) == 0:
270+
raise ValueError(
271+
f"The '{xtype}' array is empty. " "Please ensure it is initialized and the correct xtype is used."
272+
)
273+
array = self.on_xtype(xtype)[0]
274+
i = (np.abs(array - value)).argmin()
275+
nearest_value = np.abs(array[i] - value)
276+
distance = min(np.abs(value - array.min()), np.abs(value - array.max()))
277+
threshold = 0.5 * (array.max() - array.min())
278+
279+
if nearest_value != 0 and (array.min() <= value <= array.max() or distance <= threshold):
280+
warnings.warn(
281+
f"WARNING: The value {value} is not an exact match of the '{xtype}' array. "
282+
f"Returning the index of the closest value."
283+
)
284+
elif distance > threshold:
285+
raise IndexError(
286+
f"The value {value} is too far from any value in the '{xtype}' array. "
287+
f"Please check if you have specified the correct xtype. "
288+
)
289+
return i
272290

273291
def _set_xarrays(self, xarray, xtype):
274292
self.all_arrays = np.empty(shape=(len(xarray), 4))

tests/test_diffraction_objects.py

Lines changed: 42 additions & 19 deletions
Original file line numberDiff line numberDiff line change
@@ -5,7 +5,7 @@
55
import pytest
66
from freezegun import freeze_time
77

8-
from diffpy.utils.diffraction_objects import XQUANTITIES, DiffractionObject
8+
from diffpy.utils.diffraction_objects import DiffractionObject
99
from diffpy.utils.transforms import wavelength_warning_emsg
1010

1111

@@ -212,42 +212,65 @@ def _test_valid_diffraction_objects(actual_diffraction_object, function, expecte
212212
return np.allclose(actual_array, expected_array)
213213

214214

215-
def test_get_angle_index():
216-
test = DiffractionObject(
217-
wavelength=0.71, xarray=np.array([30, 60, 90]), yarray=np.array([1, 2, 3]), xtype="tth"
218-
)
219-
actual_index = test.get_array_index(xtype="tth", value=30)
220-
assert actual_index == 0
215+
params_index = [
216+
# UC1: exact match
217+
([4 * np.pi, np.array([30.005, 60]), np.array([1, 2]), "tth", "tth", 30.005], [0]),
218+
# UC2: target value lies in the array, returns the (first) closest index
219+
([4 * np.pi, np.array([30, 60]), np.array([1, 2]), "tth", "tth", 45], [0]),
220+
([4 * np.pi, np.array([30, 60]), np.array([1, 2]), "tth", "q", 0.25], [0]),
221+
# UC3: target value out of the range but within reasonable distance, returns the closest index
222+
([4 * np.pi, np.array([0.25, 0.5, 0.71]), np.array([1, 2, 3]), "q", "q", 0.1], [0]),
223+
([4 * np.pi, np.array([30, 60]), np.array([1, 2]), "tth", "tth", 63], [1]),
224+
]
225+
226+
227+
@pytest.mark.parametrize("inputs, expected", params_index)
228+
def test_get_array_index(inputs, expected):
229+
test = DiffractionObject(wavelength=inputs[0], xarray=inputs[1], yarray=inputs[2], xtype=inputs[3])
230+
actual = test.get_array_index(value=inputs[5], xtype=inputs[4])
231+
assert actual == expected[0]
221232

222233

223234
params_index_bad = [
224-
# UC1: empty array
235+
# UC0: empty array
225236
(
226-
[0.71, np.array([]), np.array([]), "tth", "tth", 10],
227-
[IndexError, "WARNING: no matching value 10 found in the tth array."],
237+
[2 * np.pi, np.array([]), np.array([]), "tth", "tth", 30],
238+
[ValueError, "The 'tth' array is empty. Please ensure it is initialized and the correct xtype is used."],
228239
),
229-
# UC2: invalid xtype
240+
# UC1: empty array (because of invalid xtype)
230241
(
231-
[None, np.array([]), np.array([]), "tth", "invalid", 10],
242+
[2 * np.pi, np.array([30, 60]), np.array([1, 2]), "tth", "invalid", 30],
232243
[
233244
ValueError,
234-
f"WARNING: I don't know how to handle the xtype, 'invalid'. "
235-
f"Please rerun specifying an xtype from {*XQUANTITIES, }",
245+
"The 'invalid' array is empty. Please ensure it is initialized and the correct xtype is used.",
236246
],
237247
),
238-
# UC3: pre-defined array with non-matching value
248+
# UC3: value is too far from any element in the array
239249
(
240-
[0.71, np.array([30, 60, 90]), np.array([1, 2, 3]), "tth", "q", 30],
241-
[IndexError, "WARNING: no matching value 30 found in the q array."],
250+
[2 * np.pi, np.array([30, 60, 90]), np.array([1, 2, 3]), "tth", "tth", 140],
251+
[
252+
IndexError,
253+
"The value 140 is too far from any value in the 'tth' array. "
254+
"Please check if you have specified the correct xtype.",
255+
],
256+
),
257+
# UC4: value is too far from any element in the array (because of wrong xtype)
258+
(
259+
[2 * np.pi, np.array([30, 60, 90]), np.array([1, 2, 3]), "tth", "q", 30],
260+
[
261+
IndexError,
262+
"The value 30 is too far from any value in the 'q' array. "
263+
"Please check if you have specified the correct xtype.",
264+
],
242265
),
243266
]
244267

245268

246269
@pytest.mark.parametrize("inputs, expected", params_index_bad)
247-
def test_get_angle_index_bad(inputs, expected):
270+
def test_get_array_index_bad(inputs, expected):
248271
test = DiffractionObject(wavelength=inputs[0], xarray=inputs[1], yarray=inputs[2], xtype=inputs[3])
249272
with pytest.raises(expected[0], match=re.escape(expected[1])):
250-
test.get_array_index(xtype=inputs[4], value=inputs[5])
273+
test.get_array_index(value=inputs[5], xtype=inputs[4])
251274

252275

253276
def test_dump(tmp_path, mocker):

0 commit comments

Comments
 (0)