Skip to content

Commit

Permalink
Run the precommit linters against the code to make them operate as ex…
Browse files Browse the repository at this point in the history
…pected
  • Loading branch information
ian-otto committed Mar 21, 2020
1 parent 0b130d2 commit e50fe20
Show file tree
Hide file tree
Showing 3 changed files with 57 additions and 211 deletions.
258 changes: 52 additions & 206 deletions cheroot/server.py
Original file line number Diff line number Diff line change
Expand Up @@ -87,15 +87,14 @@
from six.moves import urllib

from . import connections, errors, __version__
from ._compat import bton, ntou
from ._compat import bton
from ._compat import IS_PPC
from .workers import threadpool
from .makefile import MakeFile, StreamWriter

__all__ = (
'HTTPRequest', 'HTTPConnection', 'HTTPServer',
'HeaderReader', 'DropUnderscoreHeaderReader',
'SizeCheckWrapper', 'KnownLengthRFile', 'ChunkedRFile',
'KnownLengthRFile', 'ChunkedRFile',
'Gateway', 'get_ssl_adapter_class',
)

Expand Down Expand Up @@ -166,174 +165,6 @@
logging.statistics = {}


class HeaderReader:
"""Object for reading headers from an HTTP request.
Interface and default implementation.
"""

def __call__(self, rfile, hdict=None):
"""
Read headers from the given stream into the given header dict.
If hdict is None, a new header dict is created. Returns the populated
header dict.
Headers which are repeated are folded together using a comma if their
specification so dictates.
This function raises ValueError when the read bytes violate the HTTP
spec.
You should probably return "400 Bad Request" if this happens.
"""
if hdict is None:
hdict = {}

while True:
line = rfile.readline()
if not line:
# No more data--illegal end of headers
raise ValueError('Illegal end of headers.')

if line == CRLF:
# Normal end of headers
break
if not line.endswith(CRLF):
raise ValueError('HTTP requires CRLF terminators')

if line[0] in (SPACE, TAB):
# It's a continuation line.
v = line.strip()
else:
try:
k, v = line.split(COLON, 1)
except ValueError:
raise ValueError('Illegal header line.')
v = v.strip()
k = self._transform_key(k)
hname = k

if not self._allow_header(k):
continue

if k in comma_separated_headers:
existing = hdict.get(hname)
if existing:
v = b', '.join((existing, v))
hdict[hname] = v

return hdict

def _allow_header(self, key_name):
return True

def _transform_key(self, key_name):
# TODO: what about TE and WWW-Authenticate?
return key_name.strip().title()


class DropUnderscoreHeaderReader(HeaderReader):
"""Custom HeaderReader to exclude any headers with underscores in them."""

def _allow_header(self, key_name):
orig = super(DropUnderscoreHeaderReader, self)._allow_header(key_name)
return orig and '_' not in key_name


class SizeCheckWrapper:
"""Wraps a file-like object, raising MaxSizeExceeded if too large.
:param rfile: ``file`` of a limited size
:param int maxlen: maximum length of the file being read
"""

def __init__(self, rfile, maxlen):
"""Initialize SizeCheckWrapper instance."""
self.rfile = rfile
self.maxlen = maxlen
self.bytes_read = 0

def _check_length(self):
if self.maxlen and self.bytes_read > self.maxlen:
raise errors.MaxSizeExceeded()

def read(self, size=None):
"""Read a chunk from ``rfile`` buffer and return it.
:param int size: amount of data to read
:returns: chunk from ``rfile``, limited by size if specified
:rtype: bytes
"""
data = self.rfile.read(size)
self.bytes_read += len(data)
self._check_length()
return data

def readline(self, size=None):
"""Read a single line from ``rfile`` buffer and return it.
:param int size: minimum amount of data to read
:returns: one line from ``rfile``
:rtype: bytes
"""
if size is not None:
data = self.rfile.readline(size)
self.bytes_read += len(data)
self._check_length()
return data

# User didn't specify a size ...
# We read the line in chunks to make sure it's not a 100MB line !
res = []
while True:
data = self.rfile.readline(256)
self.bytes_read += len(data)
self._check_length()
res.append(data)
# See https://github.com/cherrypy/cherrypy/issues/421
if len(data) < 256 or data[-1:] == LF:
return EMPTY.join(res)

