Skip to content

Commit

Permalink
Handle postgres array
Browse files Browse the repository at this point in the history
  • Loading branch information
kenjihikmatullah committed Jan 1, 2025
1 parent f482f25 commit 4465ee6
Show file tree
Hide file tree
Showing 4 changed files with 82 additions and 0 deletions.
4 changes: 4 additions & 0 deletions interfaces.go
Original file line number Diff line number Diff line change
Expand Up @@ -20,6 +20,10 @@ type Dialector interface {
Explain(sql string, vars ...interface{}) string
}

type ArrayValueHandler interface {
HandleArray(field *schema.Field) error
}

// Plugin GORM plugin interface
type Plugin interface {
Name() string
Expand Down
9 changes: 9 additions & 0 deletions schema/field.go
Original file line number Diff line number Diff line change
Expand Up @@ -47,6 +47,7 @@ const (
String DataType = "string"
Time DataType = "time"
Bytes DataType = "bytes"
Array DataType = "array"
)

const DefaultAutoIncrementIncrement int64 = 1
Expand Down Expand Up @@ -282,6 +283,10 @@ func (schema *Schema) ParseField(fieldStruct reflect.StructField) *Field {
case reflect.Array, reflect.Slice:
if reflect.Indirect(fieldValue).Type().Elem() == ByteReflectType && field.DataType == "" {
field.DataType = Bytes
} else {
elemType := reflect.Indirect(fieldValue).Type().Elem()
field.DataType = Array
field.TagSettings["ELEM_TYPE"] = elemType.Kind().String()
}
}

Expand Down Expand Up @@ -977,6 +982,10 @@ func (field *Field) setupValuerAndSetter() {
return
}
}

if field.DataType != "" && field.FieldType.Kind() == reflect.Slice && field.FieldType.Elem().Kind() != reflect.Uint8 {
field.TagSettings["ARRAY_FIELD"] = "true"
}
}

func (field *Field) setupNewValuePool() {
Expand Down
67 changes: 67 additions & 0 deletions tests/scanner_valuer_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -7,6 +7,8 @@ import (
"encoding/json"
"errors"
"fmt"
"log"
"os"
"reflect"
"regexp"
"strconv"
Expand Down Expand Up @@ -65,6 +67,54 @@ func TestScannerValuer(t *testing.T) {
AssertObjEqual(t, data, result, "Name", "Gender", "Age", "Male", "Height", "Birthday", "Password", "Bytes", "Num", "Strings", "Structs")
}

func TestScannerValuerArray(t *testing.T) {
// Use custom dialector to enable array handler
os.Setenv("GORM_DIALECT", "postgres")
os.Setenv("GORM_ENABLE_ARRAY_HANDLER", "true")
var err error
if DB, err = OpenTestConnection(&gorm.Config{}); err != nil {
log.Printf("failed to connect database, got error %v", err)
os.Exit(1)
}

DB.Migrator().DropTable(&ScannerValuerStructOfArrays{})
if err := DB.Migrator().AutoMigrate(&ScannerValuerStructOfArrays{}); err != nil {
t.Fatalf("no error should happen when migrate scanner, valuer struct, got error %v", err)
}

data := ScannerValuerStructOfArrays{
StringArray: []string{"a", "b", "c"},
IntArray: []int{1, 2, 3},
Int8Array: []int8{1, 2, 3},
Int16Array: []int16{1, 2, 3},
Int32Array: []int32{1, 2, 3},
Int64Array: []int64{1, 2, 3},
UintArray: []uint{1, 2, 3},
Uint16Array: []uint16{1, 2, 3},
Uint32Array: []uint32{1, 2, 3},
Uint64Array: []uint64{1, 2, 3},
Float32Array: []float32{
1.1, 2.2, 3.3,
},
Float64Array: []float64{
1.1, 2.2, 3.3,
},
BoolArray: []bool{true, false, true},
}

if err := DB.Create(&data).Error; err != nil {
t.Fatalf("No error should happened when create scanner valuer struct, but got %v", err)
}

var result ScannerValuerStructOfArrays

if err := DB.Find(&result, "id = ?", data.ID).Error; err != nil {
t.Fatalf("no error should happen when query scanner, valuer struct, but got %v", err)
}

AssertObjEqual(t, data, result, "StringArray", "IntArray", "Int8Array", "Int16Array", "Int32Array", "Int64Array", "UintArray", "Uint16Array", "Uint32Array", "Uint64Array", "Float32Array", "Float64Array", "BoolArray")
}

func TestScannerValuerWithFirstOrCreate(t *testing.T) {
DB.Migrator().DropTable(&ScannerValuerStruct{})
if err := DB.Migrator().AutoMigrate(&ScannerValuerStruct{}); err != nil {
Expand Down Expand Up @@ -162,6 +212,23 @@ type ScannerValuerStruct struct {
ExampleStructPtr *ExampleStruct
}

type ScannerValuerStructOfArrays struct {
gorm.Model
StringArray []string
IntArray []int
Int8Array []int8
Int16Array []int16
Int32Array []int32
Int64Array []int64
UintArray []uint
Uint16Array []uint16
Uint32Array []uint32
Uint64Array []uint64
Float32Array []float32
Float64Array []float64
BoolArray []bool
}

type EncryptedData []byte

func (data *EncryptedData) Scan(value interface{}) error {
Expand Down
2 changes: 2 additions & 0 deletions tests/tests_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -48,6 +48,7 @@ func init() {

func OpenTestConnection(cfg *gorm.Config) (db *gorm.DB, err error) {
dbDSN := os.Getenv("GORM_DSN")
enableArrayHandler := os.Getenv("GORM_ENABLE_ARRAY_HANDLER")
switch os.Getenv("GORM_DIALECT") {
case "mysql":
log.Println("testing mysql...")
Expand All @@ -63,6 +64,7 @@ func OpenTestConnection(cfg *gorm.Config) (db *gorm.DB, err error) {
db, err = gorm.Open(postgres.New(postgres.Config{
DSN: dbDSN,
PreferSimpleProtocol: true,
EnableArrayHandler: enableArrayHandler == "true",
}), cfg)
case "sqlserver":
// go install github.com/microsoft/go-sqlcmd/cmd/sqlcmd@latest
Expand Down

0 comments on commit 4465ee6

Please sign in to comment.