diff --git a/opxrd/analyser.py b/opxrd/analyser.py index cf119a3..1b1cea6 100644 --- a/opxrd/analyser.py +++ b/opxrd/analyser.py @@ -5,11 +5,15 @@ from matplotlib.axes import Axes from numpy.typing import NDArray from sklearn.decomposition import PCA +from tabulate import tabulate + from xrdpattern.pattern import PatternDB from xrdpattern.pattern import XrdPattern from matplotlib import pyplot as plt from opxrd import OpXRD +from xrdpattern.xrd import LabelType + # ---------------------------------------------------------- class DatabaseAnalyser: @@ -78,27 +82,36 @@ def plot_pattern_dbs(self, title : str): self._plot_reconstructed(pca, example_xy_list, example_pca_coords, title=title) print('done') - def compute_effective_components(self, tolerance : float = 0.10): + def plot_effective_components(self): for db in self.databases: max_components = len(db.patterns) standardized_intensities = [p.get_pattern_data()[1] for p in db.patterns] pca = PCA(n_components=max_components) - pca_coords = pca.fit_transform(standardized_intensities) + db_pca_coords = pca.fit_transform(standardized_intensities) self._plot_reconstructed(pca, example_xy_list=[p.get_pattern_data() for p in db.patterns[:20]], - example_pca_coords=pca_coords[:20], title=db.name) + example_pca_coords=db_pca_coords[:20], title=db.name) - for n_comp in range(max_components): + plot_components = 10 + accuracies = [] + x = np.linspace(0,plot_components/max_components, num=plot_components) + for n_comp in range(plot_components): mismatches = [] for j, p in enumerate(db.patterns): _, i1 = p.get_pattern_data() - i2 = pca.inverse_transform(pca_coords[j]) + limited_pca = db_pca_coords[j][:n_comp] + zero_padded_comp = np.pad(limited_pca, (0, max_components - n_comp)) + i2 = pca.inverse_transform(zero_padded_comp) mismatch = self.compute_mismatch(i1, i2) mismatches.append(mismatch) - avg_mismatch = np.mean(mismatches) - if avg_mismatch < tolerance: - print(f'Database {db.name} has {n_comp} effective components') - break + accuracy = 1-np.mean(mismatches) + accuracies.append(accuracy) + print(f'Computed accuracy for {db.name} with {n_comp} components = {accuracy}') + + plt.plot(x,accuracies) + plt.title(f'Accuracy vs fraction of max components for {db.name}') + + plt.show() # ----------------------- # tools @@ -177,12 +190,41 @@ def compute_mismatch(i1 : NDArray, i2 : NDArray) -> float: mismatch = delta_norm / norm_original return mismatch + def show_label_fractions(self): + table_data = [] + for d in self.databases: + label_counts = {l: 0 for l in LabelType} + patterns = d.patterns + for l in LabelType: + for p in patterns: + if p.has_label(label_type=l): + label_counts[l] += 1 + db_percentages = [label_counts[l] / len(patterns) for l in LabelType] + table_data.append(db_percentages) + + col_headers = [label.name for label in LabelType] + row_headers = [db.name for db in self.databases] + + table = tabulate(table_data, headers=col_headers, showindex=row_headers, tablefmt='psql') + print(table) + + def print_total_counts(self): + num_total = len(self.get_all_patterns()) + + labeled_patterns = [p for p in self.get_all_patterns() if p.is_labeled()] + num_labelel = len(labeled_patterns) + print(f'Total number of patterns = {num_total}') + print(f'Number of labeled patterns = {num_labelel}') + if __name__ == "__main__": test_dirpath = '/tmp/opxrd_test' full_dirpath = '/home/daniel/aimat/data/opXRD/final/' opxrd_databases = OpXRD.as_database_list(root_dirpath=test_dirpath) analyser = DatabaseAnalyser(databases=opxrd_databases, output_dirpath='/tmp/opxrd_analysis') + analyser.plot_effective_components() + # analyser.plot_databases_in_single() # analyser.compute_effective_components() - analyser.plot_fourier() \ No newline at end of file + # analyser.plot_fourier() + # analyser.print_total_counts() \ No newline at end of file