Skip to content

Commit 26358fa

Browse files
oulgenpytorchmergebot
authored andcommitted
Add AppendingByteSerializer class (pytorch#148226)
This PR adds a new util class that enables efficient appending of sequential byte data with custom serialization and deserialization. Pull Request resolved: pytorch#148226 Approved by: https://github.com/aorenste
1 parent b59776d commit 26358fa

File tree

2 files changed

+194
-0
lines changed

2 files changed

+194
-0
lines changed
+85
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,85 @@
1+
# Owner(s): ["module: inductor"]
2+
3+
import dataclasses
4+
5+
from torch.testing._internal.common_utils import TestCase
6+
from torch.utils._appending_byte_serializer import (
7+
AppendingByteSerializer,
8+
BytesReader,
9+
BytesWriter,
10+
)
11+
12+
13+
class TestAppendingByteSerializer(TestCase):
14+
def test_write_and_read_int(self) -> None:
15+
def int_serializer(writer: BytesWriter, i: int) -> None:
16+
writer.write_uint64(i)
17+
18+
def int_deserializer(reader: BytesReader) -> int:
19+
return reader.read_uint64()
20+
21+
s = AppendingByteSerializer(serialize_fn=int_serializer)
22+
23+
data = [1, 2, 3, 4]
24+
s.extend(data)
25+
self.assertListEqual(
26+
data,
27+
AppendingByteSerializer.to_list(
28+
s.to_bytes(), deserialize_fn=int_deserializer
29+
),
30+
)
31+
32+
data2 = [8, 9, 10, 11]
33+
s.extend(data2)
34+
self.assertListEqual(
35+
data + data2,
36+
AppendingByteSerializer.to_list(
37+
s.to_bytes(), deserialize_fn=int_deserializer
38+
),
39+
)
40+
41+
def test_write_and_read_class(self) -> None:
42+
@dataclasses.dataclass(frozen=True, eq=True)
43+
class Foo:
44+
x: int
45+
y: str
46+
z: bytes
47+
48+
@staticmethod
49+
def serialize(writer: BytesWriter, cls: "Foo") -> None:
50+
writer.write_uint64(cls.x)
51+
writer.write_str(cls.y)
52+
writer.write_bytes(cls.z)
53+
54+
@staticmethod
55+
def deserialize(reader: BytesReader) -> "Foo":
56+
x = reader.read_uint64()
57+
y = reader.read_str()
58+
z = reader.read_bytes()
59+
return Foo(x, y, z)
60+
61+
a = Foo(5, "ok", bytes([15]))
62+
b = Foo(10, "lol", bytes([25]))
63+
64+
s = AppendingByteSerializer(serialize_fn=Foo.serialize)
65+
s.append(a)
66+
self.assertListEqual(
67+
[a],
68+
AppendingByteSerializer.to_list(
69+
s.to_bytes(), deserialize_fn=Foo.deserialize
70+
),
71+
)
72+
73+
s.append(b)
74+
self.assertListEqual(
75+
[a, b],
76+
AppendingByteSerializer.to_list(
77+
s.to_bytes(), deserialize_fn=Foo.deserialize
78+
),
79+
)
80+
81+
82+
if __name__ == "__main__":
83+
from torch._inductor.test_case import run_tests
84+
85+
run_tests()
+109
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,109 @@
1+
from collections.abc import Iterable
2+
from typing import Callable, Generic, TypeVar
3+
4+
5+
T = TypeVar("T")
6+
7+
_ENCODING_VERSION: int = 1
8+
9+
__all__ = ["AppendingByteSerializer"]
10+
11+
12+
#######################################
13+
# Helper classes
14+
#######################################
15+
16+
17+
class BytesWriter:
18+
def __init__(self, preallocate_size: int) -> None:
19+
self._data = bytearray(preallocate_size)
20+
21+
def write_uint64(self, i: int) -> None:
22+
self._data.extend(i.to_bytes(8, byteorder="big", signed=False))
23+
24+
def write_str(self, s: str) -> None:
25+
payload = s.encode("utf-8")
26+
self.write_bytes(payload)
27+
28+
def write_bytes(self, b: bytes) -> None:
29+
self.write_uint64(len(b))
30+
self._data.extend(b)
31+
32+
def to_bytes(self) -> bytes:
33+
return bytes(self._data)
34+
35+
36+
class BytesReader:
37+
def __init__(self, data: bytes) -> None:
38+
self._data = data
39+
self._i = 0
40+
41+
def is_finished(self) -> bool:
42+
return len(self._data) == self._i
43+
44+
def read_uint64(self) -> int:
45+
result = int.from_bytes(
46+
self._data[self._i : self._i + 8], byteorder="big", signed=False
47+
)
48+
self._i += 8
49+
return result
50+
51+
def read_str(self) -> str:
52+
return self.read_bytes().decode("utf-8")
53+
54+
def read_bytes(self) -> bytes:
55+
size = self.read_uint64()
56+
result = self._data[self._i : self._i + size]
57+
self._i += size
58+
return result
59+
60+
61+
#######################################
62+
# AppendingByteSerializer
63+
#######################################
64+
65+
66+
class AppendingByteSerializer(Generic[T]):
67+
"""
68+
Provides efficient serialization and deserialization of list of bytes
69+
Note that this does not provide any guarantees around byte order
70+
"""
71+
72+
_serialize_fn: Callable[[BytesWriter, T], None]
73+
_writer: BytesWriter
74+
_preallocate_size: int
75+
76+
def __init__(
77+
self,
78+
*,
79+
serialize_fn: Callable[[BytesWriter, T], None],
80+
preallocate_size: int = 0,
81+
) -> None:
82+
self._serialize_fn = serialize_fn
83+
self._preallocate_size = preallocate_size
84+
self.clear()
85+
86+
def clear(self) -> None:
87+
self._writer = BytesWriter(preallocate_size=self._preallocate_size)
88+
# First 8-bytes are for version
89+
self._writer.write_uint64(_ENCODING_VERSION)
90+
91+
def append(self, data: T) -> None:
92+
self._serialize_fn(self._writer, data)
93+
94+
def extend(self, elems: Iterable[T]) -> None:
95+
for elem in elems:
96+
self.append(elem)
97+
98+
def to_bytes(self) -> bytes:
99+
return self._writer.to_bytes()
100+
101+
@staticmethod
102+
def to_list(data: bytes, *, deserialize_fn: Callable[[BytesReader], T]) -> list[T]:
103+
reader = BytesReader(data)
104+
assert reader.read_uint64() == _ENCODING_VERSION
105+
106+
result: list[T] = []
107+
while not reader.is_finished():
108+
result.append(deserialize_fn(reader))
109+
return result

0 commit comments

Comments
 (0)