3
3
from scipy .optimize import minimize
4
4
from scipy .sparse import block_diag , coo_matrix , diags
5
5
6
- # from scipy.sparse import csr_matrix, spdiags (needed for hessian once fixed)
7
-
8
6
9
7
class SNMFOptimizer :
10
8
def __init__ (self , MM , Y0 = None , X0 = None , A = None , rho = 1e12 , eta = 610 , maxiter = 300 , components = None ):
@@ -67,6 +65,7 @@ def __init__(self, MM, Y0=None, X0=None, A=None, rho=1e12, eta=610, maxiter=300,
67
65
f", Obj - reg/sparse: { self .objective_function - regularization_term - sparsity_term :.5e} "
68
66
)
69
67
68
+ # Main optimization loop
70
69
for outiter in range (self .maxiter ):
71
70
self .outiter = outiter
72
71
self .outer_loop ()
@@ -81,10 +80,18 @@ def __init__(self, MM, Y0=None, X0=None, A=None, rho=1e12, eta=610, maxiter=300,
81
80
)
82
81
83
82
# Convergence check: Stop if diffun is small and at least 20 iterations have passed
84
- # This check is not working, so have temporarily set 20->120 instead
85
- if self .objective_difference < self .objective_function * 1e-6 and outiter >= 120 :
83
+ print ( self . objective_difference , " < " , self . objective_function * 1e-6 )
84
+ if self .objective_difference < self .objective_function * 1e-6 and outiter >= 20 :
86
85
break
87
86
87
+ # Normalize our results
88
+ Y_row_max = np .max (self .Y , axis = 1 , keepdims = True )
89
+ self .Y = self .Y / Y_row_max
90
+ A_row_max = np .max (self .A , axis = 1 , keepdims = True )
91
+ self .A = self .A / A_row_max
92
+ # TODO loop to normalize X (currently not normalized)
93
+ # effectively just re-running class with non-normalized X, normalized Y/A as inputs, then only update X
94
+
88
95
def outer_loop (self ):
89
96
# This inner loop runs up to four times per outer loop, making updates to X, Y
90
97
for iter in range (4 ):
@@ -108,25 +115,19 @@ def outer_loop(self):
108
115
self .objective_history .append (self .objective_function )
109
116
110
117
# Check whether to break out early
118
+ # TODO this condition has not been tested, and may have issues
111
119
if len (self .objective_history ) >= 3 : # Ensure at least 3 values exist
112
120
if self .objective_history [- 3 ] - self .objective_function < self .objective_difference * 1e-3 :
113
121
break # Stop if improvement is too small
114
122
115
- if self .outiter == 0 :
116
- print ("Testing regularize_function:" )
117
- test_fun , test_gra , test_hess = self .regularize_function ()
118
- print (f"Fun: { test_fun :.5e} " )
119
- np .savetxt ("output/py_test_gra.txt" , test_gra , fmt = "%.8g" , delimiter = " " )
120
- np .savetxt ("output/py_test_hess.txt" , test_hess , fmt = "%.8g" , delimiter = " " )
121
-
122
123
self .updateA2 ()
123
124
124
125
self .num_updates += 1
125
126
self .R = self .get_residual_matrix ()
126
127
self .objective_function = self .get_objective_function ()
127
128
print (f"Objective function after updateA2: { self .objective_function :.5e} " )
128
129
self .objective_history .append (self .objective_function )
129
- self .objective_difference = self .objective_history [- 1 ] - self .objective_function
130
+ self .objective_difference = self .objective_history [- 2 ] - self .objective_history [ - 1 ]
130
131
131
132
def apply_interpolation (self , a , x ):
132
133
"""
0 commit comments