Skip to content

Commit 1996d0f

Browse files
authored
Merge pull request #81 from ambitioninc/develop
0.14.1
2 parents 2406167 + e20addf commit 1996d0f

5 files changed

Lines changed: 172 additions & 24 deletions

File tree

docs/release_notes.rst

Lines changed: 4 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -1,6 +1,10 @@
11
Release Notes
22
=============
33

4+
v0.14.1
5+
-------
6+
* Fix upsert to handle case when the uniqueness constraint is the pk field
7+
48
v0.14.0
59
-------
610
* Drop support for django 1.7, add official support for python 3.5

querybuilder/query.py

Lines changed: 68 additions & 17 deletions
Original file line numberDiff line numberDiff line change
@@ -1,7 +1,7 @@
11
from copy import deepcopy
22

33
from django.db import connection as default_django_connection
4-
from django.db.models import Q
4+
from django.db.models import Q, AutoField
55
from django.db.models.query import QuerySet
66
from django.db.models.constants import LOOKUP_SEP
77
try:
@@ -1185,17 +1185,22 @@ def get_update_sql(self, rows):
11851185

11861186
return self.sql, sql_args
11871187

1188-
def get_upsert_sql(self, rows, unique_fields, update_fields):
1188+
def get_upsert_sql(self, rows, unique_fields, update_fields, auto_field_name=None, only_insert=False):
11891189
"""
1190-
Performs postgres upsert with multiple rows
1190+
Generates the postgres specific sql necessary to perform an upsert (ON CONFLICT)
11911191
11921192
INSERT INTO table_name (field1, field2)
11931193
VALUES (1, 'two')
11941194
ON CONFLICT (unique_field) DO UPDATE SET field2 = EXCLUDED.field2;
11951195
"""
11961196
ModelClass = self.tables[0].model
1197-
pk_name = ModelClass._meta.pk.column
1198-
all_fields = [field for field in ModelClass._meta.fields if field.column != pk_name]
1197+
1198+
# Use all fields except pk unless the uniqueness constraint is the pk field. Null pk field rows will be
1199+
# excluded in the upsert method before calling this method
1200+
all_fields = [field for field in ModelClass._meta.fields if field.column != auto_field_name]
1201+
if auto_field_name in unique_fields and not only_insert:
1202+
all_fields = [field for field in ModelClass._meta.fields]
1203+
11991204
all_field_names = [field.column for field in all_fields]
12001205
all_field_names_sql = ', '.join(all_field_names)
12011206

@@ -1696,40 +1701,86 @@ def update(self, rows):
16961701
# execute the query
16971702
cursor.execute(sql, sql_args)
16981703

1704+
def get_auto_field_name(self, model_class):
1705+
"""
1706+
If one of the unique_fields is the model's AutoField, return the field name, otherwise return None
1707+
"""
1708+
# Get auto field name (a model can only have one AutoField)
1709+
for field in model_class._meta.fields:
1710+
if isinstance(field, AutoField):
1711+
return field.column
1712+
1713+
return None
1714+
16991715
def upsert(self, rows, unique_fields, update_fields, return_rows=False, return_models=False):
17001716
"""
1701-
Performs an upsert on the set of models defined in rows.
1717+
Performs an upsert with the set of models defined in rows. If the unique field which is meant
1718+
to cause a conflict is an auto increment field, then the field should be excluded when its value is null.
1719+
In this case, an upsert will be performed followed by a bulk_create
17021720
"""
17031721
if len(rows) == 0:
17041722
return
17051723

1706-
sql, sql_args = self.get_upsert_sql(rows, unique_fields, update_fields)
1724+
ModelClass = self.tables[0].model
17071725

1708-
# get the cursor to execute the query
1709-
cursor = self.get_cursor()
1726+
rows_with_null_auto_field_value = []
17101727

1711-
# execute the query
1712-
cursor.execute(sql, sql_args)
1728+
# Get auto field name (a model can only have one AutoField)
1729+
auto_field_name = self.get_auto_field_name(ModelClass)
1730+
1731+
# Check if unique fields list contains an auto field
1732+
if auto_field_name in unique_fields:
1733+
# Separate the rows that need to be inserted vs the rows that need to be upserted
1734+
rows_with_null_auto_field_value = [row for row in rows if getattr(row, auto_field_name) is None]
1735+
rows = [row for row in rows if getattr(row, auto_field_name) is not None]
1736+
1737+
return_value = []
1738+
1739+
if rows:
1740+
sql, sql_args = self.get_upsert_sql(rows, unique_fields, update_fields, auto_field_name=auto_field_name)
1741+
1742+
# get the cursor to execute the query
1743+
cursor = self.get_cursor()
1744+
1745+
# execute the upsert query
1746+
cursor.execute(sql, sql_args)
1747+
1748+
if return_rows or return_models:
1749+
return_value.extend(self._fetch_all_as_dict(cursor))
1750+
1751+
if rows_with_null_auto_field_value:
1752+
sql, sql_args = self.get_upsert_sql(
1753+
rows_with_null_auto_field_value,
1754+
unique_fields,
1755+
update_fields,
1756+
auto_field_name=auto_field_name,
1757+
only_insert=True,
1758+
)
1759+
1760+
# get the cursor to execute the query
1761+
cursor = self.get_cursor()
1762+
1763+
# execute the upsert query
1764+
cursor.execute(sql, sql_args)
17131765

1714-
if return_rows:
1715-
return self._fetch_all_as_dict(cursor)
1766+
if return_rows or return_models:
1767+
return_value.extend(self._fetch_all_as_dict(cursor))
17161768

