50
50
51
51
_TAXONOMY_FILE_NAME = 'taxonomy.json'
52
52
53
- _MODEL_SUBPATH = os .path .join ('models' , 'model_normalized.obj' )
54
-
55
53
_SPLIT_FILE_URL = \
56
54
'http://shapenet.cs.stanford.edu/shapenet/obj-zip/SHREC16/all.csv'
57
55
58
56
_CHECKSUMS_DIR = os .path .normpath (
59
57
os .path .join (os .path .dirname (__file__ ), 'checksums/' ))
60
58
61
59
62
- class Shapenet (tfds .core .GeneratorBasedBuilder ):
63
- """ShapeNetCore V2.
64
-
65
- Example usage of the dataset:
60
+ class ShapenetConfig (tfds .core .BuilderConfig ):
61
+ """Base class for Shapenet BuilderConfigs.
66
62
67
- import tensorflow_datasets as tfds
68
- from tensorflow_graphics.datasets.shapenet import Shapenet
63
+ The Shapenet database builder delegates the implementation of info,
64
+ split_generators and generate_examples to the specified ShapenetConfig. This
65
+ is done to allow multiple versions of the dataset.
66
+ """
69
67
70
- data_set = Shapenet.load(
71
- split='train',
72
- download_and_prepare_kwargs={
73
- 'download_config':
74
- tfds.download.DownloadConfig(manual_dir='~/shapenet_base')
75
- })
68
+ def info (self , dataset_builder ):
69
+ """Delegated Shapenet._info."""
70
+ raise NotImplementedError ('Abstract method' )
76
71
77
- for example in data_set.take(1 ):
78
- trimesh, label, model_id = example['trimesh'], example['label'],
79
- example['model_id']
72
+ def split_generators ( self , dl_manager , dataset_builder ):
73
+ """Delegated Shapenet._split_generators."""
74
+ raise NotImplementedError ( 'Abstract method' )
80
75
81
- """
76
+ def generate_examples (self , ** kwargs ):
77
+ """Delegated Shapenet._generate_examples."""
78
+ raise NotImplementedError ('Abstract method' )
82
79
83
- VERSION = tfds .core .Version ('1.0.0' )
84
80
85
- @staticmethod
86
- def load (* args , ** kwargs ):
87
- return tfds .load ('shapenet' , * args , ** kwargs )
81
+ class MeshConfig (ShapenetConfig ):
82
+ """A Shapenet config for loading the original .obj files."""
88
83
89
- MANUAL_DOWNLOAD_INSTRUCTIONS = textwrap .dedent ("""\
90
- manual_dir should contain the extracted ShapeNetCore.v2.zip archive.
91
- You need to register on https://shapenet.org/download/shapenetcore in order
92
- to get the link to download the dataset.
93
- """ )
84
+ _MODEL_SUBPATH = os .path .join ('models' , 'model_normalized.obj' )
94
85
95
- def __init__ (self , * args , ** kwargs ):
96
- super (Shapenet , self ).__init__ (* args , ** kwargs )
97
- tfds .download .add_checksums_dir (_CHECKSUMS_DIR )
86
+ def __init__ (self , model_subpath = _MODEL_SUBPATH ):
87
+ super (MeshConfig , self ).__init__ (
88
+ name = 'shapenet_trimesh' ,
89
+ description = _DESCRIPTION ,
90
+ version = tfds .core .Version ('1.0.0' ))
91
+ self .model_subpath = model_subpath
98
92
99
- def _info (self ):
93
+ def info (self , dataset_builder ):
100
94
return tfds .core .DatasetInfo (
101
- builder = self ,
95
+ builder = dataset_builder ,
102
96
description = _DESCRIPTION ,
103
97
features = tfds_features .FeaturesDict ({
104
98
'trimesh' : tfg_features .TriangleMesh (),
@@ -111,8 +105,7 @@ def _info(self):
111
105
citation = _CITATION ,
112
106
)
113
107
114
- def _split_generators (self , dl_manager ):
115
- """Returns SplitGenerators."""
108
+ def split_generators (self , dl_manager , dataset_builder ):
116
109
# Extract the synset ids from the taxonomy file and update the ClassLabel
117
110
# feature.
118
111
with tf .io .gfile .GFile (
@@ -122,7 +115,7 @@ def _split_generators(self, dl_manager):
122
115
# Remove duplicate labels (the json file contains two identical entries
123
116
# for synset '04591713').
124
117
labels = list (collections .OrderedDict .fromkeys (labels ))
125
- self .info .features ['label' ].names = labels
118
+ dataset_builder .info .features ['label' ].names = labels
126
119
127
120
split_file = dl_manager .download (_SPLIT_FILE_URL )
128
121
fieldnames = ['id' , 'synset' , 'sub_synset' , 'model_id' , 'split' ]
@@ -155,7 +148,7 @@ def _split_generators(self, dl_manager):
155
148
),
156
149
]
157
150
158
- def _generate_examples (self , base_dir , models ):
151
+ def generate_examples (self , base_dir , models ):
159
152
"""Yields examples.
160
153
161
154
The structure of the examples:
@@ -172,7 +165,8 @@ def _generate_examples(self, base_dir, models):
172
165
for model in models :
173
166
synset = model ['synset' ]
174
167
model_id = model ['model_id' ]
175
- model_filepath = os .path .join (base_dir , synset , model_id , _MODEL_SUBPATH )
168
+ model_filepath = os .path .join (base_dir , synset , model_id ,
169
+ self .model_subpath )
176
170
# If the model doesn't exist, skip it.
177
171
if not tf .io .gfile .exists (model_filepath ):
178
172
continue
@@ -181,3 +175,54 @@ def _generate_examples(self, base_dir, models):
181
175
'label' : synset ,
182
176
'model_id' : model_id ,
183
177
}
178
+
179
+
180
+ class Shapenet (tfds .core .GeneratorBasedBuilder ):
181
+ """ShapeNetCore V2.
182
+
183
+ Example usage of the dataset:
184
+
185
+ import tensorflow_datasets as tfds
186
+ from tensorflow_graphics.datasets.shapenet import Shapenet
187
+
188
+ data_set = Shapenet.load(
189
+ split='train',
190
+ download_and_prepare_kwargs={
191
+ 'download_config':
192
+ tfds.download.DownloadConfig(manual_dir='~/shapenet_base')
193
+ })
194
+
195
+ for example in data_set.take(1):
196
+ trimesh, label, model_id = example['trimesh'], example['label'],
197
+ example['model_id']
198
+
199
+ """
200
+
201
+ BUILDER_CONFIGS = [MeshConfig ()]
202
+
203
+ VERSION = tfds .core .Version ('1.0.0' )
204
+
205
+ @staticmethod
206
+ def load (* args , ** kwargs ):
207
+ return tfds .load ('shapenet' , * args , ** kwargs )
208
+
209
+ MANUAL_DOWNLOAD_INSTRUCTIONS = textwrap .dedent ("""\
210
+ manual_dir should contain the extracted ShapeNetCore.v2.zip archive.
211
+ You need to register on https://shapenet.org/download/shapenetcore in order
212
+ to get the link to download the dataset.
213
+ """ )
214
+
215
+ def __init__ (self , * args , ** kwargs ):
216
+ super (Shapenet , self ).__init__ (* args , ** kwargs )
217
+ tfds .download .add_checksums_dir (_CHECKSUMS_DIR )
218
+
219
+ def _info (self ):
220
+ return self .builder_config .info (self )
221
+
222
+ def _split_generators (self , dl_manager ):
223
+ """Returns SplitGenerators."""
224
+ return self .builder_config .split_generators (dl_manager , self )
225
+
226
+ def _generate_examples (self , ** kwargs ):
227
+ """Yields examples."""
228
+ return self .builder_config .generate_examples (** kwargs )
0 commit comments