Skip to content

Commit 1c097fe

Browse files
feat(Broadcast): router implementation
The router is used to send messages to other servers (broadcast) or to reply to clients (clientReply)
1 parent 6804a92 commit 1c097fe

File tree

3 files changed

+226
-0
lines changed

3 files changed

+226
-0
lines changed

broadcast/router/connPool.go

+44
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,44 @@
1+
package router
2+
3+
import (
4+
"github.com/relab/gorums/broadcast/dtos"
5+
"sync"
6+
)
7+
8+
// ConnPool is used to persist connection from the server to other clients.
9+
// This significantly increases performance because it reuses connections for separate
10+
// messages.
11+
type ConnPool struct {
12+
mut sync.Mutex
13+
clients map[string]*dtos.Client
14+
}
15+
16+
func newConnPool() *ConnPool {
17+
return &ConnPool{
18+
clients: make(map[string]*dtos.Client),
19+
}
20+
}
21+
22+
func (cp *ConnPool) getClient(addr string) (*dtos.Client, bool) {
23+
cp.mut.Lock()
24+
defer cp.mut.Unlock()
25+
client, ok := cp.clients[addr]
26+
return client, ok
27+
}
28+
29+
func (cp *ConnPool) addClient(addr string, client *dtos.Client) {
30+
cp.mut.Lock()
31+
defer cp.mut.Unlock()
32+
cp.clients[addr] = client
33+
}
34+
35+
func (cp *ConnPool) Close() error {
36+
var err error = nil
37+
for _, client := range cp.clients {
38+
clientErr := client.Close()
39+
if clientErr != nil {
40+
err = clientErr
41+
}
42+
}
43+
return err
44+
}

broadcast/router/consts.go

+10
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,10 @@
1+
package router
2+
3+
import "github.com/relab/gorums/broadcast/dtos"
4+
5+
type ServerHandler func(broadcastMsg *dtos.BroadcastMsg)
6+
7+
const (
8+
// ServerOriginAddr is special origin Addr used in creating a broadcast request from a server
9+
ServerOriginAddr string = "server"
10+
)

broadcast/router/router.go

