-
Notifications
You must be signed in to change notification settings - Fork 10
feat(checkpoint-redis): implement adelete_thread and delete_thread … #54
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
Changes from all commits
File filter
Filter by extension
Conversations
Jump to
Diff view
Diff view
There are no files selected for viewing
Original file line number | Diff line number | Diff line change | ||||
---|---|---|---|---|---|---|
|
@@ -575,6 +575,78 @@ def _load_pending_sends( | |||||
# Extract type and blob pairs | ||||||
return [(doc.type, doc.blob) for doc in sorted_writes] | ||||||
|
||||||
def delete_thread(self, thread_id: str) -> None: | ||||||
"""Delete all checkpoints and writes associated with a specific thread ID. | ||||||
|
||||||
Args: | ||||||
thread_id: The thread ID whose checkpoints should be deleted. | ||||||
""" | ||||||
storage_safe_thread_id = to_storage_safe_id(thread_id) | ||||||
|
||||||
# Delete all checkpoints for this thread | ||||||
checkpoint_query = FilterQuery( | ||||||
filter_expression=Tag("thread_id") == storage_safe_thread_id, | ||||||
return_fields=["checkpoint_ns", "checkpoint_id"], | ||||||
num_results=10000, # Get all checkpoints for this thread | ||||||
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. Consider using a named constant instead of hardcoding 10000 to enhance clarity and ease future adjustments.
Suggested change
Copilot uses AI. Check for mistakes. Positive FeedbackNegative Feedback |
||||||
) | ||||||
|
||||||
checkpoint_results = self.checkpoints_index.search(checkpoint_query) | ||||||
|
||||||
# Delete all checkpoint-related keys | ||||||
pipeline = self._redis.pipeline() | ||||||
|
||||||
for doc in checkpoint_results.docs: | ||||||
checkpoint_ns = getattr(doc, "checkpoint_ns", "") | ||||||
checkpoint_id = getattr(doc, "checkpoint_id", "") | ||||||
|
||||||
# Delete checkpoint key | ||||||
checkpoint_key = BaseRedisSaver._make_redis_checkpoint_key( | ||||||
storage_safe_thread_id, checkpoint_ns, checkpoint_id | ||||||
) | ||||||
pipeline.delete(checkpoint_key) | ||||||
|
||||||
# Delete all blobs for this thread | ||||||
blob_query = FilterQuery( | ||||||
filter_expression=Tag("thread_id") == storage_safe_thread_id, | ||||||
return_fields=["checkpoint_ns", "channel", "version"], | ||||||
num_results=10000, | ||||||
) | ||||||
|
||||||
blob_results = self.checkpoint_blobs_index.search(blob_query) | ||||||
|
||||||
for doc in blob_results.docs: | ||||||
checkpoint_ns = getattr(doc, "checkpoint_ns", "") | ||||||
channel = getattr(doc, "channel", "") | ||||||
version = getattr(doc, "version", "") | ||||||
|
||||||
blob_key = BaseRedisSaver._make_redis_checkpoint_blob_key( | ||||||
storage_safe_thread_id, checkpoint_ns, channel, version | ||||||
) | ||||||
pipeline.delete(blob_key) | ||||||
|
||||||
# Delete all writes for this thread | ||||||
writes_query = FilterQuery( | ||||||
filter_expression=Tag("thread_id") == storage_safe_thread_id, | ||||||
return_fields=["checkpoint_ns", "checkpoint_id", "task_id", "idx"], | ||||||
num_results=10000, | ||||||
) | ||||||
|
||||||
writes_results = self.checkpoint_writes_index.search(writes_query) | ||||||
|
||||||
for doc in writes_results.docs: | ||||||
checkpoint_ns = getattr(doc, "checkpoint_ns", "") | ||||||
checkpoint_id = getattr(doc, "checkpoint_id", "") | ||||||
task_id = getattr(doc, "task_id", "") | ||||||
idx = getattr(doc, "idx", 0) | ||||||
|
||||||
write_key = BaseRedisSaver._make_redis_checkpoint_writes_key( | ||||||
storage_safe_thread_id, checkpoint_ns, checkpoint_id, task_id, idx | ||||||
) | ||||||
pipeline.delete(write_key) | ||||||
|
||||||
# Execute all deletions | ||||||
pipeline.execute() | ||||||
|
||||||
|
||||||
__all__ = [ | ||||||
"__version__", | ||||||
|
Original file line number | Diff line number | Diff line change | ||||
---|---|---|---|---|---|---|
|
@@ -925,3 +925,75 @@ async def _aload_pending_writes( | |||||
|
||||||
pending_writes = BaseRedisSaver._load_writes(self.serde, writes_dict) | ||||||
return pending_writes | ||||||
|
||||||
async def adelete_thread(self, thread_id: str) -> None: | ||||||
"""Delete all checkpoints and writes associated with a specific thread ID. | ||||||
|
||||||
Args: | ||||||
thread_id: The thread ID whose checkpoints should be deleted. | ||||||
""" | ||||||
storage_safe_thread_id = to_storage_safe_id(thread_id) | ||||||
|
||||||
# Delete all checkpoints for this thread | ||||||
checkpoint_query = FilterQuery( | ||||||
filter_expression=Tag("thread_id") == storage_safe_thread_id, | ||||||
return_fields=["checkpoint_ns", "checkpoint_id"], | ||||||
num_results=10000, # Get all checkpoints for this thread | ||||||
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. Consider defining a named constant for the magic number 10000 to improve maintainability and readability.
Suggested change
Copilot uses AI. Check for mistakes. Positive FeedbackNegative Feedback |
||||||
) | ||||||
|
||||||
checkpoint_results = await self.checkpoints_index.search(checkpoint_query) | ||||||
|
||||||
# Delete all checkpoint-related keys | ||||||
pipeline = self._redis.pipeline() | ||||||
|
||||||
for doc in checkpoint_results.docs: | ||||||
checkpoint_ns = getattr(doc, "checkpoint_ns", "") | ||||||
checkpoint_id = getattr(doc, "checkpoint_id", "") | ||||||
|
||||||
# Delete checkpoint key | ||||||
checkpoint_key = BaseRedisSaver._make_redis_checkpoint_key( | ||||||
storage_safe_thread_id, checkpoint_ns, checkpoint_id | ||||||
) | ||||||
pipeline.delete(checkpoint_key) | ||||||
|
||||||
# Delete all blobs for this thread | ||||||
blob_query = FilterQuery( | ||||||
filter_expression=Tag("thread_id") == storage_safe_thread_id, | ||||||
return_fields=["checkpoint_ns", "channel", "version"], | ||||||
num_results=10000, | ||||||
) | ||||||
|
||||||
blob_results = await self.checkpoint_blobs_index.search(blob_query) | ||||||
|
||||||
for doc in blob_results.docs: | ||||||
checkpoint_ns = getattr(doc, "checkpoint_ns", "") | ||||||
channel = getattr(doc, "channel", "") | ||||||
version = getattr(doc, "version", "") | ||||||
|
||||||
blob_key = BaseRedisSaver._make_redis_checkpoint_blob_key( | ||||||
storage_safe_thread_id, checkpoint_ns, channel, version | ||||||
) | ||||||
pipeline.delete(blob_key) | ||||||
|
||||||
# Delete all writes for this thread | ||||||
writes_query = FilterQuery( | ||||||
filter_expression=Tag("thread_id") == storage_safe_thread_id, | ||||||
return_fields=["checkpoint_ns", "checkpoint_id", "task_id", "idx"], | ||||||
num_results=10000, | ||||||
) | ||||||
|
||||||
writes_results = await self.checkpoint_writes_index.search(writes_query) | ||||||
|
||||||
for doc in writes_results.docs: | ||||||
checkpoint_ns = getattr(doc, "checkpoint_ns", "") | ||||||
checkpoint_id = getattr(doc, "checkpoint_id", "") | ||||||
task_id = getattr(doc, "task_id", "") | ||||||
idx = getattr(doc, "idx", 0) | ||||||
|
||||||
write_key = BaseRedisSaver._make_redis_checkpoint_writes_key( | ||||||
storage_safe_thread_id, checkpoint_ns, checkpoint_id, task_id, idx | ||||||
) | ||||||
pipeline.delete(write_key) | ||||||
|
||||||
# Execute all deletions | ||||||
await pipeline.execute() |
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,240 @@ | ||
"""Test for issue #51 - adelete_thread implementation.""" | ||
|
||
import pytest | ||
from langchain_core.runnables import RunnableConfig | ||
from langgraph.checkpoint.base import Checkpoint, CheckpointMetadata | ||
|
||
from langgraph.checkpoint.redis import RedisSaver | ||
from langgraph.checkpoint.redis.aio import AsyncRedisSaver | ||
|
||
|
||
@pytest.mark.asyncio | ||
async def test_adelete_thread_implemented(redis_url): | ||
"""Test that adelete_thread method is now implemented in AsyncRedisSaver.""" | ||
async with AsyncRedisSaver.from_conn_string(redis_url) as checkpointer: | ||
# Create a checkpoint | ||
thread_id = "test-thread-to-delete" | ||
config: RunnableConfig = { | ||
"configurable": { | ||
"thread_id": thread_id, | ||
"checkpoint_ns": "", | ||
"checkpoint_id": "1", | ||
} | ||
} | ||
|
||
checkpoint = Checkpoint( | ||
v=1, | ||
id="1", | ||
ts="2024-01-01T00:00:00Z", | ||
channel_values={"messages": ["Test"]}, | ||
channel_versions={"messages": "1"}, | ||
versions_seen={"agent": {"messages": "1"}}, | ||
pending_sends=[], | ||
tasks=[], | ||
) | ||
|
||
# Store checkpoint | ||
await checkpointer.aput( | ||
config=config, | ||
checkpoint=checkpoint, | ||
metadata=CheckpointMetadata(source="input", step=0, writes={}), | ||
new_versions={"messages": "1"}, | ||
) | ||
|
||
# Verify checkpoint exists | ||
result = await checkpointer.aget_tuple(config) | ||
assert result is not None | ||
assert result.checkpoint["id"] == "1" | ||
|
||
# Delete the thread | ||
await checkpointer.adelete_thread(thread_id) | ||
|
||
# Verify checkpoint is deleted | ||
result = await checkpointer.aget_tuple(config) | ||
assert result is None | ||
|
||
|
||
def test_delete_thread_implemented(redis_url): | ||
"""Test that delete_thread method is now implemented in RedisSaver.""" | ||
with RedisSaver.from_conn_string(redis_url) as checkpointer: | ||
checkpointer.setup() # Initialize Redis indices | ||
|
||
# Create a checkpoint | ||
thread_id = "test-thread-to-delete-sync" | ||
config: RunnableConfig = { | ||
"configurable": { | ||
"thread_id": thread_id, | ||
"checkpoint_ns": "", | ||
"checkpoint_id": "1", | ||
} | ||
} | ||
|
||
checkpoint = Checkpoint( | ||
v=1, | ||
id="1", | ||
ts="2024-01-01T00:00:00Z", | ||
channel_values={"messages": ["Test"]}, | ||
channel_versions={"messages": "1"}, | ||
versions_seen={"agent": {"messages": "1"}}, | ||
pending_sends=[], | ||
tasks=[], | ||
) | ||
|
||
# Store checkpoint | ||
checkpointer.put( | ||
config=config, | ||
checkpoint=checkpoint, | ||
metadata=CheckpointMetadata(source="input", step=0, writes={}), | ||
new_versions={"messages": "1"}, | ||
) | ||
|
||
# Verify checkpoint exists | ||
result = checkpointer.get_tuple(config) | ||
assert result is not None | ||
assert result.checkpoint["id"] == "1" | ||
|
||
# Delete the thread | ||
checkpointer.delete_thread(thread_id) | ||
|
||
# Verify checkpoint is deleted | ||
result = checkpointer.get_tuple(config) | ||
assert result is None | ||
|
||
|
||
@pytest.mark.asyncio | ||
async def test_adelete_thread_comprehensive(redis_url): | ||
"""Comprehensive test for adelete_thread with multiple checkpoints and namespaces.""" | ||
async with AsyncRedisSaver.from_conn_string(redis_url) as checkpointer: | ||
thread_id = "test-thread-comprehensive" | ||
other_thread_id = "other-thread" | ||
|
||
# Create multiple checkpoints for the thread | ||
checkpoints_data = [ | ||
("", "1", {"messages": ["First"]}, "input", 0), | ||
("", "2", {"messages": ["Second"]}, "output", 1), | ||
("ns1", "3", {"messages": ["Third"]}, "input", 0), | ||
("ns2", "4", {"messages": ["Fourth"]}, "output", 1), | ||
] | ||
|
||
# Also create checkpoints for another thread that should not be deleted | ||
other_checkpoints_data = [ | ||
("", "5", {"messages": ["Other1"]}, "input", 0), | ||
("ns1", "6", {"messages": ["Other2"]}, "output", 1), | ||
] | ||
|
||
# Store all checkpoints | ||
for ns, cp_id, channel_values, source, step in checkpoints_data: | ||
config: RunnableConfig = { | ||
"configurable": { | ||
"thread_id": thread_id, | ||
"checkpoint_ns": ns, | ||
"checkpoint_id": cp_id, | ||
} | ||
} | ||
|
||
checkpoint = Checkpoint( | ||
v=1, | ||
id=cp_id, | ||
ts=f"2024-01-01T00:00:0{cp_id}Z", | ||
channel_values=channel_values, | ||
channel_versions={"messages": "1"}, | ||
versions_seen={"agent": {"messages": "1"}}, | ||
pending_sends=[], | ||
tasks=[], | ||
) | ||
|
||
await checkpointer.aput( | ||
config=config, | ||
checkpoint=checkpoint, | ||
metadata=CheckpointMetadata(source=source, step=step, writes={}), | ||
new_versions={"messages": "1"}, | ||
) | ||
|
||
# Also add some writes | ||
await checkpointer.aput_writes( | ||
config=config, | ||
writes=[("messages", f"Write for {cp_id}")], | ||
task_id=f"task-{cp_id}", | ||
) | ||
|
||
# Store checkpoints for other thread | ||
for ns, cp_id, channel_values, source, step in other_checkpoints_data: | ||
config: RunnableConfig = { | ||
"configurable": { | ||
"thread_id": other_thread_id, | ||
"checkpoint_ns": ns, | ||
"checkpoint_id": cp_id, | ||
} | ||
} | ||
|
||
checkpoint = Checkpoint( | ||
v=1, | ||
id=cp_id, | ||
ts=f"2024-01-01T00:00:0{cp_id}Z", | ||
channel_values=channel_values, | ||
channel_versions={"messages": "1"}, | ||
versions_seen={"agent": {"messages": "1"}}, | ||
pending_sends=[], | ||
tasks=[], | ||
) | ||
|
||
await checkpointer.aput( | ||
config=config, | ||
checkpoint=checkpoint, | ||
metadata=CheckpointMetadata(source=source, step=step, writes={}), | ||
new_versions={"messages": "1"}, | ||
) | ||
|
||
# Verify all checkpoints exist | ||
for ns, cp_id, _, _, _ in checkpoints_data: | ||
config: RunnableConfig = { | ||
"configurable": { | ||
"thread_id": thread_id, | ||
"checkpoint_ns": ns, | ||
"checkpoint_id": cp_id, | ||
} | ||
} | ||
result = await checkpointer.aget_tuple(config) | ||
assert result is not None | ||
assert result.checkpoint["id"] == cp_id | ||
|
||
# Verify other thread checkpoints exist | ||
for ns, cp_id, _, _, _ in other_checkpoints_data: | ||
config: RunnableConfig = { | ||
"configurable": { | ||
"thread_id": other_thread_id, | ||
"checkpoint_ns": ns, | ||
"checkpoint_id": cp_id, | ||
} | ||
} | ||
result = await checkpointer.aget_tuple(config) | ||
assert result is not None | ||
assert result.checkpoint["id"] == cp_id | ||
|
||
# Delete the thread | ||
await checkpointer.adelete_thread(thread_id) | ||
|
||
# Verify all checkpoints for the thread are deleted | ||
for ns, cp_id, _, _, _ in checkpoints_data: | ||
config: RunnableConfig = { | ||
"configurable": { | ||
"thread_id": thread_id, | ||
"checkpoint_ns": ns, | ||
"checkpoint_id": cp_id, | ||
} | ||
} | ||
result = await checkpointer.aget_tuple(config) | ||
assert result is None | ||
|
||
# Verify other thread checkpoints still exist | ||
for ns, cp_id, _, _, _ in other_checkpoints_data: | ||
config: RunnableConfig = { | ||
"configurable": { | ||
"thread_id": other_thread_id, | ||
"checkpoint_ns": ns, | ||
"checkpoint_id": cp_id, | ||
} | ||
} | ||
result = await checkpointer.aget_tuple(config) | ||
assert result is not None | ||
assert result.checkpoint["id"] == cp_id |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
[nitpick] Both async and sync deletion methods share very similar logic; consider refactoring the common deletion routines into a shared utility to reduce duplication.
Copilot uses AI. Check for mistakes.