Skip to content
Draft
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
10 changes: 10 additions & 0 deletions python/google/protobuf/internal/containers.py
Original file line number Diff line number Diff line change
Expand Up @@ -101,10 +101,20 @@ def sort(self, *args, **kwargs) -> None:
if 'sort_function' in kwargs:
kwargs['cmp'] = kwargs.pop('sort_function')
self._values.sort(*args, **kwargs)
if not self._message_listener.dirty:
self._message_listener.Modified()

def reverse(self) -> None:
self._AssureWritable()
self._values.reverse()
if not self._message_listener.dirty:
self._message_listener.Modified()

def clear(self) -> None:
self._AssureWritable()
self._values.clear()
if not self._message_listener.dirty:
self._message_listener.Modified()


# TODO: Remove this. BaseContainer does *not* conform to
Expand Down
89 changes: 74 additions & 15 deletions python/google/protobuf/internal/message_test.py
Original file line number Diff line number Diff line change
Expand Up @@ -531,16 +531,16 @@ def testAssignRepeatedField(self, message_module):
self.assertEqual([1, 2, 3, 4], msg.payload.repeated_int32)

def testRepeatedFieldSelfSliceAssignment(self, message_module):
msg = message_module.NestedTestAllTypes()
msg.payload.repeated_int32[:] = [1, 2, 3, 4]
msg.payload.repeated_int32[:] = msg.payload.repeated_int32
self.assertEqual([1, 2, 3, 4], msg.payload.repeated_int32)
msg = message_module.NestedTestAllTypes()
msg.payload.repeated_int32[:] = [1, 2, 3, 4]
msg.payload.repeated_int32[:] = msg.payload.repeated_int32
self.assertEqual([1, 2, 3, 4], msg.payload.repeated_int32)

def testRepeatedFieldSelfExtend(self, message_module):
msg = message_module.NestedTestAllTypes()
msg.payload.repeated_int32[:] = [1, 2, 3, 4]
msg.payload.repeated_int32.extend(msg.payload.repeated_int32)
self.assertEqual([1, 2, 3, 4] * 2, msg.payload.repeated_int32)
msg = message_module.NestedTestAllTypes()
msg.payload.repeated_int32[:] = [1, 2, 3, 4]
msg.payload.repeated_int32.extend(msg.payload.repeated_int32)
self.assertEqual([1, 2, 3, 4] * 2, msg.payload.repeated_int32)

def testAssignOutOfRange(self, message_module):
msg = message_module.NestedTestAllTypes()
Expand Down Expand Up @@ -781,13 +781,72 @@ def testRepeatedNestedFieldIteration(self, message_module):
)

def testSortEmptyRepeated(self, message_module):
message = message_module.NestedTestAllTypes()
self.assertFalse(message.HasField('child'))
self.assertFalse(message.HasField('payload'))
message.child.repeated_child.sort()
message.payload.repeated_int32.sort()
self.assertFalse(message.HasField('child'))
self.assertFalse(message.HasField('payload'))
if api_implementation.Type() == 'python':
return
msg = message_module.NestedTestAllTypes()
self.assertFalse(msg.HasField('child'))
self.assertFalse(msg.HasField('payload'))
msg.child.repeated_child.sort()
msg.payload.repeated_int32.sort()
self.assertTrue(msg.HasField('child'))
self.assertTrue(msg.HasField('payload'))

def testReverseEmptyRepeated(self, message_module):
if api_implementation.Type() == 'python':
return
msg = message_module.NestedTestAllTypes()
self.assertFalse(msg.HasField('child'))
self.assertFalse(msg.HasField('payload'))
msg.child.repeated_child.reverse()
msg.payload.repeated_int32.reverse()
self.assertTrue(msg.HasField('child'))
self.assertTrue(msg.HasField('payload'))

def testClearEmptyRepeated(self, message_module):
if api_implementation.Type() == 'python':
return
msg = message_module.NestedTestAllTypes()
self.assertFalse(msg.HasField('child'))
self.assertFalse(msg.HasField('payload'))
msg.child.repeated_child.clear()
msg.payload.repeated_int32.clear()
self.assertTrue(msg.HasField('child'))
self.assertTrue(msg.HasField('payload'))

def testDelEmptyRepeated(self, message_module):
if api_implementation.Type() == 'python':
return
msg = message_module.NestedTestAllTypes()
self.assertFalse(msg.HasField('child'))
self.assertFalse(msg.HasField('payload'))
del msg.child.repeated_child[:]
del msg.payload.repeated_int32[:]
self.assertTrue(msg.HasField('child'))
self.assertTrue(msg.HasField('payload'))

