Skip to content

Commit 61d163e

Browse files
authored
feat: Add timeout, max_retries to all generators and async support to AnthropicVertexChatGenerator (#1952)
* Add timeout and max_retries to AnthropicChatGenerator * Add timeout, max_retries and async client to AnthropicVertexChatGenerator * Add async integration test for AnthropicVertexChatGenerator * Also add timeout and max_retries to regular generator * Update live test
1 parent 15e12c7 commit 61d163e

File tree

6 files changed

+129
-10
lines changed

6 files changed

+129
-10
lines changed

integrations/anthropic/src/haystack_integrations/components/generators/anthropic/chat/chat_generator.py

Lines changed: 24 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -195,6 +195,9 @@ def __init__(
195195
generation_kwargs: Optional[Dict[str, Any]] = None,
196196
ignore_tools_thinking_messages: bool = True,
197197
tools: Optional[Union[List[Tool], Toolset]] = None,
198+
*,
199+
timeout: Optional[float] = None,
200+
max_retries: Optional[int] = None,
198201
):
199202
"""
200203
Creates an instance of AnthropicChatGenerator.
@@ -222,16 +225,32 @@ def __init__(
222225
use is detected. See the Anthropic [tools](https://docs.anthropic.com/en/docs/tool-use#chain-of-thought-tool-use)
223226
for more details.
224227
:param tools: A list of Tool objects or a Toolset that the model can use. Each tool should have a unique name.
225-
228+
:param timeout:
229+
Timeout for Anthropic client calls. If not set, it defaults to the default set by the Anthropic client.
230+
:param max_retries:
231+
Maximum number of retries to attempt for failed requests. If not set, it defaults to the default set by
232+
the Anthropic client.
226233
"""
227234
_check_duplicate_tool_names(list(tools or [])) # handles Toolset as well
228235

229236
self.api_key = api_key
230237
self.model = model
231238
self.generation_kwargs = generation_kwargs or {}
232239
self.streaming_callback = streaming_callback
233-
self.client = Anthropic(api_key=self.api_key.resolve_value())
234-
self.async_client = AsyncAnthropic(api_key=self.api_key.resolve_value())
240+
self.timeout = timeout
241+
self.max_retries = max_retries
242+
243+
client_kwargs: Dict[str, Any] = {"api_key": api_key.resolve_value()}
244+
# We do this since timeout=None is not the same as not setting it in Anthropic
245+
if timeout is not None:
246+
client_kwargs["timeout"] = timeout
247+
# We do this since max_retries must be an int when passing to Anthropic
248+
if max_retries is not None:
249+
client_kwargs["max_retries"] = max_retries
250+
251+
self.client = Anthropic(**client_kwargs)
252+
self.async_client = AsyncAnthropic(**client_kwargs)
253+
235254
self.ignore_tools_thinking_messages = ignore_tools_thinking_messages
236255
self.tools = tools
237256

@@ -257,6 +276,8 @@ def to_dict(self) -> Dict[str, Any]:
257276
api_key=self.api_key.to_dict(),
258277
ignore_tools_thinking_messages=self.ignore_tools_thinking_messages,
259278
tools=serialize_tools_or_toolset(self.tools),
279+
timeout=self.timeout,
280+
max_retries=self.max_retries,
260281
)
261282

262283
@classmethod

integrations/anthropic/src/haystack_integrations/components/generators/anthropic/chat/vertex_chat_generator.py

Lines changed: 24 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -6,7 +6,7 @@
66
from haystack.tools import Tool, _check_duplicate_tool_names, deserialize_tools_or_toolset_inplace
77
from haystack.utils import deserialize_callable, serialize_callable
88

9-
from anthropic import AnthropicVertex
9+
from anthropic import AnthropicVertex, AsyncAnthropicVertex
1010

1111
from .chat_generator import AnthropicChatGenerator
1212

