Skip to content

Commit 06cb796

Browse files
committedJan 14, 2025
Add local FSQ implementation as temp fix until SAT upstream
1 parent 97fa3b5 commit 06cb796

File tree

2 files changed

+135
-1
lines changed

2 files changed

+135
-1
lines changed
 

‎stable_codec/fsq.py

+134
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,134 @@
1+
"""
2+
Dithered Finite Scalar Quantization
3+
Code adapted from https://github.com/lucidrains/vector-quantize-pytorch/blob/master/vector_quantize_pytorch/finite_scalar_quantization.py
4+
"""
5+
6+
from typing import List, Tuple
7+
import random
8+
9+
import torch
10+
import torch.nn as nn
11+
from torch.nn import Module
12+
from torch import Tensor, int32
13+
from torch.amp import autocast
14+
15+
from einops import rearrange
16+
17+
18+
def leaky_hard_clip(x: Tensor, alpha: float = 1e-3) -> Tensor:
19+
return (1-alpha) * torch.clamp(x, -1, 1) + alpha * x
20+
21+
def round_ste(z: Tensor) -> Tensor:
22+
"""Round with straight through gradients."""
23+
zhat = z.round()
24+
return z + (zhat - z).detach()
25+
26+
class DitheredFSQ(Module):
27+
def __init__(
28+
self,
29+
levels: List[int],
30+
dither_inference: bool = False,
31+
num_codebooks: int = 1,
32+
noise_dropout: float = 0.5,
33+
scale: float = 1.0,
34+
):
35+
super().__init__()
36+
self.levels = levels
37+
38+
_levels = torch.tensor(levels, dtype=torch.int64)
39+
self.register_buffer("_levels", _levels, persistent = False)
40+
41+
_basis = torch.cumprod(torch.tensor([1] + levels[:-1]), dim=0, dtype=torch.int64)
42+
self.register_buffer("_basis", _basis, persistent = False)
43+
44+
codebook_dim = len(levels)
45+
self.codebook_dim = codebook_dim
46+
47+
self.codebook_size = _levels.prod().item()
48+
49+
self.num_codebooks = num_codebooks
50+
51+
self.dim = codebook_dim * num_codebooks
52+
53+
self.dither_inference = dither_inference
54+
55+
self.scale = scale
56+
57+
half_l = self.scale * 2 / (self._levels - 1)
58+
self.register_buffer("half_l", half_l, persistent = False)
59+
60+
self.allowed_dtypes = (torch.float32, torch.float64)
61+
62+
self.noise_dropout = noise_dropout
63+
64+
def quantize(self, z, skip_tanh: bool = False):
65+
if not skip_tanh: z = torch.tanh(z)
66+
67+
if not self.training:
68+
quantized = self._scale_and_shift_inverse(round_ste(self._scale_and_shift(z)))
69+
else:
70+
quantized = z
71+
mask = torch.bernoulli(torch.full([z.shape[0],1,1,1], self.noise_dropout, device = z.device)).bool().expand_as(z)
72+
quantized = torch.where(mask, quantized, self._scale_and_shift_inverse(round_ste(self._scale_and_shift(quantized))))
73+
mask = torch.bernoulli(torch.full([z.shape[0],1,1,1], self.noise_dropout, device = z.device)).bool().expand_as(z)
74+
quantized = torch.where(mask, quantized, z + (torch.rand_like(z) - 0.5) * self.half_l)
75+
76+
return quantized
77+
78+
def _scale_and_shift(self, z):
79+
level_indices = (z + 1 * self.scale) / self.half_l
80+
return level_indices
81+
82+
def _scale_and_shift_inverse(self, level_indices):
83+
z = level_indices * self.half_l - 1 * self.scale
84+
return z
85+
86+
def _indices_to_codes(self, indices):
87+
level_indices = self._indices_to_level_indices(indices)
88+
codes = self._scale_and_shift_inverse(level_indices)
89+
return codes
90+
91+
def _codes_to_indices(self, zhat):
92+
zhat = self._scale_and_shift(zhat)
93+
zhat = zhat.round().to(torch.int64)
94+
out = (zhat * self._basis).sum(dim=-1)
95+
return out
96+
97+
def _indices_to_level_indices(self, indices):
98+
indices = rearrange(indices, '... -> ... 1')
99+
codes_non_centered = (indices // self._basis) % self._levels
100+
return codes_non_centered
101+
102+
def indices_to_codes(self, indices):
103+
# Expects input of batch x sequence x num_codebooks
104+
assert indices.shape[-1] == self.num_codebooks, f'expected last dimension of {self.num_codebooks} but found last dimension of {indices.shape[-1]}'
105+
codes = self._indices_to_codes(indices.to(torch.int64))
106+
codes = rearrange(codes, '... c d -> ... (c d)')
107+
return codes
108+
109+
@autocast(device_type="cuda", enabled = False)
110+
def forward(self, z, skip_tanh: bool = False):
111+
112+
orig_dtype = z.dtype
113+
114+
assert z.shape[-1] == self.dim, f'expected dimension of {self.num_codebooks * self.dim} but found dimension of {z.shape[-1]}'
115+
116+
z = rearrange(z, 'b n (c d) -> b n c d', c = self.num_codebooks)
117+
118+
# make sure allowed dtype before quantizing
119+
120+
if z.dtype not in self.allowed_dtypes:
121+
z = z.to(torch.float64)
122+
123+
codes = self.quantize(z, skip_tanh=skip_tanh)
124+
indices = self._codes_to_indices(codes)
125+
codes = rearrange(codes, 'b n c d -> b n (c d)')
126+
127+
# cast codes back to original dtype
128+
129+
if codes.dtype != orig_dtype:
130+
codes = codes.type(orig_dtype)
131+
132+
# return quantized output and indices
133+
134+
return codes, indices

‎stable_codec/residual_fsq.py

+1-1
Original file line numberDiff line numberDiff line change
@@ -3,7 +3,7 @@
33

44
from typing import List, Tuple
55
from einops import rearrange
6-
from stable_audio_tools.models.fsq import DitheredFSQ
6+
from .fsq import DitheredFSQ
77

88
class ResidualFSQBottleneck(nn.Module):
99
def __init__(self, stages: List[Tuple[List[int], float]]):

0 commit comments

Comments
 (0)