Skip to content

Commit c0be760

Browse files
authored
Merge pull request #49 from aajayi-21/reconstruct_signal
function reconstruct_signal
2 parents 6a0fee3 + be81e67 commit c0be760

File tree

2 files changed

+44
-2
lines changed

2 files changed

+44
-2
lines changed

diffpy/snmf/subroutines.py

+29-1
Original file line numberDiff line numberDiff line change
@@ -171,14 +171,42 @@ def update_weights(components, data_input, method=None):
171171
for i, component in enumerate(components):
172172
stretched_components[:, i] = component.apply_stretch(signal)[0]
173173
if method == 'align':
174-
weights = lsqnonneg(stretched_components, data_input[:,signal])
174+
weights = lsqnonneg(stretched_components, data_input[:, signal])
175175
else:
176176
weights = get_weights(stretched_components.T @ stretched_components,
177177
-stretched_components.T @ data_input[:, signal], 0, 1)
178178
weight_matrix[:, signal] = weights
179179
return weight_matrix
180180

181181

182+
def reconstruct_signal(components, signal_idx):
183+
"""Reconstructs a specific signal from its weighted and stretched components.
184+
185+
Calculates the linear combination of stretched components where each term is the stretched component multiplied
186+
by its weight factor.
187+
188+
Parameters
189+
----------
190+
components: tuple of ComponentSignal objects
191+
The tuple containing the ComponentSignal objects
192+
signal_idx: int
193+
The index of the specific signal in the input data to be reconstructed
194+
195+
Returns
196+
-------
197+
1d array like
198+
The reconstruction of a signal from calculated weights, stretching factors, and iq values.
199+
200+
"""
201+
signal_length = len(components[0].grid)
202+
reconstruction = np.zeros(signal_length)
203+
for component in components:
204+
stretched = component.apply_stretch(signal_idx)[0]
205+
stretched_and_weighted = component.apply_weight(signal_idx, stretched)
206+
reconstruction += stretched_and_weighted
207+
return reconstruction
208+
209+
182210
def initialize_arrays(number_of_components, number_of_moments, signal_length):
183211
"""Generates the initial guesses for the weight, stretching, and component matrices
184212

diffpy/snmf/tests/test_subroutines.py

+15-1
Original file line numberDiff line numberDiff line change
@@ -3,7 +3,7 @@
33
from diffpy.snmf.containers import ComponentSignal
44
from diffpy.snmf.subroutines import objective_function, get_stretched_component, reconstruct_data, get_residual_matrix, \
55
update_weights_matrix, initialize_arrays, lift_data, initialize_components, construct_stretching_matrix, \
6-
construct_component_matrix, construct_weight_matrix, update_weights
6+
construct_component_matrix, construct_weight_matrix, update_weights, reconstruct_signal
77

88
to = [
99
([[[1, 2], [3, 4]], [[5, 6], [7, 8]], 1e11, [[1, 2], [3, 4]], [[1, 2], [3, 4]], 1], 2.574e14),
@@ -252,3 +252,17 @@ def test_construct_weight_matrix(tcwm):
252252
def test_update_weights(tuw):
253253
actual = update_weights(tuw[0], tuw[1], tuw[2])
254254
assert np.shape(actual) == (len(tuw[0]), len(tuw[0][0].weights))
255+
256+
trs = [([ComponentSignal([0, .25, .5, .75, 1], 2, 0), ComponentSignal([0, .25, .5, .75, 1], 2, 1),
257+
ComponentSignal([0, .25, .5, .75, 1], 2, 2)], 1),
258+
([ComponentSignal([0, .25, .5, .75, 1], 2, 0), ComponentSignal([0, .25, .5, .75, 1], 2, 1),
259+
ComponentSignal([0, .25, .5, .75, 1], 2, 2)], 0),
260+
([ComponentSignal([0, .25, .5, .75, 1], 3, 0), ComponentSignal([0, .25, .5, .75, 1], 3, 1),
261+
ComponentSignal([0, .25, .5, .75, 1], 3, 2)], 2),
262+
# ([ComponentSignal([0, .25, .5, .75, 1], 2, 0), ComponentSignal([0, .25, .5, .75, 1], 2, 1),
263+
# ComponentSignal([0, .25, .5, .75, 1], 2, 2)], -1),
264+
]
265+
@pytest.mark.parametrize('trs',trs)
266+
def test_reconstruct_signal(trs):
267+
actual = reconstruct_signal(trs[0], trs[1])
268+
assert len(actual) == len(trs[0][0].grid)

0 commit comments

Comments
 (0)