Skip to content

Commit 97fa3b5

Browse files
committed
Increase float precision in residual FSQ
1 parent 3cfb98c commit 97fa3b5

File tree

1 file changed

+3
-2
lines changed

1 file changed

+3
-2
lines changed

stable_codec/residual_fsq.py

+3-2
Original file line numberDiff line numberDiff line change
@@ -18,14 +18,15 @@ def __init__(self, stages: List[Tuple[List[int], float]]):
1818
self.codebook_size = sum(map(len, stages)) * self.n_codebooks
1919

2020
def encode(self, x):
21-
z = torch.tanh(x)
21+
input_dtype = x.dtype
22+
z = torch.tanh(x.to(torch.float64))
2223
z = rearrange(z, "b c n -> b n c")
2324

2425
r = z
2526
res_ids = []
2627
for quantizer in self.quantizers:
2728
q, ids = quantizer(r, skip_tanh=True)
28-
r = r - q
29+
r = r - q.to(torch.float64)
2930
res_ids.append(ids)
3031

3132
return res_ids

0 commit comments

Comments
 (0)