Skip to content

Commit 3bd24cd

Browse files
remove unnecessary interpolation function, edit test for compute_cve for different xtypes
1 parent 45039e3 commit 3bd24cd

File tree

2 files changed

+11
-72
lines changed

2 files changed

+11
-72
lines changed

src/diffpy/labpdfproc/functions.py

+3-37
Original file line numberDiff line numberDiff line change
@@ -244,39 +244,6 @@ def _cve_method(method):
244244
return methods[method]
245245

246246

247-
def interpolate_to_xtype_grid(cve_do, xtype):
248-
f"""
249-
interpolates the cve grid to the xtype user specified
250-
251-
Parameters
252-
----------
253-
cve_do Diffraction_object
254-
the diffraction object that contains the cve to be applied
255-
xtype str
256-
the quantity on the independent variable axis, allowed values are {*XQUANTITIES, }
257-
258-
Returns
259-
-------
260-
the new diffraction object with interpolated cve curves
261-
"""
262-
if xtype == "tth":
263-
return cve_do
264-
265-
orig_grid, orig_cve = cve_do.on_tth[0], cve_do.on_tth[1]
266-
new_grid = cve_do.tth_to_q()
267-
new_cve = np.interp(new_grid, orig_grid, orig_cve)
268-
new_cve_do = Diffraction_object(wavelength=cve_do.wavelength)
269-
new_cve_do.insert_scattering_quantity(
270-
new_grid,
271-
new_cve,
272-
xtype,
273-
metadata=cve_do.metadata,
274-
name=cve_do.name,
275-
scat_quantity="cve",
276-
)
277-
return new_cve_do
278-
279-
280247
def compute_cve(diffraction_data, mud, method="polynomial_interpolation", xtype="tth"):
281248
f"""
282249
compute and interpolate the cve for the given diffraction data and mud using the selected method
@@ -298,11 +265,10 @@ def compute_cve(diffraction_data, mud, method="polynomial_interpolation", xtype=
298265
"""
299266

300267
cve_function = _cve_method(method)
301-
cve_do_on_global_tth = cve_function(diffraction_data, mud)
302-
cve_do_on_global_xtype = interpolate_to_xtype_grid(cve_do_on_global_tth, xtype)
268+
cve_do_on_global_grid = cve_function(diffraction_data, mud)
303269
orig_grid = diffraction_data.on_xtype(xtype)[0]
304-
global_xtype = cve_do_on_global_xtype.on_xtype(xtype)[0]
305-
cve_on_global_xtype = cve_do_on_global_xtype.on_xtype(xtype)[1]
270+
global_xtype = cve_do_on_global_grid.on_xtype(xtype)[0]
271+
cve_on_global_xtype = cve_do_on_global_grid.on_xtype(xtype)[1]
306272
newcve = np.interp(orig_grid, global_xtype, cve_on_global_xtype)
307273
cve_do = Diffraction_object(wavelength=diffraction_data.wavelength)
308274
cve_do.insert_scattering_quantity(

tests/test_functions.py

+8-35
Original file line numberDiff line numberDiff line change
@@ -3,13 +3,7 @@
33
import numpy as np
44
import pytest
55

6-
from diffpy.labpdfproc.functions import (
7-
CVE_METHODS,
8-
Gridded_circle,
9-
apply_corr,
10-
compute_cve,
11-
interpolate_to_xtype_grid,
12-
)
6+
from diffpy.labpdfproc.functions import CVE_METHODS, Gridded_circle, apply_corr, compute_cve
137
from diffpy.utils.scattering_objects.diffraction_objects import Diffraction_object
148

159
params1 = [
@@ -78,44 +72,23 @@ def _instantiate_test_do(xarray, yarray, xtype="tth", name="test", scat_quantity
7872

7973

8074
params4 = [
81-
([np.array([30, 60, 90]), np.array([1, 2, 3]), "tth"], [np.array([30, 60, 90]), np.array([1, 2, 3]), "tth"]),
82-
(
83-
[np.array([30, 60, 90]), np.array([1, 2, 3]), "q"],
84-
[np.array([2.11195, 4.07999, 5.76998]), np.array([1, 1, 1]), "q"],
85-
),
75+
(["tth"], [np.array([90, 90.1, 90.2]), np.array([0.5, 0.5, 0.5]), "tth"]),
76+
(["q"], [np.array([5.76998, 5.77501, 5.78004]), np.array([0.5, 0.5, 0.5]), "q"]),
8677
]
8778

8879

8980
@pytest.mark.parametrize("inputs, expected", params4)
90-
def test_interpolate_xtype(inputs, expected, mocker):
91-
expected_cve_do = _instantiate_test_do(
92-
expected[0],
93-
expected[1],
94-
xtype=expected[2],
95-
name="absorption correction, cve, for test",
96-
scat_quantity="cve",
97-
)
98-
input_cve_do = _instantiate_test_do(
99-
inputs[0],
100-
inputs[1],
101-
xtype="tth",
102-
name="absorption correction, cve, for test",
103-
scat_quantity="cve",
104-
)
105-
actual_cve_do = interpolate_to_xtype_grid(input_cve_do, xtype=inputs[2])
106-
assert actual_cve_do == expected_cve_do
107-
108-
109-
def test_compute_cve(mocker):
81+
def test_compute_cve(inputs, expected, mocker):
11082
xarray, yarray = np.array([90, 90.1, 90.2]), np.array([2, 2, 2])
11183
expected_cve = np.array([0.5, 0.5, 0.5])
11284
mocker.patch("diffpy.labpdfproc.functions.TTH_GRID", xarray)
11385
mocker.patch("numpy.interp", return_value=expected_cve)
11486
input_pattern = _instantiate_test_do(xarray, yarray)
115-
actual_cve_do = compute_cve(input_pattern, mud=1)
87+
actual_cve_do = compute_cve(input_pattern, mud=1, method="polynomial_interpolation", xtype=inputs[0])
11688
expected_cve_do = _instantiate_test_do(
117-
xarray,
118-
expected_cve,
89+
expected[0],
90+
expected[1],
91+
expected[2],
11992
name="absorption correction, cve, for test",
12093
scat_quantity="cve",
12194
)

0 commit comments

Comments
 (0)