Skip to content

Commit 87125fc

Browse files
committed
tcpreuse: fix Scope() for *tls.Conn
1 parent 080e6c8 commit 87125fc

File tree

3 files changed

+113
-20
lines changed

3 files changed

+113
-20
lines changed

p2p/test/transport/gating_test.go

Lines changed: 21 additions & 15 deletions
Original file line numberDiff line numberDiff line change
@@ -22,14 +22,26 @@ import (
2222

2323
//go:generate go run go.uber.org/mock/mockgen -package transport_integration -destination mock_connection_gater_test.go github.com/libp2p/go-libp2p/core/connmgr ConnectionGater
2424

25-
func stripCertHash(addr ma.Multiaddr) ma.Multiaddr {
25+
// normalize removes the certhash and replaces /wss with /tls/ws
26+
func normalize(addr ma.Multiaddr) ma.Multiaddr {
2627
for {
2728
if _, err := addr.ValueForProtocol(ma.P_CERTHASH); err != nil {
2829
break
2930
}
3031
addr, _ = ma.SplitLast(addr)
3132
}
32-
return addr
33+
34+
// replace /wss with /tls/ws
35+
components := []ma.Multiaddr{}
36+
ma.ForEach(addr, func(c ma.Component) bool {
37+
if c.Protocol().Code == ma.P_WSS {
38+
components = append(components, ma.StringCast("/tls/ws"))
39+
} else {
40+
components = append(components, &c)
41+
}
42+
return true
43+
})
44+
return ma.Join(components...)
3345
}
3446

3547
func addrPort(addr ma.Multiaddr) netip.AddrPort {
@@ -119,8 +131,7 @@ func TestInterceptSecuredOutgoing(t *testing.T) {
119131
connGater.EXPECT().InterceptPeerDial(h2.ID()).Return(true),
120132
connGater.EXPECT().InterceptAddrDial(h2.ID(), gomock.Any()).Return(true),
121133
connGater.EXPECT().InterceptSecured(network.DirOutbound, h2.ID(), gomock.Any()).Do(func(_ network.Direction, _ peer.ID, addrs network.ConnMultiaddrs) {
122-
// remove the certhash component from WebTransport and WebRTC addresses
123-
require.Equal(t, stripCertHash(h2.Addrs()[0]).String(), addrs.RemoteMultiaddr().String())
134+
require.Equal(t, normalize(h2.Addrs()[0]), normalize(addrs.RemoteMultiaddr()))
124135
}),
125136
)
126137
err := h1.Connect(ctx, peer.AddrInfo{ID: h2.ID(), Addrs: h2.Addrs()})
@@ -154,8 +165,7 @@ func TestInterceptUpgradedOutgoing(t *testing.T) {
154165
connGater.EXPECT().InterceptAddrDial(h2.ID(), gomock.Any()).Return(true),
155166
connGater.EXPECT().InterceptSecured(network.DirOutbound, h2.ID(), gomock.Any()).Return(true),
156167
connGater.EXPECT().InterceptUpgraded(gomock.Any()).Do(func(c network.Conn) {
157-
// remove the certhash component from WebTransport addresses
158-
require.Equal(t, stripCertHash(h2.Addrs()[0]), c.RemoteMultiaddr())
168+
require.Equal(t, normalize(h2.Addrs()[0]), normalize(c.RemoteMultiaddr()))
159169
require.Equal(t, h1.ID(), c.LocalPeer())
160170
require.Equal(t, h2.ID(), c.RemotePeer())
161171
}))
@@ -189,17 +199,15 @@ func TestInterceptAccept(t *testing.T) {
189199
// In WebRTC, retransmissions of the STUN packet might cause us to create multiple connections,
190200
// if the first connection attempt is rejected.
191201
connGater.EXPECT().InterceptAccept(gomock.Any()).Do(func(addrs network.ConnMultiaddrs) {
192-
// remove the certhash component from WebTransport addresses
193-
require.Equal(t, stripCertHash(h2.Addrs()[0]), addrs.LocalMultiaddr())
202+
require.Equal(t, normalize(h2.Addrs()[0]), normalize(addrs.LocalMultiaddr()))
194203
}).AnyTimes()
195-
} else if strings.Contains(tc.Name, "WebSocket-Shared") {
204+
} else if strings.Contains(tc.Name, "WebSocket-Shared") || strings.Contains(tc.Name, "WebSocket-Secured-Shared") {
196205
connGater.EXPECT().InterceptAccept(gomock.Any()).Do(func(addrs network.ConnMultiaddrs) {
197206
require.Equal(t, addrPort(h2.Addrs()[0]), addrPort(addrs.LocalMultiaddr()))
198207
})
199208
} else {
200209
connGater.EXPECT().InterceptAccept(gomock.Any()).Do(func(addrs network.ConnMultiaddrs) {
201-
// remove the certhash component from WebTransport addresses
202-
require.Equal(t, stripCertHash(h2.Addrs()[0]), addrs.LocalMultiaddr(), "%s\n%s", h2.Addrs()[0], addrs.LocalMultiaddr())
210+
require.Equal(t, normalize(h2.Addrs()[0]), normalize(addrs.LocalMultiaddr()))
203211
})
204212
}
205213

@@ -236,8 +244,7 @@ func TestInterceptSecuredIncoming(t *testing.T) {
236244
gomock.InOrder(
237245
connGater.EXPECT().InterceptAccept(gomock.Any()).Return(true),
238246
connGater.EXPECT().InterceptSecured(network.DirInbound, h1.ID(), gomock.Any()).Do(func(_ network.Direction, _ peer.ID, addrs network.ConnMultiaddrs) {
239-
// remove the certhash component from WebTransport addresses
240-
require.Equal(t, stripCertHash(h2.Addrs()[0]), addrs.LocalMultiaddr())
247+
require.Equal(t, normalize(h2.Addrs()[0]), normalize(addrs.LocalMultiaddr()))
241248
}),
242249
)
243250
h1.Peerstore().AddAddrs(h2.ID(), h2.Addrs(), time.Hour)
@@ -270,8 +277,7 @@ func TestInterceptUpgradedIncoming(t *testing.T) {
270277
connGater.EXPECT().InterceptAccept(gomock.Any()).Return(true),
271278
connGater.EXPECT().InterceptSecured(network.DirInbound, h1.ID(), gomock.Any()).Return(true),
272279
connGater.EXPECT().InterceptUpgraded(gomock.Any()).Do(func(c network.Conn) {
273-
// remove the certhash component from WebTransport addresses
274-
require.Equal(t, stripCertHash(h2.Addrs()[0]), c.LocalMultiaddr())
280+
require.Equal(t, normalize(h2.Addrs()[0]), normalize(c.LocalMultiaddr()))
275281
require.Equal(t, h1.ID(), c.RemotePeer())
276282
require.Equal(t, h2.ID(), c.LocalPeer())
277283
}),

p2p/test/transport/transport_test.go

Lines changed: 84 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -3,10 +3,16 @@ package transport_integration
33
import (
44
"bytes"
55
"context"
6+
"crypto/ecdsa"
7+
"crypto/elliptic"
68
"crypto/rand"
9+
"crypto/tls"
10+
"crypto/x509"
11+
"crypto/x509/pkix"
712
"errors"
813
"fmt"
914
"io"
15+
"math/big"
1016
"net"
1117
"runtime"
1218
"strings"
@@ -15,6 +21,8 @@ import (
1521
"testing"
1622
"time"
1723

24+
libp2ptls "github.com/libp2p/go-libp2p/p2p/security/tls"
25+
1826
"github.com/libp2p/go-libp2p"
1927
"github.com/libp2p/go-libp2p/config"
2028
"github.com/libp2p/go-libp2p/core/connmgr"
@@ -30,9 +38,9 @@ import (
3038
"github.com/libp2p/go-libp2p/p2p/net/swarm"
3139
"github.com/libp2p/go-libp2p/p2p/protocol/ping"
3240
"github.com/libp2p/go-libp2p/p2p/security/noise"
33-
tls "github.com/libp2p/go-libp2p/p2p/security/tls"
3441
"github.com/libp2p/go-libp2p/p2p/transport/tcp"
3542
libp2pwebrtc "github.com/libp2p/go-libp2p/p2p/transport/webrtc"
43+
"github.com/libp2p/go-libp2p/p2p/transport/websocket"
3644
"go.uber.org/mock/gomock"
3745

3846
ma "github.com/multiformats/go-multiaddr"
@@ -68,6 +76,44 @@ func transformOpts(opts TransportTestCaseOpts) []config.Option {
6876
return libp2pOpts
6977
}
7078

79+
func selfSignedTLSConfig(t *testing.T) *tls.Config {
80+
t.Helper()
81+
priv, err := ecdsa.GenerateKey(elliptic.P256(), rand.Reader)
82+
require.NoError(t, err)
83+
84+
notBefore := time.Now()
85+
notAfter := notBefore.Add(365 * 24 * time.Hour)
86+
87+
serialNumberLimit := new(big.Int).Lsh(big.NewInt(1), 128)
88+
serialNumber, err := rand.Int(rand.Reader, serialNumberLimit)
89+
require.NoError(t, err)
90+
91+
certTemplate := x509.Certificate{
92+
SerialNumber: serialNumber,
93+
Subject: pkix.Name{
94+
Organization: []string{"Test"},
95+
},
96+
NotBefore: notBefore,
97+
NotAfter: notAfter,
98+
KeyUsage: x509.KeyUsageKeyEncipherment | x509.KeyUsageDigitalSignature,
99+
ExtKeyUsage: []x509.ExtKeyUsage{x509.ExtKeyUsageServerAuth},
100+
BasicConstraintsValid: true,
101+
}
102+
103+
derBytes, err := x509.CreateCertificate(rand.Reader, &certTemplate, &certTemplate, &priv.PublicKey, priv)
104+
require.NoError(t, err)
105+
106+
cert := tls.Certificate{
107+
Certificate: [][]byte{derBytes},
108+
PrivateKey: priv,
109+
}
110+
111+
tlsConfig := &tls.Config{
112+
Certificates: []tls.Certificate{cert},
113+
}
114+
return tlsConfig
115+
}
116+
71117
var transportsToTest = []TransportTestCase{
72118
{
73119
Name: "TCP / Noise / Yamux",
@@ -89,7 +135,7 @@ var transportsToTest = []TransportTestCase{
89135
Name: "TCP / TLS / Yamux",
90136
HostGenerator: func(t *testing.T, opts TransportTestCaseOpts) host.Host {
91137
libp2pOpts := transformOpts(opts)
92-
libp2pOpts = append(libp2pOpts, libp2p.Security(tls.ID, tls.New))
138+
libp2pOpts = append(libp2pOpts, libp2p.Security(libp2ptls.ID, libp2ptls.New))
93139
libp2pOpts = append(libp2pOpts, libp2p.Muxer(yamux.ID, yamux.DefaultTransport))
94140
if opts.NoListen {
95141
libp2pOpts = append(libp2pOpts, libp2p.NoListenAddrs)
@@ -106,7 +152,7 @@ var transportsToTest = []TransportTestCase{
106152
HostGenerator: func(t *testing.T, opts TransportTestCaseOpts) host.Host {
107153
libp2pOpts := transformOpts(opts)
108154
libp2pOpts = append(libp2pOpts, libp2p.ShareTCPListener())
109-
libp2pOpts = append(libp2pOpts, libp2p.Security(tls.ID, tls.New))
155+
libp2pOpts = append(libp2pOpts, libp2p.Security(libp2ptls.ID, libp2ptls.New))
110156
libp2pOpts = append(libp2pOpts, libp2p.Muxer(yamux.ID, yamux.DefaultTransport))
111157
if opts.NoListen {
112158
libp2pOpts = append(libp2pOpts, libp2p.NoListenAddrs)
@@ -123,7 +169,7 @@ var transportsToTest = []TransportTestCase{
123169
HostGenerator: func(t *testing.T, opts TransportTestCaseOpts) host.Host {
124170
libp2pOpts := transformOpts(opts)
125171
libp2pOpts = append(libp2pOpts, libp2p.ShareTCPListener())
126-
libp2pOpts = append(libp2pOpts, libp2p.Security(tls.ID, tls.New))
172+
libp2pOpts = append(libp2pOpts, libp2p.Security(libp2ptls.ID, libp2ptls.New))
127173
libp2pOpts = append(libp2pOpts, libp2p.Muxer(yamux.ID, yamux.DefaultTransport))
128174
libp2pOpts = append(libp2pOpts, libp2p.Transport(tcp.NewTCPTransport, tcp.WithMetrics()))
129175
if opts.NoListen {
@@ -140,7 +186,7 @@ var transportsToTest = []TransportTestCase{
140186
Name: "TCP-WithMetrics / TLS / Yamux",
141187
HostGenerator: func(t *testing.T, opts TransportTestCaseOpts) host.Host {
142188
libp2pOpts := transformOpts(opts)
143-
libp2pOpts = append(libp2pOpts, libp2p.Security(tls.ID, tls.New))
189+
libp2pOpts = append(libp2pOpts, libp2p.Security(libp2ptls.ID, libp2ptls.New))
144190
libp2pOpts = append(libp2pOpts, libp2p.Muxer(yamux.ID, yamux.DefaultTransport))
145191
libp2pOpts = append(libp2pOpts, libp2p.Transport(tcp.NewTCPTransport, tcp.WithMetrics()))
146192
if opts.NoListen {
@@ -168,6 +214,23 @@ var transportsToTest = []TransportTestCase{
168214
return h
169215
},
170216
},
217+
{
218+
Name: "WebSocket-Secured-Shared",
219+
HostGenerator: func(t *testing.T, opts TransportTestCaseOpts) host.Host {
220+
libp2pOpts := transformOpts(opts)
221+
libp2pOpts = append(libp2pOpts, libp2p.ShareTCPListener())
222+
if opts.NoListen {
223+
config := tls.Config{InsecureSkipVerify: true}
224+
libp2pOpts = append(libp2pOpts, libp2p.NoListenAddrs, libp2p.Transport(websocket.New, websocket.WithTLSClientConfig(&config)))
225+
} else {
226+
config := selfSignedTLSConfig(t)
227+
libp2pOpts = append(libp2pOpts, libp2p.ListenAddrStrings("/ip4/127.0.0.1/tcp/0/sni/localhost/tls/ws"), libp2p.Transport(websocket.New, websocket.WithTLSConfig(config)))
228+
}
229+
h, err := libp2p.New(libp2pOpts...)
230+
require.NoError(t, err)
231+
return h
232+
},
233+
},
171234
{
172235
Name: "WebSocket",
173236
HostGenerator: func(t *testing.T, opts TransportTestCaseOpts) host.Host {
@@ -182,6 +245,22 @@ var transportsToTest = []TransportTestCase{
182245
return h
183246
},
184247
},
248+
{
249+
Name: "WebSocket-Secured",
250+
HostGenerator: func(t *testing.T, opts TransportTestCaseOpts) host.Host {
251+
libp2pOpts := transformOpts(opts)
252+
if opts.NoListen {
253+
config := tls.Config{InsecureSkipVerify: true}
254+
libp2pOpts = append(libp2pOpts, libp2p.NoListenAddrs, libp2p.Transport(websocket.New, websocket.WithTLSClientConfig(&config)))
255+
} else {
256+
config := selfSignedTLSConfig(t)
257+
libp2pOpts = append(libp2pOpts, libp2p.ListenAddrStrings("/ip4/127.0.0.1/tcp/0/sni/localhost/tls/ws"), libp2p.Transport(websocket.New, websocket.WithTLSConfig(config)))
258+
}
259+
h, err := libp2p.New(libp2pOpts...)
260+
require.NoError(t, err)
261+
return h
262+
},
263+
},
185264
{
186265
Name: "QUIC",
187266
HostGenerator: func(t *testing.T, opts TransportTestCaseOpts) host.Host {

p2p/transport/websocket/conn.go

Lines changed: 8 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -1,6 +1,7 @@
11
package websocket
22

33
import (
4+
"crypto/tls"
45
"errors"
56
"io"
67
"net"
@@ -142,6 +143,13 @@ func (c *Conn) Scope() network.ConnManagementScope {
142143
}); ok {
143144
return sc.Scope()
144145
}
146+
if nc, ok := nc.(*tls.Conn); ok {
147+
if sc, ok := nc.NetConn().(interface {
148+
Scope() network.ConnManagementScope
149+
}); ok {
150+
return sc.Scope()
151+
}
152+
}
145153
return nil
146154
}
147155

0 commit comments

Comments
 (0)