Skip to content

Commit 7e65011

Browse files
authored
ensure Console.Close on exit (#333)
1 parent 37c04c8 commit 7e65011

File tree

3 files changed

+55
-23
lines changed

3 files changed

+55
-23
lines changed

cmd/sqlcmd/sqlcmd.go

+4-1
Original file line numberDiff line numberDiff line change
@@ -7,9 +7,9 @@ package sqlcmd
77
import (
88
"errors"
99
"fmt"
10-
"github.com/alecthomas/kong"
1110
"os"
1211

12+
"github.com/alecthomas/kong"
1313
"github.com/microsoft/go-mssqldb/azuread"
1414
"github.com/microsoft/go-sqlcmd/internal/localizer"
1515
"github.com/microsoft/go-sqlcmd/pkg/console"
@@ -252,6 +252,9 @@ func run(vars *sqlcmd.Variables, args *SQLCmdArguments) (int, error) {
252252
}
253253

254254
s := sqlcmd.New(line, wd, vars)
255+
// We want the default behavior on ctrl-c - exit the process
256+
s.SetupCloseHandler()
257+
defer s.StopCloseHandler()
255258
s.UnicodeOutputFile = args.UnicodeOutputFile
256259

257260
if args.DisableCmdAndWarn {

internal/sql/mssql.go

+8-3
Original file line numberDiff line numberDiff line change
@@ -5,11 +5,12 @@ package sql
55

66
import (
77
"fmt"
8-
"github.com/microsoft/go-sqlcmd/internal/buffer"
9-
"github.com/microsoft/go-sqlcmd/pkg/console"
108
"os"
119
"strings"
1210

11+
"github.com/microsoft/go-sqlcmd/internal/buffer"
12+
"github.com/microsoft/go-sqlcmd/pkg/console"
13+
1314
"github.com/microsoft/go-sqlcmd/cmd/modern/sqlconfig"
1415
"github.com/microsoft/go-sqlcmd/pkg/sqlcmd"
1516
)
@@ -79,8 +80,12 @@ func (m *mssql) Query(text string) {
7980
err := m.sqlcmd.Run(true, false)
8081
checkErr(err)
8182
} else {
83+
// sqlcmd prints the ErrCtrlC message before returning
84+
// In modern mode we do not exit the process on ctrl-c during interactive mode
8285
err := m.sqlcmd.Run(false, true)
83-
checkErr(err)
86+
if err != sqlcmd.ErrCtrlC {
87+
checkErr(err)
88+
}
8489
}
8590
}
8691

pkg/sqlcmd/sqlcmd.go

+43-19
Original file line numberDiff line numberDiff line change
@@ -71,18 +71,26 @@ type Sqlcmd struct {
7171
echoFileLines bool
7272
// Exitcode is returned to the operating system when the process exits
7373
Exitcode int
74-
Connect *ConnectSettings
75-
vars *Variables
76-
Format Formatter
77-
Query string
78-
Cmd Commands
74+
// Connect controls how Sqlcmd connects to the database
75+
Connect *ConnectSettings
76+
vars *Variables
77+
// Format renders the query output
78+
Format Formatter
79+
// Query is the TSQL query to run
80+
Query string
81+
// Cmd provides the implementation of commands like :list and GO
82+
Cmd Commands
7983
// PrintError allows the host to redirect errors away from the default output. Returns false if the error is not redirected by the host.
80-
PrintError func(msg string, severity uint8) bool
84+
PrintError func(msg string, severity uint8) bool
85+
// UnicodeOutputFile is true when UTF16 file output is needed
8186
UnicodeOutputFile bool
8287
colorizer color.Colorizer
88+
termchan chan os.Signal
8389
}
8490

85-
// New creates a new Sqlcmd instance
91+
// New creates a new Sqlcmd instance.
92+
// The Console instane must be non-nil for Sqlcmd to run in interactive mode.
93+
// The hosting application is responsible for calling Close() on the Console instance before process exit.
8694
func New(l Console, workingDirectory string, vars *Variables) *Sqlcmd {
8795
s := &Sqlcmd{
8896
lineIo: l,
@@ -107,8 +115,9 @@ func (s *Sqlcmd) scanNext() (string, error) {
107115
// Run processes all available batches.
108116
// When once is true it stops after the first query runs.
109117
// When processAll is true it executes any remaining batch content when reaching EOF
118+
// The error returned from Run is mainly of informational value. Its Message will have been printed
119+
// before Run returns.
110120
func (s *Sqlcmd) Run(once bool, processAll bool) error {
111-
setupCloseHandler(s)
112121
iactive := s.lineIo != nil
113122
var lastError error
114123
for {
@@ -160,8 +169,10 @@ func (s *Sqlcmd) Run(once bool, processAll bool) error {
160169
}
161170
}
162171

172+
// Some Console implementations catch the ctrl-c so s.termchan isn't signalled
163173
if err == ErrCtrlC {
164-
os.Exit(0)
174+
s.Exitcode = 0
175+
return err
165176
}
166177
if err != nil && err != io.EOF && (s.Connect.ExitOnError && !s.Connect.IgnoreError) {
167178
// If the error were due to a SQL error, the GO command handler
@@ -392,16 +403,6 @@ func (s *Sqlcmd) getRunnableQuery(q string) string {
392403
return b.String()
393404
}
394405

395-
func setupCloseHandler(s *Sqlcmd) {
396-
c := make(chan os.Signal, 1)
397-
signal.Notify(c, os.Interrupt, syscall.SIGTERM)
398-
go func() {
399-
<-c
400-
s.WriteError(s.GetOutput(), ErrCtrlC)
401-
os.Exit(0)
402-
}()
403-
}
404-
405406
// runQuery runs the query and prints the results
406407
// The return value is based on the first cell of the last column of the last result set.
407408
// If it's numeric, it will be converted to int
@@ -540,3 +541,26 @@ func (s Sqlcmd) Log(_ context.Context, _ msdsn.Log, msg string) {
540541
_, _ = s.GetOutput().Write([]byte("DRIVER:" + msg))
541542
_, _ = s.GetOutput().Write([]byte(SqlcmdEol))
542543
}
544+
545+
// SetupCloseHandler subscribes to the os.Signal channel for SIGTERM.
546+
// When it receives the event due to the user pressing ctrl-c or ctrl-break
547+
// that isn't handled directly by the Console or hosting application,
548+
// it will call Close() on the Console and exit the application.
549+
// Use StopCloseHandler to remove the subscription
550+
func (s *Sqlcmd) SetupCloseHandler() {
551+
s.termchan = make(chan os.Signal, 1)
552+
signal.Notify(s.termchan, os.Interrupt, syscall.SIGTERM)
553+
go func() {
554+
<-s.termchan
555+
s.WriteError(s.GetOutput(), ErrCtrlC)
556+
if s.lineIo != nil {
557+
s.lineIo.Close()
558+
}
559+
os.Exit(0)
560+
}()
561+
}
562+
563+
// StopCloseHandler unsubscribes the Sqlcmd from the SIGTERM signal
564+
func (s *Sqlcmd) StopCloseHandler() {
565+
signal.Stop(s.termchan)
566+
}

0 commit comments

Comments
 (0)