Skip to content

Commit 553f5a1

Browse files
authored
Implement :ED (#146)
* implement ED command * don't omit blank lines * fix test for linux * fix space in test
1 parent cbf535e commit 553f5a1

File tree

8 files changed

+89
-11
lines changed

8 files changed

+89
-11
lines changed

pkg/sqlcmd/batch.go

+1-1
Original file line numberDiff line numberDiff line change
@@ -147,7 +147,7 @@ parse:
147147
if err == nil {
148148
i = min(i, b.rawlen)
149149
empty := isEmptyLine(b.raw, 0, i)
150-
appendLine := b.quote != 0 || b.comment || !empty
150+
appendLine := true
151151
if !b.comment && command != nil && empty {
152152
appendLine = false
153153
}

pkg/sqlcmd/commands.go

+52-4
Original file line numberDiff line numberDiff line change
@@ -26,6 +26,8 @@ type Command struct {
2626
action func(*Sqlcmd, []string, uint) error
2727
// Name of the command
2828
name string
29+
// whether the command is a system command
30+
isSystem bool
2931
}
3032

3133
// Commands is the set of sqlcmd command implementations
@@ -89,9 +91,16 @@ func newCommands() Commands {
8991
name: "CONNECT",
9092
},
9193
"EXEC": {
92-
regex: regexp.MustCompile(`(?im)^[ \t]*?:?!!(?:[ \t]+(.*$)|$)`),
93-
action: execCommand,
94-
name: "EXEC",
94+
regex: regexp.MustCompile(`(?im)^[ \t]*?:?!!(?:[ \t]+(.*$)|$)`),
95+
action: execCommand,
96+
name: "EXEC",
97+
isSystem: true,
98+
},
99+
"EDIT": {
100+
regex: regexp.MustCompile(`(?im)^[\t ]*?:?ED(?:[ \t]+(.*$)|$)`),
101+
action: editCommand,
102+
name: "EDIT",
103+
isSystem: true,
95104
},
96105
}
97106
}
@@ -103,8 +112,13 @@ func (c Commands) DisableSysCommands(exitOnCall bool) {
103112
if exitOnCall {
104113
f = errorDisabled
105114
}
106-
c["EXEC"].action = f
115+
for _, cmd := range c {
116+
if cmd.isSystem {
117+
cmd.action = f
118+
}
119+
}
107120
}
121+
108122
func (c Commands) matchCommand(line string) (*Command, []string) {
109123
for _, cmd := range c {
110124
matchedCommand := cmd.regex.FindStringSubmatch(line)
@@ -411,6 +425,40 @@ func execCommand(s *Sqlcmd, args []string, line uint) error {
411425
return nil
412426
}
413427

428+
func editCommand(s *Sqlcmd, args []string, line uint) error {
429+
if args != nil && strings.TrimSpace(args[0]) != "" {
430+
return InvalidCommandError("ED", line)
431+
}
432+
file, err := os.CreateTemp("", "sq*.sql")
433+
if err != nil {
434+
return err
435+
}
436+
fileName := file.Name()
437+
defer os.Remove(fileName)
438+
text := s.batch.String()
439+
if s.batch.State() == "-" {
440+
text = fmt.Sprintf("%s%s", text, SqlcmdEol)
441+
}
442+
_, err = file.WriteString(text)
443+
if err != nil {
444+
return err
445+
}
446+
file.Close()
447+
cmd := sysCommand(s.vars.TextEditor() + " " + `"` + fileName + `"`)
448+
cmd.Stderr = s.GetError()
449+
cmd.Stdout = s.GetOutput()
450+
err = cmd.Run()
451+
if err != nil {
452+
return err
453+
}
454+
wasEcho := s.echoFileLines
455+
s.echoFileLines = true
456+
s.batch.Reset(nil)
457+
_ = s.IncludeFile(fileName, false)
458+
s.echoFileLines = wasEcho
459+
return nil
460+
}
461+
414462
func resolveArgumentVariables(s *Sqlcmd, arg []rune, failOnUnresolved bool) (string, error) {
415463
var b *strings.Builder
416464
end := len(arg)

pkg/sqlcmd/commands_test.go

+11
Original file line numberDiff line numberDiff line change
@@ -287,3 +287,14 @@ func TestDisableSysCommandBlocksExec(t *testing.T) {
287287
assert.Equal(t, 1, s.Exitcode, "ExitCode after error")
288288
}
289289
}
290+
291+
func TestEditCommand(t *testing.T) {
292+
s, buf := setupSqlCmdWithMemoryOutput(t)
293+
defer buf.Close()
294+
s.vars.Set(SQLCMDEDITOR, "echo select 5000> ")
295+
c := []string{"set nocount on", "go", "select 100", ":ed", "go"}
296+
err := runSqlCmd(t, s, c)
297+
if assert.NoError(t, err, ":ed should not raise error") {
298+
assert.Equal(t, "1> select 5000"+SqlcmdEol+"5000"+SqlcmdEol+SqlcmdEol, buf.buf.String(), "Incorrect output from query after :ed command")
299+
}
300+
}

pkg/sqlcmd/exec_darwin.go

+2
Original file line numberDiff line numberDiff line change
@@ -14,3 +14,5 @@ func comSpec() string {
1414
// /bin/sh will be a link to the shell
1515
return `/bin/sh`
1616
}
17+
18+
const defaultEditor = "vi"

pkg/sqlcmd/exec_linux.go

+2
Original file line numberDiff line numberDiff line change
@@ -14,3 +14,5 @@ func comSpec() string {
1414
// /bin/sh will be a link to the shell
1515
return `/bin/sh`
1616
}
17+
18+
const defaultEditor = "vi"

pkg/sqlcmd/exec_windows.go

+2
Original file line numberDiff line numberDiff line change
@@ -23,3 +23,5 @@ func comSpec() string {
2323
func comArgs(args string) string {
2424
return `/c ` + args
2525
}
26+
27+
const defaultEditor = "notepad.exe"

pkg/sqlcmd/sqlcmd.go

+13-5
Original file line numberDiff line numberDiff line change
@@ -62,6 +62,7 @@ type Sqlcmd struct {
6262
out io.WriteCloser
6363
err io.WriteCloser
6464
batch *Batch
65+
echoFileLines bool
6566
// Exitcode is returned to the operating system when the process exits
6667
Exitcode int
6768
Connect *ConnectSettings
@@ -310,6 +311,7 @@ func (s *Sqlcmd) IncludeFile(path string, processAll bool) error {
310311
buf := make([]byte, maxLineBuffer)
311312
scanner.Buffer(buf, maxLineBuffer)
312313
curLine := s.batch.read
314+
echoFileLines := s.echoFileLines
313315
s.batch.read = func() (string, error) {
314316
if !scanner.Scan() {
315317
err := scanner.Err()
@@ -318,14 +320,20 @@ func (s *Sqlcmd) IncludeFile(path string, processAll bool) error {
318320
}
319321
return "", err
320322
}
321-
return scanner.Text(), nil
323+
t := scanner.Text()
324+
if echoFileLines {
325+
_, _ = s.GetOutput().Write([]byte(s.Prompt() + t + SqlcmdEol))
326+
}
327+
return t, nil
322328
}
323329
err = s.Run(false, processAll)
324330
s.batch.read = curLine
325-
if s.batch.State() == "=" {
326-
s.batch.batchline = 1
327-
} else {
328-
s.batch.batchline = b + 1
331+
if !s.echoFileLines {
332+
if s.batch.State() == "=" {
333+
s.batch.batchline = 1
334+
} else {
335+
s.batch.batchline = b + 1
336+
}
329337
}
330338
return err
331339
}

pkg/sqlcmd/variables.go

+6-1
Original file line numberDiff line numberDiff line change
@@ -179,6 +179,11 @@ func (v Variables) Format() string {
179179
return "horizontal"
180180
}
181181

182+
// TextEditor is the query editor application launched by the :ED command
183+
func (v Variables) TextEditor() string {
184+
return v[SQLCMDEDITOR]
185+
}
186+
182187
func mustValue(val string) int64 {
183188
var n int64
184189
_, err := fmt.Sscanf(val, "%d", &n)
@@ -193,7 +198,7 @@ func mustValue(val string) int64 {
193198
var defaultVariables = Variables{
194199
SQLCMDCOLSEP: " ",
195200
SQLCMDCOLWIDTH: "0",
196-
SQLCMDEDITOR: "edit.com",
201+
SQLCMDEDITOR: defaultEditor,
197202
SQLCMDERRORLEVEL: "0",
198203
SQLCMDHEADERS: "0",
199204
SQLCMDLOGINTIMEOUT: "30",

0 commit comments

Comments
 (0)