Skip to content

Commit 20f33fe

Browse files
committed
fix: make sure GetBlocks() channel is closed on session close
This commit was moved from ipfs/go-bitswap@994279b
1 parent 6b5fcf1 commit 20f33fe

File tree

4 files changed

+73
-8
lines changed

4 files changed

+73
-8
lines changed

bitswap/getter/getter.go

Lines changed: 25 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -61,40 +61,56 @@ func SyncGetBlock(p context.Context, k cid.Cid, gb GetBlocksFunc) (blocks.Block,
6161
type WantFunc func(context.Context, []cid.Cid)
6262

6363
// AsyncGetBlocks take a set of block cids, a pubsub channel for incoming
64-
// blocks, a want function, and a close function,
65-
// and returns a channel of incoming blocks.
66-
func AsyncGetBlocks(ctx context.Context, keys []cid.Cid, notif notifications.PubSub, want WantFunc, cwants func([]cid.Cid)) (<-chan blocks.Block, error) {
64+
// blocks, a want function, and a close function, and returns a channel of
65+
// incoming blocks.
66+
func AsyncGetBlocks(ctx context.Context, sessctx context.Context, keys []cid.Cid, notif notifications.PubSub,
67+
want WantFunc, cwants func([]cid.Cid)) (<-chan blocks.Block, error) {
68+
69+
// If there are no keys supplied, just return a closed channel
6770
if len(keys) == 0 {
6871
out := make(chan blocks.Block)
6972
close(out)
7073
return out, nil
7174
}
7275

76+
// Use a PubSub notifier to listen for incoming blocks for each key
7377
remaining := cid.NewSet()
7478
promise := notif.Subscribe(ctx, keys...)
7579
for _, k := range keys {
7680
log.Event(ctx, "Bitswap.GetBlockRequest.Start", k)
7781
remaining.Add(k)
7882
}
7983

84+
// Send the want request for the keys to the network
8085
want(ctx, keys)
8186

8287
out := make(chan blocks.Block)
83-
go handleIncoming(ctx, remaining, promise, out, cwants)
88+
go handleIncoming(ctx, sessctx, remaining, promise, out, cwants)
8489
return out, nil
8590
}
8691

87-
func handleIncoming(ctx context.Context, remaining *cid.Set, in <-chan blocks.Block, out chan blocks.Block, cfun func([]cid.Cid)) {
92+
// Listens for incoming blocks, passing them to the out channel.
93+
// If the context is cancelled or the incoming channel closes, calls cfun with
94+
// any keys corresponding to blocks that were never received.
95+
func handleIncoming(ctx context.Context, sessctx context.Context, remaining *cid.Set,
96+
in <-chan blocks.Block, out chan blocks.Block, cfun func([]cid.Cid)) {
97+
8898
ctx, cancel := context.WithCancel(ctx)
99+
100+
// Clean up before exiting this function, and call the cancel function on
101+
// any remaining keys
89102
defer func() {
90103
cancel()
91104
close(out)
92105
// can't just defer this call on its own, arguments are resolved *when* the defer is created
93106
cfun(remaining.Keys())
94107
}()
108+
95109
for {
96110
select {
97111
case blk, ok := <-in:
112+
// If the channel is closed, we're done (note that PubSub closes
113+
// the channel once all the keys have been received)
98114
if !ok {
99115
return
100116
}
@@ -104,9 +120,13 @@ func handleIncoming(ctx context.Context, remaining *cid.Set, in <-chan blocks.Bl
104120
case out <- blk:
105121
case <-ctx.Done():
106122
return
123+
case <-sessctx.Done():
124+
return
107125
}
108126
case <-ctx.Done():
109127
return
128+
case <-sessctx.Done():
129+
return
110130
}
111131
}
112132
}

bitswap/notifications/notifications.go

Lines changed: 4 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -60,8 +60,8 @@ func (ps *impl) Shutdown() {
6060
}
6161

6262
// Subscribe returns a channel of blocks for the given |keys|. |blockChannel|
63-
// is closed if the |ctx| times out or is cancelled, or after sending len(keys)
64-
// blocks.
63+
// is closed if the |ctx| times out or is cancelled, or after receiving the blocks
64+
// corresponding to |keys|.
6565
func (ps *impl) Subscribe(ctx context.Context, keys ...cid.Cid) <-chan blocks.Block {
6666

6767
blocksCh := make(chan blocks.Block, len(keys))
@@ -82,6 +82,8 @@ func (ps *impl) Subscribe(ctx context.Context, keys ...cid.Cid) <-chan blocks.Bl
8282
default:
8383
}
8484

85+
// AddSubOnceEach listens for each key in the list, and closes the channel
86+
// once all keys have been received
8587
ps.wrapped.AddSubOnceEach(valuesCh, toStrings(keys)...)
8688
go func() {
8789
defer func() {

bitswap/session/session.go

Lines changed: 2 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -182,7 +182,8 @@ func (s *Session) GetBlock(parent context.Context, k cid.Cid) (blocks.Block, err
182182
// guaranteed on the returned blocks.
183183
func (s *Session) GetBlocks(ctx context.Context, keys []cid.Cid) (<-chan blocks.Block, error) {
184184
ctx = logging.ContextWithLoggable(ctx, s.uuid)
185-
return bsgetter.AsyncGetBlocks(ctx, keys, s.notif,
185+
186+
return bsgetter.AsyncGetBlocks(ctx, s.ctx, keys, s.notif,
186187
func(ctx context.Context, keys []cid.Cid) {
187188
select {
188189
case s.newReqs <- keys:

bitswap/session/session_test.go

Lines changed: 42 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -416,3 +416,45 @@ func TestSessionFailingToGetFirstBlock(t *testing.T) {
416416
t.Fatal("Did not rebroadcast to find more peers")
417417
}
418418
}
419+
420+
func TestSessionCtxCancelClosesGetBlocksChannel(t *testing.T) {
421+
wantReqs := make(chan wantReq, 1)
422+
cancelReqs := make(chan wantReq, 1)
423+
fwm := &fakeWantManager{wantReqs, cancelReqs}
424+
fpm := &fakePeerManager{}
425+
frs := &fakeRequestSplitter{}
426+
notif := notifications.New()
427+
defer notif.Shutdown()
428+
id := testutil.GenerateSessionID()
429+
430+
// Create a new session with its own context
431+
sessctx, sesscancel := context.WithTimeout(context.Background(), 100*time.Millisecond)
432+
session := New(sessctx, id, fwm, fpm, frs, notif, time.Second, delay.Fixed(time.Minute))
433+
434+
timerCtx, timerCancel := context.WithTimeout(context.Background(), 10*time.Millisecond)
435+
defer timerCancel()
436+
437+
// Request a block with a new context
438+
blockGenerator := blocksutil.NewBlockGenerator()
439+
blks := blockGenerator.Blocks(1)
440+
getctx, getcancel := context.WithTimeout(context.Background(), 100*time.Millisecond)
441+
defer getcancel()
442+
443+
getBlocksCh, err := session.GetBlocks(getctx, []cid.Cid{blks[0].Cid()})
444+
if err != nil {
445+
t.Fatal("error getting blocks")
446+
}
447+
448+
// Cancel the session context
449+
sesscancel()
450+
451+
// Expect the GetBlocks() channel to be closed
452+
select {
453+
case _, ok := <-getBlocksCh:
454+
if ok {
455+
t.Fatal("expected channel to be closed but was not closed")
456+
}
457+
case <-timerCtx.Done():
458+
t.Fatal("expected channel to be closed before timeout")
459+
}
460+
}

0 commit comments

Comments
 (0)