@@ -10,6 +10,8 @@ import (
10
10
"sync"
11
11
"syscall"
12
12
"time"
13
+
14
+ "github.com/zrepl/zrepl/util/circlog"
13
15
)
14
16
15
17
type Endpoint struct {
@@ -168,7 +170,7 @@ func (conn *SSHConn) shutdownProcess() *shutdownResult {
168
170
conn .shutdownResult = & shutdownResult {waitErr }
169
171
case <- timeout .C :
170
172
conn .cmdCancel ()
171
- waitErr := <- wait // reuse existing Wait invocation, must not call twice
173
+ waitErr := <- wait // reuse existing Wait invocation, must not call twice
172
174
conn .shutdownResult = & shutdownResult {waitErr }
173
175
}
174
176
return conn .shutdownResult
@@ -220,11 +222,6 @@ type SSHError struct {
220
222
// Error() will try to present a one-line error message unless ssh stderr output is longer than one line
221
223
func (e * SSHError ) Error () string {
222
224
223
- if e .RWCError == io .EOF {
224
- // rwccmd returns io.EOF on exit status 0, but we do not expect ssh to do that
225
- return fmt .Sprintf ("ssh exited unexpectedly with exit status 0" )
226
- }
227
-
228
225
exitErr , ok := e .RWCError .(* exec.ExitError )
229
226
if ! ok {
230
227
return fmt .Sprintf ("ssh: %s" , e .RWCError )
@@ -286,18 +283,31 @@ func Dial(dialCtx context.Context, endpoint Endpoint) (*SSHConn, error) {
286
283
if err != nil {
287
284
return nil , err
288
285
}
289
- // stderr is required for *exec.ExitErr
286
+
287
+ stderrBuf , err := circlog .NewCircularLog (1 << 15 )
288
+ if err != nil {
289
+ panic (err ) // wrong API usage
290
+ }
291
+ cmd .Stderr = stderrBuf
290
292
291
293
if err = cmd .Start (); err != nil {
292
294
return nil , err
293
295
}
296
+ cmdWaitErrOrIOErr := func (ioErr error , what string ) * SSHError {
297
+ werr := cmd .Wait ()
298
+ if werr , ok := werr .(* exec.ExitError ); ok {
299
+ werr .Stderr = []byte (stderrBuf .String ())
300
+ return & SSHError {werr , what }
301
+ }
302
+ return & SSHError {ioErr , what }
303
+ }
294
304
295
305
confErrChan := make (chan error , 1 )
296
306
go func () {
297
307
defer close (confErrChan )
298
308
var buf bytes.Buffer
299
309
if _ , err := io .CopyN (& buf , stdout , int64 (len (banner_msg ))); err != nil {
300
- confErrChan <- & SSHError { err , "read banner" }
310
+ confErrChan <- cmdWaitErrOrIOErr ( err , "read banner" )
301
311
return
302
312
}
303
313
resp := buf .Bytes ()
@@ -314,7 +324,7 @@ func Dial(dialCtx context.Context, endpoint Endpoint) (*SSHConn, error) {
314
324
buf .Reset ()
315
325
buf .Write (begin_msg )
316
326
if _ , err := io .Copy (stdin , & buf ); err != nil {
317
- confErrChan <- & SSHError { err , "send begin message" }
327
+ confErrChan <- cmdWaitErrOrIOErr ( err , "send begin message" )
318
328
return
319
329
}
320
330
}()
@@ -332,6 +342,9 @@ func Dial(dialCtx context.Context, endpoint Endpoint) (*SSHConn, error) {
332
342
for _ = range confErrChan {
333
343
}
334
344
345
+ // TODO collect stderr in this case
346
+ // can probably extend *SSHError for this but need to implement net.Error
347
+
335
348
return nil , dialCtx .Err ()
336
349
337
350
case err := <- confErrChan :
0 commit comments