def testImmutabilityCheckComesFirstEmptyRepeated(self, message_module):
if api_implementation.Type() == 'python':
return
msg = message_module.TestAllTypes()
options = msg.DESCRIPTOR.GetOptions()
self.assertEqual(0, len(options.uninterpreted_option))

with self.assertRaises(
(TypeError, AttributeError, message.FrozenInstanceError)
):
options.uninterpreted_option.sort()
with self.assertRaises(
(TypeError, AttributeError, message.FrozenInstanceError)
):
options.uninterpreted_option.reverse()
with self.assertRaises(
(TypeError, AttributeError, message.FrozenInstanceError)
):
options.uninterpreted_option.clear()
with self.assertRaises(
(TypeError, AttributeError, message.FrozenInstanceError)
):
del options.uninterpreted_option[:]

def testSortingRepeatedScalarFieldsDefaultComparator(self, message_module):
"""Check some different types with the default comparator."""
Expand Down
44 changes: 12 additions & 32 deletions python/google/protobuf/pyext/repeated_composite_container.cc
Original file line number Diff line number Diff line change
Expand Up @@ -263,12 +263,6 @@ int AssignSubscript(RepeatedCompositeContainer* self, PyObject* slice,
return -1;
}

// TODO: b/517235198 - Reify even for empty sequences.
int status = cmessage::CheckRepeatedFieldDeletion(
self->parent, self->parent_field_descriptor, slice);
if (status < 0) return -1;
if (status > 0) return 0;

if (cmessage::AssureWritable(self->parent) == nullptr) return -1;

