Skip to content

Commit a584140

Browse files
committed
Complete overhaul of interaction between KaitaiStream and KaitaiStruct
1 parent c7f2b89 commit a584140

File tree

1 file changed

+167
-25
lines changed

1 file changed

+167
-25
lines changed

kaitaistruct.py

+167-25
Original file line numberDiff line numberDiff line change
@@ -1,7 +1,12 @@
1+
import typing
12
import itertools
23
import sys
34
import struct
45
from io import open, BytesIO, SEEK_CUR, SEEK_END # noqa
6+
from io import IOBase, BufferedIOBase
7+
import mmap
8+
from pathlib import Path
9+
from abc import ABC, abstractmethod
510

611
PY2 = sys.version_info[0] == 2
712

@@ -15,51 +20,188 @@
1520
__version__ = '0.9'
1621

1722

18-
class KaitaiStruct(object):
19-
def __init__(self, stream):
20-
self._io = stream
23+
class _NonClosingNonParsingKaitaiStruct:
24+
__slots__ = ("_io", "_parent", "_root")
25+
26+
def __init__(self, _io:"KaitaiStream", _parent:typing.Optional["_NonClosingNonParsingKaitaiStruct"]=None, _root:typing.Optional["_NonClosingNonParsingKaitaiStruct"]=None):
27+
self._io = _io
28+
self._parent = _parent
29+
self._root = _root if _root else self
30+
31+
32+
class NonClosingKaitaiStruct(_NonClosingNonParsingKaitaiStruct, ABC):
33+
__slots__ = ()
34+
35+
@abstractmethod
36+
def _read(self):
37+
raise NotImplementedError()
38+
39+
40+
class KaitaiStruct(NonClosingKaitaiStruct):
41+
__slots__ = ("_shouldExit",)
42+
def __init__(self, io: typing.Union["KaitaiStream", Path, bytes, str]):
43+
if not isinstance(io, KaitaiStream):
44+
io = KaitaiStream(io)
45+
super.__init__(io)
46+
self._shouldExit = False
2147

2248
def __enter__(self):
49+
self._shouldExit = not stream.is_entered
50+
if self._shouldExit:
51+
self._io.__enter__()
2352
return self
2453

2554
def __exit__(self, *args, **kwargs):
26-
self.close()
55+
if self.shouldExit:
56+
self._io.__exit__(*args, **kwargs)
2757

28-
def close(self):
29-
self._io.close()
58+
@classmethod
59+
def from_any(cls, o: typing.Union[Path, str]) -> "KaitaiStruct":
60+
with KaitaiStream(o) as io:
61+
s = cls(io)
62+
s._read()
63+
return s
3064

3165
@classmethod
32-
def from_file(cls, filename):
33-
f = open(filename, 'rb')
34-
try:
35-
return cls(KaitaiStream(f))
36-
except Exception:
37-
# close file descriptor, then reraise the exception
38-
f.close()
39-
raise
66+
def from_file(cls, file: typing.Union[Path, str, BufferedIOBase], use_mmap:bool=True) -> "KaitaiStruct":
67+
return cls.from_any(file, use_mmap=use_mmap)
4068

4169
@classmethod
42-
def from_bytes(cls, buf):
43-
return cls(KaitaiStream(BytesIO(buf)))
70+
def from_bytes(cls, data: bytes) -> "KaitaiStruct":
71+
return cls.from_any(data)
4472

4573
@classmethod
46-
def from_io(cls, io):
47-
return cls(KaitaiStream(io))
74+
def from_io(cls, io: IOBase) -> "KaitaiStruct":
75+
return cls.from_any(io)
4876

4977

50-
class KaitaiStream(object):
51-
def __init__(self, io):
52-
self._io = io
53-
self.align_to_byte()
78+
class IKaitaiDownStream(ABC):
79+
__slots__ =("_io",)
80+
81+
def __init__(self, _io: typing.Any):
82+
self._io = _io
83+
84+
@property
85+
@abstractmethod
86+
def is_entered(self):
87+
raise NotImplementedError
88+
89+
@abstractmethod
90+
def __enter__(self):
91+
raise NotImplementedError()
92+
93+
def __exit__(self, *args, **kwargs):
94+
if self.is_entered:
95+
self._io.__exit__(*args, **kwargs)
96+
self._io = None
97+
98+
99+
class KaitaiIODownStream(IKaitaiDownStream):
100+
__slots__ = ()
101+
def __init__(self, data: typing.Any):
102+
super().__init__(data)
103+
104+
@property
105+
def is_entered(self):
106+
return isinstance(self._io, IOBase)
54107

