Skip to content

Commit 80a8585

Browse files
Add methods for serialization
1 parent 4b97420 commit 80a8585

File tree

1 file changed

+118
-36
lines changed

1 file changed

+118
-36
lines changed

kaitaistruct.py

+118-36
Original file line numberDiff line numberDiff line change
@@ -3,13 +3,13 @@
33
import struct
44
from io import open, BytesIO, SEEK_CUR, SEEK_END # noqa
55

6-
PY2 = sys.version_info[0] == 2
6+
PY_OLD = sys.version_info[0] == 2 or (sys.version_info[0] == 3 and sys.version_info[1] < 2)
77

8-
# Kaitai Struct runtime streaming API version, defined as per PEP-0396
9-
# standard. Used for two purposes:
8+
# Kaitai Struct runtime streaming API version, defined as per PEP-0396 standard.
9+
# Used for two purposes:
1010
#
11-
# * .py files generated by ksc from .ksy check that they import proper
12-
# KS runtime library by this version number;
11+
# * .py files generated by ksc from .ksy check that they import proper KS runtime
12+
# library by this version number;
1313
# * distribution utils (setup.py) use this when packaging for PyPI
1414
#
1515
__version__ = '0.9'
@@ -38,14 +38,32 @@ def from_file(cls, filename):
3838
f.close()
3939
raise
4040

41+
@classmethod
42+
def to_file(cls, filename, data=None):
43+
f = open(filename, 'rb')
44+
try:
45+
return cls(KaitaiStream(f), _mode='w', _data=data)
46+
except Exception:
47+
# close file descriptor, then reraise the exception
48+
f.close()
49+
raise
50+
4151
@classmethod
4252
def from_bytes(cls, buf):
4353
return cls(KaitaiStream(BytesIO(buf)))
4454

55+
@classmethod
56+
def to_bytes(cls, buf, data=None):
57+
return cls(KaitaiStream(BytesIO(buf)), _mode='w', _data=data)
58+
4559
@classmethod
4660
def from_io(cls, io):
4761
return cls(KaitaiStream(io))
4862

63+
@classmethod
64+
def to_io(cls, io, data=None):
65+
return cls(KaitaiStream(io), _mode='w', _data=data)
66+
4967

5068
class KaitaiStream(object):
5169
def __init__(self, io):
@@ -125,6 +143,9 @@ def size(self):
125143
def read_s1(self):
126144
return KaitaiStream.packer_s1.unpack(self.read_bytes(1))[0]
127145

146+
def write_s1(self, data):
147+
return self.write_bytes(KaitaiStream.packer_s1.pack(data))
148+
128149
# ........................................................................
129150
# Big-endian
130151
# ........................................................................
@@ -138,6 +159,15 @@ def read_s4be(self):
138159
def read_s8be(self):
139160
return KaitaiStream.packer_s8be.unpack(self.read_bytes(8))[0]
140161

162+
def write_s2be(self, data):
163+
return self.write_bytes(KaitaiStream.packer_s2be.pack(data))
164+
165+
def write_s4be(self, data):
166+
return self.write_bytes(KaitaiStream.packer_s4be.pack(data))
167+
168+
def write_s8be(self, data):
169+
return self.write_bytes(KaitaiStream.packer_s8be.pack(data))
170+
141171
# ........................................................................
142172
# Little-endian
143173
# ........................................................................
@@ -151,13 +181,25 @@ def read_s4le(self):
151181
def read_s8le(self):
152182
return KaitaiStream.packer_s8le.unpack(self.read_bytes(8))[0]
153183

184+
def write_s2le(self, data):
185+
return self.write_bytes(KaitaiStream.packer_s2le.pack(data))
186+
187+
def write_s4le(self, data):
188+
return self.write_bytes(KaitaiStream.packer_s4le.pack(data))
189+
190+
def write_s8le(self, data):
191+
return self.write_bytes(KaitaiStream.packer_s8le.pack(data))
192+
154193
# ------------------------------------------------------------------------
155194
# Unsigned
156195
# ------------------------------------------------------------------------
157196

