Skip to content

Commit 76b6844

Browse files
committed
Always raise subclasses of InvalidHandshake.
Also shorten messages for InvalidHeader exceptions.
1 parent b42a65c commit 76b6844

File tree

8 files changed

+29
-41
lines changed

8 files changed

+29
-41
lines changed

src/websockets/client.py

Lines changed: 7 additions & 8 deletions
Original file line numberDiff line numberDiff line change
@@ -176,10 +176,7 @@ def process_response(self, response: Response) -> None:
176176
except KeyError as exc:
177177
raise InvalidHeader("Sec-WebSocket-Accept") from exc
178178
except MultipleValuesError as exc:
179-
raise InvalidHeader(
180-
"Sec-WebSocket-Accept",
181-
"more than one Sec-WebSocket-Accept header found",
182-
) from exc
179+
raise InvalidHeader("Sec-WebSocket-Accept", "multiple values") from exc
183180

184181
if s_w_accept != accept_key(self.key):
185182
raise InvalidHeaderValue("Sec-WebSocket-Accept", s_w_accept)
@@ -225,7 +222,7 @@ def process_extensions(self, headers: Headers) -> list[Extension]:
225222

226223
if extensions:
227224
if self.available_extensions is None:
228-
raise InvalidHandshake("no extensions supported")
225+
raise NegotiationError("no extensions supported")
229226

