Skip to content

Commit 984ce91

Browse files
committed
Dial(): handle expiring dialCtx correctly
The returned connection must not be affected by an expired dialCtx. Make this clear through argument names.
1 parent ffa145d commit 984ce91

File tree

2 files changed

+23
-6
lines changed

2 files changed

+23
-6
lines changed

dial.go

Lines changed: 12 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -110,10 +110,16 @@ var banner_msg = mustMessage("SSHCON_HELO")
110110
var proxy_error_msg = mustMessage("SSHCON_PROXY_ERROR")
111111
var begin_msg = mustMessage("SSHCON_BEGIN")
112112

113-
func Dial(ctx context.Context, endpoint Endpoint) (*SSHConn , error) {
113+
// Dial connects to the remote endpoint where it expects a command executing Proxy().
114+
// Dial performs a handshake consisting of the exchange of banner messages before returning the connection.
115+
// If the handshake cannot be completed before dialCtx is Done(), the underlying ssh command is killed
116+
// and the dialCtx.Err() returned.
117+
// If the handshake completes, dialCtx's deadline does not affect the returned connection.
118+
func Dial(dialCtx context.Context, endpoint Endpoint) (*SSHConn , error) {
114119

115120
sshCmd, sshArgs, sshEnv := endpoint.CmdArgs()
116-
cmd, err := rwccmd.CommandContext(ctx, sshCmd, sshArgs, sshEnv)
121+
commandCtx, commandCancel := context.WithCancel(context.Background())
122+
cmd, err := rwccmd.CommandContext(commandCtx, sshCmd, sshArgs, sshEnv)
117123
if err != nil {
118124
return nil, err
119125
}
@@ -149,10 +155,12 @@ func Dial(ctx context.Context, endpoint Endpoint) (*SSHConn , error) {
149155
}()
150156

151157
select {
152-
case <-ctx.Done():
153-
return nil, ctx.Err()
158+
case <-dialCtx.Done():
159+
commandCancel()
160+
return nil, dialCtx.Err()
154161
case err := <-confErrChan:
155162
if err != nil {
163+
commandCancel()
156164
return nil, err
157165
}
158166
}

example/cmd/connect.go

Lines changed: 11 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -18,6 +18,7 @@ var connectArgs struct {
1818
killSSHDuration time.Duration
1919
waitBeforeRequestDuration time.Duration
2020
responseTimeout time.Duration
21+
dialTimeout time.Duration
2122
endpoint netssh.Endpoint
2223

2324
}
@@ -29,12 +30,19 @@ var connectCmd = &cobra.Command{
2930

3031
log := log.New(os.Stdout, "", log.Ltime|log.Lmicroseconds|log.Lshortfile)
3132

33+
log.Print("dialing %#v", connectArgs.endpoint)
34+
log.Printf("timeout %s", connectArgs.dialTimeout)
3235
ctx := netssh.ContextWithLog(context.TODO(), log)
3336
ctx = rwccmd.ContextWithLog(ctx, log)
34-
outstream, err := netssh.Dial(ctx, connectArgs.endpoint)
35-
if err != nil {
37+
dialCtx, dialCancel := context.WithTimeout(ctx, connectArgs.dialTimeout)
38+
outstream, err := netssh.Dial(dialCtx, connectArgs.endpoint)
39+
dialCancel()
40+
if err == context.DeadlineExceeded {
41+
log.Panic("dial timeout exceeded")
42+
} else if err != nil {
3643
log.Panic(err)
3744
}
45+
3846
defer func() {
3947
log.Printf("closing connection in defer")
4048
err := outstream.Close()
@@ -86,6 +94,7 @@ func init() {
8694
connectCmd.Flags().DurationVar(&connectArgs.killSSHDuration, "killSSH",0, "")
8795
connectCmd.Flags().DurationVar(&connectArgs.waitBeforeRequestDuration, "wait",0, "")
8896
connectCmd.Flags().DurationVar(&connectArgs.responseTimeout, "responseTimeout",math.MaxInt64, "")
97+
connectCmd.Flags().DurationVar(&connectArgs.dialTimeout, "dialTimeout",math.MaxInt64, "")
8998
connectCmd.Flags().StringVar(&connectArgs.endpoint.Host, "ssh.host", "", "")
9099
connectCmd.Flags().StringVar(&connectArgs.endpoint.User, "ssh.user", "", "")
91100
connectCmd.Flags().StringVar(&connectArgs.endpoint.IdentityFile, "ssh.identity", "", "")

0 commit comments

Comments
 (0)