Skip to content

Commit db1c456

Browse files
committed
filled function contents and tests
1 parent 726a7e2 commit db1c456

File tree

2 files changed

+17
-5
lines changed

2 files changed

+17
-5
lines changed

diffpy/snmf/subroutines.py

+7-4
Original file line numberDiff line numberDiff line change
@@ -449,7 +449,7 @@ def get_residual_matrix(component_matrix, weights_matrix, stretching_matrix, dat
449449
return residual_matrx
450450

451451

452-
def reconstruct_data(components, input_data):
452+
def reconstruct_data(components):
453453
"""Reconstructs the `input_data` matrix
454454
455455
Reconstructs the `input_data` matrix from calculated component signals, weights, and stretching factors.
@@ -458,13 +458,16 @@ def reconstruct_data(components, input_data):
458458
----------
459459
components: tuple of ComponentSignal objects
460460
The tuple containing the component signals.
461-
input_data: 2d array
462-
The 2d array containing the user provided signals.
463461
464462
Returns
465463
-------
466464
2d array
467465
The 2d array containing the reconstruction of input_data.
468466
469467
"""
470-
pass
468+
signal_length = len(components[0].iq)
469+
number_of_signal = len(components[0].weights)
470+
data_reconstruction = np.zeros((signal_length, number_of_signal))
471+
for signal in range(number_of_signal):
472+
data_reconstruction[:, signal] = reconstruct_signal(components, signal)
473+
return data_reconstruction

diffpy/snmf/tests/test_subroutines.py

+10-1
Original file line numberDiff line numberDiff line change
@@ -108,13 +108,22 @@ def test_get_residual_matrix(tgrm):
108108

109109

110110
trd = [
111+
([ComponentSignal([0, .25, .5, .75, 1], 2, 0), ComponentSignal([0, .25, .5, .75, 1], 2, 1),
112+
ComponentSignal([0, .25, .5, .75, 1], 2, 2)]),
113+
([ComponentSignal([0, .25, .5, .75, 1], 2, 0)]),
114+
([ComponentSignal([0, .25, .5, .75, 1], 2, 0), ComponentSignal([0, .25, .5, .75, 1], 2, 1),
115+
ComponentSignal([0, .25, .5, .75, 1], 2, 2), ComponentSignal([0, .25, .5, .75, 1], 2, 3),
116+
ComponentSignal([0, .25, .5, .75, 1], 2, 4)]),
117+
#([]) # Exception expected
111118

112119
]
113120

114121

115122
@pytest.mark.parametrize('trd', trd)
116123
def test_reconstruct_data(trd):
117-
assert False
124+
actual = reconstruct_data(trd)
125+
assert actual.shape == (len(trd[0].iq),len(trd[0].weights))
126+
print(actual)
118127

119128

120129
tld = [(([[[1, -1, 1], [0, 0, 0], [2, 10, -3]], 1]), ([[4, 2, 4], [3, 3, 3], [5, 13, 0]])),

0 commit comments

Comments
 (0)