Skip to content

Commit d8c1513

Browse files
authored
fix(runtime): always use id fields to address existing entity during upsert (#1273)
1 parent 8137481 commit d8c1513

File tree

11 files changed

+742
-51
lines changed

11 files changed

+742
-51
lines changed

Diff for: packages/plugins/swr/tests/test-model-meta.ts

+4-1
Original file line numberDiff line numberDiff line change
@@ -43,7 +43,10 @@ export const modelMeta: ModelMeta = {
4343
ownerId: { ...fieldDefaults, type: 'User', name: 'owner', isForeignKey: true },
4444
},
4545
},
46-
uniqueConstraints: {},
46+
uniqueConstraints: {
47+
user: { id: { name: 'id', fields: ['id'] } },
48+
post: { id: { name: 'id', fields: ['id'] } },
49+
},
4750
deleteCascade: {
4851
user: ['Post'],
4952
},

Diff for: packages/plugins/tanstack-query/tests/test-model-meta.ts

+4-1
Original file line numberDiff line numberDiff line change
@@ -43,7 +43,10 @@ export const modelMeta: ModelMeta = {
4343
ownerId: { ...fieldDefaults, type: 'User', name: 'owner', isForeignKey: true },
4444
},
4545
},
46-
uniqueConstraints: {},
46+
uniqueConstraints: {
47+
user: { id: { name: 'id', fields: ['id'] } },
48+
post: { id: { name: 'id', fields: ['id'] } },
49+
},
4750
deleteCascade: {
4851
user: ['Post'],
4952
},

Diff for: packages/runtime/src/cross/utils.ts

+9-11
Original file line numberDiff line numberDiff line change
@@ -1,5 +1,5 @@
11
import { lowerCaseFirst } from 'lower-case-first';
2-
import { ModelMeta } from '.';
2+
import { ModelMeta, requireField } from '.';
33

