Skip to content
Draft
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
32 changes: 26 additions & 6 deletions src/google/protobuf/compiler/cpp/service.cc
Original file line number Diff line number Diff line change
Expand Up @@ -55,7 +55,7 @@ void ServiceGenerator::GenerateDeclarations(io::Printer* printer) {
$pb$::RpcController* $nullable$ controller,
const $pb$::Message* $nonnull$ request,
$pb$::Message* $nonnull$ response,
::google::protobuf::Closure* $nullable$ done) override;
::std::function<void()> $nullable$ done) override;

const $pb$::Message& GetRequestPrototype(
const $pb$::MethodDescriptor* $nonnull$ method) const override;
Expand Down Expand Up @@ -109,6 +109,11 @@ void ServiceGenerator::GenerateMethodSignatures(VirtualOrNot virtual_or_not,
const $input$* $nonnull$ request,
$output$* $nonnull$ response,
::google::protobuf::Closure* $nullable$ done)$ override$;

$virtual $void $name$($pb$::RpcController* $nullable$ controller,
const $input$* $nonnull$ request,
$output$* $nonnull$ response,
::std::function<void()> $nullable$ done)$ override$;
)");
}
}
Expand Down Expand Up @@ -179,6 +184,13 @@ void ServiceGenerator::GenerateNotImplementedMethods(io::Printer* printer) {
controller->SetFailed("Method $name$() not implemented.");
done->Run();
}

