Skip to content

Commit 3a17423

Browse files
MAJOR: Added image_representations for cont. envs, removed some bugs and minor improvements
1 parent 5604327 commit 3a17423

8 files changed

+386
-91
lines changed

example.py

+2-2
Original file line numberDiff line numberDiff line change
@@ -25,7 +25,7 @@ def discrete_environment_example():
2525
config["seed"] = 0
2626

2727
config["state_space_type"] = "discrete"
28-
config["state_space_size"] = 8
28+
config["action_space_size"] = 8
2929
config["delay"] = 1
3030
config["sequence_length"] = 3
3131
config["reward_scale"] = 2.5
@@ -59,7 +59,7 @@ def discrete_environment_image_representations_example():
5959
config["seed"] = 0
6060

6161
config["state_space_type"] = "discrete"
62-
config["state_space_size"] = 8
62+
config["action_space_size"] = 8
6363
config["image_representations"] = True
6464
config["delay"] = 1
6565
config["sequence_length"] = 3

mdp_playground/envs/rl_toy_env.py

+51-82
Large diffs are not rendered by default.

mdp_playground/spaces/__init__.py

+3-1
Original file line numberDiff line numberDiff line change
@@ -2,6 +2,8 @@
22
from mdp_playground.spaces.box_extended import BoxExtended
33
from mdp_playground.spaces.multi_discrete_extended import MultiDiscreteExtended
44
from mdp_playground.spaces.image_multi_discrete import ImageMultiDiscrete
5+
from mdp_playground.spaces.image_continuous import ImageContinuous
56
from mdp_playground.spaces.tuple_extended import TupleExtended
67

