Skip to content

Commit f396de0

Browse files
authored
Upgrade pytest to 8.3.5 (#646)
* Put non-forked tests last * Upgrade pytest dependency * Remove timeouts from pytest * Replace Python 3.8 with 3.12 in setup.py * Remove the image from the app * Add timeouts to process cleanup in test_cli_run_server_identity_path * Clean up DHT and averager processes in test_adaptive_compression
1 parent 9a76360 commit f396de0

9 files changed

+112
-107
lines changed

modal_ci.py

+1-5
Original file line numberDiff line numberDiff line change
@@ -51,7 +51,7 @@
5151
)
5252

5353

54-
app = modal.App("hivemind-ci", image=image)
54+
app = modal.App("hivemind-ci")
5555

5656
codecov_secret = modal.Secret.from_dict(
5757
{
@@ -99,9 +99,6 @@ def run_tests():
9999
"-v",
100100
"-n",
101101
"8",
102-
"--dist",
103-
"worksteal",
104-
"--timeout=60",
105102
"tests",
106103
],
107104
check=True,
@@ -120,7 +117,6 @@ def run_codecov():
120117
"hivemind",
121118
"--cov-config=pyproject.toml",
122119
"-v",
123-
"--timeout=60",
124120
"tests",
125121
],
126122
check=True,

requirements-dev.txt

+1-1
Original file line numberDiff line numberDiff line change
@@ -1,4 +1,4 @@
1-
pytest==6.2.5 # see https://github.com/pytest-dev/pytest/issues/9621
1+
pytest==8.3.5
22
pytest-forked
33
pytest-asyncio==0.16.0
44
pytest-cov

setup.py

+1-1
Original file line numberDiff line numberDiff line change
@@ -183,10 +183,10 @@ def run(self):
183183
"Intended Audience :: Science/Research",
184184
"License :: OSI Approved :: MIT License",
185185
"Programming Language :: Python :: 3",
186-
"Programming Language :: Python :: 3.8",
187186
"Programming Language :: Python :: 3.9",
188187
"Programming Language :: Python :: 3.10",
189188
"Programming Language :: Python :: 3.11",
189+
"Programming Language :: Python :: 3.12",
190190
"Topic :: Scientific/Engineering",
191191
"Topic :: Scientific/Engineering :: Mathematics",
192192
"Topic :: Scientific/Engineering :: Artificial Intelligence",

tests/test_compression.py

+7-2
Original file line numberDiff line numberDiff line change
@@ -221,6 +221,8 @@ def test_adaptive_compression():
221221
greater_equal=STATE_FP16,
222222
)
223223

224+
dht_instances = launch_dht_instances(2)
225+
224226
averager1 = hivemind.TrainingAverager(
225227
opt=torch.optim.Adam(make_params()),
226228
average_parameters=True,
@@ -232,7 +234,7 @@ def test_adaptive_compression():
232234
target_group_size=2,
233235
part_size_bytes=5_000,
234236
start=True,
235-
dht=hivemind.DHT(start=True),
237+
dht=dht_instances[0],
236238
)
237239

238240
averager2 = hivemind.TrainingAverager(
@@ -246,7 +248,7 @@ def test_adaptive_compression():
246248
target_group_size=2,
247249
part_size_bytes=5_000,
248250
start=True,
249-
dht=hivemind.DHT(initial_peers=averager1.dht.get_visible_maddrs(), start=True),
251+
dht=dht_instances[1],
250252
)
251253

252254
futures = [averager1.step(wait=False), averager2.step(wait=False)]
@@ -266,3 +268,6 @@ def test_adaptive_compression():
266268
assert STATE_FP16.mp_counter.value == len([tensor for tensor in state_tensors if tensor.numel() >= 500])
267269
assert STATE_FP32.mp_counter.value == len([tensor for tensor in state_tensors if tensor.numel() < 500])
268270
assert STATE_FP16.mp_part_size.value == STATE_FP32.mp_part_size.value == 0 # not partitioned
271+
272+
for instance in [averager1, averager2] + dht_instances:
273+
instance.shutdown()

tests/test_dht.py

+14-14
Original file line numberDiff line numberDiff line change
@@ -12,20 +12,6 @@
1212
from test_utils.networking import get_free_port
1313

1414

