Skip to content
This repository was archived by the owner on Nov 6, 2024. It is now read-only.

Commit 322e99d

Browse files
trajax authorsstephentu
authored andcommitted
[JAX] Replace uses of deprecated jax.ops.index_update(x, idx, y) APIs with their up-to-date, more succinct equivalent x.at[idx].set(y).
The JAX operators: jax.ops.index_update(x, jax.ops.index[idx], y) jax.ops.index_add(x, jax.ops.index[idx], y) ... have long been deprecated in lieu of their more succinct counterparts: x.at[idx].set(y) x.at[idx].add(y) ... This change updates users of the deprecated APIs to use the current APIs, in preparation for removing the deprecated forms from JAX. The main subtlety is that if `x` is not a JAX array, we must cast it to one using `jnp.asarray(x)` before using the new form, since `.at[...]` is only defined on JAX arrays. PiperOrigin-RevId: 400214219
1 parent d4e807a commit 322e99d

File tree

2 files changed

+12
-13
lines changed

2 files changed

+12
-13
lines changed

trajax/optimizers.py

Lines changed: 5 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -344,15 +344,15 @@ def ddp_rollout(dynamics, X, U, K, k, alpha, *args):
344344
T, m = U.shape
345345
Xnew = np.zeros((T + 1, n))
346346
Unew = np.zeros((T, m))
347-
Xnew = ops.index_update(Xnew, ops.index[0], X[0])
347+
Xnew = Xnew.at[0].set(X[0])
348348

349349
def body(t, inputs):
350350
Xnew, Unew = inputs
351351
del_u = alpha * k[t] + np.matmul(K[t], Xnew[t] - X[t])
352352
u = U[t] + del_u
353353
x = dynamics(Xnew[t], u, t, *args)
354-
Unew = ops.index_update(Unew, ops.index[t], u)
355-
Xnew = ops.index_update(Xnew, ops.index[t + 1], x)
354+
Unew = Unew.at[t].set(u)
355+
Xnew = Xnew.at[t + 1].set(x)
356356
return Xnew, Unew
357357

358358
return lax.fori_loop(0, T, body, (Xnew, Unew))
@@ -742,9 +742,8 @@ def gaussian_samples(random_key, mean, stdev, control_low, control_high,
742742
smoothing_coef = hyperparams['sampling_smoothing']
743743

744744
def body_fun(t, noises):
745-
return jax.ops.index_update(
746-
noises, jax.ops.index[:, t], smoothing_coef * noises[:, t - 1] +
747-
np.sqrt(1 - smoothing_coef**2) * noises[:, t])
745+
return noises.at[:, t].set(smoothing_coef * noises[:, t - 1] +
746+
np.sqrt(1 - smoothing_coef**2) * noises[:, t])
748747

749748
noises = jax.lax.fori_loop(1, horizon, body_fun, noises)
750749
samples = noises * stdev

trajax/tvlqr.py

Lines changed: 7 additions & 7 deletions
Original file line numberDiff line numberDiff line change
@@ -46,14 +46,14 @@ def rollout(K, k, x0, A, B, c):
4646
T, m, n = K.shape
4747
X = np.zeros((T + 1, n))
4848
U = np.zeros((T, m))
49-
X = ops.index_update(X, ops.index[0], x0)
49+
X = X.at[0].set(x0)
5050

5151
def body(t, inputs):
5252
X, U = inputs
5353
u = np.matmul(K[t], X[t]) + k[t]
5454
x = np.matmul(A[t], X[t]) + np.matmul(B[t], u) + c[t]
55-
X = ops.index_update(X, ops.index[t + 1], x)
56-
U = ops.index_update(U, ops.index[t], u)
55+
X = X.at[t + 1].set(x)
56+
U = U.at[t].set(u)
5757
return X, U
5858

5959
return lax.fori_loop(0, T, body, (X, U))
@@ -147,10 +147,10 @@ def body(tt, inputs):
147147
t = T - 1 - tt
148148
P_t, p_t, K_t, k_t = lqr_step(P[t+1], p[t+1], Q[t], q[t], R[t], r[t], M[t],
149149
A[t], B[t], c[t])
150-
K = ops.index_update(K, ops.index[t], K_t)
151-
k = ops.index_update(k, ops.index[t], k_t)
152-
P = ops.index_update(P, ops.index[t], P_t)
153-
p = ops.index_update(p, ops.index[t], p_t)
150+
K = K.at[t].set(K_t)
151+
k = k.at[t].set(k_t)
152+
P = P.at[t].set(P_t)
153+
p = p.at[t].set(p_t)
154154

155155
return K, k, P, p
156156

0 commit comments

Comments
 (0)