@@ -55,32 +55,40 @@ def initialise_network(xception_weights: str, weights: str, species: str) -> Seq
55
55
model .add (Dense (9 , activation = "linear" ))
56
56
57
57
if weights != None :
58
- model = load_xception_weights (model , weights )
58
+ model = load_xception_weights (model , weights , species )
59
59
return model
60
60
61
61
62
- def load_xception_weights (model , weights ):
62
+ def load_xception_weights (model , weights , species = "mouse" ):
63
63
with h5py .File (weights , "r" ) as new :
64
64
# set weight of each layer manually
65
- model .layers [1 ].set_weights ([new ["dense" ]["dense" ]["kernel:0" ], new ["dense" ]["dense" ]["bias:0" ]])
66
- model .layers [2 ].set_weights ([new ["dense_1" ]["dense_1" ]["kernel:0" ], new ["dense_1" ]["dense_1" ]["bias:0" ]])
67
- model .layers [3 ].set_weights ([new ["dense_2" ]["dense_2" ]["kernel:0" ], new ["dense_2" ]["dense_2" ]["bias:0" ]])
65
+ if species == "mouse" :
66
+ xception_idx = 0
67
+ dense_idx = 1
68
+ elif species == "rat" :
69
+ # RatModelInProgress.h5 has an "input_2" layer at index 0, so we need to adjust the indices<
70
+ xception_idx = 1
71
+ dense_idx = 2
72
+
73
+ model .layers [dense_idx ].set_weights ([new ["dense" ]["dense" ]["kernel:0" ], new ["dense" ]["dense" ]["bias:0" ]])
74
+ model .layers [dense_idx + 1 ].set_weights ([new ["dense_1" ]["dense_1" ]["kernel:0" ], new ["dense_1" ]["dense_1" ]["bias:0" ]])
75
+ model .layers [dense_idx + 2 ].set_weights ([new ["dense_2" ]["dense_2" ]["kernel:0" ], new ["dense_2" ]["dense_2" ]["bias:0" ]])
68
76
69
77
# Set the weights of the xception model
70
78
weight_names = new ["xception" ].attrs ["weight_names" ].tolist ()
71
79
weight_names_layers = [name .decode ("utf-8" ).split ("/" )[0 ] for name in weight_names ]
72
80
73
- for i in range (len (model .layers [0 ].layers )):
74
- name_of_layer = model .layers [0 ].layers [i ].name
81
+ for i in range (len (model .layers [xception_idx ].layers )):
82
+ name_of_layer = model .layers [xception_idx ].layers [i ].name
75
83
# if layer name is in the weight names, then we will set weights
76
84
if name_of_layer in weight_names_layers :
77
85
# Get name of weights in the layer
78
86
layer_weight_names = []
79
- for weight in model .layers [0 ].layers [i ].weights :
87
+ for weight in model .layers [xception_idx ].layers [i ].weights :
80
88
layer_weight_names .append (weight .name .split ("/" )[1 ])
81
89
h5_group = new ["xception" ][name_of_layer ]
82
90
weights_list = [np .array (h5_group [kk ]) for kk in layer_weight_names ]
83
- model .layers [0 ].layers [i ].set_weights (weights_list )
91
+ model .layers [xception_idx ].layers [i ].set_weights (weights_list )
84
92
return model
85
93
86
94
def load_images_from_path (image_path : str ) -> np .ndarray :
@@ -163,6 +171,7 @@ def predictions_util(
163
171
primary_weights : str ,
164
172
secondary_weights : str ,
165
173
ensemble : bool = False ,
174
+ species : str = "mouse"
166
175
):
167
176
"""
168
177
Predict the image alignments
@@ -174,7 +183,7 @@ def predictions_util(
174
183
:return: The predicted alignments
175
184
:rtype: list
176
185
"""
177
- model = load_xception_weights (model , primary_weights )
186
+ model = load_xception_weights (model , primary_weights , species )
178
187
predictions = model .predict (
179
188
image_generator ,
180
189
steps = image_generator .n // image_generator .batch_size ,
@@ -183,14 +192,14 @@ def predictions_util(
183
192
predictions = predictions .astype (np .float64 )
184
193
if ensemble :
185
194
image_generator .reset ()
186
- model = load_xception_weights (model , secondary_weights )
195
+ model = load_xception_weights (model , secondary_weights , species )
187
196
secondary_predictions = model .predict (
188
197
image_generator ,
189
198
steps = image_generator .n // image_generator .batch_size ,
190
199
verbose = 1 ,
191
200
)
192
201
predictions = np .mean ([predictions , secondary_predictions ], axis = 0 )
193
- model = load_xception_weights (model , primary_weights )
202
+ model = load_xception_weights (model , primary_weights , species )
194
203
filenames = image_generator .filenames
195
204
filenames = [os .path .basename (i ) for i in filenames ]
196
205
predictions_df = pd .DataFrame (
0 commit comments