From b1507687ae02291b5402f0c64b7ceae2c712e9e7 Mon Sep 17 00:00:00 2001 From: Max Ryabinin Date: Mon, 20 Dec 2021 15:48:19 +0300 Subject: [PATCH] Use hivemind.Optimizer in examples/albert (#426) * Use hivemind.Optimizer in examples/albert * Deprecate CollaborativeOptimizer * Add is_synchronized_with_peers * Enable delayed gradient averaging --- benchmarks/benchmark_optimizer.py | 3 +- docs/modules/optim.rst | 4 +- examples/albert/arguments.py | 25 +++--- examples/albert/run_trainer.py | 90 +++++++++++-------- examples/albert/run_training_monitor.py | 35 ++++---- hivemind/__init__.py | 2 +- hivemind/optim/__init__.py | 2 +- hivemind/optim/base.py | 8 ++ hivemind/optim/collaborative.py | 10 +-- hivemind/optim/experimental/__init__.py | 0 .../optim/{experimental => }/grad_averager.py | 0 .../optim/{experimental => }/optimizer.py | 23 +++-- .../{experimental => }/progress_tracker.py | 0 .../{experimental => }/state_averager.py | 0 tests/test_optimizer.py | 8 +- 15 files changed, 121 insertions(+), 89 deletions(-) delete mode 100644 hivemind/optim/experimental/__init__.py rename hivemind/optim/{experimental => }/grad_averager.py (100%) rename hivemind/optim/{experimental => }/optimizer.py (98%) rename hivemind/optim/{experimental => }/progress_tracker.py (100%) rename hivemind/optim/{experimental => }/state_averager.py (100%) diff --git a/benchmarks/benchmark_optimizer.py b/benchmarks/benchmark_optimizer.py index b5dd956f5..8f93ff692 100644 --- a/benchmarks/benchmark_optimizer.py +++ b/benchmarks/benchmark_optimizer.py @@ -6,7 +6,6 @@ from functools import partial from typing import Callable -import numpy as np import torch import torchvision from torch import nn as nn @@ -14,7 +13,7 @@ from torch.utils.data import Dataset import hivemind -from hivemind.optim.experimental.optimizer import Optimizer +from hivemind.optim.optimizer import Optimizer from hivemind.utils.crypto import RSAPrivateKey diff --git a/docs/modules/optim.rst b/docs/modules/optim.rst index a07984392..11fd765fd 100644 --- a/docs/modules/optim.rst +++ b/docs/modules/optim.rst @@ -9,8 +9,8 @@

-.. automodule:: hivemind.optim.experimental.optimizer -.. currentmodule:: hivemind.optim.experimental.optimizer +.. automodule:: hivemind.optim.optimizer +.. currentmodule:: hivemind.optim.optimizer **hivemind.Optimizer** ---------------------- diff --git a/examples/albert/arguments.py b/examples/albert/arguments.py index f5f0227e1..a64af658f 100644 --- a/examples/albert/arguments.py +++ b/examples/albert/arguments.py @@ -45,12 +45,11 @@ class BaseTrainingArguments: @dataclass class AveragerArguments: - averaging_expiration: float = field( - default=5.0, metadata={"help": "Averaging group will wait for stragglers for at most this many seconds"} - ) - averaging_timeout: float = field( - default=60.0, metadata={"help": "Give up on averaging step after this many seconds"} - ) + target_group_size: int = field(default=256, metadata={"help": "Maximum group size for all-reduce"}) + + +@dataclass +class ProgressTrackerArguments: min_refresh_period: float = field( default=0.5, metadata={"help": "Wait for at least this many seconds before fetching new collaboration state"} ) @@ -66,17 +65,13 @@ class AveragerArguments: expected_drift_rate: float = field( default=0.2, metadata={"help": "Trainer assumes that this fraction of current size can join per step"} ) - performance_ema_alpha: float = field( - default=0.1, metadata={"help": "Uses this alpha for moving average estimate of samples per second"} - ) - target_group_size: int = field(default=256, metadata={"help": "Maximum group size for all-reduce"}) metadata_expiration: float = field( default=120, metadata={"help": "Peer's metadata will be removed if not updated in this many seconds"} ) @dataclass -class CollaborativeOptimizerArguments: +class OptimizerArguments: target_batch_size: int = field( default=4096, metadata={"help": "Perform optimizer step after all peers collectively accumulate this many samples"}, @@ -93,10 +88,16 @@ class CollaborativeOptimizerArguments: default=100.0, metadata={"help": "Available network bandwidth, in mbps (used for load balancing in all-reduce)"}, ) + averaging_timeout: float = field( + default=60.0, metadata={"help": "Give up on averaging step after this many seconds"} + ) + matchmaking_time: float = field( + default=5.0, metadata={"help": "When looking for group, wait for requests for at least this many seconds"} + ) @dataclass -class CollaborationArguments(CollaborativeOptimizerArguments, BaseTrainingArguments): +class CollaborationArguments(OptimizerArguments, BaseTrainingArguments): statistics_expiration: float = field( default=600, metadata={"help": "Statistics will be removed if not updated in this many seconds"} ) diff --git a/examples/albert/run_trainer.py b/examples/albert/run_trainer.py index 8a6611a1f..ecaa6026e 100644 --- a/examples/albert/run_trainer.py +++ b/examples/albert/run_trainer.py @@ -16,11 +16,17 @@ from transformers.trainer import Trainer from transformers.trainer_utils import is_main_process -import hivemind +from hivemind import DHT, Float16Compression, Optimizer, get_dht_time from hivemind.utils.logging import get_logger, use_hivemind_log_handler import utils -from arguments import AlbertTrainingArguments, AveragerArguments, CollaborationArguments, DatasetArguments +from arguments import ( + AlbertTrainingArguments, + AveragerArguments, + CollaborationArguments, + DatasetArguments, + ProgressTrackerArguments, +) use_hivemind_log_handler("in_root_logger") logger = get_logger(__name__) @@ -90,8 +96,8 @@ class CollaborativeCallback(transformers.TrainerCallback): def __init__( self, - dht: hivemind.DHT, - optimizer: hivemind.CollaborativeOptimizer, + dht: DHT, + optimizer: Optimizer, model: torch.nn.Module, local_public_key: bytes, statistics_expiration: float, @@ -99,7 +105,7 @@ def __init__( ): super().__init__() self.model = model - self.dht, self.collaborative_optimizer = dht, optimizer + self.dht, self.optimizer = dht, optimizer self.local_public_key = local_public_key self.statistics_expiration = statistics_expiration self.last_reported_collaboration_step = -1 @@ -114,7 +120,7 @@ def on_train_begin( self, args: TrainingArguments, state: transformers.TrainerState, control: transformers.TrainerControl, **kwargs ): logger.info("Loading state from peers") - self.collaborative_optimizer.load_state_from_peers() + self.optimizer.load_state_from_peers() def on_step_end( self, args: TrainingArguments, state: transformers.TrainerState, control: transformers.TrainerControl, **kwargs @@ -124,40 +130,43 @@ def on_step_end( self.restore_from_backup(self.latest_backup) return control + local_progress = self.optimizer.local_progress + if state.log_history: self.loss += state.log_history[-1]["loss"] self.steps += 1 - if self.collaborative_optimizer.local_step != self.last_reported_collaboration_step: - self.last_reported_collaboration_step = self.collaborative_optimizer.local_step + + if self.optimizer.local_epoch != self.last_reported_collaboration_step: + self.last_reported_collaboration_step = self.optimizer.local_epoch self.total_samples_processed += self.samples - samples_per_second = self.collaborative_optimizer.performance_ema.samples_per_second + samples_per_second = local_progress.samples_per_second statistics = utils.LocalMetrics( - step=self.collaborative_optimizer.local_step, + step=self.optimizer.local_epoch, samples_per_second=samples_per_second, samples_accumulated=self.samples, loss=self.loss, mini_steps=self.steps, ) - logger.info(f"Step #{self.collaborative_optimizer.local_step}") + logger.info(f"Step #{self.optimizer.local_epoch}") logger.info(f"Your current contribution: {self.total_samples_processed} samples") logger.info(f"Performance: {samples_per_second} samples per second.") if self.steps: logger.info(f"Local loss: {self.loss / self.steps}") - if self.collaborative_optimizer.local_step % self.backup_every_steps == 0: + if self.optimizer.local_epoch % self.backup_every_steps == 0: self.latest_backup = self.backup_state() self.loss = 0 self.steps = 0 - if self.collaborative_optimizer.is_synchronized: + if self.optimizer.is_synchronized_with_peers(): self.dht.store( - key=self.collaborative_optimizer.prefix + "_metrics", + key=self.optimizer.run_id + "_metrics", subkey=self.local_public_key, value=statistics.dict(), - expiration_time=hivemind.get_dht_time() + self.statistics_expiration, + expiration_time=get_dht_time() + self.statistics_expiration, return_future=True, ) - self.samples = self.collaborative_optimizer.local_samples_accumulated + self.samples = local_progress.samples_accumulated return control @@ -170,19 +179,17 @@ def params_are_finite(self): @torch.no_grad() def backup_state(self) -> bytes: - return pickle.dumps( - {"model": self.model.state_dict(), "optimizer": self.collaborative_optimizer.opt.state_dict()} - ) + return pickle.dumps({"model": self.model.state_dict(), "optimizer": self.optimizer.state_dict()}) @torch.no_grad() def restore_from_backup(self, backup: bytes): state = pickle.loads(backup) self.model.load_state_dict(state["model"]) - self.collaborative_optimizer.opt.load_state_dict(state["optimizer"]) + self.optimizer.load_state_dict(state["optimizer"]) class NoOpScheduler(LRSchedulerBase): - """Dummy scheduler for transformers.Trainer. The real scheduler is defined in CollaborativeOptimizer.scheduler""" + """Dummy scheduler for transformers.Trainer. The real scheduler is defined in Optimizer.scheduler""" def get_lr(self): return [group["lr"] for group in self.optimizer.param_groups] @@ -202,8 +209,16 @@ def load_state_dict(self, *args, **kwargs): def main(): - parser = HfArgumentParser((AlbertTrainingArguments, DatasetArguments, CollaborationArguments, AveragerArguments)) - training_args, dataset_args, collaboration_args, averager_args = parser.parse_args_into_dataclasses() + parser = HfArgumentParser( + ( + AlbertTrainingArguments, + DatasetArguments, + CollaborationArguments, + AveragerArguments, + ProgressTrackerArguments, + ) + ) + training_args, dataset_args, collaboration_args, averager_args, tracker_args = parser.parse_args_into_dataclasses() logger.info(f"Found {len(collaboration_args.initial_peers)} initial peers: {collaboration_args.initial_peers}") if len(collaboration_args.initial_peers) == 0: @@ -228,7 +243,7 @@ def main(): validators, local_public_key = utils.make_validators(collaboration_args.experiment_prefix) - dht = hivemind.DHT( + dht = DHT( start=True, initial_peers=collaboration_args.initial_peers, client_mode=collaboration_args.client_mode, @@ -246,19 +261,24 @@ def main(): adjusted_target_batch_size = collaboration_args.target_batch_size - collaboration_args.batch_size_lead - collaborative_optimizer = hivemind.CollaborativeOptimizer( - opt=opt, + optimizer = Optimizer( dht=dht, - scheduler=scheduler, - prefix=collaboration_args.experiment_prefix, - compression=hivemind.Float16Compression(), - batch_size_per_step=total_batch_size_per_step, - bandwidth=collaboration_args.bandwidth, + run_id=collaboration_args.experiment_prefix, target_batch_size=adjusted_target_batch_size, + batch_size_per_step=total_batch_size_per_step, + optimizer=opt, + scheduler=scheduler, + matchmaking_time=collaboration_args.matchmaking_time, + averaging_timeout=collaboration_args.averaging_timeout, + offload_optimizer=True, + delay_optimizer_step=True, + delay_grad_averaging=True, client_mode=collaboration_args.client_mode, + grad_compression=Float16Compression(), + state_averaging_compression=Float16Compression(), + averager_opts={"bandwidth": collaboration_args.bandwidth, **asdict(averager_args)}, + tracker_opts=asdict(tracker_args), verbose=True, - start=True, - **asdict(averager_args), ) class TrainerWithIndependentShuffling(Trainer): @@ -274,11 +294,11 @@ def get_train_dataloader(self) -> DataLoader: data_collator=data_collator, train_dataset=tokenized_datasets["train"] if training_args.do_train else None, eval_dataset=tokenized_datasets["validation"] if training_args.do_eval else None, - optimizers=(collaborative_optimizer, NoOpScheduler(collaborative_optimizer)), + optimizers=(optimizer, NoOpScheduler(optimizer)), callbacks=[ CollaborativeCallback( dht, - collaborative_optimizer, + optimizer, model, local_public_key, collaboration_args.statistics_expiration, diff --git a/examples/albert/run_training_monitor.py b/examples/albert/run_training_monitor.py index 880fd45bb..2e80d43ab 100644 --- a/examples/albert/run_training_monitor.py +++ b/examples/albert/run_training_monitor.py @@ -12,10 +12,11 @@ from transformers import AlbertConfig, AlbertForPreTraining, HfArgumentParser import hivemind +from hivemind.optim.state_averager import TrainingStateAverager from hivemind.utils.logging import get_logger, use_hivemind_log_handler import utils -from arguments import AveragerArguments, BaseTrainingArguments, CollaborativeOptimizerArguments +from arguments import AveragerArguments, BaseTrainingArguments, OptimizerArguments use_hivemind_log_handler("in_root_logger") logger = get_logger(__name__) @@ -55,14 +56,14 @@ class TrainingMonitorArguments(BaseTrainingArguments): upload_interval: Optional[float] = field( default=None, metadata={"help": "Frequency (in seconds) of uploading the model to Hub"} ) - store_checkpoins: bool = field(default=False, metadata={"help": "If True, enables CheckpointHandler"}) + store_checkpoints: bool = field(default=False, metadata={"help": "If True, enables CheckpointHandler"}) class CheckpointHandler: def __init__( self, monitor_args: TrainingMonitorArguments, - collab_optimizer_args: CollaborativeOptimizerArguments, + optimizer_args: OptimizerArguments, averager_args: AveragerArguments, dht: hivemind.DHT, ): @@ -95,17 +96,13 @@ def __init__( debias=True, ) - adjusted_target_batch_size = collab_optimizer_args.target_batch_size - collab_optimizer_args.batch_size_lead - - self.collaborative_optimizer = hivemind.CollaborativeOptimizer( - opt=opt, + self.state_averager = TrainingStateAverager( dht=dht, + optimizer=opt, prefix=experiment_prefix, - compression_type=hivemind.Float16Compression(), - bandwidth=collab_optimizer_args.bandwidth, - target_batch_size=adjusted_target_batch_size, - client_mode=collab_optimizer_args.client_mode, - verbose=True, + state_compression=hivemind.Float16Compression(), + bandwidth=optimizer_args.bandwidth, + client_mode=optimizer_args.client_mode, start=True, **asdict(averager_args), ) @@ -121,7 +118,7 @@ def is_time_to_save_state(self, cur_step): def save_state(self, cur_step): logger.info("Saving state from peers") - self.collaborative_optimizer.load_state_from_peers() + self.state_averager.load_state_from_peers() self.previous_step = cur_step def is_time_to_upload(self): @@ -134,7 +131,7 @@ def is_time_to_upload(self): def upload_checkpoint(self, current_loss): logger.info("Saving optimizer") - torch.save(self.collaborative_optimizer.opt.state_dict(), f"{self.repo_path}/optimizer_state.pt") + torch.save(self.state_averager.optimizer.state_dict(), f"{self.repo_path}/optimizer_state.pt") self.previous_timestamp = time.time() logger.info("Started uploading to Model Hub") self.model.push_to_hub( @@ -146,8 +143,8 @@ def upload_checkpoint(self, current_loss): if __name__ == "__main__": - parser = HfArgumentParser((TrainingMonitorArguments, CollaborativeOptimizerArguments, AveragerArguments)) - monitor_args, collab_optimizer_args, averager_args = parser.parse_args_into_dataclasses() + parser = HfArgumentParser((TrainingMonitorArguments, OptimizerArguments, AveragerArguments)) + monitor_args, optimizer_args, averager_args = parser.parse_args_into_dataclasses() if monitor_args.use_google_dns: request = requests.get("https://api.ipify.org") @@ -176,8 +173,8 @@ def upload_checkpoint(self, current_loss): wandb.init(project=monitor_args.wandb_project) current_step = 0 - if monitor_args.store_checkpoins: - checkpoint_handler = CheckpointHandler(monitor_args, collab_optimizer_args, averager_args, dht) + if monitor_args.store_checkpoints: + checkpoint_handler = CheckpointHandler(monitor_args, optimizer_args, averager_args, dht) while True: metrics_dict = dht.get(experiment_prefix + "_metrics", latest=True) @@ -219,7 +216,7 @@ def upload_checkpoint(self, current_loss): } ) - if monitor_args.store_checkpoins: + if monitor_args.store_checkpoints: if checkpoint_handler.is_time_to_save_state(current_step): checkpoint_handler.save_state(current_step) if checkpoint_handler.is_time_to_upload(): diff --git a/hivemind/__init__.py b/hivemind/__init__.py index 37a4c82a6..f241bcdf4 100644 --- a/hivemind/__init__.py +++ b/hivemind/__init__.py @@ -23,4 +23,4 @@ from hivemind.p2p import P2P, P2PContext, P2PHandlerError, PeerID, PeerInfo from hivemind.utils import * -__version__ = "1.0.0dev0" +__version__ = "1.0.0" diff --git a/hivemind/optim/__init__.py b/hivemind/optim/__init__.py index 625f07daf..e1b09f3a1 100644 --- a/hivemind/optim/__init__.py +++ b/hivemind/optim/__init__.py @@ -1,7 +1,7 @@ from hivemind.optim.adaptive import CollaborativeAdaptiveOptimizer from hivemind.optim.base import DecentralizedOptimizerBase from hivemind.optim.collaborative import CollaborativeOptimizer -from hivemind.optim.experimental.optimizer import Optimizer from hivemind.optim.grad_scaler import GradScaler, HivemindGradScaler +from hivemind.optim.optimizer import Optimizer from hivemind.optim.simple import DecentralizedAdam, DecentralizedOptimizer, DecentralizedSGD from hivemind.optim.training_averager import TrainingAverager diff --git a/hivemind/optim/base.py b/hivemind/optim/base.py index e5eba1216..975c6ffeb 100644 --- a/hivemind/optim/base.py +++ b/hivemind/optim/base.py @@ -1,3 +1,5 @@ +from warnings import warn + import torch from hivemind.dht import DHT @@ -8,6 +10,12 @@ class DecentralizedOptimizerBase(torch.optim.Optimizer): def __init__(self, opt: torch.optim.Optimizer, dht: DHT): self.opt, self.dht = opt, dht + warn( + "DecentralizedOptimizerBase and its subclasses have been deprecated and will be removed " + "in hivemind 1.1.0. Use hivemind.Optimizer instead", + FutureWarning, + stacklevel=2, + ) @property def state(self): diff --git a/hivemind/optim/collaborative.py b/hivemind/optim/collaborative.py index 203fb268e..078d139be 100644 --- a/hivemind/optim/collaborative.py +++ b/hivemind/optim/collaborative.py @@ -57,15 +57,15 @@ class TrainingProgressSchema(BaseModel): class CollaborativeOptimizer(DecentralizedOptimizerBase): """ - :note: **For new projects please use hivemind.Optimizer**. CollaborativeOptimizer is an older version of that. - Currently, hivemind.Optimizer supports all the features of CollaborativeOptimizer and a many advanced ones. - CollaborativeOptimizer will still be supported for a while, but it will be deprecated eventually. - - An optimizer that performs model updates after collaboratively accumulating a target (large) batch size across peers + An optimizer that performs model updates after collaboratively accumulating a target (large) batch size across peers. These optimizers use DHT to track how much progress did the collaboration make towards target batch size. Once enough samples were accumulated, optimizers will compute a weighted average of their statistics. + :note: **For new projects, please use hivemind.Optimizer**. CollaborativeOptimizer is an older version of that. + Currently, hivemind.Optimizer supports all the features of CollaborativeOptimizer and many advanced ones. + CollaborativeOptimizer will still be supported for a while, but it will be deprecated in v1.1.0. + :note: This optimizer behaves unlike regular pytorch optimizers in two ways: * calling .step will periodically zero-out gradients w.r.t. model parameters after each step diff --git a/hivemind/optim/experimental/__init__.py b/hivemind/optim/experimental/__init__.py deleted file mode 100644 index e69de29bb..000000000 diff --git a/hivemind/optim/experimental/grad_averager.py b/hivemind/optim/grad_averager.py similarity index 100% rename from hivemind/optim/experimental/grad_averager.py rename to hivemind/optim/grad_averager.py diff --git a/hivemind/optim/experimental/optimizer.py b/hivemind/optim/optimizer.py similarity index 98% rename from hivemind/optim/experimental/optimizer.py rename to hivemind/optim/optimizer.py index a6768a98a..129875850 100644 --- a/hivemind/optim/experimental/optimizer.py +++ b/hivemind/optim/optimizer.py @@ -11,9 +11,10 @@ from hivemind.averaging.control import AveragingStage, StepControl from hivemind.compression import CompressionBase, NoCompression from hivemind.dht import DHT -from hivemind.optim.experimental.grad_averager import GradientAverager -from hivemind.optim.experimental.progress_tracker import ProgressTracker -from hivemind.optim.experimental.state_averager import ( +from hivemind.optim.grad_averager import GradientAverager +from hivemind.optim.grad_scaler import GradScaler +from hivemind.optim.progress_tracker import LocalTrainingProgress, ProgressTracker +from hivemind.optim.state_averager import ( LRSchedulerBase, OptimizerFactory, Parameters, @@ -22,7 +23,6 @@ TorchOptimizer, TrainingStateAverager, ) -from hivemind.optim.grad_scaler import GradScaler from hivemind.utils import PerformanceEMA, get_dht_time, get_logger logger = get_logger(__name__) @@ -154,7 +154,7 @@ class Optimizer(torch.optim.Optimizer): :param averager_opts: additional keyword arguments forwarded to both GradientAverager and TrainingStateAverager :param tracker_opts: additional keyword arguments forwarded to ProgressTracker - :param performance_ema_alpha: moving average alpha in ProgressTracer, TrainingStateAverager and Optimizer + :param performance_ema_alpha: moving average alpha in ProgressTracker, TrainingStateAverager and Optimizer :param verbose: if True, report internal events such as accumilating gradients and running background tasks :note: in a large-scale training, peers will inevitably fail and you will see error messages. hivemind.Optimizer @@ -345,6 +345,10 @@ def local_epoch(self) -> int: """ return self.state_averager.local_epoch + @property + def local_progress(self) -> LocalTrainingProgress: + return self.tracker.local_progress + @property def use_local_updates(self) -> bool: return self.grad_averager is None @@ -384,7 +388,7 @@ def step( with torch.enable_grad(): loss = closure() - if not self.auxiliary and self.should_load_state_from_peers(): + if not self.auxiliary and self._should_load_state_from_peers(): logger.log(self.status_loglevel, "Peer is out of sync") self.load_state_from_peers() return loss # local gradients were computed with out-of-sync parameters, must start over @@ -564,7 +568,6 @@ def _maybe_schedule_state_averaging(self) -> None: if eta_seconds_to_averaging <= self.matchmaking_time: if self.scheduled_state is None or self.scheduled_state.triggered or self.scheduled_state.done(): - min_matchmaking_time = self.state_averager.matchmaking_kwargs["min_matchmaking_time"] actual_seconds = max(eta_seconds_to_averaging, min_matchmaking_time) logger.log(self.status_loglevel, f"Pre-scheduling state averaging round in {actual_seconds:.2f} sec") @@ -626,7 +629,7 @@ def zero_grad(self, set_to_none: bool = False): else: param.grad.zero_() - def should_load_state_from_peers(self) -> bool: + def _should_load_state_from_peers(self) -> bool: """ If true, peer will discard local progress and attempt to download state from peers. This method allows peer to continue training in two cases: @@ -646,6 +649,10 @@ def should_load_state_from_peers(self) -> bool: return self.local_epoch != self.tracker.global_epoch # require exact synchronization once per step return self.local_epoch < self.tracker.global_epoch - 1 # catch up if a peer just switched to next epoch + def is_synchronized_with_peers(self) -> bool: + """Checks whether the current peer is up-to-date with others in terms of the epoch (step) number.""" + return self.local_epoch >= self.tracker.global_epoch - 1 + def load_state_from_peers(self, **kwargs): """ Attempt to load the newest collaboration state from other peers within the same run_id. diff --git a/hivemind/optim/experimental/progress_tracker.py b/hivemind/optim/progress_tracker.py similarity index 100% rename from hivemind/optim/experimental/progress_tracker.py rename to hivemind/optim/progress_tracker.py diff --git a/hivemind/optim/experimental/state_averager.py b/hivemind/optim/state_averager.py similarity index 100% rename from hivemind/optim/experimental/state_averager.py rename to hivemind/optim/state_averager.py diff --git a/tests/test_optimizer.py b/tests/test_optimizer.py index e6a38a61d..446077c4c 100644 --- a/tests/test_optimizer.py +++ b/tests/test_optimizer.py @@ -11,10 +11,10 @@ import hivemind from hivemind.averaging.control import AveragingStage -from hivemind.optim.experimental.grad_averager import GradientAverager -from hivemind.optim.experimental.optimizer import Optimizer -from hivemind.optim.experimental.progress_tracker import ProgressTracker -from hivemind.optim.experimental.state_averager import TrainingStateAverager +from hivemind.optim.grad_averager import GradientAverager +from hivemind.optim.optimizer import Optimizer +from hivemind.optim.progress_tracker import ProgressTracker +from hivemind.optim.state_averager import TrainingStateAverager from hivemind.utils.crypto import RSAPrivateKey