Skip to content

Commit 5e59b51

Browse files
[float8nocompile] add e2e fsdp test (#1523)
1 parent f90b29e commit 5e59b51

File tree

2 files changed

+97
-3
lines changed

2 files changed

+97
-3
lines changed

torchao/prototype/float8nocompile/.gitignore

-3
This file was deleted.
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,97 @@
1+
######################################################################
2+
#
3+
# To run these unit tests, use the following command:
4+
#
5+
# torchrun --nproc_per_node=${NUM_GPUS} -m pytest test/fsdp_test.py
6+
#
7+
#######################################################################
8+
import os
9+
10+
import pytest
11+
import torch
12+
import torch.distributed as dist
13+
import torch.nn as nn
14+
from torch.distributed._composable.fsdp import fully_shard
15+
16+
from torchao.float8.float8_linear_utils import convert_to_float8_training
17+
from torchao.prototype.float8nocompile.float8nocompile_linear_utils import (
18+
convert_to_float8_nocompile_training,
19+
)
20+
from torchao.utils import TORCH_VERSION_AT_LEAST_2_5
21+
22+
if not TORCH_VERSION_AT_LEAST_2_5:
23+
raise AssertionError("torchao.float8 requires PyTorch version 2.5 or greater")
24+
25+
26+
class TestModel(nn.Module):
27+
def __init__(self):
28+
super().__init__()
29+
self.layers = nn.Sequential(
30+
nn.Linear(2048, 4096, bias=False),
31+
nn.Linear(4096, 16, bias=False),
32+
)
33+
34+
def forward(self, x: torch.Tensor) -> torch.Tensor:
35+
return self.layers(x)
36+
37+
38+
def setup_distributed():
39+
rank = int(os.environ["RANK"])
40+
world_size = int(os.environ["WORLD_SIZE"])
41+
dist.init_process_group("nccl", rank=rank, world_size=world_size)
42+
torch.cuda.set_device(rank)
43+
44+
45+
@pytest.fixture
46+
def model1():
47+
torch.manual_seed(0)
48+
return TestModel()
49+
50+
51+
@pytest.fixture
52+
def model2():
53+
torch.manual_seed(0)
54+
return TestModel()
55+
56+
57+
def test_model_weights_and_gradients(model1, model2):
58+
assert torch.cuda.is_available()
59+
device = torch.device("cuda")
60+
61+
setup_distributed()
62+
63+
model1 = model1.to(torch.bfloat16).to(device)
64+
model2 = model2.to(torch.bfloat16).to(device)
65+
66+
# compare production float8 linear conversion with no-compile version
67+
convert_to_float8_training(model2)
68+
convert_to_float8_nocompile_training(model1)
69+
70+
# distributed training with FSDP2
71+
fully_shard(model1)
72+
fully_shard(model2)
73+
74+
input_tensor = torch.randn(
75+
16, 2048, requires_grad=True, dtype=torch.bfloat16, device=device
76+
)
77+
input_copy1 = input_tensor.clone().detach().requires_grad_(True)
78+
input_copy2 = input_tensor.clone().detach().requires_grad_(True)
79+
80+
loss_fn = nn.MSELoss()
81+
82+
output1 = model1(input_copy1)
83+
output2 = model2(input_copy2)
84+
85+
loss1 = loss_fn(output1, torch.zeros_like(output1))
86+
loss2 = loss_fn(output2, torch.zeros_like(output2))
87+
88+
loss1.backward()
89+
loss2.backward()
90+
91+
# compare the outputs, weight gradients, and input gradients
92+
assert torch.allclose(output1, output2, atol=0, rtol=0)
93+
assert torch.allclose(input_copy1.grad, input_copy2.grad, atol=0, rtol=0)
94+
for param1, param2 in zip(model1.parameters(), model2.parameters()):
95+
assert torch.equal(param1.grad, param2.grad)
96+
97+
dist.destroy_process_group()

0 commit comments

Comments
 (0)