7-
__all__ = ["BoxExtended", "DiscreteExtended", "MultiDiscreteExtended", "ImageMultiDiscrete", "TupleExtended"]
8+
__all__ = ["BoxExtended", "DiscreteExtended", "MultiDiscreteExtended",\
9+
"ImageMultiDiscrete", "ImageContinuous", "TupleExtended"]
+194
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,194 @@
1+
import warnings
2+
import numpy as np
3+
import gym
4+
from gym.spaces import Box, Space
5+
import PIL.ImageDraw as ImageDraw
6+
import PIL.Image as Image
7+
from PIL.Image import FLIP_LEFT_RIGHT, FLIP_TOP_BOTTOM
8+
import os
9+
10+
class ImageContinuous(Box):
11+
'''A space that maps a continuous 1- or 2-D space 1-to-1 to images so that the
12+
images may be used as representations for corresponding continuous environments.
13+
14+
Methods
15+
-------
16+
get_concatenated_image(continuous_obs)
17+
Gets an image representation for a given feature space observation
18+
'''
19+
20+
def __init__(self, feature_space, term_spaces=None, width=100, height=100,\
21+
circle_radius=5, target_point=None, relevant_indices=[0,1],\
22+
seed=None, use_custom_images=None, cust_path=None, dtype=np.uint8):
23+
'''
24+
Parameters
25+
----------
26+
feature_space : Gym.spaces.Box
27+
The feature space to which this class associates images as external
28+
observations
29+
term_spaces : list of Gym.spaces.Box
30+
Sub-spaces of the feature space which are terminal
31+
width : int
32+
The width of the image
33+
height : int
34+
The height of the image
35+
circle_radius : int
36+
The radius of the circle which represents the agent and target point
37+
target_point : np.array
38+
39+
relevant_indices : list
40+
41+
seed : int
42+
Seed for this space
43+
'''
44+
# ##TODO Define a common superclass for this and ImageMultiDiscrete
45+
self.feature_space = feature_space
46+
assert (self.feature_space.high != np.inf).any()
47+
assert (self.feature_space.low != -np.inf).any()
48+
self.width = width
49+
self.height = height
50+
# Warn if resolution is too low?
51+
self.circle_radius = circle_radius
52+
self.target_point = target_point
53+
self.term_spaces = term_spaces
54+
self.relevant_indices = relevant_indices
55+
all_indices = set(range(self.feature_space.shape[0]))
56+
self.irrelevant_indices = list(all_indices - set(self.relevant_indices))
57+
if len(self.irrelevant_indices) == 0:
58+
self.irrelevant_features = False
59+
else:
60+
self.irrelevant_features = True
61+
62+
self.goal_colour = (0, 255, 0)
63+
self.agent_colour = (0, 0, 255)
64+
self.term_colour = (0, 0, 0)
65+
66+
assert len(feature_space.shape) == 1
67+
relevant_dims = len(relevant_indices)
68+
irr_dims = len(self.irrelevant_indices)
69+
assert relevant_dims <= 2 and irr_dims <=2, "Image observations are "\
70+
"supported only "\
71+
"for 1- or 2-D feature spaces."
72+
73+
74+
# Shape has 1 appended for Ray Rllib to be compatible IIRC
75+
super(ImageContinuous, self).__init__(shape=(width, height, 1), \
76+
dtype=dtype, low=0, high=255)
77+
super(ImageContinuous, self).seed(seed=seed)
78+
79+
if self.target_point is not None:
80+
self.target_point_pixel = self.convert_to_pixel(target_point)
81+
82+
83+
def generate_image(self, position, relevant=True):
84+
'''
85+
Parameters
86+
----------
87+
position : np.array
88+
89+
'''
90+
# Use RGB
91+
image_ = Image.new("RGB", (self.width, self.height), color=(255,255,255))
92+
# Use L for black and white 8-bit pixels instead of RGB in case not
93+
# using custom images
94+
# image_ = Image.new("L", (self.width, self.height))
95+
draw = ImageDraw.Draw(image_)
96+
97+
# Draw term_spaces first, so that others are drawn over it
98+
if self.term_spaces is not None and relevant:
99+
for term_space in self.term_spaces:
100+
low = self.convert_to_pixel(term_space.low)
101+
high = self.convert_to_pixel(term_space.high)
102+
103+
leftUpPoint = tuple((low))
104+
rightDownPoint = tuple((high))
105+
twoPointList = [leftUpPoint, rightDownPoint]
106+
draw.rectangle(twoPointList, fill=self.term_colour)
107+
108+
R = self.circle_radius
109+
110+
if self.target_point is not None and relevant:
111+
# print("draw2", self.target_point_pixel)
112+
leftUpPoint = tuple((self.target_point_pixel - R))
113+
rightDownPoint = tuple((self.target_point_pixel + R))
114+
twoPointList = [leftUpPoint, rightDownPoint]
115+
draw.ellipse(twoPointList, fill=self.goal_colour)
116+
117+
pos_pixel = self.convert_to_pixel(position)
118+
# print("draw1", pos_pixel)
119+
# Draw circle https://stackoverflow.com/a/2980931/11063709
120+
leftUpPoint = tuple(pos_pixel - R)
121+
rightDownPoint = tuple(pos_pixel + R)
122+
twoPointList = [leftUpPoint, rightDownPoint]
123+
draw.ellipse(twoPointList, fill=self.agent_colour)
124+
125+
126+
127+
128+
# Because numpy is row-major and Image is column major, need to transpose
129+
# ret_arr = np.array(image_).T # For 2-D
130+
ret_arr = np.transpose(np.array(image_), axes=(1, 0, 2))
131+
132+
return ret_arr
133+
134+
def get_concatenated_image(self, obs):
135+
'''Gets the "stitched together" image made from images corresponding to
136+
each continuous sub-space within the continuous space, concatenated
137+
along the X-axis.
138+
'''
139+
concatenated_image = []
140+
# For relevant/irrelevant sub-spaces:
141+
concatenated_image.append(self.generate_image(obs[self.relevant_indices]))
142+
if self.irrelevant_features:
143+
irr_image = self.generate_image(obs[self.irrelevant_indices], relevant=False)
144+
concatenated_image.append(irr_image)
145+
146+
concatenated_image = np.concatenate(tuple(concatenated_image), axis=0)
147+
148+
return np.atleast_3d(concatenated_image) # because Ray expects an
149+
# image to have >=3 dims
150+
151+
def convert_to_pixel(self, position):
152+
'''
153+
'''
154+
# It's implicit that both relevant and irrelevant sub-spaces have the
155+
# same max and min here:
156+
max = self.feature_space.high[self.relevant_indices]
157+
min = self.feature_space.low[self.relevant_indices]
158+
pos_pixel = ((position - min) / (max - min))
159+
pos_pixel = (pos_pixel * self.shape[:2]).astype(int)
160+
161+
return pos_pixel
162+
163+
164+
def sample(self):
165+
166+
sampled = self.feature_space.sample()
167+
return self.get_concatenated_image(sampled)
168+
169+
def __repr__(self):
170+
return "{} with continuous underlying space of shape: {} and "\
171+
"images of resolution: {} and dtype: {}".format(self.__class__,\
172+
self.feature_space.shape,\
173+
self.shape, self.dtype)
174+
175+
def contains(self, x):
176+
"""
177+
Return boolean specifying if x is a valid
178+
member of this space
179+
"""
180+
if x.shape == (self.width, self.height, 1): #TODO compare each pixel for all possible images?
181+
return True
182+
183+
def to_jsonable(self, sample_n):
184+
"""Convert a batch of samples from this space to a JSONable data type."""
185+
# By default, assume identity is JSONable
186+
raise NotImplementedError
187+
188+
def from_jsonable(self, sample_n):
189+
"""Convert a JSONable data type to a batch of samples from this space."""
190+
# By default, assume identity is JSONable
191+
raise NotImplementedError
192+
193+
def __eq__(self, other):
194+
raise NotImplementedError

