|
43 | 43 | #include "../verification_exception.h"
|
44 | 44 | #include "error.h"
|
45 | 45 | #include "logger.h"
|
| 46 | +#include <numeric> |
| 47 | +#include <string> |
46 | 48 |
|
47 | 49 | namespace gt_verification {
|
48 | 50 |
|
@@ -86,38 +88,32 @@ namespace gt_verification {
|
86 | 88 | // Get info of serialized field
|
87 | 89 | const ser::field_meta_info &info = serializer_->get_field_meta_info(name);
|
88 | 90 |
|
89 |
| - int iSizeHalo = field.i_size(); |
90 |
| - int jSizeHalo = field.j_size(); |
91 |
| - int kSize = field.k_size(); |
| 91 | + std::vector< int > field_sizes{field.i_size(), field.j_size(), field.k_size()}; |
| 92 | + |
| 93 | + auto mask = mask_for_killed_dimensions(info.dims(), field_sizes); |
| 94 | + field_sizes = apply_mask(mask, field_sizes); |
92 | 95 |
|
93 |
| - VERIFICATION_LOG() << boost::format(" - loading %-15s (%2i, %2i, %2i)") % name % iSizeHalo % jSizeHalo % |
94 |
| - kSize |
| 96 | + VERIFICATION_LOG() << boost::format(" - loading %-15s (%2i, %2i, %2i)") % name % field_sizes[0] % |
| 97 | + field_sizes[1] % field_sizes[2] |
95 | 98 | << logger_action::endl;
|
96 | 99 |
|
97 | 100 | // Check dimensions
|
98 |
| - if ((info.dims()[0] != iSizeHalo) || (info.dims()[1] != jSizeHalo) || (info.dims()[2] != kSize)) |
| 101 | + if (!sizes_compatible(info.dims(), field_sizes)) |
99 | 102 | throw verification_exception("the requested field '%s' has a different size than the provided field.\n"
|
100 |
| - "Registered as: (%i, %i, %i)\n" |
101 |
| - "Given as: (%i, %i, %i)", |
| 103 | + "Registered as: (%s)\n" |
| 104 | + "Given as: (%s)", |
102 | 105 | name,
|
103 |
| - info.dims()[0], |
104 |
| - info.dims()[1], |
105 |
| - info.dims()[2], |
106 |
| - iSizeHalo, |
107 |
| - jSizeHalo, |
108 |
| - kSize); |
| 106 | + to_string(info.dims()), |
| 107 | + to_string(field_sizes)); |
109 | 108 |
|
110 | 109 | // Check types
|
111 | 110 | if (info.type() != serialbox::ToTypeID< T >::value)
|
112 | 111 | throw verification_exception(
|
113 | 112 | "the requested field '%s' has a different type than the provided field.", name);
|
114 | 113 |
|
115 | 114 | // Deserialize field
|
116 |
| - int iStride = field.i_stride(); |
117 |
| - int jStride = field.j_stride(); |
118 |
| - int kStride = field.k_stride(); |
119 |
| - |
120 |
| - std::vector< int > strides{iStride, jStride, kStride}; |
| 115 | + std::vector< int > strides{field.i_stride(), field.j_stride(), field.k_stride()}; |
| 116 | + strides = apply_mask(mask, strides); |
121 | 117 | serializer_->read(name, savepoint, field.data(), strides);
|
122 | 118 |
|
123 | 119 | field.sync();
|
@@ -189,5 +185,62 @@ namespace gt_verification {
|
189 | 185 |
|
190 | 186 | private:
|
191 | 187 | std::shared_ptr< ser::serializer > serializer_;
|
| 188 | + |
| 189 | + bool can_transform_dimension(int serialized_size, int verifier_size) { |
| 190 | + // We allow automatic transformation of D-1-dim fields to D-dim fields if the length of the dimension is 1 |
| 191 | + if (serialized_size == 0 && verifier_size == 1) |
| 192 | + return true; |
| 193 | + else |
| 194 | + return false; |
| 195 | + } |
| 196 | + |
| 197 | + bool sizes_compatible(const std::vector< int > &serialized_sizes, const std::vector< int > &verifier_sizes) { |
| 198 | + for (size_t i = 0; i < serialized_sizes.size(); ++i) { |
| 199 | + if (serialized_sizes[i] == 0) |
| 200 | + return true; |
| 201 | + else if (serialized_sizes[i] != verifier_sizes[i] && |
| 202 | + !can_transform_dimension(serialized_sizes[i], verifier_sizes[i])) |
| 203 | + return false; |
| 204 | + } |
| 205 | + return true; |
| 206 | + } |
| 207 | + |
| 208 | + std::string to_string(const std::vector< int > &v) { |
| 209 | + if (v.size() == 0) |
| 210 | + return std::string("<empty-vector>"); |
| 211 | + else |
| 212 | + return std::accumulate(std::next(v.begin()), |
| 213 | + v.end(), |
| 214 | + std::string(std::to_string(v[0])), |
| 215 | + [](std::string s, int i) { return s + ", " + std::to_string(i); }); |
| 216 | + } |
| 217 | + |
| 218 | + // FIXME: hack the mapping of killed dimension |
| 219 | + std::vector< bool > mask_for_killed_dimensions( |
| 220 | + const std::vector< int > &serialized_sizes, const std::vector< int > &verifier_sizes) const { |
| 221 | + std::vector< bool > mask; |
| 222 | + size_t i_serialized = 0; |
| 223 | + for (size_t i = 0; i < verifier_sizes.size(); ++i) { |
| 224 | + if (verifier_sizes[i] == serialized_sizes[i_serialized]) { |
| 225 | + mask.push_back(true); |
| 226 | + i_serialized++; |
| 227 | + } else if (verifier_sizes[i] == 1) |
| 228 | + mask.push_back(false); |
| 229 | + else |
| 230 | + throw verification_exception("Failed to mask killed dimensions."); |
| 231 | + } |
| 232 | + return mask; |
| 233 | + } |
| 234 | + |
| 235 | + std::vector< int > apply_mask(const std::vector< bool > &mask, const std::vector< int > &v) const { |
| 236 | + if (mask.size() != v.size()) |
| 237 | + throw verification_exception("Size of mask does not match size of vector."); |
| 238 | + std::vector< int > result; |
| 239 | + for (size_t i = 0; i < mask.size(); ++i) { |
| 240 | + if (mask[i]) |
| 241 | + result.push_back(v[i]); |
| 242 | + } |
| 243 | + return result; |
| 244 | + } |
192 | 245 | };
|
193 | 246 | }
|
0 commit comments