Skip to content

Commit 0d44783

Browse files
authored
PYTHON-3821 use overload pattern for _DocumentType (#1352)
1 parent c1d3383 commit 0d44783

File tree

3 files changed

+71
-24
lines changed

3 files changed

+71
-24
lines changed

bson/__init__.py

+46-6
Original file line numberDiff line numberDiff line change
@@ -1106,9 +1106,21 @@ def _decode_all(
11061106
_decode_all = _cbson._decode_all # noqa: F811
11071107

11081108

1109+
@overload
1110+
def decode_all(data: "_ReadableBuffer", codec_options: None = None) -> "List[Dict[str, Any]]":
1111+
...
1112+
1113+
1114+
@overload
11091115
def decode_all(
1110-
data: "_ReadableBuffer", codec_options: "Optional[CodecOptions[_DocumentType]]" = None
1116+
data: "_ReadableBuffer", codec_options: "CodecOptions[_DocumentType]"
11111117
) -> "List[_DocumentType]":
1118+
...
1119+
1120+
1121+
def decode_all(
1122+
data: "_ReadableBuffer", codec_options: "Optional[CodecOptions[_DocumentType]]" = None
1123+
) -> "Union[List[Dict[str, Any]], List[_DocumentType]]":
11121124
"""Decode BSON data to multiple documents.
11131125
11141126
`data` must be a bytes-like object implementing the buffer protocol that
@@ -1131,11 +1143,13 @@ def decode_all(
11311143
Replaced `as_class`, `tz_aware`, and `uuid_subtype` options with
11321144
`codec_options`.
11331145
"""
1134-
opts = codec_options or DEFAULT_CODEC_OPTIONS
1135-
if not isinstance(opts, CodecOptions):
1146+
if codec_options is None:
1147+
return _decode_all(data, DEFAULT_CODEC_OPTIONS)
1148+
1149+
if not isinstance(codec_options, CodecOptions):
11361150
raise _CODEC_OPTIONS_TYPE_ERROR
11371151

1138-
return _decode_all(data, opts) # type:ignore[arg-type]
1152+
return _decode_all(data, codec_options)
11391153

11401154

11411155
def _decode_selective(rawdoc: Any, fields: Any, codec_options: Any) -> Mapping[Any, Any]:
@@ -1242,9 +1256,21 @@ def _decode_all_selective(data: Any, codec_options: CodecOptions, fields: Any) -
12421256
]
12431257

12441258

1259+
@overload
1260+
def decode_iter(data: bytes, codec_options: None = None) -> "Iterator[Dict[str, Any]]":
1261+
...
1262+
1263+
1264+
@overload
12451265
def decode_iter(
1246-
data: bytes, codec_options: "Optional[CodecOptions[_DocumentType]]" = None
1266+
data: bytes, codec_options: "CodecOptions[_DocumentType]"
12471267
) -> "Iterator[_DocumentType]":
1268+
...
1269+
1270+
1271+
def decode_iter(
1272+
data: bytes, codec_options: "Optional[CodecOptions[_DocumentType]]" = None
1273+
) -> "Union[Iterator[Dict[str, Any]], Iterator[_DocumentType]]":
12481274
"""Decode BSON data to multiple documents as a generator.
12491275
12501276
Works similarly to the decode_all function, but yields one document at a
@@ -1278,9 +1304,23 @@ def decode_iter(
12781304
yield _bson_to_dict(elements, opts)
12791305

12801306

1307+
@overload
12811308
def decode_file_iter(
1282-
file_obj: Union[BinaryIO, IO], codec_options: "Optional[CodecOptions[_DocumentType]]" = None
1309+
file_obj: Union[BinaryIO, IO], codec_options: None = None
1310+
) -> "Iterator[Dict[str, Any]]":
1311+
...
1312+
1313+
1314+
@overload
1315+
def decode_file_iter(
1316+
file_obj: Union[BinaryIO, IO], codec_options: "CodecOptions[_DocumentType]"
12831317
) -> "Iterator[_DocumentType]":
1318+
...
1319+
1320+
1321+
def decode_file_iter(
1322+
file_obj: Union[BinaryIO, IO], codec_options: "Optional[CodecOptions[_DocumentType]]" = None
1323+
) -> "Union[Iterator[Dict[str, Any]], Iterator[_DocumentType]]":
12841324
"""Decode bson data from a file to multiple documents as a generator.
12851325
12861326
Works similarly to the decode_all function, but reads from the file object

pymongo/collection.py

-18
Original file line numberDiff line numberDiff line change
@@ -427,24 +427,6 @@ def database(self) -> Database[_DocumentType]:
427427
"""
428428
return self.__database
429429

430-
# @overload
431-
# def with_options(
432-
# self,
433-
# codec_options: None = None,
434-
# read_preference: Optional[_ServerMode] = None,
435-
# write_concern: Optional[WriteConcern] = None,
436-
# read_concern: Optional[ReadConcern] = None,
437-
# ) -> Collection[Dict[str, Any]]: ...
438-
439-
# @overload
440-
# def with_options(
441-
# self,
442-
# codec_options: bson.CodecOptions[_DocumentType],
443-
# read_preference: Optional[_ServerMode] = None,
444-
# write_concern: Optional[WriteConcern] = None,
445-
# read_concern: Optional[ReadConcern] = None,
446-
# ) -> Collection[_DocumentType]: ...
447-
448430
def with_options(
449431
self,
450432
codec_options: Optional[bson.CodecOptions[_DocumentTypeArg]] = None,

test/test_typing.py

+25
Original file line numberDiff line numberDiff line change
@@ -242,6 +242,11 @@ def foo(self):
242242
rt_document3 = decode(bsonbytes2, codec_options=codec_options2)
243243
assert rt_document3.raw
244244

245+
def test_bson_decode_no_codec_option(self) -> None:
246+
doc = decode_all(encode({"a": 1}))
247+
assert doc
248+
doc[0]["a"] = 2
249+
245250
def test_bson_decode_all(self) -> None:
246251
doc = {"_id": 1}
247252
bsonbytes = encode(doc)
@@ -266,6 +271,15 @@ def foo(self):
266271
rt_documents3 = decode_all(bsonbytes3, codec_options3)
267272
assert rt_documents3[0].raw
268273

274+
def test_bson_decode_all_no_codec_option(self) -> None:
275+
docs = decode_all(b"")
276+
docs.append({"new": 1})
277+
278+
docs = decode_all(encode({"a": 1}))
279+
assert docs
280+
docs[0]["a"] = 2
281+
docs.append({"new": 1})
282+
269283
def test_bson_decode_iter(self) -> None:
270284
doc = {"_id": 1}
271285
bsonbytes = encode(doc)
@@ -290,6 +304,11 @@ def foo(self):
290304
rt_documents3 = decode_iter(bsonbytes3, codec_options3)
291305
assert next(rt_documents3).raw
292306

307+
def test_bson_decode_iter_no_codec_option(self) -> None:
308+
doc = next(decode_iter(encode({"a": 1})))
309+
assert doc
310+
doc["a"] = 2
311+
293312
def make_tempfile(self, content: bytes) -> Any:
294313
fileobj = tempfile.TemporaryFile()
295314
fileobj.write(content)
@@ -324,6 +343,12 @@ def foo(self):
324343
rt_documents3 = decode_file_iter(fileobj3, codec_options3)
325344
assert next(rt_documents3).raw
326345

346+
def test_bson_decode_file_iter_none_codec_option(self) -> None:
347+
fileobj = self.make_tempfile(encode({"new": 1}))
348+
doc = next(decode_file_iter(fileobj))
349+
assert doc
350+
doc["a"] = 2
351+
327352

328353
class TestDocumentType(unittest.TestCase):
329354
@only_type_check

0 commit comments

Comments
 (0)