diff --git a/.circleci/config.yml b/.circleci/config.yml
new file mode 100644
index 0000000..8ab9c5d
--- /dev/null
+++ b/.circleci/config.yml
@@ -0,0 +1,15 @@
+version: 2
+
+jobs:
+ build:
+ docker:
+ - image: circleci/golang:1.14
+
+ working_directory: /go/src/github.com/mingruimingrui/batcher
+
+ steps:
+ - checkout
+ - run: go test -v
+ - run: go test -v -bench=. -run ^$ -cpu=1
+ - run: go test -v -bench=. -run ^$ -cpu=31
+ - run: go test -v -bench=BenchmarkSendRequestParallel -run=^$ -cpu=32
diff --git a/README.md b/README.md
index 28c39f4..430aa59 100644
--- a/README.md
+++ b/README.md
@@ -1,2 +1,85 @@
-# batcher
-A Go library for batching requests
+`batcher` is a package for batching individual requests into batch
+requests.
+
+For some live services, batching is necessary to maximize throughput.
+In particular, services that require disk I/O or leverages GPUs would
+require batching as the overhead cost associated to each request is long and
+near constant.
+The typical way that batching is achieved is with the help of a queue system.
+
+`batcher` takes this design pattern and formalizes it into a template for
+developers to conveniently incorporate batching to their live services.
+
+- [Installation and Docs](#installation-and-docs)
+- [Usage](#usage)
+- [BatchingConfig](#batchingconfig)
+
+
+# Installation and Docs
+
+Within of a module, this package can be installed with `go get`.
+
+```
+go get github.com/mingruimingrui/batcher
+```
+
+Auto-generated documentation is available at
+https://pkg.go.dev/github.com/mingruimingrui/batcher.
+
+
+# Usage
+
+Usage of this library typically begins with creation of a new `RequestBatcher`
+variable.
+Typically this is done on the global scope.
+
+```golang
+var (
+ requestBatcher *batcher.RequestBatcher
+ ...
+)
+
+...
+
+func main() {
+ requestBatcher = batcher.RequestBatcher(
+ context.Background(),
+ &batcher.BatchingConfig{
+ MaxBatchSize: ...,
+ BatchTimeout: ...,
+
+ // SendF is a function that handles a batch of requests
+ SendF: func(body *[]interface{}) (*[]interface{}, error) {
+ ...
+ }
+ },
+ )
+
+ ...
+}
+```
+
+To submit a request, simply use the `SendRequestWithTimeout` method
+
+```golang
+resp, err := requestBatcher.SendRequestWithTimeout(&someRequestBody, someTimeoutDuration)
+```
+
+When multiple requests are send together from multiple goroutines,
+the `RequestBatcher` would group those requests together into a single batch
+so `SendF` would be called minimally.
+
+
+# BatchingConfig
+
+`batcher.BatchingConfig` controls batching behavior and accepts the
+following parameters.
+
+- **`MaxBatchSize`** `{int}`
+ The maximum number of requests per batch
+
+- **`BatchTimeout`** `{time.Duration}`
+ The maximum wait time before batch is sent
+
+- **`SendF`** `{func(body *[]interface{}) (*[]interface{}, error)}`
+ A function for the user to define how to handle a batch
diff --git a/batcher.go b/batcher.go
new file mode 100644
index 0000000..1890969
--- /dev/null
+++ b/batcher.go
@@ -0,0 +1,310 @@
+/*
+batcher is a library for batching requests.
+
+This library encapsualtes the process of grouping requests into batches
+for batch processing in an asynchronous manner.
+
+This implementation is adapted from the Google batch API
+https://github.com/terraform-providers/terraform-provider-google/blob/master/google/batcher.go
+
+However there are a number of notable differences
+- Usage assumes 1 batcher for 1 API (instead of 1 batcher for multiple APIs)
+- Clients should only receive their own response, and not the batch response
+ their request is sent with
+- Config conventions follow the framework as defined in
+ https://github.com/tensorflow/serving/tree/master/tensorflow_serving/batching#batch-scheduling-parameters-and-tuning
+*/
+
+package batcher
+
+import (
+ "context"
+ "fmt"
+ "log"
+ "sync"
+ "time"
+)
+
+/*
+RequestBatcher handles receiving of new requests, and all the background
+asynchronous tasks to batch and send batch.
+
+A new RequestBatcher should be created using the NewRequestBatcher function.
+
+Expected usage pattern of the RequestBatcher involves declaring a
+RequestBatcher on global scope and calling SendRequestWithTimeout from
+multiple goroutines.
+*/
+type RequestBatcher struct {
+ sync.Mutex
+
+ *BatchingConfig
+ running bool
+ parentCtx context.Context
+ curBatch *startedBatch
+}
+
+/*
+BatchingConfig determines how batching is done in a RequestBatcher.
+*/
+type BatchingConfig struct {
+ // Maximum request size of each batch.
+ MaxBatchSize int
+
+ // Maximum wait time before batch should be executed.
+ BatchTimeout time.Duration
+
+ // User defined SendF for sending a batch request.
+ // See SendFunc for type definition of this function.
+ SendF SendFunc
+}
+
+/*
+SendFunc is a function type for sending a batch of requests.
+A batch of requests is a slice of inputs to SendRequestWithTimeout.
+*/
+type SendFunc func(body *[]interface{}) (*[]interface{}, error)
+
+// startedBatch refers to a batch awaiting for more requests to come in
+// before having SendFunc applied to it's content
+type startedBatch struct {
+ // Combined batch request
+ body []interface{}
+
+ // subscribers is a registry of the requests (batchSubscriber)
+ // combined to make this batch
+ subscribers []batchSubscriber
+
+ // timer for keeping track of BatchTimeout
+ timer *time.Timer
+}
+
+// singleResponse represents a single response received from SendF
+type singleResponse struct {
+ body interface{}
+ err error
+}
+
+// batchSubscriber contains the response queue to awaits for a singleResponse
+type batchSubscriber struct {
+ // singleRequestBody is the original request this subscriber represents
+ singleRequestBody interface{}
+
+ // respCh is the channel created to communicate the result to a waiting
+ // goroutine
+ respCh chan *singleResponse
+}
+
+/*
+NewRequestBatcher creates a new RequestBatcher
+from a Context and a BatchingConfig.
+
+In the typical usage pattern, a RequestBatcher should always be alive so it
+is safe and recommended to use the background context.
+*/
+func NewRequestBatcher(
+ ctx context.Context,
+ config *BatchingConfig,
+) *RequestBatcher {
+ batcher := &RequestBatcher{
+ BatchingConfig: config,
+ parentCtx: ctx,
+ running: true,
+ }
+
+ if batcher.SendF == nil {
+ log.Fatal("Expecting SendF")
+ }
+
+ go func(b *RequestBatcher) {
+ <-b.parentCtx.Done()
+ log.Printf("Parent context cancelled")
+ b.stop()
+ }(batcher)
+
+ return batcher
+}
+
+// stop would safely releases all batcher allocated resources
+func (b *RequestBatcher) stop() {
+ b.Lock()
+ defer b.Unlock()
+ log.Println("Stopping batcher")
+
+ b.running = false
+ if b.curBatch != nil {
+ b.curBatch.timer.Stop()
+ for i := len(b.curBatch.subscribers) - 1; i >= 0; i-- {
+ close(b.curBatch.subscribers[i].respCh)
+ }
+ }
+ log.Println("Batcher stopped")
+}
+
+/*
+SendRequestWithTimeout is a method to make a single request.
+It manages registering the request into the batcher,
+and waiting on the response.
+
+Arguments:
+ newRequestBody {*interface{}} -- A request body. SendF will expect
+ a slice of objects like newRequestBody.
+
+Returns:
+ interface{} -- A response body. SendF's output is expected to be a slice
+ of objects like this.
+ error -- Error
+*/
+func (b *RequestBatcher) SendRequestWithTimeout(
+ newRequestBody *interface{},
+ timeout time.Duration,
+) (interface{}, error) {
+ // Check that request is valid
+ if newRequestBody == nil {
+ return nil, fmt.Errorf("Received `nil` request")
+ }
+
+ if timeout <= b.BatchTimeout {
+ errmsg := fmt.Sprintf(
+ "Timeout period should be longer than batch timout, %v",
+ b.BatchTimeout,
+ )
+ return nil, fmt.Errorf(errmsg)
+ }
+
+ respCh, err := b.registerRequest(newRequestBody)
+ if err != nil {
+ log.Printf("[ERROR] Failed to register request: %v", err)
+ return nil, fmt.Errorf("Failed to register request")
+ }
+
+ ctx, cancel := context.WithTimeout(b.parentCtx, timeout)
+ defer cancel()
+
+ select {
+ case resp := <-respCh:
+ if resp.err != nil {
+ log.Printf("[ERROR] Failed to process request: %v", resp.err)
+ return nil, resp.err
+ }
+ return resp.body, nil
+
+ case <-ctx.Done():
+ return nil, fmt.Errorf("Request timeout after %v", timeout)
+ }
+}
+
+// registerRequest safely determines if new request should be
+// added to existing batch or to a new batch
+func (b *RequestBatcher) registerRequest(
+ newRequestBody *interface{},
+) (<-chan *singleResponse, error) {
+ respCh := make(chan *singleResponse, 1)
+ sub := batchSubscriber{
+ singleRequestBody: *newRequestBody,
+ respCh: respCh,
+ }
+
+ b.Lock()
+ defer b.Unlock()
+
+ if b.curBatch != nil {
+ // Check if new request can be appended to curBatch
+ if len(b.curBatch.body) < b.MaxBatchSize {
+ // Append request to current batch
+ b.curBatch.body = append(b.curBatch.body, *newRequestBody)
+ b.curBatch.subscribers = append(b.curBatch.subscribers, sub)
+
+ // Check if current batch is full
+ if len(b.curBatch.body) >= b.MaxBatchSize {
+ // Send current batch
+ b.curBatch.timer.Stop()
+ b.sendCurBatch()
+ }
+
+ return respCh, nil
+ }
+
+ // Send current batch
+ b.curBatch.timer.Stop()
+ b.sendCurBatch()
+ }
+
+ // Create new batch from request
+ b.curBatch = &startedBatch{
+ body: []interface{}{*newRequestBody},
+ subscribers: []batchSubscriber{sub},
+ }
+
+ // Start a timer to send request after batch timeout
+ b.curBatch.timer = time.AfterFunc(b.BatchTimeout, b.sendCurBatchWithSafety)
+
+ return respCh, nil
+}
+
+// sendCurBatch pops curBatch and sends it without mutex
+func (b *RequestBatcher) sendCurBatch() {
+ // Acquire batch
+ batch := b.curBatch
+ b.curBatch = nil
+
+ if batch != nil {
+ go func() {
+ b.send(batch)
+ }()
+ }
+}
+
+// sendCurBatchWithSafety pops curBatch and sends it with mutex
+func (b *RequestBatcher) sendCurBatchWithSafety() {
+ // Acquire batch
+ b.Lock()
+ batch := b.curBatch
+ b.curBatch = nil
+ b.Unlock()
+
+ if batch != nil {
+ go func() {
+ b.send(batch)
+ }()
+ }
+}
+
+// send calls SendF on a startedBatch
+func (b *RequestBatcher) send(batch *startedBatch) {
+
+ // Attempt to apply SendF
+ batchResp, err := b.SendF(&batch.body)
+ if err != nil {
+ for i := len(batch.subscribers) - 1; i >= 0; i-- {
+ batch.subscribers[i].respCh <- &singleResponse{
+ body: nil,
+ err: err,
+ }
+ close(batch.subscribers[i].respCh)
+ }
+ return
+ }
+
+ // Raise error if number of entries mismatch
+ if len(*batchResp) != len(batch.body) {
+ log.Printf("[ERROR] SendF returned different number of entries.")
+ for i := len(batch.subscribers) - 1; i >= 0; i-- {
+ batch.subscribers[i].respCh <- &singleResponse{
+ body: nil,
+ err: fmt.Errorf("API error"),
+ }
+ close(batch.subscribers[i].respCh)
+ }
+ return
+ }
+
+ // On success, place response into subscribed response queues.
+ for i := len(batch.subscribers) - 1; i >= 0; i-- {
+ batch.subscribers[i].respCh <- &singleResponse{
+ body: (*batchResp)[i],
+ err: nil,
+ }
+ close(batch.subscribers[i].respCh)
+ }
+}
diff --git a/batcher_test.go b/batcher_test.go
new file mode 100644
index 0000000..db3b644
--- /dev/null
+++ b/batcher_test.go
@@ -0,0 +1,167 @@
+package batcher
+
+import (
+ "context"
+ "encoding/json"
+ "math/rand"
+ "testing"
+ "time"
+)
+
+var (
+ ctx context.Context
+ delayedBatcher *RequestBatcher
+ immediateBatcher *RequestBatcher
+
+ idleTimeout time.Duration
+ nonEmptyRequestBody interface{}
+)
+
+func dummySendF(body *[]interface{}) (*[]interface{}, error) {
+ return body, nil
+}
+
+func init() {
+ ctx = context.Background()
+
+ delayedBatcher = NewRequestBatcher(ctx, &BatchingConfig{
+ MaxBatchSize: 32,
+ BatchTimeout: time.Millisecond,
+ SendF: dummySendF,
+ })
+
+ immediateBatcher = NewRequestBatcher(ctx, &BatchingConfig{
+ MaxBatchSize: 32,
+ BatchTimeout: time.Duration(0),
+ SendF: dummySendF,
+ })
+
+ idleTimeout = time.Minute
+ nonEmptyRequestBody = "This is some string"
+}
+
+// TestValidRequestBody checks if RequestBatcher accepts valid requests
+// In this test, all test cases are guaranteed to be JSONifiable
+func TestValidRequestBody(t *testing.T) {
+ testCases := []struct {
+ desc string
+ body interface{}
+ }{
+ {
+ desc: "String",
+ body: "Hello World",
+ },
+ {
+ desc: "Empty bytestring",
+ body: []byte{},
+ },
+ {
+ desc: "Non-empty bytestring",
+ body: []byte("Hello World"),
+ },
+ {
+ desc: "Integer",
+ body: 1,
+ },
+ {
+ desc: "Float",
+ body: 1.23,
+ },
+ {
+ desc: "Object",
+ body: struct {
+ Word string `json:"word"`
+ Number int `json:"num"`
+ }{
+ Word: "Word",
+ Number: 3,
+ },
+ },
+ {
+ desc: "Empty object",
+ body: struct{}{},
+ },
+ }
+ for _, tC := range testCases {
+ t.Run(tC.desc, func(t *testing.T) {
+ resp, err := delayedBatcher.SendRequestWithTimeout(&tC.body, idleTimeout)
+
+ // Ensure no error
+ if err != nil {
+ t.Errorf("Unexpected error: %v", err)
+ }
+
+ A, _ := json.Marshal(resp)
+ B, _ := json.Marshal(tC.body)
+
+ // Ensure response is consistent
+ if string(A) != string(B) {
+ t.Errorf(
+ "SendRequestWithTimeout output inconsistent."+
+ " Expecting %v, got %v",
+ string(A), string(B),
+ )
+ }
+
+ })
+ }
+}
+
+// TestInvalidRequestBody checks if RequestBatcher rejects invalid requests
+func TestInvalidRequestBody(t *testing.T) {
+ _, err := delayedBatcher.SendRequestWithTimeout(nil, idleTimeout)
+ if err == nil {
+ t.Error("Expecting error when sending with `nil` body")
+ }
+}
+
+// TestTimeoutTooShort checks if timeouts that are too short are rejected
+func TestTimeoutTooShort(t *testing.T) {
+ _, err := delayedBatcher.SendRequestWithTimeout(
+ &nonEmptyRequestBody,
+ delayedBatcher.BatchTimeout,
+ )
+ if err == nil {
+ t.Errorf(
+ "Expecting error when timeout too short %v",
+ delayedBatcher.BatchTimeout,
+ )
+ }
+}
+
+// BenchmarkSendRequestOverhead benchmarks the overhead costs of the
+// RequestBatcher
+func BenchmarkSendRequestOverhead(b *testing.B) {
+ b.RunParallel(func(pb *testing.PB) {
+ var reqBody interface{}
+ for pb.Next() {
+ reqBody = rand.Float32()
+ resp, err := immediateBatcher.SendRequestWithTimeout(
+ &reqBody, idleTimeout)
+ if err != nil {
+ b.Errorf("Unexpected error: %v", err)
+ }
+ if resp != reqBody {
+ b.Error("Response not the same")
+ }
+ }
+ })
+}
+
+// BenchmarkSendRequestParallel benchmarks the parallelism of the
+// RequestBatcher
+func BenchmarkSendRequestParallel(b *testing.B) {
+ b.RunParallel(func(pb *testing.PB) {
+ var reqBody interface{}
+ for pb.Next() {
+ reqBody = rand.Float32()
+ resp, err := delayedBatcher.SendRequestWithTimeout(&reqBody, idleTimeout)
+ if err != nil {
+ b.Errorf("Unexpected error: %v", err)
+ }
+ if resp != reqBody {
+ b.Error("Response not the same")
+ }
+ }
+ })
+}
diff --git a/go.mod b/go.mod
new file mode 100644
index 0000000..f18aa4b
--- /dev/null
+++ b/go.mod
@@ -0,0 +1,3 @@
+module github.com/mingruimingrui/batcher
+
+go 1.14
diff --git a/version.go b/version.go
new file mode 100644
index 0000000..3aff01d
--- /dev/null
+++ b/version.go
@@ -0,0 +1,6 @@
+package batcher
+
+var (
+ // Version -- Current version of this package
+ Version string = "1.0.0"
+)