55108
def __enter__(self):
109+
if not self.is_entered:
110+
self._io = open(self._io).__enter__()
111+
return self
112+
113+
114+
class KaitaiBytesDownStream(KaitaiIODownStream):
115+
__slots__ = ()
116+
def __init__(self, data: bytes):
117+
super().__init__(data)
118+
119+
120+
class KaitaiFileSyscallDownStream(KaitaiIODownStream):
121+
__slots__ = ()
122+
def __init__(self, io: typing.Union[Path, str, IOBase]):
123+
if isinstance(io, str):
124+
io = Path(io)
125+
super().__init__(io)
126+
127+
128+
class KaitaiFileMapDownStream(KaitaiIODownStream):
129+
__slots__ = ("file",)
130+
def __init__(self, io: typing.Union[Path, str, IOBase]):
131+
super().__init__(None)
132+
self.file = KaitaiFileSyscallDownStream(io)
133+
134+
@property
135+
def is_entered(self):
136+
return isinstance(self._io, mmap.mmap)
137+
138+
def __enter__(self):
139+
self.file = self.file.__enter__()
140+
self._io = mmap.mmap(self.file.file.fileno(), 0, access=mmap.ACCESS_READ).__enter__()
56141
return self
57142

58143
def __exit__(self, *args, **kwargs):
59-
self.close()
144+
super().__exit__(*args, **kwargs)
145+
if self.file is not None:
146+
self.file.__exit__(*args, **kwargs)
147+
self.file = None
148+
149+
150+
def get_file_down_stream(path: Path, *args, use_mmap: bool=True, **kwargs) -> IKaitaiDownStream:
151+
if use_mmap:
152+
cls = KaitaiFileMapDownStream
153+
else:
154+
cls = KaitaiFileSyscallDownStream
155+
return cls(path, *args, **kwargs)
156+
157+
158+
downstreamMapping = {
159+
bytes: KaitaiBytesDownStream,
160+
BytesIO: KaitaiBytesDownStream,
161+
str: get_file_down_stream,
162+
Path: get_file_down_stream,
163+
BufferedIOBase: get_file_down_stream,
164+
}
165+
166+
167+
def get_downstream_ctor(x) -> typing.Type[IKaitaiDownStream]:
168+
ctor = downstreamMapping.get(t, None)
169+
if ctor:
170+
return ctor
171+
types = t.mro()
172+
for t1 in types[1:]:
173+
ctor = downstreamMapping.get(t1, None)
174+
if ctor:
175+
downstreamMapping[t] = ctor
176+
return ctor
177+
raise TypeError("Unsupported type", t, types)
178+
179+
180+
def get_downstream(x: typing.Union[bytes, str, Path], *args, **kwargs) -> IKaitaiDownStream:
181+
return get_downstream_ctor(type(x))(x, *args, **kwargs)
182+
183+
184+
class KaitaiStream():
185+
def __init__(self, o: typing.Union[bytes, str, Path, IKaitaiDownStream]):
186+
if not isinstance(o, IKaitaiDownStream):
187+
o = get_downstream(o)
188+
self._downstream = o
189+
self.align_to_byte()
190+
191+
@property
192+
def _io(self):
193+
return self._downstream._io
194+
195+
def __enter__(self):
196+
self._downstream.__enter__()
197+
return self
60198

61-
def close(self):
62-
self._io.close()
199+
@property
200+
def is_entered(self):
201+
return self._downstream is not None and self._downstream.is_entered
202+
203+
def __exit__(self, *args, **kwargs):
204+
self._downstream.__exit__(*args, **kwargs)
63205

64206
# ========================================================================
65207
# Stream positioning

0 commit comments

Comments
 (0)