-
Notifications
You must be signed in to change notification settings - Fork 0
Expand file tree
/
Copy pathserver.go
More file actions
225 lines (184 loc) Β· 5.97 KB
/
Copy pathserver.go
File metadata and controls
225 lines (184 loc) Β· 5.97 KB
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
182
183
184
185
186
187
188
189
190
191
192
193
194
195
196
197
198
199
200
201
202
203
204
205
206
207
208
209
210
211
212
213
214
215
216
217
218
219
220
221
222
223
224
225
package hornet
import (
"context"
"fmt"
"log/slog"
"os"
"reflect"
"strings"
"sync"
"google.golang.org/grpc"
"google.golang.org/grpc/codes"
"google.golang.org/grpc/status"
"google.golang.org/protobuf/proto"
)
// serviceInfo wraps information about a service. It is very similar to
// grpc.ServiceDesc and is constructed from it for internal purposes.
type serviceInfo struct {
// Contains the implementation for the methods in this service.
serviceImpl any
methods map[string]*grpc.MethodDesc
}
type serverOptions struct {
logger *slog.Logger
}
var defaultServerOptions = serverOptions{
logger: slog.Default(),
}
var _ grpc.ServiceRegistrar = (*Server)(nil)
// Server is a gRPC server that implements the [PluginHandler] interface to
// be used in Wasm plugins. It supports unary RPCs only.
//
// It is similar to grpc.Server, but it does not implement the net.Listener
// interface. Instead, it implements the [PluginHandler] interface to
// process the bytes sent to the plugin as a gRPC request.
type Server struct {
opts serverOptions
mu sync.Mutex // guards following fields
services map[string]*serviceInfo
}
func NewServer(opt ...ServerOption) *Server {
opts := defaultServerOptions
for _, o := range opt {
o.applyServer(&opts)
}
return &Server{
opts: opts,
services: make(map[string]*serviceInfo),
}
}
// RegisterService registers a service and its implementation to the gRPC
// server. It is called from the IDL generated code. This must be called in the
// init function in the Wasm plugin. If ss is non-nil, its type is checked to
// ensure it implements sd.HandlerType.
//
// If there is an error during registration, the server will log the error and
// exit the process. This is to ensure that the plugin does not run with an
// invalid state.
//
// Note that this method does not support stream RPCs. If the service
// description contains stream methods, a warning will be logged and the stream
// methods will be ignored.
func (s *Server) RegisterService(sd *grpc.ServiceDesc, ss any) {
if ss != nil {
ht := reflect.TypeOf(sd.HandlerType).Elem()
st := reflect.TypeOf(ss)
if !st.Implements(ht) {
s.opts.logger.Error("proto: Server.RegisterService found an incompatible handler type", "want", ht, "got", st)
os.Exit(1) // That's what the original gRPC implementation does.
}
}
err := s.register(sd, ss)
if err != nil {
s.opts.logger.Error("proto: Server.RegisterService failed", "error", err)
os.Exit(1) // That's what the original gRPC implementation does.
}
}
func (s *Server) register(sd *grpc.ServiceDesc, ss any) error {
s.mu.Lock()
defer s.mu.Unlock()
s.opts.logger.Debug("Registering service", "service", sd.ServiceName)
if _, ok := s.services[sd.ServiceName]; ok {
return fmt.Errorf("found duplicate service registration: %q", sd.ServiceName)
}
if len(sd.Streams) > 0 {
s.opts.logger.Warn("proto: Server.RegisterService found stream service, "+
"streams are not supported in Wasm plugins", "service", sd.ServiceName)
}
info := &serviceInfo{
serviceImpl: ss,
methods: make(map[string]*grpc.MethodDesc),
}
for i := range sd.Methods {
d := &sd.Methods[i]
info.methods[d.MethodName] = d
}
s.services[sd.ServiceName] = info
return nil
}
// Handle implements the [PluginHandler] interface and processes the bytes
// sent to the plugin as a gRPC request.
func (s *Server) Handle(fn string, reqBytes []byte) []byte {
// Start a new context for each request.
ctx := context.Background()
pos := strings.LastIndex(fn, "/")
if pos == -1 {
return s.handleError(
status.New(codes.Unimplemented, "malformed method name"),
"method", fn,
)
}
service := strings.TrimPrefix(fn[1:pos], "/")
method := fn[pos+1:]
srv, ok := s.services[service]
if !ok {
return s.handleError(
status.New(codes.Unimplemented, "unknown service"),
"service", service,
)
}
sd, ok := srv.methods[method]
if !ok {
return s.handleError(
status.New(codes.Unimplemented, "unknown method"),
"service", service, "method", method,
)
}
decFn := func(v any) error {
return protoUnmarshal(reqBytes, v)
}
resp, err := sd.Handler(srv.serviceImpl, ctx, decFn, nil)
if err != nil {
st, ok := status.FromError(err)
if !ok {
st = status.FromContextError(err)
err = st.Err()
}
return s.handleError(st, "service", service, "method", method, "error", err)
}
// NB: We overwrite the request bytes to reuse the same bytes buffer and
// possibly avoid allocations.
respBytes, err := protoMarshalAppend(reqBytes[:0], resp)
if err != nil {
return s.handleError(
status.New(codes.Internal, "error marshalling response"),
"service", service, "method", method, "response", resp, "error", err,
)
}
return respBytes
}
func (s *Server) handleError(st *status.Status, args ...any) []byte {
s.opts.logger.Debug("ERROR: proto: Server.Handle "+st.Message(), args...)
// The first byte tells the client if it's an error or a valid response.
out, err := proto.MarshalOptions{}.MarshalAppend([]byte{1}, st.Proto())
if err != nil {
// This should never happen, as we are marshalling a status message. If it
// does, we panic, as we cannot return a proper error message to the client.
panic(fmt.Errorf("proto: error marshalling error: %w", err))
}
return out
}
func protoMarshalAppend(data []byte, v any) ([]byte, error) {
msg, ok := v.(proto.Message)
if !ok {
return data, fmt.Errorf("proto: error marshalling data: expected proto.Message, got %T", v)
}
// The first byte tells the client if it's an error or a valid response.
data = append(data, 0)
data, err := proto.MarshalOptions{}.MarshalAppend(data, msg)
if err != nil {
return data, fmt.Errorf("proto: error marshalling data: %w", err)
}
return data, nil
}
func protoUnmarshal(data []byte, v any) error {
msg, ok := v.(proto.Message)
if !ok {
return fmt.Errorf("proto: error unmarshalling data: expected proto.Message, got %T", v)
}
err := proto.Unmarshal(data, msg)
if err != nil {
return fmt.Errorf("proto: error unmarshalling data: %w", err)
}
return nil
}