15-
@pytest.mark.asyncio
16-
async def test_startup_error():
17-
with pytest.raises(hivemind.p2p.P2PDaemonError, match=r"(?i)Failed to connect to bootstrap peers"):
18-
hivemind.DHT(
19-
initial_peers=[f"/ip4/127.0.0.1/tcp/{get_free_port()}/p2p/QmdaK4LUeQaKhqSFPRu9N7MvXUEWDxWwtCvPrS444tCgd1"],
20-
start=True,
21-
)
22-
23-
dht = hivemind.DHT(start=True, await_ready=False)
24-
with pytest.raises(concurrent.futures.TimeoutError):
25-
dht.wait_until_ready(timeout=0.01)
26-
dht.shutdown()
27-
28-
2915
@pytest.mark.forked
3016
def test_get_store(n_peers=10):
3117
peers = launch_dht_instances(n_peers)
@@ -122,3 +108,17 @@ async def test_dht_get_visible_maddrs():
122108

123109
assert dht.get_visible_maddrs() == [dummy_endpoint.encapsulate(f"/p2p/{p2p.peer_id}")]
124110
dht.shutdown()
111+
112+
113+
@pytest.mark.asyncio
114+
async def test_startup_error():
115+
with pytest.raises(hivemind.p2p.P2PDaemonError, match=r"(?i)Failed to connect to bootstrap peers"):
116+
hivemind.DHT(
117+
initial_peers=[f"/ip4/127.0.0.1/tcp/{get_free_port()}/p2p/QmdaK4LUeQaKhqSFPRu9N7MvXUEWDxWwtCvPrS444tCgd1"],
118+
start=True,
119+
)
120+
121+
dht = hivemind.DHT(start=True, await_ready=False)
122+
with pytest.raises(concurrent.futures.TimeoutError):
123+
dht.wait_until_ready(timeout=0.01)
124+
dht.shutdown()

tests/test_dht_crypto.py

+18-18
Original file line numberDiff line numberDiff line change
@@ -88,24 +88,6 @@ def get_signed_record(conn: mp.connection.Connection) -> DHTRecord:
8888
return record
8989

9090

91-
def test_signing_in_different_process():
92-
parent_conn, child_conn = mp.Pipe()
93-
process = mp.Process(target=get_signed_record, args=[child_conn])
94-
process.start()
95-
96-
validator = RSASignatureValidator()
97-
parent_conn.send(validator)
98-
99-
record = DHTRecord(
100-
key=b"key", subkey=b"subkey" + validator.local_public_key, value=b"value", expiration_time=get_dht_time() + 10
101-
)
102-
parent_conn.send(record)
103-
104-
signed_record = parent_conn.recv()
105-
assert b"[signature:" in signed_record.value
106-
assert validator.validate(signed_record)
107-
108-
10991
@pytest.mark.forked
11092
@pytest.mark.asyncio
11193
async def test_dhtnode_signatures():
@@ -134,3 +116,21 @@ async def test_dhtnode_signatures():
134116
store_ok = await mallory.store(key, b"updated_fake_value", hivemind.get_dht_time() + 10, subkey=subkey)
135117
assert not store_ok
136118
assert (await alice.get(key, latest=True)).value[subkey].value == b"updated_true_value"
119+
120+
121+
def test_signing_in_different_process():
122+
parent_conn, child_conn = mp.Pipe()
123+
process = mp.Process(target=get_signed_record, args=[child_conn])
124+
process.start()
125+
126+
validator = RSASignatureValidator()
127+
parent_conn.send(validator)
128+
129+
record = DHTRecord(
130+
key=b"key", subkey=b"subkey" + validator.local_public_key, value=b"value", expiration_time=get_dht_time() + 10
131+
)
132+
parent_conn.send(record)
133+
134+
signed_record = parent_conn.recv()
135+
assert b"[signature:" in signed_record.value
136+
assert validator.validate(signed_record)

tests/test_dht_experts.py

+31-31
Original file line numberDiff line numberDiff line change
@@ -126,6 +126,37 @@ def test_dht_single_node():
126126
node.shutdown()
127127

128128

