Skip to content

Commit 8359a09

Browse files
rwgkcopybara-github
authored andcommitted
Remove two pool-membership conditions guarding the C++ equivalent of obj.SerializePartialToString()
The main change in this CL is to remove two conditions in `PyProtoIsCompatible()`: 1. ``` if (descriptor->file()->pool() != DescriptorPool::generated_pool()) { ``` 2. ``` return py_pool->is(GlobalState::instance()->global_pool()); ``` Rationale for removing these conditions: * All that matters for protobuf compatibility is that the `full_name` is the same. (Thanks @kmoffett for that insight!) * Cross-extension-module ABI compatibility is not a concern because only the Python API is used in the relevant code paths serializing Python protobuf objects to Python `bytes` (equivalent to calling `obj.SerializePartialToString()` from Python). All other changes in this CL are secondary: small-scale refactoring, slight naming changes, additional tests for error conditions. PiperOrigin-RevId: 589898116
1 parent b713501 commit 8359a09

File tree

4 files changed

+103
-78
lines changed

4 files changed

+103
-78
lines changed

pybind11_protobuf/proto_cast_util.cc

+43-62
Original file line numberDiff line numberDiff line change
@@ -21,6 +21,7 @@
2121
#include "absl/strings/numbers.h"
2222
#include "absl/strings/str_replace.h"
2323
#include "absl/strings/str_split.h"
24+
#include "absl/strings/string_view.h"
2425
#include "absl/types/optional.h"
2526
#include "pybind11_protobuf/check_unknown_fields.h"
2627

@@ -534,10 +535,8 @@ class PythonDescriptorPoolWrapper {
534535
}
535536
}
536537

537-
py::object wire = py_file_descriptor.attr("serialized_pb");
538-
const char* bytes = PYBIND11_BYTES_AS_STRING(wire.ptr());
539-
return output->ParsePartialFromArray(bytes,
540-
PYBIND11_BYTES_SIZE(wire.ptr()));
538+
return output->ParsePartialFromString(
539+
PyBytesAsStringView(py_file_descriptor.attr("serialized_pb")));
541540
}
542541

543542
py::object pool_; // never dereferenced.
@@ -549,6 +548,11 @@ class PythonDescriptorPoolWrapper {
549548

550549
} // namespace
551550