17171769
if return_models:
1718-
row_dicts = self._fetch_all_as_dict(cursor)
17191770
ModelClass = self.tables[0].model
17201771
model_objects = [
17211772
ModelClass(**row_dict)
1722-
for row_dict in row_dicts
1773+
for row_dict in return_value
17231774
]
17241775

17251776
# Set the state to indicate the object has been loaded from db
17261777
for model_object in model_objects:
17271778
model_object._state.adding = False
17281779
model_object._state.db = 'default'
17291780

1730-
return model_objects
1781+
return_value = model_objects
17311782

1732-
return []
1783+
return return_value
17331784

17341785
def sql_delete(self):
17351786
"""

querybuilder/tests/upsert_tests.py

Lines changed: 95 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -1,14 +1,15 @@
11
from django.test.utils import override_settings
22
from django import VERSION
3+
from django_dynamic_fixture import G
34

45
from querybuilder.logger import Logger
56
from querybuilder.query import Query
6-
from querybuilder.tests.models import Uniques
7+
from querybuilder.tests.models import Uniques, User
78
from querybuilder.tests.query_tests import QueryTestCase
89

910

1011
@override_settings(DEBUG=True)
11-
class TestUpdate(QueryTestCase):
12+
class TestUpsert(QueryTestCase):
1213

1314
def setUp(self):
1415
self.logger = Logger()
@@ -123,3 +124,95 @@ def test_upsert(self):
123124
self.assertEqual(models[1].field5, 'not null')
124125
self.assertEqual(models[1].field6, '2.6')
125126
self.assertEqual(models[1].field7, '2.7')
127+
128+
def test_upsert_pk(self):
129+
"""
130+
Makes sure upserting is possible when the only uniqueness constraint is the pk.
131+
"""
132+
user1 = G(User, email='user1')
133+
user1.email = 'user1change'
134+
user2 = User(email='user2')
135+
user3 = User(email='user3')
136+
137+
self.assertEqual(User.objects.count(), 1)
138+
Query().from_table(User).upsert(
139+
[user1, user2, user3],
140+
unique_fields=['id'],
141+
update_fields=['email'],
142+
)
143+
self.assertEqual(User.objects.count(), 3)
144+
145+
users = list(User.objects.order_by('id'))
146+
147+
self.assertEqual(users[0].email, 'user1change')
148+
self.assertEqual(users[1].email, 'user2')
149+
self.assertEqual(users[2].email, 'user3')
150+
151+
def test_upsert_pk_return_dicts(self):
152+
"""
153+
Makes sure upserting is possible when the only uniqueness constraint is the pk. Should return dicts.
154+
"""
155+
user1 = G(User, email='user1')
156+
user1.email = 'user1change'
157+
user2 = User(email='user2')
158+
user3 = User(email='user3')
159+
160+
self.assertEqual(User.objects.count(), 1)
161+
rows = Query().from_table(User).upsert(
162+
[user1, user2, user3],
163+
unique_fields=['id'],
164+
update_fields=['email'],
165+
return_rows=True,
166+
)
167+
self.assertEqual(User.objects.count(), 3)
168+
self.assertEqual(len(rows), 3)
169+
170+
# Check ids
171+
for row in rows:
172+
self.assertIsNotNone(row['id'])
173+
174+
# Check emails
175+
email_set = {
176+
row['email'] for row in rows
177+
}
178+
self.assertEqual(email_set, {'user1change', 'user2', 'user3'})
179+
180+
# Check fields from db
181+
users = list(User.objects.order_by('id'))
182+
self.assertEqual(users[0].email, 'user1change')
183+
self.assertEqual(users[1].email, 'user2')
184+
self.assertEqual(users[2].email, 'user3')
185+
186+
def test_upsert_pk_return_models(self):
187+
"""
188+
Makes sure upserting is possible when the only uniqueness constraint is the pk. Should return models.
189+
"""
190+
user1 = G(User, email='user1')
191+
user1.email = 'user1change'
192+
user2 = User(email='user2')
193+
user3 = User(email='user3')
194+
195+
self.assertEqual(User.objects.count(), 1)
196+
records = Query().from_table(User).upsert(
197+
[user1, user2, user3],
198+
unique_fields=['id'],
199+
update_fields=['email'],
200+
return_models=True,
201+
)
202+
self.assertEqual(len(records), 3)
203+
204+
# Check ids
205+
for record in records:
206+
self.assertIsNotNone(record.id)
207+
208+
# Check emails
209+
email_set = {
210+
record.email for record in records
211+
}
212+
self.assertEqual(email_set, {'user1change', 'user2', 'user3'})
213+
214+
# Check fields from db
215+
users = list(User.objects.order_by('id'))
216+
self.assertEqual(users[0].email, 'user1change')
217+
self.assertEqual(users[1].email, 'user2')
218+
self.assertEqual(users[2].email, 'user3')

querybuilder/version.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -1 +1 @@
1-
__version__ = '0.14.0'
1+
__version__ = '0.14.1'

settings.py

Lines changed: 4 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -13,10 +13,10 @@ def configure_settings():
1313
if test_db is None:
1414
db_config = {
1515
'ENGINE': 'django.db.backends.postgresql_psycopg2',
16-
'NAME': 'ambition_dev',
17-
'USER': 'ambition_dev',
18-
'PASSWORD': 'ambition_dev',
19-
'HOST': 'localhost'
16+
'NAME': 'ambition',
17+
'USER': 'ambition',
18+
'PASSWORD': 'ambition',
19+
'HOST': 'db'
2020
}
2121
elif test_db == 'postgres':
2222
db_config = {

0 commit comments

Comments
 (0)