|
3 | 3 | # SPDX-License-Identifier: Apache-2.0
|
4 | 4 |
|
5 | 5 | import inspect
|
6 |
| -from typing import Callable |
| 6 | +from typing import Any, Callable |
7 | 7 |
|
8 | 8 | from haystack.core.errors import DeserializationError, SerializationError
|
9 | 9 | from haystack.utils.type_serialization import thread_safe_import
|
@@ -50,26 +50,31 @@ def deserialize_callable(callable_handle: str) -> Callable:
|
50 | 50 | :return: The callable
|
51 | 51 | :raises DeserializationError: If the callable cannot be found
|
52 | 52 | """
|
53 |
| - module_name, *attribute_chain = callable_handle.split(".") |
| 53 | + parts = callable_handle.split(".") |
54 | 54 |
|
55 |
| - try: |
56 |
| - current = thread_safe_import(module_name) |
57 |
| - except Exception as e: |
58 |
| - raise DeserializationError(f"Could not locate the module: {module_name}") from e |
59 |
| - |
60 |
| - for attr in attribute_chain: |
| 55 | + for i in range(len(parts), 0, -1): |
| 56 | + module_name = ".".join(parts[:i]) |
61 | 57 | try:
|
62 |
| - attr_value = getattr(current, attr) |
63 |
| - except AttributeError as e: |
64 |
| - raise DeserializationError(f"Could not find attribute '{attr}' in {current.__name__}") from e |
| 58 | + mod: Any = thread_safe_import(module_name) |
| 59 | + except Exception: |
| 60 | + # keep reducing i until we find a valid module import |
| 61 | + continue |
| 62 | + |
| 63 | + attr_value = mod |
| 64 | + for part in parts[i:]: |
| 65 | + try: |
| 66 | + attr_value = getattr(attr_value, part) |
| 67 | + except AttributeError as e: |
| 68 | + raise DeserializationError(f"Could not find attribute '{part}' in {attr_value.__name__}") from e |
65 | 69 |
|
66 | 70 | # when the attribute is a classmethod, we need the underlying function
|
67 | 71 | if isinstance(attr_value, (classmethod, staticmethod)):
|
68 | 72 | attr_value = attr_value.__func__
|
69 | 73 |
|
70 |
| - current = attr_value |
| 74 | + if not callable(attr_value): |
| 75 | + raise DeserializationError(f"The final attribute is not callable: {attr_value}") |
71 | 76 |
|
72 |
| - if not callable(current): |
73 |
| - raise DeserializationError(f"The final attribute is not callable: {current}") |
| 77 | + return attr_value |
74 | 78 |
|
75 |
| - return current |
| 79 | + # Fallback if we never find anything |
| 80 | + raise DeserializationError(f"Could not import '{callable_handle}' as a module or callable.") |
0 commit comments