@@ -2046,17 +2046,24 @@ def test_mixed_ndim_error(self):
2046
2046
def test_static_shape_inference (self ):
2047
2047
a = at .tensor (dtype = "int8" , shape = (2 , 3 ))
2048
2048
b = at .tensor (dtype = "int8" , shape = (2 , 5 ))
2049
- assert at .join (1 , a , b ).type .shape == (2 , 8 )
2050
- assert at .join (- 1 , a , b ).type .shape == (2 , 8 )
2049
+
2050
+ res = at .join (1 , a , b ).type .shape
2051
+ assert res == (2 , 8 )
2052
+ assert all (isinstance (s , int ) for s in res )
2053
+
2054
+ res = at .join (- 1 , a , b ).type .shape
2055
+ assert res == (2 , 8 )
2056
+ assert all (isinstance (s , int ) for s in res )
2051
2057
2052
2058
# Check early informative errors from static shape info
2053
2059
with pytest .raises (ValueError , match = "must match exactly" ):
2054
2060
at .join (0 , at .ones ((2 , 3 )), at .ones ((2 , 5 )))
2055
2061
2056
2062
# Check partial inference
2057
2063
d = at .tensor (dtype = "int8" , shape = (2 , None ))
2058
- assert at .join (1 , a , b , d ).type .shape == (2 , None )
2059
- return
2064
+ res = at .join (1 , a , b , d ).type .shape
2065
+ assert res == (2 , None )
2066
+ assert isinstance (res [0 ], int )
2060
2067
2061
2068
def test_split_0elem (self ):
2062
2069
rng = np .random .default_rng (seed = utt .fetch_seed ())
0 commit comments