Skip to content

Commit

Permalink
fix: Improve session shutdown logging and enhance error handling in C…
Browse files Browse the repository at this point in the history
…assandraOnlineStore
  • Loading branch information
Bhargav Dodla committed Feb 7, 2025
1 parent d518e9f commit 6f353b3
Showing 1 changed file with 53 additions and 51 deletions.
Original file line number Diff line number Diff line change
Expand Up @@ -336,14 +336,16 @@ def __del__(self):
You'd get a RuntimeError "cannot schedule new futures after shutdown".
"""
print("Called __del__")
session_id = None
if self._session:
if not self._session.is_shutdown:
print("Shutting down session..")
session_id = self._session.session_id
print(f"{session_id} Shutting down session..")
self._session.shutdown()

if self._cluster:
if not self._cluster.is_shutdown:
print("Shutting down cluster..")
print(f"{session_id} {self._cluster.client_id} Shutting down cluster..")
self._cluster.shutdown()

def online_write_batch(
Expand Down Expand Up @@ -373,10 +375,12 @@ def online_write_batch(
def on_success(result, concurrent_queue):
concurrent_queue.get_nowait()

def on_failure(exc, concurrent_queue):
def on_failure(exc, concurrent_queue, session_id):
concurrent_queue.get_nowait()
logger.exception(f"Error writing a batch: {exc}")
raise Exception("Exception raised while writing a batch") from exc
logger.exception(f"Error writing a batch {session_id}: {exc}")
raise Exception(
f"Exception raised while writing a batch {session_id}"
) from exc

override_configs = table.get_online_store_tags

Expand All @@ -397,6 +401,7 @@ def on_failure(exc, concurrent_queue):
rate_limiter = SlidingWindowRateLimiter(write_rate_limit, 1)

session: Session = self._get_session(config)
session_id = session.session_id
keyspace: str = self._keyspace
fqtable = CassandraOnlineStore._fq_table_name(keyspace, project, table)

Expand All @@ -408,55 +413,52 @@ def on_failure(exc, concurrent_queue):
session=session,
)

try:
for entity_key, values, timestamp, created_ts in data:
batch = BatchStatement(batch_type=BatchType.UNLOGGED)
entity_key_bin = serialize_entity_key(
entity_key,
entity_key_serialization_version=config.entity_key_serialization_version,
).hex()
for feature_name, val in values.items():
params: Tuple[str, bytes, str, datetime] = (
feature_name,
val.SerializeToString(),
entity_key_bin,
timestamp,
)
batch.add(insert_cql, params)

# Wait until the rate limiter allows
if not rate_limiter.acquire():
while not rate_limiter.acquire():
time.sleep(0.001)
if not session.is_shutdown:
future = session.execute_async(batch)
concurrent_queue.put(future)
future.add_callbacks(
partial(
on_success,
concurrent_queue=concurrent_queue,
),
partial(
on_failure,
concurrent_queue=concurrent_queue,
),
)
else:
raise Exception("Session is shutdown and cannot be used")

# this happens N-1 times, will be corrected outside:
if progress:
progress(1)
# Wait for all tasks to complete
while not concurrent_queue.empty():
time.sleep(0.001)
for entity_key, values, timestamp, created_ts in data:
batch = BatchStatement(batch_type=BatchType.UNLOGGED)
entity_key_bin = serialize_entity_key(
entity_key,
entity_key_serialization_version=config.entity_key_serialization_version,
).hex()
for feature_name, val in values.items():
params: Tuple[str, bytes, str, datetime] = (
feature_name,
val.SerializeToString(),
entity_key_bin,
timestamp,
)
batch.add(insert_cql, params)

# Wait until the rate limiter allows
if not rate_limiter.acquire():
while not rate_limiter.acquire():
time.sleep(0.001)
if not session.is_shutdown:
future = session.execute_async(batch)
concurrent_queue.put(future)
future.add_callbacks(
partial(
on_success,
concurrent_queue=concurrent_queue,
),
partial(
on_failure,
concurrent_queue=concurrent_queue,
session_id=session_id,
),
)
else:
raise Exception(f"{session_id} Session is shutdown and cannot be used")

# correction for the last missing call to `progress`:
# this happens N-1 times, will be corrected outside:
if progress:
progress(1)
except Exception as e:
logger.exception(f"Unknown exception: {e}")
raise Exception("Unknown exception") from e
# Wait for all tasks to complete
while not concurrent_queue.empty():
time.sleep(0.001)

# correction for the last missing call to `progress`:
if progress:
progress(1)

def online_read(
self,
Expand Down

0 comments on commit 6f353b3

Please sign in to comment.