diff --git a/channel_test.go b/channel_test.go index 9eab63148..30263730d 100644 --- a/channel_test.go +++ b/channel_test.go @@ -31,8 +31,8 @@ import ( func TestReadWriteMessage(t *testing.T) { var ( w, r = net.Pipe() - ch = newChannel(w) - rch = newChannel(r) + ch = newChannel(w, 0) + rch = newChannel(r, 0) messages = [][]byte{ []byte("hello"), []byte("this is a test"), @@ -90,7 +90,7 @@ func TestReadWriteMessage(t *testing.T) { func TestMessageOversize(t *testing.T) { var ( w, _ = net.Pipe() - wch = newChannel(w) + wch = newChannel(w, 0) msg = bytes.Repeat([]byte("a message of massive length"), 512<<10) errs = make(chan error, 1) ) diff --git a/server_test.go b/server_test.go index cf34986d6..d947db6a6 100644 --- a/server_test.go +++ b/server_test.go @@ -19,6 +19,7 @@ package ttrpc import ( "bytes" "context" + "crypto/md5" "errors" "fmt" "net" @@ -61,10 +62,17 @@ func (tc *testingClient) Test(ctx context.Context, req *internal.TestPayload) (* } // testingServer is what would be implemented by the user of this package. -type testingServer struct{} +type testingServer struct { + echoOnce bool +} func (s *testingServer) Test(ctx context.Context, req *internal.TestPayload) (*internal.TestPayload, error) { - tp := &internal.TestPayload{Foo: strings.Repeat(req.Foo, 2)} + tp := &internal.TestPayload{} + if s.echoOnce { + tp.Foo = req.Foo + } else { + tp.Foo = strings.Repeat(req.Foo, 2) + } if dl, ok := ctx.Deadline(); ok { tp.Deadline = dl.UnixNano() } @@ -299,37 +307,152 @@ func TestServerClose(t *testing.T) { } func TestOversizeCall(t *testing.T) { - var ( - ctx = context.Background() - server = mustServer(t)(NewServer()) - addr, listener = newTestListener(t) - errs = make(chan error, 1) - client, cleanup = newTestClient(t, addr) - ) - defer cleanup() - defer listener.Close() - go func() { - errs <- server.Serve(ctx, listener) - }() + type testCase struct { + name string + echoOnce bool + clientLimit int + serverLimit int + requestSize int + clientFail bool + serverFail bool + } + + overhead := getWireMessageOverhead(t) + + runTest := func(t *testing.T, tc *testCase) { + var ( + ctx = context.Background() + server = mustServer(t)(NewServer(WithServerWireMessageLimit(tc.serverLimit))) + addr, listener = newTestListener(t) + errs = make(chan error, 1) + client, cleanup = newTestClient(t, addr, WithClientWireMessageLimit(tc.clientLimit)) + ) + defer cleanup() + defer listener.Close() + go func() { + errs <- server.Serve(ctx, listener) + }() + + registerTestingService(server, &testingServer{echoOnce: tc.echoOnce}) + + req := &internal.TestPayload{ + Foo: strings.Repeat("a", tc.requestSize), + } + rsp := &internal.TestPayload{} + + err := client.Call(ctx, serviceName, "Test", req, rsp) + if tc.clientFail { + if err == nil { + t.Fatalf("expected error from oversized message") + } else if status, ok := status.FromError(err); !ok { + t.Fatalf("expected status present in error: %v", err) + } else if status.Code() != codes.ResourceExhausted { + t.Fatalf("expected code: %v != %v", status.Code(), codes.ResourceExhausted) + } + } else if tc.serverFail { + if err == nil { + t.Fatalf("expected error from server-side oversized message") + } + } else { + if err != nil { + t.Fatalf("expected success, got error %v", err) + } + } - registerTestingService(server, &testingServer{}) + if err := server.Shutdown(ctx); err != nil { + t.Fatal(err) + } + if err := <-errs; err != ErrServerClosed { + t.Fatal(err) + } + } - tp := &internal.TestPayload{ - Foo: strings.Repeat("a", 1+messageLengthMax), + for _, tc := range []*testCase{ + { + name: "default limits, fitting request and response", + echoOnce: true, + clientLimit: 0, + serverLimit: 0, + requestSize: DefaultMessageLengthLimit - overhead, + }, + { + name: "default limits, oversized request", + echoOnce: true, + clientLimit: 0, + serverLimit: 0, + requestSize: DefaultMessageLengthLimit, + clientFail: true, + }, + { + name: "default limits, oversized response", + clientLimit: 0, + serverLimit: 0, + requestSize: DefaultMessageLengthLimit / 2, + serverFail: true, + }, + { + name: "8K limits, fitting 4K request and response", + echoOnce: true, + clientLimit: 8 * 1024, + serverLimit: 8 * 1024, + requestSize: 4 * 1024, + }, + { + name: "8K limits, fitting cc. 4K request and response", + echoOnce: true, + clientLimit: 4 * 1024, + serverLimit: 4 * 1024, + requestSize: 4*1024 - overhead, + }, + { + name: "4K limits, non-fitting 4K response", + echoOnce: true, + clientLimit: 4*1024 + overhead, + serverLimit: 4 * 1024, + requestSize: 4 * 1024, + serverFail: true, + }, + { + name: "too small limits, adjusted to minimum accepted limit", + echoOnce: true, + clientLimit: 4, + serverLimit: 4, + requestSize: 4*1024 - overhead, + }, + { + name: "maximum allowed protocol limit", + echoOnce: true, + clientLimit: MaxMessageLengthLimit, + serverLimit: MaxMessageLengthLimit, + requestSize: MaxMessageLengthLimit - overhead, + }, + } { + t.Run(tc.name, func(t *testing.T) { + runTest(t, tc) + }) } - if err := client.Call(ctx, serviceName, "Test", tp, tp); err == nil { - t.Fatalf("expected error from oversized message") - } else if status, ok := status.FromError(err); !ok { - t.Fatalf("expected status present in error: %v", err) - } else if status.Code() != codes.ResourceExhausted { - t.Fatalf("expected code: %v != %v", status.Code(), codes.ResourceExhausted) +} + +func getWireMessageOverhead(t *testing.T) int { + emptyReq, err := codec{}.Marshal(&Request{ + Service: serviceName, + Method: "Test", + }) + if err != nil { + t.Fatalf("failed to marshal empty request: %v", err) } - if err := server.Shutdown(ctx); err != nil { - t.Fatal(err) + emptyRsp, err := codec{}.Marshal(&Response{ + Status: status.New(codes.OK, "").Proto(), + }) + if err != nil { + t.Fatalf("failed to marshal empty response: %v", err) } - if err := <-errs; err != ErrServerClosed { - t.Fatal(err) + + if reqLen, rspLen := len(emptyReq), len(emptyRsp); reqLen > rspLen { + return reqLen + messageHeaderLength + } else { + return rspLen + messageHeaderLength } } @@ -551,13 +674,20 @@ func newTestClient(t testing.TB, addr string, opts ...ClientOpts) (*Client, func } func newTestListener(t testing.TB) (string, net.Listener) { - var prefix string + var ( + name = t.Name() + prefix string + ) // Abstracts sockets are only available on Linux. if runtime.GOOS == "linux" { prefix = "\x00" + } else { + if split := strings.SplitN(name, "/", 2); len(split) == 2 { + name = split[0] + "-" + fmt.Sprintf("%x", md5.Sum([]byte(split[1]))) + } } - addr := prefix + t.Name() + addr := prefix + name listener, err := net.Listen("unix", addr) if err != nil { t.Fatal(err)