Skip to content

Commit

Permalink
Refactor ray pipeline
Browse files Browse the repository at this point in the history
  • Loading branch information
lihuahua123 committed Dec 18, 2024
1 parent 9f0a196 commit 5968e5c
Show file tree
Hide file tree
Showing 9 changed files with 130 additions and 90 deletions.
23 changes: 0 additions & 23 deletions examples/ray_example.py

This file was deleted.

8 changes: 4 additions & 4 deletions examples/run.sh
Original file line number Diff line number Diff line change
Expand Up @@ -3,7 +3,7 @@ set -x
export PYTHONPATH=$PWD:$PYTHONPATH

# Select the model type
export MODEL_TYPE="Flux"
export MODEL_TYPE="Sd3"
# Configuration for different model types
# script, model_id, inference_step
declare -A MODEL_CONFIGS=(
Expand All @@ -29,8 +29,8 @@ TASK_ARGS="--height 1024 --width 1024 --no_use_resolution_binning"


# On 8 gpus, pp=2, ulysses=2, ring=1, cfg_parallel=2 (split batch)
N_GPUS=8
PARALLEL_ARGS="--pipefusion_parallel_degree 2 --ulysses_degree 2 --ring_degree 2"
N_GPUS=2
PARALLEL_ARGS="--pipefusion_parallel_degree 2 --ulysses_degree 1 --ring_degree 1"

# CFG_ARGS="--use_cfg_parallel"

Expand All @@ -49,7 +49,7 @@ PARALLEL_ARGS="--pipefusion_parallel_degree 2 --ulysses_degree 2 --ring_degree 2
# Use this flag to quantize the T5 text encoder, which could reduce the memory usage and have no effect on the result quality.
# QUANTIZE_FLAG="--use_fp8_t5_encoder"

# export CUDA_VISIBLE_DEVICES=4,5,6,7
export CUDA_VISIBLE_DEVICES=0,1

torchrun --nproc_per_node=$N_GPUS ./examples/$SCRIPT \
--model $MODEL_ID \
Expand Down
2 changes: 1 addition & 1 deletion examples/sd3_example.py
Original file line number Diff line number Diff line change
Expand Up @@ -75,7 +75,7 @@ def main():

if get_world_group().rank == get_world_group().world_size - 1:
print(
f"epoch time: {elapsed_time:.2f} sec, parameter memory: {parameter_peak_memory/1e9:.2f} GB, peak memory: {peak_memory/1e9:.2f} GB"
f"epoch time: {elapsed_time} sec, parameter memory: {parameter_peak_memory/1e9:.2f} GB, peak memory: {peak_memory/1e9:.2f} GB"
)

get_runtime_state().destory_distributed_env()
Expand Down
11 changes: 6 additions & 5 deletions examples/sd_run.sh → tests/executor/sd_run.sh
Original file line number Diff line number Diff line change
Expand Up @@ -9,7 +9,7 @@ export MODEL_TYPE="Sd3"
declare -A MODEL_CONFIGS=(
["Pixart-alpha"]="pixartalpha_example.py /cfs/dit/PixArt-XL-2-1024-MS 20"
["Pixart-sigma"]="pixartsigma_example.py /cfs/dit/PixArt-Sigma-XL-2-2K-MS 20"
["Sd3"]="ray_example.py /data/stable-diffusion-3-medium-diffusers 20"
["Sd3"]="./test_ray.py /cfs/dit/stable-diffusion-3-medium-diffusers 20"
["Flux"]="flux_example.py /cfs/dit/FLUX.1-dev 28"
["HunyuanDiT"]="hunyuandit_example.py /cfs/dit/HunyuanDiT-v1.2-Diffusers 50"
)
Expand All @@ -29,7 +29,7 @@ TASK_ARGS="--height 1024 --width 1024 --no_use_resolution_binning"


# On 8 gpus, pp=2, ulysses=2, ring=1, cfg_parallel=2 (split batch)
N_GPUS=1
N_GPUS=2
PARALLEL_ARGS="--pipefusion_parallel_degree 2 --ulysses_degree 1 --ring_degree 1"

# CFG_ARGS="--use_cfg_parallel"
Expand All @@ -49,9 +49,9 @@ PARALLEL_ARGS="--pipefusion_parallel_degree 2 --ulysses_degree 1 --ring_degree 1
# Use this flag to quantize the T5 text encoder, which could reduce the memory usage and have no effect on the result quality.
# QUANTIZE_FLAG="--use_fp8_t5_encoder"

# export CUDA_VISIBLE_DEVICES=4,5,6,7
export CUDA_VISIBLE_DEVICES=0,1

python ./examples/$SCRIPT \
python ./tests/executor/$SCRIPT \
--model $MODEL_ID \
$PARALLEL_ARGS \
$TASK_ARGS \
Expand All @@ -60,7 +60,8 @@ $OUTPUT_ARGS \
--num_inference_steps $INFERENCE_STEP \
--warmup_steps 1 \
--prompt "brown dog laying on the ground with a metal bowl in front of him." \
--world_size 2 \
--use_ray \
--ray_world_size $N_GPUS \
$CFG_ARGS \
$PARALLLEL_VAE \
$COMPILE_FLAG \
Expand Down
68 changes: 68 additions & 0 deletions tests/executor/test_ray.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,68 @@
import time
import os
import torch
import torch.distributed
from transformers import T5EncoderModel
from xfuser import xFuserStableDiffusion3Pipeline, xFuserArgs
from xfuser.executor.gpu_executor import RayDiffusionPipeline
from xfuser.config import FlexibleArgumentParser
from xfuser.core.distributed import (
get_world_group,
is_dp_last_group,
get_data_parallel_rank,
get_runtime_state,
)
from xfuser.core.distributed.parallel_state import get_data_parallel_world_size
from xfuser.executor.gpu_executor import RayDiffusionPipeline
from xfuser.worker.worker import xFuserPixArtAlphaPipeline, xFuserPixArtSigmaPipeline, xFuserStableDiffusion3Pipeline, xFuserHunyuanDiTPipeline, xFuserFluxPipeline

def main():
os.environ["MASTER_ADDR"] = "localhost"
os.environ["MASTER_PORT"] = "12355"
parser = FlexibleArgumentParser(description="xFuser Arguments")
args = xFuserArgs.add_cli_args(parser).parse_args()
engine_args = xFuserArgs.from_cli_args(args)
engine_config, input_config = engine_args.create_config()
pipeline_map = {
"PixArt-XL-2-1024-MS": xFuserPixArtAlphaPipeline,
"PixArt-Sigma-XL-2-2K-MS": xFuserPixArtSigmaPipeline,
"stable-diffusion-3-medium-diffusers": xFuserStableDiffusion3Pipeline,
"HunyuanDiT-v1.2-Diffusers": xFuserHunyuanDiTPipeline,
"FLUX.1-schnell": xFuserFluxPipeline,
}
model_name = engine_config.model_config.model.split("/")[-1]
PipelineClass = pipeline_map.get(model_name)
if PipelineClass is None:
raise NotImplementedError(f"{model_name} is currently not supported!")
text_encoder_3 = T5EncoderModel.from_pretrained(engine_config.model_config.model, subfolder="text_encoder_3", torch_dtype=torch.float16)
if args.use_fp8_t5_encoder:
from optimum.quanto import freeze, qfloat8, quantize
quantize(text_encoder_3, weights=qfloat8)
freeze(text_encoder_3)

pipe = RayDiffusionPipeline.from_pretrained(
PipelineClass=PipelineClass,
pretrained_model_name_or_path=engine_config.model_config.model,
engine_config=engine_config,
torch_dtype=torch.float16,
text_encoder_3=text_encoder_3,
)
pipe.prepare_run(input_config)

start_time = time.time()
output = pipe(
height=input_config.height,
width=input_config.width,
prompt=input_config.prompt,
num_inference_steps=input_config.num_inference_steps,
output_type=input_config.output_type,
generator=torch.Generator(device="cuda").manual_seed(input_config.seed),
)
end_time = time.time()
elapsed_time = end_time - start_time

print(f"elapsed time:{elapsed_time}")


if __name__ == "__main__":
main()
28 changes: 18 additions & 10 deletions xfuser/config/args.py
Original file line number Diff line number Diff line change
Expand Up @@ -79,7 +79,9 @@ class xFuserArgs:
# tensor parallel
tensor_parallel_degree: int = 1
split_scheme: Optional[str] = "row"
world_size: int = 1
# ray arguments
use_ray: bool = False
ray_world_size: int = 1
# pipefusion parallel
pipefusion_parallel_degree: int = 1
num_pipeline_patch: Optional[int] = None
Expand Down Expand Up @@ -152,8 +154,13 @@ def add_cli_args(parser: FlexibleArgumentParser):

# Parallel arguments
parallel_group = parser.add_argument_group("Parallel Processing Options")
runtime_group.add_argument(
"--use_ray",
action="store_true",
help="Enable ray to run inference in multi-card",
)
parallel_group.add_argument(
"--world_size",
"--ray_world_size",
type=int,
default=1,
help="World size.",
Expand Down Expand Up @@ -329,14 +336,15 @@ def from_cli_args(cls, args: argparse.Namespace):
def create_config(
self,
) -> Tuple[EngineConfig, InputConfig]:
# if not torch.distributed.is_initialized():
# logger.warning(
# "Distributed environment is not initialized. " "Initializing..."
# )
# init_distributed_environment(
# rank=self.rank,
# world_size=self.world_size,
# )
if not self.use_ray and not torch.distributed.is_initialized():
logger.warning(
"Distributed environment is not initialized. " "Initializing..."
)
init_distributed_environment()
if self.use_ray:
self.world_size = self.ray_world_size
else:
self.world_size = torch.distributed.get_world_size()

model_config = ModelConfig(
model=self.model,
Expand Down
20 changes: 12 additions & 8 deletions xfuser/executor/gpu_executor.py
Original file line number Diff line number Diff line change
Expand Up @@ -8,7 +8,7 @@
from xfuser.logger import init_logger
from xfuser.worker.worker_wrappers import RayWorkerWrapper
from xfuser.config.config import InputConfig, EngineConfig

from xfuser.model_executor.pipelines.base_pipeline import xFuserPipelineBaseWrapper
logger = init_logger(__name__)


Expand All @@ -17,10 +17,11 @@ def _init_executor(self):
pass


class RayGPUExecutor(GPUExecutor):
class RayDiffusionPipeline(GPUExecutor):
workers = []
def _init_executor(self):
self._init_ray_workers()
self._run_workers("init_worker_distributed_environment")

def _init_ray_workers(self):
placement_group = initialize_ray_cluster(self.engine_config.parallel_config)
Expand Down Expand Up @@ -96,11 +97,14 @@ def _run_workers(

return ray_worker_outputs

def init_distributed_environment(self):
self._run_workers("init_worker_distributed_environment")
@classmethod
def from_pretrained(cls,PipelineClass,pretrained_model_name_or_path: str,engine_config: EngineConfig,**kwargs):
pipeline = cls(engine_config)
pipeline._run_workers("from_pretrained",PipelineClass,pretrained_model_name_or_path,engine_config,**kwargs)
return pipeline

def load_model(self,engine_config: EngineConfig):
self._run_workers("load_model",engine_config)
def prepare_run(self, input_config: InputConfig, steps: int = 3, sync_steps: int = 1):
self._run_workers("prepare_run",input_config,steps,sync_steps)

def execute(self, input_config: InputConfig):
self._run_workers("execute", input_config)
def __call__(self,**kwargs):
return self._run_workers("execute",**kwargs)
4 changes: 4 additions & 0 deletions xfuser/executor/ray_utils.py
Original file line number Diff line number Diff line change
@@ -1,3 +1,7 @@
# Copyright 2024 The xDiT team.
# Adapted from
# https://github.com/vllm-project/vllm/blob/main/vllm/executor/ray_utils.py
# Copyright (c) 2022, vLLM team. All rights reserved.
import time
import socket
from typing import Dict, List, Optional
Expand Down
56 changes: 17 additions & 39 deletions xfuser/worker/worker.py
Original file line number Diff line number Diff line change
Expand Up @@ -43,8 +43,8 @@ def execute(
raise NotImplementedError

@abstractmethod
def load_model(
self, engine_config: EngineConfig,*args, **kwargs
def from_pretrained(
self, PipelineClass, engine_config: EngineConfig,**kwargs,
):
raise NotImplementedError

Expand Down Expand Up @@ -73,46 +73,23 @@ def init_worker_distributed_environment(self):
world_size=self.parallel_config.world_size,
)

def load_model(self,engine_config: EngineConfig):
def from_pretrained(self,PipelineClass, pretrained_model_name_or_path: str, engine_config: EngineConfig,**kwargs,):
local_rank = get_world_group().local_rank
pipeline_map = {
"PixArt-XL-2-1024-MS": xFuserPixArtAlphaPipeline,
"PixArt-Sigma-XL-2-2K-MS": xFuserPixArtSigmaPipeline,
"stable-diffusion-3-medium-diffusers": xFuserStableDiffusion3Pipeline,
"HunyuanDiT-v1.2-Diffusers": xFuserHunyuanDiTPipeline,
"FLUX.1-schnell": xFuserFluxPipeline,
}
model_name = engine_config.model_config.model.split("/")[-1]
PipelineClass = pipeline_map.get(model_name)
if PipelineClass is None:
raise NotImplementedError(f"{model_name} is currently not supported!")
if model_name == "stable-diffusion-3-medium-diffusers": # FIXME: hard code
text_encoder = T5EncoderModel.from_pretrained(engine_config.model_config.model, subfolder="text_encoder_3", torch_dtype=torch.float16)
pipe = PipelineClass.from_pretrained(
pretrained_model_name_or_path=engine_config.model_config.model,
engine_config=engine_config,
torch_dtype=torch.float16,
text_encoder_3=text_encoder, # FIXME: hard code
).to(f"cuda:{local_rank}")
else:
pipe = PipelineClass.from_pretrained(
pretrained_model_name_or_path=engine_config.model_config.model,
engine_config=engine_config,
torch_dtype=torch.float16,
).to(f"cuda:{local_rank}")
pipe = PipelineClass.from_pretrained(
pretrained_model_name_or_path=pretrained_model_name_or_path,
engine_config=engine_config,
**kwargs
).to(f"cuda:{local_rank}")
self.pipe = pipe
return

def prepare_run(self,input_config: InputConfig,steps: int = 3,sync_steps: int = 1):
self.pipe.prepare_run(input_config,steps,sync_steps)

def execute(self, input_config: InputConfig):
self.pipe.prepare_run(input_config)
output = self.pipe(
height=input_config.height,
width=input_config.width,
prompt=input_config.prompt,
num_inference_steps=input_config.num_inference_steps,
output_type=input_config.output_type,
generator=torch.Generator(device="cuda").manual_seed(input_config.seed),
)
def execute(self, **kwargs):
time_start = time.time()
output = self.pipe(**kwargs)
time_end = time.time()
if self.pipe.is_dp_last_group():
if not os.path.exists("results"):
os.mkdir("results")
Expand All @@ -123,4 +100,5 @@ def execute(self, input_config: InputConfig):
print(
f"image {i} saved to /data/results/stable_diffusion_3_result_{i}.png"
)
return
print(f"time cost: {time_end - time_start}")
return output

0 comments on commit 5968e5c

Please sign in to comment.