Skip to content

load input directory and files #50

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Closed
wants to merge 10 commits into from
79 changes: 46 additions & 33 deletions src/diffpy/labpdfproc/labpdfprocapp.py
Original file line number Diff line number Diff line change
Expand Up @@ -3,15 +3,21 @@
from pathlib import Path

from diffpy.labpdfproc.functions import apply_corr, compute_cve
from diffpy.labpdfproc.tools import known_sources, load_user_metadata, set_output_directory, set_wavelength
from diffpy.labpdfproc.tools import (
known_sources,
load_user_metadata,
set_input_files,
set_output_directory,
set_wavelength,
)
from diffpy.utils.parsers.loaddata import loadData
from diffpy.utils.scattering_objects.diffraction_objects import XQUANTITIES, Diffraction_object


def get_args(override_cli_inputs=None):
p = ArgumentParser()
p.add_argument("mud", help="Value of mu*D for your " "sample. Required.", type=float)
p.add_argument("-i", "--input-file", help="The filename of the " "datafile to load.")
p.add_argument("input", help="The filename or directory of the " "datafile to load.")
p.add_argument(
"-a",
"--anode-type",
Expand Down Expand Up @@ -76,45 +82,52 @@ def get_args(override_cli_inputs=None):

def main():
args = get_args()
args = set_input_files(args)
args.output_directory = set_output_directory(args)
args.wavelength = set_wavelength(args)
args = load_user_metadata(args)

filepath = Path(args.input_file)
outfilestem = filepath.stem + "_corrected"
corrfilestem = filepath.stem + "_cve"
outfile = args.output_directory / (outfilestem + ".chi")
corrfile = args.output_directory / (corrfilestem + ".chi")
for input_file in args.input_file:
filepath = Path(args.input_file)
outfilestem = filepath.stem + "_corrected"
corrfilestem = filepath.stem + "_cve"
outfile = args.output_directory / (outfilestem + ".chi")
corrfile = args.output_directory / (corrfilestem + ".chi")

if outfile.exists() and not args.force_overwrite:
sys.exit(
f"Output file {str(outfile)} already exists. Please rerun "
f"specifying -f if you want to overwrite it."
)
if corrfile.exists() and args.output_correction and not args.force_overwrite:
sys.exit(
f"Corrections file {str(corrfile)} was requested and already "
f"exists. Please rerun specifying -f if you want to overwrite it."
)
if outfile.exists() and not args.force_overwrite:
sys.exit(
f"Output file {str(outfile)} already exists. Please rerun "
f"specifying -f if you want to overwrite it."
)
if corrfile.exists() and args.output_correction and not args.force_overwrite:
sys.exit(
f"Corrections file {str(corrfile)} was requested and already "
f"exists. Please rerun specifying -f if you want to overwrite it."
)

input_pattern = Diffraction_object(wavelength=args.wavelength)
xarray, yarray = loadData(args.input_file, unpack=True)
input_pattern.insert_scattering_quantity(
xarray,
yarray,
"tth",
scat_quantity="x-ray",
name=str(args.input_file),
metadata={"muD": args.mud, "anode_type": args.anode_type},
)
input_pattern = Diffraction_object(wavelength=args.wavelength)

try:
xarray, yarray = loadData(args.input_file, unpack=True)
except Exception as e:
raise ValueError(f"Failed to load data from {filepath}: {e}.")

input_pattern.insert_scattering_quantity(
xarray,
yarray,
"tth",
scat_quantity="x-ray",
name=str(args.input_file),
metadata={"muD": args.mud, "anode_type": args.anode_type},
)

absorption_correction = compute_cve(input_pattern, args.mud, args.wavelength)
corrected_data = apply_corr(input_pattern, absorption_correction)
corrected_data.name = f"Absorption corrected input_data: {input_pattern.name}"
corrected_data.dump(f"{outfile}", xtype="tth")
absorption_correction = compute_cve(input_pattern, args.mud, args.wavelength)
corrected_data = apply_corr(input_pattern, absorption_correction)
corrected_data.name = f"Absorption corrected input_data: {input_pattern.name}"
corrected_data.dump(f"{outfile}", xtype="tth")

if args.output_correction:
absorption_correction.dump(f"{corrfile}", xtype="tth")
if args.output_correction:
absorption_correction.dump(f"{corrfile}", xtype="tth")


if __name__ == "__main__":
Expand Down
89 changes: 81 additions & 8 deletions src/diffpy/labpdfproc/tests/test_tools.py
Original file line number Diff line number Diff line change
Expand Up @@ -4,7 +4,80 @@
import pytest

from diffpy.labpdfproc.labpdfprocapp import get_args
from diffpy.labpdfproc.tools import known_sources, load_user_metadata, set_output_directory, set_wavelength
from diffpy.labpdfproc.tools import (
known_sources,
load_user_metadata,
set_input_files,
set_output_directory,
set_wavelength,
)
from diffpy.utils.parsers.loaddata import loadData

# Use cases can be found here: https://github.com/diffpy/diffpy.labpdfproc/issues/48
params_input = [
(["good_data.chi"], [".", "good_data.chi"]),
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I think we need to have a discussion, by comment thread of zoom if it is easier, about what we want the program to do given the inputs and therefore what we want to test. It seems to me that these tests are just testing something to do with metadata, but this PR is about much more than that. Here is a start:

single-file case:

  1. check the file exists
  2. read the file
  3. if valid, compute the cve and process the data
  4. if unreadable, error with helpful message
  5. find the absolute path and store it in metadata
  6. write this into the output file header

We want to make sure all these things are tested. Some will be tested by other functions. But we need tests here for all the things that won't be covered by other functions.

Then we would like a similar list for teh other cases (a list of files, a glob....) and tests for those too. Please can you think about this and have a crack at it.

(["input_dir/good_data.chi"], ["input_dir", "good_data.chi"]),
(["./input_dir/good_data.chi"], ["input_dir", "good_data.chi"]),
(
["."],
[
".",
["good_data.chi", "good_data.xy", "good_data.txt", "unreadable_file.txt", "binary.pkl"],
],
),
(
["./input_dir"],
[
"input_dir",
["good_data.chi", "good_data.xy", "good_data.txt", "unreadable_file.txt", "binary.pkl"],
],
),
(
["input_dir"],
[
"input_dir",
["good_data.chi", "good_data.xy", "good_data.txt", "unreadable_file.txt", "binary.pkl"],
],
),
]


@pytest.mark.parametrize("inputs, expected", params_input)
def test_set_input_files(inputs, expected, user_filesystem):
expected_input_directory = Path(user_filesystem) / expected[0]
expected_input_files = expected[1]

cli_inputs = ["2.5"] + inputs
actual_args = get_args(cli_inputs)
actual_args = set_input_files(actual_args)
assert actual_args.input_directory == expected_input_directory
assert set(actual_args.input_file) == set(expected_input_files)


params_input_bad = [
(["new_file.xy"]),
(["./input_dir/new_file.xy"]),
(["./new_dir"]),
]


@pytest.mark.parametrize("inputs", params_input_bad)
def test_set_input_files_bad(inputs, user_filesystem):
cli_inputs = ["2.5"] + inputs
actual_args = get_args(cli_inputs)
with pytest.raises(ValueError):
actual_args = set_input_files(actual_args)


def test_loadData_with_input_files(user_filesystem):
xarray_chi, yarray_chi = loadData("good_data.chi", unpack=True)
xarray_xy, yarray_xy = loadData("good_data.xy", unpack=True)
xarray_txt, yarray_txt = loadData("good_data.txt", unpack=True)
with pytest.raises(ValueError):
xarray_txt, yarray_txt = loadData("unreadable_file.txt", unpack=True)
with pytest.raises(ValueError):
xarray_pkl, yarray_pkl = loadData("binary.pkl", unpack=True)


params1 = [
([], ["."]),
Expand All @@ -17,7 +90,7 @@
@pytest.mark.parametrize("inputs, expected", params1)
def test_set_output_directory(inputs, expected, user_filesystem):
expected_output_directory = Path(user_filesystem) / expected[0]
cli_inputs = ["2.5"] + inputs
cli_inputs = ["2.5", "data.xy"] + inputs
actual_args = get_args(cli_inputs)
actual_args.output_directory = set_output_directory(actual_args)
assert actual_args.output_directory == expected_output_directory
Expand All @@ -26,7 +99,7 @@ def test_set_output_directory(inputs, expected, user_filesystem):


def test_set_output_directory_bad(user_filesystem):
cli_inputs = ["2.5", "--output-directory", "good_data.chi"]
cli_inputs = ["2.5", "data.xy", "--output-directory", "good_data.chi"]
actual_args = get_args(cli_inputs)
with pytest.raises(FileExistsError):
actual_args.output_directory = set_output_directory(actual_args)
Expand All @@ -45,7 +118,7 @@ def test_set_output_directory_bad(user_filesystem):
@pytest.mark.parametrize("inputs, expected", params2)
def test_set_wavelength(inputs, expected):
expected_wavelength = expected[0]
cli_inputs = ["2.5"] + inputs
cli_inputs = ["2.5", "data.xy"] + inputs
actual_args = get_args(cli_inputs)
actual_args.wavelength = set_wavelength(actual_args)
assert actual_args.wavelength == expected_wavelength
Expand All @@ -69,7 +142,7 @@ def test_set_wavelength(inputs, expected):

@pytest.mark.parametrize("inputs, msg", params3)
def test_set_wavelength_bad(inputs, msg):
cli_inputs = ["2.5"] + inputs
cli_inputs = ["2.5", "data.xy"] + inputs
actual_args = get_args(cli_inputs)
with pytest.raises(ValueError, match=re.escape(msg[0])):
actual_args.wavelength = set_wavelength(actual_args)
Expand All @@ -87,12 +160,12 @@ def test_set_wavelength_bad(inputs, msg):

@pytest.mark.parametrize("inputs, expected", params5)
def test_load_user_metadata(inputs, expected):
expected_args = get_args(["2.5"])
expected_args = get_args(["2.5", "data.xy"])
for expected_pair in expected:
setattr(expected_args, expected_pair[0], expected_pair[1])
delattr(expected_args, "user_metadata")

cli_inputs = ["2.5"] + inputs
cli_inputs = ["2.5", "data.xy"] + inputs
actual_args = get_args(cli_inputs)
actual_args = load_user_metadata(actual_args)
assert actual_args == expected_args
Expand Down Expand Up @@ -129,7 +202,7 @@ def test_load_user_metadata(inputs, expected):

@pytest.mark.parametrize("inputs, msg", params6)
def test_load_user_metadata_bad(inputs, msg):
cli_inputs = ["2.5"] + inputs
cli_inputs = ["2.5", "data.xy"] + inputs
actual_args = get_args(cli_inputs)
with pytest.raises(ValueError, match=msg[0]):
actual_args = load_user_metadata(actual_args)
31 changes: 31 additions & 0 deletions src/diffpy/labpdfproc/tools.py
Original file line number Diff line number Diff line change
@@ -1,9 +1,40 @@
import glob
import os
from pathlib import Path

WAVELENGTHS = {"Mo": 0.71, "Ag": 0.59, "Cu": 1.54}
known_sources = [key for key in WAVELENGTHS.keys()]


def set_input_files(args):
"""
Set input directory and files

Parameters
----------
args argparse.Namespace
the arguments from the parser

Returns
-------
args argparse.Namespace

"""
if not Path(args.input).exists():
raise ValueError("Please specify valid input file or directory.")

if not Path(args.input).is_dir():
input_dir = Path.cwd() / Path(args.input).parent
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I wonder whether this will fail if the user gives a filename with a path that doesn't include cwd? For example, I think a valid test would be:

cd /user/me/analysis
labpdfcor 2.5 /user/me/data/my_file.xy

Please could you add test for this situation and make sure it passes?

input_file_name = Path(args.input).name
else:
input_dir = Path(args.input).resolve()
input_files = [file for file in glob.glob(str(input_dir) + "/*", recursive=True) if os.path.isfile(file)]
input_file_name = [os.path.basename(input_file_path) for input_file_path in input_files]
setattr(args, "input_directory", input_dir)
setattr(args, "input_file", input_file_name)
return args
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

otherwise, I think this looks good.



def set_output_directory(args):
"""
set the output directory based on the given input arguments
Expand Down
Loading