Skip to content
This repository was archived by the owner on May 17, 2024. It is now read-only.

Commit f08d821

Browse files
authored
Merge pull request #134 from datafold/inf_threshold
Fixed tests; bisection_threshold can now be inf
2 parents cc2a323 + 01fa893 commit f08d821

File tree

2 files changed

+6
-8
lines changed

2 files changed

+6
-8
lines changed

data_diff/diff_tables.py

Lines changed: 4 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -4,9 +4,10 @@
44
from abc import ABC, abstractmethod
55
import time
66
import os
7+
from numbers import Number
78
from operator import attrgetter, methodcaller
89
from collections import defaultdict
9-
from typing import List, Tuple, Iterator, Optional, Type
10+
from typing import List, Tuple, Iterator, Optional
1011
import logging
1112
from concurrent.futures import ThreadPoolExecutor
1213

@@ -263,14 +264,14 @@ class TableDiffer:
263264
264265
Parameters:
265266
bisection_factor (int): Into how many segments to bisect per iteration.
266-
bisection_threshold (int): When should we stop bisecting and compare locally (in row count).
267+
bisection_threshold (Number): When should we stop bisecting and compare locally (in row count).
267268
threaded (bool): Enable/disable threaded diffing. Needed to take advantage of database threads.
268269
max_threadpool_size (int): Maximum size of each threadpool. ``None`` means auto. Only relevant when `threaded` is ``True``.
269270
There may be many pools, so number of actual threads can be a lot higher.
270271
"""
271272

272273
bisection_factor: int = DEFAULT_BISECTION_FACTOR
273-
bisection_threshold: int = DEFAULT_BISECTION_THRESHOLD
274+
bisection_threshold: Number = DEFAULT_BISECTION_THRESHOLD # Accepts inf for tests
274275
threaded: bool = True
275276
max_threadpool_size: Optional[int] = 1
276277

tests/test_database_types.py

Lines changed: 2 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -589,7 +589,7 @@ def test_types(self, source_db, target_db, source_type, target_type, type_catego
589589
# reasonable amount of rows each. These will then be downloaded in
590590
# parallel, using the existing implementation.
591591
dl_factor = max(int(N_SAMPLES / 100_000), 2) if BENCHMARK else 2
592-
dl_threshold = int(N_SAMPLES / dl_factor) + 1 if BENCHMARK else N_SAMPLES + 1
592+
dl_threshold = int(N_SAMPLES / dl_factor) + 1 if BENCHMARK else math.inf
593593
dl_threads = 1
594594
differ = TableDiffer(
595595
bisection_threshold=dl_threshold, bisection_factor=dl_factor, max_threadpool_size=dl_threads
@@ -599,10 +599,7 @@ def test_types(self, source_db, target_db, source_type, target_type, type_catego
599599
download_duration = time.time() - start
600600
expected = []
601601
self.assertEqual(expected, diff)
602-
if type_category == "uuid":
603-
pass # UUIDs aren't serial, so they mess with the first max_rows estimation.
604-
else:
605-
self.assertEqual(len(sample_values), differ.stats.get("rows_downloaded", 0))
602+
self.assertEqual(len(sample_values), differ.stats.get("rows_downloaded", 0))
606603

607604
result = {
608605
"test": self._testMethodName,

0 commit comments

Comments
 (0)