Skip to content

Commit 020c068

Browse files
borzunovyhn112justheuristic
authored
Log collaboration step to Wandb, store metrics only if peer is synchronized (#267)
These are small practical changes moved from https://github.com/mryab/collaborative-training Co-authored-by: Michael Diskin <[email protected]> Co-authored-by: justheuristic <[email protected]>
1 parent e58f65d commit 020c068

File tree

3 files changed

+12
-10
lines changed

3 files changed

+12
-10
lines changed

examples/albert/run_first_peer.py

Lines changed: 3 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -177,13 +177,14 @@ def upload_checkpoint(self, current_loss):
177177
num_samples += item.samples_accumulated
178178
sum_mini_steps += item.mini_steps
179179
current_loss = sum_loss / sum_mini_steps
180-
180+
181181
if coordinator_args.wandb_project is not None:
182182
wandb.log({
183183
"loss": current_loss,
184184
"alive peers": alive_peers,
185185
"samples": num_samples,
186-
"performance": sum_perf
186+
"performance": sum_perf,
187+
"step": latest_step
187188
})
188189
if checkpoint_handler.is_time_to_save_state(current_step):
189190
checkpoint_handler.save_state(current_step)

examples/albert/run_trainer.py

Lines changed: 7 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -112,7 +112,7 @@ def __init__(self, dht: hivemind.DHT, optimizer: hivemind.CollaborativeOptimizer
112112

113113
def on_train_begin(self, args: TrainingArguments, state: transformers.TrainerState,
114114
control: transformers.TrainerControl, **kwargs):
115-
logger.warning('Loading state from peers')
115+
logger.info('Loading state from peers')
116116
self.collaborative_optimizer.load_state_from_peers()
117117

118118
def on_step_end(self, args: TrainingArguments, state: transformers.TrainerState,
@@ -139,14 +139,15 @@ def on_step_end(self, args: TrainingArguments, state: transformers.TrainerState,
139139
logger.info(f"Step {self.collaborative_optimizer.local_step}")
140140
logger.info(f"Your current contribution: {self.total_samples_processed} samples")
141141
if self.steps:
142-
logger.info(f"Loss of your model: {self.loss/self.steps}")
142+
logger.info(f"Local loss: {self.loss / self.steps}")
143143

144144
self.loss = 0
145145
self.steps = 0
146-
self.dht.store(key=self.collaborative_optimizer.prefix + "_metrics",
147-
subkey=self.local_public_key, value=statistics.dict(),
148-
expiration_time=hivemind.get_dht_time() + self.statistics_expiration,
149-
return_future=True)
146+
if self.collaborative_optimizer.is_synchronized:
147+
self.dht.store(key=self.collaborative_optimizer.prefix + "_metrics",
148+
subkey=self.local_public_key, value=statistics.dict(),
149+
expiration_time=hivemind.get_dht_time() + self.statistics_expiration,
150+
return_future=True)
150151

151152
self.samples = self.collaborative_optimizer.local_samples_accumulated
152153

hivemind/client/averaging/allreduce.py

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -54,7 +54,7 @@ def __init__(self, *, group_id: GroupID, tensors: Sequence[torch.Tensor], endpoi
5454
self.averaged_part: asyncio.Future[torch.Tensor] = asyncio.Future() # will be set to [accumulator / group size]
5555
self.averaged_tensor_parts: Dict[Endpoint, torch.Tensor] = {} # averaged chunks from all peers will be put here
5656
self.future: asyncio.Future[Sequence[torch.Tensor]] = asyncio.Future() # final result or exception
57-
57+
5858
self.num_senders = len([mode for mode in modes if mode != AveragingMode.AUX])
5959

6060
if self.num_senders == 0:
@@ -258,7 +258,7 @@ async def rpc_aggregate_part(self, stream: AsyncIterator[averaging_pb2.Averaging
258258
yield averaging_pb2.AveragingData(code=averaging_pb2.INTERNAL_ERROR)
259259

260260

261-
def split_into_parts(tensors: Sequence[torch.Tensor], part_sizes: Tuple[int]) -> Tuple[torch.Tensor, ...]:
261+
def split_into_parts(tensors: Sequence[torch.Tensor], part_sizes: Tuple[int, ...]) -> Tuple[torch.Tensor, ...]:
262262
""" combines averaged_tensors into one tensor and splits them into equal chunks of size group_size """
263263
flat_tensor = torch.cat(tuple(map(torch.Tensor.flatten, tensors)))
264264
return torch.split_with_sizes(flat_tensor, part_sizes, dim=0)

0 commit comments

Comments
 (0)