Skip to content
This repository was archived by the owner on Apr 26, 2024. It is now read-only.

Commit 63b51ef

Browse files
authored
Support IPv6-only SMTP servers (#16155)
Use Twisted HostnameEndpoint to connect to SMTP servers (instead of connectTCP/connectSSL) which properly supports IPv6-only servers.
1 parent 2d72367 commit 63b51ef

File tree

5 files changed

+125
-29
lines changed

5 files changed

+125
-29
lines changed

Diff for: changelog.d/16155.bugfix

+1
Original file line numberDiff line numberDiff line change
@@ -0,0 +1 @@
1+
Fix IPv6-related bugs on SMTP settings, adding groundwork to fix similar issues. Contributed by @evilham and @telmich (ungleich.ch).

Diff for: synapse/handlers/send_email.py

+11-17
Original file line numberDiff line numberDiff line change
@@ -23,9 +23,11 @@
2323

2424
import twisted
2525
from twisted.internet.defer import Deferred
26-
from twisted.internet.interfaces import IOpenSSLContextFactory
26+
from twisted.internet.endpoints import HostnameEndpoint
27+
from twisted.internet.interfaces import IOpenSSLContextFactory, IProtocolFactory
2728
from twisted.internet.ssl import optionsForClientTLS
2829
from twisted.mail.smtp import ESMTPSender, ESMTPSenderFactory
30+
from twisted.protocols.tls import TLSMemoryBIOFactory
2931

3032
from synapse.logging.context import make_deferred_yieldable
3133
from synapse.types import ISynapseReactor
@@ -97,6 +99,7 @@ def build_sender_factory(**kwargs: Any) -> ESMTPSenderFactory:
9799
**kwargs,
98100
)
99101

102+
factory: IProtocolFactory
100103
if _is_old_twisted:
101104
# before twisted 21.2, we have to override the ESMTPSender protocol to disable
102105
# TLS
@@ -110,22 +113,13 @@ def build_sender_factory(**kwargs: Any) -> ESMTPSenderFactory:
110113
factory = build_sender_factory(hostname=smtphost if enable_tls else None)
111114

112115
if force_tls:
113-
reactor.connectSSL(
114-
smtphost,
115-
smtpport,
116-
factory,
117-
optionsForClientTLS(smtphost),
118-
timeout=30,
119-
bindAddress=None,
120-
)
121-
else:
122-
reactor.connectTCP(
123-
smtphost,
124-
smtpport,
125-
factory,
126-
timeout=30,
127-
bindAddress=None,
128-
)
116+
factory = TLSMemoryBIOFactory(optionsForClientTLS(smtphost), True, factory)
117+
118+
endpoint = HostnameEndpoint(
119+
reactor, smtphost, smtpport, timeout=30, bindAddress=None
120+
)
121+
122+
await make_deferred_yieldable(endpoint.connect(factory))
129123

130124
await make_deferred_yieldable(d)
131125

Diff for: tests/handlers/test_send_email.py

+59-10
Original file line numberDiff line numberDiff line change
@@ -13,19 +13,40 @@
1313
# limitations under the License.
1414

1515

16-
from typing import Callable, List, Tuple
16+
from typing import Callable, List, Tuple, Type, Union
17+
from unittest.mock import patch
1718

1819
from zope.interface import implementer
1920

2021
from twisted.internet import defer
21-
from twisted.internet.address import IPv4Address
22+
from twisted.internet._sslverify import ClientTLSOptions
23+
from twisted.internet.address import IPv4Address, IPv6Address
2224
from twisted.internet.defer import ensureDeferred
25+
from twisted.internet.interfaces import IProtocolFactory
26+
from twisted.internet.ssl import ContextFactory
2327
from twisted.mail import interfaces, smtp
2428

2529
from tests.server import FakeTransport
2630
from tests.unittest import HomeserverTestCase, override_config
2731

2832

