Skip to content

Commit 5968e5c

Browse files
committed
Refactor ray pipeline
1 parent 9f0a196 commit 5968e5c

File tree

9 files changed

+130
-90
lines changed

9 files changed

+130
-90
lines changed

examples/ray_example.py

Lines changed: 0 additions & 23 deletions
This file was deleted.

examples/run.sh

Lines changed: 4 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -3,7 +3,7 @@ set -x
33
export PYTHONPATH=$PWD:$PYTHONPATH
44

55
# Select the model type
6-
export MODEL_TYPE="Flux"
6+
export MODEL_TYPE="Sd3"
77
# Configuration for different model types
88
# script, model_id, inference_step
99
declare -A MODEL_CONFIGS=(
@@ -29,8 +29,8 @@ TASK_ARGS="--height 1024 --width 1024 --no_use_resolution_binning"
2929

3030

3131
# On 8 gpus, pp=2, ulysses=2, ring=1, cfg_parallel=2 (split batch)
32-
N_GPUS=8
33-
PARALLEL_ARGS="--pipefusion_parallel_degree 2 --ulysses_degree 2 --ring_degree 2"
32+
N_GPUS=2
33+
PARALLEL_ARGS="--pipefusion_parallel_degree 2 --ulysses_degree 1 --ring_degree 1"
3434

3535
# CFG_ARGS="--use_cfg_parallel"
3636

@@ -49,7 +49,7 @@ PARALLEL_ARGS="--pipefusion_parallel_degree 2 --ulysses_degree 2 --ring_degree 2
4949
# Use this flag to quantize the T5 text encoder, which could reduce the memory usage and have no effect on the result quality.
5050
# QUANTIZE_FLAG="--use_fp8_t5_encoder"
5151

52-
# export CUDA_VISIBLE_DEVICES=4,5,6,7
52+
export CUDA_VISIBLE_DEVICES=0,1
5353

5454
torchrun --nproc_per_node=$N_GPUS ./examples/$SCRIPT \
5555
--model $MODEL_ID \

examples/sd3_example.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -75,7 +75,7 @@ def main():
7575

7676
if get_world_group().rank == get_world_group().world_size - 1:
7777
print(
78-
f"epoch time: {elapsed_time:.2f} sec, parameter memory: {parameter_peak_memory/1e9:.2f} GB, peak memory: {peak_memory/1e9:.2f} GB"
78+
f"epoch time: {elapsed_time} sec, parameter memory: {parameter_peak_memory/1e9:.2f} GB, peak memory: {peak_memory/1e9:.2f} GB"
7979
)
8080

8181
get_runtime_state().destory_distributed_env()

examples/sd_run.sh renamed to tests/executor/sd_run.sh

Lines changed: 6 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -9,7 +9,7 @@ export MODEL_TYPE="Sd3"
99
declare -A MODEL_CONFIGS=(
1010
["Pixart-alpha"]="pixartalpha_example.py /cfs/dit/PixArt-XL-2-1024-MS 20"
1111
["Pixart-sigma"]="pixartsigma_example.py /cfs/dit/PixArt-Sigma-XL-2-2K-MS 20"
12-
["Sd3"]="ray_example.py /data/stable-diffusion-3-medium-diffusers 20"
12+
["Sd3"]="./test_ray.py /cfs/dit/stable-diffusion-3-medium-diffusers 20"
1313
["Flux"]="flux_example.py /cfs/dit/FLUX.1-dev 28"
1414
["HunyuanDiT"]="hunyuandit_example.py /cfs/dit/HunyuanDiT-v1.2-Diffusers 50"
1515
)
@@ -29,7 +29,7 @@ TASK_ARGS="--height 1024 --width 1024 --no_use_resolution_binning"
2929

3030

3131
# On 8 gpus, pp=2, ulysses=2, ring=1, cfg_parallel=2 (split batch)
32-
N_GPUS=1
32+
N_GPUS=2
3333
PARALLEL_ARGS="--pipefusion_parallel_degree 2 --ulysses_degree 1 --ring_degree 1"
3434

3535
# CFG_ARGS="--use_cfg_parallel"
@@ -49,9 +49,9 @@ PARALLEL_ARGS="--pipefusion_parallel_degree 2 --ulysses_degree 1 --ring_degree 1
4949
# Use this flag to quantize the T5 text encoder, which could reduce the memory usage and have no effect on the result quality.
5050
# QUANTIZE_FLAG="--use_fp8_t5_encoder"
5151

52-
# export CUDA_VISIBLE_DEVICES=4,5,6,7
52+
export CUDA_VISIBLE_DEVICES=0,1
5353

54-
python ./examples/$SCRIPT \
54+
python ./tests/executor/$SCRIPT \
5555
--model $MODEL_ID \
5656
$PARALLEL_ARGS \
5757
$TASK_ARGS \
@@ -60,7 +60,8 @@ $OUTPUT_ARGS \
6060
--num_inference_steps $INFERENCE_STEP \
6161
--warmup_steps 1 \
6262
--prompt "brown dog laying on the ground with a metal bowl in front of him." \
63-
--world_size 2 \
63+
--use_ray \
64+
--ray_world_size $N_GPUS \
6465
$CFG_ARGS \
6566
$PARALLLEL_VAE \
6667
$COMPILE_FLAG \

tests/executor/test_ray.py

Lines changed: 68 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,68 @@
1+
import time
2+
import os
3+
import torch
4+
import torch.distributed
5+
from transformers import T5EncoderModel
6+
from xfuser import xFuserStableDiffusion3Pipeline, xFuserArgs
7+
from xfuser.executor.gpu_executor import RayDiffusionPipeline
8+
from xfuser.config import FlexibleArgumentParser
9+
from xfuser.core.distributed import (
10+
get_world_group,
11+
is_dp_last_group,
12+
get_data_parallel_rank,
13+
get_runtime_state,
14+
)
15+
from xfuser.core.distributed.parallel_state import get_data_parallel_world_size
16+
from xfuser.executor.gpu_executor import RayDiffusionPipeline
17+
from xfuser.worker.worker import xFuserPixArtAlphaPipeline, xFuserPixArtSigmaPipeline, xFuserStableDiffusion3Pipeline, xFuserHunyuanDiTPipeline, xFuserFluxPipeline
18+
19+
def main():
20+
os.environ["MASTER_ADDR"] = "localhost"
21+
os.environ["MASTER_PORT"] = "12355"
22+
parser = FlexibleArgumentParser(description="xFuser Arguments")
23+
args = xFuserArgs.add_cli_args(parser).parse_args()
24+
engine_args = xFuserArgs.from_cli_args(args)
25+
engine_config, input_config = engine_args.create_config()
26+
pipeline_map = {
27+
"PixArt-XL-2-1024-MS": xFuserPixArtAlphaPipeline,
28+
"PixArt-Sigma-XL-2-2K-MS": xFuserPixArtSigmaPipeline,
29+
"stable-diffusion-3-medium-diffusers": xFuserStableDiffusion3Pipeline,
30+
"HunyuanDiT-v1.2-Diffusers": xFuserHunyuanDiTPipeline,
31+
"FLUX.1-schnell": xFuserFluxPipeline,
32+
}
33+
model_name = engine_config.model_config.model.split("/")[-1]
34+
PipelineClass = pipeline_map.get(model_name)
35+
if PipelineClass is None:
36+
raise NotImplementedError(f"{model_name} is currently not supported!")
37+
text_encoder_3 = T5EncoderModel.from_pretrained(engine_config.model_config.model, subfolder="text_encoder_3", torch_dtype=torch.float16)
38+
if args.use_fp8_t5_encoder:
39+
from optimum.quanto import freeze, qfloat8, quantize
40+
quantize(text_encoder_3, weights=qfloat8)
41+
freeze(text_encoder_3)
42+
43+
pipe = RayDiffusionPipeline.from_pretrained(
44+
PipelineClass=PipelineClass,
45+
pretrained_model_name_or_path=engine_config.model_config.model,
46+
engine_config=engine_config,
47+
torch_dtype=torch.float16,
48+
text_encoder_3=text_encoder_3,
49+
)
50+
pipe.prepare_run(input_config)
51+
52+
start_time = time.time()
53+
output = pipe(
54+
height=input_config.height,
55+
width=input_config.width,
56+
prompt=input_config.prompt,
57+
num_inference_steps=input_config.num_inference_steps,
58+
output_type=input_config.output_type,
59+
generator=torch.Generator(device="cuda").manual_seed(input_config.seed),
60+
)
61+
end_time = time.time()
62+
elapsed_time = end_time - start_time
63+
64+
print(f"elapsed time:{elapsed_time}")
65+
66+
67+
if __name__ == "__main__":
68+
main()

xfuser/config/args.py

Lines changed: 18 additions & 10 deletions
Original file line numberDiff line numberDiff line change
@@ -79,7 +79,9 @@ class xFuserArgs:
7979
# tensor parallel
8080
tensor_parallel_degree: int = 1
8181
split_scheme: Optional[str] = "row"
82-
world_size: int = 1
82+
# ray arguments
83+
use_ray: bool = False
84+
ray_world_size: int = 1
8385
# pipefusion parallel
8486
pipefusion_parallel_degree: int = 1
8587
num_pipeline_patch: Optional[int] = None
@@ -152,8 +154,13 @@ def add_cli_args(parser: FlexibleArgumentParser):
152154

153155
# Parallel arguments
154156
parallel_group = parser.add_argument_group("Parallel Processing Options")
157+
runtime_group.add_argument(
158+
"--use_ray",
159+
action="store_true",
160+
help="Enable ray to run inference in multi-card",
161+
)
155162
parallel_group.add_argument(
156-
"--world_size",
163+
"--ray_world_size",
157164
type=int,
158165
default=1,
159166
help="World size.",
@@ -329,14 +336,15 @@ def from_cli_args(cls, args: argparse.Namespace):
329336
def create_config(
330337
self,
331338
) -> Tuple[EngineConfig, InputConfig]:
332-
# if not torch.distributed.is_initialized():
333-
# logger.warning(
334-
# "Distributed environment is not initialized. " "Initializing..."
335-
# )
336-
# init_distributed_environment(
337-
# rank=self.rank,
338-
# world_size=self.world_size,
339-
# )
339+
if not self.use_ray and not torch.distributed.is_initialized():
340+
logger.warning(
341+
"Distributed environment is not initialized. " "Initializing..."
342+
)
343+
init_distributed_environment()
344+
if self.use_ray:
345+
self.world_size = self.ray_world_size
346+
else:
347+
self.world_size = torch.distributed.get_world_size()
340348

341349
model_config = ModelConfig(
342350
model=self.model,

xfuser/executor/gpu_executor.py

Lines changed: 12 additions & 8 deletions
Original file line numberDiff line numberDiff line change
@@ -8,7 +8,7 @@
88
from xfuser.logger import init_logger
99
from xfuser.worker.worker_wrappers import RayWorkerWrapper
1010
from xfuser.config.config import InputConfig, EngineConfig
11-
11+
from xfuser.model_executor.pipelines.base_pipeline import xFuserPipelineBaseWrapper
1212
logger = init_logger(__name__)
1313

1414

@@ -17,10 +17,11 @@ def _init_executor(self):
1717
pass
1818

1919

20-
class RayGPUExecutor(GPUExecutor):
20+
class RayDiffusionPipeline(GPUExecutor):
2121
workers = []
2222
def _init_executor(self):
2323
self._init_ray_workers()
24+
self._run_workers("init_worker_distributed_environment")
2425

2526
def _init_ray_workers(self):
2627
placement_group = initialize_ray_cluster(self.engine_config.parallel_config)
@@ -96,11 +97,14 @@ def _run_workers(
9697

9798
return ray_worker_outputs
9899

99-
def init_distributed_environment(self):
100-
self._run_workers("init_worker_distributed_environment")
100+
@classmethod
101+
def from_pretrained(cls,PipelineClass,pretrained_model_name_or_path: str,engine_config: EngineConfig,**kwargs):
102+
pipeline = cls(engine_config)
103+
pipeline._run_workers("from_pretrained",PipelineClass,pretrained_model_name_or_path,engine_config,**kwargs)
104+
return pipeline
101105

102-
def load_model(self,engine_config: EngineConfig):
103-
self._run_workers("load_model",engine_config)
106+
def prepare_run(self, input_config: InputConfig, steps: int = 3, sync_steps: int = 1):
107+
self._run_workers("prepare_run",input_config,steps,sync_steps)
104108

105-
def execute(self, input_config: InputConfig):
106-
self._run_workers("execute", input_config)
109+
def __call__(self,**kwargs):
110+
return self._run_workers("execute",**kwargs)

xfuser/executor/ray_utils.py

Lines changed: 4 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -1,3 +1,7 @@
1+
# Copyright 2024 The xDiT team.
2+
# Adapted from
3+
# https://github.com/vllm-project/vllm/blob/main/vllm/executor/ray_utils.py
4+
# Copyright (c) 2022, vLLM team. All rights reserved.
15
import time
26
import socket
37
from typing import Dict, List, Optional

xfuser/worker/worker.py

Lines changed: 17 additions & 39 deletions
Original file line numberDiff line numberDiff line change
@@ -43,8 +43,8 @@ def execute(
4343
raise NotImplementedError
4444

4545
@abstractmethod
46-
def load_model(
47-
self, engine_config: EngineConfig,*args, **kwargs
46+
def from_pretrained(
47+
self, PipelineClass, engine_config: EngineConfig,**kwargs,
4848
):
4949
raise NotImplementedError
5050

@@ -73,46 +73,23 @@ def init_worker_distributed_environment(self):
7373
world_size=self.parallel_config.world_size,
7474
)
7575

76-
def load_model(self,engine_config: EngineConfig):
76+
def from_pretrained(self,PipelineClass, pretrained_model_name_or_path: str, engine_config: EngineConfig,**kwargs,):
7777
local_rank = get_world_group().local_rank
78-
pipeline_map = {
79-
"PixArt-XL-2-1024-MS": xFuserPixArtAlphaPipeline,
80-
"PixArt-Sigma-XL-2-2K-MS": xFuserPixArtSigmaPipeline,
81-
"stable-diffusion-3-medium-diffusers": xFuserStableDiffusion3Pipeline,
82-
"HunyuanDiT-v1.2-Diffusers": xFuserHunyuanDiTPipeline,
83-
"FLUX.1-schnell": xFuserFluxPipeline,
84-
}
85-
model_name = engine_config.model_config.model.split("/")[-1]
86-
PipelineClass = pipeline_map.get(model_name)
87-
if PipelineClass is None:
88-
raise NotImplementedError(f"{model_name} is currently not supported!")
89-
if model_name == "stable-diffusion-3-medium-diffusers": # FIXME: hard code
90-
text_encoder = T5EncoderModel.from_pretrained(engine_config.model_config.model, subfolder="text_encoder_3", torch_dtype=torch.float16)
91-
pipe = PipelineClass.from_pretrained(
92-
pretrained_model_name_or_path=engine_config.model_config.model,
93-
engine_config=engine_config,
94-
torch_dtype=torch.float16,
95-
text_encoder_3=text_encoder, # FIXME: hard code
96-
).to(f"cuda:{local_rank}")
97-
else:
98-
pipe = PipelineClass.from_pretrained(
99-
pretrained_model_name_or_path=engine_config.model_config.model,
100-
engine_config=engine_config,
101-
torch_dtype=torch.float16,
102-
).to(f"cuda:{local_rank}")
78+
pipe = PipelineClass.from_pretrained(
79+
pretrained_model_name_or_path=pretrained_model_name_or_path,
80+
engine_config=engine_config,
81+
**kwargs
82+
).to(f"cuda:{local_rank}")
10383
self.pipe = pipe
10484
return
85+
86+
def prepare_run(self,input_config: InputConfig,steps: int = 3,sync_steps: int = 1):
87+
self.pipe.prepare_run(input_config,steps,sync_steps)
10588

106-
def execute(self, input_config: InputConfig):
107-
self.pipe.prepare_run(input_config)
108-
output = self.pipe(
109-
height=input_config.height,
110-
width=input_config.width,
111-
prompt=input_config.prompt,
112-
num_inference_steps=input_config.num_inference_steps,
113-
output_type=input_config.output_type,
114-
generator=torch.Generator(device="cuda").manual_seed(input_config.seed),
115-
)
89+
def execute(self, **kwargs):
90+
time_start = time.time()
91+
output = self.pipe(**kwargs)
92+
time_end = time.time()
11693
if self.pipe.is_dp_last_group():
11794
if not os.path.exists("results"):
11895
os.mkdir("results")
@@ -123,4 +100,5 @@ def execute(self, input_config: InputConfig):
123100
print(
124101
f"image {i} saved to /data/results/stable_diffusion_3_result_{i}.png"
125102
)
126-
return
103+
print(f"time cost: {time_end - time_start}")
104+
return output

0 commit comments

Comments
 (0)