Skip to content

Commit b89020e

Browse files
authored
Add SOCKS5 support
- Bundle the golang.org/x/net/proxy package to x_net_proxy.go. The package contains a SOCKS5 proxy. The package is bundled to avoid adding a dependency from the weboscket package to golang.org/x/net. - Restructure the existing HTTP proxy code so the code can be used as a dialer with the proxy package. - Modify Dialer.Dial to use proxy.FromURL. - Improve tests (avoid modifying package-level data, use timeouts in tests, use correct proxy URLs in tests). Fixes gorilla#297.
1 parent 8c6cfd4 commit b89020e

File tree

4 files changed

+677
-67
lines changed

4 files changed

+677
-67
lines changed

client.go

+36-58
Original file line numberDiff line numberDiff line change
@@ -5,10 +5,8 @@
55
package websocket
66

77
import (
8-
"bufio"
98
"bytes"
109
"crypto/tls"
11-
"encoding/base64"
1210
"errors"
1311
"io"
1412
"io/ioutil"
@@ -106,7 +104,7 @@ func hostPortNoPort(u *url.URL) (hostPort, hostNoPort string) {
106104
return hostPort, hostNoPort
107105
}
108106

109-
// DefaultDialer is a dialer with all fields set to the default zero values.
107+
// DefaultDialer is a dialer with all fields set to the default values.
110108
var DefaultDialer = &Dialer{
111109
Proxy: http.ProxyFromEnvironment,
112110
}
@@ -202,36 +200,52 @@ func (d *Dialer) Dial(urlStr string, requestHeader http.Header) (*Conn, *http.Re
202200
req.Header.Set("Sec-Websocket-Extensions", "permessage-deflate; server_no_context_takeover; client_no_context_takeover")
203201
}
204202

205-
hostPort, hostNoPort := hostPortNoPort(u)
206-
207-
var proxyURL *url.URL
208-
// Check wether the proxy method has been configured
209-
if d.Proxy != nil {
210-
proxyURL, err = d.Proxy(req)
211-
}
212-
if err != nil {
213-
return nil, nil, err
214-
}
215-
216-
var targetHostPort string
217-
if proxyURL != nil {
218-
targetHostPort, _ = hostPortNoPort(proxyURL)
219-
} else {
220-
targetHostPort = hostPort
221-
}
222-
223203
var deadline time.Time
224204
if d.HandshakeTimeout != 0 {
225205
deadline = time.Now().Add(d.HandshakeTimeout)
226206
}
227207

208+
// Get network dial function.
228209
netDial := d.NetDial
229210
if netDial == nil {
230211
netDialer := &net.Dialer{Deadline: deadline}
231212
netDial = netDialer.Dial
232213
}
233214

234-
netConn, err := netDial("tcp", targetHostPort)
215+
// If needed, wrap the dial function to set the connection deadline.
216+
if !deadline.Equal(time.Time{}) {
217+
forwardDial := netDial
218+
netDial = func(network, addr string) (net.Conn, error) {
219+
c, err := forwardDial(network, addr)
220+
if err != nil {
221+
return nil, err
222+
}
223+
err = c.SetDeadline(deadline)
224+
if err != nil {
225+
c.Close()
226+
return nil, err
227+
}
228+
return c, nil
229+
}
230+
}
231+
232+
// If needed, wrap the dial function to connect through a proxy.
233+
if d.Proxy != nil {
234+
proxyURL, err := d.Proxy(req)
235+
if err != nil {
236+
return nil, nil, err
237+
}
238+
if proxyURL != nil {
239+
dialer, err := proxy_FromURL(proxyURL, netDialerFunc(netDial))
240+
if err != nil {
241+
return nil, nil, err
242+
}
243+
netDial = dialer.Dial
244+
}
245+
}
246+
247+
hostPort, hostNoPort := hostPortNoPort(u)
248+
netConn, err := netDial("tcp", hostPort)
235249
if err != nil {
236250
return nil, nil, err
237251
}
@@ -242,42 +256,6 @@ func (d *Dialer) Dial(urlStr string, requestHeader http.Header) (*Conn, *http.Re
242256
}
243257
}()
244258

