Skip to content

Commit 0f70571

Browse files
committed
fix(redis): implement transaction handling for Redis checkpointing (#11)
- Add transaction handling to AsyncRedisSaver.aput and aput_writes methods - Add transaction handling to AsyncShallowRedisSaver.aput method - Fix typing issue in shallow.py - Add comprehensive tests for interruption handling - Ensure atomic operations in Redis using pipeline with transaction=True - Proper handling of asyncio.CancelledError during interruptions
1 parent 6bfdcd0 commit 0f70571

File tree

5 files changed

+791
-179
lines changed

5 files changed

+791
-179
lines changed

Diff for: langgraph/checkpoint/redis/aio.py

+122-46
Original file line numberDiff line numberDiff line change
@@ -384,7 +384,24 @@ async def aput(
384384
metadata: CheckpointMetadata,
385385
new_versions: ChannelVersions,
386386
) -> RunnableConfig:
387-
"""Store a checkpoint to Redis."""
387+
"""Store a checkpoint to Redis with proper transaction handling.
388+
389+
This method ensures that all Redis operations are performed atomically
390+
using Redis transactions. In case of interruption (asyncio.CancelledError),
391+
the transaction will be aborted, ensuring consistency.
392+
393+
Args:
394+
config: The config to associate with the checkpoint
395+
checkpoint: The checkpoint data to store
396+
metadata: Additional metadata to save with the checkpoint
397+
new_versions: New channel versions as of this write
398+
399+
Returns:
400+
Updated configuration after storing the checkpoint
401+
402+
Raises:
403+
asyncio.CancelledError: If the operation is cancelled/interrupted
404+
"""
388405
configurable = config["configurable"].copy()
389406

390407
thread_id = configurable.pop("thread_id")
@@ -410,46 +427,63 @@ async def aput(
410427
}
411428
}
412429

413-
# Store checkpoint data
414-
checkpoint_data = {
415-
"thread_id": storage_safe_thread_id,
416-
"checkpoint_ns": storage_safe_checkpoint_ns,
417-
"checkpoint_id": storage_safe_checkpoint_id,
418-
"parent_checkpoint_id": storage_safe_checkpoint_id,
419-
"checkpoint": self._dump_checkpoint(copy),
420-
"metadata": self._dump_metadata(metadata),
421-
}
422-
423-
# store at top-level for filters in list()
424-
if all(key in metadata for key in ["source", "step"]):
425-
checkpoint_data["source"] = metadata["source"]
426-
checkpoint_data["step"] = metadata["step"] # type: ignore
427-
428-
await self.checkpoints_index.load(
429-
[checkpoint_data],
430-
keys=[
431-
BaseRedisSaver._make_redis_checkpoint_key(
432-
storage_safe_thread_id,
433-
storage_safe_checkpoint_ns,
434-
storage_safe_checkpoint_id,
435-
)
436-
],
437-
)
438-
439-
# Store blob values
440-
blobs = self._dump_blobs(
441-
storage_safe_thread_id,
442-
storage_safe_checkpoint_ns,
443-
copy.get("channel_values", {}),
444-
new_versions,
445-
)
446-
447-
if blobs:
448-
# Unzip the list of tuples into separate lists for keys and data
449-
keys, data = zip(*blobs)
450-
await self.checkpoint_blobs_index.load(list(data), keys=list(keys))
451-
452-
return next_config
430+
# Store checkpoint data with transaction handling
431+
try:
432+
# Create a pipeline with transaction=True for atomicity
433+
pipeline = self._redis.pipeline(transaction=True)
434+
435+
# Store checkpoint data
436+
checkpoint_data = {
437+
"thread_id": storage_safe_thread_id,
438+
"checkpoint_ns": storage_safe_checkpoint_ns,
439+
"checkpoint_id": storage_safe_checkpoint_id,
440+
"parent_checkpoint_id": storage_safe_checkpoint_id,
441+
"checkpoint": self._dump_checkpoint(copy),
442+
"metadata": self._dump_metadata(metadata),
443+
}
444+
445+
# store at top-level for filters in list()
446+
if all(key in metadata for key in ["source", "step"]):
447+
checkpoint_data["source"] = metadata["source"]
448+
checkpoint_data["step"] = metadata["step"] # type: ignore
449+
450+
# Prepare checkpoint key
451+
checkpoint_key = BaseRedisSaver._make_redis_checkpoint_key(
452+
storage_safe_thread_id,
453+
storage_safe_checkpoint_ns,
454+
storage_safe_checkpoint_id,
455+
)
456+
457+
# Add checkpoint data to Redis
458+
await pipeline.json().set(checkpoint_key, "$", checkpoint_data)
459+
460+
# Store blob values
461+
blobs = self._dump_blobs(
462+
storage_safe_thread_id,
463+
storage_safe_checkpoint_ns,
464+
copy.get("channel_values", {}),
465+
new_versions,
466+
)
467+
468+
if blobs:
469+
# Add all blob operations to the pipeline
470+
for key, data in blobs:
471+
await pipeline.json().set(key, "$", data)
472+
473+
# Execute all operations atomically
474+
await pipeline.execute()
475+
476+
return next_config
477+
478+
except asyncio.CancelledError:
479+
# Handle cancellation/interruption
480+
# Pipeline will be automatically discarded
481+
# Either all operations succeed or none do
482+
raise
483+
484+
except Exception as e:
485+
# Re-raise other exceptions
486+
raise e
453487

