Skip to content

load input directory and files #50

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Closed
wants to merge 10 commits into from
77 changes: 45 additions & 32 deletions src/diffpy/labpdfproc/labpdfprocapp.py
Original file line number Diff line number Diff line change
@@ -3,7 +3,13 @@
from pathlib import Path

from diffpy.labpdfproc.functions import apply_corr, compute_cve
from diffpy.labpdfproc.tools import known_sources, load_user_metadata, set_output_directory, set_wavelength
from diffpy.labpdfproc.tools import (
known_sources,
load_user_metadata,
set_input_files,
set_output_directory,
set_wavelength,
)
from diffpy.utils.parsers.loaddata import loadData
from diffpy.utils.scattering_objects.diffraction_objects import XQUANTITIES, Diffraction_object

@@ -76,45 +82,52 @@ def get_args(override_cli_inputs=None):

def main():
args = get_args()
args = set_input_files(args)
args.output_directory = set_output_directory(args)
args.wavelength = set_wavelength(args)
args = load_user_metadata(args)

filepath = Path(args.input_file)
outfilestem = filepath.stem + "_corrected"
corrfilestem = filepath.stem + "_cve"
outfile = args.output_directory / (outfilestem + ".chi")
corrfile = args.output_directory / (corrfilestem + ".chi")
for input_file in args.input_file:
filepath = Path(args.input_file)
outfilestem = filepath.stem + "_corrected"
corrfilestem = filepath.stem + "_cve"
outfile = args.output_directory / (outfilestem + ".chi")
corrfile = args.output_directory / (corrfilestem + ".chi")

if outfile.exists() and not args.force_overwrite:
sys.exit(
f"Output file {str(outfile)} already exists. Please rerun "
f"specifying -f if you want to overwrite it."
)
if corrfile.exists() and args.output_correction and not args.force_overwrite:
sys.exit(
f"Corrections file {str(corrfile)} was requested and already "
f"exists. Please rerun specifying -f if you want to overwrite it."
)
if outfile.exists() and not args.force_overwrite:
sys.exit(
f"Output file {str(outfile)} already exists. Please rerun "
f"specifying -f if you want to overwrite it."
)
if corrfile.exists() and args.output_correction and not args.force_overwrite:
sys.exit(
f"Corrections file {str(corrfile)} was requested and already "
f"exists. Please rerun specifying -f if you want to overwrite it."
)

input_pattern = Diffraction_object(wavelength=args.wavelength)
xarray, yarray = loadData(args.input_file, unpack=True)
input_pattern.insert_scattering_quantity(
xarray,
yarray,
"tth",
scat_quantity="x-ray",
name=str(args.input_file),
metadata={"muD": args.mud, "anode_type": args.anode_type},
)
input_pattern = Diffraction_object(wavelength=args.wavelength)

try:
xarray, yarray = loadData(args.input_file, unpack=True)
except Exception as e:
raise ValueError(f"Failed to load data from {filepath}: {e}.")

input_pattern.insert_scattering_quantity(
xarray,
yarray,
"tth",
scat_quantity="x-ray",
name=str(args.input_file),
metadata={"muD": args.mud, "anode_type": args.anode_type},
)

absorption_correction = compute_cve(input_pattern, args.mud, args.wavelength)
corrected_data = apply_corr(input_pattern, absorption_correction)
corrected_data.name = f"Absorption corrected input_data: {input_pattern.name}"
corrected_data.dump(f"{outfile}", xtype="tth")
absorption_correction = compute_cve(input_pattern, args.mud, args.wavelength)
corrected_data = apply_corr(input_pattern, absorption_correction)
corrected_data.name = f"Absorption corrected input_data: {input_pattern.name}"
corrected_data.dump(f"{outfile}", xtype="tth")

if args.output_correction:
absorption_correction.dump(f"{corrfile}", xtype="tth")
if args.output_correction:
absorption_correction.dump(f"{corrfile}", xtype="tth")


if __name__ == "__main__":
45 changes: 44 additions & 1 deletion src/diffpy/labpdfproc/tests/test_tools.py
Original file line number Diff line number Diff line change
@@ -4,7 +4,50 @@
import pytest

from diffpy.labpdfproc.labpdfprocapp import get_args
from diffpy.labpdfproc.tools import known_sources, load_user_metadata, set_output_directory, set_wavelength
from diffpy.labpdfproc.tools import (
known_sources,
load_user_metadata,
set_input_files,
set_output_directory,
set_wavelength,
)
from diffpy.utils.parsers.loaddata import loadData

params1 = [
(
["--input-file", "."],
[
".",
["good_data.chi", "good_data.xy", "good_data.txt", "unreadable_file.txt", "binary.pkl"],
],
),
(["--input-file", "good_data.chi"], [".", "good_data.chi"]),
(["--input-file", "input_dir/unreadable_file.txt"], ["input_dir", "unreadable_file.txt"]),
# ([Path.cwd()], [Path.cwd()]),
]


@pytest.mark.parametrize("inputs, expected", params1)
def test_set_input_files(inputs, expected, user_filesystem):
expected_input_directory = Path(user_filesystem) / expected[0]
expected_input_files = expected[1]

cli_inputs = ["2.5"] + inputs
actual_args = get_args(cli_inputs)
actual_args = set_input_files(actual_args)
assert actual_args.input_directory == expected_input_directory
assert set(actual_args.input_file) == set(expected_input_files)


def test_loadData_with_input_files(user_filesystem):
xarray_chi, yarray_chi = loadData("good_data.chi", unpack=True)
xarray_xy, yarray_xy = loadData("good_data.xy", unpack=True)
xarray_txt, yarray_txt = loadData("good_data.txt", unpack=True)
with pytest.raises(ValueError):
xarray_txt, yarray_txt = loadData("unreadable_file.txt", unpack=True)
with pytest.raises(ValueError):
xarray_pkl, yarray_pkl = loadData("binary.pkl", unpack=True)


params1 = [
([], ["."]),
32 changes: 32 additions & 0 deletions src/diffpy/labpdfproc/tools.py
Original file line number Diff line number Diff line change
@@ -1,9 +1,41 @@
import glob
import os
from pathlib import Path

WAVELENGTHS = {"Mo": 0.71, "Ag": 0.59, "Cu": 1.54}
known_sources = [key for key in WAVELENGTHS.keys()]


def set_input_files(args):
"""
Set input directory and files
Parameters
----------
args argparse.Namespace
the arguments from the parser
Returns
-------
args argparse.Namespace
"""
if not args.input_file or not Path(args.input_file).exists():
raise ValueError("Please specify valid input file or directory.")

if not Path(args.input_file).is_dir():
input_dir = Path.cwd() / Path(args.input_file).parent
input_file_name = Path(args.input_file).name
args.input_file = input_file_name
else:
input_dir = Path(args.input_file).resolve()
input_files = [file for file in glob.glob(str(input_dir) + "/*", recursive=True) if os.path.isfile(file)]
input_file_names = [os.path.basename(input_file_path) for input_file_path in input_files]
args.input_file = input_file_names
setattr(args, "input_directory", input_dir)
return args
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

otherwise, I think this looks good.



def set_output_directory(args):
"""
set the output directory based on the given input arguments