Skip to content

Commit c4dec5c

Browse files
opxrd/analysis: Prepared analysis module
1 parent ef11894 commit c4dec5c

File tree

2 files changed

+138
-1
lines changed

2 files changed

+138
-1
lines changed

opxrd/analysis.py

Lines changed: 137 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,137 @@
1+
import os
2+
3+
import numpy as np
4+
from matplotlib.axes import Axes
5+
from sklearn.decomposition import PCA
6+
from xrdpattern.pattern import PatternDB
7+
from xrdpattern.pattern import XrdPattern
8+
from matplotlib import pyplot as plt
9+
10+
from opxrd import OpXRD
11+
# ----------------------------------------------------------
12+
13+
class DatabaseAnalyser:
14+
def __init__(self, databases : list[PatternDB], output_dirpath : str):
15+
if len(databases) == 0:
16+
raise ValueError('No databases provided')
17+
self.databases : list[PatternDB] = databases
18+
self.joined_db : PatternDB = PatternDB.merge(databases)
19+
self.output_dirpath : str = output_dirpath
20+
os.makedirs(self.output_dirpath, exist_ok=True)
21+
22+
@staticmethod
23+
def plot_in_single(patterns : list[XrdPattern]):
24+
data = [p.get_pattern_data() for p in patterns]
25+
fig, ax = plt.subplots()
26+
for x, y in data:
27+
ax.plot(x, y, linewidth=0.1)
28+
29+
ax.set_xlabel('X Label')
30+
ax.set_ylabel('Y Label')
31+
ax.set_title('Multiple XY Plots')
32+
plt.show()
33+
34+
35+
def plot_fourier(self, x, y, max_freq=10):
36+
N = len(y) # Number of sample points
37+
T = (x[-1] - x[0]) / (N - 1) # Sample spacing
38+
yf = np.fft.fft(y) # Perform FFT
39+
xf = np.fft.fftfreq(N, T)[:N // 2] # Frequency axis
40+
41+
magnitude = 2.0 / N * np.abs(yf[:N // 2])
42+
fig, ax = plt.subplots(figsize=(10, 4))
43+
self._set_ax_properties(ax, title='Fourier Transform', xlabel='Frequency (Hz)', ylabel='Magnitude')
44+
ax.grid(True)
45+
46+
if max_freq is not None:
47+
valid_indices = xf <= max_freq
48+
plt.plot(xf[valid_indices], magnitude[valid_indices])
49+
else:
50+
plt.plot(xf, magnitude)
51+
52+
plt.show()
53+
54+
def plot_pattern_dbs(self, title : str):
55+
combined_pattern_list = self.get_all_patterns()
56+
xy_list = [p.get_pattern_data() for p in combined_pattern_list]
57+
combined_y_list = [y for x, y in xy_list]
58+
59+
pca = PCA(n_components=2)
60+
transformed_data = pca.fit_transform(combined_y_list)
61+
62+
rand_indices = [np.random.randint(low=0, high=len(combined_y_list)) for _ in range(10)]
63+
example_xy_list = [combined_pattern_list[idx].get_pattern_data() for idx in rand_indices]
64+
example_pca_coords = [transformed_data[idx] for idx in rand_indices]
65+
66+
self._plot_pca_scatter(transformed_data, title=title)
67+
self._plot_pca_basis(pca, title=title)
68+
self._plot_reconstructed(pca, example_xy_list, example_pca_coords, title=title)
69+
print('done')
70+
71+
# -----------------------
72+
# tools
73+
74+
def _plot_pca_scatter(self, transformed_data, title : str):
75+
db_lens = [len(db.patterns) for db in self.databases]
76+
max_points = 50
77+
for j, l in enumerate(db_lens):
78+
partial = transformed_data[:l]
79+
if l > max_points:
80+
indices = np.random.choice(len(partial), size=max_points, replace=False)
81+
partial = partial[indices]
82+
plt.scatter(partial[:, 0], partial[:, 1], label=f'db number {j}')
83+
transformed_data = transformed_data[l:]
84+
85+
plt.title(f'(1): Two Component PCA Scatter Plot for {title}')
86+
plt.xlabel('Component 1')
87+
plt.ylabel('Component 2')
88+
plt.legend()
89+
plt.savefig(f'{self.output_dirpath}pca_scatter_{title}.png')
90+
plt.show()
91+
92+
def _plot_pca_basis(self, pca, title : str):
93+
b1, b2 = pca.inverse_transform(np.array([1,0])), pca.inverse_transform(np.array([0,1]))
94+
x = np.linspace(start=0,stop=180, num=len(b1))
95+
plt.plot(x, b1)
96+
plt.plot(x, b2)
97+
plt.title(f'(2): Principal Components for {title}')
98+
plt.savefig(f'{self.output_dirpath}pca_basis_{title}.png')
99+
plt.show()
100+
101+
def _plot_reconstructed(self, pca, example_xy_list, example_pca_coords, title):
102+
fig, axs = plt.subplots(len(example_xy_list), 2, figsize=(10, 5 * len(example_xy_list)))
103+
fig.suptitle(f'(3): Comparison of Original and Reconstructed Patterns for {title}', fontsize=16)
104+
for index, ((x1, y1), pca_coords) in enumerate(zip(example_xy_list, example_pca_coords)):
105+
axs[index, 0].plot(x1, y1, 'b-')
106+
self._set_ax_properties(axs[index, 0], title='Original pattern', xlabel='x', ylabel='Relative intensity')
107+
108+
reconstructed = pca.inverse_transform(pca_coords)
109+
x = np.linspace(start=0, stop=180, num=len(reconstructed))
110+
axs[index, 1].plot(x, reconstructed, 'r-')
111+
self._set_ax_properties(axs[index, 1], title='Reconstructed pattern', xlabel='x',
112+
ylabel='Relative intensity')
113+
114+
plt.tight_layout()
115+
plt.savefig(f'{self.output_dirpath}reconstructed_{title}.png')
116+
plt.show()
117+
118+
@staticmethod
119+
def _set_ax_properties(ax : Axes, title : str, xlabel : str, ylabel : str):
120+
ax.set_title(title)
121+
ax.set_xlabel(xlabel)
122+
ax.set_ylabel(ylabel)
123+
124+
def get_all_patterns(self) -> list[XrdPattern]:
125+
return self.joined_db.patterns
126+
127+
128+
if __name__ == "__main__":
129+
test_dirpath = '/tmp/opxrd_test'
130+
full_dirpath = '/tmp/opxrd'
131+
132+
opxrd_databases = OpXRD.as_database_list(root_dirpath=full_dirpath)
133+
analyser = DatabaseAnalyser(databases=opxrd_databases, output_dirpath='/tmp/opxrd_analysis')
134+
135+
opxrd= OpXRD.load(root_dirpath=test_dirpath)
136+
opxrd.show_histograms(save_fpath=f'/tmp/quantities_hist.png',attach_colorbar=False)
137+

opxrd/wrapper/opxrd.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -26,7 +26,7 @@ def load(cls, root_dirpath : str, download : bool = True, download_in_situ : boo
2626
@classmethod
2727
def as_database_list(cls, root_dirpath : str, download : bool = True, download_in_situ : bool = False) -> list[PatternDB]:
2828
if not os.path.isdir(root_dirpath) and download:
29-
cls._prepare_files(root_dirpath=root_dirpath)
29+
cls._prepare_files(root_dirpath=root_dirpath, include_in_situ=download_in_situ)
3030

3131
pattern_dbs = []
3232
print(f'- Loading databases from {root_dirpath}')

0 commit comments

Comments
 (0)