Skip to content

Commit 5dcdb2c

Browse files
committed
dial: better error handling if ssh command exits with non-zero exit status
SSHError.Error() relied on go-rwccmd behavior of returning io.EOF if the ssh binary exited with status code 0. We no longe ruse go-rwccmd => capture Stderr ourselves using zrepl's circlog (depending on zrepl is not pretty, but since this package is supposedly only used by zrepl ATM, this is fine) refs zrepl/zrepl#237
1 parent 1668537 commit 5dcdb2c

File tree

1 file changed

+22
-9
lines changed

1 file changed

+22
-9
lines changed

dial.go

+22-9
Original file line numberDiff line numberDiff line change
@@ -10,6 +10,8 @@ import (
1010
"sync"
1111
"syscall"
1212
"time"
13+
14+
"github.com/zrepl/zrepl/util/circlog"
1315
)
1416

1517
type Endpoint struct {
@@ -168,7 +170,7 @@ func (conn *SSHConn) shutdownProcess() *shutdownResult {
168170
conn.shutdownResult = &shutdownResult{waitErr}
169171
case <-timeout.C:
170172
conn.cmdCancel()
171-
waitErr := <- wait // reuse existing Wait invocation, must not call twice
173+
waitErr := <-wait // reuse existing Wait invocation, must not call twice
172174
conn.shutdownResult = &shutdownResult{waitErr}
173175
}
174176
return conn.shutdownResult
@@ -220,11 +222,6 @@ type SSHError struct {
220222
// Error() will try to present a one-line error message unless ssh stderr output is longer than one line
221223
func (e *SSHError) Error() string {
222224

223-
if e.RWCError == io.EOF {
224-
// rwccmd returns io.EOF on exit status 0, but we do not expect ssh to do that
225-
return fmt.Sprintf("ssh exited unexpectedly with exit status 0")
226-
}
227-
228225
exitErr, ok := e.RWCError.(*exec.ExitError)
229226
if !ok {
230227
return fmt.Sprintf("ssh: %s", e.RWCError)
@@ -286,18 +283,31 @@ func Dial(dialCtx context.Context, endpoint Endpoint) (*SSHConn, error) {
286283
if err != nil {
287284
return nil, err
288285
}
289-
// stderr is required for *exec.ExitErr
286+
287+
stderrBuf, err := circlog.NewCircularLog(1 << 15)
288+
if err != nil {
289+
panic(err) // wrong API usage
290+
}
291+
cmd.Stderr = stderrBuf
290292

291293
if err = cmd.Start(); err != nil {
292294
return nil, err
293295
}
296+
cmdWaitErrOrIOErr := func(ioErr error, what string) *SSHError {
297+
werr := cmd.Wait()
298+
if werr, ok := werr.(*exec.ExitError); ok {
299+
werr.Stderr = []byte(stderrBuf.String())
300+
return &SSHError{werr, what}
301+
}
302+
return &SSHError{ioErr, what}
303+
}
294304

295305
confErrChan := make(chan error, 1)
296306
go func() {
297307
defer close(confErrChan)
298308
var buf bytes.Buffer
299309
if _, err := io.CopyN(&buf, stdout, int64(len(banner_msg))); err != nil {
300-
confErrChan <- &SSHError{err, "read banner"}
310+
confErrChan <- cmdWaitErrOrIOErr(err, "read banner")
301311
return
302312
}
303313
resp := buf.Bytes()
@@ -314,7 +324,7 @@ func Dial(dialCtx context.Context, endpoint Endpoint) (*SSHConn, error) {
314324
buf.Reset()
315325
buf.Write(begin_msg)
316326
if _, err := io.Copy(stdin, &buf); err != nil {
317-
confErrChan <- &SSHError{err, "send begin message"}
327+
confErrChan <- cmdWaitErrOrIOErr(err, "send begin message")
318328
return
319329
}
320330
}()
@@ -332,6 +342,9 @@ func Dial(dialCtx context.Context, endpoint Endpoint) (*SSHConn, error) {
332342
for _ = range confErrChan {
333343
}
334344

345+
// TODO collect stderr in this case
346+
// can probably extend *SSHError for this but need to implement net.Error
347+
335348
return nil, dialCtx.Err()
336349

337350
case err := <-confErrChan:

0 commit comments

Comments
 (0)