Skip to content

Commit 6331627

Browse files
committed
fix: handle both bytes and string Redis keys when decode_responses=True
1 parent b045e2b commit 6331627

File tree

6 files changed

+239
-14
lines changed

6 files changed

+239
-14
lines changed

langgraph/checkpoint/redis/aio.py

+9-4
Original file line numberDiff line numberDiff line change
@@ -33,6 +33,7 @@
3333
EMPTY_ID_SENTINEL,
3434
from_storage_safe_id,
3535
from_storage_safe_str,
36+
safely_decode,
3637
to_storage_safe_id,
3738
to_storage_safe_str,
3839
)
@@ -212,12 +213,14 @@ async def aget_tuple(self, config: RunnableConfig) -> Optional[CheckpointTuple]:
212213
# Get the blob keys
213214
blob_key_pattern = f"{CHECKPOINT_BLOB_PREFIX}:{to_storage_safe_id(doc_thread_id)}:{to_storage_safe_str(doc_checkpoint_ns)}:*"
214215
blob_keys = await self._redis.keys(blob_key_pattern)
215-
blob_keys = [key.decode() for key in blob_keys]
216+
# Use safely_decode to handle both string and bytes responses
217+
blob_keys = [safely_decode(key) for key in blob_keys]
216218

217219
# Also get checkpoint write keys that should have the same TTL
218220
write_key_pattern = f"{CHECKPOINT_WRITE_PREFIX}:{to_storage_safe_id(doc_thread_id)}:{to_storage_safe_str(doc_checkpoint_ns)}:{to_storage_safe_id(doc_checkpoint_id)}:*"
219221
write_keys = await self._redis.keys(write_key_pattern)
220-
write_keys = [key.decode() for key in write_keys]
222+
# Use safely_decode to handle both string and bytes responses
223+
write_keys = [safely_decode(key) for key in write_keys]
221224

222225
# Apply TTL to checkpoint, blob keys, and write keys
223226
ttl_minutes = self.ttl_config.get("default_ttl")
@@ -895,9 +898,11 @@ async def _aload_pending_writes(
895898
None,
896899
)
897900
matching_keys = await self._redis.keys(pattern=writes_key)
901+
# Use safely_decode to handle both string and bytes responses
902+
decoded_keys = [safely_decode(key) for key in matching_keys]
898903
parsed_keys = [
899-
BaseRedisSaver._parse_redis_checkpoint_writes_key(key.decode())
900-
for key in matching_keys
904+
BaseRedisSaver._parse_redis_checkpoint_writes_key(key)
905+
for key in decoded_keys
901906
]
902907
pending_writes = BaseRedisSaver._load_writes(
903908
self.serde,

langgraph/checkpoint/redis/ashallow.py

+13-5
Original file line numberDiff line numberDiff line change
@@ -34,6 +34,7 @@
3434
REDIS_KEY_SEPARATOR,
3535
BaseRedisSaver,
3636
)
37+
from langgraph.checkpoint.redis.util import safely_decode
3738

