Examples in this folder demonstrates doing distributed inference on multiple devices with Torch-TensorRT backend.
Data parallel distributed inference based on Accelerate
Using Accelerate users can achieve data parallel distributed inference with Torch-TensorRt backend. In this case, the entire model will be loaded onto each GPU and different chunks of batch input is processed on each device.
See the examples data_parallel_gpt2.py and data_parallel_stable_diffusion.py for more details.
Here we use torch.distributed as an example, but compilation with tensor parallelism is agnostic to the implementation framework as long as the module is properly sharded.
torchrun --nproc_per_node=2 tensor_parallel_llama2.py
We use torch.distributed package to add shard the model with Tensor parallelism. The distributed ops (all_gather
and all_reduce
) are then expressed as TensorRT-LLM plugins to avoid graph breaks during Torch-TensorRT compilation. The converters for these operators are already available in Torch-TensorRT. The functional implementation of ops is imported from tensorrt_llm
package (to be more specific, only libnvinfer_plugin_tensorrt_llm.so
is required). So we have two options here
Follow the instructions to install TensorRT-LLM
If the default installation fails due to issues like library version mismatches or Python compatibility, it is recommended to use Option 2. After a successful installation, ensure you test by running import torch_tensorrt
to confirm it works without errors. The import might fail if the tensorrt_llm
installation overrides torch_tensorrt
dependencies. Option 2 is particularly advisable if you prefer not to install tensorrt_llm
and its associated dependencies.
Another alternative is to load the libnvinfer_plugin_tensorrt_llm.so
directly. You can do this by
- Downloading the tensorrt_llm-0.16.0 wheel file from the NVIDIA python index.
- Extract the wheel file to a directory and you can find the
libnvinfer_plugin_tensorrt_llm.so
library undertensorrt_llm/libs
directory. - Please set the environment variable TRTLLM_PLUGINS_PATH to the above extracted path at the initialize_distributed_env() call.
After configuring the TensorRT-LLM or the TensorRT-LLM plugin library path, please run the following command which illustrates tensor parallelism of a simple model and compilation with Torch-TensorRT
mpirun -n 2 --allow-run-as-root python tensor_parallel_simple_example.py
We also provide a tensor paralellism compilation example on a more advanced model like Llama-3
. Here's the command to run it
mpirun -n 2 --allow-run-as-root python tensor_parallel_llama3.py