Skip to content

Commit 4e3e89b

Browse files
committed
Add typings
1 parent 7609144 commit 4e3e89b

File tree

2 files changed

+107
-59
lines changed

2 files changed

+107
-59
lines changed

pgproto.pyi

+8
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,8 @@
1+
import codecs
2+
3+
class CodecContext:
4+
def get_text_codec(self) -> codecs.CodecInfo: ...
5+
def is_encoding_utf8(self) -> bool: ...
6+
7+
class ReadBuffer: ...
8+
class WriteBuffer: ...

types.py

+99-59
Original file line numberDiff line numberDiff line change
@@ -5,18 +5,28 @@
55
# the Apache 2.0 License: http://www.apache.org/licenses/LICENSE-2.0
66

77

8+
import builtins
9+
import typing
10+
import typing_extensions
11+
12+
813
__all__ = (
914
'BitString', 'Point', 'Path', 'Polygon',
1015
'Box', 'Line', 'LineSegment', 'Circle',
1116
)
1217

18+
_BS = typing.TypeVar('_BS', bound='BitString')
19+
_P = typing.TypeVar('_P', bound='Point')
20+
_BitOrder = typing_extensions.Literal['big', 'little']
21+
1322

1423
class BitString:
1524
"""Immutable representation of PostgreSQL `bit` and `varbit` types."""
1625

1726
__slots__ = '_bytes', '_bitlength'
1827

19-
def __init__(self, bitstring=None):
28+
def __init__(self,
29+
bitstring: typing.Optional[builtins.bytes] = None) -> None:
2030
if not bitstring:
2131
self._bytes = bytes()
2232
self._bitlength = 0
@@ -28,7 +38,7 @@ def __init__(self, bitstring=None):
2838
bit_pos = 0
2939

3040
for i, bit in enumerate(bitstring):
31-
if bit == ' ':
41+
if bit == ' ': # type: ignore
3242
continue
3343
bit = int(bit)
3444
if bit != 0 and bit != 1:
@@ -53,14 +63,15 @@ def __init__(self, bitstring=None):
5363
self._bitlength = bitlen
5464

