Skip to content

Commit 6af1cd0

Browse files
Merge pull request #9 from DoWithLogic/feat/dbtx
chore: database transaction
2 parents b6daf76 + c1e5bd9 commit 6af1cd0

File tree

7 files changed

+82
-91
lines changed

7 files changed

+82
-91
lines changed

internal/app/service.go

+1-1
Original file line numberDiff line numberDiff line change
@@ -11,7 +11,7 @@ func (app *App) StartService() error {
1111
userRepo := userRepository.NewRepository(app.DB, app.Log)
1212

1313
// define usecase
14-
userUC := userUseCase.NewUseCase(userRepo, app.DB, app.Log)
14+
userUC := userUseCase.NewUseCase(userRepo, app.Log)
1515

1616
// define controllers
1717
userCTRL := userV1.NewHandlers(userUC, app.Log)

internal/users/entities/users.go

+1-2
Original file line numberDiff line numberDiff line change
@@ -19,8 +19,7 @@ type (
1919
}
2020

2121
LockingOpt struct {
22-
ForUpdateNoWait bool
23-
ForUpdate bool
22+
PessimisticLocking bool
2423
}
2524
)
2625

internal/users/mock/repository_mock.go

+16
Some generated files are not rendered by default. Learn more about customizing how changed files appear on GitHub.

internal/users/repository/repository.go

+35-14
Original file line numberDiff line numberDiff line change
@@ -2,30 +2,55 @@ package repository
22

33
import (
44
"context"
5+
"database/sql"
56

67
"github.com/DoWithLogic/golang-clean-architecture/internal/users/entities"
78
"github.com/DoWithLogic/golang-clean-architecture/internal/users/repository/repository_query"
89
"github.com/DoWithLogic/golang-clean-architecture/pkg/datasource"
910
"github.com/DoWithLogic/golang-clean-architecture/pkg/otel/zerolog"
1011
"github.com/DoWithLogic/golang-clean-architecture/pkg/utils"
12+
"github.com/jmoiron/sqlx"
1113
)
1214

1315
type (
1416
Repository interface {
17+
Atomic(ctx context.Context, opt *sql.TxOptions, repo func(tx Repository) error) error
18+
1519
SaveNewUser(context.Context, entities.Users) (int64, error)
1620
UpdateUserByID(context.Context, entities.UpdateUsers) error
1721
GetUserByID(context.Context, int64, ...entities.LockingOpt) (entities.Users, error)
1822
UpdateUserStatusByID(context.Context, entities.UpdateUserStatus) error
1923
}
2024

2125
repository struct {
22-
conn datasource.SQLTxConn
26+
db *sqlx.DB
27+
conn datasource.ConnTx
2328
log *zerolog.Logger
2429
}
2530
)
2631

