Skip to content

Commit 5b734e0

Browse files
refactor(protoc): moved some of the generated code
1 parent a741341 commit 5b734e0

11 files changed

+282
-339
lines changed

broadcast_test.go

+81
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,81 @@
1+
package gorums_test
2+
3+
import (
4+
"context"
5+
"net"
6+
"testing"
7+
"time"
8+
9+
"github.com/relab/gorums"
10+
"google.golang.org/grpc"
11+
"google.golang.org/grpc/credentials/insecure"
12+
"google.golang.org/grpc/metadata"
13+
)
14+
15+
func createBroadcastServer(ownAddr string, srvAddrs []string) error {
16+
srv := gorums.NewServer()
17+
18+
lis, err := net.Listen("tcp", ":0")
19+
if err != nil {
20+
return err
21+
}
22+
23+
go func() { _ = srv.Serve(lis) }()
24+
defer srv.Stop()
25+
26+
srv.RegisterView(ownAddr, srvAddrs)
27+
srv.ListenForBroadcast()
28+
return nil
29+
}
30+
31+
func TestBroadcast(t *testing.T) {
32+
var message string
33+
signal := make(chan struct{})
34+
35+
srv := gorums.NewServer(gorums.WithConnectCallback(func(ctx context.Context) {
36+
m, ok := metadata.FromIncomingContext(ctx)
37+
if !ok {
38+
return
39+
}
40+
message = m.Get("message")[0]
41+
signal <- struct{}{}
42+
}))
43+
44+
lis, err := net.Listen("tcp", ":0")
45+
if err != nil {
46+
t.Fatal(err)
47+
}
48+
49+
go func() { _ = srv.Serve(lis) }()
50+
defer srv.Stop()
51+
52+
md := metadata.New(map[string]string{"message": "hello"})
53+
54+
mgr := gorums.NewRawManager(
55+
gorums.WithDialTimeout(time.Second),
56+
gorums.WithMetadata(md),
57+
gorums.WithGrpcDialOptions(
58+
grpc.WithBlock(),
59+
grpc.WithTransportCredentials(insecure.NewCredentials()),
60+
),
61+
)
62+
defer mgr.Close()
63+
64+
node, err := gorums.NewRawNode(lis.Addr().String())
65+
if err != nil {
66+
t.Fatal(err)
67+
}
68+
69+
if err = mgr.AddNode(node); err != nil {
70+
t.Fatal(err)
71+
}
72+
73+
select {
74+
case <-time.After(100 * time.Millisecond):
75+
case <-signal:
76+
}
77+
78+
if message != "hello" {
79+
t.Errorf("incorrect message: got '%s', want 'hello'", message)
80+
}
81+
}

clientserver.go

