diff --git a/contracts/database/orm/orm.go b/contracts/database/orm/orm.go index 0bada3091..37d8274f2 100644 --- a/contracts/database/orm/orm.go +++ b/contracts/database/orm/orm.go @@ -94,6 +94,8 @@ type Query interface { Order(value any) Query // OrWhere add an "or where" clause to the query. OrWhere(query any, args ...any) Query + // OrWhereIn adds an "or where column in" clause to the query. + OrWhereIn(column string, values []any) Query // Paginate the given query into a simple paginator. Paginate(page, limit int, dest any, total *int64) error // Pluck retrieves a single column from the database. diff --git a/database/gorm/query.go b/database/gorm/query.go index 1c4b7ff5d..5c8b12f04 100644 --- a/database/gorm/query.go +++ b/database/gorm/query.go @@ -668,6 +668,10 @@ func (r *QueryImpl) WhereIn(column string, values []any) ormcontract.Query { return r.Where(fmt.Sprintf("%s IN ?", column), values) } +func (r *QueryImpl) OrWhereIn(column string, values []any) ormcontract.Query { + return r.OrWhere(fmt.Sprintf("%s IN ?", column), values) +} + func (r *QueryImpl) WithoutEvents() ormcontract.Query { return NewQueryImplByInstance(r.instance, &QueryImpl{ config: r.config, diff --git a/database/gorm/query_test.go b/database/gorm/query_test.go index d37100f80..8143c979b 100644 --- a/database/gorm/query_test.go +++ b/database/gorm/query_test.go @@ -2725,6 +2725,24 @@ func (s *QueryTestSuite) TestWhereIn() { } } +func (s *QueryTestSuite) TestOrWhereIn() { + for driver, query := range s.queries { + s.Run(driver.String(), func() { + user := User{Name: "where_in_user", Avatar: "where_in_avatar"} + s.Nil(query.Create(&user)) + s.True(user.ID > 0) + + user1 := User{Name: "where_in_user_1", Avatar: "where_in_avatar_1"} + s.Nil(query.Create(&user1)) + s.True(user1.ID > 0) + + var users []User + s.Nil(query.Where("id = ?", -1).OrWhereIn("id", []any{user.ID, user1.ID}).Find(&users)) + s.True(len(users) == 2) + }) + } +} + func (s *QueryTestSuite) TestWithoutEvents() { for _, query := range s.queries { tests := []struct { diff --git a/mocks/database/orm/Query.go b/mocks/database/orm/Query.go index e6ded2f60..e7849ef71 100644 --- a/mocks/database/orm/Query.go +++ b/mocks/database/orm/Query.go @@ -544,6 +544,22 @@ func (_m *Query) OrWhere(query interface{}, args ...interface{}) orm.Query { return r0 } +// OrWhereIn provides a mock function with given fields: column, values +func (_m *Query) OrWhereIn(column string, values []interface{}) orm.Query { + ret := _m.Called(column, values) + + var r0 orm.Query + if rf, ok := ret.Get(0).(func(string, []interface{}) orm.Query); ok { + r0 = rf(column, values) + } else { + if ret.Get(0) != nil { + r0 = ret.Get(0).(orm.Query) + } + } + + return r0 +} + // Order provides a mock function with given fields: value func (_m *Query) Order(value interface{}) orm.Query { ret := _m.Called(value) diff --git a/mocks/database/orm/Transaction.go b/mocks/database/orm/Transaction.go index 884af9b4e..e97f1cb84 100644 --- a/mocks/database/orm/Transaction.go +++ b/mocks/database/orm/Transaction.go @@ -558,6 +558,22 @@ func (_m *Transaction) OrWhere(query interface{}, args ...interface{}) orm.Query return r0 } +// OrWhereIn provides a mock function with given fields: column, values +func (_m *Transaction) OrWhereIn(column string, values []interface{}) orm.Query { + ret := _m.Called(column, values) + + var r0 orm.Query + if rf, ok := ret.Get(0).(func(string, []interface{}) orm.Query); ok { + r0 = rf(column, values) + } else { + if ret.Get(0) != nil { + r0 = ret.Get(0).(orm.Query) + } + } + + return r0 +} + // Order provides a mock function with given fields: value func (_m *Transaction) Order(value interface{}) orm.Query { ret := _m.Called(value)