Skip to content

Commit 8e3f647

Browse files
authored
feat: use importlib when deserializing callables (#8648)
1 parent 7b4d9ba commit 8e3f647

File tree

3 files changed

+20
-4
lines changed

3 files changed

+20
-4
lines changed

haystack/utils/callable_serialization.py

Lines changed: 5 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -3,10 +3,10 @@
33
# SPDX-License-Identifier: Apache-2.0
44

55
import inspect
6-
import sys
76
from typing import Callable, Optional
87

98
from haystack import DeserializationError
9+
from haystack.utils.type_serialization import thread_safe_import
1010

1111

1212
def serialize_callable(callable_handle: Callable) -> str:
@@ -37,9 +37,10 @@ def deserialize_callable(callable_handle: str) -> Optional[Callable]:
3737
parts = callable_handle.split(".")
3838
module_name = ".".join(parts[:-1])
3939
function_name = parts[-1]
40-
module = sys.modules.get(module_name, None)
41-
if not module:
42-
raise DeserializationError(f"Could not locate the module of the callable: {module_name}")
40+
try:
41+
module = thread_safe_import(module_name)
42+
except Exception as e:
43+
raise DeserializationError(f"Could not locate the module of the callable: {module_name}") from e
4344
deserialized_callable = getattr(module, function_name, None)
4445
if not deserialized_callable:
4546
raise DeserializationError(f"Could not locate the callable: {function_name}")
Lines changed: 6 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,6 @@
1+
---
2+
enhancements:
3+
- |
4+
Improved deserialization of callables by using `importlib` instead of `sys.modules`.
5+
This change allows importing local functions and classes that are not in `sys.modules`
6+
when deserializing callables.

test/utils/test_callable_serialization.py

Lines changed: 9 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -1,8 +1,10 @@
11
# SPDX-FileCopyrightText: 2022-present deepset GmbH <[email protected]>
22
#
33
# SPDX-License-Identifier: Apache-2.0
4+
import pytest
45
import requests
56

7+
from haystack import DeserializationError
68
from haystack.components.generators.utils import print_streaming_chunk
79
from haystack.utils import serialize_callable, deserialize_callable
810

@@ -36,3 +38,10 @@ def test_callable_deserialization_non_local():
3638
result = serialize_callable(requests.api.get)
3739
fn = deserialize_callable(result)
3840
assert fn is requests.api.get
41+
42+
43+
def test_callable_deserialization_error():
44+
with pytest.raises(DeserializationError):
45+
deserialize_callable("this.is.not.a.valid.module")
46+
with pytest.raises(DeserializationError):
47+
deserialize_callable("sys.foobar")

0 commit comments

Comments
 (0)