Skip to content

Commit bac5dfe

Browse files
committed
Fix for linux
1 parent 7dccba5 commit bac5dfe

File tree

2 files changed

+72
-32
lines changed

2 files changed

+72
-32
lines changed

cpp/CMakeLists.txt

+57-28
Original file line numberDiff line numberDiff line change
@@ -22,24 +22,9 @@ if(NOT WIN32)
2222
set(ColorBoldRed "${ColorRed}${ColorBold}")
2323
endif()
2424

25-
#--------------------------- CMAKE VARIABLES (partly for Cmake GUI) ----------------------------------------------------
26-
27-
set(USE_BACKEND CACHE STRING "Neural net backend")
28-
string(TOUPPER "${USE_BACKEND}" USE_BACKEND)
29-
set_property(CACHE USE_BACKEND PROPERTY STRINGS "" CUDA OPENCL EIGEN ONNXRUNTIME)
30-
31-
set(USE_TCMALLOC 0 CACHE BOOL "Use TCMalloc")
32-
set(NO_GIT_REVISION 0 CACHE BOOL "Disable embedding the git revision into the compiled exe")
33-
set(Boost_USE_STATIC_LIBS_ON 0 CACHE BOOL "Compile against boost statically instead of dynamically")
34-
set(USE_AVX2 0 CACHE BOOL "Compile with AVX2")
35-
set(USE_BIGGER_BOARDS_EXPENSIVE 0 CACHE BOOL "Allow boards up to size 29. Compiling with this Will use more memory and slow down KataGo, even when playing on boards of size 19.")
36-
37-
#--------------------------- NEURAL NET BACKEND ------------------------------------------------------------------------
38-
39-
message(STATUS "Building 'katago' executable for GTP engine and other tools.")
40-
if(USE_BACKEND STREQUAL "CUDA")
41-
message(STATUS "-DUSE_BACKEND=CUDA, using CUDA backend.")
25+
#--------------------------- CUDA MACRO -------------------------------------------------------------------------------
4226

