Skip to content

Commit

Permalink
fix: component checks failing for components that return dataframes (#…
Browse files Browse the repository at this point in the history
…8873)

* fix: use is not to compare to sentinel value

* chore: release notes

* Update releasenotes/notes/fix-component-checks-with-ambiguous-truth-values-949c447b3702e427.yaml

Co-authored-by: David S. Batista <[email protected]>

* fix: another sentinel value

* test: also test base class

* add pandas as test dependency

* format

* Trigger CI

* mark test with xfail strict=False

---------

Co-authored-by: Sebastian Husch Lee <[email protected]>
Co-authored-by: David S. Batista <[email protected]>
Co-authored-by: anakin87 <[email protected]>
  • Loading branch information
4 people authored Feb 19, 2025
1 parent 93f361e commit 8c54f06
Show file tree
Hide file tree
Showing 10 changed files with 115 additions and 9 deletions.
2 changes: 1 addition & 1 deletion haystack/core/pipeline/base.py
Original file line number Diff line number Diff line change
Expand Up @@ -913,7 +913,7 @@ def _consume_component_inputs(component_name: str, component: Dict, inputs: Dict
greedy_inputs_to_remove = set()
for socket_name, socket in component["input_sockets"].items():
socket_inputs = component_inputs.get(socket_name, [])
socket_inputs = [sock["value"] for sock in socket_inputs if sock["value"] != _NO_OUTPUT_PRODUCED]
socket_inputs = [sock["value"] for sock in socket_inputs if sock["value"] is not _NO_OUTPUT_PRODUCED]
if socket_inputs:
if not socket.is_variadic:
# We only care about the first input provided to the socket.
Expand Down
16 changes: 11 additions & 5 deletions haystack/core/pipeline/component_checks.py
Original file line number Diff line number Diff line change
Expand Up @@ -103,7 +103,7 @@ def any_socket_value_from_predecessor_received(socket_inputs: List[Dict[str, Any
:param socket_inputs: Inputs for the component's socket.
"""
# When sender is None, the input was provided from outside the pipeline.
return any(inp["value"] != _NO_OUTPUT_PRODUCED and inp["sender"] is not None for inp in socket_inputs)
return any(inp["value"] is not _NO_OUTPUT_PRODUCED and inp["sender"] is not None for inp in socket_inputs)


def has_user_input(inputs: Dict) -> bool:
Expand Down Expand Up @@ -143,7 +143,7 @@ def any_socket_input_received(socket_inputs: List[Dict]) -> bool:
:param socket_inputs: Inputs for the socket.
"""
return any(inp["value"] != _NO_OUTPUT_PRODUCED for inp in socket_inputs)
return any(inp["value"] is not _NO_OUTPUT_PRODUCED for inp in socket_inputs)


def has_lazy_variadic_socket_received_all_inputs(socket: InputSocket, socket_inputs: List[Dict]) -> bool:
Expand All @@ -155,7 +155,9 @@ def has_lazy_variadic_socket_received_all_inputs(socket: InputSocket, socket_inp
"""
expected_senders = set(socket.senders)
actual_senders = {
sock["sender"] for sock in socket_inputs if sock["value"] != _NO_OUTPUT_PRODUCED and sock["sender"] is not None
sock["sender"]
for sock in socket_inputs
if sock["value"] is not _NO_OUTPUT_PRODUCED and sock["sender"] is not None
}

return expected_senders == actual_senders
Expand All @@ -182,15 +184,19 @@ def has_socket_received_all_inputs(socket: InputSocket, socket_inputs: List[Dict
return False

# The socket is greedy variadic and at least one input was produced, it is complete.
if socket.is_variadic and socket.is_greedy and any(sock["value"] != _NO_OUTPUT_PRODUCED for sock in socket_inputs):
if (
socket.is_variadic
and socket.is_greedy
and any(sock["value"] is not _NO_OUTPUT_PRODUCED for sock in socket_inputs)
):
return True

# The socket is lazy variadic and all expected inputs were produced.
if is_socket_lazy_variadic(socket) and has_lazy_variadic_socket_received_all_inputs(socket, socket_inputs):
return True

# The socket is not variadic and the only expected input is complete.
return not socket.is_variadic and socket_inputs[0]["value"] != _NO_OUTPUT_PRODUCED
return not socket.is_variadic and socket_inputs[0]["value"] is not _NO_OUTPUT_PRODUCED


def all_predecessors_executed(component: Dict, inputs: Dict) -> bool:
Expand Down
1 change: 1 addition & 0 deletions pyproject.toml
Original file line number Diff line number Diff line change
Expand Up @@ -93,6 +93,7 @@ extra-dependencies = [
"langdetect", # TextLanguageRouter and DocumentLanguageClassifier
"openai-whisper>=20231106", # LocalWhisperTranscriber
"arrow>=1.3.0", # Jinja2TimeExtension
"pandas", # Needed for pipeline tests with components that return dataframes

# NamedEntityExtractor
"spacy>=3.8,<3.9",
Expand Down
Original file line number Diff line number Diff line change
@@ -0,0 +1,5 @@
---
fixes:
- |
Pipelines with components that return plain pandas dataframes failed.
The comparison of socket values is now 'is not' instead of '!=' to avoid errors with dataframes.
4 changes: 3 additions & 1 deletion test/components/generators/chat/test_hugging_face_api.py
Original file line number Diff line number Diff line change
Expand Up @@ -583,7 +583,9 @@ def test_live_run_serverless_streaming(self):
not os.environ.get("HF_API_TOKEN", None),
reason="Export an env var called HF_API_TOKEN containing the Hugging Face token to run this test.",
)
@pytest.mark.flaky(reruns=3, reruns_delay=10)
@pytest.mark.xfail(
reason="The Hugging Face API can be unstable and this test may fail intermittently", strict=False
)
def test_live_run_with_tools(self, tools):
"""
We test the round trip: generate tool call, pass tool message, generate response.
Expand Down
29 changes: 27 additions & 2 deletions test/core/pipeline/features/conftest.py
Original file line number Diff line number Diff line change
Expand Up @@ -4,6 +4,7 @@
import re
import pytest
import asyncio
import pandas as pd

from pytest_bdd import when, then, parsers

Expand Down Expand Up @@ -142,15 +143,39 @@ def draw_pipeline(pipeline_data: Tuple[Pipeline, List[PipelineRunData]], request
@then("it should return the expected result")
def check_pipeline_result(pipeline_result: List[Tuple[_PipelineResult, PipelineRunData]]):
for res, data in pipeline_result:
assert res.outputs == data.expected_outputs
compare_outputs_with_dataframes(res.outputs, data.expected_outputs)


@then("components are called with the expected inputs")
def check_component_calls(pipeline_result: List[Tuple[_PipelineResult, PipelineRunData]]):
for res, data in pipeline_result:
assert res.component_calls == data.expected_component_calls
assert compare_outputs_with_dataframes(res.component_calls, data.expected_component_calls)


@then(parsers.parse("it must have raised {exception_class_name}"))
def check_pipeline_raised(pipeline_result: Exception, exception_class_name: str):
assert pipeline_result.__class__.__name__ == exception_class_name


def compare_outputs_with_dataframes(actual: Dict, expected: Dict) -> bool:
"""
Compare two component_calls or pipeline outputs dictionaries where values may contain DataFrames.
"""
assert actual.keys() == expected.keys()

for key in actual:
actual_data = actual[key]
expected_data = expected[key]

assert actual_data.keys() == expected_data.keys()

for data_key in actual_data:
actual_value = actual_data[data_key]
expected_value = expected_data[data_key]

if isinstance(actual_value, pd.DataFrame) and isinstance(expected_value, pd.DataFrame):
assert actual_value.equals(expected_value)
else:
assert actual_value == expected_value

return True
1 change: 1 addition & 0 deletions test/core/pipeline/features/pipeline_run.feature
Original file line number Diff line number Diff line change
Expand Up @@ -52,6 +52,7 @@ Feature: Pipeline running
| with a component that has dynamic default inputs |
| with a component that has variadic dynamic default inputs |
| that is a file conversion pipeline with two joiners |
| that has components returning dataframes |

Scenario Outline: Running a bad Pipeline
Given a pipeline <kind>
Expand Down
31 changes: 31 additions & 0 deletions test/core/pipeline/features/test_run.py
Original file line number Diff line number Diff line change
Expand Up @@ -4,6 +4,7 @@

from pytest_bdd import scenarios, given
import pytest
import pandas as pd

from haystack import Document, component
from haystack.document_stores.types import DuplicatePolicy
Expand Down Expand Up @@ -5079,3 +5080,33 @@ def pipeline_that_converts_files(pipeline_class):
)
],
)


@given("a pipeline that has components returning dataframes", target_fixture="pipeline_data")
def pipeline_has_components_returning_dataframes(pipeline_class):
def get_df():
return pd.DataFrame({"a": [1, 2], "b": [1, 2]})

@component
class DataFramer:
@component.output_types(dataframe=pd.DataFrame)
def run(self, dataframe: pd.DataFrame) -> Dict[str, Any]:
return {"dataframe": get_df()}

pp = pipeline_class(max_runs_per_component=1)

pp.add_component("df_1", DataFramer())
pp.add_component("df_2", DataFramer())

pp.connect("df_1", "df_2")

return (
pp,
[
PipelineRunData(
inputs={"df_1": {"dataframe": get_df()}},
expected_outputs={"df_2": {"dataframe": get_df()}},
expected_component_calls={("df_1", 1): {"dataframe": get_df()}, ("df_2", 1): {"dataframe": get_df()}},
)
],
)
23 changes: 23 additions & 0 deletions test/core/pipeline/test_component_checks.py
Original file line number Diff line number Diff line change
Expand Up @@ -9,6 +9,9 @@
from haystack.core.component.types import InputSocket, OutputSocket, Variadic, GreedyVariadic


import pandas as pd


@pytest.fixture
def basic_component():
"""Basic component with one mandatory and one optional input."""
Expand Down Expand Up @@ -130,6 +133,26 @@ def test_component_missing_mandatory_input(self, basic_component):
inputs = {"optional_input": [{"sender": "previous_component", "value": "test"}]}
assert can_component_run(basic_component, inputs) is False

# We added these tests because a component that returned a pandas dataframe caused the pipeline to fail.
# Previously, we compared the value of the socket using '!=' which leads to an error with dataframes.
# Instead, we use 'is not' to compare with the sentinel value.
def test_sockets_with_ambiguous_truth_value(self, basic_component, greedy_variadic_socket, regular_socket):
inputs = {
"mandatory_input": [{"sender": "previous_component", "value": pd.DataFrame.from_dict([{"value": 42}])}]
}

assert are_all_sockets_ready(basic_component, inputs, only_check_mandatory=True) is True
assert any_socket_value_from_predecessor_received(inputs["mandatory_input"]) is True
assert any_socket_input_received(inputs["mandatory_input"]) is True
assert (
has_lazy_variadic_socket_received_all_inputs(
basic_component["input_sockets"]["mandatory_input"], inputs["mandatory_input"]
)
is True
)
assert has_socket_received_all_inputs(greedy_variadic_socket, inputs["mandatory_input"]) is True
assert has_socket_received_all_inputs(regular_socket, inputs["mandatory_input"]) is True

def test_component_with_no_trigger_but_all_inputs(self, basic_component):
"""
Test case where all mandatory inputs are present with valid values,
Expand Down
12 changes: 12 additions & 0 deletions test/core/pipeline/test_pipeline_base.py
Original file line number Diff line number Diff line change
Expand Up @@ -8,6 +8,8 @@

import pytest

import pandas as pd

from haystack import Document
from haystack.core.component import component
from haystack.core.component.types import InputSocket, OutputSocket, Variadic, GreedyVariadic, _empty
Expand Down Expand Up @@ -1625,3 +1627,13 @@ def test__consume_component_inputs(self, input_sockets, component_inputs, expect
# Verify
assert consumed == expected_consumed
assert inputs["test_component"] == expected_remaining

def test__consume_component_inputs_with_df(self, regular_input_socket):
component = {"input_sockets": {"input1": regular_input_socket}}
inputs = {
"test_component": {"input1": [{"sender": "sender1", "value": pd.DataFrame({"a": [1, 2], "b": [1, 2]})}]}
}

consumed = PipelineBase._consume_component_inputs("test_component", component, inputs)

assert consumed["input1"].equals(pd.DataFrame({"a": [1, 2], "b": [1, 2]}))

0 comments on commit 8c54f06

Please sign in to comment.