|
1 | 1 | from copy import deepcopy
|
2 | 2 |
|
| 3 | +from django import VERSION |
3 | 4 | from django.db import connection as default_django_connection
|
4 | 5 | from django.db.models import Q, AutoField
|
5 | 6 | from django.db.models.query import QuerySet
|
|
10 | 11 |
|
11 | 12 |
|
12 | 13 | from querybuilder.fields import FieldFactory, CountField, MaxField, MinField, SumField, AvgField
|
13 |
| -from querybuilder.helpers import set_value_for_keypath |
| 14 | +from querybuilder.helpers import set_value_for_keypath, copy_instance |
14 | 15 | from querybuilder.tables import TableFactory, ModelTable, QueryTable
|
15 | 16 |
|
16 | 17 | SERIAL_DTYPES = ['serial', 'bigserial']
|
@@ -1119,6 +1120,19 @@ def get_insert_sql(self, rows):
|
1119 | 1120 |
|
1120 | 1121 | return self.sql, sql_args
|
1121 | 1122 |
|
| 1123 | + def should_not_cast_value(self, field_object): |
| 1124 | + """ |
| 1125 | + In Django 4.1 on PostgreSQL, AutoField, BigAutoField, and SmallAutoField are now created as identity |
| 1126 | + columns rather than serial columns with sequences. |
| 1127 | + """ |
| 1128 | + db_type = field_object.db_type(self.connection) |
| 1129 | + if db_type in SERIAL_DTYPES: |
| 1130 | + return True |
| 1131 | + if (VERSION[0] == 4 and VERSION[1] >= 1) or VERSION[0] >= 5: |
| 1132 | + if getattr(field_object, 'primary_key', None) and getattr(field_object, 'serialize', None) is False: |
| 1133 | + return True |
| 1134 | + return False |
| 1135 | + |
1122 | 1136 | def get_update_sql(self, rows):
|
1123 | 1137 | """
|
1124 | 1138 | Returns SQL UPDATE for rows ``rows``
|
@@ -1171,8 +1185,8 @@ def get_update_sql(self, rows):
|
1171 | 1185 | field_object = self.tables[0].model._meta.get_field(field_names[field_index])
|
1172 | 1186 | db_type = field_object.db_type(self.connection)
|
1173 | 1187 |
|
1174 |
| - # Don't cast the pk |
1175 |
| - if db_type in SERIAL_DTYPES: |
| 1188 | + # Don't cast serial types |
| 1189 | + if self.should_not_cast_value(field_object): |
1176 | 1190 | placeholders.append('%s')
|
1177 | 1191 | else:
|
1178 | 1192 | # Cast the placeholder to the data type
|
@@ -1536,7 +1550,7 @@ def wrap(self, alias=None):
|
1536 | 1550 | :return: The wrapped query
|
1537 | 1551 | """
|
1538 | 1552 | field_names = self.get_field_names()
|
1539 |
| - query = Query(self.connection).from_table(deepcopy(self), alias=alias) |
| 1553 | + query = Query(self.connection).from_table(copy_instance(self), alias=alias) |
1540 | 1554 | self.__dict__.update(query.__dict__)
|
1541 | 1555 |
|
1542 | 1556 | # set explicit field names
|
|
0 commit comments