|
1 | 1 | from copy import deepcopy |
2 | 2 |
|
| 3 | +from django import VERSION |
3 | 4 | from django.db import connection as default_django_connection |
4 | 5 | from django.db.models import Q, AutoField |
5 | 6 | from django.db.models.query import QuerySet |
|
10 | 11 |
|
11 | 12 |
|
12 | 13 | from querybuilder.fields import FieldFactory, CountField, MaxField, MinField, SumField, AvgField |
13 | | -from querybuilder.helpers import set_value_for_keypath |
| 14 | +from querybuilder.helpers import set_value_for_keypath, copy_instance |
14 | 15 | from querybuilder.tables import TableFactory, ModelTable, QueryTable |
15 | 16 |
|
16 | 17 | SERIAL_DTYPES = ['serial', 'bigserial'] |
@@ -1119,6 +1120,19 @@ def get_insert_sql(self, rows): |
1119 | 1120 |
|
1120 | 1121 | return self.sql, sql_args |
1121 | 1122 |
|
| 1123 | + def should_not_cast_value(self, field_object): |
| 1124 | + """ |
| 1125 | + In Django 4.1 on PostgreSQL, AutoField, BigAutoField, and SmallAutoField are now created as identity |
| 1126 | + columns rather than serial columns with sequences. |
| 1127 | + """ |
| 1128 | + db_type = field_object.db_type(self.connection) |
| 1129 | + if db_type in SERIAL_DTYPES: |
| 1130 | + return True |
| 1131 | + if (VERSION[0] == 4 and VERSION[1] >= 1) or VERSION[0] >= 5: |
| 1132 | + if getattr(field_object, 'primary_key', None) and getattr(field_object, 'serialize', None) is False: |
| 1133 | + return True |
| 1134 | + return False |
| 1135 | + |
1122 | 1136 | def get_update_sql(self, rows): |
1123 | 1137 | """ |
1124 | 1138 | Returns SQL UPDATE for rows ``rows`` |
@@ -1171,8 +1185,8 @@ def get_update_sql(self, rows): |
1171 | 1185 | field_object = self.tables[0].model._meta.get_field(field_names[field_index]) |
1172 | 1186 | db_type = field_object.db_type(self.connection) |
1173 | 1187 |
|
1174 | | - # Don't cast the pk |
1175 | | - if db_type in SERIAL_DTYPES: |
| 1188 | + # Don't cast serial types |
| 1189 | + if self.should_not_cast_value(field_object): |
1176 | 1190 | placeholders.append('%s') |
1177 | 1191 | else: |
1178 | 1192 | # Cast the placeholder to the data type |
@@ -1536,7 +1550,7 @@ def wrap(self, alias=None): |
1536 | 1550 | :return: The wrapped query |
1537 | 1551 | """ |
1538 | 1552 | field_names = self.get_field_names() |
1539 | | - query = Query(self.connection).from_table(deepcopy(self), alias=alias) |
| 1553 | + query = Query(self.connection).from_table(copy_instance(self), alias=alias) |
1540 | 1554 | self.__dict__.update(query.__dict__) |
1541 | 1555 |
|
1542 | 1556 | # set explicit field names |
|
0 commit comments