diff --git a/.buildkite/pipeline.yaml b/.buildkite/pipeline.yaml index 5b26119..2d964d0 100644 --- a/.buildkite/pipeline.yaml +++ b/.buildkite/pipeline.yaml @@ -4,19 +4,14 @@ steps: commands: - golangci-lint run --timeout 10m0s - - label: ':hammer: Test' + - label: ':hammer: Test (:codecov: + :codeclimate:)' commands: - - gotestsum --junitfile test.xml ./... + - gotestsum --junitfile test.xml -- -race -coverprofile=cover.out ./... + - sh .buildkite/upload_coverage.sh cover.out plugins: - - test-collector#v1.10.0: + - test-collector#v1.10.1: files: test.xml format: junit env: GOEXPERIMENT: rangefunc - - label: ':codecov: + :codeclimate: Coverage' - commands: - - go test -race -coverprofile=cover.out ./... - - sh .buildkite/upload_coverage.sh cover.out - env: - GOEXPERIMENT: rangefunc diff --git a/.codeclimate.yml b/.codeclimate.yml index 275b6fd..052c19b 100644 --- a/.codeclimate.yml +++ b/.codeclimate.yml @@ -13,3 +13,6 @@ exclude_patterns: - "go.sum" - "LICENSE" - "nocopy.go" +engines: + golangci: + enabled: true diff --git a/.github/workflows/test.yml b/.github/workflows/test.yml index f5fa4e4..a34bbe2 100644 --- a/.github/workflows/test.yml +++ b/.github/workflows/test.yml @@ -26,7 +26,7 @@ jobs: - name: 🧸 golangci-lint uses: golangci/golangci-lint-action@v4 with: - version: v1.56.1 + version: v1.56.2 - name: 🔨 Test run: go test -race ./... env: diff --git a/README.md b/README.md index 422e2a3..9f337df 100644 --- a/README.md +++ b/README.md @@ -7,6 +7,7 @@ [![Maintainability](https://api.codeclimate.com/v1/badges/12a77c18122e2d1e1f6b/maintainability)](https://codeclimate.com/github/fillmore-labs/promise/maintainability) [![Go Report Card](https://goreportcard.com/badge/fillmore-labs.com/promise)](https://goreportcard.com/report/fillmore-labs.com/promise) [![License](https://img.shields.io/github/license/fillmore-labs/promise)](https://www.apache.org/licenses/LICENSE-2.0) +[![FOSSA Status](https://app.fossa.com/api/projects/git%2Bgithub.com%2Ffillmore-labs%2Fpromise.svg?type=shield&issueType=license)](https://app.fossa.com/projects/git%2Bgithub.com%2Ffillmore-labs%2Fpromise?ref=badge_shield&issueType=license) The `promise` package provides interfaces and utilities for writing asynchronous code in Go. diff --git a/combine.go b/combine.go index 299b932..8d7d041 100644 --- a/combine.go +++ b/combine.go @@ -19,28 +19,48 @@ package promise import ( "context" "fmt" - "runtime/trace" + "reflect" "fillmore-labs.com/promise/result" ) -// List is a list of [Future], representing results of asynchronous tasks. -type List[R any] []Future[R] +// AnyFuture matches a [Future] of any type. +type AnyFuture interface { + reflect() reflect.Value +} + +// AwaitAll returns a function that yields the results of all futures. +// If the context is canceled, it returns an error for the remaining futures. +func AwaitAll[R any](ctx context.Context, futures ...Future[R]) func(yield func(int, result.Result[R]) bool) { + i := newIterator(ctx, convertValue[R], futures) -// All returns a function that yields the results of all futures. + return i.yieldTo +} + +// AwaitAllAny returns a function that yields the results of all futures. // If the context is canceled, it returns an error for the remaining futures. -func (l List[R]) All(ctx context.Context) func(yield func(int, result.Result[R]) bool) { - defer trace.StartRegion(ctx, "asyncSeq").End() - s := newIterator(ctx, l) +func AwaitAllAny(ctx context.Context, futures ...AnyFuture) func(yield func(int, result.Result[any]) bool) { + i := newIterator(ctx, convertValueAny, futures) + + return i.yieldTo +} - return s.yieldTo +// AwaitAllResults waits for all futures to complete and returns the results. +// If the context is canceled, it returns early with errors for the remaining futures. +func AwaitAllResults[R any](ctx context.Context, futures ...Future[R]) []result.Result[R] { + return awaitAllResults(len(futures), AwaitAll(ctx, futures...)) } -// AwaitAll waits for all futures to complete and returns the results. +// AwaitAllResultsAny waits for all futures to complete and returns the results. // If the context is canceled, it returns early with errors for the remaining futures. -func (l List[R]) AwaitAll(ctx context.Context) []result.Result[R] { - results := make([]result.Result[R], len(l)) - l.All(ctx)(func(i int, r result.Result[R]) bool { +func AwaitAllResultsAny(ctx context.Context, futures ...AnyFuture) []result.Result[any] { + return awaitAllResults(len(futures), AwaitAllAny(ctx, futures...)) +} + +func awaitAllResults[R any](n int, iter func(yield func(int, result.Result[R]) bool)) []result.Result[R] { + results := make([]result.Result[R], n) + + iter(func(i int, r result.Result[R]) bool { results[i] = r return true @@ -51,10 +71,21 @@ func (l List[R]) AwaitAll(ctx context.Context) []result.Result[R] { // AwaitAllValues returns the values of completed futures. // If any future fails or the context is canceled, it returns early with an error. -func (l List[R]) AwaitAllValues(ctx context.Context) ([]R, error) { - results := make([]R, len(l)) +func AwaitAllValues[R any](ctx context.Context, futures ...Future[R]) ([]R, error) { + return awaitAllValues(len(futures), AwaitAll(ctx, futures...)) +} + +// AwaitAllValuesAny returns the values of completed futures. +// If any future fails or the context is canceled, it returns early with an error. +func AwaitAllValuesAny(ctx context.Context, futures ...AnyFuture) ([]any, error) { + return awaitAllValues(len(futures), AwaitAllAny(ctx, futures...)) +} + +func awaitAllValues[R any](n int, iter func(yield func(int, result.Result[R]) bool)) ([]R, error) { + results := make([]R, n) var yieldErr error - l.All(ctx)(func(i int, r result.Result[R]) bool { + + iter(func(i int, r result.Result[R]) bool { if r.Err() != nil { yieldErr = fmt.Errorf("list AwaitAllValues result %d: %w", i, r.Err()) @@ -70,13 +101,25 @@ func (l List[R]) AwaitAllValues(ctx context.Context) ([]R, error) { // AwaitFirst returns the result of the first completed future. // If the context is canceled, it returns early with an error. -func (l List[R]) AwaitFirst(ctx context.Context) (R, error) { +func AwaitFirst[R any](ctx context.Context, futures ...Future[R]) (R, error) { + return awaitFirst(AwaitAll(ctx, futures...)) +} + +// AwaitFirstAny returns the result of the first completed future. +// If the context is canceled, it returns early with an error. +func AwaitFirstAny(ctx context.Context, futures ...AnyFuture) (any, error) { + return awaitFirst(AwaitAllAny(ctx, futures...)) +} + +func awaitFirst[R any](iter func(yield func(int, result.Result[R]) bool)) (R, error) { var v result.Result[R] - l.All(ctx)(func(_ int, r result.Result[R]) bool { + + iter(func(_ int, r result.Result[R]) bool { v = r return false }) + if v == nil { return *new(R), ErrNoResult } diff --git a/combine_all_test.go b/combine_all_test.go index b5efa22..e98f1e1 100644 --- a/combine_all_test.go +++ b/combine_all_test.go @@ -42,16 +42,15 @@ func TestAll(t *testing.T) { } for i, v := range values { - value, err := v.value, v.err - go promises[i].Do(func() (int, error) { return value, err }) + promises[i].Do(func() (int, error) { return v.value, v.err }) } ctx, cancel := context.WithCancel(context.Background()) defer cancel() // when - var results [3]result.Result[int] - for i, r := range futures.All(ctx) { //nolint:typecheck + results := make([]result.Result[int], len(futures)) + for i, r := range promise.AwaitAll(ctx, futures...) { //nolint:typecheck results[i] = r } @@ -74,13 +73,50 @@ func TestAllEmpty(t *testing.T) { ctx, cancel := context.WithCancel(context.Background()) defer cancel() - var futures promise.List[int] - // when - allFutures := futures.All(ctx) + allFutures := promise.AwaitAllResults[int](ctx) // then + assert.Zero(t, len(allFutures)) for _, v := range allFutures { //nolint:typecheck t.Errorf("Invalid value %v", v) } } + +func TestAnyAll(t *testing.T) { + t.Parallel() + + // given + ctx, cancel := context.WithCancel(context.Background()) + defer cancel() + + p1, f1 := promise.New[int]() + p2, f2 := promise.New[string]() + p3, f3 := promise.New[struct{}]() + + p1.Resolve(1) + p2.Resolve("test") + p3.Resolve(struct{}{}) + + // when + results := make([]result.Result[any], 3) + for i, r := range promise.AwaitAllAny(ctx, f1, f2, f3) { //nolint:typecheck + results[i] = r + } + + // then + for i, r := range results { + if assert.NoError(t, r.Err()) { + switch i { + case 0: + assert.Equal(t, 1, r.Value()) + case 1: + assert.Equal(t, "test", r.Value()) + case 2: + assert.Equal(t, struct{}{}, r.Value()) + default: + assert.Fail(t, "unexpected index") + } + } + } +} diff --git a/combine_test.go b/combine_test.go index 0e18cfc..2472276 100644 --- a/combine_test.go +++ b/combine_test.go @@ -27,7 +27,7 @@ import ( const iterations = 3 -func makePromisesAndFutures[R any]() ([]promise.Promise[R], promise.List[R]) { +func makePromisesAndFutures[R any]() ([]promise.Promise[R], []promise.Future[R]) { var promises [iterations]promise.Promise[R] var futures [iterations]promise.Future[R] @@ -50,7 +50,7 @@ func TestWaitAll(t *testing.T) { // when ctx := context.Background() - results := futures.AwaitAll(ctx) + results := promise.AwaitAllResults(ctx, futures...) // then assert.Len(t, results, len(futures)) @@ -76,7 +76,7 @@ func TestAllValues(t *testing.T) { // when ctx := context.Background() - results, err := futures.AwaitAllValues(ctx) + results, err := promise.AwaitAllValues(ctx, futures...) // then if assert.NoError(t, err) { @@ -96,7 +96,7 @@ func TestAllValuesError(t *testing.T) { // when ctx := context.Background() - _, err := futures.AwaitAllValues(ctx) + _, err := promise.AwaitAllValues(ctx, futures...) // then assert.ErrorIs(t, err, errTest) @@ -111,11 +111,11 @@ func TestFirst(t *testing.T) { // when ctx := context.Background() - result, err := futures.AwaitFirst(ctx) + v, err := promise.AwaitFirst(ctx, futures...) // then if assert.NoError(t, err) { - assert.Equal(t, 2, result) + assert.Equal(t, 2, v) } } @@ -124,20 +124,20 @@ func TestCombineCancellation(t *testing.T) { subTests := []struct { name string - combine func(promise.List[int], context.Context) error + combine func([]promise.Future[int], context.Context) error }{ - {name: "First", combine: func(futures promise.List[int], ctx context.Context) error { - _, err := futures.AwaitFirst(ctx) + {name: "First", combine: func(futures []promise.Future[int], ctx context.Context) error { + _, err := promise.AwaitFirst(ctx, futures...) return err }}, - {name: "All", combine: func(futures promise.List[int], ctx context.Context) error { - r := futures.AwaitAll(ctx) + {name: "All", combine: func(futures []promise.Future[int], ctx context.Context) error { + r := promise.AwaitAllResults(ctx, futures...) return r[0].Err() }}, - {name: "AllValues", combine: func(futures promise.List[int], ctx context.Context) error { - _, err := futures.AwaitAllValues(ctx) + {name: "AllValues", combine: func(futures []promise.Future[int], ctx context.Context) error { + _, err := promise.AwaitAllValues(ctx, futures...) return err }}, @@ -171,14 +171,14 @@ func TestCombineMemoized(t *testing.T) { subTests := []struct { name string - combine func(promise.List[int], context.Context) (any, error) + combine func(context.Context, []promise.Future[int]) (any, error) expect func(t *testing.T, actual any) }{ - {name: "First", combine: func(futures promise.List[int], ctx context.Context) (any, error) { - return futures.AwaitFirst(ctx) + {name: "First", combine: func(ctx context.Context, futures []promise.Future[int]) (any, error) { + return promise.AwaitFirst(ctx, futures...) }, expect: func(t *testing.T, actual any) { t.Helper(); assert.Equal(t, 3, actual) }}, - {name: "All", combine: func(futures promise.List[int], ctx context.Context) (any, error) { - return futures.AwaitAll(ctx), nil + {name: "All", combine: func(ctx context.Context, futures []promise.Future[int]) (any, error) { + return promise.AwaitAllResults(ctx, futures...), nil }, expect: func(t *testing.T, actual any) { t.Helper() vv, ok := actual.([]result.Result[int]) @@ -195,8 +195,8 @@ func TestCombineMemoized(t *testing.T) { } } }}, - {name: "AllValues", combine: func(futures promise.List[int], ctx context.Context) (any, error) { - return futures.AwaitAllValues(ctx) + {name: "AllValues", combine: func(ctx context.Context, futures []promise.Future[int]) (any, error) { + return promise.AwaitAllValues(ctx, futures...) }, expect: func(t *testing.T, actual any) { t.Helper(); assert.Equal(t, []int{3, 3, 3}, actual) }}, } @@ -213,16 +213,16 @@ func TestCombineMemoized(t *testing.T) { ctx, cancel := context.WithCancel(context.Background()) defer cancel() - for _, promise := range promises { - promise.Resolve(3) + for _, p := range promises { + p.Resolve(3) } // when - result, err := combine(futures, ctx) + v, err := combine(ctx, futures) // then if assert.NoError(t, err) { - expect(t, result) + expect(t, v) } }) } @@ -231,28 +231,28 @@ func TestCombineMemoized(t *testing.T) { func TestAwaitAllEmpty(t *testing.T) { t.Parallel() + // given ctx, cancel := context.WithCancel(context.Background()) defer cancel() - var futures promise.List[int] - // when - results := futures.AwaitAll(ctx) + results := promise.AwaitAllResultsAny(ctx) + // then assert.Empty(t, results) } func TestAwaitAllValuesEmpty(t *testing.T) { t.Parallel() + // given ctx, cancel := context.WithCancel(context.Background()) defer cancel() - var futures promise.List[int] - // when - results, err := futures.AwaitAllValues(ctx) + results, err := promise.AwaitAllValuesAny(ctx) + // then if assert.NoError(t, err) { assert.Empty(t, results) } @@ -261,13 +261,76 @@ func TestAwaitAllValuesEmpty(t *testing.T) { func TestAwaitFirstEmpty(t *testing.T) { t.Parallel() + // given ctx, cancel := context.WithCancel(context.Background()) defer cancel() - var futures promise.List[int] - // when - _, err := futures.AwaitFirst(ctx) + _, err := promise.AwaitFirstAny(ctx) + // then assert.ErrorIs(t, err, promise.ErrNoResult) } + +func TestAllAny(t *testing.T) { + // given + t.Parallel() + ctx := context.Background() + + p1, f1 := promise.New[int]() + p2, f2 := promise.New[string]() + p3, f3 := promise.New[struct{}]() + + p1.Resolve(1) + p2.Resolve("test") + p3.Resolve(struct{}{}) + + // when + results := make([]result.Result[any], 3) + promise.AwaitAllAny(ctx, f1, f2, f3)(func(i int, r result.Result[any]) bool { + results[i] = r + + return true + }) + + // then + for i, r := range results { + if assert.NoError(t, r.Err()) { + switch i { + case 0: + assert.Equal(t, 1, r.Value()) + case 1: + assert.Equal(t, "test", r.Value()) + case 2: + assert.Equal(t, struct{}{}, r.Value()) + default: + assert.Fail(t, "unexpected index") + } + } + } +} + +func TestAllNil(t *testing.T) { + // given + t.Parallel() + ctx := context.Background() + + p1, f1 := promise.New[struct{}]() + p1 <- nil + + // when + var v result.Result[any] + var set bool + promise.AwaitAllAny(ctx, f1)(func(_ int, r result.Result[any]) bool { + if set { + assert.Fail(t, "Value already set") + } + v = r + set = true + + return false + }) + + // then + assert.ErrorIs(t, v.Err(), promise.ErrNoResult) +} diff --git a/future.go b/future.go index 6e3f4ac..7a04aaa 100644 --- a/future.go +++ b/future.go @@ -45,7 +45,7 @@ func NewAsync[R any](fn func() (R, error)) Future[R] { func (f Future[R]) Await(ctx context.Context) (R, error) { select { case r, ok := <-f: - if !ok { + if !ok || r == nil { return *new(R), ErrNoResult } @@ -60,7 +60,7 @@ func (f Future[R]) Await(ctx context.Context) (R, error) { func (f Future[R]) Try() (R, error) { select { case r, ok := <-f: - if !ok { + if !ok || r == nil { return *new(R), ErrNoResult } diff --git a/future_test.go b/future_test.go index 310079e..d7cedde 100644 --- a/future_test.go +++ b/future_test.go @@ -112,7 +112,32 @@ func (s *FutureTestSuite) TestTry() { s.ErrorIs(err3, promise.ErrNoResult) } -func TestAsyncFuture(t *testing.T) { +func (s *FutureTestSuite) TestNil() { + // given + ctx, cancel := context.WithCancel(context.Background()) + defer cancel() + + s.promise <- nil + + // when + _, err := s.future.Await(ctx) + + // then + s.ErrorIs(err, promise.ErrNoResult) +} + +func (s *FutureTestSuite) TestTryNil() { + // given + s.promise <- nil + + // when + _, err := s.future.Try() + + // then + s.ErrorIs(err, promise.ErrNoResult) +} + +func TestAsync(t *testing.T) { t.Parallel() // given diff --git a/go.mod b/go.mod index f846e98..d671049 100644 --- a/go.mod +++ b/go.mod @@ -2,15 +2,14 @@ module fillmore-labs.com/promise go 1.21 -toolchain go1.22.0 +toolchain go1.22.1 -require github.com/stretchr/testify v1.8.4 +require github.com/stretchr/testify v1.9.0 require ( github.com/davecgh/go-spew v1.1.1 // indirect - github.com/kr/pretty v0.1.0 // indirect - github.com/kr/text v0.2.0 // indirect + github.com/kr/pretty v0.3.1 // indirect github.com/pmezard/go-difflib v1.0.0 // indirect - gopkg.in/check.v1 v1.0.0-20180628173108-788fd7840127 // indirect + gopkg.in/check.v1 v1.0.0-20201130134442-10cb98267c6c // indirect gopkg.in/yaml.v3 v3.0.1 // indirect ) diff --git a/go.sum b/go.sum index 193bd00..c18bab0 100644 --- a/go.sum +++ b/go.sum @@ -1,18 +1,22 @@ github.com/creack/pty v1.1.9/go.mod h1:oKZEueFk5CKHvIhNR5MUki03XCEU+Q6VDXinZuGJ33E= github.com/davecgh/go-spew v1.1.1 h1:vj9j/u1bqnvCEfJOwUhtlOARqs3+rkHYY13jYWTU97c= github.com/davecgh/go-spew v1.1.1/go.mod h1:J7Y8YcW2NihsgmVo/mv3lAwl/skON4iLHjSsI+c5H38= -github.com/kr/pretty v0.1.0 h1:L/CwN0zerZDmRFUapSPitk6f+Q3+0za1rQkzVuMiMFI= -github.com/kr/pretty v0.1.0/go.mod h1:dAy3ld7l9f0ibDNOQOHHMYYIIbhfbHSm3C4ZsoJORNo= +github.com/kr/pretty v0.2.1/go.mod h1:ipq/a2n7PKx3OHsz4KJII5eveXtPO4qwEXGdVfWzfnI= +github.com/kr/pretty v0.3.1 h1:flRD4NNwYAUpkphVc1HcthR4KEIFJ65n8Mw5qdRn3LE= +github.com/kr/pretty v0.3.1/go.mod h1:hoEshYVHaxMs3cyo3Yncou5ZscifuDolrwPKZanG3xk= github.com/kr/pty v1.1.1/go.mod h1:pFQYn66WHrOpPYNljwOMqo10TkYh1fy3cYio2l3bCsQ= github.com/kr/text v0.1.0/go.mod h1:4Jbv+DJW3UT/LiOwJeYQe1efqtUx/iVham/4vfdArNI= github.com/kr/text v0.2.0 h1:5Nx0Ya0ZqY2ygV366QzturHI13Jq95ApcVaJBhpS+AY= github.com/kr/text v0.2.0/go.mod h1:eLer722TekiGuMkidMxC/pM04lWEeraHUUmBw8l2grE= +github.com/pkg/diff v0.0.0-20210226163009-20ebb0f2a09e/go.mod h1:pJLUxLENpZxwdsKMEsNbx1VGcRFpLqf3715MtcvvzbA= github.com/pmezard/go-difflib v1.0.0 h1:4DBwDE0NGyQoBHbLQYPwSUPoCMWR5BEzIk/f1lZbAQM= github.com/pmezard/go-difflib v1.0.0/go.mod h1:iKH77koFhYxTK1pcRnkKkqfTogsbg7gZNVY4sRDYZ/4= -github.com/stretchr/testify v1.8.4 h1:CcVxjf3Q8PM0mHUKJCdn+eZZtm5yQwehR5yeSVQQcUk= -github.com/stretchr/testify v1.8.4/go.mod h1:sz/lmYIOXD/1dqDmKjjqLyZ2RngseejIcXlSw2iwfAo= +github.com/rogpeppe/go-internal v1.9.0 h1:73kH8U+JUqXU8lRuOHeVHaa/SZPifC7BkcraZVejAe8= +github.com/rogpeppe/go-internal v1.9.0/go.mod h1:WtVeX8xhTBvf0smdhujwtBcq4Qrzq/fJaraNFVN+nFs= +github.com/stretchr/testify v1.9.0 h1:HtqpIVDClZ4nwg75+f6Lvsy/wHu+3BoSGCbBAcpTsTg= +github.com/stretchr/testify v1.9.0/go.mod h1:r2ic/lqez/lEtzL7wO/rwa5dbSLXVDPFyf8C91i36aY= gopkg.in/check.v1 v0.0.0-20161208181325-20d25e280405/go.mod h1:Co6ibVJAznAaIkqp8huTwlJQCZ016jof/cbN4VW5Yz0= -gopkg.in/check.v1 v1.0.0-20180628173108-788fd7840127 h1:qIbj1fsPNlZgppZ+VLlY7N33q108Sa+fhmuc+sWQYwY= -gopkg.in/check.v1 v1.0.0-20180628173108-788fd7840127/go.mod h1:Co6ibVJAznAaIkqp8huTwlJQCZ016jof/cbN4VW5Yz0= +gopkg.in/check.v1 v1.0.0-20201130134442-10cb98267c6c h1:Hei/4ADfdWqJk1ZMxUNpqntNwaWcugrBjAiHlqqRiVk= +gopkg.in/check.v1 v1.0.0-20201130134442-10cb98267c6c/go.mod h1:JHkPIbrfpd72SG/EVd6muEfDQjcINNoR0C8j2r3qZ4Q= gopkg.in/yaml.v3 v3.0.1 h1:fxVm/GzAzEWqLHuvctI91KS9hhNmmWOoWu0XTYJS7CA= gopkg.in/yaml.v3 v3.0.1/go.mod h1:K4uyk7z7BCEPqu6E+C64Yfv1cQ7kz7rIZviUmN+EgEM= diff --git a/iterator.go b/iterator.go index 3068d36..810322c 100644 --- a/iterator.go +++ b/iterator.go @@ -20,18 +20,23 @@ import ( "context" "fmt" "reflect" + "runtime/trace" "fillmore-labs.com/promise/result" ) // This iterator is used to combine the results of multiple asynchronous operations waiting in parallel. type iterator[R any] struct { - numFutures int - cases []reflect.SelectCase - ctxErr func() error + _ noCopy + numFutures int + cases []reflect.SelectCase + convertValue func(recv reflect.Value, ok bool) result.Result[R] + ctx context.Context //nolint:containedctx } -func newIterator[R any](ctx context.Context, l List[R]) *iterator[R] { +func newIterator[R any, F AnyFuture]( + ctx context.Context, convertValue func(recv reflect.Value, ok bool) result.Result[R], l []F, +) iterator[R] { numFutures := len(l) cases := make([]reflect.SelectCase, numFutures+1) for idx, future := range l { @@ -45,37 +50,52 @@ func newIterator[R any](ctx context.Context, l List[R]) *iterator[R] { Chan: reflect.ValueOf(ctx.Done()), } - return &iterator[R]{ - numFutures: numFutures, - cases: cases, - ctxErr: func() error { return context.Cause(ctx) }, + return iterator[R]{ + numFutures: numFutures, + cases: cases, + convertValue: convertValue, + ctx: ctx, } } func (i *iterator[R]) yieldTo(yield func(int, result.Result[R]) bool) { + defer trace.StartRegion(i.ctx, "promiseSeq").End() for run := 0; run < i.numFutures; run++ { chosen, recv, ok := reflect.Select(i.cases) if chosen == i.numFutures { // context channel - err := fmt.Errorf("list yield canceled: %w", i.ctxErr()) + err := fmt.Errorf("list yield canceled: %w", context.Cause(i.ctx)) i.yieldErr(yield, err) break } i.cases[chosen].Chan = reflect.Value{} // Disable case + v := i.convertValue(recv, ok) + if !yield(chosen, v) { + break + } + } +} - var v result.Result[R] - if ok { - v, _ = recv.Interface().(result.Result[R]) - } else { - v = result.OfError[R](ErrNoResult) +func convertValue[R any](recv reflect.Value, ok bool) result.Result[R] { + if ok { + if r, ok2 := recv.Interface().(result.Result[R]); ok2 { + return r } + } - if !yield(chosen, v) { - break + return result.OfError[R](ErrNoResult) +} + +func convertValueAny(recv reflect.Value, ok bool) result.Result[any] { + if ok { + if a, ok2 := recv.Interface().(result.AnyResult); ok2 { + return a.Any() } } + + return result.OfError[any](ErrNoResult) } func (i *iterator[R]) yieldErr(yield func(int, result.Result[R]) bool, err error) { diff --git a/memoizer.go b/memoizer.go index 987f25a..b2b9f5f 100644 --- a/memoizer.go +++ b/memoizer.go @@ -25,8 +25,7 @@ import ( // A Memoizer is created with [Future.Memoize] and contains a memoized result of a future. type Memoizer[R any] struct { - _ noCopy - + _ noCopy wait chan struct{} value result.Result[R] future Future[R] @@ -59,7 +58,7 @@ func (m *Memoizer[R]) Await(ctx context.Context) (R, error) { select { case v, ok := <-m.future: - if ok { + if ok && v != nil { m.value = v } else { m.value = result.OfError[R](ErrNoResult) @@ -89,7 +88,7 @@ func (m *Memoizer[R]) Try() (R, error) { select { case v, ok := <-m.future: - if ok { + if ok && v != nil { m.value = v } else { m.value = result.OfError[R](ErrNoResult) diff --git a/memoizer_test.go b/memoizer_test.go index 6cfd026..a8cd893 100644 --- a/memoizer_test.go +++ b/memoizer_test.go @@ -129,12 +129,12 @@ func TestMemoizerCancel(t *testing.T) { time.Sleep(time.Millisecond) f2 := promise.NewAsync(fn2) - cancel1() - _, err1 := f1.Await(ctx) - _, err2 := f2.Try() - cancel2() - _, err3 := f2.Await(ctx) + _, err1 := f2.Await(ctx) + _, err2 := f1.Try() + + cancel1() + _, err3 := f1.Await(ctx) // then assert.ErrorIs(t, err1, context.Canceled) @@ -147,13 +147,12 @@ func TestMemoizerClosed(t *testing.T) { // given p, f := promise.New[int]() + close(p) + m := f.Memoize() ctx := context.Background() - p.Resolve(1) - _, _ = f.Await(ctx) - // when _, err := m.Await(ctx) @@ -166,12 +165,31 @@ func TestMemoizerTryClosed(t *testing.T) { // given p, f := promise.New[int]() + close(p) + m := f.Memoize() ctx := context.Background() - p.Resolve(1) - _, _ = f.Await(ctx) + // when + _, err1 := m.Try() + _, err2 := m.Await(ctx) + + // then + assert.ErrorIs(t, err1, promise.ErrNoResult) + assert.ErrorIs(t, err2, promise.ErrNoResult) +} + +func TestMemoizerNil(t *testing.T) { + t.Parallel() + + // given + p, f := promise.New[int]() + p <- nil + + m := f.Memoize() + + ctx := context.Background() // when _, err1 := m.Try() diff --git a/result/result.go b/result/result.go index ad8a9dc..ff32f3f 100644 --- a/result/result.go +++ b/result/result.go @@ -19,10 +19,14 @@ package result // Result defines the interface for returning results from asynchronous operations. // It encapsulates the final value or error from the operation. type Result[R any] interface { + AnyResult V() (R, error) // The V method returns the final value or an error. Value() R // The Value method returns the final value. Err() error // The Err method returns the error. +} +// AnyResult can be used with any [Result]. +type AnyResult interface { Any() Result[any] // The Any method returns a Result[any] that can be used with any type. } @@ -65,7 +69,7 @@ func (v valueResult[_]) Err() error { return nil } -// AnyResult returns the valueResult as a Result[any]. +// Any returns the valueResult as a Result[any]. func (v valueResult[_]) Any() Result[any] { return valueResult[any]{value: v.value} } @@ -90,7 +94,7 @@ func (e errorResult[_]) Err() error { return e.err } -// AnyResult returns the errorResult as a Result[any]. +// Any returns the errorResult as a Result[any]. func (e errorResult[_]) Any() Result[any] { return errorResult[any](e) } diff --git a/result/result_test.go b/result/result_test.go index 48d82ec..59e1318 100644 --- a/result/result_test.go +++ b/result/result_test.go @@ -29,7 +29,7 @@ var errTest = errors.New("test error") func TestV(t *testing.T) { t.Parallel() // given - r := result.Of(1, nil) + r := result.OfValue(1) // when v, err := r.V() // then @@ -41,13 +41,37 @@ func TestV(t *testing.T) { func TestVErr(t *testing.T) { t.Parallel() // given - r := result.Of(0, errTest) + r := result.OfError[struct{}](errTest) // when _, err := r.V() // then assert.ErrorIs(t, err, errTest) } +func TestOf(t *testing.T) { + t.Parallel() + // given + r := result.Of(1, nil) + // when + v := r.Value() + err := r.Err() + // then + if assert.NoError(t, err) { + assert.Equal(t, 1, v) + } +} + +func TestOfErr(t *testing.T) { + t.Parallel() + // given + r := result.Of(1, errTest) + // when + _ = r.Value() // doesn't panic + err := r.Err() + // then + assert.ErrorIs(t, err, errTest) +} + func TestAny(t *testing.T) { t.Parallel() // given