Skip to content

Commit

Permalink
feat: Implement sliding window rate limiter and enhance online store …
Browse files Browse the repository at this point in the history
…tag retrieval
  • Loading branch information
Bhargav Dodla committed Feb 5, 2025
1 parent 6c42453 commit 09302d1
Show file tree
Hide file tree
Showing 7 changed files with 103 additions and 148 deletions.
24 changes: 11 additions & 13 deletions sdk/python/feast/feature_view.py
Original file line number Diff line number Diff line change
Expand Up @@ -51,6 +51,8 @@
join_keys=[DUMMY_ENTITY_ID],
)

ONLINE_STORE_TAG_SUFFIX = "online_store_"


@typechecked
class FeatureView(BaseFeatureView):
Expand Down Expand Up @@ -496,20 +498,16 @@ def most_recent_end_time(self) -> Optional[datetime]:
return max([interval[1] for interval in self.materialization_intervals])

@property
def online_store_key_ttl_seconds(self) -> Optional[int]:
def get_online_store_tags(self) -> Dict[str, str]:
"""
Retrieves the online store TTL from the FeatureView's tags.
Retrieves online store specific tags.
Returns:
An integer representing the TTL in seconds, or None if not set.
A dictionary of tags. If no tags are found, returns an empty dictionary.
"""
ttl_str = self.tags.get("online_store_key_ttl_seconds")
if ttl_str:
try:
return int(ttl_str)
except ValueError:
raise ValueError(
f"Invalid online_store_key_ttl_seconds value '{ttl_str}' in tags. It must be an integer representing seconds."
)
else:
return None
tags = {
k.removeprefix(ONLINE_STORE_TAG_SUFFIX): v
for k, v in self.tags.items()
if k.startswith(ONLINE_STORE_TAG_SUFFIX)
}
return tags
Original file line number Diff line number Diff line change
Expand Up @@ -19,11 +19,10 @@
"""

import logging
import queue
import threading
import time
from datetime import datetime
from functools import partial
from queue import Queue
from typing import Any, Callable, Dict, List, Literal, Optional, Sequence, Tuple

from cassandra.auth import PlainTextAuthProvider
Expand All @@ -44,6 +43,7 @@
from feast.infra.online_stores.online_store import OnlineStore
from feast.protos.feast.types.EntityKey_pb2 import EntityKey as EntityKeyProto
from feast.protos.feast.types.Value_pb2 import Value as ValueProto
from feast.rate_limiter import SlidingWindowRateLimiter
from feast.repo_config import FeastConfigBaseModel

# Error messages
Expand Down Expand Up @@ -197,67 +197,8 @@ class CassandraLoadBalancingPolicy(FeastConfigBaseModel):
Default: 100.
"""


class RateLimiter:
def __init__(self, rate: int, interval: float):
"""
Initialize the rate limiter.
:param rate: Maximum number of requests allowed per interval.
:param interval: Time interval in seconds (e.g., 1 second for per-second limiting).
"""
self.rate = rate
self.tokens = rate
self.interval = interval
self.lock = threading.Lock()
self._start_refill_thread()

def _start_refill_thread(self):
"""Refills tokens periodically in a background thread."""

def refill():
while True:
time.sleep(self.interval / self.rate)
with self.lock:
if self.tokens < self.rate:
self.tokens += 1

threading.Thread(target=refill, daemon=True).start()

def allow(self) -> bool:
"""Check if a request can proceed."""
with self.lock:
if self.tokens > 0:
self.tokens -= 1
return True
return False

def __enter__(self):
"""Context manager entry: Wait until a token is available."""
while not self.allow():
time.sleep(0.01) # Small delay to prevent busy-waiting
return self

def __exit__(self, exc_type, exc_val, exc_tb):
"""Context manager exit: No cleanup needed."""
pass


class SlidingWindowRateLimiter:
def __init__(self, max_calls, period):
self.max_calls = max_calls
self.period = period
self.timestamps = [0] * max_calls
self.index = 0

def acquire(self):
now = time.time()
window_start = now - self.period

if self.timestamps[self.index] <= window_start:
self.timestamps[self.index] = now
self.index = (self.index + 1) % self.max_calls
return True
return False
write_rate_limit: Optional[StrictInt] = 0
"""The maximum number of write batches per second. Value 0 means no rate limiting. For spark materialization engine, this configuration is per executor task."""


