Skip to content

Commit effd8b7

Browse files
committed
shared average
1 parent 1b158f1 commit effd8b7

File tree

2 files changed

+6
-2
lines changed

2 files changed

+6
-2
lines changed

moshi/moshi/__init__.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -16,4 +16,4 @@
1616
from . import quantization
1717
from . import utils
1818

19-
__version__ = "0.2.3a1"
19+
__version__ = "0.2.3a2"

moshi/moshi/quantization/core_vq.py

Lines changed: 5 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -333,7 +333,6 @@ def forward(
333333
embedding_sum.scatter_add_(0, repeat(flat_codes, "n -> n d", d=self.dim), x)
334334
_ema_inplace(self.embedding_sum, embedding_sum, self.decay)
335335
self.register_buffer('_embedding', None)
336-
_average_tensors([self.embedding_sum, self.cluster_usage])
337336

338337
return _CodebookForwardResult(quantized, codes, metrics)
339338

@@ -496,6 +495,11 @@ def forward(
496495
if self.training:
497496
# Solving subtle bug with STE and RVQ: https://github.com/facebookresearch/encodec/issues/25
498497
quantized_out = x + (quantized_out - x).detach()
498+
to_average = []
499+
for layer in self.layers:
500+
assert isinstance(layer, VectorQuantization)
501+
to_average += [layer._codebook.cluster_usage, layer._codebook.embedding_sum]
502+
_average_tensors(to_average)
499503

500504
out_losses, out_codes = map(torch.stack, (all_losses, all_codes))
501505
return _VQForwardResult(quantized_out, out_codes, out_losses, all_metrics)

0 commit comments

Comments
 (0)