Skip to content

增加多库切换 #92

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Open
wants to merge 7 commits into
base: main
Choose a base branch
from
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
9 changes: 5 additions & 4 deletions constants/constants.go
Original file line number Diff line number Diff line change
Expand Up @@ -18,8 +18,9 @@
package constants

const (
Comma = ","
LeftBracket = "("
RightBracket = ")"
DefaultPrimaryName = "id"
Comma = ","
LeftBracket = "("
RightBracket = ")"
DefaultPrimaryName = "id"
DefaultGormPlusConnName = "DefaultGormPlusConnName" //内置的gorm-plus 数据库连接名
)
31 changes: 18 additions & 13 deletions gplus/cache.go
Original file line number Diff line number Diff line change
Expand Up @@ -31,9 +31,10 @@ var columnNameCache sync.Map
var modelInstanceCache sync.Map

// Cache 缓存实体对象所有的字段名
func Cache(models ...any) {
func Cache(opt Option, models ...any) {
db, _, _ := getDefaultDbByOpt(opt)
for _, model := range models {
columnNameMap := getColumnNameMap(model)
columnNameMap := getColumnNameMap(model, db.Config.NamingStrategy)
for pointer, columnName := range columnNameMap {
columnNameCache.Store(pointer, columnName)
}
Expand All @@ -43,7 +44,7 @@ func Cache(models ...any) {
}
}

func getColumnNameMap(model any) map[uintptr]string {
func getColumnNameMap(model any, namingStrategy schema.Namer) map[uintptr]string {
var columnNameMap = make(map[uintptr]string)
valueOf := reflect.ValueOf(model).Elem()
typeOf := reflect.TypeOf(model).Elem()
Expand All @@ -52,22 +53,23 @@ func getColumnNameMap(model any) map[uintptr]string {
// 如果当前实体嵌入了其他实体,同样需要缓存它的字段名
if field.Anonymous {
// 如果存在多重嵌套,通过递归方式获取他们的字段名
subFieldMap := getSubFieldColumnNameMap(valueOf, field)
subFieldMap := getSubFieldColumnNameMap(valueOf, field, namingStrategy)
for pointer, columnName := range subFieldMap {
columnNameMap[pointer] = columnName
}
} else {
// 获取对象字段指针值
pointer := valueOf.Field(i).Addr().Pointer()
columnName := parseColumnName(field)
columnName := parseColumnName(field, namingStrategy)
columnNameMap[pointer] = columnName
}
}
return columnNameMap
}

// GetModel 获取
func GetModel[T any]() *T {
func GetModel[T any](opts ...OptionFunc) *T {
opt := getDefaultOptionInfo(opts...) //兼容设计
modelTypeStr := reflect.TypeOf((*T)(nil)).Elem().String()
if model, ok := modelInstanceCache.Load(modelTypeStr); ok {
m, isReal := model.(*T)
Expand All @@ -76,12 +78,12 @@ func GetModel[T any]() *T {
}
}
t := new(T)
Cache(t)
Cache(opt, t)
return t
}

// 递归获取嵌套字段名
func getSubFieldColumnNameMap(valueOf reflect.Value, field reflect.StructField) map[uintptr]string {
func getSubFieldColumnNameMap(valueOf reflect.Value, field reflect.StructField, namingStrategy schema.Namer) map[uintptr]string {
result := make(map[uintptr]string)
modelType := field.Type
if modelType.Kind() == reflect.Ptr {
Expand All @@ -90,28 +92,31 @@ func getSubFieldColumnNameMap(valueOf reflect.Value, field reflect.StructField)
for j := 0; j < modelType.NumField(); j++ {
subField := modelType.Field(j)
if subField.Anonymous {
nestedFields := getSubFieldColumnNameMap(valueOf, subField)
nestedFields := getSubFieldColumnNameMap(valueOf, subField, namingStrategy)
for key, value := range nestedFields {
result[key] = value
}
} else {
pointer := valueOf.FieldByName(modelType.Field(j).Name).Addr().Pointer()
name := parseColumnName(modelType.Field(j))
name := parseColumnName(modelType.Field(j), namingStrategy)
result[pointer] = name
}
}

return result
}

// 解析字段名称
func parseColumnName(field reflect.StructField) string {
// 解析字段名称 兼容多数据库切换,
// 如果用户使用Option的GetDb而没有传数据库连接名这边获取的namingStrategy 是默认的一个可能会有问题,
// 所以建议用户多数据库的时候弃用Option里的Db,并且重新改写初始化,给与每个db连接有连接名
// 并且改造下多数据使用NewQuery和GetModel和NewQueryModel相关方法传入数据库连接名
func parseColumnName(field reflect.StructField, namingStrategy schema.Namer) string {
tagSetting := schema.ParseTagSetting(field.Tag.Get("gorm"), ";")
name, ok := tagSetting["COLUMN"]
if ok {
return name
}
return globalDb.Config.NamingStrategy.ColumnName("", field.Name)
return namingStrategy.ColumnName("", field.Name)
}

func getColumnName(v any) string {
Expand Down
116 changes: 107 additions & 9 deletions gplus/dao.go
Original file line number Diff line number Diff line change
Expand Up @@ -18,6 +18,7 @@ package gplus

import (
"database/sql"
"errors"
"fmt"
"github.com/acmestack/gorm-plus/constants"
"gorm.io/gorm"
Expand All @@ -28,11 +29,38 @@ import (
"time"
)

var globalDb *gorm.DB
var globalDbMap = make(map[string]*gorm.DB)
var globalDbKeys []string
var defaultBatchSize = 1000

func Init(db *gorm.DB) {
globalDb = db
// Init 可选参数dbConnNameArr 代表数据库连接名,只需要传一个就行,
// 主要为了兼容之前用户只传一个db无需修改
func Init(db *gorm.DB, dbConnNameArr ...string) error {
var dbConnName = ""
if len(dbConnNameArr) > 0 {
dbConnName = dbConnNameArr[0]
}
return setGlobalInfo(db, dbConnName)
}

// InitMany 初始化多个
func InitMany(dic map[string]*gorm.DB) []error {
var errs []error
for k, v := range dic {
if err := setGlobalInfo(v, k); err != nil {
errs = append(errs, err)
}
}
return errs
}

// GetDb 获取数据库连接
func GetDb(dbConnName string) (*gorm.DB, error) {
db, exists := globalDbMap[dbConnName]
if exists {
return db, nil
}
return nil, errors.New("MultipleDbChange not exists dbConn:" + dbConnName + ",please check")
}

type Page[T any] struct {
Expand All @@ -45,8 +73,8 @@ type Page[T any] struct {

type Dao[T any] struct{}

func (dao Dao[T]) NewQuery() (*QueryCond[T], *T) {
return NewQuery[T]()
func (dao Dao[T]) NewQuery(opts ...OptionFunc) (*QueryCond[T], *T) {
return NewQuery[T](opts...)
}

func NewPage[T any](current, size int) *Page[T] {
Expand Down Expand Up @@ -157,7 +185,7 @@ func UpdateZeroById[T any](entity *T, opts ...OptionFunc) *gorm.DB {
func updateAllIfNeed(entity any, opts []OptionFunc, db *gorm.DB) {
option := getOption(opts)
if len(option.Selects) == 0 {
columnNameMap := getColumnNameMap(entity)
columnNameMap := getColumnNameMap(entity, db.Config.NamingStrategy)
var columnNames []string
for _, columnName := range columnNameMap {
columnNames = append(columnNames, columnName)
Expand Down Expand Up @@ -449,14 +477,21 @@ func buildSqlAndArgs[T any](expressions []any, sqlBuilder *strings.Builder, quer
}

func getDb(opts ...OptionFunc) *gorm.DB {
var db *gorm.DB
option := getOption(opts)
// Clauses()目的是为了初始化Db,如果db已经被初始化了,会直接返回db
var db = globalDb.Clauses()

if option.Db != nil {
db = option.Db.Clauses()
db = option.Db
} else {
db, option.DbConnName, _ = getDefaultDbByName(option.DbConnName)
}

//设置session,如果需要子句仅在当前会话生效,先调用 Session(),再调用 Clauses()。
setSessionIfNeed(option, db)

// Clauses()目的是为了初始化Db,如果db已经被初始化了,会直接返回db
db = db.Clauses()

// 设置需要忽略的字段
setOmitIfNeed(option, db)

Expand Down Expand Up @@ -496,6 +531,12 @@ func setOmitIfNeed(option Option, db *gorm.DB) {
}
}

func setSessionIfNeed(option Option, db *gorm.DB) {
if option.DbSession != nil {
db.Session(option.DbSession)
}
}

func getPkColumnName[T any]() string {
var entity T
entityType := reflect.TypeOf(entity)
Expand All @@ -520,3 +561,60 @@ func getPkColumnName[T any]() string {
}
return columnName
}

func getDefaultDbConnName() string {
dbConnName := constants.DefaultGormPlusConnName
//如果用户没传数据库连接名称,优先判断全局自定义的连接名是否存在,
//如果上面不存在其次从全局globalDbKeys里获取第一个连接名
//1.避免用户使用InitDb方法初始化数据库 自定义数据库连接名 ,然后方法里不传是哪个数据库连接名 则只能默认取第一条
//2.再混用单库Init取初始化,做方法兼容
_, exists := globalDbMap[dbConnName]
if exists {
return dbConnName
}
dbConnName = globalDbKeys[0]
return dbConnName
}

// 获取如果连接名为空则默认填充的option数据
func getDefaultOptionInfo(opts ...OptionFunc) Option {
option := getOption(opts)
if len(option.DbConnName) == 0 {
option.DbConnName = getDefaultDbConnName() //兼容之前设计
}
return option
}

func getDefaultDbByOpt(opt Option) (*gorm.DB, string, error) {
return getDefaultDbByName(opt.DbConnName)
}

func getDefaultDbByName(dbConnName string) (*gorm.DB, string, error) {
if len(dbConnName) == 0 {
dbConnName = getDefaultDbConnName()
}
db, err := GetDb(dbConnName)
return db, dbConnName, err
}

func setGlobalInfo(db *gorm.DB, dbConnName string) error {
if len(dbConnName) == 0 {
//return errors.New("InitMultiple dbConnName is empty please check")
//如果字典里不包含了默认名则使用默认名,兼容之前单库
_, exists := globalDbMap[constants.DefaultGormPlusConnName]
if exists {
//根据db指针地址获取作为连接名,因为GORM 本身不提供直接获取数据库连接地址的方法,也不推荐使用反射来获取dsn
dbConnName = fmt.Sprintf("%p", db)
} else {
dbConnName = constants.DefaultGormPlusConnName
}
}
_, exists := globalDbMap[dbConnName]
if !exists {
// db instance register to global variable
globalDbMap[dbConnName] = db
globalDbKeys = append(globalDbKeys, dbConnName)
return nil
}
return errors.New("InitMultiple have same name:" + dbConnName + ",please check")
}
17 changes: 14 additions & 3 deletions gplus/option.go
Original file line number Diff line number Diff line change
Expand Up @@ -17,13 +17,17 @@

package gplus

import "gorm.io/gorm"
import (
"gorm.io/gorm"
)

type Option struct {
Db *gorm.DB
Selects []any
Omits []any
IgnoreTotal bool
DbConnName string
DbSession *gorm.Session
}

type OptionFunc func(*Option)
Expand All @@ -35,10 +39,10 @@ func Db(db *gorm.DB) OptionFunc {
}
}

// Session 创建回话
// Session 创建会话
func Session(session *gorm.Session) OptionFunc {
return func(o *Option) {
o.Db = globalDb.Session(session)
o.DbSession = session //调整session 在dao类的getDb方法那边处理
}
}

Expand All @@ -62,3 +66,10 @@ func IgnoreTotal() OptionFunc {
o.IgnoreTotal = true
}
}

// DbConnName 多个数据库连接根据自定义连接名称选择切换
func DbConnName(dbConnName string) OptionFunc {
return func(o *Option) {
o.DbConnName = dbConnName
}
}
12 changes: 7 additions & 5 deletions gplus/query.go
Original file line number Diff line number Diff line change
Expand Up @@ -46,7 +46,8 @@ func (q *QueryCond[T]) getSqlSegment() string {
}

// NewQuery 构建查询条件
func NewQuery[T any]() (*QueryCond[T], *T) {
func NewQuery[T any](opts ...OptionFunc) (*QueryCond[T], *T) {
opt := getDefaultOptionInfo(opts...) //兼容设计
q := &QueryCond[T]{}
modelTypeStr := reflect.TypeOf((*T)(nil)).Elem().String()
if model, ok := modelInstanceCache.Load(modelTypeStr); ok {
Expand All @@ -56,12 +57,13 @@ func NewQuery[T any]() (*QueryCond[T], *T) {
}
}
m := new(T)
Cache(m)
Cache(opt, m)
return q, m
}

// NewQueryModel 构建查询条件
func NewQueryModel[T any, R any]() (*QueryCond[T], *T, *R) {
func NewQueryModel[T any, R any](opts ...OptionFunc) (*QueryCond[T], *T, *R) {
opt := getDefaultOptionInfo(opts...) //兼容设计
q := &QueryCond[T]{}
var t *T
var r *R
Expand All @@ -83,12 +85,12 @@ func NewQueryModel[T any, R any]() (*QueryCond[T], *T, *R) {

if t == nil {
t = new(T)
Cache(t)
Cache(opt, t)
}

if r == nil {
r = new(R)
Cache(r)
Cache(opt, r)
}

return q, t, r
Expand Down
Loading
Loading