diff --git a/driver/e2e_test.go b/driver/e2e_test.go deleted file mode 100644 index cd9f3975e4..0000000000 --- a/driver/e2e_test.go +++ /dev/null @@ -1,106 +0,0 @@ -package driver_test - -import ( - "testing" - "time" - - "github.com/stretchr/testify/assert" - "github.com/stretchr/testify/require" -) - -func TestQuery(t *testing.T) { - mtb, records := personMemTable("db", "person") - db := sqlOpen(t, mtb, t.Name()+"?jsonAs=object") - - var name, email string - var numbers interface{} - var created time.Time - var count int - - cases := []struct { - Name, Query string - Pointers Pointers - Expect Records - }{ - {"Select All", "SELECT * FROM db.person", []V{&name, &email, &numbers, &created}, records}, - {"Select First", "SELECT * FROM db.person LIMIT 1", []V{&name, &email, &numbers, &created}, records.Rows(0)}, - {"Select Name", "SELECT name FROM db.person", []V{&name}, records.Columns(0)}, - {"Select Count", "SELECT COUNT(1) FROM db.person", []V{&count}, Records{{len(records)}}}, - - {"Insert", `INSERT INTO db.person VALUES ('foo', 'bar', '["baz"]', NOW())`, []V{}, Records{}}, - {"Select Inserted", "SELECT name, email, phone_numbers FROM db.person WHERE name = 'foo'", []V{&name, &email, &numbers}, Records{{"foo", "bar", []V{"baz"}}}}, - - {"Update", "UPDATE db.person SET name = 'asdf' WHERE name = 'foo'", []V{}, Records{}}, - {"Delete", "DELETE FROM db.person WHERE name = 'asdf'", []V{}, Records{}}, - } - - for _, c := range cases { - t.Run(c.Name, func(t *testing.T) { - rows, err := db.Query(c.Query) - require.NoError(t, err, "Query") - - var i int - for ; rows.Next(); i++ { - require.NoError(t, rows.Scan(c.Pointers...), "Scan") - values := c.Pointers.Values() - - if i >= len(c.Expect) { - t.Errorf("Got row %d, expected %d total: %v", i+1, len(c.Expect), values) - continue - } - - assert.EqualValues(t, c.Expect[i], values, "Values") - } - - require.NoError(t, rows.Err(), "Rows.Err") - - if i < len(c.Expect) { - t.Errorf("Expected %d row(s), got %d", len(c.Expect), i) - } - }) - } -} - -func TestExec(t *testing.T) { - mtb, records := personMemTable("db", "person") - db := sqlOpen(t, mtb, t.Name()) - - cases := []struct { - Name, Statement string - RowsAffected int - }{ - {"Insert", `INSERT INTO db.person VALUES ('asdf', 'qwer', '["zxcv"]', NOW())`, 1}, - {"Update", "UPDATE db.person SET name = 'foo' WHERE name = 'asdf'", 1}, - {"Delete", "DELETE FROM db.person WHERE name = 'foo'", 1}, - {"Delete All", "DELETE FROM db.person WHERE LENGTH(name) < 100", len(records)}, - } - - for _, c := range cases { - t.Run(c.Name, func(t *testing.T) { - res, err := db.Exec(c.Statement) - require.NoError(t, err, "Exec") - - count, err := res.RowsAffected() - require.NoError(t, err, "RowsAffected") - assert.EqualValues(t, c.RowsAffected, count, "RowsAffected") - }) - } - - errCases := []struct { - Name, Statement string - Error string - }{ - {"Select", "SELECT * FROM db.person", "no result"}, - } - - for _, c := range errCases { - t.Run(c.Name, func(t *testing.T) { - res, err := db.Exec(c.Statement) - require.NoError(t, err, "Exec") - - _, err = res.RowsAffected() - require.Error(t, err, "RowsAffected") - assert.Equal(t, c.Error, err.Error()) - }) - } -} diff --git a/driver/fixtures_test.go b/driver/fixtures_test.go deleted file mode 100644 index 3da368b08e..0000000000 --- a/driver/fixtures_test.go +++ /dev/null @@ -1,68 +0,0 @@ -package driver_test - -import ( - "sync" - "time" - - "github.com/dolthub/go-mysql-server/driver" - "github.com/dolthub/go-mysql-server/memory" - "github.com/dolthub/go-mysql-server/sql" - "github.com/dolthub/go-mysql-server/sql/information_schema" -) - -type memTable struct { - DatabaseName string - TableName string - Schema sql.PrimaryKeySchema - Records Records - - once sync.Once - dbProvider sql.DatabaseProvider -} - -func (f *memTable) Resolve(name string, _ *driver.Options) (string, sql.DatabaseProvider, error) { - f.once.Do(func() { - database := memory.NewDatabase(f.DatabaseName) - - table := memory.NewTable(f.TableName, f.Schema, database.GetForeignKeyCollection()) - - if f.Records != nil { - ctx := sql.NewEmptyContext() - for _, row := range f.Records { - table.Insert(ctx, sql.NewRow(row...)) - } - } - - database.AddTable(f.TableName, table) - - pro := memory.NewMemoryDBProvider( - database, - information_schema.NewInformationSchemaDatabase()) - f.dbProvider = pro - }) - - return name, f.dbProvider, nil -} - -func personMemTable(database, table string) (*memTable, Records) { - records := Records{ - []V{"John Doe", "john@doe.com", []V{"555-555-555"}, time.Now()}, - []V{"John Doe", "johnalt@doe.com", []V{}, time.Now()}, - []V{"Jane Doe", "jane@doe.com", []V{}, time.Now()}, - []V{"Evil Bob", "evilbob@gmail.com", []V{"555-666-555", "666-666-666"}, time.Now()}, - } - - mtb := &memTable{ - DatabaseName: database, - TableName: table, - Schema: sql.NewPrimaryKeySchema(sql.Schema{ - {Name: "name", Type: sql.Text, Nullable: false, Source: table}, - {Name: "email", Type: sql.Text, Nullable: false, Source: table}, - {Name: "phone_numbers", Type: sql.JSON, Nullable: false, Source: table}, - {Name: "created_at", Type: sql.Timestamp, Nullable: false, Source: table}, - }), - Records: records, - } - - return mtb, records -} diff --git a/driver/helpers_test.go b/driver/helpers_test.go deleted file mode 100644 index 5f796dc008..0000000000 --- a/driver/helpers_test.go +++ /dev/null @@ -1,66 +0,0 @@ -package driver_test - -import ( - "database/sql" - "reflect" - "sync" - "testing" - - "github.com/stretchr/testify/require" - - "github.com/dolthub/go-mysql-server/driver" -) - -type V = interface{} - -var driverMu sync.Mutex -var drivers = map[driver.Provider]*driver.Driver{} - -func sqlOpen(t *testing.T, provider driver.Provider, dsn string) *sql.DB { - driverMu.Lock() - drv, ok := drivers[provider] - if !ok { - drv = driver.New(provider, nil) - drivers[provider] = drv - } - driverMu.Unlock() - - conn, err := drv.OpenConnector(dsn) - require.NoError(t, err) - return sql.OpenDB(conn) -} - -type Pointers []V - -func (ptrs Pointers) Values() []V { - values := make([]V, len(ptrs)) - for i := range values { - values[i] = reflect.ValueOf(ptrs[i]).Elem().Interface() - } - return values -} - -type Records [][]V - -func (records Records) Rows(rows ...int) Records { - result := make(Records, len(rows)) - - for i := range rows { - result[i] = records[rows[i]] - } - - return result -} - -func (records Records) Columns(cols ...int) Records { - result := make(Records, len(records)) - - for i := range records { - result[i] = make([]V, len(cols)) - for j := range cols { - result[i][j] = records[i][cols[j]] - } - } - - return result -} diff --git a/enginetest/memory_engine_test.go b/enginetest/memory_engine_test.go index 7656860b12..e5d0bdbef9 100644 --- a/enginetest/memory_engine_test.go +++ b/enginetest/memory_engine_test.go @@ -475,12 +475,6 @@ func TestLoadDataPrepared(t *testing.T) { } func TestScriptsPrepared(t *testing.T) { - //TODO: when foreign keys are implemented in the memory table, we can do the following test - for i := len(queries.ScriptTests) - 1; i >= 0; i-- { - if queries.ScriptTests[i].Name == "failed statements data validation for DELETE, REPLACE" { - queries.ScriptTests = append(queries.ScriptTests[:i], queries.ScriptTests[i+1:]...) - } - } enginetest.TestScriptsPrepared(t, enginetest.NewMemoryHarness("default", 1, testNumPartitions, true, mergableIndexDriver)) } diff --git a/enginetest/queries/priv_auth_queries.go b/enginetest/queries/priv_auth_queries.go index fb40711c87..d964c43a58 100644 --- a/enginetest/queries/priv_auth_queries.go +++ b/enginetest/queries/priv_auth_queries.go @@ -311,36 +311,36 @@ var UserPrivTests = []UserPrivilegeTest{ { "localhost", // Host "root", // User - "Y", // Select_priv - "Y", // Insert_priv - "Y", // Update_priv - "Y", // Delete_priv - "Y", // Create_priv - "Y", // Drop_priv - "Y", // Reload_priv - "Y", // Shutdown_priv - "Y", // Process_priv - "Y", // File_priv - "Y", // Grant_priv - "Y", // References_priv - "Y", // Index_priv - "Y", // Alter_priv - "Y", // Show_db_priv - "Y", // Super_priv - "Y", // Create_tmp_table_priv - "Y", // Lock_tables_priv - "Y", // Execute_priv - "Y", // Repl_slave_priv - "Y", // Repl_client_priv - "Y", // Create_view_priv - "Y", // Show_view_priv - "Y", // Create_routine_priv - "Y", // Alter_routine_priv - "Y", // Create_user_priv - "Y", // Event_priv - "Y", // Trigger_priv - "Y", // Create_tablespace_priv - "", // ssl_type + uint16(2), // Select_priv + uint16(2), // Insert_priv + uint16(2), // Update_priv + uint16(2), // Delete_priv + uint16(2), // Create_priv + uint16(2), // Drop_priv + uint16(2), // Reload_priv + uint16(2), // Shutdown_priv + uint16(2), // Process_priv + uint16(2), // File_priv + uint16(2), // Grant_priv + uint16(2), // References_priv + uint16(2), // Index_priv + uint16(2), // Alter_priv + uint16(2), // Show_db_priv + uint16(2), // Super_priv + uint16(2), // Create_tmp_table_priv + uint16(2), // Lock_tables_priv + uint16(2), // Execute_priv + uint16(2), // Repl_slave_priv + uint16(2), // Repl_client_priv + uint16(2), // Create_view_priv + uint16(2), // Show_view_priv + uint16(2), // Create_routine_priv + uint16(2), // Alter_routine_priv + uint16(2), // Create_user_priv + uint16(2), // Event_priv + uint16(2), // Trigger_priv + uint16(2), // Create_tablespace_priv + uint16(1), // ssl_type "", // ssl_cipher "", // x509_issuer "", // x509_subject @@ -350,12 +350,12 @@ var UserPrivTests = []UserPrivilegeTest{ uint32(0), // max_user_connections "mysql_native_password", // plugin "", // authentication_string - "N", // password_expired + uint16(1), // password_expired time.Unix(1, 0).UTC(), // password_last_changed nil, // password_lifetime - "N", // account_locked - "Y", // Create_role_priv - "Y", // Drop_role_priv + uint16(1), // account_locked + uint16(2), // Create_role_priv + uint16(2), // Drop_role_priv nil, // Password_reuse_history nil, // Password_reuse_time nil, // Password_require_current @@ -461,7 +461,7 @@ var UserPrivTests = []UserPrivilegeTest{ User: "root", Host: "localhost", Query: "SELECT * FROM mysql.db;", - Expected: []sql.Row{{"localhost", "mydb", "tester", "Y", "N", "Y", "N", "N", "N", "N", "N", "N", "N", "N", "N", "N", "N", "N", "N", "Y", "N", "N"}}, + Expected: []sql.Row{{"localhost", "mydb", "tester", uint16(2), uint16(1), uint16(2), uint16(1), uint16(1), uint16(1), uint16(1), uint16(1), uint16(1), uint16(1), uint16(1), uint16(1), uint16(1), uint16(1), uint16(1), uint16(1), uint16(2), uint16(1), uint16(1)}}, }, { User: "root", @@ -473,7 +473,7 @@ var UserPrivTests = []UserPrivilegeTest{ User: "root", Host: "localhost", Query: "SELECT * FROM mysql.db;", - Expected: []sql.Row{{"localhost", "mydb", "tester", "Y", "N", "N", "N", "N", "N", "N", "N", "N", "N", "N", "N", "N", "N", "N", "N", "Y", "N", "N"}}, + Expected: []sql.Row{{"localhost", "mydb", "tester", uint16(2), uint16(1), uint16(1), uint16(1), uint16(1), uint16(1), uint16(1), uint16(1), uint16(1), uint16(1), uint16(1), uint16(1), uint16(1), uint16(1), uint16(1), uint16(1), uint16(2), uint16(1), uint16(1)}}, }, { User: "root", @@ -493,7 +493,7 @@ var UserPrivTests = []UserPrivilegeTest{ User: "root", Host: "localhost", Query: "SELECT * FROM mysql.db;", - Expected: []sql.Row{{"localhost", "mydb", "tester", "Y", "Y", "N", "N", "N", "N", "N", "N", "N", "N", "N", "N", "N", "N", "N", "N", "Y", "N", "N"}}, + Expected: []sql.Row{{"localhost", "mydb", "tester", uint16(2), uint16(2), uint16(1), uint16(1), uint16(1), uint16(1), uint16(1), uint16(1), uint16(1), uint16(1), uint16(1), uint16(1), uint16(1), uint16(1), uint16(1), uint16(1), uint16(2), uint16(1), uint16(1)}}, }, }, }, @@ -514,7 +514,7 @@ var UserPrivTests = []UserPrivilegeTest{ User: "root", Host: "localhost", Query: "SELECT * FROM mysql.tables_priv;", - Expected: []sql.Row{{"localhost", "mydb", "tester", "test", "", time.Unix(1, 0).UTC(), "Select,Delete,Drop", ""}}, + Expected: []sql.Row{{"localhost", "mydb", "tester", "test", "", time.Unix(1, 0).UTC(), uint64(0b101001), uint64(0)}}, }, { User: "root", @@ -526,7 +526,7 @@ var UserPrivTests = []UserPrivilegeTest{ User: "root", Host: "localhost", Query: "SELECT * FROM mysql.tables_priv;", - Expected: []sql.Row{{"localhost", "mydb", "tester", "test", "", time.Unix(1, 0).UTC(), "Select,Drop", ""}}, + Expected: []sql.Row{{"localhost", "mydb", "tester", "test", "", time.Unix(1, 0).UTC(), uint64(0b100001), uint64(0)}}, }, { User: "root", @@ -546,7 +546,7 @@ var UserPrivTests = []UserPrivilegeTest{ User: "root", Host: "localhost", Query: "SELECT * FROM mysql.tables_priv;", - Expected: []sql.Row{{"localhost", "mydb", "tester", "test", "", time.Unix(1, 0).UTC(), "References,Index", ""}}, + Expected: []sql.Row{{"localhost", "mydb", "tester", "test", "", time.Unix(1, 0).UTC(), uint64(0b110000000), uint64(0)}}, }, }, }, @@ -569,7 +569,7 @@ var UserPrivTests = []UserPrivilegeTest{ User: "root", Host: "localhost", Query: "SELECT User, Host, Select_priv FROM mysql.user WHERE User = 'tester';", - Expected: []sql.Row{{"tester", "localhost", "Y"}}, + Expected: []sql.Row{{"tester", "localhost", uint16(2)}}, }, { User: "root", @@ -587,7 +587,7 @@ var UserPrivTests = []UserPrivilegeTest{ User: "root", Host: "localhost", Query: "SELECT User, Host, Select_priv FROM mysql.user WHERE User = 'tester';", - Expected: []sql.Row{{"tester", "localhost", "N"}}, + Expected: []sql.Row{{"tester", "localhost", uint16(1)}}, }, }, }, @@ -616,7 +616,7 @@ var UserPrivTests = []UserPrivilegeTest{ User: "root", Host: "localhost", Query: "SELECT User, Host, Select_priv, Insert_priv FROM mysql.user WHERE User = 'tester';", - Expected: []sql.Row{{"tester", "localhost", "Y", "Y"}}, + Expected: []sql.Row{{"tester", "localhost", uint16(2), uint16(2)}}, }, { User: "root", @@ -640,7 +640,7 @@ var UserPrivTests = []UserPrivilegeTest{ User: "root", Host: "localhost", Query: "SELECT User, Host, Select_priv, Insert_priv FROM mysql.user WHERE User = 'tester';", - Expected: []sql.Row{{"tester", "localhost", "N", "N"}}, + Expected: []sql.Row{{"tester", "localhost", uint16(1), uint16(1)}}, }, }, }, @@ -654,7 +654,7 @@ var UserPrivTests = []UserPrivilegeTest{ User: "root", Host: "localhost", Query: "SELECT User, Host, account_locked FROM mysql.user WHERE User = 'test_role';", - Expected: []sql.Row{{"test_role", "%", "Y"}}, + Expected: []sql.Row{{"test_role", "%", uint16(2)}}, }, }, }, @@ -691,7 +691,7 @@ var UserPrivTests = []UserPrivilegeTest{ User: "root", Host: "localhost", Query: "SELECT * FROM mysql.role_edges;", - Expected: []sql.Row{{"%", "test_role", "localhost", "tester", "N"}}, + Expected: []sql.Row{{"%", "test_role", "localhost", "tester", uint16(1)}}, }, { User: "tester", @@ -703,7 +703,7 @@ var UserPrivTests = []UserPrivilegeTest{ User: "root", Host: "localhost", Query: "SELECT User, Host, Select_priv FROM mysql.user WHERE User = 'tester';", - Expected: []sql.Row{{"tester", "localhost", "N"}}, + Expected: []sql.Row{{"tester", "localhost", uint16(1)}}, }, }, }, @@ -729,7 +729,7 @@ var UserPrivTests = []UserPrivilegeTest{ User: "root", Host: "localhost", Query: "SELECT * FROM mysql.role_edges;", - Expected: []sql.Row{{"%", "test_role", "localhost", "tester", "N"}}, + Expected: []sql.Row{{"%", "test_role", "localhost", "tester", uint16(1)}}, }, { User: "root", @@ -785,7 +785,7 @@ var UserPrivTests = []UserPrivilegeTest{ User: "root", Host: "localhost", Query: "SELECT * FROM mysql.role_edges;", - Expected: []sql.Row{{"%", "test_role", "localhost", "tester", "N"}}, + Expected: []sql.Row{{"%", "test_role", "localhost", "tester", uint16(1)}}, }, { User: "root", @@ -847,7 +847,7 @@ var UserPrivTests = []UserPrivilegeTest{ User: "root", Host: "localhost", Query: "SELECT * FROM mysql.role_edges;", - Expected: []sql.Row{{"%", "test_role", "localhost", "tester", "N"}}, + Expected: []sql.Row{{"%", "test_role", "localhost", "tester", uint16(1)}}, }, { User: "root", diff --git a/enginetest/queries/script_queries.go b/enginetest/queries/script_queries.go index 9751f23ca0..dc5a153a5b 100644 --- a/enginetest/queries/script_queries.go +++ b/enginetest/queries/script_queries.go @@ -1359,7 +1359,7 @@ var ScriptTests = []ScriptTest{ Assertions: []ScriptTestAssertion{ { Query: "SELECT * FROM test;", - Expected: []sql.Row{{1, "b", "b"}, {2, "a", "a"}}, + Expected: []sql.Row{{1, uint16(2), uint64(2)}, {2, uint16(1), uint64(1)}}, }, { Query: "UPDATE test SET v1 = 3 WHERE v1 = 2;", @@ -1367,7 +1367,7 @@ var ScriptTests = []ScriptTest{ }, { Query: "SELECT * FROM test;", - Expected: []sql.Row{{1, "c", "b"}, {2, "a", "a"}}, + Expected: []sql.Row{{1, uint16(3), uint64(2)}, {2, uint16(1), uint64(1)}}, }, { Query: "UPDATE test SET v2 = 3 WHERE 2 = v2;", @@ -1375,7 +1375,7 @@ var ScriptTests = []ScriptTest{ }, { Query: "SELECT * FROM test;", - Expected: []sql.Row{{1, "c", "a,b"}, {2, "a", "a"}}, + Expected: []sql.Row{{1, uint16(3), uint64(3)}, {2, uint16(1), uint64(1)}}, }, }, }, diff --git a/enginetest/queries/variable_queries.go b/enginetest/queries/variable_queries.go index f16a9dded5..6c005d122f 100644 --- a/enginetest/queries/variable_queries.go +++ b/enginetest/queries/variable_queries.go @@ -76,7 +76,7 @@ var VariableQueries = []ScriptTest{ }, Query: "SELECT @@autocommit, @@session.sql_mode", Expected: []sql.Row{ - {1, ""}, + {1, uint64(0)}, }, }, { @@ -86,7 +86,7 @@ var VariableQueries = []ScriptTest{ }, Query: "SELECT @@autocommit, @@session.sql_mode", Expected: []sql.Row{ - {1, ""}, + {1, uint64(0)}, }, }, { @@ -189,7 +189,7 @@ var VariableQueries = []ScriptTest{ }, Query: "SELECT @@sql_mode", Expected: []sql.Row{ - {"ALLOW_INVALID_DATES"}, + {uint64(1)}, }, }, { @@ -199,7 +199,7 @@ var VariableQueries = []ScriptTest{ }, Query: "SELECT @@sql_mode", Expected: []sql.Row{ - {"ALLOW_INVALID_DATES"}, + {uint64(1)}, }, }, { @@ -209,7 +209,7 @@ var VariableQueries = []ScriptTest{ }, Query: "SELECT @@sql_mode", Expected: []sql.Row{ - {"ERROR_FOR_DIVISION_BY_ZERO,NO_ENGINE_SUBSTITUTION,NO_ZERO_DATE,NO_ZERO_IN_DATE,STRICT_ALL_TABLES,STRICT_TRANS_TABLES,TRADITIONAL"}, + {uint64(0b10110000110100000100)}, }, }, // User variables diff --git a/enginetest/testdata.go b/enginetest/testdata.go index a7c92f12c1..dabd191bb7 100644 --- a/enginetest/testdata.go +++ b/enginetest/testdata.go @@ -199,21 +199,21 @@ func createSubsetTestData(t *testing.T, harness Harness, includedTables []string if err == nil { InsertRows(t, NewContext(harness), mustInsertableTable(t, table), - sql.NewRow(1, "first_row"), - sql.NewRow(2, "second_row"), - sql.NewRow(3, "third_row"), - sql.NewRow(4, `%`), - sql.NewRow(5, `'`), - sql.NewRow(6, `"`), - sql.NewRow(7, "\t"), - sql.NewRow(8, "\n"), - sql.NewRow(9, "\v"), - sql.NewRow(10, `test%test`), - sql.NewRow(11, `test'test`), - sql.NewRow(12, `test"test`), - sql.NewRow(13, "test\ttest"), - sql.NewRow(14, "test\ntest"), - sql.NewRow(15, "test\vtest"), + sql.NewRow(int64(1), "first_row"), + sql.NewRow(int64(2), "second_row"), + sql.NewRow(int64(3), "third_row"), + sql.NewRow(int64(4), `%`), + sql.NewRow(int64(5), `'`), + sql.NewRow(int64(6), `"`), + sql.NewRow(int64(7), "\t"), + sql.NewRow(int64(8), "\n"), + sql.NewRow(int64(9), "\v"), + sql.NewRow(int64(10), `test%test`), + sql.NewRow(int64(11), `test'test`), + sql.NewRow(int64(12), `test"test`), + sql.NewRow(int64(13), "test\ttest"), + sql.NewRow(int64(14), "test\ntest"), + sql.NewRow(int64(15), "test\vtest"), ) } else { t.Logf("Warning: could not create table %s: %s", "specialtable", err) @@ -252,10 +252,10 @@ func createSubsetTestData(t *testing.T, harness Harness, includedTables []string if err == nil { InsertRows(t, NewContext(harness), mustInsertableTable(t, table), - sql.NewRow(0, 0, 1, 2, 3, 4), - sql.NewRow(1, 10, 11, 12, 13, 14), - sql.NewRow(2, 20, 21, 22, 23, 24), - sql.NewRow(3, 30, 31, 32, 33, 34)) + sql.NewRow(int8(0), int8(0), int8(1), int8(2), int8(3), int8(4)), + sql.NewRow(int8(1), int8(10), int8(11), int8(12), int8(13), int8(14)), + sql.NewRow(int8(2), int8(20), int8(21), int8(22), int8(23), int8(24)), + sql.NewRow(int8(3), int8(30), int8(31), int8(32), int8(33), int8(34))) } else { t.Logf("Warning: could not create table %s: %s", "one_pk", err) } @@ -273,10 +273,10 @@ func createSubsetTestData(t *testing.T, harness Harness, includedTables []string if err == nil { InsertRows(t, NewContext(harness), mustInsertableTable(t, table), - sql.NewRow(1, "row one", sql.JSONDocument{Val: []interface{}{1, 2}}, sql.JSONDocument{Val: map[string]interface{}{"a": 2}}), - sql.NewRow(2, "row two", sql.JSONDocument{Val: []interface{}{3, 4}}, sql.JSONDocument{Val: map[string]interface{}{"b": 2}}), - sql.NewRow(3, "row three", sql.JSONDocument{Val: []interface{}{5, 6}}, sql.JSONDocument{Val: map[string]interface{}{"c": 2}}), - sql.NewRow(4, "row four", sql.JSONDocument{Val: []interface{}{7, 8}}, sql.JSONDocument{Val: map[string]interface{}{"d": 2}})) + sql.NewRow(int8(1), "row one", sql.JSONDocument{Val: []interface{}{1, 2}}, sql.JSONDocument{Val: map[string]interface{}{"a": 2}}), + sql.NewRow(int8(2), "row two", sql.JSONDocument{Val: []interface{}{3, 4}}, sql.JSONDocument{Val: map[string]interface{}{"b": 2}}), + sql.NewRow(int8(3), "row three", sql.JSONDocument{Val: []interface{}{5, 6}}, sql.JSONDocument{Val: map[string]interface{}{"c": 2}}), + sql.NewRow(int8(4), "row four", sql.JSONDocument{Val: []interface{}{7, 8}}, sql.JSONDocument{Val: map[string]interface{}{"d": 2}})) } else { t.Logf("Warning: could not create table %s: %s", "jsontable", err) } @@ -297,10 +297,10 @@ func createSubsetTestData(t *testing.T, harness Harness, includedTables []string if err == nil { InsertRows(t, NewContext(harness), mustInsertableTable(t, table), - sql.NewRow(0, 0, 0, 1, 2, 3, 4), - sql.NewRow(0, 1, 10, 11, 12, 13, 14), - sql.NewRow(1, 0, 20, 21, 22, 23, 24), - sql.NewRow(1, 1, 30, 31, 32, 33, 34)) + sql.NewRow(int8(0), int8(0), int8(0), int8(1), int8(2), int8(3), int8(4)), + sql.NewRow(int8(0), int8(1), int8(10), int8(11), int8(12), int8(13), int8(14)), + sql.NewRow(int8(1), int8(0), int8(20), int8(21), int8(22), int8(23), int8(24)), + sql.NewRow(int8(1), int8(1), int8(30), int8(31), int8(32), int8(33), int8(34))) } else { t.Logf("Warning: could not create table %s: %s", "two_pk", err) } @@ -317,14 +317,14 @@ func createSubsetTestData(t *testing.T, harness Harness, includedTables []string if err == nil { InsertRows(t, NewContext(harness), mustInsertableTable(t, table), - sql.NewRow(0, 0, 0), - sql.NewRow(1, 1, 1), - sql.NewRow(2, 2, 2), - sql.NewRow(3, 3, 3), - sql.NewRow(4, 4, 4), - sql.NewRow(5, 5, 5), - sql.NewRow(6, 6, 6), - sql.NewRow(7, 7, 7)) + sql.NewRow(int64(0), int64(0), int64(0)), + sql.NewRow(int64(1), int64(1), int64(1)), + sql.NewRow(int64(2), int64(2), int64(2)), + sql.NewRow(int64(3), int64(3), int64(3)), + sql.NewRow(int64(4), int64(4), int64(4)), + sql.NewRow(int64(5), int64(5), int64(5)), + sql.NewRow(int64(6), int64(6), int64(6)), + sql.NewRow(int64(7), int64(7), int64(7))) } else { t.Logf("Warning: could not create table %s: %s", "one_pk_two_idx", err) } @@ -342,14 +342,14 @@ func createSubsetTestData(t *testing.T, harness Harness, includedTables []string if err == nil { InsertRows(t, NewContext(harness), mustInsertableTable(t, table), - sql.NewRow(0, 0, 0, 0), - sql.NewRow(1, 0, 0, 1), - sql.NewRow(2, 0, 1, 0), - sql.NewRow(3, 0, 2, 2), - sql.NewRow(4, 1, 0, 0), - sql.NewRow(5, 2, 0, 3), - sql.NewRow(6, 3, 3, 0), - sql.NewRow(7, 4, 4, 4)) + sql.NewRow(int64(0), int64(0), int64(0), int64(0)), + sql.NewRow(int64(1), int64(0), int64(0), int64(1)), + sql.NewRow(int64(2), int64(0), int64(1), int64(0)), + sql.NewRow(int64(3), int64(0), int64(2), int64(2)), + sql.NewRow(int64(4), int64(1), int64(0), int64(0)), + sql.NewRow(int64(5), int64(2), int64(0), int64(3)), + sql.NewRow(int64(6), int64(3), int64(3), int64(0)), + sql.NewRow(int64(7), int64(4), int64(4), int64(4))) } else { t.Logf("Warning: could not create table %s: %s", "one_pk_three_idx", err) } @@ -383,9 +383,9 @@ func createSubsetTestData(t *testing.T, harness Harness, includedTables []string if err == nil { InsertRows(t, NewContext(harness), mustInsertableTable(t, table), - sql.NewRow(int64(1), "first row"), - sql.NewRow(int64(2), "second row"), - sql.NewRow(int64(3), "third row")) + sql.NewRow(int32(1), "first row"), + sql.NewRow(int32(2), "second row"), + sql.NewRow(int32(3), "third row")) } else { t.Logf("Warning: could not create table %s: %s", "tabletest", err) } @@ -510,11 +510,11 @@ func createSubsetTestData(t *testing.T, harness Harness, includedTables []string if err == nil { InsertRows(t, NewContext(harness), mustInsertableTable(t, table), sql.NewRow(int64(1), nil, nil, nil), - sql.NewRow(int64(2), int64(2), 1, nil), - sql.NewRow(int64(3), nil, 0, nil), + sql.NewRow(int64(2), int64(2), int8(1), nil), + sql.NewRow(int64(3), nil, int8(0), nil), sql.NewRow(int64(4), int64(4), nil, float64(4)), - sql.NewRow(int64(5), nil, 1, float64(5)), - sql.NewRow(int64(6), int64(6), 0, float64(6))) + sql.NewRow(int64(5), nil, int8(1), float64(5)), + sql.NewRow(int64(6), int64(6), int8(0), float64(6))) } else { t.Logf("Warning: could not create table %s: %s", "niltable", err) } @@ -585,7 +585,7 @@ func createSubsetTestData(t *testing.T, harness Harness, includedTables []string t1, t2, "fourteen", - 0, + int8(0), nil, nil, )) @@ -607,9 +607,9 @@ func createSubsetTestData(t *testing.T, harness Harness, includedTables []string if err == nil { InsertRows(t, NewContext(harness), mustInsertableTable(t, table), - sql.NewRow(1, mustParseDate("2019-12-31T12:00:00Z"), mustParseTime("2020-01-01T12:00:00Z"), mustParseTime("2020-01-02T12:00:00Z"), mustSQLTime(3*time.Hour+10*time.Minute)), - sql.NewRow(2, mustParseDate("2020-01-03T12:00:00Z"), mustParseTime("2020-01-04T12:00:00Z"), mustParseTime("2020-01-05T12:00:00Z"), mustSQLTime(4*time.Hour+44*time.Second)), - sql.NewRow(3, mustParseDate("2020-01-07T00:00:00Z"), mustParseTime("2020-01-07T12:00:00Z"), mustParseTime("2020-01-07T12:00:01Z"), mustSQLTime(15*time.Hour+5*time.Millisecond)), + sql.NewRow(int64(1), mustParseDate("2019-12-31T12:00:00Z"), mustParseTime("2020-01-01T12:00:00Z"), mustParseTime("2020-01-02T12:00:00Z"), mustSQLTime(3*time.Hour+10*time.Minute)), + sql.NewRow(int64(2), mustParseDate("2020-01-03T12:00:00Z"), mustParseTime("2020-01-04T12:00:00Z"), mustParseTime("2020-01-05T12:00:00Z"), mustSQLTime(4*time.Hour+44*time.Second)), + sql.NewRow(int64(3), mustParseDate("2020-01-07T00:00:00Z"), mustParseTime("2020-01-07T12:00:00Z"), mustParseTime("2020-01-07T12:00:01Z"), mustSQLTime(15*time.Hour+5*time.Millisecond)), ) } }) @@ -666,9 +666,9 @@ func createSubsetTestData(t *testing.T, harness Harness, includedTables []string if err == nil { InsertRows(t, NewContext(harness), mustInsertableTable(t, table), - sql.NewRow(1, 1, "first row"), - sql.NewRow(2, 2, "second row"), - sql.NewRow(3, 3, "third row"), + sql.NewRow(int64(1), int64(1), "first row"), + sql.NewRow(int64(2), int64(2), "second row"), + sql.NewRow(int64(3), int64(3), "third row"), ) } else { t.Logf("Warning: could not create table %s: %s", "fk_tbl", err) @@ -688,9 +688,9 @@ func createSubsetTestData(t *testing.T, harness Harness, includedTables []string if err == nil && ok { ctx := NewContext(harness) InsertRows(t, ctx, mustInsertableTable(t, autoTbl), - sql.NewRow(1, 11), - sql.NewRow(2, 22), - sql.NewRow(3, 33), + sql.NewRow(int64(1), int64(11)), + sql.NewRow(int64(2), int64(22)), + sql.NewRow(int64(3), int64(33)), ) // InsertRows bypasses integrator auto increment methods // manually set the auto increment value here @@ -713,9 +713,9 @@ func createSubsetTestData(t *testing.T, harness Harness, includedTables []string if err == nil && ok { InsertRows(t, NewContext(harness), mustInsertableTable(t, autoTbl), - sql.NewRow(0, 2, 2), - sql.NewRow(1, 1, 0), - sql.NewRow(2, 0, 1), + sql.NewRow(int64(0), int64(2), int64(2)), + sql.NewRow(int64(1), int64(1), int64(0)), + sql.NewRow(int64(2), int64(0), int64(1)), ) } else { t.Logf("Warning: could not create table %s: %s", "invert_pk", err) @@ -734,16 +734,16 @@ func createSubsetTestData(t *testing.T, harness Harness, includedTables []string if err == nil { InsertRows(t, NewContext(harness), mustInsertableTable(t, table), []sql.Row{ - {"pie", "crust", 1}, - {"pie", "filling", 2}, - {"crust", "flour", 20}, - {"crust", "sugar", 2}, - {"crust", "butter", 15}, - {"crust", "salt", 15}, - {"filling", "sugar", 5}, - {"filling", "fruit", 9}, - {"filling", "salt", 3}, - {"filling", "butter", 3}, + {"pie", "crust", int64(1)}, + {"pie", "filling", int64(2)}, + {"crust", "flour", int64(20)}, + {"crust", "sugar", int64(2)}, + {"crust", "butter", int64(15)}, + {"crust", "salt", int64(15)}, + {"filling", "sugar", int64(5)}, + {"filling", "fruit", int64(9)}, + {"filling", "salt", int64(3)}, + {"filling", "butter", int64(3)}, }...) } else { t.Logf("Warning: could not create table %s: %s", "parts", err) diff --git a/memory/table.go b/memory/table.go index 44d0373aeb..9de1a13b12 100644 --- a/memory/table.go +++ b/memory/table.go @@ -19,6 +19,7 @@ import ( "encoding/gob" "fmt" "io" + "reflect" "sort" "strconv" "strings" @@ -1451,3 +1452,18 @@ func (t *Table) PartitionRows2(ctx *sql.Context, partition sql.Partition) (sql.R return iter.(*tableIter), nil } + +func (t *Table) verifyRowTypes(row sql.Row) { + //TODO: only run this when in testing mode + if len(row) == len(t.schema.Schema) { + for i := range t.schema.Schema { + col := t.schema.Schema[i] + rowVal := row[i] + valType := reflect.TypeOf(rowVal) + expectedType := col.Type.ValueType() + if valType != expectedType && rowVal != nil && !valType.AssignableTo(expectedType) { + panic(fmt.Errorf("Actual Value Type: %s, Expected Value Type: %s", valType.String(), expectedType.String())) + } + } + } +} diff --git a/memory/table_editor.go b/memory/table_editor.go index ed2d169202..227da2897b 100644 --- a/memory/table_editor.go +++ b/memory/table_editor.go @@ -70,6 +70,7 @@ func (t *tableEditor) Insert(ctx *sql.Context, row sql.Row) error { if err := checkRow(t.table.schema.Schema, row); err != nil { return err } + t.table.verifyRowTypes(row) partitionRow, added, err := t.ea.Get(row) if err != nil { @@ -119,6 +120,7 @@ func (t *tableEditor) Delete(ctx *sql.Context, row sql.Row) error { if err := checkRow(t.table.schema.Schema, row); err != nil { return err } + t.table.verifyRowTypes(row) err := t.ea.Delete(row) if err != nil { @@ -136,6 +138,8 @@ func (t *tableEditor) Update(ctx *sql.Context, oldRow sql.Row, newRow sql.Row) e if err := checkRow(t.table.schema.Schema, newRow); err != nil { return err } + t.table.verifyRowTypes(oldRow) + t.table.verifyRowTypes(newRow) err := t.ea.Delete(oldRow) if err != nil { diff --git a/sql/arraytype.go b/sql/arraytype.go index 6d85cf9e96..afb87ab423 100644 --- a/sql/arraytype.go +++ b/sql/arraytype.go @@ -18,11 +18,14 @@ import ( "encoding/json" "fmt" "io" + "reflect" "github.com/dolthub/vitess/go/sqltypes" "github.com/dolthub/vitess/go/vt/proto/query" ) +var arrayValueType = reflect.TypeOf((*[]interface{})(nil)).Elem() + type arrayType struct { underlying Type } @@ -153,7 +156,7 @@ func (t arrayType) SQL(dest []byte, v interface{}) (sqltypes.Value, error) { } } - val = appendAndSlice(dest, val) + val = appendAndSliceBytes(dest, val) return sqltypes.MakeTrusted(sqltypes.TypeJSON, val), nil } @@ -166,6 +169,10 @@ func (t arrayType) Type() query.Type { return sqltypes.TypeJSON } +func (t arrayType) ValueType() reflect.Type { + return arrayValueType +} + func (t arrayType) Zero() interface{} { return nil } diff --git a/sql/bit.go b/sql/bit.go index 19cfa155c4..ca5d771a5c 100644 --- a/sql/bit.go +++ b/sql/bit.go @@ -17,7 +17,7 @@ package sql import ( "encoding/binary" "fmt" - "strconv" + "reflect" "github.com/dolthub/vitess/go/sqltypes" "github.com/dolthub/vitess/go/vt/proto/query" @@ -35,10 +35,12 @@ const ( var ( promotedBitType = MustCreateBitType(BitTypeMaxBits) errBeyondMaxBit = errors.NewKind("%v is beyond the maximum value that can be held by %v bits") + bitValueType = reflect.TypeOf(uint64(0)) ) -// Represents the BIT type. +// BitType represents the BIT type. // https://dev.mysql.com/doc/refman/8.0/en/bit-type.html +// The type of the returned value is uint64. type BitType interface { Type NumberOfBits() uint8 @@ -133,6 +135,11 @@ func (t bitType) Convert(v interface{}) (interface{}, error) { return nil, fmt.Errorf(`negative floats cannot become bit values`) } value = uint64(val) + case decimal.NullDecimal: + if !val.Valid { + return nil, nil + } + return t.Convert(val.Decimal) case decimal.Decimal: val = val.Round(0) if val.GreaterThan(dec_uint64_max) { @@ -190,10 +197,16 @@ func (t bitType) SQL(dest []byte, v interface{}) (sqltypes.Value, error) { if err != nil { return sqltypes.Value{}, err } + bitVal := value.(uint64) - stop := len(dest) - dest = strconv.AppendUint(dest, value.(uint64), 10) - val := dest[stop:] + var data []byte + for i := uint64(0); i < uint64(t.numOfBits); i += 8 { + data = append(data, byte(bitVal>>i)) + } + for i, j := 0, len(data)-1; i < j; i, j = i+1, j-1 { + data[i], data[j] = data[j], data[i] + } + val := appendAndSliceBytes(dest, data) return sqltypes.MakeTrusted(sqltypes.Bit, val), nil } @@ -208,6 +221,11 @@ func (t bitType) Type() query.Type { return sqltypes.Bit } +// ValueType implements Type interface. +func (t bitType) ValueType() reflect.Type { + return bitValueType +} + // Zero implements Type interface. Returns a uint64 value. func (t bitType) Zero() interface{} { return uint64(0) diff --git a/sql/bit_test.go b/sql/bit_test.go index 48a390296d..8680e74fa6 100644 --- a/sql/bit_test.go +++ b/sql/bit_test.go @@ -16,6 +16,7 @@ package sql import ( "fmt" + "reflect" "testing" "time" @@ -128,6 +129,9 @@ func TestBitConvert(t *testing.T) { } else { require.NoError(t, err) assert.Equal(t, test.expectedVal, val) + if val != nil { + assert.Equal(t, test.typ.ValueType(), reflect.TypeOf(val)) + } } }) } diff --git a/sql/core.go b/sql/core.go index 5994de0b2e..67dab90608 100644 --- a/sql/core.go +++ b/sql/core.go @@ -1037,7 +1037,7 @@ type ExternalStoredProcedureDetails struct { Schema Schema // Function is the implementation of the external stored procedure. All functions should have the following definition: // `func(*Context, ) (RowIter, error)`. The may be any of the following types: `bool`, - // `string`, `[]byte`, `int8`-`int64`, `uint8`-`uint64`, `float32`, `float64`, `time.Time`, or `decimal` + // `string`, `[]byte`, `int8`-`int64`, `uint8`-`uint64`, `float32`, `float64`, `time.Time`, or `Decimal` // (shopspring/decimal). The architecture-dependent types `int` and `uint` (without a number) are also supported. // It is valid to return a nil RowIter if there are no rows to be returned. // @@ -1046,7 +1046,7 @@ type ExternalStoredProcedureDetails struct { // // Values are converted to their nearest type before being passed in, following the conversion rules of their // related SQL types. The exceptions are `time.Time` (treated as a `DATETIME`), string (treated as a `LONGTEXT` with - // the default collation) and decimal (treated with a larger precision and scale). Take extra care when using decimal + // the default collation) and Decimal (treated with a larger precision and scale). Take extra care when using decimal // for an INOUT parameter, to ensure that the returned value fits the original's precision and scale, else an error // will occur. // diff --git a/sql/datetimetype.go b/sql/datetimetype.go index 928b41f7ad..c0ba3483eb 100644 --- a/sql/datetimetype.go +++ b/sql/datetimetype.go @@ -16,8 +16,11 @@ package sql import ( "math" + "reflect" "time" + "github.com/shopspring/decimal" + "github.com/dolthub/vitess/go/sqltypes" "github.com/dolthub/vitess/go/vt/proto/query" "gopkg.in/src-d/go-errors.v1" @@ -87,10 +90,13 @@ var ( Datetime = MustCreateDatetimeType(sqltypes.Datetime) // Timestamp is an UNIX timestamp. Timestamp = MustCreateDatetimeType(sqltypes.Timestamp) + + datetimeValueType = reflect.TypeOf(time.Time{}) ) -// Represents DATE, DATETIME, and TIMESTAMP. +// DatetimeType represents DATE, DATETIME, and TIMESTAMP. // https://dev.mysql.com/doc/refman/8.0/en/datetime.html +// The type of the returned value is time.Time. type DatetimeType interface { Type ConvertWithoutRangeCheck(v interface{}) (time.Time, error) @@ -275,6 +281,16 @@ func (t datetimeType) ConvertWithoutRangeCheck(v interface{}) (time.Time, error) return zeroTime, nil } return zeroTime, ErrConvertingToTime.New(v) + case decimal.Decimal: + if value.IsZero() { + return zeroTime, nil + } + return zeroTime, ErrConvertingToTime.New(v) + case decimal.NullDecimal: + if value.Valid && value.Decimal.IsZero() { + return zeroTime, nil + } + return zeroTime, ErrConvertingToTime.New(v) default: return zeroTime, ErrConvertToSQL.New(t) } @@ -317,37 +333,37 @@ func (t datetimeType) SQL(dest []byte, v interface{}) (sqltypes.Value, error) { vt := v.(time.Time) var typ query.Type - var val []byte + var val string switch t.baseType { case sqltypes.Date: typ = sqltypes.Date if vt.Equal(zeroTime) { - val = []byte(vt.Format(zeroDateStr)) + val = vt.Format(zeroDateStr) } else { - val = []byte(vt.Format(DateLayout)) + val = vt.Format(DateLayout) } case sqltypes.Datetime: typ = sqltypes.Datetime if vt.Equal(zeroTime) { - val = []byte(vt.Format(zeroTimestampDatetimeStr)) + val = vt.Format(zeroTimestampDatetimeStr) } else { - val = []byte(vt.Format(TimestampDatetimeLayout)) + val = vt.Format(TimestampDatetimeLayout) } case sqltypes.Timestamp: typ = sqltypes.Timestamp if vt.Equal(zeroTime) { - val = []byte(vt.Format(zeroTimestampDatetimeStr)) + val = vt.Format(zeroTimestampDatetimeStr) } else { - val = []byte(vt.Format(TimestampDatetimeLayout)) + val = vt.Format(TimestampDatetimeLayout) } default: panic(ErrInvalidBaseType.New(t.baseType.String(), "datetime")) } - val = appendAndSlice(dest, val) + valBytes := appendAndSliceString(dest, val) - return sqltypes.MakeTrusted(typ, val), nil + return sqltypes.MakeTrusted(typ, valBytes), nil } func (t datetimeType) String() string { @@ -368,6 +384,11 @@ func (t datetimeType) Type() query.Type { return t.baseType } +// ValueType implements Type interface. +func (t datetimeType) ValueType() reflect.Type { + return datetimeValueType +} + func (t datetimeType) Zero() interface{} { return zeroTime } diff --git a/sql/datetimetype_test.go b/sql/datetimetype_test.go index 10716b52c3..266efbec1a 100644 --- a/sql/datetimetype_test.go +++ b/sql/datetimetype_test.go @@ -16,6 +16,7 @@ package sql import ( "fmt" + "reflect" "testing" "time" @@ -304,6 +305,9 @@ func TestDatetimeConvert(t *testing.T) { } else { require.NoError(t, err) assert.Equal(t, test.expectedVal, val) + if val != nil { + assert.Equal(t, test.typ.ValueType(), reflect.TypeOf(val)) + } } }) } diff --git a/sql/decimal.go b/sql/decimal.go index 6c0b4d16f0..146319c8e2 100644 --- a/sql/decimal.go +++ b/sql/decimal.go @@ -17,6 +17,7 @@ package sql import ( "fmt" "math/big" + "reflect" "github.com/dolthub/vitess/go/sqltypes" "github.com/dolthub/vitess/go/vt/proto/query" @@ -37,14 +38,29 @@ var ( ErrConvertingToDecimal = errors.NewKind("value %v is not a valid Decimal") ErrConvertToDecimalLimit = errors.NewKind("value of Decimal is too large for type") ErrMarshalNullDecimal = errors.NewKind("Decimal cannot marshal a null value") + + decimalValueType = reflect.TypeOf(decimal.Decimal{}) ) +// DecimalType represents the DECIMAL type. +// https://dev.mysql.com/doc/refman/8.0/en/fixed-point-types.html +// The type of the returned value is decimal.Decimal. type DecimalType interface { Type - ConvertToDecimal(v interface{}) (decimal.NullDecimal, error) + // ConvertToNullDecimal converts the given value to a decimal.NullDecimal if it has a compatible type. It is worth + // noting that Convert() returns a nil value for nil inputs, and also returns decimal.Decimal rather than + // decimal.NullDecimal. + ConvertToNullDecimal(v interface{}) (decimal.NullDecimal, error) + // ExclusiveUpperBound returns the exclusive upper bound for this Decimal. + // For example, DECIMAL(5,2) would return 1000, as 999.99 is the max represented. ExclusiveUpperBound() decimal.Decimal + // MaximumScale returns the maximum scale allowed for the current precision. MaximumScale() uint8 + // Precision returns the base-10 precision of the type, which is the total number of digits. For example, a + // precision of 3 means that 999, 99.9, 9.99, and .999 are all valid maximums (depending on the scale). Precision() uint8 + // Scale returns the scale, or number of digits after the decimal, that may be held. + // This will always be less than or equal to the precision. Scale() uint8 } @@ -103,11 +119,11 @@ func (t decimalType) Compare(a interface{}, b interface{}) (int, error) { return res, nil } - af, err := t.ConvertToDecimal(a) + af, err := t.ConvertToNullDecimal(a) if err != nil { return 0, err } - bf, err := t.ConvertToDecimal(b) + bf, err := t.ConvertToNullDecimal(b) if err != nil { return 0, err } @@ -117,18 +133,18 @@ func (t decimalType) Compare(a interface{}, b interface{}) (int, error) { // Convert implements Type interface. func (t decimalType) Convert(v interface{}) (interface{}, error) { - dec, err := t.ConvertToDecimal(v) + dec, err := t.ConvertToNullDecimal(v) if err != nil { return nil, err } if !dec.Valid { return nil, nil } - return dec.Decimal.StringFixed(int32(t.scale)), nil + return dec.Decimal, nil } -// Precision returns the precision, or total number of digits, that may be held. -func (t decimalType) ConvertToDecimal(v interface{}) (decimal.NullDecimal, error) { +// ConvertToNullDecimal implements DecimalType interface. +func (t decimalType) ConvertToNullDecimal(v interface{}) (decimal.NullDecimal, error) { if v == nil { return decimal.NullDecimal{}, nil } @@ -137,21 +153,21 @@ func (t decimalType) ConvertToDecimal(v interface{}) (decimal.NullDecimal, error switch value := v.(type) { case int: - return t.ConvertToDecimal(int64(value)) + return t.ConvertToNullDecimal(int64(value)) case uint: - return t.ConvertToDecimal(uint64(value)) + return t.ConvertToNullDecimal(uint64(value)) case int8: - return t.ConvertToDecimal(int64(value)) + return t.ConvertToNullDecimal(int64(value)) case uint8: - return t.ConvertToDecimal(uint64(value)) + return t.ConvertToNullDecimal(uint64(value)) case int16: - return t.ConvertToDecimal(int64(value)) + return t.ConvertToNullDecimal(int64(value)) case uint16: - return t.ConvertToDecimal(uint64(value)) + return t.ConvertToNullDecimal(uint64(value)) case int32: res = decimal.NewFromInt32(value) case uint32: - return t.ConvertToDecimal(uint64(value)) + return t.ConvertToNullDecimal(uint64(value)) case int64: res = decimal.NewFromInt(value) case uint64: @@ -175,11 +191,11 @@ func (t decimalType) ConvertToDecimal(v interface{}) (decimal.NullDecimal, error } } case *big.Float: - return t.ConvertToDecimal(value.Text('f', -1)) + return t.ConvertToNullDecimal(value.Text('f', -1)) case *big.Int: - return t.ConvertToDecimal(value.Text(10)) + return t.ConvertToNullDecimal(value.Text(10)) case *big.Rat: - return t.ConvertToDecimal(new(big.Float).SetRat(value)) + return t.ConvertToNullDecimal(new(big.Float).SetRat(value)) case decimal.Decimal: res = value case decimal.NullDecimal: @@ -232,7 +248,7 @@ func (t decimalType) SQL(dest []byte, v interface{}) (sqltypes.Value, error) { return sqltypes.Value{}, err } - val := appendAndSlice(dest, []byte(value.(string))) + val := appendAndSliceString(dest, value.(decimal.Decimal).StringFixed(int32(t.scale))) return sqltypes.MakeTrusted(sqltypes.Decimal, val), nil } @@ -242,18 +258,22 @@ func (t decimalType) String() string { return fmt.Sprintf("DECIMAL(%v,%v)", t.precision, t.scale) } -// Zero implements Type interface. Returns a uint64 value. +// ValueType implements Type interface. +func (t decimalType) ValueType() reflect.Type { + return decimalValueType +} + +// Zero implements Type interface. func (t decimalType) Zero() interface{} { return decimal.NewFromInt(0).StringFixed(int32(t.scale)) } -// ExclusiveUpperBound returns the exclusive upper bound for this Decimal. -// For example, DECIMAL(5,2) would return 1000, as 999.99 is the max represented. +// ExclusiveUpperBound implements DecimalType interface. func (t decimalType) ExclusiveUpperBound() decimal.Decimal { return t.exclusiveUpperBound } -// MaximumScale returns the maximum scale allowed for the current precision. +// MaximumScale implements DecimalType interface. func (t decimalType) MaximumScale() uint8 { if t.precision >= DecimalTypeMaxScale { return DecimalTypeMaxScale @@ -261,14 +281,12 @@ func (t decimalType) MaximumScale() uint8 { return t.precision } -// Precision returns the base-10 precision of the type, which is the total number of digits. -// For example, a precision of 3 means that 999, 99.9, 9.99, and .999 are all valid maximums (depending on the scale). +// Precision implements DecimalType interface. func (t decimalType) Precision() uint8 { return t.precision } -// Scale returns the scale, or number of digits after the decimal, that may be held. -// This will always be less than or equal to the precision. +// Scale implements DecimalType interface. func (t decimalType) Scale() uint8 { return t.scale } diff --git a/sql/decimal_test.go b/sql/decimal_test.go index 3dcb0ff86f..fa4b6bff8d 100644 --- a/sql/decimal_test.go +++ b/sql/decimal_test.go @@ -17,6 +17,7 @@ package sql import ( "fmt" "math/big" + "reflect" "strings" "testing" "time" @@ -69,7 +70,7 @@ func TestDecimalAccuracy(t *testing.T) { for _, test := range tests { decimalType := MustCreateDecimalType(uint8(precision), uint8(test.scale)) - decimal := big.NewInt(0) + decimalInt := big.NewInt(0) bigIntervals := make([]*big.Int, len(test.intervals)) for i, interval := range test.intervals { bigInterval := new(big.Int) @@ -81,18 +82,18 @@ func TestDecimalAccuracy(t *testing.T) { upperBound := new(big.Int) _ = upperBound.UnmarshalText([]byte("1" + strings.Repeat("0", test.scale))) - for decimal.Cmp(upperBound) == -1 { - decimalStr := decimal.Text(10) + for decimalInt.Cmp(upperBound) == -1 { + decimalStr := decimalInt.Text(10) fullDecimalStr := strings.Repeat("0", test.scale-len(decimalStr)) + decimalStr fullStr := baseStr + fullDecimalStr t.Run(fmt.Sprintf("Scale:%v DecVal:%v", test.scale, fullDecimalStr), func(t *testing.T) { res, err := decimalType.Convert(fullStr) require.NoError(t, err) - require.Equal(t, fullStr, res) + require.Equal(t, fullStr, res.(decimal.Decimal).StringFixed(int32(decimalType.Scale()))) }) - decimal.Add(decimal, bigIntervals[intervalIndex]) + decimalInt.Add(decimalInt, bigIntervals[intervalIndex]) intervalIndex = (intervalIndex + 1) % len(bigIntervals) } } @@ -267,12 +268,20 @@ func TestDecimalConvert(t *testing.T) { for _, test := range tests { t.Run(fmt.Sprintf("%v %v %v", test.precision, test.scale, test.val), func(t *testing.T) { - val, err := MustCreateDecimalType(test.precision, test.scale).Convert(test.val) + typ := MustCreateDecimalType(test.precision, test.scale) + val, err := typ.Convert(test.val) if test.expectedErr { assert.Error(t, err) } else { require.NoError(t, err) - assert.Equal(t, test.expectedVal, val) + if test.expectedVal == nil { + assert.Nil(t, val) + } else { + expectedVal, err := decimal.NewFromString(test.expectedVal.(string)) + require.NoError(t, err) + assert.True(t, expectedVal.Equal(val.(decimal.Decimal))) + assert.Equal(t, typ.ValueType(), reflect.TypeOf(val)) + } } }) } diff --git a/sql/deferredtype.go b/sql/deferredtype.go index 36ae340398..b793d106a0 100644 --- a/sql/deferredtype.go +++ b/sql/deferredtype.go @@ -15,6 +15,8 @@ package sql import ( + "reflect" + "github.com/dolthub/vitess/go/sqltypes" "github.com/dolthub/vitess/go/vt/proto/query" ) @@ -85,6 +87,11 @@ func (t deferredType) Type() query.Type { return sqltypes.Expression } +// ValueType implements Type interface. +func (t deferredType) ValueType() reflect.Type { + return nil +} + // Zero implements Type interface. func (t deferredType) Zero() interface{} { return nil diff --git a/sql/enumtype.go b/sql/enumtype.go index 53f5d8ce3d..c6b57ae316 100644 --- a/sql/enumtype.go +++ b/sql/enumtype.go @@ -16,9 +16,12 @@ package sql import ( "fmt" + "reflect" "strconv" "strings" + "github.com/shopspring/decimal" + "github.com/dolthub/vitess/go/sqltypes" "github.com/dolthub/vitess/go/vt/proto/query" "gopkg.in/src-d/go-errors.v1" @@ -35,23 +38,26 @@ const ( var ( ErrConvertingToEnum = errors.NewKind("value %v is not valid for this Enum") ErrUnmarshallingEnum = errors.NewKind("value %v is not a marshalled value for this Enum") + + enumValueType = reflect.TypeOf(uint16(0)) ) // Comments with three slashes were taken directly from the linked documentation. -// Represents the ENUM type. +// EnumType represents the ENUM type. // https://dev.mysql.com/doc/refman/8.0/en/enum.html +// The type of the returned value is uint16. type EnumType interface { Type + // At returns the string at the given index, as well if the string was found. At(index int) (string, bool) CharacterSet() CharacterSet Collation() Collation - ConvertToIndex(v interface{}) (int, error) + // IndexOf returns the index of the given string. If the string was not found, then this returns -1. IndexOf(v string) int - //TODO: move this out of go-mysql-server and into the Dolt layer - Marshal(v interface{}) (int64, error) + // NumberOfElements returns the number of enumerations. NumberOfElements() uint16 - Unmarshal(v int64) (string, error) + // Values returns the elements, in order, of every enumeration. Values() []string } @@ -104,18 +110,20 @@ func (t enumType) Compare(a interface{}, b interface{}) (int, error) { return res, nil } - ai, err := t.ConvertToIndex(a) + ai, err := t.Convert(a) if err != nil { return 0, err } - bi, err := t.ConvertToIndex(b) + bi, err := t.Convert(b) if err != nil { return 0, err } + au := ai.(uint16) + bu := bi.(uint16) - if ai < bi { + if au < bu { return -1, nil - } else if ai > bi { + } else if au > bu { return 1, nil } return 0, nil @@ -129,8 +137,8 @@ func (t enumType) Convert(v interface{}) (interface{}, error) { switch value := v.(type) { case int: - if str, ok := t.At(value); ok { - return str, nil + if _, ok := t.At(value); ok { + return uint16(value), nil } case uint: return t.Convert(int(value)) @@ -154,12 +162,17 @@ func (t enumType) Convert(v interface{}) (interface{}, error) { return t.Convert(int(value)) case float64: return t.Convert(int(value)) + case decimal.Decimal: + return t.Convert(value.IntPart()) + case decimal.NullDecimal: + if !value.Valid { + return nil, nil + } + return t.Convert(value.Decimal.IntPart()) case string: if index := t.IndexOf(value); index != -1 { - realStr, _ := t.At(index) - return realStr, nil + return uint16(index), nil } - return nil, ErrConvertingToEnum.New(`"` + value + `"`) case []byte: return t.Convert(string(value)) } @@ -189,47 +202,6 @@ func (t enumType) Equals(otherType Type) bool { return false } -// ConvertToIndex is similar to Convert, except that it converts to the index rather than the value. -// Returns an error on nil. -func (t enumType) ConvertToIndex(v interface{}) (int, error) { - switch value := v.(type) { - case int: - if _, ok := t.At(value); ok { - return value, nil - } - case uint: - return t.ConvertToIndex(int(value)) - case int8: - return t.ConvertToIndex(int(value)) - case uint8: - return t.ConvertToIndex(int(value)) - case int16: - return t.ConvertToIndex(int(value)) - case uint16: - return t.ConvertToIndex(int(value)) - case int32: - return t.ConvertToIndex(int(value)) - case uint32: - return t.ConvertToIndex(int(value)) - case int64: - return t.ConvertToIndex(int(value)) - case uint64: - return t.ConvertToIndex(int(value)) - case float32: - return t.ConvertToIndex(int(value)) - case float64: - return t.ConvertToIndex(int(value)) - case string: - if index := t.IndexOf(value); index != -1 { - return index, nil - } - case []byte: - return t.ConvertToIndex(string(value)) - } - - return -1, ErrConvertingToEnum.New(v) -} - // Promote implements the Type interface. func (t enumType) Promote() Type { return t @@ -240,12 +212,13 @@ func (t enumType) SQL(dest []byte, v interface{}) (sqltypes.Value, error) { if v == nil { return sqltypes.NULL, nil } - value, err := t.Convert(v) + convertedValue, err := t.Convert(v) if err != nil { return sqltypes.Value{}, err } + value, _ := t.At(int(convertedValue.(uint16))) - val := appendAndSlice(dest, []byte(value.(string))) + val := appendAndSliceString(dest, value) return sqltypes.MakeTrusted(sqltypes.Enum, val), nil } @@ -267,13 +240,18 @@ func (t enumType) Type() query.Type { return sqltypes.Enum } +// ValueType implements Type interface. +func (t enumType) ValueType() reflect.Type { + return enumValueType +} + // Zero implements Type interface. func (t enumType) Zero() interface{} { /// If an ENUM column is declared NOT NULL, its default value is the first element of the list of permitted values. return t.indexToVal[0] } -// At returns the string at the given index, as well if the string was found. +// At implements EnumType interface. func (t enumType) At(index int) (string, bool) { /// The elements listed in the column specification are assigned index numbers, beginning with 1. index -= 1 @@ -283,15 +261,17 @@ func (t enumType) At(index int) (string, bool) { return t.indexToVal[index], true } +// CharacterSet implements EnumType interface. func (t enumType) CharacterSet() CharacterSet { return t.collation.CharacterSet() } +// Collation implements EnumType interface. func (t enumType) Collation() Collation { return t.collation } -// IndexOf returns the index of the given string. If the string was not found, then this returns -1. +// IndexOf implements EnumType interface. func (t enumType) IndexOf(v string) int { if index, ok := t.valToIndex[v]; ok { return index @@ -308,27 +288,12 @@ func (t enumType) IndexOf(v string) int { return -1 } -// Marshal takes a valid Enum value and returns it as an int64. -func (t enumType) Marshal(v interface{}) (int64, error) { - i, err := t.ConvertToIndex(v) - return int64(i), err -} - -// NumberOfElements returns the number of enumerations. +// NumberOfElements implements EnumType interface. func (t enumType) NumberOfElements() uint16 { return uint16(len(t.indexToVal)) } -// Unmarshal takes a previously-marshalled value and returns it as a string. -func (t enumType) Unmarshal(v int64) (string, error) { - str, found := t.At(int(v)) - if !found { - return "", ErrUnmarshallingEnum.New(v) - } - return str, nil -} - -// Values returns the elements, in order, of every enumeration. +// Values implements EnumType interface. func (t enumType) Values() []string { vals := make([]string, len(t.indexToVal)) copy(vals, t.indexToVal) diff --git a/sql/enumtype_test.go b/sql/enumtype_test.go index e44880d53c..d5d640f3bf 100644 --- a/sql/enumtype_test.go +++ b/sql/enumtype_test.go @@ -16,6 +16,7 @@ package sql import ( "fmt" + "reflect" "strconv" "testing" "time" @@ -144,15 +145,15 @@ func TestEnumConvert(t *testing.T) { assert.Error(t, err) } else { require.NoError(t, err) - assert.Equal(t, test.expectedVal, val) if test.val != nil { - mar, err := typ.Marshal(test.val) - require.NoError(t, err) - umar, err := typ.Unmarshal(mar) - require.NoError(t, err) + umar, ok := typ.At(int(val.(uint16))) + require.True(t, ok) cmp, err := typ.Compare(test.val, umar) require.NoError(t, err) assert.Equal(t, 0, cmp) + assert.Equal(t, typ.ValueType(), reflect.TypeOf(val)) + } else { + assert.Equal(t, test.expectedVal, val) } } }) diff --git a/sql/expression/function/aggregation/window_partition_test.go b/sql/expression/function/aggregation/window_partition_test.go index 5adda99411..ecc4fcabea 100644 --- a/sql/expression/function/aggregation/window_partition_test.go +++ b/sql/expression/function/aggregation/window_partition_test.go @@ -96,15 +96,15 @@ func mustNewRowIter(t *testing.T, ctx *sql.Context) sql.RowIter { table := memory.NewTable("test", childSchema, nil) rows := []sql.Row{ - {int64(1), "forest", "leaf", 4}, - {int64(2), "forest", "bark", 4}, - {int64(3), "forest", "canopy", 6}, - {int64(4), "forest", "bug", 3}, - {int64(5), "forest", "wildflower", 10}, - {int64(6), "desert", "sand", 4}, - {int64(7), "desert", "cactus", 6}, - {int64(8), "desert", "scorpion", 8}, - {int64(9), "desert", "mummy", 5}, + {int64(1), "forest", "leaf", int32(4)}, + {int64(2), "forest", "bark", int32(4)}, + {int64(3), "forest", "canopy", int32(6)}, + {int64(4), "forest", "bug", int32(3)}, + {int64(5), "forest", "wildflower", int32(10)}, + {int64(6), "desert", "sand", int32(4)}, + {int64(7), "desert", "cactus", int32(6)}, + {int64(8), "desert", "scorpion", int32(8)}, + {int64(9), "desert", "mummy", int32(5)}, } for _, r := range rows { @@ -135,15 +135,15 @@ func TestWindowPartition_MaterializeInput(t *testing.T) { buf, ordering, err := i.materializeInput(ctx) require.NoError(t, err) expBuf := []sql.Row{ - {int64(1), "forest", "leaf", 4}, - {int64(2), "forest", "bark", 4}, - {int64(3), "forest", "canopy", 6}, - {int64(4), "forest", "bug", 3}, - {int64(5), "forest", "wildflower", 10}, - {int64(6), "desert", "sand", 4}, - {int64(7), "desert", "cactus", 6}, - {int64(8), "desert", "scorpion", 8}, - {int64(9), "desert", "mummy", 5}, + {int64(1), "forest", "leaf", int32(4)}, + {int64(2), "forest", "bark", int32(4)}, + {int64(3), "forest", "canopy", int32(6)}, + {int64(4), "forest", "bug", int32(3)}, + {int64(5), "forest", "wildflower", int32(10)}, + {int64(6), "desert", "sand", int32(4)}, + {int64(7), "desert", "cactus", int32(6)}, + {int64(8), "desert", "scorpion", int32(8)}, + {int64(9), "desert", "mummy", int32(5)}, } require.ElementsMatch(t, expBuf, buf) expOrd := []int{0, 1, 2, 3, 4, 5, 6, 7, 8} @@ -157,15 +157,15 @@ func TestWindowPartition_InitializePartitions(t *testing.T) { PartitionBy: partitionByX, }) i.input = []sql.Row{ - {int64(1), "forest", "leaf", 4}, - {int64(2), "forest", "bark", 4}, - {int64(3), "forest", "canopy", 6}, - {int64(4), "forest", "bug", 3}, - {int64(5), "forest", "wildflower", 10}, - {int64(6), "desert", "sand", 4}, - {int64(7), "desert", "cactus", 6}, - {int64(8), "desert", "scorpion", 8}, - {int64(9), "desert", "mummy", 5}, + {int64(1), "forest", "leaf", int32(4)}, + {int64(2), "forest", "bark", int32(4)}, + {int64(3), "forest", "canopy", int32(6)}, + {int64(4), "forest", "bug", int32(3)}, + {int64(5), "forest", "wildflower", int32(10)}, + {int64(6), "desert", "sand", int32(4)}, + {int64(7), "desert", "cactus", int32(6)}, + {int64(8), "desert", "scorpion", int32(8)}, + {int64(9), "desert", "mummy", int32(5)}, } partitions, err := i.initializePartitions(ctx) require.NoError(t, err) diff --git a/sql/expression/function/timediff.go b/sql/expression/function/timediff.go index ac4eb6dd0c..cd7cca3989 100644 --- a/sql/expression/function/timediff.go +++ b/sql/expression/function/timediff.go @@ -98,13 +98,12 @@ func (td *TimeDiff) Eval(ctx *sql.Context, row sql.Row) (interface{}, error) { rightDatetime = rightDatetime.In(leftDatetime.Location()) } return sql.Time.Convert(leftDatetime.Sub(rightDatetime)) - } else if leftTime, err := sql.Time.ConvertToTimeDuration(left); err == nil { - rightTime, err := sql.Time.ConvertToTimeDuration(right) + } else if leftTime, err := sql.Time.ConvertToTimespan(left); err == nil { + rightTime, err := sql.Time.ConvertToTimespan(right) if err != nil { return nil, err } - resTime := leftTime - rightTime - return sql.Time.Convert(resTime) + return leftTime.Subtract(rightTime), nil } else { return nil, ErrInvalidArgumentType.New("timediff") } diff --git a/sql/expression/function/timediff_test.go b/sql/expression/function/timediff_test.go index de8830b3db..00d8763374 100644 --- a/sql/expression/function/timediff_test.go +++ b/sql/expression/function/timediff_test.go @@ -26,19 +26,27 @@ import ( ) func TestTimeDiff(t *testing.T) { + toTimespan := func(str string) sql.Timespan { + res, err := sql.Time.ConvertToTimespan(str) + if err != nil { + t.Fatal(err) + } + return res + } + ctx := sql.NewEmptyContext() testCases := []struct { name string from sql.Expression to sql.Expression - expected string + expected sql.Timespan err bool }{ { "invalid type text", expression.NewLiteral("hello there", sql.Text), expression.NewConvert(expression.NewLiteral("01:00:00", sql.Text), expression.ConvertToTime), - "", + toTimespan(""), true, }, //TODO: handle Date properly @@ -53,70 +61,70 @@ func TestTimeDiff(t *testing.T) { "type mismatch 1", expression.NewLiteral(time.Date(2008, time.December, 29, 1, 1, 1, 2, time.Local), sql.Timestamp), expression.NewConvert(expression.NewLiteral("01:00:00", sql.Text), expression.ConvertToTime), - "", + toTimespan(""), true, }, { "type mismatch 2", expression.NewLiteral("00:00:00.2", sql.Text), expression.NewLiteral("2020-10-10 10:10:10", sql.Text), - "", + toTimespan(""), true, }, { "valid mismatch", expression.NewLiteral(time.Date(2008, time.December, 29, 1, 1, 1, 2, time.Local), sql.Timestamp), expression.NewLiteral(time.Date(2008, time.December, 30, 1, 1, 1, 2, time.Local), sql.Datetime), - "-24:00:00", + toTimespan("-24:00:00"), false, }, { "timestamp types 1", expression.NewLiteral(time.Date(2018, time.May, 2, 0, 0, 0, 0, time.Local), sql.Timestamp), expression.NewLiteral(time.Date(2018, time.May, 2, 0, 0, 1, 0, time.Local), sql.Timestamp), - "-00:00:01", + toTimespan("-00:00:01"), false, }, { "timestamp types 2", expression.NewLiteral(time.Date(2008, time.December, 31, 23, 59, 59, 1, time.Local), sql.Timestamp), expression.NewLiteral(time.Date(2008, time.December, 30, 1, 1, 1, 2, time.Local), sql.Timestamp), - "46:58:57.999999", + toTimespan("46:58:57.999999"), false, }, { "time types 1", expression.NewConvert(expression.NewLiteral("00:00:00.1", sql.Text), expression.ConvertToTime), expression.NewConvert(expression.NewLiteral("00:00:00.2", sql.Text), expression.ConvertToTime), - "-00:00:00.100000", + toTimespan("-00:00:00.100000"), false, }, { "time types 2", expression.NewLiteral("00:00:00.2", sql.Text), expression.NewLiteral("00:00:00.4", sql.Text), - "-00:00:00.200000", + toTimespan("-00:00:00.200000"), false, }, { "datetime types", expression.NewLiteral(time.Date(2008, time.December, 29, 0, 0, 0, 0, time.Local), sql.Datetime), expression.NewLiteral(time.Date(2008, time.December, 30, 0, 0, 0, 0, time.Local), sql.Datetime), - "-24:00:00", + toTimespan("-24:00:00"), false, }, { "datetime string types", expression.NewLiteral("2008-12-29 00:00:00", sql.Text), expression.NewLiteral("2008-12-30 00:00:00", sql.Text), - "-24:00:00", + toTimespan("-24:00:00"), false, }, { "datetime string mix types", expression.NewLiteral(time.Date(2008, time.December, 29, 0, 0, 0, 0, time.UTC), sql.Datetime), expression.NewLiteral("2008-12-30 00:00:00", sql.Text), - "-24:00:00", + toTimespan("-24:00:00"), false, }, } diff --git a/sql/geometry.go b/sql/geometry.go index 68c6316721..b05b818f86 100644 --- a/sql/geometry.go +++ b/sql/geometry.go @@ -17,21 +17,35 @@ package sql import ( "encoding/binary" "math" + "reflect" "github.com/dolthub/vitess/go/sqltypes" "github.com/dolthub/vitess/go/vt/proto/query" "gopkg.in/src-d/go-errors.v1" ) +// GeometryType represents the GEOMETRY type. +// https://dev.mysql.com/doc/refman/8.0/en/gis-class-geometry.html +// The type of the returned value is one of the following (each implements GeometryValue): Point, Polygon, LineString. type GeometryType struct { SRID uint32 DefinedSRID bool } +// GeometryValue is the value type returned from GeometryType, which is an interface over the following types: +// Point, Polygon, LineString. +type GeometryValue interface { + implementsGeometryValue() +} + var _ Type = GeometryType{} var _ SpatialColumnType = GeometryType{} -var ErrNotGeometry = errors.NewKind("Value of type %T is not a geometry") +var ( + ErrNotGeometry = errors.NewKind("Value of type %T is not a geometry") + + geometryValueType = reflect.TypeOf((*GeometryValue)(nil)).Elem() +) const ( CartesianSRID = uint32(0) @@ -262,7 +276,8 @@ func (t GeometryType) SQL(dest []byte, v interface{}) (sqltypes.Value, error) { return sqltypes.Value{}, nil } - val := appendAndSlice(dest, []byte(pv.(string))) + //TODO: pretty sure this is wrong, pv is not a string type + val := appendAndSliceString(dest, pv.(string)) return sqltypes.MakeTrusted(sqltypes.Geometry, val), nil } @@ -277,6 +292,11 @@ func (t GeometryType) Type() query.Type { return sqltypes.Geometry } +// ValueType implements Type interface. +func (t GeometryType) ValueType() reflect.Type { + return geometryValueType +} + // Zero implements Type interface. func (t GeometryType) Zero() interface{} { // TODO: it doesn't make sense for geometry to have a zero type diff --git a/sql/json.go b/sql/json.go index d663aeaa6a..f29742de1b 100644 --- a/sql/json.go +++ b/sql/json.go @@ -16,16 +16,24 @@ package sql import ( "encoding/json" + "reflect" "github.com/dolthub/vitess/go/sqltypes" "github.com/dolthub/vitess/go/vt/proto/query" "gopkg.in/src-d/go-errors.v1" ) -var ErrConvertingToJSON = errors.NewKind("value %v is not valid JSON") +var ( + ErrConvertingToJSON = errors.NewKind("value %v is not valid JSON") + + jsonValueType = reflect.TypeOf((*JSONValue)(nil)).Elem() +) var JSON JsonType = jsonType{} +// JsonType represents the JSON type. +// https://dev.mysql.com/doc/refman/8.0/en/json.html +// The type of the returned value is JSONValue. type JsonType interface { Type } @@ -97,7 +105,7 @@ func (t jsonType) SQL(dest []byte, v interface{}) (sqltypes.Value, error) { return sqltypes.NULL, err } - val := appendAndSlice(dest, []byte(s)) + val := appendAndSliceString(dest, s) return sqltypes.MakeTrusted(sqltypes.TypeJSON, val), nil } @@ -112,6 +120,11 @@ func (t jsonType) Type() query.Type { return sqltypes.TypeJSON } +// ValueType implements Type interface. +func (t jsonType) ValueType() reflect.Type { + return jsonValueType +} + // Zero implements Type interface. func (t jsonType) Zero() interface{} { // JSON Null diff --git a/sql/json_test.go b/sql/json_test.go index 6399107655..25deca792a 100644 --- a/sql/json_test.go +++ b/sql/json_test.go @@ -16,6 +16,7 @@ package sql import ( "fmt" + "reflect" "testing" querypb "github.com/dolthub/vitess/go/vt/proto/query" @@ -122,6 +123,9 @@ func TestJsonConvert(t *testing.T) { } else { require.NoError(t, err) assert.Equal(t, test.expectedVal, val) + if val != nil { + assert.True(t, reflect.TypeOf(val).Implements(JSON.ValueType())) + } } }) } diff --git a/sql/linestring.go b/sql/linestring.go index faef48dc8b..9c70597ab9 100644 --- a/sql/linestring.go +++ b/sql/linestring.go @@ -15,28 +15,37 @@ package sql import ( + "reflect" + "gopkg.in/src-d/go-errors.v1" "github.com/dolthub/vitess/go/sqltypes" "github.com/dolthub/vitess/go/vt/proto/query" ) -// Represents the LineString type. +// LineStringType represents the LINESTRING type. // https://dev.mysql.com/doc/refman/8.0/en/gis-class-linestring.html -type LineString struct { - SRID uint32 - Points []Point -} - +// The type of the returned value is LineString. type LineStringType struct { SRID uint32 DefinedSRID bool } +// LineString is the value type returned from LineStringType. Implements GeometryValue. +type LineString struct { + SRID uint32 + Points []Point +} + var _ Type = LineStringType{} var _ SpatialColumnType = LineStringType{} +var _ GeometryValue = LineString{} + +var ( + ErrNotLineString = errors.NewKind("value of type %T is not a linestring") -var ErrNotLineString = errors.NewKind("value of type %T is not a linestring") + lineStringValueType = reflect.TypeOf(LineString{}) +) // Compare implements Type interface. func (t LineStringType) Compare(a interface{}, b interface{}) (int, error) { @@ -146,7 +155,8 @@ func (t LineStringType) SQL(dest []byte, v interface{}) (sqltypes.Value, error) return sqltypes.Value{}, nil } - val := appendAndSlice(dest, []byte(pv.(string))) + //TODO: pretty sure this is wrong, pv is not a string type + val := appendAndSliceString(dest, pv.(string)) return sqltypes.MakeTrusted(sqltypes.Geometry, val), nil } @@ -161,6 +171,11 @@ func (t LineStringType) Type() query.Type { return sqltypes.Geometry } +// ValueType implements Type interface. +func (t LineStringType) ValueType() reflect.Type { + return lineStringValueType +} + // Zero implements Type interface. func (t LineStringType) Zero() interface{} { return LineString{Points: []Point{{}, {}}} @@ -191,3 +206,6 @@ func (t LineStringType) MatchSRID(v interface{}) error { } return ErrNotMatchingSRID.New(val.SRID, t.SRID) } + +// implementsGeometryValue implements GeometryValue interface. +func (p LineString) implementsGeometryValue() {} diff --git a/sql/mysql_db/db_table.go b/sql/mysql_db/db_table.go index 764f5b5c3d..13dfa669d6 100644 --- a/sql/mysql_db/db_table.go +++ b/sql/mysql_db/db_table.go @@ -74,7 +74,7 @@ func (conv DbConverter) AddRowToEntry(ctx *sql.Context, row sql.Row, entry in_me } var privs []sql.PrivilegeType for i, val := range row { - if strVal, ok := val.(string); ok && strVal == "Y" { + if uintVal, ok := val.(uint16); ok && uintVal == 2 { switch i { case dbTblColIndex_Select_priv: privs = append(privs, sql.PrivilegeType_Select) @@ -167,43 +167,43 @@ func (conv DbConverter) EntryToRows(ctx *sql.Context, entry in_mem_table.Entry) for _, priv := range dbSet.ToSlice() { switch priv { case sql.PrivilegeType_Select: - row[dbTblColIndex_Select_priv] = "Y" + row[dbTblColIndex_Select_priv] = uint16(2) case sql.PrivilegeType_Insert: - row[dbTblColIndex_Insert_priv] = "Y" + row[dbTblColIndex_Insert_priv] = uint16(2) case sql.PrivilegeType_Update: - row[dbTblColIndex_Update_priv] = "Y" + row[dbTblColIndex_Update_priv] = uint16(2) case sql.PrivilegeType_Delete: - row[dbTblColIndex_Delete_priv] = "Y" + row[dbTblColIndex_Delete_priv] = uint16(2) case sql.PrivilegeType_Create: - row[dbTblColIndex_Create_priv] = "Y" + row[dbTblColIndex_Create_priv] = uint16(2) case sql.PrivilegeType_Drop: - row[dbTblColIndex_Drop_priv] = "Y" + row[dbTblColIndex_Drop_priv] = uint16(2) case sql.PrivilegeType_Grant: - row[dbTblColIndex_Grant_priv] = "Y" + row[dbTblColIndex_Grant_priv] = uint16(2) case sql.PrivilegeType_References: - row[dbTblColIndex_References_priv] = "Y" + row[dbTblColIndex_References_priv] = uint16(2) case sql.PrivilegeType_Index: - row[dbTblColIndex_Index_priv] = "Y" + row[dbTblColIndex_Index_priv] = uint16(2) case sql.PrivilegeType_Alter: - row[dbTblColIndex_Alter_priv] = "Y" + row[dbTblColIndex_Alter_priv] = uint16(2) case sql.PrivilegeType_CreateTempTable: - row[dbTblColIndex_Create_tmp_table_priv] = "Y" + row[dbTblColIndex_Create_tmp_table_priv] = uint16(2) case sql.PrivilegeType_LockTables: - row[dbTblColIndex_Lock_tables_priv] = "Y" + row[dbTblColIndex_Lock_tables_priv] = uint16(2) case sql.PrivilegeType_CreateView: - row[dbTblColIndex_Create_view_priv] = "Y" + row[dbTblColIndex_Create_view_priv] = uint16(2) case sql.PrivilegeType_ShowView: - row[dbTblColIndex_Show_view_priv] = "Y" + row[dbTblColIndex_Show_view_priv] = uint16(2) case sql.PrivilegeType_CreateRoutine: - row[dbTblColIndex_Create_routine_priv] = "Y" + row[dbTblColIndex_Create_routine_priv] = uint16(2) case sql.PrivilegeType_AlterRoutine: - row[dbTblColIndex_Alter_routine_priv] = "Y" + row[dbTblColIndex_Alter_routine_priv] = uint16(2) case sql.PrivilegeType_Execute: - row[dbTblColIndex_Execute_priv] = "Y" + row[dbTblColIndex_Execute_priv] = uint16(2) case sql.PrivilegeType_Event: - row[dbTblColIndex_Event_priv] = "Y" + row[dbTblColIndex_Event_priv] = uint16(2) case sql.PrivilegeType_Trigger: - row[dbTblColIndex_Trigger_priv] = "Y" + row[dbTblColIndex_Trigger_priv] = uint16(2) } } rows = append(rows, row) diff --git a/sql/mysql_db/role_edge.go b/sql/mysql_db/role_edge.go index 363c4c6584..02361a35d0 100644 --- a/sql/mysql_db/role_edge.go +++ b/sql/mysql_db/role_edge.go @@ -44,7 +44,7 @@ func (r *RoleEdge) NewFromRow(ctx *sql.Context, row sql.Row) (in_mem_table.Entry FromUser: row[roleEdgesTblColIndex_FROM_USER].(string), ToHost: row[roleEdgesTblColIndex_TO_HOST].(string), ToUser: row[roleEdgesTblColIndex_TO_USER].(string), - WithAdminOption: row[roleEdgesTblColIndex_WITH_ADMIN_OPTION].(string) == "Y", + WithAdminOption: row[roleEdgesTblColIndex_WITH_ADMIN_OPTION].(uint16) == 2, }, nil } @@ -61,9 +61,9 @@ func (r *RoleEdge) ToRow(ctx *sql.Context) sql.Row { row[roleEdgesTblColIndex_TO_HOST] = r.ToHost row[roleEdgesTblColIndex_TO_USER] = r.ToUser if r.WithAdminOption { - row[roleEdgesTblColIndex_WITH_ADMIN_OPTION] = "Y" + row[roleEdgesTblColIndex_WITH_ADMIN_OPTION] = uint16(2) } else { - row[roleEdgesTblColIndex_WITH_ADMIN_OPTION] = "N" + row[roleEdgesTblColIndex_WITH_ADMIN_OPTION] = uint16(1) } return row } diff --git a/sql/mysql_db/tables_priv.go b/sql/mysql_db/tables_priv.go index 235b7ae6f8..783f25a1d9 100644 --- a/sql/mysql_db/tables_priv.go +++ b/sql/mysql_db/tables_priv.go @@ -78,12 +78,16 @@ func (conv TablesPrivConverter) AddRowToEntry(ctx *sql.Context, row sql.Row, ent if !ok { return nil, errTablesPrivRow } - tablePrivs, ok := row[tablesPrivTblColIndex_Table_priv].(string) + tablePrivs, ok := row[tablesPrivTblColIndex_Table_priv].(uint64) if !ok { return nil, errTablesPrivRow } + tablePrivStrs, err := tablesPrivTblSchema[tablesPrivTblColIndex_Table_priv].Type.(sql.SetType).BitsToString(tablePrivs) + if err != nil { + return nil, err + } var privs []sql.PrivilegeType - for _, val := range strings.Split(tablePrivs, ",") { + for _, val := range strings.Split(tablePrivStrs, ",") { switch val { case "Select": privs = append(privs, sql.PrivilegeType_Select) @@ -204,7 +208,7 @@ func (conv TablesPrivConverter) EntryToRows(ctx *sql.Context, entry in_mem_table if err != nil { return nil, err } - row[tablesPrivTblColIndex_Table_priv] = formattedSet.(string) + row[tablesPrivTblColIndex_Table_priv] = formattedSet.(uint64) rows = append(rows, row) } } diff --git a/sql/mysql_db/user.go b/sql/mysql_db/user.go index 334f18c3cd..fc5f121a9a 100644 --- a/sql/mysql_db/user.go +++ b/sql/mysql_db/user.go @@ -64,7 +64,7 @@ func (u *User) NewFromRow(ctx *sql.Context, row sql.Row) (in_mem_table.Entry, er Plugin: row[userTblColIndex_plugin].(string), Password: row[userTblColIndex_authentication_string].(string), PasswordLastChanged: passwordLastChanged, - Locked: row[userTblColIndex_account_locked].(string) == "Y", + Locked: row[userTblColIndex_account_locked].(uint16) == 2, Attributes: attributes, IsRole: false, }, nil @@ -97,7 +97,7 @@ func (u *User) ToRow(ctx *sql.Context) sql.Row { row[userTblColIndex_authentication_string] = u.Password row[userTblColIndex_password_last_changed] = u.PasswordLastChanged if u.Locked { - row[userTblColIndex_account_locked] = "Y" + row[userTblColIndex_account_locked] = uint16(2) } if u.Attributes != nil { row[userTblColIndex_User_attributes] = *u.Attributes @@ -172,127 +172,127 @@ func (u *User) rowToPrivSet(ctx *sql.Context, row sql.Row) PrivilegeSet { for i, val := range row { switch i { case userTblColIndex_Select_priv: - if val.(string) == "Y" { + if val.(uint16) == 2 { privSet.AddGlobalStatic(sql.PrivilegeType_Select) } case userTblColIndex_Insert_priv: - if val.(string) == "Y" { + if val.(uint16) == 2 { privSet.AddGlobalStatic(sql.PrivilegeType_Insert) } case userTblColIndex_Update_priv: - if val.(string) == "Y" { + if val.(uint16) == 2 { privSet.AddGlobalStatic(sql.PrivilegeType_Update) } case userTblColIndex_Delete_priv: - if val.(string) == "Y" { + if val.(uint16) == 2 { privSet.AddGlobalStatic(sql.PrivilegeType_Delete) } case userTblColIndex_Create_priv: - if val.(string) == "Y" { + if val.(uint16) == 2 { privSet.AddGlobalStatic(sql.PrivilegeType_Create) } case userTblColIndex_Drop_priv: - if val.(string) == "Y" { + if val.(uint16) == 2 { privSet.AddGlobalStatic(sql.PrivilegeType_Drop) } case userTblColIndex_Reload_priv: - if val.(string) == "Y" { + if val.(uint16) == 2 { privSet.AddGlobalStatic(sql.PrivilegeType_Reload) } case userTblColIndex_Shutdown_priv: - if val.(string) == "Y" { + if val.(uint16) == 2 { privSet.AddGlobalStatic(sql.PrivilegeType_Shutdown) } case userTblColIndex_Process_priv: - if val.(string) == "Y" { + if val.(uint16) == 2 { privSet.AddGlobalStatic(sql.PrivilegeType_Process) } case userTblColIndex_File_priv: - if val.(string) == "Y" { + if val.(uint16) == 2 { privSet.AddGlobalStatic(sql.PrivilegeType_File) } case userTblColIndex_Grant_priv: - if val.(string) == "Y" { + if val.(uint16) == 2 { privSet.AddGlobalStatic(sql.PrivilegeType_Grant) } case userTblColIndex_References_priv: - if val.(string) == "Y" { + if val.(uint16) == 2 { privSet.AddGlobalStatic(sql.PrivilegeType_References) } case userTblColIndex_Index_priv: - if val.(string) == "Y" { + if val.(uint16) == 2 { privSet.AddGlobalStatic(sql.PrivilegeType_Index) } case userTblColIndex_Alter_priv: - if val.(string) == "Y" { + if val.(uint16) == 2 { privSet.AddGlobalStatic(sql.PrivilegeType_Alter) } case userTblColIndex_Show_db_priv: - if val.(string) == "Y" { + if val.(uint16) == 2 { privSet.AddGlobalStatic(sql.PrivilegeType_ShowDB) } case userTblColIndex_Super_priv: - if val.(string) == "Y" { + if val.(uint16) == 2 { privSet.AddGlobalStatic(sql.PrivilegeType_Super) } case userTblColIndex_Create_tmp_table_priv: - if val.(string) == "Y" { + if val.(uint16) == 2 { privSet.AddGlobalStatic(sql.PrivilegeType_CreateTempTable) } case userTblColIndex_Lock_tables_priv: - if val.(string) == "Y" { + if val.(uint16) == 2 { privSet.AddGlobalStatic(sql.PrivilegeType_LockTables) } case userTblColIndex_Execute_priv: - if val.(string) == "Y" { + if val.(uint16) == 2 { privSet.AddGlobalStatic(sql.PrivilegeType_Execute) } case userTblColIndex_Repl_slave_priv: - if val.(string) == "Y" { + if val.(uint16) == 2 { privSet.AddGlobalStatic(sql.PrivilegeType_ReplicationSlave) } case userTblColIndex_Repl_client_priv: - if val.(string) == "Y" { + if val.(uint16) == 2 { privSet.AddGlobalStatic(sql.PrivilegeType_ReplicationClient) } case userTblColIndex_Create_view_priv: - if val.(string) == "Y" { + if val.(uint16) == 2 { privSet.AddGlobalStatic(sql.PrivilegeType_CreateView) } case userTblColIndex_Show_view_priv: - if val.(string) == "Y" { + if val.(uint16) == 2 { privSet.AddGlobalStatic(sql.PrivilegeType_ShowView) } case userTblColIndex_Create_routine_priv: - if val.(string) == "Y" { + if val.(uint16) == 2 { privSet.AddGlobalStatic(sql.PrivilegeType_CreateRoutine) } case userTblColIndex_Alter_routine_priv: - if val.(string) == "Y" { + if val.(uint16) == 2 { privSet.AddGlobalStatic(sql.PrivilegeType_AlterRoutine) } case userTblColIndex_Create_user_priv: - if val.(string) == "Y" { + if val.(uint16) == 2 { privSet.AddGlobalStatic(sql.PrivilegeType_CreateUser) } case userTblColIndex_Event_priv: - if val.(string) == "Y" { + if val.(uint16) == 2 { privSet.AddGlobalStatic(sql.PrivilegeType_Event) } case userTblColIndex_Trigger_priv: - if val.(string) == "Y" { + if val.(uint16) == 2 { privSet.AddGlobalStatic(sql.PrivilegeType_Trigger) } case userTblColIndex_Create_tablespace_priv: - if val.(string) == "Y" { + if val.(uint16) == 2 { privSet.AddGlobalStatic(sql.PrivilegeType_CreateTablespace) } case userTblColIndex_Create_role_priv: - if val.(string) == "Y" { + if val.(uint16) == 2 { privSet.AddGlobalStatic(sql.PrivilegeType_CreateRole) } case userTblColIndex_Drop_role_priv: - if val.(string) == "Y" { + if val.(uint16) == 2 { privSet.AddGlobalStatic(sql.PrivilegeType_DropRole) } } @@ -306,67 +306,67 @@ func (u *User) privSetToRow(ctx *sql.Context, row sql.Row) { for _, priv := range u.PrivilegeSet.ToSlice() { switch priv { case sql.PrivilegeType_Select: - row[userTblColIndex_Select_priv] = "Y" + row[userTblColIndex_Select_priv] = uint16(2) case sql.PrivilegeType_Insert: - row[userTblColIndex_Insert_priv] = "Y" + row[userTblColIndex_Insert_priv] = uint16(2) case sql.PrivilegeType_Update: - row[userTblColIndex_Update_priv] = "Y" + row[userTblColIndex_Update_priv] = uint16(2) case sql.PrivilegeType_Delete: - row[userTblColIndex_Delete_priv] = "Y" + row[userTblColIndex_Delete_priv] = uint16(2) case sql.PrivilegeType_Create: - row[userTblColIndex_Create_priv] = "Y" + row[userTblColIndex_Create_priv] = uint16(2) case sql.PrivilegeType_Drop: - row[userTblColIndex_Drop_priv] = "Y" + row[userTblColIndex_Drop_priv] = uint16(2) case sql.PrivilegeType_Reload: - row[userTblColIndex_Reload_priv] = "Y" + row[userTblColIndex_Reload_priv] = uint16(2) case sql.PrivilegeType_Shutdown: - row[userTblColIndex_Shutdown_priv] = "Y" + row[userTblColIndex_Shutdown_priv] = uint16(2) case sql.PrivilegeType_Process: - row[userTblColIndex_Process_priv] = "Y" + row[userTblColIndex_Process_priv] = uint16(2) case sql.PrivilegeType_File: - row[userTblColIndex_File_priv] = "Y" + row[userTblColIndex_File_priv] = uint16(2) case sql.PrivilegeType_Grant: - row[userTblColIndex_Grant_priv] = "Y" + row[userTblColIndex_Grant_priv] = uint16(2) case sql.PrivilegeType_References: - row[userTblColIndex_References_priv] = "Y" + row[userTblColIndex_References_priv] = uint16(2) case sql.PrivilegeType_Index: - row[userTblColIndex_Index_priv] = "Y" + row[userTblColIndex_Index_priv] = uint16(2) case sql.PrivilegeType_Alter: - row[userTblColIndex_Alter_priv] = "Y" + row[userTblColIndex_Alter_priv] = uint16(2) case sql.PrivilegeType_ShowDB: - row[userTblColIndex_Show_db_priv] = "Y" + row[userTblColIndex_Show_db_priv] = uint16(2) case sql.PrivilegeType_Super: - row[userTblColIndex_Super_priv] = "Y" + row[userTblColIndex_Super_priv] = uint16(2) case sql.PrivilegeType_CreateTempTable: - row[userTblColIndex_Create_tmp_table_priv] = "Y" + row[userTblColIndex_Create_tmp_table_priv] = uint16(2) case sql.PrivilegeType_LockTables: - row[userTblColIndex_Lock_tables_priv] = "Y" + row[userTblColIndex_Lock_tables_priv] = uint16(2) case sql.PrivilegeType_Execute: - row[userTblColIndex_Execute_priv] = "Y" + row[userTblColIndex_Execute_priv] = uint16(2) case sql.PrivilegeType_ReplicationSlave: - row[userTblColIndex_Repl_slave_priv] = "Y" + row[userTblColIndex_Repl_slave_priv] = uint16(2) case sql.PrivilegeType_ReplicationClient: - row[userTblColIndex_Repl_client_priv] = "Y" + row[userTblColIndex_Repl_client_priv] = uint16(2) case sql.PrivilegeType_CreateView: - row[userTblColIndex_Create_view_priv] = "Y" + row[userTblColIndex_Create_view_priv] = uint16(2) case sql.PrivilegeType_ShowView: - row[userTblColIndex_Show_view_priv] = "Y" + row[userTblColIndex_Show_view_priv] = uint16(2) case sql.PrivilegeType_CreateRoutine: - row[userTblColIndex_Create_routine_priv] = "Y" + row[userTblColIndex_Create_routine_priv] = uint16(2) case sql.PrivilegeType_AlterRoutine: - row[userTblColIndex_Alter_routine_priv] = "Y" + row[userTblColIndex_Alter_routine_priv] = uint16(2) case sql.PrivilegeType_CreateUser: - row[userTblColIndex_Create_user_priv] = "Y" + row[userTblColIndex_Create_user_priv] = uint16(2) case sql.PrivilegeType_Event: - row[userTblColIndex_Event_priv] = "Y" + row[userTblColIndex_Event_priv] = uint16(2) case sql.PrivilegeType_Trigger: - row[userTblColIndex_Trigger_priv] = "Y" + row[userTblColIndex_Trigger_priv] = uint16(2) case sql.PrivilegeType_CreateTablespace: - row[userTblColIndex_Create_tablespace_priv] = "Y" + row[userTblColIndex_Create_tablespace_priv] = uint16(2) case sql.PrivilegeType_CreateRole: - row[userTblColIndex_Create_role_priv] = "Y" + row[userTblColIndex_Create_role_priv] = uint16(2) case sql.PrivilegeType_DropRole: - row[userTblColIndex_Drop_role_priv] = "Y" + row[userTblColIndex_Drop_role_priv] = uint16(2) } } } diff --git a/sql/nulltype.go b/sql/nulltype.go index 0af99c9e38..a37226eb42 100644 --- a/sql/nulltype.go +++ b/sql/nulltype.go @@ -15,6 +15,8 @@ package sql import ( + "reflect" + "github.com/dolthub/vitess/go/sqltypes" "github.com/dolthub/vitess/go/vt/proto/query" "gopkg.in/src-d/go-errors.v1" @@ -83,6 +85,11 @@ func (t nullType) Type() query.Type { return sqltypes.Null } +// ValueType implements Type interface. +func (t nullType) ValueType() reflect.Type { + return nil +} + // Zero implements Type interface. func (t nullType) Zero() interface{} { return nil diff --git a/sql/numbertype.go b/sql/numbertype.go index 64de6b91e5..0b6d703b92 100644 --- a/sql/numbertype.go +++ b/sql/numbertype.go @@ -18,6 +18,7 @@ import ( "encoding/hex" "fmt" "math" + "reflect" "strconv" "time" @@ -74,11 +75,23 @@ var ( dec_int64_min = decimal.NewFromInt(math.MinInt64) // decimal that represents the zero value dec_zero = decimal.NewFromInt(0) + + numberInt8ValueType = reflect.TypeOf(int8(0)) + numberInt16ValueType = reflect.TypeOf(int16(0)) + numberInt32ValueType = reflect.TypeOf(int32(0)) + numberInt64ValueType = reflect.TypeOf(int64(0)) + numberUint8ValueType = reflect.TypeOf(uint8(0)) + numberUint16ValueType = reflect.TypeOf(uint16(0)) + numberUint32ValueType = reflect.TypeOf(uint32(0)) + numberUint64ValueType = reflect.TypeOf(uint64(0)) + numberFloat32ValueType = reflect.TypeOf(float32(0)) + numberFloat64ValueType = reflect.TypeOf(float64(0)) ) -// Represents all integer and floating point types. +// NumberType represents all integer and floating point types. // https://dev.mysql.com/doc/refman/8.0/en/integer-types.html // https://dev.mysql.com/doc/refman/8.0/en/floating-point-types.html +// The type of the returned value is one of the following: int8, int16, int32, int64, uint8, uint16, uint32, uint64, float32, float64. type NumberType interface { Type IsSigned() bool @@ -597,6 +610,38 @@ func (t numberTypeImpl) Type() query.Type { return t.baseType } +// ValueType implements Type interface. +func (t numberTypeImpl) ValueType() reflect.Type { + switch t.baseType { + case sqltypes.Int8: + return numberInt8ValueType + case sqltypes.Uint8: + return numberUint8ValueType + case sqltypes.Int16: + return numberInt16ValueType + case sqltypes.Uint16: + return numberUint16ValueType + case sqltypes.Int24: + return numberInt32ValueType + case sqltypes.Uint24: + return numberUint32ValueType + case sqltypes.Int32: + return numberInt32ValueType + case sqltypes.Uint32: + return numberUint32ValueType + case sqltypes.Int64: + return numberInt64ValueType + case sqltypes.Uint64: + return numberUint64ValueType + case sqltypes.Float32: + return numberFloat32ValueType + case sqltypes.Float64: + return numberFloat64ValueType + default: + panic(fmt.Sprintf("%v is not a valid number base type", t.baseType.String())) + } +} + // Zero implements Type interface. func (t numberTypeImpl) Zero() interface{} { switch t.baseType { diff --git a/sql/numbertype_test.go b/sql/numbertype_test.go index 925b1bf9fd..297ff1b912 100644 --- a/sql/numbertype_test.go +++ b/sql/numbertype_test.go @@ -17,6 +17,7 @@ package sql import ( "fmt" "math" + "reflect" "strconv" "testing" "time" @@ -227,6 +228,9 @@ func TestNumberConvert(t *testing.T) { } else { require.NoError(t, err) assert.Equal(t, test.expectedVal, val) + if val != nil { + assert.Equal(t, test.typ.ValueType(), reflect.TypeOf(val)) + } } }) } diff --git a/sql/plan/common_test.go b/sql/plan/common_test.go index 1c2327bdc0..6eb470adf1 100644 --- a/sql/plan/common_test.go +++ b/sql/plan/common_test.go @@ -39,15 +39,19 @@ var benchtable = func() *memory.Table { for i := 0; i < 100; i++ { n := fmt.Sprint(i) + boolVal := int8(0) + if i%2 == 0 { + boolVal = 1 + } err := t.Insert( sql.NewEmptyContext(), sql.NewRow( repeatStr(n, i%10+1), float64(i), - i%2 == 0, + boolVal, int32(i), int64(i), - []byte(repeatStr(n, 100+(i%100))), + repeatStr(n, 100+(i%100)), ), ) if err != nil { @@ -60,10 +64,10 @@ var benchtable = func() *memory.Table { sql.NewRow( repeatStr(n, i%10+1), float64(i), - i%2 == 0, + boolVal, int32(i), int64(i), - []byte(repeatStr(n, 100+(i%100))), + repeatStr(n, 100+(i%100)), ), ) if err != nil { diff --git a/sql/plan/external_procedure.go b/sql/plan/external_procedure.go index c13a66768e..750d281b77 100644 --- a/sql/plan/external_procedure.go +++ b/sql/plan/external_procedure.go @@ -183,11 +183,7 @@ func (n *ExternalProcedure) processParam(ctx *sql.Context, funcParamType reflect exprParamVal = int(exprParamVal.(uint64)) } case decimalType: - var err error - exprParamVal, err = decimal.NewFromString(exprParamVal.(string)) - if err != nil { - return reflect.Value{}, err - } + exprParamVal = exprParamVal.(decimal.Decimal) } if funcParamType.Kind() == reflect.Ptr { // Coincides with INOUT diff --git a/sql/plan/sort_test.go b/sql/plan/sort_test.go index 15e03980bf..1fcf94481a 100644 --- a/sql/plan/sort_test.go +++ b/sql/plan/sort_test.go @@ -42,9 +42,9 @@ func TestSort(t *testing.T) { { rows: []sql.Row{ sql.NewRow("c", nil, nil), - sql.NewRow("a", int32(3), 3.0), - sql.NewRow("b", int32(3), 3.0), - sql.NewRow("c", int32(1), 1.0), + sql.NewRow("a", int32(3), float64(3.0)), + sql.NewRow("b", int32(3), float64(3.0)), + sql.NewRow("c", int32(1), float64(1.0)), sql.NewRow(nil, int32(1), nil), }, sortFields: []sql.SortField{ @@ -55,14 +55,14 @@ func TestSort(t *testing.T) { expected: []sql.Row{ sql.NewRow("c", nil, nil), sql.NewRow(nil, int32(1), nil), - sql.NewRow("c", int32(1), 1.0), - sql.NewRow("b", int32(3), 3.0), - sql.NewRow("a", int32(3), 3.0), + sql.NewRow("c", int32(1), float64(1.0)), + sql.NewRow("b", int32(3), float64(3.0)), + sql.NewRow("a", int32(3), float64(3.0)), }, }, { rows: []sql.Row{ - sql.NewRow("c", int32(3), 3.0), + sql.NewRow("c", int32(3), float64(3.0)), sql.NewRow("c", int32(3), nil), }, sortFields: []sql.SortField{ @@ -72,15 +72,15 @@ func TestSort(t *testing.T) { }, expected: []sql.Row{ sql.NewRow("c", int32(3), nil), - sql.NewRow("c", int32(3), 3.0), + sql.NewRow("c", int32(3), float64(3.0)), }, }, { rows: []sql.Row{ sql.NewRow("c", nil, nil), - sql.NewRow("a", int32(3), 3.0), - sql.NewRow("b", int32(3), 3.0), - sql.NewRow("c", int32(1), 1.0), + sql.NewRow("a", int32(3), float64(3.0)), + sql.NewRow("b", int32(3), float64(3.0)), + sql.NewRow("c", int32(1), float64(1.0)), sql.NewRow(nil, int32(1), nil), }, sortFields: []sql.SortField{ @@ -91,15 +91,15 @@ func TestSort(t *testing.T) { expected: []sql.Row{ sql.NewRow("c", nil, nil), sql.NewRow(nil, int32(1), nil), - sql.NewRow("c", int32(1), 1.0), - sql.NewRow("a", int32(3), 3.0), - sql.NewRow("b", int32(3), 3.0), + sql.NewRow("c", int32(1), float64(1.0)), + sql.NewRow("a", int32(3), float64(3.0)), + sql.NewRow("b", int32(3), float64(3.0)), }, }, { rows: []sql.Row{ - sql.NewRow("a", int32(1), 2), - sql.NewRow("a", int32(1), 1), + sql.NewRow("a", int32(1), float64(2)), + sql.NewRow("a", int32(1), float64(1)), }, sortFields: []sql.SortField{ {Column: expression.NewGetField(0, sql.Text, "col1", true), Order: sql.Ascending, NullOrdering: sql.NullsFirst}, @@ -107,18 +107,18 @@ func TestSort(t *testing.T) { {Column: expression.NewGetField(2, sql.Float64, "col3", true), Order: sql.Ascending, NullOrdering: sql.NullsFirst}, }, expected: []sql.Row{ - sql.NewRow("a", int32(1), 1), - sql.NewRow("a", int32(1), 2), + sql.NewRow("a", int32(1), float64(1)), + sql.NewRow("a", int32(1), float64(2)), }, }, { rows: []sql.Row{ - sql.NewRow("a", int32(1), 2), - sql.NewRow("a", int32(1), 1), - sql.NewRow("a", int32(2), 2), - sql.NewRow("a", int32(3), 1), - sql.NewRow("b", int32(2), 2), - sql.NewRow("c", int32(3), 1), + sql.NewRow("a", int32(1), float64(2)), + sql.NewRow("a", int32(1), float64(1)), + sql.NewRow("a", int32(2), float64(2)), + sql.NewRow("a", int32(3), float64(1)), + sql.NewRow("b", int32(2), float64(2)), + sql.NewRow("c", int32(3), float64(1)), }, sortFields: []sql.SortField{ {Column: expression.NewGetField(0, sql.Text, "col1", true), Order: sql.Ascending, NullOrdering: sql.NullsFirst}, @@ -126,18 +126,18 @@ func TestSort(t *testing.T) { {Column: expression.NewGetField(2, sql.Float64, "col3", true), Order: sql.Ascending, NullOrdering: sql.NullsFirst}, }, expected: []sql.Row{ - sql.NewRow("a", int32(3), 1), - sql.NewRow("a", int32(2), 2), - sql.NewRow("a", int32(1), 1), - sql.NewRow("a", int32(1), 2), - sql.NewRow("b", int32(2), 2), - sql.NewRow("c", int32(3), 1), + sql.NewRow("a", int32(3), float64(1)), + sql.NewRow("a", int32(2), float64(2)), + sql.NewRow("a", int32(1), float64(1)), + sql.NewRow("a", int32(1), float64(2)), + sql.NewRow("b", int32(2), float64(2)), + sql.NewRow("c", int32(3), float64(1)), }, }, { rows: []sql.Row{ - sql.NewRow(nil, nil, 2), - sql.NewRow(nil, nil, 1), + sql.NewRow(nil, nil, float64(2)), + sql.NewRow(nil, nil, float64(1)), }, sortFields: []sql.SortField{ {Column: expression.NewGetField(0, sql.Text, "col1", true), Order: sql.Ascending, NullOrdering: sql.NullsFirst}, @@ -145,14 +145,14 @@ func TestSort(t *testing.T) { {Column: expression.NewGetField(2, sql.Float64, "col3", true), Order: sql.Ascending, NullOrdering: sql.NullsFirst}, }, expected: []sql.Row{ - sql.NewRow(nil, nil, 1), - sql.NewRow(nil, nil, 2), + sql.NewRow(nil, nil, float64(1)), + sql.NewRow(nil, nil, float64(2)), }, }, { rows: []sql.Row{ - sql.NewRow(nil, nil, 1), - sql.NewRow(nil, nil, 2), + sql.NewRow(nil, nil, float64(1)), + sql.NewRow(nil, nil, float64(2)), }, sortFields: []sql.SortField{ {Column: expression.NewGetField(0, sql.Text, "col1", true), Order: sql.Descending, NullOrdering: sql.NullsFirst}, @@ -160,13 +160,13 @@ func TestSort(t *testing.T) { {Column: expression.NewGetField(2, sql.Float64, "col3", true), Order: sql.Descending, NullOrdering: sql.NullsFirst}, }, expected: []sql.Row{ - sql.NewRow(nil, nil, 2), - sql.NewRow(nil, nil, 1), + sql.NewRow(nil, nil, float64(2)), + sql.NewRow(nil, nil, float64(1)), }, }, { rows: []sql.Row{ - sql.NewRow(nil, nil, 1), + sql.NewRow(nil, nil, float64(1)), sql.NewRow(nil, nil, nil), }, sortFields: []sql.SortField{ @@ -176,13 +176,13 @@ func TestSort(t *testing.T) { }, expected: []sql.Row{ sql.NewRow(nil, nil, nil), - sql.NewRow(nil, nil, 1), + sql.NewRow(nil, nil, float64(1)), }, }, { rows: []sql.Row{ sql.NewRow(nil, nil, nil), - sql.NewRow(nil, nil, 1), + sql.NewRow(nil, nil, float64(1)), }, sortFields: []sql.SortField{ {Column: expression.NewGetField(0, sql.Text, "col1", true), Order: sql.Ascending, NullOrdering: sql.NullsLast}, @@ -190,7 +190,7 @@ func TestSort(t *testing.T) { {Column: expression.NewGetField(2, sql.Float64, "col3", true), Order: sql.Ascending, NullOrdering: sql.NullsLast}, }, expected: []sql.Row{ - sql.NewRow(nil, nil, 1), + sql.NewRow(nil, nil, float64(1)), sql.NewRow(nil, nil, nil), }, }, diff --git a/sql/point.go b/sql/point.go index b5b832bed4..b6f19e125f 100644 --- a/sql/point.go +++ b/sql/point.go @@ -15,28 +15,37 @@ package sql import ( + "reflect" + "github.com/dolthub/vitess/go/sqltypes" "github.com/dolthub/vitess/go/vt/proto/query" "gopkg.in/src-d/go-errors.v1" ) -// Represents the Point type. +// PointType represents the POINT type. // https://dev.mysql.com/doc/refman/8.0/en/gis-class-point.html +// The type of the returned value is Point. +type PointType struct { + SRID uint32 + DefinedSRID bool +} + +// Point is the value type returned from PointType. Implements GeometryValue. type Point struct { SRID uint32 X float64 Y float64 } -type PointType struct { - SRID uint32 - DefinedSRID bool -} - var _ Type = PointType{} var _ SpatialColumnType = PointType{} +var _ GeometryValue = Point{} -var ErrNotPoint = errors.NewKind("value of type %T is not a point") +var ( + ErrNotPoint = errors.NewKind("value of type %T is not a point") + + pointValueType = reflect.TypeOf(Point{}) +) // Compare implements Type interface. func (t PointType) Compare(a interface{}, b interface{}) (int, error) { @@ -133,7 +142,10 @@ func (t PointType) SQL(dest []byte, v interface{}) (sqltypes.Value, error) { return sqltypes.Value{}, nil } - return sqltypes.MakeTrusted(sqltypes.Geometry, []byte(pv.(string))), nil + //TODO: pretty sure this is wrong, pv is not a string type + val := appendAndSliceString(dest, pv.(string)) + + return sqltypes.MakeTrusted(sqltypes.Geometry, val), nil } // String implements Type interface. @@ -151,6 +163,11 @@ func (t PointType) Zero() interface{} { return Point{X: 0.0, Y: 0.0} } +// ValueType implements Type interface. +func (t PointType) ValueType() reflect.Type { + return pointValueType +} + // GetSpatialTypeSRID implements SpatialColumnType interface. func (t PointType) GetSpatialTypeSRID() (uint32, bool) { return t.SRID, t.DefinedSRID @@ -176,3 +193,6 @@ func (t PointType) MatchSRID(v interface{}) error { } return ErrNotMatchingSRID.New(val.SRID, t.SRID) } + +// implementsGeometryValue implements GeometryValue interface. +func (p Point) implementsGeometryValue() {} diff --git a/sql/polygon.go b/sql/polygon.go index 01b32259b8..b4055b7cf2 100644 --- a/sql/polygon.go +++ b/sql/polygon.go @@ -15,28 +15,37 @@ package sql import ( + "reflect" + "gopkg.in/src-d/go-errors.v1" "github.com/dolthub/vitess/go/sqltypes" "github.com/dolthub/vitess/go/vt/proto/query" ) -// Represents the Polygon type. +// PolygonType represents the POLYGON type. // https://dev.mysql.com/doc/refman/8.0/en/gis-class-polygon.html -type Polygon struct { - SRID uint32 - Lines []LineString -} - +// The type of the returned value is Polygon. type PolygonType struct { SRID uint32 DefinedSRID bool } +// Polygon is the value type returned from PolygonType. Implements GeometryValue. +type Polygon struct { + SRID uint32 + Lines []LineString +} + var _ Type = PolygonType{} var _ SpatialColumnType = PolygonType{} +var _ GeometryValue = Polygon{} -var ErrNotPolygon = errors.NewKind("value of type %T is not a polygon") +var ( + ErrNotPolygon = errors.NewKind("value of type %T is not a polygon") + + polygonValueType = reflect.TypeOf(Polygon{}) +) // Compare implements Type interface. func (t PolygonType) Compare(a interface{}, b interface{}) (int, error) { @@ -146,7 +155,10 @@ func (t PolygonType) SQL(dest []byte, v interface{}) (sqltypes.Value, error) { return sqltypes.Value{}, nil } - return sqltypes.MakeTrusted(sqltypes.Geometry, []byte(lv.(string))), nil + //TODO: pretty sure this is wrong, lv is not a string type + val := appendAndSliceString(dest, lv.(string)) + + return sqltypes.MakeTrusted(sqltypes.Geometry, val), nil } // String implements Type interface. @@ -159,6 +171,11 @@ func (t PolygonType) Type() query.Type { return sqltypes.Geometry } +// ValueType implements Type interface. +func (t PolygonType) ValueType() reflect.Type { + return polygonValueType +} + // Zero implements Type interface. func (t PolygonType) Zero() interface{} { return Polygon{Lines: []LineString{{Points: []Point{{}, {}, {}, {}}}}} @@ -189,3 +206,6 @@ func (t PolygonType) MatchSRID(v interface{}) error { } return ErrNotMatchingSRID.New(val.SRID, t.SRID) } + +// implementsGeometryValue implements GeometryValue interface. +func (p Polygon) implementsGeometryValue() {} diff --git a/sql/settype.go b/sql/settype.go index 4def914722..64fb02777f 100644 --- a/sql/settype.go +++ b/sql/settype.go @@ -18,9 +18,12 @@ import ( "fmt" "math" "math/bits" + "reflect" "strconv" "strings" + "github.com/shopspring/decimal" + "github.com/dolthub/vitess/go/sqltypes" "github.com/dolthub/vitess/go/vt/proto/query" "gopkg.in/src-d/go-errors.v1" @@ -36,20 +39,24 @@ var ( ErrDuplicateEntrySet = errors.NewKind("duplicate entry: %v") ErrInvalidSetValue = errors.NewKind("value %v was not found in the set") ErrTooLargeForSet = errors.NewKind(`value "%v" is too large for this set`) + + setValueType = reflect.TypeOf(uint64(0)) ) // Comments with three slashes were taken directly from the linked documentation. -// Represents the SET type. +// SetType represents the SET type. // https://dev.mysql.com/doc/refman/8.0/en/set.html +// The type of the returned value is uint64. type SetType interface { Type CharacterSet() CharacterSet Collation() Collation - //TODO: move this out of go-mysql-server and into the Dolt layer - Marshal(v interface{}) (uint64, error) + // NumberOfElements returns the number of elements in this set. NumberOfElements() uint16 - Unmarshal(bits uint64) (string, error) + // BitsToString takes a previously-converted value and returns it as a string. + BitsToString(bits uint64) (string, error) + // Values returns all of the set's values in ascending order according to their corresponding bit value. Values() []string } @@ -121,18 +128,20 @@ func (t setType) Compare(a interface{}, b interface{}) (int, error) { return res, nil } - ai, err := t.Marshal(a) + ai, err := t.Convert(a) if err != nil { return 0, err } - bi, err := t.Marshal(b) + bi, err := t.Convert(b) if err != nil { return 0, err } + au := ai.(uint64) + bu := bi.(uint64) - if ai < bi { + if au < bu { return -1, nil - } else if ai > bi { + } else if au > bu { return 1, nil } return 0, nil @@ -166,26 +175,26 @@ func (t setType) Convert(v interface{}) (interface{}, error) { return t.Convert(uint64(value)) case uint64: if value <= t.allValuesBitField() { - return t.convertBitFieldToString(value) + return value, nil } - return nil, ErrConvertingToSet.New(v) case float32: return t.Convert(uint64(value)) case float64: return t.Convert(uint64(value)) - case string: - // For SET('a','b') and given a string 'b,a,a', we would return 'a,b', so we can't return the input. - bitField, err := t.convertStringToBitField(value) - if err != nil { - return nil, err + case decimal.Decimal: + return t.Convert(value.BigInt().Uint64()) + case decimal.NullDecimal: + if !value.Valid { + return nil, nil } - setStr, _ := t.convertBitFieldToString(bitField) - return setStr, nil + return t.Convert(value.Decimal.BigInt().Uint64()) + case string: + return t.convertStringToBitField(value) case []byte: return t.Convert(string(value)) } - return nil, ErrConvertingToSet.New(v) + return uint64(0), ErrConvertingToSet.New(v) } // MustConvert implements the Type interface. @@ -220,12 +229,16 @@ func (t setType) SQL(dest []byte, v interface{}) (sqltypes.Value, error) { if v == nil { return sqltypes.NULL, nil } - value, err := t.Convert(v) + convertedValue, err := t.Convert(v) + if err != nil { + return sqltypes.Value{}, err + } + value, err := t.BitsToString(convertedValue.(uint64)) if err != nil { return sqltypes.Value{}, err } - val := appendAndSlice(dest, []byte(value.(string))) + val := appendAndSliceString(dest, value) return sqltypes.MakeTrusted(sqltypes.Set, val), nil } @@ -247,68 +260,37 @@ func (t setType) Type() query.Type { return sqltypes.Set } +// ValueType implements Type interface. +func (t setType) ValueType() reflect.Type { + return setValueType +} + // Zero implements Type interface. func (t setType) Zero() interface{} { return "" } +// CharacterSet implements EnumType interface. func (t setType) CharacterSet() CharacterSet { return t.collation.CharacterSet() } +// Collation implements EnumType interface. func (t setType) Collation() Collation { return t.collation } -// Marshal takes a valid Set value and returns it as an uint64. -func (t setType) Marshal(v interface{}) (uint64, error) { - switch value := v.(type) { - case int: - return t.Marshal(uint64(value)) - case uint: - return t.Marshal(uint64(value)) - case int8: - return t.Marshal(uint64(value)) - case uint8: - return t.Marshal(uint64(value)) - case int16: - return t.Marshal(uint64(value)) - case uint16: - return t.Marshal(uint64(value)) - case int32: - return t.Marshal(uint64(value)) - case uint32: - return t.Marshal(uint64(value)) - case int64: - return t.Marshal(uint64(value)) - case uint64: - if value <= t.allValuesBitField() { - return value, nil - } - case float32: - return t.Marshal(uint64(value)) - case float64: - return t.Marshal(uint64(value)) - case string: - return t.convertStringToBitField(value) - case []byte: - return t.Marshal(string(value)) - } - - return uint64(0), ErrConvertingToSet.New(v) -} - -// NumberOfElements returns the number of elements in this set. +// NumberOfElements implements EnumType interface. func (t setType) NumberOfElements() uint16 { return uint16(len(t.valToBit)) } -// Unmarshal takes a previously-marshalled value and returns it as a string. -func (t setType) Unmarshal(v uint64) (string, error) { +// BitsToString implements EnumType interface. +func (t setType) BitsToString(v uint64) (string, error) { return t.convertBitFieldToString(v) } -// Values returns all of the set's values in ascending order according to their corresponding bit value. +// Values implements EnumType interface. func (t setType) Values() []string { bitEdge := 64 - bits.LeadingZeros64(t.allValuesBitField()) valArray := make([]string, bitEdge) diff --git a/sql/settype_test.go b/sql/settype_test.go index a2200ba2de..00d6f0c69c 100644 --- a/sql/settype_test.go +++ b/sql/settype_test.go @@ -16,6 +16,7 @@ package sql import ( "fmt" + "reflect" "strconv" "testing" "time" @@ -185,7 +186,12 @@ func TestSetConvert(t *testing.T) { assert.Error(t, err) } else { require.NoError(t, err) - assert.Equal(t, test.expectedVal, val) + res, err := typ.Compare(test.expectedVal, val) + require.NoError(t, err) + assert.Equal(t, 0, res) + if val != nil { + assert.Equal(t, typ.ValueType(), reflect.TypeOf(val)) + } } }) } @@ -208,12 +214,14 @@ func TestSetMarshalMax(t *testing.T) { for _, test := range tests { t.Run(fmt.Sprintf("%v", test), func(t *testing.T) { - bits, err := typ.Marshal(test) + bits, err := typ.Convert(test) require.NoError(t, err) - res1, err := typ.Unmarshal(bits) + res1, err := typ.BitsToString(bits.(uint64)) require.NoError(t, err) require.Equal(t, test, res1) - res2, err := typ.Convert(bits) + bits2, err := typ.Convert(bits) + require.NoError(t, err) + res2, err := typ.BitsToString(bits2.(uint64)) require.NoError(t, err) require.Equal(t, test, res2) }) diff --git a/sql/stringtype.go b/sql/stringtype.go index f3b5cff5fd..2a07056530 100644 --- a/sql/stringtype.go +++ b/sql/stringtype.go @@ -16,6 +16,7 @@ package sql import ( "fmt" + "reflect" "strconv" "strings" "time" @@ -53,12 +54,15 @@ var ( Blob = MustCreateBinary(sqltypes.Blob, textBlobMax) MediumBlob = MustCreateBinary(sqltypes.Blob, mediumTextBlobMax) LongBlob = MustCreateBinary(sqltypes.Blob, longTextBlobMax) + + stringValueType = reflect.TypeOf(string("")) ) // StringType represents all string types, including VARCHAR and BLOB. // https://dev.mysql.com/doc/refman/8.0/en/char.html // https://dev.mysql.com/doc/refman/8.0/en/binary-varbinary.html // https://dev.mysql.com/doc/refman/8.0/en/blob.html +// The type of the returned value is string. type StringType interface { Type CharacterSet() CharacterSet @@ -349,7 +353,7 @@ func (t stringType) SQL(dest []byte, v interface{}) (sqltypes.Value, error) { return sqltypes.Value{}, err } - val := appendAndSlice(dest, []byte(v.(string))) + val := appendAndSliceString(dest, v.(string)) return sqltypes.MakeTrusted(t.baseType, val), nil } @@ -407,6 +411,11 @@ func (t stringType) Type() query.Type { return t.baseType } +// ValueType implements Type interface. +func (t stringType) ValueType() reflect.Type { + return stringValueType +} + // Zero implements Type interface. func (t stringType) Zero() interface{} { return "" @@ -442,7 +451,14 @@ func (t stringType) CreateMatcher(likeStr string) (regex.DisposableMatcher, erro } } -func appendAndSlice(buffer, addition []byte) (slice []byte) { +func appendAndSliceString(buffer []byte, addition string) (slice []byte) { + stop := len(buffer) + buffer = append(buffer, addition...) + slice = buffer[stop:] + return +} + +func appendAndSliceBytes(buffer, addition []byte) (slice []byte) { stop := len(buffer) buffer = append(buffer, addition...) slice = buffer[stop:] diff --git a/sql/stringtype_test.go b/sql/stringtype_test.go index 9b1b7a03ca..b4b35214a2 100644 --- a/sql/stringtype_test.go +++ b/sql/stringtype_test.go @@ -16,6 +16,7 @@ package sql import ( "fmt" + "reflect" "strings" "testing" "time" @@ -327,6 +328,9 @@ func TestStringConvert(t *testing.T) { } else { require.NoError(t, err) assert.Equal(t, test.expectedVal, val) + if val != nil { + assert.Equal(t, test.typ.ValueType(), reflect.TypeOf(val)) + } } }) } diff --git a/sql/system_booltype.go b/sql/system_booltype.go index 96e5f872a9..3f6f5a43f4 100644 --- a/sql/system_booltype.go +++ b/sql/system_booltype.go @@ -15,13 +15,18 @@ package sql import ( + "reflect" "strconv" "strings" + "github.com/shopspring/decimal" + "github.com/dolthub/vitess/go/sqltypes" "github.com/dolthub/vitess/go/vt/proto/query" ) +var systemBoolValueType = reflect.TypeOf(int8(0)) + // systemBoolType is an internal boolean type ONLY for system variables. type systemBoolType struct { varName string @@ -95,6 +100,14 @@ func (t systemBoolType) Convert(v interface{}) (interface{}, error) { if value == float64(int64(value)) { return t.Convert(int64(value)) } + case decimal.Decimal: + f, _ := value.Float64() + return t.Convert(f) + case decimal.NullDecimal: + if value.Valid { + f, _ := value.Decimal.Float64() + return t.Convert(f) + } case string: switch strings.ToLower(value) { case "on", "true": @@ -157,6 +170,11 @@ func (t systemBoolType) Type() query.Type { return sqltypes.Int8 } +// ValueType implements Type interface. +func (t systemBoolType) ValueType() reflect.Type { + return systemBoolValueType +} + // Zero implements Type interface. func (t systemBoolType) Zero() interface{} { return int8(0) diff --git a/sql/system_doubletype.go b/sql/system_doubletype.go index d038d64090..5a7849f188 100644 --- a/sql/system_doubletype.go +++ b/sql/system_doubletype.go @@ -15,12 +15,17 @@ package sql import ( + "reflect" "strconv" + "github.com/shopspring/decimal" + "github.com/dolthub/vitess/go/sqltypes" "github.com/dolthub/vitess/go/vt/proto/query" ) +var systemDoubleValueType = reflect.TypeOf(float64(0)) + // systemDoubleType is an internal double type ONLY for system variables. type systemDoubleType struct { varName string @@ -87,6 +92,14 @@ func (t systemDoubleType) Convert(v interface{}) (interface{}, error) { if value >= t.lowerbound && value <= t.upperbound { return value, nil } + case decimal.Decimal: + f, _ := value.Float64() + return t.Convert(f) + case decimal.NullDecimal: + if value.Valid { + f, _ := value.Decimal.Float64() + return t.Convert(f) + } } return nil, ErrInvalidSystemVariableValue.New(t.varName, v) @@ -142,6 +155,11 @@ func (t systemDoubleType) Type() query.Type { return sqltypes.Float64 } +// ValueType implements Type interface. +func (t systemDoubleType) ValueType() reflect.Type { + return systemDoubleValueType +} + // Zero implements Type interface. func (t systemDoubleType) Zero() interface{} { return float64(0) diff --git a/sql/system_enumtype.go b/sql/system_enumtype.go index 90c8c2646f..d219b861c1 100644 --- a/sql/system_enumtype.go +++ b/sql/system_enumtype.go @@ -15,12 +15,17 @@ package sql import ( + "reflect" "strings" + "github.com/shopspring/decimal" + "github.com/dolthub/vitess/go/sqltypes" "github.com/dolthub/vitess/go/vt/proto/query" ) +var systemEnumValueType = reflect.TypeOf(string("")) + // systemEnumType is an internal enum type ONLY for system variables. type systemEnumType struct { varName string @@ -98,6 +103,14 @@ func (t systemEnumType) Convert(v interface{}) (interface{}, error) { if value == float64(int(value)) { return t.Convert(int(value)) } + case decimal.Decimal: + f, _ := value.Float64() + return t.Convert(f) + case decimal.NullDecimal: + if value.Valid { + f, _ := value.Decimal.Float64() + return t.Convert(f) + } case string: if idx, ok := t.valToIndex[strings.ToLower(value)]; ok { return t.indexToVal[idx], nil @@ -145,7 +158,7 @@ func (t systemEnumType) SQL(dest []byte, v interface{}) (sqltypes.Value, error) return sqltypes.Value{}, err } - val := appendAndSlice(dest, []byte(v.(string))) + val := appendAndSliceString(dest, v.(string)) return sqltypes.MakeTrusted(t.Type(), val), nil } @@ -160,6 +173,11 @@ func (t systemEnumType) Type() query.Type { return sqltypes.VarChar } +// ValueType implements Type interface. +func (t systemEnumType) ValueType() reflect.Type { + return systemEnumValueType +} + // Zero implements Type interface. func (t systemEnumType) Zero() interface{} { return "" diff --git a/sql/system_inttype.go b/sql/system_inttype.go index 7696012f3b..0b2ccb3556 100644 --- a/sql/system_inttype.go +++ b/sql/system_inttype.go @@ -15,12 +15,17 @@ package sql import ( + "reflect" "strconv" + "github.com/shopspring/decimal" + "github.com/dolthub/vitess/go/sqltypes" "github.com/dolthub/vitess/go/vt/proto/query" ) +var systemIntValueType = reflect.TypeOf(int64(0)) + // systemIntType is an internal integer type ONLY for system variables. type systemIntType struct { varName string @@ -95,6 +100,14 @@ func (t systemIntType) Convert(v interface{}) (interface{}, error) { if value == float64(int64(value)) { return t.Convert(int64(value)) } + case decimal.Decimal: + f, _ := value.Float64() + return t.Convert(f) + case decimal.NullDecimal: + if value.Valid { + f, _ := value.Decimal.Float64() + return t.Convert(f) + } } return nil, ErrInvalidSystemVariableValue.New(t.varName, v) @@ -150,6 +163,11 @@ func (t systemIntType) Type() query.Type { return sqltypes.Int64 } +// ValueType implements Type interface. +func (t systemIntType) ValueType() reflect.Type { + return systemIntValueType +} + // Zero implements Type interface. func (t systemIntType) Zero() interface{} { return int64(0) diff --git a/sql/system_settype.go b/sql/system_settype.go index 749dfd10d2..30150f2e1e 100644 --- a/sql/system_settype.go +++ b/sql/system_settype.go @@ -15,8 +15,11 @@ package sql import ( + "reflect" + "github.com/dolthub/vitess/go/sqltypes" "github.com/dolthub/vitess/go/vt/proto/query" + "github.com/shopspring/decimal" ) // systemSetType is an internal set type ONLY for system variables. @@ -37,19 +40,21 @@ func (t systemSetType) Compare(a interface{}, b interface{}) (int, error) { if a == nil || b == nil { return 0, ErrInvalidSystemVariableValue.New(t.varName, nil) } - ai, err := t.Marshal(a) + ai, err := t.Convert(a) if err != nil { return 0, err } - bi, err := t.Marshal(b) + bi, err := t.Convert(b) if err != nil { return 0, err } + au := ai.(uint64) + bu := bi.(uint64) - if ai == bi { + if au == bu { return 0, nil } - if ai < bi { + if au < bu { return -1, nil } return 1, nil @@ -87,6 +92,14 @@ func (t systemSetType) Convert(v interface{}) (interface{}, error) { if value == float64(int64(value)) { return t.SetType.Convert(int64(value)) } + case decimal.Decimal: + f, _ := value.Float64() + return t.Convert(f) + case decimal.NullDecimal: + if value.Valid { + f, _ := value.Decimal.Float64() + return t.Convert(f) + } case string: return t.SetType.Convert(value) } @@ -121,13 +134,16 @@ func (t systemSetType) SQL(dest []byte, v interface{}) (sqltypes.Value, error) { if v == nil { return sqltypes.NULL, nil } - - v, err := t.Convert(v) + convertedValue, err := t.Convert(v) + if err != nil { + return sqltypes.Value{}, err + } + value, err := t.BitsToString(convertedValue.(uint64)) if err != nil { return sqltypes.Value{}, err } - val := appendAndSlice(dest, []byte(v.(string))) + val := appendAndSliceString(dest, value) return sqltypes.MakeTrusted(t.Type(), val), nil } @@ -142,6 +158,11 @@ func (t systemSetType) Type() query.Type { return sqltypes.VarChar } +// ValueType implements Type interface. +func (t systemSetType) ValueType() reflect.Type { + return t.SetType.ValueType() +} + // Zero implements Type interface. func (t systemSetType) Zero() interface{} { return "" @@ -149,11 +170,11 @@ func (t systemSetType) Zero() interface{} { // EncodeValue implements SystemVariableType interface. func (t systemSetType) EncodeValue(val interface{}) (string, error) { - expectedVal, ok := val.(string) + expectedVal, ok := val.(uint64) if !ok { return "", ErrSystemVariableCodeFail.New(val, t.String()) } - return expectedVal, nil + return t.BitsToString(expectedVal) } // DecodeValue implements SystemVariableType interface. diff --git a/sql/system_stringtype.go b/sql/system_stringtype.go index 5337da08ec..ddcdd13e65 100644 --- a/sql/system_stringtype.go +++ b/sql/system_stringtype.go @@ -15,10 +15,14 @@ package sql import ( + "reflect" + "github.com/dolthub/vitess/go/sqltypes" "github.com/dolthub/vitess/go/vt/proto/query" ) +var systemStringValueType = reflect.TypeOf(string("")) + // systemStringType is an internal string type ONLY for system variables. type systemStringType struct { varName string @@ -98,7 +102,7 @@ func (t systemStringType) SQL(dest []byte, v interface{}) (sqltypes.Value, error return sqltypes.Value{}, err } - val := appendAndSlice(dest, []byte(v.(string))) + val := appendAndSliceString(dest, v.(string)) return sqltypes.MakeTrusted(t.Type(), val), nil } @@ -113,6 +117,11 @@ func (t systemStringType) Type() query.Type { return sqltypes.VarChar } +// ValueType implements Type interface. +func (t systemStringType) ValueType() reflect.Type { + return systemStringValueType +} + // Zero implements Type interface. func (t systemStringType) Zero() interface{} { return "" diff --git a/sql/system_uinttype.go b/sql/system_uinttype.go index 2ec670fb6f..5ff304db9f 100644 --- a/sql/system_uinttype.go +++ b/sql/system_uinttype.go @@ -15,12 +15,17 @@ package sql import ( + "reflect" "strconv" + "github.com/shopspring/decimal" + "github.com/dolthub/vitess/go/sqltypes" "github.com/dolthub/vitess/go/vt/proto/query" ) +var systemUintValueType = reflect.TypeOf(uint64(0)) + // systemUintType is an internal unsigned integer type ONLY for system variables. type systemUintType struct { varName string @@ -91,6 +96,14 @@ func (t systemUintType) Convert(v interface{}) (interface{}, error) { if value == float64(uint64(value)) { return t.Convert(uint64(value)) } + case decimal.Decimal: + f, _ := value.Float64() + return t.Convert(f) + case decimal.NullDecimal: + if value.Valid { + f, _ := value.Decimal.Float64() + return t.Convert(f) + } } return nil, ErrInvalidSystemVariableValue.New(t.varName, v) @@ -146,6 +159,11 @@ func (t systemUintType) Type() query.Type { return sqltypes.Uint64 } +// ValueType implements Type interface. +func (t systemUintType) ValueType() reflect.Type { + return systemUintValueType +} + // Zero implements Type interface. func (t systemUintType) Zero() interface{} { return uint64(0) diff --git a/sql/timetype.go b/sql/timetype.go index eff9a83433..f0bed3adad 100644 --- a/sql/timetype.go +++ b/sql/timetype.go @@ -17,10 +17,13 @@ package sql import ( "fmt" "math" + "reflect" "strconv" "strings" "time" + "github.com/shopspring/decimal" + "github.com/dolthub/vitess/go/sqltypes" "github.com/dolthub/vitess/go/vt/proto/query" "gopkg.in/src-d/go-errors.v1" @@ -37,28 +40,35 @@ var ( microsecondsPerMinute int64 = 60000000 microsecondsPerHour int64 = 3600000000 nanosecondsPerMicrosecond int64 = 1000 + + timeValueType = reflect.TypeOf(Timespan(0)) ) -// Represents the TIME type. +// TimeType represents the TIME type. // https://dev.mysql.com/doc/refman/8.0/en/time.html -// TIME is implemented as TIME(6) +// TIME is implemented as TIME(6). +// The type of the returned value is Timespan. // TODO: implement parameters on the TIME type type TimeType interface { Type + // ConvertToTimespan returns a Timespan from the given interface. Follows the same conversion rules as + // Convert(), in that this will process the value based on its base-10 visual representation (for example, Convert() + // will interpret the value `1234` as 12 minutes and 34 seconds). Returns an error for nil values. + ConvertToTimespan(v interface{}) (Timespan, error) + // ConvertToTimeDuration returns a time.Duration from the given interface. Follows the same conversion rules as + // Convert(), in that this will process the value based on its base-10 visual representation (for example, Convert() + // will interpret the value `1234` as 12 minutes and 34 seconds). Returns an error for nil values. ConvertToTimeDuration(v interface{}) (time.Duration, error) - //TODO: move this out of go-mysql-server and into the Dolt layer - Marshal(v interface{}) (int64, error) - Unmarshal(v int64) string + // MicrosecondsToTimespan returns a Timespan from the given number of microseconds. This differs from Convert(), as + // that will process the value based on its base-10 visual representation (for example, Convert() will interpret + // the value `1234` as 12 minutes and 34 seconds). This clamps the given microseconds to the allowed range. + MicrosecondsToTimespan(v int64) Timespan } type timespanType struct{} -type timespanImpl struct { - negative bool - hours int16 - minutes int8 - seconds int8 - microseconds int32 -} + +// Timespan is the value type returned by TimeType.Convert(). +type Timespan int64 // Compare implements Type interface. func (t timespanType) Compare(a interface{}, b interface{}) (int, error) { @@ -66,24 +76,16 @@ func (t timespanType) Compare(a interface{}, b interface{}) (int, error) { return res, nil } - as, err := t.ConvertToTimespanImpl(a) + as, err := t.ConvertToTimespan(a) if err != nil { return 0, err } - bs, err := t.ConvertToTimespanImpl(b) + bs, err := t.ConvertToTimespan(b) if err != nil { return 0, err } - ai := as.AsMicroseconds() - bi := bs.AsMicroseconds() - - if ai < bi { - return -1, nil - } else if ai > bi { - return 1, nil - } - return 0, nil + return as.Compare(bs), nil } func (t timespanType) Convert(v interface{}) (interface{}, error) { @@ -91,11 +93,7 @@ func (t timespanType) Convert(v interface{}) (interface{}, error) { return nil, nil } - if ti, err := t.ConvertToTimespanImpl(v); err != nil { - return nil, err - } else { - return ti.String(), nil - } + return t.ConvertToTimespan(v) } // MustConvert implements the Type interface. @@ -107,38 +105,44 @@ func (t timespanType) MustConvert(v interface{}) interface{} { return value } -// Convert implements Type interface. -func (t timespanType) ConvertToTimespanImpl(v interface{}) (timespanImpl, error) { +// ConvertToTimespan converts the given interface value to a Timespan. This follows the conversion rules of MySQL, which +// are based on the base-10 visual representation of numbers (for example, Time.Convert() will interpret the value +// `1234` as 12 minutes and 34 seconds). Returns an error on a nil value. +func (t timespanType) ConvertToTimespan(v interface{}) (Timespan, error) { switch value := v.(type) { + case Timespan: + // We only create a Timespan if it's valid, so we can skip this check if we receive a Timespan. + // Timespan values are not intended to be modified by an integrator, therefore it is on the integrator if they corrupt a Timespan. + return value, nil case int: - return t.ConvertToTimespanImpl(int64(value)) + return t.ConvertToTimespan(int64(value)) case uint: - return t.ConvertToTimespanImpl(int64(value)) + return t.ConvertToTimespan(int64(value)) case int8: - return t.ConvertToTimespanImpl(int64(value)) + return t.ConvertToTimespan(int64(value)) case uint8: - return t.ConvertToTimespanImpl(int64(value)) + return t.ConvertToTimespan(int64(value)) case int16: - return t.ConvertToTimespanImpl(int64(value)) + return t.ConvertToTimespan(int64(value)) case uint16: - return t.ConvertToTimespanImpl(int64(value)) + return t.ConvertToTimespan(int64(value)) case int32: - return t.ConvertToTimespanImpl(int64(value)) + return t.ConvertToTimespan(int64(value)) case uint32: - return t.ConvertToTimespanImpl(int64(value)) + return t.ConvertToTimespan(int64(value)) case int64: absValue := int64Abs(value) if absValue >= -59 && absValue <= 59 { - return microsecondsToTimespan(value * microsecondsPerSecond), nil + return t.MicrosecondsToTimespan(value * microsecondsPerSecond), nil } else if absValue >= 100 && absValue <= 9999 { minutes := absValue / 100 seconds := absValue % 100 if minutes <= 59 && seconds <= 59 { microseconds := (seconds * microsecondsPerSecond) + (minutes * microsecondsPerMinute) if value < 0 { - return microsecondsToTimespan(-1 * microseconds), nil + return t.MicrosecondsToTimespan(-1 * microseconds), nil } - return microsecondsToTimespan(microseconds), nil + return t.MicrosecondsToTimespan(microseconds), nil } } else if absValue >= 10000 && absValue <= 9999999 { hours := absValue / 10000 @@ -147,15 +151,15 @@ func (t timespanType) ConvertToTimespanImpl(v interface{}) (timespanImpl, error) if minutes <= 59 && seconds <= 59 { microseconds := (seconds * microsecondsPerSecond) + (minutes * microsecondsPerMinute) + (hours * microsecondsPerHour) if value < 0 { - return microsecondsToTimespan(-1 * microseconds), nil + return t.MicrosecondsToTimespan(-1 * microseconds), nil } - return microsecondsToTimespan(microseconds), nil + return t.MicrosecondsToTimespan(microseconds), nil } } case uint64: - return t.ConvertToTimespanImpl(int64(value)) + return t.ConvertToTimespan(int64(value)) case float32: - return t.ConvertToTimespanImpl(float64(value)) + return t.ConvertToTimespan(float64(value)) case float64: intValue := int64(value) microseconds := int64Abs(int64(math.Round((value - float64(intValue)) * float64(microsecondsPerSecond)))) @@ -163,18 +167,18 @@ func (t timespanType) ConvertToTimespanImpl(v interface{}) (timespanImpl, error) if absValue >= -59 && absValue <= 59 { totalMicroseconds := (absValue * microsecondsPerSecond) + microseconds if value < 0 { - return microsecondsToTimespan(-1 * totalMicroseconds), nil + return t.MicrosecondsToTimespan(-1 * totalMicroseconds), nil } - return microsecondsToTimespan(totalMicroseconds), nil + return t.MicrosecondsToTimespan(totalMicroseconds), nil } else if absValue >= 100 && absValue <= 9999 { minutes := absValue / 100 seconds := absValue % 100 if minutes <= 59 && seconds <= 59 { totalMicroseconds := (seconds * microsecondsPerSecond) + (minutes * microsecondsPerMinute) + microseconds if value < 0 { - return microsecondsToTimespan(-1 * totalMicroseconds), nil + return t.MicrosecondsToTimespan(-1 * totalMicroseconds), nil } - return microsecondsToTimespan(totalMicroseconds), nil + return t.MicrosecondsToTimespan(totalMicroseconds), nil } } else if absValue >= 10000 && absValue <= 9999999 { hours := absValue / 10000 @@ -183,11 +187,17 @@ func (t timespanType) ConvertToTimespanImpl(v interface{}) (timespanImpl, error) if minutes <= 59 && seconds <= 59 { totalMicroseconds := (seconds * microsecondsPerSecond) + (minutes * microsecondsPerMinute) + (hours * microsecondsPerHour) + microseconds if value < 0 { - return microsecondsToTimespan(-1 * totalMicroseconds), nil + return t.MicrosecondsToTimespan(-1 * totalMicroseconds), nil } - return microsecondsToTimespan(totalMicroseconds), nil + return t.MicrosecondsToTimespan(totalMicroseconds), nil } } + case decimal.Decimal: + return t.ConvertToTimespan(value.IntPart()) + case decimal.NullDecimal: + if value.Valid { + return t.ConvertToTimespan(value.Decimal.IntPart()) + } case string: impl, err := stringToTimespan(value) if err == nil { @@ -196,26 +206,27 @@ func (t timespanType) ConvertToTimespanImpl(v interface{}) (timespanImpl, error) if strings.Contains(value, ".") { strAsDouble, err := strconv.ParseFloat(value, 64) if err != nil { - return timespanImpl{}, ErrConvertingToTimeType.New(v) + return Timespan(0), ErrConvertingToTimeType.New(v) } - return t.ConvertToTimespanImpl(strAsDouble) + return t.ConvertToTimespan(strAsDouble) } else { strAsInt, err := strconv.ParseInt(value, 10, 64) if err != nil { - return timespanImpl{}, ErrConvertingToTimeType.New(v) + return Timespan(0), ErrConvertingToTimeType.New(v) } - return t.ConvertToTimespanImpl(strAsInt) + return t.ConvertToTimespan(strAsInt) } case time.Duration: - microseconds := value.Nanoseconds() / 1000 - return microsecondsToTimespan(microseconds), nil + microseconds := value.Nanoseconds() / nanosecondsPerMicrosecond + return t.MicrosecondsToTimespan(microseconds), nil } - return timespanImpl{}, ErrConvertingToTimeType.New(v) + return Timespan(0), ErrConvertingToTimeType.New(v) } +// ConvertToTimeDuration implements the TimeType interface. func (t timespanType) ConvertToTimeDuration(v interface{}) (time.Duration, error) { - val, err := t.ConvertToTimespanImpl(v) + val, err := t.ConvertToTimespan(v) if err != nil { return time.Duration(0), err } @@ -235,12 +246,15 @@ func (t timespanType) Promote() Type { // SQL implements Type interface. func (t timespanType) SQL(dest []byte, v interface{}) (sqltypes.Value, error) { - ti, err := t.ConvertToTimespanImpl(v) + if v == nil { + return sqltypes.NULL, nil + } + ti, err := t.ConvertToTimespan(v) if err != nil { return sqltypes.Value{}, err } - val := appendAndSlice(dest, []byte(ti.String())) + val := appendAndSliceString(dest, ti.String()) return sqltypes.MakeTrusted(sqltypes.Time, val), nil } @@ -255,23 +269,14 @@ func (t timespanType) Type() query.Type { return sqltypes.Time } -// Zero implements Type interface. -func (t timespanType) Zero() interface{} { - return "00:00:00" -} - -// Marshal takes a valid Time value and returns it as an int64. -func (t timespanType) Marshal(v interface{}) (int64, error) { - if ti, err := t.ConvertToTimespanImpl(v); err != nil { - return 0, err - } else { - return ti.AsMicroseconds(), nil - } +// ValueType implements Type interface. +func (t timespanType) ValueType() reflect.Type { + return timeValueType } -// Unmarshal takes a previously-marshalled value and returns it as a string. -func (t timespanType) Unmarshal(v int64) string { - return microsecondsToTimespan(v).String() +// Zero implements Type interface. +func (t timespanType) Zero() interface{} { + return Timespan(0) } // No built in for absolute values on int64 @@ -280,10 +285,15 @@ func int64Abs(v int64) int64 { return (v ^ shift) - shift } -func stringToTimespan(s string) (timespanImpl, error) { - impl := timespanImpl{} +func stringToTimespan(s string) (Timespan, error) { + var negative bool + var hours int16 + var minutes int8 + var seconds int8 + var microseconds int32 + if len(s) > 0 && s[0] == '-' { - impl.negative = true + negative = true s = s[1:] } @@ -296,15 +306,15 @@ func stringToTimespan(s string) (timespanImpl, error) { microStr += strings.Repeat("0", 6-len(comps[1])) } microStr, remainStr := microStr[0:6], microStr[6:] - microseconds, err := strconv.Atoi(microStr) + convertedMicroseconds, err := strconv.Atoi(microStr) if err != nil { - return timespanImpl{}, ErrConvertingToTimeType.New(s) + return Timespan(0), ErrConvertingToTimeType.New(s) } // MySQL just uses the last digit to round up. This is weird, but matches their implementation. if len(remainStr) > 0 && remainStr[len(remainStr)-1:] >= "5" { - microseconds++ + convertedMicroseconds++ } - impl.microseconds = int32(microseconds) + microseconds = int32(convertedMicroseconds) } // Parse H-M-S time @@ -312,16 +322,16 @@ func stringToTimespan(s string) (timespanImpl, error) { hms := make([]string, 3) if len(hmsComps) >= 2 { if len(hmsComps[0]) > 3 { - return timespanImpl{}, ErrConvertingToTimeType.New(s) + return Timespan(0), ErrConvertingToTimeType.New(s) } hms[0] = hmsComps[0] if len(hmsComps[1]) > 2 { - return timespanImpl{}, ErrConvertingToTimeType.New(s) + return Timespan(0), ErrConvertingToTimeType.New(s) } hms[1] = hmsComps[1] if len(hmsComps) == 3 { if len(hmsComps[2]) > 2 { - return timespanImpl{}, ErrConvertingToTimeType.New(s) + return Timespan(0), ErrConvertingToTimeType.New(s) } hms[2] = hmsComps[2] } @@ -332,52 +342,52 @@ func stringToTimespan(s string) (timespanImpl, error) { hms[0] = safeSubstr(hmsComps[0], l-7, l-4) } - hours, err := strconv.Atoi(hms[0]) + hmsHours, err := strconv.Atoi(hms[0]) if len(hms[0]) > 0 && err != nil { - return timespanImpl{}, ErrConvertingToTimeType.New(s) + return Timespan(0), ErrConvertingToTimeType.New(s) } - impl.hours = int16(hours) + hours = int16(hmsHours) - minutes, err := strconv.Atoi(hms[1]) + hmsMinutes, err := strconv.Atoi(hms[1]) if len(hms[1]) > 0 && err != nil { - return timespanImpl{}, ErrConvertingToTimeType.New(s) - } else if minutes >= 60 { - return timespanImpl{}, ErrConvertingToTimeType.New(s) + return Timespan(0), ErrConvertingToTimeType.New(s) + } else if hmsMinutes >= 60 { + return Timespan(0), ErrConvertingToTimeType.New(s) } - impl.minutes = int8(minutes) + minutes = int8(hmsMinutes) - seconds, err := strconv.Atoi(hms[2]) + hmsSeconds, err := strconv.Atoi(hms[2]) if len(hms[2]) > 0 && err != nil { - return timespanImpl{}, ErrConvertingToTimeType.New(s) - } else if seconds >= 60 { - return timespanImpl{}, ErrConvertingToTimeType.New(s) + return Timespan(0), ErrConvertingToTimeType.New(s) + } else if hmsSeconds >= 60 { + return Timespan(0), ErrConvertingToTimeType.New(s) } - impl.seconds = int8(seconds) + seconds = int8(hmsSeconds) - if impl.microseconds == int32(microsecondsPerSecond) { - impl.microseconds = 0 - impl.seconds++ + if microseconds == int32(microsecondsPerSecond) { + microseconds = 0 + seconds++ } - if impl.seconds == 60 { - impl.seconds = 0 - impl.minutes++ + if seconds == 60 { + seconds = 0 + minutes++ } - if impl.minutes == 60 { - impl.minutes = 0 - impl.hours++ + if minutes == 60 { + minutes = 0 + hours++ } - if impl.hours > 838 { - impl.hours = 838 - impl.minutes = 59 - impl.seconds = 59 + if hours > 838 { + hours = 838 + minutes = 59 + seconds = 59 } - if impl.hours == 838 && impl.minutes == 59 && impl.seconds == 59 { - impl.microseconds = 0 + if hours == 838 && minutes == 59 && seconds == 59 { + microseconds = 0 } - return impl, nil + return unitsToTimespan(negative, hours, minutes, seconds, microseconds), nil } func safeSubstr(s string, start int, end int) string { @@ -396,46 +406,103 @@ func safeSubstr(s string, start int, end int) string { return s[start:end] } -func microsecondsToTimespan(v int64) timespanImpl { +// MicrosecondsToTimespan implements the TimeType interface. +func (_ timespanType) MicrosecondsToTimespan(v int64) Timespan { if v < timespanMinimum { v = timespanMinimum } else if v > timespanMaximum { v = timespanMaximum } + return Timespan(v) +} - absV := int64Abs(v) - - return timespanImpl{ - negative: v < 0, - hours: int16(absV / microsecondsPerHour), - minutes: int8((absV / microsecondsPerMinute) % 60), - seconds: int8((absV / microsecondsPerSecond) % 60), - microseconds: int32(absV % microsecondsPerSecond), +func unitsToTimespan(isNegative bool, hours int16, minutes int8, seconds int8, microseconds int32) Timespan { + negative := int64(1) + if isNegative { + negative = -1 } + return Timespan(negative * + (int64(microseconds) + + (int64(seconds) * microsecondsPerSecond) + + (int64(minutes) * microsecondsPerMinute) + + (int64(hours) * microsecondsPerHour))) } -func (t timespanImpl) String() string { +func (t Timespan) timespanToUnits() (isNegative bool, hours int16, minutes int8, seconds int8, microseconds int32) { + isNegative = t < 0 + absV := int64Abs(int64(t)) + hours = int16(absV / microsecondsPerHour) + minutes = int8((absV / microsecondsPerMinute) % 60) + seconds = int8((absV / microsecondsPerSecond) % 60) + microseconds = int32(absV % microsecondsPerSecond) + return +} + +// String returns the Timespan formatted as a string (such as for display purposes). +func (t Timespan) String() string { + isNegative, hours, minutes, seconds, microseconds := t.timespanToUnits() sign := "" - if t.negative { + if isNegative { sign = "-" } - if t.microseconds == 0 { - return fmt.Sprintf("%v%02d:%02d:%02d", sign, t.hours, t.minutes, t.seconds) + if microseconds == 0 { + return fmt.Sprintf("%v%02d:%02d:%02d", sign, hours, minutes, seconds) } - return fmt.Sprintf("%v%02d:%02d:%02d.%06d", sign, t.hours, t.minutes, t.seconds, t.microseconds) + return fmt.Sprintf("%v%02d:%02d:%02d.%06d", sign, hours, minutes, seconds, microseconds) } -func (t timespanImpl) AsMicroseconds() int64 { - negative := int64(1) - if t.negative { - negative = -1 - } - return negative * (int64(t.microseconds) + - (int64(t.seconds) * microsecondsPerSecond) + - (int64(t.minutes) * microsecondsPerMinute) + - (int64(t.hours) * microsecondsPerHour)) +// AsMicroseconds returns the Timespan in microseconds. +func (t Timespan) AsMicroseconds() int64 { + // Timespan already being implemented in microseconds is an implementation detail that integrators do not need to + // know about. This is also the reason for the comparison functions. + return int64(t) } -func (t timespanImpl) AsTimeDuration() time.Duration { +// AsTimeDuration returns the Timespan as a time.Duration. +func (t Timespan) AsTimeDuration() time.Duration { return time.Duration(t.AsMicroseconds() * nanosecondsPerMicrosecond) } + +// Equals returns whether the calling Timespan and given Timespan are equivalent. +func (t Timespan) Equals(other Timespan) bool { + return t == other +} + +// Compare returns an integer comparing two values. The result will be 0 if t==other, -1 if t < other, and +1 if t > other. +func (t Timespan) Compare(other Timespan) int { + if t < other { + return -1 + } else if t > other { + return 1 + } + return 0 +} + +// Negate returns a new Timespan that has been negated. +func (t Timespan) Negate() Timespan { + return -1 * t +} + +// Add returns a new Timespan that is the sum of the calling Timespan and given Timespan. The resulting Timespan is +// clamped to the allowed range. +func (t Timespan) Add(other Timespan) Timespan { + v := int64(t + other) + if v < timespanMinimum { + v = timespanMinimum + } else if v > timespanMaximum { + v = timespanMaximum + } + return Timespan(v) +} + +// Subtract returns a new Timespan that is the difference of the calling Timespan and given Timespan. The resulting +// Timespan is clamped to the allowed range. +func (t Timespan) Subtract(other Timespan) Timespan { + v := int64(t - other) + if v < timespanMinimum { + v = timespanMinimum + } else if v > timespanMaximum { + v = timespanMaximum + } + return Timespan(v) +} diff --git a/sql/timetype_test.go b/sql/timetype_test.go index 577d72c547..68fce4bb41 100644 --- a/sql/timetype_test.go +++ b/sql/timetype_test.go @@ -166,12 +166,16 @@ func TestTimeConvert(t *testing.T) { assert.Error(t, err) } else { require.NoError(t, err) - assert.Equal(t, test.expectedVal, val) - if test.val != nil { - mar, err := Time.Marshal(test.val) + if test.val == nil { + assert.Equal(t, test.expectedVal, val) + } else { + assert.Equal(t, test.expectedVal, val.(Timespan).String()) + timespan, err := Time.ConvertToTimespan(test.val) require.NoError(t, err) - umar := Time.Unmarshal(mar) - cmp, err := Time.Compare(test.val, umar) + require.True(t, timespan.Equals(val.(Timespan))) + ms := timespan.AsMicroseconds() + ums := Time.MicrosecondsToTimespan(ms) + cmp, err := Time.Compare(test.val, ums) require.NoError(t, err) assert.Equal(t, 0, cmp) } diff --git a/sql/tupletype.go b/sql/tupletype.go index 095a759f2b..aa31a67b1b 100644 --- a/sql/tupletype.go +++ b/sql/tupletype.go @@ -16,14 +16,19 @@ package sql import ( "fmt" + "reflect" "strings" "github.com/dolthub/vitess/go/sqltypes" "github.com/dolthub/vitess/go/vt/proto/query" ) +var tupleValueType = reflect.TypeOf((*[]interface{})(nil)).Elem() + type TupleType []Type +var _ Type = TupleType{nil} + // CreateTuple returns a new tuple type with the given element types. func CreateTuple(types ...Type) Type { return TupleType(types) @@ -124,6 +129,11 @@ func (t TupleType) Type() query.Type { return sqltypes.Expression } +// ValueType implements Type interface. +func (t TupleType) ValueType() reflect.Type { + return tupleValueType +} + func (t TupleType) Zero() interface{} { zeroes := make([]interface{}, len(t)) for i, tt := range t { diff --git a/sql/type.go b/sql/type.go index aeb6c9a05b..3134309ac2 100644 --- a/sql/type.go +++ b/sql/type.go @@ -17,6 +17,7 @@ package sql import ( "fmt" "io" + "reflect" "strconv" "strings" "time" @@ -66,6 +67,8 @@ type Type interface { SQL(dest []byte, v interface{}) (sqltypes.Value, error) // Type returns the query.Type for the given Type. Type() query.Type + // ValueType returns the Go type of the value returned by Convert(). + ValueType() reflect.Type // Zero returns the golang zero value for this type Zero() interface{} fmt.Stringer @@ -140,7 +143,7 @@ func ApproximateTypeFromValue(val interface{}) Type { return Uint16 case uint8: return Uint8 - case time.Duration: + case Timespan, time.Duration: return Time case time.Time: return Datetime diff --git a/sql/yeartype.go b/sql/yeartype.go index 3b1623e864..13fd3177d7 100644 --- a/sql/yeartype.go +++ b/sql/yeartype.go @@ -15,9 +15,12 @@ package sql import ( + "reflect" "strconv" "time" + "github.com/shopspring/decimal" + "github.com/dolthub/vitess/go/sqltypes" "github.com/dolthub/vitess/go/vt/proto/query" "gopkg.in/src-d/go-errors.v1" @@ -27,10 +30,13 @@ var ( Year YearType = yearType{} ErrConvertingToYear = errors.NewKind("value %v is not a valid Year") + + yearValueType = reflect.TypeOf(int16(0)) ) -// Represents the YEAR type. +// YearType represents the YEAR type. // https://dev.mysql.com/doc/refman/8.0/en/year.html +// The type of the returned value is int16. type YearType interface { Type } @@ -105,6 +111,13 @@ func (t yearType) Convert(v interface{}) (interface{}, error) { return t.Convert(int64(value)) case float64: return t.Convert(int64(value)) + case decimal.Decimal: + return t.Convert(value.IntPart()) + case decimal.NullDecimal: + if !value.Valid { + return nil, nil + } + return t.Convert(value.Decimal.IntPart()) case string: valueLength := len(value) if valueLength == 1 || valueLength == 2 || valueLength == 4 { @@ -175,6 +188,11 @@ func (t yearType) Type() query.Type { return sqltypes.Year } +// ValueType implements Type interface. +func (t yearType) ValueType() reflect.Type { + return yearValueType +} + // Zero implements Type interface. func (t yearType) Zero() interface{} { return int16(0) diff --git a/sql/yeartype_test.go b/sql/yeartype_test.go index 6fe6372b30..339878a874 100644 --- a/sql/yeartype_test.go +++ b/sql/yeartype_test.go @@ -16,6 +16,7 @@ package sql import ( "fmt" + "reflect" "testing" "time" @@ -95,6 +96,9 @@ func TestYearConvert(t *testing.T) { } else { require.NoError(t, err) assert.Equal(t, test.expectedVal, val) + if val != nil { + assert.Equal(t, Year.ValueType(), reflect.TypeOf(val)) + } } }) }