Skip to content

Commit cee9e78

Browse files
authored
Merge pull request dennybritz#188 from JovanSardinha/master
cleaning up lib/envs/gridword.py
2 parents 8e8a21b + 01b8b13 commit cee9e78

File tree

1 file changed

+14
-3
lines changed

1 file changed

+14
-3
lines changed

lib/envs/gridworld.py

+14-3
Original file line numberDiff line numberDiff line change
@@ -1,3 +1,4 @@
1+
import io
12
import numpy as np
23
import sys
34
from gym.envs.toy_text import discrete
@@ -49,6 +50,7 @@ def __init__(self, shape=[4,4]):
4950
s = it.iterindex
5051
y, x = it.multi_index
5152

53+
# P[s][a] = (prob, next_state, reward, is_done)
5254
P[s] = {a : [] for a in range(nA)}
5355

5456
is_done = lambda s: s == 0 or s == (nS - 1)
@@ -83,10 +85,19 @@ def __init__(self, shape=[4,4]):
8385
super(GridworldEnv, self).__init__(nS, nA, P, isd)
8486

8587
def _render(self, mode='human', close=False):
88+
""" Renders the current gridworld layout
89+
90+
For example, a 4x4 grid with the mode="human" looks like:
91+
T o o o
92+
o x o o
93+
o o o o
94+
o o o T
95+
where x is your position and T are the two terminal states.
96+
"""
8697
if close:
8798
return
8899

89-
outfile = StringIO() if mode == 'ansi' else sys.stdout
100+
outfile = io.StringIO() if mode == 'ansi' else sys.stdout
90101

91102
grid = np.arange(self.nS).reshape(self.shape)
92103
it = np.nditer(grid, flags=['multi_index'])
@@ -102,7 +113,7 @@ def _render(self, mode='human', close=False):
102113
output = " o "
103114

104115
if x == 0:
105-
output = output.lstrip()
116+
output = output.lstrip()
106117
if x == self.shape[1] - 1:
107118
output = output.rstrip()
108119

@@ -111,4 +122,4 @@ def _render(self, mode='human', close=False):
111122
if x == self.shape[1] - 1:
112123
outfile.write("\n")
113124

114-
it.iternext()
125+
it.iternext()

0 commit comments

Comments
 (0)