Skip to content

Commit 1ea9b6d

Browse files
committed
feat(redis): implement TTL support and upgrade langgraph to ^0.3.0 (#18,#23)
- Add Time-To-Live (TTL) functionality to Redis store implementation TTL using Redis's native TTL functionality - Update dependency to langgraph ^0.3.0 with proper import handling for create_react_agent and fix various type errors to ensure linting sanity. - Added null checks for connection_args to satisfy mypy type checking. - Implemented the URL environment variable handling directly in our code.
1 parent 4ea899f commit 1ea9b6d

File tree

11 files changed

+1274
-841
lines changed

11 files changed

+1274
-841
lines changed

langgraph/checkpoint/redis/aio.py

+11-3
Original file line numberDiff line numberDiff line change
@@ -4,6 +4,7 @@
44

55
import asyncio
66
import json
7+
import os
78
from collections.abc import AsyncIterator
89
from contextlib import asynccontextmanager
910
from functools import partial
@@ -89,9 +90,16 @@ def configure_client(
8990
) -> None:
9091
"""Configure the Redis client."""
9192
self._owns_its_client = redis_client is None
92-
self._redis = redis_client or RedisConnectionFactory.get_async_redis_connection(
93-
redis_url, **connection_args
94-
)
93+
94+
# Use direct AsyncRedis.from_url to avoid the deprecated get_async_redis_connection
95+
if redis_client is None:
96+
if not redis_url:
97+
redis_url = os.environ.get("REDIS_URL")
98+
if not redis_url:
99+
raise ValueError("REDIS_URL env var not set")
100+
self._redis = AsyncRedis.from_url(redis_url, **(connection_args or {}))
101+
else:
102+
self._redis = redis_client
95103

96104
def create_indexes(self) -> None:
97105
"""Create indexes without connecting to Redis."""

langgraph/checkpoint/redis/ashallow.py

+11-3
Original file line numberDiff line numberDiff line change
@@ -4,6 +4,7 @@
44

55
import asyncio
66
import json
7+
import os
78
from contextlib import asynccontextmanager
89
from functools import partial
910
from types import TracebackType
@@ -546,9 +547,16 @@ def configure_client(
546547
) -> None:
547548
"""Configure the Redis client."""
548549
self._owns_its_client = redis_client is None
549-
self._redis = redis_client or RedisConnectionFactory.get_async_redis_connection(
550-
redis_url, **connection_args
551-
)
550+
551+
# Use direct AsyncRedis.from_url to avoid the deprecated get_async_redis_connection
552+
if redis_client is None:
553+
if not redis_url:
554+
redis_url = os.environ.get("REDIS_URL")
555+
if not redis_url:
556+
raise ValueError("REDIS_URL env var not set")
557+
self._redis = AsyncRedis.from_url(redis_url, **(connection_args or {}))
558+
else:
559+
self._redis = redis_client
552560

553561
def create_indexes(self) -> None:
554562
"""Create indexes without connecting to Redis."""

langgraph/checkpoint/redis/version.py

+1-1
Original file line numberDiff line numberDiff line change
@@ -1,5 +1,5 @@
11
from redisvl.version import __version__ as __redisvl_version__
22

3-
__version__ = "0.0.3"
3+
__version__ = "0.0.4"
44
__lib_name__ = f"langgraph-checkpoint-redis_v{__version__}"
55
__full_lib_name__ = f"redis-py(redisvl_v{__redisvl_version__};{__lib_name__})"

langgraph/store/redis/__init__.py

+146-8
Original file line numberDiff line numberDiff line change
@@ -18,6 +18,7 @@
1818
PutOp,
1919
Result,
2020
SearchOp,
21+
TTLConfig,
2122
)
2223
from redis import Redis
2324
from redis.commands.search.query import Query
@@ -70,14 +71,19 @@ class RedisStore(BaseStore, BaseRedisStore[Redis, SearchIndex]):
7071
vector similarity search support.
7172
"""
7273

74+
# Enable TTL support
75+
supports_ttl = True
76+
ttl_config: Optional[TTLConfig] = None
77+
7378
def __init__(
7479
self,
7580
conn: Redis,
7681
*,
7782
index: Optional[IndexConfig] = None,
83+
ttl: Optional[dict[str, Any]] = None,
7884
) -> None:
7985
BaseStore.__init__(self)
80-
BaseRedisStore.__init__(self, conn, index=index)
86+
BaseRedisStore.__init__(self, conn, index=index, ttl=ttl)
8187

8288
@classmethod
8389
@contextmanager
@@ -86,12 +92,13 @@ def from_conn_string(
8692
conn_string: str,
8793
*,
8894
index: Optional[IndexConfig] = None,
95+
ttl: Optional[dict[str, Any]] = None,
8996
) -> Iterator[RedisStore]:
9097
"""Create store from Redis connection string."""
9198
client = None
9299
try:
93100
client = RedisConnectionFactory.get_redis_connection(conn_string)
94-
yield cls(client, index=index)
101+
yield cls(client, index=index, ttl=ttl)
95102
finally:
96103
if client:
97104
client.close()
@@ -186,15 +193,64 @@ def _batch_get_ops(
186193
results: list[Result],
187194
) -> None:
188195
"""Execute GET operations in batch."""
196+
refresh_keys_by_idx: dict[int, list[str]] = (
197+
{}
198+
) # Track keys that need TTL refreshed by op index
199+
189200
for query, _, namespace, items in self._get_batch_GET_ops_queries(get_ops):
190201
res = self.store_index.search(Query(query))
191202
# Parse JSON from each document
192203
key_to_row = {
193-
json.loads(doc.json)["key"]: json.loads(doc.json) for doc in res.docs
204+
json.loads(doc.json)["key"]: (json.loads(doc.json), doc.id)
205+
for doc in res.docs
194206
}
207+
195208
for idx, key in items:
196209
if key in key_to_row:
197-
results[idx] = _row_to_item(namespace, key_to_row[key])
210+
data, doc_id = key_to_row[key]
211+
results[idx] = _row_to_item(namespace, data)
212+
213+
# Find the corresponding operation by looking it up in the operation list
214+
# This is needed because idx is the index in the overall operation list
215+
op_idx = None
216+
for i, (local_idx, op) in enumerate(get_ops):
217+
if local_idx == idx:
218+
op_idx = i
219+
break
220+
221+
if op_idx is not None:
222+
op = get_ops[op_idx][1]
223+
if hasattr(op, "refresh_ttl") and op.refresh_ttl:
224+
if idx not in refresh_keys_by_idx:
225+
refresh_keys_by_idx[idx] = []
226+
refresh_keys_by_idx[idx].append(doc_id)
227+
228+
# Also add vector keys for the same document
229+
doc_uuid = doc_id.split(":")[-1]
230+
vector_key = (
231+
f"{STORE_VECTOR_PREFIX}{REDIS_KEY_SEPARATOR}{doc_uuid}"
232+
)
233+
refresh_keys_by_idx[idx].append(vector_key)
234+
235+
# Now refresh TTLs for any keys that need it
236+
if refresh_keys_by_idx and self.ttl_config:
237+
# Get default TTL from config
238+
ttl_minutes = None
239+
if "default_ttl" in self.ttl_config:
240+
ttl_minutes = self.ttl_config.get("default_ttl")
241+
242+
if ttl_minutes is not None:
243+
ttl_seconds = int(ttl_minutes * 60)
244+
pipeline = self._redis.pipeline()
245+
246+
for keys in refresh_keys_by_idx.values():
247+
for key in keys:
248+
# Only refresh TTL if the key exists and has a TTL
249+
ttl = self._redis.ttl(key)
250+
if ttl > 0: # Only refresh if key exists and has TTL
251+
pipeline.expire(key, ttl_seconds)
252+
253+
pipeline.execute()
198254

199255
def _batch_put_ops(
200256
self,
@@ -219,20 +275,35 @@ def _batch_put_ops(
219275
doc_ids: dict[tuple[str, str], str] = {}
220276
store_docs: list[RedisDocument] = []
221277
store_keys: list[str] = []
278+
ttl_tracking: dict[str, tuple[list[str], Optional[float]]] = (
279+
{}
280+
) # Tracks keys that need TTL + their TTL values
222281

223282
# Generate IDs for PUT operations
224283
for _, op in put_ops:
225284
if op.value is not None:
226285
generated_doc_id = str(ULID())
227286
namespace = _namespace_to_text(op.namespace)
228287
doc_ids[(namespace, op.key)] = generated_doc_id
288+
# Track TTL for this document if specified
289+
if hasattr(op, "ttl") and op.ttl is not None:
290+
main_key = f"{STORE_PREFIX}{REDIS_KEY_SEPARATOR}{generated_doc_id}"
291+
ttl_tracking[main_key] = ([], op.ttl)
229292

230293
# Load store docs with explicit keys
231294
for doc in operations:
232295
store_key = (doc["prefix"], doc["key"])
233296
doc_id = doc_ids[store_key]
297+
# Remove TTL fields - they're not needed with Redis native TTL
298+
if "ttl_minutes" in doc:
299+
doc.pop("ttl_minutes", None)
300+
if "expires_at" in doc:
301+
doc.pop("expires_at", None)
302+
234303
store_docs.append(doc)
235-
store_keys.append(f"{STORE_PREFIX}{REDIS_KEY_SEPARATOR}{doc_id}")
304+
redis_key = f"{STORE_PREFIX}{REDIS_KEY_SEPARATOR}{doc_id}"
305+
store_keys.append(redis_key)
306+
236307
if store_docs:
237308
self.store_index.load(store_docs, keys=store_keys)
238309

@@ -260,12 +331,21 @@ def _batch_put_ops(
260331
"updated_at": datetime.now(timezone.utc).timestamp(),
261332
}
262333
)
263-
vector_keys.append(
264-
f"{STORE_VECTOR_PREFIX}{REDIS_KEY_SEPARATOR}{doc_id}"
265-
)
334+
vector_key = f"{STORE_VECTOR_PREFIX}{REDIS_KEY_SEPARATOR}{doc_id}"
335+
vector_keys.append(vector_key)
336+
337+
# Add this vector key to the related keys list for TTL
338+
main_key = f"{STORE_PREFIX}{REDIS_KEY_SEPARATOR}{doc_id}"
339+
if main_key in ttl_tracking:
340+
ttl_tracking[main_key][0].append(vector_key)
341+
266342
if vector_docs:
267343
self.vector_index.load(vector_docs, keys=vector_keys)
268344

345+
# Now apply TTLs after all documents are loaded
346+
for main_key, (related_keys, ttl_minutes) in ttl_tracking.items():
347+
self._apply_ttl_to_keys(main_key, related_keys, ttl_minutes)
348+
269349
def _batch_search_ops(
270350
self,
271351
search_ops: list[tuple[int, SearchOp]],
@@ -316,6 +396,8 @@ def _batch_search_ops(
316396

317397
# Process results maintaining order and applying filters
318398
items = []
399+
refresh_keys = [] # Track keys that need TTL refreshed
400+
319401
for store_key, store_doc in zip(result_map.keys(), store_docs):
320402
if store_doc:
321403
vector_result = result_map[store_key]
@@ -345,6 +427,16 @@ def _batch_search_ops(
345427
if not matches:
346428
continue
347429

430+
# If refresh_ttl is true, add to list for refreshing
431+
if op.refresh_ttl:
432+
refresh_keys.append(store_key)
433+
# Also find associated vector keys with same ID
434+
doc_id = store_key.split(":")[-1]
435+
vector_key = (
436+
f"{STORE_VECTOR_PREFIX}{REDIS_KEY_SEPARATOR}{doc_id}"
437+
)
438+
refresh_keys.append(vector_key)
439+
348440
items.append(
349441
_row_to_search_item(
350442
_decode_ns(store_doc["prefix"]),
@@ -353,13 +445,31 @@ def _batch_search_ops(
353445
)
354446
)
355447

448+
# Refresh TTL if requested
449+
if op.refresh_ttl and refresh_keys and self.ttl_config:
450+
# Get default TTL from config
451+
ttl_minutes = None
452+
if "default_ttl" in self.ttl_config:
453+
ttl_minutes = self.ttl_config.get("default_ttl")
454+
455+
if ttl_minutes is not None:
456+
ttl_seconds = int(ttl_minutes * 60)
457+
pipeline = self._redis.pipeline()
458+
for key in refresh_keys:
459+
# Only refresh TTL if the key exists and has a TTL
460+
ttl = self._redis.ttl(key)
461+
if ttl > 0: # Only refresh if key exists and has TTL
462+
pipeline.expire(key, ttl_seconds)
463+
pipeline.execute()
464+
356465
results[idx] = items
357466
else:
358467
# Regular search
359468
query = Query(query_str)
360469
# Get all potential matches for filtering
361470
res = self.store_index.search(query)
362471
items = []
472+
refresh_keys = [] # Track keys that need TTL refreshed
363473

364474
for doc in res.docs:
365475
data = json.loads(doc.json)
@@ -378,13 +488,41 @@ def _batch_search_ops(
378488
break
379489
if not matches:
380490
continue
491+
492+
# If refresh_ttl is true, add the key to refresh list
493+
if op.refresh_ttl:
494+
refresh_keys.append(doc.id)
495+
# Also find associated vector keys with same ID
496+
doc_id = doc.id.split(":")[-1]
497+
vector_key = (
498+
f"{STORE_VECTOR_PREFIX}{REDIS_KEY_SEPARATOR}{doc_id}"
499+
)
500+
refresh_keys.append(vector_key)
501+
381502
items.append(_row_to_search_item(_decode_ns(data["prefix"]), data))
382503

383504
# Apply pagination after filtering
384505
if params:
385506
limit, offset = params
386507
items = items[offset : offset + limit]
387508

509+
# Refresh TTL if requested
510+
if op.refresh_ttl and refresh_keys and self.ttl_config:
511+
# Get default TTL from config
512+
ttl_minutes = None
513+
if "default_ttl" in self.ttl_config:
514+
ttl_minutes = self.ttl_config.get("default_ttl")
515+
516+
if ttl_minutes is not None:
517+
ttl_seconds = int(ttl_minutes * 60)
518+
pipeline = self._redis.pipeline()
519+
for key in refresh_keys:
520+
# Only refresh TTL if the key exists and has a TTL
521+
ttl = self._redis.ttl(key)
522+
if ttl > 0: # Only refresh if key exists and has TTL
523+
pipeline.expire(key, ttl_seconds)
524+
pipeline.execute()
525+
388526
results[idx] = items
389527

390528
async def abatch(self, ops: Iterable[Op]) -> list[Result]:

0 commit comments

Comments
 (0)