Skip to content

Commit e1d0dc6

Browse files
committed
better error message in dataclass_array_container
1 parent 5c85680 commit e1d0dc6

File tree

3 files changed

+86
-6
lines changed

3 files changed

+86
-6
lines changed

arraycontext/container/__init__.py

Lines changed: 4 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -173,6 +173,10 @@ def is_array_container_type(cls: type) -> bool:
173173
function will say that :class:`numpy.ndarray` is an array container
174174
type, only object arrays *actually are* array containers.
175175
"""
176+
assert isinstance(cls, type), \
177+
f"must pass a type, not an instance: '{cls!r}'"
178+
assert hasattr(cls, "__mro__"), "'cls' has no attribute '__mro__': "
179+
176180
return (
177181
cls is ArrayContainer
178182
or (serialize_container.dispatch(cls)

arraycontext/container/dataclass.py

Lines changed: 25 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -49,10 +49,31 @@ def dataclass_array_container(cls: type) -> type:
4949
from dataclasses import is_dataclass
5050
assert is_dataclass(cls)
5151

52-
array_fields = [
53-
f for f in fields(cls) if is_array_container_type(f.type)]
54-
non_array_fields = [
55-
f for f in fields(cls) if not is_array_container_type(f.type)]
52+
def is_array_field(f):
53+
if __debug__:
54+
if not f.init:
55+
raise ValueError(
56+
f"'init=False' field not allowed: '{f.name}'")
57+
58+
if isinstance(f.type, str):
59+
raise TypeError(
60+
f"string annotation on field '{f.name}' not supported")
61+
62+
from typing import _SpecialForm
63+
if isinstance(f.type, _SpecialForm):
64+
raise TypeError(
65+
f"typing annotation not supported on field '{f.name}': "
66+
f"'{f.type!r}'")
67+
68+
if not isinstance(f.type, type):
69+
raise TypeError(
70+
f"field '{f.name}' not an instance of 'type': "
71+
f"'{f.type!r}'")
72+
73+
return is_array_container_type(f.type)
74+
75+
from pytools import partition
76+
array_fields, non_array_fields = partition(is_array_field, fields(cls))
5677

5778
if not array_fields:
5879
raise ValueError(f"'{cls}' must have fields with array container type "

test/test_utils.py

Lines changed: 57 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -22,11 +22,16 @@
2222
OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN
2323
THE SOFTWARE.
2424
"""
25+
import pytest
26+
27+
import numpy as np
2528

2629
import logging
2730
logger = logging.getLogger(__name__)
2831

2932

33+
# {{{ test_pt_actx_key_stringification_uniqueness
34+
3035
def test_pt_actx_key_stringification_uniqueness():
3136
from arraycontext.impl.pytato.compile import _ary_container_key_stringifier
3237

@@ -36,13 +41,63 @@ def test_pt_actx_key_stringification_uniqueness():
3641
assert (_ary_container_key_stringifier(("tup", 3, "endtup"))
3742
!= _ary_container_key_stringifier(((3,),)))
3843

44+
# }}}
45+
46+
47+
# {{{ test_dataclass_array_container
48+
49+
def test_dataclass_array_container():
50+
from typing import Optional
51+
from dataclasses import dataclass, field
52+
from arraycontext import dataclass_array_container
53+
54+
# {{{ string fields
55+
56+
@dataclass
57+
class ArrayContainerWithStringTypes:
58+
x: np.ndarray
59+
y: "np.ndarray"
60+
61+
with pytest.raises(TypeError):
62+
# NOTE: cannot have string annotations in container
63+
dataclass_array_container(ArrayContainerWithStringTypes)
64+
65+
# }}}
66+
67+
# {{{ optional fields
68+
69+
@dataclass
70+
class ArrayContainerWithOptional:
71+
x: np.ndarray
72+
y: Optional[np.ndarray]
73+
74+
with pytest.raises(TypeError):
75+
# NOTE: cannot have wrapped annotations (here by `Optional`)
76+
dataclass_array_container(ArrayContainerWithOptional)
77+
78+
# }}}
79+
80+
# {{{ field(init=False)
81+
82+
@dataclass
83+
class ArrayContainerWithInitFalse:
84+
x: np.ndarray
85+
y: np.ndarray = field(default=np.zeros(42), init=False, repr=False)
86+
87+
with pytest.raises(ValueError):
88+
# NOTE: init=False fields are not allowed
89+
dataclass_array_container(ArrayContainerWithInitFalse)
90+
91+
# }}}
92+
93+
# }}}
94+
3995

4096
if __name__ == "__main__":
4197
import sys
4298
if len(sys.argv) > 1:
4399
exec(sys.argv[1])
44100
else:
45-
from pytest import main
46-
main([__file__])
101+
pytest.main([__file__])
47102

48103
# vim: fdm=marker

0 commit comments

Comments
 (0)