@@ -420,7 +420,7 @@ export class PolicyProxyHandler<DbClient extends DbClientContract> implements Pr
420
420
} ) ;
421
421
422
422
// return only the ids of the top-level entity
423
- const ids = this . utils . getEntityIds ( this . model , result ) ;
423
+ const ids = this . utils . getEntityIds ( model , result ) ;
424
424
return { result : ids , postWriteChecks : [ ...postCreateChecks . values ( ) ] } ;
425
425
}
426
426
@@ -792,8 +792,10 @@ export class PolicyProxyHandler<DbClient extends DbClientContract> implements Pr
792
792
}
793
793
794
794
// proceed with the create and collect post-create checks
795
- const { postWriteChecks : checks } = await this . doCreate ( model , { data : createData } , db ) ;
795
+ const { postWriteChecks : checks , result } = await this . doCreate ( model , { data : createData } , db ) ;
796
796
postWriteChecks . push ( ...checks ) ;
797
+
798
+ return result ;
797
799
} ;
798
800
799
801
const _createMany = async (
@@ -881,18 +883,10 @@ export class PolicyProxyHandler<DbClient extends DbClientContract> implements Pr
881
883
// check pre-update guard
882
884
await this . utils . checkPolicyForUnique ( model , uniqueFilter , 'update' , db , args ) ;
883
885
884
- // handles the case where id fields are updated
885
- const postUpdateIds = this . utils . clone ( existing ) ;
886
- for ( const key of Object . keys ( existing ) ) {
887
- const updateValue = ( args as any ) . data ? ( args as any ) . data [ key ] : ( args as any ) [ key ] ;
888
- if (
889
- typeof updateValue === 'string' ||
890
- typeof updateValue === 'number' ||
891
- typeof updateValue === 'bigint'
892
- ) {
893
- postUpdateIds [ key ] = updateValue ;
894
- }
895
- }
886
+ // handle the case where id fields are updated
887
+ const _args : any = args ;
888
+ const updatePayload = _args . data && typeof _args . data === 'object' ? _args . data : _args ;
889
+ const postUpdateIds = this . calculatePostUpdateIds ( model , existing , updatePayload ) ;
896
890
897
891
// register post-update check
898
892
await _registerPostUpdateCheck ( model , existing , postUpdateIds ) ;
@@ -984,10 +978,13 @@ export class PolicyProxyHandler<DbClient extends DbClientContract> implements Pr
984
978
// update case
985
979
986
980
// check pre-update guard
987
- await this . utils . checkPolicyForUnique ( model , uniqueFilter , 'update' , db , args ) ;
981
+ await this . utils . checkPolicyForUnique ( model , existing , 'update' , db , args ) ;
982
+
983
+ // handle the case where id fields are updated
984
+ const postUpdateIds = this . calculatePostUpdateIds ( model , existing , args . update ) ;
988
985
989
986
// register post-update check
990
- await _registerPostUpdateCheck ( model , uniqueFilter , uniqueFilter ) ;
987
+ await _registerPostUpdateCheck ( model , existing , postUpdateIds ) ;
991
988
992
989
// convert upsert to update
993
990
const convertedUpdate = {
@@ -1021,9 +1018,22 @@ export class PolicyProxyHandler<DbClient extends DbClientContract> implements Pr
1021
1018
if ( existing ) {
1022
1019
// connect
1023
1020
await _connectDisconnect ( model , args . where , context ) ;
1021
+ return true ;
1024
1022
} else {
1025
1023
// create
1026
- await _create ( model , args . create , context ) ;
1024
+ const created = await _create ( model , args . create , context ) ;
1025
+
1026
+ const upperContext = context . nestingPath [ context . nestingPath . length - 2 ] ;
1027
+ if ( upperContext ?. where && context . field ) {
1028
+ // check if the where clause of the upper context references the id
1029
+ // of the connected entity, if so, we need to update it
1030
+ this . overrideForeignKeyFields ( upperContext . model , upperContext . where , context . field , created ) ;
1031
+ }
1032
+
1033
+ // remove the payload from the parent
1034
+ this . removeFromParent ( context . parent , 'connectOrCreate' , args ) ;
1035
+
1036
+ return false ;
1027
1037
}
1028
1038
} ,
1029
1039
@@ -1093,6 +1103,52 @@ export class PolicyProxyHandler<DbClient extends DbClientContract> implements Pr
1093
1103
return { result, postWriteChecks } ;
1094
1104
}
1095
1105
1106
+ // calculate id fields used for post-update check given an update payload
1107
+ private calculatePostUpdateIds ( _model : string , currentIds : any , updatePayload : any ) {
1108
+ const result = this . utils . clone ( currentIds ) ;
1109
+ for ( const key of Object . keys ( currentIds ) ) {
1110
+ const updateValue = updatePayload [ key ] ;
1111
+ if ( typeof updateValue === 'string' || typeof updateValue === 'number' || typeof updateValue === 'bigint' ) {
1112
+ result [ key ] = updateValue ;
1113
+ }
1114
+ }
1115
+ return result ;
1116
+ }
1117
+
1118
+ // updates foreign key fields inside `payload` based on relation id fields in `newIds`
1119
+ private overrideForeignKeyFields (
1120
+ model : string ,
1121
+ payload : any ,
1122
+ relation : FieldInfo ,
1123
+ newIds : Record < string , unknown >
1124
+ ) {
1125
+ if ( ! relation . foreignKeyMapping || Object . keys ( relation . foreignKeyMapping ) . length === 0 ) {
1126
+ return ;
1127
+ }
1128
+
1129
+ // override foreign key values
1130
+ for ( const [ id , fk ] of Object . entries ( relation . foreignKeyMapping ) ) {
1131
+ if ( payload [ fk ] !== undefined && newIds [ id ] !== undefined ) {
1132
+ payload [ fk ] = newIds [ id ] ;
1133
+ }
1134
+ }
1135
+
1136
+ // deal with compound id fields
1137
+ const uniqueConstraints = this . utils . getUniqueConstraints ( model ) ;
1138
+ for ( const [ name , constraint ] of Object . entries ( uniqueConstraints ) ) {
1139
+ if ( constraint . fields . length > 1 ) {
1140
+ const target = payload [ name ] ;
1141
+ if ( target ) {
1142
+ for ( const [ id , fk ] of Object . entries ( relation . foreignKeyMapping ) ) {
1143
+ if ( target [ fk ] !== undefined && newIds [ id ] !== undefined ) {
1144
+ target [ fk ] = newIds [ id ] ;
1145
+ }
1146
+ }
1147
+ }
1148
+ }
1149
+ }
1150
+ }
1151
+
1096
1152
// Validates the given update payload against Zod schema if any
1097
1153
private validateUpdateInputSchema ( model : string , data : any ) {
1098
1154
const schema = this . utils . getZodSchema ( model , 'update' ) ;
@@ -1224,11 +1280,18 @@ export class PolicyProxyHandler<DbClient extends DbClientContract> implements Pr
1224
1280
1225
1281
const { result, error } = await this . transaction ( async ( tx ) => {
1226
1282
const { where, create, update, ...rest } = args ;
1227
- const existing = await this . utils . checkExistence ( tx , this . model , args . where ) ;
1283
+ const existing = await this . utils . checkExistence ( tx , this . model , where ) ;
1228
1284
1229
1285
if ( existing ) {
1230
1286
// update case
1231
- const { result, postWriteChecks } = await this . doUpdate ( { where, data : update , ...rest } , tx ) ;
1287
+ const { result, postWriteChecks } = await this . doUpdate (
1288
+ {
1289
+ where : this . utils . composeCompoundUniqueField ( this . model , existing ) ,
1290
+ data : update ,
1291
+ ...rest ,
1292
+ } ,
1293
+ tx
1294
+ ) ;
1232
1295
await this . runPostWriteChecks ( postWriteChecks , tx ) ;
1233
1296
return this . utils . readBack ( tx , this . model , 'update' , args , result ) ;
1234
1297
} else {
0 commit comments