Skip to content

Commit 9cdaf44

Browse files
committed
feat: add before find hook
1 parent 9f27377 commit 9cdaf44

File tree

4 files changed

+22
-2
lines changed

4 files changed

+22
-2
lines changed

callbacks/callbacks.go

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -48,6 +48,7 @@ func RegisterDefaultCallbacks(db *gorm.DB, config *Config) {
4848
createCallback.Clauses = config.CreateClauses
4949

5050
queryCallback := db.Callback().Query()
51+
queryCallback.Register("gorm:before_query", BeforeQuery)
5152
queryCallback.Register("gorm:query", Query)
5253
queryCallback.Register("gorm:preload", Preload)
5354
queryCallback.Register("gorm:after_query", AfterQuery)

callbacks/interfaces.go

Lines changed: 4 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -34,6 +34,10 @@ type AfterDeleteInterface interface {
3434
AfterDelete(*gorm.DB) error
3535
}
3636

37+
type BeforeFindInterface interface {
38+
BeforeFind(*gorm.DB) error
39+
}
40+
3741
type AfterFindInterface interface {
3842
AfterFind(*gorm.DB) error
3943
}

callbacks/query.go

Lines changed: 12 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -11,6 +11,18 @@ import (
1111
"gorm.io/gorm/utils"
1212
)
1313

14+
func BeforeQuery(db *gorm.DB) {
15+
if db.Error == nil && db.Statement.Schema != nil && !db.Statement.Statement.SkipHooks && db.Statement.Schema.BeforeFind {
16+
callMethod(db, func(value interface{}, tx *gorm.DB) bool {
17+
if i, ok := value.(BeforeFindInterface); ok {
18+
db.AddError(i.BeforeFind(tx))
19+
return true
20+
}
21+
return false
22+
})
23+
}
24+
}
25+
1426
func Query(db *gorm.DB) {
1527
if db.Error == nil {
1628
BuildQuerySQL(db)

schema/schema.go

Lines changed: 5 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -24,6 +24,7 @@ const (
2424
callbackTypeAfterSave callbackType = "AfterSave"
2525
callbackTypeBeforeDelete callbackType = "BeforeDelete"
2626
callbackTypeAfterDelete callbackType = "AfterDelete"
27+
callbackTypeBeforeFind callbackType = "BeforeFind"
2728
callbackTypeAfterFind callbackType = "AfterFind"
2829
)
2930

@@ -52,7 +53,7 @@ type Schema struct {
5253
BeforeUpdate, AfterUpdate bool
5354
BeforeDelete, AfterDelete bool
5455
BeforeSave, AfterSave bool
55-
AfterFind bool
56+
BeforeFind, AfterFind bool
5657
err error
5758
initialized chan struct{}
5859
namer Namer
@@ -308,7 +309,7 @@ func ParseWithSpecialTableName(dest interface{}, cacheStore *sync.Map, namer Nam
308309
callbackTypeBeforeUpdate, callbackTypeAfterUpdate,
309310
callbackTypeBeforeSave, callbackTypeAfterSave,
310311
callbackTypeBeforeDelete, callbackTypeAfterDelete,
311-
callbackTypeAfterFind,
312+
callbackTypeBeforeFind, callbackTypeAfterFind,
312313
}
313314
for _, cbName := range callbackTypes {
314315
if methodValue := callBackToMethodValue(modelValue, cbName); methodValue.IsValid() {
@@ -396,6 +397,8 @@ func callBackToMethodValue(modelType reflect.Value, cbType callbackType) reflect
396397
return modelType.MethodByName(string(callbackTypeBeforeDelete))
397398
case callbackTypeAfterDelete:
398399
return modelType.MethodByName(string(callbackTypeAfterDelete))
400+
case callbackTypeBeforeFind:
401+
return modelType.MethodByName(string(callbackTypeBeforeFind))
399402
case callbackTypeAfterFind:
400403
return modelType.MethodByName(string(callbackTypeAfterFind))
401404
default:

0 commit comments

Comments
 (0)