Skip to content

Commit

Permalink
fix: Using rate limiter
Browse files Browse the repository at this point in the history
  • Loading branch information
Bhargav Dodla committed Feb 5, 2025
1 parent c74d4f3 commit a29b5ae
Showing 1 changed file with 95 additions and 9 deletions.
Original file line number Diff line number Diff line change
Expand Up @@ -19,7 +19,10 @@
"""

import logging
import threading
import time
from datetime import datetime
from functools import partial
from typing import Any, Callable, Dict, List, Literal, Optional, Sequence, Tuple

from cassandra.auth import PlainTextAuthProvider
Expand Down Expand Up @@ -194,6 +197,50 @@ class CassandraLoadBalancingPolicy(FeastConfigBaseModel):
"""


class RateLimiter:
def __init__(self, rate: int, interval: float):
"""
Initialize the rate limiter.
:param rate: Maximum number of requests allowed per interval.
:param interval: Time interval in seconds (e.g., 1 second for per-second limiting).
"""
self.rate = rate
self.tokens = rate
self.interval = interval
self.lock = threading.Lock()
self._start_refill_thread()

def _start_refill_thread(self):
"""Refills tokens periodically in a background thread."""

def refill():
while True:
time.sleep(self.interval / self.rate)
with self.lock:
if self.tokens < self.rate:
self.tokens += 1

threading.Thread(target=refill, daemon=True).start()

def allow(self) -> bool:
"""Check if a request can proceed."""
with self.lock:
if self.tokens > 0:
self.tokens -= 1
return True
return False

def __enter__(self):
"""Ctimext manager entry: Wait until a token is available."""
while not self.allow():
time.sleep(0.01) # Small delay to prevent busy-waiting
return self

def __exit__(self, exc_type, exc_val, exc_tb):
"""Context manager exit: No cleanup needed."""
pass


class CassandraOnlineStore(OnlineStore):
"""
Cassandra/Astra DB online store implementation for Feast.
Expand Down Expand Up @@ -351,7 +398,26 @@ def online_write_batch(
rows is written to the online store. Can be used to
display progress.
"""

def on_success(result, semaphore, futures, future, lock):
semaphore.release()
with lock:
futures.remove(future)

def on_failure(exc, semaphore, futures, future, lock):
semaphore.release()
with lock:
futures.remove(future)
logger.exception(f"Error writing a batch: {exc}")
print(f"Error writing a batch: {exc}")
raise Exception("Error writing a batch") from exc

write_concurrency = config.online_store.write_concurrency

rate_limiter = RateLimiter(rate=write_concurrency, interval=1)
semaphore = threading.Semaphore(write_concurrency)
lock = threading.Lock()

project = config.project
ttl = (
table.online_store_key_ttl_seconds
Expand Down Expand Up @@ -385,23 +451,43 @@ def online_write_batch(
timestamp,
)
batch.add(insert_cql, params)
futures.append(session.execute_async(batch))
with rate_limiter:
semaphore.acquire()
future = session.execute_async(batch)
futures.append(future)
future.add_callbacks(
partial(
on_success,
semaphore=semaphore,
futures=futures,
future=future,
lock=lock,
),
partial(
on_failure,
semaphore=semaphore,
futures=futures,
future=future,
lock=lock,
),
)

# TODO: Make this efficient by leveraging continuous writes rather
# than blocking until all writes are done. We may need to rate limit
# the writes to reduce the impact on read performance.
if len(futures) >= write_concurrency:
# Raises exception if at least one of the batch fails
self._wait_for_futures(futures)
futures.clear()
# if len(futures) >= write_concurrency:
# # Raises exception if at least one of the batch fails
# self._wait_for_futures(futures)
# futures.clear()

# this happens N-1 times, will be corrected outside:
if progress:
progress(1)

if len(futures) > 0:
self._wait_for_futures(futures)
futures.clear()
while futures:
time.sleep(0.001)
# if len(futures) > 0:
# self._wait_for_futures(futures)
# futures.clear()

# correction for the last missing call to `progress`:
if progress:
Expand Down

0 comments on commit a29b5ae

Please sign in to comment.