Skip to content
This repository has been archived by the owner on Mar 19, 2024. It is now read-only.

Commit

Permalink
Training step as a class (#275)
Browse files Browse the repository at this point in the history
Summary: Pull Request resolved: fairinternal/ssl_scaling#275

Reviewed By: mannatsingh

Differential Revision: D37619081

Pulled By: QuentinDuval

fbshipit-source-id: e4af1644554bc258c15b4186880464598bf6185e
  • Loading branch information
QuentinDuval authored and facebook-github-bot committed Jul 14, 2022
1 parent 3e5f2b6 commit 0d465ce
Show file tree
Hide file tree
Showing 2 changed files with 233 additions and 202 deletions.
20 changes: 10 additions & 10 deletions vissl/trainer/train_steps/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -19,7 +19,7 @@
TRAIN_STEP_NAMES = set()


def register_train_step(name):
def register_train_step(name: str):
"""
Registers Self-Supervision Train step.
Expand All @@ -36,30 +36,30 @@ def my_step_name():
To get a train step from a configuration file, see :func:`get_train_step`.
"""

def register_train_step_fn(func):
def register_train_step_cls(cls):
if name in TRAIN_STEP_REGISTRY:
raise ValueError("Cannot register duplicate train step ({})".format(name))

if func.__name__ in TRAIN_STEP_NAMES:
if cls.__name__ in TRAIN_STEP_NAMES:
raise ValueError(
"Cannot register task with duplicate train step name ({})".format(
func.__name__
cls.__name__
)
)
TRAIN_STEP_REGISTRY[name] = func
TRAIN_STEP_NAMES.add(func.__name__)
return func
TRAIN_STEP_REGISTRY[name] = cls
TRAIN_STEP_NAMES.add(cls.__name__)
return cls

return register_train_step_fn
return register_train_step_cls


def get_train_step(train_step_name: str):
def get_train_step(train_step_name: str, **train_step_kwargs):
"""
Lookup the train_step_name in the train step registry and return.
If the train step is not implemented, asserts will be thrown and workflow will exit.
"""
assert train_step_name in TRAIN_STEP_REGISTRY, "Unknown train step"
return TRAIN_STEP_REGISTRY[train_step_name]
return TRAIN_STEP_REGISTRY[train_step_name](**train_step_kwargs)


# automatically import any Python files in the train_steps/ directory
Expand Down
Loading

0 comments on commit 0d465ce

Please sign in to comment.