void $classname$::$name$($pb$::RpcController* $nullable$ controller,
const $input$* $nonnull$ request,
$output$* $nonnull$ response,
::std::function<void()> $nullable$ done) {
$name$(controller, request, response, $pb$::ToClosure(::std::move(done)));
}
)cc");
}
}
Expand All @@ -190,11 +202,12 @@ void ServiceGenerator::GenerateCallMethod(io::Printer* printer) {
{"cases", [&] { GenerateCallMethodCases(printer); }},
},
R"cc(
void $classname$::CallMethod(
const $pb$::MethodDescriptor* $nonnull$ method,
$pb$::RpcController* $nullable$ controller,
const $pb$::Message* $nonnull$ request,
$pb$::Message* $nonnull$ response, ::google::protobuf::Closure* $nullable$ done) {
void $classname$::CallMethod(const $pb$::MethodDescriptor* $nonnull$
method,
$pb$::RpcController* $nullable$ controller,
const $pb$::Message* $nonnull$ request,
$pb$::Message* $nonnull$ response,
::std::function<void()> $nullable$ done) {
ABSL_DCHECK_EQ(method->service(), $file_level_service_descriptors$[$index$]);
switch (method->index()) {
$cases$;
Expand Down Expand Up @@ -284,6 +297,13 @@ void ServiceGenerator::GenerateStubMethods(io::Printer* printer) {
$pb$::RpcController* $nullable$ controller,
const $input$* $nonnull$ request, $output$* $nonnull$ response,
::google::protobuf::Closure* $nullable$ done) {
$name$(controller, request, response, $pb$::ToFunction(done));
}

void $classname$_Stub::$name$(
$pb$::RpcController* $nullable$ controller,
const $input$* $nonnull$ request, $output$* $nonnull$ response,
::std::function<void()> $nullable$ done) {
channel_->CallMethod(descriptor()->method($index$), controller,
request, response, done);
}
Expand Down
51 changes: 25 additions & 26 deletions src/google/protobuf/compiler/cpp/unittest.inc
Original file line number Diff line number Diff line change
Expand Up @@ -1125,7 +1125,8 @@ class GENERATED_SERVICE_TEST_NAME : public testing::Test {
// implements TestService ----------------------------------------

void Foo(RpcController* controller, const UNITTEST::FooRequest* request,
UNITTEST::FooResponse* response, Closure* done) override {
UNITTEST::FooResponse* response,
std::function<void()> done) override {
ASSERT_FALSE(called_);
called_ = true;
method_ = "Foo";
Expand All @@ -1136,7 +1137,8 @@ class GENERATED_SERVICE_TEST_NAME : public testing::Test {
}

void Bar(RpcController* controller, const UNITTEST::BarRequest* request,
UNITTEST::BarResponse* response, Closure* done) override {
UNITTEST::BarResponse* response,
std::function<void()> done) override {
ASSERT_FALSE(called_);
called_ = true;
method_ = "Bar";
Expand All @@ -1153,7 +1155,7 @@ class GENERATED_SERVICE_TEST_NAME : public testing::Test {
RpcController* controller_;
const Message* request_;
Message* response_;
Closure* done_;
std::function<void()> done_;
};

class MockRpcChannel : public RpcChannel {
Expand All @@ -1177,7 +1179,7 @@ class GENERATED_SERVICE_TEST_NAME : public testing::Test {

void CallMethod(const MethodDescriptor* method, RpcController* controller,
const Message* request, Message* response,
Closure* done) override {
std::function<void()> done) override {
ASSERT_FALSE(called_);
called_ = true;
method_ = method;
Expand All @@ -1194,7 +1196,7 @@ class GENERATED_SERVICE_TEST_NAME : public testing::Test {
RpcController* controller_;
const Message* request_;
Message* response_;
Closure* done_;
std::function<void()> done_;
bool* destroyed_;
};

Expand All @@ -1221,17 +1223,17 @@ class GENERATED_SERVICE_TEST_NAME : public testing::Test {
ADD_FAILURE() << "IsCanceled() not expected during this test.";
return false;
}
void NotifyOnCancel(Closure* callback) override {
void NotifyOnCancel(std::function<void()> callback) override {
ADD_FAILURE() << "NotifyOnCancel() not expected during this test.";
}
};

GENERATED_SERVICE_TEST_NAME()
: descriptor_(UNITTEST::TestService::descriptor()),
foo_(descriptor_->FindMethodByName("Foo")),
bar_(descriptor_->FindMethodByName("Bar")),
stub_(&mock_channel_),
done_(::google::protobuf::NewPermanentCallback(&DoNothing)) {}
: descriptor_(UNITTEST::TestService::descriptor()),
foo_(descriptor_->FindMethodByName("Foo")),
bar_(descriptor_->FindMethodByName("Bar")),
stub_(&mock_channel_),
done_(&DoNothing) {}

void SetUp() override {
ASSERT_TRUE(foo_ != nullptr);
Expand All @@ -1253,7 +1255,7 @@ class GENERATED_SERVICE_TEST_NAME : public testing::Test {
UNITTEST::FooResponse foo_response_;
UNITTEST::BarRequest bar_request_;
UNITTEST::BarResponse bar_response_;
std::unique_ptr<Closure> done_;
std::function<void()> done_;
};

TEST_F(GENERATED_SERVICE_TEST_NAME, GetDescriptor) {
Expand Down Expand Up @@ -1284,21 +1286,20 @@ TEST_F(GENERATED_SERVICE_TEST_NAME, CallMethod) {
// Test that CallMethod() works.

// Call Foo() via CallMethod().
mock_service_.CallMethod(foo_, &mock_controller_,
&foo_request_, &foo_response_, done_.get());
mock_service_.CallMethod(foo_, &mock_controller_, &foo_request_,
&foo_response_, done_);

ASSERT_TRUE(mock_service_.called_);

EXPECT_EQ("Foo" , mock_service_.method_ );
EXPECT_EQ(&mock_controller_, mock_service_.controller_);
EXPECT_EQ(&foo_request_ , mock_service_.request_ );
EXPECT_EQ(&foo_response_ , mock_service_.response_ );
EXPECT_EQ(done_.get() , mock_service_.done_ );
EXPECT_EQ(&foo_response_, mock_service_.response_);

// Try again, but call Bar() instead.
mock_service_.Reset();
mock_service_.CallMethod(bar_, &mock_controller_,
&bar_request_, &bar_response_, done_.get());
mock_service_.CallMethod(bar_, &mock_controller_, &bar_request_,
&bar_response_, done_);

ASSERT_TRUE(mock_service_.called_);
EXPECT_EQ("Bar", mock_service_.method_);
Expand All @@ -1310,14 +1311,14 @@ TEST_F(GENERATED_SERVICE_TEST_NAME, CallMethodTypeFailure) {
#if GTEST_HAS_DEATH_TEST // death tests do not work on Windows yet
EXPECT_DEBUG_DEATH(
mock_service_.CallMethod(foo_, &mock_controller_, &foo_request_,
&bar_response_, done_.get()),
&bar_response_, done_),
"Cannot downcast proto2_unittest.*.BarResponse to "
"proto2_unittest.*.FooResponse");

mock_service_.Reset();
EXPECT_DEBUG_DEATH(
mock_service_.CallMethod(foo_, &mock_controller_, &bar_request_,
&foo_response_, done_.get()),
&foo_response_, done_),
"Cannot downcast proto2_unittest.*.BarRequest to "
"proto2_unittest.*.FooRequest");
#endif // GTEST_HAS_DEATH_TEST
Expand All @@ -1341,19 +1342,18 @@ TEST_F(GENERATED_SERVICE_TEST_NAME, Stub) {
// Test that the stub class works.

// Call Foo() via the stub.
stub_.Foo(&mock_controller_, &foo_request_, &foo_response_, done_.get());
stub_.Foo(&mock_controller_, &foo_request_, &foo_response_, done_);

ASSERT_TRUE(mock_channel_.called_);

EXPECT_EQ(foo_ , mock_channel_.method_ );
EXPECT_EQ(&mock_controller_, mock_channel_.controller_);
EXPECT_EQ(&foo_request_ , mock_channel_.request_ );
EXPECT_EQ(&foo_response_ , mock_channel_.response_ );
EXPECT_EQ(done_.get() , mock_channel_.done_ );
EXPECT_EQ(&foo_response_, mock_channel_.response_);

// Call Bar() via the stub.
mock_channel_.Reset();
stub_.Bar(&mock_controller_, &bar_request_, &bar_response_, done_.get());
stub_.Bar(&mock_controller_, &bar_request_, &bar_response_, done_);

ASSERT_TRUE(mock_channel_.called_);
EXPECT_EQ(bar_, mock_channel_.method_);
Expand Down Expand Up @@ -1388,8 +1388,7 @@ TEST_F(GENERATED_SERVICE_TEST_NAME, NotImplemented) {
ExpectUnimplementedController controller;

// Call Foo.
unimplemented_service.Foo(&controller, &foo_request_, &foo_response_,
done_.get());
unimplemented_service.Foo(&controller, &foo_request_, &foo_response_, done_);

EXPECT_TRUE(controller.called_);
}
Expand Down
57 changes: 57 additions & 0 deletions src/google/protobuf/service.cc
Original file line number Diff line number Diff line change
Expand Up @@ -11,12 +11,69 @@

#include "google/protobuf/service.h"

#include <functional>
#include <utility>

#include "google/protobuf/stubs/callback.h"
#include "google/protobuf/stubs/common.h"

namespace google {
namespace protobuf {

Service::~Service() {}
RpcChannel::~RpcChannel() {}
RpcController::~RpcController() {}

namespace {

// A Closure that wraps a std::function<void()> and deletes itself on call.
class FunctionClosure final : public Closure {
public:
explicit FunctionClosure(std::function<void()> f) : f_(std::move(f)) {}
void Run() override {
f_();
delete this;
}

private:
std::function<void()> f_;
};

} // namespace

void RpcController::NotifyOnCancel(std::function<void()> callback) {
NotifyOnCancel(ToClosure(std::move(callback)));
}

void RpcController::NotifyOnCancel(Closure* callback) {
NotifyOnCancel(ToFunction(callback));
}

void RpcChannel::CallMethod(const MethodDescriptor* method,
RpcController* controller, const Message* request,
Message* response, std::function<void()> done) {
CallMethod(method, controller, request, response, ToClosure(std::move(done)));
}

void RpcChannel::CallMethod(const MethodDescriptor* method,
RpcController* controller, const Message* request,
Message* response, Closure* done) {
CallMethod(method, controller, request, response, ToFunction(done));
}

Closure* ToClosure(std::function<void()> f) {
if (f == nullptr) {
return nullptr;
}
return new FunctionClosure(std::move(f));
}

std::function<void()> ToFunction(Closure* callback) {
if (callback == nullptr) {
return std::function<void()>();
}
return [callback] { callback->Run(); };
}

} // namespace protobuf
} // namespace google
27 changes: 22 additions & 5 deletions src/google/protobuf/service.h
Original file line number Diff line number Diff line change
Expand Up @@ -41,7 +41,7 @@
// void Foo(google::protobuf::RpcController* controller,
// const MyRequest* request,
// MyResponse* response,
// Closure* done) {
// std::function<void()> done) {
// // ... read request and fill in response ...
// done->Run();
// }
Expand All @@ -60,7 +60,7 @@
//
// // ... fill in request ...
//
// stub.Foo(&controller, request, &response, NewCallback(HandleResponse));
// stub.Foo(&controller, request, &response, HandleResponse);
//
// On Thread-Safety:
//
Expand All @@ -77,6 +77,7 @@
#ifndef GOOGLE_PROTOBUF_SERVICE_H__
#define GOOGLE_PROTOBUF_SERVICE_H__

#include <functional>
#include <string>

#include "google/protobuf/stubs/callback.h"
Expand Down Expand Up @@ -152,7 +153,7 @@ class PROTOBUF_EXPORT Service {
// possibly to get more information about the error.
virtual void CallMethod(const MethodDescriptor* method,
RpcController* controller, const Message* request,
Message* response, Closure* done) = 0;
Message* response, std::function<void()> done) = 0;

// CallMethod() requires that the request and response passed in are of a
// particular subclass of Message. GetRequestPrototype() and
Expand Down Expand Up @@ -235,7 +236,11 @@ class PROTOBUF_EXPORT RpcController {
// will be called immediately.
//
// NotifyOnCancel() must be called no more than once per request.
virtual void NotifyOnCancel(Closure* callback) = 0;
//
// One overload of NotifyOnCancel must be implemented. The default
// implementations call each other to allow for migration between them.
virtual void NotifyOnCancel(std::function<void()> callback);
virtual void NotifyOnCancel(Closure* callback);
};

// Abstract interface for an RPC channel. An RpcChannel represents a
Expand All @@ -258,11 +263,23 @@ class PROTOBUF_EXPORT RpcChannel {
// are less strict in one important way: the request and response objects
// need not be of any specific class as long as their descriptors are
// method->input_type() and method->output_type().
//
// One overload of CallMethod must be implemented. The default
// implementations call each other to allow for migration between them.
virtual void CallMethod(const MethodDescriptor* method,
RpcController* controller, const Message* request,
Message* response, std::function<void()> done);
virtual void CallMethod(const MethodDescriptor* method,
RpcController* controller, const Message* request,
Message* response, Closure* done) = 0;
Message* response, Closure* done);
};

// A helper function for converting between Closure* and std::function<void()>.
Closure* PROTOBUF_EXPORT ToClosure(std::function<void()> f);

// A helper function for converting between Closure* and std::function<void()>.
std::function<void()> PROTOBUF_EXPORT ToFunction(Closure* callback);

} // namespace protobuf
} // namespace google

Expand Down
Loading
Loading