|
1 | 1 | from copy import deepcopy |
2 | 2 |
|
3 | 3 | from django.db import connection as default_django_connection |
4 | | -from django.db.models import Q |
| 4 | +from django.db.models import Q, AutoField |
5 | 5 | from django.db.models.query import QuerySet |
6 | 6 | from django.db.models.constants import LOOKUP_SEP |
7 | 7 | try: |
@@ -1185,17 +1185,22 @@ def get_update_sql(self, rows): |
1185 | 1185 |
|
1186 | 1186 | return self.sql, sql_args |
1187 | 1187 |
|
1188 | | - def get_upsert_sql(self, rows, unique_fields, update_fields): |
| 1188 | + def get_upsert_sql(self, rows, unique_fields, update_fields, auto_field_name=None, only_insert=False): |
1189 | 1189 | """ |
1190 | | - Performs postgres upsert with multiple rows |
| 1190 | + Generates the postgres specific sql necessary to perform an upsert (ON CONFLICT) |
1191 | 1191 |
|
1192 | 1192 | INSERT INTO table_name (field1, field2) |
1193 | 1193 | VALUES (1, 'two') |
1194 | 1194 | ON CONFLICT (unique_field) DO UPDATE SET field2 = EXCLUDED.field2; |
1195 | 1195 | """ |
1196 | 1196 | ModelClass = self.tables[0].model |
1197 | | - pk_name = ModelClass._meta.pk.column |
1198 | | - all_fields = [field for field in ModelClass._meta.fields if field.column != pk_name] |
| 1197 | + |
| 1198 | + # Use all fields except pk unless the uniqueness constraint is the pk field. Null pk field rows will be |
| 1199 | + # excluded in the upsert method before calling this method |
| 1200 | + all_fields = [field for field in ModelClass._meta.fields if field.column != auto_field_name] |
| 1201 | + if auto_field_name in unique_fields and not only_insert: |
| 1202 | + all_fields = [field for field in ModelClass._meta.fields] |
| 1203 | + |
1199 | 1204 | all_field_names = [field.column for field in all_fields] |
1200 | 1205 | all_field_names_sql = ', '.join(all_field_names) |
1201 | 1206 |
|
@@ -1696,40 +1701,86 @@ def update(self, rows): |
1696 | 1701 | # execute the query |
1697 | 1702 | cursor.execute(sql, sql_args) |
1698 | 1703 |
|
| 1704 | + def get_auto_field_name(self, model_class): |
| 1705 | + """ |
| 1706 | + If one of the unique_fields is the model's AutoField, return the field name, otherwise return None |
| 1707 | + """ |
| 1708 | + # Get auto field name (a model can only have one AutoField) |
| 1709 | + for field in model_class._meta.fields: |
| 1710 | + if isinstance(field, AutoField): |
| 1711 | + return field.column |
| 1712 | + |
| 1713 | + return None |
| 1714 | + |
1699 | 1715 | def upsert(self, rows, unique_fields, update_fields, return_rows=False, return_models=False): |
1700 | 1716 | """ |
1701 | | - Performs an upsert on the set of models defined in rows. |
| 1717 | + Performs an upsert with the set of models defined in rows. If the unique field which is meant |
| 1718 | + to cause a conflict is an auto increment field, then the field should be excluded when its value is null. |
| 1719 | + In this case, an upsert will be performed followed by a bulk_create |
1702 | 1720 | """ |
1703 | 1721 | if len(rows) == 0: |
1704 | 1722 | return |
1705 | 1723 |
|
1706 | | - sql, sql_args = self.get_upsert_sql(rows, unique_fields, update_fields) |
| 1724 | + ModelClass = self.tables[0].model |
1707 | 1725 |
|
1708 | | - # get the cursor to execute the query |
1709 | | - cursor = self.get_cursor() |
| 1726 | + rows_with_null_auto_field_value = [] |
1710 | 1727 |
|
1711 | | - # execute the query |
1712 | | - cursor.execute(sql, sql_args) |
| 1728 | + # Get auto field name (a model can only have one AutoField) |
| 1729 | + auto_field_name = self.get_auto_field_name(ModelClass) |
| 1730 | + |
| 1731 | + # Check if unique fields list contains an auto field |
| 1732 | + if auto_field_name in unique_fields: |
| 1733 | + # Separate the rows that need to be inserted vs the rows that need to be upserted |
| 1734 | + rows_with_null_auto_field_value = [row for row in rows if getattr(row, auto_field_name) is None] |
| 1735 | + rows = [row for row in rows if getattr(row, auto_field_name) is not None] |
| 1736 | + |
| 1737 | + return_value = [] |
| 1738 | + |
| 1739 | + if rows: |
| 1740 | + sql, sql_args = self.get_upsert_sql(rows, unique_fields, update_fields, auto_field_name=auto_field_name) |
| 1741 | + |
| 1742 | + # get the cursor to execute the query |
| 1743 | + cursor = self.get_cursor() |
| 1744 | + |
| 1745 | + # execute the upsert query |
| 1746 | + cursor.execute(sql, sql_args) |
| 1747 | + |
| 1748 | + if return_rows or return_models: |
| 1749 | + return_value.extend(self._fetch_all_as_dict(cursor)) |
| 1750 | + |
| 1751 | + if rows_with_null_auto_field_value: |
| 1752 | + sql, sql_args = self.get_upsert_sql( |
| 1753 | + rows_with_null_auto_field_value, |
| 1754 | + unique_fields, |
| 1755 | + update_fields, |
| 1756 | + auto_field_name=auto_field_name, |
| 1757 | + only_insert=True, |
| 1758 | + ) |
| 1759 | + |
| 1760 | + # get the cursor to execute the query |
| 1761 | + cursor = self.get_cursor() |
| 1762 | + |
| 1763 | + # execute the upsert query |
| 1764 | + cursor.execute(sql, sql_args) |
1713 | 1765 |
|
1714 | | - if return_rows: |
1715 | | - return self._fetch_all_as_dict(cursor) |
| 1766 | + if return_rows or return_models: |
| 1767 | + return_value.extend(self._fetch_all_as_dict(cursor)) |
1716 | 1768 |
|
1717 | 1769 | if return_models: |
1718 | | - row_dicts = self._fetch_all_as_dict(cursor) |
1719 | 1770 | ModelClass = self.tables[0].model |
1720 | 1771 | model_objects = [ |
1721 | 1772 | ModelClass(**row_dict) |
1722 | | - for row_dict in row_dicts |
| 1773 | + for row_dict in return_value |
1723 | 1774 | ] |
1724 | 1775 |
|
1725 | 1776 | # Set the state to indicate the object has been loaded from db |
1726 | 1777 | for model_object in model_objects: |
1727 | 1778 | model_object._state.adding = False |
1728 | 1779 | model_object._state.db = 'default' |
1729 | 1780 |
|
1730 | | - return model_objects |
| 1781 | + return_value = model_objects |
1731 | 1782 |
|
1732 | | - return [] |
| 1783 | + return return_value |
1733 | 1784 |
|
1734 | 1785 | def sql_delete(self): |
1735 | 1786 | """ |
|
0 commit comments