Skip to content

Commit

Permalink
feat: Row level TTL and cleaned up
Browse files Browse the repository at this point in the history
  • Loading branch information
Bhargav Dodla committed Feb 3, 2025
1 parent beac9e7 commit 4470a5f
Showing 1 changed file with 17 additions and 98 deletions.
Original file line number Diff line number Diff line change
Expand Up @@ -65,7 +65,7 @@
INSERT_CQL_4_TEMPLATE = (
"INSERT INTO {fqtable} (feature_name,"
" value, entity_key, event_ts) VALUES"
" (?, ?, ?, ?);"
" (?, ?, ?, ?) USING TTL {ttl};"
)

SELECT_CQL_TEMPLATE = "SELECT {columns} FROM {fqtable} WHERE entity_key = ?;"
Expand All @@ -78,7 +78,7 @@
event_ts TIMESTAMP,
created_ts TIMESTAMP,
PRIMARY KEY ((entity_key), feature_name)
) WITH CLUSTERING ORDER BY (feature_name ASC) AND default_time_to_live={ttl};
) WITH CLUSTERING ORDER BY (feature_name ASC);
"""

DROP_TABLE_CQL_TEMPLATE = "DROP TABLE IF EXISTS {fqtable};"
Expand Down Expand Up @@ -194,10 +194,6 @@ class CassandraLoadBalancingPolicy(FeastConfigBaseModel):
"""


def get_current_time_in_ms():
return datetime.now().strftime("%Y-%m-%d %H:%M:%S.%f")[:-3]


class CassandraOnlineStore(OnlineStore):
"""
Cassandra/Astra DB online store implementation for Feast.
Expand Down Expand Up @@ -229,15 +225,9 @@ def _get_session(self, config: RepoConfig):

if self._session:
if not self._session.is_shutdown:
print(f"{get_current_time_in_ms()} Reusing existing session..")
return self._session
else:
self._session = None
print(
f"{get_current_time_in_ms()} Setting a session to None. Creating a new session.."
)
else:
print(f"{get_current_time_in_ms()} Creating a new session..")
if not self._session:
# configuration consistency checks
hosts = online_store_config.hosts
Expand Down Expand Up @@ -307,8 +297,6 @@ def _get_session(self, config: RepoConfig):
hosts,
port=port,
auth_provider=auth_provider,
idle_heartbeat_interval=0,
idle_heartbeat_timeout=0,
**cluster_kwargs,
)
else:
Expand All @@ -332,61 +320,13 @@ def __del__(self):
you can't use the session object anymore.
You'd get a RuntimeError "cannot schedule new futures after shutdown".
"""
print("Calling CassandraOnlineStore __del__() method")
if self._session:
if not self._session.is_shutdown:
self._session.shutdown()
current_time_in_ms = datetime.now().strftime("%Y-%m-%d %H:%M:%S.%f")[
:-3
]
print(f"{current_time_in_ms} Session is shutdown")

if self._cluster:
if not self._cluster.is_shutdown:
self._cluster.shutdown()
current_time_in_ms = datetime.now().strftime("%Y-%m-%d %H:%M:%S.%f")[
:-3
]
print(f"{current_time_in_ms} Cluster is shutdown")

def online_write_batch_connector(
self,
config: RepoConfig,
table: FeatureView,
data: List[
Tuple[EntityKeyProto, Dict[str, ValueProto], datetime, Optional[datetime]]
],
progress: Optional[Callable[[int], Any]],
) -> List[Tuple[str, bytes, str, datetime]]:
"""
Write a batch of features of several entities to the database.
Args:
config: The RepoConfig for the current FeatureStore.
table: Feast FeatureView.
data: a list of quadruplets containing Feature data. Each
quadruplet contains an Entity Key, a dict containing feature
values, an event timestamp for the row, and
the created timestamp for the row if it exists.
progress: Optional function to be called once every mini-batch of
rows is written to the online store. Can be used to
display progress.
"""
data_list = []
for entity_key, values, timestamp, created_ts in data:
entity_key_bin = serialize_entity_key(
entity_key,
entity_key_serialization_version=config.entity_key_serialization_version,
).hex()
for feature_name, val in values.items():
params: Tuple[str, bytes, str, datetime] = (
feature_name,
val.SerializeToString(),
entity_key_bin,
timestamp,
)
data_list.append(params)
return data_list

