Skip to content

Commit 13f2b0d

Browse files
test(broadcast): added tests for broadcast
1 parent 15c8d66 commit 13f2b0d

7 files changed

+139
-82
lines changed

broadcast.go

+18-28
Original file line numberDiff line numberDiff line change
@@ -2,9 +2,6 @@ package gorums
22

33
import (
44
"context"
5-
"errors"
6-
"log"
7-
"sync"
85
"time"
96

107
"github.com/google/uuid"
@@ -14,24 +11,20 @@ import (
1411
)
1512

1613
type broadcastServer struct {
17-
sync.RWMutex
1814
id string
1915
addr string
2016
broadcastedMsgs map[string]map[string]bool
2117
handlers map[string]broadcastFunc
2218
broadcastChan chan *broadcastMsg
23-
responseChan chan responseMsg
19+
responseChan chan *responseMsg
2420
clientHandlers map[string]func(addr, broadcastID string, req protoreflect.ProtoMessage, opts ...grpc.CallOption) (any, error)
2521
bNew iBroadcastStruct
2622
timeout time.Duration
2723
clientReqs *RequestMap
2824
view serverView
2925
middlewares []func(BroadcastMetadata) error
3026
stopChan chan struct{}
31-
//mutex sync.RWMutex
32-
//returnedToClientMsgs map[string]bool
33-
//clientReqs map[string]*clientRequest
34-
//clientReqsMutex sync.Mutex
27+
async bool
3528
}
3629

3730
func newBroadcastServer() *broadcastServer {
@@ -41,12 +34,10 @@ func newBroadcastServer() *broadcastServer {
4134
clientHandlers: make(map[string]func(addr, broadcastID string, req protoreflect.ProtoMessage, opts ...grpc.CallOption) (any, error)),
4235
broadcastChan: make(chan *broadcastMsg, 1000),
4336
handlers: make(map[string]broadcastFunc),
44-
responseChan: make(chan responseMsg),
37+
responseChan: make(chan *responseMsg),
4538
clientReqs: NewRequestMap(),
4639
middlewares: make([]func(BroadcastMetadata) error, 0),
4740
stopChan: make(chan struct{}, 0),
48-
//returnedToClientMsgs: make(map[string]bool),
49-
//clientReqs: make(map[string]*clientRequest),
5041
}
5142
}
5243

@@ -77,26 +68,25 @@ func (srv *broadcastServer) handleClientResponses() {
7768
}
7869
}
7970

