Skip to content

Commit b982382

Browse files
committed
updated main, fixed io
1 parent dfa4910 commit b982382

File tree

2 files changed

+26
-3
lines changed

2 files changed

+26
-3
lines changed

diffpy/snmf/io.py

+4-2
Original file line numberDiff line numberDiff line change
@@ -44,14 +44,16 @@ def initialize_variables(data_input, component_amount, data_type, sparsity=1, sm
4444

4545
component_matrix_guess = np.random.rand(signal_length, component_amount)
4646
weight_matrix_guess = np.random.rand(component_amount, moment_amount)
47-
stretching_matrix_guess = np.ones(component_amount, moment_amount) + np.random.randn(component_amount,
47+
stretching_matrix_guess = np.ones((component_amount, moment_amount)) + np.random.randn(component_amount,
4848
moment_amount) * 1e-3
4949

5050
diagonals = [np.ones(moment_amount - 2), -2 * np.ones(moment_amount - 2), np.ones(moment_amount - 2)]
5151
smoothness_term = .25 * scipy.sparse.diags(diagonals, [0, 1, 2], shape=(moment_amount - 2, moment_amount))
5252

5353
hessian_helper_matrix = scipy.sparse.block_diag([smoothness_term.T @ smoothness_term] * component_amount)
5454
sequence = np.arange(moment_amount * component_amount).reshape(component_amount, moment_amount).T.flatten()
55+
56+
hessian_helper_matrix = hessian_helper_matrix.tocsr()
5557
hessian_helper_matrix = hessian_helper_matrix[sequence, :][:, sequence]
5658

5759
return {
@@ -102,7 +104,7 @@ def load_input_signals(file_path=None):
102104
for item in directory_path.iterdir():
103105
if item.is_file():
104106
data = loadData(item.resolve())
105-
if current_grid and current_grid != data[:, 0]:
107+
if len(current_grid) != 0 and (current_grid != data[:, 0]).any():
106108
print(f"{item.name} was ignored as it is not on a compatible grid.")
107109
continue
108110
else:

diffpy/snmf/stretchednmfapp.py

+22-1
Original file line numberDiff line numberDiff line change
@@ -1,5 +1,26 @@
1+
import numpy as np
2+
3+
from diffpy.snmf.io import load_input_signals, initialize_variables
4+
5+
16
def main():
2-
print("Hello World!")
7+
directory_path = input("Specify Path (Optional. Press enter to skip):")
8+
if not directory_path:
9+
directory_path = None
10+
11+
data_type = input("Specify the data type ('xrd' or 'pdf'): ")
12+
if data_type != 'xrd' and data_type != 'pdf':
13+
raise ValueError("The data type must be 'xrd' or 'pdf'")
14+
15+
component_amount = input("\nEnter the amount of components to obtain:")
16+
try:
17+
component_amount = int(component_amount)
18+
except TypeError:
19+
raise TypeError("Please enter an integer greater than 0")
20+
21+
grid, data_input = load_input_signals(directory_path)
22+
variables = initialize_variables(data_input, component_amount, data_type)
23+
lifted_data = data_input - np.ndarray.min(data_input[:])
324

425

526
if __name__ == "__main__":

0 commit comments

Comments
 (0)