Skip to content

Commit b9fb8d0

Browse files
committed
Add test for diff grad in test_costs.py
1 parent 6ab1383 commit b9fb8d0

File tree

1 file changed

+54
-3
lines changed

1 file changed

+54
-3
lines changed

tests/python/test_costs.py

Lines changed: 54 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -6,6 +6,45 @@
66

77
import pytest
88

9+
EPS = 1e-7
10+
ATOL = 2 * EPS**0.5
11+
12+
13+
def finite_grad(costmodel, space, x, u, EPS=1e-8):
14+
ndx = space.ndx
15+
nu = u.size
16+
grad = np.zeros(ndx + nu)
17+
dx = np.zeros(ndx)
18+
du = np.zeros(nu)
19+
data = costmodel.createData()
20+
costmodel.evaluate(x, u, data)
21+
# distance to origin
22+
_dx = space.difference(space.neutral(), x)
23+
ex = EPS * max(1.0, np.linalg.norm(_dx))
24+
vref = data.value
25+
for i in range(ndx):
26+
dx[i] = ex
27+
x1 = space.integrate(x, dx)
28+
costmodel.evaluate(x1, u, data)
29+
grad[i] = (data.value - vref) / ex
30+
dx[i] = 0.0
31+
32+
for i in range(ndx, ndx + nu):
33+
du[i - ndx] = ex
34+
u1 = u + du
35+
costmodel.evaluate(x, u1, data)
36+
grad[i] = (data.value - vref) / ex
37+
du[i - ndx] = 0.0
38+
39+
return grad
40+
41+
42+
def sample_gauss(space):
43+
x0 = space.neutral()
44+
d = np.random.randn(space.ndx) * 0.1
45+
x1 = space.integrate(x0, d)
46+
return x1
47+
948

1049
def test_cost_stack():
1150
nx = 2
@@ -107,6 +146,12 @@ def test_composite_cost():
107146
print(data.value)
108147
print(data.grad)
109148
print(data.hess)
149+
for i in range(100):
150+
x0 = sample_gauss(space)
151+
cost.evaluate(x0, u0, data)
152+
cost.computeGradients(x0, u0, data)
153+
fgrad = finite_grad(cost, space, x0, u0)
154+
assert np.allclose(fgrad, data.grad)
110155
print("----")
111156

112157

@@ -127,10 +172,9 @@ def test_log_barrier():
127172

128173
np.random.seed(40)
129174

130-
weight = np.random.rand()
175+
weights = np.ones(fun.nr)
131176
thresh = np.random.rand()
132-
cost = aligator.RelaxedLogBarrierCost(space, fun, weight, thresh)
133-
weights = np.array([weight] * fun.nr)
177+
cost = aligator.RelaxedLogBarrierCost(space, fun, weights, thresh)
134178
assert np.array_equal(weights, cost.weights)
135179

136180
data = cost.createData()
@@ -145,6 +189,13 @@ def test_log_barrier():
145189
print(data.value)
146190
print(data.grad)
147191
print(data.hess)
192+
193+
for i in range(100):
194+
x0 = sample_gauss(space)
195+
cost.evaluate(x0, u0, data)
196+
cost.computeGradients(x0, u0, data)
197+
fgrad = finite_grad(cost, space, x0, u0)
198+
assert np.allclose(fgrad, data.grad)
148199
print("----")
149200

150201

0 commit comments

Comments
 (0)