Skip to content

Commit 8e67a01

Browse files
James Sunfacebook-github-bot
authored andcommitted
link launch and sync conda/workspace locations (#742)
Summary: Pull Request resolved: #742 Make sure the conda/workspace locations during launch map with the locations when we sync. Reviewed By: kiukchung Differential Revision: D79516268
1 parent 49900f3 commit 8e67a01

File tree

4 files changed

+79
-20
lines changed

4 files changed

+79
-20
lines changed

python/monarch/_src/actor/proc_mesh.py

Lines changed: 22 additions & 15 deletions
Original file line numberDiff line numberDiff line change
@@ -8,11 +8,11 @@
88

99
import asyncio
1010
import logging
11-
import os
1211
import sys
1312
import threading
1413
import warnings
1514
from contextlib import AbstractContextManager
15+
from pathlib import Path
1616

1717
from typing import (
1818
Any,
@@ -75,6 +75,8 @@
7575
from monarch._src.actor.endpoint import endpoint
7676
from monarch._src.actor.future import DeprecatedNotAFuture, Future
7777
from monarch._src.actor.shape import MeshTrait
78+
from monarch.tools.config import Workspace
79+
from monarch.tools.utils import conda as conda_utils
7880

7981
HAS_TENSOR_ENGINE = False
8082
try:
@@ -361,7 +363,12 @@ def rank_tensor(self, dim: str | Sequence[str]) -> "Tensor":
361363
def rank_tensors(self) -> Dict[str, "Tensor"]:
362364
return self._device_mesh.ranks
363365

364-
async def sync_workspace(self, conda: bool = False, auto_reload: bool = False) -> None:
366+
async def sync_workspace(
367+
self,
368+
workspace: Workspace = None,
369+
conda: bool = False,
370+
auto_reload: bool = False,
371+
) -> None:
365372
if self._code_sync_client is None:
366373
self._code_sync_client = CodeSyncMeshClient.spawn_blocking(
367374
proc_mesh=await self._proc_mesh_for_asyncio_fixme,
@@ -373,25 +380,25 @@ async def sync_workspace(self, conda: bool = False, auto_reload: bool = False) -
373380
# The workspace shape (i.e. only perform one rsync per host).
374381
assert set(self._shape.labels).issubset({"gpus", "hosts"})
375382

376-
# TODO(agallagher): Is there a better way to infer/set the local
377-
# workspace dir, rather than use PWD?
378-
workspaces = [
379-
WorkspaceConfig(
380-
local=os.getcwd(),
381-
remote=RemoteWorkspace(
382-
location=WorkspaceLocation.FromEnvVar("WORKSPACE_DIR"),
383-
shape=WorkspaceShape.shared("gpus"),
383+
workspaces = []
384+
if workspace is not None:
385+
workspaces.append(
386+
WorkspaceConfig(
387+
local=Path(workspace),
388+
remote=RemoteWorkspace(
389+
location=WorkspaceLocation.FromEnvVar("WORKSPACE_DIR"),
390+
shape=WorkspaceShape.shared("gpus"),
391+
),
392+
method=CodeSyncMethod.Rsync,
384393
),
385-
method=CodeSyncMethod.Rsync,
386-
),
387-
]
394+
)
388395

389396
# If `conda` is set, also sync the currently activated conda env.
390-
conda_prefix = os.environ.get("CONDA_PREFIX")
397+
conda_prefix = conda_utils.active_env_dir()
391398
if conda and conda_prefix is not None:
392399
workspaces.append(
393400
WorkspaceConfig(
394-
local=conda_prefix,
401+
local=Path(conda_prefix),
395402
remote=RemoteWorkspace(
396403
location=WorkspaceLocation.FromEnvVar("CONDA_PREFIX"),
397404
shape=WorkspaceShape.shared("gpus"),

python/monarch/tools/config/__init__.py

Lines changed: 10 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -6,7 +6,7 @@
66

77
# pyre-strict
88
from dataclasses import dataclass, field
9-
from typing import Any, Dict, List, Optional
9+
from typing import Any, Dict, List
1010

1111
from torchx.specs import Role
1212

@@ -24,6 +24,14 @@ class UnnamedAppDef:
2424
metadata: Dict[str, str] = field(default_factory=dict)
2525

2626

27+
# TODO: provide a proper Workspace class to support
28+
# - multiple workspaces
29+
# - empty workspaces
30+
# - no workspace
31+
# - experimental directories
32+
Workspace = str | None
33+
34+
2735
@dataclass
2836
class Config:
2937
"""
@@ -32,6 +40,6 @@ class Config:
3240

3341
scheduler: str = NOT_SET
3442
scheduler_args: dict[str, Any] = field(default_factory=dict)
35-
workspace: Optional[str] = None
43+
workspace: Workspace = None
3644
dryrun: bool = False
3745
appdef: UnnamedAppDef = field(default_factory=UnnamedAppDef)

python/monarch/tools/config/defaults.py

Lines changed: 3 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -8,10 +8,10 @@
88

99
"""Defines defaults for ``monarch.tools``"""
1010

11-
from typing import Callable, Optional
11+
from typing import Callable
1212

1313
from monarch.tools.components import hyperactor
14-
from monarch.tools.config import Config, UnnamedAppDef
14+
from monarch.tools.config import Config, UnnamedAppDef, Workspace
1515

1616
from torchx import specs
1717
from torchx.schedulers import (
@@ -40,7 +40,7 @@ def scheduler_factories() -> dict[str, SchedulerFactory]:
4040
}
4141

4242

43-
def config(scheduler: str, workspace: Optional[str] = None) -> Config:
43+
def config(scheduler: str, workspace: Workspace = None) -> Config:
4444
"""The default :py:class:`~monarch.tools.config.Config` to use when submitting to the provided ``scheduler``."""
4545
return Config(scheduler=scheduler, workspace=workspace)
4646

python/tests/test_python_actors.py

Lines changed: 44 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -35,6 +35,7 @@
3535
local_proc_mesh,
3636
proc_mesh,
3737
)
38+
from monarch.tools.config import defaults
3839
from typing_extensions import assert_type
3940

4041

@@ -748,6 +749,49 @@ async def test_same_actor_twice() -> None:
748749
), f"Expected error message about duplicate actor name, got: {error_msg}"
749750

750751

752+
class LsActor(Actor):
753+
def __init__(self, workspace: str):
754+
self.workspace = workspace
755+
756+
@endpoint
757+
async def ls(self) -> list[str]:
758+
return os.listdir(self.workspace)
759+
760+
761+
async def test_sync_workspace() -> None:
762+
pm = await proc_mesh(gpus=1)
763+
764+
# create two workspaces: one for local and one for remote
765+
with tempfile.TemporaryDirectory() as workspace_src, tempfile.TemporaryDirectory() as workspace_dst, unittest.mock.patch.dict(
766+
os.environ, {"WORKSPACE_DIR": workspace_dst}
767+
):
768+
os.environ["WORKSPACE_DIR"] = workspace_dst
769+
config = defaults.config("slurm", workspace_src)
770+
await pm.sync_workspace(
771+
workspace=config.workspace, conda=False, auto_reload=True
772+
)
773+
774+
# now file in remote workspace initially
775+
am = await pm.spawn("ls", LsActor, workspace_dst)
776+
for item in list(am.ls.call().get()):
777+
assert len(item[1]) == 0
778+
779+
# write a file to local workspace
780+
file_path = os.path.join(workspace_src, "new_file")
781+
with open(file_path, "w") as f:
782+
f.write("hello world")
783+
f.flush()
784+
785+
# force a sync and it should populate on the dst workspace
786+
await pm.sync_workspace(config.workspace, conda=False, auto_reload=True)
787+
for item in list(am.ls.call().get()):
788+
assert len(item[1]) == 1
789+
assert item[1][0] == "new_file"
790+
file_path = os.path.join(workspace_dst, item[1][0])
791+
with open(file_path, "r") as f:
792+
assert f.readline() == "hello world"
793+
794+
751795
class TestActorMeshStop(unittest.IsolatedAsyncioTestCase):
752796
async def test_actor_mesh_stop(self) -> None:
753797
pm = proc_mesh(gpus=2)

0 commit comments

Comments
 (0)