Skip to content

Commit

Permalink
Merge pull request #140 from Raytar/channels-refactor
Browse files Browse the repository at this point in the history
Refactor and rename orderedNodeStream and receiveQueue
  • Loading branch information
meling authored May 10, 2021
2 parents dc76971 + 6b8a3d1 commit a5ad159
Show file tree
Hide file tree
Showing 13 changed files with 284 additions and 608 deletions.
10 changes: 5 additions & 5 deletions async.go
Original file line number Diff line number Diff line change
Expand Up @@ -3,6 +3,7 @@ package gorums
import (
"context"

"github.com/relab/gorums/ordering"
"google.golang.org/protobuf/reflect/protoreflect"
)

Expand Down Expand Up @@ -31,8 +32,8 @@ func (f *Async) Done() bool {

func (c Configuration) AsyncCall(ctx context.Context, d QuorumCallData) *Async {
expectedReplies := len(c)
md := c.newCall(d.Method)
replyChan, callDone := c.newReply(md, expectedReplies)
md := &ordering.Metadata{MessageID: c.getMsgID(), Method: d.Method}
replyChan := make(chan response, expectedReplies)

for _, n := range c {
msg := d.Message
Expand All @@ -43,13 +44,12 @@ func (c Configuration) AsyncCall(ctx context.Context, d QuorumCallData) *Async {
continue // don't send if no msg
}
}
n.sendQ <- gorumsStreamRequest{ctx: ctx, msg: &Message{Metadata: md, Message: msg}}
n.channel.enqueue(request{ctx: ctx, msg: &Message{Metadata: md, Message: msg}}, replyChan)
}

fut := &Async{c: make(chan struct{}, 1)}

go func() {
defer callDone()
defer close(fut.c)

var (
Expand All @@ -66,7 +66,7 @@ func (c Configuration) AsyncCall(ctx context.Context, d QuorumCallData) *Async {
errs = append(errs, Error{r.nid, r.err})
break
}
replies[r.nid] = r.reply
replies[r.nid] = r.msg
if resp, quorum = d.QuorumFunction(d.Message, replies); quorum {
fut.reply, fut.err = resp, nil
return
Expand Down
211 changes: 211 additions & 0 deletions channel.go
Original file line number Diff line number Diff line change
@@ -0,0 +1,211 @@
package gorums

import (
"context"
"math"
"math/rand"
"sync"
"sync/atomic"
"time"

"github.com/relab/gorums/ordering"
"google.golang.org/grpc"
"google.golang.org/grpc/codes"
"google.golang.org/grpc/status"
"google.golang.org/protobuf/reflect/protoreflect"
)

type request struct {
ctx context.Context
msg *Message
opts callOptions
}

type response struct {
nid uint32
msg protoreflect.ProtoMessage
err error
}

type channel struct {
sendQ chan request
node *Node // needed for ID and setLastError
rand *rand.Rand
gorumsClient ordering.GorumsClient
gorumsStream ordering.Gorums_NodeStreamClient
streamMut sync.RWMutex
streamBroken atomicFlag
parentCtx context.Context
streamCtx context.Context
cancelStream context.CancelFunc
responseRouter map[uint64]chan<- response
responseMut sync.Mutex
}

func newChannel(n *Node) *channel {
return &channel{
sendQ: make(chan request, n.mgr.opts.sendBuffer),
node: n,
rand: rand.New(rand.NewSource(time.Now().UnixNano())),
responseRouter: make(map[uint64]chan<- response),
}
}

func (c *channel) connect(ctx context.Context, conn *grpc.ClientConn) error {
var err error
c.parentCtx = ctx
c.streamCtx, c.cancelStream = context.WithCancel(c.parentCtx)
c.gorumsClient = ordering.NewGorumsClient(conn)
c.gorumsStream, err = c.gorumsClient.NodeStream(c.streamCtx)
if err != nil {
return err
}
go c.sendMsgs()
go c.recvMsgs()
return nil
}

func (c *channel) routeResponse(msgID uint64, resp response) {
c.responseMut.Lock()
defer c.responseMut.Unlock()
if ch, ok := c.responseRouter[msgID]; ok {
ch <- resp
delete(c.responseRouter, msgID)
}
}

func (c *channel) enqueue(req request, responseChan chan<- response) {
if responseChan != nil {
c.responseMut.Lock()
c.responseRouter[req.msg.Metadata.MessageID] = responseChan
c.responseMut.Unlock()
}
c.sendQ <- req
}

func (c *channel) sendMsg(req request) (err error) {
// unblock the waiting caller unless noSendWaiting is enabled
defer func() {
if req.opts.callType == E_Multicast || req.opts.callType == E_Unicast && !req.opts.noSendWaiting {
c.routeResponse(req.msg.Metadata.MessageID, response{})
}
}()

// don't send if context is already cancelled.
if req.ctx.Err() != nil {
return req.ctx.Err()
}

c.streamMut.RLock()
defer c.streamMut.RUnlock()

done := make(chan struct{}, 1)

// wait for either the message to be sent, or the request context being cancelled.
// if the request context was cancelled, then we most likely have a blocked stream.
go func() {
select {
case <-done:
case <-req.ctx.Done():
c.cancelStream()
}
}()

err = c.gorumsStream.SendMsg(req.msg)
if err != nil {
c.node.setLastErr(err)
c.streamBroken.set()
}
done <- struct{}{}

return err
}

func (c *channel) sendMsgs() {
var req request
for {
select {
case <-c.parentCtx.Done():
return
case req = <-c.sendQ:
}
// return error if stream is broken
if c.streamBroken.get() {
err := status.Errorf(codes.Unavailable, "stream is down")
c.routeResponse(req.msg.Metadata.MessageID, response{nid: c.node.ID(), msg: nil, err: err})
continue
}
// else try to send message
err := c.sendMsg(req)
if err != nil {
// return the error
c.routeResponse(req.msg.Metadata.MessageID, response{nid: c.node.ID(), msg: nil, err: err})
}
}
}

func (c *channel) recvMsgs() {
for {
resp := newMessage(responseType)
c.streamMut.RLock()
err := c.gorumsStream.RecvMsg(resp)
if err != nil {
c.streamBroken.set()
c.streamMut.RUnlock()
c.node.setLastErr(err)
// attempt to reconnect
c.reconnect()
} else {
c.streamMut.RUnlock()
err := status.FromProto(resp.Metadata.GetStatus()).Err()
c.routeResponse(resp.Metadata.MessageID, response{nid: c.node.ID(), msg: resp.Message, err: err})
}

select {
case <-c.parentCtx.Done():
return
default:
}
}
}

func (c *channel) reconnect() {
c.streamMut.Lock()
defer c.streamMut.Unlock()
backoffCfg := c.node.mgr.opts.backoff

var retries float64
for {
var err error

c.streamCtx, c.cancelStream = context.WithCancel(c.parentCtx)
c.gorumsStream, err = c.gorumsClient.NodeStream(c.streamCtx)
if err == nil {
c.streamBroken.clear()
return
}
c.cancelStream()
c.node.setLastErr(err)
delay := float64(backoffCfg.BaseDelay)
max := float64(backoffCfg.MaxDelay)
for r := retries; delay < max && r > 0; r-- {
delay *= backoffCfg.Multiplier
}
delay = math.Min(delay, max)
delay *= 1 + backoffCfg.Jitter*(rand.Float64()*2-1)
select {
case <-time.After(time.Duration(delay)):
retries++
case <-c.parentCtx.Done():
return
}
}
}

type atomicFlag struct {
flag int32
}

func (f *atomicFlag) set() { atomic.StoreInt32(&f.flag, 1) }
func (f *atomicFlag) get() bool { return atomic.LoadInt32(&f.flag) == 1 }
func (f *atomicFlag) clear() { atomic.StoreInt32(&f.flag, 0) }
17 changes: 2 additions & 15 deletions config.go
Original file line number Diff line number Diff line change
Expand Up @@ -2,8 +2,6 @@ package gorums

import (
"fmt"

"github.com/relab/gorums/ordering"
)

// Configuration represents a static set of nodes on which quorum calls may be invoked.
Expand Down Expand Up @@ -52,17 +50,6 @@ func (c Configuration) Equal(b Configuration) bool {
return true
}

// newCall returns unique metadata for a method call.
func (c Configuration) newCall(method string) (md *ordering.Metadata) {
// Note that we just use the first node's newCall method since all nodes
// associated with the same manager use the same receiveQueue instance.
return c[0].newCall(method)
}

// newReply returns a channel for receiving replies
// and a done function to be called for clean up.
func (c Configuration) newReply(md *ordering.Metadata, maxReplies int) (replyChan chan *gorumsStreamResult, done func()) {
// Note that we just use the first node's newReply method since all nodes
// associated with the same manager use the same receiveQueue instance.
return c[0].newReply(md, maxReplies)
func (c Configuration) getMsgID() uint64 {
return c[0].mgr.getMsgID()
}
15 changes: 7 additions & 8 deletions correctable.go
Original file line number Diff line number Diff line change
Expand Up @@ -4,6 +4,7 @@ import (
"context"
"sync"

"github.com/relab/gorums/ordering"
"google.golang.org/protobuf/reflect/protoreflect"
)

Expand Down Expand Up @@ -82,9 +83,9 @@ type CorrectableCallData struct {

func (c Configuration) CorrectableCall(ctx context.Context, d CorrectableCallData) *Correctable {
expectedReplies := len(c)
md := c.newCall(d.Method)
replyChan, callDone := c.newReply(md, expectedReplies)
md := &ordering.Metadata{MessageID: c.getMsgID(), Method: d.Method}

replyChan := make(chan response, expectedReplies)
for _, n := range c {
msg := d.Message
if d.PerNodeArgFn != nil {
Expand All @@ -94,14 +95,12 @@ func (c Configuration) CorrectableCall(ctx context.Context, d CorrectableCallDat
continue // don't send if no msg
}
}
n.sendQ <- gorumsStreamRequest{ctx: ctx, msg: &Message{Metadata: md, Message: msg}}
n.channel.enqueue(request{ctx: ctx, msg: &Message{Metadata: md, Message: msg}}, replyChan)
}

corr := &Correctable{donech: make(chan struct{}, 1)}

go func() {
defer callDone()

var (
resp protoreflect.ProtoMessage
errs []Error
Expand All @@ -118,15 +117,15 @@ func (c Configuration) CorrectableCall(ctx context.Context, d CorrectableCallDat
errs = append(errs, Error{r.nid, r.err})
break
}
replies[r.nid] = r.reply
replies[r.nid] = r.msg
if resp, rlevel, quorum = d.QuorumFunction(d.Message, replies); quorum {
if quorum {
corr.set(r.reply, rlevel, nil, true)
corr.set(r.msg, rlevel, nil, true)
return
}
if rlevel > clevel {
clevel = rlevel
corr.set(r.reply, rlevel, nil, false)
corr.set(r.msg, rlevel, nil, false)
}
}
case <-ctx.Done():
Expand Down
16 changes: 10 additions & 6 deletions mgr.go
Original file line number Diff line number Diff line change
Expand Up @@ -4,6 +4,7 @@ import (
"fmt"
"log"
"sync"
"sync/atomic"

"google.golang.org/grpc"
"google.golang.org/grpc/backoff"
Expand All @@ -18,8 +19,7 @@ type Manager struct {
closeOnce sync.Once
logger *log.Logger
opts managerOptions

*receiveQueue
nextMsgID uint64
}

// NewManager returns a new Manager for managing connection to nodes added
Expand All @@ -28,9 +28,8 @@ type Manager struct {
// You should use the `NewManager` function in the generated code instead.
func NewManager(opts ...ManagerOption) *Manager {
m := &Manager{
lookup: make(map[uint32]*Node),
receiveQueue: newReceiveQueue(),
opts: newManagerOptions(),
lookup: make(map[uint32]*Node),
opts: newManagerOptions(),
}
for _, opt := range opts {
opt(&m.opts)
Expand Down Expand Up @@ -116,7 +115,7 @@ func (m *Manager) AddNode(node *Node) error {
if m.logger != nil {
m.logger.Printf("connecting to %s with id %d\n", node, node.id)
}
if err := node.connect(m.receiveQueue, m.opts); err != nil {
if err := node.connect(m); err != nil {
return fmt.Errorf("connection failed for %s: %w", node, err)
}

Expand All @@ -126,3 +125,8 @@ func (m *Manager) AddNode(node *Node) error {
m.nodes = append(m.nodes, node)
return nil
}

// getMsgID returns a unique message ID.
func (m *Manager) getMsgID() uint64 {
return atomic.AddUint64(&m.nextMsgID, 1)
}
Loading

0 comments on commit a5ad159

Please sign in to comment.