Skip to content

Commit adae364

Browse files
Improved example.py by adding MiniGrid and ProcGen examples
1 parent e029d8a commit adae364

File tree

4 files changed

+89
-50
lines changed

4 files changed

+89
-50
lines changed

example.py

+84-43
Original file line numberDiff line numberDiff line change
@@ -8,7 +8,8 @@
88
one for grid environments with image representations
99
one for wrapping Atari env qbert
1010
one for wrapping Mujoco env HalfCheetah
11-
one for wrapping Minigrid env
11+
one for wrapping MiniGrid env
12+
one for wrapping ProcGen env
1213
two examples at the end showing how to create toy envs using gym.make()
1314
1415
Many further examples can be found in test_mdp_playground.py.
@@ -22,6 +23,17 @@
2223
import numpy as np
2324

2425

26+
def display_image(obs, mode="RGB"):
27+
# Display the image observation associated with the next state
28+
from PIL import Image
29+
30+
# Because numpy is row-major and Image is column major, need to transpose
31+
obs = obs.transpose(1, 0, 2)
32+
img1 = Image.fromarray(np.squeeze(obs), mode) # squeeze() is
33+
# used because the image is 3-D because frameworks like Ray expect the image
34+
# to be 3-D.
35+
img1.show()
36+
2537
def discrete_environment_example():
2638

2739
config = {}
@@ -101,18 +113,10 @@ def discrete_environment_image_representations_example():
101113
# the current discrete state.
102114
print("sars', done =", state, action, reward, next_state, done)
103115

104-
# Display the image observation associated with the next state
105-
from PIL import Image
106-
107-
# Because numpy is row-major and Image is column major, need to transpose
108-
next_state_image = next_state_image.transpose(1, 0, 2)
109-
img1 = Image.fromarray(np.squeeze(next_state_image), "L") # 'L' is used for
110-
# black and white. squeeze() is used because the image is 3-D because
111-
# frameworks like Ray expect the image to be 3-D.
112-
img1.show()
113-
114116
env.close()
115117

118+
display_image(next_state_image, mode="L")
119+
116120

117121
def continuous_environment_example_move_along_a_line():
118122

@@ -236,15 +240,8 @@ def grid_environment_image_representations_example():
236240
env.reset()
237241
env.close()
238242

239-
# Display the image observation associated with the next state
240-
from PIL import Image
243+
display_image(next_obs)
241244

242-
# Because numpy is row-major and Image is column major, need to transpose
243-
next_obs = next_obs.transpose(1, 0, 2)
244-
img1 = Image.fromarray(np.squeeze(next_obs), "RGB") # squeeze() is
245-
# used because the image is 3-D because frameworks like Ray expect the image
246-
# to be 3-D.
247-
img1.show()
248245

249246

250247
def atari_wrapper_example():
@@ -265,21 +262,24 @@ def atari_wrapper_example():
265262
state = env.reset()
266263

267264
print(
268-
"Taking a step in the environment with a random action and printing the transition:"
269-
)
270-
action = env.action_space.sample()
271-
next_state, reward, done, info = env.step(action)
272-
print(
273-
"s.shape ar s'.shape, done =",
274-
state.shape,
275-
action,
276-
reward,
277-
next_state.shape,
278-
done,
265+
"Taking 10 steps in the environment with a random action and printing the transition:"
279266
)
267+
for i in range(10):
268+
action = env.action_space.sample()
269+
next_state, reward, done, info = env.step(action)
270+
print(
271+
"s.shape ar s'.shape, done =",
272+
state.shape,
273+
action,
274+
reward,
275+
next_state.shape,
276+
done,
277+
)
280278

281279
env.close()
282280

281+
display_image(next_state)
282+
283283

284284
def mujoco_wrapper_example():
285285

@@ -302,23 +302,23 @@ def mujoco_wrapper_example():
302302
try:
303303
from mdp_playground.envs import get_mujoco_wrapper
304304
from gym.envs.mujoco.half_cheetah_v3 import HalfCheetahEnv
305-
except Exception as e:
306-
print("Exception:", e, "caught. You may need to install mujoco-py. NOT running mujoco_wrapper_example.")
307-
return
305+
wrapped_mujoco_env = get_mujoco_wrapper(HalfCheetahEnv)
308306

309-
wrapped_mujoco_env = get_mujoco_wrapper(HalfCheetahEnv)
307+
env = wrapped_mujoco_env(**config)
308+
state = env.reset()
310309

311-
env = wrapped_mujoco_env(**config)
312-
state = env.reset()
310+
print(
311+
"Taking a step in the environment with a random action and printing the transition:"
312+
)
313+
action = env.action_space.sample()
314+
next_state, reward, done, info = env.step(action)
315+
print("sars', done =", state, action, reward, next_state, done)
313316

314-
print(
315-
"Taking a step in the environment with a random action and printing the transition:"
316-
)
317-
action = env.action_space.sample()
318-
next_state, reward, done, info = env.step(action)
319-
print("sars', done =", state, action, reward, next_state, done)
317+
env.close()
320318