def readlines(self, sizehint=0):
"""Read all lines from ``rfile`` buffer and return them.
:param int sizehint: hint of minimum amount of data to read
:returns: lines of bytes read from ``rfile``
:rtype: list[bytes]
"""
# Shamelessly stolen from StringIO
total = 0
lines = []
line = self.readline(sizehint)
while line:
lines.append(line)
total += len(line)
if 0 < sizehint <= total:
break
line = self.readline(sizehint)
return lines

def close(self):
"""Release resources allocated for ``rfile``."""
self.rfile.close()

def __iter__(self):
"""Return file iterator."""
return self

def __next__(self):
"""Generate next file chunk."""
data = next(self.rfile)
self.bytes_read += len(data)
self._check_length()
return data

next = __next__


class RFile(ABC):
def __init__(self, rfile, wfile, conn):
self.rfile = rfile
Expand All @@ -358,7 +189,7 @@ def _handle_input(self, data):
if isinstance(evt, h11.Data):
return bytes(evt.data)
else:
return b""
return b''

@abstractmethod
def read(self, size=None):
Expand Down Expand Up @@ -704,11 +535,6 @@ class HTTPRequest:
This value is set automatically inside send_headers."""

header_reader = HeaderReader()
"""
A HeaderReader instance or compatible reader.
"""

def __init__(self, server, conn, proxy_mode=False, strict_mode=True):
"""Initialize HTTP request container instance.
Expand Down Expand Up @@ -745,7 +571,7 @@ def __init__(self, server, conn, proxy_mode=False, strict_mode=True):
self.strict_mode = strict_mode

def _process_next_h11_event(self, read_new=True):
"""Instruct h11 to process data in its buffer and return any events it has available
"""Instruct h11 to process data in its buffer and return any events it has available.
:param read_new: If the method should attempt to read new lines to find an event
:type read_new: bool
Expand All @@ -771,20 +597,22 @@ def parse_request(self):
if isinstance(req_line, h11.Request):
self.started_request = True
self.uri = req_line.target
scheme, authority, path, query, fragment = urllib.parse.urlsplit(self.uri)
self.qs = query
scheme, authority, path, qs, fragment = urllib.parse.urlsplit(self.uri) # noqa
self.qs = qs
self.method = req_line.method
self.request_protocol = b"HTTP/%s" % req_line.http_version
self.request_protocol = b'HTTP/%s' % req_line.http_version
if req_line.http_version > b'1.1':
self.simple_response(505, "Cannot fulfill request")
self.response_protocol = b"HTTP/1.1"
self.simple_response(505, 'Cannot fulfill request')
self.response_protocol = b'HTTP/1.1'
# TODO: oneliner-ify this
self.inheaders = {}
for header in req_line.headers:
self.inheaders[header[0]] = header[1]

if (b'transfer-encoding' in self.inheaders and
self.inheaders[b'transfer-encoding'].lower() == b"chunked"):
if (
b'transfer-encoding' in self.inheaders and
self.inheaders[b'transfer-encoding'].lower() == b'chunked'
):
self.chunked_read = True

uri_is_absolute_form = scheme or authority
Expand All @@ -803,8 +631,9 @@ def parse_request(self):
self.simple_response(405)
return False

# `urlsplit()` above parses "example.com:3128" as path part of URI.
# this is a workaround, which makes it detect netloc correctly
# `urlsplit()` above parses "example.com:3128" as path part
# of URI. This is a workaround, which makes it detect
# netloc correctly
uri_split = urllib.parse.urlsplit(b''.join((b'//', self.uri)))
_scheme, _authority, _path, _qs, _fragment = uri_split
_port = EMPTY
Expand Down Expand Up @@ -844,7 +673,8 @@ def parse_request(self):
"""Absolute URI is only allowed within proxies."""
self.simple_response(
400,
'Absolute URI not allowed if server is not a proxy.',
'Absolute URI not allowed'
' if server is not a proxy.',
)
return False

Expand Down Expand Up @@ -876,7 +706,10 @@ def parse_request(self):
return False
path = QUOTED_SLASH.join(atoms)
if fragment:
self.simple_response(400, "Illegal #fragment in Request-URI.")
self.simple_response(
400,
'Illegal #fragment in Request-URI.',
)

