diff --git a/src/agents/function_schema.py b/src/agents/function_schema.py index 681affce..b292d2f7 100644 --- a/src/agents/function_schema.py +++ b/src/agents/function_schema.py @@ -231,30 +231,36 @@ def function_schema( takes_context = False filtered_params = [] - if params: - first_name, first_param = params[0] - # Prefer the evaluated type hint if available - ann = type_hints.get(first_name, first_param.annotation) + # Helper function to check if a parameter is a special method parameter + def is_special_param(name: str) -> bool: + return name in ("self", "cls") + + # Helper function to check if a parameter is a context parameter + def is_context_param(name: str, param: inspect.Parameter) -> bool: + ann = type_hints.get(name, param.annotation) if ann != inspect._empty: origin = get_origin(ann) or ann - if origin is RunContextWrapper: - takes_context = True # Mark that the function takes context - else: - filtered_params.append((first_name, first_param)) - else: + return origin is RunContextWrapper + return False + + if params: + first_name, first_param = params[0] + + # Handle special first parameter cases + if is_context_param(first_name, first_param): + takes_context = True + elif not is_special_param(first_name): filtered_params.append((first_name, first_param)) - # For parameters other than the first, raise error if any use RunContextWrapper. + # For parameters other than the first, handle special cases and context for name, param in params[1:]: - ann = type_hints.get(name, param.annotation) - if ann != inspect._empty: - origin = get_origin(ann) or ann - if origin is RunContextWrapper: - raise UserError( - f"RunContextWrapper param found at non-first position in function" - f" {func.__name__}" - ) - filtered_params.append((name, param)) + if is_context_param(name, param): + raise UserError( + f"RunContextWrapper param found at non-first position in function" + f" {func.__name__}" + ) + if not is_special_param(name): + filtered_params.append((name, param)) # We will collect field definitions for create_model as a dict: # field_name -> (type_annotation, default_value_or_Field(...)) diff --git a/tests/test_function_schema.py b/tests/test_function_schema.py index ef1e9c22..52414f93 100644 --- a/tests/test_function_schema.py +++ b/tests/test_function_schema.py @@ -431,6 +431,90 @@ def func(**kwargs: dict[str, int]): assert properties.get("kwargs").get("additionalProperties").get("type") == "integer" +def test_context_with_special_params(): + """Test that context parameter works correctly with special parameters (self/cls).""" + class TestClass: + def instance_method_with_context(self, ctx: RunContextWrapper[str], a: int) -> str: + return f"instance {a}" + + @classmethod + def class_method_with_context(cls, ctx: RunContextWrapper[str], a: int) -> str: + return f"class {a}" + + # Test instance method + instance = TestClass() + func_schema = function_schema(instance.instance_method_with_context) + assert func_schema.takes_context + assert func_schema.params_json_schema.get("title") == "instance_method_with_context_args" + + # Verify only 'a' is in the schema, not 'self' or 'ctx' + properties = func_schema.params_json_schema.get("properties", {}) + assert "a" in properties + assert "self" not in properties + assert "ctx" not in properties + + # Test class method + func_schema = function_schema(TestClass.class_method_with_context) + assert func_schema.takes_context + assert func_schema.params_json_schema.get("title") == "class_method_with_context_args" + + # Verify only 'a' is in the schema, not 'cls' or 'ctx' + properties = func_schema.params_json_schema.get("properties", {}) + assert "a" in properties + assert "cls" not in properties + assert "ctx" not in properties + + # Test actual function calls + context = RunContextWrapper(context="test") + + # Test instance method call + parsed = func_schema.params_pydantic_model(**{"a": 42}) + args, kwargs_dict = func_schema.to_call_args(parsed) + result = instance.instance_method_with_context(context, *args, **kwargs_dict) + assert result == "instance 42" + + # Test class method call + parsed = func_schema.params_pydantic_model(**{"a": 42}) + args, kwargs_dict = func_schema.to_call_args(parsed) + result = TestClass.class_method_with_context(context, *args, **kwargs_dict) + assert result == "class 42" + + +def test_context_with_other_params(): + """Test that context parameter works correctly with other parameters.""" + def func_with_context_and_params( + ctx: RunContextWrapper[str], + a: int, + b: str = "default", + ) -> str: + return f"{a} {b}" + + func_schema = function_schema(func_with_context_and_params) + assert func_schema.takes_context + assert func_schema.params_json_schema.get("title") == "func_with_context_and_params_args" + + # Verify schema only contains 'a' and 'b', not 'ctx' + properties = func_schema.params_json_schema.get("properties", {}) + assert "a" in properties + assert "b" in properties + assert "ctx" not in properties + + # Test function call + context = RunContextWrapper(context="test") + + # Test with default value + parsed = func_schema.params_pydantic_model(**{"a": 42}) + args, kwargs_dict = func_schema.to_call_args(parsed) + result = func_with_context_and_params(context, *args, **kwargs_dict) + assert result == "42 default" + + # Test with explicit value + parsed = func_schema.params_pydantic_model(**{"a": 42, "b": "explicit"}) + args, kwargs_dict = func_schema.to_call_args(parsed) + result = func_with_context_and_params(context, *args, **kwargs_dict) + assert result == "42 explicit" + + def test_schema_with_mapping_raises_strict_mode_error(): """A mapping type is not allowed in strict mode. Same for dicts. Ensure we raise a UserError."""