diff --git a/docsrc/index.rst b/docsrc/index.rst index e7d5250e52..b4d96dbc8d 100644 --- a/docsrc/index.rst +++ b/docsrc/index.rst @@ -68,6 +68,7 @@ Tutorials * :ref:`mutable_torchtrt_module_example` * :ref:`weight_streaming_example` * :ref:`pre_allocated_output_example` +* :ref:`tensor_parallel_llama3` .. toctree:: :caption: Tutorials @@ -87,6 +88,7 @@ Tutorials tutorials/_rendered_examples/dynamo/mutable_torchtrt_module_example tutorials/_rendered_examples/dynamo/weight_streaming_example tutorials/_rendered_examples/dynamo/pre_allocated_output_example + tutorials/_rendered_examples/distributed_inference/tensor_parallel_llama3 Dynamo Frontend ---------------- diff --git a/examples/distributed_inference/README.md b/examples/distributed_inference/README.md deleted file mode 100644 index d4cf9508e1..0000000000 --- a/examples/distributed_inference/README.md +++ /dev/null @@ -1,50 +0,0 @@ -# Torch-TensorRT parallelism for distributed inference - -Examples in this folder demonstrates doing distributed inference on multiple devices with Torch-TensorRT backend. - -1. Data parallel distributed inference based on [Accelerate](https://huggingface.co/docs/accelerate/usage_guides/distributed_inference) - -Using Accelerate users can achieve data parallel distributed inference with Torch-TensorRt backend. In this case, the entire model -will be loaded onto each GPU and different chunks of batch input is processed on each device. - -See the examples started with `data_parallel` for more details. - -2. Tensor parallel distributed inference - -Here we use torch.distributed as an example, but compilation with tensor parallelism is agnostic to the implementation framework as long as the module is properly sharded. - -torchrun --nproc_per_node=2 tensor_parallel_llama2.py - -3. Tensor parallel distributed inference using nccl ops plugin - - apt install libmpich-dev - - apt install libopenmpi-dev - - #For python3.10 - - pip install tensorrt-llm - - For other python versions, you need to load the libnvinfer_plugin_tensorrt_llm.so. Please set that in the environment variable export TRTLLM_PLUGINS_PATH={lib_path}. For example, we have already set the variable in initialize_distributed_env(). You can replace this with your TRTLLM_PLUGINS_PATH and unset it there - - #then pip install the tensorrt and torch version compatible with installed torchTRT - - mpirun -n 2 --allow-run-as-root python tensor_parallel_simple_example.py - - #For other python - -4. Tensor parallel distributed llama3 inference using nccl ops plugin - - apt install libmpich-dev - - apt install libopenmpi-dev - -#For python3.10 - - pip install tensorrt-llm - - For other python versions, you need to load the libnvinfer_plugin_tensorrt_llm.so - - #then pip install the tensorrt and torch version compatible with installed torchTRT - - mpirun -n 2 --allow-run-as-root python tensor_parallel_llama3.py diff --git a/examples/distributed_inference/README.rst b/examples/distributed_inference/README.rst new file mode 100644 index 0000000000..5f9e2ec99a --- /dev/null +++ b/examples/distributed_inference/README.rst @@ -0,0 +1,83 @@ +.. _tensor_parallel_llama3: + +Torch-TensorRT Parallelism for Distributed Inference +==================================================== + +Examples in this folder demonstrate distributed inference on multiple devices with the Torch-TensorRT backend. + +Data Parallel Distributed Inference based on `Accelerate `_ +----------------------------------------------------------------------------------------------------------------------------------------- + +Using Accelerate, users can achieve data parallel distributed inference with the Torch-TensorRT backend. +In this case, the entire model will be loaded onto each GPU, and different chunks of batch input are processed on each device. + +See the examples: + +- `data_parallel_gpt2.py `_ +- `data_parallel_stable_diffusion.py `_ + +for more details. + +Tensor Parallel Distributed Inference +-------------------------------------- + +Here, we use `torch.distributed` as an example, but compilation with tensor parallelism is agnostic to the implementation framework as long as the module is properly sharded. + +.. code-block:: bash + + torchrun --nproc_per_node=2 tensor_parallel_llama2.py + +Tensor Parallel Distributed Inference on a Simple Model using NCCL Ops Plugin +------------------------------------------------------------------------------ + +We use `torch.distributed `_ to shard the model with Tensor parallelism. +The distributed operations (`all_gather` and `all_reduce`) are then expressed as TensorRT-LLM plugins to avoid graph breaks during Torch-TensorRT compilation. +The `converters for these operators `_ are already available in Torch-TensorRT. +The functional implementation of ops is imported from the `tensorrt_llm` package (specifically, `libnvinfer_plugin_tensorrt_llm.so` is required). + +We have two options: + +Option 1: Install TensorRT-LLM +^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^ + +Follow the instructions to `install TensorRT-LLM `_. +Please note that before installing TensorRT-LLM, you need to + +1. apt install libmpich-dev +2. apt install libopenmpi-dev + +If the default installation fails due to issues like library version mismatches or Python compatibility, consider using Option 2. +After a successful installation, test by running: + +.. code-block:: python + + import torch_tensorrt + +to ensure it works without errors. +The import might fail if `tensorrt_llm` overrides `torch_tensorrt` dependencies. +Option 2 is preferable if you do not wish to install `tensorrt_llm` and its dependencies. + +Option 2: Link the TensorRT-LLM Directly +^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^ + +Alternatively, you can load `libnvinfer_plugin_tensorrt_llm.so` manually: + +1. Download the `tensorrt_llm-0.16.0 `_ wheel file from NVIDIA's Python index. +2. Extract the wheel file to a directory and locate `libnvinfer_plugin_tensorrt_llm.so` under the `tensorrt_llm/libs` directory. +3. Set the environment variable `TRTLLM_PLUGINS_PATH` to the extracted path at the `initialize_distributed_env() `_ call. + +After configuring TensorRT-LLM or the TensorRT-LLM plugin library path, run the following command to illustrate tensor parallelism of a simple model and compilation with Torch-TensorRT: + +.. code-block:: bash + + mpirun -n 2 --allow-run-as-root python tensor_parallel_simple_example.py + +We also provide a tensor parallelism compilation example on a more advanced model like `Llama-3`. Run the following command: + +.. code-block:: bash + + mpirun -n 2 --allow-run-as-root python tensor_parallel_llama3.py + +Tutorials +----------------------------------------- +* :ref:`tensor_parallel_llama3`: Illustration of distributed inference on multiple devices with the Torch-TensorRT backend. \ No newline at end of file diff --git a/examples/distributed_inference/llama3_model.py b/examples/distributed_inference/llama3_model.py index 9fa59b5c49..e5e8e0ca6c 100644 --- a/examples/distributed_inference/llama3_model.py +++ b/examples/distributed_inference/llama3_model.py @@ -1,3 +1,7 @@ +""" +This file contains the Llama3 model example used for tensor parallel distribution +""" + # Taken and modified pytorch lightening # https://lightning.ai/lightning-ai/studios/tensor-parallelism-supercharging-large-model-training-with-pytorch-lightning diff --git a/examples/distributed_inference/tensor_parallel_initialize_dist.py b/examples/distributed_inference/tensor_parallel_initialize_dist.py index 21e4cbc282..9a662e92f7 100644 --- a/examples/distributed_inference/tensor_parallel_initialize_dist.py +++ b/examples/distributed_inference/tensor_parallel_initialize_dist.py @@ -1,3 +1,9 @@ +""" +This script contains utility functions for Tensor Parallelism +using Torch-TensorRT. It sets up the necessary communication protocols, +environments and partitions the model across multiple GPUs. +""" + import logging import os from typing import Any, Callable, Dict, List, Optional, Sequence, Tuple, Union diff --git a/examples/distributed_inference/tensor_parallel_llama3.py b/examples/distributed_inference/tensor_parallel_llama3.py index 998c378be2..9ed985b702 100644 --- a/examples/distributed_inference/tensor_parallel_llama3.py +++ b/examples/distributed_inference/tensor_parallel_llama3.py @@ -1,10 +1,25 @@ -# Taken and modified pytorch lightening -# https://lightning.ai/lightning-ai/studios/tensor-parallelism-supercharging-large-model-training-with-pytorch-lightning +""" +.. _tensor_parallel_llama3: + +Torch distributed example for llama3-7B model +====================================================== + +As model sizes are increasing, large models with billions of parameters are trained with many GPUs, where regular data parallel training is no longer possible. In this example, we illustrate the Llama3-7B model inference using Torch-TensorRT backend, split across multiple GPUs using a form of model parallelism called Tensor Parallelism. We make use of Pytorch Distributed Tensor Parallelism Module. Please refer to these tutorials- https://pytorch.org/tutorials/intermediate/TP_tutorial.html and https://lightning.ai/lightning-ai/studios/tensor-parallelism-supercharging-large-model-training-with-pytorch-lightning?section=featured +""" + +# %% +# Imports and Model Definition +# ^^^^^^^^^^^^^^^^^^^^^^^^^^^^ + import logging import os import time import torch +import torch_tensorrt + +# Pytorch Tensor Parallel APIs offer set of module level primitives(ParallelStyle) to configure the sharding of tensors in each layer of the model +# ParallelTransformer creates the parallelize_plan for the FeedForward layer of the model from llama3_model import ModelArgs, ParallelTransformer from tensor_parallel_initialize_dist import initialize_distributed_env from torch.distributed._composable.fsdp import MixedPrecisionPolicy @@ -14,11 +29,26 @@ checkpoint_wrapper, ) +# %% +# Initialize the distributed environment +# ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^ +# The following steps are performed: +# +# - Initialize the communicators and the distributed environment +# - Set the path for the `TRT-LLM`` plugin `.so` file, which is required for the NCCL operations in Torch-TRT backend. +# - Initialize the logger: +# +# - Example: In a 2-GPU setup, the log files will be: +# - `./tensor_parallel_llama3_0.log` +# - `./tensor_parallel_llama3_1.log` +# device_mesh, _world_size, _rank, logger = initialize_distributed_env( "./tensor_parallel_llama3" ) -# Import should be after initialization of the TRT-LLM plugin .so path -import tensorrt_llm + +# %% +# Model initialization with torch distributed parallel plan +# ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^ logger.info(f"Starting PyTorch TP example on rank {_rank}.") assert ( @@ -36,7 +66,59 @@ ) with torch.no_grad(): + # The plan is + # plan = { + # "attention": PrepareModuleInput( + # input_layouts=(Shard(1), None), + # desired_input_layouts=(Replicate(), None), + # ), + # "attention.wq": ColwiseParallel(), + # "attention.wk": ColwiseParallel(), + # "attention.wv": ColwiseParallel(), + # "attention.wo": RowwiseParallel(output_layouts=Shard(1)), + # "attention_norm": SequenceParallel(), + # "feed_forward": PrepareModuleInput( + # input_layouts=(Shard(1),), + # desired_input_layouts=(Replicate(),), + # ), + # "feed_forward.w1": ColwiseParallel(), + # "feed_forward.w2": RowwiseParallel(output_layouts=Shard(1)), + # "feed_forward.w3": ColwiseParallel(), + # "ffn_norm": SequenceParallel(), + # } + model = ParallelTransformer(model_args, device_mesh) + + # %% + # Model inference with Torch-TensorRT backend + # ------------------------------------------- + # When we compile the distributed model using the **Torch-TensorRT** backend, PyTorch's distributed libraries: + # + # - Create the **sharded model** across multiple GPUs. + # - Use **communicator operations** to ensure proper communication. + # + # The following components manage different aspects of parallelism: + # + # - **`ColwiseParallel`** and **`RowwiseParallel`**: + # - Shard the attention layers in **column-wise** or **row-wise** fashion. + # + # - **`SequenceParallel`**: + # - Performs **sharded computations** of the normalization layer. + # + # - **`PrepareModuleInput`**: + # - Configures the model input with proper **communication operations**. + # + # **NCCL Operations in TensorRT-LLM:** + # + # - The **TensorRT-LLM NCCL plugins** handle distributed backend NCCL operations, preventing **graph breaks**. + # - Depending on the **DTensor sharding layout**, proper **communication operations** are required to transform the DTensor layout. + # + # **Common NCCL Operations Used:** + # + # - `allreduce` + # - `allgather` + # - `reduce_scatter` + # torch.manual_seed(0) inp = torch.randint(32000, (8, 256), device="cuda") python_result = model(inp) @@ -62,9 +144,11 @@ output = model(inp) end = time.time() if i == 0: + # Logging the Compilation time logger.info(f"Compilation time is {end-start}") assert ( python_result - output ).std() < 0.01, "Compilation result is not correct." elif _rank == 0: + # Logging the inference time logger.info(f"Inference time is {end-start}") diff --git a/examples/distributed_inference/tensor_parallel_simple_example.py b/examples/distributed_inference/tensor_parallel_simple_example.py index 837648fdb4..ade8f0607d 100755 --- a/examples/distributed_inference/tensor_parallel_simple_example.py +++ b/examples/distributed_inference/tensor_parallel_simple_example.py @@ -1,3 +1,7 @@ +""" +This file contains the Tensor parallel simple model example used for tensor parallel distribution +""" + import time import tensorrt as trt @@ -15,7 +19,6 @@ device_mesh, _world_size, _rank, logger = initialize_distributed_env( "./tensor_parallel_simple_example" ) -import tensorrt_llm """ This example copies some code from https://github.com/pytorch/examples/blob/main/distributed/tensor_parallelism/tensor_parallel_example.py