Skip to content

Commit 5d35e72

Browse files
Aviv Segallcopybara-github
Aviv Segall
authored andcommitted
Refactor the Shapenet dataset to support different BUILDER_CONFIGS.
PiperOrigin-RevId: 311760835
1 parent 6800317 commit 5d35e72

File tree

1 file changed

+82
-37
lines changed

1 file changed

+82
-37
lines changed

tensorflow_graphics/datasets/shapenet/shapenet.py

+82-37
Original file line numberDiff line numberDiff line change
@@ -50,55 +50,49 @@
5050

5151
_TAXONOMY_FILE_NAME = 'taxonomy.json'
5252

53-
_MODEL_SUBPATH = os.path.join('models', 'model_normalized.obj')
54-
5553
_SPLIT_FILE_URL = \
5654
'http://shapenet.cs.stanford.edu/shapenet/obj-zip/SHREC16/all.csv'
5755

5856
_CHECKSUMS_DIR = os.path.normpath(
5957
os.path.join(os.path.dirname(__file__), 'checksums/'))
6058

6159

62-
class Shapenet(tfds.core.GeneratorBasedBuilder):
63-
"""ShapeNetCore V2.
64-
65-
Example usage of the dataset:
60+
class ShapenetConfig(tfds.core.BuilderConfig):
61+
"""Base class for Shapenet BuilderConfigs.
6662
67-
import tensorflow_datasets as tfds
68-
from tensorflow_graphics.datasets.shapenet import Shapenet
63+
The Shapenet database builder delegates the implementation of info,
64+
split_generators and generate_examples to the specified ShapenetConfig. This
65+
is done to allow multiple versions of the dataset.
66+
"""
6967

70-
data_set = Shapenet.load(
71-
split='train',
72-
download_and_prepare_kwargs={
73-
'download_config':
74-
tfds.download.DownloadConfig(manual_dir='~/shapenet_base')
75-
})
68+
def info(self, dataset_builder):
69+
"""Delegated Shapenet._info."""
70+
raise NotImplementedError('Abstract method')
7671

77-
for example in data_set.take(1):
78-
trimesh, label, model_id = example['trimesh'], example['label'],
79-
example['model_id']
72+
def split_generators(self, dl_manager, dataset_builder):
73+
"""Delegated Shapenet._split_generators."""
74+
raise NotImplementedError('Abstract method')
8075

81-
"""
76+
def generate_examples(self, **kwargs):
77+
"""Delegated Shapenet._generate_examples."""
78+
raise NotImplementedError('Abstract method')
8279

83-
VERSION = tfds.core.Version('1.0.0')
8480

85-
@staticmethod
86-
def load(*args, **kwargs):
87-
return tfds.load('shapenet', *args, **kwargs)
81+
class MeshConfig(ShapenetConfig):
82+
"""A Shapenet config for loading the original .obj files."""
8883

89-
MANUAL_DOWNLOAD_INSTRUCTIONS = textwrap.dedent("""\
90-
manual_dir should contain the extracted ShapeNetCore.v2.zip archive.
91-
You need to register on https://shapenet.org/download/shapenetcore in order
92-
to get the link to download the dataset.
93-
""")
84+
_MODEL_SUBPATH = os.path.join('models', 'model_normalized.obj')
9485

95-
def __init__(self, *args, **kwargs):
96-
super(Shapenet, self).__init__(*args, **kwargs)
97-
tfds.download.add_checksums_dir(_CHECKSUMS_DIR)
86+
def __init__(self, model_subpath=_MODEL_SUBPATH):
87+
super(MeshConfig, self).__init__(
88+
name='shapenet_trimesh',
89+
description=_DESCRIPTION,
90+
version=tfds.core.Version('1.0.0'))
91+
self.model_subpath = model_subpath
9892

