Skip to content

Commit e12db82

Browse files
author
User
committed
update
1 parent e3f1673 commit e12db82

File tree

2 files changed

+14
-2
lines changed

2 files changed

+14
-2
lines changed

rl2/cartpole/dqn_tf.py

+7-1
Original file line numberDiff line numberDiff line change
@@ -16,6 +16,10 @@
1616
from q_learning_bins import plot_running_avg
1717

1818

19+
# global counter
20+
global_iters = 0
21+
22+
1923
# a version of HiddenLayer that keeps track of params
2024
class HiddenLayer:
2125
def __init__(self, M1, M2, f=tf.nn.tanh, use_bias=True):
@@ -154,6 +158,7 @@ def sample_action(self, x, eps):
154158

155159

156160
def play_one(env, model, tmodel, eps, gamma, copy_period):
161+
global global_iters
157162
observation = env.reset()
158163
done = False
159164
totalreward = 0
@@ -174,8 +179,9 @@ def play_one(env, model, tmodel, eps, gamma, copy_period):
174179
model.train(tmodel)
175180

176181
iters += 1
182+
global_iters += 1
177183

178-
if iters % copy_period == 0:
184+
if global_iters % copy_period == 0:
179185
tmodel.copy_from(model)
180186

181187
return totalreward

rl2/cartpole/dqn_theano.py

+7-1
Original file line numberDiff line numberDiff line change
@@ -17,6 +17,10 @@
1717
from q_learning_bins import plot_running_avg
1818

1919

20+
# global counter
21+
global_iters = 0
22+
23+
2024
# helper for adam optimizer
2125
# use tensorflow defaults
2226
def adam(cost, params, lr0=1e-2, beta1=0.9, beta2=0.999, eps=1e-8):
@@ -170,6 +174,7 @@ def sample_action(self, x, eps):
170174

171175

172176
def play_one(env, model, tmodel, eps, gamma, copy_period):
177+
global global_iters
173178
observation = env.reset()
174179
done = False
175180
totalreward = 0
@@ -190,8 +195,9 @@ def play_one(env, model, tmodel, eps, gamma, copy_period):
190195
model.train(tmodel)
191196

192197
iters += 1
198+
global_iters += 1
193199

194-
if iters % copy_period == 0:
200+
if global_iters % copy_period == 0:
195201
tmodel.copy_from(model)
196202

197203
return totalreward

0 commit comments

Comments
 (0)