245-
if err := netConn.SetDeadline(deadline); err != nil {
246-
return nil, nil, err
247-
}
248-
249-
if proxyURL != nil {
250-
connectHeader := make(http.Header)
251-
if user := proxyURL.User; user != nil {
252-
proxyUser := user.Username()
253-
if proxyPassword, passwordSet := user.Password(); passwordSet {
254-
credential := base64.StdEncoding.EncodeToString([]byte(proxyUser + ":" + proxyPassword))
255-
connectHeader.Set("Proxy-Authorization", "Basic "+credential)
256-
}
257-
}
258-
connectReq := &http.Request{
259-
Method: "CONNECT",
260-
URL: &url.URL{Opaque: hostPort},
261-
Host: hostPort,
262-
Header: connectHeader,
263-
}
264-
265-
connectReq.Write(netConn)
266-
267-
// Read response.
268-
// Okay to use and discard buffered reader here, because
269-
// TLS server will not speak until spoken to.
270-
br := bufio.NewReader(netConn)
271-
resp, err := http.ReadResponse(br, connectReq)
272-
if err != nil {
273-
return nil, nil, err
274-
}
275-
if resp.StatusCode != 200 {
276-
f := strings.SplitN(resp.Status, " ", 2)
277-
return nil, nil, errors.New(f[1])
278-
}
279-
}
280-
281259
if u.Scheme == "https" {
282260
cfg := cloneTLSConfig(d.TLSClientConfig)
283261
if cfg.ServerName == "" {

client_server_test.go

+91-9
Original file line numberDiff line numberDiff line change
@@ -5,11 +5,14 @@
55
package websocket
66

77
import (
8+
"bytes"
89
"crypto/tls"
910
"crypto/x509"
1011
"encoding/base64"
12+
"encoding/binary"
1113
"io"
1214
"io/ioutil"
15+
"net"
1316
"net/http"
1417
"net/http/cookiejar"
1518
"net/http/httptest"
@@ -31,9 +34,10 @@ var cstUpgrader = Upgrader{
3134
}
3235

3336
var cstDialer = Dialer{
34-
Subprotocols: []string{"p1", "p2"},
35-
ReadBufferSize: 1024,
36-
WriteBufferSize: 1024,
37+
Subprotocols: []string{"p1", "p2"},
38+
ReadBufferSize: 1024,
39+
WriteBufferSize: 1024,
40+
HandshakeTimeout: 30 * time.Second,
3741
}
3842

3943
type cstHandler struct{ *testing.T }
@@ -143,8 +147,9 @@ func TestProxyDial(t *testing.T) {
143147
s := newServer(t)
144148
defer s.Close()
145149

146-
surl, _ := url.Parse(s.URL)
150+
surl, _ := url.Parse(s.Server.URL)
147151

152+
cstDialer := cstDialer // make local copy for modification on next line.
148153
cstDialer.Proxy = http.ProxyURL(surl)
149154

150155
connect := false
@@ -173,16 +178,16 @@ func TestProxyDial(t *testing.T) {
173178
}
174179
defer ws.Close()
175180
sendRecv(t, ws)
176-
177-
cstDialer.Proxy = http.ProxyFromEnvironment
178181
}
179182

180183
func TestProxyAuthorizationDial(t *testing.T) {
181184
s := newServer(t)
182185
defer s.Close()
183186

184-
surl, _ := url.Parse(s.URL)
187+
surl, _ := url.Parse(s.Server.URL)
185188
surl.User = url.UserPassword("username", "password")
189+
190+
cstDialer := cstDialer // make local copy for modification on next line.
186191
cstDialer.Proxy = http.ProxyURL(surl)
187192

188193
connect := false
@@ -213,8 +218,6 @@ func TestProxyAuthorizationDial(t *testing.T) {
213218
}
214219
defer ws.Close()
215220
sendRecv(t, ws)
216-
217-
cstDialer.Proxy = http.ProxyFromEnvironment
218221
}
219222

220223
func TestDial(t *testing.T) {
@@ -518,3 +521,82 @@ func TestDialCompression(t *testing.T) {
518521
defer ws.Close()
519522
sendRecv(t, ws)
520523
}
524+
525+
func TestSocksProxyDial(t *testing.T) {
526+
s := newServer(t)
527+
defer s.Close()
528+
529+
proxyListener, err := net.Listen("tcp", "127.0.0.1:0")
530+
if err != nil {
531+
t.Fatalf("listen failed: %v", err)
532+
}
533+
defer proxyListener.Close()
534+
go func() {
535+
c1, err := proxyListener.Accept()
536+
if err != nil {
537+
t.Errorf("proxy accept failed: %v", err)
538+
return
539+
}
540+
defer c1.Close()
541+
542+
c1.SetDeadline(time.Now().Add(30 * time.Second))
543+
544+
buf := make([]byte, 32)
545+
if _, err := io.ReadFull(c1, buf[:3]); err != nil {
546+
t.Errorf("read failed: %v", err)
547+
return
548+
}
549+
if want := []byte{5, 1, 0}; !bytes.Equal(want, buf[:len(want)]) {
550+
t.Errorf("read %x, want %x", buf[:len(want)], want)
551+
}
552+
if _, err := c1.Write([]byte{5, 0}); err != nil {
553+
t.Errorf("write failed: %v", err)
554+
return
555+
}
556+
if _, err := io.ReadFull(c1, buf[:10]); err != nil {
557+
t.Errorf("read failed: %v", err)
558+
return
559+
}
560+
if want := []byte{5, 1, 0, 1}; !bytes.Equal(want, buf[:len(want)]) {
561+
t.Errorf("read %x, want %x", buf[:len(want)], want)
562+
return
563+
}
564+
buf[1] = 0
565+
if _, err := c1.Write(buf[:10]); err != nil {
566+
t.Errorf("write failed: %v", err)
567+
return
568+
}
569+
570+
ip := net.IP(buf[4:8])
571+
port := binary.BigEndian.Uint16(buf[8:10])
572+
573+
c2, err := net.DialTCP("tcp", nil, &net.TCPAddr{IP: ip, Port: int(port)})
574+
if err != nil {
575+
t.Errorf("dial failed; %v", err)
576+
return
577+
}
578+
defer c2.Close()
579+
done := make(chan struct{})
580+
go func() {
581+
io.Copy(c1, c2)
582+
close(done)
583+
}()
584+
io.Copy(c2, c1)
585+
<-done
586+
}()
587+
588+
purl, err := url.Parse("socks5://" + proxyListener.Addr().String())
589+
if err != nil {
590+
t.Fatalf("parse failed: %v", err)
591+
}
592+
593+
cstDialer := cstDialer // make local copy for modification on next line.
594+
cstDialer.Proxy = http.ProxyURL(purl)
595+
596+
ws, _, err := cstDialer.Dial(s.URL, nil)
597+
if err != nil {
598+
t.Fatalf("Dial: %v", err)
599+
}
600+
defer ws.Close()
601+
sendRecv(t, ws)
602+
}

proxy.go

+77
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,77 @@
1+
// Copyright 2017 The Gorilla WebSocket Authors. All rights reserved.
2+
// Use of this source code is governed by a BSD-style
3+
// license that can be found in the LICENSE file.
4+
5+
package websocket
6+
7+
import (
8+
"bufio"
9+
"encoding/base64"
10+
"errors"
11+
"net"
12+
"net/http"
13+
"net/url"
14+
"strings"
15+
)
16+
17+
type netDialerFunc func(netowrk, addr string) (net.Conn, error)
18+
19+
func (fn netDialerFunc) Dial(network, addr string) (net.Conn, error) {
20+
return fn(network, addr)
21+
}
22+
23+
func init() {
24+
proxy_RegisterDialerType("http", func(proxyURL *url.URL, forwardDialer proxy_Dialer) (proxy_Dialer, error) {
25+
return &httpProxyDialer{proxyURL: proxyURL, fowardDial: forwardDialer.Dial}, nil
26+
})
27+
}
28+
29+
type httpProxyDialer struct {
30+
proxyURL *url.URL
31+
fowardDial func(network, addr string) (net.Conn, error)
32+
}
33+
34+
func (hpd *httpProxyDialer) Dial(network string, addr string) (net.Conn, error) {
35+
hostPort, _ := hostPortNoPort(hpd.proxyURL)
36+
conn, err := hpd.fowardDial(network, hostPort)
37+
if err != nil {
38+
return nil, err
39+
}
40+
41+
connectHeader := make(http.Header)
42+
if user := hpd.proxyURL.User; user != nil {
43+
proxyUser := user.Username()
44+
if proxyPassword, passwordSet := user.Password(); passwordSet {
45+
credential := base64.StdEncoding.EncodeToString([]byte(proxyUser + ":" + proxyPassword))
46+
connectHeader.Set("Proxy-Authorization", "Basic "+credential)
47+
}
48+
}
49+
50+
connectReq := &http.Request{
51+
Method: "CONNECT",
52+
URL: &url.URL{Opaque: addr},
53+
Host: addr,
54+
Header: connectHeader,
55+
}
56+
57+
if err := connectReq.Write(conn); err != nil {
58+
conn.Close()
59+
return nil, err
60+
}
61+
62+
// Read response. It's OK to use and discard buffered reader here becaue
63+
// the remote server does not speak until spoken to.
64+
br := bufio.NewReader(conn)
65+
resp, err := http.ReadResponse(br, connectReq)
66+
if err != nil {
67+
conn.Close()
68+
return nil, err
69+
}
70+
71+
if resp.StatusCode != 200 {
72+
conn.Close()
73+
f := strings.SplitN(resp.Status, " ", 2)
74+
return nil, errors.New(f[1])
75+
}
76+
return conn, nil
77+
}

0 commit comments

Comments
 (0)