Skip to content

feat(checkpoint-redis): implement adelete_thread and delete_thread … #54

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Merged
merged 1 commit into from
Jun 1, 2025
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
72 changes: 72 additions & 0 deletions langgraph/checkpoint/redis/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -575,6 +575,78 @@ def _load_pending_sends(
# Extract type and blob pairs
return [(doc.type, doc.blob) for doc in sorted_writes]

def delete_thread(self, thread_id: str) -> None:
Copy link
Preview

Copilot AI Jun 1, 2025

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

[nitpick] Both async and sync deletion methods share very similar logic; consider refactoring the common deletion routines into a shared utility to reduce duplication.

Copilot uses AI. Check for mistakes.

"""Delete all checkpoints and writes associated with a specific thread ID.

Args:
thread_id: The thread ID whose checkpoints should be deleted.
"""
storage_safe_thread_id = to_storage_safe_id(thread_id)

# Delete all checkpoints for this thread
checkpoint_query = FilterQuery(
filter_expression=Tag("thread_id") == storage_safe_thread_id,
return_fields=["checkpoint_ns", "checkpoint_id"],
num_results=10000, # Get all checkpoints for this thread
Copy link
Preview

Copilot AI Jun 1, 2025

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Consider using a named constant instead of hardcoding 10000 to enhance clarity and ease future adjustments.

Suggested change
num_results=10000, # Get all checkpoints for this thread
num_results=DEFAULT_NUM_RESULTS, # Get all checkpoints for this thread

Copilot uses AI. Check for mistakes.

)

checkpoint_results = self.checkpoints_index.search(checkpoint_query)

# Delete all checkpoint-related keys
pipeline = self._redis.pipeline()

for doc in checkpoint_results.docs:
checkpoint_ns = getattr(doc, "checkpoint_ns", "")
checkpoint_id = getattr(doc, "checkpoint_id", "")

# Delete checkpoint key
checkpoint_key = BaseRedisSaver._make_redis_checkpoint_key(
storage_safe_thread_id, checkpoint_ns, checkpoint_id
)
pipeline.delete(checkpoint_key)

# Delete all blobs for this thread
blob_query = FilterQuery(
filter_expression=Tag("thread_id") == storage_safe_thread_id,
return_fields=["checkpoint_ns", "channel", "version"],
num_results=10000,
)

blob_results = self.checkpoint_blobs_index.search(blob_query)

for doc in blob_results.docs:
checkpoint_ns = getattr(doc, "checkpoint_ns", "")
channel = getattr(doc, "channel", "")
version = getattr(doc, "version", "")

blob_key = BaseRedisSaver._make_redis_checkpoint_blob_key(
storage_safe_thread_id, checkpoint_ns, channel, version
)
pipeline.delete(blob_key)

# Delete all writes for this thread
writes_query = FilterQuery(
filter_expression=Tag("thread_id") == storage_safe_thread_id,
return_fields=["checkpoint_ns", "checkpoint_id", "task_id", "idx"],
num_results=10000,
)

writes_results = self.checkpoint_writes_index.search(writes_query)

for doc in writes_results.docs:
checkpoint_ns = getattr(doc, "checkpoint_ns", "")
checkpoint_id = getattr(doc, "checkpoint_id", "")
task_id = getattr(doc, "task_id", "")
idx = getattr(doc, "idx", 0)

write_key = BaseRedisSaver._make_redis_checkpoint_writes_key(
storage_safe_thread_id, checkpoint_ns, checkpoint_id, task_id, idx
)
pipeline.delete(write_key)

# Execute all deletions
pipeline.execute()


__all__ = [
"__version__",
Expand Down
72 changes: 72 additions & 0 deletions langgraph/checkpoint/redis/aio.py
Original file line number Diff line number Diff line change
Expand Up @@ -925,3 +925,75 @@ async def _aload_pending_writes(

pending_writes = BaseRedisSaver._load_writes(self.serde, writes_dict)
return pending_writes

async def adelete_thread(self, thread_id: str) -> None:
"""Delete all checkpoints and writes associated with a specific thread ID.

Args:
thread_id: The thread ID whose checkpoints should be deleted.
"""
storage_safe_thread_id = to_storage_safe_id(thread_id)

# Delete all checkpoints for this thread
checkpoint_query = FilterQuery(
filter_expression=Tag("thread_id") == storage_safe_thread_id,
return_fields=["checkpoint_ns", "checkpoint_id"],
num_results=10000, # Get all checkpoints for this thread
Copy link
Preview

Copilot AI Jun 1, 2025

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Consider defining a named constant for the magic number 10000 to improve maintainability and readability.

Suggested change
num_results=10000, # Get all checkpoints for this thread
num_results=DEFAULT_NUM_RESULTS, # Get all checkpoints for this thread

Copilot uses AI. Check for mistakes.

)

checkpoint_results = await self.checkpoints_index.search(checkpoint_query)

# Delete all checkpoint-related keys
pipeline = self._redis.pipeline()

for doc in checkpoint_results.docs:
checkpoint_ns = getattr(doc, "checkpoint_ns", "")
checkpoint_id = getattr(doc, "checkpoint_id", "")

# Delete checkpoint key
checkpoint_key = BaseRedisSaver._make_redis_checkpoint_key(
storage_safe_thread_id, checkpoint_ns, checkpoint_id
)
pipeline.delete(checkpoint_key)

# Delete all blobs for this thread
blob_query = FilterQuery(
filter_expression=Tag("thread_id") == storage_safe_thread_id,
return_fields=["checkpoint_ns", "channel", "version"],
num_results=10000,
)

blob_results = await self.checkpoint_blobs_index.search(blob_query)

for doc in blob_results.docs:
checkpoint_ns = getattr(doc, "checkpoint_ns", "")
channel = getattr(doc, "channel", "")
version = getattr(doc, "version", "")

blob_key = BaseRedisSaver._make_redis_checkpoint_blob_key(
storage_safe_thread_id, checkpoint_ns, channel, version
)
pipeline.delete(blob_key)

# Delete all writes for this thread
writes_query = FilterQuery(
filter_expression=Tag("thread_id") == storage_safe_thread_id,
return_fields=["checkpoint_ns", "checkpoint_id", "task_id", "idx"],
num_results=10000,
)

writes_results = await self.checkpoint_writes_index.search(writes_query)

for doc in writes_results.docs:
checkpoint_ns = getattr(doc, "checkpoint_ns", "")
checkpoint_id = getattr(doc, "checkpoint_id", "")
task_id = getattr(doc, "task_id", "")
idx = getattr(doc, "idx", 0)

write_key = BaseRedisSaver._make_redis_checkpoint_writes_key(
storage_safe_thread_id, checkpoint_ns, checkpoint_id, task_id, idx
)
pipeline.delete(write_key)

# Execute all deletions
await pipeline.execute()
240 changes: 240 additions & 0 deletions tests/test_issue_51_adelete_thread.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,240 @@
"""Test for issue #51 - adelete_thread implementation."""

import pytest
from langchain_core.runnables import RunnableConfig
from langgraph.checkpoint.base import Checkpoint, CheckpointMetadata

from langgraph.checkpoint.redis import RedisSaver
from langgraph.checkpoint.redis.aio import AsyncRedisSaver


@pytest.mark.asyncio
async def test_adelete_thread_implemented(redis_url):
"""Test that adelete_thread method is now implemented in AsyncRedisSaver."""
async with AsyncRedisSaver.from_conn_string(redis_url) as checkpointer:
# Create a checkpoint
thread_id = "test-thread-to-delete"
config: RunnableConfig = {
"configurable": {
"thread_id": thread_id,
"checkpoint_ns": "",
"checkpoint_id": "1",
}
}

checkpoint = Checkpoint(
v=1,
id="1",
ts="2024-01-01T00:00:00Z",
channel_values={"messages": ["Test"]},
channel_versions={"messages": "1"},
versions_seen={"agent": {"messages": "1"}},
pending_sends=[],
tasks=[],
)

# Store checkpoint
await checkpointer.aput(
config=config,
checkpoint=checkpoint,
metadata=CheckpointMetadata(source="input", step=0, writes={}),
new_versions={"messages": "1"},
)

# Verify checkpoint exists
result = await checkpointer.aget_tuple(config)
assert result is not None
assert result.checkpoint["id"] == "1"

# Delete the thread
await checkpointer.adelete_thread(thread_id)

# Verify checkpoint is deleted
result = await checkpointer.aget_tuple(config)
assert result is None


def test_delete_thread_implemented(redis_url):
"""Test that delete_thread method is now implemented in RedisSaver."""
with RedisSaver.from_conn_string(redis_url) as checkpointer:
checkpointer.setup() # Initialize Redis indices

# Create a checkpoint
thread_id = "test-thread-to-delete-sync"
config: RunnableConfig = {
"configurable": {
"thread_id": thread_id,
"checkpoint_ns": "",
"checkpoint_id": "1",
}
}

checkpoint = Checkpoint(
v=1,
id="1",
ts="2024-01-01T00:00:00Z",
channel_values={"messages": ["Test"]},
channel_versions={"messages": "1"},
versions_seen={"agent": {"messages": "1"}},
pending_sends=[],
tasks=[],
)

# Store checkpoint
checkpointer.put(
config=config,
checkpoint=checkpoint,
metadata=CheckpointMetadata(source="input", step=0, writes={}),
new_versions={"messages": "1"},
)

# Verify checkpoint exists
result = checkpointer.get_tuple(config)
assert result is not None
assert result.checkpoint["id"] == "1"

# Delete the thread
checkpointer.delete_thread(thread_id)

# Verify checkpoint is deleted
result = checkpointer.get_tuple(config)
assert result is None


@pytest.mark.asyncio
async def test_adelete_thread_comprehensive(redis_url):
"""Comprehensive test for adelete_thread with multiple checkpoints and namespaces."""
async with AsyncRedisSaver.from_conn_string(redis_url) as checkpointer:
thread_id = "test-thread-comprehensive"
other_thread_id = "other-thread"

# Create multiple checkpoints for the thread
checkpoints_data = [
("", "1", {"messages": ["First"]}, "input", 0),
("", "2", {"messages": ["Second"]}, "output", 1),
("ns1", "3", {"messages": ["Third"]}, "input", 0),
("ns2", "4", {"messages": ["Fourth"]}, "output", 1),
]

# Also create checkpoints for another thread that should not be deleted
other_checkpoints_data = [
("", "5", {"messages": ["Other1"]}, "input", 0),
("ns1", "6", {"messages": ["Other2"]}, "output", 1),
]

# Store all checkpoints
for ns, cp_id, channel_values, source, step in checkpoints_data:
config: RunnableConfig = {
"configurable": {
"thread_id": thread_id,
"checkpoint_ns": ns,
"checkpoint_id": cp_id,
}
}

checkpoint = Checkpoint(
v=1,
id=cp_id,
ts=f"2024-01-01T00:00:0{cp_id}Z",
channel_values=channel_values,
channel_versions={"messages": "1"},
versions_seen={"agent": {"messages": "1"}},
pending_sends=[],
tasks=[],
)

await checkpointer.aput(
config=config,
checkpoint=checkpoint,
metadata=CheckpointMetadata(source=source, step=step, writes={}),
new_versions={"messages": "1"},
)

# Also add some writes
await checkpointer.aput_writes(
config=config,
writes=[("messages", f"Write for {cp_id}")],
task_id=f"task-{cp_id}",
)

# Store checkpoints for other thread
for ns, cp_id, channel_values, source, step in other_checkpoints_data:
config: RunnableConfig = {
"configurable": {
"thread_id": other_thread_id,
"checkpoint_ns": ns,
"checkpoint_id": cp_id,
}
}

checkpoint = Checkpoint(
v=1,
id=cp_id,
ts=f"2024-01-01T00:00:0{cp_id}Z",
channel_values=channel_values,
channel_versions={"messages": "1"},
versions_seen={"agent": {"messages": "1"}},
pending_sends=[],
tasks=[],
)

await checkpointer.aput(
config=config,
checkpoint=checkpoint,
metadata=CheckpointMetadata(source=source, step=step, writes={}),
new_versions={"messages": "1"},
)

# Verify all checkpoints exist
for ns, cp_id, _, _, _ in checkpoints_data:
config: RunnableConfig = {
"configurable": {
"thread_id": thread_id,
"checkpoint_ns": ns,
"checkpoint_id": cp_id,
}
}
result = await checkpointer.aget_tuple(config)
assert result is not None
assert result.checkpoint["id"] == cp_id

# Verify other thread checkpoints exist
for ns, cp_id, _, _, _ in other_checkpoints_data:
config: RunnableConfig = {
"configurable": {
"thread_id": other_thread_id,
"checkpoint_ns": ns,
"checkpoint_id": cp_id,
}
}
result = await checkpointer.aget_tuple(config)
assert result is not None
assert result.checkpoint["id"] == cp_id

# Delete the thread
await checkpointer.adelete_thread(thread_id)

# Verify all checkpoints for the thread are deleted
for ns, cp_id, _, _, _ in checkpoints_data:
config: RunnableConfig = {
"configurable": {
"thread_id": thread_id,
"checkpoint_ns": ns,
"checkpoint_id": cp_id,
}
}
result = await checkpointer.aget_tuple(config)
assert result is None

# Verify other thread checkpoints still exist
for ns, cp_id, _, _, _ in other_checkpoints_data:
config: RunnableConfig = {
"configurable": {
"thread_id": other_thread_id,
"checkpoint_ns": ns,
"checkpoint_id": cp_id,
}
}
result = await checkpointer.aget_tuple(config)
assert result is not None
assert result.checkpoint["id"] == cp_id