Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

fix: Fix deadlock raised in #11 #13

Open
wants to merge 2 commits into
base: master
Choose a base branch
from
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
55 changes: 42 additions & 13 deletions errgroup.go
Original file line number Diff line number Diff line change
Expand Up @@ -40,8 +40,10 @@ type Group struct {

wg sync.WaitGroup

errOnce sync.Once
Copy link
Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Since now we will be reading err continuously then we needed an RWMutex already.
We could keep the sync.Once but the same can be achieved with RWMutex that is already needed, so removed the sync.Once. Let me know if you think it's cleaner to keep both sync.Once and sync.RWMutex.

err error
err error

// errMu protects err.
errMu sync.RWMutex
Copy link

@umitanuki umitanuki May 24, 2022

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

atomc.Value could also work?


// numG is the maximum number of goroutines that can be started.
numG int
Expand Down Expand Up @@ -154,13 +156,26 @@ func (g *Group) Go(f func() error) {
return
}

g.qCh <- f
for {
Copy link
Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

In this for loop we simply do the following:

  • Get a read lock on err
  • If err is NOT nil we return as there is no need to send function to buffer
  • If err is nil we try to send to channel, if we can't send to channel then we break the select and go back to checking if err is not nil

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

The only concern with this is it could become a busy loop. I wonder if sync.Cond can do something better here

g.errMu.RLock()
if g.err != nil {
g.errMu.RUnlock()
g.qMu.Unlock()

// Check if we can or should start a new goroutine?
g.maybeStartG()
return
}

g.qMu.Unlock()
select {
case g.qCh <- f:
g.errMu.RUnlock()
g.maybeStartG()
g.qMu.Unlock()

return
default:
g.errMu.RUnlock()
}
}
}

// maybeStartG might start a new worker goroutine, if
Expand Down Expand Up @@ -204,16 +219,30 @@ func (g *Group) startG() {
return
}

if err := f(); err != nil {
Copy link
Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

This looks nicer but as explained above we are already needing the RWMutex so this can be achieved by using the RWMutex instead of also introducing a new field sync.Once.

So let me know if you think its better to stay using sync.Once in addition to the RWMutex just for readability.

g.errOnce.Do(func() {
g.err = err
if g.cancel != nil {
g.cancel()
}
})
err := f()
if err == nil {
// happy path
continue
}

// an error exists
// checking if it's the first group error
g.errMu.Lock()
if g.err != nil {
// this is not the first group error
// no need to set it
g.errMu.Unlock()
return
}

g.err = err
g.errMu.Unlock()

if g.cancel != nil {
g.cancel()
}

return
}
}()
}
70 changes: 70 additions & 0 deletions errgroupn_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -2,6 +2,7 @@ package errgroup_test

import (
"context"
"errors"
"sync"
"testing"
"time"
Expand Down Expand Up @@ -105,6 +106,54 @@ func TestGroup(t *testing.T) {

}

func TestGroupWithErrors(t *testing.T) {
testCases := []struct {
name string
newG func() (grouper, context.Context)
}{
{name: "errgroup_zero", newG: newErrgroupZero},
{name: "errgroup_wctx", newG: newErrgroupWithContext},
{name: "errgroupn_zero", newG: newErrgroupnZero},
{name: "errgroupn_wctx", newG: newErrgroupnWithContext},
{name: "errgroupn_wctx_0_0", newG: newErrgroupnWithContextN(0, 0)},
{name: "errgroupn_wctx_1_0", newG: newErrgroupnWithContextN(1, 0)},
{name: "errgroupn_wctx_1_1", newG: newErrgroupnWithContextN(1, 1)},
{name: "errgroupn_wctx_4_16", newG: newErrgroupnWithContextN(4, 16)},
{name: "errgroupn_wctx_16_4", newG: newErrgroupnWithContextN(16, 4)},
}

for _, tc := range testCases {
tc := tc
t.Run(tc.name, func(t *testing.T) {
f := func() {
g, _ := tc.newG()

for i := 0; i < 1000; i++ {
i := i
g.Go(func() error {
// return an error for every 3
if i%3 == 0 {
return errors.New("sample error")
}

return nil
})
}

if g.Wait() == nil {
t.Error("Wait should return an error but did not")
}
}

// this may cause a deadlock, so running test with a timeout
testTimeout := 10 * time.Second
if err := mustRunInTime(testTimeout, f); err != nil {
t.Errorf("mustRunInTime failed with error: %v", err)
}
})
}
}

func TestEquivalence_GoWaitThenGoAgain(t *testing.T) {
testCases := []struct {
name string
Expand Down Expand Up @@ -265,3 +314,24 @@ func equalInts(a, b []int) bool {

return true
}

// mustRunInTime returns an error if execution of a function
// takes more than the timeout set
func mustRunInTime(d time.Duration, f func()) error {
c := make(chan struct{}, 1)

// Run your long running function in it's own goroutine and pass back it's
// response into our channel.
go func() {
f()
c <- struct{}{}
}()

// Listen on our channel AND a timeout channel - which ever happens first.
select {
case <-c:
return nil
case <-time.After(d):
return errors.New("timeout")
}
}