Skip to content

Commit 656a7fc

Browse files
feat(broadcast): added broadcast request auth
1 parent 8dba797 commit 656a7fc

18 files changed

+203
-46
lines changed

authentication/authentication.go

+24
Original file line numberDiff line numberDiff line change
@@ -1,6 +1,7 @@
11
package authentication
22

33
import (
4+
"bytes"
45
"crypto/ecdsa"
56
"crypto/elliptic"
67
"crypto/rand"
@@ -151,3 +152,26 @@ func (ec *EllipticCurve) EncodeMsg(msg any) ([]byte, error) {
151152
}
152153
return encodedMsg.Bytes(), nil*/
153154
}
155+
156+
func encodeMsg(msg any) ([]byte, error) {
157+
return []byte(fmt.Sprintf("%v", msg)), nil
158+
}
159+
160+
func Verify(pemEncodedPub string, signature, digest []byte, msg any) (bool, error) {
161+
encodedMsg, err := encodeMsg(msg)
162+
if err != nil {
163+
return false, err
164+
}
165+
ec := New(elliptic.P256())
166+
h := sha256.Sum256(encodedMsg)
167+
hash := h[:]
168+
if !bytes.Equal(hash, digest) {
169+
return false, fmt.Errorf("wrong digest")
170+
}
171+
pubKey, err := ec.DecodePublic(pemEncodedPub)
172+
if err != nil {
173+
return false, err
174+
}
175+
ok := ecdsa.VerifyASN1(pubKey, hash, signature)
176+
return ok, nil
177+
}

broadcast.go

+10-1
Original file line numberDiff line numberDiff line change
@@ -161,11 +161,13 @@ type BroadcastMetadata struct {
161161
OriginAddr string // address of the origin
162162
OriginMethod string // the first method called by the origin
163163
Method string // the current method
164-
Digest []byte // digest of original message sent by client
165164
Timestamp time.Time // timestamp in seconds when the broadcast request was issued by the client/server
166165
ShardID uint16 // ID of the shard handling the broadcast request
167166
MachineID uint16 // ID of the client/server that issued the broadcast request
168167
SequenceNo uint32 // sequence number of the broadcast request from that particular client/server. Will roll over when reaching max.
168+
OriginDigest []byte
169+
OriginSignature []byte
170+
OriginPubKey string
169171
}
170172

171173
func newBroadcastMetadata(md *ordering.Metadata) BroadcastMetadata {
@@ -184,10 +186,17 @@ func newBroadcastMetadata(md *ordering.Metadata) BroadcastMetadata {
184186
SenderAddr: md.BroadcastMsg.SenderAddr,
185187
OriginAddr: md.BroadcastMsg.OriginAddr,
186188
OriginMethod: md.BroadcastMsg.OriginMethod,
189+
OriginDigest: md.BroadcastMsg.OriginDigest,
190+
OriginSignature: md.BroadcastMsg.OriginSignature,
191+
OriginPubKey: md.BroadcastMsg.OriginPubKey,
187192
Method: m,
188193
Timestamp: broadcast.Epoch().Add(time.Duration(timestamp) * time.Second),
189194
ShardID: shardID,
190195
MachineID: machineID,
191196
SequenceNo: sequenceNo,
192197
}
193198
}
199+
200+
func (md BroadcastMetadata) Verify(msg protoreflect.ProtoMessage) (bool, error) {
201+
return authentication.Verify(md.OriginPubKey, md.OriginSignature, md.OriginDigest, msg)
202+
}

broadcast/consts.go

+1-1
Original file line numberDiff line numberDiff line change
@@ -20,4 +20,4 @@ const (
2020
ServerOriginAddr string = "server"
2121
)
2222

23-
type ServerHandler func(ctx context.Context, in protoreflect.ProtoMessage, broadcastID uint64, originAddr, originMethod string, options BroadcastOptions, id uint32, addr string)
23+
type ServerHandler func(ctx context.Context, in protoreflect.ProtoMessage, broadcastID uint64, originAddr, originMethod string, options BroadcastOptions, id uint32, addr string, originDigest, originSignature []byte, originPubKey string)

broadcast/processor.go

