Skip to content

Commit a053e79

Browse files
MrWakogopherbot
authored andcommitted
net/http: support TCP half-close when HTTP is upgraded in ReverseProxy
This CL propagates closing the write stream from either side of the reverse proxy and ensures the proxy waits for both copy-to and the copy-from the backend to complete. The new unit test checks communication through the reverse proxy when the backend or frontend closes either the read or write streams. That closing the write stream is propagated through the proxy from either the backend or the frontend. That closing the read stream is not propagated through the proxy. Fixes #35892 Change-Id: I83ce377df66a0f17b9ba2b53caf9e4991a95f6a0 Reviewed-on: https://go-review.googlesource.com/c/go/+/637939 Reviewed-by: Michael Pratt <[email protected]> Reviewed-by: Sean Liao <[email protected]> Auto-Submit: Sean Liao <[email protected]> LUCI-TryBot-Result: Go LUCI <[email protected]> Reviewed-by: Damien Neil <[email protected]> Reviewed-by: Matej Kramny <[email protected]>
1 parent 7181118 commit a053e79

File tree

3 files changed

+184
-5
lines changed

3 files changed

+184
-5
lines changed

src/net/http/httputil/reverseproxy.go

Lines changed: 33 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -793,7 +793,15 @@ func (p *ReverseProxy) handleUpgradeResponse(rw http.ResponseWriter, req *http.R
793793
spc := switchProtocolCopier{user: conn, backend: backConn}
794794
go spc.copyToBackend(errc)
795795
go spc.copyFromBackend(errc)
796-
<-errc
796+
797+
// wait until both copy functions have sent on the error channel
798+
err := <-errc
799+
if err == nil {
800+
err = <-errc
801+
}
802+
if err != nil {
803+
p.getErrorHandler()(rw, req, fmt.Errorf("can't copy: %v", err))
804+
}
797805
}
798806

799807
// switchProtocolCopier exists so goroutines proxying data back and
@@ -803,13 +811,33 @@ type switchProtocolCopier struct {
803811
}
804812

805813
func (c switchProtocolCopier) copyFromBackend(errc chan<- error) {
806-
_, err := io.Copy(c.user, c.backend)
807-
errc <- err
814+
if _, err := io.Copy(c.user, c.backend); err != nil {
815+
errc <- err
816+
return
817+
}
818+
819+
// backend conn has reached EOF so propogate close write to user conn
820+
if wc, ok := c.user.(interface{ CloseWrite() error }); ok {
821+
errc <- wc.CloseWrite()
822+
return
823+
}
824+
825+
errc <- nil
808826
}
809827

810828
func (c switchProtocolCopier) copyToBackend(errc chan<- error) {
811-
_, err := io.Copy(c.backend, c.user)
812-
errc <- err
829+
if _, err := io.Copy(c.backend, c.user); err != nil {
830+
errc <- err
831+
return
832+
}
833+
834+
// user conn has reached EOF so propogate close write to backend conn
835+
if wc, ok := c.backend.(interface{ CloseWrite() error }); ok {
836+
errc <- wc.CloseWrite()
837+
return
838+
}
839+
840+
errc <- nil
813841
}
814842

