diff --git a/callbacks/update.go b/callbacks/update.go index c9fa3830c..db7f513c5 100644 --- a/callbacks/update.go +++ b/callbacks/update.go @@ -32,24 +32,37 @@ func SetupUpdateReflectValue(db *gorm.DB) { // BeforeUpdate before update hooks func BeforeUpdate(db *gorm.DB) { if db.Error == nil && db.Statement.Schema != nil && !db.Statement.SkipHooks && (db.Statement.Schema.BeforeSave || db.Statement.Schema.BeforeUpdate) { - callMethod(db, func(value interface{}, tx *gorm.DB) (called bool) { - // save a copy before executing the hook so that can find out which fields were modified after the hook is executed. - rv := reflect.Indirect(reflect.ValueOf(value)) - rvClone := reflect.New(rv.Type()).Elem() - rvClone.Set(rv) - + callMethod(db, func(value interface{}, tx *gorm.DB) bool { + var ( + beforeSaveInterface BeforeSaveInterface + isBeforeSaveHook bool + beforeUpdateInterface BeforeUpdateInterface + isBeforeUpdateHook bool + ) if db.Statement.Schema.BeforeSave { - if i, ok := value.(BeforeSaveInterface); ok { - called = true - db.AddError(i.BeforeSave(tx)) - } + beforeSaveInterface, isBeforeSaveHook = value.(BeforeSaveInterface) } - if db.Statement.Schema.BeforeUpdate { - if i, ok := value.(BeforeUpdateInterface); ok { - called = true - db.AddError(i.BeforeUpdate(tx)) - } + beforeUpdateInterface, isBeforeUpdateHook = value.(BeforeUpdateInterface) + } + + var ( + called bool + rv reflect.Value + rvSnapshot reflect.Value + ) + if isBeforeSaveHook || isBeforeUpdateHook { + called = true + // save a snapshot of the struct before the hook was called + rv = reflect.Indirect(reflect.ValueOf(value)) + rvSnapshot = reflect.New(rv.Type()).Elem() + rvSnapshot.Set(rv) + } + if isBeforeSaveHook { + db.AddError(beforeSaveInterface.BeforeSave(tx)) + } + if isBeforeUpdateHook { + db.AddError(beforeUpdateInterface.BeforeUpdate(tx)) } if called { @@ -61,8 +74,8 @@ func BeforeUpdate(db *gorm.DB) { if !ok { continue } - // compare with the copy value and update the field if there is a difference - if !reflect.DeepEqual(rv.FieldByName(field.Name).Interface(), rvClone.FieldByName(field.Name).Interface()) { + // compare with the snapshot and update the field if there is a difference + if !reflect.DeepEqual(rv.FieldByName(field.Name).Interface(), rvSnapshot.FieldByName(field.Name).Interface()) { db.Statement.SetColumn(dbFieldName, rv.FieldByName(field.Name).Interface()) } }