diff --git a/callbacks/preload.go b/callbacks/preload.go index fd8214bb2..0d648d944 100644 --- a/callbacks/preload.go +++ b/callbacks/preload.go @@ -12,29 +12,6 @@ import ( "gorm.io/gorm/utils" ) -// parsePreloadMap extracts nested preloads. e.g. -// -// // schema has a "k0" relation and a "k7.k8" embedded relation -// parsePreloadMap(schema, map[string][]interface{}{ -// clause.Associations: {"arg1"}, -// "k1": {"arg2"}, -// "k2.k3": {"arg3"}, -// "k4.k5.k6": {"arg4"}, -// }) -// // preloadMap is -// map[string]map[string][]interface{}{ -// "k0": {}, -// "k7": { -// "k8": {}, -// }, -// "k1": {}, -// "k2": { -// "k3": {"arg3"}, -// }, -// "k4": { -// "k5.k6": {"arg4"}, -// }, -// } func parsePreloadMap(s *schema.Schema, preloads map[string][]interface{}) map[string]map[string][]interface{} { preloadMap := map[string]map[string][]interface{}{} setPreloadMap := func(name, value string, args []interface{}) { @@ -74,7 +51,6 @@ func embeddedValues(embeddedRelations *schema.Relationships) []string { } names := make([]string, 0, len(embeddedRelations.Relations)+len(embeddedRelations.EmbeddedRelations)) for _, relation := range embeddedRelations.Relations { - // skip first struct name names = append(names, strings.Join(relation.Field.EmbeddedBindNames[1:], ".")) } for _, relations := range embeddedRelations.EmbeddedRelations { @@ -84,10 +60,7 @@ func embeddedValues(embeddedRelations *schema.Relationships) []string { } // preloadEntryPoint enters layer by layer. It will call real preload if it finds the right entry point. -// If the current relationship is embedded or joined, current query will be ignored. -// -//nolint:cyclop -func preloadEntryPoint(db *gorm.DB, joins []string, relationships *schema.Relationships, preloads map[string][]interface{}, associationsConds []interface{}) error { +func preloadEntryPoint(db *gorm.DB, joins []string, relationships *schema.Relationships, preloads map[string][]interface{}, associationsConds []interface{}, customJoin func(*gorm.DB) *gorm.DB) error { preloadMap := parsePreloadMap(db.Statement.Schema, preloads) // avoid random traversal of the map @@ -116,7 +89,7 @@ func preloadEntryPoint(db *gorm.DB, joins []string, relationships *schema.Relati for _, name := range preloadNames { if relations := relationships.EmbeddedRelations[name]; relations != nil { - if err := preloadEntryPoint(db, joins, relations, preloadMap[name], associationsConds); err != nil { + if err := preloadEntryPoint(db, joins, relations, preloadMap[name], associationsConds, customJoin); err != nil { return err } } else if rel := relationships.Relations[name]; rel != nil { @@ -138,14 +111,14 @@ func preloadEntryPoint(db *gorm.DB, joins []string, relationships *schema.Relati } tx := preloadDB(db, reflectValue, reflectValue.Interface()) - if err := preloadEntryPoint(tx, nestedJoins, &tx.Statement.Schema.Relationships, preloadMap[name], associationsConds); err != nil { + if err := preloadEntryPoint(tx, nestedJoins, &tx.Statement.Schema.Relationships, preloadMap[name], associationsConds, customJoin); err != nil { return err } } case reflect.Struct, reflect.Pointer: reflectValue := rel.Field.ReflectValueOf(db.Statement.Context, rv) tx := preloadDB(db, reflectValue, reflectValue.Interface()) - if err := preloadEntryPoint(tx, nestedJoins, &tx.Statement.Schema.Relationships, preloadMap[name], associationsConds); err != nil { + if err := preloadEntryPoint(tx, nestedJoins, &tx.Statement.Schema.Relationships, preloadMap[name], associationsConds, customJoin); err != nil { return err } default: @@ -155,7 +128,7 @@ func preloadEntryPoint(db *gorm.DB, joins []string, relationships *schema.Relati tx := db.Table("").Session(&gorm.Session{Context: db.Statement.Context, SkipHooks: db.Statement.SkipHooks}) tx.Statement.ReflectValue = db.Statement.ReflectValue tx.Statement.Unscoped = db.Statement.Unscoped - if err := preload(tx, rel, append(preloads[name], associationsConds...), preloadMap[name]); err != nil { + if err := preload(tx, rel, append(preloads[name], associationsConds...), preloadMap[name], customJoin); err != nil { return err } } @@ -182,7 +155,7 @@ func preloadDB(db *gorm.DB, reflectValue reflect.Value, dest interface{}) *gorm. return tx } -func preload(tx *gorm.DB, rel *schema.Relationship, conds []interface{}, preloads map[string][]interface{}) error { +func preload(tx *gorm.DB, rel *schema.Relationship, conds []interface{}, preloads map[string][]interface{}, customJoin func(*gorm.DB) *gorm.DB) error { var ( reflectValue = tx.Statement.ReflectValue relForeignKeys []string @@ -193,6 +166,10 @@ func preload(tx *gorm.DB, rel *schema.Relationship, conds []interface{}, preload inlineConds []interface{} ) + if customJoin != nil { + tx = customJoin(tx) + } + if rel.JoinTable != nil { var ( joinForeignFields = make([]*schema.Field, 0, len(rel.References)) @@ -268,7 +245,13 @@ func preload(tx *gorm.DB, rel *schema.Relationship, conds []interface{}, preload // nested preload for p, pvs := range preloads { - tx = tx.Preload(p, pvs...) + if customJoin != nil { + tx = tx.Preload(p, pvs, func(tx *gorm.DB) *gorm.DB { + return customJoin(tx) + }) + } else { + tx = tx.Preload(p, pvs...) + } } reflectResults := rel.FieldSchema.MakeSlice().Elem() diff --git a/callbacks/query.go b/callbacks/query.go index bbf238a9f..d013d2870 100644 --- a/callbacks/query.go +++ b/callbacks/query.go @@ -280,7 +280,7 @@ func Preload(db *gorm.DB) { return } - db.AddError(preloadEntryPoint(tx, joins, &tx.Statement.Schema.Relationships, db.Statement.Preloads, db.Statement.Preloads[clause.Associations])) + db.AddError(preloadEntryPoint(tx, joins, &tx.Statement.Schema.Relationships, db.Statement.Preloads, db.Statement.Preloads[clause.Associations], nil)) } } diff --git a/go.mod b/go.mod index 3060fc8f6..b55aabe0c 100644 --- a/go.mod +++ b/go.mod @@ -7,3 +7,8 @@ require ( github.com/jinzhu/now v1.1.5 golang.org/x/text v0.20.0 ) + +require ( + github.com/mattn/go-sqlite3 v1.14.22 // indirect + gorm.io/driver/sqlite v1.5.6 // indirect +) diff --git a/go.sum b/go.sum index 9af115728..90802b21e 100644 --- a/go.sum +++ b/go.sum @@ -2,5 +2,9 @@ github.com/jinzhu/inflection v1.0.0 h1:K317FqzuhWc8YvSVlFMCCUb36O/S9MCKRDI7QkRKD github.com/jinzhu/inflection v1.0.0/go.mod h1:h+uFLlag+Qp1Va5pdKtLDYj+kHp5pxUVkryuEj+Srlc= github.com/jinzhu/now v1.1.5 h1:/o9tlHleP7gOFmsnYNz3RGnqzefHA47wQpKrrdTIwXQ= github.com/jinzhu/now v1.1.5/go.mod h1:d3SSVoowX0Lcu0IBviAWJpolVfI5UJVZZ7cO71lE/z8= +github.com/mattn/go-sqlite3 v1.14.22 h1:2gZY6PC6kBnID23Tichd1K+Z0oS6nE/XwU+Vz/5o4kU= +github.com/mattn/go-sqlite3 v1.14.22/go.mod h1:Uh1q+B4BYcTPb+yiD3kU8Ct7aC0hY9fxUwlHK0RXw+Y= golang.org/x/text v0.20.0 h1:gK/Kv2otX8gz+wn7Rmb3vT96ZwuoxnQlY+HlJVj7Qug= golang.org/x/text v0.20.0/go.mod h1:D4IsuqiFMhST5bX19pQ9ikHC2GsaKyk/oF+pn3ducp4= +gorm.io/driver/sqlite v1.5.6 h1:fO/X46qn5NUEEOZtnjJRWRzZMe8nqJiQ9E+0hi+hKQE= +gorm.io/driver/sqlite v1.5.6/go.mod h1:U+J8craQU6Fzkcvu8oLeAQmi50TkwPEhHDEjQZXDah4= diff --git a/tests/preload_custom_test.go b/tests/preload_custom_test.go new file mode 100644 index 000000000..1afb21303 --- /dev/null +++ b/tests/preload_custom_test.go @@ -0,0 +1,205 @@ +package tests_test + +import ( + "testing" + "time" + + "gorm.io/driver/sqlite" + "gorm.io/gorm" +) + +// Structs for preload tests +type PreloadItem struct { + ID uint + Name string + Tags []PreloadTag `gorm:"many2many:preload_items_preload_tags"` + CreatedAt time.Time +} + +type PreloadTag struct { + ID uint + Name string + Status string + SubTags []PreloadSubTag `gorm:"many2many:tag_sub_tags"` +} + +type PreloadSubTag struct { + ID uint + Name string + Status string +} + +// Setup database for preload tests +func setupPreloadTestDB(t *testing.T) *gorm.DB { + db, err := gorm.Open(sqlite.Open(":memory:"), &gorm.Config{}) + if err != nil { + t.Fatalf("failed to connect database: %v", err) + } + err = db.AutoMigrate(&PreloadItem{}, &PreloadTag{}, &PreloadSubTag{}) + if err != nil { + t.Fatalf("failed to migrate database: %v", err) + } + return db +} + +// Test default preload functionality +func TestDefaultPreload(t *testing.T) { + db := setupPreloadTestDB(t) + + tag1 := PreloadTag{Name: "Tag1", Status: "active"} + item := PreloadItem{Name: "Item1", Tags: []PreloadTag{tag1}} + db.Create(&item) + + var items []PreloadItem + err := db.Preload("Tags").Find(&items).Error + if err != nil { + t.Fatalf("default preload failed: %v", err) + } + + if len(items) != 1 || len(items[0].Tags) != 1 || items[0].Tags[0].Name != "Tag1" { + t.Errorf("unexpected default preload results: %v", items) + } +} + +// Test preloading with custom joins and conditions +func TestCustomJoinsWithConditions(t *testing.T) { + db := setupPreloadTestDB(t) + + tag1 := PreloadTag{Name: "Tag1", Status: "active"} + tag2 := PreloadTag{Name: "Tag2", Status: "inactive"} + item := PreloadItem{Name: "Item1", Tags: []PreloadTag{tag1, tag2}} + db.Create(&item) + + var items []PreloadItem + err := db.Preload("Tags", func(tx *gorm.DB) *gorm.DB { + return tx.Joins("JOIN preload_items_preload_tags ON preload_items_preload_tags.preload_tag_id = preload_tags.id"). + Where("preload_tags.status = ?", "active") + }).Find(&items).Error + if err != nil { + t.Fatalf("custom join with conditions failed: %v", err) + } + + if len(items) != 1 || len(items[0].Tags) != 1 || items[0].Tags[0].Status != "active" { + t.Errorf("unexpected results in TestCustomJoinsWithConditions: %v", items) + } +} + +// Test nested preload functionality with custom joins +func TestNestedPreloadWithCustomJoins(t *testing.T) { + db := setupPreloadTestDB(t) + + subTag := PreloadSubTag{Name: "SubTag1", Status: "active"} + tag := PreloadTag{Name: "Tag1", Status: "active", SubTags: []PreloadSubTag{subTag}} + item := PreloadItem{Name: "Item1", Tags: []PreloadTag{tag}} + db.Create(&item) + + var items []PreloadItem + err := db.Preload("Tags.SubTags", func(tx *gorm.DB) *gorm.DB { + return tx.Joins("JOIN tag_sub_tags ON tag_sub_tags.preload_sub_tag_id = preload_sub_tags.id"). + Where("preload_sub_tags.status = ?", "active") + }).Find(&items).Error + if err != nil { + t.Fatalf("nested preload with custom joins failed: %v", err) + } + + if len(items) != 1 || len(items[0].Tags) != 1 || len(items[0].Tags[0].SubTags) != 1 || items[0].Tags[0].SubTags[0].Name != "SubTag1" { + t.Errorf("unexpected nested preload results: %v", items) + } +} + +// Test behavior when no matching records exist +func TestNoMatchingRecords(t *testing.T) { + db := setupPreloadTestDB(t) + + tag := PreloadTag{Name: "Tag1", Status: "inactive"} + item := PreloadItem{Name: "Item1", Tags: []PreloadTag{tag}} + db.Create(&item) + + var items []PreloadItem + err := db.Preload("Tags", func(tx *gorm.DB) *gorm.DB { + return tx.Joins("JOIN preload_items_preload_tags ON preload_items_preload_tags.preload_tag_id = preload_tags.id"). + Where("preload_tags.status = ?", "active") + }).Find(&items).Error + if err != nil { + t.Fatalf("preload with no matching records failed: %v", err) + } + + if len(items) != 1 || len(items[0].Tags) != 0 { + t.Errorf("unexpected results in TestNoMatchingRecords: %v", items) + } +} + +// Test behavior with an empty database +func TestEmptyDatabase(t *testing.T) { + db := setupPreloadTestDB(t) + + var items []PreloadItem + err := db.Preload("Tags").Find(&items).Error + if err != nil { + t.Fatalf("preload with empty database failed: %v", err) + } + + if len(items) != 0 { + t.Errorf("unexpected results in TestEmptyDatabase: %v", items) + } +} + +// Test multiple items with different tag statuses +func TestMultipleItemsWithDifferentTagStatuses(t *testing.T) { + db := setupPreloadTestDB(t) + + tag1 := PreloadTag{Name: "Tag1", Status: "active"} + tag2 := PreloadTag{Name: "Tag2", Status: "inactive"} + item1 := PreloadItem{Name: "Item1", Tags: []PreloadTag{tag1}} + item2 := PreloadItem{Name: "Item2", Tags: []PreloadTag{tag2}} + db.Create(&item1) + db.Create(&item2) + + var items []PreloadItem + err := db.Preload("Tags", func(tx *gorm.DB) *gorm.DB { + return tx.Joins("JOIN preload_items_preload_tags ON preload_items_preload_tags.preload_tag_id = preload_tags.id"). + Where("preload_tags.status = ?", "active") + }).Find(&items).Error + if err != nil { + t.Fatalf("preload with multiple items failed: %v", err) + } + + if len(items) != 2 || len(items[0].Tags) != 1 || len(items[1].Tags) != 0 { + t.Errorf("unexpected results in TestMultipleItemsWithDifferentTagStatuses: %v", items) + } +} + +// Test duplicate preload conditions +func TestDuplicatePreloadConditions(t *testing.T) { + db := setupPreloadTestDB(t) + + tag1 := PreloadTag{Name: "Tag1", Status: "active"} + tag2 := PreloadTag{Name: "Tag2", Status: "inactive"} + item := PreloadItem{Name: "Item1", Tags: []PreloadTag{tag1, tag2}} + db.Create(&item) + + var activeTagsItems []PreloadItem + var inactiveTagsItems []PreloadItem + + err := db.Preload("Tags", func(tx *gorm.DB) *gorm.DB { + return tx.Where("status = ?", "active") + }).Find(&activeTagsItems).Error + if err != nil { + t.Fatalf("preload for active tags failed: %v", err) + } + + err = db.Preload("Tags", func(tx *gorm.DB) *gorm.DB { + return tx.Where("status = ?", "inactive") + }).Find(&inactiveTagsItems).Error + if err != nil { + t.Fatalf("preload for inactive tags failed: %v", err) + } + + if len(activeTagsItems) != 1 || len(activeTagsItems[0].Tags) != 1 || activeTagsItems[0].Tags[0].Status != "active" { + t.Errorf("unexpected active tag results in TestDuplicatePreloadConditions: %v", activeTagsItems) + } + + if len(inactiveTagsItems) != 1 || len(inactiveTagsItems[0].Tags) != 1 || inactiveTagsItems[0].Tags[0].Status != "inactive" { + t.Errorf("unexpected inactive tag results in TestDuplicatePreloadConditions: %v", inactiveTagsItems) + } +}