@@ -126,6 +126,37 @@ def test_dht_single_node():
126
126
node .shutdown ()
127
127
128
128
129
+ @pytest .mark .forked
130
+ @pytest .mark .asyncio
131
+ async def test_negative_caching (n_peers = 10 ):
132
+ dht_kwargs = {"cache_locally" : False }
133
+
134
+ peers = [hivemind .DHT (start = True , ** dht_kwargs )]
135
+ initial_peers = peers [0 ].get_visible_maddrs ()
136
+ peers += [hivemind .DHT (initial_peers = initial_peers , start = True , ** dht_kwargs ) for _ in range (n_peers - 1 )]
137
+
138
+ writer_peer = random .choice (peers )
139
+ assert all (declare_experts (writer_peer , ["ffn.1.2.3" , "ffn.3.4.5" ], get_dht_time () + 30 ).values ())
140
+
141
+ neighbors = sum ([peer .get_visible_maddrs () for peer in random .sample (peers , min (3 , len (peers )))], [])
142
+ neg_caching_peer = hivemind .DHT (initial_peers = neighbors , start = True , ** dht_kwargs )
143
+ beam_search = MoEBeamSearcher (neg_caching_peer , uid_prefix = "ffn." , grid_size = (10 , 10 , 10 ), negative_caching = True )
144
+ # get prefixes by the peer with negative caching. Cache "no data" entries for ffn.0.*, ffn.2.*, ffn.4.*, ffn.5.*
145
+ assert len (beam_search .get_initial_beam (scores = [0.1 , 0.2 , 0.3 , 0.4 , 0.5 , 0.6 ], beam_size = 3 )) == 2
146
+
147
+ node = await DHTNode .create (initial_peers = neighbors )
148
+ fetched = await asyncio .gather (* (node .get (f"ffn.{ i } ." ) for i in range (10 )))
149
+ for i in range (6 ):
150
+ assert fetched [i ] is not None , f"node should have cached ffn.{ i } ."
151
+ for i in range (6 , len (fetched )):
152
+ assert fetched [i ] is None , f"node shouldn't have cached ffn.{ i } ."
153
+
154
+ await node .shutdown ()
155
+ neg_caching_peer .shutdown ()
156
+ for peer in peers :
157
+ peer .shutdown ()
158
+
159
+
129
160
def test_uid_patterns ():
130
161
valid_experts = [
131
162
"expert.1" ,
@@ -188,34 +219,3 @@ def test_uid_patterns():
188
219
assert not is_valid_uid (uid ), f"UID { uid } is not valid, but was perceived as valid"
189
220
for pfx in invalid_prefixes :
190
221
assert not is_valid_prefix (pfx ), f"Prefix { pfx } is not valid, but was perceived as valid"
191
-
192
-
193
- @pytest .mark .forked
194
- @pytest .mark .asyncio
195
- async def test_negative_caching (n_peers = 10 ):
196
- dht_kwargs = {"cache_locally" : False }
197
-
198
- peers = [hivemind .DHT (start = True , ** dht_kwargs )]
199
- initial_peers = peers [0 ].get_visible_maddrs ()
200
- peers += [hivemind .DHT (initial_peers = initial_peers , start = True , ** dht_kwargs ) for _ in range (n_peers - 1 )]
201
-
202
- writer_peer = random .choice (peers )
203
- assert all (declare_experts (writer_peer , ["ffn.1.2.3" , "ffn.3.4.5" ], get_dht_time () + 30 ).values ())
204
-
205
- neighbors = sum ([peer .get_visible_maddrs () for peer in random .sample (peers , min (3 , len (peers )))], [])
206
- neg_caching_peer = hivemind .DHT (initial_peers = neighbors , start = True , ** dht_kwargs )
207
- beam_search = MoEBeamSearcher (neg_caching_peer , uid_prefix = "ffn." , grid_size = (10 , 10 , 10 ), negative_caching = True )
208
- # get prefixes by the peer with negative caching. Cache "no data" entries for ffn.0.*, ffn.2.*, ffn.4.*, ffn.5.*
209
- assert len (beam_search .get_initial_beam (scores = [0.1 , 0.2 , 0.3 , 0.4 , 0.5 , 0.6 ], beam_size = 3 )) == 2
210
-
211
- node = await DHTNode .create (initial_peers = neighbors )
212
- fetched = await asyncio .gather (* (node .get (f"ffn.{ i } ." ) for i in range (10 )))
213
- for i in range (6 ):
214
- assert fetched [i ] is not None , f"node should have cached ffn.{ i } ."
215
- for i in range (6 , len (fetched )):
216
- assert fetched [i ] is None , f"node shouldn't have cached ffn.{ i } ."
217
-
218
- await node .shutdown ()
219
- neg_caching_peer .shutdown ()
220
- for peer in peers :
221
- peer .shutdown ()
0 commit comments