Skip to content

Commit e77daa6

Browse files
authored
Merge pull request #2861 from bagerard/warning_dup_Document_class
refactor _document_registry + log a warning when user register multip…
2 parents 11943d9 + f0de61e commit e77daa6

File tree

11 files changed

+96
-61
lines changed

11 files changed

+96
-61
lines changed

docs/changelog.rst

+3
Original file line numberDiff line numberDiff line change
@@ -13,6 +13,8 @@ Development
1313
- make sure to read https://www.mongodb.com/docs/manual/core/transactions-in-applications/#callback-api-vs-core-api
1414
- run_in_transaction context manager relies on Pymongo coreAPI, it will retry automatically in case of `UnknownTransactionCommitResult` but not `TransientTransactionError` exceptions
1515
- Using .count() in a transaction will always use Collection.count_document (as estimated_document_count is not supported in transactions)
16+
- BREAKING CHANGE: wrap _document_registry (normally not used by end users) with _DocumentRegistry which acts as a singleton to access the registry
17+
- Log a warning in case users creates multiple Document classes with the same name as it can lead to unexpected behavior #1778
1618
- Fix use of $geoNear or $collStats in aggregate #2493
1719
- BREAKING CHANGE: Further to the deprecation warning, remove ability to use an unpacked list to `Queryset.aggregate(*pipeline)`, a plain list must be provided instead `Queryset.aggregate(pipeline)`, as it's closer to pymongo interface
1820
- BREAKING CHANGE: Further to the deprecation warning, remove `full_response` from `QuerySet.modify` as it wasn't supported with Pymongo 3+
@@ -21,6 +23,7 @@ Development
2123
- BREAKING CHANGE: Remove LongField as it's equivalent to IntField since we drop support to Python2 long time ago (User should simply switch to IntField) #2309
2224
- BugFix - Calling .clear on a ListField wasn't being marked as changed (and flushed to db upon .save()) #2858
2325

26+
2427
Changes in 0.29.0
2528
=================
2629
- Fix weakref in EmbeddedDocumentListField (causing brief mem leak in certain circumstances) #2827

mongoengine/base/__init__.py

