Skip to content

Commit 1830da3

Browse files
committed
Complete overhaul of interaction between KaitaiStream and KaitaiStruct
1 parent 924f325 commit 1830da3

File tree

1 file changed

+195
-24
lines changed

1 file changed

+195
-24
lines changed

Diff for: kaitaistruct.py

+195-24
Original file line numberDiff line numberDiff line change
@@ -1,7 +1,12 @@
1+
import typing
12
import itertools
23
import sys
34
import struct
45
from io import open, BytesIO, SEEK_CUR, SEEK_END # noqa
6+
from io import IOBase, BufferedIOBase
7+
import mmap
8+
from pathlib import Path
9+
from abc import ABC, abstractmethod
510

611
PY2 = sys.version_info[0] == 2
712

@@ -15,51 +20,216 @@
1520
__version__ = '0.9'
1621

1722

18-
class KaitaiStruct(object):
19-
def __init__(self, stream):
20-
self._io = stream
23+
class _NonClosingNonParsingKaitaiStruct:
24+
__slots__ = ("_io", "_parent", "_root")
25+
26+
def __init__(self, _io: "KaitaiStream", _parent: typing.Optional["_NonClosingNonParsingKaitaiStruct"] = None, _root: typing.Optional["_NonClosingNonParsingKaitaiStruct"] = None):
27+
self._io = _io
28+
self._parent = _parent
29+
self._root = _root if _root else self
30+
31+
32+
class NonClosingKaitaiStruct(_NonClosingNonParsingKaitaiStruct, ABC):
33+
__slots__ = ()
34+
35+
@abstractmethod
36+
def _read(self):
37+
raise NotImplementedError()
38+
39+
40+
class KaitaiStruct(NonClosingKaitaiStruct):
41+
__slots__ = ("_shouldExit",)
42+
def __init__(self, io: typing.Union["KaitaiStream", Path, bytes, str]):
43+
if not isinstance(io, KaitaiStream):
44+
io = KaitaiStream(io)
45+
super.__init__(io)
46+
self._shouldExit = False
2147

2248
def __enter__(self):
49+
self._shouldExit = not self.stream.is_entered
50+
if self._shouldExit:
51+
self._io.__enter__()
2352
return self
2453

2554
def __exit__(self, *args, **kwargs):
26-
self.close()
55+
if self.shouldExit:
56+
self._io.__exit__(*args, **kwargs)
2757

28-
def close(self):
29-
self._io.close()
58+
@classmethod
59+
def from_any(cls, o: typing.Union[Path, str]) -> "KaitaiStruct":
60+
with KaitaiStream(o) as io:
61+
s = cls(io)
62+
s._read()
63+
return s
3064

3165
@classmethod
32-
def from_file(cls, filename):
33-
f = open(filename, 'rb')
34-
try:
35-
return cls(KaitaiStream(f))
36-
except Exception:
37-
# close file descriptor, then reraise the exception
38-
f.close()
39-
raise
66+
def from_file(cls, file: typing.Union[Path, str, BufferedIOBase], use_mmap: bool = True) -> "KaitaiStruct":
67+
return cls.from_any(file, use_mmap=use_mmap)
4068

4169
@classmethod
42-
def from_bytes(cls, buf):
43-
return cls(KaitaiStream(BytesIO(buf)))
70+
def from_bytes(cls, data: bytes) -> "KaitaiStruct":
71+
return cls.from_any(data)
4472

4573
@classmethod
46-
def from_io(cls, io):
47-
return cls(KaitaiStream(io))
74+
def from_io(cls, io: IOBase) -> "KaitaiStruct":
75+
return cls.from_any(io)
76+
77+
78+
class IKaitaiDownStream(ABC):
79+
__slots__ = ("_io",)
80+
81+
def __init__(self, _io: typing.Any):
82+
self._io = _io
83+
84+
@property
85+
@abstractmethod
86+
def is_entered(self):
87+
raise NotImplementedError
88+
89+
@abstractmethod
90+
def __enter__(self):
91+
raise NotImplementedError()
92+
93+
def __exit__(self, *args, **kwargs):
94+
if self.is_entered:
95+
self._io.__exit__(*args, **kwargs)
96+
self._io = None
4897

4998

