diff --git a/diffpy/snmf/subroutines.py b/diffpy/snmf/subroutines.py index c0052bb..46d394c 100644 --- a/diffpy/snmf/subroutines.py +++ b/diffpy/snmf/subroutines.py @@ -206,7 +206,31 @@ def reconstruct_signal(components, signal_idx): reconstruction += stretched_and_weighted return reconstruction +def reconstruct_signal_hess(components, signal_idx): + """Reconstruct a specific signal's hessian (second derivative) from its weighted and stretched components. + Calculates the linear combination of stretched components' hessians where each term is a stretched component's + hessian mulitplied by its weight factor. + + Parameters + ---------- + components: tuple of ComponentSignal objects + The tuple containing the ComponentSignal objects + signal_idx: int + The index of the specific signal in the input data to be reconstructed. + + Returns + ------- + 1d array like + The reconstruction of a signal's hessian from calculated weights, stretching factors, and iq values + """ + signal_length = len(components[0].grid) + reconstruction = np.zeros(signal_length) + for component in components: + stretched = component.apply_stretch(signal_idx)[2] + stretched_and_weighted = component.apply_weight(signal_idx, stretched) + reconstruction += stretched_and_weighted + return reconstruction def initialize_arrays(number_of_components, number_of_moments, signal_length): """Generates the initial guesses for the weight, stretching, and component matrices diff --git a/diffpy/snmf/tests/test_subroutines.py b/diffpy/snmf/tests/test_subroutines.py index 1a93971..8f67728 100644 --- a/diffpy/snmf/tests/test_subroutines.py +++ b/diffpy/snmf/tests/test_subroutines.py @@ -3,7 +3,7 @@ from diffpy.snmf.containers import ComponentSignal from diffpy.snmf.subroutines import objective_function, get_stretched_component, reconstruct_data, get_residual_matrix, \ update_weights_matrix, initialize_arrays, lift_data, initialize_components, construct_stretching_matrix, \ - construct_component_matrix, construct_weight_matrix, update_weights, reconstruct_signal + construct_component_matrix, construct_weight_matrix, update_weights, reconstruct_signal, reconstruct_signal_hess to = [ ([[[1, 2], [3, 4]], [[5, 6], [7, 8]], 1e11, [[1, 2], [3, 4]], [[1, 2], [3, 4]], 1], 2.574e14), @@ -251,3 +251,17 @@ def test_update_weights(tuw): def test_reconstruct_signal(trs): actual = reconstruct_signal(trs[0], trs[1]) assert len(actual) == len(trs[0][0].grid) + +trsh = [([ComponentSignal([0, .25, .5, .75, 1], 2, 0), ComponentSignal([0, .25, .5, .75, 1], 2, 1), + ComponentSignal([0, .25, .5, .75, 1], 2, 2)], 1), + ([ComponentSignal([0, .25, .5, .75, 1], 2, 0), ComponentSignal([0, .25, .5, .75, 1], 2, 1), + ComponentSignal([0, .25, .5, .75, 1], 2, 2)], 0), + ([ComponentSignal([0, .25, .5, .75, 1], 3, 0), ComponentSignal([0, .25, .5, .75, 1], 3, 1), + ComponentSignal([0, .25, .5, .75, 1], 3, 2)], 2), + # ([ComponentSignal([0, .25, .5, .75, 1], 2, 0), ComponentSignal([0, .25, .5, .75, 1], 2, 1), + # ComponentSignal([0, .25, .5, .75, 1], 2, 2)], -1), +] +@pytest.mark.parametrize('trsh', trsh) +def test_reconstruct_signal_hess(trsh): + actual = reconstruct_signal(trsh[0], trsh[1]) + assert len(actual) == len(trsh[0][0].grid)