815843
func cleanQueryParams(s string) string {

src/net/http/httputil/reverseproxy_test.go

Lines changed: 144 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -14,6 +14,7 @@ import (
1414
"fmt"
1515
"io"
1616
"log"
17+
"net"
1718
"net/http"
1819
"net/http/httptest"
1920
"net/http/httptrace"
@@ -1551,6 +1552,149 @@ func TestReverseProxyWebSocketCancellation(t *testing.T) {
15511552
}
15521553
}
15531554

1555+
func TestReverseProxyWebSocketHalfTCP(t *testing.T) {
1556+
// Issue #35892: support TCP half-close when HTTP is upgraded in the ReverseProxy.
1557+
// Specifically testing:
1558+
// - the communication through the reverse proxy when the client or server closes
1559+
// either the read or write streams
1560+
// - that closing the write stream is propagated through the proxy and results in reading
1561+
// EOF at the other end of the connection
1562+
1563+
mustRead := func(t *testing.T, conn *net.TCPConn, msg string) {
1564+
b := make([]byte, len(msg))
1565+
if _, err := conn.Read(b); err != nil {
1566+
t.Errorf("failed to read: %v", err)
1567+
}
1568+
1569+
if got, want := string(b), msg; got != want {
1570+
t.Errorf("got %#q, want %#q", got, want)
1571+
}
1572+
}
1573+
1574+
mustReadError := func(t *testing.T, conn *net.TCPConn, e error) {
1575+
b := make([]byte, 1)
1576+
if _, err := conn.Read(b); !errors.Is(err, e) {
1577+
t.Errorf("failed to read error: %v", err)
1578+
}
1579+
}
1580+
1581+
mustWrite := func(t *testing.T, conn *net.TCPConn, msg string) {
1582+
if _, err := conn.Write([]byte(msg)); err != nil {
1583+
t.Errorf("failed to write: %v", err)
1584+
}
1585+
}
1586+
1587+
mustCloseRead := func(t *testing.T, conn *net.TCPConn) {
1588+
if err := conn.CloseRead(); err != nil {
1589+
t.Errorf("failed to CloseRead: %v", err)
1590+
}
1591+
}
1592+
1593+
mustCloseWrite := func(t *testing.T, conn *net.TCPConn) {
1594+
if err := conn.CloseWrite(); err != nil {
1595+
t.Errorf("failed to CloseWrite: %v", err)
1596+
}
1597+
}
1598+
1599+
tests := map[string]func(t *testing.T, cli, srv *net.TCPConn){
1600+
"server close read": func(t *testing.T, cli, srv *net.TCPConn) {
1601+
mustCloseRead(t, srv)
1602+
mustWrite(t, srv, "server sends")
1603+
mustRead(t, cli, "server sends")
1604+
},
1605+
"server close write": func(t *testing.T, cli, srv *net.TCPConn) {
1606+
mustCloseWrite(t, srv)
1607+
mustWrite(t, cli, "client sends")
1608+
mustRead(t, srv, "client sends")
1609+
mustReadError(t, cli, io.EOF)
1610+
},
1611+
"client close read": func(t *testing.T, cli, srv *net.TCPConn) {
1612+
mustCloseRead(t, cli)
1613+
mustWrite(t, cli, "client sends")
1614+
mustRead(t, srv, "client sends")
1615+
},
1616+
"client close write": func(t *testing.T, cli, srv *net.TCPConn) {
1617+
mustCloseWrite(t, cli)
1618+
mustWrite(t, srv, "server sends")
1619+
mustRead(t, cli, "server sends")
1620+
mustReadError(t, srv, io.EOF)
1621+
},
1622+
}
1623+
1624+
for name, test := range tests {
1625+
t.Run(name, func(t *testing.T) {
1626+
var srv *net.TCPConn
1627+
1628+
backendServer := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
1629+
if g, ws := upgradeType(r.Header), "websocket"; g != ws {
1630+
t.Fatalf("Unexpected upgrade type %q, want %q", g, ws)
1631+
}
1632+
1633+
conn, _, err := w.(http.Hijacker).Hijack()
1634+
if err != nil {
1635+
conn.Close()
1636+
t.Fatalf("hijack failed: %v", err)
1637+
}
1638+
1639+
var ok bool
1640+
if srv, ok = conn.(*net.TCPConn); !ok {
1641+
conn.Close()
1642+
t.Fatal("conn is not a TCPConn")
1643+
}
1644+
1645+
upgradeMsg := "HTTP/1.1 101 Switching Protocols\r\nConnection: upgrade\r\nUpgrade: WebSocket\r\n\r\n"
1646+
if _, err := io.WriteString(srv, upgradeMsg); err != nil {
1647+
srv.Close()
1648+
t.Fatalf("backend upgrade failed: %v", err)
1649+
}
1650+
}))
1651+
defer backendServer.Close()
1652+
1653+
backendURL, _ := url.Parse(backendServer.URL)
1654+
rproxy := NewSingleHostReverseProxy(backendURL)
1655+
rproxy.ErrorLog = log.New(io.Discard, "", 0) // quiet for tests
1656+
frontendProxy := httptest.NewServer(http.HandlerFunc(func(rw http.ResponseWriter, req *http.Request) {
1657+
rproxy.ServeHTTP(rw, req)
1658+
}))
1659+
defer frontendProxy.Close()
1660+
1661+
frontendURL, _ := url.Parse(frontendProxy.URL)
1662+
addr, err := net.ResolveTCPAddr("tcp", frontendURL.Host)
1663+
if err != nil {
1664+
t.Fatalf("failed to resolve TCP address: %v", err)
1665+
}
1666+
cli, err := net.DialTCP("tcp", nil, addr)
1667+
if err != nil {
1668+
t.Fatalf("failed to dial TCP address: %v", err)
1669+
}
1670+
defer cli.Close()
1671+
1672+
req, _ := http.NewRequest("GET", frontendProxy.URL, nil)
1673+
req.Header.Set("Connection", "Upgrade")
1674+
req.Header.Set("Upgrade", "websocket")
1675+
if err := req.Write(cli); err != nil {
1676+
t.Fatalf("failed to write request: %v", err)
1677+
}
1678+
1679+
br := bufio.NewReader(cli)
1680+
resp, err := http.ReadResponse(br, &http.Request{Method: "GET"})
1681+
if err != nil {
1682+
t.Fatalf("failed to read response: %v", err)
1683+
}
1684+
if resp.StatusCode != 101 {
1685+
t.Fatalf("status code not 101: %v", resp.StatusCode)
1686+
}
1687+
if strings.ToLower(resp.Header.Get("Upgrade")) != "websocket" ||
1688+
strings.ToLower(resp.Header.Get("Connection")) != "upgrade" {
1689+
t.Fatalf("frontend upgrade failed")
1690+
}
1691+
defer srv.Close()
1692+
1693+
test(t, cli, srv)
1694+
})
1695+
}
1696+
}
1697+
15541698
func TestUnannouncedTrailer(t *testing.T) {
15551699
backend := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
15561700
w.WriteHeader(http.StatusOK)

src/net/http/transport.go

Lines changed: 7 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -2575,6 +2575,13 @@ func (b *readWriteCloserBody) Read(p []byte) (n int, err error) {
25752575
return b.ReadWriteCloser.Read(p)
25762576
}
25772577

2578+
func (b *readWriteCloserBody) CloseWrite() error {
2579+
if cw, ok := b.ReadWriteCloser.(interface{ CloseWrite() error }); ok {
2580+
return cw.CloseWrite()
2581+
}
2582+
return fmt.Errorf("CloseWrite: %w", ErrNotSupported)
2583+
}
2584+
25782585
// nothingWrittenError wraps a write errors which ended up writing zero bytes.
25792586
type nothingWrittenError struct {
25802587
error

0 commit comments

Comments
 (0)