You signed in with another tab or window. Reload to refresh your session.You signed out in another tab or window. Reload to refresh your session.You switched accounts on another tab or window. Reload to refresh your session.Dismiss alert
**What does this PR do?**
1. This PR introduce ModelSpec to decribe a model and how to parallelize a model.
2. All the models should define `build_model_spec()` or `model_spec` to
be imported by the `model` module.
3. `build_model_specs()` is called in the trainer to get the `model_specs` and the result is used to get the corresponding model spec.
4. Users can also use `--experimental.model_module_path` to dynamically import a model that is not implemented by TorchTitan.
**Why do we need this PR?**
This allows users to use TorchTitan with a new model without intrusively change TorchTitan code.
**Next steps**
1. This PR only include the mode definitions, configurations, totkenizer, parallize_fn, and
pipelining_fn. We may also want to extend ModelSpec to include optimizer and lr_scheduler
2. Current TorchTitan parallelize and pipelining_fn import ModelArgs which can cause circular imports.
We should fix this issue.
**What does this PR do?**
1. Introduces `ModelSpec` to describe a model and how to parallelize it.
2. Requires all models to define `build_model_spec()` or `model_spec`, which will be imported by the model module.
3. Calls `build_model_specs()` in the trainer to obtain `model_specs`, which are then used to retrieve the corresponding model spec.
4. Allows users to dynamically import a model not implemented by TorchTitan using --experimental.model_module_path.
**Why do we need this PR?**
This PR enables users to integrate new models with TorchTitan without making intrusive changes to the TorchTitan codebase.
**Next steps**
1. This PR includes only the model definitions, configurations, tokenizer, parallelize_fn, and pipelining_fn. We may want to extend ModelSpec to include the optimizer and learning rate scheduler.
2. The current TorchTitan parallelize and pipelining_fn import ModelArgs, which can lead to circular imports. This issue needs to be addressed.
ghstack-source-id: 28259eb74975eeb7ad790a774b6e719f3aa19a31
Pull Request resolved: #814
0 commit comments