Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Google Gemini doesn't work with QueryTool #1572

Open
shhlife opened this issue Jan 13, 2025 · 6 comments
Open

Google Gemini doesn't work with QueryTool #1572

shhlife opened this issue Jan 13, 2025 · 6 comments
Assignees
Labels
status:blocked Can't proceed due to external factors type:bug Something isn't working

Comments

@shhlife
Copy link

shhlife commented Jan 13, 2025

When using GoogleDriversConfig and off_prompt=True I'm getting an error - Unknown field for Schema: anyOf

from griptape.configs import Defaults
from griptape.configs.drivers import GoogleDriversConfig
from griptape.structures import Agent
from griptape.tools import QueryTool, WebScraperTool

Defaults.drivers_config = GoogleDriversConfig()

agent = Agent(tools=[WebScraperTool(off_prompt=True), QueryTool()])

agent.run(
    "How does off-prompt work? https://docs.griptape.ai/stable/griptape-framework/structures/task-memory/ "
)

Here's the error:

[01/14/25 11:36:25] INFO     PromptTask 6d49e7b57c0e4fb88b1b379e2e63e8ae
                             Input: How does off-prompt work? https://docs.griptape.ai/stable/griptape-framework/structures/task-memory/  
                    ERROR    PromptTask 6d49e7b57c0e4fb88b1b379e2e63e8ae
                             Unknown field for Schema: anyOf
                             Traceback (most recent call last):
                               File
                             "C:\Users\jason\Documents\GitHub\griptape-intro-demos\.venv\Lib\site-packages\proto\marshal\rules\message.py"
                             , line 36, in to_proto
                                 return self._descriptor(**value)
                                        ^^^^^^^^^^^^^^^^^^^^^^^^^
                             TypeError: Parameter to CopyFrom() must be instance of same class: expected <class 'Schema'> got <class      
                             'dict'>.

                             During handling of the above exception, another exception occurred:

                             Traceback (most recent call last):
                               File
                             "C:\Users\jason\Documents\GitHub\griptape-intro-demos\.venv\Lib\site-packages\proto\marshal\rules\message.py"
                             , line 36, in to_proto
                                 return self._descriptor(**value)
                                        ^^^^^^^^^^^^^^^^^^^^^^^^^
                             ValueError: Protocol message Schema has no "anyOf" field.

                             During handling of the above exception, another exception occurred:

                             Traceback (most recent call last):
                               File
                             "C:\Users\jason\Documents\GitHub\griptape-intro-demos\.venv\Lib\site-packages\griptape\tasks\base_task.py",  
                             line 163, in run
                                 self.output = self.try_run()
                                               ^^^^^^^^^^^^^^
                               File
                             "C:\Users\jason\Documents\GitHub\griptape-intro-demos\.venv\Lib\site-packages\griptape\tasks\prompt_task.py",
                             line 205, in try_run
                                 result = self.prompt_driver.run(self.prompt_stack)
                                          ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
                               File
                             "C:\Users\jason\Documents\GitHub\griptape-intro-demos\.venv\Lib\site-packages\griptape\common\decorators.py",
                             line 18, in decorator
                                 Observability.observe(
                               File
                             "C:\Users\jason\Documents\GitHub\griptape-intro-demos\.venv\Lib\site-packages\griptape\observability\observab
                             ility.py", line 36, in observe
                                 return driver.observe(call)
                                        ^^^^^^^^^^^^^^^^^^^^
                               File
                             "C:\Users\jason\Documents\GitHub\griptape-intro-demos\.venv\Lib\site-packages\griptape\drivers\observability\
                             no_op_observability_driver.py", line 16, in observe
                                 return call()
                                        ^^^^^^
                               File
                             "C:\Users\jason\Documents\GitHub\griptape-intro-demos\.venv\Lib\site-packages\griptape\common\observable.py",
                             line 19, in __call__
                                 return self.func(*self.args, **self.kwargs)
                                        ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
                               File
                             "C:\Users\jason\Documents\GitHub\griptape-intro-demos\.venv\Lib\site-packages\griptape\drivers\prompt\base_pr
                             ompt_driver.py", line 81, in run
                                 for attempt in self.retrying():
                               File "C:\Users\jason\Documents\GitHub\griptape-intro-demos\.venv\Lib\site-packages\tenacity\__init__.py",  
                             line 443, in __iter__
                                 do = self.iter(retry_state=retry_state)
                                      ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
                               File "C:\Users\jason\Documents\GitHub\griptape-intro-demos\.venv\Lib\site-packages\tenacity\__init__.py",  
                             line 376, in iter
                                 result = action(retry_state)
                                          ^^^^^^^^^^^^^^^^^^^
                               File "C:\Users\jason\Documents\GitHub\griptape-intro-demos\.venv\Lib\site-packages\tenacity\__init__.py",  
                             line 398, in <lambda>
                                 self._add_action_func(lambda rs: rs.outcome.result())
                                                                  ^^^^^^^^^^^^^^^^^^^
                               File "C:\Users\jason\AppData\Local\Programs\Python\Python311\Lib\concurrent\futures\_base.py", line 449, in
                             result
                                 return self.__get_result()
                                        ^^^^^^^^^^^^^^^^^^^
                               File "C:\Users\jason\AppData\Local\Programs\Python\Python311\Lib\concurrent\futures\_base.py", line 401, in
                             __get_result
                                 raise self._exception
                               File
                             "C:\Users\jason\Documents\GitHub\griptape-intro-demos\.venv\Lib\site-packages\griptape\drivers\prompt\base_pr
                             ompt_driver.py", line 85, in run
                                 result = self.__process_stream(prompt_stack) if self.stream else self.__process_run(prompt_stack)        
                                                                                                  ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^        
                               File
                             "C:\Users\jason\Documents\GitHub\griptape-intro-demos\.venv\Lib\site-packages\griptape\drivers\prompt\base_pr
                             ompt_driver.py", line 126, in __process_run
                                 return self.try_run(prompt_stack)
                                        ^^^^^^^^^^^^^^^^^^^^^^^^^^
                               File
                             "C:\Users\jason\Documents\GitHub\griptape-intro-demos\.venv\Lib\site-packages\griptape\common\decorators.py",
                             line 18, in decorator
                                 Observability.observe(
                               File
                             "C:\Users\jason\Documents\GitHub\griptape-intro-demos\.venv\Lib\site-packages\griptape\observability\observab
                             ility.py", line 36, in observe
                                 return driver.observe(call)
                                        ^^^^^^^^^^^^^^^^^^^^
                               File
                             "C:\Users\jason\Documents\GitHub\griptape-intro-demos\.venv\Lib\site-packages\griptape\drivers\observability\
                             no_op_observability_driver.py", line 16, in observe
                                 return call()
                                        ^^^^^^
                               File
                             "C:\Users\jason\Documents\GitHub\griptape-intro-demos\.venv\Lib\site-packages\griptape\common\observable.py",
                             line 19, in __call__
                                 return self.func(*self.args, **self.kwargs)
                                        ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
                               File
                             "C:\Users\jason\Documents\GitHub\griptape-intro-demos\.venv\Lib\site-packages\griptape\drivers\prompt\google_
                             prompt_driver.py", line 79, in try_run
                                 params = self._base_params(prompt_stack)
                                          ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
                               File
                             "C:\Users\jason\Documents\GitHub\griptape-intro-demos\.venv\Lib\site-packages\griptape\drivers\prompt\google_
                             prompt_driver.py", line 153, in _base_params
                                 "tools": self.__to_google_tools(prompt_stack.tools),
                                          ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
                               File
                             "C:\Users\jason\Documents\GitHub\griptape-intro-demos\.venv\Lib\site-packages\griptape\drivers\prompt\google_
                             prompt_driver.py", line 193, in __to_google_tools
                                 tool_declaration = types.FunctionDeclaration(
                                                    ^^^^^^^^^^^^^^^^^^^^^^^^^^
                               File
                             "C:\Users\jason\Documents\GitHub\griptape-intro-demos\.venv\Lib\site-packages\google\generativeai\types\conte
                             nt_types.py", line 558, in __init__
                                 self._proto = protos.FunctionDeclaration(
                                               ^^^^^^^^^^^^^^^^^^^^^^^^^^^
                               File "C:\Users\jason\Documents\GitHub\griptape-intro-demos\.venv\Lib\site-packages\proto\message.py", line 
                             728, in __init__
                                 pb_value = marshal.to_proto(pb_type, value)
                                            ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
                               File
                             "C:\Users\jason\Documents\GitHub\griptape-intro-demos\.venv\Lib\site-packages\proto\marshal\marshal.py", line
                             235, in to_proto
                                 pb_value = self.get_rule(proto_type=proto_type).to_proto(value)
                                            ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
                               File
                             "C:\Users\jason\Documents\GitHub\griptape-intro-demos\.venv\Lib\site-packages\proto\marshal\rules\message.py"
                             , line 45, in to_proto
                                 return self._wrapper(value)._pb
                                        ^^^^^^^^^^^^^^^^^^^^
                               File "C:\Users\jason\Documents\GitHub\griptape-intro-demos\.venv\Lib\site-packages\proto\message.py", line 
                             728, in __init__
                                 pb_value = marshal.to_proto(pb_type, value)
                                            ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
                               File
                             "C:\Users\jason\Documents\GitHub\griptape-intro-demos\.venv\Lib\site-packages\proto\marshal\marshal.py", line
                             233, in to_proto
                                 return {k: self.to_proto(recursive_type, v) for k, v in value.items()}
                                        ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
                               File
                             "C:\Users\jason\Documents\GitHub\griptape-intro-demos\.venv\Lib\site-packages\proto\marshal\marshal.py", line
                             233, in <dictcomp>
                                 return {k: self.to_proto(recursive_type, v) for k, v in value.items()}
                                            ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
                               File
                             "C:\Users\jason\Documents\GitHub\griptape-intro-demos\.venv\Lib\site-packages\proto\marshal\marshal.py", line
                             235, in to_proto
                                 pb_value = self.get_rule(proto_type=proto_type).to_proto(value)
                                            ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
                               File
                             "C:\Users\jason\Documents\GitHub\griptape-intro-demos\.venv\Lib\site-packages\proto\marshal\rules\message.py"
                             , line 45, in to_proto
                                 return self._wrapper(value)._pb
                                        ^^^^^^^^^^^^^^^^^^^^
                               File "C:\Users\jason\Documents\GitHub\griptape-intro-demos\.venv\Lib\site-packages\proto\message.py", line 
                             724, in __init__
                                 raise ValueError(
                             ValueError: Unknown field for Schema: anyOf
@collindutter
Copy link
Member

Can be simplified to:

from griptape.drivers import GooglePromptDriver
from griptape.structures import Agent
from griptape.tools import QueryTool

agent = Agent(
    prompt_driver=GooglePromptDriver(model="gemini-1.5-pro"), tools=[QueryTool()]
)

agent.run()

It's not an issue with off_prompt, it's an issue with QueryTool.

@collindutter collindutter changed the title Google Gemini doesn't work with tools where off_prompt=True Google Gemini doesn't work with QueryTool Jan 13, 2025
@collindutter
Copy link
Member

It looks like Google Gemini does not support anyOf in its json schemas. This is far from ideal, but in the meantime you can "fix" the tool by removing the use of schema.Or.

from __future__ import annotations

from attrs import define
from griptape.artifacts import ErrorArtifact, ListArtifact
from griptape.drivers import GooglePromptDriver
from griptape.structures import Agent
from griptape.tools import QueryTool
from griptape.utils.decorators import activity
from schema import Literal, Schema


@define(kw_only=True)
class GeminiQueryTool(QueryTool):
    @activity(
        config={
            "description": "Can be used to search through textual content.",
            "schema": Schema(
                {
                    Literal(
                        "query", description="A natural language search query"
                    ): str,
                    Literal("content"): Schema(
                        {
                            "memory_name": str,
                            "artifact_namespace": str,
                        }
                    ),
                }
            ),
        },
    )
    def query(self, params: dict) -> ListArtifact | ErrorArtifact:
        return super().query(params)


agent = Agent(
    prompt_driver=GooglePromptDriver(model="gemini-1.5-flash"),
    tools=[GeminiQueryTool()],
)

agent.run()

@shhlife
Copy link
Author

shhlife commented Jan 13, 2025

yeeks. :)

I can fix that if I'm using it on my own, but would rather not try and fix it in comfyUI where our customer is hitting it. Is there another fix that we can use for the framework, or is this a biggie?

@collindutter collindutter added the status:blocked Can't proceed due to external factors label Jan 13, 2025
@collindutter
Copy link
Member

The issue boils down to this tool using schema.Or which turns into anyOf when rendered as a json schema. Others are running into it here. All the solutions I can think of would be a breaking change on the framework to the QueryTool. Can you just include this patched version of the QueryTool in comfy?

@shhlife
Copy link
Author

shhlife commented Jan 14, 2025

I'm giving it a try - but I'm getting this error:

Traceback (most recent call last):
  File "C:\Users\jason\Documents\GitHub\ComfyUI\.venv\Lib\site-packages\griptape\tools\base_tool.py", line 136, in run
    output = self.try_run(activity, subtask, action, output)
             ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
  File "C:\Users\jason\Documents\GitHub\ComfyUI\.venv\Lib\site-packages\griptape\common\decorators.py", line 18, in decorator
    Observability.observe(
  File "C:\Users\jason\Documents\GitHub\ComfyUI\.venv\Lib\site-packages\griptape\observability\observability.py", line 36, in observe
    return driver.observe(call)
           ^^^^^^^^^^^^^^^^^^^^
  File "C:\Users\jason\Documents\GitHub\ComfyUI\.venv\Lib\site-packages\griptape\drivers\observability\no_op_observability_driver.py", line 16, in observe
    return call()
           ^^^^^^
  File "C:\Users\jason\Documents\GitHub\ComfyUI\.venv\Lib\site-packages\griptape\common\observable.py", line 19, in __call__
    return self.func(*self.args, **self.kwargs)
           ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
  File "C:\Users\jason\Documents\GitHub\ComfyUI\.venv\Lib\site-packages\griptape\tools\base_tool.py", line 158, in try_run
    activity_result = activity(deepcopy(value))
                      ^^^^^^^^^^^^^^^^^^^^^^^^^
  File "C:\Users\jason\Documents\GitHub\ComfyUI\.venv\Lib\site-packages\griptape\utils\decorators.py", line 31, in wrapper
    return func(self, **_build_kwargs(func, params))
           ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
  File "C:\Users\jason\Documents\GitHub\ComfyUI\custom_nodes\ComfyUI-Griptape\nodes\patches\gemini_query_tool.py", line 31, in query
    return super().query(params)
           ^^^^^^^^^^^^^
TypeError: super(type, obj): obj must be an instance or subtype of type

Will keep poking around unless you know a quick fix :)

@shhlife
Copy link
Author

shhlife commented Jan 14, 2025

resolved it by copying the query code from QueryTool and not relying on:

    def query(self, params: dict) -> ListArtifact | ErrorArtifact:
        return super().query(params)

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
status:blocked Can't proceed due to external factors type:bug Something isn't working
Projects
None yet
Development

No branches or pull requests

2 participants