Skip to content

Commit c6cdeac

Browse files
committed
Add typings
1 parent 7609144 commit c6cdeac

File tree

2 files changed

+125
-59
lines changed

2 files changed

+125
-59
lines changed

pgproto.pyi

+8
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,8 @@
1+
import codecs
2+
3+
class CodecContext:
4+
def get_text_codec(self) -> codecs.CodecInfo: ...
5+
def is_encoding_utf8(self) -> bool: ...
6+
7+
class ReadBuffer: ...
8+
class WriteBuffer: ...

types.py

+117-59
Original file line numberDiff line numberDiff line change
@@ -5,18 +5,29 @@
55
# the Apache 2.0 License: http://www.apache.org/licenses/LICENSE-2.0
66

77

8+
import builtins
9+
import sys
10+
import typing
11+
import typing_extensions
12+
13+
814
__all__ = (
915
'BitString', 'Point', 'Path', 'Polygon',
1016
'Box', 'Line', 'LineSegment', 'Circle',
1117
)
1218

19+
_BS = typing.TypeVar('_BS', bound='BitString')
20+
_P = typing.TypeVar('_P', bound='Point')
21+
_BitOrder = typing_extensions.Literal['big', 'little']
22+
1323

1424
class BitString:
1525
"""Immutable representation of PostgreSQL `bit` and `varbit` types."""
1626

1727
__slots__ = '_bytes', '_bitlength'
1828

19-
def __init__(self, bitstring=None):
29+
def __init__(self,
30+
bitstring: typing.Optional[builtins.bytes] = None) -> None:
2031
if not bitstring:
2132
self._bytes = bytes()
2233
self._bitlength = 0
@@ -28,7 +39,7 @@ def __init__(self, bitstring=None):
2839
bit_pos = 0
2940

3041
for i, bit in enumerate(bitstring):
31-
if bit == ' ':
42+
if bit == ' ': # type: ignore
3243
continue
3344
bit = int(bit)
3445
if bit != 0 and bit != 1:
@@ -53,14 +64,15 @@ def __init__(self, bitstring=None):
5364
self._bitlength = bitlen
5465

