diff --git a/sdk/python/feast/infra/online_stores/contrib/cassandra_online_store/cassandra_online_store.py b/sdk/python/feast/infra/online_stores/contrib/cassandra_online_store/cassandra_online_store.py index a9cbd87f9d2..6cceb9924ba 100644 --- a/sdk/python/feast/infra/online_stores/contrib/cassandra_online_store/cassandra_online_store.py +++ b/sdk/python/feast/infra/online_stores/contrib/cassandra_online_store/cassandra_online_store.py @@ -18,9 +18,10 @@ Cassandra/Astra DB online store for Feast. """ -import atexit import logging from datetime import datetime +from functools import partial +from threading import Condition, Lock, Semaphore from typing import Any, Callable, Dict, List, Literal, Optional, Sequence, Tuple from cassandra.auth import PlainTextAuthProvider @@ -192,6 +193,10 @@ class CassandraLoadBalancingPolicy(FeastConfigBaseModel): """ +def get_current_time_in_ms(): + return datetime.now().strftime("%Y-%m-%d %H:%M:%S.%f")[:-3] + + class CassandraOnlineStore(OnlineStore): """ Cassandra/Astra DB online store implementation for Feast. @@ -209,125 +214,114 @@ class CassandraOnlineStore(OnlineStore): _keyspace: str = "feast_keyspace" _prepared_statements: Dict[str, PreparedStatement] = {} - def shutdown(self): - """ - Shutdown the Cassandra cluster and session. - """ - if self._session: - if not self._session.is_shutdown: - self._session.shutdown() - if self._cluster: - if not self._cluster.is_shutdown: - self._cluster.shutdown() - - def _get_cluster(self, config: RepoConfig): - """ - Establish the database connection, if not yet created, - and return it. - - Also perform basic config validation checks. - """ - - online_store_config = config.online_store - if not isinstance(online_store_config, CassandraOnlineStoreConfig): - raise CassandraInvalidConfig(E_CASSANDRA_UNEXPECTED_CONFIGURATION_CLASS) - - if self._cluster: - if not self._cluster.is_shutdown: - print("Reusing existing cluster..") - return self._cluster - else: - self._cluster = None - print("Setting _cluster to None..") - else: - print("Creating a new cluster..") - if not self._cluster: - # configuration consistency checks - hosts = online_store_config.hosts - secure_bundle_path = online_store_config.secure_bundle_path - port = online_store_config.port or 9042 - keyspace = online_store_config.keyspace - username = online_store_config.username - password = online_store_config.password - protocol_version = online_store_config.protocol_version - - db_directions = hosts or secure_bundle_path - if not db_directions or not keyspace: - raise CassandraInvalidConfig(E_CASSANDRA_NOT_CONFIGURED) - if hosts and secure_bundle_path: - raise CassandraInvalidConfig(E_CASSANDRA_MISCONFIGURED) - if (username is None) ^ (password is None): - raise CassandraInvalidConfig(E_CASSANDRA_INCONSISTENT_AUTH) - - if username is not None: - auth_provider = PlainTextAuthProvider( - username=username, - password=password, - ) - else: - auth_provider = None - - # handling of load-balancing policy (optional) - if online_store_config.load_balancing: - # construct a proper execution profile embedding - # the configured LB policy - _lbp_name = online_store_config.load_balancing.load_balancing_policy - if _lbp_name == "DCAwareRoundRobinPolicy": - lb_policy = DCAwareRoundRobinPolicy( - local_dc=online_store_config.load_balancing.local_dc, - ) - elif _lbp_name == "TokenAwarePolicy(DCAwareRoundRobinPolicy)": - lb_policy = TokenAwarePolicy( - DCAwareRoundRobinPolicy( - local_dc=online_store_config.load_balancing.local_dc, - ) - ) - else: - raise CassandraInvalidConfig(E_CASSANDRA_UNKNOWN_LB_POLICY) - - # wrap it up in a map of ex.profiles with a default - exe_profile = ExecutionProfile( - request_timeout=online_store_config.request_timeout, - load_balancing_policy=lb_policy, - ) - execution_profiles = {EXEC_PROFILE_DEFAULT: exe_profile} - else: - execution_profiles = None - - # additional optional keyword args to Cluster - cluster_kwargs = { - k: v - for k, v in { - "protocol_version": protocol_version, - "execution_profiles": execution_profiles, - }.items() - if v is not None - } - - # creation of Cluster (Cassandra vs. Astra) - if hosts: - self._cluster = Cluster( - hosts, - port=port, - auth_provider=auth_provider, - idle_heartbeat_interval=0, - idle_heartbeat_timeout=0, - **cluster_kwargs, - ) - else: - # we use 'secure_bundle_path' - self._cluster = Cluster( - cloud={"secure_connect_bundle": secure_bundle_path}, - auth_provider=auth_provider, - **cluster_kwargs, - ) - - # creation of Session - self._keyspace = keyspace - # self._session = self._cluster.connect(self._keyspace) - atexit.register(self._cluster.shutdown) - - return self._cluster + # def _get_cluster(self, config: RepoConfig): + # """ + # Establish the database connection, if not yet created, + # and return it. + + # Also perform basic config validation checks. + # """ + + # online_store_config = config.online_store + # if not isinstance(online_store_config, CassandraOnlineStoreConfig): + # raise CassandraInvalidConfig(E_CASSANDRA_UNEXPECTED_CONFIGURATION_CLASS) + + # if self._cluster: + # if not self._cluster.is_shutdown: + # print("Reusing existing cluster..") + # return self._cluster + # else: + # self._cluster = None + # print("Setting _cluster to None..") + # else: + # print("Creating a new cluster..") + # if not self._cluster: + # # configuration consistency checks + # hosts = online_store_config.hosts + # secure_bundle_path = online_store_config.secure_bundle_path + # port = online_store_config.port or 9042 + # keyspace = online_store_config.keyspace + # username = online_store_config.username + # password = online_store_config.password + # protocol_version = online_store_config.protocol_version + + # db_directions = hosts or secure_bundle_path + # if not db_directions or not keyspace: + # raise CassandraInvalidConfig(E_CASSANDRA_NOT_CONFIGURED) + # if hosts and secure_bundle_path: + # raise CassandraInvalidConfig(E_CASSANDRA_MISCONFIGURED) + # if (username is None) ^ (password is None): + # raise CassandraInvalidConfig(E_CASSANDRA_INCONSISTENT_AUTH) + + # if username is not None: + # auth_provider = PlainTextAuthProvider( + # username=username, + # password=password, + # ) + # else: + # auth_provider = None + + # # handling of load-balancing policy (optional) + # if online_store_config.load_balancing: + # # construct a proper execution profile embedding + # # the configured LB policy + # _lbp_name = online_store_config.load_balancing.load_balancing_policy + # if _lbp_name == "DCAwareRoundRobinPolicy": + # lb_policy = DCAwareRoundRobinPolicy( + # local_dc=online_store_config.load_balancing.local_dc, + # ) + # elif _lbp_name == "TokenAwarePolicy(DCAwareRoundRobinPolicy)": + # lb_policy = TokenAwarePolicy( + # DCAwareRoundRobinPolicy( + # local_dc=online_store_config.load_balancing.local_dc, + # ) + # ) + # else: + # raise CassandraInvalidConfig(E_CASSANDRA_UNKNOWN_LB_POLICY) + + # # wrap it up in a map of ex.profiles with a default + # exe_profile = ExecutionProfile( + # request_timeout=online_store_config.request_timeout, + # load_balancing_policy=lb_policy, + # ) + # execution_profiles = {EXEC_PROFILE_DEFAULT: exe_profile} + # else: + # execution_profiles = None + + # # additional optional keyword args to Cluster + # cluster_kwargs = { + # k: v + # for k, v in { + # "protocol_version": protocol_version, + # "execution_profiles": execution_profiles, + # }.items() + # if v is not None + # } + + # # creation of Cluster (Cassandra vs. Astra) + # if hosts: + # self._cluster = Cluster( + # hosts, + # port=port, + # auth_provider=auth_provider, + # idle_heartbeat_interval=0, + # idle_heartbeat_timeout=0, + # **cluster_kwargs, + # ) + # else: + # # we use 'secure_bundle_path' + # self._cluster = Cluster( + # cloud={"secure_connect_bundle": secure_bundle_path}, + # auth_provider=auth_provider, + # **cluster_kwargs, + # ) + + # # creation of Session + # self._keyspace = keyspace + # # self._session = self._cluster.connect(self._keyspace) + # atexit.register(self._cluster.shutdown) + + # return self._cluster def _get_session(self, config: RepoConfig): """ @@ -343,13 +337,15 @@ def _get_session(self, config: RepoConfig): if self._session: if not self._session.is_shutdown: - print("Reusing existing session..") + print(f"{get_current_time_in_ms()} Reusing existing session..") return self._session else: self._session = None - print("Setting a session to None. Creating a new session..") + print( + f"{get_current_time_in_ms()} Setting a session to None. Creating a new session.." + ) else: - print("Creating a new session..") + print(f"{get_current_time_in_ms()} Creating a new session..") if not self._session: # configuration consistency checks hosts = online_store_config.hosts @@ -439,47 +435,55 @@ def _get_session(self, config: RepoConfig): def __del__(self): """ - One may be tempted to reclaim resources and do, here: - if self._session: - self._session.shutdown() - But *beware*, DON'T DO THIS. - Indeed this could destroy the session object before some internal - tasks runs in other threads (this is handled internally in the - Cassandra driver). + Shutting down the session and cluster objects. If you don't do this, + you would notice increase in connection spikes on the cluster. Once shutdown, + you can't use the session object anymore. You'd get a RuntimeError "cannot schedule new futures after shutdown". """ - print("Calling CassandraOnlineStore __del__() method") + print( + f"{get_current_time_in_ms()} Calling CassandraOnlineStore __del__() method" + ) + if self._session: + if not self._session.is_shutdown: + self._session.shutdown() + print(f"{get_current_time_in_ms()} Session is shutdown") + if self._cluster: if not self._cluster.is_shutdown: - current_datetime_with_ms = datetime.now().strftime( - "%Y-%m-%d %H:%M:%S.%f" - )[:-3] - print( - f"{current_datetime_with_ms} {self._cluster.client_id}: Cluster is still active" - ) - print( - f"{current_datetime_with_ms} {self._cluster.client_id}: Del Control Connection Host {self._cluster.get_control_connection_host()}" - ) - for connection in self._cluster.get_connection_holders(): - print( - f"{self._cluster.client_id}: Del Connection ID: {connection.get_connections()}" - ) - self._cluster.control_connection.shutdown() - if not self._cluster.scheduler.is_shutdown: - print(f"{self._cluster.client_id}: Shutting down scheduler") - self._cluster.scheduler.shutdown() - # self._cluster.shutdown() - current_datetime_with_ms = datetime.now().strftime( - "%Y-%m-%d %H:%M:%S.%f" - )[:-3] - print( - f"{current_datetime_with_ms} {self._cluster.client_id}: Done __del__(): Cluster is shutdown" - ) + self._cluster.shutdown() + print(f"{get_current_time_in_ms()} Cluster is shutdown") - else: - print(f"{self._cluster.client_id}: Cluster is not active") - else: - print("Cluster object doesn't exists.") + # if self._cluster: + # if not self._cluster.is_shutdown: + # current_datetime_with_ms = datetime.now().strftime( + # "%Y-%m-%d %H:%M:%S.%f" + # )[:-3] + # print( + # f"{current_datetime_with_ms} {self._cluster.client_id}: Cluster is still active" + # ) + # print( + # f"{current_datetime_with_ms} {self._cluster.client_id}: Del Control Connection Host {self._cluster.get_control_connection_host()}" + # ) + # for connection in self._cluster.get_connection_holders(): + # print( + # f"{self._cluster.client_id}: Del Connection ID: {connection.get_connections()}" + # ) + # self._cluster.control_connection.shutdown() + # if not self._cluster.scheduler.is_shutdown: + # print(f"{self._cluster.client_id}: Shutting down scheduler") + # self._cluster.scheduler.shutdown() + # # self._cluster.shutdown() + # current_datetime_with_ms = datetime.now().strftime( + # "%Y-%m-%d %H:%M:%S.%f" + # )[:-3] + # print( + # f"{current_datetime_with_ms} {self._cluster.client_id}: Done __del__(): Cluster is shutdown" + # ) + + # else: + # print(f"{self._cluster.client_id}: Cluster is not active") + # else: + # print("Cluster object doesn't exists.") # pass def online_write_batch_connector( @@ -544,10 +548,13 @@ def online_write_batch( rows is written to the online store. Can be used to display progress. """ - current_datetime_with_ms = datetime.now().strftime("%Y-%m-%d %H:%M:%S.%f")[:-3] + logger.info(f"Started writing data of size {len(data)} to CassandraOnlineStore") print( - f"{current_datetime_with_ms} Called CassandraOnlineStore online_write_batch method with data size: {len(data)}" + f"{get_current_time_in_ms()} Started writing data of size {len(data)} to CassandraOnlineStore" ) + active_tasks = 0 + lock = Lock() + condition = Condition(lock) # def clusterStatus(cluster, client_id): # if cluster is None: @@ -575,14 +582,30 @@ def online_write_batch( # f"{current_datetime_with_ms} {cluster.client_id}: Initial Control Connection Host {self._cluster.get_control_connection_host()}" # ) + def on_success(result, semaphore): + global active_tasks + with condition: + active_tasks -= 1 + if active_tasks == 0: + print(f"{get_current_time_in_ms()} Notifying all tasks to complete") + condition.notify_all() + semaphore.release() + + def on_failure(exc): + logger.error(f"Error writing a batch: {exc}") + print(f"Error writing a batch: {exc}") + raise Exception("Error writing a batch") from exc + session: Session = self._get_session(config) keyspace: str = self._keyspace fqtable = CassandraOnlineStore._fq_table_name(keyspace, project, table) - futures = [] + # futures = [] insert_cql = self._get_cql_statement( config, "insert4", fqtable=fqtable, session=session ) + + semaphore = Semaphore(config.online_store.write_concurrency) for entity_key, values, timestamp, created_ts in data: batch = BatchStatement(batch_type=BatchType.UNLOGGED) entity_key_bin = serialize_entity_key( @@ -597,31 +620,24 @@ def online_write_batch( timestamp, ) batch.add(insert_cql, params) + with condition: + active_tasks += 1 + semaphore.acquire() + future = session.execute_async(batch) + future.add_callbacks(partial(on_success, semaphore=semaphore), on_failure) + # this happens N-1 times, will be corrected outside: if progress: progress(1) - futures.append(session.execute_async(batch)) - if len(futures) >= config.online_store.write_concurrency: - # Raises exception if at least one of the batch fails - try: - for future in futures: - future.result() - futures = [] - except Exception as exc: - logger.error(f"Error writing a batch: {exc}") - print(f"Error writing a batch: {exc}") - raise Exception("Error writing a batch") from exc - - if len(futures) > 0: - try: - for future in futures: - future.result() - futures = [] - except Exception as exc: - logger.error(f"Error writing a batch: {exc}") - print(f"Error writing a batch: {exc}") - raise Exception("Error writing a batch") from exc + # Wait for all tasks to complete + with condition: + while active_tasks > 0: + print( + f"{get_current_time_in_ms()} Waiting for active tasks to complete" + ) + condition.wait() + # with cluster.connect(keyspace) as session: # for connection in cluster.get_connection_holders(): # current_datetime_with_ms = datetime.now().strftime( @@ -677,6 +693,8 @@ def online_write_batch( # print( # f"{current_datetime_with_ms} {cluster.client_id}: Done Calling CassandraOnlineStore online_write_batch method" # ) + logger.info("Done writing data to CassandraOnlineStore") + print(f"{get_current_time_in_ms()} Done writing data to CassandraOnlineStore") # correction for the last missing call to `progress`: if progress: progress(1)