2
2
3
3
import multiprocessing as mp
4
4
import multiprocessing .synchronize
5
- import random
6
5
import threading
7
6
from contextlib import contextmanager
8
7
from functools import partial
9
- from typing import Dict , Optional , Tuple , List
8
+ from typing import Dict , Optional , Tuple
10
9
from pathlib import Path
11
10
12
11
import torch
13
12
14
13
import hivemind
15
14
from hivemind .dht import DHT
16
- from hivemind .server .expert_uid import UID_DELIMITER
17
- from hivemind .server .checkpoints import CheckpointSaver , load_weights , dir_is_correct
15
+ from hivemind .server .expert_uid import UID_DELIMITER , generate_uids_from_pattern
16
+ from hivemind .server .checkpoints import CheckpointSaver , load_experts , is_directory
18
17
from hivemind .server .connection_handler import ConnectionHandler
19
18
from hivemind .server .dht_handler import DHTHandlerThread , declare_experts , get_experts
20
19
from hivemind .server .expert_backend import ExpertBackend
21
- from hivemind .server .layers import name_to_block , name_to_input
20
+ from hivemind .server .layers import name_to_block , name_to_input , schedule_name_to_scheduler
22
21
from hivemind .server .runtime import Runtime
23
22
from hivemind .server .task_pool import Task , TaskPool , TaskPoolBase
24
23
from hivemind .utils import Endpoint , get_port , replace_port , find_open_port , get_logger
@@ -68,11 +67,12 @@ def __init__(
68
67
if start :
69
68
self .run_in_background (await_ready = True )
70
69
71
- @staticmethod
72
- def create (listen_on = '0.0.0.0:*' , num_experts : int = None , expert_uids : str = None , expert_pattern : str = None ,
73
- expert_cls = 'ffn' , hidden_dim = 1024 , optim_cls = torch .optim .Adam , num_handlers = None , max_batch_size = 4096 ,
74
- device = None , no_dht = False , initial_peers = (), dht_port = None , checkpoint_dir : Optional [Path ] = None ,
75
- load_experts = False , compression = CompressionType .NONE , * , start : bool , ** kwargs ) -> Server :
70
+ @classmethod
71
+ def create (cls , listen_on = '0.0.0.0:*' , num_experts : int = None , expert_uids : str = None , expert_pattern : str = None ,
72
+ expert_cls = 'ffn' , hidden_dim = 1024 , optim_cls = torch .optim .Adam , scheduler : str = 'none' ,
73
+ num_warmup_steps = None , num_training_steps = None , num_handlers = None , max_batch_size = 4096 , device = None ,
74
+ no_dht = False , initial_peers = (), dht_port = None , checkpoint_dir : Optional [Path ] = None ,
75
+ compression = CompressionType .NONE , * , start : bool , ** kwargs ) -> Server :
76
76
"""
77
77
Instantiate a server with several identical experts. See argparse comments below for details
78
78
:param listen_on: network interface with address and (optional) port, e.g. "127.0.0.1:1337" or "[::]:80"
@@ -85,16 +85,20 @@ def create(listen_on='0.0.0.0:*', num_experts: int = None, expert_uids: str = No
85
85
:param num_handlers: server will use this many parallel processes to handle incoming requests
86
86
:param max_batch_size: total num examples in the same batch will not exceed this value
87
87
:param device: all experts will use this device in torch notation; default: cuda if available else cpu
88
+
88
89
:param optim_cls: uses this optimizer to train all experts
90
+ :param scheduler: if not `none`, the name of the expert LR scheduler
91
+ :param num_warmup_steps: the number of warmup steps for LR schedule
92
+ :param num_training_steps: the total number of steps for LR schedule
93
+
89
94
:param no_dht: if specified, the server will not be attached to a dht
90
95
:param initial_peers: a list of peers that will introduce this node to the dht,\
91
96
e.g. ('123.11.22.33:1337', '[fe80::abe2:db1c:be7d:5a85]:4567'), default = no peers
92
97
93
98
:param dht_port: DHT node will listen on this port, default = find open port
94
99
You can then use this node as initial peer for subsequent servers.
95
100
96
- :param checkpoint_dir: directory to save expert checkpoints
97
- :param load_experts: whether to load expert checkpoints from checkpoint_dir
101
+ :param checkpoint_dir: directory to save and load expert checkpoints
98
102
99
103
:param compression: if specified, use this compression to pack all inputs, outputs and gradients by all experts
100
104
hosted on this server. For a more fine-grained compression, start server in python and specify compression
@@ -113,23 +117,29 @@ def create(listen_on='0.0.0.0:*', num_experts: int = None, expert_uids: str = No
113
117
dht = hivemind .DHT (initial_peers = initial_peers , start = True , listen_on = dht_endpoint )
114
118
logger .info (f"Running DHT node on port { dht .port } , initial peers = { initial_peers } " )
115
119
116
- if load_experts :
117
- assert dir_is_correct (checkpoint_dir )
118
- assert expert_uids is None , "Can't both load saved experts and create new ones from given UIDs"
119
- expert_uids = [child .name for child in checkpoint_dir .iterdir () if (child / 'checkpoint_last.pt' ).exists ()]
120
- if expert_uids :
121
- logger .info (f"Located checkpoints for experts { expert_uids } , ignoring UID generation options" )
122
- else :
123
- logger .info (f"No expert checkpoints found in { checkpoint_dir } , generating..." )
124
-
125
- assert (expert_pattern is None and num_experts is None ) or (expert_uids is None ) or (num_experts == 0 ), \
126
- "Please provide either expert_uids *or* num_experts and expert_pattern, but not both"
120
+ assert ((expert_pattern is None and num_experts is None and expert_uids is not None ) or
121
+ (num_experts is not None and expert_uids is None )), \
122
+ "Please provide either expert_uids *or* num_experts (possibly with expert_pattern), but not both"
127
123
128
- # get expert uids if not loaded previously
129
124
if expert_uids is None :
130
- assert num_experts is not None , "Please specify either expert_uids or num_experts [and expert_pattern]"
131
- logger .info (f"Generating expert uids from pattern { expert_pattern } " )
132
- expert_uids = generate_uids_from_pattern (num_experts , expert_pattern , dht = dht )
125
+ if checkpoint_dir is not None :
126
+ assert is_directory (checkpoint_dir )
127
+ expert_uids = [child .name for child in checkpoint_dir .iterdir () if
128
+ (child / 'checkpoint_last.pt' ).exists ()]
129
+ total_experts_in_checkpoint = len (expert_uids )
130
+ logger .info (f"Located { total_experts_in_checkpoint } checkpoints for experts { expert_uids } " )
131
+
132
+ if total_experts_in_checkpoint > num_experts :
133
+ raise ValueError (
134
+ f"Found { total_experts_in_checkpoint } checkpoints, but num_experts is set to { num_experts } , "
135
+ f"which is smaller. Either increase num_experts or remove unneeded checkpoints." )
136
+ else :
137
+ expert_uids = []
138
+
139
+ uids_to_generate = num_experts - len (expert_uids )
140
+ if uids_to_generate > 0 :
141
+ logger .info (f"Generating { uids_to_generate } expert uids from pattern { expert_pattern } " )
142
+ expert_uids .extend (generate_uids_from_pattern (uids_to_generate , expert_pattern , dht ))
133
143
134
144
num_experts = len (expert_uids )
135
145
num_handlers = num_handlers if num_handlers is not None else num_experts * 8
@@ -142,6 +152,8 @@ def create(listen_on='0.0.0.0:*', num_experts: int = None, expert_uids: str = No
142
152
else :
143
153
args_schema = (hivemind .BatchTensorDescriptor .from_tensor (sample_input , compression ),)
144
154
155
+ scheduler = schedule_name_to_scheduler [scheduler ]
156
+
145
157
# initialize experts
146
158
experts = {}
147
159
for expert_uid in expert_uids :
@@ -150,15 +162,17 @@ def create(listen_on='0.0.0.0:*', num_experts: int = None, expert_uids: str = No
150
162
args_schema = args_schema ,
151
163
outputs_schema = hivemind .BatchTensorDescriptor (
152
164
hidden_dim , compression = compression ),
153
- opt = optim_cls (expert .parameters ()),
165
+ optimizer = optim_cls (expert .parameters ()),
166
+ scheduler = scheduler ,
167
+ num_warmup_steps = num_warmup_steps ,
168
+ num_training_steps = num_training_steps ,
154
169
max_batch_size = max_batch_size )
155
170
156
- if load_experts :
157
- load_weights (experts , checkpoint_dir )
171
+ if checkpoint_dir is not None :
172
+ load_experts (experts , checkpoint_dir )
158
173
159
- server = Server (dht , experts , listen_on = listen_on , num_connection_handlers = num_handlers , device = device ,
160
- start = start )
161
- return server
174
+ return cls (dht , experts , listen_on = listen_on , num_connection_handlers = num_handlers , device = device ,
175
+ checkpoint_dir = checkpoint_dir , start = start )
162
176
163
177
def run (self ):
164
178
"""
@@ -241,7 +255,7 @@ def shutdown(self):
241
255
def background_server (* args , shutdown_timeout = 5 , ** kwargs ) -> Tuple [hivemind .Endpoint , hivemind .Endpoint ]:
242
256
""" A context manager that creates server in a background thread, awaits .ready on entry and shutdowns on exit """
243
257
pipe , runners_pipe = mp .Pipe (duplex = True )
244
- runner = mp .get_context ( "spawn" ). Process (target = _server_runner , args = (runners_pipe , * args ), kwargs = kwargs )
258
+ runner = mp .Process (target = _server_runner , args = (runners_pipe , * args ), kwargs = kwargs )
245
259
246
260
try :
247
261
runner .start ()
@@ -269,63 +283,3 @@ def _server_runner(pipe, *args, **kwargs):
269
283
server .shutdown ()
270
284
server .join ()
271
285
logger .info ("Server shut down." )
272
-
273
-
274
- def generate_uids_from_pattern (num_experts : int , expert_pattern : Optional [str ], dht : Optional [DHT ] = None ,
275
- attempts_per_expert = 10 ) -> List [str ]:
276
- """
277
- Sample experts from a given pattern, remove duplicates.
278
- :param num_experts: sample this many unique expert uids
279
- :param expert_pattern: a string pattern or a list of expert uids, example: myprefix.[0:32].[0:256]\
280
- means "sample random experts between myprefix.0.0 and myprefix.255.255;
281
- :param dht: if specified, uses this DHT to check that expert uids are not yet occupied by other peers
282
- :param attempts_per_expert: give up if unable to generate a new expert uid after this many attempts per uid
283
- :note: this method is not strictly process-safe. If several servers run it concurrently, they have
284
- a small chance of sampling duplicate expert uids.
285
- """
286
- remaining_attempts = attempts_per_expert * num_experts
287
- found_uids , attempted_uids = list (), set ()
288
-
289
- def _generate_uid ():
290
- if expert_pattern is None :
291
- return f"expert{ UID_DELIMITER } { attempts_per_expert * num_experts - remaining_attempts } "
292
-
293
- uid = []
294
- for block in expert_pattern .split (UID_DELIMITER ):
295
- try :
296
- if '[' not in block and ']' not in block :
297
- uid .append (block )
298
- elif block .startswith ('[' ) and block .endswith (']' ) and ':' in block :
299
- slice_start , slice_end = map (int , block [1 :- 1 ].split (':' ))
300
- uid .append (str (random .randint (slice_start , slice_end - 1 )))
301
- else :
302
- raise ValueError ("Block must be either fixed or a range [from:to]" )
303
- except KeyboardInterrupt as e :
304
- raise e
305
- except Exception as e :
306
- raise ValueError (f"Expert pattern { expert_pattern } has invalid block { block } , { e } " )
307
- return UID_DELIMITER .join (uid )
308
-
309
- while remaining_attempts > 0 and len (found_uids ) < num_experts :
310
-
311
- # 1. sample new expert uids at random
312
- new_uids = []
313
- while len (new_uids ) + len (found_uids ) < num_experts and remaining_attempts > 0 :
314
- new_uid = _generate_uid ()
315
- remaining_attempts -= 1
316
- if new_uid not in attempted_uids :
317
- attempted_uids .add (new_uid )
318
- new_uids .append (new_uid )
319
-
320
- # 2. look into DHT (if given) and remove duplicates
321
- if dht :
322
- existing_expert_uids = {found_expert .uid for found_expert in dht .get_experts (new_uids )
323
- if found_expert is not None }
324
- new_uids = [new_uid for new_uid in new_uids if new_uid not in existing_expert_uids ]
325
-
326
- found_uids += new_uids
327
-
328
- if len (found_uids ) != num_experts :
329
- logger .warning (f"Found only { len (found_uids )} out of { num_experts } free expert uids after "
330
- f"{ attempts_per_expert * num_experts } attempts" )
331
- return found_uids
0 commit comments