80-
func (srv *broadcastServer) handle(response responseMsg) {
71+
func (srv *broadcastServer) handle(response *responseMsg) {
8172
broadcastID := response.getBroadcastID()
82-
req, handled := srv.clientReqs.GetSet(broadcastID)
73+
req, handled := srv.clientReqs.GetAndSetHandled(broadcastID)
8374
if handled {
8475
// this server has not received a request directly from a client
8576
// hence, the response should be ignored
8677
// already handled and can not be removed yet. It is possible to get duplicates.
8778
return
8879
}
89-
select {
90-
case <-req.ctx.Done():
91-
// client request has been cancelled by client
92-
log.Println("CLIENT REQUEST HAS BEEN CANCELLED")
93-
return
94-
default:
95-
}
80+
//select {
81+
//case <-req.ctx.Done():
82+
// // client request has been cancelled by client
83+
// log.Println("CLIENT REQUEST HAS BEEN CANCELLED")
84+
// return
85+
//default:
86+
//}
9687
if !response.valid() {
9788
// the response is old and should have timed out, but may not due to scheduling.
9889
// the timeout msg should arrive soon.
99-
log.Println("NOT VALID")
10090
return
10191
}
10292
if handler, ok := srv.clientHandlers[req.metadata.BroadcastMsg.OriginMethod]; ok && req.metadata.BroadcastMsg.OriginAddr != "" {
@@ -121,17 +111,17 @@ func (srv *broadcastServer) clientReturn(resp ResponseTypes, err error, metadata
121111
}
122112

123113
func (srv *broadcastServer) returnToClient(broadcastID string, resp ResponseTypes, err error) {
124-
srv.Lock()
125-
defer srv.Unlock()
114+
//srv.Lock()
115+
//defer srv.Unlock()
126116
if !srv.alreadyReturnedToClient(broadcastID) {
127117
srv.responseChan <- newResponseMessage(resp, err, broadcastID, clientResponse, srv.timeout)
128118
}
129119
}
130120

131-
func (srv *broadcastServer) timeoutClientResponse(ctx ServerCtx, in *Message, finished chan<- *Message) {
132-
time.Sleep(srv.timeout)
133-
srv.responseChan <- newResponseMessage(protoreflect.ProtoMessage(nil), errors.New("server timed out"), in.Metadata.BroadcastMsg.GetBroadcastID(), timeout, srv.timeout)
134-
}
121+
//func (srv *broadcastServer) timeoutClientResponse(ctx ServerCtx, in *Message, finished chan<- *Message) {
122+
// time.Sleep(srv.timeout)
123+
// srv.responseChan <- newResponseMessage(protoreflect.ProtoMessage(nil), errors.New("server timed out"), in.Metadata.BroadcastMsg.GetBroadcastID(), timeout, srv.timeout)
124+
//}
135125

136126
func (srv *broadcastServer) alreadyReturnedToClient(broadcastID string) bool {
137127
req, ok := srv.clientReqs.Get(broadcastID)

broadcastTypes.go

+14-22
Original file line numberDiff line numberDiff line change
@@ -38,14 +38,6 @@ type defaultImplementationFunc[T RequestTypes, V ResponseTypes] func(ServerCtx,
3838

3939
type implementationFunc[T RequestTypes, V iBroadcastStruct] func(ServerCtx, T, V)
4040

41-
type responseMsg interface {
42-
getResponse() ResponseTypes
43-
getError() error
44-
getBroadcastID() string
45-
valid() bool
46-
getType() respType
47-
}
48-
4941
type respType int
5042

5143
const (
@@ -55,7 +47,7 @@ const (
5547
done
5648
)
5749

58-
type responseMessage struct {
50+
type responseMsg struct {
5951
response ResponseTypes
6052
err error
6153
broadcastID string
@@ -64,8 +56,8 @@ type responseMessage struct {
6456
respType respType
6557
}
6658

67-
func newResponseMessage(response ResponseTypes, err error, broadcastID string, respType respType, ttl time.Duration) *responseMessage {
68-
return &responseMessage{
59+
func newResponseMessage(response ResponseTypes, err error, broadcastID string, respType respType, ttl time.Duration) *responseMsg {
60+
return &responseMsg{
6961
response: response,
7062
err: err,
7163
broadcastID: broadcastID,
@@ -75,23 +67,23 @@ func newResponseMessage(response ResponseTypes, err error, broadcastID string, r
7567
}
7668
}
7769

78-
func (r *responseMessage) getResponse() ResponseTypes {
70+
func (r *responseMsg) getResponse() ResponseTypes {
7971
return r.response
8072
}
8173

82-
func (r *responseMessage) getError() error {
74+
func (r *responseMsg) getError() error {
8375
return r.err
8476
}
8577

86-
func (r *responseMessage) getBroadcastID() string {
78+
func (r *responseMsg) getBroadcastID() string {
8779
return r.broadcastID
8880
}
8981

90-
func (r *responseMessage) valid() bool {
82+
func (r *responseMsg) valid() bool {
9183
return r.respType == clientResponse && time.Since(r.timestamp) <= r.ttl
9284
}
9385

94-
func (r *responseMessage) getType() respType {
86+
func (r *responseMsg) getType() respType {
9587
return r.respType
9688
}
9789

@@ -260,13 +252,13 @@ func (list *RequestMap) Get(identifier string) (clientRequest, bool) {
260252
return elem, ok
261253
}
262254

263-
func (list *RequestMap) Set(identifier string, elem clientRequest) {
264-
list.mutex.Lock()
265-
defer list.mutex.Unlock()
266-
list.data[identifier] = elem
267-
}
255+
//func (list *RequestMap) Set(identifier string, elem clientRequest) {
256+
// list.mutex.Lock()
257+
// defer list.mutex.Unlock()
258+
// list.data[identifier] = elem
259+
//}
268260

269-
func (list *RequestMap) GetSet(identifier string) (clientRequest, bool) {
261+
func (list *RequestMap) GetAndSetHandled(identifier string) (clientRequest, bool) {
270262
list.mutex.Lock()
271263
defer list.mutex.Unlock()
272264
elem, ok := list.data[identifier]

broadcast_test.go

+83-3
Original file line numberDiff line numberDiff line change
@@ -62,7 +62,19 @@ func (b *testBroadcast) Broadcast(req *testBroadcastRequest, opts ...BroadcastOp
6262
for _, opt := range opts {
6363
opt(&data)
6464
}
65-
b.sp.BroadcastHandler("broadcast", req, b.metadata, data)
65+
b.sp.BroadcastHandler("Broadcast", req, b.metadata, data)
66+
}
67+
68+
func (b *testBroadcast) CanBroadcast(req *testBroadcastRequest, opts ...BroadcastOption) {
69+
data := NewBroadcastOptions()
70+
for _, opt := range opts {
71+
opt(&data)
72+
}
73+
b.sp.BroadcastHandler("CanBroadcast", req, b.metadata, data)
74+
}
75+
76+
func (b *testBroadcast) SendToClient(resp protoreflect.ProtoMessage, err error) {
77+
b.sp.ReturnToClientHandler(resp, err, b.metadata)
6678
}
6779

6880
//
@@ -175,6 +187,10 @@ type testBroadcastServer struct {
175187
req *testBroadcastRequest
176188
}
177189

190+
func (srv *testBroadcastServer) SendToClient(resp protoreflect.ProtoMessage, err error, broadcastID string) {
191+
srv.RetToClient(resp, err, broadcastID)
192+
}
193+
178194
func newTestBroadcastServer() *testBroadcastServer {
179195
srv := &testBroadcastServer{
180196
Server: NewServer(),
@@ -194,6 +210,18 @@ func (srv *testBroadcastServer) Broadcast(ctx ServerCtx, request *testBroadcastR
194210
//broadcast.Broadcast(&testBroadcastRequest{})
195211
}
196212

213+
func (srv *testBroadcastServer) CanBroadcast(ctx ServerCtx, request *testBroadcastRequest, broadcast *testBroadcast) {
214+
srv.numMsgs++
215+
srv.req = request
216+
go broadcast.Broadcast(&testBroadcastRequest{})
217+
}
218+
219+
func (srv *testBroadcastServer) SendToClientBroadcast(ctx ServerCtx, request *testBroadcastRequest, broadcast *testBroadcast) {
220+
srv.numMsgs++
221+
srv.req = request
222+
go broadcast.SendToClient(&testBroadcastRequest{}, nil)
223+
}
224+
197225
func createReq(val string) (ServerCtx, *Message, chan<- *Message) {
198226
var mut sync.Mutex
199227
mut.Lock()
@@ -216,16 +244,16 @@ func createReq(val string) (ServerCtx, *Message, chan<- *Message) {
216244
}
217245

218246
func TestBroadcastHandler(t *testing.T) {
219-
handlerName := "broadcast"
247+
handlerName := "Broadcast"
220248

221249
// create a server
222250
srv := newTestBroadcastServer()
223251
// register the broadcast handler. Similar to proto option: broadcast
224252
srv.RegisterHandler(handlerName, BroadcastHandler(srv.Broadcast, srv.Server))
225253

226-
// create a request
227254
vals := []string{"test1", "test2", "test3"}
228255
for _, val := range vals {
256+
// create the request
229257
srvCtx, req, finished := createReq(val)
230258
// call the server handler
231259
srv.srv.handlers[handlerName](srvCtx, req, finished)
@@ -235,3 +263,55 @@ func TestBroadcastHandler(t *testing.T) {
235263
}
236264
}
237265
}
266+
267+
func TestCanBroadcastHandler(t *testing.T) {
268+
handlerName := "CanBroadcast"
269+
270+
// create a server
271+
srv := newTestBroadcastServer()
272+
// register the broadcast handler. Similar to proto option: broadcast
273+
srv.RegisterHandler(handlerName, BroadcastHandler(srv.CanBroadcast, srv.Server))
274+
275+
vals := []string{"test1", "test2", "test3"}
276+
for _, val := range vals {
277+
// create the request
278+
srvCtx, req, finished := createReq(val)
279+
// call the server handler
280+
srv.srv.handlers[handlerName](srvCtx, req, finished)
281+
broadcastMsg := <-srv.broadcastSrv.broadcastChan
282+
283+
if broadcastMsg == nil {
284+
t.Errorf("broadcastMsg should not be nil")
285+
continue
286+
}
287+
if broadcastMsg.broadcastID != req.Metadata.BroadcastMsg.BroadcastID {
288+
t.Errorf("broadcastID = %v, expected %s", broadcastMsg.broadcastID, req.Metadata.BroadcastMsg.BroadcastID)
289+
}
290+
}
291+
}
292+
293+
func TestBroadcastSendToClient(t *testing.T) {
294+
handlerName := "SendToClientBroadcast"
295+
296+
// create a server
297+
srv := newTestBroadcastServer()
298+
// register the broadcast handler. Similar to proto option: broadcast
299+
srv.RegisterHandler(handlerName, BroadcastHandler(srv.SendToClientBroadcast, srv.Server))
300+
301+
vals := []string{"test1", "test2", "test3"}
302+
for _, val := range vals {
303+
// create the request
304+
srvCtx, req, finished := createReq(val)
305+
// call the server handler
306+
srv.srv.handlers[handlerName](srvCtx, req, finished)
307+
responseMsg := <-srv.broadcastSrv.responseChan
308+
309+
if responseMsg == nil {
310+
t.Errorf("responseMsg should not be nil")
311+
continue
312+
}
313+
if responseMsg.getBroadcastID() != req.Metadata.BroadcastMsg.BroadcastID {
314+
t.Errorf("broadcastID = %v, expected %s", responseMsg.getBroadcastID(), req.Metadata.BroadcastMsg.BroadcastID)
315+
}
316+
}
317+
}

cmd/protoc-gen-gorums/dev/zorums_broadcast_gorums.pb.go

+9-9
Some generated files are not rendered by default. Learn more about customizing how changed files appear on GitHub.

cmd/protoc-gen-gorums/dev/zorums_qspec_gorums.pb.go

+1-1
Some generated files are not rendered by default. Learn more about customizing how changed files appear on GitHub.

cmd/protoc-gen-gorums/gengorums/template_broadcast.go

+3-3
Original file line numberDiff line numberDiff line change
@@ -7,11 +7,11 @@ var broadcastVar = `
77
var broadcastSignature = `func (b *Broadcast) {{.Method.GoName}}(req *{{in .GenFile .Method}}, opts... gorums.BroadcastOption) {`
88

99
var broadcastBody = `
10-
data := gorums.NewBroadcastOptions()
10+
options := gorums.NewBroadcastOptions()
1111
for _, opt := range opts {
12-
opt(&data)
12+
opt(&options)
1313
}
14-
b.sp.BroadcastHandler("{{.Method.Desc.FullName}}", req, b.metadata, data)
14+
b.sp.BroadcastHandler("{{.Method.Desc.FullName}}", req, b.metadata, options)
1515
}
1616
`
1717

0 commit comments

Comments
 (0)