7
7
from __future__ import annotations
8
8
9
9
from enum import Enum , IntEnum
10
- from typing import TYPE_CHECKING , Any
10
+ from typing import TYPE_CHECKING , Any , cast
11
11
12
12
from hpack import HeaderTuple
13
13
from hyperframe .frame import AltSvcFrame , ContinuationFrame , DataFrame , Frame , HeadersFrame , PushPromiseFrame , RstStreamFrame , WindowUpdateFrame
@@ -1046,10 +1046,11 @@ def receive_push_promise_in_band(self,
1046
1046
events = self .state_machine .process_input (
1047
1047
StreamInputs .RECV_PUSH_PROMISE ,
1048
1048
)
1049
- events [0 ].pushed_stream_id = promised_stream_id
1049
+ push_event = cast (PushedStreamReceived , events [0 ])
1050
+ push_event .pushed_stream_id = promised_stream_id
1050
1051
1051
1052
hdr_validation_flags = self ._build_hdr_validation_flags (events )
1052
- events [ 0 ] .headers = self ._process_received_headers (
1053
+ push_event .headers = self ._process_received_headers (
1053
1054
headers , hdr_validation_flags , header_encoding ,
1054
1055
)
1055
1056
return [], events
@@ -1083,22 +1084,30 @@ def receive_headers(self,
1083
1084
input_ = StreamInputs .RECV_HEADERS
1084
1085
1085
1086
events = self .state_machine .process_input (input_ )
1087
+ headers_event = cast (
1088
+ RequestReceived | ResponseReceived | TrailersReceived | InformationalResponseReceived ,
1089
+ events [0 ],
1090
+ )
1086
1091
1087
1092
if end_stream :
1088
1093
es_events = self .state_machine .process_input (
1089
1094
StreamInputs .RECV_END_STREAM ,
1090
1095
)
1091
- events [0 ].stream_ended = es_events [0 ]
1096
+ # We ensured it's not an information response at the beginning of the method.
1097
+ cast (
1098
+ RequestReceived | ResponseReceived | TrailersReceived ,
1099
+ headers_event ,
1100
+ ).stream_ended = cast (StreamEnded , es_events [0 ])
1092
1101
events += es_events
1093
1102
1094
1103
self ._initialize_content_length (headers )
1095
1104
1096
- if isinstance (events [ 0 ] , TrailersReceived ) and not end_stream :
1105
+ if isinstance (headers_event , TrailersReceived ) and not end_stream :
1097
1106
msg = "Trailers must have END_STREAM set"
1098
1107
raise ProtocolError (msg )
1099
1108
1100
1109
hdr_validation_flags = self ._build_hdr_validation_flags (events )
1101
- events [ 0 ] .headers = self ._process_received_headers (
1110
+ headers_event .headers = self ._process_received_headers (
1102
1111
headers , hdr_validation_flags , header_encoding ,
1103
1112
)
1104
1113
return [], events
@@ -1112,18 +1121,19 @@ def receive_data(self, data: bytes, end_stream: bool, flow_control_len: int) ->
1112
1121
"set to %d" , self , end_stream , flow_control_len ,
1113
1122
)
1114
1123
events = self .state_machine .process_input (StreamInputs .RECV_DATA )
1124
+ data_event = cast (DataReceived , events [0 ])
1115
1125
self ._inbound_window_manager .window_consumed (flow_control_len )
1116
1126
self ._track_content_length (len (data ), end_stream )
1117
1127
1118
1128
if end_stream :
1119
1129
es_events = self .state_machine .process_input (
1120
1130
StreamInputs .RECV_END_STREAM ,
1121
1131
)
1122
- events [ 0 ] .stream_ended = es_events [0 ]
1132
+ data_event .stream_ended = cast ( StreamEnded , es_events [0 ])
1123
1133
events .extend (es_events )
1124
1134
1125
- events [ 0 ] .data = data
1126
- events [ 0 ] .flow_controlled_length = flow_control_len
1135
+ data_event .data = data
1136
+ data_event .flow_controlled_length = flow_control_len
1127
1137
return [], events
1128
1138
1129
1139
def receive_window_update (self , increment : int ) -> tuple [list [Frame ], list [Event ]]:
@@ -1143,7 +1153,7 @@ def receive_window_update(self, increment: int) -> tuple[list[Frame], list[Event
1143
1153
# this should be treated as a *stream* error, not a *connection* error.
1144
1154
# That means we need to catch the error and forcibly close the stream.
1145
1155
if events :
1146
- events [0 ].delta = increment
1156
+ cast ( WindowUpdated , events [0 ]) .delta = increment
1147
1157
try :
1148
1158
self .outbound_flow_control_window = guard_increment_window (
1149
1159
self .outbound_flow_control_window ,
@@ -1226,7 +1236,7 @@ def stream_reset(self, frame: RstStreamFrame) -> tuple[list[Frame], list[Event]]
1226
1236
1227
1237
if events :
1228
1238
# We don't fire an event if this stream is already closed.
1229
- events [0 ].error_code = _error_code_from_int (frame .error_code )
1239
+ cast ( StreamReset , events [0 ]) .error_code = _error_code_from_int (frame .error_code )
1230
1240
1231
1241
return [], events
1232
1242
@@ -1328,7 +1338,7 @@ def _build_headers_frames(self,
1328
1338
def _process_received_headers (self ,
1329
1339
headers : Iterable [Header ],
1330
1340
header_validation_flags : HeaderValidationFlags ,
1331
- header_encoding : bool | str | None ) -> Iterable [Header ]:
1341
+ header_encoding : bool | str | None ) -> list [Header ]:
1332
1342
"""
1333
1343
When headers have been received from the remote peer, run a processing
1334
1344
pipeline on them to transform them into the appropriate form for
0 commit comments