forked from pytorch/pytorch
-
Notifications
You must be signed in to change notification settings - Fork 1
/
Copy pathtest_content_store.py
135 lines (118 loc) · 4.82 KB
/
test_content_store.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
# Owner(s): ["oncall: pt2"]
import torch
from torch._prims.debug_prims import load_tensor_reader
from torch._subclasses.fake_tensor import FakeTensor, FakeTensorMode
from torch.multiprocessing.reductions import StorageWeakRef
from torch.testing._internal.common_device_type import instantiate_device_type_tests
from torch.testing._internal.common_utils import (
run_tests,
skipIfRocm,
TemporaryDirectoryName,
TestCase,
)
from torch.utils._content_store import (
ContentStoreReader,
ContentStoreWriter,
hash_storage,
)
class TestContentStore(TestCase):
def test_basic(self, device):
# setup test data
x = torch.randn(4, device=device)
y = torch.randn(6, device=device)
z = x.view(2, 2)
# start writing
with TemporaryDirectoryName() as loc:
writer = ContentStoreWriter(loc)
writer.write_tensor("x", x)
writer.write_tensor("y", y)
writer.write_tensor("z", z)
# do some mutation that is VC UNTRACKED
x.data.add_(1)
writer.write_tensor("x2", x)
writer.write_tensor("y2", y)
writer.write_tensor("z2", z)
del writer
reader = ContentStoreReader(loc)
n_x = reader.read_tensor("x")
n_y = reader.read_tensor("y")
n_z = reader.read_tensor("z")
self.assertEqual(n_x + 1, x)
self.assertEqual(n_y, y)
self.assertEqual(n_z + 1, z)
self.assertEqual(
StorageWeakRef(n_x.untyped_storage()),
StorageWeakRef(n_z.untyped_storage()),
)
n_x2 = reader.read_tensor("x2")
n_y2 = reader.read_tensor("y2")
n_z2 = reader.read_tensor("z2")
self.assertEqual(n_x2, x)
self.assertEqual(n_y2, y)
self.assertEqual(n_z2, z)
self.assertEqual(
StorageWeakRef(n_y2.untyped_storage()),
StorageWeakRef(n_y.untyped_storage()),
)
def test_scalar(self, device):
# Should not raise an error
hash_storage(torch.tensor(2, device=device).untyped_storage())
@torch._dynamo.config.patch(recompile_limit=1)
def test_repeated_hash(self, device):
# Test that repeated hashing doesn't trigger a recompile in dynamo
# If it does, we will execute prims.xor_sum in eager which fails
for _ in range(4):
hash_storage(torch.tensor(2, device=device).untyped_storage())
@skipIfRocm
def test_load_tensor(self, device):
with TemporaryDirectoryName() as loc:
writer = ContentStoreWriter(loc)
x = torch.randn(4, device=device)
def same_meta_as_x(t):
self.assertEqual(t.size(), x.size())
self.assertEqual(t.stride(), x.stride())
self.assertEqual(t.dtype, x.dtype)
self.assertEqual(t.device, x.device)
writer.write_tensor("x", x)
with load_tensor_reader(loc):
x2 = torch.ops.debugprims.load_tensor.default(
"x", (4,), (1,), dtype=torch.float32, device=device
)
self.assertEqual(x, x2)
x3 = torch.ops.debugprims.load_tensor.default(
"x", (4,), (1,), dtype=torch.float32, device=device
)
self.assertEqual(x, x3)
# Must not alias!
self.assertNotEqual(
StorageWeakRef(x.untyped_storage()),
StorageWeakRef(x2.untyped_storage()),
)
self.assertNotEqual(
StorageWeakRef(x2.untyped_storage()),
StorageWeakRef(x3.untyped_storage()),
)
# Check fake tensor mode works too
with FakeTensorMode():
x4 = torch.ops.debugprims.load_tensor.default(
"x", (4,), (1,), dtype=torch.float32, device=device
)
self.assertIsInstance(x4, FakeTensor)
same_meta_as_x(x4)
# Check fp64 works on non-MPS platforms, since MPS doesn't currently
# support fp64.
if not device.startswith("mps"):
x5 = torch.ops.debugprims.load_tensor.default(
"x", (4,), (1,), dtype=torch.float64, device=device
)
self.assertEqual(x5.float(), x)
self.assertEqual(x5.dtype, torch.float64)
x6 = torch.ops.debugprims.load_tensor.default(
"x", (4,), (1,), dtype=torch.float32, device=device
)
same_meta_as_x(x6)
instantiate_device_type_tests(
TestContentStore, globals(), allow_mps=True, allow_xpu=True
)
if __name__ == "__main__":
run_tests()