+149
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,149 @@
1+
package gorums
2+
3+
import (
4+
"context"
5+
"fmt"
6+
"net"
7+
"strings"
8+
9+
"github.com/google/uuid"
10+
"google.golang.org/grpc"
11+
"google.golang.org/grpc/credentials/insecure"
12+
"google.golang.org/grpc/metadata"
13+
"google.golang.org/protobuf/reflect/protoreflect"
14+
)
15+
16+
type ReplySpecHandler func([]protoreflect.ProtoMessage) (protoreflect.ProtoMessage, bool)
17+
18+
type ClientResponse struct {
19+
broadcastID string
20+
data protoreflect.ProtoMessage
21+
}
22+
23+
type ClientRequest struct {
24+
broadcastID string
25+
doneChan chan protoreflect.ProtoMessage
26+
handler ReplySpecHandler
27+
}
28+
29+
type ClientServer struct {
30+
respChan chan *ClientResponse
31+
reqChan chan *ClientRequest
32+
resps map[string][]protoreflect.ProtoMessage
33+
doneChans map[string]chan protoreflect.ProtoMessage
34+
handlers map[string]func(resps []protoreflect.ProtoMessage) (protoreflect.ProtoMessage, bool)
35+
listenAddr string
36+
}
37+
38+
func NewClientServer(listenAddr string) (*ClientServer, net.Listener, error) {
39+
srv := &ClientServer{
40+
respChan: make(chan *ClientResponse, 10),
41+
reqChan: make(chan *ClientRequest),
42+
resps: make(map[string][]protoreflect.ProtoMessage),
43+
doneChans: make(map[string]chan protoreflect.ProtoMessage),
44+
handlers: make(map[string]func(resps []protoreflect.ProtoMessage) (protoreflect.ProtoMessage, bool)),
45+
}
46+
lis, err := net.Listen("tcp", listenAddr)
47+
for err != nil {
48+
return nil, nil, err
49+
}
50+
srv.listenAddr = lis.Addr().String()
51+
go srv.handle()
52+
return srv, lis, nil
53+
}
54+
55+
func (srv *ClientServer) AddRequest(ctx context.Context, in protoreflect.ProtoMessage, handler ReplySpecHandler) (chan protoreflect.ProtoMessage, QuorumCallData) {
56+
broadcastID := uuid.New().String()
57+
cd := QuorumCallData{
58+
Message: in,
59+
Method: "protos.UniformBroadcast.SaveStudent",
60+
61+
BroadcastID: broadcastID,
62+
Sender: BroadcastClient,
63+
OriginAddr: srv.listenAddr,
64+
}
65+
doneChan := make(chan protoreflect.ProtoMessage)
66+
srv.reqChan <- &ClientRequest{
67+
broadcastID: broadcastID,
68+
doneChan: doneChan,
69+
handler: handler,
70+
}
71+
return doneChan, cd
72+
}
73+
74+
func (srv *ClientServer) AddResponse(ctx context.Context, resp protoreflect.ProtoMessage) error {
75+
md, ok := metadata.FromIncomingContext(ctx)
76+
if !ok {
77+
return fmt.Errorf("no metadata")
78+
}
79+
broadcastID := ""
80+
val := md.Get(BroadcastID)
81+
if val != nil && len(val) >= 1 {
82+
broadcastID = val[0]
83+
}
84+
if broadcastID == "" {
85+
return fmt.Errorf("no broadcastID")
86+
}
87+
srv.respChan <- &ClientResponse{
88+
broadcastID: broadcastID,
89+
data: resp,
90+
}
91+
return nil
92+
}
93+
94+
func (srv *ClientServer) handle() {
95+
for {
96+
select {
97+
case resp := <-srv.respChan:
98+
if _, ok := srv.resps[resp.broadcastID]; !ok {
99+
continue
100+
}
101+
srv.resps[resp.broadcastID] = append(srv.resps[resp.broadcastID], resp.data)
102+
response, done := srv.handlers[resp.broadcastID](srv.resps[resp.broadcastID])
103+
if done {
104+
srv.doneChans[resp.broadcastID] <- response
105+
close(srv.doneChans[resp.broadcastID])
106+
delete(srv.resps, resp.broadcastID)
107+
delete(srv.doneChans, resp.broadcastID)
108+
delete(srv.handlers, resp.broadcastID)
109+
}
110+
case req := <-srv.reqChan:
111+
srv.resps[req.broadcastID] = make([]protoreflect.ProtoMessage, 0)
112+
srv.doneChans[req.broadcastID] = req.doneChan
113+
srv.handlers[req.broadcastID] = req.handler
114+
}
115+
}
116+
}
117+
118+
func ConvertToType[T protoreflect.ProtoMessage](handler func([]T) (T, bool)) func(d []protoreflect.ProtoMessage) (protoreflect.ProtoMessage, bool) {
119+
return func(d []protoreflect.ProtoMessage) (protoreflect.ProtoMessage, bool) {
120+
data := make([]T, len(d))
121+
for i, elem := range d {
122+
data[i] = elem.(T)
123+
}
124+
return handler(data)
125+
}
126+
}
127+
128+
func ServerClientRPC(method string) func(addr, broadcastID string, in protoreflect.ProtoMessage, opts ...grpc.CallOption) (any, error) {
129+
return func(addr, broadcastID string, in protoreflect.ProtoMessage, opts ...grpc.CallOption) (any, error) {
130+
tmp := strings.Split(method, ".")
131+
m := ""
132+
if len(tmp) >= 1 {
133+
m = tmp[len(tmp)-1]
134+
}
135+
method = "/protos.ClientServer/Client" + m
136+
cc, err := grpc.Dial(addr, grpc.WithTransportCredentials(insecure.NewCredentials()))
137+
if err != nil {
138+
return nil, err
139+
}
140+
out := new(any)
141+
md := metadata.Pairs(BroadcastID, broadcastID)
142+
ctx := metadata.NewOutgoingContext(context.Background(), md)
143+
err = cc.Invoke(ctx, method, in, out, opts...)
144+
if err != nil {
145+
return nil, err
146+
}
147+
return nil, nil
148+
}
149+
}

cmd/protoc-gen-gorums/dev/config.go

