@@ -116,13 +116,9 @@ def __init__(self, prev_layer, act=None, name=None, *args, **kwargs):
116
116
117
117
self .inputs = None
118
118
self .outputs = None
119
- self .graph = {}
120
119
self .all_layers = list ()
121
120
self .all_params = list ()
122
121
self .all_drop = dict ()
123
- self .all_graphs = list ()
124
-
125
- self .layer_args = self ._get_init_args (skip = 4 )
126
122
127
123
if name is None :
128
124
raise ValueError ('Layer must have a name.' )
@@ -146,7 +142,6 @@ def __init__(self, prev_layer, act=None, name=None, *args, **kwargs):
146
142
self ._add_layers (prev_layer .all_layers )
147
143
self ._add_params (prev_layer .all_params )
148
144
self ._add_dropout_layers (prev_layer .all_drop )
149
- self ._add_graphs (prev_layer .all_graphs )
150
145
151
146
elif isinstance (prev_layer , list ):
152
147
# 2. for layer have multiply inputs i.e. ConcatLayer
@@ -156,7 +151,6 @@ def __init__(self, prev_layer, act=None, name=None, *args, **kwargs):
156
151
self ._add_layers (sum ([l .all_layers for l in prev_layer ], []))
157
152
self ._add_params (sum ([l .all_params for l in prev_layer ], []))
158
153
self ._add_dropout_layers (sum ([list (l .all_drop .items ()) for l in prev_layer ], []))
159
- self ._add_graphs (sum ([l .all_graphs for l in prev_layer ], []))
160
154
161
155
elif isinstance (prev_layer , tf .Tensor ) or isinstance (prev_layer , tf .Variable ): # placeholders
162
156
if self .__class__ .__name__ not in ['InputLayer' , 'OneHotInputLayer' , 'Word2vecEmbeddingInputlayer' ,
@@ -165,49 +159,15 @@ def __init__(self, prev_layer, act=None, name=None, *args, **kwargs):
165
159
166
160
self .inputs = prev_layer
167
161
168
- self ._add_graphs (
169
- (
170
- self .inputs .name , # .split(':')[0],
171
- {
172
- 'shape' : self .inputs .get_shape ().as_list (),
173
- 'dtype' : self .inputs .dtype .name ,
174
- 'class' : 'placeholder' ,
175
- 'prev_layer' : None
176
- }
177
- )
178
- )
179
-
180
162
elif prev_layer is not None :
181
163
# 4. tl.models
182
164
self ._add_layers (prev_layer .all_layers )
183
165
self ._add_params (prev_layer .all_params )
184
166
self ._add_dropout_layers (prev_layer .all_drop )
185
- self ._add_graphs (prev_layer .all_graphs )
186
167
187
168
if hasattr (prev_layer , "outputs" ):
188
169
self .inputs = prev_layer .outputs
189
170
190
- # TL Graph
191
- if isinstance (prev_layer , list ): # e.g. ConcatLayer, ElementwiseLayer have multiply previous layers
192
- _list = []
193
- for layer in prev_layer :
194
- _list .append (layer .name )
195
- self .graph .update ({'class' : self .__class__ .__name__ .split ('.' )[- 1 ], 'prev_layer' : _list })
196
- elif prev_layer is None : #
197
- self .graph .update ({'class' : self .__class__ .__name__ .split ('.' )[- 1 ], 'prev_layer' : None })
198
- else : # normal layers e.g. Conv2d
199
- self .graph .update ({'class' : self .__class__ .__name__ .split ('.' )[- 1 ], 'prev_layer' : prev_layer .name })
200
- # if act: ## convert activation from function to string
201
- # try:
202
- # act = act.__name__
203
- # except:
204
- # pass
205
- # self.graph.update({'act': act})
206
- # print(self.layer_args)
207
- self .graph .update (self .layer_args )
208
- # print(self.graph)
209
- self ._add_graphs ((self .name , self .graph ))
210
-
211
171
def print_params (self , details = True , session = None ):
212
172
"""Print all info of parameters in the network"""
213
173
for i , p in enumerate (self .all_params ):
@@ -278,7 +238,6 @@ def __getitem__(self, key):
278
238
net_new ._add_layers (net_new .outputs )
279
239
280
240
net_new ._add_params (self .all_params )
281
- net_new ._add_graphs (self .all_graphs )
282
241
net_new ._add_dropout_layers (self .all_drop )
283
242
284
243
return net_new
@@ -353,17 +312,6 @@ def _add_params(self, params):
353
312
354
313
self .all_params = list_remove_repeat (self .all_params )
355
314
356
- @protected_method
357
- def _add_graphs (self , graphs ):
358
-
359
- if isinstance (graphs , list ):
360
- self .all_graphs .extend (list (graphs ))
361
-
362
- else :
363
- self .all_graphs .append (graphs )
364
-
365
- # self.all_graphs = list_remove_repeat(self.all_graphs) # cannot repeat
366
-
367
315
@protected_method
368
316
def _add_dropout_layers (self , drop_layers ):
369
317
if isinstance (drop_layers , dict ) or isinstance (drop_layers , list ):
0 commit comments