Skip to content

Commit e6948a6

Browse files
author
The TensorFlow Datasets Authors
committed
Fix tfds builders that try to access gcs even though the data is local.
PiperOrigin-RevId: 666405348
1 parent 858fbe5 commit e6948a6

File tree

2 files changed

+22
-5
lines changed

2 files changed

+22
-5
lines changed

tensorflow_datasets/datasets/pg19/pg19_dataset_builder.py

+20-5
Original file line numberDiff line numberDiff line change
@@ -17,6 +17,7 @@
1717

1818
from __future__ import annotations
1919

20+
from collections.abc import Mapping
2021
import os
2122

2223
import numpy as np
@@ -44,13 +45,27 @@ def _info(self):
4445
homepage='https://github.com/deepmind/pg19',
4546
)
4647

48+
def _get_paths(self, data_dir: str) -> Mapping[str, str]:
49+
return {
50+
'metadata': os.path.join(data_dir, 'metadata.csv'),
51+
'train': os.path.join(data_dir, 'train'),
52+
'validation': os.path.join(data_dir, 'validation'),
53+
'test': os.path.join(data_dir, 'test'),
54+
}
55+
4756
def _split_generators(self, dl_manager):
4857
"""Returns SplitGenerators."""
4958
del dl_manager # Unused
5059

5160
metadata_dict = dict()
52-
metadata_path = os.path.join(_DATA_DIR, 'metadata.csv')
53-
metadata = tf.io.gfile.GFile(metadata_path).read().splitlines()
61+
if self.data_dir and all(
62+
map(os.path.exists, self._get_paths(self.data_dir).values())
63+
):
64+
data_dir = self._data_dir
65+
else:
66+
data_dir = _DATA_DIR
67+
paths = self._get_paths(data_dir)
68+
metadata = tf.io.gfile.GFile(paths['metadata']).read().splitlines()
5469

5570
for row in metadata:
5671
row_split = row.split(',')
@@ -62,21 +77,21 @@ def _split_generators(self, dl_manager):
6277
name=tfds.Split.TRAIN,
6378
gen_kwargs={
6479
'metadata': metadata_dict,
65-
'filepath': os.path.join(_DATA_DIR, 'train'),
80+
'filepath': paths['train'],
6681
},
6782
),
6883
tfds.core.SplitGenerator(
6984
name=tfds.Split.VALIDATION,
7085
gen_kwargs={
7186
'metadata': metadata_dict,
72-
'filepath': os.path.join(_DATA_DIR, 'validation'),
87+
'filepath': paths['validation'],
7388
},
7489
),
7590
tfds.core.SplitGenerator(
7691
name=tfds.Split.TEST,
7792
gen_kwargs={
7893
'metadata': metadata_dict,
79-
'filepath': os.path.join(_DATA_DIR, 'test'),
94+
'filepath': paths['test'],
8095
},
8196
),
8297
]

tensorflow_datasets/robotics/dataset_importer_builder.py

+2
Original file line numberDiff line numberDiff line change
@@ -72,6 +72,8 @@ def get_relative_dataset_location(self):
7272
pass
7373

7474
def get_dataset_location(self):
75+
if self._data_dir:
76+
return self._data_dir
7577
return os.path.join(
7678
str(self._GCS_BUCKET), self.get_relative_dataset_location()
7779
)

0 commit comments

Comments
 (0)