From 51a48cba4799cdc3c354a42e7025f58bffd63c81 Mon Sep 17 00:00:00 2001 From: Hein Meling Date: Mon, 10 May 2021 19:24:53 +0200 Subject: [PATCH 1/2] Decoupled node from channel --- channel.go | 45 +++++++++++++++++++++++++++++++++++--------- node.go | 53 +++++++++++++++++++--------------------------------- node_test.go | 36 +++++++++++++++++++++-------------- 3 files changed, 77 insertions(+), 57 deletions(-) diff --git a/channel.go b/channel.go index b9d98a4b..d1c59182 100644 --- a/channel.go +++ b/channel.go @@ -10,6 +10,7 @@ import ( "github.com/relab/gorums/ordering" "google.golang.org/grpc" + "google.golang.org/grpc/backoff" "google.golang.org/grpc/codes" "google.golang.org/grpc/status" "google.golang.org/protobuf/reflect/protoreflect" @@ -29,7 +30,11 @@ type response struct { type channel struct { sendQ chan request - node *Node // needed for ID and setLastError + nodeID uint32 + mu sync.Mutex + lastErr error + latency time.Duration + backoffCfg backoff.Config rand *rand.Rand gorumsClient ordering.GorumsClient gorumsStream ordering.Gorums_NodeStreamClient @@ -45,7 +50,9 @@ type channel struct { func newChannel(n *Node) *channel { return &channel{ sendQ: make(chan request, n.mgr.opts.sendBuffer), - node: n, + backoffCfg: n.mgr.opts.backoff, + nodeID: n.ID(), + latency: -1 * time.Second, rand: rand.New(rand.NewSource(time.Now().UnixNano())), responseRouter: make(map[uint64]chan<- response), } @@ -113,7 +120,7 @@ func (c *channel) sendMsg(req request) (err error) { err = c.gorumsStream.SendMsg(req.msg) if err != nil { - c.node.setLastErr(err) + c.setLastErr(err) c.streamBroken.set() } done <- struct{}{} @@ -132,14 +139,14 @@ func (c *channel) sendMsgs() { // return error if stream is broken if c.streamBroken.get() { err := status.Errorf(codes.Unavailable, "stream is down") - c.routeResponse(req.msg.Metadata.MessageID, response{nid: c.node.ID(), msg: nil, err: err}) + c.routeResponse(req.msg.Metadata.MessageID, response{nid: c.nodeID, msg: nil, err: err}) continue } // else try to send message err := c.sendMsg(req) if err != nil { // return the error - c.routeResponse(req.msg.Metadata.MessageID, response{nid: c.node.ID(), msg: nil, err: err}) + c.routeResponse(req.msg.Metadata.MessageID, response{nid: c.nodeID, msg: nil, err: err}) } } } @@ -152,13 +159,13 @@ func (c *channel) recvMsgs() { if err != nil { c.streamBroken.set() c.streamMut.RUnlock() - c.node.setLastErr(err) + c.setLastErr(err) // attempt to reconnect c.reconnect() } else { c.streamMut.RUnlock() err := status.FromProto(resp.Metadata.GetStatus()).Err() - c.routeResponse(resp.Metadata.MessageID, response{nid: c.node.ID(), msg: resp.Message, err: err}) + c.routeResponse(resp.Metadata.MessageID, response{nid: c.nodeID, msg: resp.Message, err: err}) } select { @@ -172,7 +179,7 @@ func (c *channel) recvMsgs() { func (c *channel) reconnect() { c.streamMut.Lock() defer c.streamMut.Unlock() - backoffCfg := c.node.mgr.opts.backoff + backoffCfg := c.backoffCfg var retries float64 for { @@ -185,7 +192,7 @@ func (c *channel) reconnect() { return } c.cancelStream() - c.node.setLastErr(err) + c.setLastErr(err) delay := float64(backoffCfg.BaseDelay) max := float64(backoffCfg.MaxDelay) for r := retries; delay < max && r > 0; r-- { @@ -202,6 +209,26 @@ func (c *channel) reconnect() { } } +func (c *channel) setLastErr(err error) { + c.mu.Lock() + defer c.mu.Unlock() + c.lastErr = err +} + +// LastErr returns the last error encountered (if any) when using this channel. +func (c *channel) LastErr() error { + c.mu.Lock() + defer c.mu.Unlock() + return c.lastErr +} + +// Latency returns the latency between the client and this channel. +func (c *channel) Latency() time.Duration { + c.mu.Lock() + defer c.mu.Unlock() + return c.latency +} + type atomicFlag struct { flag int32 } diff --git a/node.go b/node.go index d9dcdaa7..4d9329f6 100644 --- a/node.go +++ b/node.go @@ -7,7 +7,6 @@ import ( "net" "sort" "strconv" - "sync" "time" "google.golang.org/grpc" @@ -20,14 +19,11 @@ const nilAngleString = "" // can be performed. type Node struct { // Only assigned at creation. - id uint32 - addr string - conn *grpc.ClientConn - cancel func() - mu sync.Mutex - lastErr error - latency time.Duration - mgr *Manager + id uint32 + addr string + conn *grpc.ClientConn + cancel func() + mgr *Manager // the default channel channel *channel @@ -42,9 +38,8 @@ func NewNode(addr string) (*Node, error) { h := fnv.New32a() _, _ = h.Write([]byte(tcpAddr.String())) return &Node{ - id: h.Sum32(), - addr: tcpAddr.String(), - latency: -1 * time.Second, + id: h.Sum32(), + addr: tcpAddr.String(), }, nil } @@ -55,9 +50,8 @@ func NewNodeWithID(addr string, id uint32) (*Node, error) { return nil, fmt.Errorf("node error: '%s' error: %v", addr, err) } return &Node{ - id: id, - addr: tcpAddr.String(), - latency: -1 * time.Second, + id: id, + addr: tcpAddr.String(), }, nil } @@ -136,28 +130,19 @@ func (n *Node) String() string { // includes id, network address and latency information. func (n *Node) FullString() string { if n != nil { - n.mu.Lock() - defer n.mu.Unlock() - return fmt.Sprintf( - "node %d | addr: %s | latency: %v", - n.id, n.addr, n.latency, - ) + return fmt.Sprintf("node %d | addr: %s", n.id, n.addr) } return nilAngleString } -func (n *Node) setLastErr(err error) { - n.mu.Lock() - defer n.mu.Unlock() - n.lastErr = err +// LastErr returns the last error encountered (if any) for this node. +func (n *Node) LastErr() error { + return n.channel.LastErr() } -// LastErr returns the last error encountered (if any) when invoking a remote -// procedure call on this node. -func (n *Node) LastErr() error { - n.mu.Lock() - defer n.mu.Unlock() - return n.lastErr +// Latency returns the latency between the client and this node. +func (n *Node) Latency() time.Duration { + return n.channel.Latency() } type lessFunc func(n1, n2 *Node) bool @@ -194,8 +179,8 @@ func (ms *MultiSorter) Swap(i, j int) { } // Less is part of sort.Interface. It is implemented by looping along the -// less functions until it finds a comparison that is either Less or -// !Less. Note that it can call the less functions twice per call. We +// less functions until it finds a comparison that is either Less or not +// Less. Note that it can call the less functions twice per call. We // could change the functions to return -1, 0, 1 and reduce the // number of calls for greater efficiency: an exercise for the reader. func (ms *MultiSorter) Less(i, j int) bool { @@ -235,7 +220,7 @@ var Port = func(n1, n2 *Node) bool { // LastNodeError sorts nodes by their LastErr() status in increasing order. A // node with LastErr() != nil is larger than a node with LastErr() == nil. var LastNodeError = func(n1, n2 *Node) bool { - if n1.lastErr != nil && n2.lastErr == nil { + if n1.channel.LastErr() != nil && n2.channel.LastErr() == nil { return false } return true diff --git a/node_test.go b/node_test.go index d22f7fa2..d6af35d4 100644 --- a/node_test.go +++ b/node_test.go @@ -10,24 +10,32 @@ import ( func TestNodeSort(t *testing.T) { nodes := []*Node{ { - id: 100, - lastErr: nil, - latency: time.Second, + id: 100, + channel: &channel{ + lastErr: nil, + latency: time.Second, + }, }, { - id: 101, - lastErr: errors.New("some error"), - latency: 250 * time.Millisecond, + id: 101, + channel: &channel{ + lastErr: errors.New("some error"), + latency: 250 * time.Millisecond, + }, }, { - id: 42, - lastErr: nil, - latency: 300 * time.Millisecond, + id: 42, + channel: &channel{ + lastErr: nil, + latency: 300 * time.Millisecond, + }, }, { - id: 99, - lastErr: errors.New("some error"), - latency: 500 * time.Millisecond, + id: 99, + channel: &channel{ + lastErr: errors.New("some error"), + latency: 500 * time.Millisecond, + }, }, } @@ -43,7 +51,7 @@ func TestNodeSort(t *testing.T) { OrderedBy(LastNodeError).Sort(nodes) for i := n - 1; i > 0; i-- { - if nodes[i].lastErr == nil && nodes[i-1].lastErr != nil { + if nodes[i].LastErr() == nil && nodes[i-1].LastErr() != nil { t.Error("by error: not sorted") printNodes(t, nodes) } @@ -55,7 +63,7 @@ func printNodes(t *testing.T, nodes []*Node) { for i, n := range nodes { nodeStr := fmt.Sprintf( "%d: node %d | addr: %s | latency: %v | err: %v", - i, n.id, n.addr, n.latency, n.lastErr) + i, n.id, n.addr, n.Latency(), n.LastErr()) t.Logf("%s", nodeStr) } } From f685e26fe536ec139ce61dc12ada33527d463e40 Mon Sep 17 00:00:00 2001 From: Hein Meling Date: Mon, 10 May 2021 19:38:10 +0200 Subject: [PATCH 2/2] Made channel methods private --- channel.go | 14 +++++++------- node.go | 6 +++--- node_test.go | 16 ++++++++-------- 3 files changed, 18 insertions(+), 18 deletions(-) diff --git a/channel.go b/channel.go index d1c59182..cffda694 100644 --- a/channel.go +++ b/channel.go @@ -32,7 +32,7 @@ type channel struct { sendQ chan request nodeID uint32 mu sync.Mutex - lastErr error + lastError error latency time.Duration backoffCfg backoff.Config rand *rand.Rand @@ -212,18 +212,18 @@ func (c *channel) reconnect() { func (c *channel) setLastErr(err error) { c.mu.Lock() defer c.mu.Unlock() - c.lastErr = err + c.lastError = err } -// LastErr returns the last error encountered (if any) when using this channel. -func (c *channel) LastErr() error { +// lastErr returns the last error encountered (if any) when using this channel. +func (c *channel) lastErr() error { c.mu.Lock() defer c.mu.Unlock() - return c.lastErr + return c.lastError } -// Latency returns the latency between the client and this channel. -func (c *channel) Latency() time.Duration { +// channelLatency returns the latency between the client and this channel. +func (c *channel) channelLatency() time.Duration { c.mu.Lock() defer c.mu.Unlock() return c.latency diff --git a/node.go b/node.go index 4d9329f6..52309417 100644 --- a/node.go +++ b/node.go @@ -137,12 +137,12 @@ func (n *Node) FullString() string { // LastErr returns the last error encountered (if any) for this node. func (n *Node) LastErr() error { - return n.channel.LastErr() + return n.channel.lastErr() } // Latency returns the latency between the client and this node. func (n *Node) Latency() time.Duration { - return n.channel.Latency() + return n.channel.channelLatency() } type lessFunc func(n1, n2 *Node) bool @@ -220,7 +220,7 @@ var Port = func(n1, n2 *Node) bool { // LastNodeError sorts nodes by their LastErr() status in increasing order. A // node with LastErr() != nil is larger than a node with LastErr() == nil. var LastNodeError = func(n1, n2 *Node) bool { - if n1.channel.LastErr() != nil && n2.channel.LastErr() == nil { + if n1.channel.lastErr() != nil && n2.channel.lastErr() == nil { return false } return true diff --git a/node_test.go b/node_test.go index d6af35d4..3a260411 100644 --- a/node_test.go +++ b/node_test.go @@ -12,29 +12,29 @@ func TestNodeSort(t *testing.T) { { id: 100, channel: &channel{ - lastErr: nil, - latency: time.Second, + lastError: nil, + latency: time.Second, }, }, { id: 101, channel: &channel{ - lastErr: errors.New("some error"), - latency: 250 * time.Millisecond, + lastError: errors.New("some error"), + latency: 250 * time.Millisecond, }, }, { id: 42, channel: &channel{ - lastErr: nil, - latency: 300 * time.Millisecond, + lastError: nil, + latency: 300 * time.Millisecond, }, }, { id: 99, channel: &channel{ - lastErr: errors.New("some error"), - latency: 500 * time.Millisecond, + lastError: errors.New("some error"), + latency: 500 * time.Millisecond, }, }, }