Skip to content

Commit

Permalink
Merge pull request #164 from lab-cosmo/dyload-cuda-stream
Browse files Browse the repository at this point in the history
Dynamically load `c10::cuda::getCurrentCUDAStream` in sphericart-torch
  • Loading branch information
nickjbrowning authored Jan 16, 2025
2 parents a6f23ec + 487bbd8 commit 0aa2287
Show file tree
Hide file tree
Showing 7 changed files with 116 additions and 44 deletions.
10 changes: 1 addition & 9 deletions scripts/check-format.sh
Original file line number Diff line number Diff line change
Expand Up @@ -5,12 +5,4 @@ set -eu
ROOT_DIR=$(cd "$(dirname "${BASH_SOURCE[0]}")"/.. && pwd)
cd "$ROOT_DIR"

find . -type f \( \
-name "*.c" -o -name "*.cpp" \
-o -name "*.h" -o -name "*.hpp" \
-o -name "*.cu" -o -name "*.cuh" \
\) \
-not -path "*/external/*" \
-not -path "*/build/*" \
-not -path "*/.tox/*" \
-exec clang-format --dry-run --Werror {} \;
git ls-files '*.cpp' '*.c' '*.hpp' '*.h' '*.cu' '*.cuh' | xargs -L 1 clang-format --dry-run --Werror
10 changes: 1 addition & 9 deletions scripts/format.sh
Original file line number Diff line number Diff line change
Expand Up @@ -5,12 +5,4 @@ set -eu
ROOT_DIR=$(cd "$(dirname "${BASH_SOURCE[0]}")"/.. && pwd)
cd "$ROOT_DIR"

find . -type f \( \
-name "*.c" -o -name "*.cpp" \
-o -name "*.h" -o -name "*.hpp" \
-o -name "*.cu" -o -name "*.cuh" \
\) \
-not -path "*/external/*" \
-not -path "*/build/*" \
-not -path "*/.tox/*" \
-exec clang-format -i {} \;
git ls-files '*.cpp' '*.c' '*.hpp' '*.h' '*.cu' '*.cuh' | xargs -L 1 clang-format -i
39 changes: 30 additions & 9 deletions sphericart-torch/CMakeLists.txt
Original file line number Diff line number Diff line change
Expand Up @@ -122,17 +122,16 @@ else()
target_sources(sphericart_torch PUBLIC "src/torch_cuda_wrapper_stub.cpp")
endif()

target_link_libraries(sphericart_torch PUBLIC torch sphericart)
if(OpenMP_CXX_FOUND)
target_link_libraries(sphericart_torch PUBLIC OpenMP::OpenMP_CXX)
endif()
target_link_libraries(sphericart_torch PUBLIC sphericart)

message (STATUS "CUDA Toolkit Dir: ${CMAKE_CUDA_TOOLKIT_INCLUDE_DIRECTORIES}")
# only link to `torch_cpu_library` instead of `torch`, which could also include
# `libtorch_cuda`.
target_link_libraries(sphericart_torch PUBLIC torch_cpu_library)
target_include_directories(sphericart_torch PUBLIC "${TORCH_INCLUDE_DIRS}")
target_compile_definitions(sphericart_torch PUBLIC "${TORCH_CXX_FLAGS}")

if(CMAKE_CUDA_COMPILER)
target_compile_definitions(sphericart_torch PRIVATE CUDA_AVAILABLE)
target_compile_definitions(sphericart_torch PRIVATE C10_CUDA_NO_CMAKE_CONFIGURE_FILE)
target_include_directories(sphericart_torch PRIVATE ${CMAKE_CUDA_TOOLKIT_INCLUDE_DIRECTORIES})
if(OpenMP_CXX_FOUND)
target_link_libraries(sphericart_torch PUBLIC OpenMP::OpenMP_CXX)
endif()

target_compile_features(sphericart_torch PUBLIC cxx_std_17)
Expand All @@ -142,10 +141,32 @@ target_include_directories(sphericart_torch PUBLIC
$<INSTALL_INTERFACE:include>
)

if (LINUX)
# so dlopen can find libsphericart_torch_cuda_stream.so
set_target_properties(sphericart_torch PROPERTIES INSTALL_RPATH "$ORIGIN")
endif()

