@@ -3,6 +3,7 @@ package cmd
3
3
import (
4
4
"bytes"
5
5
"context"
6
+ "fmt"
6
7
"io"
7
8
"io/ioutil"
8
9
"log"
@@ -20,83 +21,103 @@ var connectArgs struct {
20
21
responseTimeout time.Duration
21
22
dialTimeout time.Duration
22
23
endpoint netssh.Endpoint
24
+ numAttempts int
25
+ attemptInterval time.Duration
26
+ }
27
+
28
+ func init () {
29
+ RootCmd .AddCommand (connectCmd )
30
+ connectCmd .Flags ().DurationVar (& connectArgs .killSSHDuration , "killSSH" , 0 , "" )
31
+ connectCmd .Flags ().DurationVar (& connectArgs .waitBeforeRequestDuration , "wait" , 0 , "" )
32
+ connectCmd .Flags ().DurationVar (& connectArgs .responseTimeout , "responseTimeout" , math .MaxInt64 , "" )
33
+ connectCmd .Flags ().DurationVar (& connectArgs .dialTimeout , "dialTimeout" , math .MaxInt64 , "" )
34
+ connectCmd .Flags ().StringVar (& connectArgs .endpoint .Host , "ssh.host" , "" , "" )
35
+ connectCmd .Flags ().StringVar (& connectArgs .endpoint .User , "ssh.user" , "" , "" )
36
+ connectCmd .Flags ().StringVar (& connectArgs .endpoint .IdentityFile , "ssh.identity" , "" , "" )
37
+ connectCmd .Flags ().Uint16Var (& connectArgs .endpoint .Port , "ssh.port" , 22 , "" )
38
+ connectCmd .Flags ().IntVar (& connectArgs .numAttempts , "attempts.count" , 1 , "number of connection attempts, 0 for infinite" )
39
+ connectCmd .Flags ().DurationVar (& connectArgs .attemptInterval , "attempts.interval" , 1 * time .Second , "sleep between connection attempts" )
23
40
}
24
41
25
42
var connectCmd = & cobra.Command {
26
43
Use : "connect" ,
27
44
Short : "connect to server over SSH using proxy" ,
28
45
Run : func (cmd * cobra.Command , args []string ) {
29
-
30
46
log := log .New (os .Stdout , "" , log .Ltime | log .Lmicroseconds | log .Lshortfile )
31
47
32
- log .Printf ("dialing %#v" , connectArgs .endpoint )
33
- log .Printf ("timeout %s" , connectArgs .dialTimeout )
34
- ctx := netssh .ContextWithLog (context .TODO (), log )
35
- dialCtx , dialCancel := context .WithTimeout (ctx , connectArgs .dialTimeout )
36
- outstream , err := netssh .Dial (dialCtx , connectArgs .endpoint )
37
- dialCancel ()
38
- if err == context .DeadlineExceeded {
39
- log .Panic ("dial timeout exceeded" )
40
- } else if err != nil {
41
- log .Panic (err )
48
+ lastPanicked := false
49
+ for a := 0 ; connectArgs .numAttempts == 0 || a < connectArgs .numAttempts ; a ++ {
50
+ log .SetPrefix (fmt .Sprintf ("attempt %03d: " , a ))
51
+ func () {
52
+ defer func () {
53
+ e := recover ()
54
+ lastPanicked = e != nil
55
+ log .Printf ("panicked=%v %s" , lastPanicked , e )
56
+ }()
57
+ connectAttempt (log )
58
+ }()
59
+ time .Sleep (connectArgs .attemptInterval )
42
60
}
43
61
44
- defer func () {
45
- log .Printf ("closing connection in defer" )
46
- err := outstream .Close ()
47
- if err != nil {
48
- log .Printf ("error closing connection in defer: %s" , err )
49
- }
50
- }()
62
+ },
63
+ }
51
64
52
- if connectArgs .killSSHDuration != 0 {
53
- go func () {
54
- time .Sleep (connectArgs .killSSHDuration )
55
- log .Printf ("killing ssh process" )
56
- outstream .CmdCancel ()
57
- }()
58
- }
65
+ func connectAttempt (log * log.Logger ) {
59
66
60
- time .Sleep (connectArgs .waitBeforeRequestDuration )
67
+ log .Printf ("dialing %#v" , connectArgs .endpoint )
68
+ log .Printf ("timeout %s" , connectArgs .dialTimeout )
69
+ ctx := netssh .ContextWithLog (context .TODO (), log )
70
+ dialCtx , dialCancel := context .WithTimeout (ctx , connectArgs .dialTimeout )
71
+ outstream , err := netssh .Dial (dialCtx , connectArgs .endpoint )
72
+ dialCancel ()
73
+ if err == context .DeadlineExceeded {
74
+ log .Panic ("dial timeout exceeded" )
75
+ } else if err != nil {
76
+ log .Panic (err )
77
+ }
61
78
62
- log .Print ("writing request" )
63
- n , err := outstream .Write ([]byte ("b\n " ))
64
- if n != 2 || err != nil {
65
- log .Panic (err )
66
- }
67
- log .Print ("read response" )
68
- _ , err = io .CopyN (ioutil .Discard , outstream , int64 (Bytecount ))
79
+ defer func () {
80
+ log .Printf ("closing connection in defer" )
81
+ err := outstream .Close ()
69
82
if err != nil {
70
- log .Panic ( err )
83
+ log .Printf ( "error closing connection in defer: %s" , err )
71
84
}
85
+ }()
72
86
73
- log .Print ("request for close" )
74
- n , err = outstream .Write ([]byte ("a\n " ))
75
- if n != 2 || err != nil {
76
- log .Panic (err )
77
- }
78
- log .Printf ("wait for close message" )
79
- var resp [2 ]byte
80
- n , err = outstream .Read (resp [:])
81
- if n != 2 || err != nil {
82
- log .Panic (err )
83
- }
84
- if bytes .Compare (resp [:], []byte ("A\n " )) != 0 {
85
- log .Panicf ("unexpected close message: %v" , resp )
86
- }
87
- log .Printf ("received close message" )
87
+ if connectArgs .killSSHDuration != 0 {
88
+ go func () {
89
+ time .Sleep (connectArgs .killSSHDuration )
90
+ log .Printf ("killing ssh process" )
91
+ outstream .CmdCancel ()
92
+ }()
93
+ }
88
94
89
- },
90
- }
95
+ time .Sleep (connectArgs .waitBeforeRequestDuration )
91
96
92
- func init () {
93
- RootCmd .AddCommand (connectCmd )
94
- connectCmd .Flags ().DurationVar (& connectArgs .killSSHDuration , "killSSH" , 0 , "" )
95
- connectCmd .Flags ().DurationVar (& connectArgs .waitBeforeRequestDuration , "wait" , 0 , "" )
96
- connectCmd .Flags ().DurationVar (& connectArgs .responseTimeout , "responseTimeout" , math .MaxInt64 , "" )
97
- connectCmd .Flags ().DurationVar (& connectArgs .dialTimeout , "dialTimeout" , math .MaxInt64 , "" )
98
- connectCmd .Flags ().StringVar (& connectArgs .endpoint .Host , "ssh.host" , "" , "" )
99
- connectCmd .Flags ().StringVar (& connectArgs .endpoint .User , "ssh.user" , "" , "" )
100
- connectCmd .Flags ().StringVar (& connectArgs .endpoint .IdentityFile , "ssh.identity" , "" , "" )
101
- connectCmd .Flags ().Uint16Var (& connectArgs .endpoint .Port , "ssh.port" , 22 , "" )
97
+ log .Print ("writing request" )
98
+ n , err := outstream .Write ([]byte ("b\n " ))
99
+ if n != 2 || err != nil {
100
+ log .Panic (err )
101
+ }
102
+ log .Print ("read response" )
103
+ _ , err = io .CopyN (ioutil .Discard , outstream , int64 (Bytecount ))
104
+ if err != nil {
105
+ log .Panic (err )
106
+ }
107
+
108
+ log .Print ("request for close" )
109
+ n , err = outstream .Write ([]byte ("a\n " ))
110
+ if n != 2 || err != nil {
111
+ log .Panic (err )
112
+ }
113
+ log .Printf ("wait for close message" )
114
+ var resp [2 ]byte
115
+ n , err = outstream .Read (resp [:])
116
+ if n != 2 || err != nil {
117
+ log .Panic (err )
118
+ }
119
+ if bytes .Compare (resp [:], []byte ("A\n " )) != 0 {
120
+ log .Panicf ("unexpected close message: %v" , resp )
121
+ }
122
+ log .Printf ("received close message" )
102
123
}
0 commit comments