@@ -69,8 +69,8 @@ def __init__(
69
69
init_stretch = None ,
70
70
rho = 0 ,
71
71
eta = 0 ,
72
- max_iter = 500 ,
73
- tol = 5e-7 ,
72
+ max_iter = 300 ,
73
+ tol = 1e-6 ,
74
74
n_components = None ,
75
75
random_state = None ,
76
76
):
@@ -231,24 +231,26 @@ def __init__(
231
231
print ("Finished optimization." )
232
232
233
233
def optimize_loop (self ):
234
- # Update components first
235
- self ._prev_grad_components = self .grad_components .copy ()
236
- self .update_components ()
237
- self .num_updates += 1
238
- self .residuals = self .get_residual_matrix ()
239
- self .objective_function = self .get_objective_function ()
240
- print (f"Objective function after update_components: { self .objective_function :.5e} " )
241
- self ._objective_history .append (self .objective_function )
242
- if self .objective_difference is None :
243
- self .objective_difference = self ._objective_history [- 1 ] - self .objective_function
244
234
245
- # Now we update weights
246
- self .update_weights ()
247
- self .num_updates += 1
248
- self .residuals = self .get_residual_matrix ()
249
- self .objective_function = self .get_objective_function ()
250
- print (f"Objective function after update_weights: { self .objective_function :.5e} " )
251
- self ._objective_history .append (self .objective_function )
235
+ for i in range (4 ):
236
+ # Update components first
237
+ self ._prev_grad_components = self .grad_components .copy ()
238
+ self .update_components ()
239
+ self .num_updates += 1
240
+ self .residuals = self .get_residual_matrix ()
241
+ self .objective_function = self .get_objective_function ()
242
+ print (f"Objective function after update_components: { self .objective_function :.5e} " )
243
+ self ._objective_history .append (self .objective_function )
244
+ if self .objective_difference is None :
245
+ self .objective_difference = self ._objective_history [- 1 ] - self .objective_function
246
+
247
+ # Now we update weights
248
+ self .update_weights ()
249
+ self .num_updates += 1
250
+ self .residuals = self .get_residual_matrix ()
251
+ self .objective_function = self .get_objective_function ()
252
+ print (f"Objective function after update_weights: { self .objective_function :.5e} " )
253
+ self ._objective_history .append (self .objective_function )
252
254
253
255
# Now we update stretch
254
256
self .update_stretch ()
@@ -488,7 +490,7 @@ def apply_transformation_matrix(self, stretch=None, weights=None, residuals=None
488
490
489
491
return stretch_transformed
490
492
491
- def solve_quadratic_program (self , t , m , alg = "trust-constr " ):
493
+ def solve_quadratic_program (self , t , m , alg = "L-BFGS-B " ):
492
494
"""
493
495
Solves the quadratic program for updating y in stretched NMF using scipy.optimize:
494
496
@@ -588,7 +590,7 @@ def update_components(self):
588
590
+ self .eta * np .sqrt (self .components )
589
591
< 0
590
592
)
591
- self .components = mask * self . components
593
+ self .components [ ~ mask ] = 0
592
594
593
595
objective_improvement = self ._objective_history [- 1 ] - self .get_objective_function (
594
596
residuals = self .get_residual_matrix ()
@@ -656,7 +658,18 @@ def regularize_function(self, stretch=None):
656
658
der_reshaped .T + self .rho * stretch @ self ._spline_smooth_operator .T @ self ._spline_smooth_operator
657
659
)
658
660
659
- return reg_func , func_grad
661
+ # Hessian: diagonal of second derivatives
662
+ hess_diag_vals = np .sum (
663
+ dd_stretch_components * np .tile (stretch_difference , (1 , self .n_components )), axis = 0
664
+ ).ravel (order = "F" )
665
+
666
+ # Add the spline regularization Hessian (rho * PPPP)
667
+ smooth_hess = self .rho * np .kron (self ._spline_smooth_penalty .toarray (), np .eye (self .n_components ))
668
+
669
+ full_hess_diag = hess_diag_vals + np .diag (smooth_hess )
670
+ hessian = diags (full_hess_diag , format = "csc" )
671
+
672
+ return reg_func , func_grad , hessian
660
673
661
674
def update_stretch (self ):
662
675
"""
@@ -669,9 +682,9 @@ def update_stretch(self):
669
682
# Define the optimization function
670
683
def objective (stretch_vec ):
671
684
stretch_matrix = stretch_vec .reshape (self .stretch .shape ) # Reshape back to matrix form
672
- func , grad = self .regularize_function (stretch_matrix )
685
+ func , grad , hess = self .regularize_function (stretch_matrix )
673
686
grad = grad .flatten ()
674
- return func , grad
687
+ return func , grad , hess
675
688
676
689
# Optimization constraints: lower bound 0.1, no upper bound
677
690
bounds = [(0.1 , None )] * stretch_init_vec .size # Equivalent to 0.1 * ones(K, M)
@@ -682,6 +695,7 @@ def objective(stretch_vec):
682
695
x0 = stretch_init_vec , # Initial guess
683
696
method = "trust-constr" , # Equivalent to 'trust-region-reflective'
684
697
jac = lambda stretch_vec : objective (stretch_vec )[1 ], # Gradient
698
+ hess = lambda stretch_vec : objective (stretch_vec )[2 ],
685
699
bounds = bounds , # Lower bounds on stretch
686
700
# TODO: A Hessian can be incorporated for better convergence.
687
701
)
0 commit comments