diff --git a/.github/workflows/test.yml b/.github/workflows/test.yml index a34bbe2..1c46343 100644 --- a/.github/workflows/test.yml +++ b/.github/workflows/test.yml @@ -26,7 +26,7 @@ jobs: - name: 🧸 golangci-lint uses: golangci/golangci-lint-action@v4 with: - version: v1.56.2 + version: v1.57.1 - name: 🔨 Test run: go test -race ./... env: diff --git a/.gitignore b/.gitignore index fca0085..c6b4c04 100644 --- a/.gitignore +++ b/.gitignore @@ -6,8 +6,9 @@ !/.github/ !/.gitignore !/.golangci.yaml -!/.markdownlint.json +!/.markdownlint.yaml !/.mockery.yaml +!/.prettierrc.yaml !/.yamlfmt !/.yamllint /bin/ diff --git a/.markdownlint.json b/.markdownlint.json deleted file mode 100644 index 1504898..0000000 --- a/.markdownlint.json +++ /dev/null @@ -1,11 +0,0 @@ -{ - "no-hard-tabs": { - "ignore_code_languages": [ - "go" - ], - "spaces_per_tab": 4 - }, - "line-length": { - "line_length": 120 - } -} diff --git a/.markdownlint.yaml b/.markdownlint.yaml new file mode 100644 index 0000000..06eee16 --- /dev/null +++ b/.markdownlint.yaml @@ -0,0 +1,7 @@ +--- +no-hard-tabs: + ignore_code_languages: + - go + spaces_per_tab: 4 +line-length: + line_length: 120 diff --git a/.prettierrc.yaml b/.prettierrc.yaml new file mode 100644 index 0000000..0769836 --- /dev/null +++ b/.prettierrc.yaml @@ -0,0 +1,9 @@ +--- +printWidth: 120 +proseWrap: always +tabWidth: 4 +useTabs: false +overrides: + - files: "*.md" + options: + tabWidth: 2 diff --git a/README.md b/README.md index 9f337df..c21c27f 100644 --- a/README.md +++ b/README.md @@ -7,7 +7,7 @@ [![Maintainability](https://api.codeclimate.com/v1/badges/12a77c18122e2d1e1f6b/maintainability)](https://codeclimate.com/github/fillmore-labs/promise/maintainability) [![Go Report Card](https://goreportcard.com/badge/fillmore-labs.com/promise)](https://goreportcard.com/report/fillmore-labs.com/promise) [![License](https://img.shields.io/github/license/fillmore-labs/promise)](https://www.apache.org/licenses/LICENSE-2.0) -[![FOSSA Status](https://app.fossa.com/api/projects/git%2Bgithub.com%2Ffillmore-labs%2Fpromise.svg?type=shield&issueType=license)](https://app.fossa.com/projects/git%2Bgithub.com%2Ffillmore-labs%2Fpromise?ref=badge_shield&issueType=license) +[![FOSSA Status](https://app.fossa.com/api/projects/git%2Bgithub.com%2Ffillmore-labs%2Fpromise.svg?type=shield&issueType=license)](https://app.fossa.com/projects/git%2Bgithub.com%2Ffillmore-labs%2Fpromise) The `promise` package provides interfaces and utilities for writing asynchronous code in Go. @@ -16,19 +16,18 @@ The `promise` package provides interfaces and utilities for writing asynchronous Promises and futures are constructs used for asynchronous and concurrent programming, allowing developers to work with values that may not be immediately available and can be evaluated in a different execution context. -Go is known for its built-in concurrency features like goroutines and channels. -The select statement further allows for efficient multiplexing and synchronization of multiple channels, thereby -enabling developers to coordinate and orchestrate asynchronous operations effectively. -Additionally, the context package offers a standardized way to manage cancellation, deadlines, and timeouts within -concurrent and asynchronous code. +Go is known for its built-in concurrency features like goroutines and channels. The select statement further allows for +efficient multiplexing and synchronization of multiple channels, thereby enabling developers to coordinate and +orchestrate asynchronous operations effectively. Additionally, the context package offers a standardized way to manage +cancellation, deadlines, and timeouts within concurrent and asynchronous code. On the other hand, Go's error handling mechanism, based on explicit error values returned from functions, provides a clear and concise way to handle errors. -The purpose of this package is to provide a library which simplifies the integration of concurrent -code while providing a cohesive strategy for handling asynchronous errors. -By adhering to Go's standard conventions for asynchronous and concurrent code, as well as error propagation, this -package aims to enhance developer productivity and code reliability in scenarios requiring asynchronous operations. +The purpose of this package is to provide a library which simplifies the integration of concurrent code while providing +a cohesive strategy for handling asynchronous errors. By adhering to Go's standard conventions for asynchronous and +concurrent code, as well as error propagation, this package aims to enhance developer productivity and code reliability +in scenarios requiring asynchronous operations. ## Usage diff --git a/combine.go b/combine.go index 8d7d041..b488abc 100644 --- a/combine.go +++ b/combine.go @@ -20,8 +20,6 @@ import ( "context" "fmt" "reflect" - - "fillmore-labs.com/promise/result" ) // AnyFuture matches a [Future] of any type. @@ -31,36 +29,32 @@ type AnyFuture interface { // AwaitAll returns a function that yields the results of all futures. // If the context is canceled, it returns an error for the remaining futures. -func AwaitAll[R any](ctx context.Context, futures ...Future[R]) func(yield func(int, result.Result[R]) bool) { - i := newIterator(ctx, convertValue[R], futures) - - return i.yieldTo +func AwaitAll[R any](ctx context.Context, futures ...Future[R]) func(yield func(int, Result[R]) bool) { + return newIterator(ctx, convertValue[R], futures) } // AwaitAllAny returns a function that yields the results of all futures. // If the context is canceled, it returns an error for the remaining futures. -func AwaitAllAny(ctx context.Context, futures ...AnyFuture) func(yield func(int, result.Result[any]) bool) { - i := newIterator(ctx, convertValueAny, futures) - - return i.yieldTo +func AwaitAllAny(ctx context.Context, futures ...AnyFuture) func(yield func(int, Result[any]) bool) { + return newIterator(ctx, convertValueAny, futures) } // AwaitAllResults waits for all futures to complete and returns the results. // If the context is canceled, it returns early with errors for the remaining futures. -func AwaitAllResults[R any](ctx context.Context, futures ...Future[R]) []result.Result[R] { +func AwaitAllResults[R any](ctx context.Context, futures ...Future[R]) []Result[R] { return awaitAllResults(len(futures), AwaitAll(ctx, futures...)) } // AwaitAllResultsAny waits for all futures to complete and returns the results. // If the context is canceled, it returns early with errors for the remaining futures. -func AwaitAllResultsAny(ctx context.Context, futures ...AnyFuture) []result.Result[any] { +func AwaitAllResultsAny(ctx context.Context, futures ...AnyFuture) []Result[any] { return awaitAllResults(len(futures), AwaitAllAny(ctx, futures...)) } -func awaitAllResults[R any](n int, iter func(yield func(int, result.Result[R]) bool)) []result.Result[R] { - results := make([]result.Result[R], n) +func awaitAllResults[R any](n int, iter func(yield func(int, Result[R]) bool)) []Result[R] { + results := make([]Result[R], n) - iter(func(i int, r result.Result[R]) bool { + iter(func(i int, r Result[R]) bool { results[i] = r return true @@ -81,17 +75,17 @@ func AwaitAllValuesAny(ctx context.Context, futures ...AnyFuture) ([]any, error) return awaitAllValues(len(futures), AwaitAllAny(ctx, futures...)) } -func awaitAllValues[R any](n int, iter func(yield func(int, result.Result[R]) bool)) ([]R, error) { +func awaitAllValues[R any](n int, iter func(yield func(int, Result[R]) bool)) ([]R, error) { results := make([]R, n) var yieldErr error - iter(func(i int, r result.Result[R]) bool { - if r.Err() != nil { - yieldErr = fmt.Errorf("list AwaitAllValues result %d: %w", i, r.Err()) + iter(func(i int, r Result[R]) bool { + if r.Err != nil { + yieldErr = fmt.Errorf("list AwaitAllValues result %d: %w", i, r.Err) return false } - results[i] = r.Value() + results[i] = r.Value return true }) @@ -111,11 +105,11 @@ func AwaitFirstAny(ctx context.Context, futures ...AnyFuture) (any, error) { return awaitFirst(AwaitAllAny(ctx, futures...)) } -func awaitFirst[R any](iter func(yield func(int, result.Result[R]) bool)) (R, error) { - var v result.Result[R] +func awaitFirst[R any](iter func(yield func(int, Result[R]) bool)) (R, error) { + var v *Result[R] - iter(func(_ int, r result.Result[R]) bool { - v = r + iter(func(_ int, r Result[R]) bool { + v = &r return false }) @@ -124,5 +118,5 @@ func awaitFirst[R any](iter func(yield func(int, result.Result[R]) bool)) (R, er return *new(R), ErrNoResult } - return v.V() + return v.Value, v.Err } diff --git a/combine_all_test.go b/combine_all_test.go index e98f1e1..92a0a19 100644 --- a/combine_all_test.go +++ b/combine_all_test.go @@ -23,7 +23,6 @@ import ( "testing" "fillmore-labs.com/promise" - "fillmore-labs.com/promise/result" "github.com/stretchr/testify/assert" ) @@ -49,20 +48,18 @@ func TestAll(t *testing.T) { defer cancel() // when - results := make([]result.Result[int], len(futures)) + results := make([]promise.Result[int], len(futures)) for i, r := range promise.AwaitAll(ctx, futures...) { //nolint:typecheck results[i] = r } // then - if assert.NoError(t, results[0].Err()) { - assert.Equal(t, 1, results[0].Value()) + if assert.NoError(t, results[0].Err) { + assert.Equal(t, 1, results[0].Value) } - if assert.ErrorIs(t, results[1].Err(), errTest) { - _ = results[1].Value() // Should not panic - } - if assert.NoError(t, results[2].Err()) { - assert.Equal(t, 2, results[2].Value()) + assert.ErrorIs(t, results[1].Err, errTest) + if assert.NoError(t, results[2].Err) { + assert.Equal(t, 2, results[2].Value) } } @@ -99,21 +96,21 @@ func TestAnyAll(t *testing.T) { p3.Resolve(struct{}{}) // when - results := make([]result.Result[any], 3) + results := make([]promise.Result[any], 3) for i, r := range promise.AwaitAllAny(ctx, f1, f2, f3) { //nolint:typecheck results[i] = r } // then for i, r := range results { - if assert.NoError(t, r.Err()) { + if assert.NoError(t, r.Err) { switch i { case 0: - assert.Equal(t, 1, r.Value()) + assert.Equal(t, 1, r.Value) case 1: - assert.Equal(t, "test", r.Value()) + assert.Equal(t, "test", r.Value) case 2: - assert.Equal(t, struct{}{}, r.Value()) + assert.Equal(t, struct{}{}, r.Value) default: assert.Fail(t, "unexpected index") } diff --git a/combine_test.go b/combine_test.go index 2472276..42e659a 100644 --- a/combine_test.go +++ b/combine_test.go @@ -21,7 +21,6 @@ import ( "testing" "fillmore-labs.com/promise" - "fillmore-labs.com/promise/result" "github.com/stretchr/testify/assert" ) @@ -54,9 +53,9 @@ func TestWaitAll(t *testing.T) { // then assert.Len(t, results, len(futures)) - v0, err0 := results[0].V() - _, err1 := results[1].V() - _, err2 := results[2].V() + v0, err0 := results[0].Value, results[0].Err + err1 := results[1].Err + err2 := results[2].Err if assert.NoError(t, err0) { assert.Equal(t, 1, v0) @@ -134,7 +133,7 @@ func TestCombineCancellation(t *testing.T) { {name: "All", combine: func(futures []promise.Future[int], ctx context.Context) error { r := promise.AwaitAllResults(ctx, futures...) - return r[0].Err() + return r[0].Err }}, {name: "AllValues", combine: func(futures []promise.Future[int], ctx context.Context) error { _, err := promise.AwaitAllValues(ctx, futures...) @@ -181,7 +180,7 @@ func TestCombineMemoized(t *testing.T) { return promise.AwaitAllResults(ctx, futures...), nil }, expect: func(t *testing.T, actual any) { t.Helper() - vv, ok := actual.([]result.Result[int]) + vv, ok := actual.([]promise.Result[int]) if !ok { assert.Fail(t, "Unexpected result type") @@ -189,7 +188,7 @@ func TestCombineMemoized(t *testing.T) { } for _, v := range vv { - value, err := v.V() + value, err := v.Value, v.Err if assert.NoError(t, err) { assert.Equal(t, 3, value) } @@ -282,12 +281,12 @@ func TestAllAny(t *testing.T) { p3, f3 := promise.New[struct{}]() p1.Resolve(1) - p2.Resolve("test") + close(p2) p3.Resolve(struct{}{}) // when - results := make([]result.Result[any], 3) - promise.AwaitAllAny(ctx, f1, f2, f3)(func(i int, r result.Result[any]) bool { + results := make([]promise.Result[any], 3) + promise.AwaitAllAny(ctx, f1, f2, f3)(func(i int, r promise.Result[any]) bool { results[i] = r return true @@ -295,42 +294,19 @@ func TestAllAny(t *testing.T) { // then for i, r := range results { - if assert.NoError(t, r.Err()) { - switch i { - case 0: - assert.Equal(t, 1, r.Value()) - case 1: - assert.Equal(t, "test", r.Value()) - case 2: - assert.Equal(t, struct{}{}, r.Value()) - default: - assert.Fail(t, "unexpected index") + switch i { + case 0: + if assert.NoError(t, r.Err) { + assert.Equal(t, 1, r.Value) } + case 1: + assert.ErrorIs(t, r.Err, promise.ErrNoResult) + case 2: + if assert.NoError(t, r.Err) { + assert.Equal(t, struct{}{}, r.Value) + } + default: + assert.Fail(t, "unexpected index") } } } - -func TestAllNil(t *testing.T) { - // given - t.Parallel() - ctx := context.Background() - - p1, f1 := promise.New[struct{}]() - p1 <- nil - - // when - var v result.Result[any] - var set bool - promise.AwaitAllAny(ctx, f1)(func(_ int, r result.Result[any]) bool { - if set { - assert.Fail(t, "Value already set") - } - v = r - set = true - - return false - }) - - // then - assert.ErrorIs(t, v.Err(), promise.ErrNoResult) -} diff --git a/future.go b/future.go index 7a04aaa..027ebb5 100644 --- a/future.go +++ b/future.go @@ -20,15 +20,32 @@ import ( "context" "fmt" "reflect" - - "fillmore-labs.com/promise/result" ) +type Result[R any] struct { + Value R + Err error +} + +func NewResult[R any](value R, err error) Result[R] { + return Result[R]{ + Value: value, + Err: err, + } +} + +func (r Result[R]) Any() Result[any] { + return Result[any]{ + Value: r.Value, + Err: r.Err, + } +} + // Future represents an asynchronous operation that will complete sometime in the future. // // It is a read-only channel that can be used with [Future.Await] to retrieve the final result of a // [Promise]. -type Future[R any] <-chan result.Result[R] +type Future[R any] <-chan Result[R] // NewAsync runs fn asynchronously, immediately returning a [Future] that can be used to retrieve // the eventual result. This allows separating evaluating the result from computation. @@ -45,11 +62,11 @@ func NewAsync[R any](fn func() (R, error)) Future[R] { func (f Future[R]) Await(ctx context.Context) (R, error) { select { case r, ok := <-f: - if !ok || r == nil { + if !ok { return *new(R), ErrNoResult } - return r.V() + return r.Value, r.Err case <-ctx.Done(): return *new(R), fmt.Errorf("channel await: %w", context.Cause(ctx)) @@ -60,17 +77,24 @@ func (f Future[R]) Await(ctx context.Context) (R, error) { func (f Future[R]) Try() (R, error) { select { case r, ok := <-f: - if !ok || r == nil { + if !ok { return *new(R), ErrNoResult } - return r.V() + return r.Value, r.Err default: return *new(R), ErrNotReady } } +// Memoize returns a memoizer for the given future, consuming it in the process. +// +// The [Memoizer] can be queried multiple times from multiple goroutines. +func (f Future[R]) Memoize() *Memoizer[R] { + return NewMemoizer(f) +} + func (f Future[_]) reflect() reflect.Value { return reflect.ValueOf(f) } diff --git a/future_test.go b/future_test.go index d7cedde..126539e 100644 --- a/future_test.go +++ b/future_test.go @@ -112,31 +112,6 @@ func (s *FutureTestSuite) TestTry() { s.ErrorIs(err3, promise.ErrNoResult) } -func (s *FutureTestSuite) TestNil() { - // given - ctx, cancel := context.WithCancel(context.Background()) - defer cancel() - - s.promise <- nil - - // when - _, err := s.future.Await(ctx) - - // then - s.ErrorIs(err, promise.ErrNoResult) -} - -func (s *FutureTestSuite) TestTryNil() { - // given - s.promise <- nil - - // when - _, err := s.future.Try() - - // then - s.ErrorIs(err, promise.ErrNoResult) -} - func TestAsync(t *testing.T) { t.Parallel() @@ -172,3 +147,15 @@ func TestAsyncCancel(t *testing.T) { //nolint:paralleltest assert.ErrorIs(t, err1, promise.ErrNotReady) assert.ErrorIs(t, err2, context.Canceled) } + +func TestNil(t *testing.T) { + t.Parallel() + + // given + p := promise.Promise[int](nil) + + // when + p.Resolve(1) + + // then +} diff --git a/go.mod b/go.mod index d671049..6822b2e 100644 --- a/go.mod +++ b/go.mod @@ -4,7 +4,11 @@ go 1.21 toolchain go1.22.1 -require github.com/stretchr/testify v1.9.0 +require ( + github.com/stretchr/testify v1.9.0 + go.uber.org/goleak v1.3.0 + golang.org/x/sync v0.6.0 +) require ( github.com/davecgh/go-spew v1.1.1 // indirect diff --git a/go.sum b/go.sum index c18bab0..309a010 100644 --- a/go.sum +++ b/go.sum @@ -15,6 +15,10 @@ github.com/rogpeppe/go-internal v1.9.0 h1:73kH8U+JUqXU8lRuOHeVHaa/SZPifC7BkcraZV github.com/rogpeppe/go-internal v1.9.0/go.mod h1:WtVeX8xhTBvf0smdhujwtBcq4Qrzq/fJaraNFVN+nFs= github.com/stretchr/testify v1.9.0 h1:HtqpIVDClZ4nwg75+f6Lvsy/wHu+3BoSGCbBAcpTsTg= github.com/stretchr/testify v1.9.0/go.mod h1:r2ic/lqez/lEtzL7wO/rwa5dbSLXVDPFyf8C91i36aY= +go.uber.org/goleak v1.3.0 h1:2K3zAYmnTNqV73imy9J1T3WC+gmCePx2hEGkimedGto= +go.uber.org/goleak v1.3.0/go.mod h1:CoHD4mav9JJNrW/WLlf7HGZPjdw8EucARQHekz1X6bE= +golang.org/x/sync v0.6.0 h1:5BMeUDZ7vkXGfEr1x9B4bRcTH4lpkTkpdh0T/J+qjbQ= +golang.org/x/sync v0.6.0/go.mod h1:Czt+wKu1gCyEFDUtn0jG5QVvpJ6rzVqr5aXyt9drQfk= gopkg.in/check.v1 v0.0.0-20161208181325-20d25e280405/go.mod h1:Co6ibVJAznAaIkqp8huTwlJQCZ016jof/cbN4VW5Yz0= gopkg.in/check.v1 v1.0.0-20201130134442-10cb98267c6c h1:Hei/4ADfdWqJk1ZMxUNpqntNwaWcugrBjAiHlqqRiVk= gopkg.in/check.v1 v1.0.0-20201130134442-10cb98267c6c/go.mod h1:JHkPIbrfpd72SG/EVd6muEfDQjcINNoR0C8j2r3qZ4Q= diff --git a/group/group.go b/group/group.go new file mode 100644 index 0000000..304b80e --- /dev/null +++ b/group/group.go @@ -0,0 +1,210 @@ +// Copyright 2023-2024 Oliver Eikemeier. All Rights Reserved. +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. +// +// SPDX-License-Identifier: Apache-2.0 + +package group + +import ( + "context" + "fmt" + "sync" + "sync/atomic" + + "fillmore-labs.com/promise" + "golang.org/x/sync/semaphore" +) + +// A Group represents a set of collaborating goroutines within a common overarching task. +// +// Should this group have a limit of active goroutines or cancel on the first error it needs to be created with [New]. +type Group struct { + sema *semaphore.Weighted + err atomic.Pointer[error] + cancel func(error) + panic atomic.Value + wg sync.WaitGroup +} + +// New creates a new [Group] with the given options. +func New(opts ...Option) *Group { + var option options + for _, opt := range opts { + opt.apply(&option) + } + + var sema *semaphore.Weighted + if option.limit > 0 { + sema = semaphore.NewWeighted(int64(option.limit)) + } + + return &Group{sema: sema, cancel: option.cancel} +} + +// options defines configurable parameters for the group. +type options struct { + cancel context.CancelCauseFunc + limit int +} + +// Option defines configurations for [New]. +type Option interface { + apply(opts *options) +} + +// WithLimit is an [Option] to configure the limit of active goroutines. +func WithLimit(limit int) Option { + if limit < 1 { + panic("limit must be at least 1") + } + + return limitOption{limit: limit} +} + +type limitOption struct { + limit int +} + +func (o limitOption) apply(opts *options) { + opts.limit = o.limit +} + +// WithCancel is an [Option] to cancel a context on the first error. +// +// cancel is a function retrieved from [context.WithCancelCause]. +func WithCancel(cancel context.CancelCauseFunc) Option { + return cancelOption{cancel: cancel} +} + +type cancelOption struct { + cancel context.CancelCauseFunc +} + +func (o cancelOption) apply(opts *options) { + opts.cancel = o.cancel +} + +// Wait blocks until all goroutines spawned from [Group.Go] and [DoAsync] are finished and +// returns the first non-nil error from them. +func (g *Group) Wait() error { + g.wg.Wait() + + if p := g.panic.Load(); p != nil { + panic(p) + } + + if err := g.err.Load(); err != nil { + return *err + } + + return nil +} + +type ExecutionError struct { + Value any +} + +func (e ExecutionError) Error() string { + return fmt.Sprintf("execution error: %v", e.Value) +} + +// DoAsync calls the given function in a new goroutine. +// +// If there is a limit on active goroutines within the group, it blocks until it can be spawned without surpassing the +// limit. If the passed context is canceled, a failed future is returned. +// +// If the group was created with [Option] [WithCancel], the first call that returns a non-nil error will cancel the +// group's context. This error will subsequently be returned by [Group.Wait]. +func DoAsync[R any](ctx context.Context, g *Group, fn func() (R, error)) (promise.Future[R], error) { + if err := g.acquire(ctx); err != nil { + return nil, fmt.Errorf("can not schedule goroutine: %w", err) + } + + p, f := promise.New[R]() + go p.Do(func() (R, error) { + defer g.release() + + value, err := exec(g, fn) + if err != nil { + g.setError(err) + } + + return value, err + }) + + return f, nil +} + +// Go calls the given function in a new goroutine. +// +// If there is a limit on active goroutines within the group, it blocks until it can be spawned without surpassing the +// limit. If the passed context is canceled, an error is returned. +// +// If the group was created with [Option] [WithCancel], the first call that returns a non-nil error will cancel the +// group's context. This error will subsequently be returned by [Group.Wait]. +func (g *Group) Go(ctx context.Context, fn func() error) error { + if err := g.acquire(ctx); err != nil { + return err + } + + go func() { + defer g.release() + + _, err := exec(g, func() (struct{}, error) { + return struct{}{}, fn() + }) + if err != nil { + g.setError(err) + } + }() + + return nil +} + +func (g *Group) acquire(ctx context.Context) error { + if g.sema != nil { + if err := g.sema.Acquire(ctx, 1); err != nil { + return fmt.Errorf("can not schedule goroutine: %w", err) + } + } + g.wg.Add(1) + + return nil +} + +func (g *Group) release() { + if g.sema != nil { + g.sema.Release(1) + } + g.wg.Done() +} + +func (g *Group) setError(err error) { + if g.err.CompareAndSwap(nil, &err) { + if g.cancel != nil { + g.cancel(err) + } + } +} + +func exec[R any](g *Group, fn func() (R, error)) (value R, err error) { + defer func() { + if r := recover(); r != nil { + _ = g.panic.CompareAndSwap(nil, r) + err = ExecutionError{Value: r} + } + }() + + return fn() +} diff --git a/group/group_test.go b/group/group_test.go new file mode 100644 index 0000000..289bebb --- /dev/null +++ b/group/group_test.go @@ -0,0 +1,159 @@ +// Copyright 2023-2024 Oliver Eikemeier. All Rights Reserved. +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. +// +// SPDX-License-Identifier: Apache-2.0 + +package group_test + +import ( + "context" + "errors" + "testing" + "time" + + "fillmore-labs.com/promise/group" + "github.com/stretchr/testify/assert" + "go.uber.org/goleak" +) + +var errTest = errors.New("test error") + +func TestMain(m *testing.M) { + goleak.VerifyTestMain(m) +} + +func TestDoAsync(t *testing.T) { + t.Parallel() + + // given + ctx, cancel := context.WithCancelCause(context.Background()) + g := group.New(group.WithCancel(cancel), group.WithLimit(1)) + + // when + f, _ := group.DoAsync(ctx, g, func() (int, error) { return 0, errTest }) + err := g.Wait() + _, errf := f.Try() + cause := context.Cause(ctx) + + // then + assert.ErrorIs(t, err, errTest) + assert.ErrorIs(t, errf, errTest) + assert.ErrorIs(t, cause, errTest) + + select { + case <-ctx.Done(): + default: + assert.Fail(t, "context should be canceled") + } +} + +func TestGroupReject(t *testing.T) { + t.Parallel() + + // given + g := group.New(group.WithLimit(1)) + ctx, cancel := context.WithCancel(context.Background()) + defer cancel() + + // when + ch := make(chan int) + f1, _ := group.DoAsync(ctx, g, func() (int, error) { + return <-ch, nil + }) + + ctx2, cancel2 := context.WithTimeout(ctx, 1*time.Millisecond) + defer cancel2() + + _, err2 := group.DoAsync(ctx2, g, func() (int, error) { return 1, nil }) + ch <- 1 + + err := g.Wait() + v1, err1 := f1.Try() + + // then + assert.NoError(t, err) + if assert.NoError(t, err1) { + assert.Equal(t, 1, v1) + } + assert.ErrorIs(t, err2, context.DeadlineExceeded) +} + +func TestGroupGoReject(t *testing.T) { + t.Parallel() + + // given + g := group.New(group.WithLimit(1)) + ctx, cancel := context.WithCancel(context.Background()) + defer cancel() + + // when + ch := make(chan error) + err1 := g.Go(ctx, func() error { + return <-ch + }) + + ctx2, cancel2 := context.WithTimeout(ctx, 1*time.Millisecond) + defer cancel2() + + err2 := g.Go(ctx2, func() error { return nil }) + ch <- errTest + + err := g.Wait() + + // then + assert.ErrorIs(t, err, errTest) + assert.NoError(t, err1) + assert.ErrorIs(t, err2, context.DeadlineExceeded) +} + +func TestGroupLimit(t *testing.T) { + t.Parallel() + + // given + defer func() { _ = recover() }() + + // when + _ = group.New(group.WithLimit(0)) + + // then + assert.Fail(t, "limit 0 should panic") +} + +func TestGroupPanic(t *testing.T) { + t.Parallel() + + // given + const mag = "test" + var g group.Group + + // when + f, _ := group.DoAsync(context.Background(), &g, func() (int, error) { panic(mag) }) + + var p any + func() { + defer func() { p = recover() }() + _ = g.Wait() + }() + + _, errf := f.Try() + + // then + + assert.ErrorContains(t, errf, mag) + var err group.ExecutionError + if assert.ErrorAs(t, errf, &err) { + assert.Equal(t, mag, err.Value) + } + assert.Equal(t, mag, p) +} diff --git a/iterator.go b/iterator.go index 810322c..536057c 100644 --- a/iterator.go +++ b/iterator.go @@ -21,22 +21,20 @@ import ( "fmt" "reflect" "runtime/trace" - - "fillmore-labs.com/promise/result" ) // This iterator is used to combine the results of multiple asynchronous operations waiting in parallel. type iterator[R any] struct { _ noCopy - numFutures int - cases []reflect.SelectCase - convertValue func(recv reflect.Value, ok bool) result.Result[R] ctx context.Context //nolint:containedctx + convertValue func(recv reflect.Value, ok bool) Result[R] + cases []reflect.SelectCase + numFutures int } func newIterator[R any, F AnyFuture]( - ctx context.Context, convertValue func(recv reflect.Value, ok bool) result.Result[R], l []F, -) iterator[R] { + ctx context.Context, convertValue func(recv reflect.Value, ok bool) Result[R], l []F, +) func(yield func(int, Result[R]) bool) { numFutures := len(l) cases := make([]reflect.SelectCase, numFutures+1) for idx, future := range l { @@ -50,15 +48,17 @@ func newIterator[R any, F AnyFuture]( Chan: reflect.ValueOf(ctx.Done()), } - return iterator[R]{ + i := iterator[R]{ numFutures: numFutures, cases: cases, convertValue: convertValue, ctx: ctx, } + + return i.yieldTo } -func (i *iterator[R]) yieldTo(yield func(int, result.Result[R]) bool) { +func (i *iterator[R]) yieldTo(yield func(int, Result[R]) bool) { defer trace.StartRegion(i.ctx, "promiseSeq").End() for run := 0; run < i.numFutures; run++ { chosen, recv, ok := reflect.Select(i.cases) @@ -78,28 +78,28 @@ func (i *iterator[R]) yieldTo(yield func(int, result.Result[R]) bool) { } } -func convertValue[R any](recv reflect.Value, ok bool) result.Result[R] { +func convertValue[R any](recv reflect.Value, ok bool) Result[R] { if ok { - if r, ok2 := recv.Interface().(result.Result[R]); ok2 { + if r, ok2 := recv.Interface().(Result[R]); ok2 { return r } } - return result.OfError[R](ErrNoResult) + return Result[R]{Err: ErrNoResult} } -func convertValueAny(recv reflect.Value, ok bool) result.Result[any] { +func convertValueAny(recv reflect.Value, ok bool) Result[any] { if ok { - if a, ok2 := recv.Interface().(result.AnyResult); ok2 { + if a, ok2 := recv.Interface().(interface{ Any() Result[any] }); ok2 { return a.Any() } } - return result.OfError[any](ErrNoResult) + return Result[any]{Err: ErrNoResult} } -func (i *iterator[R]) yieldErr(yield func(int, result.Result[R]) bool, err error) { - e := result.OfError[R](err) +func (i *iterator[R]) yieldErr(yield func(int, Result[R]) bool, err error) { + e := Result[R]{Err: err} for idx := 0; idx < i.numFutures; idx++ { if i.cases[idx].Chan.IsValid() && !yield(idx, e) { break diff --git a/memoizer.go b/memoizer.go index b2b9f5f..2eb2ebf 100644 --- a/memoizer.go +++ b/memoizer.go @@ -19,22 +19,18 @@ package promise import ( "context" "fmt" - - "fillmore-labs.com/promise/result" ) // A Memoizer is created with [Future.Memoize] and contains a memoized result of a future. type Memoizer[R any] struct { _ noCopy - wait chan struct{} - value result.Result[R] + result Result[R] future Future[R] + wait chan struct{} } -// Memoize returns a memoizer for the given future, consuming it in the process. -// -// The [Memoizer] can be queried multiple times from multiple goroutines. -func (f Future[R]) Memoize() *Memoizer[R] { +// NewMemoizer creates a new memoizer from a future. +func NewMemoizer[R any](f <-chan Result[R]) *Memoizer[R] { wait := make(chan struct{}, 1) wait <- struct{}{} @@ -49,23 +45,22 @@ func (m *Memoizer[R]) Await(ctx context.Context) (R, error) { select { case _, ok := <-m.wait: if !ok { - return m.value.V() + return m.result.Value, m.result.Err } case <-ctx.Done(): return *new(R), fmt.Errorf("memoizer canceled: %w", context.Cause(ctx)) } + var ok bool select { - case v, ok := <-m.future: - if ok && v != nil { - m.value = v - } else { - m.value = result.OfError[R](ErrNoResult) + case m.result, ok = <-m.future: + if !ok { + m.result.Err = ErrNoResult } close(m.wait) - return m.value.V() + return m.result.Value, m.result.Err case <-ctx.Done(): m.wait <- struct{}{} @@ -79,23 +74,22 @@ func (m *Memoizer[R]) Try() (R, error) { select { case _, ok := <-m.wait: if !ok { - return m.value.V() + return m.result.Value, m.result.Err } default: return *new(R), ErrNotReady } + var ok bool select { - case v, ok := <-m.future: - if ok && v != nil { - m.value = v - } else { - m.value = result.OfError[R](ErrNoResult) + case m.result, ok = <-m.future: + if !ok { + m.result.Err = ErrNoResult } close(m.wait) - return m.value.V() + return m.result.Value, m.result.Err default: m.wait <- struct{}{} diff --git a/memoizer_test.go b/memoizer_test.go index a8cd893..c42c500 100644 --- a/memoizer_test.go +++ b/memoizer_test.go @@ -23,7 +23,6 @@ import ( "time" "fillmore-labs.com/promise" - "fillmore-labs.com/promise/result" "github.com/stretchr/testify/assert" ) @@ -88,13 +87,13 @@ func TestMemoizerMany(t *testing.T) { ctx, cancel := context.WithTimeout(context.Background(), 5*time.Second) defer cancel() - results := make([]result.Result[int], iterations) + results := make([]promise.Result[int], iterations) var wg sync.WaitGroup for i := 0; i < iterations; i++ { wg.Add(1) go func(i int) { defer wg.Done() - results[i] = result.Of(m.Await(ctx)) + results[i] = promise.NewResult(m.Await(ctx)) }(i) } @@ -103,8 +102,8 @@ func TestMemoizerMany(t *testing.T) { // then for i := 0; i < iterations; i++ { - if assert.NoError(t, results[i].Err()) { - assert.Equal(t, 1, results[i].Value()) + if assert.NoError(t, results[i].Err) { + assert.Equal(t, 1, results[i].Value) } } } @@ -179,23 +178,3 @@ func TestMemoizerTryClosed(t *testing.T) { assert.ErrorIs(t, err1, promise.ErrNoResult) assert.ErrorIs(t, err2, promise.ErrNoResult) } - -func TestMemoizerNil(t *testing.T) { - t.Parallel() - - // given - p, f := promise.New[int]() - p <- nil - - m := f.Memoize() - - ctx := context.Background() - - // when - _, err1 := m.Try() - _, err2 := m.Await(ctx) - - // then - assert.ErrorIs(t, err1, promise.ErrNoResult) - assert.ErrorIs(t, err2, promise.ErrNoResult) -} diff --git a/promise.go b/promise.go index d68436f..2539aa8 100644 --- a/promise.go +++ b/promise.go @@ -16,40 +16,42 @@ package promise -import "fillmore-labs.com/promise/result" - // Promise is used to send the result of an asynchronous operation. // -// It is a write-only promise. +// It is a write-only channel. // Either [Promise.Resolve] or [Promise.Reject] should be called exactly once. -type Promise[R any] chan<- result.Result[R] +type Promise[R any] chan<- Result[R] // New provides a simple way to create a [Promise] for asynchronous operations. // This allows synchronous and asynchronous code to be composed seamlessly and separating initiation from running. // // The returned [Future] can be used to retrieve the eventual result of the [Promise]. func New[R any]() (Promise[R], Future[R]) { - ch := make(chan result.Result[R], 1) + ch := make(chan Result[R], 1) return ch, ch } // Resolve fulfills the promise with a value. func (p Promise[R]) Resolve(value R) { - p.complete(result.OfValue(value)) + p.complete(Result[R]{Value: value}) } // Reject breaks the promise with an error. func (p Promise[R]) Reject(err error) { - p.complete(result.OfError[R](err)) + p.complete(Result[R]{Err: err}) } // Do runs f synchronously, resolving the promise with the return value. func (p Promise[R]) Do(f func() (R, error)) { - p.complete(result.Of(f())) + p.complete(NewResult(f())) } -func (p Promise[R]) complete(r result.Result[R]) { +func (p Promise[R]) complete(r Result[R]) { + if p == nil { + return + } + p <- r close(p) } diff --git a/result/result.go b/result/result.go deleted file mode 100644 index ff32f3f..0000000 --- a/result/result.go +++ /dev/null @@ -1,100 +0,0 @@ -// Copyright 2023-2024 Oliver Eikemeier. All Rights Reserved. -// -// Licensed under the Apache License, Version 2.0 (the "License"); -// you may not use this file except in compliance with the License. -// You may obtain a copy of the License at -// -// http://www.apache.org/licenses/LICENSE-2.0 -// -// Unless required by applicable law or agreed to in writing, software -// distributed under the License is distributed on an "AS IS" BASIS, -// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -// See the License for the specific language governing permissions and -// limitations under the License. -// -// SPDX-License-Identifier: Apache-2.0 - -package result - -// Result defines the interface for returning results from asynchronous operations. -// It encapsulates the final value or error from the operation. -type Result[R any] interface { - AnyResult - V() (R, error) // The V method returns the final value or an error. - Value() R // The Value method returns the final value. - Err() error // The Err method returns the error. -} - -// AnyResult can be used with any [Result]. -type AnyResult interface { - Any() Result[any] // The Any method returns a Result[any] that can be used with any type. -} - -// Of creates a new [Result] from a pair of values. -func Of[R any](value R, err error) Result[R] { - if err != nil { - return errorResult[R]{err: err} - } - - return valueResult[R]{value: value} -} - -// OfValue creates a new [Result] from a value. -func OfValue[R any](value R) Result[R] { - return valueResult[R]{value: value} -} - -// OfError creates a new [Result] from an error. -func OfError[R any](err error) Result[R] { - return errorResult[R]{err: err} -} - -// valueResult is an implementation of [Result] that simply holds a value. -type valueResult[R any] struct { - value R -} - -// V returns the stored value. -func (v valueResult[R]) V() (R, error) { - return v.value, nil -} - -// Value returns the stored value. -func (v valueResult[R]) Value() R { - return v.value -} - -// The Err method returns nil. -func (v valueResult[_]) Err() error { - return nil -} - -// Any returns the valueResult as a Result[any]. -func (v valueResult[_]) Any() Result[any] { - return valueResult[any]{value: v.value} -} - -// errorResult handles errors from failed operations. -type errorResult[_ any] struct { - err error -} - -// V returns the stored error. -func (e errorResult[R]) V() (R, error) { - return *new(R), e.err -} - -// Value returns the null value. -func (e errorResult[R]) Value() R { - return *new(R) -} - -// Err returns the stored error. -func (e errorResult[_]) Err() error { - return e.err -} - -// Any returns the errorResult as a Result[any]. -func (e errorResult[_]) Any() Result[any] { - return errorResult[any](e) -} diff --git a/result/result_test.go b/result/result_test.go deleted file mode 100644 index 59e1318..0000000 --- a/result/result_test.go +++ /dev/null @@ -1,96 +0,0 @@ -// Copyright 2023-2024 Oliver Eikemeier. All Rights Reserved. -// -// Licensed under the Apache License, Version 2.0 (the "License"); -// you may not use this file except in compliance with the License. -// You may obtain a copy of the License at -// -// http://www.apache.org/licenses/LICENSE-2.0 -// -// Unless required by applicable law or agreed to in writing, software -// distributed under the License is distributed on an "AS IS" BASIS, -// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -// See the License for the specific language governing permissions and -// limitations under the License. -// -// SPDX-License-Identifier: Apache-2.0 - -package result_test - -import ( - "errors" - "testing" - - "fillmore-labs.com/promise/result" - "github.com/stretchr/testify/assert" -) - -var errTest = errors.New("test error") - -func TestV(t *testing.T) { - t.Parallel() - // given - r := result.OfValue(1) - // when - v, err := r.V() - // then - if assert.NoError(t, err) { - assert.Equal(t, 1, v) - } -} - -func TestVErr(t *testing.T) { - t.Parallel() - // given - r := result.OfError[struct{}](errTest) - // when - _, err := r.V() - // then - assert.ErrorIs(t, err, errTest) -} - -func TestOf(t *testing.T) { - t.Parallel() - // given - r := result.Of(1, nil) - // when - v := r.Value() - err := r.Err() - // then - if assert.NoError(t, err) { - assert.Equal(t, 1, v) - } -} - -func TestOfErr(t *testing.T) { - t.Parallel() - // given - r := result.Of(1, errTest) - // when - _ = r.Value() // doesn't panic - err := r.Err() - // then - assert.ErrorIs(t, err, errTest) -} - -func TestAny(t *testing.T) { - t.Parallel() - // given - r := result.OfValue(1) - // when - r2 := r.Any() - // then - if assert.NoError(t, r2.Err()) { - assert.Equal(t, 1, r2.Value()) - } -} - -func TestAnyErr(t *testing.T) { - t.Parallel() - // given - r := result.OfError[int](errTest) - // when - r2 := r.Any() - // then - assert.ErrorIs(t, r2.Err(), errTest) - _ = r2.Value() -}