-
Notifications
You must be signed in to change notification settings - Fork 27
/
Copy pathdeferred_init.py
63 lines (51 loc) · 1.93 KB
/
deferred_init.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
import torch
from torch._subclasses.fake_tensor import FakeTensorMode
from torch.fx.experimental.proxy_tensor import PythonKeyTracer, ProxyTorchDispatchMode
from torch.fx import Graph, GraphModule
# Limitations:
# - initialization cannot refer to external tensors
# - parameters are these weird ProxyTensors, should have a custom class for
# these placeholders
# - DCE is likely not sound, needs to be implemented more carefully by
# understanding aliasing relationships
# - only top level module is rematerialized
# - we lose parameter-ness and requires_grad-ness
# - no version counter safety to guard against input mutation
def deferred_init(f, *args, **kwargs):
fx_tracer = PythonKeyTracer()
fx_tracer.graph = Graph(fx_tracer)
fx_tracer.root = torch.nn.Module()
fx_tracer.tensor_attrs = {}
fake_tensor_mode = FakeTensorMode(allow_fallback_kernels=True)
proxy_mode = ProxyTorchDispatchMode(fx_tracer, tracing_mode="real")
with fake_tensor_mode, proxy_mode:
r = f(*args, **kwargs)
r._deferred = fx_tracer
return r
def materialize_module(m):
# TODO: handle children
outputs = []
def mark_for_materialize(tensors):
for k, t in tensors.items():
if t is None:
continue
outputs.append(t.proxy.node)
mark_for_materialize(m._parameters)
mark_for_materialize(m._buffers)
m._deferred.graph.output(outputs)
m._deferred.graph.eliminate_dead_code() # hmmm
recomp = GraphModule(m._deferred.root, m._deferred.graph)
results = recomp()
results_iter = iter(results)
def replace_results(tensors):
for k, t in tensors.items():
if t is None:
continue
tensors[k] = next(results_iter)
replace_results(m._parameters)
replace_results(m._buffers)
del m._deferred
m = deferred_init(torch.nn.Linear, 3, 5)
print(m.weight)
materialize_module(m)
print(m.weight)