Skip to content

Commit a2a3524

Browse files
committed
Bump websockets to >=14.2
Using new websockets asyncio implementation
1 parent fbe03c4 commit a2a3524

File tree

9 files changed

+48
-59
lines changed

9 files changed

+48
-59
lines changed

gql/transport/common/adapters/websockets.py

Lines changed: 7 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -3,7 +3,7 @@
33
from typing import Any, Dict, Optional, Union
44

55
import websockets
6-
from websockets.client import WebSocketClientProtocol
6+
from websockets import ClientConnection
77
from websockets.datastructures import Headers, HeadersLike
88

99
from ...exceptions import TransportConnectionFailed, TransportProtocolError
@@ -40,7 +40,7 @@ def __init__(
4040
self._headers: Optional[HeadersLike] = headers
4141
self.ssl = ssl
4242

43-
self.websocket: Optional[WebSocketClientProtocol] = None
43+
self.websocket: Optional[ClientConnection] = None
4444
self._response_headers: Optional[Headers] = None
4545

4646
async def connect(self) -> None:
@@ -57,7 +57,7 @@ async def connect(self) -> None:
5757
# Set default arguments used in the websockets.connect call
5858
connect_args: Dict[str, Any] = {
5959
"ssl": ssl,
60-
"extra_headers": self.headers,
60+
"additional_headers": self.headers,
6161
}
6262

6363
if self.subprotocols:
@@ -68,11 +68,13 @@ async def connect(self) -> None:
6868

6969
# Connection to the specified url
7070
try:
71-
self.websocket = await websockets.client.connect(self.url, **connect_args)
71+
self.websocket = await websockets.connect(self.url, **connect_args)
7272
except Exception as e:
7373
raise TransportConnectionFailed("Connect failed") from e
7474

75-
self._response_headers = self.websocket.response_headers
75+
assert self.websocket.response is not None
76+
77+
self._response_headers = self.websocket.response.headers
7678

7779
async def send(self, message: str) -> None:
7880
"""Send message to the WebSocket server.

gql/transport/common/base.py

Lines changed: 7 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -482,6 +482,10 @@ async def _close_coro(self, e: Exception, clean_close: bool = True) -> None:
482482
# We should always have an active websocket connection here
483483
assert self._connected
484484

485+
# Saving exception to raise it later if trying to use the transport
486+
# after it has already closed.
487+
self.close_exception = e
488+
485489
# Properly shut down liveness checker if enabled
486490
if self.check_keep_alive_task is not None:
487491
# More info: https://stackoverflow.com/a/43810272/1113207
@@ -492,18 +496,16 @@ async def _close_coro(self, e: Exception, clean_close: bool = True) -> None:
492496
# Calling the subclass close hook
493497
await self._close_hook()
494498

495-
# Saving exception to raise it later if trying to use the transport
496-
# after it has already closed.
497-
self.close_exception = e
498-
499499
if clean_close:
500500
log.debug("_close_coro: starting clean_close")
501501
try:
502502
await self._clean_close(e)
503503
except Exception as exc: # pragma: no cover
504504
log.warning("Ignoring exception in _clean_close: " + repr(exc))
505505

506-
log.debug("_close_coro: sending exception to listeners")
506+
log.debug(
507+
f"_close_coro: sending exception to {len(self.listeners)} listeners"
508+
)
507509

508510
# Send an exception to all remaining listeners
509511
for query_id, listener in self.listeners.items():

setup.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -51,7 +51,7 @@
5151
]
5252

5353
install_websockets_requires = [
54-
"websockets>=10.1,<14",
54+
"websockets>=14.2,<16",
5555
]
5656

5757
install_botocore_requires = [

tests/conftest.py

Lines changed: 13 additions & 26 deletions
Original file line numberDiff line numberDiff line change
@@ -197,7 +197,7 @@ def __init__(self, with_ssl: bool = False):
197197

198198
async def start(self, handler, extra_serve_args=None):
199199

200-
import websockets.server
200+
import websockets
201201

202202
print("Starting server")
203203

@@ -209,16 +209,21 @@ async def start(self, handler, extra_serve_args=None):
209209
extra_serve_args["ssl"] = ssl_context
210210

211211
# Adding dummy response headers
212-
extra_serve_args["extra_headers"] = {"dummy": "test1234"}
212+
extra_headers = {"dummy": "test1234"}
213+
214+
def process_response(connection, request, response):
215+
response.headers.update(extra_headers)
216+
return response
213217

214218
# Start a server with a random open port
215-
self.start_server = websockets.server.serve(
216-
handler, "127.0.0.1", 0, **extra_serve_args
219+
self.server = await websockets.serve(
220+
handler,
221+
"127.0.0.1",
222+
0,
223+
process_response=process_response,
224+
**extra_serve_args,
217225
)
218226

219-
# Wait that the server is started
220-
self.server = await self.start_server
221-
222227
# Get hostname and port
223228
hostname, port = self.server.sockets[0].getsockname()[:2] # type: ignore
224229
assert hostname == "127.0.0.1"
@@ -603,32 +608,14 @@ async def graphqlws_server(request):
603608

604609
subprotocol = "graphql-transport-ws"
605610

606-
from websockets.server import WebSocketServerProtocol
607-
608-
class CustomSubprotocol(WebSocketServerProtocol):
609-
def select_subprotocol(self, client_subprotocols, server_subprotocols):
610-
print(f"Client subprotocols: {client_subprotocols!r}")
611-
print(f"Server subprotocols: {server_subprotocols!r}")
612-
613-
return subprotocol
614-
615-
def process_subprotocol(self, headers, available_subprotocols):
616-
# Overwriting available subprotocols
617-
available_subprotocols = [subprotocol]
618-
619-
print(f"headers: {headers!r}")
620-
# print (f"Available subprotocols: {available_subprotocols!r}")
621-
622-
return super().process_subprotocol(headers, available_subprotocols)
623-
624611
server_handler = get_server_handler(request)
625612

626613
try:
627614
test_server = WebSocketServer()
628615

629616
# Starting the server with the fixture param as the handler function
630617
await test_server.start(
631-
server_handler, extra_serve_args={"create_protocol": CustomSubprotocol}
618+
server_handler, extra_serve_args={"subprotocols": [subprotocol]}
632619
)
633620

634621
yield test_server

tests/test_aiohttp_websocket_exceptions.py

Lines changed: 10 additions & 10 deletions
Original file line numberDiff line numberDiff line change
@@ -118,10 +118,10 @@ async def test_aiohttp_websocket_server_does_not_send_ack(server, query_str):
118118

119119
url = f"ws://{server.hostname}:{server.port}/graphql"
120120

121-
sample_transport = AIOHTTPWebsocketsTransport(url=url, ack_timeout=1)
121+
transport = AIOHTTPWebsocketsTransport(url=url, ack_timeout=1)
122122

123123
with pytest.raises(asyncio.TimeoutError):
124-
async with Client(transport=sample_transport):
124+
async with Client(transport=transport):
125125
pass
126126

127127

@@ -261,10 +261,10 @@ async def test_aiohttp_websocket_server_does_not_ack(server):
261261
url = f"ws://{server.hostname}:{server.port}/graphql"
262262
print(f"url = {url}")
263263

264-
sample_transport = AIOHTTPWebsocketsTransport(url=url)
264+
transport = AIOHTTPWebsocketsTransport(url=url)
265265

266266
with pytest.raises(TransportProtocolError):
267-
async with Client(transport=sample_transport):
267+
async with Client(transport=transport):
268268
pass
269269

270270

@@ -281,10 +281,10 @@ async def test_aiohttp_websocket_server_closing_directly(server):
281281
url = f"ws://{server.hostname}:{server.port}/graphql"
282282
print(f"url = {url}")
283283

284-
sample_transport = AIOHTTPWebsocketsTransport(url=url)
284+
transport = AIOHTTPWebsocketsTransport(url=url)
285285

286286
with pytest.raises(TransportConnectionFailed):
287-
async with Client(transport=sample_transport):
287+
async with Client(transport=transport):
288288
pass
289289

290290

@@ -323,10 +323,10 @@ async def test_aiohttp_websocket_server_sending_invalid_query_errors(server):
323323
url = f"ws://{server.hostname}:{server.port}/graphql"
324324
print(f"url = {url}")
325325

326-
sample_transport = AIOHTTPWebsocketsTransport(url=url)
326+
transport = AIOHTTPWebsocketsTransport(url=url)
327327

328328
# Invalid server message is ignored
329-
async with Client(transport=sample_transport):
329+
async with Client(transport=transport):
330330
await asyncio.sleep(2 * MS)
331331

332332

@@ -342,9 +342,9 @@ async def test_aiohttp_websocket_non_regression_bug_105(server):
342342
url = f"ws://{server.hostname}:{server.port}/graphql"
343343
print(f"url = {url}")
344344

345-
sample_transport = AIOHTTPWebsocketsTransport(url=url)
345+
transport = AIOHTTPWebsocketsTransport(url=url)
346346

347-
client = Client(transport=sample_transport)
347+
client = Client(transport=transport)
348348

349349
# Create a coroutine which start the connection with the transport but does nothing
350350
async def client_connect(client):

tests/test_appsync_websockets.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -139,7 +139,7 @@ async def realtime_appsync_server_template(ws):
139139
)
140140
return
141141

142-
path = ws.path
142+
path = ws.request.path
143143

144144
print(f"path = {path}")
145145

tests/test_graphqlws_exceptions.py

Lines changed: 7 additions & 7 deletions
Original file line numberDiff line numberDiff line change
@@ -111,10 +111,10 @@ async def test_graphqlws_server_does_not_send_ack(graphqlws_server, query_str):
111111

112112
url = f"ws://{graphqlws_server.hostname}:{graphqlws_server.port}/graphql"
113113

114-
sample_transport = WebsocketsTransport(url=url, ack_timeout=1)
114+
transport = WebsocketsTransport(url=url, ack_timeout=1)
115115

116116
with pytest.raises(asyncio.TimeoutError):
117-
async with Client(transport=sample_transport):
117+
async with Client(transport=transport):
118118
pass
119119

120120

@@ -212,10 +212,10 @@ async def test_graphqlws_server_does_not_ack(graphqlws_server):
212212
url = f"ws://{graphqlws_server.hostname}:{graphqlws_server.port}/graphql"
213213
print(f"url = {url}")
214214

215-
sample_transport = WebsocketsTransport(url=url)
215+
transport = WebsocketsTransport(url=url)
216216

217217
with pytest.raises(TransportProtocolError):
218-
async with Client(transport=sample_transport):
218+
async with Client(transport=transport):
219219
pass
220220

221221

@@ -231,10 +231,10 @@ async def test_graphqlws_server_closing_directly(graphqlws_server):
231231
url = f"ws://{graphqlws_server.hostname}:{graphqlws_server.port}/graphql"
232232
print(f"url = {url}")
233233

234-
sample_transport = WebsocketsTransport(url=url)
234+
transport = WebsocketsTransport(url=url)
235235

236236
with pytest.raises(TransportConnectionFailed):
237-
async with Client(transport=sample_transport):
237+
async with Client(transport=transport):
238238
pass
239239

240240

@@ -251,7 +251,7 @@ async def test_graphqlws_server_closing_after_ack(client_and_graphqlws_server):
251251

252252
query = gql("query { hello }")
253253

254-
with pytest.raises(TransportConnectionFailed):
254+
with pytest.raises(TransportClosed):
255255
await session.execute(query)
256256

257257
await session.transport.wait_closed()

tests/test_websocket_exceptions.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -296,7 +296,7 @@ async def test_websocket_server_closing_after_ack(client_and_server):
296296

297297
query = gql("query { hello }")
298298

299-
with pytest.raises(TransportConnectionFailed):
299+
with pytest.raises(TransportClosed):
300300
await session.execute(query)
301301

302302
await session.transport.wait_closed()

tests/test_websocket_query.py

Lines changed: 1 addition & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -112,9 +112,7 @@ async def test_websocket_using_ssl_connection(ws_ssl_server):
112112

113113
async with Client(transport=transport) as session:
114114

115-
assert isinstance(
116-
transport.adapter.websocket, websockets.client.WebSocketClientProtocol
117-
)
115+
assert isinstance(transport.adapter.websocket, websockets.ClientConnection)
118116

119117
query1 = gql(query1_str)
120118

0 commit comments

Comments
 (0)