mdp_playground/spaces/image_multi_discrete.py

+8-5
Original file line numberDiff line numberDiff line change
@@ -16,7 +16,7 @@ class ImageMultiDiscrete(Box):
1616
Gets an image representation for a given multi_discrete_state
1717
'''
1818

19-
def __init__(self, state_space_sizes, width=100, height=100, circle_radius=20, transforms='rotate,flip,scale,shift', sh_quant=1, scale_range=(0.5,1.5), ro_quant=1, seed=None, use_custom_images=None, cust_path=None): # , polygon_sides=4
19+
def __init__(self, state_space_sizes, width=100, height=100, circle_radius=20, transforms='rotate,flip,scale,shift', sh_quant=1, scale_range=(0.5,1.5), ro_quant=1, seed=None, use_custom_images=None, cust_path=None, dtype=np.uint8): # , polygon_sides=4
2020
'''
2121
Parameters
2222
----------
@@ -84,7 +84,7 @@ def __init__(self, state_space_sizes, width=100, height=100, circle_radius=20, t
8484

8585

8686
# self.shape = (width, height, 1)
87-
super(ImageMultiDiscrete, self).__init__(shape=(width, height, 1), dtype=np.int64, low=0, high=255) #
87+
super(ImageMultiDiscrete, self).__init__(shape=(width, height, 1), dtype=dtype, low=0, high=255) #
8888
super(ImageMultiDiscrete, self).seed(seed=seed) #
8989

9090
# def seed(self, seed=None):
@@ -214,21 +214,24 @@ def get_concatenated_image(self, multi_discrete_state,):
214214
# concatenated_image.append(self.disjoint_states[i][multi_discrete_state[i]])
215215
concatenated_image = np.concatenate(tuple(concatenated_image), axis=0)
216216

217-
return concatenated_image[..., np.newaxis] # because Ray expects an image to have >=3 dims
217+
return np.atleast_3d(concatenated_image) # because Ray expects an image to have >=3 dims
218218

219219
# def get_multi_discrete_state(self,
220220

221221
def sample(self):
222222
sss = np.array(self.state_space_sizes)
223-
sampled = (self.np_random.random_sample(sss.shape) * sss).astype(np.int64) # Based on Gym's MultiDiscrete sampling
223+
sampled = (self.np_random.random_sample(sss.shape) * sss).astype(self.dtype) # Based on Gym's MultiDiscrete sampling
224224
# if type(sampled) == int:
225225
# sampled = [sampled]
226226
sampled = list(sampled)
227227

228228
return self.get_concatenated_image(sampled)
229229

230230
def __repr__(self):
231-
return "ImageMultiDiscrete with multi-discrete space of shape: {} and images of resolution: {}".format(self.state_space_sizes, self.shape)
231+
return "{} with multi-discrete space of shape: {} and "\
232+
"images of resolution: {} and dtype: {}".format(self.__class__,\
233+
self.state_space_sizes,\
234+
self.shape, self.dtype)
232235

233236
def contains(self, x):
234237
"""
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,62 @@
1+
import unittest
2+
import numpy as np
3+
from mdp_playground.spaces.image_continuous import ImageContinuous
4+
from gym.spaces import Box
5+
# import PIL.ImageDraw as ImageDraw
6+
import PIL.Image as Image
7+
8+
9+
class TestImageContinuous(unittest.TestCase):
10+
11+
def test_image_continuous(self):
12+
lows = 0.0
13+
highs = 20.0
14+
cs2 = Box(shape=(2,), low=lows, high=highs,)
15+
cs4 = Box(shape=(4,), low=lows, high=highs,)
16+
17+
imc = ImageContinuous(cs2, width=400, height=400,)
18+
pos = np.array([5.0, 7.0])
19+
img1 = Image.fromarray(np.squeeze(imc.generate_image(pos)), 'RGB')
20+
img1.show()
21+
22+
target = np.array([10, 10])
23+
imc = ImageContinuous(cs2, target_point=target, width=400, height=400,)
24+
img1 = Image.fromarray(np.squeeze(imc.generate_image(pos)), 'RGB')
25+
img1.show()
26+
27+
# Terminal sub-spaces
28+
lows = np.array([2., 4.])
29+
highs = np.array([3., 6.])
30+
cs2_term1 = Box(low=lows, high=highs,)
31+
lows = np.array([12., 3.])
32+
highs = np.array([13., 4.])
33+
cs2_term2 = Box(low=lows, high=highs,)
34+
term_spaces = [cs2_term1, cs2_term2]
35+
36+
target = np.array([10, 10])
37+
imc = ImageContinuous(cs2, target_point=target, term_spaces=term_spaces,\
38+
width=400, height=400,)
39+
pos = np.array([5.0, 7.0])
40+
img1 = Image.fromarray(np.squeeze(imc.get_concatenated_image(pos)), 'RGB')
41+
img1.show()
42+
43+
44+
# Irrelevant features
45+
target = np.array([10, 10])
46+
imc = ImageContinuous(cs4, target_point=target, width=400, height=400,)
47+
pos = np.array([5.0, 7.0, 10.0, 15.0])
48+
img1 = Image.fromarray(np.squeeze(imc.get_concatenated_image(pos)), 'RGB')
49+
img1.show()
50+
# print(imc.get_concatenated_image(pos).shape)
51+
52+
# Random sample and __repr__
53+
imc = ImageContinuous(cs4, target_point=target, width=400, height=400,)
54+
print(imc)
55+
img1 = Image.fromarray(np.squeeze(imc.sample()), 'RGB')
56+
img1.show()
57+
58+
59+
60+
61+
if __name__ == '__main__':
62+
unittest.main()

mdp_playground/spaces/test_image_multi_discrete.py

+3-1
Original file line numberDiff line numberDiff line change
@@ -1,6 +1,6 @@
11
import unittest
22
import numpy as np
3-
from gym.spaces.image_multi_discrete import ImageMultiDiscrete
3+
from mdp_playground.spaces.image_multi_discrete import ImageMultiDiscrete
44
from gym.spaces import Discrete, MultiDiscrete
55
# import gym
66
# from gym.spaces import MultiDiscrete
@@ -13,6 +13,8 @@ class TestImageMultiDiscrete(unittest.TestCase):
1313

1414
def test_image_multi_discrete(self):
1515
ds4 = Discrete(4)
16+
ds4 = [ds4.n]
17+
print(ds4)
1618
imd = ImageMultiDiscrete(ds4, transforms='shift')
1719
from PIL import Image
1820
# img1 = Image.fromarray(imd.disjoint_states[0][1], 'L')

0 commit comments

Comments
 (0)