321-
env.close()
319+
except ImportError as e:
320+
print("Exception:", type(e), e, "caught. You may need to install mujoco-py. NOT running mujoco_wrapper_example.")
321+
return
322322

323323

324324
def minigrid_wrapper_example():
@@ -358,6 +358,44 @@ def minigrid_wrapper_example():
358358

359359
env.close()
360360

361+
display_image(next_obs)
362+
363+
364+
def procgen_wrapper_example():
365+
366+
config = {
367+
"seed": 0,
368+
"delay": 1,
369+
"transition_noise": 0.25,
370+
"reward_noise": lambda a: a.normal(0, 0.1),
371+
"state_space_type": "discrete",
372+
}
373+
374+
from mdp_playground.envs.gym_env_wrapper import GymEnvWrapper
375+
import gym
376+
377+
env = gym.make("procgen:procgen-coinrun-v0")
378+
env = GymEnvWrapper(env, **config)
379+
obs = env.reset()
380+
381+
print(
382+
"Taking a step in the environment with a random action and printing the transition:"
383+
)
384+
action = env.action_space.sample()
385+
next_obs, reward, done, info = env.step(action)
386+
print(
387+
"s.shape ar s'.shape, done =",
388+
obs.shape,
389+
action,
390+
reward,
391+
next_obs.shape,
392+
done,
393+
)
394+
395+
env.close()
396+
397+
display_image(next_obs)
398+
361399

362400
if __name__ == "__main__":
363401

@@ -404,6 +442,9 @@ def minigrid_wrapper_example():
404442
print(set_ansi_escape + "\nRunning MiniGrid wrapper example:\n" + reset_ansi_escape)
405443
minigrid_wrapper_example()
406444

445+
# print(set_ansi_escape + "\nRunning ProcGen wrapper example:\n" + reset_ansi_escape)
446+
# procgen_wrapper_example()
447+
407448
# Using gym.make() example 1
408449
import mdp_playground
409450
import gym

mdp_playground/envs/__init__.py

+3-2
Original file line numberDiff line numberDiff line change
@@ -1,7 +1,8 @@
11
from mdp_playground.envs.rl_toy_env import RLToyEnv
2+
from gym import error
23

34
try:
45
from mdp_playground.envs.gym_env_wrapper import GymEnvWrapper
56
from mdp_playground.envs.mujoco_env_wrapper import get_mujoco_wrapper
6-
except Exception as e:
7-
print("Exception:", e, "caught. You may need to install Ray or mujoco-py.")
7+
except error.DependencyNotInstalled as e:
8+
print("Exception:", type(e), e, "caught. You may need to install Ray or mujoco-py.")

mdp_playground/envs/gym_env_wrapper.py

+1-1
Original file line numberDiff line numberDiff line change
@@ -4,7 +4,6 @@
44
import sys
55
from gym.spaces import Box, Tuple
66
from gym.wrappers import AtariPreprocessing
7-
from ray.rllib.env.atari_wrappers import wrap_deepmind, is_atari
87
from mdp_playground.envs.rl_toy_env import RLToyEnv
98
import warnings
109
import PIL.ImageDraw as ImageDraw
@@ -151,6 +150,7 @@ def __init__(self, env, **config):
151150
if (
152151
"wrap_deepmind_ray" in config and config["wrap_deepmind_ray"]
153152
): # hack ##TODO remove?
153+
from ray.rllib.env.atari_wrappers import wrap_deepmind, is_atari
154154
self.env = wrap_deepmind(self.env, dim=42, framestack=True)
155155
elif "atari_preprocessing" in config and config["atari_preprocessing"]:
156156
self.frame_skip = 4 # default for AtariPreprocessing

setup.py

+1-4
Original file line numberDiff line numberDiff line change
@@ -12,10 +12,9 @@
1212
package_data = {"": ["*"]}
1313

1414
extras_require = [
15-
"ray[default,rllib,debug]==1.3.0",
15+
"ray[default,rllib]==1.3.0",
1616
"tensorflow==2.2.0",
1717
"pillow>=6.1.0",
18-
"pandas==0.25.0",
1918
"requests==2.22.0",
2019
"configspace==0.4.10",
2120
"scipy>=1.3.0",
@@ -27,7 +26,6 @@
2726
"ray[rllib,debug]==0.7.3",
2827
"tensorflow==1.13.0rc1",
2928
"pillow>=6.1.0",
30-
"pandas==0.25.0",
3129
"requests==2.22.0",
3230
"configspace==0.4.10",
3331
"scipy==1.3.0",
@@ -42,7 +40,6 @@
4240
# 'ray[rllib,debug]==0.9.0',
4341
"tensorflow==2.2.0",
4442
"tensorflow-probability==0.9.0",
45-
"pandas==0.25.0",
4643
"requests==2.22.0",
4744
"mujoco-py==2.0.2.13", # with mujoco 2.0
4845
"configspace>=0.4.10",

0 commit comments

Comments
 (0)