Skip to content

Commit 6b30257

Browse files
committed
Complete overhaul of interaction between KaitaiStream and KaitaiStruct
1 parent 4b97420 commit 6b30257

File tree

1 file changed

+166
-25
lines changed

1 file changed

+166
-25
lines changed

kaitaistruct.py

+166-25
Original file line numberDiff line numberDiff line change
@@ -1,7 +1,12 @@
1+
import typing
12
import itertools
23
import sys
34
import struct
45
from io import open, BytesIO, SEEK_CUR, SEEK_END # noqa
6+
from io import IOBase
7+
import mmap
8+
from pathlib import Path
9+
from abc import ABC, abstractmethod
510

611
PY2 = sys.version_info[0] == 2
712

@@ -15,51 +20,187 @@
1520
__version__ = '0.9'
1621

1722

18-
class KaitaiStruct(object):
19-
def __init__(self, stream):
20-
self._io = stream
23+
class _NonClosingNonParsingKaitaiStruct:
24+
__slots__ = ("_io", "_parent", "_root")
25+
def __init__(self, _io:"KaitaiStream", _parent:typing.Optional["_NonClosingNonParsingKaitaiStruct"]=None, _root:typing.Optional["_NonClosingNonParsingKaitaiStruct"]=None):
26+
self._io = _io
27+
self._parent = _parent
28+
self._root = _root if _root else self
29+
30+
31+
class NonClosingKaitaiStruct(_NonClosingNonParsingKaitaiStruct, ABC):
32+
__slots__ = ()
33+
34+
@abstractmethod
35+
def _read(self):
36+
raise NotImplementedError()
37+
38+
39+
class KaitaiStruct(NonClosingKaitaiStruct):
40+
__slots__ = ("_shouldExit",)
41+
def __init__(self, io: typing.Union["KaitaiStream", Path, bytes, str]):
42+
if not isinstance(io, KaitaiStream):
43+
io = KaitaiStream(io)
44+
super.__init__(io)
45+
self._shouldExit = False
2146

2247
def __enter__(self):
48+
self._shouldExit = not stream.is_entered
49+
if self._shouldExit:
50+
self._io.__enter__()
2351
return self
2452

2553
def __exit__(self, *args, **kwargs):
26-
self.close()
54+
if self.shouldExit:
55+
self._io.__exit__(*args, **kwargs)
2756

28-
def close(self):
29-
self._io.close()
57+
@classmethod
58+
def from_any(cls, o: typing.Union[Path, str]) -> "KaitaiStruct":
59+
with KaitaiStream(o) as io:
60+
s = cls(io)
61+
s._read()
62+
return s
3063

3164
@classmethod
32-
def from_file(cls, filename):
33-
f = open(filename, 'rb')
34-
try:
35-
return cls(KaitaiStream(f))
36-
except Exception:
37-
# close file descriptor, then reraise the exception
38-
f.close()
39-
raise
65+
def from_file(cls, file: typing.Union[Path, str, _io._BufferedIOBase], use_mmap:bool=True) -> "KaitaiStruct":
66+
return cls.from_any(file, use_mmap=use_mmap)
4067

4168
@classmethod
42-
def from_bytes(cls, buf):
43-
return cls(KaitaiStream(BytesIO(buf)))
69+
def from_bytes(cls, data: bytes) -> "KaitaiStruct":
70+
return cls.from_any(data)
4471

4572
@classmethod
46-
def from_io(cls, io):
47-
return cls(KaitaiStream(io))
73+
def from_io(cls, io: IOBase) -> "KaitaiStruct":
74+
return cls.from_any(io)
4875

4976

50-
class KaitaiStream(object):
51-
def __init__(self, io):
52-
self._io = io
53-
self.align_to_byte()
77+
class IKaitaiDownStream(ABC):
78+
__slots__ =("_io",)
79+
80+
def __init__(self, _io: typing.Any):
81+
self._io = _io
82+
83+
@abstractmethod
84+
@property
85+
def is_entered(self):
86+
raise NotImplementedError
87+
88+
@abstractmethod
89+
def __enter__(self):
90+
raise NotImplementedError()
91+
92+
def __exit__(self, *args, **kwargs):
93+
if self.is_entered:
94+
self._io.__exit__(*args, **kwargs)
95+
self._io = None
96+
97+
98+
class KaitaiIODownStream(IKaitaiDownStream):
99+
__slots__ = ()
100+
def __init__(self, data: typing.Any):
101+
super().__init__(data)
102+
103+
@property
104+
def is_entered(self):
105+
return isinstance(self._io, IOBase)
54106

