Skip to content

Commit d45d6d5

Browse files
authored
Merge pull request #5 from conversocial/dev-equality
feat: noticket - Update Document __eq__/__ne__ to return NotImplemented
2 parents f72b951 + c808684 commit d45d6d5

File tree

3 files changed

+47
-5
lines changed

3 files changed

+47
-5
lines changed

mongoengine/__init__.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -12,7 +12,7 @@
1212
__all__ = (document.__all__ + fields.__all__ + connection.__all__ +
1313
queryset.__all__ + signals.__all__)
1414

15-
VERSION = (0, 6, 20)
15+
VERSION = (0, 6, 21)
1616

1717

1818
def get_version():

mongoengine/base.py

Lines changed: 6 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -1143,12 +1143,14 @@ def __str__(self):
11431143

11441144
def __eq__(self, other):
11451145
if isinstance(other, self.__class__) and hasattr(other, 'id'):
1146-
if self.id == other.id:
1147-
return True
1148-
return False
1146+
return self.id == other.id
1147+
return NotImplemented
11491148

11501149
def __ne__(self, other):
1151-
return not self.__eq__(other)
1150+
res = self.__eq__(other)
1151+
if res == NotImplemented:
1152+
return NotImplemented
1153+
return not res
11521154

11531155
def __hash__(self):
11541156
if self.pk is None:

tests/document.py

Lines changed: 40 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -2279,5 +2279,45 @@ def __str__(self):
22792279
}
22802280
) ]), "1,2")
22812281

2282+
2283+
def test_document_equality(self):
2284+
class A(Document):
2285+
name = StringField()
2286+
2287+
class B(Document):
2288+
name = StringField()
2289+
2290+
class SomethingElse(object):
2291+
def __init__(self, id, name):
2292+
self.id = id
2293+
self.name = name
2294+
2295+
a = A(id=1, name="BOOK")
2296+
a_ = A(id=1, name="BOOK")
2297+
b = B(id=1, name="BOOK")
2298+
somethingElse = SomethingElse(id=1, name="BOOK")
2299+
2300+
# ensure __eq__ returns NotImplemented
2301+
self.assertEqual(a.__eq__(a_), True)
2302+
self.assertEqual(a.__eq__(b), NotImplemented)
2303+
self.assertEqual(a.__eq__(somethingElse), NotImplemented)
2304+
2305+
# test equality
2306+
self.assertTrue(a == a_)
2307+
self.assertFalse(a == b)
2308+
self.assertFalse(a == somethingElse)
2309+
2310+
# test __ne__ returns NotImplemented
2311+
self.assertEqual(a.__ne__(a_), False)
2312+
self.assertEqual(a.__ne__(b), NotImplemented)
2313+
self.assertEqual(a.__ne__(somethingElse), NotImplemented)
2314+
2315+
# test not equal
2316+
self.assertFalse(a != a_)
2317+
self.assertTrue(a != b)
2318+
self.assertTrue(a != somethingElse)
2319+
2320+
2321+
22822322
if __name__ == '__main__':
22832323
unittest.main()

0 commit comments

Comments
 (0)