diff --git a/diffpy/snmf/io.py b/diffpy/snmf/io.py index f666949..150065d 100644 --- a/diffpy/snmf/io.py +++ b/diffpy/snmf/io.py @@ -2,6 +2,10 @@ import scipy.sparse from pathlib import Path from diffpy.utils.parsers.loaddata import loadData +import matplotlib.pyplot as plt +from bg_mpl_stylesheet.bg_mpl_stylesheet import bg_mpl_style + +plt.style.use(bg_mpl_style) def initialize_variables(data_input, component_amount, data_type, sparsity=1, smoothness=1e18): @@ -114,3 +118,32 @@ def load_input_signals(file_path=None): grid_vector = np.unique(grid_array, axis=1) values_array = np.column_stack(values_list) return grid_vector, values_array + + +def drawfig(moment_amount, stretching_matrix, weight_matrix, grid_vector): + plt.ion() + fig = plt.figure() + grid = plt.GridSpec(4, 4) + stretching_plot = fig.add_subplot(grid[0:2, 0:2]) + weight_plot = fig.add_subplot(grid[2:, 0:2]) + component_plot = fig.add_subplot(grid[:, 2:]) + + stretching_plot.plot(stretching_matrix.T) + stretching_plot.set_title("Component Stretching Factors") + stretching_plot.set_xlabel("Moment") + stretching_plot.set_ylabel("Stretching Factor") + + weight_plot.plot(weight_matrix.T) + weight_plot.set_title("Component Weights") + weight_plot.set_ylabel("Weight") + weight_plot.set_xlabel("Moment") + + component_plot.plot(grid_vector, component_plot) + component_plot.set_title("Component Signals") + component_plot.sex_ylabel("g(r)") + component_plot.set_xlabel("r") + + fig.canvas.draw() + fig.canvas.flush_events() + plt.tight_layout() + plt.show()