12
12
import warnings
13
13
import imghdr
14
14
import struct
15
+ import h5py
15
16
16
17
17
18
def gray_scale (img : np .ndarray ) -> np .ndarray :
@@ -38,8 +39,7 @@ def initialise_network(xception_weights: str, weights: str, species: str) -> Seq
38
39
:rtype: keras.models.Sequential
39
40
"""
40
41
base_model = Xception (include_top = True , weights = xception_weights )
41
- base_model ._layers .pop ()
42
- base_model ._layers .pop ()
42
+
43
43
if species == "rat" :
44
44
inputs = Input (shape = (299 , 299 , 3 ))
45
45
base_model_layer = base_model (inputs , training = True )
@@ -55,11 +55,41 @@ def initialise_network(xception_weights: str, weights: str, species: str) -> Seq
55
55
model .add (Dense (9 , activation = "linear" ))
56
56
57
57
if weights != None :
58
- model . load_weights ( weights )
58
+ model = load_xception_weights ( model , weights , species )
59
59
return model
60
60
61
61
62
+ def load_xception_weights (model , weights , species = "mouse" ):
63
+ with h5py .File (weights , "r" ) as new :
64
+ # set weight of each layer manually
65
+ if species == "mouse" :
66
+ xception_idx = 0
67
+ dense_idx = 1
68
+ elif species == "rat" :
69
+ # RatModelInProgress.h5 has an "input_2" layer at index 0, so we need to adjust the indices<
70
+ xception_idx = 1
71
+ dense_idx = 2
72
+
73
+ model .layers [dense_idx ].set_weights ([new ["dense" ]["dense" ]["kernel:0" ], new ["dense" ]["dense" ]["bias:0" ]])
74
+ model .layers [dense_idx + 1 ].set_weights ([new ["dense_1" ]["dense_1" ]["kernel:0" ], new ["dense_1" ]["dense_1" ]["bias:0" ]])
75
+ model .layers [dense_idx + 2 ].set_weights ([new ["dense_2" ]["dense_2" ]["kernel:0" ], new ["dense_2" ]["dense_2" ]["bias:0" ]])
76
+
77
+ # Set the weights of the xception model
78
+ weight_names = new ["xception" ].attrs ["weight_names" ].tolist ()
79
+ weight_names_layers = [name .decode ("utf-8" ).split ("/" )[0 ] for name in weight_names ]
62
80
81
+ for i in range (len (model .layers [xception_idx ].layers )):
82
+ name_of_layer = model .layers [xception_idx ].layers [i ].name
83
+ # if layer name is in the weight names, then we will set weights
84
+ if name_of_layer in weight_names_layers :
85
+ # Get name of weights in the layer
86
+ layer_weight_names = []
87
+ for weight in model .layers [xception_idx ].layers [i ].weights :
88
+ layer_weight_names .append (weight .name .split ("/" )[1 ])
89
+ h5_group = new ["xception" ][name_of_layer ]
90
+ weights_list = [np .array (h5_group [kk ]) for kk in layer_weight_names ]
91
+ model .layers [xception_idx ].layers [i ].set_weights (weights_list )
92
+ return model
63
93
64
94
def load_images_from_path (image_path : str ) -> np .ndarray :
65
95
"""
@@ -141,6 +171,7 @@ def predictions_util(
141
171
primary_weights : str ,
142
172
secondary_weights : str ,
143
173
ensemble : bool = False ,
174
+ species : str = "mouse"
144
175
):
145
176
"""
146
177
Predict the image alignments
@@ -152,7 +183,7 @@ def predictions_util(
152
183
:return: The predicted alignments
153
184
:rtype: list
154
185
"""
155
- model . load_weights ( primary_weights )
186
+ model = load_xception_weights ( model , primary_weights , species )
156
187
predictions = model .predict (
157
188
image_generator ,
158
189
steps = image_generator .n // image_generator .batch_size ,
@@ -161,14 +192,14 @@ def predictions_util(
161
192
predictions = predictions .astype (np .float64 )
162
193
if ensemble :
163
194
image_generator .reset ()
164
- model . load_weights ( secondary_weights )
195
+ model = load_xception_weights ( model , secondary_weights , species )
165
196
secondary_predictions = model .predict (
166
197
image_generator ,
167
198
steps = image_generator .n // image_generator .batch_size ,
168
199
verbose = 1 ,
169
200
)
170
201
predictions = np .mean ([predictions , secondary_predictions ], axis = 0 )
171
- model . load_weights ( primary_weights )
202
+ model = load_xception_weights ( model , primary_weights , species )
172
203
filenames = image_generator .filenames
173
204
filenames = [os .path .basename (i ) for i in filenames ]
174
205
predictions_df = pd .DataFrame (
0 commit comments