Skip to content

Commit d2efc30

Browse files
author
John Halloran
committed
feat: use hessian in optimization
1 parent 27ea989 commit d2efc30

File tree

1 file changed

+38
-24
lines changed

1 file changed

+38
-24
lines changed

src/diffpy/snmf/snmf_class.py

Lines changed: 38 additions & 24 deletions
Original file line numberDiff line numberDiff line change
@@ -69,8 +69,8 @@ def __init__(
6969
init_stretch=None,
7070
rho=0,
7171
eta=0,
72-
max_iter=500,
73-
tol=5e-7,
72+
max_iter=300,
73+
tol=1e-6,
7474
n_components=None,
7575
random_state=None,
7676
):
@@ -231,24 +231,26 @@ def __init__(
231231
print("Finished optimization.")
232232

233233
def optimize_loop(self):
234-
# Update components first
235-
self._prev_grad_components = self.grad_components.copy()
236-
self.update_components()
237-
self.num_updates += 1
238-
self.residuals = self.get_residual_matrix()
239-
self.objective_function = self.get_objective_function()
240-
print(f"Objective function after update_components: {self.objective_function:.5e}")
241-
self._objective_history.append(self.objective_function)
242-
if self.objective_difference is None:
243-
self.objective_difference = self._objective_history[-1] - self.objective_function
244234

245-
# Now we update weights
246-
self.update_weights()
247-
self.num_updates += 1
248-
self.residuals = self.get_residual_matrix()
249-
self.objective_function = self.get_objective_function()
250-
print(f"Objective function after update_weights: {self.objective_function:.5e}")
251-
self._objective_history.append(self.objective_function)
235+
for i in range(4):
236+
# Update components first
237+
self._prev_grad_components = self.grad_components.copy()
238+
self.update_components()
239+
self.num_updates += 1
240+
self.residuals = self.get_residual_matrix()
241+
self.objective_function = self.get_objective_function()
242+
print(f"Objective function after update_components: {self.objective_function:.5e}")
243+
self._objective_history.append(self.objective_function)
244+
if self.objective_difference is None:
245+
self.objective_difference = self._objective_history[-1] - self.objective_function
246+
247+
# Now we update weights
248+
self.update_weights()
249+
self.num_updates += 1
250+
self.residuals = self.get_residual_matrix()
251+
self.objective_function = self.get_objective_function()
252+
print(f"Objective function after update_weights: {self.objective_function:.5e}")
253+
self._objective_history.append(self.objective_function)
252254

253255
# Now we update stretch
254256
self.update_stretch()
@@ -488,7 +490,7 @@ def apply_transformation_matrix(self, stretch=None, weights=None, residuals=None
488490

489491
return stretch_transformed
490492

491-
def solve_quadratic_program(self, t, m, alg="trust-constr"):
493+
def solve_quadratic_program(self, t, m, alg="L-BFGS-B"):
492494
"""
493495
Solves the quadratic program for updating y in stretched NMF using scipy.optimize:
494496
@@ -588,7 +590,7 @@ def update_components(self):
588590
+ self.eta * np.sqrt(self.components)
589591
< 0
590592
)
591-
self.components = mask * self.components
593+
self.components[~mask] = 0
592594

593595
objective_improvement = self._objective_history[-1] - self.get_objective_function(
594596
residuals=self.get_residual_matrix()
@@ -656,7 +658,18 @@ def regularize_function(self, stretch=None):
656658
der_reshaped.T + self.rho * stretch @ self._spline_smooth_operator.T @ self._spline_smooth_operator
657659
)
658660

659-
return reg_func, func_grad
661+
# Hessian: diagonal of second derivatives
662+
hess_diag_vals = np.sum(
663+
dd_stretch_components * np.tile(stretch_difference, (1, self.n_components)), axis=0
664+
).ravel(order="F")
665+
666+
# Add the spline regularization Hessian (rho * PPPP)
667+
smooth_hess = self.rho * np.kron(self._spline_smooth_penalty.toarray(), np.eye(self.n_components))
668+
669+
full_hess_diag = hess_diag_vals + np.diag(smooth_hess)
670+
hessian = diags(full_hess_diag, format="csc")
671+
672+
return reg_func, func_grad, hessian
660673

661674
def update_stretch(self):
662675
"""
@@ -669,9 +682,9 @@ def update_stretch(self):
669682
# Define the optimization function
670683
def objective(stretch_vec):
671684
stretch_matrix = stretch_vec.reshape(self.stretch.shape) # Reshape back to matrix form
672-
func, grad = self.regularize_function(stretch_matrix)
685+
func, grad, hess = self.regularize_function(stretch_matrix)
673686
grad = grad.flatten()
674-
return func, grad
687+
return func, grad, hess
675688

676689
# Optimization constraints: lower bound 0.1, no upper bound
677690
bounds = [(0.1, None)] * stretch_init_vec.size # Equivalent to 0.1 * ones(K, M)
@@ -682,6 +695,7 @@ def objective(stretch_vec):
682695
x0=stretch_init_vec, # Initial guess
683696
method="trust-constr", # Equivalent to 'trust-region-reflective'
684697
jac=lambda stretch_vec: objective(stretch_vec)[1], # Gradient
698+
hess=lambda stretch_vec: objective(stretch_vec)[2],
685699
bounds=bounds, # Lower bounds on stretch
686700
# TODO: A Hessian can be incorporated for better convergence.
687701
)

0 commit comments

Comments
 (0)