|
| 1 | +import os |
| 2 | + |
| 3 | +import numpy as np |
| 4 | +from matplotlib.axes import Axes |
| 5 | +from sklearn.decomposition import PCA |
| 6 | +from xrdpattern.pattern import PatternDB |
| 7 | +from xrdpattern.pattern import XrdPattern |
| 8 | +from matplotlib import pyplot as plt |
| 9 | + |
| 10 | +from opxrd import OpXRD |
| 11 | +# ---------------------------------------------------------- |
| 12 | + |
| 13 | +class DatabaseAnalyser: |
| 14 | + def __init__(self, databases : list[PatternDB], output_dirpath : str): |
| 15 | + if len(databases) == 0: |
| 16 | + raise ValueError('No databases provided') |
| 17 | + self.databases : list[PatternDB] = databases |
| 18 | + self.joined_db : PatternDB = PatternDB.merge(databases) |
| 19 | + self.output_dirpath : str = output_dirpath |
| 20 | + os.makedirs(self.output_dirpath, exist_ok=True) |
| 21 | + |
| 22 | + @staticmethod |
| 23 | + def plot_in_single(patterns : list[XrdPattern]): |
| 24 | + data = [p.get_pattern_data() for p in patterns] |
| 25 | + fig, ax = plt.subplots() |
| 26 | + for x, y in data: |
| 27 | + ax.plot(x, y, linewidth=0.1) |
| 28 | + |
| 29 | + ax.set_xlabel('X Label') |
| 30 | + ax.set_ylabel('Y Label') |
| 31 | + ax.set_title('Multiple XY Plots') |
| 32 | + plt.show() |
| 33 | + |
| 34 | + |
| 35 | + def plot_fourier(self, x, y, max_freq=10): |
| 36 | + N = len(y) # Number of sample points |
| 37 | + T = (x[-1] - x[0]) / (N - 1) # Sample spacing |
| 38 | + yf = np.fft.fft(y) # Perform FFT |
| 39 | + xf = np.fft.fftfreq(N, T)[:N // 2] # Frequency axis |
| 40 | + |
| 41 | + magnitude = 2.0 / N * np.abs(yf[:N // 2]) |
| 42 | + fig, ax = plt.subplots(figsize=(10, 4)) |
| 43 | + self._set_ax_properties(ax, title='Fourier Transform', xlabel='Frequency (Hz)', ylabel='Magnitude') |
| 44 | + ax.grid(True) |
| 45 | + |
| 46 | + if max_freq is not None: |
| 47 | + valid_indices = xf <= max_freq |
| 48 | + plt.plot(xf[valid_indices], magnitude[valid_indices]) |
| 49 | + else: |
| 50 | + plt.plot(xf, magnitude) |
| 51 | + |
| 52 | + plt.show() |
| 53 | + |
| 54 | + def plot_pattern_dbs(self, title : str): |
| 55 | + combined_pattern_list = self.get_all_patterns() |
| 56 | + xy_list = [p.get_pattern_data() for p in combined_pattern_list] |
| 57 | + combined_y_list = [y for x, y in xy_list] |
| 58 | + |
| 59 | + pca = PCA(n_components=2) |
| 60 | + transformed_data = pca.fit_transform(combined_y_list) |
| 61 | + |
| 62 | + rand_indices = [np.random.randint(low=0, high=len(combined_y_list)) for _ in range(10)] |
| 63 | + example_xy_list = [combined_pattern_list[idx].get_pattern_data() for idx in rand_indices] |
| 64 | + example_pca_coords = [transformed_data[idx] for idx in rand_indices] |
| 65 | + |
| 66 | + self._plot_pca_scatter(transformed_data, title=title) |
| 67 | + self._plot_pca_basis(pca, title=title) |
| 68 | + self._plot_reconstructed(pca, example_xy_list, example_pca_coords, title=title) |
| 69 | + print('done') |
| 70 | + |
| 71 | + # ----------------------- |
| 72 | + # tools |
| 73 | + |
| 74 | + def _plot_pca_scatter(self, transformed_data, title : str): |
| 75 | + db_lens = [len(db.patterns) for db in self.databases] |
| 76 | + max_points = 50 |
| 77 | + for j, l in enumerate(db_lens): |
| 78 | + partial = transformed_data[:l] |
| 79 | + if l > max_points: |
| 80 | + indices = np.random.choice(len(partial), size=max_points, replace=False) |
| 81 | + partial = partial[indices] |
| 82 | + plt.scatter(partial[:, 0], partial[:, 1], label=f'db number {j}') |
| 83 | + transformed_data = transformed_data[l:] |
| 84 | + |
| 85 | + plt.title(f'(1): Two Component PCA Scatter Plot for {title}') |
| 86 | + plt.xlabel('Component 1') |
| 87 | + plt.ylabel('Component 2') |
| 88 | + plt.legend() |
| 89 | + plt.savefig(f'{self.output_dirpath}pca_scatter_{title}.png') |
| 90 | + plt.show() |
| 91 | + |
| 92 | + def _plot_pca_basis(self, pca, title : str): |
| 93 | + b1, b2 = pca.inverse_transform(np.array([1,0])), pca.inverse_transform(np.array([0,1])) |
| 94 | + x = np.linspace(start=0,stop=180, num=len(b1)) |
| 95 | + plt.plot(x, b1) |
| 96 | + plt.plot(x, b2) |
| 97 | + plt.title(f'(2): Principal Components for {title}') |
| 98 | + plt.savefig(f'{self.output_dirpath}pca_basis_{title}.png') |
| 99 | + plt.show() |
| 100 | + |
| 101 | + def _plot_reconstructed(self, pca, example_xy_list, example_pca_coords, title): |
| 102 | + fig, axs = plt.subplots(len(example_xy_list), 2, figsize=(10, 5 * len(example_xy_list))) |
| 103 | + fig.suptitle(f'(3): Comparison of Original and Reconstructed Patterns for {title}', fontsize=16) |
| 104 | + for index, ((x1, y1), pca_coords) in enumerate(zip(example_xy_list, example_pca_coords)): |
| 105 | + axs[index, 0].plot(x1, y1, 'b-') |
| 106 | + self._set_ax_properties(axs[index, 0], title='Original pattern', xlabel='x', ylabel='Relative intensity') |
| 107 | + |
| 108 | + reconstructed = pca.inverse_transform(pca_coords) |
| 109 | + x = np.linspace(start=0, stop=180, num=len(reconstructed)) |
| 110 | + axs[index, 1].plot(x, reconstructed, 'r-') |
| 111 | + self._set_ax_properties(axs[index, 1], title='Reconstructed pattern', xlabel='x', |
| 112 | + ylabel='Relative intensity') |
| 113 | + |
| 114 | + plt.tight_layout() |
| 115 | + plt.savefig(f'{self.output_dirpath}reconstructed_{title}.png') |
| 116 | + plt.show() |
| 117 | + |
| 118 | + @staticmethod |
| 119 | + def _set_ax_properties(ax : Axes, title : str, xlabel : str, ylabel : str): |
| 120 | + ax.set_title(title) |
| 121 | + ax.set_xlabel(xlabel) |
| 122 | + ax.set_ylabel(ylabel) |
| 123 | + |
| 124 | + def get_all_patterns(self) -> list[XrdPattern]: |
| 125 | + return self.joined_db.patterns |
| 126 | + |
| 127 | + |
| 128 | +if __name__ == "__main__": |
| 129 | + test_dirpath = '/tmp/opxrd_test' |
| 130 | + full_dirpath = '/tmp/opxrd' |
| 131 | + |
| 132 | + opxrd_databases = OpXRD.as_database_list(root_dirpath=full_dirpath) |
| 133 | + analyser = DatabaseAnalyser(databases=opxrd_databases, output_dirpath='/tmp/opxrd_analysis') |
| 134 | + |
| 135 | + opxrd= OpXRD.load(root_dirpath=test_dirpath) |
| 136 | + opxrd.show_histograms(save_fpath=f'/tmp/quantities_hist.png',attach_colorbar=False) |
| 137 | + |
0 commit comments