Skip to content

Commit 332f5e7

Browse files
authored
Add --database flag to sqlcmd query (#288)
1 parent b8fed69 commit 332f5e7

File tree

5 files changed

+47
-11
lines changed

5 files changed

+47
-11
lines changed

cmd/modern/root/install/mssql-base.go

+1-1
Original file line numberDiff line numberDiff line change
@@ -326,7 +326,7 @@ func (c *MssqlBase) createContainer(imageName string, contextName string) {
326326
Password: secret.Encode(saPassword, c.encryptPassword)},
327327
Name: "sa"}
328328

329-
c.sql.Connect(endpoint, saUser, sql.ConnectOptions{Interactive: false})
329+
c.sql.Connect(endpoint, saUser, sql.ConnectOptions{Database: "master", Interactive: false})
330330

331331
c.createNonSaUser(userName, password)
332332

cmd/modern/root/query.go

+21-4
Original file line numberDiff line numberDiff line change
@@ -4,16 +4,19 @@
44
package root
55

66
import (
7+
"fmt"
78
"github.com/microsoft/go-sqlcmd/internal/cmdparser"
89
"github.com/microsoft/go-sqlcmd/internal/config"
10+
"github.com/microsoft/go-sqlcmd/internal/pal"
911
"github.com/microsoft/go-sqlcmd/internal/sql"
1012
)
1113

1214
// Query defines the `sqlcmd query` command
1315
type Query struct {
1416
cmdparser.Cmd
1517

16-
text string
18+
text string
19+
database string
1720
}
1821

1922
func (c *Query) DefineCommand(...cmdparser.CommandOptions) {
@@ -25,7 +28,15 @@ func (c *Query) DefineCommand(...cmdparser.CommandOptions) {
2528
`sqlcmd query "SELECT @@SERVERNAME"`,
2629
`sqlcmd query --text "SELECT @@SERVERNAME"`,
2730
`sqlcmd query --query "SELECT @@SERVERNAME"`,
28-
}}},
31+
}},
32+
{Description: "Run a query using [master] database", Steps: []string{
33+
`sqlcmd query "SELECT DB_NAME()" --database master`,
34+
}},
35+
{Description: "Set new default database", Steps: []string{
36+
fmt.Sprintf(`sqlcmd query "ALTER LOGIN [%s] WITH DEFAULT_DATABASE = [tempdb]" --database master`,
37+
pal.UserName()),
38+
}},
39+
},
2940
Run: c.run,
3041
FirstArgAlternativeForFlag: &cmdparser.AlternativeForFlagOptions{
3142
Flag: "text",
@@ -47,6 +58,12 @@ func (c *Query) DefineCommand(...cmdparser.CommandOptions) {
4758
Name: "query",
4859
Shorthand: "q",
4960
Usage: "Command text to run"})
61+
62+
c.AddFlag(cmdparser.FlagOptions{
63+
String: &c.database,
64+
Name: "database",
65+
Shorthand: "d",
66+
Usage: "Database to use"})
5067
}
5168

5269
// run executes the Query command.
@@ -58,9 +75,9 @@ func (c *Query) run() {
5875

5976
s := sql.New(sql.SqlOptions{})
6077
if c.text == "" {
61-
s.Connect(endpoint, user, sql.ConnectOptions{Interactive: true})
78+
s.Connect(endpoint, user, sql.ConnectOptions{Database: c.database, Interactive: true})
6279
} else {
63-
s.Connect(endpoint, user, sql.ConnectOptions{Interactive: false})
80+
s.Connect(endpoint, user, sql.ConnectOptions{Database: c.database, Interactive: false})
6481
}
6582

6683
s.Query(c.text)

cmd/modern/root/query_test.go

+19-6
Original file line numberDiff line numberDiff line change
@@ -7,7 +7,6 @@ import (
77
"fmt"
88
"github.com/microsoft/go-sqlcmd/cmd/modern/root/config"
99
"github.com/microsoft/go-sqlcmd/internal/cmdparser"
10-
"github.com/stretchr/testify/assert"
1110
"os"
1211
"runtime"
1312
"testing"
@@ -18,8 +17,27 @@ func TestQuery(t *testing.T) {
1817
if runtime.GOOS != "windows" {
1918
t.Skip("stuartpa: This is failing in the pipeline (Login failed for user 'sa'.)")
2019
}
20+
2121
cmdparser.TestSetup(t)
2222

23+
setupContext(t)
24+
cmdparser.TestCmd[*Query]("PRINT")
25+
}
26+
27+
func TestQueryWithNonDefaultDatabase(t *testing.T) {
28+
if runtime.GOOS != "windows" {
29+
t.Skip("stuartpa: This is failing in the pipeline (Login failed for user 'sa'.)")
30+
}
31+
32+
cmdparser.TestSetup(t)
33+
34+
setupContext(t)
35+
cmdparser.TestCmd[*Query](`--text "PRINT DB_NAME()" --database master`)
36+
37+
// TODO: Add test validation that DB name was actually master!
38+
}
39+
40+
func setupContext(t *testing.T) {
2341
// if SQLCMDSERVER != "" add an endpoint using the --address
2442
if os.Getenv("SQLCMDSERVER") == "" {
2543
cmdparser.TestCmd[*config.AddEndpoint]()
@@ -33,10 +51,6 @@ func TestQuery(t *testing.T) {
3351
if os.Getenv("SQLCMDPASSWORD") != "" &&
3452
os.Getenv("SQLCMDUSER") != "" {
3553

36-
// sqlcmd uses the SQLCMD_PASSWORD env var, but the tests use the
37-
// SQLCMDPASSWORD env var
38-
err := os.Setenv("SQLCMD_PASSWORD", os.Getenv("SQLCMDPASSWORD"))
39-
assert.Nil(t, err)
4054
cmdparser.TestCmd[*config.AddUser](
4155
fmt.Sprintf("--name user1 --username %s",
4256
os.Getenv("SQLCMDUSER")))
@@ -45,5 +59,4 @@ func TestQuery(t *testing.T) {
4559
cmdparser.TestCmd[*config.AddContext]("--endpoint endpoint")
4660
}
4761
cmdparser.TestCmd[*config.View]() // displaying the config (info in-case test fails)
48-
cmdparser.TestCmd[*Query]("PRINT")
4962
}

internal/sql/interface.go

+2
Original file line numberDiff line numberDiff line change
@@ -14,5 +14,7 @@ type Sql interface {
1414
}
1515

1616
type ConnectOptions struct {
17+
Database string
18+
1719
Interactive bool
1820
}

internal/sql/mssql.go

+4
Original file line numberDiff line numberDiff line change
@@ -40,6 +40,10 @@ func (m *mssql) Connect(
4040
ApplicationName: "sqlcmd",
4141
}
4242

43+
if options.Database != "" {
44+
connect.Database = options.Database
45+
}
46+
4347
if user == nil {
4448
connect.UseTrustedConnection = true
4549
} else {

0 commit comments

Comments
 (0)