Skip to content

Commit

Permalink
Merge pull request #165 from conductor-sdk/feature/graceful-shutdown
Browse files Browse the repository at this point in the history
Task Runner graceful shutdown
  • Loading branch information
jmigueprieto authored Feb 3, 2025
2 parents 05ae15f + 59a2ce3 commit 504d773
Show file tree
Hide file tree
Showing 2 changed files with 69 additions and 3 deletions.
21 changes: 20 additions & 1 deletion sdk/worker/task_runner.go
Original file line number Diff line number Diff line change
Expand Up @@ -202,14 +202,31 @@ func (c *TaskRunner) Resume(taskName string) {
c.pausedWorkers[taskName] = false
}

// Shutdown the TaskRunner will stop polling for tasks and once all running workers are done,
// a signal will be sent to the WaitGroup to indicate that this worker has completed its work.
// When used in conjunction with TaskRunner.WaitWorkers() it allows a graceful shutdown.
func (c *TaskRunner) Shutdown(taskName string) {
c.batchSizeByTaskNameMutex.Lock()
delete(c.batchSizeByTaskName, taskName)
c.batchSizeByTaskNameMutex.Unlock()

c.pausedWorkersMutex.Lock()
delete(c.pausedWorkers, taskName)
c.pausedWorkersMutex.Unlock()

c.pollIntervalByTaskNameMutex.Lock()
delete(c.pollIntervalByTaskName, taskName)
c.pollIntervalByTaskNameMutex.Unlock()
}

func (c *TaskRunner) isPaused(taskName string) bool {
c.pausedWorkersMutex.RLock()
defer c.pausedWorkersMutex.RUnlock()
return c.pausedWorkers[taskName]
}

// WaitWorkers uses an internal waitgroup to block the calling thread until all workers started by this TaskRunner have
// been stopped.
// been shut down.
func (c *TaskRunner) WaitWorkers() {
c.workerWaitGroup.Wait()
}
Expand Down Expand Up @@ -469,6 +486,7 @@ func (c *TaskRunner) increaseRunningWorkers(taskName string) error {
c.runningWorkersByTaskNameMutex.Lock()
defer c.runningWorkersByTaskNameMutex.Unlock()
c.runningWorkersByTaskName[taskName] += 1
c.workerWaitGroup.Add(1)
log.Trace("Increased running workers for task: ", taskName)
return nil
}
Expand All @@ -477,6 +495,7 @@ func (c *TaskRunner) runningWorkerDone(taskName string) error {
c.runningWorkersByTaskNameMutex.Lock()
defer c.runningWorkersByTaskNameMutex.Unlock()
c.runningWorkersByTaskName[taskName] -= 1
c.workerWaitGroup.Done()
log.Trace("Running worker done for task: ", taskName)
return nil
}
Expand Down
51 changes: 49 additions & 2 deletions test/unit_tests/worker_unit_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -10,13 +10,14 @@
package unit_tests

import (
"testing"
"time"

"github.com/conductor-sdk/conductor-go/sdk/client"
"github.com/conductor-sdk/conductor-go/sdk/model"
"github.com/conductor-sdk/conductor-go/sdk/settings"
"github.com/conductor-sdk/conductor-go/sdk/worker"
"github.com/stretchr/testify/assert"
"testing"
"time"
)

func TestSimpleTaskRunner(t *testing.T) {
Expand Down Expand Up @@ -75,6 +76,52 @@ func TestPauseResume(t *testing.T) {

}

func TestShutown(t *testing.T) {
authenticationSettings := settings.NewAuthenticationSettings(
"keyId",
"keySecret",
)
apiClient := client.NewAPIClient(
authenticationSettings,
settings.NewHttpDefaultSettings(),
)
taskRunner := worker.NewTaskRunnerWithApiClient(
apiClient,
)
taskRunner.StartWorker("test_shutdown1", TaskWorker, 4, time.Second)
taskRunner.StartWorker("test_shutdown2", TaskWorker, 4, time.Second)

start := time.Now()
go func() {
time.Sleep(3 * time.Second)
taskRunner.Shutdown("test_shutdown1")
taskRunner.Shutdown("test_shutdown2")
}()

taskRunner.WaitWorkers()
elapsed := time.Since(start)
assert.GreaterOrEqual(t, elapsed.Seconds(), 2.9)

assert.Equal(t, 0, taskRunner.GetBatchSizeForTask("test_shutdown1"))
assert.Equal(t, 0, taskRunner.GetBatchSizeForTask("test_shutdown2"))

err := taskRunner.IncreaseBatchSize("test_shutdown1", 1)
assert.NotNil(t, err)
assert.Equal(t, "no worker registered for taskName: test_shutdown1", err.Error())

err = taskRunner.IncreaseBatchSize("test_shutdown2", 1)
assert.NotNil(t, err)
assert.Equal(t, "no worker registered for taskName: test_shutdown2", err.Error())

pollInteval, err := taskRunner.GetPollIntervalForTask("test_shutdown1")
assert.Equal(t, time.Duration(0), pollInteval)
assert.Equal(t, "poll interval not registered for task: test_shutdown1", err.Error())

pollInteval, err = taskRunner.GetPollIntervalForTask("test_shutdown2")
assert.Equal(t, time.Duration(0), pollInteval)
assert.Equal(t, "poll interval not registered for task: test_shutdown2", err.Error())
}

func TaskWorker(task *model.Task) (interface{}, error) {
return map[string]interface{}{
"zip": "10121",
Expand Down

0 comments on commit 504d773

Please sign in to comment.