44
/**
55
* Gets field names in a data model entity, filtering out internal fields.
@@ -47,17 +47,15 @@ export function zip<T1, T2>(x: Enumerable<T1>, y: Enumerable<T2>): Array<[T1, T2
4747
}
4848

4949
export function getIdFields(modelMeta: ModelMeta, model: string, throwIfNotFound = false) {
50-
let fields = modelMeta.fields[lowerCaseFirst(model)];
51-
if (!fields) {
50+
const uniqueConstraints = modelMeta.uniqueConstraints[lowerCaseFirst(model)] ?? {};
51+
52+
const entries = Object.values(uniqueConstraints);
53+
if (entries.length === 0) {
5254
if (throwIfNotFound) {
53-
throw new Error(`Unable to load fields for ${model}`);
54-
} else {
55-
fields = {};
55+
throw new Error(`Model ${model} does not have any id field`);
5656
}
57+
return [];
5758
}
58-
const result = Object.values(fields).filter((f) => f.isId);
59-
if (result.length === 0 && throwIfNotFound) {
60-
throw new Error(`model ${model} does not have an id field`);
61-
}
62-
return result;
59+
60+
return entries[0].fields.map((f) => requireField(modelMeta, model, f));
6361
}

Diff for: packages/runtime/src/enhancements/policy/handler.ts

+82-19
Original file line numberDiff line numberDiff line change
@@ -420,7 +420,7 @@ export class PolicyProxyHandler<DbClient extends DbClientContract> implements Pr
420420
});
421421

422422
// return only the ids of the top-level entity
423-
const ids = this.utils.getEntityIds(this.model, result);
423+
const ids = this.utils.getEntityIds(model, result);
424424
return { result: ids, postWriteChecks: [...postCreateChecks.values()] };
425425
}
426426

@@ -792,8 +792,10 @@ export class PolicyProxyHandler<DbClient extends DbClientContract> implements Pr
792792
}
793793

794794
// proceed with the create and collect post-create checks
795-
const { postWriteChecks: checks } = await this.doCreate(model, { data: createData }, db);
795+
const { postWriteChecks: checks, result } = await this.doCreate(model, { data: createData }, db);
796796
postWriteChecks.push(...checks);
797+
798+
return result;
797799
};
798800

799801
const _createMany = async (
@@ -881,18 +883,10 @@ export class PolicyProxyHandler<DbClient extends DbClientContract> implements Pr
881883
// check pre-update guard
882884
await this.utils.checkPolicyForUnique(model, uniqueFilter, 'update', db, args);
883885

884-
// handles the case where id fields are updated
885-
const postUpdateIds = this.utils.clone(existing);
886-
for (const key of Object.keys(existing)) {
887-
const updateValue = (args as any).data ? (args as any).data[key] : (args as any)[key];
888-
if (
889-
typeof updateValue === 'string' ||
890-
typeof updateValue === 'number' ||
891-
typeof updateValue === 'bigint'
892-
) {
893-
postUpdateIds[key] = updateValue;
894-
}
895-
}
886+
// handle the case where id fields are updated
887+
const _args: any = args;
888+
const updatePayload = _args.data && typeof _args.data === 'object' ? _args.data : _args;
889+
const postUpdateIds = this.calculatePostUpdateIds(model, existing, updatePayload);
896890

897891
// register post-update check
898892
await _registerPostUpdateCheck(model, existing, postUpdateIds);
@@ -984,10 +978,13 @@ export class PolicyProxyHandler<DbClient extends DbClientContract> implements Pr
984978
// update case
985979

986980
// check pre-update guard
987-
await this.utils.checkPolicyForUnique(model, uniqueFilter, 'update', db, args);
981+
await this.utils.checkPolicyForUnique(model, existing, 'update', db, args);
982+
983+
// handle the case where id fields are updated
984+
const postUpdateIds = this.calculatePostUpdateIds(model, existing, args.update);
988985

989986
// register post-update check
990-
await _registerPostUpdateCheck(model, uniqueFilter, uniqueFilter);
987+
await _registerPostUpdateCheck(model, existing, postUpdateIds);
991988

992989
// convert upsert to update
993990
const convertedUpdate = {
@@ -1021,9 +1018,22 @@ export class PolicyProxyHandler<DbClient extends DbClientContract> implements Pr
10211018
if (existing) {
10221019
// connect
10231020
await _connectDisconnect(model, args.where, context);
1021+
return true;
10241022
} else {
10251023
// create
1026-
await _create(model, args.create, context);
1024+
const created = await _create(model, args.create, context);
1025+
1026+
const upperContext = context.nestingPath[context.nestingPath.length - 2];
1027+
if (upperContext?.where && context.field) {
1028+
// check if the where clause of the upper context references the id
1029+
// of the connected entity, if so, we need to update it
1030+
this.overrideForeignKeyFields(upperContext.model, upperContext.where, context.field, created);
1031+
}
1032+
1033+
// remove the payload from the parent
1034+
this.removeFromParent(context.parent, 'connectOrCreate', args);
1035+
1036+
return false;
10271037
}
10281038
},
10291039

@@ -1093,6 +1103,52 @@ export class PolicyProxyHandler<DbClient extends DbClientContract> implements Pr
10931103
return { result, postWriteChecks };
10941104
}
10951105

1106+
// calculate id fields used for post-update check given an update payload
1107+
private calculatePostUpdateIds(_model: string, currentIds: any, updatePayload: any) {
1108+
const result = this.utils.clone(currentIds);
1109+
for (const key of Object.keys(currentIds)) {
1110+
const updateValue = updatePayload[key];
1111+
if (typeof updateValue === 'string' || typeof updateValue === 'number' || typeof updateValue === 'bigint') {
1112+
result[key] = updateValue;
1113+
}
1114+
}
1115+
return result;
1116+
}
1117+
1118+
// updates foreign key fields inside `payload` based on relation id fields in `newIds`
1119+
private overrideForeignKeyFields(
1120+
model: string,
1121+
payload: any,
1122+
relation: FieldInfo,
1123+
newIds: Record<string, unknown>
1124+
) {
1125+
if (!relation.foreignKeyMapping || Object.keys(relation.foreignKeyMapping).length === 0) {
1126+
return;
1127+
}
1128+
1129+
// override foreign key values
1130+
for (const [id, fk] of Object.entries(relation.foreignKeyMapping)) {
1131+
if (payload[fk] !== undefined && newIds[id] !== undefined) {
1132+
payload[fk] = newIds[id];
1133+
}
1134+
}
1135+
1136+
// deal with compound id fields
1137+
const uniqueConstraints = this.utils.getUniqueConstraints(model);
1138+
for (const [name, constraint] of Object.entries(uniqueConstraints)) {
1139+
if (constraint.fields.length > 1) {
1140+
const target = payload[name];
1141+
if (target) {
1142+
for (const [id, fk] of Object.entries(relation.foreignKeyMapping)) {
1143+
if (target[fk] !== undefined && newIds[id] !== undefined) {
1144+
target[fk] = newIds[id];
1145+
}
1146+
}
1147+
}
1148+
}
1149+
}
1150+
}
1151+
10961152
// Validates the given update payload against Zod schema if any
10971153
private validateUpdateInputSchema(model: string, data: any) {
10981154
const schema = this.utils.getZodSchema(model, 'update');
@@ -1224,11 +1280,18 @@ export class PolicyProxyHandler<DbClient extends DbClientContract> implements Pr
12241280

12251281
const { result, error } = await this.transaction(async (tx) => {
12261282
const { where, create, update, ...rest } = args;
1227-
const existing = await this.utils.checkExistence(tx, this.model, args.where);
1283+
const existing = await this.utils.checkExistence(tx, this.model, where);
12281284

12291285
if (existing) {
12301286
// update case
1231-
const { result, postWriteChecks } = await this.doUpdate({ where, data: update, ...rest }, tx);
1287+
const { result, postWriteChecks } = await this.doUpdate(
1288+
{
1289+
where: this.utils.composeCompoundUniqueField(this.model, existing),
1290+
data: update,
1291+
...rest,
1292+
},
1293+
tx
1294+
);
12321295
await this.runPostWriteChecks(postWriteChecks, tx);
12331296
return this.utils.readBack(tx, this.model, 'update', args, result);
12341297
} else {

Diff for: packages/runtime/src/enhancements/policy/policy-utils.ts

+30
Original file line numberDiff line numberDiff line change
@@ -569,6 +569,27 @@ export class PolicyUtil {
569569
}
570570
}
571571

572+
composeCompoundUniqueField(model: string, fieldData: any) {
573+
const uniqueConstraints = this.modelMeta.uniqueConstraints?.[lowerCaseFirst(model)];
574+
if (!uniqueConstraints) {
575+
return fieldData;
576+
}
577+
578+
// e.g.: { a: '1', b: '1' } => { a_b: { a: '1', b: '1' } }
579+
const result: any = this.clone(fieldData);
580+
for (const [name, constraint] of Object.entries(uniqueConstraints)) {
581+
if (constraint.fields.length > 1 && constraint.fields.every((f) => fieldData[f] !== undefined)) {
582+
// multi-field unique constraint, compose it
583+
result[name] = constraint.fields.reduce<any>(
584+
(prev, field) => ({ ...prev, [field]: fieldData[field] }),
585+
{}
586+
);
587+
constraint.fields.forEach((f) => delete result[f]);
588+
}
589+
}
590+
return result;
591+
}
592+
572593
/**
573594
* Gets unique constraints for the given model.
574595
*/
@@ -642,6 +663,15 @@ export class PolicyUtil {
642663
// preserve the original structure
643664
currQuery[currField.backLink] = { ...visitWhere };
644665
}
666+
667+
if (forMutationPayload && currQuery[currField.backLink]) {
668+
// reconstruct compound unique field
669+
currQuery[currField.backLink] = this.composeCompoundUniqueField(
670+
backLinkField.type,
671+
currQuery[currField.backLink]
672+
);
673+
}
674+
645675
currQuery = currQuery[currField.backLink];
646676
}
647677
currField = field;

