Skip to content

Commit 5a780b2

Browse files
Redesign BaseModel with unified inheritance and safety validation (#195)
- Replace separate _BaseModelPyodide/_BaseModelStd with single BaseModel class - Use __init_subclass__ for consistent dataclass transformation across environments - Add comprehensive inheritance validation: - Prevent dataclass mixins while inheriting from BaseModel - Block dangerous auto_dataclass=False when parent was auto-dataclassed - Require BaseModel or BaseModel subclass as primary base - Support dataclass options (frozen=True, etc.) via kwargs passthrough - Add __auto_dataclass tracking attribute for inheritance state - Preserve custom __init__ methods while maintaining dataclass benefits Fix custom dataclass coder support for list[Output] returns: - Auto-register DataclassCoder when BaseModel subclasses are created - Add custom_output test runner and schema validation - Fix Go test expected output formatting with JSON round-trip serialization - Remove manual dataclass_coder import requirement for users Remove obsolete pyodide-specific code and add comprehensive test suite covering all inheritance scenarios. This creates a safer, more predictable BaseModel system that prevents broken field inheritance while enabling seamless custom dataclass serialization.
1 parent add3a36 commit 5a780b2

File tree

7 files changed

+903
-97
lines changed

7 files changed

+903
-97
lines changed

internal/tests/coder_test.go

Lines changed: 42 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -88,3 +88,45 @@ func TestPredictionChatCoderSucceeded(t *testing.T) {
8888
assert.Equal(t, expectedOutput, predictionResponse.Output)
8989
assert.Equal(t, server.PredictionSucceeded, predictionResponse.Status)
9090
}
91+
92+
func TestPredictionCustomOutputCoder(t *testing.T) {
93+
t.Parallel()
94+
if *legacyCog {
95+
t.Skip("legacy Cog does not support custom coder")
96+
}
97+
98+
runtimeServer := setupCogRuntime(t, cogRuntimeServerConfig{
99+
procedureMode: false,
100+
explicitShutdown: true,
101+
uploadURL: "",
102+
module: "custom_output",
103+
predictorClass: "Predictor",
104+
})
105+
waitForSetupComplete(t, runtimeServer, server.StatusReady, server.SetupSucceeded)
106+
107+
input := map[string]any{"i": 3}
108+
req := httpPredictionRequest(t, runtimeServer, server.PredictionRequest{Input: input})
109+
resp, err := http.DefaultClient.Do(req)
110+
require.NoError(t, err)
111+
defer resp.Body.Close()
112+
assert.Equal(t, http.StatusOK, resp.StatusCode)
113+
body, err := io.ReadAll(resp.Body)
114+
require.NoError(t, err)
115+
var predictionResponse server.PredictionResponse
116+
err = json.Unmarshal(body, &predictionResponse)
117+
require.NoError(t, err)
118+
119+
// Create expected output using JSON round-trip to match server serialization
120+
expectedItems := []map[string]any{
121+
{"x": 3, "y": "a"},
122+
{"x": 2, "y": "a"},
123+
{"x": 1, "y": "a"},
124+
}
125+
expectedJSON, err := json.Marshal(expectedItems)
126+
require.NoError(t, err)
127+
var expectedOutput []any
128+
err = json.Unmarshal(expectedJSON, &expectedOutput)
129+
require.NoError(t, err)
130+
assert.Equal(t, expectedOutput, predictionResponse.Output)
131+
assert.Equal(t, server.PredictionSucceeded, predictionResponse.Status)
132+
}

python/coglet/api.py

Lines changed: 42 additions & 34 deletions
Original file line numberDiff line numberDiff line change
@@ -1,5 +1,4 @@
11
import pathlib
2-
import sys
32
from abc import ABC, abstractmethod
43
from dataclasses import dataclass, is_dataclass
54
from typing import Any, AsyncIterator, Iterator, List, Optional, Type, TypeVar, Union
@@ -97,39 +96,48 @@ class Input:
9796
deprecated: Optional[bool] = None
9897

9998

100-
# pyodide does not recognise `BaseModel` with the `__new__` keyword as a dataclass while regular python does.
101-
# to get around this, we hijack `__init__subclass` instead to make sure the subclass of a base model is recognised
102-
# as a dataclass. In addition to this, we provide this `BaseModel` only on pyodide instances, normal python still gets
103-
# the regular `BaseModel``.
104-
class _BaseModelPyodide:
105-
def __init_subclass__(cls, **kwargs):
106-
dc_keys = {
107-
'init',
108-
'repr',
109-
'eq',
110-
'order',
111-
'unsafe_hash',
112-
'frozen',
113-
'match_args',
114-
'kw_only',
115-
'slots',
116-
'weakref_slot',
117-
}
118-
dc_opts = {k: kwargs.pop(k) for k in list(kwargs) if k in dc_keys}
119-
super().__init_subclass__(**kwargs)
120-
if not is_dataclass(cls):
121-
dataclass(**dc_opts)(cls)
122-
123-
124-
class _BaseModelStd:
125-
def __new__(cls, *args, **kwargs):
126-
# This does not work with frozen=True
127-
# Also user might want to mutate the output class
128-
dcls = dataclass()(cls)
129-
return super().__new__(dcls)
130-
131-
132-
BaseModel = _BaseModelPyodide if 'pyodide' in sys.modules else _BaseModelStd
99+
class BaseModel:
100+
def __init_subclass__(
101+
cls, *, auto_dataclass: bool = True, init: bool = True, **kwargs
102+
):
103+
# BaseModel is parented to `object` so we have nothing to pass up to it, we pass the kwargs to dataclass() only.
104+
super().__init_subclass__()
105+
106+
# For sanity, the primary base class must inherit from BaseModel
107+
if not issubclass(cls.__bases__[0], BaseModel):
108+
raise TypeError(
109+
f'Primary base class of "{cls.__name__}" must inherit from BaseModel'
110+
)
111+
elif not auto_dataclass:
112+
try:
113+
if (
114+
cls.__bases__[0] != BaseModel
115+
and cls.__bases__[0].__auto_dataclass is True # type: ignore[attr-defined]
116+
):
117+
raise ValueError(
118+
f'Primary base class of "{cls.__name__}" ("{cls.__bases__[0].__name__}") has auto_dataclass=True, but "{cls.__name__}" has auto_dataclass=False. This creates broken field inheritance.'
119+
)
120+
except AttributeError:
121+
raise RuntimeError(
122+
f'Primary base class of "{cls.__name__}" is a child of a child of `BaseModel`, but `auto_dataclass` tracking does not exist. This is likely a bug or other programming error.'
123+
)
124+
125+
for base in cls.__bases__[1:]:
126+
if is_dataclass(base):
127+
raise TypeError(
128+
f'Cannot mixin dataclass "{base.__name__}" while inheriting from `BaseModel`'
129+
)
130+
131+
# Once manual dataclass handling is enabled, we never apply the auto dataclass logic again,
132+
# it becomes the responsibility of the user to ensure that all dataclass semantics are handled.
133+
if not auto_dataclass:
134+
cls.__auto_dataclass = False # type: ignore[attr-defined]
135+
return
136+
137+
# all children should be dataclass'd, this is the only way to ensure that the dataclass inheritence
138+
# is handled properly.
139+
dataclass(init=init, **kwargs)(cls)
140+
cls.__auto_dataclass = True # type: ignore[attr-defined]
133141

134142

135143
########################################
Lines changed: 17 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,17 @@
1+
from cog import BaseModel, BasePredictor
2+
3+
4+
class CustomOut(BaseModel):
5+
x: int
6+
y: str
7+
8+
9+
class Predictor(BasePredictor):
10+
test_inputs = {'i': 3}
11+
12+
def predict(self, i: int) -> list[CustomOut]:
13+
outputs: list[CustomOut] = []
14+
while i > 0:
15+
outputs.append(CustomOut(x=i, y='a'))
16+
i -= 1
17+
return outputs

0 commit comments

Comments
 (0)