Skip to content

Commit 240fb9e

Browse files
bdracoTimMenninger
andauthored
[PR #10534/3b9bb1cd backport][3.12] Replace tcp_sockopts with socket_factory (#10574)
replaces and closes #10565 Instead of TCPConnector taking a list of sockopts to be applied sockets created, take a socket_factory callback that allows the caller to implement socket creation entirely. Fixes #10520 <!-- Thank you for your contribution! --> Replace `tcp_sockopts` parameter with a `socket_factory` parameter that is a callback allowing the caller to own socket creation. If passed, all sockets created by `TCPConnector` are expected to come from the `socket_factory` callback. <!-- Please give a short brief about these changes. --> The only users to experience a change in behavior are those who are using the un-released `tcp_sockopts` argument to `TCPConnector`. However, using unreleased code comes with caveat emptor, and is why I felt entitled to remove the option entirely without warning. <!-- Outline any notable behaviour for the end users. --> The burden will be minimal and would only arise if `aiohappyeyeballs` changes their interface. <!-- Stop right there! Pause. Just for a minute... Can you think of anything obvious that would complicate the ongoing development of this project? Try to consider if you'd be able to maintain it throughout the next 5 years. Does it seem viable? Tell us your thoughts! We'd very much love to hear what the consequences of merging this patch might be... This will help us assess if your change is something we'd want to entertain early in the review process. Thank you in advance! --> <!-- Are there any issues opened that will be resolved by merging this change? --> <!-- Remember to prefix with 'Fixes' if it should close the issue (e.g. 'Fixes #123'). --> - [x] I think the code is well written - [x] Unit tests for the changes exist - [x] Documentation reflects the changes - [x] If you provide code modification, please add yourself to `CONTRIBUTORS.txt` * The format is &lt;Name&gt; &lt;Surname&gt;. * Please keep alphabetical order, the file is sorted by names. - [x] Add a new news fragment into the `CHANGES/` folder * name it `<issue_or_pr_num>.<type>.rst` (e.g. `588.bugfix.rst`) * if you don't have an issue number, change it to the pull request number after creating the PR * `.bugfix`: A bug fix for something the maintainers deemed an improper undesired behavior that got corrected to match pre-agreed expectations. * `.feature`: A new behavior, public APIs. That sort of stuff. * `.deprecation`: A declaration of future API removals and breaking changes in behavior. * `.breaking`: When something public is removed in a breaking way. Could be deprecated in an earlier release. * `.doc`: Notable updates to the documentation structure or build process. * `.packaging`: Notes for downstreams about unobvious side effects and tooling. Changes in the test invocation considerations and runtime assumptions. * `.contrib`: Stuff that affects the contributor experience. e.g. Running tests, building the docs, setting up the development environment. * `.misc`: Changes that are hard to assign to any of the above categories. * Make sure to use full sentences with correct case and punctuation, for example: ```rst Fixed issue with non-ascii contents in doctest text files -- by :user:`contributor-gh-handle`. ``` Use the past tense or the present tense a non-imperative mood, referring to what's changed compared to the last released version of this project. --------- Co-authored-by: J. Nick Koston <[email protected]> (cherry picked from commit 3b9bb1c) <!-- Thank you for your contribution! --> ## What do these changes do? <!-- Please give a short brief about these changes. --> ## Are there changes in behavior for the user? <!-- Outline any notable behaviour for the end users. --> ## Is it a substantial burden for the maintainers to support this? <!-- Stop right there! Pause. Just for a minute... Can you think of anything obvious that would complicate the ongoing development of this project? Try to consider if you'd be able to maintain it throughout the next 5 years. Does it seem viable? Tell us your thoughts! We'd very much love to hear what the consequences of merging this patch might be... This will help us assess if your change is something we'd want to entertain early in the review process. Thank you in advance! --> ## Related issue number <!-- Are there any issues opened that will be resolved by merging this change? --> <!-- Remember to prefix with 'Fixes' if it should close the issue (e.g. 'Fixes #123'). --> ## Checklist - [ ] I think the code is well written - [ ] Unit tests for the changes exist - [ ] Documentation reflects the changes - [ ] If you provide code modification, please add yourself to `CONTRIBUTORS.txt` * The format is &lt;Name&gt; &lt;Surname&gt;. * Please keep alphabetical order, the file is sorted by names. - [ ] Add a new news fragment into the `CHANGES/` folder * name it `<issue_or_pr_num>.<type>.rst` (e.g. `588.bugfix.rst`) * if you don't have an issue number, change it to the pull request number after creating the PR * `.bugfix`: A bug fix for something the maintainers deemed an improper undesired behavior that got corrected to match pre-agreed expectations. * `.feature`: A new behavior, public APIs. That sort of stuff. * `.deprecation`: A declaration of future API removals and breaking changes in behavior. * `.breaking`: When something public is removed in a breaking way. Could be deprecated in an earlier release. * `.doc`: Notable updates to the documentation structure or build process. * `.packaging`: Notes for downstreams about unobvious side effects and tooling. Changes in the test invocation considerations and runtime assumptions. * `.contrib`: Stuff that affects the contributor experience. e.g. Running tests, building the docs, setting up the development environment. * `.misc`: Changes that are hard to assign to any of the above categories. * Make sure to use full sentences with correct case and punctuation, for example: ```rst Fixed issue with non-ascii contents in doctest text files -- by :user:`contributor-gh-handle`. ``` Use the past tense or the present tense a non-imperative mood, referring to what's changed compared to the last released version of this project. Co-authored-by: Tim Menninger <[email protected]>
1 parent 8337927 commit 240fb9e

10 files changed

+128
-46
lines changed

CHANGES/10474.feature.rst

-2
This file was deleted.

CHANGES/10520.feature.rst

+2
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,2 @@
1+
Added ``socket_factory`` to :py:class:`aiohttp.TCPConnector` to allow specifying custom socket options
2+
-- by :user:`TimMenninger`.

aiohttp/__init__.py

+6
Original file line numberDiff line numberDiff line change
@@ -47,6 +47,10 @@
4747
WSServerHandshakeError,
4848
request,
4949
)
50+
from .connector import (
51+
AddrInfoType as AddrInfoType,
52+
SocketFactoryType as SocketFactoryType,
53+
)
5054
from .cookiejar import CookieJar as CookieJar, DummyCookieJar as DummyCookieJar
5155
from .formdata import FormData as FormData
5256
from .helpers import BasicAuth, ChainMapProxy, ETag
@@ -126,6 +130,7 @@
126130
__all__: Tuple[str, ...] = (
127131
"hdrs",
128132
# client
133+
"AddrInfoType",
129134
"BaseConnector",
130135
"ClientConnectionError",
131136
"ClientConnectionResetError",
@@ -161,6 +166,7 @@
161166
"ServerDisconnectedError",
162167
"ServerFingerprintMismatch",
163168
"ServerTimeoutError",
169+
"SocketFactoryType",
164170
"SocketTimeoutError",
165171
"TCPConnector",
166172
"TooManyRedirects",

aiohttp/connector.py

+18-11
Original file line numberDiff line numberDiff line change
@@ -19,7 +19,6 @@
1919
DefaultDict,
2020
Deque,
2121
Dict,
22-
Iterable,
2322
Iterator,
2423
List,
2524
Literal,
@@ -33,6 +32,7 @@
3332
)
3433