3839
SCHEMAS = [
3940
{
@@ -252,7 +253,9 @@ async def aput(
252253
# Process each existing blob key to determine if it should be kept or deleted
253254
if existing_blob_keys:
254255
for blob_key in existing_blob_keys:
255-
key_parts = blob_key.decode().split(REDIS_KEY_SEPARATOR)
256+
# Use safely_decode to handle both string and bytes responses
257+
decoded_key = safely_decode(blob_key)
258+
key_parts = decoded_key.split(REDIS_KEY_SEPARATOR)
256259
# The key format is checkpoint_blob:thread_id:checkpoint_ns:channel:version
257260
if len(key_parts) >= 5:
258261
channel = key_parts[3]
@@ -428,7 +431,8 @@ async def aget_tuple(self, config: RunnableConfig) -> Optional[CheckpointTuple]:
428431
)
429432
)
430433
blob_keys = await self._redis.keys(blob_key_pattern)
431-
blob_keys = [key.decode() for key in blob_keys]
434+
# Use safely_decode to handle both string and bytes responses
435+
blob_keys = [safely_decode(key) for key in blob_keys]
432436

433437
# Apply TTL
434438
ttl_minutes = self.ttl_config.get("default_ttl")
@@ -554,7 +558,9 @@ async def aput_writes(
554558
# Process each existing writes key to determine if it should be kept or deleted
555559
if existing_writes_keys:
556560
for write_key in existing_writes_keys:
557-
key_parts = write_key.decode().split(REDIS_KEY_SEPARATOR)
561+
# Use safely_decode to handle both string and bytes responses
562+
decoded_key = safely_decode(write_key)
563+
key_parts = decoded_key.split(REDIS_KEY_SEPARATOR)
558564
# The key format is checkpoint_write:thread_id:checkpoint_ns:checkpoint_id:task_id:idx
559565
if len(key_parts) >= 5:
560566
key_checkpoint_id = key_parts[3]
@@ -700,9 +706,11 @@ async def _aload_pending_writes(
700706
thread_id, checkpoint_ns, checkpoint_id, "*", None
701707
)
702708
matching_keys = await self._redis.keys(pattern=writes_key)
709+
# Use safely_decode to handle both string and bytes responses
710+
decoded_keys = [safely_decode(key) for key in matching_keys]
703711
parsed_keys = [
704-
BaseRedisSaver._parse_redis_checkpoint_writes_key(key.decode())
705-
for key in matching_keys
712+
BaseRedisSaver._parse_redis_checkpoint_writes_key(key)
713+
for key in decoded_keys
706714
]
707715
pending_writes = BaseRedisSaver._load_writes(
708716
self.serde,

langgraph/checkpoint/redis/base.py

+9-2
Original file line numberDiff line numberDiff line change
@@ -18,6 +18,7 @@
1818
from langgraph.checkpoint.serde.types import ChannelProtocol
1919

2020
from langgraph.checkpoint.redis.util import (
21+
safely_decode,
2122
to_storage_safe_id,
2223
to_storage_safe_str,
2324
)
@@ -509,9 +510,12 @@ def _load_pending_writes(
509510
# Cast the result to List[bytes] to help type checker
510511
matching_keys: List[bytes] = self._redis.keys(pattern=writes_key) # type: ignore[assignment]
511512

513+
# Use safely_decode to handle both string and bytes responses
514+
decoded_keys = [safely_decode(key) for key in matching_keys]
515+
512516
parsed_keys = [
513-
BaseRedisSaver._parse_redis_checkpoint_writes_key(key.decode())
514-
for key in matching_keys
517+
BaseRedisSaver._parse_redis_checkpoint_writes_key(key)
518+
for key in decoded_keys
515519
]
516520
pending_writes = BaseRedisSaver._load_writes(
517521
self.serde,
@@ -541,6 +545,9 @@ def _load_writes(
541545

542546
@staticmethod
543547
def _parse_redis_checkpoint_writes_key(redis_key: str) -> dict:
548+
# Ensure redis_key is a string
549+
redis_key = safely_decode(redis_key)
550+
544551
parts = redis_key.split(REDIS_KEY_SEPARATOR)
545552
# Ensure we have at least 6 parts
546553
if len(parts) < 6:

langgraph/checkpoint/redis/shallow.py

+11-3
Original file line numberDiff line numberDiff line change
@@ -26,6 +26,7 @@
2626
REDIS_KEY_SEPARATOR,
2727
BaseRedisSaver,
2828
)
29+
from langgraph.checkpoint.redis.util import safely_decode
2930

3031
SCHEMAS = [
3132
{
@@ -179,7 +180,9 @@ def put(
179180
# Process each existing blob key to determine if it should be kept or deleted
180181
if existing_blob_keys:
181182
for blob_key in existing_blob_keys:
182-
key_parts = blob_key.decode().split(REDIS_KEY_SEPARATOR)
183+
# Use safely_decode to handle both string and bytes responses
184+
decoded_key = safely_decode(blob_key)
185+
key_parts = decoded_key.split(REDIS_KEY_SEPARATOR)
183186
# The key format is checkpoint_blob:thread_id:checkpoint_ns:channel:version
184187
if len(key_parts) >= 5:
185188
channel = key_parts[3]
@@ -387,7 +390,10 @@ def get_tuple(self, config: RunnableConfig) -> Optional[CheckpointTuple]:
387390
thread_id, checkpoint_ns
388391
)
389392
)
390-
blob_keys = [key.decode() for key in self._redis.keys(blob_key_pattern)]
393+
# Use safely_decode to handle both string and bytes responses
394+
blob_keys = [
395+
safely_decode(key) for key in self._redis.keys(blob_key_pattern)
396+
]
391397

392398
# Apply TTL
393399
self._apply_ttl_to_keys(checkpoint_key, blob_keys)
@@ -524,7 +530,9 @@ def put_writes(
524530
# Process each existing writes key to determine if it should be kept or deleted
525531
if existing_writes_keys:
526532
for write_key in existing_writes_keys:
527-
key_parts = write_key.decode().split(REDIS_KEY_SEPARATOR)
533+
# Use safely_decode to handle both string and bytes responses
534+
decoded_key = safely_decode(write_key)
535+
key_parts = decoded_key.split(REDIS_KEY_SEPARATOR)
528536
# The key format is checkpoint_write:thread_id:checkpoint_ns:checkpoint_id:task_id:idx
529537
if len(key_parts) >= 5:
530538
key_checkpoint_id = key_parts[3]

langgraph/checkpoint/redis/util.py

+49
Original file line numberDiff line numberDiff line change
@@ -5,8 +5,14 @@
55
that is lexicographically sortable. Typically, checkpoints that need
66
sentinel values are from the first run of the graph, so this should
77
generally be correct.
8+
9+
This module also includes utility functions for safely handling Redis responses,
10+
including handling bytes vs string responses depending on how the Redis client is
11+
configured with decode_responses.
812
"""
913

14+
from typing import Any, Dict, List, Optional, Set, Tuple, Union
15+
1016
EMPTY_STRING_SENTINEL = "__empty__"
1117
EMPTY_ID_SENTINEL = "00000000-0000-0000-0000-000000000000"
1218

@@ -81,3 +87,46 @@ def from_storage_safe_id(value: str) -> str:
8187
return ""
8288
else:
8389
return value
90+
91+
92+
def safely_decode(obj: Any) -> Any:
93+
"""
94+
Safely decode Redis responses, handling both string and bytes types.
95+
96+
This is especially useful when working with Redis clients configured with
97+
different decode_responses settings. It recursively processes nested
98+
data structures (dicts, lists, tuples, sets).
99+
100+
Based on RedisVL's convert_bytes function (redisvl.redis.utils.convert_bytes)
101+
but implemented directly to avoid runtime import issues and ensure consistent
102+
behavior with sets and other data structures. See PR #34 and referenced
103+
implementation: https://github.com/redis/redis-vl-python/blob/9f22a9ad4c2166af6462b007833b456448714dd9/redisvl/redis/utils.py#L20
104+
105+
Args:
106+
obj: The object to decode. Can be a string, bytes, or a nested structure
107+
containing strings/bytes (dict, list, tuple, set).
108+
109+
Returns:
110+
The decoded object with all bytes converted to strings.
111+
"""
112+
if obj is None:
113+
return None
114+
elif isinstance(obj, bytes):
115+
try:
116+
return obj.decode("utf-8")
117+
except UnicodeDecodeError:
118+
# If decoding fails, return the original bytes
119+
return obj
120+
elif isinstance(obj, str):
121+
return obj
122+
elif isinstance(obj, dict):
123+
return {safely_decode(k): safely_decode(v) for k, v in obj.items()}
124+
elif isinstance(obj, list):
125+
return [safely_decode(item) for item in obj]
126+
elif isinstance(obj, tuple):
127+
return tuple(safely_decode(item) for item in obj)
128+
elif isinstance(obj, set):
129+
return {safely_decode(item) for item in obj}
130+
else:
131+
# For other types (int, float, bool, etc.), return as is
132+
return obj

tests/test_decode_responses.py

+148
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,148 @@
1+
"""Tests for Redis key decoding functionality."""
2+
3+
import os
4+
import time
5+
import uuid
6+
from typing import Any, Dict, Optional
7+
8+
import pytest
9+
from redis import Redis
10+
11+
from langgraph.checkpoint.redis.util import safely_decode
12+
13+
14+
def test_safely_decode_basic_types():
15+
"""Test safely_decode function with basic type inputs."""
16+
# Test with bytes
17+
assert safely_decode(b"test") == "test"
18+
19+
# Test with string
20+
assert safely_decode("test") == "test"
21+
22+
# Test with None
23+
assert safely_decode(None) is None
24+
25+
# Test with other types
26+
assert safely_decode(123) == 123
27+
assert safely_decode(1.23) == 1.23
28+
assert safely_decode(True) is True
29+
30+
31+
def test_safely_decode_nested_structures():
32+
"""Test safely_decode function with nested data structures."""
33+
# Test with dictionary
34+
assert safely_decode({b"key": b"value"}) == {"key": "value"}
35+
assert safely_decode({b"key1": b"value1", "key2": 123}) == {
36+
"key1": "value1",
37+
"key2": 123,
38+
}
39+
40+
# Test with nested dictionary
41+
nested_dict = {b"outer": {b"inner": b"value"}}
42+
assert safely_decode(nested_dict) == {"outer": {"inner": "value"}}
43+
44+
# Test with list
45+
assert safely_decode([b"item1", b"item2"]) == ["item1", "item2"]
46+
47+
# Test with tuple
48+
assert safely_decode((b"item1", b"item2")) == ("item1", "item2")
49+
50+
# Test with set
51+
decoded_set = safely_decode({b"item1", b"item2"})
52+
assert isinstance(decoded_set, set)
53+
assert "item1" in decoded_set
54+
assert "item2" in decoded_set
55+
56+
# Test with complex nested structure
57+
complex_struct = {
58+
b"key1": [b"list_item1", {b"nested_key": b"nested_value"}],
59+
b"key2": (b"tuple_item", 123),
60+
b"key3": {b"set_item1", b"set_item2"},
61+
}
62+
decoded = safely_decode(complex_struct)
63+
assert decoded["key1"][0] == "list_item1"
64+
assert decoded["key1"][1]["nested_key"] == "nested_value"
65+
assert decoded["key2"][0] == "tuple_item"
66+
assert decoded["key2"][1] == 123
67+
assert isinstance(decoded["key3"], set)
68+
assert "set_item1" in decoded["key3"]
69+
assert "set_item2" in decoded["key3"]
70+
71+
72+
@pytest.mark.parametrize("decode_responses", [True, False])
73+
def test_safely_decode_with_redis(decode_responses: bool, redis_url):
74+
"""Test safely_decode function with actual Redis responses using TestContainers."""
75+
r = Redis.from_url(redis_url, decode_responses=decode_responses)
76+
77+
try:
78+
# Clean up before test to ensure a clean state
79+
r.delete("test:string")
80+
r.delete("test:hash")
81+
r.delete("test:list")
82+
r.delete("test:set")
83+
84+
# Set up test data
85+
r.set("test:string", "value")
86+
r.hset("test:hash", mapping={"field1": "value1", "field2": "value2"})
87+
r.rpush("test:list", "item1", "item2", "item3")
88+
r.sadd("test:set", "member1", "member2")
89+
90+
# Test string value
91+
string_val = r.get("test:string")
92+
decoded_string = safely_decode(string_val)
93+
assert decoded_string == "value"
94+
95+
# Test hash value
96+
hash_val = r.hgetall("test:hash")
97+
decoded_hash = safely_decode(hash_val)
98+
assert decoded_hash == {"field1": "value1", "field2": "value2"}
99+
100+
# Test list value
101+
list_val = r.lrange("test:list", 0, -1)
102+
decoded_list = safely_decode(list_val)
103+
assert decoded_list == ["item1", "item2", "item3"]
104+
105+
# Test set value
106+
set_val = r.smembers("test:set")
107+
decoded_set = safely_decode(set_val)
108+
assert isinstance(decoded_set, set)
109+
assert "member1" in decoded_set
110+
assert "member2" in decoded_set
111+
112+
# Test key fetching
113+
keys = r.keys("test:*")
114+
decoded_keys = safely_decode(keys)
115+
assert sorted(decoded_keys) == sorted(
116+
["test:string", "test:hash", "test:list", "test:set"]
117+
)
118+
119+
finally:
120+
# Clean up after test
121+
r.delete("test:string")
122+
r.delete("test:hash")
123+
r.delete("test:list")
124+
r.delete("test:set")
125+
r.close()
126+
127+
128+
def test_safely_decode_unicode_error_handling():
129+
"""Test safely_decode function with invalid UTF-8 bytes."""
130+
# Create bytes that will cause UnicodeDecodeError
131+
invalid_utf8 = b"\xff\xfe\xfd"
132+
133+
# Should return the original bytes if it can't be decoded
134+
result = safely_decode(invalid_utf8)
135+
assert result == invalid_utf8
136+
137+
# Test with mixed valid and invalid in a complex structure
138+
mixed = {
139+
b"valid": b"This is valid UTF-8",
140+
b"invalid": invalid_utf8,
141+
b"nested": [b"valid", invalid_utf8],
142+
}
143+
144+
result = safely_decode(mixed)
145+
assert result["valid"] == "This is valid UTF-8"
146+
assert result["invalid"] == invalid_utf8
147+
assert result["nested"][0] == "valid"
148+
assert result["nested"][1] == invalid_utf8

0 commit comments

Comments
 (0)