Skip to content

Commit 45039e3

Browse files
add function in tools to set xtype to tth or q
1 parent 1811fe8 commit 45039e3

File tree

4 files changed

+52
-26
lines changed

4 files changed

+52
-26
lines changed

src/diffpy/labpdfproc/functions.py

+3-11
Original file line numberDiff line numberDiff line change
@@ -5,12 +5,7 @@
55
import pandas as pd
66
from scipy.interpolate import interp1d
77

8-
from diffpy.utils.scattering_objects.diffraction_objects import (
9-
ANGLEQUANTITIES,
10-
DQUANTITIES,
11-
XQUANTITIES,
12-
Diffraction_object,
13-
)
8+
from diffpy.utils.scattering_objects.diffraction_objects import XQUANTITIES, Diffraction_object
149

1510
RADIUS_MM = 1
1611
N_POINTS_ON_DIAMETER = 300
@@ -251,7 +246,7 @@ def _cve_method(method):
251246

252247
def interpolate_to_xtype_grid(cve_do, xtype):
253248
f"""
254-
interpolates the cve grid to the xtype user specifies, raise an error if xtype is invalid
249+
interpolates the cve grid to the xtype user specified
255250
256251
Parameters
257252
----------
@@ -264,10 +259,7 @@ def interpolate_to_xtype_grid(cve_do, xtype):
264259
-------
265260
the new diffraction object with interpolated cve curves
266261
"""
267-
268-
if xtype.lower() not in XQUANTITIES:
269-
raise ValueError(f"Unknown xtype: {xtype}. Allowed xtypes are {*XQUANTITIES, }.")
270-
if xtype.lower() in ANGLEQUANTITIES or xtype.lower() in DQUANTITIES:
262+
if xtype == "tth":
271263
return cve_do
272264

273265
orig_grid, orig_cve = cve_do.on_tth[0], cve_do.on_tth[1]

src/diffpy/labpdfproc/tools.py

+21
Original file line numberDiff line numberDiff line change
@@ -2,6 +2,7 @@
22
from pathlib import Path
33

44
from diffpy.labpdfproc.mud_calculator import compute_mud
5+
from diffpy.utils.scattering_objects.diffraction_objects import QQUANTITIES, XQUANTITIES
56
from diffpy.utils.tools import get_package_info, get_user_info
67

78
WAVELENGTHS = {"Mo": 0.71, "Ag": 0.59, "Cu": 1.54}
@@ -135,6 +136,25 @@ def set_wavelength(args):
135136
return args
136137

137138

139+
def set_xtype(args):
140+
f"""
141+
Set the xtype based on the given input arguments, raise an error if xtype is not one of {*XQUANTITIES, }
142+
143+
Parameters
144+
----------
145+
args argparse.Namespace
146+
the arguments from the parser
147+
148+
Returns
149+
-------
150+
args argparse.Namespace
151+
"""
152+
if args.xtype.lower() not in XQUANTITIES:
153+
raise ValueError(f"Unknown xtype: {args.xtype}. Allowed xtypes are {*XQUANTITIES, }.")
154+
args.xtype = "q" if args.xtype.lower() in QQUANTITIES else "tth"
155+
return args
156+
157+
138158
def set_mud(args):
139159
"""
140160
Set the mud based on the given input arguments
@@ -257,6 +277,7 @@ def preprocessing_args(args):
257277
args = set_input_lists(args)
258278
args.output_directory = set_output_directory(args)
259279
args = set_wavelength(args)
280+
args = set_xtype(args)
260281
args = set_mud(args)
261282
args = load_user_metadata(args)
262283
return args

tests/test_functions.py

+1-15
Original file line numberDiff line numberDiff line change
@@ -10,7 +10,7 @@
1010
compute_cve,
1111
interpolate_to_xtype_grid,
1212
)
13-
from diffpy.utils.scattering_objects.diffraction_objects import XQUANTITIES, Diffraction_object
13+
from diffpy.utils.scattering_objects.diffraction_objects import Diffraction_object
1414

1515
params1 = [
1616
([0.5, 3, 1], {(0.0, -0.5), (0.0, 0.0), (0.5, 0.0), (-0.5, 0.0), (0.0, 0.5)}),
@@ -106,20 +106,6 @@ def test_interpolate_xtype(inputs, expected, mocker):
106106
assert actual_cve_do == expected_cve_do
107107

108108

109-
def test_interpolate_xtype_bad():
110-
input_cve_do = _instantiate_test_do(
111-
np.array([30, 60, 90]),
112-
np.array([1, 2, 3]),
113-
xtype="tth",
114-
name="absorption correction, cve, for test",
115-
scat_quantity="cve",
116-
)
117-
with pytest.raises(
118-
ValueError, match=re.escape(f"Unknown xtype: invalid. Allowed xtypes are {*XQUANTITIES, }.")
119-
):
120-
interpolate_to_xtype_grid(input_cve_do, xtype="invalid")
121-
122-
123109
def test_compute_cve(mocker):
124110
xarray, yarray = np.array([90, 90.1, 90.2]), np.array([2, 2, 2])
125111
expected_cve = np.array([0.5, 0.5, 0.5])

tests/test_tools.py

+27
Original file line numberDiff line numberDiff line change
@@ -16,7 +16,9 @@
1616
set_mud,
1717
set_output_directory,
1818
set_wavelength,
19+
set_xtype,
1920
)
21+
from diffpy.utils.scattering_objects.diffraction_objects import XQUANTITIES
2022

2123
# Use cases can be found here: https://github.com/diffpy/diffpy.labpdfproc/issues/48
2224

@@ -189,6 +191,31 @@ def test_set_wavelength_bad(inputs, msg):
189191
actual_args = set_wavelength(actual_args)
190192

191193

194+
params4 = [
195+
([], ["tth"]),
196+
(["--xtype", "2theta"], ["tth"]),
197+
(["--xtype", "d"], ["tth"]),
198+
(["--xtype", "q"], ["q"]),
199+
]
200+
201+
202+
@pytest.mark.parametrize("inputs, expected", params4)
203+
def test_set_xtype(inputs, expected):
204+
cli_inputs = ["2.5", "data.xy"] + inputs
205+
actual_args = get_args(cli_inputs)
206+
actual_args = set_xtype(actual_args)
207+
assert actual_args.xtype == expected[0]
208+
209+
210+
def test_set_xtype_bad():
211+
cli_inputs = ["2.5", "data.xy", "--xtype", "invalid"]
212+
actual_args = get_args(cli_inputs)
213+
with pytest.raises(
214+
ValueError, match=re.escape(f"Unknown xtype: invalid. Allowed xtypes are {*XQUANTITIES, }.")
215+
):
216+
actual_args = set_xtype(actual_args)
217+
218+
192219
def test_set_mud(user_filesystem):
193220
cli_inputs = ["2.5", "data.xy"]
194221
actual_args = get_args(cli_inputs)

0 commit comments

Comments
 (0)