Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
32 changes: 25 additions & 7 deletions dataloader.go
Original file line number Diff line number Diff line change
Expand Up @@ -26,6 +26,7 @@ type Interface[K comparable, V any] interface {
Clear(context.Context, K) Interface[K, V]
ClearAll() Interface[K, V]
Prime(ctx context.Context, key K, value V) Interface[K, V]
Flush()
}

// BatchFunc is a function, which when given a slice of keys (string), returns a slice of `results`.
Expand Down Expand Up @@ -264,21 +265,38 @@ func (l *Loader[K, V]) Load(originalContext context.Context, key K) Thunk[V] {
l.count++
// if we hit our limit, force the batch to start
if l.count == l.batchCap {
// end the batcher synchronously here because another call to Load
// end/flush the batcher synchronously here because another call to Load
// may concurrently happen and needs to go to a new batcher.
l.curBatcher.end()
// end the sleeper for the current batcher.
// this is to stop the goroutine without waiting for the
// sleeper timeout.
close(l.endSleeper)
l.reset()
l.flush()
}
}
l.batchLock.Unlock()

return thunk
}

// flush() is a helper that runs whatever batched items there are immediately.
// it must be called by code protected by a l.batchLock.Lock()
func (l *Loader[K, V]) flush() {
l.curBatcher.end()

// end the sleeper for the current batcher.
// this is to stop the goroutine without waiting for the
// sleeper timeout.
close(l.endSleeper)
l.reset()
}

// Flush will load the items in the current batch immediately without waiting for the timer.
func (l *Loader[K, V]) Flush() {
l.batchLock.Lock()
defer l.batchLock.Unlock()
if l.curBatcher == nil {
return
}
l.flush()
}

// LoadMany loads multiple keys, returning a thunk (type: ThunkMany) that will resolve the keys passed in.
func (l *Loader[K, V]) LoadMany(originalContext context.Context, keys []K) ThunkMany[V] {
ctx, finish := l.tracer.TraceLoadMany(originalContext, keys)
Expand Down
59 changes: 53 additions & 6 deletions dataloader_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -9,11 +9,12 @@ import (
"strconv"
"sync"
"testing"
"time"
)

///////////////////////////////////////////////////
// Tests
///////////////////////////////////////////////////
/*
Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

apologies about this change, vim kept changing the comments to // //////... on save so I just gave up and changed it to this.

Tests
*/
func TestLoader(t *testing.T) {
t.Run("test Load method", func(t *testing.T) {
t.Parallel()
Expand Down Expand Up @@ -289,6 +290,7 @@ func TestLoader(t *testing.T) {
t.Parallel()
identityLoader, loadCalls := IDLoader[string](0)
ctx := context.Background()
start := time.Now()
future1 := identityLoader.Load(ctx, "1")
future2 := identityLoader.Load(ctx, "1")

Expand All @@ -301,6 +303,12 @@ func TestLoader(t *testing.T) {
t.Error(err.Error())
}

// also check that it took the full timeout to return
var duration = time.Since(start)
if duration < 16*time.Millisecond {
t.Errorf("took %v when expected it to take more than 16 ms because of wait", duration)
}

calls := *loadCalls
inner := []string{"1"}
expected := [][]string{inner}
Expand All @@ -309,6 +317,45 @@ func TestLoader(t *testing.T) {
}
})

t.Run("doesn't wait for timeout if Flush() is called", func(t *testing.T) {
t.Parallel()
identityLoader, loadCalls := IDLoader[string](0)
ctx := context.Background()
start := time.Now()
future1 := identityLoader.Load(ctx, "1")
future2 := identityLoader.Load(ctx, "2")

// trigger them to be fetched immediately vs waiting for the 16 ms timer
identityLoader.Flush()

_, err := future1()
if err != nil {
t.Error(err.Error())
}
_, err = future2()
if err != nil {
t.Error(err.Error())
}

var duration = time.Since(start)
if duration > 2*time.Millisecond {
t.Errorf("took %v when expected it to take less than 2 ms b/c we called Flush()", duration)
}

calls := *loadCalls
inner := []string{"1", "2"}
expected := [][]string{inner}
if !reflect.DeepEqual(calls, expected) {
t.Errorf("did not respect max batch size. Expected %#v, got %#v", expected, calls)
}
})

t.Run("Nothing for Flush() to do on empty loader with current batch", func(t *testing.T) {
t.Parallel()
identityLoader, _ := IDLoader[string](0)
identityLoader.Flush()
})

t.Run("allows primed cache", func(t *testing.T) {
t.Parallel()
identityLoader, loadCalls := IDLoader[string](0)
Expand Down Expand Up @@ -678,9 +725,9 @@ func FaultyLoader[K comparable]() (*Loader[K, K], *[][]K) {
return loader, &loadCalls
}

///////////////////////////////////////////////////
// Benchmarks
///////////////////////////////////////////////////
/*
Benchmarks
*/
var a = &Avg{}

func batchIdentity[K comparable](_ context.Context, keys []K) (results []*Result[K]) {
Expand Down