From 8e3f64717f8c601bd7ef58b68e044a84deceb4eb Mon Sep 17 00:00:00 2001 From: Bohan Qu Date: Fri, 3 Jan 2025 22:06:58 +0800 Subject: [PATCH] feat: use importlib when deserializing callables (#8648) --- haystack/utils/callable_serialization.py | 9 +++++---- ...lib-when-deserializing-callable-1f36f07c4518c2cf.yaml | 6 ++++++ test/utils/test_callable_serialization.py | 9 +++++++++ 3 files changed, 20 insertions(+), 4 deletions(-) create mode 100644 releasenotes/notes/use-importlib-when-deserializing-callable-1f36f07c4518c2cf.yaml diff --git a/haystack/utils/callable_serialization.py b/haystack/utils/callable_serialization.py index 72a57c5ab3..3c6003135e 100644 --- a/haystack/utils/callable_serialization.py +++ b/haystack/utils/callable_serialization.py @@ -3,10 +3,10 @@ # SPDX-License-Identifier: Apache-2.0 import inspect -import sys from typing import Callable, Optional from haystack import DeserializationError +from haystack.utils.type_serialization import thread_safe_import def serialize_callable(callable_handle: Callable) -> str: @@ -37,9 +37,10 @@ def deserialize_callable(callable_handle: str) -> Optional[Callable]: parts = callable_handle.split(".") module_name = ".".join(parts[:-1]) function_name = parts[-1] - module = sys.modules.get(module_name, None) - if not module: - raise DeserializationError(f"Could not locate the module of the callable: {module_name}") + try: + module = thread_safe_import(module_name) + except Exception as e: + raise DeserializationError(f"Could not locate the module of the callable: {module_name}") from e deserialized_callable = getattr(module, function_name, None) if not deserialized_callable: raise DeserializationError(f"Could not locate the callable: {function_name}") diff --git a/releasenotes/notes/use-importlib-when-deserializing-callable-1f36f07c4518c2cf.yaml b/releasenotes/notes/use-importlib-when-deserializing-callable-1f36f07c4518c2cf.yaml new file mode 100644 index 0000000000..7160479cc5 --- /dev/null +++ b/releasenotes/notes/use-importlib-when-deserializing-callable-1f36f07c4518c2cf.yaml @@ -0,0 +1,6 @@ +--- +enhancements: + - | + Improved deserialization of callables by using `importlib` instead of `sys.modules`. + This change allows importing local functions and classes that are not in `sys.modules` + when deserializing callables. diff --git a/test/utils/test_callable_serialization.py b/test/utils/test_callable_serialization.py index c0afafa73e..941aa14cdf 100644 --- a/test/utils/test_callable_serialization.py +++ b/test/utils/test_callable_serialization.py @@ -1,8 +1,10 @@ # SPDX-FileCopyrightText: 2022-present deepset GmbH # # SPDX-License-Identifier: Apache-2.0 +import pytest import requests +from haystack import DeserializationError from haystack.components.generators.utils import print_streaming_chunk from haystack.utils import serialize_callable, deserialize_callable @@ -36,3 +38,10 @@ def test_callable_deserialization_non_local(): result = serialize_callable(requests.api.get) fn = deserialize_callable(result) assert fn is requests.api.get + + +def test_callable_deserialization_error(): + with pytest.raises(DeserializationError): + deserialize_callable("this.is.not.a.valid.module") + with pytest.raises(DeserializationError): + deserialize_callable("sys.foobar")