Skip to content

Commit 98cdc05

Browse files
committed
[parallelisation] Graceful shutdown
1 parent 8671555 commit 98cdc05

File tree

3 files changed

+184
-2
lines changed

3 files changed

+184
-2
lines changed

changes/20231016114710.feature

+1
Original file line numberDiff line numberDiff line change
@@ -0,0 +1 @@
1+
:sparkles: `[parallelisation]` Run action with interrupt handling

utils/parallelisation/parallelisation.go

+50-1
Original file line numberDiff line numberDiff line change
@@ -8,7 +8,11 @@ package parallelisation
88

99
import (
1010
"context"
11+
"golang.org/x/sync/errgroup"
12+
"os"
13+
"os/signal"
1114
"reflect"
15+
"syscall"
1216
"time"
1317

1418
"go.uber.org/atomic"
@@ -214,7 +218,7 @@ func RunActionWithTimeoutAndCancelStore(ctx context.Context, timeout time.Durati
214218
}
215219

216220
// RunActionWithParallelCheck runs an action with a check in parallel
217-
// The function performing the check should return true if the check was favorable; false otherwise. If the check did not have the expected result and the whole function would be cancelled.
221+
// The function performing the check should return true if the check was favorable; false otherwise. If the check did not have the expected result, the whole function would be cancelled.
218222
func RunActionWithParallelCheck(ctx context.Context, action func(ctx context.Context) error, checkAction func(ctx context.Context) bool, checkPeriod time.Duration) error {
219223
err := DetermineContextError(ctx)
220224
if err != nil {
@@ -246,3 +250,48 @@ func RunActionWithParallelCheck(ctx context.Context, action func(ctx context.Con
246250
}
247251
return err
248252
}
253+
254+
// RunActionWithInterruptCancellation runs an action listening to interrupt signals such as SIGTERM or SIGINT
255+
// On interrupt, any cancellation functions in store are called followed by actionOnInterrupt. These functions are not called if no interrupts were raised but action completed.
256+
func RunActionWithInterruptCancellation(ctx context.Context, cancelStore *CancelFunctionStore, action func(ctx context.Context) error, actionOnInterrupt func(ctx context.Context) error) error {
257+
err := DetermineContextError(ctx)
258+
if err != nil {
259+
return err
260+
}
261+
if cancelStore == nil {
262+
cancelStore = NewCancelFunctionsStore()
263+
}
264+
defer cancelStore.Cancel()
265+
// Listening to the following interrupt signals https://www.man7.org/linux/man-pages/man7/signal.7.html
266+
interruptableCtx, cancel := signal.NotifyContext(ctx, syscall.SIGINT, syscall.SIGTERM, os.Interrupt, os.Kill, syscall.SIGHUP, syscall.SIGQUIT, syscall.SIGABRT)
267+
cancelStore.RegisterCancelFunction(cancel)
268+
g, groupCancellableCtx := errgroup.WithContext(ctx)
269+
groupCancellableCtx, cancelOnSuccess := context.WithCancel(groupCancellableCtx)
270+
g.Go(func() error {
271+
select {
272+
case <-interruptableCtx.Done():
273+
case <-groupCancellableCtx.Done():
274+
}
275+
err = DetermineContextError(interruptableCtx)
276+
if err != nil {
277+
// An interrupt was raised.
278+
cancelStore.Cancel()
279+
return actionOnInterrupt(ctx)
280+
}
281+
return err
282+
})
283+
g.Go(func() error {
284+
err := action(interruptableCtx)
285+
if err == nil {
286+
cancelOnSuccess()
287+
}
288+
return err
289+
})
290+
return g.Wait()
291+
}
292+
293+
// RunActionWithGracefulShutdown carries out an action until asked to gracefully shutdown on which the shutdownOnSignal is executed.
294+
// if the action is completed before the shutdown request is performed, shutdownOnSignal will not be executed.
295+
func RunActionWithGracefulShutdown(ctx context.Context, action func(ctx context.Context) error, shutdownOnSignal func(ctx context.Context) error) error {
296+
return RunActionWithInterruptCancellation(ctx, NewCancelFunctionsStore(), action, shutdownOnSignal)
297+
}

utils/parallelisation/parallelisation_test.go

+133-1
Original file line numberDiff line numberDiff line change
@@ -9,7 +9,9 @@ import (
99
"errors"
1010
"fmt"
1111
"math/rand"
12+
"os"
1213
"reflect"
14+
"syscall"
1315
"testing"
1416
"time"
1517

@@ -20,10 +22,11 @@ import (
2022

2123
"github.com/ARM-software/golang-utils/utils/commonerrors"
2224
"github.com/ARM-software/golang-utils/utils/commonerrors/errortest"
25+
"github.com/ARM-software/golang-utils/utils/platform"
2326
)
2427

2528
var (
26-
random = rand.New(rand.NewSource(time.Now().Unix())) //nolint:gosec //causes G404: Use of weak random number generator (math/rand instead of crypto/rand) (gosec), So disable gosec as this is just for
29+
random = rand.New(rand.NewSource(time.Now().Unix())) //nolint:gosec //causes G404: Use of weak random number generator (math/rand instead of crypto/rand) (gosec), So disable gosec as this is just for tests
2730
)
2831

2932
func TestParallelisationWithResults(t *testing.T) {
@@ -411,3 +414,132 @@ func runActionWithParallelCheckFailAtRandom(t *testing.T, ctx context.Context) {
411414
require.Error(t, err)
412415
errortest.AssertError(t, err, commonerrors.ErrCancelled)
413416
}
417+
418+
func TestRunActionWithGracefulShutdown(t *testing.T) {
419+
if platform.IsWindows() {
420+
// Sending Interrupt on Windows is not implemented - https://golang.org/pkg/os/#Process.Signal
421+
t.Skip("Skipping test on Windows as sending interrupt is not implemented on [this platform](https://golang.org/pkg/os/#Process.Signal)")
422+
}
423+
ctx := context.Background()
424+
425+
defer goleak.VerifyNone(t)
426+
tests := []struct {
427+
name string
428+
signal os.Signal
429+
}{
430+
{
431+
name: "SIGTERM",
432+
signal: syscall.SIGTERM,
433+
},
434+
{
435+
name: "SIGINT",
436+
signal: syscall.SIGINT,
437+
},
438+
{
439+
name: "SIGHUP",
440+
signal: syscall.SIGHUP,
441+
},
442+
{
443+
name: "SIGQUIT",
444+
signal: syscall.SIGQUIT,
445+
},
446+
{
447+
name: "SIGABRT",
448+
signal: syscall.SIGABRT,
449+
},
450+
{
451+
name: "Interrupt",
452+
signal: os.Interrupt,
453+
},
454+
}
455+
456+
process := os.Process{Pid: os.Getpid()}
457+
longAction := func(ctx context.Context) error {
458+
SleepWithContext(ctx, 150*time.Millisecond)
459+
return ctx.Err()
460+
}
461+
shortAction := func(ctx context.Context) error {
462+
return ctx.Err()
463+
}
464+
shortActionWithError := func(_ context.Context) error {
465+
return commonerrors.ErrUnexpected
466+
}
467+
468+
t.Run("cancelled context", func(t *testing.T) {
469+
defer goleak.VerifyNone(t)
470+
cctx, cancel := context.WithCancel(ctx)
471+
cancel()
472+
err := RunActionWithGracefulShutdown(cctx, longAction, func(ctx context.Context) error {
473+
return nil
474+
})
475+
require.Error(t, err)
476+
errortest.AssertError(t, err, commonerrors.ErrTimeout, commonerrors.ErrCancelled)
477+
})
478+
479+
for i := range tests {
480+
test := tests[i]
481+
t.Run(fmt.Sprintf("interrupt [%v] before longAction completion", test.name), func(t *testing.T) {
482+
defer goleak.VerifyNone(t)
483+
called := atomic.NewBool(false)
484+
shutdownAction := func(ctx2 context.Context) error {
485+
err := DetermineContextError(ctx2)
486+
if err == nil {
487+
called.Store(true)
488+
}
489+
return err
490+
}
491+
require.False(t, called.Load())
492+
ScheduleAfter(ctx, time.Duration(random.Intn(100))*time.Millisecond, func(ti time.Time) { //nolint:gosec //causes G404: Use of weak random number generator (math/rand instead of crypto/rand) (gosec), So disable gosec as this is just for tests
493+
if err := process.Signal(test.signal); err != nil {
494+
t.Error("failed sending interrupt signal")
495+
}
496+
})
497+
err := RunActionWithGracefulShutdown(ctx, longAction, shutdownAction)
498+
require.Error(t, err)
499+
errortest.AssertError(t, err, commonerrors.ErrTimeout, commonerrors.ErrCancelled)
500+
require.True(t, called.Load())
501+
})
502+
t.Run(fmt.Sprintf("interrupt [%v] after shortAction completion", test.name), func(t *testing.T) {
503+
defer goleak.VerifyNone(t)
504+
called := atomic.NewBool(false)
505+
shutdownAction := func(ctx2 context.Context) error {
506+
err := DetermineContextError(ctx2)
507+
if err == nil {
508+
called.Store(true)
509+
}
510+
return err
511+
}
512+
require.False(t, called.Load())
513+
ScheduleAfter(ctx, time.Duration(50+random.Intn(100))*time.Millisecond, func(ti time.Time) { //nolint:gosec //causes G404: Use of weak random number generator (math/rand instead of crypto/rand) (gosec), So disable gosec as this is just for tests
514+
if err := process.Signal(test.signal); err != nil {
515+
t.Error("failed sending interrupt signal")
516+
}
517+
})
518+
err := RunActionWithGracefulShutdown(ctx, shortAction, shutdownAction)
519+
require.NoError(t, err)
520+
require.False(t, called.Load())
521+
})
522+
t.Run(fmt.Sprintf("interrupt [%v] after shortActionWithError completion", test.name), func(t *testing.T) {
523+
defer goleak.VerifyNone(t)
524+
called := atomic.NewBool(false)
525+
shutdownAction := func(ctx2 context.Context) error {
526+
err := DetermineContextError(ctx2)
527+
if err == nil {
528+
called.Store(true)
529+
}
530+
return err
531+
}
532+
require.False(t, called.Load())
533+
ScheduleAfter(ctx, time.Duration(50+random.Intn(100))*time.Millisecond, func(ti time.Time) { //nolint:gosec //causes G404: Use of weak random number generator (math/rand instead of crypto/rand) (gosec), So disable gosec as this is just for tests
534+
if err := process.Signal(test.signal); err != nil {
535+
t.Error("failed sending interrupt signal")
536+
}
537+
})
538+
err := RunActionWithGracefulShutdown(ctx, shortActionWithError, shutdownAction)
539+
require.Error(t, err)
540+
errortest.AssertError(t, err, commonerrors.ErrUnexpected)
541+
require.False(t, called.Load())
542+
})
543+
}
544+
545+
}

0 commit comments

Comments
 (0)