@@ -34,16 +34,17 @@ def main():
34
34
# Graph
35
35
# We must not use train.replicate_device_setter for normal operations
36
36
# with tf.device(tf.train.replica_device_setter(ps_tasks=n_pss\
37
- # ,worker_device="/job:%s/task:%d/cpu:0" % (FLAGS.job_name,FLAGS.task_index))):
38
- with tf .device (tf .train .replica_device_setter (ps_tasks = n_pss \
39
- ,worker_device = "/job:%s/task:%d" % (FLAGS .job_name ,FLAGS .task_index ))):
37
+ # ,worker_device="/job:%s/task:%d" % (FLAGS.job_name,FLAGS.task_index))):
40
38
39
+ #Local operations
40
+ with tf .device ("/job:worker/replica:0/task:%d" % FLAGS .task_index ):
41
41
a = tf .Variable (tf .constant (0. ,shape = [2 ]),dtype = tf .float32 ,
42
42
collections = [tf .GraphKeys .LOCAL_VARIABLES ])
43
43
b = tf .Variable (tf .constant (0. ,shape = [2 ]),dtype = tf .float32 ,
44
44
collections = [tf .GraphKeys .LOCAL_VARIABLES ])
45
45
c = a + b
46
-
46
+ with tf .device (tf .train .replica_device_setter (ps_tasks = n_pss \
47
+ ,worker_device = "/job:%s/task:%d" % (FLAGS .job_name ,FLAGS .task_index ))):
47
48
local_step = tf .Variable (0 ,dtype = tf .int32 ,trainable = False ,name = 'local_step' ,
48
49
collections = ['local_non_trainable' ])
49
50
#global step is tricky
@@ -54,19 +55,11 @@ def main():
54
55
# all workers use the same learning rate and it is decided on by the task 0
55
56
# or maybe the from the graph of the chief worker
56
57
base_lr = .0001
57
- loptimizer = tf .train .GradientDescentOptimizer (base_lr )
58
+ loptimizer = tf .train .GradientDescentOptimizer (base_lr ) #local optimizer
58
59
optimizer = tf .train .GradientDescentOptimizer (base_lr ) #the learning rate set here is global
59
60
60
61
#create global variables and/or references
61
62
local_to_global , global_to_local = create_global_variables ()
62
-
63
- #local optimizers and steps
64
- # actually only need one optimizer
65
- # optimizers=[]
66
- # local_steps = []
67
- # for w in range(n_workers):
68
- # local_steps.append(tf.Variable(0,dtype=tf.int32,trainable=False,name='local_step_%d'%w))
69
- # optimizers.append(tf.train.GradientDescentOptimizer(base_lr))
70
63
71
64
# ADAG (simplest case since all batches are the same)
72
65
update_window = 3 # T: update/communication window, a.k.a number of gradients to use before sending to ps
@@ -86,14 +79,20 @@ def main():
86
79
zip (grads ,[ local_to_global [v ] for v in varss ])
87
80
,global_step = global_step ) #apply the gradients to variables on ps
88
81
82
+ # Push to global server
89
83
with tf .control_dependencies ([opt ]):
90
84
assign_locals = assign_global_to_local (global_to_local )
91
85
92
86
# Init ops
93
87
init_local = tf .variables_initializer (tf .local_variables ()+ tf .get_collection ('local_non_trainable' ))#tf.local_variables_initializer() #for local variables
94
88
init = tf .global_variables_initializer () # for global variables
95
89
90
+ # TODO: Grab global state before training so all workers have same initialization
91
+ grab_global_init = assign_global_to_local (global_to_local )
92
+
96
93
# TODO: Add op the assigns local values to global ones for chief to execute
94
+ assign_global = assign_local_to_global (local_to_global )
95
+
97
96
98
97
99
98
# Session
@@ -105,10 +104,12 @@ def main():
105
104
sess = tf .train .MonitoredTrainingSession (master = server .target ,is_chief = is_chief ,config = config ,
106
105
scaffold = scaff ,hooks = hooks ,save_checkpoint_secs = 1 ,checkpoint_dir = 'logdir' )
107
106
if is_chief :
107
+ sess .run (assign_global ) #TODO #assigns chiefs initial values to ps
108
108
time .sleep (10 ) #grace period to wait on other workers before starting training
109
109
110
110
# Train until hook stops session
111
111
print ('Starting training on worker %d' % FLAGS .task_index )
112
+ sess .run (grab_global_init )
112
113
while not sess .should_stop ():
113
114
_ ,_ ,r ,gs ,ls = sess .run ([opt ,assign_locals ,c ,global_step ,local_step ])
114
115
@@ -127,14 +128,40 @@ def main():
127
128
print ('Session from worker %d closed cleanly' % FLAGS .task_index )
128
129
129
130
def assign_global_to_local (global_to_local ):
131
+ """
132
+ global_to_local : dictionary with corresponding local variable for global key
133
+
134
+ Assigns global variable value to local variables
135
+ """
130
136
for v in global_to_local .keys ():
131
137
tf .assign (global_to_local [v ],v )
132
138
return tf .no_op ()
133
139
140
+ def assign_local_to_global (local_to_global ):
141
+ """
142
+ local_to_global : dictionary with corresponding global variable for local key
143
+
144
+ Assigns global variable value to local variables
145
+ """
146
+ for v in local_to_global .keys ():
147
+ tf .assign (local_to_global [v ],v )
148
+ return tf .no_op ()
149
+
134
150
def get_global_variable_by_name (name ):
151
+ """
152
+ name : the name of the global variable
153
+
154
+ Returns the global variable of given name
155
+ """
135
156
return [v for v in tf .global_variables () if v .name == name ][0 ]
136
157
137
158
def create_global_variables ():
159
+ """
160
+ Creates global variables for local variables on the graph.
161
+
162
+ Returns dictionarys for local-to-global and global-to-local
163
+ variable mappings.
164
+ """
138
165
# TODO: swap static string with tf.train.replica_device_setter(ps_tasks=n_pss)
139
166
local_to_global = {}
140
167
global_to_local = {}
@@ -149,10 +176,10 @@ def create_global_variables():
149
176
global_to_local [v_g ] = v
150
177
return local_to_global ,global_to_local
151
178
152
- # TODO: initialize global ps variables
153
- # according to the chiefs initial values
154
- def assign_global_values ():
155
- return None
179
+ # # TODO: initialize global ps variables
180
+ # # according to the chiefs initial values
181
+ # def assign_global_values():
182
+ # return None
156
183
157
184
158
185
if __name__ == '__main__' :
0 commit comments