|
57 | 57 | from jax import jacobian
|
58 | 58 | from jax import jit
|
59 | 59 | from jax import lax
|
60 |
| -from jax import ops |
61 | 60 | from jax import random
|
62 | 61 | from jax import vmap
|
63 | 62 | import jax.numpy as np
|
@@ -451,14 +450,14 @@ def ilqr(cost,
|
451 | 450 | dynamics_fn, dynamics_args = custom_derivatives.closure_convert(
|
452 | 451 | dynamics, x0, U[0], 0)
|
453 | 452 | return ilqr_base(cost_fn, dynamics_fn, x0, U, tuple(cost_args),
|
454 |
| - tuple(dynamics_args), maxiter, grad_norm_threshold, make_psd, |
455 |
| - psd_delta, alpha_0, alpha_min) |
| 453 | + tuple(dynamics_args), maxiter, grad_norm_threshold, make_psd, |
| 454 | + psd_delta, alpha_0, alpha_min) |
456 | 455 |
|
457 | 456 |
|
458 | 457 | @partial(jax.custom_vjp, nondiff_argnums=(0, 1))
|
459 | 458 | @partial(jit, static_argnums=(0, 1))
|
460 | 459 | def ilqr_base(cost, dynamics, x0, U, cost_args, dynamics_args, maxiter,
|
461 |
| - grad_norm_threshold, make_psd, psd_delta, alpha_0, alpha_min): |
| 460 | + grad_norm_threshold, make_psd, psd_delta, alpha_0, alpha_min): |
462 | 461 | """ilqr implementation."""
|
463 | 462 |
|
464 | 463 | T, m = U.shape
|
@@ -963,7 +962,8 @@ def augmented_lagrangian(x, u, t, dual_equality, dual_inequality, penalty):
|
963 | 962 | inequality = inequality_constraint(x, u, t)
|
964 | 963 |
|
965 | 964 | # active set
|
966 |
| - active_set = np.invert(np.isclose(dual_inequality[t], 0.0) & (inequality < 0.0)) |
| 965 | + active_set = np.invert( |
| 966 | + np.isclose(dual_inequality[t], 0.0) & (inequality < 0.0)) |
967 | 967 |
|
968 | 968 | # update cost
|
969 | 969 | # TODO(taylorhowell): Gauss-Newton approximation for constraints,
|
@@ -1052,11 +1052,17 @@ def body(inputs):
|
1052 | 1052 |
|
1053 | 1053 | def continuation_criteria(inputs):
|
1054 | 1054 | # unpack
|
| 1055 | + dual_inequality = inputs[3] |
| 1056 | + inequality_constraints = inputs[6] |
1055 | 1057 | max_constraint_violation = inputs[7]
|
1056 | 1058 | iteration_al = inputs[11]
|
| 1059 | + max_complementary_slack = np.max( |
| 1060 | + np.abs(inequality_constraints * dual_inequality)) |
1057 | 1061 | # check maximum constraint violation and augmented Lagrangian iterations
|
1058 | 1062 | return np.logical_and(iteration_al < maxiter_al,
|
1059 |
| - max_constraint_violation > constraints_threshold) |
| 1063 | + np.logical_or( |
| 1064 | + max_constraint_violation > constraints_threshold, |
| 1065 | + max_complementary_slack > constraints_threshold)) |
1060 | 1066 |
|
1061 | 1067 | return lax.while_loop(
|
1062 | 1068 | continuation_criteria, body,
|
|
0 commit comments