diff --git a/opty/tests/test_utils.py b/opty/tests/test_utils.py index 23a74168..72e36037 100644 --- a/opty/tests/test_utils.py +++ b/opty/tests/test_utils.py @@ -200,6 +200,43 @@ def test_invalid_integration_limits(self): sym.Integral(self.x ** 2, (self.t, 0, 1)), self.state_symbols, self.input_symbols, self.unknown_symbols, self.N, 1) + def test_variable_time(self): + + def expected_obj(free): + f = free[2*self.N:-1] + return free[-1]*np.sum(f**2) + + def expected_obj_grad(free): + f = free[2*self.N:-1] + grad = np.zeros_like(free) + grad[2*self.N:-1] = 2.0*free[-1]*free[2*self.N:-1] + grad[-1] = np.sum(f**2) + return grad + + obj_expr = sym.Integral(self.f1**2 + self.f2**2, self.t) + obj, obj_grad = utils.create_objective_function( + obj_expr, self.state_symbols, self.input_symbols, + self.unknown_symbols, self.N, self.h, time_symbols=self.t) + np.testing.assert_allclose(obj(self.free), expected_obj(self.free)) + np.testing.assert_allclose(obj_grad(self.free), + expected_obj_grad(self.free)) + + def expected_obj(free): + return free[-1] + + def expected_obj_grad(free): + grad = np.zeros_like(free) + grad[-1] = 1.0 + return grad + + obj_expr = sym.Integral(1, self.t) + obj, obj_grad = utils.create_objective_function( + obj_expr, self.state_symbols, self.input_symbols, + self.unknown_symbols, self.N, self.h, time_symbols=self.t) + np.testing.assert_allclose(obj(self.free), expected_obj(self.free)) + np.testing.assert_allclose(obj_grad(self.free), + expected_obj_grad(self.free)) + def test_parse_free():