Skip to content
This repository was archived by the owner on Nov 6, 2024. It is now read-only.

Commit af9876a

Browse files
ssingh19stephentu
authored andcommitted
source sync
PiperOrigin-RevId: 432980481
1 parent 24d0af6 commit af9876a

File tree

1 file changed

+12
-6
lines changed

1 file changed

+12
-6
lines changed

trajax/optimizers.py

Lines changed: 12 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -57,7 +57,6 @@
5757
from jax import jacobian
5858
from jax import jit
5959
from jax import lax
60-
from jax import ops
6160
from jax import random
6261
from jax import vmap
6362
import jax.numpy as np
@@ -451,14 +450,14 @@ def ilqr(cost,
451450
dynamics_fn, dynamics_args = custom_derivatives.closure_convert(
452451
dynamics, x0, U[0], 0)
453452
return ilqr_base(cost_fn, dynamics_fn, x0, U, tuple(cost_args),
454-
tuple(dynamics_args), maxiter, grad_norm_threshold, make_psd,
455-
psd_delta, alpha_0, alpha_min)
453+
tuple(dynamics_args), maxiter, grad_norm_threshold, make_psd,
454+
psd_delta, alpha_0, alpha_min)
456455

457456

458457
@partial(jax.custom_vjp, nondiff_argnums=(0, 1))
459458
@partial(jit, static_argnums=(0, 1))
460459
def ilqr_base(cost, dynamics, x0, U, cost_args, dynamics_args, maxiter,
461-
grad_norm_threshold, make_psd, psd_delta, alpha_0, alpha_min):
460+
grad_norm_threshold, make_psd, psd_delta, alpha_0, alpha_min):
462461
"""ilqr implementation."""
463462

464463
T, m = U.shape
@@ -963,7 +962,8 @@ def augmented_lagrangian(x, u, t, dual_equality, dual_inequality, penalty):
963962
inequality = inequality_constraint(x, u, t)
964963

965964
# active set
966-
active_set = np.invert(np.isclose(dual_inequality[t], 0.0) & (inequality < 0.0))
965+
active_set = np.invert(
966+
np.isclose(dual_inequality[t], 0.0) & (inequality < 0.0))
967967

968968
# update cost
969969
# TODO(taylorhowell): Gauss-Newton approximation for constraints,
@@ -1052,11 +1052,17 @@ def body(inputs):
10521052

10531053
def continuation_criteria(inputs):
10541054
# unpack
1055+
dual_inequality = inputs[3]
1056+
inequality_constraints = inputs[6]
10551057
max_constraint_violation = inputs[7]
10561058
iteration_al = inputs[11]
1059+
max_complementary_slack = np.max(
1060+
np.abs(inequality_constraints * dual_inequality))
10571061
# check maximum constraint violation and augmented Lagrangian iterations
10581062
return np.logical_and(iteration_al < maxiter_al,
1059-
max_constraint_violation > constraints_threshold)
1063+
np.logical_or(
1064+
max_constraint_violation > constraints_threshold,
1065+
max_complementary_slack > constraints_threshold))
10601066

10611067
return lax.while_loop(
10621068
continuation_criteria, body,

0 commit comments

Comments
 (0)