Skip to content

Commit 77bdf03

Browse files
committed
Simplify untangling of import cycles
Three imports at the bottom, related to type annotations, is a lesser evil compared to dozens of module-prefixed identifiers, a departure from the coding style of this library. Refs #989.
1 parent 9fcd8fd commit 77bdf03

File tree

7 files changed

+74
-78
lines changed

7 files changed

+74
-78
lines changed

src/websockets/exceptions.py

Lines changed: 3 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -33,7 +33,6 @@
3333
import typing
3434
import warnings
3535

36-
from . import frames, http11
3736
from .imports import lazy_import
3837

3938

@@ -376,3 +375,6 @@ class InvalidState(WebSocketException, AssertionError):
376375
"WebSocketProtocolError": ".legacy.exceptions",
377376
},
378377
)
378+
379+
# At the bottom to break import cycles created by type annotations.
380+
from . import frames, http11 # noqa: E402

src/websockets/extensions/base.py

Lines changed: 3 additions & 8 deletions
Original file line numberDiff line numberDiff line change
@@ -2,7 +2,7 @@
22

33
from typing import Sequence
44

5-
from .. import frames
5+
from ..frames import Frame
66
from ..typing import ExtensionName, ExtensionParameter
77

88

@@ -18,12 +18,7 @@ class Extension:
1818
name: ExtensionName
1919
"""Extension identifier."""
2020

21-
def decode(
22-
self,
23-
frame: frames.Frame,
24-
*,
25-
max_size: int | None = None,
26-
) -> frames.Frame:
21+
def decode(self, frame: Frame, *, max_size: int | None = None) -> Frame:
2722
"""
2823
Decode an incoming frame.
2924
@@ -40,7 +35,7 @@ def decode(
4035
"""
4136
raise NotImplementedError
4237

43-
def encode(self, frame: frames.Frame) -> frames.Frame:
38+
def encode(self, frame: Frame) -> Frame:
4439
"""
4540
Encode an outgoing frame.
4641

src/websockets/extensions/permessage_deflate.py

Lines changed: 28 additions & 20 deletions
Original file line numberDiff line numberDiff line change
@@ -4,7 +4,15 @@
44
import zlib
55
from typing import Any, Sequence
66

7-
from .. import exceptions, frames
7+
from .. import frames
8+
from ..exceptions import (
9+
DuplicateParameter,
10+
InvalidParameterName,
11+
InvalidParameterValue,
12+
NegotiationError,
13+
PayloadTooBig,
14+
ProtocolError,
15+
)
816
from ..typing import ExtensionName, ExtensionParameter
917
from .base import ClientExtensionFactory, Extension, ServerExtensionFactory
1018

@@ -129,9 +137,9 @@ def decode(
129137
try:
130138
data = self.decoder.decompress(data, max_length)
131139
except zlib.error as exc:
132-
raise exceptions.ProtocolError("decompression failed") from exc
140+
raise ProtocolError("decompression failed") from exc
133141
if self.decoder.unconsumed_tail:
134-
raise exceptions.PayloadTooBig(f"over size limit (? > {max_size} bytes)")
142+
raise PayloadTooBig(f"over size limit (? > {max_size} bytes)")
135143

136144
# Allow garbage collection of the decoder if it won't be reused.
137145
if frame.fin and self.remote_no_context_takeover:
@@ -215,40 +223,40 @@ def _extract_parameters(
215223
for name, value in params:
216224
if name == "server_no_context_takeover":
217225
if server_no_context_takeover:
218-
raise exceptions.DuplicateParameter(name)
226+
raise DuplicateParameter(name)
219227
if value is None:
220228
server_no_context_takeover = True
221229
else:
222-
raise exceptions.InvalidParameterValue(name, value)
230+
raise InvalidParameterValue(name, value)
223231

224232
elif name == "client_no_context_takeover":
225233
if client_no_context_takeover:
226-
raise exceptions.DuplicateParameter(name)
234+
raise DuplicateParameter(name)
227235
if value is None:
228236
client_no_context_takeover = True
229237
else:
230-
raise exceptions.InvalidParameterValue(name, value)
238+
raise InvalidParameterValue(name, value)
231239

232240
elif name == "server_max_window_bits":
233241
if server_max_window_bits is not None:
234-
raise exceptions.DuplicateParameter(name)
242+
raise DuplicateParameter(name)
235243
if value in _MAX_WINDOW_BITS_VALUES:
236244
server_max_window_bits = int(value)
237245
else:
238-
raise exceptions.InvalidParameterValue(name, value)
246+
raise InvalidParameterValue(name, value)
239247

240248
elif name == "client_max_window_bits":
241249
if client_max_window_bits is not None:
242-
raise exceptions.DuplicateParameter(name)
250+
raise DuplicateParameter(name)
243251
if is_server and value is None: # only in handshake requests
244252
client_max_window_bits = True
245253
elif value in _MAX_WINDOW_BITS_VALUES:
246254
client_max_window_bits = int(value)
247255
else:
248-
raise exceptions.InvalidParameterValue(name, value)
256+
raise InvalidParameterValue(name, value)
249257

250258
else:
251-
raise exceptions.InvalidParameterName(name)
259+
raise InvalidParameterName(name)
252260

253261
return (
254262
server_no_context_takeover,
@@ -340,7 +348,7 @@ def process_response_params(
340348
341349
"""
342350
if any(other.name == self.name for other in accepted_extensions):
343-
raise exceptions.NegotiationError(f"received duplicate {self.name}")
351+
raise NegotiationError(f"received duplicate {self.name}")
344352

