diff --git a/lib/sycamore/poetry.lock b/lib/sycamore/poetry.lock index 090426ac7..1ab201e2c 100644 --- a/lib/sycamore/poetry.lock +++ b/lib/sycamore/poetry.lock @@ -9954,4 +9954,4 @@ weaviate = ["weaviate-client"] [metadata] lock-version = "2.0" python-versions = ">=3.9,<3.13" -content-hash = "fda09c7256e9192143850b2748859c429b88eb8d2db0a957d489a7abe954409d" +content-hash = "b3d1c82fdc7f6d50b745d3acdc7474b375a0d89ecf9e3ec02286910d15ee0b91" diff --git a/lib/sycamore/sycamore/llms/gemini.py b/lib/sycamore/sycamore/llms/gemini.py index 73dca060a..76a2bde5a 100644 --- a/lib/sycamore/sycamore/llms/gemini.py +++ b/lib/sycamore/sycamore/llms/gemini.py @@ -44,7 +44,7 @@ class Gemini(LLM): cache: A cache object to use for caching results. """ - @requires_modules("google-genai") + @requires_modules("google-genai", extra="google-genai") def __init__( self, model_name: Union[GeminiModels, str], @@ -82,6 +82,9 @@ def get_generate_kwargs(self, prompt: RenderedPrompt, llm_kwargs: Optional[dict] **(llm_kwargs or {}), } config["max_output_tokens"] = config.get("max_output_tokens", DEFAULT_MAX_TOKENS) + if prompt.response_format: + config["response_mime_type"] = "application/json" + config["response_schema"] = prompt.response_format content_list = [] for message in prompt.messages: if message.role == "system": diff --git a/lib/sycamore/sycamore/transforms/summarize_images.py b/lib/sycamore/sycamore/transforms/summarize_images.py index 7ffca78ab..dfe237211 100644 --- a/lib/sycamore/sycamore/transforms/summarize_images.py +++ b/lib/sycamore/sycamore/transforms/summarize_images.py @@ -93,7 +93,7 @@ def summarize_image(self, image: Image.Image, context: Optional[str]) -> str: The summarized image as a string. """ messages = [] - if context is not None and self.include_context: + if context is not None: messages = [RenderedMessage(role="system", content=context)] messages.append(RenderedMessage(role="user", content=self.prompt, images=[image])) @@ -141,11 +141,12 @@ class GeminiImageSummarizer(LLMImageSummarizer): def __init__( self, - gemini_model: Gemini, + gemini_model: Optional[Gemini] = None, prompt: Optional[str] = None, include_context: bool = True, ): - + if gemini_model is None: + gemini_model = Gemini(model_name=self.model) super().__init__(llm=gemini_model, prompt=prompt, include_context=include_context)