diff --git a/mjx/mujoco/mjx/_src/smooth.py b/mjx/mujoco/mjx/_src/smooth.py index 632a225d84..9e5b06e1c6 100644 --- a/mjx/mujoco/mjx/_src/smooth.py +++ b/mjx/mujoco/mjx/_src/smooth.py @@ -254,7 +254,7 @@ def off_diag_fn(madr_d, madr_ij, qld=qld, width=out_end - out_beg): return -(qld_row[0] / qld[madr_d]) * qld_row qld_update = jp.sum(off_diag_fn(madr_d, madr_ij), axis=0) - qld = qld.at[out_beg:out_end].add(qld_update) + qld = qld.at[jp.arange(out_beg, out_end)].add(qld_update) # TODO(erikfrey): determine if this minimum value guarding is necessary: # qld = qld.at[dof_madr].set(jp.maximum(qld[dof_madr], _MJ_MINVAL))