|
| 1 | +import warnings |
| 2 | +import numpy as np |
| 3 | +import gym |
| 4 | +from gym.spaces import Box, Space |
| 5 | +import PIL.ImageDraw as ImageDraw |
| 6 | +import PIL.Image as Image |
| 7 | +from PIL.Image import FLIP_LEFT_RIGHT, FLIP_TOP_BOTTOM |
| 8 | +import os |
| 9 | + |
| 10 | +class ImageContinuous(Box): |
| 11 | + '''A space that maps a continuous 1- or 2-D space 1-to-1 to images so that the |
| 12 | + images may be used as representations for corresponding continuous environments. |
| 13 | +
|
| 14 | + Methods |
| 15 | + ------- |
| 16 | + get_concatenated_image(continuous_obs) |
| 17 | + Gets an image representation for a given feature space observation |
| 18 | + ''' |
| 19 | + |
| 20 | + def __init__(self, feature_space, term_spaces=None, width=100, height=100,\ |
| 21 | + circle_radius=5, target_point=None, relevant_indices=[0,1],\ |
| 22 | + seed=None, use_custom_images=None, cust_path=None, dtype=np.uint8): |
| 23 | + ''' |
| 24 | + Parameters |
| 25 | + ---------- |
| 26 | + feature_space : Gym.spaces.Box |
| 27 | + The feature space to which this class associates images as external |
| 28 | + observations |
| 29 | + term_spaces : list of Gym.spaces.Box |
| 30 | + Sub-spaces of the feature space which are terminal |
| 31 | + width : int |
| 32 | + The width of the image |
| 33 | + height : int |
| 34 | + The height of the image |
| 35 | + circle_radius : int |
| 36 | + The radius of the circle which represents the agent and target point |
| 37 | + target_point : np.array |
| 38 | +
|
| 39 | + relevant_indices : list |
| 40 | +
|
| 41 | + seed : int |
| 42 | + Seed for this space |
| 43 | + ''' |
| 44 | + # ##TODO Define a common superclass for this and ImageMultiDiscrete |
| 45 | + self.feature_space = feature_space |
| 46 | + assert (self.feature_space.high != np.inf).any() |
| 47 | + assert (self.feature_space.low != -np.inf).any() |
| 48 | + self.width = width |
| 49 | + self.height = height |
| 50 | + # Warn if resolution is too low? |
| 51 | + self.circle_radius = circle_radius |
| 52 | + self.target_point = target_point |
| 53 | + self.term_spaces = term_spaces |
| 54 | + self.relevant_indices = relevant_indices |
| 55 | + all_indices = set(range(self.feature_space.shape[0])) |
| 56 | + self.irrelevant_indices = list(all_indices - set(self.relevant_indices)) |
| 57 | + if len(self.irrelevant_indices) == 0: |
| 58 | + self.irrelevant_features = False |
| 59 | + else: |
| 60 | + self.irrelevant_features = True |
| 61 | + |
| 62 | + self.goal_colour = (0, 255, 0) |
| 63 | + self.agent_colour = (0, 0, 255) |
| 64 | + self.term_colour = (0, 0, 0) |
| 65 | + |
| 66 | + assert len(feature_space.shape) == 1 |
| 67 | + relevant_dims = len(relevant_indices) |
| 68 | + irr_dims = len(self.irrelevant_indices) |
| 69 | + assert relevant_dims <= 2 and irr_dims <=2, "Image observations are "\ |
| 70 | + "supported only "\ |
| 71 | + "for 1- or 2-D feature spaces." |
| 72 | + |
| 73 | + |
| 74 | + # Shape has 1 appended for Ray Rllib to be compatible IIRC |
| 75 | + super(ImageContinuous, self).__init__(shape=(width, height, 1), \ |
| 76 | + dtype=dtype, low=0, high=255) |
| 77 | + super(ImageContinuous, self).seed(seed=seed) |
| 78 | + |
| 79 | + if self.target_point is not None: |
| 80 | + self.target_point_pixel = self.convert_to_pixel(target_point) |
| 81 | + |
| 82 | + |
| 83 | + def generate_image(self, position, relevant=True): |
| 84 | + ''' |
| 85 | + Parameters |
| 86 | + ---------- |
| 87 | + position : np.array |
| 88 | +
|
| 89 | + ''' |
| 90 | + # Use RGB |
| 91 | + image_ = Image.new("RGB", (self.width, self.height), color=(255,255,255)) |
| 92 | + # Use L for black and white 8-bit pixels instead of RGB in case not |
| 93 | + # using custom images |
| 94 | + # image_ = Image.new("L", (self.width, self.height)) |
| 95 | + draw = ImageDraw.Draw(image_) |
| 96 | + |
| 97 | + # Draw term_spaces first, so that others are drawn over it |
| 98 | + if self.term_spaces is not None and relevant: |
| 99 | + for term_space in self.term_spaces: |
| 100 | + low = self.convert_to_pixel(term_space.low) |
| 101 | + high = self.convert_to_pixel(term_space.high) |
| 102 | + |
| 103 | + leftUpPoint = tuple((low)) |
| 104 | + rightDownPoint = tuple((high)) |
| 105 | + twoPointList = [leftUpPoint, rightDownPoint] |
| 106 | + draw.rectangle(twoPointList, fill=self.term_colour) |
| 107 | + |
| 108 | + R = self.circle_radius |
| 109 | + |
| 110 | + if self.target_point is not None and relevant: |
| 111 | + # print("draw2", self.target_point_pixel) |
| 112 | + leftUpPoint = tuple((self.target_point_pixel - R)) |
| 113 | + rightDownPoint = tuple((self.target_point_pixel + R)) |
| 114 | + twoPointList = [leftUpPoint, rightDownPoint] |
| 115 | + draw.ellipse(twoPointList, fill=self.goal_colour) |
| 116 | + |
| 117 | + pos_pixel = self.convert_to_pixel(position) |
| 118 | + # print("draw1", pos_pixel) |
| 119 | + # Draw circle https://stackoverflow.com/a/2980931/11063709 |
| 120 | + leftUpPoint = tuple(pos_pixel - R) |
| 121 | + rightDownPoint = tuple(pos_pixel + R) |
| 122 | + twoPointList = [leftUpPoint, rightDownPoint] |
| 123 | + draw.ellipse(twoPointList, fill=self.agent_colour) |
| 124 | + |
| 125 | + |
| 126 | + |
| 127 | + |
| 128 | + # Because numpy is row-major and Image is column major, need to transpose |
| 129 | + # ret_arr = np.array(image_).T # For 2-D |
| 130 | + ret_arr = np.transpose(np.array(image_), axes=(1, 0, 2)) |
| 131 | + |
| 132 | + return ret_arr |
| 133 | + |
| 134 | + def get_concatenated_image(self, obs): |
| 135 | + '''Gets the "stitched together" image made from images corresponding to |
| 136 | + each continuous sub-space within the continuous space, concatenated |
| 137 | + along the X-axis. |
| 138 | + ''' |
| 139 | + concatenated_image = [] |
| 140 | + # For relevant/irrelevant sub-spaces: |
| 141 | + concatenated_image.append(self.generate_image(obs[self.relevant_indices])) |
| 142 | + if self.irrelevant_features: |
| 143 | + irr_image = self.generate_image(obs[self.irrelevant_indices], relevant=False) |
| 144 | + concatenated_image.append(irr_image) |
| 145 | + |
| 146 | + concatenated_image = np.concatenate(tuple(concatenated_image), axis=0) |
| 147 | + |
| 148 | + return np.atleast_3d(concatenated_image) # because Ray expects an |
| 149 | + # image to have >=3 dims |
| 150 | + |
| 151 | + def convert_to_pixel(self, position): |
| 152 | + ''' |
| 153 | + ''' |
| 154 | + # It's implicit that both relevant and irrelevant sub-spaces have the |
| 155 | + # same max and min here: |
| 156 | + max = self.feature_space.high[self.relevant_indices] |
| 157 | + min = self.feature_space.low[self.relevant_indices] |
| 158 | + pos_pixel = ((position - min) / (max - min)) |
| 159 | + pos_pixel = (pos_pixel * self.shape[:2]).astype(int) |
| 160 | + |
| 161 | + return pos_pixel |
| 162 | + |
| 163 | + |
| 164 | + def sample(self): |
| 165 | + |
| 166 | + sampled = self.feature_space.sample() |
| 167 | + return self.get_concatenated_image(sampled) |
| 168 | + |
| 169 | + def __repr__(self): |
| 170 | + return "{} with continuous underlying space of shape: {} and "\ |
| 171 | + "images of resolution: {} and dtype: {}".format(self.__class__,\ |
| 172 | + self.feature_space.shape,\ |
| 173 | + self.shape, self.dtype) |
| 174 | + |
| 175 | + def contains(self, x): |
| 176 | + """ |
| 177 | + Return boolean specifying if x is a valid |
| 178 | + member of this space |
| 179 | + """ |
| 180 | + if x.shape == (self.width, self.height, 1): #TODO compare each pixel for all possible images? |
| 181 | + return True |
| 182 | + |
| 183 | + def to_jsonable(self, sample_n): |
| 184 | + """Convert a batch of samples from this space to a JSONable data type.""" |
| 185 | + # By default, assume identity is JSONable |
| 186 | + raise NotImplementedError |
| 187 | + |
| 188 | + def from_jsonable(self, sample_n): |
| 189 | + """Convert a JSONable data type to a batch of samples from this space.""" |
| 190 | + # By default, assume identity is JSONable |
| 191 | + raise NotImplementedError |
| 192 | + |
| 193 | + def __eq__(self, other): |
| 194 | + raise NotImplementedError |
0 commit comments