Skip to content

Commit efca1d6

Browse files
kausvfacebook-github-bot
authored andcommitted
Handle meta tensors in FX quantization (#2622)
Summary: Pull Request resolved: #2622 X-link: pytorch/pytorch#142262 If module being quantized contains a some meta tensors and some tensors with actual device, we should not fail quantization. Quantization should also not fail if new quantized module is created on a meta device. If devices contain meta, copying from meta to meta is not necessary, copying from another device to meta can be skipped. Reviewed By: emlin Differential Revision: D66895899 fbshipit-source-id: bba8de9ddc5f86292521985dc588f9dbe14b4b4c
1 parent dab48e4 commit efca1d6

File tree

1 file changed

+33
-1
lines changed

1 file changed

+33
-1
lines changed

torchrec/quant/embedding_modules.py

+33-1
Original file line numberDiff line numberDiff line change
@@ -10,7 +10,18 @@
1010
import copy
1111
import itertools
1212
from collections import defaultdict
13-
from typing import Callable, cast, Dict, List, Optional, Sequence, Tuple, Type, Union
13+
from typing import (
14+
Any,
15+
Callable,
16+
cast,
17+
Dict,
18+
List,
19+
Optional,
20+
Sequence,
21+
Tuple,
22+
Type,
23+
Union,
24+
)
1425

1526
import torch
1627
import torch.nn as nn
@@ -971,6 +982,27 @@ def __init__(
971982
) in self._managed_collision_collection._managed_collision_modules.values():
972983
managed_collision_module.reset_inference_mode()
973984

985+
def to(
986+
self, *args: List[Any], **kwargs: Dict[str, Any]
987+
) -> "QuantManagedCollisionEmbeddingCollection":
988+
device, dtype, non_blocking, _ = torch._C._nn._parse_to(
989+
*args, # pyre-ignore
990+
**kwargs, # pyre-ignore
991+
)
992+
for param in self.parameters():
993+
if param.device.type != "meta":
994+
param.to(device)
995+
996+
for buffer in self.buffers():
997+
if buffer.device.type != "meta":
998+
buffer.to(device)
999+
# Skip device movement and continue with other args
1000+
super().to(
1001+
dtype=dtype,
1002+
non_blocking=non_blocking,
1003+
)
1004+
return self
1005+
9741006
def forward(
9751007
self,
9761008
features: KeyedJaggedTensor,

0 commit comments

Comments
 (0)