Skip to content

Commit 0fba127

Browse files
authored
Implement connect command and password prompt (#51)
* impl password prompt and connect command * move connectsettings to new file * temporarily use shueybubbles go-mssqldb
1 parent 9d446a0 commit 0fba127

12 files changed

+451
-266
lines changed

README.md

+5
Original file line numberDiff line numberDiff line change
@@ -26,6 +26,11 @@ We will be implementing command line switches and behaviors over time. Several s
2626
- Some behaviors that were kept to maintain compatibility with `OSQL` may be changed, such as alignment of column headers for some data types.
2727
- All commands must fit on one line, even `EXIT`. Interactive mode will not check for open parentheses or quotes for commands and prompt for successive lines. The ODBC sqlcmd allows the query run by `EXIT(query)` to span multiple lines.
2828

29+
### Miscellaneous enhancements
30+
31+
- `:Connect` now has an optional `-G` parameter to select one of the authentication methods for Azure SQL Database - `SqlAuthentication`, `ActiveDirectoryDefault`, `ActiveDirectoryIntegrated`, `ActiveDirectoryServicePrincipal`, `ActiveDirectoryManagedIdentity`, `ActiveDirectoryPassword`. If `-G` is not provided, either Integrated security or SQL Authentication will be used, dependent on the presence of a `-U` user name parameter.
32+
- The new `--driver-logging-level` command line parameter allows you to see traces from the `go-mssqldb` client driver. Use `64` to see all traces.
33+
2934
### Azure Active Directory Authentication
3035

3136
This version of sqlcmd supports a broader range of AAD authentication models, based on the [azidentity package](https://pkg.go.dev/github.com/Azure/azure-sdk-for-go/sdk/azidentity). The implementation relies on an AAD Connector in the [driver](https://github.com/denisenkom/go-mssqldb).

cmd/sqlcmd/main.go

+37-22
Original file line numberDiff line numberDiff line change
@@ -99,7 +99,7 @@ func main() {
9999
setVars(vars, &args)
100100

101101
// so far sqlcmd prints all the errors itself so ignore it
102-
exitCode, _ := run(vars)
102+
exitCode, _ := run(vars, &args)
103103
os.Exit(exitCode)
104104
}
105105

@@ -156,43 +156,57 @@ func setVars(vars *sqlcmd.Variables, args *SQLCmdArguments) {
156156

157157
}
158158

159-
func setConnect(s *sqlcmd.Sqlcmd, args *SQLCmdArguments) {
159+
func setConnect(connect *sqlcmd.ConnectSettings, args *SQLCmdArguments, vars *sqlcmd.Variables) {
160160
if !args.DisableCmdAndWarn {
161-
s.Connect.Password = os.Getenv(sqlcmd.SQLCMDPASSWORD)
162-
}
163-
s.Connect.UseTrustedConnection = args.UseTrustedConnection
164-
s.Connect.TrustServerCertificate = args.TrustServerCertificate
165-
s.Connect.AuthenticationMethod = args.authenticationMethod(s.Connect.Password != "")
166-
s.Connect.DisableEnvironmentVariables = args.DisableCmdAndWarn
167-
s.Connect.DisableVariableSubstitution = args.DisableVariableSubstitution
168-
s.Connect.ApplicationIntent = args.ApplicationIntent
169-
s.Connect.LoginTimeoutSeconds = args.LoginTimeout
170-
s.Connect.Encrypt = args.EncryptConnection
171-
s.Connect.PacketSize = args.PacketSize
172-
s.Connect.WorkstationName = args.WorkstationName
173-
s.Connect.LogLevel = args.DriverLoggingLevel
174-
s.Connect.ExitOnError = args.ExitOnError
175-
s.Connect.ErrorSeverityLevel = args.ErrorSeverityLevel
161+
connect.Password = os.Getenv(sqlcmd.SQLCMDPASSWORD)
162+
}
163+
connect.ServerName = args.Server
164+
if connect.ServerName == "" {
165+
connect.ServerName, _ = vars.Get(sqlcmd.SQLCMDSERVER)
166+
}
167+
connect.Database = args.DatabaseName
168+
if connect.Database == "" {
169+
connect.Database, _ = vars.Get(sqlcmd.SQLCMDDBNAME)
170+
}
171+
connect.UserName = args.UserName
172+
if connect.UserName == "" {
173+
connect.UserName, _ = vars.Get(sqlcmd.SQLCMDUSER)
174+
}
175+
connect.UseTrustedConnection = args.UseTrustedConnection
176+
connect.TrustServerCertificate = args.TrustServerCertificate
177+
connect.AuthenticationMethod = args.authenticationMethod(connect.Password != "")
178+
connect.DisableEnvironmentVariables = args.DisableCmdAndWarn
179+
connect.DisableVariableSubstitution = args.DisableVariableSubstitution
180+
connect.ApplicationIntent = args.ApplicationIntent
181+
connect.LoginTimeoutSeconds = args.LoginTimeout
182+
connect.Encrypt = args.EncryptConnection
183+
connect.PacketSize = args.PacketSize
184+
connect.WorkstationName = args.WorkstationName
185+
connect.LogLevel = args.DriverLoggingLevel
186+
connect.ExitOnError = args.ExitOnError
187+
connect.ErrorSeverityLevel = args.ErrorSeverityLevel
176188
}
177189

178-
func run(vars *sqlcmd.Variables) (int, error) {
190+
func run(vars *sqlcmd.Variables, args *SQLCmdArguments) (int, error) {
179191
wd, err := os.Getwd()
180192
if err != nil {
181193
return 1, err
182194
}
183195

184196
iactive := args.InputFile == nil && args.Query == ""
197+
var console sqlcmd.Console = nil
185198
var line *readline.Instance
186199
if iactive {
187200
line, err = readline.New(">")
188201
if err != nil {
189202
return 1, err
190203
}
204+
console = line
191205
defer line.Close()
192206
}
193207

194-
s := sqlcmd.New(line, wd, vars)
195-
208+
s := sqlcmd.New(console, wd, vars)
209+
setConnect(&s.Connect, args, vars)
196210
if args.BatchTerminator != "GO" {
197211
err = s.Cmd.SetBatchTerminator(args.BatchTerminator)
198212
if err != nil {
@@ -203,7 +217,7 @@ func run(vars *sqlcmd.Variables) (int, error) {
203217
return 1, err
204218
}
205219

206-
setConnect(s, &args)
220+
setConnect(&s.Connect, args, vars)
207221
s.Format = sqlcmd.NewSQLCmdDefaultFormatter(false)
208222
if args.OutputFile != "" {
209223
err = s.RunCommand(s.Cmd["OUT"], []string{args.OutputFile})
@@ -218,7 +232,8 @@ func run(vars *sqlcmd.Variables) (int, error) {
218232
once = true
219233
s.Query = args.Query
220234
}
221-
err = s.ConnectDb("", "", "", !iactive)
235+
// connect using no overrides
236+
err = s.ConnectDb(nil, !iactive)
222237
if err != nil {
223238
return 1, err
224239
}

cmd/sqlcmd/main_test.go

+3-4
Original file line numberDiff line numberDiff line change
@@ -122,7 +122,7 @@ func TestRunInputFiles(t *testing.T) {
122122
vars.Set(sqlcmd.SQLCMDMAXVARTYPEWIDTH, "0")
123123
setVars(vars, &args)
124124

125-
exitCode, err := run(vars)
125+
exitCode, err := run(vars, &args)
126126
assert.NoError(t, err, "run")
127127
assert.Equal(t, 0, exitCode, "exitCode")
128128
bytes, err := os.ReadFile(o.Name())
@@ -148,7 +148,7 @@ func TestQueryAndExit(t *testing.T) {
148148
vars.Set("VAR1", "100")
149149
setVars(vars, &args)
150150

151-
exitCode, err := run(vars)
151+
exitCode, err := run(vars, &args)
152152
assert.NoError(t, err, "run")
153153
assert.Equal(t, 0, exitCode, "exitCode")
154154
bytes, err := os.ReadFile(o.Name())
@@ -173,8 +173,7 @@ func TestAzureAuth(t *testing.T) {
173173
vars := sqlcmd.InitializeVariables(!args.DisableCmdAndWarn)
174174
vars.Set(sqlcmd.SQLCMDMAXVARTYPEWIDTH, "0")
175175
setVars(vars, &args)
176-
177-
exitCode, err := run(vars)
176+
exitCode, err := run(vars, &args)
178177
assert.NoError(t, err, "run")
179178
assert.Equal(t, 0, exitCode, "exitCode")
180179
bytes, err := os.ReadFile(o.Name())

go.mod

+1-2
Original file line numberDiff line numberDiff line change
@@ -3,8 +3,6 @@ module github.com/microsoft/go-sqlcmd
33
go 1.16
44

55
require (
6-
github.com/Azure/azure-sdk-for-go/sdk/azcore v0.19.0
7-
github.com/Azure/azure-sdk-for-go/sdk/azidentity v0.11.0
86
github.com/alecthomas/kong v0.2.18-0.20210621093454-54558f65e86f
97
github.com/chzyer/logex v1.1.10 // indirect
108
github.com/chzyer/test v0.0.0-20210722231415-061457976a23 // indirect
@@ -15,3 +13,4 @@ require (
1513
github.com/stretchr/testify v1.7.0
1614
)
1715

16+
replace github.com/denisenkom/go-mssqldb => github.com/shueybubbles/go-mssqldb v0.10.1-0.20220303143659-8896461e4ec7

go.sum

+2-2
Original file line numberDiff line numberDiff line change
@@ -13,8 +13,6 @@ github.com/chzyer/test v0.0.0-20210722231415-061457976a23/go.mod h1:Q3SI9o4m/ZMn
1313
github.com/davecgh/go-spew v1.1.0/go.mod h1:J7Y8YcW2NihsgmVo/mv3lAwl/skON4iLHjSsI+c5H38=
1414
github.com/davecgh/go-spew v1.1.1 h1:vj9j/u1bqnvCEfJOwUhtlOARqs3+rkHYY13jYWTU97c=
1515
github.com/davecgh/go-spew v1.1.1/go.mod h1:J7Y8YcW2NihsgmVo/mv3lAwl/skON4iLHjSsI+c5H38=
16-
github.com/denisenkom/go-mssqldb v0.12.0 h1:VtrkII767ttSPNRfFekePK3sctr+joXgO58stqQbtUA=
17-
github.com/denisenkom/go-mssqldb v0.12.0/go.mod h1:iiK0YP1ZeepvmBQk/QpLEhhTNJgfzrpArPY/aFvc9yU=
1816
github.com/dnaeon/go-vcr v1.2.0/go.mod h1:R4UdLID7HZT3taECzJs4YgbbH6PIGXB6W/sc5OLb6RQ=
1917
github.com/gohxs/readline v0.0.0-20171011095936-a780388e6e7c h1:yE35fKFwcelIte3q5q1/cPiY7pI7vvf5/j/0ddxNCKs=
2018
github.com/gohxs/readline v0.0.0-20171011095936-a780388e6e7c/go.mod h1:9S/fKAutQ6wVHqm1jnp9D9sc5hu689s9AaTWFS92LaU=
@@ -31,6 +29,8 @@ github.com/pkg/errors v0.9.1 h1:FEBLx1zS214owpjy7qsBeixbURkuhQAwrK5UwLGTwt4=
3129
github.com/pkg/errors v0.9.1/go.mod h1:bwawxfHBFNV+L2hUp1rHADufV3IMtnDRdf1r5NINEl0=
3230
github.com/pmezard/go-difflib v1.0.0 h1:4DBwDE0NGyQoBHbLQYPwSUPoCMWR5BEzIk/f1lZbAQM=
3331
github.com/pmezard/go-difflib v1.0.0/go.mod h1:iKH77koFhYxTK1pcRnkKkqfTogsbg7gZNVY4sRDYZ/4=
32+
github.com/shueybubbles/go-mssqldb v0.10.1-0.20220303143659-8896461e4ec7 h1:4CIaYagSRCGr0/Gh6cfF5cQx3RVE3qrQukZn8iMO6Y8=
33+
github.com/shueybubbles/go-mssqldb v0.10.1-0.20220303143659-8896461e4ec7/go.mod h1:iiK0YP1ZeepvmBQk/QpLEhhTNJgfzrpArPY/aFvc9yU=
3434
github.com/stretchr/objx v0.1.0/go.mod h1:HFkY916IF+rwdDfMAkV7OtwuqBVzrE8GR6GFx+wExME=
3535
github.com/stretchr/testify v1.7.0 h1:nwc3DEeHmmLAfoZucVR881uASk0Mfjw8xYJ99tb5CcY=
3636
github.com/stretchr/testify v1.7.0/go.mod h1:6Fq8oRcR53rry900zMqJjRRixrwX3KX962/h/Wwjteg=

pkg/sqlcmd/azure_auth.go

+11-11
Original file line numberDiff line numberDiff line change
@@ -5,6 +5,7 @@ package sqlcmd
55

66
import (
77
"database/sql/driver"
8+
"fmt"
89
"net/url"
910
"os"
1011

@@ -24,28 +25,27 @@ func getSqlClientId() string {
2425
return sqlClientId
2526
}
2627

27-
func (s *Sqlcmd) GetTokenBasedConnection(connstr string, user string, password string) (driver.Connector, error) {
28+
func GetTokenBasedConnection(connstr string, authenticationMethod string) (driver.Connector, error) {
2829

2930
connectionUrl, err := url.Parse(connstr)
3031
if err != nil {
3132
return nil, err
3233
}
3334

34-
if user != "" {
35-
connectionUrl.User = url.UserPassword(user, password)
36-
}
37-
3835
query := connectionUrl.Query()
39-
query.Set("fedauth", s.Connect.authenticationMethod())
36+
query.Set("fedauth", authenticationMethod)
4037
query.Set("applicationclientid", getSqlClientId())
41-
42-
switch s.Connect.AuthenticationMethod {
43-
case azuread.ActiveDirectoryServicePrincipal:
44-
case azuread.ActiveDirectoryApplication:
38+
switch authenticationMethod {
39+
case azuread.ActiveDirectoryServicePrincipal, azuread.ActiveDirectoryApplication:
4540
query.Set("clientcertpath", os.Getenv("AZURE_CLIENT_CERTIFICATE_PATH"))
4641
case azuread.ActiveDirectoryInteractive:
42+
loginTimeout := query.Get("connection timeout")
43+
loginTimeoutSeconds := 0
44+
if loginTimeout != "" {
45+
_, _ = fmt.Sscanf(loginTimeout, "%d", &loginTimeoutSeconds)
46+
}
4747
// AAD interactive needs minutes at minimum
48-
if s.Connect.LoginTimeoutSeconds < 120 {
48+
if loginTimeoutSeconds > 0 && loginTimeoutSeconds < 120 {
4949
query.Set("connection timeout", "120")
5050
}
5151
}

pkg/sqlcmd/commands.go

+43
Original file line numberDiff line numberDiff line change
@@ -10,6 +10,8 @@ import (
1010
"sort"
1111
"strings"
1212
"syscall"
13+
14+
"github.com/alecthomas/kong"
1315
)
1416

1517
// Command defines a sqlcmd action which can be intermixed with the SQL batch
@@ -80,6 +82,11 @@ func newCommands() Commands {
8082
action: listCommand,
8183
name: "LIST",
8284
},
85+
"CONNECT": {
86+
regex: regexp.MustCompile(`(?im)^[ \t]*:CONNECT(?:[ \t]+(.*$)|$)`),
87+
action: connectCommand,
88+
name: "CONNECT",
89+
},
8390
}
8491

8592
}
@@ -285,3 +292,39 @@ func listCommand(s *Sqlcmd, args []string, line uint) error {
285292

286293
return nil
287294
}
295+
296+
type connectData struct {
297+
Server string `arg:""`
298+
Database string `short:"D"`
299+
Username string `short:"U"`
300+
Password string `short:"P"`
301+
LoginTimeout int `short:"l"`
302+
AuthenticationMethod string `short:"G"`
303+
}
304+
305+
func connectCommand(s *Sqlcmd, args []string, line uint) error {
306+
if len(args) == 0 || strings.TrimSpace(args[0]) == "" {
307+
return InvalidCommandError("CONNECT", line)
308+
}
309+
arguments := &connectData{}
310+
parser, err := kong.New(arguments)
311+
if err != nil {
312+
return InvalidCommandError("CONNECT", line)
313+
}
314+
if _, err = parser.Parse(strings.Split(args[0], " ")); err != nil {
315+
return InvalidCommandError("CONNECT", line)
316+
}
317+
318+
connect := s.Connect
319+
connect.UserName = arguments.Username
320+
connect.Password = arguments.Password
321+
connect.ServerName = arguments.Server
322+
if arguments.LoginTimeout > 0 {
323+
connect.LoginTimeoutSeconds = arguments.LoginTimeout
324+
}
325+
connect.AuthenticationMethod = arguments.AuthenticationMethod
326+
// If no user name is provided we switch to integrated auth
327+
_ = s.ConnectDb(&connect, s.lineIo == nil)
328+
// ConnectDb prints connection errors already, and failure to connect is not fatal even with -b option
329+
return nil
330+
}

pkg/sqlcmd/commands_test.go

+39
Original file line numberDiff line numberDiff line change
@@ -5,6 +5,8 @@ package sqlcmd
55

66
import (
77
"bytes"
8+
"fmt"
9+
"os"
810
"strings"
911
"testing"
1012

@@ -41,6 +43,7 @@ func TestCommandParsing(t *testing.T) {
4143
{`:EXIT (select 100 as count)`, "EXIT", []string{"(select 100 as count)"}},
4244
{`:EXIT ( )`, "EXIT", []string{"( )"}},
4345
{`EXIT `, "EXIT", []string{""}},
46+
{`:Connect someserver -U someuser`, "CONNECT", []string{"someserver -U someuser"}},
4447
}
4548

4649
for _, test := range commands {
@@ -151,3 +154,39 @@ func TestListCommand(t *testing.T) {
151154
o := buf.buf.String()
152155
assert.Equal(t, o, "select 1"+SqlcmdEol, ":list output not equal to batch")
153156
}
157+
158+
func TestConnectCommand(t *testing.T) {
159+
s, _ := setupSqlCmdWithMemoryOutput(t)
160+
prompted := false
161+
s.lineIo = &testConsole{
162+
OnPasswordPrompt: func(prompt string) ([]byte, error) {
163+
prompted = true
164+
return []byte{}, nil
165+
},
166+
}
167+
err := connectCommand(s, []string{"someserver -U someuser"}, 1)
168+
assert.NoError(t, err, "connectCommand with valid arguments doesn't return an error on connect failure")
169+
assert.True(t, prompted, "connectCommand with user name and no password should prompt for password")
170+
assert.NotEqual(t, "someserver", s.Connect.ServerName, "On error, sqlCmd.Connect does not copy inputs")
171+
172+
err = connectCommand(s, []string{}, 2)
173+
assert.EqualError(t, err, InvalidCommandError("CONNECT", 2).Error(), ":Connect with no arguments should return an error")
174+
c := newConnect(t)
175+
176+
authenticationMethod := ""
177+
if c.Password == "" {
178+
c.UserName = os.Getenv("AZURE_CLIENT_ID") + "@" + os.Getenv("AZURE_TENANT_ID")
179+
c.Password = os.Getenv("AZURE_CLIENT_SECRET")
180+
authenticationMethod = "-G ActiveDirectoryServicePrincipal"
181+
if c.Password == "" {
182+
t.Log("Not trying :Connect with valid password due to no password being available")
183+
return
184+
}
185+
err = connectCommand(s, []string{fmt.Sprintf("%s -U %s -P %s %s", c.ServerName, c.UserName, c.Password, authenticationMethod)}, 3)
186+
assert.NoError(t, err, "connectCommand with valid parameters should not return an error")
187+
// not using assert to avoid printing passwords in the log
188+
if s.Connect.UserName != c.UserName || c.Password != s.Connect.Password {
189+
t.Fatal("After connect, sqlCmd.Connect is not updated")
190+
}
191+
}
192+
}

0 commit comments

Comments
 (0)