230227
parsed_extensions: list[ExtensionHeader] = sum(
231228
[parse_extension(header_value) for header_value in extensions], []
@@ -280,15 +277,17 @@ def process_subprotocol(self, headers: Headers) -> Subprotocol | None:
280277

281278
if subprotocols:
282279
if self.available_subprotocols is None:
283-
raise InvalidHandshake("no subprotocols supported")
280+
raise NegotiationError("no subprotocols supported")
284281

285282
parsed_subprotocols: Sequence[Subprotocol] = sum(
286283
[parse_subprotocol(header_value) for header_value in subprotocols], []
287284
)
288285

289286
if len(parsed_subprotocols) > 1:
290-
subprotocols_display = ", ".join(parsed_subprotocols)
291-
raise InvalidHandshake(f"multiple subprotocols: {subprotocols_display}")
287+
raise InvalidHeader(
288+
"Sec-WebSocket-Protocol",
289+
f"multiple values: {', '.join(parsed_subprotocols)}",
290+
)
292291

293292
subprotocol = parsed_subprotocols[0]
294293

src/websockets/datastructures.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -17,7 +17,7 @@
1717

1818
class MultipleValuesError(LookupError):
1919
"""
20-
Exception raised when :class:`Headers` has more than one value for a key.
20+
Exception raised when :class:`Headers` has multiple values for a key.
2121
2222
"""
2323

src/websockets/legacy/client.py

Lines changed: 7 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -19,10 +19,9 @@
1919
from ..asyncio.compatibility import asyncio_timeout
2020
from ..datastructures import Headers, HeadersLike
2121
from ..exceptions import (
22-
InvalidHandshake,
2322
InvalidHeader,
23+
InvalidHeaderValue,
2424
NegotiationError,
25-
2625
SecurityError,
2726
)
2827
from ..extensions import ClientExtensionFactory, Extension
@@ -181,7 +180,7 @@ def process_extensions(
181180

182181
if header_values:
183182
if available_extensions is None:
184-
raise InvalidHandshake("no extensions supported")
183+
raise NegotiationError("no extensions supported")
185184

186185
parsed_header_values: list[ExtensionHeader] = sum(
187186
[parse_extension(header_value) for header_value in header_values], []
@@ -235,15 +234,17 @@ def process_subprotocol(
235234

236235
if header_values:
237236
if available_subprotocols is None:
238-
raise InvalidHandshake("no subprotocols supported")
237+
raise NegotiationError("no subprotocols supported")
239238

240239
parsed_header_values: Sequence[Subprotocol] = sum(
241240
[parse_subprotocol(header_value) for header_value in header_values], []
242241
)
243242

244243
if len(parsed_header_values) > 1:
245-
subprotocols = ", ".join(parsed_header_values)
246-
raise InvalidHandshake(f"multiple subprotocols: {subprotocols}")
244+
raise InvalidHeaderValue(
245+
"Sec-WebSocket-Protocol",
246+
f"multiple values: {', '.join(parsed_header_values)}",
247+
)
247248

248249
subprotocol = parsed_header_values[0]
249250

src/websockets/legacy/handshake.py

Lines changed: 3 additions & 9 deletions
Original file line numberDiff line numberDiff line change
@@ -76,9 +76,7 @@ def check_request(headers: Headers) -> str:
7676
except KeyError as exc:
7777
raise InvalidHeader("Sec-WebSocket-Key") from exc
7878
except MultipleValuesError as exc:
79-
raise InvalidHeader(
80-
"Sec-WebSocket-Key", "more than one Sec-WebSocket-Key header found"
81-
) from exc
79+
raise InvalidHeader("Sec-WebSocket-Key", "multiple values") from exc
8280

8381
try:
8482
raw_key = base64.b64decode(s_w_key.encode(), validate=True)
@@ -92,9 +90,7 @@ def check_request(headers: Headers) -> str:
9290
except KeyError as exc:
9391
raise InvalidHeader("Sec-WebSocket-Version") from exc
9492
except MultipleValuesError as exc:
95-
raise InvalidHeader(
96-
"Sec-WebSocket-Version", "more than one Sec-WebSocket-Version header found"
97-
) from exc
93+
raise InvalidHeader("Sec-WebSocket-Version", "multiple values") from exc
9894

9995
if s_w_version != "13":
10096
raise InvalidHeaderValue("Sec-WebSocket-Version", s_w_version)
@@ -156,9 +152,7 @@ def check_response(headers: Headers, key: str) -> None:
156152
except KeyError as exc:
157153
raise InvalidHeader("Sec-WebSocket-Accept") from exc
158154
except MultipleValuesError as exc:
159-
raise InvalidHeader(
160-
"Sec-WebSocket-Accept", "more than one Sec-WebSocket-Accept header found"
161-
) from exc
155+
raise InvalidHeader("Sec-WebSocket-Accept", "multiple values") from exc
162156

163157
if s_w_accept != accept(key):
164158
raise InvalidHeaderValue("Sec-WebSocket-Accept", s_w_accept)

src/websockets/legacy/server.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -397,7 +397,7 @@ def process_origin(
397397
try:
398398
origin = headers.get("Origin")
399399
except MultipleValuesError as exc:
400-
raise InvalidHeader("Origin", "more than one Origin header found") from exc
400+
raise InvalidHeader("Origin", "multiple values") from exc
401401
if origin is not None:
402402
origin = cast(Origin, origin)
403403
if origins is not None:

src/websockets/server.py

Lines changed: 3 additions & 8 deletions
Original file line numberDiff line numberDiff line change
@@ -257,9 +257,7 @@ def process_request(
257257
except KeyError as exc:
258258
raise InvalidHeader("Sec-WebSocket-Key") from exc
259259
except MultipleValuesError as exc:
260-
raise InvalidHeader(
261-
"Sec-WebSocket-Key", "more than one Sec-WebSocket-Key header found"
262-
) from exc
260+
raise InvalidHeader("Sec-WebSocket-Key", "multiple values") from exc
263261

264262
try:
265263
raw_key = base64.b64decode(key.encode(), validate=True)
@@ -273,10 +271,7 @@ def process_request(
273271
except KeyError as exc:
274272
raise InvalidHeader("Sec-WebSocket-Version") from exc
275273
except MultipleValuesError as exc:
276-
raise InvalidHeader(
277-
"Sec-WebSocket-Version",
278-
"more than one Sec-WebSocket-Version header found",
279-
) from exc
274+
raise InvalidHeader("Sec-WebSocket-Version", "multiple values") from exc
280275

281276
if version != "13":
282277
raise InvalidHeaderValue("Sec-WebSocket-Version", version)
@@ -315,7 +310,7 @@ def process_origin(self, headers: Headers) -> Origin | None:
315310
try:
316311
origin = headers.get("Origin")
317312
except MultipleValuesError as exc:
318-
raise InvalidHeader("Origin", "more than one Origin header found") from exc
313+
raise InvalidHeader("Origin", "multiple values") from exc
319314
if origin is not None:
320315
origin = cast(Origin, origin)
321316
if self.origins is not None:

tests/test_client.py

Lines changed: 4 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -342,8 +342,7 @@ def test_multiple_accept(self):
342342
raise client.handshake_exc
343343
self.assertEqual(
344344
str(raised.exception),
345-
"invalid Sec-WebSocket-Accept header: "
346-
"more than one Sec-WebSocket-Accept header found",
345+
"invalid Sec-WebSocket-Accept header: multiple values",
347346
)
348347

349348
def test_invalid_accept(self):
@@ -556,7 +555,9 @@ def test_multiple_subprotocols(self):
556555
with self.assertRaises(InvalidHandshake) as raised:
557556
raise client.handshake_exc
558557
self.assertEqual(
559-
str(raised.exception), "multiple subprotocols: superchat, chat"
558+
str(raised.exception),
559+
"invalid Sec-WebSocket-Protocol header: "
560+
"multiple values: superchat, chat",
560561
)
561562

562563
def test_supported_subprotocol(self):

tests/test_server.py

Lines changed: 3 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -305,8 +305,7 @@ def test_multiple_key(self):
305305
raise server.handshake_exc
306306
self.assertEqual(
307307
str(raised.exception),
308-
"invalid Sec-WebSocket-Key header: "
309-
"more than one Sec-WebSocket-Key header found",
308+
"invalid Sec-WebSocket-Key header: multiple values",
310309
)
311310

312311
def test_invalid_key(self):
@@ -366,8 +365,7 @@ def test_multiple_version(self):
366365
raise server.handshake_exc
367366
self.assertEqual(
368367
str(raised.exception),
369-
"invalid Sec-WebSocket-Version header: "
370-
"more than one Sec-WebSocket-Version header found",
368+
"invalid Sec-WebSocket-Version header: multiple values",
371369
)
372370

373371
def test_invalid_version(self):
@@ -437,7 +435,7 @@ def test_multiple_origin(self):
437435
raise server.handshake_exc
438436
self.assertEqual(
439437
str(raised.exception),
440-
"invalid Origin header: more than one Origin header found",
438+
"invalid Origin header: multiple values",
441439
)
442440

443441
def test_supported_origin(self):

0 commit comments

Comments
 (0)