Skip to content
This repository was archived by the owner on Nov 6, 2024. It is now read-only.

Commit d4e807a

Browse files
ssingh19stephentu
authored andcommitted
source sync
PiperOrigin-RevId: 395326593
1 parent d2ab930 commit d4e807a

File tree

1 file changed

+10
-3
lines changed

1 file changed

+10
-3
lines changed

trajax/tvlqr.py

Lines changed: 10 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -134,20 +134,27 @@ def tvlqr(Q, q, R, r, M, A, B, c):
134134
m = R.shape[1]
135135
n = Q.shape[1]
136136

137+
P = np.zeros((T+1, n, n))
138+
p = np.zeros((T+1, n))
137139
K = np.zeros((T, m, n))
138140
k = np.zeros((T, m))
139141

142+
P = P.at[-1].set(Q[T])
143+
p = p.at[-1].set(q[T])
144+
140145
def body(tt, inputs):
141146
K, k, P, p = inputs
142147
t = T - 1 - tt
143-
P, p, K_t, k_t = lqr_step(P, p, Q[t], q[t], R[t], r[t], M[t], A[t], B[t],
144-
c[t])
148+
P_t, p_t, K_t, k_t = lqr_step(P[t+1], p[t+1], Q[t], q[t], R[t], r[t], M[t],
149+
A[t], B[t], c[t])
145150
K = ops.index_update(K, ops.index[t], K_t)
146151
k = ops.index_update(k, ops.index[t], k_t)
152+
P = ops.index_update(P, ops.index[t], P_t)
153+
p = ops.index_update(p, ops.index[t], p_t)
147154

148155
return K, k, P, p
149156

150-
return lax.fori_loop(0, T, body, (K, k, Q[T], q[T]))
157+
return lax.fori_loop(0, T, body, (K, k, P, p))
151158

152159

153160
@partial(jit, static_argnums=(0,))

0 commit comments

Comments
 (0)