Skip to content

Commit

Permalink
Use hivemind.Optimizer in examples/albert (#426)
Browse files Browse the repository at this point in the history
* Use hivemind.Optimizer in examples/albert

* Deprecate CollaborativeOptimizer

* Add is_synchronized_with_peers

* Enable delayed gradient averaging
  • Loading branch information
mryab authored Dec 20, 2021
1 parent 9a921af commit b150768
Show file tree
Hide file tree
Showing 15 changed files with 121 additions and 89 deletions.
3 changes: 1 addition & 2 deletions benchmarks/benchmark_optimizer.py
Original file line number Diff line number Diff line change
Expand Up @@ -6,15 +6,14 @@
from functools import partial
from typing import Callable

import numpy as np
import torch
import torchvision
from torch import nn as nn
from torch.nn import functional as F
from torch.utils.data import Dataset

import hivemind
from hivemind.optim.experimental.optimizer import Optimizer
from hivemind.optim.optimizer import Optimizer
from hivemind.utils.crypto import RSAPrivateKey


Expand Down
4 changes: 2 additions & 2 deletions docs/modules/optim.rst
Original file line number Diff line number Diff line change
Expand Up @@ -9,8 +9,8 @@

<br><br>

.. automodule:: hivemind.optim.experimental.optimizer
.. currentmodule:: hivemind.optim.experimental.optimizer
.. automodule:: hivemind.optim.optimizer
.. currentmodule:: hivemind.optim.optimizer

**hivemind.Optimizer**
----------------------
Expand Down
25 changes: 13 additions & 12 deletions examples/albert/arguments.py
Original file line number Diff line number Diff line change
Expand Up @@ -45,12 +45,11 @@ class BaseTrainingArguments:

@dataclass
class AveragerArguments:
averaging_expiration: float = field(
default=5.0, metadata={"help": "Averaging group will wait for stragglers for at most this many seconds"}
)
averaging_timeout: float = field(
default=60.0, metadata={"help": "Give up on averaging step after this many seconds"}
)
target_group_size: int = field(default=256, metadata={"help": "Maximum group size for all-reduce"})


@dataclass
class ProgressTrackerArguments:
min_refresh_period: float = field(
default=0.5, metadata={"help": "Wait for at least this many seconds before fetching new collaboration state"}
)
Expand All @@ -66,17 +65,13 @@ class AveragerArguments:
expected_drift_rate: float = field(
default=0.2, metadata={"help": "Trainer assumes that this fraction of current size can join per step"}
)
performance_ema_alpha: float = field(
default=0.1, metadata={"help": "Uses this alpha for moving average estimate of samples per second"}
)
target_group_size: int = field(default=256, metadata={"help": "Maximum group size for all-reduce"})
metadata_expiration: float = field(
default=120, metadata={"help": "Peer's metadata will be removed if not updated in this many seconds"}
)


@dataclass
class CollaborativeOptimizerArguments:
class OptimizerArguments:
target_batch_size: int = field(
default=4096,
metadata={"help": "Perform optimizer step after all peers collectively accumulate this many samples"},
Expand All @@ -93,10 +88,16 @@ class CollaborativeOptimizerArguments:
default=100.0,
metadata={"help": "Available network bandwidth, in mbps (used for load balancing in all-reduce)"},
)
averaging_timeout: float = field(
default=60.0, metadata={"help": "Give up on averaging step after this many seconds"}
)
matchmaking_time: float = field(
default=5.0, metadata={"help": "When looking for group, wait for requests for at least this many seconds"}
)


@dataclass
class CollaborationArguments(CollaborativeOptimizerArguments, BaseTrainingArguments):
class CollaborationArguments(OptimizerArguments, BaseTrainingArguments):
statistics_expiration: float = field(
default=600, metadata={"help": "Statistics will be removed if not updated in this many seconds"}
)
Expand Down
90 changes: 55 additions & 35 deletions examples/albert/run_trainer.py
Original file line number Diff line number Diff line change
Expand Up @@ -16,11 +16,17 @@
from transformers.trainer import Trainer
from transformers.trainer_utils import is_main_process

import hivemind
from hivemind import DHT, Float16Compression, Optimizer, get_dht_time
from hivemind.utils.logging import get_logger, use_hivemind_log_handler

import utils
from arguments import AlbertTrainingArguments, AveragerArguments, CollaborationArguments, DatasetArguments
from arguments import (
AlbertTrainingArguments,
AveragerArguments,
CollaborationArguments,
DatasetArguments,
ProgressTrackerArguments,
)

use_hivemind_log_handler("in_root_logger")
logger = get_logger(__name__)
Expand Down Expand Up @@ -90,16 +96,16 @@ class CollaborativeCallback(transformers.TrainerCallback):

def __init__(
self,
dht: hivemind.DHT,
optimizer: hivemind.CollaborativeOptimizer,
dht: DHT,
optimizer: Optimizer,
model: torch.nn.Module,
local_public_key: bytes,
statistics_expiration: float,
backup_every_steps: int,
):
super().__init__()
self.model = model
self.dht, self.collaborative_optimizer = dht, optimizer
self.dht, self.optimizer = dht, optimizer
self.local_public_key = local_public_key
self.statistics_expiration = statistics_expiration
self.last_reported_collaboration_step = -1
Expand All @@ -114,7 +120,7 @@ def on_train_begin(
self, args: TrainingArguments, state: transformers.TrainerState, control: transformers.TrainerControl, **kwargs
):
logger.info("Loading state from peers")
self.collaborative_optimizer.load_state_from_peers()
self.optimizer.load_state_from_peers()

def on_step_end(
self, args: TrainingArguments, state: transformers.TrainerState, control: transformers.TrainerControl, **kwargs
Expand All @@ -124,40 +130,43 @@ def on_step_end(
self.restore_from_backup(self.latest_backup)
return control

local_progress = self.optimizer.local_progress

if state.log_history:
self.loss += state.log_history[-1]["loss"]
self.steps += 1
if self.collaborative_optimizer.local_step != self.last_reported_collaboration_step:
self.last_reported_collaboration_step = self.collaborative_optimizer.local_step

if self.optimizer.local_epoch != self.last_reported_collaboration_step:
self.last_reported_collaboration_step = self.optimizer.local_epoch
self.total_samples_processed += self.samples
samples_per_second = self.collaborative_optimizer.performance_ema.samples_per_second
samples_per_second = local_progress.samples_per_second
statistics = utils.LocalMetrics(
step=self.collaborative_optimizer.local_step,
step=self.optimizer.local_epoch,
samples_per_second=samples_per_second,
samples_accumulated=self.samples,
loss=self.loss,
mini_steps=self.steps,
)
logger.info(f"Step #{self.collaborative_optimizer.local_step}")
logger.info(f"Step #{self.optimizer.local_epoch}")
logger.info(f"Your current contribution: {self.total_samples_processed} samples")
logger.info(f"Performance: {samples_per_second} samples per second.")
if self.steps:
logger.info(f"Local loss: {self.loss / self.steps}")
if self.collaborative_optimizer.local_step % self.backup_every_steps == 0:
if self.optimizer.local_epoch % self.backup_every_steps == 0:
self.latest_backup = self.backup_state()

self.loss = 0
self.steps = 0
if self.collaborative_optimizer.is_synchronized:
if self.optimizer.is_synchronized_with_peers():
self.dht.store(
key=self.collaborative_optimizer.prefix + "_metrics",
key=self.optimizer.run_id + "_metrics",
subkey=self.local_public_key,
value=statistics.dict(),
expiration_time=hivemind.get_dht_time() + self.statistics_expiration,
expiration_time=get_dht_time() + self.statistics_expiration,
return_future=True,
)

self.samples = self.collaborative_optimizer.local_samples_accumulated
self.samples = local_progress.samples_accumulated

return control

Expand All @@ -170,19 +179,17 @@ def params_are_finite(self):

@torch.no_grad()
def backup_state(self) -> bytes:
return pickle.dumps(
{"model": self.model.state_dict(), "optimizer": self.collaborative_optimizer.opt.state_dict()}
)
return pickle.dumps({"model": self.model.state_dict(), "optimizer": self.optimizer.state_dict()})

@torch.no_grad()
def restore_from_backup(self, backup: bytes):
state = pickle.loads(backup)
self.model.load_state_dict(state["model"])
self.collaborative_optimizer.opt.load_state_dict(state["optimizer"])
self.optimizer.load_state_dict(state["optimizer"])


class NoOpScheduler(LRSchedulerBase):
"""Dummy scheduler for transformers.Trainer. The real scheduler is defined in CollaborativeOptimizer.scheduler"""
"""Dummy scheduler for transformers.Trainer. The real scheduler is defined in Optimizer.scheduler"""

def get_lr(self):
return [group["lr"] for group in self.optimizer.param_groups]
Expand All @@ -202,8 +209,16 @@ def load_state_dict(self, *args, **kwargs):


def main():
parser = HfArgumentParser((AlbertTrainingArguments, DatasetArguments, CollaborationArguments, AveragerArguments))
training_args, dataset_args, collaboration_args, averager_args = parser.parse_args_into_dataclasses()
parser = HfArgumentParser(
(
AlbertTrainingArguments,
DatasetArguments,
CollaborationArguments,
AveragerArguments,
ProgressTrackerArguments,
)
)
training_args, dataset_args, collaboration_args, averager_args, tracker_args = parser.parse_args_into_dataclasses()

logger.info(f"Found {len(collaboration_args.initial_peers)} initial peers: {collaboration_args.initial_peers}")
if len(collaboration_args.initial_peers) == 0:
Expand All @@ -228,7 +243,7 @@ def main():

validators, local_public_key = utils.make_validators(collaboration_args.experiment_prefix)

dht = hivemind.DHT(
dht = DHT(
start=True,
initial_peers=collaboration_args.initial_peers,
client_mode=collaboration_args.client_mode,
Expand All @@ -246,19 +261,24 @@ def main():

adjusted_target_batch_size = collaboration_args.target_batch_size - collaboration_args.batch_size_lead

collaborative_optimizer = hivemind.CollaborativeOptimizer(
opt=opt,
optimizer = Optimizer(
dht=dht,
scheduler=scheduler,
prefix=collaboration_args.experiment_prefix,
compression=hivemind.Float16Compression(),
batch_size_per_step=total_batch_size_per_step,
bandwidth=collaboration_args.bandwidth,
run_id=collaboration_args.experiment_prefix,
target_batch_size=adjusted_target_batch_size,
batch_size_per_step=total_batch_size_per_step,
optimizer=opt,
scheduler=scheduler,
matchmaking_time=collaboration_args.matchmaking_time,
averaging_timeout=collaboration_args.averaging_timeout,
offload_optimizer=True,
delay_optimizer_step=True,
delay_grad_averaging=True,
client_mode=collaboration_args.client_mode,
grad_compression=Float16Compression(),
state_averaging_compression=Float16Compression(),
averager_opts={"bandwidth": collaboration_args.bandwidth, **asdict(averager_args)},
tracker_opts=asdict(tracker_args),
verbose=True,
start=True,
**asdict(averager_args),
)

class TrainerWithIndependentShuffling(Trainer):
Expand All @@ -274,11 +294,11 @@ def get_train_dataloader(self) -> DataLoader:
data_collator=data_collator,
train_dataset=tokenized_datasets["train"] if training_args.do_train else None,
eval_dataset=tokenized_datasets["validation"] if training_args.do_eval else None,
optimizers=(collaborative_optimizer, NoOpScheduler(collaborative_optimizer)),
optimizers=(optimizer, NoOpScheduler(optimizer)),
callbacks=[
CollaborativeCallback(
dht,
collaborative_optimizer,
optimizer,
model,
local_public_key,
collaboration_args.statistics_expiration,
Expand Down
35 changes: 16 additions & 19 deletions examples/albert/run_training_monitor.py
Original file line number Diff line number Diff line change
Expand Up @@ -12,10 +12,11 @@
from transformers import AlbertConfig, AlbertForPreTraining, HfArgumentParser

import hivemind
from hivemind.optim.state_averager import TrainingStateAverager
from hivemind.utils.logging import get_logger, use_hivemind_log_handler

import utils
from arguments import AveragerArguments, BaseTrainingArguments, CollaborativeOptimizerArguments
from arguments import AveragerArguments, BaseTrainingArguments, OptimizerArguments

use_hivemind_log_handler("in_root_logger")
logger = get_logger(__name__)
Expand Down Expand Up @@ -55,14 +56,14 @@ class TrainingMonitorArguments(BaseTrainingArguments):
upload_interval: Optional[float] = field(
default=None, metadata={"help": "Frequency (in seconds) of uploading the model to Hub"}
)
store_checkpoins: bool = field(default=False, metadata={"help": "If True, enables CheckpointHandler"})
store_checkpoints: bool = field(default=False, metadata={"help": "If True, enables CheckpointHandler"})


class CheckpointHandler:
def __init__(
self,
monitor_args: TrainingMonitorArguments,
collab_optimizer_args: CollaborativeOptimizerArguments,
optimizer_args: OptimizerArguments,
averager_args: AveragerArguments,
dht: hivemind.DHT,
):
Expand Down Expand Up @@ -95,17 +96,13 @@ def __init__(
debias=True,
)

adjusted_target_batch_size = collab_optimizer_args.target_batch_size - collab_optimizer_args.batch_size_lead

self.collaborative_optimizer = hivemind.CollaborativeOptimizer(
opt=opt,
self.state_averager = TrainingStateAverager(
dht=dht,
optimizer=opt,
prefix=experiment_prefix,
compression_type=hivemind.Float16Compression(),
bandwidth=collab_optimizer_args.bandwidth,
target_batch_size=adjusted_target_batch_size,
client_mode=collab_optimizer_args.client_mode,
verbose=True,
state_compression=hivemind.Float16Compression(),
bandwidth=optimizer_args.bandwidth,
client_mode=optimizer_args.client_mode,
start=True,
**asdict(averager_args),
)
Expand All @@ -121,7 +118,7 @@ def is_time_to_save_state(self, cur_step):

def save_state(self, cur_step):
logger.info("Saving state from peers")
self.collaborative_optimizer.load_state_from_peers()
self.state_averager.load_state_from_peers()
self.previous_step = cur_step

def is_time_to_upload(self):
Expand All @@ -134,7 +131,7 @@ def is_time_to_upload(self):

def upload_checkpoint(self, current_loss):
logger.info("Saving optimizer")
torch.save(self.collaborative_optimizer.opt.state_dict(), f"{self.repo_path}/optimizer_state.pt")
torch.save(self.state_averager.optimizer.state_dict(), f"{self.repo_path}/optimizer_state.pt")
self.previous_timestamp = time.time()
logger.info("Started uploading to Model Hub")
self.model.push_to_hub(
Expand All @@ -146,8 +143,8 @@ def upload_checkpoint(self, current_loss):


if __name__ == "__main__":
parser = HfArgumentParser((TrainingMonitorArguments, CollaborativeOptimizerArguments, AveragerArguments))
monitor_args, collab_optimizer_args, averager_args = parser.parse_args_into_dataclasses()
parser = HfArgumentParser((TrainingMonitorArguments, OptimizerArguments, AveragerArguments))
monitor_args, optimizer_args, averager_args = parser.parse_args_into_dataclasses()

if monitor_args.use_google_dns:
request = requests.get("https://api.ipify.org")
Expand Down Expand Up @@ -176,8 +173,8 @@ def upload_checkpoint(self, current_loss):
wandb.init(project=monitor_args.wandb_project)

current_step = 0
if monitor_args.store_checkpoins:
checkpoint_handler = CheckpointHandler(monitor_args, collab_optimizer_args, averager_args, dht)
if monitor_args.store_checkpoints:
checkpoint_handler = CheckpointHandler(monitor_args, optimizer_args, averager_args, dht)

while True:
metrics_dict = dht.get(experiment_prefix + "_metrics", latest=True)
Expand Down Expand Up @@ -219,7 +216,7 @@ def upload_checkpoint(self, current_loss):
}
)

if monitor_args.store_checkpoins:
if monitor_args.store_checkpoints:
if checkpoint_handler.is_time_to_save_state(current_step):
checkpoint_handler.save_state(current_step)
if checkpoint_handler.is_time_to_upload():
Expand Down
2 changes: 1 addition & 1 deletion hivemind/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -23,4 +23,4 @@
from hivemind.p2p import P2P, P2PContext, P2PHandlerError, PeerID, PeerInfo
from hivemind.utils import *

__version__ = "1.0.0dev0"
__version__ = "1.0.0"
Loading

0 comments on commit b150768

Please sign in to comment.