50-
class KaitaiStream(object):
51-
def __init__(self, io):
99+
class KaitaiIODownStream(IKaitaiDownStream):
100+
__slots__ = ()
101+
102+
def __init__(self, data: typing.Any):
103+
super().__init__(data)
104+
105+
@property
106+
def is_entered(self):
107+
return isinstance(self._io, IOBase)
108+
109+
def __enter__(self):
110+
if not self.is_entered:
111+
self._io = open(self._io).__enter__()
112+
return self
113+
114+
115+
class KaitaiBytesDownStream(KaitaiIODownStream):
116+
__slots__ = ()
117+
118+
def __init__(self, data: bytes):
119+
super().__init__(data)
120+
121+
122+
class KaitaiFileSyscallDownStream(KaitaiIODownStream):
123+
__slots__ = ()
124+
125+
def __init__(self, io: typing.Union[Path, str, IOBase]):
126+
if isinstance(io, str):
127+
io = Path(io)
128+
super().__init__(io)
129+
130+
131+
class KaitaiRawMMapDownStream(KaitaiIODownStream):
132+
__slots__ = ()
133+
134+
def __init__(self, io: typing.Union[mmap.mmap]):
135+
super().__init__(None)
52136
self._io = io
53-
self.align_to_byte()
137+
138+
@property
139+
def is_entered(self):
140+
return isinstance(self._io, mmap.mmap)
141+
142+
def __enter__(self):
143+
return self
144+
145+
def __exit__(self, *args, **kwargs):
146+
super().__exit__(*args, **kwargs)
147+
148+
149+
class KaitaiFileMapDownStream(KaitaiRawMMapDownStream):
150+
__slots__ = ("file",)
151+
152+
def __init__(self, io: typing.Union[Path, str, IOBase]):
153+
super().__init__(None)
154+
self.file = KaitaiFileSyscallDownStream(io)
155+
156+
@property
157+
def is_entered(self):
158+
return isinstance(self._io, mmap.mmap)
54159

55160
def __enter__(self):
161+
self.file = self.file.__enter__()
162+
self._io = mmap.mmap(self.file.file.fileno(), 0, access=mmap.ACCESS_READ).__enter__()
56163
return self
57164

58165
def __exit__(self, *args, **kwargs):
59-
self.close()
166+
super().__exit__(*args, **kwargs)
167+
if self.file is not None:
168+
self.file.__exit__(*args, **kwargs)
169+
self.file = None
170+
171+
172+
def get_file_down_stream(path: Path, *args, use_mmap: bool = True, **kwargs) -> IKaitaiDownStream:
173+
if use_mmap:
174+
cls = KaitaiFileMapDownStream
175+
else:
176+
cls = KaitaiFileSyscallDownStream
177+
178+
return cls(path, *args, **kwargs)
179+
180+
181+
def get_mmap_downstream(mapping: mmap.mmap):
182+
return KaitaiRawMMapDownStream(mapping)
183+
60184

61-
def close(self):
62-
self._io.close()
185+
downstreamMapping = {
186+
bytes: KaitaiBytesDownStream,
187+
BytesIO: KaitaiBytesDownStream,
188+
str: get_file_down_stream,
189+
Path: get_file_down_stream,
190+
BufferedIOBase: get_file_down_stream,
191+
mmap.mmap: get_mmap_downstream,
192+
}
193+
194+
195+
def get_downstream_ctor(t) -> typing.Type[IKaitaiDownStream]:
196+
ctor = downstreamMapping.get(t, None)
197+
if ctor:
198+
return ctor
199+
types = t.mro()
200+
for t1 in types[1:]:
201+
ctor = downstreamMapping.get(t1, None)
202+
if ctor:
203+
downstreamMapping[t] = ctor
204+
return ctor
205+
raise TypeError("Unsupported type", t, types)
206+
207+
208+
def get_downstream(x: typing.Union[bytes, str, Path], *args, **kwargs) -> IKaitaiDownStream:
209+
return get_downstream_ctor(type(x))(x, *args, **kwargs)
210+
211+
212+
class KaitaiStream():
213+
def __init__(self, o: typing.Union[bytes, str, Path, IKaitaiDownStream]):
214+
if not isinstance(o, IKaitaiDownStream):
215+
o = get_downstream(o)
216+
self._downstream = o
217+
self.align_to_byte()
218+
219+
@property
220+
def _io(self):
221+
return self._downstream._io
222+
223+
def __enter__(self):
224+
self._downstream.__enter__()
225+
return self
226+
227+
@property
228+
def is_entered(self):
229+
return self._downstream is not None and self._downstream.is_entered
230+
231+
def __exit__(self, *args, **kwargs):
232+
self._downstream.__exit__(*args, **kwargs)
63233

64234
# ========================================================================
65235
# Stream positioning
@@ -419,6 +589,7 @@ class KaitaiStructError(Exception):
419589
Stores KSY source path, pointing to an element supposedly guilty of
420590
an error.
421591
"""
592+
422593
def __init__(self, msg, src_path):
423594
super(KaitaiStructError, self).__init__("%s: %s" % (src_path, msg))
424595
self.src_path = src_path

0 commit comments

Comments
 (0)