Skip to content

Commit 2d028c8

Browse files
committed
fix 'Unsupported' object has no attribute 'info'
1 parent accbac4 commit 2d028c8

File tree

6 files changed

+103
-98
lines changed

6 files changed

+103
-98
lines changed

Diff for: examples/build_test_model.py

+20-14
Original file line numberDiff line numberDiff line change
@@ -4,26 +4,32 @@
44
from tensorflow.keras.optimizers import Adam
55
from utils.BASEDIR import BASEDIR
66

7+
8+
def one_hot(idx):
9+
x = [0 for _ in range(3)]
10+
x[idx] = 1
11+
return x
12+
13+
714
samples = 10000
8-
x_train = (np.random.rand(samples, 1) * 10)
9-
# y_train = x_train.take(axis=1, indices=1) * 2
10-
y_train = ((x_train * 1.5) + 2.5) / 2
15+
x_train = (np.random.rand(samples, 3) * 10)
16+
y_train = np.array([one_hot(np.argmax(sample)) for sample in x_train])
1117

1218
model = Sequential()
13-
model.add(Dense(256, activation='relu', input_shape=x_train.shape[1:]))
19+
model.add(Dense(32, activation='relu', input_shape=x_train.shape[1:]))
20+
model.add(Dropout(0.2))
1421
model.add(BatchNormalization())
15-
model.add(Dense(128, activation='relu'))
16-
# model.add(BatchNormalization())
17-
model.add(Dense(64, activation='relu'))
18-
# model.add(BatchNormalization())
19-
model.add(Dense(1, activation='linear'))
2022

21-
model.compile(optimizer=Adam(lr=0.001, amsgrad=True), loss='mse')
22-
model.fit(x_train, y_train, batch_size=64, epochs=10, verbose=1, validation_split=0.2)
23+
model.add(Dense(16, activation='relu'))
24+
model.add(Dropout(0.2))
25+
model.add(BatchNormalization())
26+
27+
model.add(Dense(3, activation='softmax'))
28+
29+
model.compile(optimizer=Adam(lr=0.003, amsgrad=True), loss='categorical_crossentropy')
30+
model.fit(x_train, y_train, batch_size=64, epochs=20, verbose=1, validation_split=0.2)
2331

2432
model.save('{}/examples/batch_norm.h5'.format(BASEDIR))
25-
print(model.predict([[4.5]]))
33+
print(model.predict([[4.5, 4.5, 9]]).tolist())
2634
print('Saved!')
27-
print(model.layers[0].get_weights()[0].shape)
28-
print(model.layers[1].get_weights()[0].shape)
2935
# exit()

Diff for: examples/load.py

+1-1
Original file line numberDiff line numberDiff line change
@@ -3,7 +3,7 @@
33
from utils.BASEDIR import BASEDIR
44

55
model = load_model('{}/examples/batch_norm.h5'.format(BASEDIR))
6-
print(model.predict([[0.5]]))
6+
print(model.predict([[[4.5, 4.5]]]).tolist())
77

88

99
# exit()

Diff for: konverter/utils/konverter_support.py

+5-5
Original file line numberDiff line numberDiff line change
@@ -59,26 +59,26 @@ def attr_map(self, classes, attr):
5959
def get_model_info(self, model):
6060
name = getattr(model, '_keras_api_names_v1')[0]
6161
model_class = self.get_class_from_name(name, 'models')
62+
model_class.info = BaseModelInfo()
6263
if not model_class:
6364
model_class = Models.Unsupported()
6465
model_class.name = name
65-
return model_class
66+
else:
67+
model_class.info.supported = True
68+
model_class.info.input_shape = model.input_shape
6669

67-
model_class.info = BaseModelInfo()
68-
model_class.info.input_shape = model.input_shape
69-
model_class.info.supported = True
7070
return model_class
7171

7272
def get_layer_info(self, layer):
7373
name = getattr(layer, '_keras_api_names_v1')
7474
if not len(name):
7575
name = getattr(layer, '_keras_api_names')
7676
layer_class = self.get_class_from_name(name[0], 'layers') # assume only one name
77+
layer_class.info = BaseLayerInfo()
7778
if not layer_class:
7879
layer_class = Layers.Unsupported() # add activation below to raise exception with
7980
layer_class.name = name
8081

81-
layer_class.info = BaseLayerInfo()
8282
layer_class.info.is_ignored = layer_class.name in self.ignored_layers
8383

8484
is_linear = False

Diff for: misc/old/konverter/run.py

+2-2
Original file line numberDiff line numberDiff line change
@@ -5,6 +5,6 @@
55

66
os.chdir(BASEDIR)
77

8-
model = load_model('examples/batch_norm.h5')
8+
model = load_model('examples/latest_maybe_good.h5')
99
kon = Konverter()
10-
kon.konvert(model, 'examples/batch_norm.py', 2, verbose=True)
10+
kon.konvert(model, 'examples/latest_maybe_good.py', 2, verbose=True)

0 commit comments

Comments
 (0)