Skip to content

Commit 6d8be90

Browse files
committed
added create_parser function to handle application input
1 parent b982382 commit 6d8be90

File tree

1 file changed

+23
-16
lines changed

1 file changed

+23
-16
lines changed

diffpy/snmf/stretchednmfapp.py

+23-16
Original file line numberDiff line numberDiff line change
@@ -1,26 +1,33 @@
11
import numpy as np
2+
import argparse
23

34
from diffpy.snmf.io import load_input_signals, initialize_variables
45

56

7+
def create_parser():
8+
parser = argparse.ArgumentParser(
9+
prog="stretched_nmf",
10+
description="Stretched Nonnegative Matrix Factorization"
11+
)
12+
parser.add_argument('-v', '--version', action='version', help='Print the software version number')
13+
parser.add_argument('-d', '--directory', type=str,
14+
help="Directory containing experimental data. Ensure it is in quotations or apostrophes.")
15+
16+
parser.add_argument('component_number', type=int,
17+
help="The number of component signals to obtain from experimental "
18+
"data. Must be an integer greater than 0.")
19+
parser.add_argument('data_type', type=str, choices=['xrd', 'pdf'], help="The type of the experimental data.")
20+
args = parser.parse_args()
21+
return args
22+
23+
624
def main():
7-
directory_path = input("Specify Path (Optional. Press enter to skip):")
8-
if not directory_path:
9-
directory_path = None
10-
11-
data_type = input("Specify the data type ('xrd' or 'pdf'): ")
12-
if data_type != 'xrd' and data_type != 'pdf':
13-
raise ValueError("The data type must be 'xrd' or 'pdf'")
14-
15-
component_amount = input("\nEnter the amount of components to obtain:")
16-
try:
17-
component_amount = int(component_amount)
18-
except TypeError:
19-
raise TypeError("Please enter an integer greater than 0")
20-
21-
grid, data_input = load_input_signals(directory_path)
22-
variables = initialize_variables(data_input, component_amount, data_type)
25+
args = create_parser()
26+
27+
grid, data_input = load_input_signals(args.directory)
28+
variables = initialize_variables(data_input, args.component_number, args.data_type)
2329
lifted_data = data_input - np.ndarray.min(data_input[:])
30+
return lifted_data
2431

2532

2633
if __name__ == "__main__":

0 commit comments

Comments
 (0)