Skip to content

Commit 5257ee2

Browse files
add more bad tests
1 parent 1acc80f commit 5257ee2

File tree

2 files changed

+47
-24
lines changed

2 files changed

+47
-24
lines changed

src/diffpy/utils/diffraction_objects.py

Lines changed: 12 additions & 10 deletions
Original file line numberDiff line numberDiff line change
@@ -248,25 +248,27 @@ def _set_array_from_range(self, begin, end, step_size=None, n_steps=None):
248248
array = np.linspace(begin, end, n_steps)
249249
return array
250250

251-
def get_angle_index(self, angle):
251+
def get_array_index(self, xtype, value):
252252
"""
253-
returns the index of a given angle in the angles list
253+
returns the index of a given value in the array associated with the specified xtype
254254
255255
Parameters
256256
----------
257-
angle float
258-
the angle to search for
257+
xtype str
258+
the xtype used to access the array
259+
value float
260+
the target value to search for
259261
260262
Returns
261263
-------
262-
the index of the angle in the angles list
264+
the index of the value in the array
263265
"""
264-
if not hasattr(self, "angles"):
265-
self.angles = np.array([])
266-
for i, target in enumerate(self.angles):
267-
if angle == target:
266+
if self.on_xtype(xtype) is None:
267+
raise ValueError(_xtype_wmsg(xtype))
268+
for i, target in enumerate(self.on_xtype(xtype)[0]):
269+
if value == target:
268270
return i
269-
raise IndexError(f"WARNING: no angle {angle} found in angles list.")
271+
raise IndexError(f"WARNING: no matching value {value} found in the {xtype} array.")
270272

271273
def _set_xarrays(self, xarray, xtype):
272274
self.all_arrays = np.empty(shape=(len(xarray), 4))

tests/test_diffraction_objects.py

Lines changed: 35 additions & 14 deletions
Original file line numberDiff line numberDiff line change
@@ -1,10 +1,11 @@
1+
import re
12
from pathlib import Path
23

34
import numpy as np
45
import pytest
56
from freezegun import freeze_time
67

7-
from diffpy.utils.diffraction_objects import DiffractionObject
8+
from diffpy.utils.diffraction_objects import XQUANTITIES, DiffractionObject
89
from diffpy.utils.transforms import wavelength_warning_emsg
910

1011

@@ -212,21 +213,41 @@ def _test_valid_diffraction_objects(actual_diffraction_object, function, expecte
212213

213214

214215
def test_get_angle_index():
215-
test = DiffractionObject()
216-
test.angles = np.array([10, 20, 30, 40, 50, 60])
217-
actual_angle_index = test.get_angle_index(angle=10)
218-
assert actual_angle_index == 0
216+
test = DiffractionObject(
217+
wavelength=0.71, xarray=np.array([30, 60, 90]), yarray=np.array([1, 2, 3]), xtype="tth"
218+
)
219+
actual_index = test.get_array_index(xtype="tth", value=30)
220+
assert actual_index == 0
221+
222+
223+
params_index_bad = [
224+
# UC1: empty array
225+
(
226+
[0.71, np.array([]), np.array([]), "tth", "tth", 10],
227+
[IndexError, "WARNING: no matching value 10 found in the tth array."],
228+
),
229+
# UC2: invalid xtype
230+
(
231+
[None, np.array([]), np.array([]), "tth", "invalid", 10],
232+
[
233+
ValueError,
234+
f"WARNING: I don't know how to handle the xtype, 'invalid'. "
235+
f"Please rerun specifying an xtype from {*XQUANTITIES, }",
236+
],
237+
),
238+
# UC3: pre-defined array with non-matching value
239+
(
240+
[0.71, np.array([30, 60, 90]), np.array([1, 2, 3]), "tth", "q", 30],
241+
[IndexError, "WARNING: no matching value 30 found in the q array."],
242+
),
243+
]
219244

220245

221-
def test_get_angle_index_bad():
222-
test = DiffractionObject()
223-
# empty angles list
224-
with pytest.raises(IndexError, match="WARNING: no angle 11 found in angles list."):
225-
test.get_angle_index(angle=11)
226-
# pre-defined angles list
227-
test.angles = np.array([10, 20, 30, 40, 50, 60])
228-
with pytest.raises(IndexError, match="WARNING: no angle 11 found in angles list."):
229-
test.get_angle_index(angle=11)
246+
@pytest.mark.parametrize("inputs, expected", params_index_bad)
247+
def test_get_angle_index_bad(inputs, expected):
248+
test = DiffractionObject(wavelength=inputs[0], xarray=inputs[1], yarray=inputs[2], xtype=inputs[3])
249+
with pytest.raises(expected[0], match=re.escape(expected[1])):
250+
test.get_array_index(xtype=inputs[4], value=inputs[5])
230251

231252

232253
def test_dump(tmp_path, mocker):

0 commit comments

Comments
 (0)