Skip to content

Commit 09ece48

Browse files
committed
Replace tcp_sockopts with socket_factory (#10520)
Instead of TCPConnector taking a list of sockopts to be applied sockets created, take a socket_factory callback that allows the caller to implement socket creation entirely.
1 parent 4399a6c commit 09ece48

10 files changed

+95
-44
lines changed

CHANGES/10474.feature.rst

-2
This file was deleted.

CHANGES/10520.feature.rst

+2
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,2 @@
1+
Added ``socket_factory`` to ``TCPConnector`` to allow specifying custom socket options
2+
-- by :user:`TimMenninger`.

aiohttp/__init__.py

+3
Original file line numberDiff line numberDiff line change
@@ -47,6 +47,7 @@
4747
WSServerHandshakeError,
4848
request,
4949
)
50+
from .connector import AddrInfoType, SocketFactoryType
5051
from .cookiejar import CookieJar, DummyCookieJar
5152
from .formdata import FormData
5253
from .helpers import BasicAuth, ChainMapProxy, ETag
@@ -112,6 +113,7 @@
112113
__all__: Tuple[str, ...] = (
113114
"hdrs",
114115
# client
116+
"AddrInfoType",
115117
"BaseConnector",
116118
"ClientConnectionError",
117119
"ClientConnectionResetError",
@@ -146,6 +148,7 @@
146148
"ServerDisconnectedError",
147149
"ServerFingerprintMismatch",
148150
"ServerTimeoutError",
151+
"SocketFactoryType",
149152
"SocketTimeoutError",
150153
"TCPConnector",
151154
"TooManyRedirects",

aiohttp/connector.py

+18-11
Original file line numberDiff line numberDiff line change
@@ -20,7 +20,6 @@
2020
DefaultDict,
2121
Deque,
2222
Dict,
23-
Iterable,
2423
Iterator,
2524
List,
2625
Literal,
@@ -34,6 +33,7 @@
3433
)
3534

3635
import aiohappyeyeballs
36+
from aiohappyeyeballs import AddrInfoType, SocketFactoryType
3737

3838
from . import hdrs, helpers
3939
from .abc import AbstractResolver, ResolveResult
@@ -96,7 +96,14 @@
9696
# which first appeared in Python 3.12.7 and 3.13.1
9797

9898

99-
__all__ = ("BaseConnector", "TCPConnector", "UnixConnector", "NamedPipeConnector")
99+
__all__ = (
100+
"BaseConnector",
101+
"TCPConnector",
102+
"UnixConnector",
103+
"NamedPipeConnector",
104+
"AddrInfoType",
105+
"SocketFactoryType",
106+
)
100107

101108

