Skip to content

Commit cc231cf

Browse files
authored
Serialize NamedDataStoreOutput into PTD. (#9758)
Update PTD serialization to account for blobs from the NamedDataStoreOutput. Something we can do in the future is to consolidate tensors (that go through the emitter) and blobs (that come from the NamedDataStore). Differential Revision: [D70939807](https://our.internmc.facebook.com/intern/diff/D70939807/)
1 parent 7fd589d commit cc231cf

File tree

4 files changed

+185
-58
lines changed

4 files changed

+185
-58
lines changed

exir/_serialize/_serialize.py

+48-27
Original file line numberDiff line numberDiff line change
@@ -6,13 +6,14 @@
66

77
# pyre-strict
88

9-
from typing import Dict, Optional, Tuple
9+
from typing import Dict, Optional, Set, Tuple
1010

1111
from executorch.exir._serialize import _serialize_pte_binary
1212

1313
from executorch.exir._serialize._cord import Cord
1414
from executorch.exir._serialize._named_data_store import NamedDataStoreOutput
1515
from executorch.exir._serialize.data_serializer import (
16+
DataEntry,
1617
DataPayload,
1718
DataSerializer,
1819
TensorEntry,
@@ -74,39 +75,59 @@ def serialize_for_executorch(
7475
tensor.extra_tensor_info.fully_qualified_name
7576
] = TensorLayout(tensor.scalar_type, tensor.sizes, tensor.dim_order)
7677

78+
if len(fqn_to_tensor_layout) == 0 and (
79+
named_data is None or len(named_data.external_data) == 0
80+
):
81+
return pte, ptd_files
82+
83+
# Consolidate tensors and opaque data with the same external tag so they
84+
# can be saved to the same PTD.
85+
all_external_tags: Set[str] = set()
86+
if named_data is not None and len(named_data.external_data) > 0:
87+
assert (
88+
len(named_data.buffers) > 0
89+
), "External data exists, but there are no buffers provided."
90+
all_external_tags = set(named_data.external_data.keys())
91+
7792
if len(fqn_to_tensor_layout) > 0:
7893
# emitter_output.external_constant_map contains the mapping from
7994
# {file: {fqn: index into external_constant_buffer}}
8095
# Contains the locations of the tensor buffers, and must be non-empty
8196
# if there are external tensors to serialize.
82-
assert emitter_output.external_constant_map is not None
83-
for (
84-
filename,
85-
fqn_to_index,
86-
) in (
87-
# pyre-ignore Undefined attribute [16]: Optional type has no attribute `items`.
88-
emitter_output.external_constant_map.items()
89-
):
90-
# Create a TensorEntry for each external tensor.
91-
fqn_to_tensor_entry: Dict[str, TensorEntry] = {}
92-
for fqn, index in fqn_to_index.items():
93-
assert fqn in fqn_to_tensor_layout
94-
fqn_to_tensor_entry[fqn] = TensorEntry(
95-
buffer_index=index,
96-
layout=fqn_to_tensor_layout[fqn],
97-
)
98-
99-
ptd_files[filename] = data_serializer.serialize(
100-
DataPayload(
101-
buffers=emitter_output.external_constant_buffer,
102-
fqn_to_tensor=fqn_to_tensor_entry,
103-
)
97+
assert (
98+
emitter_output.external_constant_map is not None
99+
), "External exists, but there are no buffers provided."
100+
all_external_tags = all_external_tags | set(
101+
emitter_output.external_constant_map.keys()
102+
)
103+
104+
for tag in all_external_tags:
105+
fqn_to_tensor_entry: Dict[str, TensorEntry] = {}
106+
# pyre-ignore[16]: Undefined attribute: `Optional` has no attribute `get`.
107+
fqn_to_index = emitter_output.external_constant_map.get(tag, {})
108+
# Create a TensorEntry for each external tensor.
109+
for fqn, index in fqn_to_index.items():
110+
assert fqn in fqn_to_tensor_layout
111+
fqn_to_tensor_entry[fqn] = TensorEntry(
112+
buffer_index=index,
113+
layout=fqn_to_tensor_layout[fqn],
104114
)
105115

106-
if named_data is None or len(named_data.external_data) == 0:
107-
return pte, ptd_files
116+
# Extract external data.
117+
key_to_data: Dict[str, DataEntry] = {}
118+
# pyre-ignore[16]: Undefined attribute: `Optional` has no attribute `get`.
119+
key_to_buffer_index = named_data.external_data.get(tag, {})
120+
for key, index in key_to_buffer_index.items():
121+
# pyre-ignore[16]: Undefined attribute: `Optional` has no attribute `buffers`.
122+
key_to_data[key] = DataEntry(index, named_data.buffers[index].alignment)
108123

109-
if len(named_data.buffers) == 0:
110-
raise RuntimeError("External data exists, but there are no buffers provided.")
124+
# Serialize into PTD file.
125+
ptd_files[tag] = data_serializer.serialize(
126+
DataPayload(
127+
buffers=emitter_output.external_constant_buffer,
128+
fqn_to_tensor=fqn_to_tensor_entry,
129+
key_to_data=key_to_data,
130+
)
131+
)
111132

112133
return pte, ptd_files

exir/_serialize/data_serializer.py

+17
Original file line numberDiff line numberDiff line change
@@ -38,6 +38,21 @@ class TensorEntry:
3838
layout: TensorLayout
3939

4040

41+
@dataclass
42+
class DataEntry:
43+
"""Represents a single blob in `DataPayload`, specifying its location
44+
and metadata.
45+
46+
Attributes:
47+
buffer_index: The index inside `DataPayload.buffers` that this
48+
DataEntry refers to.
49+
alignment: The alignment of the data.
50+
"""
51+
52+
buffer_index: int
53+
alignment: int
54+
55+
4156
@dataclass
4257
class DataPayload:
4358
"""Contains the data and metadata required for serialization.
@@ -49,10 +64,12 @@ class DataPayload:
4964
Attributes:
5065
buffers: a sequence of tensor buffers.
5166
fqn_to_tensor: a map from fully qualified names to serializable tensors.
67+
key_to_data: a map from unique keys to serializable opaque data.
5268
"""
5369

5470
buffers: Sequence[bytes]
5571
fqn_to_tensor: Dict[str, TensorEntry]
72+
key_to_data: Dict[str, DataEntry]
5673

5774

5875
class DataSerializer(ABC):

extension/flat_tensor/serialize/serialize.py

+70-14
Original file line numberDiff line numberDiff line change
@@ -7,6 +7,7 @@
77
# pyre-strict
88

99
import json
10+
import math
1011
import os
1112
import tempfile
1213
from dataclasses import dataclass
@@ -19,6 +20,7 @@
1920
from executorch.exir._serialize._flatbuffer import _flatc_compile, _flatc_decompile
2021
from executorch.exir._serialize._program import _insert_flatbuffer_header
2122
from executorch.exir._serialize.data_serializer import (
23+
DataEntry,
2224
DataPayload,
2325
DataSerializer,
2426
TensorEntry,
@@ -29,6 +31,7 @@
2931
from executorch.extension.flat_tensor.serialize.flat_tensor_schema import (
3032
DataSegment,
3133
FlatTensor,
34+
NamedData,
3235
TensorMetadata,
3336
)
3437

@@ -202,6 +205,24 @@ def to_bytes(self) -> bytes:
202205
return data
203206

204207

208+
@dataclass
209+
class AlignedData:
210+
"""
211+
Holds data that should be aligned, for serialization.
212+
213+
Attributes:
214+
data: The data to serialize, as a cord.
215+
alignment: The alignment required for the data.
216+
"""
217+
218+
data: Cord
219+
alignment: int
220+
221+
def __init__(self, data: Cord, alignment: Optional[int] = None) -> None:
222+
self.data = data
223+
self.alignment = alignment or 1
224+
225+
205226
def _get_extended_header(flat_tensor_data: bytes) -> Optional[FlatTensorHeader]:
206227
"""Returns the extended header of the flat_tensor data, if present and valid."""
207228
try:
@@ -216,7 +237,7 @@ def _get_extended_header(flat_tensor_data: bytes) -> Optional[FlatTensorHeader]:
216237
def _extract_tensors(
217238
fqn_to_tensor: Dict[str, TensorEntry],
218239
buffers: Sequence[bytes],
219-
segments: List[Cord],
240+
segments: List[AlignedData],
220241
tensor_alignment: int,
221242
) -> List[TensorMetadata]:
222243
"""Places tensors into a single segment, aligned to tensor_alignment within
@@ -265,10 +286,43 @@ def _extract_tensors(
265286
offset=offset,
266287
)
267288
)
268-
segments.append(tensor_data)
289+
segments.append(AlignedData(tensor_data))
269290
return tensors
270291

271292

293+
def _extract_named_data(
294+
key_to_data: Dict[str, DataEntry],
295+
buffers: Sequence[bytes],
296+
segments: List[AlignedData],
297+
) -> List[NamedData]:
298+
"""Places named data into segments and record the alignment for each.
299+
300+
Args:
301+
key_to_data: A map from keys to opaque data entries.
302+
buffers: A sequence of buffers holding opaque blob data.
303+
segments: A list of segments to append data to. Modified in-place.
304+
305+
Returns:
306+
A list of NamedData describing the offsets to the opaque blob data.
307+
"""
308+
309+
# Map from buffer_idx to segment_idx.
310+
segment_index_map: Dict[int, int] = {}
311+
312+
named_data: List[NamedData] = []
313+
for key, data_entry in key_to_data.items():
314+
buffer_idx = data_entry.buffer_index
315+
segment_index = segment_index_map.get(buffer_idx, None)
316+
if segment_index is None:
317+
segment_index = len(segments)
318+
segment_index_map[buffer_idx] = segment_index
319+
segments.append(
320+
AlignedData(Cord(buffers[buffer_idx]), data_entry.alignment)
321+
)
322+
named_data.append(NamedData(key=key, segment_index=segment_index))
323+
return named_data
324+
325+
272326
class FlatTensorSerializer(DataSerializer):
273327
"""A concrete implementation of the DataSerializer interface that
274328
serializes and deserializes data to/from the FlatTensor format.
@@ -289,35 +343,37 @@ def serialize(
289343
) -> Cord:
290344
"""Serializes a list of tensors and named data into a blob."""
291345

292-
segments: List[Cord] = []
346+
segments: List[AlignedData] = []
293347
tensors = _extract_tensors(
294348
data.fqn_to_tensor,
295349
data.buffers,
296350
segments,
297351
self.config.tensor_alignment,
298352
)
353+
named_data = _extract_named_data(data.key_to_data, data.buffers, segments)
299354

300355
data_segments: List[DataSegment] = []
301-
segment_data = Cord()
356+
aggregated_segment_data = Cord()
302357
for segment in segments:
303358
prev_end = (
304359
(data_segments[-1].offset + data_segments[-1].size)
305360
if data_segments
306361
else 0
307362
)
363+
alignment = math.lcm(self.config.segment_alignment, segment.alignment)
308364
data_segments.append(
309365
DataSegment(
310-
offset=aligned_size(prev_end, self.config.segment_alignment),
311-
size=len(segment),
366+
offset=aligned_size(prev_end, alignment),
367+
size=len(segment.data),
312368
)
313369
)
314-
# Pad segment_data to segment alignment.
370+
# Pad aggregated_segment_data to segment alignment.
315371
segment_pad_length = padding_required(
316-
len(segment_data), self.config.segment_alignment
372+
len(aggregated_segment_data), alignment
317373
)
318374
if segment_pad_length > 0:
319-
segment_data.append(b"\x00" * segment_pad_length)
320-
segment_data.append(segment)
375+
aggregated_segment_data.append(b"\x00" * segment_pad_length)
376+
aggregated_segment_data.append(segment.data)
321377

322378
# Create FlatTensor, which describes of the contents of the file and
323379
# points to all the data segments. It will be serialized to flatbuffer.
@@ -326,7 +382,7 @@ def serialize(
326382
tensor_alignment=self.config.tensor_alignment,
327383
tensors=tensors,
328384
segments=data_segments,
329-
named_data=[],
385+
named_data=named_data,
330386
)
331387

332388
flatbuffer_payload = _serialize_to_flatbuffer(flat_tensor)
@@ -351,7 +407,7 @@ def serialize(
351407
flatbuffer_offset=padded_header_length,
352408
flatbuffer_size=len(flatbuffer_payload),
353409
segment_base_offset=segment_base_offset,
354-
segment_data_size=len(segment_data),
410+
segment_data_size=len(aggregated_segment_data),
355411
).to_bytes()
356412

357413
# Pad header and payload to segment alignment.
@@ -371,15 +427,15 @@ def serialize(
371427
assert eh.flatbuffer_size == original_flatbuffer_payload_size
372428
assert eh.segment_base_offset == segment_base_offset
373429
assert eh.flatbuffer_offset == padded_header_length
374-
assert eh.segment_data_size == len(segment_data)
430+
assert eh.segment_data_size == len(aggregated_segment_data)
375431

376432
del header_data
377433
del flatbuffer_payload
378434

379435
# Place everything into one segment.
380436
payload = Cord()
381437
payload.append(injected_flatbuffer_data)
382-
payload.append(segment_data)
438+
payload.append(aggregated_segment_data)
383439

384440
return payload
385441

0 commit comments

Comments
 (0)