23
23
import typing
24
24
from typing import Any , Callable , Union , get_type_hints , get_origin , get_args
25
25
from typing_extensions import TypedDict , is_typeddict
26
+ import dataclasses
26
27
27
28
import pydantic
28
29
@@ -334,9 +335,11 @@ def to_contents(contents: ContentsType) -> list[protos.Content]:
334
335
return contents
335
336
336
337
337
- def _schema_for_class (cls : TypedDict ) -> dict [str , Any ]:
338
+ def _schema_for_class (cls : type ) -> dict [str , Any ]:
338
339
schema = _build_schema ("dummy" , {"dummy" : (cls , pydantic .Field ())})
339
340
properties = schema ["properties" ]["dummy" ]
341
+
342
+ # Handling TypedDict
340
343
if is_typeddict (cls ):
341
344
required_keys = []
342
345
type_hints = get_type_hints (cls )
@@ -347,6 +350,26 @@ def _schema_for_class(cls: TypedDict) -> dict[str, Any]:
347
350
continue
348
351
required_keys .append (key )
349
352
properties ["required" ] = required_keys
353
+
354
+ # Handling dataclasses
355
+ elif dataclasses .is_dataclass (cls ):
356
+ required_keys = []
357
+ for field in dataclasses .fields (cls ):
358
+ if field .default is dataclasses .MISSING and field .default_factory is dataclasses .MISSING :
359
+ required_keys .append (field .name ) # Field is required if it has no default value
360
+ properties ["required" ] = required_keys
361
+
362
+ # Handling Pydantic models
363
+ elif issubclass (cls , pydantic .BaseModel ):
364
+ required_keys = [name for name , field in cls .__fields__ .items () if field .is_required ()]
365
+ properties ["required" ] = required_keys
366
+
367
+ # Bug that it sets default values in case default exists
368
+ # TODO: Should be handled in the schema generation or not be allowed
369
+
370
+ for key in properties ["properties" ]:
371
+ if 'default' in properties ["properties" ][key ]:
372
+ properties ["properties" ][key ].pop ('default' )
350
373
return properties
351
374
352
375
0 commit comments