Skip to content

Commit d9c566a

Browse files
authored
Handle missing config keys during checkpoint search (#10)
* Handle missing config keys during checkpoint search When queried for checkpoints of a root graph, the `config` object the checkpointer receives will not have a `checkpoint_id` or `checkpoint_ns` key. Currently, we interpret their absence incorrectly by searching for the string literal "None", which doesn't find a matching checkpoint. Unfortunately, older RediSearch versions don't support the INDEXEMPTY and INDEXMISSING options introduced in version 2.10, and the library this package uses, RedisVL, doesn't support those options yet. In this PR, we introduce sentinel values for strings and IDs that allow querying for empty values. This solves the problem with some loss of elegance and grace.
1 parent 2763812 commit d9c566a

File tree

9 files changed

+472
-120
lines changed

9 files changed

+472
-120
lines changed

.github/workflows/test.yml

Lines changed: 4 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -7,8 +7,12 @@ on:
77
branches:
88
- main
99

10+
schedule:
11+
- cron: "0 2 * * *" # 2 AM UTC nightly
12+
1013
workflow_dispatch:
1114

15+
1216
env:
1317
POETRY_VERSION: "1.8.3"
1418

langgraph/checkpoint/redis/__init__.py

Lines changed: 91 additions & 48 deletions
Original file line numberDiff line numberDiff line change
@@ -24,6 +24,13 @@
2424
from langgraph.checkpoint.redis.ashallow import AsyncShallowRedisSaver
2525
from langgraph.checkpoint.redis.base import BaseRedisSaver
2626
from langgraph.checkpoint.redis.shallow import ShallowRedisSaver
27+
from langgraph.checkpoint.redis.util import (
28+
EMPTY_ID_SENTINEL,
29+
from_storage_safe_id,
30+
from_storage_safe_str,
31+
to_storage_safe_id,
32+
to_storage_safe_str,
33+
)
2734
from langgraph.checkpoint.redis.version import __lib_name__, __version__
2835

2936

@@ -79,12 +86,21 @@ def list(
7986
filter_expression = []
8087
if config:
8188
filter_expression.append(
82-
Tag("thread_id") == config["configurable"]["thread_id"]
89+
Tag("thread_id")
90+
== to_storage_safe_id(config["configurable"]["thread_id"])
8391
)
92+
93+
# Reproducing the logic from the Postgres implementation, we'll
94+
# search for checkpoints with any namespace, including an empty
95+
# string, while `checkpoint_id` has to have a value.
8496
if checkpoint_ns := config["configurable"].get("checkpoint_ns"):
85-
filter_expression.append(Tag("checkpoint_ns") == checkpoint_ns)
97+
filter_expression.append(
98+
Tag("checkpoint_ns") == to_storage_safe_str(checkpoint_ns)
99+
)
86100
if checkpoint_id := get_checkpoint_id(config):
87-
filter_expression.append(Tag("checkpoint_id") == checkpoint_id)
101+
filter_expression.append(
102+
Tag("checkpoint_id") == to_storage_safe_id(checkpoint_id)
103+
)
88104

89105
if filter:
90106
for k, v in filter.items():
@@ -122,9 +138,10 @@ def list(
122138

123139
# Process the results
124140
for doc in results.docs:
125-
thread_id = str(getattr(doc, "thread_id", ""))
126-
checkpoint_ns = str(getattr(doc, "checkpoint_ns", ""))
127-
checkpoint_id = str(getattr(doc, "checkpoint_id", ""))
141+
thread_id = from_storage_safe_id(doc["thread_id"])
142+
checkpoint_ns = from_storage_safe_str(doc["checkpoint_ns"])
143+
checkpoint_id = from_storage_safe_id(doc["checkpoint_id"])
144+
parent_checkpoint_id = from_storage_safe_id(doc["parent_checkpoint_id"])
128145

129146
# Fetch channel_values
130147
channel_values = self.get_channel_values(
@@ -135,11 +152,11 @@ def list(
135152

136153
# Fetch pending_sends from parent checkpoint
137154
pending_sends = []
138-
if doc["parent_checkpoint_id"]:
155+
if parent_checkpoint_id:
139156
pending_sends = self._load_pending_sends(
140157
thread_id=thread_id,
141158
checkpoint_ns=checkpoint_ns,
142-
parent_checkpoint_id=doc["parent_checkpoint_id"],
159+
parent_checkpoint_id=parent_checkpoint_id,
143160
)
144161

145162
# Fetch and parse metadata
@@ -163,7 +180,7 @@ def list(
163180
"configurable": {
164181
"thread_id": thread_id,
165182
"checkpoint_ns": checkpoint_ns,
166-
"checkpoint_id": doc["checkpoint_id"],
183+
"checkpoint_id": checkpoint_id,
167184
}
168185
}
169186

@@ -194,49 +211,60 @@ def put(
194211
) -> RunnableConfig:
195212
"""Store a checkpoint to Redis."""
196213
configurable = config["configurable"].copy()
214+
197215
thread_id = configurable.pop("thread_id")
198216
checkpoint_ns = configurable.pop("checkpoint_ns")
199-
checkpoint_id = configurable.pop(
200-
"checkpoint_id", configurable.pop("thread_ts", None)
217+
checkpoint_id = checkpoint_id = configurable.pop(
218+
"checkpoint_id", configurable.pop("thread_ts", "")
201219
)
202220

221+
# For values we store in Redis, we need to convert empty strings to the
222+
# sentinel value.
223+
storage_safe_thread_id = to_storage_safe_id(thread_id)
224+
storage_safe_checkpoint_ns = to_storage_safe_str(checkpoint_ns)
225+
storage_safe_checkpoint_id = to_storage_safe_id(checkpoint_id)
226+
203227
copy = checkpoint.copy()
228+
# When we return the config, we need to preserve empty strings that
229+
# were passed in, instead of the sentinel value.
204230
next_config = {
205231
"configurable": {
206232
"thread_id": thread_id,
207233
"checkpoint_ns": checkpoint_ns,
208-
"checkpoint_id": checkpoint["id"],
234+
"checkpoint_id": checkpoint_id,
209235
}
210236
}
211237

212-
# Store checkpoint data
238+
# Store checkpoint data.
213239
checkpoint_data = {
214-
"thread_id": thread_id,
215-
"checkpoint_ns": checkpoint_ns,
216-
"checkpoint_id": checkpoint["id"],
217-
"parent_checkpoint_id": checkpoint_id,
240+
"thread_id": storage_safe_thread_id,
241+
"checkpoint_ns": storage_safe_checkpoint_ns,
242+
"checkpoint_id": storage_safe_checkpoint_id,
243+
"parent_checkpoint_id": storage_safe_checkpoint_id,
218244
"checkpoint": self._dump_checkpoint(copy),
219245
"metadata": self._dump_metadata(metadata),
220246
}
221247

222248
# store at top-level for filters in list()
223249
if all(key in metadata for key in ["source", "step"]):
224250
checkpoint_data["source"] = metadata["source"]
225-
checkpoint_data["step"] = metadata["step"]
251+
checkpoint_data["step"] = metadata["step"] # type: ignore
226252

227253
self.checkpoints_index.load(
228254
[checkpoint_data],
229255
keys=[
230256
BaseRedisSaver._make_redis_checkpoint_key(
231-
thread_id, checkpoint_ns, checkpoint["id"]
257+
storage_safe_thread_id,
258+
storage_safe_checkpoint_ns,
259+
storage_safe_checkpoint_id,
232260
)
233261
],
234262
)
235263

236-
# Store blob values
264+
# Store blob values.
237265
blobs = self._dump_blobs(
238-
thread_id,
239-
checkpoint_ns,
266+
storage_safe_thread_id,
267+
storage_safe_checkpoint_ns,
240268
copy.get("channel_values", {}),
241269
new_versions,
242270
)
@@ -258,19 +286,22 @@ def get_tuple(self, config: RunnableConfig) -> Optional[CheckpointTuple]:
258286
Optional[CheckpointTuple]: The retrieved checkpoint tuple, or None if no matching checkpoint was found.
259287
"""
260288
thread_id = config["configurable"]["thread_id"]
261-
checkpoint_id = str(get_checkpoint_id(config))
289+
checkpoint_id = get_checkpoint_id(config)
262290
checkpoint_ns = config["configurable"].get("checkpoint_ns", "")
263291

264-
if checkpoint_id:
292+
ascending = True
293+
294+
if checkpoint_id and checkpoint_id != EMPTY_ID_SENTINEL:
265295
checkpoint_filter_expression = (
266-
(Tag("thread_id") == thread_id)
267-
& (Tag("checkpoint_ns") == checkpoint_ns)
268-
& (Tag("checkpoint_id") == checkpoint_id)
296+
(Tag("thread_id") == to_storage_safe_id(thread_id))
297+
& (Tag("checkpoint_ns") == to_storage_safe_str(checkpoint_ns))
298+
& (Tag("checkpoint_id") == to_storage_safe_id(checkpoint_id))
269299
)
270300
else:
271-
checkpoint_filter_expression = (Tag("thread_id") == thread_id) & (
272-
Tag("checkpoint_ns") == checkpoint_ns
273-
)
301+
checkpoint_filter_expression = (
302+
Tag("thread_id") == to_storage_safe_id(thread_id)
303+
) & (Tag("checkpoint_ns") == to_storage_safe_str(checkpoint_ns))
304+
ascending = False
274305

275306
# Construct the query
276307
checkpoints_query = FilterQuery(
@@ -285,29 +316,33 @@ def get_tuple(self, config: RunnableConfig) -> Optional[CheckpointTuple]:
285316
],
286317
num_results=1,
287318
)
288-
checkpoints_query.sort_by("checkpoint_id", asc=False)
319+
checkpoints_query.sort_by("checkpoint_id", asc=ascending)
289320

290321
# Execute the query
291322
results = self.checkpoints_index.search(checkpoints_query)
292323
if not results.docs:
293324
return None
294325

295326
doc = results.docs[0]
327+
doc_thread_id = from_storage_safe_id(doc["thread_id"])
328+
doc_checkpoint_ns = from_storage_safe_str(doc["checkpoint_ns"])
329+
doc_checkpoint_id = from_storage_safe_id(doc["checkpoint_id"])
330+
doc_parent_checkpoint_id = from_storage_safe_id(doc["parent_checkpoint_id"])
296331

297332
# Fetch channel_values
298333
channel_values = self.get_channel_values(
299-
thread_id=doc["thread_id"],
300-
checkpoint_ns=doc["checkpoint_ns"],
301-
checkpoint_id=doc["checkpoint_id"],
334+
thread_id=doc_thread_id,
335+
checkpoint_ns=doc_checkpoint_ns,
336+
checkpoint_id=doc_checkpoint_id,
302337
)
303338

304339
# Fetch pending_sends from parent checkpoint
305340
pending_sends = []
306-
if doc["parent_checkpoint_id"]:
341+
if doc_parent_checkpoint_id:
307342
pending_sends = self._load_pending_sends(
308-
thread_id=thread_id,
309-
checkpoint_ns=checkpoint_ns,
310-
parent_checkpoint_id=doc["parent_checkpoint_id"],
343+
thread_id=doc_thread_id,
344+
checkpoint_ns=doc_checkpoint_ns,
345+
parent_checkpoint_id=doc_parent_checkpoint_id,
311346
)
312347

313348
# Fetch and parse metadata
@@ -329,7 +364,7 @@ def get_tuple(self, config: RunnableConfig) -> Optional[CheckpointTuple]:
329364
"configurable": {
330365
"thread_id": thread_id,
331366
"checkpoint_ns": checkpoint_ns,
332-
"checkpoint_id": doc["checkpoint_id"],
367+
"checkpoint_id": doc_checkpoint_id,
333368
}
334369
}
335370

@@ -340,7 +375,7 @@ def get_tuple(self, config: RunnableConfig) -> Optional[CheckpointTuple]:
340375
)
341376

342377
pending_writes = self._load_pending_writes(
343-
thread_id, checkpoint_ns, checkpoint_id
378+
thread_id, checkpoint_ns, doc_checkpoint_id
344379
)
345380

346381
return CheckpointTuple(
@@ -379,10 +414,14 @@ def get_channel_values(
379414
self, thread_id: str, checkpoint_ns: str = "", checkpoint_id: str = ""
380415
) -> dict[str, Any]:
381416
"""Retrieve channel_values dictionary with properly constructed message objects."""
417+
storage_safe_thread_id = to_storage_safe_id(thread_id)
418+
storage_safe_checkpoint_ns = to_storage_safe_str(checkpoint_ns)
419+
storage_safe_checkpoint_id = to_storage_safe_id(checkpoint_id)
420+
382421
checkpoint_query = FilterQuery(
383-
filter_expression=(Tag("thread_id") == thread_id)
384-
& (Tag("checkpoint_ns") == checkpoint_ns)
385-
& (Tag("checkpoint_id") == checkpoint_id),
422+
filter_expression=(Tag("thread_id") == storage_safe_thread_id)
423+
& (Tag("checkpoint_ns") == storage_safe_checkpoint_ns)
424+
& (Tag("checkpoint_id") == storage_safe_checkpoint_id),
386425
return_fields=["$.checkpoint.channel_versions"],
387426
num_results=1,
388427
)
@@ -400,8 +439,8 @@ def get_channel_values(
400439
channel_values = {}
401440
for channel, version in channel_versions.items():
402441
blob_query = FilterQuery(
403-
filter_expression=(Tag("thread_id") == thread_id)
404-
& (Tag("checkpoint_ns") == checkpoint_ns)
442+
filter_expression=(Tag("thread_id") == storage_safe_thread_id)
443+
& (Tag("checkpoint_ns") == storage_safe_checkpoint_ns)
405444
& (Tag("channel") == channel)
406445
& (Tag("version") == version),
407446
return_fields=["type", "$.blob"],
@@ -437,11 +476,15 @@ def _load_pending_sends(
437476
Returns:
438477
List of (type, blob) tuples representing pending sends
439478
"""
479+
storage_safe_thread_id = to_storage_safe_str(thread_id)
480+
storage_safe_checkpoint_ns = to_storage_safe_str(checkpoint_ns)
481+
storage_safe_parent_checkpoint_id = to_storage_safe_str(parent_checkpoint_id)
482+
440483
# Query checkpoint_writes for parent checkpoint's TASKS channel
441484
parent_writes_query = FilterQuery(
442-
filter_expression=(Tag("thread_id") == thread_id)
443-
& (Tag("checkpoint_ns") == checkpoint_ns)
444-
& (Tag("checkpoint_id") == parent_checkpoint_id)
485+
filter_expression=(Tag("thread_id") == storage_safe_thread_id)
486+
& (Tag("checkpoint_ns") == storage_safe_checkpoint_ns)
487+
& (Tag("checkpoint_id") == storage_safe_parent_checkpoint_id)
445488
& (Tag("channel") == TASKS),
446489
return_fields=["type", "blob", "task_path", "task_id", "idx"],
447490
num_results=100, # Adjust as needed

0 commit comments

Comments
 (0)