33+
def TestingESMTPTLSClientFactory(
34+
contextFactory: ContextFactory,
35+
_connectWrapped: bool,
36+
wrappedProtocol: IProtocolFactory,
37+
) -> IProtocolFactory:
38+
"""We use this to pass through in testing without using TLS, but
39+
saving the context information to check that it would have happened.
40+
41+
Note that this is what the MemoryReactor does on connectSSL.
42+
It only saves the contextFactory, but starts the connection with the
43+
underlying Factory.
44+
See: L{twisted.internet.testing.MemoryReactor.connectSSL}"""
45+
46+
wrappedProtocol._testingContextFactory = contextFactory # type: ignore[attr-defined]
47+
return wrappedProtocol
48+
49+
2950
@implementer(interfaces.IMessageDelivery)
3051
class _DummyMessageDelivery:
3152
def __init__(self) -> None:
@@ -75,7 +96,13 @@ def connectionLost(self) -> None:
7596
pass
7697

7798

78-
class SendEmailHandlerTestCase(HomeserverTestCase):
99+
class SendEmailHandlerTestCaseIPv4(HomeserverTestCase):
100+
ip_class: Union[Type[IPv4Address], Type[IPv6Address]] = IPv4Address
101+
102+
def setUp(self) -> None:
103+
super().setUp()
104+
self.reactor.lookups["localhost"] = "127.0.0.1"
105+
79106
def test_send_email(self) -> None:
80107
"""Happy-path test that we can send email to a non-TLS server."""
81108
h = self.hs.get_send_email_handler()
@@ -89,7 +116,7 @@ def test_send_email(self) -> None:
89116
(host, port, client_factory, _timeout, _bindAddress) = self.reactor.tcpClients[
90117
0
91118
]
92-
self.assertEqual(host, "localhost")
119+
self.assertEqual(host, self.reactor.lookups["localhost"])
93120
self.assertEqual(port, 25)
94121

95122
# wire it up to an SMTP server
@@ -105,7 +132,9 @@ def test_send_email(self) -> None:
105132
FakeTransport(
106133
client_protocol,
107134
self.reactor,
108-
peer_address=IPv4Address("TCP", "127.0.0.1", 1234),
135+
peer_address=self.ip_class(
136+
"TCP", self.reactor.lookups["localhost"], 1234
137+
),
109138
)
110139
)
111140

@@ -118,6 +147,10 @@ def test_send_email(self) -> None:
118147
self.assertEqual(str(user), "[email protected]")
119148
self.assertIn(b"Subject: test subject", msg)
120149

150+
@patch(
151+
"synapse.handlers.send_email.TLSMemoryBIOFactory",
152+
TestingESMTPTLSClientFactory,
153+
)
121154
@override_config(
122155
{
123156
"email": {
@@ -135,17 +168,23 @@ def test_send_email_force_tls(self) -> None:
135168
)
136169
)
137170
# there should be an attempt to connect to localhost:465
138-
self.assertEqual(len(self.reactor.sslClients), 1)
171+
self.assertEqual(len(self.reactor.tcpClients), 1)
139172
(
140173
host,
141174
port,
142175
client_factory,
143-
contextFactory,
144176
_timeout,
145177
_bindAddress,
146-
) = self.reactor.sslClients[0]
147-
self.assertEqual(host, "localhost")
178+
) = self.reactor.tcpClients[0]
179+
self.assertEqual(host, self.reactor.lookups["localhost"])
148180
self.assertEqual(port, 465)
181+
# We need to make sure that TLS is happenning
182+
self.assertIsInstance(
183+
client_factory._wrappedFactory._testingContextFactory,
184+
ClientTLSOptions,
185+
)
186+
# And since we use endpoints, they go through reactor.connectTCP
187+
# which works differently to connectSSL on the testing reactor
149188

150189
# wire it up to an SMTP server
151190
message_delivery = _DummyMessageDelivery()
@@ -160,7 +199,9 @@ def test_send_email_force_tls(self) -> None:
160199
FakeTransport(
161200
client_protocol,
162201
self.reactor,
163-
peer_address=IPv4Address("TCP", "127.0.0.1", 1234),
202+
peer_address=self.ip_class(
203+
"TCP", self.reactor.lookups["localhost"], 1234
204+
),
164205
)
165206
)
166207