129+
@pytest.mark.forked
130+
@pytest.mark.asyncio
131+
async def test_negative_caching(n_peers=10):
132+
dht_kwargs = {"cache_locally": False}
133+
134+
peers = [hivemind.DHT(start=True, **dht_kwargs)]
135+
initial_peers = peers[0].get_visible_maddrs()
136+
peers += [hivemind.DHT(initial_peers=initial_peers, start=True, **dht_kwargs) for _ in range(n_peers - 1)]
137+
138+
writer_peer = random.choice(peers)
139+
assert all(declare_experts(writer_peer, ["ffn.1.2.3", "ffn.3.4.5"], get_dht_time() + 30).values())
140+
141+
neighbors = sum([peer.get_visible_maddrs() for peer in random.sample(peers, min(3, len(peers)))], [])
142+
neg_caching_peer = hivemind.DHT(initial_peers=neighbors, start=True, **dht_kwargs)
143+
beam_search = MoEBeamSearcher(neg_caching_peer, uid_prefix="ffn.", grid_size=(10, 10, 10), negative_caching=True)
144+
# get prefixes by the peer with negative caching. Cache "no data" entries for ffn.0.*, ffn.2.*, ffn.4.*, ffn.5.*
145+
assert len(beam_search.get_initial_beam(scores=[0.1, 0.2, 0.3, 0.4, 0.5, 0.6], beam_size=3)) == 2
146+
147+
node = await DHTNode.create(initial_peers=neighbors)
148+
fetched = await asyncio.gather(*(node.get(f"ffn.{i}.") for i in range(10)))
149+
for i in range(6):
150+
assert fetched[i] is not None, f"node should have cached ffn.{i}."
151+
for i in range(6, len(fetched)):
152+
assert fetched[i] is None, f"node shouldn't have cached ffn.{i}."
153+
154+
await node.shutdown()
155+
neg_caching_peer.shutdown()
156+
for peer in peers:
157+
peer.shutdown()
158+
159+
129160
def test_uid_patterns():
130161
valid_experts = [
131162
"expert.1",
@@ -188,34 +219,3 @@ def test_uid_patterns():
188219
assert not is_valid_uid(uid), f"UID {uid} is not valid, but was perceived as valid"
189220
for pfx in invalid_prefixes:
190221
assert not is_valid_prefix(pfx), f"Prefix {pfx} is not valid, but was perceived as valid"
191-
192-
193-
@pytest.mark.forked
194-
@pytest.mark.asyncio
195-
async def test_negative_caching(n_peers=10):
196-
dht_kwargs = {"cache_locally": False}
197-
198-
peers = [hivemind.DHT(start=True, **dht_kwargs)]
199-
initial_peers = peers[0].get_visible_maddrs()
200-
peers += [hivemind.DHT(initial_peers=initial_peers, start=True, **dht_kwargs) for _ in range(n_peers - 1)]
201-
202-
writer_peer = random.choice(peers)
203-
assert all(declare_experts(writer_peer, ["ffn.1.2.3", "ffn.3.4.5"], get_dht_time() + 30).values())
204-
205-
neighbors = sum([peer.get_visible_maddrs() for peer in random.sample(peers, min(3, len(peers)))], [])
206-
neg_caching_peer = hivemind.DHT(initial_peers=neighbors, start=True, **dht_kwargs)
207-
beam_search = MoEBeamSearcher(neg_caching_peer, uid_prefix="ffn.", grid_size=(10, 10, 10), negative_caching=True)
208-
# get prefixes by the peer with negative caching. Cache "no data" entries for ffn.0.*, ffn.2.*, ffn.4.*, ffn.5.*
209-
assert len(beam_search.get_initial_beam(scores=[0.1, 0.2, 0.3, 0.4, 0.5, 0.6], beam_size=3)) == 2
210-
211-
node = await DHTNode.create(initial_peers=neighbors)
212-
fetched = await asyncio.gather(*(node.get(f"ffn.{i}.") for i in range(10)))
213-
for i in range(6):
214-
assert fetched[i] is not None, f"node should have cached ffn.{i}."
215-
for i in range(6, len(fetched)):
216-
assert fetched[i] is None, f"node shouldn't have cached ffn.{i}."
217-
218-
await node.shutdown()
219-
neg_caching_peer.shutdown()
220-
for peer in peers:
221-
peer.shutdown()

tests/test_dht_validation.py

+28-28
Original file line numberDiff line numberDiff line change
@@ -29,6 +29,34 @@ def validators_for_app():
2929
}
3030

3131

