1
1
package netssh
2
2
3
3
import (
4
- "net"
5
- "fmt"
4
+ "bytes"
6
5
"context"
6
+ "fmt"
7
7
"io"
8
- "bytes"
9
- "github.com/problame/go-rwccmd"
8
+ "net"
10
9
"os/exec"
10
+ "sync"
11
11
"syscall"
12
+ "time"
12
13
)
13
14
14
15
type Endpoint struct {
@@ -45,54 +46,152 @@ func (e Endpoint) CmdArgs() (cmd string, args []string, env []string) {
45
46
return
46
47
}
47
48
48
- // FIXME: should conform to net.Conn one day, but deadlines as required by net.Conn are complicated:
49
- // it requires to keep the connection open when the deadline is exceeded, but rwcconn.Cmd does not provide Deadlines
50
- // for good reason, see their docs for details.
51
49
type SSHConn struct {
52
- c * rwccmd.Cmd
50
+ cmd * exec.Cmd
51
+ stdin io.WriteCloser
52
+ stdout io.ReadCloser
53
+
54
+ shutdownMtx sync.Mutex
55
+ shutdownResult * shutdownResult
56
+ cmdCancel context.CancelFunc
53
57
}
54
58
55
- const go_network string = "SSH "
59
+ const go_network string = "netssh "
56
60
57
- type addr struct {
61
+ type clientAddr struct {
58
62
pid int
59
63
}
60
64
61
- func (a addr ) Network () string {
65
+ func (a clientAddr ) Network () string {
62
66
return go_network
63
67
}
64
68
65
- func (a addr ) String () string {
69
+ func (a clientAddr ) String () string {
66
70
return fmt .Sprintf ("pid=%d" , a .pid )
67
71
}
68
72
69
73
func (conn * SSHConn ) LocalAddr () net.Addr {
70
- return addr {conn .c .Pid ()}
74
+ proc := conn .cmd .Process
75
+ if proc == nil {
76
+ return clientAddr {- 1 }
77
+ }
78
+ return clientAddr {proc .Pid }
71
79
}
72
80
73
81
func (conn * SSHConn ) RemoteAddr () net.Addr {
74
- return addr { conn .c . Pid ()}
82
+ return conn .LocalAddr ()
75
83
}
76
84
85
+ // Read implements io.Reader.
86
+ // It returns *IOError for any non-nil error that is != io.EOF.
77
87
func (conn * SSHConn ) Read (p []byte ) (int , error ) {
78
- return conn .c .Read (p )
88
+ n , err := conn .stdout .Read (p )
89
+ if err != nil && err != io .EOF {
90
+ return n , & IOError {err }
91
+ }
92
+ return n , nil
79
93
}
80
94
95
+ // Write implements io.Writer.
96
+ // It returns *IOError for any error != nil.
81
97
func (conn * SSHConn ) Write (p []byte ) (int , error ) {
82
- return conn .c .Write (p )
98
+ n , err := conn .stdin .Write (p )
99
+ if err != nil {
100
+ return n , & IOError {err }
101
+ }
102
+ return n , nil
103
+ }
104
+
105
+ func (conn * SSHConn ) CloseWrite () error {
106
+ return conn .stdin .Close ()
107
+ }
108
+
109
+ type deadliner interface {
110
+ SetReadDeadline (time.Time ) error
111
+ SetWriteDeadline (time.Time ) error
112
+ }
113
+
114
+ func (conn * SSHConn ) SetReadDeadline (t time.Time ) error {
115
+ // type assertion is covered by test TestExecCmdPipesDeadlineBehavior
116
+ return conn .stdout .(deadliner ).SetReadDeadline (t )
117
+ }
118
+
119
+ func (conn * SSHConn ) SetWriteDeadline (t time.Time ) error {
120
+ // type assertion is covered by test TestExecCmdPipesDeadlineBehavior
121
+ return conn .stdin .(deadliner ).SetWriteDeadline (t )
122
+ }
123
+
124
+ func (conn * SSHConn ) SetDeadline (t time.Time ) error {
125
+ // try both
126
+ rerr := conn .SetReadDeadline (t )
127
+ werr := conn .SetWriteDeadline (t )
128
+ if rerr != nil {
129
+ return rerr
130
+ }
131
+ if werr != nil {
132
+ return werr
133
+ }
134
+ return nil
135
+ }
136
+
137
+ func (conn * SSHConn ) Close () error {
138
+ conn .shutdownProcess ()
139
+ return nil // FIXME: waitError will be non-zero because we signaled it, shutdownProcess needs to distinguish that
140
+ }
141
+
142
+ type shutdownResult struct {
143
+ waitErr error
144
+ }
145
+
146
+ func (conn * SSHConn ) shutdownProcess () * shutdownResult {
147
+ conn .shutdownMtx .Lock ()
148
+ defer conn .shutdownMtx .Unlock ()
149
+
150
+ if conn .shutdownResult != nil {
151
+ return conn .shutdownResult
152
+ }
153
+
154
+ termSuccessful := make (chan error , 1 )
155
+ go func () {
156
+ if err := conn .cmd .Process .Signal (syscall .SIGTERM ); err != nil {
157
+ // TODO log error
158
+ return
159
+ }
160
+ termSuccessful <- conn .cmd .Wait ()
161
+ }()
162
+
163
+ timeout := time .NewTimer (1 * time .Second ) // FIXME const
164
+ defer timeout .Stop ()
165
+
166
+ select {
167
+ case waitErr := <- termSuccessful :
168
+ conn .shutdownResult = & shutdownResult {waitErr }
169
+ case <- timeout .C :
170
+ conn .cmdCancel ()
171
+ waitErr := conn .cmd .Wait ()
172
+ conn .shutdownResult = & shutdownResult {waitErr }
173
+ }
174
+ return conn .shutdownResult
83
175
}
84
176
85
- func (conn * SSHConn ) Close () (error ) {
86
- return conn .c .Close ()
177
+ // Cmd returns the underlying *exec.Cmd (the ssh client process)
178
+ // Use read-only, should not be necessary for regular users.
179
+ func (conn * SSHConn ) Cmd () * exec.Cmd {
180
+ return conn .cmd
87
181
}
88
182
89
- // Use at your own risk...
90
- func (conn * SSHConn ) Cmd () * rwccmd.Cmd {
91
- return conn .c
183
+ // CmdCancel bypasses the normal shutdown mechanism of SSHConn
184
+ // (that is, calling Close) and cancels the process's context,
185
+ // which usually results in SIGKILL being sent to the process.
186
+ // Intended for integration tests, regular users shouldn't use it.
187
+ func (conn * SSHConn ) CmdCancel () {
188
+ conn .cmdCancel ()
92
189
}
93
190
94
191
const bannerMessageLen = 31
192
+
95
193
var messages = make (map [string ][]byte )
194
+
96
195
func mustMessage (str string ) []byte {
97
196
if len (str ) > bannerMessageLen {
98
197
panic ("message length must be smaller than bannerMessageLen" )
@@ -108,12 +207,13 @@ func mustMessage(str string) []byte {
108
207
buf .Write (bytes .Repeat ([]byte {0 }, bannerMessageLen - n ))
109
208
return buf .Bytes ()
110
209
}
210
+
111
211
var banner_msg = mustMessage ("SSHCON_HELO" )
112
212
var proxy_error_msg = mustMessage ("SSHCON_PROXY_ERROR" )
113
213
var begin_msg = mustMessage ("SSHCON_BEGIN" )
114
214
115
215
type SSHError struct {
116
- RWCError error
216
+ RWCError error
117
217
WhileActivity string
118
218
}
119
219
@@ -172,23 +272,31 @@ func (e ProtocolError) Error() string {
172
272
// If the handshake completes, dialCtx's deadline does not affect the returned connection.
173
273
//
174
274
// Errors returned are either dialCtx.Err(), or intances of ProtocolError or *SSHError
175
- func Dial (dialCtx context.Context , endpoint Endpoint ) (* SSHConn , error ) {
275
+ func Dial (dialCtx context.Context , endpoint Endpoint ) (* SSHConn , error ) {
176
276
177
277
sshCmd , sshArgs , sshEnv := endpoint .CmdArgs ()
178
278
commandCtx , commandCancel := context .WithCancel (context .Background ())
179
- cmd , err := rwccmd .CommandContext (commandCtx , sshCmd , sshArgs , sshEnv )
279
+ cmd := exec .CommandContext (commandCtx , sshCmd , sshArgs ... )
280
+ cmd .Env = sshEnv
281
+ stdin , err := cmd .StdinPipe ()
282
+ if err != nil {
283
+ return nil , err
284
+ }
285
+ stdout , err := cmd .StdoutPipe ()
180
286
if err != nil {
181
287
return nil , err
182
288
}
289
+ // stderr is required for *exec.ExitErr
290
+
183
291
if err = cmd .Start (); err != nil {
184
292
return nil , err
185
293
}
186
294
187
- confErrChan := make (chan error )
295
+ confErrChan := make (chan error , 1 )
188
296
go func () {
189
297
defer close (confErrChan )
190
298
var buf bytes.Buffer
191
- if _ , err := io .CopyN (& buf , cmd , int64 (len (banner_msg ))); err != nil {
299
+ if _ , err := io .CopyN (& buf , stdout , int64 (len (banner_msg ))); err != nil {
192
300
confErrChan <- & SSHError {err , "read banner" }
193
301
return
194
302
}
@@ -205,7 +313,7 @@ func Dial(dialCtx context.Context, endpoint Endpoint) (*SSHConn , error) {
205
313
}
206
314
buf .Reset ()
207
315
buf .Write (begin_msg )
208
- if _ , err := io .Copy (cmd , & buf ); err != nil {
316
+ if _ , err := io .Copy (stdin , & buf ); err != nil {
209
317
confErrChan <- & SSHError {err , "send begin message" }
210
318
return
211
319
}
@@ -221,7 +329,8 @@ func Dial(dialCtx context.Context, endpoint Endpoint) (*SSHConn , error) {
221
329
// ignore the error and return the cancellation cause
222
330
223
331
// draining always terminates because we know the channel is always closed
224
- for _ = range confErrChan {}
332
+ for _ = range confErrChan {
333
+ }
225
334
226
335
return nil , dialCtx .Err ()
227
336
@@ -232,5 +341,10 @@ func Dial(dialCtx context.Context, endpoint Endpoint) (*SSHConn , error) {
232
341
}
233
342
}
234
343
235
- return & SSHConn {cmd }, nil
344
+ return & SSHConn {
345
+ cmd : cmd ,
346
+ stdin : stdin ,
347
+ stdout : stdout ,
348
+ cmdCancel : commandCancel ,
349
+ }, nil
236
350
}
0 commit comments