454488
async def aput_writes(
455489
self,
@@ -458,14 +492,23 @@ async def aput_writes(
458492
task_id: str,
459493
task_path: str = "",
460494
) -> None:
461-
"""Store intermediate writes linked to a checkpoint using Redis JSON.
495+
"""Store intermediate writes linked to a checkpoint using Redis JSON with transaction handling.
496+
497+
This method uses Redis pipeline with transaction=True to ensure atomicity of all
498+
write operations. In case of interruption, all operations will be aborted.
462499
463500
Args:
464501
config (RunnableConfig): Configuration of the related checkpoint.
465502
writes (List[Tuple[str, Any]]): List of writes to store.
466503
task_id (str): Identifier for the task creating the writes.
467504
task_path (str): Path of the task creating the writes.
505+
506+
Raises:
507+
asyncio.CancelledError: If the operation is cancelled/interrupted
468508
"""
509+
if not writes:
510+
return
511+
469512
thread_id = config["configurable"]["thread_id"]
470513
checkpoint_ns = config["configurable"].get("checkpoint_ns", "")
471514
checkpoint_id = config["configurable"]["checkpoint_id"]
@@ -487,7 +530,14 @@ async def aput_writes(
487530
}
488531
writes_objects.append(write_obj)
489532

533+
try:
534+
# Use a transaction pipeline for atomicity
535+
pipeline = self._redis.pipeline(transaction=True)
536+
537+
# Determine if this is an upsert case
490538
upsert_case = all(w[0] in WRITES_IDX_MAP for w in writes)
539+
540+
# Add all write operations to the pipeline
491541
for write_obj in writes_objects:
492542
key = self._make_redis_checkpoint_writes_key(
493543
thread_id,
@@ -496,10 +546,36 @@ async def aput_writes(
496546
task_id,
497547
write_obj["idx"], # type: ignore[arg-type]
498548
)
499-
tx = partial(
500-
_write_obj_tx, key=key, write_obj=write_obj, upsert_case=upsert_case
501-
)
502-
await self._redis.transaction(tx, key)
549+
550+
if upsert_case:
551+
# For upsert case, we need to check if the key exists and update differently
552+
exists = await self._redis.exists(key)
553+
if exists:
554+
# Update existing key
555+
await pipeline.json().set(key, "$.channel", write_obj["channel"])
556+
await pipeline.json().set(key, "$.type", write_obj["type"])
557+
await pipeline.json().set(key, "$.blob", write_obj["blob"])
558+
else:
559+
# Create new key
560+
await pipeline.json().set(key, "$", write_obj)
561+
else:
562+
# For non-upsert case, only set if key doesn't exist
563+
exists = await self._redis.exists(key)
564+
if not exists:
565+
await pipeline.json().set(key, "$", write_obj)
566+
567+
# Execute all operations atomically
568+
await pipeline.execute()
569+
570+
except asyncio.CancelledError:
571+
# Handle cancellation/interruption
572+
# Pipeline will be automatically discarded
573+
# Either all operations succeed or none do
574+
raise
575+
576+
except Exception as e:
577+
# Re-raise other exceptions
578+
raise e
503579

504580
def put_writes(
505581
self,

0 commit comments

Comments
 (0)