class CassandraOnlineStore(OnlineStore):
Expand Down Expand Up @@ -417,8 +358,6 @@ def online_write_batch(
rows is written to the online store. Can be used to
display progress.
"""
current_time_in_ms = datetime.now().strftime("%Y-%m-%d %H:%M:%S.%f")[:-3]
print(f"{current_time_in_ms} Starting online_write_batch method...")

def on_success(result, concurrent_queue):
concurrent_queue.get_nowait()
Expand All @@ -429,24 +368,25 @@ def on_failure(exc, concurrent_queue):
print(f"Error writing a batch: {exc}")
raise Exception("Error writing a batch") from exc

write_concurrency = config.online_store.write_concurrency
rate_limiter = SlidingWindowRateLimiter(write_concurrency, 1)
concurrent_queue: queue.Queue = queue.Queue(maxsize=write_concurrency)
current_time_in_ms = datetime.now().strftime("%Y-%m-%d %H:%M:%S.%f")[:-3]
print(f"{current_time_in_ms} Using write concurrency: {write_concurrency}")
override_configs = table.get_online_store_tags

project = config.project
write_concurrency = config.online_store.get_override_config(
"write_concurrency", override_configs
)
ttl = (
table.online_store_key_ttl_seconds
or config.online_store.key_ttl_seconds
config.online_store.get_override_config(
"key_ttl_seconds", table.get_online_store_tags
)
or 0
)
session: Session = self._get_session(config)

current_time_in_ms = datetime.now().strftime("%Y-%m-%d %H:%M:%S.%f")[:-3]
print(
f"{current_time_in_ms} {session.session_id} Starting online_write_batch session created..."
write_rate_limit = config.online_store.get_override_config(
"write_rate_limit", override_configs
)
concurrent_queue: Queue = Queue(maxsize=write_concurrency)
rate_limiter = SlidingWindowRateLimiter(write_rate_limit, 1)

session: Session = self._get_session(config)
keyspace: str = self._keyspace
fqtable = CassandraOnlineStore._fq_table_name(keyspace, project, table)

Expand Down Expand Up @@ -475,16 +415,8 @@ def on_failure(exc, concurrent_queue):

# Wait until the rate limiter allows
if not rate_limiter.acquire():
current_time_in_ms = datetime.now().strftime("%Y-%m-%d %H:%M:%S.%f")[
:-3
]
print(f"{current_time_in_ms} Rate limit exceeded, waiting...")
while not rate_limiter.acquire():
time.sleep(0.001)
current_time_in_ms = datetime.now().strftime("%Y-%m-%d %H:%M:%S.%f")[
:-3
]
print(f"{current_time_in_ms} Rate limit released, continuing...")

future = session.execute_async(batch)
concurrent_queue.put(future)
Expand All @@ -503,22 +435,12 @@ def on_failure(exc, concurrent_queue):
if progress:
progress(1)
# Wait for all tasks to complete
current_time_in_ms = datetime.now().strftime("%Y-%m-%d %H:%M:%S.%f")[:-3]
print(
f"{current_time_in_ms} {session.session_id} waiting for queue to be empty"
)
while not concurrent_queue.empty():
time.sleep(0.001)
current_time_in_ms = datetime.now().strftime("%Y-%m-%d %H:%M:%S.%f")[:-3]
print(f"{current_time_in_ms} {session.session_id} queue is empty")

# correction for the last missing call to `progress`:
if progress:
progress(1)
current_time_in_ms = datetime.now().strftime("%Y-%m-%d %H:%M:%S.%f")[:-3]
print(
f"{current_time_in_ms} {session.session_id} Done online_write_batch method..."
)

def online_read(
self,
Expand Down Expand Up @@ -651,7 +573,9 @@ def _read_rows_by_entity_keys(
session,
select_cql,
((entity_key_bin,) for entity_key_bin in entity_key_bins),
concurrency=config.online_store.read_concurrency,
concurrency=config.online_store.get_override_config(
"read_concurrency", table.get_online_store_tags
),
)
# execute_concurrent_with_args return a sequence
# of (success, result_or_exception) pairs:
Expand Down
5 changes: 3 additions & 2 deletions sdk/python/feast/infra/online_stores/redis.py
Original file line number Diff line number Diff line change
Expand Up @@ -331,8 +331,9 @@ def online_write_batch(
pipe.hset(redis_key_bin, mapping=entity_hset)

ttl = (
table.online_store_key_ttl_seconds
or online_store_config.key_ttl_seconds
online_store_config.get_override_config(
"key_ttl_seconds", table.get_online_store_tags
)
or None
)
if ttl:
Expand Down
21 changes: 21 additions & 0 deletions sdk/python/feast/rate_limiter.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,21 @@
import time


class SlidingWindowRateLimiter:
def __init__(self, max_calls, period):
self.max_calls = max_calls
self.period = period
self.timestamps = [0] * max_calls
self.index = 0

def acquire(self):
if self.max_calls == 0:
return True
now = time.time()
window_start = now - self.period

if self.timestamps[self.index] <= window_start:
self.timestamps[self.index] = now
self.index = (self.index + 1) % self.max_calls
return True
return False
31 changes: 28 additions & 3 deletions sdk/python/feast/repo_config.py
Original file line number Diff line number Diff line change
Expand Up @@ -2,7 +2,7 @@
import os
import warnings
from pathlib import Path
from typing import Any, Dict, Optional
from typing import Any, Dict, Optional, Union, get_args, get_origin

import yaml
from pydantic import (
Expand Down Expand Up @@ -113,6 +113,31 @@ class FeastConfigBaseModel(BaseModel):

model_config = ConfigDict(arbitrary_types_allowed=True, extra="forbid")

def get_override_config(self, key: str, overrides: dict[str, str] = {}):
if key in overrides:
annotation = self.model_fields[key].annotation
base_type = get_origin(annotation) or annotation

if base_type is Union:
types = get_args(annotation)
base_type = next(t for t in types if t is not type(None))

if base_type:
try:
if base_type is bool:
return overrides[key].lower() == "true"
return base_type(overrides[key])
except (ValueError, TypeError) as e:
_logger.warning(
f"Could not convert override {key}='{overrides[key]}' to {base_type.__name__}. Exception is {e}"
)
else:
_logger.warning(
f"Could not find base_type for {key}. Please check your online store configuration."
)

return getattr(self, key, None)


class RegistryConfig(FeastBaseModel):
"""Metadata Store Configuration. Configuration that relates to reading from and writing to the Feast registry."""
Expand Down Expand Up @@ -382,13 +407,13 @@ def _validate_auth_config(cls, values: Any) -> Any:
elif values["auth"]["type"] not in ALLOWED_AUTH_TYPES:
raise ValueError(
f'auth configuration has invalid authentication type={values["auth"]["type"]}. Possible '
f'values={ALLOWED_AUTH_TYPES}'
f"values={ALLOWED_AUTH_TYPES}"
)
elif isinstance(values["auth"], AuthConfig):
if values["auth"].type not in ALLOWED_AUTH_TYPES:
raise ValueError(
f'auth configuration has invalid authentication type={values["auth"].type}. Possible '
f'values={ALLOWED_AUTH_TYPES}'
f"values={ALLOWED_AUTH_TYPES}"
)
return values

Expand Down
25 changes: 13 additions & 12 deletions sdk/python/tests/unit/infra/scaffolding/test_repo_config.py
Original file line number Diff line number Diff line change
Expand Up @@ -365,7 +365,7 @@ def test_repo_config_init_expedia_provider():
assert isinstance(c.offline_store, SparkOfflineStoreConfig)


def test_repo_config_init_expedia_provider_with_onlne_store_config():
def test_repo_config_init_expedia_provider_with_online_store_config():
c = _test_config(
dedent(
"""
Expand All @@ -390,7 +390,7 @@ def test_repo_config_init_expedia_provider_with_onlne_store_config():
assert isinstance(c.offline_store, SparkOfflineStoreConfig)


def test_repo_config_init_expedia_provider_with_online_store_config():
def test_override_online_store_config():
c = _test_config(
dedent(
"""
Expand All @@ -407,13 +407,14 @@ def test_repo_config_init_expedia_provider_with_online_store_config():
),
expect_error=None,
)
assert c.registry_config == "registry.db"
assert c.offline_config["type"] == "spark"
assert c.online_config["type"] == "redis"
assert c.online_config["connection_string"] == "localhost:6380"
assert c.online_config["redis_type"] == "redis_cluster"
assert c.online_config["key_ttl_seconds"] == 123456
assert c.batch_engine_config == "spark.engine"
assert isinstance(c.online_store, RedisOnlineStoreConfig)
assert isinstance(c.batch_engine, SparkMaterializationEngineConfig)
assert isinstance(c.offline_store, SparkOfflineStoreConfig)
assert c.online_store.get_override_config("key_ttl_seconds") == 123456
assert (
c.online_store.get_override_config(
"key_ttl_seconds", {"key_ttl_seconds": "654321"}
)
== 654321
)
assert (
c.online_store.get_override_config("key_ttl_seconds", {"batch_size": "10"})
== 123456
)
Loading

0 comments on commit 09302d1

Please sign in to comment.