Skip to content

Commit

Permalink
analyser: Added fourier decomposition
Browse files Browse the repository at this point in the history
  • Loading branch information
Somerandomguy10111 committed Dec 14, 2024
1 parent 8ee261e commit c0cc7c3
Showing 1 changed file with 66 additions and 37 deletions.
103 changes: 66 additions & 37 deletions opxrd/analyser.py
Original file line number Diff line number Diff line change
Expand Up @@ -27,24 +27,40 @@ def plot_databases_in_single(self):
for database in self.databases:
database.show_all(single_plot=True, limit_patterns=10)

def plot_fourier(self, x, y, max_freq=10):
N = len(y) # Number of sample points
T = (x[-1] - x[0]) / (N - 1) # Sample spacing
yf = np.fft.fft(y) # Perform FFT
xf = np.fft.fftfreq(N, T)[:N // 2] # Frequency axis
def plot_fourier(self, max_freq=2):
for db in self.databases:
fig, ax = plt.subplots(figsize=(10, 4), dpi=300)
ref_x, _ = db.patterns[0].get_pattern_data()
ref_y = np.zeros(shape=len(ref_x))

magnitude = 2.0 / N * np.abs(yf[:N // 2])
fig, ax = plt.subplots(figsize=(10, 4))
self._set_ax_properties(ax, title='Fourier Transform', xlabel='Frequency (Hz)', ylabel='Magnitude')
ax.grid(True)
size = len(ref_y)
mean = size // 2
std_dev = 200

if max_freq is not None:
valid_indices = xf <= max_freq
plt.plot(xf[valid_indices], magnitude[valid_indices])
else:
plt.plot(xf, magnitude)
x = np.linspace(0, size - 1, size)
gaussian = np.exp(-(x - mean) ** 2 / (2 * std_dev ** 2))
noise = np.random.normal(0, 0.025, size)
noisy_gaussian = gaussian + noise
ref_y += noisy_gaussian

plt.show()

for p in db.patterns[:10]:
x,y = p.get_pattern_data()
plt.plot(x, y, linewidth=0.75, linestyle='--', alpha=0.75)
plt.plot(ref_x, ref_y, alpha=0.75)
plt.show()

for p in db.patterns[:10]:
x,y = p.get_pattern_data()
xf, yf = self.compute_fourier_transform(x, y, max_freq)

self._set_ax_properties(ax, title='Fourier Transform', xlabel='Frequency (Hz)', ylabel='Magnitude')
ax.grid(True)
plt.plot(xf, yf, linewidth=0.75, linestyle='--', alpha=0.75)
ft_ref_x, ft_ref_y = self.compute_fourier_transform(ref_x, ref_y, max_freq)
plt.plot(ft_ref_x, ft_ref_y, alpha=0.75)

plt.show()

def plot_pattern_dbs(self, title : str):
combined_pattern_list = self.get_all_patterns()
Expand All @@ -62,6 +78,28 @@ def plot_pattern_dbs(self, title : str):
self._plot_reconstructed(pca, example_xy_list, example_pca_coords, title=title)
print('done')

def compute_effective_components(self, tolerance : float = 0.10):
for db in self.databases:
max_components = len(db.patterns)
standardized_intensities = [p.get_pattern_data()[1] for p in db.patterns]
pca = PCA(n_components=max_components)
pca_coords = pca.fit_transform(standardized_intensities)

self._plot_reconstructed(pca, example_xy_list=[p.get_pattern_data() for p in db.patterns[:20]],
example_pca_coords=pca_coords[:20], title=db.name)

for n_comp in range(max_components):
mismatches = []
for j, p in enumerate(db.patterns):
_, i1 = p.get_pattern_data()
i2 = pca.inverse_transform(pca_coords[j])
mismatch = self.compute_mismatch(i1, i2)
mismatches.append(mismatch)
avg_mismatch = np.mean(mismatches)
if avg_mismatch < tolerance:
print(f'Database {db.name} has {n_comp} effective components')
break

# -----------------------
# tools

Expand Down Expand Up @@ -118,29 +156,19 @@ def _set_ax_properties(ax : Axes, title : str, xlabel : str, ylabel : str):
def get_all_patterns(self) -> list[XrdPattern]:
return self.joined_db.patterns

@staticmethod
def compute_fourier_transform(x,y, max_freq : float):
N = len(y) # Number of sample points
T = (x[-1] - x[0]) / (N - 1) # Sample spacing
yf = np.fft.fft(y) # Perform FFT
xf = np.fft.fftfreq(N, T)[:N // 2] # Frequency axis

def compute_effective_components(self, tolerance : float = 0.10):
for db in self.databases:
max_components = len(db.patterns)
standardized_intensities = [p.get_pattern_data()[1] for p in db.patterns]
pca = PCA(n_components=max_components)
pca_coords = pca.fit_transform(standardized_intensities)

self._plot_reconstructed(pca, example_xy_list=[p.get_pattern_data() for p in db.patterns[:20]],
example_pca_coords=pca_coords[:20], title=db.name)

for n_comp in range(max_components):
mismatches = []
for j, p in enumerate(db.patterns):
_, i1 = p.get_pattern_data()
i2 = pca.inverse_transform(pca_coords[j])
mismatch = self.compute_mismatch(i1, i2)
mismatches.append(mismatch)
avg_mismatch = np.mean(mismatches)
if avg_mismatch < tolerance:
print(f'Database {db.name} has {n_comp} effective components')
break
magnitude = 2.0 / N * np.abs(yf[:N // 2])
valid_indices = xf <= max_freq

xf = xf[valid_indices]
yf = magnitude[valid_indices]
return xf, yf

@staticmethod
def compute_mismatch(i1 : NDArray, i2 : NDArray) -> float:
Expand All @@ -156,4 +184,5 @@ def compute_mismatch(i1 : NDArray, i2 : NDArray) -> float:
opxrd_databases = OpXRD.as_database_list(root_dirpath=test_dirpath)
analyser = DatabaseAnalyser(databases=opxrd_databases, output_dirpath='/tmp/opxrd_analysis')
# analyser.plot_databases_in_single()
analyser.compute_effective_components()
# analyser.compute_effective_components()
analyser.plot_fourier()

0 comments on commit c0cc7c3

Please sign in to comment.