Skip to content

Commit 21a34fd

Browse files
committed
feat(checkpoint-redis): implement adelete_thread and delete_thread methods (#51)
- Add adelete_thread method to AsyncRedisSaver to delete all checkpoints, blobs, and writes for a thread - Add delete_thread method to RedisSaver for sync operations - Use Redis search indices instead of keys() command for better performance - Batch deletions using Redis pipeline for efficiency
1 parent 35adab0 commit 21a34fd

File tree

3 files changed

+384
-0
lines changed

3 files changed

+384
-0
lines changed

langgraph/checkpoint/redis/__init__.py

Lines changed: 72 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -575,6 +575,78 @@ def _load_pending_sends(
575575
# Extract type and blob pairs
576576
return [(doc.type, doc.blob) for doc in sorted_writes]
577577

578+
def delete_thread(self, thread_id: str) -> None:
579+
"""Delete all checkpoints and writes associated with a specific thread ID.
580+
581+
Args:
582+
thread_id: The thread ID whose checkpoints should be deleted.
583+
"""
584+
storage_safe_thread_id = to_storage_safe_id(thread_id)
585+
586+
# Delete all checkpoints for this thread
587+
checkpoint_query = FilterQuery(
588+
filter_expression=Tag("thread_id") == storage_safe_thread_id,
589+
return_fields=["checkpoint_ns", "checkpoint_id"],
590+
num_results=10000, # Get all checkpoints for this thread
591+
)
592+
593+
checkpoint_results = self.checkpoints_index.search(checkpoint_query)
594+
595+
# Delete all checkpoint-related keys
596+
pipeline = self._redis.pipeline()
597+
598+
for doc in checkpoint_results.docs:
599+
checkpoint_ns = getattr(doc, "checkpoint_ns", "")
600+
checkpoint_id = getattr(doc, "checkpoint_id", "")
601+
602+
# Delete checkpoint key
603+
checkpoint_key = BaseRedisSaver._make_redis_checkpoint_key(
604+
storage_safe_thread_id, checkpoint_ns, checkpoint_id
605+
)
606+
pipeline.delete(checkpoint_key)
607+
608+
# Delete all blobs for this thread
609+
blob_query = FilterQuery(
610+
filter_expression=Tag("thread_id") == storage_safe_thread_id,
611+
return_fields=["checkpoint_ns", "channel", "version"],
612+
num_results=10000,
613+
)
614+
615+
blob_results = self.checkpoint_blobs_index.search(blob_query)
616+
617+
for doc in blob_results.docs:
618+
checkpoint_ns = getattr(doc, "checkpoint_ns", "")
619+
channel = getattr(doc, "channel", "")
620+
version = getattr(doc, "version", "")
621+
622+
blob_key = BaseRedisSaver._make_redis_checkpoint_blob_key(
623+
storage_safe_thread_id, checkpoint_ns, channel, version
624+
)
625+
pipeline.delete(blob_key)
626+
627+
# Delete all writes for this thread
628+
writes_query = FilterQuery(
629+
filter_expression=Tag("thread_id") == storage_safe_thread_id,
630+
return_fields=["checkpoint_ns", "checkpoint_id", "task_id", "idx"],
631+
num_results=10000,
632+
)
633+
634+
writes_results = self.checkpoint_writes_index.search(writes_query)
635+
636+
for doc in writes_results.docs:
637+
checkpoint_ns = getattr(doc, "checkpoint_ns", "")
638+
checkpoint_id = getattr(doc, "checkpoint_id", "")
639+
task_id = getattr(doc, "task_id", "")
640+
idx = getattr(doc, "idx", 0)
641+
642+
write_key = BaseRedisSaver._make_redis_checkpoint_writes_key(
643+
storage_safe_thread_id, checkpoint_ns, checkpoint_id, task_id, idx
644+
)
645+
pipeline.delete(write_key)
646+
647+
# Execute all deletions
648+
pipeline.execute()
649+
578650

579651
__all__ = [
580652
"__version__",

langgraph/checkpoint/redis/aio.py

Lines changed: 72 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -925,3 +925,75 @@ async def _aload_pending_writes(
925925

926926
pending_writes = BaseRedisSaver._load_writes(self.serde, writes_dict)
927927
return pending_writes
928+
929+
async def adelete_thread(self, thread_id: str) -> None:
930+
"""Delete all checkpoints and writes associated with a specific thread ID.
931+
932+
Args:
933+
thread_id: The thread ID whose checkpoints should be deleted.
934+
"""
935+
storage_safe_thread_id = to_storage_safe_id(thread_id)
936+
937+
# Delete all checkpoints for this thread
938+
checkpoint_query = FilterQuery(
939+
filter_expression=Tag("thread_id") == storage_safe_thread_id,
940+
return_fields=["checkpoint_ns", "checkpoint_id"],
941+
num_results=10000, # Get all checkpoints for this thread
942+
)
943+
944+
checkpoint_results = await self.checkpoints_index.search(checkpoint_query)
945+
946+
# Delete all checkpoint-related keys
947+
pipeline = self._redis.pipeline()
948+
949+
for doc in checkpoint_results.docs:
950+
checkpoint_ns = getattr(doc, "checkpoint_ns", "")
951+
checkpoint_id = getattr(doc, "checkpoint_id", "")
952+
953+
# Delete checkpoint key
954+
checkpoint_key = BaseRedisSaver._make_redis_checkpoint_key(
955+
storage_safe_thread_id, checkpoint_ns, checkpoint_id
956+
)
957+
pipeline.delete(checkpoint_key)
958+
959+
# Delete all blobs for this thread
960+
blob_query = FilterQuery(
961+
filter_expression=Tag("thread_id") == storage_safe_thread_id,
962+
return_fields=["checkpoint_ns", "channel", "version"],
963+
num_results=10000,
964+
)
965+
966+
blob_results = await self.checkpoint_blobs_index.search(blob_query)
967+
968+
for doc in blob_results.docs:
969+
checkpoint_ns = getattr(doc, "checkpoint_ns", "")
970+
channel = getattr(doc, "channel", "")
971+
version = getattr(doc, "version", "")
972+
973+
blob_key = BaseRedisSaver._make_redis_checkpoint_blob_key(
974+
storage_safe_thread_id, checkpoint_ns, channel, version
975+
)
976+
pipeline.delete(blob_key)
977+
978+
# Delete all writes for this thread
979+
writes_query = FilterQuery(
980+
filter_expression=Tag("thread_id") == storage_safe_thread_id,
981+
return_fields=["checkpoint_ns", "checkpoint_id", "task_id", "idx"],
982+
num_results=10000,
983+
)
984+
985+
writes_results = await self.checkpoint_writes_index.search(writes_query)
986+
987+
for doc in writes_results.docs:
988+
checkpoint_ns = getattr(doc, "checkpoint_ns", "")
989+
checkpoint_id = getattr(doc, "checkpoint_id", "")
990+
task_id = getattr(doc, "task_id", "")
991+
idx = getattr(doc, "idx", 0)
992+
993+
write_key = BaseRedisSaver._make_redis_checkpoint_writes_key(
994+
storage_safe_thread_id, checkpoint_ns, checkpoint_id, task_id, idx
995+
)
996+
pipeline.delete(write_key)
997+
998+
# Execute all deletions
999+
await pipeline.execute()

tests/test_issue_51_adelete_thread.py

Lines changed: 240 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,240 @@
1+
"""Test for issue #51 - adelete_thread implementation."""
2+
3+
import pytest
4+
from langchain_core.runnables import RunnableConfig
5+
from langgraph.checkpoint.base import Checkpoint, CheckpointMetadata
6+
7+
from langgraph.checkpoint.redis import RedisSaver
8+
from langgraph.checkpoint.redis.aio import AsyncRedisSaver
9+
10+
11+
@pytest.mark.asyncio
12+
async def test_adelete_thread_implemented(redis_url):
13+
"""Test that adelete_thread method is now implemented in AsyncRedisSaver."""
14+
async with AsyncRedisSaver.from_conn_string(redis_url) as checkpointer:
15+
# Create a checkpoint
16+
thread_id = "test-thread-to-delete"
17+
config: RunnableConfig = {
18+
"configurable": {
19+
"thread_id": thread_id,
20+
"checkpoint_ns": "",
21+
"checkpoint_id": "1",
22+
}
23+
}
24+
25+
checkpoint = Checkpoint(
26+
v=1,
27+
id="1",
28+
ts="2024-01-01T00:00:00Z",
29+
channel_values={"messages": ["Test"]},
30+
channel_versions={"messages": "1"},
31+
versions_seen={"agent": {"messages": "1"}},
32+
pending_sends=[],
33+
tasks=[],
34+
)
35+
36+
# Store checkpoint
37+
await checkpointer.aput(
38+
config=config,
39+
checkpoint=checkpoint,
40+
metadata=CheckpointMetadata(source="input", step=0, writes={}),
41+
new_versions={"messages": "1"},
42+
)
43+
44+
# Verify checkpoint exists
45+
result = await checkpointer.aget_tuple(config)
46+
assert result is not None
47+
assert result.checkpoint["id"] == "1"
48+
49+
# Delete the thread
50+
await checkpointer.adelete_thread(thread_id)
51+
52+
# Verify checkpoint is deleted
53+
result = await checkpointer.aget_tuple(config)
54+
assert result is None
55+
56+
57+
def test_delete_thread_implemented(redis_url):
58+
"""Test that delete_thread method is now implemented in RedisSaver."""
59+
with RedisSaver.from_conn_string(redis_url) as checkpointer:
60+
checkpointer.setup() # Initialize Redis indices
61+
62+
# Create a checkpoint
63+
thread_id = "test-thread-to-delete-sync"
64+
config: RunnableConfig = {
65+
"configurable": {
66+
"thread_id": thread_id,
67+
"checkpoint_ns": "",
68+
"checkpoint_id": "1",
69+
}
70+
}
71+
72+
checkpoint = Checkpoint(
73+
v=1,
74+
id="1",
75+
ts="2024-01-01T00:00:00Z",
76+
channel_values={"messages": ["Test"]},
77+
channel_versions={"messages": "1"},
78+
versions_seen={"agent": {"messages": "1"}},
79+
pending_sends=[],
80+
tasks=[],
81+
)
82+
83+
# Store checkpoint
84+
checkpointer.put(
85+
config=config,
86+
checkpoint=checkpoint,
87+
metadata=CheckpointMetadata(source="input", step=0, writes={}),
88+
new_versions={"messages": "1"},
89+
)
90+
91+
# Verify checkpoint exists
92+
result = checkpointer.get_tuple(config)
93+
assert result is not None
94+
assert result.checkpoint["id"] == "1"
95+
96+
# Delete the thread
97+
checkpointer.delete_thread(thread_id)
98+
99+
# Verify checkpoint is deleted
100+
result = checkpointer.get_tuple(config)
101+
assert result is None
102+
103+
104+
@pytest.mark.asyncio
105+
async def test_adelete_thread_comprehensive(redis_url):
106+
"""Comprehensive test for adelete_thread with multiple checkpoints and namespaces."""
107+
async with AsyncRedisSaver.from_conn_string(redis_url) as checkpointer:
108+
thread_id = "test-thread-comprehensive"
109+
other_thread_id = "other-thread"
110+
111+
# Create multiple checkpoints for the thread
112+
checkpoints_data = [
113+
("", "1", {"messages": ["First"]}, "input", 0),
114+
("", "2", {"messages": ["Second"]}, "output", 1),
115+
("ns1", "3", {"messages": ["Third"]}, "input", 0),
116+
("ns2", "4", {"messages": ["Fourth"]}, "output", 1),
117+
]
118+
119+
# Also create checkpoints for another thread that should not be deleted
120+
other_checkpoints_data = [
121+
("", "5", {"messages": ["Other1"]}, "input", 0),
122+
("ns1", "6", {"messages": ["Other2"]}, "output", 1),
123+
]
124+
125+
# Store all checkpoints
126+
for ns, cp_id, channel_values, source, step in checkpoints_data:
127+
config: RunnableConfig = {
128+
"configurable": {
129+
"thread_id": thread_id,
130+
"checkpoint_ns": ns,
131+
"checkpoint_id": cp_id,
132+
}
133+
}
134+
135+
checkpoint = Checkpoint(
136+
v=1,
137+
id=cp_id,
138+
ts=f"2024-01-01T00:00:0{cp_id}Z",
139+
channel_values=channel_values,
140+
channel_versions={"messages": "1"},
141+
versions_seen={"agent": {"messages": "1"}},
142+
pending_sends=[],
143+
tasks=[],
144+
)
145+
146+
await checkpointer.aput(
147+
config=config,
148+
checkpoint=checkpoint,
149+
metadata=CheckpointMetadata(source=source, step=step, writes={}),
150+
new_versions={"messages": "1"},
151+
)
152+
153+
# Also add some writes
154+
await checkpointer.aput_writes(
155+
config=config,
156+
writes=[("messages", f"Write for {cp_id}")],
157+
task_id=f"task-{cp_id}",
158+
)
159+
160+
# Store checkpoints for other thread
161+
for ns, cp_id, channel_values, source, step in other_checkpoints_data:
162+
config: RunnableConfig = {
163+
"configurable": {
164+
"thread_id": other_thread_id,
165+
"checkpoint_ns": ns,
166+
"checkpoint_id": cp_id,
167+
}
168+
}
169+
170+
checkpoint = Checkpoint(
171+
v=1,
172+
id=cp_id,
173+
ts=f"2024-01-01T00:00:0{cp_id}Z",
174+
channel_values=channel_values,
175+
channel_versions={"messages": "1"},
176+
versions_seen={"agent": {"messages": "1"}},
177+
pending_sends=[],
178+
tasks=[],
179+
)
180+
181+
await checkpointer.aput(
182+
config=config,
183+
checkpoint=checkpoint,
184+
metadata=CheckpointMetadata(source=source, step=step, writes={}),
185+
new_versions={"messages": "1"},
186+
)
187+
188+
# Verify all checkpoints exist
189+
for ns, cp_id, _, _, _ in checkpoints_data:
190+
config: RunnableConfig = {
191+
"configurable": {
192+
"thread_id": thread_id,
193+
"checkpoint_ns": ns,
194+
"checkpoint_id": cp_id,
195+
}
196+
}
197+
result = await checkpointer.aget_tuple(config)
198+
assert result is not None
199+
assert result.checkpoint["id"] == cp_id
200+
201+
# Verify other thread checkpoints exist
202+
for ns, cp_id, _, _, _ in other_checkpoints_data:
203+
config: RunnableConfig = {
204+
"configurable": {
205+
"thread_id": other_thread_id,
206+
"checkpoint_ns": ns,
207+
"checkpoint_id": cp_id,
208+
}
209+
}
210+
result = await checkpointer.aget_tuple(config)
211+
assert result is not None
212+
assert result.checkpoint["id"] == cp_id
213+
214+
# Delete the thread
215+
await checkpointer.adelete_thread(thread_id)
216+
217+
# Verify all checkpoints for the thread are deleted
218+
for ns, cp_id, _, _, _ in checkpoints_data:
219+
config: RunnableConfig = {
220+
"configurable": {
221+
"thread_id": thread_id,
222+
"checkpoint_ns": ns,
223+
"checkpoint_id": cp_id,
224+
}
225+
}
226+
result = await checkpointer.aget_tuple(config)
227+
assert result is None
228+
229+
# Verify other thread checkpoints still exist
230+
for ns, cp_id, _, _, _ in other_checkpoints_data:
231+
config: RunnableConfig = {
232+
"configurable": {
233+
"thread_id": other_thread_id,
234+
"checkpoint_ns": ns,
235+
"checkpoint_id": cp_id,
236+
}
237+
}
238+
result = await checkpointer.aget_tuple(config)
239+
assert result is not None
240+
assert result.checkpoint["id"] == cp_id

0 commit comments

Comments
 (0)