add_library(sphericart_torch_cuda_stream SHARED
"src/streams.cpp"
)
target_link_libraries(sphericart_torch_cuda_stream PUBLIC torch)
target_compile_features(sphericart_torch_cuda_stream PUBLIC cxx_std_17)

if(CMAKE_CUDA_COMPILER)
target_compile_definitions(sphericart_torch_cuda_stream PRIVATE CUDA_AVAILABLE)
target_compile_definitions(sphericart_torch_cuda_stream PRIVATE C10_CUDA_NO_CMAKE_CONFIGURE_FILE)
target_include_directories(sphericart_torch_cuda_stream PRIVATE ${CMAKE_CUDA_TOOLKIT_INCLUDE_DIRECTORIES})
endif()


install(TARGETS sphericart_torch
LIBRARY DESTINATION "lib"
)

install(TARGETS sphericart_torch_cuda_stream
LIBRARY DESTINATION "lib"
)

if (SPHERICART_TORCH_BUILD_FOR_PYTHON)
install(
FILES ${CMAKE_CURRENT_BINARY_DIR}/_build_torch_version.py
Expand Down
65 changes: 54 additions & 11 deletions sphericart-torch/src/autograd.cpp
Original file line number Diff line number Diff line change
@@ -1,19 +1,67 @@
#include <cstdint> // For intptr_t

#ifdef __linux__
#include <dlfcn.h>
#endif

#include "sphericart/autograd.hpp"

#include "cuda_base.hpp"
#include "sphericart.hpp"
#include "sphericart/torch.hpp"
#include "sphericart/torch_cuda_wrapper.hpp"
#include <torch/torch.h>

#ifdef CUDA_AVAILABLE
#include <c10/cuda/CUDAStream.h>
/// Dynamically load `get_current_cuda_stream`, see `streams.cpp` for more
/// information
class CUDAStream {
public:
static CUDAStream& instance() {
static CUDAStream instance;
return instance;
}

bool loaded() { return handle != nullptr; }

using get_stream_t = void* (*)(uint8_t);
get_stream_t get_stream = nullptr;

CUDAStream() {
#ifdef __linux__
handle = dlopen("libsphericart_torch_cuda_stream.so", RTLD_NOW);
if (!handle) {
throw std::runtime_error(
std::string("Failed to load libsphericart_torch_cuda_stream.so: ") + dlerror()
);
}

auto get_stream = reinterpret_cast<get_stream_t>(dlsym(handle, "get_current_cuda_stream"));
if (!get_stream) {
throw std::runtime_error(
std::string("Failed to load get_current_cuda_stream: ") + dlerror()
);
}
this->get_stream = get_stream;
#else
throw std::runtime_error("Platform not supported for dynamic loading of CUDA streams");
#endif
}

~CUDAStream() {
#ifdef __linux__
if (handle) {
dlclose(handle);
}
#endif
}

// Prevent copying
CUDAStream(const CUDAStream&) = delete;
CUDAStream& operator=(const CUDAStream&) = delete;

void* handle = nullptr;
};

using namespace sphericart_torch;
using namespace at;

template <template <typename> class C, typename scalar_t>
std::vector<torch::Tensor> _compute_raw_cpu(
Expand Down Expand Up @@ -269,9 +317,6 @@ std::vector<torch::Tensor> SphericartAutograd::forward(
}

void* stream = nullptr;
#ifdef CUDA_AVAILABLE
stream = reinterpret_cast<void*>(at::cuda::getCurrentCUDAStream().stream());
#endif

auto sph = torch::Tensor();
auto dsph = torch::Tensor();
Expand All @@ -288,6 +333,7 @@ std::vector<torch::Tensor> SphericartAutograd::forward(
dsph = results[1];
ddsph = results[2];
} else if (xyz.device().is_cuda()) {
stream = CUDAStream::instance().get_stream(xyz.device().index());
auto results = calculator.compute_raw_cuda(xyz, requires_grad, requires_hessian, stream);
sph = results[0];
dsph = results[1];
Expand Down Expand Up @@ -339,10 +385,6 @@ torch::Tensor SphericartAutogradBackward::forward(
) {

void* stream = nullptr;
#ifdef CUDA_AVAILABLE
stream = reinterpret_cast<void*>(at::cuda::getCurrentCUDAStream().stream());
#endif

auto dsph = saved_variables[1];
auto ddsph = saved_variables[2];

Expand All @@ -351,6 +393,7 @@ torch::Tensor SphericartAutogradBackward::forward(
if (xyz.device().is_cpu()) {
xyz_grad = backward_cpu(xyz, dsph, grad_outputs);
} else if (xyz.device().is_cuda()) {
stream = CUDAStream::instance().get_stream(xyz.device().index());
xyz_grad =
sphericart_torch::spherical_harmonics_backward_cuda(xyz, dsph, grad_outputs, stream);
} else {
Expand Down
27 changes: 27 additions & 0 deletions sphericart-torch/src/streams.cpp
Original file line number Diff line number Diff line change
@@ -0,0 +1,27 @@
#include <cstdint>

#include <torch/torch.h>

#ifdef CUDA_AVAILABLE
#include <c10/cuda/CUDAStream.h>

// This function re-expose `c10::cuda::getCurrentCUDAStream` in a way we can
// call through dlopen/dlsym, and is intended to be compiled in a separate
// shared library from the main sphericart_torch one.
//
// The alternative would be to link the main library directly to
// `libtorch_cuda`, but doing so will prevent users from loading
// sphericart_torch when using a CPU-only version of torch.
extern "C" void* get_current_cuda_stream(uint8_t device_id) {
return reinterpret_cast<void*>(c10::cuda::getCurrentCUDAStream(device_id).stream());
}

#else

extern "C" void* get_current_cuda_stream(uint8_t device_id) {
TORCH_WARN_ONCE("Something wrong is happening: trying to get the current CUDA stream, "
"but this version of sphericart was compiled without CUDA support");
return nullptr;
}

#endif
3 changes: 0 additions & 3 deletions sphericart-torch/src/torch_cuda_wrapper.cpp
Original file line number Diff line number Diff line change
@@ -1,7 +1,4 @@
#include "sphericart/torch_cuda_wrapper.hpp"
#include <cuda.h>
#include <cuda_runtime.h>
#include <iostream>
#include <torch/torch.h>

#define _SPHERICART_INTERNAL_IMPLEMENTATION // gives us access to
Expand Down
6 changes: 3 additions & 3 deletions sphericart/CMakeLists.txt
Original file line number Diff line number Diff line change
Expand Up @@ -3,7 +3,7 @@ cmake_minimum_required(VERSION 3.27)
project(sphericart LANGUAGES C CXX)

#[[
This function wraps the input file with the appropriate string to allow us to
This function wraps the input file with the appropriate string to allow us to
static-initialize string variables, for example:
static const char* CUDA_CODE =
#include "generated/wrapped_sphericart_impl.cu"
Expand Down Expand Up @@ -128,7 +128,7 @@ if (CMAKE_CUDA_COMPILER AND SPHERICART_ENABLE_CUDA)
# Prepend headers to the source file
prepend_headers_to_source(
"${CMAKE_CURRENT_SOURCE_DIR}/src/sphericart_impl.cu"
"${CMAKE_CURRENT_BINARY_DIR}/generated/tmp.cu"
"${CMAKE_CURRENT_BINARY_DIR}/generated/tmp.cu"
"${SPHERICART_CUDA_HEADERS}"
)

Expand Down Expand Up @@ -261,4 +261,4 @@ install(FILES "${PROJECT_BINARY_DIR}/sphericart-config-version.cmake"
DESTINATION ${LIB_INSTALL_DIR}/cmake/sphericart)

install(DIRECTORY ${PROJECT_SOURCE_DIR}/include/ DESTINATION ${INCLUDE_INSTALL_DIR})
install(DIRECTORY ${PROJECT_BINARY_DIR}/include/ DESTINATION ${INCLUDE_INSTALL_DIR})
install(DIRECTORY ${PROJECT_BINARY_DIR}/include/ DESTINATION ${INCLUDE_INSTALL_DIR})

0 comments on commit 0aa2287

Please sign in to comment.