Skip to content

Commit 834f6b2

Browse files
authored
Merge pull request #52 from wjguan/main
Update neural_network.py to be able to load old model weights. This enables use of DeepSlice with TF2 and newer python versions
2 parents 9a947ed + aee5f00 commit 834f6b2

File tree

3 files changed

+42
-11
lines changed

3 files changed

+42
-11
lines changed

Diff for: DeepSlice/main.py

+2-2
Original file line numberDiff line numberDiff line change
@@ -77,11 +77,11 @@ def predict(
7777
if use_secondary_model:
7878
print("Using secondary model")
7979
predictions = neural_network.predictions_util(
80-
self.model, image_generator, secondary_weights,None, ensemble
80+
self.model, image_generator, secondary_weights,None, ensemble, self.species
8181
)
8282
else:
8383
predictions = neural_network.predictions_util(
84-
self.model, image_generator, primary_weights, secondary_weights, ensemble
84+
self.model, image_generator, primary_weights, secondary_weights, ensemble, self.species
8585
)
8686
predictions["width"] = width
8787
predictions["height"] = height

Diff for: DeepSlice/neural_network/neural_network.py

+37-6
Original file line numberDiff line numberDiff line change
@@ -12,6 +12,7 @@
1212
import warnings
1313
import imghdr
1414
import struct
15+
import h5py
1516

1617

1718
def gray_scale(img: np.ndarray) -> np.ndarray:
@@ -38,8 +39,7 @@ def initialise_network(xception_weights: str, weights: str, species: str) -> Seq
3839
:rtype: keras.models.Sequential
3940
"""
4041
base_model = Xception(include_top=True, weights=xception_weights)
41-
base_model._layers.pop()
42-
base_model._layers.pop()
42+
4343
if species == "rat":
4444
inputs = Input(shape=(299, 299, 3))
4545
base_model_layer = base_model(inputs, training=True)
@@ -55,11 +55,41 @@ def initialise_network(xception_weights: str, weights: str, species: str) -> Seq
5555
model.add(Dense(9, activation="linear"))
5656

5757
if weights != None:
58-
model.load_weights(weights)
58+
model = load_xception_weights(model, weights, species)
5959
return model
6060

6161

62+
def load_xception_weights(model, weights, species = "mouse"):
63+
with h5py.File(weights, "r") as new:
64+
# set weight of each layer manually
65+
if species == "mouse":
66+
xception_idx = 0
67+
dense_idx = 1
68+
elif species == "rat":
69+
# RatModelInProgress.h5 has an "input_2" layer at index 0, so we need to adjust the indices<
70+
xception_idx = 1
71+
dense_idx = 2
72+
73+
model.layers[dense_idx].set_weights([new["dense"]["dense"]["kernel:0"], new["dense"]["dense"]["bias:0"]])
74+
model.layers[dense_idx+1].set_weights([new["dense_1"]["dense_1"]["kernel:0"], new["dense_1"]["dense_1"]["bias:0"]])
75+
model.layers[dense_idx+2].set_weights([new["dense_2"]["dense_2"]["kernel:0"], new["dense_2"]["dense_2"]["bias:0"]])
76+
77+
# Set the weights of the xception model
78+
weight_names = new["xception"].attrs["weight_names"].tolist()
79+
weight_names_layers = [name.decode("utf-8").split("/")[0] for name in weight_names]
6280

81+
for i in range(len(model.layers[xception_idx].layers)):
82+
name_of_layer = model.layers[xception_idx].layers[i].name
83+
# if layer name is in the weight names, then we will set weights
84+
if name_of_layer in weight_names_layers:
85+
# Get name of weights in the layer
86+
layer_weight_names = []
87+
for weight in model.layers[xception_idx].layers[i].weights:
88+
layer_weight_names.append(weight.name.split("/")[1])
89+
h5_group = new["xception"][name_of_layer]
90+
weights_list = [np.array(h5_group[kk]) for kk in layer_weight_names]
91+
model.layers[xception_idx].layers[i].set_weights(weights_list)
92+
return model
6393

6494
def load_images_from_path(image_path: str) -> np.ndarray:
6595
"""
@@ -141,6 +171,7 @@ def predictions_util(
141171
primary_weights: str,
142172
secondary_weights: str,
143173
ensemble: bool = False,
174+
species : str = "mouse"
144175
):
145176
"""
146177
Predict the image alignments
@@ -152,7 +183,7 @@ def predictions_util(
152183
:return: The predicted alignments
153184
:rtype: list
154185
"""
155-
model.load_weights(primary_weights)
186+
model = load_xception_weights(model, primary_weights, species)
156187
predictions = model.predict(
157188
image_generator,
158189
steps=image_generator.n // image_generator.batch_size,
@@ -161,14 +192,14 @@ def predictions_util(
161192
predictions = predictions.astype(np.float64)
162193
if ensemble:
163194
image_generator.reset()
164-
model.load_weights(secondary_weights)
195+
model = load_xception_weights(model, secondary_weights, species)
165196
secondary_predictions = model.predict(
166197
image_generator,
167198
steps=image_generator.n // image_generator.batch_size,
168199
verbose=1,
169200
)
170201
predictions = np.mean([predictions, secondary_predictions], axis=0)
171-
model.load_weights(primary_weights)
202+
model = load_xception_weights(model, primary_weights, species)
172203
filenames = image_generator.filenames
173204
filenames = [os.path.basename(i) for i in filenames]
174205
predictions_df = pd.DataFrame(

Diff for: setup.py

+3-3
Original file line numberDiff line numberDiff line change
@@ -26,10 +26,10 @@
2626
'numpy',
2727
'scikit-learn',
2828
'scikit-image',
29-
'tensorflow==1.15.0',
30-
'h5py==2.10.0',
29+
'tensorflow==2.13.1',
30+
'h5py',
3131
'typing',
32-
'pandas==1.3.5',
32+
'pandas',
3333
'requests',
3434
'protobuf==3.20',
3535
'lxml',

0 commit comments

Comments
 (0)