def online_write_batch(
self,
Expand All @@ -411,32 +351,26 @@ def online_write_batch(
rows is written to the online store. Can be used to
display progress.
"""
logger.info(f"Started writing data of size {len(data)} to CassandraOnlineStore")
print(
f"{get_current_time_in_ms()} Started writing data of size {len(data)} to CassandraOnlineStore"
)
write_concurrency = config.online_store.write_concurrency
project = config.project

# def on_success(result, semaphore):
# semaphore.release()

# def on_failure(exc, semaphore):
# semaphore.release()
# logger.exception(f"Error writing a batch: {exc}")
# print(f"Error writing a batch: {exc}")
# raise Exception("Error writing a batch") from exc

ttl = (
table.online_store_key_ttl_seconds
or config.online_store.key_ttl_seconds
or 0
)
session: Session = self._get_session(config)
keyspace: str = self._keyspace
fqtable = CassandraOnlineStore._fq_table_name(keyspace, project, table)

futures = []
insert_cql = self._get_cql_statement(
config, "insert4", fqtable=fqtable, session=session
config,
"insert4",
fqtable=fqtable,
ttl=ttl,
session=session,
)

# semaphore = Semaphore(write_concurrency)
for entity_key, values, timestamp, created_ts in data:
batch = BatchStatement(batch_type=BatchType.UNLOGGED)
entity_key_bin = serialize_entity_key(
Expand All @@ -451,14 +385,11 @@ def online_write_batch(
timestamp,
)
batch.add(insert_cql, params)
# semaphore.acquire()
# future = session.execute_async(batch)
# future.add_callbacks(
# partial(on_success, semaphore=semaphore),
# partial(on_failure, semaphore=semaphore),
# )
futures.append(session.execute_async(batch))

# TODO: Make this efficient by leveraging continuous writes rather
# than blocking until all writes are done. We may need to rate limit
# the writes to reduce the impact on read performance.
if len(futures) >= write_concurrency:
# Raises exception if at least one of the batch fails
try:
Expand All @@ -484,13 +415,6 @@ def online_write_batch(
print(f"Error writing a batch: {exc}")
raise Exception("Error writing a batch") from exc

# Wait for all tasks to complete
# while semaphore._value < write_concurrency:
# print(f"{get_current_time_in_ms()} Waiting for active tasks to complete")
# time.sleep(0.01)

logger.info("Done writing data to CassandraOnlineStore")
print(f"{get_current_time_in_ms()} Done writing data to CassandraOnlineStore")
# correction for the last missing call to `progress`:
if progress:
progress(1)
Expand Down Expand Up @@ -663,13 +587,8 @@ def _create_table(self, config: RepoConfig, project: str, table: FeatureView):
session: Session = self._get_session(config)
keyspace: str = self._keyspace
fqtable = CassandraOnlineStore._fq_table_name(keyspace, project, table)
ttl = (
table.online_store_key_ttl_seconds
or config.online_store.key_ttl_seconds
or 0
)
create_cql = self._get_cql_statement(config, "create", fqtable, ttl=ttl)
logger.info(f"Creating table {fqtable} with TTL {ttl}.")
create_cql = self._get_cql_statement(config, "create", fqtable)
logger.info(f"Creating table {fqtable} in keyspace {keyspace}.")
session.execute(create_cql)

def _get_cql_statement(
Expand Down

0 comments on commit 4470a5f

Please sign in to comment.