Skip to content

Commit 98ab844

Browse files
authored
Adding support for building BERT plugins with GPU_ARCHS specified (#255)
1 parent fb305de commit 98ab844

File tree

33 files changed

+102
-93
lines changed

33 files changed

+102
-93
lines changed

CMakeLists.txt

Lines changed: 27 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -77,6 +77,33 @@ endif()
7777

7878
set(CMAKE_CXX_FLAGS "-Wno-deprecated-declarations ${CMAKE_CXX_FLAGS} -DBUILD_SYSTEM=cmake_oss")
7979

80+
if (DEFINED GPU_ARCHS)
81+
message(STATUS "GPU_ARCHS defined as ${GPU_ARCHS}. Generating CUDA code for SM ${GPU_ARCHS}")
82+
separate_arguments(GPU_ARCHS)
83+
else()
84+
list(APPEND GPU_ARCHS
85+
35
86+
53
87+
61
88+
70
89+
75
90+
)
91+
message(STATUS "GPU_ARCHS is not defined. Generating CUDA code for default SMs: ${GPU_ARCHS}")
92+
endif()
93+
set(BERT_GENCODES)
94+
# Generate SASS for each architecture
95+
foreach(arch ${GPU_ARCHS})
96+
if (${arch} GREATER_EQUAL 70)
97+
set(BERT_GENCODES "${BERT_GENCODES} -gencode arch=compute_${arch},code=sm_${arch}")
98+
endif()
99+
set(GENCODES "${GENCODES} -gencode arch=compute_${arch},code=sm_${arch}")
100+
endforeach()
101+
# Generate PTX for the last architecture in the list.
102+
list(GET GPU_ARCHS -1 LATEST_SM)
103+
set(GENCODES "${GENCODES} -gencode arch=compute_${LATEST_SM},code=compute_${LATEST_SM}")
104+
if (${LATEST_SM} GREATER_EQUAL 70)
105+
set(BERT_GENCODES "${BERT_GENCODES} -gencode arch=compute_${LATEST_SM},code=compute_${LATEST_SM}")
106+
endif()
80107
set(CMAKE_CUDA_FLAGS "${CMAKE_CUDA_FLAGS} -Xcompiler -Wno-deprecated-declarations")
81108

82109
################################### DEPENDENCIES ##########################################

README.md

Lines changed: 2 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -174,11 +174,12 @@ NOTE: Along with the TensorRT OSS components, the following source packages will
174174

175175
- `CUB_VERSION`: The version of CUB to use, for example [`1.8.0`].
176176

177-
- `GPU_ARCHS`: GPU (SM) architectures to target. By default we generate CUDA code for the latest SM version. If lower SM versions are desired, they can be specified here as a comma separated list. Table of compute capabilities of NVIDIA GPUs can be found [here](https://developer.nvidia.com/cuda-gpus). Examples:
177+
- `GPU_ARCHS`: GPU (SM) architectures to target. By default we generate CUDA code for all major SMs. Specific SM versions can be specified here as a quoted space-separated list to reduce compilation time and binary size. Table of compute capabilities of NVIDIA GPUs can be found [here](https://developer.nvidia.com/cuda-gpus). Examples:
178178
- Titan V: `-DGPU_ARCHS="70"`
179179
- Tesla V100: `-DGPU_ARCHS="70"`
180180
- GeForce RTX 2080: `-DGPU_ARCHS="75"`
181181
- Tesla T4: `-DGPU_ARCHS="75"`
182+
- Multiple SMs: `-DGPU_ARCHS="70 75"`
182183

183184
## Install the TensorRT OSS Components [Optional]
184185

plugin/CMakeLists.txt

Lines changed: 22 additions & 8 deletions
Original file line numberDiff line numberDiff line change
@@ -26,17 +26,10 @@ if(${CMAKE_BUILD_TYPE} MATCHES "Debug")
2626
set(CMAKE_CXX_FLAGS "${CMAKE_CXX_FLAGS} -g")
2727
endif()
2828

29-
set(PLUGIN_SRCS)
3029
set(PLUGIN_SOURCES)
31-
set(CUDA_SRCS)
32-
set(COMMON_SRCS)
30+
set(PLUGIN_CU_SOURCES)
3331

3432
set(PLUGIN_LISTS
35-
embLayerNormPlugin
36-
fcPlugin
37-
geluPlugin
38-
bertQKVToContextPlugin
39-
skipLayerNormPlugin
4033
nmsPlugin
4134
normalizePlugin
4235
priorBoxPlugin
@@ -57,6 +50,19 @@ set(PLUGIN_LISTS
5750
instanceNormalizationPlugin
5851
)
5952

53+
# Add BERT sources if ${BERT_GENCODES} was populated
54+
if(BERT_GENCODES)
55+
set(BERT_CU_SOURCES)
56+
set(PLUGIN_LISTS
57+
${PLUGIN_LISTS}
58+
embLayerNormPlugin
59+
fcPlugin
60+
geluPlugin
61+
bertQKVToContextPlugin
62+
skipLayerNormPlugin
63+
)
64+
endif()
65+
6066
include_directories(common common/kernels ../samples/common)
6167

6268
foreach(PLUGIN_ITER ${PLUGIN_LISTS})
@@ -67,6 +73,14 @@ endforeach(PLUGIN_ITER)
6773
# Add common
6874
add_subdirectory(common)
6975

76+
# Set gencodes
77+
set_source_files_properties(${PLUGIN_CU_SOURCES} PROPERTIES COMPILE_FLAGS ${GENCODES})
78+
list(APPEND PLUGIN_SOURCES "${PLUGIN_CU_SOURCES}")
79+
if (BERT_CU_SOURCES)
80+
set_source_files_properties(${BERT_CU_SOURCES} PROPERTIES COMPILE_FLAGS ${BERT_GENCODES})
81+
list(APPEND PLUGIN_SOURCES "${BERT_CU_SOURCES}")
82+
endif()
83+
7084
list(APPEND PLUGIN_SOURCES "${CMAKE_CURRENT_SOURCE_DIR}/InferPlugin.cpp")
7185
list(APPEND PLUGIN_SOURCES "${CMAKE_CURRENT_SOURCE_DIR}/../samples/common/logger.cpp")
7286

plugin/batchTilePlugin/CMakeLists.txt

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -13,6 +13,6 @@
1313
# See the License for the specific language governing permissions and
1414
# limitations under the License.
1515
#
16-
file(GLOB SRCS *.cpp *.cu)
16+
file(GLOB SRCS *.cpp)
1717
set(PLUGIN_SOURCES ${PLUGIN_SOURCES} ${SRCS})
1818
set(PLUGIN_SOURCES ${PLUGIN_SOURCES} PARENT_SCOPE)

plugin/batchedNMSPlugin/CMakeLists.txt

Lines changed: 4 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -13,8 +13,11 @@
1313
# See the License for the specific language governing permissions and
1414
# limitations under the License.
1515
#
16-
file(GLOB SRCS *.cpp *.cu)
16+
file(GLOB SRCS *.cpp)
1717
set(PLUGIN_SOURCES ${PLUGIN_SOURCES} ${SRCS})
1818
set(PLUGIN_SOURCES ${PLUGIN_SOURCES} PARENT_SCOPE)
19+
file(GLOB CU_SRCS *.cu)
20+
set(PLUGIN_CU_SOURCES ${PLUGIN_CU_SOURCES} ${CU_SRCS})
21+
set(PLUGIN_CU_SOURCES ${PLUGIN_CU_SOURCES} PARENT_SCOPE)
1922

2023

plugin/bertQKVToContextPlugin/CMakeLists.txt

Lines changed: 3 additions & 12 deletions
Original file line numberDiff line numberDiff line change
@@ -13,15 +13,6 @@
1313
# See the License for the specific language governing permissions and
1414
# limitations under the License.
1515
#
16-
17-
string(FIND ${CMAKE_CUDA_FLAGS} "sm_7" POS_SM)
18-
string(FIND ${CMAKE_CUDA_FLAGS} "compute_7" POS_COMPUTE)
19-
20-
if(${POS_SM} GREATER_EQUAL 0 OR ${POS_COMPUTE} GREATER_EQUAL 0)
21-
file(GLOB SRCS *.cpp *.cu)
22-
23-
set(PLUGIN_SOURCES ${PLUGIN_SOURCES} ${SRCS})
24-
set(PLUGIN_SOURCES ${PLUGIN_SOURCES} PARENT_SCOPE)
25-
endif()
26-
27-
16+
file(GLOB CU_SRCS *.cu)
17+
set(BERT_CU_SOURCES ${BERT_CU_SOURCES} ${CU_SRCS})
18+
set(BERT_CU_SOURCES ${BERT_CU_SOURCES} PARENT_SCOPE)

plugin/bertQKVToContextPlugin/README.md

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -72,4 +72,4 @@ This is the first release of this `README.md` file.
7272

7373
## Known issues
7474

75-
There are no known issues in this plugin.
75+
This plugin only supports GPUs with compute capability >= 7.0. For more information see the [CUDA GPU Compute Capability Support Matrix](https://developer.nvidia.com/cuda-gpus#compute)

plugin/common/CMakeLists.txt

Lines changed: 5 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -13,10 +13,12 @@
1313
# See the License for the specific language governing permissions and
1414
# limitations under the License.
1515
#
16-
file(GLOB SRCS *.cpp *.cu)
17-
1816
add_subdirectory(kernels)
19-
17+
file(GLOB SRCS *.cpp)
2018
set(PLUGIN_SOURCES ${PLUGIN_SOURCES} ${SRCS})
2119
set(PLUGIN_SOURCES ${PLUGIN_SOURCES} PARENT_SCOPE)
20+
file(GLOB CU_SRCS *.cu)
21+
set(PLUGIN_CU_SOURCES ${PLUGIN_CU_SOURCES} ${CU_SRCS})
22+
set(PLUGIN_CU_SOURCES ${PLUGIN_CU_SOURCES} PARENT_SCOPE)
23+
2224

plugin/common/bertCommon.h

Lines changed: 2 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -35,6 +35,8 @@ extern LogStreamConsumer gLogFatal;
3535

3636
void setReportableSeverity(Logger::Severity severity);
3737

38+
#define TRT_UNUSED (void)
39+
3840
#include <numeric>
3941
#include <vector>
4042

plugin/common/kernels/CMakeLists.txt

Lines changed: 4 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -13,6 +13,9 @@
1313
# See the License for the specific language governing permissions and
1414
# limitations under the License.
1515
#
16-
file(GLOB SRCS *.cpp *.cu)
16+
file(GLOB SRCS *.cpp)
1717
set(PLUGIN_SOURCES ${PLUGIN_SOURCES} ${SRCS})
1818
set(PLUGIN_SOURCES ${PLUGIN_SOURCES} PARENT_SCOPE)
19+
file(GLOB CU_SRCS *.cu)
20+
set(PLUGIN_CU_SOURCES ${PLUGIN_CU_SOURCES} ${CU_SRCS})
21+
set(PLUGIN_CU_SOURCES ${PLUGIN_CU_SOURCES} PARENT_SCOPE)

0 commit comments

Comments
 (0)