Skip to content

Commit 19e64f6

Browse files
Merge branch 'main' into muD
2 parents becbe94 + 103f4dc commit 19e64f6

File tree

4 files changed

+89
-84
lines changed

4 files changed

+89
-84
lines changed

news/fastcalto7.rst

+24
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,24 @@
1+
**Added:**
2+
3+
* Fast calculation supports values up to muD = 7
4+
5+
**Changed:**
6+
7+
* Default to brute-force computation when muD < 0.5 or > 7.
8+
* Print a warning message instead of error, explicitly stating the input muD value
9+
10+
**Deprecated:**
11+
12+
* <news item>
13+
14+
**Removed:**
15+
16+
* <news item>
17+
18+
**Fixed:**
19+
20+
* <news item>
21+
22+
**Security:**
23+
24+
* <news item>
Original file line numberDiff line numberDiff line change
@@ -1,5 +1,7 @@
1-
-6.543824057313183,2.517676049711881e-10,-430.98705781061807,-1390.1242894055792,-2233.6043531698424,-2624.226082870084,-2590.22123249732
2-
5.020666247920394,-4.676545705328227e-10,875.3743680589598,2744.521375373735,4366.230491983095,5113.620084734874,5051.51673764446
3-
1.8615172482459748,3.25649449911046e-10,-663.8228448830648,-2027.2101338882571,-3194.0427904869357,-3728.3876362721985,-3684.6430045258744
4-
-1.9688714316525813,0.9999999998992586,224.64320300800793,666.0546805817288,1038.3830530062712,1207.4338980088214,1193.1790131535502
5-
0.9819818977427489,1.1680695590205036e-11,-28.533774174029173,-82.17596308875525,-126.6770208129477,-146.65755463691826,-144.8506009433783
1+
-20619.128648244743,2.4364997877410417e-06,26505.585137008606,-337279.3213902739,-1056051.2792994028,-1815539.1160187968,-2372171.980894541,-2642993.4538858426
2+
56203.20048346139,-6.774323967911372e-06,-51225.01460831775,1005023.6254387972,3051609.2101441943,5201294.212075991,6773264.754466731,7538833.212217793
3+
-63802.106117440504,7.845612740225726e-06,33089.880661840034,-1242527.9972772636,-3668761.3028854067,-6202335.40819644,-8050708.344579743,-8951452.327576341
4+
38603.18842240787,-4.844608638276231e-06,-4028.752491140328,816333.2796219026,2349291.871863084,3940828.5623758254,5099123.959731602,5663734.051678175
5+
-13126.533425725584,1.6822282153058894e-06,-4395.4012467221355,-300757.7094990732,-845218.597424347,-1407243.8630987857,-1815244.4369415767,-2014105.8565458413
6+
2378.155272695758,0.9999996885549859,1912.0755691304305,58944.34388638723,162014.93660541167,267802.2088119374,344395.5526651557,381710.2334140182
7+
-178.70587585316136,2.4018054840276848e-08,-234.76082555771987,-4803.1698085743365,-12928.521798999534,-21220.235622246797,-27207.092578054173,-30121.285267280207

src/diffpy/labpdfproc/functions.py

+16-27
Original file line numberDiff line numberDiff line change
@@ -1,4 +1,5 @@
11
import math
2+
import warnings
23
from pathlib import Path
34

45
import numpy as np
@@ -16,7 +17,7 @@
1617
CVE_METHODS = ["brute_force", "polynomial_interpolation"]
1718

