Skip to content

Commit 97856b4

Browse files
authored
Merge pull request #21 from libp2p/fix/keepalive-race
fix: synchronize when resetting the keepalive timer
2 parents 51522d4 + 345f639 commit 97856b4

File tree

5 files changed

+414
-258
lines changed

5 files changed

+414
-258
lines changed

.travis.yml

+3-1
Original file line numberDiff line numberDiff line change
@@ -8,8 +8,10 @@ go:
88

99
env:
1010
global:
11-
- GOTFLAGS="-race"
1211
- BUILD_DEPTYPE=gomod
12+
matrix:
13+
- GOTFLAGS="-race"
14+
- GOTFLAGS="-count 5"
1315

1416

1517
# disable travis install

bench_test.go

+8-5
Original file line numberDiff line numberDiff line change
@@ -48,24 +48,25 @@ func BenchmarkAccept(b *testing.B) {
4848
func BenchmarkSendRecv(b *testing.B) {
4949
client, server := testClientServer()
5050
defer client.Close()
51-
defer server.Close()
5251

5352
sendBuf := make([]byte, 512)
5453
recvBuf := make([]byte, 512)
5554

5655
doneCh := make(chan struct{})
5756
go func() {
57+
defer close(doneCh)
58+
defer server.Close()
5859
stream, err := server.AcceptStream()
5960
if err != nil {
6061
return
6162
}
6263
defer stream.Close()
6364
for i := 0; i < b.N; i++ {
6465
if _, err := io.ReadFull(stream, recvBuf); err != nil {
65-
b.Fatalf("err: %v", err)
66+
b.Errorf("err: %v", err)
67+
return
6668
}
6769
}
68-
close(doneCh)
6970
}()
7071

7172
stream, err := client.Open()
@@ -95,6 +96,8 @@ func BenchmarkSendRecvLarge(b *testing.B) {
9596
recvDone := make(chan struct{})
9697

9798
go func() {
99+
defer close(recvDone)
100+
defer server.Close()
98101
stream, err := server.AcceptStream()
99102
if err != nil {
100103
return
@@ -103,11 +106,11 @@ func BenchmarkSendRecvLarge(b *testing.B) {
103106
for i := 0; i < b.N; i++ {
104107
for j := 0; j < sendSize/recvSize; j++ {
105108
if _, err := io.ReadFull(stream, recvBuf); err != nil {
106-
b.Fatalf("err: %v", err)
109+
b.Errorf("err: %v", err)
110+
return
107111
}
108112
}
109113
}
110-
close(recvDone)
111114
}()
112115

113116
stream, err := client.Open()

session.go

+34-19
Original file line numberDiff line numberDiff line change
@@ -87,15 +87,11 @@ type Session struct {
8787

8888
// keepaliveTimer is a periodic timer for keepalive messages. It's nil
8989
// when keepalives are disabled.
90-
keepaliveLock sync.Mutex
91-
keepaliveTimer *time.Timer
90+
keepaliveLock sync.Mutex
91+
keepaliveTimer *time.Timer
92+
keepaliveActive bool
9293
}
9394

94-
const (
95-
stageInitial uint32 = iota
96-
stageFinal
97-
)
98-
9995
// newSession is used to construct a new session
10096
func newSession(config *Config, conn net.Conn, client bool, readBuf int) *Session {
10197
var reader io.Reader = conn
@@ -327,23 +323,27 @@ func (s *Session) startKeepalive() {
327323
defer s.keepaliveLock.Unlock()
328324
s.keepaliveTimer = time.AfterFunc(s.config.KeepAliveInterval, func() {
329325
s.keepaliveLock.Lock()
330-
331-
if s.keepaliveTimer == nil {
326+
if s.keepaliveTimer == nil || s.keepaliveActive {
327+
// keepalives have been stopped or a keepalive is active.
332328
s.keepaliveLock.Unlock()
333-
// keepalives have been stopped.
334329
return
335330
}
331+
s.keepaliveActive = true
332+
s.keepaliveLock.Unlock()
333+
336334
_, err := s.Ping()
335+
336+
s.keepaliveLock.Lock()
337+
s.keepaliveActive = false
338+
if s.keepaliveTimer != nil {
339+
s.keepaliveTimer.Reset(s.config.KeepAliveInterval)
340+
}
341+
s.keepaliveLock.Unlock()
342+
337343
if err != nil {
338-
// Make sure to unlock before exiting so we don't
339-
// deadlock trying to shutdown keepalives.
340-
s.keepaliveLock.Unlock()
341344
s.logger.Printf("[ERR] yamux: keepalive failed: %v", err)
342345
s.exitErr(ErrKeepAliveTimeout)
343-
return
344346
}
345-
s.keepaliveTimer.Reset(s.config.KeepAliveInterval)
346-
s.keepaliveLock.Unlock()
347347
})
348348
}
349349

@@ -353,7 +353,24 @@ func (s *Session) stopKeepalive() {
353353
defer s.keepaliveLock.Unlock()
354354
if s.keepaliveTimer != nil {
355355
s.keepaliveTimer.Stop()
356+
s.keepaliveTimer = nil
357+
}
358+
}
359+
360+
func (s *Session) extendKeepalive() {
361+
s.keepaliveLock.Lock()
362+
if s.keepaliveTimer != nil && !s.keepaliveActive {
363+
// Don't stop the timer and drain the channel. This is an
364+
// AfterFunc, not a normal timer, and any attempts to drain the
365+
// channel will block forever.
366+
//
367+
// Go will stop the timer for us internally anyways. The docs
368+
// say one must stop the timer before calling reset but that's
369+
// to ensure that the timer doesn't end up firing immediately
370+
// after calling Reset.
371+
s.keepaliveTimer.Reset(s.config.KeepAliveInterval)
356372
}
373+
s.keepaliveLock.Unlock()
357374
}
358375

359376
// send sends the header and body.
@@ -512,9 +529,7 @@ func (s *Session) recvLoop() error {
512529
// There's no reason to keepalive if we're active. Worse, if the
513530
// peer is busy sending us stuff, the pong might get stuck
514531
// behind a bunch of data.
515-
if s.keepaliveTimer != nil {
516-
s.keepaliveTimer.Reset(s.config.KeepAliveInterval)
517-
}
532+
s.extendKeepalive()
518533

519534
// Verify the version
520535
if hdr.Version() != protoVersion {

session_norace_test.go

+163
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,163 @@
1+
//+build !race
2+
3+
package yamux
4+
5+
import (
6+
"bytes"
7+
"io"
8+
"io/ioutil"
9+
"sync"
10+
"testing"
11+
"time"
12+
)
13+
14+
func TestSession_PingOfDeath(t *testing.T) {
15+
client, server := testClientServerConfig(testConfNoKeepAlive())
16+
defer client.Close()
17+
defer server.Close()
18+
19+
count := 10000
20+
21+
var wg sync.WaitGroup
22+
begin := make(chan struct{})
23+
for i := 0; i < count; i++ {
24+
wg.Add(2)
25+
go func() {
26+
defer wg.Done()
27+
<-begin
28+
if _, err := server.Ping(); err != nil {
29+
t.Error(err)
30+
}
31+
}()
32+
go func() {
33+
defer wg.Done()
34+
<-begin
35+
if _, err := client.Ping(); err != nil {
36+
t.Error(err)
37+
}
38+
}()
39+
}
40+
close(begin)
41+
wg.Wait()
42+
}
43+
44+
func TestSendData_VeryLarge(t *testing.T) {
45+
client, server := testClientServer()
46+
defer client.Close()
47+
defer server.Close()
48+
49+
var n int64 = 1 * 1024 * 1024 * 1024
50+
var workers int = 16
51+
52+
wg := &sync.WaitGroup{}
53+
wg.Add(workers * 2)
54+
55+
for i := 0; i < workers; i++ {
56+
go func() {
57+
defer wg.Done()
58+
stream, err := server.AcceptStream()
59+
if err != nil {
60+
t.Errorf("err: %v", err)
61+
return
62+
}
63+
defer stream.Close()
64+
65+
buf := make([]byte, 4)
66+
_, err = io.ReadFull(stream, buf)
67+
if err != nil {
68+
t.Errorf("err: %v", err)
69+
return
70+
}
71+
if !bytes.Equal(buf, []byte{0, 1, 2, 3}) {
72+
t.Errorf("bad header")
73+
return
74+
}
75+
76+
recv, err := io.Copy(ioutil.Discard, stream)
77+
if err != nil {
78+
t.Errorf("err: %v", err)
79+
return
80+
}
81+
if recv != n {
82+
t.Errorf("bad: %v", recv)
83+
return
84+
}
85+
}()
86+
}
87+
for i := 0; i < workers; i++ {
88+
go func() {
89+
defer wg.Done()
90+
stream, err := client.Open()
91+
if err != nil {
92+
t.Errorf("err: %v", err)
93+
return
94+
}
95+
defer stream.Close()
96+
97+
_, err = stream.Write([]byte{0, 1, 2, 3})
98+
if err != nil {
99+
t.Errorf("err: %v", err)
100+
return
101+
}
102+
103+
unlimited := &UnlimitedReader{}
104+
sent, err := io.Copy(stream, io.LimitReader(unlimited, n))
105+
if err != nil {
106+
t.Errorf("err: %v", err)
107+
return
108+
}
109+
if sent != n {
110+
t.Errorf("bad: %v", sent)
111+
return
112+
}
113+
}()
114+
}
115+
116+
doneCh := make(chan struct{})
117+
go func() {
118+
wg.Wait()
119+
close(doneCh)
120+
}()
121+
select {
122+
case <-doneCh:
123+
case <-time.After(20 * time.Second):
124+
server.Close()
125+
client.Close()
126+
wg.Wait()
127+
t.Fatal("timeout")
128+
}
129+
}
130+
131+
func TestLargeWindow(t *testing.T) {
132+
conf := DefaultConfig()
133+
conf.MaxStreamWindowSize *= 2
134+
135+
client, server := testClientServerConfig(conf)
136+
defer client.Close()
137+
defer server.Close()
138+
139+
stream, err := client.Open()
140+
if err != nil {
141+
t.Fatalf("err: %v", err)
142+
}
143+
defer stream.Close()
144+
145+
stream2, err := server.Accept()
146+
if err != nil {
147+
t.Fatalf("err: %v", err)
148+
}
149+
defer stream2.Close()
150+
151+
err = stream.SetWriteDeadline(time.Now().Add(10 * time.Millisecond))
152+
if err != nil {
153+
t.Fatal(err)
154+
}
155+
buf := make([]byte, conf.MaxStreamWindowSize)
156+
n, err := stream.Write(buf)
157+
if err != nil {
158+
t.Fatalf("err: %v", err)
159+
}
160+
if n != len(buf) {
161+
t.Fatalf("short write: %d", n)
162+
}
163+
}

0 commit comments

Comments
 (0)