345353
# Request parameters are available in instance variables.
346354

@@ -366,7 +374,7 @@ def process_response_params(
366374

367375
if self.server_no_context_takeover:
368376
if not server_no_context_takeover:
369-
raise exceptions.NegotiationError("expected server_no_context_takeover")
377+
raise NegotiationError("expected server_no_context_takeover")
370378

371379
# client_no_context_takeover
372380
#
@@ -396,9 +404,9 @@ def process_response_params(
396404

397405
else:
398406
if server_max_window_bits is None:
399-
raise exceptions.NegotiationError("expected server_max_window_bits")
407+
raise NegotiationError("expected server_max_window_bits")
400408
elif server_max_window_bits > self.server_max_window_bits:
401-
raise exceptions.NegotiationError("unsupported server_max_window_bits")
409+
raise NegotiationError("unsupported server_max_window_bits")
402410

403411
# client_max_window_bits
404412

@@ -414,7 +422,7 @@ def process_response_params(
414422

415423
if self.client_max_window_bits is None:
416424
if client_max_window_bits is not None:
417-
raise exceptions.NegotiationError("unexpected client_max_window_bits")
425+
raise NegotiationError("unexpected client_max_window_bits")
418426

419427
elif self.client_max_window_bits is True:
420428
pass
@@ -423,7 +431,7 @@ def process_response_params(
423431
if client_max_window_bits is None:
424432
client_max_window_bits = self.client_max_window_bits
425433
elif client_max_window_bits > self.client_max_window_bits:
426-
raise exceptions.NegotiationError("unsupported client_max_window_bits")
434+
raise NegotiationError("unsupported client_max_window_bits")
427435

428436
return PerMessageDeflate(
429437
server_no_context_takeover, # remote_no_context_takeover
@@ -534,7 +542,7 @@ def process_request_params(
534542
535543
"""
536544
if any(other.name == self.name for other in accepted_extensions):
537-
raise exceptions.NegotiationError(f"skipped duplicate {self.name}")
545+
raise NegotiationError(f"skipped duplicate {self.name}")
538546

539547
# Load request parameters in local variables.
540548
(
@@ -613,7 +621,7 @@ def process_request_params(
613621
else:
614622
if client_max_window_bits is None:
615623
if self.require_client_max_window_bits:
616-
raise exceptions.NegotiationError("required client_max_window_bits")
624+
raise NegotiationError("required client_max_window_bits")
617625
elif client_max_window_bits is True:
618626
client_max_window_bits = self.client_max_window_bits
619627
elif self.client_max_window_bits < client_max_window_bits:

src/websockets/frames.py

Lines changed: 13 additions & 11 deletions
Original file line numberDiff line numberDiff line change
@@ -8,7 +8,7 @@
88
import struct
99
from typing import Callable, Generator, Sequence
1010

11-
from . import exceptions, extensions
11+
from .exceptions import PayloadTooBig, ProtocolError
1212

1313

1414
try:
@@ -239,10 +239,10 @@ def parse(
239239
try:
240240
opcode = Opcode(head1 & 0b00001111)
241241
except ValueError as exc:
242-
raise exceptions.ProtocolError("invalid opcode") from exc
242+
raise ProtocolError("invalid opcode") from exc
243243

244244
if (True if head2 & 0b10000000 else False) != mask:
245-
raise exceptions.ProtocolError("incorrect masking")
245+
raise ProtocolError("incorrect masking")
246246

247247
length = head2 & 0b01111111
248248
if length == 126:
@@ -252,9 +252,7 @@ def parse(
252252
data = yield from read_exact(8)
253253
(length,) = struct.unpack("!Q", data)
254254
if max_size is not None and length > max_size:
255-
raise exceptions.PayloadTooBig(
256-
f"over size limit ({length} > {max_size} bytes)"
257-
)
255+
raise PayloadTooBig(f"over size limit ({length} > {max_size} bytes)")
258256
if mask:
259257
mask_bytes = yield from read_exact(4)
260258

@@ -342,13 +340,13 @@ def check(self) -> None:
342340
343341
"""
344342
if self.rsv1 or self.rsv2 or self.rsv3:
345-
raise exceptions.ProtocolError("reserved bits must be 0")
343+
raise ProtocolError("reserved bits must be 0")
346344

347345
if self.opcode in CTRL_OPCODES:
348346
if len(self.data) > 125:
349-
raise exceptions.ProtocolError("control frame too long")
347+
raise ProtocolError("control frame too long")
350348
if not self.fin:
351-
raise exceptions.ProtocolError("fragmented control frame")
349+
raise ProtocolError("fragmented control frame")
352350

353351

354352
@dataclasses.dataclass
@@ -405,7 +403,7 @@ def parse(cls, data: bytes) -> Close:
405403
elif len(data) == 0:
406404
return cls(CloseCode.NO_STATUS_RCVD, "")
407405
else:
408-
raise exceptions.ProtocolError("close frame too short")
406+
raise ProtocolError("close frame too short")
409407

410408
def serialize(self) -> bytes:
411409
"""
@@ -424,4 +422,8 @@ def check(self) -> None:
424422
425423
"""
426424
if not (self.code in EXTERNAL_CLOSE_CODES or 3000 <= self.code < 5000):
427-
raise exceptions.ProtocolError("invalid status code")
425+
raise ProtocolError("invalid status code")
426+
427+
428+
# At the bottom to break import cycles created by type annotations.
429+
from . import extensions # noqa: E402

src/websockets/headers.py

Lines changed: 12 additions & 20 deletions
Original file line numberDiff line numberDiff line change
@@ -6,7 +6,7 @@
66
import re
77
from typing import Callable, Sequence, TypeVar, cast
88

9-
from . import exceptions
9+
from .exceptions import InvalidHeaderFormat, InvalidHeaderValue
1010
from .typing import (
1111
ConnectionOption,
1212
ExtensionHeader,
@@ -108,7 +108,7 @@ def parse_token(header: str, pos: int, header_name: str) -> tuple[str, int]:
108108
"""
109109
match = _token_re.match(header, pos)
110110
if match is None:
111-
raise exceptions.InvalidHeaderFormat(header_name, "expected token", header, pos)
111+
raise InvalidHeaderFormat(header_name, "expected token", header, pos)
112112
return match.group(), match.end()
113113

114114

@@ -132,9 +132,7 @@ def parse_quoted_string(header: str, pos: int, header_name: str) -> tuple[str, i
132132
"""
133133
match = _quoted_string_re.match(header, pos)
134134
if match is None:
135-
raise exceptions.InvalidHeaderFormat(
136-
header_name, "expected quoted string", header, pos
137-
)
135+
raise InvalidHeaderFormat(header_name, "expected quoted string", header, pos)
138136
return _unquote_re.sub(r"\1", match.group()[1:-1]), match.end()
139137

140138

@@ -206,9 +204,7 @@ def parse_list(
206204
if peek_ahead(header, pos) == ",":
207205
pos = parse_OWS(header, pos + 1)
208206
else:
209-
raise exceptions.InvalidHeaderFormat(
210-
header_name, "expected comma", header, pos
211-
)
207+
raise InvalidHeaderFormat(header_name, "expected comma", header, pos)
212208

213209
# Remove extra delimiters before the next item.
214210
while peek_ahead(header, pos) == ",":
@@ -276,9 +272,7 @@ def parse_upgrade_protocol(
276272
"""
277273
match = _protocol_re.match(header, pos)
278274
if match is None:
279-
raise exceptions.InvalidHeaderFormat(
280-
header_name, "expected protocol", header, pos
281-
)
275+
raise InvalidHeaderFormat(header_name, "expected protocol", header, pos)
282276
return cast(UpgradeProtocol, match.group()), match.end()
283277

284278

@@ -324,7 +318,7 @@ def parse_extension_item_param(
324318
# the value after quoted-string unescaping MUST conform to
325319
# the 'token' ABNF.
326320
if _token_re.fullmatch(value) is None:
327-
raise exceptions.InvalidHeaderFormat(
321+
raise InvalidHeaderFormat(
328322
header_name, "invalid quoted header content", header, pos_before
329323
)
330324
else:
@@ -510,9 +504,7 @@ def parse_token68(header: str, pos: int, header_name: str) -> tuple[str, int]:
510504
"""
511505
match = _token68_re.match(header, pos)
512506
if match is None:
513-
raise exceptions.InvalidHeaderFormat(
514-
header_name, "expected token68", header, pos
515-
)
507+
raise InvalidHeaderFormat(header_name, "expected token68", header, pos)
516508
return match.group(), match.end()
517509

518510

@@ -522,7 +514,7 @@ def parse_end(header: str, pos: int, header_name: str) -> None:
522514
523515
"""
524516
if pos < len(header):
525-
raise exceptions.InvalidHeaderFormat(header_name, "trailing data", header, pos)
517+
raise InvalidHeaderFormat(header_name, "trailing data", header, pos)
526518

527519

528520
def parse_authorization_basic(header: str) -> tuple[str, str]:
@@ -543,12 +535,12 @@ def parse_authorization_basic(header: str) -> tuple[str, str]:
543535
# https://datatracker.ietf.org/doc/html/rfc7617#section-2
544536
scheme, pos = parse_token(header, 0, "Authorization")
545537
if scheme.lower() != "basic":
546-
raise exceptions.InvalidHeaderValue(
538+
raise InvalidHeaderValue(
547539
"Authorization",
548540
f"unsupported scheme: {scheme}",
549541
)
550542
if peek_ahead(header, pos) != " ":
551-
raise exceptions.InvalidHeaderFormat(
543+
raise InvalidHeaderFormat(
552544
"Authorization", "expected space after scheme", header, pos
553545
)
554546
pos += 1
@@ -558,14 +550,14 @@ def parse_authorization_basic(header: str) -> tuple[str, str]:
558550
try:
559551
user_pass = base64.b64decode(basic_credentials.encode()).decode()
560552
except binascii.Error:
561-
raise exceptions.InvalidHeaderValue(
553+
raise InvalidHeaderValue(
562554
"Authorization",
563555
"expected base64-encoded credentials",
564556
) from None
565557
try:
566558
username, password = user_pass.split(":", 1)
567559
except ValueError:
568-
raise exceptions.InvalidHeaderValue(
560+
raise InvalidHeaderValue(
569561
"Authorization",
570562
"expected username:password credentials",
571563
) from None

0 commit comments

Comments
 (0)