Skip to content

Commit

Permalink
Merge pull request colinhowe#7 from conversocial/get_tests_passing
Browse files Browse the repository at this point in the history
fix: CON-112 - Get tests passing
  • Loading branch information
Marcus Baker authored Dec 1, 2017
2 parents 01c1e42 + 8b1f869 commit 12b4d2d
Show file tree
Hide file tree
Showing 18 changed files with 1,165 additions and 687 deletions.
152 changes: 101 additions & 51 deletions mongoengine/base.py

Large diffs are not rendered by default.

7 changes: 4 additions & 3 deletions mongoengine/connection.py
Original file line number Diff line number Diff line change
Expand Up @@ -32,8 +32,7 @@ def register_connection(alias, host='localhost', port=27017,
:param name: the name of the specific database to use
:param host: the host name of the :program:`mongod` instance to connect to
:param port: the port that the :program:`mongod` instance is running on
:param is_slave: whether the connection can act as a slave ** Depreciated pymongo 2.0.1+
:param read_preference: The read preference for the collection ** Added pymongo 2.1
:param read_preference: The read preference for the collection
:param slaves: a list of aliases of slave connections; each of these must
be a registered connection that has :attr:`is_slave` set to ``True``
:param username: username to authenticate with
Expand Down Expand Up @@ -109,7 +108,8 @@ def get_connection(alias=DEFAULT_CONNECTION_NAME, reconnect=False):
try:
_connections[alias] = MongoClient(**conn_settings)
except Exception, e:
raise ConnectionError("Cannot connect to database %s :\n%s" % (alias, e))
raise ConnectionError(
"Cannot connect to database %s :\n%s" % (alias, e))
return _connections[alias]


Expand All @@ -126,6 +126,7 @@ def register_db(
'db_name': db_name,
}


def get_db(alias=DEFAULT_DB_ALIAS, reconnect=False, refresh=False):
global _dbs
global _db_settings
Expand Down
46 changes: 31 additions & 15 deletions mongoengine/dereference.py
Original file line number Diff line number Diff line change
Expand Up @@ -69,21 +69,29 @@ def _find_references(self, items, depth=0):
for field_name, field in item._fields.iteritems():
v = item._data.get(field_name, None)
if isinstance(v, (DBRef)):
reference_map.setdefault(field.document_type, []).append(v.id)
reference_map.setdefault(field.document_type, []) \
.append(v.id)
elif isinstance(v, (dict, SON)) and '_ref' in v:
reference_map.setdefault(get_document(v['_cls']), []).append(v['_ref'].id)
elif isinstance(v, (dict, list, tuple)) and depth <= self.max_depth:
field_cls = getattr(getattr(field, 'field', None), 'document_type', None)
reference_map.setdefault(get_document(v['_cls']), []) \
.append(v['_ref'].id)
elif isinstance(v, (dict, list, tuple)) and \
depth <= self.max_depth:
field_cls = getattr(getattr(field, 'field', None),
'document_type', None)
references = self._find_references(v, depth)
for key, refs in references.iteritems():
if isinstance(field_cls, (Document, TopLevelDocumentMetaclass)):
if isinstance(
field_cls,
(Document, TopLevelDocumentMetaclass)):
key = field_cls
reference_map.setdefault(key, []).extend(refs)
elif isinstance(item, (DBRef)):
reference_map.setdefault(item.collection, []).append(item.id)
elif isinstance(item, (dict, SON)) and '_ref' in item:
reference_map.setdefault(get_document(item['_cls']), []).append(item['_ref'].id)
elif isinstance(item, (dict, list, tuple)) and depth - 1 <= self.max_depth:
reference_map.setdefault(get_document(item['_cls']), []) \
.append(item['_ref'].id)
elif isinstance(item, (dict, list, tuple)) and \
depth - 1 <= self.max_depth:
references = self._find_references(item, depth - 1)
for key, refs in references.iteritems():
reference_map.setdefault(key, []).extend(refs)
Expand All @@ -96,14 +104,17 @@ def _fetch_objects(self, doc_type=None):
object_map = {}
for col, dbrefs in self.reference_map.iteritems():
keys = object_map.keys()
refs = list(set([dbref for dbref in dbrefs if str(dbref) not in keys]))
refs = list(set(
[dbref for dbref in dbrefs if str(dbref) not in keys]))
if hasattr(col, 'objects'): # We have a document class for the refs
references = col.objects.in_bulk(refs)
for key, doc in references.iteritems():
object_map[key] = doc
else: # Generic reference: use the refs data to convert to document
if doc_type and not isinstance(doc_type, (ListField, DictField, MapField,) ):
references = doc_type._get_db()[col].find({'_id': {'$in': refs}})
if doc_type and \
not isinstance(doc_type, (ListField, DictField, MapField,)): # noqa
references = doc_type._get_db()[col].find(
{'_id': {'$in': refs}})
for ref in references:
doc = doc_type._from_son(ref)
object_map[doc.id] = doc
Expand Down Expand Up @@ -170,13 +181,18 @@ def _attach_objects(self, items, depth=0, instance=None, name=None):
if isinstance(v, (DBRef)):
data[k]._data[field_name] = self.object_map.get(v.id, v)
elif isinstance(v, (dict, SON)) and '_ref' in v:
data[k]._data[field_name] = self.object_map.get(v['_ref'].id, v)
data[k]._data[field_name] = \
self.object_map.get(v['_ref'].id, v)
elif isinstance(v, dict) and depth <= self.max_depth:
data[k]._data[field_name] = self._attach_objects(v, depth, instance=instance, name=name)
elif isinstance(v, (list, tuple)) and depth <= self.max_depth:
data[k]._data[field_name] = self._attach_objects(v, depth, instance=instance, name=name)
data[k]._data[field_name] = self._attach_objects(
v, depth, instance=instance, name=name)
elif isinstance(v, (list, tuple)) and \
depth <= self.max_depth:
data[k]._data[field_name] = self._attach_objects(
v, depth, instance=instance, name=name)
elif isinstance(v, (dict, list, tuple)) and depth <= self.max_depth:
data[k] = self._attach_objects(v, depth - 1, instance=instance, name=name)
data[k] = self._attach_objects(
v, depth - 1, instance=instance, name=name)
elif hasattr(v, 'id'):
data[k] = self.object_map.get(v.id, v)

Expand Down
53 changes: 33 additions & 20 deletions mongoengine/document.py
Original file line number Diff line number Diff line change
Expand Up @@ -8,7 +8,8 @@
from connection import get_db, DEFAULT_CONNECTION_NAME

__all__ = ['Document', 'EmbeddedDocument', 'DynamicDocument',
'DynamicEmbeddedDocument', 'OperationError', 'InvalidCollectionError']
'DynamicEmbeddedDocument', 'OperationError',
'InvalidCollectionError']


