Skip to content

Commit 1a91365

Browse files
authored
fix: callables can be deserialized from fully qualified import path (#8788)
* fix: callables can be deserialized from fully qualified import path * fix: license header * fix: format * fix: types * fix? types * test: extend test case * format * add release notes
1 parent 379711f commit 1a91365

File tree

4 files changed

+43
-15
lines changed

4 files changed

+43
-15
lines changed
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,10 @@
1+
# SPDX-FileCopyrightText: 2022-present deepset GmbH <[email protected]>
2+
#
3+
# SPDX-License-Identifier: Apache-2.0
4+
5+
6+
def callable_to_deserialize(hello: str) -> str:
7+
"""
8+
A function to test callable deserialization.
9+
"""
10+
return f"{hello}, world!"

Diff for: haystack/utils/callable_serialization.py

+20-15
Original file line numberDiff line numberDiff line change
@@ -3,7 +3,7 @@
33
# SPDX-License-Identifier: Apache-2.0
44

55
import inspect
6-
from typing import Callable
6+
from typing import Any, Callable
77

88
from haystack.core.errors import DeserializationError, SerializationError
99
from haystack.utils.type_serialization import thread_safe_import
@@ -50,26 +50,31 @@ def deserialize_callable(callable_handle: str) -> Callable:
5050
:return: The callable
5151
:raises DeserializationError: If the callable cannot be found
5252
"""
53-
module_name, *attribute_chain = callable_handle.split(".")
53+
parts = callable_handle.split(".")
5454

55-
try:
56-
current = thread_safe_import(module_name)
57-
except Exception as e:
58-
raise DeserializationError(f"Could not locate the module: {module_name}") from e
59-
60-
for attr in attribute_chain:
55+
for i in range(len(parts), 0, -1):
56+
module_name = ".".join(parts[:i])
6157
try:
62-
attr_value = getattr(current, attr)
63-
except AttributeError as e:
64-
raise DeserializationError(f"Could not find attribute '{attr}' in {current.__name__}") from e
58+
mod: Any = thread_safe_import(module_name)
59+
except Exception:
60+
# keep reducing i until we find a valid module import
61+
continue
62+
63+
attr_value = mod
64+
for part in parts[i:]:
65+
try:
66+
attr_value = getattr(attr_value, part)
67+
except AttributeError as e:
68+
raise DeserializationError(f"Could not find attribute '{part}' in {attr_value.__name__}") from e
6569

6670
# when the attribute is a classmethod, we need the underlying function
6771
if isinstance(attr_value, (classmethod, staticmethod)):
6872
attr_value = attr_value.__func__
6973

70-
current = attr_value
74+
if not callable(attr_value):
75+
raise DeserializationError(f"The final attribute is not callable: {attr_value}")
7176

72-
if not callable(current):
73-
raise DeserializationError(f"The final attribute is not callable: {current}")
77+
return attr_value
7478

75-
return current
79+
# Fallback if we never find anything
80+
raise DeserializationError(f"Could not import '{callable_handle}' as a module or callable.")
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,4 @@
1+
---
2+
fixes:
3+
- |
4+
Callable deserialization now works for all fully qualified import paths.

Diff for: test/utils/test_callable_serialization.py

+9
Original file line numberDiff line numberDiff line change
@@ -1,10 +1,12 @@
11
# SPDX-FileCopyrightText: 2022-present deepset GmbH <[email protected]>
22
#
33
# SPDX-License-Identifier: Apache-2.0
4+
45
import pytest
56
import requests
67
from haystack.core.errors import DeserializationError, SerializationError
78
from haystack.components.generators.utils import print_streaming_chunk
9+
from haystack.testing.callable_serialization.random_callable import callable_to_deserialize
810
from haystack.utils import serialize_callable, deserialize_callable
911

1012

@@ -40,6 +42,13 @@ def test_callable_serialization_non_local():
4042
assert result == "requests.api.get"
4143

4244

45+
def test_fully_qualified_import_deserialization():
46+
func = deserialize_callable("haystack.testing.callable_serialization.random_callable.callable_to_deserialize")
47+
48+
assert func is callable_to_deserialize
49+
assert func("Hello") == "Hello, world!"
50+
51+
4352
def test_callable_serialization_instance_methods_fail():
4453
with pytest.raises(SerializationError):
4554
serialize_callable(TestClass.my_method)

0 commit comments

Comments
 (0)