Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

🐛 [Bug] quantized_resnet_test.py failed on no attribute 'EXPLICIT_PRECISION' #3362

Open
korkland opened this issue Jan 22, 2025 · 7 comments
Labels
bug Something isn't working

Comments

@korkland
Copy link

The example script fx/quantized_resnet_test.py in the Torch-TensorRT repository fails to execute due to the use of a deprecated attribute EXPLICIT_PRECISION in the TensorRT Python API. This attribute is no longer available in recent versions of TensorRT (e.g., TensorRT 10.1).

The error traceback is as follows:

Traceback (most recent call last):
  File "/home/yz9qvs/projects/Torch-TensorRT/examples/fx/quantized_resnet_test.py", line 142, in <module>
    int8_trt = build_int8_trt(rn18)
  File "/usr/local/lib/python3.10/dist-packages/torch/utils/_contextlib.py", line 116, in decorate_context
    return func(*args, **kwargs)
  File "/home/yz9qvs/projects/Torch-TensorRT/examples/fx/quantized_resnet_test.py", line 60, in build_int8_trt
    interp = TRTInterpreter(
  File "/usr/local/lib/python3.10/dist-packages/torch_tensorrt/fx/fx2trt.py", line 59, in __init__
    trt.NetworkDefinitionCreationFlag.EXPLICIT_PRECISION
AttributeError: type object 'tensorrt.tensorrt.NetworkDefinitionCreationFlag' has no attribute 'EXPLICIT_PRECISION'

To Reproduce

Steps to reproduce the behavior:

Steps to reproduce the behavior:

  1. Clone the Torch-TensorRT repository.
  2. Navigate to the examples/fx directory.
  3. Run the script quantized_resnet_test.py:
python quantized_resnet_test.py

Expected behavior

The script should run successfully, converting the quantized ResNet model to TensorRT without encountering an error.

Environment

Torch-TensorRT Version: 2.4.0
PyTorch Version: 2.4.0
CPU Architecture: amd64
OS: Ubuntu 22.04
How you installed PyTorch: pip
Build command you used (if compiling from source): N/A
Are you using local sources or building from archives: Building from local sources
Python version: 3.10
CUDA version: 11.8
GPU models and configuration: NVIDIA A40
Any other relevant information: Running TensorRT 10.1.0

Additional context

The issue seems to stem from the use of the deprecated EXPLICIT_PRECISION flag in the TRTInterpreter class within torch_tensorrt/fx/fx2trt.py. TensorRT 10.1 does not support this attribute, and its usage needs to be updated to align with the latest TensorRT API.

This script is one of the very few examples that demonstrates how to quantize a model using FX and lower it to TensorRT. It is a valuable resource for users looking to implement this workflow.

If addressing this issue immediately is not feasible, it would be extremely helpful if an alternative example could be provided to demonstrate how to achieve model quantization and conversion to TensorRT using FX. This would ensure users can still proceed with their workflows while awaiting a permanent fix.
THANKS!

@korkland korkland added the bug Something isn't working label Jan 22, 2025
@korkland korkland changed the title 🐛 [Bug] Encountered bug when using Torch-TensorRT 🐛 [Bug] quantized_resnet_test.py failed on no attribute 'EXPLICIT_PRECISION' Jan 22, 2025
@narendasan
Copy link
Collaborator

Does this example not work for you? https://pytorch.org/TensorRT/tutorials/_rendered_examples/dynamo/vgg16_ptq.html.

@korkland
Copy link
Author

Does this example not work for you? https://pytorch.org/TensorRT/tutorials/_rendered_examples/dynamo/vgg16_ptq.html.

Thank you for pointing out the vgg16_ptq example! However, this example uses modelopt for post-training quantization, while our workflow specifically relies on torch.fx for quantization and lowering to TensorRT. The quantized_resnet_test.py script appears to be one of the few examples in the repository demonstrating this approach.

Unfortunately, the script does not work as expected due to the issue with the deprecated EXPLICIT_PRECISION flag in TensorRT. This raises a couple of questions:
1. Is the fx-based quantization and TensorRT lowering workflow (as shown in quantized_resnet_test.py) still supported?
2. Has this test been disabled in the CI, or is it no longer being actively maintained?

If this workflow is no longer supported, are there plans to update it, or could you provide an alternative example demonstrating fx-based quantization and integration with TensorRT? This would be immensely helpful for users exploring this specific workflow.

Thank you for your ssupport!

@narendasan
Copy link
Collaborator

Is the fx-based quantization and TensorRT lowering workflow (as shown in quantized_resnet_test.py) still supported?

In theory you can continue to use this workflow through the dynamo frontend. As long as the operations that the fx converters use aren't getting lowered out, they will still work today. If they are getting lowered out we can patch them for dynamo.

Has this test been disabled in the CI, or is it no longer being actively maintained?

The fx frontend is no longer being actively maintained as it has been been superseded by dynamo.

If this workflow is no longer supported, are there plans to update it, or could you provide an alternative example demonstrating fx-based quantization and integration with TensorRT

We can explore making an example for dynamo frontend that replicates the behavior but it wont use the FX frontend.

@korkland
Copy link
Author

Is the fx-based quantization and TensorRT lowering workflow (as shown in quantized_resnet_test.py) still supported?

In theory you can continue to use this workflow through the dynamo frontend. As long as the operations that the fx converters use aren't getting lowered out, they will still work today. If they are getting lowered out we can patch them for dynamo.

Has this test been disabled in the CI, or is it no longer being actively maintained?

The fx frontend is no longer being actively maintained as it has been been superseded by dynamo.

If this workflow is no longer supported, are there plans to update it, or could you provide an alternative example demonstrating fx-based quantization and integration with TensorRT

We can explore making an example for dynamo frontend that replicates the behavior but it wont use the FX frontend.

Got it. So if i understand you correctly, we can still quantize using torch.fx and then lower to TensorRT with the Dynamo frontend? That’s good to know because right now, it feels like there are two main options for quantization: NVIDIA’s Model Optimization Toolkit, which is still pretty early in development, and torch.fx, which a lot of people are already using.

Would it be possible to share an example showing how to quantize a model using torch.fx and then lower it using the Dynamo frontend in Torch-TensorRT? It would really help clarify how to transition workflows like this without relying on outdated tools.

@korkland
Copy link
Author

I’m currently stuck with this workflow. We’re quantizing models using torch.fx, but I’m running into issues with all of Torch-TensorRT’s frontends:

TorchScript (TS) and Dynamo: These don’t seem to support torch.fx quantized graphs.
FX Frontend: Fails with the error 'tensorrt.tensorrt.INetworkDefinition' object has no attribute 'has_explicit_precision', making it incompatible with TensorRT 10 as recommended.
Given this situation, I’m unsure what my options are. Could you clarify:

Is there any supported way to lower torch.fx quantized graphs to TensorRT?
If not, are there plans to address this or alternative workflows you’d recommend?
Would appreciate any guidance on this.

@narendasan
Copy link
Collaborator

To solve your issue right now, you can either use ModelOpt, or quickly patch fx's TRTInterpreter to not need the explicit precision flag (its explicit precision by default so you dont need to replace it with anything). We are investigating supporting PT2 quantization in Dynamo but the opset is different so we cannot directly use the same converters. It is also unclear from PyTorch's side if PT2 quantization is their future direction or if torchao is so we are trying to clarify this with them before committing to supporting it.

@korkland
Copy link
Author

korkland commented Feb 2, 2025

To solve your issue right now, you can either use ModelOpt, or quickly patch fx's TRTInterpreter to not need the explicit precision flag (its explicit precision by default so you dont need to replace it with anything). We are investigating supporting PT2 quantization in Dynamo but the opset is different so we cannot directly use the same converters. It is also unclear from PyTorch's side if PT2 quantization is their future direction or if torchao is so we are trying to clarify this with them before committing to supporting it.

should int8 quantization work in torch_tensorrt 2.4.0 (we are restricted to python version)?
im trying to run the vgg example after the vgg training example and got an error:
using modelopt 0.17.0
it works using fp8 but we are interested in int8

/usr/bin/python /home/yz9qvs/projects/l3_workspace/adas/neural_nets_emb/tensorrt/test_vgg16_dynamo_int8.py
WARNING:torch_tensorrt.dynamo.conversion.aten_ops_converters:Unable to import quantization op. Please install modelopt library (https://github.com/NVIDIA/TensorRT-Model-Optimizer?tab=readme-ov-file#installation) to add support for compiling quantized models
WARNING:py.warnings:/home/yz9qvs/projects/l3_workspace/adas/neural_nets_emb/tensorrt/test_vgg16_dynamo_int8.py:93: FutureWarning: You are using `torch.load` with `weights_only=False` (the current default value), which uses the default pickle module implicitly. It is possible to construct malicious pickle data which will execute arbitrary code during unpickling (See https://github.com/pytorch/pytorch/blob/main/SECURITY.md#untrusted-models for more details). In a future release, the default value for `weights_only` will be flipped to `True`. This limits the functions that could be executed during unpickling. Arbitrary objects will no longer be allowed to be loaded via this mode unless they are explicitly allowlisted by the user via `torch.serialization.add_safe_globals`. We recommend you start setting `weights_only=True` for any use case where you don't have full control of the loaded file. Please open an issue on GitHub for any issues related to this experimental feature.
  ckpt = torch.load("/home/yz9qvs/projects/Torch-TensorRT/examples/int8/training/vgg16/vgg16_ckpts/ckpt_epoch100.pth")

Downloading https://www.cs.toronto.edu/~kriz/cifar-10-python.tar.gz to ./data/cifar-10-python.tar.gz
100%|██████████████████████████████████████████████████████████████████████████████████████| 170498071/170498071 [00:26<00:00, 6460079.20it/s]
Extracting ./data/cifar-10-python.tar.gz to ./data
Inserted 86 quantizers
PTQ Loss: 0.00010 Acc: 99.63%
Files already downloaded and verified
Loading extension modelopt_cuda_ext...
INFO:torch_tensorrt.dynamo._compiler:Compilation Settings: CompilationSettings(enabled_precisions={<dtype.i8: 3>}, debug=False, workspace_size=0, min_block_size=1, torch_executed_ops=set(), pass_through_build_failures=False, max_aux_streams=None, version_compatible=False, optimization_level=None, use_python_runtime=False, truncate_double=False, use_fast_partitioner=True, enable_experimental_decompositions=False, device=Device(type=DeviceType.GPU, gpu_id=0), require_full_compilation=False, disable_tf32=False, assume_dynamic_shape_support=False, sparse_weights=False, refit=False, engine_capability=<EngineCapability.STANDARD: 1>, num_avg_timing_iters=1, dla_sram_size=1048576, dla_local_dram_size=1073741824, dla_global_dram_size=536870912, dryrun=False, hardware_compatible=False, timing_cache_path='/tmp/timing_cache.bin')

INFO:torch_tensorrt.dynamo._compiler:Partitioning the graph via the fast partitioner
INFO:torch_tensorrt [TensorRT Conversion Context]:[MemUsageChange] Init CUDA: CPU +2, GPU +0, now: CPU 597, GPU 2437 (MiB)
INFO:torch_tensorrt [TensorRT Conversion Context]:[MemUsageChange] Init builder kernel library: CPU +1619, GPU +290, now: CPU 2363, GPU 2727 (MiB)
WARNING:py.warnings:/usr/local/lib/python3.10/dist-packages/torch_tensorrt/dynamo/conversion/impl/activation/base.py:40: DeprecationWarning: Use Deprecated in TensorRT 10.1. Superseded by explicit quantization. instead.
  if input_val.dynamic_range is not None and dyn_range_fn is not None:

INFO:torch_tensorrt.dynamo.conversion._TRTInterpreter:TRT INetwork construction elapsed time: 0:00:00.004329
WARNING:torch_tensorrt [TensorRT Conversion Context]:Calibrator is not being used. Users must provide dynamic range for all tensors that are not Int32 or Bool.
ERROR:torch_tensorrt [TensorRT Conversion Context]:IBuilder::buildSerializedNetwork: Error Code 4: Internal Error (Calibration failure occurred with no scaling factors detected. This could be due to no int8 calibrator or insufficient custom scales for network layers. Please see int8 sample to setup calibration correctly.)
Traceback (most recent call last):
  File "/home/yz9qvs/projects/l3_workspace/adas/neural_nets_emb/tensorrt/test_vgg16_dynamo_int8.py", line 185, in <module>
    trt_model = torchtrt.dynamo.compile(
  File "/usr/local/lib/python3.10/dist-packages/torch_tensorrt/dynamo/_compiler.py", line 230, in compile
    trt_gm = compile_module(gm, inputs, settings)
  File "/usr/local/lib/python3.10/dist-packages/torch_tensorrt/dynamo/_compiler.py", line 418, in compile_module
    trt_module = convert_module(
  File "/usr/local/lib/python3.10/dist-packages/torch_tensorrt/dynamo/conversion/_conversion.py", line 106, in convert_module
    interpreter_result = interpret_module_to_result(module, inputs, settings)
  File "/usr/local/lib/python3.10/dist-packages/torch_tensorrt/dynamo/conversion/_conversion.py", line 87, in interpret_module_to_result
    interpreter_result = interpreter.run()
  File "/usr/local/lib/python3.10/dist-packages/torch_tensorrt/dynamo/conversion/_TRTInterpreter.py", line 344, in run
    assert serialized_engine
AssertionError
WARNING:py.warnings:/usr/lib/python3.10/tempfile.py:999: ResourceWarning: Implicitly cleaning up <TemporaryDirectory '/tmp/tmpp57x4hyy'>
  _warnings.warn(warn_message, ResourceWarning)

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
bug Something isn't working
Projects
None yet
Development

No branches or pull requests

2 participants