diff --git a/.circleci/config.yml b/.circleci/config.yml new file mode 100644 index 0000000..8ab9c5d --- /dev/null +++ b/.circleci/config.yml @@ -0,0 +1,15 @@ +version: 2 + +jobs: + build: + docker: + - image: circleci/golang:1.14 + + working_directory: /go/src/github.com/mingruimingrui/batcher + + steps: + - checkout + - run: go test -v + - run: go test -v -bench=. -run ^$ -cpu=1 + - run: go test -v -bench=. -run ^$ -cpu=31 + - run: go test -v -bench=BenchmarkSendRequestParallel -run=^$ -cpu=32 diff --git a/README.md b/README.md index 28c39f4..430aa59 100644 --- a/README.md +++ b/README.md @@ -1,2 +1,85 @@ -# batcher -A Go library for batching requests +`batcher` is a package for batching individual requests into batch +requests. + +For some live services, batching is necessary to maximize throughput. +In particular, services that require disk I/O or leverages GPUs would +require batching as the overhead cost associated to each request is long and +near constant. +The typical way that batching is achieved is with the help of a queue system. + +`batcher` takes this design pattern and formalizes it into a template for +developers to conveniently incorporate batching to their live services. + +- [Installation and Docs](#installation-and-docs) +- [Usage](#usage) +- [BatchingConfig](#batchingconfig) + + +# Installation and Docs + +Within of a module, this package can be installed with `go get`. + +``` +go get github.com/mingruimingrui/batcher +``` + +Auto-generated documentation is available at +https://pkg.go.dev/github.com/mingruimingrui/batcher. + + +# Usage + +Usage of this library typically begins with creation of a new `RequestBatcher` +variable. +Typically this is done on the global scope. + +```golang +var ( + requestBatcher *batcher.RequestBatcher + ... +) + +... + +func main() { + requestBatcher = batcher.RequestBatcher( + context.Background(), + &batcher.BatchingConfig{ + MaxBatchSize: ..., + BatchTimeout: ..., + + // SendF is a function that handles a batch of requests + SendF: func(body *[]interface{}) (*[]interface{}, error) { + ... + } + }, + ) + + ... +} +``` + +To submit a request, simply use the `SendRequestWithTimeout` method + +```golang +resp, err := requestBatcher.SendRequestWithTimeout(&someRequestBody, someTimeoutDuration) +``` + +When multiple requests are send together from multiple goroutines, +the `RequestBatcher` would group those requests together into a single batch +so `SendF` would be called minimally. + + +# BatchingConfig + +`batcher.BatchingConfig` controls batching behavior and accepts the +following parameters. + +- **`MaxBatchSize`** `{int}`
+ The maximum number of requests per batch + +- **`BatchTimeout`** `{time.Duration}`
+ The maximum wait time before batch is sent + +- **`SendF`** `{func(body *[]interface{}) (*[]interface{}, error)}`
+ A function for the user to define how to handle a batch diff --git a/batcher.go b/batcher.go new file mode 100644 index 0000000..1890969 --- /dev/null +++ b/batcher.go @@ -0,0 +1,310 @@ +/* +batcher is a library for batching requests. + +This library encapsualtes the process of grouping requests into batches +for batch processing in an asynchronous manner. + +This implementation is adapted from the Google batch API +https://github.com/terraform-providers/terraform-provider-google/blob/master/google/batcher.go + +However there are a number of notable differences +- Usage assumes 1 batcher for 1 API (instead of 1 batcher for multiple APIs) +- Clients should only receive their own response, and not the batch response + their request is sent with +- Config conventions follow the framework as defined in + https://github.com/tensorflow/serving/tree/master/tensorflow_serving/batching#batch-scheduling-parameters-and-tuning +*/ + +package batcher + +import ( + "context" + "fmt" + "log" + "sync" + "time" +) + +/* +RequestBatcher handles receiving of new requests, and all the background +asynchronous tasks to batch and send batch. + +A new RequestBatcher should be created using the NewRequestBatcher function. + +Expected usage pattern of the RequestBatcher involves declaring a +RequestBatcher on global scope and calling SendRequestWithTimeout from +multiple goroutines. +*/ +type RequestBatcher struct { + sync.Mutex + + *BatchingConfig + running bool + parentCtx context.Context + curBatch *startedBatch +} + +/* +BatchingConfig determines how batching is done in a RequestBatcher. +*/ +type BatchingConfig struct { + // Maximum request size of each batch. + MaxBatchSize int + + // Maximum wait time before batch should be executed. + BatchTimeout time.Duration + + // User defined SendF for sending a batch request. + // See SendFunc for type definition of this function. + SendF SendFunc +} + +/* +SendFunc is a function type for sending a batch of requests. +A batch of requests is a slice of inputs to SendRequestWithTimeout. +*/ +type SendFunc func(body *[]interface{}) (*[]interface{}, error) + +// startedBatch refers to a batch awaiting for more requests to come in +// before having SendFunc applied to it's content +type startedBatch struct { + // Combined batch request + body []interface{} + + // subscribers is a registry of the requests (batchSubscriber) + // combined to make this batch + subscribers []batchSubscriber + + // timer for keeping track of BatchTimeout + timer *time.Timer +} + +// singleResponse represents a single response received from SendF +type singleResponse struct { + body interface{} + err error +} + +// batchSubscriber contains the response queue to awaits for a singleResponse +type batchSubscriber struct { + // singleRequestBody is the original request this subscriber represents + singleRequestBody interface{} + + // respCh is the channel created to communicate the result to a waiting + // goroutine + respCh chan *singleResponse +} + +/* +NewRequestBatcher creates a new RequestBatcher +from a Context and a BatchingConfig. + +In the typical usage pattern, a RequestBatcher should always be alive so it +is safe and recommended to use the background context. +*/ +func NewRequestBatcher( + ctx context.Context, + config *BatchingConfig, +) *RequestBatcher { + batcher := &RequestBatcher{ + BatchingConfig: config, + parentCtx: ctx, + running: true, + } + + if batcher.SendF == nil { + log.Fatal("Expecting SendF") + } + + go func(b *RequestBatcher) { + <-b.parentCtx.Done() + log.Printf("Parent context cancelled") + b.stop() + }(batcher) + + return batcher +} + +// stop would safely releases all batcher allocated resources +func (b *RequestBatcher) stop() { + b.Lock() + defer b.Unlock() + log.Println("Stopping batcher") + + b.running = false + if b.curBatch != nil { + b.curBatch.timer.Stop() + for i := len(b.curBatch.subscribers) - 1; i >= 0; i-- { + close(b.curBatch.subscribers[i].respCh) + } + } + log.Println("Batcher stopped") +} + +/* +SendRequestWithTimeout is a method to make a single request. +It manages registering the request into the batcher, +and waiting on the response. + +Arguments: + newRequestBody {*interface{}} -- A request body. SendF will expect + a slice of objects like newRequestBody. + +Returns: + interface{} -- A response body. SendF's output is expected to be a slice + of objects like this. + error -- Error +*/ +func (b *RequestBatcher) SendRequestWithTimeout( + newRequestBody *interface{}, + timeout time.Duration, +) (interface{}, error) { + // Check that request is valid + if newRequestBody == nil { + return nil, fmt.Errorf("Received `nil` request") + } + + if timeout <= b.BatchTimeout { + errmsg := fmt.Sprintf( + "Timeout period should be longer than batch timout, %v", + b.BatchTimeout, + ) + return nil, fmt.Errorf(errmsg) + } + + respCh, err := b.registerRequest(newRequestBody) + if err != nil { + log.Printf("[ERROR] Failed to register request: %v", err) + return nil, fmt.Errorf("Failed to register request") + } + + ctx, cancel := context.WithTimeout(b.parentCtx, timeout) + defer cancel() + + select { + case resp := <-respCh: + if resp.err != nil { + log.Printf("[ERROR] Failed to process request: %v", resp.err) + return nil, resp.err + } + return resp.body, nil + + case <-ctx.Done(): + return nil, fmt.Errorf("Request timeout after %v", timeout) + } +} + +// registerRequest safely determines if new request should be +// added to existing batch or to a new batch +func (b *RequestBatcher) registerRequest( + newRequestBody *interface{}, +) (<-chan *singleResponse, error) { + respCh := make(chan *singleResponse, 1) + sub := batchSubscriber{ + singleRequestBody: *newRequestBody, + respCh: respCh, + } + + b.Lock() + defer b.Unlock() + + if b.curBatch != nil { + // Check if new request can be appended to curBatch + if len(b.curBatch.body) < b.MaxBatchSize { + // Append request to current batch + b.curBatch.body = append(b.curBatch.body, *newRequestBody) + b.curBatch.subscribers = append(b.curBatch.subscribers, sub) + + // Check if current batch is full + if len(b.curBatch.body) >= b.MaxBatchSize { + // Send current batch + b.curBatch.timer.Stop() + b.sendCurBatch() + } + + return respCh, nil + } + + // Send current batch + b.curBatch.timer.Stop() + b.sendCurBatch() + } + + // Create new batch from request + b.curBatch = &startedBatch{ + body: []interface{}{*newRequestBody}, + subscribers: []batchSubscriber{sub}, + } + + // Start a timer to send request after batch timeout + b.curBatch.timer = time.AfterFunc(b.BatchTimeout, b.sendCurBatchWithSafety) + + return respCh, nil +} + +// sendCurBatch pops curBatch and sends it without mutex +func (b *RequestBatcher) sendCurBatch() { + // Acquire batch + batch := b.curBatch + b.curBatch = nil + + if batch != nil { + go func() { + b.send(batch) + }() + } +} + +// sendCurBatchWithSafety pops curBatch and sends it with mutex +func (b *RequestBatcher) sendCurBatchWithSafety() { + // Acquire batch + b.Lock() + batch := b.curBatch + b.curBatch = nil + b.Unlock() + + if batch != nil { + go func() { + b.send(batch) + }() + } +} + +// send calls SendF on a startedBatch +func (b *RequestBatcher) send(batch *startedBatch) { + + // Attempt to apply SendF + batchResp, err := b.SendF(&batch.body) + if err != nil { + for i := len(batch.subscribers) - 1; i >= 0; i-- { + batch.subscribers[i].respCh <- &singleResponse{ + body: nil, + err: err, + } + close(batch.subscribers[i].respCh) + } + return + } + + // Raise error if number of entries mismatch + if len(*batchResp) != len(batch.body) { + log.Printf("[ERROR] SendF returned different number of entries.") + for i := len(batch.subscribers) - 1; i >= 0; i-- { + batch.subscribers[i].respCh <- &singleResponse{ + body: nil, + err: fmt.Errorf("API error"), + } + close(batch.subscribers[i].respCh) + } + return + } + + // On success, place response into subscribed response queues. + for i := len(batch.subscribers) - 1; i >= 0; i-- { + batch.subscribers[i].respCh <- &singleResponse{ + body: (*batchResp)[i], + err: nil, + } + close(batch.subscribers[i].respCh) + } +} diff --git a/batcher_test.go b/batcher_test.go new file mode 100644 index 0000000..db3b644 --- /dev/null +++ b/batcher_test.go @@ -0,0 +1,167 @@ +package batcher + +import ( + "context" + "encoding/json" + "math/rand" + "testing" + "time" +) + +var ( + ctx context.Context + delayedBatcher *RequestBatcher + immediateBatcher *RequestBatcher + + idleTimeout time.Duration + nonEmptyRequestBody interface{} +) + +func dummySendF(body *[]interface{}) (*[]interface{}, error) { + return body, nil +} + +func init() { + ctx = context.Background() + + delayedBatcher = NewRequestBatcher(ctx, &BatchingConfig{ + MaxBatchSize: 32, + BatchTimeout: time.Millisecond, + SendF: dummySendF, + }) + + immediateBatcher = NewRequestBatcher(ctx, &BatchingConfig{ + MaxBatchSize: 32, + BatchTimeout: time.Duration(0), + SendF: dummySendF, + }) + + idleTimeout = time.Minute + nonEmptyRequestBody = "This is some string" +} + +// TestValidRequestBody checks if RequestBatcher accepts valid requests +// In this test, all test cases are guaranteed to be JSONifiable +func TestValidRequestBody(t *testing.T) { + testCases := []struct { + desc string + body interface{} + }{ + { + desc: "String", + body: "Hello World", + }, + { + desc: "Empty bytestring", + body: []byte{}, + }, + { + desc: "Non-empty bytestring", + body: []byte("Hello World"), + }, + { + desc: "Integer", + body: 1, + }, + { + desc: "Float", + body: 1.23, + }, + { + desc: "Object", + body: struct { + Word string `json:"word"` + Number int `json:"num"` + }{ + Word: "Word", + Number: 3, + }, + }, + { + desc: "Empty object", + body: struct{}{}, + }, + } + for _, tC := range testCases { + t.Run(tC.desc, func(t *testing.T) { + resp, err := delayedBatcher.SendRequestWithTimeout(&tC.body, idleTimeout) + + // Ensure no error + if err != nil { + t.Errorf("Unexpected error: %v", err) + } + + A, _ := json.Marshal(resp) + B, _ := json.Marshal(tC.body) + + // Ensure response is consistent + if string(A) != string(B) { + t.Errorf( + "SendRequestWithTimeout output inconsistent."+ + " Expecting %v, got %v", + string(A), string(B), + ) + } + + }) + } +} + +// TestInvalidRequestBody checks if RequestBatcher rejects invalid requests +func TestInvalidRequestBody(t *testing.T) { + _, err := delayedBatcher.SendRequestWithTimeout(nil, idleTimeout) + if err == nil { + t.Error("Expecting error when sending with `nil` body") + } +} + +// TestTimeoutTooShort checks if timeouts that are too short are rejected +func TestTimeoutTooShort(t *testing.T) { + _, err := delayedBatcher.SendRequestWithTimeout( + &nonEmptyRequestBody, + delayedBatcher.BatchTimeout, + ) + if err == nil { + t.Errorf( + "Expecting error when timeout too short %v", + delayedBatcher.BatchTimeout, + ) + } +} + +// BenchmarkSendRequestOverhead benchmarks the overhead costs of the +// RequestBatcher +func BenchmarkSendRequestOverhead(b *testing.B) { + b.RunParallel(func(pb *testing.PB) { + var reqBody interface{} + for pb.Next() { + reqBody = rand.Float32() + resp, err := immediateBatcher.SendRequestWithTimeout( + &reqBody, idleTimeout) + if err != nil { + b.Errorf("Unexpected error: %v", err) + } + if resp != reqBody { + b.Error("Response not the same") + } + } + }) +} + +// BenchmarkSendRequestParallel benchmarks the parallelism of the +// RequestBatcher +func BenchmarkSendRequestParallel(b *testing.B) { + b.RunParallel(func(pb *testing.PB) { + var reqBody interface{} + for pb.Next() { + reqBody = rand.Float32() + resp, err := delayedBatcher.SendRequestWithTimeout(&reqBody, idleTimeout) + if err != nil { + b.Errorf("Unexpected error: %v", err) + } + if resp != reqBody { + b.Error("Response not the same") + } + } + }) +} diff --git a/go.mod b/go.mod new file mode 100644 index 0000000..f18aa4b --- /dev/null +++ b/go.mod @@ -0,0 +1,3 @@ +module github.com/mingruimingrui/batcher + +go 1.14 diff --git a/version.go b/version.go new file mode 100644 index 0000000..3aff01d --- /dev/null +++ b/version.go @@ -0,0 +1,6 @@ +package batcher + +var ( + // Version -- Current version of this package + Version string = "1.0.0" +)