Skip to content

Commit 14c810e

Browse files
authored
Fix seed checkpoint creation assertion (#1112)
Fixes #1107
1 parent 508350b commit 14c810e

File tree

2 files changed

+4
-4
lines changed

2 files changed

+4
-4
lines changed

torchtitan/experiments/flux/train.py

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -226,8 +226,8 @@ def eval_step(self, prompt: str = "A photo of a cat"):
226226
try:
227227
trainer = FluxTrainer(config)
228228
if config.checkpoint.create_seed_checkpoint:
229-
assert int(
230-
os.environ["WORLD_SIZE"]
229+
assert (
230+
int(os.environ["WORLD_SIZE"]) == 1
231231
), "Must create seed checkpoint using a single device, to disable sharding."
232232
assert (
233233
config.checkpoint.enable_checkpoint

torchtitan/train.py

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -455,8 +455,8 @@ def close(self) -> None:
455455
trainer = Trainer(config)
456456

457457
if config.checkpoint.create_seed_checkpoint:
458-
assert int(
459-
os.environ["WORLD_SIZE"]
458+
assert (
459+
int(os.environ["WORLD_SIZE"]) == 1
460460
), "Must create seed checkpoint using a single device, to disable sharding."
461461
assert (
462462
config.checkpoint.enable_checkpoint

0 commit comments

Comments
 (0)