Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

How can I transfer t5 decoder_model_merged.onnx to tensorrt #4032

Closed
EASTERNTIGER opened this issue Jul 30, 2024 · 3 comments
Closed

How can I transfer t5 decoder_model_merged.onnx to tensorrt #4032

EASTERNTIGER opened this issue Jul 30, 2024 · 3 comments

Comments

@EASTERNTIGER
Copy link

Hi,when I try to use command:
trtexec --onnx=decoder_model_merged.onnx --saveEngine=decoder_model_merged.trt
in linux,it showed:
[07/30/2024-02:00:24] [E] Error[4]: [graphShapeAnalyzer.cpp::analyzeShapes::2084] Error Code 4: Miscellaneous (IConditionalOutputLayer optimum::if_OutputLayer_4945: optimum::if_OutputLayer_4945: dimensions not compatible for if-conditional outputs)
[07/30/2024-02:00:24] [E] Engine could not be created from network
[07/30/2024-02:00:24] [E] Building engine failed
[07/30/2024-02:00:24] [E] Failed to create engine from model or file.
[07/30/2024-02:00:24] [E] Engine set up failed
image

when I use the same command:
trtexec --onnx=decoder_model.onnx --saveEngine=decoder_model.trt
it works , so why in the same T5 onnx model,decoder_model.onnx、encoder_model.onnx、decoder_with_past_model.onnx could be transferred to trt successfully, decoder_model_merged.onnx could not,How can I fixed that?

@lix19937
Copy link

dimensions not compatible for if-conditional outputs)

Your model has if-then-else flow control. There can be no cross-edges connecting layers in the true-branch to layers in the false-branch, and vice versa. In other words, the outputs of one branch cannot depend on layers in the other branch.

You should export scripted PyTorch code to ONNX. Like follow:

import torch.onnx
import torch
import tensorrt as trt
import numpy as np

TRT_LOGGER = trt.Logger(trt.Logger.WARNING)
EXPLICIT_BATCH = 1 << (int)(trt.NetworkDefinitionCreationFlag.EXPLICIT_BATCH)

@torch.jit.script
def sum_even(items):
    s = torch.zeros(1, dtype=torch.float)
    for c in items:
        if c % 2 == 0:
            s += c
    return s

class ExampleModel(torch.nn.Module):
    def __init__(self):
        super().__init__()

    def forward(self, items):
        return sum_even(items)

def build_engine(model_file):
    builder = trt.Builder(TRT_LOGGER)
    network = builder.create_network(EXPLICIT_BATCH)
    config = builder.create_builder_config()
    parser = trt.OnnxParser(network, TRT_LOGGER)

    with open(model_file, 'rb') as model:
        assert parser.parse(model.read())
        return builder.build_engine(network, config)

def export_to_onnx():
    items = torch.zeros(4, dtype=torch.float)
    example = ExampleModel()
    torch.onnx.export(example, (items), "example.onnx", verbose=False, opset_version=13, enable_onnx_checker=False, do_constant_folding=True)

export_to_onnx()
build_engine("example.onnx")  # 
print("done") # lix19937 ---   

Or you replace the if-else control with fixed control branch.

@EASTERNTIGER
Copy link
Author

dimensions not compatible for if-conditional outputs)

Your model has if-then-else flow control. There can be no cross-edges connecting layers in the true-branch to layers in the false-branch, and vice versa. In other words, the outputs of one branch cannot depend on layers in the other branch.

You should export scripted PyTorch code to ONNX. Like follow:

import torch.onnx
import torch
import tensorrt as trt
import numpy as np

TRT_LOGGER = trt.Logger(trt.Logger.WARNING)
EXPLICIT_BATCH = 1 << (int)(trt.NetworkDefinitionCreationFlag.EXPLICIT_BATCH)

@torch.jit.script
def sum_even(items):
    s = torch.zeros(1, dtype=torch.float)
    for c in items:
        if c % 2 == 0:
            s += c
    return s

class ExampleModel(torch.nn.Module):
    def __init__(self):
        super().__init__()

    def forward(self, items):
        return sum_even(items)

def build_engine(model_file):
    builder = trt.Builder(TRT_LOGGER)
    network = builder.create_network(EXPLICIT_BATCH)
    config = builder.create_builder_config()
    parser = trt.OnnxParser(network, TRT_LOGGER)

    with open(model_file, 'rb') as model:
        assert parser.parse(model.read())
        return builder.build_engine(network, config)

def export_to_onnx():
    items = torch.zeros(4, dtype=torch.float)
    example = ExampleModel()
    torch.onnx.export(example, (items), "example.onnx", verbose=False, opset_version=13, enable_onnx_checker=False, do_constant_folding=True)

export_to_onnx()
build_engine("example.onnx")  # 
print("done") # lix19937 ---   

Or you replace the if-else control with fixed control branch.

thank you so much for your reply!I will try that

@poweiw
Copy link
Collaborator

poweiw commented Feb 11, 2025

Closing this issue for now but feel free to reopen if you see further issues!

@poweiw poweiw closed this as completed Feb 11, 2025
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
None yet
Projects
None yet
Development

No branches or pull requests

3 participants