Diff for: packages/sdk/src/model-meta-generator.ts

+37-8
Original file line numberDiff line numberDiff line change
@@ -1,6 +1,7 @@
11
import {
22
ArrayExpr,
33
DataModel,
4+
DataModelAttribute,
45
DataModelField,
56
isArrayExpr,
67
isBooleanLiteral,
@@ -239,10 +240,7 @@ function getFieldAttributes(field: DataModelField): RuntimeAttribute[] {
239240
function getUniqueConstraints(model: DataModel) {
240241
const constraints: Array<{ name: string; fields: string[] }> = [];
241242

242-
// model-level constraints
243-
for (const attr of model.attributes.filter(
244-
(attr) => attr.decl.ref?.name === '@@unique' || attr.decl.ref?.name === '@@id'
245-
)) {
243+
const extractConstraint = (attr: DataModelAttribute) => {
246244
const argsMap = getAttributeArgs(attr);
247245
if (argsMap.fields) {
248246
const fieldNames = (argsMap.fields as ArrayExpr).items.map(
@@ -253,14 +251,45 @@ function getUniqueConstraints(model: DataModel) {
253251
// default constraint name is fields concatenated with underscores
254252
constraintName = fieldNames.join('_');
255253
}
256-
constraints.push({ name: constraintName, fields: fieldNames });
254+
return { name: constraintName, fields: fieldNames };
255+
} else {
256+
return undefined;
257+
}
258+
};
259+
260+
const addConstraint = (constraint: { name: string; fields: string[] }) => {
261+
if (!constraints.some((c) => c.name === constraint.name)) {
262+
constraints.push(constraint);
263+
}
264+
};
265+
266+
// field-level @id first
267+
for (const field of model.fields) {
268+
if (hasAttribute(field, '@id')) {
269+
addConstraint({ name: field.name, fields: [field.name] });
257270
}
258271
}
259272

260-
// field-level constraints
273+
// then model-level @@id
274+
for (const attr of model.attributes.filter((attr) => attr.decl.ref?.name === '@@id')) {
275+
const constraint = extractConstraint(attr);
276+
if (constraint) {
277+
addConstraint(constraint);
278+
}
279+
}
280+
281+
// then field-level @unique
261282
for (const field of model.fields) {
262-
if (hasAttribute(field, '@id') || hasAttribute(field, '@unique')) {
263-
constraints.push({ name: field.name, fields: [field.name] });
283+
if (hasAttribute(field, '@unique')) {
284+
addConstraint({ name: field.name, fields: [field.name] });
285+
}
286+
}
287+
288+
// then model-level @@unique
289+
for (const attr of model.attributes.filter((attr) => attr.decl.ref?.name === '@@unique')) {
290+
const constraint = extractConstraint(attr);
291+
if (constraint) {
292+
addConstraint(constraint);
264293
}
265294
}
266295

0 commit comments

Comments
 (0)