55107
def __enter__(self):
108+
if not self.is_entered
109+
self._io = open(self._io).__enter__()
110+
return self
111+
112+
113+
class KaitaiBytesDownStream(KaitaiIODownStream):
114+
__slots__ = ()
115+
def __init__(self, data: bytes):
116+
super().__init__(data)
117+
118+
119+
class KaitaiFileSyscallDownStream(KaitaiIODownStream):
120+
__slots__ = ()
121+
def __init__(self, io: typing.Union[Path, str, IOBase]):
122+
if isinstance(io, str):
123+
io = Path(io)
124+
super().__init__(io)
125+
126+
127+
class KaitaiFileMapDownStream(IKaitaiFileDownStream):
128+
__slots__ = ("file",)
129+
def __init__(self, io: typing.Union[Path, str, IOBase]):
130+
super().__init__(None)
131+
self.file = KaitaiFileSyscallDownStream(io)
132+
133+
@property
134+
def is_entered(self):
135+
return isinstance(self._io, mmap.mmap)
136+
137+
def __enter__(self):
138+
self.file = self.file.__enter__()
139+
self._io = mmap.mmap(self.file.file.fileno(), 0, access=mmap.ACCESS_READ).__enter__()
56140
return self
57141

58142
def __exit__(self, *args, **kwargs):
59-
self.close()
143+
super().__exit__(*args, **kwargs)
144+
if self.file is not None:
145+
self.file.__exit__(*args, **kwargs)
146+
self.file = None
147+
148+
149+
def get_file_down_stream(path: Path, *args, use_mmap: bool=True, **kwargs) -> IKaitaiDownStream:
150+
if use_mmap:
151+
cls = KaitaiFileMapDownStream
152+
else:
153+
cls = KaitaiFileSyscallDownStream
154+
return cls(path, *args, **kwargs)
155+
156+
157+
downstreamMapping = {
158+
bytes: KaitaiBytesDownStream,
159+
BytesIO: KaitaiBytesDownStream,
160+
str: get_file_down_stream,
161+
Path: get_file_down_stream,
162+
_io._BufferedIOBase: get_file_down_stream,
163+
}
164+
165+
166+
def get_downstream_ctor(x) -> typing.Type[IKaitaiDownStream]:
167+
ctor = downstreamMapping.get(t, None)
168+
if ctor:
169+
return ctor
170+
types = t.mro()
171+
for t1 in types[1:]:
172+
ctor = downstreamMapping.get(t1, None)
173+
if ctor:
174+
downstreamMapping[t] = ctor
175+
return ctor
176+
raise TypeError("Unsupported type", t, types)
177+
178+
179+
def get_downstream(x: typing.Union[bytes, str, Path], *args, **kwargs) -> IKaitaiDownStream:
180+
return get_downstream_ctor(type(x))(x, *args, **kwargs)
181+
182+
183+
class KaitaiStream():
184+
def __init__(self, o: typing.Union[bytes, str, Path, IKaitaiDownStream]):
185+
if not isinstance(o, IKaitaiDownStream):
186+
o = get_downstream(o)
187+
self._downstream = o
188+
self.align_to_byte()
189+
190+
@property
191+
def _io(self):
192+
return self._downstream._io
193+
194+
def __enter__(self):
195+
self._downstream.__enter__()
196+
return self
60197

61-
def close(self):
62-
self._io.close()
198+
@property
199+
def is_entered(self):
200+
return self._downstream is not None and self._downstream.is_entered
201+
202+
def __exit__(self, *args, **kwargs):
203+
self._downstream.__exit__(*args, **kwargs)
63204

64205
# ========================================================================
65206
# Stream positioning

0 commit comments

Comments
 (0)