Skip to content

Commit aa86a3a

Browse files
committed
PYTHON-5306 - Fix use of public MongoClient attributes before connection
1 parent e6a4a71 commit aa86a3a

File tree

4 files changed

+134
-16
lines changed

4 files changed

+134
-16
lines changed

pymongo/asynchronous/mongo_client.py

+31-8
Original file line numberDiff line numberDiff line change
@@ -109,6 +109,7 @@
109109
)
110110
from pymongo.read_preferences import ReadPreference, _ServerMode
111111
from pymongo.results import ClientBulkWriteResult
112+
from pymongo.server_description import ServerDescription
112113
from pymongo.server_selectors import writable_server_selector
113114
from pymongo.server_type import SERVER_TYPE
114115
from pymongo.topology_description import TOPOLOGY_TYPE, TopologyDescription
@@ -779,7 +780,7 @@ def __init__(
779780
keyword_opts["document_class"] = doc_class
780781
self._resolve_srv_info: dict[str, Any] = {"keyword_opts": keyword_opts}
781782

782-
seeds = set()
783+
self._seeds = set()
783784
is_srv = False
784785
username = None
785786
password = None
@@ -804,18 +805,18 @@ def __init__(
804805
srv_max_hosts=srv_max_hosts,
805806
)
806807
is_srv = entity.startswith(SRV_SCHEME)
807-
seeds.update(res["nodelist"])
808+
self._seeds.update(res["nodelist"])
808809
username = res["username"] or username
809810
password = res["password"] or password
810811
dbase = res["database"] or dbase
811812
opts = res["options"]
812813
fqdn = res["fqdn"]
813814
else:
814-
seeds.update(split_hosts(entity, self._port))
815-
if not seeds:
815+
self._seeds.update(split_hosts(entity, self._port))
816+
if not self._seeds:
816817
raise ConfigurationError("need to specify at least one host")
817818

818-
for hostname in [node[0] for node in seeds]:
819+
for hostname in [node[0] for node in self._seeds]:
819820
if _detect_external_db(hostname):
820821
break
821822

@@ -838,7 +839,7 @@ def __init__(
838839
srv_service_name = opts.get("srvServiceName", common.SRV_SERVICE_NAME)
839840

840841
srv_max_hosts = srv_max_hosts or opts.get("srvmaxhosts")
841-
opts = self._normalize_and_validate_options(opts, seeds)
842+
opts = self._normalize_and_validate_options(opts, self._seeds)
842843

843844
# Username and password passed as kwargs override user info in URI.
844845
username = opts.get("username", username)
@@ -857,7 +858,7 @@ def __init__(
857858
"username": username,
858859
"password": password,
859860
"dbase": dbase,
860-
"seeds": seeds,
861+
"seeds": self._seeds,
861862
"fqdn": fqdn,
862863
"srv_service_name": srv_service_name,
863864
"pool_class": pool_class,
@@ -874,7 +875,7 @@ def __init__(
874875
)
875876

876877
if not is_srv:
877-
self._init_based_on_options(seeds, srv_max_hosts, srv_service_name)
878+
self._init_based_on_options(self._seeds, srv_max_hosts, srv_service_name)
878879

879880
self._opened = False
880881
self._closed = False
@@ -1205,6 +1206,18 @@ def topology_description(self) -> TopologyDescription:
12051206
12061207
.. versionadded:: 4.0
12071208
"""
1209+
if self._topology is None:
1210+
servers = {
1211+
(host, self._port): ServerDescription((host, self._port)) for host in self._seeds
1212+
}
1213+
return TopologyDescription(
1214+
TOPOLOGY_TYPE.Unknown,
1215+
servers,
1216+
None,
1217+
None,
1218+
None,
1219+
TopologySettings(),
1220+
)
12081221
return self._topology.description
12091222

12101223
@property
@@ -1218,6 +1231,8 @@ def nodes(self) -> FrozenSet[_Address]:
12181231
to any servers, or a network partition causes it to lose connection
12191232
to all servers.
12201233
"""
1234+
if self._topology is None:
1235+
return frozenset()
12211236
description = self._topology.description
12221237
return frozenset(s.address for s in description.known_servers)
12231238

@@ -1576,6 +1591,8 @@ async def address(self) -> Optional[tuple[str, int]]:
15761591
15771592
.. versionadded:: 3.0
15781593
"""
1594+
if self._topology is None:
1595+
await self._get_topology()
15791596
topology_type = self._topology._description.topology_type
15801597
if (
15811598
topology_type == TOPOLOGY_TYPE.Sharded
@@ -1598,6 +1615,8 @@ async def primary(self) -> Optional[tuple[str, int]]:
15981615
.. versionadded:: 3.0
15991616
AsyncMongoClient gained this property in version 3.0.
16001617
"""
1618+
if self._topology is None:
1619+
await self._get_topology()
16011620
return await self._topology.get_primary() # type: ignore[return-value]
16021621

16031622
@property
@@ -1611,6 +1630,8 @@ async def secondaries(self) -> set[_Address]:
16111630
.. versionadded:: 3.0
16121631
AsyncMongoClient gained this property in version 3.0.
16131632
"""
1633+
if self._topology is None:
1634+
await self._get_topology()
16141635
return await self._topology.get_secondaries()
16151636

16161637
@property
@@ -1621,6 +1642,8 @@ async def arbiters(self) -> set[_Address]:
16211642
connected to a replica set, there are no arbiters, or this client was
16221643
created without the `replicaSet` option.
16231644
"""
1645+
if self._topology is None:
1646+
await self._get_topology()
16241647
return await self._topology.get_arbiters()
16251648

16261649
@property

pymongo/synchronous/mongo_client.py

+31-8
Original file line numberDiff line numberDiff line change
@@ -101,6 +101,7 @@
101101
)
102102
from pymongo.read_preferences import ReadPreference, _ServerMode
103103
from pymongo.results import ClientBulkWriteResult
104+
from pymongo.server_description import ServerDescription
104105
from pymongo.server_selectors import writable_server_selector
105106
from pymongo.server_type import SERVER_TYPE
106107
from pymongo.synchronous import client_session, database, uri_parser
@@ -777,7 +778,7 @@ def __init__(
777778
keyword_opts["document_class"] = doc_class
778779
self._resolve_srv_info: dict[str, Any] = {"keyword_opts": keyword_opts}
779780

780-
seeds = set()
781+
self._seeds = set()
781782
is_srv = False
782783
username = None
783784
password = None
@@ -802,18 +803,18 @@ def __init__(
802803
srv_max_hosts=srv_max_hosts,
803804
)
804805
is_srv = entity.startswith(SRV_SCHEME)
805-
seeds.update(res["nodelist"])
806+
self._seeds.update(res["nodelist"])
806807
username = res["username"] or username
807808
password = res["password"] or password
808809
dbase = res["database"] or dbase
809810
opts = res["options"]
810811
fqdn = res["fqdn"]
811812
else:
812-
seeds.update(split_hosts(entity, self._port))
813-
if not seeds:
813+
self._seeds.update(split_hosts(entity, self._port))
814+
if not self._seeds:
814815
raise ConfigurationError("need to specify at least one host")
815816

816-
for hostname in [node[0] for node in seeds]:
817+
for hostname in [node[0] for node in self._seeds]:
817818
if _detect_external_db(hostname):
818819
break
819820

@@ -836,7 +837,7 @@ def __init__(
836837
srv_service_name = opts.get("srvServiceName", common.SRV_SERVICE_NAME)
837838

838839
srv_max_hosts = srv_max_hosts or opts.get("srvmaxhosts")
839-
opts = self._normalize_and_validate_options(opts, seeds)
840+
opts = self._normalize_and_validate_options(opts, self._seeds)
840841

841842
# Username and password passed as kwargs override user info in URI.
842843
username = opts.get("username", username)
@@ -855,7 +856,7 @@ def __init__(
855856
"username": username,
856857
"password": password,
857858
"dbase": dbase,
858-
"seeds": seeds,
859+
"seeds": self._seeds,
859860
"fqdn": fqdn,
860861
"srv_service_name": srv_service_name,
861862
"pool_class": pool_class,
@@ -872,7 +873,7 @@ def __init__(
872873
)
873874

874875
if not is_srv:
875-
self._init_based_on_options(seeds, srv_max_hosts, srv_service_name)
876+
self._init_based_on_options(self._seeds, srv_max_hosts, srv_service_name)
876877

877878
self._opened = False
878879
self._closed = False
@@ -1203,6 +1204,18 @@ def topology_description(self) -> TopologyDescription:
12031204
12041205
.. versionadded:: 4.0
12051206
"""
1207+
if self._topology is None:
1208+
servers = {
1209+
(host, self._port): ServerDescription((host, self._port)) for host in self._seeds
1210+
}
1211+
return TopologyDescription(
1212+
TOPOLOGY_TYPE.Unknown,
1213+
servers,
1214+
None,
1215+
None,
1216+
None,
1217+
TopologySettings(),
1218+
)
12061219
return self._topology.description
12071220

12081221
@property
@@ -1216,6 +1229,8 @@ def nodes(self) -> FrozenSet[_Address]:
12161229
to any servers, or a network partition causes it to lose connection
12171230
to all servers.
12181231
"""
1232+
if self._topology is None:
1233+
return frozenset()
12191234
description = self._topology.description
12201235
return frozenset(s.address for s in description.known_servers)
12211236

@@ -1570,6 +1585,8 @@ def address(self) -> Optional[tuple[str, int]]:
15701585
15711586
.. versionadded:: 3.0
15721587
"""
1588+
if self._topology is None:
1589+
self._get_topology()
15731590
topology_type = self._topology._description.topology_type
15741591
if (
15751592
topology_type == TOPOLOGY_TYPE.Sharded
@@ -1592,6 +1609,8 @@ def primary(self) -> Optional[tuple[str, int]]:
15921609
.. versionadded:: 3.0
15931610
MongoClient gained this property in version 3.0.
15941611
"""
1612+
if self._topology is None:
1613+
self._get_topology()
15951614
return self._topology.get_primary() # type: ignore[return-value]
15961615

15971616
@property
@@ -1605,6 +1624,8 @@ def secondaries(self) -> set[_Address]:
16051624
.. versionadded:: 3.0
16061625
MongoClient gained this property in version 3.0.
16071626
"""
1627+
if self._topology is None:
1628+
self._get_topology()
16081629
return self._topology.get_secondaries()
16091630

16101631
@property
@@ -1615,6 +1636,8 @@ def arbiters(self) -> set[_Address]:
16151636
connected to a replica set, there are no arbiters, or this client was
16161637
created without the `replicaSet` option.
16171638
"""
1639+
if self._topology is None:
1640+
self._get_topology()
16181641
return self._topology.get_arbiters()
16191642

16201643
@property

test/asynchronous/test_client.py

+36
Original file line numberDiff line numberDiff line change
@@ -816,6 +816,30 @@ async def test_constants(self):
816816
async def test_init_disconnected(self):
817817
host, port = await async_client_context.host, await async_client_context.port
818818
c = await self.async_rs_or_single_client(connect=False)
819+
# nodes returns an empty set if not connected
820+
self.assertEqual(c.nodes, frozenset())
821+
# topology_description returns the initial seed description if not connected
822+
topology_description = c.topology_description
823+
self.assertEqual(topology_description.topology_type, TOPOLOGY_TYPE.Unknown)
824+
self.assertEqual(
825+
topology_description.server_descriptions(),
826+
{(host, port): ServerDescription((host, port))},
827+
)
828+
# address causes client to block until connected
829+
self.assertIsNotNone(await c.address)
830+
c = await self.async_rs_or_single_client(connect=False)
831+
# primary causes client to block until connected
832+
await c.primary
833+
self.assertIsNotNone(c._topology)
834+
c = await self.async_rs_or_single_client(connect=False)
835+
# secondaries causes client to block until connected
836+
await c.secondaries
837+
self.assertIsNotNone(c._topology)
838+
c = await self.async_rs_or_single_client(connect=False)
839+
# arbiters causes client to block until connected
840+
await c.arbiters
841+
self.assertIsNotNone(c._topology)
842+
c = await self.async_rs_or_single_client(connect=False)
819843
# is_primary causes client to block until connected
820844
self.assertIsInstance(await c.is_primary, bool)
821845
c = await self.async_rs_or_single_client(connect=False)
@@ -2170,6 +2194,18 @@ async def test_uuid_queries(self):
21702194
self.assertEqual(2, len(docs))
21712195
await coll.drop()
21722196

2197+
async def test_unconnected_client_properties_with_srv(self):
2198+
client = self.simple_client("mongodb+srv://test1.test.build.10gen.cc/", connect=False)
2199+
self.assertEqual(client.nodes, frozenset())
2200+
topology_description = client.topology_description
2201+
self.assertEqual(topology_description.topology_type, TOPOLOGY_TYPE.Unknown)
2202+
self.assertEqual(
2203+
topology_description.server_descriptions(),
2204+
{("unknown", None): ServerDescription(("unknown", None))},
2205+
)
2206+
await client.aconnect()
2207+
self.assertEqual(await client.address, None)
2208+
21732209

21742210
class TestExhaustCursor(AsyncIntegrationTest):
21752211
"""Test that clients properly handle errors from exhaust cursors."""

test/test_client.py

+36
Original file line numberDiff line numberDiff line change
@@ -791,6 +791,30 @@ def test_constants(self):
791791
def test_init_disconnected(self):
792792
host, port = client_context.host, client_context.port
793793
c = self.rs_or_single_client(connect=False)
794+
# nodes returns an empty set if not connected
795+
self.assertEqual(c.nodes, frozenset())
796+
# topology_description returns the initial seed description if not connected
797+
topology_description = c.topology_description
798+
self.assertEqual(topology_description.topology_type, TOPOLOGY_TYPE.Unknown)
799+
self.assertEqual(
800+
topology_description.server_descriptions(),
801+
{(host, port): ServerDescription((host, port))},
802+
)
803+
# address causes client to block until connected
804+
self.assertIsNotNone(c.address)
805+
c = self.rs_or_single_client(connect=False)
806+
# primary causes client to block until connected
807+
c.primary
808+
self.assertIsNotNone(c._topology)
809+
c = self.rs_or_single_client(connect=False)
810+
# secondaries causes client to block until connected
811+
c.secondaries
812+
self.assertIsNotNone(c._topology)
813+
c = self.rs_or_single_client(connect=False)
814+
# arbiters causes client to block until connected
815+
c.arbiters
816+
self.assertIsNotNone(c._topology)
817+
c = self.rs_or_single_client(connect=False)
794818
# is_primary causes client to block until connected
795819
self.assertIsInstance(c.is_primary, bool)
796820
c = self.rs_or_single_client(connect=False)
@@ -2127,6 +2151,18 @@ def test_uuid_queries(self):
21272151
self.assertEqual(2, len(docs))
21282152
coll.drop()
21292153

2154+
def test_unconnected_client_properties_with_srv(self):
2155+
client = self.simple_client("mongodb+srv://test1.test.build.10gen.cc/", connect=False)
2156+
self.assertEqual(client.nodes, frozenset())
2157+
topology_description = client.topology_description
2158+
self.assertEqual(topology_description.topology_type, TOPOLOGY_TYPE.Unknown)
2159+
self.assertEqual(
2160+
topology_description.server_descriptions(),
2161+
{("unknown", None): ServerDescription(("unknown", None))},
2162+
)
2163+
client._connect()
2164+
self.assertEqual(client.address, None)
2165+
21302166

21312167
class TestExhaustCursor(IntegrationTest):
21322168
"""Test that clients properly handle errors from exhaust cursors."""

0 commit comments

Comments
 (0)