Skip to content

Commit 094ebf7

Browse files
authored
Merge pull request dennybritz#118 from praveen-palanisamy/master
Fixes for issues and for compatibility with TensorFlow v 1.0+
2 parents 18100ea + 10ce5dc commit 094ebf7

File tree

2 files changed

+28
-23
lines changed

2 files changed

+28
-23
lines changed

DQN/Deep Q Learning.ipynb

+14-12
Original file line numberDiff line numberDiff line change
@@ -11,6 +11,7 @@
1111
"%matplotlib inline\n",
1212
"\n",
1313
"import gym\n",
14+
"from gym.wrappers import Monitor\n",
1415
"import itertools\n",
1516
"import numpy as np\n",
1617
"import os\n",
@@ -67,7 +68,7 @@
6768
" self.output = tf.image.rgb_to_grayscale(self.input_state)\n",
6869
" self.output = tf.image.crop_to_bounding_box(self.output, 34, 0, 160, 160)\n",
6970
" self.output = tf.image.resize_images(\n",
70-
" self.output, 84, 84, method=tf.image.ResizeMethod.NEAREST_NEIGHBOR)\n",
71+
" self.output, [84, 84], method=tf.image.ResizeMethod.NEAREST_NEIGHBOR)\n",
7172
" self.output = tf.squeeze(self.output)\n",
7273
"\n",
7374
" def process(self, sess, state):\n",
@@ -107,7 +108,7 @@
107108
" summary_dir = os.path.join(summaries_dir, \"summaries_{}\".format(scope))\n",
108109
" if not os.path.exists(summary_dir):\n",
109110
" os.makedirs(summary_dir)\n",
110-
" self.summary_writer = tf.train.SummaryWriter(summary_dir)\n",
111+
" self.summary_writer = tf.summary.FileWriter(summary_dir)\n",
111112
"\n",
112113
" def _build_model(self):\n",
113114
" \"\"\"\n",
@@ -151,11 +152,11 @@
151152
" self.train_op = self.optimizer.minimize(self.loss, global_step=tf.contrib.framework.get_global_step())\n",
152153
"\n",
153154
" # Summaries for Tensorboard\n",
154-
" self.summaries = tf.merge_summary([\n",
155-
" tf.scalar_summary(\"loss\", self.loss),\n",
156-
" tf.histogram_summary(\"loss_hist\", self.losses),\n",
157-
" tf.histogram_summary(\"q_values_hist\", self.predictions),\n",
158-
" tf.scalar_summary(\"max_q_value\", tf.reduce_max(self.predictions))\n",
155+
" self.summaries = tf.summary.merge([\n",
156+
" tf.summary.scalar(\"loss\", self.loss),\n",
157+
" tf.summary.histogram(\"loss_hist\", self.losses),\n",
158+
" tf.summary.histogram(\"q_values_hist\", self.predictions),\n",
159+
" tf.summary.scalar(\"max_q_value\", tf.reduce_max(self.predictions))\n",
159160
" ])\n",
160161
"\n",
161162
"\n",
@@ -212,7 +213,7 @@
212213
"sp = StateProcessor()\n",
213214
"\n",
214215
"with tf.Session() as sess:\n",
215-
" sess.run(tf.initialize_all_variables())\n",
216+
" sess.run(tf.global_variables_initializer())\n",
216217
" \n",
217218
" # Example observation batch\n",
218219
" observation = env.reset()\n",
@@ -391,9 +392,10 @@
391392
" pass\n",
392393
"\n",
393394
" # Record videos\n",
394-
" env.monitor.start(monitor_path,\n",
395-
" resume=True,\n",
396-
" video_callable=lambda count: count % record_video_every == 0)\n",
395+
" env= Monitor(env,\n",
396+
" directory=monitor_path,\n",
397+
" resume=True,\n",
398+
" video_callable=lambda count: count % record_video_every == 0)\n",
397399
"\n",
398400
" for i_episode in range(num_episodes):\n",
399401
"\n",
@@ -526,7 +528,7 @@
526528
"name": "python",
527529
"nbconvert_exporter": "python",
528530
"pygments_lexer": "ipython3",
529-
"version": "3.5.1"
531+
"version": "3.6.0"
530532
}
531533
},
532534
"nbformat": 4,

DQN/dqn.py

+14-11
Original file line numberDiff line numberDiff line change
@@ -1,4 +1,5 @@
11
import gym
2+
from gym.wrappers import Monitor
23
import itertools
34
import numpy as np
45
import os
@@ -28,7 +29,7 @@ def __init__(self):
2829
self.output = tf.image.rgb_to_grayscale(self.input_state)
2930
self.output = tf.image.crop_to_bounding_box(self.output, 34, 0, 160, 160)
3031
self.output = tf.image.resize_images(
31-
self.output, 84, 84, method=tf.image.ResizeMethod.NEAREST_NEIGHBOR)
32+
self.output, [84, 84], method=tf.image.ResizeMethod.NEAREST_NEIGHBOR)
3233
self.output = tf.squeeze(self.output)
3334

3435
def process(self, sess, state):
@@ -59,7 +60,7 @@ def __init__(self, scope="estimator", summaries_dir=None):
5960
summary_dir = os.path.join(summaries_dir, "summaries_{}".format(scope))
6061
if not os.path.exists(summary_dir):
6162
os.makedirs(summary_dir)
62-
self.summary_writer = tf.train.SummaryWriter(summary_dir)
63+
self.summary_writer = tf.summary.FileWriter(summary_dir)
6364

6465
def _build_model(self):
6566
"""
@@ -103,11 +104,11 @@ def _build_model(self):
103104
self.train_op = self.optimizer.minimize(self.loss, global_step=tf.contrib.framework.get_global_step())
104105

105106
# Summaries for Tensorboard
106-
self.summaries = tf.merge_summary([
107-
tf.scalar_summary("loss", self.loss),
108-
tf.histogram_summary("loss_hist", self.losses),
109-
tf.histogram_summary("q_values_hist", self.predictions),
110-
tf.scalar_summary("max_q_value", tf.reduce_max(self.predictions))
107+
self.summaries = tf.summary.merge([
108+
tf.summary.scalar("loss", self.loss),
109+
tf.summary.histogram("loss_hist", self.losses),
110+
tf.summary.histogram("q_values_hist", self.predictions),
111+
tf.summary.scalar("max_q_value", tf.reduce_max(self.predictions))
111112
])
112113

113114

@@ -292,9 +293,11 @@ def deep_q_learning(sess,
292293
state = next_state
293294

294295
# Record videos
295-
env.monitor.start(monitor_path,
296-
resume=True,
297-
video_callable=lambda count: count % record_video_every == 0)
296+
# Use the gym env Monitor wrapper
297+
env = Monitor(env,
298+
directory=monitor_path,
299+
resume=True,
300+
video_callable=lambda count: count % record_video_every ==0)
298301

299302
for i_episode in range(num_episodes):
300303

@@ -398,7 +401,7 @@ def deep_q_learning(sess,
398401
state_processor = StateProcessor()
399402

400403
with tf.Session() as sess:
401-
sess.run(tf.initialize_all_variables())
404+
sess.run(tf.global_variables_initializer())
402405
for t, stats in deep_q_learning(sess,
403406
env,
404407
q_estimator=q_estimator,

0 commit comments

Comments
 (0)