diff --git a/CHANGELOG.md b/CHANGELOG.md
index a67bb266..c81d2225 100644
--- a/CHANGELOG.md
+++ b/CHANGELOG.md
@@ -1,6 +1,38 @@
# TensorRT OSS Release Changelog
-## 10.3.0 GA - 2024-08-07
+## 10.4.0 GA - 2024-09-11
+Key Features and Updates:
+
+- Demo changes
+ - Added [Stable Cascade](demo/Diffusion) pipeline.
+ - Enabled INT8 and FP8 quantization for Stable Diffusion v1.5, v2.0 and v2.1 pipelines.
+ - Enabled FP8 quantization for Stable Diffusion XL pipeline.
+- Sample changes
+ - Add a new python sample `aliased_io_plugin` which demonstrates how in-place updates to plugin inputs can be achieved through I/O aliasing.
+- Plugin changes
+ - Migrated IPluginV2-descendent versions (a) of the following plugins to newer versions (b) which implement IPluginV3 (a->b):
+ - scatterElementsPlugin (1->2)
+ - skipLayerNormPlugin (1->5, 2->6, 3->7, 4->8)
+ - embLayerNormPlugin (2->4, 3->5)
+ - bertQKVToContextPlugin (1->4, 2->5, 3->6)
+ - Note
+ - The newer versions preserve the corresponding attributes and I/O of the corresponding older plugin version.
+ - The older plugin versions are deprecated and will be removed in a future release.
+
+- Quickstart guide
+ - Updated deploy_to_triton guide and removed legacy APIs.
+ - Removed legacy TF-TRT code as the project is no longer supported.
+ - Removed quantization_tutorial as pytorch_quantization has been deprecated. Check out https://github.com/NVIDIA/TensorRT-Model-Optimizer for the latest quantization support. Check [Stable Diffusion XL (Base/Turbo) and Stable Diffusion 1.5 Quantization with Model Optimizer](https://github.com/NVIDIA/TensorRT-Model-Optimizer/tree/main/diffusers/quantization) for integration with TensorRT.
+- Parser changes
+ - Added support for tensor `axes` for `Pad` operations.
+ - Added support for `BlackmanWindow`, `HammingWindow`, and `HannWindow` operations.
+ - Improved error handling in `IParserRefitter`.
+ - Fixed kernel shape inference in multi-input convolutions.
+
+- Updated tooling
+ - polygraphy-extension-trtexec v0.0.9
+
+## 10.3.0 GA - 2024-08-02
Key Features and Updates:
diff --git a/CMakeLists.txt b/CMakeLists.txt
index a1f072a5..2928f5ef 100644
--- a/CMakeLists.txt
+++ b/CMakeLists.txt
@@ -80,7 +80,7 @@ option(BUILD_PARSERS "Build TensorRT parsers" ON)
option(BUILD_SAMPLES "Build TensorRT samples" ON)
# C++14
-set(CMAKE_CXX_STANDARD 14)
+set(CMAKE_CXX_STANDARD 17)
set(CMAKE_CXX_STANDARD_REQUIRED ON)
set(CMAKE_CXX_EXTENSIONS OFF)
diff --git a/README.md b/README.md
index 4a120aab..cdd6bdbb 100644
--- a/README.md
+++ b/README.md
@@ -26,13 +26,13 @@ You can skip the **Build** section to enjoy TensorRT with Python.
To build the TensorRT-OSS components, you will first need the following software packages.
**TensorRT GA build**
-* TensorRT v10.3.0.26
+* TensorRT v10.4.0.26
* Available from direct download links listed below
**System Packages**
* [CUDA](https://developer.nvidia.com/cuda-toolkit)
* Recommended versions:
- * cuda-12.5.0 + cuDNN-8.9
+ * cuda-12.6.0 + cuDNN-8.9
* cuda-11.8.0 + cuDNN-8.9
* [GNU make](https://ftp.gnu.org/gnu/make/) >= v4.1
* [cmake](https://github.com/Kitware/CMake/releases) >= v3.13
@@ -73,25 +73,25 @@ To build the TensorRT-OSS components, you will first need the following software
If using the TensorRT OSS build container, TensorRT libraries are preinstalled under `/usr/lib/x86_64-linux-gnu` and you may skip this step.
Else download and extract the TensorRT GA build from [NVIDIA Developer Zone](https://developer.nvidia.com) with the direct links below:
- - [TensorRT 10.3.0.26 for CUDA 11.8, Linux x86_64](https://developer.nvidia.com/downloads/compute/machine-learning/tensorrt/10.3.0/tars/TensorRT-10.3.0.26.Linux.x86_64-gnu.cuda-11.8.tar.gz)
- - [TensorRT 10.3.0.26 for CUDA 12.5, Linux x86_64](https://developer.nvidia.com/downloads/compute/machine-learning/tensorrt/10.3.0/tars/TensorRT-10.3.0.26.Linux.x86_64-gnu.cuda-12.5.tar.gz)
- - [TensorRT 10.3.0.26 for CUDA 11.8, Windows x86_64](https://developer.nvidia.com/downloads/compute/machine-learning/tensorrt/10.3.0/zip/TensorRT-10.3.0.26.Windows.win10.cuda-11.8.zip)
- - [TensorRT 10.3.0.26 for CUDA 12.5, Windows x86_64](https://developer.nvidia.com/downloads/compute/machine-learning/tensorrt/10.3.0/zip/TensorRT-10.3.0.26.Windows.win10.cuda-12.5.zip)
+ - [TensorRT 10.4.0.26 for CUDA 11.8, Linux x86_64](https://developer.nvidia.com/downloads/compute/machine-learning/tensorrt/10.4.0/tars/TensorRT-10.4.0.26.Linux.x86_64-gnu.cuda-11.8.tar.gz)
+ - [TensorRT 10.4.0.26 for CUDA 12.6, Linux x86_64](https://developer.nvidia.com/downloads/compute/machine-learning/tensorrt/10.4.0/tars/TensorRT-10.4.0.26.Linux.x86_64-gnu.cuda-12.6.tar.gz)
+ - [TensorRT 10.4.0.26 for CUDA 11.8, Windows x86_64](https://developer.nvidia.com/downloads/compute/machine-learning/tensorrt/10.4.0/zip/TensorRT-10.4.0.26.Windows.win10.cuda-11.8.zip)
+ - [TensorRT 10.4.0.26 for CUDA 12.6, Windows x86_64](https://developer.nvidia.com/downloads/compute/machine-learning/tensorrt/10.4.0/zip/TensorRT-10.4.0.26.Windows.win10.cuda-12.6.zip)
- **Example: Ubuntu 20.04 on x86-64 with cuda-12.5**
+ **Example: Ubuntu 20.04 on x86-64 with cuda-12.6**
```bash
cd ~/Downloads
- tar -xvzf TensorRT-10.3.0.26.Linux.x86_64-gnu.cuda-12.5.tar.gz
- export TRT_LIBPATH=`pwd`/TensorRT-10.3.0.26
+ tar -xvzf TensorRT-10.4.0.26.Linux.x86_64-gnu.cuda-12.6.tar.gz
+ export TRT_LIBPATH=`pwd`/TensorRT-10.4.0.26
```
- **Example: Windows on x86-64 with cuda-12.5**
+ **Example: Windows on x86-64 with cuda-12.6**
```powershell
- Expand-Archive -Path TensorRT-10.3.0.26.Windows.win10.cuda-12.5.zip
- $env:TRT_LIBPATH="$pwd\TensorRT-10.3.0.26\lib"
+ Expand-Archive -Path TensorRT-10.4.0.26.Windows.win10.cuda-12.6.zip
+ $env:TRT_LIBPATH="$pwd\TensorRT-10.4.0.26\lib"
```
## Setting Up The Build Environment
@@ -101,27 +101,27 @@ For Linux platforms, we recommend that you generate a docker container for build
1. #### Generate the TensorRT-OSS build container.
The TensorRT-OSS build container can be generated using the supplied Dockerfiles and build scripts. The build containers are configured for building TensorRT OSS out-of-the-box.
- **Example: Ubuntu 20.04 on x86-64 with cuda-12.5 (default)**
+ **Example: Ubuntu 20.04 on x86-64 with cuda-12.6 (default)**
```bash
- ./docker/build.sh --file docker/ubuntu-20.04.Dockerfile --tag tensorrt-ubuntu20.04-cuda12.5
+ ./docker/build.sh --file docker/ubuntu-20.04.Dockerfile --tag tensorrt-ubuntu20.04-cuda12.6
```
- **Example: Rockylinux8 on x86-64 with cuda-12.5**
+ **Example: Rockylinux8 on x86-64 with cuda-12.6**
```bash
- ./docker/build.sh --file docker/rockylinux8.Dockerfile --tag tensorrt-rockylinux8-cuda12.5
+ ./docker/build.sh --file docker/rockylinux8.Dockerfile --tag tensorrt-rockylinux8-cuda12.6
```
- **Example: Ubuntu 22.04 cross-compile for Jetson (aarch64) with cuda-12.5 (JetPack SDK)**
+ **Example: Ubuntu 22.04 cross-compile for Jetson (aarch64) with cuda-12.6 (JetPack SDK)**
```bash
- ./docker/build.sh --file docker/ubuntu-cross-aarch64.Dockerfile --tag tensorrt-jetpack-cuda12.5
+ ./docker/build.sh --file docker/ubuntu-cross-aarch64.Dockerfile --tag tensorrt-jetpack-cuda12.6
```
- **Example: Ubuntu 22.04 on aarch64 with cuda-12.5**
+ **Example: Ubuntu 22.04 on aarch64 with cuda-12.6**
```bash
- ./docker/build.sh --file docker/ubuntu-22.04-aarch64.Dockerfile --tag tensorrt-aarch64-ubuntu22.04-cuda12.5
+ ./docker/build.sh --file docker/ubuntu-22.04-aarch64.Dockerfile --tag tensorrt-aarch64-ubuntu22.04-cuda12.6
```
2. #### Launch the TensorRT-OSS build container.
**Example: Ubuntu 20.04 build container**
```bash
- ./docker/launch.sh --tag tensorrt-ubuntu20.04-cuda12.5 --gpus all
+ ./docker/launch.sh --tag tensorrt-ubuntu20.04-cuda12.6 --gpus all
```
> NOTE:
1. Use the `--tag` corresponding to build container generated in Step 1.
@@ -132,38 +132,38 @@ For Linux platforms, we recommend that you generate a docker container for build
## Building TensorRT-OSS
* Generate Makefiles and build.
- **Example: Linux (x86-64) build with default cuda-12.5**
+ **Example: Linux (x86-64) build with default cuda-12.6**
```bash
cd $TRT_OSSPATH
mkdir -p build && cd build
cmake .. -DTRT_LIB_DIR=$TRT_LIBPATH -DTRT_OUT_DIR=`pwd`/out
make -j$(nproc)
```
- **Example: Linux (aarch64) build with default cuda-12.5**
+ **Example: Linux (aarch64) build with default cuda-12.6**
```bash
cd $TRT_OSSPATH
mkdir -p build && cd build
cmake .. -DTRT_LIB_DIR=$TRT_LIBPATH -DTRT_OUT_DIR=`pwd`/out -DCMAKE_TOOLCHAIN_FILE=$TRT_OSSPATH/cmake/toolchains/cmake_aarch64-native.toolchain
make -j$(nproc)
```
- **Example: Native build on Jetson (aarch64) with cuda-12.5**
+ **Example: Native build on Jetson (aarch64) with cuda-12.6**
```bash
cd $TRT_OSSPATH
mkdir -p build && cd build
- cmake .. -DTRT_LIB_DIR=$TRT_LIBPATH -DTRT_OUT_DIR=`pwd`/out -DTRT_PLATFORM_ID=aarch64 -DCUDA_VERSION=12.5
+ cmake .. -DTRT_LIB_DIR=$TRT_LIBPATH -DTRT_OUT_DIR=`pwd`/out -DTRT_PLATFORM_ID=aarch64 -DCUDA_VERSION=12.6
CC=/usr/bin/gcc make -j$(nproc)
```
> NOTE: C compiler must be explicitly specified via CC= for native aarch64 builds of protobuf.
- **Example: Ubuntu 22.04 Cross-Compile for Jetson (aarch64) with cuda-12.5 (JetPack)**
+ **Example: Ubuntu 22.04 Cross-Compile for Jetson (aarch64) with cuda-12.6 (JetPack)**
```bash
cd $TRT_OSSPATH
mkdir -p build && cd build
- cmake .. -DCMAKE_TOOLCHAIN_FILE=$TRT_OSSPATH/cmake/toolchains/cmake_aarch64.toolchain -DCUDA_VERSION=12.5 -DCUDNN_LIB=/pdk_files/cudnn/usr/lib/aarch64-linux-gnu/libcudnn.so -DCUBLAS_LIB=/usr/local/cuda-12.5/targets/aarch64-linux/lib/stubs/libcublas.so -DCUBLASLT_LIB=/usr/local/cuda-12.5/targets/aarch64-linux/lib/stubs/libcublasLt.so -DTRT_LIB_DIR=/pdk_files/tensorrt/lib
+ cmake .. -DCMAKE_TOOLCHAIN_FILE=$TRT_OSSPATH/cmake/toolchains/cmake_aarch64.toolchain -DCUDA_VERSION=12.6 -DCUDNN_LIB=/pdk_files/cudnn/usr/lib/aarch64-linux-gnu/libcudnn.so -DCUBLAS_LIB=/usr/local/cuda-12.6/targets/aarch64-linux/lib/stubs/libcublas.so -DCUBLASLT_LIB=/usr/local/cuda-12.6/targets/aarch64-linux/lib/stubs/libcublasLt.so -DTRT_LIB_DIR=/pdk_files/tensorrt/lib
make -j$(nproc)
```
- **Example: Native builds on Windows (x86) with cuda-12.5**
+ **Example: Native builds on Windows (x86) with cuda-12.6**
```powershell
cd $TRT_OSSPATH
mkdir -p build
diff --git a/VERSION b/VERSION
index 92bc5b53..9de5818c 100644
--- a/VERSION
+++ b/VERSION
@@ -1 +1 @@
-10.3.0.26
+10.4.0.26
diff --git a/demo/BERT/README.md b/demo/BERT/README.md
index b48bd8be..2aa91952 100755
--- a/demo/BERT/README.md
+++ b/demo/BERT/README.md
@@ -75,8 +75,8 @@ The following software version configuration has been tested:
|Software|Version|
|--------|-------|
|Python|>=3.8|
-|TensorRT|10.3.0.26|
-|CUDA|12.5|
+|TensorRT|10.4.0.26|
+|CUDA|12.6|
## Setup
@@ -430,245 +430,244 @@ The following sections provide details on how we achieved our performance and in
Results were obtained by running `scripts/inference_benchmark.sh --gpu Ampere` on NVIDIA A100 (40G).
-##### BERT Base
+##### BERT base
| Sequence Length | Batch Size | INT8 Latency (ms) | | | FP16 Latency (ms) | | |
|-----------------|------------|-----------------|-----------------|---------|-----------------|-----------------|---------|
| | | 95th Percentile | 99th Percentile | Average | 95th Percentile | 99th Percentile | Average |
-| 128 | 1 | 0.53 | 0.68 | 0.54 | 0.79 | 0.79 | 0.64 |
-| 128 | 2 | 0.76 | 0.76 | 0.60 | 0.72 | 0.91 | 0.72 |
-| 128 | 4 | 0.73 | 0.92 | 0.73 | 1.03 | 1.04 | 0.93 |
-| 128 | 8 | 0.94 | 1.20 | 0.95 | 1.31 | 1.31 | 1.31 |
-| 128 | 12 | 1.19 | 1.20 | 1.19 | 1.72 | 1.73 | 1.72 |
-| 128 | 16 | 1.33 | 1.71 | 1.34 | 2.07 | 2.08 | 2.05 |
-| 128 | 24 | 1.82 | 1.82 | 1.81 | 3.04 | 3.07 | 3.01 |
-| 128 | 32 | 2.23 | 2.24 | 2.23 | 3.90 | 3.93 | 3.86 |
-| 128 | 64 | 4.15 | 4.17 | 4.12 | 7.62 | 7.70 | 7.57 |
-| 128 | 128 | 8.11 | 8.12 | 8.03 | 15.34 | 15.35 | 15.20 |
-| 384 | 1 | 1.13 | 1.45 | 1.13 | 1.24 | 1.25 | 1.24 |
-| 384 | 2 | 1.31 | 1.31 | 1.31 | 1.54 | 1.98 | 1.55 |
-| 384 | 4 | 1.66 | 1.66 | 1.66 | 2.12 | 2.12 | 2.12 |
-| 384 | 8 | 2.21 | 2.21 | 2.20 | 3.34 | 3.36 | 3.32 |
-| 384 | 12 | 3.32 | 3.32 | 3.31 | 4.78 | 4.82 | 4.77 |
-| 384 | 16 | 4.01 | 4.01 | 4.00 | 6.37 | 6.44 | 6.35 |
-| 384 | 24 | 5.71 | 5.71 | 5.70 | 9.47 | 9.49 | 9.39 |
-| 384 | 32 | 7.64 | 7.64 | 7.63 | 13.00 | 13.04 | 12.85 |
-| 384 | 64 | 14.87 | 14.88 | 14.73 | 25.12 | 25.14 | 24.78 |
-| 384 | 128 | 28.96 | 28.97 | 28.70 | 48.93 | 49.13 | 48.57 |
-
-##### BERT Large
+| 128 | 1 | 0.69 | 0.69 | 0.55 | 0.79 | 0.79 | 0.63 |
+| 128 | 2 | 0.60 | 0.76 | 0.60 | 0.72 | 0.91 | 0.72 |
+| 128 | 4 | 0.73 | 0.93 | 0.73 | 1.09 | 1.09 | 0.94 |
+| 128 | 8 | 1.21 | 1.21 | 0.95 | 1.31 | 1.31 | 1.30 |
+| 128 | 12 | 1.40 | 1.40 | 1.21 | 1.72 | 1.72 | 1.72 |
+| 128 | 16 | 1.34 | 1.71 | 1.34 | 2.08 | 2.08 | 2.06 |
+| 128 | 24 | 1.82 | 1.83 | 1.82 | 3.05 | 3.06 | 3.03 |
+| 128 | 32 | 2.23 | 2.24 | 2.23 | 3.95 | 3.99 | 3.91 |
+| 128 | 64 | 4.19 | 4.20 | 4.14 | 7.82 | 7.83 | 7.69 |
+| 128 | 128 | 8.14 | 8.19 | 8.09 | 15.37 | 15.42 | 15.32 |
+| 384 | 1 | 1.13 | 1.45 | 1.14 | 1.25 | 1.60 | 1.26 |
+| 384 | 2 | 1.32 | 1.69 | 1.32 | 1.55 | 1.98 | 1.55 |
+| 384 | 4 | 1.66 | 2.12 | 1.66 | 2.12 | 2.13 | 2.12 |
+| 384 | 8 | 2.21 | 2.21 | 2.20 | 3.37 | 3.40 | 3.33 |
+| 384 | 12 | 3.31 | 3.31 | 3.31 | 4.82 | 4.83 | 4.78 |
+| 384 | 16 | 4.00 | 4.00 | 4.00 | 6.38 | 6.43 | 6.37 |
+| 384 | 24 | 5.70 | 5.75 | 5.70 | 9.44 | 9.49 | 9.35 |
+| 384 | 32 | 7.72 | 7.74 | 7.66 | 13.02 | 13.02 | 12.91 |
+| 384 | 64 | 14.88 | 14.90 | 14.84 | 25.17 | 25.25 | 24.88 |
+| 384 | 128 | 29.00 | 29.01 | 28.83 | 49.03 | 49.22 | 48.77 |
+
+##### BERT large
| Sequence Length | Batch Size | INT8 Latency (ms) | | | FP16 Latency (ms) | | |
|-----------------|------------|-----------------|-----------------|---------|-----------------|-----------------|---------|
| | | 95th Percentile | 99th Percentile | Average | 95th Percentile | 99th Percentile | Average |
-| 128 | 1 | 1.22 | 1.23 | 1.22 | 1.54 | 1.91 | 1.55 |
-| 128 | 2 | 1.42 | 1.42 | 1.41 | 1.82 | 1.82 | 1.82 |
-| 128 | 4 | 1.78 | 2.06 | 1.79 | 2.50 | 2.50 | 2.50 |
-| 128 | 8 | 2.64 | 2.64 | 2.64 | 3.98 | 3.98 | 3.98 |
-| 128 | 12 | 3.09 | 3.09 | 3.08 | 5.02 | 5.07 | 4.99 |
-| 128 | 16 | 4.09 | 4.09 | 4.08 | 6.93 | 6.94 | 6.86 |
-| 128 | 24 | 5.28 | 5.28 | 5.27 | 9.64 | 9.68 | 9.56 |
-| 128 | 32 | 7.01 | 7.01 | 6.95 | 12.92 | 13.07 | 12.85 |
-| 128 | 64 | 12.86 | 12.86 | 12.73 | 24.79 | 25.07 | 24.59 |
-| 128 | 128 | 25.03 | 25.26 | 24.99 | 49.12 | 49.28 | 48.83 |
-| 384 | 1 | 2.55 | 2.55 | 2.55 | 2.96 | 2.96 | 2.95 |
+| 128 | 1 | 1.24 | 1.24 | 1.24 | 1.54 | 1.55 | 1.54 |
+| 128 | 2 | 1.42 | 1.79 | 1.42 | 1.82 | 1.82 | 1.82 |
+| 128 | 4 | 1.78 | 1.79 | 1.78 | 2.53 | 2.53 | 2.52 |
+| 128 | 8 | 2.64 | 2.64 | 2.64 | 4.07 | 4.10 | 4.06 |
+| 128 | 12 | 3.11 | 3.12 | 3.11 | 5.08 | 5.10 | 5.03 |
+| 128 | 16 | 4.03 | 4.03 | 4.03 | 6.95 | 6.95 | 6.90 |
+| 128 | 24 | 5.32 | 5.34 | 5.30 | 9.80 | 9.90 | 9.72 |
+| 128 | 32 | 7.07 | 7.07 | 7.00 | 13.08 | 13.08 | 12.93 |
+| 128 | 64 | 12.94 | 13.01 | 12.82 | 24.83 | 24.99 | 24.69 |
+| 128 | 128 | 25.29 | 25.29 | 25.09 | 49.70 | 49.72 | 49.06 |
+| 384 | 1 | 2.55 | 2.56 | 2.55 | 2.96 | 2.96 | 2.96 |
| 384 | 2 | 3.04 | 3.04 | 3.03 | 3.90 | 3.90 | 3.90 |
-| 384 | 4 | 4.01 | 4.02 | 4.01 | 5.68 | 5.74 | 5.67 |
-| 384 | 8 | 7.18 | 7.18 | 7.17 | 11.13 | 11.13 | 11.01 |
-| 384 | 12 | 9.14 | 9.15 | 9.13 | 15.43 | 15.44 | 15.32 |
-| 384 | 16 | 12.28 | 12.28 | 12.27 | 21.14 | 21.15 | 20.90 |
-| 384 | 24 | 17.68 | 17.68 | 17.54 | 30.98 | 31.02 | 30.68 |
-| 384 | 32 | 23.24 | 23.24 | 23.02 | 41.11 | 41.20 | 40.58 |
-| 384 | 64 | 44.86 | 45.13 | 44.78 | 79.25 | 79.68 | 79.10 |
-| 384 | 128 | 87.82 | 87.84 | 87.69 | 156.70 | 157.02 | 155.61 |
+| 384 | 4 | 4.01 | 4.01 | 4.01 | 5.74 | 5.79 | 5.71 |
+| 384 | 8 | 7.16 | 7.16 | 7.15 | 11.15 | 11.24 | 11.09 |
+| 384 | 12 | 9.15 | 9.23 | 9.14 | 15.46 | 15.47 | 15.40 |
+| 384 | 16 | 12.40 | 12.40 | 12.29 | 21.17 | 21.18 | 21.05 |
+| 384 | 24 | 17.72 | 17.85 | 17.64 | 31.09 | 31.36 | 30.81 |
+| 384 | 32 | 23.29 | 23.31 | 23.15 | 41.32 | 41.34 | 40.86 |
+| 384 | 64 | 45.38 | 45.40 | 45.02 | 79.95 | 80.27 | 79.31 |
+| 384 | 128 | 87.97 | 87.99 | 87.89 | 156.97 | 157.56 | 155.84 |
##### Megatron Large with Sparsity
| Sequence Length | Batch Size | INT8 QAT Latency (ms) | | |
|-----------------|------------|-----------------|-----------------|---------|
| | | 95th Percentile | 99th Percentile | Average |
-| 128 | 1 | 1.11 | 1.40 | 1.11 |
-| 128 | 2 | 1.33 | 1.33 | 1.33 |
-| 128 | 4 | 1.78 | 1.78 | 1.78 |
-| 128 | 8 | 2.54 | 2.54 | 2.53 |
-| 128 | 12 | 2.97 | 2.97 | 2.97 |
-| 128 | 16 | 3.99 | 3.99 | 3.98 |
-| 128 | 24 | 4.91 | 4.91 | 4.90 |
-| 128 | 32 | 7.13 | 7.13 | 7.12 |
-| 128 | 64 | 11.61 | 11.62 | 11.60 |
-| 128 | 128 | 21.22 | 21.32 | 21.09 |
-| 384 | 1 | 1.71 | 2.15 | 1.71 |
-| 384 | 2 | 2.21 | 2.21 | 2.21 |
-| 384 | 4 | 3.47 | 3.48 | 3.47 |
-| 384 | 8 | 5.74 | 5.74 | 5.74 |
-| 384 | 12 | 8.21 | 8.21 | 8.20 |
-| 384 | 16 | 10.33 | 10.34 | 10.32 |
-| 384 | 24 | 14.68 | 14.69 | 14.67 |
-| 384 | 32 | 18.73 | 18.74 | 18.72 |
-| 384 | 64 | 35.77 | 35.78 | 35.49 |
-| 384 | 128 | 67.78 | 67.95 | 67.63 |
+| 128 | 1 | 1.24 | 1.56 | 1.24 |
+| 128 | 2 | 1.42 | 1.42 | 1.42 |
+| 128 | 4 | 1.78 | 1.79 | 1.78 |
+| 128 | 8 | 2.64 | 2.65 | 2.64 |
+| 128 | 12 | 3.11 | 3.12 | 3.11 |
+| 128 | 16 | 4.03 | 4.03 | 4.02 |
+| 128 | 24 | 5.32 | 5.34 | 5.31 |
+| 128 | 32 | 7.07 | 7.09 | 7.02 |
+| 128 | 64 | 12.98 | 13.01 | 12.86 |
+| 128 | 128 | 25.40 | 25.55 | 25.17 |
+| 384 | 1 | 2.55 | 2.55 | 2.55 |
+| 384 | 2 | 3.03 | 3.04 | 3.03 |
+| 384 | 4 | 4.01 | 4.01 | 4.01 |
+| 384 | 8 | 7.16 | 7.16 | 7.16 |
+| 384 | 12 | 9.14 | 9.23 | 9.14 |
+| 384 | 16 | 12.31 | 12.41 | 12.29 |
+| 384 | 24 | 17.85 | 17.90 | 17.68 |
+| 384 | 32 | 23.41 | 23.51 | 23.23 |
+| 384 | 64 | 45.39 | 45.40 | 45.09 |
+| 384 | 128 | 88.73 | 88.79 | 88.11 |
### Inference Performance NVIDIA L4
Results were obtained by running `scripts/inference_benchmark.sh --gpu Ampere` on NVIDIA L4.
-##### BERT Base
+##### BERT base
| Sequence Length | Batch Size | INT8 Latency (ms) | | | FP16 Latency (ms) | | |
|-----------------|------------|-----------------|-----------------|---------|-----------------|-----------------|---------|
| | | 95th Percentile | 99th Percentile | Average | 95th Percentile | 99th Percentile | Average |
-| 128 | 1 | 0.61 | 0.61 | 0.60 | 1.01 | 1.01 | 1.00 |
-| 128 | 2 | 0.79 | 0.80 | 0.77 | 1.32 | 1.35 | 1.31 |
-| 128 | 4 | 1.14 | 1.15 | 1.12 | 2.22 | 2.23 | 2.14 |
-| 128 | 8 | 1.94 | 1.96 | 1.90 | 3.66 | 3.67 | 3.63 |
-| 128 | 12 | 2.67 | 2.67 | 2.61 | 5.34 | 5.34 | 5.26 |
-| 128 | 16 | 3.37 | 3.38 | 3.32 | 6.69 | 6.69 | 6.64 |
-| 128 | 24 | 4.84 | 4.84 | 4.75 | 10.53 | 10.64 | 10.50 |
-| 128 | 32 | 6.21 | 6.28 | 6.13 | 13.91 | 13.91 | 13.72 |
-| 128 | 64 | 13.40 | 13.60 | 13.20 | 31.48 | 31.53 | 31.01 |
-| 128 | 128 | 28.42 | 28.68 | 27.84 | 70.60 | 71.10 | 69.25 |
-| 384 | 1 | 1.27 | 1.27 | 1.27 | 2.08 | 2.09 | 2.07 |
-| 384 | 2 | 1.84 | 1.84 | 1.82 | 3.15 | 3.19 | 3.11 |
-| 384 | 4 | 2.94 | 2.94 | 2.91 | 5.68 | 5.75 | 5.63 |
-| 384 | 8 | 5.53 | 5.55 | 5.42 | 11.45 | 11.59 | 11.32 |
-| 384 | 12 | 8.21 | 8.31 | 8.07 | 17.16 | 17.36 | 17.00 |
-| 384 | 16 | 10.96 | 11.07 | 10.80 | 23.20 | 23.50 | 22.81 |
-| 384 | 24 | 16.71 | 16.74 | 16.55 | 39.82 | 40.46 | 38.15 |
-| 384 | 32 | 22.82 | 23.00 | 22.63 | 50.56 | 50.89 | 50.14 |
-| 384 | 64 | 49.66 | 50.18 | 48.40 | 104.90 | 105.55 | 103.81 |
-| 384 | 128 | 104.78 | 105.09 | 103.96 | 208.20 | 208.70 | 206.93 |
-
-##### BERT Large
+| 128 | 1 | 0.62 | 0.62 | 0.60 | 1.03 | 1.03 | 1.01 |
+| 128 | 2 | 0.79 | 0.80 | 0.77 | 1.31 | 1.35 | 1.30 |
+| 128 | 4 | 1.14 | 1.15 | 1.12 | 2.23 | 2.23 | 2.15 |
+| 128 | 8 | 1.97 | 1.97 | 1.92 | 3.68 | 3.69 | 3.63 |
+| 128 | 12 | 2.66 | 2.67 | 2.61 | 5.34 | 5.35 | 5.27 |
+| 128 | 16 | 3.39 | 3.39 | 3.34 | 6.62 | 6.69 | 6.58 |
+| 128 | 24 | 4.84 | 4.85 | 4.76 | 10.49 | 10.55 | 10.32 |
+| 128 | 32 | 6.20 | 6.29 | 6.14 | 13.92 | 13.92 | 13.75 |
+| 128 | 64 | 13.42 | 13.42 | 13.26 | 31.28 | 31.48 | 31.07 |
+| 128 | 128 | 28.48 | 28.64 | 28.19 | 66.10 | 66.23 | 65.36 |
+| 384 | 1 | 1.29 | 1.30 | 1.29 | 2.08 | 2.09 | 2.08 |
+| 384 | 2 | 1.83 | 1.84 | 1.82 | 3.15 | 3.19 | 3.11 |
+| 384 | 4 | 2.99 | 2.99 | 2.92 | 5.75 | 5.81 | 5.68 |
+| 384 | 8 | 5.53 | 5.54 | 5.42 | 11.28 | 11.33 | 11.08 |
+| 384 | 12 | 8.26 | 8.29 | 8.09 | 17.19 | 17.22 | 16.99 |
+| 384 | 16 | 11.00 | 11.08 | 10.85 | 23.38 | 23.38 | 22.90 |
+| 384 | 24 | 16.79 | 16.89 | 16.60 | 37.90 | 38.29 | 37.18 |
+| 384 | 32 | 23.08 | 23.31 | 22.74 | 50.70 | 50.94 | 50.27 |
+| 384 | 64 | 49.43 | 49.86 | 48.56 | 103.88 | 104.19 | 102.89 |
+| 384 | 128 | 104.55 | 104.97 | 103.74 | 211.09 | 211.67 | 209.85 |
+
+##### BERT large
| Sequence Length | Batch Size | INT8 Latency (ms) | | | FP16 Latency (ms) | | |
|-----------------|------------|-----------------|-----------------|---------|-----------------|-----------------|---------|
| | | 95th Percentile | 99th Percentile | Average | 95th Percentile | 99th Percentile | Average |
-| 128 | 1 | 1.79 | 1.80 | 1.77 | 3.11 | 3.11 | 3.09 |
-| 128 | 2 | 2.49 | 2.49 | 2.43 | 4.35 | 4.37 | 4.33 |
-| 128 | 4 | 3.62 | 3.70 | 3.60 | 6.86 | 6.89 | 6.78 |
-| 128 | 8 | 6.26 | 6.31 | 6.24 | 12.85 | 12.91 | 12.73 |
-| 128 | 12 | 8.40 | 8.41 | 8.28 | 18.42 | 18.43 | 18.33 |
-| 128 | 16 | 11.23 | 11.24 | 11.12 | 25.18 | 25.19 | 25.10 |
-| 128 | 24 | 15.95 | 16.09 | 15.90 | 35.67 | 35.67 | 35.47 |
-| 128 | 32 | 21.26 | 21.31 | 20.91 | 48.92 | 49.21 | 48.26 |
-| 128 | 64 | 44.10 | 44.11 | 43.92 | 108.81 | 109.12 | 107.18 |
-| 128 | 128 | 94.22 | 95.02 | 92.65 | 217.32 | 219.58 | 212.68 |
-| 384 | 1 | 3.41 | 3.43 | 3.39 | 6.55 | 6.57 | 6.36 |
-| 384 | 2 | 5.55 | 5.56 | 5.46 | 10.34 | 10.35 | 10.18 |
-| 384 | 4 | 9.69 | 9.79 | 9.53 | 20.66 | 20.95 | 19.94 |
-| 384 | 8 | 18.08 | 18.19 | 17.92 | 38.41 | 39.30 | 37.62 |
-| 384 | 12 | 26.20 | 26.44 | 26.11 | 60.38 | 60.91 | 58.67 |
-| 384 | 16 | 36.33 | 36.41 | 36.02 | 81.66 | 82.16 | 80.52 |
-| 384 | 24 | 53.54 | 53.61 | 53.08 | 123.01 | 123.34 | 122.10 |
-| 384 | 32 | 75.01 | 75.43 | 74.40 | 170.40 | 171.03 | 169.12 |
-| 384 | 64 | 157.97 | 158.62 | 155.87 | 349.25 | 351.53 | 344.76 |
-| 384 | 128 | 330.88 | 331.87 | 328.27 | 632.85 | 633.88 | 629.74 |
+| 128 | 1 | 1.78 | 1.79 | 1.76 | 3.11 | 3.11 | 3.10 |
+| 128 | 2 | 2.50 | 2.51 | 2.44 | 4.35 | 4.45 | 4.31 |
+| 128 | 4 | 3.60 | 3.63 | 3.54 | 6.83 | 6.86 | 6.77 |
+| 128 | 8 | 6.27 | 6.31 | 6.25 | 12.98 | 13.01 | 12.80 |
+| 128 | 12 | 8.40 | 8.41 | 8.27 | 18.45 | 18.66 | 18.22 |
+| 128 | 16 | 11.22 | 11.23 | 11.12 | 25.18 | 25.19 | 25.14 |
+| 128 | 24 | 15.95 | 16.10 | 15.82 | 35.67 | 35.68 | 35.59 |
+| 128 | 32 | 21.30 | 21.35 | 20.90 | 49.02 | 49.26 | 48.33 |
+| 128 | 64 | 44.08 | 44.32 | 43.93 | 107.89 | 108.30 | 107.11 |
+| 128 | 128 | 93.69 | 94.36 | 92.69 | 215.00 | 215.46 | 213.84 |
+| 384 | 1 | 3.43 | 3.44 | 3.41 | 6.58 | 6.66 | 6.40 |
+| 384 | 2 | 5.55 | 5.55 | 5.49 | 10.56 | 10.59 | 10.44 |
+| 384 | 4 | 9.80 | 9.88 | 9.58 | 20.55 | 20.94 | 19.93 |
+| 384 | 8 | 18.04 | 18.11 | 17.86 | 38.87 | 39.47 | 37.69 |
+| 384 | 12 | 26.44 | 26.61 | 26.14 | 59.28 | 59.85 | 56.90 |
+| 384 | 16 | 36.37 | 36.48 | 36.04 | 82.93 | 83.33 | 81.95 |
+| 384 | 24 | 53.60 | 53.73 | 53.15 | 122.78 | 123.06 | 122.05 |
+| 384 | 32 | 75.52 | 75.84 | 74.45 | 164.55 | 164.98 | 163.68 |
+| 384 | 64 | 157.71 | 158.27 | 155.68 | 345.90 | 346.53 | 344.57 |
+| 384 | 128 | 331.37 | 332.44 | 329.06 | 663.75 | 664.69 | 661.89 |
##### Megatron Large with Sparsity
| Sequence Length | Batch Size | INT8 QAT Latency (ms) | | |
|-----------------|------------|-----------------|-----------------|---------|
| | | 95th Percentile | 99th Percentile | Average |
-| 128 | 1 | 1.49 | 1.49 | 1.48 |
-| 128 | 2 | 2.03 | 2.03 | 1.99 |
-| 128 | 4 | 2.99 | 3.00 | 2.93 |
-| 128 | 8 | 5.00 | 5.07 | 4.99 |
-| 128 | 12 | 6.69 | 6.72 | 6.58 |
-| 128 | 16 | 8.77 | 8.84 | 8.66 |
-| 128 | 24 | 13.28 | 13.30 | 13.14 |
-| 128 | 32 | 17.41 | 17.44 | 17.26 |
-| 128 | 64 | 35.73 | 36.07 | 35.49 |
-| 128 | 128 | 79.03 | 79.15 | 78.47 |
-| 384 | 1 | 2.78 | 2.79 | 2.72 |
-| 384 | 2 | 4.10 | 4.12 | 4.06 |
-| 384 | 4 | 7.57 | 7.58 | 7.45 |
-| 384 | 8 | 15.03 | 15.10 | 14.86 |
-| 384 | 12 | 21.52 | 21.69 | 21.31 |
-| 384 | 16 | 28.29 | 28.33 | 28.10 |
-| 384 | 24 | 46.83 | 47.09 | 46.29 |
-| 384 | 32 | 60.29 | 60.47 | 59.37 |
-| 384 | 64 | 125.58 | 125.64 | 125.24 |
-| 384 | 128 | 253.46 | 253.90 | 252.28 |
+| 128 | 1 | 1.78 | 1.79 | 1.76 |
+| 128 | 2 | 2.50 | 2.51 | 2.44 |
+| 128 | 4 | 3.56 | 3.57 | 3.54 |
+| 128 | 8 | 6.27 | 6.31 | 6.26 |
+| 128 | 12 | 8.40 | 8.41 | 8.29 |
+| 128 | 16 | 11.23 | 11.23 | 11.16 |
+| 128 | 24 | 16.06 | 16.12 | 15.90 |
+| 128 | 32 | 21.31 | 21.34 | 20.98 |
+| 128 | 64 | 44.15 | 44.66 | 43.88 |
+| 128 | 128 | 94.19 | 94.93 | 92.81 |
+| 384 | 1 | 3.39 | 3.43 | 3.37 |
+| 384 | 2 | 5.56 | 5.56 | 5.48 |
+| 384 | 4 | 9.81 | 9.90 | 9.61 |
+| 384 | 8 | 18.07 | 18.25 | 17.94 |
+| 384 | 12 | 26.47 | 26.57 | 26.27 |
+| 384 | 16 | 36.78 | 37.14 | 36.37 |
+| 384 | 24 | 54.16 | 54.53 | 53.65 |
+| 384 | 32 | 75.33 | 75.62 | 74.69 |
+| 384 | 64 | 158.72 | 159.55 | 156.72 |
+| 384 | 128 | 333.24 | 334.26 | 330.67 |
### Inference Performance NVIDIA L40S
Results were obtained by running `scripts/inference_benchmark.sh --gpu Ampere` on NVIDIA L40S.
-##### BERT Base
+##### BERT base
| Sequence Length | Batch Size | INT8 Latency (ms) | | | FP16 Latency (ms) | | |
|-----------------|------------|-----------------|-----------------|---------|-----------------|-----------------|---------|
| | | 95th Percentile | 99th Percentile | Average | 95th Percentile | 99th Percentile | Average |
-| 128 | 1 | 0.33 | 0.33 | 0.33 | 0.48 | 0.48 | 0.48 |
-| 128 | 2 | 0.41 | 0.41 | 0.41 | 0.57 | 0.57 | 0.57 |
-| 128 | 4 | 0.50 | 0.51 | 0.50 | 0.78 | 0.78 | 0.78 |
-| 128 | 8 | 0.67 | 0.67 | 0.67 | 1.33 | 1.33 | 1.32 |
-| 128 | 12 | 0.91 | 0.91 | 0.91 | 1.75 | 1.76 | 1.73 |
-| 128 | 16 | 1.10 | 1.10 | 1.09 | 2.29 | 2.29 | 2.28 |
-| 128 | 24 | 1.48 | 1.49 | 1.47 | 3.30 | 3.31 | 3.27 |
-| 128 | 32 | 1.84 | 1.84 | 1.83 | 3.98 | 3.99 | 3.97 |
-| 128 | 64 | 3.61 | 3.66 | 3.56 | 8.64 | 8.70 | 8.51 |
-| 128 | 128 | 7.92 | 7.99 | 7.82 | 18.78 | 18.82 | 18.45 |
-| 384 | 1 | 0.73 | 0.73 | 0.73 | 1.11 | 1.12 | 1.10 |
-| 384 | 2 | 0.88 | 0.88 | 0.88 | 1.39 | 1.39 | 1.38 |
-| 384 | 4 | 1.17 | 1.17 | 1.17 | 2.19 | 2.20 | 2.19 |
-| 384 | 8 | 1.74 | 1.74 | 1.73 | 3.53 | 3.53 | 3.50 |
-| 384 | 12 | 2.75 | 2.75 | 2.73 | 5.32 | 5.33 | 5.29 |
-| 384 | 16 | 3.33 | 3.33 | 3.31 | 7.62 | 7.64 | 7.57 |
-| 384 | 24 | 4.97 | 4.98 | 4.95 | 10.53 | 10.57 | 10.40 |
-| 384 | 32 | 6.55 | 6.57 | 6.48 | 14.36 | 14.47 | 14.20 |
-| 384 | 64 | 14.27 | 14.37 | 14.07 | 33.31 | 33.51 | 32.65 |
-| 384 | 128 | 30.38 | 30.52 | 29.73 | 67.34 | 68.04 | 66.06 |
-
-##### BERT Large
+| 128 | 1 | 0.34 | 0.34 | 0.34 | 0.48 | 0.48 | 0.48 |
+| 128 | 2 | 0.41 | 0.41 | 0.41 | 0.56 | 0.56 | 0.55 |
+| 128 | 4 | 0.50 | 0.50 | 0.50 | 0.77 | 0.77 | 0.77 |
+| 128 | 8 | 0.67 | 0.67 | 0.67 | 1.30 | 1.30 | 1.29 |
+| 128 | 12 | 0.91 | 0.91 | 0.91 | 1.68 | 1.68 | 1.67 |
+| 128 | 16 | 1.09 | 1.10 | 1.09 | 2.22 | 2.23 | 2.22 |
+| 128 | 24 | 1.50 | 1.50 | 1.48 | 3.23 | 3.24 | 3.20 |
+| 128 | 32 | 1.82 | 1.83 | 1.82 | 3.94 | 3.94 | 3.93 |
+| 128 | 64 | 3.47 | 3.47 | 3.45 | 8.24 | 8.26 | 8.14 |
+| 128 | 128 | 7.74 | 7.91 | 7.66 | 17.73 | 17.86 | 17.56 |
+| 384 | 1 | 0.73 | 0.73 | 0.73 | 1.02 | 1.02 | 1.02 |
+| 384 | 2 | 0.88 | 0.89 | 0.88 | 1.38 | 1.38 | 1.37 |
+| 384 | 4 | 1.17 | 1.17 | 1.16 | 2.16 | 2.17 | 2.15 |
+| 384 | 8 | 1.72 | 1.73 | 1.72 | 3.45 | 3.46 | 3.45 |
+| 384 | 12 | 2.73 | 2.73 | 2.72 | 5.07 | 5.07 | 5.05 |
+| 384 | 16 | 3.28 | 3.28 | 3.27 | 7.41 | 7.44 | 7.37 |
+| 384 | 24 | 4.93 | 4.94 | 4.90 | 10.16 | 10.19 | 10.09 |
+| 384 | 32 | 6.33 | 6.34 | 6.29 | 14.07 | 14.11 | 13.96 |
+| 384 | 64 | 13.74 | 13.76 | 13.57 | 30.65 | 30.82 | 30.14 |
+| 384 | 128 | 28.25 | 28.41 | 27.87 | 62.48 | 62.67 | 61.70 |
+
+##### BERT large
| Sequence Length | Batch Size | INT8 Latency (ms) | | | FP16 Latency (ms) | | |
|-----------------|------------|-----------------|-----------------|---------|-----------------|-----------------|---------|
| | | 95th Percentile | 99th Percentile | Average | 95th Percentile | 99th Percentile | Average |
-| 128 | 1 | 0.89 | 0.89 | 0.88 | 1.30 | 1.30 | 1.29 |
-| 128 | 2 | 0.97 | 0.98 | 0.97 | 1.45 | 1.45 | 1.44 |
-| 128 | 4 | 1.36 | 1.36 | 1.35 | 2.30 | 2.30 | 2.29 |
-| 128 | 8 | 1.94 | 1.96 | 1.93 | 3.89 | 3.90 | 3.88 |
-| 128 | 12 | 2.82 | 2.82 | 2.80 | 5.89 | 5.90 | 5.85 |
-| 128 | 16 | 3.26 | 3.27 | 3.24 | 6.85 | 6.86 | 6.80 |
-| 128 | 24 | 4.62 | 4.63 | 4.59 | 10.72 | 10.73 | 10.64 |
-| 128 | 32 | 5.74 | 5.76 | 5.70 | 13.22 | 13.23 | 13.04 |
-| 128 | 64 | 12.18 | 12.20 | 11.97 | 29.42 | 29.59 | 28.89 |
-| 128 | 128 | 26.68 | 26.86 | 26.23 | 68.72 | 69.05 | 67.12 |
-| 384 | 1 | 1.68 | 1.68 | 1.68 | 2.78 | 2.78 | 2.77 |
-| 384 | 2 | 2.31 | 2.31 | 2.30 | 3.95 | 3.95 | 3.94 |
-| 384 | 4 | 3.29 | 3.30 | 3.29 | 6.57 | 6.58 | 6.50 |
-| 384 | 8 | 5.16 | 5.17 | 5.13 | 10.89 | 10.90 | 10.79 |
-| 384 | 12 | 8.16 | 8.17 | 8.10 | 19.81 | 19.91 | 19.31 |
-| 384 | 16 | 9.90 | 9.93 | 9.80 | 23.34 | 23.51 | 23.10 |
-| 384 | 24 | 15.60 | 15.62 | 15.39 | 37.37 | 37.48 | 36.93 |
-| 384 | 32 | 20.66 | 20.73 | 20.33 | 50.13 | 50.34 | 49.52 |
-| 384 | 64 | 46.31 | 46.53 | 45.39 | 111.74 | 111.98 | 110.14 |
-| 384 | 128 | 93.80 | 94.04 | 92.33 | 213.05 | 214.15 | 210.25 |
+| 128 | 1 | 0.89 | 0.89 | 0.88 | 1.30 | 1.30 | 1.30 |
+| 128 | 2 | 0.98 | 0.98 | 0.98 | 1.47 | 1.47 | 1.47 |
+| 128 | 4 | 1.34 | 1.35 | 1.33 | 2.29 | 2.29 | 2.28 |
+| 128 | 8 | 1.92 | 1.92 | 1.91 | 3.82 | 3.83 | 3.79 |
+| 128 | 12 | 2.77 | 2.78 | 2.75 | 5.73 | 5.73 | 5.71 |
+| 128 | 16 | 3.22 | 3.22 | 3.19 | 6.72 | 6.73 | 6.67 |
+| 128 | 24 | 4.54 | 4.55 | 4.52 | 10.38 | 10.39 | 10.31 |
+| 128 | 32 | 5.67 | 5.68 | 5.63 | 12.87 | 12.90 | 12.74 |
+| 128 | 64 | 11.92 | 11.96 | 11.77 | 28.21 | 28.40 | 27.89 |
+| 128 | 128 | 25.44 | 25.49 | 25.12 | 61.85 | 62.01 | 61.20 |
+| 384 | 1 | 1.68 | 1.68 | 1.68 | 2.74 | 2.75 | 2.74 |
+| 384 | 2 | 2.32 | 2.32 | 2.30 | 3.87 | 3.87 | 3.85 |
+| 384 | 4 | 3.27 | 3.28 | 3.27 | 6.32 | 6.35 | 6.29 |
+| 384 | 8 | 5.09 | 5.09 | 5.06 | 10.76 | 10.77 | 10.60 |
+| 384 | 12 | 8.06 | 8.07 | 8.02 | 18.73 | 18.77 | 18.64 |
+| 384 | 16 | 9.70 | 9.75 | 9.61 | 22.15 | 22.26 | 21.95 |
+| 384 | 24 | 15.02 | 15.04 | 14.88 | 35.43 | 35.48 | 35.15 |
+| 384 | 32 | 20.37 | 20.49 | 20.00 | 46.36 | 46.37 | 45.86 |
+| 384 | 64 | 43.50 | 43.65 | 43.02 | 105.84 | 106.08 | 104.92 |
+| 384 | 128 | 86.30 | 86.48 | 85.58 | 195.18 | 195.98 | 192.90 |
##### Megatron Large with Sparsity
| Sequence Length | Batch Size | INT8 QAT Latency (ms) | | |
|-----------------|------------|-----------------|-----------------|---------|
| | | 95th Percentile | 99th Percentile | Average |
-| 128 | 1 | 0.76 | 0.76 | 0.76 |
-| 128 | 2 | 0.91 | 0.91 | 0.91 |
-| 128 | 4 | 1.13 | 1.13 | 1.13 |
-| 128 | 8 | 1.70 | 1.70 | 1.70 |
-| 128 | 12 | 2.26 | 2.26 | 2.25 |
-| 128 | 16 | 2.72 | 2.72 | 2.71 |
-| 128 | 24 | 4.54 | 4.55 | 4.52 |
-| 128 | 32 | 5.14 | 5.16 | 5.10 |
-| 128 | 64 | 10.07 | 10.08 | 10.01 |
-| 128 | 128 | 21.57 | 21.67 | 21.21 |
-| 384 | 1 | 1.13 | 1.13 | 1.13 |
-| 384 | 2 | 1.64 | 1.65 | 1.62 |
-| 384 | 4 | 2.51 | 2.51 | 2.50 |
-| 384 | 8 | 5.02 | 5.03 | 4.99 |
-| 384 | 12 | 6.43 | 6.43 | 6.41 |
-| 384 | 16 | 8.47 | 8.49 | 8.41 |
-| 384 | 24 | 12.62 | 12.65 | 12.54 |
-| 384 | 32 | 16.88 | 16.91 | 16.74 |
-| 384 | 64 | 36.62 | 36.71 | 36.12 |
-| 384 | 128 | 79.88 | 80.18 | 77.33 |
-
+| 128 | 1 | 0.89 | 0.89 | 0.88 |
+| 128 | 2 | 0.98 | 0.98 | 0.98 |
+| 128 | 4 | 1.34 | 1.36 | 1.33 |
+| 128 | 8 | 1.93 | 1.95 | 1.91 |
+| 128 | 12 | 2.79 | 2.82 | 2.77 |
+| 128 | 16 | 3.24 | 3.24 | 3.22 |
+| 128 | 24 | 4.59 | 4.59 | 4.57 |
+| 128 | 32 | 5.68 | 5.68 | 5.65 |
+| 128 | 64 | 11.81 | 11.87 | 11.71 |
+| 128 | 128 | 26.21 | 26.24 | 25.86 |
+| 384 | 1 | 1.68 | 1.68 | 1.68 |
+| 384 | 2 | 2.31 | 2.32 | 2.31 |
+| 384 | 4 | 3.29 | 3.29 | 3.28 |
+| 384 | 8 | 5.14 | 5.15 | 5.10 |
+| 384 | 12 | 8.05 | 8.06 | 8.01 |
+| 384 | 16 | 9.78 | 9.80 | 9.66 |
+| 384 | 24 | 15.14 | 15.15 | 15.01 |
+| 384 | 32 | 20.34 | 20.42 | 19.99 |
+| 384 | 64 | 43.81 | 43.97 | 43.39 |
+| 384 | 128 | 88.37 | 88.64 | 87.38 |
\ No newline at end of file
diff --git a/demo/Diffusion/README.md b/demo/Diffusion/README.md
index 469f0b26..dd9aa50d 100755
--- a/demo/Diffusion/README.md
+++ b/demo/Diffusion/README.md
@@ -7,7 +7,7 @@ This demo application ("demoDiffusion") showcases the acceleration of Stable Dif
### Clone the TensorRT OSS repository
```bash
-git clone git@github.com:NVIDIA/TensorRT.git -b release/10.2 --single-branch
+git clone git@github.com:NVIDIA/TensorRT.git -b release/10.4 --single-branch
cd TensorRT
```
@@ -43,17 +43,17 @@ pip3 install -r requirements.txt
> NOTE: demoDiffusion has been tested on systems with NVIDIA H100, A100, L40, T4, and RTX4090 GPUs, and the following software configuration.
```
-diffusers 0.26.3
+diffusers 0.29.2
onnx 1.15.0
onnx-graphsurgeon 0.5.2
onnxruntime 1.16.3
polygraphy 0.49.9
-tensorrt 10.3.0.26
+tensorrt 10.4.0.26
tokenizers 0.13.3
torch 2.2.0
transformers 4.33.1
controlnet-aux 0.0.6
-nvidia-modelopt 0.11.2
+nvidia-modelopt 0.15.1
```
> NOTE: optionally install HuggingFace [accelerate](https://pypi.org/project/accelerate/) package for faster and less memory-intense model loading. Note that installing accelerate is known to cause failures while running certain pipelines in Torch Compile mode ([known issue](https://github.com/huggingface/diffusers/issues/9091))
@@ -84,6 +84,20 @@ export HF_TOKEN=
python3 demo_txt2img.py "a beautiful photograph of Mt. Fuji during cherry blossom" --hf-token=$HF_TOKEN
```
+### Faster Text-to-image using SD1.5 or SD2.1 INT8 & FP8 quantization using ModelOpt
+
+Run the below command to generate an image with SD1.5 or SD2.1 in INT8
+
+```bash
+python3 demo_txt2img.py "a beautiful photograph of Mt. Fuji during cherry blossom" --hf-token=$HF_TOKEN --int8
+```
+
+Run the below command to generate an image with SD1.5 or SD2.1 in FP8. (FP8 is only supppoted on Hopper.)
+
+```bash
+python3 demo_txt2img.py "a beautiful photograph of Mt. Fuji during cherry blossom" --hf-token=$HF_TOKEN --fp8
+```
+
### Generate an image guided by an initial image and a text prompt
```bash
@@ -139,16 +153,26 @@ python3 demo_txt2img_xl.py "a photo of an astronaut riding a horse on mars" --hf
python3 demo_txt2img_xl.py "Picture of a rustic Italian village with Olive trees and mountains" --version=xl-1.0 --lora-path "ostris/crayon_style_lora_sdxl" "ostris/watercolor_style_lora_sdxl" --lora-scale 0.3 0.7 --onnx-dir onnx-sdxl-lora --engine-dir engine-sdxl-lora --build-enable-refit
```
-### Faster Text-to-image using SDXL & INT8 quantization using ModelOpt
+### Faster Text-to-image using SDXL INT8 & FP8 quantization using ModelOpt
+
+Run the below command to generate an image with Stable Diffusion XL in INT8
```bash
python3 demo_txt2img_xl.py "a photo of an astronaut riding a horse on mars" --version xl-1.0 --onnx-dir onnx-sdxl --engine-dir engine-sdxl --int8
```
-> Note that INT8 quantization is only supported for SDXL, and won't work with LoRA weights. Some prompts may produce better inputs with fewer denoising steps (e.g. `--denoising-steps 20`) but this will repeat the calibration, ONNX export, and engine building processes for the U-Net.
-For step-by-step tutorials to run INT8 inference on stable diffusion models, please refer to examples in [TensorRT ModelOpt diffusers sample](https://github.com/NVIDIA/TensorRT-Model-Optimizer/tree/main/diffusers).
+Run the below command to generate an image with Stable Diffusion XL in FP8. (FP8 is only supppoted on Hopper.)
+
+```bash
+python3 demo_txt2img_xl.py "a photo of an astronaut riding a horse on mars" --version xl-1.0 --onnx-dir onnx-sdxl --engine-dir engine-sdxl --fp8
+```
+
+> Note that INT8 & FP8 quantization is only supported for SDXL, SD1.5, SD2.1 and SD2.1-base, and won't work with LoRA weights. FP8 quantization is only supported on Hopper. Some prompts may produce better inputs with fewer denoising steps (e.g. `--denoising-steps 20`) but this will repeat the calibration, ONNX export, and engine building processes for the U-Net.
+
+For step-by-step tutorials to run INT8 & FP8 inference on stable diffusion models, please refer to examples in [TensorRT ModelOpt diffusers sample](https://github.com/NVIDIA/TensorRT-Model-Optimizer/tree/main/diffusers).
### Faster Text-to-Image using SDXL + LCM (Latent Consistency Model) LoRA weights
+
[LCM-LoRA](https://arxiv.org/abs/2311.05556) produces good quality images in 4 to 8 denoising steps instead of 30+ needed base model. Note that we use LCM scheduler and disable classifier-free-guidance by setting `--guidance-scale` to 0.
LoRA weights are fused into the ONNX and finalized TensorRT plan files in this example.
```bash
@@ -200,6 +224,24 @@ python3 demo_img2vid.py --version svd-xt-1.1 --onnx-dir onnx-svd-xt-1-1 --engine
NOTE: The min and max guidance scales are configured using --min-guidance-scale and --max-guidance-scale respectively.
+### Generate an image using Stable Cascade guided by a text prompt
+
+Run the below command to generate an image using Stable Cascade
+```bash
+python3 demo_stable_cascade.py --onnx-opset=16 "Anthropomorphic cat dressed as a pilot" --onnx-dir onnx-sc --engine-dir engine-sc
+```
+
+The lite versions of the models are also supported using the command below
+```bash
+python3 demo_stable_cascade.py --onnx-opset=16 "Anthropomorphic cat dressed as a pilot" --onnx-dir onnx-sc-lite --engine-dir engine-sc-lite --lite
+```
+
+> NOTE: The pipeline is only enabled for the BF16 model weights
+
+> NOTE: The pipeline only supports ONNX export using Opset 16.
+
+> NOTE: The denoising steps and guidance scale for the Prior and Decoder models are configured using --prior-denoising-steps, --prior-guidance-scale, --decoder-denoising-steps, and --decoder-guidance-scale respectively.
+
## Configuration options
- Noise scheduler can be set using `--scheduler `. Note: not all schedulers are available for every version.
- To accelerate engine building time use `--timing-cache `. The cache file will be created if it does not already exist. Note that performance may degrade if cache files are used across multiple GPU targets. It is recommended to use timing caches only during development. To achieve the best perfromance in deployment, please build engines without timing cache.
diff --git a/demo/Diffusion/demo_stable_cascade.py b/demo/Diffusion/demo_stable_cascade.py
new file mode 100644
index 00000000..e7f3bed0
--- /dev/null
+++ b/demo/Diffusion/demo_stable_cascade.py
@@ -0,0 +1,159 @@
+#
+# SPDX-FileCopyrightText: Copyright (c) 1993-2024 NVIDIA CORPORATION & AFFILIATES. All rights reserved.
+# SPDX-License-Identifier: Apache-2.0
+#
+# Licensed under the Apache License, Version 2.0 (the "License");
+# you may not use this file except in compliance with the License.
+# You may obtain a copy of the License at
+#
+# http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing, software
+# distributed under the License is distributed on an "AS IS" BASIS,
+# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+# See the License for the specific language governing permissions and
+# limitations under the License.
+#
+
+import os
+import torch
+import argparse
+
+from cuda import cudart
+
+from stable_cascade_pipeline import StableCascadePipeline
+from utilities import PIPELINE_TYPE, add_arguments, process_pipeline_args
+
+def parseArgs():
+ parser = argparse.ArgumentParser(description="Options for Stable Cascade Txt2Img Demo", conflict_handler='resolve')
+ parser = add_arguments(parser)
+ parser.add_argument('--version', type=str, default="cascade", choices=["cascade"], help="Version of Stable Cascade")
+ parser.add_argument('--height', type=int, default=1024, help="Height of image to generate (must be multiple of 8)")
+ parser.add_argument('--width', type=int, default=1024, help="Width of image to generate (must be multiple of 8)")
+ parser.add_argument('--lite', action='store_true', help="Use the Lite Version of the Stage B and Stage C models")
+ parser.add_argument('--prior-guidance-scale', type=float, default=4.0, help="Value of classifier-free guidance scale for the prior")
+ parser.add_argument('--decoder-guidance-scale', type=float, default=0.0, help="Value of classifier-free guidance scale for the decoder")
+ parser.add_argument('--prior-denoising-steps', type=int, default=20, help="Number of denoising steps for the prior")
+ parser.add_argument('--decoder-denoising-steps', type=int, default=10, help="Number of denoising steps for the decoder")
+ return parser.parse_args()
+
+
+class StableCascadeDemoPipeline(StableCascadePipeline):
+ def __init__(self, prior_denoising_steps, decoder_denoising_steps, prior_guidance_scale, decoder_guidance_scale, lite, **kwargs):
+ self.nvtx_profile = kwargs['nvtx_profile']
+ self.prior = StableCascadePipeline(
+ pipeline_type=PIPELINE_TYPE.CASCADE_PRIOR,
+ denoising_steps=prior_denoising_steps,
+ guidance_scale=prior_guidance_scale,
+ return_latents=True,
+ lite=lite,
+ **kwargs,
+ )
+ self.decoder = StableCascadePipeline(
+ pipeline_type=PIPELINE_TYPE.CASCADE_DECODER,
+ denoising_steps=decoder_denoising_steps,
+ guidance_scale=decoder_guidance_scale,
+ lite=lite,
+ **kwargs,
+ )
+
+ def loadEngines(self, framework_model_dir, onnx_dir, engine_dir, **kwargs):
+ prior_suffix = "prior_lite" if self.prior.lite else "prior"
+ decoder_suffix = "decoder_lite" if self.decoder.lite else "decoder"
+ self.prior.loadEngines(
+ os.path.join(engine_dir, prior_suffix),
+ framework_model_dir,
+ os.path.join(onnx_dir, prior_suffix),
+ **kwargs)
+ self.decoder.loadEngines(
+ os.path.join(engine_dir, decoder_suffix),
+ framework_model_dir,
+ os.path.join(onnx_dir, decoder_suffix),
+ **kwargs)
+
+ def activateEngines(self, shared_device_memory=None):
+ self.prior.activateEngines(shared_device_memory)
+ self.decoder.activateEngines(shared_device_memory)
+
+ def loadResources(self, image_height, image_width, batch_size, seed):
+ self.prior.loadResources(image_height, image_width, batch_size, seed)
+ # Use a different seed for decoder
+ self.decoder.loadResources(image_height, image_width, batch_size, ((seed+1) if seed is not None else None))
+
+ def get_max_device_memory(self):
+ max_device_memory = self.prior.calculateMaxDeviceMemory()
+ max_device_memory = max(max_device_memory, self.decoder.calculateMaxDeviceMemory())
+ return max_device_memory
+
+ def run(self, prompt, negative_prompt, height, width, batch_size, batch_count, num_warmup_runs, use_cuda_graph):
+ # Process prompt
+ if not isinstance(prompt, list):
+ raise ValueError(f"`prompt` must be of type `str` list, but is {type(prompt)}")
+ prompt = prompt * batch_size
+
+ if not isinstance(negative_prompt, list):
+ raise ValueError(f"`--negative-prompt` must be of type `str` list, but is {type(negative_prompt)}")
+ if len(negative_prompt) == 1:
+ negative_prompt = negative_prompt * batch_size
+
+ num_warmup_runs = max(1, num_warmup_runs) if use_cuda_graph else num_warmup_runs
+ if num_warmup_runs > 0:
+ print("[I] Warming up ..")
+ for _ in range(num_warmup_runs):
+ latents, _ = self.prior.infer(prompt, negative_prompt, height, width, warmup=True)
+ latents = latents.to(torch.float16) if self.decoder.fp16 else latents
+ images, _ = self.decoder.infer(prompt, negative_prompt, height, width, image_embeddings=latents, warmup=True)
+
+ for _ in range(batch_count):
+ print("[I] Running Stable Cascade pipeline")
+ if self.nvtx_profile:
+ cudart.cudaProfilerStart()
+ latents, time_prior = self.prior.infer(prompt, negative_prompt, height, width, warmup=False)
+ latents = latents.to(torch.float16) if self.decoder.fp16 else latents
+ images, time_decoder = self.decoder.infer(prompt, negative_prompt, height, width, image_embeddings=latents, warmup=False)
+
+ if self.nvtx_profile:
+ cudart.cudaProfilerStop()
+ print('|-----------------|--------------|')
+ print('| {:^15} | {:>9.2f} ms |'.format('e2e', time_prior + time_decoder))
+ print('|-----------------|--------------|')
+
+ def teardown(self):
+ self.prior.teardown()
+ self.decoder.teardown()
+
+if __name__ == "__main__":
+ print("[I] Initializing StableCascade txt2img demo using TensorRT")
+ args = parseArgs()
+
+ kwargs_init_pipeline, kwargs_load_engine, args_run_demo = process_pipeline_args(args)
+
+ # Initialize demo
+ _ = kwargs_init_pipeline.pop('guidance_scale')
+ _ = kwargs_init_pipeline.pop('denoising_steps')
+ demo = StableCascadeDemoPipeline(
+ args.prior_denoising_steps,
+ args.decoder_denoising_steps,
+ args.prior_guidance_scale,
+ args.decoder_guidance_scale,
+ args.lite,
+ **kwargs_init_pipeline
+ )
+
+ # Load TensorRT engines and pytorch modules
+ demo.loadEngines(
+ args.framework_model_dir,
+ args.onnx_dir,
+ args.engine_dir,
+ **kwargs_load_engine,
+ )
+
+ # Load resources
+ _, shared_device_memory = cudart.cudaMalloc(demo.get_max_device_memory())
+ demo.activateEngines(shared_device_memory)
+ demo.loadResources(args.height, args.width, args.batch_size, args.seed)
+
+ # Run inference
+ demo.run(*args_run_demo)
+
+ demo.teardown()
diff --git a/demo/Diffusion/demo_txt2img_sd3.py b/demo/Diffusion/demo_txt2img_sd3.py
index 197de964..f61a2fed 100644
--- a/demo/Diffusion/demo_txt2img_sd3.py
+++ b/demo/Diffusion/demo_txt2img_sd3.py
@@ -20,16 +20,14 @@
from cuda import cudart
from stable_diffusion_3_pipeline import StableDiffusion3Pipeline
-from utilities import PIPELINE_TYPE
+from utilities import PIPELINE_TYPE, add_arguments
from utils_sd3.other_impls import preprocess_image_sd3
-def add_arguments(parser):
- # Stable Diffusion configuration
+def parseArgs():
+ # Stable Diffusion 3 configuration
+ parser = argparse.ArgumentParser(description="Options for Stable Diffusion 3 Txt2Img Demo", conflict_handler='resolve')
+ parser = add_arguments(parser)
parser.add_argument('--version', type=str, default="sd3", choices=["sd3"], help="Version of Stable Diffusion")
- parser.add_argument('prompt', nargs = '*', help="Text prompt(s) to guide image generation")
- parser.add_argument('--negative-prompt', nargs = '*', default=[''], help="The negative prompt(s) to guide the image generation.")
- parser.add_argument('--batch-size', type=int, default=1, choices=[1, 2, 4], help="Batch size (repeat prompt)")
- parser.add_argument('--batch-count', type=int, default=1, help="Number of images to generate in sequence, one at a time.")
parser.add_argument('--height', type=int, default=1024, help="Height of image to generate (must be multiple of 8)")
parser.add_argument('--width', type=int, default=1024, help="Height of image to generate (must be multiple of 8)")
parser.add_argument('--shift', type=int, default=1.0, help="Shift parameter for SD3")
@@ -38,31 +36,7 @@ def add_arguments(parser):
parser.add_argument('--denoising-percentage', type=float, default=0.6, help="Percentage of denoising steps to run. This parameter is only used if input-image is provided")
parser.add_argument('--input-image', type=str, default="", help="Path to the input image")
- # ONNX export
- parser.add_argument('--onnx-opset', type=int, default=19, choices=range(7,20), help="Select ONNX opset version to target for exported models")
- parser.add_argument('--onnx-dir', default='onnx', help="Output directory for ONNX export")
-
- # Framework model ckpt
- parser.add_argument('--framework-model-dir', default='pytorch_model', help="Directory for HF saved models")
-
- # TensorRT engine build
- parser.add_argument('--engine-dir', default='engine', help="Output directory for TensorRT engines")
- parser.add_argument('--build-static-batch', action='store_true', help="Build TensorRT engines with fixed batch size.")
- parser.add_argument('--build-dynamic-shape', action='store_true', help="Build TensorRT engines with dynamic image shapes.")
- parser.add_argument('--build-all-tactics', action='store_true', help="Build TensorRT engines using all tactic sources.")
- parser.add_argument('--timing-cache', default=None, type=str, help="Path to the precached timing measurements to accelerate build.")
-
- # TensorRT inference
- parser.add_argument('--num-warmup-runs', type=int, default=5, help="Number of warmup runs before benchmarking performance")
- parser.add_argument('--use-cuda-graph', action='store_true', help="Enable cuda graph")
- parser.add_argument('--nvtx-profile', action='store_true', help="Enable NVTX markers for performance profiling")
- parser.add_argument('--torch-inference', default='', help="Run inference with PyTorch (using specified compilation mode) instead of TensorRT.")
-
- parser.add_argument('--seed', type=int, default=None, help="Seed for random generator to get consistent results")
- parser.add_argument('--output-dir', default='output', help="Output directory for logs and image artifacts")
- parser.add_argument('--hf-token', type=str, help="HuggingFace API access token for downloading model checkpoints")
- parser.add_argument('-v', '--verbose', action='store_true', help="Show verbose output")
- return parser
+ return parser.parse_args()
def process_pipeline_args(args):
if args.height % 8 != 0 or args.width % 8 != 0:
@@ -119,11 +93,6 @@ def process_pipeline_args(args):
return kwargs_init_pipeline, kwargs_load_engine, args_run_demo
-def parseArgs():
- parser = argparse.ArgumentParser(description="Options for Stable Diffusion 3 Demo")
- parser = add_arguments(parser)
- return parser.parse_args()
-
if __name__ == "__main__":
print("[I] Initializing Stable Diffusion 3 demo using TensorRT")
args = parseArgs()
diff --git a/demo/Diffusion/models.py b/demo/Diffusion/models.py
index d9d82b69..cdd2300c 100644
--- a/demo/Diffusion/models.py
+++ b/demo/Diffusion/models.py
@@ -23,7 +23,9 @@
ControlNetModel,
UNet2DConditionModel,
UNetSpatioTemporalConditionModel,
+ StableCascadeUNet
)
+from diffusers.pipelines.wuerstchen import PaellaVQModel
import json
import numpy as np
import onnx
@@ -48,6 +50,13 @@
from utils_sd3.sd3_impls import BaseModel as BaseModelSD3
from utils_sd3.sd3_impls import SDVAE
from utils_sd3.other_impls import load_into, SDClipModel, SDXLClipG, T5XXLModel
+from utils_modelopt import (
+ convert_zp_fp8,
+ cast_resize_io,
+ convert_fp16_io,
+ cast_fp8_mha_io,
+)
+from onnxmltools.utils.float16_converter import convert_float_to_float16
class Optimizer():
def __init__(
@@ -99,7 +108,7 @@ def infer_shapes(self, return_onnx=False):
if return_onnx:
return onnx_graph
- def clip_add_hidden_states(self, return_onnx=False):
+ def clip_add_hidden_states(self, hidden_layer_offset, return_onnx=False):
hidden_layers = -1
onnx_graph = gs.export_onnx(self.graph)
for i in range(len(onnx_graph.graph.node)):
@@ -109,10 +118,10 @@ def clip_add_hidden_states(self, return_onnx=False):
hidden_layers = max(int(name.split(".")[1].split("/")[0]), hidden_layers)
for i in range(len(onnx_graph.graph.node)):
for j in range(len(onnx_graph.graph.node[i].output)):
- if onnx_graph.graph.node[i].output[j] == "/text_model/encoder/layers.{}/Add_1_output_0".format(hidden_layers-1):
+ if onnx_graph.graph.node[i].output[j] == "/text_model/encoder/layers.{}/Add_1_output_0".format(hidden_layers+hidden_layer_offset):
onnx_graph.graph.node[i].output[j] = "hidden_states"
for j in range(len(onnx_graph.graph.node[i].input)):
- if onnx_graph.graph.node[i].input[j] == "/text_model/encoder/layers.{}/Add_1_output_0".format(hidden_layers-1):
+ if onnx_graph.graph.node[i].input[j] == "/text_model/encoder/layers.{}/Add_1_output_0".format(hidden_layers+hidden_layer_offset):
onnx_graph.graph.node[i].input[j] = "hidden_states"
if return_onnx:
return onnx_graph
@@ -169,16 +178,30 @@ def fuse_mha_qkv_int8_sq(self):
print(f"Removed {removed} QDQ nodes")
return removed # expected 72 for L2.5
+ def modify_fp8_graph(self):
+ onnx_graph = gs.export_onnx(self.graph)
+ # Convert INT8 Zero to FP8.
+ onnx_graph = convert_zp_fp8(onnx_graph)
+ # Convert weights and activations to FP16 and insert Cast nodes in FP8 MHA.
+ onnx_graph = convert_float_to_float16(onnx_graph, keep_io_types=True, disable_shape_infer=True)
+ self.graph = gs.import_onnx(onnx_graph)
+ # Add cast nodes to Resize I/O.
+ cast_resize_io(self.graph)
+ # Convert model inputs and outputs to fp16 I/O.
+ convert_fp16_io(self.graph)
+ # Add cast nodes to MHA's BMM1 and BMM2's I/O.
+ cast_fp8_mha_io(self.graph)
+
def get_path(version, pipeline, controlnets=None):
if controlnets is not None:
return ["lllyasviel/sd-controlnet-" + modality for modality in controlnets]
-
+
if version in ("1.4", "1.5") and pipeline.is_inpaint():
- return "runwayml/stable-diffusion-inpainting"
+ return "benjamin-paine/stable-diffusion-v1-5-inpainting"
elif version == "1.4":
return "CompVis/stable-diffusion-v1-4"
elif version == "1.5":
- return "runwayml/stable-diffusion-v1-5"
+ return "benjamin-paine/stable-diffusion-v1-5"
elif version == 'dreamshaper-7':
return 'Lykon/dreamshaper-7'
elif version in ("2.0-base", "2.0") and pipeline.is_inpaint():
@@ -202,6 +225,11 @@ def get_path(version, pipeline, controlnets=None):
return "stabilityai/stable-diffusion-3-medium"
elif version == 'svd-xt-1.1' and pipeline.is_img2vid():
return "stabilityai/stable-video-diffusion-img2vid-xt-1-1"
+ elif version == 'cascade':
+ if pipeline.is_cascade_decoder():
+ return "stabilityai/stable-cascade"
+ else:
+ return "stabilityai/stable-cascade-prior"
else:
raise ValueError(f"Unsupported version {version} + pipeline {pipeline.name}")
@@ -218,7 +246,7 @@ def get_clip_embedding_dim(version, pipeline):
raise ValueError(f"Invalid version {version} + pipeline {pipeline}")
def get_clipwithproj_embedding_dim(version, pipeline):
- if version in ("xl-1.0", "xl-turbo"):
+ if version in ("xl-1.0", "xl-turbo", "cascade"):
return 1280
else:
raise ValueError(f"Invalid version {version} + pipeline {pipeline}")
@@ -230,6 +258,8 @@ def get_unet_embedding_dim(version, pipeline):
return 1024
elif version in ("xl-1.0", "xl-turbo") and pipeline.is_sd_xl_base():
return 2048
+ elif version in ("cascade"):
+ return 1280
elif version in ("xl-1.0", "xl-turbo") and pipeline.is_sd_xl_refiner():
return 1280
elif pipeline.is_img2vid():
@@ -305,10 +335,13 @@ def __init__(self,
verbose=True,
framework_model_dir='pytorch_model',
fp16=False,
+ bf16=False,
int8=False,
+ fp8=False,
max_batch_size=16,
text_maxlen=77,
embedding_dim=768,
+ compression_factor=8
):
self.name = self.__class__.__name__
@@ -322,23 +355,28 @@ def __init__(self,
self.framework_model_dir = framework_model_dir
self.fp16 = fp16
+ self.bf16 = bf16
self.int8 = int8
+ self.fp8 = fp8
+ self.compression_factor = compression_factor
self.min_batch = 1
self.max_batch = max_batch_size
self.min_image_shape = 256 # min image resolution: 256x256
self.max_image_shape = 1024 # max image resolution: 1024x1024
- self.min_latent_shape = self.min_image_shape // 8
- self.max_latent_shape = self.max_image_shape // 8
+ self.min_latent_shape = self.min_image_shape // self.compression_factor
+ self.max_latent_shape = self.max_image_shape // self.compression_factor
self.text_maxlen = text_maxlen
self.embedding_dim = embedding_dim
self.extra_output_names = []
self.lora_dict = None
+ self.do_constant_folding = True
def get_pipeline(self):
model_opts = {'variant': 'fp16', 'torch_dtype': torch.float16} if self.fp16 else {}
+ model_opts = {'variant': 'bf16', 'torch_dtype': torch.bfloat16} if self.bf16 else model_opts
return DiffusionPipeline.from_pretrained(
self.path,
use_safetensors=self.hf_safetensor,
@@ -399,7 +437,7 @@ def export_onnx(model):
onnx_path,
export_params=True,
opset_version=onnx_opset,
- do_constant_folding=True,
+ do_constant_folding=self.do_constant_folding,
input_names=self.get_input_names(),
output_names=self.get_output_names(),
dynamic_axes=self.get_dynamic_axes(),
@@ -479,13 +517,17 @@ def optimize(self, onnx_graph, return_onnx=True, **kwargs):
opt.info(self.name + ': original')
opt.cleanup()
opt.info(self.name + ': cleanup')
- opt.fold_constants()
- opt.info(self.name + ': fold constants')
- opt.infer_shapes()
- opt.info(self.name + ': shape inference')
- if kwargs.get('fuse_mha_qkv_int8', False):
- opt.fuse_mha_qkv_int8_sq()
- opt.info(self.name + ': fuse QKV nodes')
+ if kwargs.get('modify_fp8_graph', False):
+ opt.modify_fp8_graph()
+ opt.info(self.name + ': modify fp8 graph')
+ else:
+ opt.fold_constants()
+ opt.info(self.name + ': fold constants')
+ opt.infer_shapes()
+ opt.info(self.name + ': shape inference')
+ if kwargs.get('fuse_mha_qkv_int8', False):
+ opt.fuse_mha_qkv_int8_sq()
+ opt.info(self.name + ': fuse QKV nodes')
onnx_opt_graph = opt.cleanup(return_onnx=return_onnx)
opt.info(self.name + ': finished')
return onnx_opt_graph
@@ -493,8 +535,8 @@ def optimize(self, onnx_graph, return_onnx=True, **kwargs):
def check_dims(self, batch_size, image_height, image_width):
assert batch_size >= self.min_batch and batch_size <= self.max_batch
assert image_height % 8 == 0 or image_width % 8 == 0
- latent_height = image_height // 8
- latent_width = image_width // 8
+ latent_height = image_height // self.compression_factor
+ latent_width = image_width // self.compression_factor
assert latent_height >= self.min_latent_shape and latent_height <= self.max_latent_shape
assert latent_width >= self.min_latent_shape and latent_width <= self.max_latent_shape
return (latent_height, latent_width)
@@ -502,8 +544,8 @@ def check_dims(self, batch_size, image_height, image_width):
def get_minmax_dims(self, batch_size, image_height, image_width, static_batch, static_shape):
min_batch = batch_size if static_batch else self.min_batch
max_batch = batch_size if static_batch else self.max_batch
- latent_height = image_height // 8
- latent_width = image_width // 8
+ latent_height = image_height // self.compression_factor
+ latent_width = image_width // self.compression_factor
min_image_height = image_height if static_shape else self.min_image_shape
max_image_height = image_height if static_shape else self.max_image_shape
min_image_width = image_width if static_shape else self.min_image_shape
@@ -526,13 +568,15 @@ def __init__(self,
max_batch_size,
embedding_dim,
fp16=False,
+ bf16=False,
output_hidden_states=False,
subfolder="text_encoder",
lora_dict=None,
lora_alphas=None,
):
- super(CLIPModel, self).__init__(version, pipeline, device=device, hf_token=hf_token, verbose=verbose, framework_model_dir=framework_model_dir, fp16=fp16, max_batch_size=max_batch_size, embedding_dim=embedding_dim)
+ super(CLIPModel, self).__init__(version, pipeline, device=device, hf_token=hf_token, verbose=verbose, framework_model_dir=framework_model_dir, fp16=fp16, bf16=bf16, max_batch_size=max_batch_size, embedding_dim=embedding_dim)
self.subfolder = subfolder
+ self.hidden_layer_offset = 0 if pipeline.is_cascade() else -1
# Output the final hidden state
if output_hidden_states:
@@ -599,7 +643,7 @@ def optimize(self, onnx_graph):
opt.info(self.name + ': remove output[0]')
opt_onnx_graph = opt.cleanup(return_onnx=True)
if 'hidden_states' in self.extra_output_names:
- opt_onnx_graph = opt.clip_add_hidden_states(return_onnx=True)
+ opt_onnx_graph = opt.clip_add_hidden_states(self.hidden_layer_offset, return_onnx=True)
opt.info(self.name + ': added hidden_states')
opt.info(self.name + ': finished')
return opt_onnx_graph
@@ -614,6 +658,7 @@ def __init__(self,
verbose,
framework_model_dir,
fp16=False,
+ bf16=False,
max_batch_size=16,
output_hidden_states=False,
subfolder="text_encoder_2",
@@ -621,34 +666,64 @@ def __init__(self,
lora_alphas=None,
):
- super(CLIPWithProjModel, self).__init__(version, pipeline, device=device, hf_token=hf_token, verbose=verbose, framework_model_dir=framework_model_dir, fp16=fp16, max_batch_size=max_batch_size, embedding_dim=get_clipwithproj_embedding_dim(version, pipeline), output_hidden_states=output_hidden_states)
+ super(CLIPWithProjModel, self).__init__(version, pipeline, device=device, hf_token=hf_token, verbose=verbose, framework_model_dir=framework_model_dir, fp16=fp16, bf16=bf16, max_batch_size=max_batch_size, embedding_dim=get_clipwithproj_embedding_dim(version, pipeline), output_hidden_states=output_hidden_states)
self.subfolder = subfolder
def get_model(self, torch_inference=''):
+ model_opts = {'variant': 'bf16', 'torch_dtype': torch.bfloat16} if self.bf16 else {}
clip_model_dir = get_checkpoint_dir(self.framework_model_dir, self.version, self.pipeline, self.subfolder)
- if not os.path.exists(clip_model_dir):
+ clip_path = self.get_model_path(clip_model_dir, model_opts, model_name='model')
+ if not os.path.exists(clip_path):
model = CLIPTextModelWithProjection.from_pretrained(self.path,
subfolder=self.subfolder,
use_safetensors=self.hf_safetensor,
- use_auth_token=self.hf_token).to(self.device)
- model.save_pretrained(clip_model_dir)
+ use_auth_token=self.hf_token,
+ **model_opts).to(self.device)
+ model.save_pretrained(clip_model_dir, **model_opts)
else:
- print(f"[I] Load CLIPTextModelWithProjection model from: {clip_model_dir}")
- model = CLIPTextModelWithProjection.from_pretrained(clip_model_dir).to(self.device)
+ print(f"[I] Load CLIPTextModelWithProjection model from: {clip_path}")
+ model = CLIPTextModelWithProjection.from_pretrained(clip_model_dir, **model_opts).to(self.device)
model = optimize_checkpoint(model, torch_inference)
return model
+ def get_input_names(self):
+ return ['input_ids', 'attention_mask']
+
+ def get_output_names(self):
+ return ['text_embeddings']
+
+ def get_dynamic_axes(self):
+ return {
+ 'input_ids': {0: 'B'},
+ 'attention_mask': {0: 'B'},
+ 'text_embeddings': {0: 'B'}
+ }
+
+ def get_input_profile(self, batch_size, image_height, image_width, static_batch, static_shape):
+ self.check_dims(batch_size, image_height, image_width)
+ min_batch, max_batch, _, _, _, _, _, _, _, _ = self.get_minmax_dims(batch_size, image_height, image_width, static_batch, static_shape)
+ return {
+ 'input_ids': [(min_batch, self.text_maxlen), (batch_size, self.text_maxlen), (max_batch, self.text_maxlen)],
+ 'attention_mask': [(min_batch, self.text_maxlen), (batch_size, self.text_maxlen), (max_batch, self.text_maxlen)]
+ }
+
def get_shape_dict(self, batch_size, image_height, image_width):
self.check_dims(batch_size, image_height, image_width)
output = {
'input_ids': (batch_size, self.text_maxlen),
+ 'attention_mask': (batch_size, self.text_maxlen),
'text_embeddings': (batch_size, self.embedding_dim)
}
if 'hidden_states' in self.extra_output_names:
output["hidden_states"] = (batch_size, self.text_maxlen, self.embedding_dim)
-
return output
+ def get_sample_input(self, batch_size, image_height, image_width, static_shape):
+ self.check_dims(batch_size, image_height, image_width)
+ return (
+ torch.zeros(batch_size, self.text_maxlen, dtype=torch.int32, device=self.device),
+ torch.zeros(batch_size, self.text_maxlen, dtype=torch.int32, device=self.device)
+ )
class SD3_CLIPGModel(CLIPModel):
def __init__(self,
@@ -914,6 +989,7 @@ def __init__(self,
framework_model_dir,
fp16 = False,
int8 = False,
+ fp8 = False,
max_batch_size = 16,
text_maxlen = 77,
controlnets = None,
@@ -923,7 +999,7 @@ def __init__(self,
do_classifier_free_guidance = False,
):
- super(UNetModel, self).__init__(version, pipeline, device=device, hf_token=hf_token, verbose=verbose, framework_model_dir=framework_model_dir, fp16=fp16, max_batch_size=max_batch_size, text_maxlen=text_maxlen, embedding_dim=get_unet_embedding_dim(version, pipeline))
+ super(UNetModel, self).__init__(version, pipeline, device=device, hf_token=hf_token, verbose=verbose, framework_model_dir=framework_model_dir, fp16=fp16, int8=int8, fp8=fp8, max_batch_size=max_batch_size, text_maxlen=text_maxlen, embedding_dim=get_unet_embedding_dim(version, pipeline))
self.subfolder = 'unet'
self.controlnets = get_path(version, pipeline, controlnets) if controlnets else None
self.unet_dim = (9 if pipeline.is_inpaint() else 4)
@@ -1054,6 +1130,13 @@ def get_sample_input(self, batch_size, image_height, image_width, static_shape):
torch.randn(len(self.controlnets), dtype=dtype, device=self.device)
)
+ def optimize(self, onnx_graph):
+ if self.fp8:
+ return super().optimize(onnx_graph, modify_fp8_graph=True)
+ if self.int8:
+ return super().optimize(onnx_graph, fuse_mha_qkv_int8=True)
+ return super().optimize(onnx_graph)
+
class UNetXLModel(BaseModel):
def __init__(self,
@@ -1065,6 +1148,7 @@ def __init__(self,
framework_model_dir,
fp16 = False,
int8 = False,
+ fp8 = False,
max_batch_size = 16,
text_maxlen = 77,
lora_scales = None,
@@ -1072,7 +1156,7 @@ def __init__(self,
lora_alphas = None,
do_classifier_free_guidance = False,
):
- super(UNetXLModel, self).__init__(version, pipeline, device=device, hf_token=hf_token, verbose=verbose, framework_model_dir=framework_model_dir, fp16=fp16, max_batch_size=max_batch_size, text_maxlen=text_maxlen, embedding_dim=get_unet_embedding_dim(version, pipeline))
+ super(UNetXLModel, self).__init__(version, pipeline, device=device, hf_token=hf_token, verbose=verbose, framework_model_dir=framework_model_dir, fp16=fp16, int8=int8, fp8=fp8, max_batch_size=max_batch_size, text_maxlen=text_maxlen, embedding_dim=get_unet_embedding_dim(version, pipeline))
self.subfolder = 'unet'
self.unet_dim = (9 if pipeline.is_inpaint() else 4)
self.time_dim = (5 if pipeline.is_sd_xl_refiner() else 6)
@@ -1164,7 +1248,11 @@ def get_sample_input(self, batch_size, image_height, image_width, static_shape):
)
def optimize(self, onnx_graph):
- return super().optimize(onnx_graph, fuse_mha_qkv_int8=True)
+ if self.fp8:
+ return super().optimize(onnx_graph, modify_fp8_graph=True)
+ if self.int8:
+ return super().optimize(onnx_graph, fuse_mha_qkv_int8=True)
+ return super().optimize(onnx_graph)
class SD3_MMDiTModel(BaseModel):
def __init__(self,
@@ -1333,6 +1421,151 @@ def get_sample_input(self, batch_size, image_height, image_width):
)
+class UNetCascadeModel(BaseModel):
+ def __init__(self,
+ version,
+ pipeline,
+ device,
+ hf_token,
+ verbose,
+ framework_model_dir,
+ fp16 = False,
+ bf16 = False,
+ max_batch_size = 16,
+ text_maxlen = 77,
+ do_classifier_free_guidance = False,
+ compression_factor=42,
+ latent_dim_scale=10.67,
+ image_embedding_dim=768,
+ lite=False
+ ):
+ super(UNetCascadeModel, self).__init__(version, pipeline, device=device, hf_token=hf_token, verbose=verbose, framework_model_dir=framework_model_dir, fp16=fp16, bf16=bf16, max_batch_size=max_batch_size, text_maxlen=text_maxlen, embedding_dim=get_unet_embedding_dim(version, pipeline), compression_factor=compression_factor)
+ self.is_prior = True if pipeline.is_cascade_prior() else False
+ self.subfolder = 'prior' if self.is_prior else 'decoder'
+ if lite:
+ self.subfolder += '_lite'
+ self.prior_dim = 16
+ self.decoder_dim = 4
+ self.xB = 2 if do_classifier_free_guidance else 1 # batch multiplier
+ self.latent_dim_scale = latent_dim_scale
+ self.min_latent_shape = self.min_image_shape // self.compression_factor
+ self.max_latent_shape = self.max_image_shape // self.compression_factor
+ self.do_constant_folding = False
+ self.image_embedding_dim = image_embedding_dim
+
+ def get_model(self, torch_inference=''):
+ # FP16 variant doesn't exist
+ model_opts = {'torch_dtype': torch.float16} if self.fp16 else {}
+ model_opts = {'variant': 'bf16', 'torch_dtype': torch.bfloat16} if self.bf16 else model_opts
+ unet_model_dir = get_checkpoint_dir(self.framework_model_dir, self.version, self.pipeline, self.subfolder)
+ unet_path = self.get_model_path(unet_model_dir, model_opts)
+ if not os.path.exists(unet_path):
+ model = StableCascadeUNet.from_pretrained(self.path,
+ subfolder=self.subfolder,
+ use_safetensors=self.hf_safetensor,
+ use_auth_token=self.hf_token,
+ **model_opts).to(self.device)
+ model.save_pretrained(unet_model_dir, **model_opts)
+ else:
+ print(f"[I] Load Stable Cascade UNet pytorch model from: {unet_path}")
+ model = StableCascadeUNet.from_pretrained(unet_model_dir, **model_opts).to(self.device)
+ model = optimize_checkpoint(model, torch_inference)
+ return model
+
+ def get_input_names(self):
+ if self.is_prior:
+ return ['sample', 'timestep_ratio', 'clip_text_pooled', 'clip_text', 'clip_img']
+ else:
+ return ['sample', 'timestep_ratio', 'clip_text_pooled', 'effnet']
+
+ def get_output_names(self):
+ return ['latent']
+
+ def get_dynamic_axes(self):
+ xB = '2B' if self.xB == 2 else 'B'
+ if self.is_prior:
+ return {
+ 'sample': {0: xB, 2: 'H', 3: 'W'},
+ 'timestep_ratio': {0: xB},
+ 'clip_text_pooled': {0: xB},
+ 'clip_text': {0: xB},
+ 'clip_img': {0: xB},
+ 'latent': {0: xB, 2: 'H', 3: 'W'}
+ }
+ else:
+ return {
+ 'sample': {0: xB, 2: 'H', 3: 'W'},
+ 'timestep_ratio': {0: xB},
+ 'clip_text_pooled': {0: xB},
+ 'effnet': {0: xB, 2: 'H_effnet', 3: 'W_effnet'},
+ 'latent': {0: xB, 2: 'H', 3: 'W'}
+ }
+
+ def get_input_profile(self, batch_size, image_height, image_width, static_batch, static_shape):
+ latent_height, latent_width = self.check_dims(batch_size, image_height, image_width)
+ min_batch, max_batch, _, _, _, _, min_latent_height, max_latent_height, min_latent_width, max_latent_width = \
+ self.get_minmax_dims(batch_size, image_height, image_width, static_batch, static_shape)
+ if self.is_prior:
+ return {
+ 'sample': [(self.xB*min_batch, self.prior_dim, min_latent_height, min_latent_width), (self.xB*batch_size, self.prior_dim, latent_height, latent_width), (self.xB*max_batch, self.prior_dim, max_latent_height, max_latent_width)],
+ 'timestep_ratio': [(self.xB*min_batch,), (self.xB*batch_size,), (self.xB*max_batch,)],
+ 'clip_text_pooled': [(self.xB*min_batch, 1, self.embedding_dim), (self.xB*batch_size, 1, self.embedding_dim), (self.xB*max_batch, 1, self.embedding_dim)],
+ 'clip_text': [(self.xB*min_batch, self.text_maxlen, self.embedding_dim), (self.xB*batch_size, self.text_maxlen, self.embedding_dim), (self.xB*max_batch, self.text_maxlen, self.embedding_dim)],
+ 'clip_img': [(self.xB*min_batch, 1, self.image_embedding_dim), (self.xB*batch_size, 1, self.image_embedding_dim), (self.xB*max_batch, 1, self.image_embedding_dim)],
+ }
+ else:
+ return {
+ 'sample': [(self.xB*min_batch, self.decoder_dim, int(min_latent_height * self.latent_dim_scale), int(min_latent_width * self.latent_dim_scale)),
+ (self.xB*batch_size, self.decoder_dim, int(latent_height * self.latent_dim_scale), int(latent_width * self.latent_dim_scale)),
+ (self.xB*max_batch, self.decoder_dim, int(max_latent_height * self.latent_dim_scale), int(max_latent_width * self.latent_dim_scale))],
+ 'timestep_ratio': [(self.xB*min_batch,), (self.xB*batch_size,), (self.xB*max_batch,)],
+ 'clip_text_pooled': [(self.xB*min_batch, 1, self.embedding_dim), (self.xB*batch_size, 1, self.embedding_dim), (self.xB*max_batch, 1, self.embedding_dim)],
+ 'effnet': [(self.xB*min_batch, self.prior_dim, min_latent_height, min_latent_width), (self.xB*batch_size, self.prior_dim, latent_height, latent_width), (self.xB*max_batch, self.prior_dim, max_latent_height, max_latent_width)]
+ }
+
+ def get_shape_dict(self, batch_size, image_height, image_width):
+ latent_height, latent_width = self.check_dims(batch_size, image_height, image_width)
+ if self.is_prior:
+ return {
+ 'sample': (self.xB*batch_size, self.prior_dim, latent_height, latent_width),
+ 'timestep_ratio': (self.xB*batch_size,),
+ 'clip_text_pooled': (self.xB*batch_size, 1, self.embedding_dim),
+ 'clip_text': (self.xB*batch_size, self.text_maxlen, self.embedding_dim),
+ 'clip_img': (self.xB*batch_size, 1, self.image_embedding_dim),
+ 'latent': (self.xB*batch_size, self.prior_dim, latent_height, latent_width)
+ }
+ else:
+ return {
+ 'sample': (self.xB*batch_size, self.decoder_dim, int(latent_height * self.latent_dim_scale), int(latent_width * self.latent_dim_scale)),
+ 'timestep_ratio': (self.xB*batch_size,),
+ 'clip_text_pooled': (self.xB*batch_size, 1, self.embedding_dim),
+ 'effnet': (self.xB*batch_size, self.prior_dim, latent_height, latent_width),
+ 'latent': (self.xB*batch_size, self.decoder_dim, int(latent_height * self.latent_dim_scale), int(latent_width * self.latent_dim_scale))
+ }
+
+ def get_sample_input(self, batch_size, image_height, image_width, static_shape):
+ latent_height, latent_width = self.check_dims(batch_size, image_height, image_width)
+ dtype = torch.float16 if self.fp16 else torch.bfloat16 if self.bf16 else torch.float32
+ if self.is_prior:
+ return (
+ torch.randn(batch_size, self.prior_dim, latent_height, latent_width, dtype=dtype, device=self.device),
+ torch.tensor([1.]*batch_size, dtype=dtype, device=self.device),
+ torch.randn(batch_size, 1, self.embedding_dim, dtype=dtype, device=self.device),
+ {
+ 'clip_text': torch.randn(batch_size, self.text_maxlen, self.embedding_dim, dtype=dtype, device=self.device),
+ 'clip_img': torch.randn(batch_size, 1, self.image_embedding_dim, dtype=dtype, device=self.device),
+ }
+ )
+ else:
+ return (
+ torch.randn(batch_size, self.decoder_dim, int(latent_height * self.latent_dim_scale), int(latent_width * self.latent_dim_scale), dtype=dtype, device=self.device),
+ torch.tensor([1.]*batch_size, dtype=dtype, device=self.device),
+ torch.randn(batch_size, 1, self.embedding_dim, dtype=dtype, device=self.device),
+ {
+ 'effnet': torch.randn(batch_size, self.prior_dim, latent_height, latent_width, dtype=dtype, device=self.device),
+ }
+ )
+
class VAEModel(BaseModel):
def __init__(self,
version,
@@ -1633,6 +1866,90 @@ def get_sample_input(self, batch_size, image_height, image_width, static_shape):
dtype = torch.float16 if self.fp16 else torch.float32
return torch.randn(batch_size, 3, image_height, image_width, dtype=dtype, device=self.device)
+class VQGANModel(BaseModel):
+ def __init__(self,
+ version,
+ pipeline,
+ device,
+ hf_token,
+ verbose,
+ framework_model_dir,
+ fp16=False,
+ bf16=False,
+ max_batch_size=16,
+ compression_factor=42,
+ latent_dim_scale=10.67,
+ scale_factor=0.3764
+ ):
+ super(VQGANModel, self).__init__(version, pipeline, device=device, hf_token=hf_token, verbose=verbose, framework_model_dir=framework_model_dir, fp16=fp16, bf16=bf16, max_batch_size=max_batch_size, compression_factor=compression_factor)
+ self.subfolder = 'vqgan'
+ self.latent_dim_scale = latent_dim_scale
+ self.scale_factor = scale_factor
+
+ def get_model(self, torch_inference=''):
+ model_opts = {'variant': 'bf16', 'torch_dtype': torch.bfloat16} if self.bf16 else {}
+ vqgan_model_dir = get_checkpoint_dir(self.framework_model_dir, self.version, self.pipeline, self.subfolder)
+ vqgan_path = self.get_model_path(vqgan_model_dir, model_opts, model_name='model')
+ if not os.path.exists(vqgan_path):
+ model = PaellaVQModel.from_pretrained(self.path,
+ subfolder=self.subfolder,
+ use_safetensors=self.hf_safetensor,
+ use_auth_token=self.hf_token,
+ **model_opts).to(self.device)
+ model.save_pretrained(vqgan_model_dir, **model_opts)
+ else:
+ print(f"[I] Load VQGAN pytorch model from: {vqgan_path}")
+ model = PaellaVQModel.from_pretrained(vqgan_model_dir, **model_opts).to(self.device)
+ model.forward = model.decode
+ model = optimize_checkpoint(model, torch_inference)
+ return model
+
+ def get_input_names(self):
+ return ['latent']
+
+ def get_output_names(self):
+ return ['images']
+
+ def get_dynamic_axes(self):
+ return {
+ 'latent': {0: 'B', 2: 'H', 3: 'W'},
+ 'images': {0: 'B', 2: '8H', 3: '8W'}
+ }
+
+ def get_input_profile(self, batch_size, image_height, image_width, static_batch, static_shape):
+ latent_height, latent_width = self.check_dims(batch_size, image_height, image_width)
+ min_batch, max_batch, _, _, _, _, min_latent_height, max_latent_height, min_latent_width, max_latent_width = \
+ self.get_minmax_dims(batch_size, image_height, image_width, static_batch, static_shape)
+ return {
+ 'latent': [(min_batch, 4, min_latent_height, min_latent_width), (batch_size, 4, latent_height, latent_width), (max_batch, 4, max_latent_height, max_latent_width)]
+ }
+
+ def get_shape_dict(self, batch_size, image_height, image_width):
+ latent_height, latent_width = self.check_dims(batch_size, image_height, image_width)
+ return {
+ 'latent': (batch_size, 4, latent_height, latent_width),
+ 'images': (batch_size, 3, image_height, image_width)
+ }
+ def get_sample_input(self, batch_size, image_height, image_width, static_shape):
+ latent_height, latent_width = self.check_dims(batch_size, image_height, image_width)
+ dtype = torch.float16 if self.fp16 else torch.bfloat16 if self.bf16 else torch.float32
+ return torch.randn(batch_size, 4, latent_height, latent_width, dtype=dtype, device=self.device)
+
+ def check_dims(self, batch_size, image_height, image_width):
+ latent_height, latent_width = super().check_dims(batch_size, image_height, image_width)
+ latent_height = int(latent_height * self.latent_dim_scale)
+ latent_width = int(latent_width * self.latent_dim_scale)
+ return (latent_height, latent_width)
+
+ def get_minmax_dims(self, batch_size, image_height, image_width, static_batch, static_shape):
+ min_batch, max_batch, min_image_height, max_image_height, min_image_width, max_image_width, min_latent_height, max_latent_height, min_latent_width, max_latent_width = \
+ super().get_minmax_dims(batch_size, image_height, image_width, static_batch, static_shape)
+ min_latent_height = int(min_latent_height * self.latent_dim_scale)
+ min_latent_width = int(min_latent_width * self.latent_dim_scale)
+ max_latent_height = int(max_latent_height * self.latent_dim_scale)
+ max_latent_width = int(max_latent_width * self.latent_dim_scale)
+ return (min_batch, max_batch, min_image_height, max_image_height, min_image_width, max_image_width, min_latent_height, max_latent_height, min_latent_width, max_latent_width)
+
def make_tokenizer(version, pipeline, hf_token, framework_model_dir, subfolder="tokenizer", **kwargs):
tokenizer_model_dir = get_checkpoint_dir(framework_model_dir, version, pipeline.name, subfolder)
if not os.path.exists(tokenizer_model_dir):
diff --git a/demo/Diffusion/requirements.txt b/demo/Diffusion/requirements.txt
index fc2979b5..280fe886 100755
--- a/demo/Diffusion/requirements.txt
+++ b/demo/Diffusion/requirements.txt
@@ -1,17 +1,17 @@
colored
controlnet_aux==0.0.6
cuda-python
-diffusers==0.26.3
+diffusers==0.29.2
ftfy
matplotlib
nvtx
onnx==1.15.0
-onnxruntime==1.16.3
+onnxruntime==1.17.3
opencv-python==4.8.0.74
scipy
transformers==4.36.2
--extra-index-url https://pypi.nvidia.com
-nvidia-modelopt==0.11.2
+nvidia-modelopt[torch,onnx]==0.15.1
onnx-graphsurgeon
polygraphy==0.49.9
sentencepiece
diff --git a/demo/Diffusion/stable_cascade_pipeline.py b/demo/Diffusion/stable_cascade_pipeline.py
new file mode 100644
index 00000000..2d9d639e
--- /dev/null
+++ b/demo/Diffusion/stable_cascade_pipeline.py
@@ -0,0 +1,340 @@
+#
+# SPDX-FileCopyrightText: Copyright (c) 1993-2024 NVIDIA CORPORATION & AFFILIATES. All rights reserved.
+# SPDX-License-Identifier: Apache-2.0
+#
+# Licensed under the Apache License, Version 2.0 (the "License");
+# you may not use this file except in compliance with the License.
+# You may obtain a copy of the License at
+#
+# http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing, software
+# distributed under the License is distributed on an "AS IS" BASIS,
+# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+# See the License for the specific language governing permissions and
+# limitations under the License.
+#
+
+from cuda import cudart
+from diffusers import (
+ DDPMWuerstchenScheduler
+)
+import inspect
+from models import (
+ make_tokenizer,
+ CLIPWithProjModel,
+ UNetCascadeModel,
+ VQGANModel
+)
+import tensorrt as trt
+import time
+import torch
+from utilities import (
+ PIPELINE_TYPE,
+ TRT_LOGGER,
+)
+from stable_diffusion_pipeline import StableDiffusionPipeline
+
+class StableCascadePipeline(StableDiffusionPipeline):
+ """
+ Application showcasing the acceleration of Stable Cascade pipelines using NVidia TensorRT.
+ """
+ def __init__(
+ self,
+ version='cascade',
+ pipeline_type=PIPELINE_TYPE.CASCADE_PRIOR,
+ latent_dim_scale=10.67,
+ lite=False,
+ **kwargs
+ ):
+ """
+ Initializes the Stable Cascade pipeline.
+
+ Args:
+ version (str):
+ The version of the pipeline. Should be one of [cascade]
+ pipeline_type (PIPELINE_TYPE):
+ Type of current pipeline.
+ latent_dim_scale (float):
+ Multiplier to determine the VQ latent space size from the image embeddings. If the image embeddings are
+ height=24 and width=24, the VQ latent shape needs to be height=int(24*10.67)=256 and
+ width=int(24*10.67)=256 in order to match the training conditions.
+ lite (bool):
+ Boolean indicating if the Lite Version of the Stage B and Stage C models is to be used
+ """
+ super().__init__(
+ version=version,
+ pipeline_type=pipeline_type,
+ **kwargs
+ )
+ self.config['clip_hidden_states'] = True
+ # from Diffusers: https://github.com/huggingface/diffusers/blob/main/src/diffusers/pipelines/stable_cascade/pipeline_stable_cascade.py#L91C9-L91C41
+ self.latent_dim_scale = latent_dim_scale
+ self.lite = lite
+
+ def initializeModels(self, framework_model_dir, int8, fp8):
+ # Load text tokenizer(s)
+ self.tokenizer = make_tokenizer(self.version, self.pipeline_type, self.hf_token, framework_model_dir)
+
+ # Load pipeline models
+ models_args = {'version': self.version, 'pipeline': self.pipeline_type, 'device': self.device,
+ 'hf_token': self.hf_token, 'verbose': self.verbose, 'framework_model_dir': framework_model_dir,
+ 'max_batch_size': self.max_batch_size}
+
+ self.fp16 = False # TODO: enable FP16 mode for decoder model (requires strongly typed engine)
+ self.bf16 = True
+ if 'clip' in self.stages:
+ self.models['clip'] = CLIPWithProjModel(**models_args, fp16=self.fp16, bf16=self.bf16, output_hidden_states=self.config.get('clip_hidden_states', False), subfolder='text_encoder')
+
+ if 'unet' in self.stages:
+ self.models['unet'] = UNetCascadeModel(**models_args, fp16=self.fp16, bf16=self.bf16, lite=self.lite, do_classifier_free_guidance=self.do_classifier_free_guidance)
+
+ if 'vqgan' in self.stages:
+ self.models['vqgan'] = VQGANModel(**models_args, fp16=self.fp16, bf16=self.bf16, latent_dim_scale = self.latent_dim_scale)
+
+ def encode_prompt(self, prompt, negative_prompt, encoder='clip', pooled_outputs=False, output_hidden_states=False):
+ self.profile_start('clip', color='green')
+
+ tokenizer = self.tokenizer
+
+ def tokenize(prompt, output_hidden_states):
+ text_inputs = tokenizer(
+ prompt,
+ padding="max_length",
+ max_length=tokenizer.model_max_length,
+ truncation=True,
+ return_tensors="pt",
+ )
+ text_input_ids = text_inputs.input_ids.type(torch.int32).to(self.device)
+ attention_mask = text_inputs.attention_mask.type(torch.int32).to(self.device)
+
+ text_hidden_states = None
+ if self.torch_inference:
+ outputs = self.torch_models[encoder](text_input_ids, attention_mask=attention_mask, output_hidden_states=output_hidden_states)
+ text_embeddings = outputs[0].clone()
+ if output_hidden_states:
+ hidden_state_layer = -1
+ text_hidden_states = outputs['hidden_states'][hidden_state_layer].clone()
+ else:
+ # NOTE: output tensor for CLIP must be cloned because it will be overwritten when called again for negative prompt
+ outputs = self.runEngine(encoder, {'input_ids': text_input_ids, 'attention_mask': attention_mask})
+ text_embeddings = outputs['text_embeddings'].clone()
+ if output_hidden_states:
+ text_hidden_states = outputs['hidden_states'].clone()
+
+ return text_embeddings, text_hidden_states
+
+ # Tokenize prompt
+ text_embeddings, text_hidden_states = tokenize(prompt, output_hidden_states)
+
+ if self.do_classifier_free_guidance:
+ # Tokenize negative prompt
+ uncond_embeddings, uncond_hidden_states = tokenize(negative_prompt, output_hidden_states)
+
+ # Concatenate the unconditional and text embeddings into a single batch to avoid doing two forward passes for classifier free guidance
+ text_embeddings = torch.cat([text_embeddings, uncond_embeddings])
+
+ if pooled_outputs:
+ pooled_output = text_embeddings
+
+ if output_hidden_states:
+ text_embeddings = torch.cat([text_hidden_states, uncond_hidden_states]) if self.do_classifier_free_guidance else text_hidden_states
+
+ self.profile_stop('clip')
+ if pooled_outputs:
+ return text_embeddings, pooled_output
+ return text_embeddings
+
+ def denoise_latent(self,
+ latents,
+ pooled_embeddings,
+ text_embeddings=None,
+ image_embeds=None,
+ effnet=None,
+ denoiser='unet',
+ timesteps=None,
+ ):
+
+ do_autocast = False
+ with torch.autocast('cuda', enabled=do_autocast):
+ self.profile_start('denoise', color='blue')
+ for step_index, timestep in enumerate(timesteps):
+ # ratio input required for stable cascade prior
+ timestep_ratio = timestep.expand(latents.size(0)).to(latents.dtype)
+ # Expand the latents and timestep_ratio if we are doing classifier free guidance
+ latent_model_input = torch.cat([latents] * 2) if self.do_classifier_free_guidance else latents
+ timestep_ratio_input = torch.cat([timestep_ratio] * 2) if self.do_classifier_free_guidance else timestep_ratio
+
+ params = {"sample": latent_model_input, "timestep_ratio": timestep_ratio_input, "clip_text_pooled": pooled_embeddings}
+ if text_embeddings is not None:
+ params.update({'clip_text': text_embeddings})
+ if image_embeds is not None:
+ params.update({'clip_img': image_embeds})
+ if effnet is not None:
+ params.update({'effnet': effnet})
+
+ # Predict the noise residual
+ if self.torch_inference:
+ noise_pred = self.torch_models[denoiser](**params)['sample']
+ else:
+ noise_pred = self.runEngine(denoiser, params)['latent']
+
+ # Perform guidance
+ if self.do_classifier_free_guidance:
+ noise_pred_text, noise_pred_uncond = noise_pred.chunk(2)
+ noise_pred = noise_pred_uncond + self.guidance_scale * (noise_pred_text - noise_pred_uncond)
+
+ # from diffusers (prepare_extra_step_kwargs)
+ extra_step_kwargs = {}
+ if "eta" in set(inspect.signature(self.scheduler.step).parameters.keys()):
+ # TODO: configurable eta
+ eta = 0.0
+ extra_step_kwargs["eta"] = eta
+ if "generator" in set(inspect.signature(self.scheduler.step).parameters.keys()):
+ extra_step_kwargs["generator"] = self.generator
+
+ latents = self.scheduler.step(noise_pred, timestep_ratio, latents, **extra_step_kwargs, return_dict=False)[0]
+
+ latents = latents.to(dtype=torch.bfloat16 if self.bf16 else torch.float32)
+
+ self.profile_stop('denoise')
+ return latents
+
+ def decode_latent(self, latents):
+ self.profile_start('vqgan', color='red')
+ latents = self.models['vqgan'].scale_factor * latents
+ if self.torch_inference:
+ images = self.torch_models['vqgan'](latents)['sample']
+ else:
+ images = self.runEngine('vqgan', {'latent': latents})['images']
+ self.profile_stop('vqgan')
+ return images
+
+ def print_summary(self, denoising_steps, walltime_ms, batch_size):
+ print('|-----------------|--------------|')
+ print('| {:^15} | {:^12} |'.format('Module', 'Latency'))
+ print('|-----------------|--------------|')
+ print('| {:^15} | {:>9.2f} ms |'.format('CLIP', cudart.cudaEventElapsedTime(self.events['clip'][0], self.events['clip'][1])[1]))
+ print('| {:^15} | {:>9.2f} ms |'.format('UNet'+' x '+str(denoising_steps), cudart.cudaEventElapsedTime(self.events['denoise'][0], self.events['denoise'][1])[1]))
+ if 'vqgan' in self.stages:
+ print('| {:^15} | {:>9.2f} ms |'.format('VQGAN', cudart.cudaEventElapsedTime(self.events['vqgan'][0], self.events['vqgan'][1])[1]))
+ print('|-----------------|--------------|')
+ print('| {:^15} | {:>9.2f} ms |'.format('Pipeline', walltime_ms))
+ print('|-----------------|--------------|')
+ print('Throughput: {:.2f} image/s'.format(batch_size*1000./walltime_ms))
+
+ def infer(
+ self,
+ prompt,
+ negative_prompt,
+ image_height,
+ image_width,
+ image_embeddings=None,
+ warmup=False,
+ verbose=False,
+ save_image=True,
+ ):
+ """
+ Run the diffusion pipeline.
+
+ Args:
+ prompt (str):
+ The text prompt to guide image generation.
+ negative_prompt (str):
+ The prompt not to guide the image generation.
+ image_height (int):
+ Height (in pixels) of the image to be generated. Must be a multiple of 8.
+ image_width (int):
+ Width (in pixels) of the image to be generated. Must be a multiple of 8.
+ image_embeddings (`torch.FloatTensor` or `List[torch.FloatTensor]`):
+ Image Embeddings either extracted from an image or generated by a Prior Model.
+ warmup (bool):
+ Indicate if this is a warmup run.
+ verbose (bool):
+ Verbose in logging
+ save_image (bool):
+ Save the generated image (if applicable)
+ """
+ if self.pipeline_type.is_cascade_decoder():
+ assert image_embeddings is not None, "Image Embeddings are required to run the decoder. Provided None"
+ assert len(prompt) == len(negative_prompt)
+ batch_size = len(prompt)
+
+ # Spatial dimensions of latent tensor
+ latent_height = image_height // 42
+ latent_width = image_width // 42
+
+ if image_embeddings is not None:
+ assert latent_height == image_embeddings.shape[-2]
+ assert latent_width == image_embeddings.shape[-1]
+
+ if self.generator and self.seed:
+ self.generator.manual_seed(self.seed)
+
+ num_inference_steps = self.denoising_steps
+
+ with torch.inference_mode(), trt.Runtime(TRT_LOGGER):
+ torch.cuda.synchronize()
+ e2e_tic = time.perf_counter()
+
+ denoise_kwargs = {}
+ # TODO: support custom timesteps
+ timesteps = None
+ if timesteps is not None:
+ if not ("timesteps" in set(inspect.signature(self.scheduler.set_timesteps).parameters.keys())):
+ raise ValueError(
+ f"The current scheduler class {self.scheduler.__class__}'s `set_timesteps` does not support custom"
+ f" timestep schedules. Please check whether you are using the correct scheduler."
+ )
+ self.scheduler.set_timesteps(timesteps=timesteps, device=self.device)
+ assert self.denoising_steps == len(self.scheduler.timesteps)
+ else:
+ self.scheduler.set_timesteps(self.denoising_steps, device=self.device)
+ timesteps = self.scheduler.timesteps.to(self.device)
+ if isinstance(self.scheduler, DDPMWuerstchenScheduler):
+ timesteps = timesteps[:-1]
+ denoise_kwargs.update({'timesteps': timesteps})
+
+ # Initialize latents
+ latents_dtpye = torch.float16 if self.fp16 else torch.bfloat16 if self.bf16 else torch.float32
+ latents = self.initialize_latents(
+ batch_size=batch_size,
+ unet_channels=16 if self.pipeline_type.is_cascade_prior() else 4, # TODO: can we query "in_channels" from config
+ latent_height=latent_height if self.pipeline_type.is_cascade_prior() else int(latent_height * self.latent_dim_scale),
+ latent_width=latent_width if self.pipeline_type.is_cascade_prior() else int(latent_width * self.latent_dim_scale),
+ latents_dtype=latents_dtpye
+ )
+
+ # CLIP text encoder(s)
+ text_embeddings, pooled_embeddings = self.encode_prompt(prompt, negative_prompt,
+ encoder='clip', pooled_outputs=True, output_hidden_states=True)
+
+ if self.pipeline_type.is_cascade_prior():
+ denoise_kwargs.update({'text_embeddings': text_embeddings})
+
+ # image embeds
+ image_embeds_pooled = torch.zeros(batch_size, 1, 768, device=self.device, dtype=latents_dtpye)
+ image_embeds = (torch.cat([image_embeds_pooled, torch.zeros_like(image_embeds_pooled)]) if self.do_classifier_free_guidance else image_embeddings)
+ denoise_kwargs.update({'image_embeds': image_embeds})
+ else:
+ effnet = (torch.cat([image_embeddings, torch.zeros_like(image_embeddings)]) if self.do_classifier_free_guidance else image_embeddings)
+ denoise_kwargs.update({'effnet': effnet})
+
+ # UNet denoiser
+ latents = self.denoise_latent(latents, pooled_embeddings.unsqueeze(1), denoiser='unet', **denoise_kwargs)
+
+ if not self.return_latents:
+ images = self.decode_latent(latents)
+
+ torch.cuda.synchronize()
+ e2e_toc = time.perf_counter()
+
+ walltime_ms = (e2e_toc - e2e_tic) * 1000.
+ if not warmup:
+ self.print_summary(num_inference_steps, walltime_ms, batch_size)
+ if not self.return_latents and save_image:
+ # post-process images
+ images = ((images) * 255).clamp(0, 255).detach().permute(0, 2, 3, 1).round().type(torch.uint8).cpu().numpy()
+ self.save_image(images, self.pipeline_type.name.lower(), prompt, self.seed)
+
+ return (latents, walltime_ms) if self.return_latents else (images, walltime_ms)
diff --git a/demo/Diffusion/stable_diffusion_3_pipeline.py b/demo/Diffusion/stable_diffusion_3_pipeline.py
index ea691e96..8749354a 100644
--- a/demo/Diffusion/stable_diffusion_3_pipeline.py
+++ b/demo/Diffusion/stable_diffusion_3_pipeline.py
@@ -562,6 +562,8 @@ def infer(
num_inference_steps = int(self.denoising_steps * self.denoising_percentage)
self.print_summary(num_inference_steps, walltime_ms, batch_size)
if save_image:
+ # post-process images
+ images = ((images + 1) * 255 / 2).clamp(0, 255).detach().permute(0, 2, 3, 1).round().type(torch.uint8).cpu().numpy()
self.save_image(images, self.pipeline_type.name.lower(), prompt, self.seed)
return images, walltime_ms
diff --git a/demo/Diffusion/stable_diffusion_pipeline.py b/demo/Diffusion/stable_diffusion_pipeline.py
index c1316c66..a9500f9d 100644
--- a/demo/Diffusion/stable_diffusion_pipeline.py
+++ b/demo/Diffusion/stable_diffusion_pipeline.py
@@ -26,6 +26,7 @@
LCMScheduler, LMSDiscreteScheduler,
PNDMScheduler,
UniPCMultistepScheduler,
+ DDPMWuerstchenScheduler
)
from hashlib import md5
import inspect
@@ -65,6 +66,11 @@
filter_func,
quantize_lvl,
get_int8_config,
+ check_lora,
+ set_fmha,
+ generate_fp8_scales,
+ SD_FP8_FP16_DEFAULT_CONFIG,
+ SD_FP8_FP32_DEFAULT_CONFIG,
)
class StableDiffusionPipeline:
@@ -171,6 +177,10 @@ def __init__(
self.stages = ['clip2', 'unetxl', 'vae']
elif self.pipeline_type.is_img2vid():
self.stages = ['clip-vis', 'clip-imgfe', 'unet-temp', 'vae-temp']
+ elif self.pipeline_type.is_cascade_prior():
+ self.stages = ['clip', 'unet']
+ elif self.pipeline_type.is_cascade_decoder():
+ self.stages = ['clip', 'unet', 'vqgan']
else:
raise ValueError(f"Unsupported pipeline {self.pipeline_type.name}.")
self.return_latents = return_latents
@@ -186,7 +196,8 @@ def __init__(
'2.1': 'DDIM',
'xl-1.0' : 'Euler',
'xl-turbo': 'EulerA',
- 'svd-xt-1.1': 'Euler'
+ 'svd-xt-1.1': 'Euler',
+ 'cascade': 'DDPMWuerstchen'
}
if not scheduler:
@@ -212,6 +223,8 @@ def makeScheduler(cls, subfolder="scheduler", **kwargs):
self.scheduler = makeScheduler(PNDMScheduler)
elif scheduler == "UniPC":
self.scheduler = makeScheduler(UniPCMultistepScheduler)
+ elif scheduler == "DDPMWuerstchen":
+ self.scheduler = makeScheduler(DDPMWuerstchenScheduler)
else:
raise ValueError(f"Unsupported scheduler {scheduler}. Should be either DDIM, DDPM, EulerA, Euler, LCM, LMSD, PNDM, or UniPC.")
@@ -256,7 +269,7 @@ def loadResources(self, image_height, image_width, batch_size, seed):
self.generator = torch.Generator(device="cuda").manual_seed(seed)
# Create CUDA events and stream
- for stage in ['clip', 'denoise', 'vae', 'vae_encoder']:
+ for stage in ['clip', 'denoise', 'vae', 'vae_encoder', 'vqgan']:
self.events[stage] = [cudart.cudaEventCreate()[1], cudart.cudaEventCreate()[1]]
self.stream = cudart.cudaStreamCreate()[1]
@@ -310,6 +323,49 @@ def getStateDictPath(self, model_name, onnx_dir, suffix=''):
os.makedirs(onnx_model_dir, exist_ok=True)
return os.path.join(onnx_model_dir, 'state_dict.pt')
+ def initializeModels(self, framework_model_dir, int8, fp8):
+ # Load text tokenizer(s)
+ if not self.pipeline_type.is_sd_xl_refiner():
+ self.tokenizer = make_tokenizer(self.version, self.pipeline_type, self.hf_token, framework_model_dir)
+ if self.pipeline_type.is_sd_xl():
+ self.tokenizer2 = make_tokenizer(self.version, self.pipeline_type, self.hf_token, framework_model_dir, subfolder='tokenizer_2')
+
+ # Load pipeline models
+ models_args = {'version': self.version, 'pipeline': self.pipeline_type, 'device': self.device,
+ 'hf_token': self.hf_token, 'verbose': self.verbose, 'framework_model_dir': framework_model_dir,
+ 'max_batch_size': self.max_batch_size}
+
+ if 'clip' in self.stages:
+ subfolder = 'text_encoder'
+ self.models['clip'] = CLIPModel(**models_args, fp16=True, embedding_dim=get_clip_embedding_dim(self.version, self.pipeline_type), output_hidden_states=self.config.get('clip_hidden_states', False), subfolder=subfolder)
+
+ if 'clip2' in self.stages:
+ subfolder = 'text_encoder_2'
+ self.models['clip2'] = CLIPWithProjModel(**models_args, fp16=True, output_hidden_states=self.config.get('clip_hidden_states', False), subfolder=subfolder)
+
+ lora_dict, lora_alphas = (None, None)
+ if 'unet' in self.stages:
+ if self.lora_loader:
+ lora_dict, lora_alphas = self.lora_loader.get_dicts('unet')
+ assert len(lora_dict) == len(self.lora_scales)
+ self.models['unet'] = UNetModel(**models_args, fp16=True, int8=int8, fp8=fp8, controlnets=self.controlnets,
+ lora_scales=self.lora_scales, lora_dict=lora_dict, lora_alphas=lora_alphas, do_classifier_free_guidance=self.do_classifier_free_guidance)
+
+ if 'unetxl' in self.stages:
+ if not self.pipeline_type.is_sd_xl_refiner() and self.lora_loader:
+ lora_dict, lora_alphas = self.lora_loader.get_dicts('unet')
+ assert len(lora_dict) == len(self.lora_scales)
+ self.models['unetxl'] = UNetXLModel(**models_args, fp16=True, int8=int8, fp8=fp8,
+ lora_scales=self.lora_scales, lora_dict=lora_dict, lora_alphas=lora_alphas, do_classifier_free_guidance=self.do_classifier_free_guidance)
+
+ vae_fp16 = not self.pipeline_type.is_sd_xl()
+
+ if 'vae' in self.stages:
+ self.models['vae'] = VAEModel(**models_args, fp16=vae_fp16)
+
+ if 'vae_encoder' in self.stages:
+ self.models['vae_encoder'] = VAEEncoderModel(**models_args, fp16=vae_fp16)
+
def loadEngines(
self,
engine_dir,
@@ -325,6 +381,7 @@ def loadEngines(
enable_all_tactics=False,
timing_cache=None,
int8=False,
+ fp8=False,
quantization_level=2.5,
quantization_percentile=1.0,
quantization_alpha=0.8,
@@ -361,7 +418,9 @@ def loadEngines(
timing_cache (str):
Path to the timing cache to speed up TensorRT build.
int8 (bool):
- Whether to quantize to int8 format or not (SDXL only).
+ Whether to quantize to int8 format or not (SDXL, SD15 and SD21 only).
+ fp8 (bool):
+ Whether to quantize to fp8 format or not (SDXL, SD15 and SD21 only).
quantization_level (float):
Controls which layers to quantize. 1: CNN, 2: CNN+FFN, 2.5: CNN+FFN+QKV, 3: CNN+FC
quantization_percentile (float):
@@ -382,47 +441,8 @@ def loadEngines(
print(f"[I] Create directory: {directory}")
pathlib.Path(directory).mkdir(parents=True)
- # Load text tokenizer(s)
- if not self.pipeline_type.is_sd_xl_refiner():
- self.tokenizer = make_tokenizer(self.version, self.pipeline_type, self.hf_token, framework_model_dir)
- if self.pipeline_type.is_sd_xl():
- self.tokenizer2 = make_tokenizer(self.version, self.pipeline_type, self.hf_token, framework_model_dir, subfolder='tokenizer_2')
-
- # Load pipeline models
- models_args = {'version': self.version, 'pipeline': self.pipeline_type, 'device': self.device,
- 'hf_token': self.hf_token, 'verbose': self.verbose, 'framework_model_dir': framework_model_dir,
- 'max_batch_size': self.max_batch_size}
-
- if 'clip' in self.stages:
- subfolder = 'text_encoder'
- self.models['clip'] = CLIPModel(**models_args, fp16=True, embedding_dim=get_clip_embedding_dim(self.version, self.pipeline_type), output_hidden_states=self.config.get('clip_hidden_states', False), subfolder=subfolder)
-
- if 'clip2' in self.stages:
- subfolder = 'text_encoder_2'
- self.models['clip2'] = CLIPWithProjModel(**models_args, fp16=True, output_hidden_states=self.config.get('clip_hidden_states', False), subfolder=subfolder)
-
- lora_dict, lora_alphas = (None, None)
- if 'unet' in self.stages:
- if self.lora_loader:
- lora_dict, lora_alphas = self.lora_loader.get_dicts('unet')
- assert len(lora_dict) == len(self.lora_scales)
- self.models['unet'] = UNetModel(**models_args, fp16=True, controlnets=self.controlnets,
- lora_scales=self.lora_scales, lora_dict=lora_dict, lora_alphas=lora_alphas, do_classifier_free_guidance=self.do_classifier_free_guidance)
-
- if 'unetxl' in self.stages:
- if not self.pipeline_type.is_sd_xl_refiner() and self.lora_loader:
- lora_dict, lora_alphas = self.lora_loader.get_dicts('unet')
- assert len(lora_dict) == len(self.lora_scales)
- self.models['unetxl'] = UNetXLModel(**models_args, fp16=True,
- lora_scales=self.lora_scales, lora_dict=lora_dict, lora_alphas=lora_alphas, do_classifier_free_guidance=self.do_classifier_free_guidance)
-
- vae_fp16 = not self.pipeline_type.is_sd_xl()
-
- if 'vae' in self.stages:
- self.models['vae'] = VAEModel(**models_args, fp16=vae_fp16)
-
- if 'vae_encoder' in self.stages:
- self.models['vae_encoder'] = VAEEncoderModel(**models_args, fp16=vae_fp16)
+ # Initialize models
+ self.initializeModels(framework_model_dir, int8, fp8)
# Configure pipeline models to load
model_names = self.models.keys()
@@ -434,10 +454,17 @@ def loadEngines(
torch_fallback = dict(zip(model_names, [self.torch_inference for model_name in model_names]))
model_suffix = dict(zip(model_names, [lora_suffix if do_lora_merge[model_name] else '' for model_name in model_names]))
use_int8 = dict.fromkeys(model_names, False)
+ use_fp8 = dict.fromkeys(model_names, False)
if int8:
- assert self.pipeline_type.is_sd_xl_base(), "int8 quantization only supported for SDXL pipeline"
- use_int8['unetxl'] = True
- model_suffix['unetxl'] += f"-int8.l{quantization_level}.bs2.s{self.denoising_steps}.c{calibration_size}.p{quantization_percentile}.a{quantization_alpha}"
+ assert self.pipeline_type.is_sd_xl_base() or self.version in ["1.5", "2.1", "2.1-base"], "int8 quantization only supported for SDXL, SD1.5 and SD2.1 pipeline"
+ model_name = 'unetxl' if self.pipeline_type.is_sd_xl() else 'unet'
+ use_int8[model_name] = True
+ model_suffix[model_name] += f"-int8.l{quantization_level}.bs2.s{self.denoising_steps}.c{calibration_size}.p{quantization_percentile}.a{quantization_alpha}"
+ elif fp8:
+ assert self.pipeline_type.is_sd_xl() or self.version in ["1.5", "2.1", "2.1-base"], "fp8 quantization only supported for SDXL, SD1.5 and SD2.1 pipeline"
+ model_name = 'unetxl' if self.pipeline_type.is_sd_xl() else 'unet'
+ use_fp8[model_name] = True
+ model_suffix[model_name] += f"-fp8.l{quantization_level}.bs2.s{self.denoising_steps}.c{calibration_size}.p{quantization_percentile}.a{quantization_alpha}"
onnx_path = dict(zip(model_names, [self.getOnnxPath(model_name, onnx_dir, opt=False, suffix=model_suffix[model_name]) for model_name in model_names]))
onnx_opt_path = dict(zip(model_names, [self.getOnnxPath(model_name, onnx_dir, suffix=model_suffix[model_name]) for model_name in model_names]))
engine_path = dict(zip(model_names, [self.getEnginePath(model_name, engine_dir, do_engine_refit[model_name], suffix=model_suffix[model_name]) for model_name in model_names]))
@@ -451,30 +478,25 @@ def loadEngines(
do_export_weights_map = weights_map_path[model_name] and not os.path.exists(weights_map_path[model_name])
if do_export_onnx or do_export_weights_map:
# Non-quantized ONNX export
- if not use_int8[model_name]:
+ if not use_int8[model_name] and not use_fp8[model_name]:
obj.export_onnx(onnx_path[model_name], onnx_opt_path[model_name], onnx_opset, opt_image_height, opt_image_width, enable_lora_merge=do_lora_merge[model_name], static_shape=static_shape)
else:
+ pipeline = obj.get_pipeline()
+ model = pipeline.unet
+ if use_fp8[model_name] and quantization_level == 4.0:
+ set_fmha(model)
+
state_dict_path = self.getStateDictPath(model_name, onnx_dir, suffix=model_suffix[model_name])
if not os.path.exists(state_dict_path):
print(f"[I] Calibrated weights not found, generating {state_dict_path}")
- pipeline = obj.get_pipeline()
- model = pipeline.unet
calibration_file = os.path.join(os.path.dirname(os.path.realpath(__file__)), 'calibration-prompts.txt')
calibration_prompts = load_calib_prompts(calib_batch_size, calibration_file)
# TODO check size > calibration_size
- quant_config = get_int8_config(
- model,
- quantization_level,
- quantization_alpha,
- quantization_percentile,
- self.denoising_steps
- )
-
- def do_calibrate(base, calibration_prompts, **kwargs):
+ def do_calibrate(pipeline, calibration_prompts, **kwargs):
for i_th, prompts in enumerate(calibration_prompts):
if i_th >= kwargs["calib_size"]:
return
- base(
+ pipeline(
prompt=prompts,
num_inference_steps=kwargs["n_steps"],
negative_prompt=[
@@ -482,35 +504,43 @@ def do_calibrate(base, calibration_prompts, **kwargs):
]
* len(prompts),
).images
-
- def calibration_loop(unet):
- pipeline.model = unet
+
+ def forward_loop(model):
+ pipeline.unet = model
do_calibrate(
- base=pipeline,
+ pipeline=pipeline,
calibration_prompts=calibration_prompts,
calib_size=calibration_size // calib_batch_size,
n_steps=self.denoising_steps,
)
-
- print(f"[I] Performing int8 calibration for {calibration_size} steps.")
- mtq.quantize(model, quant_config, forward_loop=calibration_loop)
+
+ print(f"[I] Performing calibration for {calibration_size} steps.")
+ if use_int8[model_name]:
+ quant_config = get_int8_config(
+ model,
+ quantization_level,
+ quantization_alpha,
+ quantization_percentile,
+ self.denoising_steps
+ )
+ elif use_fp8[model_name]:
+ check_lora(model)
+ quant_config = SD_FP8_FP32_DEFAULT_CONFIG if self.version == "2.1" else SD_FP8_FP16_DEFAULT_CONFIG
+ mtq.quantize(model, quant_config, forward_loop)
mto.save(model, state_dict_path)
+ else:
+ mto.restore(model, state_dict_path)
print(f"[I] Generating quantized ONNX model: {onnx_opt_path[model_name]}")
if not os.path.exists(onnx_path[model_name]):
- model = obj.get_model()
- mto.restore(model, state_dict_path)
- quantize_lvl(model, quantization_level)
+ quantize_lvl(model, quantization_level)
mtq.disable_quantizer(model, filter_func)
- model.to(torch.float32).to("cpu") # QDQ needs to be in FP32
- # WAR to enable ONNX export of quantized UNet
- obj.device="cpu"
- obj.fp16=False
+ if use_fp8[model_name]:
+ generate_fp8_scales(model)
else:
- model = None
- obj.export_onnx(onnx_path[model_name], onnx_opt_path[model_name], onnx_opset, opt_image_height, opt_image_width, custom_model=model)
- obj.fp16=True # Part of WAR, UNET obj.fp16 defaults to True so it is safe to reset this way
-
+ model = None
+ obj.export_onnx(onnx_path[model_name], onnx_opt_path[model_name], onnx_opset, opt_image_height, opt_image_width, custom_model=model, static_shape=static_shape)
+
# FIXME do_export_weights_map needs ONNX graph
if do_export_weights_map:
print(f"[I] Saving weights map: {weights_map_path[model_name]}")
@@ -523,14 +553,20 @@ def calibration_loop(unet):
engine = Engine(engine_path[model_name])
if not os.path.exists(engine_path[model_name]):
update_output_names = obj.get_output_names() + obj.extra_output_names if obj.extra_output_names else None
+ fp16amp = obj.fp16 if not use_fp8[model_name] else False
+ bf16amp = obj.bf16 if not use_fp8[model_name] else False
+ strongly_typed = False if not use_fp8[model_name] else True
extra_build_args = {'verbose': self.verbose}
if use_int8[model_name]:
extra_build_args['int8'] = True
extra_build_args['precision_constraints'] = 'prefer'
extra_build_args['builder_optimization_level'] = 4
- fp16amp = obj.fp16
+ elif use_fp8[model_name]:
+ extra_build_args['builder_optimization_level'] = 4
engine.build(onnx_opt_path[model_name],
+ strongly_typed=strongly_typed,
fp16=fp16amp,
+ bf16=bf16amp,
input_profile=obj.get_input_profile(
opt_batch_size, opt_image_height, opt_image_width,
static_batch=static_batch, static_shape=static_shape
@@ -588,8 +624,8 @@ def runEngine(self, model_name, feed_dict):
engine = self.engine[model_name]
return engine.infer(feed_dict, self.stream, use_cuda_graph=self.use_cuda_graph)
- def initialize_latents(self, batch_size, unet_channels, latent_height, latent_width):
- latents_dtype = torch.float32 # text_embeddings.dtype
+ def initialize_latents(self, batch_size, unet_channels, latent_height, latent_width, latents_dtype=torch.float32):
+ latents_dtype = latents_dtype # text_embeddings.dtype
latents_shape = (batch_size, unet_channels, latent_height, latent_width)
latents = torch.randn(latents_shape, device=self.device, dtype=latents_dtype, generator=self.generator)
# Scale the initial noise by the standard deviation required by the scheduler
@@ -1002,6 +1038,8 @@ def _get_add_time_ids(original_size, crops_coords_top_left, target_size, dtype,
if not warmup:
self.print_summary(num_inference_steps, walltime_ms, batch_size)
if not self.return_latents and save_image:
+ # post-process images
+ images = ((images + 1) * 255 / 2).clamp(0, 255).detach().permute(0, 2, 3, 1).round().type(torch.uint8).cpu().numpy()
self.save_image(images, self.pipeline_type.name.lower(), prompt, self.seed)
return (latents, walltime_ms) if self.return_latents else (images, walltime_ms)
diff --git a/demo/Diffusion/utilities.py b/demo/Diffusion/utilities.py
index 6dece14f..911139b6 100644
--- a/demo/Diffusion/utilities.py
+++ b/demo/Diffusion/utilities.py
@@ -19,7 +19,6 @@
from collections import OrderedDict
from cuda import cudart
from diffusers.models.lora import LoRACompatibleConv, LoRACompatibleLinear
-from diffusers.utils.torch_utils import randn_tensor
from enum import Enum, auto
import gc
from io import BytesIO
@@ -50,26 +49,17 @@
def GiB(val):
return val * 1 << 30
-# Map of numpy dtype -> torch dtype
-numpy_to_torch_dtype_dict = {
- np.uint8 : torch.uint8,
- np.int8 : torch.int8,
- np.int16 : torch.int16,
- np.int32 : torch.int32,
- np.int64 : torch.int64,
- np.float16 : torch.float16,
- np.float32 : torch.float32,
- np.float64 : torch.float64,
- np.complex64 : torch.complex64,
- np.complex128 : torch.complex128
+# Map of TensorRT dtype -> torch dtype
+trt_to_torch_dtype_dict = {
+ trt.DataType.BOOL : torch.bool,
+ trt.DataType.UINT8 : torch.uint8,
+ trt.DataType.INT8 : torch.int8,
+ trt.DataType.INT32 : torch.int32,
+ trt.DataType.INT64 : torch.int64,
+ trt.DataType.HALF : torch.float16,
+ trt.DataType.FLOAT : torch.float32,
+ trt.DataType.BF16 : torch.bfloat16
}
-if np.version.full_version >= "1.24.0":
- numpy_to_torch_dtype_dict[np.bool_] = torch.bool
-else:
- numpy_to_torch_dtype_dict[np.bool] = torch.bool
-
-# Map of torch dtype -> numpy dtype
-torch_to_numpy_dtype_dict = {value : key for (key, value) in numpy_to_torch_dtype_dict.items()}
def unload_model(model):
if model:
@@ -160,6 +150,8 @@ class PIPELINE_TYPE(Enum):
CONTROLNET = auto()
XL_BASE = auto()
XL_REFINER = auto()
+ CASCADE_PRIOR = auto()
+ CASCADE_DECODER = auto()
def is_txt2img(self):
return self == self.TXT2IMG
@@ -185,6 +177,15 @@ def is_sd_xl_refiner(self):
def is_sd_xl(self):
return self.is_sd_xl_base() or self.is_sd_xl_refiner()
+ def is_cascade_prior(self):
+ return self == self.CASCADE_PRIOR
+
+ def is_cascade_decoder(self):
+ return self == self.CASCADE_DECODER
+
+ def is_cascade(self):
+ return self.is_cascade_prior() or self.is_cascade_decoder()
+
class Engine():
def __init__(
self,
@@ -236,9 +237,12 @@ def refit(self, refit_weights, is_fp16):
def build(self,
onnx_path,
+ strongly_typed=False,
fp16=True,
+ bf16=False,
tf32=False,
int8=False,
+ fp8=False,
input_profile=None,
enable_refit=False,
enable_all_tactics=False,
@@ -261,7 +265,11 @@ def build(self,
flags = []
if native_instancenorm:
flags.append(trt.OnnxParserFlag.NATIVE_INSTANCENORM)
- network = network_from_onnx_path(onnx_path, flags=flags)
+ network = network_from_onnx_path(
+ onnx_path,
+ flags=flags,
+ strongly_typed=strongly_typed
+ )
if update_output_names:
print(f"Updating network outputs to {update_output_names}")
network = ModifyNetworkOutputs(network, update_output_names)
@@ -269,8 +277,10 @@ def build(self,
engine = engine_from_network(
network,
config=CreateConfig(fp16=fp16,
+ bf16=bf16,
tf32=tf32,
int8=int8,
+ fp8=fp8,
refittable=enable_refit,
profiles=[p],
load_timing_cache=timing_cache,
@@ -306,10 +316,10 @@ def allocate_buffers(self, shape_dict=None, device='cuda'):
shape = shape_dict[name]
else:
shape = self.engine.get_tensor_shape(name)
- dtype = trt.nptype(self.engine.get_tensor_dtype(name))
if self.engine.get_tensor_mode(name) == trt.TensorIOMode.INPUT:
self.context.set_input_shape(name, shape)
- tensor = torch.empty(tuple(shape), dtype=numpy_to_torch_dtype_dict[dtype]).to(device=device)
+ dtype=trt_to_torch_dtype_dict[self.engine.get_tensor_dtype(name)]
+ tensor = torch.empty(tuple(shape), dtype=dtype).to(device=device)
self.tensors[name] = tensor
@@ -350,7 +360,6 @@ def save_image(images, image_path_dir, image_name_prefix, image_name_suffix):
"""
Save the generated images to png files.
"""
- images = ((images + 1) * 255 / 2).clamp(0, 255).detach().permute(0, 2, 3, 1).round().type(torch.uint8).cpu().numpy()
for i in range(images.shape[0]):
image_path = os.path.join(image_path_dir, image_name_prefix+str(i+1)+'-'+str(random.randint(1000,9999))+'-'+image_name_suffix+'.png')
print(f"Saving image {i+1} / {images.shape[0]} to: {image_path}")
@@ -575,7 +584,7 @@ def append(self, item):
def add_arguments(parser):
# Stable Diffusion configuration
- parser.add_argument('--version', type=str, default="1.5", choices=["1.4", "1.5", "dreamshaper-7", "2.0-base", "2.0", "2.1-base", "2.1", "xl-1.0", "xl-turbo"], help="Version of Stable Diffusion")
+ parser.add_argument('--version', type=str, default="1.5", choices=["1.4", "1.5", "dreamshaper-7", "2.0-base", "2.0", "2.1-base", "2.1", "xl-1.0", "xl-turbo", "svd-xt-1.1", "sd3", "cascade"], help="Version of Stable Diffusion")
parser.add_argument('prompt', nargs = '*', help="Text prompt(s) to guide image generation")
parser.add_argument('--negative-prompt', nargs = '*', default=[''], help="The negative prompt(s) to guide the image generation.")
parser.add_argument('--batch-size', type=int, default=1, choices=[1, 2, 4], help="Batch size (repeat prompt)")
@@ -598,7 +607,8 @@ def add_arguments(parser):
# TensorRT engine build
parser.add_argument('--engine-dir', default='engine', help="Output directory for TensorRT engines")
parser.add_argument('--int8', action='store_true', help="Apply int8 quantization.")
- parser.add_argument('--quantization-level', type=float, default=2.5, choices=[1.0, 2.0, 2.5, 3.0], help="int8/fp8 quantization level, 1: CNN, 2: CNN+FFN, 2.5: CNN+FFN+QKV, 3: CNN+FC")
+ parser.add_argument('--fp8', action='store_true', help="Apply fp8 quantization.")
+ parser.add_argument('--quantization-level', type=float, default=0.0, choices=[0.0, 1.0, 2.0, 2.5, 3.0, 4.0], help="int8/fp8 quantization level, 1: CNN, 2: CNN + FFN, 2.5: CNN + FFN + QKV, 3: CNN + Almost all Linear (Including FFN, QKV, Proj and others), 4: CNN + Almost all Linear + fMHA, 0: Default to 2.5 for int8 and 4.0 for fp8.")
parser.add_argument('--build-static-batch', action='store_true', help="Build TensorRT engines with fixed batch size.")
parser.add_argument('--build-dynamic-shape', action='store_true', help="Build TensorRT engines with dynamic image shapes.")
parser.add_argument('--build-enable-refit', action='store_true', help="Enable Refit option in TensorRT engines during build.")
@@ -628,8 +638,28 @@ def process_pipeline_args(args):
if args.use_cuda_graph and (not args.build_static_batch or args.build_dynamic_shape):
raise ValueError(f"Using CUDA graph requires static dimensions. Enable `--build-static-batch` and do not specify `--build-dynamic-shape`")
- if args.int8 and not args.version.startswith('xl'):
- raise ValueError(f"int8 quantization only supported for SDXL pipeline.")
+ if args.int8 and not any(args.version.startswith(prefix) for prefix in ['xl', '1.5', '2.1']):
+ raise ValueError(f"int8 quantization is only supported for SDXL, SD1.5 and SD2.1 pipelines.")
+
+ if args.fp8 and not any(args.version.startswith(prefix) for prefix in ['xl', '1.5', '2.1']):
+ raise ValueError(f"fp8 quantization is only supported for SDXL, SD1.5 and SD2.1 pipelines.")
+
+ if args.fp8 and args.int8:
+ raise ValueError(f"Cannot apply both int8 and fp8 quantization, please choose only one.")
+
+ if args.fp8:
+ device_info = torch.cuda.get_device_properties(0)
+ version = device_info.major * 10 + device_info.minor
+ if version < 90: # if Ada or older
+ raise ValueError(f"Cannot apply FP8 quantization for GPU with compute capability {version / 10.0}. Only Hopper is supported.")
+
+ if args.quantization_level == 0.0:
+ if args.fp8:
+ args.quantization_level = 4.0
+ print("The default quantization level has been set to 4.0 for FP8.")
+ elif args.int8:
+ args.quantization_level = 2.5
+ print("The default quantization level has been set to 2.5 for INT8.")
if args.lora_scale:
for lora_scale in (lora_scale for lora_scale in args.lora_scale if not 0 <= lora_scale <= 1):
@@ -663,6 +693,7 @@ def process_pipeline_args(args):
'enable_refit': args.build_enable_refit,
'timing_cache': args.timing_cache,
'int8': args.int8,
+ 'fp8': args.fp8,
'quantization_level': args.quantization_level,
}
diff --git a/demo/Diffusion/utils_modelopt.py b/demo/Diffusion/utils_modelopt.py
index b8735bfa..fbaaed97 100644
--- a/demo/Diffusion/utils_modelopt.py
+++ b/demo/Diffusion/utils_modelopt.py
@@ -17,12 +17,21 @@
import re
import torch
+import numpy as np
+import onnx
+import onnx_graphsurgeon as gs
from modelopt.torch.quantization import utils as quant_utils
from modelopt.torch.quantization.calib.max import MaxCalibrator
-from diffusers.models.attention_processor import Attention
+from diffusers.models.attention_processor import Attention, AttnProcessor
from diffusers.models.lora import LoRACompatibleConv, LoRACompatibleLinear
+USE_PEFT = True
+try:
+ from peft.tuners.lora.layer import Conv2d as PEFTLoRAConv2d
+ from peft.tuners.lora.layer import Linear as PEFTLoRALinear
+except ModuleNotFoundError:
+ USE_PEFT = False
class PercentileCalibrator(MaxCalibrator):
def __init__(self, num_bits=8, axis=None, unsigned=False, track_amax=False, **kwargs):
@@ -94,12 +103,11 @@ def __repr__(self):
def filter_func(name):
pattern = re.compile(
- r".*(time_emb_proj|time_embedding|conv_in|conv_out|conv_shortcut|add_embedding).*"
+ r".*(time_emb_proj|time_embedding|conv_in|conv_out|conv_shortcut|add_embedding|pos_embed|time_text_embed|context_embedder|norm_out|proj_out).*"
)
return pattern.match(name) is not None
-
-def quantize_lvl(unet, quant_level=2.5):
+def quantize_lvl(unet, quant_level=2.5, linear_only=False):
"""
We should disable the unwanted quantizer when exporting the onnx
Because in the current modelopt setting, it will load the quantizer amax for all the layers even
@@ -107,8 +115,12 @@ def quantize_lvl(unet, quant_level=2.5):
"""
for name, module in unet.named_modules():
if isinstance(module, (torch.nn.Conv2d, LoRACompatibleConv)):
- module.input_quantizer.enable()
- module.weight_quantizer.enable()
+ if linear_only:
+ module.input_quantizer.disable()
+ module.weight_quantizer.disable()
+ else:
+ module.input_quantizer.enable()
+ module.weight_quantizer.enable()
elif isinstance(module, (torch.nn.Linear, LoRACompatibleLinear)):
if (
(quant_level >= 2 and "ff.net" in name)
@@ -121,18 +133,17 @@ def quantize_lvl(unet, quant_level=2.5):
module.input_quantizer.disable()
module.weight_quantizer.disable()
elif isinstance(module, Attention):
- if quant_level >= 4:
+ head_size = int(module.inner_dim / module.heads)
+ if quant_level >= 4 and head_size % 16 == 0:
module.q_bmm_quantizer.enable()
module.k_bmm_quantizer.enable()
module.v_bmm_quantizer.enable()
module.softmax_quantizer.enable()
- module.bmm2_output_quantizer.enable()
else:
module.q_bmm_quantizer.disable()
module.k_bmm_quantizer.disable()
module.v_bmm_quantizer.disable()
module.softmax_quantizer.disable()
- module.bmm2_output_quantizer.disable()
def get_int8_config(
model,
@@ -185,3 +196,279 @@ def get_int8_config(
}
return quant_config
+SD_FP8_FP16_DEFAULT_CONFIG = {
+ "quant_cfg": {
+ "*weight_quantizer": {"num_bits": (4, 3), "axis": None, "trt_high_precision_dtype": "Half"},
+ "*input_quantizer": {"num_bits": (4, 3), "axis": None, "trt_high_precision_dtype": "Half"},
+ "*output_quantizer": {"enable": False},
+ "*q_bmm_quantizer": {"num_bits": (4, 3), "axis": None, "trt_high_precision_dtype": "Half"},
+ "*k_bmm_quantizer": {"num_bits": (4, 3), "axis": None, "trt_high_precision_dtype": "Half"},
+ "*v_bmm_quantizer": {"num_bits": (4, 3), "axis": None, "trt_high_precision_dtype": "Half"},
+ "*softmax_quantizer": {
+ "num_bits": (4, 3),
+ "axis": None,
+ "trt_high_precision_dtype": "Half",
+ },
+ "default": {"enable": False},
+ },
+ "algorithm": "max",
+}
+
+SD_FP8_FP32_DEFAULT_CONFIG = {
+ "quant_cfg": {
+ "*weight_quantizer": {"num_bits": (4, 3), "axis": None, "trt_high_precision_dtype": "Float"},
+ "*input_quantizer": {"num_bits": (4, 3), "axis": None, "trt_high_precision_dtype": "Float"},
+ "*output_quantizer": {"enable": False},
+ "*q_bmm_quantizer": {"num_bits": (4, 3), "axis": None, "trt_high_precision_dtype": "Float"},
+ "*k_bmm_quantizer": {"num_bits": (4, 3), "axis": None, "trt_high_precision_dtype": "Float"},
+ "*v_bmm_quantizer": {"num_bits": (4, 3), "axis": None, "trt_high_precision_dtype": "Float"},
+ "*softmax_quantizer": {
+ "num_bits": (4, 3),
+ "axis": None,
+ "trt_high_precision_dtype": "Float",
+ },
+ "default": {"enable": False},
+ },
+ "algorithm": "max",
+}
+
+def set_fmha(unet):
+ for name, module in unet.named_modules():
+ if isinstance(module, Attention):
+ module.set_processor(AttnProcessor())
+
+def check_lora(unet):
+ for name, module in unet.named_modules():
+ if isinstance(module, (LoRACompatibleConv, LoRACompatibleLinear)):
+ assert (
+ module.lora_layer is None
+ ), f"To quantize {name}, LoRA layer should be fused/merged. Please fuse the LoRA layer before quantization."
+ elif USE_PEFT and isinstance(module, (PEFTLoRAConv2d, PEFTLoRALinear)):
+ assert (
+ module.merged
+ ), f"To quantize {name}, LoRA layer should be fused/merged. Please fuse the LoRA layer before quantization."
+
+def generate_fp8_scales(unet):
+ # temporary solution due to a known bug in torch.onnx._dynamo_export
+ for _, module in unet.named_modules():
+ if isinstance(module, (torch.nn.Linear, torch.nn.Conv2d)) and (
+ hasattr(module.input_quantizer, "_amax") and module.input_quantizer is not None
+ ):
+ module.input_quantizer._num_bits = 8
+ module.weight_quantizer._num_bits = 8
+ module.input_quantizer._amax = module.input_quantizer._amax * (127 / 448.0)
+ module.weight_quantizer._amax = module.weight_quantizer._amax * (127 / 448.0)
+ elif isinstance(module, Attention) and (
+ hasattr(module.q_bmm_quantizer, "_amax") and module.q_bmm_quantizer is not None
+ ):
+ module.q_bmm_quantizer._num_bits = 8
+ module.q_bmm_quantizer._amax = module.q_bmm_quantizer._amax * (127 / 448.0)
+ module.k_bmm_quantizer._num_bits = 8
+ module.k_bmm_quantizer._amax = module.k_bmm_quantizer._amax * (127 / 448.0)
+ module.v_bmm_quantizer._num_bits = 8
+ module.v_bmm_quantizer._amax = module.v_bmm_quantizer._amax * (127 / 448.0)
+ module.softmax_quantizer._num_bits = 8
+ module.softmax_quantizer._amax = module.softmax_quantizer._amax * (127 / 448.0)
+
+def get_parent_nodes(node):
+ """
+ Returns list of input producer nodes for the given node.
+ """
+ parents = []
+ for tensor in node.inputs:
+ # If the tensor is not a constant or graph input and has a producer,
+ # the producer is a parent of node `node`
+ if len(tensor.inputs) == 1:
+ parents.append(tensor.inputs[0])
+ return parents
+
+def get_child_nodes(node):
+ """
+ Returns list of output consumer nodes for the given node.
+ """
+ children = []
+ for tensor in node.outputs:
+ for consumer in tensor.outputs: # Traverse all consumer of the tensor
+ children.append(consumer)
+ return children
+
+def has_path_type(node, graph, path_type, is_forward, wild_card_types, path_nodes):
+ """
+ Return pattern nodes for the given path_type.
+ """
+ if not path_type:
+ # All types matched
+ return True
+
+ # Check if current non-wild node type does not match the expected path type
+ node_type = node.op
+ is_match = node_type == path_type[0]
+ is_wild_match = node_type in wild_card_types
+ if not is_match and not is_wild_match:
+ return False
+
+ if is_match:
+ path_nodes.append(node)
+ next_path_type = path_type[1:]
+ else:
+ next_path_type = path_type[:]
+
+ if is_forward:
+ next_level_nodes = get_child_nodes(node)
+ else:
+ next_level_nodes = get_parent_nodes(node)
+
+ # Check if any child (forward path) or parent (backward path) can match the remaining path types
+ for next_node in next_level_nodes:
+ sub_path = []
+ if has_path_type(next_node, graph, next_path_type, is_forward, wild_card_types, sub_path):
+ path_nodes.extend(sub_path)
+ return True
+
+ # Path type matches if there is no remaining types to match
+ return not next_path_type
+
+def insert_cast(graph, input_tensor, attrs):
+ """
+ Create a cast layer using tensor as input.
+ """
+ output_tensor = gs.Variable(name=f"{input_tensor.name}/Cast_output", dtype=attrs["to"])
+ next_node_list = input_tensor.outputs.copy()
+ graph.layer(
+ op="Cast",
+ name=f"{input_tensor.name}/Cast",
+ inputs=[input_tensor],
+ outputs=[output_tensor],
+ attrs=attrs,
+ )
+
+ # use cast output as input to next node
+ for next_node in next_node_list:
+ for idx, next_input in enumerate(next_node.inputs):
+ if next_input.name == input_tensor.name:
+ next_node.inputs[idx] = output_tensor
+
+def convert_zp_fp8(onnx_graph):
+ """
+ Convert Q/DQ zero datatype from INT8 to FP8.
+ """
+ # Find all zero constant nodes
+ qdq_zero_nodes = set()
+ for node in onnx_graph.graph.node:
+ if node.op_type == "QuantizeLinear":
+ if len(node.input) > 2:
+ qdq_zero_nodes.add(node.input[2])
+
+ print(f"Found {len(qdq_zero_nodes)} QDQ pairs")
+
+ # Convert zero point datatype from INT8 to FP8.
+ for node in onnx_graph.graph.node:
+ if node.output[0] in qdq_zero_nodes:
+ node.attribute[0].t.data_type = onnx.TensorProto.FLOAT8E4M3FN
+
+ return onnx_graph
+
+def cast_resize_io(graph):
+ """
+ After all activations and weights are converted to fp16, we will
+ add cast nodes to Resize nodes I/O because Resize need to be run in fp32.
+ """
+ nodes = graph.nodes
+ up_block_resize_regex = r"\/up_blocks.[0-2]\/upsamplers.0\/Resize"
+ up_block_resize_nodes = [_n for _n in nodes if re.match(up_block_resize_regex, _n.name)]
+
+ print(f"Found {len(up_block_resize_nodes)} Resize nodes to fix")
+ for resize_node in up_block_resize_nodes:
+ for input_tensor in resize_node.inputs:
+ if input_tensor.name:
+ insert_cast(graph, input_tensor=input_tensor, attrs={"to": np.float32})
+ for output_tensor in resize_node.outputs:
+ if output_tensor.name:
+ insert_cast(graph, input_tensor=output_tensor, attrs={"to": np.float16})
+
+def cast_fp8_mha_io(graph):
+ r"""
+ Insert three cast ops.
+ The first cast will be added before the input0 of MatMul to cast fp16 to fp32.
+ The second cast will be added before the input1 of MatMul to cast fp16 to fp32.
+ The third cast will be added after the output of MatMul to cast fp32 back to fp16.
+ Q Q
+ | |
+ DQ DQ
+ | |
+ Cast Cast
+ (fp16 to fp32) (fp16 to fp32)
+ \ /
+ \ /
+ \ /
+ MatMul
+ |
+ Cast (fp32 to fp16)
+ |
+ Q
+ |
+ DQ
+ The insertion of Cast ops in the FP8 MHA part actually forbids the MHAs to run
+ with FP16 accumulation because TensorRT only has FP32 accumulation kernels for FP8 MHAs.
+ """
+ # Find FP8 MHA pattern.
+ # Match FP8 MHA: Q -> DQ -> BMM1 -> (Mul/Div) -> (Add) -> Softmax -> (Cast) -> Q -> DQ -> BMM2 -> Q -> DQ
+ softmax_bmm1_chain_type = ["Softmax", "MatMul", "DequantizeLinear", "QuantizeLinear"]
+ softmax_bmm2_chain_type = [
+ "Softmax",
+ "QuantizeLinear",
+ "DequantizeLinear",
+ "MatMul",
+ "QuantizeLinear",
+ "DequantizeLinear",
+ ]
+ wild_card_types = [
+ "Div",
+ "Mul",
+ "ConstMul",
+ "Add",
+ "BiasAdd",
+ "Reshape",
+ "Transpose",
+ "Flatten",
+ "Cast",
+ ]
+
+ fp8_mha_partitions = []
+ for node in graph.nodes:
+ if node.op == "Softmax":
+ fp8_mha_partition = []
+ if has_path_type(
+ node, graph, softmax_bmm1_chain_type, False, wild_card_types, fp8_mha_partition
+ ) and has_path_type(
+ node, graph, softmax_bmm2_chain_type, True, wild_card_types, fp8_mha_partition
+ ):
+ if (
+ len(fp8_mha_partition) == 10
+ and fp8_mha_partition[1].op == "MatMul"
+ and fp8_mha_partition[7].op == "MatMul"
+ ):
+ fp8_mha_partitions.append(fp8_mha_partition)
+
+ print(f"Found {len(fp8_mha_partitions)} FP8 attentions")
+
+ # Insert Cast nodes for BMM1 and BMM2.
+ for fp8_mha_partition in fp8_mha_partitions:
+ bmm1_node = fp8_mha_partition[1]
+ insert_cast(graph, input_tensor=bmm1_node.inputs[0], attrs={"to": np.float32})
+ insert_cast(graph, input_tensor=bmm1_node.inputs[1], attrs={"to": np.float32})
+ insert_cast(graph, input_tensor=bmm1_node.outputs[0], attrs={"to": np.float16})
+
+ bmm2_node = fp8_mha_partition[7]
+ insert_cast(graph, input_tensor=bmm2_node.inputs[0], attrs={"to": np.float32})
+ insert_cast(graph, input_tensor=bmm2_node.inputs[1], attrs={"to": np.float32})
+ insert_cast(graph, input_tensor=bmm2_node.outputs[0], attrs={"to": np.float16})
+
+def convert_fp16_io(graph):
+ """
+ Convert graph I/O to FP16.
+ """
+ for input_tensor in graph.inputs:
+ input_tensor.dtype = onnx.TensorProto.FLOAT16
+ for output_tensor in graph.outputs:
+ output_tensor.dtype = onnx.TensorProto.FLOAT16
diff --git a/docker/rockylinux8.Dockerfile b/docker/rockylinux8.Dockerfile
index 2ad1caf9..24cd4bce 100644
--- a/docker/rockylinux8.Dockerfile
+++ b/docker/rockylinux8.Dockerfile
@@ -15,7 +15,7 @@
# limitations under the License.
#
-ARG CUDA_VERSION=12.5.0
+ARG CUDA_VERSION=12.6.0
FROM nvidia/cuda:${CUDA_VERSION}-devel-rockylinux8
LABEL maintainer="NVIDIA CORPORATION"
@@ -25,7 +25,7 @@ ENV NV_CUDNN_VERSION 8.9.6.50-1
ENV NV_CUDNN_PACKAGE libcudnn8-${NV_CUDNN_VERSION}.cuda12.2
ENV NV_CUDNN_PACKAGE_DEV libcudnn8-devel-${NV_CUDNN_VERSION}.cuda12.2
-ENV TRT_VERSION 10.3.0.26
+ENV TRT_VERSION 10.4.0.26
SHELL ["/bin/bash", "-c"]
RUN dnf install -y \
@@ -62,15 +62,15 @@ RUN dnf install -y python38 python38-devel &&\
# Install TensorRT
RUN if [ "${CUDA_VERSION:0:2}" = "11" ]; then \
- wget https://developer.nvidia.com/downloads/compute/machine-learning/tensorrt/10.3.0/tars/TensorRT-10.3.0.26.Linux.x86_64-gnu.cuda-11.8.tar.gz \
- && tar -xf TensorRT-10.3.0.26.Linux.x86_64-gnu.cuda-11.8.tar.gz \
- && cp -a TensorRT-10.3.0.26/lib/*.so* /usr/lib64 \
- && pip install TensorRT-10.3.0.26/python/tensorrt-10.3.0-cp38-none-linux_x86_64.whl ;\
+ wget https://developer.nvidia.com/downloads/compute/machine-learning/tensorrt/10.4.0/tars/TensorRT-10.4.0.26.Linux.x86_64-gnu.cuda-11.8.tar.gz \
+ && tar -xf TensorRT-10.4.0.26.Linux.x86_64-gnu.cuda-11.8.tar.gz \
+ && cp -a TensorRT-10.4.0.26/lib/*.so* /usr/lib64 \
+ && pip install TensorRT-10.4.0.26/python/tensorrt-10.4.0-cp38-none-linux_x86_64.whl ;\
elif [ "${CUDA_VERSION:0:2}" = "12" ]; then \
- wget https://developer.nvidia.com/downloads/compute/machine-learning/tensorrt/10.3.0/tars/TensorRT-10.3.0.26.Linux.x86_64-gnu.cuda-12.5.tar.gz \
- && tar -xf TensorRT-10.3.0.26.Linux.x86_64-gnu.cuda-12.5.tar.gz \
- && cp -a TensorRT-10.3.0.26/lib/*.so* /usr/lib64 \
- && pip install TensorRT-10.3.0.26/python/tensorrt-10.3.0-cp38-none-linux_x86_64.whl ;\
+ wget https://developer.nvidia.com/downloads/compute/machine-learning/tensorrt/10.4.0/tars/TensorRT-10.4.0.26.Linux.x86_64-gnu.cuda-12.6.tar.gz \
+ && tar -xf TensorRT-10.4.0.26.Linux.x86_64-gnu.cuda-12.6.tar.gz \
+ && cp -a TensorRT-10.4.0.26/lib/*.so* /usr/lib64 \
+ && pip install TensorRT-10.4.0.26/python/tensorrt-10.4.0-cp38-none-linux_x86_64.whl ;\
else \
echo "Invalid CUDA_VERSION"; \
exit 1; \
@@ -97,7 +97,7 @@ RUN ln -s /usr/bin/python3 /usr/bin/python
# Set environment and working directory
ENV TRT_LIBPATH /usr/lib64
ENV TRT_OSSPATH /workspace/TensorRT
-ENV PATH="${PATH}:/usr/local/bin/ngc-cli"
+ENV PATH="/workspace/TensorRT/build/out:${PATH}:/usr/local/bin/ngc-cli"
ENV LD_LIBRARY_PATH="${LD_LIBRARY_PATH}:${TRT_OSSPATH}/build/out:${TRT_LIBPATH}"
WORKDIR /workspace
diff --git a/docker/rockylinux9.Dockerfile b/docker/rockylinux9.Dockerfile
index 8741977b..95a87cce 100644
--- a/docker/rockylinux9.Dockerfile
+++ b/docker/rockylinux9.Dockerfile
@@ -15,7 +15,7 @@
# limitations under the License.
#
-ARG CUDA_VERSION=12.5.0
+ARG CUDA_VERSION=12.6.0
FROM nvidia/cuda:${CUDA_VERSION}-devel-rockylinux9
LABEL maintainer="NVIDIA CORPORATION"
@@ -25,7 +25,7 @@ ENV NV_CUDNN_VERSION 8.9.6.50-1
ENV NV_CUDNN_PACKAGE libcudnn8-${NV_CUDNN_VERSION}.cuda12.2
ENV NV_CUDNN_PACKAGE_DEV libcudnn8-devel-${NV_CUDNN_VERSION}.cuda12.2
-ENV TRT_VERSION 10.3.0.26
+ENV TRT_VERSION 10.4.0.26
SHELL ["/bin/bash", "-c"]
RUN dnf install -y \
@@ -67,15 +67,15 @@ RUN dnf -y install \
# Install TensorRT
RUN if [ "${CUDA_VERSION:0:2}" = "11" ]; then \
- wget https://developer.nvidia.com/downloads/compute/machine-learning/tensorrt/10.3.0/tars/TensorRT-10.3.0.26.Linux.x86_64-gnu.cuda-11.8.tar.gz \
- && tar -xf TensorRT-10.3.0.26.Linux.x86_64-gnu.cuda-11.8.tar.gz \
- && cp -a TensorRT-10.3.0.26/lib/*.so* /usr/lib64 \
- && pip install TensorRT-10.3.0.26/python/tensorrt-10.3.0-cp39-none-linux_x86_64.whl ;\
+ wget https://developer.nvidia.com/downloads/compute/machine-learning/tensorrt/10.4.0/tars/TensorRT-10.4.0.26.Linux.x86_64-gnu.cuda-11.8.tar.gz \
+ && tar -xf TensorRT-10.4.0.26.Linux.x86_64-gnu.cuda-11.8.tar.gz \
+ && cp -a TensorRT-10.4.0.26/lib/*.so* /usr/lib64 \
+ && pip install TensorRT-10.4.0.26/python/tensorrt-10.4.0-cp39-none-linux_x86_64.whl ;\
elif [ "${CUDA_VERSION:0:2}" = "12" ]; then \
- wget https://developer.nvidia.com/downloads/compute/machine-learning/tensorrt/10.3.0/tars/TensorRT-10.3.0.26.Linux.x86_64-gnu.cuda-12.5.tar.gz \
- && tar -xf TensorRT-10.3.0.26.Linux.x86_64-gnu.cuda-12.5.tar.gz \
- && cp -a TensorRT-10.3.0.26/lib/*.so* /usr/lib64 \
- && pip install TensorRT-10.3.0.26/python/tensorrt-10.3.0-cp39-none-linux_x86_64.whl ;\
+ wget https://developer.nvidia.com/downloads/compute/machine-learning/tensorrt/10.4.0/tars/TensorRT-10.4.0.26.Linux.x86_64-gnu.cuda-12.6.tar.gz \
+ && tar -xf TensorRT-10.4.0.26.Linux.x86_64-gnu.cuda-12.6.tar.gz \
+ && cp -a TensorRT-10.4.0.26/lib/*.so* /usr/lib64 \
+ && pip install TensorRT-10.4.0.26/python/tensorrt-10.4.0-cp39-none-linux_x86_64.whl ;\
else \
echo "Invalid CUDA_VERSION"; \
exit 1; \
@@ -96,7 +96,7 @@ RUN ln -s /usr/bin/python3 /usr/bin/python
# Set environment and working directory
ENV TRT_LIBPATH /usr/lib64
ENV TRT_OSSPATH /workspace/TensorRT
-ENV PATH="${PATH}:/usr/local/bin/ngc-cli"
+ENV PATH="/workspace/TensorRT/build/out:${PATH}:/usr/local/bin/ngc-cli"
ENV LD_LIBRARY_PATH="${LD_LIBRARY_PATH}:${TRT_OSSPATH}/build/out:${TRT_LIBPATH}"
WORKDIR /workspace
diff --git a/docker/ubuntu-20.04.Dockerfile b/docker/ubuntu-20.04.Dockerfile
index b481d945..88e504f4 100644
--- a/docker/ubuntu-20.04.Dockerfile
+++ b/docker/ubuntu-20.04.Dockerfile
@@ -15,7 +15,7 @@
# limitations under the License.
#
-ARG CUDA_VERSION=12.5.0
+ARG CUDA_VERSION=12.6.0
FROM nvidia/cuda:${CUDA_VERSION}-devel-ubuntu20.04
LABEL maintainer="NVIDIA CORPORATION"
@@ -28,7 +28,7 @@ ENV CUDA_VERSION_MAJOR_MINOR=12.2
ENV NV_CUDNN_PACKAGE "libcudnn8=$NV_CUDNN_VERSION-1+cuda${CUDA_VERSION_MAJOR_MINOR}"
ENV NV_CUDNN_PACKAGE_DEV "libcudnn8-dev=$NV_CUDNN_VERSION-1+cuda${CUDA_VERSION_MAJOR_MINOR}"
-ENV TRT_VERSION 10.3.0.26
+ENV TRT_VERSION 10.4.0.26
SHELL ["/bin/bash", "-c"]
RUN apt-get update && apt-get install -y --no-install-recommends \
@@ -84,15 +84,15 @@ RUN apt-get install -y --no-install-recommends \
# Install TensorRT
RUN if [ "${CUDA_VERSION:0:2}" = "11" ]; then \
- wget https://developer.nvidia.com/downloads/compute/machine-learning/tensorrt/10.3.0/tars/TensorRT-10.3.0.26.Linux.x86_64-gnu.cuda-11.8.tar.gz \
- && tar -xf TensorRT-10.3.0.26.Linux.x86_64-gnu.cuda-11.8.tar.gz \
- && cp -a TensorRT-10.3.0.26/lib/*.so* /usr/lib/x86_64-linux-gnu \
- && pip install TensorRT-10.3.0.26/python/tensorrt-10.3.0-cp38-none-linux_x86_64.whl ;\
+ wget https://developer.nvidia.com/downloads/compute/machine-learning/tensorrt/10.4.0/tars/TensorRT-10.4.0.26.Linux.x86_64-gnu.cuda-11.8.tar.gz \
+ && tar -xf TensorRT-10.4.0.26.Linux.x86_64-gnu.cuda-11.8.tar.gz \
+ && cp -a TensorRT-10.4.0.26/lib/*.so* /usr/lib/x86_64-linux-gnu \
+ && pip install TensorRT-10.4.0.26/python/tensorrt-10.4.0-cp38-none-linux_x86_64.whl ;\
elif [ "${CUDA_VERSION:0:2}" = "12" ]; then \
- wget https://developer.nvidia.com/downloads/compute/machine-learning/tensorrt/10.3.0/tars/TensorRT-10.3.0.26.Linux.x86_64-gnu.cuda-12.5.tar.gz \
- && tar -xf TensorRT-10.3.0.26.Linux.x86_64-gnu.cuda-12.5.tar.gz \
- && cp -a TensorRT-10.3.0.26/lib/*.so* /usr/lib/x86_64-linux-gnu \
- && pip install TensorRT-10.3.0.26/python/tensorrt-10.3.0-cp38-none-linux_x86_64.whl ;\
+ wget https://developer.nvidia.com/downloads/compute/machine-learning/tensorrt/10.4.0/tars/TensorRT-10.4.0.26.Linux.x86_64-gnu.cuda-12.6.tar.gz \
+ && tar -xf TensorRT-10.4.0.26.Linux.x86_64-gnu.cuda-12.6.tar.gz \
+ && cp -a TensorRT-10.4.0.26/lib/*.so* /usr/lib/x86_64-linux-gnu \
+ && pip install TensorRT-10.4.0.26/python/tensorrt-10.4.0-cp38-none-linux_x86_64.whl ;\
else \
echo "Invalid CUDA_VERSION"; \
exit 1; \
@@ -120,7 +120,7 @@ RUN cd /usr/local/bin && wget https://ngc.nvidia.com/downloads/ngccli_cat_linux.
# Set environment and working directory
ENV TRT_LIBPATH /usr/lib/x86_64-linux-gnu
ENV TRT_OSSPATH /workspace/TensorRT
-ENV PATH="${PATH}:/usr/local/bin/ngc-cli"
+ENV PATH="/workspace/TensorRT/build/out:${PATH}:/usr/local/bin/ngc-cli"
ENV LD_LIBRARY_PATH="${LD_LIBRARY_PATH}:${TRT_OSSPATH}/build/out:${TRT_LIBPATH}"
WORKDIR /workspace
diff --git a/docker/ubuntu-22.04-aarch64.Dockerfile b/docker/ubuntu-22.04-aarch64.Dockerfile
index e6991c4c..47836ddf 100644
--- a/docker/ubuntu-22.04-aarch64.Dockerfile
+++ b/docker/ubuntu-22.04-aarch64.Dockerfile
@@ -15,12 +15,12 @@
# limitations under the License.
#
-ARG CUDA_VERSION=12.5.0
+ARG CUDA_VERSION=12.6.0
# Multi-arch container support available in non-cudnn containers.
FROM nvidia/cuda:${CUDA_VERSION}-devel-ubuntu22.04
-ENV TRT_VERSION 10.3.0.26
+ENV TRT_VERSION 10.4.0.26
SHELL ["/bin/bash", "-c"]
# Setup user account
@@ -71,7 +71,7 @@ RUN apt-get install -y --no-install-recommends \
# Install TensorRT. This will also pull in CUDNN
RUN ver="${CUDA_VERSION%.*}" &&\
if [ "${ver%.*}" = "12" ] ; then \
- ver="12.5"; \
+ ver="12.6"; \
fi &&\
v="${TRT_VERSION}-1+cuda${ver}" &&\
apt-key adv --fetch-keys https://developer.download.nvidia.com/compute/cuda/repos/ubuntu2204/sbsa/3bf863cc.pub &&\
@@ -104,7 +104,7 @@ RUN cd /usr/local/bin && wget https://ngc.nvidia.com/downloads/ngccli_arm64.zip
# Set environment and working directory
ENV TRT_LIBPATH /usr/lib/aarch64-linux-gnu/
ENV TRT_OSSPATH /workspace/TensorRT
-ENV PATH="${PATH}:/usr/local/bin/ngc-cli"
+ENV PATH="/workspace/TensorRT/build/out:${PATH}:/usr/local/bin/ngc-cli"
ENV LD_LIBRARY_PATH="${LD_LIBRARY_PATH}:${TRT_OSSPATH}/build/out:${TRT_LIBPATH}"
WORKDIR /workspace
diff --git a/docker/ubuntu-22.04.Dockerfile b/docker/ubuntu-22.04.Dockerfile
index 43f872a6..c218693d 100644
--- a/docker/ubuntu-22.04.Dockerfile
+++ b/docker/ubuntu-22.04.Dockerfile
@@ -15,7 +15,7 @@
# limitations under the License.
#
-ARG CUDA_VERSION=12.5.0
+ARG CUDA_VERSION=12.6.0
FROM nvidia/cuda:${CUDA_VERSION}-devel-ubuntu22.04
LABEL maintainer="NVIDIA CORPORATION"
@@ -28,7 +28,7 @@ ENV CUDA_VERSION_MAJOR_MINOR=12.2
ENV NV_CUDNN_PACKAGE "libcudnn8=$NV_CUDNN_VERSION-1+cuda${CUDA_VERSION_MAJOR_MINOR}"
ENV NV_CUDNN_PACKAGE_DEV "libcudnn8-dev=$NV_CUDNN_VERSION-1+cuda${CUDA_VERSION_MAJOR_MINOR}"
-ENV TRT_VERSION 10.3.0.26
+ENV TRT_VERSION 10.4.0.26
SHELL ["/bin/bash", "-c"]
RUN apt-get update && apt-get install -y --no-install-recommends \
@@ -84,15 +84,15 @@ RUN apt-get install -y --no-install-recommends \
# Install TensorRT
RUN if [ "${CUDA_VERSION:0:2}" = "11" ]; then \
- wget https://developer.nvidia.com/downloads/compute/machine-learning/tensorrt/10.3.0/tars/TensorRT-10.3.0.26.Linux.x86_64-gnu.cuda-11.8.tar.gz \
- && tar -xf TensorRT-10.3.0.26.Linux.x86_64-gnu.cuda-11.8.tar.gz \
- && cp -a TensorRT-10.3.0.26/lib/*.so* /usr/lib/x86_64-linux-gnu \
- && pip install TensorRT-10.3.0.26/python/tensorrt-10.3.0-cp310-none-linux_x86_64.whl ;\
+ wget https://developer.nvidia.com/downloads/compute/machine-learning/tensorrt/10.4.0/tars/TensorRT-10.4.0.26.Linux.x86_64-gnu.cuda-11.8.tar.gz \
+ && tar -xf TensorRT-10.4.0.26.Linux.x86_64-gnu.cuda-11.8.tar.gz \
+ && cp -a TensorRT-10.4.0.26/lib/*.so* /usr/lib/x86_64-linux-gnu \
+ && pip install TensorRT-10.4.0.26/python/tensorrt-10.4.0-cp310-none-linux_x86_64.whl ;\
elif [ "${CUDA_VERSION:0:2}" = "12" ]; then \
- wget https://developer.nvidia.com/downloads/compute/machine-learning/tensorrt/10.3.0/tars/TensorRT-10.3.0.26.Linux.x86_64-gnu.cuda-12.5.tar.gz \
- && tar -xf TensorRT-10.3.0.26.Linux.x86_64-gnu.cuda-12.5.tar.gz \
- && cp -a TensorRT-10.3.0.26/lib/*.so* /usr/lib/x86_64-linux-gnu \
- && pip install TensorRT-10.3.0.26/python/tensorrt-10.3.0-cp310-none-linux_x86_64.whl ;\
+ wget https://developer.nvidia.com/downloads/compute/machine-learning/tensorrt/10.4.0/tars/TensorRT-10.4.0.26.Linux.x86_64-gnu.cuda-12.6.tar.gz \
+ && tar -xf TensorRT-10.4.0.26.Linux.x86_64-gnu.cuda-12.6.tar.gz \
+ && cp -a TensorRT-10.4.0.26/lib/*.so* /usr/lib/x86_64-linux-gnu \
+ && pip install TensorRT-10.4.0.26/python/tensorrt-10.4.0-cp310-none-linux_x86_64.whl ;\
else \
echo "Invalid CUDA_VERSION"; \
exit 1; \
@@ -120,7 +120,7 @@ RUN cd /usr/local/bin && wget https://ngc.nvidia.com/downloads/ngccli_cat_linux.
# Set environment and working directory
ENV TRT_LIBPATH /usr/lib/x86_64-linux-gnu
ENV TRT_OSSPATH /workspace/TensorRT
-ENV PATH="${PATH}:/usr/local/bin/ngc-cli"
+ENV PATH="/workspace/TensorRT/build/out:${PATH}:/usr/local/bin/ngc-cli"
ENV LD_LIBRARY_PATH="${LD_LIBRARY_PATH}:${TRT_OSSPATH}/build/out:${TRT_LIBPATH}"
WORKDIR /workspace
diff --git a/docker/ubuntu-cross-aarch64.Dockerfile b/docker/ubuntu-cross-aarch64.Dockerfile
index 6a03c874..7a105f69 100644
--- a/docker/ubuntu-cross-aarch64.Dockerfile
+++ b/docker/ubuntu-cross-aarch64.Dockerfile
@@ -15,13 +15,13 @@
# limitations under the License.
#
-ARG CUDA_VERSION=12.5.0
+ARG CUDA_VERSION=12.6.0
ARG OS_VERSION=22.04
FROM nvidia/cuda:${CUDA_VERSION}-devel-ubuntu${OS_VERSION}
LABEL maintainer="NVIDIA CORPORATION"
-ENV TRT_VERSION 10.3.0.26
+ENV TRT_VERSION 10.4.0.26
ENV DEBIAN_FRONTEND=noninteractive
ARG uid=1000
diff --git a/include/NvInfer.h b/include/NvInfer.h
index fc8d4ec8..18e10be4 100644
--- a/include/NvInfer.h
+++ b/include/NvInfer.h
@@ -2536,7 +2536,6 @@ constexpr inline int32_t EnumMax() noexcept
//! * GatherMode::kDEFAULT: s = q + r - 1 - nbElementwiseDims
//! * GatherMode::kND: s = q + r - indices.d[q-1] - 1 - nbElementwiseDims
//! * GatherMode::kELEMENT: s = q = r.
-//! The output can be a shape tensor only if the mode is GatherMode::kDEFAULT.
//!
//! The dimensions of the output likewise depends on the mode:
//!
@@ -3037,7 +3036,7 @@ struct Permutation
//! This layer shuffles data by applying in sequence: a transpose operation, a reshape operation
//! and a second transpose operation. The dimension types of the output are those of the reshape dimension.
//!
-//! The layer has an optional second input. If present, it must be a 1D Int32 shape tensor,
+//! The layer has an optional second input. If present, it must be a 1D tensor of type Int32 or Int64,
//! and the reshape dimensions are taken from it.
//!
//! \warning Do not inherit from this class, as doing so will break forward-compatibility of the API and ABI.
@@ -3126,7 +3125,7 @@ class IShuffleLayer : public ILayer
//! The indices in the dynamic case are as follows:
//!
//! - 0: Data or Shape tensor to be shuffled.
- //! - 1: The dimensions for the reshape operation, as a 1D Int32 shape tensor.
+ //! - 1: The dimensions for the reshape operation, as a 1D tensor of type Int32 or Int64.
//!
//! If this function is called with the value 1, then the function getNbInputs() changes
//! from returning 1 to 2.
@@ -3247,7 +3246,7 @@ constexpr inline int32_t EnumMax() noexcept
//!
//! The slice layer selects for each dimension a start location from within the input tensor, and
//! copies elements to the output tensor using the specified stride across the input tensor.
-//! Start, size, and stride tensors must be 1D Int32 shape tensors if not specified via Dims.
+//! Start, size, and stride tensors must be 1D tensors of type Int32 or Int64 if not specified via Dims.
//!
//! An example of using slice on a tensor:
//! input = {{0, 2, 4}, {1, 3, 5}}
@@ -3285,10 +3284,12 @@ constexpr inline int32_t EnumMax() noexcept
//! The following constraints must be satisfied to execute this layer on DLA:
//! * start, size, and stride are build time constants, either as static Dims or as constant input tensors.
//! * axes, if provided, are build time constants, either as static Dims or as a constant input tensor.
-//! * sampleMode is kSTRICT_BOUNDS.
+//! * sampleMode is kDEFAULT, kWRAP, or kFILL.
//! * Strides are 1 for all dimensions.
-//! * Slicing is not performed on the first dimension
-//! * The input tensor has four dimensions
+//! * Slicing is not performed on the first dimension.
+//! * The input tensor has four dimensions.
+//! * For kFILL sliceMode, the fill value input is a scalar output of an IConstantLayer with value 0 that is not
+//! consumed by any other layer.
//!
//! \warning Do not inherit from this class, as doing so will break forward-compatibility of the API and ABI.
//!
@@ -3412,15 +3413,15 @@ class ISliceLayer : public ILayer
//! The indices are as follows:
//!
//! - 0: Tensor to be sliced.
- //! - 1: The start tensor to begin slicing, as a 1D Int32 shape tensor.
- //! - 2: The size tensor of the resulting slice, as a 1D Int32 shape tensor.
- //! - 3: The stride of the slicing operation, as a 1D Int32 shape tensor.
+ //! - 1: The start tensor to begin slicing, as a 1D tensor of type Int32 or Int64.
+ //! - 2: The size tensor of the resulting slice, as a 1D tensor of type Int32 or Int64.
+ //! - 3: The stride of the slicing operation, as a 1D tensor of type Int32 or Int64.
//! - 4: Value for the kFILL slice mode. The fill value data type should either be the same
//! or be implicitly convertible to the input data type.
//! Implicit data type conversion is supported among kFLOAT, kHALF, kINT8, and kFP8 data types.
//! This input is disallowed for other modes.
//! - 5: The axes tensor indicating the corresponding axes that start, size, and stride
- //! should apply to, as a 1D Int32 shape tensor. Negative values for axes
+ //! should apply to, as a 1D tensor or type Int32 or Int64. Negative values for axes
//! indicate indexing from the back of the input tensor. Values must be unique and be
//! within the interval of [-rank(input), rank(input)-1].
//!
@@ -4208,7 +4209,7 @@ class IResizeLayer : public ILayer
//! The indices in the dynamic case are as follows:
//!
//! - 0: Execution tensor to be resized.
- //! - 1: The output dimensions, as a 1D Int32 shape tensor.
+ //! - 1: The output dimensions, as a 1D tensor of type Int32 or Int64.
//!
//! If this function is called with the value 1, then the function getNbInputs() changes
//! from returning 1 to 2.
@@ -4467,7 +4468,9 @@ class IConditionLayer : public IIfConditionalBoundaryLayer
//!
//! \brief This layer represents an output of an IIfConditional.
//!
-//! An IIfConditionalOutputLayer has exactly one output.
+//! An IIfConditionalOutputLayer has two inputs and one output.
+//!
+//! \see IIfConditional::addOutput
//!
class IIfConditionalOutputLayer : public IIfConditionalBoundaryLayer
{
@@ -4539,6 +4542,8 @@ class IIfConditional : public INoCopy
//! Each output layer of an IIfConditional represents a single output of either the true-subgraph or the
//! false-subgraph of an IIfConditional, depending on which subgraph was executed.
//!
+ //! The shapes of the two tensors must be equal unless the condition is a build-time constant.
+ //!
//! \see IIfConditionalOutputLayer
//!
IIfConditionalOutputLayer* addOutput(ITensor& trueSubgraphOutput, ITensor& falseSubgraphOutput) noexcept
@@ -4693,7 +4698,7 @@ class ILoopOutputLayer : public ILoopBoundaryLayer
//! The indices in the kCONCATENATE or kREVERSE cases are as follows:
//!
//! - 0: Contribution to the output tensor. The contribution must come from inside the loop.
- //! - 1: The concatenation length scalar value, must come from outside the loop, as a 0D Int32 or Int64 shape tensor.
+ //! - 1: The concatenation length scalar value, must come from outside the loop, as a 0D shape tensor of type Int32 or Int64.
//!
//! If this function is called with the value 1, then the function getNbInputs() changes
//! from returning 1 to 2.
@@ -5775,8 +5780,8 @@ class IScatterLayer : public ILayer
//! Output, and an axis attribute.
//! * Indices is an Int32 tensor that determines which locations in Output to set as on_value.
//! * Values is a two-element (rank=1) tensor that consists of [off_value, on_value]
-//! * Depth is an Int32 shape tensor of rank 0, which contains the depth (number of classes) of the one-hot encoding.
-//! The depth tensor must be a build-time constant, and its value should be positive.
+//! * Depth is an 0D tensor of type Int32 or Int64, which contains the depth (number of classes) of the one-hot encoding.
+//! The depth tensor must be a positive build-time constant.
//! * Output is a tensor with rank = rank(indices)+1, where the added dimension contains the one-hot encoding.
//! The data types of Output is equal to the Values data type.
//! * Axis is a scalar specifying to which dimension of the output one-hot encoding is added.
@@ -7046,7 +7051,7 @@ class INetworkDefinition : public INoCopy
//!
//! \see IParametricReLULayer
//!
- //! \warning Int32 tensors are not valid input tensors.
+ //! \warning Tensors of type Int32, Int64, Bool, or UInt8 are not allowed as inputs.
//!
//! \return The new parametric ReLU layer, or nullptr if it could not be created.
//!
@@ -9585,6 +9590,30 @@ class IBuilderConfig : public INoCopy
return mImpl->getRuntimePlatform();
}
+ //!
+ //! \brief Set the maximum number of tactics to time when there is a choice of tactics.
+ //!
+ //! This function controls the number of tactics timed when there are multiple tactics to choose from.
+ //!
+ //! \see getMaxNbTactics()
+ //!
+ void setMaxNbTactics(int32_t maxNbTactics) noexcept
+ {
+ mImpl->setMaxNbTactics(maxNbTactics);
+ }
+
+ //!
+ //! \brief Query the maximum number of tactics timed when there is a choice.
+ //!
+ //! By default the value is -1, indicating TensorRT can determine the number of tactics based on its own heuristic.
+ //!
+ //! \see setMaxNbTactics()
+ //!
+ int32_t getMaxNbTactics() const noexcept
+ {
+ return mImpl->getMaxNbTactics();
+ }
+
protected:
apiv::VBuilderConfig* mImpl;
};
diff --git a/include/NvInferConsistency.h b/include/NvInferConsistency.h
index 32bca28b..0d1b8b40 100644
--- a/include/NvInferConsistency.h
+++ b/include/NvInferConsistency.h
@@ -41,7 +41,9 @@ namespace consistency
//!
//! \warning Do not inherit from this class, as doing so will break forward-compatibility of the API and ABI.
//!
-class IConsistencyChecker
+//! \deprecated Deprecated in TensorRT 10.4.
+//!
+class TRT_DEPRECATED IConsistencyChecker
{
public:
//!
@@ -80,7 +82,9 @@ class IConsistencyChecker
//!
//! Supported IPlugin inferfaces are limited to IPluginV2IOExt only.
//!
-class IPluginChecker : public IPluginCreator
+//! \deprecated Deprecated in TensorRT 10.4.
+//!
+class TRT_DEPRECATED IPluginChecker : public IPluginCreator
{
public:
//!
@@ -114,6 +118,9 @@ class IPluginChecker : public IPluginCreator
} // namespace nvinfer1
+//!
+//! \deprecated Deprecated in TensorRT 10.4.
+//!
extern "C" TENSORRTAPI void* createConsistencyChecker_INTERNAL(void* logger, void const* blob, size_t size,
int32_t version); //!< Internal C entry point for creating IConsistencyChecker.
@@ -132,7 +139,10 @@ namespace consistency
namespace // anonymous
{
-inline IConsistencyChecker* createConsistencyChecker(ILogger& logger, void const* blob, size_t size)
+//!
+//! \deprecated Deprecated in TensorRT 10.4.
+//!
+TRT_DEPRECATED inline IConsistencyChecker* createConsistencyChecker(ILogger& logger, void const* blob, size_t size)
{
return static_cast(
createConsistencyChecker_INTERNAL(&logger, blob, size, NV_TENSORRT_VERSION));
diff --git a/include/NvInferConsistencyImpl.h b/include/NvInferConsistencyImpl.h
index b0626a26..0b4e8dd3 100644
--- a/include/NvInferConsistencyImpl.h
+++ b/include/NvInferConsistencyImpl.h
@@ -32,7 +32,10 @@ namespace nvinfer1
namespace apiv
{
-class VConsistencyChecker
+//!
+//! \deprecated Deprecated in TensorRT 10.4.
+//!
+class TRT_DEPRECATED VConsistencyChecker
{
public:
virtual ~VConsistencyChecker() noexcept = default;
diff --git a/include/NvInferImpl.h b/include/NvInferImpl.h
index d202c3d1..2c7df74a 100644
--- a/include/NvInferImpl.h
+++ b/include/NvInferImpl.h
@@ -1170,6 +1170,8 @@ class VBuilderConfig : public VRoot
virtual IProgressMonitor* getProgressMonitor() const noexcept = 0;
virtual void setRuntimePlatform(RuntimePlatform runtimePlatform) noexcept = 0;
virtual RuntimePlatform getRuntimePlatform() const noexcept = 0;
+ virtual void setMaxNbTactics(int32_t maxTactics) noexcept = 0;
+ virtual int32_t getMaxNbTactics() const noexcept = 0;
};
class VSerializationConfig : public VRoot
diff --git a/include/NvInferVersion.h b/include/NvInferVersion.h
index 11b9cc6d..477c5719 100644
--- a/include/NvInferVersion.h
+++ b/include/NvInferVersion.h
@@ -24,9 +24,9 @@
#define NV_INFER_VERSION_H
#define NV_TENSORRT_MAJOR 10 //!< TensorRT major version.
-#define NV_TENSORRT_MINOR 3 //!< TensorRT minor version.
-#define NV_TENSORRT_PATCH 0 //!< TensorRT patch version.
-#define NV_TENSORRT_BUILD 26 //!< TensorRT build number.
+#define NV_TENSORRT_MINOR 4 //!< TensorRT minor version.
+#define NV_TENSORRT_PATCH 0 //!< TensorRT patch version.
+#define NV_TENSORRT_BUILD 26 //!< TensorRT build number.
#define NV_TENSORRT_LWS_MAJOR 0 //!< TensorRT LWS major version.
#define NV_TENSORRT_LWS_MINOR 0 //!< TensorRT LWS minor version.
diff --git a/parsers/onnx b/parsers/onnx
index 62bdde2a..3775e499 160000
--- a/parsers/onnx
+++ b/parsers/onnx
@@ -1 +1 @@
-Subproject commit 62bdde2a04fcd53c2409cb895ee18db445b7e755
+Subproject commit 3775e499322eee17c837e27bff6d07af4261767a
diff --git a/plugin/CMakeLists.txt b/plugin/CMakeLists.txt
index 112d45f7..32ec4ed5 100644
--- a/plugin/CMakeLists.txt
+++ b/plugin/CMakeLists.txt
@@ -127,7 +127,7 @@ if (CUDA_VERSION VERSION_LESS 11.0)
endif()
set_target_properties(${SHARED_TARGET} PROPERTIES
- CXX_STANDARD "14"
+ CXX_STANDARD "17"
CXX_STANDARD_REQUIRED "YES"
CXX_EXTENSIONS "NO"
ARCHIVE_OUTPUT_DIRECTORY "${TRT_OUT_DIR}"
@@ -173,7 +173,7 @@ target_include_directories(${STATIC_TARGET}
)
set_target_properties(${STATIC_TARGET} PROPERTIES
- CXX_STANDARD "14"
+ CXX_STANDARD "17"
CXX_STANDARD_REQUIRED "YES"
CXX_EXTENSIONS "NO"
ARCHIVE_OUTPUT_DIRECTORY "${TRT_OUT_DIR}"
@@ -206,7 +206,7 @@ target_include_directories(${VFC_SHARED_TARGET}
)
set_target_properties(${VFC_SHARED_TARGET} PROPERTIES
- CXX_STANDARD "14"
+ CXX_STANDARD "17"
CXX_STANDARD_REQUIRED "YES"
CXX_EXTENSIONS "NO"
ARCHIVE_OUTPUT_DIRECTORY "${TRT_OUT_DIR}"
diff --git a/plugin/README.md b/plugin/README.md
index 0416ecc0..1a826743 100644
--- a/plugin/README.md
+++ b/plugin/README.md
@@ -17,7 +17,8 @@
| [disentangledAttentionPlugin](disentangledAttentionPlugin) | DisentangledAttention_TRT | 1 |
| [efficientNMSPlugin](efficientNMSPlugin) | EfficientNMS_TRT | 1 |
| [efficientNMSONNXPlugin](efficientNMSPlugin) [DEPRECATED] | EfficientNMS_ONNX_TRT | 1 |
-| [embLayerNormPlugin](embLayerNormPlugin) | CustomEmbLayerNormPluginDynamic | 1, 2 |
+| [embLayerNormPlugin](embLayerNormPlugin) [DEPRECATED]| CustomEmbLayerNormPluginDynamic | 1, 2, 3 |
+| [embLayerNormPlugin](embLayerNormPlugin) | CustomEmbLayerNormPluginDynamic | 4, 5 |
| [fcPlugin](fcPlugin) | CustomFCPluginDynamic | 1 |
| [flattenConcat](flattenConcat) | FlattenConcat_TRT | 1 |
| [geluPlugin](geluPlugin) [DEPRECATED] | CustomGeluPluginDynamic | 1 |
@@ -50,7 +51,8 @@
| [scatterElementsPlugin](scatterElementsPlugin) [DEPRECATED] | ScatterElements | 1 |
| [scatterElementsPlugin](scatterElementsPlugin) | ScatterElements | 2 |
| [scatterPlugin](scatterPlugin) | ScatterND | 1 |
-| [skipLayerNormPlugin](skipLayerNormPlugin) | CustomSkipLayerNormPluginDynamic | 1, 2, 3 |
+| [skipLayerNormPlugin](skipLayerNormPlugin) [DEPRECATED] | CustomSkipLayerNormPluginDynamic | 1, 2, 3, 4 |
+| [skipLayerNormPlugin](skipLayerNormPlugin) | CustomSkipLayerNormPluginDynamic | 5, 6, 7, 8 |
| [specialSlicePlugin](specialSlicePlugin) [DEPRECATED] | SpecialSlice_TRT | 1 |
| [splitPlugin](splitPlugin) [DEPRECATED] | Split | 1 |
| [voxelGeneratorPlugin](voxelGeneratorPlugin) | VoxelGeneratorPlugin | 1 |
diff --git a/plugin/bertQKVToContextPlugin/qkvToContext.cu b/plugin/bertQKVToContextPlugin/qkvToContext.cu
index cd5c69e7..2b89d428 100644
--- a/plugin/bertQKVToContextPlugin/qkvToContext.cu
+++ b/plugin/bertQKVToContextPlugin/qkvToContext.cu
@@ -630,9 +630,9 @@ static inline void set_alpha(uint32_t& alpha, float norm, Data_type dtype)
class FusedMHARunnerFP16::mhaImpl
{
public:
- mhaImpl(FusedMHARunnerFP16* interface)
- : interface(interface)
- , sm(interface->mSm)
+ mhaImpl(FusedMHARunnerFP16* mhaInterface)
+ : mhaInterface(mhaInterface)
+ , sm(mhaInterface->mSm)
, xmmaKernel(getXMMAKernels(DATA_TYPE_FP16, sm))
, xmmas_m(0U)
, xmmas_n(0U)
@@ -647,8 +647,8 @@ public:
// check that we initialized
assert(xmmas_m > 0);
assert(threads_per_cta > 0);
- assert(interface->mB > 0);
- return interface->mB * xmmas_m * threads_per_cta * sizeof(uint32_t);
+ assert(mhaInterface->mB > 0);
+ return mhaInterface->mB * xmmas_m * threads_per_cta * sizeof(uint32_t);
}
void setup(int32_t S, int32_t B, int32_t headSize)
@@ -679,7 +679,7 @@ public:
// The number of xmmas in the N dimension.
xmmas_n = (S + 16 * warps_n - 1) / (16 * warps_n);
- const float scale_bmm1 = interface->mRsqrtHeadSize;
+ const float scale_bmm1 = mhaInterface->mRsqrtHeadSize;
const float scale_softmax = 1.f; // Seems to be only required for int8
const float scale_bmm2 = 1.f;
@@ -689,13 +689,13 @@ public:
set_alpha(params.scale_bmm2, scale_bmm2, scale_type);
params.b = B;
- params.h = interface->mNumHeads;
+ params.h = mhaInterface->mNumHeads;
params.s = S;
- params.d = interface->mHeadSize;
+ params.d = mhaInterface->mHeadSize;
- params.qkv_stride_in_bytes = get_size_in_bytes(interface->mLdQKV, DATA_TYPE_FP16);
+ params.qkv_stride_in_bytes = get_size_in_bytes(mhaInterface->mLdQKV, DATA_TYPE_FP16);
params.packed_mask_stride_in_bytes = xmmas_m * threads_per_cta * sizeof(uint32_t);
- params.o_stride_in_bytes = get_size_in_bytes(interface->mLdOut, DATA_TYPE_FP16);
+ params.o_stride_in_bytes = get_size_in_bytes(mhaInterface->mLdOut, DATA_TYPE_FP16);
}
void run(const PluginTensorDesc& inputDesc, const PluginTensorDesc& outputDesc, const void* qkvPtr,
@@ -718,7 +718,7 @@ public:
}
private:
- FusedMHARunnerFP16* interface;
+ FusedMHARunnerFP16* mhaInterface;
Fused_multihead_attention_params params;
int sm;
const FusedMultiHeadAttentionXMMAKernel* xmmaKernel;
@@ -774,11 +774,11 @@ class FusedMHARunnerInt8::mhaImpl
{
public:
- mhaImpl(FusedMHARunnerInt8* interface)
- : interface(interface)
- , sm(interface->mSm)
+ mhaImpl(FusedMHARunnerInt8* mhaInterface)
+ : mhaInterface(mhaInterface)
+ , sm(mhaInterface->mSm)
, xmmaKernel(getXMMAKernels(DATA_TYPE_INT8, sm))
- , mDqProbs(interface->mDqProbs)
+ , mDqProbs(mhaInterface->mDqProbs)
, xmmas_m(0U)
, xmmas_n(0U)
, threads_per_cta(1U)
@@ -791,8 +791,8 @@ public:
{
assert(xmmas_m > 0);
assert(threads_per_cta > 0);
- assert(interface->mB > 0);
- return interface->mB * xmmas_m * threads_per_cta * sizeof(uint32_t);
+ assert(mhaInterface->mB > 0);
+ return mhaInterface->mB * xmmas_m * threads_per_cta * sizeof(uint32_t);
}
void setup(int32_t S, int32_t B, int32_t headSize)
@@ -823,13 +823,13 @@ public:
params.b = B;
- params.h = interface->mNumHeads;
+ params.h = mhaInterface->mNumHeads;
params.s = S;
- params.d = interface->mHeadSize;
+ params.d = mhaInterface->mHeadSize;
- params.qkv_stride_in_bytes = get_size_in_bytes(interface->mLdQKV, DATA_TYPE_INT8);
+ params.qkv_stride_in_bytes = get_size_in_bytes(mhaInterface->mLdQKV, DATA_TYPE_INT8);
params.packed_mask_stride_in_bytes = xmmas_m * threads_per_cta * sizeof(uint32_t);
- params.o_stride_in_bytes = get_size_in_bytes(interface->mLdOut, DATA_TYPE_INT8);
+ params.o_stride_in_bytes = get_size_in_bytes(mhaInterface->mLdOut, DATA_TYPE_INT8);
}
void run(const PluginTensorDesc& inputDesc, const PluginTensorDesc& outputDesc, const void* qkvPtr,
@@ -838,7 +838,7 @@ public:
float scaleQkv = inputDesc.scale;
float scaleCtx = outputDesc.scale;
- float scaleBmm1 = scaleQkv * scaleQkv * interface->mRsqrtHeadSize;
+ float scaleBmm1 = scaleQkv * scaleQkv * mhaInterface->mRsqrtHeadSize;
float scaleBmm2 = mDqProbs * scaleQkv / scaleCtx;
float scaleSoftmax = 1.f / mDqProbs;
@@ -866,7 +866,7 @@ public:
private:
float mDqProbs;
- FusedMHARunnerInt8* interface;
+ FusedMHARunnerInt8* mhaInterface;
Fused_multihead_attention_params params;
int sm;
const FusedMultiHeadAttentionXMMAKernel* xmmaKernel;
@@ -920,9 +920,9 @@ bool FusedMHARunnerInt8::isValid(int32_t headSize, int32_t s) const
class FusedMHARunnerFP16v2::mhaImpl
{
public:
- mhaImpl(FusedMHARunnerFP16v2* interface)
- : interface(interface)
- , sm(interface->mSm)
+ mhaImpl(FusedMHARunnerFP16v2* mhaInterface)
+ : mhaInterface(mhaInterface)
+ , sm(mhaInterface->mSm)
, xmmaKernel(getXMMAKernelsV2(DATA_TYPE_FP16, sm))
{
assert((sm == kSM_72 || sm == kSM_75 || sm == kSM_80 || sm == kSM_86 || sm == kSM_87 || sm == kSM_89 || sm == kSM_90)
@@ -937,8 +937,8 @@ public:
// check that we initialized
assert(xmmas_m > 0);
assert(threads_per_cta > 0);
- assert(interface->mB > 0);
- return interface->mB * xmmas_m * threads_per_cta * sizeof(uint32_t);
+ assert(mhaInterface->mB > 0);
+ return mhaInterface->mB * xmmas_m * threads_per_cta * sizeof(uint32_t);
}
void setup(int32_t S, int32_t B, int32_t headSize)
@@ -986,7 +986,7 @@ public:
// The number of xmmas in the N dimension.
xmmas_n = (S + 16 * warps_n - 1) / (16 * warps_n);
- const float scale_bmm1 = interface->mRsqrtHeadSize;
+ const float scale_bmm1 = mhaInterface->mRsqrtHeadSize;
const float scale_softmax = 1.f; // Seems to be only required for int8
const float scale_bmm2 = 1.f;
@@ -996,16 +996,16 @@ public:
set_alpha(params.scale_bmm2, scale_bmm2, scale_type);
params.b = B;
- params.h = interface->mNumHeads;
+ params.h = mhaInterface->mNumHeads;
params.s = S;
- params.d = interface->mHeadSize;
+ params.d = mhaInterface->mHeadSize;
// mLdQKV = 3 * B * mNumHeads * mHeadSize;
// mLdOut = B * mNumHeads * mHeadSize;
- params.qkv_stride_in_bytes = 3 * interface->mNumHeads * interface->mHeadSize * sizeof(half);
+ params.qkv_stride_in_bytes = 3 * mhaInterface->mNumHeads * mhaInterface->mHeadSize * sizeof(half);
params.packed_mask_stride_in_bytes = xmmas_m * threads_per_cta * sizeof(uint32_t);
- params.o_stride_in_bytes = interface->mNumHeads * interface->mHeadSize * sizeof(half);
+ params.o_stride_in_bytes = mhaInterface->mNumHeads * mhaInterface->mHeadSize * sizeof(half);
}
void run(const PluginTensorDesc& inputDesc, const PluginTensorDesc& outputDesc, const void* qkvPtr,
@@ -1030,7 +1030,7 @@ public:
}
private:
- FusedMHARunnerFP16v2* interface;
+ FusedMHARunnerFP16v2* mhaInterface;
Fused_multihead_attention_params_v2 params;
int sm;
const FusedMultiHeadAttentionXMMAKernelV2* xmmaKernel;
@@ -1087,11 +1087,11 @@ class FusedMHARunnerInt8v2::mhaImpl
{
public:
- mhaImpl(FusedMHARunnerInt8v2* interface)
- : interface(interface)
- , sm(interface->mSm)
+ mhaImpl(FusedMHARunnerInt8v2* mhaInterface)
+ : mhaInterface(mhaInterface)
+ , sm(mhaInterface->mSm)
, xmmaKernel(getXMMAKernelsV2(DATA_TYPE_INT8, sm))
- , mDqProbs(interface->mDqProbs)
+ , mDqProbs(mhaInterface->mDqProbs)
, xmmas_m(0U)
, xmmas_n(0U)
, threads_per_cta(1U)
@@ -1107,8 +1107,8 @@ public:
{
assert(xmmas_m > 0);
assert(threads_per_cta > 0);
- assert(interface->mB > 0);
- return interface->mB * xmmas_m * threads_per_cta * sizeof(uint32_t);
+ assert(mhaInterface->mB > 0);
+ return mhaInterface->mB * xmmas_m * threads_per_cta * sizeof(uint32_t);
}
void setup(int32_t S, int32_t B, int32_t headSize)
@@ -1163,13 +1163,13 @@ public:
xmmas_n = (S + 16 * warps_n - 1) / (16 * warps_n);
params.b = B;
- params.h = interface->mNumHeads;
+ params.h = mhaInterface->mNumHeads;
params.s = S;
- params.d = interface->mHeadSize;
- params.use_int8_scale_max = interface->mUseInt8ScaleMax;
+ params.d = mhaInterface->mHeadSize;
+ params.use_int8_scale_max = mhaInterface->mUseInt8ScaleMax;
params.packed_mask_stride_in_bytes = xmmas_m * threads_per_cta * sizeof(uint32_t);
- params.qkv_stride_in_bytes = 3 * interface->mNumHeads * interface->mHeadSize * sizeof(int8_t);
- params.o_stride_in_bytes = interface->mNumHeads * interface->mHeadSize * sizeof(int8_t);
+ params.qkv_stride_in_bytes = 3 * mhaInterface->mNumHeads * mhaInterface->mHeadSize * sizeof(int8_t);
+ params.o_stride_in_bytes = mhaInterface->mNumHeads * mhaInterface->mHeadSize * sizeof(int8_t);
}
void run(const PluginTensorDesc& inputDesc, const PluginTensorDesc& outputDesc, const void* qkvPtr,
@@ -1178,7 +1178,7 @@ public:
float scaleQkv = inputDesc.scale;
float scaleCtx = outputDesc.scale;
- float scaleBmm1 = scaleQkv * scaleQkv * interface->mRsqrtHeadSize;
+ float scaleBmm1 = scaleQkv * scaleQkv * mhaInterface->mRsqrtHeadSize;
float scaleBmm2 = mDqProbs * scaleQkv / scaleCtx;
float scaleSoftmax = 1.f / mDqProbs;
@@ -1194,7 +1194,7 @@ public:
// dummy input in V2/V3 because now we use cu_seqlens
params.packed_mask_ptr = nullptr;
- params.use_int8_scale_max = interface->mUseInt8ScaleMax;
+ params.use_int8_scale_max = mhaInterface->mUseInt8ScaleMax;
params.o_ptr = output;
@@ -1211,7 +1211,7 @@ public:
private:
float mDqProbs;
- FusedMHARunnerInt8v2* interface;
+ FusedMHARunnerInt8v2* mhaInterface;
Fused_multihead_attention_params_v2 params;
int sm;
const FusedMultiHeadAttentionXMMAKernelV2* xmmaKernel;
diff --git a/plugin/common/common.cuh b/plugin/common/common.cuh
index 4fff6d49..4a7620cd 100644
--- a/plugin/common/common.cuh
+++ b/plugin/common/common.cuh
@@ -24,18 +24,6 @@
#include
#endif // CUDA_VERSION
-#if CUDA_VERSION >= 12050
-#include
-#undef _CCCL_FORCEINLINE
-
-#if defined(_CCCL_CUDA_COMPILER)
-# define _CCCL_FORCEINLINE __forceinline__
-#else // ^^^ _CCCL_CUDA_COMPILER ^^^ / vvv !_CCCL_CUDA_COMPILER vvv
-# define _CCCL_FORCEINLINE inline
-#endif // !_CCCL_CUDA_COMPILER
-
-#endif // CUDA_VERSION >= 12050
-
#include "common/cublasWrapper.h"
#include
#include
diff --git a/plugin/common/cub_helper.h b/plugin/common/cub_helper.h
index 7c947b2f..28c5cbdb 100644
--- a/plugin/common/cub_helper.h
+++ b/plugin/common/cub_helper.h
@@ -21,18 +21,6 @@
#include
#endif // CUDA_VERSION
-#if CUDA_VERSION >= 12050
-#include
-#undef _CCCL_FORCEINLINE
-
-#if defined(_CCCL_CUDA_COMPILER)
-# define _CCCL_FORCEINLINE __forceinline__
-#else // ^^^ _CCCL_CUDA_COMPILER ^^^ / vvv !_CCCL_CUDA_COMPILER vvv
-# define _CCCL_FORCEINLINE inline
-#endif // !_CCCL_CUDA_COMPILER
-
-#endif // CUDA_VERSION >= 12050
-
#include "common/kernels/kernel.h"
#include
template
diff --git a/plugin/common/plugin.cpp b/plugin/common/plugin.cpp
index b685528a..a4b228ff 100644
--- a/plugin/common/plugin.cpp
+++ b/plugin/common/plugin.cpp
@@ -149,5 +149,20 @@ int32_t dimToInt32(int64_t d)
return static_cast(d);
}
+bool supportsMemPoolsHelper()
+{
+ int32_t device;
+ PLUGIN_CUASSERT(cudaGetDevice(&device));
+ int32_t value;
+ PLUGIN_CUASSERT(cudaDeviceGetAttribute(&value, cudaDevAttrMemoryPoolsSupported, device));
+ return value != 0;
+}
+
+bool supportsMemPools()
+{
+ static bool sResult = supportsMemPoolsHelper();
+ return sResult;
+}
+
} // namespace plugin
} // namespace nvinfer1
diff --git a/plugin/common/plugin.h b/plugin/common/plugin.h
index 833ebb22..450259f6 100644
--- a/plugin/common/plugin.h
+++ b/plugin/common/plugin.h
@@ -141,6 +141,11 @@ struct CudaBind
// Throw exception if it doesn't fit.
int32_t dimToInt32(int64_t);
+// Helper function to determine whether memory pool support is available on the device.
+bool supportsMemPoolsHelper();
+
+// Wrapper function around the helper to keep the result in a static variable to avoid mulitple calls to CUDA APIs.
+bool supportsMemPools();
} // namespace plugin
} // namespace nvinfer1
diff --git a/plugin/embLayerNormPlugin/CustomEmbLayerNormPluginDynamic_PluginConfig.yaml b/plugin/embLayerNormPlugin/CustomEmbLayerNormPluginDynamic_PluginConfig.yaml
index a942508c..6d8d1125 100644
--- a/plugin/embLayerNormPlugin/CustomEmbLayerNormPluginDynamic_PluginConfig.yaml
+++ b/plugin/embLayerNormPlugin/CustomEmbLayerNormPluginDynamic_PluginConfig.yaml
@@ -82,13 +82,13 @@ versions:
bert_embeddings_token_type_embeddings: 2
bert_embeddings_position_embeddings: 2
attribute_dim_range:
- output_fp16:
+ output_fp16:
- min: "=1"
- max: "=1"
- full_mask:
+ full_mask:
- min: "=1"
- max: "=1"
- mha_type_id:
+ mha_type_id:
- min: "=1"
- max: "=1"
bert_embeddings_layernorm_beta:
@@ -107,29 +107,29 @@ versions:
- min: "=1, =1"
- max: "=pinf, =pinf"
attribute_options:
- output_fp16:
+ output_fp16:
- 0
- 1
- full_mask:
+ full_mask:
- 0
- 1
- mha_type_id:
+ mha_type_id:
- 0
- 1
- 2
- bert_embeddings_layernorm_beta:
+ bert_embeddings_layernorm_beta:
min: "=ninf"
max: "=pinf"
- bert_embeddings_layernorm_gamma:
+ bert_embeddings_layernorm_gamma:
min: "=ninf"
max: "=pinf"
- bert_embeddings_word_embeddings:
+ bert_embeddings_word_embeddings:
min: "=ninf"
max: "=pinf"
- bert_embeddings_token_type_embeddings:
+ bert_embeddings_token_type_embeddings:
min: "=ninf"
max: "=pinf"
- bert_embeddings_position_embeddings:
+ bert_embeddings_position_embeddings:
min: "=ninf"
max: "=pinf"
attributes_required:
@@ -148,23 +148,23 @@ versions:
segment_id: int32
input_mask: int32
attribute_options:
- output_fp16:
+ output_fp16:
value: 0
shape: "1"
- full_mask:
+ full_mask:
value: 0
shape: "1"
- mha_type_id:
+ mha_type_id:
value: 0
shape: "1"
- bert_embeddings_layernorm_beta:
+ bert_embeddings_layernorm_beta:
shape: "128"
- bert_embeddings_layernorm_gamma:
+ bert_embeddings_layernorm_gamma:
shape: "128"
- bert_embeddings_word_embeddings:
+ bert_embeddings_word_embeddings:
shape: "100, 128"
- bert_embeddings_token_type_embeddings:
+ bert_embeddings_token_type_embeddings:
shape: "2, 128"
- bert_embeddings_position_embeddings:
+ bert_embeddings_position_embeddings:
shape: "20, 128"
...
diff --git a/plugin/embLayerNormPlugin/README.md b/plugin/embLayerNormPlugin/README.md
index bd767ef3..ca0ed259 100644
--- a/plugin/embLayerNormPlugin/README.md
+++ b/plugin/embLayerNormPlugin/README.md
@@ -15,45 +15,78 @@
The plugin performs the following two tasks:
1. Embeds an input sequence consisting of token ids and segment ids. This consists of token embedding lookup, segment embedding lookup, adding positional embeddings and finally, layer normalization.
-2. Preprocesses input masks, that are used to mark valid input tokens in sequences that are padded to the target sequence length.
+2. For version 1 of the plugin only, preprocesses input masks, that are used to mark valid input tokens in sequences that are padded to the target sequence length.
Assuming contiguous input masks, encodes the masks as a single number denoting the number of valid elements, e.g.:
-```
-111100 => 4
-110000 => 2
-110100: Invalid mask, because it is not contiguous
-```
+ ```
+ 111100 => 4
+ 110000 => 2
+ 110100: Invalid mask, because it is not contiguous
+ ```
+ For subsequent versions (2,3,4,5), the input mask is returned after casting to `half` and reshaping to the shape of the embedded output.
### Structure
-The `embLayerNormPlugin` takes three inputs; `token_id`, `segmend_id`, and `input_mask`.
+The version 1 `embLayerNormPlugin` takes three inputs; `token_id`, `segment_id`, and `input_mask`.
+The subsequent versions 2,3,4,5 (variable seqlen) take four inputs; `token_id`, `segment_id`, `cu_seqlen`, and `max_seqlen`.
-`token_id`
-An input sequence containing token ids. token_id is an `int32` tensor with shape `[S, B]` where `S` is the sequence length and `B` is the batch size.
+### Version 1
+Inputs:
+- `token_id`
+An input sequence containing token ids. token_id is an `int32` tensor with shape `[S, B,]` where `S` is the sequence length and `B` is the batch size.
Tokens typically identify words or word pieces that were obtained by preprocessing the input text.
-`segment_id`
+- `segment_id`
An input sequence containing segment ids. segment_id is an `int32` tensor with shape `[S, B]` where `S` is the sequence length and `B` is the batch size.
The segment id is used to distinguish between different parts of the input sequence that might serve different purposes. E.g. in a squad task, the input sequence might consist of a segment representing the knowledge base (i.e. a paragraph of text) and a segment representing the question.
-`input_mask`
+- `input_mask`
input_mask is an `int32` tensor with shape `[S, B]` where `S` is the sequence length and `B` is the batch size.
The input mask denotes valid elements in a sequence that was padded to the sequence length `S`.
+Outputs:
-The `embLayerNormPlugin` generates the following two outputs:
-
-`embedded_input`
-embedded_input is an floating point tensor with shape `[S, B, E]` where `S` is sequence length, `B` is batch size, and `E` is hidden size.
+- `embedded_output`
+embedded_output is a floating point tensor with shape `[S, B, E]` where `S` is sequence length, `B` is batch size, and `E` is hidden size.
The final output embedding is the sum of embeddings for the token, the segment and the position in the sequence.
-`maskIdx`
+- `maskIdx`
The `maskIdx` is a more compact representation of the input mask, consisting of the number of valid elements, assuming that the original mask was contiguous.
For fixed sequence length version 1, the `maskIdx` is an `int32` tensor with shape `[B, packSize]` where `B` is batch size, `packSize` is the packed mask size that depends on the sequence length.
-For huggingface style variable sequence length version 2, the `maskIdx` is an `int32` empty tensor.
-For megatron style variable sequence length version 3, the `maskIdx` is a `half` tensor with shape `[B, S, 1, 1]` where `B` is batch size, `S` is the sequence length.
+
+### Version >= 2
+
+Inputs:
+- `token_id`
+An input sequence containing token ids. token_id is a 1-D, `int32` tensor with shape `[SxB]` where `S` is the sequence length and `B` is the batch size.
+Tokens typically identify words or word pieces that were obtained by preprocessing the input text.
+
+- `segment_id`
+An input sequence containing segment ids. segment_id is also a 1-D, `int32` tensor with shape `[SxB]` where `S` is the sequence length and `B` is the batch size.
+The segment id is used to distinguish between different parts of the input sequence that might serve different purposes. E.g. in a squad task, the input sequence might consist of a segment representing the knowledge base (i.e. a paragraph of text) and a segment representing the question.
+
+- `input_mask`
+input_mask is also a 1-D, `int32` tensor with shape `[SxB]` where `S` is the sequence length and `B` is the batch size.
+The input mask denotes valid elements in a sequence that was padded to the sequence length `S`.
+
+- `cu_seqlen` (Version 2,3,4,5 only)
+An input sequence containing the "cumulative sequence lengths", used to index into the right sequence when sequences have variable lengths. `cu_seqlen` is a 1-D, `int32` tensor with shape `[B+1]` where `B` is the batch size.
+
+- `max_seqlen` (Version 2,3,4,5 only)
+Scalar `int32` value that specifies the maximum sequence length.
+
+Outputs:
+
+- `embedded_output`
+embedded_output is a floating point tensor with shape `[SxB, E, 1, 1]` where `S` is sequence length, `B` is batch size, and `E` is hidden size.
+The final output embedding is the sum of embeddings for the token, the segment and the position in the sequence.
+
+- `maskIdx`
+(1) Huggingface variant (versions 2,4): An empty tensor (for backwards compatibility)
+(2) Megatron variant (versions 3,5): The inputs masks returned as a `half` tensor with the same shape as `embedded_output` - `[SxB, E, 1, 1]`.
+
## Parameters
@@ -62,16 +95,16 @@ For megatron style variable sequence length version 3, the `maskIdx` is a `half`
The parameters are defined below and consists of the following attributes:
-| Type | Parameter | Version | Description
-|----------|----------------------------------------|----------|--------------------------------------------------------
-|`int` |`output_fp16` | 1, 2 |Integer encoding the DataType, set 0 when build FP32 network and set 1 when build FP32/INT8 network (0: FP32, 1: FP16)
-|`int` |`full_mask` | 1 |Whether to output the full mask that works with the specialized multi-head-attention plugin kernels (this is deprecated, please use mha_type_id)
-|`int` |`mha_type_id` | 1 |Integer encoding the multi-head-attention plugin DataType (0: FP32, 1: FP16, 2: INT8)
-|`Weights` |`bert_embeddings_layernorm_beta` | 1, 2 |Beta parameter for layer norm. Shape: `[E,]` where `E` is hidden size
-|`Weights` |`bert_embeddings_layernorm_gamma` | 1, 2 |Gamma parameter for layer norm. Shape: `[E,]` where `E` is hidden size
-|`Weights` |`bert_embeddings_word_embeddings` | 1, 2 |Token embedding matrix. Shape: `[word_vocab_size, E]` where `E` is hidden size
-|`Weights` |`bert_embeddings_token_type_embeddings` | 1, 2 |Token type embedding matrix. Shape: `[type_vocab_size, E]` where `E` is hidden size
-|`Weights` |`bert_embeddings_position_embeddings` | 1, 2 |Positional embedding matrix. Shape: `[S, E]` where `S` is the maximum sequence length and `E` is hidden size
+| Type | Parameter | Version | Description
+|----------|----------------------------------------|----------------|--------------------------------------------------------
+|`int` |`output_fp16` | 1, 2, 3, 4, 5 |Integer encoding the DataType, set 0 when build FP32 network and set 1 when build FP32/INT8 network (0: FP32, 1: FP16)
+|`int` |`full_mask` | 1 |Whether to output the full mask that works with the specialized multi-head-attention plugin kernels (this is deprecated, please use mha_type_id)
+|`int` |`mha_type_id` | 1 |Integer encoding the multi-head-attention plugin DataType (0: FP32, 1: FP16, 2: INT8)
+|`Weights` |`bert_embeddings_layernorm_beta` | 1, 2, 3, 4, 5 |Beta parameter for layer norm. Shape: `[E,]` where `E` is hidden size
+|`Weights` |`bert_embeddings_layernorm_gamma` | 1, 2, 3, 4, 5 |Gamma parameter for layer norm. Shape: `[E,]` where `E` is hidden size
+|`Weights` |`bert_embeddings_word_embeddings` | 1, 2, 3, 4, 5 |Token embedding matrix. Shape: `[word_vocab_size, E]` where `E` is hidden size
+|`Weights` |`bert_embeddings_token_type_embeddings` | 1, 2, 3, 4, 5 |Token type embedding matrix. Shape: `[type_vocab_size, E]` where `E` is hidden size
+|`Weights` |`bert_embeddings_position_embeddings` | 1, 2, 3, 4, 5 |Positional embedding matrix. Shape: `[S, E]` where `S` is the maximum sequence length and `E` is hidden size
## Additional resources
@@ -90,10 +123,14 @@ documentation.
## Changelog
-October 2020
+July 2024:
+Add `EmbLayerNormPlugin` versions 3 & 4 that duplicate the behavior of v2 and v3 plugins respectively, but implement the `IPluginV3` interface instead of the deprecated `IPluginV2DynamicExt` interface.
+Update this README with updated description of I/O and structure.
+
+October 2020:
Add V2 plugin that supports variable sequence length.
-November 2019
+November 2019:
This is the first release of this `README.md` file.
diff --git a/plugin/embLayerNormPlugin/embLayerNormVarSeqlenPlugin.cpp b/plugin/embLayerNormPlugin/embLayerNormVarSeqlenPlugin.cpp
index 4313faa7..6b9a61ad 100644
--- a/plugin/embLayerNormPlugin/embLayerNormVarSeqlenPlugin.cpp
+++ b/plugin/embLayerNormPlugin/embLayerNormVarSeqlenPlugin.cpp
@@ -30,9 +30,99 @@ using namespace nvinfer1::plugin::bert;
namespace
{
-char const* EMB_LAYER_NORM_VAR_SEQLEN_VERSION_HFACE{"2"};
-char const* EMB_LAYER_NORM_VAR_SEQLEN_VERSION_MTRON{"3"};
-char const* EMB_LAYER_NORM_VAR_SEQLEN_NAME{"CustomEmbLayerNormPluginDynamic"};
+constexpr char const* kEMB_LAYER_NORM_VAR_SEQLEN_VERSION_HFACE{"4"};
+constexpr char const* kEMB_LAYER_NORM_VAR_SEQLEN_VERSION_MTRON{"5"};
+constexpr char const* kEMB_LAYER_NORM_VAR_SEQLEN_NAME{"CustomEmbLayerNormPluginDynamic"};
+
+void checkConfigurationInputs(
+ PluginTensorDesc const* inputs, int32_t nbInputs, PluginTensorDesc const* outputs, int32_t nbOutputs) noexcept
+{
+ // Validate input arguments
+ PLUGIN_ASSERT(nbInputs == 4);
+ PLUGIN_ASSERT(nbOutputs == 2);
+
+ PLUGIN_ASSERT(inputs[0].dims.nbDims == 1);
+ PLUGIN_ASSERT(inputs[1].dims.nbDims == 1);
+
+ PLUGIN_ASSERT(inputs[1].dims.d[0] == inputs[0].dims.d[0]);
+
+ PLUGIN_ASSERT(inputs[2].dims.nbDims == 1);
+
+ PLUGIN_ASSERT(outputs[0].dims.nbDims == 4);
+ PLUGIN_ASSERT(static_cast(outputs[0].dims.d[0]) == static_cast(inputs[0].dims.d[0]));
+ PLUGIN_ASSERT(outputs[0].dims.d[2] == 1);
+ PLUGIN_ASSERT(outputs[0].dims.d[3] == 1);
+
+ PLUGIN_ASSERT(inputs[0].type == DataType::kINT32);
+ PLUGIN_ASSERT(inputs[1].type == DataType::kINT32);
+ PLUGIN_ASSERT(inputs[2].type == DataType::kINT32);
+}
+
+bool initializeFields(char const* name, PluginFieldCollection const* fc, Weights& beta, Weights& gamma,
+ Weights& word_emb, Weights& pos_emb, Weights& tok_emb)
+{
+ bool output_fp16 = false;
+ std::set const requiredAttributes{
+ "bert_embeddings_layernorm_beta",
+ "bert_embeddings_layernorm_gamma",
+ "bert_embeddings_word_embeddings",
+ "bert_embeddings_token_type_embeddings",
+ "bert_embeddings_position_embeddings",
+ };
+ plugin::validateRequiredAttributesExist(requiredAttributes, fc);
+
+ for (int32_t i = 0; i < fc->nbFields; i++)
+ {
+ std::string field_name(fc->fields[i].name);
+ if (field_name.compare("bert_embeddings_layernorm_beta") == 0)
+ {
+ BERT_DEBUG_MSG("Building bert_embeddings_layernorm_beta...");
+ beta.values = fc->fields[i].data;
+ beta.count = fc->fields[i].length;
+ beta.type = fieldTypeToDataType(fc->fields[i].type);
+ }
+
+ else if (field_name.compare("bert_embeddings_layernorm_gamma") == 0)
+ {
+ BERT_DEBUG_MSG("Building bert_embeddings_layernorm_gamma...");
+ gamma.values = fc->fields[i].data;
+ gamma.count = fc->fields[i].length;
+ gamma.type = fieldTypeToDataType(fc->fields[i].type);
+ }
+
+ else if (field_name.compare("bert_embeddings_word_embeddings") == 0)
+ {
+ BERT_DEBUG_MSG("Building bert_embeddings_word_embeddings...");
+ word_emb.values = fc->fields[i].data;
+ word_emb.count = fc->fields[i].length;
+ word_emb.type = fieldTypeToDataType(fc->fields[i].type);
+ }
+
+ else if (field_name.compare("bert_embeddings_token_type_embeddings") == 0)
+ {
+ BERT_DEBUG_MSG("Building bert_embeddings_token_type_embeddings...");
+ tok_emb.values = fc->fields[i].data;
+ tok_emb.count = fc->fields[i].length;
+ tok_emb.type = fieldTypeToDataType(fc->fields[i].type);
+ }
+
+ else if (field_name.compare("bert_embeddings_position_embeddings") == 0)
+ {
+ BERT_DEBUG_MSG("Building bert_embeddings_position_embeddings...");
+ pos_emb.values = fc->fields[i].data;
+ pos_emb.count = fc->fields[i].length;
+ pos_emb.type = fieldTypeToDataType(fc->fields[i].type);
+ }
+ else if (field_name.compare("output_fp16") == 0)
+ {
+ BERT_DEBUG_MSG("Building output_fp16...");
+ PLUGIN_VALIDATE(fc->fields[i].type == PluginFieldType::kINT32);
+ output_fp16 = static_cast(fc->fields[i].data)[0] != 0;
+ }
+ }
+ return output_fp16;
+}
+
} // namespace
// Static class fields initialization
@@ -75,67 +165,76 @@ EmbLayerNormVarSeqlenPluginBase::EmbLayerNormVarSeqlenPluginBase(std::string con
copyToDevice(mTokEmb, getWeightsSize(mTokEmb, mType), mTokEmbDev);
}
-EmbLayerNormVarSeqlenPluginBase::EmbLayerNormVarSeqlenPluginBase(
- std::string const& name, void const* data, size_t length)
- : mLayerName(name)
- , mGammaDev(nullptr)
- , mBetaDev(nullptr)
- , mWordEmbDev(nullptr)
- , mTokEmbDev(nullptr)
- , mPosEmbDev(nullptr)
-{
- // Deserialize in the same order as serialization
- deserialize_value(&data, &length, &mType);
- deserialize_value(&data, &length, &mLd);
- deserialize_value(&data, &length, &mWordVocabSize);
- deserialize_value(&data, &length, &mPosVocabSize);
- deserialize_value(&data, &length, &mTokVocabSize);
- deserialize_value(&data, &length, &mMaskType);
-
- char const* d = static_cast(data);
- mBeta.convertAndCopy(d, mLd, nvinfer1::DataType::kFLOAT);
- mGamma.convertAndCopy(d, mLd, nvinfer1::DataType::kFLOAT);
-
- mWordEmb.convertAndCopy(d, mLd * mWordVocabSize, mType);
- mPosEmb.convertAndCopy(d, mLd * mPosVocabSize, mType);
- mTokEmb.convertAndCopy(d, mLd * mTokVocabSize, mType);
-
- copyToDevice(mGamma, sizeof(float) * mGamma.count, mGammaDev);
- copyToDevice(mBeta, sizeof(float) * mBeta.count, mBetaDev);
-
- copyToDevice(mWordEmb, getWeightsSize(mWordEmb, mType), mWordEmbDev);
- copyToDevice(mPosEmb, getWeightsSize(mPosEmb, mType), mPosEmbDev);
- copyToDevice(mTokEmb, getWeightsSize(mTokEmb, mType), mTokEmbDev);
-}
-
EmbLayerNormVarSeqlenPluginHFace::EmbLayerNormVarSeqlenPluginHFace(std::string const& name, DataType const type,
Weights const& beta, Weights const& gamma, Weights const& wordEmb, Weights const& posEmb, Weights const& tokEmb)
: EmbLayerNormVarSeqlenPluginBase(name, type, beta, gamma, wordEmb, posEmb, tokEmb, DataType::kINT32)
{
-}
-
-EmbLayerNormVarSeqlenPluginHFace::EmbLayerNormVarSeqlenPluginHFace(
- std::string const& name, void const* data, size_t length)
- : EmbLayerNormVarSeqlenPluginBase(name, data, length)
-{
- BERT_DEBUG_MSG("EmbLayerNormVarSeqlenPluginHFace deserialize");
+ BERT_DEBUG_MSG("EmbLayerNormVarSeqlenPluginHFace creation");
}
EmbLayerNormVarSeqlenPluginMTron::EmbLayerNormVarSeqlenPluginMTron(std::string const& name, DataType const type,
Weights const& beta, Weights const& gamma, Weights const& wordEmb, Weights const& posEmb, Weights const& tokEmb)
: EmbLayerNormVarSeqlenPluginBase(name, type, beta, gamma, wordEmb, posEmb, tokEmb, type)
{
+ BERT_DEBUG_MSG("EmbLayerNormVarSeqlenPluginMTron creation");
}
-EmbLayerNormVarSeqlenPluginMTron::EmbLayerNormVarSeqlenPluginMTron(
- std::string const& name, void const* data, size_t length)
- : EmbLayerNormVarSeqlenPluginBase(name, data, length)
+EmbLayerNormVarSeqlenPluginBase::~EmbLayerNormVarSeqlenPluginBase()
{
- BERT_DEBUG_MSG("EmbLayerNormVarSeqlenPluginMTron deserialize");
+ try
+ {
+ // This gets called when the network containing plugin is destroyed
+ mGammaDev.reset(nullptr);
+ mBetaDev.reset(nullptr);
+ mWordEmbDev.reset(nullptr);
+ mPosEmbDev.reset(nullptr);
+ mTokEmbDev.reset(nullptr);
+ // delete this; (TRT will delete this plugin object)
+ }
+ catch (std::exception const& e)
+ {
+ caughtError(e);
+ }
+}
+
+EmbLayerNormVarSeqlenPluginHFace::~EmbLayerNormVarSeqlenPluginHFace()
+{
+ BERT_DEBUG_MSG("EmbLayerNormVarSeqlenPluginHFace destruction");
+}
+
+EmbLayerNormVarSeqlenPluginMTron::~EmbLayerNormVarSeqlenPluginMTron()
+{
+ BERT_DEBUG_MSG("EmbLayerNormVarSeqlenPluginMTron destruction");
+}
+
+//////
+// IPluginV3 method definitions:
+// - getCapabilityInterface() (Base)
+// - clone() (HFace, MTron)
+//////
+IPluginCapability* EmbLayerNormVarSeqlenPluginBase::getCapabilityInterface(PluginCapabilityType type) noexcept
+{
+ try
+ {
+ if (type == PluginCapabilityType::kBUILD)
+ {
+ return static_cast(this);
+ }
+ if (type == PluginCapabilityType::kRUNTIME)
+ {
+ return static_cast(this);
+ }
+ PLUGIN_ASSERT(type == PluginCapabilityType::kCORE);
+ return static_cast(this);
+ }
+ catch (std::exception const& e)
+ {
+ caughtError(e);
+ }
+ return nullptr;
}
-// IPluginV2DynamicExt Methods
-IPluginV2DynamicExt* EmbLayerNormVarSeqlenPluginHFace::clone() const noexcept
+IPluginV3* EmbLayerNormVarSeqlenPluginHFace::clone() noexcept
{
try
{
@@ -143,7 +242,6 @@ IPluginV2DynamicExt* EmbLayerNormVarSeqlenPluginHFace::clone() const noexcept
auto p = new EmbLayerNormVarSeqlenPluginHFace(mLayerName, mType, mBeta, mGamma, mWordEmb, mPosEmb, mTokEmb);
p->setPluginNamespace(mNamespace.c_str());
-
return p;
}
catch (std::exception const& e)
@@ -153,7 +251,7 @@ IPluginV2DynamicExt* EmbLayerNormVarSeqlenPluginHFace::clone() const noexcept
return nullptr;
}
-IPluginV2DynamicExt* EmbLayerNormVarSeqlenPluginMTron::clone() const noexcept
+IPluginV3* EmbLayerNormVarSeqlenPluginMTron::clone() noexcept
{
try
{
@@ -171,158 +269,101 @@ IPluginV2DynamicExt* EmbLayerNormVarSeqlenPluginMTron::clone() const noexcept
return nullptr;
}
-DimsExprs EmbLayerNormVarSeqlenPluginHFace::getOutputDimensions(
- int32_t outputIndex, DimsExprs const* inputs, int32_t nbInputs, IExprBuilder& exprBuilder) noexcept
-{
- // Input should be input ids and token ids and cumulative seqlens
- // Output should be the embeddings tensor and mask indices
- PLUGIN_ASSERT(nbInputs == 4);
-
- PLUGIN_ASSERT(inputs[0].nbDims == 1); // sum of all s
- PLUGIN_ASSERT(inputs[0].nbDims == inputs[1].nbDims);
+// End IPluginV3 method definitions
- PLUGIN_ASSERT(inputs[2].nbDims == 1); // B+1
+//////
+// IPluginV3OneRuntime method definitions:
+// - getFieldsToSerialize() (Base)
+// - onShapeChange() (Base)
+// - attachToContext() (Base)
+// - enqueue() (HFace, MTron)
+/////
- PLUGIN_ASSERT(outputIndex == 0 || outputIndex == 1);
-
- if (outputIndex == 0)
+PluginFieldCollection const* EmbLayerNormVarSeqlenPluginBase::getFieldsToSerialize() noexcept
+{
+ mDataToSerialize.clear();
+ bool output_fp16 = mType == DataType::kHALF;
+ mDataToSerialize.emplace_back("output_fp16", &output_fp16, PluginFieldType::kINT32, 1);
+ mDataToSerialize.emplace_back("bert_embeddings_layernorm_beta", static_cast(mBeta.values),
+ PluginFieldType::kFLOAT32, mBeta.count);
+ mDataToSerialize.emplace_back("bert_embeddings_layernorm_gamma", static_cast(mGamma.values),
+ PluginFieldType::kFLOAT32, mGamma.count);
+ if (output_fp16)
{
- DimsExprs ret;
- ret.nbDims = 4;
- ret.d[0] = inputs[0].d[0];
- ret.d[1] = exprBuilder.constant(mLd);
- ret.d[2] = exprBuilder.constant(1);
- ret.d[3] = exprBuilder.constant(1);
- return ret;
+ mDataToSerialize.emplace_back("bert_embeddings_word_embeddings", static_cast(mWordEmb.values),
+ PluginFieldType::kFLOAT16, mWordEmb.count);
+ mDataToSerialize.emplace_back("bert_embeddings_token_type_embeddings", static_cast(mTokEmb.values),
+ PluginFieldType::kFLOAT16, mTokEmb.count);
+ mDataToSerialize.emplace_back("bert_embeddings_position_embeddings", static_cast(mPosEmb.values),
+ PluginFieldType::kFLOAT16, mPosEmb.count);
}
-
- // Return empty tensor since this is dummy output, we do not delete it for backward compatibility.
- DimsExprs ret{};
- ret.nbDims = 0;
- return ret;
-}
-
-DimsExprs EmbLayerNormVarSeqlenPluginMTron::getOutputDimensions(
- int32_t outputIndex, DimsExprs const* inputs, int32_t nbInputs, IExprBuilder& exprBuilder) noexcept
-{
- // Input should be input ids and token ids and cumulative seqlens
- // Output should be the embeddings tensor and mask indices
- PLUGIN_ASSERT(nbInputs == 4);
-
- PLUGIN_ASSERT(inputs[0].nbDims == 1); // sum of all s
- PLUGIN_ASSERT(inputs[0].nbDims == inputs[1].nbDims);
-
- PLUGIN_ASSERT(inputs[2].nbDims == 1); // B+1
-
- PLUGIN_ASSERT(outputIndex == 0 || outputIndex == 1);
-
- DimsExprs ret;
- ret.nbDims = 4;
- ret.d[0] = inputs[0].d[0];
- ret.d[1] = exprBuilder.constant(mLd);
- ret.d[2] = exprBuilder.constant(1);
- ret.d[3] = exprBuilder.constant(1);
- return ret;
+ else
+ {
+ mDataToSerialize.emplace_back("bert_embeddings_word_embeddings", static_cast(mWordEmb.values),
+ PluginFieldType::kFLOAT32, mWordEmb.count);
+ mDataToSerialize.emplace_back("bert_embeddings_token_type_embeddings",
+ static_cast(mTokEmb.values), PluginFieldType::kFLOAT32, mTokEmb.count);
+ mDataToSerialize.emplace_back("bert_embeddings_position_embeddings", static_cast(mPosEmb.values),
+ PluginFieldType::kFLOAT32, mPosEmb.count);
+ }
+ mFCToSerialize.nbFields = mDataToSerialize.size();
+ mFCToSerialize.fields = mDataToSerialize.data();
+ return &mFCToSerialize;
}
-bool EmbLayerNormVarSeqlenPluginBase::supportsFormatCombination(
- int32_t pos, PluginTensorDesc const* inOut, int32_t nbInputs, int32_t nbOutputs) noexcept
+int32_t EmbLayerNormVarSeqlenPluginHFace::onShapeChange(
+ PluginTensorDesc const* inputs, int32_t nbInputs, PluginTensorDesc const* outputs, int32_t nbOutputs) noexcept
{
- // The four inputs to this plugin input_ids, segment_ids, cu_seqlens and a dummy input with the
- // size of the max seq length in that order
- PLUGIN_ASSERT(nbInputs == 4);
- // The two outputs of the plugin are embedding and the mask
- PLUGIN_ASSERT(nbOutputs == 2);
-
- PluginTensorDesc const& desc = inOut[pos];
- if (desc.format != TensorFormat::kLINEAR)
- {
- return false;
- }
- if (pos == 0 || pos == 2) // input_ids and cu_seqlens
+ try
{
- return desc.type == DataType::kINT32 && desc.dims.nbDims == 1;
+ BERT_DEBUG_MSG("EmbLayerNormVarSeqlenPluginHFace onShapeChange");
+ checkConfigurationInputs(inputs, nbInputs, outputs, nbOutputs);
+
+ // output 0 is the embedding
+ PLUGIN_ASSERT(static_cast(outputs[0].dims.d[1]) == static_cast(mLd));
+ PLUGIN_ASSERT(outputs[0].type == mType);
+ // output 1 is the mask indices (empty for HFace variant)
+ PLUGIN_ASSERT(outputs[1].dims.nbDims == 0);
+ PLUGIN_ASSERT(outputs[1].type == mMaskType);
+ return pluginStatus_t::STATUS_SUCCESS;
}
-
- PluginTensorDesc const& prev = inOut[pos - 1];
- if (pos == 1) // segment ids: check it's the same as input_ids
+ catch (std::exception const& e)
{
- return desc.type == DataType::kINT32 && desc.dims.nbDims == 1 && desc.dims.d[0] == prev.dims.d[0];
+ caughtError(e);
}
+ return pluginStatus_t::STATUS_FAILURE;
+}
- if (pos == 3)
+int32_t EmbLayerNormVarSeqlenPluginMTron::onShapeChange(
+ PluginTensorDesc const* inputs, int32_t nbInputs, PluginTensorDesc const* outputs, int32_t nbOutputs) noexcept
+{
+ try
{
- return desc.dims.nbDims == 1;
+ // Validate input arguments
+ BERT_DEBUG_MSG("EmbLayerNormVarSeqlenPluginMTron onShapeChange");
+ checkConfigurationInputs(inputs, nbInputs, outputs, nbOutputs);
+ PLUGIN_ASSERT(static_cast(outputs[0].dims.d[1]) == static_cast(mLd));
+
+ PLUGIN_ASSERT(outputs[1].dims.nbDims == 4);
+ PLUGIN_ASSERT(static_cast(outputs[1].dims.d[0]) == static_cast(inputs[0].dims.d[0]));
+ PLUGIN_ASSERT(static_cast(outputs[1].dims.d[1]) == static_cast(mLd));
+ PLUGIN_ASSERT(outputs[1].dims.d[2] == 1);
+ PLUGIN_ASSERT(outputs[1].dims.d[3] == 1);
+
+ PLUGIN_ASSERT(outputs[0].type == mType);
+ PLUGIN_ASSERT(outputs[1].type == mMaskType);
+ return pluginStatus_t::STATUS_SUCCESS;
}
-
- // embedded sequence
- if (pos == nbInputs)
+ catch (std::exception const& e)
{
- return desc.type == mType && desc.dims.nbDims == 4 && desc.dims.d[0] == inOut[0].dims.d[0]
- && desc.dims.d[2] == 1 && desc.dims.d[3] == 1;
+ caughtError(e);
}
- // mask
- return desc.type == mMaskType;
-}
-
-void checkConfigurationInputs(DynamicPluginTensorDesc const* inputs, int32_t nbInputs,
- DynamicPluginTensorDesc const* outputs, int32_t nbOutputs) noexcept
-{
- // Validate input arguments
- PLUGIN_ASSERT(nbInputs == 4);
- PLUGIN_ASSERT(nbOutputs == 2);
-
- PLUGIN_ASSERT(inputs[0].desc.dims.nbDims == 1);
- PLUGIN_ASSERT(inputs[1].desc.dims.nbDims == 1);
-
- PLUGIN_ASSERT(inputs[1].desc.dims.d[0] == inputs[0].desc.dims.d[0]);
-
- PLUGIN_ASSERT(inputs[2].desc.dims.nbDims == 1);
-
- PLUGIN_ASSERT(outputs[0].desc.dims.nbDims == 4);
- PLUGIN_ASSERT(static_cast(outputs[0].desc.dims.d[0]) == static_cast(inputs[0].desc.dims.d[0]));
- PLUGIN_ASSERT(outputs[0].desc.dims.d[2] == 1);
- PLUGIN_ASSERT(outputs[0].desc.dims.d[3] == 1);
-
- PLUGIN_ASSERT(inputs[0].desc.type == DataType::kINT32);
- PLUGIN_ASSERT(inputs[1].desc.type == DataType::kINT32);
- PLUGIN_ASSERT(inputs[2].desc.type == DataType::kINT32);
-}
-
-void EmbLayerNormVarSeqlenPluginHFace::configurePlugin(DynamicPluginTensorDesc const* inputs, int32_t nbInputs,
- DynamicPluginTensorDesc const* outputs, int32_t nbOutputs) noexcept
-{
- BERT_DEBUG_MSG("EmbLayerNormVarSeqlenPluginHFace configurePlugin");
- checkConfigurationInputs(inputs, nbInputs, outputs, nbOutputs);
- PLUGIN_ASSERT(static_cast(outputs[0].desc.dims.d[1]) == static_cast(mLd));
-
- // check mask
- PLUGIN_ASSERT(outputs[1].desc.dims.nbDims == 0);
- PLUGIN_ASSERT(outputs[0].desc.type == mType);
- PLUGIN_ASSERT(outputs[1].desc.type == mMaskType);
-}
-
-void EmbLayerNormVarSeqlenPluginMTron::configurePlugin(DynamicPluginTensorDesc const* inputs, int32_t nbInputs,
- DynamicPluginTensorDesc const* outputs, int32_t nbOutputs) noexcept
-{
- BERT_DEBUG_MSG("EmbLayerNormVarSeqlenPluginMTron configurePlugin");
- checkConfigurationInputs(inputs, nbInputs, outputs, nbOutputs);
- PLUGIN_ASSERT(static_cast(outputs[0].desc.dims.d[1]) == static_cast(mLd));
-
- PLUGIN_ASSERT(outputs[1].desc.dims.nbDims == 4);
- PLUGIN_ASSERT(static_cast(outputs[1].desc.dims.d[0]) == static_cast(inputs[0].desc.dims.d[0]));
- PLUGIN_ASSERT(static_cast(outputs[1].desc.dims.d[1]) == static_cast(mLd));
- PLUGIN_ASSERT(outputs[1].desc.dims.d[2] == 1);
- PLUGIN_ASSERT(outputs[1].desc.dims.d[3] == 1);
-
- PLUGIN_ASSERT(outputs[0].desc.type == mType);
- PLUGIN_ASSERT(outputs[1].desc.type == mMaskType);
+ return pluginStatus_t::STATUS_FAILURE;
}
-size_t EmbLayerNormVarSeqlenPluginBase::getWorkspaceSize(
- PluginTensorDesc const* inputs, int32_t nbInputs, PluginTensorDesc const* outputs, int32_t nbOutputs) const noexcept
+IPluginV3* EmbLayerNormVarSeqlenPluginBase::attachToContext(IPluginResourceContext* context) noexcept
{
- return 0;
+ return clone();
}
int32_t EmbLayerNormVarSeqlenPluginHFace::enqueue(PluginTensorDesc const* inputDesc,
@@ -471,126 +512,191 @@ int32_t EmbLayerNormVarSeqlenPluginMTron::enqueue(PluginTensorDesc const* inputD
return STATUS_FAILURE;
}
-// IPluginV2Ext Methods
-DataType EmbLayerNormVarSeqlenPluginBase::getOutputDataType(
- int32_t index, DataType const* inputTypes, int32_t nbInputs) const noexcept
-{
- PLUGIN_ASSERT(index == 0 || index == 1);
- PLUGIN_ASSERT(mType == DataType::kHALF || mType == DataType::kFLOAT);
- return index == 0 ? mType : mMaskType;
-}
+// end IPluginV3OneRuntime method definitions
-// IPluginV2 Methods
-char const* EmbLayerNormVarSeqlenPluginBase::getPluginType() const noexcept
-{
- return EMB_LAYER_NORM_VAR_SEQLEN_NAME;
-}
-
-char const* EmbLayerNormVarSeqlenPluginHFace::getPluginVersion() const noexcept
-{
- return EMB_LAYER_NORM_VAR_SEQLEN_VERSION_HFACE;
-}
-
-char const* EmbLayerNormVarSeqlenPluginMTron::getPluginVersion() const noexcept
-{
- return EMB_LAYER_NORM_VAR_SEQLEN_VERSION_MTRON;
-}
+///////
+// IPluginV3OneBuild method definitions
+// - getNbOutputs() (Base)
+// - supportsFormatCombination() (Base)
+// - getOutputShapes (HFace, MTron)
+// - getOutputDataTypes() (Base)
+// - configurePlugin() (Base)
+// - getWorkSpaceSize() (Base)
+//////
int32_t EmbLayerNormVarSeqlenPluginBase::getNbOutputs() const noexcept
{
return 2;
}
-int32_t EmbLayerNormVarSeqlenPluginHFace::initialize() noexcept
+bool EmbLayerNormVarSeqlenPluginBase::supportsFormatCombination(
+ int32_t pos, DynamicPluginTensorDesc const* inOut, int32_t nbInputs, int32_t nbOutputs) noexcept
{
- BERT_DEBUG_MSG("EmbLayerNormVarSeqlenPluginHFace initialize");
- return 0;
-}
+ // The four inputs to this plugin input_ids, segment_ids, cu_seqlens and a dummy input with the
+ // size of the max seq length in that order
+ PLUGIN_ASSERT(nbInputs == 4);
+ // The two outputs of the plugin are embedding and the mask
+ PLUGIN_ASSERT(nbOutputs == 2);
-int32_t EmbLayerNormVarSeqlenPluginMTron::initialize() noexcept
-{
- BERT_DEBUG_MSG("EmbLayerNormVarSeqlenPluginMTron initialize");
- return 0;
+ PluginTensorDesc const& desc = inOut[pos].desc;
+ if (desc.format != TensorFormat::kLINEAR)
+ {
+ return false;
+ }
+ if (pos == 0 || pos == 2) // input_ids and cu_seqlens
+ {
+ return desc.type == DataType::kINT32 && desc.dims.nbDims == 1;
+ }
+
+ PluginTensorDesc const& prev = inOut[pos - 1].desc;
+ if (pos == 1) // segment ids: check it's the same as input_ids
+ {
+ return desc.type == DataType::kINT32 && desc.dims.nbDims == 1 && desc.dims.d[0] == prev.dims.d[0];
+ }
+
+ if (pos == 3)
+ {
+ return desc.dims.nbDims == 1;
+ }
+
+ // embedded sequence
+ if (pos == nbInputs)
+ {
+ return desc.type == mType && desc.dims.nbDims == 4 && desc.dims.d[0] == inOut[0].desc.dims.d[0]
+ && desc.dims.d[2] == 1 && desc.dims.d[3] == 1;
+ }
+ // mask
+ return desc.type == mMaskType;
}
-void EmbLayerNormVarSeqlenPluginHFace::terminate() noexcept
+int32_t EmbLayerNormVarSeqlenPluginHFace::getOutputShapes(DimsExprs const* inputs, int32_t nbInputs,
+ DimsExprs const* shapeInputs, int32_t nbShapeInputs, DimsExprs* outputs, int32_t nbOutputs,
+ IExprBuilder& exprBuilder) noexcept
{
- BERT_DEBUG_MSG("EmbLayerNormVarSeqlenPluginHFace terminate");
+ try
+ {
+ PLUGIN_VALIDATE(inputs != nullptr);
+ PLUGIN_VALIDATE(outputs != nullptr);
+
+ // Input should be input ids and token ids and cumulative seqlens
+ // Output should be the embeddings tensor and mask indices
+ PLUGIN_ASSERT(nbInputs == 4);
+ PLUGIN_ASSERT(nbOutputs == 2);
+
+ PLUGIN_ASSERT(inputs[0].nbDims == 1); // sum of all s
+ PLUGIN_ASSERT(inputs[0].nbDims == inputs[1].nbDims);
+
+ PLUGIN_ASSERT(inputs[2].nbDims == 1); // B+1
+
+ // output 0 : embedded input
+ outputs[0].nbDims = 4;
+ outputs[0].d[0] = inputs[0].d[0];
+ outputs[0].d[1] = exprBuilder.constant(mLd);
+ outputs[0].d[2] = exprBuilder.constant(1);
+ outputs[0].d[3] = exprBuilder.constant(1);
+
+ // Output 1 : maskIdx
+ // Return empty tensor since this is dummy output, we do not delete it for backward compatibility.
+ outputs[1].nbDims = 0;
+ return pluginStatus_t::STATUS_SUCCESS;
+ }
+ catch (const std::exception& e)
+ {
+ caughtError(e);
+ }
+ return pluginStatus_t::STATUS_FAILURE;
}
-void EmbLayerNormVarSeqlenPluginMTron::terminate() noexcept
+int32_t EmbLayerNormVarSeqlenPluginMTron::getOutputShapes(DimsExprs const* inputs, int32_t nbInputs,
+ DimsExprs const* shapeInputs, int32_t nbShapeInputs, DimsExprs* outputs, int32_t nbOutputs,
+ IExprBuilder& exprBuilder) noexcept
{
- BERT_DEBUG_MSG("EmbLayerNormVarSeqlenPluginMTron terminate");
+ try
+ {
+ PLUGIN_VALIDATE(inputs != nullptr);
+ PLUGIN_VALIDATE(outputs != nullptr);
+ // Input should be input ids and token ids and cumulative seqlens
+ // Output should be the embeddings tensor and mask indices
+ PLUGIN_ASSERT(nbInputs == 4);
+ PLUGIN_ASSERT(nbOutputs == 2);
+
+ PLUGIN_ASSERT(inputs[0].nbDims == 1); // sum of all s
+ PLUGIN_ASSERT(inputs[0].nbDims == inputs[1].nbDims);
+ PLUGIN_ASSERT(inputs[2].nbDims == 1); // B+1
+
+ // Output 0 : embedded input
+ outputs[0].nbDims = 4;
+ outputs[0].d[0] = inputs[0].d[0];
+ outputs[0].d[1] = exprBuilder.constant(mLd);
+ outputs[0].d[2] = exprBuilder.constant(1);
+ outputs[0].d[3] = exprBuilder.constant(1);
+
+ // Output 1 : maskIdx
+ outputs[1].nbDims = 4;
+ outputs[1].d[0] = inputs[0].d[0];
+ outputs[1].d[1] = exprBuilder.constant(mLd);
+ outputs[1].d[2] = exprBuilder.constant(1);
+ outputs[1].d[3] = exprBuilder.constant(1);
+
+ return pluginStatus_t::STATUS_SUCCESS;
+ }
+ catch (const std::exception& e)
+ {
+ caughtError(e);
+ }
+ return pluginStatus_t::STATUS_FAILURE;
}
-size_t EmbLayerNormVarSeqlenPluginBase::getSerializationSize() const noexcept
+int32_t EmbLayerNormVarSeqlenPluginBase::getOutputDataTypes(
+ DataType* outputTypes, int32_t nbOutputs, DataType const* inputTypes, int32_t nbInputs) const noexcept
{
- size_t const wordSize = getElementSize(mType);
- return 2 * sizeof(float) * mLd // beta + gamma
- + sizeof(mType) //
- + sizeof(mLd) //
- + sizeof(mWordVocabSize) //
- + sizeof(mPosVocabSize) //
- + sizeof(mTokVocabSize) //
- + wordSize * mLd * mWordVocabSize // word emb
- + wordSize * mLd * mPosVocabSize // pos emb
- + wordSize * mLd * mTokVocabSize // tok emb
- + sizeof(mMaskType) // mask type
- ;
+ try
+ {
+ PLUGIN_ASSERT(mType == DataType::kHALF || mType == DataType::kFLOAT);
+ outputTypes[0] = mType;
+ outputTypes[1] = mMaskType;
+ return pluginStatus_t::STATUS_SUCCESS;
+ }
+ catch (std::exception const& e)
+ {
+ caughtError(e);
+ }
+ return pluginStatus_t::STATUS_FAILURE;
}
-void EmbLayerNormVarSeqlenPluginBase::serialize(void* buffer) const noexcept
+int32_t EmbLayerNormVarSeqlenPluginBase::configurePlugin(DynamicPluginTensorDesc const* inputs, int32_t nbInputs,
+ DynamicPluginTensorDesc const* outputs, int32_t nbOutputs) noexcept
{
- serialize_value(&buffer, mType);
- serialize_value(&buffer, mLd);
- serialize_value(&buffer, mWordVocabSize);
- serialize_value(&buffer, mPosVocabSize);
- serialize_value(&buffer, mTokVocabSize);
- serialize_value(&buffer, mMaskType);
-
- char* d = static_cast(buffer);
- size_t const wordSize = getElementSize(mType);
-
- serFromDev(d, mBetaDev.get(), mLd);
- serFromDev(d, mGammaDev.get(), mLd);
- serFromDev(d, static_cast(mWordEmbDev.get()), mLd * mWordVocabSize * wordSize);
- serFromDev(d, static_cast(mPosEmbDev.get()), mLd * mPosVocabSize * wordSize);
- serFromDev(d, static_cast(mTokEmbDev.get()), mLd * mTokVocabSize * wordSize);
+ return pluginStatus_t::STATUS_SUCCESS;
}
-void EmbLayerNormVarSeqlenPluginBase::destroy() noexcept
+size_t EmbLayerNormVarSeqlenPluginBase::getWorkspaceSize(DynamicPluginTensorDesc const* inputs, int32_t nbInputs,
+ DynamicPluginTensorDesc const* outputs, int32_t nbOutputs) const noexcept
{
- // This gets called when the network containing plugin is destroyed
- mGammaDev.reset(nullptr);
- mBetaDev.reset(nullptr);
- mWordEmbDev.reset(nullptr);
- mPosEmbDev.reset(nullptr);
- mTokEmbDev.reset(nullptr);
- delete this;
+ return 0;
}
+// End IPluginV3OneBuild method definitions
-void EmbLayerNormVarSeqlenPluginHFace::destroy() noexcept
+//////
+// IPluginV3OneCore method definitions
+// - getPluginVersion() (MTron, HFace)
+// - getPluginName() (Base)
+// - getPluginNamespace() (Base)
+// - setPluginNamespace() (Base)
+//////
+char const* EmbLayerNormVarSeqlenPluginHFace::getPluginVersion() const noexcept
{
- BERT_DEBUG_MSG("EmbLayerNormVarSeqlenPluginHFace destroy");
- EmbLayerNormVarSeqlenPluginBase::destroy();
+ return kEMB_LAYER_NORM_VAR_SEQLEN_VERSION_HFACE;
}
-void EmbLayerNormVarSeqlenPluginMTron::destroy() noexcept
+char const* EmbLayerNormVarSeqlenPluginMTron::getPluginVersion() const noexcept
{
- BERT_DEBUG_MSG("EmbLayerNormVarSeqlenPluginMTron destroy");
- EmbLayerNormVarSeqlenPluginBase::destroy();
+ return kEMB_LAYER_NORM_VAR_SEQLEN_VERSION_MTRON;
}
-void EmbLayerNormVarSeqlenPluginBase::setPluginNamespace(char const* libNamespace) noexcept
+char const* EmbLayerNormVarSeqlenPluginBase::getPluginName() const noexcept
{
- try
- {
- mNamespace = libNamespace;
- }
- catch (std::exception const& e)
- {
- caughtError(e);
- }
+ return kEMB_LAYER_NORM_VAR_SEQLEN_NAME;
}
char const* EmbLayerNormVarSeqlenPluginBase::getPluginNamespace() const noexcept
@@ -598,34 +704,46 @@ char const* EmbLayerNormVarSeqlenPluginBase::getPluginNamespace() const noexcept
return mNamespace.c_str();
}
-///////////////////////
+void EmbLayerNormVarSeqlenPluginBase::setPluginNamespace(char const* libNamespace) noexcept
+{
+ mNamespace = libNamespace;
+}
+// End IPluginV3OneCore method definitions
+
+//////////////////////////// Plugin Creator member definitions /////////////////////////////
EmbLayerNormVarSeqlenPluginBaseCreator::EmbLayerNormVarSeqlenPluginBaseCreator()
{
+ static std::mutex sMutex;
+ std::lock_guard lock(sMutex);
mPluginAttributes.clear();
- mPluginAttributes.emplace_back(PluginField("bert_embeddings_layernorm_beta"));
- mPluginAttributes.emplace_back(PluginField("bert_embeddings_layernorm_gamma"));
- mPluginAttributes.emplace_back(PluginField("bert_embeddings_word_embeddings"));
- mPluginAttributes.emplace_back(PluginField("bert_embeddings_token_type_embeddings"));
- mPluginAttributes.emplace_back(PluginField("bert_embeddings_position_embeddings"));
- mPluginAttributes.emplace_back(PluginField("output_fp16"));
+ mPluginAttributes.emplace_back(PluginField("output_fp16", nullptr, PluginFieldType::kINT32, 1));
+ // the length of beta, gamma, word_emb, pos_emb, and tok_emb will only be known at the time of plugin creation
+ // so we set it to 0 here
+ mPluginAttributes.emplace_back(PluginField("bert_embeddings_layernorm_beta", nullptr, PluginFieldType::kFLOAT32, 0));
+ mPluginAttributes.emplace_back(PluginField("bert_embeddings_layernorm_gamma", nullptr, PluginFieldType::kFLOAT32, 0));
+ // the embeddings datatype is determined by the output_fp16 attribute known at runtime
+ // so we set it to kUNKNOWN here
+ mPluginAttributes.emplace_back(PluginField("bert_embeddings_word_embeddings", nullptr, PluginFieldType::kUNKNOWN, 0));
+ mPluginAttributes.emplace_back(PluginField("bert_embeddings_token_type_embeddings", nullptr, PluginFieldType::kUNKNOWN, 0));
+ mPluginAttributes.emplace_back(PluginField("bert_embeddings_position_embeddings", nullptr, PluginFieldType::kUNKNOWN, 0));
mFC.nbFields = mPluginAttributes.size();
mFC.fields = mPluginAttributes.data();
}
char const* EmbLayerNormVarSeqlenPluginBaseCreator::getPluginName() const noexcept
{
- return EMB_LAYER_NORM_VAR_SEQLEN_NAME;
+ return kEMB_LAYER_NORM_VAR_SEQLEN_NAME;
}
char const* EmbLayerNormVarSeqlenPluginHFaceCreator::getPluginVersion() const noexcept
{
- return EMB_LAYER_NORM_VAR_SEQLEN_VERSION_HFACE;
+ return kEMB_LAYER_NORM_VAR_SEQLEN_VERSION_HFACE;
}
char const* EmbLayerNormVarSeqlenPluginMTronCreator::getPluginVersion() const noexcept
{
- return EMB_LAYER_NORM_VAR_SEQLEN_VERSION_MTRON;
+ return kEMB_LAYER_NORM_VAR_SEQLEN_VERSION_MTRON;
}
PluginFieldCollection const* EmbLayerNormVarSeqlenPluginBaseCreator::getFieldNames() noexcept
@@ -633,73 +751,8 @@ PluginFieldCollection const* EmbLayerNormVarSeqlenPluginBaseCreator::getFieldNam
return &mFC;
}
-bool initializeFields(char const* name, PluginFieldCollection const* fc, Weights& beta, Weights& gamma,
- Weights& word_emb, Weights& pos_emb, Weights& tok_emb)
-{
- bool output_fp16 = false;
- std::set const requiredAttributes{
- "bert_embeddings_layernorm_beta",
- "bert_embeddings_layernorm_gamma",
- "bert_embeddings_word_embeddings",
- "bert_embeddings_token_type_embeddings",
- "bert_embeddings_position_embeddings",
- };
- plugin::validateRequiredAttributesExist(requiredAttributes, fc);
-
- for (int32_t i = 0; i < fc->nbFields; i++)
- {
- std::string field_name(fc->fields[i].name);
- if (field_name.compare("bert_embeddings_layernorm_beta") == 0)
- {
- BERT_DEBUG_MSG("Building bert_embeddings_layernorm_beta...");
- beta.values = fc->fields[i].data;
- beta.count = fc->fields[i].length;
- beta.type = fieldTypeToDataType(fc->fields[i].type);
- }
-
- if (field_name.compare("bert_embeddings_layernorm_gamma") == 0)
- {
- BERT_DEBUG_MSG("Building bert_embeddings_layernorm_gamma...");
- gamma.values = fc->fields[i].data;
- gamma.count = fc->fields[i].length;
- gamma.type = fieldTypeToDataType(fc->fields[i].type);
- }
-
- if (field_name.compare("bert_embeddings_word_embeddings") == 0)
- {
- BERT_DEBUG_MSG("Building bert_embeddings_word_embeddings...");
- word_emb.values = fc->fields[i].data;
- word_emb.count = fc->fields[i].length;
- word_emb.type = fieldTypeToDataType(fc->fields[i].type);
- }
-
- if (field_name.compare("bert_embeddings_token_type_embeddings") == 0)
- {
- BERT_DEBUG_MSG("Building bert_embeddings_token_type_embeddings...");
- tok_emb.values = fc->fields[i].data;
- tok_emb.count = fc->fields[i].length;
- tok_emb.type = fieldTypeToDataType(fc->fields[i].type);
- }
-
- if (field_name.compare("bert_embeddings_position_embeddings") == 0)
- {
- BERT_DEBUG_MSG("Building bert_embeddings_position_embeddings...");
- pos_emb.values = fc->fields[i].data;
- pos_emb.count = fc->fields[i].length;
- pos_emb.type = fieldTypeToDataType(fc->fields[i].type);
- }
- if (field_name.compare("output_fp16") == 0)
- {
- BERT_DEBUG_MSG("Building output_fp16...");
- PLUGIN_VALIDATE(fc->fields[i].type == PluginFieldType::kINT32);
- output_fp16 = static_cast(fc->fields[i].data)[0] != 0;
- }
- }
- return output_fp16;
-}
-
-IPluginV2* EmbLayerNormVarSeqlenPluginHFaceCreator::createPlugin(
- char const* name, PluginFieldCollection const* fc) noexcept
+IPluginV3* EmbLayerNormVarSeqlenPluginHFaceCreator::createPlugin(
+ char const* name, PluginFieldCollection const* fc, TensorRTPhase phase) noexcept
{
try
{
@@ -729,8 +782,8 @@ IPluginV2* EmbLayerNormVarSeqlenPluginHFaceCreator::createPlugin(
return nullptr;
}
-IPluginV2* EmbLayerNormVarSeqlenPluginMTronCreator::createPlugin(
- char const* name, PluginFieldCollection const* fc) noexcept
+IPluginV3* EmbLayerNormVarSeqlenPluginMTronCreator::createPlugin(
+ char const* name, PluginFieldCollection const* fc, TensorRTPhase phase) noexcept
{
try
{
@@ -760,38 +813,6 @@ IPluginV2* EmbLayerNormVarSeqlenPluginMTronCreator::createPlugin(
return nullptr;
}
-IPluginV2* EmbLayerNormVarSeqlenPluginHFaceCreator::deserializePlugin(
- char const* name, void const* serialData, size_t serialLength) noexcept
-{
- try
- {
- // This object will be deleted when the network is destroyed, which will
- // call EmbLayerNormVarSeqlen::destroy()
- return new EmbLayerNormVarSeqlenPluginHFace(name, serialData, serialLength);
- }
- catch (std::exception const& e)
- {
- caughtError(e);
- }
- return nullptr;
-}
-
-IPluginV2* EmbLayerNormVarSeqlenPluginMTronCreator::deserializePlugin(
- char const* name, void const* serialData, size_t serialLength) noexcept
-{
- try
- {
- // This object will be deleted when the network is destroyed, which will
- // call EmbLayerNormVarSeqlen::destroy()
- return new EmbLayerNormVarSeqlenPluginMTron(name, serialData, serialLength);
- }
- catch (std::exception const& e)
- {
- caughtError(e);
- }
- return nullptr;
-}
-
void EmbLayerNormVarSeqlenPluginBaseCreator::setPluginNamespace(char const* libNamespace) noexcept
{
try
diff --git a/plugin/embLayerNormPlugin/embLayerNormVarSeqlenPlugin.h b/plugin/embLayerNormPlugin/embLayerNormVarSeqlenPlugin.h
index d3141a6b..57612d06 100644
--- a/plugin/embLayerNormPlugin/embLayerNormVarSeqlenPlugin.h
+++ b/plugin/embLayerNormPlugin/embLayerNormVarSeqlenPlugin.h
@@ -43,41 +43,63 @@ int32_t embSkipLayerNormMTron(cudaStream_t stream, int32_t ld, int32_t B, int32_
int32_t const* tokenIds, int32_t const* cuSeqlens, float const* beta, float const* gamma, T const* wordEmb,
T const* posEmb, T const* tokEmb, int32_t const wordSize, int32_t const tokSize, T* output, T* skip);
-class EmbLayerNormVarSeqlenPluginBase : public nvinfer1::IPluginV2DynamicExt
+class EmbLayerNormVarSeqlenPluginBase : public IPluginV3,
+ public IPluginV3OneCore,
+ public IPluginV3OneBuild,
+ public IPluginV3OneRuntime
{
public:
EmbLayerNormVarSeqlenPluginBase(std::string const& name, DataType type, Weights const& beta, Weights const& gamma,
Weights const& word_emb, Weights const& pos_emb, Weights const& tok_emb, DataType maskType);
- EmbLayerNormVarSeqlenPluginBase(std::string const& name, void const* data, size_t length);
-
// It doesn't make sense to make EmbLayerNormVarSeqlenPlugin without arguments, so we
// delete default constructor.
EmbLayerNormVarSeqlenPluginBase() = delete;
- // IPluginV2DynamicExt Methods
+ ~EmbLayerNormVarSeqlenPluginBase() override;
+
+ // IPluginV3 Methods
+ // NOTE: since this is itself is an abstract class, the rest of virtual methods defined in its children classes
+ IPluginCapability* getCapabilityInterface(PluginCapabilityType type) noexcept override;
+ // end of IPluginV3 Methods
+
+ // IPluginV3OneCore Methods
+ char const* getPluginName() const noexcept override;
+
+ char const* getPluginNamespace() const noexcept override;
+
+ void setPluginNamespace(char const* pluginNamespace) noexcept;
+ // end of IPluginV3OneCore Methods
+
+ // IPluginV3Build Methods
bool supportsFormatCombination(
- int32_t pos, nvinfer1::PluginTensorDesc const* inOut, int32_t nbInputs, int32_t nbOutputs) noexcept override;
- size_t getWorkspaceSize(nvinfer1::PluginTensorDesc const* inputs, int32_t nbInputs,
- nvinfer1::PluginTensorDesc const* outputs, int32_t nbOutputs) const noexcept override;
+ int32_t pos, DynamicPluginTensorDesc const* inOut, int32_t nbInputs, int32_t nbOutputs) noexcept override;
+
+ int32_t configurePlugin(DynamicPluginTensorDesc const* in, int32_t nbInputs, DynamicPluginTensorDesc const* out,
+ int32_t nbOutputs) noexcept override;
+
+ size_t getWorkspaceSize(DynamicPluginTensorDesc const* inputs, int32_t nbInputs,
+ DynamicPluginTensorDesc const* outputs, int32_t nbOutputs) const noexcept override;
- // IPluginV2Ext Methods
- nvinfer1::DataType getOutputDataType(
- int32_t index, nvinfer1::DataType const* inputTypes, int32_t nbInputs) const noexcept override;
+ int32_t getOutputDataTypes(
+ DataType* outputTypes, int32_t nbOutputs, DataType const* inputTypes, int32_t nbInputs) const noexcept override;
- // IPluginV2 Methods
- char const* getPluginType() const noexcept override;
int32_t getNbOutputs() const noexcept override;
- size_t getSerializationSize() const noexcept override;
- void serialize(void* buffer) const noexcept override;
- void destroy() noexcept override;
- char const* getPluginNamespace() const noexcept override;
- void setPluginNamespace(char const* pluginNamespace) noexcept override;
+ // end IPluginV3Build Methods
+
+ // IPluginV3Runtime Methods
+
+ IPluginV3* attachToContext(IPluginResourceContext* context) noexcept override;
+
+ PluginFieldCollection const* getFieldsToSerialize() noexcept override;
+ // end IPluginV3Runtime Methods
protected:
+ // metadata fields
std::string const mLayerName;
std::string mNamespace;
+ // device-side
bert::cuda_unique_ptr mGammaDev;
bert::cuda_unique_ptr mBetaDev;
bert::cuda_unique_ptr mWordEmbDev;
@@ -87,6 +109,8 @@ class EmbLayerNormVarSeqlenPluginBase : public nvinfer1::IPluginV2DynamicExt
size_t mWordVocabSize;
size_t mPosVocabSize;
size_t mTokVocabSize;
+
+ // members that partcipate in ser/deserialization
bert::WeightsWithOwnership mBeta;
bert::WeightsWithOwnership mGamma;
bert::WeightsWithOwnership mWordEmb;
@@ -94,6 +118,10 @@ class EmbLayerNormVarSeqlenPluginBase : public nvinfer1::IPluginV2DynamicExt
bert::WeightsWithOwnership mPosEmb;
DataType mType{};
DataType mMaskType{};
+
+ // IPluginV3 serialization related
+ std::vector mDataToSerialize;
+ nvinfer1::PluginFieldCollection mFCToSerialize;
};
class EmbLayerNormVarSeqlenPluginHFace : public EmbLayerNormVarSeqlenPluginBase
@@ -103,26 +131,27 @@ class EmbLayerNormVarSeqlenPluginHFace : public EmbLayerNormVarSeqlenPluginBase
nvinfer1::Weights const& beta, nvinfer1::Weights const& gamma, nvinfer1::Weights const& word_emb,
nvinfer1::Weights const& pos_emb, nvinfer1::Weights const& tok_emb);
- EmbLayerNormVarSeqlenPluginHFace(std::string const& name, void const* data, size_t length);
-
// It doesn't make sense to make EmbLayerNormVarSeqlenPlugin without arguments, so we
// delete default constructor.
EmbLayerNormVarSeqlenPluginHFace() = delete;
- // IPluginV2DynamicExt Methods
- nvinfer1::IPluginV2DynamicExt* clone() const noexcept override;
- nvinfer1::DimsExprs getOutputDimensions(int32_t outputIndex, nvinfer1::DimsExprs const* inputs, int32_t nbInputs,
- nvinfer1::IExprBuilder& exprBuilder) noexcept override;
- void configurePlugin(nvinfer1::DynamicPluginTensorDesc const* in, int32_t nbInputs,
- nvinfer1::DynamicPluginTensorDesc const* out, int32_t nbOutputs) noexcept override;
+ ~EmbLayerNormVarSeqlenPluginHFace() override;
+
+ // IPluginV3Runtime overrides
+ IPluginV3* clone() noexcept;
+
+ int32_t onShapeChange(
+ PluginTensorDesc const* in, int32_t nbInputs, PluginTensorDesc const* out, int32_t nbOutputs) noexcept override;
+
int32_t enqueue(nvinfer1::PluginTensorDesc const* inputDesc, nvinfer1::PluginTensorDesc const* outputDesc,
void const* const* inputs, void* const* outputs, void* workspace, cudaStream_t stream) noexcept override;
- // IPluginV2 Methods
- int32_t initialize() noexcept override;
- void terminate() noexcept override;
- void destroy() noexcept override;
+ // IPluginV3OneCore override
char const* getPluginVersion() const noexcept override;
+
+ // IPluginV3OneBuild override
+ int32_t getOutputShapes(DimsExprs const* inputs, int32_t nbInputs, DimsExprs const* shapeInputs,
+ int32_t nbShapeInputs, DimsExprs* outputs, int32_t nbOutputs, IExprBuilder& exprBuilder) noexcept override;
};
class EmbLayerNormVarSeqlenPluginMTron : public EmbLayerNormVarSeqlenPluginBase
@@ -132,29 +161,30 @@ class EmbLayerNormVarSeqlenPluginMTron : public EmbLayerNormVarSeqlenPluginBase
nvinfer1::Weights const& beta, nvinfer1::Weights const& gamma, nvinfer1::Weights const& word_emb,
nvinfer1::Weights const& pos_emb, nvinfer1::Weights const& tok_emb);
- EmbLayerNormVarSeqlenPluginMTron(std::string const& name, void const* data, size_t length);
-
// It doesn't make sense to make EmbLayerNormVarSeqlenPlugin without arguments, so we
// delete default constructor.
EmbLayerNormVarSeqlenPluginMTron() = delete;
- // IPluginV2DynamicExt Methods
- nvinfer1::IPluginV2DynamicExt* clone() const noexcept override;
- nvinfer1::DimsExprs getOutputDimensions(int32_t outputIndex, nvinfer1::DimsExprs const* inputs, int32_t nbInputs,
- nvinfer1::IExprBuilder& exprBuilder) noexcept override;
- void configurePlugin(nvinfer1::DynamicPluginTensorDesc const* in, int32_t nbInputs,
- nvinfer1::DynamicPluginTensorDesc const* out, int32_t nbOutputs) noexcept override;
+ ~EmbLayerNormVarSeqlenPluginMTron() override;
+
+ // IPluginV3Runtime overrides
+ IPluginV3* clone() noexcept;
+
+ int32_t onShapeChange(
+ PluginTensorDesc const* in, int32_t nbInputs, PluginTensorDesc const* out, int32_t nbOutputs) noexcept override;
+
int32_t enqueue(nvinfer1::PluginTensorDesc const* inputDesc, nvinfer1::PluginTensorDesc const* outputDesc,
void const* const* inputs, void* const* outputs, void* workspace, cudaStream_t stream) noexcept override;
- // IPluginV2 Methods
- int32_t initialize() noexcept override;
- void terminate() noexcept override;
- void destroy() noexcept override;
+ // IPluginV3OneCore override
char const* getPluginVersion() const noexcept override;
+
+ // IPluginV3OneBuild override
+ int32_t getOutputShapes(DimsExprs const* inputs, int32_t nbInputs, DimsExprs const* shapeInputs,
+ int32_t nbShapeInputs, DimsExprs* outputs, int32_t nbOutputs, IExprBuilder& exprBuilder) noexcept override;
};
-class EmbLayerNormVarSeqlenPluginBaseCreator : public nvinfer1::IPluginCreator
+class EmbLayerNormVarSeqlenPluginBaseCreator : public nvinfer1::IPluginCreatorV3One
{
public:
EmbLayerNormVarSeqlenPluginBaseCreator();
@@ -163,7 +193,7 @@ class EmbLayerNormVarSeqlenPluginBaseCreator : public nvinfer1::IPluginCreator
nvinfer1::PluginFieldCollection const* getFieldNames() noexcept override;
- void setPluginNamespace(char const* pluginNamespace) noexcept override;
+ void setPluginNamespace(char const* libNamespace) noexcept;
char const* getPluginNamespace() const noexcept override;
@@ -176,19 +206,15 @@ class EmbLayerNormVarSeqlenPluginBaseCreator : public nvinfer1::IPluginCreator
class EmbLayerNormVarSeqlenPluginHFaceCreator : public EmbLayerNormVarSeqlenPluginBaseCreator
{
public:
- nvinfer1::IPluginV2* createPlugin(char const* name, nvinfer1::PluginFieldCollection const* fc) noexcept override;
+ IPluginV3* createPlugin(char const* name, PluginFieldCollection const* fc, TensorRTPhase phase) noexcept override;
char const* getPluginVersion() const noexcept override;
- nvinfer1::IPluginV2* deserializePlugin(
- char const* name, void const* serialData, size_t serialLength) noexcept override;
};
class EmbLayerNormVarSeqlenPluginMTronCreator : public EmbLayerNormVarSeqlenPluginBaseCreator
{
public:
- nvinfer1::IPluginV2* createPlugin(char const* name, nvinfer1::PluginFieldCollection const* fc) noexcept override;
+ IPluginV3* createPlugin(char const* name, PluginFieldCollection const* fc, TensorRTPhase phase) noexcept override;
char const* getPluginVersion() const noexcept override;
- nvinfer1::IPluginV2* deserializePlugin(
- char const* name, void const* serialData, size_t serialLength) noexcept override;
};
} // namespace bert
diff --git a/plugin/embLayerNormPlugin/embLayerNormVarSeqlenPluginLegacy.cpp b/plugin/embLayerNormPlugin/embLayerNormVarSeqlenPluginLegacy.cpp
new file mode 100644
index 00000000..f86700fd
--- /dev/null
+++ b/plugin/embLayerNormPlugin/embLayerNormVarSeqlenPluginLegacy.cpp
@@ -0,0 +1,814 @@
+/*
+ * SPDX-FileCopyrightText: Copyright (c) 1993-2024 NVIDIA CORPORATION & AFFILIATES. All rights reserved.
+ * SPDX-License-Identifier: Apache-2.0
+ *
+ * Licensed under the Apache License, Version 2.0 (the "License");
+ * you may not use this file except in compliance with the License.
+ * You may obtain a copy of the License at
+ *
+ * http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing, software
+ * distributed under the License is distributed on an "AS IS" BASIS,
+ * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+ * See the License for the specific language governing permissions and
+ * limitations under the License.
+ */
+
+#include
+#include
+#include
+#include
+
+#include "NvInfer.h"
+#include "common/serialize.hpp"
+#include "embLayerNormVarSeqlenPluginLegacy.h"
+
+using namespace nvinfer1;
+using namespace nvinfer1::plugin;
+using namespace nvinfer1::plugin::bert;
+
+namespace
+{
+constexpr char const* kEMB_LAYER_NORM_VAR_SEQLEN_VERSION_HFACE{"2"};
+constexpr char const* kEMB_LAYER_NORM_VAR_SEQLEN_VERSION_MTRON{"3"};
+constexpr char const* kEMB_LAYER_NORM_VAR_SEQLEN_NAME{"CustomEmbLayerNormPluginDynamic"};
+} // namespace
+
+// Static class fields initialization
+PluginFieldCollection EmbLayerNormVarSeqlenPluginLegacyBaseCreator::mFC{};
+std::vector EmbLayerNormVarSeqlenPluginLegacyBaseCreator::mPluginAttributes;
+
+REGISTER_TENSORRT_PLUGIN(EmbLayerNormVarSeqlenPluginLegacyHFaceCreator);
+REGISTER_TENSORRT_PLUGIN(EmbLayerNormVarSeqlenPluginLegacyMTronCreator);
+
+EmbLayerNormVarSeqlenPluginLegacyBase::EmbLayerNormVarSeqlenPluginLegacyBase(std::string const& name, DataType type,
+ Weights const& beta, Weights const& gamma, Weights const& wordEmb, Weights const& posEmb, Weights const& tokEmb,
+ DataType maskType)
+ : mLayerName(name)
+ , mLd(beta.count)
+ , mType(type)
+ , mMaskType(maskType)
+{
+ // Assuming Weights.count is the number of elements and not bytes
+ PLUGIN_VALIDATE(beta.count == gamma.count);
+ PLUGIN_VALIDATE(mLd > 0U);
+ PLUGIN_VALIDATE(wordEmb.count % mLd == 0);
+ PLUGIN_VALIDATE(posEmb.count % mLd == 0);
+ PLUGIN_VALIDATE(tokEmb.count % mLd == 0);
+ mWordVocabSize = wordEmb.count / mLd;
+ mPosVocabSize = posEmb.count / mLd;
+ mTokVocabSize = tokEmb.count / mLd;
+
+ mBeta.convertAndCopy(beta, nvinfer1::DataType::kFLOAT);
+ mGamma.convertAndCopy(gamma, nvinfer1::DataType::kFLOAT);
+
+ mWordEmb.convertAndCopy(wordEmb, mType);
+ mTokEmb.convertAndCopy(tokEmb, mType);
+ mPosEmb.convertAndCopy(posEmb, mType);
+
+ copyToDevice(mGamma, sizeof(float) * mGamma.count, mGammaDev);
+ copyToDevice(mBeta, sizeof(float) * mBeta.count, mBetaDev);
+
+ copyToDevice(mWordEmb, getWeightsSize(mWordEmb, mType), mWordEmbDev);
+ copyToDevice(mPosEmb, getWeightsSize(mPosEmb, mType), mPosEmbDev);
+ copyToDevice(mTokEmb, getWeightsSize(mTokEmb, mType), mTokEmbDev);
+}
+
+EmbLayerNormVarSeqlenPluginLegacyBase::EmbLayerNormVarSeqlenPluginLegacyBase(
+ std::string const& name, void const* data, size_t length)
+ : mLayerName(name)
+ , mGammaDev(nullptr)
+ , mBetaDev(nullptr)
+ , mWordEmbDev(nullptr)
+ , mTokEmbDev(nullptr)
+ , mPosEmbDev(nullptr)
+{
+ // Deserialize in the same order as serialization
+ deserialize_value(&data, &length, &mType);
+ deserialize_value(&data, &length, &mLd);
+ deserialize_value(&data, &length, &mWordVocabSize);
+ deserialize_value(&data, &length, &mPosVocabSize);
+ deserialize_value(&data, &length, &mTokVocabSize);
+ deserialize_value(&data, &length, &mMaskType);
+
+ char const* d = static_cast(data);
+ mBeta.convertAndCopy(d, mLd, nvinfer1::DataType::kFLOAT);
+ mGamma.convertAndCopy(d, mLd, nvinfer1::DataType::kFLOAT);
+
+ mWordEmb.convertAndCopy(d, mLd * mWordVocabSize, mType);
+ mPosEmb.convertAndCopy(d, mLd * mPosVocabSize, mType);
+ mTokEmb.convertAndCopy(d, mLd * mTokVocabSize, mType);
+
+ copyToDevice(mGamma, sizeof(float) * mGamma.count, mGammaDev);
+ copyToDevice(mBeta, sizeof(float) * mBeta.count, mBetaDev);
+
+ copyToDevice(mWordEmb, getWeightsSize(mWordEmb, mType), mWordEmbDev);
+ copyToDevice(mPosEmb, getWeightsSize(mPosEmb, mType), mPosEmbDev);
+ copyToDevice(mTokEmb, getWeightsSize(mTokEmb, mType), mTokEmbDev);
+}
+
+EmbLayerNormVarSeqlenPluginLegacyHFace::EmbLayerNormVarSeqlenPluginLegacyHFace(std::string const& name,
+ DataType const type, Weights const& beta, Weights const& gamma, Weights const& wordEmb, Weights const& posEmb,
+ Weights const& tokEmb)
+ : EmbLayerNormVarSeqlenPluginLegacyBase(name, type, beta, gamma, wordEmb, posEmb, tokEmb, DataType::kINT32)
+{
+}
+
+EmbLayerNormVarSeqlenPluginLegacyHFace::EmbLayerNormVarSeqlenPluginLegacyHFace(
+ std::string const& name, void const* data, size_t length)
+ : EmbLayerNormVarSeqlenPluginLegacyBase(name, data, length)
+{
+ BERT_DEBUG_MSG("EmbLayerNormVarSeqlenPluginLegacyHFace deserialize");
+}
+
+EmbLayerNormVarSeqlenPluginLegacyMTron::EmbLayerNormVarSeqlenPluginLegacyMTron(std::string const& name,
+ DataType const type, Weights const& beta, Weights const& gamma, Weights const& wordEmb, Weights const& posEmb,
+ Weights const& tokEmb)
+ : EmbLayerNormVarSeqlenPluginLegacyBase(name, type, beta, gamma, wordEmb, posEmb, tokEmb, type)
+{
+}
+
+EmbLayerNormVarSeqlenPluginLegacyMTron::EmbLayerNormVarSeqlenPluginLegacyMTron(
+ std::string const& name, void const* data, size_t length)
+ : EmbLayerNormVarSeqlenPluginLegacyBase(name, data, length)
+{
+ BERT_DEBUG_MSG("EmbLayerNormVarSeqlenPluginLegacyMTron deserialize");
+}
+
+// IPluginV2DynamicExt Methods
+IPluginV2DynamicExt* EmbLayerNormVarSeqlenPluginLegacyHFace::clone() const noexcept
+{
+ try
+ {
+ BERT_DEBUG_MSG("EmbLayerNormVarSeqlenPluginLegacyHFace clone");
+
+ auto p
+ = new EmbLayerNormVarSeqlenPluginLegacyHFace(mLayerName, mType, mBeta, mGamma, mWordEmb, mPosEmb, mTokEmb);
+ p->setPluginNamespace(mNamespace.c_str());
+
+ return p;
+ }
+ catch (std::exception const& e)
+ {
+ caughtError(e);
+ }
+ return nullptr;
+}
+
+IPluginV2DynamicExt* EmbLayerNormVarSeqlenPluginLegacyMTron::clone() const noexcept
+{
+ try
+ {
+ BERT_DEBUG_MSG("EmbLayerNormVarSeqlenPluginLegacyMTron clone");
+
+ auto p
+ = new EmbLayerNormVarSeqlenPluginLegacyMTron(mLayerName, mType, mBeta, mGamma, mWordEmb, mPosEmb, mTokEmb);
+ p->setPluginNamespace(mNamespace.c_str());
+
+ return p;
+ }
+ catch (std::exception const& e)
+ {
+ caughtError(e);
+ }
+ return nullptr;
+}
+
+DimsExprs EmbLayerNormVarSeqlenPluginLegacyHFace::getOutputDimensions(
+ int32_t outputIndex, DimsExprs const* inputs, int32_t nbInputs, IExprBuilder& exprBuilder) noexcept
+{
+ // Input should be input ids and token ids and cumulative seqlens
+ // Output should be the embeddings tensor and mask indices
+ PLUGIN_ASSERT(nbInputs == 4);
+
+ PLUGIN_ASSERT(inputs[0].nbDims == 1); // sum of all s
+ PLUGIN_ASSERT(inputs[0].nbDims == inputs[1].nbDims);
+
+ PLUGIN_ASSERT(inputs[2].nbDims == 1); // B+1
+
+ PLUGIN_ASSERT(outputIndex == 0 || outputIndex == 1);
+
+ if (outputIndex == 0)
+ {
+ DimsExprs ret;
+ ret.nbDims = 4;
+ ret.d[0] = inputs[0].d[0];
+ ret.d[1] = exprBuilder.constant(mLd);
+ ret.d[2] = exprBuilder.constant(1);
+ ret.d[3] = exprBuilder.constant(1);
+ return ret;
+ }
+
+ // Return empty tensor since this is dummy output, we do not delete it for backward compatibility.
+ DimsExprs ret{};
+ ret.nbDims = 0;
+ return ret;
+}
+
+DimsExprs EmbLayerNormVarSeqlenPluginLegacyMTron::getOutputDimensions(
+ int32_t outputIndex, DimsExprs const* inputs, int32_t nbInputs, IExprBuilder& exprBuilder) noexcept
+{
+ // Input should be input ids and token ids and cumulative seqlens
+ // Output should be the embeddings tensor and mask indices
+ PLUGIN_ASSERT(nbInputs == 4);
+
+ PLUGIN_ASSERT(inputs[0].nbDims == 1); // sum of all s
+ PLUGIN_ASSERT(inputs[0].nbDims == inputs[1].nbDims);
+
+ PLUGIN_ASSERT(inputs[2].nbDims == 1); // B+1
+
+ PLUGIN_ASSERT(outputIndex == 0 || outputIndex == 1);
+
+ DimsExprs ret;
+ ret.nbDims = 4;
+ ret.d[0] = inputs[0].d[0];
+ ret.d[1] = exprBuilder.constant(mLd);
+ ret.d[2] = exprBuilder.constant(1);
+ ret.d[3] = exprBuilder.constant(1);
+ return ret;
+}
+
+bool EmbLayerNormVarSeqlenPluginLegacyBase::supportsFormatCombination(
+ int32_t pos, PluginTensorDesc const* inOut, int32_t nbInputs, int32_t nbOutputs) noexcept
+{
+ // The four inputs to this plugin input_ids, segment_ids, cu_seqlens and a dummy input with the
+ // size of the max seq length in that order
+ PLUGIN_ASSERT(nbInputs == 4);
+ // The two outputs of the plugin are embedding and the mask
+ PLUGIN_ASSERT(nbOutputs == 2);
+
+ PluginTensorDesc const& desc = inOut[pos];
+ if (desc.format != TensorFormat::kLINEAR)
+ {
+ return false;
+ }
+ if (pos == 0 || pos == 2) // input_ids and cu_seqlens
+ {
+ return desc.type == DataType::kINT32 && desc.dims.nbDims == 1;
+ }
+
+ PluginTensorDesc const& prev = inOut[pos - 1];
+ if (pos == 1) // segment ids: check it's the same as input_ids
+ {
+ return desc.type == DataType::kINT32 && desc.dims.nbDims == 1 && desc.dims.d[0] == prev.dims.d[0];
+ }
+
+ if (pos == 3)
+ {
+ return desc.dims.nbDims == 1;
+ }
+
+ // embedded sequence
+ if (pos == nbInputs)
+ {
+ return desc.type == mType && desc.dims.nbDims == 4 && desc.dims.d[0] == inOut[0].dims.d[0]
+ && desc.dims.d[2] == 1 && desc.dims.d[3] == 1;
+ }
+ // mask
+ return desc.type == mMaskType;
+}
+
+void checkConfigurationInputs(DynamicPluginTensorDesc const* inputs, int32_t nbInputs,
+ DynamicPluginTensorDesc const* outputs, int32_t nbOutputs) noexcept
+{
+ // Validate input arguments
+ PLUGIN_ASSERT(nbInputs == 4);
+ PLUGIN_ASSERT(nbOutputs == 2);
+
+ PLUGIN_ASSERT(inputs[0].desc.dims.nbDims == 1);
+ PLUGIN_ASSERT(inputs[1].desc.dims.nbDims == 1);
+
+ PLUGIN_ASSERT(inputs[1].desc.dims.d[0] == inputs[0].desc.dims.d[0]);
+
+ PLUGIN_ASSERT(inputs[2].desc.dims.nbDims == 1);
+
+ PLUGIN_ASSERT(outputs[0].desc.dims.nbDims == 4);
+ PLUGIN_ASSERT(static_cast(outputs[0].desc.dims.d[0]) == static_cast(inputs[0].desc.dims.d[0]));
+ PLUGIN_ASSERT(outputs[0].desc.dims.d[2] == 1);
+ PLUGIN_ASSERT(outputs[0].desc.dims.d[3] == 1);
+
+ PLUGIN_ASSERT(inputs[0].desc.type == DataType::kINT32);
+ PLUGIN_ASSERT(inputs[1].desc.type == DataType::kINT32);
+ PLUGIN_ASSERT(inputs[2].desc.type == DataType::kINT32);
+}
+
+void EmbLayerNormVarSeqlenPluginLegacyHFace::configurePlugin(DynamicPluginTensorDesc const* inputs, int32_t nbInputs,
+ DynamicPluginTensorDesc const* outputs, int32_t nbOutputs) noexcept
+{
+ BERT_DEBUG_MSG("EmbLayerNormVarSeqlenPluginLegacyHFace configurePlugin");
+ checkConfigurationInputs(inputs, nbInputs, outputs, nbOutputs);
+ PLUGIN_ASSERT(static_cast(outputs[0].desc.dims.d[1]) == static_cast(mLd));
+
+ // check mask
+ PLUGIN_ASSERT(outputs[1].desc.dims.nbDims == 0);
+ PLUGIN_ASSERT(outputs[0].desc.type == mType);
+ PLUGIN_ASSERT(outputs[1].desc.type == mMaskType);
+}
+
+void EmbLayerNormVarSeqlenPluginLegacyMTron::configurePlugin(DynamicPluginTensorDesc const* inputs, int32_t nbInputs,
+ DynamicPluginTensorDesc const* outputs, int32_t nbOutputs) noexcept
+{
+ BERT_DEBUG_MSG("EmbLayerNormVarSeqlenPluginLegacyMTron configurePlugin");
+ checkConfigurationInputs(inputs, nbInputs, outputs, nbOutputs);
+ PLUGIN_ASSERT(static_cast(outputs[0].desc.dims.d[1]) == static_cast(mLd));
+
+ PLUGIN_ASSERT(outputs[1].desc.dims.nbDims == 4);
+ PLUGIN_ASSERT(static_cast(outputs[1].desc.dims.d[0]) == static_cast(inputs[0].desc.dims.d[0]));
+ PLUGIN_ASSERT(static_cast(outputs[1].desc.dims.d[1]) == static_cast(mLd));
+ PLUGIN_ASSERT(outputs[1].desc.dims.d[2] == 1);
+ PLUGIN_ASSERT(outputs[1].desc.dims.d[3] == 1);
+
+ PLUGIN_ASSERT(outputs[0].desc.type == mType);
+ PLUGIN_ASSERT(outputs[1].desc.type == mMaskType);
+}
+
+size_t EmbLayerNormVarSeqlenPluginLegacyBase::getWorkspaceSize(
+ PluginTensorDesc const* inputs, int32_t nbInputs, PluginTensorDesc const* outputs, int32_t nbOutputs) const noexcept
+{
+ return 0;
+}
+
+int32_t EmbLayerNormVarSeqlenPluginLegacyHFace::enqueue(PluginTensorDesc const* inputDesc,
+ PluginTensorDesc const* /* outputDesc */, void const* const* inputs, void* const* outputs, void* /* workspace */,
+ cudaStream_t stream) noexcept
+{
+ try
+ {
+ PLUGIN_VALIDATE(inputDesc != nullptr && inputs != nullptr && outputs != nullptr);
+
+ int32_t const batchSize = inputDesc[2].dims.d[0] - 1;
+ // read out the maximum sequence length from the dummy input
+ int32_t const maxSeqlen = inputDesc[3].dims.d[0];
+
+ // There are four versions of the kernel which are optimized for sequence lengths 384, 256, 192 and 128.
+ // Find the closest sequence length bigger than the max seq length in this batch.
+ int32_t S = 384;
+ if (maxSeqlen <= 128)
+ {
+ S = 128;
+ }
+ else if (maxSeqlen <= 192)
+ {
+ S = 192;
+ }
+ else if (maxSeqlen <= 256)
+ {
+ S = 256;
+ }
+
+ // Our plugin outputs only one tensor
+ auto const inputIds = static_cast(inputs[0]);
+ auto const segmentIds = static_cast(inputs[1]);
+ int32_t const* cuSeqlens = static_cast(inputs[2]);
+
+ float const* beta = mBetaDev.get();
+ float const* gamma = mGammaDev.get();
+ if (mType == DataType::kFLOAT)
+ {
+ auto output = static_cast(outputs[0]);
+ auto const wordEmb = static_cast(mWordEmbDev.get());
+ auto const tokEmb = static_cast(mTokEmbDev.get());
+ auto const posEmb = static_cast(mPosEmbDev.get());
+
+ return embSkipLayerNormHFace(stream, static_cast(mLd), batchSize, S, inputIds, segmentIds,
+ cuSeqlens, beta, gamma, wordEmb, posEmb, tokEmb, mWordVocabSize, mTokVocabSize, output);
+ }
+ if (mType == DataType::kHALF)
+ {
+ auto output = static_cast(outputs[0]);
+ auto const wordEmb = static_cast(mWordEmbDev.get());
+ auto const tokEmb = static_cast(mTokEmbDev.get());
+ auto const posEmb = static_cast(mPosEmbDev.get());
+
+ return embSkipLayerNormHFace(stream, static_cast(mLd), batchSize, S, inputIds, segmentIds,
+ cuSeqlens, beta, gamma, wordEmb, posEmb, tokEmb, mWordVocabSize, mTokVocabSize, output);
+ }
+ else
+ {
+ gLogError << "Unsupported type error, expected [kHALF,kFLOAT], but received " << static_cast(mType)
+ << std::endl;
+
+ return STATUS_NOT_SUPPORTED;
+ }
+
+ return STATUS_SUCCESS;
+ }
+ catch (std::exception const& e)
+ {
+ caughtError(e);
+ }
+ return STATUS_FAILURE;
+}
+
+int32_t EmbLayerNormVarSeqlenPluginLegacyMTron::enqueue(PluginTensorDesc const* inputDesc,
+ PluginTensorDesc const* /* outputDesc */, void const* const* inputs, void* const* outputs, void* /* workspace */,
+ cudaStream_t stream) noexcept
+{
+ try
+ {
+ PLUGIN_VALIDATE(inputDesc != nullptr && inputs != nullptr && outputs != nullptr);
+
+ int32_t const batchSize = inputDesc[2].dims.d[0] - 1;
+ // read out the maximum sequence length from the dummy input
+ int32_t const maxSeqlen = inputDesc[3].dims.d[0];
+
+ // There are four versions of the kernel which are optimized for sequence lengths 384, 256, 192 and 128.
+ // Find the closest sequence length bigger than the max seq length in this batch.
+ int32_t S = 384;
+ if (maxSeqlen <= 128)
+ {
+ S = 128;
+ }
+ else if (maxSeqlen <= 192)
+ {
+ S = 192;
+ }
+ else if (maxSeqlen <= 256)
+ {
+ S = 256;
+ }
+
+ // Our plugin outputs only one tensor
+ auto const inputIds = static_cast(inputs[0]);
+ auto const segmentIds = static_cast(inputs[1]);
+ int32_t const* cuSeqlens = static_cast(inputs[2]);
+
+ float const* beta = mBetaDev.get();
+ float const* gamma = mGammaDev.get();
+ if (mType == DataType::kFLOAT)
+ {
+ auto output = static_cast(outputs[0]);
+ auto skip = static_cast(outputs[1]);
+ auto const wordEmb = static_cast(mWordEmbDev.get());
+ auto const tokEmb = static_cast(mTokEmbDev.get());
+ auto const posEmb = static_cast(mPosEmbDev.get());
+
+ return embSkipLayerNormMTron(stream, static_cast(mLd), batchSize, S, inputIds, segmentIds,
+ cuSeqlens, beta, gamma, wordEmb, posEmb, tokEmb, mWordVocabSize, mTokVocabSize, output, skip);
+ }
+ if (mType == DataType::kHALF)
+ {
+ auto output = static_cast(outputs[0]);
+ auto skip = static_cast(outputs[1]);
+ auto const wordEmb = static_cast(mWordEmbDev.get());
+ auto const tokEmb = static_cast(mTokEmbDev.get());
+ auto const posEmb = static_cast(mPosEmbDev.get());
+
+ return embSkipLayerNormMTron(stream, static_cast(mLd), batchSize, S, inputIds, segmentIds,
+ cuSeqlens, beta, gamma, wordEmb, posEmb, tokEmb, mWordVocabSize, mTokVocabSize, output, skip);
+ }
+ else
+ {
+ gLogError << "Unsupported type error, expected [kHALF,kFLOAT], but received " << static_cast(mType)
+ << std::endl;
+
+ return STATUS_NOT_SUPPORTED;
+ }
+
+ return STATUS_SUCCESS;
+ }
+ catch (std::exception const& e)
+ {
+ caughtError(e);
+ }
+ return STATUS_FAILURE;
+}
+
+// IPluginV2Ext Methods
+DataType EmbLayerNormVarSeqlenPluginLegacyBase::getOutputDataType(
+ int32_t index, DataType const* inputTypes, int32_t nbInputs) const noexcept
+{
+ PLUGIN_ASSERT(index == 0 || index == 1);
+ PLUGIN_ASSERT(mType == DataType::kHALF || mType == DataType::kFLOAT);
+ return index == 0 ? mType : mMaskType;
+}
+
+// IPluginV2 Methods
+char const* EmbLayerNormVarSeqlenPluginLegacyBase::getPluginType() const noexcept
+{
+ return kEMB_LAYER_NORM_VAR_SEQLEN_NAME;
+}
+
+char const* EmbLayerNormVarSeqlenPluginLegacyHFace::getPluginVersion() const noexcept
+{
+ return kEMB_LAYER_NORM_VAR_SEQLEN_VERSION_HFACE;
+}
+
+char const* EmbLayerNormVarSeqlenPluginLegacyMTron::getPluginVersion() const noexcept
+{
+ return kEMB_LAYER_NORM_VAR_SEQLEN_VERSION_MTRON;
+}
+
+int32_t EmbLayerNormVarSeqlenPluginLegacyBase::getNbOutputs() const noexcept
+{
+ return 2;
+}
+
+int32_t EmbLayerNormVarSeqlenPluginLegacyHFace::initialize() noexcept
+{
+ BERT_DEBUG_MSG("EmbLayerNormVarSeqlenPluginLegacyHFace initialize");
+ return 0;
+}
+
+int32_t EmbLayerNormVarSeqlenPluginLegacyMTron::initialize() noexcept
+{
+ BERT_DEBUG_MSG("EmbLayerNormVarSeqlenPluginLegacyMTron initialize");
+ return 0;
+}
+
+void EmbLayerNormVarSeqlenPluginLegacyHFace::terminate() noexcept
+{
+ BERT_DEBUG_MSG("EmbLayerNormVarSeqlenPluginLegacyHFace terminate");
+}
+
+void EmbLayerNormVarSeqlenPluginLegacyMTron::terminate() noexcept
+{
+ BERT_DEBUG_MSG("EmbLayerNormVarSeqlenPluginLegacyMTron terminate");
+}
+
+size_t EmbLayerNormVarSeqlenPluginLegacyBase::getSerializationSize() const noexcept
+{
+ size_t const wordSize = getElementSize(mType);
+ return 2 * sizeof(float) * mLd // beta + gamma
+ + sizeof(mType) //
+ + sizeof(mLd) //
+ + sizeof(mWordVocabSize) //
+ + sizeof(mPosVocabSize) //
+ + sizeof(mTokVocabSize) //
+ + wordSize * mLd * mWordVocabSize // word emb
+ + wordSize * mLd * mPosVocabSize // pos emb
+ + wordSize * mLd * mTokVocabSize // tok emb
+ + sizeof(mMaskType) // mask type
+ ;
+}
+
+void EmbLayerNormVarSeqlenPluginLegacyBase::serialize(void* buffer) const noexcept
+{
+ serialize_value(&buffer, mType);
+ serialize_value(&buffer, mLd);
+ serialize_value(&buffer, mWordVocabSize);
+ serialize_value(&buffer, mPosVocabSize);
+ serialize_value(&buffer, mTokVocabSize);
+ serialize_value(&buffer, mMaskType);
+
+ char* d = static_cast(buffer);
+ size_t const wordSize = getElementSize(mType);
+
+ serFromDev(d, mBetaDev.get(), mLd);
+ serFromDev(d, mGammaDev.get(), mLd);
+ serFromDev(d, static_cast(mWordEmbDev.get()), mLd * mWordVocabSize * wordSize);
+ serFromDev(d, static_cast(mPosEmbDev.get()), mLd * mPosVocabSize * wordSize);
+ serFromDev(d, static_cast(mTokEmbDev.get()), mLd * mTokVocabSize * wordSize);
+}
+
+void EmbLayerNormVarSeqlenPluginLegacyBase::destroy() noexcept
+{
+ // This gets called when the network containing plugin is destroyed
+ mGammaDev.reset(nullptr);
+ mBetaDev.reset(nullptr);
+ mWordEmbDev.reset(nullptr);
+ mPosEmbDev.reset(nullptr);
+ mTokEmbDev.reset(nullptr);
+ delete this;
+}
+
+void EmbLayerNormVarSeqlenPluginLegacyHFace::destroy() noexcept
+{
+ BERT_DEBUG_MSG("EmbLayerNormVarSeqlenPluginLegacyHFace destroy");
+ EmbLayerNormVarSeqlenPluginLegacyBase::destroy();
+}
+
+void EmbLayerNormVarSeqlenPluginLegacyMTron::destroy() noexcept
+{
+ BERT_DEBUG_MSG("EmbLayerNormVarSeqlenPluginLegacyMTron destroy");
+ EmbLayerNormVarSeqlenPluginLegacyBase::destroy();
+}
+
+void EmbLayerNormVarSeqlenPluginLegacyBase::setPluginNamespace(char const* libNamespace) noexcept
+{
+ try
+ {
+ mNamespace = libNamespace;
+ }
+ catch (std::exception const& e)
+ {
+ caughtError(e);
+ }
+}
+
+char const* EmbLayerNormVarSeqlenPluginLegacyBase::getPluginNamespace() const noexcept
+{
+ return mNamespace.c_str();
+}
+
+///////////////////////
+
+EmbLayerNormVarSeqlenPluginLegacyBaseCreator::EmbLayerNormVarSeqlenPluginLegacyBaseCreator()
+{
+ mPluginAttributes.clear();
+ mPluginAttributes.emplace_back(PluginField("bert_embeddings_layernorm_beta"));
+ mPluginAttributes.emplace_back(PluginField("bert_embeddings_layernorm_gamma"));
+ mPluginAttributes.emplace_back(PluginField("bert_embeddings_word_embeddings"));
+ mPluginAttributes.emplace_back(PluginField("bert_embeddings_token_type_embeddings"));
+ mPluginAttributes.emplace_back(PluginField("bert_embeddings_position_embeddings"));
+ mPluginAttributes.emplace_back(PluginField("output_fp16"));
+ mFC.nbFields = mPluginAttributes.size();
+ mFC.fields = mPluginAttributes.data();
+}
+
+char const* EmbLayerNormVarSeqlenPluginLegacyBaseCreator::getPluginName() const noexcept
+{
+ return kEMB_LAYER_NORM_VAR_SEQLEN_NAME;
+}
+
+char const* EmbLayerNormVarSeqlenPluginLegacyHFaceCreator::getPluginVersion() const noexcept
+{
+ return kEMB_LAYER_NORM_VAR_SEQLEN_VERSION_HFACE;
+}
+
+char const* EmbLayerNormVarSeqlenPluginLegacyMTronCreator::getPluginVersion() const noexcept
+{
+ return kEMB_LAYER_NORM_VAR_SEQLEN_VERSION_MTRON;
+}
+
+PluginFieldCollection const* EmbLayerNormVarSeqlenPluginLegacyBaseCreator::getFieldNames() noexcept
+{
+ return &mFC;
+}
+
+bool initializeFields(char const* name, PluginFieldCollection const* fc, Weights& beta, Weights& gamma,
+ Weights& word_emb, Weights& pos_emb, Weights& tok_emb)
+{
+ bool output_fp16 = false;
+ std::set const requiredAttributes{
+ "bert_embeddings_layernorm_beta",
+ "bert_embeddings_layernorm_gamma",
+ "bert_embeddings_word_embeddings",
+ "bert_embeddings_token_type_embeddings",
+ "bert_embeddings_position_embeddings",
+ };
+ plugin::validateRequiredAttributesExist(requiredAttributes, fc);
+
+ for (int32_t i = 0; i < fc->nbFields; i++)
+ {
+ std::string field_name(fc->fields[i].name);
+ if (field_name.compare("bert_embeddings_layernorm_beta") == 0)
+ {
+ BERT_DEBUG_MSG("Building bert_embeddings_layernorm_beta...");
+ beta.values = fc->fields[i].data;
+ beta.count = fc->fields[i].length;
+ beta.type = fieldTypeToDataType(fc->fields[i].type);
+ }
+
+ if (field_name.compare("bert_embeddings_layernorm_gamma") == 0)
+ {
+ BERT_DEBUG_MSG("Building bert_embeddings_layernorm_gamma...");
+ gamma.values = fc->fields[i].data;
+ gamma.count = fc->fields[i].length;
+ gamma.type = fieldTypeToDataType(fc->fields[i].type);
+ }
+
+ if (field_name.compare("bert_embeddings_word_embeddings") == 0)
+ {
+ BERT_DEBUG_MSG("Building bert_embeddings_word_embeddings...");
+ word_emb.values = fc->fields[i].data;
+ word_emb.count = fc->fields[i].length;
+ word_emb.type = fieldTypeToDataType(fc->fields[i].type);
+ }
+
+ if (field_name.compare("bert_embeddings_token_type_embeddings") == 0)
+ {
+ BERT_DEBUG_MSG("Building bert_embeddings_token_type_embeddings...");
+ tok_emb.values = fc->fields[i].data;
+ tok_emb.count = fc->fields[i].length;
+ tok_emb.type = fieldTypeToDataType(fc->fields[i].type);
+ }
+
+ if (field_name.compare("bert_embeddings_position_embeddings") == 0)
+ {
+ BERT_DEBUG_MSG("Building bert_embeddings_position_embeddings...");
+ pos_emb.values = fc->fields[i].data;
+ pos_emb.count = fc->fields[i].length;
+ pos_emb.type = fieldTypeToDataType(fc->fields[i].type);
+ }
+ if (field_name.compare("output_fp16") == 0)
+ {
+ BERT_DEBUG_MSG("Building output_fp16...");
+ PLUGIN_VALIDATE(fc->fields[i].type == PluginFieldType::kINT32);
+ output_fp16 = static_cast(fc->fields[i].data)[0] != 0;
+ }
+ }
+ return output_fp16;
+}
+
+IPluginV2* EmbLayerNormVarSeqlenPluginLegacyHFaceCreator::createPlugin(
+ char const* name, PluginFieldCollection const* fc) noexcept
+{
+ try
+ {
+ BERT_DEBUG_MSG("EmbLayerNormVarSeqlenHFace createPlugin");
+
+ Weights beta{}; // required attribute: validateRequiredAttributesExist() call in initializeFields() will verify
+ // existence
+ Weights gamma{}; // required attribute: validateRequiredAttributesExist() call in initializeFields() will verify
+ // existence
+ Weights word_emb{}; // required attribute: validateRequiredAttributesExist() call in initializeFields() will
+ // verify existence
+ Weights pos_emb{}; // required attribute: validateRequiredAttributesExist() call in initializeFields() will
+ // verify existence
+ Weights tok_emb{}; // required attribute: validateRequiredAttributesExist() call in initializeFields() will
+ // verify existence
+ bool output_fp16 = initializeFields(name, fc, beta, gamma, word_emb, pos_emb, tok_emb);
+
+ BERT_DEBUG_MSG("Building the Plugin...");
+ EmbLayerNormVarSeqlenPluginLegacyHFace* p = new EmbLayerNormVarSeqlenPluginLegacyHFace(
+ name, output_fp16 ? DataType::kHALF : DataType::kFLOAT, beta, gamma, word_emb, pos_emb, tok_emb);
+ return p;
+ }
+ catch (std::exception const& e)
+ {
+ caughtError(e);
+ }
+ return nullptr;
+}
+
+IPluginV2* EmbLayerNormVarSeqlenPluginLegacyMTronCreator::createPlugin(
+ char const* name, PluginFieldCollection const* fc) noexcept
+{
+ try
+ {
+ BERT_DEBUG_MSG("EmbLayerNormVarSeqlenMTron createPlugin");
+
+ Weights beta{}; // required attribute: validateRequiredAttributesExist() call in initializeFields() will verify
+ // existence
+ Weights gamma{}; // required attribute: validateRequiredAttributesExist() call in initializeFields() will verify
+ // existence
+ Weights word_emb{}; // required attribute: validateRequiredAttributesExist() call in initializeFields() will
+ // verify existence
+ Weights pos_emb{}; // required attribute: validateRequiredAttributesExist() call in initializeFields() will
+ // verify existence
+ Weights tok_emb{}; // required attribute: validateRequiredAttributesExist() call in initializeFields() will
+ // verify existence
+ bool output_fp16 = initializeFields(name, fc, beta, gamma, word_emb, pos_emb, tok_emb);
+
+ BERT_DEBUG_MSG("Building the Plugin...");
+ EmbLayerNormVarSeqlenPluginLegacyMTron* p = new EmbLayerNormVarSeqlenPluginLegacyMTron(
+ name, output_fp16 ? DataType::kHALF : DataType::kFLOAT, beta, gamma, word_emb, pos_emb, tok_emb);
+ return p;
+ }
+ catch (std::exception const& e)
+ {
+ caughtError(e);
+ }
+ return nullptr;
+}
+
+IPluginV2* EmbLayerNormVarSeqlenPluginLegacyHFaceCreator::deserializePlugin(
+ char const* name, void const* serialData, size_t serialLength) noexcept
+{
+ try
+ {
+ // This object will be deleted when the network is destroyed, which will
+ // call EmbLayerNormVarSeqlen::destroy()
+ return new EmbLayerNormVarSeqlenPluginLegacyHFace(name, serialData, serialLength);
+ }
+ catch (std::exception const& e)
+ {
+ caughtError(e);
+ }
+ return nullptr;
+}
+
+IPluginV2* EmbLayerNormVarSeqlenPluginLegacyMTronCreator::deserializePlugin(
+ char const* name, void const* serialData, size_t serialLength) noexcept
+{
+ try
+ {
+ // This object will be deleted when the network is destroyed, which will
+ // call EmbLayerNormVarSeqlen::destroy()
+ return new EmbLayerNormVarSeqlenPluginLegacyMTron(name, serialData, serialLength);
+ }
+ catch (std::exception const& e)
+ {
+ caughtError(e);
+ }
+ return nullptr;
+}
+
+void EmbLayerNormVarSeqlenPluginLegacyBaseCreator::setPluginNamespace(char const* libNamespace) noexcept
+{
+ try
+ {
+ mNamespace = libNamespace;
+ }
+ catch (std::exception const& e)
+ {
+ caughtError(e);
+ }
+}
+
+char const* EmbLayerNormVarSeqlenPluginLegacyBaseCreator::getPluginNamespace() const noexcept
+{
+ return mNamespace.c_str();
+}
diff --git a/plugin/embLayerNormPlugin/embLayerNormVarSeqlenPluginLegacy.h b/plugin/embLayerNormPlugin/embLayerNormVarSeqlenPluginLegacy.h
new file mode 100644
index 00000000..a42a2b87
--- /dev/null
+++ b/plugin/embLayerNormPlugin/embLayerNormVarSeqlenPluginLegacy.h
@@ -0,0 +1,198 @@
+/*
+ * SPDX-FileCopyrightText: Copyright (c) 1993-2024 NVIDIA CORPORATION & AFFILIATES. All rights reserved.
+ * SPDX-License-Identifier: Apache-2.0
+ *
+ * Licensed under the Apache License, Version 2.0 (the "License");
+ * you may not use this file except in compliance with the License.
+ * You may obtain a copy of the License at
+ *
+ * http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing, software
+ * distributed under the License is distributed on an "AS IS" BASIS,
+ * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+ * See the License for the specific language governing permissions and
+ * limitations under the License.
+ */
+#ifndef TRT_EMB_LAYER_NORM_VARSEQ_PLUGIN_LEGACY_H
+#define TRT_EMB_LAYER_NORM_VARSEQ_PLUGIN_LEGACY_H
+
+#include
+
+#include "NvInferPlugin.h"
+#include "NvInferRuntime.h"
+
+#include "common/bertCommon.h"
+#include
+#include
+
+namespace nvinfer1
+{
+namespace plugin
+{
+namespace bert
+{
+
+template
+int32_t embSkipLayerNormHFace(cudaStream_t stream, int32_t ld, int32_t B, int32_t S, int32_t const* inputIds,
+ int32_t const* tokenIds, int32_t const* cuSeqlens, float const* beta, float const* gamma, T const* wordEmb,
+ T const* posEmb, T const* tokEmb, int32_t const wordSize, int32_t const tokSize, T* output);
+
+template
+int32_t embSkipLayerNormMTron(cudaStream_t stream, int32_t ld, int32_t B, int32_t S, int32_t const* inputIds,
+ int32_t const* tokenIds, int32_t const* cuSeqlens, float const* beta, float const* gamma, T const* wordEmb,
+ T const* posEmb, T const* tokEmb, int32_t const wordSize, int32_t const tokSize, T* output, T* skip);
+
+class EmbLayerNormVarSeqlenPluginLegacyBase : public nvinfer1::IPluginV2DynamicExt
+{
+public:
+ EmbLayerNormVarSeqlenPluginLegacyBase(std::string const& name, DataType type, Weights const& beta,
+ Weights const& gamma, Weights const& word_emb, Weights const& pos_emb, Weights const& tok_emb,
+ DataType maskType);
+
+ EmbLayerNormVarSeqlenPluginLegacyBase(std::string const& name, void const* data, size_t length);
+
+ // It doesn't make sense to make EmbLayerNormVarSeqlenPluginLegacy without arguments, so we
+ // delete default constructor.
+ EmbLayerNormVarSeqlenPluginLegacyBase() = delete;
+
+ // IPluginV2DynamicExt Methods
+ bool supportsFormatCombination(
+ int32_t pos, nvinfer1::PluginTensorDesc const* inOut, int32_t nbInputs, int32_t nbOutputs) noexcept override;
+ size_t getWorkspaceSize(nvinfer1::PluginTensorDesc const* inputs, int32_t nbInputs,
+ nvinfer1::PluginTensorDesc const* outputs, int32_t nbOutputs) const noexcept override;
+
+ // IPluginV2Ext Methods
+ nvinfer1::DataType getOutputDataType(
+ int32_t index, nvinfer1::DataType const* inputTypes, int32_t nbInputs) const noexcept override;
+
+ // IPluginV2 Methods
+ char const* getPluginType() const noexcept override;
+ int32_t getNbOutputs() const noexcept override;
+ size_t getSerializationSize() const noexcept override;
+ void serialize(void* buffer) const noexcept override;
+ void destroy() noexcept override;
+ char const* getPluginNamespace() const noexcept override;
+ void setPluginNamespace(char const* pluginNamespace) noexcept override;
+
+protected:
+ std::string const mLayerName;
+ std::string mNamespace;
+
+ bert::cuda_unique_ptr mGammaDev;
+ bert::cuda_unique_ptr mBetaDev;
+ bert::cuda_unique_ptr mWordEmbDev;
+ bert::cuda_unique_ptr mTokEmbDev;
+ bert::cuda_unique_ptr mPosEmbDev;
+ size_t mLd; // leading dim = hidden size
+ size_t mWordVocabSize;
+ size_t mPosVocabSize;
+ size_t mTokVocabSize;
+ bert::WeightsWithOwnership mBeta;
+ bert::WeightsWithOwnership mGamma;
+ bert::WeightsWithOwnership mWordEmb;
+ bert::WeightsWithOwnership mTokEmb;
+ bert::WeightsWithOwnership mPosEmb;
+ DataType mType{};
+ DataType mMaskType{};
+};
+
+class EmbLayerNormVarSeqlenPluginLegacyHFace : public EmbLayerNormVarSeqlenPluginLegacyBase
+{
+public:
+ EmbLayerNormVarSeqlenPluginLegacyHFace(std::string const& name, nvinfer1::DataType const type,
+ nvinfer1::Weights const& beta, nvinfer1::Weights const& gamma, nvinfer1::Weights const& word_emb,
+ nvinfer1::Weights const& pos_emb, nvinfer1::Weights const& tok_emb);
+
+ EmbLayerNormVarSeqlenPluginLegacyHFace(std::string const& name, void const* data, size_t length);
+
+ // It doesn't make sense to make EmbLayerNormVarSeqlenPluginLegacy without arguments, so we
+ // delete default constructor.
+ EmbLayerNormVarSeqlenPluginLegacyHFace() = delete;
+
+ // IPluginV2DynamicExt Methods
+ nvinfer1::IPluginV2DynamicExt* clone() const noexcept override;
+ nvinfer1::DimsExprs getOutputDimensions(int32_t outputIndex, nvinfer1::DimsExprs const* inputs, int32_t nbInputs,
+ nvinfer1::IExprBuilder& exprBuilder) noexcept override;
+ void configurePlugin(nvinfer1::DynamicPluginTensorDesc const* in, int32_t nbInputs,
+ nvinfer1::DynamicPluginTensorDesc const* out, int32_t nbOutputs) noexcept override;
+ int32_t enqueue(nvinfer1::PluginTensorDesc const* inputDesc, nvinfer1::PluginTensorDesc const* outputDesc,
+ void const* const* inputs, void* const* outputs, void* workspace, cudaStream_t stream) noexcept override;
+
+ // IPluginV2 Methods
+ int32_t initialize() noexcept override;
+ void terminate() noexcept override;
+ void destroy() noexcept override;
+ char const* getPluginVersion() const noexcept override;
+};
+
+class EmbLayerNormVarSeqlenPluginLegacyMTron : public EmbLayerNormVarSeqlenPluginLegacyBase
+{
+public:
+ EmbLayerNormVarSeqlenPluginLegacyMTron(std::string const& name, nvinfer1::DataType const type,
+ nvinfer1::Weights const& beta, nvinfer1::Weights const& gamma, nvinfer1::Weights const& word_emb,
+ nvinfer1::Weights const& pos_emb, nvinfer1::Weights const& tok_emb);
+
+ EmbLayerNormVarSeqlenPluginLegacyMTron(std::string const& name, void const* data, size_t length);
+
+ // It doesn't make sense to make EmbLayerNormVarSeqlenPluginLegacy without arguments, so we
+ // delete default constructor.
+ EmbLayerNormVarSeqlenPluginLegacyMTron() = delete;
+
+ // IPluginV2DynamicExt Methods
+ nvinfer1::IPluginV2DynamicExt* clone() const noexcept override;
+ nvinfer1::DimsExprs getOutputDimensions(int32_t outputIndex, nvinfer1::DimsExprs const* inputs, int32_t nbInputs,
+ nvinfer1::IExprBuilder& exprBuilder) noexcept override;
+ void configurePlugin(nvinfer1::DynamicPluginTensorDesc const* in, int32_t nbInputs,
+ nvinfer1::DynamicPluginTensorDesc const* out, int32_t nbOutputs) noexcept override;
+ int32_t enqueue(nvinfer1::PluginTensorDesc const* inputDesc, nvinfer1::PluginTensorDesc const* outputDesc,
+ void const* const* inputs, void* const* outputs, void* workspace, cudaStream_t stream) noexcept override;
+
+ // IPluginV2 Methods
+ int32_t initialize() noexcept override;
+ void terminate() noexcept override;
+ void destroy() noexcept override;
+ char const* getPluginVersion() const noexcept override;
+};
+
+class EmbLayerNormVarSeqlenPluginLegacyBaseCreator : public nvinfer1::IPluginCreator
+{
+public:
+ EmbLayerNormVarSeqlenPluginLegacyBaseCreator();
+
+ char const* getPluginName() const noexcept override;
+
+ nvinfer1::PluginFieldCollection const* getFieldNames() noexcept override;
+
+ void setPluginNamespace(char const* pluginNamespace) noexcept override;
+
+ char const* getPluginNamespace() const noexcept override;
+
+protected:
+ static nvinfer1::PluginFieldCollection mFC;
+ static std::vector mPluginAttributes;
+ std::string mNamespace;
+};
+
+class EmbLayerNormVarSeqlenPluginLegacyHFaceCreator : public EmbLayerNormVarSeqlenPluginLegacyBaseCreator
+{
+public:
+ nvinfer1::IPluginV2* createPlugin(char const* name, nvinfer1::PluginFieldCollection const* fc) noexcept override;
+ char const* getPluginVersion() const noexcept override;
+ nvinfer1::IPluginV2* deserializePlugin(
+ char const* name, void const* serialData, size_t serialLength) noexcept override;
+};
+
+class EmbLayerNormVarSeqlenPluginLegacyMTronCreator : public EmbLayerNormVarSeqlenPluginLegacyBaseCreator
+{
+public:
+ nvinfer1::IPluginV2* createPlugin(char const* name, nvinfer1::PluginFieldCollection const* fc) noexcept override;
+ char const* getPluginVersion() const noexcept override;
+ nvinfer1::IPluginV2* deserializePlugin(
+ char const* name, void const* serialData, size_t serialLength) noexcept override;
+};
+
+} // namespace bert
+} // namespace plugin
+} // namespace nvinfer1
+#endif // TRT_EMB_LAYER_NORM_VARSEQ_PLUGIN_LEGACY_H
diff --git a/plugin/fcPlugin/fcPlugin.h b/plugin/fcPlugin/fcPlugin.h
index 855ce96d..458d4dc3 100644
--- a/plugin/fcPlugin/fcPlugin.h
+++ b/plugin/fcPlugin/fcPlugin.h
@@ -451,12 +451,18 @@ nvinfer1::pluginInternal::cublasLtMatmulAlgo_t gemmSearch(int32_t const m, int32
Gemm g(m, n, k, false, false);
std::vector perfResults(kNB_ALGO_COMBINATIONS);
- PLUGIN_CUASSERT(cudaMallocAsync(reinterpret_cast(&g.A), g.bytesA, stream));
- PLUGIN_CUASSERT(cudaMallocAsync(reinterpret_cast(&g.B), g.bytesB, stream));
- PLUGIN_CUASSERT(cudaMallocAsync(reinterpret_cast(&g.C), g.bytesC, stream));
+ bool const useAsync = supportsMemPools();
+
+ PLUGIN_CUASSERT(useAsync ? cudaMallocAsync(reinterpret_cast(&g.A), g.bytesA, stream)
+ : cudaMalloc(reinterpret_cast(&g.A), g.bytesA));
+ PLUGIN_CUASSERT(useAsync ? cudaMallocAsync(reinterpret_cast(&g.B), g.bytesB, stream)
+ : cudaMalloc(reinterpret_cast(&g.B), g.bytesB));
+ PLUGIN_CUASSERT(useAsync ? cudaMallocAsync(reinterpret_cast(&g.C), g.bytesC, stream)
+ : cudaMalloc(reinterpret_cast(&g.C), g.bytesC));
void* workspace;
- PLUGIN_CUASSERT(cudaMallocAsync(&workspace, workspaceSize, stream));
+ PLUGIN_CUASSERT(
+ useAsync ? cudaMallocAsync(&workspace, workspaceSize, stream) : cudaMalloc(&workspace, workspaceSize));
nvinfer1::pluginInternal::cublasLtHandle_t lt;
nvinfer1::pluginInternal::CublasLtWrapper& cublasLtWrapper = nvinfer1::pluginInternal::getCublasLtWrapper();
PLUGIN_CUBLASASSERT(cublasLtWrapper.cublasLtCreate(<));
@@ -464,11 +470,11 @@ nvinfer1::pluginInternal::cublasLtMatmulAlgo_t gemmSearch(int32_t const m, int32
LtGemmSearch(lt, g, workspace, workspaceSize, perfResults, stream);
PLUGIN_CUASSERT(cudaStreamSynchronize(stream));
PLUGIN_CUBLASASSERT(cublasLtWrapper.cublasLtDestroy(lt));
- PLUGIN_CUASSERT(cudaFreeAsync(workspace, stream));
+ PLUGIN_CUASSERT(useAsync ? cudaFreeAsync(workspace, stream) : cudaFree(workspace));
- PLUGIN_CUASSERT(cudaFreeAsync(g.A, stream));
- PLUGIN_CUASSERT(cudaFreeAsync(g.B, stream));
- PLUGIN_CUASSERT(cudaFreeAsync(g.C, stream));
+ PLUGIN_CUASSERT(useAsync ? cudaFreeAsync(g.A, stream) : cudaFree(g.A));
+ PLUGIN_CUASSERT(useAsync ? cudaFreeAsync(g.B, stream) : cudaFree(g.B));
+ PLUGIN_CUASSERT(useAsync ? cudaFreeAsync(g.C, stream) : cudaFree(g.C));
actualWorkspace = perfResults[0].workspaceSize;
return perfResults[0].algo;
@@ -480,12 +486,18 @@ nvinfer1::pluginInternal::cublasLtMatmulAlgo_t gemmSearch(
{
std::vector perfResults(kNB_ALGO_COMBINATIONS);
- PLUGIN_CUASSERT(cudaMallocAsync(&g.A, g.bytesA, stream));
- PLUGIN_CUASSERT(cudaMallocAsync(&g.B, g.bytesB, stream));
- PLUGIN_CUASSERT(cudaMallocAsync(&g.C, g.bytesC, stream));
+ bool const useAsync = supportsMemPools();
+
+ PLUGIN_CUASSERT(useAsync ? cudaMallocAsync(reinterpret_cast(&g.A), g.bytesA, stream)
+ : cudaMalloc(reinterpret_cast(&g.A), g.bytesA));
+ PLUGIN_CUASSERT(useAsync ? cudaMallocAsync(reinterpret_cast(&g.B), g.bytesB, stream)
+ : cudaMalloc(reinterpret_cast(&g.B), g.bytesB));
+ PLUGIN_CUASSERT(useAsync ? cudaMallocAsync(reinterpret_cast(&g.C), g.bytesC, stream)
+ : cudaMalloc(reinterpret_cast(&g.C), g.bytesC));
void* workspace;
- PLUGIN_CUASSERT(cudaMallocAsync(&workspace, workspaceSize, stream));
+ PLUGIN_CUASSERT(
+ useAsync ? cudaMallocAsync(&workspace, workspaceSize, stream) : cudaMalloc(&workspace, workspaceSize));
nvinfer1::pluginInternal::cublasLtHandle_t lt;
nvinfer1::pluginInternal::CublasLtWrapper& cublasLtWrapper = nvinfer1::pluginInternal::getCublasLtWrapper();
PLUGIN_CUBLASASSERT(cublasLtWrapper.cublasLtCreate(<));
@@ -493,11 +505,11 @@ nvinfer1::pluginInternal::cublasLtMatmulAlgo_t gemmSearch(
LtGemmSearch(lt, g, workspace, workspaceSize, perfResults, stream);
PLUGIN_CUASSERT(cudaStreamSynchronize(stream));
PLUGIN_CUBLASASSERT(cublasLtWrapper.cublasLtDestroy(lt));
- PLUGIN_CUASSERT(cudaFreeAsync(workspace, stream));
+ PLUGIN_CUASSERT(useAsync ? cudaFreeAsync(workspace, stream) : cudaFree(workspace));
- PLUGIN_CUASSERT(cudaFreeAsync(g.A, stream));
- PLUGIN_CUASSERT(cudaFreeAsync(g.B, stream));
- PLUGIN_CUASSERT(cudaFreeAsync(g.C, stream));
+ PLUGIN_CUASSERT(useAsync ? cudaFreeAsync(g.A, stream) : cudaFree(g.A));
+ PLUGIN_CUASSERT(useAsync ? cudaFreeAsync(g.B, stream) : cudaFree(g.B));
+ PLUGIN_CUASSERT(useAsync ? cudaFreeAsync(g.C, stream) : cudaFree(g.C));
actualWorkspace = perfResults[0].workspaceSize;
return perfResults[0].algo;
diff --git a/plugin/skipLayerNormPlugin/CustomSkipLayerNormPluginDynamic_PluginConfig.yaml b/plugin/skipLayerNormPlugin/CustomSkipLayerNormPluginDynamic_PluginConfig.yaml
index a39fd9bc..117fcbf1 100644
--- a/plugin/skipLayerNormPlugin/CustomSkipLayerNormPluginDynamic_PluginConfig.yaml
+++ b/plugin/skipLayerNormPlugin/CustomSkipLayerNormPluginDynamic_PluginConfig.yaml
@@ -16,9 +16,9 @@
#
---
name: CustomSkipLayerNormPluginDynamic
-interface: "IPluginV2DynamicExt"
+interface: "IPluginV3"
versions:
- "1":
+ "5": # SkipLayerNormPluginV3
inputs:
- input
- skip
@@ -115,13 +115,13 @@ versions:
attribute_options:
type_id:
value: 0
- ld:
+ ld:
value: 128
- beta:
+ beta:
shape: "1, 1, 128"
- gamma:
+ gamma:
shape: "1, 1, 128"
- bias:
+ bias:
shape: "1, 1, 128"
config2:
input_types:
@@ -130,12 +130,118 @@ versions:
attribute_options:
type_id:
value: 1
- ld:
+ ld:
value: 768
- beta:
+ beta:
shape: "1, 1, 768"
- gamma:
+ gamma:
shape: "1, 1, 768"
- bias:
+ bias:
+ shape: "1, 1, 768"
+ "6": # SkipLayerNormVarSeqlenPluginV3
+ inputs:
+ - input
+ - skip
+ outputs:
+ - output
+ input_dims:
+ input: 5
+ skip: 5
+ input_dim_constraints:
+ - "input_2 == bias_2"
+ - "skip_0 == input_0"
+ - "skip_1 == input_1"
+ - "skip_2 == input_2"
+ input_dim_range:
+ input:
+ min: "=1, =1, =1, =1, =1"
+ max: "=pinf, =pinf, =pinf, =1, =1"
+ skip:
+ min: "=1, =1, =1, =1, =1"
+ max: "=pinf, =pinf, =pinf, =1, =1"
+ supported_input_types:
+ - combination1:
+ input: float32
+ skip: float32
+ - combination2:
+ input: float16
+ skip: float16
+ output_dims:
+ output: "input_0, input_1, input_2, input_3, input_4"
+ attributes:
+ - type_id
+ - beta
+ - gamma
+ - bias
+ attribute_types:
+ type_id: int32
+ beta: float32
+ gamma: float32
+ bias: float32
+ attribute_dims:
+ type_id: 1
+ beta: 3
+ gamma: 3
+ bias: 3
+ attribute_dim_range:
+ type_id:
+ min: "=1"
+ max: "=1"
+ beta:
+ min: "=1, =1, =1"
+ max: "=1, =1, =pinf"
+ gamma:
+ min: "=1, =1, =1"
+ max: "=1, =1, =pinf"
+ bias:
+ min: "=1, =1, =1"
+ max: "=1, =1, =pinf"
+ attribute_options:
+ type_id:
+ - 0
+ - 1
+ - 2
+ beta:
+ min: "=ninf"
+ max: "=pinf"
+ gamma:
+ min: "=ninf"
+ max: "=pinf"
+ bias:
+ min: "=ninf"
+ max: "=pinf"
+ attributes_required:
+ - type_id
+ - beta
+ - gamma
+ golden_reference_script: "plugin/skipLayerNormPlugin/CustomSkipLayerNormPluginDynamic_PluginReference.py"
+ abs_tol: 1e-2
+ rel_tol: 1e-2
+ configs:
+ config1:
+ input_types:
+ input: float32
+ skip: float32
+ attribute_options:
+ type_id:
+ value: 0
+ beta:
+ shape: "1, 1, 128"
+ gamma:
+ shape: "1, 1, 128"
+ bias:
+ shape: "1, 1, 128"
+ config2:
+ input_types:
+ input: float16
+ skip: float16
+ attribute_options:
+ type_id:
+ value: 1
+ beta:
+ shape: "1, 1, 768"
+ gamma:
+ shape: "1, 1, 768"
+ bias:
shape: "1, 1, 768"
...
diff --git a/plugin/skipLayerNormPlugin/README.md b/plugin/skipLayerNormPlugin/README.md
index 5b5ffbda..e80a846b 100644
--- a/plugin/skipLayerNormPlugin/README.md
+++ b/plugin/skipLayerNormPlugin/README.md
@@ -21,7 +21,7 @@ Optionally, adds a bias vector before layer-normalization.
The `skipLayerNormPlugin` takes two inputs; `input` and `skip`.
`input`
-For V1 and V2, input is a tensor with shape `[S, B, E, 1, 1]` where `S` is the sequence length, `B` is the batch size, `E` is the hidden size, and the last two dimensions are of size 1.
+For V1, V2, V5, V6, input is a tensor with shape `[S, B, E, 1, 1]` where `S` is the sequence length, `B` is the batch size, `E` is the hidden size, and the last two dimensions are of size 1.
For V3 and V4, input is a tensor with shape `[1, E, S', 1]` where `S'` is the accumulated sequence length, `E` is the hidden size, and the first and last dimensions are of size 1.
`skip`
@@ -41,13 +41,13 @@ output is a tensor with the same shape as the input.
The parameters are defined below and consists of the following attributes:
-| Type | Parameter | Version | Description
-|----------|-----------------------------------------|------------|-------------------------------------------------------------------
-|`int` |`type_id` | 1, 2 |Integer encoding the DataType (0: FP32, 1: FP16, 2: INT8)
-|`int` |`ld` | 1 |The leading dimension of the input tensor, corresponding to the hidden size, denoted by `E` above.
-|`Weights` |`beta` | 1, 2, 3, 4|The mean to normalize to. Shape: `[1, 1, E]`
-|`Weights` |`gamma` | 1, 2, 3, 4|The standard deviation to normalize to. Shape: `[1, 1, E]`
-|`Weights` |`bias` | 1, 2 |An optional bias vector to add before normalization. Shape: `[1, 1, E]`
+| Type | Parameter | Version | Description
+|----------|-----------------------------------------|-------------------------|-------------------------------------------------------------------
+|`int` |`type_id` | 1, 2, 5, 6 |Integer encoding the DataType (0: FP32, 1: FP16, 2: INT8)
+|`int` |`ld` | 1, 5 |The leading dimension of the input tensor, corresponding to the hidden size, denoted by `E` above.
+|`Weights` |`beta` | 1, 2, 3, 4, 5, 6, 7, 8 |The mean to normalize to. Shape: `[1, 1, E]`
+|`Weights` |`gamma` | 1, 2, 3, 4, 5, 6, 7, 8 |The standard deviation to normalize to. Shape: `[1, 1, E]`
+|`Weights` |`bias` | 1, 2, 5, 6 |An optional bias vector to add before normalization. Shape: `[1, 1, E]`
## Additional resources
@@ -63,6 +63,9 @@ documentation.
## Changelog
+July 2024
+Add v5, v6, v7 and v8 plugins that duplicate the behavior of v1, v3, v3 and v4 plugins respectively, but implement the `IPluginV3` interface instead of the deprecated `IPluginV2DynamicExt` interface.
+
February 2024
Add epsilon to avoid divide by zero.
diff --git a/plugin/skipLayerNormPlugin/skipLayerNormInt8InterleavedPlugin.cpp b/plugin/skipLayerNormPlugin/skipLayerNormInt8InterleavedPlugin.cpp
index 1b74f944..266ba76c 100644
--- a/plugin/skipLayerNormPlugin/skipLayerNormInt8InterleavedPlugin.cpp
+++ b/plugin/skipLayerNormPlugin/skipLayerNormInt8InterleavedPlugin.cpp
@@ -18,6 +18,7 @@
#include "skipLayerNormInt8InterleavedPlugin.h"
#include "NvInfer.h"
#include "common/serialize.hpp"
+
#include
#include
@@ -30,9 +31,59 @@ using namespace nvinfer1::plugin::bert;
// Clip plugin specific constants
namespace
{
-char const* kSKIP_LAYER_NORM_INTERLEAVED_VERSION_HFACE{"3"};
-char const* kSKIP_LAYER_NORM_INTERLEAVED_VERSION_MTRON{"4"};
-char const* kSKIP_LAYER_NORM_INTERLEAVED_NAME{"CustomSkipLayerNormPluginDynamic"};
+constexpr char const* kSKIP_LAYER_NORM_INTERLEAVED_VERSION_HFACE{"7"};
+constexpr char const* kSKIP_LAYER_NORM_INTERLEAVED_VERSION_MTRON{"8"};
+constexpr char const* kSKIP_LAYER_NORM_INTERLEAVED_NAME{"CustomSkipLayerNormPluginDynamic"};
+
+void checkDescs(PluginTensorDesc const& iDesc, PluginTensorDesc const& sDesc, PluginTensorDesc const& oDesc)
+{
+ PLUGIN_VALIDATE(iDesc.dims.nbDims == 4);
+ PLUGIN_VALIDATE(iDesc.dims.nbDims == sDesc.dims.nbDims);
+ PLUGIN_VALIDATE(std::equal(iDesc.dims.d, iDesc.dims.d + iDesc.dims.nbDims, sDesc.dims.d));
+ PLUGIN_VALIDATE(std::equal(iDesc.dims.d, iDesc.dims.d + iDesc.dims.nbDims, oDesc.dims.d));
+ PLUGIN_VALIDATE(iDesc.dims.d[0] == 1);
+ PLUGIN_VALIDATE(iDesc.dims.d[3] == 1);
+ PLUGIN_VALIDATE(iDesc.format == TensorFormat::kCHW32);
+ PLUGIN_VALIDATE(iDesc.type == DataType::kINT8);
+ PLUGIN_VALIDATE(iDesc.format == sDesc.format);
+ PLUGIN_VALIDATE(iDesc.format == oDesc.format);
+ PLUGIN_VALIDATE(iDesc.type == sDesc.type);
+ PLUGIN_VALIDATE(iDesc.type == oDesc.type);
+}
+
+void buildBetaAndGamma(PluginFieldCollection const* fc, Weights& beta, Weights& gamma)
+{
+ PLUGIN_VALIDATE(fc != nullptr, "SkipLayerNorm: Plugin Field collection is null");
+ PLUGIN_VALIDATE(fc->fields != nullptr, "SkipLayerNorm: Plugin Fields are null");
+ plugin::validateRequiredAttributesExist({"beta", "gamma"}, fc);
+
+ for (int32_t i = 0; i < fc->nbFields; i++)
+ {
+ std::string fieldName(fc->fields[i].name);
+
+ if (fieldName.compare("beta") == 0)
+ {
+ BERT_DEBUG_MSG("Building beta...");
+ beta.values = fc->fields[i].data;
+ beta.count = fc->fields[i].length;
+ beta.type = fieldTypeToDataType(fc->fields[i].type);
+ }
+
+ if (fieldName.compare("gamma") == 0)
+ {
+ BERT_DEBUG_MSG("Building gamma...");
+ gamma.values = fc->fields[i].data;
+ gamma.count = fc->fields[i].length;
+ gamma.type = fieldTypeToDataType(fc->fields[i].type);
+ }
+ }
+
+ PLUGIN_VALIDATE(beta.values != nullptr, "SkipLayerNorm: invalid beta");
+ PLUGIN_VALIDATE(beta.count > 0, "SkipLayerNorm: invalid beta");
+
+ PLUGIN_VALIDATE(gamma.values != nullptr, "SkipLayerNorm: invalid gamma");
+ PLUGIN_VALIDATE(gamma.count > 0, "SkipLayerNorm: invalid gamma");
+}
} // namespace
// Static class fields initialization
@@ -42,17 +93,7 @@ std::vector SkipLayerNormInterleavedPluginBaseCreator::mPluginAttri
REGISTER_TENSORRT_PLUGIN(SkipLayerNormInterleavedPluginHFaceCreator);
REGISTER_TENSORRT_PLUGIN(SkipLayerNormInterleavedPluginMTronCreator);
-constexpr auto param_type = DataType::kHALF;
-
-static inline DataType getParamWordType(DataType cfgType)
-{
- if (cfgType == DataType::kINT8)
- {
- return DataType::kHALF;
- }
-
- return cfgType;
-}
+constexpr auto kPARAM_TYPE = DataType::kHALF;
SkipLayerNormInterleavedPluginBase::SkipLayerNormInterleavedPluginBase(
std::string const& name, Weights const& beta, Weights const& gamma)
@@ -66,10 +107,10 @@ SkipLayerNormInterleavedPluginBase::SkipLayerNormInterleavedPluginBase(
PLUGIN_VALIDATE(beta.count == gamma.count);
// dataType for beta, gamma weights is always fp16
- mParamWordsize = getElementSize(param_type);
+ mParamWordsize = getElementSize(kPARAM_TYPE);
- mBeta.convertAndCopy(beta, param_type);
- mGamma.convertAndCopy(gamma, param_type);
+ mBeta.convertAndCopy(beta, kPARAM_TYPE);
+ mGamma.convertAndCopy(gamma, kPARAM_TYPE);
}
SkipLayerNormInterleavedPluginHFace::SkipLayerNormInterleavedPluginHFace(
@@ -84,48 +125,48 @@ SkipLayerNormInterleavedPluginMTron::SkipLayerNormInterleavedPluginMTron(
{
}
-SkipLayerNormInterleavedPluginBase::SkipLayerNormInterleavedPluginBase(
- std::string const& name, void const* data, size_t length)
- : mLayerName(name)
- , mGammaDev(nullptr)
- , mBetaDev(nullptr)
- , mParamsOnDevice(false)
+SkipLayerNormInterleavedPluginBase::~SkipLayerNormInterleavedPluginBase()
{
- // Deserialize in the same order as serialization
- deserialize_value(&data, &length, &mLd);
-
- mParamWordsize = getElementSize(param_type);
-
- char const* d = static_cast(data);
- mBeta.convertAndCopy(d, mLd, param_type);
- mGamma.convertAndCopy(d, mLd, param_type);
+ try
+ {
+ mGammaDev.reset(nullptr);
+ mBetaDev.reset(nullptr);
+ }
+ catch (std::exception const& e)
+ {
+ caughtError(e);
+ }
}
-SkipLayerNormInterleavedPluginHFace::SkipLayerNormInterleavedPluginHFace(
- std::string const& name, void const* data, size_t length)
- : SkipLayerNormInterleavedPluginBase(name, data, length)
+SkipLayerNormInterleavedPluginHFace::~SkipLayerNormInterleavedPluginHFace()
{
- BERT_DEBUG_MSG("SkipLayerNormInterleavedPluginHFace deserialize");
+ BERT_DEBUG_MSG("SkipLayerNormInterleavedPluginHFace destructor");
}
-SkipLayerNormInterleavedPluginMTron::SkipLayerNormInterleavedPluginMTron(
- std::string const& name, void const* data, size_t length)
- : SkipLayerNormInterleavedPluginBase(name, data, length)
+SkipLayerNormInterleavedPluginMTron::~SkipLayerNormInterleavedPluginMTron()
{
- BERT_DEBUG_MSG("SkipLayerNormInterleavedPluginMTron deserialize");
+ BERT_DEBUG_MSG("SkipLayerNormInterleavedPluginMTron destructor");
}
-// IPluginV2DynamicExt Methods
-IPluginV2DynamicExt* SkipLayerNormInterleavedPluginHFace::clone() const noexcept
+//////
+// IPluginV3 method definitions:
+// - getCapabilityInterface() (Base)
+// - clone() (HFace, MTron)
+//////
+IPluginCapability* SkipLayerNormInterleavedPluginBase::getCapabilityInterface(PluginCapabilityType type) noexcept
{
try
{
- BERT_DEBUG_MSG("SkipLayerNormInterleavedPluginHFace clone");
-
- auto* p = new SkipLayerNormInterleavedPluginHFace(mLayerName, mBeta, mGamma);
- p->initialize();
- p->setPluginNamespace(mNamespace.c_str());
- return p;
+ if (type == PluginCapabilityType::kBUILD)
+ {
+ return static_cast(this);
+ }
+ if (type == PluginCapabilityType::kRUNTIME)
+ {
+ return static_cast(this);
+ }
+ PLUGIN_ASSERT(type == PluginCapabilityType::kCORE);
+ return static_cast(this);
}
catch (std::exception const& e)
{
@@ -134,14 +175,13 @@ IPluginV2DynamicExt* SkipLayerNormInterleavedPluginHFace::clone() const noexcept
return nullptr;
}
-IPluginV2DynamicExt* SkipLayerNormInterleavedPluginMTron::clone() const noexcept
+IPluginV3* SkipLayerNormInterleavedPluginHFace::clone() noexcept
{
try
{
- BERT_DEBUG_MSG("SkipLayerNormInterleavedPluginMTron clone");
+ BERT_DEBUG_MSG("SkipLayerNormInterleavedPluginHFace clone");
- auto* p = new SkipLayerNormInterleavedPluginMTron(mLayerName, mBeta, mGamma);
- p->initialize();
+ auto* p = new SkipLayerNormInterleavedPluginHFace(mLayerName, mBeta, mGamma);
p->setPluginNamespace(mNamespace.c_str());
return p;
}
@@ -152,46 +192,48 @@ IPluginV2DynamicExt* SkipLayerNormInterleavedPluginMTron::clone() const noexcept
return nullptr;
}
-DimsExprs SkipLayerNormInterleavedPluginBase::getOutputDimensions(
- int32_t outputIndex, DimsExprs const* inputs, int32_t nbInputs, IExprBuilder& exprBuilder) noexcept
+IPluginV3* SkipLayerNormInterleavedPluginMTron::clone() noexcept
{
try
{
- PLUGIN_VALIDATE(inputs != nullptr);
- PLUGIN_VALIDATE(nbInputs == 2);
- PLUGIN_VALIDATE(outputIndex >= 0 && outputIndex < getNbOutputs());
- PLUGIN_VALIDATE(inputs[0].nbDims == inputs[1].nbDims);
- return inputs[0];
+ BERT_DEBUG_MSG("SkipLayerNormInterleavedPluginMTron clone");
+
+ auto* p = new SkipLayerNormInterleavedPluginMTron(mLayerName, mBeta, mGamma);
+ p->setPluginNamespace(mNamespace.c_str());
+ return p;
}
catch (std::exception const& e)
{
caughtError(e);
}
- return DimsExprs{};
+ return nullptr;
}
-bool SkipLayerNormInterleavedPluginBase::supportsFormatCombination(
- int32_t pos, PluginTensorDesc const* inOut, int32_t nbInputs, int32_t nbOutputs) noexcept
-{
- try
- {
- PLUGIN_VALIDATE(inOut != nullptr);
- PLUGIN_VALIDATE(nbInputs == 2);
- PLUGIN_VALIDATE(nbOutputs == getNbOutputs());
- PLUGIN_VALIDATE(pos >= 0 && pos < (nbInputs + nbOutputs));
+// End IPluginV3 method definitions
- PluginTensorDesc const& desc = inOut[pos];
- return desc.type == DataType::kINT8 && desc.format == TensorFormat::kCHW32;
- }
- catch (std::exception const& e)
- {
- caughtError(e);
- }
- return false;
+//////
+// IPluginV3OneRuntime method definitions:
+// - getFieldsToSerialize() (Base)
+// - onShapeChange() (Base)
+// - attachToContext() (HFace, MTron)
+// - execute() (HFace, MTron)
+/////
+PluginFieldCollection const* SkipLayerNormInterleavedPluginBase::getFieldsToSerialize() noexcept
+{
+ mDataToSerialize.clear();
+ mDataToSerialize.emplace_back(
+ "beta", static_cast(mBeta.values), PluginFieldType::kFLOAT16, mBeta.count);
+ PLUGIN_ASSERT(mBeta.type == kPARAM_TYPE);
+ mDataToSerialize.emplace_back(
+ "gamma", static_cast(mGamma.values), PluginFieldType::kFLOAT16, mGamma.count);
+ PLUGIN_ASSERT(mGamma.type == kPARAM_TYPE);
+ mFCToSerialize.nbFields = mDataToSerialize.size();
+ mFCToSerialize.fields = mDataToSerialize.data();
+ return &mFCToSerialize;
}
-void SkipLayerNormInterleavedPluginBase::configurePlugin(DynamicPluginTensorDesc const* inputs, int32_t nbInputs,
- DynamicPluginTensorDesc const* outputs, int32_t nbOutputs) noexcept
+int32_t SkipLayerNormInterleavedPluginBase::onShapeChange(
+ PluginTensorDesc const* inputs, int32_t nbInputs, PluginTensorDesc const* outputs, int32_t nbOutputs) noexcept
{
try
{
@@ -200,51 +242,36 @@ void SkipLayerNormInterleavedPluginBase::configurePlugin(DynamicPluginTensorDesc
PLUGIN_VALIDATE(outputs != nullptr);
PLUGIN_VALIDATE(nbOutputs == getNbOutputs());
PLUGIN_VALIDATE(nbInputs == 2);
- PLUGIN_VALIDATE(DataType::kINT8 == inputs[0].desc.type);
- PLUGIN_VALIDATE(DataType::kINT8 == inputs[1].desc.type);
+ PLUGIN_VALIDATE(DataType::kINT8 == inputs[0].type);
+ PLUGIN_VALIDATE(DataType::kINT8 == inputs[1].type);
- auto const& inDims0 = inputs[0].desc.dims;
- auto const& inDims1 = inputs[1].desc.dims;
+ auto const& inDims0 = inputs[0].dims;
+ auto const& inDims1 = inputs[1].dims;
TRT_UNUSED inDims1;
PLUGIN_VALIDATE(inDims0.nbDims == inDims1.nbDims);
PLUGIN_VALIDATE(std::equal(inDims0.d, inDims0.d + inDims0.nbDims, inDims1.d));
- mParamWordsize = getElementSize(param_type);
+ mParamWordsize = getElementSize(kPARAM_TYPE);
if (!mParamsOnDevice)
{
- copyToDevice(mGamma, getWeightsSize(mGamma, param_type), mGammaDev);
- copyToDevice(mBeta, getWeightsSize(mBeta, param_type), mBetaDev);
+ copyToDevice(mGamma, getWeightsSize(mGamma, kPARAM_TYPE), mGammaDev);
+ copyToDevice(mBeta, getWeightsSize(mBeta, kPARAM_TYPE), mBetaDev);
mParamsOnDevice = true;
}
+ return pluginStatus_t::STATUS_SUCCESS;
}
catch (std::exception const& e)
{
caughtError(e);
}
+ return pluginStatus_t::STATUS_FAILURE;
}
-size_t SkipLayerNormInterleavedPluginBase::getWorkspaceSize(
- PluginTensorDesc const* inputs, int32_t nbInputs, PluginTensorDesc const* outputs, int32_t nbOutputs) const noexcept
+IPluginV3* SkipLayerNormInterleavedPluginBase::attachToContext(IPluginResourceContext* context) noexcept
{
- return 0;
-}
-
-void checkDescs(PluginTensorDesc const& iDesc, PluginTensorDesc const& sDesc, PluginTensorDesc const& oDesc)
-{
- PLUGIN_VALIDATE(iDesc.dims.nbDims == 4);
- PLUGIN_VALIDATE(iDesc.dims.nbDims == sDesc.dims.nbDims);
- PLUGIN_VALIDATE(std::equal(iDesc.dims.d, iDesc.dims.d + iDesc.dims.nbDims, sDesc.dims.d));
- PLUGIN_VALIDATE(std::equal(iDesc.dims.d, iDesc.dims.d + iDesc.dims.nbDims, oDesc.dims.d));
- PLUGIN_VALIDATE(iDesc.dims.d[0] == 1);
- PLUGIN_VALIDATE(iDesc.dims.d[3] == 1);
- PLUGIN_VALIDATE(iDesc.format == TensorFormat::kCHW32);
- PLUGIN_VALIDATE(iDesc.type == DataType::kINT8);
- PLUGIN_VALIDATE(iDesc.format == sDesc.format);
- PLUGIN_VALIDATE(iDesc.format == oDesc.format);
- PLUGIN_VALIDATE(iDesc.type == sDesc.type);
- PLUGIN_VALIDATE(iDesc.type == oDesc.type);
+ return clone();
}
int32_t SkipLayerNormInterleavedPluginHFace::enqueue(PluginTensorDesc const* inputDesc,
@@ -331,41 +358,17 @@ int32_t SkipLayerNormInterleavedPluginMTron::enqueue(PluginTensorDesc const* inp
}
return -1;
}
-
-// IPluginV2Ext Methods
-DataType SkipLayerNormInterleavedPluginBase::getOutputDataType(
- int32_t index, DataType const* inputTypes, int32_t nbInputs) const noexcept
-{
- try
- {
- PLUGIN_VALIDATE(inputTypes != nullptr);
- PLUGIN_VALIDATE(index >= 0 && index < getNbOutputs());
- PLUGIN_VALIDATE(nbInputs == 2);
- return inputTypes[0];
- }
- catch (std::exception const& e)
- {
- caughtError(e);
- }
- return DataType{};
-}
-
-// IPluginV2 Methods
-char const* SkipLayerNormInterleavedPluginBase::getPluginType() const noexcept
-{
- return kSKIP_LAYER_NORM_INTERLEAVED_NAME;
-}
-
-char const* SkipLayerNormInterleavedPluginHFace::getPluginVersion() const noexcept
-{
- return kSKIP_LAYER_NORM_INTERLEAVED_VERSION_HFACE;
-}
-
-char const* SkipLayerNormInterleavedPluginMTron::getPluginVersion() const noexcept
-{
- return kSKIP_LAYER_NORM_INTERLEAVED_VERSION_MTRON;
-}
-
+// end IPluginV3OneRuntime method definitions
+
+///////
+// IPluginV3OneBuild method definitions
+// - getNbOutputs() (MTron, HFace)
+// - supportsFormatCombination() (Base)
+// - getOutputShapes (Base)
+// - getOutputDataType() (Base)
+// - configurePlugin() (Base)
+// - getWorkSpaceSize() (Base)
+//////
int32_t SkipLayerNormInterleavedPluginHFace::getNbOutputs() const noexcept
{
return 1;
@@ -376,79 +379,102 @@ int32_t SkipLayerNormInterleavedPluginMTron::getNbOutputs() const noexcept
return 2;
}
-int32_t SkipLayerNormInterleavedPluginHFace::initialize() noexcept
-{
- BERT_DEBUG_MSG("SkipLayerNormInterleavedPluginHFace initialize");
- return 0;
-}
-
-int32_t SkipLayerNormInterleavedPluginMTron::initialize() noexcept
-{
- BERT_DEBUG_MSG("SkipLayerNormInterleavedPluginMTron initialize");
- return 0;
-}
-
-void SkipLayerNormInterleavedPluginHFace::terminate() noexcept
-{
- BERT_DEBUG_MSG("SkipLayerNormInterleavedPluginHFace terminate");
-}
-
-void SkipLayerNormInterleavedPluginMTron::terminate() noexcept
-{
- BERT_DEBUG_MSG("SkipLayerNormInterleavedPluginMTron terminate");
-}
-
-size_t SkipLayerNormInterleavedPluginBase::getSerializationSize() const noexcept
+bool SkipLayerNormInterleavedPluginBase::supportsFormatCombination(
+ int32_t pos, DynamicPluginTensorDesc const* inOut, int32_t nbInputs, int32_t nbOutputs) noexcept
{
- return 2 * mParamWordsize * mLd + sizeof(mLd);
+ try
+ {
+ PLUGIN_VALIDATE(inOut != nullptr);
+ PLUGIN_VALIDATE(nbInputs == 2);
+ PLUGIN_VALIDATE(nbOutputs == getNbOutputs());
+ PLUGIN_VALIDATE(pos >= 0 && pos < (nbInputs + nbOutputs));
+ PluginTensorDesc const& desc = inOut[pos].desc;
+ return desc.type == DataType::kINT8 && desc.format == TensorFormat::kCHW32;
+ }
+ catch (std::exception const& e)
+ {
+ caughtError(e);
+ }
+ return false;
}
-void SkipLayerNormInterleavedPluginBase::serialize(void* buffer) const noexcept
+int32_t SkipLayerNormInterleavedPluginBase::getOutputShapes(DimsExprs const* inputs, int32_t nbInputs,
+ DimsExprs const* shapeInputs, int32_t nbShapeInputs, DimsExprs* outputs, int32_t nbOutputs,
+ IExprBuilder& exprBuilder) noexcept
{
try
{
- serialize_value(&buffer, mLd);
-
- char* d = static_cast(buffer);
- serFromDev(d, static_cast(mBetaDev.get()), mLd * mParamWordsize);
- serFromDev(d, static_cast(mGammaDev.get()), mLd * mParamWordsize);
+ PLUGIN_VALIDATE(inputs != nullptr);
+ PLUGIN_VALIDATE(nbInputs == 2);
+ PLUGIN_VALIDATE(nbOutputs == getNbOutputs());
+ PLUGIN_VALIDATE(inputs[0].nbDims == inputs[1].nbDims);
+ for (int32_t i = 0; i < nbOutputs; ++i)
+ {
+ outputs[i] = inputs[0];
+ }
+ return pluginStatus_t::STATUS_SUCCESS;
}
catch (std::exception const& e)
{
caughtError(e);
}
+ return pluginStatus_t::STATUS_FAILURE;
}
-void SkipLayerNormInterleavedPluginBase::destroy() noexcept
+int32_t SkipLayerNormInterleavedPluginBase::getOutputDataTypes(
+ DataType* outputTypes, int32_t nbOutputs, DataType const* inputTypes, int32_t nbInputs) const noexcept
{
try
{
- // This gets called when the network containing plugin is destroyed
- mGammaDev.reset(nullptr);
- mBetaDev.reset(nullptr);
- delete this;
+ PLUGIN_VALIDATE(inputTypes != nullptr);
+ PLUGIN_VALIDATE(nbOutputs == getNbOutputs());
+ PLUGIN_VALIDATE(nbInputs == 2);
+ for (int32_t i = 0; i < nbOutputs; ++i)
+ {
+ outputTypes[i] = inputTypes[0];
+ }
+ return pluginStatus_t::STATUS_SUCCESS;
}
catch (std::exception const& e)
{
caughtError(e);
}
+ return pluginStatus_t::STATUS_FAILURE;
}
-void SkipLayerNormInterleavedPluginHFace::destroy() noexcept
+int32_t SkipLayerNormInterleavedPluginBase::configurePlugin(DynamicPluginTensorDesc const* inputs, int32_t nbInputs,
+ DynamicPluginTensorDesc const* outputs, int32_t nbOutputs) noexcept
{
- BERT_DEBUG_MSG("SkipLayerNormInterleavedPluginHFace destroy");
- SkipLayerNormInterleavedPluginBase::destroy();
+ return pluginStatus_t::STATUS_SUCCESS;
}
-void SkipLayerNormInterleavedPluginMTron::destroy() noexcept
+size_t SkipLayerNormInterleavedPluginBase::getWorkspaceSize(DynamicPluginTensorDesc const* inputs, int32_t nbInputs,
+ DynamicPluginTensorDesc const* outputs, int32_t nbOutputs) const noexcept
{
- BERT_DEBUG_MSG("SkipLayerNormInterleavedPluginMTron destroy");
- SkipLayerNormInterleavedPluginBase::destroy();
+ return 0;
}
+// End IPluginV3OneBuild method definitions
-void SkipLayerNormInterleavedPluginBase::setPluginNamespace(char const* libNamespace) noexcept
+//////
+// IPluginV3OneCore method definitions
+// - getPluginVersion() (MTron, HFace)
+// - getPluginName() (Base)
+// - getPluginNamespace() (Base)
+// - setPluginNamespace() (Base)
+//////
+char const* SkipLayerNormInterleavedPluginHFace::getPluginVersion() const noexcept
{
- mNamespace = libNamespace;
+ return kSKIP_LAYER_NORM_INTERLEAVED_VERSION_HFACE;
+}
+
+char const* SkipLayerNormInterleavedPluginMTron::getPluginVersion() const noexcept
+{
+ return kSKIP_LAYER_NORM_INTERLEAVED_VERSION_MTRON;
+}
+
+char const* SkipLayerNormInterleavedPluginBase::getPluginName() const noexcept
+{
+ return kSKIP_LAYER_NORM_INTERLEAVED_NAME;
}
char const* SkipLayerNormInterleavedPluginBase::getPluginNamespace() const noexcept
@@ -456,10 +482,18 @@ char const* SkipLayerNormInterleavedPluginBase::getPluginNamespace() const noexc
return mNamespace.c_str();
}
-/////////////////////////////////////////////////////////
+void SkipLayerNormInterleavedPluginBase::setPluginNamespace(char const* libNamespace) noexcept
+{
+ mNamespace = libNamespace;
+}
+// End IPluginV3OneCore method definitions
+
+//////////////////////////// Plugin Creator member definitions /////////////////////////////
SkipLayerNormInterleavedPluginBaseCreator::SkipLayerNormInterleavedPluginBaseCreator()
{
+ static std::mutex sMutex;
+ std::lock_guard lock(sMutex);
mPluginAttributes.clear();
mPluginAttributes.emplace_back(PluginField("beta"));
mPluginAttributes.emplace_back(PluginField("gamma"));
@@ -497,40 +531,8 @@ PluginFieldCollection const* SkipLayerNormInterleavedPluginBaseCreator::getField
return &mFC;
}
-void buildBetaAndGamma(PluginFieldCollection const* fc, Weights& beta, Weights& gamma)
-{
- plugin::validateRequiredAttributesExist({"beta", "gamma"}, fc);
-
- for (int32_t i = 0; i < fc->nbFields; i++)
- {
- std::string field_name(fc->fields[i].name);
-
- if (field_name.compare("beta") == 0)
- {
- BERT_DEBUG_MSG("Building beta...");
- beta.values = fc->fields[i].data;
- beta.count = fc->fields[i].length;
- beta.type = fieldTypeToDataType(fc->fields[i].type);
- }
-
- if (field_name.compare("gamma") == 0)
- {
- BERT_DEBUG_MSG("Building gamma...");
- gamma.values = fc->fields[i].data;
- gamma.count = fc->fields[i].length;
- gamma.type = fieldTypeToDataType(fc->fields[i].type);
- }
- }
-
- PLUGIN_VALIDATE(beta.values != nullptr, "SkipLayerNorm: invalid beta");
- PLUGIN_VALIDATE(beta.count > 0, "SkipLayerNorm: invalid beta");
-
- PLUGIN_VALIDATE(gamma.values != nullptr, "SkipLayerNorm: invalid gamma");
- PLUGIN_VALIDATE(gamma.count > 0, "SkipLayerNorm: invalid gamma");
-}
-
-IPluginV2* SkipLayerNormInterleavedPluginHFaceCreator::createPlugin(
- char const* name, PluginFieldCollection const* fc) noexcept
+IPluginV3* SkipLayerNormInterleavedPluginHFaceCreator::createPlugin(
+ char const* name, PluginFieldCollection const* fc, TensorRTPhase phase) noexcept
{
try
{
@@ -549,8 +551,8 @@ IPluginV2* SkipLayerNormInterleavedPluginHFaceCreator::createPlugin(
return nullptr;
}
-IPluginV2* SkipLayerNormInterleavedPluginMTronCreator::createPlugin(
- char const* name, PluginFieldCollection const* fc) noexcept
+IPluginV3* SkipLayerNormInterleavedPluginMTronCreator::createPlugin(
+ char const* name, PluginFieldCollection const* fc, TensorRTPhase phase) noexcept
{
try
{
@@ -571,40 +573,6 @@ IPluginV2* SkipLayerNormInterleavedPluginMTronCreator::createPlugin(
return nullptr;
}
-IPluginV2* SkipLayerNormInterleavedPluginHFaceCreator::deserializePlugin(
- char const* name, void const* serialData, size_t serialLength) noexcept
-{
- // This object will be deleted when the network is destroyed, which will
- // call SkipLayerNormInterleavedPlugin::destroy()
- try
- {
- BERT_DEBUG_MSG("SkipLayerNormInterleavedPluginHFaceCreator deserializePlugin");
- return new SkipLayerNormInterleavedPluginHFace(name, serialData, serialLength);
- }
- catch (std::exception const& e)
- {
- caughtError(e);
- }
- return nullptr;
-}
-
-IPluginV2* SkipLayerNormInterleavedPluginMTronCreator::deserializePlugin(
- char const* name, void const* serialData, size_t serialLength) noexcept
-{
- // This object will be deleted when the network is destroyed, which will
- // call SkipLayerNormInterleavedPlugin::destroy()
- try
- {
- BERT_DEBUG_MSG("SkipLayerNormInterleavedPluginMTronCreator deserializePlugin");
- return new SkipLayerNormInterleavedPluginMTron(name, serialData, serialLength);
- }
- catch (std::exception const& e)
- {
- caughtError(e);
- }
- return nullptr;
-}
-
void SkipLayerNormInterleavedPluginBaseCreator::setPluginNamespace(char const* libNamespace) noexcept
{
mNamespace = libNamespace;
@@ -614,3 +582,4 @@ char const* SkipLayerNormInterleavedPluginBaseCreator::getPluginNamespace() cons
{
return mNamespace.c_str();
}
+// End Plugin Creator member definitions
diff --git a/plugin/skipLayerNormPlugin/skipLayerNormInt8InterleavedPlugin.h b/plugin/skipLayerNormPlugin/skipLayerNormInt8InterleavedPlugin.h
index e858919b..3f675232 100644
--- a/plugin/skipLayerNormPlugin/skipLayerNormInt8InterleavedPlugin.h
+++ b/plugin/skipLayerNormPlugin/skipLayerNormInt8InterleavedPlugin.h
@@ -48,50 +48,77 @@ int32_t launch_large_mtron(cudaStream_t stream, int32_t const ld, int32_t const
int8_t const* skip, half const* beta, half const* gamma, int8_t* output, int8_t* preln, float const dqScaleIn,
float const dqScaleSkip, float const qScale, float const qSkipScale);
-class SkipLayerNormInterleavedPluginBase : public nvinfer1::IPluginV2DynamicExt
+class SkipLayerNormInterleavedPluginBase : public IPluginV3,
+ public IPluginV3OneCore,
+ public IPluginV3OneBuild,
+ public IPluginV3OneRuntime
{
public:
SkipLayerNormInterleavedPluginBase(
std::string const& name, nvinfer1::Weights const& beta, nvinfer1::Weights const& gamma);
- SkipLayerNormInterleavedPluginBase(std::string const& name, void const* data, size_t length);
-
// It doesn't make sense to make SkipLayerNormInterleavedPlugin without
// arguments, so we delete default constructor.
SkipLayerNormInterleavedPluginBase() = delete;
- // IPluginV2DynamicExt Methods
- nvinfer1::DimsExprs getOutputDimensions(int32_t outputIndex, nvinfer1::DimsExprs const* inputs, int32_t nbInputs,
- nvinfer1::IExprBuilder& exprBuilder) noexcept override;
- bool supportsFormatCombination(
- int32_t pos, nvinfer1::PluginTensorDesc const* inOut, int32_t nbInputs, int32_t nbOutputs) noexcept override;
- void configurePlugin(nvinfer1::DynamicPluginTensorDesc const* in, int32_t nbInputs,
- nvinfer1::DynamicPluginTensorDesc const* out, int32_t nbOutputs) noexcept override;
- size_t getWorkspaceSize(nvinfer1::PluginTensorDesc const* inputs, int32_t nbInputs,
- nvinfer1::PluginTensorDesc const* outputs, int32_t nbOutputs) const noexcept override;
-
- // IPluginV2Ext Methods
- nvinfer1::DataType getOutputDataType(
- int32_t index, nvinfer1::DataType const* inputTypes, int32_t nbInputs) const noexcept override;
-
- // IPluginV2 Methods
- char const* getPluginType() const noexcept override;
- size_t getSerializationSize() const noexcept override;
- void serialize(void* buffer) const noexcept override;
- void destroy() noexcept override;
- void setPluginNamespace(char const* pluginNamespace) noexcept override;
+ ~SkipLayerNormInterleavedPluginBase() override;
+
+ // IPluginV3 Methods
+ // NOTE: since this is itself is an abstract class, the rest of virtual methods defined in its children classes
+ IPluginCapability* getCapabilityInterface(PluginCapabilityType type) noexcept override;
+ // end of IPluginV3 Methods
+
+ // IPluginV3OneCore Methods
+ char const* getPluginName() const noexcept override;
+
char const* getPluginNamespace() const noexcept override;
+ void setPluginNamespace(char const* pluginNamespace) noexcept;
+ // end of IPluginV3OneCore Methods
+
+ // IPluginV3Build Methods
+ bool supportsFormatCombination(
+ int32_t pos, DynamicPluginTensorDesc const* inOut, int32_t nbInputs, int32_t nbOutputs) noexcept override;
+
+ int32_t getOutputShapes(DimsExprs const* inputs, int32_t nbInputs, DimsExprs const* shapeInputs,
+ int32_t nbShapeInputs, DimsExprs* outputs, int32_t nbOutputs, IExprBuilder& exprBuilder) noexcept override;
+
+ int32_t configurePlugin(DynamicPluginTensorDesc const* in, int32_t nbInputs, DynamicPluginTensorDesc const* out,
+ int32_t nbOutputs) noexcept override;
+
+ size_t getWorkspaceSize(DynamicPluginTensorDesc const* inputs, int32_t nbInputs,
+ DynamicPluginTensorDesc const* outputs, int32_t nbOutputs) const noexcept override;
+
+ int32_t getOutputDataTypes(
+ DataType* outputTypes, int32_t nbOutputs, DataType const* inputTypes, int32_t nbInputs) const noexcept override;
+ // end IPluginV3Build Methods
+
+ // IPluginV3Runtime Methods
+ int32_t onShapeChange(
+ PluginTensorDesc const* in, int32_t nbInputs, PluginTensorDesc const* out, int32_t nbOutputs) noexcept override;
+
+ IPluginV3* attachToContext(IPluginResourceContext* context) noexcept override;
+
+ PluginFieldCollection const* getFieldsToSerialize() noexcept override;
+ // end IPluginV3Runtime Methods
+
protected:
+ // metadata fields
std::string const& mLayerName;
std::string mNamespace;
+ std::vector mDataToSerialize;
+ nvinfer1::PluginFieldCollection mFCToSerialize;
- bert::cuda_unique_ptr mGammaDev;
- bert::cuda_unique_ptr mBetaDev;
- size_t mLd{}; // leading dim
+ // members that participate in ser/deserialization
bert::WeightsWithOwnership mGamma;
bert::WeightsWithOwnership mBeta;
+ // device-side
+ bert::cuda_unique_ptr mGammaDev;
+ bert::cuda_unique_ptr mBetaDev;
+
+ // derived members
+ size_t mLd{}; // leading dim
size_t mParamWordsize{};
bool mParamsOnDevice{};
};
@@ -102,22 +129,22 @@ class SkipLayerNormInterleavedPluginHFace : public SkipLayerNormInterleavedPlugi
SkipLayerNormInterleavedPluginHFace(
std::string const& name, nvinfer1::Weights const& beta, nvinfer1::Weights const& gamma);
- SkipLayerNormInterleavedPluginHFace(std::string const& name, void const* data, size_t length);
-
// It doesn't make sense to make SkipLayerNormInterleavedPlugin without
// arguments, so we delete default constructor.
SkipLayerNormInterleavedPluginHFace() = delete;
- // IPluginV2DynamicExt Methods
- nvinfer1::IPluginV2DynamicExt* clone() const noexcept override;
+ ~SkipLayerNormInterleavedPluginHFace() override;
+
+ // IPluginV3Runtime overrides
+ IPluginV3* clone() noexcept;
+
int32_t enqueue(nvinfer1::PluginTensorDesc const* inputDesc, nvinfer1::PluginTensorDesc const* outputDesc,
void const* const* inputs, void* const* outputs, void* workspace, cudaStream_t stream) noexcept override;
- // IPluginV2 Methods
- int32_t initialize() noexcept override;
- void terminate() noexcept override;
- void destroy() noexcept override;
+ // IPluginV3OneCore override
char const* getPluginVersion() const noexcept override;
+
+ // IPluginV3OneBuild override
int32_t getNbOutputs() const noexcept override;
};
@@ -127,35 +154,36 @@ class SkipLayerNormInterleavedPluginMTron : public SkipLayerNormInterleavedPlugi
SkipLayerNormInterleavedPluginMTron(
std::string const& name, nvinfer1::Weights const& beta, nvinfer1::Weights const& gamma);
- SkipLayerNormInterleavedPluginMTron(std::string const& name, void const* data, size_t length);
-
// It doesn't make sense to make SkipLayerNormInterleavedPlugin without
// arguments, so we delete default constructor.
SkipLayerNormInterleavedPluginMTron() = delete;
- // IPluginV2DynamicExt Methods
- nvinfer1::IPluginV2DynamicExt* clone() const noexcept override;
+ ~SkipLayerNormInterleavedPluginMTron() override;
+
+ // IPluginV3Runtime overrides
+ IPluginV3* clone() noexcept;
+
int32_t enqueue(nvinfer1::PluginTensorDesc const* inputDesc, nvinfer1::PluginTensorDesc const* outputDesc,
void const* const* inputs, void* const* outputs, void* workspace, cudaStream_t stream) noexcept override;
- // IPluginV2 Methods
- int32_t initialize() noexcept override;
- void terminate() noexcept override;
- void destroy() noexcept override;
+ // IPluginV3OneCore override
char const* getPluginVersion() const noexcept override;
+
+ // IPluginV3OneBuild override
int32_t getNbOutputs() const noexcept override;
};
-class SkipLayerNormInterleavedPluginBaseCreator : public nvinfer1::IPluginCreator
+class SkipLayerNormInterleavedPluginBaseCreator : public nvinfer1::IPluginCreatorV3One
{
public:
SkipLayerNormInterleavedPluginBaseCreator();
+ ~SkipLayerNormInterleavedPluginBaseCreator() override = default;
char const* getPluginName() const noexcept override;
nvinfer1::PluginFieldCollection const* getFieldNames() noexcept override;
- void setPluginNamespace(char const* pluginNamespace) noexcept override;
+ void setPluginNamespace(char const* pluginNamespace) noexcept;
char const* getPluginNamespace() const noexcept override;
@@ -170,11 +198,11 @@ class SkipLayerNormInterleavedPluginHFaceCreator : public SkipLayerNormInterleav
public:
SkipLayerNormInterleavedPluginHFaceCreator();
+ ~SkipLayerNormInterleavedPluginHFaceCreator() override = default;
+
char const* getPluginVersion() const noexcept override;
- nvinfer1::IPluginV2* createPlugin(char const* name, nvinfer1::PluginFieldCollection const* fc) noexcept override;
- nvinfer1::IPluginV2* deserializePlugin(
- char const* name, void const* serialData, size_t serialLength) noexcept override;
+ IPluginV3* createPlugin(char const* name, PluginFieldCollection const* fc, TensorRTPhase phase) noexcept override;
};
class SkipLayerNormInterleavedPluginMTronCreator : public SkipLayerNormInterleavedPluginBaseCreator
@@ -182,11 +210,11 @@ class SkipLayerNormInterleavedPluginMTronCreator : public SkipLayerNormInterleav
public:
SkipLayerNormInterleavedPluginMTronCreator();
+ ~SkipLayerNormInterleavedPluginMTronCreator() override = default;
+
char const* getPluginVersion() const noexcept override;
- nvinfer1::IPluginV2* createPlugin(char const* name, nvinfer1::PluginFieldCollection const* fc) noexcept override;
- nvinfer1::IPluginV2* deserializePlugin(
- char const* name, void const* serialData, size_t serialLength) noexcept override;
+ IPluginV3* createPlugin(char const* name, PluginFieldCollection const* fc, TensorRTPhase phase) noexcept override;
};
} // namespace bert
diff --git a/plugin/skipLayerNormPlugin/skipLayerNormInt8InterleavedPluginLegacy.cpp b/plugin/skipLayerNormPlugin/skipLayerNormInt8InterleavedPluginLegacy.cpp
new file mode 100644
index 00000000..6a6100db
--- /dev/null
+++ b/plugin/skipLayerNormPlugin/skipLayerNormInt8InterleavedPluginLegacy.cpp
@@ -0,0 +1,606 @@
+/*
+ * SPDX-FileCopyrightText: Copyright (c) 1993-2024 NVIDIA CORPORATION &
+ * AFFILIATES. All rights reserved. SPDX-License-Identifier: Apache-2.0
+ *
+ * Licensed under the Apache License, Version 2.0 (the "License");
+ * you may not use this file except in compliance with the License.
+ * You may obtain a copy of the License at
+ *
+ * http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing, software
+ * distributed under the License is distributed on an "AS IS" BASIS,
+ * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+ * See the License for the specific language governing permissions and
+ * limitations under the License.
+ */
+
+#include "skipLayerNormInt8InterleavedPluginLegacy.h"
+#include "NvInfer.h"
+#include "common/serialize.hpp"
+#include
+
+#include
+#include
+
+using namespace nvinfer1;
+using namespace nvinfer1::plugin;
+using namespace nvinfer1::plugin::bert;
+
+// Clip plugin specific constants
+namespace
+{
+constexpr char const* kSKIP_LAYER_NORM_INTERLEAVED_VERSION_HFACE_LEGACY{"3"};
+constexpr char const* kSKIP_LAYER_NORM_INTERLEAVED_VERSION_MTRON_LEGACY{"4"};
+constexpr char const* kSKIP_LAYER_NORM_INTERLEAVED_NAME{"CustomSkipLayerNormPluginDynamic"};
+
+void buildBetaAndGamma(PluginFieldCollection const* fc, Weights& beta, Weights& gamma)
+{
+ plugin::validateRequiredAttributesExist({"beta", "gamma"}, fc);
+
+ for (int32_t i = 0; i < fc->nbFields; i++)
+ {
+ std::string field_name(fc->fields[i].name);
+
+ if (field_name.compare("beta") == 0)
+ {
+ BERT_DEBUG_MSG("Building beta...");
+ beta.values = fc->fields[i].data;
+ beta.count = fc->fields[i].length;
+ beta.type = fieldTypeToDataType(fc->fields[i].type);
+ }
+
+ if (field_name.compare("gamma") == 0)
+ {
+ BERT_DEBUG_MSG("Building gamma...");
+ gamma.values = fc->fields[i].data;
+ gamma.count = fc->fields[i].length;
+ gamma.type = fieldTypeToDataType(fc->fields[i].type);
+ }
+ }
+
+ PLUGIN_VALIDATE(beta.values != nullptr, "SkipLayerNorm: invalid beta");
+ PLUGIN_VALIDATE(beta.count > 0, "SkipLayerNorm: invalid beta");
+
+ PLUGIN_VALIDATE(gamma.values != nullptr, "SkipLayerNorm: invalid gamma");
+ PLUGIN_VALIDATE(gamma.count > 0, "SkipLayerNorm: invalid gamma");
+}
+
+void checkDescs(PluginTensorDesc const& iDesc, PluginTensorDesc const& sDesc, PluginTensorDesc const& oDesc)
+{
+ PLUGIN_VALIDATE(iDesc.dims.nbDims == 4);
+ PLUGIN_VALIDATE(iDesc.dims.nbDims == sDesc.dims.nbDims);
+ PLUGIN_VALIDATE(std::equal(iDesc.dims.d, iDesc.dims.d + iDesc.dims.nbDims, sDesc.dims.d));
+ PLUGIN_VALIDATE(std::equal(iDesc.dims.d, iDesc.dims.d + iDesc.dims.nbDims, oDesc.dims.d));
+ PLUGIN_VALIDATE(iDesc.dims.d[0] == 1);
+ PLUGIN_VALIDATE(iDesc.dims.d[3] == 1);
+ PLUGIN_VALIDATE(iDesc.format == TensorFormat::kCHW32);
+ PLUGIN_VALIDATE(iDesc.type == DataType::kINT8);
+ PLUGIN_VALIDATE(iDesc.format == sDesc.format);
+ PLUGIN_VALIDATE(iDesc.format == oDesc.format);
+ PLUGIN_VALIDATE(iDesc.type == sDesc.type);
+ PLUGIN_VALIDATE(iDesc.type == oDesc.type);
+}
+} // namespace
+
+// Static class fields initialization
+PluginFieldCollection SkipLayerNormInterleavedPluginBaseLegacyCreator::mFC{};
+std::vector SkipLayerNormInterleavedPluginBaseLegacyCreator::mPluginAttributes;
+
+REGISTER_TENSORRT_PLUGIN(SkipLayerNormInterleavedPluginHFaceLegacyCreator);
+REGISTER_TENSORRT_PLUGIN(SkipLayerNormInterleavedPluginMTronLegacyCreator);
+
+constexpr auto kPARAM_TYPE = DataType::kHALF;
+
+SkipLayerNormInterleavedPluginBaseLegacy::SkipLayerNormInterleavedPluginBaseLegacy(
+ std::string const& name, Weights const& beta, Weights const& gamma)
+ : mLayerName(name)
+ , mGammaDev(nullptr)
+ , mBetaDev(nullptr)
+ , mLd(beta.count)
+ , mParamsOnDevice(false)
+{
+ PLUGIN_VALIDATE(mLd > 0);
+ PLUGIN_VALIDATE(beta.count == gamma.count);
+ // dataType for beta, gamma weights is always fp16
+
+ mParamWordsize = getElementSize(kPARAM_TYPE);
+
+ mBeta.convertAndCopy(beta, kPARAM_TYPE);
+ mGamma.convertAndCopy(gamma, kPARAM_TYPE);
+}
+
+SkipLayerNormInterleavedPluginHFaceLegacy::SkipLayerNormInterleavedPluginHFaceLegacy(
+ std::string const& name, Weights const& beta, Weights const& gamma)
+ : SkipLayerNormInterleavedPluginBaseLegacy(name, beta, gamma)
+{
+}
+
+SkipLayerNormInterleavedPluginMTronLegacy::SkipLayerNormInterleavedPluginMTronLegacy(
+ std::string const& name, Weights const& beta, Weights const& gamma)
+ : SkipLayerNormInterleavedPluginBaseLegacy(name, beta, gamma)
+{
+}
+
+SkipLayerNormInterleavedPluginBaseLegacy::SkipLayerNormInterleavedPluginBaseLegacy(
+ std::string const& name, void const* data, size_t length)
+ : mLayerName(name)
+ , mGammaDev(nullptr)
+ , mBetaDev(nullptr)
+ , mParamsOnDevice(false)
+{
+ // Deserialize in the same order as serialization
+ deserialize_value(&data, &length, &mLd);
+
+ mParamWordsize = getElementSize(kPARAM_TYPE);
+
+ char const* d = static_cast(data);
+ mBeta.convertAndCopy(d, mLd, kPARAM_TYPE);
+ mGamma.convertAndCopy(d, mLd, kPARAM_TYPE);
+}
+
+SkipLayerNormInterleavedPluginHFaceLegacy::SkipLayerNormInterleavedPluginHFaceLegacy(
+ std::string const& name, void const* data, size_t length)
+ : SkipLayerNormInterleavedPluginBaseLegacy(name, data, length)
+{
+ BERT_DEBUG_MSG("SkipLayerNormInterleavedPluginHFaceLegacy deserialize");
+}
+
+SkipLayerNormInterleavedPluginMTronLegacy::SkipLayerNormInterleavedPluginMTronLegacy(
+ std::string const& name, void const* data, size_t length)
+ : SkipLayerNormInterleavedPluginBaseLegacy(name, data, length)
+{
+ BERT_DEBUG_MSG("SkipLayerNormInterleavedPluginMTronLegacy deserialize");
+}
+
+// IPluginV2DynamicExt Methods
+IPluginV2DynamicExt* SkipLayerNormInterleavedPluginHFaceLegacy::clone() const noexcept
+{
+ try
+ {
+ BERT_DEBUG_MSG("SkipLayerNormInterleavedPluginHFaceLegacy clone");
+
+ auto* p = new SkipLayerNormInterleavedPluginHFaceLegacy(mLayerName, mBeta, mGamma);
+ p->initialize();
+ p->setPluginNamespace(mNamespace.c_str());
+ return p;
+ }
+ catch (std::exception const& e)
+ {
+ caughtError(e);
+ }
+ return nullptr;
+}
+
+IPluginV2DynamicExt* SkipLayerNormInterleavedPluginMTronLegacy::clone() const noexcept
+{
+ try
+ {
+ BERT_DEBUG_MSG("SkipLayerNormInterleavedPluginMTronLegacy clone");
+
+ auto* p = new SkipLayerNormInterleavedPluginMTronLegacy(mLayerName, mBeta, mGamma);
+ p->initialize();
+ p->setPluginNamespace(mNamespace.c_str());
+ return p;
+ }
+ catch (std::exception const& e)
+ {
+ caughtError(e);
+ }
+ return nullptr;
+}
+
+DimsExprs SkipLayerNormInterleavedPluginBaseLegacy::getOutputDimensions(
+ int32_t outputIndex, DimsExprs const* inputs, int32_t nbInputs, IExprBuilder& exprBuilder) noexcept
+{
+ try
+ {
+ PLUGIN_VALIDATE(inputs != nullptr);
+ PLUGIN_VALIDATE(nbInputs == 2);
+ PLUGIN_VALIDATE(outputIndex >= 0 && outputIndex < getNbOutputs());
+ PLUGIN_VALIDATE(inputs[0].nbDims == inputs[1].nbDims);
+ return inputs[0];
+ }
+ catch (std::exception const& e)
+ {
+ caughtError(e);
+ }
+ return DimsExprs{};
+}
+
+bool SkipLayerNormInterleavedPluginBaseLegacy::supportsFormatCombination(
+ int32_t pos, PluginTensorDesc const* inOut, int32_t nbInputs, int32_t nbOutputs) noexcept
+{
+ try
+ {
+ PLUGIN_VALIDATE(inOut != nullptr);
+ PLUGIN_VALIDATE(nbInputs == 2);
+ PLUGIN_VALIDATE(nbOutputs == getNbOutputs());
+ PLUGIN_VALIDATE(pos >= 0 && pos < (nbInputs + nbOutputs));
+
+ PluginTensorDesc const& desc = inOut[pos];
+ return desc.type == DataType::kINT8 && desc.format == TensorFormat::kCHW32;
+ }
+ catch (std::exception const& e)
+ {
+ caughtError(e);
+ }
+ return false;
+}
+
+void SkipLayerNormInterleavedPluginBaseLegacy::configurePlugin(DynamicPluginTensorDesc const* inputs, int32_t nbInputs,
+ DynamicPluginTensorDesc const* outputs, int32_t nbOutputs) noexcept
+{
+ try
+ {
+ // Validate input arguments
+ PLUGIN_VALIDATE(inputs != nullptr);
+ PLUGIN_VALIDATE(outputs != nullptr);
+ PLUGIN_VALIDATE(nbOutputs == getNbOutputs());
+ PLUGIN_VALIDATE(nbInputs == 2);
+ PLUGIN_VALIDATE(DataType::kINT8 == inputs[0].desc.type);
+ PLUGIN_VALIDATE(DataType::kINT8 == inputs[1].desc.type);
+
+ auto const& inDims0 = inputs[0].desc.dims;
+ auto const& inDims1 = inputs[1].desc.dims;
+ TRT_UNUSED inDims1;
+
+ PLUGIN_VALIDATE(inDims0.nbDims == inDims1.nbDims);
+ PLUGIN_VALIDATE(std::equal(inDims0.d, inDims0.d + inDims0.nbDims, inDims1.d));
+
+ mParamWordsize = getElementSize(kPARAM_TYPE);
+
+ if (!mParamsOnDevice)
+ {
+ copyToDevice(mGamma, getWeightsSize(mGamma, kPARAM_TYPE), mGammaDev);
+ copyToDevice(mBeta, getWeightsSize(mBeta, kPARAM_TYPE), mBetaDev);
+ mParamsOnDevice = true;
+ }
+ }
+ catch (std::exception const& e)
+ {
+ caughtError(e);
+ }
+}
+
+size_t SkipLayerNormInterleavedPluginBaseLegacy::getWorkspaceSize(
+ PluginTensorDesc const* inputs, int32_t nbInputs, PluginTensorDesc const* outputs, int32_t nbOutputs) const noexcept
+{
+ return 0;
+}
+
+int32_t SkipLayerNormInterleavedPluginHFaceLegacy::enqueue(PluginTensorDesc const* inputDesc,
+ PluginTensorDesc const* outputDesc, void const* const* inputs, void* const* outputs, void* /* workspace */,
+ cudaStream_t stream) noexcept
+{
+ try
+ {
+ PLUGIN_VALIDATE(inputDesc != nullptr && outputDesc != nullptr && inputs != nullptr && outputs != nullptr);
+
+ // Input shape: 1x(hxd)xtotalx1
+ auto const iDesc = inputDesc[0];
+ auto const sDesc = inputDesc[1];
+ auto const oDesc = outputDesc[0];
+ checkDescs(iDesc, sDesc, oDesc);
+
+ int32_t const ld = iDesc.dims.d[1];
+ int32_t const total = iDesc.dims.d[2];
+ float const dqScaleIn = iDesc.scale;
+ float const dqScaleSkip = sDesc.scale;
+ float const qScale = 1.F / oDesc.scale;
+ int8_t const* input = static_cast(inputs[0]);
+ int8_t const* skip = static_cast(inputs[1]);
+ int8_t* output = static_cast(outputs[0]);
+ half const* gamma = static_cast(mGammaDev.get());
+ half const* beta = static_cast(mBetaDev.get());
+
+ if (total < 4096)
+ {
+ return launch_small_hface(
+ stream, ld, total, input, skip, beta, gamma, output, dqScaleIn, dqScaleSkip, qScale);
+ }
+
+ return launch_large_hface(stream, ld, total, input, skip, beta, gamma, output, dqScaleIn, dqScaleSkip, qScale);
+ }
+ catch (std::exception const& e)
+ {
+ caughtError(e);
+ }
+ return -1;
+}
+
+int32_t SkipLayerNormInterleavedPluginMTronLegacy::enqueue(PluginTensorDesc const* inputDesc,
+ PluginTensorDesc const* outputDesc, void const* const* inputs, void* const* outputs, void* /* workspace */,
+ cudaStream_t stream) noexcept
+{
+ try
+ {
+ PLUGIN_VALIDATE(inputDesc != nullptr && outputDesc != nullptr && inputs != nullptr && outputs != nullptr);
+
+ // Input shape: 1x(hxd)xtotalx1
+ auto const iDesc = inputDesc[0];
+ auto const sDesc = inputDesc[1];
+ auto const oDesc = outputDesc[0];
+ auto const pDesc = outputDesc[1];
+ checkDescs(iDesc, sDesc, oDesc);
+ PLUGIN_VALIDATE(std::equal(iDesc.dims.d, iDesc.dims.d + iDesc.dims.nbDims, pDesc.dims.d));
+
+ int32_t const ld = iDesc.dims.d[1];
+ int32_t const total = iDesc.dims.d[2];
+ float const dqScaleIn = iDesc.scale;
+ float const dqScaleSkip = sDesc.scale;
+ float const qScale = 1.F / oDesc.scale;
+ float const qSkipScale = 1.F / pDesc.scale;
+ int8_t const* input = static_cast(inputs[0]);
+ int8_t const* skip = static_cast(inputs[1]);
+ int8_t* output = static_cast(outputs[0]);
+ int8_t* preln = static_cast(outputs[1]);
+ half const* gamma = static_cast(mGammaDev.get());
+ half const* beta = static_cast(mBetaDev.get());
+
+ if (total < 4096)
+ {
+ return launch_small_mtron(
+ stream, ld, total, input, skip, beta, gamma, output, preln, dqScaleIn, dqScaleSkip, qScale, qSkipScale);
+ }
+
+ return launch_large_mtron(
+ stream, ld, total, input, skip, beta, gamma, output, preln, dqScaleIn, dqScaleSkip, qScale, qSkipScale);
+ }
+ catch (std::exception const& e)
+ {
+ caughtError(e);
+ }
+ return -1;
+}
+
+// IPluginV2Ext Methods
+DataType SkipLayerNormInterleavedPluginBaseLegacy::getOutputDataType(
+ int32_t index, DataType const* inputTypes, int32_t nbInputs) const noexcept
+{
+ try
+ {
+ PLUGIN_VALIDATE(inputTypes != nullptr);
+ PLUGIN_VALIDATE(index >= 0 && index < getNbOutputs());
+ PLUGIN_VALIDATE(nbInputs == 2);
+ return inputTypes[0];
+ }
+ catch (std::exception const& e)
+ {
+ caughtError(e);
+ }
+ return DataType{};
+}
+
+// IPluginV2 Methods
+char const* SkipLayerNormInterleavedPluginBaseLegacy::getPluginType() const noexcept
+{
+ return kSKIP_LAYER_NORM_INTERLEAVED_NAME;
+}
+
+char const* SkipLayerNormInterleavedPluginHFaceLegacy::getPluginVersion() const noexcept
+{
+ return kSKIP_LAYER_NORM_INTERLEAVED_VERSION_HFACE_LEGACY;
+}
+
+char const* SkipLayerNormInterleavedPluginMTronLegacy::getPluginVersion() const noexcept
+{
+ return kSKIP_LAYER_NORM_INTERLEAVED_VERSION_MTRON_LEGACY;
+}
+
+int32_t SkipLayerNormInterleavedPluginHFaceLegacy::getNbOutputs() const noexcept
+{
+ return 1;
+}
+
+int32_t SkipLayerNormInterleavedPluginMTronLegacy::getNbOutputs() const noexcept
+{
+ return 2;
+}
+
+int32_t SkipLayerNormInterleavedPluginHFaceLegacy::initialize() noexcept
+{
+ BERT_DEBUG_MSG("SkipLayerNormInterleavedPluginHFaceLegacy initialize");
+ return 0;
+}
+
+int32_t SkipLayerNormInterleavedPluginMTronLegacy::initialize() noexcept
+{
+ BERT_DEBUG_MSG("SkipLayerNormInterleavedPluginMTronLegacy initialize");
+ return 0;
+}
+
+void SkipLayerNormInterleavedPluginHFaceLegacy::terminate() noexcept
+{
+ BERT_DEBUG_MSG("SkipLayerNormInterleavedPluginHFaceLegacy terminate");
+}
+
+void SkipLayerNormInterleavedPluginMTronLegacy::terminate() noexcept
+{
+ BERT_DEBUG_MSG("SkipLayerNormInterleavedPluginMTronLegacy terminate");
+}
+
+size_t SkipLayerNormInterleavedPluginBaseLegacy::getSerializationSize() const noexcept
+{
+ return 2 * mParamWordsize * mLd + sizeof(mLd);
+}
+
+void SkipLayerNormInterleavedPluginBaseLegacy::serialize(void* buffer) const noexcept
+{
+ try
+ {
+ serialize_value(&buffer, mLd);
+
+ char* d = static_cast(buffer);
+ serFromDev(d, static_cast(mBetaDev.get()), mLd * mParamWordsize);
+ serFromDev(d, static_cast(mGammaDev.get()), mLd * mParamWordsize);
+ }
+ catch (std::exception const& e)
+ {
+ caughtError(e);
+ }
+}
+
+void SkipLayerNormInterleavedPluginBaseLegacy::destroy() noexcept
+{
+ try
+ {
+ // This gets called when the network containing plugin is destroyed
+ mGammaDev.reset(nullptr);
+ mBetaDev.reset(nullptr);
+ delete this;
+ }
+ catch (std::exception const& e)
+ {
+ caughtError(e);
+ }
+}
+
+void SkipLayerNormInterleavedPluginHFaceLegacy::destroy() noexcept
+{
+ BERT_DEBUG_MSG("SkipLayerNormInterleavedPluginHFaceLegacy destroy");
+ SkipLayerNormInterleavedPluginBaseLegacy::destroy();
+}
+
+void SkipLayerNormInterleavedPluginMTronLegacy::destroy() noexcept
+{
+ BERT_DEBUG_MSG("SkipLayerNormInterleavedPluginMTronLegacy destroy");
+ SkipLayerNormInterleavedPluginBaseLegacy::destroy();
+}
+
+void SkipLayerNormInterleavedPluginBaseLegacy::setPluginNamespace(char const* libNamespace) noexcept
+{
+ mNamespace = libNamespace;
+}
+
+char const* SkipLayerNormInterleavedPluginBaseLegacy::getPluginNamespace() const noexcept
+{
+ return mNamespace.c_str();
+}
+
+/////////////////////////////////////////////////////////
+
+SkipLayerNormInterleavedPluginBaseLegacyCreator::SkipLayerNormInterleavedPluginBaseLegacyCreator()
+{
+ mPluginAttributes.clear();
+ mPluginAttributes.emplace_back(PluginField("beta"));
+ mPluginAttributes.emplace_back(PluginField("gamma"));
+ mFC.nbFields = mPluginAttributes.size();
+ mFC.fields = mPluginAttributes.data();
+}
+
+SkipLayerNormInterleavedPluginHFaceLegacyCreator::SkipLayerNormInterleavedPluginHFaceLegacyCreator()
+ : SkipLayerNormInterleavedPluginBaseLegacyCreator()
+{
+}
+
+SkipLayerNormInterleavedPluginMTronLegacyCreator::SkipLayerNormInterleavedPluginMTronLegacyCreator()
+ : SkipLayerNormInterleavedPluginBaseLegacyCreator()
+{
+}
+
+char const* SkipLayerNormInterleavedPluginBaseLegacyCreator::getPluginName() const noexcept
+{
+ return kSKIP_LAYER_NORM_INTERLEAVED_NAME;
+}
+
+char const* SkipLayerNormInterleavedPluginHFaceLegacyCreator::getPluginVersion() const noexcept
+{
+ return kSKIP_LAYER_NORM_INTERLEAVED_VERSION_HFACE_LEGACY;
+}
+
+char const* SkipLayerNormInterleavedPluginMTronLegacyCreator::getPluginVersion() const noexcept
+{
+ return kSKIP_LAYER_NORM_INTERLEAVED_VERSION_MTRON_LEGACY;
+}
+
+PluginFieldCollection const* SkipLayerNormInterleavedPluginBaseLegacyCreator::getFieldNames() noexcept
+{
+ return &mFC;
+}
+
+IPluginV2* SkipLayerNormInterleavedPluginHFaceLegacyCreator::createPlugin(
+ char const* name, PluginFieldCollection const* fc) noexcept
+{
+ try
+ {
+ BERT_DEBUG_MSG("SkipLayerNormInterleavedPluginHFaceLegacyCreator createPlugin");
+
+ Weights beta{DataType::kFLOAT, nullptr, 0};
+ Weights gamma{DataType::kFLOAT, nullptr, 0};
+ buildBetaAndGamma(fc, beta, gamma);
+
+ return new SkipLayerNormInterleavedPluginHFaceLegacy(name, beta, gamma);
+ }
+ catch (std::exception const& e)
+ {
+ caughtError(e);
+ }
+ return nullptr;
+}
+
+IPluginV2* SkipLayerNormInterleavedPluginMTronLegacyCreator::createPlugin(
+ char const* name, PluginFieldCollection const* fc) noexcept
+{
+ try
+ {
+ BERT_DEBUG_MSG("SkipLayerNormInterleavedPluginMTronLegacyCreator createPlugin");
+
+ PLUGIN_VALIDATE(fc != nullptr);
+
+ Weights beta{DataType::kFLOAT, nullptr, 0};
+ Weights gamma{DataType::kFLOAT, nullptr, 0};
+ buildBetaAndGamma(fc, beta, gamma);
+
+ return new SkipLayerNormInterleavedPluginMTronLegacy(name, beta, gamma);
+ }
+ catch (std::exception const& e)
+ {
+ caughtError(e);
+ }
+ return nullptr;
+}
+
+IPluginV2* SkipLayerNormInterleavedPluginHFaceLegacyCreator::deserializePlugin(
+ char const* name, void const* serialData, size_t serialLength) noexcept
+{
+ // This object will be deleted when the network is destroyed, which will
+ // call SkipLayerNormInterleavedPlugin::destroy()
+ try
+ {
+ BERT_DEBUG_MSG("SkipLayerNormInterleavedPluginHFaceLegacyCreator deserializePlugin");
+ return new SkipLayerNormInterleavedPluginHFaceLegacy(name, serialData, serialLength);
+ }
+ catch (std::exception const& e)
+ {
+ caughtError(e);
+ }
+ return nullptr;
+}
+
+IPluginV2* SkipLayerNormInterleavedPluginMTronLegacyCreator::deserializePlugin(
+ char const* name, void const* serialData, size_t serialLength) noexcept
+{
+ // This object will be deleted when the network is destroyed, which will
+ // call SkipLayerNormInterleavedPlugin::destroy()
+ try
+ {
+ BERT_DEBUG_MSG("SkipLayerNormInterleavedPluginMTronLegacyCreator deserializePlugin");
+ return new SkipLayerNormInterleavedPluginMTronLegacy(name, serialData, serialLength);
+ }
+ catch (std::exception const& e)
+ {
+ caughtError(e);
+ }
+ return nullptr;
+}
+
+void SkipLayerNormInterleavedPluginBaseLegacyCreator::setPluginNamespace(char const* libNamespace) noexcept
+{
+ mNamespace = libNamespace;
+}
+
+char const* SkipLayerNormInterleavedPluginBaseLegacyCreator::getPluginNamespace() const noexcept
+{
+ return mNamespace.c_str();
+}
diff --git a/plugin/skipLayerNormPlugin/skipLayerNormInt8InterleavedPluginLegacy.h b/plugin/skipLayerNormPlugin/skipLayerNormInt8InterleavedPluginLegacy.h
new file mode 100644
index 00000000..b5b56b38
--- /dev/null
+++ b/plugin/skipLayerNormPlugin/skipLayerNormInt8InterleavedPluginLegacy.h
@@ -0,0 +1,195 @@
+/*
+ * SPDX-FileCopyrightText: Copyright (c) 1993-2024 NVIDIA CORPORATION &
+ * AFFILIATES. All rights reserved. SPDX-License-Identifier: Apache-2.0
+ *
+ * Licensed under the Apache License, Version 2.0 (the "License");
+ * you may not use this file except in compliance with the License.
+ * You may obtain a copy of the License at
+ *
+ * http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing, software
+ * distributed under the License is distributed on an "AS IS" BASIS,
+ * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+ * See the License for the specific language governing permissions and
+ * limitations under the License.
+ */
+
+#ifndef TRT_SKIP_LAYER_NORM_INTERLEAVED_PLUGIN_LEGACY_H
+#define TRT_SKIP_LAYER_NORM_INTERLEAVED_PLUGIN_LEGACY_H
+#include "NvInferPlugin.h"
+#include
+
+#include "common/bertCommon.h"
+#include
+#include
+#include
+
+namespace nvinfer1
+{
+namespace plugin
+{
+namespace bert
+{
+
+int32_t launch_small_hface(cudaStream_t stream, int32_t const ld, int32_t const total, int8_t const* input,
+ int8_t const* skip, half const* beta, half const* gamma, int8_t* output, float const dqScaleIn,
+ float const dqScaleSkip, float const qScale);
+
+int32_t launch_large_hface(cudaStream_t stream, int32_t const ld, int32_t const total, int8_t const* input,
+ int8_t const* skip, half const* beta, half const* gamma, int8_t* output, float const dqScaleIn,
+ float const dqScaleSkip, float const qScale);
+
+int32_t launch_small_mtron(cudaStream_t stream, int32_t const ld, int32_t const total, int8_t const* input,
+ int8_t const* skip, half const* beta, half const* gamma, int8_t* output, int8_t* preln, float const dqScaleIn,
+ float const dqScaleSkip, float const qScale, float const qSkipScale);
+
+int32_t launch_large_mtron(cudaStream_t stream, int32_t const ld, int32_t const total, int8_t const* input,
+ int8_t const* skip, half const* beta, half const* gamma, int8_t* output, int8_t* preln, float const dqScaleIn,
+ float const dqScaleSkip, float const qScale, float const qSkipScale);
+
+class SkipLayerNormInterleavedPluginBaseLegacy : public nvinfer1::IPluginV2DynamicExt
+{
+public:
+ SkipLayerNormInterleavedPluginBaseLegacy(
+ std::string const& name, nvinfer1::Weights const& beta, nvinfer1::Weights const& gamma);
+
+ SkipLayerNormInterleavedPluginBaseLegacy(std::string const& name, void const* data, size_t length);
+
+ // It doesn't make sense to make SkipLayerNormInterleavedPlugin without
+ // arguments, so we delete default constructor.
+ SkipLayerNormInterleavedPluginBaseLegacy() = delete;
+
+ // IPluginV2DynamicExt Methods
+ nvinfer1::DimsExprs getOutputDimensions(int32_t outputIndex, nvinfer1::DimsExprs const* inputs, int32_t nbInputs,
+ nvinfer1::IExprBuilder& exprBuilder) noexcept override;
+ bool supportsFormatCombination(
+ int32_t pos, nvinfer1::PluginTensorDesc const* inOut, int32_t nbInputs, int32_t nbOutputs) noexcept override;
+ void configurePlugin(nvinfer1::DynamicPluginTensorDesc const* in, int32_t nbInputs,
+ nvinfer1::DynamicPluginTensorDesc const* out, int32_t nbOutputs) noexcept override;
+ size_t getWorkspaceSize(nvinfer1::PluginTensorDesc const* inputs, int32_t nbInputs,
+ nvinfer1::PluginTensorDesc const* outputs, int32_t nbOutputs) const noexcept override;
+
+ // IPluginV2Ext Methods
+ nvinfer1::DataType getOutputDataType(
+ int32_t index, nvinfer1::DataType const* inputTypes, int32_t nbInputs) const noexcept override;
+
+ // IPluginV2 Methods
+ char const* getPluginType() const noexcept override;
+ size_t getSerializationSize() const noexcept override;
+ void serialize(void* buffer) const noexcept override;
+ void destroy() noexcept override;
+ void setPluginNamespace(char const* pluginNamespace) noexcept override;
+ char const* getPluginNamespace() const noexcept override;
+
+protected:
+ std::string const& mLayerName;
+ std::string mNamespace;
+
+ bert::cuda_unique_ptr mGammaDev;
+ bert::cuda_unique_ptr mBetaDev;
+ size_t mLd{}; // leading dim
+ bert::WeightsWithOwnership mGamma;
+ bert::WeightsWithOwnership mBeta;
+
+ size_t mParamWordsize{};
+ bool mParamsOnDevice{};
+};
+
+class SkipLayerNormInterleavedPluginHFaceLegacy : public SkipLayerNormInterleavedPluginBaseLegacy
+{
+public:
+ SkipLayerNormInterleavedPluginHFaceLegacy(
+ std::string const& name, nvinfer1::Weights const& beta, nvinfer1::Weights const& gamma);
+
+ SkipLayerNormInterleavedPluginHFaceLegacy(std::string const& name, void const* data, size_t length);
+
+ // It doesn't make sense to make SkipLayerNormInterleavedPlugin without
+ // arguments, so we delete default constructor.
+ SkipLayerNormInterleavedPluginHFaceLegacy() = delete;
+
+ // IPluginV2DynamicExt Methods
+ nvinfer1::IPluginV2DynamicExt* clone() const noexcept override;
+ int32_t enqueue(nvinfer1::PluginTensorDesc const* inputDesc, nvinfer1::PluginTensorDesc const* outputDesc,
+ void const* const* inputs, void* const* outputs, void* workspace, cudaStream_t stream) noexcept override;
+
+ // IPluginV2 Methods
+ int32_t initialize() noexcept override;
+ void terminate() noexcept override;
+ void destroy() noexcept override;
+ char const* getPluginVersion() const noexcept override;
+ int32_t getNbOutputs() const noexcept override;
+};
+
+class SkipLayerNormInterleavedPluginMTronLegacy : public SkipLayerNormInterleavedPluginBaseLegacy
+{
+public:
+ SkipLayerNormInterleavedPluginMTronLegacy(
+ std::string const& name, nvinfer1::Weights const& beta, nvinfer1::Weights const& gamma);
+
+ SkipLayerNormInterleavedPluginMTronLegacy(std::string const& name, void const* data, size_t length);
+
+ // It doesn't make sense to make SkipLayerNormInterleavedPlugin without
+ // arguments, so we delete default constructor.
+ SkipLayerNormInterleavedPluginMTronLegacy() = delete;
+
+ // IPluginV2DynamicExt Methods
+ nvinfer1::IPluginV2DynamicExt* clone() const noexcept override;
+ int32_t enqueue(nvinfer1::PluginTensorDesc const* inputDesc, nvinfer1::PluginTensorDesc const* outputDesc,
+ void const* const* inputs, void* const* outputs, void* workspace, cudaStream_t stream) noexcept override;
+
+ // IPluginV2 Methods
+ int32_t initialize() noexcept override;
+ void terminate() noexcept override;
+ void destroy() noexcept override;
+ char const* getPluginVersion() const noexcept override;
+ int32_t getNbOutputs() const noexcept override;
+};
+
+class SkipLayerNormInterleavedPluginBaseLegacyCreator : public nvinfer1::IPluginCreator
+{
+public:
+ SkipLayerNormInterleavedPluginBaseLegacyCreator();
+
+ char const* getPluginName() const noexcept override;
+
+ nvinfer1::PluginFieldCollection const* getFieldNames() noexcept override;
+
+ void setPluginNamespace(char const* pluginNamespace) noexcept override;
+
+ char const* getPluginNamespace() const noexcept override;
+
+private:
+ static nvinfer1::PluginFieldCollection mFC;
+ static std::vector mPluginAttributes;
+ std::string mNamespace;
+};
+
+class SkipLayerNormInterleavedPluginHFaceLegacyCreator : public SkipLayerNormInterleavedPluginBaseLegacyCreator
+{
+public:
+ SkipLayerNormInterleavedPluginHFaceLegacyCreator();
+
+ char const* getPluginVersion() const noexcept override;
+
+ nvinfer1::IPluginV2* createPlugin(char const* name, nvinfer1::PluginFieldCollection const* fc) noexcept override;
+ nvinfer1::IPluginV2* deserializePlugin(
+ char const* name, void const* serialData, size_t serialLength) noexcept override;
+};
+
+class SkipLayerNormInterleavedPluginMTronLegacyCreator : public SkipLayerNormInterleavedPluginBaseLegacyCreator
+{
+public:
+ SkipLayerNormInterleavedPluginMTronLegacyCreator();
+
+ char const* getPluginVersion() const noexcept override;
+
+ nvinfer1::IPluginV2* createPlugin(char const* name, nvinfer1::PluginFieldCollection const* fc) noexcept override;
+ nvinfer1::IPluginV2* deserializePlugin(
+ char const* name, void const* serialData, size_t serialLength) noexcept override;
+};
+
+} // namespace bert
+} // namespace plugin
+} // namespace nvinfer1
+#endif // TRT_SKIP_LAYER_NORM_INTERLEAVED_PLUGIN_LEGACY_H
diff --git a/plugin/skipLayerNormPlugin/skipLayerNormKernel.cu b/plugin/skipLayerNormPlugin/skipLayerNormKernel.cu
index da0cee19..b107a824 100644
--- a/plugin/skipLayerNormPlugin/skipLayerNormKernel.cu
+++ b/plugin/skipLayerNormPlugin/skipLayerNormKernel.cu
@@ -23,6 +23,7 @@
#include "common/common.cuh"
#include "common/serialize.hpp"
#include "skipLayerNormPlugin.h"
+#include "skipLayerNormPluginLegacy.h"
#include
#include
diff --git a/plugin/skipLayerNormPlugin/skipLayerNormPlugin.cpp b/plugin/skipLayerNormPlugin/skipLayerNormPlugin.cpp
index c792486b..6d3daa82 100644
--- a/plugin/skipLayerNormPlugin/skipLayerNormPlugin.cpp
+++ b/plugin/skipLayerNormPlugin/skipLayerNormPlugin.cpp
@@ -32,38 +32,28 @@ using namespace nvinfer1::plugin::bert;
// Clip plugin specific constants
namespace
{
-char const* kSKIP_LAYER_NORM_VERSION{"1"};
-char const* kSKIP_LAYER_NORM_NAME{"CustomSkipLayerNormPluginDynamic"};
-char const* kSKIP_LAYER_NORM_VAR_SEQLEN_VERSION{"2"};
+constexpr char const* kSKIP_LAYER_NORM_VERSION{"5"};
+constexpr char const* kSKIP_LAYER_NORM_NAME{"CustomSkipLayerNormPluginDynamic"};
+constexpr char const* kSKIP_LAYER_NORM_VAR_SEQLEN_VERSION{"6"};
} // namespace
// Static class fields initialization
-PluginFieldCollection SkipLayerNormPluginDynamicCreator::mFC{};
-std::vector SkipLayerNormPluginDynamicCreator::mPluginAttributes;
+PluginFieldCollection SkipLayerNormPluginV3Creator::mFC{};
+std::vector SkipLayerNormPluginV3Creator::mPluginAttributes;
-PluginFieldCollection SkipLayerNormVarSeqlenPluginCreator::mFC{};
-std::vector SkipLayerNormVarSeqlenPluginCreator::mPluginAttributes;
+PluginFieldCollection SkipLayerNormVarSeqlenPluginV3Creator::mFC{};
+std::vector SkipLayerNormVarSeqlenPluginV3Creator::mPluginAttributes;
-REGISTER_TENSORRT_PLUGIN(SkipLayerNormPluginDynamicCreator);
-REGISTER_TENSORRT_PLUGIN(SkipLayerNormVarSeqlenPluginCreator);
+REGISTER_TENSORRT_PLUGIN(SkipLayerNormPluginV3Creator);
+REGISTER_TENSORRT_PLUGIN(SkipLayerNormVarSeqlenPluginV3Creator);
-static inline DataType getParamWordType(DataType cfgType) noexcept
-{
- if (cfgType == DataType::kINT8)
- {
- return DataType::kHALF;
- }
-
- return cfgType;
-}
-
-SkipLayerNormPluginDynamic::SkipLayerNormPluginDynamic(const std::string name, const DataType type, int32_t const ld,
+SkipLayerNormPluginV3::SkipLayerNormPluginV3(const std::string name, const DataType type, int32_t const ld,
Weights const& beta, Weights const& gamma, Weights const& bias)
: mLayerName(name)
+ , mType(type)
+ , mLd(ld)
, mGammaDev(nullptr)
, mBetaDev(nullptr)
- , mLd(ld)
- , mType(type)
, mBiasDev(nullptr)
{
PLUGIN_VALIDATE(mType == nvinfer1::DataType::kFLOAT || mType == nvinfer1::DataType::kHALF
@@ -88,50 +78,32 @@ SkipLayerNormPluginDynamic::SkipLayerNormPluginDynamic(const std::string name, c
{
copyToDevice(mBias, getWeightsSize(mBias, mCfgType), mBiasDev);
}
+ BERT_DEBUG_MSG("SkipLayerNormPluginV3 initialize");
}
-SkipLayerNormPluginDynamic::SkipLayerNormPluginDynamic(const std::string name, void const* data, size_t length)
- : mLayerName(name)
- , mGammaDev(nullptr)
- , mBetaDev(nullptr)
- , mBiasDev(nullptr)
+SkipLayerNormPluginV3::~SkipLayerNormPluginV3()
{
- BERT_DEBUG_MSG("SkipLayerNormPluginDynamic deserialize");
-
- // Deserialize in the same order as serialization
- deserialize_value(&data, &length, &mType);
- deserialize_value(&data, &length, &mCfgType);
- deserialize_value(&data, &length, &mLd);
- deserialize_value(&data, &length, &mHasBias);
-
- PLUGIN_VALIDATE(mCfgType == nvinfer1::DataType::kFLOAT || mCfgType == nvinfer1::DataType::kHALF);
- mParamWordsize = getElementSize(mCfgType);
-
- char const* d = static_cast(data);
- mBeta.convertAndCopy(d, mLd, mCfgType);
- mGamma.convertAndCopy(d, mLd, mCfgType);
- if (mHasBias)
+ BERT_DEBUG_MSG("SkipLayerNormPluginV3 terminate");
+ try
{
- mBias.convertAndCopy(d, mLd, mCfgType);
+ BERT_DEBUG_MSG("SkipLayerNormPluginV3 destroy");
+ mGammaDev.reset(nullptr);
+ mBetaDev.reset(nullptr);
+ mBiasDev.reset(nullptr);
}
-
- copyToDevice(mGamma, getWeightsSize(mGamma, mCfgType), mGammaDev);
- copyToDevice(mBeta, getWeightsSize(mBeta, mCfgType), mBetaDev);
- if (mHasBias)
+ catch (std::exception const& e)
{
- copyToDevice(mBias, getWeightsSize(mBias, mCfgType), mBiasDev);
+ caughtError(e);
}
}
-// IPluginV2DynamicExt Methods
-IPluginV2DynamicExt* SkipLayerNormPluginDynamic::clone() const noexcept
+IPluginV3* SkipLayerNormPluginV3::clone() noexcept
{
try
{
- BERT_DEBUG_MSG("SkipLayerNormPluginDynamic clone");
+ BERT_DEBUG_MSG("SkipLayerNormPluginV3 clone");
- auto* p = new SkipLayerNormPluginDynamic(mLayerName, mType, mLd, mBeta, mGamma, mBias);
- p->initialize();
+ auto* p = new SkipLayerNormPluginV3(mLayerName, mType, mLd, mBeta, mGamma, mBias);
p->setPluginNamespace(mNamespace.c_str());
return p;
}
@@ -142,26 +114,26 @@ IPluginV2DynamicExt* SkipLayerNormPluginDynamic::clone() const noexcept
return nullptr;
}
-DimsExprs SkipLayerNormPluginDynamic::getOutputDimensions(
- int32_t outputIndex, DimsExprs const* inputs, int32_t nbInputs, IExprBuilder& exprBuilder) noexcept
+int32_t SkipLayerNormPluginV3::getOutputShapes(DimsExprs const* inputs, int32_t nbInputs, DimsExprs const* shapeInputs,
+ int32_t nbShapeInputs, DimsExprs* outputs, int32_t nbOutputs, IExprBuilder& exprBuilder) noexcept
{
try
{
PLUGIN_VALIDATE(inputs != nullptr);
PLUGIN_VALIDATE(nbInputs == 2);
- PLUGIN_VALIDATE(outputIndex == 0);
PLUGIN_VALIDATE(inputs[0].nbDims == inputs[1].nbDims);
- return inputs[0];
+ outputs[0] = inputs[0];
+ return pluginStatus_t::STATUS_SUCCESS;
}
catch (std::exception const& e)
{
caughtError(e);
}
- return DimsExprs{};
+ return pluginStatus_t::STATUS_FAILURE;
}
-bool SkipLayerNormPluginDynamic::supportsFormatCombination(
- int32_t pos, PluginTensorDesc const* inOut, int32_t nbInputs, int32_t nbOutputs) noexcept
+bool SkipLayerNormPluginV3::supportsFormatCombination(
+ int32_t pos, DynamicPluginTensorDesc const* inOut, int32_t nbInputs, int32_t nbOutputs) noexcept
{
try
{
@@ -170,7 +142,7 @@ bool SkipLayerNormPluginDynamic::supportsFormatCombination(
PLUGIN_VALIDATE(nbOutputs == 1);
PLUGIN_VALIDATE(pos >= 0 && pos < (nbInputs + nbOutputs));
- PluginTensorDesc const& in = inOut[pos];
+ PluginTensorDesc const& in = inOut[pos].desc;
if (pos == 0)
{
// Since H = W = 1, we can report CHWx for any x
@@ -192,7 +164,7 @@ bool SkipLayerNormPluginDynamic::supportsFormatCombination(
}
return (in.type == mType) && (in.format == TensorFormat::kLINEAR);
}
- PluginTensorDesc const& prev = inOut[pos - 1];
+ PluginTensorDesc const& prev = inOut[pos - 1].desc;
return in.type == prev.type && in.format == prev.format;
}
@@ -203,59 +175,21 @@ bool SkipLayerNormPluginDynamic::supportsFormatCombination(
return false;
}
-void SkipLayerNormPluginDynamic::configurePlugin(DynamicPluginTensorDesc const* inputs, int32_t nbInputs,
+int32_t SkipLayerNormPluginV3::configurePlugin(DynamicPluginTensorDesc const* inputs, int32_t nbInputs,
DynamicPluginTensorDesc const* outputs, int32_t nbOutputs) noexcept
{
- try
- {
- BERT_DEBUG_MSG("SkipLayerNormPluginDynamic configurePlugin");
-
- // Validate input arguments
- PLUGIN_VALIDATE(inputs != nullptr);
- PLUGIN_VALIDATE(outputs != nullptr);
- PLUGIN_VALIDATE(nbOutputs == 1);
- PLUGIN_VALIDATE(nbInputs == 2);
- if (mType == DataType::kFLOAT || mType == DataType::kHALF)
- {
- PLUGIN_VALIDATE(mType == inputs[0].desc.type);
- PLUGIN_VALIDATE(mType == inputs[1].desc.type);
- }
- else
- {
- PLUGIN_VALIDATE(mType == inputs[0].desc.type || DataType::kFLOAT == inputs[0].desc.type);
- PLUGIN_VALIDATE(mType == inputs[1].desc.type || DataType::kFLOAT == inputs[1].desc.type);
- }
- auto const& inDims0 = inputs[0].desc.dims;
- auto const& inDims1 = inputs[1].desc.dims;
- PLUGIN_VALIDATE(inDims0.nbDims == inDims1.nbDims);
-
- PLUGIN_VALIDATE(std::equal(inDims0.d, inDims0.d + inDims0.nbDims, inDims1.d));
-
- PLUGIN_VALIDATE(inDims0.nbDims == 5);
- mLd = inDims0.d[HDIM]; // hiddensize
- PLUGIN_VALIDATE(mLd != 0U);
- PLUGIN_VALIDATE(inDims0.d[3] == 1);
- PLUGIN_VALIDATE(inDims0.d[4] == 1);
-
- mCfgType = inputs[0].desc.type == DataType::kINT8 ? DataType::kHALF : inputs[0].desc.type;
-
- auto const paramType = getParamWordType(mCfgType);
- mParamWordsize = getElementSize(paramType);
- }
- catch (std::exception const& e)
- {
- caughtError(e);
- }
+ return pluginStatus_t::STATUS_SUCCESS;
}
-size_t SkipLayerNormPluginDynamic::getWorkspaceSize(
- PluginTensorDesc const* inputs, int32_t nbInputs, PluginTensorDesc const* outputs, int32_t nbOutputs) const noexcept
+size_t SkipLayerNormPluginV3::getWorkspaceSize(DynamicPluginTensorDesc const* inputs, int32_t nbInputs,
+ DynamicPluginTensorDesc const* outputs, int32_t nbOutputs) const noexcept
{
return 0;
}
-int32_t SkipLayerNormPluginDynamic::enqueue(PluginTensorDesc const* inputDesc, PluginTensorDesc const* outputDesc,
- void const* const* inputs, void* const* outputs, void* /* workspace */, cudaStream_t stream) noexcept
+int32_t SkipLayerNormPluginV3::enqueue(nvinfer1::PluginTensorDesc const* inputDesc,
+ nvinfer1::PluginTensorDesc const* outputDesc, void const* const* inputs, void* const* outputs, void* workspace,
+ cudaStream_t stream) noexcept
{
int32_t status = -1;
try
@@ -342,120 +276,181 @@ int32_t SkipLayerNormPluginDynamic::enqueue(PluginTensorDesc const* inputDesc, P
return status;
}
-// IPluginV2Ext Methods
-DataType SkipLayerNormPluginDynamic::getOutputDataType(
- int32_t index, DataType const* inputTypes, int32_t nbInputs) const noexcept
+int32_t SkipLayerNormPluginV3::getOutputDataTypes(
+ DataType* outputTypes, int32_t nbOutputs, DataType const* inputTypes, int32_t nbInputs) const noexcept
{
try
{
+ PLUGIN_VALIDATE(outputTypes != nullptr);
+ PLUGIN_VALIDATE(nbOutputs == 1);
PLUGIN_VALIDATE(inputTypes != nullptr);
- PLUGIN_VALIDATE(index == 0);
PLUGIN_VALIDATE(nbInputs == 2);
- return inputTypes[0];
+ outputTypes[0] = inputTypes[0];
+ return pluginStatus_t::STATUS_SUCCESS;
}
catch (std::exception const& e)
{
caughtError(e);
}
- return DataType{};
+ return pluginStatus_t::STATUS_FAILURE;
}
-// IPluginV2 Methods
-char const* SkipLayerNormPluginDynamic::getPluginType() const noexcept
+char const* SkipLayerNormPluginV3::getPluginVersion() const noexcept
{
- return kSKIP_LAYER_NORM_NAME;
+ return kSKIP_LAYER_NORM_VERSION;
}
-char const* SkipLayerNormPluginDynamic::getPluginVersion() const noexcept
+int32_t SkipLayerNormPluginV3::getNbOutputs() const noexcept
{
- return kSKIP_LAYER_NORM_VERSION;
+ return 1;
}
-int32_t SkipLayerNormPluginDynamic::getNbOutputs() const noexcept
+PluginFieldCollection const* SkipLayerNormPluginV3::getFieldsToSerialize() noexcept
{
- return 1;
+ mDataToSerialize.clear();
+ mDataToSerialize.emplace_back("type_id", &mType, PluginFieldType::kINT32, 1);
+ mDataToSerialize.emplace_back("ld", &mLd, PluginFieldType::kINT32, 1);
+ if (mCfgType == DataType::kHALF)
+ {
+ mDataToSerialize.emplace_back(
+ "beta", static_cast(mBeta.values), PluginFieldType::kFLOAT16, mBeta.count);
+ PLUGIN_ASSERT(mBeta.type == mCfgType);
+ mDataToSerialize.emplace_back(
+ "gamma", static_cast(mGamma.values), PluginFieldType::kFLOAT16, mGamma.count);
+ PLUGIN_ASSERT(mGamma.type == mCfgType);
+ if (mHasBias)
+ {
+ mDataToSerialize.emplace_back(
+ "bias", static_cast(mBias.values), PluginFieldType::kFLOAT16, mBias.count);
+ PLUGIN_ASSERT(mBias.type == mCfgType);
+ }
+ }
+ else
+ {
+ PLUGIN_ASSERT(mCfgType == DataType::kFLOAT);
+ mDataToSerialize.emplace_back(
+ "beta", static_cast(mBeta.values), PluginFieldType::kFLOAT32, mBeta.count);
+ PLUGIN_ASSERT(mBeta.type == mCfgType);
+ mDataToSerialize.emplace_back(
+ "gamma", static_cast(mGamma.values), PluginFieldType::kFLOAT32, mGamma.count);
+ PLUGIN_ASSERT(mGamma.type == mCfgType);
+ if (mHasBias)
+ {
+ mDataToSerialize.emplace_back(
+ "bias", static_cast(mBias.values), PluginFieldType::kFLOAT32, mBias.count);
+ PLUGIN_ASSERT(mBias.type == mCfgType);
+ }
+ }
+
+ mFCToSerialize.nbFields = mDataToSerialize.size();
+ mFCToSerialize.fields = mDataToSerialize.data();
+
+ return &mFCToSerialize;
}
-int32_t SkipLayerNormPluginDynamic::initialize() noexcept
+void SkipLayerNormPluginV3::setPluginNamespace(char const* libNamespace) noexcept
{
- BERT_DEBUG_MSG("SkipLayerNormPluginDynamic initialize");
- return 0;
+ try
+ {
+ mNamespace = libNamespace;
+ }
+ catch (std::exception const& e)
+ {
+ caughtError(e);
+ }
}
-void SkipLayerNormPluginDynamic::terminate() noexcept
+char const* SkipLayerNormPluginV3::getPluginNamespace() const noexcept
{
- BERT_DEBUG_MSG("SkipLayerNormPluginDynamic terminate");
+ return mNamespace.c_str();
}
-size_t SkipLayerNormPluginDynamic::getSerializationSize() const noexcept
+char const* SkipLayerNormPluginV3::getPluginName() const noexcept
{
- const size_t biasSize = mHasBias ? (mLd * mParamWordsize) : 0;
- return 2 * mParamWordsize * mLd + 2 * sizeof(DataType) + sizeof(mLd) + biasSize + sizeof(mHasBias);
+ return kSKIP_LAYER_NORM_NAME;
}
-void SkipLayerNormPluginDynamic::serialize(void* buffer) const noexcept
+int32_t SkipLayerNormPluginV3::onShapeChange(
+ PluginTensorDesc const* inputs, int32_t nbInputs, PluginTensorDesc const* outputs, int32_t nbOutputs) noexcept
{
try
{
- serialize_value(&buffer, mType);
- serialize_value(&buffer, mCfgType);
- serialize_value(&buffer, mLd);
- serialize_value(&buffer, mHasBias);
-
- char* d = static_cast(buffer);
- serFromDev(d, static_cast(mBetaDev.get()), mLd * mParamWordsize);
- serFromDev(d, static_cast(mGammaDev.get()), mLd * mParamWordsize);
- if (mHasBias)
+ BERT_DEBUG_MSG("SkipLayerNormPluginV3 onShapeChange");
+
+ // Validate input arguments
+ PLUGIN_VALIDATE(inputs != nullptr);
+ PLUGIN_VALIDATE(outputs != nullptr);
+ PLUGIN_VALIDATE(nbOutputs == 1);
+ PLUGIN_VALIDATE(nbInputs == 2);
+ if (mType == DataType::kFLOAT || mType == DataType::kHALF)
+ {
+ PLUGIN_VALIDATE(mType == inputs[0].type);
+ PLUGIN_VALIDATE(mType == inputs[1].type);
+ }
+ else
{
- serFromDev(d, static_cast(mBiasDev.get()), mLd * mParamWordsize);
+ PLUGIN_VALIDATE(mType == inputs[0].type || DataType::kFLOAT == inputs[0].type);
+ PLUGIN_VALIDATE(mType == inputs[1].type || DataType::kFLOAT == inputs[1].type);
}
+ auto const& inDims0 = inputs[0].dims;
+ auto const& inDims1 = inputs[1].dims;
+ PLUGIN_VALIDATE(inDims0.nbDims == inDims1.nbDims);
+
+ PLUGIN_VALIDATE(std::equal(inDims0.d, inDims0.d + inDims0.nbDims, inDims1.d));
+
+ PLUGIN_VALIDATE(inDims0.nbDims == 5);
+ mLd = inDims0.d[HDIM]; // hiddensize
+ PLUGIN_VALIDATE(mLd != 0);
+ PLUGIN_VALIDATE(inDims0.d[3] == 1);
+ PLUGIN_VALIDATE(inDims0.d[4] == 1);
+
+ mCfgType = inputs[0].type == DataType::kINT8 ? DataType::kHALF : inputs[0].type;
+
+ mParamWordsize = getElementSize(mCfgType);
+ return pluginStatus_t::STATUS_SUCCESS;
}
catch (std::exception const& e)
{
caughtError(e);
}
+ return pluginStatus_t::STATUS_FAILURE;
}
-void SkipLayerNormPluginDynamic::destroy() noexcept
+IPluginV3* SkipLayerNormPluginV3::attachToContext(IPluginResourceContext* context) noexcept
{
- try
- {
- BERT_DEBUG_MSG("SkipLayerNormPluginDynamic destroy");
- // This gets called when the network containing plugin is destroyed
- mGammaDev.reset(nullptr);
- mBetaDev.reset(nullptr);
- mBiasDev.reset(nullptr);
- delete this;
- }
- catch (std::exception const& e)
- {
- caughtError(e);
- }
+ return clone();
}
-void SkipLayerNormPluginDynamic::setPluginNamespace(char const* libNamespace) noexcept
+IPluginCapability* SkipLayerNormPluginV3::getCapabilityInterface(PluginCapabilityType type) noexcept
{
try
{
- mNamespace = libNamespace;
+ if (type == PluginCapabilityType::kBUILD)
+ {
+ return static_cast(this);
+ }
+ if (type == PluginCapabilityType::kRUNTIME)
+ {
+ return static_cast(this);
+ }
+ PLUGIN_ASSERT(type == PluginCapabilityType::kCORE);
+ return static_cast(this);
}
catch (std::exception const& e)
{
caughtError(e);
}
+ return nullptr;
}
-char const* SkipLayerNormPluginDynamic::getPluginNamespace() const noexcept
-{
- return mNamespace.c_str();
-}
-
-/////////////////////////////////////////////////////////
+////////////////////////// SkipLayerNormPluginV3 (version:5) Creator ///////////////////////////////
-SkipLayerNormPluginDynamicCreator::SkipLayerNormPluginDynamicCreator()
+SkipLayerNormPluginV3Creator::SkipLayerNormPluginV3Creator()
{
+ static std::mutex sMutex;
+ std::lock_guard guard(sMutex);
mPluginAttributes.clear();
- mPluginAttributes.emplace_back(PluginField("ld"));
mPluginAttributes.emplace_back(PluginField("type_id"));
+ mPluginAttributes.emplace_back(PluginField("ld"));
mPluginAttributes.emplace_back(PluginField("beta"));
mPluginAttributes.emplace_back(PluginField("gamma"));
mPluginAttributes.emplace_back(PluginField("bias"));
@@ -463,26 +458,27 @@ SkipLayerNormPluginDynamicCreator::SkipLayerNormPluginDynamicCreator()
mFC.fields = mPluginAttributes.data();
}
-char const* SkipLayerNormPluginDynamicCreator::getPluginName() const noexcept
+char const* SkipLayerNormPluginV3Creator::getPluginName() const noexcept
{
return kSKIP_LAYER_NORM_NAME;
}
-char const* SkipLayerNormPluginDynamicCreator::getPluginVersion() const noexcept
+char const* SkipLayerNormPluginV3Creator::getPluginVersion() const noexcept
{
return kSKIP_LAYER_NORM_VERSION;
}
-PluginFieldCollection const* SkipLayerNormPluginDynamicCreator::getFieldNames() noexcept
+PluginFieldCollection const* SkipLayerNormPluginV3Creator::getFieldNames() noexcept
{
return &mFC;
}
-IPluginV2* SkipLayerNormPluginDynamicCreator::createPlugin(char const* name, PluginFieldCollection const* fc) noexcept
+IPluginV3* SkipLayerNormPluginV3Creator::createPlugin(
+ char const* name, PluginFieldCollection const* fc, TensorRTPhase phase) noexcept
{
try
{
- BERT_DEBUG_MSG("SkipLayerNormPluginDynamicCreator createPlugin");
+ BERT_DEBUG_MSG("SkipLayerNormPluginV3Creator createPlugin");
int32_t ld = 0;
Weights beta{DataType::kFLOAT, nullptr, 0};
@@ -491,46 +487,32 @@ IPluginV2* SkipLayerNormPluginDynamicCreator::createPlugin(char const* name, Plu
int32_t typeId = -1;
PLUGIN_VALIDATE(fc != nullptr);
+ PLUGIN_VALIDATE(fc->fields != nullptr);
plugin::validateRequiredAttributesExist({"type_id", "beta", "ld", "gamma"}, fc);
for (int32_t i = 0; i < fc->nbFields; i++)
{
- std::string field_name(fc->fields[i].name);
- if (field_name.compare("ld") == 0)
- {
- ld = *static_cast(fc->fields[i].data);
- BERT_DEBUG_VALUE("Building ld: ", ld);
- }
-
- if (field_name.compare("type_id") == 0)
+ std::string fieldName(fc->fields[i].name);
+ if (fieldName == "type_id")
{
typeId = *static_cast(fc->fields[i].data);
BERT_DEBUG_VALUE("Building typeId: ", typeId);
}
-
- if (field_name.compare("beta") == 0)
+ else if (fieldName == "ld")
{
- BERT_DEBUG_MSG("Building beta...");
- beta.values = fc->fields[i].data;
- beta.count = fc->fields[i].length;
- beta.type = fieldTypeToDataType(fc->fields[i].type);
+ ld = *static_cast(fc->fields[i].data);
+ BERT_DEBUG_VALUE("Building ld: ", ld);
}
-
- if (field_name.compare("gamma") == 0)
+ // process the weight tensors beta, gamma, bias
+ else if (fieldName == "beta" || fieldName == "gamma" || fieldName == "bias")
{
- BERT_DEBUG_MSG("Building gamma...");
- gamma.values = fc->fields[i].data;
- gamma.count = fc->fields[i].length;
- gamma.type = fieldTypeToDataType(fc->fields[i].type);
- }
+ Weights* weightPtr = (fieldName == "beta") ? &beta : (fieldName == "gamma") ? &gamma : &bias;
- if (field_name.compare("bias") == 0)
- {
- BERT_DEBUG_MSG("Building bias...");
- bias.values = fc->fields[i].data;
- bias.count = fc->fields[i].length;
- bias.type = fieldTypeToDataType(fc->fields[i].type);
+ BERT_DEBUG_MSG(("Building " + fieldName + "...").c_str());
+ weightPtr->type = fieldTypeToDataType(fc->fields[i].type);
+ weightPtr->values = fc->fields[i].data;
+ weightPtr->count = fc->fields[i].length;
}
}
BERT_DEBUG_VALUE("Type ", typeId);
@@ -540,27 +522,14 @@ IPluginV2* SkipLayerNormPluginDynamicCreator::createPlugin(char const* name, Plu
PLUGIN_VALIDATE(beta.values != nullptr, "SkipLayerNorm: invalid beta");
PLUGIN_VALIDATE(beta.count > 0, "SkipLayerNorm: invalid beta");
-
PLUGIN_VALIDATE(gamma.values != nullptr, "SkipLayerNorm: invalid gamma");
PLUGIN_VALIDATE(gamma.count > 0, "SkipLayerNorm: invalid gamma");
+ if (bias.values != nullptr)
+ {
+ PLUGIN_VALIDATE(bias.count > 0, "SkipLayerNorm: invalid bias");
+ }
- return new SkipLayerNormPluginDynamic(name, static_cast(typeId), ld, beta, gamma, bias);
- }
- catch (std::exception const& e)
- {
- caughtError(e);
- }
- return nullptr;
-}
-
-IPluginV2* SkipLayerNormPluginDynamicCreator::deserializePlugin(
- char const* name, void const* serialData, size_t serialLength) noexcept
-{
- // This object will be deleted when the network is destroyed, which will
- // call SkipLayerNormPluginDynamic::destroy()
- try
- {
- return new SkipLayerNormPluginDynamic(name, serialData, serialLength);
+ return new SkipLayerNormPluginV3(name, static_cast(typeId), ld, beta, gamma, bias);
}
catch (std::exception const& e)
{
@@ -569,7 +538,7 @@ IPluginV2* SkipLayerNormPluginDynamicCreator::deserializePlugin(
return nullptr;
}
-void SkipLayerNormPluginDynamicCreator::setPluginNamespace(char const* libNamespace) noexcept
+void SkipLayerNormPluginV3Creator::setPluginNamespace(char const* libNamespace) noexcept
{
try
{
@@ -581,12 +550,14 @@ void SkipLayerNormPluginDynamicCreator::setPluginNamespace(char const* libNamesp
}
}
-char const* SkipLayerNormPluginDynamicCreator::getPluginNamespace() const noexcept
+char const* SkipLayerNormPluginV3Creator::getPluginNamespace() const noexcept
{
return mNamespace.c_str();
}
-SkipLayerNormVarSeqlenPlugin::SkipLayerNormVarSeqlenPlugin(
+////////////////////////// SkipLayerNormVarSeqlenPluginV3 (skipLayerNorm version: 6) ///////////////////////////////
+
+SkipLayerNormVarSeqlenPluginV3::SkipLayerNormVarSeqlenPluginV3(
const std::string name, const DataType type, Weights const& beta, Weights const& gamma, Weights const& bias)
: mLayerName(name)
, mGammaDev(nullptr)
@@ -621,48 +592,27 @@ SkipLayerNormVarSeqlenPlugin::SkipLayerNormVarSeqlenPlugin(
}
}
-SkipLayerNormVarSeqlenPlugin::SkipLayerNormVarSeqlenPlugin(const std::string name, void const* data, size_t length)
- : mLayerName(name)
- , mGammaDev(nullptr)
- , mBetaDev(nullptr)
- , mBiasDev(nullptr)
+SkipLayerNormVarSeqlenPluginV3::~SkipLayerNormVarSeqlenPluginV3()
{
- BERT_DEBUG_MSG("SkipLayerNormVarSeqlenPlugin deserialize");
-
- // Deserialize in the same order as serialization
- deserialize_value(&data, &length, &mType);
- deserialize_value(&data, &length, &mCfgType);
- deserialize_value(&data, &length, &mLd);
- deserialize_value(&data, &length, &mHasBias);
-
- PLUGIN_VALIDATE(mCfgType == nvinfer1::DataType::kFLOAT || mCfgType == nvinfer1::DataType::kHALF);
- mParamWordsize = getElementSize(mCfgType);
-
- char const* d = static_cast(data);
- mBeta.convertAndCopy(d, mLd, mCfgType);
- mGamma.convertAndCopy(d, mLd, mCfgType);
- if (mHasBias)
+ try
{
- mBias.convertAndCopy(d, mLd, mCfgType);
+ BERT_DEBUG_MSG("SkipLayerNormVarSeqlenPluginV3 destroy");
+ mGammaDev.reset(nullptr);
+ mBetaDev.reset(nullptr);
+ mBiasDev.reset(nullptr);
}
-
- copyToDevice(mGamma, getWeightsSize(mGamma, mCfgType), mGammaDev);
- copyToDevice(mBeta, getWeightsSize(mBeta, mCfgType), mBetaDev);
- if (mHasBias)
+ catch (std::exception const& e)
{
- copyToDevice(mBias, getWeightsSize(mBias, mCfgType), mBiasDev);
+ caughtError(e);
}
}
-// IPluginV2DynamicExt Methods
-IPluginV2DynamicExt* SkipLayerNormVarSeqlenPlugin::clone() const noexcept
+IPluginV3* SkipLayerNormVarSeqlenPluginV3::clone() noexcept
{
try
{
- BERT_DEBUG_MSG("SkipLayerNormVarSeqlenPlugin clone");
-
- auto* p = new SkipLayerNormVarSeqlenPlugin(mLayerName, mType, mBeta, mGamma, mBias);
- p->initialize();
+ BERT_DEBUG_MSG("SkipLayerNormVarSeqlenPluginV3 clone");
+ auto* p = new SkipLayerNormVarSeqlenPluginV3(mLayerName, mType, mBeta, mGamma, mBias);
p->setPluginNamespace(mNamespace.c_str());
return p;
}
@@ -673,26 +623,28 @@ IPluginV2DynamicExt* SkipLayerNormVarSeqlenPlugin::clone() const noexcept
return nullptr;
}
-DimsExprs SkipLayerNormVarSeqlenPlugin::getOutputDimensions(
- int32_t outputIndex, DimsExprs const* inputs, int32_t nbInputs, IExprBuilder& exprBuilder) noexcept
+int32_t SkipLayerNormVarSeqlenPluginV3::getOutputShapes(DimsExprs const* inputs, int32_t nbInputs,
+ DimsExprs const* shapeInputs, int32_t nbShapeInputs, DimsExprs* outputs, int32_t nbOutputs,
+ IExprBuilder& exprBuilder) noexcept
{
try
{
PLUGIN_VALIDATE(inputs != nullptr);
PLUGIN_VALIDATE(nbInputs == 2);
- PLUGIN_VALIDATE(outputIndex == 0);
+ PLUGIN_VALIDATE(nbOutputs == 1);
PLUGIN_VALIDATE(inputs[0].nbDims == inputs[1].nbDims);
- return inputs[0];
+ outputs[0] = inputs[0];
+ return pluginStatus_t::STATUS_SUCCESS;
}
catch (std::exception const& e)
{
caughtError(e);
}
- return DimsExprs{};
+ return pluginStatus_t::STATUS_FAILURE;
}
-bool SkipLayerNormVarSeqlenPlugin::supportsFormatCombination(
- int32_t pos, PluginTensorDesc const* inOut, int32_t nbInputs, int32_t nbOutputs) noexcept
+bool SkipLayerNormVarSeqlenPluginV3::supportsFormatCombination(
+ int32_t pos, DynamicPluginTensorDesc const* inOut, int32_t nbInputs, int32_t nbOutputs) noexcept
{
try
{
@@ -701,7 +653,7 @@ bool SkipLayerNormVarSeqlenPlugin::supportsFormatCombination(
PLUGIN_VALIDATE(nbOutputs == 1);
PLUGIN_VALIDATE(pos >= 0 && pos < (nbInputs + nbOutputs));
- PluginTensorDesc const& in = inOut[pos];
+ PluginTensorDesc const& in = inOut[pos].desc;
if (mType != in.type)
return false;
@@ -726,7 +678,7 @@ bool SkipLayerNormVarSeqlenPlugin::supportsFormatCombination(
}
return in.format == TensorFormat::kLINEAR;
}
- PluginTensorDesc const& prev = inOut[pos - 1];
+ PluginTensorDesc const& prev = inOut[pos - 1].desc;
return in.format == prev.format;
}
@@ -737,52 +689,21 @@ bool SkipLayerNormVarSeqlenPlugin::supportsFormatCombination(
return false;
}
-void SkipLayerNormVarSeqlenPlugin::configurePlugin(DynamicPluginTensorDesc const* inputs, int32_t nbInputs,
+int32_t SkipLayerNormVarSeqlenPluginV3::configurePlugin(DynamicPluginTensorDesc const* inputs, int32_t nbInputs,
DynamicPluginTensorDesc const* outputs, int32_t nbOutputs) noexcept
{
- try
- {
- // Validate input arguments
- PLUGIN_VALIDATE(inputs != nullptr);
- PLUGIN_VALIDATE(outputs != nullptr);
- PLUGIN_VALIDATE(nbOutputs == 1);
- PLUGIN_VALIDATE(nbInputs == 2);
-
- if (mType == DataType::kFLOAT || mType == DataType::kHALF)
- {
- PLUGIN_VALIDATE(mType == inputs[0].desc.type);
- PLUGIN_VALIDATE(mType == inputs[1].desc.type);
- }
- else
- {
- PLUGIN_VALIDATE(mType == inputs[0].desc.type || DataType::kFLOAT == inputs[0].desc.type);
- PLUGIN_VALIDATE(mType == inputs[1].desc.type || DataType::kFLOAT == inputs[1].desc.type);
- }
- auto const& inDims0 = inputs[0].desc.dims;
- auto const& inDims1 = inputs[1].desc.dims;
- PLUGIN_VALIDATE(inDims0.nbDims == inDims1.nbDims);
-
- PLUGIN_VALIDATE(std::equal(inDims0.d, inDims0.d + inDims0.nbDims, inDims1.d));
-
- mCfgType = inputs[0].desc.type == DataType::kINT8 ? DataType::kHALF : inputs[0].desc.type;
-
- auto const paramType = getParamWordType(mCfgType);
- mParamWordsize = getElementSize(paramType);
- }
- catch (std::exception const& e)
- {
- caughtError(e);
- }
+ return pluginStatus_t::STATUS_SUCCESS;
}
-size_t SkipLayerNormVarSeqlenPlugin::getWorkspaceSize(
- PluginTensorDesc const* inputs, int32_t nbInputs, PluginTensorDesc const* outputs, int32_t nbOutputs) const noexcept
+size_t SkipLayerNormVarSeqlenPluginV3::getWorkspaceSize(DynamicPluginTensorDesc const* inputs, int32_t nbInputs,
+ DynamicPluginTensorDesc const* outputs, int32_t nbOutputs) const noexcept
{
return 0;
}
-int32_t SkipLayerNormVarSeqlenPlugin::enqueue(PluginTensorDesc const* inputDesc, PluginTensorDesc const* outputDesc,
- void const* const* inputs, void* const* outputs, void* /* workspace */, cudaStream_t stream) noexcept
+int32_t SkipLayerNormVarSeqlenPluginV3::enqueue(nvinfer1::PluginTensorDesc const* inputDesc,
+ nvinfer1::PluginTensorDesc const* outputDesc, void const* const* inputs, void* const* outputs, void* workspace,
+ cudaStream_t stream) noexcept
{
int32_t status = -1;
try
@@ -870,102 +791,165 @@ int32_t SkipLayerNormVarSeqlenPlugin::enqueue(PluginTensorDesc const* inputDesc,
return status;
}
-// IPluginV2Ext Methods
-DataType SkipLayerNormVarSeqlenPlugin::getOutputDataType(
- int32_t index, DataType const* inputTypes, int32_t nbInputs) const noexcept
-{
- PLUGIN_VALIDATE(inputTypes != nullptr);
- PLUGIN_VALIDATE(index == 0);
- PLUGIN_VALIDATE(nbInputs == 2);
- return inputTypes[0];
-}
-
-// IPluginV2 Methods
-char const* SkipLayerNormVarSeqlenPlugin::getPluginType() const noexcept
+int32_t SkipLayerNormVarSeqlenPluginV3::getOutputDataTypes(
+ DataType* outputTypes, int32_t nbOutputs, DataType const* inputTypes, int32_t nbInputs) const noexcept
{
- return kSKIP_LAYER_NORM_NAME;
+ try
+ {
+ PLUGIN_VALIDATE(outputTypes != nullptr);
+ PLUGIN_VALIDATE(nbOutputs == 1);
+ PLUGIN_VALIDATE(inputTypes != nullptr);
+ PLUGIN_VALIDATE(nbInputs == 2);
+ outputTypes[0] = inputTypes[0];
+ return pluginStatus_t::STATUS_SUCCESS;
+ }
+ catch (std::exception const& e)
+ {
+ caughtError(e);
+ }
+ return pluginStatus_t::STATUS_FAILURE;
}
-char const* SkipLayerNormVarSeqlenPlugin::getPluginVersion() const noexcept
+char const* SkipLayerNormVarSeqlenPluginV3::getPluginVersion() const noexcept
{
return kSKIP_LAYER_NORM_VAR_SEQLEN_VERSION;
}
-int32_t SkipLayerNormVarSeqlenPlugin::getNbOutputs() const noexcept
+int32_t SkipLayerNormVarSeqlenPluginV3::getNbOutputs() const noexcept
{
return 1;
}
-int32_t SkipLayerNormVarSeqlenPlugin::initialize() noexcept
+PluginFieldCollection const* SkipLayerNormVarSeqlenPluginV3::getFieldsToSerialize() noexcept
{
- BERT_DEBUG_MSG("SkipLayerNormVarSeqlenPlugin initialize");
- return 0;
+ mDataToSerialize.clear();
+ mDataToSerialize.emplace_back("type_id", &mType, PluginFieldType::kINT32, 1);
+ mDataToSerialize.emplace_back("ld", &mLd, PluginFieldType::kINT32, 1);
+ if (mCfgType == DataType::kHALF)
+ {
+ mDataToSerialize.emplace_back(
+ "beta", static_cast(mBeta.values), PluginFieldType::kFLOAT16, mBeta.count);
+ PLUGIN_ASSERT(mBeta.type == mCfgType);
+ mDataToSerialize.emplace_back(
+ "gamma", static_cast(mGamma.values), PluginFieldType::kFLOAT16, mGamma.count);
+ PLUGIN_ASSERT(mGamma.type == mCfgType);
+ if (mHasBias)
+ {
+ mDataToSerialize.emplace_back(
+ "bias", static_cast(mBias.values), PluginFieldType::kFLOAT16, mBias.count);
+ PLUGIN_ASSERT(mBias.type == mCfgType);
+ }
+ }
+ else
+ {
+ PLUGIN_ASSERT(mCfgType == DataType::kFLOAT);
+ mDataToSerialize.emplace_back(
+ "beta", static_cast(mBeta.values), PluginFieldType::kFLOAT32, mBeta.count);
+ PLUGIN_ASSERT(mBeta.type == mCfgType);
+ mDataToSerialize.emplace_back(
+ "gamma", static_cast(mGamma.values), PluginFieldType::kFLOAT32, mGamma.count);
+ PLUGIN_ASSERT(mGamma.type == mCfgType);
+ if (mHasBias)
+ {
+ mDataToSerialize.emplace_back(
+ "bias", static_cast(mBias.values), PluginFieldType::kFLOAT32, mBias.count);
+ PLUGIN_ASSERT(mBias.type == mCfgType);
+ }
+ }
+
+ mFCToSerialize.nbFields = mDataToSerialize.size();
+ mFCToSerialize.fields = mDataToSerialize.data();
+
+ return &mFCToSerialize;
+}
+
+void SkipLayerNormVarSeqlenPluginV3::setPluginNamespace(char const* libNamespace) noexcept
+{
+ mNamespace = libNamespace;
}
-void SkipLayerNormVarSeqlenPlugin::terminate() noexcept
+char const* SkipLayerNormVarSeqlenPluginV3::getPluginNamespace() const noexcept
{
- BERT_DEBUG_MSG("SkipLayerNormVarSeqlenPlugin terminate");
+ return mNamespace.c_str();
}
-size_t SkipLayerNormVarSeqlenPlugin::getSerializationSize() const noexcept
+char const* SkipLayerNormVarSeqlenPluginV3::getPluginName() const noexcept
{
- const size_t biasSize = mHasBias ? (mLd * mParamWordsize) : 0;
- return 2 * mParamWordsize * mLd + 2 * sizeof(DataType) + sizeof(mLd) + biasSize + sizeof(mHasBias);
+ return kSKIP_LAYER_NORM_NAME;
}
-void SkipLayerNormVarSeqlenPlugin::serialize(void* buffer) const noexcept
+int32_t SkipLayerNormVarSeqlenPluginV3::onShapeChange(
+ PluginTensorDesc const* inputs, int32_t nbInputs, PluginTensorDesc const* outputs, int32_t nbOutputs) noexcept
{
try
{
- serialize_value(&buffer, mType);
- serialize_value(&buffer, mCfgType);
- serialize_value(&buffer, mLd);
- serialize_value(&buffer, mHasBias);
-
- char* d = static_cast(buffer);
- serFromDev(d, static_cast(mBetaDev.get()), mLd * mParamWordsize);
- serFromDev(d, static_cast(mGammaDev.get()), mLd * mParamWordsize);
- if (mHasBias)
+ // Validate input arguments
+ PLUGIN_VALIDATE(inputs != nullptr);
+ PLUGIN_VALIDATE(outputs != nullptr);
+ PLUGIN_VALIDATE(nbOutputs == 1);
+ PLUGIN_VALIDATE(nbInputs == 2);
+
+ if (mType == DataType::kFLOAT || mType == DataType::kHALF)
{
- serFromDev(d, static_cast(mBiasDev.get()), mLd * mParamWordsize);
+ PLUGIN_VALIDATE(mType == inputs[0].type);
+ PLUGIN_VALIDATE(mType == inputs[1].type);
}
+ else
+ {
+ PLUGIN_VALIDATE(mType == inputs[0].type || DataType::kFLOAT == inputs[0].type);
+ PLUGIN_VALIDATE(mType == inputs[1].type || DataType::kFLOAT == inputs[1].type);
+ }
+ auto const& inDims0 = inputs[0].dims;
+ auto const& inDims1 = inputs[1].dims;
+ PLUGIN_VALIDATE(inDims0.nbDims == inDims1.nbDims);
+
+ PLUGIN_VALIDATE(std::equal(inDims0.d, inDims0.d + inDims0.nbDims, inDims1.d));
+
+ mCfgType = inputs[0].type == DataType::kINT8 ? DataType::kHALF : inputs[0].type;
+
+ mParamWordsize = getElementSize(mCfgType);
+
+ return pluginStatus_t::STATUS_SUCCESS;
}
catch (std::exception const& e)
{
caughtError(e);
}
+ return pluginStatus_t::STATUS_FAILURE;
}
-void SkipLayerNormVarSeqlenPlugin::destroy() noexcept
+IPluginV3* SkipLayerNormVarSeqlenPluginV3::attachToContext(IPluginResourceContext* context) noexcept
+{
+ return clone();
+}
+
+IPluginCapability* SkipLayerNormVarSeqlenPluginV3::getCapabilityInterface(PluginCapabilityType type) noexcept
{
try
{
- BERT_DEBUG_MSG("SkipLayerNormVarSeqlenPlugin destroy");
- // This gets called when the network containing plugin is destroyed
- mGammaDev.reset(nullptr);
- mBetaDev.reset(nullptr);
- mBiasDev.reset(nullptr);
- delete this;
+ if (type == PluginCapabilityType::kBUILD)
+ {
+ return static_cast(this);
+ }
+ if (type == PluginCapabilityType::kRUNTIME)
+ {
+ return static_cast(this);
+ }
+ PLUGIN_ASSERT(type == PluginCapabilityType::kCORE);
+ return static_cast(this);
}
catch (std::exception const& e)
{
caughtError(e);
}
+ return nullptr;
}
-void SkipLayerNormVarSeqlenPlugin::setPluginNamespace(char const* libNamespace) noexcept
-{
- mNamespace = libNamespace;
-}
-
-char const* SkipLayerNormVarSeqlenPlugin::getPluginNamespace() const noexcept
-{
- return mNamespace.c_str();
-}
-
-/////////////////////////////////////////////////////////
+////////////////////////// SkipLayerNormVarSeqlenPluginV3Creator ///////////////////////////////
-SkipLayerNormVarSeqlenPluginCreator::SkipLayerNormVarSeqlenPluginCreator()
+SkipLayerNormVarSeqlenPluginV3Creator::SkipLayerNormVarSeqlenPluginV3Creator()
{
+ static std::mutex sMutex;
+ std::lock_guard guard(sMutex);
mPluginAttributes.clear();
mPluginAttributes.emplace_back(PluginField("type_id"));
mPluginAttributes.emplace_back(PluginField("beta"));
@@ -975,26 +959,27 @@ SkipLayerNormVarSeqlenPluginCreator::SkipLayerNormVarSeqlenPluginCreator()
mFC.fields = mPluginAttributes.data();
}
-char const* SkipLayerNormVarSeqlenPluginCreator::getPluginName() const noexcept
+char const* SkipLayerNormVarSeqlenPluginV3Creator::getPluginName() const noexcept
{
return kSKIP_LAYER_NORM_NAME;
}
-char const* SkipLayerNormVarSeqlenPluginCreator::getPluginVersion() const noexcept
+char const* SkipLayerNormVarSeqlenPluginV3Creator::getPluginVersion() const noexcept
{
return kSKIP_LAYER_NORM_VAR_SEQLEN_VERSION;
}
-PluginFieldCollection const* SkipLayerNormVarSeqlenPluginCreator::getFieldNames() noexcept
+PluginFieldCollection const* SkipLayerNormVarSeqlenPluginV3Creator::getFieldNames() noexcept
{
return &mFC;
}
-IPluginV2* SkipLayerNormVarSeqlenPluginCreator::createPlugin(char const* name, PluginFieldCollection const* fc) noexcept
+IPluginV3* SkipLayerNormVarSeqlenPluginV3Creator::createPlugin(
+ char const* name, PluginFieldCollection const* fc, TensorRTPhase phase) noexcept
{
try
{
- BERT_DEBUG_MSG("SkipLayerNormVarSeqlenPluginCreator createPlugin");
+ BERT_DEBUG_MSG("SkipLayerNormVarSeqlenPluginV3Creator createPlugin");
Weights beta{DataType::kFLOAT, nullptr, 0};
Weights gamma{DataType::kFLOAT, nullptr, 0};
@@ -1007,36 +992,21 @@ IPluginV2* SkipLayerNormVarSeqlenPluginCreator::createPlugin(char const* name, P
for (int32_t i = 0; i < fc->nbFields; i++)
{
- std::string field_name(fc->fields[i].name);
-
- if (field_name.compare("type_id") == 0)
+ std::string fieldName(fc->fields[i].name);
+ if (fieldName == "type_id")
{
typeId = *static_cast(fc->fields[i].data);
BERT_DEBUG_VALUE("Building typeId: ", typeId);
}
-
- if (field_name.compare("beta") == 0)
+ // process the weight tensors beta, gamma, bias
+ else if (fieldName == "beta" || fieldName == "gamma" || fieldName == "bias")
{
- BERT_DEBUG_MSG("Building beta...");
- beta.values = fc->fields[i].data;
- beta.count = fc->fields[i].length;
- beta.type = fieldTypeToDataType(fc->fields[i].type);
- }
+ Weights* weightPtr = (fieldName == "beta") ? &beta : (fieldName == "gamma") ? &gamma : &bias;
- if (field_name.compare("gamma") == 0)
- {
- BERT_DEBUG_MSG("Building gamma...");
- gamma.values = fc->fields[i].data;
- gamma.count = fc->fields[i].length;
- gamma.type = fieldTypeToDataType(fc->fields[i].type);
- }
-
- if (field_name.compare("bias") == 0)
- {
- BERT_DEBUG_MSG("Building bias...");
- bias.values = fc->fields[i].data;
- bias.count = fc->fields[i].length;
- bias.type = fieldTypeToDataType(fc->fields[i].type);
+ BERT_DEBUG_MSG(("Building " + fieldName + "...").c_str());
+ weightPtr->type = fieldTypeToDataType(fc->fields[i].type);
+ weightPtr->values = fc->fields[i].data;
+ weightPtr->count = fc->fields[i].length;
}
}
BERT_DEBUG_VALUE("Type ", typeId);
@@ -1050,7 +1020,7 @@ IPluginV2* SkipLayerNormVarSeqlenPluginCreator::createPlugin(char const* name, P
PLUGIN_VALIDATE(gamma.values != nullptr, "SkipLayerNorm: invalid gamma");
PLUGIN_VALIDATE(gamma.count > 0, "SkipLayerNorm: invalid gamma");
- return new SkipLayerNormVarSeqlenPlugin(name, static_cast(typeId), beta, gamma, bias);
+ return new SkipLayerNormVarSeqlenPluginV3(name, static_cast(typeId), beta, gamma, bias);
}
catch (std::exception const& e)
{
@@ -1059,28 +1029,19 @@ IPluginV2* SkipLayerNormVarSeqlenPluginCreator::createPlugin(char const* name, P
return nullptr;
}
-IPluginV2* SkipLayerNormVarSeqlenPluginCreator::deserializePlugin(
- char const* name, void const* serialData, size_t serialLength) noexcept
+void SkipLayerNormVarSeqlenPluginV3Creator::setPluginNamespace(char const* libNamespace) noexcept
{
- // This object will be deleted when the network is destroyed, which will
- // call SkipLayerNormVarSeqlenPlugin::destroy()
try
{
- return new SkipLayerNormVarSeqlenPlugin(name, serialData, serialLength);
+ mNamespace = libNamespace;
}
catch (std::exception const& e)
{
caughtError(e);
}
- return nullptr;
-}
-
-void SkipLayerNormVarSeqlenPluginCreator::setPluginNamespace(char const* libNamespace) noexcept
-{
- mNamespace = libNamespace;
}
-char const* SkipLayerNormVarSeqlenPluginCreator::getPluginNamespace() const noexcept
+char const* SkipLayerNormVarSeqlenPluginV3Creator::getPluginNamespace() const noexcept
{
return mNamespace.c_str();
}
diff --git a/plugin/skipLayerNormPlugin/skipLayerNormPlugin.h b/plugin/skipLayerNormPlugin/skipLayerNormPlugin.h
index 9b1a783a..364fc68d 100644
--- a/plugin/skipLayerNormPlugin/skipLayerNormPlugin.h
+++ b/plugin/skipLayerNormPlugin/skipLayerNormPlugin.h
@@ -43,75 +43,100 @@ template
int32_t computeSkipLayerNorm(cudaStream_t stream, int32_t const ld, int32_t const n, T const* input, T const* skip,
T const* beta, T const* gamma, T* output, T const* bias);
-class SkipLayerNormPluginDynamic : public nvinfer1::IPluginV2DynamicExt
+class SkipLayerNormPluginV3 : public IPluginV3,
+ public IPluginV3OneCore,
+ public IPluginV3OneBuild,
+ public IPluginV3OneRuntime
{
public:
- SkipLayerNormPluginDynamic(const std::string name, const nvinfer1::DataType type, int32_t const ld,
+ SkipLayerNormPluginV3(const std::string name, const nvinfer1::DataType type, int32_t const ld,
nvinfer1::Weights const& beta, nvinfer1::Weights const& gamma, nvinfer1::Weights const& bias);
- SkipLayerNormPluginDynamic(const std::string name, void const* data, size_t length);
-
- // It doesn't make sense to make SkipLayerNormPluginDynamic without arguments,
+ // It doesn't make sense to make SkipLayerNormPluginV3 without arguments,
// so we delete default constructor.
- SkipLayerNormPluginDynamic() = delete;
+ SkipLayerNormPluginV3() = delete;
+
+ ~SkipLayerNormPluginV3() override;
+
+ // IPluginV3 Methods
+ IPluginCapability* getCapabilityInterface(PluginCapabilityType type) noexcept override;
+
+ IPluginV3* clone() noexcept override;
+ // end of IPluginV3 Methods
+
+ // IPluginV3OneCore Methods
+ char const* getPluginName() const noexcept override;
+
+ char const* getPluginVersion() const noexcept override;
+
+ char const* getPluginNamespace() const noexcept override;
+
+ void setPluginNamespace(char const* pluginNamespace) noexcept;
+ // end of IPluginV3OneCore Methods
+
+ // IPluginV3Build Methods
+ int32_t getNbOutputs() const noexcept override;
- // IPluginV2DynamicExt Methods
- nvinfer1::IPluginV2DynamicExt* clone() const noexcept override;
- nvinfer1::DimsExprs getOutputDimensions(int32_t outputIndex, nvinfer1::DimsExprs const* inputs, int32_t nbInputs,
- nvinfer1::IExprBuilder& exprBuilder) noexcept override;
bool supportsFormatCombination(
- int32_t pos, nvinfer1::PluginTensorDesc const* inOut, int32_t nbInputs, int32_t nbOutputs) noexcept override;
- void configurePlugin(nvinfer1::DynamicPluginTensorDesc const* in, int32_t nbInputs,
- nvinfer1::DynamicPluginTensorDesc const* out, int32_t nbOutputs) noexcept override;
- size_t getWorkspaceSize(nvinfer1::PluginTensorDesc const* inputs, int32_t nbInputs,
- nvinfer1::PluginTensorDesc const* outputs, int32_t nbOutputs) const noexcept override;
+ int32_t pos, DynamicPluginTensorDesc const* inOut, int32_t nbInputs, int32_t nbOutputs) noexcept override;
+
+ int32_t getOutputShapes(DimsExprs const* inputs, int32_t nbInputs, DimsExprs const* shapeInputs,
+ int32_t nbShapeInputs, DimsExprs* outputs, int32_t nbOutputs, IExprBuilder& exprBuilder) noexcept override;
+
+ int32_t configurePlugin(DynamicPluginTensorDesc const* in, int32_t nbInputs, DynamicPluginTensorDesc const* out,
+ int32_t nbOutputs) noexcept override;
+
+ size_t getWorkspaceSize(DynamicPluginTensorDesc const* inputs, int32_t nbInputs,
+ DynamicPluginTensorDesc const* outputs, int32_t nbOutputs) const noexcept override;
+
+ int32_t getOutputDataTypes(
+ DataType* outputTypes, int32_t nbOutputs, DataType const* inputTypes, int32_t nbInputs) const noexcept override;
+ // end IPluginV3Build Methods
+
+ // IPluginV3Runtime Methods
int32_t enqueue(nvinfer1::PluginTensorDesc const* inputDesc, nvinfer1::PluginTensorDesc const* outputDesc,
void const* const* inputs, void* const* outputs, void* workspace, cudaStream_t stream) noexcept override;
- // IPluginV2Ext Methods
- nvinfer1::DataType getOutputDataType(
- int32_t index, nvinfer1::DataType const* inputTypes, int32_t nbInputs) const noexcept override;
+ int32_t onShapeChange(
+ PluginTensorDesc const* in, int32_t nbInputs, PluginTensorDesc const* out, int32_t nbOutputs) noexcept override;
- // IPluginV2 Methods
- char const* getPluginType() const noexcept override;
- char const* getPluginVersion() const noexcept override;
- int32_t getNbOutputs() const noexcept override;
- int32_t initialize() noexcept override;
- void terminate() noexcept override;
- size_t getSerializationSize() const noexcept override;
- void serialize(void* buffer) const noexcept override;
- void destroy() noexcept override;
- void setPluginNamespace(char const* pluginNamespace) noexcept override;
- char const* getPluginNamespace() const noexcept override;
+ IPluginV3* attachToContext(IPluginResourceContext* context) noexcept override;
+
+ PluginFieldCollection const* getFieldsToSerialize() noexcept override;
+ // end IPluginV3Runtime Methods
private:
+ // metadata
const std::string mLayerName;
std::string mNamespace;
- bert::cuda_unique_ptr mGammaDev;
- bert::cuda_unique_ptr mBetaDev;
- size_t mLd{}; // leading dim
+ // members that participate in ser/deserialization
bert::WeightsWithOwnership mGamma;
bert::WeightsWithOwnership mBeta;
+ bert::WeightsWithOwnership mBias;
nvinfer1::DataType mType;
nvinfer1::DataType mCfgType;
-
+ int32_t mLd{}; // leading dim
bool mHasBias{};
+
+ // device-side
+ bert::cuda_unique_ptr mGammaDev;
+ bert::cuda_unique_ptr mBetaDev;
bert::cuda_unique_ptr mBiasDev;
- bert::WeightsWithOwnership mBias;
+ // derived member from mCfgType
size_t mParamWordsize{};
- using IPluginV2::enqueue;
- using IPluginV2::getOutputDimensions;
- using IPluginV2::getWorkspaceSize;
- using IPluginV2Ext::configurePlugin;
+ // serialization data structures
+ std::vector mDataToSerialize;
+ nvinfer1::PluginFieldCollection mFCToSerialize;
};
-class SkipLayerNormPluginDynamicCreator : public nvinfer1::IPluginCreator
+class SkipLayerNormPluginV3Creator : public nvinfer1::IPluginCreatorV3One
{
public:
- SkipLayerNormPluginDynamicCreator();
+ SkipLayerNormPluginV3Creator();
+ ~SkipLayerNormPluginV3Creator() override = default;
char const* getPluginName() const noexcept override;
@@ -119,61 +144,81 @@ class SkipLayerNormPluginDynamicCreator : public nvinfer1::IPluginCreator
nvinfer1::PluginFieldCollection const* getFieldNames() noexcept override;
- nvinfer1::IPluginV2* createPlugin(char const* name, nvinfer1::PluginFieldCollection const* fc) noexcept override;
+ IPluginV3* createPlugin(char const* name, PluginFieldCollection const* fc, TensorRTPhase phase) noexcept override;
- nvinfer1::IPluginV2* deserializePlugin(
- char const* name, void const* serialData, size_t serialLength) noexcept override;
-
- void setPluginNamespace(char const* pluginNamespace) noexcept override;
+ void setPluginNamespace(char const* libNamespace) noexcept;
char const* getPluginNamespace() const noexcept override;
private:
- static nvinfer1::PluginFieldCollection mFC;
- static std::vector mPluginAttributes;
+ static PluginFieldCollection mFC;
+ static std::vector mPluginAttributes;
std::string mNamespace;
};
-class SkipLayerNormVarSeqlenPlugin : public nvinfer1::IPluginV2DynamicExt
+class SkipLayerNormVarSeqlenPluginV3 : public IPluginV3,
+ public IPluginV3OneCore,
+ public IPluginV3OneBuild,
+ public IPluginV3OneRuntime
{
public:
- SkipLayerNormVarSeqlenPlugin(const std::string name, const nvinfer1::DataType type, nvinfer1::Weights const& beta,
+ SkipLayerNormVarSeqlenPluginV3(const std::string name, const nvinfer1::DataType type, nvinfer1::Weights const& beta,
nvinfer1::Weights const& gamma, nvinfer1::Weights const& bias);
- SkipLayerNormVarSeqlenPlugin(const std::string name, void const* data, size_t length);
+ SkipLayerNormVarSeqlenPluginV3(const std::string name, void const* data, size_t length);
- // It doesn't make sense to make SkipLayerNormVarSeqlenPlugin without
+ // It doesn't make sense to make SkipLayerNormVarSeqlenPluginV3 without
// arguments, so we delete default constructor.
- SkipLayerNormVarSeqlenPlugin() = delete;
+ SkipLayerNormVarSeqlenPluginV3() = delete;
+
+ ~SkipLayerNormVarSeqlenPluginV3() override;
+
+ // IPluginV3 Methods
+ IPluginCapability* getCapabilityInterface(PluginCapabilityType type) noexcept override;
+
+ IPluginV3* clone() noexcept override;
+ // end of IPluginV3 Methods
+
+ // IPluginV3OneCore Methods
+ char const* getPluginName() const noexcept override;
+
+ char const* getPluginVersion() const noexcept override;
+
+ char const* getPluginNamespace() const noexcept override;
+
+ void setPluginNamespace(char const* pluginNamespace) noexcept;
+ // end of IPluginV3OneCore Methods
+
+ // IPluginV3Build Methods
+ int32_t getNbOutputs() const noexcept override;
- // IPluginV2DynamicExt Methods
- nvinfer1::IPluginV2DynamicExt* clone() const noexcept override;
- nvinfer1::DimsExprs getOutputDimensions(int32_t outputIndex, nvinfer1::DimsExprs const* inputs, int32_t nbInputs,
- nvinfer1::IExprBuilder& exprBuilder) noexcept override;
bool supportsFormatCombination(
- int32_t pos, nvinfer1::PluginTensorDesc const* inOut, int32_t nbInputs, int32_t nbOutputs) noexcept override;
- void configurePlugin(nvinfer1::DynamicPluginTensorDesc const* in, int32_t nbInputs,
- nvinfer1::DynamicPluginTensorDesc const* out, int32_t nbOutputs) noexcept override;
- size_t getWorkspaceSize(nvinfer1::PluginTensorDesc const* inputs, int32_t nbInputs,
- nvinfer1::PluginTensorDesc const* outputs, int32_t nbOutputs) const noexcept override;
+ int32_t pos, DynamicPluginTensorDesc const* inOut, int32_t nbInputs, int32_t nbOutputs) noexcept override;
+
+ int32_t getOutputShapes(DimsExprs const* inputs, int32_t nbInputs, DimsExprs const* shapeInputs,
+ int32_t nbShapeInputs, DimsExprs* outputs, int32_t nbOutputs, IExprBuilder& exprBuilder) noexcept override;
+
+ int32_t configurePlugin(DynamicPluginTensorDesc const* in, int32_t nbInputs, DynamicPluginTensorDesc const* out,
+ int32_t nbOutputs) noexcept override;
+
+ size_t getWorkspaceSize(DynamicPluginTensorDesc const* inputs, int32_t nbInputs,
+ DynamicPluginTensorDesc const* outputs, int32_t nbOutputs) const noexcept override;
+
+ int32_t getOutputDataTypes(
+ DataType* outputTypes, int32_t nbOutputs, DataType const* inputTypes, int32_t nbInputs) const noexcept override;
+ // end IPluginV3Build Methods
+
+ // IPluginV3Runtime Methods
int32_t enqueue(nvinfer1::PluginTensorDesc const* inputDesc, nvinfer1::PluginTensorDesc const* outputDesc,
void const* const* inputs, void* const* outputs, void* workspace, cudaStream_t stream) noexcept override;
- // IPluginV2Ext Methods
- nvinfer1::DataType getOutputDataType(
- int32_t index, nvinfer1::DataType const* inputTypes, int32_t nbInputs) const noexcept override;
+ int32_t onShapeChange(
+ PluginTensorDesc const* in, int32_t nbInputs, PluginTensorDesc const* out, int32_t nbOutputs) noexcept override;
- // IPluginV2 Methods
- char const* getPluginType() const noexcept override;
- char const* getPluginVersion() const noexcept override;
- int32_t getNbOutputs() const noexcept override;
- int32_t initialize() noexcept override;
- void terminate() noexcept override;
- size_t getSerializationSize() const noexcept override;
- void serialize(void* buffer) const noexcept override;
- void destroy() noexcept override;
- void setPluginNamespace(char const* pluginNamespace) noexcept override;
- char const* getPluginNamespace() const noexcept override;
+ IPluginV3* attachToContext(IPluginResourceContext* context) noexcept override;
+
+ PluginFieldCollection const* getFieldsToSerialize() noexcept override;
+ // end IPluginV3Runtime Methods
private:
const std::string mLayerName;
@@ -181,7 +226,7 @@ class SkipLayerNormVarSeqlenPlugin : public nvinfer1::IPluginV2DynamicExt
bert::cuda_unique_ptr mGammaDev;
bert::cuda_unique_ptr mBetaDev;
- size_t mLd{}; // leading dim
+ int32_t mLd{}; // leading dim
bert::WeightsWithOwnership mGamma;
bert::WeightsWithOwnership mBeta;
nvinfer1::DataType mType;
@@ -193,29 +238,25 @@ class SkipLayerNormVarSeqlenPlugin : public nvinfer1::IPluginV2DynamicExt
size_t mParamWordsize{};
- using IPluginV2::enqueue;
- using IPluginV2::getOutputDimensions;
- using IPluginV2::getWorkspaceSize;
- using IPluginV2Ext::configurePlugin;
+ std::vector mDataToSerialize;
+ nvinfer1::PluginFieldCollection mFCToSerialize;
};
-class SkipLayerNormVarSeqlenPluginCreator : public nvinfer1::IPluginCreator
+class SkipLayerNormVarSeqlenPluginV3Creator : public nvinfer1::IPluginCreatorV3One
{
public:
- SkipLayerNormVarSeqlenPluginCreator();
+ SkipLayerNormVarSeqlenPluginV3Creator();
+ ~SkipLayerNormVarSeqlenPluginV3Creator() override = default;
char const* getPluginName() const noexcept override;
char const* getPluginVersion() const noexcept override;
- nvinfer1::PluginFieldCollection const* getFieldNames() noexcept override;
-
- nvinfer1::IPluginV2* createPlugin(char const* name, nvinfer1::PluginFieldCollection const* fc) noexcept override;
+ PluginFieldCollection const* getFieldNames() noexcept override;
- nvinfer1::IPluginV2* deserializePlugin(
- char const* name, void const* serialData, size_t serialLength) noexcept override;
+ IPluginV3* createPlugin(char const* name, PluginFieldCollection const* fc, TensorRTPhase phase) noexcept override;
- void setPluginNamespace(char const* pluginNamespace) noexcept override;
+ void setPluginNamespace(char const* libNamespace) noexcept;
char const* getPluginNamespace() const noexcept override;
diff --git a/plugin/skipLayerNormPlugin/skipLayerNormPluginLegacy.cpp b/plugin/skipLayerNormPlugin/skipLayerNormPluginLegacy.cpp
new file mode 100644
index 00000000..119cc8d0
--- /dev/null
+++ b/plugin/skipLayerNormPlugin/skipLayerNormPluginLegacy.cpp
@@ -0,0 +1,1078 @@
+/*
+ * SPDX-FileCopyrightText: Copyright (c) 1993-2024 NVIDIA CORPORATION &
+ * AFFILIATES. All rights reserved. SPDX-License-Identifier: Apache-2.0
+ *
+ * Licensed under the Apache License, Version 2.0 (the "License");
+ * you may not use this file except in compliance with the License.
+ * You may obtain a copy of the License at
+ *
+ * http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing, software
+ * distributed under the License is distributed on an "AS IS" BASIS,
+ * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+ * See the License for the specific language governing permissions and
+ * limitations under the License.
+ */
+
+#include
+#if CUDA_VERSION >= 10010
+
+#include "NvInfer.h"
+#include "common/serialize.hpp"
+#include "skipLayerNormPluginLegacy.h"
+
+#include
+#include
+
+using namespace nvinfer1;
+using namespace nvinfer1::plugin;
+using namespace nvinfer1::plugin::bert;
+
+// Clip plugin specific constants
+namespace
+{
+constexpr char const* kSKIP_LAYER_NORM_VERSION{"1"};
+constexpr char const* kSKIP_LAYER_NORM_NAME{"CustomSkipLayerNormPluginDynamic"};
+constexpr char const* kSKIP_LAYER_NORM_VAR_SEQLEN_VERSION{"2"};
+} // namespace
+
+// Static class fields initialization
+PluginFieldCollection SkipLayerNormPluginDynamicCreator::mFC{};
+std::vector SkipLayerNormPluginDynamicCreator::mPluginAttributes;
+
+PluginFieldCollection SkipLayerNormVarSeqlenPluginCreator::mFC{};
+std::vector SkipLayerNormVarSeqlenPluginCreator::mPluginAttributes;
+
+REGISTER_TENSORRT_PLUGIN(SkipLayerNormPluginDynamicCreator);
+REGISTER_TENSORRT_PLUGIN(SkipLayerNormVarSeqlenPluginCreator);
+
+SkipLayerNormPluginDynamic::SkipLayerNormPluginDynamic(const std::string name, const DataType type, int32_t const ld,
+ Weights const& beta, Weights const& gamma, Weights const& bias)
+ : mLayerName(name)
+ , mGammaDev(nullptr)
+ , mBetaDev(nullptr)
+ , mLd(ld)
+ , mType(type)
+ , mBiasDev(nullptr)
+{
+ PLUGIN_VALIDATE(mType == nvinfer1::DataType::kFLOAT || mType == nvinfer1::DataType::kHALF
+ || mType == nvinfer1::DataType::kINT8);
+ // mCfgType is the dataType for beta, gamma bias weights, always fp16 or fp32
+ // mType is the plugin IO datatype, can be int8
+ mCfgType = mType == DataType::kINT8 ? DataType::kHALF : mType;
+ mParamWordsize = getElementSize(mCfgType);
+
+ mBeta.convertAndCopy(beta, mCfgType);
+ mGamma.convertAndCopy(gamma, mCfgType);
+
+ mHasBias = (bias.values != nullptr);
+ if (mHasBias)
+ {
+ mBias.convertAndCopy(bias, mCfgType);
+ }
+
+ copyToDevice(mGamma, getWeightsSize(mGamma, mCfgType), mGammaDev);
+ copyToDevice(mBeta, getWeightsSize(mBeta, mCfgType), mBetaDev);
+ if (mHasBias)
+ {
+ copyToDevice(mBias, getWeightsSize(mBias, mCfgType), mBiasDev);
+ }
+}
+
+SkipLayerNormPluginDynamic::SkipLayerNormPluginDynamic(const std::string name, void const* data, size_t length)
+ : mLayerName(name)
+ , mGammaDev(nullptr)
+ , mBetaDev(nullptr)
+ , mBiasDev(nullptr)
+{
+ BERT_DEBUG_MSG("SkipLayerNormPluginDynamic deserialize");
+
+ // Deserialize in the same order as serialization
+ deserialize_value(&data, &length, &mType);
+ deserialize_value(&data, &length, &mCfgType);
+ deserialize_value(&data, &length, &mLd);
+ deserialize_value(&data, &length, &mHasBias);
+
+ PLUGIN_VALIDATE(mCfgType == nvinfer1::DataType::kFLOAT || mCfgType == nvinfer1::DataType::kHALF);
+ mParamWordsize = getElementSize(mCfgType);
+
+ char const* d = static_cast(data);
+ mBeta.convertAndCopy(d, mLd, mCfgType);
+ mGamma.convertAndCopy(d, mLd, mCfgType);
+ if (mHasBias)
+ {
+ mBias.convertAndCopy(d, mLd, mCfgType);
+ }
+
+ copyToDevice(mGamma, getWeightsSize(mGamma, mCfgType), mGammaDev);
+ copyToDevice(mBeta, getWeightsSize(mBeta, mCfgType), mBetaDev);
+ if (mHasBias)
+ {
+ copyToDevice(mBias, getWeightsSize(mBias, mCfgType), mBiasDev);
+ }
+}
+
+// IPluginV2DynamicExt Methods
+IPluginV2DynamicExt* SkipLayerNormPluginDynamic::clone() const noexcept
+{
+ try
+ {
+ BERT_DEBUG_MSG("SkipLayerNormPluginDynamic clone");
+
+ auto* p = new SkipLayerNormPluginDynamic(mLayerName, mType, mLd, mBeta, mGamma, mBias);
+ p->initialize();
+ p->setPluginNamespace(mNamespace.c_str());
+ return p;
+ }
+ catch (std::exception const& e)
+ {
+ caughtError(e);
+ }
+ return nullptr;
+}
+
+DimsExprs SkipLayerNormPluginDynamic::getOutputDimensions(
+ int32_t outputIndex, DimsExprs const* inputs, int32_t nbInputs, IExprBuilder& exprBuilder) noexcept
+{
+ try
+ {
+ PLUGIN_VALIDATE(inputs != nullptr);
+ PLUGIN_VALIDATE(nbInputs == 2);
+ PLUGIN_VALIDATE(outputIndex == 0);
+ PLUGIN_VALIDATE(inputs[0].nbDims == inputs[1].nbDims);
+ return inputs[0];
+ }
+ catch (std::exception const& e)
+ {
+ caughtError(e);
+ }
+ return DimsExprs{};
+}
+
+bool SkipLayerNormPluginDynamic::supportsFormatCombination(
+ int32_t pos, PluginTensorDesc const* inOut, int32_t nbInputs, int32_t nbOutputs) noexcept
+{
+ try
+ {
+ PLUGIN_VALIDATE(inOut != nullptr);
+ PLUGIN_VALIDATE(nbInputs == 2);
+ PLUGIN_VALIDATE(nbOutputs == 1);
+ PLUGIN_VALIDATE(pos >= 0 && pos < (nbInputs + nbOutputs));
+
+ PluginTensorDesc const& in = inOut[pos];
+ if (pos == 0)
+ {
+ // Since H = W = 1, we can report CHWx for any x
+ if (mType == DataType::kINT8)
+ {
+ // won't work for hiddensize too small!
+ TensorFormat myFmt = TensorFormat::kCHW32;
+ if (mLd < 32)
+ {
+ myFmt = TensorFormat::kCHW4;
+ BERT_DEBUG_VALUE("SkipLayerNormDQQ: TensorFormat CHW4 for LD=", mLd);
+ }
+ else
+ {
+ BERT_DEBUG_VALUE("SkipLayerNormDQQ: TensorFormat CHW32 for LD=", mLd);
+ }
+ // TODO do we need to check if the vectorization divides mLd?
+ return ((in.type == mType) && (in.format == myFmt));
+ }
+ return (in.type == mType) && (in.format == TensorFormat::kLINEAR);
+ }
+ PluginTensorDesc const& prev = inOut[pos - 1];
+
+ return in.type == prev.type && in.format == prev.format;
+ }
+ catch (std::exception const& e)
+ {
+ caughtError(e);
+ }
+ return false;
+}
+
+void SkipLayerNormPluginDynamic::configurePlugin(DynamicPluginTensorDesc const* inputs, int32_t nbInputs,
+ DynamicPluginTensorDesc const* outputs, int32_t nbOutputs) noexcept
+{
+ try
+ {
+ BERT_DEBUG_MSG("SkipLayerNormPluginDynamic configurePlugin");
+
+ // Validate input arguments
+ PLUGIN_VALIDATE(inputs != nullptr);
+ PLUGIN_VALIDATE(outputs != nullptr);
+ PLUGIN_VALIDATE(nbOutputs == 1);
+ PLUGIN_VALIDATE(nbInputs == 2);
+ if (mType == DataType::kFLOAT || mType == DataType::kHALF)
+ {
+ PLUGIN_VALIDATE(mType == inputs[0].desc.type);
+ PLUGIN_VALIDATE(mType == inputs[1].desc.type);
+ }
+ else
+ {
+ PLUGIN_VALIDATE(mType == inputs[0].desc.type || DataType::kFLOAT == inputs[0].desc.type);
+ PLUGIN_VALIDATE(mType == inputs[1].desc.type || DataType::kFLOAT == inputs[1].desc.type);
+ }
+ auto const& inDims0 = inputs[0].desc.dims;
+ auto const& inDims1 = inputs[1].desc.dims;
+ PLUGIN_VALIDATE(inDims0.nbDims == inDims1.nbDims);
+
+ PLUGIN_VALIDATE(std::equal(inDims0.d, inDims0.d + inDims0.nbDims, inDims1.d));
+
+ PLUGIN_VALIDATE(inDims0.nbDims == 5);
+ mLd = inDims0.d[HDIM]; // hiddensize
+ PLUGIN_VALIDATE(mLd != 0U);
+ PLUGIN_VALIDATE(inDims0.d[3] == 1);
+ PLUGIN_VALIDATE(inDims0.d[4] == 1);
+
+ mCfgType = inputs[0].desc.type == DataType::kINT8 ? DataType::kHALF : inputs[0].desc.type;
+
+ auto const paramType = mCfgType == DataType::kINT8 ? DataType::kHALF : mCfgType;
+ mParamWordsize = getElementSize(paramType);
+ }
+ catch (std::exception const& e)
+ {
+ caughtError(e);
+ }
+}
+
+size_t SkipLayerNormPluginDynamic::getWorkspaceSize(
+ PluginTensorDesc const* inputs, int32_t nbInputs, PluginTensorDesc const* outputs, int32_t nbOutputs) const noexcept
+{
+ return 0;
+}
+
+int32_t SkipLayerNormPluginDynamic::enqueue(PluginTensorDesc const* inputDesc, PluginTensorDesc const* outputDesc,
+ void const* const* inputs, void* const* outputs, void* /* workspace */, cudaStream_t stream) noexcept
+{
+ int32_t status = -1;
+ try
+ {
+ PLUGIN_VALIDATE(inputDesc != nullptr && outputDesc != nullptr && inputs != nullptr && outputs != nullptr);
+
+ int32_t const inputVolume = volume(inputDesc[0].dims);
+ DataType iType = inputDesc->type;
+
+ // Our plugin outputs only one tensor
+ // Launch CUDA kernel wrapper and save its return value
+ if (iType == DataType::kFLOAT)
+ {
+ auto const* const input = static_cast(inputs[0]);
+ auto const* const skip = static_cast(inputs[1]);
+ auto* output = static_cast