|
| 1 | +# Copyright (c) Meta Platforms, Inc. and affiliates. |
| 2 | +# All rights reserved. |
| 3 | +# |
| 4 | +# This source code is licensed under the BSD-style license found in the |
| 5 | +# LICENSE file in the root directory of this source tree. |
| 6 | +# |
| 7 | +# Copyright (c) Meta Platforms, Inc. All Rights Reserved. |
| 8 | + |
| 9 | + |
| 10 | +from dataclasses import dataclass |
| 11 | +from typing import Callable, Dict, List, Protocol, Tuple, Type |
| 12 | + |
| 13 | +import torch.nn as nn |
| 14 | +from torch.distributed.pipelining.schedules import _PipelineSchedule |
| 15 | +from torchtitan.config_manager import JobConfig |
| 16 | +from torchtitan.optimizer import LRSchedulersContainer, OptimizersContainer |
| 17 | + |
| 18 | + |
| 19 | +@dataclass |
| 20 | +class BaseModelArgs: |
| 21 | + """All ModelArgs should inherit from this class. |
| 22 | +
|
| 23 | + The only usage of this class is type checking but allows us to extend common |
| 24 | + arguments to all models in the future. |
| 25 | + """ |
| 26 | + |
| 27 | + _enforced: str = "This field is used to enforce all fields have defaults." |
| 28 | + |
| 29 | + |
| 30 | +class ModelProtocol(Protocol): |
| 31 | + """Defines the interface for a model class. |
| 32 | +
|
| 33 | + This is used to enforce that all model classes have some methods that are |
| 34 | + required by the TorchTitan trainer. |
| 35 | + """ |
| 36 | + |
| 37 | + @staticmethod |
| 38 | + def from_model_args(self, args: BaseModelArgs) -> nn.Module: ... |
| 39 | + |
| 40 | + |
| 41 | +@dataclass |
| 42 | +class ModelSpec: |
| 43 | + name: str |
| 44 | + cls: Type[nn.Module] |
| 45 | + config: Dict[str, BaseModelArgs] |
| 46 | + # TODO: Add a ``build_dataloader_fn`` |
| 47 | + # As for now, this is a string. So it will have to be built-in to the |
| 48 | + # TorchTitan library. A better way would be to have a dataloader class |
| 49 | + # and a ``build_dataloader`` function that take job_config to consume |
| 50 | + # the different dataloader and tokenizer configs. |
| 51 | + tokenizer: str |
| 52 | + parallelize_fn: Callable[[nn.Module], None] |
| 53 | + pipelining_fn: Callable[[nn.Module], Tuple[_PipelineSchedule, List[nn.Module]]] |
| 54 | + build_optimizers_fn: Callable[[List[nn.Module], JobConfig], OptimizersContainer] |
| 55 | + build_lr_schedulers_fn: Callable[ |
| 56 | + [List[nn.Module], JobConfig], LRSchedulersContainer |
| 57 | + ] |
| 58 | + |
| 59 | + # TODO: Add a FQN convert fn to allow users to load checkpoints from |
| 60 | + # HuggingFace or other sources that have different FQN conventions. |
| 61 | + |
| 62 | + |
| 63 | +_model_specs = {} |
| 64 | + |
| 65 | + |
| 66 | +def register_model_spec(model_spec: ModelSpec) -> None: |
| 67 | + global _model_specs |
| 68 | + if model_spec.name in _model_specs: |
| 69 | + raise ValueError(f"Model {model_spec.name} is already registered.") |
| 70 | + _model_specs[model_spec.name] = model_spec |
| 71 | + |
| 72 | + |
| 73 | +def get_model_spec(name: str) -> ModelSpec: |
| 74 | + global _model_specs |
| 75 | + if name not in _model_specs: |
| 76 | + raise ValueError(f"Model {name} is not registered.") |
| 77 | + return _model_specs[name] |
0 commit comments