Skip to content

Commit a633092

Browse files
committed
Add incremental updating of open streams count and closed_streams state
1 parent f96b4f5 commit a633092

File tree

3 files changed

+109
-18
lines changed

3 files changed

+109
-18
lines changed

h2/connection.py

+27-15
Original file line numberDiff line numberDiff line change
@@ -292,6 +292,9 @@ def __init__(self, config=None):
292292
self.encoder = Encoder()
293293
self.decoder = Decoder()
294294

295+
self._open_outbound_stream_count = 0
296+
self._open_inbound_stream_count = 0
297+
295298
# This won't always actually do anything: for versions of HPACK older
296299
# than 2.3.0 it does nothing. However, we have to try!
297300
self.decoder.max_header_list_size = self.DEFAULT_MAX_HEADER_LIST_SIZE
@@ -362,6 +365,8 @@ def __init__(self, config=None):
362365
size_limit=self.MAX_CLOSED_STREAMS
363366
)
364367

368+
self._streams_to_close = list()
369+
365370
# The flow control window manager for the connection.
366371
self._inbound_flow_control_window_manager = WindowManager(
367372
max_window_size=self.local_settings.initial_window_size
@@ -383,6 +388,15 @@ def __init__(self, config=None):
383388
ExtensionFrame: self._receive_unknown_frame
384389
}
385390

391+
def _increment_open_streams(self, stream_id, incr):
392+
if stream_id % 2 == 0:
393+
self._open_inbound_stream_count += incr
394+
elif stream_id % 2 == 1:
395+
self._open_outbound_stream_count += incr
396+
397+
def _close_stream(self, stream_id):
398+
self._streams_to_close.append(stream_id)
399+
386400
def _prepare_for_sending(self, frames):
387401
if not frames:
388402
return
@@ -393,22 +407,18 @@ def _open_streams(self, remainder):
393407
"""
394408
A common method of counting number of open streams. Returns the number
395409
of streams that are open *and* that have (stream ID % 2) == remainder.
396-
While it iterates, also deletes any closed streams.
410+
Also cleans up closed streams.
397411
"""
398-
count = 0
399-
to_delete = []
400-
401-
for stream_id, stream in self.streams.items():
402-
if stream.open and (stream_id % 2 == remainder):
403-
count += 1
404-
elif stream.closed:
405-
to_delete.append(stream_id)
406-
407-
for stream_id in to_delete:
412+
for stream_id in self._streams_to_close:
408413
stream = self.streams.pop(stream_id)
409414
self._closed_streams[stream_id] = stream.closed_by
415+
self._streams_to_close = list()
410416

411-
return count
417+
if remainder == 0:
418+
return self._open_inbound_stream_count
419+
elif remainder == 1:
420+
return self._open_outbound_stream_count
421+
return 0
412422

413423
@property
414424
def open_outbound_streams(self):
@@ -467,7 +477,9 @@ def _begin_new_stream(self, stream_id, allowed_ids):
467477
stream_id,
468478
config=self.config,
469479
inbound_window_size=self.local_settings.initial_window_size,
470-
outbound_window_size=self.remote_settings.initial_window_size
480+
outbound_window_size=self.remote_settings.initial_window_size,
481+
increment_open_stream_count_callback=self._increment_open_streams,
482+
close_stream_callback=self._close_stream,
471483
)
472484
self.config.logger.debug("Stream ID %d created", stream_id)
473485
s.max_inbound_frame_size = self.max_inbound_frame_size
@@ -1542,8 +1554,8 @@ def _receive_headers_frame(self, frame):
15421554
max_open_streams = self.local_settings.max_concurrent_streams
15431555
if (self.open_inbound_streams + 1) > max_open_streams:
15441556
raise TooManyStreamsError(
1545-
"Max outbound streams is %d, %d open" %
1546-
(max_open_streams, self.open_outbound_streams)
1557+
"Max inbound streams is %d, %d open" %
1558+
(max_open_streams, self.open_inbound_streams)
15471559
)
15481560

15491561
# Let's decode the headers. We handle headers as bytes internally up

h2/stream.py

+81-2
Original file line numberDiff line numberDiff line change
@@ -767,6 +767,56 @@ def send_alt_svc(self, previous_state):
767767
(H2StreamStateMachine.send_on_closed_stream, StreamState.CLOSED),
768768
}
769769

