diff --git a/python/monarch/_src/actor/proc_mesh.py b/python/monarch/_src/actor/proc_mesh.py index fe9313909..edf3ceb2a 100644 --- a/python/monarch/_src/actor/proc_mesh.py +++ b/python/monarch/_src/actor/proc_mesh.py @@ -78,6 +78,9 @@ _RdmaManager, ) + # type: ignore[import] + from monarch._src.tensor_engine.tcp import TCPManager # @manual + # type: ignore[16] HAS_TENSOR_ENGINE = torch.cuda.is_available() except ImportError: @@ -132,6 +135,9 @@ def __init__( self._slice = False # type: ignore[21] self._rdma_manager: Optional["_RdmaManager"] = None + # type: ignore[21] + self._tcp_manager: Optional["TCPManager"] = None + self._debug_manager: Optional[DebugManager] = None self._code_sync_client: Optional[CodeSyncMeshClient] = None self._logging_mesh_client: Optional[LoggingMeshClient] = None @@ -182,12 +188,20 @@ async def _init_manager_actors_coro( else None ) + _tcp_manager = ( + # type: ignore[16] + await self._spawn_nonblocking_on(proc_mesh, "tcp_manager", TCPManager) + if HAS_TENSOR_ENGINE + else None + ) + _debug_manager = await self._spawn_nonblocking_on( proc_mesh, _DEBUG_MANAGER_ACTOR_NAME, DebugManager, await _debug_client() ) self._debug_manager = _debug_manager self._rdma_manager = _rdma_manager + self._tcp_manager = _tcp_manager if setup is not None: # If the user has passed the setup lambda, we need to call diff --git a/python/monarch/_src/tensor_engine/rdma.py b/python/monarch/_src/tensor_engine/rdma.py index fe7dddfcf..5208b5ec0 100644 --- a/python/monarch/_src/tensor_engine/rdma.py +++ b/python/monarch/_src/tensor_engine/rdma.py @@ -101,6 +101,14 @@ def read_into( Returns an ActorFuture that can be awaited or called with .get() for blocking operation. """ + try: + MonarchContext.get() + except LookupError: + raise RuntimeError( + "RDMABuffer.read_into() can only be called from within a Monarch actor context. " + "Make sure you're calling this from within an actor method." + ) + _assert_tensor_is_1d_contiguous_uint8(dst) dst_gpu = None if dst.device.type != "cpu": @@ -148,6 +156,14 @@ def write_from( Returns an ActorFuture that can be awaited or called with .get() for blocking operation. """ + try: + MonarchContext.get() + except LookupError: + raise RuntimeError( + "RDMABuffer.write_from() can only be called from within a Monarch actor context. " + "Make sure you're calling this from within an actor method." + ) + _assert_tensor_is_1d_contiguous_uint8(src) src_gpu = None if src.device.type != "cpu": diff --git a/python/monarch/_src/tensor_engine/tcp.py b/python/monarch/_src/tensor_engine/tcp.py new file mode 100644 index 000000000..79ff2daca --- /dev/null +++ b/python/monarch/_src/tensor_engine/tcp.py @@ -0,0 +1,372 @@ +# Copyright (c) Meta Platforms, Inc. and affiliates. +# All rights reserved. +# +# This source code is licensed under the BSD-style license found in the +# LICENSE file in the root directory of this source tree. + +""" +This file provides a TCPBuffer, a light-weight version of RDMABuffer that works on +any hardware. +""" + +import ctypes + +from dataclasses import dataclass +from typing import cast, Dict, Optional, Tuple + +import torch +import zmq +import zmq.asyncio + +from monarch._rust_bindings.monarch_hyperactor.proc import ActorId +from monarch._src.actor.actor_mesh import Actor, ActorMesh, MonarchContext +from monarch._src.actor.endpoint import endpoint +from monarch._src.actor.future import Future + + +@dataclass +class LocalTCPRecord: + data: torch.Tensor + + +@dataclass +class ZMQConnectionInfo: + """Connection information needed to establish ZMQ connection""" + + endpoint: str # ZMQ endpoint (e.g., "tcp://127.0.0.1:5555") + + +class ZMQConnection: + """Manages a bidirectional ZMQ connection between two TCPManagers""" + + def __init__(self, context: zmq.asyncio.Context, is_server: bool = False): + self.context = context + self.is_server = is_server + self.send_socket: Optional[zmq.asyncio.Socket] = None + self.recv_socket: Optional[zmq.asyncio.Socket] = None + self.endpoint: Optional[str] = None + self.connected = False + + def initialize(self) -> ZMQConnectionInfo: + """Initialize the connection and return connection info""" + if self.is_server: + # Server creates a PAIR socket and binds to a random port + self.send_socket = self.context.socket(zmq.PAIR) + port = self.send_socket.bind_to_random_port("tcp://127.0.0.1") + self.endpoint = f"tcp://127.0.0.1:{port}" + self.recv_socket = self.send_socket # PAIR socket is bidirectional + else: + # Client will connect later + self.send_socket = self.context.socket(zmq.PAIR) + self.recv_socket = self.send_socket # PAIR socket is bidirectional + + return ZMQConnectionInfo(endpoint=self.endpoint or "") + + def connect(self, connection_info: ZMQConnectionInfo) -> None: + """Connect to the remote endpoint""" + if not self.is_server and self.send_socket: + self.send_socket.connect(connection_info.endpoint) + self.endpoint = connection_info.endpoint + self.connected = True + + async def send(self, data: bytes) -> None: + """Send data through the connection""" + if not self.connected or not self.send_socket: + raise RuntimeError("Connection not established") + await self.send_socket.send(data) + + async def recv(self) -> bytes | zmq.Frame: + """Receive data from the connection""" + if not self.connected or not self.recv_socket: + raise RuntimeError("Connection not established") + return await self.recv_socket.recv() + + def close(self) -> None: + """Close the connection""" + if self.send_socket: + self.send_socket.close() + if self.recv_socket and self.recv_socket != self.send_socket: + self.recv_socket.close() + self.connected = False + + +_local_buffers: Dict[int, "LocalTCPRecord"] = {} + + +def _get_bytes(storage: torch.Tensor, offset: int, size: int) -> bytearray: + """Extracts a bytearray from a 1D, 1byte per item tensor.""" + if offset + size > storage.numel(): + raise ValueError(f"Read out of range: {offset + size} > {storage.size()}") + addr = storage.data_ptr() + if storage.device.type != "cpu": + result = bytearray(size) + result_tensor = torch.frombuffer( + result, + dtype=torch.uint8, + ) + source_tensor = storage[offset:] + result_tensor.copy_(source_tensor) + else: + ctypes_array = (ctypes.c_byte * size).from_address(addr) + result = bytearray(ctypes_array) + return result + + +class TCPManager(Actor): + # Note - we go through ZMQ instead of Monarch's TCP implementation + # to bypass Rust limitations we've seen... + def __init__(self): + # Map between ActorIds and their corresponding ZMQConnection + self.connection_map: Dict[ActorId, ZMQConnection] = {} + # ZMQ context for managing sockets (lazy-initialized) + self._zmq_context: Optional[zmq.asyncio.Context] = None + + @property + def zmq_context(self) -> zmq.asyncio.Context: + """Lazy-initialize ZMQ context to avoid serialization issues""" + if self._zmq_context is None: + self._zmq_context = zmq.asyncio.Context() + return self._zmq_context + + def __reduce__(self) -> Tuple[type, Tuple[()]]: + """ + Custom pickle reduction that only preserves the class type. + Similar to how ActorMeshRef handles pickling - we don't serialize + the ZMQ connections or context, just recreate a fresh TCPManager. + """ + return (self.__class__, ()) + + @staticmethod + def on_proc(proc_id: str) -> "ActorMesh[TCPManager]": + ctx = MonarchContext.get() + return ActorMesh.from_actor_id( + Class=TCPManager, + actor_id=ActorId.from_string(f"{proc_id}.tcp_manager[0]"), + mailbox=ctx.mailbox, + ) + + @endpoint + async def drop(self, addr: int) -> None: + if addr in _local_buffers: + del _local_buffers[addr] + + @endpoint + async def fetch(self, addr: int, offset: int, nbytes: int) -> bytearray: + if addr not in _local_buffers: + raise ValueError(f"Unknown buffer {addr}") + storage = _local_buffers[addr].data + return _get_bytes(storage, offset, nbytes) + + @endpoint + async def put(self, addr: int, offset: int, bytes: bytearray) -> None: + if addr not in _local_buffers: + raise ValueError(f"Unknown buffer {addr}") + storage = _local_buffers[addr].data + storage[offset : offset + len(bytes)] = torch.frombuffer( + bytes, dtype=storage.dtype + ) + + def _is_connected(self, other_id: ActorId) -> bool: + """Check if connected to another TCPManager""" + if other_id not in self.connection_map: + return False + return self.connection_map[other_id].connected + + @endpoint + def is_connected(self, other_id: ActorId) -> bool: + """Check if connected to another TCPManager""" + return self._is_connected(other_id) + + def _initialize_connection(self, remote_id: ActorId) -> bool: + """Initialize a new ZMQ connection with another TCPManager""" + if remote_id in self.connection_map: + return True # Already initialized + + # Determine if this actor should be the server (based on actor ID comparison) + current_id = MonarchContext.get().mailbox.actor_id + is_server = str(current_id) < str(remote_id) + + connection = ZMQConnection(self.zmq_context, is_server=is_server) + connection.initialize() + self.connection_map[remote_id] = connection + + return True + + @endpoint + async def initialize_connection(self, other_id: ActorId) -> bool: + """Initialize a new ZMQ connection with another TCPManager""" + return self._initialize_connection(other_id) + + def _connection_info(self, other_id: ActorId) -> ZMQConnectionInfo: + """Get connection information for establishing a ZMQ connection""" + if other_id not in self.connection_map: + raise ValueError(f"No connection initialized for actor {other_id}") + + connection = self.connection_map[other_id] + if not connection.endpoint: + raise ValueError( + f"Connection not properly initialized for actor {other_id}" + ) + + return ZMQConnectionInfo(endpoint=connection.endpoint) + + @endpoint + async def connection_info(self, other_id: ActorId) -> ZMQConnectionInfo: + """Get connection information for establishing a ZMQ connection""" + return self._connection_info(other_id) + + def _connect(self, other_id: ActorId, connection_info: ZMQConnectionInfo) -> None: + """Establish connection with another TCPManager using provided connection info""" + if other_id not in self.connection_map: + raise ValueError(f"No connection initialized for actor {other_id}") + + connection = self.connection_map[other_id] + connection.connect(connection_info) + + @endpoint + async def connect(self, other_id: ActorId, connection_info: ZMQConnectionInfo): + """Establish connection with another TCPManager using provided connection info""" + self._connect(other_id, connection_info) + + @endpoint + async def request_connection(self, remote_id: ActorId) -> ZMQConnection: + """ + Main method to get/create connections with another TCPManager. + Similar to RDMAManager's request_queue_pair. + """ + current_id = MonarchContext.get().mailbox.actor_id + + if not self._is_connected(remote_id): + is_loopback = remote_id == current_id + + if is_loopback: + self._initialize_connection(remote_id) + connection_info = self._connection_info(remote_id) + self._connect(remote_id, connection_info) + else: + # Get remote TCPManager reference + remote_tcp_manager = TCPManager.on_proc(remote_id.proc_id) + + # Initialize connections on both sides + self._initialize_connection(remote_id) + # pyre-ignore[16]: Endpoint is not propagating through on_proc. + await remote_tcp_manager.initialize_connection.call_one(current_id) + + # Exchange connection information + remote_connection_info = ( + await remote_tcp_manager.connection_info.call_one(current_id) + ) + self._connect(remote_id, remote_connection_info) + + local_connection_info = self._connection_info(remote_id) + await remote_tcp_manager.connect.call_one( + current_id, local_connection_info + ) + + connection = self.connection_map.get(remote_id) + if not connection: + raise RuntimeError(f"Failed to establish connection with {remote_id}") + + return connection + + +def _assert_tensor_is_1d_contiguous_uint8(t: torch.Tensor) -> None: + if t.ndim != 1: + raise ValueError(f"Tensor must be 1D, got {t.ndim}D") + if t.dtype != torch.uint8: + raise ValueError(f"Tensor must be uint8, got {t.dtype}") + if not t.is_contiguous(): + raise ValueError("Tensor must be contiguous") + + +class TCPBuffer: + def __init__(self, data: torch.Tensor) -> None: + """ + TCPBuffer only supports 1D contiguous tensors that are 1 byte per item. + + To create a 1 byte, 1D view, use t.view(torch.uint8).flatten() + """ + _assert_tensor_is_1d_contiguous_uint8(data) + assert data.storage_offset() == 0 + storage = data.untyped_storage() + self.addr: int = storage.data_ptr() + self.begin = 0 + self.end: int = storage.size() + self.proc_id: str = MonarchContext.get().proc_id + self.local_data: object = None + _local_buffers[self.addr] = LocalTCPRecord(data) + + def drop(self) -> None: + if self.proc_id is None: + del _local_buffers[self.addr] + return + rmda_actor = TCPManager.on_proc(self.proc_id) + # pyre-ignore[16]: Undefined attribute [16]: `Endpoint` has no attribute `cast`. + rmda_actor.drop.cast(self.addr) + + def __getstate__(self) -> Tuple[int, int, int, Optional[str]]: + proc_id = self.proc_id + # locally created TCPBuffer being set remotely, + # record its proc_id so we know how to establish connections to it + if proc_id is None: + proc_id = MonarchContext.get().proc_id + return (self.addr, self.begin, self.end, proc_id) + + def __setstate__(self, state: Tuple[int, int, int, str]) -> None: + self.local_data = None + self.addr, self.begin, self.end, self.proc_id = state + + def read_into( + self, dst: torch.Tensor, offset: int = 0, *args, **kwargs + ) -> Future[None]: + """ + Read data from the TCPBuffer into a destination tensor. + + The destination tensor must be contiguous and 1 byte per item. + """ + try: + MonarchContext.get() + except LookupError: + raise RuntimeError( + "TCPBuffer.read_into() can only be called from within a Monarch actor context. " + "Make sure you're calling this from within an actor method." + ) + + _assert_tensor_is_1d_contiguous_uint8(dst) + + # pyre-ignore[16]: Endpoint is not propagating through on_proc. + bytes_future = TCPManager.on_proc(self.proc_id).fetch.call_one( + self.addr, offset, dst.numel() + ) + + async def coro() -> None: + bytes_ = await bytes_future + dst.copy_(torch.frombuffer(bytes_, dtype=torch.uint8)) + + return Future(coro=coro()) + + def write_from( + self, src: torch.Tensor, offset: int = 0, *args, **kwargs + ) -> Future[None]: + """ + Write data from a source tensor into the TCPBuffer. + + The source tensor must be contiguous and 1 byte per item. + """ + # Check if we're in a Monarch context + try: + MonarchContext.get() + except LookupError: + raise RuntimeError( + "TCPBuffer.write_from() can only be called from within a Monarch actor context. " + "Make sure you're calling this from within an actor method." + ) + + _assert_tensor_is_1d_contiguous_uint8(src) + bytes_ = _get_bytes( + src, + cast(int, src.storage_offset()), + src.numel(), + ) + # pyre-ignore[16]: Endpoint is not propagating through on_proc. + return TCPManager.on_proc(self.proc_id).put.call_one(self.addr, offset, bytes_) diff --git a/python/monarch/tensor_engine/__init__.py b/python/monarch/tensor_engine/__init__.py index 172a7a5b1..4870a9e2a 100644 --- a/python/monarch/tensor_engine/__init__.py +++ b/python/monarch/tensor_engine/__init__.py @@ -14,10 +14,13 @@ RDMAReadTransferWarning, RDMAWriteTransferWarning, ) +from monarch._src.tensor_engine.tcp import TCPBuffer, TCPManager __all__ = [ "is_available", "RDMABuffer", "RDMAReadTransferWarning", "RDMAWriteTransferWarning", + "TCPBuffer", + "TCPManager", ] diff --git a/python/tests/test_tcp_buffer.py b/python/tests/test_tcp_buffer.py new file mode 100644 index 000000000..b87d628a3 --- /dev/null +++ b/python/tests/test_tcp_buffer.py @@ -0,0 +1,191 @@ +# Copyright (c) Meta Platforms, Inc. and affiliates. +# All rights reserved. +# +# This source code is licensed under the BSD-style license found in the +# LICENSE file in the root directory of this source tree. + +import pytest + +import torch +from monarch.actor import Actor, current_rank, endpoint, proc_mesh +from monarch.tensor_engine import TCPBuffer + + +needs_cuda = pytest.mark.skipif( + not torch.cuda.is_available(), + reason="CUDA not available", +) + + +class ParameterServer(Actor): + def __init__(self): + self.params = torch.rand(10, 10) + self.grad_buffer = torch.rand(10, 10) + + @endpoint + async def grad_handle(self) -> TCPBuffer: + byte_tensor = self.grad_buffer.view(torch.uint8).flatten() + buffer = TCPBuffer(byte_tensor) + return buffer + + @endpoint + async def update(self): + self.params += 0.01 * self.grad_buffer + + @endpoint + async def get_grad_buffer(self) -> torch.Tensor: + # just used for testing + return self.grad_buffer + + +class ParameterClient(Actor): + def __init__(self, server, buffer): + self.server = server + byte_tensor = buffer.view(torch.uint8).flatten() + self.buffer = byte_tensor + + @endpoint + async def upload(self, tensor): + gh = await self.server.grad_handle.call_one() + await gh.write_from(tensor) + + @endpoint + async def download(self): + gh = await self.server.grad_handle.call_one() + await gh.read_into(self.buffer) + + @endpoint + async def get_buffer(self): + return self.buffer + + +@needs_cuda +async def test_proc_mesh_tcp(): + proc = await proc_mesh(gpus=1) + server = await proc.spawn("server", ParameterServer) + + # --- CPU TESTS --- + client_cpu = await proc.spawn( + "client_cpu", ParameterClient, server, torch.ones(10, 10) + ) + x = await client_cpu.get_buffer.call_one() + assert torch.sum(x.view(torch.float32).view(10, 10)) == 100 + zeros = torch.zeros(10, 10) + await client_cpu.upload.call_one(zeros.view(torch.uint8).flatten()) + await client_cpu.download.call_one() + x = await client_cpu.get_buffer.call_one() + assert torch.sum(x.view(torch.float32).view(10, 10)) == 0 + + # --- Modify server's backing buffer directly --- + await server.update.call_one() + + # Should reflect updated values + await client_cpu.download.call_one() + + buffer = await client_cpu.get_buffer.call_one() + remote_grad = await server.get_grad_buffer.call_one() + assert torch.allclose(buffer.view(torch.float32).view(10, 10), remote_grad) + + # --- GPU TESTS --- + client_gpu = await proc.spawn( + "client_gpu", ParameterClient, server, torch.ones(10, 10, device="cuda") + ) + x = await client_gpu.get_buffer.call_one() + buffer = x.view(torch.float32).view(10, 10) + assert torch.sum(buffer) == 100 + zeros = torch.zeros(10, 10, device="cuda") + await client_gpu.upload.call_one(zeros.view(torch.uint8).flatten()) + await client_gpu.download.call_one() + x = await client_gpu.get_buffer.call_one() + buffer_gpu = x.view(torch.float32).view(10, 10) + assert torch.sum(buffer_gpu) == 0 + # copying a tensor across hosts moves it to CPU + assert buffer_gpu.device.type == "cpu" + + # Modify server state again + await server.update.call_one() + await client_gpu.download.call_one() + x = await client_gpu.get_buffer.call_one() + buffer_gpu = x.view(torch.float32).view(10, 10) + remote_grad = await server.get_grad_buffer.call_one() + assert torch.allclose(buffer_gpu.cpu(), remote_grad) + + +class TrainerActor(Actor): + def __init__(self): + super().__init__() + # TODO - switch to CUDA once GPU support is added + self.trainer = torch.nn.Linear(10, 10).to("cpu") + self.trainer.weight.data.zero_() + + @endpoint + async def init(self, gen): + ranks = current_rank() + self.gen = gen.slice(**ranks) + + @endpoint + async def exchange_metadata(self): + byte_tensor = self.trainer.weight.data.view(torch.uint8).flatten() + self.handle = TCPBuffer(byte_tensor) + await self.gen.attach_weight_buffer.call(self.handle) + + @endpoint + async def weights_ready(self): + self.trainer.weight.data.add_(1.0) + + +class GeneratorActor(Actor): + def __init__(self): + super().__init__() + self.generator = torch.nn.Linear(10, 10).to("cuda") + self.step = 0 + + @endpoint + async def init(self, trainer): + ranks = current_rank() + self.trainer = trainer.slice(**ranks) + + @endpoint + async def attach_weight_buffer(self, handle): + self.handle = handle + + @endpoint + async def update_weights(self): + self.step += 1 + byte_tensor = self.generator.weight.data.view(torch.uint8).flatten() + await self.handle.read_into(byte_tensor) + assert ( + torch.sum(self.generator.weight.data) == self.step * 100 + ), f"{torch.sum(self.generator.weight.data)=}, {self.step=}" + + +@needs_cuda +async def test_gpu_trainer_generator(): + trainer_proc = await proc_mesh(gpus=1) + gen_proc = await proc_mesh(gpus=1) + trainer = await trainer_proc.spawn("trainer", TrainerActor) + generator = await gen_proc.spawn("gen", GeneratorActor) + + await generator.init.call(trainer) + await trainer.init.call(generator) + await trainer.exchange_metadata.call() + + for _ in range(3): + await trainer.weights_ready.call() + await generator.update_weights.call() + + +@needs_cuda +def test_gpu_trainer_generator_sync() -> None: + trainer_proc = proc_mesh(gpus=1).get() + gen_proc = proc_mesh(gpus=1).get() + trainer = trainer_proc.spawn("trainer", TrainerActor).get() + generator = gen_proc.spawn("gen", GeneratorActor).get() + + generator.init.call(trainer).get() + trainer.init.call(generator).get() + trainer.exchange_metadata.call().get() + + for _ in range(1): + trainer.weights_ready.call().get() + generator.update_weights.call().get()