Skip to content

Commit a152251

Browse files
authored
Flexible error metric (#21)
Allows to use a user-defined error metric. The old hard-coded error_metric is replaced by an interface but shipped as the default metric.
1 parent 4a70be9 commit a152251

File tree

9 files changed

+72
-22
lines changed

9 files changed

+72
-22
lines changed

CMakeLists.txt

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -9,7 +9,7 @@ project(gridtools-verification CXX C)
99

1010
include(ExternalProject)
1111

12-
set(GRIDTOOLS_VERIFICATION_VERSION_STRING "0.3")
12+
set(GRIDTOOLS_VERIFICATION_VERSION_STRING "0.4")
1313
set(SERIALBOX_VERSION_REQUIRED "2.2.1")
1414

1515
#----------------- CMake options

src/verification/error_metric.h

Lines changed: 3 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -37,6 +37,7 @@
3737

3838
#include <cmath>
3939
#include "../common.h"
40+
#include "error_metric_interface.h"
4041

4142
namespace gt_verification {
4243

@@ -54,7 +55,7 @@ namespace gt_verification {
5455
* @ingroup DycoreUnittestVerificationLibrary
5556
*/
5657
template < typename T >
57-
class error_metric {
58+
class error_metric : public error_metric_interface< T > {
5859
public:
5960
error_metric(const error_metric &) = default;
6061
error_metric &operator=(const error_metric &) = default;
@@ -72,7 +73,7 @@ namespace gt_verification {
7273
*
7374
* @return true iff absolute(a - b) <= (atol + rtol * absolute(b))
7475
*/
75-
inline bool equal(T a, T b) const noexcept { return (std::fabs(a - b) <= (atol_ + rtol_ * std::fabs(b))); }
76+
bool equal(T a, T b) const noexcept override { return (std::fabs(a - b) <= (atol_ + rtol_ * std::fabs(b))); }
7677

7778
private:
7879
T rtol_;
Lines changed: 51 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,51 @@
1+
/*
2+
GridTools Libraries
3+
4+
Copyright (c) 2016, GridTools Consortium
5+
All rights reserved.
6+
7+
Redistribution and use in source and binary forms, with or without
8+
modification, are permitted provided that the following conditions are
9+
met:
10+
11+
1. Redistributions of source code must retain the above copyright
12+
notice, this list of conditions and the following disclaimer.
13+
14+
2. Redistributions in binary form must reproduce the above copyright
15+
notice, this list of conditions and the following disclaimer in the
16+
documentation and/or other materials provided with the distribution.
17+
18+
3. Neither the name of the copyright holder nor the names of its
19+
contributors may be used to endorse or promote products derived from
20+
this software without specific prior written permission.
21+
22+
THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS
23+
"AS IS" AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT
24+
LIMITED TO, THE IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR
25+
A PARTICULAR PURPOSE ARE DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT
26+
HOLDER OR CONTRIBUTORS BE LIABLE FOR ANY DIRECT, INDIRECT, INCIDENTAL,
27+
SPECIAL, EXEMPLARY, OR CONSEQUENTIAL DAMAGES (INCLUDING, BUT NOT
28+
LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR SERVICES; LOSS OF USE,
29+
DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER CAUSED AND ON ANY
30+
THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY, OR TORT
31+
(INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE
32+
OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE.
33+
34+
For information: http://eth-cscs.github.io/gridtools/
35+
*/
36+
#pragma once
37+
38+
#include <cmath>
39+
#include "../common.h"
40+
41+
namespace gt_verification {
42+
43+
template < typename T >
44+
class error_metric_interface {
45+
public:
46+
/**
47+
* @brief Check if two real numbers @c a and @c b are equal within a tolerance
48+
*/
49+
virtual bool equal(T a, T b) const noexcept = 0;
50+
};
51+
}

src/verification/field_collection.h

Lines changed: 4 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -44,7 +44,7 @@
4444
#include "../core/logger.h"
4545
#include "../verification_exception.h"
4646
#include "verification_reporter.h"
47-
#include "error_metric.h"
47+
#include "error_metric_interface.h"
4848
#include "boundary_extent.h"
4949
#include "verification.h"
5050
#include "verification_result.h"
@@ -206,18 +206,18 @@ namespace gt_verification {
206206
*
207207
* @return VerificationResult
208208
*/
209-
verification_result verify(error_metric< T > errorMetric) {
209+
verification_result verify(const error_metric_interface< T > &error_metric) {
210210
verifications_.clear();
211211

212212
verification_result totalResult(true, "\n");
213213

214214
// Iterate over output fields and compare them to the reference fields
215215
for (std::size_t i = 0; i < outputFields_.size(); ++i) {
216216
verifications_.emplace_back(
217-
outputFields_[i].second, referenceFields_[i].second.to_view(), errorMetric, boundaries_[i]);
217+
outputFields_[i].second, referenceFields_[i].second.to_view(), boundaries_[i]);
218218

219219
// Perform actual verification and merge results
220-
totalResult.merge(verifications_.back().verify());
220+
totalResult.merge(verifications_.back().verify(error_metric));
221221
}
222222
return totalResult;
223223
}

src/verification/unittest_environment.h

Lines changed: 6 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -1,6 +1,7 @@
11
#pragma once
22

33
#include "field_collection.h"
4+
#include "error_metric.h"
45
#include "../core.h"
56
#include <gtest/gtest.h>
67
#include <string>
@@ -16,14 +17,16 @@ namespace gt_verification {
1617
private boost::noncopyable /* singleton */
1718
{
1819
public:
19-
unittest_environment(command_line &cl, std::string data_name, std::string archive_type="Binary") : cl_(cl), data_path_("./") {
20+
unittest_environment(command_line &cl, std::string data_name, std::string archive_type = "Binary")
21+
: cl_(cl), data_path_("./") {
2022
if (cl_.has("path"))
2123
data_path_ = cl_.as< std::string >("path");
2224

2325
VERIFICATION_LOG() << "Using serializer data-path: '" << data_path_ << "'" << logger_action::endl;
2426

2527
// Initialize the serializer
26-
reference_serializer_ = std::make_shared< ser::serializer >(ser::open_mode::Read, data_path_, data_name, archive_type);
28+
reference_serializer_ =
29+
std::make_shared< ser::serializer >(ser::open_mode::Read, data_path_, data_name, archive_type);
2730

2831
// Initialize error serializer
2932
error_serializer_ = std::make_shared< ser::serializer >(ser::open_mode::Write, ".", "Error");
@@ -125,7 +128,7 @@ namespace gt_verification {
125128
*/
126129
template < typename T >
127130
testing::AssertionResult verify_collection(
128-
field_collection< T > &fieldCollection, error_metric< T > errorMetric) {
131+
field_collection< T > &fieldCollection, const error_metric_interface< T > &errorMetric) {
129132
verification_result result = fieldCollection.verify(errorMetric);
130133
if (!result.passed())
131134
fieldCollection.report_failures();

src/verification/verification.h

Lines changed: 3 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -86,10 +86,8 @@ namespace gt_verification {
8686
*/
8787
verification(type_erased_field_view< T > outputField,
8888
type_erased_field_view< T > referenceField,
89-
error_metric< T > errorMetric,
9089
boundary_extent boundary = boundary_extent())
91-
: outputField_(outputField), referenceField_(referenceField), errorMetric_(errorMetric),
92-
boundary_(boundary) {}
90+
: outputField_(outputField), referenceField_(referenceField), boundary_(boundary) {}
9391

9492
/**
9593
* @brief Verify that outputField is equal to refrenceField within the given error metric
@@ -98,7 +96,7 @@ namespace gt_verification {
9896
*
9997
* @return VerificationResult
10098
*/
101-
verification_result verify() noexcept {
99+
verification_result verify(const error_metric_interface< T > &error_metric) noexcept {
102100
// Sync fields with Host
103101
outputField_.sync();
104102
referenceField_.sync();
@@ -130,7 +128,7 @@ namespace gt_verification {
130128
for (int k = boundary_.k_minus(); k < (kSizeOut + boundary_.k_plus()); ++k)
131129
for (int j = boundary_.j_minus(); j < (jSizeOut + boundary_.j_plus()); ++j)
132130
for (int i = boundary_.i_minus(); i < (iSizeOut + boundary_.i_plus()); ++i)
133-
if (!errorMetric_.equal(outputField_(i, j, k), referenceField_(i, j, k)))
131+
if (!error_metric.equal(outputField_(i, j, k), referenceField_(i, j, k)))
134132
failures_.push_back(failure{i, j, k, outputField_(i, j, k), referenceField_(i, j, k)});
135133

136134
if (failures_.empty())
@@ -184,7 +182,6 @@ namespace gt_verification {
184182
private:
185183
type_erased_field_view< T > outputField_;
186184
type_erased_field_view< T > referenceField_;
187-
error_metric< T > errorMetric_;
188185
boundary_extent boundary_;
189186

190187
std::vector< failure > failures_;

src/verification/verification_reporter.h

Lines changed: 0 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -50,7 +50,6 @@
5050
#include "../core/command_line.h"
5151
#include "../core/error.h"
5252
#include "../core/utility.h"
53-
#include "error_metric.h"
5453

5554
namespace gt_verification {
5655

src/verification/verification_specification.cpp

Lines changed: 0 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -39,7 +39,6 @@
3939
#include "../core/error.h"
4040
#include "../core/utility.h"
4141
#include "../core/logger.h"
42-
#include "error_metric.h"
4342
#include "verification_specification.h"
4443
#include <cstdlib>
4544
#include <string>

unittest/verification/test_verification.cpp

Lines changed: 4 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -71,8 +71,8 @@ class test_Verification : public ::testing::Test {
7171
TEST_F(test_Verification, UnmodifiedFieldShouldPass) {
7272
error_metric< Real > errorMetric(1e-6, 1e-8);
7373

74-
verification< Real > test(outView, refView, errorMetric);
75-
ASSERT_TRUE(test.verify().passed());
74+
verification< Real > test(outView, refView);
75+
ASSERT_TRUE(test.verify(errorMetric).passed());
7676
}
7777

7878
TEST_F(test_Verification, ModifiedFieldShouldFail) {
@@ -87,8 +87,8 @@ TEST_F(test_Verification, ModifiedFieldShouldFail) {
8787
outField.h2d_update();
8888
#endif
8989

90-
verification< Real > test(outView, refView, errorMetric);
91-
ASSERT_FALSE(test.verify().passed());
90+
verification< Real > test(outView, refView);
91+
ASSERT_FALSE(test.verify(errorMetric).passed());
9292
}
9393

9494
#endif

0 commit comments

Comments
 (0)