@@ -384,7 +384,24 @@ async def aput(
384
384
metadata : CheckpointMetadata ,
385
385
new_versions : ChannelVersions ,
386
386
) -> RunnableConfig :
387
- """Store a checkpoint to Redis."""
387
+ """Store a checkpoint to Redis with proper transaction handling.
388
+
389
+ This method ensures that all Redis operations are performed atomically
390
+ using Redis transactions. In case of interruption (asyncio.CancelledError),
391
+ the transaction will be aborted, ensuring consistency.
392
+
393
+ Args:
394
+ config: The config to associate with the checkpoint
395
+ checkpoint: The checkpoint data to store
396
+ metadata: Additional metadata to save with the checkpoint
397
+ new_versions: New channel versions as of this write
398
+
399
+ Returns:
400
+ Updated configuration after storing the checkpoint
401
+
402
+ Raises:
403
+ asyncio.CancelledError: If the operation is cancelled/interrupted
404
+ """
388
405
configurable = config ["configurable" ].copy ()
389
406
390
407
thread_id = configurable .pop ("thread_id" )
@@ -410,46 +427,63 @@ async def aput(
410
427
}
411
428
}
412
429
413
- # Store checkpoint data
414
- checkpoint_data = {
415
- "thread_id" : storage_safe_thread_id ,
416
- "checkpoint_ns" : storage_safe_checkpoint_ns ,
417
- "checkpoint_id" : storage_safe_checkpoint_id ,
418
- "parent_checkpoint_id" : storage_safe_checkpoint_id ,
419
- "checkpoint" : self ._dump_checkpoint (copy ),
420
- "metadata" : self ._dump_metadata (metadata ),
421
- }
422
-
423
- # store at top-level for filters in list()
424
- if all (key in metadata for key in ["source" , "step" ]):
425
- checkpoint_data ["source" ] = metadata ["source" ]
426
- checkpoint_data ["step" ] = metadata ["step" ] # type: ignore
427
-
428
- await self .checkpoints_index .load (
429
- [checkpoint_data ],
430
- keys = [
431
- BaseRedisSaver ._make_redis_checkpoint_key (
432
- storage_safe_thread_id ,
433
- storage_safe_checkpoint_ns ,
434
- storage_safe_checkpoint_id ,
435
- )
436
- ],
437
- )
438
-
439
- # Store blob values
440
- blobs = self ._dump_blobs (
441
- storage_safe_thread_id ,
442
- storage_safe_checkpoint_ns ,
443
- copy .get ("channel_values" , {}),
444
- new_versions ,
445
- )
446
-
447
- if blobs :
448
- # Unzip the list of tuples into separate lists for keys and data
449
- keys , data = zip (* blobs )
450
- await self .checkpoint_blobs_index .load (list (data ), keys = list (keys ))
451
-
452
- return next_config
430
+ # Store checkpoint data with transaction handling
431
+ try :
432
+ # Create a pipeline with transaction=True for atomicity
433
+ pipeline = self ._redis .pipeline (transaction = True )
434
+
435
+ # Store checkpoint data
436
+ checkpoint_data = {
437
+ "thread_id" : storage_safe_thread_id ,
438
+ "checkpoint_ns" : storage_safe_checkpoint_ns ,
439
+ "checkpoint_id" : storage_safe_checkpoint_id ,
440
+ "parent_checkpoint_id" : storage_safe_checkpoint_id ,
441
+ "checkpoint" : self ._dump_checkpoint (copy ),
442
+ "metadata" : self ._dump_metadata (metadata ),
443
+ }
444
+
445
+ # store at top-level for filters in list()
446
+ if all (key in metadata for key in ["source" , "step" ]):
447
+ checkpoint_data ["source" ] = metadata ["source" ]
448
+ checkpoint_data ["step" ] = metadata ["step" ] # type: ignore
449
+
450
+ # Prepare checkpoint key
451
+ checkpoint_key = BaseRedisSaver ._make_redis_checkpoint_key (
452
+ storage_safe_thread_id ,
453
+ storage_safe_checkpoint_ns ,
454
+ storage_safe_checkpoint_id ,
455
+ )
456
+
457
+ # Add checkpoint data to Redis
458
+ await pipeline .json ().set (checkpoint_key , "$" , checkpoint_data )
459
+
460
+ # Store blob values
461
+ blobs = self ._dump_blobs (
462
+ storage_safe_thread_id ,
463
+ storage_safe_checkpoint_ns ,
464
+ copy .get ("channel_values" , {}),
465
+ new_versions ,
466
+ )
467
+
468
+ if blobs :
469
+ # Add all blob operations to the pipeline
470
+ for key , data in blobs :
471
+ await pipeline .json ().set (key , "$" , data )
472
+
473
+ # Execute all operations atomically
474
+ await pipeline .execute ()
475
+
476
+ return next_config
477
+
478
+ except asyncio .CancelledError :
479
+ # Handle cancellation/interruption
480
+ # Pipeline will be automatically discarded
481
+ # Either all operations succeed or none do
482
+ raise
483
+
484
+ except Exception as e :
485
+ # Re-raise other exceptions
486
+ raise e
453
487
454
488
async def aput_writes (
455
489
self ,
@@ -458,14 +492,23 @@ async def aput_writes(
458
492
task_id : str ,
459
493
task_path : str = "" ,
460
494
) -> None :
461
- """Store intermediate writes linked to a checkpoint using Redis JSON.
495
+ """Store intermediate writes linked to a checkpoint using Redis JSON with transaction handling.
496
+
497
+ This method uses Redis pipeline with transaction=True to ensure atomicity of all
498
+ write operations. In case of interruption, all operations will be aborted.
462
499
463
500
Args:
464
501
config (RunnableConfig): Configuration of the related checkpoint.
465
502
writes (List[Tuple[str, Any]]): List of writes to store.
466
503
task_id (str): Identifier for the task creating the writes.
467
504
task_path (str): Path of the task creating the writes.
505
+
506
+ Raises:
507
+ asyncio.CancelledError: If the operation is cancelled/interrupted
468
508
"""
509
+ if not writes :
510
+ return
511
+
469
512
thread_id = config ["configurable" ]["thread_id" ]
470
513
checkpoint_ns = config ["configurable" ].get ("checkpoint_ns" , "" )
471
514
checkpoint_id = config ["configurable" ]["checkpoint_id" ]
@@ -487,7 +530,14 @@ async def aput_writes(
487
530
}
488
531
writes_objects .append (write_obj )
489
532
533
+ try :
534
+ # Use a transaction pipeline for atomicity
535
+ pipeline = self ._redis .pipeline (transaction = True )
536
+
537
+ # Determine if this is an upsert case
490
538
upsert_case = all (w [0 ] in WRITES_IDX_MAP for w in writes )
539
+
540
+ # Add all write operations to the pipeline
491
541
for write_obj in writes_objects :
492
542
key = self ._make_redis_checkpoint_writes_key (
493
543
thread_id ,
@@ -496,10 +546,36 @@ async def aput_writes(
496
546
task_id ,
497
547
write_obj ["idx" ], # type: ignore[arg-type]
498
548
)
499
- tx = partial (
500
- _write_obj_tx , key = key , write_obj = write_obj , upsert_case = upsert_case
501
- )
502
- await self ._redis .transaction (tx , key )
549
+
550
+ if upsert_case :
551
+ # For upsert case, we need to check if the key exists and update differently
552
+ exists = await self ._redis .exists (key )
553
+ if exists :
554
+ # Update existing key
555
+ await pipeline .json ().set (key , "$.channel" , write_obj ["channel" ])
556
+ await pipeline .json ().set (key , "$.type" , write_obj ["type" ])
557
+ await pipeline .json ().set (key , "$.blob" , write_obj ["blob" ])
558
+ else :
559
+ # Create new key
560
+ await pipeline .json ().set (key , "$" , write_obj )
561
+ else :
562
+ # For non-upsert case, only set if key doesn't exist
563
+ exists = await self ._redis .exists (key )
564
+ if not exists :
565
+ await pipeline .json ().set (key , "$" , write_obj )
566
+
567
+ # Execute all operations atomically
568
+ await pipeline .execute ()
569
+
570
+ except asyncio .CancelledError :
571
+ # Handle cancellation/interruption
572
+ # Pipeline will be automatically discarded
573
+ # Either all operations succeed or none do
574
+ raise
575
+
576
+ except Exception as e :
577
+ # Re-raise other exceptions
578
+ raise e
503
579
504
580
def put_writes (
505
581
self ,
0 commit comments