27+
macro(CONFIGURE_CUDA)
4328
# Ensure dynamic cuda linking (Versions prior to 3.17)
4429
if (${CMAKE_VERSION} VERSION_LESS "3.17")
4530
set(CMAKE_CUDA_FLAGS "" CACHE STRING "")
@@ -144,6 +129,26 @@ if(USE_BACKEND STREQUAL "CUDA")
144129
"
145130
)
146131
endif()
132+
endmacro()
133+
134+
#--------------------------- CMAKE VARIABLES (partly for Cmake GUI) ----------------------------------------------------
135+
136+
set(USE_BACKEND CACHE STRING "Neural net backend")
137+
string(TOUPPER "${USE_BACKEND}" USE_BACKEND)
138+
set_property(CACHE USE_BACKEND PROPERTY STRINGS "" CUDA OPENCL EIGEN ONNXRUNTIME)
139+
140+
set(USE_TCMALLOC 0 CACHE BOOL "Use TCMalloc")
141+
set(NO_GIT_REVISION 0 CACHE BOOL "Disable embedding the git revision into the compiled exe")
142+
set(Boost_USE_STATIC_LIBS_ON 0 CACHE BOOL "Compile against boost statically instead of dynamically")
143+
set(USE_AVX2 0 CACHE BOOL "Compile with AVX2")
144+
set(USE_BIGGER_BOARDS_EXPENSIVE 0 CACHE BOOL "Allow boards up to size 29. Compiling with this Will use more memory and slow down KataGo, even when playing on boards of size 19.")
145+
146+
#--------------------------- NEURAL NET BACKEND ------------------------------------------------------------------------
147+
148+
message(STATUS "Building 'katago' executable for GTP engine and other tools.")
149+
if(USE_BACKEND STREQUAL "CUDA")
150+
message(STATUS "-DUSE_BACKEND=CUDA, using CUDA backend.")
151+
configure_cuda()
147152
elseif(USE_BACKEND STREQUAL "OPENCL")
148153
message(STATUS "-DUSE_BACKEND=OPENCL, using OpenCL backend.")
149154
set(NEURALNET_BACKEND_SOURCES
@@ -166,6 +171,9 @@ elseif(USE_BACKEND STREQUAL "ONNXRUNTIME")
166171
set(ORT_TENSORRT 0 CACHE BOOL "Use TensorRT execution provider for ONNXRuntime.")
167172
set(ORT_DIRECTML 0 CACHE BOOL "Use DirectML execution provider for ONNXRuntime.")
168173
set(ORT_MIGRAPHX 0 CACHE BOOL "Use MIGraphX execution provider for ONNXRuntime.")
174+
if(ORT_CUDA OR ORT_TENSORRT)
175+
configure_cuda()
176+
endif()
169177
if(ORT_MIGRAPHX)
170178
set(NEURALNET_BACKEND_SOURCES
171179
neuralnet/ortbackend.cpp
@@ -345,18 +353,38 @@ elseif(USE_BACKEND STREQUAL "EIGEN")
345353
endif()
346354
elseif(USE_BACKEND STREQUAL "ONNXRUNTIME")
347355
target_compile_definitions(katago PRIVATE USE_ONNXRUNTIME_BACKEND)
356+
set(ORT_LIB_DIR CACHE STRING "ONNXRuntime library location")
357+
set(ORT_INCLUDE_DIR CACHE STRING "ONNXRuntime header files location")
358+
message(STATUS "ORT_LIB_DIR: " ${ORT_LIB_DIR})
359+
message(STATUS "ORT_INCLUDE_DIR: " ${ORT_INCLUDE_DIR})
360+
include_directories(${ORT_INCLUDE_DIR})
361+
if(EXISTS ${ORT_INCLUDE_DIR}/core/session)
362+
include_directories(${ORT_INCLUDE_DIR}/core/session)
363+
endif()
364+
if(EXISTS ${ORT_INCLUDE_DIR}/core/providers/cpu)
365+
include_directories(${ORT_INCLUDE_DIR}/core/providers/cpu)
366+
endif()
367+
find_library(ORT_LIBRARY NAMES onnxruntime PATHS ${ORT_LIB_DIR})
368+
if(NOT ORT_LIBRARY)
369+
message(FATAL_ERROR "Could not find onnxruntime")
370+
endif()
371+
target_link_libraries(katago ${ORT_LIBRARY})
348372
if(ORT_CUDA)
349373
target_compile_definitions(katago PRIVATE USE_ORT_CUDA)
350374
endif()
351375
if(ORT_TENSORRT)
352376
target_compile_definitions(katago PRIVATE USE_ORT_TENSORRT)
353-
set(TENSORRT_ROOT_DIR CACHE STRING "TensorRT root location")
354-
include_directories(${TENSORRT_ROOT_DIR}/include)
355-
find_library(TENSORRT_LIBRARY NAMES nvinfer PATHS ${TENSORRT_ROOT_DIR}/lib)
377+
set(TENSORRT_LIB_DIR CACHE STRING "TensorRT library location")
378+
set(TENSORRT_INCLUDE_DIR CACHE STRING "TensorRT header file location")
379+
include_directories(${TENSORRT_INCLUDE_DIR})
380+
find_library(TENSORRT_LIBRARY NAMES nvinfer PATHS ${TENSORRT_LIB_DIR})
356381
if(NOT TENSORRT_LIBRARY)
357382
message(FATAL_ERROR "Could not find nvinfer")
358383
endif()
359384
target_link_libraries(katago ${TENSORRT_LIBRARY})
385+
if(EXISTS ${ORT_INCLUDE_DIR}/core/providers/tensorrt)
386+
include_directories(${ORT_INCLUDE_DIR}/core/providers/tensorrt)
387+
endif()
360388
endif()
361389
if(ORT_CUDA OR ORT_TENSORRT)
362390
find_package(CUDA REQUIRED)
@@ -367,21 +395,22 @@ elseif(USE_BACKEND STREQUAL "ONNXRUNTIME")
367395
find_library(CUDNN_LIBRARY libcudnn.so PATHS /usr/local/cuda/lib64 /opt/cuda/lib64)
368396
include_directories(SYSTEM ${CUDA_INCLUDE_DIRS} ${CUDNN_INCLUDE_DIR}) #SYSTEM is for suppressing some compiler warnings in thrust libraries
369397
target_link_libraries(katago ${CUDNN_LIBRARY} ${CUDA_CUBLAS_LIBRARIES} ${CUDA_LIBRARIES})
398+
if(EXISTS ${ORT_INCLUDE_DIR}/core/providers/cuda)
399+
include_directories(${ORT_INCLUDE_DIR}/core/providers/cuda)
400+
endif()
370401
endif()
371402
if(ORT_DIRECTML)
372403
target_compile_definitions(katago PRIVATE USE_ORT_DIRECTML)
404+
if(EXISTS ${ORT_INCLUDE_DIR}/core/providers/directml)
405+
include_directories(${ORT_INCLUDE_DIR}/core/providers/directml)
406+
endif()
373407
endif()
374408
if(ORT_MIGRAPHX)
375409
target_compile_definitions(katago PRIVATE USE_ORT_MIGRAPHX)
410+
if(EXISTS ${ORT_INCLUDE_DIR}/core/providers/migraphx)
411+
include_directories(${ORT_INCLUDE_DIR}/core/providers/migraphx)
412+
endif()
376413
endif()
377-
set(ORT_ROOT_DIR CACHE STRING "ONNXRuntime root location")
378-
message(STATUS "ORT_ROOT_DIR: " ${ORT_ROOT_DIR})
379-
include_directories(${ORT_ROOT_DIR}/include)
380-
find_library(ORT_LIBRARY NAMES onnxruntime PATHS ${ORT_ROOT_DIR}/lib)
381-
if(NOT ORT_LIBRARY)
382-
message(FATAL_ERROR "Could not find onnxruntime")
383-
endif()
384-
target_link_libraries(katago ${ORT_LIBRARY})
385414
endif()
386415

387416
if(USE_BIGGER_BOARDS_EXPENSIVE)

cpp/neuralnet/ortbackend.cpp

+15-4
Original file line numberDiff line numberDiff line change
@@ -13,7 +13,7 @@
1313

1414
#if defined(USE_ORT_CUDA) || defined(USE_ORT_TENSORRT)
1515
#include <cuda_provider_factory.h>
16-
#include <cuda.h>
16+
#include "../neuralnet/cudaincludes.h"
1717
#endif
1818
#ifdef USE_ORT_TENSORRT
1919
#include <tensorrt_provider_factory.h>
@@ -84,6 +84,8 @@ Rules NeuralNet::getSupportedRules(const LoadedModel* loadedModel, const Rules&
8484

8585
//------------------------------------------------------------------------------
8686

87+
std::unique_ptr < Ort::Env> env = nullptr;
88+
8789
struct Model {
8890
string name;
8991
int version;
@@ -110,14 +112,19 @@ struct Model {
110112
numScoreValueChannels = desc->numScoreValueChannels;
111113
numOwnershipChannels = desc->numOwnershipChannels;
112114

113-
Ort::Env env(ORT_LOGGING_LEVEL_ERROR, "Default");
115+
auto envLocal = std::make_unique<Ort::Env>(ORT_LOGGING_LEVEL_ERROR, "Default");
116+
env = std::move(envLocal);
114117
Ort::SessionOptions sf;
115118
sf.SetGraphOptimizationLevel(GraphOptimizationLevel::ORT_ENABLE_EXTENDED);
116119
string dir = HomeData::getHomeDataDir(true, homeDataDirOverride);
117120
MakeDir::make(dir);
118121
string optModelPath = dir + "/" + onnxOptModelFile;
122+
#ifdef _WIN32
119123
std::wstring optModelFile = std::wstring(optModelPath.begin(), optModelPath.end());
120-
sf.SetOptimizedModelFilePath(optModelFile.c_str());
124+
sf.SetOptimizedModelFilePath(optModelFile.data());
125+
#else
126+
sf.SetOptimizedModelFilePath(optModelPath.data());
127+
#endif
121128

122129
if(onnxRuntimeExecutionProvider == "CUDA") {
123130
#ifdef USE_ORT_CUDA
@@ -153,8 +160,12 @@ struct Model {
153160
throw StringError("Invalid ONNXRuntime backend");
154161
}
155162

163+
#ifdef _WIN32
156164
std::wstring modelName = std::wstring(name.begin(), name.end());
157-
session = new Ort::Session{env, modelName.c_str(), sf};
165+
session = new Ort::Session(*env, modelName.data(), sf);
166+
#else
167+
session = new Ort::Session(*env, name.data(), sf);
168+
#endif
158169

159170
Ort::AllocatorWithDefaultOptions allocator;
160171

0 commit comments

Comments
 (0)