diff --git a/google/generativeai/generative_models.py b/google/generativeai/generative_models.py index 8d331a9f6..f25d0eb63 100644 --- a/google/generativeai/generative_models.py +++ b/google/generativeai/generative_models.py @@ -244,6 +244,7 @@ def generate_content( tools: content_types.FunctionLibraryType | None = None, tool_config: content_types.ToolConfigType | None = None, request_options: helper_types.RequestOptionsType | None = None, + extra_headers: dict[str, str] | None = None, ) -> generation_types.GenerateContentResponse: """A multipurpose function to generate responses from the model. @@ -319,6 +320,14 @@ def generate_content( if request_options is None: request_options = {} + # Convert `extra_headers` to metadata format + if extra_headers: + metadata = [(k, v) for k, v in extra_headers.items()] + if "metadata" in request_options: + request_options["metadata"].extend(metadata) + else: + request_options["metadata"] = metadata + try: if stream: with generation_types.rewrite_stream_error(): @@ -351,6 +360,7 @@ async def generate_content_async( tools: content_types.FunctionLibraryType | None = None, tool_config: content_types.ToolConfigType | None = None, request_options: helper_types.RequestOptionsType | None = None, + extra_headers: dict[str, str] | None = None, ) -> generation_types.AsyncGenerateContentResponse: """The async version of `GenerativeModel.generate_content`.""" if not contents: @@ -373,6 +383,15 @@ async def generate_content_async( if request_options is None: request_options = {} + # Convert extra_headers to metadata format if provided + metadata = [] + if extra_headers: + metadata = [(k, v) for k, v in extra_headers.items()] + + # Add metadata to request_options + if metadata: + request_options["metadata"] = metadata + try: if stream: with generation_types.rewrite_stream_error(): diff --git a/tests/test_generative_models.py b/tests/test_generative_models.py index 74469e5b8..b9842e8d9 100644 --- a/tests/test_generative_models.py +++ b/tests/test_generative_models.py @@ -48,10 +48,16 @@ def __init__(self, test): def generate_content( self, request: protos.GenerateContentRequest, + *, + extra_headers: dict[str, str] | None = None, **kwargs, ) -> protos.GenerateContentResponse: self.test.assertIsInstance(request, protos.GenerateContentRequest) self.observed_requests.append(request) + # Convert extra_headers to metadata format and ensure it's in kwargs + if extra_headers: + metadata = [(k, v) for k, v in extra_headers.items()] + kwargs.setdefault("metadata", []).extend(metadata) self.observed_kwargs.append(kwargs) response = self.responses["generate_content"].pop(0) return response @@ -59,9 +65,15 @@ def generate_content( def stream_generate_content( self, request: protos.GetModelRequest, + *, + extra_headers: dict[str, str] | None = None, **kwargs, ) -> Iterable[protos.GenerateContentResponse]: self.observed_requests.append(request) + # Convert extra_headers to metadata format and ensure it's in kwargs + if extra_headers: + metadata = [(k, v) for k, v in extra_headers.items()] + kwargs.setdefault("metadata", []).extend(metadata) self.observed_kwargs.append(kwargs) response = self.responses["stream_generate_content"].pop(0) return response diff --git a/tests/test_generative_models_async.py b/tests/test_generative_models_async.py index b37c65235..5a26d4914 100644 --- a/tests/test_generative_models_async.py +++ b/tests/test_generative_models_async.py @@ -53,19 +53,35 @@ def add_client_method(f): @add_client_method async def generate_content( request: protos.GenerateContentRequest, + *, + extra_headers: dict[str, str] | None = None, **kwargs, ) -> protos.GenerateContentResponse: self.assertIsInstance(request, protos.GenerateContentRequest) self.observed_requests.append(request) + + if extra_headers: + metadata = [(k, v) for k, v in extra_headers.items()] + kwargs.setdefault("metadata", []).extend(metadata) # Merge with existing metadata if any + + self.observed_kwargs.append(kwargs) response = self.responses["generate_content"].pop(0) return response @add_client_method async def stream_generate_content( request: protos.GetModelRequest, + *, + extra_headers: dict[str, str] | None = None, **kwargs, ) -> Iterable[protos.GenerateContentResponse]: self.observed_requests.append(request) + # Convert extra_headers to metadata format and ensure it's in kwargs + if extra_headers: + metadata = [(k, v) for k, v in extra_headers.items()] + kwargs.setdefault("metadata", []).extend(metadata) + + self.observed_kwargs.append(kwargs) response = self.responses["stream_generate_content"].pop(0) return response