diff --git a/.gitignore b/.gitignore
index c98f0fb8..20d14a0f 100644
--- a/.gitignore
+++ b/.gitignore
@@ -130,3 +130,6 @@ dmypy.json
# Pyre type checker
.pyre/
+
+.vscode
+detr_tf/custom_ops/ms_deform_attn/ms_deform_im2col.o
diff --git a/README.md b/README.md
index 35c8b087..cfc32882 100644
--- a/README.md
+++ b/README.md
@@ -7,8 +7,9 @@ Tensorflow implementation of DETR : Object Detection with Transformers, includin
* [3. Tutorials](#tutorials)
* [4. Finetuning](#finetuning)
* [5. Training](#training)
-* [5. inference](#inference)
-* [6. Acknowledgement](#acknowledgement)
+* [5. Inference](#inference)
+* [6. Inference with TensorRT](#inference-with-tensorrt)
+* [7. Acknowledgement](#acknowledgement)
DETR paper: https://arxiv.org/pdf/2005.12872.pdf
@@ -152,6 +153,82 @@ python webcam_inference.py
+## Inference with TensorRT
+
+### Requirements:
+```
+cmake >= 3.8
+TensorRT 8
+```
+To install TensorRT 8, follow [NVIDIA TensorRT official installation guide](https://docs.nvidia.com/deeplearning/tensorrt/install-guide/index.html).
+
+Python package requirements:
+```
+onnx
+tf2onnx
+```
+
+### Custom plugin for Deformable DETR
+Deformable DETR use a custom operation Im2Col in its Transformer layer. This operation is not supported by TensorRT so we need to build a TensorRT custom plugin from source.
+
+```
+cd detr_tensorrt/plugins/ms_deform_im2col
+mkdir build && cd build
+cmake .. \
+ -DTRT_LIB=/path/to/tensorrt/lib/
+ -DTRT_INCLUDE=/path/to/tensorrt/include/
+ -DCUDA_ARCH_SM=/your_gpu_cuda_arch/
+make -j
+```
+For more detail, see: `detr_tensorrt/plugins/ms_deform_im2col/README.txt`
+
+Parameters:
+- `-DTRT_LIB`: Path to TensorRT lib. It could be `YOUR_TENSORRT_DIR/lib` or `/usr/lib/x86_64-linux-gnu`
+- `-DTRT_INCLUDE`: Path to TensorRT C++ include. It could be `YOUR_TENSORRT_DIR/include` or `/usr/include/x86_64-linux-gnu`
+- `-DCUDA_ARCHE_SM`: Compute capability of your NVIDIA GPU. Example: `70` for Tesla V100. Check [here](https://arnon.dk/matching-sm-architectures-arch-and-gencode-for-various-nvidia-cards/) for other GPU.
+
+### Workflow
+Tensorflow model --> ONNX --> TensorRT serialized engine
+
+#### Export Tensorflow graph to ONNX graph:
+
+For each model (detr/deformable-detr), we have:
+```
+python3 detr_tensorrt/export_onnx.py MODEL_NAME
+ [--input_shape H W]
+ [--save_to DIR_TO_SAVE_ONNX_FILE]
+```
+Parameters:
+- `--input_shape`: image height and width, default: 1280 1920
+- `--save_to`: directory that onnx file will be saved to. Default: `./weights/MODEL_NAME/MODEL_NAME_trt/`
+
+#### Convert ONNX model to TensorRT serialized engine:
+```
+python3 detr_tensorrt/onnx2engine.py MODEL_NAME
+ [--precision PRECISION]
+ [--onnx_dir ONNX_DIR]
+ [--verbose]
+```
+Parameters:
+- `--precision`: precision of model weights: FP32, FP16, MIX. MIX precision will let TensorRT the freedom to optimize weights as either FP32 or FP16. In most cases, the inference time between FP16 and MIX has no big difference.
+- `--onnx_dir`: directory containing the ONNX file to be converted to TensorRT engine. The required ONNX file must be named `MODEL_NAME.onnx`. Default: `./weights/MODEL_NAME/MODEL_NAME_trt/`
+- `--verbose`: Print out TensorRT log of all levels
+
+The TensorRT serialized engine will be saved in `ONNX_DIR/MODEL_NAME_PRECISION.engine`
+
+### Run inference
+An example of inference with a test image: `images/test.jpeg`
+
+```
+python tensorrt_inference.py --engine_path ENGINE_PATH
+```
+
+Inference time in milisecond:
+| | DETR | Deformable DETR |
+|---------------|------|-----------------|
+| Tensorflow | 100 | 160 |
+| TensorRT FP32 | 60 | 100 |
+| TensorRT FP16 | 27 | 60 |
## Acknowledgement
diff --git a/detr_tensorrt/TRTEngineBuilder.py b/detr_tensorrt/TRTEngineBuilder.py
new file mode 100644
index 00000000..cd82c35e
--- /dev/null
+++ b/detr_tensorrt/TRTEngineBuilder.py
@@ -0,0 +1,87 @@
+import tensorrt as trt
+import os
+
+TRT_LOGGER = trt.Logger(trt.Logger.VERBOSE)
+network_creation_flag = 1 << int(trt.NetworkDefinitionCreationFlag.EXPLICIT_BATCH)
+
+def GiB(val):
+ return val * 1 << 30
+
+class TRTEngineBuilder():
+ """
+ Work with TensorRT 8. Should work fine with TensorRT 7.2.3 (not tested)
+
+ Helper class to build TensorRT engine from ONNX graph file (including weights).
+ The graph must have defined input shape. For more detail, please see TensorRT Developer Guide:
+ https://docs.nvidia.com/deeplearning/tensorrt/developer-guide/index.html#python_topics
+ """
+ def __init__(self, onnx_file_path, FP16_allowed=False, INT8_allowed=False, strict_type=False, calibrator=None, logger=TRT_LOGGER):
+ """
+ Parameters:
+ -----------
+ onnx_file_path: str
+ path to ONNX graph file
+ FP16_allowed: bool
+ Enable FP16 precision for engine builder
+ INT8_allowed: bool
+ Enable FP16 precision for engine builder, user must provide also a calibrator
+ strict_type: bool
+ Ensure that the builder understands to force the precision
+ calibrator: extended instance from tensorrt.IInt8Calibrator
+ Used for INT8 quantization
+ """
+ self.FP16_allowed = FP16_allowed
+ self.INT8_allowed = INT8_allowed
+ self.onnx_file_path = onnx_file_path
+ self.calibrator = calibrator
+ self.max_workspace_size = GiB(8)
+ self.strict_type = strict_type
+ self.logger = logger
+
+ def set_workspace_size(self, workspace_size_GiB):
+ self.max_workspace_size = GiB(workspace_size_GiB)
+
+ def get_engine(self):
+ """
+ Setup engine builder, read ONNX graph and build TensorRT engine.
+ """
+ global network_creation_flag
+ with trt.Builder(self.logger) as builder, builder.create_network(network_creation_flag) as network, trt.OnnxParser(network, self.logger) as parser:
+ builder.max_batch_size = 1
+ config = builder.create_builder_config()
+ config.max_workspace_size = self.max_workspace_size
+ # FP16
+ if self.FP16_allowed:
+ config.set_flag(trt.BuilderFlag.FP16)
+ # INT8
+ if self.INT8_allowed:
+ raise NotImplementedError()
+ if self.strict_type:
+ config.set_flag(trt.BuilderFlag.STRICT_TYPES)
+
+ # Load and build model
+ with open(self.onnx_file_path, 'rb') as model:
+ if not parser.parse(model.read()):
+ print ('ERROR: Failed to parse the ONNX file.')
+ for error in range(parser.num_errors):
+ print (parser.get_error(error))
+ return None
+ else:
+ print("ONNX file is loaded")
+ print("Building engine...")
+ engine = builder.build_engine(network, config)
+ if engine is None:
+ raise Exception("TRT export engine error. Check log")
+ print("Engine built")
+ return engine
+
+ def export_engine(self, engine_path):
+ """Seriazlize TensorRT engine"""
+ engine = self.get_engine()
+ assert engine is not None, "Error while parsing engine from ONNX"
+ with open(engine_path, "wb") as f:
+ print("Serliaze and save as engine: " + engine_path)
+ f.write(engine.serialize())
+ print("Engine exported")
+
+
diff --git a/detr_tensorrt/TRTExecutor.py b/detr_tensorrt/TRTExecutor.py
new file mode 100644
index 00000000..8fce6cc7
--- /dev/null
+++ b/detr_tensorrt/TRTExecutor.py
@@ -0,0 +1,131 @@
+import ctypes
+import pycuda.autoinit as cuda_init
+from surroundnet.detr.tensorrt.trt_helper import *
+import tensorrt as trt
+
+TRT_LOGGER = trt.Logger(trt.Logger.VERBOSE)
+# trt.init_libnvinfer_plugins(None, "")
+
+class TRTExecutor():
+ """
+ A helper class to execute a TensorRT engine.
+
+ Attributes:
+ -----------
+ stream: pycuda.driver.Stream
+ engine: tensorrt.ICudaEngine
+ context: tensorrt.IExecutionContext
+ inputs/outputs: list[HostDeviceMem]
+ see trt_helper.py
+ bindings: list[int]
+ pointers in GPU for each input/output of the engine
+ dict_inputs/dict_outputs: dict[str, HostDeviceMem]
+ key = input node name
+ value = HostDeviceMem of corresponding binding
+
+ """
+ def __init__(self, engine_path=None, has_dynamic_shape=False, stream=None, engine=None):
+ """
+ Parameters:
+ ----------
+ engine_path: str
+ path to serialized TensorRT engine
+ has_dynamic_shape: bool
+ stream: pycuda.driver.Stream
+ if None, one will be created by allocate_buffers function
+ """
+ self.stream = stream
+ if engine_path is not None:
+ with open(engine_path, "rb") as f, trt.Runtime(TRT_LOGGER) as runtime:
+ print("Reading engine ...")
+ self.engine = runtime.deserialize_cuda_engine(f.read())
+ assert self.engine is not None, "Read engine failed"
+ print("Engine loaded")
+ elif engine is not None:
+ self.engine = engine
+ self.context = self.engine.create_execution_context()
+ if not has_dynamic_shape:
+ self.inputs, self.outputs, self.bindings, self.stream = allocate_buffers(self.context, self.stream)
+ self.dict_inputs = {mem_obj.name:mem_obj for mem_obj in self.inputs}
+ self.dict_outputs = {mem_obj.name:mem_obj for mem_obj in self.outputs}
+
+ def print_bindings_info(self):
+ print("ID / Name / isInput / shape / dtype")
+ for i in range(self.engine.num_bindings):
+ print(f"Binding: {i}, name: {self.engine.get_binding_name(i)}, input: {self.engine.binding_is_input(i)}, shape: {self.engine.get_binding_shape(i)}, dtype: {self.engine.get_binding_dtype(i)}")
+
+ def execute(self):
+ do_inference_async(
+ self.context,
+ bindings=self.bindings,
+ inputs=self.inputs,
+ outputs=self.outputs,
+ stream=self.stream
+ )
+
+ def set_binding_shape(self, binding:int, shape:tuple):
+ self.context.set_binding_shape(binding, shape)
+
+ def allocate_mem(self):
+ self.inputs, self.outputs, self.bindings, self.stream = allocate_buffers(self.context, self.stream)
+ self.dict_inputs = {mem_obj.name:mem_obj for mem_obj in self.inputs}
+ self.dict_outputs = {mem_obj.name:mem_obj for mem_obj in self.outputs}
+
+class TRTExecutor_Sync():
+ """
+ A helper class to execute a TensorRT engine.
+
+ Attributes:
+ -----------
+ engine: tensorrt.ICudaEngine
+ context: tensorrt.IExecutionContext
+ inputs/outputs: list[HostDeviceMem]
+ see trt_helper.py
+ bindings: list[int]
+ pointers in GPU for each input/output of the engine
+ dict_inputs/dict_outputs: dict[str, HostDeviceMem]
+ key = input node name
+ value = HostDeviceMem of corresponding binding
+
+ """
+ def __init__(self, engine_path=None, has_dynamic_shape=False, engine=None):
+ """
+ Parameters:
+ ----------
+ engine_path: str
+ path to serialized TensorRT engine
+ has_dynamic_shape: bool
+ """
+ if engine_path is not None:
+ with open(engine_path, "rb") as f, trt.Runtime(TRT_LOGGER) as runtime:
+ print("Reading engine ...")
+ self.engine = runtime.deserialize_cuda_engine(f.read())
+ assert self.engine is not None, "Read engine failed"
+ print("Engine loaded")
+ elif engine is not None:
+ self.engine = engine
+ self.context = self.engine.create_execution_context()
+ if not has_dynamic_shape:
+ self.inputs, self.outputs, self.bindings, self.stream = allocate_buffers(self.context, is_async=False)
+ self.dict_inputs = {mem_obj.name:mem_obj for mem_obj in self.inputs}
+ self.dict_outputs = {mem_obj.name:mem_obj for mem_obj in self.outputs}
+
+ def print_bindings_info(self):
+ print("ID / Name / isInput / shape / dtype")
+ for i in range(self.engine.num_bindings):
+ print(f"Binding: {i}, name: {self.engine.get_binding_name(i)}, input: {self.engine.binding_is_input(i)}, shape: {self.engine.get_binding_shape(i)}, dtype: {self.engine.get_binding_dtype(i)}")
+
+ def execute(self):
+ do_inference(
+ self.context,
+ bindings=self.bindings,
+ inputs=self.inputs,
+ outputs=self.outputs,
+ )
+
+ def set_binding_shape(self, binding:int, shape:tuple):
+ self.context.set_binding_shape(binding, shape)
+
+
+
+
diff --git a/detr_tensorrt/common.py b/detr_tensorrt/common.py
new file mode 100644
index 00000000..05a496cc
--- /dev/null
+++ b/detr_tensorrt/common.py
@@ -0,0 +1,199 @@
+#
+# Copyright 1993-2020 NVIDIA Corporation. All rights reserved.
+#
+# NOTICE TO LICENSEE:
+#
+# This source code and/or documentation ("Licensed Deliverables") are
+# subject to NVIDIA intellectual property rights under U.S. and
+# international Copyright laws.
+#
+# These Licensed Deliverables contained herein is PROPRIETARY and
+# CONFIDENTIAL to NVIDIA and is being provided under the terms and
+# conditions of a form of NVIDIA software license agreement by and
+# between NVIDIA and Licensee ("License Agreement") or electronically
+# accepted by Licensee. Notwithstanding any terms or conditions to
+# the contrary in the License Agreement, reproduction or disclosure
+# of the Licensed Deliverables to any third party without the express
+# written consent of NVIDIA is prohibited.
+#
+# NOTWITHSTANDING ANY TERMS OR CONDITIONS TO THE CONTRARY IN THE
+# LICENSE AGREEMENT, NVIDIA MAKES NO REPRESENTATION ABOUT THE
+# SUITABILITY OF THESE LICENSED DELIVERABLES FOR ANY PURPOSE. IT IS
+# PROVIDED "AS IS" WITHOUT EXPRESS OR IMPLIED WARRANTY OF ANY KIND.
+# NVIDIA DISCLAIMS ALL WARRANTIES WITH REGARD TO THESE LICENSED
+# DELIVERABLES, INCLUDING ALL IMPLIED WARRANTIES OF MERCHANTABILITY,
+# NONINFRINGEMENT, AND FITNESS FOR A PARTICULAR PURPOSE.
+# NOTWITHSTANDING ANY TERMS OR CONDITIONS TO THE CONTRARY IN THE
+# LICENSE AGREEMENT, IN NO EVENT SHALL NVIDIA BE LIABLE FOR ANY
+# SPECIAL, INDIRECT, INCIDENTAL, OR CONSEQUENTIAL DAMAGES, OR ANY
+# DAMAGES WHATSOEVER RESULTING FROM LOSS OF USE, DATA OR PROFITS,
+# WHETHER IN AN ACTION OF CONTRACT, NEGLIGENCE OR OTHER TORTIOUS
+# ACTION, ARISING OUT OF OR IN CONNECTION WITH THE USE OR PERFORMANCE
+# OF THESE LICENSED DELIVERABLES.
+#
+# U.S. Government End Users. These Licensed Deliverables are a
+# "commercial item" as that term is defined at 48 C.F.R. 2.101 (OCT
+# 1995), consisting of "commercial computer software" and "commercial
+# computer software documentation" as such terms are used in 48
+# C.F.R. 12.212 (SEPT 1995) and is provided to the U.S. Government
+# only as a commercial end item. Consistent with 48 C.F.R.12.212 and
+# 48 C.F.R. 227.7202-1 through 227.7202-4 (JUNE 1995), all
+# U.S. Government End Users acquire the Licensed Deliverables with
+# only those rights set forth herein.
+#
+# Any use of the Licensed Deliverables in individual and commercial
+# software must include, in the user documentation and internal
+# comments to the code, the above Disclaimer and U.S. Government End
+# Users Notice.
+#
+
+from itertools import chain
+import argparse
+import os
+
+import pycuda.driver as cuda
+import pycuda.autoinit
+import numpy as np
+
+import tensorrt as trt
+
+try:
+ # Sometimes python2 does not understand FileNotFoundError
+ FileNotFoundError
+except NameError:
+ FileNotFoundError = IOError
+
+EXPLICIT_BATCH = 1 << (int)(trt.NetworkDefinitionCreationFlag.EXPLICIT_BATCH)
+
+def GiB(val):
+ return val * 1 << 30
+
+
+def add_help(description):
+ parser = argparse.ArgumentParser(description=description, formatter_class=argparse.ArgumentDefaultsHelpFormatter)
+ args, _ = parser.parse_known_args()
+
+
+def find_sample_data(description="Runs a TensorRT Python sample", subfolder="", find_files=[]):
+ '''
+ Parses sample arguments.
+
+ Args:
+ description (str): Description of the sample.
+ subfolder (str): The subfolder containing data relevant to this sample
+ find_files (str): A list of filenames to find. Each filename will be replaced with an absolute path.
+
+ Returns:
+ str: Path of data directory.
+ '''
+
+ # Standard command-line arguments for all samples.
+ kDEFAULT_DATA_ROOT = os.path.join(os.sep, "usr", "src", "tensorrt", "data")
+ parser = argparse.ArgumentParser(description=description, formatter_class=argparse.ArgumentDefaultsHelpFormatter)
+ parser.add_argument("-d", "--datadir", help="Location of the TensorRT sample data directory, and any additional data directories.", action="append", default=[kDEFAULT_DATA_ROOT])
+ args, _ = parser.parse_known_args()
+
+ def get_data_path(data_dir):
+ # If the subfolder exists, append it to the path, otherwise use the provided path as-is.
+ data_path = os.path.join(data_dir, subfolder)
+ if not os.path.exists(data_path):
+ print("WARNING: " + data_path + " does not exist. Trying " + data_dir + " instead.")
+ data_path = data_dir
+ # Make sure data directory exists.
+ if not (os.path.exists(data_path)):
+ print("WARNING: {:} does not exist. Please provide the correct data path with the -d option.".format(data_path))
+ return data_path
+
+ data_paths = [get_data_path(data_dir) for data_dir in args.datadir]
+ return data_paths, locate_files(data_paths, find_files)
+
+def locate_files(data_paths, filenames):
+ """
+ Locates the specified files in the specified data directories.
+ If a file exists in multiple data directories, the first directory is used.
+
+ Args:
+ data_paths (List[str]): The data directories.
+ filename (List[str]): The names of the files to find.
+
+ Returns:
+ List[str]: The absolute paths of the files.
+
+ Raises:
+ FileNotFoundError if a file could not be located.
+ """
+ found_files = [None] * len(filenames)
+ for data_path in data_paths:
+ # Find all requested files.
+ for index, (found, filename) in enumerate(zip(found_files, filenames)):
+ if not found:
+ file_path = os.path.abspath(os.path.join(data_path, filename))
+ if os.path.exists(file_path):
+ found_files[index] = file_path
+
+ # Check that all files were found
+ for f, filename in zip(found_files, filenames):
+ if not f or not os.path.exists(f):
+ raise FileNotFoundError("Could not find {:}. Searched in data paths: {:}".format(filename, data_paths))
+ return found_files
+
+# Simple helper data class that's a little nicer to use than a 2-tuple.
+class HostDeviceMem(object):
+ def __init__(self, host_mem, device_mem):
+ self.host = host_mem
+ self.device = device_mem
+
+ def __str__(self):
+ return "Host:\n" + str(self.host) + "\nDevice:\n" + str(self.device)
+
+ def __repr__(self):
+ return self.__str__()
+
+# Allocates all buffers required for an engine, i.e. host/device inputs/outputs.
+def allocate_buffers(engine):
+ inputs = []
+ outputs = []
+ bindings = []
+ stream = cuda.Stream()
+ for binding in engine:
+ size = trt.volume(engine.get_binding_shape(binding)) * engine.max_batch_size
+ dtype = trt.nptype(engine.get_binding_dtype(binding))
+ # Allocate host and device buffers
+ host_mem = cuda.pagelocked_empty(size, dtype)
+ device_mem = cuda.mem_alloc(host_mem.nbytes)
+ # Append the device buffer to device bindings.
+ bindings.append(int(device_mem))
+ # Append to the appropriate list.
+ if engine.binding_is_input(binding):
+ inputs.append(HostDeviceMem(host_mem, device_mem))
+ else:
+ outputs.append(HostDeviceMem(host_mem, device_mem))
+ return inputs, outputs, bindings, stream
+
+# This function is generalized for multiple inputs/outputs.
+# inputs and outputs are expected to be lists of HostDeviceMem objects.
+def do_inference(context, bindings, inputs, outputs, stream, batch_size=1):
+ # Transfer input data to the GPU.
+ [cuda.memcpy_htod_async(inp.device, inp.host, stream) for inp in inputs]
+ # Run inference.
+ context.execute_async(batch_size=batch_size, bindings=bindings, stream_handle=stream.handle)
+ # Transfer predictions back from the GPU.
+ [cuda.memcpy_dtoh_async(out.host, out.device, stream) for out in outputs]
+ # Synchronize the stream
+ stream.synchronize()
+ # Return only the host outputs.
+ return [out.host for out in outputs]
+
+# This function is generalized for multiple inputs/outputs for full dimension networks.
+# inputs and outputs are expected to be lists of HostDeviceMem objects.
+def do_inference_v2(context, bindings, inputs, outputs, stream):
+ # Transfer input data to the GPU.
+ [cuda.memcpy_htod_async(inp.device, inp.host, stream) for inp in inputs]
+ # Run inference.
+ context.execute_async_v2(bindings=bindings, stream_handle=stream.handle)
+ # Transfer predictions back from the GPU.
+ [cuda.memcpy_dtoh_async(out.host, out.device, stream) for out in outputs]
+ # Synchronize the stream
+ stream.synchronize()
+ # Return only the host outputs.
+ return [out.host for out in outputs]
diff --git a/detr_tensorrt/export_onnx.py b/detr_tensorrt/export_onnx.py
new file mode 100644
index 00000000..c5f9acc6
--- /dev/null
+++ b/detr_tensorrt/export_onnx.py
@@ -0,0 +1,89 @@
+import tensorflow as tf
+import numpy as np
+import os
+import onnx
+import tf2onnx
+from pathlib import Path
+
+from detr_tf.training_config import TrainingConfig, training_config_parser
+
+from detr_tf.networks.detr import get_detr_model
+from detr_tf.networks.deformable_detr import get_deformable_detr_model
+
+def get_model(config, args):
+ if args.model == "detr":
+ print("Loading detr...")
+ # Load the model with the new layers to finetune
+ model = get_detr_model(config, include_top=True, weights="detr")
+ config.background_class = 91
+ m_output_names = ["pred_logits", "pred_boxes"]
+ use_mask = True
+ # model.summary()
+ # return model
+ elif args.model == "deformable-detr":
+ print("Loading deformable-detr...")
+ model = get_deformable_detr_model(config, include_top=True, weights="deformable-detr")
+ m_output_names = ["bbox_pred_logits", "bbox_pred_boxes"]
+ # [print(name, model.output[name]) for name in model.output]
+ # model.summary()
+ use_mask = False
+ else:
+ raise NotImplementedError()
+ # Remove auxliary outputs
+ input_image = tf.keras.Input(args.input_shape, batch_size=1, name="input_image")
+ if use_mask:
+ mask =tf.keras.Input(args.input_shape[:2] + [1], batch_size=1, name="input_mask")
+ m_inputs = (input_image, mask)
+ else:
+ m_inputs = (input_image, )
+ all_outputs = model(m_inputs, training=False)
+
+ m_outputs = {
+ name:tf.identity(all_outputs[name], name=name)
+ for name in m_output_names if name in all_outputs}
+ [print(m_outputs[name]) for name in m_outputs]
+
+ model = tf.keras.Model(m_inputs, m_outputs, name=args.model)
+ model.summary()
+ return model
+
+
+if __name__ == "__main__":
+ physical_devices = tf.config.list_physical_devices('GPU')
+ if len(physical_devices) == 1:
+ tf.config.experimental.set_memory_growth(physical_devices[0], True)
+
+ config = TrainingConfig()
+ parser = training_config_parser()
+ parser.add_argument("model", type=str, default="deformable-detr", help="One of 'detr', or 'deformable-detr'")
+ parser.add_argument("--input_shape", type=int, default=[1280, 1920], nargs=2, help="ex: 1280 1920 3")
+ parser.add_argument('--save_to', type=str, default=None, help="Path to save ONNX file")
+ args = parser.parse_args()
+ config.update_from_args(args)
+
+ args.input_shape.append(3) # C = 3
+
+ if args.save_to is None:
+ args.save_to = os.path.join("weights", args.model, args.model + "_trt")
+
+ # === Load model
+ model = get_model(config, args)
+ # === Save model to pb file
+ if not os.path.isdir(args.save_to):
+ os.makedirs(args.save_to)
+
+ # === Save onnx file
+ input_spec = [tf.TensorSpec.from_tensor(tensor) for tensor in model.input]
+ # print(input_spec)
+ output_path = os.path.join(args.save_to, args.model + ".onnx")
+ model_proto, _ = tf2onnx.convert.from_keras(
+ model, input_signature=input_spec,
+ opset=13, output_path=output_path)
+ print("===== Inputs =======")
+ [print(n.name) for n in model_proto.graph.input]
+ print("===== Outputs =======")
+ [print(n.name) for n in model_proto.graph.output]
+
+
+
+
diff --git a/detr_tensorrt/inference.py b/detr_tensorrt/inference.py
new file mode 100644
index 00000000..f4f215c1
--- /dev/null
+++ b/detr_tensorrt/inference.py
@@ -0,0 +1,76 @@
+import numpy as np
+import os
+import cv2
+import argparse
+
+from detr_tf import bbox
+from detr_tensorrt.TRTExecutor import TRTExecutor
+from scipy.special import softmax
+
+def normalized_images(image, normalized_method="torch_resnet"):
+ """ Normalized images. torch_resnet is used on finetuning
+ since the weights are based on the original paper training code
+ from pytorch. tf_resnet is used when training from scratch with a
+ resnet50 traine don tensorflow.
+ """
+ if normalized_method == "torch_resnet":
+ channel_avg = np.array([0.485, 0.456, 0.406])
+ channel_std = np.array([0.229, 0.224, 0.225])
+ image = (image / 255.0 - channel_avg) / channel_std
+ return image.astype(np.float32)
+ elif normalized_method == "tf_resnet":
+ mean = [103.939, 116.779, 123.68]
+ image = image[..., ::-1]
+ image = image - mean
+ return image.astype(np.float32)
+ else:
+ raise Exception("Can't handler thid normalized method")
+
+def sigmoid(x):
+ return 1/(1 + np.exp(-x))
+
+def get_model_inference(m_outputs: dict, background_class, bbox_format="xy_center", threshold=None):
+
+ #print('get model inference', [key for key in m_outputs])
+
+ # Detr or deformable
+ predicted_bbox = m_outputs["pred_boxes"][0] if "pred_boxes" in m_outputs else m_outputs["bbox_pred_boxes"][0]
+ predicted_labels = m_outputs["pred_logits"][0] if "pred_logits" in m_outputs else m_outputs["bbox_pred_logits"][0]
+ activation = "softmax" if "pred_boxes" in m_outputs else "sigmoid"
+
+ if activation == "softmax": # Detr
+ softmax_scores = softmax(predicted_labels, axis=-1)
+ predicted_scores = np.max(softmax_scores, axis=-1)
+ predicted_labels = np.argmax(softmax_scores, axis=-1)
+ bool_filter = predicted_labels != background_class
+ else: # Deformable Detr
+ sigmoid_scores = sigmoid(predicted_labels)
+ predicted_scores = np.max(sigmoid_scores, axis=-1)
+ predicted_labels = np.argmax(sigmoid_scores, axis=-1)
+ threshold = 0.1 if threshold is None else threshold
+ bool_filter = predicted_scores > threshold
+
+
+ predicted_scores = predicted_scores[bool_filter]
+ predicted_labels = predicted_labels[bool_filter]
+ predicted_bbox = predicted_bbox[bool_filter]
+
+ if bbox_format == "xy_center":
+ predicted_bbox = predicted_bbox
+ elif bbox_format == "xyxy":
+ predicted_bbox = bbox.xcycwh_to_xy_min_xy_max(predicted_bbox)
+ elif bbox_format == "yxyx":
+ predicted_bbox = bbox.xcycwh_to_yx_min_yx_max(predicted_bbox)
+ else:
+ raise NotImplementedError()
+
+ return predicted_bbox, predicted_labels, predicted_scores
+
+def main(engine_path):
+ pass
+
+if __name__ == "__main__":
+ parser = argparse.ArgumentParser()
+ parser.add_argument("--engine_path", type=str, required=True)
+ args = parser.parse_args()
+ main(**vars(args))
\ No newline at end of file
diff --git a/detr_tensorrt/onnx2engine.py b/detr_tensorrt/onnx2engine.py
new file mode 100644
index 00000000..c831acf4
--- /dev/null
+++ b/detr_tensorrt/onnx2engine.py
@@ -0,0 +1,250 @@
+import argparse
+import ctypes
+import tensorrt as trt
+import os
+import onnx
+import numpy as np
+import onnx_graphsurgeon as gs
+
+from TRTEngineBuilder import TRTEngineBuilder, TRT_LOGGER
+from common import GiB
+
+
+def print_graph_io(graph):
+ # Print inputs:
+ print(" ===== Inputs =====")
+ for i in graph.inputs:
+ print(i)
+ # Print outputs:
+ print(" ===== Outputs =====")
+ for i in graph.outputs:
+ print(i)
+
+
+def io_name_handler(graph: gs.Graph):
+ len_suffix = len("tf_op_layer_")
+ for out in graph.outputs:
+ out.name = out.name[len_suffix:]
+
+
+def get_node_by_name(name, onnx_graph: gs.Graph):
+ for n in onnx_graph.nodes:
+ if name in n.name:
+ return n
+ return None
+
+
+def get_nodes_by_op(op_name, onnx_graph):
+ nodes = []
+ for n in onnx_graph.nodes:
+ if n.op == op_name:
+ nodes.append(n)
+ return nodes
+
+def get_nodes_by_prefix(prefix, onnx_graph: gs.Graph):
+ nodes = []
+ for n in onnx_graph.nodes:
+ if n.name.startswith(prefix):
+ nodes.append(n)
+ return nodes
+
+
+def fix_graph_detr(graph: gs.Graph):
+ # === Fix Pad 2 in Resnet backbone ===
+ # TensorRT supports padding only on 2 innermost dimensions
+ resnet_pad2 = get_node_by_name(
+ "detr/detr_finetuning/detr/backbone/pad2/Pad", graph)
+ resnet_pad2.inputs[1] = gs.Constant(
+ "pad2/pads_input", np.array([0, 0, 1, 1, 0, 0, 1, 1]))
+ graph.cleanup()
+ graph.toposort()
+ return graph
+
+
+def fix_graph_deformable_detr(graph: gs.Graph):
+ batch_size = graph.inputs[0].shape[0]
+ # === Fix Pad 2 in Resnet backbone ===
+ # TensorRT supports padding only on 2 innermost dimensions
+ resnet_pad2 = get_node_by_name(
+ "deformable-detr/deformable_detr/detr_core/backbone/pad2/Pad", graph)
+ unused_nodes = [resnet_pad2.i(1), resnet_pad2.i(1).i()]
+ resnet_pad2.inputs[1] = gs.Constant(
+ "pad2/pads_input", np.array([0, 0, 1, 1, 0, 0, 1, 1]))
+ for n in unused_nodes:
+ graph.nodes.remove(n)
+
+ # ======= Add nodes for MsDeformIm2ColTRT ===========
+ tf_im2col_nodes = get_nodes_by_op("MsDeformIm2col", graph)
+
+ spatial_shape_np = tf_im2col_nodes[0].inputs[1].values.reshape((1, -1, 2))
+ spatial_shape_tensor = gs.Constant(
+ name="MsDeformIm2Col_spatial_shape",
+ values=spatial_shape_np)
+
+ start_index_np = tf_im2col_nodes[0].inputs[2].values.reshape((1, -1))
+ start_index_tensor = gs.Constant(
+ name="MsDeformIm2Col_start_index",
+ values=start_index_np)
+
+ def handle_ops_MsDeformIm2ColTRT(graph: gs.Graph, node: gs.Node):
+ inputs = node.inputs
+ inputs.pop(1)
+ inputs.pop(1)
+ inputs.insert(1, start_index_tensor)
+ inputs.insert(1, spatial_shape_tensor)
+ outputs = node.outputs
+ graph.layer(
+ op="MsDeformIm2ColTRT",
+ name=node.name + "_trt",
+ inputs=inputs,
+ outputs=outputs)
+
+ for n in tf_im2col_nodes:
+ handle_ops_MsDeformIm2ColTRT(graph, n)
+ # Detach old node from graph
+ n.inputs.clear()
+ n.outputs.clear()
+ graph.nodes.remove(n)
+
+ # ======= Handle GroupNorm by TensorRT official plugin =======
+ gn_nodes = []
+ for i in range(4):
+ gn_nodes.append(
+ get_nodes_by_prefix(
+ f"deformable-detr/deformable_detr/detr_core/input_proj_gn/{i}", graph))
+
+ def handle_group_norm_nodes(nodes, graph:gs.Graph):
+ # Get GN name
+ gn_name = nodes[0].name[:-7]
+ # Get GN input tensors
+
+ gn_input = nodes[0].i().inputs[0]
+ # Get gamme input
+ mul_node = None
+ for n in nodes:
+ if n.name.endswith("/mul"):
+ mul_node = n
+ assert mul_node is not None
+ gamma_input = gs.Constant(
+ name=gn_name + "gamma:0",
+ values=mul_node.inputs[1].values.reshape((batch_size, -1)))
+ # Get beta input
+ sub_node = None
+ for n in nodes:
+ if n.name.endswith("batchnorm/sub"):
+ sub_node = n
+ assert sub_node is not None
+ beta_input = gs.Constant(
+ name=gn_name+"beta:0",
+ values=sub_node.inputs[0].values.reshape((batch_size, -1)))
+ # Get output tensor
+ gn_output = nodes[-1].outputs[0]
+ # print(gn_output)
+ # Add new plugin node to graph
+ graph.layer(
+ name=gn_name + "group_norm_trt",
+ inputs=[gn_input, gamma_input, beta_input],
+ outputs=[gn_output],
+ op="GroupNormalizationPlugin",
+ attrs={
+ "eps": 1e-5,
+ "num_groups": 32
+ })
+ # Detach gn_output from existing graph
+ gn_out_flatten = gn_output.outputs[0]
+ gn_out_flatten.inputs.pop(0)
+ # Add Transpose
+ transposed_tensor = graph.layer(
+ name=gn_name+"gn_out_transpose",
+ inputs=[gn_output],
+ outputs=[gn_name + "input_proj_flatten:0"],
+ op="Transpose",
+ attrs={"perm": [0, 2, 3, 1]}
+ )
+ gn_out_flatten.inputs.insert(0, transposed_tensor[0])
+ # Disconnect old nodes
+ nodes.insert(0, nodes[0].i()) # for clean up purpose
+ for n in nodes:
+ n.inputs.clear()
+ n.outputs.clear()
+ graph.nodes.remove(n)
+
+ for nodes in gn_nodes:
+ handle_group_norm_nodes(nodes, graph)
+
+
+ return graph
+
+
+def fix_onnx_graph(graph: gs.Graph, model: str):
+ if model == "detr":
+ return fix_graph_detr(graph)
+ elif model == "deformable-detr":
+ return fix_graph_deformable_detr(graph)
+
+
+def main(onnx_dir: str, model: str, precision: str, verbose: bool, **args):
+ print(model)
+ onnx_path = os.path.join(onnx_dir, model + ".onnx")
+ print(onnx_path)
+
+ graph = gs.import_onnx(onnx.load(onnx_path))
+ graph.toposort()
+
+ # === Change graph IO names
+ # print_graph_io(graph)
+ io_name_handler(graph)
+ print_graph_io(graph)
+
+ # === Fix graph to adapt to TensorRT
+ graph = fix_onnx_graph(graph, model)
+
+ # === Export adapted onnx for TRT engine
+ adapted_onnx_path = os.path.join(onnx_dir, model + "_trt.onnx")
+ onnx.save(gs.export_onnx(graph), adapted_onnx_path)
+
+ # === Build engine
+ if verbose:
+ trt_logger = trt.Logger(trt.Logger.VERBOSE)
+ else:
+ trt_logger = trt.Logger(trt.Logger.WARNING)
+
+ builder = TRTEngineBuilder(adapted_onnx_path, logger=trt_logger)
+
+ if precision == "FP32":
+ pass
+ if precision == "FP16":
+ builder.FP16_allowed = True
+ builder.strict_type = True
+ if precision == "MIX":
+ builder.FP16_allowed = True
+ builder.strict_type = False
+
+ builder.export_engine(os.path.join(
+ onnx_dir, model + f"_{precision.lower()}.engine"))
+
+
+if __name__ == "__main__":
+ parser = argparse.ArgumentParser()
+ parser.add_argument('model', type=str, default="detr",
+ help="detr/deformable-detr")
+ parser.add_argument("--precision", type=str,
+ default="FP16", help="FP32/FP16/MIX")
+ parser.add_argument('--onnx_dir', type=str, default=None,
+ help="path to dir containing the \{model_name\}.onnx file")
+ parser.add_argument("--verbose", action="store_true",
+ help="Print out TensorRT log of all levels")
+ args = parser.parse_args()
+
+ if "deformable" in args.model:
+ MS_DEFORM_IM2COL_PLUGIN_LIB = "./detr_tensorrt/plugins/ms_deform_im2col/build/libms_deform_im2col_trt.so"
+ ctypes.CDLL(MS_DEFORM_IM2COL_PLUGIN_LIB)
+ trt.init_libnvinfer_plugins(TRT_LOGGER, '')
+ PLUGIN_CREATORS = trt.get_plugin_registry().plugin_creator_list
+
+ if args.onnx_dir is None:
+ args.onnx_dir = os.path.join(
+ "weights", args.model, args.model + "_trt")
+ # for plugin in PLUGIN_CREATORS:
+ # print(plugin.name, plugin.plugin_version)
+ main(**vars(args))
diff --git a/detr_tensorrt/plugins/ms_deform_im2col/CMakeLists.txt b/detr_tensorrt/plugins/ms_deform_im2col/CMakeLists.txt
new file mode 100644
index 00000000..e22e58b3
--- /dev/null
+++ b/detr_tensorrt/plugins/ms_deform_im2col/CMakeLists.txt
@@ -0,0 +1,58 @@
+# We need cmake >= 3.8, since 3.8 introduced CUDA as a first class language
+cmake_minimum_required(VERSION 3.8 FATAL_ERROR)
+project(MsDeformIm2ColTRT LANGUAGES CXX CUDA)
+
+# Enable all compile warnings
+set(CMAKE_CXX_FLAGS "${CMAKE_CXX_FLAGS} -Wall -Wno-long-long -pedantic -Wno-deprecated-declarations")
+
+# Sets variable to a value if variable is unset.
+macro(set_ifndef var val)
+ if (NOT ${var})
+ set(${var} ${val})
+ endif()
+ message(STATUS "Configurable variable ${var} set to ${${var}}")
+endmacro()
+
+# -------- CONFIGURATION --------
+set_ifndef(TRT_LIB /home/ubuntu/TensorRT-8.0.0.3/lib)
+set_ifndef(TRT_INCLUDE /home/ubuntu/TensorRT-8.0.0.3/include)
+set_ifndef(CUDA_INC_DIR /usr/local/cuda/include)
+set_ifndef(CUDA_ARCH_SM 70) # should be fine for Tesla V100
+
+# Find dependencies:
+message("\nThe following variables are derived from the values of the previous variables unless provided explicitly:\n")
+
+# TensorRT's nvinfer lib
+find_library(_NVINFER_LIB nvinfer HINTS ${TRT_LIB} PATH_SUFFIXES lib lib64)
+set_ifndef(NVINFER_LIB ${_NVINFER_LIB})
+
+
+# -------- BUILDING --------
+
+# Add include directories
+include_directories(${CUDA_INC_DIR} ${TRT_INCLUDE} ${CMAKE_SOURCE_DIR}/sources/)
+message(STATUS "CUDA_INC_DIR: ${CUDA_INC_DIR}")
+# Define plugin library target
+add_library(ms_deform_im2col_trt MODULE
+${CMAKE_SOURCE_DIR}/sources/ms_deform_im2col_kernel.cu
+${CMAKE_SOURCE_DIR}/sources/ms_deform_im2col_kernel.h
+${CMAKE_SOURCE_DIR}/sources/ms_deform_im2col_plugin.cpp
+${CMAKE_SOURCE_DIR}/sources/ms_deform_im2col_plugin.h
+)
+
+# Use C++11
+target_compile_features(ms_deform_im2col_trt PUBLIC cxx_std_11)
+
+# Link TensorRT's nvinfer lib
+target_link_libraries(ms_deform_im2col_trt PRIVATE ${NVINFER_LIB})
+
+# We need to explicitly state that we need all CUDA files
+# to be built with -dc as the member functions will be called by
+# other libraries and executables (in our case, Python inference scripts)
+set_target_properties(ms_deform_im2col_trt PROPERTIES
+CUDA_SEPARABLE_COMPILATION ON
+)
+
+# CUDA ARCHITECTURE
+set_target_properties(ms_deform_im2col_trt PROPERTIES
+CUDA_ARCHITECTURES "${CUDA_ARCH_SM}")
diff --git a/detr_tensorrt/plugins/ms_deform_im2col/README.txt b/detr_tensorrt/plugins/ms_deform_im2col/README.txt
new file mode 100644
index 00000000..8ef6c1a9
--- /dev/null
+++ b/detr_tensorrt/plugins/ms_deform_im2col/README.txt
@@ -0,0 +1,16 @@
+To build the plugin:
+mkdir build && cd build
+cmake .. && make -j
+
+NOTE: If any of the dependencies are not installed in their default locations, you can manually specify them. For example:
+
+cmake .. -DPYBIND11_DIR=/path/to/pybind11/
+ -DCMAKE_CUDA_COMPILER=/usr/local/cuda-x.x/bin/nvcc (Or adding /path/to/nvcc into $PATH)
+ -DCUDA_INC_DIR=/usr/local/cuda-x.x/include/ (Or adding /path/to/cuda/include into $CPLUS_INCLUDE_PATH)
+ -DPYTHON3_INC_DIR=/usr/include/python3.6/
+ -DTRT_LIB=/path/to/tensorrt/lib/
+ -DTRT_INCLUDE=/path/to/tensorrt/include/
+ -DCUDA_ARCH_SM=70
+
+Check matching sm for Nvidia GPU:
+https://arnon.dk/matching-sm-architectures-arch-and-gencode-for-various-nvidia-cards/
\ No newline at end of file
diff --git a/detr_tensorrt/plugins/ms_deform_im2col/sources/ms_deform_im2col_kernel.cu b/detr_tensorrt/plugins/ms_deform_im2col/sources/ms_deform_im2col_kernel.cu
new file mode 100644
index 00000000..a9b20cb3
--- /dev/null
+++ b/detr_tensorrt/plugins/ms_deform_im2col/sources/ms_deform_im2col_kernel.cu
@@ -0,0 +1,326 @@
+#include "ms_deform_im2col_kernel.h"
+#include
+#include
+#include "cuda_fp16.h"
+#include "NvInfer.h"
+
+
+#define assertm(exp, msg) assert(((void)msg, exp))
+
+#define CUDA_KERNEL_LOOP(i, n) \
+ for (int i = blockIdx.x * blockDim.x + threadIdx.x; \
+ i < n; i += blockDim.x * gridDim.x)
+
+const int CUDA_NUM_THREADS = 1024;
+
+int GET_BLOCKS(const int N, const int num_threads)
+{
+ return (N + num_threads - 1) / num_threads;
+}
+
+template
+__device__ scalar_t ms_deform_attn_im2col_bilinear(const scalar_t* &bottom_data,
+ const int &height, const int &width, const int &nheads, const int &channels,
+ const scalar_t &h, const scalar_t &w, const int &m, const int &c)
+{
+ const int h_low = floor(h);
+ const int w_low = floor(w);
+ const int h_high = h_low + 1;
+ const int w_high = w_low + 1;
+
+ const scalar_t lh = h - h_low;
+ const scalar_t lw = w - w_low;
+ const scalar_t hh = 1 - lh, hw = 1 - lw;
+
+ const int w_stride = nheads * channels;
+ const int h_stride = width * w_stride;
+ const int h_low_ptr_offset = h_low * h_stride;
+ const int h_high_ptr_offset = h_low_ptr_offset + h_stride;
+ const int w_low_ptr_offset = w_low * w_stride;
+ const int w_high_ptr_offset = w_low_ptr_offset + w_stride;
+ const int base_ptr = m * channels + c;
+
+ scalar_t v1 = 0;
+ if (h_low >= 0 && w_low >= 0)
+ {
+ const int ptr1 = h_low_ptr_offset + w_low_ptr_offset + base_ptr;
+ v1 = bottom_data[ptr1];
+ }
+ scalar_t v2 = 0;
+ if (h_low >= 0 && w_high <= width - 1)
+ {
+ const int ptr2 = h_low_ptr_offset + w_high_ptr_offset + base_ptr;
+ v2 = bottom_data[ptr2];
+ }
+ scalar_t v3 = 0;
+ if (h_high <= height - 1 && w_low >= 0)
+ {
+ const int ptr3 = h_high_ptr_offset + w_low_ptr_offset + base_ptr;
+ v3 = bottom_data[ptr3];
+ }
+ scalar_t v4 = 0;
+ if (h_high <= height - 1 && w_high <= width - 1)
+ {
+ const int ptr4 = h_high_ptr_offset + w_high_ptr_offset + base_ptr;
+ v4 = bottom_data[ptr4];
+ }
+
+ const scalar_t w1 = hh * hw, w2 = hh * lw, w3 = lh * hw, w4 = lh * lw;
+
+ const scalar_t val = (w1 * v1 + w2 * v2 + w3 * v3 + w4 * v4);
+ return val;
+}
+
+__device__ half ms_deform_attn_im2col_bilinear_half(const half* &bottom_data,
+ const int &height, const int &width, const int &nheads, const int &channels,
+ const half &h, const half &w, const int &m, const int &c)
+{
+ half one = __float2half(1.0f);
+ half zero = __float2half(0.0f);
+
+ const half h_low = hfloor(h);
+ const half w_low = hfloor(w);
+ const int h_high = hceil(h);
+ const int w_high = hceil(w);
+
+
+ const half lh = h - h_low;
+ const half lw = w - w_low;
+ const half hh = one - lh, hw = one - lw;
+
+ const unsigned int w_stride = nheads * channels;
+ const unsigned int h_stride = width * w_stride;
+ const unsigned int h_low_ptr_offset = __half2uint_rd(h_low) * h_stride;
+ const unsigned int h_high_ptr_offset = h_low_ptr_offset + h_stride;
+ const unsigned int w_low_ptr_offset = __half2uint_rd(w_low) * w_stride;
+ const unsigned int w_high_ptr_offset = w_low_ptr_offset + w_stride;
+ const unsigned int base_ptr = m * channels + c;
+
+ half v1 = __float2half(0.0f);
+ if (h_low >= zero && w_low >= zero)
+ {
+ const int ptr1 = h_low_ptr_offset + w_low_ptr_offset + base_ptr;
+ v1 = bottom_data[ptr1];
+ }
+ half v2 = 0;
+ if (h_low >= zero && w_high <= width - 1)
+ {
+ const int ptr2 = h_low_ptr_offset + w_high_ptr_offset + base_ptr;
+ v2 = bottom_data[ptr2];
+ }
+ half v3 = 0;
+ if (h_high <= height - 1 && w_low >= zero)
+ {
+ const int ptr3 = h_high_ptr_offset + w_low_ptr_offset + base_ptr;
+ v3 = bottom_data[ptr3];
+ }
+ half v4 = 0;
+ if (h_high <= height - 1 && w_high <= width - 1)
+ {
+ const int ptr4 = h_high_ptr_offset + w_high_ptr_offset + base_ptr;
+ v4 = bottom_data[ptr4];
+ }
+
+ const half w1 = hh * hw, w2 = hh * lw, w3 = lh * hw, w4 = lh * lw;
+
+ const half val = (w1 * v1 + w2 * v2 + w3 * v3 + w4 * v4);
+ return val;
+}
+
+
+template
+__global__ void ms_deformable_im2col_gpu_kernel(const int n,
+ const scalar_t *data_value,
+ const int *data_spatial_shapes,
+ const int *data_level_start_index,
+ const scalar_t *data_sampling_loc,
+ const scalar_t *data_attn_weight,
+ const int batch_size,
+ const int spatial_size,
+ const int num_heads,
+ const int channels,
+ const int num_levels,
+ const int num_query,
+ const int num_point,
+ scalar_t *data_col)
+{
+ CUDA_KERNEL_LOOP(index, n)
+ {
+ int _temp = index;
+ const int c_col = _temp % channels;
+ _temp /= channels;
+ const int sampling_index = _temp;
+ const int m_col = _temp % num_heads;
+ _temp /= num_heads;
+ const int q_col = _temp % num_query;
+ _temp /= num_query;
+ const int b_col = _temp;
+
+ scalar_t *data_col_ptr = data_col + index;
+ int data_weight_ptr = sampling_index * num_levels * num_point;
+ int data_loc_w_ptr = data_weight_ptr << 1;
+ const int qid_stride = num_heads * channels;
+ const int data_value_ptr_init_offset = b_col * spatial_size * qid_stride;
+ scalar_t col = 0;
+
+ for (int l_col=0; l_col < num_levels; ++l_col)
+ {
+ const int level_start_id = data_level_start_index[l_col];
+ const int spatial_h_ptr = l_col << 1;
+ const int spatial_h = data_spatial_shapes[spatial_h_ptr];
+ const int spatial_w = data_spatial_shapes[spatial_h_ptr + 1];
+ const scalar_t *data_value_ptr = data_value + (data_value_ptr_init_offset + level_start_id * qid_stride);
+ for (int p_col=0; p_col < num_point; ++p_col)
+ {
+ const scalar_t loc_w = data_sampling_loc[data_loc_w_ptr];
+ const scalar_t loc_h = data_sampling_loc[data_loc_w_ptr + 1];
+ const scalar_t weight = data_attn_weight[data_weight_ptr];
+
+ const scalar_t h_im = loc_h * (spatial_h-1); //- 0.5;
+ const scalar_t w_im = loc_w * (spatial_w-1); //- 0.5;
+
+ if (h_im > -1 && w_im > -1 && h_im < spatial_h && w_im < spatial_w)
+ {
+ col += ms_deform_attn_im2col_bilinear(data_value_ptr, spatial_h, spatial_w, num_heads, channels, h_im, w_im, m_col, c_col) * weight;
+ }
+
+ data_weight_ptr += 1;
+ data_loc_w_ptr += 2;
+ }
+ }
+ *data_col_ptr = col;
+ }
+}
+
+__global__ void ms_deformable_im2col_gpu_kernel_half(const int n,
+ const half *data_value,
+ const int *data_spatial_shapes,
+ const int *data_level_start_index,
+ const half *data_sampling_loc,
+ const half *data_attn_weight,
+ const int batch_size,
+ const int spatial_size,
+ const int num_heads,
+ const int channels,
+ const int num_levels,
+ const int num_query,
+ const int num_point,
+ half *data_col)
+{
+ half one(1.0f);
+ CUDA_KERNEL_LOOP(index, n)
+ {
+ int _temp = index;
+ const int c_col = _temp % channels;
+ _temp /= channels;
+ const int sampling_index = _temp;
+ const int m_col = _temp % num_heads;
+ _temp /= num_heads;
+ const int q_col = _temp % num_query;
+ _temp /= num_query;
+ const int b_col = _temp;
+
+ half *data_col_ptr = data_col + index;
+ int data_weight_ptr = sampling_index * num_levels * num_point;
+ int data_loc_w_ptr = data_weight_ptr << 1;
+ const int qid_stride = num_heads * channels;
+ const int data_value_ptr_init_offset = b_col * spatial_size * qid_stride;
+ half col = 0;
+
+ for (int l_col=0; l_col < num_levels; ++l_col)
+ {
+ const int level_start_id = data_level_start_index[l_col];
+ const int spatial_h_ptr = l_col << 1;
+ const int spatial_h = data_spatial_shapes[spatial_h_ptr];
+ const int spatial_w = data_spatial_shapes[spatial_h_ptr + 1];
+ const half *data_value_ptr = data_value + (data_value_ptr_init_offset + level_start_id * qid_stride);
+ for (int p_col=0; p_col < num_point; ++p_col)
+ {
+ const half loc_w = data_sampling_loc[data_loc_w_ptr];
+ const half loc_h = data_sampling_loc[data_loc_w_ptr + 1];
+ const half weight = data_attn_weight[data_weight_ptr];
+
+ const half h_im = loc_h * __int2half_rd(spatial_h-1); //- 0.5;
+ const half w_im = loc_w * __int2half_rd(spatial_w-1); //- 0.5;
+
+ if (h_im > -one && w_im > -one && h_im < __int2half_rd(spatial_h) && w_im < __int2half_rd(spatial_w))
+ {
+ col += ms_deform_attn_im2col_bilinear_half(data_value_ptr, spatial_h, spatial_w, num_heads, channels, h_im, w_im, m_col, c_col) * weight;
+ }
+
+ data_weight_ptr += 1;
+ data_loc_w_ptr += 2;
+ }
+ }
+ *data_col_ptr = col;
+ }
+}
+
+int ms_deform_im2col_inference(
+ cudaStream_t stream,
+ const void* data_value,
+ const void* data_spatial_shapes,
+ const void* data_level_start_index,
+ const void* data_sampling_loc,
+ const void* data_attn_weight,
+ const int batch_size,
+ const int spatial_size,
+ const int num_heads,
+ const int channels,
+ const int num_levels,
+ const int num_query,
+ const int num_point,
+ void* data_col,
+ DataType mDataType
+)
+{
+ const int num_kernels = batch_size * num_query * num_heads * channels;
+ const int num_actual_kernels = batch_size * num_query * num_heads * channels;
+ const int num_threads = CUDA_NUM_THREADS;
+
+
+ if(mDataType == DataType::kFLOAT)
+ {
+ // printf("Hey FLOAT \n");
+ ms_deformable_im2col_gpu_kernel
+ <<>>(
+ num_kernels,
+ static_cast(data_value),
+ static_cast(data_spatial_shapes),
+ static_cast(data_level_start_index),
+ static_cast(data_sampling_loc),
+ static_cast(data_attn_weight),
+ batch_size,
+ spatial_size,
+ num_heads,
+ channels,
+ num_levels,
+ num_query,
+ num_point,
+ static_cast(data_col));
+ }
+ else if(mDataType == DataType::kHALF)
+ {
+ // printf("Hey HALF \n");
+ ms_deformable_im2col_gpu_kernel_half
+ <<>>(
+ num_kernels,
+ static_cast(data_value),
+ static_cast(data_spatial_shapes),
+ static_cast(data_level_start_index),
+ static_cast(data_sampling_loc),
+ static_cast(data_attn_weight),
+ batch_size,
+ spatial_size,
+ num_heads,
+ channels,
+ num_levels,
+ num_query,
+ num_point,
+ static_cast(data_col));
+
+ }
+ else return -1;
+
+ return 0;
+}
\ No newline at end of file
diff --git a/detr_tensorrt/plugins/ms_deform_im2col/sources/ms_deform_im2col_kernel.h b/detr_tensorrt/plugins/ms_deform_im2col/sources/ms_deform_im2col_kernel.h
new file mode 100644
index 00000000..128539d3
--- /dev/null
+++ b/detr_tensorrt/plugins/ms_deform_im2col/sources/ms_deform_im2col_kernel.h
@@ -0,0 +1,29 @@
+#ifndef MS_DEFORM_IM2COL_KERNEL
+#define MS_DEFORM_IM2COL_KERNEL
+
+#include "NvInfer.h"
+#include "cuda_fp16.h"
+
+using namespace nvinfer1;
+
+
+int ms_deform_im2col_inference(
+ cudaStream_t stream,
+ const void* data_value,
+ const void* data_spatial_shapes,
+ const void* data_level_start_index,
+ const void* data_sampling_loc,
+ const void* data_attn_weight,
+ const int batch_size,
+ const int spatial_size,
+ const int num_heads,
+ const int channels,
+ const int num_levels,
+ const int num_query,
+ const int num_point,
+ void* data_col,
+ DataType mDataType
+);
+
+#endif
+
diff --git a/detr_tensorrt/plugins/ms_deform_im2col/sources/ms_deform_im2col_plugin.cpp b/detr_tensorrt/plugins/ms_deform_im2col/sources/ms_deform_im2col_plugin.cpp
new file mode 100644
index 00000000..8a5a4cf2
--- /dev/null
+++ b/detr_tensorrt/plugins/ms_deform_im2col/sources/ms_deform_im2col_plugin.cpp
@@ -0,0 +1,281 @@
+#include
+#include
+#include
+#include
+#include
+
+#include "ms_deform_im2col_plugin.h"
+#include "NvInfer.h"
+#include "ms_deform_im2col_kernel.h"
+
+using namespace std;
+
+#define assertm(exp, msg) assert(((void)msg, exp))
+
+using namespace nvinfer1;
+
+namespace {
+ static const char* MS_DEFORM_IM2COL_PLUGIN_VERSION{"1"};
+ static const char* MS_DEFORM_IM2COL_PLUGIN_NAME{"MsDeformIm2ColTRT"};
+}
+
+// Static class fields initialization
+PluginFieldCollection MsDeformIm2ColCreator::mFC{};
+std::vector MsDeformIm2ColCreator::mPluginAttributes;
+
+// statically registers the Plugin Creator to the Plugin Registry of TensorRT
+REGISTER_TENSORRT_PLUGIN(MsDeformIm2ColCreator);
+
+// Helper function for serializing plugin
+template
+void writeToBuffer(char*& buffer, const T& val)
+{
+ *reinterpret_cast(buffer) = val;
+ buffer += sizeof(T);
+}
+
+// Helper function for deserializing plugin
+template
+T readFromBuffer(const char*& buffer)
+{
+ T val = *reinterpret_cast(buffer);
+ buffer += sizeof(T);
+ return val;
+}
+
+
+MsDeformIm2Col::MsDeformIm2Col(const std::string name)
+ : mLayerName(name)
+{
+}
+
+MsDeformIm2Col::MsDeformIm2Col(const std::string name, const void* data, size_t length)
+ : mLayerName(name)
+{
+ // Deserialize in the same order as serialization
+ const char *d = static_cast(data);
+ const char *a = d;
+
+ im2col_step = readFromBuffer(d);
+ spatial_size = readFromBuffer(d);
+ num_heads = readFromBuffer(d);
+ channels = readFromBuffer(d);
+ num_levels = readFromBuffer(d);
+ num_query = readFromBuffer(d);
+ num_point = readFromBuffer(d);
+ mDataType = readFromBuffer(d);
+
+ assert(d == (a + length));
+}
+
+const char* MsDeformIm2Col::getPluginType() const noexcept
+{
+ return MS_DEFORM_IM2COL_PLUGIN_NAME;
+}
+
+const char* MsDeformIm2Col::getPluginVersion() const noexcept
+{
+ return MS_DEFORM_IM2COL_PLUGIN_VERSION;
+}
+
+int MsDeformIm2Col::getNbOutputs() const noexcept
+{
+ return 1;
+}
+
+Dims MsDeformIm2Col::getOutputDimensions(int index, const Dims* inputs, int nbInputDims) noexcept
+{
+ int len_q = inputs[3].d[0];
+ int num_heads = inputs[0].d[1];
+ int head_dim = inputs[0].d[2];
+ return Dims2(len_q, num_heads * head_dim);
+}
+
+int MsDeformIm2Col::initialize() noexcept
+{
+ return 0;
+}
+
+int MsDeformIm2Col::enqueue(int batchSize, const void* const* inputs, void* const* outputs,
+ void* workspace, cudaStream_t stream) noexcept
+{
+ int status = -1;
+ // Launch CUDA kernel wrapper and save its return value
+ status = ms_deform_im2col_inference(
+ stream, inputs[0], inputs[1], inputs[2], inputs[3], inputs[4],
+ batchSize,spatial_size, num_heads, channels,
+ num_levels, num_query, num_point, outputs[0], mDataType);
+ assert(status == 0);
+ return status;
+}
+
+size_t MsDeformIm2Col::getSerializationSize() const noexcept
+{
+ // 7 int paramters
+ return 7 * sizeof(int32_t) + sizeof(DataType);
+}
+
+void MsDeformIm2Col::serialize(void* buffer) const noexcept
+{
+ char *d = static_cast(buffer);
+ const char *a = d;
+
+ writeToBuffer(d, im2col_step);
+ writeToBuffer(d, spatial_size);
+ writeToBuffer(d, num_heads);
+ writeToBuffer(d, channels);
+ writeToBuffer(d, num_levels);
+ writeToBuffer(d, num_query);
+ writeToBuffer(d, num_point);
+ writeToBuffer(d, mDataType);
+
+ assert(d == a + getSerializationSize());
+}
+
+void MsDeformIm2Col::terminate() noexcept {}
+
+void MsDeformIm2Col::destroy() noexcept {
+ // This gets called when the network containing plugin is destroyed
+ delete this;
+}
+
+DataType MsDeformIm2Col::getOutputDataType(int32_t index, const DataType *inputTypes, int32_t nbInputs) const noexcept
+{
+ // only 1 output
+ assert(index == 0);
+ assert(nbInputs == 5);
+ return inputTypes[0]; // return type of input tensor image
+}
+
+bool MsDeformIm2Col::isOutputBroadcastAcrossBatch(int32_t outputIndex,
+ const bool* inputIsBroadcasted, int32_t nbInputs) const noexcept
+{
+ return false;
+}
+
+bool MsDeformIm2Col::canBroadcastInputAcrossBatch(int inputIndex) const noexcept
+{
+ return false;
+}
+
+void MsDeformIm2Col::configurePlugin(const PluginTensorDesc* in, int32_t nbInput,
+ const PluginTensorDesc* out, int32_t nbOutput) noexcept
+{
+ assertm(nbInput == 5, "Must provide 5 inputs: value, spatial_shape, start_index, sampling_locations, attn_weights\n");
+ assertm(in[0].dims.nbDims == 3, "flatten_value must have shape (len_in, num_head, head_dim)\n");
+ assertm(in[1].dims.nbDims == 2, "spatial_shapes must have shape (num_levels, 2)\n");
+ assertm(in[2].dims.nbDims == 1, "start_index must have shape (num_levels, )\n");
+ assertm(in[3].dims.nbDims == 5, "sampling_loc must have shape (len_q, num_head, num_levels, num_points, 2)\n");
+ assertm(in[4].dims.nbDims == 4, "attn_weights must have shape (len_q, num_head, num_levels, num_points)\n");
+ assertm(nbOutput == 1, "This layer has only one output.\n");
+
+ im2col_step = 64;
+ spatial_size = in[0].dims.d[0];
+ num_heads = in[0].dims.d[1];
+ channels = in[0].dims.d[2];
+ num_levels = in[3].dims.d[2];
+ num_query = in[3].dims.d[0];
+ num_point = in[3].dims.d[3];
+ mDataType = in[0].type;
+
+ // cout << "DEBUG in[0].type: " << (int)in[0].type << endl;
+ // cout << "MsDeformIm2Col DEBUG: im2col_step=" << im2col_step << endl;
+ // cout << "MsDeformIm2Col DEBUG: spatial_size=" << spatial_size << endl;
+ // cout << "MsDeformIm2Col DEBUG: num_heads=" << num_heads << endl;
+ // cout << "MsDeformIm2Col DEBUG: channels=" << channels << endl;
+ // cout << "MsDeformIm2Col DEBUG: num_levels=" << num_levels << endl;
+ // cout << "MsDeformIm2Col DEBUG: num_query=" << num_query << endl;
+ // cout << "MsDeformIm2Col DEBUG: num_point=" << num_point << endl;
+}
+
+bool MsDeformIm2Col::supportsFormatCombination(int pos, const PluginTensorDesc* inOut,
+ int nbInputs, int nbOutputs) const noexcept
+{
+ bool ret;
+ ret = inOut[pos].format == TensorFormat::kLINEAR;
+ if((pos == 1) || (pos == 2))
+ {
+ return ret && (inOut[pos].type == DataType::kINT32);
+ }
+ else
+ {
+ bool type_supported = (inOut[pos].type == DataType::kFLOAT) || (inOut[pos].type == DataType::kHALF);
+ type_supported = type_supported && (inOut[pos].type == inOut[0].type);
+ return ret && type_supported;
+ }
+
+}
+
+IPluginV2Ext* MsDeformIm2Col::clone() const noexcept
+{
+ auto plugin = new MsDeformIm2Col(mLayerName);
+ plugin->im2col_step = im2col_step;
+ plugin->spatial_size = spatial_size;
+ plugin->num_heads = num_heads;
+ plugin->channels = channels;
+ plugin->num_levels = num_levels;
+ plugin->num_query = num_query;
+ plugin->num_point = num_point;
+ plugin->mDataType = mDataType;
+ plugin->setPluginNamespace(mNamespace.c_str());
+ return plugin;
+}
+
+void MsDeformIm2Col::setPluginNamespace(const char* libNamespace) noexcept
+{
+ mNamespace = libNamespace;
+}
+
+const char* MsDeformIm2Col::getPluginNamespace() const noexcept
+{
+ return mNamespace.c_str();
+}
+
+MsDeformIm2ColCreator::MsDeformIm2ColCreator()
+{
+ // Describe MsDeformIm2Col's required PluginField arguments
+
+ // Fill PluginFieldCollection with PluginField arguments metadata
+ mFC.nbFields = mPluginAttributes.size();
+ mFC.fields = mPluginAttributes.data();
+}
+
+const char* MsDeformIm2ColCreator::getPluginName() const noexcept
+{
+ return MS_DEFORM_IM2COL_PLUGIN_NAME;
+}
+
+const char* MsDeformIm2ColCreator::getPluginVersion() const noexcept
+{
+ return MS_DEFORM_IM2COL_PLUGIN_VERSION;
+}
+
+const PluginFieldCollection* MsDeformIm2ColCreator::getFieldNames() noexcept
+{
+ return &mFC;
+}
+
+IPluginV2* MsDeformIm2ColCreator::createPlugin(const char* name, const PluginFieldCollection* fc) noexcept
+{
+ // const PluginField* fields = fc->fields;
+
+ // Parse fields from PluginFieldCollection
+ return new MsDeformIm2Col(name);
+}
+
+IPluginV2* MsDeformIm2ColCreator::deserializePlugin(const char* name, const void* serialData, size_t serialLength) noexcept
+{
+ // This object will be deleted when the network is destroyed, which will
+ // call MsDeformIm2Col::destroy()
+ return new MsDeformIm2Col(name, serialData, serialLength);
+}
+
+void MsDeformIm2ColCreator::setPluginNamespace(const char* libNamespace) noexcept
+{
+ mNamespace = libNamespace;
+}
+
+const char* MsDeformIm2ColCreator::getPluginNamespace() const noexcept
+{
+ return mNamespace.c_str();
+}
diff --git a/detr_tensorrt/plugins/ms_deform_im2col/sources/ms_deform_im2col_plugin.h b/detr_tensorrt/plugins/ms_deform_im2col/sources/ms_deform_im2col_plugin.h
new file mode 100644
index 00000000..aef020a5
--- /dev/null
+++ b/detr_tensorrt/plugins/ms_deform_im2col/sources/ms_deform_im2col_plugin.h
@@ -0,0 +1,103 @@
+#ifndef MS_DEFORM_IM2COL_TRT_PLUGIN_H
+#define MS_DEFORM_IM2COL_TRT_PLUGIN_H
+
+#include "NvInferPlugin.h"
+#include
+#include
+
+
+using namespace nvinfer1;
+
+// One of the preferred ways of making TensorRT to be able to see
+// our custom layer requires extending IPluginV2 and IPluginCreator classes.
+// For requirements for overriden functions, check TensorRT API docs.
+
+class MsDeformIm2Col : public IPluginV2IOExt
+{
+public:
+ MsDeformIm2Col(const std::string name);
+
+ MsDeformIm2Col(const std::string name, const void* data, size_t length);
+
+ // It doesn't make sense to make MsDeformIm2Col without arguments, so we delete default constructor.
+ MsDeformIm2Col() = delete;
+
+ int getNbOutputs() const noexcept override;
+
+ Dims getOutputDimensions(int index, const Dims* inputs, int nbInputDims) noexcept override;
+
+ int initialize() noexcept override;
+
+ void terminate() noexcept override;
+
+ size_t getWorkspaceSize(int) const noexcept override { return 0; };
+
+ int enqueue(int batchSize, const void* const* inputs, void* const* outputs, void* workspace,
+ cudaStream_t stream) noexcept override;
+
+ size_t getSerializationSize() const noexcept override;
+
+ void serialize(void* buffer) const noexcept override;
+
+ const char* getPluginType() const noexcept override;
+
+ const char* getPluginVersion() const noexcept override;
+
+ void destroy() noexcept override;
+
+ IPluginV2Ext* clone() const noexcept override;
+
+ void setPluginNamespace(const char* pluginNamespace) noexcept override;
+
+ const char* getPluginNamespace() const noexcept override;
+
+ DataType getOutputDataType(int32_t index, const nvinfer1::DataType *inputTypes, int32_t nbInputs) const noexcept override;
+
+ bool isOutputBroadcastAcrossBatch(int32_t outputIndex, const bool* inputIsBroadcasted, int32_t nbInputs) const noexcept override;
+
+ bool canBroadcastInputAcrossBatch(int inputIndex) const noexcept override;
+
+ void configurePlugin(const PluginTensorDesc* in, int32_t nbInput, const PluginTensorDesc* out, int32_t nbOutput) noexcept override;
+
+ bool supportsFormatCombination(int pos, const PluginTensorDesc* inOut, int nbInputs, int nbOutputs) const noexcept override;
+
+private:
+ const std::string mLayerName;
+ std::string mNamespace;
+ int im2col_step;
+ int spatial_size;
+ int num_heads;
+ int channels;
+ int num_levels;
+ int num_query;
+ int num_point;
+ DataType mDataType;
+ // DataType mDataType = DataType::kHALF;
+};
+
+class MsDeformIm2ColCreator : public IPluginCreator
+{
+public:
+ MsDeformIm2ColCreator();
+
+ const char* getPluginName() const noexcept override;
+
+ const char* getPluginVersion() const noexcept override;
+
+ const PluginFieldCollection* getFieldNames() noexcept override;
+
+ IPluginV2* createPlugin(const char* name, const PluginFieldCollection* fc) noexcept override;
+
+ IPluginV2* deserializePlugin(const char* name, const void* serialData, size_t serialLength) noexcept override;
+
+ void setPluginNamespace(const char* pluginNamespace) noexcept override;
+
+ const char* getPluginNamespace() const noexcept override;
+
+private:
+ static PluginFieldCollection mFC;
+ static std::vector mPluginAttributes;
+ std::string mNamespace;
+};
+
+#endif
diff --git a/detr_tensorrt/plugins/ms_deform_im2col/test.py b/detr_tensorrt/plugins/ms_deform_im2col/test.py
new file mode 100644
index 00000000..111c039b
--- /dev/null
+++ b/detr_tensorrt/plugins/ms_deform_im2col/test.py
@@ -0,0 +1,119 @@
+import tensorrt as trt
+import pycuda.driver as cuda
+import numpy as np
+import os
+import ctypes
+import re
+import time
+
+from surroundnet.detr.tensorrt.TRTExecutor import TRTExecutor
+
+TRT_LOGGER = trt.Logger(trt.Logger.VERBOSE)
+EXPLICIT_BATCH = 1 << (int)(trt.NetworkDefinitionCreationFlag.EXPLICIT_BATCH)
+
+current_path = os.path.dirname(os.path.abspath(__file__))
+
+MS_DEFORM_IM2COL_PLUGIN_LIB = "./detr_tensorrt/plugins/ms_deform_im2col/build/libms_deform_im2col_trt.so"
+ctypes.CDLL(MS_DEFORM_IM2COL_PLUGIN_LIB)
+trt.init_libnvinfer_plugins(TRT_LOGGER, '')
+PLUGIN_CREATORS = trt.get_plugin_registry().plugin_creator_list
+
+def GiB(val):
+ return val * 1 << 30
+
+def camel_to_snake(name):
+ name = re.sub('(.)([A-Z][a-z]+)', r'\1_\2', name)
+ return re.sub('([a-z0-9])([A-Z])', r'\1_\2', name).lower()
+
+def get_trt_plugin(plugin_name):
+ plugin = None
+ for plugin_creator in PLUGIN_CREATORS:
+ if plugin_creator.name == plugin_name:
+ plugin = plugin_creator.create_plugin(camel_to_snake(plugin_name), None)
+ if plugin is None:
+ raise Exception(f"plugin {plugin_name} not found")
+ return plugin
+
+def build_test_engine(input_shape, dtype=trt.float32):
+ num_level = input_shape["flatten_sampling_loc"][3]
+
+ with trt.Builder(TRT_LOGGER) as builder, builder.create_network(EXPLICIT_BATCH) as network:
+ builder.max_batch_size = 1
+ config = builder.create_builder_config()
+ config.max_workspace_size = GiB(5)
+ if dtype == trt.float16:
+ config.set_flag(trt.BuilderFlag.FP16)
+ config.set_flag(trt.BuilderFlag.STRICT_TYPES)
+
+ input_flatten_value = network.add_input(
+ name="input_flatten_value", dtype=dtype, shape=input_shape["flatten_value"])
+ input_spatial_shapes = network.add_input(
+ name="input_spatial_shapes", dtype=trt.int32, shape=(1, num_level, 2))
+ input_start_index = network.add_input(
+ name="input_start_index", dtype=trt.int32, shape=(1, num_level))
+ input_flatten_sampling_loc = network.add_input(
+ name="input_flatten_sampling_loc", dtype=dtype, shape=input_shape["flatten_sampling_loc"])
+ input_flatten_attn_weight = network.add_input(
+ name="input_flatten_attn_weight", dtype=dtype, shape=input_shape["flatten_attn_weight"])
+
+ ms_deform_im2col_node = network.add_plugin_v2(
+ inputs=[
+ input_flatten_value, input_spatial_shapes,
+ input_start_index, input_flatten_sampling_loc,
+ input_flatten_attn_weight],
+ plugin=get_trt_plugin("MsDeformIm2ColTRT")
+ )
+ ms_deform_im2col_node.name = "ms_deform_im2col_node"
+ ms_deform_im2col_node.get_output(0).name = "im2col_output"
+
+ network.mark_output(ms_deform_im2col_node.get_output(0))
+
+ return builder.build_engine(network, config)
+
+def get_target_test_tensors(dtype=np.float32):
+ test_dir = os.path.join(current_path, "test_tensors")
+ test_tensors = {}
+ test_shapes = {}
+ for filename in os.listdir(test_dir):
+ tensor_name = filename.split(".")[0]
+ test_tensors[tensor_name] = np.load(os.path.join(test_dir, filename))
+
+ if test_tensors[tensor_name].dtype == np.int64:
+ test_tensors[tensor_name] = test_tensors[tensor_name].astype(np.int32)
+ elif test_tensors[tensor_name].dtype == np.float32:
+ test_tensors[tensor_name] = test_tensors[tensor_name].astype(dtype)
+
+ test_shapes[tensor_name] = test_tensors[tensor_name].shape
+ return test_tensors, test_shapes
+
+
+if __name__ == "__main__":
+ # for plugin in PLUGIN_CREATORS:
+ # print(plugin.name, plugin.plugin_version)
+
+ test_tensors, test_shapes = get_target_test_tensors()
+ for key in test_tensors:
+ print(key, test_tensors[key].shape, test_tensors[key].dtype)
+ test_engine = build_test_engine(test_shapes, dtype=trt.float16)
+
+ trt_model = TRTExecutor(engine=test_engine)
+ trt_model.print_bindings_info()
+
+ trt_model.inputs[0].host = test_tensors["flatten_value"].astype(np.float16)
+ trt_model.inputs[1].host = test_tensors["spatial_shapes"]
+ trt_model.inputs[2].host = test_tensors["level_start_index"][:4].copy()
+ trt_model.inputs[3].host = test_tensors["flatten_sampling_loc"].astype(np.float16)
+ trt_model.inputs[4].host = test_tensors["flatten_attn_weight"].astype(np.float16)
+
+
+ trt_model.execute()
+
+ N = 1000
+ tic = time.time()
+ [trt_model.execute() for i in range(N)]
+ toc = time.time()
+
+ diff = test_tensors["output"] - trt_model.outputs[0].host
+ print(np.abs(diff).mean())
+ print(f"Execution time: {(toc - tic)/N*1000} ms")
+
\ No newline at end of file
diff --git a/detr_tensorrt/plugins/ms_deform_im2col/test_tensors/flatten_attn_weight.npy b/detr_tensorrt/plugins/ms_deform_im2col/test_tensors/flatten_attn_weight.npy
new file mode 100644
index 00000000..cdfd6859
Binary files /dev/null and b/detr_tensorrt/plugins/ms_deform_im2col/test_tensors/flatten_attn_weight.npy differ
diff --git a/detr_tensorrt/plugins/ms_deform_im2col/test_tensors/flatten_sampling_loc.npy b/detr_tensorrt/plugins/ms_deform_im2col/test_tensors/flatten_sampling_loc.npy
new file mode 100644
index 00000000..5dba0170
Binary files /dev/null and b/detr_tensorrt/plugins/ms_deform_im2col/test_tensors/flatten_sampling_loc.npy differ
diff --git a/detr_tensorrt/plugins/ms_deform_im2col/test_tensors/flatten_value.npy b/detr_tensorrt/plugins/ms_deform_im2col/test_tensors/flatten_value.npy
new file mode 100644
index 00000000..76f0b919
Binary files /dev/null and b/detr_tensorrt/plugins/ms_deform_im2col/test_tensors/flatten_value.npy differ
diff --git a/detr_tensorrt/plugins/ms_deform_im2col/test_tensors/level_start_index.npy b/detr_tensorrt/plugins/ms_deform_im2col/test_tensors/level_start_index.npy
new file mode 100644
index 00000000..e96f3c88
Binary files /dev/null and b/detr_tensorrt/plugins/ms_deform_im2col/test_tensors/level_start_index.npy differ
diff --git a/detr_tensorrt/plugins/ms_deform_im2col/test_tensors/output.npy b/detr_tensorrt/plugins/ms_deform_im2col/test_tensors/output.npy
new file mode 100644
index 00000000..c3785334
Binary files /dev/null and b/detr_tensorrt/plugins/ms_deform_im2col/test_tensors/output.npy differ
diff --git a/detr_tensorrt/plugins/ms_deform_im2col/test_tensors/spatial_shapes.npy b/detr_tensorrt/plugins/ms_deform_im2col/test_tensors/spatial_shapes.npy
new file mode 100644
index 00000000..66b44fa2
Binary files /dev/null and b/detr_tensorrt/plugins/ms_deform_im2col/test_tensors/spatial_shapes.npy differ
diff --git a/detr_tensorrt/trt_helper.py b/detr_tensorrt/trt_helper.py
new file mode 100644
index 00000000..5739402f
--- /dev/null
+++ b/detr_tensorrt/trt_helper.py
@@ -0,0 +1,172 @@
+import pycuda.driver as cuda
+import tensorrt as trt
+
+COCO_PANOPTIC_CLASS_NAMES = [
+ 'N/A', 'person', 'bicycle', 'car', 'motorcycle', 'airplane', 'bus',
+ 'train', 'truck', 'boat', 'traffic light', 'fire hydrant', 'N/A',
+ 'stop sign', 'parking meter', 'bench', 'bird', 'cat', 'dog', 'horse',
+ 'sheep', 'cow', 'elephant', 'bear', 'zebra', 'giraffe', 'N/A', 'backpack',
+ 'umbrella', 'N/A', 'N/A', 'handbag', 'tie', 'suitcase', 'frisbee', 'skis',
+ 'snowboard', 'sports ball', 'kite', 'baseball bat', 'baseball glove',
+ 'skateboard', 'surfboard', 'tennis racket', 'bottle', 'N/A', 'wine glass',
+ 'cup', 'fork', 'knife', 'spoon', 'bowl', 'banana', 'apple', 'sandwich',
+ 'orange', 'broccoli', 'carrot', 'hot dog', 'pizza', 'donut', 'cake',
+ 'chair', 'couch', 'potted plant', 'bed', 'N/A', 'dining table', 'N/A',
+ 'N/A', 'toilet', 'N/A', 'tv', 'laptop', 'mouse', 'remote', 'keyboard',
+ 'cell phone', 'microwave', 'oven', 'toaster', 'sink', 'refrigerator',
+ 'N/A', 'book', 'clock', 'vase', 'scissors', 'teddy bear', 'hair drier',
+ 'toothbrush', 'N/A', 'banner', 'blanket', 'N/A', 'bridge', 'N/A',
+ 'N/A', 'N/A', 'N/A', 'cardboard', 'N/A', 'N/A', 'N/A', 'N/A', 'N/A', 'N/A',
+ 'counter', 'N/A', 'curtain', 'N/A', 'N/A', 'door-stuff', 'N/A', 'N/A', 'N/A', 'N/A',
+ 'N/A', 'floor-wood', 'flower', 'N/A', 'N/A', 'fruit', 'N/A', 'N/A', 'gravel', 'N/A', 'N/A',
+ 'house', 'N/A', 'light', 'N/A', 'N/A', 'mirror-stuff', 'N/A', 'N/A',
+ 'N/A', 'N/A', 'net', 'N/A', 'N/A', 'pillow', 'N/A', 'N/A', 'platform',
+ 'playingfield', 'N/A', 'railroad', 'river', 'road', 'N/A', 'roof', 'N/A', 'N/A',
+ 'sand', 'sea', 'shelf', 'N/A', 'N/A', 'snow', 'N/A', 'stairs', 'N/A', 'N/A', 'N/A',
+ 'N/A', 'tent', 'N/A', 'towel', 'N/A', 'N/A', 'wall-brick', 'N/A', 'N/A', 'N/A', 'wall-stone',
+ 'wall-tile', 'wall-wood', 'water-other', 'N/A', 'window-blind', 'window-other', 'N/A', 'N/A',
+ 'tree-merged', 'fence-merged', 'ceiling-merged', 'sky-other-merged',
+ 'cabinet-merged', 'table-merged', 'floor-other-merged', 'pavement-merged',
+ 'mountain-merged', 'grass-merged', 'dirt-merged', 'paper-merged', 'food-other-merged',
+ 'building-other-merged', 'rock-merged', 'wall-other-merged', 'rug-merged', 'N/A',
+ 'N/A', 'N/A', 'N/A', 'N/A', 'N/A', 'N/A', 'N/A', 'N/A', 'N/A', 'N/A', 'N/A', 'N/A',
+ 'N/A', 'N/A', 'N/A', 'N/A', 'N/A', 'N/A', 'N/A', 'N/A', 'N/A', 'N/A', 'N/A', 'N/A',
+ 'N/A', 'N/A', 'N/A', 'N/A', 'N/A', 'N/A', 'N/A', 'N/A', 'N/A', 'N/A', 'N/A', 'N/A',
+ 'N/A', 'N/A', 'N/A', 'N/A', 'N/A', 'N/A', 'N/A', 'N/A', 'N/A', 'N/A', 'N/A', 'N/A', 'backbround'
+]
+
+AGV_PANOPTIC_CLASS_NAMES = [
+ "person", 'carpet', 'dirt', 'floor-mable', 'floor-other', 'floor-stone',
+ 'floor-tile', 'floor-wood', 'gravel', 'gournd-other', 'mud', 'pavement', 'platform', 'playingfield',
+ 'railroad', 'road', 'sand', 'snow', 'background'
+]
+
+def GiB(val):
+ """Calculate Gibibit in bits, used to set workspace for TensorRT engine builder."""
+ return val * 1 << 30
+
+class HostDeviceMem(object):
+ """
+ Simple helper class to store useful data of an engine's binding
+
+ Attributes:
+ ----------
+ host_mem: np.ndarray
+ data stored in CPU
+ device_mem: pycuda.driver.DeviceAllocation
+ represent data pointer in GPU
+ shape: tuple
+ dtype: np dtype
+ name: str
+ name of the binding
+ """
+ def __init__(self, host_mem, device_mem, shape, dtype, name=""):
+ self.host = host_mem
+ self.device = device_mem
+ self.shape = shape
+ self.dtype = dtype
+ self.name = name
+
+ def __str__(self):
+ return "Host:\n" + str(self.host) + "\nDevice:\n" + str(self.device)
+
+ def __repr__(self):
+ return self.__str__()
+
+def allocate_buffers(context, stream=None, is_async=True):
+ """
+ Read bindings' information in ExecutionContext, create pagelocked np.ndarray in CPU,
+ allocate corresponding memory in GPU.
+
+ Returns:
+ --------
+ inputs: list[HostDeviceMem]
+ outputs: list[HostDeviceMem]
+ bindings: list[int]
+ list of pointers in GPU for each bindings
+ stream: pycuda.driver.Stream
+ used for memory transfers between CPU-GPU
+ """
+ inputs = []
+ outputs = []
+ bindings = []
+ if stream is None and is_async:
+ stream = cuda.Stream()
+ for binding in context.engine:
+ binding_idx = context.engine.get_binding_index(binding)
+ shape = context.get_binding_shape(binding_idx)
+ size = trt.volume(shape) * context.engine.max_batch_size
+ dtype = trt.nptype(context.engine.get_binding_dtype(binding))
+ # Allocate host and device buffers
+ host_mem = cuda.pagelocked_empty(size, dtype)
+ device_mem = cuda.mem_alloc(host_mem.nbytes)
+ # Append the device buffer to device bindings.
+ bindings.append(int(device_mem))
+ # Append to the appropriate list.
+ if context.engine.binding_is_input(binding):
+ inputs.append(HostDeviceMem(host_mem, device_mem, shape, dtype, binding))
+ else:
+ outputs.append(HostDeviceMem(host_mem, device_mem, shape, dtype, binding))
+ return inputs, outputs, bindings, stream
+
+def do_inference_async(context, bindings, inputs, outputs, stream):
+ """
+ Execute an TensorRT engine.
+
+ Parameters:
+ -----------
+ context: tensorrt.IExecutionContext
+ bindings: list[int]
+ list of pointers in GPU for each bindings
+ inputs: list[HostDeviceMem]
+ outputs: list[HostDeviceMem]
+ stream: pycuda.driver.Stream
+ used for memory transfers between CPU-GPU
+
+ Returns:
+ --------
+ list[np.ndarray] for each outputs of the engine
+ """
+ # Transfer input data to the GPU.
+ [cuda.memcpy_htod_async(inp.device, inp.host, stream) for inp in inputs]
+ # Run inference.
+ context.execute_async_v2(bindings=bindings, stream_handle=stream.handle)
+ # Transfer predictions back from the GPU.
+ [cuda.memcpy_dtoh_async(out.host, out.device, stream) for out in outputs]
+ # Synchronize the stream
+ stream.synchronize()
+ # Return only the host outputs.
+ for out in outputs:
+ out.host = out.host.reshape(out.shape)
+ return [out.host for out in outputs]
+
+def do_inference(context, bindings, inputs, outputs):
+ """
+ Execute an TensorRT engine.
+
+ Parameters:
+ -----------
+ context: tensorrt.IExecutionContext
+ bindings: list[int]
+ list of pointers in GPU for each bindings
+ inputs: list[HostDeviceMem]
+ outputs: list[HostDeviceMem]
+ stream: pycuda.driver.Stream
+ used for memory transfers between CPU-GPU
+
+ Returns:
+ --------
+ list[np.ndarray] for each outputs of the engine
+ """
+ # Transfer input data to the GPU.
+ [cuda.memcpy_htod(inp.device, inp.host) for inp in inputs]
+ # Run inference.
+ context.execute_v2(bindings=bindings)
+ # Transfer predictions back from the GPU.
+ [cuda.memcpy_dtoh(out.host, out.device) for out in outputs]
+ # # Synchronize the stream
+ # stream.synchronize()
+ # Return only the host outputs.
+ for out in outputs:
+ out.host = out.host.reshape(out.shape)
+ return [out.host for out in outputs]
\ No newline at end of file
diff --git a/detr_tf/custom_ops/ms_deform_attn/ms_deform_im2col.o b/detr_tf/custom_ops/ms_deform_attn/ms_deform_im2col.o
index e025baa6..e7d6cd3d 100644
Binary files a/detr_tf/custom_ops/ms_deform_attn/ms_deform_im2col.o and b/detr_tf/custom_ops/ms_deform_attn/ms_deform_im2col.o differ
diff --git a/detr_tf/inference.py b/detr_tf/inference.py
index 957233fd..55f6b40c 100644
--- a/detr_tf/inference.py
+++ b/detr_tf/inference.py
@@ -2,15 +2,16 @@
import numpy as np
import cv2
-
-CLASS_COLOR_MAP = np.random.randint(0, 255, (100, 3))
+np.random.seed(20)
+# CLASS_COLOR_MAP = np.random.randint(0, 255, (100, 3))
+CLASS_COLOR_MAP = np.random.random((100, 3))
from detr_tf import bbox
def numpy_bbox_to_image(image, bbox_list, labels=None, scores=None, class_name=[], config=None):
""" Numpy function used to display the bbox (target or prediction)
"""
- assert(image.dtype == np.float32 and image.dtype == np.float32 and len(image.shape) == 3)
+ assert(image.dtype == np.float32 and len(image.shape) == 3)
if config is not None and config.normalized_method == "torch_resnet":
channel_avg = np.array([0.485, 0.456, 0.406])
@@ -18,7 +19,8 @@ def numpy_bbox_to_image(image, bbox_list, labels=None, scores=None, class_name=[
image = (image * channel_std) + channel_avg
image = (image*255).astype(np.uint8)
elif config is not None and config.normalized_method == "tf_resnet":
- image = image + mean
+ mean = [103.939, 116.779, 123.68]
+ image = image + mean
image = image[..., ::-1]
image = image / 255
@@ -36,9 +38,6 @@ def numpy_bbox_to_image(image, bbox_list, labels=None, scores=None, class_name=[
# Go through each bbox
for b in np.argsort(bbox_area)[::-1]:
- # Take a new color at reandon for this instance
- instance_color = np.random.randint(0, 255, (3))
-
x1, y1, x2, y2 = bbox_x1y1x2y2[b]
x1, y1, x2, y2 = int(x1), int(y1), int(x2), int(y2)
@@ -55,12 +54,10 @@ def numpy_bbox_to_image(image, bbox_list, labels=None, scores=None, class_name=[
class_color = CLASS_COLOR_MAP[int(class_id)]
- color = instance_color
-
- multiplier = image.shape[0] / 500
- cv2.rectangle(image, (x1, y1), (x1 + int(multiplier*15)*len(label_name), y1 + 20), class_color.tolist(), -10)
- cv2.putText(image, label_name, (x1+2, y1 + 20), cv2.FONT_HERSHEY_SIMPLEX, 0.6 * multiplier, (0, 0, 0), 1)
+ multiplier = image.shape[0] / 1000
cv2.rectangle(image, (x1, y1), (x2, y2), tuple(class_color.tolist()), 2)
+ cv2.rectangle(image, (x1, y1 - 20), (x1 + int(multiplier*15)*len(label_name), y1 + 1), class_color.tolist(), -1)
+ cv2.putText(image, label_name, (x1+2, y1), cv2.FONT_HERSHEY_SIMPLEX, 0.6 * multiplier, (0, 0, 0), 1)
return image
diff --git a/detr_tf/networks/deformable_detr.py b/detr_tf/networks/deformable_detr.py
index 9bc51c00..c5b326ff 100644
--- a/detr_tf/networks/deformable_detr.py
+++ b/detr_tf/networks/deformable_detr.py
@@ -95,6 +95,7 @@ def __init__(self,
layer_norm = functools.partial(tfa.layers.GroupNormalization, groups=32, epsilon=1e-05) #tf.keras.layers.BatchNormalization
+
self.input_proj_0 = tf.keras.layers.Conv2D(self.model_dim, kernel_size=1, name='input_proj/0/0', trainable=train_encoder)
self.input_proj_gn_0 = layer_norm(name="input_proj_gn/0/1", trainable=train_encoder)
@@ -146,7 +147,7 @@ def call(self, inp, training=False, post_process=False):
class DetrClassHead(tf.keras.layers.Layer):
- def __init__(self, detr, include_top, nb_class=None, refine_bbox=False, **kwargs):
+ def __init__(self, detr:DeformableDETR, include_top, nb_class=None, refine_bbox=False, **kwargs):
"""
"""
super().__init__(name="detr_class_head", **kwargs)
diff --git a/detr_tf/networks/transformer.py b/detr_tf/networks/transformer.py
index 60c402a5..5225ccb3 100644
--- a/detr_tf/networks/transformer.py
+++ b/detr_tf/networks/transformer.py
@@ -306,24 +306,23 @@ def call(self, inputs, attn_mask=None, key_padding_mask=None,
if attn_mask is not None:
attn_output_weights += attn_mask
-
- if key_padding_mask is not None:
- key_padding_mask = tf.cast(key_padding_mask, tf.bool)
-
- attn_output_weights = tf.reshape(attn_output_weights,
- [batch_size, self.num_heads, target_len, source_len])
-
- key_padding_mask = tf.expand_dims(key_padding_mask, 1)
- key_padding_mask = tf.expand_dims(key_padding_mask, 2)
- key_padding_mask = tf.tile(key_padding_mask, [1, self.num_heads, target_len, 1])
-
-
- #print("before attn_output_weights", attn_output_weights.shape)
- attn_output_weights = tf.where(key_padding_mask,
- tf.zeros_like(attn_output_weights) + float('-inf'),
- attn_output_weights)
- attn_output_weights = tf.reshape(attn_output_weights,
- [batch_size * self.num_heads, target_len, source_len])
+ # Comment out this code block when export to ONNX and TRT engine
+ # if key_padding_mask is not None:
+ # attn_output_weights = tf.reshape(attn_output_weights,
+ # [batch_size, self.num_heads, target_len, source_len])
+
+ # key_padding_mask = tf.expand_dims(key_padding_mask, 1)
+ # key_padding_mask = tf.expand_dims(key_padding_mask, 2)
+ # key_padding_mask = tf.tile(key_padding_mask, [1, self.num_heads, target_len, 1])
+
+ # key_padding_mask = tf.cast(key_padding_mask, tf.bool)
+
+ # #print("before attn_output_weights", attn_output_weights.shape)
+ # attn_output_weights = tf.where(key_padding_mask,
+ # tf.zeros_like(attn_output_weights) + float('-inf'),
+ # attn_output_weights)
+ # attn_output_weights = tf.reshape(attn_output_weights,
+ # [batch_size * self.num_heads, target_len, source_len])
attn_output_weights = tf.nn.softmax(attn_output_weights, axis=-1)
diff --git a/images/test.jpeg b/images/test.jpeg
new file mode 100644
index 00000000..533e3260
Binary files /dev/null and b/images/test.jpeg differ
diff --git a/tensorflow_inference.py b/tensorflow_inference.py
new file mode 100644
index 00000000..a539433b
--- /dev/null
+++ b/tensorflow_inference.py
@@ -0,0 +1,84 @@
+import tensorflow as tf
+import numpy as np
+import cv2
+import time
+
+from detr_tf.training_config import TrainingConfig, training_config_parser
+
+from detr_tf.networks.detr import get_detr_model
+from detr_tf.networks.deformable_detr import get_deformable_detr_model
+
+from detr_tf.data import processing
+from detr_tf.data.coco import COCO_CLASS_NAME
+from detr_tf.inference import get_model_inference, numpy_bbox_to_image
+
+
+@tf.function
+def run_inference(model, images, config, use_mask=True):
+
+ if use_mask:
+ mask = tf.zeros((1, images.shape[1], images.shape[2], 1))
+ m_outputs = model((images, mask), training=False)
+ else:
+ m_outputs = model(images, training=False)
+
+ predicted_bbox, predicted_labels, predicted_scores = get_model_inference(
+ m_outputs, config.background_class, bbox_format="xy_center", threshold=0.4)
+ return predicted_bbox, predicted_labels, predicted_scores
+
+
+def main(model, use_mask=True, resize=None):
+
+ image = cv2.imread("images/test.jpeg")
+ # Convert to RGB and process the input image
+ if resize is not None:
+ image = cv2.resize(image, (resize[1], resize[0]))
+ model_input = cv2.cvtColor(image, cv2.COLOR_BGR2RGB)
+ model_input = processing.normalized_images(model_input, config)
+
+ # GPU warm up
+ [run_inference(model, np.expand_dims(model_input, axis=0), config, use_mask=use_mask) for i in range(3)]
+
+ # Timimg
+ tic = time.time()
+ predicted_bbox, predicted_labels, predicted_scores = run_inference(model, np.expand_dims(model_input, axis=0), config, use_mask=use_mask)
+ toc = time.time()
+ print(f"Inference latency: {(toc - tic)*1000} ms")
+
+ image = image.astype(np.float32)
+ image = image / 255
+ image = numpy_bbox_to_image(image, predicted_bbox, labels=predicted_labels, scores=predicted_scores, class_name=COCO_CLASS_NAME)
+
+ cv2.imshow('image', image)
+ cv2.waitKey(0)
+ cv2.destroyAllWindows()
+
+
+if __name__ == "__main__":
+
+ physical_devices = tf.config.list_physical_devices('GPU')
+ if len(physical_devices) == 1:
+ tf.config.experimental.set_memory_growth(physical_devices[0], True)
+
+ config = TrainingConfig()
+ parser = training_config_parser()
+
+ # Logging
+ parser.add_argument("model", type=str, help="One of 'detr', or 'deformable-detr'")
+ parser.add_argument("--resize", type=int, nargs=2, default=None, help="Resize image before running inference")
+ args = parser.parse_args()
+ config.update_from_args(args)
+
+ if args.model == "detr":
+ print("Loading detr...")
+ # Load the model with the new layers to finetune
+ model = get_detr_model(config, include_top=True, weights="detr")
+ config.background_class = 91
+ use_mask = True
+ elif args.model == "deformable-detr":
+ print("Loading deformable-detr...")
+ model = get_deformable_detr_model(config, include_top=True, weights="deformable-detr")
+ model.summary()
+ use_mask = False
+ # Run webcam inference
+ main(model, use_mask=use_mask, resize=args.resize)
diff --git a/tensorrt_inference.py b/tensorrt_inference.py
new file mode 100644
index 00000000..f0534592
--- /dev/null
+++ b/tensorrt_inference.py
@@ -0,0 +1,68 @@
+import ctypes
+import tensorrt as trt
+import numpy as np
+import os
+import cv2
+import argparse
+import time
+from numpy.lib.twodim_base import mask_indices
+
+from detr_tensorrt.TRTExecutor import TRTExecutor, TRT_LOGGER
+from detr_tensorrt.inference import normalized_images, get_model_inference
+
+from detr_tf.data.coco import COCO_CLASS_NAME
+from detr_tf.inference import numpy_bbox_to_image
+
+BACKGROUND_CLASS = 91 # COCO background class
+
+
+
+def run_inference(model: TRTExecutor, normalized_image: np.ndarray):
+ model.inputs[0].host = normalized_image
+ model.execute()
+ m_outputs = {out.name:out.host for out in model.outputs}
+ p_bbox, p_labels, p_scores = get_model_inference(m_outputs, BACKGROUND_CLASS, threshold=0.4)
+ return p_bbox, p_labels, p_scores
+
+
+def main(engine_path):
+ # Load TensorRT engine
+ model = TRTExecutor(engine_path)
+ model.print_bindings_info()
+
+ # Read image
+ input_shape = model.inputs[0].shape # (B, H, W, C)
+ H, W = input_shape[1], input_shape[2]
+ image = cv2.imread("images/test.jpeg")
+
+ # Pre-process image
+ model_input = cv2.resize(image, (W, H))
+ model_input = cv2.cvtColor(model_input, cv2.COLOR_BGR2RGB)
+ model_input = normalized_images(model_input)
+
+ # Run inference
+ [model.execute() for i in range(3)] # GPU warm up
+
+ tic = time.time()
+ p_bbox, p_labels, p_scores = run_inference(model, model_input)
+ toc = time.time()
+ print(f"Inference latency: {(toc - tic)*1000} ms")
+
+ image = image.astype(np.float32) / 255
+ image = numpy_bbox_to_image(image, p_bbox, p_labels, p_scores, COCO_CLASS_NAME)
+
+ cv2.imshow("image", image)
+ cv2.waitKey(0)
+ cv2.destroyAllWindows()
+
+if __name__ == "__main__":
+ parser = argparse.ArgumentParser()
+ parser.add_argument("--engine_path", type=str, required=True)
+ args = parser.parse_args()
+ if "deformable" in args.engine_path:
+ # Load custom plugin for deformable-detr
+ MS_DEFORM_IM2COL_PLUGIN_LIB = "./detr_tensorrt/plugins/ms_deform_im2col/build/libms_deform_im2col_trt.so"
+ ctypes.CDLL(MS_DEFORM_IM2COL_PLUGIN_LIB)
+ trt.init_libnvinfer_plugins(TRT_LOGGER, '')
+ PLUGIN_CREATORS = trt.get_plugin_registry().plugin_creator_list
+ main(**vars(args))
\ No newline at end of file