Skip to content

Commit f78b382

Browse files
authored
fix: resolve vars in error and out commands (#535)
1 parent 9cd8538 commit f78b382

File tree

2 files changed

+28
-10
lines changed

2 files changed

+28
-10
lines changed

pkg/sqlcmd/commands.go

+17-8
Original file line numberDiff line numberDiff line change
@@ -264,13 +264,18 @@ func outCommand(s *Sqlcmd, args []string, line uint) error {
264264
if len(args) == 0 || args[0] == "" {
265265
return InvalidCommandError("OUT", line)
266266
}
267+
filePath, err := resolveArgumentVariables(s, []rune(args[0]), true)
268+
if err != nil {
269+
return err
270+
}
271+
267272
switch {
268-
case strings.EqualFold(args[0], "stdout"):
273+
case strings.EqualFold(filePath, "stdout"):
269274
s.SetOutput(os.Stdout)
270-
case strings.EqualFold(args[0], "stderr"):
275+
case strings.EqualFold(filePath, "stderr"):
271276
s.SetOutput(os.Stderr)
272277
default:
273-
o, err := os.OpenFile(args[0], os.O_TRUNC|os.O_CREATE|os.O_WRONLY, 0o644)
278+
o, err := os.OpenFile(filePath, os.O_TRUNC|os.O_CREATE|os.O_WRONLY, 0o644)
274279
if err != nil {
275280
return InvalidFileError(err, args[0])
276281
}
@@ -290,15 +295,19 @@ func outCommand(s *Sqlcmd, args []string, line uint) error {
290295
// errorCommand changes the error writer to use a file
291296
func errorCommand(s *Sqlcmd, args []string, line uint) error {
292297
if len(args) == 0 || args[0] == "" {
293-
return InvalidCommandError("OUT", line)
298+
return InvalidCommandError("ERROR", line)
299+
}
300+
filePath, err := resolveArgumentVariables(s, []rune(args[0]), true)
301+
if err != nil {
302+
return err
294303
}
295304
switch {
296-
case strings.EqualFold(args[0], "stderr"):
305+
case strings.EqualFold(filePath, "stderr"):
297306
s.SetError(os.Stderr)
298-
case strings.EqualFold(args[0], "stdout"):
307+
case strings.EqualFold(filePath, "stdout"):
299308
s.SetError(os.Stdout)
300309
default:
301-
o, err := os.OpenFile(args[0], os.O_TRUNC|os.O_CREATE|os.O_WRONLY, 0o644)
310+
o, err := os.OpenFile(filePath, os.O_TRUNC|os.O_CREATE|os.O_WRONLY, 0o644)
302311
if err != nil {
303312
return InvalidFileError(err, args[0])
304313
}
@@ -549,7 +558,7 @@ func xmlCommand(s *Sqlcmd, args []string, line uint) error {
549558
func resolveArgumentVariables(s *Sqlcmd, arg []rune, failOnUnresolved bool) (string, error) {
550559
var b *strings.Builder
551560
end := len(arg)
552-
for i := 0; i < end; {
561+
for i := 0; i < end && !s.Connect.DisableVariableSubstitution; {
553562
c, next := arg[i], grab(arg, i+1, end)
554563
switch {
555564
case c == '$' && next == '(':

pkg/sqlcmd/commands_test.go

+11-2
Original file line numberDiff line numberDiff line change
@@ -246,24 +246,28 @@ func TestConnectCommand(t *testing.T) {
246246

247247
func TestErrorCommand(t *testing.T) {
248248
s, buf := setupSqlCmdWithMemoryOutput(t)
249+
defer s.SetError(nil)
249250
defer buf.Close()
250251
file, err := os.CreateTemp("", "sqlcmderr")
251252
assert.NoError(t, err, "os.CreateTemp")
252253
defer os.Remove(file.Name())
253254
fileName := file.Name()
254255
_ = file.Close()
255256
err = errorCommand(s, []string{""}, 1)
256-
assert.EqualError(t, err, InvalidCommandError("OUT", 1).Error(), "errorCommand with empty file name")
257+
assert.EqualError(t, err, InvalidCommandError("ERROR", 1).Error(), "errorCommand with empty file name")
257258
err = errorCommand(s, []string{fileName}, 1)
258259
assert.NoError(t, err, "errorCommand")
259260
// Only some error kinds go to the error output
260261
err = runSqlCmd(t, s, []string{"print N'message'", "RAISERROR(N'Error', 16, 1)", "SELECT 1", ":SETVAR 1", "GO"})
261262
assert.NoError(t, err, "runSqlCmd")
262-
s.SetError(nil)
263263
errText, err := os.ReadFile(file.Name())
264264
if assert.NoError(t, err, "ReadFile") {
265265
assert.Regexp(t, "Msg 50000, Level 16, State 1, Server .*, Line 2"+SqlcmdEol+"Error"+SqlcmdEol, string(errText), "Error file contents: "+string(errText))
266266
}
267+
s.vars.Set("myvar", "stdout")
268+
err = errorCommand(s, []string{"$(myvar)"}, 1)
269+
assert.NoError(t, err, "errorCommand with a variable")
270+
assert.Equal(t, os.Stdout, s.err, "error set to stdout using a variable")
267271
}
268272

269273
func TestOnErrorCommand(t *testing.T) {
@@ -320,6 +324,11 @@ func TestResolveArgumentVariables(t *testing.T) {
320324
if assert.ErrorContains(t, err, UndefinedVariable("var2").Error(), "fail on unresolved variable") {
321325
assert.Empty(t, actual, "fail on unresolved variable")
322326
}
327+
s.Connect.DisableVariableSubstitution = true
328+
input := "$(var1) notvar"
329+
actual, err = resolveArgumentVariables(s, []rune(input), true)
330+
assert.NoError(t, err)
331+
assert.Equal(t, input, actual, "resolveArgumentVariables when DisableVariableSubstitution is false")
323332
}
324333

325334
func TestExecCommand(t *testing.T) {

0 commit comments

Comments
 (0)