158197
def read_u1(self):
159198
return KaitaiStream.packer_u1.unpack(self.read_bytes(1))[0]
160199

200+
def write_u1(self, data):
201+
return self.write_bytes(KaitaiStream.packer_u1.pack(data))
202+
161203
# ........................................................................
162204
# Big-endian
163205
# ........................................................................
@@ -171,6 +213,15 @@ def read_u4be(self):
171213
def read_u8be(self):
172214
return KaitaiStream.packer_u8be.unpack(self.read_bytes(8))[0]
173215

216+
def write_u2be(self, data):
217+
return self.write_bytes(KaitaiStream.packer_u2be.pack(data))
218+
219+
def write_u4be(self, data):
220+
return self.write_bytes(KaitaiStream.packer_u4be.pack(data))
221+
222+
def write_u8be(self, data):
223+
return self.write_bytes(KaitaiStream.packer_u8be.pack(data))
224+
174225
# ........................................................................
175226
# Little-endian
176227
# ........................................................................
@@ -184,6 +235,15 @@ def read_u4le(self):
184235
def read_u8le(self):
185236
return KaitaiStream.packer_u8le.unpack(self.read_bytes(8))[0]
186237

238+
def write_u2le(self, data):
239+
return self.write_bytes(KaitaiStream.packer_u2le.pack(data))
240+
241+
def write_u4le(self, data):
242+
return self.write_bytes(KaitaiStream.packer_u4le.pack(data))
243+
244+
def write_u8le(self, data):
245+
return self.write_bytes(KaitaiStream.packer_u8le.pack(data))
246+
187247
# ========================================================================
188248
# Floating point numbers
189249
# ========================================================================
@@ -203,6 +263,12 @@ def read_f4be(self):
203263
def read_f8be(self):
204264
return KaitaiStream.packer_f8be.unpack(self.read_bytes(8))[0]
205265

266+
def write_f4be(self, data):
267+
return self.write_bytes(KaitaiStream.packer_f4be.pack(data))
268+
269+
def write_f8be(self, data):
270+
return self.write_bytes(KaitaiStream.packer_f8be.pack(data))
271+
206272
# ........................................................................
207273
# Little-endian
208274
# ........................................................................
@@ -213,6 +279,12 @@ def read_f4le(self):
213279
def read_f8le(self):
214280
return KaitaiStream.packer_f8le.unpack(self.read_bytes(8))[0]
215281

282+
def write_f4le(self, data):
283+
return self.write_bytes(KaitaiStream.packer_f4le.pack(data))
284+
285+
def write_f8le(self, data):
286+
return self.write_bytes(KaitaiStream.packer_f8le.pack(data))
287+
216288
# ========================================================================
217289
# Unaligned bit values
218290
# ========================================================================
@@ -279,57 +351,67 @@ def read_bits_int_le(self, n):
279351
# Byte arrays
280352
# ========================================================================
281353

282-
def read_bytes(self, n):
354+
def alignment(self, a):
355+
return (a - self.pos()) % a
356+
357+
def read_bytes(self, n, align=0):
283358
if n < 0:
284-
raise ValueError(
285-
"requested invalid %d amount of bytes" %
286-
(n,)
287-
)
359+
raise ValueError("%d is invalid amount of bytes" % n)
288360
r = self._io.read(n)
289361
if len(r) < n:
290-
raise EOFError(
291-
"requested %d bytes, but got only %d bytes" %
292-
(n, len(r))
293-
)
362+
raise EOFError("got only %d bytes out of %d requested" % (n, len(r)))
363+
if align > 1:
364+
self._io.seek(self.alignment(align), 1)
294365
return r
295366