27-
func NewRepository(conn datasource.SQLTxConn, log *zerolog.Logger) Repository {
28-
return &repository{conn, log}
32+
func NewRepository(c *sqlx.DB, l *zerolog.Logger) Repository {
33+
return &repository{conn: c, log: l, db: c}
34+
}
35+
36+
// Atomic implements vendor.Repository for transaction query
37+
func (r *repository) Atomic(ctx context.Context, opt *sql.TxOptions, repo func(tx Repository) error) error {
38+
txConn, err := r.db.BeginTx(ctx, opt)
39+
if err != nil {
40+
r.log.Z().Err(err).Msg("[repository]Atomic.BeginTxx")
41+
42+
return err
43+
}
44+
45+
newRepository := &repository{conn: txConn, db: r.db}
46+
47+
repo(newRepository)
48+
49+
if err := new(datasource.DataSource).EndTx(txConn, err); err != nil {
50+
return err
51+
}
52+
53+
return nil
2954
}
3055

3156
func (repo *repository) SaveNewUser(ctx context.Context, user entities.Users) (int64, error) {
@@ -39,7 +64,7 @@ func (repo *repository) SaveNewUser(ctx context.Context, user entities.Users) (i
3964
}
4065

4166
var userID int64
42-
err := new(datasource.SQL).Exec(repo.conn.ExecContext(ctx, repository_query.InsertUsers, args...)).Scan(nil, &userID)
67+
err := new(datasource.DataSource).ExecSQL(repo.conn.ExecContext(ctx, repository_query.InsertUsers, args...)).Scan(nil, &userID)
4368
if err != nil {
4469
repo.log.Z().Err(err).Msg("[repository]SaveNewUser.ExecContext")
4570

@@ -59,7 +84,7 @@ func (repo *repository) UpdateUserByID(ctx context.Context, user entities.Update
5984
user.UserID,
6085
}
6186

62-
err := new(datasource.SQL).Exec(repo.conn.ExecContext(ctx, repository_query.UpdateUsers, args...)).Scan(nil, nil)
87+
err := new(datasource.DataSource).ExecSQL(repo.conn.ExecContext(ctx, repository_query.UpdateUsers, args...)).Scan(nil, nil)
6388
if err != nil {
6489
repo.log.Z().Err(err).Msg("[repository]UpdateUserByID.ExecContext")
6590

@@ -69,7 +94,7 @@ func (repo *repository) UpdateUserByID(ctx context.Context, user entities.Update
6994
return nil
7095
}
7196

72-
func (repo *repository) GetUserByID(ctx context.Context, userID int64, lockOpt ...entities.LockingOpt) (userData entities.Users, err error) {
97+
func (repo *repository) GetUserByID(ctx context.Context, userID int64, options ...entities.LockingOpt) (userData entities.Users, err error) {
7398
args := utils.Array{
7499
userID,
75100
}
@@ -87,15 +112,11 @@ func (repo *repository) GetUserByID(ctx context.Context, userID int64, lockOpt .
87112

88113
query := repository_query.GetUserByID
89114

90-
if len(lockOpt) >= 1 {
91-
if lockOpt[0].ForUpdate {
92-
query += " FOR UPDATE;"
93-
} else {
94-
query += " FOR UPDATE NO WAIT;"
95-
}
115+
if len(options) >= 1 && options[0].PessimisticLocking {
116+
query += " FOR UPDATE"
96117
}
97118

98-
if err = new(datasource.SQL).Query(repo.conn.QueryContext(ctx, query, args...)).Scan(row); err != nil {
119+
if err = new(datasource.DataSource).QuerySQL(repo.conn.QueryContext(ctx, query, args...)).Scan(row); err != nil {
99120
repo.log.Z().Err(err).Msg("[repository]GetUserByID.QueryContext")
100121
return userData, err
101122
}
@@ -112,7 +133,7 @@ func (repo *repository) UpdateUserStatusByID(ctx context.Context, req entities.U
112133
}
113134

114135
var updatedID int64
115-
err := new(datasource.SQL).Exec(repo.conn.ExecContext(ctx, repository_query.UpdateUserStatusByID, args...)).Scan(nil, &updatedID)
136+
err := new(datasource.DataSource).ExecSQL(repo.conn.ExecContext(ctx, repository_query.UpdateUserStatusByID, args...)).Scan(nil, &updatedID)
116137
if err != nil {
117138
repo.log.Z().Err(err).Msg("[repository]UpdateUserStatusByID.ExecContext")
118139

internal/users/usecase/usecase.go

+9-21
Original file line numberDiff line numberDiff line change
@@ -2,12 +2,11 @@ package usecase
22

33
import (
44
"context"
5+
"database/sql"
56

67
"github.com/DoWithLogic/golang-clean-architecture/internal/users/entities"
78
"github.com/DoWithLogic/golang-clean-architecture/internal/users/repository"
8-
"github.com/DoWithLogic/golang-clean-architecture/pkg/datasource"
99
"github.com/DoWithLogic/golang-clean-architecture/pkg/otel/zerolog"
10-
"github.com/jmoiron/sqlx"
1110
)
1211

1312
type (
@@ -19,13 +18,12 @@ type (
1918

2019
usecase struct {
2120
repo repository.Repository
22-
dbTx *sqlx.DB
2321
log *zerolog.Logger
2422
}
2523
)
2624

27-
func NewUseCase(repo repository.Repository, txConn *sqlx.DB, log *zerolog.Logger) Usecase {
28-
return &usecase{repo, txConn, log}
25+
func NewUseCase(repo repository.Repository, log *zerolog.Logger) Usecase {
26+
return &usecase{repo, log}
2927
}
3028

3129
func (uc *usecase) CreateUser(ctx context.Context, payload entities.CreateUser) (int64, error) {
@@ -40,32 +38,22 @@ func (uc *usecase) CreateUser(ctx context.Context, payload entities.CreateUser)
4038
}
4139

4240
func (uc *usecase) UpdateUser(ctx context.Context, updateData entities.UpdateUsers) error {
43-
return func(dbTx *sqlx.DB) error {
44-
txConn, err := uc.dbTx.BeginTx(ctx, nil)
45-
if err != nil {
46-
return err
47-
}
48-
49-
defer func() {
50-
if err := new(datasource.SQL).EndTx(txConn, err); err != nil {
51-
return
52-
}
53-
}()
41+
return uc.repo.Atomic(ctx, &sql.TxOptions{}, func(tx repository.Repository) error {
5442

55-
repoTx := repository.NewRepository(txConn, uc.log)
56-
57-
if _, err := repoTx.GetUserByID(ctx, updateData.UserID, entities.LockingOpt{ForUpdate: true}); err != nil {
43+
if _, err := tx.GetUserByID(ctx, updateData.UserID, entities.LockingOpt{PessimisticLocking: true}); err != nil {
5844
uc.log.Z().Err(err).Msg("[usecase]UpdateUser.GetUserByID")
45+
5946
return err
6047
}
6148

62-
if err = repoTx.UpdateUserByID(ctx, entities.NewUpdateUsers(updateData)); err != nil {
49+
if err := tx.UpdateUserByID(ctx, entities.NewUpdateUsers(updateData)); err != nil {
6350
uc.log.Z().Err(err).Msg("[usecase]UpdateUser.UpdateUserByID")
51+
6452
return err
6553
}
6654

6755
return nil
68-
}(uc.dbTx)
56+
})
6957
}
7058

7159
func (uc *usecase) UpdateUserStatus(ctx context.Context, req entities.UpdateUserStatus) error {

internal/users/usecase/usecase_test.go

+1-6
Original file line numberDiff line numberDiff line change
@@ -49,7 +49,6 @@ func Test_usecase_CreateUser(t *testing.T) {
4949
repo := mocks.NewMockRepository(ctrl)
5050
uc := usecase.NewUseCase(
5151
repo,
52-
nil,
5352
zerolog.NewZeroLog(ctx, os.Stdout),
5453
)
5554

@@ -104,11 +103,7 @@ func Test_usecase_UpdateUserStatus(t *testing.T) {
104103

105104
ctx := context.Background()
106105
repo := mocks.NewMockRepository(ctrl)
107-
uc := usecase.NewUseCase(
108-
repo,
109-
nil,
110-
zerolog.NewZeroLog(ctx, os.Stdout),
111-
)
106+
uc := usecase.NewUseCase(repo, zerolog.NewZeroLog(ctx, os.Stdout))
112107

113108
args := entities.UpdateUserStatus{
114109
UserID: 1,

pkg/datasource/sql.go

+19-47
Original file line numberDiff line numberDiff line change
@@ -13,22 +13,17 @@ import (
1313
)
1414

1515
type (
16-
BeginTx interface {
16+
Conn interface {
1717
BeginTx(ctx context.Context, opts *sql.TxOptions) (tx *sql.Tx, err error)
18-
}
19-
ExecContext interface {
20-
ExecContext(ctx context.Context, query string, args ...interface{}) (res sql.Result, err error)
21-
}
22-
PingContext interface {
2318
PingContext(ctx context.Context) (err error)
19+
io.Closer
20+
ConnTx
2421
}
25-
PrepareContext interface {
22+
23+
ConnTx interface {
24+
ExecContext(ctx context.Context, query string, args ...interface{}) (res sql.Result, err error)
2625
PrepareContext(ctx context.Context, query string) (stmt *sql.Stmt, err error)
27-
}
28-
QueryContext interface {
2926
QueryContext(ctx context.Context, query string, args ...interface{}) (rows *sql.Rows, err error)
30-
}
31-
QueryRowContext interface {
3227
QueryRowContext(ctx context.Context, query string, args ...interface{}) (row *sql.Row)
3328
}
3429

@@ -37,11 +32,6 @@ type (
3732
}
3833

3934
Query interface {
40-
// Scan accept do, a func that accept `i int` as index and returns a List
41-
// of pointer.
42-
// List == nil // break the loop
43-
// len(List) < 1 // skip the current loop
44-
// len(List) > 0 // assign the pointer, must be same as the length of columns
4535
Scan(row func(i int) utils.Array) (err error)
4636
}
4737

@@ -55,30 +45,17 @@ type (
5545
err error
5646
}
5747

58-
SQLConn interface {
59-
BeginTx
60-
io.Closer
61-
PingContext
62-
SQLTxConn
63-
}
64-
65-
SQLTxConn interface {
66-
ExecContext
67-
PrepareContext
68-
QueryContext
69-
QueryRowContext
70-
}
71-
72-
SQL struct{}
48+
DataSource struct{}
7349
)
7450

7551
var (
76-
_ SQLConn = (*sql.Conn)(nil)
77-
_ SQLConn = (*sql.DB)(nil)
78-
_ SQLTxConn = (*sql.Tx)(nil)
79-
log = zerolog.NewZeroLog(context.Background(), os.Stdout)
52+
_ Conn = (*sql.Conn)(nil)
53+
_ Conn = (*sql.DB)(nil)
54+
_ ConnTx = (*sql.Tx)(nil)
55+
log = zerolog.NewZeroLog(context.Background(), os.Stdout)
8056
)
8157

58+
// datasource errors
8259
var (
8360
ErrNoColumnReturned = errors.New("no columns returned")
8461
ErrDataNotFound = errors.New("data not found")
@@ -193,20 +170,15 @@ func (x query) Scan(row func(i int) utils.Array) error {
193170
return err
194171
}
195172

196-
func (SQL) Exec(sqlResult sql.Result, err error) Exec { return exec{sqlResult, err} }
173+
func (DataSource) ExecSQL(sqlResult sql.Result, err error) exec {
174+
return exec{sqlResult, err}
175+
}
197176

198-
func (SQL) Query(sqlRows *sql.Rows, err error) Query { return query{sqlRows, err} }
177+
func (DataSource) QuerySQL(sqlRows *sql.Rows, err error) Query {
178+
return query{sqlRows, err}
179+
}
199180

200-
// EndTx will end transaction with provided *sql.Tx and error. The tx argument
201-
// should be valid, and then will check the err, if any error occurred, will
202-
// commencing the ROLLBACK else will COMMIT the transaction.
203-
//
204-
// txc := XSQLTxConn(db) // shared between *sql.Tx, *sql.DB and *sql.Conn
205-
// if tx, err := db.BeginTx(ctx, nil); err == nil && tx != nil {
206-
// defer func() { err = xsql.EndTx(tx, err) }()
207-
// txc = tx
208-
// }
209-
func (SQL) EndTx(tx *sql.Tx, err error) error {
181+
func (DataSource) EndTx(tx *sql.Tx, err error) error {
210182
if tx == nil {
211183
log.Z().Err(ErrInvalidTransaction).Msg("[database:EndTx]")
212184

0 commit comments

Comments
 (0)