Skip to content

Commit a29b5ae

Browse files
author
Bhargav Dodla
committed
fix: Using rate limiter
1 parent c74d4f3 commit a29b5ae

File tree

1 file changed

+95
-9
lines changed

1 file changed

+95
-9
lines changed

sdk/python/feast/infra/online_stores/contrib/cassandra_online_store/cassandra_online_store.py

Lines changed: 95 additions & 9 deletions
Original file line numberDiff line numberDiff line change
@@ -19,7 +19,10 @@
1919
"""
2020

2121
import logging
22+
import threading
23+
import time
2224
from datetime import datetime
25+
from functools import partial
2326
from typing import Any, Callable, Dict, List, Literal, Optional, Sequence, Tuple
2427

2528
from cassandra.auth import PlainTextAuthProvider
@@ -194,6 +197,50 @@ class CassandraLoadBalancingPolicy(FeastConfigBaseModel):
194197
"""
195198

196199

200+
class RateLimiter:
201+
def __init__(self, rate: int, interval: float):
202+
"""
203+
Initialize the rate limiter.
204+
:param rate: Maximum number of requests allowed per interval.
205+
:param interval: Time interval in seconds (e.g., 1 second for per-second limiting).
206+
"""
207+
self.rate = rate
208+
self.tokens = rate
209+
self.interval = interval
210+
self.lock = threading.Lock()
211+
self._start_refill_thread()
212+
213+
def _start_refill_thread(self):
214+
"""Refills tokens periodically in a background thread."""
215+
216+
def refill():
217+
while True:
218+
time.sleep(self.interval / self.rate)
219+
with self.lock:
220+
if self.tokens < self.rate:
221+
self.tokens += 1
222+
223+
threading.Thread(target=refill, daemon=True).start()
224+
225+
def allow(self) -> bool:
226+
"""Check if a request can proceed."""
227+
with self.lock:
228+
if self.tokens > 0:
229+
self.tokens -= 1
230+
return True
231+
return False
232+
233+
def __enter__(self):
234+
"""Ctimext manager entry: Wait until a token is available."""
235+
while not self.allow():
236+
time.sleep(0.01) # Small delay to prevent busy-waiting
237+
return self
238+
239+
def __exit__(self, exc_type, exc_val, exc_tb):
240+
"""Context manager exit: No cleanup needed."""
241+
pass
242+
243+
197244
class CassandraOnlineStore(OnlineStore):
198245
"""
199246
Cassandra/Astra DB online store implementation for Feast.
@@ -351,7 +398,26 @@ def online_write_batch(
351398
rows is written to the online store. Can be used to
352399
display progress.
353400
"""
401+
402+
def on_success(result, semaphore, futures, future, lock):
403+
semaphore.release()
404+
with lock:
405+
futures.remove(future)
406+
407+
def on_failure(exc, semaphore, futures, future, lock):
408+
semaphore.release()
409+
with lock:
410+
futures.remove(future)
411+
logger.exception(f"Error writing a batch: {exc}")
412+
print(f"Error writing a batch: {exc}")
413+
raise Exception("Error writing a batch") from exc
414+
354415
write_concurrency = config.online_store.write_concurrency
416+
417+
rate_limiter = RateLimiter(rate=write_concurrency, interval=1)
418+
semaphore = threading.Semaphore(write_concurrency)
419+
lock = threading.Lock()
420+
355421
project = config.project
356422
ttl = (
357423
table.online_store_key_ttl_seconds
@@ -385,23 +451,43 @@ def online_write_batch(
385451
timestamp,
386452
)
387453
batch.add(insert_cql, params)
388-
futures.append(session.execute_async(batch))
454+
with rate_limiter:
455+
semaphore.acquire()
456+
future = session.execute_async(batch)
457+
futures.append(future)
458+
future.add_callbacks(
459+
partial(
460+
on_success,
461+
semaphore=semaphore,
462+
futures=futures,
463+
future=future,
464+
lock=lock,
465+
),
466+
partial(
467+
on_failure,
468+
semaphore=semaphore,
469+
futures=futures,
470+
future=future,
471+
lock=lock,
472+
),
473+
)
389474

390475
# TODO: Make this efficient by leveraging continuous writes rather
391476
# than blocking until all writes are done. We may need to rate limit
392477
# the writes to reduce the impact on read performance.
393-
if len(futures) >= write_concurrency:
394-
# Raises exception if at least one of the batch fails
395-
self._wait_for_futures(futures)
396-
futures.clear()
478+
# if len(futures) >= write_concurrency:
479+
# # Raises exception if at least one of the batch fails
480+
# self._wait_for_futures(futures)
481+
# futures.clear()
397482

398483
# this happens N-1 times, will be corrected outside:
399484
if progress:
400485
progress(1)
401-
402-
if len(futures) > 0:
403-
self._wait_for_futures(futures)
404-
futures.clear()
486+
while futures:
487+
time.sleep(0.001)
488+
# if len(futures) > 0:
489+
# self._wait_for_futures(futures)
490+
# futures.clear()
405491

406492
# correction for the last missing call to `progress`:
407493
if progress:

0 commit comments

Comments
 (0)