Skip to content

Commit 9d77f94

Browse files
committed
Support checking structs for equality.
Also: - Rename _fields to _fieldnames, because the history of this CL shows that _fields is confusing - Fix some lint errors. Change-Id: I23d0191e6a588820b3697b2d36d70880ea921d8a
1 parent 5879932 commit 9d77f94

File tree

2 files changed

+81
-10
lines changed

2 files changed

+81
-10
lines changed

tests/net_test/cstruct.py

Lines changed: 21 additions & 10 deletions
Original file line numberDiff line numberDiff line change
@@ -57,7 +57,7 @@ def CalcNumElements(fmt):
5757
return len(elements)
5858

5959

60-
def Struct(name, fmt, fields, substructs={}):
60+
def Struct(name, fmt, fieldnames, substructs={}):
6161
"""Function that returns struct classes."""
6262

6363
class Meta(type):
@@ -77,12 +77,12 @@ class CStruct(object):
7777
# Name of the struct.
7878
_name = name
7979
# List of field names.
80-
_fields = fields
80+
_fieldnames = fieldnames
8181
# Dict mapping field indices to nested struct classes.
8282
_nested = {}
8383

84-
if isinstance(_fields, str):
85-
_fields = _fields.split(" ")
84+
if isinstance(_fieldnames, str):
85+
_fieldnames = _fieldnames.split(" ")
8686

8787
# Parse fmt into _format, converting any S format characters to "XXs",
8888
# where XX is the length of the struct type's packed representation.
@@ -121,14 +121,14 @@ def __init__(self, values):
121121
self._Parse(values)
122122
else:
123123
# Initializing from a tuple.
124-
if len(values) != len(self._fields):
125-
raise TypeError("%s has exactly %d fields (%d given)" %
126-
(self._name, len(self._fields), len(values)))
124+
if len(values) != len(self._fieldnames):
125+
raise TypeError("%s has exactly %d fieldnames (%d given)" %
126+
(self._name, len(self._fieldnames), len(values)))
127127
self._SetValues(values)
128128

129129
def _FieldIndex(self, attr):
130130
try:
131-
return self._fields.index(attr)
131+
return self._fieldnames.index(attr)
132132
except ValueError:
133133
raise AttributeError("'%s' has no attribute '%s'" %
134134
(self._name, attr))
@@ -143,6 +143,15 @@ def __setattr__(self, name, value):
143143
def __len__(cls):
144144
return cls._length
145145

146+
def __ne__(self, other):
147+
return not self.__eq__(other)
148+
149+
def __eq__(self, other):
150+
return (isinstance(other, self.__class__) and
151+
self._name == other._name and
152+
self._fieldnames == other._fieldnames and
153+
self._values == other._values)
154+
146155
@staticmethod
147156
def _MaybePackStruct(value):
148157
if hasattr(value, "__metaclass__"):# and value.__metaclass__ == Meta:
@@ -156,12 +165,14 @@ def Pack(self):
156165

157166
def __str__(self):
158167
def FieldDesc(index, name, value):
159-
if isinstance(value, str) and any (c not in string.printable for c in value):
168+
if isinstance(value, str) and any(
169+
c not in string.printable for c in value):
160170
value = value.encode("hex")
161171
return "%s=%s" % (name, value)
162172

163173
descriptions = [
164-
FieldDesc(i, n, v) for i, (n, v) in enumerate(zip(self._fields, self._values))]
174+
FieldDesc(i, n, v) for i, (n, v) in
175+
enumerate(zip(self._fieldnames, self._values))]
165176

166177
return "%s(%s)" % (self._name, ", ".join(descriptions))
167178

tests/net_test/cstruct_test.py

Lines changed: 60 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,60 @@
1+
#!/usr/bin/python
2+
#
3+
# Copyright 2016 The Android Open Source Project
4+
#
5+
# Licensed under the Apache License, Version 2.0 (the "License");
6+
# you may not use this file except in compliance with the License.
7+
# You may obtain a copy of the License at
8+
#
9+
# http://www.apache.org/licenses/LICENSE-2.0
10+
#
11+
# Unless required by applicable law or agreed to in writing, software
12+
# distributed under the License is distributed on an "AS IS" BASIS,
13+
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
14+
# See the License for the specific language governing permissions and
15+
# limitations under the License.
16+
17+
import unittest
18+
19+
import cstruct
20+
21+
22+
# These aren't constants, they're classes. So, pylint: disable=invalid-name
23+
TestStructA = cstruct.Struct("TestStructA", "=BI", "byte1 int2")
24+
TestStructB = cstruct.Struct("TestStructB", "=BI", "byte1 int2")
25+
26+
27+
class CstructTest(unittest.TestCase):
28+
29+
def CheckEquals(self, a, b):
30+
self.assertEquals(a, b)
31+
self.assertEquals(b, a)
32+
assert a == b
33+
assert b == a
34+
assert not (a != b) # pylint: disable=g-comparison-negation,superfluous-parens
35+
assert not (b != a) # pylint: disable=g-comparison-negation,superfluous-parens
36+
37+
def CheckNotEquals(self, a, b):
38+
self.assertNotEquals(a, b)
39+
self.assertNotEquals(b, a)
40+
assert a != b
41+
assert b != a
42+
assert not (a == b) # pylint: disable=g-comparison-negation,superfluous-parens
43+
assert not (b == a) # pylint: disable=g-comparison-negation,superfluous-parens
44+
45+
def testEqAndNe(self):
46+
a1 = TestStructA((1, 2))
47+
a2 = TestStructA((2, 3))
48+
a3 = TestStructA((1, 2))
49+
b = TestStructB((1, 2))
50+
self.CheckNotEquals(a1, b)
51+
self.CheckNotEquals(a2, b)
52+
self.CheckNotEquals(a1, a2)
53+
self.CheckNotEquals(a2, a3)
54+
for i in [a1, a2, a3, b]:
55+
self.CheckEquals(i, i)
56+
self.CheckEquals(a1, a3)
57+
58+
59+
if __name__ == "__main__":
60+
unittest.main()

0 commit comments

Comments
 (0)