Skip to content

Commit 20e2f06

Browse files
authored
[Flux] Fix missing field in flux argparse (#1120)
## Context 1. Bug fix. Fix missing fields in `arg_parser`. 2. Rebase to main folder changes, add `ft_pg` field in train.py
1 parent 14c810e commit 20e2f06

File tree

9 files changed

+43
-41
lines changed

9 files changed

+43
-41
lines changed

torchtitan/experiments/flux/dataset/flux_dataset.py

Lines changed: 6 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -262,9 +262,13 @@ def build_flux_dataloader(
262262
ds = FluxDataset(
263263
dataset_name=dataset_name,
264264
dataset_path=dataset_path,
265-
t5_tokenizer=FluxTokenizer(t5_encoder_name, max_length=max_t5_encoding_len),
265+
t5_tokenizer=FluxTokenizer(
266+
t5_encoder_name,
267+
max_length=max_t5_encoding_len,
268+
),
266269
clip_tokenizer=FluxTokenizer(
267-
clip_encoder_name, max_length=77
270+
clip_encoder_name,
271+
max_length=77,
268272
), # fix max_length for CLIP
269273
job_config=job_config,
270274
dp_rank=dp_rank,

torchtitan/experiments/flux/dataset/tokenizer.py

Lines changed: 4 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -23,20 +23,20 @@ class FluxTokenizer(Tokenizer):
2323
2424
"""
2525

26-
def __init__(self, model_path: str = "t5-small", max_length: int = 77):
26+
def __init__(self, model_path: str = "t5-small", max_length: int = 77, **hf_kwargs):
2727
super().__init__()
2828
self._n_words = 8 # TODO(jianiw): check
2929
self._max_length = max_length
3030

31-
self.is_clip = model_path.startswith("openai")
31+
self.is_clip = "clip" in model_path.lower()
3232

3333
if self.is_clip:
3434
self._tokenizer: CLIPTokenizer = CLIPTokenizer.from_pretrained(
35-
model_path, max_length=max_length
35+
model_path, max_length=max_length, **hf_kwargs
3636
)
3737
else:
3838
self._tokenizer: T5Tokenizer = T5Tokenizer.from_pretrained(
39-
model_path, max_length=max_length
39+
model_path, max_length=max_length, **hf_kwargs
4040
)
4141

4242
def encode(

torchtitan/experiments/flux/flux_argparser.py

Lines changed: 8 additions & 13 deletions
Original file line numberDiff line numberDiff line change
@@ -6,8 +6,6 @@
66

77
import argparse
88

9-
import torch
10-
119

1210
def extend_parser(parser: argparse.ArgumentParser) -> None:
1311
parser.add_argument(
@@ -26,31 +24,28 @@ def extend_parser(parser: argparse.ArgumentParser) -> None:
2624
"--encoder.t5_encoder",
2725
type=str,
2826
default="google/t5-v1_1-small",
29-
help="T5 encoder to use, HuggingFace model name.",
27+
help="T5 encoder to use, HuggingFace model name. This field could be either a local folder path, \
28+
or a Huggingface repo name.",
3029
)
3130
parser.add_argument(
3231
"--encoder.clip_encoder",
3332
type=str,
3433
default="openai/clip-vit-large-patch14",
35-
help="Clip encoder to use, HuggingFace model name.",
34+
help="Clip encoder to use, HuggingFace model name. This field could be either a local folder path, \
35+
or a Huggingface repo name.",
3636
)
3737
parser.add_argument(
38-
"--encoder.encoder_dtype",
39-
type=torch.dtype,
40-
default=torch.bfloat16,
41-
help="Which dtype to load for autoencoder. ",
38+
"--encoder.autoencoder_path",
39+
type=str,
40+
default="torchtitan/experiments/flux/assets/autoencoder/ae.safetensors",
41+
help="Autoencoder checkpoint path to load. This should be a local path referring to a safetensors file.",
4242
)
4343
parser.add_argument(
4444
"--encoder.max_t5_encoding_len",
4545
type=int,
4646
default=512,
4747
help="Maximum length of the T5 encoding.",
4848
)
49-
parser.add_argument(
50-
"--encoder.offload_encoder",
51-
action="store_true",
52-
help="Whether to shard the encoder using FSDP",
53-
)
5449
# eval configs
5550
parser.add_argument(
5651
"--eval.enable_classifer_free_guidance",

torchtitan/experiments/flux/model/hf_embedder.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -11,7 +11,7 @@
1111
class FluxEmbedder(nn.Module):
1212
def __init__(self, version: str, **hf_kwargs):
1313
super().__init__()
14-
self.is_clip = version.startswith("openai")
14+
self.is_clip = "clip" in version.lower()
1515
self.output_key = "pooler_output" if self.is_clip else "last_hidden_state"
1616

1717
if self.is_clip:

torchtitan/experiments/flux/tests/test_generate_image.py

Lines changed: 10 additions & 9 deletions
Original file line numberDiff line numberDiff line change
@@ -79,26 +79,27 @@ def test_generate_image(self):
7979
device=torch_device,
8080
dtype=torch.bfloat16,
8181
)
82-
clip_tokenizer = FluxTokenizer(
83-
model_path=config.encoder.clip_encoder, max_length=77
84-
)
82+
t5_encoder = FluxEmbedder(
83+
version=config.encoder.t5_encoder,
84+
).to(torch_device, dtype=torch.bfloat16)
8585
t5_tokenizer = FluxTokenizer(
8686
model_path=config.encoder.t5_encoder,
8787
max_length=config.encoder.max_t5_encoding_len,
8888
)
89-
clip_encoder = FluxEmbedder(version=config.encoder.clip_encoder).to(
90-
torch_device, dtype=torch.bfloat16
91-
)
92-
t5_encoder = FluxEmbedder(version=config.encoder.t5_encoder).to(
93-
torch_device, dtype=torch.bfloat16
89+
clip_encoder = FluxEmbedder(
90+
version=config.encoder.clip_encoder,
91+
).to(torch_device, dtype=torch.bfloat16)
92+
clip_tokenizer = FluxTokenizer(
93+
model_path=config.encoder.clip_encoder,
94+
max_length=77,
9495
)
9596

9697
if torch.cuda.is_available():
9798
torch.cuda.synchronize()
9899
t1 = time.perf_counter()
99100

100101
model = self._get_test_model(
101-
context_in_dim=768, device=torch_device, dtype=torch.bfloat16
102+
context_in_dim=4096, device=torch_device, dtype=torch.bfloat16
102103
)
103104
model.eval()
104105

torchtitan/experiments/flux/train.py

Lines changed: 11 additions & 9 deletions
Original file line numberDiff line numberDiff line change
@@ -51,17 +51,18 @@ def __init__(self, job_config: JobConfig):
5151
model_config = self.train_spec.config[job_config.model.flavor]
5252

5353
self.autoencoder = load_ae(
54-
job_config.encoder.auto_encoder_path,
54+
job_config.encoder.autoencoder_path,
5555
model_config.autoencoder_params,
5656
device=self.device,
5757
dtype=self._dtype,
5858
)
59-
self.clip_encoder = FluxEmbedder(version=job_config.encoder.clip_encoder).to(
60-
device=self.device, dtype=self._dtype
61-
)
62-
self.t5_encoder = FluxEmbedder(version=job_config.encoder.t5_encoder).to(
63-
device=self.device, dtype=self._dtype
64-
)
59+
60+
self.clip_encoder = FluxEmbedder(
61+
version=job_config.encoder.clip_encoder,
62+
).to(device=self.device, dtype=self._dtype)
63+
self.t5_encoder = FluxEmbedder(
64+
version=job_config.encoder.t5_encoder,
65+
).to(device=self.device, dtype=self._dtype)
6566

6667
# Apply FSDP to the T5 model / CLIP model
6768
self.t5_encoder, self.clip_encoder = parallelize_encoders(
@@ -159,9 +160,10 @@ def train_step(self, input_dict: dict[str, torch.Tensor], labels: torch.Tensor):
159160
or parallel_dims.cp_enabled
160161
):
161162
loss = loss.detach()
163+
ft_pg = self.ft_manager.replicate_pg if self.ft_manager.enabled else None
162164
global_avg_loss, global_max_loss = (
163-
dist_utils.dist_mean(loss, world_mesh["dp_cp"]),
164-
dist_utils.dist_max(loss, world_mesh["dp_cp"]),
165+
dist_utils.dist_mean(loss, world_mesh["dp_cp"], ft_pg),
166+
dist_utils.dist_max(loss, world_mesh["dp_cp"], ft_pg),
165167
)
166168
else:
167169
global_avg_loss = global_max_loss = loss.item()

torchtitan/experiments/flux/train_configs/debug_model.toml

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -46,7 +46,7 @@ img_size = 256
4646
t5_encoder = "google/t5-v1_1-xxl"
4747
clip_encoder = "openai/clip-vit-large-patch14"
4848
max_t5_encoding_len = 4096
49-
auto_encoder_path = "torchtitan/experiments/flux/assets/autoencoder/ae.safetensors" # Autoencoder to use for image
49+
autoencoder_path = "torchtitan/experiments/flux/assets/autoencoder/ae.safetensors" # Autoencoder to use for image
5050

5151
[eval]
5252
enable_classifer_free_guidance = true

torchtitan/experiments/flux/train_configs/flux_dev_model.toml

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -45,7 +45,7 @@ img_size = 256
4545
t5_encoder = "google/t5-v1_1-xxl"
4646
clip_encoder = "openai/clip-vit-large-patch14"
4747
max_t5_encoding_len = 4096
48-
auto_encoder_path = "torchtitan/experiments/flux/assets/autoencoder/ae.safetensors" # Autoencoder to use for image
48+
autoencoder_path = "torchtitan/experiments/flux/assets/autoencoder/ae.safetensors" # Autoencoder to use for image
4949

5050
[eval]
5151
enable_classifer_free_guidance = true

torchtitan/experiments/flux/train_configs/flux_schnell_model.toml

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -45,7 +45,7 @@ img_size = 256
4545
t5_encoder = "google/t5-v1_1-xxl"
4646
clip_encoder = "openai/clip-vit-large-patch14"
4747
max_t5_encoding_len = 4096
48-
auto_encoder_path = "torchtitan/experiments/flux/assets/autoencoder/ae.safetensors" # Autoencoder to use for image
48+
autoencoder_path = "torchtitan/experiments/flux/assets/autoencoder/ae.safetensors" # Autoencoder to use for image
4949

5050
[eval]
5151
enable_classifer_free_guidance = true

0 commit comments

Comments
 (0)