-
Notifications
You must be signed in to change notification settings - Fork 7
/
Copy pathdataset.py
382 lines (321 loc) · 15.3 KB
/
dataset.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
182
183
184
185
186
187
188
189
190
191
192
193
194
195
196
197
198
199
200
201
202
203
204
205
206
207
208
209
210
211
212
213
214
215
216
217
218
219
220
221
222
223
224
225
226
227
228
229
230
231
232
233
234
235
236
237
238
239
240
241
242
243
244
245
246
247
248
249
250
251
252
253
254
255
256
257
258
259
260
261
262
263
264
265
266
267
268
269
270
271
272
273
274
275
276
277
278
279
280
281
282
283
284
285
286
287
288
289
290
291
292
293
294
295
296
297
298
299
300
301
302
303
304
305
306
307
308
309
310
311
312
313
314
315
316
317
318
319
320
321
322
323
324
325
326
327
328
329
330
331
332
333
334
335
336
337
338
339
340
341
342
343
344
345
346
347
348
349
350
351
352
353
354
355
356
357
358
359
360
361
362
363
364
365
366
367
368
369
370
371
372
373
374
375
376
377
378
379
380
381
382
import inspect
import string
from functools import partial
from typing import Any, Callable, Dict, Sequence, Union
import tensorflow as tf
import tensorflow_datasets as tfds
from tensorflow_datasets.core.dataset_builder import DatasetBuilder
from dlimp.utils import parallel_vmap
def _wrap(f, is_flattened):
"""Wraps a method to return a DLataset instead of a tf.data.Dataset."""
def wrapper(*args, **kwargs):
result = f(*args, **kwargs)
if not isinstance(result, DLataset) and isinstance(result, tf.data.Dataset):
# make the result a subclass of DLataset and the original class
result.__class__ = type(
"DLataset", (DLataset, type(result)), DLataset.__dict__.copy()
)
# propagate the is_flattened flag
if is_flattened is None:
result.is_flattened = f.__self__.is_flattened
else:
result.is_flattened = is_flattened
return result
return wrapper
class DLataset(tf.data.Dataset):
"""A DLimp Dataset. This is a thin wrapper around tf.data.Dataset that adds some utilities for working
with datasets of trajectories.
A DLataset starts out as dataset of trajectories, where each dataset element is a single trajectory. A
dataset element is always a (possibly nested) dictionary from strings to tensors; however, a trajectory
has the additional property that each tensor has the same leading dimension, which is the trajectory
length. Each element of the trajectory is known as a frame.
A DLataset is just a tf.data.Dataset, so you can always use standard methods like `.map` and `.filter`.
However, a DLataset is also aware of the difference between trajectories and frames, so it provides some
additional methods. To perform a transformation at the trajectory level (e.g., restructuring, relabeling,
truncating), use `.traj_map`. To perform a transformation at the frame level (e.g., image decoding,
resizing, augmentations) use `.frame_map`.
Once there are no more trajectory-level transformation to perform, you can convert to DLataset to a
dataset of frames using `.flatten`. You can still use `.frame_map` after flattening, but using `.traj_map`
will raise an error.
"""
def __getattribute__(self, name):
# monkey-patches tf.data.Dataset methods to return DLatasets
attr = super().__getattribute__(name)
if inspect.ismethod(attr):
return _wrap(attr, None)
return attr
def _apply_options(self):
"""Applies some default options for performance."""
options = tf.data.Options()
# options.autotune.enabled = True
# options.deterministic = False
# options.experimental_optimization.apply_default_optimizations = True
# options.experimental_optimization.map_fusion = True
# options.experimental_optimization.map_and_filter_fusion = True
# options.experimental_optimization.inject_prefetch = False
# options.experimental_warm_start = True
options.autotune.enabled = False
options.deterministic = True
options.experimental_optimization.apply_default_optimizations = False
options.experimental_optimization.map_fusion = False
options.experimental_optimization.map_and_filter_fusion = False
options.experimental_optimization.inject_prefetch = True
options.experimental_warm_start = False
return self.with_options(options)
def with_ram_budget(self, gb: int) -> "DLataset":
"""Sets the RAM budget for the dataset. The default is half of the available memory.
Args:
gb (int): The RAM budget in GB.
"""
options = tf.data.Options()
options.autotune.ram_budget = gb * 1024 * 1024 * 1024 # GB --> Bytes
return self.with_options(options)
@staticmethod
def from_tfrecords(
dir_or_paths: Union[str, Sequence[str]],
shuffle: bool = True,
num_parallel_reads: int = 1, # tf.data.AUTOTUNE
) -> "DLataset":
"""Creates a DLataset from tfrecord files. The type spec of the dataset is inferred from the first file. The
only constraint is that each example must be a trajectory where each entry is either a scalar, a tensor of shape
(1, ...), or a tensor of shape (T, ...), where T is the length of the trajectory.
Args:
dir_or_paths (Union[str, Sequence[str]]): Either a directory containing .tfrecord files, or a list of paths
to tfrecord files.
shuffle (bool, optional): Whether to shuffle the tfrecord files. Defaults to True.
num_parallel_reads (int, optional): The number of tfrecord files to read in parallel. Defaults to AUTOTUNE. This
can use an excessive amount of memory if reading from cloud storage; decrease if necessary.
"""
if isinstance(dir_or_paths, str):
paths = tf.io.gfile.glob(tf.io.gfile.join(dir_or_paths, "*.tfrecord"))
else:
paths = dir_or_paths
if len(paths) == 0:
raise ValueError(f"No tfrecord files found in {dir_or_paths}")
if shuffle:
paths = tf.random.shuffle(paths)
# extract the type spec from the first file
type_spec = _get_type_spec(paths[0])
# read the tfrecords (yields raw serialized examples)
dataset = _wrap(tf.data.TFRecordDataset, False)(
paths,
num_parallel_reads=num_parallel_reads,
)._apply_options()
# decode the examples (yields trajectories)
dataset = dataset.traj_map(partial(_decode_example, type_spec=type_spec))
# broadcast traj metadata, as well as add some extra metadata (_len, _traj_index, _frame_index)
dataset = dataset.enumerate().traj_map(_broadcast_metadata)
return dataset
@staticmethod
def from_rlds(
builder: DatasetBuilder,
split: str = "train",
shuffle: bool = True,
num_parallel_reads: int = 1, # tf.data.AUTOTUNE
) -> "DLataset":
"""Creates a DLataset from the RLDS format (which is a special case of the TFDS format).
Args:
builder (DatasetBuilder): The TFDS dataset builder to load the dataset from.
data_dir (str): The directory to load the dataset from.
split (str, optional): The split to load, specified in TFDS format. Defaults to "train".
shuffle (bool, optional): Whether to shuffle the dataset. Defaults to True.
num_parallel_reads (int, optional): The number of tfrecord files to read in parallel. Defaults to AUTOTUNE. This
can use an excessive amount of memory if reading from cloud storage; decrease if necessary.
"""
dataset = _wrap(builder.as_dataset, False)(
split=split,
shuffle_files=shuffle,
decoders={"steps": tfds.decode.SkipDecoding()},
read_config=tfds.ReadConfig(
skip_prefetch=True,
num_parallel_calls_for_interleave_files=num_parallel_reads,
interleave_cycle_length=num_parallel_reads,
),
) #._apply_options()
dataset = dataset.enumerate().traj_map(_broadcast_metadata_rlds)
return dataset
def map(
self,
fn: Callable[[Dict[str, Any]], Dict[str, Any]],
num_parallel_calls=1, # tf.data.AUTOTUNE
**kwargs,
) -> "DLataset":
return super().map(fn, num_parallel_calls=num_parallel_calls, **kwargs)
def traj_map(
self,
fn: Callable[[Dict[str, Any]], Dict[str, Any]],
num_parallel_calls=1, # tf.data.AUTOTUNE
**kwargs,
) -> "DLataset":
"""Maps a function over the trajectories of the dataset. The function should take a single trajectory
as input and return a single trajectory as output.
"""
if self.is_flattened:
raise ValueError("Cannot call traj_map on a flattened dataset.")
return super().map(fn, num_parallel_calls=num_parallel_calls, **kwargs)
def frame_map(
self,
fn: Callable[[Dict[str, Any]], Dict[str, Any]],
num_parallel_calls=1, # tf.data.AUTOTUNE
**kwargs,
) -> "DLataset":
"""Maps a function over the frames of the dataset. The function should take a single frame as input
and return a single frame as output.
"""
if self.is_flattened:
return super().map(fn, num_parallel_calls=num_parallel_calls, **kwargs)
else:
return super().map(
parallel_vmap(fn, num_parallel_calls=num_parallel_calls),
num_parallel_calls=num_parallel_calls,
**kwargs,
)
def flatten(self, *, num_parallel_calls=1) -> "DLataset": # tf.data.AUTOTUNE
"""Flattens the dataset of trajectories into a dataset of frames."""
if self.is_flattened:
raise ValueError("Dataset is already flattened.")
dataset = self.interleave(
lambda traj: tf.data.Dataset.from_tensor_slices(traj),
cycle_length=num_parallel_calls,
num_parallel_calls=num_parallel_calls,
)
dataset.is_flattened = True
return dataset
def iterator(self, *, prefetch=0): # tf.data.AUTOTUNE
if prefetch == 0:
return self.as_numpy_iterator()
return self.prefetch(prefetch).as_numpy_iterator()
@staticmethod
def choose_from_datasets(datasets, choice_dataset, stop_on_empty_dataset=True):
if not isinstance(datasets[0], DLataset):
raise ValueError("Please pass DLatasets to choose_from_datasets.")
return _wrap(tf.data.Dataset.choose_from_datasets, datasets[0].is_flattened)(
datasets, choice_dataset, stop_on_empty_dataset=stop_on_empty_dataset
)
@staticmethod
def sample_from_datasets(
datasets,
weights=None,
seed=None,
stop_on_empty_dataset=False,
rerandomize_each_iteration=None,
):
if not isinstance(datasets[0], DLataset):
raise ValueError("Please pass DLatasets to sample_from_datasets.")
return _wrap(tf.data.Dataset.sample_from_datasets, datasets[0].is_flattened)(
datasets,
weights=weights,
seed=seed,
stop_on_empty_dataset=stop_on_empty_dataset,
rerandomize_each_iteration=rerandomize_each_iteration,
)
@staticmethod
def zip(*args, datasets=None, name=None):
if datasets is not None:
raise ValueError("Please do not pass `datasets=` to zip.")
if not isinstance(args[0], DLataset):
raise ValueError("Please pass DLatasets to zip.")
return _wrap(tf.data.Dataset.zip, args[0].is_flattened)(*args, name=name)
def _decode_example(
example_proto: tf.Tensor, type_spec: Dict[str, tf.TensorSpec]
) -> Dict[str, tf.Tensor]:
features = {key: tf.io.FixedLenFeature([], tf.string) for key in type_spec.keys()}
parsed_features = tf.io.parse_single_example(example_proto, features)
parsed_tensors = {
key: tf.io.parse_tensor(parsed_features[key], spec.dtype)
if spec is not None
else parsed_features[key]
for key, spec in type_spec.items()
}
for key in parsed_tensors:
if type_spec[key] is not None:
parsed_tensors[key] = tf.ensure_shape(
parsed_tensors[key], type_spec[key].shape
)
return parsed_tensors
def _get_type_spec(path: str) -> Dict[str, tf.TensorSpec]:
"""Get a type spec from a tfrecord file.
Args:
path (str): Path to a single tfrecord file.
Returns:
dict: A dictionary mapping feature names to tf.TensorSpecs.
"""
data = next(iter(tf.data.TFRecordDataset(path))).numpy()
example = tf.train.Example()
example.ParseFromString(data)
printable_chars = set(bytes(string.printable, "utf-8"))
out = {}
for key, value in example.features.feature.items():
data = value.bytes_list.value[0]
# stupid hack to deal with strings that are not encoded as tensors
if all(char in printable_chars for char in data):
out[key] = None
continue
tensor_proto = tf.make_tensor_proto([])
tensor_proto.ParseFromString(data)
dtype = tf.dtypes.as_dtype(tensor_proto.dtype)
shape = [d.size for d in tensor_proto.tensor_shape.dim]
if shape:
shape[0] = None # first dimension is trajectory length, which is variable
out[key] = tf.TensorSpec(shape=shape, dtype=dtype)
return out
def _broadcast_metadata(
i: tf.Tensor, traj: Dict[str, tf.Tensor]
) -> Dict[str, tf.Tensor]:
"""
Each element of a dlimp dataset is a trajectory. This means each entry must either have a leading dimension equal to
the length of the trajectory, have a leading dimension of 1, or be a scalar. Entries with a leading dimension of 1
and scalars are assumed to be trajectory-level metadata. This function broadcasts these entries to the length of the
trajectory, as well as adds the extra metadata fields `_len`, `_traj_index`, and `_frame_index`.
"""
# get the length of each dict entry
traj_lens = {
k: tf.shape(v)[0] if len(v.shape) > 0 else None for k, v in traj.items()
}
# take the maximum length as the canonical length (elements should either be the same length or length 1)
traj_len = tf.reduce_max([l for l in traj_lens.values() if l is not None])
for k in traj:
# broadcast scalars to the length of the trajectory
if traj_lens[k] is None:
traj[k] = tf.repeat(traj[k], traj_len)
traj_lens[k] = traj_len
# broadcast length-1 elements to the length of the trajectory
if traj_lens[k] == 1:
traj[k] = tf.repeat(traj[k], traj_len, axis=0)
traj_lens[k] = traj_len
asserts = [
# make sure all the lengths are the same
tf.assert_equal(
tf.size(tf.unique(tf.stack(list(traj_lens.values()))).y),
1,
message="All elements must have the same length.",
),
]
assert "_len" not in traj
assert "_traj_index" not in traj
assert "_frame_index" not in traj
traj["_len"] = tf.repeat(traj_len, traj_len)
traj["_traj_index"] = tf.repeat(i, traj_len)
traj["_frame_index"] = tf.range(traj_len)
with tf.control_dependencies(asserts):
return traj
def _broadcast_metadata_rlds(i: tf.Tensor, traj: Dict[str, Any]) -> Dict[str, Any]:
"""
In the RLDS format, each trajectory has some top-level metadata that is explicitly separated out, and a "steps"
entry. This function moves the "steps" entry to the top level, broadcasting any metadata to the length of the
trajectory. This function also adds the extra metadata fields `_len`, `_traj_index`, and `_frame_index`.
"""
steps = traj.pop("steps")
traj_len = tf.shape(tf.nest.flatten(steps)[0])[0]
# broadcast metadata to the length of the trajectory
metadata = tf.nest.map_structure(lambda x: tf.repeat(x, traj_len), traj)
# put steps back in
assert "traj_metadata" not in steps
traj = {**steps, "traj_metadata": metadata}
assert "_len" not in traj
assert "_traj_index" not in traj
assert "_frame_index" not in traj
traj["_len"] = tf.repeat(traj_len, traj_len)
traj["_traj_index"] = tf.repeat(i, traj_len)
traj["_frame_index"] = tf.range(traj_len)
return traj