Skip to content

Commit 2801c74

Browse files
Mattias WighMattias Wigh
Mattias Wigh
authored and
Mattias Wigh
committed
Needs to set required fields as required Handle cases dataclasses, TypedDict and Pydantic models
1 parent f30c36d commit 2801c74

File tree

1 file changed

+24
-1
lines changed

1 file changed

+24
-1
lines changed

google/generativeai/types/content_types.py

+24-1
Original file line numberDiff line numberDiff line change
@@ -23,6 +23,7 @@
2323
import typing
2424
from typing import Any, Callable, Union, get_type_hints, get_origin, get_args
2525
from typing_extensions import TypedDict, is_typeddict
26+
import dataclasses
2627

2728
import pydantic
2829

@@ -334,9 +335,11 @@ def to_contents(contents: ContentsType) -> list[protos.Content]:
334335
return contents
335336

336337

337-
def _schema_for_class(cls: TypedDict) -> dict[str, Any]:
338+
def _schema_for_class(cls: type) -> dict[str, Any]:
338339
schema = _build_schema("dummy", {"dummy": (cls, pydantic.Field())})
339340
properties = schema["properties"]["dummy"]
341+
342+
# Handling TypedDict
340343
if is_typeddict(cls):
341344
required_keys = []
342345
type_hints = get_type_hints(cls)
@@ -347,6 +350,26 @@ def _schema_for_class(cls: TypedDict) -> dict[str, Any]:
347350
continue
348351
required_keys.append(key)
349352
properties["required"] = required_keys
353+
354+
# Handling dataclasses
355+
elif dataclasses.is_dataclass(cls):
356+
required_keys = []
357+
for field in dataclasses.fields(cls):
358+
if field.default is dataclasses.MISSING and field.default_factory is dataclasses.MISSING:
359+
required_keys.append(field.name) # Field is required if it has no default value
360+
properties["required"] = required_keys
361+
362+
# Handling Pydantic models
363+
elif issubclass(cls, pydantic.BaseModel):
364+
required_keys = [name for name, field in cls.__fields__.items() if field.is_required()]
365+
properties["required"] = required_keys
366+
367+
# Bug that it sets default values in case default exists
368+
# TODO: Should be handled in the schema generation or not be allowed
369+
370+
for key in properties["properties"]:
371+
if 'default' in properties["properties"][key]:
372+
properties["properties"][key].pop('default')
350373
return properties
351374

352375

0 commit comments

Comments
 (0)