Skip to content

Commit f1b30da

Browse files
lcolittiGerrit Code Review
authored and
Gerrit Code Review
committed
Merge changes I251088dc,I23d0191e
* changes: Test for a cross-family bytecode comparison bug. Support checking structs for equality.
2 parents 5879932 + 0d9dd0d commit f1b30da

File tree

3 files changed

+116
-10
lines changed

3 files changed

+116
-10
lines changed

tests/net_test/cstruct.py

+21-10
Original file line numberDiff line numberDiff line change
@@ -57,7 +57,7 @@ def CalcNumElements(fmt):
5757
return len(elements)
5858

5959

60-
def Struct(name, fmt, fields, substructs={}):
60+
def Struct(name, fmt, fieldnames, substructs={}):
6161
"""Function that returns struct classes."""
6262

6363
class Meta(type):
@@ -77,12 +77,12 @@ class CStruct(object):
7777
# Name of the struct.
7878
_name = name
7979
# List of field names.
80-
_fields = fields
80+
_fieldnames = fieldnames
8181
# Dict mapping field indices to nested struct classes.
8282
_nested = {}
8383

84-
if isinstance(_fields, str):
85-
_fields = _fields.split(" ")
84+
if isinstance(_fieldnames, str):
85+
_fieldnames = _fieldnames.split(" ")
8686

8787
# Parse fmt into _format, converting any S format characters to "XXs",
8888
# where XX is the length of the struct type's packed representation.
@@ -121,14 +121,14 @@ def __init__(self, values):
121121
self._Parse(values)
122122
else:
123123
# Initializing from a tuple.
124-
if len(values) != len(self._fields):
125-
raise TypeError("%s has exactly %d fields (%d given)" %
126-
(self._name, len(self._fields), len(values)))
124+
if len(values) != len(self._fieldnames):
125+
raise TypeError("%s has exactly %d fieldnames (%d given)" %
126+
(self._name, len(self._fieldnames), len(values)))
127127
self._SetValues(values)
128128

129129
def _FieldIndex(self, attr):
130130
try:
131-
return self._fields.index(attr)
131+
return self._fieldnames.index(attr)
132132
except ValueError:
133133
raise AttributeError("'%s' has no attribute '%s'" %
134134
(self._name, attr))
@@ -143,6 +143,15 @@ def __setattr__(self, name, value):
143143
def __len__(cls):
144144
return cls._length
145145

146+
def __ne__(self, other):
147+
return not self.__eq__(other)
148+
149+
def __eq__(self, other):
150+
return (isinstance(other, self.__class__) and
151+
self._name == other._name and
152+
self._fieldnames == other._fieldnames and
153+
self._values == other._values)
154+
146155
@staticmethod
147156
def _MaybePackStruct(value):
148157
if hasattr(value, "__metaclass__"):# and value.__metaclass__ == Meta:
@@ -156,12 +165,14 @@ def Pack(self):
156165

157166
def __str__(self):
158167
def FieldDesc(index, name, value):
159-
if isinstance(value, str) and any (c not in string.printable for c in value):
168+
if isinstance(value, str) and any(
169+
c not in string.printable for c in value):
160170
value = value.encode("hex")
161171
return "%s=%s" % (name, value)
162172

163173
descriptions = [
164-
FieldDesc(i, n, v) for i, (n, v) in enumerate(zip(self._fields, self._values))]
174+
FieldDesc(i, n, v) for i, (n, v) in
175+
enumerate(zip(self._fieldnames, self._values))]
165176

166177
return "%s(%s)" % (self._name, ", ".join(descriptions))
167178

tests/net_test/cstruct_test.py

