Skip to content

Commit 73c68ef

Browse files
authored
Merge pull request #46 from aajayi-21/update_weights2
function update_weights
2 parents c644543 + 9c49872 commit 73c68ef

File tree

2 files changed

+69
-2
lines changed

2 files changed

+69
-2
lines changed

diffpy/snmf/subroutines.py

Lines changed: 41 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -134,11 +134,51 @@ def construct_weight_matrix(components):
134134
raise ValueError(f"Number of components = {number_of_components}. Number of components must be >= 1")
135135
if number_of_signals == 0:
136136
raise ValueError(f"Number of signals = {number_of_signals}. Number_of_signals must be >= 1.")
137-
weights_matrix = np.zeros((number_of_components,number_of_signals))
137+
weights_matrix = np.zeros((number_of_components, number_of_signals))
138138
for i, component in enumerate(components):
139139
weights_matrix[i] = component.weights
140140
return weights_matrix
141141

142+
143+
def update_weights(components, data_input, method=None):
144+
"""Updates the weights matrix.
145+
146+
Updates the weights matrix and the weights vector for each ComponentSignal object.
147+
148+
Parameters
149+
----------
150+
components: tuple of ComponentSignal objects
151+
The tuple containing the component signals.
152+
method: str
153+
The string specifying which method should be used to find a new weight matrix: non-negative least squares or a
154+
quadratic program.
155+
data_input: 2d array
156+
The 2d array containing the user-provided signals.
157+
158+
Returns
159+
-------
160+
2d array
161+
The 2d array containing the weight factors for each component for each signal from `data_input`. Has dimensions
162+
K x M where K is the number of components and M is the number of signals in `data_input.`
163+
"""
164+
data_input = np.asarray(data_input)
165+
weight_matrix = construct_weight_matrix(components)
166+
number_of_signals = len(components[0].weights)
167+
number_of_components = len(components)
168+
signal_length = len(components[0].grid)
169+
for signal in range(number_of_signals):
170+
stretched_components = np.zeros((signal_length, number_of_components))
171+
for i, component in enumerate(components):
172+
stretched_components[:, i] = component.apply_stretch(signal)[0]
173+
if method == 'align':
174+
weights = lsqnonneg(stretched_components, data_input[:,signal])
175+
else:
176+
weights = get_weights(stretched_components.T @ stretched_components,
177+
-stretched_components.T @ data_input[:, signal], 0, 1)
178+
weight_matrix[:, signal] = weights
179+
return weight_matrix
180+
181+
142182
def initialize_arrays(number_of_components, number_of_moments, signal_length):
143183
"""Generates the initial guesses for the weight, stretching, and component matrices
144184

diffpy/snmf/tests/test_subroutines.py

Lines changed: 28 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -2,7 +2,8 @@
22
import numpy as np
33
from diffpy.snmf.containers import ComponentSignal
44
from diffpy.snmf.subroutines import objective_function, get_stretched_component, reconstruct_data, get_residual_matrix, \
5-
update_weights_matrix, initialize_arrays, lift_data, initialize_components, construct_stretching_matrix, construct_component_matrix, construct_weight_matrix
5+
update_weights_matrix, initialize_arrays, lift_data, initialize_components, construct_stretching_matrix, \
6+
construct_component_matrix, construct_weight_matrix, update_weights
67

78
to = [
89
([[[1, 2], [3, 4]], [[5, 6], [7, 8]], 1e11, [[1, 2], [3, 4]], [[1, 2], [3, 4]], 1], 2.574e14),
@@ -207,6 +208,7 @@ def test_construct_component_matrix(tccm):
207208
for component in tccm:
208209
np.testing.assert_allclose(actual[component.id], component.iq)
209210

211+
210212
tcwm = [
211213
([ComponentSignal([0,.25,.5,.75,1],20,0)]),
212214
# ([ComponentSignal([0,.25,.5,.75,1],0,0)]), # 0 signal length. Failure expected
@@ -225,3 +227,28 @@ def test_construct_weight_matrix(tcwm):
225227
actual = construct_weight_matrix(tcwm)
226228
for component in tcwm:
227229
np.testing.assert_allclose(actual[component.id], component.weights)
230+
231+
232+
tuw = [([ComponentSignal([0, .25, .5, .75, 1], 2, 0), ComponentSignal([0, .25, .5, .75, 1], 2, 1),
233+
ComponentSignal([0, .25, .5, .75, 1], 2, 2)], [[1, 1], [1.2, 1.3], [1.3, 1.4], [1.4, 1.5], [2, 2.1]], None),
234+
([ComponentSignal([0, .25, .5, .75, 1], 2, 0), ComponentSignal([0, .25, .5, .75, 1], 2, 1),
235+
ComponentSignal([0, .25, .5, .75, 1], 2, 2)], [[1, 1], [1.2, 1.3], [1.3, 1.4], [1.4, 1.5], [2, 2.1]], "align"),
236+
([ComponentSignal([0, .25, .5, .75, 1], 2, 0), ComponentSignal([0, .25, .5, .75, 1], 2, 1),
237+
ComponentSignal([0, .25, .5, .75, 1], 2, 2)], [[0, 0], [0, 0], [0, 0], [0, 0], [0, 0]], None),
238+
([ComponentSignal([0, .25, .5, .75, 1], 2, 0), ComponentSignal([0, .25, .5, .75, 1], 2, 1),
239+
ComponentSignal([0, .25, .5, .75, 1], 2, 2)], [[0, 0], [0, 0], [0, 0], [0, 0], [0, 0]], "align"),
240+
([ComponentSignal([0, .25, .5, .75, 1], 2, 0), ComponentSignal([0, .25, .5, .75, 1], 2, 1),
241+
ComponentSignal([0, .25, .5, .75, 1], 2, 2)], [[-.5, 1], [1.2, -1.3], [1.1, -1], [0, -1.5], [0, .1]], None),
242+
([ComponentSignal([0, .25, .5, .75, 1], 2, 0), ComponentSignal([0, .25, .5, .75, 1], 2, 1),
243+
ComponentSignal([0, .25, .5, .75, 1], 2, 2)], [[-.5, 1], [1.2, -1.3], [1.1, -1], [0, -1.5], [0, .1]], "align"),
244+
# ([ComponentSignal([0, .25, .5, .75, 1], 0, 0), ComponentSignal([0, .25, .5, .75, 1], 0, 1),
245+
# ComponentSignal([0, .25, .5, .75, 1], 0, 2)], [[1, 1], [1.2, 1.3], [1.3, 1.4], [1.4, 1.5], [2, 2.1]], None),
246+
# ([ComponentSignal([0, .25, .5, .75, 1], 0, 0), ComponentSignal([0, .25, .5, .75, 1], 0, 1),
247+
# ComponentSignal([0, .25, .5, .75, 1], 0, 2)], [], None),
248+
# ([ComponentSignal([0, .25, .5, .75, 1], 2, 0), ComponentSignal([0, .25, .5, .75, 1], 2, 1),
249+
# ComponentSignal([0, .25, .5, .75, 1], 2, 2)], [], 170),
250+
]
251+
@pytest.mark.parametrize('tuw', tuw)
252+
def test_update_weights(tuw):
253+
actual = update_weights(tuw[0], tuw[1], tuw[2])
254+
assert np.shape(actual) == (len(tuw[0]), len(tuw[0][0].weights))

0 commit comments

Comments
 (0)