+4-5
Original file line numberDiff line numberDiff line change
@@ -8,11 +8,10 @@ import (
88
// procedure calls may be invoked.
99
type Configuration struct {
1010
gorums.RawConfiguration
11-
nodes []*Node
12-
qspec QuorumSpec
13-
srv *clientServerImpl
14-
listenAddr string
15-
replySpec ReplySpec
11+
nodes []*Node
12+
qspec QuorumSpec
13+
srv *clientServerImpl
14+
replySpec ReplySpec
1615
}
1716

1817
// ConfigurationFromRaw returns a new Configuration from the given raw configuration and QuorumSpec.

cmd/protoc-gen-gorums/dev/server.go

+11-96
Original file line numberDiff line numberDiff line change
@@ -1,15 +1,8 @@
11
package dev
22

33
import (
4-
context "context"
5-
"net"
6-
"strings"
7-
84
"github.com/relab/gorums"
95
grpc "google.golang.org/grpc"
10-
"google.golang.org/grpc/credentials/insecure"
11-
"google.golang.org/grpc/metadata"
12-
protoreflect "google.golang.org/protobuf/reflect/protoreflect"
136
)
147

158
type Server struct {
@@ -61,101 +54,23 @@ func (b *Broadcast) GetMetadata() gorums.BroadcastMetadata {
6154
return b.metadata
6255
}
6356

64-
type clientResponse struct {
65-
broadcastID string
66-
data protoreflect.ProtoMessage
67-
}
68-
69-
type clientRequest struct {
70-
broadcastID string
71-
doneChan chan protoreflect.ProtoMessage
72-
handler func([]protoreflect.ProtoMessage) (protoreflect.ProtoMessage, error)
73-
}
74-
7557
type clientServerImpl struct {
58+
*gorums.ClientServer
7659
grpcServer *grpc.Server
77-
respChan chan *clientResponse
78-
reqChan chan *clientRequest
79-
resps map[string][]protoreflect.ProtoMessage
80-
doneChans map[string]chan protoreflect.ProtoMessage
81-
handlers map[string]func(resps []protoreflect.ProtoMessage) (protoreflect.ProtoMessage, error)
8260
}
8361

84-
func (c *Configuration) RegisterClientServer(listenAddr string, replySpec ReplySpec) {
85-
var opts []grpc.ServerOption
86-
srv := &clientServerImpl{
62+
func (c *Configuration) RegisterClientServer(listenAddr string, replySpec ReplySpec, opts ...grpc.ServerOption) error {
63+
srvImpl := &clientServerImpl{
8764
grpcServer: grpc.NewServer(opts...),
88-
respChan: make(chan *clientResponse, 10),
89-
reqChan: make(chan *clientRequest),
90-
resps: make(map[string][]protoreflect.ProtoMessage),
91-
doneChans: make(map[string]chan protoreflect.ProtoMessage),
92-
handlers: make(map[string]func(resps []protoreflect.ProtoMessage) (protoreflect.ProtoMessage, error)),
9365
}
94-
lis, err := net.Listen("tcp", listenAddr)
95-
for err != nil {
96-
return
66+
srv, lis, err := gorums.NewClientServer(listenAddr)
67+
if err != nil {
68+
return err
9769
}
98-
c.listenAddr = lis.Addr().String()
99-
srv.grpcServer.RegisterService(&clientServer_ServiceDesc, srv)
100-
go srv.grpcServer.Serve(lis)
101-
go srv.handle()
102-
c.srv = srv
70+
srvImpl.grpcServer.RegisterService(&clientServer_ServiceDesc, srvImpl)
71+
go srvImpl.grpcServer.Serve(lis)
72+
srvImpl.ClientServer = srv
73+
c.srv = srvImpl
10374
c.replySpec = replySpec
104-
}
105-
106-
func (srv *clientServerImpl) handle() {
107-
for {
108-
select {
109-
case resp := <-srv.respChan:
110-
if _, ok := srv.resps[resp.broadcastID]; !ok {
111-
continue
112-
}
113-
srv.resps[resp.broadcastID] = append(srv.resps[resp.broadcastID], resp.data)
114-
response, err := srv.handlers[resp.broadcastID](srv.resps[resp.broadcastID])
115-
if err == nil {
116-
srv.doneChans[resp.broadcastID] <- response
117-
close(srv.doneChans[resp.broadcastID])
118-
delete(srv.resps, resp.broadcastID)
119-
delete(srv.doneChans, resp.broadcastID)
120-
delete(srv.handlers, resp.broadcastID)
121-
}
122-
case req := <-srv.reqChan:
123-
srv.resps[req.broadcastID] = make([]protoreflect.ProtoMessage, 0)
124-
srv.doneChans[req.broadcastID] = req.doneChan
125-
srv.handlers[req.broadcastID] = req.handler
126-
}
127-
}
128-
}
129-
130-
func convertToType[T protoreflect.ProtoMessage](handler func([]T) (T, error)) func(d []protoreflect.ProtoMessage) (protoreflect.ProtoMessage, error) {
131-
return func(d []protoreflect.ProtoMessage) (protoreflect.ProtoMessage, error) {
132-
data := make([]T, len(d))
133-
for i, elem := range d {
134-
data[i] = elem.(T)
135-
}
136-
return handler(data)
137-
}
138-
}
139-
140-
func _serverClientRPC(method string) func(addr, broadcastID string, in protoreflect.ProtoMessage, opts ...grpc.CallOption) (any, error) {
141-
return func(addr, broadcastID string, in protoreflect.ProtoMessage, opts ...grpc.CallOption) (any, error) {
142-
tmp := strings.Split(method, ".")
143-
m := ""
144-
if len(tmp) >= 1 {
145-
m = tmp[len(tmp)-1]
146-
}
147-
method = "/protos.ClientServer/Client" + m
148-
cc, err := grpc.Dial(addr, grpc.WithTransportCredentials(insecure.NewCredentials()))
149-
if err != nil {
150-
return nil, err
151-
}
152-
out := new(any)
153-
md := metadata.Pairs(gorums.BroadcastID, broadcastID)
154-
ctx := metadata.NewOutgoingContext(context.Background(), md)
155-
err = cc.Invoke(ctx, method, in, out, opts...)
156-
if err != nil {
157-
return nil, err
158-
}
159-
return nil, nil
160-
}
75+
return nil
16176
}

0 commit comments

Comments
 (0)