+18-3
Original file line numberDiff line numberDiff line change
@@ -39,6 +39,9 @@ type metadata struct {
3939
IsBroadcastClient bool
4040
SentCancellation bool
4141
HasReceivedClientReq bool
42+
OriginDigest []byte
43+
OriginPubKey string
44+
OriginSignature []byte
4245
}
4346

4447
func (p *BroadcastProcessor) run(msg *Content) {
@@ -50,6 +53,9 @@ func (p *BroadcastProcessor) run(msg *Content) {
5053
SendFn: msg.SendFn,
5154
Sent: false,
5255
SentCancellation: false,
56+
OriginDigest: msg.OriginDigest,
57+
OriginSignature: msg.OriginSignature,
58+
OriginPubKey: msg.OriginPubKey,
5359
}
5460
// methods is placed here and not in the metadata as an optimization strategy.
5561
// Testing shows that it does not allocate memory for it on the heap.
@@ -102,7 +108,7 @@ func (p *BroadcastProcessor) handleCancellation(bMsg *Msg, metadata *metadata) b
102108
p.log("broadcast: sent cancellation", nil, logging.MsgType(bMsg.MsgType.String()), logging.Stopping(false))
103109
metadata.SentCancellation = true
104110
go func(broadcastID uint64, cancellationMsg *cancellation) {
105-
_ = p.router.Send(broadcastID, "", "", cancellationMsg)
111+
_ = p.router.Send(broadcastID, "", "", nil, nil, "", cancellationMsg)
106112
}(p.broadcastID, bMsg.Cancellation)
107113
}
108114
return false
@@ -114,7 +120,7 @@ func (p *BroadcastProcessor) handleBroadcast(bMsg *Msg, methods []string, metada
114120
if !bMsg.Msg.options.AllowDuplication && alreadyBroadcasted(methods, bMsg.Method) {
115121
return false
116122
}
117-
err := p.router.Send(p.broadcastID, metadata.OriginAddr, metadata.OriginMethod, bMsg.Msg)
123+
err := p.router.Send(p.broadcastID, metadata.OriginAddr, metadata.OriginMethod, metadata.OriginDigest, metadata.OriginSignature, metadata.OriginPubKey, bMsg.Msg)
118124
p.log("broadcast: sending broadcast", err, logging.MsgType(bMsg.MsgType.String()), logging.Method(bMsg.Method), logging.Stopping(false), logging.IsBroadcastCall(metadata.isBroadcastCall()))
119125

120126
p.updateOrder(bMsg.Method, bMsg.Msg.options.ProgressTo)
@@ -126,7 +132,7 @@ func (p *BroadcastProcessor) handleReply(bMsg *Msg, metadata *metadata) bool {
126132
// BroadcastCall if origin addr is non-empty.
127133
if metadata.isBroadcastCall() {
128134
go func(broadcastID uint64, originAddr, originMethod string, replyMsg *reply) {
129-
err := p.router.Send(broadcastID, originAddr, originMethod, replyMsg)
135+
err := p.router.Send(broadcastID, originAddr, originMethod, metadata.OriginDigest, metadata.OriginSignature, metadata.OriginPubKey, replyMsg)
130136
p.log("broadcast: sent reply to client", err, logging.Method(originMethod), logging.MsgType(bMsg.MsgType.String()), logging.Stopping(true), logging.IsBroadcastCall(metadata.isBroadcastCall()))
131137
}(p.broadcastID, metadata.OriginAddr, metadata.OriginMethod, bMsg.Reply)
132138
// the request is done becuase we have sent a reply to the client
@@ -267,6 +273,15 @@ func (m *metadata) update(new *Content) {
267273
if m.OriginMethod == "" && new.OriginMethod != "" {
268274
m.OriginMethod = new.OriginMethod
269275
}
276+
if m.OriginPubKey == "" && new.OriginPubKey != "" {
277+
m.OriginPubKey = new.OriginPubKey
278+
}
279+
if m.OriginSignature == nil && new.OriginSignature != nil {
280+
m.OriginSignature = new.OriginSignature
281+
}
282+
if m.OriginDigest == nil && new.OriginDigest != nil {
283+
m.OriginDigest = new.OriginDigest
284+
}
270285
if m.SendFn == nil && new.SendFn != nil {
271286
m.SendFn = new.SendFn
272287
m.IsBroadcastClient = new.IsBroadcastClient

broadcast/processor_test.go

+1-1
Original file line numberDiff line numberDiff line change
@@ -21,7 +21,7 @@ type mockRouter struct {
2121
resp protoreflect.ProtoMessage
2222
}
2323

24-
func (r *mockRouter) Send(broadcastID uint64, addr, method string, req msg) error {
24+
func (r *mockRouter) Send(broadcastID uint64, addr, method string, originDigest, originSignature []byte, originPubKey string, req msg) error {
2525
switch val := req.(type) {
2626
case *broadcastMsg:
2727
r.reqType = "Broadcast"

broadcast/router.go

+9-9
Original file line numberDiff line numberDiff line change
@@ -15,12 +15,12 @@ import (
1515

1616
type Client struct {
1717
Addr string
18-
SendMsg func(broadcastID uint64, method string, msg protoreflect.ProtoMessage, timeout time.Duration) error
18+
SendMsg func(broadcastID uint64, method string, msg protoreflect.ProtoMessage, timeout time.Duration, originDigest, originSignature []byte, originPubKey string) error
1919
Close func() error
2020
}
2121

2222
type Router interface {
23-
Send(broadcastID uint64, addr, method string, req msg) error
23+
Send(broadcastID uint64, addr, method string, originDigest, originSignature []byte, originPubKey string, req msg) error
2424
Connect(addr string)
2525
}
2626

@@ -61,15 +61,15 @@ func (r *BroadcastRouter) registerState(state *BroadcastState) {
6161

6262
type msg interface{}
6363

64-
func (r *BroadcastRouter) Send(broadcastID uint64, addr, method string, req msg) error {
64+
func (r *BroadcastRouter) Send(broadcastID uint64, addr, method string, originDigest, originSignature []byte, originPubKey string, req msg) error {
6565
if r.addr == "" {
6666
panic("The listen addr on the broadcast server cannot be empty. Use the WithListenAddr() option when creating the server.")
6767
}
6868
switch val := req.(type) {
6969
case *broadcastMsg:
70-
return r.routeBroadcast(broadcastID, addr, method, val)
70+
return r.routeBroadcast(broadcastID, addr, method, val, originDigest, originSignature, originPubKey)
7171
case *reply:
72-
return r.routeClientReply(broadcastID, addr, method, val)
72+
return r.routeClientReply(broadcastID, addr, method, val, originDigest, originSignature, originPubKey)
7373
case *cancellation:
7474
r.canceler(broadcastID, val.srvAddrs)
7575
return nil
@@ -83,27 +83,27 @@ func (r *BroadcastRouter) Connect(addr string) {
8383
_, _ = r.getClient(addr)
8484
}
8585

86-
func (r *BroadcastRouter) routeBroadcast(broadcastID uint64, addr, method string, msg *broadcastMsg) error {
86+
func (r *BroadcastRouter) routeBroadcast(broadcastID uint64, addr, method string, msg *broadcastMsg, originDigest, originSignature []byte, originPubKey string) error {
8787
if handler, ok := r.serverHandlers[msg.method]; ok {
8888
// it runs an interceptor prior to broadcastCall, hence a different signature.
8989
// see: (srv *broadcastServer) registerBroadcastFunc(method string).
90-
handler(msg.ctx, msg.request, broadcastID, addr, method, msg.options, r.id, r.addr)
90+
handler(msg.ctx, msg.request, broadcastID, addr, method, msg.options, r.id, r.addr, originDigest, originSignature, originPubKey)
9191
return nil
9292
}
9393
err := errors.New("handler not found")
9494
r.log("router (broadcast): could not find handler", err, logging.BroadcastID(broadcastID), logging.NodeAddr(addr), logging.Method(method))
9595
return err
9696
}
9797

98-
func (r *BroadcastRouter) routeClientReply(broadcastID uint64, addr, method string, resp *reply) error {
98+
func (r *BroadcastRouter) routeClientReply(broadcastID uint64, addr, method string, resp *reply, originDigest, originSignature []byte, originPubKey string) error {
9999
// the client has initiated a broadcast call and the reply should be sent as an RPC
100100
if _, ok := r.clientHandlers[method]; ok && addr != "" {
101101
client, err := r.getClient(addr)
102102
if err != nil {
103103
r.log("router (reply): could not get client", err, logging.BroadcastID(broadcastID), logging.NodeAddr(addr), logging.Method(method))
104104
return err
105105
}
106-
err = client.SendMsg(broadcastID, method, resp.getResponse(), r.dialTimeout)
106+
err = client.SendMsg(broadcastID, method, resp.getResponse(), r.dialTimeout, originDigest, originSignature, originPubKey)
107107
r.log("router (reply): sending reply to client", err, logging.BroadcastID(broadcastID), logging.NodeAddr(addr), logging.Method(method))
108108
return err
109109
}

broadcast/shard_test.go

+1-1
Original file line numberDiff line numberDiff line change
@@ -16,7 +16,7 @@ type slowRouter struct {
1616
resp protoreflect.ProtoMessage
1717
}
1818

19-
func (r *slowRouter) Send(broadcastID uint64, addr, method string, req msg) error {
19+
func (r *slowRouter) Send(broadcastID uint64, addr, method string, originDigest, originSignature []byte, originPubKey string, req msg) error {
2020
time.Sleep(1 * time.Second)
2121
switch val := req.(type) {
2222
case *broadcastMsg:

broadcast/state.go

+3
Original file line numberDiff line numberDiff line change
@@ -132,6 +132,9 @@ type Content struct {
132132
IsCancellation bool
133133
OriginAddr string
134134
OriginMethod string
135+
OriginPubKey string
136+
OriginSignature []byte
137+
OriginDigest []byte
135138
ViewNumber uint64
136139
SenderAddr string
137140
CurrentMethod string

broadcastcall.go

+7-1
Original file line numberDiff line numberDiff line change
@@ -18,6 +18,9 @@ type BroadcastCallData struct {
1818
SenderAddr string
1919
OriginAddr string
2020
OriginMethod string
21+
OriginPubKey string
22+
OriginSignature []byte
23+
OriginDigest []byte
2124
ServerAddresses []string
2225
SkipSelf bool
2326
}
@@ -46,10 +49,13 @@ func (c RawConfiguration) BroadcastCall(ctx context.Context, d BroadcastCallData
4649
SenderAddr: d.SenderAddr,
4750
OriginAddr: d.OriginAddr,
4851
OriginMethod: d.OriginMethod,
52+
OriginPubKey: d.OriginPubKey,
53+
OriginSignature: d.OriginSignature,
54+
OriginDigest: d.OriginDigest,
4955
}}
5056
msg := &Message{Metadata: md, Message: d.Message}
5157
o := getCallOptions(E_Broadcast, opts)
52-
c.sign(msg)
58+
c.sign(msg, o.signOrigin)
5359

5460
var replyChan chan response
5561
if !o.noSendWaiting {

channel_test.go

+1-1
Original file line numberDiff line numberDiff line change
@@ -113,7 +113,7 @@ func TestChannelUnsuccessfulConnection(t *testing.T) {
113113
}
114114

115115
func TestChannelReconnection(t *testing.T) {
116-
srvAddr := "127.0.0.1:5000"
116+
srvAddr := "127.0.0.1:5005"
117117
// wait to start the server
118118
startServer, stopServer := testServerSetup(t, srvAddr, dummySrv())
119119
node, err := NewRawNode(srvAddr)

clientserver.go

+72-4
Original file line numberDiff line numberDiff line change
@@ -8,6 +8,7 @@ import (
88
"sync"
99
"time"
1010

11+
"github.com/relab/gorums/authentication"
1112
"github.com/relab/gorums/broadcast"
1213
"github.com/relab/gorums/logging"
1314
"github.com/relab/gorums/ordering"
@@ -34,6 +35,8 @@ type ClientServer struct {
3435
grpcServer *grpc.Server
3536
handlers map[string]requestHandler
3637
logger *slog.Logger
38+
auth *authentication.EllipticCurve
39+
allowList map[string]string
3740
ordering.UnimplementedGorumsServer
3841
}
3942

@@ -186,6 +189,10 @@ func (s *ClientServer) NodeStream(srv ordering.Gorums_NodeStreamServer) error {
186189
if err != nil {
187190
return err
188191
}
192+
err = s.verify(req)
193+
if err != nil {
194+
continue
195+
}
189196
if handler, ok := s.handlers[req.Metadata.Method]; ok {
190197
go handler(ServerCtx{Context: ctx, once: new(sync.Once), mut: &mut}, req, nil)
191198
mut.Lock()
@@ -221,6 +228,12 @@ func NewClientServer(lis net.Listener, opts ...ServerOption) *ClientServer {
221228
}
222229
ordering.RegisterGorumsServer(srv.grpcServer, srv)
223230
srv.lis = lis
231+
if serverOpts.auth != nil {
232+
srv.auth = serverOpts.auth
233+
}
234+
if serverOpts.allowList != nil {
235+
srv.allowList = serverOpts.allowList
236+
}
224237
return srv
225238
}
226239

@@ -239,6 +252,58 @@ func (srv *ClientServer) Serve(listener net.Listener) error {
239252
return srv.grpcServer.Serve(listener)
240253
}
241254

255+
func (srv *ClientServer) encodeMsg(req *Message) ([]byte, error) {
256+
// we must not consider the signature field when validating.
257+
// also the msgType must be set to requestType.
258+
signature := make([]byte, len(req.Metadata.AuthMsg.Signature))
259+
copy(signature, req.Metadata.AuthMsg.Signature)
260+
reqType := req.msgType
261+
req.Metadata.AuthMsg.Signature = nil
262+
req.msgType = 0
263+
encodedMsg, err := srv.auth.EncodeMsg(*req)
264+
req.Metadata.AuthMsg.Signature = make([]byte, len(signature))
265+
copy(req.Metadata.AuthMsg.Signature, signature)
266+
req.msgType = reqType
267+
return encodedMsg, err
268+
}
269+
270+
func (srv *ClientServer) verify(req *Message) error {
271+
if srv.auth == nil {
272+
return nil
273+
}
274+
if req.Metadata.AuthMsg == nil {
275+
return fmt.Errorf("missing authMsg")
276+
}
277+
if req.Metadata.AuthMsg.Signature == nil {
278+
return fmt.Errorf("missing signature")
279+
}
280+
if req.Metadata.AuthMsg.PublicKey == "" {
281+
return fmt.Errorf("missing publicKey")
282+
}
283+
authMsg := req.Metadata.AuthMsg
284+
if srv.allowList != nil {
285+
pemEncodedPub, ok := srv.allowList[authMsg.Sender]
286+
if !ok {
287+
return fmt.Errorf("not allowed")
288+
}
289+
if pemEncodedPub != authMsg.PublicKey {
290+
return fmt.Errorf("publicKey did not match")
291+
}
292+
}
293+
encodedMsg, err := srv.encodeMsg(req)
294+
if err != nil {
295+
return err
296+
}
297+
valid, err := srv.auth.VerifySignature(authMsg.PublicKey, encodedMsg, authMsg.Signature)
298+
if err != nil {
299+
return err
300+
}
301+
if !valid {
302+
return fmt.Errorf("invalid signature")
303+
}
304+
return nil
305+
}
306+
242307
func createClient(addr string, dialOpts []grpc.DialOption) (*broadcast.Client, error) {
243308
// necessary to ensure correct marshalling and unmarshalling of gorums messages
244309
// TODO: find a better solution
@@ -258,13 +323,16 @@ func createClient(addr string, dialOpts []grpc.DialOption) (*broadcast.Client, e
258323
}
259324
return &broadcast.Client{
260325
Addr: node.Address(),
261-
SendMsg: func(broadcastID uint64, method string, msg protoreflect.ProtoMessage, timeout time.Duration) error {
326+
SendMsg: func(broadcastID uint64, method string, msg protoreflect.ProtoMessage, timeout time.Duration, originDigest, originSignature []byte, originPubKey string) error {
262327
ctx, cancel := context.WithTimeout(context.Background(), timeout)
263328
defer cancel()
264329
cd := CallData{
265-
Method: method,
266-
Message: msg,
267-
BroadcastID: broadcastID,
330+
Method: method,
331+
Message: msg,
332+
BroadcastID: broadcastID,
333+
OriginDigest: originDigest,
334+
OriginSignature: originSignature,
335+
OriginPubKey: originPubKey,
268336
}
269337
_, err := node.RPCCall(ctx, cd)
270338
return err

0 commit comments

Comments
 (0)