From e44a1e593fae7131d3d4b4971bb6a1b51b9e0011 Mon Sep 17 00:00:00 2001 From: Taiga Tachibana Date: Sun, 7 Nov 2021 20:23:47 +0900 Subject: [PATCH 01/12] modify interface and implementation --- bulk_insert.go | 83 ++++++++++++++++---------------------------------- 1 file changed, 27 insertions(+), 56 deletions(-) diff --git a/bulk_insert.go b/bulk_insert.go index 9970baf..cfba7ce 100644 --- a/bulk_insert.go +++ b/bulk_insert.go @@ -25,79 +25,58 @@ import ( func BulkInsert(db *gorm.DB, objects []interface{}, chunkSize int, excludeColumns ...string) error { // Split records with specified size not to exceed Database parameter limit for _, objSet := range splitObjects(objects, chunkSize) { - if err := insertObjSet(db, objSet, excludeColumns...); err != nil { + _, err := insertObjSet(db, objSet, excludeColumns...) + if err != nil { return err } } return nil } -// BulkInsertWithAssigningIDs executes the query to insert multiple records at once. -// it will scan the result of `returning id` or `returning *` to [returnedValue] after every insert. -// it's necessary to set "gorm:insert_option"="returning id" in *gorm.DB -// -// [returnedValue] slice of primary_key or model, must be a *[]uint(for integer), *[]string(for uuid), *[]struct(for `returning *`) +// BulkInsertWithReturningValues executes the query to insert multiple records at once. +// This will scan the returned value into `dstValues`. +// It's necessary to set "gorm:insert_option" to execute "returning" query. // // [objects] must be a slice of struct. // +// [dstValues] must be a point to a slice of struct. Struct properties must correspond to returning results. +// // [chunkSize] is a number of variables embedded in query. To prevent the error which occurs embedding a large number of variables at once // and exceeds the limit of prepared statement. Larger size normally leads to better performance, in most cases 2000 to 3000 is reasonable. // // [excludeColumns] is column names to exclude from insert. -func BulkInsertWithAssigningIDs(db *gorm.DB, returnedValue interface{}, objects []interface{}, chunkSize int, excludeColumns ...string) error { - typ := reflect.TypeOf(returnedValue) - if typ.Kind() != reflect.Ptr || typ.Elem().Kind() != reflect.Slice { - return errors.New("returningId must be a slice ptr") +func BulkInsertWithReturningValues(db *gorm.DB, objects []interface{}, returnedVals interface{}, chunkSize int, excludeColumns ...string) error { + typ := reflect.TypeOf(returnedVals) + if typ.Kind() != reflect.Ptr || typ.Elem().Kind() != reflect.Slice || typ.Elem().Elem().Kind() != reflect.Struct { + return errors.New("returnedVals must be a pointer to a slice of struct") } - allIds := reflect.Indirect(reflect.ValueOf(returnedValue)) - typ = allIds.Type() - - // Deference value of slice - valueTyp := typ.Elem() - for valueTyp.Kind() == reflect.Ptr { - valueTyp = valueTyp.Elem() - } + refDst := reflect.Indirect(reflect.ValueOf(returnedVals)) + typ = refDst.Type() // Split records with specified size not to exceed Database parameter limit for _, objSet := range splitObjects(objects, chunkSize) { - returnValueSlice := reflect.New(typ) - var scanReturningId func(*gorm.DB) error - switch valueTyp.Kind() { - case reflect.Struct: - // If user want to scan `returning *` with returnedValue=[]struct{...} - scanReturningId = func(db *gorm.DB) error { - return db.Scan(returnValueSlice.Interface()).Error - } - default: - // If user want to scan primary key `returning pk` with returnedValue=[]struct{...} - pk := db.NewScope(objects[0]).PrimaryKey() - scanReturningId = func(db *gorm.DB) error { - return db.Pluck(pk, returnValueSlice.Interface()).Error - } + db, err := insertObjSet(db, objSet, excludeColumns...) + if err != nil { + return err } - - if err := insertObjSetWithCallback(db, objSet, scanReturningId, excludeColumns...); err != nil { + scanned := reflect.New(typ) + if err := db.Scan(scanned.Interface()).Error; err != nil { return err } - - allIds.Set(reflect.AppendSlice(allIds, returnValueSlice.Elem())) + refDst.Set(reflect.AppendSlice(refDst, scanned.Elem())) } return nil } -func insertObjSet(db *gorm.DB, objects []interface{}, excludeColumns ...string) error { - return insertObjSetWithCallback(db, objects, nil, excludeColumns...) -} - -func insertObjSetWithCallback(db *gorm.DB, objects []interface{}, postInsert func(*gorm.DB) error, excludeColumns ...string) error { +func insertObjSet(db *gorm.DB, objects []interface{}, excludeColumns ...string) (*gorm.DB, error) { if len(objects) == 0 { - return nil + return db, nil } firstAttrs, err := extractMapValue(objects[0], excludeColumns) if err != nil { - return err + return db, err } attrSize := len(firstAttrs) @@ -116,12 +95,12 @@ func insertObjSetWithCallback(db *gorm.DB, objects []interface{}, postInsert fun for _, obj := range objects { objAttrs, err := extractMapValue(obj, excludeColumns) if err != nil { - return err + return db, err } // If object sizes are different, SQL statement loses consistency if len(objAttrs) != attrSize { - return errors.New("attribute sizes are inconsistent") + return db, errors.New("attribute sizes are inconsistent") } scope := db.NewScope(obj) @@ -144,7 +123,7 @@ func insertObjSetWithCallback(db *gorm.DB, objects []interface{}, postInsert fun if val, ok := db.Get("gorm:insert_option"); ok { strVal, ok := val.(string) if !ok { - return errors.New("gorm:insert_option should be a string") + return db, errors.New("gorm:insert_option should be a string") } insertOption = strVal } @@ -157,18 +136,10 @@ func insertObjSetWithCallback(db *gorm.DB, objects []interface{}, postInsert fun )) db = db.Raw(mainScope.SQL, mainScope.SQLVars...) - if err := db.Error; err != nil { - return err - } - - if postInsert != nil { - if err := postInsert(db); err != nil { - return err - } + return db, err } - - return nil + return db, nil } // Obtain columns and values required for insert from interface From 3903de24c55eb350cecd9d4ee0142c6b0edb35ec Mon Sep 17 00:00:00 2001 From: Taiga Tachibana Date: Sun, 7 Nov 2021 20:23:56 +0900 Subject: [PATCH 02/12] add test for BulkInsertWithReturningValues --- bulk_insert_test.go | 128 +++++++++++++++++--------------------------- 1 file changed, 50 insertions(+), 78 deletions(-) diff --git a/bulk_insert_test.go b/bulk_insert_test.go index 1dd9099..1c0d01b 100644 --- a/bulk_insert_test.go +++ b/bulk_insert_test.go @@ -26,6 +26,55 @@ type fakeTable struct { UpdatedAt time.Time } +func TestBulkInsertWithReturningValues(t *testing.T) { + type Table struct { + ID uint `gorm:"primary_key;auto_increment"` + RegularColumn string + Custom string `gorm:"column:ThisIsCamelCase"` + } + + db, mock, err := sqlmock.New() + require.NoError(t, err) + + defer db.Close() + + gdb, err := gorm.Open("mysql", db) + require.NoError(t, err) + + mock.ExpectQuery( + "INSERT INTO `tables` \\(`ThisIsCamelCase`, `regular_column`\\)", + ).WithArgs( + "first custom", "first regular", + "second custom", "second regular", + ).WillReturnRows( + sqlmock.NewRows([]string{"id", "ThisIsCamelCase", "regular_column"}). + AddRow(1, "first custom", "first regular"). + AddRow(2, "second custom", "second regular"), + ) + + var returnedVals []Table + obj := []interface{}{ + Table{ + RegularColumn: "first regular", + Custom: "first custom", + }, + Table{ + RegularColumn: "second regular", + Custom: "second custom", + }, + } + + gdb = gdb.Set("gorm_insert_option", "RETURNING id, ThisIsCamelCase, regular_column") + err = BulkInsertWithReturningValues(gdb, obj, &returnedVals, 1000) + require.NoError(t, err) + + expected := []Table{ + {ID: 1, RegularColumn: "first regular", Custom: "first custom"}, + {ID: 2, RegularColumn: "second regular", Custom: "second custom"}, + } + assert.Equal(t, expected, returnedVals) +} + func Test_extractMapValue(t *testing.T) { collectKeys := func(val map[string]interface{}) []string { keys := make([]string, 0, len(val)) @@ -99,7 +148,7 @@ func Test_insertObject(t *testing.T) { sqlmock.NewResult(1, 1), ) - err = insertObjSet(gdb, []interface{}{ + _, err = insertObjSet(gdb, []interface{}{ Table{ RegularColumn: "first regular", Custom: "first custom", @@ -113,59 +162,6 @@ func Test_insertObject(t *testing.T) { require.NoError(t, err) } -func Test_insertObjSetWithCallback(t *testing.T) { - type Table struct { - ID uint `gorm:"primary_key;auto_increment"` - RegularColumn string - Custom string `gorm:"column:ThisIsCamelCase"` - } - - db, mock, err := sqlmock.New() - require.NoError(t, err) - - defer db.Close() - - gdb, err := gorm.Open("mysql", db) - require.NoError(t, err) - - mock.ExpectQuery( - "INSERT INTO `tables` \\(`ThisIsCamelCase`, `regular_column`\\)", - ).WithArgs( - "first custom", "first regular", - "second custom", "second regular", - ).WillReturnRows( - sqlmock.NewRows([]string{"id"}).AddRow(1).AddRow(2), - ) - - returningIdScope := func(db *gorm.DB) *gorm.DB { - return db.Set("gorm:insert_option", "returning id") - } - - err = insertObjSetWithCallback(gdb.Scopes(returningIdScope), []interface{}{ - Table{ - RegularColumn: "first regular", - Custom: "first custom", - }, - Table{ - RegularColumn: "second regular", - Custom: "second custom", - }, - }, func(db *gorm.DB) error { - var ids []uint - if err := db.Pluck("id", &ids).Error; err != nil { - return err - } - require.Len(t, ids, 2, "must return 2 ids") - return nil - }) - - if err != nil { - t.Fatal(err) - } - - require.NoError(t, err) -} - func Test_fieldIsAutoIncrement(t *testing.T) { type explicitSetTable struct { ID int `gorm:"column:id;auto_increment"` @@ -195,27 +191,3 @@ func Test_fieldIsAutoIncrement(t *testing.T) { } } } - -func Test_fieldIsPrimaryAndBlank(t *testing.T) { - type notPrimaryTable struct { - Dummy int - } - type primaryKeyTable struct { - ID int `gorm:"column:id;primary_key"` - } - - cases := []struct { - Value interface{} - Expected bool - }{ - {notPrimaryTable{Dummy: 0}, false}, - {notPrimaryTable{Dummy: 1}, false}, - {primaryKeyTable{ID: 0}, true}, - {primaryKeyTable{ID: 1}, false}, - } - for _, c := range cases { - for _, field := range (&gorm.Scope{Value: c.Value}).Fields() { - assert.Equal(t, fieldIsPrimaryAndBlank(field), c.Expected) - } - } -} From d96d08efbe43f6c0c9efb306d7f079f189261e84 Mon Sep 17 00:00:00 2001 From: Taiga Tachibana Date: Sun, 7 Nov 2021 20:32:53 +0900 Subject: [PATCH 03/12] add test for invalid type --- bulk_insert_test.go | 25 ++++++++++++++++++++++++- 1 file changed, 24 insertions(+), 1 deletion(-) diff --git a/bulk_insert_test.go b/bulk_insert_test.go index 1c0d01b..892972f 100644 --- a/bulk_insert_test.go +++ b/bulk_insert_test.go @@ -35,7 +35,6 @@ func TestBulkInsertWithReturningValues(t *testing.T) { db, mock, err := sqlmock.New() require.NoError(t, err) - defer db.Close() gdb, err := gorm.Open("mysql", db) @@ -75,6 +74,30 @@ func TestBulkInsertWithReturningValues(t *testing.T) { assert.Equal(t, expected, returnedVals) } +func TestBulkInsertWithReturningValues_InvalidTypeOfReturnedVals(t *testing.T) { + db, _, err := sqlmock.New() + require.NoError(t, err) + defer db.Close() + + gdb, err := gorm.Open("mysql", db) + require.NoError(t, err) + + tests := []struct { + name string + vals interface{} + } { + {name: "not a pointer", vals: []struct{Name string}{{Name: "1"}}}, + {name: "element is not a slice", vals: &struct{Name string}{Name: "1"}}, + {name: "slice element is not a struct", vals: &[]string{"1"}}, + } + for _, tt := range tests { + t.Run(tt.name, func(t *testing.T) { + err := BulkInsertWithReturningValues(gdb, []interface{}{}, tt.vals, 1000) + assert.EqualError(t, err, "returnedVals must be a pointer to a slice of struct") + }) + } +} + func Test_extractMapValue(t *testing.T) { collectKeys := func(val map[string]interface{}) []string { keys := make([]string, 0, len(val)) From 59174e7cdce339a2c630b210fb0ca6eba2e8459c Mon Sep 17 00:00:00 2001 From: Taiga Tachibana Date: Sun, 7 Nov 2021 20:38:38 +0900 Subject: [PATCH 04/12] update README --- README.md | 76 +++++++++---------------------------------------------- 1 file changed, 12 insertions(+), 64 deletions(-) diff --git a/README.md b/README.md index b40ab41..70627e0 100644 --- a/README.md +++ b/README.md @@ -46,8 +46,6 @@ In the above pattern `Name` and `Email` fields are excluded. ## Example -### BulkInsert - ```go package main @@ -91,73 +89,23 @@ func main() { // do something } - // columns you want to exclude from Insert, specify as an argument + // Columns you want to exclude from Insert, specify as an argument err = gormbulk.BulkInsert(db, insertRecords, 3000, "Email") if err != nil { // do something } -} -``` - -### BulkInsertWithAssigningIDs - -```go -package main - -import ( - "fmt" - "time" - - "github.com/jinzhu/gorm" - _ "github.com/jinzhu/gorm/dialects/postgres" - gormbulk "github.com/t-tiger/gorm-bulk-insert/v2" -) - -type fakeTable struct { - IdPK uint `gorm:"primary_key"` - CreatedAt time.Time `gorm:"default:now()"` - Data string -} - -func main() { - db, err := gorm.Open("postgres", "host=localhost port=5432 user=cloudwalker dbname=cloudwalker password=cloudwalker sslmode=disable") + + // Fetch returning values + dbForReturning := db.Set("gorm:insert_option", "RETURNING id, name, created_at") + var returned []struct{ + ID int + Name string + CreatedAt time.Time + } + err = gormbulk.BulkInsertWithReturningValues(dbForReturning, insertRecords, &returned, 3000) if err != nil { - panic(err) - } - defer db.Close() - db.SingularTable(true) - - if err := db.AutoMigrate(fakeTable{}).Error; err != nil { - panic(err) - } - - models := []interface{}{ - fakeTable{Data: "aaa"}, - fakeTable{Data: "bbb"}, - fakeTable{Data: "ccc"}, - } - - // if you want to scan * back - var returnModel []fakeTable - if err := gormbulk.BulkInsertWithAssigningIDs( - db.Set("gorm:insert_option", "returning *"), &returnModel, models, 1000); err != nil { - panic(err) - } - fmt.Printf("success to insert with returning: %+v\n", returnModel) - // success to insert with returning: [ - // {IdPK:1 CreatedAt:2021-10-31 16:21:48.019947 +0000 UTC Data:aaa} - // {IdPK:2 CreatedAt:2021-10-31 16:21:48.019959 +0000 UTC Data:bbb} - // {IdPK:3 CreatedAt:2021-10-31 16:21:48.019965 +0000 UTC Data:ccc} - // ] - - // if you want to scan primary key - var returnId []uint - if err := gormbulk.BulkInsertWithAssigningIDs( - db.Set("gorm:insert_option", "returning id"), &returnId, models, 1000); err != nil { - panic(err) - } - fmt.Printf("success to insert with returning: %+v\n", returnId) - // `success to insert with returning: [4 5 6]` + // do something + } } ``` From b4155aa2b8f86b2215159bf6d1607868b19e392d Mon Sep 17 00:00:00 2001 From: Taiga Tachibana Date: Sun, 7 Nov 2021 20:40:06 +0900 Subject: [PATCH 05/12] modify comment --- bulk_insert.go | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/bulk_insert.go b/bulk_insert.go index cfba7ce..dba72a0 100644 --- a/bulk_insert.go +++ b/bulk_insert.go @@ -39,7 +39,7 @@ func BulkInsert(db *gorm.DB, objects []interface{}, chunkSize int, excludeColumn // // [objects] must be a slice of struct. // -// [dstValues] must be a point to a slice of struct. Struct properties must correspond to returning results. +// [returnedVals] must be a point to a slice of struct. Values returned from `RETURNING` clause will be assigned. // // [chunkSize] is a number of variables embedded in query. To prevent the error which occurs embedding a large number of variables at once // and exceeds the limit of prepared statement. Larger size normally leads to better performance, in most cases 2000 to 3000 is reasonable. From 88dc3250ad1f14fda1ba1ff420dd07b15557a1ce Mon Sep 17 00:00:00 2001 From: Taiga Tachibana Date: Sun, 7 Nov 2021 20:45:40 +0900 Subject: [PATCH 06/12] update README --- README.md | 100 ++++++++++++++++++++++++++++-------------------------- 1 file changed, 51 insertions(+), 49 deletions(-) diff --git a/README.md b/README.md index 70627e0..7681d64 100644 --- a/README.md +++ b/README.md @@ -50,62 +50,64 @@ In the above pattern `Name` and `Email` fields are excluded. package main import ( - "fmt" - "log" - "time" + "fmt" + "log" + "time" - "github.com/jinzhu/gorm" - _ "github.com/jinzhu/gorm/dialects/mysql" - gormbulk "github.com/t-tiger/gorm-bulk-insert/v2" + "github.com/jinzhu/gorm" + _ "github.com/jinzhu/gorm/dialects/mysql" + gormbulk "github.com/t-tiger/gorm-bulk-insert/v2" ) type fakeTable struct { - ID int `gorm:"AUTO_INCREMENT"` - Name string - Email string - CreatedAt time.Time - UpdatedAt time.Time + ID int `gorm:"AUTO_INCREMENT"` + Name string + Email string + CreatedAt time.Time + UpdatedAt time.Time } func main() { - db, err := gorm.Open("mysql", "mydb") - if err != nil { - log.Fatal(err) - } - - var insertRecords []interface{} - for i := 0; i < 10; i++ { - insertRecords = append(insertRecords, - fakeTable{ - Name: fmt.Sprintf("name%d", i), - Email: fmt.Sprintf("test%d@test.com", i), - // you don't need to set CreatedAt, UpdatedAt - }, - ) - } - - err = gormbulk.BulkInsert(db, insertRecords, 3000) - if err != nil { - // do something - } - - // Columns you want to exclude from Insert, specify as an argument - err = gormbulk.BulkInsert(db, insertRecords, 3000, "Email") - if err != nil { - // do something - } - - // Fetch returning values - dbForReturning := db.Set("gorm:insert_option", "RETURNING id, name, created_at") - var returned []struct{ - ID int - Name string - CreatedAt time.Time - } - err = gormbulk.BulkInsertWithReturningValues(dbForReturning, insertRecords, &returned, 3000) - if err != nil { - // do something - } + db, err := gorm.Open("mysql", "mydb") + if err != nil { + log.Fatal(err) + } + + var insertRecords []interface{} + for i := 0; i < 10; i++ { + insertRecords = append(insertRecords, + fakeTable{ + Name: fmt.Sprintf("name%d", i), + Email: fmt.Sprintf("test%d@test.com", i), + // you don't need to set CreatedAt, UpdatedAt + }, + ) + } + + err = gormbulk.BulkInsert(db, insertRecords, 3000) + if err != nil { + // do something + } + + // Columns you want to exclude from Insert, specify as an argument + err = gormbulk.BulkInsert(db, insertRecords, 3000, "Email") + if err != nil { + // do something + } + + // Fetch returning values + dbForReturning := db.Set("gorm:insert_option", "RETURNING id, name, created_at") + var returned []struct { + ID int + Name string + CreatedAt time.Time + } + err = gormbulk.BulkInsertWithReturningValues(dbForReturning, insertRecords, &returned, 3000) + if err != nil { + // do something + } + // Values of `returned` will be as follows + // {{ID: 1, Name: "name0", CreatedAt: 2021-10-31 16:21:48.019947 +0000 UTC}, ...} } ``` From aad2d274a28c9bd04a1b063a580057718cbad9e7 Mon Sep 17 00:00:00 2001 From: Taiga Tachibana Date: Sun, 7 Nov 2021 20:52:08 +0900 Subject: [PATCH 07/12] recover test --- bulk_insert_test.go | 30 +++++++++++++++++++++++++++--- 1 file changed, 27 insertions(+), 3 deletions(-) diff --git a/bulk_insert_test.go b/bulk_insert_test.go index 892972f..7cce126 100644 --- a/bulk_insert_test.go +++ b/bulk_insert_test.go @@ -85,9 +85,9 @@ func TestBulkInsertWithReturningValues_InvalidTypeOfReturnedVals(t *testing.T) { tests := []struct { name string vals interface{} - } { - {name: "not a pointer", vals: []struct{Name string}{{Name: "1"}}}, - {name: "element is not a slice", vals: &struct{Name string}{Name: "1"}}, + }{ + {name: "not a pointer", vals: []struct{ Name string }{{Name: "1"}}}, + {name: "element is not a slice", vals: &struct{ Name string }{Name: "1"}}, {name: "slice element is not a struct", vals: &[]string{"1"}}, } for _, tt := range tests { @@ -214,3 +214,27 @@ func Test_fieldIsAutoIncrement(t *testing.T) { } } } + +func Test_fieldIsPrimaryAndBlank(t *testing.T) { + type notPrimaryTable struct { + Dummy int + } + type primaryKeyTable struct { + ID int `gorm:"column:id;primary_key"` + } + + cases := []struct { + Value interface{} + Expected bool + }{ + {notPrimaryTable{Dummy: 0}, false}, + {notPrimaryTable{Dummy: 1}, false}, + {primaryKeyTable{ID: 0}, true}, + {primaryKeyTable{ID: 1}, false}, + } + for _, c := range cases { + for _, field := range (&gorm.Scope{Value: c.Value}).Fields() { + assert.Equal(t, fieldIsPrimaryAndBlank(field), c.Expected) + } + } +} From 8c21b862ebba3f3fa6f39c1d1142cdbcf90a76dc Mon Sep 17 00:00:00 2001 From: Taiga Tachibana Date: Sun, 7 Nov 2021 20:56:27 +0900 Subject: [PATCH 08/12] modify comment --- bulk_insert.go | 5 +++-- 1 file changed, 3 insertions(+), 2 deletions(-) diff --git a/bulk_insert.go b/bulk_insert.go index dba72a0..eb355cd 100644 --- a/bulk_insert.go +++ b/bulk_insert.go @@ -34,8 +34,9 @@ func BulkInsert(db *gorm.DB, objects []interface{}, chunkSize int, excludeColumn } // BulkInsertWithReturningValues executes the query to insert multiple records at once. -// This will scan the returned value into `dstValues`. -// It's necessary to set "gorm:insert_option" to execute "returning" query. +// This will scan the returned values into `returnedVals`. +// +// [db] must be set with "gorm:insert_option" to execute RETURNING clause. e.g. db.Set("gorm:insert_option", "RETURNING id") // // [objects] must be a slice of struct. // From 4484575a546514a9ea80a481956884f9ca194a83 Mon Sep 17 00:00:00 2001 From: Taiga Tachibana Date: Sun, 7 Nov 2021 20:59:17 +0900 Subject: [PATCH 09/12] update indent --- README.md | 102 +++++++++++++++++++++++++++--------------------------- 1 file changed, 51 insertions(+), 51 deletions(-) diff --git a/README.md b/README.md index 7681d64..f4c7797 100644 --- a/README.md +++ b/README.md @@ -50,64 +50,64 @@ In the above pattern `Name` and `Email` fields are excluded. package main import ( - "fmt" - "log" - "time" + "fmt" + "log" + "time" - "github.com/jinzhu/gorm" - _ "github.com/jinzhu/gorm/dialects/mysql" - gormbulk "github.com/t-tiger/gorm-bulk-insert/v2" + "github.com/jinzhu/gorm" + _ "github.com/jinzhu/gorm/dialects/mysql" + gormbulk "github.com/t-tiger/gorm-bulk-insert/v2" ) type fakeTable struct { - ID int `gorm:"AUTO_INCREMENT"` - Name string - Email string - CreatedAt time.Time - UpdatedAt time.Time + ID int `gorm:"AUTO_INCREMENT"` + Name string + Email string + CreatedAt time.Time + UpdatedAt time.Time } func main() { - db, err := gorm.Open("mysql", "mydb") - if err != nil { - log.Fatal(err) - } - - var insertRecords []interface{} - for i := 0; i < 10; i++ { - insertRecords = append(insertRecords, - fakeTable{ - Name: fmt.Sprintf("name%d", i), - Email: fmt.Sprintf("test%d@test.com", i), - // you don't need to set CreatedAt, UpdatedAt - }, - ) - } - - err = gormbulk.BulkInsert(db, insertRecords, 3000) - if err != nil { - // do something - } - - // Columns you want to exclude from Insert, specify as an argument - err = gormbulk.BulkInsert(db, insertRecords, 3000, "Email") - if err != nil { - // do something - } - - // Fetch returning values - dbForReturning := db.Set("gorm:insert_option", "RETURNING id, name, created_at") - var returned []struct { - ID int - Name string - CreatedAt time.Time - } - err = gormbulk.BulkInsertWithReturningValues(dbForReturning, insertRecords, &returned, 3000) - if err != nil { - // do something - } - // Values of `returned` will be as follows - // {{ID: 1, Name: "name0", CreatedAt: 2021-10-31 16:21:48.019947 +0000 UTC}, ...} + db, err := gorm.Open("mysql", "mydb") + if err != nil { + log.Fatal(err) + } + + var insertRecords []interface{} + for i := 0; i < 10; i++ { + insertRecords = append(insertRecords, + fakeTable{ + Name: fmt.Sprintf("name%d", i), + Email: fmt.Sprintf("test%d@test.com", i), + // you don't need to set CreatedAt, UpdatedAt + }, + ) + } + + err = gormbulk.BulkInsert(db, insertRecords, 3000) + if err != nil { + // do something + } + + // Columns you want to exclude from Insert, specify as an argument + err = gormbulk.BulkInsert(db, insertRecords, 3000, "Email") + if err != nil { + // do something + } + + // Fetch returning values + dbForReturning := db.Set("gorm:insert_option", "RETURNING id, name, created_at") + var returned []struct { + ID int + Name string + CreatedAt time.Time + } + err = gormbulk.BulkInsertWithReturningValues(dbForReturning, insertRecords, &returned, 3000) + if err != nil { + // do something + } + // Values of `returned` will be as follows + // {{ID: 1, Name: "name0", CreatedAt: 2021-10-31 16:21:48.019947 +0000 UTC}, ...} } ``` From 9eeb2670112f79becb68cca920eb8c1fbe647e87 Mon Sep 17 00:00:00 2001 From: Taiga Tachibana Date: Sun, 7 Nov 2021 21:04:57 +0900 Subject: [PATCH 10/12] minor fix --- README.md | 6 +++--- bulk_insert.go | 3 +-- 2 files changed, 4 insertions(+), 5 deletions(-) diff --git a/README.md b/README.md index f4c7797..9fe7534 100644 --- a/README.md +++ b/README.md @@ -61,8 +61,8 @@ import ( type fakeTable struct { ID int `gorm:"AUTO_INCREMENT"` - Name string - Email string + Name string + Email string CreatedAt time.Time UpdatedAt time.Time } @@ -106,7 +106,7 @@ func main() { if err != nil { // do something } - // Values of `returned` will be as follows + // Values of `returned` will be like this // {{ID: 1, Name: "name0", CreatedAt: 2021-10-31 16:21:48.019947 +0000 UTC}, ...} } ``` diff --git a/bulk_insert.go b/bulk_insert.go index eb355cd..3a29606 100644 --- a/bulk_insert.go +++ b/bulk_insert.go @@ -53,7 +53,6 @@ func BulkInsertWithReturningValues(db *gorm.DB, objects []interface{}, returnedV } refDst := reflect.Indirect(reflect.ValueOf(returnedVals)) - typ = refDst.Type() // Split records with specified size not to exceed Database parameter limit for _, objSet := range splitObjects(objects, chunkSize) { @@ -61,7 +60,7 @@ func BulkInsertWithReturningValues(db *gorm.DB, objects []interface{}, returnedV if err != nil { return err } - scanned := reflect.New(typ) + scanned := reflect.New(refDst.Type()) if err := db.Scan(scanned.Interface()).Error; err != nil { return err } From 241f2add362784c38b29b3387e85ddfa4abaf4a9 Mon Sep 17 00:00:00 2001 From: Taiga Tachibana Date: Thu, 11 Nov 2021 21:37:15 +0900 Subject: [PATCH 11/12] CR fix --- README.md | 3 +-- bulk_insert.go | 35 +++++++++++++++++++---------------- bulk_insert_test.go | 3 +-- 3 files changed, 21 insertions(+), 20 deletions(-) diff --git a/README.md b/README.md index 9fe7534..2a42abd 100644 --- a/README.md +++ b/README.md @@ -96,13 +96,12 @@ func main() { } // Fetch returning values - dbForReturning := db.Set("gorm:insert_option", "RETURNING id, name, created_at") var returned []struct { ID int Name string CreatedAt time.Time } - err = gormbulk.BulkInsertWithReturningValues(dbForReturning, insertRecords, &returned, 3000) + err = gormbulk.BulkInsertWithReturningValues(db, insertRecords, &returned, 3000) if err != nil { // do something } diff --git a/bulk_insert.go b/bulk_insert.go index 3a29606..049083a 100644 --- a/bulk_insert.go +++ b/bulk_insert.go @@ -25,8 +25,7 @@ import ( func BulkInsert(db *gorm.DB, objects []interface{}, chunkSize int, excludeColumns ...string) error { // Split records with specified size not to exceed Database parameter limit for _, objSet := range splitObjects(objects, chunkSize) { - _, err := insertObjSet(db, objSet, excludeColumns...) - if err != nil { + if err := insertObjSet(db, objSet, excludeColumns...); err != nil { return err } } @@ -36,8 +35,6 @@ func BulkInsert(db *gorm.DB, objects []interface{}, chunkSize int, excludeColumn // BulkInsertWithReturningValues executes the query to insert multiple records at once. // This will scan the returned values into `returnedVals`. // -// [db] must be set with "gorm:insert_option" to execute RETURNING clause. e.g. db.Set("gorm:insert_option", "RETURNING id") -// // [objects] must be a slice of struct. // // [returnedVals] must be a point to a slice of struct. Values returned from `RETURNING` clause will be assigned. @@ -51,13 +48,19 @@ func BulkInsertWithReturningValues(db *gorm.DB, objects []interface{}, returnedV if typ.Kind() != reflect.Ptr || typ.Elem().Kind() != reflect.Slice || typ.Elem().Elem().Kind() != reflect.Struct { return errors.New("returnedVals must be a pointer to a slice of struct") } - refDst := reflect.Indirect(reflect.ValueOf(returnedVals)) + // set insert_option + fields := (&gorm.Scope{Value: returnedVals}).Fields() + returningCols := make([]string, len(fields)) + for i, f := range fields { + returningCols[i] = f.DBName + } + db = db.Set("gorm:insert_option", fmt.Sprintf("RETURNING %s", strings.Join(returningCols, ", "))) + // Split records with specified size not to exceed Database parameter limit for _, objSet := range splitObjects(objects, chunkSize) { - db, err := insertObjSet(db, objSet, excludeColumns...) - if err != nil { + if err := insertObjSet(db, objSet, excludeColumns...); err != nil { return err } scanned := reflect.New(refDst.Type()) @@ -69,14 +72,14 @@ func BulkInsertWithReturningValues(db *gorm.DB, objects []interface{}, returnedV return nil } -func insertObjSet(db *gorm.DB, objects []interface{}, excludeColumns ...string) (*gorm.DB, error) { +func insertObjSet(db *gorm.DB, objects []interface{}, excludeColumns ...string) error { if len(objects) == 0 { - return db, nil + return nil } firstAttrs, err := extractMapValue(objects[0], excludeColumns) if err != nil { - return db, err + return err } attrSize := len(firstAttrs) @@ -95,12 +98,12 @@ func insertObjSet(db *gorm.DB, objects []interface{}, excludeColumns ...string) for _, obj := range objects { objAttrs, err := extractMapValue(obj, excludeColumns) if err != nil { - return db, err + return err } // If object sizes are different, SQL statement loses consistency if len(objAttrs) != attrSize { - return db, errors.New("attribute sizes are inconsistent") + return errors.New("attribute sizes are inconsistent") } scope := db.NewScope(obj) @@ -123,7 +126,7 @@ func insertObjSet(db *gorm.DB, objects []interface{}, excludeColumns ...string) if val, ok := db.Get("gorm:insert_option"); ok { strVal, ok := val.(string) if !ok { - return db, errors.New("gorm:insert_option should be a string") + return errors.New("gorm:insert_option should be a string") } insertOption = strVal } @@ -135,11 +138,11 @@ func insertObjSet(db *gorm.DB, objects []interface{}, excludeColumns ...string) insertOption, )) - db = db.Raw(mainScope.SQL, mainScope.SQLVars...) + *db = *db.Raw(mainScope.SQL, mainScope.SQLVars...) if err := db.Error; err != nil { - return db, err + return err } - return db, nil + return nil } // Obtain columns and values required for insert from interface diff --git a/bulk_insert_test.go b/bulk_insert_test.go index 7cce126..8535486 100644 --- a/bulk_insert_test.go +++ b/bulk_insert_test.go @@ -63,7 +63,6 @@ func TestBulkInsertWithReturningValues(t *testing.T) { }, } - gdb = gdb.Set("gorm_insert_option", "RETURNING id, ThisIsCamelCase, regular_column") err = BulkInsertWithReturningValues(gdb, obj, &returnedVals, 1000) require.NoError(t, err) @@ -171,7 +170,7 @@ func Test_insertObject(t *testing.T) { sqlmock.NewResult(1, 1), ) - _, err = insertObjSet(gdb, []interface{}{ + err = insertObjSet(gdb, []interface{}{ Table{ RegularColumn: "first regular", Custom: "first custom", From 798a708a1293c87cb4970f37c031883fe4081ad2 Mon Sep 17 00:00:00 2001 From: Taiga Tachibana Date: Mon, 6 Dec 2021 22:57:14 +0900 Subject: [PATCH 12/12] implement BulkInsertWithReturningIDs --- README.md | 13 +++++++++---- bulk_insert.go | 34 ++++++++++++++++++++++++++++++++++ 2 files changed, 43 insertions(+), 4 deletions(-) diff --git a/README.md b/README.md index 2a42abd..2560239 100644 --- a/README.md +++ b/README.md @@ -96,17 +96,22 @@ func main() { } // Fetch returning values - var returned []struct { + var returnedVals []struct { ID int Name string CreatedAt time.Time } - err = gormbulk.BulkInsertWithReturningValues(db, insertRecords, &returned, 3000) + err = gormbulk.BulkInsertWithReturningValues(db, insertRecords, &returnedVals, 3000) + if err != nil { + // do something + } + + // Fetch returned IDs + var returnedIDs []int + err = gormbulk.BulkInsertWithReturningIDs(db, insertRecords, &returnedIDs, 3000) if err != nil { // do something } - // Values of `returned` will be like this - // {{ID: 1, Name: "name0", CreatedAt: 2021-10-31 16:21:48.019947 +0000 UTC}, ...} } ``` diff --git a/bulk_insert.go b/bulk_insert.go index 049083a..a7b1d7b 100644 --- a/bulk_insert.go +++ b/bulk_insert.go @@ -72,6 +72,40 @@ func BulkInsertWithReturningValues(db *gorm.DB, objects []interface{}, returnedV return nil } +// BulkInsertWithReturningIDs executes the query to insert multiple records at once. +// This will scan the returned id into `returnedIDs`. If the target table does not have "id" column, please use BulkInsertWithReturningValues instead. +// +// [objects] must be a slice of struct. +// +// [returnedVals] must be a point to a slice. Values returned from `RETURNING` clause will be assigned. +// +// [chunkSize] is a number of variables embedded in query. To prevent the error which occurs embedding a large number of variables at once +// and exceeds the limit of prepared statement. Larger size normally leads to better performance, in most cases 2000 to 3000 is reasonable. +// +// [excludeColumns] is column names to exclude from insert. +func BulkInsertWithReturningIDs(db *gorm.DB, objects []interface{}, returnedIDs interface{}, chunkSize int, excludeColumns ...string) error { + typ := reflect.TypeOf(returnedIDs) + if typ.Kind() != reflect.Ptr || typ.Elem().Kind() != reflect.Slice { + return errors.New("returnedVals must be a pointer to a slice") + } + refDst := reflect.Indirect(reflect.ValueOf(returnedIDs)) + + db = db.Set("gorm:insert_option", "RETURNING id") + + // Split records with specified size not to exceed Database parameter limit + for _, objSet := range splitObjects(objects, chunkSize) { + if err := insertObjSet(db, objSet, excludeColumns...); err != nil { + return err + } + ids := reflect.New(refDst.Type()) + if err := db.Pluck("ID", ids.Interface()).Error; err != nil { + return err + } + refDst.Set(reflect.AppendSlice(refDst, ids.Elem())) + } + return nil +} + func insertObjSet(db *gorm.DB, objects []interface{}, excludeColumns ...string) error { if len(objects) == 0 { return nil