|
| 1 | +""" |
| 2 | +Dithered Finite Scalar Quantization |
| 3 | +Code adapted from https://github.com/lucidrains/vector-quantize-pytorch/blob/master/vector_quantize_pytorch/finite_scalar_quantization.py |
| 4 | +""" |
| 5 | + |
| 6 | +from typing import List, Tuple |
| 7 | +import random |
| 8 | + |
| 9 | +import torch |
| 10 | +import torch.nn as nn |
| 11 | +from torch.nn import Module |
| 12 | +from torch import Tensor, int32 |
| 13 | +from torch.amp import autocast |
| 14 | + |
| 15 | +from einops import rearrange |
| 16 | + |
| 17 | + |
| 18 | +def leaky_hard_clip(x: Tensor, alpha: float = 1e-3) -> Tensor: |
| 19 | + return (1-alpha) * torch.clamp(x, -1, 1) + alpha * x |
| 20 | + |
| 21 | +def round_ste(z: Tensor) -> Tensor: |
| 22 | + """Round with straight through gradients.""" |
| 23 | + zhat = z.round() |
| 24 | + return z + (zhat - z).detach() |
| 25 | + |
| 26 | +class DitheredFSQ(Module): |
| 27 | + def __init__( |
| 28 | + self, |
| 29 | + levels: List[int], |
| 30 | + dither_inference: bool = False, |
| 31 | + num_codebooks: int = 1, |
| 32 | + noise_dropout: float = 0.5, |
| 33 | + scale: float = 1.0, |
| 34 | + ): |
| 35 | + super().__init__() |
| 36 | + self.levels = levels |
| 37 | + |
| 38 | + _levels = torch.tensor(levels, dtype=torch.int64) |
| 39 | + self.register_buffer("_levels", _levels, persistent = False) |
| 40 | + |
| 41 | + _basis = torch.cumprod(torch.tensor([1] + levels[:-1]), dim=0, dtype=torch.int64) |
| 42 | + self.register_buffer("_basis", _basis, persistent = False) |
| 43 | + |
| 44 | + codebook_dim = len(levels) |
| 45 | + self.codebook_dim = codebook_dim |
| 46 | + |
| 47 | + self.codebook_size = _levels.prod().item() |
| 48 | + |
| 49 | + self.num_codebooks = num_codebooks |
| 50 | + |
| 51 | + self.dim = codebook_dim * num_codebooks |
| 52 | + |
| 53 | + self.dither_inference = dither_inference |
| 54 | + |
| 55 | + self.scale = scale |
| 56 | + |
| 57 | + half_l = self.scale * 2 / (self._levels - 1) |
| 58 | + self.register_buffer("half_l", half_l, persistent = False) |
| 59 | + |
| 60 | + self.allowed_dtypes = (torch.float32, torch.float64) |
| 61 | + |
| 62 | + self.noise_dropout = noise_dropout |
| 63 | + |
| 64 | + def quantize(self, z, skip_tanh: bool = False): |
| 65 | + if not skip_tanh: z = torch.tanh(z) |
| 66 | + |
| 67 | + if not self.training: |
| 68 | + quantized = self._scale_and_shift_inverse(round_ste(self._scale_and_shift(z))) |
| 69 | + else: |
| 70 | + quantized = z |
| 71 | + mask = torch.bernoulli(torch.full([z.shape[0],1,1,1], self.noise_dropout, device = z.device)).bool().expand_as(z) |
| 72 | + quantized = torch.where(mask, quantized, self._scale_and_shift_inverse(round_ste(self._scale_and_shift(quantized)))) |
| 73 | + mask = torch.bernoulli(torch.full([z.shape[0],1,1,1], self.noise_dropout, device = z.device)).bool().expand_as(z) |
| 74 | + quantized = torch.where(mask, quantized, z + (torch.rand_like(z) - 0.5) * self.half_l) |
| 75 | + |
| 76 | + return quantized |
| 77 | + |
| 78 | + def _scale_and_shift(self, z): |
| 79 | + level_indices = (z + 1 * self.scale) / self.half_l |
| 80 | + return level_indices |
| 81 | + |
| 82 | + def _scale_and_shift_inverse(self, level_indices): |
| 83 | + z = level_indices * self.half_l - 1 * self.scale |
| 84 | + return z |
| 85 | + |
| 86 | + def _indices_to_codes(self, indices): |
| 87 | + level_indices = self._indices_to_level_indices(indices) |
| 88 | + codes = self._scale_and_shift_inverse(level_indices) |
| 89 | + return codes |
| 90 | + |
| 91 | + def _codes_to_indices(self, zhat): |
| 92 | + zhat = self._scale_and_shift(zhat) |
| 93 | + zhat = zhat.round().to(torch.int64) |
| 94 | + out = (zhat * self._basis).sum(dim=-1) |
| 95 | + return out |
| 96 | + |
| 97 | + def _indices_to_level_indices(self, indices): |
| 98 | + indices = rearrange(indices, '... -> ... 1') |
| 99 | + codes_non_centered = (indices // self._basis) % self._levels |
| 100 | + return codes_non_centered |
| 101 | + |
| 102 | + def indices_to_codes(self, indices): |
| 103 | + # Expects input of batch x sequence x num_codebooks |
| 104 | + assert indices.shape[-1] == self.num_codebooks, f'expected last dimension of {self.num_codebooks} but found last dimension of {indices.shape[-1]}' |
| 105 | + codes = self._indices_to_codes(indices.to(torch.int64)) |
| 106 | + codes = rearrange(codes, '... c d -> ... (c d)') |
| 107 | + return codes |
| 108 | + |
| 109 | + @autocast(device_type="cuda", enabled = False) |
| 110 | + def forward(self, z, skip_tanh: bool = False): |
| 111 | + |
| 112 | + orig_dtype = z.dtype |
| 113 | + |
| 114 | + assert z.shape[-1] == self.dim, f'expected dimension of {self.num_codebooks * self.dim} but found dimension of {z.shape[-1]}' |
| 115 | + |
| 116 | + z = rearrange(z, 'b n (c d) -> b n c d', c = self.num_codebooks) |
| 117 | + |
| 118 | + # make sure allowed dtype before quantizing |
| 119 | + |
| 120 | + if z.dtype not in self.allowed_dtypes: |
| 121 | + z = z.to(torch.float64) |
| 122 | + |
| 123 | + codes = self.quantize(z, skip_tanh=skip_tanh) |
| 124 | + indices = self._codes_to_indices(codes) |
| 125 | + codes = rearrange(codes, 'b n c d -> b n (c d)') |
| 126 | + |
| 127 | + # cast codes back to original dtype |
| 128 | + |
| 129 | + if codes.dtype != orig_dtype: |
| 130 | + codes = codes.type(orig_dtype) |
| 131 | + |
| 132 | + # return quantized output and indices |
| 133 | + |
| 134 | + return codes, indices |
0 commit comments