Skip to content

Commit

Permalink
refactor: improve serialization/deserialization of callables (to hand…
Browse files Browse the repository at this point in the history
…le class methods and static methods) (#8683)

* progress

* refinements

* tidy up

* release note
  • Loading branch information
anakin87 authored Jan 8, 2025
1 parent e6059e6 commit 5539f6c
Show file tree
Hide file tree
Showing 7 changed files with 112 additions and 76 deletions.
58 changes: 43 additions & 15 deletions haystack/utils/callable_serialization.py
Original file line number Diff line number Diff line change
Expand Up @@ -3,9 +3,9 @@
# SPDX-License-Identifier: Apache-2.0

import inspect
from typing import Callable, Optional
from typing import Callable

from haystack import DeserializationError
from haystack.core.errors import DeserializationError, SerializationError
from haystack.utils.type_serialization import thread_safe_import


Expand All @@ -16,32 +16,60 @@ def serialize_callable(callable_handle: Callable) -> str:
:param callable_handle: The callable to serialize
:return: The full path of the callable
"""
module = inspect.getmodule(callable_handle)
try:
full_arg_spec = inspect.getfullargspec(callable_handle)
is_instance_method = bool(full_arg_spec.args and full_arg_spec.args[0] == "self")
except TypeError:
is_instance_method = False
if is_instance_method:
raise SerializationError("Serialization of instance methods is not supported.")

# __qualname__ contains the fully qualified path we need for classmethods and staticmethods
qualname = getattr(callable_handle, "__qualname__", "")
if "<lambda>" in qualname:
raise SerializationError("Serialization of lambdas is not supported.")
if "<locals>" in qualname:
raise SerializationError("Serialization of nested functions is not supported.")

name = qualname or callable_handle.__name__

# Get the full package path of the function
module = inspect.getmodule(callable_handle)
if module is not None:
full_path = f"{module.__name__}.{callable_handle.__name__}"
full_path = f"{module.__name__}.{name}"
else:
full_path = callable_handle.__name__
full_path = name
return full_path


def deserialize_callable(callable_handle: str) -> Optional[Callable]:
def deserialize_callable(callable_handle: str) -> Callable:
"""
Deserializes a callable given its full import path as a string.
:param callable_handle: The full path of the callable_handle
:return: The callable
:raises DeserializationError: If the callable cannot be found
"""
parts = callable_handle.split(".")
module_name = ".".join(parts[:-1])
function_name = parts[-1]
module_name, *attribute_chain = callable_handle.split(".")

try:
module = thread_safe_import(module_name)
current = thread_safe_import(module_name)
except Exception as e:
raise DeserializationError(f"Could not locate the module of the callable: {module_name}") from e
deserialized_callable = getattr(module, function_name, None)
if not deserialized_callable:
raise DeserializationError(f"Could not locate the callable: {function_name}")
return deserialized_callable
raise DeserializationError(f"Could not locate the module: {module_name}") from e

for attr in attribute_chain:
try:
attr_value = getattr(current, attr)
except AttributeError as e:
raise DeserializationError(f"Could not find attribute '{attr}' in {current.__name__}") from e

# when the attribute is a classmethod, we need the underlying function
if isinstance(attr_value, (classmethod, staticmethod)):
attr_value = attr_value.__func__

current = attr_value

if not callable(current):
raise DeserializationError(f"The final attribute is not callable: {current}")

return current
Original file line number Diff line number Diff line change
@@ -0,0 +1,6 @@
---
enhancements:
- |
Improve serialization and deserialization of callables.
We now allow serialization of classmethods and staticmethods
and explicitly prohibit serialization of instance methods, lambdas, and nested functions.
3 changes: 2 additions & 1 deletion test/components/generators/chat/test_hugging_face_local.py
Original file line number Diff line number Diff line change
Expand Up @@ -142,7 +142,7 @@ def test_to_dict(self, model_info_mock):
token=Secret.from_env_var("ENV_VAR", strict=False),
generation_kwargs={"n": 5},
stop_words=["stop", "words"],
streaming_callback=lambda x: x,
streaming_callback=streaming_callback_handler,
chat_template="irrelevant",
)

Expand All @@ -155,6 +155,7 @@ def test_to_dict(self, model_info_mock):
assert init_params["huggingface_pipeline_kwargs"]["model"] == "NousResearch/Llama-2-7b-chat-hf"
assert "token" not in init_params["huggingface_pipeline_kwargs"]
assert init_params["generation_kwargs"] == {"max_new_tokens": 512, "n": 5, "stop_sequences": ["stop", "words"]}
assert init_params["streaming_callback"] == "chat.test_hugging_face_local.streaming_callback_handler"
assert init_params["chat_template"] == "irrelevant"

def test_from_dict(self, model_info_mock):
Expand Down
33 changes: 3 additions & 30 deletions test/components/generators/chat/test_openai.py
Original file line number Diff line number Diff line change
@@ -1,26 +1,24 @@
# SPDX-FileCopyrightText: 2022-present deepset GmbH <[email protected]>
#
# SPDX-License-Identifier: Apache-2.0
from unittest.mock import MagicMock, patch
from unittest.mock import patch
import pytest

from typing import Iterator

import logging
import os
import json
from datetime import datetime

from openai import OpenAIError
from openai.types.chat import ChatCompletion, ChatCompletionChunk, ChatCompletionMessage, ChatCompletionMessageToolCall
from openai.types.chat.chat_completion import Choice
from openai.types.chat.chat_completion_message_tool_call import Function
from openai.types.chat import chat_completion_chunk
from openai import Stream

from haystack.components.generators.utils import print_streaming_chunk
from haystack.dataclasses import StreamingChunk
from haystack.utils.auth import Secret
from haystack.dataclasses import ChatMessage, Tool, ToolCall, ChatRole, TextContent
from haystack.dataclasses import ChatMessage, Tool, ToolCall
from haystack.components.generators.chat.openai import OpenAIChatGenerator


Expand Down Expand Up @@ -212,31 +210,6 @@ def test_to_dict_with_parameters(self, monkeypatch):
},
}

def test_to_dict_with_lambda_streaming_callback(self, monkeypatch):
monkeypatch.setenv("OPENAI_API_KEY", "test-api-key")
component = OpenAIChatGenerator(
model="gpt-4o-mini",
streaming_callback=lambda x: x,
api_base_url="test-base-url",
generation_kwargs={"max_tokens": 10, "some_test_param": "test-params"},
)
data = component.to_dict()
assert data == {
"type": "haystack.components.generators.chat.openai.OpenAIChatGenerator",
"init_parameters": {
"api_key": {"env_vars": ["OPENAI_API_KEY"], "strict": True, "type": "env_var"},
"model": "gpt-4o-mini",
"organization": None,
"api_base_url": "test-base-url",
"max_retries": None,
"timeout": None,
"streaming_callback": "chat.test_openai.<lambda>",
"generation_kwargs": {"max_tokens": 10, "some_test_param": "test-params"},
"tools": None,
"tools_strict": False,
},
}

def test_from_dict(self, monkeypatch):
monkeypatch.setenv("OPENAI_API_KEY", "fake-api-key")
data = {
Expand Down
22 changes: 0 additions & 22 deletions test/components/generators/test_openai.py
Original file line number Diff line number Diff line change
Expand Up @@ -90,28 +90,6 @@ def test_to_dict_with_parameters(self, monkeypatch):
},
}

def test_to_dict_with_lambda_streaming_callback(self, monkeypatch):
monkeypatch.setenv("OPENAI_API_KEY", "test-api-key")
component = OpenAIGenerator(
model="gpt-4o-mini",
streaming_callback=lambda x: x,
api_base_url="test-base-url",
generation_kwargs={"max_tokens": 10, "some_test_param": "test-params"},
)
data = component.to_dict()
assert data == {
"type": "haystack.components.generators.openai.OpenAIGenerator",
"init_parameters": {
"api_key": {"env_vars": ["OPENAI_API_KEY"], "strict": True, "type": "env_var"},
"model": "gpt-4o-mini",
"system_prompt": None,
"organization": None,
"api_base_url": "test-base-url",
"streaming_callback": "test_openai.<lambda>",
"generation_kwargs": {"max_tokens": 10, "some_test_param": "test-params"},
},
}

def test_from_dict(self, monkeypatch):
monkeypatch.setenv("OPENAI_API_KEY", "fake-api-key")
data = {
Expand Down
3 changes: 0 additions & 3 deletions test/components/preprocessors/test_document_splitter.py
Original file line number Diff line number Diff line change
Expand Up @@ -467,9 +467,6 @@ def test_from_dict_with_splitting_function(self):
Test the from_dict class method of the DocumentSplitter class when a custom splitting function is provided.
"""

def custom_split(text):
return text.split(".")

data = {
"type": "haystack.components.preprocessors.document_splitter.DocumentSplitter",
"init_parameters": {"split_by": "function", "splitting_function": serialize_callable(custom_split)},
Expand Down
63 changes: 58 additions & 5 deletions test/utils/test_callable_serialization.py
Original file line number Diff line number Diff line change
Expand Up @@ -3,8 +3,7 @@
# SPDX-License-Identifier: Apache-2.0
import pytest
import requests

from haystack import DeserializationError
from haystack.core.errors import DeserializationError, SerializationError
from haystack.components.generators.utils import print_streaming_chunk
from haystack.utils import serialize_callable, deserialize_callable

Expand All @@ -13,6 +12,19 @@ def some_random_callable_for_testing(some_ignored_arg: str):
pass


class TestClass:
@classmethod
def class_method(cls):
pass

@staticmethod
def static_method():
pass

def my_method(self):
pass


def test_callable_serialization():
result = serialize_callable(some_random_callable_for_testing)
assert result == "test_callable_serialization.some_random_callable_for_testing"
Expand All @@ -28,6 +40,28 @@ def test_callable_serialization_non_local():
assert result == "requests.api.get"


def test_callable_serialization_instance_methods_fail():
with pytest.raises(SerializationError):
serialize_callable(TestClass.my_method)

instance = TestClass()
with pytest.raises(SerializationError):
serialize_callable(instance.my_method)


def test_lambda_serialization_fail():
with pytest.raises(SerializationError):
serialize_callable(lambda x: x)


def test_nested_function_serialization_fail():
def my_fun():
pass

with pytest.raises(SerializationError):
serialize_callable(my_fun)


def test_callable_deserialization():
result = serialize_callable(some_random_callable_for_testing)
fn = deserialize_callable(result)
Expand All @@ -40,8 +74,27 @@ def test_callable_deserialization_non_local():
assert fn is requests.api.get


def test_callable_deserialization_error():
def test_classmethod_serialization_deserialization():
result = serialize_callable(TestClass.class_method)
fn = deserialize_callable(result)
assert fn == TestClass.class_method


def test_staticmethod_serialization_deserialization():
result = serialize_callable(TestClass.static_method)
fn = deserialize_callable(result)
assert fn == TestClass.static_method


def test_callable_deserialization_errors():
# module does not exist
with pytest.raises(DeserializationError):
deserialize_callable("this.is.not.a.valid.module")
deserialize_callable("nonexistent_module.function")

# function does not exist
with pytest.raises(DeserializationError):
deserialize_callable("os.nonexistent_function")

# attribute is not callable
with pytest.raises(DeserializationError):
deserialize_callable("sys.foobar")
deserialize_callable("os.name")

0 comments on commit 5539f6c

Please sign in to comment.