Skip to content

Commit 36d2c90

Browse files
john-halloranJohn Halloran
and
John Halloran
authored
fix stop condition and add Y/A normalization (#141)
Co-authored-by: John Halloran <[email protected]>
1 parent 56f89a1 commit 36d2c90

File tree

1 file changed

+13
-12
lines changed

1 file changed

+13
-12
lines changed

src/diffpy/snmf/snmf_class.py

Lines changed: 13 additions & 12 deletions
Original file line numberDiff line numberDiff line change
@@ -3,8 +3,6 @@
33
from scipy.optimize import minimize
44
from scipy.sparse import block_diag, coo_matrix, diags
55

6-
# from scipy.sparse import csr_matrix, spdiags (needed for hessian once fixed)
7-
86

97
class SNMFOptimizer:
108
def __init__(self, MM, Y0=None, X0=None, A=None, rho=1e12, eta=610, maxiter=300, components=None):
@@ -67,6 +65,7 @@ def __init__(self, MM, Y0=None, X0=None, A=None, rho=1e12, eta=610, maxiter=300,
6765
f", Obj - reg/sparse: {self.objective_function - regularization_term - sparsity_term:.5e}"
6866
)
6967

68+
# Main optimization loop
7069
for outiter in range(self.maxiter):
7170
self.outiter = outiter
7271
self.outer_loop()
@@ -81,10 +80,18 @@ def __init__(self, MM, Y0=None, X0=None, A=None, rho=1e12, eta=610, maxiter=300,
8180
)
8281

8382
# Convergence check: Stop if diffun is small and at least 20 iterations have passed
84-
# This check is not working, so have temporarily set 20->120 instead
85-
if self.objective_difference < self.objective_function * 1e-6 and outiter >= 120:
83+
print(self.objective_difference, " < ", self.objective_function * 1e-6)
84+
if self.objective_difference < self.objective_function * 1e-6 and outiter >= 20:
8685
break
8786

87+
# Normalize our results
88+
Y_row_max = np.max(self.Y, axis=1, keepdims=True)
89+
self.Y = self.Y / Y_row_max
90+
A_row_max = np.max(self.A, axis=1, keepdims=True)
91+
self.A = self.A / A_row_max
92+
# TODO loop to normalize X (currently not normalized)
93+
# effectively just re-running class with non-normalized X, normalized Y/A as inputs, then only update X
94+
8895
def outer_loop(self):
8996
# This inner loop runs up to four times per outer loop, making updates to X, Y
9097
for iter in range(4):
@@ -108,25 +115,19 @@ def outer_loop(self):
108115
self.objective_history.append(self.objective_function)
109116

110117
# Check whether to break out early
118+
# TODO this condition has not been tested, and may have issues
111119
if len(self.objective_history) >= 3: # Ensure at least 3 values exist
112120
if self.objective_history[-3] - self.objective_function < self.objective_difference * 1e-3:
113121
break # Stop if improvement is too small
114122

115-
if self.outiter == 0:
116-
print("Testing regularize_function:")
117-
test_fun, test_gra, test_hess = self.regularize_function()
118-
print(f"Fun: {test_fun:.5e}")
119-
np.savetxt("output/py_test_gra.txt", test_gra, fmt="%.8g", delimiter=" ")
120-
np.savetxt("output/py_test_hess.txt", test_hess, fmt="%.8g", delimiter=" ")
121-
122123
self.updateA2()
123124

124125
self.num_updates += 1
125126
self.R = self.get_residual_matrix()
126127
self.objective_function = self.get_objective_function()
127128
print(f"Objective function after updateA2: {self.objective_function:.5e}")
128129
self.objective_history.append(self.objective_function)
129-
self.objective_difference = self.objective_history[-1] - self.objective_function
130+
self.objective_difference = self.objective_history[-2] - self.objective_history[-1]
130131

131132
def apply_interpolation(self, a, x):
132133
"""

0 commit comments

Comments
 (0)