Skip to content

Commit ae76cdf

Browse files
committed
PyTorch typify for NoneType
1 parent 2fb9f0e commit ae76cdf

File tree

1 file changed

+6
-2
lines changed

1 file changed

+6
-2
lines changed

pytensor/link/pytorch/dispatch/basic.py

+6-2
Original file line numberDiff line numberDiff line change
@@ -1,4 +1,5 @@
11
from functools import singledispatch
2+
from types import NoneType
23

34
import torch
45

@@ -12,8 +13,11 @@
1213
@singledispatch
1314
def pytorch_typify(data, dtype=None, **kwargs):
1415
r"""Convert instances of PyTensor `Type`\s to PyTorch types."""
15-
if data is not None:
16-
return torch.as_tensor(data, dtype=dtype)
16+
return torch.as_tensor(data, dtype=dtype)
17+
18+
19+
@pytorch_typify.register(NoneType)
20+
def pytorch_typify_None(data, **kwargs):
1721
return None
1822

1923

0 commit comments

Comments
 (0)