Skip to content

Commit 8d508ed

Browse files
gau-nernstjainapurva
authored andcommitted
[low-bit optim] Fix load state dict when device is different (#1021)
* fix serialization * fix pytorch 2.3 * fix typo * update note
1 parent b7126d8 commit 8d508ed

File tree

5 files changed

+118
-11
lines changed

5 files changed

+118
-11
lines changed

test/prototype/test_low_bit_optim.py

Lines changed: 26 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -97,7 +97,30 @@ def test_optim_smoke(self, optim_name, dtype, device):
9797
optim.step()
9898
optim.zero_grad()
9999

100-
@pytest.mark.skipif(bnb is None, reason="bitsandbytes is not availablle")
100+
# test serialization. also test the case CUDA optim loads CPU state dict
101+
with tempfile.NamedTemporaryFile() as f:
102+
torch.save(optim.state_dict(), f.name)
103+
state_dict = torch.load(f.name, map_location="cpu")
104+
105+
model2 = copy.deepcopy(model)
106+
optim2 = getattr(low_bit_optim, optim_name)(model2.parameters())
107+
optim2.load_state_dict(state_dict)
108+
109+
for _ in range(2):
110+
x = torch.randn(4, 32, device=device, dtype=dtype)
111+
112+
model(x).sum().backward()
113+
optim.step()
114+
optim.zero_grad()
115+
116+
model2(x).sum().backward()
117+
optim2.step()
118+
optim2.zero_grad()
119+
120+
for p1, p2 in zip(model.parameters(), model2.parameters()):
121+
torch.testing.assert_close(p2, p1)
122+
123+
@pytest.mark.skipif(bnb is None, reason="bitsandbytes is not available")
101124
@pytest.mark.skipif(not torch.cuda.is_available(), reason="bitsandbytes 8-bit Adam only works for CUDA")
102125
@pytest.mark.skipif(not TORCH_VERSION_AT_LEAST_2_3, reason="requires PyTorch >= 2.3")
103126
@parametrize("optim_name", ["Adam8bit", "AdamW8bit"])
@@ -129,7 +152,7 @@ def test_optim_8bit_correctness(self, optim_name):
129152
torch.testing.assert_close(p2, p1, rtol=1e-5, atol=1e-5)
130153

131154
# this will not run in CI because we can't install lpmm
132-
@pytest.mark.skipif(lpmm is None, reason="lpmm is not availablle")
155+
@pytest.mark.skipif(lpmm is None, reason="lpmm is not available")
133156
@pytest.mark.skipif(not torch.cuda.is_available(), reason="lpmm 4-bit Adam only works for CUDA")
134157
@pytest.mark.skipif(not TORCH_VERSION_AT_LEAST_2_3, reason="requires PyTorch >= 2.3")
135158
@parametrize("optim_name", ["Adam4bit", "AdamW4bit"])
@@ -205,7 +228,7 @@ def test_optim_cpu_offload_save_load(self):
205228
# save checkpoint. make sure it can be serialized by torch.save()
206229
with tempfile.NamedTemporaryFile() as file:
207230
torch.save(optim1.state_dict(), file.name)
208-
state_dict = torch.load(file.name)
231+
state_dict = torch.load(file.name, map_location="cpu")
209232

210233
# resume training
211234
model2 = copy.deepcopy(model1)

torchao/prototype/low_bit_optim/adam.py

Lines changed: 3 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -109,6 +109,9 @@ def step(self, closure=None):
109109

110110
# this will work with any optim state tensor subclass that implements aten.lerp.Scalar and aten.copy_.default
111111
# and param tensor subclass that implements aten.add_.Tensor, and aten.addcdiv_.default
112+
# NOTE: right now all of our optimizer state subclasses will dequant to FP32, thus adam computation
113+
# will be done in FP32 (not purposely). we should explicitly cast all inputs to FP32 to ensure FP32
114+
# computation. will need to benchmark to ensure no slowdown.
112115
def single_param_adam(
113116
p: Tensor,
114117
grad: Tensor,

torchao/prototype/low_bit_optim/subclass_4bit.py

Lines changed: 37 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -2,7 +2,8 @@
22

33
import torch
44
from torch import Tensor
5-
from torchao.utils import TorchAOBaseTensor
5+
from torch.utils._python_dispatch import return_and_correct_aliasing
6+
from torchao.utils import TorchAOBaseTensor, TORCH_VERSION_AT_LEAST_2_4
67

78
from .quant_utils import create_dynamic_map, scale_tensor, quantize_4bit_with_qmap, dequant_with_qmap
89

@@ -60,8 +61,9 @@ def __tensor_unflatten__(cls, tensor_data_dict, tensor_attributes, outer_size=No
6061
def dequantize(self, output_dtype=None):
6162
codes = torch.stack([self.codes >> 4, self.codes & 0b1111], dim=-1) # unpack
6263
float_data = dequant_with_qmap(codes, self.qmap, self.scale)
63-
dtype = output_dtype or torch.get_default_dtype()
64-
return float_data.view(self._shape).to(dtype)
64+
if output_dtype is not None:
65+
float_data = float_data.to(output_dtype)
66+
return float_data.view(self._shape)
6567

6668
@classmethod
6769
def zeros(cls, shape, signed: bool = True, block_size: int = 128, device=None):
@@ -80,6 +82,24 @@ def __repr__(self):
8082
)
8183

8284

85+
# in pre-2.4, calling .to(device, dtype) will not dispatch aten._to_copy.default when
86+
# dtype is the same but device is different. thus, we must override .to() method instead.
87+
if not TORCH_VERSION_AT_LEAST_2_4:
88+
def _to(self, *args, **kwargs):
89+
# ignore other args/kwargs
90+
device = kwargs.pop("device", None)
91+
return OptimState4bit(
92+
self.codes.to(device),
93+
self.scale.to(device),
94+
self.qmap.to(device),
95+
self.signed,
96+
self.shape,
97+
)
98+
99+
OptimState4bit.to = _to
100+
del _to # make sure to not re-use
101+
102+
83103
@OptimState4bit.implements(aten.copy_.default)
84104
def _(func, types, args, kwargs):
85105
dst = args[0]
@@ -107,6 +127,20 @@ def _(func, types, args, kwargs):
107127
return dst
108128

109129

130+
@OptimState4bit.implements(aten._to_copy.default)
131+
def _(func, types, args, kwargs):
132+
# ignore dtype
133+
device = kwargs.get("device", None)
134+
out = OptimState4bit(
135+
args[0].codes.to(device=device),
136+
args[0].scale.to(device=device),
137+
args[0].qmap.to(device=device),
138+
args[0].signed,
139+
args[0].shape,
140+
)
141+
return return_and_correct_aliasing(func, args, kwargs, out)
142+
143+
110144
@OptimState4bit.implements(aten.lerp.Scalar)
111145
def _(func, types, args, kwargs):
112146
args = [x.dequantize() if isinstance(x, OptimState4bit) else x for x in args]

torchao/prototype/low_bit_optim/subclass_8bit.py

Lines changed: 36 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -1,6 +1,7 @@
11
import torch
22
from torch import Tensor
3-
from torchao.utils import TorchAOBaseTensor
3+
from torch.utils._python_dispatch import return_and_correct_aliasing
4+
from torchao.utils import TorchAOBaseTensor, TORCH_VERSION_AT_LEAST_2_4
45

56
from .quant_utils import create_dynamic_map, scale_tensor, quantize_8bit_with_qmap, dequant_with_qmap
67

@@ -49,8 +50,10 @@ def __tensor_unflatten__(cls, tensor_data_dict, tensor_attributes, outer_size=No
4950
return cls(*[tensor_data_dict[name] for name in cls.tensor_attrs], *tensor_attributes)
5051

5152
def dequantize(self, output_dtype=None):
52-
dtype = output_dtype or torch.get_default_dtype()
53-
return dequant_with_qmap(self.codes, self.qmap, self.scale).to(dtype)
53+
float_data = dequant_with_qmap(self.codes, self.qmap, self.scale)
54+
if output_dtype is not None:
55+
float_data = float_data.to(output_dtype)
56+
return float_data
5457

5558
@classmethod
5659
def zeros(cls, shape, signed: bool = True, block_size: int = 256, device=None):
@@ -66,6 +69,23 @@ def __repr__(self):
6669
)
6770

6871

72+
# in pre-2.4, calling .to(device, dtype) will not dispatch aten._to_copy.default when
73+
# dtype is the same but device is different. thus, we must override .to() method instead.
74+
if not TORCH_VERSION_AT_LEAST_2_4:
75+
def _to(self, *args, **kwargs):
76+
# ignore other args/kwargs
77+
device = kwargs.pop("device", None)
78+
return OptimState8bit(
79+
self.codes.to(device),
80+
self.scale.to(device),
81+
self.qmap.to(device),
82+
self.signed,
83+
)
84+
85+
OptimState8bit.to = _to
86+
del _to # make sure to not re-use
87+
88+
6989
@OptimState8bit.implements(aten.copy_.default)
7090
def _(func, types, args, kwargs):
7191
dst = args[0]
@@ -89,6 +109,19 @@ def _(func, types, args, kwargs):
89109
return dst
90110

91111

112+
@OptimState8bit.implements(aten._to_copy.default)
113+
def _(func, types, args, kwargs):
114+
# ignore dtype
115+
device = kwargs.get("device", None)
116+
out = OptimState8bit(
117+
args[0].codes.to(device=device),
118+
args[0].scale.to(device=device),
119+
args[0].qmap.to(device=device),
120+
args[0].signed,
121+
)
122+
return return_and_correct_aliasing(func, args, kwargs, out)
123+
124+
92125
@OptimState8bit.implements(aten.lerp.Scalar)
93126
def _(func, types, args, kwargs):
94127
args = [x.dequantize() if isinstance(x, OptimState8bit) else x for x in args]

torchao/prototype/low_bit_optim/subclass_fp8.py

Lines changed: 16 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -1,5 +1,6 @@
11
import torch
22
from torch import Tensor
3+
from torch.utils._python_dispatch import return_and_correct_aliasing
34
from torchao.utils import TorchAOBaseTensor
45

56

@@ -21,6 +22,7 @@ def quantize_fp8(input: Tensor, block_size: int):
2122

2223
# NOTE: FP8 sign bit is redundant for unsigned optim state.
2324
# we may investigate how to use it to increase range/precision for unsigned optim state.
25+
# https://arxiv.org/abs/2409.12517 uses FP8 E5M2 for 2nd Adam buffer
2426
class OptimStateFp8(TorchAOBaseTensor):
2527
tensor_attrs = ["codes", "scale"]
2628

@@ -56,8 +58,9 @@ def dequantize(self, output_dtype=None):
5658
float_data = self.codes.float()
5759
float_data = float_data.view(-1, self.block_size) * self.scale.view(-1, 1)
5860

59-
dtype = output_dtype or torch.get_default_dtype()
60-
return float_data.view(self.codes.shape).to(dtype)
61+
if output_dtype is not None:
62+
float_data = float_data.to(output_dtype)
63+
return float_data.view(self.codes.shape)
6164

6265
@classmethod
6366
def zeros(cls, shape, block_size: int = 256, device=None):
@@ -93,6 +96,17 @@ def _(func, types, args, kwargs):
9396
return dst
9497

9598

99+
@OptimStateFp8.implements(aten._to_copy.default)
100+
def _(func, types, args, kwargs):
101+
# ignore dtype
102+
device = kwargs.get("device", None)
103+
out = OptimStateFp8(
104+
args[0].codes.to(device=device),
105+
args[0].scale.to(device=device),
106+
)
107+
return return_and_correct_aliasing(func, args, kwargs, out)
108+
109+
96110
@OptimStateFp8.implements(aten.lerp.Scalar)
97111
def _(func, types, args, kwargs):
98112
args = [x.dequantize() if isinstance(x, OptimStateFp8) else x for x in args]

0 commit comments

Comments
 (0)