3534
import aiohappyeyeballs
35+
from aiohappyeyeballs import AddrInfoType, SocketFactoryType
3636

3737
from . import hdrs, helpers
3838
from .abc import AbstractResolver, ResolveResult
@@ -95,7 +95,14 @@
9595
# which first appeared in Python 3.12.7 and 3.13.1
9696

9797

98-
__all__ = ("BaseConnector", "TCPConnector", "UnixConnector", "NamedPipeConnector")
98+
__all__ = (
99+
"BaseConnector",
100+
"TCPConnector",
101+
"UnixConnector",
102+
"NamedPipeConnector",
103+
"AddrInfoType",
104+
"SocketFactoryType",
105+
)
99106

100107

101108
if TYPE_CHECKING:
@@ -834,8 +841,9 @@ class TCPConnector(BaseConnector):
834841
the happy eyeballs algorithm, set to None.
835842
interleave - “First Address Family Count” as defined in RFC 8305
836843
loop - Optional event loop.
837-
tcp_sockopts - List of tuples of sockopts applied to underlying
838-
socket
844+
socket_factory - A SocketFactoryType function that, if supplied,
845+
will be used to create sockets given an
846+
AddrInfoType.
839847
"""
840848

841849
allowed_protocol_schema_set = HIGH_LEVEL_SCHEMA_SET | frozenset({"tcp"})
@@ -861,7 +869,7 @@ def __init__(
861869
timeout_ceil_threshold: float = 5,
862870
happy_eyeballs_delay: Optional[float] = 0.25,
863871
interleave: Optional[int] = None,
864-
tcp_sockopts: Iterable[Tuple[int, int, Union[int, Buffer]]] = [],
872+
socket_factory: Optional[SocketFactoryType] = None,
865873
):
866874
super().__init__(
867875
keepalive_timeout=keepalive_timeout,
@@ -888,7 +896,7 @@ def __init__(
888896
self._happy_eyeballs_delay = happy_eyeballs_delay
889897
self._interleave = interleave
890898
self._resolve_host_tasks: Set["asyncio.Task[List[ResolveResult]]"] = set()
891-
self._tcp_sockopts = tcp_sockopts
899+
self._socket_factory = socket_factory
892900

893901
def close(self) -> Awaitable[None]:
894902
"""Close all ongoing DNS calls."""
@@ -1112,7 +1120,7 @@ def _get_fingerprint(self, req: ClientRequest) -> Optional["Fingerprint"]:
11121120
async def _wrap_create_connection(
11131121
self,
11141122
*args: Any,
1115-
addr_infos: List[aiohappyeyeballs.AddrInfoType],
1123+
addr_infos: List[AddrInfoType],
11161124
req: ClientRequest,
11171125
timeout: "ClientTimeout",
11181126
client_error: Type[Exception] = ClientConnectorError,
@@ -1129,9 +1137,8 @@ async def _wrap_create_connection(
11291137
happy_eyeballs_delay=self._happy_eyeballs_delay,
11301138
interleave=self._interleave,
11311139
loop=self._loop,
1140+
socket_factory=self._socket_factory,
11321141
)
1133-
for sockopt in self._tcp_sockopts:
1134-
sock.setsockopt(*sockopt)
11351142
connection = await self._loop.create_connection(
11361143
*args, **kwargs, sock=sock
11371144
)
@@ -1331,13 +1338,13 @@ async def _start_tls_connection(
13311338

13321339
def _convert_hosts_to_addr_infos(
13331340
self, hosts: List[ResolveResult]
1334-
) -> List[aiohappyeyeballs.AddrInfoType]:
1341+
) -> List[AddrInfoType]:
13351342
"""Converts the list of hosts to a list of addr_infos.
13361343
13371344
The list of hosts is the result of a DNS lookup. The list of
13381345
addr_infos is the result of a call to `socket.getaddrinfo()`.
13391346
"""
1340-
addr_infos: List[aiohappyeyeballs.AddrInfoType] = []
1347+
addr_infos: List[AddrInfoType] = []
13411348
for hinfo in hosts:
13421349
host = hinfo["host"]
13431350
is_ipv6 = ":" in host

docs/client_advanced.rst

+15-8
Original file line numberDiff line numberDiff line change
@@ -461,19 +461,26 @@ If your HTTP server uses UNIX domain sockets you can use
461461
session = aiohttp.ClientSession(connector=conn)
462462

463463

464-
Setting socket options
464+
Custom socket creation
465465
^^^^^^^^^^^^^^^^^^^^^^
466466

467-
Socket options passed to the :class:`~aiohttp.TCPConnector` will be passed
468-
to the underlying socket when creating a connection. For example, we may
469-
want to change the conditions under which we consider a connection dead.
470-
The following would change that to 9*7200 = 18 hours::
467+
If the default socket is insufficient for your use case, pass an optional
468+
`socket_factory` to the :class:`~aiohttp.TCPConnector`, which implements
469+
`SocketFactoryType`. This will be used to create all sockets for the
470+
lifetime of the class object. For example, we may want to change the
471+
conditions under which we consider a connection dead. The following would
472+
make all sockets respect 9*7200 = 18 hours::
471473

472474
import socket
473475

474-
conn = aiohttp.TCPConnector(tcp_sockopts=[(socket.SOL_SOCKET, socket.SO_KEEPALIVE, True),
475-
(socket.IPPROTO_TCP, socket.TCP_KEEPIDLE, 7200),
476-
(socket.IPPROTO_TCP, socket.TCP_KEEPCNT, 9) ])
476+
def socket_factory(addr_info):
477+
family, type_, proto, _, _, _ = addr_info
478+
sock = socket.socket(family=family, type=type_, proto=proto)
479+
sock.setsockopt(socket.SOL_SOCKET, socket.SO_KEEPALIVE, True)
480+
sock.setsockopt(socket.IPPROTO_TCP, socket.TCP_KEEPIDLE, 7200)
481+
sock.setsockopt(socket.IPPROTO_TCP, socket.TCP_KEEPCNT, 9)
482+
return sock
483+
conn = aiohttp.TCPConnector(socket_factory=socket_factory)
477484

478485

479486
Named pipes in Windows

docs/client_reference.rst

+32-4
Original file line numberDiff line numberDiff line change
@@ -1138,14 +1138,42 @@ is controlled by *force_close* constructor's parameter).
11381138
overridden in subclasses.
11391139

11401140

1141+
.. autodata:: AddrInfoType
1142+
1143+
.. note::
1144+
1145+
Refer to :py:data:`aiohappyeyeballs.AddrInfoType` for more info.
1146+
1147+
.. warning::
1148+
1149+
Be sure to use ``aiohttp.AddrInfoType`` rather than
1150+
``aiohappyeyeballs.AddrInfoType`` to avoid import breakage, as
1151+
it is likely to be removed from ``aiohappyeyeballs`` in the
1152+
future.
1153+
1154+
1155+
.. autodata:: SocketFactoryType
1156+
1157+
.. note::
1158+
1159+
Refer to :py:data:`aiohappyeyeballs.SocketFactoryType` for more info.
1160+
1161+
.. warning::
1162+
1163+
Be sure to use ``aiohttp.SocketFactoryType`` rather than
1164+
``aiohappyeyeballs.SocketFactoryType`` to avoid import breakage,
1165+
as it is likely to be removed from ``aiohappyeyeballs`` in the
1166+
future.
1167+
1168+
11411169
.. class:: TCPConnector(*, ssl=True, verify_ssl=True, fingerprint=None, \
11421170
use_dns_cache=True, ttl_dns_cache=10, \
11431171
family=0, ssl_context=None, local_addr=None, \
11441172
resolver=None, keepalive_timeout=sentinel, \
11451173
force_close=False, limit=100, limit_per_host=0, \
11461174
enable_cleanup_closed=False, timeout_ceil_threshold=5, \
11471175
happy_eyeballs_delay=0.25, interleave=None, loop=None, \
1148-
tcp_sockopts=[])
1176+
socket_factory=None)
11491177

11501178
Connector for working with *HTTP* and *HTTPS* via *TCP* sockets.
11511179

@@ -1266,9 +1294,9 @@ is controlled by *force_close* constructor's parameter).
12661294

12671295
.. versionadded:: 3.10
12681296

1269-
:param list tcp_sockopts: options applied to the socket when a connection is
1270-
created. This should be a list of 3-tuples, each a ``(level, optname, value)``.
1271-
Each tuple is deconstructed and passed verbatim to ``<socket>.setsockopt``.
1297+
:param :py:data:``SocketFactoryType`` socket_factory: This function takes an
1298+
:py:data:``AddrInfoType`` and is used in lieu of ``socket.socket()`` when
1299+
creating TCP connections.
12721300

12731301
.. versionadded:: 3.12
12741302

docs/conf.py

+3
Original file line numberDiff line numberDiff line change
@@ -54,6 +54,7 @@
5454
# ones.
5555
extensions = [
5656
# stdlib-party extensions:
57+
"sphinx.ext.autodoc",
5758
"sphinx.ext.extlinks",
5859
"sphinx.ext.graphviz",
5960
"sphinx.ext.intersphinx",
@@ -83,6 +84,7 @@
8384
"aiohttpsession": ("https://aiohttp-session.readthedocs.io/en/stable/", None),
8485
"aiohttpdemos": ("https://aiohttp-demos.readthedocs.io/en/latest/", None),
8586
"aiojobs": ("https://aiojobs.readthedocs.io/en/stable/", None),
87+
"aiohappyeyeballs": ("https://aiohappyeyeballs.readthedocs.io/en/stable/", None),
8688
}
8789

8890
# Add any paths that contain templates here, relative to this directory.
@@ -441,6 +443,7 @@
441443
("py:exc", "HTTPMethodNotAllowed"), # undocumented
442444
("py:class", "HTTPMethodNotAllowed"), # undocumented
443445
("py:class", "HTTPUnavailableForLegalReasons"), # undocumented
446+
("py:class", "socket.SocketKind"), # undocumented
444447
]
445448

446449
# -- Options for towncrier_draft extension -----------------------------------

requirements/runtime-deps.in

+1-1
Original file line numberDiff line numberDiff line change
@@ -1,7 +1,7 @@
11
# Extracted from `setup.cfg` via `make sync-direct-runtime-deps`
22

33
aiodns >= 3.2.0; sys_platform=="linux" or sys_platform=="darwin"
4-
aiohappyeyeballs >= 2.3.0
4+
aiohappyeyeballs >= 2.5.0
55
aiosignal >= 1.1.2
66
async-timeout >= 4.0, < 6.0 ; python_version < "3.11"
77
attrs >= 17.3.0

setup.cfg

+1-1
Original file line numberDiff line numberDiff line change
@@ -51,7 +51,7 @@ zip_safe = False
5151
include_package_data = True
5252

5353
install_requires =
54-
aiohappyeyeballs >= 2.3.0
54+
aiohappyeyeballs >= 2.5.0
5555
aiosignal >= 1.1.2
5656
async-timeout >= 4.0, < 6.0 ; python_version < "3.11"
5757
attrs >= 17.3.0

tests/test_connector.py

+50-19
Original file line numberDiff line numberDiff line change
@@ -10,11 +10,20 @@
1010
from collections import defaultdict, deque
1111
from concurrent import futures
1212
from contextlib import closing, suppress
13-
from typing import Any, DefaultDict, Deque, List, Literal, Optional, Sequence, Tuple
13+
from typing import (
14+
Any,
15+
Callable,
16+
DefaultDict,
17+
Deque,
18+
List,
19+
Literal,
20+
Optional,
21+
Sequence,
22+
Tuple,
23+
)
1424
from unittest import mock
1525

1626
import pytest
17-
from aiohappyeyeballs import AddrInfoType
1827
from pytest_mock import MockerFixture
1928
from yarl import URL
2029

@@ -26,6 +35,7 @@
2635
from aiohttp.connector import (
2736
_SSL_CONTEXT_UNVERIFIED,
2837
_SSL_CONTEXT_VERIFIED,
38+
AddrInfoType,
2939
Connection,
3040
TCPConnector,
3141
_DNSCacheTable,
@@ -3663,27 +3673,48 @@ def test_connect() -> Literal[True]:
36633673
assert raw_response_list == [True, True]
36643674

36653675

3666-
async def test_tcp_connector_setsockopts(
3676+
async def test_tcp_connector_socket_factory(
36673677
loop: asyncio.AbstractEventLoop, start_connection: mock.AsyncMock
36683678
) -> None:
3669-
"""Check that sockopts get passed to socket"""
3670-
conn = aiohttp.TCPConnector(
3671-
tcp_sockopts=[(socket.IPPROTO_TCP, socket.TCP_KEEPCNT, 2)]
3672-
)
3673-
3674-
with mock.patch.object(
3675-
conn._loop, "create_connection", autospec=True, spec_set=True
3676-
) as create_connection:
3677-
with socket.socket(socket.AF_INET, socket.SOCK_STREAM) as s:
3678-
start_connection.return_value = s
3679-
create_connection.return_value = mock.Mock(), mock.Mock()
3680-
3681-
req = ClientRequest("GET", URL("https://127.0.0.1:443"), loop=loop)
3679+
"""Check that socket factory is called"""
3680+
with socket.socket(socket.AF_INET, socket.SOCK_STREAM) as s:
3681+
start_connection.return_value = s
3682+
3683+
local_addr = None
3684+
socket_factory: Callable[[AddrInfoType], socket.socket] = lambda _: s
3685+
happy_eyeballs_delay = 0.123
3686+
interleave = 3
3687+
conn = aiohttp.TCPConnector(
3688+
interleave=interleave,
3689+
local_addr=local_addr,
3690+
happy_eyeballs_delay=happy_eyeballs_delay,
3691+
socket_factory=socket_factory,
3692+
)
36823693

3694+
with mock.patch.object(
3695+
conn._loop,
3696+
"create_connection",
3697+
autospec=True,
3698+
spec_set=True,
3699+
return_value=(mock.Mock(), mock.Mock()),
3700+
):
3701+
host = "127.0.0.1"
3702+
port = 443
3703+
req = ClientRequest("GET", URL(f"https://{host}:{port}"), loop=loop)
36833704
with closing(await conn.connect(req, [], ClientTimeout())):
3684-
assert s.getsockopt(socket.IPPROTO_TCP, socket.TCP_KEEPCNT) == 2
3685-
3686-
await conn.close()
3705+
pass
3706+
await conn.close()
3707+
3708+
start_connection.assert_called_with(
3709+
addr_infos=[
3710+
(socket.AF_INET, socket.SOCK_STREAM, socket.IPPROTO_TCP, "", (host, port))
3711+
],
3712+
local_addr_infos=local_addr,
3713+
happy_eyeballs_delay=happy_eyeballs_delay,
3714+
interleave=interleave,
3715+
loop=loop,
3716+
socket_factory=socket_factory,
3717+
)
36873718

36883719

36893720
def test_default_ssl_context_creation_without_ssl() -> None:

0 commit comments

Comments
 (0)