Skip to content

Commit e734e6b

Browse files
wochingejulian-risch
authored andcommitted
fix: fix deserialization issues in multi-threading environments (#8651)
1 parent 55e7fbf commit e734e6b

File tree

3 files changed

+25
-4
lines changed

3 files changed

+25
-4
lines changed

haystack/core/pipeline/base.py

Lines changed: 2 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -2,7 +2,6 @@
22
#
33
# SPDX-License-Identifier: Apache-2.0
44

5-
import importlib
65
import itertools
76
from collections import defaultdict
87
from copy import deepcopy
@@ -26,7 +25,7 @@
2625
from haystack.core.serialization import DeserializationCallbacks, component_from_dict, component_to_dict
2726
from haystack.core.type_utils import _type_name, _types_are_compatible
2827
from haystack.marshal import Marshaller, YamlMarshaller
29-
from haystack.utils import is_in_jupyter
28+
from haystack.utils import is_in_jupyter, type_serialization
3029

3130
from .descriptions import find_pipeline_inputs, find_pipeline_outputs
3231
from .draw import _to_mermaid_image
@@ -161,7 +160,7 @@ def from_dict(
161160
# Import the module first...
162161
module, _ = component_data["type"].rsplit(".", 1)
163162
logger.debug("Trying to import module {module_name}", module_name=module)
164-
importlib.import_module(module)
163+
type_serialization.thread_safe_import(module)
165164
# ...then try again
166165
if component_data["type"] not in component.registry:
167166
raise PipelineError(

haystack/utils/type_serialization.py

Lines changed: 19 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -6,10 +6,14 @@
66
import inspect
77
import sys
88
import typing
9+
from threading import Lock
10+
from types import ModuleType
911
from typing import Any, get_args, get_origin
1012

1113
from haystack import DeserializationError
1214

15+
_import_lock = Lock()
16+
1317

1418
def serialize_type(target: Any) -> str:
1519
"""
@@ -132,7 +136,7 @@ def parse_generic_args(args_str):
132136
module = sys.modules.get(module_name)
133137
if not module:
134138
try:
135-
module = importlib.import_module(module_name)
139+
module = thread_safe_import(module_name)
136140
except ImportError as e:
137141
raise DeserializationError(f"Could not import the module: {module_name}") from e
138142

@@ -141,3 +145,17 @@ def parse_generic_args(args_str):
141145
raise DeserializationError(f"Could not locate the type: {type_name} in the module: {module_name}")
142146

143147
return deserialized_type
148+
149+
150+
def thread_safe_import(module_name: str) -> ModuleType:
151+
"""
152+
Import a module in a thread-safe manner.
153+
154+
Importing modules in a multi-threaded environment can lead to race conditions.
155+
This function ensures that the module is imported in a thread-safe manner without having impact
156+
on the performance of the import for single-threaded environments.
157+
158+
:param module_name: the module to import
159+
"""
160+
with _import_lock:
161+
return importlib.import_module(module_name)
Lines changed: 4 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,4 @@
1+
---
2+
fixes:
3+
- |
4+
Fixes issues with deserialization of components in multi-threaded environments.

0 commit comments

Comments
 (0)