5566
@classmethod
56-
def frombytes(cls, bytes_=None, bitlength=None):
57-
if bitlength is None and bytes_ is None:
58-
bytes_ = bytes()
59-
bitlength = 0
60-
61-
elif bitlength is None:
62-
bitlength = len(bytes_) * 8
63-
67+
def frombytes(cls: typing.Type[_BS],
68+
bytes_: typing.Optional[builtins.bytes] = None,
69+
bitlength: typing.Optional[int] = None) -> _BS:
70+
if bitlength is None:
71+
if bytes_ is None:
72+
bytes_ = bytes()
73+
bitlength = 0
74+
else:
75+
bitlength = len(bytes_) * 8
6476
else:
6577
if bytes_ is None:
6678
bytes_ = bytes(bitlength // 8 + 1)
@@ -87,10 +99,10 @@ def frombytes(cls, bytes_=None, bitlength=None):
8799
return result
88100

89101
@property
90-
def bytes(self):
102+
def bytes(self) -> builtins.bytes:
91103
return self._bytes
92104

93-
def as_string(self):
105+
def as_string(self) -> str:
94106
s = ''
95107

96108
for i in range(self._bitlength):
@@ -100,7 +112,8 @@ def as_string(self):
100112

101113
return s.strip()
102114

103-
def to_int(self, bitorder='big', *, signed=False):
115+
def to_int(self, bitorder: _BitOrder = 'big',
116+
*, signed: bool = False) -> int:
104117
"""Interpret the BitString as a Python int.
105118
Acts similarly to int.from_bytes.
106119
@@ -135,7 +148,8 @@ def to_int(self, bitorder='big', *, signed=False):
135148
return x
136149

137150
@classmethod
138-
def from_int(cls, x, length, bitorder='big', *, signed=False):
151+
def from_int(cls: typing.Type[_BS], x: int, length: int,
152+
bitorder: _BitOrder = 'big', *, signed: bool = False) -> _BS:
139153
"""Represent the Python int x as a BitString.
140154
Acts similarly to int.to_bytes.
141155
@@ -187,27 +201,27 @@ def from_int(cls, x, length, bitorder='big', *, signed=False):
187201
bytes_ = x.to_bytes((length + 7) // 8, byteorder='big')
188202
return cls.frombytes(bytes_, length)
189203

190-
def __repr__(self):
204+
def __repr__(self) -> str:
191205
return '<BitString {}>'.format(self.as_string())
192206

193207
__str__ = __repr__
194208

195-
def __eq__(self, other):
209+
def __eq__(self, other: object) -> bool:
196210
if not isinstance(other, BitString):
197211
return NotImplemented
198212

199213
return (self._bytes == other._bytes and
200214
self._bitlength == other._bitlength)
201215

202-
def __hash__(self):
216+
def __hash__(self) -> int:
203217
return hash((self._bytes, self._bitlength))
204218

205-
def _getitem(self, i):
219+
def _getitem(self, i: int) -> int:
206220
byte = self._bytes[i // 8]
207221
shift = 8 - i % 8 - 1
208222
return (byte >> shift) & 0x1
209223

210-
def __getitem__(self, i):
224+
def __getitem__(self, i: int) -> int:
211225
if isinstance(i, slice):
212226
raise NotImplementedError('BitString does not support slices')
213227

@@ -216,100 +230,134 @@ def __getitem__(self, i):
216230

217231
return self._getitem(i)
218232

219-
def __len__(self):
233+
def __len__(self) -> int:
220234
return self._bitlength
221235

222236

223-
class Point(tuple):
237+
if typing.TYPE_CHECKING or sys.version_info >= (3, 6):
238+
_PointBase = typing.Tuple[float, float]
239+
_BoxBase = typing.Tuple['Point', 'Point']
240+
_LineBase = typing.Tuple[float, float, float]
241+
_LineSegmentBase = typing.Tuple['Point', 'Point']
242+
_CircleBase = typing.Tuple['Point', float]
243+
else:
244+
# In Python 3.5, subclassing from typing.Tuple does not make the
245+
# subclass act like a tuple in certain situations (like starred
246+
# expressions)
247+
_PointBase = tuple
248+
_BoxBase = tuple
249+
_LineBase = tuple
250+
_LineSegmentBase = tuple
251+
_CircleBase = tuple
252+
253+
254+
class Point(_PointBase):
224255
"""Immutable representation of PostgreSQL `point` type."""
225256

226257
__slots__ = ()
227258

228-
def __new__(cls, x, y):
229-
return super().__new__(cls, (float(x), float(y)))
230-
231-
def __repr__(self):
259+
def __new__(cls,
260+
x: typing.Union[typing.SupportsFloat,
261+
'builtins._SupportsIndex',
262+
typing.Text,
263+
builtins.bytes,
264+
builtins.bytearray],
265+
y: typing.Union[typing.SupportsFloat,
266+
'builtins._SupportsIndex',
267+
typing.Text,
268+
builtins.bytes,
269+
builtins.bytearray]) -> 'Point':
270+
return super().__new__(cls,
271+
typing.cast(typing.Any, (float(x), float(y))))
272+
273+
def __repr__(self) -> str:
232274
return '{}.{}({})'.format(
233275
type(self).__module__,
234276
type(self).__name__,
235277
tuple.__repr__(self)
236278
)
237279

238280
@property
239-
def x(self):
281+
def x(self) -> float:
240282
return self[0]
241283

242284
@property
243-
def y(self):
285+
def y(self) -> float:
244286
return self[1]
245287

246288

247-
class Box(tuple):
289+
class Box(_BoxBase):
248290
"""Immutable representation of PostgreSQL `box` type."""
249291

250292
__slots__ = ()
251293

252-
def __new__(cls, high, low):
253-
return super().__new__(cls, (Point(*high), Point(*low)))
294+
def __new__(cls, high: typing.Sequence[float],
295+
low: typing.Sequence[float]) -> 'Box':
296+
return super().__new__(cls,
297+
typing.cast(typing.Any, (Point(*high),
298+
Point(*low))))
254299

255-
def __repr__(self):
300+
def __repr__(self) -> str:
256301
return '{}.{}({})'.format(
257302
type(self).__module__,
258303
type(self).__name__,
259304
tuple.__repr__(self)
260305
)
261306

262307
@property
263-
def high(self):
308+
def high(self) -> Point:
264309
return self[0]
265310

266311
@property
267-
def low(self):
312+
def low(self) -> Point:
268313
return self[1]
269314

270315

271-
class Line(tuple):
316+
class Line(_LineBase):
272317
"""Immutable representation of PostgreSQL `line` type."""
273318

274319
__slots__ = ()
275320

276-
def __new__(cls, A, B, C):
277-
return super().__new__(cls, (A, B, C))
321+
def __new__(cls, A: float, B: float, C: float) -> 'Line':
322+
return super().__new__(cls, typing.cast(typing.Any, (A, B, C)))
278323

279324
@property
280-
def A(self):
325+
def A(self) -> float:
281326
return self[0]
282327

283328
@property
284-
def B(self):
329+
def B(self) -> float:
285330
return self[1]
286331

287332
@property
288-
def C(self):
333+
def C(self) -> float:
289334
return self[2]
290335

291336

292-
class LineSegment(tuple):
337+
class LineSegment(_LineSegmentBase):
293338
"""Immutable representation of PostgreSQL `lseg` type."""
294339

295340
__slots__ = ()
296341

297-
def __new__(cls, p1, p2):
298-
return super().__new__(cls, (Point(*p1), Point(*p2)))
342+
def __new__(cls, p1: typing.Sequence[float],
343+
p2: typing.Sequence[float]) -> 'LineSegment':
344+
return super().__new__(cls,
345+
typing.cast(typing.Any, (Point(*p1),
346+
Point(*p2))))
299347

300-
def __repr__(self):
348+
def __repr__(self) -> str:
301349
return '{}.{}({})'.format(
302350
type(self).__module__,
303351
type(self).__name__,
304352
tuple.__repr__(self)
305353
)
306354

307355
@property
308-
def p1(self):
356+
def p1(self) -> Point:
309357
return self[0]
310358

311359
@property
312-
def p2(self):
360+
def p2(self) -> Point:
313361
return self[1]
314362

315363

@@ -318,34 +366,44 @@ class Path:
318366

319367
__slots__ = '_is_closed', 'points'
320368

321-
def __init__(self, *points, is_closed=False):
369+
def __init__(self, *points: typing.Sequence[float],
370+
is_closed: bool = False) -> None:
322371
self.points = tuple(Point(*p) for p in points)
323372
self._is_closed = is_closed
324373

325374
@property
326-
def is_closed(self):
375+
def is_closed(self) -> bool:
327376
return self._is_closed
328377

329-
def __eq__(self, other):
378+
def __eq__(self, other: object) -> bool:
330379
if not isinstance(other, Path):
331380
return NotImplemented
332381

333382
return (self.points == other.points and
334383
self._is_closed == other._is_closed)
335384

336-
def __hash__(self):
385+
def __hash__(self) -> int:
337386
return hash((self.points, self.is_closed))
338387

339-
def __iter__(self):
388+
def __iter__(self) -> typing.Iterator[Point]:
340389
return iter(self.points)
341390

342-
def __len__(self):
391+
def __len__(self) -> int:
343392
return len(self.points)
344393

345-
def __getitem__(self, i):
394+
@typing.overload
395+
def __getitem__(self, i: int) -> Point:
396+
...
397+
398+
@typing.overload
399+
def __getitem__(self, i: slice) -> typing.Tuple[Point, ...]:
400+
...
401+
402+
def __getitem__(self, i: typing.Union[int, slice]) \
403+
-> typing.Union[Point, typing.Tuple[Point, ...]]:
346404
return self.points[i]
347405

348-
def __contains__(self, point):
406+
def __contains__(self, point: object) -> bool:
349407
return point in self.points
350408

351409

@@ -354,23 +412,23 @@ class Polygon(Path):
354412

355413
__slots__ = ()
356414

357-
def __init__(self, *points):
415+
def __init__(self, *points: typing.Sequence[float]) -> None:
358416
# polygon is always closed
359417
super().__init__(*points, is_closed=True)
360418

361419

362-
class Circle(tuple):
420+
class Circle(_CircleBase):
363421
"""Immutable representation of PostgreSQL `circle` type."""
364422

365423
__slots__ = ()
366424

367-
def __new__(cls, center, radius):
368-
return super().__new__(cls, (center, radius))
425+
def __new__(cls, center: Point, radius: float) -> 'Circle':
426+
return super().__new__(cls, typing.cast(typing.Any, (center, radius)))
369427

370428
@property
371-
def center(self):
429+
def center(self) -> Point:
372430
return self[0]
373431

374432
@property
375-
def radius(self):
433+
def radius(self) -> float:
376434
return self[1]

0 commit comments

Comments
 (0)