From 0774046a7f3ba4d33df4ebb36ae5f7b511d41217 Mon Sep 17 00:00:00 2001 From: lxfeng1997 <824141436@qq.com> Date: Fri, 10 Jan 2025 19:22:34 +0800 Subject: [PATCH] recover update executor --- pkg/datasource/sql/exec/at/update_executor.go | 115 +++++++++++----- .../sql/exec/at/update_executor_test.go | 126 +----------------- .../sql/exec/at/update_join_executor_test.go | 110 +++++++++++++++ 3 files changed, 195 insertions(+), 156 deletions(-) diff --git a/pkg/datasource/sql/exec/at/update_executor.go b/pkg/datasource/sql/exec/at/update_executor.go index 906fb15a..0f14e97b 100644 --- a/pkg/datasource/sql/exec/at/update_executor.go +++ b/pkg/datasource/sql/exec/at/update_executor.go @@ -20,8 +20,9 @@ package at import ( "context" "database/sql/driver" - "errors" "fmt" + "github.com/arana-db/parser/model" + "seata.apache.org/seata-go/pkg/datasource/sql/util" "strings" "github.com/arana-db/parser/ast" @@ -93,32 +94,37 @@ func (u *updateExecutor) beforeImage(ctx context.Context) (*types.RecordImage, e return nil, nil } - tableName, _ := u.parserCtx.GetTableName() - metaData, err := datasource.GetTableCache(types.DBTypeMySQL).GetTableMeta(ctx, u.execContext.DBName, tableName) + selectSQL, selectArgs, err := u.buildBeforeImageSQL(ctx, u.execContext.NamedValues) if err != nil { return nil, err } - selectSQL, selectArgs, err := u.buildBeforeImageSQL(ctx, metaData, "", u.execContext.NamedValues) - + tableName, _ := u.parserCtx.GetTableName() + metaData, err := datasource.GetTableCache(types.DBTypeMySQL).GetTableMeta(ctx, u.execContext.DBName, tableName) if err != nil { return nil, err } - if selectSQL == "" { - return nil, errors.New("build select sql by update sourceQuery fail") - } - rowsi, err := u.rowsPrepare(ctx, u.execContext.Conn, selectSQL, selectArgs) - defer func() { - if rowsi != nil { - if err := rowsi.Close(); err != nil { - log.Errorf("rows close fail, err:%v", err) - return + var rowsi driver.Rows + queryerCtx, ok := u.execContext.Conn.(driver.QueryerContext) + var queryer driver.Queryer + if !ok { + queryer, ok = u.execContext.Conn.(driver.Queryer) + } + if ok { + rowsi, err = util.CtxDriverQuery(ctx, queryerCtx, queryer, selectSQL, selectArgs) + defer func() { + if rowsi != nil { + rowsi.Close() } + }() + if err != nil { + log.Errorf("ctx driver query: %+v", err) + return nil, err } - }() - if err != nil { - return nil, err + } else { + log.Errorf("target conn should been driver.QueryerContext or driver.Queryer") + return nil, fmt.Errorf("invalid conn") } image, err := u.buildRecordImages(rowsi, metaData, types.SQLTypeUpdate) @@ -149,17 +155,26 @@ func (u *updateExecutor) afterImage(ctx context.Context, beforeImage types.Recor } selectSQL, selectArgs := u.buildAfterImageSQL(beforeImage, metaData) - rowsi, err := u.rowsPrepare(ctx, u.execContext.Conn, selectSQL, selectArgs) - defer func() { - if rowsi != nil { - if err := rowsi.Close(); err != nil { - log.Errorf("rows close fail, err:%v", err) - return + var rowsi driver.Rows + queryerCtx, ok := u.execContext.Conn.(driver.QueryerContext) + var queryer driver.Queryer + if !ok { + queryer, ok = u.execContext.Conn.(driver.Queryer) + } + if ok { + rowsi, err = util.CtxDriverQuery(ctx, queryerCtx, queryer, selectSQL, selectArgs) + defer func() { + if rowsi != nil { + rowsi.Close() } + }() + if err != nil { + log.Errorf("ctx driver query: %+v", err) + return nil, err } - }() - if err != nil { - return nil, err + } else { + log.Errorf("target conn should been driver.QueryerContext or driver.Queryer") + return nil, fmt.Errorf("invalid conn") } afterImage, err := u.buildRecordImages(rowsi, metaData, types.SQLTypeUpdate) @@ -201,19 +216,53 @@ func (u *updateExecutor) buildAfterImageSQL(beforeImage types.RecordImage, meta } // buildAfterImageSQL build the SQL to query before image data -func (u *updateExecutor) buildBeforeImageSQL(ctx context.Context, tableMeta *types.TableMeta, tableAliases string, args []driver.NamedValue) (string, []driver.NamedValue, error) { +func (u *updateExecutor) buildBeforeImageSQL(ctx context.Context, args []driver.NamedValue) (string, []driver.NamedValue, error) { if !u.isAstStmtValid() { log.Errorf("invalid update stmt") return "", nil, fmt.Errorf("invalid update stmt") } updateStmt := u.parserCtx.UpdateStmt - fields, err := u.buildSelectFields(ctx, tableMeta, tableAliases, updateStmt.List) - if err != nil { - return "", nil, err - } - if len(fields) == 0 { - return "", nil, err + fields := make([]*ast.SelectField, 0, len(updateStmt.List)) + + if undo.UndoConfig.OnlyCareUpdateColumns { + for _, column := range updateStmt.List { + fields = append(fields, &ast.SelectField{ + Expr: &ast.ColumnNameExpr{ + Name: column.Column, + }, + }) + } + + // select indexes columns + tableName, _ := u.parserCtx.GetTableName() + metaData, err := datasource.GetTableCache(types.DBTypeMySQL).GetTableMeta(ctx, u.execContext.DBName, tableName) + if err != nil { + return "", nil, err + } + for _, columnName := range metaData.GetPrimaryKeyOnlyName() { + fields = append(fields, &ast.SelectField{ + Expr: &ast.ColumnNameExpr{ + Name: &ast.ColumnName{ + Name: model.CIStr{ + O: columnName, + L: columnName, + }, + }, + }, + }) + } + } else { + fields = append(fields, &ast.SelectField{ + Expr: &ast.ColumnNameExpr{ + Name: &ast.ColumnName{ + Name: model.CIStr{ + O: "*", + L: "*", + }, + }, + }, + }) } selStmt := ast.SelectStmt{ diff --git a/pkg/datasource/sql/exec/at/update_executor_test.go b/pkg/datasource/sql/exec/at/update_executor_test.go index 103358d1..34173aaf 100644 --- a/pkg/datasource/sql/exec/at/update_executor_test.go +++ b/pkg/datasource/sql/exec/at/update_executor_test.go @@ -20,7 +20,7 @@ package at import ( "context" "database/sql/driver" - "os" + "seata.apache.org/seata-go/pkg/datasource/sql/undo" "testing" "github.com/stretchr/testify/assert" @@ -28,17 +28,12 @@ import ( "seata.apache.org/seata-go/pkg/datasource/sql/exec" "seata.apache.org/seata-go/pkg/datasource/sql/parser" "seata.apache.org/seata-go/pkg/datasource/sql/types" - "seata.apache.org/seata-go/pkg/datasource/sql/undo" "seata.apache.org/seata-go/pkg/datasource/sql/util" _ "seata.apache.org/seata-go/pkg/util/log" ) -var ( - MetaDataMap map[string]*types.TableMeta -) - -func initTest() { - MetaDataMap = map[string]*types.TableMeta{ +func TestBuildSelectSQLByUpdate(t *testing.T) { + MetaDataMap := map[string]*types.TableMeta{ "t_user": { TableName: "t_user", Indexs: map[string]types.IndexMeta{ @@ -65,124 +60,9 @@ func initTest() { }, ColumnNames: []string{"id", "name", "age"}, }, - "table1": { - TableName: "table1", - Indexs: map[string]types.IndexMeta{ - "id": { - IType: types.IndexTypePrimaryKey, - Columns: []types.ColumnMeta{ - {ColumnName: "id"}, - }, - }, - }, - Columns: map[string]types.ColumnMeta{ - "id": { - ColumnDef: nil, - ColumnName: "id", - }, - "name": { - ColumnDef: nil, - ColumnName: "name", - }, - "age": { - ColumnDef: nil, - ColumnName: "age", - }, - }, - ColumnNames: []string{"id", "name", "age"}, - }, - "table2": { - TableName: "table2", - Indexs: map[string]types.IndexMeta{ - "id": { - IType: types.IndexTypePrimaryKey, - Columns: []types.ColumnMeta{ - {ColumnName: "id"}, - }, - }, - }, - Columns: map[string]types.ColumnMeta{ - "id": { - ColumnDef: nil, - ColumnName: "id", - }, - "name": { - ColumnDef: nil, - ColumnName: "name", - }, - "age": { - ColumnDef: nil, - ColumnName: "age", - }, - "kk": { - ColumnDef: nil, - ColumnName: "kk", - }, - "addr": { - ColumnDef: nil, - ColumnName: "addr", - }, - }, - ColumnNames: []string{"id", "name", "age", "kk", "addr"}, - }, - "table3": { - TableName: "table3", - Indexs: map[string]types.IndexMeta{ - "id": { - IType: types.IndexTypePrimaryKey, - Columns: []types.ColumnMeta{ - {ColumnName: "id"}, - }, - }, - }, - Columns: map[string]types.ColumnMeta{ - "id": { - ColumnDef: nil, - ColumnName: "id", - }, - "age": { - ColumnDef: nil, - ColumnName: "age", - }, - }, - ColumnNames: []string{"id", "age"}, - }, - "table4": { - TableName: "table4", - Indexs: map[string]types.IndexMeta{ - "id": { - IType: types.IndexTypePrimaryKey, - Columns: []types.ColumnMeta{ - {ColumnName: "id"}, - }, - }, - }, - Columns: map[string]types.ColumnMeta{ - "id": { - ColumnDef: nil, - ColumnName: "id", - }, - "age": { - ColumnDef: nil, - ColumnName: "age", - }, - }, - ColumnNames: []string{"id", "age"}, - }, } undo.InitUndoConfig(undo.Config{OnlyCareUpdateColumns: true}) -} - -func TestMain(m *testing.M) { - // 调用初始化函数 - initTest() - - // 启动测试 - os.Exit(m.Run()) -} - -func TestBuildSelectSQLByUpdate(t *testing.T) { tests := []struct { name string diff --git a/pkg/datasource/sql/exec/at/update_join_executor_test.go b/pkg/datasource/sql/exec/at/update_join_executor_test.go index 0f396ddb..0ef30da3 100644 --- a/pkg/datasource/sql/exec/at/update_join_executor_test.go +++ b/pkg/datasource/sql/exec/at/update_join_executor_test.go @@ -20,6 +20,7 @@ package at import ( "context" "database/sql/driver" + "seata.apache.org/seata-go/pkg/datasource/sql/undo" "testing" "github.com/stretchr/testify/assert" @@ -32,6 +33,115 @@ import ( ) func TestBuildSelectSQLByUpdateJoin(t *testing.T) { + MetaDataMap := map[string]*types.TableMeta{ + "table1": { + TableName: "table1", + Indexs: map[string]types.IndexMeta{ + "id": { + IType: types.IndexTypePrimaryKey, + Columns: []types.ColumnMeta{ + {ColumnName: "id"}, + }, + }, + }, + Columns: map[string]types.ColumnMeta{ + "id": { + ColumnDef: nil, + ColumnName: "id", + }, + "name": { + ColumnDef: nil, + ColumnName: "name", + }, + "age": { + ColumnDef: nil, + ColumnName: "age", + }, + }, + ColumnNames: []string{"id", "name", "age"}, + }, + "table2": { + TableName: "table2", + Indexs: map[string]types.IndexMeta{ + "id": { + IType: types.IndexTypePrimaryKey, + Columns: []types.ColumnMeta{ + {ColumnName: "id"}, + }, + }, + }, + Columns: map[string]types.ColumnMeta{ + "id": { + ColumnDef: nil, + ColumnName: "id", + }, + "name": { + ColumnDef: nil, + ColumnName: "name", + }, + "age": { + ColumnDef: nil, + ColumnName: "age", + }, + "kk": { + ColumnDef: nil, + ColumnName: "kk", + }, + "addr": { + ColumnDef: nil, + ColumnName: "addr", + }, + }, + ColumnNames: []string{"id", "name", "age", "kk", "addr"}, + }, + "table3": { + TableName: "table3", + Indexs: map[string]types.IndexMeta{ + "id": { + IType: types.IndexTypePrimaryKey, + Columns: []types.ColumnMeta{ + {ColumnName: "id"}, + }, + }, + }, + Columns: map[string]types.ColumnMeta{ + "id": { + ColumnDef: nil, + ColumnName: "id", + }, + "age": { + ColumnDef: nil, + ColumnName: "age", + }, + }, + ColumnNames: []string{"id", "age"}, + }, + "table4": { + TableName: "table4", + Indexs: map[string]types.IndexMeta{ + "id": { + IType: types.IndexTypePrimaryKey, + Columns: []types.ColumnMeta{ + {ColumnName: "id"}, + }, + }, + }, + Columns: map[string]types.ColumnMeta{ + "id": { + ColumnDef: nil, + ColumnName: "id", + }, + "age": { + ColumnDef: nil, + ColumnName: "age", + }, + }, + ColumnNames: []string{"id", "age"}, + }, + } + + undo.InitUndoConfig(undo.Config{OnlyCareUpdateColumns: true}) + tests := []struct { name string sourceQuery string