Skip to content

Commit c58709b

Browse files
jbschlosserfacebook-github-bot
authored andcommitted
Helper function for skipping module parameter / buffer initialization (pytorch#57555)
Summary: This PR introduces a helper function named `torch.nn.utils.skip_init()` that accepts a module class object + `args` / `kwargs` and instantiates the module while skipping initialization of parameter / buffer values. See discussion at pytorch#29523 for more context. Example usage: ```python import torch m = torch.nn.utils.skip_init(torch.nn.Linear, 5, 1) print(m.weight) m2 = torch.nn.utils.skip_init(torch.nn.Linear, 5, 1, device='cuda') print(m2.weight) m3 = torch.nn.utils.skip_init(torch.nn.Linear, in_features=5, out_features=1) print(m3.weight) ``` ``` Parameter containing: tensor([[-3.3011e+28, 4.5915e-41, -3.3009e+28, 4.5915e-41, 0.0000e+00]], requires_grad=True) Parameter containing: tensor([[-2.5339e+27, 4.5915e-41, -2.5367e+27, 4.5915e-41, 0.0000e+00]], device='cuda:0', requires_grad=True) Parameter containing: tensor([[1.4013e-45, 0.0000e+00, 0.0000e+00, 0.0000e+00, 0.0000e+00]], requires_grad=True) ``` Bikeshedding on the name / namespace is welcome, as well as comments on the design itself - just wanted to get something out there for discussion. Pull Request resolved: pytorch#57555 Reviewed By: zou3519 Differential Revision: D28640613 Pulled By: jbschlosser fbshipit-source-id: 5654f2e5af5530425ab7a9e357b6ba0d807e967f
1 parent 277f587 commit c58709b

File tree

4 files changed

+65
-0
lines changed

4 files changed

+65
-0
lines changed

docs/source/nn.rst

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -371,6 +371,7 @@ From the ``torch.nn.utils`` module
371371
remove_weight_norm
372372
spectral_norm
373373
remove_spectral_norm
374+
skip_init
374375

375376
Parametrizations implemented using the new parametrization functionality
376377
in :func:`torch.nn.utils.parameterize.register_parametrization`.

test/test_nn.py

Lines changed: 12 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -16236,6 +16236,18 @@ def forward(self, x):
1623616236
m.to_empty(device='meta')
1623716237
m(input)
1623816238

16239+
@skipMeta
16240+
def test_skip_init(self, device):
16241+
torch.manual_seed(1)
16242+
m_initialized = torch.nn.Linear(5, 1)
16243+
m_initialized.to(device)
16244+
16245+
torch.manual_seed(1)
16246+
m_uninitialized = torch.nn.utils.skip_init(torch.nn.Linear, 5, 1, device=device)
16247+
16248+
self.assertEqual(m_initialized.weight.device, m_uninitialized.weight.device)
16249+
self.assertFalse(torch.allclose(m_initialized.weight, m_uninitialized.weight))
16250+
1623916251
class TestModuleGlobalHooks(TestCase):
1624016252

1624116253
def tearDown(self):

torch/nn/utils/__init__.py

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -6,3 +6,4 @@
66
from .fusion import fuse_conv_bn_eval, fuse_conv_bn_weights
77
from .memory_format import convert_conv2d_weight_memory_format
88
from . import parametrizations
9+
from .init import skip_init

torch/nn/utils/init.py

Lines changed: 51 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,51 @@
1+
import inspect
2+
import torch
3+
4+
5+
def skip_init(module_cls, *args, **kwargs):
6+
r"""
7+
Given a module class object and args / kwargs, instantiates the module without initializing
8+
parameters / buffers. This can be useful if initialization is slow or if custom initialization will
9+
be performed, making the default initialization unnecessary. There are some caveats to this, due to
10+
the way this function is implemented:
11+
12+
1. The module must accept a `device` arg in its constructor that is passed to any parameters
13+
or buffers created during construction.
14+
15+
2. The module must not perform any computation on parameters in its constructor except
16+
initialization (i.e. functions from :mod:`torch.nn.init`).
17+
18+
If these conditions are satisfied, the module can be instantiated with parameter / buffer values
19+
uninitialized, as if having been created using :func:`torch.empty`.
20+
21+
Args:
22+
module_cls: Class object; should be a subclass of :class:`torch.nn.Module`
23+
args: args to pass to the module's constructor
24+
kwargs: kwargs to pass to the module's constructor
25+
26+
Returns:
27+
Instantiated module with uninitialized parameters / buffers
28+
29+
Example::
30+
31+
>>> import torch
32+
>>> m = torch.nn.utils.skip_init(torch.nn.Linear, 5, 1)
33+
>>> m.weight
34+
Parameter containing:
35+
tensor([[0.0000e+00, 1.5846e+29, 7.8307e+00, 2.5250e-29, 1.1210e-44]],
36+
requires_grad=True)
37+
>>> m2 = torch.nn.utils.skip_init(torch.nn.Linear, in_features=6, out_features=1)
38+
>>> m2.weight
39+
Parameter containing:
40+
tensor([[-1.4677e+24, 4.5915e-41, 1.4013e-45, 0.0000e+00, -1.4677e+24,
41+
4.5915e-41]], requires_grad=True)
42+
43+
"""
44+
if not issubclass(module_cls, torch.nn.Module):
45+
raise RuntimeError('Expected a Module; got {}'.format(module_cls))
46+
if 'device' not in inspect.signature(module_cls).parameters:
47+
raise RuntimeError('Module must support a \'device\' arg to skip initialization')
48+
49+
final_device = kwargs.pop('device', 'cpu')
50+
kwargs['device'] = 'meta'
51+
return module_cls(*args, **kwargs).to_empty(device=final_device)

0 commit comments

Comments
 (0)