Skip to content

Commit cf35373

Browse files
committed
Complete overhaul of interaction between KaitaiStream and KaitaiStruct
1 parent 4accba9 commit cf35373

File tree

1 file changed

+195
-24
lines changed

1 file changed

+195
-24
lines changed

kaitaistruct.py

+195-24
Original file line numberDiff line numberDiff line change
@@ -1,7 +1,12 @@
1+
import typing
12
import itertools
23
import sys
34
import struct
45
from io import open, BytesIO, SEEK_CUR, SEEK_END # noqa
6+
from io import IOBase, BufferedIOBase
7+
import mmap
8+
from pathlib import Path
9+
from abc import ABC, abstractmethod
510

611
PY2 = sys.version_info[0] == 2
712

@@ -24,51 +29,216 @@
2429
# pylint: disable=useless-object-inheritance,super-with-arguments,consider-using-f-string
2530

2631

27-
class KaitaiStruct(object):
28-
def __init__(self, stream):
29-
self._io = stream
32+
class _NonClosingNonParsingKaitaiStruct:
33+
__slots__ = ("_io", "_parent", "_root")
34+
35+
def __init__(self, _io: "KaitaiStream", _parent: typing.Optional["_NonClosingNonParsingKaitaiStruct"] = None, _root: typing.Optional["_NonClosingNonParsingKaitaiStruct"] = None):
36+
self._io = _io
37+
self._parent = _parent
38+
self._root = _root if _root else self
39+
40+
41+
class NonClosingKaitaiStruct(_NonClosingNonParsingKaitaiStruct, ABC):
42+
__slots__ = ()
43+
44+
@abstractmethod
45+
def _read(self):
46+
raise NotImplementedError()
47+
48+
49+
class KaitaiStruct(NonClosingKaitaiStruct):
50+
__slots__ = ("_shouldExit",)
51+
def __init__(self, io: typing.Union["KaitaiStream", Path, bytes, str]):
52+
if not isinstance(io, KaitaiStream):
53+
io = KaitaiStream(io)
54+
super.__init__(io)
55+
self._shouldExit = False
3056

3157
def __enter__(self):
58+
self._shouldExit = not self.stream.is_entered
59+
if self._shouldExit:
60+
self._io.__enter__()
3261
return self
3362

3463
def __exit__(self, *args, **kwargs):
35-
self.close()
64+
if self.shouldExit:
65+
self._io.__exit__(*args, **kwargs)
3666

37-
def close(self):
38-
self._io.close()
67+
@classmethod
68+
def from_any(cls, o: typing.Union[Path, str]) -> "KaitaiStruct":
69+
with KaitaiStream(o) as io:
70+
s = cls(io)
71+
s._read()
72+
return s
3973

4074
@classmethod
41-
def from_file(cls, filename):
42-
f = open(filename, 'rb')
43-
try:
44-
return cls(KaitaiStream(f))
45-
except Exception:
46-
# close file descriptor, then reraise the exception
47-
f.close()
48-
raise
75+
def from_file(cls, file: typing.Union[Path, str, BufferedIOBase], use_mmap: bool = True) -> "KaitaiStruct":
76+
return cls.from_any(file, use_mmap=use_mmap)
4977

5078
@classmethod
51-
def from_bytes(cls, buf):
52-
return cls(KaitaiStream(BytesIO(buf)))
79+
def from_bytes(cls, data: bytes) -> "KaitaiStruct":
80+
return cls.from_any(data)
5381

5482
@classmethod
55-
def from_io(cls, io):
56-
return cls(KaitaiStream(io))
83+
def from_io(cls, io: IOBase) -> "KaitaiStruct":
84+
return cls.from_any(io)
85+
86+
87+
class IKaitaiDownStream(ABC):
88+
__slots__ = ("_io",)
89+
90+
def __init__(self, _io: typing.Any):
91+
self._io = _io
92+
93+
@property
94+
@abstractmethod
95+
def is_entered(self):
96+
raise NotImplementedError
97+
98+
@abstractmethod
99+
def __enter__(self):
100+
raise NotImplementedError()
101+
102+
def __exit__(self, *args, **kwargs):
103+
if self.is_entered:
104+
self._io.__exit__(*args, **kwargs)
105+
self._io = None
57106

58107