class InvalidCollectionError(Exception):
Expand Down Expand Up @@ -81,14 +82,15 @@ def pk():
"""
def fget(self):
return getattr(self, self._meta['id_field'])

def fset(self, value):
return setattr(self, self._meta['id_field'], value)
return property(fget, fset)

@classmethod
def _get_db(cls):
"""Some Model using other db_alias"""
return get_db(cls._meta.get("db_alias", DEFAULT_CONNECTION_NAME ))
return get_db(cls._meta.get("db_alias", DEFAULT_CONNECTION_NAME))

@classmethod
def _get_subdocuments(cls):
Expand Down Expand Up @@ -122,7 +124,8 @@ def _get_collection(cls):
if options.get('max') != max_documents or \
options.get('size') != max_size:
msg = ('Cannot create collection "%s" as a capped '
'collection as it already exists') % cls._collection
'collection as it already exists'
% cls._collection)
raise InvalidCollectionError(msg)
else:
# Create the collection as a capped collection
Expand All @@ -139,7 +142,7 @@ def _get_collection(cls):
return cls._collection

def save(self, force_insert=False, validate=True, write_options=None,
cascade=None, cascade_kwargs=None, _refs=None):
cascade=None, cascade_kwargs=None, _refs=None):
"""Save the :class:`~mongoengine.Document` to the database. If the
document already exists, it will be updated, otherwise it will be
created.
Expand All @@ -150,24 +153,28 @@ def save(self, force_insert=False, validate=True, write_options=None,
:param write_options: Extra keyword arguments are passed down to
:meth:`~pymongo.collection.Collection.save` OR
:meth:`~pymongo.collection.Collection.insert`
which will be used as options for the resultant ``getLastError`` command.
For example, ``save(..., w=2, fsync=True)`` will wait until at least two servers
have recorded the write and will force an fsync on each server being written to.
:param cascade: Sets the flag for cascading saves. You can set a default by setting
"cascade" in the document __meta__
:param cascade_kwargs: optional kwargs dictionary to be passed throw to cascading saves
which will be used as options for the resultant
``getLastError`` command.
For example, ``save(..., w=2, fsync=True)`` will wait until at
least two servers have recorded the write and will force an
fsync on each server being written to.
:param cascade: Sets the flag for cascading saves. You can set a
default by setting "cascade" in the document __meta__
:param cascade_kwargs: optional kwargs dictionary to be passed throw
to cascading saves
:param _refs: A list of processed references used in cascading saves
.. versionchanged:: 0.5
In existing documents it only saves changed fields using set / unset
Saves are cascaded and any :class:`~bson.dbref.DBRef` objects
that have changes are saved as well.
.. versionchanged:: 0.6
Cascade saves are optional = defaults to True, if you want fine grain
control then you can turn off using document meta['cascade'] = False
Also you can pass different kwargs to the cascade save using cascade_kwargs
which overwrites the existing kwargs with custom values
Cascade saves are optional = defaults to True, if you want fine
grain control then you can turn off using document
meta['cascade'] = False
Also you can pass different kwargs to the cascade save using
cascade_kwargs which overwrites the existing kwargs with custom
values
"""
signals.pre_save.send(self.__class__, document=self)

Expand Down Expand Up @@ -201,11 +208,17 @@ def save(self, force_insert=False, validate=True, write_options=None,

upsert = self._created
if updates:
collection.update(select_dict, {"$set": updates}, upsert=upsert, **write_options)
collection.update(select_dict,
{"$set": updates},
upsert=upsert,
**write_options)
if removals:
collection.update(select_dict, {"$unset": removals}, upsert=upsert, **write_options)
collection.update(select_dict,
{"$unset": removals},
upsert=upsert,
**write_options)

cascade = self._meta.get('cascade', True) if cascade is None else cascade
cascade = self._meta.get('cascade', True) if cascade is None else cascade # noqa
if cascade:
kwargs = {
"force_insert": force_insert,
Expand Down Expand Up @@ -419,8 +432,8 @@ def object(self):
try:
self.key = id_field_type(self.key)
except:
raise Exception("Could not cast key as %s" % \
id_field_type.__name__)
raise Exception(
"Could not cast key as %s" % id_field_type.__name__)

if not hasattr(self, "_key_object"):
self._key_object = self._document.objects.with_id(self.key)
Expand Down
Loading

0 comments on commit 12b4d2d

Please sign in to comment.