Skip to content

Commit 06e8049

Browse files
authored
24kHz model release (#13)
* add support for multiple models * expose model_type arg to cli + docs update * version updates * update tests * expose download args via cli * doc correction * minor change * update readme * minor change
1 parent 3e877db commit 06e8049

File tree

7 files changed

+130
-26
lines changed

7 files changed

+130
-26
lines changed

README.md

Lines changed: 8 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -32,11 +32,14 @@ pip install git+https://github.com/descriptinc/descript-audio-codec
3232

3333
### Weights
3434
Weights are released as part of this repo under MIT license.
35-
They are automatically downloaded when you first run `encode` or `decode` command. They can be cached locally with
35+
We release weights for models that can natively support 24kHz and 44.1kHz sampling rates.
36+
Weights are automatically downloaded when you first run `encode` or `decode` command. You can cache them using one of the following commands
37+
```bash
38+
python3 -m dac download # downloads the default 44kHz variant
39+
python3 -m dac download --model_type 44khz # downloads the 44kHz variant
40+
python3 -m dac download --model_type 24khz # downloads the 24kHz variant
3641
```
37-
python3 -m dac download
38-
```
39-
We provide a Dockerfile that installs all required dependencies for encoding and decoding. The build process caches model weights inside the image. This allows the image to be used without an internet connection. [Please refer to instructions below.](#docker-image)
42+
We provide a Dockerfile that installs all required dependencies for encoding and decoding. The build process caches the default model weights inside the image. This allows the image to be used without an internet connection. [Please refer to instructions below.](#docker-image)
4043

4144

4245
### Compress audio
@@ -74,7 +77,7 @@ from audiotools import AudioSignal
7477
model = DAC()
7578

7679
# Load compatible pre-trained model
77-
model = load_model(dac.__model_version__)
80+
model = load_model(tag="latest", model_type="44khz")
7881
model.eval()
7982
model.to('cuda')
8083

dac/__init__.py

Lines changed: 5 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -1,5 +1,8 @@
1-
__version__ = "0.0.3"
2-
__model_version__ = "0.0.1"
1+
__version__ = "0.0.4"
2+
3+
# preserved here for legacy reasons
4+
__model_version__ = "latest"
5+
36
import audiotools
47

58
audiotools.ml.BaseModel.INTERN += ["dac.**"]

dac/utils/__init__.py

Lines changed: 55 additions & 10 deletions
Original file line numberDiff line numberDiff line change
@@ -1,24 +1,68 @@
11
from pathlib import Path
22

3+
import argbind
34
from audiotools import ml
45

56
import dac
67

7-
88
DAC = dac.model.DAC
99
Accelerator = ml.Accelerator
1010

11+
__MODEL_LATEST_TAGS__ = {
12+
"44khz": "0.0.1",
13+
"24khz": "0.0.4",
14+
}
15+
16+
__MODEL_URLS__ = {
17+
(
18+
"44khz",
19+
"0.0.1",
20+
): "https://github.com/descriptinc/descript-audio-codec/releases/download/0.0.1/weights.pth",
21+
(
22+
"24khz",
23+
"0.0.4",
24+
): "https://github.com/descriptinc/descript-audio-codec/releases/download/0.0.4/weights_24khz.pth",
25+
}
26+
1127

12-
def ensure_default_model(tag: str = dac.__model_version__):
28+
@argbind.bind(group="download", positional=True, without_prefix=True)
29+
def ensure_default_model(tag: str = "latest", model_type: str = "44khz"):
1330
"""
14-
Function that downloads the weights file from URL if a local cache is not
15-
found.
31+
Function that downloads the weights file from URL if a local cache is not found.
1632
17-
Args:
18-
tag (str): The tag of the model to download.
33+
Parameters
34+
----------
35+
tag : str
36+
The tag of the model to download. Defaults to "latest".
37+
model_type : str
38+
The type of model to download. Must be one of "44khz" or "24khz". Defaults to "44khz".
39+
40+
Returns
41+
-------
42+
Path
43+
Directory path required to load model via audiotools.
1944
"""
20-
download_link = f"https://github.com/descriptinc/descript-audio-codec/releases/download/{tag}/weights.pth"
21-
local_path = Path.home() / ".cache" / "descript" / tag / "dac" / f"weights.pth"
45+
model_type = model_type.lower()
46+
tag = tag.lower()
47+
48+
assert model_type in [
49+
"44khz",
50+
"24khz",
51+
], "model_type must be one of '44khz' or '24khz'"
52+
53+
if tag == "latest":
54+
tag = __MODEL_LATEST_TAGS__[model_type]
55+
56+
download_link = __MODEL_URLS__.get((model_type, tag), None)
57+
58+
if download_link is None:
59+
raise ValueError(
60+
f"Could not find model with tag {tag} and model type {model_type}"
61+
)
62+
63+
local_path = (
64+
Path.home() / ".cache" / "descript" / model_type / tag / "dac" / f"weights.pth"
65+
)
2266
if not local_path.exists():
2367
local_path.parent.mkdir(parents=True, exist_ok=True)
2468

@@ -38,11 +82,12 @@ def ensure_default_model(tag: str = dac.__model_version__):
3882

3983

4084
def load_model(
41-
tag: str,
85+
tag: str = "latest",
4286
load_path: str = "",
87+
model_type: str = "44khz",
4388
):
4489
if not load_path:
45-
load_path = ensure_default_model(tag)
90+
load_path = ensure_default_model(tag, model_type)
4691
kwargs = {
4792
"folder": load_path,
4893
"map_location": "cpu",

dac/utils/decode.py

Lines changed: 24 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -7,7 +7,6 @@
77
from audiotools import AudioSignal
88
from tqdm import tqdm
99

10-
import dac
1110
from dac.utils import load_model
1211

1312
warnings.filterwarnings("ignore", category=UserWarning)
@@ -99,13 +98,36 @@ def decode(
9998
input: str,
10099
output: str = "",
101100
weights_path: str = "",
102-
model_tag: str = dac.__model_version__,
101+
model_tag: str = "latest",
103102
preserve_sample_rate: bool = False,
104103
device: str = "cuda",
104+
model_type: str = "44khz",
105105
):
106+
"""Decode audio from codes.
107+
108+
Parameters
109+
----------
110+
input : str
111+
Path to input directory or file
112+
output : str, optional
113+
Path to output directory, by default "".
114+
If `input` is a directory, the directory sub-tree relative to `input` is re-created in `output`.
115+
weights_path : str, optional
116+
Path to weights file, by default "". If not specified, the weights file will be downloaded from the internet using the
117+
model_tag and model_type.
118+
model_tag : str, optional
119+
Tag of the model to use, by default "latest". Ignored if `weights_path` is specified.
120+
preserve_sample_rate : bool, optional
121+
If True, return audio will have the same sample rate as the original
122+
device : str, optional
123+
Device to use, by default "cuda". If "cpu", the model will be loaded on the CPU.
124+
model_type : str, optional
125+
The type of model to download. Must be one of "44khz" or "24khz". Defaults to "44khz". Ignored if `weights_path` is specified.
126+
"""
106127
generator = load_model(
107128
tag=model_tag,
108129
load_path=weights_path,
130+
model_type=model_type,
109131
)
110132
generator.to(device)
111133
generator.eval()

dac/utils/encode.py

Lines changed: 23 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -9,7 +9,6 @@
99
from audiotools.core import util
1010
from tqdm import tqdm
1111

12-
import dac
1312
from dac.utils import load_model
1413

1514
warnings.filterwarnings("ignore", category=UserWarning)
@@ -124,13 +123,35 @@ def encode(
124123
input: str,
125124
output: str = "",
126125
weights_path: str = "",
127-
model_tag: str = dac.__model_version__,
126+
model_tag: str = "latest",
128127
n_quantizers: int = None,
129128
device: str = "cuda",
129+
model_type: str = "44khz",
130130
):
131+
"""Encode audio files in input path to .dac format.
132+
133+
Parameters
134+
----------
135+
input : str
136+
Path to input audio file or directory
137+
output : str, optional
138+
Path to output directory, by default "". If `input` is a directory, the directory sub-tree relative to `input` is re-created in `output`.
139+
weights_path : str, optional
140+
Path to weights file, by default "". If not specified, the weights file will be downloaded from the internet using the
141+
model_tag and model_type.
142+
model_tag : str, optional
143+
Tag of the model to use, by default "latest". Ignored if `weights_path` is specified.
144+
n_quantizers : int, optional
145+
Number of quantizers to use, by default None. If not specified, all the quantizers will be used and the model will compress at maximum bitrate.
146+
device : str, optional
147+
Device to use, by default "cuda"
148+
model_type : str, optional
149+
The type of model to download. Must be one of "44khz" or "24khz". Defaults to "44khz". Ignored if `weights_path` is specified.
150+
"""
131151
generator = load_model(
132152
tag=model_tag,
133153
load_path=weights_path,
154+
model_type=model_type,
134155
)
135156
generator.to(device)
136157
generator.eval()

setup.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -6,7 +6,7 @@
66

77
setup(
88
name="descript-audio-codec",
9-
version="0.0.3",
9+
version="0.0.4",
1010
classifiers=[
1111
"Intended Audience :: Developers",
1212
"Natural Language :: English",

tests/test_cli.py

Lines changed: 14 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -6,6 +6,8 @@
66

77
import argbind
88
import numpy as np
9+
import pytest
10+
import torch
911
from audiotools import AudioSignal
1012

1113
from dac.__main__ import run
@@ -28,20 +30,23 @@ def teardown_module(module):
2830
subprocess.check_output(["rm", "-rf", f"{repo_root}/tests/assets"])
2931

3032

31-
def test_reconstruction():
33+
@pytest.mark.parametrize("model_type", ["44khz", "24khz"])
34+
def test_reconstruction(model_type):
3235
# Test encoding
3336
input_dir = Path(__file__).parent / "assets" / "input"
34-
output_dir = input_dir.parent / "encoded_output"
37+
output_dir = input_dir.parent / model_type / "encoded_output"
3538
args = {
3639
"input": str(input_dir),
3740
"output": str(output_dir),
41+
"device": "cuda" if torch.cuda.is_available() else "cpu",
42+
"model_type": model_type,
3843
}
3944
with argbind.scope(args):
4045
run("encode")
4146

4247
# Test decoding
4348
input_dir = output_dir
44-
output_dir = input_dir.parent / "decoded_output"
49+
output_dir = input_dir.parent / model_type / "decoded_output"
4550
args = {
4651
"input": str(input_dir),
4752
"output": str(output_dir),
@@ -54,7 +59,12 @@ def test_compression():
5459
# Test encoding
5560
input_dir = Path(__file__).parent / "assets" / "input"
5661
output_dir = input_dir.parent / "encoded_output_quantizers"
57-
args = {"input": str(input_dir), "output": str(output_dir), "n_quantizers": 3}
62+
args = {
63+
"input": str(input_dir),
64+
"output": str(output_dir),
65+
"n_quantizers": 3,
66+
"device": "cuda" if torch.cuda.is_available() else "cpu",
67+
}
5868
with argbind.scope(args):
5969
run("encode")
6070

0 commit comments

Comments
 (0)