Skip to content

Commit 5c8afae

Browse files
authored
added check for square matrix in make_node to Det (#861)
1 parent ffca031 commit 5c8afae

File tree

2 files changed

+14
-1
lines changed

2 files changed

+14
-1
lines changed

pytensor/tensor/nlinalg.py

+9-1
Original file line numberDiff line numberDiff line change
@@ -198,7 +198,15 @@ class Det(Op):
198198

199199
def make_node(self, x):
200200
x = as_tensor_variable(x)
201-
assert x.ndim == 2
201+
if x.ndim != 2:
202+
raise ValueError(
203+
f"Input passed is not a valid 2D matrix. Current ndim {x.ndim} != 2"
204+
)
205+
# Check for known shapes and square matrix
206+
if None not in x.type.shape and (x.type.shape[0] != x.type.shape[1]):
207+
raise ValueError(
208+
f"Determinant not defined for non-square matrix inputs. Shape received is {x.type.shape}"
209+
)
202210
o = scalar(dtype=x.dtype)
203211
return Apply(self, [x], [o])
204212

tests/tensor/test_nlinalg.py

+5
Original file line numberDiff line numberDiff line change
@@ -365,6 +365,11 @@ def test_det():
365365
assert np.allclose(np.linalg.det(r), f(r))
366366

367367

368+
def test_det_non_square_raises():
369+
with pytest.raises(ValueError, match="Determinant not defined"):
370+
det(tensor("x", shape=(5, 7)))
371+
372+
368373
def test_det_grad():
369374
rng = np.random.default_rng(utt.fetch_seed())
370375

0 commit comments

Comments
 (0)