Skip to content

Commit d9b26a4

Browse files
committed
example: option for running multiple attempts
1 parent 18d8aa6 commit d9b26a4

File tree

2 files changed

+82
-62
lines changed

2 files changed

+82
-62
lines changed

example/cmd/connect.go

+82-61
Original file line numberDiff line numberDiff line change
@@ -3,6 +3,7 @@ package cmd
33
import (
44
"bytes"
55
"context"
6+
"fmt"
67
"io"
78
"io/ioutil"
89
"log"
@@ -20,83 +21,103 @@ var connectArgs struct {
2021
responseTimeout time.Duration
2122
dialTimeout time.Duration
2223
endpoint netssh.Endpoint
24+
numAttempts int
25+
attemptInterval time.Duration
26+
}
27+
28+
func init() {
29+
RootCmd.AddCommand(connectCmd)
30+
connectCmd.Flags().DurationVar(&connectArgs.killSSHDuration, "killSSH", 0, "")
31+
connectCmd.Flags().DurationVar(&connectArgs.waitBeforeRequestDuration, "wait", 0, "")
32+
connectCmd.Flags().DurationVar(&connectArgs.responseTimeout, "responseTimeout", math.MaxInt64, "")
33+
connectCmd.Flags().DurationVar(&connectArgs.dialTimeout, "dialTimeout", math.MaxInt64, "")
34+
connectCmd.Flags().StringVar(&connectArgs.endpoint.Host, "ssh.host", "", "")
35+
connectCmd.Flags().StringVar(&connectArgs.endpoint.User, "ssh.user", "", "")
36+
connectCmd.Flags().StringVar(&connectArgs.endpoint.IdentityFile, "ssh.identity", "", "")
37+
connectCmd.Flags().Uint16Var(&connectArgs.endpoint.Port, "ssh.port", 22, "")
38+
connectCmd.Flags().IntVar(&connectArgs.numAttempts, "attempts.count", 1, "number of connection attempts, 0 for infinite")
39+
connectCmd.Flags().DurationVar(&connectArgs.attemptInterval, "attempts.interval", 1*time.Second, "sleep between connection attempts")
2340
}
2441

2542
var connectCmd = &cobra.Command{
2643
Use: "connect",
2744
Short: "connect to server over SSH using proxy",
2845
Run: func(cmd *cobra.Command, args []string) {
29-
3046
log := log.New(os.Stdout, "", log.Ltime|log.Lmicroseconds|log.Lshortfile)
3147

32-
log.Printf("dialing %#v", connectArgs.endpoint)
33-
log.Printf("timeout %s", connectArgs.dialTimeout)
34-
ctx := netssh.ContextWithLog(context.TODO(), log)
35-
dialCtx, dialCancel := context.WithTimeout(ctx, connectArgs.dialTimeout)
36-
outstream, err := netssh.Dial(dialCtx, connectArgs.endpoint)
37-
dialCancel()
38-
if err == context.DeadlineExceeded {
39-
log.Panic("dial timeout exceeded")
40-
} else if err != nil {
41-
log.Panic(err)
48+
lastPanicked := false
49+
for a := 0; connectArgs.numAttempts == 0 || a < connectArgs.numAttempts; a++ {
50+
log.SetPrefix(fmt.Sprintf("attempt %03d: ", a))
51+
func() {
52+
defer func() {
53+
e := recover()
54+
lastPanicked = e != nil
55+
log.Printf("panicked=%v %s", lastPanicked, e)
56+
}()
57+
connectAttempt(log)
58+
}()
59+
time.Sleep(connectArgs.attemptInterval)
4260
}
4361

44-
defer func() {
45-
log.Printf("closing connection in defer")
46-
err := outstream.Close()
47-
if err != nil {
48-
log.Printf("error closing connection in defer: %s", err)
49-
}
50-
}()
62+
},
63+
}
5164

52-
if connectArgs.killSSHDuration != 0 {
53-
go func() {
54-
time.Sleep(connectArgs.killSSHDuration)
55-
log.Printf("killing ssh process")
56-
outstream.CmdCancel()
57-
}()
58-
}
65+
func connectAttempt(log *log.Logger) {
5966

60-
time.Sleep(connectArgs.waitBeforeRequestDuration)
67+
log.Printf("dialing %#v", connectArgs.endpoint)
68+
log.Printf("timeout %s", connectArgs.dialTimeout)
69+
ctx := netssh.ContextWithLog(context.TODO(), log)
70+
dialCtx, dialCancel := context.WithTimeout(ctx, connectArgs.dialTimeout)
71+
outstream, err := netssh.Dial(dialCtx, connectArgs.endpoint)
72+
dialCancel()
73+
if err == context.DeadlineExceeded {
74+
log.Panic("dial timeout exceeded")
75+
} else if err != nil {
76+
log.Panic(err)
77+
}
6178

62-
log.Print("writing request")
63-
n, err := outstream.Write([]byte("b\n"))
64-
if n != 2 || err != nil {
65-
log.Panic(err)
66-
}
67-
log.Print("read response")
68-
_, err = io.CopyN(ioutil.Discard, outstream, int64(Bytecount))
79+
defer func() {
80+
log.Printf("closing connection in defer")
81+
err := outstream.Close()
6982
if err != nil {
70-
log.Panic(err)
83+
log.Printf("error closing connection in defer: %s", err)
7184
}
85+
}()
7286

73-
log.Print("request for close")
74-
n, err = outstream.Write([]byte("a\n"))
75-
if n != 2 || err != nil {
76-
log.Panic(err)
77-
}
78-
log.Printf("wait for close message")
79-
var resp [2]byte
80-
n, err = outstream.Read(resp[:])
81-
if n != 2 || err != nil {
82-
log.Panic(err)
83-
}
84-
if bytes.Compare(resp[:], []byte("A\n")) != 0 {
85-
log.Panicf("unexpected close message: %v", resp)
86-
}
87-
log.Printf("received close message")
87+
if connectArgs.killSSHDuration != 0 {
88+
go func() {
89+
time.Sleep(connectArgs.killSSHDuration)
90+
log.Printf("killing ssh process")
91+
outstream.CmdCancel()
92+
}()
93+
}
8894

89-
},
90-
}
95+
time.Sleep(connectArgs.waitBeforeRequestDuration)
9196

92-
func init() {
93-
RootCmd.AddCommand(connectCmd)
94-
connectCmd.Flags().DurationVar(&connectArgs.killSSHDuration, "killSSH", 0, "")
95-
connectCmd.Flags().DurationVar(&connectArgs.waitBeforeRequestDuration, "wait", 0, "")
96-
connectCmd.Flags().DurationVar(&connectArgs.responseTimeout, "responseTimeout", math.MaxInt64, "")
97-
connectCmd.Flags().DurationVar(&connectArgs.dialTimeout, "dialTimeout", math.MaxInt64, "")
98-
connectCmd.Flags().StringVar(&connectArgs.endpoint.Host, "ssh.host", "", "")
99-
connectCmd.Flags().StringVar(&connectArgs.endpoint.User, "ssh.user", "", "")
100-
connectCmd.Flags().StringVar(&connectArgs.endpoint.IdentityFile, "ssh.identity", "", "")
101-
connectCmd.Flags().Uint16Var(&connectArgs.endpoint.Port, "ssh.port", 22, "")
97+
log.Print("writing request")
98+
n, err := outstream.Write([]byte("b\n"))
99+
if n != 2 || err != nil {
100+
log.Panic(err)
101+
}
102+
log.Print("read response")
103+
_, err = io.CopyN(ioutil.Discard, outstream, int64(Bytecount))
104+
if err != nil {
105+
log.Panic(err)
106+
}
107+
108+
log.Print("request for close")
109+
n, err = outstream.Write([]byte("a\n"))
110+
if n != 2 || err != nil {
111+
log.Panic(err)
112+
}
113+
log.Printf("wait for close message")
114+
var resp [2]byte
115+
n, err = outstream.Read(resp[:])
116+
if n != 2 || err != nil {
117+
log.Panic(err)
118+
}
119+
if bytes.Compare(resp[:], []byte("A\n")) != 0 {
120+
log.Panicf("unexpected close message: %v", resp)
121+
}
122+
log.Printf("received close message")
102123
}

example/cmd/serve.go

-1
Original file line numberDiff line numberDiff line change
@@ -33,7 +33,6 @@ var serveCmd = &cobra.Command{
3333
rwc, err := listener.Accept()
3434
defer rwc.Close()
3535

36-
3736
log.Print("urandom")
3837
rand, err := os.Open("/dev/urandom")
3938
if err != nil {

0 commit comments

Comments
 (0)