Skip to content

Add support for per-request metadata/headers in GenerativeModel.gener… #707

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Closed
wants to merge 1 commit into from
Closed
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
32 changes: 32 additions & 0 deletions google/generativeai/generative_models.py
Original file line number Diff line number Diff line change
Expand Up @@ -244,6 +244,7 @@ def generate_content(
tools: content_types.FunctionLibraryType | None = None,
tool_config: content_types.ToolConfigType | None = None,
request_options: helper_types.RequestOptionsType | None = None,
extra_headers: dict[str, str] | None = None,
) -> generation_types.GenerateContentResponse:
"""A multipurpose function to generate responses from the model.

Expand Down Expand Up @@ -318,7 +319,22 @@ def generate_content(

if request_options is None:
request_options = {}
else:
request_options = (
request_options.copy()
if isinstance(request_options, dict)
else vars(request_options).copy()
)

metadata = list(request_options.get("metadata", []))
if extra_headers:
metadata += list(extra_headers.items())

if metadata:
request_options["metadata"] = metadata
elif "metadata" in request_options:
del request_options["metadata"]

try:
if stream:
with generation_types.rewrite_stream_error():
Expand Down Expand Up @@ -351,6 +367,7 @@ async def generate_content_async(
tools: content_types.FunctionLibraryType | None = None,
tool_config: content_types.ToolConfigType | None = None,
request_options: helper_types.RequestOptionsType | None = None,
extra_headers: dict[str, str] | None = None,
) -> generation_types.AsyncGenerateContentResponse:
"""The async version of `GenerativeModel.generate_content`."""
if not contents:
Expand All @@ -372,6 +389,21 @@ async def generate_content_async(

if request_options is None:
request_options = {}
else:
request_options = (
request_options.copy()
if isinstance(request_options, dict)
else vars(request_options).copy()
)

metadata = list(request_options.get("metadata", []))
if extra_headers:
metadata += list(extra_headers.items())

if metadata:
request_options["metadata"] = metadata
elif "metadata" in request_options:
del request_options["metadata"]

try:
if stream:
Expand Down
32 changes: 32 additions & 0 deletions tests/test_generative_models.py
Original file line number Diff line number Diff line change
Expand Up @@ -333,6 +333,38 @@ def test_stream_prompt_feedback_not_blocked(self):
text = "".join(chunk.text for chunk in response)
self.assertEqual(text, "first second")

def test_generate_content_with_extra_headers(self):
"""Test that extra_headers are properly passed to the underlying client in generate_content."""
model = generative_models.GenerativeModel(model_name="gemini-1.5-flash")
extra_headers = {"Helicone-User-Id": "test-user-123", "Custom-Header": "test-value"}

self.responses["generate_content"].append(simple_response("world!"))
model.generate_content("Hello", extra_headers=extra_headers)

request_options = self.observed_kwargs[-1]
self.assertIn("metadata", request_options)

metadata = dict(request_options["metadata"])
self.assertEqual("test-user-123", metadata["Helicone-User-Id"])
self.assertEqual("test-value", metadata["Custom-Header"])

def test_extra_headers_with_existing_request_options(self):
"""Test that extra_headers are correctly merged with existing request_options."""
model = generative_models.GenerativeModel(model_name="gemini-1.5-flash")
extra_headers = {"Helicone-User-Id": "test-user-456"}
request_options = {"timeout": 30, "metadata": [("Existing-Header", "existing-value")]}

self.responses["generate_content"].append(simple_response("world!"))
model.generate_content("Hello", extra_headers=extra_headers, request_options=request_options)

observed_options = self.observed_kwargs[-1]
self.assertIn("metadata", observed_options)

metadata = dict(observed_options["metadata"])
self.assertEqual("test-user-456", metadata["Helicone-User-Id"])
self.assertEqual("existing-value", metadata["Existing-Header"])
self.assertEqual(30, observed_options["timeout"])

@parameterized.named_parameters(
[
dict(testcase_name="test_cached_content_as_id", cached_content="test-cached-content"),
Expand Down
38 changes: 38 additions & 0 deletions tests/test_generative_models_async.py
Original file line number Diff line number Diff line change
Expand Up @@ -91,6 +91,44 @@ async def test_basic(self):

self.assertEqual(response.text, "world!")

async def test_generate_content_async_with_extra_headerss(self):
self.client.generate_content = unittest.mock.AsyncMock()
request = unittest.mock.ANY
extra_headers = {
"Helicone-User-Id": "test-user-123",
"Custom-Header": "test-value"
}

model = generative_models.GenerativeModel(model_name="gemini-1.5-flash")
await model.generate_content_async("Hello", extra_headers=extra_headers)

expected_request_options = {"metadata": list(extra_headers.items())}

self.client.generate_content.assert_called_once_with(
request, **expected_request_options
)

async def test_extra_headers_with_existing_request_optionss(self):
self.client.generate_content = unittest.mock.AsyncMock()
request = unittest.mock.ANY

extra_headers = {"Helicone-User-Id": "test-user-456"}
request_options = {
"timeout": 30,
"metadata": [("Existing-Header", "existing-value")]
}

model = generative_models.GenerativeModel(model_name="gemini-1.5-flash")

await model.generate_content_async("Hello", extra_headers=extra_headers, request_options=request_options)

expected_metadata = [("Existing-Header", "existing-value")] + list(extra_headers.items())
expected_request_options = {"timeout": 30, "metadata": expected_metadata}

self.client.generate_content.assert_called_once_with(
request, **expected_request_options
)

async def test_streaming(self):
# Generate text from text prompt
model = generative_models.GenerativeModel(model_name="gemini-1.5-flash")
Expand Down
Loading