5565
@classmethod
56-
def frombytes(cls, bytes_=None, bitlength=None):
57-
if bitlength is None and bytes_ is None:
58-
bytes_ = bytes()
59-
bitlength = 0
60-
61-
elif bitlength is None:
62-
bitlength = len(bytes_) * 8
63-
66+
def frombytes(cls: typing.Type[_BS],
67+
bytes_: typing.Optional[builtins.bytes] = None,
68+
bitlength: typing.Optional[int] = None) -> _BS:
69+
if bitlength is None:
70+
if bytes_ is None:
71+
bytes_ = bytes()
72+
bitlength = 0
73+
else:
74+
bitlength = len(bytes_) * 8
6475
else:
6576
if bytes_ is None:
6677
bytes_ = bytes(bitlength // 8 + 1)
@@ -87,10 +98,10 @@ def frombytes(cls, bytes_=None, bitlength=None):
8798
return result
8899

89100
@property
90-
def bytes(self):
101+
def bytes(self) -> builtins.bytes:
91102
return self._bytes
92103

93-
def as_string(self):
104+
def as_string(self) -> str:
94105
s = ''
95106

96107
for i in range(self._bitlength):
@@ -100,7 +111,8 @@ def as_string(self):
100111

101112
return s.strip()
102113

103-
def to_int(self, bitorder='big', *, signed=False):
114+
def to_int(self, bitorder: _BitOrder = 'big',
115+
*, signed: bool = False) -> int:
104116
"""Interpret the BitString as a Python int.
105117
Acts similarly to int.from_bytes.
106118
@@ -135,7 +147,8 @@ def to_int(self, bitorder='big', *, signed=False):
135147
return x
136148

137149
@classmethod
138-
def from_int(cls, x, length, bitorder='big', *, signed=False):
150+
def from_int(cls: typing.Type[_BS], x: int, length: int,
151+
bitorder: _BitOrder = 'big', *, signed: bool = False) -> _BS:
139152
"""Represent the Python int x as a BitString.
140153
Acts similarly to int.to_bytes.
141154
@@ -187,27 +200,27 @@ def from_int(cls, x, length, bitorder='big', *, signed=False):
187200
bytes_ = x.to_bytes((length + 7) // 8, byteorder='big')
188201
return cls.frombytes(bytes_, length)
189202

190-
def __repr__(self):
203+
def __repr__(self) -> str:
191204
return '<BitString {}>'.format(self.as_string())
192205

193206
__str__ = __repr__
194207

195-
def __eq__(self, other):
208+
def __eq__(self, other: object) -> bool:
196209
if not isinstance(other, BitString):
197210
return NotImplemented
198211

199212
return (self._bytes == other._bytes and
200213
self._bitlength == other._bitlength)
201214

202-
def __hash__(self):
215+
def __hash__(self) -> int:
203216
return hash((self._bytes, self._bitlength))
204217

205-
def _getitem(self, i):
218+
def _getitem(self, i: int) -> int:
206219
byte = self._bytes[i // 8]
207220
shift = 8 - i % 8 - 1
208221
return (byte >> shift) & 0x1
209222

210-
def __getitem__(self, i):
223+
def __getitem__(self, i: int) -> int:
211224
if isinstance(i, slice):
212225
raise NotImplementedError('BitString does not support slices')
213226

@@ -216,100 +229,117 @@ def __getitem__(self, i):
216229

217230
return self._getitem(i)
218231

219-
def __len__(self):
232+
def __len__(self) -> int:
220233
return self._bitlength
221234

222235

223-
class Point(tuple):
236+
class Point(typing.Tuple[float, float]):
224237
"""Immutable representation of PostgreSQL `point` type."""
225238

226239
__slots__ = ()
227240

228-
def __new__(cls, x, y):
229-
return super().__new__(cls, (float(x), float(y)))
230-
231-
def __repr__(self):
241+
def __new__(cls,
242+
x: typing.Union[typing.SupportsFloat,
243+
'builtins._SupportsIndex',
244+
typing.Text,
245+
builtins.bytes,
246+
builtins.bytearray],
247+
y: typing.Union[typing.SupportsFloat,
248+
'builtins._SupportsIndex',
249+
typing.Text,
250+
builtins.bytes,
251+
builtins.bytearray]) -> 'Point':
252+
return super().__new__(cls,
253+
typing.cast(typing.Any, (float(x), float(y))))
254+
255+
def __repr__(self) -> str:
232256
return '{}.{}({})'.format(
233257
type(self).__module__,
234258
type(self).__name__,
235259
tuple.__repr__(self)
236260
)
237261

238262
@property
239-
def x(self):
263+
def x(self) -> float:
240264
return self[0]
241265

242266
@property
243-
def y(self):
267+
def y(self) -> float:
244268
return self[1]
245269

246270

247-
class Box(tuple):
271+
class Box(typing.Tuple[Point, Point]):
248272
"""Immutable representation of PostgreSQL `box` type."""
249273

250274
__slots__ = ()
251275

252-
def __new__(cls, high, low):
253-
return super().__new__(cls, (Point(*high), Point(*low)))
276+
def __new__(cls, high: typing.Sequence[float],
277+
low: typing.Sequence[float]) -> 'Box':
278+
return super().__new__(cls,
279+
typing.cast(typing.Any, (Point(*high),
280+
Point(*low))))
254281

255-
def __repr__(self):
282+
def __repr__(self) -> str:
256283
return '{}.{}({})'.format(
257284
type(self).__module__,
258285
type(self).__name__,
259286
tuple.__repr__(self)
260287
)
261288

262289
@property
263-
def high(self):
290+
def high(self) -> Point:
264291
return self[0]
265292

266293
@property
267-
def low(self):
294+
def low(self) -> Point:
268295
return self[1]
269296

270297

271-
class Line(tuple):
298+
class Line(typing.Tuple[float, float, float]):
272299
"""Immutable representation of PostgreSQL `line` type."""
273300

274301
__slots__ = ()
275302

276-
def __new__(cls, A, B, C):
277-
return super().__new__(cls, (A, B, C))
303+
def __new__(cls, A: float, B: float, C: float) -> 'Line':
304+
return super().__new__(cls, typing.cast(typing.Any, (A, B, C)))
278305

279306
@property
280-
def A(self):
307+
def A(self) -> float:
281308
return self[0]
282309

283310
@property
284-
def B(self):
311+
def B(self) -> float:
285312
return self[1]
286313

287314
@property
288-
def C(self):
315+
def C(self) -> float:
289316
return self[2]
290317

291318

292-
class LineSegment(tuple):
319+
class LineSegment(typing.Tuple[Point, Point]):
293320
"""Immutable representation of PostgreSQL `lseg` type."""
294321

295322
__slots__ = ()
296323

297-
def __new__(cls, p1, p2):
298-
return super().__new__(cls, (Point(*p1), Point(*p2)))
324+
def __new__(cls, p1: typing.Sequence[float],
325+
p2: typing.Sequence[float]) -> 'LineSegment':
326+
return super().__new__(cls,
327+
typing.cast(typing.Any, (Point(*p1),
328+
Point(*p2))))
299329

300-
def __repr__(self):
330+
def __repr__(self) -> str:
301331
return '{}.{}({})'.format(
302332
type(self).__module__,
303333
type(self).__name__,
304334
tuple.__repr__(self)
305335
)
306336

307337
@property
308-
def p1(self):
338+
def p1(self) -> Point:
309339
return self[0]
310340

311341
@property
312-
def p2(self):
342+
def p2(self) -> Point:
313343
return self[1]
314344

315345

@@ -318,34 +348,44 @@ class Path:
318348

319349
__slots__ = '_is_closed', 'points'
320350

321-
def __init__(self, *points, is_closed=False):
351+
def __init__(self, *points: typing.Sequence[float],
352+
is_closed: bool = False) -> None:
322353
self.points = tuple(Point(*p) for p in points)
323354
self._is_closed = is_closed
324355

325356
@property
326-
def is_closed(self):
357+
def is_closed(self) -> bool:
327358
return self._is_closed
328359

329-
def __eq__(self, other):
360+
def __eq__(self, other: object) -> bool:
330361
if not isinstance(other, Path):
331362
return NotImplemented
332363

333364
return (self.points == other.points and
334365
self._is_closed == other._is_closed)
335366

336-
def __hash__(self):
367+
def __hash__(self) -> int:
337368
return hash((self.points, self.is_closed))
338369

339-
def __iter__(self):
370+
def __iter__(self) -> typing.Iterator[Point]:
340371
return iter(self.points)
341372

342-
def __len__(self):
373+
def __len__(self) -> int:
343374
return len(self.points)
344375

345-
def __getitem__(self, i):
376+
@typing.overload
377+
def __getitem__(self, i: int) -> Point:
378+
...
379+
380+
@typing.overload
381+
def __getitem__(self, i: slice) -> typing.Tuple[Point, ...]:
382+
...
383+
384+
def __getitem__(self, i: typing.Union[int, slice]) \
385+
-> typing.Union[Point, typing.Tuple[Point, ...]]:
346386
return self.points[i]
347387

348-
def __contains__(self, point):
388+
def __contains__(self, point: object) -> bool:
349389
return point in self.points
350390

351391

@@ -354,23 +394,23 @@ class Polygon(Path):
354394

355395
__slots__ = ()
356396

357-
def __init__(self, *points):
397+
def __init__(self, *points: typing.Sequence[float]) -> None:
358398
# polygon is always closed
359399
super().__init__(*points, is_closed=True)
360400

361401

362-
class Circle(tuple):
402+
class Circle(typing.Tuple[Point, float]):
363403
"""Immutable representation of PostgreSQL `circle` type."""
364404

365405
__slots__ = ()
366406

367-
def __new__(cls, center, radius):
368-
return super().__new__(cls, (center, radius))
407+
def __new__(cls, center: Point, radius: float) -> 'Circle':
408+
return super().__new__(cls, typing.cast(typing.Any, (center, radius)))
369409

370410
@property
371-
def center(self):
411+
def center(self) -> Point:
372412
return self[0]
373413

374414
@property
375-
def radius(self):
415+
def radius(self) -> float:
376416
return self[1]

0 commit comments

Comments
 (0)