102109
if TYPE_CHECKING:
@@ -826,8 +833,9 @@ class TCPConnector(BaseConnector):
826833
the happy eyeballs algorithm, set to None.
827834
interleave - “First Address Family Count” as defined in RFC 8305
828835
loop - Optional event loop.
829-
tcp_sockopts - List of tuples of sockopts applied to underlying
830-
socket
836+
socket_factory - A SocketFactoryType function that, if supplied,
837+
will be used to create sockets given an
838+
AddrInfoType.
831839
"""
832840

833841
allowed_protocol_schema_set = HIGH_LEVEL_SCHEMA_SET | frozenset({"tcp"})
@@ -849,7 +857,7 @@ def __init__(
849857
timeout_ceil_threshold: float = 5,
850858
happy_eyeballs_delay: Optional[float] = 0.25,
851859
interleave: Optional[int] = None,
852-
tcp_sockopts: Iterable[Tuple[int, int, Union[int, Buffer]]] = [],
860+
socket_factory: Optional[SocketFactoryType] = None,
853861
):
854862
super().__init__(
855863
keepalive_timeout=keepalive_timeout,
@@ -880,7 +888,7 @@ def __init__(
880888
self._happy_eyeballs_delay = happy_eyeballs_delay
881889
self._interleave = interleave
882890
self._resolve_host_tasks: Set["asyncio.Task[List[ResolveResult]]"] = set()
883-
self._tcp_sockopts = tcp_sockopts
891+
self._socket_factory = socket_factory
884892

885893
def _close_immediately(self) -> List[Awaitable[object]]:
886894
for fut in chain.from_iterable(self._throttle_dns_futures.values()):
@@ -1105,7 +1113,7 @@ def _get_fingerprint(self, req: ClientRequest) -> Optional["Fingerprint"]:
11051113
async def _wrap_create_connection(
11061114
self,
11071115
*args: Any,
1108-
addr_infos: List[aiohappyeyeballs.AddrInfoType],
1116+
addr_infos: List[AddrInfoType],
11091117
req: ClientRequest,
11101118
timeout: "ClientTimeout",
11111119
client_error: Type[Exception] = ClientConnectorError,
@@ -1122,9 +1130,8 @@ async def _wrap_create_connection(
11221130
happy_eyeballs_delay=self._happy_eyeballs_delay,
11231131
interleave=self._interleave,
11241132
loop=self._loop,
1133+
socket_factory=self._socket_factory,
11251134
)
1126-
for sockopt in self._tcp_sockopts:
1127-
sock.setsockopt(*sockopt)
11281135
connection = await self._loop.create_connection(
11291136
*args, **kwargs, sock=sock
11301137
)
@@ -1256,13 +1263,13 @@ async def _start_tls_connection(
12561263

12571264
def _convert_hosts_to_addr_infos(
12581265
self, hosts: List[ResolveResult]
1259-
) -> List[aiohappyeyeballs.AddrInfoType]:
1266+
) -> List[AddrInfoType]:
12601267
"""Converts the list of hosts to a list of addr_infos.
12611268
12621269
The list of hosts is the result of a DNS lookup. The list of
12631270
addr_infos is the result of a call to `socket.getaddrinfo()`.
12641271
"""
1265-
addr_infos: List[aiohappyeyeballs.AddrInfoType] = []
1272+
addr_infos: List[AddrInfoType] = []
12661273
for hinfo in hosts:
12671274
host = hinfo["host"]
12681275
is_ipv6 = ":" in host

docs/client_advanced.rst

+15-8
Original file line numberDiff line numberDiff line change
@@ -468,19 +468,26 @@ If your HTTP server uses UNIX domain sockets you can use
468468
session = aiohttp.ClientSession(connector=conn)
469469

470470

471-
Setting socket options
471+
Custom socket creation
472472
^^^^^^^^^^^^^^^^^^^^^^
473473

474-
Socket options passed to the :class:`~aiohttp.TCPConnector` will be passed
475-
to the underlying socket when creating a connection. For example, we may
476-
want to change the conditions under which we consider a connection dead.
477-
The following would change that to 9*7200 = 18 hours::
474+
If the default socket is insufficient for your use case, pass an optional
475+
`socket_factory` to the :class:`~aiohttp.TCPConnector`, which implements
476+
`SocketFactoryType`. This will be used to create all sockets for the
477+
lifetime of the class object. For example, we may want to change the
478+
conditions under which we consider a connection dead. The following would
479+
make all sockets respect 9*7200 = 18 hours::
478480

479481
import socket
480482

481-
conn = aiohttp.TCPConnector(tcp_sockopts=[(socket.SOL_SOCKET, socket.SO_KEEPALIVE, True),
482-
(socket.IPPROTO_TCP, socket.TCP_KEEPIDLE, 7200),
483-
(socket.IPPROTO_TCP, socket.TCP_KEEPCNT, 9) ])
483+
def socket_factory(addr_info):
484+
family, type_, proto, _, _, _ = addr_info
485+
sock = socket.socket(family=family, type=type_, proto=proto)
486+
sock.setsockopt(socket.SOL_SOCKET, socket.SO_KEEPALIVE, True)
487+
sock.setsockopt(socket.IPPROTO_TCP, socket.TCP_KEEPIDLE, 7200)
488+
sock.setsockopt(socket.IPPROTO_TCP, socket.TCP_KEEPCNT, 9)
489+
return sock
490+
conn = aiohttp.TCPConnector(socket_factory=socket_factory)
484491

485492

486493
Named pipes in Windows

docs/client_reference.rst

+14-4
Original file line numberDiff line numberDiff line change
@@ -1122,14 +1122,24 @@ is controlled by *force_close* constructor's parameter).
11221122
overridden in subclasses.
11231123

11241124

1125+
.. autodata:: AddrInfoType
1126+
1127+
Refer to :py:data:`aiohappyeyeballs.AddrInfoType`
1128+
1129+
1130+
.. autodata:: SocketFactoryType
1131+
1132+
Refer to :py:data:`aiohappyeyeballs.SocketFactoryType`
1133+
1134+
11251135
.. class:: TCPConnector(*, ssl=True, verify_ssl=True, fingerprint=None, \
11261136
use_dns_cache=True, ttl_dns_cache=10, \
11271137
family=0, ssl_context=None, local_addr=None, \
11281138
resolver=None, keepalive_timeout=sentinel, \
11291139
force_close=False, limit=100, limit_per_host=0, \
11301140
enable_cleanup_closed=False, timeout_ceil_threshold=5, \
11311141
happy_eyeballs_delay=0.25, interleave=None, loop=None, \
1132-
tcp_sockopts=[])
1142+
socket_factory=None)
11331143

11341144
Connector for working with *HTTP* and *HTTPS* via *TCP* sockets.
11351145

@@ -1250,9 +1260,9 @@ is controlled by *force_close* constructor's parameter).
12501260

12511261
.. versionadded:: 3.10
12521262

1253-
:param list tcp_sockopts: options applied to the socket when a connection is
1254-
created. This should be a list of 3-tuples, each a ``(level, optname, value)``.
1255-
Each tuple is deconstructed and passed verbatim to ``<socket>.setsockopt``.
1263+
:param :py:data:``SocketFactoryType`` socket_factory: This function takes an
1264+
:py:data:``AddrInfoType`` and is used in lieu of ``socket.socket()`` when
1265+
creating TCP connections.
12561266

12571267
.. versionadded:: 3.12
12581268

docs/conf.py

+3
Original file line numberDiff line numberDiff line change
@@ -53,6 +53,7 @@
5353
# ones.
5454
extensions = [
5555
# stdlib-party extensions:
56+
"sphinx.ext.autodoc",
5657
"sphinx.ext.extlinks",
5758
"sphinx.ext.graphviz",
5859
"sphinx.ext.intersphinx",
@@ -82,6 +83,7 @@
8283
"aiohttpsession": ("https://aiohttp-session.readthedocs.io/en/stable/", None),
8384
"aiohttpdemos": ("https://aiohttp-demos.readthedocs.io/en/latest/", None),
8485
"aiojobs": ("https://aiojobs.readthedocs.io/en/stable/", None),
86+
"aiohappyeyeballs": ("https://aiohappyeyeballs.readthedocs.io/en/stable/", None),
8587
}
8688

8789
# Add any paths that contain templates here, relative to this directory.
@@ -425,6 +427,7 @@
425427
("py:class", "cgi.FieldStorage"), # undocumented
426428
("py:meth", "aiohttp.web.UrlDispatcher.register_resource"), # undocumented
427429
("py:func", "aiohttp_debugtoolbar.setup"), # undocumented
430+
("py:class", "socket.SocketKind"), # undocumented
428431
]
429432

430433
# -- Options for towncrier_draft extension -----------------------------------

requirements/runtime-deps.in

+1-1
Original file line numberDiff line numberDiff line change
@@ -1,7 +1,7 @@
11
# Extracted from `setup.cfg` via `make sync-direct-runtime-deps`
22

33
aiodns >= 3.2.0; sys_platform=="linux" or sys_platform=="darwin"
4-
aiohappyeyeballs >= 2.3.0
4+
aiohappyeyeballs >= 2.5.0
55
aiosignal >= 1.1.2
66
async-timeout >= 4.0, < 6.0 ; python_version < "3.11"
77
Brotli; platform_python_implementation == 'CPython'

setup.cfg

+1-1
Original file line numberDiff line numberDiff line change
@@ -51,7 +51,7 @@ zip_safe = False
5151
include_package_data = True
5252

5353
install_requires =
54-
aiohappyeyeballs >= 2.3.0
54+
aiohappyeyeballs >= 2.5.0
5555
aiosignal >= 1.1.2
5656
async-timeout >= 4.0, < 6.0 ; python_version < "3.11"
5757
frozenlist >= 1.1.1

tests/test_connector.py

+38-17
Original file line numberDiff line numberDiff line change
@@ -26,7 +26,6 @@
2626
from unittest import mock
2727

2828
import pytest
29-
from aiohappyeyeballs import AddrInfoType
3029
from pytest_mock import MockerFixture
3130
from yarl import URL
3231

@@ -44,6 +43,7 @@
4443
from aiohttp.connector import (
4544
_SSL_CONTEXT_UNVERIFIED,
4645
_SSL_CONTEXT_VERIFIED,
46+
AddrInfoType,
4747
Connection,
4848
TCPConnector,
4949
_DNSCacheTable,
@@ -3767,27 +3767,48 @@ def test_connect() -> Literal[True]:
37673767
assert raw_response_list == [True, True]
37683768

37693769

3770-
async def test_tcp_connector_setsockopts(
3770+
async def test_tcp_connector_socket_factory(
37713771
loop: asyncio.AbstractEventLoop, start_connection: mock.AsyncMock
37723772
) -> None:
3773-
"""Check that sockopts get passed to socket"""
3774-
conn = aiohttp.TCPConnector(
3775-
tcp_sockopts=[(socket.IPPROTO_TCP, socket.TCP_KEEPCNT, 2)]
3776-
)
3777-
3778-
with mock.patch.object(
3779-
conn._loop, "create_connection", autospec=True, spec_set=True
3780-
) as create_connection:
3781-
with socket.socket(socket.AF_INET, socket.SOCK_STREAM) as s:
3782-
start_connection.return_value = s
3783-
create_connection.return_value = mock.Mock(), mock.Mock()
3773+
"""Check that socket factory is called"""
3774+
with socket.socket(socket.AF_INET, socket.SOCK_STREAM) as s:
3775+
start_connection.return_value = s
37843776

3785-
req = ClientRequest("GET", URL("https://127.0.0.1:443"), loop=loop)
3777+
local_addr = None
3778+
socket_factory: Callable[[AddrInfoType], socket.socket] = lambda _: s
3779+
happy_eyeballs_delay = 0.123
3780+
interleave = 3
3781+
conn = aiohttp.TCPConnector(
3782+
interleave=interleave,
3783+
local_addr=local_addr,
3784+
happy_eyeballs_delay=happy_eyeballs_delay,
3785+
socket_factory=socket_factory,
3786+
)
37863787

3788+
with mock.patch.object(
3789+
conn._loop,
3790+
"create_connection",
3791+
autospec=True,
3792+
spec_set=True,
3793+
return_value=(mock.Mock(), mock.Mock()),
3794+
):
3795+
host = "127.0.0.1"
3796+
port = 443
3797+
req = ClientRequest("GET", URL(f"https://{host}:{port}"), loop=loop)
37873798
with closing(await conn.connect(req, [], ClientTimeout())):
3788-
assert s.getsockopt(socket.IPPROTO_TCP, socket.TCP_KEEPCNT) == 2
3789-
3790-
await conn.close()
3799+
pass
3800+
await conn.close()
3801+
3802+
start_connection.assert_called_with(
3803+
addr_infos=[
3804+
(socket.AF_INET, socket.SOCK_STREAM, socket.IPPROTO_TCP, "", (host, port))
3805+
],
3806+
local_addr_infos=local_addr,
3807+
happy_eyeballs_delay=happy_eyeballs_delay,
3808+
interleave=interleave,
3809+
loop=loop,
3810+
socket_factory=socket_factory,
3811+
)
37913812

37923813

37933814
def test_default_ssl_context_creation_without_ssl() -> None:

0 commit comments

Comments
 (0)