2
2
import numpy as np
3
3
from diffpy .snmf .containers import ComponentSignal
4
4
from diffpy .snmf .subroutines import objective_function , get_stretched_component , reconstruct_data , get_residual_matrix , \
5
- update_weights_matrix , initialize_arrays , lift_data , initialize_components , construct_stretching_matrix , construct_component_matrix , construct_weight_matrix
5
+ update_weights_matrix , initialize_arrays , lift_data , initialize_components , construct_stretching_matrix , \
6
+ construct_component_matrix , construct_weight_matrix , update_weights
6
7
7
8
to = [
8
9
([[[1 , 2 ], [3 , 4 ]], [[5 , 6 ], [7 , 8 ]], 1e11 , [[1 , 2 ], [3 , 4 ]], [[1 , 2 ], [3 , 4 ]], 1 ], 2.574e14 ),
@@ -207,6 +208,7 @@ def test_construct_component_matrix(tccm):
207
208
for component in tccm :
208
209
np .testing .assert_allclose (actual [component .id ], component .iq )
209
210
211
+
210
212
tcwm = [
211
213
([ComponentSignal ([0 ,.25 ,.5 ,.75 ,1 ],20 ,0 )]),
212
214
# ([ComponentSignal([0,.25,.5,.75,1],0,0)]), # 0 signal length. Failure expected
@@ -225,3 +227,28 @@ def test_construct_weight_matrix(tcwm):
225
227
actual = construct_weight_matrix (tcwm )
226
228
for component in tcwm :
227
229
np .testing .assert_allclose (actual [component .id ], component .weights )
230
+
231
+
232
+ tuw = [([ComponentSignal ([0 , .25 , .5 , .75 , 1 ], 2 , 0 ), ComponentSignal ([0 , .25 , .5 , .75 , 1 ], 2 , 1 ),
233
+ ComponentSignal ([0 , .25 , .5 , .75 , 1 ], 2 , 2 )], [[1 , 1 ], [1.2 , 1.3 ], [1.3 , 1.4 ], [1.4 , 1.5 ], [2 , 2.1 ]], None ),
234
+ ([ComponentSignal ([0 , .25 , .5 , .75 , 1 ], 2 , 0 ), ComponentSignal ([0 , .25 , .5 , .75 , 1 ], 2 , 1 ),
235
+ ComponentSignal ([0 , .25 , .5 , .75 , 1 ], 2 , 2 )], [[1 , 1 ], [1.2 , 1.3 ], [1.3 , 1.4 ], [1.4 , 1.5 ], [2 , 2.1 ]], "align" ),
236
+ ([ComponentSignal ([0 , .25 , .5 , .75 , 1 ], 2 , 0 ), ComponentSignal ([0 , .25 , .5 , .75 , 1 ], 2 , 1 ),
237
+ ComponentSignal ([0 , .25 , .5 , .75 , 1 ], 2 , 2 )], [[0 , 0 ], [0 , 0 ], [0 , 0 ], [0 , 0 ], [0 , 0 ]], None ),
238
+ ([ComponentSignal ([0 , .25 , .5 , .75 , 1 ], 2 , 0 ), ComponentSignal ([0 , .25 , .5 , .75 , 1 ], 2 , 1 ),
239
+ ComponentSignal ([0 , .25 , .5 , .75 , 1 ], 2 , 2 )], [[0 , 0 ], [0 , 0 ], [0 , 0 ], [0 , 0 ], [0 , 0 ]], "align" ),
240
+ ([ComponentSignal ([0 , .25 , .5 , .75 , 1 ], 2 , 0 ), ComponentSignal ([0 , .25 , .5 , .75 , 1 ], 2 , 1 ),
241
+ ComponentSignal ([0 , .25 , .5 , .75 , 1 ], 2 , 2 )], [[- .5 , 1 ], [1.2 , - 1.3 ], [1.1 , - 1 ], [0 , - 1.5 ], [0 , .1 ]], None ),
242
+ ([ComponentSignal ([0 , .25 , .5 , .75 , 1 ], 2 , 0 ), ComponentSignal ([0 , .25 , .5 , .75 , 1 ], 2 , 1 ),
243
+ ComponentSignal ([0 , .25 , .5 , .75 , 1 ], 2 , 2 )], [[- .5 , 1 ], [1.2 , - 1.3 ], [1.1 , - 1 ], [0 , - 1.5 ], [0 , .1 ]], "align" ),
244
+ # ([ComponentSignal([0, .25, .5, .75, 1], 0, 0), ComponentSignal([0, .25, .5, .75, 1], 0, 1),
245
+ # ComponentSignal([0, .25, .5, .75, 1], 0, 2)], [[1, 1], [1.2, 1.3], [1.3, 1.4], [1.4, 1.5], [2, 2.1]], None),
246
+ # ([ComponentSignal([0, .25, .5, .75, 1], 0, 0), ComponentSignal([0, .25, .5, .75, 1], 0, 1),
247
+ # ComponentSignal([0, .25, .5, .75, 1], 0, 2)], [], None),
248
+ # ([ComponentSignal([0, .25, .5, .75, 1], 2, 0), ComponentSignal([0, .25, .5, .75, 1], 2, 1),
249
+ # ComponentSignal([0, .25, .5, .75, 1], 2, 2)], [], 170),
250
+ ]
251
+ @pytest .mark .parametrize ('tuw' , tuw )
252
+ def test_update_weights (tuw ):
253
+ actual = update_weights (tuw [0 ], tuw [1 ], tuw [2 ])
254
+ assert np .shape (actual ) == (len (tuw [0 ]), len (tuw [0 ][0 ].weights ))
0 commit comments