@@ -385,20 +385,20 @@ async def aput(
385
385
new_versions : ChannelVersions ,
386
386
) -> RunnableConfig :
387
387
"""Store a checkpoint to Redis with proper transaction handling.
388
-
388
+
389
389
This method ensures that all Redis operations are performed atomically
390
390
using Redis transactions. In case of interruption (asyncio.CancelledError),
391
391
the transaction will be aborted, ensuring consistency.
392
-
392
+
393
393
Args:
394
394
config: The config to associate with the checkpoint
395
395
checkpoint: The checkpoint data to store
396
396
metadata: Additional metadata to save with the checkpoint
397
397
new_versions: New channel versions as of this write
398
-
398
+
399
399
Returns:
400
400
Updated configuration after storing the checkpoint
401
-
401
+
402
402
Raises:
403
403
asyncio.CancelledError: If the operation is cancelled/interrupted
404
404
"""
@@ -431,7 +431,7 @@ async def aput(
431
431
try :
432
432
# Create a pipeline with transaction=True for atomicity
433
433
pipeline = self ._redis .pipeline (transaction = True )
434
-
434
+
435
435
# Store checkpoint data
436
436
checkpoint_data = {
437
437
"thread_id" : storage_safe_thread_id ,
@@ -441,46 +441,46 @@ async def aput(
441
441
"checkpoint" : self ._dump_checkpoint (copy ),
442
442
"metadata" : self ._dump_metadata (metadata ),
443
443
}
444
-
444
+
445
445
# store at top-level for filters in list()
446
446
if all (key in metadata for key in ["source" , "step" ]):
447
447
checkpoint_data ["source" ] = metadata ["source" ]
448
448
checkpoint_data ["step" ] = metadata ["step" ] # type: ignore
449
-
449
+
450
450
# Prepare checkpoint key
451
451
checkpoint_key = BaseRedisSaver ._make_redis_checkpoint_key (
452
452
storage_safe_thread_id ,
453
453
storage_safe_checkpoint_ns ,
454
454
storage_safe_checkpoint_id ,
455
455
)
456
-
456
+
457
457
# Add checkpoint data to Redis
458
458
await pipeline .json ().set (checkpoint_key , "$" , checkpoint_data )
459
-
459
+
460
460
# Store blob values
461
461
blobs = self ._dump_blobs (
462
462
storage_safe_thread_id ,
463
463
storage_safe_checkpoint_ns ,
464
464
copy .get ("channel_values" , {}),
465
465
new_versions ,
466
466
)
467
-
467
+
468
468
if blobs :
469
469
# Add all blob operations to the pipeline
470
470
for key , data in blobs :
471
471
await pipeline .json ().set (key , "$" , data )
472
-
472
+
473
473
# Execute all operations atomically
474
474
await pipeline .execute ()
475
-
475
+
476
476
return next_config
477
-
477
+
478
478
except asyncio .CancelledError :
479
479
# Handle cancellation/interruption
480
480
# Pipeline will be automatically discarded
481
481
# Either all operations succeed or none do
482
482
raise
483
-
483
+
484
484
except Exception as e :
485
485
# Re-raise other exceptions
486
486
raise e
@@ -502,13 +502,13 @@ async def aput_writes(
502
502
writes (List[Tuple[str, Any]]): List of writes to store.
503
503
task_id (str): Identifier for the task creating the writes.
504
504
task_path (str): Path of the task creating the writes.
505
-
505
+
506
506
Raises:
507
507
asyncio.CancelledError: If the operation is cancelled/interrupted
508
508
"""
509
509
if not writes :
510
510
return
511
-
511
+
512
512
thread_id = config ["configurable" ]["thread_id" ]
513
513
checkpoint_ns = config ["configurable" ].get ("checkpoint_ns" , "" )
514
514
checkpoint_id = config ["configurable" ]["checkpoint_id" ]
@@ -533,10 +533,10 @@ async def aput_writes(
533
533
try :
534
534
# Use a transaction pipeline for atomicity
535
535
pipeline = self ._redis .pipeline (transaction = True )
536
-
536
+
537
537
# Determine if this is an upsert case
538
538
upsert_case = all (w [0 ] in WRITES_IDX_MAP for w in writes )
539
-
539
+
540
540
# Add all write operations to the pipeline
541
541
for write_obj in writes_objects :
542
542
key = self ._make_redis_checkpoint_writes_key (
@@ -546,13 +546,15 @@ async def aput_writes(
546
546
task_id ,
547
547
write_obj ["idx" ], # type: ignore[arg-type]
548
548
)
549
-
549
+
550
550
if upsert_case :
551
551
# For upsert case, we need to check if the key exists and update differently
552
552
exists = await self ._redis .exists (key )
553
553
if exists :
554
554
# Update existing key
555
- await pipeline .json ().set (key , "$.channel" , write_obj ["channel" ])
555
+ await pipeline .json ().set (
556
+ key , "$.channel" , write_obj ["channel" ]
557
+ )
556
558
await pipeline .json ().set (key , "$.type" , write_obj ["type" ])
557
559
await pipeline .json ().set (key , "$.blob" , write_obj ["blob" ])
558
560
else :
@@ -563,16 +565,16 @@ async def aput_writes(
563
565
exists = await self ._redis .exists (key )
564
566
if not exists :
565
567
await pipeline .json ().set (key , "$" , write_obj )
566
-
568
+
567
569
# Execute all operations atomically
568
570
await pipeline .execute ()
569
-
571
+
570
572
except asyncio .CancelledError :
571
- # Handle cancellation/interruption
573
+ # Handle cancellation/interruption
572
574
# Pipeline will be automatically discarded
573
575
# Either all operations succeed or none do
574
576
raise
575
-
577
+
576
578
except Exception as e :
577
579
# Re-raise other exceptions
578
580
raise e
0 commit comments