59-
class KaitaiStream(object):
60-
def __init__(self, io):
108+
class KaitaiIODownStream(IKaitaiDownStream):
109+
__slots__ = ()
110+
111+
def __init__(self, data: typing.Any):
112+
super().__init__(data)
113+
114+
@property
115+
def is_entered(self):
116+
return isinstance(self._io, IOBase)
117+
118+
def __enter__(self):
119+
if not self.is_entered:
120+
self._io = open(self._io).__enter__()
121+
return self
122+
123+
124+
class KaitaiBytesDownStream(KaitaiIODownStream):
125+
__slots__ = ()
126+
127+
def __init__(self, data: bytes):
128+
super().__init__(data)
129+
130+
131+
class KaitaiFileSyscallDownStream(KaitaiIODownStream):
132+
__slots__ = ()
133+
134+
def __init__(self, io: typing.Union[Path, str, IOBase]):
135+
if isinstance(io, str):
136+
io = Path(io)
137+
super().__init__(io)
138+
139+
140+
class KaitaiRawMMapDownStream(KaitaiIODownStream):
141+
__slots__ = ()
142+
143+
def __init__(self, io: typing.Union[mmap.mmap]):
144+
super().__init__(None)
61145
self._io = io
62-
self.align_to_byte()
146+
147+
@property
148+
def is_entered(self):
149+
return isinstance(self._io, mmap.mmap)
150+
151+
def __enter__(self):
152+
return self
153+
154+
def __exit__(self, *args, **kwargs):
155+
super().__exit__(*args, **kwargs)
156+
157+
158+
class KaitaiFileMapDownStream(KaitaiRawMMapDownStream):
159+
__slots__ = ("file",)
160+
161+
def __init__(self, io: typing.Union[Path, str, IOBase]):
162+
super().__init__(None)
163+
self.file = KaitaiFileSyscallDownStream(io)
164+
165+
@property
166+
def is_entered(self):
167+
return isinstance(self._io, mmap.mmap)
63168

64169
def __enter__(self):
170+
self.file = self.file.__enter__()
171+
self._io = mmap.mmap(self.file.file.fileno(), 0, access=mmap.ACCESS_READ).__enter__()
65172
return self
66173

67174
def __exit__(self, *args, **kwargs):
68-
self.close()
175+
super().__exit__(*args, **kwargs)
176+
if self.file is not None:
177+
self.file.__exit__(*args, **kwargs)
178+
self.file = None
179+
180+
181+
def get_file_down_stream(path: Path, *args, use_mmap: bool = True, **kwargs) -> IKaitaiDownStream:
182+
if use_mmap:
183+
cls = KaitaiFileMapDownStream
184+
else:
185+
cls = KaitaiFileSyscallDownStream
186+
187+
return cls(path, *args, **kwargs)
188+
189+
190+
def get_mmap_downstream(mapping: mmap.mmap):
191+
return KaitaiRawMMapDownStream(mapping)
192+
69193

70-
def close(self):
71-
self._io.close()
194+
downstreamMapping = {
195+
bytes: KaitaiBytesDownStream,
196+
BytesIO: KaitaiBytesDownStream,
197+
str: get_file_down_stream,
198+
Path: get_file_down_stream,
199+
BufferedIOBase: get_file_down_stream,
200+
mmap.mmap: get_mmap_downstream,
201+
}
202+
203+
204+
def get_downstream_ctor(t) -> typing.Type[IKaitaiDownStream]:
205+
ctor = downstreamMapping.get(t, None)
206+
if ctor:
207+
return ctor
208+
types = t.mro()
209+
for t1 in types[1:]:
210+
ctor = downstreamMapping.get(t1, None)
211+
if ctor:
212+
downstreamMapping[t] = ctor
213+
return ctor
214+
raise TypeError("Unsupported type", t, types)
215+
216+
217+
def get_downstream(x: typing.Union[bytes, str, Path], *args, **kwargs) -> IKaitaiDownStream:
218+
return get_downstream_ctor(type(x))(x, *args, **kwargs)
219+
220+
221+
class KaitaiStream():
222+
def __init__(self, o: typing.Union[bytes, str, Path, IKaitaiDownStream]):
223+
if not isinstance(o, IKaitaiDownStream):
224+
o = get_downstream(o)
225+
self._downstream = o
226+
self.align_to_byte()
227+
228+
@property
229+
def _io(self):
230+
return self._downstream._io
231+
232+
def __enter__(self):
233+
self._downstream.__enter__()
234+
return self
235+
236+
@property
237+
def is_entered(self):
238+
return self._downstream is not None and self._downstream.is_entered
239+
240+
def __exit__(self, *args, **kwargs):
241+
self._downstream.__exit__(*args, **kwargs)
72242

73243
# ========================================================================
74244
# Stream positioning
@@ -450,6 +620,7 @@ class KaitaiStructError(Exception):
450620
Stores KSY source path, pointing to an element supposedly guilty of
451621
an error.
452622
"""
623+
453624
def __init__(self, msg, src_path):
454625
super(KaitaiStructError, self).__init__("%s: %s" % (src_path, msg))
455626
self.src_path = src_path

0 commit comments

Comments
 (0)