Skip to content

Commit db20683

Browse files
committed
fix up type errors with stricter state machine
1 parent 748eef9 commit db20683

File tree

2 files changed

+28
-18
lines changed

2 files changed

+28
-18
lines changed

src/h2/events.py

+6-6
Original file line numberDiff line numberDiff line change
@@ -16,7 +16,7 @@
1616
from .settings import ChangedSetting, SettingCodes, Settings, _setting_code_from_int
1717

1818
if TYPE_CHECKING: # pragma: no cover
19-
from hpack import HeaderTuple
19+
from hpack.struct import Header
2020
from hyperframe.frame import Frame
2121

2222
from .errors import ErrorCodes
@@ -52,7 +52,7 @@ def __init__(self) -> None:
5252
self.stream_id: int | None = None
5353

5454
#: The request headers.
55-
self.headers: list[HeaderTuple] | None = None
55+
self.headers: list[Header] | None = None
5656

5757
#: If this request also ended the stream, the associated
5858
#: :class:`StreamEnded <h2.events.StreamEnded>` event will be available
@@ -91,7 +91,7 @@ def __init__(self) -> None:
9191
self.stream_id: int | None = None
9292

9393
#: The response headers.
94-
self.headers: list[HeaderTuple] | None = None
94+
self.headers: list[Header] | None = None
9595

9696
#: If this response also ended the stream, the associated
9797
#: :class:`StreamEnded <h2.events.StreamEnded>` event will be available
@@ -133,7 +133,7 @@ def __init__(self) -> None:
133133
self.stream_id: int | None = None
134134

135135
#: The trailers themselves.
136-
self.headers: list[HeaderTuple] | None = None
136+
self.headers: list[Header] | None = None
137137

138138
#: Trailers always end streams. This property has the associated
139139
#: :class:`StreamEnded <h2.events.StreamEnded>` in it.
@@ -237,7 +237,7 @@ def __init__(self) -> None:
237237
self.stream_id: int | None = None
238238

239239
#: The headers for this informational response.
240-
self.headers: list[HeaderTuple] | None = None
240+
self.headers: list[Header] | None = None
241241

242242
#: If this response also had associated priority information, the
243243
#: associated :class:`PriorityUpdated <h2.events.PriorityUpdated>`
@@ -460,7 +460,7 @@ def __init__(self) -> None:
460460
self.parent_stream_id: int | None = None
461461

462462
#: The request headers, sent by the remote party in the push.
463-
self.headers: list[HeaderTuple] | None = None
463+
self.headers: list[Header] | None = None
464464

