Skip to content

Commit 3e877db

Browse files
authored
Use n_quantizers and save codes as uint16 (#12)
* restrict full quantizer usage during inference * save/load uint16 codes * add test * bump version
1 parent 7cb5f5b commit 3e877db

File tree

6 files changed

+29
-4
lines changed

6 files changed

+29
-4
lines changed

dac/__init__.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -1,4 +1,4 @@
1-
__version__ = "0.0.2"
1+
__version__ = "0.0.3"
22
__model_version__ = "0.0.1"
33
import audiotools
44

dac/nn/quantize.py

Lines changed: 3 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -171,6 +171,9 @@ def forward(self, z, n_quantizers: int = None):
171171
n_quantizers = n_quantizers.to(z.device)
172172

173173
for i, quantizer in enumerate(self.quantizers):
174+
if self.training is False and i >= n_quantizers:
175+
break
176+
174177
z_q_i, commitment_loss_i, codebook_loss_i, indices_i, z_e_i = quantizer(
175178
residual
176179
)

dac/utils/decode.py

Lines changed: 3 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -52,7 +52,9 @@ def process(
5252
"""
5353
if isinstance(generator, torch.nn.DataParallel):
5454
generator = generator.module
55-
audio_signal = AudioSignal(artifacts["codes"], generator.sample_rate)
55+
audio_signal = AudioSignal(
56+
artifacts["codes"].astype(np.int64), generator.sample_rate
57+
)
5658
metadata = artifacts["metadata"]
5759

5860
# Decode chunks

dac/utils/encode.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -104,7 +104,7 @@ def process(
104104
codebook_indices = torch.cat(codebook_indices, dim=0)
105105

106106
return {
107-
"codes": codebook_indices.numpy(),
107+
"codes": codebook_indices.numpy().astype(np.uint16),
108108
"metadata": {
109109
"original_db": input_db,
110110
"overlap_hop_duration": overlap_hop_duration,

setup.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -6,7 +6,7 @@
66

77
setup(
88
name="descript-audio-codec",
9-
version="0.0.2",
9+
version="0.0.3",
1010
classifiers=[
1111
"Intended Audience :: Developers",
1212
"Natural Language :: English",

tests/test_cli.py

Lines changed: 20 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -50,4 +50,24 @@ def test_reconstruction():
5050
run("decode")
5151

5252

53+
def test_compression():
54+
# Test encoding
55+
input_dir = Path(__file__).parent / "assets" / "input"
56+
output_dir = input_dir.parent / "encoded_output_quantizers"
57+
args = {"input": str(input_dir), "output": str(output_dir), "n_quantizers": 3}
58+
with argbind.scope(args):
59+
run("encode")
60+
61+
# Open .dac file
62+
dac_file = output_dir / "sample_0.dac"
63+
artifacts = np.load(dac_file, allow_pickle=True)[()]
64+
codes = artifacts["codes"]
65+
66+
# Ensure that the number of quantizers is correct
67+
assert codes.shape[1] == 3
68+
69+
# Ensure that dtype of compression is uint16
70+
assert codes.dtype == np.uint16
71+
72+
5373
# CUDA_VISIBLE_DEVICES=0 python -m pytest tests/test_cli.py -s

0 commit comments

Comments
 (0)