Skip to content

Commit 425239d

Browse files
authored
fix: patch annotations inside a container type (#5167)
* fix: patch annotations inside a container Signed-off-by: Frost Ming <[email protected]>
1 parent 3fe9653 commit 425239d

File tree

2 files changed

+38
-12
lines changed

2 files changed

+38
-12
lines changed

src/_bentoml_sdk/_pydantic.py

+36
Original file line numberDiff line numberDiff line change
@@ -3,6 +3,7 @@
33
import typing as t
44

55
from pydantic._internal import _known_annotated_metadata
6+
from pydantic._internal._typing_extra import is_annotated
67

78
from .typing_utils import get_args
89
from .typing_utils import get_origin
@@ -194,3 +195,38 @@ def pathlib_prepare_pydantic_annotations(
194195
# PIL image
195196
pil_prepare_pydantic_annotations,
196197
]
198+
199+
SUPPORTED_CONTAINER_TYPES = [
200+
t.Union,
201+
list,
202+
t.List,
203+
dict,
204+
t.Dict,
205+
t.AsyncGenerator,
206+
t.AsyncIterable,
207+
t.AsyncIterator,
208+
t.Generator,
209+
t.Iterable,
210+
t.Iterator,
211+
]
212+
213+
214+
def patch_annotation(annotation: t.Any, model_config: ConfigDict) -> t.Any:
215+
import typing_extensions as te
216+
217+
origin, args = te.get_origin(annotation), te.get_args(annotation)
218+
if origin in SUPPORTED_CONTAINER_TYPES:
219+
patched_args = [patch_annotation(arg, model_config) for arg in args]
220+
return origin[tuple(patched_args)]
221+
222+
if is_annotated(annotation):
223+
source, *annotations = args
224+
else:
225+
source = annotation
226+
annotations = []
227+
for method in CUSTOM_PREPARE_METHODS:
228+
result = method(source, annotations, model_config)
229+
if result is None:
230+
continue
231+
return t.Annotated[(result[0], *result[1])] # type: ignore
232+
return annotation

src/_bentoml_sdk/io_models.py

+2-12
Original file line numberDiff line numberDiff line change
@@ -139,20 +139,10 @@ def mime_type(cls) -> str:
139139

140140
@classmethod
141141
def __get_pydantic_core_schema__(cls: type[BaseModel], source, handler):
142-
from ._pydantic import CUSTOM_PREPARE_METHODS
142+
from ._pydantic import patch_annotation
143143

144144
for _, info in cls.model_fields.items():
145-
if is_annotated(info.annotation):
146-
origin, *args = get_args(info.annotation)
147-
else:
148-
origin = info.annotation
149-
args = []
150-
for method in CUSTOM_PREPARE_METHODS:
151-
result = method(origin, args, cls.model_config)
152-
if result is None:
153-
continue
154-
info.annotation = t.Annotated[(result[0], *result[1])] # type: ignore
155-
break
145+
info.annotation = patch_annotation(info.annotation, cls.model_config)
156146

157147
return super().__get_pydantic_core_schema__(source, handler)
158148

0 commit comments

Comments
 (0)