+172
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,172 @@
1+
package router
2+
3+
import (
4+
"context"
5+
"errors"
6+
"github.com/relab/gorums/broadcast/dtos"
7+
errs "github.com/relab/gorums/broadcast/errors"
8+
"log/slog"
9+
"sync"
10+
"time"
11+
12+
"github.com/relab/gorums/logging"
13+
"google.golang.org/grpc"
14+
"google.golang.org/grpc/credentials/insecure"
15+
)
16+
17+
// Router is defined as an interface in order to allow mocking it in tests.
18+
type Router interface {
19+
Broadcast(dto *dtos.BroadcastMsg) error
20+
ReplyToClient(dto *dtos.ReplyMsg) error
21+
Connect(addr string)
22+
Close() error
23+
}
24+
25+
type router struct {
26+
mut sync.RWMutex
27+
id uint32
28+
addr string
29+
serverHandlers map[string]ServerHandler // handlers on other servers
30+
clientHandlers map[string]struct{} // specifies what handlers a client has implemented. Used only for BroadcastCalls.
31+
createClient func(addr string, dialOpts []grpc.DialOption) (*dtos.Client, error)
32+
dialOpts []grpc.DialOption
33+
dialTimeout time.Duration
34+
logger *slog.Logger
35+
connPool *ConnPool
36+
allowList map[string]string // whitelist of (address, pubKey) pairs the server can reply to
37+
}
38+
39+
type Config struct {
40+
ID uint32
41+
Addr string
42+
Logger *slog.Logger
43+
CreateClient func(addr string, dialOpts []grpc.DialOption) (*dtos.Client, error)
44+
DialTimeout time.Duration
45+
AllowList map[string]string
46+
DialOpts []grpc.DialOption
47+
}
48+
49+
func NewRouter(config *Config) *router {
50+
if len(config.DialOpts) <= 0 {
51+
config.DialOpts = []grpc.DialOption{
52+
grpc.WithTransportCredentials(insecure.NewCredentials()),
53+
}
54+
}
55+
return &router{
56+
id: config.ID,
57+
addr: config.Addr,
58+
serverHandlers: make(map[string]ServerHandler),
59+
clientHandlers: make(map[string]struct{}),
60+
createClient: config.CreateClient,
61+
dialOpts: config.DialOpts,
62+
dialTimeout: config.DialTimeout,
63+
logger: config.Logger,
64+
allowList: config.AllowList,
65+
connPool: newConnPool(),
66+
}
67+
}
68+
func (r *router) Connect(addr string) {
69+
_, _ = r.getClient(addr)
70+
}
71+
72+
func (r *router) Broadcast(dto *dtos.BroadcastMsg) error {
73+
if handler, ok := r.serverHandlers[dto.Info.Method]; ok {
74+
// it runs an interceptor prior to broadcastCall, hence a different signature.
75+
// see: (srv *broadcastServer) registerBroadcastFunc(method string).
76+
handler(dto)
77+
return nil
78+
}
79+
err := errors.New("handler not found")
80+
r.log("router (broadcast): could not find handler", err, &dto.Info)
81+
return err
82+
}
83+
84+
func (r *router) ReplyToClient(dto *dtos.ReplyMsg) error {
85+
// the client has initiated a broadcast call and the reply should be sent as an RPC
86+
if _, ok := r.clientHandlers[dto.Info.Method]; ok && dto.ClientAddr != "" {
87+
client, err := r.getClient(dto.ClientAddr)
88+
if err != nil {
89+
//r.log("router (reply): could not get client", err, logging.BroadcastID(dto.BroadcastID), logging.NodeAddr(dto.Addr), logging.Method(dto.Method))
90+
return err
91+
}
92+
err = client.SendMsg(r.dialTimeout, dto)
93+
r.log("router (reply): sending reply to client", err, &dto.Info)
94+
return err
95+
}
96+
// the server can receive a broadcast from another server before a client sends a direct message.
97+
// it should thus wait for a potential message from the client. otherwise, it should be removed.
98+
err := errors.New("not routed")
99+
r.log("router (reply): could not find handler", err, &dto.Info)
100+
return err
101+
}
102+
103+
func (r *router) validAddr(addr string) bool {
104+
if addr == "" {
105+
return false
106+
}
107+
if addr == ServerOriginAddr {
108+
return false
109+
}
110+
if r.allowList != nil {
111+
_, ok := r.allowList[addr]
112+
return ok
113+
}
114+
return true
115+
}
116+
117+
func (r *router) getClient(addr string) (*dtos.Client, error) {
118+
if !r.validAddr(addr) {
119+
return nil, errs.InvalidAddrErr{Addr: addr}
120+
}
121+
// fast path:
122+
// read lock because it is likely that we will send many
123+
// messages to the same client.
124+
r.mut.RLock()
125+
if client, ok := r.connPool.getClient(addr); ok {
126+
r.mut.RUnlock()
127+
return client, nil
128+
}
129+
r.mut.RUnlock()
130+
// slow path:
131+
// we need a write lock when adding a new client. This only process
132+
// one at a time and is thus necessary to check if the client has
133+
// already been added again. Otherwise, we can end up creating multiple
134+
// clients.
135+
r.mut.Lock()
136+
defer r.mut.Unlock()
137+
if client, ok := r.connPool.getClient(addr); ok {
138+
return client, nil
139+
}
140+
client, err := r.createClient(addr, r.dialOpts)
141+
if err != nil {
142+
return nil, err
143+
}
144+
r.connPool.addClient(addr, client)
145+
return client, nil
146+
}
147+
148+
func (r *router) log(msg string, err error, info *dtos.Info) {
149+
if r.logger != nil {
150+
args := []slog.Attr{logging.BroadcastID(info.BroadcastID), logging.NodeAddr(info.Addr), logging.Method(info.Method), logging.Err(err), logging.Type("router")}
151+
level := slog.LevelInfo
152+
if err != nil {
153+
level = slog.LevelError
154+
}
155+
r.logger.LogAttrs(context.Background(), level, msg, args...)
156+
}
157+
}
158+
159+
func (r *router) Close() error {
160+
return r.connPool.Close()
161+
}
162+
163+
func (r *router) AddHandler(method string, handler any) {
164+
switch h := handler.(type) {
165+
case ServerHandler:
166+
r.serverHandlers[method] = h
167+
default:
168+
// only needs to know whether the handler exists. routing is done
169+
// client-side using the provided metadata in the request.
170+
r.clientHandlers[method] = struct{}{}
171+
}
172+
}

0 commit comments

Comments
 (0)