Skip to content

Commit 738000f

Browse files
authored
[Flux] Enable classifier-free guidance for training and inference (#1099)
## Context 1. Refactor dataloader to support classifier-free guidance: Replace some of the strings with empty strings. 2. Refactor test_generate_image.py to run a forward pass to generate image with classifer-free guidance. ``` # sampling pseudocode given prompt (e.g. "a picture of a cat") : 1. sample noise of shape [B, 64, height // 16, width // 16] 2. duplicate noise i.e. x = torch.cat((noise, noise)) 3. patchify x 4. choose timesteps, e.g., t = torch.linspace(1.0, 0.0, 30) 5. compute t5 / clip embeddings for prompt and empty string "" (you could also cache the empty string embeddings if you want ) 6. for loop over timesteps for t_curr, t_prev in zip(timesteps[:-1], timesteps[1:]): evaluate pred = model(x, t, t5_embeddings, clip_embeddings) classifier-free guidance pred_u, pred_c = pred.chunk(2) -> pred = pred_u + cfg_val * (pred_c - pred_u) advanced x = x + (t_prev - t_curr) * pred 7. image = autoencoder(x) ``` 3. Remove guided-distillation related params. 4. Refactor inference related code to `sampling.py`. Adding periodical eval in `train.py` 5. Seed management for Flux model 6. Adding basic eval function (generate image every few steps) during training ## Seed management during Flux training and eval With in 1 single training run: 1. During training, each `dp_shard` rank should have different random seed. Because we are randomly dropout the prompt information to train a unconditional + conditional model. We don't want to each rank drop samples from same position in dataloader. 2. During evaluation/sampling, since we are generating from random noise, so we want to use the exact same seed for each evaluation step. - Run image sampling on Rank 0 only, Use the local random seed on rank0 to generate Between 2 different training runs: - We want to reproduce the seed for the first time to strictly reproduce the 1st run results. ## Next Step: - Adding more evaluation functionality: Calculate loss fix `t` value .
1 parent 96c11f9 commit 738000f

File tree

18 files changed

+548
-352
lines changed

18 files changed

+548
-352
lines changed

torchtitan/distributed/utils.py

Lines changed: 23 additions & 12 deletions
Original file line numberDiff line numberDiff line change
@@ -46,10 +46,12 @@ def set_determinism(
4646
device: torch.device,
4747
seed: int | None = None,
4848
deterministic: bool = False,
49+
distinct_seed_mesh_dim: str = "pp",
4950
) -> None:
5051
"""
51-
Set the same DTensor manual seed for all ranks within the same DTensor SPMD group, but different
52-
seeds across PP groups (if applicable).
52+
Set the same DTensor manual seed for all dimensions in world mesh, but only different seeds
53+
across dimension denoted by `distinct_seed_mesh_dim`. An example use case is pipeline parallelism,
54+
where we want to have the same seed across SPMD groups, but different seeds across PP groups.
5355
5456
Currently, does not set seeds for the CUDA RNG since TorchTitan always uses DTensor for SPMD parallelisms,
5557
and DTensor manages its own RNG tracker, but we could extend to support both if needed.
@@ -81,22 +83,31 @@ def set_determinism(
8183
torch.distributed.broadcast(seed_tensor, src=0)
8284
seed = seed_tensor.to("cpu").view(torch.uint64).item()
8385

86+
# Set distinct seed for each rank in mesh dimensions, with dimension name provdied by `distinct_seed_mesh_dim`
8487
# For PP + SPMD cases, we want to separate the world into the SPMD mesh and the PP mesh,
8588
# and choose a unique seed for each rank on the PP mesh.
86-
if c10d.get_world_size() > 1 and "pp" in world_mesh.mesh_dim_names:
87-
pp_mesh = world_mesh["pp"]
88-
seed += pp_mesh.get_local_rank()
89+
# TODO(jianiw): We could further extend this to support mutiple distinct dimensions instead of just one.
90+
if (
91+
c10d.get_world_size() > 1
92+
and distinct_seed_mesh_dim in world_mesh.mesh_dim_names
93+
):
94+
distinct_mesh = world_mesh[distinct_seed_mesh_dim]
95+
seed += distinct_mesh.get_local_rank()
8996
seed %= 2**64
9097

9198
logger.debug(
92-
f"PP rank {pp_mesh.get_local_rank()}, Global rank {c10d.get_rank()} using seed: {seed}"
99+
f"{distinct_seed_mesh_dim} rank {distinct_mesh.get_local_rank()}, Global rank {c10d.get_rank()} using seed: {seed}"
93100
)
94-
spmd_mesh_dims = list(
95-
filter(lambda name: name != "pp", world_mesh.mesh_dim_names)
101+
duplicate_seed_mesh = list(
102+
filter(
103+
lambda name: name != distinct_seed_mesh_dim, world_mesh.mesh_dim_names
104+
)
105+
)
106+
duplicate_seed_mesh = (
107+
world_mesh[duplicate_seed_mesh] if len(duplicate_seed_mesh) else None
96108
)
97-
spmd_mesh = world_mesh[spmd_mesh_dims] if len(spmd_mesh_dims) else None
98109
else:
99-
spmd_mesh = world_mesh
110+
duplicate_seed_mesh = world_mesh
100111
logger.debug(f"Global Rank {c10d.get_rank()} using seed: {seed}")
101112

102113
# The native RNGs and python RNG may not be important, except for the 1-D PP case, but we seed them for consistency.
@@ -106,8 +117,8 @@ def set_determinism(
106117

107118
# As long as we are not in the 1-D (PP-only) case, we will have a seed to use for all ranks of the SPMD mesh.
108119
# IF PP is also used, this seed is unique per PP rank.
109-
if spmd_mesh and spmd_mesh.get_coordinate() is not None:
110-
torch.distributed.tensor._random.manual_seed(seed, spmd_mesh)
120+
if duplicate_seed_mesh and duplicate_seed_mesh.get_coordinate() is not None:
121+
torch.distributed.tensor._random.manual_seed(seed, duplicate_seed_mesh)
111122

112123

113124
def create_context_parallel_ctx(

torchtitan/experiments/flux/README.md

Lines changed: 0 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -25,6 +25,5 @@ Run the following command to train the model on a single GPU:
2525
## TODO
2626
- [ ] More parallesim support (Tensor Parallelism, Context Parallelism, etc)
2727
- [ ] Support for distributed checkpointing and loading
28-
- [ ] Implement init_weights() function to initialize the model weights
2928
- [ ] Implement the num_flops_per_token calculation in get_nparams_and_flops() function
3029
- [ ] Implement test cases in CI for FLUX model. Adding more unit tests for FLUX model (eg, unit test for preprocessor, etc)

torchtitan/experiments/flux/__init__.py

Lines changed: 0 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -39,7 +39,6 @@
3939
axes_dim=(16, 56, 56),
4040
theta=10_000,
4141
qkv_bias=True,
42-
guidance_embed=True,
4342
autoencoder_params=AutoEncoderParams(
4443
resolution=256,
4544
in_channels=3,
@@ -65,7 +64,6 @@
6564
axes_dim=(16, 56, 56),
6665
theta=10_000,
6766
qkv_bias=True,
68-
guidance_embed=False,
6967
autoencoder_params=AutoEncoderParams(
7068
resolution=256,
7169
in_channels=3,
@@ -91,7 +89,6 @@
9189
axes_dim=(16, 56, 56),
9290
theta=10_000,
9391
qkv_bias=True,
94-
guidance_embed=True,
9592
autoencoder_params=AutoEncoderParams(
9693
resolution=256,
9794
in_channels=3,

torchtitan/experiments/flux/dataset/flux_dataset.py

Lines changed: 22 additions & 9 deletions
Original file line numberDiff line numberDiff line change
@@ -10,12 +10,11 @@
1010
from typing import Any, Callable, Optional
1111

1212
import numpy as np
13+
import PIL
1314

1415
import torch
15-
1616
from datasets import Dataset, load_dataset
1717
from datasets.distributed import split_dataset_by_node
18-
from PIL import Image
1918

2019
from torch.distributed.checkpoint.stateful import Stateful
2120

@@ -28,7 +27,7 @@
2827

2928

3029
def _process_cc12m_image(
31-
img: Image.Image,
30+
img: PIL.Image.Image,
3231
output_size: int = 256,
3332
) -> Optional[torch.Tensor]:
3433
"""Process CC12M image to the desired size."""
@@ -56,9 +55,9 @@ def _process_cc12m_image(
5655

5756
assert resized_img.size[0] == resized_img.size[1] == output_size
5857

59-
# Skip grayscale images, and RGBA, CMYK images
58+
# Convert grayscale images, and RGBA, CMYK images
6059
if resized_img.mode != "RGB":
61-
return None
60+
resized_img = resized_img.convert("RGB")
6261

6362
np_img = np.array(resized_img).transpose((2, 0, 1))
6463
tensor_img = torch.tensor(np_img).float() / 255.0
@@ -76,7 +75,7 @@ def _process_cc12m_image(
7675
return tensor_img
7776

7877

79-
def _flux_data_processor(
78+
def _cc12m_wds_data_processor(
8079
sample: dict[str, Any],
8180
t5_tokenizer: FluxTokenizer,
8281
clip_tokenizer: FluxTokenizer,
@@ -111,10 +110,10 @@ class TextToImageDatasetConfig:
111110

112111

113112
DATASETS = {
114-
"cc12m": TextToImageDatasetConfig(
113+
"cc12m-wds": TextToImageDatasetConfig(
115114
path="pixparse/cc12m-wds",
116115
loader=lambda path: load_dataset(path, split="train", streaming=True),
117-
data_processor=_flux_data_processor,
116+
data_processor=_cc12m_wds_data_processor,
118117
),
119118
}
120119

@@ -171,7 +170,9 @@ def __init__(
171170
self._data = split_dataset_by_node(ds, dp_rank, dp_world_size)
172171

173172
self._t5_tokenizer = t5_tokenizer
173+
self._t5_empty_token = t5_tokenizer.encode("")
174174
self._clip_tokenizer = clip_tokenizer
175+
self._clip_empty_token = clip_tokenizer.encode("")
175176
self._data_processor = data_processor
176177
self.job_config = job_config
177178

@@ -195,7 +196,10 @@ def __iter__(self):
195196
for sample in self._get_data_iter():
196197
# Use the dataset-specific preprocessor
197198
sample_dict = self._data_processor(
198-
sample, self._t5_tokenizer, self._clip_tokenizer, output_size=256
199+
sample,
200+
self._t5_tokenizer,
201+
self._clip_tokenizer,
202+
output_size=self.job_config.training.img_size,
199203
)
200204

201205
# skip low quality image or image with color channel = 1
@@ -205,6 +209,14 @@ def __iter__(self):
205209
)
206210
continue
207211

212+
# Classifier-free guidance: Replace some of the strings with empty strings.
213+
# Distinct random seed is initialized at the beginning of training for each FSDP rank.
214+
dropout_prob = self.job_config.training.classifer_free_guidance_prob
215+
if dropout_prob > 0.0:
216+
if random.random() < dropout_prob:
217+
sample_dict["t5_tokens"] = self._t5_empty_token
218+
sample_dict["clip_tokens"] = self._clip_empty_token
219+
208220
self._all_samples.extend(sample_dict)
209221
self._sample_idx += 1
210222

@@ -254,6 +266,7 @@ def build_flux_dataloader(
254266
clip_tokenizer=FluxTokenizer(
255267
clip_encoder_name, max_length=77
256268
), # fix max_length for CLIP
269+
job_config=job_config,
257270
dp_rank=dp_rank,
258271
dp_world_size=dp_world_size,
259272
infinite=infinite,

torchtitan/experiments/flux/flux_argparser.py

Lines changed: 39 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -11,10 +11,16 @@
1111

1212
def extend_parser(parser: argparse.ArgumentParser) -> None:
1313
parser.add_argument(
14-
"--training.guidance",
14+
"--training.classifer_free_guidance_prob",
1515
type=float,
16-
default=3.5,
17-
help="guidance value used for guidance distillation",
16+
default=0.0,
17+
help="Classifier-free guidance with probability p to dropout the text conditioning",
18+
)
19+
parser.add_argument(
20+
"--training.img_size",
21+
type=int,
22+
default=256,
23+
help="Image width to sample",
1824
)
1925
parser.add_argument(
2026
"--encoder.t5_encoder",
@@ -45,3 +51,33 @@ def extend_parser(parser: argparse.ArgumentParser) -> None:
4551
action="store_true",
4652
help="Whether to shard the encoder using FSDP",
4753
)
54+
# eval configs
55+
parser.add_argument(
56+
"--eval.enable_classifer_free_guidance",
57+
action="store_true",
58+
help="Whether to use classifier-free guidance during sampling",
59+
)
60+
parser.add_argument(
61+
"--eval.classifier_free_guidance_scale",
62+
type=float,
63+
default=5.0,
64+
help="Classifier-free guidance scale when sampling",
65+
)
66+
parser.add_argument(
67+
"--eval.denoising_steps",
68+
type=int,
69+
default=50,
70+
help="How many denoising steps to sample when generating an image",
71+
)
72+
parser.add_argument(
73+
"--eval.eval_freq",
74+
type=int,
75+
default=100,
76+
help="Frequency of evaluation/sampling during training",
77+
)
78+
parser.add_argument(
79+
"--eval.save_img_folder",
80+
type=str,
81+
default="img",
82+
help="Directory to save image generated/sampled from the model",
83+
)

torchtitan/experiments/flux/model/model.py

Lines changed: 1 addition & 16 deletions
Original file line numberDiff line numberDiff line change
@@ -40,7 +40,6 @@ class FluxModelArgs(BaseModelArgs):
4040
axes_dim: tuple = (16, 56, 56)
4141
theta: int = 10_000
4242
qkv_bias: bool = True
43-
guidance_embed: bool = True
4443
autoencoder_params: AutoEncoderParams = field(default_factory=AutoEncoderParams)
4544

4645
def update_from_config(self, job_config: JobConfig, tokenizer: Tokenizer) -> None:
@@ -89,11 +88,6 @@ def __init__(self, model_args: FluxModelArgs):
8988
self.img_in = nn.Linear(self.in_channels, self.hidden_size, bias=True)
9089
self.time_in = MLPEmbedder(in_dim=256, hidden_dim=self.hidden_size)
9190
self.vector_in = MLPEmbedder(model_args.vec_in_dim, self.hidden_size)
92-
self.guidance_in = (
93-
MLPEmbedder(in_dim=256, hidden_dim=self.hidden_size)
94-
if model_args.guidance_embed
95-
else nn.Identity()
96-
)
9791
self.txt_in = nn.Linear(model_args.context_in_dim, self.hidden_size)
9892

9993
self.double_blocks = nn.ModuleList(
@@ -127,11 +121,9 @@ def init_weights(self, buffer_device=None):
127121
nn.init.xavier_uniform_(self.txt_in.weight)
128122
nn.init.constant_(self.txt_in.bias, 0)
129123

130-
# Initialize time_in, vector_in, guidance_in (MLPEmbedder)
124+
# Initialize time_in, vector_in (MLPEmbedder)
131125
self.time_in.init_weights(init_std=0.02)
132126
self.vector_in.init_weights(init_std=0.02)
133-
if self.model_args.guidance_embed:
134-
self.guidance_in.init_weights(init_std=0.02)
135127

136128
# Initialize transformer blocks:
137129
for block in self.single_blocks:
@@ -150,20 +142,13 @@ def forward(
150142
txt_ids: Tensor,
151143
timesteps: Tensor,
152144
y: Tensor,
153-
guidance: Tensor | None = None,
154145
) -> Tensor:
155146
if img.ndim != 3 or txt.ndim != 3:
156147
raise ValueError("Input img and txt tensors must have 3 dimensions.")
157148

158149
# running on sequences img
159150
img = self.img_in(img)
160151
vec = self.time_in(timestep_embedding(timesteps, 256))
161-
if self.model_args.guidance_embed:
162-
if guidance is None:
163-
raise ValueError(
164-
"Didn't get guidance strength for guidance distilled model."
165-
)
166-
vec = vec + self.guidance_in(timestep_embedding(guidance, 256))
167152
vec = vec + self.vector_in(y)
168153
txt = self.txt_in(txt)
169154

torchtitan/experiments/flux/parallelize_flux.py

Lines changed: 0 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -77,7 +77,6 @@ def apply_fsdp(
7777
linear_layers = [
7878
model.img_in,
7979
model.time_in,
80-
model.guidance_in,
8180
model.vector_in,
8281
model.txt_in,
8382
]

0 commit comments

Comments
 (0)