Skip to content

Commit 6a5d943

Browse files
committed
Updates json array type tests to account for unittest coverage
1 parent a0543b9 commit 6a5d943

File tree

2 files changed

+72
-7
lines changed

2 files changed

+72
-7
lines changed

db_dtypes/json.py

Lines changed: 7 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -275,7 +275,11 @@ def __hash__(self) -> int:
275275
def to_pandas_dtype(self):
276276
return JSONDtype()
277277

278-
279278
# Register the type to be included in RecordBatches, sent over IPC and received in
280-
# another Python process.
281-
pa.register_extension_type(JSONArrowType())
279+
# another Python process. Also handle potential pre-registration
280+
try:
281+
pa.register_extension_type(JSONArrowType())
282+
except pa.ArrowKeyError:
283+
# Type 'dbjson' might already be registered if the module is reloaded,
284+
# which is okay.
285+
pass

tests/unit/test__init__.py

Lines changed: 65 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -3,12 +3,13 @@
33
import types
44
import warnings
55
from unittest import mock
6+
import pyarrow as pa
67

78
# The module where the version check code resides
89
MODULE_PATH = "db_dtypes"
910
HELPER_MODULE_PATH = f"{MODULE_PATH}._versions_helpers"
1011

11-
@pytest.fixture(autouse=True)
12+
@pytest.fixture
1213
def cleanup_imports():
1314
"""Ensures the target module and its helper are removed from sys.modules
1415
before each test, allowing for clean imports with patching.
@@ -48,7 +49,7 @@ def cleanup_imports():
4849
((3, 12, 0), "3.12.0", False),
4950
]
5051
)
51-
def test_python_version_warning_on_import(mock_version_tuple, version_str, expect_warning):
52+
def test_python_version_warning_on_import(mock_version_tuple, version_str, expect_warning, cleanup_imports):
5253
"""Test that a FutureWarning is raised ONLY for Python 3.7 or 3.8 during import.
5354
"""
5455

@@ -71,8 +72,6 @@ def test_python_version_warning_on_import(mock_version_tuple, version_str, expec
7172
assert len(record) == 1
7273
warning_message = str(record[0].message)
7374
assert "longer supports Python 3.7 and Python 3.8" in warning_message
74-
assert f"Your Python version is {version_str}" in warning_message
75-
assert "https://cloud.google.com/python/docs/supported-python-versions" in warning_message
7675
else:
7776
with warnings.catch_warnings(record=True) as record:
7877
warnings.simplefilter("always")
@@ -88,3 +87,65 @@ def test_python_version_warning_on_import(mock_version_tuple, version_str, expec
8887
assert not found_warning, (
8988
f"Unexpected FutureWarning raised for Python version {version_str}"
9089
)
90+
91+
# --- Test Case 1: JSON types available ---
92+
93+
@pytest.fixture
94+
def cleanup_imports_for_all(request):
95+
"""
96+
Ensures the target module and its dependencies potentially affecting
97+
__all__ are removed from sys.modules before and after each test,
98+
allowing for clean imports with patching. Also handles PyArrow extension type registration.
99+
"""
100+
101+
# Modules that might be checked or imported in __init__
102+
modules_to_clear = [
103+
MODULE_PATH,
104+
f"{MODULE_PATH}.core",
105+
f"{MODULE_PATH}.json",
106+
f"{MODULE_PATH}.version",
107+
f"{MODULE_PATH}._versions_helpers",
108+
]
109+
original_modules = {}
110+
111+
# Store original modules and remove them
112+
for mod_name in modules_to_clear:
113+
original_modules[mod_name] = sys.modules.get(mod_name)
114+
if mod_name in sys.modules:
115+
del sys.modules[mod_name]
116+
117+
yield # Run the test
118+
119+
# Restore original modules after test
120+
for mod_name, original_mod in original_modules.items():
121+
if original_mod:
122+
sys.modules[mod_name] = original_mod
123+
elif mod_name in sys.modules:
124+
# If it wasn't there before but is now, remove it
125+
del sys.modules[mod_name]
126+
127+
def test_all_includes_json_when_available(cleanup_imports_for_all):
128+
"""
129+
Test that __all__ includes JSON types when JSONArray and JSONDtype are available.
130+
"""
131+
132+
# No patching needed for the 'else' block, assume normal import works
133+
# and JSONArray/JSONDtype are truthy.
134+
import db_dtypes
135+
136+
expected_all = [
137+
"__version__",
138+
"DateArray",
139+
"DateDtype",
140+
"JSONDtype",
141+
"JSONArray",
142+
"JSONArrowType",
143+
"TimeArray",
144+
"TimeDtype",
145+
]
146+
# Use set comparison for order independence, as __all__ order isn't critical
147+
assert set(db_dtypes.__all__) == set(expected_all)
148+
# Explicitly check presence of JSON types
149+
assert "JSONDtype" in db_dtypes.__all__
150+
assert "JSONArray" in db_dtypes.__all__
151+
assert "JSONArrowType" in db_dtypes.__all__

0 commit comments

Comments
 (0)