diff --git a/src/diffpy/labpdfproc/labpdfprocapp.py b/src/diffpy/labpdfproc/labpdfprocapp.py index 7a9531f..ef3f253 100644 --- a/src/diffpy/labpdfproc/labpdfprocapp.py +++ b/src/diffpy/labpdfproc/labpdfprocapp.py @@ -1,9 +1,14 @@ import sys from argparse import ArgumentParser -from pathlib import Path from diffpy.labpdfproc.functions import apply_corr, compute_cve -from diffpy.labpdfproc.tools import known_sources, load_user_metadata, set_output_directory, set_wavelength +from diffpy.labpdfproc.tools import ( + known_sources, + load_user_metadata, + set_input_files, + set_output_directory, + set_wavelength, +) from diffpy.utils.parsers.loaddata import loadData from diffpy.utils.scattering_objects.diffraction_objects import XQUANTITIES, Diffraction_object @@ -11,7 +16,15 @@ def get_args(override_cli_inputs=None): p = ArgumentParser() p.add_argument("mud", help="Value of mu*D for your " "sample. Required.", type=float) - p.add_argument("-i", "--input-file", help="The filename of the " "datafile to load.") + p.add_argument( + "input", + nargs="+", + help="The filename(s) or folder(s) of the datafile(s) to load. Required. " + "Supports multiple arguments of input file or directory. " + "The file can be either a data file or a file containing a list of files. " + "If a directory is provided, we will load all data files in it. " + "For example, file.xy, data/file.xy, file_list.txt, ./data/file.xy, ./data are all valid inputs. ", + ) p.add_argument( "-a", "--anode-type", @@ -76,45 +89,51 @@ def get_args(override_cli_inputs=None): def main(): args = get_args() + args = set_input_files(args) args.output_directory = set_output_directory(args) args.wavelength = set_wavelength(args) args = load_user_metadata(args) - filepath = Path(args.input_file) - outfilestem = filepath.stem + "_corrected" - corrfilestem = filepath.stem + "_cve" - outfile = args.output_directory / (outfilestem + ".chi") - corrfile = args.output_directory / (corrfilestem + ".chi") + for filepath in args.input_directory: + outfilestem = filepath.stem + "_corrected" + corrfilestem = filepath.stem + "_cve" + outfile = args.output_directory / (outfilestem + ".chi") + corrfile = args.output_directory / (corrfilestem + ".chi") - if outfile.exists() and not args.force_overwrite: - sys.exit( - f"Output file {str(outfile)} already exists. Please rerun " - f"specifying -f if you want to overwrite it." - ) - if corrfile.exists() and args.output_correction and not args.force_overwrite: - sys.exit( - f"Corrections file {str(corrfile)} was requested and already " - f"exists. Please rerun specifying -f if you want to overwrite it." - ) + if outfile.exists() and not args.force_overwrite: + sys.exit( + f"Output file {str(outfile)} already exists. Please rerun " + f"specifying -f if you want to overwrite it." + ) + if corrfile.exists() and args.output_correction and not args.force_overwrite: + sys.exit( + f"Corrections file {str(corrfile)} was requested and already " + f"exists. Please rerun specifying -f if you want to overwrite it." + ) - input_pattern = Diffraction_object(wavelength=args.wavelength) - xarray, yarray = loadData(args.input_file, unpack=True) - input_pattern.insert_scattering_quantity( - xarray, - yarray, - "tth", - scat_quantity="x-ray", - name=str(args.input_file), - metadata={"muD": args.mud, "anode_type": args.anode_type}, - ) + input_pattern = Diffraction_object(wavelength=args.wavelength) + + try: + xarray, yarray = loadData(args.input_file, unpack=True) + except Exception as e: + raise ValueError(f"Failed to load data from {filepath}: {e}.") + + input_pattern.insert_scattering_quantity( + xarray, + yarray, + "tth", + scat_quantity="x-ray", + name=str(args.input_file), + metadata={"muD": args.mud, "anode_type": args.anode_type}, + ) - absorption_correction = compute_cve(input_pattern, args.mud, args.wavelength) - corrected_data = apply_corr(input_pattern, absorption_correction) - corrected_data.name = f"Absorption corrected input_data: {input_pattern.name}" - corrected_data.dump(f"{outfile}", xtype="tth") + absorption_correction = compute_cve(input_pattern, args.mud, args.wavelength) + corrected_data = apply_corr(input_pattern, absorption_correction) + corrected_data.name = f"Absorption corrected input_data: {input_pattern.name}" + corrected_data.dump(f"{outfile}", xtype="tth") - if args.output_correction: - absorption_correction.dump(f"{corrfile}", xtype="tth") + if args.output_correction: + absorption_correction.dump(f"{corrfile}", xtype="tth") if __name__ == "__main__": diff --git a/src/diffpy/labpdfproc/tests/conftest.py b/src/diffpy/labpdfproc/tests/conftest.py index 1e4ab40..075428c 100644 --- a/src/diffpy/labpdfproc/tests/conftest.py +++ b/src/diffpy/labpdfproc/tests/conftest.py @@ -39,4 +39,13 @@ def user_filesystem(tmp_path): with open(os.path.join(input_dir, "binary.pkl"), "wb") as f: f.write(binary_data) + file_list_dir = Path(tmp_path).resolve() / "file_list_dir" + file_list_dir.mkdir(parents=True, exist_ok=True) + with open(os.path.join(file_list_dir, "file_list.txt"), "w") as f: + f.write("good_data.chi \n good_data.xy \n good_data.txt \n missing_file.txt") + with open(os.path.join(file_list_dir, "file_list_example2.txt"), "w") as f: + f.write("input_dir/good_data.chi \n") + f.write("good_data.xy \n") + f.write(str(os.path.abspath(os.path.join(input_dir, "good_data.txt"))) + "\n") + yield tmp_path diff --git a/src/diffpy/labpdfproc/tests/test_tools.py b/src/diffpy/labpdfproc/tests/test_tools.py index 10491ac..717e5eb 100644 --- a/src/diffpy/labpdfproc/tests/test_tools.py +++ b/src/diffpy/labpdfproc/tests/test_tools.py @@ -1,10 +1,119 @@ +import os import re from pathlib import Path import pytest from diffpy.labpdfproc.labpdfprocapp import get_args -from diffpy.labpdfproc.tools import known_sources, load_user_metadata, set_output_directory, set_wavelength +from diffpy.labpdfproc.tools import ( + known_sources, + load_user_metadata, + set_input_files, + set_output_directory, + set_wavelength, +) +from diffpy.utils.parsers.loaddata import loadData + +# Use cases can be found here: https://github.com/diffpy/diffpy.labpdfproc/issues/48 + +# This test covers existing single input file, directory, a file list, and multiple files +# We store absolute path into input_directory and file names into input_file +params_input = [ + (["good_data.chi"], ["good_data.chi"]), # single good file, same directory + (["input_dir/good_data.chi"], ["input_dir/good_data.chi"]), # single good file, input directory + ( # glob current directory + ["."], + ["good_data.chi", "good_data.xy", "good_data.txt", "unreadable_file.txt", "binary.pkl"], + ), + ( # glob input directory + ["./input_dir"], + [ + "input_dir/good_data.chi", + "input_dir/good_data.xy", + "input_dir/good_data.txt", + "input_dir/unreadable_file.txt", + "input_dir/binary.pkl", + ], + ), + ( # list of files provided (we skip if encountering invalid files) + ["good_data.chi", "good_data.xy", "unreadable_file.txt", "missing_file.txt"], + ["good_data.chi", "good_data.xy", "unreadable_file.txt"], + ), + ( # list of files provided (with invalid files and files in different directories) + ["input_dir/good_data.chi", "good_data.chi", "missing_file.txt"], + ["input_dir/good_data.chi", "good_data.chi"], + ), + ( # file_list.txt list of files provided + ["file_list_dir/file_list.txt"], + ["good_data.chi", "good_data.xy", "good_data.txt"], + ), + ( # file_list_example2.txt list of files provided in different directories + ["file_list_dir/file_list_example2.txt"], + ["input_dir/good_data.chi", "good_data.xy", "input_dir/good_data.txt"], + ), +] + + +@pytest.mark.parametrize("inputs, expected", params_input) +def test_set_input_files(inputs, expected, user_filesystem): + expected_input_directory = [] + for expected_path in expected: + expected_input_directory.append(Path(user_filesystem) / expected_path) + + cli_inputs = ["2.5"] + inputs + actual_args = get_args(cli_inputs) + actual_args = set_input_files(actual_args) + assert set(actual_args.input_directory) == set(expected_input_directory) + + +# This test is for existing single input file or directory absolute path not in cwd +# Here we are in user_filesystem/input_dir, testing for a file or directory in user_filesystem +params_input_not_cwd = [ + (["good_data.chi"], ["good_data.chi"]), + (["."], ["good_data.chi", "good_data.xy", "good_data.txt", "unreadable_file.txt", "binary.pkl"]), +] + + +@pytest.mark.parametrize("inputs, expected", params_input_not_cwd) +def test_set_input_files_not_cwd(inputs, expected, user_filesystem): + expected_input_directory = [] + for expected_path in expected: + expected_input_directory.append(Path(user_filesystem) / expected_path) + actual_input = [str(Path(user_filesystem) / inputs[0])] + os.chdir("input_dir") + + cli_inputs = ["2.5"] + actual_input + actual_args = get_args(cli_inputs) + actual_args = set_input_files(actual_args) + assert set(actual_args.input_directory) == set(expected_input_directory) + + +# This test covers non-existing single input file or directory, in this case we raise an error with message +params_input_bad = [ + (["non_existing_file.xy"], "Please specify at least one valid input file or directory."), + (["./input_dir/non_existing_file.xy"], "Please specify at least one valid input file or directory."), + (["./non_existing_dir"], "Please specify at least one valid input file or directory."), +] + + +@pytest.mark.parametrize("inputs, msg", params_input_bad) +def test_set_input_files_bad(inputs, msg, user_filesystem): + cli_inputs = ["2.5"] + inputs + actual_args = get_args(cli_inputs) + with pytest.raises(ValueError, match=msg[0]): + actual_args = set_input_files(actual_args) + + +# Pass files to loadData and use it to check if file is valid or not +def test_loadData_with_input_files(user_filesystem): + xarray_chi, yarray_chi = loadData("good_data.chi", unpack=True) + xarray_xy, yarray_xy = loadData("good_data.xy", unpack=True) + xarray_txt, yarray_txt = loadData("good_data.txt", unpack=True) + with pytest.raises(ValueError): + xarray_txt, yarray_txt = loadData("unreadable_file.txt", unpack=True) + with pytest.raises(ValueError): + xarray_pkl, yarray_pkl = loadData("binary.pkl", unpack=True) + params1 = [ ([], ["."]), @@ -17,7 +126,7 @@ @pytest.mark.parametrize("inputs, expected", params1) def test_set_output_directory(inputs, expected, user_filesystem): expected_output_directory = Path(user_filesystem) / expected[0] - cli_inputs = ["2.5"] + inputs + cli_inputs = ["2.5", "data.xy"] + inputs actual_args = get_args(cli_inputs) actual_args.output_directory = set_output_directory(actual_args) assert actual_args.output_directory == expected_output_directory @@ -26,7 +135,7 @@ def test_set_output_directory(inputs, expected, user_filesystem): def test_set_output_directory_bad(user_filesystem): - cli_inputs = ["2.5", "--output-directory", "good_data.chi"] + cli_inputs = ["2.5", "data.xy", "--output-directory", "good_data.chi"] actual_args = get_args(cli_inputs) with pytest.raises(FileExistsError): actual_args.output_directory = set_output_directory(actual_args) @@ -45,7 +154,7 @@ def test_set_output_directory_bad(user_filesystem): @pytest.mark.parametrize("inputs, expected", params2) def test_set_wavelength(inputs, expected): expected_wavelength = expected[0] - cli_inputs = ["2.5"] + inputs + cli_inputs = ["2.5", "data.xy"] + inputs actual_args = get_args(cli_inputs) actual_args.wavelength = set_wavelength(actual_args) assert actual_args.wavelength == expected_wavelength @@ -69,7 +178,7 @@ def test_set_wavelength(inputs, expected): @pytest.mark.parametrize("inputs, msg", params3) def test_set_wavelength_bad(inputs, msg): - cli_inputs = ["2.5"] + inputs + cli_inputs = ["2.5", "data.xy"] + inputs actual_args = get_args(cli_inputs) with pytest.raises(ValueError, match=re.escape(msg[0])): actual_args.wavelength = set_wavelength(actual_args) @@ -87,12 +196,12 @@ def test_set_wavelength_bad(inputs, msg): @pytest.mark.parametrize("inputs, expected", params5) def test_load_user_metadata(inputs, expected): - expected_args = get_args(["2.5"]) + expected_args = get_args(["2.5", "data.xy"]) for expected_pair in expected: setattr(expected_args, expected_pair[0], expected_pair[1]) delattr(expected_args, "user_metadata") - cli_inputs = ["2.5"] + inputs + cli_inputs = ["2.5", "data.xy"] + inputs actual_args = get_args(cli_inputs) actual_args = load_user_metadata(actual_args) assert actual_args == expected_args @@ -129,7 +238,7 @@ def test_load_user_metadata(inputs, expected): @pytest.mark.parametrize("inputs, msg", params6) def test_load_user_metadata_bad(inputs, msg): - cli_inputs = ["2.5"] + inputs + cli_inputs = ["2.5", "data.xy"] + inputs actual_args = get_args(cli_inputs) with pytest.raises(ValueError, match=msg[0]): actual_args = load_user_metadata(actual_args) diff --git a/src/diffpy/labpdfproc/tools.py b/src/diffpy/labpdfproc/tools.py index caa012d..df3139d 100644 --- a/src/diffpy/labpdfproc/tools.py +++ b/src/diffpy/labpdfproc/tools.py @@ -1,9 +1,70 @@ +import glob +import os from pathlib import Path WAVELENGTHS = {"Mo": 0.71, "Ag": 0.59, "Cu": 1.54} known_sources = [key for key in WAVELENGTHS.keys()] +def set_input_files(args): + """ + Set input directory and files + + Parameters + ---------- + args argparse.Namespace + the arguments from the parser + + It is implemented as the following: + For each input, we try to read it as a file or a directory. + If input is a file, we first try to read it as a file list and store all listed file names. + If the first filename is invalid, then we proceed to treat it as a data file. + Otherwise if we have a directory, glob all files within it. + If any file does not exist, we raise a ValueError telling which file(s) does not exist. + If all files are invalid, we raise an Error telling user to specify at least one valid file or directory. + + Returns + ------- + args argparse.Namespace + + """ + + input_paths = [] + for input in args.input: + try: + if Path(input).exists(): + if not Path(input).is_dir(): + with open(args.input[0], "r") as f: + lines = [line.strip() for line in f] + if not os.path.isfile(lines[0]): + input_paths.append(Path(input).resolve()) + else: + for line in lines: + try: + if os.path.isfile(line): + input_paths.append(Path(line).resolve()) + except Exception as e: + raise ValueError(f"{line} does not exist. {e}.") + + else: + input_dir = Path(input).resolve() + input_files = [ + Path(file).resolve() + for file in glob.glob(str(input_dir) + "/*", recursive=True) + if os.path.isfile(file) + ] + input_paths.extend(input_files) + + except Exception as e: + raise ValueError(f"{input} does not exist. {e}.") + + if len(input_paths) == 0: + raise ValueError("Please specify at least one valid input file or directory.") + + setattr(args, "input_directory", input_paths) + return args + + def set_output_directory(args): """ set the output directory based on the given input arguments