File tree 2 files changed +14
-1
lines changed
2 files changed +14
-1
lines changed Original file line number Diff line number Diff line change @@ -198,7 +198,15 @@ class Det(Op):
198
198
199
199
def make_node (self , x ):
200
200
x = as_tensor_variable (x )
201
- assert x .ndim == 2
201
+ if x .ndim != 2 :
202
+ raise ValueError (
203
+ f"Input passed is not a valid 2D matrix. Current ndim { x .ndim } != 2"
204
+ )
205
+ # Check for known shapes and square matrix
206
+ if None not in x .type .shape and (x .type .shape [0 ] != x .type .shape [1 ]):
207
+ raise ValueError (
208
+ f"Determinant not defined for non-square matrix inputs. Shape received is { x .type .shape } "
209
+ )
202
210
o = scalar (dtype = x .dtype )
203
211
return Apply (self , [x ], [o ])
204
212
Original file line number Diff line number Diff line change @@ -365,6 +365,11 @@ def test_det():
365
365
assert np .allclose (np .linalg .det (r ), f (r ))
366
366
367
367
368
+ def test_det_non_square_raises ():
369
+ with pytest .raises (ValueError , match = "Determinant not defined" ):
370
+ det (tensor ("x" , shape = (5 , 7 )))
371
+
372
+
368
373
def test_det_grad ():
369
374
rng = np .random .default_rng (utt .fetch_seed ())
370
375
You can’t perform that action at this time.
0 commit comments