99-
def _info(self):
93+
def info(self, dataset_builder):
10094
return tfds.core.DatasetInfo(
101-
builder=self,
95+
builder=dataset_builder,
10296
description=_DESCRIPTION,
10397
features=tfds_features.FeaturesDict({
10498
'trimesh': tfg_features.TriangleMesh(),
@@ -111,8 +105,7 @@ def _info(self):
111105
citation=_CITATION,
112106
)
113107

114-
def _split_generators(self, dl_manager):
115-
"""Returns SplitGenerators."""
108+
def split_generators(self, dl_manager, dataset_builder):
116109
# Extract the synset ids from the taxonomy file and update the ClassLabel
117110
# feature.
118111
with tf.io.gfile.GFile(
@@ -122,7 +115,7 @@ def _split_generators(self, dl_manager):
122115
# Remove duplicate labels (the json file contains two identical entries
123116
# for synset '04591713').
124117
labels = list(collections.OrderedDict.fromkeys(labels))
125-
self.info.features['label'].names = labels
118+
dataset_builder.info.features['label'].names = labels
126119

127120
split_file = dl_manager.download(_SPLIT_FILE_URL)
128121
fieldnames = ['id', 'synset', 'sub_synset', 'model_id', 'split']
@@ -155,7 +148,7 @@ def _split_generators(self, dl_manager):
155148
),
156149
]
157150

158-
def _generate_examples(self, base_dir, models):
151+
def generate_examples(self, base_dir, models):
159152
"""Yields examples.
160153
161154
The structure of the examples:
@@ -172,7 +165,8 @@ def _generate_examples(self, base_dir, models):
172165
for model in models:
173166
synset = model['synset']
174167
model_id = model['model_id']
175-
model_filepath = os.path.join(base_dir, synset, model_id, _MODEL_SUBPATH)
168+
model_filepath = os.path.join(base_dir, synset, model_id,
169+
self.model_subpath)
176170
# If the model doesn't exist, skip it.
177171
if not tf.io.gfile.exists(model_filepath):
178172
continue
@@ -181,3 +175,54 @@ def _generate_examples(self, base_dir, models):
181175
'label': synset,
182176
'model_id': model_id,
183177
}
178+
179+
180+
class Shapenet(tfds.core.GeneratorBasedBuilder):
181+
"""ShapeNetCore V2.
182+
183+
Example usage of the dataset:
184+
185+
import tensorflow_datasets as tfds
186+
from tensorflow_graphics.datasets.shapenet import Shapenet
187+
188+
data_set = Shapenet.load(
189+
split='train',
190+
download_and_prepare_kwargs={
191+
'download_config':
192+
tfds.download.DownloadConfig(manual_dir='~/shapenet_base')
193+
})
194+
195+
for example in data_set.take(1):
196+
trimesh, label, model_id = example['trimesh'], example['label'],
197+
example['model_id']
198+
199+
"""
200+
201+
BUILDER_CONFIGS = [MeshConfig()]
202+
203+
VERSION = tfds.core.Version('1.0.0')
204+
205+
@staticmethod
206+
def load(*args, **kwargs):
207+
return tfds.load('shapenet', *args, **kwargs)
208+
209+
MANUAL_DOWNLOAD_INSTRUCTIONS = textwrap.dedent("""\
210+
manual_dir should contain the extracted ShapeNetCore.v2.zip archive.
211+
You need to register on https://shapenet.org/download/shapenetcore in order
212+
to get the link to download the dataset.
213+
""")
214+
215+
def __init__(self, *args, **kwargs):
216+
super(Shapenet, self).__init__(*args, **kwargs)
217+
tfds.download.add_checksums_dir(_CHECKSUMS_DIR)
218+
219+
def _info(self):
220+
return self.builder_config.info(self)
221+
222+
def _split_generators(self, dl_manager):
223+
"""Returns SplitGenerators."""
224+
return self.builder_config.split_generators(dl_manager, self)
225+
226+
def _generate_examples(self, **kwargs):
227+
"""Yields examples."""
228+
return self.builder_config.generate_examples(**kwargs)

0 commit comments

Comments
 (0)