1
+ #include " testPch.h"
2
+ #include " test/onnx/TestCase.h"
3
+ #include " test/onnx/heap_buffer.h"
4
+ #include " test/util/include/test/compare_ortvalue.h"
5
+ #include " ort_value_helper.h"
6
+ #include " onnxruntime_cxx_api.h"
7
+ #include " StringHelpers.h"
8
+ #include " skip_model_tests.h"
9
+
10
+ #ifndef BUILD_GOOGLE_TEST
11
+ #error Must use googletest for value-parameterized tests
12
+ #endif
13
+
14
+ using namespace onnxruntime ::test;
15
+ using namespace winml ;
16
+ using namespace onnxruntime ;
17
+
18
+ namespace WinML {
19
+ // Global needed to keep the actual ITestCase alive while the tests are going on. Only ITestCase* are used as test parameters.
20
+ std::vector<std::unique_ptr<ITestCase>> ownedTests;
21
+
22
+ class ModelTest : public testing ::TestWithParam<std::tuple<ITestCase*, winml::LearningModelDeviceKind>> {
23
+ protected:
24
+ void SetUp () override {
25
+ std::tie (m_testCase, m_deviceKind) = GetParam ();
26
+ WINML_EXPECT_NO_THROW (m_testCase->GetPerSampleTolerance (&m_perSampleTolerance));
27
+ WINML_EXPECT_NO_THROW (m_testCase->GetRelativePerSampleTolerance (&m_relativePerSampleTolerance));
28
+ WINML_EXPECT_NO_THROW (m_testCase->GetPostProcessing (&m_postProcessing));
29
+ }
30
+ // Called after the last test in this test suite.
31
+ static void TearDownTestSuite () {
32
+ ownedTests.clear (); // clear the global vector
33
+ }
34
+ winml::LearningModelDeviceKind m_deviceKind;
35
+ ITestCase* m_testCase;
36
+ double m_perSampleTolerance = 1e-3 ;
37
+ double m_relativePerSampleTolerance = 1e-3 ;
38
+ bool m_postProcessing = false ;
39
+
40
+ void BindInputsFromFeed (LearningModelBinding& binding, std::unordered_map<std::string, Ort::Value>& feed) {
41
+ for (auto & [name, value] : feed) {
42
+ ITensor bindingValue;
43
+ WINML_EXPECT_NO_THROW (bindingValue = OrtValueHelpers::LoadTensorFromOrtValue (value));
44
+ WINML_EXPECT_NO_THROW (binding.Bind (_winml::Strings::WStringFromString (name), bindingValue));
45
+ }
46
+ }
47
+
48
+ void CompareEvaluationResults (LearningModelEvaluationResult& results,
49
+ std::unordered_map<std::string, Ort::Value>& expectedOutputFeeds) {
50
+ for (const auto & [name, value] : expectedOutputFeeds) {
51
+ // Extract the output buffer from the evaluation output
52
+ std::wstring outputName = _winml::Strings::WStringFromString (name);
53
+ auto actualOutputTensorValue = results.Outputs ().Lookup (outputName).as <ITensorNative>();
54
+ BYTE* actualData;
55
+ uint32_t actualSizeInBytes;
56
+ WINML_EXPECT_HRESULT_SUCCEEDED (actualOutputTensorValue->GetBuffer (&actualData, &actualSizeInBytes));
57
+
58
+ // Create a copy of Ort::Value from evaluation output
59
+ auto expectedShapeAndTensorType = Ort::TensorTypeAndShapeInfo{nullptr };
60
+ auto memoryInfo = Ort::MemoryInfo{nullptr };
61
+ WINML_EXPECT_NO_THROW (expectedShapeAndTensorType = value.GetTensorTypeAndShapeInfo ());
62
+ WINML_EXPECT_NO_THROW (memoryInfo = Ort::MemoryInfo::CreateCpu (OrtArenaAllocator, OrtMemTypeDefault));
63
+ Ort::Value actualOutput = Ort::Value{nullptr };
64
+ WINML_EXPECT_NO_THROW (
65
+ actualOutput = Ort::Value::CreateTensor (
66
+ memoryInfo,
67
+ actualData,
68
+ actualSizeInBytes,
69
+ expectedShapeAndTensorType.GetShape ().data (),
70
+ expectedShapeAndTensorType.GetShape ().size (),
71
+ expectedShapeAndTensorType.GetElementType ()));
72
+
73
+ // Use the expected and actual OrtValues to compare
74
+ std::pair<COMPARE_RESULT, std::string> ret = CompareOrtValue (*actualOutput, *value, m_perSampleTolerance, m_relativePerSampleTolerance, m_postProcessing);
75
+ WINML_EXPECT_EQUAL (COMPARE_RESULT::SUCCESS, ret.first ) << ret.second ;
76
+ }
77
+ }
78
+ };
79
+
80
+ TEST_P (ModelTest, Run) {
81
+ LearningModel model = nullptr ;
82
+ LearningModelDevice device = nullptr ;
83
+ LearningModelSession session = nullptr ;
84
+ LearningModelBinding binding = nullptr ;
85
+ WINML_EXPECT_NO_THROW (model = LearningModel::LoadFromFilePath (m_testCase->GetModelUrl ()));
86
+ WINML_EXPECT_NO_THROW (device = LearningModelDevice (m_deviceKind));
87
+ WINML_EXPECT_NO_THROW (session = LearningModelSession (model, device));
88
+ WINML_EXPECT_NO_THROW (binding = LearningModelBinding (session));
89
+ for (size_t i = 0 ; i < m_testCase->GetDataCount (); i++) {
90
+ // Load and bind inputs
91
+ onnxruntime::test::HeapBuffer inputHolder;
92
+ std::unordered_map<std::string, Ort::Value> inputFeeds;
93
+ WINML_EXPECT_NO_THROW (m_testCase->LoadTestData (i, inputHolder, inputFeeds, true ));
94
+ WINML_EXPECT_NO_THROW (BindInputsFromFeed (binding, inputFeeds));
95
+
96
+ // evaluate
97
+ LearningModelEvaluationResult results = nullptr ;
98
+ WINML_EXPECT_NO_THROW (results = session.Evaluate (binding, L" Testing" ));
99
+
100
+ // Load expected outputs
101
+ onnxruntime::test::HeapBuffer outputHolder;
102
+ std::unordered_map<std::string, Ort::Value> outputFeeds;
103
+ WINML_EXPECT_NO_THROW (m_testCase->LoadTestData (i, outputHolder, outputFeeds, false ));
104
+
105
+ // compare results
106
+ CompareEvaluationResults (results, outputFeeds);
107
+ }
108
+ }
109
+
110
+ // Get the path of the model test collateral. Will return empty string if it doesn't exist.
111
+ std::string GetTestDataPath () {
112
+ std::string testDataPath (MAX_PATH, ' \0 ' );
113
+ auto environmentVariableFetchSuceeded = GetEnvironmentVariableA (" WINML_TEST_DATA_PATH" , testDataPath.data (), MAX_PATH);
114
+ if (environmentVariableFetchSuceeded == 0 && GetLastError () == ERROR_ENVVAR_NOT_FOUND || environmentVariableFetchSuceeded > MAX_PATH) {
115
+ // if the WINML_TEST_DATA_PATH environment variable cannot be found, attempt to find the hardcoded models folder
116
+ std::wstring modulePath = FileHelpers::GetModulePath ();
117
+ std::filesystem::path currPath = modulePath.substr (0 ,modulePath.find_last_of (L" \\ " ));
118
+ std::filesystem::path parentPath = currPath.parent_path ();
119
+ auto hardcodedModelPath = parentPath.string () + " \\ models" ;
120
+ if (std::filesystem::exists (hardcodedModelPath) && hardcodedModelPath.length () <= MAX_PATH) {
121
+ return hardcodedModelPath;
122
+ }
123
+ }
124
+ return testDataPath;
125
+ }
126
+
127
+ // This function returns the list of all test cases inside model test collateral
128
+ static std::vector<ITestCase*> GetAllTestCases () {
129
+ std::vector<ITestCase*> tests;
130
+ std::vector<std::basic_string<PATH_CHAR_TYPE>> whitelistedTestCases;
131
+ double perSampleTolerance = 1e-3 ;
132
+ double relativePerSampleTolerance = 1e-3 ;
133
+ std::unordered_set<std::basic_string<ORTCHAR_T>> allDisabledTests;
134
+ std::vector<std::basic_string<PATH_CHAR_TYPE>> dataDirs;
135
+ auto testDataPath = GetTestDataPath ();
136
+ if (testDataPath == " " ) return tests;
137
+
138
+ for (auto & p : std::filesystem::directory_iterator (testDataPath.c_str ())) {
139
+ if (p.is_directory ()) {
140
+ dataDirs.push_back (std::move (p.path ()));
141
+ }
142
+ }
143
+
144
+ WINML_EXPECT_NO_THROW (LoadTests (dataDirs, whitelistedTestCases, perSampleTolerance, relativePerSampleTolerance,
145
+ allDisabledTests,
146
+ [&tests](std::unique_ptr<ITestCase> l) {
147
+ tests.push_back (l.get ());
148
+ ownedTests.push_back (std::move (l));
149
+ }));
150
+ return tests;
151
+ }
152
+
153
+ // determine if test should be disabled
154
+ void DetermineIfDisableTest (std::string& testName, winml::LearningModelDeviceKind deviceKind) {
155
+ bool shouldSkip = false ;
156
+ std::string reason = " Reason not found." ;
157
+ if (disabledTests.find (testName) != disabledTests.end ()) {
158
+ reason = disabledTests.at (testName);
159
+ shouldSkip = true ;
160
+ } else if (deviceKind == LearningModelDeviceKind::DirectX) {
161
+ if (SkipGpuTests ()) {
162
+ reason = " GPU tests are not enabled for this build." ;
163
+ shouldSkip = true ;
164
+ } else if (disabledGpuTests.find (testName) != disabledGpuTests.end ()) {
165
+ reason = disabledGpuTests.at (testName);
166
+ shouldSkip = true ;
167
+ }
168
+ } else if (disabledx86Tests.find (testName) != disabledx86Tests.end ()) {
169
+ #if !defined(__amd64__) && !defined(_M_AMD64)
170
+ reason = disabledx86Tests.at (testName);
171
+ shouldSkip = true ;
172
+ #endif
173
+ }
174
+ if (shouldSkip) {
175
+ printf (" Disabling %s test because : %s\n " , testName.c_str (), reason.c_str ());
176
+ testName = " DISABLED_" + testName;
177
+ }
178
+ }
179
+
180
+ // This function gets the name of the test
181
+ static std::string GetNameOfTest (const testing::TestParamInfo<ModelTest::ParamType>& info) {
182
+ std::string name = " " ;
183
+ auto modelPath = std::wstring (std::get<0 >(info.param )->GetModelUrl ());
184
+ auto modelPathStr = _winml::Strings::UTF8FromUnicode (modelPath.c_str (), modelPath.length ());
185
+ std::vector<std::string> tokenizedModelPath;
186
+ std::istringstream ss (modelPathStr);
187
+ std::string token;
188
+ while (std::getline (ss, token, ' \\ ' )) {
189
+ tokenizedModelPath.push_back (std::move (token));
190
+ }
191
+ // The model path is structured like this "<opset>/<model_name>/model.onnx
192
+ // The desired naming of the test is like this <model_name>_<opset>_<CPU/GPU>
193
+ name += tokenizedModelPath[tokenizedModelPath.size () - 2 ] += " _" ; // model name
194
+ name += tokenizedModelPath[tokenizedModelPath.size () - 3 ]; // opset version
195
+
196
+ std::replace_if (name.begin (), name.end (), [](char c) { return !google::protobuf::ascii_isalnum (c); }, ' _' );
197
+
198
+ auto deviceKind = std::get<1 >(info.param );
199
+ // Determine if test should be skipped
200
+ DetermineIfDisableTest (name, deviceKind);
201
+ if (deviceKind == winml::LearningModelDeviceKind::Cpu) {
202
+ name += " _CPU" ;
203
+ } else {
204
+ name += " _GPU" ;
205
+ }
206
+
207
+ return name;
208
+ }
209
+
210
+ INSTANTIATE_TEST_SUITE_P (ModelTests, ModelTest, testing::Combine(testing::ValuesIn(GetAllTestCases()), testing::Values(winml::LearningModelDeviceKind::Cpu, winml::LearningModelDeviceKind::DirectX)),
211
+ GetNameOfTest);
212
+ } // namespace WinML
0 commit comments