Skip to content

Commit 3024d38

Browse files
mryabjustheuristic
andauthored
Support learning rate schedulers in ExpertBackend (#196)
* Add empty __init__ to hivemind_cli for correct package discovery * Support learning rate schedulers in ExpertBackend * Save/load full expert state * Don't pass compression to make_empty * spawn -> fork * Remove load_expert_states * Make TaskPoolBase an abstract class * Output warning if some of the keys in state_dict are missing Co-authored-by: justheuristic <[email protected]>
1 parent f132294 commit 3024d38

21 files changed

+384
-227
lines changed

.circleci/config.yml

Lines changed: 6 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -8,11 +8,11 @@ jobs:
88
- checkout
99
- restore_cache:
1010
keys:
11-
- v1-{{ checksum "requirements.txt" }}-{{ checksum "requirements-dev.txt" }}
11+
- py37-v1-{{ checksum "requirements.txt" }}-{{ checksum "requirements-dev.txt" }}
1212
- run: pip install -r requirements.txt
1313
- run: pip install -r requirements-dev.txt
1414
- save_cache:
15-
key: v1-{{ checksum "requirements.txt" }}-{{ checksum "requirements-dev.txt" }}
15+
key: py37-v1-{{ checksum "requirements.txt" }}-{{ checksum "requirements-dev.txt" }}
1616
paths:
1717
- '~/.cache/pip'
1818
- run:
@@ -28,11 +28,11 @@ jobs:
2828
- checkout
2929
- restore_cache:
3030
keys:
31-
- v1-{{ checksum "requirements.txt" }}-{{ checksum "requirements-dev.txt" }}
31+
- py38-v1-{{ checksum "requirements.txt" }}-{{ checksum "requirements-dev.txt" }}
3232
- run: pip install -r requirements.txt
3333
- run: pip install -r requirements-dev.txt
3434
- save_cache:
35-
key: v1-{{ checksum "requirements.txt" }}-{{ checksum "requirements-dev.txt" }}
35+
key: py38-v1-{{ checksum "requirements.txt" }}-{{ checksum "requirements-dev.txt" }}
3636
paths:
3737
- '~/.cache/pip'
3838
- run:
@@ -48,11 +48,11 @@ jobs:
4848
- checkout
4949
- restore_cache:
5050
keys:
51-
- v1-{{ checksum "requirements.txt" }}-{{ checksum "requirements-dev.txt" }}
51+
- py39-v1-{{ checksum "requirements.txt" }}-{{ checksum "requirements-dev.txt" }}
5252
- run: pip install -r requirements.txt
5353
- run: pip install -r requirements-dev.txt
5454
- save_cache:
55-
key: v1-{{ checksum "requirements.txt" }}-{{ checksum "requirements-dev.txt" }}
55+
key: py39-v1-{{ checksum "requirements.txt" }}-{{ checksum "requirements-dev.txt" }}
5656
paths:
5757
- '~/.cache/pip'
5858
- run:

hivemind/hivemind_cli/__init__.py

Whitespace-only changes.

hivemind/hivemind_cli/run_server.py

Lines changed: 8 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -8,6 +8,7 @@
88
from hivemind.server import Server
99
from hivemind.utils.threading import increase_file_limit
1010
from hivemind.utils.logging import get_logger
11+
from hivemind.server.layers import schedule_name_to_scheduler
1112

1213
logger = get_logger(__name__)
1314

@@ -28,13 +29,20 @@ def main():
2829
parser.add_argument('--expert_cls', type=str, default='ffn', required=False,
2930
help="expert type from test_utils.layers, e.g. 'ffn', 'transformer', 'det_dropout' or 'nop'.")
3031
parser.add_argument('--hidden_dim', type=int, default=1024, required=False, help='main dimension for expert_cls')
32+
3133
parser.add_argument('--num_handlers', type=int, default=None, required=False,
3234
help='server will use this many processes to handle incoming requests')
3335
parser.add_argument('--max_batch_size', type=int, default=16384, required=False,
3436
help='The total number of examples in the same batch will not exceed this value')
3537
parser.add_argument('--device', type=str, default=None, required=False,
3638
help='all experts will use this device in torch notation; default: cuda if available else cpu')
39+
3740
parser.add_argument('--optimizer', type=str, default='adam', required=False, help='adam, sgd or none')
41+
parser.add_argument('--scheduler', type=str, choices=schedule_name_to_scheduler.keys(), default='none',
42+
help='LR scheduler type to use')
43+
parser.add_argument('--num-warmup-steps', type=int, required=False, help='the number of warmup steps for LR schedule')
44+
parser.add_argument('--num-training-steps', type=int, required=False, help='the total number of steps for LR schedule')
45+
3846
parser.add_argument('--no_dht', action='store_true', help='if specified, the server will not be attached to a dht')
3947
parser.add_argument('--initial_peers', type=str, nargs='*', required=False, default=[],
4048
help='one or more peers that can welcome you to the dht, e.g. 1.2.3.4:1337 192.132.231.4:4321')
@@ -45,7 +53,6 @@ def main():
4553
parser.add_argument('--compression', type=str, default='NONE', required=False, help='Tensor compression '
4654
'parameter for grpc. Can be NONE, MEANSTD or FLOAT16')
4755
parser.add_argument('--checkpoint_dir', type=Path, required=False, help='Directory to store expert checkpoints')
48-
parser.add_argument('--load_experts', action='store_true', help='Load experts from the checkpoint directory')
4956

5057
# fmt:on
5158
args = vars(parser.parse_args())

hivemind/server/__init__.py

Lines changed: 48 additions & 94 deletions
Original file line numberDiff line numberDiff line change
@@ -2,23 +2,22 @@
22

33
import multiprocessing as mp
44
import multiprocessing.synchronize
5-
import random
65
import threading
76
from contextlib import contextmanager
87
from functools import partial
9-
from typing import Dict, Optional, Tuple, List
8+
from typing import Dict, Optional, Tuple
109
from pathlib import Path
1110

1211
import torch
1312

1413
import hivemind
1514
from hivemind.dht import DHT
16-
from hivemind.server.expert_uid import UID_DELIMITER
17-
from hivemind.server.checkpoints import CheckpointSaver, load_weights, dir_is_correct
15+
from hivemind.server.expert_uid import UID_DELIMITER, generate_uids_from_pattern
16+
from hivemind.server.checkpoints import CheckpointSaver, load_experts, is_directory
1817
from hivemind.server.connection_handler import ConnectionHandler
1918
from hivemind.server.dht_handler import DHTHandlerThread, declare_experts, get_experts
2019
from hivemind.server.expert_backend import ExpertBackend
21-
from hivemind.server.layers import name_to_block, name_to_input
20+
from hivemind.server.layers import name_to_block, name_to_input, schedule_name_to_scheduler
2221
from hivemind.server.runtime import Runtime
2322
from hivemind.server.task_pool import Task, TaskPool, TaskPoolBase
2423
from hivemind.utils import Endpoint, get_port, replace_port, find_open_port, get_logger
@@ -68,11 +67,12 @@ def __init__(
6867
if start:
6968
self.run_in_background(await_ready=True)
7069

71-
@staticmethod
72-
def create(listen_on='0.0.0.0:*', num_experts: int = None, expert_uids: str = None, expert_pattern: str = None,
73-
expert_cls='ffn', hidden_dim=1024, optim_cls=torch.optim.Adam, num_handlers=None, max_batch_size=4096,
74-
device=None, no_dht=False, initial_peers=(), dht_port=None, checkpoint_dir: Optional[Path] = None,
75-
load_experts=False, compression=CompressionType.NONE, *, start: bool, **kwargs) -> Server:
70+
@classmethod
71+
def create(cls, listen_on='0.0.0.0:*', num_experts: int = None, expert_uids: str = None, expert_pattern: str = None,
72+
expert_cls='ffn', hidden_dim=1024, optim_cls=torch.optim.Adam, scheduler: str = 'none',
73+
num_warmup_steps=None, num_training_steps=None, num_handlers=None, max_batch_size=4096, device=None,
74+
no_dht=False, initial_peers=(), dht_port=None, checkpoint_dir: Optional[Path] = None,
75+
compression=CompressionType.NONE, *, start: bool, **kwargs) -> Server:
7676
"""
7777
Instantiate a server with several identical experts. See argparse comments below for details
7878
:param listen_on: network interface with address and (optional) port, e.g. "127.0.0.1:1337" or "[::]:80"
@@ -85,16 +85,20 @@ def create(listen_on='0.0.0.0:*', num_experts: int = None, expert_uids: str = No
8585
:param num_handlers: server will use this many parallel processes to handle incoming requests
8686
:param max_batch_size: total num examples in the same batch will not exceed this value
8787
:param device: all experts will use this device in torch notation; default: cuda if available else cpu
88+
8889
:param optim_cls: uses this optimizer to train all experts
90+
:param scheduler: if not `none`, the name of the expert LR scheduler
91+
:param num_warmup_steps: the number of warmup steps for LR schedule
92+
:param num_training_steps: the total number of steps for LR schedule
93+
8994
:param no_dht: if specified, the server will not be attached to a dht
9095
:param initial_peers: a list of peers that will introduce this node to the dht,\
9196
e.g. ('123.11.22.33:1337', '[fe80::abe2:db1c:be7d:5a85]:4567'), default = no peers
9297
9398
:param dht_port: DHT node will listen on this port, default = find open port
9499
You can then use this node as initial peer for subsequent servers.
95100
96-
:param checkpoint_dir: directory to save expert checkpoints
97-
:param load_experts: whether to load expert checkpoints from checkpoint_dir
101+
:param checkpoint_dir: directory to save and load expert checkpoints
98102
99103
:param compression: if specified, use this compression to pack all inputs, outputs and gradients by all experts
100104
hosted on this server. For a more fine-grained compression, start server in python and specify compression
@@ -113,23 +117,29 @@ def create(listen_on='0.0.0.0:*', num_experts: int = None, expert_uids: str = No
113117
dht = hivemind.DHT(initial_peers=initial_peers, start=True, listen_on=dht_endpoint)
114118
logger.info(f"Running DHT node on port {dht.port}, initial peers = {initial_peers}")
115119

116-
if load_experts:
117-
assert dir_is_correct(checkpoint_dir)
118-
assert expert_uids is None, "Can't both load saved experts and create new ones from given UIDs"
119-
expert_uids = [child.name for child in checkpoint_dir.iterdir() if (child / 'checkpoint_last.pt').exists()]
120-
if expert_uids:
121-
logger.info(f"Located checkpoints for experts {expert_uids}, ignoring UID generation options")
122-
else:
123-
logger.info(f"No expert checkpoints found in {checkpoint_dir}, generating...")
124-
125-
assert (expert_pattern is None and num_experts is None) or (expert_uids is None) or (num_experts == 0), \
126-
"Please provide either expert_uids *or* num_experts and expert_pattern, but not both"
120+
assert ((expert_pattern is None and num_experts is None and expert_uids is not None) or
121+
(num_experts is not None and expert_uids is None)), \
122+
"Please provide either expert_uids *or* num_experts (possibly with expert_pattern), but not both"
127123

128-
# get expert uids if not loaded previously
129124
if expert_uids is None:
130-
assert num_experts is not None, "Please specify either expert_uids or num_experts [and expert_pattern]"
131-
logger.info(f"Generating expert uids from pattern {expert_pattern}")
132-
expert_uids = generate_uids_from_pattern(num_experts, expert_pattern, dht=dht)
125+
if checkpoint_dir is not None:
126+
assert is_directory(checkpoint_dir)
127+
expert_uids = [child.name for child in checkpoint_dir.iterdir() if
128+
(child / 'checkpoint_last.pt').exists()]
129+
total_experts_in_checkpoint = len(expert_uids)
130+
logger.info(f"Located {total_experts_in_checkpoint} checkpoints for experts {expert_uids}")
131+
132+
if total_experts_in_checkpoint > num_experts:
133+
raise ValueError(
134+
f"Found {total_experts_in_checkpoint} checkpoints, but num_experts is set to {num_experts}, "
135+
f"which is smaller. Either increase num_experts or remove unneeded checkpoints.")
136+
else:
137+
expert_uids = []
138+
139+
uids_to_generate = num_experts - len(expert_uids)
140+
if uids_to_generate > 0:
141+
logger.info(f"Generating {uids_to_generate} expert uids from pattern {expert_pattern}")
142+
expert_uids.extend(generate_uids_from_pattern(uids_to_generate, expert_pattern, dht))
133143

134144
num_experts = len(expert_uids)
135145
num_handlers = num_handlers if num_handlers is not None else num_experts * 8
@@ -142,6 +152,8 @@ def create(listen_on='0.0.0.0:*', num_experts: int = None, expert_uids: str = No
142152
else:
143153
args_schema = (hivemind.BatchTensorDescriptor.from_tensor(sample_input, compression),)
144154

155+
scheduler = schedule_name_to_scheduler[scheduler]
156+
145157
# initialize experts
146158
experts = {}
147159
for expert_uid in expert_uids:
@@ -150,15 +162,17 @@ def create(listen_on='0.0.0.0:*', num_experts: int = None, expert_uids: str = No
150162
args_schema=args_schema,
151163
outputs_schema=hivemind.BatchTensorDescriptor(
152164
hidden_dim, compression=compression),
153-
opt=optim_cls(expert.parameters()),
165+
optimizer=optim_cls(expert.parameters()),
166+
scheduler=scheduler,
167+
num_warmup_steps=num_warmup_steps,
168+
num_training_steps=num_training_steps,
154169
max_batch_size=max_batch_size)
155170

156-
if load_experts:
157-
load_weights(experts, checkpoint_dir)
171+
if checkpoint_dir is not None:
172+
load_experts(experts, checkpoint_dir)
158173

159-
server = Server(dht, experts, listen_on=listen_on, num_connection_handlers=num_handlers, device=device,
160-
start=start)
161-
return server
174+
return cls(dht, experts, listen_on=listen_on, num_connection_handlers=num_handlers, device=device,
175+
checkpoint_dir=checkpoint_dir, start=start)
162176

163177
def run(self):
164178
"""
@@ -241,7 +255,7 @@ def shutdown(self):
241255
def background_server(*args, shutdown_timeout=5, **kwargs) -> Tuple[hivemind.Endpoint, hivemind.Endpoint]:
242256
""" A context manager that creates server in a background thread, awaits .ready on entry and shutdowns on exit """
243257
pipe, runners_pipe = mp.Pipe(duplex=True)
244-
runner = mp.get_context("spawn").Process(target=_server_runner, args=(runners_pipe, *args), kwargs=kwargs)
258+
runner = mp.Process(target=_server_runner, args=(runners_pipe, *args), kwargs=kwargs)
245259

246260
try:
247261
runner.start()
@@ -269,63 +283,3 @@ def _server_runner(pipe, *args, **kwargs):
269283
server.shutdown()
270284
server.join()
271285
logger.info("Server shut down.")
272-
273-
274-
def generate_uids_from_pattern(num_experts: int, expert_pattern: Optional[str], dht: Optional[DHT] = None,
275-
attempts_per_expert=10) -> List[str]:
276-
"""
277-
Sample experts from a given pattern, remove duplicates.
278-
:param num_experts: sample this many unique expert uids
279-
:param expert_pattern: a string pattern or a list of expert uids, example: myprefix.[0:32].[0:256]\
280-
means "sample random experts between myprefix.0.0 and myprefix.255.255;
281-
:param dht: if specified, uses this DHT to check that expert uids are not yet occupied by other peers
282-
:param attempts_per_expert: give up if unable to generate a new expert uid after this many attempts per uid
283-
:note: this method is not strictly process-safe. If several servers run it concurrently, they have
284-
a small chance of sampling duplicate expert uids.
285-
"""
286-
remaining_attempts = attempts_per_expert * num_experts
287-
found_uids, attempted_uids = list(), set()
288-
289-
def _generate_uid():
290-
if expert_pattern is None:
291-
return f"expert{UID_DELIMITER}{attempts_per_expert * num_experts - remaining_attempts}"
292-
293-
uid = []
294-
for block in expert_pattern.split(UID_DELIMITER):
295-
try:
296-
if '[' not in block and ']' not in block:
297-
uid.append(block)
298-
elif block.startswith('[') and block.endswith(']') and ':' in block:
299-
slice_start, slice_end = map(int, block[1:-1].split(':'))
300-
uid.append(str(random.randint(slice_start, slice_end - 1)))
301-
else:
302-
raise ValueError("Block must be either fixed or a range [from:to]")
303-
except KeyboardInterrupt as e:
304-
raise e
305-
except Exception as e:
306-
raise ValueError(f"Expert pattern {expert_pattern} has invalid block {block}, {e}")
307-
return UID_DELIMITER.join(uid)
308-
309-
while remaining_attempts > 0 and len(found_uids) < num_experts:
310-
311-
# 1. sample new expert uids at random
312-
new_uids = []
313-
while len(new_uids) + len(found_uids) < num_experts and remaining_attempts > 0:
314-
new_uid = _generate_uid()
315-
remaining_attempts -= 1
316-
if new_uid not in attempted_uids:
317-
attempted_uids.add(new_uid)
318-
new_uids.append(new_uid)
319-
320-
# 2. look into DHT (if given) and remove duplicates
321-
if dht:
322-
existing_expert_uids = {found_expert.uid for found_expert in dht.get_experts(new_uids)
323-
if found_expert is not None}
324-
new_uids = [new_uid for new_uid in new_uids if new_uid not in existing_expert_uids]
325-
326-
found_uids += new_uids
327-
328-
if len(found_uids) != num_experts:
329-
logger.warning(f"Found only {len(found_uids)} out of {num_experts} free expert uids after "
330-
f"{attempts_per_expert * num_experts} attempts")
331-
return found_uids

hivemind/server/checkpoints.py

Lines changed: 15 additions & 8 deletions
Original file line numberDiff line numberDiff line change
@@ -1,17 +1,20 @@
1+
import os
12
import threading
23
from datetime import datetime
34
from pathlib import Path
45
from shutil import copy2
56
from tempfile import TemporaryDirectory
67
from typing import Dict
7-
import os
88

99
import torch
1010

1111
from hivemind.server.expert_backend import ExpertBackend
12+
from hivemind.utils.logging import get_logger
1213

14+
logger = get_logger(__name__)
1315

14-
def dir_is_correct(directory: Path):
16+
17+
def is_directory(directory: Path):
1518
assert directory is not None
1619
assert directory.exists()
1720
assert directory.is_dir()
@@ -33,7 +36,7 @@ def copy_tree(src: str, dst: str):
3336
class CheckpointSaver(threading.Thread):
3437
def __init__(self, expert_backends: Dict[str, ExpertBackend], checkpoint_dir: Path, update_period: int):
3538
super().__init__()
36-
assert dir_is_correct(checkpoint_dir)
39+
assert is_directory(checkpoint_dir)
3740
self.expert_backends = expert_backends
3841
self.update_period = update_period
3942
self.checkpoint_dir = checkpoint_dir
@@ -48,21 +51,25 @@ def run(self) -> None:
4851

4952

5053
def store_experts(experts: Dict[str, ExpertBackend], checkpoint_dir: Path):
51-
assert dir_is_correct(checkpoint_dir)
54+
logger.debug(f'Storing experts at {checkpoint_dir.absolute()}')
55+
assert is_directory(checkpoint_dir)
5256
timestamp = datetime.now().isoformat(sep='_')
5357
with TemporaryDirectory() as tmpdirname:
5458
for expert_name, expert_backend in experts.items():
5559
expert_dir = Path(tmpdirname) / expert_name
5660
expert_dir.mkdir()
5761
checkpoint_name = expert_dir / f'checkpoint_{timestamp}.pt'
58-
torch.save(expert_backend.state_dict(), checkpoint_name)
62+
torch.save(expert_backend.get_full_state(), checkpoint_name)
5963
os.symlink(checkpoint_name, expert_dir / 'checkpoint_last.pt')
6064
copy_tree(tmpdirname, str(checkpoint_dir))
6165

6266

63-
def load_weights(experts: Dict[str, ExpertBackend], checkpoint_dir: Path):
64-
assert dir_is_correct(checkpoint_dir)
67+
def load_experts(experts: Dict[str, ExpertBackend], checkpoint_dir: Path):
68+
assert is_directory(checkpoint_dir)
6569
for expert_name, expert in experts.items():
6670
checkpoints_folder = checkpoint_dir / expert_name
6771
latest_checkpoint = checkpoints_folder / 'checkpoint_last.pt'
68-
expert.load_state_dict(torch.load(latest_checkpoint))
72+
if latest_checkpoint.exists():
73+
expert.load_full_state(torch.load(latest_checkpoint))
74+
else:
75+
logger.warning(f'Failed to load checkpoint for expert {expert_name}')

hivemind/server/connection_handler.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -16,7 +16,7 @@
1616
logger = get_logger(__name__)
1717

1818

19-
class ConnectionHandler(mp.Process):
19+
class ConnectionHandler(mp.context.ForkProcess):
2020
"""
2121
A process that accepts incoming requests to experts and submits them into the corresponding TaskPool.
2222

0 commit comments

Comments
 (0)