Skip to content

Commit 931f3aa

Browse files
authored
Update neural_network.py to account for rat bug
1 parent 014e193 commit 931f3aa

File tree

1 file changed

+21
-12
lines changed

1 file changed

+21
-12
lines changed

DeepSlice/neural_network/neural_network.py

Lines changed: 21 additions & 12 deletions
Original file line numberDiff line numberDiff line change
@@ -55,32 +55,40 @@ def initialise_network(xception_weights: str, weights: str, species: str) -> Seq
5555
model.add(Dense(9, activation="linear"))
5656

5757
if weights != None:
58-
model = load_xception_weights(model, weights)
58+
model = load_xception_weights(model, weights, species)
5959
return model
6060

6161

62-
def load_xception_weights(model, weights):
62+
def load_xception_weights(model, weights, species = "mouse"):
6363
with h5py.File(weights, "r") as new:
6464
# set weight of each layer manually
65-
model.layers[1].set_weights([new["dense"]["dense"]["kernel:0"], new["dense"]["dense"]["bias:0"]])
66-
model.layers[2].set_weights([new["dense_1"]["dense_1"]["kernel:0"], new["dense_1"]["dense_1"]["bias:0"]])
67-
model.layers[3].set_weights([new["dense_2"]["dense_2"]["kernel:0"], new["dense_2"]["dense_2"]["bias:0"]])
65+
if species == "mouse":
66+
xception_idx = 0
67+
dense_idx = 1
68+
elif species == "rat":
69+
# RatModelInProgress.h5 has an "input_2" layer at index 0, so we need to adjust the indices<
70+
xception_idx = 1
71+
dense_idx = 2
72+
73+
model.layers[dense_idx].set_weights([new["dense"]["dense"]["kernel:0"], new["dense"]["dense"]["bias:0"]])
74+
model.layers[dense_idx+1].set_weights([new["dense_1"]["dense_1"]["kernel:0"], new["dense_1"]["dense_1"]["bias:0"]])
75+
model.layers[dense_idx+2].set_weights([new["dense_2"]["dense_2"]["kernel:0"], new["dense_2"]["dense_2"]["bias:0"]])
6876

6977
# Set the weights of the xception model
7078
weight_names = new["xception"].attrs["weight_names"].tolist()
7179
weight_names_layers = [name.decode("utf-8").split("/")[0] for name in weight_names]
7280

73-
for i in range(len(model.layers[0].layers)):
74-
name_of_layer = model.layers[0].layers[i].name
81+
for i in range(len(model.layers[xception_idx].layers)):
82+
name_of_layer = model.layers[xception_idx].layers[i].name
7583
# if layer name is in the weight names, then we will set weights
7684
if name_of_layer in weight_names_layers:
7785
# Get name of weights in the layer
7886
layer_weight_names = []
79-
for weight in model.layers[0].layers[i].weights:
87+
for weight in model.layers[xception_idx].layers[i].weights:
8088
layer_weight_names.append(weight.name.split("/")[1])
8189
h5_group = new["xception"][name_of_layer]
8290
weights_list = [np.array(h5_group[kk]) for kk in layer_weight_names]
83-
model.layers[0].layers[i].set_weights(weights_list)
91+
model.layers[xception_idx].layers[i].set_weights(weights_list)
8492
return model
8593

8694
def load_images_from_path(image_path: str) -> np.ndarray:
@@ -163,6 +171,7 @@ def predictions_util(
163171
primary_weights: str,
164172
secondary_weights: str,
165173
ensemble: bool = False,
174+
species : str = "mouse"
166175
):
167176
"""
168177
Predict the image alignments
@@ -174,7 +183,7 @@ def predictions_util(
174183
:return: The predicted alignments
175184
:rtype: list
176185
"""
177-
model = load_xception_weights(model, primary_weights)
186+
model = load_xception_weights(model, primary_weights, species)
178187
predictions = model.predict(
179188
image_generator,
180189
steps=image_generator.n // image_generator.batch_size,
@@ -183,14 +192,14 @@ def predictions_util(
183192
predictions = predictions.astype(np.float64)
184193
if ensemble:
185194
image_generator.reset()
186-
model = load_xception_weights(model, secondary_weights)
195+
model = load_xception_weights(model, secondary_weights, species)
187196
secondary_predictions = model.predict(
188197
image_generator,
189198
steps=image_generator.n // image_generator.batch_size,
190199
verbose=1,
191200
)
192201
predictions = np.mean([predictions, secondary_predictions], axis=0)
193-
model = load_xception_weights(model, primary_weights)
202+
model = load_xception_weights(model, primary_weights, species)
194203
filenames = image_generator.filenames
195204
filenames = [os.path.basename(i) for i in filenames]
196205
predictions_df = pd.DataFrame(

0 commit comments

Comments
 (0)