367+
def write_bytes(self, data, align=0, pad=0, padding=b'\0'):
368+
if data is None:
369+
return
370+
nb = len(data)
371+
if nb == 0 and align < 2 and pad < 1:
372+
return
373+
if self._io.write(data) != nb:
374+
raise Exception("not all bytes written")
375+
if pad > 0:
376+
self._io.write(padding * pad)
377+
if align > 1:
378+
self._io.write(padding * self.alignment(align))
379+
return
380+
296381
def read_bytes_full(self):
297382
return self._io.read()
298383

299-
def read_bytes_term(self, term, include_term, consume_term, eos_error):
384+
def read_bytes_term(self, term, include_term=False, consume_term=True, eos_error=True, elem_size=1):
300385
r = b''
301386
while True:
302-
c = self._io.read(1)
387+
c = self._io.read(elem_size)
303388
if c == b'':
304389
if eos_error:
305-
raise Exception(
306-
"end of stream reached, but no terminator %d found" %
307-
(term,)
308-
)
390+
raise Exception("end of stream reached, but no terminator (%d) found" % term)
309391
else:
310392
return r
311393
elif ord(c) == term:
312394
if include_term:
313395
r += c
314396
if not consume_term:
315-
self._io.seek(-1, SEEK_CUR)
397+
self._io.seek(-elem_size, SEEK_CUR)
316398
return r
317399
else:
318400
r += c
319401

402+
def write_bytes_term(self, data, term=b'\0', align=0):
403+
self.write_bytes(data, align=align, pad=1, padding=term)
404+
320405
def ensure_fixed_contents(self, expected):
321406
actual = self._io.read(len(expected))
322407
if actual != expected:
323-
raise Exception(
324-
"unexpected fixed contents: got %r, was waiting for %r" %
325-
(actual, expected)
326-
)
408+
raise Exception("unexpected fixed contents: got %r, was waiting for %r" % (actual, expected))
327409
return actual
328410

329411
@staticmethod
330-
def bytes_strip_right(data, pad_byte):
412+
def bytes_strip_right(data, pad_byte=b'\0'):
331413
new_len = len(data)
332-
if PY2:
414+
if PY_OLD:
333415
# data[...] must yield an integer, to compare with integer pad_byte
334416
data = bytearray(data)
335417

@@ -339,18 +421,18 @@ def bytes_strip_right(data, pad_byte):
339421
return data[:new_len]
340422

341423
@staticmethod
342-
def bytes_terminate(data, term, include_term):
424+
def bytes_terminate(data, term, include_term=True, elem_size=1):
343425
new_len = 0
344426
max_len = len(data)
345-
if PY2:
427+
if PY_OLD:
346428
# data[...] must yield an integer, to compare with integer term
347429
data = bytearray(data)
348430

349431
while new_len < max_len and data[new_len] != term:
350-
new_len += 1
432+
new_len += elem_size
351433

352434
if include_term and new_len < max_len:
353-
new_len += 1
435+
new_len += elem_size
354436

355437
return data[:new_len]
356438

@@ -360,14 +442,14 @@ def bytes_terminate(data, term, include_term):
360442

361443
@staticmethod
362444
def process_xor_one(data, key):
363-
if PY2:
445+
if PY_OLD:
364446
return bytes(bytearray(v ^ key for v in bytearray(data)))
365447
else:
366448
return bytes(v ^ key for v in data)
367449

368450
@staticmethod
369451
def process_xor_many(data, key):
370-
if PY2:
452+
if PY_OLD:
371453
return bytes(bytearray(a ^ b for a, b in zip(bytearray(data), itertools.cycle(bytearray(key)))))
372454
else:
373455
return bytes(a ^ b for a, b in zip(data, itertools.cycle(key)))
@@ -393,10 +475,10 @@ def process_rotate_left(data, amount, group_size):
393475
# ========================================================================
394476

395477
@staticmethod
396-
def int_from_byte(v):
397-
if PY2:
478+
def int_from_byte(v, signed=False):
479+
if PY_OLD:
398480
return ord(v)
399-
return v
481+
return int.from_bytes(v, signed=signed)
400482

401483
@staticmethod
402484
def byte_array_index(data, i):

0 commit comments

Comments
 (0)