|
5 | 5 | import pytest
|
6 | 6 | from freezegun import freeze_time
|
7 | 7 |
|
8 |
| -from diffpy.utils.diffraction_objects import XQUANTITIES, DiffractionObject |
| 8 | +from diffpy.utils.diffraction_objects import DiffractionObject |
9 | 9 | from diffpy.utils.transforms import wavelength_warning_emsg
|
10 | 10 |
|
11 | 11 |
|
@@ -212,42 +212,65 @@ def _test_valid_diffraction_objects(actual_diffraction_object, function, expecte
|
212 | 212 | return np.allclose(actual_array, expected_array)
|
213 | 213 |
|
214 | 214 |
|
215 |
| -def test_get_angle_index(): |
216 |
| - test = DiffractionObject( |
217 |
| - wavelength=0.71, xarray=np.array([30, 60, 90]), yarray=np.array([1, 2, 3]), xtype="tth" |
218 |
| - ) |
219 |
| - actual_index = test.get_array_index(xtype="tth", value=30) |
220 |
| - assert actual_index == 0 |
| 215 | +params_index = [ |
| 216 | + # UC1: exact match |
| 217 | + ([4 * np.pi, np.array([30.005, 60]), np.array([1, 2]), "tth", "tth", 30.005], [0]), |
| 218 | + # UC2: target value lies in the array, returns the (first) closest index |
| 219 | + ([4 * np.pi, np.array([30, 60]), np.array([1, 2]), "tth", "tth", 45], [0]), |
| 220 | + ([4 * np.pi, np.array([30, 60]), np.array([1, 2]), "tth", "q", 0.25], [0]), |
| 221 | + # UC3: target value out of the range but within reasonable distance, returns the closest index |
| 222 | + ([4 * np.pi, np.array([0.25, 0.5, 0.71]), np.array([1, 2, 3]), "q", "q", 0.1], [0]), |
| 223 | + ([4 * np.pi, np.array([30, 60]), np.array([1, 2]), "tth", "tth", 63], [1]), |
| 224 | +] |
| 225 | + |
| 226 | + |
| 227 | +@pytest.mark.parametrize("inputs, expected", params_index) |
| 228 | +def test_get_array_index(inputs, expected): |
| 229 | + test = DiffractionObject(wavelength=inputs[0], xarray=inputs[1], yarray=inputs[2], xtype=inputs[3]) |
| 230 | + actual = test.get_array_index(value=inputs[5], xtype=inputs[4]) |
| 231 | + assert actual == expected[0] |
221 | 232 |
|
222 | 233 |
|
223 | 234 | params_index_bad = [
|
224 |
| - # UC1: empty array |
| 235 | + # UC0: empty array |
225 | 236 | (
|
226 |
| - [0.71, np.array([]), np.array([]), "tth", "tth", 10], |
227 |
| - [IndexError, "WARNING: no matching value 10 found in the tth array."], |
| 237 | + [2 * np.pi, np.array([]), np.array([]), "tth", "tth", 30], |
| 238 | + [ValueError, "The 'tth' array is empty. Please ensure it is initialized and the correct xtype is used."], |
228 | 239 | ),
|
229 |
| - # UC2: invalid xtype |
| 240 | + # UC1: empty array (because of invalid xtype) |
230 | 241 | (
|
231 |
| - [None, np.array([]), np.array([]), "tth", "invalid", 10], |
| 242 | + [2 * np.pi, np.array([30, 60]), np.array([1, 2]), "tth", "invalid", 30], |
232 | 243 | [
|
233 | 244 | ValueError,
|
234 |
| - f"WARNING: I don't know how to handle the xtype, 'invalid'. " |
235 |
| - f"Please rerun specifying an xtype from {*XQUANTITIES, }", |
| 245 | + "The 'invalid' array is empty. Please ensure it is initialized and the correct xtype is used.", |
236 | 246 | ],
|
237 | 247 | ),
|
238 |
| - # UC3: pre-defined array with non-matching value |
| 248 | + # UC3: value is too far from any element in the array |
239 | 249 | (
|
240 |
| - [0.71, np.array([30, 60, 90]), np.array([1, 2, 3]), "tth", "q", 30], |
241 |
| - [IndexError, "WARNING: no matching value 30 found in the q array."], |
| 250 | + [2 * np.pi, np.array([30, 60, 90]), np.array([1, 2, 3]), "tth", "tth", 140], |
| 251 | + [ |
| 252 | + IndexError, |
| 253 | + "The value 140 is too far from any value in the 'tth' array. " |
| 254 | + "Please check if you have specified the correct xtype.", |
| 255 | + ], |
| 256 | + ), |
| 257 | + # UC4: value is too far from any element in the array (because of wrong xtype) |
| 258 | + ( |
| 259 | + [2 * np.pi, np.array([30, 60, 90]), np.array([1, 2, 3]), "tth", "q", 30], |
| 260 | + [ |
| 261 | + IndexError, |
| 262 | + "The value 30 is too far from any value in the 'q' array. " |
| 263 | + "Please check if you have specified the correct xtype.", |
| 264 | + ], |
242 | 265 | ),
|
243 | 266 | ]
|
244 | 267 |
|
245 | 268 |
|
246 | 269 | @pytest.mark.parametrize("inputs, expected", params_index_bad)
|
247 |
| -def test_get_angle_index_bad(inputs, expected): |
| 270 | +def test_get_array_index_bad(inputs, expected): |
248 | 271 | test = DiffractionObject(wavelength=inputs[0], xarray=inputs[1], yarray=inputs[2], xtype=inputs[3])
|
249 | 272 | with pytest.raises(expected[0], match=re.escape(expected[1])):
|
250 |
| - test.get_array_index(xtype=inputs[4], value=inputs[5]) |
| 273 | + test.get_array_index(value=inputs[5], xtype=inputs[4]) |
251 | 274 |
|
252 | 275 |
|
253 | 276 | def test_dump(tmp_path, mocker):
|
|
0 commit comments