Skip to content

Commit 2aa98bf

Browse files
Try to fix failing tests
1 parent a7afbd5 commit 2aa98bf

5 files changed

+73
-15
lines changed

experiments/dqn_image_representations.py

+5
Original file line numberDiff line numberDiff line change
@@ -19,6 +19,11 @@
1919
agent_config specifies static agent configurations
2020
model_config specifies static NN model configurations
2121
eval_config specifies static evaluation configurations
22+
23+
NOTE: Please note that for any configuration values not provided here, reasonable
24+
default values would be used. As such, these config values are much more verbose
25+
than needed. We only explicitly provide many important configuration values here
26+
to have them be easy to find.
2227
"""
2328
from ray import tune
2429
from collections import OrderedDict

experiments/dqn_image_representations_sh_quant.py

+5
Original file line numberDiff line numberDiff line change
@@ -19,6 +19,11 @@
1919
agent_config specifies static agent configurations
2020
model_config specifies static NN model configurations
2121
eval_config specifies static evaluation configurations
22+
23+
NOTE: Please note that for any configuration values not provided here, reasonable
24+
default values would be used. As such, these config values are much more verbose
25+
than needed. We only explicitly provide many important configuration values here
26+
to have them be easy to find.
2227
"""
2328
import itertools
2429
from ray import tune

mdp_playground/spaces/image_continuous.py

+1
Original file line numberDiff line numberDiff line change
@@ -102,6 +102,7 @@ def __init__(
102102

103103
if self.target_point is not None:
104104
if self.draw_grid:
105+
target_point = target_point.astype(float)
105106
target_point += 0.5
106107
self.target_point_pixel = self.convert_to_pixel(target_point)
107108

mdp_playground/spaces/test_image_continuous.py

+61-14
Original file line numberDiff line numberDiff line change
@@ -11,6 +11,7 @@
1111

1212
class TestImageContinuous(unittest.TestCase):
1313
def test_image_continuous(self):
14+
render = False
1415
lows = 0.0
1516
highs = 20.0
1617
cs2 = Box(
@@ -26,24 +27,25 @@ def test_image_continuous(self):
2627

2728
imc = ImageContinuous(
2829
cs2,
29-
width=100,
30-
height=100,
30+
width=400,
31+
height=400,
3132
)
3233
pos = np.array([5.0, 7.0])
3334
img1 = Image.fromarray(np.squeeze(imc.generate_image(pos)), "RGB")
3435
# img1 = ImageOps.invert(img1)
35-
img1.show()
36+
if render: img1.show()
3637
# img1.save("cont_state_no_target.pdf")
3738

3839
target = np.array([10, 10])
3940
imc = ImageContinuous(
4041
cs2,
42+
circle_radius=10,
4143
target_point=target,
42-
width=100,
43-
height=100,
44+
width=400,
45+
height=400,
4446
)
4547
img1 = Image.fromarray(np.squeeze(imc.generate_image(pos)), "RGB")
46-
img1.show()
48+
if render: img1.show()
4749
# img1.save("cont_state_target.pdf")
4850

4951
# Terminal sub-spaces
@@ -65,13 +67,14 @@ def test_image_continuous(self):
6567
imc = ImageContinuous(
6668
cs2,
6769
target_point=target,
70+
circle_radius=10,
6871
term_spaces=term_spaces,
69-
width=100,
70-
height=100,
72+
width=400,
73+
height=400,
7174
)
7275
pos = np.array([5.0, 7.0])
7376
img1 = Image.fromarray(np.squeeze(imc.get_concatenated_image(pos)), "RGB")
74-
img1.show()
77+
if render: img1.show()
7578
# img1.save("cont_state_target_terminal_states.pdf")
7679

7780
# Irrelevant features
@@ -84,7 +87,7 @@ def test_image_continuous(self):
8487
)
8588
pos = np.array([5.0, 7.0, 10.0, 15.0])
8689
img1 = Image.fromarray(np.squeeze(imc.get_concatenated_image(pos)), "RGB")
87-
img1.show()
90+
if render: img1.show()
8891
# print(imc.get_concatenated_image(pos).shape)
8992

9093
# Random sample and __repr__
@@ -96,14 +99,58 @@ def test_image_continuous(self):
9699
)
97100
# print(imc)
98101
img1 = Image.fromarray(np.squeeze(imc.sample()), "RGB")
99-
img1.show()
102+
if render: img1.show()
100103

101104
# Draw grid
105+
grid_shape=(5, 5)
106+
cs2_grid = Box(
107+
low=0 * np.array(grid_shape).astype(np.float64),
108+
high=np.array(grid_shape).astype(np.float64),
109+
)
110+
pos = np.array([2, 3])
111+
target = np.array([4, 4])
102112
imc = ImageContinuous(
103-
cs4, target_point=target, width=400, height=400, grid=(5, 5)
113+
cs2_grid,
114+
target_point=target,
115+
circle_radius=10,
116+
width=400,
117+
height=400,
118+
grid_shape=grid_shape
104119
)
105-
img1 = Image.fromarray(np.squeeze(imc.sample()), "RGB")
106-
img1.show()
120+
img1 = Image.fromarray(np.squeeze(imc.get_concatenated_image(pos)), "RGB")
121+
if render: img1.show()
122+
# img1.save("grid_target.pdf")
123+
124+
# Grid with terminal sub-spaces
125+
lows = np.array([2.0, 4.0])
126+
highs = np.array([2.0, 4.0])
127+
cs2_term1 = Box(
128+
low=lows,
129+
high=highs,
130+
)
131+
lows = np.array([1.0, 1.0])
132+
highs = np.array([1.0, 1.0])
133+
cs2_term2 = Box(
134+
low=lows,
135+
high=highs,
136+
)
137+
term_spaces = [cs2_term1, cs2_term2]
138+
139+
pos = np.array([2, 3])
140+
target = np.array([4, 4])
141+
imc = ImageContinuous(
142+
cs2_grid,
143+
circle_radius=10,
144+
target_point=target,
145+
term_spaces=term_spaces,
146+
width=400,
147+
height=400,
148+
grid_shape=grid_shape
149+
)
150+
img1 = Image.fromarray(np.squeeze(imc.get_concatenated_image(pos)), "RGB")
151+
if render: img1.show()
152+
# img1.save("grid_target_terminal_states.pdf")
153+
107154

108155

109156
if __name__ == "__main__":

setup.py

+1-1
Original file line numberDiff line numberDiff line change
@@ -24,7 +24,7 @@
2424

2525
extras_require_disc = [
2626
"ray[rllib,debug]==0.7.3",
27-
"tensorflow==1.13.0rc1",
27+
"tensorflow==1.13.1",
2828
"pillow>=6.1.0",
2929
"requests==2.22.0",
3030
"configspace==0.4.10",

0 commit comments

Comments
 (0)