Skip to content

Commit 0774046

Browse files
committed
recover update executor
1 parent be59040 commit 0774046

File tree

3 files changed

+195
-156
lines changed

3 files changed

+195
-156
lines changed

pkg/datasource/sql/exec/at/update_executor.go

Lines changed: 82 additions & 33 deletions
Original file line numberDiff line numberDiff line change
@@ -20,8 +20,9 @@ package at
2020
import (
2121
"context"
2222
"database/sql/driver"
23-
"errors"
2423
"fmt"
24+
"github.com/arana-db/parser/model"
25+
"seata.apache.org/seata-go/pkg/datasource/sql/util"
2526
"strings"
2627

2728
"github.com/arana-db/parser/ast"
@@ -93,32 +94,37 @@ func (u *updateExecutor) beforeImage(ctx context.Context) (*types.RecordImage, e
9394
return nil, nil
9495
}
9596

96-
tableName, _ := u.parserCtx.GetTableName()
97-
metaData, err := datasource.GetTableCache(types.DBTypeMySQL).GetTableMeta(ctx, u.execContext.DBName, tableName)
97+
selectSQL, selectArgs, err := u.buildBeforeImageSQL(ctx, u.execContext.NamedValues)
9898
if err != nil {
9999
return nil, err
100100
}
101101

102-
selectSQL, selectArgs, err := u.buildBeforeImageSQL(ctx, metaData, "", u.execContext.NamedValues)
103-
102+
tableName, _ := u.parserCtx.GetTableName()
103+
metaData, err := datasource.GetTableCache(types.DBTypeMySQL).GetTableMeta(ctx, u.execContext.DBName, tableName)
104104
if err != nil {
105105
return nil, err
106106
}
107-
if selectSQL == "" {
108-
return nil, errors.New("build select sql by update sourceQuery fail")
109-
}
110107

111-
rowsi, err := u.rowsPrepare(ctx, u.execContext.Conn, selectSQL, selectArgs)
112-
defer func() {
113-
if rowsi != nil {
114-
if err := rowsi.Close(); err != nil {
115-
log.Errorf("rows close fail, err:%v", err)
116-
return
108+
var rowsi driver.Rows
109+
queryerCtx, ok := u.execContext.Conn.(driver.QueryerContext)
110+
var queryer driver.Queryer
111+
if !ok {
112+
queryer, ok = u.execContext.Conn.(driver.Queryer)
113+
}
114+
if ok {
115+
rowsi, err = util.CtxDriverQuery(ctx, queryerCtx, queryer, selectSQL, selectArgs)
116+
defer func() {
117+
if rowsi != nil {
118+
rowsi.Close()
117119
}
120+
}()
121+
if err != nil {
122+
log.Errorf("ctx driver query: %+v", err)
123+
return nil, err
118124
}
119-
}()
120-
if err != nil {
121-
return nil, err
125+
} else {
126+
log.Errorf("target conn should been driver.QueryerContext or driver.Queryer")
127+
return nil, fmt.Errorf("invalid conn")
122128
}
123129

124130
image, err := u.buildRecordImages(rowsi, metaData, types.SQLTypeUpdate)
@@ -149,17 +155,26 @@ func (u *updateExecutor) afterImage(ctx context.Context, beforeImage types.Recor
149155
}
150156
selectSQL, selectArgs := u.buildAfterImageSQL(beforeImage, metaData)
151157

152-
rowsi, err := u.rowsPrepare(ctx, u.execContext.Conn, selectSQL, selectArgs)
153-
defer func() {
154-
if rowsi != nil {
155-
if err := rowsi.Close(); err != nil {
156-
log.Errorf("rows close fail, err:%v", err)
157-
return
158+
var rowsi driver.Rows
159+
queryerCtx, ok := u.execContext.Conn.(driver.QueryerContext)
160+
var queryer driver.Queryer
161+
if !ok {
162+
queryer, ok = u.execContext.Conn.(driver.Queryer)
163+
}
164+
if ok {
165+
rowsi, err = util.CtxDriverQuery(ctx, queryerCtx, queryer, selectSQL, selectArgs)
166+
defer func() {
167+
if rowsi != nil {
168+
rowsi.Close()
158169
}
170+
}()
171+
if err != nil {
172+
log.Errorf("ctx driver query: %+v", err)
173+
return nil, err
159174
}
160-
}()
161-
if err != nil {
162-
return nil, err
175+
} else {
176+
log.Errorf("target conn should been driver.QueryerContext or driver.Queryer")
177+
return nil, fmt.Errorf("invalid conn")
163178
}
164179

165180
afterImage, err := u.buildRecordImages(rowsi, metaData, types.SQLTypeUpdate)
@@ -201,19 +216,53 @@ func (u *updateExecutor) buildAfterImageSQL(beforeImage types.RecordImage, meta
201216
}
202217

203218
// buildAfterImageSQL build the SQL to query before image data
204-
func (u *updateExecutor) buildBeforeImageSQL(ctx context.Context, tableMeta *types.TableMeta, tableAliases string, args []driver.NamedValue) (string, []driver.NamedValue, error) {
219+
func (u *updateExecutor) buildBeforeImageSQL(ctx context.Context, args []driver.NamedValue) (string, []driver.NamedValue, error) {
205220
if !u.isAstStmtValid() {
206221
log.Errorf("invalid update stmt")
207222
return "", nil, fmt.Errorf("invalid update stmt")
208223
}
209224

210225
updateStmt := u.parserCtx.UpdateStmt
211-
fields, err := u.buildSelectFields(ctx, tableMeta, tableAliases, updateStmt.List)
212-
if err != nil {
213-
return "", nil, err
214-
}
215-
if len(fields) == 0 {
216-
return "", nil, err
226+
fields := make([]*ast.SelectField, 0, len(updateStmt.List))
227+
228+
if undo.UndoConfig.OnlyCareUpdateColumns {
229+
for _, column := range updateStmt.List {
230+
fields = append(fields, &ast.SelectField{
231+
Expr: &ast.ColumnNameExpr{
232+
Name: column.Column,
233+
},
234+
})
235+
}
236+
237+
// select indexes columns
238+
tableName, _ := u.parserCtx.GetTableName()
239+
metaData, err := datasource.GetTableCache(types.DBTypeMySQL).GetTableMeta(ctx, u.execContext.DBName, tableName)
240+
if err != nil {
241+
return "", nil, err
242+
}
243+
for _, columnName := range metaData.GetPrimaryKeyOnlyName() {
244+
fields = append(fields, &ast.SelectField{
245+
Expr: &ast.ColumnNameExpr{
246+
Name: &ast.ColumnName{
247+
Name: model.CIStr{
248+
O: columnName,
249+
L: columnName,
250+
},
251+
},
252+
},
253+
})
254+
}
255+
} else {
256+
fields = append(fields, &ast.SelectField{
257+
Expr: &ast.ColumnNameExpr{
258+
Name: &ast.ColumnName{
259+
Name: model.CIStr{
260+
O: "*",
261+
L: "*",
262+
},
263+
},
264+
},
265+
})
217266
}
218267

219268
selStmt := ast.SelectStmt{

pkg/datasource/sql/exec/at/update_executor_test.go

Lines changed: 3 additions & 123 deletions
Original file line numberDiff line numberDiff line change
@@ -20,25 +20,20 @@ package at
2020
import (
2121
"context"
2222
"database/sql/driver"
23-
"os"
23+
"seata.apache.org/seata-go/pkg/datasource/sql/undo"
2424
"testing"
2525

2626
"github.com/stretchr/testify/assert"
2727

2828
"seata.apache.org/seata-go/pkg/datasource/sql/exec"
2929
"seata.apache.org/seata-go/pkg/datasource/sql/parser"
3030
"seata.apache.org/seata-go/pkg/datasource/sql/types"
31-
"seata.apache.org/seata-go/pkg/datasource/sql/undo"
3231
"seata.apache.org/seata-go/pkg/datasource/sql/util"
3332
_ "seata.apache.org/seata-go/pkg/util/log"
3433
)
3534

36-
var (
37-
MetaDataMap map[string]*types.TableMeta
38-
)
39-
40-
func initTest() {
41-
MetaDataMap = map[string]*types.TableMeta{
35+
func TestBuildSelectSQLByUpdate(t *testing.T) {
36+
MetaDataMap := map[string]*types.TableMeta{
4237
"t_user": {
4338
TableName: "t_user",
4439
Indexs: map[string]types.IndexMeta{
@@ -65,124 +60,9 @@ func initTest() {
6560
},
6661
ColumnNames: []string{"id", "name", "age"},
6762
},
68-
"table1": {
69-
TableName: "table1",
70-
Indexs: map[string]types.IndexMeta{
71-
"id": {
72-
IType: types.IndexTypePrimaryKey,
73-
Columns: []types.ColumnMeta{
74-
{ColumnName: "id"},
75-
},
76-
},
77-
},
78-
Columns: map[string]types.ColumnMeta{
79-
"id": {
80-
ColumnDef: nil,
81-
ColumnName: "id",
82-
},
83-
"name": {
84-
ColumnDef: nil,
85-
ColumnName: "name",
86-
},
87-
"age": {
88-
ColumnDef: nil,
89-
ColumnName: "age",
90-
},
91-
},
92-
ColumnNames: []string{"id", "name", "age"},
93-
},
94-
"table2": {
95-
TableName: "table2",
96-
Indexs: map[string]types.IndexMeta{
97-
"id": {
98-
IType: types.IndexTypePrimaryKey,
99-
Columns: []types.ColumnMeta{
100-
{ColumnName: "id"},
101-
},
102-
},
103-
},
104-
Columns: map[string]types.ColumnMeta{
105-
"id": {
106-
ColumnDef: nil,
107-
ColumnName: "id",
108-
},
109-
"name": {
110-
ColumnDef: nil,
111-
ColumnName: "name",
112-
},
113-
"age": {
114-
ColumnDef: nil,
115-
ColumnName: "age",
116-
},
117-
"kk": {
118-
ColumnDef: nil,
119-
ColumnName: "kk",
120-
},
121-
"addr": {
122-
ColumnDef: nil,
123-
ColumnName: "addr",
124-
},
125-
},
126-
ColumnNames: []string{"id", "name", "age", "kk", "addr"},
127-
},
128-
"table3": {
129-
TableName: "table3",
130-
Indexs: map[string]types.IndexMeta{
131-
"id": {
132-
IType: types.IndexTypePrimaryKey,
133-
Columns: []types.ColumnMeta{
134-
{ColumnName: "id"},
135-
},
136-
},
137-
},
138-
Columns: map[string]types.ColumnMeta{
139-
"id": {
140-
ColumnDef: nil,
141-
ColumnName: "id",
142-
},
143-
"age": {
144-
ColumnDef: nil,
145-
ColumnName: "age",
146-
},
147-
},
148-
ColumnNames: []string{"id", "age"},
149-
},
150-
"table4": {
151-
TableName: "table4",
152-
Indexs: map[string]types.IndexMeta{
153-
"id": {
154-
IType: types.IndexTypePrimaryKey,
155-
Columns: []types.ColumnMeta{
156-
{ColumnName: "id"},
157-
},
158-
},
159-
},
160-
Columns: map[string]types.ColumnMeta{
161-
"id": {
162-
ColumnDef: nil,
163-
ColumnName: "id",
164-
},
165-
"age": {
166-
ColumnDef: nil,
167-
ColumnName: "age",
168-
},
169-
},
170-
ColumnNames: []string{"id", "age"},
171-
},
17263
}
17364

17465
undo.InitUndoConfig(undo.Config{OnlyCareUpdateColumns: true})
175-
}
176-
177-
func TestMain(m *testing.M) {
178-
// 调用初始化函数
179-
initTest()
180-
181-
// 启动测试
182-
os.Exit(m.Run())
183-
}
184-
185-
func TestBuildSelectSQLByUpdate(t *testing.T) {
18666

18767
tests := []struct {
18868
name string

0 commit comments

Comments
 (0)