Skip to content

Commit 82f7387

Browse files
committed
minor fix
ghstack-source-id: 0dd35232e76d80a4542a7e91b2d25fea663938b6 Pull Request resolved: #788
1 parent 95677cb commit 82f7387

File tree

3 files changed

+6
-3
lines changed

3 files changed

+6
-3
lines changed

docs/checkpoint.md

+1-1
Original file line numberDiff line numberDiff line change
@@ -75,5 +75,5 @@ A seed checkpoint does initialization of the model on a single CPU, and can be l
7575
To create a seed checkpoint, use the same model config as you use for training.
7676
e.g.
7777
```bash
78-
NGPU=1 CONFIG=<path_to_model_config> ./run_llama_train.sh --checkpoint.enable_checkpoint --checkpoint.create_seed_checkpoint --training.data_parallel_shard_degree 1
78+
NGPU=1 CONFIG=<path_to_model_config> ./run_llama_train.sh --checkpoint.enable_checkpoint --checkpoint.create_seed_checkpoint --training.data_parallel_replicate_degree 1 --training.data_parallel_shard_degree 1 --training.tensor_parallel_degree 1 --experimental.pipeline_parallel_degree 1 --experimental.context_parallel_degree 1
7979
```

torchtitan/parallelisms/pipelining_utils.py

+1-1
Original file line numberDiff line numberDiff line change
@@ -107,7 +107,7 @@ def build_pipeline_schedule(job_config, stages, loss_fn):
107107
)
108108
logger.info(
109109
f"Using pipeline schedule {job_config.experimental.pipeline_parallel_schedule} \
110-
with {n_microbatches} and {num_total_stages} stages."
110+
with {n_microbatches} microbatches and {num_total_stages} stages."
111111
)
112112

113113
if pp_schedule_csv:

train.py

+4-1
Original file line numberDiff line numberDiff line change
@@ -201,7 +201,10 @@ def loss_fn(pred, labels):
201201
if job_config.checkpoint.create_seed_checkpoint:
202202
assert (
203203
world_size == 1
204-
), "Must create seed-checkpoint using one gpu, to disable sharding"
204+
), "Must create seed checkpoint using a single device, to disable sharding"
205+
assert (
206+
job_config.checkpoint.enable_checkpoint
207+
), "Must enable checkpointing when creating a seed checkpoint"
205208
checkpoint.save(curr_step=0, force=True)
206209
logger.info("Created seed checkpoint")
207210
return

0 commit comments

Comments
 (0)