32+
@pytest.mark.forked
33+
def test_dht_add_validators(validators_for_app):
34+
# One app may create a DHT with its validators
35+
dht = hivemind.DHT(start=False, record_validators=validators_for_app["A"])
36+
37+
# While the DHT process is not started, you can't send a command to append new validators
38+
with pytest.raises(RuntimeError):
39+
dht.add_validators(validators_for_app["B"])
40+
dht.run_in_background(await_ready=True)
41+
42+
# After starting the process, other apps may add new validators to the existing DHT
43+
dht.add_validators(validators_for_app["B"])
44+
45+
assert dht.store("field_a", b"bytes_value", hivemind.get_dht_time() + 10)
46+
assert dht.get("field_a", latest=True).value == b"bytes_value"
47+
48+
assert not dht.store("field_a", 666, hivemind.get_dht_time() + 10)
49+
assert dht.get("field_a", latest=True).value == b"bytes_value"
50+
51+
local_public_key = validators_for_app["A"][0].local_public_key
52+
assert dht.store("field_b", 777, hivemind.get_dht_time() + 10, subkey=local_public_key)
53+
dictionary = dht.get("field_b", latest=True).value
54+
assert len(dictionary) == 1 and dictionary[local_public_key].value == 777
55+
56+
assert not dht.store("unknown_key", 666, hivemind.get_dht_time() + 10)
57+
assert dht.get("unknown_key", latest=True) is None
58+
59+
3260
def test_composite_validator(validators_for_app):
3361
validator = CompositeValidator(validators_for_app["A"])
3462
assert [type(item) for item in validator._validators] == [SchemaValidator, RSASignatureValidator]
@@ -63,31 +91,3 @@ def test_composite_validator(validators_for_app):
6391
assert signed_record.value.count(b"[signature:") == 0
6492
# Expect failed validation since `unknown_key` is not a part of any schema
6593
assert not validator.validate(signed_record)
66-
67-
68-
@pytest.mark.forked
69-
def test_dht_add_validators(validators_for_app):
70-
# One app may create a DHT with its validators
71-
dht = hivemind.DHT(start=False, record_validators=validators_for_app["A"])
72-
73-
# While the DHT process is not started, you can't send a command to append new validators
74-
with pytest.raises(RuntimeError):
75-
dht.add_validators(validators_for_app["B"])
76-
dht.run_in_background(await_ready=True)
77-
78-
# After starting the process, other apps may add new validators to the existing DHT
79-
dht.add_validators(validators_for_app["B"])
80-
81-
assert dht.store("field_a", b"bytes_value", hivemind.get_dht_time() + 10)
82-
assert dht.get("field_a", latest=True).value == b"bytes_value"
83-
84-
assert not dht.store("field_a", 666, hivemind.get_dht_time() + 10)
85-
assert dht.get("field_a", latest=True).value == b"bytes_value"
86-
87-
local_public_key = validators_for_app["A"][0].local_public_key
88-
assert dht.store("field_b", 777, hivemind.get_dht_time() + 10, subkey=local_public_key)
89-
dictionary = dht.get("field_b", latest=True).value
90-
assert len(dictionary) == 1 and dictionary[local_public_key].value == 777
91-
92-
assert not dht.store("unknown_key", 666, hivemind.get_dht_time() + 10)
93-
assert dht.get("unknown_key", latest=True) is None

tests/test_start_server.py

+11-7
Original file line numberDiff line numberDiff line change
@@ -9,6 +9,15 @@
99
from hivemind.moe.server import background_server
1010

1111

12+
def cleanup_process(process, timeout=5):
13+
try:
14+
process.terminate()
15+
process.wait(timeout=timeout) # Add timeout to wait
16+
except:
17+
process.kill()
18+
process.wait(timeout=timeout)
19+
20+
1221
@pytest.mark.xfail(reason="Flaky test", strict=False)
1322
def test_background_server_identity_path():
1423
with TemporaryDirectory() as tempdir:
@@ -96,10 +105,5 @@ def test_cli_run_server_identity_path():
96105
assert addrs_1 != addrs_3
97106
assert addrs_2 != addrs_3
98107

99-
server_1_proc.terminate()
100-
server_2_proc.terminate()
101-
server_3_proc.terminate()
102-
103-
server_1_proc.wait()
104-
server_2_proc.wait()
105-
server_3_proc.wait()
108+
for p in [server_1_proc, server_2_proc, server_3_proc]:
109+
cleanup_process(p)

0 commit comments

Comments
 (0)