diff --git a/cmd/sqlcmd/main.go b/cmd/sqlcmd/main.go index bb09e023..da9c30c5 100644 --- a/cmd/sqlcmd/main.go +++ b/cmd/sqlcmd/main.go @@ -9,7 +9,7 @@ import ( "github.com/alecthomas/kong" "github.com/denisenkom/go-mssqldb/azuread" - "github.com/gohxs/readline" + "github.com/microsoft/go-sqlcmd/pkg/console" "github.com/microsoft/go-sqlcmd/pkg/sqlcmd" ) @@ -194,18 +194,12 @@ func run(vars *sqlcmd.Variables, args *SQLCmdArguments) (int, error) { } iactive := args.InputFile == nil && args.Query == "" - var console sqlcmd.Console = nil - var line *readline.Instance + var line sqlcmd.Console = nil if iactive { - line, err = readline.New(">") - if err != nil { - return 1, err - } - console = line - defer line.Close() + line = console.NewConsole("") } - s := sqlcmd.New(console, wd, vars) + s := sqlcmd.New(line, wd, vars) setConnect(&s.Connect, args, vars) if args.BatchTerminator != "GO" { err = s.Cmd.SetBatchTerminator(args.BatchTerminator) diff --git a/go.mod b/go.mod index d26d8dd1..26ef0462 100644 --- a/go.mod +++ b/go.mod @@ -4,12 +4,10 @@ go 1.16 require ( github.com/alecthomas/kong v0.2.18-0.20210621093454-54558f65e86f - github.com/chzyer/logex v1.1.10 // indirect - github.com/chzyer/test v0.0.0-20210722231415-061457976a23 // indirect github.com/denisenkom/go-mssqldb v0.12.0 - github.com/gohxs/readline v0.0.0-20171011095936-a780388e6e7c github.com/golang-sql/sqlexp v0.0.0-20170517235910-f1bb20e5a188 github.com/google/uuid v1.2.0 + github.com/peterh/liner v1.2.2 github.com/stretchr/testify v1.7.0 ) diff --git a/go.sum b/go.sum index 38288101..42b2d91f 100644 --- a/go.sum +++ b/go.sum @@ -6,23 +6,21 @@ github.com/Azure/azure-sdk-for-go/sdk/internal v0.7.0 h1:v9p9TfTbf7AwNb5NYQt7hI4 github.com/Azure/azure-sdk-for-go/sdk/internal v0.7.0/go.mod h1:yqy467j36fJxcRV2TzfVZ1pCb5vxm4BtZPUdYWe/Xo8= github.com/alecthomas/kong v0.2.18-0.20210621093454-54558f65e86f h1:VgRM6/wqZIB1D9W3XMllm/wplTmPgI5yvCHUXEsmKps= github.com/alecthomas/kong v0.2.18-0.20210621093454-54558f65e86f/go.mod h1:ka3VZ8GZNPXv9Ov+j4YNLkI8mTuhXyr/0ktSlqIydQQ= -github.com/chzyer/logex v1.1.10 h1:Swpa1K6QvQznwJRcfTfQJmTE72DqScAa40E+fbHEXEE= -github.com/chzyer/logex v1.1.10/go.mod h1:+Ywpsq7O8HXn0nuIou7OrIPyXbp3wmkHB+jjWRnGsAI= -github.com/chzyer/test v0.0.0-20210722231415-061457976a23 h1:dZ0/VyGgQdVGAss6Ju0dt5P0QltE0SFY5Woh6hbIfiQ= -github.com/chzyer/test v0.0.0-20210722231415-061457976a23/go.mod h1:Q3SI9o4m/ZMnBNeIyt5eFwwo7qiLfzFZmjNmxjkiQlU= github.com/davecgh/go-spew v1.1.0/go.mod h1:J7Y8YcW2NihsgmVo/mv3lAwl/skON4iLHjSsI+c5H38= github.com/davecgh/go-spew v1.1.1 h1:vj9j/u1bqnvCEfJOwUhtlOARqs3+rkHYY13jYWTU97c= github.com/davecgh/go-spew v1.1.1/go.mod h1:J7Y8YcW2NihsgmVo/mv3lAwl/skON4iLHjSsI+c5H38= github.com/dnaeon/go-vcr v1.2.0/go.mod h1:R4UdLID7HZT3taECzJs4YgbbH6PIGXB6W/sc5OLb6RQ= -github.com/gohxs/readline v0.0.0-20171011095936-a780388e6e7c h1:yE35fKFwcelIte3q5q1/cPiY7pI7vvf5/j/0ddxNCKs= -github.com/gohxs/readline v0.0.0-20171011095936-a780388e6e7c/go.mod h1:9S/fKAutQ6wVHqm1jnp9D9sc5hu689s9AaTWFS92LaU= github.com/golang-sql/civil v0.0.0-20190719163853-cb61b32ac6fe h1:lXe2qZdvpiX5WZkZR4hgp4KJVfY3nMkvmwbVkpv1rVY= github.com/golang-sql/civil v0.0.0-20190719163853-cb61b32ac6fe/go.mod h1:8vg3r2VgvsThLBIFL93Qb5yWzgyZWhEmBwUJWevAkK0= github.com/golang-sql/sqlexp v0.0.0-20170517235910-f1bb20e5a188 h1:+eHOFJl1BaXrQxKX+T06f78590z4qA2ZzBTqahsKSE4= github.com/golang-sql/sqlexp v0.0.0-20170517235910-f1bb20e5a188/go.mod h1:vXjM/+wXQnTPR4KqTKDgJukSZ6amVRtWMPEjE6sQoK8= github.com/google/uuid v1.2.0 h1:qJYtXnJRWmpe7m/3XlyhrsLrEURqHRM2kxzoxXqyUDs= github.com/google/uuid v1.2.0/go.mod h1:TIyPZe4MgqvfeYDBFedMoGGpEw/LqOeaOT+nhxU+yHo= +github.com/mattn/go-runewidth v0.0.3 h1:a+kO+98RDGEfo6asOGMmpodZq4FNtnGP54yps8BzLR4= +github.com/mattn/go-runewidth v0.0.3/go.mod h1:LwmH8dsx7+W8Uxz3IHJYH5QSwggIsqBzpuz5H//U1FU= github.com/modocache/gover v0.0.0-20171022184752-b58185e213c5/go.mod h1:caMODM3PzxT8aQXRPkAt8xlV/e7d7w8GM5g0fa5F0D8= +github.com/peterh/liner v1.2.2 h1:aJ4AOodmL+JxOZZEL2u9iJf8omNRpqHc/EbrK+3mAXw= +github.com/peterh/liner v1.2.2/go.mod h1:xFwJyiKIXJZUKItq5dGHZSTBRAuG/CpeNpWLyiNRNwI= github.com/pkg/browser v0.0.0-20180916011732-0a3d74bf9ce4 h1:49lOXmGaUpV9Fz3gd7TFZY106KVlPVa5jcYD1gaQf98= github.com/pkg/browser v0.0.0-20180916011732-0a3d74bf9ce4/go.mod h1:4OwLy04Bl9Ef3GJJCoec+30X3LQs/0/m4HFRt/2LUSA= github.com/pkg/errors v0.9.1 h1:FEBLx1zS214owpjy7qsBeixbURkuhQAwrK5UwLGTwt4= @@ -44,8 +42,9 @@ golang.org/x/sys v0.0.0-20190215142949-d0b11bdaac8a/go.mod h1:STP8DvDyc/dI5b8T5h golang.org/x/sys v0.0.0-20190215142949-d0b11bdaac8a/go.mod h1:STP8DvDyc/dI5b8T5hshtkjS+E42TnysNCUPdjciGhY= golang.org/x/sys v0.0.0-20190412213103-97732733099d/go.mod h1:h1NjWce9XRLGQEsW7wpKNCjG9DtNlClVuFLEZdDNbEs= golang.org/x/sys v0.0.0-20201119102817-f84b799fce68/go.mod h1:h1NjWce9XRLGQEsW7wpKNCjG9DtNlClVuFLEZdDNbEs= -golang.org/x/sys v0.0.0-20210423082822-04245dca01da h1:b3NXsE2LusjYGGjL5bxEVZZORm/YEFFrWFjR8eFrw/c= golang.org/x/sys v0.0.0-20210423082822-04245dca01da/go.mod h1:h1NjWce9XRLGQEsW7wpKNCjG9DtNlClVuFLEZdDNbEs= +golang.org/x/sys v0.0.0-20211117180635-dee7805ff2e1 h1:kwrAHlwJ0DUBZwQ238v+Uod/3eZ8B2K5rYsUHBQvzmI= +golang.org/x/sys v0.0.0-20211117180635-dee7805ff2e1/go.mod h1:oPkhp1MJrh7nUepCBck5+mAzfO9JrbApNNgaTdGDITg= golang.org/x/term v0.0.0-20201126162022-7de9c90e9dd1/go.mod h1:bj7SfCRtBDWHUb9snDiAeCFNEtKQo2Wmx5Cou7ajbmo= golang.org/x/text v0.3.0/go.mod h1:NqM8EUOU14njkJ3fqMW+pc6Ldnwhi/IjpwHt7yyuwOQ= golang.org/x/text v0.3.6 h1:aRYxNxv6iGQlyVaZmk6ZgYEDa+Jg18DxebPSrd6bg1M= diff --git a/pkg/console/complete.go b/pkg/console/complete.go new file mode 100644 index 00000000..d55ded83 --- /dev/null +++ b/pkg/console/complete.go @@ -0,0 +1,245 @@ +// Copyright (c) Microsoft Corporation. +// Licensed under the MIT license. + +package console + +import "strings" + +// CompleteLine returns a set of candidate TSQL keywords to complete the current input line +func CompleteLine(line string) []string { + idx := strings.LastIndexAny(line, " ;") + 1 + // we don't try to complete without a starting letter + if idx == len(line) { + return []string{} + } + prefix := strings.ToUpper(string(line[idx:])) + left := 0 + right := len(keywords) - 1 + for left <= right { + mid := (left + right) / 2 + comp := 0 + if len(keywords[mid]) >= len(prefix) { + comp = strings.Compare(prefix, string(keywords[mid][:len(prefix)])) + } else { + comp = strings.Compare(prefix, keywords[mid]) + } + if comp < 0 { + right = mid - 1 + } else if comp > 0 { + left = mid + 1 + } else { + // look up and down the list from mid and return the slice of matching words + first := mid - 1 + last := mid + 1 + for first >= 0 && strings.HasPrefix(keywords[first], prefix) { + first-- + } + for last < len(keywords) && strings.HasPrefix(keywords[last], prefix) { + last++ + } + lines := make([]string, last-first-1) + for i, w := range keywords[first+1 : last] { + lines[i] = mergeLine(line, w, idx) + } + return lines + } + } + return []string{} +} + +// mergeline appends keyword to line starting at index idx +// It matches the case of the current character in the line +func mergeLine(line string, keyword string, idx int) string { + upcase := line[idx] >= 'A' && line[idx] <= 'Z' + b := strings.Builder{} + b.Write([]byte(line[:idx])) + if !upcase { + b.WriteString(strings.ToLower(keyword)) + } else { + b.WriteString(keyword) + } + return b.String() +} + +var keywords = []string{ + "ADD", + "ALL", + "ALTER", + "AND", + "ANY", + "AS", + "ASC", + "AUTHORIZATION", + "BACKUP", + "BEGIN", + "BETWEEN", + "BREAK", + "BROWSE", + "BULK", + "BY", + "CASCADE", + "CASE", + "CHECK", + "CHECKPOINT", + "CLOSE", + "CLUSTERED", + "COALESCE", + "COLLATE", + "COLUMN", + "COMMIT", + "COMPUTE", + "CONSTRAINT", + "CONTAINS", + "CONTAINSTABLE", + "CONTINUE", + "CONVERT", + "CREATE", + "CROSS", + "CURRENT", + "CURRENT_DATE", + "CURRENT_TIME", + "CURRENT_TIMESTAMP", + "CURRENT_USER", + "CURSOR", + "DATABASE", + "DBCC", + "DEALLOCATE", + "DECLARE", + "DEFAULT", + "DELETE", + "DENY", + "DESC", + "DISTINCT", + "DISTRIBUTED", + "DOUBLE", + "DROP", + "ELSE", + "END", + "ERRLVL", + "ESCAPE", + "EXCEPT", + "EXEC", + "EXECUTE", + "EXISTS", + "EXIT", + "EXTERNAL", + "FETCH", + "FILE", + "FILLFACTOR", + "FOR", + "FOREIGN", + "FREETEXT", + "FREETEXTTABLE", + "FROM", + "FULL", + "FUNCTION", + "GOTO", + "GRANT", + "GROUP", + "HAVING", + "HOLDLOCK", + "IDENTITY", + "IDENTITY_INSERT", + "IDENTITYCOL", + "IF", + "IN", + "INDEX", + "INNER", + "INSERT", + "INTERSECT", + "INTO", + "IS", + "JOIN", + "KEY", + "KILL", + "LEFT", + "LIKE", + "LINENO", + "MERGE", + "NATIONAL", + "NOCHECK", + "NONCLUSTERED", + "NOT", + "NULL", + "NULLIF", + "OF", + "OFF", + "OFFSETS", + "ON", + "OPEN", + "OPENDATASOURCE", + "OPENQUERY", + "OPENROWSET", + "OPENXML", + "OPTION", + "OR", + "ORDER", + "OUTER", + "OVER", + "PERCENT", + "PIVOT", + "PLAN", + "PRIMARY", + "PRINT", + "PROC", + "PROCEDURE", + "PUBLIC", + "RAISERROR", + "READ", + "READTEXT", + "RECONFIGURE", + "REFERENCES", + "REPLICATION", + "RESTORE", + "RESTRICT", + "RETURN", + "REVERT", + "REVOKE", + "RIGHT", + "ROLLBACK", + "ROWCOUNT", + "ROWGUIDCOL", + "RULE", + "SAVE", + "SCHEMA", + "SELECT", + "SEMANTICKEYPHRASETABLE", + "SEMANTICSIMILARITYDETAILSTABLE", + "SEMANTICSIMILARITYTABLE", + "SESSION_USER", + "SET", + "SETUSER", + "SHUTDOWN", + "SOME", + "STATISTICS", + "SYSTEM_USER", + "TABLE", + "TABLESAMPLE", + "TEXTSIZE", + "THEN", + "TO", + "TOP", + "TRAN", + "TRANSACTION", + "TRIGGER", + "TRUNCATE", + "TRY_CONVERT", + "TSEQUAL", + "UNION", + "UNIQUE", + "UNPIVOT", + "UPDATE", + "UPDATETEXT", + "USE", + "USER", + "VALUES", + "VARYING", + "VIEW", + "WAITFOR", + "WHEN", + "WHERE", + "WHERECURRENT", + "WHILE", + "WITH", + "WRITETEXT", +} diff --git a/pkg/console/complete_test.go b/pkg/console/complete_test.go new file mode 100644 index 00000000..d71db4ff --- /dev/null +++ b/pkg/console/complete_test.go @@ -0,0 +1,165 @@ +// Copyright (c) Microsoft Corporation. +// Licensed under the MIT license. + +package console + +import ( + "strings" + "testing" + + "github.com/stretchr/testify/assert" +) + +type completer func(line string) []string + +func TestKeywordComplete(t *testing.T) { + testKeywordComplete(t, CompleteLine) +} + +func BenchmarkAutoComplete(b *testing.B) { + for i := 0; i < b.N; i++ { + testKeywordComplete(b, CompleteLine) + } +} + +func BenchmarkPrefixTreeAutoComplete(b *testing.B) { + for i := 0; i < b.N; i++ { + testKeywordComplete(b, prefixTreeCompleteLine) + } +} + +// This method is generic so we could plug in different implementations of completer for comparison +func testKeywordComplete(t testing.TB, c completer) { + t.Helper() + keywords := c("BR") + assert.ElementsMatch(t, []string{"BREAK", "BROWSE"}, keywords, "CompleteLine(ER)") + keywords = c("SELECT name fr") + assert.ElementsMatch(t, []string{"SELECT name freetext", "SELECT name freetexttable", "SELECT name from"}, keywords, "CompleteLine(SELECT name fr)") + keywords = c("my word is SEMANTIC") + assert.ElementsMatch(t, []string{"my word is SEMANTICKEYPHRASETABLE", "my word is SEMANTICSIMILARITYDETAILSTABLE", "my word is SEMANTICSIMILARITYTABLE"}, keywords, "CompleteLine(SEMANTIC)") + keywords = c("BREAD") + assert.Empty(t, keywords, "CompleteLine(BREAD)") + keywords = c("Z") + assert.Empty(t, keywords, "CompleteLine(BREAD)") +} + +// This code provides an alternative implementation of console.CompleteLine +// It's here for benchmarking purposes, so if one outperforms the other we can easily swap. +// Binary search outperforms it, likely due to the relatively small search space. +// With 179 keywords: +/*goos: windows +goarch: amd64 +pkg: github.com/microsoft/go-sqlcmd/pkg/console +cpu: AMD Ryzen 9 5950X 16-Core Processor +BenchmarkAutoComplete +BenchmarkAutoComplete-32 491948 2481 ns/op 586 B/op 31 allocs/op +BenchmarkPrefixTreeAutoComplete +BenchmarkPrefixTreeAutoComplete-32 383217 3124 ns/op 1450 B/op 37 allocs/op +PASS +ok github.com/microsoft/go-sqlcmd/pkg/console 2.746s +*/ +func prefixTreeCompleteLine(line string) []string { + idx := strings.LastIndexAny(line, " ;") + 1 + // we don't try to complete without a starting letter + if idx == len(line) { + return []string{} + } + prefix := strings.ToUpper(string(line[idx:])) + words := keywordList.GetKeywords(prefix) + lines := make([]string, len(words)) + for i, w := range words { + lines[i] = mergeLine(line, w, idx) + } + return lines +} + +type prefixTree struct { + children [27]*prefixTree + isLeaf bool + maxlen int + word string + validChildren []int +} + +func newPrefixTree() *prefixTree { + var tree = &prefixTree{ + isLeaf: false, + validChildren: make([]int, 0, 10), + } + return tree +} + +func runeIndex(ch rune) rune { + if ch == '_' { + return 26 + } else { + return ch - 'A' + } +} + +// Insert relies on the incoming list being sorted +func (p *prefixTree) Insert(word string) { + cur := p + for _, ch := range word { + idx := runeIndex(ch) + if cur.children[idx] == nil { + cur.children[idx] = newPrefixTree() + cur.validChildren = append(cur.validChildren, int(idx)) + } + cur = cur.children[idx] + } + cur.isLeaf = true + cur.word = word + if len(word) > p.maxlen { + p.maxlen = len(word) + } +} + +func (p *prefixTree) GetKeywords(prefix string) []string { + cur := p + for _, ch := range prefix { + idx := runeIndex(ch) + if idx < 0 || idx > 26 { + return []string{} + } + if cur.children[idx] == nil { + return []string{} + } else { + cur = cur.children[idx] + } + } + length := len(prefix) + words := make([]string, 0, 10) + word := make([]rune, length, p.maxlen) + copy(word, []rune(prefix)) + cur.appendLevel(&word, length, &words) + return words +} + +func (p *prefixTree) appendLevel(word *[]rune, length int, words *[]string) { + if p.isLeaf { + *words = append(*words, p.word) + } + for _, i := range p.validChildren { + ch := '_' + if i < 26 { + ch = 'A' + rune(i) + } + if len(*word) == length { + *word = append(*word, ch) + } else { + (*word)[length] = ch + } + p.children[i].appendLevel(word, length+1, words) + } +} + +var keywordList *prefixTree + +// use capital letters for all keywords +func init() { + keywordList = newPrefixTree() + for i := range keywords { + keywordList.Insert(keywords[i]) + } +} diff --git a/pkg/console/console.go b/pkg/console/console.go new file mode 100644 index 00000000..4e0bb001 --- /dev/null +++ b/pkg/console/console.go @@ -0,0 +1,74 @@ +// Copyright (c) Microsoft Corporation. +// Licensed under the MIT license. + +package console + +import ( + "os" + + "github.com/microsoft/go-sqlcmd/pkg/sqlcmd" + "github.com/peterh/liner" +) + +type console struct { + impl *liner.State + historyFile string + prompt string +} + +// NewConsole creates a sqlcmdConsole implementation that provides these features: +// - Storage of input history to a local file. History can be scrolled through using the up and down arrow keys. +// - Simple tab key completion of SQL keywords +func NewConsole(historyFile string) sqlcmd.Console { + c := &console{ + impl: liner.NewLiner(), + historyFile: historyFile, + } + c.impl.SetCtrlCAborts(true) + c.impl.SetCompleter(CompleteLine) + if c.historyFile != "" { + if f, err := os.Open(historyFile); err == nil { + _, _ = c.impl.ReadHistory(f) + f.Close() + } + } + return c +} + +// Close writes out the history data to disk and closes the console buffers +func (c *console) Close() { + if c.historyFile != "" { + if f, err := os.Create(c.historyFile); err == nil { + _, _ = c.impl.WriteHistory(f) + f.Close() + } + } + c.impl.Close() +} + +// Readline displays the current prompt and returns a line of text entered by the user. +// It appends the returned line to the history buffer. +// If the user presses Ctrl-C the error returned is sqlcmd.ErrCtrlC +func (c *console) Readline() (string, error) { + s, err := c.impl.Prompt(c.prompt) + if err == liner.ErrPromptAborted { + return "", sqlcmd.ErrCtrlC + } + c.impl.AppendHistory(s) + return s, err +} + +// ReadPassword displays the given prompt and returns the password entered by the user. +// If the user presses Ctrl-C the error returned is sqlcmd.ErrCtrlC +func (c *console) ReadPassword(prompt string) ([]byte, error) { + b, err := c.impl.PasswordPrompt(prompt) + if err == liner.ErrPromptAborted { + return []byte{}, sqlcmd.ErrCtrlC + } + return []byte(b), err +} + +// SetPrompt sets the prompt text shown to input the next line +func (c *console) SetPrompt(s string) { + c.prompt = s +} diff --git a/pkg/sqlcmd/sqlcmd.go b/pkg/sqlcmd/sqlcmd.go index efca05d3..6a2fa91f 100644 --- a/pkg/sqlcmd/sqlcmd.go +++ b/pkg/sqlcmd/sqlcmd.go @@ -117,9 +117,6 @@ func (s *Sqlcmd) Run(once bool, processAll bool) error { if !execute { return nil } - } else if err.Error() == "Interrupt" { - // Ignore any error printing the ctrl-c notice since we are exiting - _, _ = s.GetOutput().Write([]byte(ErrCtrlC.Error() + SqlcmdEol)) } else { _, _ = s.GetOutput().Write([]byte(err.Error() + SqlcmdEol)) }