770+
"""
771+
Wraps a stream state change function to ensure that we keep
772+
the parent H2Connection's state in sync
773+
"""
774+
def sync_state_change(func):
775+
def wrapper(self, *args, **kwargs):
776+
# Collect state at the beginning.
777+
start_state = self.state_machine.state
778+
started_open = self.open
779+
started_closed = not started_open
780+
781+
# Do the state change (if any).
782+
result = func(self, *args, **kwargs)
783+
784+
# Collect state at the end.
785+
end_state = self.state_machine.state
786+
ended_open = self.open
787+
ended_closed = not ended_open
788+
789+
if end_state == StreamState.CLOSED and start_state != end_state:
790+
if self._close_stream_callback:
791+
self._close_stream_callback(self.stream_id)
792+
# Clear callback so we only call this once per stream
793+
self._close_stream_callback = None
794+
795+
# If we were open, but are now closed, decrement
796+
# the open stream count, and call the close callback.
797+
if started_open and ended_closed:
798+
if self._decrement_open_stream_count_callback:
799+
self._decrement_open_stream_count_callback(self.stream_id,
800+
-1,)
801+
# Clear callback so we only call this once per stream
802+
self._decrement_open_stream_count_callback = None
803+
804+
if self._close_stream_callback:
805+
self._close_stream_callback(self.stream_id)
806+
# Clear callback so we only call this once per stream
807+
self._close_stream_callback = None
808+
809+
# If we were closed, but are now open, increment
810+
# the open stream count.
811+
elif started_closed and ended_open:
812+
if self._increment_open_stream_count_callback:
813+
self._increment_open_stream_count_callback(self.stream_id,
814+
1,)
815+
# Clear callback so we only call this once per stream
816+
self._increment_open_stream_count_callback = None
817+
return result
818+
return wrapper
819+
770820

771821
class H2Stream(object):
772822
"""
@@ -782,18 +832,29 @@ def __init__(self,
782832
stream_id,
783833
config,
784834
inbound_window_size,
785-
outbound_window_size):
835+
outbound_window_size,
836+
increment_open_stream_count_callback,
837+
close_stream_callback,):
786838
self.state_machine = H2StreamStateMachine(stream_id)
787839
self.stream_id = stream_id
788840
self.max_outbound_frame_size = None
789841
self.request_method = None
790842

791-
# The current value of the outbound stream flow control window
843+
# The current value of the outbound stream flow control window.
792844
self.outbound_flow_control_window = outbound_window_size
793845

794846
# The flow control manager.
795847
self._inbound_window_manager = WindowManager(inbound_window_size)
796848

849+
# Callback to increment open stream count for the H2Connection.
850+
self._increment_open_stream_count_callback = increment_open_stream_count_callback
851+
852+
# Callback to decrement open stream count for the H2Connection.
853+
self._decrement_open_stream_count_callback = increment_open_stream_count_callback
854+
855+
# Callback to clean up state for the H2Connection once we're closed.
856+
self._close_stream_callback = close_stream_callback
857+
797858
# The expected content length, if any.
798859
self._expected_content_length = None
799860

@@ -850,6 +911,7 @@ def closed_by(self):
850911
"""
851912
return self.state_machine.stream_closed_by
852913

914+
@sync_state_change
853915
def upgrade(self, client_side):
854916
"""
855917
Called by the connection to indicate that this stream is the initial
@@ -868,6 +930,7 @@ def upgrade(self, client_side):
868930
self.state_machine.process_input(input_)
869931
return
870932

933+
@sync_state_change
871934
def send_headers(self, headers, encoder, end_stream=False):
872935
"""
873936
Returns a list of HEADERS/CONTINUATION frames to emit as either headers
@@ -917,6 +980,7 @@ def send_headers(self, headers, encoder, end_stream=False):
917980

918981
return frames
919982

983+
@sync_state_change
920984
def push_stream_in_band(self, related_stream_id, headers, encoder):
921985
"""
922986
Returns a list of PUSH_PROMISE/CONTINUATION frames to emit as a pushed
@@ -941,6 +1005,7 @@ def push_stream_in_band(self, related_stream_id, headers, encoder):
9411005

9421006
return frames
9431007

1008+
@sync_state_change
9441009
def locally_pushed(self):
9451010
"""
9461011
Mark this stream as one that was pushed by this peer. Must be called
@@ -954,6 +1019,7 @@ def locally_pushed(self):
9541019
assert not events
9551020
return []
9561021

