@@ -236,6 +236,34 @@ def sgd(lr, tparams, grads, x, mask, y, cost):
236
236
237
237
238
238
def adadelta (lr , tparams , grads , x , mask , y , cost ):
239
+ """
240
+ An adaptive learning rate optimizer
241
+
242
+ Parameters
243
+ ----------
244
+ lr : Theano SharedVariable
245
+ Initial learning rate
246
+ tpramas: Theano SharedVariable
247
+ Model parameters
248
+ grads: Theano variable
249
+ Gradients of cost w.r.t to parameres
250
+ x: Theano variable
251
+ Model inputs
252
+ mask: Theano variable
253
+ Sequence mask
254
+ y: Theano variable
255
+ Targets
256
+ cost: Theano variable
257
+ Objective fucntion to minimize
258
+
259
+ Notes
260
+ -----
261
+ For more information, see [ADADELTA]_.
262
+
263
+ .. [ADADELTA] Matthew D. Zeiler, *ADADELTA: An Adaptive Learning
264
+ Rate Method*, arXiv:1212.5701.
265
+ """
266
+
239
267
zipped_grads = [theano .shared (p .get_value () * numpy_floatX (0. ),
240
268
name = '%s_grad' % k )
241
269
for k , p in tparams .iteritems ()]
@@ -269,6 +297,36 @@ def adadelta(lr, tparams, grads, x, mask, y, cost):
269
297
270
298
271
299
def rmsprop (lr , tparams , grads , x , mask , y , cost ):
300
+ """
301
+ A variant of SGD that scales the step size by running average of the
302
+ recent step norms.
303
+
304
+ Parameters
305
+ ----------
306
+ lr : Theano SharedVariable
307
+ Initial learning rate
308
+ tpramas: Theano SharedVariable
309
+ Model parameters
310
+ grads: Theano variable
311
+ Gradients of cost w.r.t to parameres
312
+ x: Theano variable
313
+ Model inputs
314
+ mask: Theano variable
315
+ Sequence mask
316
+ y: Theano variable
317
+ Targets
318
+ cost: Theano variable
319
+ Objective fucntion to minimize
320
+
321
+ Notes
322
+ -----
323
+ For more information, see [Hint2014]_.
324
+
325
+ .. [Hint2014] Geoff Hinton, *Neural Networks for Machine Learning*,
326
+ lecture 6a,
327
+ http://cs.toronto.edu/~tijmen/csc321/slides/lecture_slides_lec6.pdf
328
+ """
329
+
272
330
zipped_grads = [theano .shared (p .get_value () * numpy_floatX (0. ),
273
331
name = '%s_grad' % k )
274
332
for k , p in tparams .iteritems ()]
0 commit comments