Skip to content

Commit

Permalink
opxrd/analyser: Added plot_effective_components
Browse files Browse the repository at this point in the history
  • Loading branch information
Somerandomguy10111 committed Dec 14, 2024
1 parent c0cc7c3 commit b6a108d
Showing 1 changed file with 52 additions and 10 deletions.
62 changes: 52 additions & 10 deletions opxrd/analyser.py
Original file line number Diff line number Diff line change
Expand Up @@ -5,11 +5,15 @@
from matplotlib.axes import Axes
from numpy.typing import NDArray
from sklearn.decomposition import PCA
from tabulate import tabulate

from xrdpattern.pattern import PatternDB
from xrdpattern.pattern import XrdPattern
from matplotlib import pyplot as plt

from opxrd import OpXRD
from xrdpattern.xrd import LabelType

# ----------------------------------------------------------

class DatabaseAnalyser:
Expand Down Expand Up @@ -78,27 +82,36 @@ def plot_pattern_dbs(self, title : str):
self._plot_reconstructed(pca, example_xy_list, example_pca_coords, title=title)
print('done')

def compute_effective_components(self, tolerance : float = 0.10):
def plot_effective_components(self):
for db in self.databases:
max_components = len(db.patterns)
standardized_intensities = [p.get_pattern_data()[1] for p in db.patterns]
pca = PCA(n_components=max_components)
pca_coords = pca.fit_transform(standardized_intensities)
db_pca_coords = pca.fit_transform(standardized_intensities)

self._plot_reconstructed(pca, example_xy_list=[p.get_pattern_data() for p in db.patterns[:20]],
example_pca_coords=pca_coords[:20], title=db.name)
example_pca_coords=db_pca_coords[:20], title=db.name)

for n_comp in range(max_components):
plot_components = 10
accuracies = []
x = np.linspace(0,plot_components/max_components, num=plot_components)
for n_comp in range(plot_components):
mismatches = []
for j, p in enumerate(db.patterns):
_, i1 = p.get_pattern_data()
i2 = pca.inverse_transform(pca_coords[j])
limited_pca = db_pca_coords[j][:n_comp]
zero_padded_comp = np.pad(limited_pca, (0, max_components - n_comp))
i2 = pca.inverse_transform(zero_padded_comp)
mismatch = self.compute_mismatch(i1, i2)
mismatches.append(mismatch)
avg_mismatch = np.mean(mismatches)
if avg_mismatch < tolerance:
print(f'Database {db.name} has {n_comp} effective components')
break
accuracy = 1-np.mean(mismatches)
accuracies.append(accuracy)
print(f'Computed accuracy for {db.name} with {n_comp} components = {accuracy}')

plt.plot(x,accuracies)
plt.title(f'Accuracy vs fraction of max components for {db.name}')

plt.show()

# -----------------------
# tools
Expand Down Expand Up @@ -177,12 +190,41 @@ def compute_mismatch(i1 : NDArray, i2 : NDArray) -> float:
mismatch = delta_norm / norm_original
return mismatch

def show_label_fractions(self):
table_data = []
for d in self.databases:
label_counts = {l: 0 for l in LabelType}
patterns = d.patterns
for l in LabelType:
for p in patterns:
if p.has_label(label_type=l):
label_counts[l] += 1
db_percentages = [label_counts[l] / len(patterns) for l in LabelType]
table_data.append(db_percentages)

col_headers = [label.name for label in LabelType]
row_headers = [db.name for db in self.databases]

table = tabulate(table_data, headers=col_headers, showindex=row_headers, tablefmt='psql')
print(table)

def print_total_counts(self):
num_total = len(self.get_all_patterns())

labeled_patterns = [p for p in self.get_all_patterns() if p.is_labeled()]
num_labelel = len(labeled_patterns)
print(f'Total number of patterns = {num_total}')
print(f'Number of labeled patterns = {num_labelel}')


if __name__ == "__main__":
test_dirpath = '/tmp/opxrd_test'
full_dirpath = '/home/daniel/aimat/data/opXRD/final/'
opxrd_databases = OpXRD.as_database_list(root_dirpath=test_dirpath)
analyser = DatabaseAnalyser(databases=opxrd_databases, output_dirpath='/tmp/opxrd_analysis')
analyser.plot_effective_components()

# analyser.plot_databases_in_single()
# analyser.compute_effective_components()
analyser.plot_fourier()
# analyser.plot_fourier()
# analyser.print_total_counts()

0 comments on commit b6a108d

Please sign in to comment.