From 758410596d7d6e7c6653428c0db17a7ad3527f35 Mon Sep 17 00:00:00 2001 From: yucongalicechen Date: Wed, 15 Jan 2025 14:46:32 -0500 Subject: [PATCH 1/6] fix: fix test functions using new group standard, remove _instantiate_test_do --- news/docstring-tests.rst | 23 +++++ tests/test_functions.py | 202 +++++++++++++++++++++++---------------- 2 files changed, 145 insertions(+), 80 deletions(-) create mode 100644 news/docstring-tests.rst diff --git a/news/docstring-tests.rst b/news/docstring-tests.rst new file mode 100644 index 0000000..790d30b --- /dev/null +++ b/news/docstring-tests.rst @@ -0,0 +1,23 @@ +**Added:** + +* + +**Changed:** + +* + +**Deprecated:** + +* + +**Removed:** + +* + +**Fixed:** + +* + +**Security:** + +* diff --git a/tests/test_functions.py b/tests/test_functions.py index 93a067a..6691f73 100644 --- a/tests/test_functions.py +++ b/tests/test_functions.py @@ -6,116 +6,139 @@ from diffpy.labpdfproc.functions import CVE_METHODS, Gridded_circle, apply_corr, compute_cve from diffpy.utils.diffraction_objects import DiffractionObject -params1 = [ - ([0.5, 3, 1], {(0.0, -0.5), (0.0, 0.0), (0.5, 0.0), (-0.5, 0.0), (0.0, 0.5)}), - ([1, 4, 1], {(-0.333333, -0.333333), (0.333333, -0.333333), (-0.333333, 0.333333), (0.333333, 0.333333)}), -] - -@pytest.mark.parametrize("inputs, expected", params1) -def test_get_grid_points(inputs, expected): - expected_grid = expected - actual_gs = Gridded_circle(radius=inputs[0], n_points_on_diameter=inputs[1], mu=inputs[2]) +@pytest.mark.parametrize( + "inputs, expected_grid", + [ + ( + {"radius": 0.5, "n_points_on_diameter": 3, "mu": 1}, + {(0.0, -0.5), (0.0, 0.0), (0.5, 0.0), (-0.5, 0.0), (0.0, 0.5)}, + ), + ( + {"radius": 1, "n_points_on_diameter": 4, "mu": 1}, + {(-0.333333, -0.333333), (0.333333, -0.333333), (-0.333333, 0.333333), (0.333333, 0.333333)}, + ), + ], +) +def test_get_grid_points(inputs, expected_grid): + actual_gs = Gridded_circle( + radius=inputs["radius"], n_points_on_diameter=inputs["n_points_on_diameter"], mu=inputs["mu"] + ) actual_grid_sorted = sorted(actual_gs.grid) expected_grid_sorted = sorted(expected_grid) for actual_point, expected_point in zip(actual_grid_sorted, expected_grid_sorted): assert actual_point == pytest.approx(expected_point, rel=1e-4, abs=1e-6) -params2 = [ - ([1, 3, 1, 45], [0, 1.4142135, 1.4142135, 2, 2]), - ([1, 3, 1, 90], [0, 0, 2, 2, 2]), - ([1, 3, 1, 120], [0, 0, 2, 3, 1.73205]), - ([1, 4, 1, 30], [2.057347, 2.044451, 1.621801, 1.813330]), - ([1, 4, 1, 90], [1.885618, 1.885618, 2.552285, 1.218951]), - ([1, 4, 1, 140], [1.139021, 2.200102, 2.744909, 1.451264]), -] - - -@pytest.mark.parametrize("inputs, expected", params2) -def test_set_distances_at_angle(inputs, expected): - expected_distances = expected - actual_gs = Gridded_circle(radius=inputs[0], n_points_on_diameter=inputs[1], mu=inputs[2]) - actual_gs.set_distances_at_angle(inputs[3]) +@pytest.mark.parametrize( + "inputs, expected_distances", + [ + ({"radius": 1, "n_points_on_diameter": 3, "mu": 1, "angle": 45}, [0, 1.4142135, 1.4142135, 2, 2]), + ({"radius": 1, "n_points_on_diameter": 3, "mu": 1, "angle": 90}, [0, 0, 2, 2, 2]), + ({"radius": 1, "n_points_on_diameter": 3, "mu": 1, "angle": 120}, [0, 0, 2, 3, 1.73205]), + ({"radius": 1, "n_points_on_diameter": 4, "mu": 1, "angle": 30}, [2.057347, 2.044451, 1.621801, 1.813330]), + ({"radius": 1, "n_points_on_diameter": 4, "mu": 1, "angle": 90}, [1.885618, 1.885618, 2.552285, 1.218951]), + ( + {"radius": 1, "n_points_on_diameter": 4, "mu": 1, "angle": 140}, + [1.139021, 2.200102, 2.744909, 1.451264], + ), + ], +) +def test_set_distances_at_angle(inputs, expected_distances): + actual_gs = Gridded_circle( + radius=inputs["radius"], n_points_on_diameter=inputs["n_points_on_diameter"], mu=inputs["mu"] + ) + actual_gs.set_distances_at_angle(inputs["angle"]) actual_distances_sorted = sorted(actual_gs.distances) expected_distances_sorted = sorted(expected_distances) assert actual_distances_sorted == pytest.approx(expected_distances_sorted, rel=1e-4, abs=1e-6) -params3 = [ - ([1], [1, 1, 0.135335, 0.049787, 0.176921]), - ([2], [1, 1, 0.018316, 0.002479, 0.031301]), -] - - -@pytest.mark.parametrize("inputs, expected", params3) -def test_set_muls_at_angle(inputs, expected): - expected_muls = expected - actual_gs = Gridded_circle(radius=1, n_points_on_diameter=3, mu=inputs[0]) +@pytest.mark.parametrize( + "input_mu, expected_muls", + [ + (1, [1, 1, 0.135335, 0.049787, 0.176921]), + (2, [1, 1, 0.018316, 0.002479, 0.031301]), + ], +) +def test_set_muls_at_angle(input_mu, expected_muls): + actual_gs = Gridded_circle(radius=1, n_points_on_diameter=3, mu=input_mu) actual_gs.set_muls_at_angle(120) actual_muls_sorted = sorted(actual_gs.muls) expected_muls_sorted = sorted(expected_muls) assert actual_muls_sorted == pytest.approx(expected_muls_sorted, rel=1e-4, abs=1e-6) -def _instantiate_test_do(xarray, yarray, xtype="tth", name="test", scat_quantity="x-ray"): - test_do = DiffractionObject( +@pytest.mark.parametrize( + "input_xtype, expected", + [ + ("tth", {"xarray": np.array([90, 90.1, 90.2]), "yarray": np.array([0.5, 0.5, 0.5]), "xtype": "tth"}), + ( + "q", + {"xarray": np.array([5.76998, 5.77501, 5.78004]), "yarray": np.array([0.5, 0.5, 0.5]), "xtype": "q"}, + ), + ], +) +def test_compute_cve(input_xtype, expected, mocker): + xarray, yarray = np.array([90, 90.1, 90.2]), np.array([2, 2, 2]) + expected_cve = np.array([0.5, 0.5, 0.5]) + mocker.patch("numpy.interp", return_value=expected_cve) + input_pattern = DiffractionObject( xarray=xarray, yarray=yarray, - xtype=xtype, + xtype="tth", wavelength=1.54, - scat_quantity=scat_quantity, - name=name, + scat_quantity="x-ray", + name="test", metadata={"thing1": 1, "thing2": "thing2"}, ) - return test_do - - -params4 = [ - (["tth"], [np.array([90, 90.1, 90.2]), np.array([0.5, 0.5, 0.5]), "tth"]), - (["q"], [np.array([5.76998, 5.77501, 5.78004]), np.array([0.5, 0.5, 0.5]), "q"]), -] - - -@pytest.mark.parametrize("inputs, expected", params4) -def test_compute_cve(inputs, expected, mocker): - xarray, yarray = np.array([90, 90.1, 90.2]), np.array([2, 2, 2]) - expected_cve = np.array([0.5, 0.5, 0.5]) - mocker.patch("numpy.interp", return_value=expected_cve) - input_pattern = _instantiate_test_do(xarray, yarray) - actual_cve_do = compute_cve(input_pattern, mud=1, method="polynomial_interpolation", xtype=inputs[0]) - expected_cve_do = _instantiate_test_do( - xarray=expected[0], - yarray=expected[1], - xtype=expected[2], - name="absorption correction, cve, for test", + actual_cve_do = compute_cve(input_pattern, mud=1, method="polynomial_interpolation", xtype=input_xtype) + expected_cve_do = DiffractionObject( + xarray=expected["xarray"], + yarray=expected["yarray"], + xtype=expected["xtype"], + wavelength=1.54, scat_quantity="cve", + name="absorption correction, cve, for test", + metadata={"thing1": 1, "thing2": "thing2"}, ) assert actual_cve_do == expected_cve_do -params_cve_bad = [ - ( - [7, "polynomial_interpolation"], - [ +@pytest.mark.parametrize( + "inputs, msg", + [ + ( + {"mud": 7, "method": "polynomial_interpolation"}, f"mu*D is out of the acceptable range (0.5 to 6) for polynomial interpolation. " - f"Please rerun with a value within this range or specifying another method from {*CVE_METHODS, }." - ], - ), - ([1, "invalid_method"], [f"Unknown method: invalid_method. Allowed methods are {*CVE_METHODS, }."]), - ([7, "invalid_method"], [f"Unknown method: invalid_method. Allowed methods are {*CVE_METHODS, }."]), -] - - -@pytest.mark.parametrize("inputs, msg", params_cve_bad) + f"Please rerun with a value within this range or specifying another method from {*CVE_METHODS, }.", + ), + ( + {"mud": 1, "method": "invalid_method"}, + f"Unknown method: invalid_method. Allowed methods are {*CVE_METHODS, }.", + ), + ( + {"mud": 7, "method": "invalid_method"}, + f"Unknown method: invalid_method. Allowed methods are {*CVE_METHODS, }.", + ), + ], +) def test_compute_cve_bad(mocker, inputs, msg): xarray, yarray = np.array([90, 90.1, 90.2]), np.array([2, 2, 2]) expected_cve = np.array([0.5, 0.5, 0.5]) mocker.patch("diffpy.labpdfproc.functions.TTH_GRID", xarray) mocker.patch("numpy.interp", return_value=expected_cve) - input_pattern = _instantiate_test_do(xarray, yarray) - with pytest.raises(ValueError, match=re.escape(msg[0])): - compute_cve(input_pattern, mud=inputs[0], method=inputs[1]) + input_pattern = DiffractionObject( + xarray=xarray, + yarray=yarray, + xtype="tth", + wavelength=1.54, + scat_quantity="x-ray", + name="test", + metadata={"thing1": 1, "thing2": "thing2"}, + ) + with pytest.raises(ValueError, match=re.escape(msg)): + compute_cve(input_pattern, mud=inputs["mud"], method=inputs["method"]) def test_apply_corr(mocker): @@ -123,13 +146,32 @@ def test_apply_corr(mocker): expected_cve = np.array([0.5, 0.5, 0.5]) mocker.patch("diffpy.labpdfproc.functions.TTH_GRID", xarray) mocker.patch("numpy.interp", return_value=expected_cve) - input_pattern = _instantiate_test_do(xarray, yarray) - absorption_correction = _instantiate_test_do( + input_pattern = DiffractionObject( + xarray=xarray, + yarray=yarray, + xtype="tth", + wavelength=1.54, + scat_quantity="x-ray", + name="test", + metadata={"thing1": 1, "thing2": "thing2"}, + ) + absorption_correction = DiffractionObject( xarray=xarray, yarray=expected_cve, - name="absorption correction, cve, for test", + xtype="tth", + wavelength=1.54, scat_quantity="cve", + name="absorption correction, cve, for test", + metadata={"thing1": 1, "thing2": "thing2"}, ) actual_corr = apply_corr(input_pattern, absorption_correction) - expected_corr = _instantiate_test_do(xarray, np.array([1, 1, 1])) + expected_corr = DiffractionObject( + xarray=xarray, + yarray=np.array([1, 1, 1]), + xtype="tth", + wavelength=1.54, + scat_quantity="x-ray", + name="test", + metadata={"thing1": 1, "thing2": "thing2"}, + ) assert actual_corr == expected_corr From 0532b0e389da8a9054da9ec790a8fb59aaf2eeaa Mon Sep 17 00:00:00 2001 From: yucongalicechen Date: Wed, 22 Jan 2025 15:54:30 -0500 Subject: [PATCH 2/6] docs: refine docstrings for functions --- src/diffpy/labpdfproc/functions.py | 224 +++++++++++++---------------- 1 file changed, 96 insertions(+), 128 deletions(-) diff --git a/src/diffpy/labpdfproc/functions.py b/src/diffpy/labpdfproc/functions.py index 7f93e10..13e783d 100644 --- a/src/diffpy/labpdfproc/functions.py +++ b/src/diffpy/labpdfproc/functions.py @@ -14,7 +14,7 @@ TTH_GRID[-1] = 180.00 CVE_METHODS = ["brute_force", "polynomial_interpolation"] -# pre-computed datasets for polynomial interpolation (fast calculation) +# Pre-computed datasets for polynomial interpolation (fast calculation) MUD_LIST = [0.5, 1, 2, 3, 4, 5, 6] CWD = Path(__file__).parent.resolve() MULS = np.loadtxt(CWD / "data" / "inverse_cve.xy") @@ -32,91 +32,43 @@ def __init__(self, radius=1, n_points_on_diameter=N_POINTS_ON_DIAMETER, mu=None) self._get_grid_points() def _get_grid_points(self): - """ - given a radius and a grid size, return a grid of points to uniformly sample that circle - """ + """Given a radius and a grid size, return a grid of points to uniformly sample that circle.""" xs = np.linspace(-self.radius, self.radius, self.npoints) ys = np.linspace(-self.radius, self.radius, self.npoints) self.grid = {(x, y) for x in xs for y in ys if x**2 + y**2 <= self.radius**2} self.total_points_in_grid = len(self.grid) - def set_distances_at_angle(self, angle): - """ - given an angle, set the distances from the grid points to the entry and exit coordinates - - Parameters - ---------- - angle float - the angle in degrees - - Returns - ------- - the list of distances containing total distance, primary distance and secondary distance - - """ - self.primary_distances, self.secondary_distances, self.distances = [], [], [] - for coord in self.grid: - distance, primary, secondary = self.get_path_length(coord, angle) - self.distances.append(distance) - self.primary_distances.append(primary) - self.secondary_distances.append(secondary) - - def set_muls_at_angle(self, angle): - """ - compute muls = exp(-mu*distance) for a given angle - - Parameters - ---------- - angle float - the angle in degrees - - Returns - ------- - an array of floats containing the muls corresponding to each angle - - """ - mu = self.mu - self.muls = [] - if len(self.distances) == 0: - self.set_distances_at_angle(angle) - for distance in self.distances: - self.muls.append(np.exp(-mu * distance)) - def _get_entry_exit_coordinates(self, coordinate, angle): - """ - get the coordinates where the beam enters and leaves the circle for a given angle and grid point - - Parameters - ---------- - grid_point tuple of floats - the coordinates of the grid point - - angle float - the angle in degrees + """Get the coordinates where the beam enters and leaves the circle for a given angle and grid point. - radius float - the radius of the circle in units of inverse mu - - it is calculated in the following way: + It is calculated in the following way: For the entry coordinate, the y-component will be the y of the grid point and the x-component will be minus the value of x on the circle at the height of this y. For the exit coordinate: - Find the line y = ax + b that passes through grid_point at angle angle - The circle is x^2 + y^2 = r^2 + Find the line y = ax + b that passes through grid_point at angle. + The circle is x^2 + y^2 = r^2. The exit point is where these are simultaneous equations x^2 + y^2 = r^2 & y = ax + b x^2 + (ax+b)^2 = r^2 => x^2 + a^2x^2 + 2abx + b^2 - r^2 = 0 => (1+a^2) x^2 + 2abx + (b^2 - r^2) = 0 to find x_exit we find the roots of these equations and pick the root that is above y-grid - then we get y_exit from y_exit = a*x_exit + b + then we get y_exit from y_exit = a*x_exit + b. + + Parameters + ---------- + coordinate : tuple of floats + The coordinates of the grid point. + + angle : float + The angle in degrees. Returns ------- - (1) the coordinate of the entry point and (2) of the exit point of a beam entering horizontally - impinging on a coordinate point that lies in the circle and then exiting at some angle, angle. - + (entry_point, exit_point): tuple of floats + (1) The coordinate of the entry point and (2) of the exit point of a beam entering horizontally + impinging on a coordinate point that lies in the circle and then exiting at some angle, angle. """ epsilon = 1e-7 # precision close to 90 angle = math.radians(angle) @@ -140,28 +92,22 @@ def _get_entry_exit_coordinates(self, coordinate, angle): return entry_point, exit_point - def get_path_length(self, grid_point, angle): - """ - return the path length - - This is the pathlength of a horizontal line entering the circle at the - same height to the grid point then exiting at angle angle + def _get_path_length(self, grid_point, angle): + """Return the path length of a horizontal line entering the circle at the + same height to the grid point then exiting at angle. Parameters ---------- - grid_point double of floats - the coordinate inside the circle + grid_point : double of floats + The coordinate inside the circle. - angle float - the angle of the output beam - - radius - the radius of the circle + angle : float + The angle of the output beam in degrees. Returns ------- - floats total distance, primary distance and secondary distance - + (total distance, primary distance, secondary distance): tuple of floats + The tuple containing three floats, which are the total distance, entry distance and exit distance. """ # move angle a tad above zero if it is zero to avoid it having the wrong sign due to some rounding error @@ -174,13 +120,41 @@ def get_path_length(self, grid_point, angle): total_distance = primary_distance + secondary_distance return total_distance, primary_distance, secondary_distance + def set_distances_at_angle(self, angle): + """Given an angle, set the distances from the grid points to the entry and exit coordinates. -def _cve_brute_force(diffraction_data, mud): - """ - compute cve for the given mud on a global grid using the brute-force method - assume mu=mud/2, given that the same mu*D yields the same cve and D/2=1 - """ + Parameters + ---------- + angle : float + The angle of the output beam in degrees. + """ + self.primary_distances, self.secondary_distances, self.distances = [], [], [] + for coord in self.grid: + distance, primary, secondary = self._get_path_length(coord, angle) + self.distances.append(distance) + self.primary_distances.append(primary) + self.secondary_distances.append(secondary) + + def set_muls_at_angle(self, angle): + """Compute muls = exp(-mu*distance) for a given angle. + Parameters + ---------- + angle : float + The angle of the output beam in degrees. + """ + mu = self.mu + self.muls = [] + if len(self.distances) == 0: + self.set_distances_at_angle(angle) + for distance in self.distances: + self.muls.append(np.exp(-mu * distance)) + + +def _cve_brute_force(input_pattern, mud): + """Compute cve for the given mud on a global grid using the brute-force method. + Assume mu=mud/2, given that the same mu*D yields the same cve and D/2=1. + """ mu_sample_invmm = mud / 2 abs_correction = Gridded_circle(mu=mu_sample_invmm) distances, muls = [], [] @@ -197,19 +171,18 @@ def _cve_brute_force(diffraction_data, mud): xarray=TTH_GRID, yarray=cve, xtype="tth", - wavelength=diffraction_data.wavelength, + wavelength=input_pattern.wavelength, scat_quantity="cve", - name=f"absorption correction, cve, for {diffraction_data.name}", - metadata=diffraction_data.metadata, + name=f"absorption correction, cve, for {input_pattern.name}", + metadata=input_pattern.metadata, ) return cve_do -def _cve_polynomial_interpolation(diffraction_data, mud): - """ - compute cve using polynomial interpolation method, raise an error if mu*D is out of the range (0.5 to 6) +def _cve_polynomial_interpolation(input_pattern, mud): + """Compute cve using polynomial interpolation method, + raise an error if the mu*D value is out of the range (0.5 to 6). """ - if mud > 6 or mud < 0.5: raise ValueError( f"mu*D is out of the acceptable range (0.5 to 6) for polynomial interpolation. " @@ -225,18 +198,16 @@ def _cve_polynomial_interpolation(diffraction_data, mud): xarray=TTH_GRID, yarray=cve, xtype="tth", - wavelength=diffraction_data.wavelength, + wavelength=input_pattern.wavelength, scat_quantity="cve", - name=f"absorption correction, cve, for {diffraction_data.name}", - metadata=diffraction_data.metadata, + name=f"absorption correction, cve, for {input_pattern.name}", + metadata=input_pattern.metadata, ) return cve_do def _cve_method(method): - """ - retrieve the cve computation function for the given method - """ + """Retrieve the cve computation function for the given method.""" methods = { "brute_force": _cve_brute_force, "polynomial_interpolation": _cve_polynomial_interpolation, @@ -246,29 +217,28 @@ def _cve_method(method): return methods[method] -def compute_cve(diffraction_data, mud, method="polynomial_interpolation", xtype="tth"): - f""" - compute and interpolate the cve for the given diffraction data and mud using the selected method +def compute_cve(input_pattern, mud, method="polynomial_interpolation", xtype="tth"): + f"""Compute and interpolate the cve for the given input diffraction data and mu*D using the selected method. Parameters ---------- - diffraction_data Diffraction_object - the diffraction pattern - mud float - the mu*D of the diffraction object, where D is the diameter of the circle - xtype str - the quantity on the independent variable axis, allowed values are {*XQUANTITIES, } - method str - the method used to calculate cve, must be one of {*CVE_METHODS, } + input_pattern : DiffractionObject + The input diffraction object to which the cve will be applied. + mud : float + The mu*D value of the diffraction object, where D is the diameter of the circle. + xtype : str + The quantity on the independent variable axis, allowed values are {*XQUANTITIES, }. + method : str + The method used to calculate cve, must be one of {*CVE_METHODS, }. Returns ------- - the diffraction object with cve curves + cve_do: DiffractionObject + The diffraction object that contains the cve to be applied. """ - cve_function = _cve_method(method) - cve_do_on_global_grid = cve_function(diffraction_data, mud) - orig_grid = diffraction_data.on_xtype(xtype)[0] + cve_do_on_global_grid = cve_function(input_pattern, mud) + orig_grid = input_pattern.on_xtype(xtype)[0] global_xtype = cve_do_on_global_grid.on_xtype(xtype)[0] cve_on_global_xtype = cve_do_on_global_grid.on_xtype(xtype)[1] newcve = np.interp(orig_grid, global_xtype, cve_on_global_xtype) @@ -276,30 +246,28 @@ def compute_cve(diffraction_data, mud, method="polynomial_interpolation", xtype= xarray=orig_grid, yarray=newcve, xtype=xtype, - wavelength=diffraction_data.wavelength, + wavelength=input_pattern.wavelength, scat_quantity="cve", - name=f"absorption correction, cve, for {diffraction_data.name}", - metadata=diffraction_data.metadata, + name=f"absorption correction, cve, for {input_pattern.name}", + metadata=input_pattern.metadata, ) return cve_do -def apply_corr(diffraction_pattern, absorption_correction): - """ - Apply absorption correction to the given diffraction object modo with the correction diffraction object abdo +def apply_corr(input_pattern, absorption_correction): + """Apply absorption correction to the given diffraction object with the correction diffraction object. Parameters ---------- - diffraction_pattern Diffraction_object - the input diffraction object to which the cve will be applied - absorption_correction Diffraction_object - the diffraction object that contains the cve to be applied + input_pattern : DiffractionObject + The input diffraction object to which the cve will be applied. + absorption_correction : DiffractionObject + The diffraction object that contains the cve to be applied. Returns ------- - a corrected diffraction object with the correction applied through multiplication - + corrected_pattern: DiffractionObject + The corrected diffraction object with the correction applied through multiplication. """ - - corrected_pattern = diffraction_pattern * absorption_correction + corrected_pattern = input_pattern * absorption_correction return corrected_pattern From 281b2f92b13f754131330051786b50e10252d170 Mon Sep 17 00:00:00 2001 From: yucongalicechen Date: Wed, 22 Jan 2025 16:13:10 -0500 Subject: [PATCH 3/6] docs: refine docstrings in tools, add orcid argument --- src/diffpy/labpdfproc/labpdfprocapp.py | 8 ++ src/diffpy/labpdfproc/tools.py | 171 ++++++++++++++----------- 2 files changed, 101 insertions(+), 78 deletions(-) diff --git a/src/diffpy/labpdfproc/labpdfprocapp.py b/src/diffpy/labpdfproc/labpdfprocapp.py index c17b8fb..2848ef3 100644 --- a/src/diffpy/labpdfproc/labpdfprocapp.py +++ b/src/diffpy/labpdfproc/labpdfprocapp.py @@ -120,6 +120,14 @@ def define_arguments(): ), "default": None, }, + { + "name": ["--orcid"], + "help": ( + "ORCID will be loaded from config files. Specify here " + "only if you want to override that behavior at runtime. " + ), + "default": None, + }, { "name": ["-z", "--z-scan-file"], "help": "Path to the z-scan file to be loaded to determine the mu*D value", diff --git a/src/diffpy/labpdfproc/tools.py b/src/diffpy/labpdfproc/tools.py index 8fcf0c3..db1e124 100644 --- a/src/diffpy/labpdfproc/tools.py +++ b/src/diffpy/labpdfproc/tools.py @@ -13,41 +13,41 @@ def set_output_directory(args): - """ - set the output directory based on the given input arguments - - Parameters - ---------- - args argparse.Namespace - the arguments from the parser + """Set the output directory based on the given input arguments. - it is determined as follows: + It is determined as follows: If user provides an output directory, use it. Otherwise, we set it to the current directory if nothing is provided. We then create the directory if it does not exist. + Parameters + ---------- + args : argparse.Namespace + The arguments from the parser. + Returns ------- - a Path object that contains the full path of the output directory + args : argparse.Namespace + The updated arguments, with output_directory as the full path to the output file directory. """ output_dir = Path(args.output_directory).resolve() if args.output_directory else Path.cwd().resolve() output_dir.mkdir(parents=True, exist_ok=True) - return output_dir + args.output_directory = output_dir + return args def _expand_user_input(args): - """ - Expands the list of inputs by adding files from file lists and wildcards. + """Expand the list of inputs by adding files from file lists and wildcards. Parameters ---------- - args argparse.Namespace - the arguments from the parser + args : argparse.Namespace + The arguments from the parser. Returns ------- - the arguments with the modified input list - + args : argparse.Namespace + The updated arguments with the modified input list. """ file_list_inputs = [input_name for input_name in args.input if "file_list" in input_name] for file_list_input in file_list_inputs: @@ -57,28 +57,34 @@ def _expand_user_input(args): args.input.remove(file_list_input) wildcard_inputs = [input_name for input_name in args.input if "*" in input_name] for wildcard_input in wildcard_inputs: - input_files = [str(file) for file in Path(".").glob(wildcard_input) if "file_list" not in file.name] + input_files = [ + str(file) + for file in Path(".").glob(wildcard_input) + if "file_list" not in file.name and "diffpyconfig.json" not in file.name + ] args.input.extend(input_files) args.input.remove(wildcard_input) return args def set_input_lists(args): - """ - Set input directory and files. - - It takes cli inputs, checks if they are files or directories and creates + """Set input directory and files. It takes cli inputs, checks if they are files or directories and creates a list of files to be processed which is stored in the args Namespace. Parameters ---------- - args argparse.Namespace - the arguments from the parser + args : argparse.Namespace + The arguments from the parser. + + Raises + ------ + FileNotFoundError + Raised when an input is invalid. Returns ------- - args argparse.Namespace - + args : argparse.Namespace + The updated arguments with the modified input list. """ input_paths = [] @@ -91,7 +97,9 @@ def set_input_lists(args): elif input_path.is_dir(): input_files = input_path.glob("*") input_files = [ - file.resolve() for file in input_files if file.is_file() and "file_list" not in file.name + file.resolve() + for file in input_files + if file.is_file() and "file_list" not in file.name and "diffpyconfig.json" not in file.name ] input_paths.extend(input_files) else: @@ -99,26 +107,31 @@ def set_input_lists(args): f"Cannot find {input_name}. Please specify valid input file(s) or directories." ) else: - raise FileNotFoundError(f"Cannot find {input_name}.") + raise FileNotFoundError( + f"Cannot find {input_name}. Please specify valid input file(s) or directories." + ) setattr(args, "input_paths", list(set(input_paths))) return args def set_wavelength(args): - """ - Set the wavelength based on the given input arguments + """Set the wavelength based on the given anode_type. If a wavelength is provided, + it will be used, and the anode_type argument will be removed. Parameters ---------- - args argparse.Namespace - the arguments from the parser + args : argparse.Namespace + The arguments from the parser. - we raise a ValueError if the input wavelength is non-positive - or if the input anode_type is not one of the known sources + Raises + ------ + ValueError + Raised when input wavelength is non-positive or if input anode_type is not one of the known sources. Returns ------- - args argparse.Namespace + args : argparse.Namespace + The updated arguments with the wavelength. """ if args.wavelength is not None and args.wavelength <= 0: raise ValueError( @@ -139,17 +152,17 @@ def set_wavelength(args): def set_xtype(args): - f""" - Set the xtype based on the given input arguments, raise an error if xtype is not one of {*XQUANTITIES, } + f"""Set the xtype based on the given input arguments, raise an error if xtype is not one of {*XQUANTITIES, }. Parameters ---------- - args argparse.Namespace - the arguments from the parser + args : argparse.Namespace + The arguments from the parser. Returns ------- - args argparse.Namespace + args : argparse.Namespace + The updated arguments with the xtype as one of q, tth, or d. """ if args.xtype.lower() not in XQUANTITIES: raise ValueError(f"Unknown xtype: {args.xtype}. Allowed xtypes are {*XQUANTITIES, }.") @@ -160,17 +173,17 @@ def set_xtype(args): def set_mud(args): - """ - Set the mud based on the given input arguments + """Compute mu*D based on the given z-scan file, if provided. Parameters ---------- - args argparse.Namespace - the arguments from the parser + args : argparse.Namespace + The arguments from the parser. Returns ------- - args argparse.Namespace + args : argparse.Namespace + The updated arguments with mu*D. """ if args.z_scan_file: filepath = Path(args.z_scan_file).resolve() @@ -186,22 +199,21 @@ def _load_key_value_pair(s): key = items[0].strip() if len(items) > 1: value = "=".join(items[1:]) - return (key, value) + return key, value def load_user_metadata(args): - """ - Load user metadata into the provided argparse Namespace, raise ValueError if in incorrect format + """Load user metadata into args, raise ValueError if it is in incorrect format. Parameters ---------- - args argparse.Namespace - the arguments from the parser + args : argparse.Namespace + The arguments from the parser. Returns ------- - the updated argparse Namespace with user metadata inserted as key-value pairs - + args : argparse.Namespace + The updated argparse Namespace with user metadata inserted as key-value pairs. """ reserved_keys = vars(args).keys() @@ -215,50 +227,49 @@ def load_user_metadata(args): ) key, value = _load_key_value_pair(item) if key in reserved_keys: - raise ValueError(f"{key} is a reserved name. Please rerun using a different key name. ") + raise ValueError(f"{key} is a reserved name. Please rerun using a different key name.") if hasattr(args, key): - raise ValueError(f"Please do not specify repeated keys: {key}. ") + raise ValueError(f"Please do not specify repeated keys: {key}.") setattr(args, key, value) delattr(args, "user_metadata") return args def load_user_info(args): - """ - Load user info into args. If args are not provided, call check_and_build_global_config function from + """Load user info into args. If none is provided, call check_and_build_global_config function from diffpy.utils to prompt the user for inputs. Otherwise, call get_user_info with the provided arguments. Parameters ---------- - args argparse.Namespace - the arguments from the parser, default is None + args : argparse.Namespace + The arguments from the parser. Returns ------- - the updated argparse Namespace with username and email inserted - + args : argparse.Namespace + The updated argparse Namespace with username, email, and orcid inserted. """ if args.username is None or args.email is None: check_and_build_global_config() - config = get_user_info(owner_name=args.username, owner_email=args.email) + config = get_user_info(owner_name=args.username, owner_email=args.email, owner_orcid=args.orcid) args.username = config.get("owner_name") args.email = config.get("owner_email") + args.orcid = config.get("owner_orcid") return args def load_package_info(args): - """ - Load diffpy.labpdfproc package name and version into args using get_package_info function from diffpy.utils + """Load diffpy.labpdfproc package name and version into args using get_package_info function from diffpy.utils. Parameters ---------- - args argparse.Namespace - the arguments from the parser, default is None + args : argparse.Namespace + The arguments from the parser. Returns ------- - the updated argparse Namespace with diffpy.labpdfproc name and version inserted - + args : argparse.Namespace + The updated argparse Namespace with diffpy.labpdfproc name and version inserted. """ metadata = get_package_info("diffpy.labpdfproc") setattr(args, "package_info", metadata["package_info"]) @@ -266,22 +277,24 @@ def load_package_info(args): def preprocessing_args(args): - """ - Perform preprocessing on the provided argparse Namespace + """Perform preprocessing on the provided args. + The process includes loading package and user information, + setting input, output, wavelength, xtype, mu*D, and loading user metadata. Parameters ---------- - args argparse.Namespace - the arguments from the parser, default is None + args : argparse.Namespace + The arguments from the parser. Returns ------- - the updated argparse Namespace with arguments preprocessed + args : argparse.Namespace + The updated argparse Namespace with arguments preprocessed. """ args = load_package_info(args) args = load_user_info(args) args = set_input_lists(args) - args.output_directory = set_output_directory(args) + args = set_output_directory(args) args = set_wavelength(args) args = set_xtype(args) args = set_mud(args) @@ -290,19 +303,21 @@ def preprocessing_args(args): def load_metadata(args, filepath): - """ - Load relevant metadata from args + """Load the relevant metadata from args to write into the header of the output files. Parameters ---------- - args argparse.Namespace - the arguments from the parser + args : argparse.Namespace + The arguments from the parser. + + filepath : Path + The filepath of the current input file. Returns ------- - A dictionary with relevant arguments from the parser + metadata : dict + The dictionary with relevant arguments from the parser. """ - metadata = copy.deepcopy(vars(args)) for key in METADATA_KEYS_TO_EXCLUDE: metadata.pop(key, None) From c246b2ac73480464146b3a2c5ab975ab86884d3b Mon Sep 17 00:00:00 2001 From: yucongalicechen Date: Wed, 22 Jan 2025 17:00:07 -0500 Subject: [PATCH 4/6] docs: refine test_tools, need to add tests for orcid --- src/diffpy/labpdfproc/tools.py | 2 +- tests/test_tools.py | 389 +++++++++++++++++---------------- 2 files changed, 203 insertions(+), 188 deletions(-) diff --git a/src/diffpy/labpdfproc/tools.py b/src/diffpy/labpdfproc/tools.py index db1e124..32cfaa2 100644 --- a/src/diffpy/labpdfproc/tools.py +++ b/src/diffpy/labpdfproc/tools.py @@ -216,7 +216,7 @@ def load_user_metadata(args): The updated argparse Namespace with user metadata inserted as key-value pairs. """ - reserved_keys = vars(args).keys() + reserved_keys = set(vars(args).keys()) if args.user_metadata: for item in args.user_metadata: diff --git a/tests/test_tools.py b/tests/test_tools.py index bcee89d..82c749c 100644 --- a/tests/test_tools.py +++ b/tests/test_tools.py @@ -20,62 +20,74 @@ ) from diffpy.utils.diffraction_objects import XQUANTITIES -# Use cases can be found here: https://github.com/diffpy/diffpy.labpdfproc/issues/48 - -# This test covers existing single input file, directory, a file list, and multiple files -# We store absolute path into input_directory and file names into input_file -params_input = [ - (["good_data.chi"], ["good_data.chi"]), # single good file, same directory - (["input_dir/good_data.chi"], ["input_dir/good_data.chi"]), # single good file, input directory - ( # glob current directory - ["."], - ["good_data.chi", "good_data.xy", "good_data.txt", "unreadable_file.txt", "binary.pkl"], - ), - ( # glob input directory - ["./input_dir"], - [ - "input_dir/good_data.chi", - "input_dir/good_data.xy", - "input_dir/good_data.txt", - "input_dir/unreadable_file.txt", - "input_dir/binary.pkl", - ], - ), - ( # glob list of input directories - [".", "./input_dir"], - [ - "./good_data.chi", - "./good_data.xy", - "./good_data.txt", - "./unreadable_file.txt", - "./binary.pkl", - "input_dir/good_data.chi", - "input_dir/good_data.xy", - "input_dir/good_data.txt", - "input_dir/unreadable_file.txt", - "input_dir/binary.pkl", - ], - ), - ( # file_list_example2.txt list of files provided in different directories with wildcard - ["input_dir/file_list_example2.txt"], - ["input_dir/good_data.chi", "good_data.xy", "input_dir/good_data.txt", "input_dir/unreadable_file.txt"], - ), - ( # wildcard pattern, matching files with .chi extension in the same directory - ["./*.chi"], - ["good_data.chi"], - ), - ( # wildcard pattern, matching files with .chi extension in the input directory - ["input_dir/*.chi"], - ["input_dir/good_data.chi"], - ), - ( # wildcard pattern, matching files starting with good_data - ["good_data*"], - ["good_data.chi", "good_data.xy", "good_data.txt"], - ), -] - - -@pytest.mark.parametrize("inputs, expected", params_input) + +@pytest.mark.parametrize( + "inputs, expected", + [ + # Use cases can be found here: https://github.com/diffpy/diffpy.labpdfproc/issues/48 + # This test covers existing single input file, directory, a file list, and multiple files + # We store absolute path into input_directory and file names into input_file + ( # C1: single good file in the current directory, expect to return the absolute Path of the file + ["good_data.chi"], + ["good_data.chi"], + ), + ( # C2: single good file in an input directory, expect to return the absolute Path of the file + ["input_dir/good_data.chi"], + ["input_dir/good_data.chi"], + ), + ( # C3: glob current directory, expect to return all files in the current directory + ["."], + ["good_data.chi", "good_data.xy", "good_data.txt", "unreadable_file.txt", "binary.pkl"], + ), + ( # C4: glob input directory, expect to return all files in that directory + ["./input_dir"], + [ + "input_dir/good_data.chi", + "input_dir/good_data.xy", + "input_dir/good_data.txt", + "input_dir/unreadable_file.txt", + "input_dir/binary.pkl", + ], + ), + ( # C5: glob list of input directories, expect to return all files in the directories + [".", "./input_dir"], + [ + "./good_data.chi", + "./good_data.xy", + "./good_data.txt", + "./unreadable_file.txt", + "./binary.pkl", + "input_dir/good_data.chi", + "input_dir/good_data.xy", + "input_dir/good_data.txt", + "input_dir/unreadable_file.txt", + "input_dir/binary.pkl", + ], + ), + ( # C6: file_list_example2.txt list of files provided in different directories with wildcard, + # expect to return all files listed on the file_list file + ["input_dir/file_list_example2.txt"], + [ + "input_dir/good_data.chi", + "good_data.xy", + "input_dir/good_data.txt", + "input_dir/unreadable_file.txt", + ], + ), + ( # C7: wildcard pattern, expect to match files with .chi extension in the same directory + ["./*.chi"], + ["good_data.chi"], + ), + ( # C8: wildcard pattern, expect to match files with .chi extension in the input directory + ["input_dir/*.chi"], + ["input_dir/good_data.chi"], + ), + ( # C9: wildcard pattern, expect to match files starting with good_data + ["good_data*"], + ["good_data.chi", "good_data.xy", "good_data.txt"], + ), + ], +) def test_set_input_lists(inputs, expected, user_filesystem): base_dir = Path(user_filesystem) os.chdir(base_dir) @@ -87,53 +99,56 @@ def test_set_input_lists(inputs, expected, user_filesystem): assert sorted(actual_args.input_paths) == sorted(expected_paths) -# This test covers non-existing single input file or directory, in this case we raise an error with message -params_input_bad = [ - ( - ["non_existing_file.xy"], - "Cannot find non_existing_file.xy. Please specify valid input file(s) or directories.", - ), - ( - ["./input_dir/non_existing_file.xy"], - "Cannot find ./input_dir/non_existing_file.xy. Please specify valid input file(s) or directories.", - ), - (["./non_existing_dir"], "Cannot find ./non_existing_dir. Please specify valid input file(s) or directories."), - ( # list of files provided (with missing files) - ["good_data.chi", "good_data.xy", "unreadable_file.txt", "missing_file.txt"], - "Cannot find missing_file.txt. Please specify valid input file(s) or directories.", - ), - ( # file_list.txt list of files provided (with missing files) - ["input_dir/file_list.txt"], - "Cannot find missing_file.txt. Please specify valid input file(s) or directories.", - ), -] - - -@pytest.mark.parametrize("inputs, msg", params_input_bad) -def test_set_input_files_bad(inputs, msg, user_filesystem): +@pytest.mark.parametrize( + "inputs, expected_error_msg", + [ + # This test covers non-existing single input file or directory, in this case we raise an error with message + ( # C1: non-existing single file + ["non_existing_file.xy"], + "Cannot find non_existing_file.xy. Please specify valid input file(s) or directories.", + ), + ( # C2: non-existing single file with directory + ["./input_dir/non_existing_file.xy"], + "Cannot find ./input_dir/non_existing_file.xy. Please specify valid input file(s) or directories.", + ), + ( # C3: non-existing single directory + ["./non_existing_dir"], + "Cannot find ./non_existing_dir. Please specify valid input file(s) or directories.", + ), + ( # C4: list of files provided (with missing files) + ["good_data.chi", "good_data.xy", "unreadable_file.txt", "missing_file.txt"], + "Cannot find missing_file.txt. Please specify valid input file(s) or directories.", + ), + ( # C5: file_list.txt list of files provided (with missing files) + ["input_dir/file_list.txt"], + "Cannot find missing_file.txt. Please specify valid input file(s) or directories.", + ), + ], +) +def test_set_input_files_bad(inputs, expected_error_msg, user_filesystem): base_dir = Path(user_filesystem) os.chdir(base_dir) cli_inputs = ["2.5"] + inputs actual_args = get_args(cli_inputs) - with pytest.raises(FileNotFoundError, match=msg[0]): + with pytest.raises(FileNotFoundError, match=re.escape(expected_error_msg)): actual_args = set_input_lists(actual_args) -params1 = [ - ([], ["."]), - (["--output-directory", "."], ["."]), - (["--output-directory", "new_dir"], ["new_dir"]), - (["--output-directory", "input_dir"], ["input_dir"]), -] - - -@pytest.mark.parametrize("inputs, expected", params1) +@pytest.mark.parametrize( + "inputs, expected", + [ + ([], ["."]), + (["--output-directory", "."], ["."]), + (["--output-directory", "new_dir"], ["new_dir"]), + (["--output-directory", "input_dir"], ["input_dir"]), + ], +) def test_set_output_directory(inputs, expected, user_filesystem): os.chdir(user_filesystem) expected_output_directory = Path(user_filesystem) / expected[0] cli_inputs = ["2.5", "data.xy"] + inputs actual_args = get_args(cli_inputs) - actual_args.output_directory = set_output_directory(actual_args) + actual_args = set_output_directory(actual_args) assert actual_args.output_directory == expected_output_directory assert Path(actual_args.output_directory).exists() assert Path(actual_args.output_directory).is_dir() @@ -144,67 +159,66 @@ def test_set_output_directory_bad(user_filesystem): cli_inputs = ["2.5", "data.xy", "--output-directory", "good_data.chi"] actual_args = get_args(cli_inputs) with pytest.raises(FileExistsError): - actual_args.output_directory = set_output_directory(actual_args) + actual_args = set_output_directory(actual_args) assert Path(actual_args.output_directory).exists() assert not Path(actual_args.output_directory).is_dir() -params2 = [ - ([], [0.71073, "Mo"]), - (["--anode-type", "Ag"], [0.59, "Ag"]), - (["--wavelength", "0.25"], [0.25, None]), - (["--wavelength", "0.25", "--anode-type", "Ag"], [0.25, None]), -] - - -@pytest.mark.parametrize("inputs, expected", params2) +@pytest.mark.parametrize( + "inputs, expected", + [ + ([], {"wavelength": 0.71073, "anode_type": "Mo"}), + (["--anode-type", "Ag"], {"wavelength": 0.59, "anode_type": "Ag"}), + (["--wavelength", "0.25"], {"wavelength": 0.25, "anode_type": None}), + (["--wavelength", "0.25", "--anode-type", "Ag"], {"wavelength": 0.25, "anode_type": None}), + ], +) def test_set_wavelength(inputs, expected): - expected_wavelength, expected_anode_type = expected[0], expected[1] cli_inputs = ["2.5", "data.xy"] + inputs actual_args = get_args(cli_inputs) actual_args = set_wavelength(actual_args) - assert actual_args.wavelength == expected_wavelength - assert getattr(actual_args, "anode_type", None) == expected_anode_type - - -params3 = [ - ( - ["--anode-type", "invalid"], - [f"Anode type not recognized. Please rerun specifying an anode_type from {*known_sources, }."], - ), - ( - ["--wavelength", "0"], - ["No valid wavelength. Please rerun specifying a known anode_type or a positive wavelength."], - ), - ( - ["--wavelength", "-1", "--anode-type", "Mo"], - ["No valid wavelength. Please rerun specifying a known anode_type or a positive wavelength."], - ), -] - - -@pytest.mark.parametrize("inputs, msg", params3) -def test_set_wavelength_bad(inputs, msg): + assert actual_args.wavelength == expected["wavelength"] + assert getattr(actual_args, "anode_type", None) == expected["anode_type"] + + +@pytest.mark.parametrize( + "inputs, expected_error_msg", + [ + ( + ["--anode-type", "invalid"], + f"Anode type not recognized. Please rerun specifying an anode_type from {*known_sources, }.", + ), + ( + ["--wavelength", "0"], + "No valid wavelength. Please rerun specifying a known anode_type or a positive wavelength.", + ), + ( + ["--wavelength", "-1", "--anode-type", "Mo"], + "No valid wavelength. Please rerun specifying a known anode_type or a positive wavelength.", + ), + ], +) +def test_set_wavelength_bad(inputs, expected_error_msg): cli_inputs = ["2.5", "data.xy"] + inputs actual_args = get_args(cli_inputs) - with pytest.raises(ValueError, match=re.escape(msg[0])): + with pytest.raises(ValueError, match=re.escape(expected_error_msg)): actual_args = set_wavelength(actual_args) -params4 = [ - ([], ["tth"]), - (["--xtype", "2theta"], ["tth"]), - (["--xtype", "d"], ["d"]), - (["--xtype", "q"], ["q"]), -] - - -@pytest.mark.parametrize("inputs, expected", params4) -def test_set_xtype(inputs, expected): +@pytest.mark.parametrize( + "inputs, expected_xtype", + [ + ([], "tth"), + (["--xtype", "2theta"], "tth"), + (["--xtype", "d"], "d"), + (["--xtype", "q"], "q"), + ], +) +def test_set_xtype(inputs, expected_xtype): cli_inputs = ["2.5", "data.xy"] + inputs actual_args = get_args(cli_inputs) actual_args = set_xtype(actual_args) - assert actual_args.xtype == expected[0] + assert actual_args.xtype == expected_xtype def test_set_xtype_bad(): @@ -242,17 +256,17 @@ def test_set_mud_bad(): actual_args = set_mud(actual_args) -params5 = [ - ([], []), - ( - ["--user-metadata", "facility=NSLS II", "beamline=28ID-2", "favorite color=blue"], - [["facility", "NSLS II"], ["beamline", "28ID-2"], ["favorite color", "blue"]], - ), - (["--user-metadata", "x=y=z"], [["x", "y=z"]]), -] - - -@pytest.mark.parametrize("inputs, expected", params5) +@pytest.mark.parametrize( + "inputs, expected", + [ + ([], []), + ( + ["--user-metadata", "facility=NSLS II", "beamline=28ID-2", "favorite color=blue"], + [["facility", "NSLS II"], ["beamline", "28ID-2"], ["favorite color", "blue"]], + ), + (["--user-metadata", "x=y=z"], [["x", "y=z"]]), + ], +) def test_load_user_metadata(inputs, expected): expected_args = get_args(["2.5", "data.xy"]) for expected_pair in expected: @@ -265,52 +279,50 @@ def test_load_user_metadata(inputs, expected): assert actual_args == expected_args -params6 = [ - ( - ["--user-metadata", "facility=", "NSLS II"], - [ +@pytest.mark.parametrize( + "inputs, expected_error_msg", + [ + ( + ["--user-metadata", "facility=", "NSLS II"], + "Please provide key-value pairs in the format key=value. " + "For more information, use `labpdfproc --help.`", + ), + ( + ["--user-metadata", "favorite", "color=blue"], "Please provide key-value pairs in the format key=value. " - "For more information, use `labpdfproc --help.`" - ], - ), - ( - ["--user-metadata", "favorite", "color=blue"], - "Please provide key-value pairs in the format key=value. " - "For more information, use `labpdfproc --help.`", - ), - ( - ["--user-metadata", "beamline", "=", "28ID-2"], - "Please provide key-value pairs in the format key=value. " - "For more information, use `labpdfproc --help.`", - ), - ( - ["--user-metadata", "facility=NSLS II", "facility=NSLS III"], - "Please do not specify repeated keys: facility. ", - ), - ( - ["--user-metadata", "wavelength=2"], - "wavelength is a reserved name. Please rerun using a different key name. ", - ), -] - - -@pytest.mark.parametrize("inputs, msg", params6) -def test_load_user_metadata_bad(inputs, msg): + "For more information, use `labpdfproc --help.`", + ), + ( + ["--user-metadata", "beamline", "=", "28ID-2"], + "Please provide key-value pairs in the format key=value. " + "For more information, use `labpdfproc --help.`", + ), + ( + ["--user-metadata", "facility=NSLS II", "facility=NSLS III"], + "Please do not specify repeated keys: facility.", + ), + ( + ["--user-metadata", "wavelength=2"], + "wavelength is a reserved name. Please rerun using a different key name.", + ), + ], +) +def test_load_user_metadata_bad(inputs, expected_error_msg): cli_inputs = ["2.5", "data.xy"] + inputs actual_args = get_args(cli_inputs) - with pytest.raises(ValueError, match=msg[0]): + with pytest.raises(ValueError, match=re.escape(expected_error_msg)): actual_args = load_user_metadata(actual_args) -params_user_info = [ - ([None, None], ["home_username", "home@email.com"]), - (["cli_username", None], ["cli_username", "home@email.com"]), - ([None, "cli@email.com"], ["home_username", "cli@email.com"]), - (["cli_username", "cli@email.com"], ["cli_username", "cli@email.com"]), -] - - -@pytest.mark.parametrize("inputs, expected", params_user_info) +@pytest.mark.parametrize( + "inputs, expected", + [ + ([None, None], ["home_username", "home@email.com"]), + (["cli_username", None], ["cli_username", "home@email.com"]), + ([None, "cli@email.com"], ["home_username", "cli@email.com"]), + (["cli_username", "cli@email.com"], ["cli_username", "cli@email.com"]), + ], +) def test_load_user_info(monkeypatch, inputs, expected, user_filesystem): cwd = Path(user_filesystem) home_dir = cwd / "home_dir" @@ -354,6 +366,8 @@ def test_load_metadata(mocker, user_filesystem): "cli_username", "--email", "cli@email.com", + "--orcid", + "cli_orcid", ] actual_args = get_args(cli_inputs) actual_args = preprocessing_args(actual_args) @@ -369,6 +383,7 @@ def test_load_metadata(mocker, user_filesystem): "key": "value", "username": "cli_username", "email": "cli@email.com", + "orcid": "cli_orcid", "package_info": {"diffpy.labpdfproc": "1.2.3", "diffpy.utils": "3.3.0"}, "z_scan_file": None, } From 42fa3641416ff2d88345da716a2c97c36301df29 Mon Sep 17 00:00:00 2001 From: yucongalicechen Date: Thu, 23 Jan 2025 13:29:56 -0500 Subject: [PATCH 5/6] fix: add tests for orcid --- tests/test_tools.py | 43 ++++++++++++++++++++++++++++++++++--------- 1 file changed, 34 insertions(+), 9 deletions(-) diff --git a/tests/test_tools.py b/tests/test_tools.py index 82c749c..778492b 100644 --- a/tests/test_tools.py +++ b/tests/test_tools.py @@ -316,11 +316,27 @@ def test_load_user_metadata_bad(inputs, expected_error_msg): @pytest.mark.parametrize( "inputs, expected", - [ - ([None, None], ["home_username", "home@email.com"]), - (["cli_username", None], ["cli_username", "home@email.com"]), - ([None, "cli@email.com"], ["home_username", "cli@email.com"]), - (["cli_username", "cli@email.com"], ["cli_username", "cli@email.com"]), + [ # Test that when cli inputs are present, they override home config, otherwise we take home config + ( + {"username": None, "email": None, "orcid": None}, + {"username": "home_username", "email": "home@email.com", "orcid": "home_orcid"}, + ), + ( + {"username": "cli_username", "email": None, "orcid": None}, + {"username": "cli_username", "email": "home@email.com", "orcid": "home_orcid"}, + ), + ( + {"username": None, "email": "cli@email.com", "orcid": None}, + {"username": "home_username", "email": "cli@email.com", "orcid": "home_orcid"}, + ), + ( + {"username": None, "email": None, "orcid": "cli_orcid"}, + {"username": "home_username", "email": "home@email.com", "orcid": "cli_orcid"}, + ), + ( + {"username": "cli_username", "email": "cli@email.com", "orcid": "cli_orcid"}, + {"username": "cli_username", "email": "cli@email.com", "orcid": "cli_orcid"}, + ), ], ) def test_load_user_info(monkeypatch, inputs, expected, user_filesystem): @@ -329,12 +345,21 @@ def test_load_user_info(monkeypatch, inputs, expected, user_filesystem): monkeypatch.setattr("pathlib.Path.home", lambda _: home_dir) os.chdir(cwd) - expected_username, expected_email = expected - cli_inputs = ["2.5", "data.xy", "--username", inputs[0], "--email", inputs[1]] + cli_inputs = [ + "2.5", + "data.xy", + "--username", + inputs["username"], + "--email", + inputs["email"], + "--orcid", + inputs["orcid"], + ] actual_args = get_args(cli_inputs) actual_args = load_user_info(actual_args) - assert actual_args.username == expected_username - assert actual_args.email == expected_email + assert actual_args.username == expected["username"] + assert actual_args.email == expected["email"] + assert actual_args.orcid == expected["orcid"] def test_load_package_info(mocker): From 84aa7279eb6642cd9cd5bdc05a7a5e80b229eea4 Mon Sep 17 00:00:00 2001 From: yucongalicechen Date: Thu, 23 Jan 2025 13:34:31 -0500 Subject: [PATCH 6/6] docs: add news --- news/docstring-tests.rst | 4 ++-- 1 file changed, 2 insertions(+), 2 deletions(-) diff --git a/news/docstring-tests.rst b/news/docstring-tests.rst index 790d30b..137f773 100644 --- a/news/docstring-tests.rst +++ b/news/docstring-tests.rst @@ -1,10 +1,10 @@ **Added:** -* +* Functionality in `load_user_info` to enable user to enter an ORCID. **Changed:** -* +* All function docstrings and tests to be more informative, incorporating new ORCID function and improving overall clarity. **Deprecated:**