Skip to content

Commit 1564578

Browse files
committed
Attempt to switch everything to cmake
stack-info: PR: #1659, branch: drisspg/stack/33
1 parent 6ffe236 commit 1564578

File tree

2 files changed

+172
-157
lines changed

2 files changed

+172
-157
lines changed

CMakeLists.txt

+101
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,101 @@
1+
# Copyright (c) Meta Platforms, Inc. and affiliates.
2+
# All rights reserved.
3+
#
4+
# This source code is licensed under the license found in the
5+
# LICENSE file in the root directory of this source tree.
6+
7+
cmake_minimum_required(VERSION 3.19)
8+
project(torchao_core CUDA CXX)
9+
10+
set(CMAKE_CXX_STANDARD 17)
11+
set(CMAKE_CXX_STANDARD_REQUIRED ON)
12+
set(CMAKE_EXPORT_COMPILE_COMMANDS ON)
13+
14+
if(NOT CMAKE_BUILD_TYPE)
15+
set(CMAKE_BUILD_TYPE Release)
16+
endif()
17+
18+
# Find PyTorch package
19+
find_package(Torch REQUIRED)
20+
21+
# Global compile definitions
22+
add_compile_definitions(Py_LIMITED_API=0x03090000)
23+
24+
# Set compiler flags based on platform and build type
25+
if(MSVC)
26+
if(CMAKE_BUILD_TYPE STREQUAL "Debug")
27+
add_compile_options(/Od /ZI /DEBUG)
28+
else()
29+
add_compile_options(/O2 /permissive-)
30+
endif()
31+
else()
32+
if(CMAKE_BUILD_TYPE STREQUAL "Debug")
33+
add_compile_options(-g -O0)
34+
if(CMAKE_CUDA_COMPILER)
35+
add_compile_options($<$<COMPILE_LANGUAGE:CUDA>:-g>)
36+
endif()
37+
else()
38+
add_compile_options(-O3)
39+
if(CMAKE_CUDA_COMPILER)
40+
add_compile_options($<$<COMPILE_LANGUAGE:CUDA>:-O3>)
41+
endif()
42+
endif()
43+
44+
# Add color diagnostics for non-Windows builds
45+
add_compile_options(-fdiagnostics-color=always)
46+
endif()
47+
48+
# Include directories
49+
include_directories(${CMAKE_CURRENT_SOURCE_DIR})
50+
51+
# CUDA Setup
52+
if(CMAKE_CUDA_COMPILER)
53+
enable_language(CUDA)
54+
add_definitions(-DTORCH_CUDA_AVAILABLE)
55+
56+
# Set CUDA architectures if not already set
57+
if(NOT DEFINED CMAKE_CUDA_ARCHITECTURES)
58+
set(CMAKE_CUDA_ARCHITECTURES 70 75 80 86)
59+
endif()
60+
61+
# Set CUDA flags
62+
# Set CUDA architectures and TORCH_CUDA_ARCH_LIST
63+
if(NOT DEFINED TORCH_CUDA_ARCH_LIST)
64+
set(TORCH_CUDA_ARCH_LIST "7.0;7.5;8.0;8.6;9.0")
65+
endif()
66+
67+
# CUTLASS support for non-Windows CUDA builds
68+
if(NOT WIN32)
69+
set(CUTLASS_DIR ${CMAKE_CURRENT_SOURCE_DIR}/third_party/cutlass)
70+
if(EXISTS ${CUTLASS_DIR})
71+
add_definitions(-DTORCHAO_USE_CUTLASS)
72+
include_directories(${CUTLASS_DIR}/include)
73+
set(CMAKE_CUDA_FLAGS "${CMAKE_CUDA_FLAGS} -DTORCHAO_USE_CUTLASS -I${CUTLASS_DIR}/include")
74+
endif()
75+
endif()
76+
endif()
77+
78+
# Find source files
79+
file(GLOB_RECURSE CPP_SOURCES "${CMAKE_CURRENT_SOURCE_DIR}/torchao/csrc/**/*.cpp")
80+
if(CMAKE_CUDA_COMPILER)
81+
file(GLOB_RECURSE CUDA_SOURCES "${CMAKE_CURRENT_SOURCE_DIR}/torchao/csrc/cuda/**/*.cu")
82+
endif()
83+
84+
# Create the core library
85+
add_library(torchao_core SHARED
86+
${CPP_SOURCES}
87+
${CUDA_SOURCES}
88+
)
89+
90+
target_link_libraries(torchao_core PRIVATE
91+
${TORCH_LIBRARIES}
92+
)
93+
94+
# Set Python limited API version
95+
target_compile_definitions(torchao_core PRIVATE Py_LIMITED_API=0x03090000)
96+
97+
# Installation
98+
install(TARGETS torchao_core
99+
LIBRARY DESTINATION lib
100+
RUNTIME DESTINATION lib
101+
)

0 commit comments

Comments
 (0)