Skip to content

Commit 90884d7

Browse files
authored
allow configurable scheduler load group
Differential Revision: D67290464 Pull Request resolved: #992
1 parent c1a195a commit 90884d7

File tree

6 files changed

+40
-9
lines changed

6 files changed

+40
-9
lines changed

torchx/runner/config.py

Lines changed: 9 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -197,7 +197,15 @@ def _configparser() -> configparser.ConfigParser:
197197

198198

199199
def _get_scheduler(name: str) -> Scheduler:
200-
schedulers = get_scheduler_factories()
200+
schedulers = {
201+
**get_scheduler_factories(),
202+
**(
203+
get_scheduler_factories(
204+
group="torchx.schedulers.orchestrator", skip_defaults=True
205+
)
206+
or {}
207+
),
208+
}
201209
if name not in schedulers:
202210
raise ValueError(
203211
f"`{name}` is not a registered scheduler. Valid scheduler names: {schedulers.keys()}"

torchx/runner/test/config_test.py

Lines changed: 15 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -470,10 +470,22 @@ def test_dump_and_load_all_registered_schedulers(self) -> None:
470470
sfile = StringIO()
471471
dump(sfile)
472472

473-
for sched_name, sched in get_scheduler_factories().items():
473+
scheduler_factories = {
474+
**get_scheduler_factories(),
475+
**(
476+
get_scheduler_factories(
477+
group="torchx.schedulers.orchestrator", skip_defaults=True
478+
)
479+
or {}
480+
),
481+
}
482+
483+
for sched_name, sched in scheduler_factories.items():
474484
sfile.seek(0) # reset the file pos
475485
cfg = {}
476486
load(scheduler=sched_name, f=sfile, cfg=cfg)
477-
478487
for opt_name, _ in sched("test").run_opts():
479-
self.assertTrue(opt_name in cfg)
488+
self.assertTrue(
489+
opt_name in cfg,
490+
f"missing {opt_name} in {sched} run opts with cfg {cfg}",
491+
)

torchx/schedulers/__init__.py

Lines changed: 6 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -42,9 +42,11 @@ def run(*args: object, **kwargs: object) -> Scheduler:
4242
return run
4343

4444

45-
def get_scheduler_factories() -> Dict[str, SchedulerFactory]:
45+
def get_scheduler_factories(
46+
group: str = "torchx.schedulers", skip_defaults: bool = False
47+
) -> Dict[str, SchedulerFactory]:
4648
"""
47-
get_scheduler_factories returns all the available schedulers names and the
49+
get_scheduler_factories returns all the available schedulers names under `group` and the
4850
method to instantiate them.
4951
5052
The first scheduler in the dictionary is used as the default scheduler.
@@ -55,8 +57,9 @@ def get_scheduler_factories() -> Dict[str, SchedulerFactory]:
5557
default_schedulers[scheduler] = _defer_load_scheduler(path)
5658

5759
return load_group(
58-
"torchx.schedulers",
60+
group,
5961
default=default_schedulers,
62+
skip_defaults=skip_defaults,
6063
)
6164

6265

torchx/schedulers/test/registry_test.py

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -22,6 +22,7 @@ def __call__(
2222
group: str,
2323
default: Dict[str, Any],
2424
ignore_missing: Optional[bool] = False,
25+
skip_defaults: bool = False,
2526
) -> Dict[str, Any]:
2627
return default
2728

torchx/util/entrypoints.py

Lines changed: 4 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -51,8 +51,7 @@ def run(*args: object, **kwargs: object) -> object:
5151

5252
# pyre-ignore-all-errors[3, 2]
5353
def load_group(
54-
group: str,
55-
default: Optional[Dict[str, Any]] = None,
54+
group: str, default: Optional[Dict[str, Any]] = None, skip_defaults: bool = False
5655
):
5756
"""
5857
Loads all the entry points specified by ``group`` and returns
@@ -72,6 +71,7 @@ def load_group(
7271
1. ``load_group("foo")["bar"]("baz")`` -> equivalent to calling ``this.is.a_fn("baz")``
7372
1. ``load_group("food")`` -> ``None``
7473
1. ``load_group("food", default={"hello": this.is.c_fn})["hello"]("world")`` -> equivalent to calling ``this.is.c_fn("world")``
74+
1. ``load_group("food", default={"hello": this.is.c_fn}, skip_defaults=True)`` -> ``None``
7575
7676
7777
If the entrypoint is a module (versus a function as shown above), then calling the ``deferred_load_fn``
@@ -90,6 +90,8 @@ def load_group(
9090
entrypoints = metadata.entry_points().select(group=group)
9191

9292
if len(entrypoints) == 0:
93+
if skip_defaults:
94+
return None
9395
return default
9496

9597
eps = {}

torchx/util/test/entrypoints_test.py

Lines changed: 5 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -122,6 +122,11 @@ def test_load_group_with_default(self, _: MagicMock) -> None:
122122
self.assertEqual("barbaz", eps["foo"]())
123123
self.assertEqual("foobar", eps["bar"]())
124124

125+
eps = load_group(
126+
"ep.grp.test.missing", {"foo": barbaz, "bar": foobar}, skip_defaults=True
127+
)
128+
self.assertIsNone(eps)
129+
125130
@patch(_METADATA_EPS, return_value=_ENTRY_POINTS)
126131
def test_load_group_missing(self, _: MagicMock) -> None:
127132
with self.assertRaises(AttributeError):

0 commit comments

Comments
 (0)