@@ -767,6 +767,56 @@ def send_alt_svc(self, previous_state):
767
767
(H2StreamStateMachine .send_on_closed_stream , StreamState .CLOSED ),
768
768
}
769
769
770
+ """
771
+ Wraps a stream state change function to ensure that we keep
772
+ the parent H2Connection's state in sync
773
+ """
774
+ def sync_state_change (func ):
775
+ def wrapper (self , * args , ** kwargs ):
776
+ # Collect state at the beginning.
777
+ start_state = self .state_machine .state
778
+ started_open = self .open
779
+ started_closed = not started_open
780
+
781
+ # Do the state change (if any).
782
+ result = func (self , * args , ** kwargs )
783
+
784
+ # Collect state at the end.
785
+ end_state = self .state_machine .state
786
+ ended_open = self .open
787
+ ended_closed = not ended_open
788
+
789
+ if end_state == StreamState .CLOSED and start_state != end_state :
790
+ if self ._close_stream_callback :
791
+ self ._close_stream_callback (self .stream_id )
792
+ # Clear callback so we only call this once per stream
793
+ self ._close_stream_callback = None
794
+
795
+ # If we were open, but are now closed, decrement
796
+ # the open stream count, and call the close callback.
797
+ if started_open and ended_closed :
798
+ if self ._decrement_open_stream_count_callback :
799
+ self ._decrement_open_stream_count_callback (self .stream_id ,
800
+ - 1 ,)
801
+ # Clear callback so we only call this once per stream
802
+ self ._decrement_open_stream_count_callback = None
803
+
804
+ if self ._close_stream_callback :
805
+ self ._close_stream_callback (self .stream_id )
806
+ # Clear callback so we only call this once per stream
807
+ self ._close_stream_callback = None
808
+
809
+ # If we were closed, but are now open, increment
810
+ # the open stream count.
811
+ elif started_closed and ended_open :
812
+ if self ._increment_open_stream_count_callback :
813
+ self ._increment_open_stream_count_callback (self .stream_id ,
814
+ 1 ,)
815
+ # Clear callback so we only call this once per stream
816
+ self ._increment_open_stream_count_callback = None
817
+ return result
818
+ return wrapper
819
+
770
820
771
821
class H2Stream (object ):
772
822
"""
@@ -782,18 +832,29 @@ def __init__(self,
782
832
stream_id ,
783
833
config ,
784
834
inbound_window_size ,
785
- outbound_window_size ):
835
+ outbound_window_size ,
836
+ increment_open_stream_count_callback ,
837
+ close_stream_callback ,):
786
838
self .state_machine = H2StreamStateMachine (stream_id )
787
839
self .stream_id = stream_id
788
840
self .max_outbound_frame_size = None
789
841
self .request_method = None
790
842
791
- # The current value of the outbound stream flow control window
843
+ # The current value of the outbound stream flow control window.
792
844
self .outbound_flow_control_window = outbound_window_size
793
845
794
846
# The flow control manager.
795
847
self ._inbound_window_manager = WindowManager (inbound_window_size )
796
848
849
+ # Callback to increment open stream count for the H2Connection.
850
+ self ._increment_open_stream_count_callback = increment_open_stream_count_callback
851
+
852
+ # Callback to decrement open stream count for the H2Connection.
853
+ self ._decrement_open_stream_count_callback = increment_open_stream_count_callback
854
+
855
+ # Callback to clean up state for the H2Connection once we're closed.
856
+ self ._close_stream_callback = close_stream_callback
857
+
797
858
# The expected content length, if any.
798
859
self ._expected_content_length = None
799
860
@@ -850,6 +911,7 @@ def closed_by(self):
850
911
"""
851
912
return self .state_machine .stream_closed_by
852
913
914
+ @sync_state_change
853
915
def upgrade (self , client_side ):
854
916
"""
855
917
Called by the connection to indicate that this stream is the initial
@@ -868,6 +930,7 @@ def upgrade(self, client_side):
868
930
self .state_machine .process_input (input_ )
869
931
return
870
932
933
+ @sync_state_change
871
934
def send_headers (self , headers , encoder , end_stream = False ):
872
935
"""
873
936
Returns a list of HEADERS/CONTINUATION frames to emit as either headers
@@ -917,6 +980,7 @@ def send_headers(self, headers, encoder, end_stream=False):
917
980
918
981
return frames
919
982
983
+ @sync_state_change
920
984
def push_stream_in_band (self , related_stream_id , headers , encoder ):
921
985
"""
922
986
Returns a list of PUSH_PROMISE/CONTINUATION frames to emit as a pushed
@@ -941,6 +1005,7 @@ def push_stream_in_band(self, related_stream_id, headers, encoder):
941
1005
942
1006
return frames
943
1007
1008
+ @sync_state_change
944
1009
def locally_pushed (self ):
945
1010
"""
946
1011
Mark this stream as one that was pushed by this peer. Must be called
@@ -954,6 +1019,7 @@ def locally_pushed(self):
954
1019
assert not events
955
1020
return []
956
1021
1022
+ @sync_state_change
957
1023
def send_data (self , data , end_stream = False , pad_length = None ):
958
1024
"""
959
1025
Prepare some data frames. Optionally end the stream.
@@ -981,6 +1047,7 @@ def send_data(self, data, end_stream=False, pad_length=None):
981
1047
982
1048
return [df ]
983
1049
1050
+ @sync_state_change
984
1051
def end_stream (self ):
985
1052
"""
986
1053
End a stream without sending data.
@@ -992,6 +1059,7 @@ def end_stream(self):
992
1059
df .flags .add ('END_STREAM' )
993
1060
return [df ]
994
1061
1062
+ @sync_state_change
995
1063
def advertise_alternative_service (self , field_value ):
996
1064
"""
997
1065
Advertise an RFC 7838 alternative service. The semantics of this are
@@ -1005,6 +1073,7 @@ def advertise_alternative_service(self, field_value):
1005
1073
asf .field = field_value
1006
1074
return [asf ]
1007
1075
1076
+ @sync_state_change
1008
1077
def increase_flow_control_window (self , increment ):
1009
1078
"""
1010
1079
Increase the size of the flow control window for the remote side.
@@ -1020,6 +1089,7 @@ def increase_flow_control_window(self, increment):
1020
1089
wuf .window_increment = increment
1021
1090
return [wuf ]
1022
1091
1092
+ @sync_state_change
1023
1093
def receive_push_promise_in_band (self ,
1024
1094
promised_stream_id ,
1025
1095
headers ,
@@ -1044,6 +1114,7 @@ def receive_push_promise_in_band(self,
1044
1114
)
1045
1115
return [], events
1046
1116
1117
+ @sync_state_change
1047
1118
def remotely_pushed (self , pushed_headers ):
1048
1119
"""
1049
1120
Mark this stream as one that was pushed by the remote peer. Must be
@@ -1057,6 +1128,7 @@ def remotely_pushed(self, pushed_headers):
1057
1128
self ._authority = authority_from_headers (pushed_headers )
1058
1129
return [], events
1059
1130
1131
+ @sync_state_change
1060
1132
def receive_headers (self , headers , end_stream , header_encoding ):
1061
1133
"""
1062
1134
Receive a set of headers (or trailers).
@@ -1091,6 +1163,7 @@ def receive_headers(self, headers, end_stream, header_encoding):
1091
1163
)
1092
1164
return [], events
1093
1165
1166
+ @sync_state_change
1094
1167
def receive_data (self , data , end_stream , flow_control_len ):
1095
1168
"""
1096
1169
Receive some data.
@@ -1114,6 +1187,7 @@ def receive_data(self, data, end_stream, flow_control_len):
1114
1187
events [0 ].flow_controlled_length = flow_control_len
1115
1188
return [], events
1116
1189
1190
+ @sync_state_change
1117
1191
def receive_window_update (self , increment ):
1118
1192
"""
1119
1193
Handle a WINDOW_UPDATE increment.
@@ -1150,6 +1224,7 @@ def receive_window_update(self, increment):
1150
1224
1151
1225
return frames , events
1152
1226
1227
+ @sync_state_change
1153
1228
def receive_continuation (self ):
1154
1229
"""
1155
1230
A naked CONTINUATION frame has been received. This is always an error,
@@ -1162,6 +1237,7 @@ def receive_continuation(self):
1162
1237
)
1163
1238
assert False , "Should not be reachable"
1164
1239
1240
+ @sync_state_change
1165
1241
def receive_alt_svc (self , frame ):
1166
1242
"""
1167
1243
An Alternative Service frame was received on the stream. This frame
@@ -1189,6 +1265,7 @@ def receive_alt_svc(self, frame):
1189
1265
1190
1266
return [], events
1191
1267
1268
+ @sync_state_change
1192
1269
def reset_stream (self , error_code = 0 ):
1193
1270
"""
1194
1271
Close the stream locally. Reset the stream with an error code.
@@ -1202,6 +1279,7 @@ def reset_stream(self, error_code=0):
1202
1279
rsf .error_code = error_code
1203
1280
return [rsf ]
1204
1281
1282
+ @sync_state_change
1205
1283
def stream_reset (self , frame ):
1206
1284
"""
1207
1285
Handle a stream being reset remotely.
@@ -1217,6 +1295,7 @@ def stream_reset(self, frame):
1217
1295
1218
1296
return [], events
1219
1297
1298
+ @sync_state_change
1220
1299
def acknowledge_received_data (self , acknowledged_size ):
1221
1300
"""
1222
1301
The user has informed us that they've processed some amount of data
0 commit comments