Skip to content

Commit 5879932

Browse files
lcolittiGerrit Code Review
authored and
Gerrit Code Review
committed
Merge "Add code and tests for inet_diag bytecode."
2 parents 83ced5b + 093d6d4 commit 5879932

File tree

3 files changed

+165
-11
lines changed

3 files changed

+165
-11
lines changed

tests/net_test/netlink.py

+4-3
Original file line numberDiff line numberDiff line change
@@ -79,11 +79,12 @@ def _NlAttrU32(self, nla_type, value):
7979
def _GetConstantName(self, module, value, prefix):
8080
thismodule = sys.modules[module]
8181
for name in dir(thismodule):
82+
if name.startswith("INET_DIAG_BC"):
83+
break
8284
if (name.startswith(prefix) and
8385
not name.startswith(prefix + "F_") and
84-
name.isupper() and
85-
getattr(thismodule, name) == value):
86-
return name
86+
name.isupper() and getattr(thismodule, name) == value):
87+
return name
8788
return value
8889

8990
def _Decode(self, command, msg, nla_type, nla_data):

tests/net_test/sock_diag.py

+109-5
Original file line numberDiff line numberDiff line change
@@ -20,6 +20,7 @@
2020

2121
import errno
2222
from socket import * # pylint: disable=wildcard-import
23+
import struct
2324

2425
import cstruct
2526
import net_test
@@ -37,6 +38,9 @@
3738
# Message types.
3839
TCPDIAG_GETSOCK = 18
3940

41+
# Request attributes.
42+
INET_DIAG_REQ_BYTECODE = 1
43+
4044
# Extensions.
4145
INET_DIAG_NONE = 0
4246
INET_DIAG_MEMINFO = 1
@@ -49,6 +53,17 @@
4953
INET_DIAG_SHUTDOWN = 8
5054
INET_DIAG_DCTCPINFO = 9
5155

56+
# Bytecode operations.
57+
INET_DIAG_BC_NOP = 0
58+
INET_DIAG_BC_JMP = 1
59+
INET_DIAG_BC_S_GE = 2
60+
INET_DIAG_BC_S_LE = 3
61+
INET_DIAG_BC_D_GE = 4
62+
INET_DIAG_BC_D_LE = 5
63+
INET_DIAG_BC_AUTO = 6
64+
INET_DIAG_BC_S_COND = 7
65+
INET_DIAG_BC_D_COND = 8
66+
5267
# Data structure formats.
5368
# These aren't constants, they're classes. So, pylint: disable=invalid-name
5469
InetDiagSockId = cstruct.Struct(
@@ -62,6 +77,9 @@
6277
[InetDiagSockId])
6378
InetDiagMeminfo = cstruct.Struct(
6479
"InetDiagMeminfo", "=IIII", "rmem wmem fmem tmem")
80+
InetDiagBcOp = cstruct.Struct("InetDiagBcOp", "BBH", "code yes no")
81+
InetDiagHostcond = cstruct.Struct("InetDiagHostcond", "=BBxxi",
82+
"family prefix_len port")
6583

