@@ -236,7 +236,7 @@ def _objective_fwd(cost, dynamics, U, x0, cost_args, dynamics_args):
236
236
237
237
238
238
def _objective_bwd (cost , dynamics , res , g ):
239
- return (g * _grad_wrt_inputs (cost , dynamics , * res ),) + (None ,) * 3
239
+ return (g * grad_wrt_controls (cost , dynamics , * res ),) + (None ,) * 3
240
240
241
241
242
242
_objective .defvjp (_objective_fwd , _objective_bwd )
@@ -275,7 +275,7 @@ def body(p, t): # backward recursion of Adjoint equations.
275
275
return np .flipud (g ), np .vstack ((np .flipud (P [:T - 1 ]), q [T ])), p
276
276
277
277
278
- def _grad_wrt_inputs (cost , dynamics , U , x0 , cost_args , dynamics_args ):
278
+ def grad_wrt_controls (cost , dynamics , U , x0 , cost_args , dynamics_args ):
279
279
"""Evaluates gradient at a control sequence.
280
280
281
281
Args:
@@ -315,7 +315,7 @@ def hvp(cost, dynamics, U, x0, V, cost_args, dynamics_args):
315
315
Returns:
316
316
gradient (T, m) of total cost with respect to controls.
317
317
"""
318
- grad_fn = partial (_grad_wrt_inputs , cost , dynamics )
318
+ grad_fn = partial (grad_wrt_controls , cost , dynamics )
319
319
return jax .jvp (lambda U1 : grad_fn (U1 , x0 , cost_args , dynamics_args ), (U ,),
320
320
(V ,))
321
321
@@ -450,14 +450,14 @@ def ilqr(cost,
450
450
cost_fn , cost_args = custom_derivatives .closure_convert (cost , x0 , U [0 ], 0 )
451
451
dynamics_fn , dynamics_args = custom_derivatives .closure_convert (
452
452
dynamics , x0 , U [0 ], 0 )
453
- return _ilqr (cost_fn , dynamics_fn , x0 , U , tuple (cost_args ),
453
+ return ilqr_base (cost_fn , dynamics_fn , x0 , U , tuple (cost_args ),
454
454
tuple (dynamics_args ), maxiter , grad_norm_threshold , make_psd ,
455
455
psd_delta , alpha_0 , alpha_min )
456
456
457
457
458
458
@partial (jax .custom_vjp , nondiff_argnums = (0 , 1 ))
459
459
@partial (jit , static_argnums = (0 , 1 ))
460
- def _ilqr (cost , dynamics , x0 , U , cost_args , dynamics_args , maxiter ,
460
+ def ilqr_base (cost , dynamics , x0 , U , cost_args , dynamics_args , maxiter ,
461
461
grad_norm_threshold , make_psd , psd_delta , alpha_0 , alpha_min ):
462
462
"""ilqr implementation."""
463
463
@@ -529,7 +529,7 @@ def continuation_criterion(inputs):
529
529
530
530
def _ilqr_fwd (cost , dynamics , * args ):
531
531
"""Forward pass of custom vector-Jacobian product implementation."""
532
- ilqr_output = _ilqr (cost , dynamics , * args ) # pylint: disable=no-value-for-parameter
532
+ ilqr_output = ilqr_base (cost , dynamics , * args ) # pylint: disable=no-value-for-parameter
533
533
X , U , _ , _ , adjoints , lqr , _ = ilqr_output
534
534
return ilqr_output , (args , X , U , adjoints , lqr )
535
535
@@ -560,7 +560,7 @@ def _ilqr_bwd(cost, dynamics, fwd_residuals, gX_gU_gNonDifferentiableOutputs):
560
560
return (zeros_like_args [:2 ] + ((gradients , * zeros_like_args [2 ][1 :]),) +
561
561
zeros_like_args [3 :])
562
562
563
- _ilqr .defvjp (_ilqr_fwd , _ilqr_bwd )
563
+ ilqr_base .defvjp (_ilqr_fwd , _ilqr_bwd )
564
564
565
565
566
566
def hamiltonian (cost , dynamics ):
@@ -649,7 +649,7 @@ def scipy_minimize(cost,
649
649
"""
650
650
651
651
obj_fn = jit (partial (objective , cost , dynamics ))
652
- grad_fn = jit (partial (_grad_wrt_inputs , cost , dynamics ,
652
+ grad_fn = jit (partial (grad_wrt_controls , cost , dynamics ,
653
653
cost_args = (), dynamics_args = ()))
654
654
T , m = U .shape
655
655
0 commit comments