1022+
@sync_state_change
9571023
def send_data(self, data, end_stream=False, pad_length=None):
9581024
"""
9591025
Prepare some data frames. Optionally end the stream.
@@ -981,6 +1047,7 @@ def send_data(self, data, end_stream=False, pad_length=None):
9811047

9821048
return [df]
9831049

1050+
@sync_state_change
9841051
def end_stream(self):
9851052
"""
9861053
End a stream without sending data.
@@ -992,6 +1059,7 @@ def end_stream(self):
9921059
df.flags.add('END_STREAM')
9931060
return [df]
9941061

1062+
@sync_state_change
9951063
def advertise_alternative_service(self, field_value):
9961064
"""
9971065
Advertise an RFC 7838 alternative service. The semantics of this are
@@ -1005,6 +1073,7 @@ def advertise_alternative_service(self, field_value):
10051073
asf.field = field_value
10061074
return [asf]
10071075

1076+
@sync_state_change
10081077
def increase_flow_control_window(self, increment):
10091078
"""
10101079
Increase the size of the flow control window for the remote side.
@@ -1020,6 +1089,7 @@ def increase_flow_control_window(self, increment):
10201089
wuf.window_increment = increment
10211090
return [wuf]
10221091

1092+
@sync_state_change
10231093
def receive_push_promise_in_band(self,
10241094
promised_stream_id,
10251095
headers,
@@ -1044,6 +1114,7 @@ def receive_push_promise_in_band(self,
10441114
)
10451115
return [], events
10461116

1117+
@sync_state_change
10471118
def remotely_pushed(self, pushed_headers):
10481119
"""
10491120
Mark this stream as one that was pushed by the remote peer. Must be
@@ -1057,6 +1128,7 @@ def remotely_pushed(self, pushed_headers):
10571128
self._authority = authority_from_headers(pushed_headers)
10581129
return [], events
10591130

1131+
@sync_state_change
10601132
def receive_headers(self, headers, end_stream, header_encoding):
10611133
"""
10621134
Receive a set of headers (or trailers).
@@ -1091,6 +1163,7 @@ def receive_headers(self, headers, end_stream, header_encoding):
10911163
)
10921164
return [], events
10931165

1166+
@sync_state_change
10941167
def receive_data(self, data, end_stream, flow_control_len):
10951168
"""
10961169
Receive some data.
@@ -1114,6 +1187,7 @@ def receive_data(self, data, end_stream, flow_control_len):
11141187
events[0].flow_controlled_length = flow_control_len
11151188
return [], events
11161189

1190+
@sync_state_change
11171191
def receive_window_update(self, increment):
11181192
"""
11191193
Handle a WINDOW_UPDATE increment.
@@ -1150,6 +1224,7 @@ def receive_window_update(self, increment):
11501224

11511225
return frames, events
11521226

1227+
@sync_state_change
11531228
def receive_continuation(self):
11541229
"""
11551230
A naked CONTINUATION frame has been received. This is always an error,
@@ -1162,6 +1237,7 @@ def receive_continuation(self):
11621237
)
11631238
assert False, "Should not be reachable"
11641239

1240+
@sync_state_change
11651241
def receive_alt_svc(self, frame):
11661242
"""
11671243
An Alternative Service frame was received on the stream. This frame
@@ -1189,6 +1265,7 @@ def receive_alt_svc(self, frame):
11891265

11901266
return [], events
11911267

1268+
@sync_state_change
11921269
def reset_stream(self, error_code=0):
11931270
"""
11941271
Close the stream locally. Reset the stream with an error code.
@@ -1202,6 +1279,7 @@ def reset_stream(self, error_code=0):
12021279
rsf.error_code = error_code
12031280
return [rsf]
12041281

1282+
@sync_state_change
12051283
def stream_reset(self, frame):
12061284
"""
12071285
Handle a stream being reset remotely.
@@ -1217,6 +1295,7 @@ def stream_reset(self, frame):
12171295

12181296
return [], events
12191297

1298+
@sync_state_change
12201299
def acknowledge_received_data(self, acknowledged_size):
12211300
"""
12221301
The user has informed us that they've processed some amount of data

test/test_basic_logic.py

+1-1
Original file line numberDiff line numberDiff line change
@@ -1851,7 +1851,7 @@ def test_stream_repr(self):
18511851
"""
18521852
Ensure stream string representation is appropriate.
18531853
"""
1854-
s = h2.stream.H2Stream(4, None, 12, 14)
1854+
s = h2.stream.H2Stream(4, None, 12, 14, None, None)
18551855
assert repr(s) == "<H2Stream id:4 state:<StreamState.IDLE: 0>>"
18561856

18571857

0 commit comments

Comments
 (0)