|
19 | 19 | """
|
20 | 20 |
|
21 | 21 | import logging
|
| 22 | +import threading |
| 23 | +import time |
22 | 24 | from datetime import datetime
|
| 25 | +from functools import partial |
23 | 26 | from typing import Any, Callable, Dict, List, Literal, Optional, Sequence, Tuple
|
24 | 27 |
|
25 | 28 | from cassandra.auth import PlainTextAuthProvider
|
@@ -194,6 +197,50 @@ class CassandraLoadBalancingPolicy(FeastConfigBaseModel):
|
194 | 197 | """
|
195 | 198 |
|
196 | 199 |
|
| 200 | +class RateLimiter: |
| 201 | + def __init__(self, rate: int, interval: float): |
| 202 | + """ |
| 203 | + Initialize the rate limiter. |
| 204 | + :param rate: Maximum number of requests allowed per interval. |
| 205 | + :param interval: Time interval in seconds (e.g., 1 second for per-second limiting). |
| 206 | + """ |
| 207 | + self.rate = rate |
| 208 | + self.tokens = rate |
| 209 | + self.interval = interval |
| 210 | + self.lock = threading.Lock() |
| 211 | + self._start_refill_thread() |
| 212 | + |
| 213 | + def _start_refill_thread(self): |
| 214 | + """Refills tokens periodically in a background thread.""" |
| 215 | + |
| 216 | + def refill(): |
| 217 | + while True: |
| 218 | + time.sleep(self.interval / self.rate) |
| 219 | + with self.lock: |
| 220 | + if self.tokens < self.rate: |
| 221 | + self.tokens += 1 |
| 222 | + |
| 223 | + threading.Thread(target=refill, daemon=True).start() |
| 224 | + |
| 225 | + def allow(self) -> bool: |
| 226 | + """Check if a request can proceed.""" |
| 227 | + with self.lock: |
| 228 | + if self.tokens > 0: |
| 229 | + self.tokens -= 1 |
| 230 | + return True |
| 231 | + return False |
| 232 | + |
| 233 | + def __enter__(self): |
| 234 | + """Ctimext manager entry: Wait until a token is available.""" |
| 235 | + while not self.allow(): |
| 236 | + time.sleep(0.01) # Small delay to prevent busy-waiting |
| 237 | + return self |
| 238 | + |
| 239 | + def __exit__(self, exc_type, exc_val, exc_tb): |
| 240 | + """Context manager exit: No cleanup needed.""" |
| 241 | + pass |
| 242 | + |
| 243 | + |
197 | 244 | class CassandraOnlineStore(OnlineStore):
|
198 | 245 | """
|
199 | 246 | Cassandra/Astra DB online store implementation for Feast.
|
@@ -351,7 +398,26 @@ def online_write_batch(
|
351 | 398 | rows is written to the online store. Can be used to
|
352 | 399 | display progress.
|
353 | 400 | """
|
| 401 | + |
| 402 | + def on_success(result, semaphore, futures, future, lock): |
| 403 | + semaphore.release() |
| 404 | + with lock: |
| 405 | + futures.remove(future) |
| 406 | + |
| 407 | + def on_failure(exc, semaphore, futures, future, lock): |
| 408 | + semaphore.release() |
| 409 | + with lock: |
| 410 | + futures.remove(future) |
| 411 | + logger.exception(f"Error writing a batch: {exc}") |
| 412 | + print(f"Error writing a batch: {exc}") |
| 413 | + raise Exception("Error writing a batch") from exc |
| 414 | + |
354 | 415 | write_concurrency = config.online_store.write_concurrency
|
| 416 | + |
| 417 | + rate_limiter = RateLimiter(rate=write_concurrency, interval=1) |
| 418 | + semaphore = threading.Semaphore(write_concurrency) |
| 419 | + lock = threading.Lock() |
| 420 | + |
355 | 421 | project = config.project
|
356 | 422 | ttl = (
|
357 | 423 | table.online_store_key_ttl_seconds
|
@@ -385,23 +451,43 @@ def online_write_batch(
|
385 | 451 | timestamp,
|
386 | 452 | )
|
387 | 453 | batch.add(insert_cql, params)
|
388 |
| - futures.append(session.execute_async(batch)) |
| 454 | + with rate_limiter: |
| 455 | + semaphore.acquire() |
| 456 | + future = session.execute_async(batch) |
| 457 | + futures.append(future) |
| 458 | + future.add_callbacks( |
| 459 | + partial( |
| 460 | + on_success, |
| 461 | + semaphore=semaphore, |
| 462 | + futures=futures, |
| 463 | + future=future, |
| 464 | + lock=lock, |
| 465 | + ), |
| 466 | + partial( |
| 467 | + on_failure, |
| 468 | + semaphore=semaphore, |
| 469 | + futures=futures, |
| 470 | + future=future, |
| 471 | + lock=lock, |
| 472 | + ), |
| 473 | + ) |
389 | 474 |
|
390 | 475 | # TODO: Make this efficient by leveraging continuous writes rather
|
391 | 476 | # than blocking until all writes are done. We may need to rate limit
|
392 | 477 | # the writes to reduce the impact on read performance.
|
393 |
| - if len(futures) >= write_concurrency: |
394 |
| - # Raises exception if at least one of the batch fails |
395 |
| - self._wait_for_futures(futures) |
396 |
| - futures.clear() |
| 478 | + # if len(futures) >= write_concurrency: |
| 479 | + # # Raises exception if at least one of the batch fails |
| 480 | + # self._wait_for_futures(futures) |
| 481 | + # futures.clear() |
397 | 482 |
|
398 | 483 | # this happens N-1 times, will be corrected outside:
|
399 | 484 | if progress:
|
400 | 485 | progress(1)
|
401 |
| - |
402 |
| - if len(futures) > 0: |
403 |
| - self._wait_for_futures(futures) |
404 |
| - futures.clear() |
| 486 | + while futures: |
| 487 | + time.sleep(0.001) |
| 488 | + # if len(futures) > 0: |
| 489 | + # self._wait_for_futures(futures) |
| 490 | + # futures.clear() |
405 | 491 |
|
406 | 492 | # correction for the last missing call to `progress`:
|
407 | 493 | if progress:
|
|
0 commit comments