@@ -107,26 +107,44 @@ def test_convex_combinations_start_equals_end(interpolation_data, func):
107107 np .testing .assert_allclose (result , arr , rtol = interpolation_data ["rtol" ])
108108
109109
110- def test_linear_impmat_interpolate (interpolation_data ):
110+ @pytest .mark .parametrize (
111+ "func,expected" ,
112+ [
113+ (
114+ linear_interp_matrix_elemwise ,
115+ np .array (
116+ [
117+ [[1.0 , 2.0 ], [3.0 , 4.0 ]],
118+ [[2.0 , 3.0 ], [4.0 , 5.0 ]],
119+ [[3.0 , 4.0 ], [5.0 , 6.0 ]],
120+ [[4.0 , 5.0 ], [6.0 , 7.0 ]],
121+ [[5.0 , 6.0 ], [7.0 , 8.0 ]],
122+ ]
123+ ),
124+ ),
125+ (
126+ exponential_interp_matrix_elemwise ,
127+ np .array (
128+ [
129+ [[1.0 , 2.0 ], [3.0 , 4.0 ]],
130+ [[1.49534878 , 2.63214803 ], [3.70779275 , 4.75682846 ]],
131+ [[2.23606798 , 3.46410162 ], [4.58257569 , 5.65685425 ]],
132+ [[3.34370152 , 4.55901411 ], [5.66374698 , 6.72717132 ]],
133+ [[5.0 , 6.0 ], [7.0 , 8.0 ]],
134+ ]
135+ ),
136+ ),
137+ ],
138+ )
139+ def test_impmat_interpolate (interpolation_data , func , expected ):
111140 data = interpolation_data
112- result = linear_interp_matrix_elemwise (
113- data ["imp_mat0" ], data ["imp_mat1" ], data ["time_points" ]
114- )
141+ result = func (data ["imp_mat0" ], data ["imp_mat1" ], data ["time_points" ])
115142
116143 assert len (result ) == data ["time_points" ]
117144 assert all (isinstance (mat , csr_matrix ) for mat in result )
118145
119146 dense = np .array ([r .todense () for r in result ])
120- expected = np .array (
121- [
122- [[1.0 , 2.0 ], [3.0 , 4.0 ]],
123- [[2.0 , 3.0 ], [4.0 , 5.0 ]],
124- [[3.0 , 4.0 ], [5.0 , 6.0 ]],
125- [[4.0 , 5.0 ], [6.0 , 7.0 ]],
126- [[5.0 , 6.0 ], [7.0 , 8.0 ]],
127- ]
128- )
129- np .testing .assert_array_equal (dense , expected )
147+ np .testing .assert_array_almost_equal (dense , expected )
130148
131149
132150# --- Tests for Interpolation Strategies ---
0 commit comments