Skip to content

Commit eb17499

Browse files
committed
Replace tcp_sockopts with socket_factory (#10520)
Instead of TCPConnector taking a list of sockopts to be applied sockets created, take a socket_factory callback that allows the caller to implement socket creation entirely.
1 parent 4399a6c commit eb17499

7 files changed

+65
-37
lines changed

CHANGES/10474.feature.rst

-2
This file was deleted.

CHANGES/10520.feature.rst

+2
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,2 @@
1+
Added ``socket_factory`` to ``TCPConnector`` to allow specifying custom socket options
2+
-- by :user:`TimMenninger`.

aiohttp/connector.py

+6-7
Original file line numberDiff line numberDiff line change
@@ -20,7 +20,6 @@
2020
DefaultDict,
2121
Deque,
2222
Dict,
23-
Iterable,
2423
Iterator,
2524
List,
2625
Literal,
@@ -826,8 +825,9 @@ class TCPConnector(BaseConnector):
826825
the happy eyeballs algorithm, set to None.
827826
interleave - “First Address Family Count” as defined in RFC 8305
828827
loop - Optional event loop.
829-
tcp_sockopts - List of tuples of sockopts applied to underlying
830-
socket
828+
socket_factory - An aiohappyeyeballs.SocketFactoryType function
829+
that, if supplied, will be used to create sockets
830+
given an aiohappyeyeballs.AddrInfoType.
831831
"""
832832

833833
allowed_protocol_schema_set = HIGH_LEVEL_SCHEMA_SET | frozenset({"tcp"})
@@ -849,7 +849,7 @@ def __init__(
849849
timeout_ceil_threshold: float = 5,
850850
happy_eyeballs_delay: Optional[float] = 0.25,
851851
interleave: Optional[int] = None,
852-
tcp_sockopts: Iterable[Tuple[int, int, Union[int, Buffer]]] = [],
852+
socket_factory: Optional[aiohappyeyeballs.SocketFactoryType] = None,
853853
):
854854
super().__init__(
855855
keepalive_timeout=keepalive_timeout,
@@ -880,7 +880,7 @@ def __init__(
880880
self._happy_eyeballs_delay = happy_eyeballs_delay
881881
self._interleave = interleave
882882
self._resolve_host_tasks: Set["asyncio.Task[List[ResolveResult]]"] = set()
883-
self._tcp_sockopts = tcp_sockopts
883+
self._socket_factory = socket_factory
884884

885885
def _close_immediately(self) -> List[Awaitable[object]]:
886886
for fut in chain.from_iterable(self._throttle_dns_futures.values()):
@@ -1122,9 +1122,8 @@ async def _wrap_create_connection(
11221122
happy_eyeballs_delay=self._happy_eyeballs_delay,
11231123
interleave=self._interleave,
11241124
loop=self._loop,
1125+
socket_factory=self._socket_factory,
11251126
)
1126-
for sockopt in self._tcp_sockopts:
1127-
sock.setsockopt(*sockopt)
11281127
connection = await self._loop.create_connection(
11291128
*args, **kwargs, sock=sock
11301129
)

docs/client_advanced.rst

+15-8
Original file line numberDiff line numberDiff line change
@@ -468,19 +468,26 @@ If your HTTP server uses UNIX domain sockets you can use
468468
session = aiohttp.ClientSession(connector=conn)
469469

470470

471-
Setting socket options
471+
Custom socket creation
472472
^^^^^^^^^^^^^^^^^^^^^^
473473

474-
Socket options passed to the :class:`~aiohttp.TCPConnector` will be passed
475-
to the underlying socket when creating a connection. For example, we may
476-
want to change the conditions under which we consider a connection dead.
477-
The following would change that to 9*7200 = 18 hours::
474+
If the default socket is insufficient for your use case, pass an optional
475+
`socket_factory` to the :class:`~aiohttp.TCPConnector`, which implements
476+
`aiohappyeyeballs.SocketFactoryType`. This will be used to create all
477+
sockets for the lifetime of the class object. For example, we may want to
478+
change the conditions under which we consider a connection dead. The
479+
following would make all sockets respect 9*7200 = 18 hours::
478480

479481
import socket
480482

481-
conn = aiohttp.TCPConnector(tcp_sockopts=[(socket.SOL_SOCKET, socket.SO_KEEPALIVE, True),
482-
(socket.IPPROTO_TCP, socket.TCP_KEEPIDLE, 7200),
483-
(socket.IPPROTO_TCP, socket.TCP_KEEPCNT, 9) ])
483+
def socket_factory(addr_info):
484+
family, type_, proto, _, _, _ = addr_info
485+
sock = socket.socket(family=family, type=type_, proto=proto)
486+
sock.setsockopt(socket.SOL_SOCKET, socket.SO_KEEPALIVE, True)
487+
sock.setsockopt(socket.IPPROTO_TCP, socket.TCP_KEEPIDLE, 7200)
488+
sock.setsockopt(socket.IPPROTO_TCP, socket.TCP_KEEPCNT, 9)
489+
return sock
490+
conn = aiohttp.TCPConnector(socket_factory=socket_factory)
484491

485492

486493
Named pipes in Windows

docs/client_reference.rst

+4-4
Original file line numberDiff line numberDiff line change
@@ -1129,7 +1129,7 @@ is controlled by *force_close* constructor's parameter).
11291129
force_close=False, limit=100, limit_per_host=0, \
11301130
enable_cleanup_closed=False, timeout_ceil_threshold=5, \
11311131
happy_eyeballs_delay=0.25, interleave=None, loop=None, \
1132-
tcp_sockopts=[])
1132+
socket_factory=None)
11331133

11341134
Connector for working with *HTTP* and *HTTPS* via *TCP* sockets.
11351135

@@ -1250,9 +1250,9 @@ is controlled by *force_close* constructor's parameter).
12501250

12511251
.. versionadded:: 3.10
12521252

1253-
:param list tcp_sockopts: options applied to the socket when a connection is
1254-
created. This should be a list of 3-tuples, each a ``(level, optname, value)``.
1255-
Each tuple is deconstructed and passed verbatim to ``<socket>.setsockopt``.
1253+
:param :py:data:``aiohappyeyeballs.SocketFactoryType`` socket_factory: This function takes
1254+
an :py:data:``aiohappyeyeballs.AddrInfoType`` and is used in lieu of ``socket.socket()``
1255+
when creating TCP connections.
12561256

12571257
.. versionadded:: 3.12
12581258

docs/conf.py

+1
Original file line numberDiff line numberDiff line change
@@ -82,6 +82,7 @@
8282
"aiohttpsession": ("https://aiohttp-session.readthedocs.io/en/stable/", None),
8383
"aiohttpdemos": ("https://aiohttp-demos.readthedocs.io/en/latest/", None),
8484
"aiojobs": ("https://aiojobs.readthedocs.io/en/stable/", None),
85+
"aiohappyeyeballs": ("https://aiohappyeyeballs.readthedocs.io/en/stable/", None),
8586
}
8687

8788
# Add any paths that contain templates here, relative to this directory.

tests/test_connector.py

+37-16
Original file line numberDiff line numberDiff line change
@@ -3767,27 +3767,48 @@ def test_connect() -> Literal[True]:
37673767
assert raw_response_list == [True, True]
37683768

37693769

3770-
async def test_tcp_connector_setsockopts(
3770+
async def test_tcp_connector_socket_factory(
37713771
loop: asyncio.AbstractEventLoop, start_connection: mock.AsyncMock
37723772
) -> None:
3773-
"""Check that sockopts get passed to socket"""
3774-
conn = aiohttp.TCPConnector(
3775-
tcp_sockopts=[(socket.IPPROTO_TCP, socket.TCP_KEEPCNT, 2)]
3776-
)
3777-
3778-
with mock.patch.object(
3779-
conn._loop, "create_connection", autospec=True, spec_set=True
3780-
) as create_connection:
3781-
with socket.socket(socket.AF_INET, socket.SOCK_STREAM) as s:
3782-
start_connection.return_value = s
3783-
create_connection.return_value = mock.Mock(), mock.Mock()
3773+
"""Check that socket factory is called"""
3774+
with socket.socket(socket.AF_INET, socket.SOCK_STREAM) as s:
3775+
start_connection.return_value = s
37843776

3785-
req = ClientRequest("GET", URL("https://127.0.0.1:443"), loop=loop)
3777+
local_addr = None
3778+
socket_factory: Callable[[AddrInfoType], socket.socket] = lambda _: s
3779+
happy_eyeballs_delay = 0.123
3780+
interleave = 3
3781+
conn = aiohttp.TCPConnector(
3782+
interleave=interleave,
3783+
local_addr=local_addr,
3784+
happy_eyeballs_delay=happy_eyeballs_delay,
3785+
socket_factory=socket_factory,
3786+
)
37863787

3788+
with mock.patch.object(
3789+
conn._loop,
3790+
"create_connection",
3791+
autospec=True,
3792+
spec_set=True,
3793+
return_value=(mock.Mock(), mock.Mock()),
3794+
):
3795+
host = "127.0.0.1"
3796+
port = 443
3797+
req = ClientRequest("GET", URL(f"https://{host}:{port}"), loop=loop)
37873798
with closing(await conn.connect(req, [], ClientTimeout())):
3788-
assert s.getsockopt(socket.IPPROTO_TCP, socket.TCP_KEEPCNT) == 2
3789-
3790-
await conn.close()
3799+
pass
3800+
await conn.close()
3801+
3802+
start_connection.assert_called_with(
3803+
addr_infos=[
3804+
(socket.AF_INET, socket.SOCK_STREAM, socket.IPPROTO_TCP, "", (host, port))
3805+
],
3806+
local_addr_infos=local_addr,
3807+
happy_eyeballs_delay=happy_eyeballs_delay,
3808+
interleave=interleave,
3809+
loop=loop,
3810+
socket_factory=socket_factory,
3811+
)
37913812

37923813

37933814
def test_default_ssl_context_creation_without_ssl() -> None:

0 commit comments

Comments
 (0)