@@ -20,8 +20,9 @@ package at
20
20
import (
21
21
"context"
22
22
"database/sql/driver"
23
- "errors"
24
23
"fmt"
24
+ "github.com/arana-db/parser/model"
25
+ "seata.apache.org/seata-go/pkg/datasource/sql/util"
25
26
"strings"
26
27
27
28
"github.com/arana-db/parser/ast"
@@ -93,32 +94,37 @@ func (u *updateExecutor) beforeImage(ctx context.Context) (*types.RecordImage, e
93
94
return nil , nil
94
95
}
95
96
96
- tableName , _ := u .parserCtx .GetTableName ()
97
- metaData , err := datasource .GetTableCache (types .DBTypeMySQL ).GetTableMeta (ctx , u .execContext .DBName , tableName )
97
+ selectSQL , selectArgs , err := u .buildBeforeImageSQL (ctx , u .execContext .NamedValues )
98
98
if err != nil {
99
99
return nil , err
100
100
}
101
101
102
- selectSQL , selectArgs , err := u .buildBeforeImageSQL ( ctx , metaData , "" , u . execContext . NamedValues )
103
-
102
+ tableName , _ := u .parserCtx . GetTableName ( )
103
+ metaData , err := datasource . GetTableCache ( types . DBTypeMySQL ). GetTableMeta ( ctx , u . execContext . DBName , tableName )
104
104
if err != nil {
105
105
return nil , err
106
106
}
107
- if selectSQL == "" {
108
- return nil , errors .New ("build select sql by update sourceQuery fail" )
109
- }
110
107
111
- rowsi , err := u .rowsPrepare (ctx , u .execContext .Conn , selectSQL , selectArgs )
112
- defer func () {
113
- if rowsi != nil {
114
- if err := rowsi .Close (); err != nil {
115
- log .Errorf ("rows close fail, err:%v" , err )
116
- return
108
+ var rowsi driver.Rows
109
+ queryerCtx , ok := u .execContext .Conn .(driver.QueryerContext )
110
+ var queryer driver.Queryer
111
+ if ! ok {
112
+ queryer , ok = u .execContext .Conn .(driver.Queryer )
113
+ }
114
+ if ok {
115
+ rowsi , err = util .CtxDriverQuery (ctx , queryerCtx , queryer , selectSQL , selectArgs )
116
+ defer func () {
117
+ if rowsi != nil {
118
+ rowsi .Close ()
117
119
}
120
+ }()
121
+ if err != nil {
122
+ log .Errorf ("ctx driver query: %+v" , err )
123
+ return nil , err
118
124
}
119
- }()
120
- if err != nil {
121
- return nil , err
125
+ } else {
126
+ log . Errorf ( "target conn should been driver.QueryerContext or driver.Queryer" )
127
+ return nil , fmt . Errorf ( "invalid conn" )
122
128
}
123
129
124
130
image , err := u .buildRecordImages (rowsi , metaData , types .SQLTypeUpdate )
@@ -149,17 +155,26 @@ func (u *updateExecutor) afterImage(ctx context.Context, beforeImage types.Recor
149
155
}
150
156
selectSQL , selectArgs := u .buildAfterImageSQL (beforeImage , metaData )
151
157
152
- rowsi , err := u .rowsPrepare (ctx , u .execContext .Conn , selectSQL , selectArgs )
153
- defer func () {
154
- if rowsi != nil {
155
- if err := rowsi .Close (); err != nil {
156
- log .Errorf ("rows close fail, err:%v" , err )
157
- return
158
+ var rowsi driver.Rows
159
+ queryerCtx , ok := u .execContext .Conn .(driver.QueryerContext )
160
+ var queryer driver.Queryer
161
+ if ! ok {
162
+ queryer , ok = u .execContext .Conn .(driver.Queryer )
163
+ }
164
+ if ok {
165
+ rowsi , err = util .CtxDriverQuery (ctx , queryerCtx , queryer , selectSQL , selectArgs )
166
+ defer func () {
167
+ if rowsi != nil {
168
+ rowsi .Close ()
158
169
}
170
+ }()
171
+ if err != nil {
172
+ log .Errorf ("ctx driver query: %+v" , err )
173
+ return nil , err
159
174
}
160
- }()
161
- if err != nil {
162
- return nil , err
175
+ } else {
176
+ log . Errorf ( "target conn should been driver.QueryerContext or driver.Queryer" )
177
+ return nil , fmt . Errorf ( "invalid conn" )
163
178
}
164
179
165
180
afterImage , err := u .buildRecordImages (rowsi , metaData , types .SQLTypeUpdate )
@@ -201,19 +216,53 @@ func (u *updateExecutor) buildAfterImageSQL(beforeImage types.RecordImage, meta
201
216
}
202
217
203
218
// buildAfterImageSQL build the SQL to query before image data
204
- func (u * updateExecutor ) buildBeforeImageSQL (ctx context.Context , tableMeta * types. TableMeta , tableAliases string , args []driver.NamedValue ) (string , []driver.NamedValue , error ) {
219
+ func (u * updateExecutor ) buildBeforeImageSQL (ctx context.Context , args []driver.NamedValue ) (string , []driver.NamedValue , error ) {
205
220
if ! u .isAstStmtValid () {
206
221
log .Errorf ("invalid update stmt" )
207
222
return "" , nil , fmt .Errorf ("invalid update stmt" )
208
223
}
209
224
210
225
updateStmt := u .parserCtx .UpdateStmt
211
- fields , err := u .buildSelectFields (ctx , tableMeta , tableAliases , updateStmt .List )
212
- if err != nil {
213
- return "" , nil , err
214
- }
215
- if len (fields ) == 0 {
216
- return "" , nil , err
226
+ fields := make ([]* ast.SelectField , 0 , len (updateStmt .List ))
227
+
228
+ if undo .UndoConfig .OnlyCareUpdateColumns {
229
+ for _ , column := range updateStmt .List {
230
+ fields = append (fields , & ast.SelectField {
231
+ Expr : & ast.ColumnNameExpr {
232
+ Name : column .Column ,
233
+ },
234
+ })
235
+ }
236
+
237
+ // select indexes columns
238
+ tableName , _ := u .parserCtx .GetTableName ()
239
+ metaData , err := datasource .GetTableCache (types .DBTypeMySQL ).GetTableMeta (ctx , u .execContext .DBName , tableName )
240
+ if err != nil {
241
+ return "" , nil , err
242
+ }
243
+ for _ , columnName := range metaData .GetPrimaryKeyOnlyName () {
244
+ fields = append (fields , & ast.SelectField {
245
+ Expr : & ast.ColumnNameExpr {
246
+ Name : & ast.ColumnName {
247
+ Name : model.CIStr {
248
+ O : columnName ,
249
+ L : columnName ,
250
+ },
251
+ },
252
+ },
253
+ })
254
+ }
255
+ } else {
256
+ fields = append (fields , & ast.SelectField {
257
+ Expr : & ast.ColumnNameExpr {
258
+ Name : & ast.ColumnName {
259
+ Name : model.CIStr {
260
+ O : "*" ,
261
+ L : "*" ,
262
+ },
263
+ },
264
+ },
265
+ })
217
266
}
218
267
219
268
selStmt := ast.SelectStmt {
0 commit comments