@@ -60,6 +60,20 @@ def _is_distributed() -> bool:
60
60
return distributed .is_initialized () and distributed .get_world_size () > 1
61
61
62
62
63
+ def _average_tensors (tensors : tp .Sequence [torch .Tensor ]) -> None :
64
+ if not _is_distributed ():
65
+ return
66
+ world_size = distributed .get_world_size ()
67
+ handles = []
68
+ for tensor in tensors :
69
+ handle = distributed .all_reduce (
70
+ tensor .data , op = distributed .ReduceOp .SUM , async_op = True )
71
+ handles .append (handle )
72
+ for tensor , handle in zip (tensors , handles ):
73
+ handle .wait ()
74
+ tensor .data /= world_size
75
+
76
+
63
77
def _run_kmeans (samples : torch .Tensor , num_clusters : int , num_iters : int = 50 ) -> tp .Tuple [torch .Tensor , torch .Tensor ]:
64
78
# Kmeans algorithm used to initialize the codebooks.
65
79
dim = samples .shape [- 1 ]
@@ -481,6 +495,11 @@ def forward(
481
495
if self .training :
482
496
# Solving subtle bug with STE and RVQ: https://github.com/facebookresearch/encodec/issues/25
483
497
quantized_out = x + (quantized_out - x ).detach ()
498
+ to_average = []
499
+ for layer in self .layers :
500
+ assert isinstance (layer , VectorQuantization )
501
+ to_average += [layer ._codebook .cluster_usage , layer ._codebook .embedding_sum ]
502
+ _average_tensors (to_average )
484
503
485
504
out_losses , out_codes = map (torch .stack , (all_losses , all_codes ))
486
505
return _VQForwardResult (quantized_out , out_codes , out_losses , all_metrics )
0 commit comments