@@ -14,6 +14,7 @@ import (
14
14
"fmt"
15
15
"io"
16
16
"log"
17
+ "net"
17
18
"net/http"
18
19
"net/http/httptest"
19
20
"net/http/httptrace"
@@ -1551,6 +1552,149 @@ func TestReverseProxyWebSocketCancellation(t *testing.T) {
1551
1552
}
1552
1553
}
1553
1554
1555
+ func TestReverseProxyWebSocketHalfTCP (t * testing.T ) {
1556
+ // Issue #35892: support TCP half-close when HTTP is upgraded in the ReverseProxy.
1557
+ // Specifically testing:
1558
+ // - the communication through the reverse proxy when the client or server closes
1559
+ // either the read or write streams
1560
+ // - that closing the write stream is propagated through the proxy and results in reading
1561
+ // EOF at the other end of the connection
1562
+
1563
+ mustRead := func (t * testing.T , conn * net.TCPConn , msg string ) {
1564
+ b := make ([]byte , len (msg ))
1565
+ if _ , err := conn .Read (b ); err != nil {
1566
+ t .Errorf ("failed to read: %v" , err )
1567
+ }
1568
+
1569
+ if got , want := string (b ), msg ; got != want {
1570
+ t .Errorf ("got %#q, want %#q" , got , want )
1571
+ }
1572
+ }
1573
+
1574
+ mustReadError := func (t * testing.T , conn * net.TCPConn , e error ) {
1575
+ b := make ([]byte , 1 )
1576
+ if _ , err := conn .Read (b ); ! errors .Is (err , e ) {
1577
+ t .Errorf ("failed to read error: %v" , err )
1578
+ }
1579
+ }
1580
+
1581
+ mustWrite := func (t * testing.T , conn * net.TCPConn , msg string ) {
1582
+ if _ , err := conn .Write ([]byte (msg )); err != nil {
1583
+ t .Errorf ("failed to write: %v" , err )
1584
+ }
1585
+ }
1586
+
1587
+ mustCloseRead := func (t * testing.T , conn * net.TCPConn ) {
1588
+ if err := conn .CloseRead (); err != nil {
1589
+ t .Errorf ("failed to CloseRead: %v" , err )
1590
+ }
1591
+ }
1592
+
1593
+ mustCloseWrite := func (t * testing.T , conn * net.TCPConn ) {
1594
+ if err := conn .CloseWrite (); err != nil {
1595
+ t .Errorf ("failed to CloseWrite: %v" , err )
1596
+ }
1597
+ }
1598
+
1599
+ tests := map [string ]func (t * testing.T , cli , srv * net.TCPConn ){
1600
+ "server close read" : func (t * testing.T , cli , srv * net.TCPConn ) {
1601
+ mustCloseRead (t , srv )
1602
+ mustWrite (t , srv , "server sends" )
1603
+ mustRead (t , cli , "server sends" )
1604
+ },
1605
+ "server close write" : func (t * testing.T , cli , srv * net.TCPConn ) {
1606
+ mustCloseWrite (t , srv )
1607
+ mustWrite (t , cli , "client sends" )
1608
+ mustRead (t , srv , "client sends" )
1609
+ mustReadError (t , cli , io .EOF )
1610
+ },
1611
+ "client close read" : func (t * testing.T , cli , srv * net.TCPConn ) {
1612
+ mustCloseRead (t , cli )
1613
+ mustWrite (t , cli , "client sends" )
1614
+ mustRead (t , srv , "client sends" )
1615
+ },
1616
+ "client close write" : func (t * testing.T , cli , srv * net.TCPConn ) {
1617
+ mustCloseWrite (t , cli )
1618
+ mustWrite (t , srv , "server sends" )
1619
+ mustRead (t , cli , "server sends" )
1620
+ mustReadError (t , srv , io .EOF )
1621
+ },
1622
+ }
1623
+
1624
+ for name , test := range tests {
1625
+ t .Run (name , func (t * testing.T ) {
1626
+ var srv * net.TCPConn
1627
+
1628
+ backendServer := httptest .NewServer (http .HandlerFunc (func (w http.ResponseWriter , r * http.Request ) {
1629
+ if g , ws := upgradeType (r .Header ), "websocket" ; g != ws {
1630
+ t .Fatalf ("Unexpected upgrade type %q, want %q" , g , ws )
1631
+ }
1632
+
1633
+ conn , _ , err := w .(http.Hijacker ).Hijack ()
1634
+ if err != nil {
1635
+ conn .Close ()
1636
+ t .Fatalf ("hijack failed: %v" , err )
1637
+ }
1638
+
1639
+ var ok bool
1640
+ if srv , ok = conn .(* net.TCPConn ); ! ok {
1641
+ conn .Close ()
1642
+ t .Fatal ("conn is not a TCPConn" )
1643
+ }
1644
+
1645
+ upgradeMsg := "HTTP/1.1 101 Switching Protocols\r \n Connection: upgrade\r \n Upgrade: WebSocket\r \n \r \n "
1646
+ if _ , err := io .WriteString (srv , upgradeMsg ); err != nil {
1647
+ srv .Close ()
1648
+ t .Fatalf ("backend upgrade failed: %v" , err )
1649
+ }
1650
+ }))
1651
+ defer backendServer .Close ()
1652
+
1653
+ backendURL , _ := url .Parse (backendServer .URL )
1654
+ rproxy := NewSingleHostReverseProxy (backendURL )
1655
+ rproxy .ErrorLog = log .New (io .Discard , "" , 0 ) // quiet for tests
1656
+ frontendProxy := httptest .NewServer (http .HandlerFunc (func (rw http.ResponseWriter , req * http.Request ) {
1657
+ rproxy .ServeHTTP (rw , req )
1658
+ }))
1659
+ defer frontendProxy .Close ()
1660
+
1661
+ frontendURL , _ := url .Parse (frontendProxy .URL )
1662
+ addr , err := net .ResolveTCPAddr ("tcp" , frontendURL .Host )
1663
+ if err != nil {
1664
+ t .Fatalf ("failed to resolve TCP address: %v" , err )
1665
+ }
1666
+ cli , err := net .DialTCP ("tcp" , nil , addr )
1667
+ if err != nil {
1668
+ t .Fatalf ("failed to dial TCP address: %v" , err )
1669
+ }
1670
+ defer cli .Close ()
1671
+
1672
+ req , _ := http .NewRequest ("GET" , frontendProxy .URL , nil )
1673
+ req .Header .Set ("Connection" , "Upgrade" )
1674
+ req .Header .Set ("Upgrade" , "websocket" )
1675
+ if err := req .Write (cli ); err != nil {
1676
+ t .Fatalf ("failed to write request: %v" , err )
1677
+ }
1678
+
1679
+ br := bufio .NewReader (cli )
1680
+ resp , err := http .ReadResponse (br , & http.Request {Method : "GET" })
1681
+ if err != nil {
1682
+ t .Fatalf ("failed to read response: %v" , err )
1683
+ }
1684
+ if resp .StatusCode != 101 {
1685
+ t .Fatalf ("status code not 101: %v" , resp .StatusCode )
1686
+ }
1687
+ if strings .ToLower (resp .Header .Get ("Upgrade" )) != "websocket" ||
1688
+ strings .ToLower (resp .Header .Get ("Connection" )) != "upgrade" {
1689
+ t .Fatalf ("frontend upgrade failed" )
1690
+ }
1691
+ defer srv .Close ()
1692
+
1693
+ test (t , cli , srv )
1694
+ })
1695
+ }
1696
+ }
1697
+
1554
1698
func TestUnannouncedTrailer (t * testing.T ) {
1555
1699
backend := httptest .NewServer (http .HandlerFunc (func (w http.ResponseWriter , r * http.Request ) {
1556
1700
w .WriteHeader (http .StatusOK )
0 commit comments