551+
absl::string_view PyBytesAsStringView(py::bytes py_bytes) {
552+
return absl::string_view(PyBytes_AsString(py_bytes.ptr()),
553+
PyBytes_Size(py_bytes.ptr()));
554+
}
555+
552556
void InitializePybindProtoCastUtil() {
553557
assert(PyGILState_Check());
554558
GlobalState::instance();
@@ -593,7 +597,7 @@ const Message* PyProtoGetCppMessagePointer(py::handle src) {
593597
#endif
594598
}
595599

596-
absl::optional<std::string> PyProtoDescriptorName(py::handle py_proto) {
600+
absl::optional<std::string> PyProtoDescriptorFullName(py::handle py_proto) {
597601
assert(PyGILState_Check());
598602
auto py_full_name = ResolveAttrs(py_proto, {"DESCRIPTOR", "full_name"});
599603
if (py_full_name) {
@@ -602,66 +606,42 @@ absl::optional<std::string> PyProtoDescriptorName(py::handle py_proto) {
602606
return absl::nullopt;
603607
}
604608

605-
bool PyProtoIsCompatible(py::handle py_proto, const Descriptor* descriptor) {
606-
assert(PyGILState_Check());
607-
if (descriptor->file()->pool() != DescriptorPool::generated_pool()) {
608-
/// This indicates that the C++ descriptor does not come from the C++
609-
/// DescriptorPool. This may happen if the C++ code has the same proto
610-
/// in different descriptor pools, perhaps from different shared objects,
611-
/// and could be result in undefined behavior.
612-
return false;
613-
}
614-
615-
auto py_descriptor = ResolveAttrs(py_proto, {"DESCRIPTOR"});
616-
if (!py_descriptor) {
617-
// Not a valid protobuf -- missing DESCRIPTOR.
618-
return false;
619-
}
620-
621-
// Test full_name equivalence.
622-
{
623-
auto py_full_name = ResolveAttrs(*py_descriptor, {"full_name"});
624-
if (!py_full_name) {
625-
// Not a valid protobuf -- missing DESCRIPTOR.full_name
626-
return false;
627-
}
628-
auto full_name = CastToOptionalString(*py_full_name);
629-
if (!full_name || *full_name != descriptor->full_name()) {
630-
// Name mismatch.
631-
return false;
632-
}
633-
}
634-
635-
// The C++ descriptor is compiled in (see above assert), so the py_proto
636-
// is expected to be from the global pool, i.e. the DESCRIPTOR.file.pool
637-
// instance is the global python pool, and not a custom pool.
638-
auto py_pool = ResolveAttrs(*py_descriptor, {"file", "pool"});
639-
if (py_pool) {
640-
return py_pool->is(GlobalState::instance()->global_pool());
641-
}
642-
643-
// The py_proto is missing a DESCRIPTOR.file.pool, but the name matches.
644-
// This will not happen with a native python implementation, but does
645-
// occur with the deprecated :proto_casters, and could happen with other
646-
// mocks. Returning true allows the caster to call PyProtoCopyToCProto.
647-
return true;
609+
bool PyProtoHasMatchingFullName(py::handle py_proto,
610+
const Descriptor* descriptor) {
611+
auto full_name = PyProtoDescriptorFullName(py_proto);
612+
return full_name && *full_name == descriptor->full_name();
648613
}
649614

650-
bool PyProtoCopyToCProto(py::handle py_proto, Message* message) {
651-
assert(PyGILState_Check());
652-
auto serialize_fn = ResolveAttrMRO(py_proto, "SerializePartialToString");
615+
py::bytes PyProtoSerializePartialToString(py::handle py_proto,
616+
bool raise_if_error) {
617+
static const char* serialize_fn_name = "SerializePartialToString";
618+
auto serialize_fn = ResolveAttrMRO(py_proto, serialize_fn_name);
653619
if (!serialize_fn) {
654-
throw py::type_error(
655-
"SerializePartialToString method not found; is this a " +
656-
message->GetDescriptor()->full_name());
657-
}
658-
auto wire = (*serialize_fn)();
659-
const char* bytes = PYBIND11_BYTES_AS_STRING(wire.ptr());
660-
if (!bytes) {
661-
throw py::type_error("SerializePartialToString failed; is this a " +
662-
message->GetDescriptor()->full_name());
620+
return py::object();
621+
}
622+
auto serialized_bytes = py::reinterpret_steal<py::object>(
623+
PyObject_CallObject(serialize_fn->ptr(), nullptr));
624+
if (!serialized_bytes) {
625+
if (raise_if_error) {
626+
std::string msg = py::repr(py_proto).cast<std::string>() + "." +
627+
serialize_fn_name + "() function call FAILED";
628+
py::raise_from(PyExc_TypeError, msg.c_str());
629+
throw py::error_already_set();
630+
}
631+
return py::object();
632+
}
633+
if (!PyBytes_Check(serialized_bytes.ptr())) {
634+
if (raise_if_error) {
635+
std::string msg = py::repr(py_proto).cast<std::string>() + "." +
636+
serialize_fn_name +
637+
"() function call is expected to return bytes, but the "
638+
"returned value is " +
639+
py::repr(serialized_bytes).cast<std::string>();
640+
throw py::type_error(msg);
641+
}
642+
return py::object();
663643
}
664-
return message->ParsePartialFromArray(bytes, PYBIND11_BYTES_SIZE(wire.ptr()));
644+
return serialized_bytes;
665645
}
666646

667647
void CProtoCopyToPyProto(Message* message, py::handle py_proto) {
@@ -686,7 +666,8 @@ std::unique_ptr<Message> AllocateCProtoFromPythonSymbolDatabase(
686666
assert(PyGILState_Check());
687667
auto pool = ResolveAttrs(src, {"DESCRIPTOR", "file", "pool"});
688668
if (!pool) {
689-
throw py::type_error("Object is not a valid protobuf");
669+
throw py::type_error(py::repr(src).cast<std::string>() +
670+
" object is not a valid protobuf");
690671
}
691672

692673
auto pool_data =

pybind11_protobuf/proto_cast_util.h

+14-8
Original file line numberDiff line numberDiff line change
@@ -14,6 +14,7 @@
1414

1515
#include "google/protobuf/descriptor.h"
1616
#include "google/protobuf/message.h"
17+
#include "absl/strings/string_view.h"
1718
#include "absl/types/optional.h"
1819

1920
// PYBIND11_PROTOBUF_ASSUME_FULL_ABI_COMPATIBILITY can be defined by users
@@ -28,6 +29,10 @@
2829

2930
namespace pybind11_protobuf {
3031

32+
// Simple helper. Caller has to ensure that the py_bytes argument outlives the
33+
// returned string_view.
34+
absl::string_view PyBytesAsStringView(pybind11::bytes py_bytes);
35+
3136
// Initialize internal proto cast dependencies, which includes importing
3237
// various protobuf-related modules.
3338
void InitializePybindProtoCastUtil();
@@ -39,22 +44,23 @@ void ImportProtoDescriptorModule(const ::google::protobuf::Descriptor *);
3944
const ::google::protobuf::Message *PyProtoGetCppMessagePointer(pybind11::handle src);
4045

4146
// Returns the protocol buffer's py_proto.DESCRIPTOR.full_name attribute.
42-
absl::optional<std::string> PyProtoDescriptorName(pybind11::handle py_proto);
47+
absl::optional<std::string> PyProtoDescriptorFullName(
48+
pybind11::handle py_proto);
49+
50+
// Returns true if py_proto full name matches descriptor full name.
51+
bool PyProtoHasMatchingFullName(pybind11::handle py_proto,
52+
const ::google::protobuf::Descriptor *descriptor);
4353

44-
// Return whether py_proto is compatible with the C++ descriptor.
45-
// The py_proto name must match the C++ Descriptor::full_name(), and is
46-
// expected to originate from the python default pool, which means that
47-
// this method will return false for dynamic protos.
48-
bool PyProtoIsCompatible(pybind11::handle py_proto,
49-
const ::google::protobuf::Descriptor *descriptor);
54+
// Caller should enforce any type identity that is required.
55+
pybind11::bytes PyProtoSerializePartialToString(pybind11::handle py_proto,
56+
bool raise_if_error);
5057

5158
// Allocates a C++ protocol buffer for a given name.
5259
std::unique_ptr<::google::protobuf::Message> AllocateCProtoFromPythonSymbolDatabase(
5360
pybind11::handle src, const std::string &full_name);
5461

5562
// Serialize the py_proto and deserialize it into the provided message.
5663
// Caller should enforce any type identity that is required.
57-
bool PyProtoCopyToCProto(pybind11::handle py_proto, ::google::protobuf::Message *message);
5864
void CProtoCopyToPyProto(::google::protobuf::Message *message, pybind11::handle py_proto);
5965

6066
// Returns a handle to a python protobuf suitably

pybind11_protobuf/proto_caster_impl.h

+18-8
Original file line numberDiff line numberDiff line change
@@ -63,16 +63,19 @@ struct proto_caster_load_impl {
6363
}
6464
}
6565

66-
// The incoming object is not a compatible fast_cpp_proto, so check whether
67-
// it is otherwise compatible, then serialize it and deserialize into a
68-
// native C++ proto type.
69-
if (!pybind11_protobuf::PyProtoIsCompatible(src,
70-
ProtoType::GetDescriptor())) {
66+
if (!PyProtoHasMatchingFullName(src, ProtoType::GetDescriptor())) {
7167
return false;
7268
}
69+
pybind11::bytes serialized_bytes =
70+
PyProtoSerializePartialToString(src, convert);
71+
if (!serialized_bytes) {
72+
return false;
73+
}
74+
7375
owned = std::unique_ptr<ProtoType>(new ProtoType());
7476
value = owned.get();
75-
return pybind11_protobuf::PyProtoCopyToCProto(src, owned.get());
77+
return owned.get()->ParsePartialFromString(
78+
PyBytesAsStringView(serialized_bytes));
7679
}
7780

7881
// ensure_owned ensures that the owned member contains a copy of the
@@ -108,16 +111,23 @@ struct proto_caster_load_impl<::google::protobuf::Message> {
108111

109112
// `src` is not a C++ proto instance from the generated_pool,
110113
// so create a compatible native C++ proto.
111-
auto descriptor_name = pybind11_protobuf::PyProtoDescriptorName(src);
114+
auto descriptor_name = pybind11_protobuf::PyProtoDescriptorFullName(src);
112115
if (!descriptor_name) {
113116
return false;
114117
}
118+
pybind11::bytes serialized_bytes =
119+
PyProtoSerializePartialToString(src, convert);
120+
if (!serialized_bytes) {
121+
return false;
122+
}
123+
115124
owned.reset(static_cast<ProtoType *>(
116125
pybind11_protobuf::AllocateCProtoFromPythonSymbolDatabase(
117126
src, *descriptor_name)
118127
.release()));
119128
value = owned.get();
120-
return pybind11_protobuf::PyProtoCopyToCProto(src, owned.get());
129+
return owned.get()->ParsePartialFromString(
130+
PyBytesAsStringView(serialized_bytes));
121131
}
122132

123133
// ensure_owned ensures that the owned member contains a copy of the

pybind11_protobuf/tests/pass_by_test.py

+28
Original file line numberDiff line numberDiff line change
@@ -170,6 +170,34 @@ def test_pass_fake2(self, check_method):
170170
def test_overload_fn(self, message_fn, expected):
171171
self.assertEqual(expected, m.fn_overload(message_fn()))
172172

173+
def test_bad_serialize_partial_function_calls(self):
174+
class FakeDescr:
175+
full_name = 'fake_full_name'
176+
177+
class FakeProto:
178+
DESCRIPTOR = FakeDescr()
179+
180+
def __init__(self, serialize_fn_return_value=None):
181+
self.serialize_fn_return_value = serialize_fn_return_value
182+
183+
def SerializePartialToString(self): # pylint: disable=invalid-name
184+
if self.serialize_fn_return_value is None:
185+
raise RuntimeError('Broken serialize_fn.')
186+
return self.serialize_fn_return_value
187+
188+
with self.assertRaisesRegex(
189+
TypeError, r'\.SerializePartialToString\(\) function call FAILED$'
190+
):
191+
m.fn_overload(FakeProto())
192+
with self.assertRaisesRegex(
193+
TypeError,
194+
r'\.SerializePartialToString\(\) function call is expected to return'
195+
r' bytes, but the returned value is \[\]$',
196+
):
197+
m.fn_overload(FakeProto([]))
198+
with self.assertRaisesRegex(TypeError, r' object is not a valid protobuf$'):
199+
m.fn_overload(FakeProto(b''))
200+
173201

174202
if __name__ == '__main__':
175203
absltest.main()

0 commit comments

Comments
 (0)