From 4d27d5c7d9348478ec886d4eab458f463d5f0ffe Mon Sep 17 00:00:00 2001 From: aaakash06 Date: Sat, 8 Mar 2025 21:52:14 +0545 Subject: [PATCH] Add support for per-request metadata/headers in GenerativeModel.generate_content --- google/generativeai/generative_models.py | 32 ++++++++++++++++++++ tests/test_generative_models.py | 32 ++++++++++++++++++++ tests/test_generative_models_async.py | 38 ++++++++++++++++++++++++ 3 files changed, 102 insertions(+) diff --git a/google/generativeai/generative_models.py b/google/generativeai/generative_models.py index 8d331a9f6..9d068b785 100644 --- a/google/generativeai/generative_models.py +++ b/google/generativeai/generative_models.py @@ -244,6 +244,7 @@ def generate_content( tools: content_types.FunctionLibraryType | None = None, tool_config: content_types.ToolConfigType | None = None, request_options: helper_types.RequestOptionsType | None = None, + extra_headers: dict[str, str] | None = None, ) -> generation_types.GenerateContentResponse: """A multipurpose function to generate responses from the model. @@ -318,7 +319,22 @@ def generate_content( if request_options is None: request_options = {} + else: + request_options = ( + request_options.copy() + if isinstance(request_options, dict) + else vars(request_options).copy() + ) + metadata = list(request_options.get("metadata", [])) + if extra_headers: + metadata += list(extra_headers.items()) + + if metadata: + request_options["metadata"] = metadata + elif "metadata" in request_options: + del request_options["metadata"] + try: if stream: with generation_types.rewrite_stream_error(): @@ -351,6 +367,7 @@ async def generate_content_async( tools: content_types.FunctionLibraryType | None = None, tool_config: content_types.ToolConfigType | None = None, request_options: helper_types.RequestOptionsType | None = None, + extra_headers: dict[str, str] | None = None, ) -> generation_types.AsyncGenerateContentResponse: """The async version of `GenerativeModel.generate_content`.""" if not contents: @@ -372,6 +389,21 @@ async def generate_content_async( if request_options is None: request_options = {} + else: + request_options = ( + request_options.copy() + if isinstance(request_options, dict) + else vars(request_options).copy() + ) + + metadata = list(request_options.get("metadata", [])) + if extra_headers: + metadata += list(extra_headers.items()) + + if metadata: + request_options["metadata"] = metadata + elif "metadata" in request_options: + del request_options["metadata"] try: if stream: diff --git a/tests/test_generative_models.py b/tests/test_generative_models.py index 74469e5b8..c2d0ac9fc 100644 --- a/tests/test_generative_models.py +++ b/tests/test_generative_models.py @@ -333,6 +333,38 @@ def test_stream_prompt_feedback_not_blocked(self): text = "".join(chunk.text for chunk in response) self.assertEqual(text, "first second") + def test_generate_content_with_extra_headers(self): + """Test that extra_headers are properly passed to the underlying client in generate_content.""" + model = generative_models.GenerativeModel(model_name="gemini-1.5-flash") + extra_headers = {"Helicone-User-Id": "test-user-123", "Custom-Header": "test-value"} + + self.responses["generate_content"].append(simple_response("world!")) + model.generate_content("Hello", extra_headers=extra_headers) + + request_options = self.observed_kwargs[-1] + self.assertIn("metadata", request_options) + + metadata = dict(request_options["metadata"]) + self.assertEqual("test-user-123", metadata["Helicone-User-Id"]) + self.assertEqual("test-value", metadata["Custom-Header"]) + + def test_extra_headers_with_existing_request_options(self): + """Test that extra_headers are correctly merged with existing request_options.""" + model = generative_models.GenerativeModel(model_name="gemini-1.5-flash") + extra_headers = {"Helicone-User-Id": "test-user-456"} + request_options = {"timeout": 30, "metadata": [("Existing-Header", "existing-value")]} + + self.responses["generate_content"].append(simple_response("world!")) + model.generate_content("Hello", extra_headers=extra_headers, request_options=request_options) + + observed_options = self.observed_kwargs[-1] + self.assertIn("metadata", observed_options) + + metadata = dict(observed_options["metadata"]) + self.assertEqual("test-user-456", metadata["Helicone-User-Id"]) + self.assertEqual("existing-value", metadata["Existing-Header"]) + self.assertEqual(30, observed_options["timeout"]) + @parameterized.named_parameters( [ dict(testcase_name="test_cached_content_as_id", cached_content="test-cached-content"), diff --git a/tests/test_generative_models_async.py b/tests/test_generative_models_async.py index b37c65235..c7ad69fb6 100644 --- a/tests/test_generative_models_async.py +++ b/tests/test_generative_models_async.py @@ -91,6 +91,44 @@ async def test_basic(self): self.assertEqual(response.text, "world!") + async def test_generate_content_async_with_extra_headerss(self): + self.client.generate_content = unittest.mock.AsyncMock() + request = unittest.mock.ANY + extra_headers = { + "Helicone-User-Id": "test-user-123", + "Custom-Header": "test-value" + } + + model = generative_models.GenerativeModel(model_name="gemini-1.5-flash") + await model.generate_content_async("Hello", extra_headers=extra_headers) + + expected_request_options = {"metadata": list(extra_headers.items())} + + self.client.generate_content.assert_called_once_with( + request, **expected_request_options + ) + + async def test_extra_headers_with_existing_request_optionss(self): + self.client.generate_content = unittest.mock.AsyncMock() + request = unittest.mock.ANY + + extra_headers = {"Helicone-User-Id": "test-user-456"} + request_options = { + "timeout": 30, + "metadata": [("Existing-Header", "existing-value")] + } + + model = generative_models.GenerativeModel(model_name="gemini-1.5-flash") + + await model.generate_content_async("Hello", extra_headers=extra_headers, request_options=request_options) + + expected_metadata = [("Existing-Header", "existing-value")] + list(extra_headers.items()) + expected_request_options = {"timeout": 30, "metadata": expected_metadata} + + self.client.generate_content.assert_called_once_with( + request, **expected_request_options + ) + async def test_streaming(self): # Generate text from text prompt model = generative_models.GenerativeModel(model_name="gemini-1.5-flash")