24
24
from langgraph .checkpoint .redis .ashallow import AsyncShallowRedisSaver
25
25
from langgraph .checkpoint .redis .base import BaseRedisSaver
26
26
from langgraph .checkpoint .redis .shallow import ShallowRedisSaver
27
+ from langgraph .checkpoint .redis .util import (
28
+ EMPTY_ID_SENTINEL ,
29
+ from_storage_safe_id ,
30
+ from_storage_safe_str ,
31
+ to_storage_safe_id ,
32
+ to_storage_safe_str ,
33
+ )
27
34
from langgraph .checkpoint .redis .version import __lib_name__ , __version__
28
35
29
36
@@ -79,12 +86,21 @@ def list(
79
86
filter_expression = []
80
87
if config :
81
88
filter_expression .append (
82
- Tag ("thread_id" ) == config ["configurable" ]["thread_id" ]
89
+ Tag ("thread_id" )
90
+ == to_storage_safe_id (config ["configurable" ]["thread_id" ])
83
91
)
92
+
93
+ # Reproducing the logic from the Postgres implementation, we'll
94
+ # search for checkpoints with any namespace, including an empty
95
+ # string, while `checkpoint_id` has to have a value.
84
96
if checkpoint_ns := config ["configurable" ].get ("checkpoint_ns" ):
85
- filter_expression .append (Tag ("checkpoint_ns" ) == checkpoint_ns )
97
+ filter_expression .append (
98
+ Tag ("checkpoint_ns" ) == to_storage_safe_str (checkpoint_ns )
99
+ )
86
100
if checkpoint_id := get_checkpoint_id (config ):
87
- filter_expression .append (Tag ("checkpoint_id" ) == checkpoint_id )
101
+ filter_expression .append (
102
+ Tag ("checkpoint_id" ) == to_storage_safe_id (checkpoint_id )
103
+ )
88
104
89
105
if filter :
90
106
for k , v in filter .items ():
@@ -122,9 +138,10 @@ def list(
122
138
123
139
# Process the results
124
140
for doc in results .docs :
125
- thread_id = str (getattr (doc , "thread_id" , "" ))
126
- checkpoint_ns = str (getattr (doc , "checkpoint_ns" , "" ))
127
- checkpoint_id = str (getattr (doc , "checkpoint_id" , "" ))
141
+ thread_id = from_storage_safe_id (doc ["thread_id" ])
142
+ checkpoint_ns = from_storage_safe_str (doc ["checkpoint_ns" ])
143
+ checkpoint_id = from_storage_safe_id (doc ["checkpoint_id" ])
144
+ parent_checkpoint_id = from_storage_safe_id (doc ["parent_checkpoint_id" ])
128
145
129
146
# Fetch channel_values
130
147
channel_values = self .get_channel_values (
@@ -135,11 +152,11 @@ def list(
135
152
136
153
# Fetch pending_sends from parent checkpoint
137
154
pending_sends = []
138
- if doc [ " parent_checkpoint_id" ] :
155
+ if parent_checkpoint_id :
139
156
pending_sends = self ._load_pending_sends (
140
157
thread_id = thread_id ,
141
158
checkpoint_ns = checkpoint_ns ,
142
- parent_checkpoint_id = doc [ " parent_checkpoint_id" ] ,
159
+ parent_checkpoint_id = parent_checkpoint_id ,
143
160
)
144
161
145
162
# Fetch and parse metadata
@@ -163,7 +180,7 @@ def list(
163
180
"configurable" : {
164
181
"thread_id" : thread_id ,
165
182
"checkpoint_ns" : checkpoint_ns ,
166
- "checkpoint_id" : doc [ " checkpoint_id" ] ,
183
+ "checkpoint_id" : checkpoint_id ,
167
184
}
168
185
}
169
186
@@ -194,49 +211,60 @@ def put(
194
211
) -> RunnableConfig :
195
212
"""Store a checkpoint to Redis."""
196
213
configurable = config ["configurable" ].copy ()
214
+
197
215
thread_id = configurable .pop ("thread_id" )
198
216
checkpoint_ns = configurable .pop ("checkpoint_ns" )
199
- checkpoint_id = configurable .pop (
200
- "checkpoint_id" , configurable .pop ("thread_ts" , None )
217
+ checkpoint_id = checkpoint_id = configurable .pop (
218
+ "checkpoint_id" , configurable .pop ("thread_ts" , "" )
201
219
)
202
220
221
+ # For values we store in Redis, we need to convert empty strings to the
222
+ # sentinel value.
223
+ storage_safe_thread_id = to_storage_safe_id (thread_id )
224
+ storage_safe_checkpoint_ns = to_storage_safe_str (checkpoint_ns )
225
+ storage_safe_checkpoint_id = to_storage_safe_id (checkpoint_id )
226
+
203
227
copy = checkpoint .copy ()
228
+ # When we return the config, we need to preserve empty strings that
229
+ # were passed in, instead of the sentinel value.
204
230
next_config = {
205
231
"configurable" : {
206
232
"thread_id" : thread_id ,
207
233
"checkpoint_ns" : checkpoint_ns ,
208
- "checkpoint_id" : checkpoint [ "id" ] ,
234
+ "checkpoint_id" : checkpoint_id ,
209
235
}
210
236
}
211
237
212
- # Store checkpoint data
238
+ # Store checkpoint data.
213
239
checkpoint_data = {
214
- "thread_id" : thread_id ,
215
- "checkpoint_ns" : checkpoint_ns ,
216
- "checkpoint_id" : checkpoint [ "id" ] ,
217
- "parent_checkpoint_id" : checkpoint_id ,
240
+ "thread_id" : storage_safe_thread_id ,
241
+ "checkpoint_ns" : storage_safe_checkpoint_ns ,
242
+ "checkpoint_id" : storage_safe_checkpoint_id ,
243
+ "parent_checkpoint_id" : storage_safe_checkpoint_id ,
218
244
"checkpoint" : self ._dump_checkpoint (copy ),
219
245
"metadata" : self ._dump_metadata (metadata ),
220
246
}
221
247
222
248
# store at top-level for filters in list()
223
249
if all (key in metadata for key in ["source" , "step" ]):
224
250
checkpoint_data ["source" ] = metadata ["source" ]
225
- checkpoint_data ["step" ] = metadata ["step" ]
251
+ checkpoint_data ["step" ] = metadata ["step" ] # type: ignore
226
252
227
253
self .checkpoints_index .load (
228
254
[checkpoint_data ],
229
255
keys = [
230
256
BaseRedisSaver ._make_redis_checkpoint_key (
231
- thread_id , checkpoint_ns , checkpoint ["id" ]
257
+ storage_safe_thread_id ,
258
+ storage_safe_checkpoint_ns ,
259
+ storage_safe_checkpoint_id ,
232
260
)
233
261
],
234
262
)
235
263
236
- # Store blob values
264
+ # Store blob values.
237
265
blobs = self ._dump_blobs (
238
- thread_id ,
239
- checkpoint_ns ,
266
+ storage_safe_thread_id ,
267
+ storage_safe_checkpoint_ns ,
240
268
copy .get ("channel_values" , {}),
241
269
new_versions ,
242
270
)
@@ -258,19 +286,22 @@ def get_tuple(self, config: RunnableConfig) -> Optional[CheckpointTuple]:
258
286
Optional[CheckpointTuple]: The retrieved checkpoint tuple, or None if no matching checkpoint was found.
259
287
"""
260
288
thread_id = config ["configurable" ]["thread_id" ]
261
- checkpoint_id = str ( get_checkpoint_id (config ) )
289
+ checkpoint_id = get_checkpoint_id (config )
262
290
checkpoint_ns = config ["configurable" ].get ("checkpoint_ns" , "" )
263
291
264
- if checkpoint_id :
292
+ ascending = True
293
+
294
+ if checkpoint_id and checkpoint_id != EMPTY_ID_SENTINEL :
265
295
checkpoint_filter_expression = (
266
- (Tag ("thread_id" ) == thread_id )
267
- & (Tag ("checkpoint_ns" ) == checkpoint_ns )
268
- & (Tag ("checkpoint_id" ) == checkpoint_id )
296
+ (Tag ("thread_id" ) == to_storage_safe_id ( thread_id ) )
297
+ & (Tag ("checkpoint_ns" ) == to_storage_safe_str ( checkpoint_ns ) )
298
+ & (Tag ("checkpoint_id" ) == to_storage_safe_id ( checkpoint_id ) )
269
299
)
270
300
else :
271
- checkpoint_filter_expression = (Tag ("thread_id" ) == thread_id ) & (
272
- Tag ("checkpoint_ns" ) == checkpoint_ns
273
- )
301
+ checkpoint_filter_expression = (
302
+ Tag ("thread_id" ) == to_storage_safe_id (thread_id )
303
+ ) & (Tag ("checkpoint_ns" ) == to_storage_safe_str (checkpoint_ns ))
304
+ ascending = False
274
305
275
306
# Construct the query
276
307
checkpoints_query = FilterQuery (
@@ -285,29 +316,33 @@ def get_tuple(self, config: RunnableConfig) -> Optional[CheckpointTuple]:
285
316
],
286
317
num_results = 1 ,
287
318
)
288
- checkpoints_query .sort_by ("checkpoint_id" , asc = False )
319
+ checkpoints_query .sort_by ("checkpoint_id" , asc = ascending )
289
320
290
321
# Execute the query
291
322
results = self .checkpoints_index .search (checkpoints_query )
292
323
if not results .docs :
293
324
return None
294
325
295
326
doc = results .docs [0 ]
327
+ doc_thread_id = from_storage_safe_id (doc ["thread_id" ])
328
+ doc_checkpoint_ns = from_storage_safe_str (doc ["checkpoint_ns" ])
329
+ doc_checkpoint_id = from_storage_safe_id (doc ["checkpoint_id" ])
330
+ doc_parent_checkpoint_id = from_storage_safe_id (doc ["parent_checkpoint_id" ])
296
331
297
332
# Fetch channel_values
298
333
channel_values = self .get_channel_values (
299
- thread_id = doc [ "thread_id" ] ,
300
- checkpoint_ns = doc [ "checkpoint_ns" ] ,
301
- checkpoint_id = doc [ "checkpoint_id" ] ,
334
+ thread_id = doc_thread_id ,
335
+ checkpoint_ns = doc_checkpoint_ns ,
336
+ checkpoint_id = doc_checkpoint_id ,
302
337
)
303
338
304
339
# Fetch pending_sends from parent checkpoint
305
340
pending_sends = []
306
- if doc [ "parent_checkpoint_id" ] :
341
+ if doc_parent_checkpoint_id :
307
342
pending_sends = self ._load_pending_sends (
308
- thread_id = thread_id ,
309
- checkpoint_ns = checkpoint_ns ,
310
- parent_checkpoint_id = doc [ "parent_checkpoint_id" ] ,
343
+ thread_id = doc_thread_id ,
344
+ checkpoint_ns = doc_checkpoint_ns ,
345
+ parent_checkpoint_id = doc_parent_checkpoint_id ,
311
346
)
312
347
313
348
# Fetch and parse metadata
@@ -329,7 +364,7 @@ def get_tuple(self, config: RunnableConfig) -> Optional[CheckpointTuple]:
329
364
"configurable" : {
330
365
"thread_id" : thread_id ,
331
366
"checkpoint_ns" : checkpoint_ns ,
332
- "checkpoint_id" : doc [ "checkpoint_id" ] ,
367
+ "checkpoint_id" : doc_checkpoint_id ,
333
368
}
334
369
}
335
370
@@ -340,7 +375,7 @@ def get_tuple(self, config: RunnableConfig) -> Optional[CheckpointTuple]:
340
375
)
341
376
342
377
pending_writes = self ._load_pending_writes (
343
- thread_id , checkpoint_ns , checkpoint_id
378
+ thread_id , checkpoint_ns , doc_checkpoint_id
344
379
)
345
380
346
381
return CheckpointTuple (
@@ -379,10 +414,14 @@ def get_channel_values(
379
414
self , thread_id : str , checkpoint_ns : str = "" , checkpoint_id : str = ""
380
415
) -> dict [str , Any ]:
381
416
"""Retrieve channel_values dictionary with properly constructed message objects."""
417
+ storage_safe_thread_id = to_storage_safe_id (thread_id )
418
+ storage_safe_checkpoint_ns = to_storage_safe_str (checkpoint_ns )
419
+ storage_safe_checkpoint_id = to_storage_safe_id (checkpoint_id )
420
+
382
421
checkpoint_query = FilterQuery (
383
- filter_expression = (Tag ("thread_id" ) == thread_id )
384
- & (Tag ("checkpoint_ns" ) == checkpoint_ns )
385
- & (Tag ("checkpoint_id" ) == checkpoint_id ),
422
+ filter_expression = (Tag ("thread_id" ) == storage_safe_thread_id )
423
+ & (Tag ("checkpoint_ns" ) == storage_safe_checkpoint_ns )
424
+ & (Tag ("checkpoint_id" ) == storage_safe_checkpoint_id ),
386
425
return_fields = ["$.checkpoint.channel_versions" ],
387
426
num_results = 1 ,
388
427
)
@@ -400,8 +439,8 @@ def get_channel_values(
400
439
channel_values = {}
401
440
for channel , version in channel_versions .items ():
402
441
blob_query = FilterQuery (
403
- filter_expression = (Tag ("thread_id" ) == thread_id )
404
- & (Tag ("checkpoint_ns" ) == checkpoint_ns )
442
+ filter_expression = (Tag ("thread_id" ) == storage_safe_thread_id )
443
+ & (Tag ("checkpoint_ns" ) == storage_safe_checkpoint_ns )
405
444
& (Tag ("channel" ) == channel )
406
445
& (Tag ("version" ) == version ),
407
446
return_fields = ["type" , "$.blob" ],
@@ -437,11 +476,15 @@ def _load_pending_sends(
437
476
Returns:
438
477
List of (type, blob) tuples representing pending sends
439
478
"""
479
+ storage_safe_thread_id = to_storage_safe_str (thread_id )
480
+ storage_safe_checkpoint_ns = to_storage_safe_str (checkpoint_ns )
481
+ storage_safe_parent_checkpoint_id = to_storage_safe_str (parent_checkpoint_id )
482
+
440
483
# Query checkpoint_writes for parent checkpoint's TASKS channel
441
484
parent_writes_query = FilterQuery (
442
- filter_expression = (Tag ("thread_id" ) == thread_id )
443
- & (Tag ("checkpoint_ns" ) == checkpoint_ns )
444
- & (Tag ("checkpoint_id" ) == parent_checkpoint_id )
485
+ filter_expression = (Tag ("thread_id" ) == storage_safe_thread_id )
486
+ & (Tag ("checkpoint_ns" ) == storage_safe_checkpoint_ns )
487
+ & (Tag ("checkpoint_id" ) == storage_safe_parent_checkpoint_id )
445
488
& (Tag ("channel" ) == TASKS ),
446
489
return_fields = ["type" , "blob" , "task_path" , "task_id" , "idx" ],
447
490
num_results = 100 , # Adjust as needed
0 commit comments