Skip to content

Commit cb5dd57

Browse files
authored
implement startup script (#148)
* implement startup script * fix test * fix test for SQL DB
1 parent 036e4f8 commit cb5dd57

File tree

4 files changed

+69
-16
lines changed

4 files changed

+69
-16
lines changed

cmd/sqlcmd/main.go

+31-16
Original file line numberDiff line numberDiff line change
@@ -283,28 +283,43 @@ func run(vars *sqlcmd.Variables, args *SQLCmdArguments) (int, error) {
283283
}
284284
}
285285
}
286-
once := false
287-
if args.InitialQuery != "" {
288-
s.Query = args.InitialQuery
289-
} else if args.Query != "" {
290-
once = true
291-
s.Query = args.Query
292-
}
286+
293287
// connect using no overrides
294288
err = s.ConnectDb(nil, line == nil)
295289
if err != nil {
296290
return 1, err
297291
}
298292

299-
iactive := args.InputFile == nil && args.Query == ""
300-
if iactive || s.Query != "" {
301-
err = s.Run(once, false)
302-
} else {
303-
for f := range args.InputFile {
304-
if err = s.IncludeFile(args.InputFile[f], true); err != nil {
305-
s.WriteError(s.GetError(), err)
306-
s.Exitcode = 1
307-
break
293+
script := vars.StartupScriptFile()
294+
if !args.DisableCmdAndWarn && len(script) > 0 {
295+
f, fileErr := os.Open(script)
296+
if fileErr != nil {
297+
s.WriteError(s.GetError(), sqlcmd.InvalidVariableValue(sqlcmd.SQLCMDINI, script))
298+
} else {
299+
_ = f.Close()
300+
// IncludeFile won't return an error for a SQL error, but ExitCode will be non-zero if -b was passed on the commandline
301+
err = s.IncludeFile(script, true)
302+
}
303+
}
304+
305+
if err == nil && s.Exitcode == 0 {
306+
once := false
307+
if args.InitialQuery != "" {
308+
s.Query = args.InitialQuery
309+
} else if args.Query != "" {
310+
once = true
311+
s.Query = args.Query
312+
}
313+
iactive := args.InputFile == nil && args.Query == ""
314+
if iactive || s.Query != "" {
315+
err = s.Run(once, false)
316+
} else {
317+
for f := range args.InputFile {
318+
if err = s.IncludeFile(args.InputFile[f], true); err != nil {
319+
s.WriteError(s.GetError(), err)
320+
s.Exitcode = 1
321+
break
322+
}
308323
}
309324
}
310325
}

cmd/sqlcmd/main_test.go

+24
Original file line numberDiff line numberDiff line change
@@ -383,6 +383,30 @@ func TestConditionsForPasswordPrompt(t *testing.T) {
383383
}
384384
}
385385

386+
func TestStartupScript(t *testing.T) {
387+
o, err := os.CreateTemp("", "sqlcmdmain")
388+
assert.NoError(t, err, "os.CreateTemp")
389+
defer os.Remove(o.Name())
390+
defer o.Close()
391+
args = newArguments()
392+
args.OutputFile = o.Name()
393+
args.Query = "set nocount on"
394+
if canTestAzureAuth() {
395+
args.UseAad = true
396+
}
397+
vars := sqlcmd.InitializeVariables(true)
398+
setVars(vars, &args)
399+
vars.Set(sqlcmd.SQLCMDINI, "testdata/select100.sql")
400+
vars.Set(sqlcmd.SQLCMDMAXVARTYPEWIDTH, "0")
401+
exitCode, err := run(vars, &args)
402+
assert.NoError(t, err, "run")
403+
assert.Equal(t, 0, exitCode, "exitCode")
404+
bytes, err := os.ReadFile(o.Name())
405+
if assert.NoError(t, err, "os.ReadFile") {
406+
assert.Equal(t, "100"+sqlcmd.SqlcmdEol+sqlcmd.SqlcmdEol+oneRowAffected+sqlcmd.SqlcmdEol, string(bytes), "Incorrect output from run")
407+
}
408+
}
409+
386410
// Assuming public Azure, use AAD when SQLCMDUSER environment variable is not set
387411
func canTestAzureAuth() bool {
388412
server := os.Getenv(sqlcmd.SQLCMDSERVER)

pkg/sqlcmd/errors.go

+9
Original file line numberDiff line numberDiff line change
@@ -6,6 +6,7 @@ package sqlcmd
66
import (
77
"errors"
88
"fmt"
9+
"strings"
910
)
1011

1112
// ErrorPrefix is the prefix for all sqlcmd-generated errors
@@ -56,6 +57,14 @@ func UndefinedVariable(variable string) *VariableError {
5657
}
5758
}
5859

60+
// InvalidVariableValue indicates the variable was set to an invalid value
61+
func InvalidVariableValue(variable string, value string) *VariableError {
62+
return &VariableError{
63+
Variable: variable,
64+
MessageFormat: "The environment variable: '%s' has invalid value: '" + strings.ReplaceAll(value, `%`, `%%`) + "'.",
65+
}
66+
}
67+
5968
// CommandError indicates syntax errors for specific sqlcmd commands
6069
type CommandError struct {
6170
Command string

pkg/sqlcmd/variables.go

+5
Original file line numberDiff line numberDiff line change
@@ -179,6 +179,11 @@ func (v Variables) Format() string {
179179
return "horizontal"
180180
}
181181

182+
// StartupScriptFile is the path to the file that contains the startup script
183+
func (v Variables) StartupScriptFile() string {
184+
return v[SQLCMDINI]
185+
}
186+
182187
// TextEditor is the query editor application launched by the :ED command
183188
func (v Variables) TextEditor() string {
184189
return v[SQLCMDEDITOR]

0 commit comments

Comments
 (0)