7
7
# pyre-strict
8
8
9
9
import json
10
+ import math
10
11
import os
11
12
import tempfile
12
13
from dataclasses import dataclass
19
20
from executorch .exir ._serialize ._flatbuffer import _flatc_compile , _flatc_decompile
20
21
from executorch .exir ._serialize ._program import _insert_flatbuffer_header
21
22
from executorch .exir ._serialize .data_serializer import (
23
+ DataEntry ,
22
24
DataPayload ,
23
25
DataSerializer ,
24
26
TensorEntry ,
29
31
from executorch .extension .flat_tensor .serialize .flat_tensor_schema import (
30
32
DataSegment ,
31
33
FlatTensor ,
34
+ NamedData ,
32
35
TensorMetadata ,
33
36
)
34
37
@@ -202,6 +205,24 @@ def to_bytes(self) -> bytes:
202
205
return data
203
206
204
207
208
+ @dataclass
209
+ class AlignedData :
210
+ """
211
+ Holds data that should be aligned, for serialization.
212
+
213
+ Attributes:
214
+ data: The data to serialize, as a cord.
215
+ alignment: The alignment required for the data.
216
+ """
217
+
218
+ data : Cord
219
+ alignment : int
220
+
221
+ def __init__ (self , data : Cord , alignment : Optional [int ] = None ) -> None :
222
+ self .data = data
223
+ self .alignment = alignment or 1
224
+
225
+
205
226
def _get_extended_header (flat_tensor_data : bytes ) -> Optional [FlatTensorHeader ]:
206
227
"""Returns the extended header of the flat_tensor data, if present and valid."""
207
228
try :
@@ -216,7 +237,7 @@ def _get_extended_header(flat_tensor_data: bytes) -> Optional[FlatTensorHeader]:
216
237
def _extract_tensors (
217
238
fqn_to_tensor : Dict [str , TensorEntry ],
218
239
buffers : Sequence [bytes ],
219
- segments : List [Cord ],
240
+ segments : List [AlignedData ],
220
241
tensor_alignment : int ,
221
242
) -> List [TensorMetadata ]:
222
243
"""Places tensors into a single segment, aligned to tensor_alignment within
@@ -265,10 +286,43 @@ def _extract_tensors(
265
286
offset = offset ,
266
287
)
267
288
)
268
- segments .append (tensor_data )
289
+ segments .append (AlignedData ( tensor_data ) )
269
290
return tensors
270
291
271
292
293
+ def _extract_named_data (
294
+ key_to_data : Dict [str , DataEntry ],
295
+ buffers : Sequence [bytes ],
296
+ segments : List [AlignedData ],
297
+ ) -> List [NamedData ]:
298
+ """Places named data into segments and record the alignment for each.
299
+
300
+ Args:
301
+ key_to_data: A map from keys to opaque data entries.
302
+ buffers: A sequence of buffers holding opaque blob data.
303
+ segments: A list of segments to append data to. Modified in-place.
304
+
305
+ Returns:
306
+ A list of NamedData describing the offsets to the opaque blob data.
307
+ """
308
+
309
+ # Map from buffer_idx to segment_idx.
310
+ segment_index_map : Dict [int , int ] = {}
311
+
312
+ named_data : List [NamedData ] = []
313
+ for key , data_entry in key_to_data .items ():
314
+ buffer_idx = data_entry .buffer_index
315
+ segment_index = segment_index_map .get (buffer_idx , None )
316
+ if segment_index is None :
317
+ segment_index = len (segments )
318
+ segment_index_map [buffer_idx ] = segment_index
319
+ segments .append (
320
+ AlignedData (Cord (buffers [buffer_idx ]), data_entry .alignment )
321
+ )
322
+ named_data .append (NamedData (key = key , segment_index = segment_index ))
323
+ return named_data
324
+
325
+
272
326
class FlatTensorSerializer (DataSerializer ):
273
327
"""A concrete implementation of the DataSerializer interface that
274
328
serializes and deserializes data to/from the FlatTensor format.
@@ -289,35 +343,37 @@ def serialize(
289
343
) -> Cord :
290
344
"""Serializes a list of tensors and named data into a blob."""
291
345
292
- segments : List [Cord ] = []
346
+ segments : List [AlignedData ] = []
293
347
tensors = _extract_tensors (
294
348
data .fqn_to_tensor ,
295
349
data .buffers ,
296
350
segments ,
297
351
self .config .tensor_alignment ,
298
352
)
353
+ named_data = _extract_named_data (data .key_to_data , data .buffers , segments )
299
354
300
355
data_segments : List [DataSegment ] = []
301
- segment_data = Cord ()
356
+ aggregated_segment_data = Cord ()
302
357
for segment in segments :
303
358
prev_end = (
304
359
(data_segments [- 1 ].offset + data_segments [- 1 ].size )
305
360
if data_segments
306
361
else 0
307
362
)
363
+ alignment = math .lcm (self .config .segment_alignment , segment .alignment )
308
364
data_segments .append (
309
365
DataSegment (
310
- offset = aligned_size (prev_end , self . config . segment_alignment ),
311
- size = len (segment ),
366
+ offset = aligned_size (prev_end , alignment ),
367
+ size = len (segment . data ),
312
368
)
313
369
)
314
- # Pad segment_data to segment alignment.
370
+ # Pad aggregated_segment_data to segment alignment.
315
371
segment_pad_length = padding_required (
316
- len (segment_data ), self . config . segment_alignment
372
+ len (aggregated_segment_data ), alignment
317
373
)
318
374
if segment_pad_length > 0 :
319
- segment_data .append (b"\x00 " * segment_pad_length )
320
- segment_data .append (segment )
375
+ aggregated_segment_data .append (b"\x00 " * segment_pad_length )
376
+ aggregated_segment_data .append (segment . data )
321
377
322
378
# Create FlatTensor, which describes of the contents of the file and
323
379
# points to all the data segments. It will be serialized to flatbuffer.
@@ -326,7 +382,7 @@ def serialize(
326
382
tensor_alignment = self .config .tensor_alignment ,
327
383
tensors = tensors ,
328
384
segments = data_segments ,
329
- named_data = [] ,
385
+ named_data = named_data ,
330
386
)
331
387
332
388
flatbuffer_payload = _serialize_to_flatbuffer (flat_tensor )
@@ -351,7 +407,7 @@ def serialize(
351
407
flatbuffer_offset = padded_header_length ,
352
408
flatbuffer_size = len (flatbuffer_payload ),
353
409
segment_base_offset = segment_base_offset ,
354
- segment_data_size = len (segment_data ),
410
+ segment_data_size = len (aggregated_segment_data ),
355
411
).to_bytes ()
356
412
357
413
# Pad header and payload to segment alignment.
@@ -371,15 +427,15 @@ def serialize(
371
427
assert eh .flatbuffer_size == original_flatbuffer_payload_size
372
428
assert eh .segment_base_offset == segment_base_offset
373
429
assert eh .flatbuffer_offset == padded_header_length
374
- assert eh .segment_data_size == len (segment_data )
430
+ assert eh .segment_data_size == len (aggregated_segment_data )
375
431
376
432
del header_data
377
433
del flatbuffer_payload
378
434
379
435
# Place everything into one segment.
380
436
payload = Cord ()
381
437
payload .append (injected_flatbuffer_data )
382
- payload .append (segment_data )
438
+ payload .append (aggregated_segment_data )
383
439
384
440
return payload
385
441
0 commit comments