11
11
12
12
class TestImageContinuous (unittest .TestCase ):
13
13
def test_image_continuous (self ):
14
+ render = False
14
15
lows = 0.0
15
16
highs = 20.0
16
17
cs2 = Box (
@@ -26,24 +27,25 @@ def test_image_continuous(self):
26
27
27
28
imc = ImageContinuous (
28
29
cs2 ,
29
- width = 100 ,
30
- height = 100 ,
30
+ width = 400 ,
31
+ height = 400 ,
31
32
)
32
33
pos = np .array ([5.0 , 7.0 ])
33
34
img1 = Image .fromarray (np .squeeze (imc .generate_image (pos )), "RGB" )
34
35
# img1 = ImageOps.invert(img1)
35
- img1 .show ()
36
+ if render : img1 .show ()
36
37
# img1.save("cont_state_no_target.pdf")
37
38
38
39
target = np .array ([10 , 10 ])
39
40
imc = ImageContinuous (
40
41
cs2 ,
42
+ circle_radius = 10 ,
41
43
target_point = target ,
42
- width = 100 ,
43
- height = 100 ,
44
+ width = 400 ,
45
+ height = 400 ,
44
46
)
45
47
img1 = Image .fromarray (np .squeeze (imc .generate_image (pos )), "RGB" )
46
- img1 .show ()
48
+ if render : img1 .show ()
47
49
# img1.save("cont_state_target.pdf")
48
50
49
51
# Terminal sub-spaces
@@ -65,13 +67,14 @@ def test_image_continuous(self):
65
67
imc = ImageContinuous (
66
68
cs2 ,
67
69
target_point = target ,
70
+ circle_radius = 10 ,
68
71
term_spaces = term_spaces ,
69
- width = 100 ,
70
- height = 100 ,
72
+ width = 400 ,
73
+ height = 400 ,
71
74
)
72
75
pos = np .array ([5.0 , 7.0 ])
73
76
img1 = Image .fromarray (np .squeeze (imc .get_concatenated_image (pos )), "RGB" )
74
- img1 .show ()
77
+ if render : img1 .show ()
75
78
# img1.save("cont_state_target_terminal_states.pdf")
76
79
77
80
# Irrelevant features
@@ -84,7 +87,7 @@ def test_image_continuous(self):
84
87
)
85
88
pos = np .array ([5.0 , 7.0 , 10.0 , 15.0 ])
86
89
img1 = Image .fromarray (np .squeeze (imc .get_concatenated_image (pos )), "RGB" )
87
- img1 .show ()
90
+ if render : img1 .show ()
88
91
# print(imc.get_concatenated_image(pos).shape)
89
92
90
93
# Random sample and __repr__
@@ -96,14 +99,58 @@ def test_image_continuous(self):
96
99
)
97
100
# print(imc)
98
101
img1 = Image .fromarray (np .squeeze (imc .sample ()), "RGB" )
99
- img1 .show ()
102
+ if render : img1 .show ()
100
103
101
104
# Draw grid
105
+ grid_shape = (5 , 5 )
106
+ cs2_grid = Box (
107
+ low = 0 * np .array (grid_shape ).astype (np .float64 ),
108
+ high = np .array (grid_shape ).astype (np .float64 ),
109
+ )
110
+ pos = np .array ([2 , 3 ])
111
+ target = np .array ([4 , 4 ])
102
112
imc = ImageContinuous (
103
- cs4 , target_point = target , width = 400 , height = 400 , grid = (5 , 5 )
113
+ cs2_grid ,
114
+ target_point = target ,
115
+ circle_radius = 10 ,
116
+ width = 400 ,
117
+ height = 400 ,
118
+ grid_shape = grid_shape
104
119
)
105
- img1 = Image .fromarray (np .squeeze (imc .sample ()), "RGB" )
106
- img1 .show ()
120
+ img1 = Image .fromarray (np .squeeze (imc .get_concatenated_image (pos )), "RGB" )
121
+ if render : img1 .show ()
122
+ # img1.save("grid_target.pdf")
123
+
124
+ # Grid with terminal sub-spaces
125
+ lows = np .array ([2.0 , 4.0 ])
126
+ highs = np .array ([2.0 , 4.0 ])
127
+ cs2_term1 = Box (
128
+ low = lows ,
129
+ high = highs ,
130
+ )
131
+ lows = np .array ([1.0 , 1.0 ])
132
+ highs = np .array ([1.0 , 1.0 ])
133
+ cs2_term2 = Box (
134
+ low = lows ,
135
+ high = highs ,
136
+ )
137
+ term_spaces = [cs2_term1 , cs2_term2 ]
138
+
139
+ pos = np .array ([2 , 3 ])
140
+ target = np .array ([4 , 4 ])
141
+ imc = ImageContinuous (
142
+ cs2_grid ,
143
+ circle_radius = 10 ,
144
+ target_point = target ,
145
+ term_spaces = term_spaces ,
146
+ width = 400 ,
147
+ height = 400 ,
148
+ grid_shape = grid_shape
149
+ )
150
+ img1 = Image .fromarray (np .squeeze (imc .get_concatenated_image (pos )), "RGB" )
151
+ if render : img1 .show ()
152
+ # img1.save("grid_target_terminal_states.pdf")
153
+
107
154
108
155
109
156
if __name__ == "__main__" :
0 commit comments