6684
SkMeminfo = cstruct.Struct(
6785
"SkMeminfo", "=IIIIIIII",
@@ -133,22 +151,108 @@ def MaybeDebugCommand(self, command, data):
133151
def _EmptyInetDiagSockId():
134152
return InetDiagSockId(("\x00" * len(InetDiagSockId)))
135153

136-
def Dump(self, diag_req):
137-
out = self._Dump(SOCK_DIAG_BY_FAMILY, diag_req, InetDiagMsg, "")
154+
def PackBytecode(self, instructions):
155+
"""Compiles instructions to inet_diag bytecode.
156+
157+
The input is a list of (INET_DIAG_BC_xxx, yes, no, arg) tuples, where yes
158+
and no are relative jump offsets measured in instructions. The yes branch
159+
is taken if the instruction matches.
160+
161+
To accept, jump 1 past the last instruction. To reject, jump 2 past the
162+
last instruction.
163+
164+
The target of a no jump is only valid if it is reachable by following
165+
only yes jumps from the first instruction - see inet_diag_bc_audit and
166+
valid_cc. This means that if cond1 and cond2 are two mutually exclusive
167+
filter terms, it is not possible to implement cond1 OR cond2 using:
168+
169+
...
170+
cond1 2 1 arg
171+
cond2 1 2 arg
172+
accept
173+
reject
174+
175+
but only using:
176+
177+
...
178+
cond1 1 2 arg
179+
jmp 1 2
180+
cond2 1 2 arg
181+
accept
182+
reject
183+
184+
The jmp instruction ignores yes and always jumps to no, but yes must be 1
185+
or the bytecode won't validate. It doesn't have to be jmp - any instruction
186+
that is guaranteed not to match on real data will do.
187+
188+
Args:
189+
instructions: list of instruction tuples
190+
191+
Returns:
192+
A string, the raw bytecode.
193+
"""
194+
args = []
195+
positions = [0]
196+
197+
for op, yes, no, arg in instructions:
198+
199+
if yes <= 0 or no <= 0:
200+
raise ValueError("Jumps must be > 0")
201+
202+
if op in [INET_DIAG_BC_NOP, INET_DIAG_BC_JMP, INET_DIAG_BC_AUTO]:
203+
arg = ""
204+
elif op in [INET_DIAG_BC_S_GE, INET_DIAG_BC_S_LE,
205+
INET_DIAG_BC_D_GE, INET_DIAG_BC_D_LE]:
206+
arg = "\x00\x00" + struct.pack("=H", arg)
207+
elif op in [INET_DIAG_BC_S_COND, INET_DIAG_BC_D_COND]:
208+
addr, prefixlen, port = arg
209+
family = AF_INET6 if ":" in addr else AF_INET
210+
addr = inet_pton(family, addr)
211+
arg = InetDiagHostcond((family, prefixlen, port)).Pack() + addr
212+
else:
213+
raise ValueError("Unsupported opcode %d" % op)
214+
215+
args.append(arg)
216+
length = len(InetDiagBcOp) + len(arg)
217+
positions.append(positions[-1] + length)
218+
219+
# Reject label.
220+
positions.append(positions[-1] + 4) # Why 4? Because the kernel uses 4.
221+
assert len(args) == len(instructions) == len(positions) - 2
222+
223+
# print positions
224+
225+
packed = ""
226+
for i, (op, yes, no, arg) in enumerate(instructions):
227+
yes = positions[i + yes] - positions[i]
228+
no = positions[i + no] - positions[i]
229+
instruction = InetDiagBcOp((op, yes, no)).Pack() + args[i]
230+
#print "%3d: %d %3d %3d %s %s" % (positions[i], op, yes, no,
231+
# arg, instruction.encode("hex"))
232+
packed += instruction
233+
#print
234+
235+
return packed
236+
237+
def Dump(self, diag_req, bytecode=""):
238+
out = self._Dump(SOCK_DIAG_BY_FAMILY, diag_req, InetDiagMsg, bytecode)
138239
return out
139240

140-
def DumpAllInetSockets(self, protocol, sock_id=None, ext=0,
241+
def DumpAllInetSockets(self, protocol, bytecode, sock_id=None, ext=0,
141242
states=ALL_NON_TIME_WAIT):
142243
"""Dumps IPv4 or IPv6 sockets matching the specified parameters."""
143244
# DumpSockets(AF_UNSPEC) does not result in dumping all inet sockets, it
144245
# results in ENOENT.
145246
if sock_id is None:
146247
sock_id = self._EmptyInetDiagSockId()
147248

249+
if bytecode:
250+
bytecode = self._NlAttr(INET_DIAG_REQ_BYTECODE, bytecode)
251+
148252
sockets = []
149253
for family in [AF_INET, AF_INET6]:
150254
diag_req = InetDiagReqV2((family, protocol, ext, states, sock_id))
151-
sockets += self.Dump(diag_req)
255+
sockets += self.Dump(diag_req, bytecode)
152256

153257
return sockets
154258

@@ -255,6 +359,6 @@ def CloseSocketFromFd(self, s):
255359
sock_id.dport = 443
256360
ext = 1 << (INET_DIAG_TOS - 1) | 1 << (INET_DIAG_TCLASS - 1)
257361
states = 0xffffffff
258-
diag_msgs = n.DumpAllInetSockets(IPPROTO_TCP,
362+
diag_msgs = n.DumpAllInetSockets(IPPROTO_TCP, "",
259363
sock_id=sock_id, ext=ext, states=states)
260364
print diag_msgs

tests/net_test/sock_diag_test.py

+52-3
Original file line numberDiff line numberDiff line change
@@ -31,6 +31,7 @@
3131

3232

3333
NUM_SOCKETS = 100
34+
NO_BYTECODE = ""
3435

3536
# TODO: Backport SOCK_DESTROY and delete this.
3637
HAVE_SOCK_DESTROY = net_test.LINUX_VERSION >= (4, 4)
@@ -115,7 +116,7 @@ def assertSockDiagMatchesSocket(self, s, diag_msg):
115116

116117
def testFindsAllMySockets(self):
117118
self.socketpairs = self._CreateLotsOfSockets()
118-
sockets = self.sock_diag.DumpAllInetSockets(IPPROTO_TCP)
119+
sockets = self.sock_diag.DumpAllInetSockets(IPPROTO_TCP, NO_BYTECODE)
119120
self.assertGreaterEqual(len(sockets), NUM_SOCKETS)
120121

121122
# Find the cookies for all of our sockets.
@@ -149,6 +150,54 @@ def testFindsAllMySockets(self):
149150
diag_msg = self.sock_diag.GetSockDiag(req)
150151
self.assertSockDiagMatchesSocket(sock, diag_msg)
151152

153+
def testBytecodeCompilation(self):
154+
instructions = [
155+
(sock_diag.INET_DIAG_BC_S_GE, 1, 8, 0), # 0
156+
(sock_diag.INET_DIAG_BC_D_LE, 1, 7, 0xffff), # 8
157+
(sock_diag.INET_DIAG_BC_S_COND, 1, 2, ("::1", 128, -1)), # 16
158+
(sock_diag.INET_DIAG_BC_JMP, 1, 3, None), # 44
159+
(sock_diag.INET_DIAG_BC_S_COND, 2, 4, ("127.0.0.1", 32, -1)), # 48
160+
(sock_diag.INET_DIAG_BC_D_LE, 1, 3, 0x6665), # not used # 64
161+
(sock_diag.INET_DIAG_BC_NOP, 1, 1, None), # 72
162+
# 76 acc
163+
# 80 rej
164+
]
165+
bytecode = self.sock_diag.PackBytecode(instructions)
166+
expected = (
167+
"0208500000000000"
168+
"050848000000ffff"
169+
"071c20000a800000ffffffff00000000000000000000000000000001"
170+
"01041c00"
171+
"0718200002200000ffffffff7f000001"
172+
"0508100000006566"
173+
"00040400"
174+
)
175+
self.assertMultiLineEqual(expected, bytecode.encode("hex"))
176+
self.assertEquals(76, len(bytecode))
177+
self.socketpairs = self._CreateLotsOfSockets()
178+
filteredsockets = self.sock_diag.DumpAllInetSockets(IPPROTO_TCP, bytecode)
179+
allsockets = self.sock_diag.DumpAllInetSockets(IPPROTO_TCP, NO_BYTECODE)
180+
self.assertEquals(len(allsockets), len(filteredsockets))
181+
182+
# Pick a few sockets in hash table order, and check that the bytecode we
183+
# compiled selects them properly.
184+
for socketpair in self.socketpairs.values()[:20]:
185+
for s in socketpair:
186+
diag_msg = self.sock_diag.FindSockDiagFromFd(s)
187+
instructions = [
188+
(sock_diag.INET_DIAG_BC_S_GE, 1, 5, diag_msg.id.sport),
189+
(sock_diag.INET_DIAG_BC_S_LE, 1, 4, diag_msg.id.sport),
190+
(sock_diag.INET_DIAG_BC_D_GE, 1, 3, diag_msg.id.dport),
191+
(sock_diag.INET_DIAG_BC_D_LE, 1, 2, diag_msg.id.dport),
192+
]
193+
bytecode = self.sock_diag.PackBytecode(instructions)
194+
self.assertEquals(32, len(bytecode))
195+
sockets = self.sock_diag.DumpAllInetSockets(IPPROTO_TCP, bytecode)
196+
self.assertEquals(1, len(sockets))
197+
198+
# TODO: why doesn't comparing the cstructs work?
199+
self.assertEquals(diag_msg.Pack(), sockets[0][0].Pack())
200+
152201
@unittest.skipUnless(HAVE_SOCK_DESTROY, "SOCK_DESTROY not supported")
153202
def testClosesSockets(self):
154203
self.socketpairs = self._CreateLotsOfSockets()
@@ -356,7 +405,7 @@ def FindChildSockets(self, s):
356405
req = self.sock_diag.DiagReqFromDiagMsg(d, IPPROTO_TCP)
357406
req.states = 1 << sock_diag.TCP_SYN_RECV | 1 << sock_diag.TCP_ESTABLISHED
358407
req.id.cookie = "\x00" * 8
359-
children = self.sock_diag.Dump(req)
408+
children = self.sock_diag.Dump(req, NO_BYTECODE)
360409
return [self.sock_diag.DiagReqFromDiagMsg(d, IPPROTO_TCP)
361410
for d, _ in children]
362411

@@ -486,7 +535,7 @@ def testIpv4MappedSynRecvSocket(self):
486535
sock_id.sport = self.port
487536
states = 1 << sock_diag.TCP_SYN_RECV
488537
req = sock_diag.InetDiagReqV2((AF_INET6, IPPROTO_TCP, 0, states, sock_id))
489-
children = self.sock_diag.Dump(req)
538+
children = self.sock_diag.Dump(req, NO_BYTECODE)
490539

491540
self.assertTrue(children)
492541
for child, unused_args in children:

0 commit comments

Comments
 (0)