From 237af3382586e51f5b3fdb31c6a821cff35b1cec Mon Sep 17 00:00:00 2001 From: SebastianLoef Date: Fri, 1 Sep 2023 14:01:58 +0000 Subject: [PATCH] Fixed bug --- src/data/nsynth.py | 10 +++++----- 1 file changed, 5 insertions(+), 5 deletions(-) diff --git a/src/data/nsynth.py b/src/data/nsynth.py index 08136b9..09d4d80 100644 --- a/src/data/nsynth.py +++ b/src/data/nsynth.py @@ -29,7 +29,7 @@ def __init__( self.sample_rate = sample_rate data_folder = os.path.join(root, "processed", "nsynth", f"nsynth-{self.subset}") if not os.path.exists(data_folder): - self.download() + self.download(root) self.file_paths, self.labels = self.load_data(data_folder) @property @@ -59,11 +59,11 @@ def load_data(self, data_folder: str) -> dict: labels.append(data[key][self._label]) return file_paths, labels - def download(self): - if not os.path.exists(self.root): - os.makedirs(self.root) + def download(self, root): + if not os.path.exists(root): + os.makedirs(root) url = f"http://download.magenta.tensorflow.org/datasets/nsynth/nsynth-{self.subset}.jsonwav.tar.gz" - download_and_extract(url, self.root, "nsynth") + download_and_extract(url, root, "nsynth") def __len__(self): return len(self.file_paths)