Skip to content

Commit 3cddb9b

Browse files
update
1 parent a1b64cf commit 3cddb9b

File tree

3 files changed

+167
-192
lines changed

3 files changed

+167
-192
lines changed

examples/reinforcement_learning/tutorial_C51.py

Lines changed: 54 additions & 63 deletions
Original file line numberDiff line numberDiff line change
@@ -31,10 +31,11 @@
3131
import matplotlib.pyplot as plt
3232

3333
import tensorlayer as tl
34-
from tutorial_wrappers import build_env
34+
import gym
3535

3636
parser = argparse.ArgumentParser()
37-
parser.add_argument('--mode', help='train or test', default='test')
37+
parser.add_argument('--train', dest='train', action='store_true', default=False)
38+
parser.add_argument('--test', dest='test', action='store_true', default=True)
3839
parser.add_argument(
3940
'--save_path', default=None, help='folder to save if mode == train else model path,'
4041
'qnet will be saved once target net update'
@@ -47,7 +48,8 @@
4748
np.random.seed(args.seed)
4849
tf.random.set_seed(args.seed) # reproducible
4950
env_id = args.env_id
50-
env = build_env(env_id, seed=args.seed)
51+
env = gym.make(env_id)
52+
env.seed(args.seed)
5153
alg_name = 'C51'
5254

5355
# #################### hyper parameters ####################
@@ -195,7 +197,7 @@ class DQN(object):
195197
def __init__(self):
196198
model = MLP if qnet_type == 'MLP' else CNN
197199
self.qnet = model('q')
198-
if args.mode == 'train':
200+
if args.train:
199201
self.qnet.train()
200202
self.targetqnet = model('targetq')
201203
self.targetqnet.infer()
@@ -211,7 +213,7 @@ def __init__(self):
211213

212214
def get_action(self, obv):
213215
eps = epsilon(self.niter)
214-
if args.mode == 'train' and random.random() < eps:
216+
if args.train and random.random() < eps:
215217
return int(random.random() * out_dim)
216218
else:
217219
obv = np.expand_dims(obv, 0).astype('float32') * ob_scale
@@ -275,76 +277,65 @@ def _train_func(self, b_o, b_index, b_m):
275277
# ############################# Trainer ###################################
276278
if __name__ == '__main__':
277279
dqn = DQN()
278-
if args.mode == 'train':
280+
t0 = time.time()
281+
if args.train:
279282
buffer = ReplayBuffer(buffer_size)
280-
281-
o = env.reset()
282283
nepisode = 0
283-
t = time.time()
284284
all_episode_reward = []
285285
for i in range(1, number_timesteps + 1):
286-
287-
a = dqn.get_action(o)
288-
289-
# execute action and feed to replay buffer
290-
# note that `_` tail in var name means next
291-
o_, r, done, info = env.step(a)
292-
buffer.add(o, a, r, o_, done)
293-
294-
if i >= warm_start:
295-
transitions = buffer.sample(batch_size)
296-
dqn.train(*transitions)
297-
298-
if done:
299-
episode_reward = info['episode']['r']
300-
if nepisode == 0:
301-
all_episode_reward.append(episode_reward)
286+
o = env.reset()
287+
episode_reward = 0
288+
while True:
289+
a = dqn.get_action(o)
290+
# execute action and feed to replay buffer
291+
# note that `_` tail in var name means next
292+
o_, r, done, info = env.step(a)
293+
buffer.add(o, a, r, o_, done)
294+
episode_reward += r
295+
296+
if i >= warm_start:
297+
transitions = buffer.sample(batch_size)
298+
dqn.train(*transitions)
299+
300+
if done:
301+
break
302302
else:
303-
all_episode_reward.append(all_episode_reward[-1] * 0.9 + episode_reward * 0.1)
304-
o = env.reset()
303+
o = o_
304+
305+
if nepisode == 0:
306+
all_episode_reward.append(episode_reward)
305307
else:
306-
o = o_
307-
308-
# episode in info is real (unwrapped) message
309-
if info.get('episode'):
310-
nepisode += 1
311-
reward, length = info['episode']['r'], info['episode']['l']
312-
try:
313-
fps = int(length / (time.time() - t))
314-
except:
315-
fps = 0
316-
print(
317-
'Time steps so far: {}, episode so far: {}, '
318-
'episode reward: {:.4f}, episode length: {}, FPS: {}'.format(i, nepisode, reward, length, fps)
308+
all_episode_reward.append(all_episode_reward[-1] * 0.9 + episode_reward * 0.1)
309+
nepisode += 1
310+
print(
311+
'Training | Episode: {} | Episode Reward: {:.4f} | Running Time: {:.4f}'.format(
312+
nepisode, episode_reward, time.time() - t0
319313
)
320-
t = time.time()
314+
) # episode num starts from 1 in print
321315

322316
dqn.save(args.save_path)
323317
plt.plot(all_episode_reward)
324318
if not os.path.exists('image'):
325319
os.makedirs('image')
326320
plt.savefig(os.path.join('image', '_'.join([alg_name, env_id])))
327-
else:
321+
322+
if args.test:
328323
nepisode = 0
329-
o = env.reset()
330324
for i in range(1, number_timesteps + 1):
331-
a = dqn.get_action(o)
332-
333-
# execute action
334-
# note that `_` tail in var name means next
335-
o_, r, done, info = env.step(a)
336-
env.render()
337-
338-
if done:
339-
o = env.reset()
340-
else:
341-
o = o_
342-
343-
# episode in info is real (unwrapped) message
344-
if info.get('episode'):
345-
nepisode += 1
346-
reward, length = info['episode']['r'], info['episode']['l']
347-
print(
348-
'Time steps so far: {}, episode so far: {}, '
349-
'episode reward: {:.4f}, episode length: {}'.format(i, nepisode, reward, length)
350-
)
325+
o = env.reset()
326+
episode_reward = 0
327+
while True:
328+
env.render()
329+
a = dqn.get_action(o)
330+
o_, r, done, info = env.step(a)
331+
episode_reward += r
332+
if done:
333+
break
334+
else:
335+
o = o_
336+
nepisode += 1
337+
print(
338+
'Testing | Episode: {} | Episode Reward: {:.4f} | Running Time: {:.4f}'.format(
339+
nepisode, episode_reward, time.time() - t0
340+
)
341+
)

examples/reinforcement_learning/tutorial_DQN_variants.py

Lines changed: 53 additions & 63 deletions
Original file line numberDiff line numberDiff line change
@@ -44,10 +44,11 @@
4444
import matplotlib.pyplot as plt
4545

4646
import tensorlayer as tl
47-
from tutorial_wrappers import build_env
47+
import gym
4848

4949
parser = argparse.ArgumentParser()
50-
parser.add_argument('--mode', help='train or test', default='train')
50+
parser.add_argument('--train', dest='train', action='store_true', default=False)
51+
parser.add_argument('--test', dest='test', action='store_true', default=True)
5152
parser.add_argument(
5253
'--save_path', default=None, help='folder to save if mode == train else model path,'
5354
'qnet will be saved once target net update'
@@ -64,7 +65,8 @@
6465
tf.random.set_seed(args.seed) # reproducible
6566

6667
env_id = args.env_id
67-
env = build_env(env_id, seed=args.seed)
68+
env = gym.make(env_id)
69+
env.seed(args.seed)
6870
noise_scale = args.noisy_scale
6971
double = not args.disable_double
7072
dueling = not args.disable_dueling
@@ -273,7 +275,7 @@ class DQN(object):
273275
def __init__(self):
274276
model = MLP if qnet_type == 'MLP' else CNN
275277
self.qnet = model('q')
276-
if args.mode == 'train':
278+
if args.train:
277279
self.qnet.train()
278280
self.targetqnet = model('targetq')
279281
self.targetqnet.infer()
@@ -290,7 +292,7 @@ def __init__(self):
290292

291293
def get_action(self, obv):
292294
eps = epsilon(self.niter)
293-
if args.mode == 'train':
295+
if args.train:
294296
if random.random() < eps:
295297
return int(random.random() * out_dim)
296298
obv = np.expand_dims(obv, 0).astype('float32') * ob_scale
@@ -364,77 +366,65 @@ def load(self, path):
364366
# ############################# Trainer ###################################
365367
if __name__ == '__main__':
366368
dqn = DQN()
367-
if args.mode == 'train':
369+
t0 = time.time()
370+
if args.train:
368371
buffer = ReplayBuffer(buffer_size)
369-
370-
o = env.reset()
371372
nepisode = 0
372-
t = time.time()
373373
all_episode_reward = []
374374
for i in range(1, number_timesteps + 1):
375-
376-
a = dqn.get_action(o)
377-
378-
# execute action and feed to replay buffer
379-
# note that `_` tail in var name means next
380-
o_, r, done, info = env.step(a)
381-
buffer.add(o, a, r, o_, done)
382-
383-
if i >= warm_start:
384-
transitions = buffer.sample(batch_size)
385-
dqn.train(*transitions)
386-
387-
if done:
388-
episode_reward = info['episode']['r']
389-
if nepisode == 0:
390-
all_episode_reward.append(episode_reward)
375+
o = env.reset()
376+
episode_reward = 0
377+
while True:
378+
a = dqn.get_action(o)
379+
380+
# execute action and feed to replay buffer
381+
# note that `_` tail in var name means next
382+
o_, r, done, info = env.step(a)
383+
buffer.add(o, a, r, o_, done)
384+
385+
if i >= warm_start:
386+
transitions = buffer.sample(batch_size)
387+
dqn.train(*transitions)
388+
389+
if done:
390+
break
391391
else:
392-
all_episode_reward.append(all_episode_reward[-1] * 0.9 + episode_reward * 0.1)
393-
o = env.reset()
392+
o = o_
393+
394+
if nepisode == 0:
395+
all_episode_reward.append(episode_reward)
394396
else:
395-
o = o_
396-
397-
# episode in info is real (unwrapped) message
398-
if info.get('episode'):
399-
nepisode += 1
400-
reward, length = info['episode']['r'], info['episode']['l']
401-
try:
402-
fps = int(length / (time.time() - t))
403-
except:
404-
fps = 0
405-
print(
406-
'Time steps so far: {}, episode so far: {}, '
407-
'episode reward: {:.4f}, episode length: {}, FPS: {}'.format(i, nepisode, reward, length, fps)
397+
all_episode_reward.append(all_episode_reward[-1] * 0.9 + episode_reward * 0.1)
398+
nepisode += 1
399+
print(
400+
'Training | Episode: {} | Episode Reward: {:.4f} | Running Time: {:.4f}'.format(
401+
nepisode, episode_reward, time.time() - t0
408402
)
409-
t = time.time()
403+
) # episode num starts from 1 in print
410404

411405
dqn.save(args.save_path)
412406
plt.plot(all_episode_reward)
413407
if not os.path.exists('image'):
414408
os.makedirs('image')
415409
plt.savefig(os.path.join('image', '_'.join([alg_name, env_id])))
416410

417-
else:
411+
if args.test:
418412
nepisode = 0
419-
o = env.reset()
420413
for i in range(1, number_timesteps + 1):
421-
a = dqn.get_action(o)
422-
423-
# execute action
424-
# note that `_` tail in var name means next
425-
o_, r, done, info = env.step(a)
426-
env.render()
427-
428-
if done:
429-
o = env.reset()
430-
else:
431-
o = o_
432-
433-
# episode in info is real (unwrapped) message
434-
if info.get('episode'):
435-
nepisode += 1
436-
reward, length = info['episode']['r'], info['episode']['l']
437-
print(
438-
'Time steps so far: {}, episode so far: {}, '
439-
'episode reward: {:.4f}, episode length: {}'.format(i, nepisode, reward, length)
440-
)
414+
o = env.reset()
415+
episode_reward = 0
416+
while True:
417+
env.render()
418+
a = dqn.get_action(o)
419+
o_, r, done, info = env.step(a)
420+
episode_reward += r
421+
if done:
422+
break
423+
else:
424+
o = o_
425+
nepisode += 1
426+
print(
427+
'Testing | Episode: {} | Episode Reward: {:.4f} | Running Time: {:.4f}'.format(
428+
nepisode, episode_reward, time.time() - t0
429+
)
430+
)

0 commit comments

Comments
 (0)