+1-2
Original file line numberDiff line numberDiff line change
@@ -13,8 +13,7 @@
1313
__all__ = (
1414
# common
1515
"UPDATE_OPERATORS",
16-
"_document_registry",
17-
"get_document",
16+
"_DocumentRegistry",
1817
# datastructures
1918
"BaseDict",
2019
"BaseList",

mongoengine/base/common.py

+54-23
Original file line numberDiff line numberDiff line change
@@ -1,6 +1,8 @@
1+
import warnings
2+
13
from mongoengine.errors import NotRegistered
24

3-
__all__ = ("UPDATE_OPERATORS", "get_document", "_document_registry")
5+
__all__ = ("UPDATE_OPERATORS", "_DocumentRegistry")
46

57

68
UPDATE_OPERATORS = {
@@ -25,28 +27,57 @@
2527
_document_registry = {}
2628

2729

28-
def get_document(name):
29-
"""Get a registered Document class by name."""
30-
doc = _document_registry.get(name, None)
31-
if not doc:
32-
# Possible old style name
33-
single_end = name.split(".")[-1]
34-
compound_end = ".%s" % single_end
35-
possible_match = [
36-
k for k in _document_registry if k.endswith(compound_end) or k == single_end
37-
]
38-
if len(possible_match) == 1:
39-
doc = _document_registry.get(possible_match.pop(), None)
40-
if not doc:
41-
raise NotRegistered(
42-
"""
43-
`%s` has not been registered in the document registry.
44-
Importing the document class automatically registers it, has it
45-
been imported?
46-
""".strip()
47-
% name
48-
)
49-
return doc
30+
class _DocumentRegistry:
31+
"""Wrapper for the document registry (providing a singleton pattern).
32+
This is part of MongoEngine's internals, not meant to be used directly by end-users
33+
"""
34+
35+
@staticmethod
36+
def get(name):
37+
doc = _document_registry.get(name, None)
38+
if not doc:
39+
# Possible old style name
40+
single_end = name.split(".")[-1]
41+
compound_end = ".%s" % single_end
42+
possible_match = [
43+
k
44+
for k in _document_registry
45+
if k.endswith(compound_end) or k == single_end
46+
]
47+
if len(possible_match) == 1:
48+
doc = _document_registry.get(possible_match.pop(), None)
49+
if not doc:
50+
raise NotRegistered(
51+
"""
52+
`%s` has not been registered in the document registry.
53+
Importing the document class automatically registers it, has it
54+
been imported?
55+
""".strip()
56+
% name
57+
)
58+
return doc
59+
60+
@staticmethod
61+
def register(DocCls):
62+
ExistingDocCls = _document_registry.get(DocCls._class_name)
63+
if (
64+
ExistingDocCls is not None
65+
and ExistingDocCls.__module__ != DocCls.__module__
66+
):
67+
# A sign that a codebase may have named two different classes with the same name accidentally,
68+
# this could cause issues with dereferencing because MongoEngine makes the assumption that a Document
69+
# class name is unique.
70+
warnings.warn(
71+
f"Multiple Document classes named `{DocCls._class_name}` were registered, "
72+
f"first from: `{ExistingDocCls.__module__}`, then from: `{DocCls.__module__}`. "
73+
"this may lead to unexpected behavior during dereferencing.",
74+
stacklevel=4,
75+
)
76+
_document_registry[DocCls._class_name] = DocCls
77+
78+
@staticmethod
79+
def unregister(doc_cls_name):
80+
_document_registry.pop(doc_cls_name)
5081

5182

5283
def _get_documents_by_db(connection_alias, default_connection_alias):

mongoengine/base/document.py

+3-3
Original file line numberDiff line numberDiff line change
@@ -7,7 +7,7 @@
77
from bson import SON, DBRef, ObjectId, json_util
88

99
from mongoengine import signals
10-
from mongoengine.base.common import get_document
10+
from mongoengine.base.common import _DocumentRegistry
1111
from mongoengine.base.datastructures import (
1212
BaseDict,
1313
BaseList,
@@ -500,7 +500,7 @@ def __expand_dynamic_values(self, name, value):
500500
# If the value is a dict with '_cls' in it, turn it into a document
501501
is_dict = isinstance(value, dict)
502502
if is_dict and "_cls" in value:
503-
cls = get_document(value["_cls"])
503+
cls = _DocumentRegistry.get(value["_cls"])
504504
return cls(**value)
505505

506506
if is_dict:
@@ -802,7 +802,7 @@ def _from_son(cls, son, _auto_dereference=True, created=False):
802802

803803
# Return correct subclass for document type
804804
if class_name != cls._class_name:
805-
cls = get_document(class_name)
805+
cls = _DocumentRegistry.get(class_name)
806806

807807
errors_dict = {}
808808

mongoengine/base/metaclasses.py

+2-2
Original file line numberDiff line numberDiff line change
@@ -1,7 +1,7 @@
11
import itertools
22
import warnings
33

4-
from mongoengine.base.common import _document_registry
4+
from mongoengine.base.common import _DocumentRegistry
55
from mongoengine.base.fields import (
66
BaseField,
77
ComplexBaseField,
@@ -169,7 +169,7 @@ def __new__(mcs, name, bases, attrs):
169169
new_class._collection = None
170170

171171
# Add class to the _document_registry
172-
_document_registry[new_class._class_name] = new_class
172+
_DocumentRegistry.register(new_class)
173173

174174
# Handle delete rules
175175
for field in new_class._fields.values():

mongoengine/dereference.py

+10-10
Original file line numberDiff line numberDiff line change
@@ -5,7 +5,7 @@
55
BaseList,
66
EmbeddedDocumentList,
77
TopLevelDocumentMetaclass,
8-
get_document,
8+
_DocumentRegistry,
99
)
1010
from mongoengine.base.datastructures import LazyReference
1111
from mongoengine.connection import _get_session, get_db
@@ -131,9 +131,9 @@ def _find_references(self, items, depth=0):
131131
elif isinstance(v, DBRef):
132132
reference_map.setdefault(field.document_type, set()).add(v.id)
133133
elif isinstance(v, (dict, SON)) and "_ref" in v:
134-
reference_map.setdefault(get_document(v["_cls"]), set()).add(
135-
v["_ref"].id
136-
)
134+
reference_map.setdefault(
135+
_DocumentRegistry.get(v["_cls"]), set()
136+
).add(v["_ref"].id)
137137
elif isinstance(v, (dict, list, tuple)) and depth <= self.max_depth:
138138
field_cls = getattr(
139139
getattr(field, "field", None), "document_type", None
@@ -151,9 +151,9 @@ def _find_references(self, items, depth=0):
151151
elif isinstance(item, DBRef):
152152
reference_map.setdefault(item.collection, set()).add(item.id)
153153
elif isinstance(item, (dict, SON)) and "_ref" in item:
154-
reference_map.setdefault(get_document(item["_cls"]), set()).add(
155-
item["_ref"].id
156-
)
154+
reference_map.setdefault(
155+
_DocumentRegistry.get(item["_cls"]), set()
156+
).add(item["_ref"].id)
157157
elif isinstance(item, (dict, list, tuple)) and depth - 1 <= self.max_depth:
158158
references = self._find_references(item, depth - 1)
159159
for key, refs in references.items():
@@ -198,9 +198,9 @@ def _fetch_objects(self, doc_type=None):
198198
)
199199
for ref in references:
200200
if "_cls" in ref:
201-
doc = get_document(ref["_cls"])._from_son(ref)
201+
doc = _DocumentRegistry.get(ref["_cls"])._from_son(ref)
202202
elif doc_type is None:
203-
doc = get_document(
203+
doc = _DocumentRegistry.get(
204204
"".join(x.capitalize() for x in collection.split("_"))
205205
)._from_son(ref)
206206
else:
@@ -235,7 +235,7 @@ def _attach_objects(self, items, depth=0, instance=None, name=None):
235235
(items["_ref"].collection, items["_ref"].id), items
236236
)
237237
elif "_cls" in items:
238-
doc = get_document(items["_cls"])._from_son(items)
238+
doc = _DocumentRegistry.get(items["_cls"])._from_son(items)
239239
_cls = doc._data.pop("_cls", None)
240240
del items["_cls"]
241241
doc._data = self._attach_objects(doc._data, depth, doc, None)

mongoengine/document.py

+3-3
Original file line numberDiff line numberDiff line change
@@ -12,7 +12,7 @@
1212
DocumentMetaclass,
1313
EmbeddedDocumentList,
1414
TopLevelDocumentMetaclass,
15-
get_document,
15+
_DocumentRegistry,
1616
)
1717
from mongoengine.base.utils import NonOrderedList
1818
from mongoengine.common import _import_class
@@ -851,12 +851,12 @@ def register_delete_rule(cls, document_cls, field_name, rule):
851851
object.
852852
"""
853853
classes = [
854-
get_document(class_name)
854+
_DocumentRegistry.get(class_name)
855855
for class_name in cls._subclasses
856856
if class_name != cls.__name__
857857
] + [cls]
858858
documents = [
859-
get_document(class_name)
859+
_DocumentRegistry.get(class_name)
860860
for class_name in document_cls._subclasses
861861
if class_name != document_cls.__name__
862862
] + [document_cls]

mongoengine/fields.py

+10-10
Original file line numberDiff line numberDiff line change
@@ -30,7 +30,7 @@
3030
GeoJsonBaseField,
3131
LazyReference,
3232
ObjectIdField,
33-
get_document,
33+
_DocumentRegistry,
3434
)
3535
from mongoengine.base.utils import LazyRegexCompiler
3636
from mongoengine.common import _import_class
@@ -725,7 +725,7 @@ def document_type(self):
725725
if self.document_type_obj == RECURSIVE_REFERENCE_CONSTANT:
726726
resolved_document_type = self.owner_document
727727
else:
728-
resolved_document_type = get_document(self.document_type_obj)
728+
resolved_document_type = _DocumentRegistry.get(self.document_type_obj)
729729

730730
if not issubclass(resolved_document_type, EmbeddedDocument):
731731
# Due to the late resolution of the document_type
@@ -801,7 +801,7 @@ def prepare_query_value(self, op, value):
801801

802802
def to_python(self, value):
803803
if isinstance(value, dict):
804-
doc_cls = get_document(value["_cls"])
804+
doc_cls = _DocumentRegistry.get(value["_cls"])
805805
value = doc_cls._from_son(value)
806806

807807
return value
@@ -879,7 +879,7 @@ def to_mongo(self, value, use_db_field=True, fields=None):
879879

880880
def to_python(self, value):
881881
if isinstance(value, dict) and "_cls" in value:
882-
doc_cls = get_document(value["_cls"])
882+
doc_cls = _DocumentRegistry.get(value["_cls"])
883883
if "_ref" in value:
884884
value = doc_cls._get_db().dereference(
885885
value["_ref"], session=_get_session()
@@ -1171,7 +1171,7 @@ def document_type(self):
11711171
if self.document_type_obj == RECURSIVE_REFERENCE_CONSTANT:
11721172
self.document_type_obj = self.owner_document
11731173
else:
1174-
self.document_type_obj = get_document(self.document_type_obj)
1174+
self.document_type_obj = _DocumentRegistry.get(self.document_type_obj)
11751175
return self.document_type_obj
11761176

11771177
@staticmethod
@@ -1195,7 +1195,7 @@ def __get__(self, instance, owner):
11951195
if auto_dereference and isinstance(ref_value, DBRef):
11961196
if hasattr(ref_value, "cls"):
11971197
# Dereference using the class type specified in the reference
1198-
cls = get_document(ref_value.cls)
1198+
cls = _DocumentRegistry.get(ref_value.cls)
11991199
else:
12001200
cls = self.document_type
12011201

@@ -1335,7 +1335,7 @@ def document_type(self):
13351335
if self.document_type_obj == RECURSIVE_REFERENCE_CONSTANT:
13361336
self.document_type_obj = self.owner_document
13371337
else:
1338-
self.document_type_obj = get_document(self.document_type_obj)
1338+
self.document_type_obj = _DocumentRegistry.get(self.document_type_obj)
13391339
return self.document_type_obj
13401340

13411341
@staticmethod
@@ -1498,7 +1498,7 @@ def __get__(self, instance, owner):
14981498

14991499
auto_dereference = instance._fields[self.name]._auto_dereference
15001500
if auto_dereference and isinstance(value, dict):
1501-
doc_cls = get_document(value["_cls"])
1501+
doc_cls = _DocumentRegistry.get(value["_cls"])
15021502
instance._data[self.name] = self._lazy_load_ref(doc_cls, value["_ref"])
15031503

15041504
return super().__get__(instance, owner)
@@ -2443,7 +2443,7 @@ def document_type(self):
24432443
if self.document_type_obj == RECURSIVE_REFERENCE_CONSTANT:
24442444
self.document_type_obj = self.owner_document
24452445
else:
2446-
self.document_type_obj = get_document(self.document_type_obj)
2446+
self.document_type_obj = _DocumentRegistry.get(self.document_type_obj)
24472447
return self.document_type_obj
24482448

24492449
def build_lazyref(self, value):
@@ -2584,7 +2584,7 @@ def build_lazyref(self, value):
25842584
elif value is not None:
25852585
if isinstance(value, (dict, SON)):
25862586
value = LazyReference(
2587-
get_document(value["_cls"]),
2587+
_DocumentRegistry.get(value["_cls"]),
25882588
value["_ref"].id,
25892589
passthrough=self.passthrough,
25902590
)

mongoengine/queryset/base.py

+4-2
Original file line numberDiff line numberDiff line change
@@ -13,7 +13,7 @@
1313
from pymongo.read_concern import ReadConcern
1414

1515
from mongoengine import signals
16-
from mongoengine.base import get_document
16+
from mongoengine.base import _DocumentRegistry
1717
from mongoengine.common import _import_class
1818
from mongoengine.connection import _get_session, get_db
1919
from mongoengine.context_managers import (
@@ -1956,7 +1956,9 @@ def _fields_to_dbfields(self, fields):
19561956
"""Translate fields' paths to their db equivalents."""
19571957
subclasses = []
19581958
if self._document._meta["allow_inheritance"]:
1959-
subclasses = [get_document(x) for x in self._document._subclasses][1:]
1959+
subclasses = [_DocumentRegistry.get(x) for x in self._document._subclasses][
1960+
1:
1961+
]
19601962

19611963
db_field_paths = []
19621964
for field in fields:

tests/document/test_instance.py

+4-4
Original file line numberDiff line numberDiff line change
@@ -13,7 +13,7 @@
1313

1414
from mongoengine import *
1515
from mongoengine import signals
16-
from mongoengine.base import _document_registry, get_document
16+
from mongoengine.base import _DocumentRegistry
1717
from mongoengine.connection import get_db
1818
from mongoengine.context_managers import query_counter, switch_db
1919
from mongoengine.errors import (
@@ -392,7 +392,7 @@ class NicePlace(Place):
392392

393393
# Mimic Place and NicePlace definitions being in a different file
394394
# and the NicePlace model not being imported in at query time.
395-
del _document_registry["Place.NicePlace"]
395+
_DocumentRegistry.unregister("Place.NicePlace")
396396

397397
with pytest.raises(NotRegistered):
398398
list(Place.objects.all())
@@ -407,8 +407,8 @@ class Area(Location):
407407

408408
Location.drop_collection()
409409

410-
assert Area == get_document("Area")
411-
assert Area == get_document("Location.Area")
410+
assert Area == _DocumentRegistry.get("Area")
411+
assert Area == _DocumentRegistry.get("Location.Area")
412412

413413
def test_creation(self):
414414
"""Ensure that document may be created using keyword arguments."""

tests/fields/test_fields.py

+2-2
Original file line numberDiff line numberDiff line change
@@ -37,7 +37,7 @@
3737
from mongoengine.base import (
3838
BaseField,
3939
EmbeddedDocumentList,
40-
_document_registry,
40+
_DocumentRegistry,
4141
)
4242
from mongoengine.base.fields import _no_dereference_for_fields
4343
from mongoengine.errors import DeprecatedError
@@ -1678,7 +1678,7 @@ class User(Document):
16781678

16791679
# Mimic User and Link definitions being in a different file
16801680
# and the Link model not being imported in the User file.
1681-
del _document_registry["Link"]
1681+
_DocumentRegistry.unregister("Link")
16821682

16831683
user = User.objects.first()
16841684
try:

0 commit comments

Comments
 (0)