Open
Description
Description
Casting an UINT8 to FLOAT16/32 after a transpose
operation breaks the graph (network.num_outputs = 0
).
Casting before the transpose works fine.
Environment
TensorRT Version: 10.1.0
NVIDIA GPU: A100-SXM4-80GB
NVIDIA Driver Version: 550.90.07
CUDA Version: 12.5
CUDNN Version: 9.1.0
Operating System: Ubuntu 22.04.4 LTS
Python Version: Python 3.10.12
PyTorch Version: 2.3.1+cu121
ONNX version: 1.16.1
Container: nvcr.io/nvidia/tensorrt:24.06-py3
Steps To Reproduce
Minimal example
import tensorrt as trt
import torch as th
class TransposeCast(th.nn.Module):
def forward(self, x):
x = x.permute(1,0)
x = x.float()
return x
class CastTranspose(th.nn.Module):
def forward(self, x):
x = x.float()
x = x.permute(1,0)
return x
TC_model = TransposeCast()
CT_model = CastTranspose()
data = th.zeros((1,1), dtype=th.uint8)
th.onnx.export(TC_model, data, "TC.onnx", opset_version=17)
th.onnx.export(CT_model, data, "CT.onnx", opset_version=17)
logger = trt.Logger()
#logger.min_severity = trt.Logger.Severity.VERBOSE
builder = trt.Builder(logger)
for file in ("TC.onnx", "CT.onnx"):
network = builder.create_network(1 << int(trt.NetworkDefinitionCreationFlag.EXPLICIT_BATCH))
parser = trt.OnnxParser(network, logger)
with open(file, 'rb') as fd:
parser.parse(fd.read())
print(f"{file}: {network.num_outputs=}")
Output
TC.onnx: network.num_outputs=0
CT.onnx: network.num_outputs=1
Both should have one output.
Logs:
TransposeCast
[07/05/2024-14:25:29] [TRT] [I] [MemUsageChange] Init CUDA: CPU +19, GPU +0, now: CPU 111, GPU 26482 (MiB)
[07/05/2024-14:25:29] [TRT] [V] Trying to load shared library libnvinfer_builder_resource.so.10.1.0
[07/05/2024-14:25:29] [TRT] [V] Loaded shared library libnvinfer_builder_resource.so.10.1.0
[07/05/2024-14:25:31] [TRT] [I] [MemUsageChange] Init builder kernel library: CPU +1931, GPU +354, now: CPU 2189, GPU 26836 (MiB)
[07/05/2024-14:25:31] [TRT] [V] CUDA lazy loading is enabled.
[07/05/2024-14:25:31] [TRT] [V] Adding network input: onnx::Transpose_0 with dtype: uint8, dimensions: (1, 1)
[07/05/2024-14:25:31] [TRT] [V] Registering tensor: onnx::Transpose_0 for ONNX tensor: onnx::Transpose_0
[07/05/2024-14:25:31] [TRT] [V] Static check for parsing node: /Transpose [Transpose]
[07/05/2024-14:25:31] [TRT] [V] Parsing node: /Transpose [Transpose]
[07/05/2024-14:25:31] [TRT] [V] Searching for input: onnx::Transpose_0
[07/05/2024-14:25:31] [TRT] [V] /Transpose [Transpose] inputs: [onnx::Transpose_0 -> (1, 1)[UINT8]],
[07/05/2024-14:25:31] [TRT] [V] Registering layer: /Transpose for ONNX node: /Transpose
[07/05/2024-14:25:31] [TRT] [V] Registering tensor: /Transpose_output_0 for ONNX tensor: /Transpose_output_0
[07/05/2024-14:25:31] [TRT] [V] Static check for parsing node: /Cast [Cast]
TC.onnx: network.num_outputs=0
CastTranspose
[07/05/2024-14:25:31] [TRT] [V] Adding network input: onnx::Cast_0 with dtype: uint8, dimensions: (1, 1)
[07/05/2024-14:25:31] [TRT] [V] Registering tensor: onnx::Cast_0 for ONNX tensor: onnx::Cast_0
[07/05/2024-14:25:31] [TRT] [V] Static check for parsing node: /Cast [Cast]
[07/05/2024-14:25:31] [TRT] [V] Parsing node: /Cast [Cast]
[07/05/2024-14:25:31] [TRT] [V] Searching for input: onnx::Cast_0
[07/05/2024-14:25:31] [TRT] [V] /Cast [Cast] inputs: [onnx::Cast_0 -> (1, 1)[UINT8]],
[07/05/2024-14:25:31] [TRT] [V] Casting to type: float32
[07/05/2024-14:25:31] [TRT] [V] Registering layer: /Cast for ONNX node: /Cast
[07/05/2024-14:25:31] [TRT] [V] Registering tensor: /Cast_output_0 for ONNX tensor: /Cast_output_0
[07/05/2024-14:25:31] [TRT] [V] /Cast [Cast] outputs: [/Cast_output_0 -> (1, 1)[FLOAT]],
[07/05/2024-14:25:31] [TRT] [V] Static check for parsing node: /Transpose [Transpose]
[07/05/2024-14:25:31] [TRT] [V] Parsing node: /Transpose [Transpose]
[07/05/2024-14:25:31] [TRT] [V] Searching for input: /Cast_output_0
[07/05/2024-14:25:31] [TRT] [V] /Transpose [Transpose] inputs: [/Cast_output_0 -> (1, 1)[FLOAT]],
[07/05/2024-14:25:31] [TRT] [V] Registering layer: /Transpose for ONNX node: /Transpose
[07/05/2024-14:25:31] [TRT] [V] Registering tensor: 2_0 for ONNX tensor: 2
[07/05/2024-14:25:31] [TRT] [V] /Transpose [Transpose] outputs: [2 -> (1, 1)[FLOAT]],
[07/05/2024-14:25:31] [TRT] [V] Marking 2_0 as output: 2
CT.onnx: network.num_outputs=1