18
18
PutOp ,
19
19
Result ,
20
20
SearchOp ,
21
+ TTLConfig ,
21
22
)
22
23
from redis import Redis
23
24
from redis .commands .search .query import Query
@@ -70,14 +71,19 @@ class RedisStore(BaseStore, BaseRedisStore[Redis, SearchIndex]):
70
71
vector similarity search support.
71
72
"""
72
73
74
+ # Enable TTL support
75
+ supports_ttl = True
76
+ ttl_config : Optional [TTLConfig ] = None
77
+
73
78
def __init__ (
74
79
self ,
75
80
conn : Redis ,
76
81
* ,
77
82
index : Optional [IndexConfig ] = None ,
83
+ ttl : Optional [dict [str , Any ]] = None ,
78
84
) -> None :
79
85
BaseStore .__init__ (self )
80
- BaseRedisStore .__init__ (self , conn , index = index )
86
+ BaseRedisStore .__init__ (self , conn , index = index , ttl = ttl )
81
87
82
88
@classmethod
83
89
@contextmanager
@@ -86,12 +92,13 @@ def from_conn_string(
86
92
conn_string : str ,
87
93
* ,
88
94
index : Optional [IndexConfig ] = None ,
95
+ ttl : Optional [dict [str , Any ]] = None ,
89
96
) -> Iterator [RedisStore ]:
90
97
"""Create store from Redis connection string."""
91
98
client = None
92
99
try :
93
100
client = RedisConnectionFactory .get_redis_connection (conn_string )
94
- yield cls (client , index = index )
101
+ yield cls (client , index = index , ttl = ttl )
95
102
finally :
96
103
if client :
97
104
client .close ()
@@ -186,15 +193,64 @@ def _batch_get_ops(
186
193
results : list [Result ],
187
194
) -> None :
188
195
"""Execute GET operations in batch."""
196
+ refresh_keys_by_idx : dict [int , list [str ]] = (
197
+ {}
198
+ ) # Track keys that need TTL refreshed by op index
199
+
189
200
for query , _ , namespace , items in self ._get_batch_GET_ops_queries (get_ops ):
190
201
res = self .store_index .search (Query (query ))
191
202
# Parse JSON from each document
192
203
key_to_row = {
193
- json .loads (doc .json )["key" ]: json .loads (doc .json ) for doc in res .docs
204
+ json .loads (doc .json )["key" ]: (json .loads (doc .json ), doc .id )
205
+ for doc in res .docs
194
206
}
207
+
195
208
for idx , key in items :
196
209
if key in key_to_row :
197
- results [idx ] = _row_to_item (namespace , key_to_row [key ])
210
+ data , doc_id = key_to_row [key ]
211
+ results [idx ] = _row_to_item (namespace , data )
212
+
213
+ # Find the corresponding operation by looking it up in the operation list
214
+ # This is needed because idx is the index in the overall operation list
215
+ op_idx = None
216
+ for i , (local_idx , op ) in enumerate (get_ops ):
217
+ if local_idx == idx :
218
+ op_idx = i
219
+ break
220
+
221
+ if op_idx is not None :
222
+ op = get_ops [op_idx ][1 ]
223
+ if hasattr (op , "refresh_ttl" ) and op .refresh_ttl :
224
+ if idx not in refresh_keys_by_idx :
225
+ refresh_keys_by_idx [idx ] = []
226
+ refresh_keys_by_idx [idx ].append (doc_id )
227
+
228
+ # Also add vector keys for the same document
229
+ doc_uuid = doc_id .split (":" )[- 1 ]
230
+ vector_key = (
231
+ f"{ STORE_VECTOR_PREFIX } { REDIS_KEY_SEPARATOR } { doc_uuid } "
232
+ )
233
+ refresh_keys_by_idx [idx ].append (vector_key )
234
+
235
+ # Now refresh TTLs for any keys that need it
236
+ if refresh_keys_by_idx and self .ttl_config :
237
+ # Get default TTL from config
238
+ ttl_minutes = None
239
+ if "default_ttl" in self .ttl_config :
240
+ ttl_minutes = self .ttl_config .get ("default_ttl" )
241
+
242
+ if ttl_minutes is not None :
243
+ ttl_seconds = int (ttl_minutes * 60 )
244
+ pipeline = self ._redis .pipeline ()
245
+
246
+ for keys in refresh_keys_by_idx .values ():
247
+ for key in keys :
248
+ # Only refresh TTL if the key exists and has a TTL
249
+ ttl = self ._redis .ttl (key )
250
+ if ttl > 0 : # Only refresh if key exists and has TTL
251
+ pipeline .expire (key , ttl_seconds )
252
+
253
+ pipeline .execute ()
198
254
199
255
def _batch_put_ops (
200
256
self ,
@@ -219,20 +275,35 @@ def _batch_put_ops(
219
275
doc_ids : dict [tuple [str , str ], str ] = {}
220
276
store_docs : list [RedisDocument ] = []
221
277
store_keys : list [str ] = []
278
+ ttl_tracking : dict [str , tuple [list [str ], Optional [float ]]] = (
279
+ {}
280
+ ) # Tracks keys that need TTL + their TTL values
222
281
223
282
# Generate IDs for PUT operations
224
283
for _ , op in put_ops :
225
284
if op .value is not None :
226
285
generated_doc_id = str (ULID ())
227
286
namespace = _namespace_to_text (op .namespace )
228
287
doc_ids [(namespace , op .key )] = generated_doc_id
288
+ # Track TTL for this document if specified
289
+ if hasattr (op , "ttl" ) and op .ttl is not None :
290
+ main_key = f"{ STORE_PREFIX } { REDIS_KEY_SEPARATOR } { generated_doc_id } "
291
+ ttl_tracking [main_key ] = ([], op .ttl )
229
292
230
293
# Load store docs with explicit keys
231
294
for doc in operations :
232
295
store_key = (doc ["prefix" ], doc ["key" ])
233
296
doc_id = doc_ids [store_key ]
297
+ # Remove TTL fields - they're not needed with Redis native TTL
298
+ if "ttl_minutes" in doc :
299
+ doc .pop ("ttl_minutes" , None )
300
+ if "expires_at" in doc :
301
+ doc .pop ("expires_at" , None )
302
+
234
303
store_docs .append (doc )
235
- store_keys .append (f"{ STORE_PREFIX } { REDIS_KEY_SEPARATOR } { doc_id } " )
304
+ redis_key = f"{ STORE_PREFIX } { REDIS_KEY_SEPARATOR } { doc_id } "
305
+ store_keys .append (redis_key )
306
+
236
307
if store_docs :
237
308
self .store_index .load (store_docs , keys = store_keys )
238
309
@@ -260,12 +331,21 @@ def _batch_put_ops(
260
331
"updated_at" : datetime .now (timezone .utc ).timestamp (),
261
332
}
262
333
)
263
- vector_keys .append (
264
- f"{ STORE_VECTOR_PREFIX } { REDIS_KEY_SEPARATOR } { doc_id } "
265
- )
334
+ vector_key = f"{ STORE_VECTOR_PREFIX } { REDIS_KEY_SEPARATOR } { doc_id } "
335
+ vector_keys .append (vector_key )
336
+
337
+ # Add this vector key to the related keys list for TTL
338
+ main_key = f"{ STORE_PREFIX } { REDIS_KEY_SEPARATOR } { doc_id } "
339
+ if main_key in ttl_tracking :
340
+ ttl_tracking [main_key ][0 ].append (vector_key )
341
+
266
342
if vector_docs :
267
343
self .vector_index .load (vector_docs , keys = vector_keys )
268
344
345
+ # Now apply TTLs after all documents are loaded
346
+ for main_key , (related_keys , ttl_minutes ) in ttl_tracking .items ():
347
+ self ._apply_ttl_to_keys (main_key , related_keys , ttl_minutes )
348
+
269
349
def _batch_search_ops (
270
350
self ,
271
351
search_ops : list [tuple [int , SearchOp ]],
@@ -316,6 +396,8 @@ def _batch_search_ops(
316
396
317
397
# Process results maintaining order and applying filters
318
398
items = []
399
+ refresh_keys = [] # Track keys that need TTL refreshed
400
+
319
401
for store_key , store_doc in zip (result_map .keys (), store_docs ):
320
402
if store_doc :
321
403
vector_result = result_map [store_key ]
@@ -345,6 +427,16 @@ def _batch_search_ops(
345
427
if not matches :
346
428
continue
347
429
430
+ # If refresh_ttl is true, add to list for refreshing
431
+ if op .refresh_ttl :
432
+ refresh_keys .append (store_key )
433
+ # Also find associated vector keys with same ID
434
+ doc_id = store_key .split (":" )[- 1 ]
435
+ vector_key = (
436
+ f"{ STORE_VECTOR_PREFIX } { REDIS_KEY_SEPARATOR } { doc_id } "
437
+ )
438
+ refresh_keys .append (vector_key )
439
+
348
440
items .append (
349
441
_row_to_search_item (
350
442
_decode_ns (store_doc ["prefix" ]),
@@ -353,13 +445,31 @@ def _batch_search_ops(
353
445
)
354
446
)
355
447
448
+ # Refresh TTL if requested
449
+ if op .refresh_ttl and refresh_keys and self .ttl_config :
450
+ # Get default TTL from config
451
+ ttl_minutes = None
452
+ if "default_ttl" in self .ttl_config :
453
+ ttl_minutes = self .ttl_config .get ("default_ttl" )
454
+
455
+ if ttl_minutes is not None :
456
+ ttl_seconds = int (ttl_minutes * 60 )
457
+ pipeline = self ._redis .pipeline ()
458
+ for key in refresh_keys :
459
+ # Only refresh TTL if the key exists and has a TTL
460
+ ttl = self ._redis .ttl (key )
461
+ if ttl > 0 : # Only refresh if key exists and has TTL
462
+ pipeline .expire (key , ttl_seconds )
463
+ pipeline .execute ()
464
+
356
465
results [idx ] = items
357
466
else :
358
467
# Regular search
359
468
query = Query (query_str )
360
469
# Get all potential matches for filtering
361
470
res = self .store_index .search (query )
362
471
items = []
472
+ refresh_keys = [] # Track keys that need TTL refreshed
363
473
364
474
for doc in res .docs :
365
475
data = json .loads (doc .json )
@@ -378,13 +488,41 @@ def _batch_search_ops(
378
488
break
379
489
if not matches :
380
490
continue
491
+
492
+ # If refresh_ttl is true, add the key to refresh list
493
+ if op .refresh_ttl :
494
+ refresh_keys .append (doc .id )
495
+ # Also find associated vector keys with same ID
496
+ doc_id = doc .id .split (":" )[- 1 ]
497
+ vector_key = (
498
+ f"{ STORE_VECTOR_PREFIX } { REDIS_KEY_SEPARATOR } { doc_id } "
499
+ )
500
+ refresh_keys .append (vector_key )
501
+
381
502
items .append (_row_to_search_item (_decode_ns (data ["prefix" ]), data ))
382
503
383
504
# Apply pagination after filtering
384
505
if params :
385
506
limit , offset = params
386
507
items = items [offset : offset + limit ]
387
508
509
+ # Refresh TTL if requested
510
+ if op .refresh_ttl and refresh_keys and self .ttl_config :
511
+ # Get default TTL from config
512
+ ttl_minutes = None
513
+ if "default_ttl" in self .ttl_config :
514
+ ttl_minutes = self .ttl_config .get ("default_ttl" )
515
+
516
+ if ttl_minutes is not None :
517
+ ttl_seconds = int (ttl_minutes * 60 )
518
+ pipeline = self ._redis .pipeline ()
519
+ for key in refresh_keys :
520
+ # Only refresh TTL if the key exists and has a TTL
521
+ ttl = self ._redis .ttl (key )
522
+ if ttl > 0 : # Only refresh if key exists and has TTL
523
+ pipeline .expire (key , ttl_seconds )
524
+ pipeline .execute ()
525
+
388
526
results [idx ] = items
389
527
390
528
async def abatch (self , ops : Iterable [Op ]) -> list [Result ]:
0 commit comments