Skip to content
This repository was archived by the owner on Nov 6, 2024. It is now read-only.

Commit 24d0af6

Browse files
vikas-sindhwanistephentu
authored andcommitted
source sync
PiperOrigin-RevId: 420358716
1 parent 322e99d commit 24d0af6

File tree

2 files changed

+10
-10
lines changed

2 files changed

+10
-10
lines changed

tests/optimizers_test.py

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -266,7 +266,7 @@ def cost(x, u, t, params):
266266
objective = functools.partial(optimizers.objective,
267267
functools.partial(cost, params=params_cost),
268268
dynamics)
269-
grad = functools.partial(optimizers._grad_wrt_inputs, cost, dynamics)
269+
grad = functools.partial(optimizers.grad_wrt_controls, cost, dynamics)
270270
gradient = grad(U, x0, (params_cost,), ())
271271

272272
def obj(Uflat):
@@ -295,7 +295,7 @@ def cost(x, u, t, params):
295295
obj = functools.partial(optimizers.objective,
296296
functools.partial(cost, params=params_cost),
297297
dynamics)
298-
grad = functools.partial(optimizers._grad_wrt_inputs, cost, dynamics)
298+
grad = functools.partial(optimizers.grad_wrt_controls, cost, dynamics)
299299
gradient = grad(U, x0, (params_cost,), ())
300300
jax_gradient = jax.grad(obj)(U, x0)
301301
self.assertTrue(np.allclose(gradient, jax_gradient))

trajax/optimizers.py

Lines changed: 8 additions & 8 deletions
Original file line numberDiff line numberDiff line change
@@ -236,7 +236,7 @@ def _objective_fwd(cost, dynamics, U, x0, cost_args, dynamics_args):
236236

237237

238238
def _objective_bwd(cost, dynamics, res, g):
239-
return (g * _grad_wrt_inputs(cost, dynamics, *res),) + (None,) * 3
239+
return (g * grad_wrt_controls(cost, dynamics, *res),) + (None,) * 3
240240

241241

242242
_objective.defvjp(_objective_fwd, _objective_bwd)
@@ -275,7 +275,7 @@ def body(p, t): # backward recursion of Adjoint equations.
275275
return np.flipud(g), np.vstack((np.flipud(P[:T - 1]), q[T])), p
276276

277277

278-
def _grad_wrt_inputs(cost, dynamics, U, x0, cost_args, dynamics_args):
278+
def grad_wrt_controls(cost, dynamics, U, x0, cost_args, dynamics_args):
279279
"""Evaluates gradient at a control sequence.
280280
281281
Args:
@@ -315,7 +315,7 @@ def hvp(cost, dynamics, U, x0, V, cost_args, dynamics_args):
315315
Returns:
316316
gradient (T, m) of total cost with respect to controls.
317317
"""
318-
grad_fn = partial(_grad_wrt_inputs, cost, dynamics)
318+
grad_fn = partial(grad_wrt_controls, cost, dynamics)
319319
return jax.jvp(lambda U1: grad_fn(U1, x0, cost_args, dynamics_args), (U,),
320320
(V,))
321321

@@ -450,14 +450,14 @@ def ilqr(cost,
450450
cost_fn, cost_args = custom_derivatives.closure_convert(cost, x0, U[0], 0)
451451
dynamics_fn, dynamics_args = custom_derivatives.closure_convert(
452452
dynamics, x0, U[0], 0)
453-
return _ilqr(cost_fn, dynamics_fn, x0, U, tuple(cost_args),
453+
return ilqr_base(cost_fn, dynamics_fn, x0, U, tuple(cost_args),
454454
tuple(dynamics_args), maxiter, grad_norm_threshold, make_psd,
455455
psd_delta, alpha_0, alpha_min)
456456

457457

458458
@partial(jax.custom_vjp, nondiff_argnums=(0, 1))
459459
@partial(jit, static_argnums=(0, 1))
460-
def _ilqr(cost, dynamics, x0, U, cost_args, dynamics_args, maxiter,
460+
def ilqr_base(cost, dynamics, x0, U, cost_args, dynamics_args, maxiter,
461461
grad_norm_threshold, make_psd, psd_delta, alpha_0, alpha_min):
462462
"""ilqr implementation."""
463463

@@ -529,7 +529,7 @@ def continuation_criterion(inputs):
529529

530530
def _ilqr_fwd(cost, dynamics, *args):
531531
"""Forward pass of custom vector-Jacobian product implementation."""
532-
ilqr_output = _ilqr(cost, dynamics, *args) # pylint: disable=no-value-for-parameter
532+
ilqr_output = ilqr_base(cost, dynamics, *args) # pylint: disable=no-value-for-parameter
533533
X, U, _, _, adjoints, lqr, _ = ilqr_output
534534
return ilqr_output, (args, X, U, adjoints, lqr)
535535

@@ -560,7 +560,7 @@ def _ilqr_bwd(cost, dynamics, fwd_residuals, gX_gU_gNonDifferentiableOutputs):
560560
return (zeros_like_args[:2] + ((gradients, *zeros_like_args[2][1:]),) +
561561
zeros_like_args[3:])
562562

563-
_ilqr.defvjp(_ilqr_fwd, _ilqr_bwd)
563+
ilqr_base.defvjp(_ilqr_fwd, _ilqr_bwd)
564564

565565

566566
def hamiltonian(cost, dynamics):
@@ -649,7 +649,7 @@ def scipy_minimize(cost,
649649
"""
650650

651651
obj_fn = jit(partial(objective, cost, dynamics))
652-
grad_fn = jit(partial(_grad_wrt_inputs, cost, dynamics,
652+
grad_fn = jit(partial(grad_wrt_controls, cost, dynamics,
653653
cost_args=(), dynamics_args=()))
654654
T, m = U.shape
655655

0 commit comments

Comments
 (0)