Skip to content

Commit 04cb5c6

Browse files
committed
chore: Fix docs and example
1 parent 186ac1b commit 04cb5c6

File tree

3 files changed

+24
-28
lines changed

3 files changed

+24
-28
lines changed

Diff for: examples/distributed_inference/README.md

+22-25
Original file line numberDiff line numberDiff line change
@@ -2,49 +2,46 @@
22

33
Examples in this folder demonstrates doing distributed inference on multiple devices with Torch-TensorRT backend.
44

5-
1. Data parallel distributed inference based on [Accelerate](https://huggingface.co/docs/accelerate/usage_guides/distributed_inference)
5+
## Data parallel distributed inference based on [Accelerate](https://huggingface.co/docs/accelerate/usage_guides/distributed_inference)
66

77
Using Accelerate users can achieve data parallel distributed inference with Torch-TensorRt backend. In this case, the entire model
88
will be loaded onto each GPU and different chunks of batch input is processed on each device.
99

10-
See the examples started with `data_parallel` for more details.
10+
See the examples [data_parallel_gpt2.py](https://github.com/pytorch/TensorRT/blob/main/examples/distributed_inference/data_parallel_gpt2.py) and [data_parallel_stable_diffusion.py](https://github.com/pytorch/TensorRT/blob/main/examples/distributed_inference/data_parallel_stable_diffusion.py) for more details.
1111

12-
2. Tensor parallel distributed inference
12+
## Tensor parallel distributed inference
1313

1414
Here we use torch.distributed as an example, but compilation with tensor parallelism is agnostic to the implementation framework as long as the module is properly sharded.
1515

1616
torchrun --nproc_per_node=2 tensor_parallel_llama2.py
1717

18-
3. Tensor parallel distributed inference using nccl ops plugin
18+
## Tensor parallel distributed inference on a simple model using nccl ops plugin
1919

20-
apt install libmpich-dev
20+
21+
We use [torch.distributed](https://pytorch.org/docs/stable/distributed.html) package to add shard the model with Tensor parallelism. The distributed ops (`all_gather` and `all_reduce`) are then expressed as TensorRT-LLM plugins to avoid graph breaks during Torch-TensorRT compilation. The [converters for these operators](https://github.com/pytorch/TensorRT/blob/main/py/torch_tensorrt/dynamo/conversion/custom_ops_converters.py#L25-L55) are already available in Torch-TensorRT. The functional implementation of ops is imported from `tensorrt_llm` package (to be more specific, only `libnvinfer_plugin_tensorrt_llm.so` is required). So we have two options here
2122

22-
apt install libopenmpi-dev
23+
### Option 1: Install TensorRT-LLM
2324

24-
#For python3.10
25+
Follow the instructions to [install TensorRT-LLM](https://nvidia.github.io/TensorRT-LLM/installation/linux.html)
2526

26-
pip install tensorrt-llm
27+
If the default installation fails due to issues like library version mismatches or Python compatibility, it is recommended to use Option 2. After a successful installation, ensure you test by running `import torch_tensorrt` to confirm it works without errors. The import might fail if the `tensorrt_llm` installation overrides `torch_tensorrt` dependencies. Option 2 is particularly advisable if you prefer not to install `tensorrt_llm` and its associated dependencies.
2728

28-
For other python versions, you need to load the libnvinfer_plugin_tensorrt_llm.so. Please set that in the environment variable export TRTLLM_PLUGINS_PATH={lib_path}. For example, we have already set the variable in initialize_distributed_env(). You can replace this with your TRTLLM_PLUGINS_PATH and unset it there
29+
### Option 2: Link the TensorRT-LLM directly.
2930

30-
#then pip install the tensorrt and torch version compatible with installed torchTRT
31+
Another alternative is to load the `libnvinfer_plugin_tensorrt_llm.so` directly. You can do this by
32+
* Downloading the [tensorrt_llm-0.16.0](https://pypi.nvidia.com/tensorrt-llm/tensorrt_llm-0.16.0-cp310-cp310-linux_x86_64.whl#sha256=f86c6b89647802f49b26b4f6e40824701da14c0f053dbda3e1e7a8709d6939c7) wheel file from the NVIDIA python index.
33+
* Extract the wheel file to a directory and you can find the `libnvinfer_plugin_tensorrt_llm.so` library under `tensorrt_llm/libs` directory.
34+
* Please set the environment variable TRTLLM_PLUGINS_PATH to the above extracted path at the [initialize_distributed_env()](https://github.com/pytorch/TensorRT/blob/54e36dbafe567c75f36b3edb22d6f49d4278c12a/examples/distributed_inference/tensor_parallel_initialize_dist.py#L45) call.
3135

32-
mpirun -n 2 --allow-run-as-root python tensor_parallel_simple_example.py
3336

34-
#For other python
37+
After configuring the TensorRT-LLM or the TensorRT-LLM plugin library path, please run the following command which illustrates tensor parallelism of a simple model and compilation with Torch-TensorRT
3538

36-
4. Tensor parallel distributed llama3 inference using nccl ops plugin
39+
```py
40+
mpirun -n 2 --allow-run-as-root python tensor_parallel_simple_example.py
41+
```
3742

38-
apt install libmpich-dev
43+
We also provide a tensor paralellism compilation example on a more advanced model like `Llama-3`. Here's the command to run it
3944

40-
apt install libopenmpi-dev
41-
42-
#For python3.10
43-
44-
pip install tensorrt-llm
45-
46-
For other python versions, you need to load the libnvinfer_plugin_tensorrt_llm.so
47-
48-
#then pip install the tensorrt and torch version compatible with installed torchTRT
49-
50-
mpirun -n 2 --allow-run-as-root python tensor_parallel_llama3.py
45+
```py
46+
mpirun -n 2 --allow-run-as-root python tensor_parallel_llama3.py
47+
```

Diff for: examples/distributed_inference/tensor_parallel_llama3.py

+2-2
Original file line numberDiff line numberDiff line change
@@ -17,8 +17,8 @@
1717
device_mesh, _world_size, _rank, logger = initialize_distributed_env(
1818
"./tensor_parallel_llama3"
1919
)
20-
# Import should be after initialization of the TRT-LLM plugin .so path
21-
import tensorrt_llm
20+
21+
import torch_tensorrt
2222

2323
logger.info(f"Starting PyTorch TP example on rank {_rank}.")
2424
assert (

Diff for: examples/distributed_inference/tensor_parallel_simple_example.py

-1
Original file line numberDiff line numberDiff line change
@@ -15,7 +15,6 @@
1515
device_mesh, _world_size, _rank, logger = initialize_distributed_env(
1616
"./tensor_parallel_simple_example"
1717
)
18-
import tensorrt_llm
1918

2019
"""
2120
This example copies some code from https://github.com/pytorch/examples/blob/main/distributed/tensor_parallelism/tensor_parallel_example.py

0 commit comments

Comments
 (0)