-
Notifications
You must be signed in to change notification settings - Fork 8
/
Copy path06_replay_scenario.py
84 lines (68 loc) · 2.45 KB
/
06_replay_scenario.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
import logging
import gymnasium
import numpy as np
from srunner.scenarios.cut_in import CutIn
from mats_gym import BaseScenarioEnv
from mats_gym.envs.renderers import camera_pov
import mats_gym
from mats_gym.wrappers import ReplayWrapper
"""
This example shows how to use the replay functionality of the scenario environment.
"""
NUM_EPISODES = 3
def run_env(env, joint_policy):
"""
Run the environment with a joint policy until the scenario is finished.
"""
done = False
while not done:
actions = joint_policy()
obs, reward, done, truncated, info = env.step(actions)
done = all(done.values())
env.render()
return info
def main():
logging.basicConfig(
level=logging.DEBUG,
format="%(asctime)s - %(filename)s - [%(levelname)s] - %(message)s",
)
env = mats_gym.srunner_env(
host="localhost",
port=2000,
seed=42, # Set the seed to make the scenario deterministic.
scenario_name="CutInFrom_left_Lane",
config_file="scenarios/scenario-runner/CutIn.xml",
render_mode="human",
render_config=camera_pov(agent="scenario"),
timeout=10,
)
env = ReplayWrapper(env)
obs, info = env.reset(seed=42)
# Run the environment with a starting policy to generate a history of the scenario.
info = run_env(env, lambda: {agent: np.array([0.75, 0, 0]) for agent in env.agents})
# On the next reset, we can provide the history and the number of frames to replay. The environment will start from
# the last frame of replay with the exact same state.
replay = {}
# Replay the environment to frame 100 and then continue with a different policy.
policies = [
lambda: {agent: np.array([1.0, 0.0, 0.0]) for agent in env.agents},
lambda: {agent: np.array([0.0, 0.0, 1.0]) for agent in env.agents},
lambda: {agent: np.array([0.8, -0.3, 0.0]) for agent in env.agents},
]
for policy in policies:
obs, info = env.reset(seed=42, options={"replay": {"num_frames": 100}})
t = 0
done = False
while not done:
if t >= 100:
actions = policy()
else:
actions = {} # Do nothing during the replay.
obs, reward, done, truncated, info = env.step(actions)
done = all(done.values())
env.render()
t += 1
run_env(env, policy)
env.close()
if __name__ == "__main__":
main()