Skip to content
This repository was archived by the owner on May 17, 2024. It is now read-only.

Commit ad34a1d

Browse files
author
Sergey Vasilyev
committed
Annotate simple iterators
1 parent ec6786c commit ad34a1d

File tree

5 files changed

+35
-17
lines changed

5 files changed

+35
-17
lines changed

data_diff/databases/base.py

Lines changed: 20 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -5,7 +5,22 @@
55
import math
66
import sys
77
import logging
8-
from typing import Any, Callable, ClassVar, Dict, Generator, Tuple, Optional, Sequence, Type, List, Union, TypeVar
8+
from typing import (
9+
Any,
10+
Callable,
11+
ClassVar,
12+
Dict,
13+
Generator,
14+
Iterator,
15+
NewType,
16+
Tuple,
17+
Optional,
18+
Sequence,
19+
Type,
20+
List,
21+
Union,
22+
TypeVar,
23+
)
924
from functools import partial, wraps
1025
from concurrent.futures import ThreadPoolExecutor
1126
import threading
@@ -885,20 +900,21 @@ def optimizer_hints(self, hints: str) -> str:
885900

886901

887902
T = TypeVar("T", bound=BaseDialect)
903+
Row = Sequence[Any]
888904

889905

890906
@attrs.define(frozen=True)
891907
class QueryResult:
892-
rows: list
908+
rows: List[Row]
893909
columns: Optional[list] = None
894910

895-
def __iter__(self):
911+
def __iter__(self) -> Iterator[Row]:
896912
return iter(self.rows)
897913

898914
def __len__(self) -> int:
899915
return len(self.rows)
900916

901-
def __getitem__(self, i):
917+
def __getitem__(self, i) -> Row:
902918
return self.rows[i]
903919

904920

data_diff/diff_tables.py

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -6,7 +6,7 @@
66
from enum import Enum
77
from contextlib import contextmanager
88
from operator import methodcaller
9-
from typing import Dict, Set, List, Tuple, Iterator, Optional, Union
9+
from typing import Any, Dict, Set, List, Tuple, Iterator, Optional, Union
1010
from concurrent.futures import ThreadPoolExecutor, as_completed
1111

1212
import attrs
@@ -89,7 +89,7 @@ class DiffResultWrapper:
8989
stats: dict
9090
result_list: list = attrs.field(factory=list)
9191

92-
def __iter__(self):
92+
def __iter__(self) -> Iterator[Any]:
9393
yield from self.result_list
9494
for i in self.diff:
9595
self.result_list.append(i)

data_diff/schema.py

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -1,5 +1,5 @@
11
import logging
2-
from typing import Any, Collection, Iterable, Optional
2+
from typing import Any, Collection, Iterator, Optional
33

44
import attrs
55

@@ -28,7 +28,7 @@ class RawColumnInfo(Collection[Any]):
2828
collation_name: Optional[str] = None
2929

3030
# It was a tuple once, so we keep it backward compatible temporarily, until remade to classes.
31-
def __iter__(self) -> Iterable[Any]:
31+
def __iter__(self) -> Iterator[Any]:
3232
return iter(
3333
(self.column_name, self.data_type, self.datetime_precision, self.numeric_precision, self.numeric_scale)
3434
)

data_diff/thread_utils.py

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -5,7 +5,7 @@
55
from concurrent.futures import ThreadPoolExecutor
66
from concurrent.futures.thread import _WorkItem
77
from time import sleep
8-
from typing import Callable, Iterator, Optional
8+
from typing import Any, Callable, Iterator, Optional
99

1010
import attrs
1111

@@ -80,7 +80,7 @@ def _worker(self, fn, *args, **kwargs):
8080
def submit(self, fn: Callable, *args, priority: int = 0, **kwargs):
8181
self._futures.append(self._pool.submit(self._worker, fn, *args, priority=priority, **kwargs))
8282

83-
def __iter__(self) -> Iterator:
83+
def __iter__(self) -> Iterator[Any]:
8484
while True:
8585
if self._exception:
8686
raise self._exception

tests/test_database_types.py

Lines changed: 9 additions & 7 deletions
Original file line numberDiff line numberDiff line change
@@ -8,9 +8,11 @@
88
import logging
99
from decimal import Decimal
1010
from itertools import islice, repeat, chain
11+
from typing import Iterator
1112

1213
from parameterized import parameterized
1314

15+
from data_diff.databases.base import Row
1416
from data_diff.utils import number_to_human
1517
from data_diff.queries.api import table, commit, this, Code
1618
from data_diff.queries.api import insert_rows_in_batches
@@ -371,7 +373,7 @@ def __init__(self, table_path, conn) -> None:
371373
self.table_path = table_path
372374
self.conn = conn
373375

374-
def __iter__(self):
376+
def __iter__(self) -> Iterator[Row]:
375377
last_id = 0
376378
while True:
377379
query = (
@@ -402,7 +404,7 @@ def __init__(self, max) -> None:
402404
super().__init__()
403405
self.max = max
404406

405-
def __iter__(self):
407+
def __iter__(self) -> Iterator[datetime]:
406408
initial = datetime(2000, 1, 1, 0, 0, 0, 0)
407409
step = timedelta(seconds=3, microseconds=571)
408410
return islice(chain(self.MANUAL_FAKES, accumulate(repeat(step), initial=initial)), self.max)
@@ -418,7 +420,7 @@ def __init__(self, max) -> None:
418420
super().__init__()
419421
self.max = max
420422

421-
def __iter__(self):
423+
def __iter__(self) -> Iterator[int]:
422424
initial = -128
423425
step = 1
424426
return islice(chain(self.MANUAL_FAKES, accumulate(repeat(step), initial=initial)), self.max)
@@ -434,7 +436,7 @@ def __init__(self, max) -> None:
434436
super().__init__()
435437
self.max = max
436438

437-
def __iter__(self):
439+
def __iter__(self) -> Iterator[bool]:
438440
return iter(self.MANUAL_FAKES[: self.max])
439441

440442
def __len__(self) -> int:
@@ -465,7 +467,7 @@ def __init__(self, max) -> None:
465467
super().__init__()
466468
self.max = max
467469

468-
def __iter__(self):
470+
def __iter__(self) -> Iterator[float]:
469471
initial = -10.0001
470472
step = 0.00571
471473
return islice(chain(self.MANUAL_FAKES, accumulate(repeat(step), initial=initial)), self.max)
@@ -482,7 +484,7 @@ def __init__(self, max) -> None:
482484
def __len__(self) -> int:
483485
return self.max
484486

485-
def __iter__(self):
487+
def __iter__(self) -> Iterator[uuid.UUID]:
486488
return (uuid.uuid1(i) for i in range(self.max))
487489

488490

@@ -495,7 +497,7 @@ def __init__(self, max) -> None:
495497
super().__init__()
496498
self.max = max
497499

498-
def __iter__(self):
500+
def __iter__(self) -> Iterator[str]:
499501
return iter(self.MANUAL_FAKES[: self.max])
500502

501503
def __len__(self) -> int:

0 commit comments

Comments
 (0)