@@ -172,3 +213,11 @@ def test_send_email_force_tls(self) -> None:
172213
user, msg = message_delivery.messages.pop()
173214
self.assertEqual(str(user), "[email protected]")
174215
self.assertIn(b"Subject: test subject", msg)
216+
217+
218+
class SendEmailHandlerTestCaseIPv6(SendEmailHandlerTestCaseIPv4):
219+
ip_class = IPv6Address
220+
221+
def setUp(self) -> None:
222+
super().setUp()
223+
self.reactor.lookups["localhost"] = "::1"

Diff for: tests/server.py

+53-1
Original file line numberDiff line numberDiff line change
@@ -12,6 +12,7 @@
1212
# See the License for the specific language governing permissions and
1313
# limitations under the License.
1414
import hashlib
15+
import ipaddress
1516
import json
1617
import logging
1718
import os
@@ -45,7 +46,7 @@
4546
from typing_extensions import ParamSpec
4647
from zope.interface import implementer
4748

48-
from twisted.internet import address, threads, udp
49+
from twisted.internet import address, tcp, threads, udp
4950
from twisted.internet._resolver import SimpleResolverComplexifier
5051
from twisted.internet.defer import Deferred, fail, maybeDeferred, succeed
5152
from twisted.internet.error import DNSLookupError
@@ -567,6 +568,8 @@ def connectTCP(
567568
conn = super().connectTCP(
568569
host, port, factory, timeout=timeout, bindAddress=None
569570
)
571+
if self.lookups and host in self.lookups:
572+
validate_connector(conn, self.lookups[host])
570573

571574
callback = self._tcp_callbacks.get((host, port))
572575
if callback:
@@ -599,6 +602,55 @@ def advance(self, amount: float) -> None:
599602
super().advance(0)
600603

601604

605+
def validate_connector(connector: tcp.Connector, expected_ip: str) -> None:
606+
"""Try to validate the obtained connector as it would happen when
607+
synapse is running and the conection will be established.
608+
609+
This method will raise a useful exception when necessary, else it will
610+
just do nothing.
611+
612+
This is in order to help catch quirks related to reactor.connectTCP,
613+
since when called directly, the connector's destination will be of type
614+
IPv4Address, with the hostname as the literal host that was given (which
615+
could be an IPv6-only host or an IPv6 literal).
616+
617+
But when called from reactor.connectTCP *through* e.g. an Endpoint, the
618+
connector's destination will contain the specific IP address with the
619+
correct network stack class.
620+
621+
Note that testing code paths that use connectTCP directly should not be
622+
affected by this check, unless they specifically add a test with a
623+
matching reactor.lookups[HOSTNAME] = "IPv6Literal", where reactor is of
624+
type ThreadedMemoryReactorClock.
625+
For an example of implementing such tests, see test/handlers/send_email.py.
626+
"""
627+
destination = connector.getDestination()
628+
629+
# We use address.IPv{4,6}Address to check what the reactor thinks it is
630+
# is sending but check for validity with ipaddress.IPv{4,6}Address
631+
# because they fail with IPs on the wrong network stack.
632+
cls_mapping = {
633+
address.IPv4Address: ipaddress.IPv4Address,
634+
address.IPv6Address: ipaddress.IPv6Address,
635+
}
636+
637+
cls = cls_mapping.get(destination.__class__)
638+
639+
if cls is not None:
640+
try:
641+
cls(expected_ip)
642+
except Exception as exc:
643+
raise ValueError(
644+
"Invalid IP type and resolution for %s. Expected %s to be %s"
645+
% (destination, expected_ip, cls.__name__)
646+
) from exc
647+
else:
648+
raise ValueError(
649+
"Unknown address type %s for %s"
650+
% (destination.__class__.__name__, destination)
651+
)
652+
653+
602654
class ThreadPool:
603655
"""
604656
Threadless thread pool.

Diff for: tests/unittest.py

+1-1
Original file line numberDiff line numberDiff line change
@@ -313,7 +313,7 @@ class HomeserverTestCase(TestCase):
313313
servlets: List of servlet registration function.
314314
user_id (str): The user ID to assume if auth is hijacked.
315315
hijack_auth: Whether to hijack auth to return the user specified
316-
in user_id.
316+
in user_id.
317317
"""
318318

319319
hijack_auth: ClassVar[bool] = True

0 commit comments

Comments
 (0)