return cmessage::DeleteRepeatedField(self->parent,
Expand Down Expand Up @@ -361,13 +355,12 @@ static PyObject* ToStr(PyObject* pself) {
// ---------------------------------------------------------------------
// sort()

static void ReorderAttached(RepeatedCompositeContainer* self,
static bool ReorderAttached(RepeatedCompositeContainer* self,
PyObject* child_list) {
const Py_ssize_t length = Length(reinterpret_cast<PyObject*>(self));
if (length == 0) return;

Message* message = cmessage::AssureWritable(self->parent);
if (message == nullptr) return;
if (message == nullptr) return false;
const Py_ssize_t length = Length(reinterpret_cast<PyObject*>(self));
if (length == 0) return true;
const Reflection* reflection = message->GetReflection();
const FieldDescriptor* descriptor = self->parent_field_descriptor;

Expand All @@ -379,10 +372,11 @@ static void ReorderAttached(RepeatedCompositeContainer* self,
CMessage* child_cmsg =
reinterpret_cast<CMessage*>(PyList_GET_ITEM(child_list, i));
Message* child_message = cmessage::AssureWritable(child_cmsg);
if (child_message == nullptr) return;
if (child_message == nullptr) return false;
reflection->UnsafeArenaAddAllocatedMessage(message, descriptor,
child_message);
}
return true;
}

// Returns 0 if successful; returns -1 and sets an exception if
Expand All @@ -398,17 +392,16 @@ static int SortPythonMessages(RepeatedCompositeContainer* self, PyObject* args,
if (m == nullptr) return -1;
if (ScopedPyObjectPtr(PyObject_Call(m.get(), args, kwds)) == nullptr)
return -1;
ReorderAttached(self, child_list.get());
if (!ReorderAttached(self, child_list.get())) {
return -1;
}
return 0;
}

static PyObject* Sort(PyObject* pself, PyObject* args, PyObject* kwds) {
RepeatedCompositeContainer* self =
reinterpret_cast<RepeatedCompositeContainer*>(pself);

if (self->parent->state == python::MESSAGE_FROZEN) {
return SetContainerFrozenError();
}

// Support the old sort_function argument for backwards
// compatibility.
Expand All @@ -422,10 +415,6 @@ static PyObject* Sort(PyObject* pself, PyObject* args, PyObject* kwds) {
}
}

// TODO: b/517235198 - Reify even for empty sequences.
if (Length(pself) == 0) {
Py_RETURN_NONE;
}

if (SortPythonMessages(self, args, kwds) < 0) {
return nullptr;
Expand All @@ -447,22 +436,17 @@ static int ReversePythonMessages(RepeatedCompositeContainer* self) {
if (ScopedPyObjectPtr(
PyObject_CallMethod(child_list.get(), "reverse", nullptr)) == nullptr)
return -1;
ReorderAttached(self, child_list.get());
if (!ReorderAttached(self, child_list.get())) {
return -1;
}
return 0;
}

static PyObject* Reverse(PyObject* pself) {
RepeatedCompositeContainer* self =
reinterpret_cast<RepeatedCompositeContainer*>(pself);

if (self->parent->state == python::MESSAGE_FROZEN) {
return SetContainerFrozenError();
}

// TODO: b/517235198 - Reify even for empty sequences.
if (Length(pself) == 0) {
Py_RETURN_NONE;
}

if (ReversePythonMessages(self) < 0) {
return nullptr;
Expand All @@ -474,10 +458,6 @@ static PyObject* Reverse(PyObject* pself) {
static PyObject* Clear(PyObject* pself) {
RepeatedCompositeContainer* self =
reinterpret_cast<RepeatedCompositeContainer*>(pself);
// TODO: b/517235198 - Reify even for empty sequences.
if (Length(pself) == 0) {
Py_RETURN_NONE;
}

CMessage* cmessage = self->parent;
Message* message = cmessage::AssureWritable(cmessage);
Expand Down
17 changes: 0 additions & 17 deletions python/google/protobuf/pyext/repeated_scalar_container.cc
Original file line number Diff line number Diff line change
Expand Up @@ -1081,9 +1081,6 @@ static PyObject* Sort(PyObject* pself, PyObject* args, PyObject* kwds) {
RepeatedScalarContainer* self =
reinterpret_cast<RepeatedScalarContainer*>(pself);

if (self->parent->state == python::MESSAGE_FROZEN) {
return SetContainerFrozenError();
}

// Support the old sort_function argument for backwards
// compatibility.
Expand All @@ -1105,9 +1102,6 @@ static PyObject* Sort(PyObject* pself, PyObject* args, PyObject* kwds) {
if (list == nullptr) {
return nullptr;
}
if (PyList_GET_SIZE(list.get()) == 0) {
Py_RETURN_NONE;
}
ScopedPyObjectPtr m(PyObject_GetAttrString(list.get(), "sort"));
if (m == nullptr) {
return nullptr;
Expand All @@ -1128,14 +1122,7 @@ static PyObject* Reverse(PyObject* pself) {
RepeatedScalarContainer* self =
reinterpret_cast<RepeatedScalarContainer*>(pself);

if (self->parent->state == python::MESSAGE_FROZEN) {
return SetContainerFrozenError();
}

// TODO: b/517235198 - Reify even for empty sequences.
if (Len(pself) == 0) {
Py_RETURN_NONE;
}

ScopedPyObjectPtr full_slice(PySlice_New(nullptr, nullptr, nullptr));
if (full_slice == nullptr) {
Expand All @@ -1161,10 +1148,6 @@ static PyObject* Clear(PyObject* pself) {
RepeatedScalarContainer* self =
reinterpret_cast<RepeatedScalarContainer*>(pself);

// TODO: b/517235198 - Reify even for empty sequences.
if (Len(pself) == 0) {
Py_RETURN_NONE;
}

CMessage* cmessage = self->parent;
Message* message = cmessage::AssureWritable(cmessage);
Expand Down
7 changes: 0 additions & 7 deletions python/repeated.c
Original file line number Diff line number Diff line change
Expand Up @@ -869,12 +869,7 @@ static PyObject* PyUpb_RepeatedContainer_Sort(PyObject* pself, PyObject* args,
}
}

if (PyUpb_RepeatedContainer_IsFrozen((PyUpb_RepeatedContainer*)pself)) {
return PyUpb_SetFrozenErrorWithMsg("Container is immutable");
}

// TODO:b/517235198 - Reify even for empty sequences.
if (PyUpb_RepeatedContainer_Length(pself) == 0) Py_RETURN_NONE;

upb_Array* arr = PyUpb_RepeatedContainer_AssureWritable(pself);
if (!arr) return NULL;
Expand Down Expand Up @@ -918,8 +913,6 @@ static PyObject* PyUpb_RepeatedContainer_Reverse(PyObject* _self) {

static PyObject* PyUpb_RepeatedContainer_Clear(PyObject* _self) {
Py_ssize_t size = PyUpb_RepeatedContainer_Length(_self);
// TODO: b/517235198 - Reify even for empty sequences.
if (size == 0) Py_RETURN_NONE;

PyUpb_RepeatedContainer* self = (PyUpb_RepeatedContainer*)_self;
upb_Array* arr = PyUpb_RepeatedContainer_AssureWritable(_self);
Expand Down
Loading