diff --git a/.gitignore b/.gitignore index c98f0fb8..20d14a0f 100644 --- a/.gitignore +++ b/.gitignore @@ -130,3 +130,6 @@ dmypy.json # Pyre type checker .pyre/ + +.vscode +detr_tf/custom_ops/ms_deform_attn/ms_deform_im2col.o diff --git a/README.md b/README.md index 35c8b087..cfc32882 100644 --- a/README.md +++ b/README.md @@ -7,8 +7,9 @@ Tensorflow implementation of DETR : Object Detection with Transformers, includin * [3. Tutorials](#tutorials) * [4. Finetuning](#finetuning) * [5. Training](#training) -* [5. inference](#inference) -* [6. Acknowledgement](#acknowledgement) +* [5. Inference](#inference) +* [6. Inference with TensorRT](#inference-with-tensorrt) +* [7. Acknowledgement](#acknowledgement) DETR paper: https://arxiv.org/pdf/2005.12872.pdf
@@ -152,6 +153,82 @@ python webcam_inference.py +## Inference with TensorRT + +### Requirements: +``` +cmake >= 3.8 +TensorRT 8 +``` +To install TensorRT 8, follow [NVIDIA TensorRT official installation guide](https://docs.nvidia.com/deeplearning/tensorrt/install-guide/index.html). + +Python package requirements: +``` +onnx +tf2onnx +``` + +### Custom plugin for Deformable DETR +Deformable DETR use a custom operation Im2Col in its Transformer layer. This operation is not supported by TensorRT so we need to build a TensorRT custom plugin from source. + +``` +cd detr_tensorrt/plugins/ms_deform_im2col +mkdir build && cd build +cmake .. \ + -DTRT_LIB=/path/to/tensorrt/lib/ + -DTRT_INCLUDE=/path/to/tensorrt/include/ + -DCUDA_ARCH_SM=/your_gpu_cuda_arch/ +make -j +``` +For more detail, see: `detr_tensorrt/plugins/ms_deform_im2col/README.txt` + +Parameters: +- `-DTRT_LIB`: Path to TensorRT lib. It could be `YOUR_TENSORRT_DIR/lib` or `/usr/lib/x86_64-linux-gnu` +- `-DTRT_INCLUDE`: Path to TensorRT C++ include. It could be `YOUR_TENSORRT_DIR/include` or `/usr/include/x86_64-linux-gnu` +- `-DCUDA_ARCHE_SM`: Compute capability of your NVIDIA GPU. Example: `70` for Tesla V100. Check [here](https://arnon.dk/matching-sm-architectures-arch-and-gencode-for-various-nvidia-cards/) for other GPU. + +### Workflow +Tensorflow model --> ONNX --> TensorRT serialized engine + +#### Export Tensorflow graph to ONNX graph: + +For each model (detr/deformable-detr), we have: +``` +python3 detr_tensorrt/export_onnx.py MODEL_NAME + [--input_shape H W] + [--save_to DIR_TO_SAVE_ONNX_FILE] +``` +Parameters: +- `--input_shape`: image height and width, default: 1280 1920 +- `--save_to`: directory that onnx file will be saved to. Default: `./weights/MODEL_NAME/MODEL_NAME_trt/` + +#### Convert ONNX model to TensorRT serialized engine: +``` +python3 detr_tensorrt/onnx2engine.py MODEL_NAME + [--precision PRECISION] + [--onnx_dir ONNX_DIR] + [--verbose] +``` +Parameters: +- `--precision`: precision of model weights: FP32, FP16, MIX. MIX precision will let TensorRT the freedom to optimize weights as either FP32 or FP16. In most cases, the inference time between FP16 and MIX has no big difference. +- `--onnx_dir`: directory containing the ONNX file to be converted to TensorRT engine. The required ONNX file must be named `MODEL_NAME.onnx`. Default: `./weights/MODEL_NAME/MODEL_NAME_trt/` +- `--verbose`: Print out TensorRT log of all levels + +The TensorRT serialized engine will be saved in `ONNX_DIR/MODEL_NAME_PRECISION.engine` + +### Run inference +An example of inference with a test image: `images/test.jpeg` + +``` +python tensorrt_inference.py --engine_path ENGINE_PATH +``` + +Inference time in milisecond: +| | DETR | Deformable DETR | +|---------------|------|-----------------| +| Tensorflow | 100 | 160 | +| TensorRT FP32 | 60 | 100 | +| TensorRT FP16 | 27 | 60 | ## Acknowledgement diff --git a/detr_tensorrt/TRTEngineBuilder.py b/detr_tensorrt/TRTEngineBuilder.py new file mode 100644 index 00000000..cd82c35e --- /dev/null +++ b/detr_tensorrt/TRTEngineBuilder.py @@ -0,0 +1,87 @@ +import tensorrt as trt +import os + +TRT_LOGGER = trt.Logger(trt.Logger.VERBOSE) +network_creation_flag = 1 << int(trt.NetworkDefinitionCreationFlag.EXPLICIT_BATCH) + +def GiB(val): + return val * 1 << 30 + +class TRTEngineBuilder(): + """ + Work with TensorRT 8. Should work fine with TensorRT 7.2.3 (not tested) + + Helper class to build TensorRT engine from ONNX graph file (including weights). + The graph must have defined input shape. For more detail, please see TensorRT Developer Guide: + https://docs.nvidia.com/deeplearning/tensorrt/developer-guide/index.html#python_topics + """ + def __init__(self, onnx_file_path, FP16_allowed=False, INT8_allowed=False, strict_type=False, calibrator=None, logger=TRT_LOGGER): + """ + Parameters: + ----------- + onnx_file_path: str + path to ONNX graph file + FP16_allowed: bool + Enable FP16 precision for engine builder + INT8_allowed: bool + Enable FP16 precision for engine builder, user must provide also a calibrator + strict_type: bool + Ensure that the builder understands to force the precision + calibrator: extended instance from tensorrt.IInt8Calibrator + Used for INT8 quantization + """ + self.FP16_allowed = FP16_allowed + self.INT8_allowed = INT8_allowed + self.onnx_file_path = onnx_file_path + self.calibrator = calibrator + self.max_workspace_size = GiB(8) + self.strict_type = strict_type + self.logger = logger + + def set_workspace_size(self, workspace_size_GiB): + self.max_workspace_size = GiB(workspace_size_GiB) + + def get_engine(self): + """ + Setup engine builder, read ONNX graph and build TensorRT engine. + """ + global network_creation_flag + with trt.Builder(self.logger) as builder, builder.create_network(network_creation_flag) as network, trt.OnnxParser(network, self.logger) as parser: + builder.max_batch_size = 1 + config = builder.create_builder_config() + config.max_workspace_size = self.max_workspace_size + # FP16 + if self.FP16_allowed: + config.set_flag(trt.BuilderFlag.FP16) + # INT8 + if self.INT8_allowed: + raise NotImplementedError() + if self.strict_type: + config.set_flag(trt.BuilderFlag.STRICT_TYPES) + + # Load and build model + with open(self.onnx_file_path, 'rb') as model: + if not parser.parse(model.read()): + print ('ERROR: Failed to parse the ONNX file.') + for error in range(parser.num_errors): + print (parser.get_error(error)) + return None + else: + print("ONNX file is loaded") + print("Building engine...") + engine = builder.build_engine(network, config) + if engine is None: + raise Exception("TRT export engine error. Check log") + print("Engine built") + return engine + + def export_engine(self, engine_path): + """Seriazlize TensorRT engine""" + engine = self.get_engine() + assert engine is not None, "Error while parsing engine from ONNX" + with open(engine_path, "wb") as f: + print("Serliaze and save as engine: " + engine_path) + f.write(engine.serialize()) + print("Engine exported") + + diff --git a/detr_tensorrt/TRTExecutor.py b/detr_tensorrt/TRTExecutor.py new file mode 100644 index 00000000..8fce6cc7 --- /dev/null +++ b/detr_tensorrt/TRTExecutor.py @@ -0,0 +1,131 @@ +import ctypes +import pycuda.autoinit as cuda_init +from surroundnet.detr.tensorrt.trt_helper import * +import tensorrt as trt + +TRT_LOGGER = trt.Logger(trt.Logger.VERBOSE) +# trt.init_libnvinfer_plugins(None, "") + +class TRTExecutor(): + """ + A helper class to execute a TensorRT engine. + + Attributes: + ----------- + stream: pycuda.driver.Stream + engine: tensorrt.ICudaEngine + context: tensorrt.IExecutionContext + inputs/outputs: list[HostDeviceMem] + see trt_helper.py + bindings: list[int] + pointers in GPU for each input/output of the engine + dict_inputs/dict_outputs: dict[str, HostDeviceMem] + key = input node name + value = HostDeviceMem of corresponding binding + + """ + def __init__(self, engine_path=None, has_dynamic_shape=False, stream=None, engine=None): + """ + Parameters: + ---------- + engine_path: str + path to serialized TensorRT engine + has_dynamic_shape: bool + stream: pycuda.driver.Stream + if None, one will be created by allocate_buffers function + """ + self.stream = stream + if engine_path is not None: + with open(engine_path, "rb") as f, trt.Runtime(TRT_LOGGER) as runtime: + print("Reading engine ...") + self.engine = runtime.deserialize_cuda_engine(f.read()) + assert self.engine is not None, "Read engine failed" + print("Engine loaded") + elif engine is not None: + self.engine = engine + self.context = self.engine.create_execution_context() + if not has_dynamic_shape: + self.inputs, self.outputs, self.bindings, self.stream = allocate_buffers(self.context, self.stream) + self.dict_inputs = {mem_obj.name:mem_obj for mem_obj in self.inputs} + self.dict_outputs = {mem_obj.name:mem_obj for mem_obj in self.outputs} + + def print_bindings_info(self): + print("ID / Name / isInput / shape / dtype") + for i in range(self.engine.num_bindings): + print(f"Binding: {i}, name: {self.engine.get_binding_name(i)}, input: {self.engine.binding_is_input(i)}, shape: {self.engine.get_binding_shape(i)}, dtype: {self.engine.get_binding_dtype(i)}") + + def execute(self): + do_inference_async( + self.context, + bindings=self.bindings, + inputs=self.inputs, + outputs=self.outputs, + stream=self.stream + ) + + def set_binding_shape(self, binding:int, shape:tuple): + self.context.set_binding_shape(binding, shape) + + def allocate_mem(self): + self.inputs, self.outputs, self.bindings, self.stream = allocate_buffers(self.context, self.stream) + self.dict_inputs = {mem_obj.name:mem_obj for mem_obj in self.inputs} + self.dict_outputs = {mem_obj.name:mem_obj for mem_obj in self.outputs} + +class TRTExecutor_Sync(): + """ + A helper class to execute a TensorRT engine. + + Attributes: + ----------- + engine: tensorrt.ICudaEngine + context: tensorrt.IExecutionContext + inputs/outputs: list[HostDeviceMem] + see trt_helper.py + bindings: list[int] + pointers in GPU for each input/output of the engine + dict_inputs/dict_outputs: dict[str, HostDeviceMem] + key = input node name + value = HostDeviceMem of corresponding binding + + """ + def __init__(self, engine_path=None, has_dynamic_shape=False, engine=None): + """ + Parameters: + ---------- + engine_path: str + path to serialized TensorRT engine + has_dynamic_shape: bool + """ + if engine_path is not None: + with open(engine_path, "rb") as f, trt.Runtime(TRT_LOGGER) as runtime: + print("Reading engine ...") + self.engine = runtime.deserialize_cuda_engine(f.read()) + assert self.engine is not None, "Read engine failed" + print("Engine loaded") + elif engine is not None: + self.engine = engine + self.context = self.engine.create_execution_context() + if not has_dynamic_shape: + self.inputs, self.outputs, self.bindings, self.stream = allocate_buffers(self.context, is_async=False) + self.dict_inputs = {mem_obj.name:mem_obj for mem_obj in self.inputs} + self.dict_outputs = {mem_obj.name:mem_obj for mem_obj in self.outputs} + + def print_bindings_info(self): + print("ID / Name / isInput / shape / dtype") + for i in range(self.engine.num_bindings): + print(f"Binding: {i}, name: {self.engine.get_binding_name(i)}, input: {self.engine.binding_is_input(i)}, shape: {self.engine.get_binding_shape(i)}, dtype: {self.engine.get_binding_dtype(i)}") + + def execute(self): + do_inference( + self.context, + bindings=self.bindings, + inputs=self.inputs, + outputs=self.outputs, + ) + + def set_binding_shape(self, binding:int, shape:tuple): + self.context.set_binding_shape(binding, shape) + + + + diff --git a/detr_tensorrt/common.py b/detr_tensorrt/common.py new file mode 100644 index 00000000..05a496cc --- /dev/null +++ b/detr_tensorrt/common.py @@ -0,0 +1,199 @@ +# +# Copyright 1993-2020 NVIDIA Corporation. All rights reserved. +# +# NOTICE TO LICENSEE: +# +# This source code and/or documentation ("Licensed Deliverables") are +# subject to NVIDIA intellectual property rights under U.S. and +# international Copyright laws. +# +# These Licensed Deliverables contained herein is PROPRIETARY and +# CONFIDENTIAL to NVIDIA and is being provided under the terms and +# conditions of a form of NVIDIA software license agreement by and +# between NVIDIA and Licensee ("License Agreement") or electronically +# accepted by Licensee. Notwithstanding any terms or conditions to +# the contrary in the License Agreement, reproduction or disclosure +# of the Licensed Deliverables to any third party without the express +# written consent of NVIDIA is prohibited. +# +# NOTWITHSTANDING ANY TERMS OR CONDITIONS TO THE CONTRARY IN THE +# LICENSE AGREEMENT, NVIDIA MAKES NO REPRESENTATION ABOUT THE +# SUITABILITY OF THESE LICENSED DELIVERABLES FOR ANY PURPOSE. IT IS +# PROVIDED "AS IS" WITHOUT EXPRESS OR IMPLIED WARRANTY OF ANY KIND. +# NVIDIA DISCLAIMS ALL WARRANTIES WITH REGARD TO THESE LICENSED +# DELIVERABLES, INCLUDING ALL IMPLIED WARRANTIES OF MERCHANTABILITY, +# NONINFRINGEMENT, AND FITNESS FOR A PARTICULAR PURPOSE. +# NOTWITHSTANDING ANY TERMS OR CONDITIONS TO THE CONTRARY IN THE +# LICENSE AGREEMENT, IN NO EVENT SHALL NVIDIA BE LIABLE FOR ANY +# SPECIAL, INDIRECT, INCIDENTAL, OR CONSEQUENTIAL DAMAGES, OR ANY +# DAMAGES WHATSOEVER RESULTING FROM LOSS OF USE, DATA OR PROFITS, +# WHETHER IN AN ACTION OF CONTRACT, NEGLIGENCE OR OTHER TORTIOUS +# ACTION, ARISING OUT OF OR IN CONNECTION WITH THE USE OR PERFORMANCE +# OF THESE LICENSED DELIVERABLES. +# +# U.S. Government End Users. These Licensed Deliverables are a +# "commercial item" as that term is defined at 48 C.F.R. 2.101 (OCT +# 1995), consisting of "commercial computer software" and "commercial +# computer software documentation" as such terms are used in 48 +# C.F.R. 12.212 (SEPT 1995) and is provided to the U.S. Government +# only as a commercial end item. Consistent with 48 C.F.R.12.212 and +# 48 C.F.R. 227.7202-1 through 227.7202-4 (JUNE 1995), all +# U.S. Government End Users acquire the Licensed Deliverables with +# only those rights set forth herein. +# +# Any use of the Licensed Deliverables in individual and commercial +# software must include, in the user documentation and internal +# comments to the code, the above Disclaimer and U.S. Government End +# Users Notice. +# + +from itertools import chain +import argparse +import os + +import pycuda.driver as cuda +import pycuda.autoinit +import numpy as np + +import tensorrt as trt + +try: + # Sometimes python2 does not understand FileNotFoundError + FileNotFoundError +except NameError: + FileNotFoundError = IOError + +EXPLICIT_BATCH = 1 << (int)(trt.NetworkDefinitionCreationFlag.EXPLICIT_BATCH) + +def GiB(val): + return val * 1 << 30 + + +def add_help(description): + parser = argparse.ArgumentParser(description=description, formatter_class=argparse.ArgumentDefaultsHelpFormatter) + args, _ = parser.parse_known_args() + + +def find_sample_data(description="Runs a TensorRT Python sample", subfolder="", find_files=[]): + ''' + Parses sample arguments. + + Args: + description (str): Description of the sample. + subfolder (str): The subfolder containing data relevant to this sample + find_files (str): A list of filenames to find. Each filename will be replaced with an absolute path. + + Returns: + str: Path of data directory. + ''' + + # Standard command-line arguments for all samples. + kDEFAULT_DATA_ROOT = os.path.join(os.sep, "usr", "src", "tensorrt", "data") + parser = argparse.ArgumentParser(description=description, formatter_class=argparse.ArgumentDefaultsHelpFormatter) + parser.add_argument("-d", "--datadir", help="Location of the TensorRT sample data directory, and any additional data directories.", action="append", default=[kDEFAULT_DATA_ROOT]) + args, _ = parser.parse_known_args() + + def get_data_path(data_dir): + # If the subfolder exists, append it to the path, otherwise use the provided path as-is. + data_path = os.path.join(data_dir, subfolder) + if not os.path.exists(data_path): + print("WARNING: " + data_path + " does not exist. Trying " + data_dir + " instead.") + data_path = data_dir + # Make sure data directory exists. + if not (os.path.exists(data_path)): + print("WARNING: {:} does not exist. Please provide the correct data path with the -d option.".format(data_path)) + return data_path + + data_paths = [get_data_path(data_dir) for data_dir in args.datadir] + return data_paths, locate_files(data_paths, find_files) + +def locate_files(data_paths, filenames): + """ + Locates the specified files in the specified data directories. + If a file exists in multiple data directories, the first directory is used. + + Args: + data_paths (List[str]): The data directories. + filename (List[str]): The names of the files to find. + + Returns: + List[str]: The absolute paths of the files. + + Raises: + FileNotFoundError if a file could not be located. + """ + found_files = [None] * len(filenames) + for data_path in data_paths: + # Find all requested files. + for index, (found, filename) in enumerate(zip(found_files, filenames)): + if not found: + file_path = os.path.abspath(os.path.join(data_path, filename)) + if os.path.exists(file_path): + found_files[index] = file_path + + # Check that all files were found + for f, filename in zip(found_files, filenames): + if not f or not os.path.exists(f): + raise FileNotFoundError("Could not find {:}. Searched in data paths: {:}".format(filename, data_paths)) + return found_files + +# Simple helper data class that's a little nicer to use than a 2-tuple. +class HostDeviceMem(object): + def __init__(self, host_mem, device_mem): + self.host = host_mem + self.device = device_mem + + def __str__(self): + return "Host:\n" + str(self.host) + "\nDevice:\n" + str(self.device) + + def __repr__(self): + return self.__str__() + +# Allocates all buffers required for an engine, i.e. host/device inputs/outputs. +def allocate_buffers(engine): + inputs = [] + outputs = [] + bindings = [] + stream = cuda.Stream() + for binding in engine: + size = trt.volume(engine.get_binding_shape(binding)) * engine.max_batch_size + dtype = trt.nptype(engine.get_binding_dtype(binding)) + # Allocate host and device buffers + host_mem = cuda.pagelocked_empty(size, dtype) + device_mem = cuda.mem_alloc(host_mem.nbytes) + # Append the device buffer to device bindings. + bindings.append(int(device_mem)) + # Append to the appropriate list. + if engine.binding_is_input(binding): + inputs.append(HostDeviceMem(host_mem, device_mem)) + else: + outputs.append(HostDeviceMem(host_mem, device_mem)) + return inputs, outputs, bindings, stream + +# This function is generalized for multiple inputs/outputs. +# inputs and outputs are expected to be lists of HostDeviceMem objects. +def do_inference(context, bindings, inputs, outputs, stream, batch_size=1): + # Transfer input data to the GPU. + [cuda.memcpy_htod_async(inp.device, inp.host, stream) for inp in inputs] + # Run inference. + context.execute_async(batch_size=batch_size, bindings=bindings, stream_handle=stream.handle) + # Transfer predictions back from the GPU. + [cuda.memcpy_dtoh_async(out.host, out.device, stream) for out in outputs] + # Synchronize the stream + stream.synchronize() + # Return only the host outputs. + return [out.host for out in outputs] + +# This function is generalized for multiple inputs/outputs for full dimension networks. +# inputs and outputs are expected to be lists of HostDeviceMem objects. +def do_inference_v2(context, bindings, inputs, outputs, stream): + # Transfer input data to the GPU. + [cuda.memcpy_htod_async(inp.device, inp.host, stream) for inp in inputs] + # Run inference. + context.execute_async_v2(bindings=bindings, stream_handle=stream.handle) + # Transfer predictions back from the GPU. + [cuda.memcpy_dtoh_async(out.host, out.device, stream) for out in outputs] + # Synchronize the stream + stream.synchronize() + # Return only the host outputs. + return [out.host for out in outputs] diff --git a/detr_tensorrt/export_onnx.py b/detr_tensorrt/export_onnx.py new file mode 100644 index 00000000..c5f9acc6 --- /dev/null +++ b/detr_tensorrt/export_onnx.py @@ -0,0 +1,89 @@ +import tensorflow as tf +import numpy as np +import os +import onnx +import tf2onnx +from pathlib import Path + +from detr_tf.training_config import TrainingConfig, training_config_parser + +from detr_tf.networks.detr import get_detr_model +from detr_tf.networks.deformable_detr import get_deformable_detr_model + +def get_model(config, args): + if args.model == "detr": + print("Loading detr...") + # Load the model with the new layers to finetune + model = get_detr_model(config, include_top=True, weights="detr") + config.background_class = 91 + m_output_names = ["pred_logits", "pred_boxes"] + use_mask = True + # model.summary() + # return model + elif args.model == "deformable-detr": + print("Loading deformable-detr...") + model = get_deformable_detr_model(config, include_top=True, weights="deformable-detr") + m_output_names = ["bbox_pred_logits", "bbox_pred_boxes"] + # [print(name, model.output[name]) for name in model.output] + # model.summary() + use_mask = False + else: + raise NotImplementedError() + # Remove auxliary outputs + input_image = tf.keras.Input(args.input_shape, batch_size=1, name="input_image") + if use_mask: + mask =tf.keras.Input(args.input_shape[:2] + [1], batch_size=1, name="input_mask") + m_inputs = (input_image, mask) + else: + m_inputs = (input_image, ) + all_outputs = model(m_inputs, training=False) + + m_outputs = { + name:tf.identity(all_outputs[name], name=name) + for name in m_output_names if name in all_outputs} + [print(m_outputs[name]) for name in m_outputs] + + model = tf.keras.Model(m_inputs, m_outputs, name=args.model) + model.summary() + return model + + +if __name__ == "__main__": + physical_devices = tf.config.list_physical_devices('GPU') + if len(physical_devices) == 1: + tf.config.experimental.set_memory_growth(physical_devices[0], True) + + config = TrainingConfig() + parser = training_config_parser() + parser.add_argument("model", type=str, default="deformable-detr", help="One of 'detr', or 'deformable-detr'") + parser.add_argument("--input_shape", type=int, default=[1280, 1920], nargs=2, help="ex: 1280 1920 3") + parser.add_argument('--save_to', type=str, default=None, help="Path to save ONNX file") + args = parser.parse_args() + config.update_from_args(args) + + args.input_shape.append(3) # C = 3 + + if args.save_to is None: + args.save_to = os.path.join("weights", args.model, args.model + "_trt") + + # === Load model + model = get_model(config, args) + # === Save model to pb file + if not os.path.isdir(args.save_to): + os.makedirs(args.save_to) + + # === Save onnx file + input_spec = [tf.TensorSpec.from_tensor(tensor) for tensor in model.input] + # print(input_spec) + output_path = os.path.join(args.save_to, args.model + ".onnx") + model_proto, _ = tf2onnx.convert.from_keras( + model, input_signature=input_spec, + opset=13, output_path=output_path) + print("===== Inputs =======") + [print(n.name) for n in model_proto.graph.input] + print("===== Outputs =======") + [print(n.name) for n in model_proto.graph.output] + + + + diff --git a/detr_tensorrt/inference.py b/detr_tensorrt/inference.py new file mode 100644 index 00000000..f4f215c1 --- /dev/null +++ b/detr_tensorrt/inference.py @@ -0,0 +1,76 @@ +import numpy as np +import os +import cv2 +import argparse + +from detr_tf import bbox +from detr_tensorrt.TRTExecutor import TRTExecutor +from scipy.special import softmax + +def normalized_images(image, normalized_method="torch_resnet"): + """ Normalized images. torch_resnet is used on finetuning + since the weights are based on the original paper training code + from pytorch. tf_resnet is used when training from scratch with a + resnet50 traine don tensorflow. + """ + if normalized_method == "torch_resnet": + channel_avg = np.array([0.485, 0.456, 0.406]) + channel_std = np.array([0.229, 0.224, 0.225]) + image = (image / 255.0 - channel_avg) / channel_std + return image.astype(np.float32) + elif normalized_method == "tf_resnet": + mean = [103.939, 116.779, 123.68] + image = image[..., ::-1] + image = image - mean + return image.astype(np.float32) + else: + raise Exception("Can't handler thid normalized method") + +def sigmoid(x): + return 1/(1 + np.exp(-x)) + +def get_model_inference(m_outputs: dict, background_class, bbox_format="xy_center", threshold=None): + + #print('get model inference', [key for key in m_outputs]) + + # Detr or deformable + predicted_bbox = m_outputs["pred_boxes"][0] if "pred_boxes" in m_outputs else m_outputs["bbox_pred_boxes"][0] + predicted_labels = m_outputs["pred_logits"][0] if "pred_logits" in m_outputs else m_outputs["bbox_pred_logits"][0] + activation = "softmax" if "pred_boxes" in m_outputs else "sigmoid" + + if activation == "softmax": # Detr + softmax_scores = softmax(predicted_labels, axis=-1) + predicted_scores = np.max(softmax_scores, axis=-1) + predicted_labels = np.argmax(softmax_scores, axis=-1) + bool_filter = predicted_labels != background_class + else: # Deformable Detr + sigmoid_scores = sigmoid(predicted_labels) + predicted_scores = np.max(sigmoid_scores, axis=-1) + predicted_labels = np.argmax(sigmoid_scores, axis=-1) + threshold = 0.1 if threshold is None else threshold + bool_filter = predicted_scores > threshold + + + predicted_scores = predicted_scores[bool_filter] + predicted_labels = predicted_labels[bool_filter] + predicted_bbox = predicted_bbox[bool_filter] + + if bbox_format == "xy_center": + predicted_bbox = predicted_bbox + elif bbox_format == "xyxy": + predicted_bbox = bbox.xcycwh_to_xy_min_xy_max(predicted_bbox) + elif bbox_format == "yxyx": + predicted_bbox = bbox.xcycwh_to_yx_min_yx_max(predicted_bbox) + else: + raise NotImplementedError() + + return predicted_bbox, predicted_labels, predicted_scores + +def main(engine_path): + pass + +if __name__ == "__main__": + parser = argparse.ArgumentParser() + parser.add_argument("--engine_path", type=str, required=True) + args = parser.parse_args() + main(**vars(args)) \ No newline at end of file diff --git a/detr_tensorrt/onnx2engine.py b/detr_tensorrt/onnx2engine.py new file mode 100644 index 00000000..c831acf4 --- /dev/null +++ b/detr_tensorrt/onnx2engine.py @@ -0,0 +1,250 @@ +import argparse +import ctypes +import tensorrt as trt +import os +import onnx +import numpy as np +import onnx_graphsurgeon as gs + +from TRTEngineBuilder import TRTEngineBuilder, TRT_LOGGER +from common import GiB + + +def print_graph_io(graph): + # Print inputs: + print(" ===== Inputs =====") + for i in graph.inputs: + print(i) + # Print outputs: + print(" ===== Outputs =====") + for i in graph.outputs: + print(i) + + +def io_name_handler(graph: gs.Graph): + len_suffix = len("tf_op_layer_") + for out in graph.outputs: + out.name = out.name[len_suffix:] + + +def get_node_by_name(name, onnx_graph: gs.Graph): + for n in onnx_graph.nodes: + if name in n.name: + return n + return None + + +def get_nodes_by_op(op_name, onnx_graph): + nodes = [] + for n in onnx_graph.nodes: + if n.op == op_name: + nodes.append(n) + return nodes + +def get_nodes_by_prefix(prefix, onnx_graph: gs.Graph): + nodes = [] + for n in onnx_graph.nodes: + if n.name.startswith(prefix): + nodes.append(n) + return nodes + + +def fix_graph_detr(graph: gs.Graph): + # === Fix Pad 2 in Resnet backbone === + # TensorRT supports padding only on 2 innermost dimensions + resnet_pad2 = get_node_by_name( + "detr/detr_finetuning/detr/backbone/pad2/Pad", graph) + resnet_pad2.inputs[1] = gs.Constant( + "pad2/pads_input", np.array([0, 0, 1, 1, 0, 0, 1, 1])) + graph.cleanup() + graph.toposort() + return graph + + +def fix_graph_deformable_detr(graph: gs.Graph): + batch_size = graph.inputs[0].shape[0] + # === Fix Pad 2 in Resnet backbone === + # TensorRT supports padding only on 2 innermost dimensions + resnet_pad2 = get_node_by_name( + "deformable-detr/deformable_detr/detr_core/backbone/pad2/Pad", graph) + unused_nodes = [resnet_pad2.i(1), resnet_pad2.i(1).i()] + resnet_pad2.inputs[1] = gs.Constant( + "pad2/pads_input", np.array([0, 0, 1, 1, 0, 0, 1, 1])) + for n in unused_nodes: + graph.nodes.remove(n) + + # ======= Add nodes for MsDeformIm2ColTRT =========== + tf_im2col_nodes = get_nodes_by_op("MsDeformIm2col", graph) + + spatial_shape_np = tf_im2col_nodes[0].inputs[1].values.reshape((1, -1, 2)) + spatial_shape_tensor = gs.Constant( + name="MsDeformIm2Col_spatial_shape", + values=spatial_shape_np) + + start_index_np = tf_im2col_nodes[0].inputs[2].values.reshape((1, -1)) + start_index_tensor = gs.Constant( + name="MsDeformIm2Col_start_index", + values=start_index_np) + + def handle_ops_MsDeformIm2ColTRT(graph: gs.Graph, node: gs.Node): + inputs = node.inputs + inputs.pop(1) + inputs.pop(1) + inputs.insert(1, start_index_tensor) + inputs.insert(1, spatial_shape_tensor) + outputs = node.outputs + graph.layer( + op="MsDeformIm2ColTRT", + name=node.name + "_trt", + inputs=inputs, + outputs=outputs) + + for n in tf_im2col_nodes: + handle_ops_MsDeformIm2ColTRT(graph, n) + # Detach old node from graph + n.inputs.clear() + n.outputs.clear() + graph.nodes.remove(n) + + # ======= Handle GroupNorm by TensorRT official plugin ======= + gn_nodes = [] + for i in range(4): + gn_nodes.append( + get_nodes_by_prefix( + f"deformable-detr/deformable_detr/detr_core/input_proj_gn/{i}", graph)) + + def handle_group_norm_nodes(nodes, graph:gs.Graph): + # Get GN name + gn_name = nodes[0].name[:-7] + # Get GN input tensors + + gn_input = nodes[0].i().inputs[0] + # Get gamme input + mul_node = None + for n in nodes: + if n.name.endswith("/mul"): + mul_node = n + assert mul_node is not None + gamma_input = gs.Constant( + name=gn_name + "gamma:0", + values=mul_node.inputs[1].values.reshape((batch_size, -1))) + # Get beta input + sub_node = None + for n in nodes: + if n.name.endswith("batchnorm/sub"): + sub_node = n + assert sub_node is not None + beta_input = gs.Constant( + name=gn_name+"beta:0", + values=sub_node.inputs[0].values.reshape((batch_size, -1))) + # Get output tensor + gn_output = nodes[-1].outputs[0] + # print(gn_output) + # Add new plugin node to graph + graph.layer( + name=gn_name + "group_norm_trt", + inputs=[gn_input, gamma_input, beta_input], + outputs=[gn_output], + op="GroupNormalizationPlugin", + attrs={ + "eps": 1e-5, + "num_groups": 32 + }) + # Detach gn_output from existing graph + gn_out_flatten = gn_output.outputs[0] + gn_out_flatten.inputs.pop(0) + # Add Transpose + transposed_tensor = graph.layer( + name=gn_name+"gn_out_transpose", + inputs=[gn_output], + outputs=[gn_name + "input_proj_flatten:0"], + op="Transpose", + attrs={"perm": [0, 2, 3, 1]} + ) + gn_out_flatten.inputs.insert(0, transposed_tensor[0]) + # Disconnect old nodes + nodes.insert(0, nodes[0].i()) # for clean up purpose + for n in nodes: + n.inputs.clear() + n.outputs.clear() + graph.nodes.remove(n) + + for nodes in gn_nodes: + handle_group_norm_nodes(nodes, graph) + + + return graph + + +def fix_onnx_graph(graph: gs.Graph, model: str): + if model == "detr": + return fix_graph_detr(graph) + elif model == "deformable-detr": + return fix_graph_deformable_detr(graph) + + +def main(onnx_dir: str, model: str, precision: str, verbose: bool, **args): + print(model) + onnx_path = os.path.join(onnx_dir, model + ".onnx") + print(onnx_path) + + graph = gs.import_onnx(onnx.load(onnx_path)) + graph.toposort() + + # === Change graph IO names + # print_graph_io(graph) + io_name_handler(graph) + print_graph_io(graph) + + # === Fix graph to adapt to TensorRT + graph = fix_onnx_graph(graph, model) + + # === Export adapted onnx for TRT engine + adapted_onnx_path = os.path.join(onnx_dir, model + "_trt.onnx") + onnx.save(gs.export_onnx(graph), adapted_onnx_path) + + # === Build engine + if verbose: + trt_logger = trt.Logger(trt.Logger.VERBOSE) + else: + trt_logger = trt.Logger(trt.Logger.WARNING) + + builder = TRTEngineBuilder(adapted_onnx_path, logger=trt_logger) + + if precision == "FP32": + pass + if precision == "FP16": + builder.FP16_allowed = True + builder.strict_type = True + if precision == "MIX": + builder.FP16_allowed = True + builder.strict_type = False + + builder.export_engine(os.path.join( + onnx_dir, model + f"_{precision.lower()}.engine")) + + +if __name__ == "__main__": + parser = argparse.ArgumentParser() + parser.add_argument('model', type=str, default="detr", + help="detr/deformable-detr") + parser.add_argument("--precision", type=str, + default="FP16", help="FP32/FP16/MIX") + parser.add_argument('--onnx_dir', type=str, default=None, + help="path to dir containing the \{model_name\}.onnx file") + parser.add_argument("--verbose", action="store_true", + help="Print out TensorRT log of all levels") + args = parser.parse_args() + + if "deformable" in args.model: + MS_DEFORM_IM2COL_PLUGIN_LIB = "./detr_tensorrt/plugins/ms_deform_im2col/build/libms_deform_im2col_trt.so" + ctypes.CDLL(MS_DEFORM_IM2COL_PLUGIN_LIB) + trt.init_libnvinfer_plugins(TRT_LOGGER, '') + PLUGIN_CREATORS = trt.get_plugin_registry().plugin_creator_list + + if args.onnx_dir is None: + args.onnx_dir = os.path.join( + "weights", args.model, args.model + "_trt") + # for plugin in PLUGIN_CREATORS: + # print(plugin.name, plugin.plugin_version) + main(**vars(args)) diff --git a/detr_tensorrt/plugins/ms_deform_im2col/CMakeLists.txt b/detr_tensorrt/plugins/ms_deform_im2col/CMakeLists.txt new file mode 100644 index 00000000..e22e58b3 --- /dev/null +++ b/detr_tensorrt/plugins/ms_deform_im2col/CMakeLists.txt @@ -0,0 +1,58 @@ +# We need cmake >= 3.8, since 3.8 introduced CUDA as a first class language +cmake_minimum_required(VERSION 3.8 FATAL_ERROR) +project(MsDeformIm2ColTRT LANGUAGES CXX CUDA) + +# Enable all compile warnings +set(CMAKE_CXX_FLAGS "${CMAKE_CXX_FLAGS} -Wall -Wno-long-long -pedantic -Wno-deprecated-declarations") + +# Sets variable to a value if variable is unset. +macro(set_ifndef var val) + if (NOT ${var}) + set(${var} ${val}) + endif() + message(STATUS "Configurable variable ${var} set to ${${var}}") +endmacro() + +# -------- CONFIGURATION -------- +set_ifndef(TRT_LIB /home/ubuntu/TensorRT-8.0.0.3/lib) +set_ifndef(TRT_INCLUDE /home/ubuntu/TensorRT-8.0.0.3/include) +set_ifndef(CUDA_INC_DIR /usr/local/cuda/include) +set_ifndef(CUDA_ARCH_SM 70) # should be fine for Tesla V100 + +# Find dependencies: +message("\nThe following variables are derived from the values of the previous variables unless provided explicitly:\n") + +# TensorRT's nvinfer lib +find_library(_NVINFER_LIB nvinfer HINTS ${TRT_LIB} PATH_SUFFIXES lib lib64) +set_ifndef(NVINFER_LIB ${_NVINFER_LIB}) + + +# -------- BUILDING -------- + +# Add include directories +include_directories(${CUDA_INC_DIR} ${TRT_INCLUDE} ${CMAKE_SOURCE_DIR}/sources/) +message(STATUS "CUDA_INC_DIR: ${CUDA_INC_DIR}") +# Define plugin library target +add_library(ms_deform_im2col_trt MODULE +${CMAKE_SOURCE_DIR}/sources/ms_deform_im2col_kernel.cu +${CMAKE_SOURCE_DIR}/sources/ms_deform_im2col_kernel.h +${CMAKE_SOURCE_DIR}/sources/ms_deform_im2col_plugin.cpp +${CMAKE_SOURCE_DIR}/sources/ms_deform_im2col_plugin.h +) + +# Use C++11 +target_compile_features(ms_deform_im2col_trt PUBLIC cxx_std_11) + +# Link TensorRT's nvinfer lib +target_link_libraries(ms_deform_im2col_trt PRIVATE ${NVINFER_LIB}) + +# We need to explicitly state that we need all CUDA files +# to be built with -dc as the member functions will be called by +# other libraries and executables (in our case, Python inference scripts) +set_target_properties(ms_deform_im2col_trt PROPERTIES +CUDA_SEPARABLE_COMPILATION ON +) + +# CUDA ARCHITECTURE +set_target_properties(ms_deform_im2col_trt PROPERTIES +CUDA_ARCHITECTURES "${CUDA_ARCH_SM}") diff --git a/detr_tensorrt/plugins/ms_deform_im2col/README.txt b/detr_tensorrt/plugins/ms_deform_im2col/README.txt new file mode 100644 index 00000000..8ef6c1a9 --- /dev/null +++ b/detr_tensorrt/plugins/ms_deform_im2col/README.txt @@ -0,0 +1,16 @@ +To build the plugin: +mkdir build && cd build +cmake .. && make -j + +NOTE: If any of the dependencies are not installed in their default locations, you can manually specify them. For example: + +cmake .. -DPYBIND11_DIR=/path/to/pybind11/ + -DCMAKE_CUDA_COMPILER=/usr/local/cuda-x.x/bin/nvcc (Or adding /path/to/nvcc into $PATH) + -DCUDA_INC_DIR=/usr/local/cuda-x.x/include/ (Or adding /path/to/cuda/include into $CPLUS_INCLUDE_PATH) + -DPYTHON3_INC_DIR=/usr/include/python3.6/ + -DTRT_LIB=/path/to/tensorrt/lib/ + -DTRT_INCLUDE=/path/to/tensorrt/include/ + -DCUDA_ARCH_SM=70 + +Check matching sm for Nvidia GPU: +https://arnon.dk/matching-sm-architectures-arch-and-gencode-for-various-nvidia-cards/ \ No newline at end of file diff --git a/detr_tensorrt/plugins/ms_deform_im2col/sources/ms_deform_im2col_kernel.cu b/detr_tensorrt/plugins/ms_deform_im2col/sources/ms_deform_im2col_kernel.cu new file mode 100644 index 00000000..a9b20cb3 --- /dev/null +++ b/detr_tensorrt/plugins/ms_deform_im2col/sources/ms_deform_im2col_kernel.cu @@ -0,0 +1,326 @@ +#include "ms_deform_im2col_kernel.h" +#include +#include +#include "cuda_fp16.h" +#include "NvInfer.h" + + +#define assertm(exp, msg) assert(((void)msg, exp)) + +#define CUDA_KERNEL_LOOP(i, n) \ + for (int i = blockIdx.x * blockDim.x + threadIdx.x; \ + i < n; i += blockDim.x * gridDim.x) + +const int CUDA_NUM_THREADS = 1024; + +int GET_BLOCKS(const int N, const int num_threads) +{ + return (N + num_threads - 1) / num_threads; +} + +template +__device__ scalar_t ms_deform_attn_im2col_bilinear(const scalar_t* &bottom_data, + const int &height, const int &width, const int &nheads, const int &channels, + const scalar_t &h, const scalar_t &w, const int &m, const int &c) +{ + const int h_low = floor(h); + const int w_low = floor(w); + const int h_high = h_low + 1; + const int w_high = w_low + 1; + + const scalar_t lh = h - h_low; + const scalar_t lw = w - w_low; + const scalar_t hh = 1 - lh, hw = 1 - lw; + + const int w_stride = nheads * channels; + const int h_stride = width * w_stride; + const int h_low_ptr_offset = h_low * h_stride; + const int h_high_ptr_offset = h_low_ptr_offset + h_stride; + const int w_low_ptr_offset = w_low * w_stride; + const int w_high_ptr_offset = w_low_ptr_offset + w_stride; + const int base_ptr = m * channels + c; + + scalar_t v1 = 0; + if (h_low >= 0 && w_low >= 0) + { + const int ptr1 = h_low_ptr_offset + w_low_ptr_offset + base_ptr; + v1 = bottom_data[ptr1]; + } + scalar_t v2 = 0; + if (h_low >= 0 && w_high <= width - 1) + { + const int ptr2 = h_low_ptr_offset + w_high_ptr_offset + base_ptr; + v2 = bottom_data[ptr2]; + } + scalar_t v3 = 0; + if (h_high <= height - 1 && w_low >= 0) + { + const int ptr3 = h_high_ptr_offset + w_low_ptr_offset + base_ptr; + v3 = bottom_data[ptr3]; + } + scalar_t v4 = 0; + if (h_high <= height - 1 && w_high <= width - 1) + { + const int ptr4 = h_high_ptr_offset + w_high_ptr_offset + base_ptr; + v4 = bottom_data[ptr4]; + } + + const scalar_t w1 = hh * hw, w2 = hh * lw, w3 = lh * hw, w4 = lh * lw; + + const scalar_t val = (w1 * v1 + w2 * v2 + w3 * v3 + w4 * v4); + return val; +} + +__device__ half ms_deform_attn_im2col_bilinear_half(const half* &bottom_data, + const int &height, const int &width, const int &nheads, const int &channels, + const half &h, const half &w, const int &m, const int &c) +{ + half one = __float2half(1.0f); + half zero = __float2half(0.0f); + + const half h_low = hfloor(h); + const half w_low = hfloor(w); + const int h_high = hceil(h); + const int w_high = hceil(w); + + + const half lh = h - h_low; + const half lw = w - w_low; + const half hh = one - lh, hw = one - lw; + + const unsigned int w_stride = nheads * channels; + const unsigned int h_stride = width * w_stride; + const unsigned int h_low_ptr_offset = __half2uint_rd(h_low) * h_stride; + const unsigned int h_high_ptr_offset = h_low_ptr_offset + h_stride; + const unsigned int w_low_ptr_offset = __half2uint_rd(w_low) * w_stride; + const unsigned int w_high_ptr_offset = w_low_ptr_offset + w_stride; + const unsigned int base_ptr = m * channels + c; + + half v1 = __float2half(0.0f); + if (h_low >= zero && w_low >= zero) + { + const int ptr1 = h_low_ptr_offset + w_low_ptr_offset + base_ptr; + v1 = bottom_data[ptr1]; + } + half v2 = 0; + if (h_low >= zero && w_high <= width - 1) + { + const int ptr2 = h_low_ptr_offset + w_high_ptr_offset + base_ptr; + v2 = bottom_data[ptr2]; + } + half v3 = 0; + if (h_high <= height - 1 && w_low >= zero) + { + const int ptr3 = h_high_ptr_offset + w_low_ptr_offset + base_ptr; + v3 = bottom_data[ptr3]; + } + half v4 = 0; + if (h_high <= height - 1 && w_high <= width - 1) + { + const int ptr4 = h_high_ptr_offset + w_high_ptr_offset + base_ptr; + v4 = bottom_data[ptr4]; + } + + const half w1 = hh * hw, w2 = hh * lw, w3 = lh * hw, w4 = lh * lw; + + const half val = (w1 * v1 + w2 * v2 + w3 * v3 + w4 * v4); + return val; +} + + +template +__global__ void ms_deformable_im2col_gpu_kernel(const int n, + const scalar_t *data_value, + const int *data_spatial_shapes, + const int *data_level_start_index, + const scalar_t *data_sampling_loc, + const scalar_t *data_attn_weight, + const int batch_size, + const int spatial_size, + const int num_heads, + const int channels, + const int num_levels, + const int num_query, + const int num_point, + scalar_t *data_col) +{ + CUDA_KERNEL_LOOP(index, n) + { + int _temp = index; + const int c_col = _temp % channels; + _temp /= channels; + const int sampling_index = _temp; + const int m_col = _temp % num_heads; + _temp /= num_heads; + const int q_col = _temp % num_query; + _temp /= num_query; + const int b_col = _temp; + + scalar_t *data_col_ptr = data_col + index; + int data_weight_ptr = sampling_index * num_levels * num_point; + int data_loc_w_ptr = data_weight_ptr << 1; + const int qid_stride = num_heads * channels; + const int data_value_ptr_init_offset = b_col * spatial_size * qid_stride; + scalar_t col = 0; + + for (int l_col=0; l_col < num_levels; ++l_col) + { + const int level_start_id = data_level_start_index[l_col]; + const int spatial_h_ptr = l_col << 1; + const int spatial_h = data_spatial_shapes[spatial_h_ptr]; + const int spatial_w = data_spatial_shapes[spatial_h_ptr + 1]; + const scalar_t *data_value_ptr = data_value + (data_value_ptr_init_offset + level_start_id * qid_stride); + for (int p_col=0; p_col < num_point; ++p_col) + { + const scalar_t loc_w = data_sampling_loc[data_loc_w_ptr]; + const scalar_t loc_h = data_sampling_loc[data_loc_w_ptr + 1]; + const scalar_t weight = data_attn_weight[data_weight_ptr]; + + const scalar_t h_im = loc_h * (spatial_h-1); //- 0.5; + const scalar_t w_im = loc_w * (spatial_w-1); //- 0.5; + + if (h_im > -1 && w_im > -1 && h_im < spatial_h && w_im < spatial_w) + { + col += ms_deform_attn_im2col_bilinear(data_value_ptr, spatial_h, spatial_w, num_heads, channels, h_im, w_im, m_col, c_col) * weight; + } + + data_weight_ptr += 1; + data_loc_w_ptr += 2; + } + } + *data_col_ptr = col; + } +} + +__global__ void ms_deformable_im2col_gpu_kernel_half(const int n, + const half *data_value, + const int *data_spatial_shapes, + const int *data_level_start_index, + const half *data_sampling_loc, + const half *data_attn_weight, + const int batch_size, + const int spatial_size, + const int num_heads, + const int channels, + const int num_levels, + const int num_query, + const int num_point, + half *data_col) +{ + half one(1.0f); + CUDA_KERNEL_LOOP(index, n) + { + int _temp = index; + const int c_col = _temp % channels; + _temp /= channels; + const int sampling_index = _temp; + const int m_col = _temp % num_heads; + _temp /= num_heads; + const int q_col = _temp % num_query; + _temp /= num_query; + const int b_col = _temp; + + half *data_col_ptr = data_col + index; + int data_weight_ptr = sampling_index * num_levels * num_point; + int data_loc_w_ptr = data_weight_ptr << 1; + const int qid_stride = num_heads * channels; + const int data_value_ptr_init_offset = b_col * spatial_size * qid_stride; + half col = 0; + + for (int l_col=0; l_col < num_levels; ++l_col) + { + const int level_start_id = data_level_start_index[l_col]; + const int spatial_h_ptr = l_col << 1; + const int spatial_h = data_spatial_shapes[spatial_h_ptr]; + const int spatial_w = data_spatial_shapes[spatial_h_ptr + 1]; + const half *data_value_ptr = data_value + (data_value_ptr_init_offset + level_start_id * qid_stride); + for (int p_col=0; p_col < num_point; ++p_col) + { + const half loc_w = data_sampling_loc[data_loc_w_ptr]; + const half loc_h = data_sampling_loc[data_loc_w_ptr + 1]; + const half weight = data_attn_weight[data_weight_ptr]; + + const half h_im = loc_h * __int2half_rd(spatial_h-1); //- 0.5; + const half w_im = loc_w * __int2half_rd(spatial_w-1); //- 0.5; + + if (h_im > -one && w_im > -one && h_im < __int2half_rd(spatial_h) && w_im < __int2half_rd(spatial_w)) + { + col += ms_deform_attn_im2col_bilinear_half(data_value_ptr, spatial_h, spatial_w, num_heads, channels, h_im, w_im, m_col, c_col) * weight; + } + + data_weight_ptr += 1; + data_loc_w_ptr += 2; + } + } + *data_col_ptr = col; + } +} + +int ms_deform_im2col_inference( + cudaStream_t stream, + const void* data_value, + const void* data_spatial_shapes, + const void* data_level_start_index, + const void* data_sampling_loc, + const void* data_attn_weight, + const int batch_size, + const int spatial_size, + const int num_heads, + const int channels, + const int num_levels, + const int num_query, + const int num_point, + void* data_col, + DataType mDataType +) +{ + const int num_kernels = batch_size * num_query * num_heads * channels; + const int num_actual_kernels = batch_size * num_query * num_heads * channels; + const int num_threads = CUDA_NUM_THREADS; + + + if(mDataType == DataType::kFLOAT) + { + // printf("Hey FLOAT \n"); + ms_deformable_im2col_gpu_kernel + <<>>( + num_kernels, + static_cast(data_value), + static_cast(data_spatial_shapes), + static_cast(data_level_start_index), + static_cast(data_sampling_loc), + static_cast(data_attn_weight), + batch_size, + spatial_size, + num_heads, + channels, + num_levels, + num_query, + num_point, + static_cast(data_col)); + } + else if(mDataType == DataType::kHALF) + { + // printf("Hey HALF \n"); + ms_deformable_im2col_gpu_kernel_half + <<>>( + num_kernels, + static_cast(data_value), + static_cast(data_spatial_shapes), + static_cast(data_level_start_index), + static_cast(data_sampling_loc), + static_cast(data_attn_weight), + batch_size, + spatial_size, + num_heads, + channels, + num_levels, + num_query, + num_point, + static_cast(data_col)); + + } + else return -1; + + return 0; +} \ No newline at end of file diff --git a/detr_tensorrt/plugins/ms_deform_im2col/sources/ms_deform_im2col_kernel.h b/detr_tensorrt/plugins/ms_deform_im2col/sources/ms_deform_im2col_kernel.h new file mode 100644 index 00000000..128539d3 --- /dev/null +++ b/detr_tensorrt/plugins/ms_deform_im2col/sources/ms_deform_im2col_kernel.h @@ -0,0 +1,29 @@ +#ifndef MS_DEFORM_IM2COL_KERNEL +#define MS_DEFORM_IM2COL_KERNEL + +#include "NvInfer.h" +#include "cuda_fp16.h" + +using namespace nvinfer1; + + +int ms_deform_im2col_inference( + cudaStream_t stream, + const void* data_value, + const void* data_spatial_shapes, + const void* data_level_start_index, + const void* data_sampling_loc, + const void* data_attn_weight, + const int batch_size, + const int spatial_size, + const int num_heads, + const int channels, + const int num_levels, + const int num_query, + const int num_point, + void* data_col, + DataType mDataType +); + +#endif + diff --git a/detr_tensorrt/plugins/ms_deform_im2col/sources/ms_deform_im2col_plugin.cpp b/detr_tensorrt/plugins/ms_deform_im2col/sources/ms_deform_im2col_plugin.cpp new file mode 100644 index 00000000..8a5a4cf2 --- /dev/null +++ b/detr_tensorrt/plugins/ms_deform_im2col/sources/ms_deform_im2col_plugin.cpp @@ -0,0 +1,281 @@ +#include +#include +#include +#include +#include + +#include "ms_deform_im2col_plugin.h" +#include "NvInfer.h" +#include "ms_deform_im2col_kernel.h" + +using namespace std; + +#define assertm(exp, msg) assert(((void)msg, exp)) + +using namespace nvinfer1; + +namespace { + static const char* MS_DEFORM_IM2COL_PLUGIN_VERSION{"1"}; + static const char* MS_DEFORM_IM2COL_PLUGIN_NAME{"MsDeformIm2ColTRT"}; +} + +// Static class fields initialization +PluginFieldCollection MsDeformIm2ColCreator::mFC{}; +std::vector MsDeformIm2ColCreator::mPluginAttributes; + +// statically registers the Plugin Creator to the Plugin Registry of TensorRT +REGISTER_TENSORRT_PLUGIN(MsDeformIm2ColCreator); + +// Helper function for serializing plugin +template +void writeToBuffer(char*& buffer, const T& val) +{ + *reinterpret_cast(buffer) = val; + buffer += sizeof(T); +} + +// Helper function for deserializing plugin +template +T readFromBuffer(const char*& buffer) +{ + T val = *reinterpret_cast(buffer); + buffer += sizeof(T); + return val; +} + + +MsDeformIm2Col::MsDeformIm2Col(const std::string name) + : mLayerName(name) +{ +} + +MsDeformIm2Col::MsDeformIm2Col(const std::string name, const void* data, size_t length) + : mLayerName(name) +{ + // Deserialize in the same order as serialization + const char *d = static_cast(data); + const char *a = d; + + im2col_step = readFromBuffer(d); + spatial_size = readFromBuffer(d); + num_heads = readFromBuffer(d); + channels = readFromBuffer(d); + num_levels = readFromBuffer(d); + num_query = readFromBuffer(d); + num_point = readFromBuffer(d); + mDataType = readFromBuffer(d); + + assert(d == (a + length)); +} + +const char* MsDeformIm2Col::getPluginType() const noexcept +{ + return MS_DEFORM_IM2COL_PLUGIN_NAME; +} + +const char* MsDeformIm2Col::getPluginVersion() const noexcept +{ + return MS_DEFORM_IM2COL_PLUGIN_VERSION; +} + +int MsDeformIm2Col::getNbOutputs() const noexcept +{ + return 1; +} + +Dims MsDeformIm2Col::getOutputDimensions(int index, const Dims* inputs, int nbInputDims) noexcept +{ + int len_q = inputs[3].d[0]; + int num_heads = inputs[0].d[1]; + int head_dim = inputs[0].d[2]; + return Dims2(len_q, num_heads * head_dim); +} + +int MsDeformIm2Col::initialize() noexcept +{ + return 0; +} + +int MsDeformIm2Col::enqueue(int batchSize, const void* const* inputs, void* const* outputs, + void* workspace, cudaStream_t stream) noexcept +{ + int status = -1; + // Launch CUDA kernel wrapper and save its return value + status = ms_deform_im2col_inference( + stream, inputs[0], inputs[1], inputs[2], inputs[3], inputs[4], + batchSize,spatial_size, num_heads, channels, + num_levels, num_query, num_point, outputs[0], mDataType); + assert(status == 0); + return status; +} + +size_t MsDeformIm2Col::getSerializationSize() const noexcept +{ + // 7 int paramters + return 7 * sizeof(int32_t) + sizeof(DataType); +} + +void MsDeformIm2Col::serialize(void* buffer) const noexcept +{ + char *d = static_cast(buffer); + const char *a = d; + + writeToBuffer(d, im2col_step); + writeToBuffer(d, spatial_size); + writeToBuffer(d, num_heads); + writeToBuffer(d, channels); + writeToBuffer(d, num_levels); + writeToBuffer(d, num_query); + writeToBuffer(d, num_point); + writeToBuffer(d, mDataType); + + assert(d == a + getSerializationSize()); +} + +void MsDeformIm2Col::terminate() noexcept {} + +void MsDeformIm2Col::destroy() noexcept { + // This gets called when the network containing plugin is destroyed + delete this; +} + +DataType MsDeformIm2Col::getOutputDataType(int32_t index, const DataType *inputTypes, int32_t nbInputs) const noexcept +{ + // only 1 output + assert(index == 0); + assert(nbInputs == 5); + return inputTypes[0]; // return type of input tensor image +} + +bool MsDeformIm2Col::isOutputBroadcastAcrossBatch(int32_t outputIndex, + const bool* inputIsBroadcasted, int32_t nbInputs) const noexcept +{ + return false; +} + +bool MsDeformIm2Col::canBroadcastInputAcrossBatch(int inputIndex) const noexcept +{ + return false; +} + +void MsDeformIm2Col::configurePlugin(const PluginTensorDesc* in, int32_t nbInput, + const PluginTensorDesc* out, int32_t nbOutput) noexcept +{ + assertm(nbInput == 5, "Must provide 5 inputs: value, spatial_shape, start_index, sampling_locations, attn_weights\n"); + assertm(in[0].dims.nbDims == 3, "flatten_value must have shape (len_in, num_head, head_dim)\n"); + assertm(in[1].dims.nbDims == 2, "spatial_shapes must have shape (num_levels, 2)\n"); + assertm(in[2].dims.nbDims == 1, "start_index must have shape (num_levels, )\n"); + assertm(in[3].dims.nbDims == 5, "sampling_loc must have shape (len_q, num_head, num_levels, num_points, 2)\n"); + assertm(in[4].dims.nbDims == 4, "attn_weights must have shape (len_q, num_head, num_levels, num_points)\n"); + assertm(nbOutput == 1, "This layer has only one output.\n"); + + im2col_step = 64; + spatial_size = in[0].dims.d[0]; + num_heads = in[0].dims.d[1]; + channels = in[0].dims.d[2]; + num_levels = in[3].dims.d[2]; + num_query = in[3].dims.d[0]; + num_point = in[3].dims.d[3]; + mDataType = in[0].type; + + // cout << "DEBUG in[0].type: " << (int)in[0].type << endl; + // cout << "MsDeformIm2Col DEBUG: im2col_step=" << im2col_step << endl; + // cout << "MsDeformIm2Col DEBUG: spatial_size=" << spatial_size << endl; + // cout << "MsDeformIm2Col DEBUG: num_heads=" << num_heads << endl; + // cout << "MsDeformIm2Col DEBUG: channels=" << channels << endl; + // cout << "MsDeformIm2Col DEBUG: num_levels=" << num_levels << endl; + // cout << "MsDeformIm2Col DEBUG: num_query=" << num_query << endl; + // cout << "MsDeformIm2Col DEBUG: num_point=" << num_point << endl; +} + +bool MsDeformIm2Col::supportsFormatCombination(int pos, const PluginTensorDesc* inOut, + int nbInputs, int nbOutputs) const noexcept +{ + bool ret; + ret = inOut[pos].format == TensorFormat::kLINEAR; + if((pos == 1) || (pos == 2)) + { + return ret && (inOut[pos].type == DataType::kINT32); + } + else + { + bool type_supported = (inOut[pos].type == DataType::kFLOAT) || (inOut[pos].type == DataType::kHALF); + type_supported = type_supported && (inOut[pos].type == inOut[0].type); + return ret && type_supported; + } + +} + +IPluginV2Ext* MsDeformIm2Col::clone() const noexcept +{ + auto plugin = new MsDeformIm2Col(mLayerName); + plugin->im2col_step = im2col_step; + plugin->spatial_size = spatial_size; + plugin->num_heads = num_heads; + plugin->channels = channels; + plugin->num_levels = num_levels; + plugin->num_query = num_query; + plugin->num_point = num_point; + plugin->mDataType = mDataType; + plugin->setPluginNamespace(mNamespace.c_str()); + return plugin; +} + +void MsDeformIm2Col::setPluginNamespace(const char* libNamespace) noexcept +{ + mNamespace = libNamespace; +} + +const char* MsDeformIm2Col::getPluginNamespace() const noexcept +{ + return mNamespace.c_str(); +} + +MsDeformIm2ColCreator::MsDeformIm2ColCreator() +{ + // Describe MsDeformIm2Col's required PluginField arguments + + // Fill PluginFieldCollection with PluginField arguments metadata + mFC.nbFields = mPluginAttributes.size(); + mFC.fields = mPluginAttributes.data(); +} + +const char* MsDeformIm2ColCreator::getPluginName() const noexcept +{ + return MS_DEFORM_IM2COL_PLUGIN_NAME; +} + +const char* MsDeformIm2ColCreator::getPluginVersion() const noexcept +{ + return MS_DEFORM_IM2COL_PLUGIN_VERSION; +} + +const PluginFieldCollection* MsDeformIm2ColCreator::getFieldNames() noexcept +{ + return &mFC; +} + +IPluginV2* MsDeformIm2ColCreator::createPlugin(const char* name, const PluginFieldCollection* fc) noexcept +{ + // const PluginField* fields = fc->fields; + + // Parse fields from PluginFieldCollection + return new MsDeformIm2Col(name); +} + +IPluginV2* MsDeformIm2ColCreator::deserializePlugin(const char* name, const void* serialData, size_t serialLength) noexcept +{ + // This object will be deleted when the network is destroyed, which will + // call MsDeformIm2Col::destroy() + return new MsDeformIm2Col(name, serialData, serialLength); +} + +void MsDeformIm2ColCreator::setPluginNamespace(const char* libNamespace) noexcept +{ + mNamespace = libNamespace; +} + +const char* MsDeformIm2ColCreator::getPluginNamespace() const noexcept +{ + return mNamespace.c_str(); +} diff --git a/detr_tensorrt/plugins/ms_deform_im2col/sources/ms_deform_im2col_plugin.h b/detr_tensorrt/plugins/ms_deform_im2col/sources/ms_deform_im2col_plugin.h new file mode 100644 index 00000000..aef020a5 --- /dev/null +++ b/detr_tensorrt/plugins/ms_deform_im2col/sources/ms_deform_im2col_plugin.h @@ -0,0 +1,103 @@ +#ifndef MS_DEFORM_IM2COL_TRT_PLUGIN_H +#define MS_DEFORM_IM2COL_TRT_PLUGIN_H + +#include "NvInferPlugin.h" +#include +#include + + +using namespace nvinfer1; + +// One of the preferred ways of making TensorRT to be able to see +// our custom layer requires extending IPluginV2 and IPluginCreator classes. +// For requirements for overriden functions, check TensorRT API docs. + +class MsDeformIm2Col : public IPluginV2IOExt +{ +public: + MsDeformIm2Col(const std::string name); + + MsDeformIm2Col(const std::string name, const void* data, size_t length); + + // It doesn't make sense to make MsDeformIm2Col without arguments, so we delete default constructor. + MsDeformIm2Col() = delete; + + int getNbOutputs() const noexcept override; + + Dims getOutputDimensions(int index, const Dims* inputs, int nbInputDims) noexcept override; + + int initialize() noexcept override; + + void terminate() noexcept override; + + size_t getWorkspaceSize(int) const noexcept override { return 0; }; + + int enqueue(int batchSize, const void* const* inputs, void* const* outputs, void* workspace, + cudaStream_t stream) noexcept override; + + size_t getSerializationSize() const noexcept override; + + void serialize(void* buffer) const noexcept override; + + const char* getPluginType() const noexcept override; + + const char* getPluginVersion() const noexcept override; + + void destroy() noexcept override; + + IPluginV2Ext* clone() const noexcept override; + + void setPluginNamespace(const char* pluginNamespace) noexcept override; + + const char* getPluginNamespace() const noexcept override; + + DataType getOutputDataType(int32_t index, const nvinfer1::DataType *inputTypes, int32_t nbInputs) const noexcept override; + + bool isOutputBroadcastAcrossBatch(int32_t outputIndex, const bool* inputIsBroadcasted, int32_t nbInputs) const noexcept override; + + bool canBroadcastInputAcrossBatch(int inputIndex) const noexcept override; + + void configurePlugin(const PluginTensorDesc* in, int32_t nbInput, const PluginTensorDesc* out, int32_t nbOutput) noexcept override; + + bool supportsFormatCombination(int pos, const PluginTensorDesc* inOut, int nbInputs, int nbOutputs) const noexcept override; + +private: + const std::string mLayerName; + std::string mNamespace; + int im2col_step; + int spatial_size; + int num_heads; + int channels; + int num_levels; + int num_query; + int num_point; + DataType mDataType; + // DataType mDataType = DataType::kHALF; +}; + +class MsDeformIm2ColCreator : public IPluginCreator +{ +public: + MsDeformIm2ColCreator(); + + const char* getPluginName() const noexcept override; + + const char* getPluginVersion() const noexcept override; + + const PluginFieldCollection* getFieldNames() noexcept override; + + IPluginV2* createPlugin(const char* name, const PluginFieldCollection* fc) noexcept override; + + IPluginV2* deserializePlugin(const char* name, const void* serialData, size_t serialLength) noexcept override; + + void setPluginNamespace(const char* pluginNamespace) noexcept override; + + const char* getPluginNamespace() const noexcept override; + +private: + static PluginFieldCollection mFC; + static std::vector mPluginAttributes; + std::string mNamespace; +}; + +#endif diff --git a/detr_tensorrt/plugins/ms_deform_im2col/test.py b/detr_tensorrt/plugins/ms_deform_im2col/test.py new file mode 100644 index 00000000..111c039b --- /dev/null +++ b/detr_tensorrt/plugins/ms_deform_im2col/test.py @@ -0,0 +1,119 @@ +import tensorrt as trt +import pycuda.driver as cuda +import numpy as np +import os +import ctypes +import re +import time + +from surroundnet.detr.tensorrt.TRTExecutor import TRTExecutor + +TRT_LOGGER = trt.Logger(trt.Logger.VERBOSE) +EXPLICIT_BATCH = 1 << (int)(trt.NetworkDefinitionCreationFlag.EXPLICIT_BATCH) + +current_path = os.path.dirname(os.path.abspath(__file__)) + +MS_DEFORM_IM2COL_PLUGIN_LIB = "./detr_tensorrt/plugins/ms_deform_im2col/build/libms_deform_im2col_trt.so" +ctypes.CDLL(MS_DEFORM_IM2COL_PLUGIN_LIB) +trt.init_libnvinfer_plugins(TRT_LOGGER, '') +PLUGIN_CREATORS = trt.get_plugin_registry().plugin_creator_list + +def GiB(val): + return val * 1 << 30 + +def camel_to_snake(name): + name = re.sub('(.)([A-Z][a-z]+)', r'\1_\2', name) + return re.sub('([a-z0-9])([A-Z])', r'\1_\2', name).lower() + +def get_trt_plugin(plugin_name): + plugin = None + for plugin_creator in PLUGIN_CREATORS: + if plugin_creator.name == plugin_name: + plugin = plugin_creator.create_plugin(camel_to_snake(plugin_name), None) + if plugin is None: + raise Exception(f"plugin {plugin_name} not found") + return plugin + +def build_test_engine(input_shape, dtype=trt.float32): + num_level = input_shape["flatten_sampling_loc"][3] + + with trt.Builder(TRT_LOGGER) as builder, builder.create_network(EXPLICIT_BATCH) as network: + builder.max_batch_size = 1 + config = builder.create_builder_config() + config.max_workspace_size = GiB(5) + if dtype == trt.float16: + config.set_flag(trt.BuilderFlag.FP16) + config.set_flag(trt.BuilderFlag.STRICT_TYPES) + + input_flatten_value = network.add_input( + name="input_flatten_value", dtype=dtype, shape=input_shape["flatten_value"]) + input_spatial_shapes = network.add_input( + name="input_spatial_shapes", dtype=trt.int32, shape=(1, num_level, 2)) + input_start_index = network.add_input( + name="input_start_index", dtype=trt.int32, shape=(1, num_level)) + input_flatten_sampling_loc = network.add_input( + name="input_flatten_sampling_loc", dtype=dtype, shape=input_shape["flatten_sampling_loc"]) + input_flatten_attn_weight = network.add_input( + name="input_flatten_attn_weight", dtype=dtype, shape=input_shape["flatten_attn_weight"]) + + ms_deform_im2col_node = network.add_plugin_v2( + inputs=[ + input_flatten_value, input_spatial_shapes, + input_start_index, input_flatten_sampling_loc, + input_flatten_attn_weight], + plugin=get_trt_plugin("MsDeformIm2ColTRT") + ) + ms_deform_im2col_node.name = "ms_deform_im2col_node" + ms_deform_im2col_node.get_output(0).name = "im2col_output" + + network.mark_output(ms_deform_im2col_node.get_output(0)) + + return builder.build_engine(network, config) + +def get_target_test_tensors(dtype=np.float32): + test_dir = os.path.join(current_path, "test_tensors") + test_tensors = {} + test_shapes = {} + for filename in os.listdir(test_dir): + tensor_name = filename.split(".")[0] + test_tensors[tensor_name] = np.load(os.path.join(test_dir, filename)) + + if test_tensors[tensor_name].dtype == np.int64: + test_tensors[tensor_name] = test_tensors[tensor_name].astype(np.int32) + elif test_tensors[tensor_name].dtype == np.float32: + test_tensors[tensor_name] = test_tensors[tensor_name].astype(dtype) + + test_shapes[tensor_name] = test_tensors[tensor_name].shape + return test_tensors, test_shapes + + +if __name__ == "__main__": + # for plugin in PLUGIN_CREATORS: + # print(plugin.name, plugin.plugin_version) + + test_tensors, test_shapes = get_target_test_tensors() + for key in test_tensors: + print(key, test_tensors[key].shape, test_tensors[key].dtype) + test_engine = build_test_engine(test_shapes, dtype=trt.float16) + + trt_model = TRTExecutor(engine=test_engine) + trt_model.print_bindings_info() + + trt_model.inputs[0].host = test_tensors["flatten_value"].astype(np.float16) + trt_model.inputs[1].host = test_tensors["spatial_shapes"] + trt_model.inputs[2].host = test_tensors["level_start_index"][:4].copy() + trt_model.inputs[3].host = test_tensors["flatten_sampling_loc"].astype(np.float16) + trt_model.inputs[4].host = test_tensors["flatten_attn_weight"].astype(np.float16) + + + trt_model.execute() + + N = 1000 + tic = time.time() + [trt_model.execute() for i in range(N)] + toc = time.time() + + diff = test_tensors["output"] - trt_model.outputs[0].host + print(np.abs(diff).mean()) + print(f"Execution time: {(toc - tic)/N*1000} ms") + \ No newline at end of file diff --git a/detr_tensorrt/plugins/ms_deform_im2col/test_tensors/flatten_attn_weight.npy b/detr_tensorrt/plugins/ms_deform_im2col/test_tensors/flatten_attn_weight.npy new file mode 100644 index 00000000..cdfd6859 Binary files /dev/null and b/detr_tensorrt/plugins/ms_deform_im2col/test_tensors/flatten_attn_weight.npy differ diff --git a/detr_tensorrt/plugins/ms_deform_im2col/test_tensors/flatten_sampling_loc.npy b/detr_tensorrt/plugins/ms_deform_im2col/test_tensors/flatten_sampling_loc.npy new file mode 100644 index 00000000..5dba0170 Binary files /dev/null and b/detr_tensorrt/plugins/ms_deform_im2col/test_tensors/flatten_sampling_loc.npy differ diff --git a/detr_tensorrt/plugins/ms_deform_im2col/test_tensors/flatten_value.npy b/detr_tensorrt/plugins/ms_deform_im2col/test_tensors/flatten_value.npy new file mode 100644 index 00000000..76f0b919 Binary files /dev/null and b/detr_tensorrt/plugins/ms_deform_im2col/test_tensors/flatten_value.npy differ diff --git a/detr_tensorrt/plugins/ms_deform_im2col/test_tensors/level_start_index.npy b/detr_tensorrt/plugins/ms_deform_im2col/test_tensors/level_start_index.npy new file mode 100644 index 00000000..e96f3c88 Binary files /dev/null and b/detr_tensorrt/plugins/ms_deform_im2col/test_tensors/level_start_index.npy differ diff --git a/detr_tensorrt/plugins/ms_deform_im2col/test_tensors/output.npy b/detr_tensorrt/plugins/ms_deform_im2col/test_tensors/output.npy new file mode 100644 index 00000000..c3785334 Binary files /dev/null and b/detr_tensorrt/plugins/ms_deform_im2col/test_tensors/output.npy differ diff --git a/detr_tensorrt/plugins/ms_deform_im2col/test_tensors/spatial_shapes.npy b/detr_tensorrt/plugins/ms_deform_im2col/test_tensors/spatial_shapes.npy new file mode 100644 index 00000000..66b44fa2 Binary files /dev/null and b/detr_tensorrt/plugins/ms_deform_im2col/test_tensors/spatial_shapes.npy differ diff --git a/detr_tensorrt/trt_helper.py b/detr_tensorrt/trt_helper.py new file mode 100644 index 00000000..5739402f --- /dev/null +++ b/detr_tensorrt/trt_helper.py @@ -0,0 +1,172 @@ +import pycuda.driver as cuda +import tensorrt as trt + +COCO_PANOPTIC_CLASS_NAMES = [ + 'N/A', 'person', 'bicycle', 'car', 'motorcycle', 'airplane', 'bus', + 'train', 'truck', 'boat', 'traffic light', 'fire hydrant', 'N/A', + 'stop sign', 'parking meter', 'bench', 'bird', 'cat', 'dog', 'horse', + 'sheep', 'cow', 'elephant', 'bear', 'zebra', 'giraffe', 'N/A', 'backpack', + 'umbrella', 'N/A', 'N/A', 'handbag', 'tie', 'suitcase', 'frisbee', 'skis', + 'snowboard', 'sports ball', 'kite', 'baseball bat', 'baseball glove', + 'skateboard', 'surfboard', 'tennis racket', 'bottle', 'N/A', 'wine glass', + 'cup', 'fork', 'knife', 'spoon', 'bowl', 'banana', 'apple', 'sandwich', + 'orange', 'broccoli', 'carrot', 'hot dog', 'pizza', 'donut', 'cake', + 'chair', 'couch', 'potted plant', 'bed', 'N/A', 'dining table', 'N/A', + 'N/A', 'toilet', 'N/A', 'tv', 'laptop', 'mouse', 'remote', 'keyboard', + 'cell phone', 'microwave', 'oven', 'toaster', 'sink', 'refrigerator', + 'N/A', 'book', 'clock', 'vase', 'scissors', 'teddy bear', 'hair drier', + 'toothbrush', 'N/A', 'banner', 'blanket', 'N/A', 'bridge', 'N/A', + 'N/A', 'N/A', 'N/A', 'cardboard', 'N/A', 'N/A', 'N/A', 'N/A', 'N/A', 'N/A', + 'counter', 'N/A', 'curtain', 'N/A', 'N/A', 'door-stuff', 'N/A', 'N/A', 'N/A', 'N/A', + 'N/A', 'floor-wood', 'flower', 'N/A', 'N/A', 'fruit', 'N/A', 'N/A', 'gravel', 'N/A', 'N/A', + 'house', 'N/A', 'light', 'N/A', 'N/A', 'mirror-stuff', 'N/A', 'N/A', + 'N/A', 'N/A', 'net', 'N/A', 'N/A', 'pillow', 'N/A', 'N/A', 'platform', + 'playingfield', 'N/A', 'railroad', 'river', 'road', 'N/A', 'roof', 'N/A', 'N/A', + 'sand', 'sea', 'shelf', 'N/A', 'N/A', 'snow', 'N/A', 'stairs', 'N/A', 'N/A', 'N/A', + 'N/A', 'tent', 'N/A', 'towel', 'N/A', 'N/A', 'wall-brick', 'N/A', 'N/A', 'N/A', 'wall-stone', + 'wall-tile', 'wall-wood', 'water-other', 'N/A', 'window-blind', 'window-other', 'N/A', 'N/A', + 'tree-merged', 'fence-merged', 'ceiling-merged', 'sky-other-merged', + 'cabinet-merged', 'table-merged', 'floor-other-merged', 'pavement-merged', + 'mountain-merged', 'grass-merged', 'dirt-merged', 'paper-merged', 'food-other-merged', + 'building-other-merged', 'rock-merged', 'wall-other-merged', 'rug-merged', 'N/A', + 'N/A', 'N/A', 'N/A', 'N/A', 'N/A', 'N/A', 'N/A', 'N/A', 'N/A', 'N/A', 'N/A', 'N/A', + 'N/A', 'N/A', 'N/A', 'N/A', 'N/A', 'N/A', 'N/A', 'N/A', 'N/A', 'N/A', 'N/A', 'N/A', + 'N/A', 'N/A', 'N/A', 'N/A', 'N/A', 'N/A', 'N/A', 'N/A', 'N/A', 'N/A', 'N/A', 'N/A', + 'N/A', 'N/A', 'N/A', 'N/A', 'N/A', 'N/A', 'N/A', 'N/A', 'N/A', 'N/A', 'N/A', 'N/A', 'backbround' +] + +AGV_PANOPTIC_CLASS_NAMES = [ + "person", 'carpet', 'dirt', 'floor-mable', 'floor-other', 'floor-stone', + 'floor-tile', 'floor-wood', 'gravel', 'gournd-other', 'mud', 'pavement', 'platform', 'playingfield', + 'railroad', 'road', 'sand', 'snow', 'background' +] + +def GiB(val): + """Calculate Gibibit in bits, used to set workspace for TensorRT engine builder.""" + return val * 1 << 30 + +class HostDeviceMem(object): + """ + Simple helper class to store useful data of an engine's binding + + Attributes: + ---------- + host_mem: np.ndarray + data stored in CPU + device_mem: pycuda.driver.DeviceAllocation + represent data pointer in GPU + shape: tuple + dtype: np dtype + name: str + name of the binding + """ + def __init__(self, host_mem, device_mem, shape, dtype, name=""): + self.host = host_mem + self.device = device_mem + self.shape = shape + self.dtype = dtype + self.name = name + + def __str__(self): + return "Host:\n" + str(self.host) + "\nDevice:\n" + str(self.device) + + def __repr__(self): + return self.__str__() + +def allocate_buffers(context, stream=None, is_async=True): + """ + Read bindings' information in ExecutionContext, create pagelocked np.ndarray in CPU, + allocate corresponding memory in GPU. + + Returns: + -------- + inputs: list[HostDeviceMem] + outputs: list[HostDeviceMem] + bindings: list[int] + list of pointers in GPU for each bindings + stream: pycuda.driver.Stream + used for memory transfers between CPU-GPU + """ + inputs = [] + outputs = [] + bindings = [] + if stream is None and is_async: + stream = cuda.Stream() + for binding in context.engine: + binding_idx = context.engine.get_binding_index(binding) + shape = context.get_binding_shape(binding_idx) + size = trt.volume(shape) * context.engine.max_batch_size + dtype = trt.nptype(context.engine.get_binding_dtype(binding)) + # Allocate host and device buffers + host_mem = cuda.pagelocked_empty(size, dtype) + device_mem = cuda.mem_alloc(host_mem.nbytes) + # Append the device buffer to device bindings. + bindings.append(int(device_mem)) + # Append to the appropriate list. + if context.engine.binding_is_input(binding): + inputs.append(HostDeviceMem(host_mem, device_mem, shape, dtype, binding)) + else: + outputs.append(HostDeviceMem(host_mem, device_mem, shape, dtype, binding)) + return inputs, outputs, bindings, stream + +def do_inference_async(context, bindings, inputs, outputs, stream): + """ + Execute an TensorRT engine. + + Parameters: + ----------- + context: tensorrt.IExecutionContext + bindings: list[int] + list of pointers in GPU for each bindings + inputs: list[HostDeviceMem] + outputs: list[HostDeviceMem] + stream: pycuda.driver.Stream + used for memory transfers between CPU-GPU + + Returns: + -------- + list[np.ndarray] for each outputs of the engine + """ + # Transfer input data to the GPU. + [cuda.memcpy_htod_async(inp.device, inp.host, stream) for inp in inputs] + # Run inference. + context.execute_async_v2(bindings=bindings, stream_handle=stream.handle) + # Transfer predictions back from the GPU. + [cuda.memcpy_dtoh_async(out.host, out.device, stream) for out in outputs] + # Synchronize the stream + stream.synchronize() + # Return only the host outputs. + for out in outputs: + out.host = out.host.reshape(out.shape) + return [out.host for out in outputs] + +def do_inference(context, bindings, inputs, outputs): + """ + Execute an TensorRT engine. + + Parameters: + ----------- + context: tensorrt.IExecutionContext + bindings: list[int] + list of pointers in GPU for each bindings + inputs: list[HostDeviceMem] + outputs: list[HostDeviceMem] + stream: pycuda.driver.Stream + used for memory transfers between CPU-GPU + + Returns: + -------- + list[np.ndarray] for each outputs of the engine + """ + # Transfer input data to the GPU. + [cuda.memcpy_htod(inp.device, inp.host) for inp in inputs] + # Run inference. + context.execute_v2(bindings=bindings) + # Transfer predictions back from the GPU. + [cuda.memcpy_dtoh(out.host, out.device) for out in outputs] + # # Synchronize the stream + # stream.synchronize() + # Return only the host outputs. + for out in outputs: + out.host = out.host.reshape(out.shape) + return [out.host for out in outputs] \ No newline at end of file diff --git a/detr_tf/custom_ops/ms_deform_attn/ms_deform_im2col.o b/detr_tf/custom_ops/ms_deform_attn/ms_deform_im2col.o index e025baa6..e7d6cd3d 100644 Binary files a/detr_tf/custom_ops/ms_deform_attn/ms_deform_im2col.o and b/detr_tf/custom_ops/ms_deform_attn/ms_deform_im2col.o differ diff --git a/detr_tf/inference.py b/detr_tf/inference.py index 957233fd..55f6b40c 100644 --- a/detr_tf/inference.py +++ b/detr_tf/inference.py @@ -2,15 +2,16 @@ import numpy as np import cv2 - -CLASS_COLOR_MAP = np.random.randint(0, 255, (100, 3)) +np.random.seed(20) +# CLASS_COLOR_MAP = np.random.randint(0, 255, (100, 3)) +CLASS_COLOR_MAP = np.random.random((100, 3)) from detr_tf import bbox def numpy_bbox_to_image(image, bbox_list, labels=None, scores=None, class_name=[], config=None): """ Numpy function used to display the bbox (target or prediction) """ - assert(image.dtype == np.float32 and image.dtype == np.float32 and len(image.shape) == 3) + assert(image.dtype == np.float32 and len(image.shape) == 3) if config is not None and config.normalized_method == "torch_resnet": channel_avg = np.array([0.485, 0.456, 0.406]) @@ -18,7 +19,8 @@ def numpy_bbox_to_image(image, bbox_list, labels=None, scores=None, class_name=[ image = (image * channel_std) + channel_avg image = (image*255).astype(np.uint8) elif config is not None and config.normalized_method == "tf_resnet": - image = image + mean + mean = [103.939, 116.779, 123.68] + image = image + mean image = image[..., ::-1] image = image / 255 @@ -36,9 +38,6 @@ def numpy_bbox_to_image(image, bbox_list, labels=None, scores=None, class_name=[ # Go through each bbox for b in np.argsort(bbox_area)[::-1]: - # Take a new color at reandon for this instance - instance_color = np.random.randint(0, 255, (3)) - x1, y1, x2, y2 = bbox_x1y1x2y2[b] x1, y1, x2, y2 = int(x1), int(y1), int(x2), int(y2) @@ -55,12 +54,10 @@ def numpy_bbox_to_image(image, bbox_list, labels=None, scores=None, class_name=[ class_color = CLASS_COLOR_MAP[int(class_id)] - color = instance_color - - multiplier = image.shape[0] / 500 - cv2.rectangle(image, (x1, y1), (x1 + int(multiplier*15)*len(label_name), y1 + 20), class_color.tolist(), -10) - cv2.putText(image, label_name, (x1+2, y1 + 20), cv2.FONT_HERSHEY_SIMPLEX, 0.6 * multiplier, (0, 0, 0), 1) + multiplier = image.shape[0] / 1000 cv2.rectangle(image, (x1, y1), (x2, y2), tuple(class_color.tolist()), 2) + cv2.rectangle(image, (x1, y1 - 20), (x1 + int(multiplier*15)*len(label_name), y1 + 1), class_color.tolist(), -1) + cv2.putText(image, label_name, (x1+2, y1), cv2.FONT_HERSHEY_SIMPLEX, 0.6 * multiplier, (0, 0, 0), 1) return image diff --git a/detr_tf/networks/deformable_detr.py b/detr_tf/networks/deformable_detr.py index 9bc51c00..c5b326ff 100644 --- a/detr_tf/networks/deformable_detr.py +++ b/detr_tf/networks/deformable_detr.py @@ -95,6 +95,7 @@ def __init__(self, layer_norm = functools.partial(tfa.layers.GroupNormalization, groups=32, epsilon=1e-05) #tf.keras.layers.BatchNormalization + self.input_proj_0 = tf.keras.layers.Conv2D(self.model_dim, kernel_size=1, name='input_proj/0/0', trainable=train_encoder) self.input_proj_gn_0 = layer_norm(name="input_proj_gn/0/1", trainable=train_encoder) @@ -146,7 +147,7 @@ def call(self, inp, training=False, post_process=False): class DetrClassHead(tf.keras.layers.Layer): - def __init__(self, detr, include_top, nb_class=None, refine_bbox=False, **kwargs): + def __init__(self, detr:DeformableDETR, include_top, nb_class=None, refine_bbox=False, **kwargs): """ """ super().__init__(name="detr_class_head", **kwargs) diff --git a/detr_tf/networks/transformer.py b/detr_tf/networks/transformer.py index 60c402a5..5225ccb3 100644 --- a/detr_tf/networks/transformer.py +++ b/detr_tf/networks/transformer.py @@ -306,24 +306,23 @@ def call(self, inputs, attn_mask=None, key_padding_mask=None, if attn_mask is not None: attn_output_weights += attn_mask - - if key_padding_mask is not None: - key_padding_mask = tf.cast(key_padding_mask, tf.bool) - - attn_output_weights = tf.reshape(attn_output_weights, - [batch_size, self.num_heads, target_len, source_len]) - - key_padding_mask = tf.expand_dims(key_padding_mask, 1) - key_padding_mask = tf.expand_dims(key_padding_mask, 2) - key_padding_mask = tf.tile(key_padding_mask, [1, self.num_heads, target_len, 1]) - - - #print("before attn_output_weights", attn_output_weights.shape) - attn_output_weights = tf.where(key_padding_mask, - tf.zeros_like(attn_output_weights) + float('-inf'), - attn_output_weights) - attn_output_weights = tf.reshape(attn_output_weights, - [batch_size * self.num_heads, target_len, source_len]) + # Comment out this code block when export to ONNX and TRT engine + # if key_padding_mask is not None: + # attn_output_weights = tf.reshape(attn_output_weights, + # [batch_size, self.num_heads, target_len, source_len]) + + # key_padding_mask = tf.expand_dims(key_padding_mask, 1) + # key_padding_mask = tf.expand_dims(key_padding_mask, 2) + # key_padding_mask = tf.tile(key_padding_mask, [1, self.num_heads, target_len, 1]) + + # key_padding_mask = tf.cast(key_padding_mask, tf.bool) + + # #print("before attn_output_weights", attn_output_weights.shape) + # attn_output_weights = tf.where(key_padding_mask, + # tf.zeros_like(attn_output_weights) + float('-inf'), + # attn_output_weights) + # attn_output_weights = tf.reshape(attn_output_weights, + # [batch_size * self.num_heads, target_len, source_len]) attn_output_weights = tf.nn.softmax(attn_output_weights, axis=-1) diff --git a/images/test.jpeg b/images/test.jpeg new file mode 100644 index 00000000..533e3260 Binary files /dev/null and b/images/test.jpeg differ diff --git a/tensorflow_inference.py b/tensorflow_inference.py new file mode 100644 index 00000000..a539433b --- /dev/null +++ b/tensorflow_inference.py @@ -0,0 +1,84 @@ +import tensorflow as tf +import numpy as np +import cv2 +import time + +from detr_tf.training_config import TrainingConfig, training_config_parser + +from detr_tf.networks.detr import get_detr_model +from detr_tf.networks.deformable_detr import get_deformable_detr_model + +from detr_tf.data import processing +from detr_tf.data.coco import COCO_CLASS_NAME +from detr_tf.inference import get_model_inference, numpy_bbox_to_image + + +@tf.function +def run_inference(model, images, config, use_mask=True): + + if use_mask: + mask = tf.zeros((1, images.shape[1], images.shape[2], 1)) + m_outputs = model((images, mask), training=False) + else: + m_outputs = model(images, training=False) + + predicted_bbox, predicted_labels, predicted_scores = get_model_inference( + m_outputs, config.background_class, bbox_format="xy_center", threshold=0.4) + return predicted_bbox, predicted_labels, predicted_scores + + +def main(model, use_mask=True, resize=None): + + image = cv2.imread("images/test.jpeg") + # Convert to RGB and process the input image + if resize is not None: + image = cv2.resize(image, (resize[1], resize[0])) + model_input = cv2.cvtColor(image, cv2.COLOR_BGR2RGB) + model_input = processing.normalized_images(model_input, config) + + # GPU warm up + [run_inference(model, np.expand_dims(model_input, axis=0), config, use_mask=use_mask) for i in range(3)] + + # Timimg + tic = time.time() + predicted_bbox, predicted_labels, predicted_scores = run_inference(model, np.expand_dims(model_input, axis=0), config, use_mask=use_mask) + toc = time.time() + print(f"Inference latency: {(toc - tic)*1000} ms") + + image = image.astype(np.float32) + image = image / 255 + image = numpy_bbox_to_image(image, predicted_bbox, labels=predicted_labels, scores=predicted_scores, class_name=COCO_CLASS_NAME) + + cv2.imshow('image', image) + cv2.waitKey(0) + cv2.destroyAllWindows() + + +if __name__ == "__main__": + + physical_devices = tf.config.list_physical_devices('GPU') + if len(physical_devices) == 1: + tf.config.experimental.set_memory_growth(physical_devices[0], True) + + config = TrainingConfig() + parser = training_config_parser() + + # Logging + parser.add_argument("model", type=str, help="One of 'detr', or 'deformable-detr'") + parser.add_argument("--resize", type=int, nargs=2, default=None, help="Resize image before running inference") + args = parser.parse_args() + config.update_from_args(args) + + if args.model == "detr": + print("Loading detr...") + # Load the model with the new layers to finetune + model = get_detr_model(config, include_top=True, weights="detr") + config.background_class = 91 + use_mask = True + elif args.model == "deformable-detr": + print("Loading deformable-detr...") + model = get_deformable_detr_model(config, include_top=True, weights="deformable-detr") + model.summary() + use_mask = False + # Run webcam inference + main(model, use_mask=use_mask, resize=args.resize) diff --git a/tensorrt_inference.py b/tensorrt_inference.py new file mode 100644 index 00000000..f0534592 --- /dev/null +++ b/tensorrt_inference.py @@ -0,0 +1,68 @@ +import ctypes +import tensorrt as trt +import numpy as np +import os +import cv2 +import argparse +import time +from numpy.lib.twodim_base import mask_indices + +from detr_tensorrt.TRTExecutor import TRTExecutor, TRT_LOGGER +from detr_tensorrt.inference import normalized_images, get_model_inference + +from detr_tf.data.coco import COCO_CLASS_NAME +from detr_tf.inference import numpy_bbox_to_image + +BACKGROUND_CLASS = 91 # COCO background class + + + +def run_inference(model: TRTExecutor, normalized_image: np.ndarray): + model.inputs[0].host = normalized_image + model.execute() + m_outputs = {out.name:out.host for out in model.outputs} + p_bbox, p_labels, p_scores = get_model_inference(m_outputs, BACKGROUND_CLASS, threshold=0.4) + return p_bbox, p_labels, p_scores + + +def main(engine_path): + # Load TensorRT engine + model = TRTExecutor(engine_path) + model.print_bindings_info() + + # Read image + input_shape = model.inputs[0].shape # (B, H, W, C) + H, W = input_shape[1], input_shape[2] + image = cv2.imread("images/test.jpeg") + + # Pre-process image + model_input = cv2.resize(image, (W, H)) + model_input = cv2.cvtColor(model_input, cv2.COLOR_BGR2RGB) + model_input = normalized_images(model_input) + + # Run inference + [model.execute() for i in range(3)] # GPU warm up + + tic = time.time() + p_bbox, p_labels, p_scores = run_inference(model, model_input) + toc = time.time() + print(f"Inference latency: {(toc - tic)*1000} ms") + + image = image.astype(np.float32) / 255 + image = numpy_bbox_to_image(image, p_bbox, p_labels, p_scores, COCO_CLASS_NAME) + + cv2.imshow("image", image) + cv2.waitKey(0) + cv2.destroyAllWindows() + +if __name__ == "__main__": + parser = argparse.ArgumentParser() + parser.add_argument("--engine_path", type=str, required=True) + args = parser.parse_args() + if "deformable" in args.engine_path: + # Load custom plugin for deformable-detr + MS_DEFORM_IM2COL_PLUGIN_LIB = "./detr_tensorrt/plugins/ms_deform_im2col/build/libms_deform_im2col_trt.so" + ctypes.CDLL(MS_DEFORM_IM2COL_PLUGIN_LIB) + trt.init_libnvinfer_plugins(TRT_LOGGER, '') + PLUGIN_CREATORS = trt.get_plugin_registry().plugin_creator_list + main(**vars(args)) \ No newline at end of file