|
7 | 7 | using static Tensorflow.Binding;
|
8 | 8 | using static Tensorflow.KerasApi;
|
9 | 9 | using System.Linq;
|
| 10 | +using System.Text.RegularExpressions; |
| 11 | + |
10 | 12 | namespace Tensorflow.Keras.Saving
|
11 | 13 | {
|
12 | 14 | public class hdf5_format
|
@@ -132,7 +134,9 @@ public static void load_weights_from_hdf5_group(long f, List<ILayer> layers)
|
132 | 134 | var weight_names = load_attributes_from_hdf5_group(g, "weight_names");
|
133 | 135 | foreach (var i_ in weight_names)
|
134 | 136 | {
|
135 |
| - (success, Array result) = Hdf5.ReadDataset<float>(g, i_); |
| 137 | + var vm = Regex.Replace(i_, "/", "$"); |
| 138 | + vm = i_.Split('/')[0] + "/$" + vm.Substring(i_.Split('/')[0].Length + 1, i_.Length - i_.Split('/')[0].Length - 1); |
| 139 | + (success, Array result) = Hdf5.ReadDataset<float>(g, vm); |
136 | 140 | if (success)
|
137 | 141 | weight_values.Add(np.array(result));
|
138 | 142 | }
|
@@ -193,7 +197,8 @@ public static void save_weights_to_hdf5_group(long f, List<ILayer> layers)
|
193 | 197 | if (name.IndexOf("/") > 1)
|
194 | 198 | {
|
195 | 199 | var crDataGroup = Hdf5.CreateOrOpenGroup(g, Hdf5Utils.NormalizedName(name.Split('/')[0]));
|
196 |
| - WriteDataset(crDataGroup, name.Split('/')[1], tensor); |
| 200 | + var _name = Regex.Replace(name.Substring(name.Split('/')[0].Length, name.Length - name.Split('/')[0].Length), "/", "$"); |
| 201 | + WriteDataset(crDataGroup, _name, tensor); |
197 | 202 | Hdf5.CloseGroup(crDataGroup);
|
198 | 203 | }
|
199 | 204 | else
|
|
0 commit comments