465465
def __repr__(self) -> str:
466466
return (

src/h2/stream.py

+22-12
Original file line numberDiff line numberDiff line change
@@ -7,7 +7,7 @@
77
from __future__ import annotations
88

99
from enum import Enum, IntEnum
10-
from typing import TYPE_CHECKING, Any
10+
from typing import TYPE_CHECKING, Any, cast
1111

1212
from hpack import HeaderTuple
1313
from hyperframe.frame import AltSvcFrame, ContinuationFrame, DataFrame, Frame, HeadersFrame, PushPromiseFrame, RstStreamFrame, WindowUpdateFrame
@@ -1046,10 +1046,11 @@ def receive_push_promise_in_band(self,
10461046
events = self.state_machine.process_input(
10471047
StreamInputs.RECV_PUSH_PROMISE,
10481048
)
1049-
events[0].pushed_stream_id = promised_stream_id
1049+
push_event = cast(PushedStreamReceived, events[0])
1050+
push_event.pushed_stream_id = promised_stream_id
10501051

10511052
hdr_validation_flags = self._build_hdr_validation_flags(events)
1052-
events[0].headers = self._process_received_headers(
1053+
push_event.headers = self._process_received_headers(
10531054
headers, hdr_validation_flags, header_encoding,
10541055
)
10551056
return [], events
@@ -1083,22 +1084,30 @@ def receive_headers(self,
10831084
input_ = StreamInputs.RECV_HEADERS
10841085

10851086
events = self.state_machine.process_input(input_)
1087+
headers_event = cast(
1088+
RequestReceived | ResponseReceived | TrailersReceived | InformationalResponseReceived,
1089+
events[0],
1090+
)
10861091

10871092
if end_stream:
10881093
es_events = self.state_machine.process_input(
10891094
StreamInputs.RECV_END_STREAM,
10901095
)
1091-
events[0].stream_ended = es_events[0]
1096+
# We ensured it's not an information response at the beginning of the method.
1097+
cast(
1098+
RequestReceived | ResponseReceived | TrailersReceived,
1099+
headers_event,
1100+
).stream_ended = cast(StreamEnded, es_events[0])
10921101
events += es_events
10931102

10941103
self._initialize_content_length(headers)
10951104

1096-
if isinstance(events[0], TrailersReceived) and not end_stream:
1105+
if isinstance(headers_event, TrailersReceived) and not end_stream:
10971106
msg = "Trailers must have END_STREAM set"
10981107
raise ProtocolError(msg)
10991108

11001109
hdr_validation_flags = self._build_hdr_validation_flags(events)
1101-
events[0].headers = self._process_received_headers(
1110+
headers_event.headers = self._process_received_headers(
11021111
headers, hdr_validation_flags, header_encoding,
11031112
)
11041113
return [], events
@@ -1112,18 +1121,19 @@ def receive_data(self, data: bytes, end_stream: bool, flow_control_len: int) ->
11121121
"set to %d", self, end_stream, flow_control_len,
11131122
)
11141123
events = self.state_machine.process_input(StreamInputs.RECV_DATA)
1124+
data_event = cast(DataReceived, events[0])
11151125
self._inbound_window_manager.window_consumed(flow_control_len)
11161126
self._track_content_length(len(data), end_stream)
11171127

11181128
if end_stream:
11191129
es_events = self.state_machine.process_input(
11201130
StreamInputs.RECV_END_STREAM,
11211131
)
1122-
events[0].stream_ended = es_events[0]
1132+
data_event.stream_ended = cast(StreamEnded, es_events[0])
11231133
events.extend(es_events)
11241134

1125-
events[0].data = data
1126-
events[0].flow_controlled_length = flow_control_len
1135+
data_event.data = data
1136+
data_event.flow_controlled_length = flow_control_len
11271137
return [], events
11281138

11291139
def receive_window_update(self, increment: int) -> tuple[list[Frame], list[Event]]:
@@ -1143,7 +1153,7 @@ def receive_window_update(self, increment: int) -> tuple[list[Frame], list[Event
11431153
# this should be treated as a *stream* error, not a *connection* error.
11441154
# That means we need to catch the error and forcibly close the stream.
11451155
if events:
1146-
events[0].delta = increment
1156+
cast(WindowUpdated, events[0]).delta = increment
11471157
try:
11481158
self.outbound_flow_control_window = guard_increment_window(
11491159
self.outbound_flow_control_window,
@@ -1226,7 +1236,7 @@ def stream_reset(self, frame: RstStreamFrame) -> tuple[list[Frame], list[Event]]
12261236

12271237
if events:
12281238
# We don't fire an event if this stream is already closed.
1229-
events[0].error_code = _error_code_from_int(frame.error_code)
1239+
cast(StreamReset, events[0]).error_code = _error_code_from_int(frame.error_code)
12301240

12311241
return [], events
12321242

@@ -1328,7 +1338,7 @@ def _build_headers_frames(self,
13281338
def _process_received_headers(self,
13291339
headers: Iterable[Header],
13301340
header_validation_flags: HeaderValidationFlags,
1331-
header_encoding: bool | str | None) -> Iterable[Header]:
1341+
header_encoding: bool | str | None) -> list[Header]:
13321342
"""
13331343
When headers have been received from the remote peer, run a processing
13341344
pipeline on them to transform them into the appropriate form for

0 commit comments

Comments
 (0)