Skip to content

Commit fad3031

Browse files
davidriazatifacebook-github-bot
davidriazati
authored andcommitted
Fix type hints for None constants (pytorch#23029)
Summary: The type hint was being ignored when emitting `None` constants, this also de-dups some testing code ](https://our.intern.facebook.com/intern/diff/16364572/) Pull Request resolved: pytorch#23029 Pulled By: driazati Differential Revision: D16364572 fbshipit-source-id: 64f3abd3e37ee49c209480a85ed4f1b8802e5d93
1 parent 2891784 commit fad3031

File tree

5 files changed

+68
-101
lines changed

5 files changed

+68
-101
lines changed

Diff for: test/jit_utils.py

+18
Original file line numberDiff line numberDiff line change
@@ -19,11 +19,14 @@
1919
from contextlib import contextmanager
2020
from functools import reduce
2121
from itertools import chain
22+
from torch._six import StringIO
23+
2224
import inspect
2325
import io
2426
import math
2527
import os
2628
import pickle
29+
import sys
2730
import tempfile
2831
import textwrap
2932

@@ -39,6 +42,21 @@ class JitTestCase(TestCase):
3942
_do_cuda_memory_leak_check = True
4043
_restored_warnings = False
4144

45+
class capture_stdout(list):
46+
"""
47+
Replace sys.stdout with a temporary StringIO
48+
"""
49+
def __enter__(self):
50+
self.sys_stdout = sys.stdout
51+
self.stringio = StringIO()
52+
sys.stdout = self.stringio
53+
return self
54+
55+
def __exit__(self, *args):
56+
self.append(str(self.stringio.getvalue()))
57+
del self.stringio
58+
sys.stdout = self.sys_stdout
59+
4260
def setHooks(self):
4361
torch._C._jit_set_emit_hooks(self.emitModuleHook, self.emitFunctionHook)
4462

Diff for: test/test_jit.py

+2-46
Original file line numberDiff line numberDiff line change
@@ -2850,21 +2850,6 @@ def test_print_op_module(self):
28502850

28512851

28522852
class TestScript(JitTestCase):
2853-
class capture_stdout(list):
2854-
"""
2855-
Replace sys.stdout with a temporary StringIO
2856-
"""
2857-
def __enter__(self):
2858-
self.sys_stdout = sys.stdout
2859-
self.stringio = StringIO()
2860-
sys.stdout = self.stringio
2861-
return self
2862-
2863-
def __exit__(self, *args):
2864-
self.append(str(self.stringio.getvalue()))
2865-
del self.stringio
2866-
sys.stdout = self.sys_stdout
2867-
28682853
def test_sequence_parsing(self):
28692854
tests = [
28702855
("return [x, x,]", True),
@@ -3194,35 +3179,6 @@ def annotate_none_no_optional():
31943179
self.checkScript(annotate_none, ())
31953180
self.checkScript(annotate_none_no_optional, ())
31963181

3197-
@unittest.skipIf(True, "Python 3 required")
3198-
def test_type_annotate_py3(self):
3199-
code = dedent("""
3200-
import torch
3201-
def fn():
3202-
a : List[int] = []
3203-
b : torch.Tensor = torch.ones(2, 2)
3204-
for _ in range(10):
3205-
a.append(4)
3206-
return a, b
3207-
""")
3208-
3209-
with tempfile.TemporaryDirectory() as tmp_dir:
3210-
script_path = os.path.join(tmp_dir, 'script.py')
3211-
with open(script_path, 'w') as f:
3212-
f.write(code)
3213-
fn = get_fn('test_type_annotate_py3', script_path)
3214-
3215-
self.checkScript(fn, ())
3216-
3217-
code = dedent("""
3218-
def wrong_type():
3219-
wrong : List[int] = [0.5]
3220-
return wrong
3221-
""")
3222-
3223-
with self.assertRaisesRegex(RuntimeError, "Lists must contain only a single type"):
3224-
cu = torch.jit.CompilationUnit(code)
3225-
32263182
def test_robust_op_resolution(self):
32273183
neg = torch.add # misleading name to make sure we resolve by function
32283184

@@ -7862,7 +7818,7 @@ def foo(i):
78627818
v = torch.rand(10, 3)
78637819
self.checkScript(foo, (v,))
78647820

7865-
with self.assertRaisesRegex(RuntimeError, r"variable 'a' previously has type Tuple"):
7821+
with self.assertRaisesRegex(RuntimeError, r"Variable 'a' previously has type Tuple"):
78667822
@torch.jit.script
78677823
def mixtypes(x):
78687824
a = (x, x)
@@ -7890,7 +7846,7 @@ def diff_type_used():
78907846
c0 = 1.0
78917847
return c0
78927848

7893-
with self.assertRaisesRegex(RuntimeError, "variable 'c0' previously has type float"):
7849+
with self.assertRaisesRegex(RuntimeError, "Variable 'c0' previously has type float"):
78947850
@torch.jit.script
78957851
def diff_existing_type(x):
78967852
c0 = 1.0

Diff for: test/test_jit_py3.py

+23-45
Original file line numberDiff line numberDiff line change
@@ -1,54 +1,12 @@
1-
import sys
2-
import torch
3-
from torch.testing import FileCheck
41
from common_utils import run_tests
5-
from contextlib import contextmanager
62
from jit_utils import JitTestCase
3+
from torch.testing import FileCheck
74
from typing import NamedTuple, List
85

9-
WINDOWS = sys.platform == 'win32'
6+
import torch
107

11-
class TestScriptPy3(JitTestCase):
12-
@contextmanager
13-
def capture_stdout(self):
14-
# No idea how to capture stdout from C++ on Windows
15-
if WINDOWS:
16-
yield ['']
17-
return
18-
import os
19-
import fcntl
20-
import errno
21-
sys.stdout.flush()
22-
stdout_fd = os.dup(1)
23-
r, w = os.pipe()
24-
try:
25-
# Override stdout with r - dup is guaranteed to return the lowest free fd
26-
os.close(1)
27-
os.dup(w)
28-
29-
captured_stdout = ['']
30-
yield captured_stdout
31-
sys.stdout.flush() # Make sure that Python hasn't buffered anything
32-
33-
# Do the ugly dance to read all the data that was written into the pipe
34-
fcntl.fcntl(r, fcntl.F_SETFL, os.O_NONBLOCK)
35-
total_stdout = ''
36-
while True:
37-
try:
38-
total_stdout += os.read(r, 1000).decode('ascii')
39-
except OSError as e:
40-
if e.errno != errno.EAGAIN:
41-
raise
42-
break
43-
captured_stdout[0] = total_stdout
44-
finally:
45-
# Revert the change, and clean up all fds
46-
os.close(1)
47-
os.dup(stdout_fd)
48-
os.close(stdout_fd)
49-
os.close(r)
50-
os.close(w)
518

9+
class TestScriptPy3(JitTestCase):
5210
def test_joined_str(self):
5311
def func(x):
5412
hello, test = "Hello", "test"
@@ -212,5 +170,25 @@ def forward(self):
212170
for name in ['a', 'b', 'c']:
213171
self.assertEqual(getattr(out_loaded, name), getattr(out, name))
214172

173+
def test_type_annotate_py3(self):
174+
def fn():
175+
a : List[int] = []
176+
b : torch.Tensor = torch.ones(2, 2)
177+
c : Optional[torch.Tensor] = None
178+
for _ in range(10):
179+
a.append(4)
180+
c = torch.ones(2, 2)
181+
return a, b, c
182+
183+
self.checkScript(fn, ())
184+
185+
def wrong_type():
186+
wrong : List[int] = [0.5]
187+
return wrong
188+
189+
with self.assertRaisesRegex(RuntimeError, "Lists must contain only a single type"):
190+
torch.jit.script(wrong_type)
191+
192+
215193
if __name__ == '__main__':
216194
run_tests()

Diff for: torch/csrc/jit/constants.cpp

+14-3
Original file line numberDiff line numberDiff line change
@@ -90,9 +90,20 @@ c10::optional<Value*> tryInsertConstant(
9090
n->setScope(*scope);
9191
if (result_type) {
9292
auto inferred_type = n->output()->type();
93-
// Retain more type information in case of tensor constant
94-
if (!(inferred_type->isSubtypeOf(TensorType::get()) &&
95-
result_type->isSubtypeOf(inferred_type))) {
93+
94+
if (inferred_type->isSubtypeOf(NoneType::get()) &&
95+
!inferred_type->isSubtypeOf(result_type)) {
96+
// None doesn't subtype Optional, but an Optional can be None, so handle
97+
// that here
98+
if (result_type->kind() == TypeKind::OptionalType) {
99+
n->output()->setType(result_type);
100+
} else {
101+
// Implicitly wrap non-optionals
102+
n->output()->setType(OptionalType::create(result_type));
103+
}
104+
} else if (!(inferred_type->isSubtypeOf(TensorType::get()) &&
105+
result_type->isSubtypeOf(inferred_type))) {
106+
// Retain more type information in case of tensor constant
96107
n->output()->setType(result_type);
97108
}
98109
}

Diff for: torch/csrc/jit/script/compiler.cpp

+11-7
Original file line numberDiff line numberDiff line change
@@ -307,20 +307,21 @@ struct Environment {
307307
}
308308
if (!as_simple_value->type()->isSubtypeOf(
309309
unshapedType(simple_parent->type()))) {
310-
std::stringstream errMsg;
311-
errMsg << "variable '" << name << "' previously has type "
310+
auto error = ErrorReport(loc);
311+
error << "Variable '" << name << "' previously has type "
312312
<< simple_parent->type()->python_str()
313313
<< " but is now being assigned to a value of type "
314314
<< as_simple_value->type()->python_str();
315+
315316
// Special-cased error msg if we're trying to assign to a tensor list.
316317
if (simple_parent->type()->kind() == TypeKind::ListType &&
317318
as_simple_value->type()->kind() == TypeKind::ListType) {
318-
errMsg << "\n. (Note: empty lists are constructed as Tensor[]; "
319+
error << "\n. (Note: empty lists are constructed as Tensor[]; "
319320
<< "if you want an empty list of a different type, "
320321
<< "use `torch.jit.annotate(List[T], [])`, "
321322
<< "where `T` is the type of elements in the list)";
322323
}
323-
throw ErrorReport(loc) << errMsg.str();
324+
throw error;
324325
}
325326
}
326327
if (as_simple_value) {
@@ -2455,7 +2456,7 @@ struct to_ir {
24552456
return graph->insertConstant(false, nullptr, tree->range());
24562457
} break;
24572458
case TK_NONE: {
2458-
return graph->insertConstant(IValue(), nullptr, tree->range());
2459+
return graph->insertConstant(IValue(), type_hint, tree->range());
24592460
} break;
24602461
case TK_SUBSCRIPT: {
24612462
return emitSubscript(Subscript(tree));
@@ -2542,7 +2543,6 @@ struct to_ir {
25422543
} break;
25432544
default:
25442545
throw ErrorReport(tree) << "Cannot emit expr for: " << tree;
2545-
break;
25462546
}
25472547
}
25482548

@@ -2694,7 +2694,11 @@ struct to_ir {
26942694
++dim;
26952695
continue;
26962696
}
2697-
auto index = emitExpr(subscript_expr, OptionalType::ofTensor());
2697+
TypePtr type_hint = OptionalType::ofTensor();
2698+
if (subscript_expr.kind() == TK_NONE) {
2699+
type_hint = NoneType::get();
2700+
}
2701+
auto index = emitExpr(subscript_expr, type_hint);
26982702
if (index->type() == IntType::get()) {
26992703
// NB: note, select squeezes out a dimension,
27002704
// so dim is **not** incremented

0 commit comments

Comments
 (0)