17
17
18
18
from __future__ import annotations
19
19
20
+ from collections .abc import Mapping
20
21
import os
21
22
22
23
import numpy as np
@@ -44,13 +45,27 @@ def _info(self):
44
45
homepage = 'https://github.com/deepmind/pg19' ,
45
46
)
46
47
48
+ def _get_paths (self , data_dir : str ) -> Mapping [str , str ]:
49
+ return {
50
+ 'metadata' : os .path .join (data_dir , 'metadata.csv' ),
51
+ 'train' : os .path .join (data_dir , 'train' ),
52
+ 'validation' : os .path .join (data_dir , 'validation' ),
53
+ 'test' : os .path .join (data_dir , 'test' ),
54
+ }
55
+
47
56
def _split_generators (self , dl_manager ):
48
57
"""Returns SplitGenerators."""
49
58
del dl_manager # Unused
50
59
51
60
metadata_dict = dict ()
52
- metadata_path = os .path .join (_DATA_DIR , 'metadata.csv' )
53
- metadata = tf .io .gfile .GFile (metadata_path ).read ().splitlines ()
61
+ if self .data_dir and all (
62
+ map (os .path .exists , self ._get_paths (self .data_dir ).values ())
63
+ ):
64
+ data_dir = self ._data_dir
65
+ else :
66
+ data_dir = _DATA_DIR
67
+ paths = self ._get_paths (data_dir )
68
+ metadata = tf .io .gfile .GFile (paths ['metadata' ]).read ().splitlines ()
54
69
55
70
for row in metadata :
56
71
row_split = row .split (',' )
@@ -62,21 +77,21 @@ def _split_generators(self, dl_manager):
62
77
name = tfds .Split .TRAIN ,
63
78
gen_kwargs = {
64
79
'metadata' : metadata_dict ,
65
- 'filepath' : os . path . join ( _DATA_DIR , 'train' ) ,
80
+ 'filepath' : paths [ 'train' ] ,
66
81
},
67
82
),
68
83
tfds .core .SplitGenerator (
69
84
name = tfds .Split .VALIDATION ,
70
85
gen_kwargs = {
71
86
'metadata' : metadata_dict ,
72
- 'filepath' : os . path . join ( _DATA_DIR , 'validation' ) ,
87
+ 'filepath' : paths [ 'validation' ] ,
73
88
},
74
89
),
75
90
tfds .core .SplitGenerator (
76
91
name = tfds .Split .TEST ,
77
92
gen_kwargs = {
78
93
'metadata' : metadata_dict ,
79
- 'filepath' : os . path . join ( _DATA_DIR , 'test' ) ,
94
+ 'filepath' : paths [ 'test' ] ,
80
95
},
81
96
),
82
97
]
0 commit comments