Skip to content

link launch and sync conda/workspace locations #742

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Closed
wants to merge 1 commit into from
Closed
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
32 changes: 18 additions & 14 deletions python/monarch/_src/actor/proc_mesh.py
Original file line number Diff line number Diff line change
Expand Up @@ -8,7 +8,6 @@

import asyncio
import logging
import os
import sys
import threading
import warnings
Expand Down Expand Up @@ -70,6 +69,8 @@
from monarch._src.actor.endpoint import endpoint
from monarch._src.actor.future import DeprecatedNotAFuture, Future
from monarch._src.actor.shape import MeshTrait
from monarch.tools.config import Workspace
from monarch.tools.utils import conda as conda_utils

HAS_TENSOR_ENGINE = False
try:
Expand Down Expand Up @@ -369,7 +370,10 @@ def rank_tensors(self) -> Dict[str, "Tensor"]:
return self._device_mesh.ranks

async def sync_workspace(
self, conda: bool = False, auto_reload: bool = False
self,
workspace: Workspace = None,
conda: bool = False,
auto_reload: bool = False,
) -> None:
if self._code_sync_client is None:
self._code_sync_client = CodeSyncMeshClient.spawn_blocking(
Expand All @@ -382,21 +386,21 @@ async def sync_workspace(
# The workspace shape (i.e. only perform one rsync per host).
assert set(self._shape.labels).issubset({"gpus", "hosts"})

# TODO(agallagher): Is there a better way to infer/set the local
# workspace dir, rather than use PWD?
workspaces = [
WorkspaceConfig(
local=Path(os.getcwd()),
remote=RemoteWorkspace(
location=WorkspaceLocation.FromEnvVar("WORKSPACE_DIR"),
shape=WorkspaceShape.shared("gpus"),
workspaces = []
if workspace is not None:
workspaces.append(
WorkspaceConfig(
local=Path(workspace),
remote=RemoteWorkspace(
location=WorkspaceLocation.FromEnvVar("WORKSPACE_DIR"),
shape=WorkspaceShape.shared("gpus"),
),
method=CodeSyncMethod.Rsync,
),
method=CodeSyncMethod.Rsync,
),
]
)

# If `conda` is set, also sync the currently activated conda env.
conda_prefix = os.environ.get("CONDA_PREFIX")
conda_prefix = conda_utils.active_env_dir()
if conda and conda_prefix is not None:
workspaces.append(
WorkspaceConfig(
Expand Down
18 changes: 14 additions & 4 deletions python/monarch/tools/config/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -6,9 +6,11 @@

# pyre-strict
from dataclasses import dataclass, field
from typing import Any, Dict, List, Optional
from typing import Any, Dict, List, TYPE_CHECKING

from torchx.specs import Role
# Defer the import of Role to avoid requiring torchx at import time
if TYPE_CHECKING:
from torchx.specs import Role


NOT_SET: str = "__NOT_SET__"
Expand All @@ -20,10 +22,18 @@ class UnnamedAppDef:
A TorchX AppDef without a name.
"""

roles: List[Role] = field(default_factory=list)
roles: List["Role"] = field(default_factory=list)
metadata: Dict[str, str] = field(default_factory=dict)


# TODO: provide a proper Workspace class to support
# - multiple workspaces
# - empty workspaces
# - no workspace
# - experimental directories
Workspace = str | None


@dataclass
class Config:
"""
Expand All @@ -32,6 +42,6 @@ class Config:

scheduler: str = NOT_SET
scheduler_args: dict[str, Any] = field(default_factory=dict)
workspace: Optional[str] = None
workspace: Workspace = None
dryrun: bool = False
appdef: UnnamedAppDef = field(default_factory=UnnamedAppDef)
6 changes: 3 additions & 3 deletions python/monarch/tools/config/defaults.py
Original file line number Diff line number Diff line change
Expand Up @@ -8,10 +8,10 @@

"""Defines defaults for ``monarch.tools``"""

from typing import Callable, Optional
from typing import Callable

from monarch.tools.components import hyperactor
from monarch.tools.config import Config, UnnamedAppDef
from monarch.tools.config import Config, UnnamedAppDef, Workspace

from torchx import specs
from torchx.schedulers import (
Expand Down Expand Up @@ -40,7 +40,7 @@ def scheduler_factories() -> dict[str, SchedulerFactory]:
}


def config(scheduler: str, workspace: Optional[str] = None) -> Config:
def config(scheduler: str, workspace: Workspace = None) -> Config:
"""The default :py:class:`~monarch.tools.config.Config` to use when submitting to the provided ``scheduler``."""
return Config(scheduler=scheduler, workspace=workspace)

Expand Down
44 changes: 44 additions & 0 deletions python/tests/test_python_actors.py
Original file line number Diff line number Diff line change
Expand Up @@ -37,6 +37,7 @@
local_proc_mesh,
proc_mesh,
)
from monarch.tools.config import defaults
from typing_extensions import assert_type


Expand Down Expand Up @@ -950,6 +951,49 @@ async def test_same_actor_twice() -> None:
), f"Expected error message about duplicate actor name, got: {error_msg}"


class LsActor(Actor):
def __init__(self, workspace: str):
self.workspace = workspace

@endpoint
async def ls(self) -> list[str]:
return os.listdir(self.workspace)


async def test_sync_workspace() -> None:
pm = await proc_mesh(gpus=1)

# create two workspaces: one for local and one for remote
with tempfile.TemporaryDirectory() as workspace_src, tempfile.TemporaryDirectory() as workspace_dst, unittest.mock.patch.dict(
os.environ, {"WORKSPACE_DIR": workspace_dst}
):
os.environ["WORKSPACE_DIR"] = workspace_dst
config = defaults.config("slurm", workspace_src)
await pm.sync_workspace(
workspace=config.workspace, conda=False, auto_reload=True
)

# now file in remote workspace initially
am = await pm.spawn("ls", LsActor, workspace_dst)
for item in list(am.ls.call().get()):
assert len(item[1]) == 0

# write a file to local workspace
file_path = os.path.join(workspace_src, "new_file")
with open(file_path, "w") as f:
f.write("hello world")
f.flush()

# force a sync and it should populate on the dst workspace
await pm.sync_workspace(config.workspace, conda=False, auto_reload=True)
for item in list(am.ls.call().get()):
assert len(item[1]) == 1
assert item[1][0] == "new_file"
file_path = os.path.join(workspace_dst, item[1][0])
with open(file_path, "r") as f:
assert f.readline() == "hello world"


class TestActorMeshStop(unittest.IsolatedAsyncioTestCase):
async def test_actor_mesh_stop(self) -> None:
pm = proc_mesh(gpus=2)
Expand Down
Loading