Skip to content

Commit 288d2c4

Browse files
committed
Restore typeify
1 parent ac35367 commit 288d2c4

File tree

1 file changed

+3
-5
lines changed

1 file changed

+3
-5
lines changed

pytensor/link/pytorch/dispatch/basic.py

Lines changed: 3 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -26,11 +26,8 @@
2626

2727

2828
@singledispatch
29-
def pytorch_typify(data, dtype=None, **kwargs):
30-
if dtype is None:
31-
return data
32-
else:
33-
return torch.tensor(data, dtype=dtype)
29+
def pytorch_typify(data, **kwargs):
30+
raise NotImplementedError(f"pytorch_typify is not implemented for {type(data)}")
3431

3532

3633
@pytorch_typify.register(np.ndarray)
@@ -40,6 +37,7 @@ def pytorch_typify_tensor(data, dtype=None, **kwargs):
4037

4138

4239
@pytorch_typify.register(slice)
40+
@pytorch_typify.register(dict)
4341
@pytorch_typify.register(NoneType)
4442
@pytorch_typify.register(np.number)
4543
def pytorch_typify_no_conversion_needed(data, **kwargs):

0 commit comments

Comments
 (0)