Skip to content

Commit ea6dac4

Browse files
authored
sqlite3: handle return-type with factory argument. (#11571)
1 parent 8558307 commit ea6dac4

File tree

2 files changed

+91
-14
lines changed

2 files changed

+91
-14
lines changed

stdlib/sqlite3/dbapi2.pyi

+65-14
Original file line numberDiff line numberDiff line change
@@ -8,6 +8,7 @@ from typing import Any, Literal, Protocol, SupportsIndex, TypeVar, final, overlo
88
from typing_extensions import Self, TypeAlias
99

1010
_T = TypeVar("_T")
11+
_ConnectionT = TypeVar("_ConnectionT", bound=Connection)
1112
_CursorT = TypeVar("_CursorT", bound=Cursor)
1213
_SqliteData: TypeAlias = str | ReadableBuffer | int | float | None
1314
# Data that is passed through adapters can be of any type accepted by an adapter.
@@ -223,29 +224,79 @@ def adapt(obj: Any, proto: Any, alt: _T, /) -> Any | _T: ...
223224
def complete_statement(statement: str) -> bool: ...
224225

225226
if sys.version_info >= (3, 12):
227+
@overload
226228
def connect(
227229
database: StrOrBytesPath,
228-
timeout: float = ...,
229-
detect_types: int = ...,
230-
isolation_level: str | None = ...,
231-
check_same_thread: bool = ...,
232-
factory: type[Connection] | None = ...,
233-
cached_statements: int = ...,
234-
uri: bool = ...,
230+
timeout: float = 5.0,
231+
detect_types: int = 0,
232+
isolation_level: Literal["DEFERRED", "EXCLUSIVE", "IMMEDIATE"] | None = "DEFERRED",
233+
check_same_thread: bool = True,
234+
cached_statements: int = 128,
235+
uri: bool = False,
236+
*,
235237
autocommit: bool = ...,
236238
) -> Connection: ...
239+
@overload
240+
def connect(
241+
database: StrOrBytesPath,
242+
timeout: float,
243+
detect_types: int,
244+
isolation_level: Literal["DEFERRED", "EXCLUSIVE", "IMMEDIATE"] | None,
245+
check_same_thread: bool,
246+
factory: type[_ConnectionT],
247+
cached_statements: int = 128,
248+
uri: bool = False,
249+
*,
250+
autocommit: bool = ...,
251+
) -> _ConnectionT: ...
252+
@overload
253+
def connect(
254+
database: StrOrBytesPath,
255+
timeout: float = 5.0,
256+
detect_types: int = 0,
257+
isolation_level: Literal["DEFERRED", "EXCLUSIVE", "IMMEDIATE"] | None = "DEFERRED",
258+
check_same_thread: bool = True,
259+
*,
260+
factory: type[_ConnectionT],
261+
cached_statements: int = 128,
262+
uri: bool = False,
263+
autocommit: bool = ...,
264+
) -> _ConnectionT: ...
237265

238266
else:
267+
@overload
239268
def connect(
240269
database: StrOrBytesPath,
241-
timeout: float = ...,
242-
detect_types: int = ...,
243-
isolation_level: str | None = ...,
244-
check_same_thread: bool = ...,
245-
factory: type[Connection] | None = ...,
246-
cached_statements: int = ...,
247-
uri: bool = ...,
270+
timeout: float = 5.0,
271+
detect_types: int = 0,
272+
isolation_level: Literal["DEFERRED", "EXCLUSIVE", "IMMEDIATE"] | None = "DEFERRED",
273+
check_same_thread: bool = True,
274+
cached_statements: int = 128,
275+
uri: bool = False,
248276
) -> Connection: ...
277+
@overload
278+
def connect(
279+
database: StrOrBytesPath,
280+
timeout: float,
281+
detect_types: int,
282+
isolation_level: Literal["DEFERRED", "EXCLUSIVE", "IMMEDIATE"] | None,
283+
check_same_thread: bool,
284+
factory: type[_ConnectionT],
285+
cached_statements: int = 128,
286+
uri: bool = False,
287+
) -> _ConnectionT: ...
288+
@overload
289+
def connect(
290+
database: StrOrBytesPath,
291+
timeout: float = 5.0,
292+
detect_types: int = 0,
293+
isolation_level: Literal["DEFERRED", "EXCLUSIVE", "IMMEDIATE"] | None = "DEFERRED",
294+
check_same_thread: bool = True,
295+
*,
296+
factory: type[_ConnectionT],
297+
cached_statements: int = 128,
298+
uri: bool = False,
299+
) -> _ConnectionT: ...
249300

250301
def enable_callback_tracebacks(enable: bool, /) -> None: ...
251302

test_cases/stdlib/check_sqlite3.py

+26
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,26 @@
1+
from __future__ import annotations
2+
3+
import sqlite3
4+
from typing_extensions import assert_type
5+
6+
7+
class MyConnection(sqlite3.Connection):
8+
pass
9+
10+
11+
# Default return-type is Connection.
12+
assert_type(sqlite3.connect(":memory:"), sqlite3.Connection)
13+
14+
# Providing an alternate factory changes the return-type.
15+
assert_type(sqlite3.connect(":memory:", factory=MyConnection), MyConnection)
16+
17+
# Provides a true positive error. When checking the connect() function,
18+
# mypy should report an arg-type error for the factory argument.
19+
with sqlite3.connect(":memory:", factory=None) as con: # type: ignore
20+
pass
21+
22+
# The Connection class also accepts a `factory` arg but it does not affect
23+
# the return-type. This use case is not idiomatic--connections should be
24+
# established using the `connect()` function, not directly (as shown here).
25+
assert_type(sqlite3.Connection(":memory:", factory=None), sqlite3.Connection)
26+
assert_type(sqlite3.Connection(":memory:", factory=MyConnection), sqlite3.Connection)

0 commit comments

Comments
 (0)