Skip to content

Commit b1cd59d

Browse files
authored
Merge pull request #1038 from tensorlayer/Jingqing-patch
minor update lambda layer doc and test
2 parents 385e193 + 6e20844 commit b1cd59d

File tree

2 files changed

+38
-12
lines changed

2 files changed

+38
-12
lines changed

tensorlayer/layers/lambda_layers.py

Lines changed: 12 additions & 12 deletions
Original file line numberDiff line numberDiff line change
@@ -35,14 +35,14 @@ class Lambda(Layer):
3535
3636
Examples
3737
---------
38-
Non-parametric and non-args case
38+
Non-parametric and non-args case:
3939
This case is supported in the Model.save() / Model.load() to save / load the whole model architecture and weights(optional).
4040
4141
>>> x = tl.layers.Input([8, 3], name='input')
4242
>>> y = tl.layers.Lambda(lambda x: 2*x, name='lambda')(x)
4343
4444
45-
Non-parametric and with args case
45+
Non-parametric and with args case:
4646
This case is supported in the Model.save() / Model.load() to save / load the whole model architecture and weights(optional).
4747
4848
>>> def customize_func(x, foo=42): # x is the inputs, foo is an argument
@@ -51,19 +51,19 @@ class Lambda(Layer):
5151
>>> lambdalayer = tl.layers.Lambda(customize_func, fn_args={'foo': 2}, name='lambda')(x)
5252
5353
54-
Any function with outside variables
54+
Any function with outside variables:
5555
This case has not been supported in Model.save() / Model.load() yet.
5656
Please avoid using Model.save() / Model.load() to save / load models that contain such Lambda layer. Instead, you may use Model.save_weights() / Model.load_weights() to save / load model weights.
5757
Note: In this case, fn_weights should be a list, and then the trainable weights in this Lambda layer can be added into the weights of the whole model.
5858
59-
>>> vara = [tf.Variable(1.0)]
59+
>>> a = tf.Variable(1.0)
6060
>>> def func(x):
61-
>>> return x + vara
61+
>>> return x + a
6262
>>> x = tl.layers.Input([8, 3], name='input')
63-
>>> y = tl.layers.Lambda(func, fn_weights=a, name='lambda')(x)
63+
>>> y = tl.layers.Lambda(func, fn_weights=[a], name='lambda')(x)
6464
6565
66-
Parametric case, merge other wrappers into TensorLayer
66+
Parametric case, merge other wrappers into TensorLayer:
6767
This case is supported in the Model.save() / Model.load() to save / load the whole model architecture and weights(optional).
6868
6969
>>> layers = [
@@ -74,27 +74,27 @@ class Lambda(Layer):
7474
>>> perceptron = tf.keras.Sequential(layers)
7575
>>> # in order to compile keras model and get trainable_variables of the keras model
7676
>>> _ = perceptron(np.random.random([100, 5]).astype(np.float32))
77-
77+
>>>
7878
>>> class CustomizeModel(tl.models.Model):
7979
>>> def __init__(self):
8080
>>> super(CustomizeModel, self).__init__()
8181
>>> self.dense = tl.layers.Dense(in_channels=1, n_units=5)
8282
>>> self.lambdalayer = tl.layers.Lambda(perceptron, perceptron.trainable_variables)
83-
83+
>>>
8484
>>> def forward(self, x):
8585
>>> z = self.dense(x)
8686
>>> z = self.lambdalayer(z)
8787
>>> return z
88-
88+
>>>
8989
>>> optimizer = tf.optimizers.Adam(learning_rate=0.1)
9090
>>> model = CustomizeModel()
9191
>>> model.train()
92-
92+
>>>
9393
>>> for epoch in range(50):
9494
>>> with tf.GradientTape() as tape:
9595
>>> pred_y = model(data_x)
9696
>>> loss = tl.cost.mean_squared_error(pred_y, data_y)
97-
97+
>>>
9898
>>> gradients = tape.gradient(loss, model.trainable_weights)
9999
>>> optimizer.apply_gradients(zip(gradients, model.trainable_weights))
100100

tests/layers/test_layers_lambda.py

Lines changed: 26 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -101,6 +101,32 @@ def forward(self, x, bar):
101101
out, out2 = model(self.data_x, bar=2)
102102
self.assertTrue(np.array_equal(out2.numpy(), out.numpy()))
103103

104+
def test_lambda_func_with_weight(self):
105+
106+
a = tf.Variable(1.0)
107+
108+
def customize_fn(x):
109+
return x + a
110+
111+
class CustomizeModel(tl.models.Model):
112+
113+
def __init__(self):
114+
super(CustomizeModel, self).__init__()
115+
self.dense = tl.layers.Dense(in_channels=1, n_units=5)
116+
self.lambdalayer = tl.layers.Lambda(customize_fn, fn_weights=[a])
117+
118+
def forward(self, x):
119+
z = self.dense(x)
120+
z = self.lambdalayer(z)
121+
return z
122+
123+
model = CustomizeModel()
124+
print(model.lambdalayer)
125+
model.train()
126+
127+
out = model(self.data_x)
128+
print(out.shape)
129+
104130
def test_lambda_func_without_args(self):
105131

106132
class CustomizeModel(tl.models.Model):

0 commit comments

Comments
 (0)