Skip to content

Commit 992bf55

Browse files
author
Beacontownfc
committed
fix load_weights
1 parent eac68ff commit 992bf55

File tree

1 file changed

+7
-2
lines changed

1 file changed

+7
-2
lines changed

Diff for: src/TensorFlowNET.Keras/Saving/hdf5_format.cs

+7-2
Original file line numberDiff line numberDiff line change
@@ -7,6 +7,8 @@
77
using static Tensorflow.Binding;
88
using static Tensorflow.KerasApi;
99
using System.Linq;
10+
using System.Text.RegularExpressions;
11+
1012
namespace Tensorflow.Keras.Saving
1113
{
1214
public class hdf5_format
@@ -132,7 +134,9 @@ public static void load_weights_from_hdf5_group(long f, List<ILayer> layers)
132134
var weight_names = load_attributes_from_hdf5_group(g, "weight_names");
133135
foreach (var i_ in weight_names)
134136
{
135-
(success, Array result) = Hdf5.ReadDataset<float>(g, i_);
137+
var vm = Regex.Replace(i_, "/", "$");
138+
vm = i_.Split('/')[0] + "/$" + vm.Substring(i_.Split('/')[0].Length + 1, i_.Length - i_.Split('/')[0].Length - 1);
139+
(success, Array result) = Hdf5.ReadDataset<float>(g, vm);
136140
if (success)
137141
weight_values.Add(np.array(result));
138142
}
@@ -193,7 +197,8 @@ public static void save_weights_to_hdf5_group(long f, List<ILayer> layers)
193197
if (name.IndexOf("/") > 1)
194198
{
195199
var crDataGroup = Hdf5.CreateOrOpenGroup(g, Hdf5Utils.NormalizedName(name.Split('/')[0]));
196-
WriteDataset(crDataGroup, name.Split('/')[1], tensor);
200+
var _name = Regex.Replace(name.Substring(name.Split('/')[0].Length, name.Length - name.Split('/')[0].Length), "/", "$");
201+
WriteDataset(crDataGroup, _name, tensor);
197202
Hdf5.CloseGroup(crDataGroup);
198203
}
199204
else

0 commit comments

Comments
 (0)