if not path.startswith(FORWARD_SLASH):
path = FORWARD_SLASH + path
Expand All @@ -887,10 +720,8 @@ def parse_request(self):
self.close_connection = True
else:
# TODO
raise NotImplementedError("Only expecting Request object here")
raise NotImplementedError('Only expecting Request object here')
except h11.RemoteProtocolError as e:
# NEED INFO: Should we adjust the tests to match h11's exception text
# Or should we continue to shim like this:
err_map = {
'bad Content-Length': 'Malformed Content-Length Header.',
'illegal request line': 'Malformed Request-Line.',
Expand All @@ -904,7 +735,10 @@ def respond(self):
"""Call the gateway and write its iterable output."""
mrbs = self.server.max_request_body_size
if self.chunked_read:
self.rfile = ChunkedRFile(self.h_conn, self.conn.rfile, self.conn.wfile, mrbs)
self.rfile = ChunkedRFile(
self.h_conn,
self.conn.rfile, self.conn.wfile, mrbs,
)
else:
cl = int(self.inheaders.get(b'content-length', 0))
if mrbs and mrbs < cl:
Expand All @@ -915,29 +749,36 @@ def respond(self):
'maximum allowed bytes.',
)
return
self.rfile = KnownLengthRFile(self.h_conn, self.conn.rfile, self.conn.wfile, cl)
self.rfile = KnownLengthRFile(
self.h_conn,
self.conn.rfile, self.conn.wfile, cl,
)

# client may still be in send body, lets find out and figure out what to do with that
# if self.h_conn.their_state == h11.SEND_BODY:
if self.h_conn.client_is_waiting_for_100_continue:
mini_headers = ()
go_ahead = h11.InformationalResponse(status_code=100, headers=mini_headers)
go_ahead = h11.InformationalResponse(
status_code=100,
headers=mini_headers,
)
bytes_out = self.h_conn.send(go_ahead)
self.conn.wfile.write(bytes_out)
try:
self.server.gateway(self).respond()
self.ready and self.ensure_headers_sent()
except errors.MaxSizeExceeded as e:
self.simple_response(413, "Request Entity Too Large")
self.simple_response(413, 'Request Entity Too Large')
self.close_connection = True
return

while self.h_conn.their_state is h11.SEND_BODY and self.h_conn.our_state is not h11.ERROR:
# empty their buffer ?
while (
self.h_conn.their_state is h11.SEND_BODY
and self.h_conn.our_state is not h11.ERROR
):
data = self.rfile.read()
self._process_next_h11_event(read_new=False)
if data == EMPTY and self.h_conn.their_state is h11.SEND_BODY:
# they didn't send a full body, kill connection, set our state to ERROR
# they didn't send a full body, kill connection,
# set our state to ERROR
self.h_conn.send_failed()

# If we haven't sent our end-of-message data, send it now
Expand All @@ -946,7 +787,10 @@ def respond(self):
self.conn.wfile.write(bytes_out)

# prep for next req cycle if it's available
if self.h_conn.our_state is h11.DONE and self.h_conn.their_state is h11.DONE:
if (
self.h_conn.our_state is h11.DONE
and self.h_conn.their_state is h11.DONE
):
self.h_conn.start_next_cycle()
self.close_connection = False
else:
Expand All @@ -959,14 +803,14 @@ def simple_response(self, status, msg=None):
msg = ''
status = str(status)
headers = [
('Content-Type', 'text/plain')
('Content-Type', 'text/plain'),
]

self.outheaders = headers
self.status = status
self.send_headers()
if msg:
self.write(bytes(msg, encoding="ascii"))
self.write(bytes(msg, encoding='ascii'))

evt = h11.EndOfMessage()
bytes_out = self.h_conn.send(evt)
Expand Down Expand Up @@ -1005,8 +849,10 @@ def send_headers(self):
self.server.server_name.encode('ISO-8859-1'),
))

res = h11.Response(status_code=status, headers=self.outheaders,
http_version=self.response_protocol[5:], reason=self.status[3:])
res = h11.Response(
status_code=status, headers=self.outheaders,
http_version=self.response_protocol[5:], reason=self.status[3:],
)
res_bytes = self.h_conn.send(res)
self.conn.wfile.write(res_bytes)

Expand Down
Loading

0 comments on commit e50fe20

Please sign in to comment.