Skip to content
This repository has been archived by the owner on Nov 6, 2024. It is now read-only.

Does not work w/ BRAX #6

Open
MahanFathi opened this issue Jul 11, 2022 · 0 comments
Open

Does not work w/ BRAX #6

MahanFathi opened this issue Jul 11, 2022 · 0 comments

Comments

@MahanFathi
Copy link

Has anyone tried the solvers on BRAX environments? Here's what I have:

import trajax
import jax
from jax import numpy as jnp
from jax.flatten_util import ravel_pytree
import brax
from brax import envs

def get_f_and_c(env):
    key = jax.random.PRNGKey(0)
    state = env.reset(key)
    _, x2qp = ravel_pytree(state.qp)
    def f(x, u, t):
        qp = x2qp(x)
        nqp, _ = env.sys.step(qp, u)
        return ravel_pytree(nqp)[0]
    def c(x, u, t):
        qp = x2qp(x)
        dstate = state.replace(qp=qp)
        nstate = env.step(dstate, u)
        return -nstate.reward
    return f, c

env = envs.create('inverted_pendulum')
key = jax.random.PRNGKey(0)
state = env.reset(key)
x_init, x2qp = ravel_pytree(state.qp)

f, c = get_f_and_c(env)

x, u, cost, *outputs = trajax.optimizers.ilqr(c, f, x_init, jnp.zeros([1, env.action_size]))

which gives:

TracerArrayConversionError: The numpy.ndarray conversion method __array__() was called on the JAX Tracer object Traced<ShapedArray(float32[3])>with<DynamicJaxprTrace(level=1/0)>
See https://jax.readthedocs.io/en/latest/errors.html#jax.errors.TracerArrayConversionError
@MahanFathi MahanFathi changed the title Does not work on BRAX Does not work w/ BRAX Jul 11, 2022
Sign up for free to subscribe to this conversation on GitHub. Already have an account? Sign in.
Labels
None yet
Projects
None yet
Development

No branches or pull requests

1 participant