Skip to content

Commit 09d6bc4

Browse files
committed
req Go 1.11, deadline support, full net.Conn + CloseWrite support
1 parent c56ad38 commit 09d6bc4

File tree

8 files changed

+557
-72
lines changed

8 files changed

+557
-72
lines changed

Diff for: Gopkg.lock

+55-1
Some generated files are not rendered by default. Learn more about customizing how changed files appear on GitHub.

Diff for: Gopkg.toml

+3
Original file line numberDiff line numberDiff line change
@@ -8,3 +8,6 @@ ignored = ["github.com/spf13/cobra"]
88
branch = "master"
99
name = "github.com/problame/go-rwccmd"
1010

11+
[[constraint]]
12+
name = "github.com/theckman/goconstraint"
13+
version = "1.11.0"

Diff for: dial.go

+143-29
Original file line numberDiff line numberDiff line change
@@ -1,14 +1,15 @@
11
package netssh
22

33
import (
4-
"net"
5-
"fmt"
4+
"bytes"
65
"context"
6+
"fmt"
77
"io"
8-
"bytes"
9-
"github.com/problame/go-rwccmd"
8+
"net"
109
"os/exec"
10+
"sync"
1111
"syscall"
12+
"time"
1213
)
1314

1415
type Endpoint struct {
@@ -45,54 +46,152 @@ func (e Endpoint) CmdArgs() (cmd string, args []string, env []string) {
4546
return
4647
}
4748

48-
// FIXME: should conform to net.Conn one day, but deadlines as required by net.Conn are complicated:
49-
// it requires to keep the connection open when the deadline is exceeded, but rwcconn.Cmd does not provide Deadlines
50-
// for good reason, see their docs for details.
5149
type SSHConn struct {
52-
c *rwccmd.Cmd
50+
cmd *exec.Cmd
51+
stdin io.WriteCloser
52+
stdout io.ReadCloser
53+
54+
shutdownMtx sync.Mutex
55+
shutdownResult *shutdownResult
56+
cmdCancel context.CancelFunc
5357
}
5458

55-
const go_network string = "SSH"
59+
const go_network string = "netssh"
5660

57-
type addr struct {
61+
type clientAddr struct {
5862
pid int
5963
}
6064

61-
func (a addr) Network() string {
65+
func (a clientAddr) Network() string {
6266
return go_network
6367
}
6468

65-
func (a addr) String() string {
69+
func (a clientAddr) String() string {
6670
return fmt.Sprintf("pid=%d", a.pid)
6771
}
6872

6973
func (conn *SSHConn) LocalAddr() net.Addr {
70-
return addr{conn.c.Pid()}
74+
proc := conn.cmd.Process
75+
if proc == nil {
76+
return clientAddr{-1}
77+
}
78+
return clientAddr{proc.Pid}
7179
}
7280

7381
func (conn *SSHConn) RemoteAddr() net.Addr {
74-
return addr{conn.c.Pid()}
82+
return conn.LocalAddr()
7583
}
7684

85+
// Read implements io.Reader.
86+
// It returns *IOError for any non-nil error that is != io.EOF.
7787
func (conn *SSHConn) Read(p []byte) (int, error) {
78-
return conn.c.Read(p)
88+
n, err := conn.stdout.Read(p)
89+
if err != nil && err != io.EOF {
90+
return n, &IOError{err}
91+
}
92+
return n, nil
7993
}
8094

95+
// Write implements io.Writer.
96+
// It returns *IOError for any error != nil.
8197
func (conn *SSHConn) Write(p []byte) (int, error) {
82-
return conn.c.Write(p)
98+
n, err := conn.stdin.Write(p)
99+
if err != nil {
100+
return n, &IOError{err}
101+
}
102+
return n, nil
103+
}
104+
105+
func (conn *SSHConn) CloseWrite() error {
106+
return conn.stdin.Close()
107+
}
108+
109+
type deadliner interface {
110+
SetReadDeadline(time.Time) error
111+
SetWriteDeadline(time.Time) error
112+
}
113+
114+
func (conn *SSHConn) SetReadDeadline(t time.Time) error {
115+
// type assertion is covered by test TestExecCmdPipesDeadlineBehavior
116+
return conn.stdout.(deadliner).SetReadDeadline(t)
117+
}
118+
119+
func (conn *SSHConn) SetWriteDeadline(t time.Time) error {
120+
// type assertion is covered by test TestExecCmdPipesDeadlineBehavior
121+
return conn.stdin.(deadliner).SetWriteDeadline(t)
122+
}
123+
124+
func (conn *SSHConn) SetDeadline(t time.Time) error {
125+
// try both
126+
rerr := conn.SetReadDeadline(t)
127+
werr := conn.SetWriteDeadline(t)
128+
if rerr != nil {
129+
return rerr
130+
}
131+
if werr != nil {
132+
return werr
133+
}
134+
return nil
135+
}
136+
137+
func (conn *SSHConn) Close() error {
138+
conn.shutdownProcess()
139+
return nil // FIXME: waitError will be non-zero because we signaled it, shutdownProcess needs to distinguish that
140+
}
141+
142+
type shutdownResult struct {
143+
waitErr error
144+
}
145+
146+
func (conn *SSHConn) shutdownProcess() *shutdownResult {
147+
conn.shutdownMtx.Lock()
148+
defer conn.shutdownMtx.Unlock()
149+
150+
if conn.shutdownResult != nil {
151+
return conn.shutdownResult
152+
}
153+
154+
termSuccessful := make(chan error, 1)
155+
go func() {
156+
if err := conn.cmd.Process.Signal(syscall.SIGTERM); err != nil {
157+
// TODO log error
158+
return
159+
}
160+
termSuccessful <- conn.cmd.Wait()
161+
}()
162+
163+
timeout := time.NewTimer(1 * time.Second) // FIXME const
164+
defer timeout.Stop()
165+
166+
select {
167+
case waitErr := <-termSuccessful:
168+
conn.shutdownResult = &shutdownResult{waitErr}
169+
case <-timeout.C:
170+
conn.cmdCancel()
171+
waitErr := conn.cmd.Wait()
172+
conn.shutdownResult = &shutdownResult{waitErr}
173+
}
174+
return conn.shutdownResult
83175
}
84176

85-
func (conn *SSHConn) Close() (error) {
86-
return conn.c.Close()
177+
// Cmd returns the underlying *exec.Cmd (the ssh client process)
178+
// Use read-only, should not be necessary for regular users.
179+
func (conn *SSHConn) Cmd() *exec.Cmd {
180+
return conn.cmd
87181
}
88182

89-
// Use at your own risk...
90-
func (conn *SSHConn) Cmd() *rwccmd.Cmd {
91-
return conn.c
183+
// CmdCancel bypasses the normal shutdown mechanism of SSHConn
184+
// (that is, calling Close) and cancels the process's context,
185+
// which usually results in SIGKILL being sent to the process.
186+
// Intended for integration tests, regular users shouldn't use it.
187+
func (conn *SSHConn) CmdCancel() {
188+
conn.cmdCancel()
92189
}
93190

94191
const bannerMessageLen = 31
192+
95193
var messages = make(map[string][]byte)
194+
96195
func mustMessage(str string) []byte {
97196
if len(str) > bannerMessageLen {
98197
panic("message length must be smaller than bannerMessageLen")
@@ -108,12 +207,13 @@ func mustMessage(str string) []byte {
108207
buf.Write(bytes.Repeat([]byte{0}, bannerMessageLen-n))
109208
return buf.Bytes()
110209
}
210+
111211
var banner_msg = mustMessage("SSHCON_HELO")
112212
var proxy_error_msg = mustMessage("SSHCON_PROXY_ERROR")
113213
var begin_msg = mustMessage("SSHCON_BEGIN")
114214

115215
type SSHError struct {
116-
RWCError error
216+
RWCError error
117217
WhileActivity string
118218
}
119219

@@ -172,23 +272,31 @@ func (e ProtocolError) Error() string {
172272
// If the handshake completes, dialCtx's deadline does not affect the returned connection.
173273
//
174274
// Errors returned are either dialCtx.Err(), or intances of ProtocolError or *SSHError
175-
func Dial(dialCtx context.Context, endpoint Endpoint) (*SSHConn , error) {
275+
func Dial(dialCtx context.Context, endpoint Endpoint) (*SSHConn, error) {
176276

177277
sshCmd, sshArgs, sshEnv := endpoint.CmdArgs()
178278
commandCtx, commandCancel := context.WithCancel(context.Background())
179-
cmd, err := rwccmd.CommandContext(commandCtx, sshCmd, sshArgs, sshEnv)
279+
cmd := exec.CommandContext(commandCtx, sshCmd, sshArgs...)
280+
cmd.Env = sshEnv
281+
stdin, err := cmd.StdinPipe()
282+
if err != nil {
283+
return nil, err
284+
}
285+
stdout, err := cmd.StdoutPipe()
180286
if err != nil {
181287
return nil, err
182288
}
289+
// stderr is required for *exec.ExitErr
290+
183291
if err = cmd.Start(); err != nil {
184292
return nil, err
185293
}
186294

187-
confErrChan := make(chan error)
295+
confErrChan := make(chan error, 1)
188296
go func() {
189297
defer close(confErrChan)
190298
var buf bytes.Buffer
191-
if _, err := io.CopyN(&buf, cmd, int64(len(banner_msg))); err != nil {
299+
if _, err := io.CopyN(&buf, stdout, int64(len(banner_msg))); err != nil {
192300
confErrChan <- &SSHError{err, "read banner"}
193301
return
194302
}
@@ -205,7 +313,7 @@ func Dial(dialCtx context.Context, endpoint Endpoint) (*SSHConn , error) {
205313
}
206314
buf.Reset()
207315
buf.Write(begin_msg)
208-
if _, err := io.Copy(cmd, &buf); err != nil {
316+
if _, err := io.Copy(stdin, &buf); err != nil {
209317
confErrChan <- &SSHError{err, "send begin message"}
210318
return
211319
}
@@ -221,7 +329,8 @@ func Dial(dialCtx context.Context, endpoint Endpoint) (*SSHConn , error) {
221329
// ignore the error and return the cancellation cause
222330

223331
// draining always terminates because we know the channel is always closed
224-
for _ = range confErrChan {}
332+
for _ = range confErrChan {
333+
}
225334

226335
return nil, dialCtx.Err()
227336

@@ -232,5 +341,10 @@ func Dial(dialCtx context.Context, endpoint Endpoint) (*SSHConn , error) {
232341
}
233342
}
234343

235-
return &SSHConn{cmd}, nil
344+
return &SSHConn{
345+
cmd: cmd,
346+
stdin: stdin,
347+
stdout: stdout,
348+
cmdCancel: commandCancel,
349+
}, nil
236350
}

0 commit comments

Comments
 (0)