5050
5151_TAXONOMY_FILE_NAME = 'taxonomy.json'
5252
53- _MODEL_SUBPATH = os .path .join ('models' , 'model_normalized.obj' )
54-
5553_SPLIT_FILE_URL = \
5654 'http://shapenet.cs.stanford.edu/shapenet/obj-zip/SHREC16/all.csv'
5755
5856_CHECKSUMS_DIR = os .path .normpath (
5957 os .path .join (os .path .dirname (__file__ ), 'checksums/' ))
6058
6159
62- class Shapenet (tfds .core .GeneratorBasedBuilder ):
63- """ShapeNetCore V2.
64-
65- Example usage of the dataset:
60+ class ShapenetConfig (tfds .core .BuilderConfig ):
61+ """Base class for Shapenet BuilderConfigs.
6662
67- import tensorflow_datasets as tfds
68- from tensorflow_graphics.datasets.shapenet import Shapenet
63+ The Shapenet database builder delegates the implementation of info,
64+ split_generators and generate_examples to the specified ShapenetConfig. This
65+ is done to allow multiple versions of the dataset.
66+ """
6967
70- data_set = Shapenet.load(
71- split='train',
72- download_and_prepare_kwargs={
73- 'download_config':
74- tfds.download.DownloadConfig(manual_dir='~/shapenet_base')
75- })
68+ def info (self , dataset_builder ):
69+ """Delegated Shapenet._info."""
70+ raise NotImplementedError ('Abstract method' )
7671
77- for example in data_set.take(1 ):
78- trimesh, label, model_id = example['trimesh'], example['label'],
79- example['model_id']
72+ def split_generators ( self , dl_manager , dataset_builder ):
73+ """Delegated Shapenet._split_generators."""
74+ raise NotImplementedError ( 'Abstract method' )
8075
81- """
76+ def generate_examples (self , ** kwargs ):
77+ """Delegated Shapenet._generate_examples."""
78+ raise NotImplementedError ('Abstract method' )
8279
83- VERSION = tfds .core .Version ('1.0.0' )
8480
85- @staticmethod
86- def load (* args , ** kwargs ):
87- return tfds .load ('shapenet' , * args , ** kwargs )
81+ class MeshConfig (ShapenetConfig ):
82+ """A Shapenet config for loading the original .obj files."""
8883
89- MANUAL_DOWNLOAD_INSTRUCTIONS = textwrap .dedent ("""\
90- manual_dir should contain the extracted ShapeNetCore.v2.zip archive.
91- You need to register on https://shapenet.org/download/shapenetcore in order
92- to get the link to download the dataset.
93- """ )
84+ _MODEL_SUBPATH = os .path .join ('models' , 'model_normalized.obj' )
9485
95- def __init__ (self , * args , ** kwargs ):
96- super (Shapenet , self ).__init__ (* args , ** kwargs )
97- tfds .download .add_checksums_dir (_CHECKSUMS_DIR )
86+ def __init__ (self , model_subpath = _MODEL_SUBPATH ):
87+ super (MeshConfig , self ).__init__ (
88+ name = 'shapenet_trimesh' ,
89+ description = _DESCRIPTION ,
90+ version = tfds .core .Version ('1.0.0' ))
91+ self .model_subpath = model_subpath
9892
99- def _info (self ):
93+ def info (self , dataset_builder ):
10094 return tfds .core .DatasetInfo (
101- builder = self ,
95+ builder = dataset_builder ,
10296 description = _DESCRIPTION ,
10397 features = tfds_features .FeaturesDict ({
10498 'trimesh' : tfg_features .TriangleMesh (),
@@ -111,8 +105,7 @@ def _info(self):
111105 citation = _CITATION ,
112106 )
113107
114- def _split_generators (self , dl_manager ):
115- """Returns SplitGenerators."""
108+ def split_generators (self , dl_manager , dataset_builder ):
116109 # Extract the synset ids from the taxonomy file and update the ClassLabel
117110 # feature.
118111 with tf .io .gfile .GFile (
@@ -122,7 +115,7 @@ def _split_generators(self, dl_manager):
122115 # Remove duplicate labels (the json file contains two identical entries
123116 # for synset '04591713').
124117 labels = list (collections .OrderedDict .fromkeys (labels ))
125- self .info .features ['label' ].names = labels
118+ dataset_builder .info .features ['label' ].names = labels
126119
127120 split_file = dl_manager .download (_SPLIT_FILE_URL )
128121 fieldnames = ['id' , 'synset' , 'sub_synset' , 'model_id' , 'split' ]
@@ -155,7 +148,7 @@ def _split_generators(self, dl_manager):
155148 ),
156149 ]
157150
158- def _generate_examples (self , base_dir , models ):
151+ def generate_examples (self , base_dir , models ):
159152 """Yields examples.
160153
161154 The structure of the examples:
@@ -172,7 +165,8 @@ def _generate_examples(self, base_dir, models):
172165 for model in models :
173166 synset = model ['synset' ]
174167 model_id = model ['model_id' ]
175- model_filepath = os .path .join (base_dir , synset , model_id , _MODEL_SUBPATH )
168+ model_filepath = os .path .join (base_dir , synset , model_id ,
169+ self .model_subpath )
176170 # If the model doesn't exist, skip it.
177171 if not tf .io .gfile .exists (model_filepath ):
178172 continue
@@ -181,3 +175,54 @@ def _generate_examples(self, base_dir, models):
181175 'label' : synset ,
182176 'model_id' : model_id ,
183177 }
178+
179+
180+ class Shapenet (tfds .core .GeneratorBasedBuilder ):
181+ """ShapeNetCore V2.
182+
183+ Example usage of the dataset:
184+
185+ import tensorflow_datasets as tfds
186+ from tensorflow_graphics.datasets.shapenet import Shapenet
187+
188+ data_set = Shapenet.load(
189+ split='train',
190+ download_and_prepare_kwargs={
191+ 'download_config':
192+ tfds.download.DownloadConfig(manual_dir='~/shapenet_base')
193+ })
194+
195+ for example in data_set.take(1):
196+ trimesh, label, model_id = example['trimesh'], example['label'],
197+ example['model_id']
198+
199+ """
200+
201+ BUILDER_CONFIGS = [MeshConfig ()]
202+
203+ VERSION = tfds .core .Version ('1.0.0' )
204+
205+ @staticmethod
206+ def load (* args , ** kwargs ):
207+ return tfds .load ('shapenet' , * args , ** kwargs )
208+
209+ MANUAL_DOWNLOAD_INSTRUCTIONS = textwrap .dedent ("""\
210+ manual_dir should contain the extracted ShapeNetCore.v2.zip archive.
211+ You need to register on https://shapenet.org/download/shapenetcore in order
212+ to get the link to download the dataset.
213+ """ )
214+
215+ def __init__ (self , * args , ** kwargs ):
216+ super (Shapenet , self ).__init__ (* args , ** kwargs )
217+ tfds .download .add_checksums_dir (_CHECKSUMS_DIR )
218+
219+ def _info (self ):
220+ return self .builder_config .info (self )
221+
222+ def _split_generators (self , dl_manager ):
223+ """Returns SplitGenerators."""
224+ return self .builder_config .split_generators (dl_manager , self )
225+
226+ def _generate_examples (self , ** kwargs ):
227+ """Yields examples."""
228+ return self .builder_config .generate_examples (** kwargs )
0 commit comments