+60
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,60 @@
1+
#!/usr/bin/python
2+
#
3+
# Copyright 2016 The Android Open Source Project
4+
#
5+
# Licensed under the Apache License, Version 2.0 (the "License");
6+
# you may not use this file except in compliance with the License.
7+
# You may obtain a copy of the License at
8+
#
9+
# http://www.apache.org/licenses/LICENSE-2.0
10+
#
11+
# Unless required by applicable law or agreed to in writing, software
12+
# distributed under the License is distributed on an "AS IS" BASIS,
13+
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
14+
# See the License for the specific language governing permissions and
15+
# limitations under the License.
16+
17+
import unittest
18+
19+
import cstruct
20+
21+
22+
# These aren't constants, they're classes. So, pylint: disable=invalid-name
23+
TestStructA = cstruct.Struct("TestStructA", "=BI", "byte1 int2")
24+
TestStructB = cstruct.Struct("TestStructB", "=BI", "byte1 int2")
25+
26+
27+
class CstructTest(unittest.TestCase):
28+
29+
def CheckEquals(self, a, b):
30+
self.assertEquals(a, b)
31+
self.assertEquals(b, a)
32+
assert a == b
33+
assert b == a
34+
assert not (a != b) # pylint: disable=g-comparison-negation,superfluous-parens
35+
assert not (b != a) # pylint: disable=g-comparison-negation,superfluous-parens
36+
37+
def CheckNotEquals(self, a, b):
38+
self.assertNotEquals(a, b)
39+
self.assertNotEquals(b, a)
40+
assert a != b
41+
assert b != a
42+
assert not (a == b) # pylint: disable=g-comparison-negation,superfluous-parens
43+
assert not (b == a) # pylint: disable=g-comparison-negation,superfluous-parens
44+
45+
def testEqAndNe(self):
46+
a1 = TestStructA((1, 2))
47+
a2 = TestStructA((2, 3))
48+
a3 = TestStructA((1, 2))
49+
b = TestStructB((1, 2))
50+
self.CheckNotEquals(a1, b)
51+
self.CheckNotEquals(a2, b)
52+
self.CheckNotEquals(a1, a2)
53+
self.CheckNotEquals(a2, a3)
54+
for i in [a1, a2, a3, b]:
55+
self.CheckEquals(i, i)
56+
self.CheckEquals(a1, a3)
57+
58+
59+
if __name__ == "__main__":
60+
unittest.main()

tests/net_test/sock_diag_test.py

+35
Original file line numberDiff line numberDiff line change
@@ -198,6 +198,41 @@ def testBytecodeCompilation(self):
198198
# TODO: why doesn't comparing the cstructs work?
199199
self.assertEquals(diag_msg.Pack(), sockets[0][0].Pack())
200200

201+
def testCrossFamilyBytecode(self):
202+
"""Checks for a cross-family bug in inet_diag_hostcond matching.
203+
204+
Relevant kernel commits:
205+
android-3.4:
206+
f67caec inet_diag: avoid unsafe and nonsensical prefix matches in inet_diag_bc_run()
207+
"""
208+
pair4 = net_test.CreateSocketPair(AF_INET, SOCK_STREAM, "127.0.0.1")
209+
pair6 = net_test.CreateSocketPair(AF_INET6, SOCK_STREAM, "::1")
210+
211+
bytecode4 = self.sock_diag.PackBytecode([
212+
(sock_diag.INET_DIAG_BC_S_COND, 1, 2, ("0.0.0.0", 0, -1))])
213+
bytecode6 = self.sock_diag.PackBytecode([
214+
(sock_diag.INET_DIAG_BC_S_COND, 1, 2, ("::", 0, -1))])
215+
216+
# IPv4/v6 filters must never match IPv6/IPv4 sockets...
217+
v4sockets = self.sock_diag.DumpAllInetSockets(IPPROTO_TCP, bytecode4)
218+
self.assertTrue(v4sockets)
219+
self.assertTrue(all(d.family == AF_INET for d, _ in v4sockets))
220+
221+
v6sockets = self.sock_diag.DumpAllInetSockets(IPPROTO_TCP, bytecode6)
222+
self.assertTrue(v6sockets)
223+
self.assertTrue(all(d.family == AF_INET6 for d, _ in v6sockets))
224+
225+
# Except for mapped addresses, which match both IPv4 and IPv6.
226+
pair5 = net_test.CreateSocketPair(AF_INET6, SOCK_STREAM,
227+
"::ffff:127.0.0.1")
228+
diag_msgs = [self.sock_diag.FindSockDiagFromFd(s) for s in pair5]
229+
v4sockets = [d for d, _ in self.sock_diag.DumpAllInetSockets(IPPROTO_TCP,
230+
bytecode4)]
231+
v6sockets = [d for d, _ in self.sock_diag.DumpAllInetSockets(IPPROTO_TCP,
232+
bytecode6)]
233+
self.assertTrue(all(d in v4sockets for d in diag_msgs))
234+
self.assertTrue(all(d in v6sockets for d in diag_msgs))
235+
201236
@unittest.skipUnless(HAVE_SOCK_DESTROY, "SOCK_DESTROY not supported")
202237
def testClosesSockets(self):
203238
self.socketpairs = self._CreateLotsOfSockets()

0 commit comments

Comments
 (0)