1819
# Pre-computed datasets for polynomial interpolation (fast calculation)
19-
MUD_LIST = [0.5, 1, 2, 3, 4, 5, 6]
20+
MUD_LIST = np.array([0.5, 1, 2, 3, 4, 5, 6, 7])
2021
CWD = Path(__file__).parent.resolve()
2122
MULS = np.loadtxt(CWD / "data" / "inverse_cve.xy")
2223
COEFFICIENT_LIST = np.array(
@@ -74,7 +75,6 @@ def _get_entry_exit_coordinates(self, coordinate, angle):
7475
----------
7576
coordinate : tuple of floats
7677
The coordinates of the grid point.
77-
7878
angle : float
7979
The angle in degrees.
8080
@@ -90,9 +90,7 @@ def _get_entry_exit_coordinates(self, coordinate, angle):
9090
angle = math.radians(angle)
9191
xgrid = coordinate[0]
9292
ygrid = coordinate[1]
93-
9493
entry_point = (-math.sqrt(self.radius**2 - ygrid**2), ygrid)
95-
9694
if not math.isclose(angle, math.pi / 2, abs_tol=epsilon):
9795
b = ygrid - xgrid * math.tan(angle)
9896
a = math.tan(angle)
@@ -107,7 +105,6 @@ def _get_entry_exit_coordinates(self, coordinate, angle):
107105
exit_point = (xexit_root1, yexit_root1)
108106
else:
109107
exit_point = (xgrid, math.sqrt(self.radius**2 - xgrid**2))
110-
111108
return entry_point, exit_point
112109

113110
def _get_path_length(self, grid_point, angle):
@@ -119,7 +116,6 @@ def _get_path_length(self, grid_point, angle):
119116
----------
120117
grid_point : double of floats
121118
The coordinate inside the circle.
122-
123119
angle : float
124120
The angle of the output beam in degrees.
125121
@@ -129,7 +125,6 @@ def _get_path_length(self, grid_point, angle):
129125
The tuple containing three floats,
130126
which are the total distance, entry distance and exit distance.
131127
"""
132-
133128
# move angle a tad above zero if it is zero
134129
# to avoid it having the wrong sign due to some rounding error
135130
angle_delta = 0.000001
@@ -181,7 +176,9 @@ def _cve_brute_force(input_pattern, mud):
181176
Assume mu=mud/2, given that the same mu*D yields the same cve and D/2=1.
182177
"""
183178
mu_sample_invmm = mud / 2
184-
abs_correction = Gridded_circle(mu=mu_sample_invmm)
179+
abs_correction = Gridded_circle(
180+
n_points_on_diameter=N_POINTS_ON_DIAMETER, mu=mu_sample_invmm
181+
)
185182
distances, muls = [], []
186183
for angle in TTH_GRID:
187184
abs_correction.set_distances_at_angle(angle)
@@ -191,7 +188,6 @@ def _cve_brute_force(input_pattern, mud):
191188
distances = np.array(distances) / abs_correction.total_points_in_grid
192189
muls = np.array(muls) / abs_correction.total_points_in_grid
193190
cve = 1 / muls
194-
195191
cve_do = DiffractionObject(
196192
xarray=TTH_GRID,
197193
yarray=cve,
@@ -206,28 +202,21 @@ def _cve_brute_force(input_pattern, mud):
206202

207203
def _cve_polynomial_interpolation(input_pattern, mud):
208204
"""Compute cve using polynomial interpolation method,
209-
raise an error if the mu*D value is out of the range (0.5 to 6).
205+
default to brute-force computation if mu*D is
206+
out of the range (0.5 to 7).
210207
"""
211-
if mud > 6 or mud < 0.5:
212-
raise ValueError(
213-
f"mu*D is out of the acceptable range (0.5 to 6) "
208+
if mud > 7 or mud < 0.5:
209+
warnings.warn(
210+
f"Input mu*D = {mud} is out of the acceptable range "
211+
f"({np.min(MUD_LIST)} to {np.max(MUD_LIST)}) "
214212
f"for polynomial interpolation. "
215-
f"Please rerun with a value within this range "
216-
f"or specifying another method from {*CVE_METHODS, }."
213+
f"Proceeding with brute-force computation. "
217214
)
218-
coeff_a, coeff_b, coeff_c, coeff_d, coeff_e = [
219-
interpolation_function(mud)
220-
for interpolation_function in INTERPOLATION_FUNCTIONS
221-
]
222-
muls = np.array(
223-
coeff_a * MULS**4
224-
+ coeff_b * MULS**3
225-
+ coeff_c * MULS**2
226-
+ coeff_d * MULS
227-
+ coeff_e
228-
)
229-
cve = 1 / muls
215+
return _cve_brute_force(input_pattern, mud)
230216

217+
coeffs = np.array([f(mud) for f in INTERPOLATION_FUNCTIONS])
218+
muls = np.polyval(coeffs, MULS)
219+
cve = 1 / muls
231220
cve_do = DiffractionObject(
232221
xarray=TTH_GRID,
233222
yarray=cve,

tests/test_functions.py

+42-52
Original file line numberDiff line numberDiff line change
@@ -105,49 +105,55 @@ def test_set_muls_at_angle(input_mu, expected_muls):
105105

106106

107107
@pytest.mark.parametrize(
108-
"input_xtype, expected",
109-
[
110-
(
111-
"tth",
108+
"input_diffraction_data, input_cve_params",
109+
[ # Test that cve diffraction object contains the expected info
110+
# Note that all cve values are interpolated to 0.5
111+
# cve do should contain the same input xarray, xtype,
112+
# wavelength, and metadata
113+
( # C1: User did not specify method, default to fast calculation
112114
{
113115
"xarray": np.array([90, 90.1, 90.2]),
114-
"yarray": np.array([0.5, 0.5, 0.5]),
115-
"xtype": "tth",
116+
"yarray": np.array([2, 2, 2]),
116117
},
118+
{"mud": 1, "xtype": "tth"},
117119
),
118-
(
119-
"q",
120+
( # C2: User specified brute-force computation method
120121
{
121-
"xarray": np.array([5.76998, 5.77501, 5.78004]),
122-
"yarray": np.array([0.5, 0.5, 0.5]),
123-
"xtype": "q",
122+
"xarray": np.array([5.1, 5.2, 5.3]),
123+
"yarray": np.array([2, 2, 2]),
124124
},
125+
{"mud": 1, "method": "brute_force", "xtype": "q"},
126+
),
127+
( # C3: User specified mu*D outside the fast calculation range,
128+
# default to brute-force computation
129+
{
130+
"xarray": np.array([5.1, 5.2, 5.3]),
131+
"yarray": np.array([2, 2, 2]),
132+
},
133+
{"mud": 20, "xtype": "q"},
125134
),
126135
],
127136
)
128-
def test_compute_cve(input_xtype, expected, mocker):
129-
xarray, yarray = np.array([90, 90.1, 90.2]), np.array([2, 2, 2])
137+
def test_compute_cve(mocker, input_diffraction_data, input_cve_params):
138+
expected_xarray = input_diffraction_data["xarray"]
130139
expected_cve = np.array([0.5, 0.5, 0.5])
140+
expected_xtype = input_cve_params["xtype"]
141+
mocker.patch("diffpy.labpdfproc.functions.N_POINTS_ON_DIAMETER", 4)
131142
mocker.patch("numpy.interp", return_value=expected_cve)
132143
input_pattern = DiffractionObject(
133-
xarray=xarray,
134-
yarray=yarray,
135-
xtype="tth",
144+
xarray=input_diffraction_data["xarray"],
145+
yarray=input_diffraction_data["yarray"],
146+
xtype=input_cve_params["xtype"],
136147
wavelength=1.54,
137148
scat_quantity="x-ray",
138149
name="test",
139150
metadata={"thing1": 1, "thing2": "thing2"},
140151
)
141-
actual_cve_do = compute_cve(
142-
input_pattern,
143-
mud=1,
144-
method="polynomial_interpolation",
145-
xtype=input_xtype,
146-
)
152+
actual_cve_do = compute_cve(input_pattern, **input_cve_params)
147153
expected_cve_do = DiffractionObject(
148-
xarray=expected["xarray"],
149-
yarray=expected["yarray"],
150-
xtype=expected["xtype"],
154+
xarray=expected_xarray,
155+
yarray=expected_cve,
156+
xtype=expected_xtype,
151157
wavelength=1.54,
152158
scat_quantity="cve",
153159
name="absorption correction, cve, for test",
@@ -156,32 +162,9 @@ def test_compute_cve(input_xtype, expected, mocker):
156162
assert actual_cve_do == expected_cve_do
157163

158164

159-
@pytest.mark.parametrize(
160-
"inputs, msg",
161-
[
162-
(
163-
{"mud": 7, "method": "polynomial_interpolation"},
164-
f"mu*D is out of the acceptable range (0.5 to 6) "
165-
f"for polynomial interpolation. "
166-
f"Please rerun with a value within this range "
167-
f"or specifying another method from {*CVE_METHODS, }.",
168-
),
169-
(
170-
{"mud": 1, "method": "invalid_method"},
171-
f"Unknown method: invalid_method. "
172-
f"Allowed methods are {*CVE_METHODS, }.",
173-
),
174-
(
175-
{"mud": 7, "method": "invalid_method"},
176-
f"Unknown method: invalid_method. "
177-
f"Allowed methods are {*CVE_METHODS, }.",
178-
),
179-
],
180-
)
181-
def test_compute_cve_bad(mocker, inputs, msg):
165+
def test_compute_cve_bad(mocker):
182166
xarray, yarray = np.array([90, 90.1, 90.2]), np.array([2, 2, 2])
183167
expected_cve = np.array([0.5, 0.5, 0.5])
184-
mocker.patch("diffpy.labpdfproc.functions.TTH_GRID", xarray)
185168
mocker.patch("numpy.interp", return_value=expected_cve)
186169
input_pattern = DiffractionObject(
187170
xarray=xarray,
@@ -192,14 +175,21 @@ def test_compute_cve_bad(mocker, inputs, msg):
192175
name="test",
193176
metadata={"thing1": 1, "thing2": "thing2"},
194177
)
195-
with pytest.raises(ValueError, match=re.escape(msg)):
196-
compute_cve(input_pattern, mud=inputs["mud"], method=inputs["method"])
178+
# Test that the function raises a ValueError
179+
# when an invalid method is provided
180+
with pytest.raises(
181+
ValueError,
182+
match=re.escape(
183+
f"Unknown method: invalid_method. "
184+
f"Allowed methods are {*CVE_METHODS, }."
185+
),
186+
):
187+
compute_cve(input_pattern, mud=1, method="invalid_method")
197188

198189

199190
def test_apply_corr(mocker):
200191
xarray, yarray = np.array([90, 90.1, 90.2]), np.array([2, 2, 2])
201192
expected_cve = np.array([0.5, 0.5, 0.5])
202-
mocker.patch("diffpy.labpdfproc.functions.TTH_GRID", xarray)
203193
mocker.patch("numpy.interp", return_value=expected_cve)
204194
input_pattern = DiffractionObject(
205195
xarray=xarray,

0 commit comments

Comments
 (0)