Skip to content

Commit fcc02ca

Browse files
authored
Average buffers (#236)
* adding buffer averaging in RVQ * shared average * fix
1 parent 4b38d36 commit fcc02ca

File tree

2 files changed

+20
-1
lines changed

2 files changed

+20
-1
lines changed

moshi/moshi/__init__.py

+1-1
Original file line numberDiff line numberDiff line change
@@ -16,4 +16,4 @@
1616
from . import quantization
1717
from . import utils
1818

19-
__version__ = "0.2.2"
19+
__version__ = "0.2.3a2"

moshi/moshi/quantization/core_vq.py

+19
Original file line numberDiff line numberDiff line change
@@ -60,6 +60,20 @@ def _is_distributed() -> bool:
6060
return distributed.is_initialized() and distributed.get_world_size() > 1
6161

6262

63+
def _average_tensors(tensors: tp.Sequence[torch.Tensor]) -> None:
64+
if not _is_distributed():
65+
return
66+
world_size = distributed.get_world_size()
67+
handles = []
68+
for tensor in tensors:
69+
handle = distributed.all_reduce(
70+
tensor.data, op=distributed.ReduceOp.SUM, async_op=True)
71+
handles.append(handle)
72+
for tensor, handle in zip(tensors, handles):
73+
handle.wait()
74+
tensor.data /= world_size
75+
76+
6377
def _run_kmeans(samples: torch.Tensor, num_clusters: int, num_iters: int = 50) -> tp.Tuple[torch.Tensor, torch.Tensor]:
6478
# Kmeans algorithm used to initialize the codebooks.
6579
dim = samples.shape[-1]
@@ -481,6 +495,11 @@ def forward(
481495
if self.training:
482496
# Solving subtle bug with STE and RVQ: https://github.com/facebookresearch/encodec/issues/25
483497
quantized_out = x + (quantized_out - x).detach()
498+
to_average = []
499+
for layer in self.layers:
500+
assert isinstance(layer, VectorQuantization)
501+
to_average += [layer._codebook.cluster_usage, layer._codebook.embedding_sum]
502+
_average_tensors(to_average)
484503

485504
out_losses, out_codes = map(torch.stack, (all_losses, all_codes))
486505
return _VQForwardResult(quantized_out, out_codes, out_losses, all_metrics)

0 commit comments

Comments
 (0)