@@ -68,6 +68,9 @@ def __init__(
6868
generation_kwargs: Optional[Dict[str, Any]] = None,
6969
ignore_tools_thinking_messages: bool = True,
7070
tools: Optional[List[Tool]] = None,
71+
*,
72+
timeout: Optional[float] = None,
73+
max_retries: Optional[int] = None,
7174
):
7275
"""
7376
Creates an instance of AnthropicVertexChatGenerator.
@@ -96,6 +99,11 @@ def __init__(
9699
use is detected. See the Anthropic [tools](https://docs.anthropic.com/en/docs/tool-use#chain-of-thought-tool-use)
97100
for more details.
98101
:param tools: A list of Tool objects that the model can use. Each tool should have a unique name.
102+
:param timeout:
103+
Timeout for Anthropic client calls. If not set, it defaults to the default set by the Anthropic client.
104+
:param max_retries:
105+
Maximum number of retries to attempt for failed requests. If not set, it defaults to the default set by
106+
the Anthropic client.
99107
"""
100108
_check_duplicate_tool_names(tools)
101109
self.region = region or os.environ.get("REGION")
@@ -105,9 +113,20 @@ def __init__(
105113
self.streaming_callback = streaming_callback
106114
self.ignore_tools_thinking_messages = ignore_tools_thinking_messages
107115
self.tools = tools
116+
self.timeout = timeout
117+
self.max_retries = max_retries
108118

109-
# mypy is not happy that we override the type of the client
110-
self.client = AnthropicVertex(region=self.region, project_id=self.project_id) # type: ignore
119+
client_kwargs: Dict[str, Any] = {"region": self.region, "project_id": self.project_id}
120+
# We do this since timeout=None is not the same as not setting it in Anthropic
121+
if timeout is not None:
122+
client_kwargs["timeout"] = timeout
123+
# We do this since max_retries must be an int when passing to Anthropic
124+
if max_retries is not None:
125+
client_kwargs["max_retries"] = max_retries
126+
127+
# mypy is not happy that we override the type of the clients
128+
self.client = AnthropicVertex(**client_kwargs) # type: ignore[assignment]
129+
self.async_client = AsyncAnthropicVertex(**client_kwargs) # type: ignore[assignment]
111130

112131
def to_dict(self) -> Dict[str, Any]:
113132
"""
@@ -128,6 +147,8 @@ def to_dict(self) -> Dict[str, Any]:
128147
generation_kwargs=self.generation_kwargs,
129148
ignore_tools_thinking_messages=self.ignore_tools_thinking_messages,
130149
tools=serialized_tools,
150+
timeout=self.timeout,
151+
max_retries=self.max_retries,
131152
)
132153

133154
@classmethod

integrations/anthropic/src/haystack_integrations/components/generators/anthropic/generator.py

Lines changed: 19 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -64,6 +64,9 @@ def __init__(
6464
streaming_callback: Optional[Callable[[StreamingChunk], None]] = None,
6565
system_prompt: Optional[str] = None,
6666
generation_kwargs: Optional[Dict[str, Any]] = None,
67+
*,
68+
timeout: Optional[float] = None,
69+
max_retries: Optional[int] = None,
6770
):
6871
"""
6972
Initialize the AnthropicGenerator.
@@ -79,12 +82,24 @@ def __init__(
7982
self.generation_kwargs = generation_kwargs or {}
8083
self.streaming_callback = streaming_callback
8184
self.system_prompt = system_prompt
82-
self.client = Anthropic(api_key=self.api_key.resolve_value())
85+
self.timeout = timeout
86+
self.max_retries = max_retries
87+
8388
self.include_thinking = self.generation_kwargs.pop("include_thinking", True)
8489
self.thinking_tag = self.generation_kwargs.pop("thinking_tag", "thinking")
8590
self.thinking_tag_start = f"<{self.thinking_tag}>" if self.thinking_tag else ""
8691
self.thinking_tag_end = f"</{self.thinking_tag}>\n\n" if self.thinking_tag else "\n\n"
8792

93+
client_kwargs: Dict[str, Any] = {"api_key": api_key.resolve_value()}
94+
# We do this since timeout=None is not the same as not setting it in Anthropic
95+
if timeout is not None:
96+
client_kwargs["timeout"] = timeout
97+
# We do this since max_retries must be an int when passing to Anthropic
98+
if max_retries is not None:
99+
client_kwargs["max_retries"] = max_retries
100+
101+
self.client = Anthropic(**client_kwargs)
102+
88103
def _get_telemetry_data(self) -> Dict[str, Any]:
89104
"""
90105
Get telemetry data for the component.
@@ -103,11 +118,13 @@ def to_dict(self) -> Dict[str, Any]:
103118
callback_name = serialize_callable(self.streaming_callback) if self.streaming_callback else None
104119
return default_to_dict(
105120
self,
121+
api_key=self.api_key.to_dict(),
106122
model=self.model,
107123
streaming_callback=callback_name,
108124
system_prompt=self.system_prompt,
109125
generation_kwargs=self.generation_kwargs,
110-
api_key=self.api_key.to_dict(),
126+
timeout=self.timeout,
127+
max_retries=self.max_retries,
111128
)
112129

113130
@classmethod

integrations/anthropic/tests/test_chat_generator.py

Lines changed: 9 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -165,6 +165,8 @@ def test_to_dict_default(self, monkeypatch):
165165
"ignore_tools_thinking_messages": True,
166166
"generation_kwargs": {},
167167
"tools": None,
168+
"timeout": None,
169+
"max_retries": None,
168170
},
169171
}
170172

@@ -181,6 +183,8 @@ def test_to_dict_with_parameters(self, monkeypatch):
181183
streaming_callback=print_streaming_chunk,
182184
generation_kwargs={"max_tokens": 10, "some_test_param": "test-params"},
183185
tools=[tool],
186+
timeout=10.0,
187+
max_retries=1,
184188
)
185189
data = component.to_dict()
186190

@@ -207,6 +211,8 @@ def test_to_dict_with_parameters(self, monkeypatch):
207211
"type": "haystack.tools.tool.Tool",
208212
}
209213
],
214+
"timeout": 10.0,
215+
"max_retries": 1,
210216
},
211217
}
212218

@@ -697,6 +703,8 @@ def test_serde_in_pipeline(self):
697703
},
698704
}
699705
],
706+
"timeout": None,
707+
"max_retries": None,
700708
},
701709
}
702710
},
@@ -775,7 +783,7 @@ def __call__(self, chunk: StreamingChunk) -> None:
775783
self.responses += chunk.content if chunk.content else ""
776784

777785
callback = Callback()
778-
component = AnthropicChatGenerator(streaming_callback=callback)
786+
component = AnthropicChatGenerator(streaming_callback=callback, timeout=30.0, max_retries=1)
779787
results = component.run([ChatMessage.from_user("What's the capital of France?")])
780788

781789
assert len(results["replies"]) == 1

integrations/anthropic/tests/test_generator.py

Lines changed: 12 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -47,6 +47,8 @@ def test_to_dict_default(self, monkeypatch):
4747
"streaming_callback": None,
4848
"system_prompt": None,
4949
"generation_kwargs": {},
50+
"timeout": None,
51+
"max_retries": None,
5052
},
5153
}
5254

@@ -57,6 +59,8 @@ def test_to_dict_with_parameters(self, monkeypatch):
5759
streaming_callback=print_streaming_chunk,
5860
system_prompt="test-prompt",
5961
generation_kwargs={"max_tokens": 10, "some_test_param": "test-params"},
62+
timeout=10.0,
63+
max_retries=1,
6064
)
6165
data = component.to_dict()
6266
assert data == {
@@ -67,6 +71,8 @@ def test_to_dict_with_parameters(self, monkeypatch):
6771
"system_prompt": "test-prompt",
6872
"streaming_callback": "haystack.components.generators.utils.print_streaming_chunk",
6973
"generation_kwargs": {"max_tokens": 10, "some_test_param": "test-params"},
74+
"timeout": 10.0,
75+
"max_retries": 1,
7076
},
7177
}
7278

@@ -80,6 +86,8 @@ def test_from_dict(self, monkeypatch):
8086
"system_prompt": "test-prompt",
8187
"streaming_callback": "haystack.components.generators.utils.print_streaming_chunk",
8288
"generation_kwargs": {"max_tokens": 10, "some_test_param": "test-params"},
89+
"timeout": None,
90+
"max_retries": None,
8391
},
8492
}
8593
component = AnthropicGenerator.from_dict(data)
@@ -88,6 +96,8 @@ def test_from_dict(self, monkeypatch):
8896
assert component.system_prompt == "test-prompt"
8997
assert component.generation_kwargs == {"max_tokens": 10, "some_test_param": "test-params"}
9098
assert component.api_key == Secret.from_env_var("ANTHROPIC_API_KEY")
99+
assert component.timeout is None
100+
assert component.max_retries is None
91101

92102
def test_from_dict_fail_wo_env_var(self, monkeypatch):
93103
monkeypatch.delenv("ANTHROPIC_API_KEY", raising=False)
@@ -99,6 +109,8 @@ def test_from_dict_fail_wo_env_var(self, monkeypatch):
99109
"system_prompt": "test-prompt",
100110
"streaming_callback": "haystack.components.generators.utils.print_streaming_chunk",
101111
"generation_kwargs": {"max_tokens": 10, "some_test_param": "test-params"},
112+
"timeout": None,
113+
"max_retries": None,
102114
},
103115
}
104116
with pytest.raises(ValueError, match="None of the .* environment variables are set"):

integrations/anthropic/tests/test_vertex_chat_generator.py

Lines changed: 41 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -58,6 +58,8 @@ def test_to_dict_default(self):
5858
"generation_kwargs": {},
5959
"ignore_tools_thinking_messages": True,
6060
"tools": None,
61+
"timeout": None,
62+
"max_retries": None,
6163
},
6264
}
6365

@@ -67,6 +69,9 @@ def test_to_dict_with_parameters(self):
6769
project_id="test-project-id",
6870
streaming_callback=print_streaming_chunk,
6971
generation_kwargs={"max_tokens": 10, "some_test_param": "test-params"},
72+
ignore_tools_thinking_messages=False,
73+
timeout=10.0,
74+
max_retries=1,
7075
)
7176
data = component.to_dict()
7277
assert data == {
@@ -80,8 +85,10 @@ def test_to_dict_with_parameters(self):
8085
"model": "claude-3-5-sonnet@20240620",
8186
"streaming_callback": "haystack.components.generators.utils.print_streaming_chunk",
8287
"generation_kwargs": {"max_tokens": 10, "some_test_param": "test-params"},
83-
"ignore_tools_thinking_messages": True,
88+
"ignore_tools_thinking_messages": False,
8489
"tools": None,
90+
"timeout": 10.0,
91+
"max_retries": 1,
8592
},
8693
}
8794

@@ -98,6 +105,9 @@ def test_from_dict(self):
98105
"streaming_callback": "haystack.components.generators.utils.print_streaming_chunk",
99106
"generation_kwargs": {"max_tokens": 10, "some_test_param": "test-params"},
100107
"ignore_tools_thinking_messages": True,
108+
"tools": None,
109+
"timeout": None,
110+
"max_retries": None,
101111
},
102112
}
103113
component = AnthropicVertexChatGenerator.from_dict(data)
@@ -106,6 +116,9 @@ def test_from_dict(self):
106116
assert component.project_id == "test-project-id"
107117
assert component.streaming_callback is print_streaming_chunk
108118
assert component.generation_kwargs == {"max_tokens": 10, "some_test_param": "test-params"}
119+
assert component.ignore_tools_thinking_messages is True
120+
assert component.timeout is None
121+
assert component.max_retries is None
109122

110123
def test_run(self, chat_messages, mock_chat_completion):
111124
component = AnthropicVertexChatGenerator(region="us-central1", project_id="test-project-id")
@@ -173,3 +186,30 @@ def test_default_inference_params(self, chat_messages):
173186

174187
# Anthropic messages API is similar for AnthropicVertex and Anthropic endpoint,
175188
# remaining tests are skipped for AnthropicVertexChatGenerator as they are already tested in AnthropicChatGenerator.
189+
190+
191+
class TestAnthropicVertexChatGeneratorAsync:
192+
@pytest.mark.asyncio
193+
@pytest.mark.skipif(
194+
not (os.environ.get("REGION", None) or os.environ.get("PROJECT_ID", None)),
195+
reason="Authenticate with GCP and set env variables REGION and PROJECT_ID to run this test.",
196+
)
197+
@pytest.mark.integration
198+
async def test_live_run_async(self):
199+
"""
200+
Integration test that the async run method of AnthropicVertexChatGenerator works correctly
201+
"""
202+
component = AnthropicVertexChatGenerator(
203+
region=os.environ.get("REGION"),
204+
project_id=os.environ.get("PROJECT_ID"),
205+
model="claude-3-5-sonnet@20240620",
206+
)
207+
results = await component.run_async(messages=[ChatMessage.from_user("What's the capital of France?")])
208+
assert len(results["replies"]) == 1
209+
message: ChatMessage = results["replies"][0]
210+
assert "Paris" in message.text
211+
assert "claude-3-5-sonnet-20240620" in message.meta["model"]
212+
assert message.meta["finish_reason"] == "end_turn"
213+
214+
# Anthropic messages API is similar for AnthropicVertex and Anthropic endpoint,
215+
# remaining tests are skipped for AnthropicVertexChatGenerator as they are already tested in AnthropicChatGenerator.

0 commit comments

Comments
 (0)