Skip to content

Commit

Permalink
Merge pull request #143 from relab/decouple-node-from-channel
Browse files Browse the repository at this point in the history
Decouple node from channel
  • Loading branch information
meling authored May 10, 2021
2 parents a5ad159 + f685e26 commit d1d3428
Show file tree
Hide file tree
Showing 3 changed files with 77 additions and 57 deletions.
45 changes: 36 additions & 9 deletions channel.go
Original file line number Diff line number Diff line change
Expand Up @@ -10,6 +10,7 @@ import (

"github.com/relab/gorums/ordering"
"google.golang.org/grpc"
"google.golang.org/grpc/backoff"
"google.golang.org/grpc/codes"
"google.golang.org/grpc/status"
"google.golang.org/protobuf/reflect/protoreflect"
Expand All @@ -29,7 +30,11 @@ type response struct {

type channel struct {
sendQ chan request
node *Node // needed for ID and setLastError
nodeID uint32
mu sync.Mutex
lastError error
latency time.Duration
backoffCfg backoff.Config
rand *rand.Rand
gorumsClient ordering.GorumsClient
gorumsStream ordering.Gorums_NodeStreamClient
Expand All @@ -45,7 +50,9 @@ type channel struct {
func newChannel(n *Node) *channel {
return &channel{
sendQ: make(chan request, n.mgr.opts.sendBuffer),
node: n,
backoffCfg: n.mgr.opts.backoff,
nodeID: n.ID(),
latency: -1 * time.Second,
rand: rand.New(rand.NewSource(time.Now().UnixNano())),
responseRouter: make(map[uint64]chan<- response),
}
Expand Down Expand Up @@ -113,7 +120,7 @@ func (c *channel) sendMsg(req request) (err error) {

err = c.gorumsStream.SendMsg(req.msg)
if err != nil {
c.node.setLastErr(err)
c.setLastErr(err)
c.streamBroken.set()
}
done <- struct{}{}
Expand All @@ -132,14 +139,14 @@ func (c *channel) sendMsgs() {
// return error if stream is broken
if c.streamBroken.get() {
err := status.Errorf(codes.Unavailable, "stream is down")
c.routeResponse(req.msg.Metadata.MessageID, response{nid: c.node.ID(), msg: nil, err: err})
c.routeResponse(req.msg.Metadata.MessageID, response{nid: c.nodeID, msg: nil, err: err})
continue
}
// else try to send message
err := c.sendMsg(req)
if err != nil {
// return the error
c.routeResponse(req.msg.Metadata.MessageID, response{nid: c.node.ID(), msg: nil, err: err})
c.routeResponse(req.msg.Metadata.MessageID, response{nid: c.nodeID, msg: nil, err: err})
}
}
}
Expand All @@ -152,13 +159,13 @@ func (c *channel) recvMsgs() {
if err != nil {
c.streamBroken.set()
c.streamMut.RUnlock()
c.node.setLastErr(err)
c.setLastErr(err)
// attempt to reconnect
c.reconnect()
} else {
c.streamMut.RUnlock()
err := status.FromProto(resp.Metadata.GetStatus()).Err()
c.routeResponse(resp.Metadata.MessageID, response{nid: c.node.ID(), msg: resp.Message, err: err})
c.routeResponse(resp.Metadata.MessageID, response{nid: c.nodeID, msg: resp.Message, err: err})
}

select {
Expand All @@ -172,7 +179,7 @@ func (c *channel) recvMsgs() {
func (c *channel) reconnect() {
c.streamMut.Lock()
defer c.streamMut.Unlock()
backoffCfg := c.node.mgr.opts.backoff
backoffCfg := c.backoffCfg

var retries float64
for {
Expand All @@ -185,7 +192,7 @@ func (c *channel) reconnect() {
return
}
c.cancelStream()
c.node.setLastErr(err)
c.setLastErr(err)
delay := float64(backoffCfg.BaseDelay)
max := float64(backoffCfg.MaxDelay)
for r := retries; delay < max && r > 0; r-- {
Expand All @@ -202,6 +209,26 @@ func (c *channel) reconnect() {
}
}

func (c *channel) setLastErr(err error) {
c.mu.Lock()
defer c.mu.Unlock()
c.lastError = err
}

// lastErr returns the last error encountered (if any) when using this channel.
func (c *channel) lastErr() error {
c.mu.Lock()
defer c.mu.Unlock()
return c.lastError
}

// channelLatency returns the latency between the client and this channel.
func (c *channel) channelLatency() time.Duration {
c.mu.Lock()
defer c.mu.Unlock()
return c.latency
}

type atomicFlag struct {
flag int32
}
Expand Down
53 changes: 19 additions & 34 deletions node.go
Original file line number Diff line number Diff line change
Expand Up @@ -7,7 +7,6 @@ import (
"net"
"sort"
"strconv"
"sync"
"time"

"google.golang.org/grpc"
Expand All @@ -20,14 +19,11 @@ const nilAngleString = "<nil>"
// can be performed.
type Node struct {
// Only assigned at creation.
id uint32
addr string
conn *grpc.ClientConn
cancel func()
mu sync.Mutex
lastErr error
latency time.Duration
mgr *Manager
id uint32
addr string
conn *grpc.ClientConn
cancel func()
mgr *Manager

// the default channel
channel *channel
Expand All @@ -42,9 +38,8 @@ func NewNode(addr string) (*Node, error) {
h := fnv.New32a()
_, _ = h.Write([]byte(tcpAddr.String()))
return &Node{
id: h.Sum32(),
addr: tcpAddr.String(),
latency: -1 * time.Second,
id: h.Sum32(),
addr: tcpAddr.String(),
}, nil
}

Expand All @@ -55,9 +50,8 @@ func NewNodeWithID(addr string, id uint32) (*Node, error) {
return nil, fmt.Errorf("node error: '%s' error: %v", addr, err)
}
return &Node{
id: id,
addr: tcpAddr.String(),
latency: -1 * time.Second,
id: id,
addr: tcpAddr.String(),
}, nil
}

Expand Down Expand Up @@ -136,28 +130,19 @@ func (n *Node) String() string {
// includes id, network address and latency information.
func (n *Node) FullString() string {
if n != nil {
n.mu.Lock()
defer n.mu.Unlock()
return fmt.Sprintf(
"node %d | addr: %s | latency: %v",
n.id, n.addr, n.latency,
)
return fmt.Sprintf("node %d | addr: %s", n.id, n.addr)
}
return nilAngleString
}

func (n *Node) setLastErr(err error) {
n.mu.Lock()
defer n.mu.Unlock()
n.lastErr = err
// LastErr returns the last error encountered (if any) for this node.
func (n *Node) LastErr() error {
return n.channel.lastErr()
}

// LastErr returns the last error encountered (if any) when invoking a remote
// procedure call on this node.
func (n *Node) LastErr() error {
n.mu.Lock()
defer n.mu.Unlock()
return n.lastErr
// Latency returns the latency between the client and this node.
func (n *Node) Latency() time.Duration {
return n.channel.channelLatency()
}

type lessFunc func(n1, n2 *Node) bool
Expand Down Expand Up @@ -194,8 +179,8 @@ func (ms *MultiSorter) Swap(i, j int) {
}

// Less is part of sort.Interface. It is implemented by looping along the
// less functions until it finds a comparison that is either Less or
// !Less. Note that it can call the less functions twice per call. We
// less functions until it finds a comparison that is either Less or not
// Less. Note that it can call the less functions twice per call. We
// could change the functions to return -1, 0, 1 and reduce the
// number of calls for greater efficiency: an exercise for the reader.
func (ms *MultiSorter) Less(i, j int) bool {
Expand Down Expand Up @@ -235,7 +220,7 @@ var Port = func(n1, n2 *Node) bool {
// LastNodeError sorts nodes by their LastErr() status in increasing order. A
// node with LastErr() != nil is larger than a node with LastErr() == nil.
var LastNodeError = func(n1, n2 *Node) bool {
if n1.lastErr != nil && n2.lastErr == nil {
if n1.channel.lastErr() != nil && n2.channel.lastErr() == nil {
return false
}
return true
Expand Down
36 changes: 22 additions & 14 deletions node_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -10,24 +10,32 @@ import (
func TestNodeSort(t *testing.T) {
nodes := []*Node{
{
id: 100,
lastErr: nil,
latency: time.Second,
id: 100,
channel: &channel{
lastError: nil,
latency: time.Second,
},
},
{
id: 101,
lastErr: errors.New("some error"),
latency: 250 * time.Millisecond,
id: 101,
channel: &channel{
lastError: errors.New("some error"),
latency: 250 * time.Millisecond,
},
},
{
id: 42,
lastErr: nil,
latency: 300 * time.Millisecond,
id: 42,
channel: &channel{
lastError: nil,
latency: 300 * time.Millisecond,
},
},
{
id: 99,
lastErr: errors.New("some error"),
latency: 500 * time.Millisecond,
id: 99,
channel: &channel{
lastError: errors.New("some error"),
latency: 500 * time.Millisecond,
},
},
}

Expand All @@ -43,7 +51,7 @@ func TestNodeSort(t *testing.T) {

OrderedBy(LastNodeError).Sort(nodes)
for i := n - 1; i > 0; i-- {
if nodes[i].lastErr == nil && nodes[i-1].lastErr != nil {
if nodes[i].LastErr() == nil && nodes[i-1].LastErr() != nil {
t.Error("by error: not sorted")
printNodes(t, nodes)
}
Expand All @@ -55,7 +63,7 @@ func printNodes(t *testing.T, nodes []*Node) {
for i, n := range nodes {
nodeStr := fmt.Sprintf(
"%d: node %d | addr: %s | latency: %v | err: %v",
i, n.id, n.addr, n.latency, n.lastErr)
i, n.id, n.addr, n.Latency(), n.LastErr())
t.Logf("%s", nodeStr)
}
}

0 comments on commit d1d3428

Please sign in to comment.