Skip to content

Commit 003c7d3

Browse files
authored
Add CPU allocation test for multiple GPU distributed run (#15829)
### Add CPU allocation test for non-CPU devices distributed run When CUDA EP is enabled in distributed training, CPU memory is still used for some node output. Early we have distributed run test coverage, but don't cover the case when some of the node are using CPU devices for storing tensor output. As a result, I recalled we hit regression twice in the passing months: - #14050 - #15823 So adding this test to avoid future regressions. The test graph looks like this: ![image](https://user-images.githubusercontent.com/10530022/236594940-70c68a55-18bf-4e09-bbf5-8a64895d3045.png) ### Motivation and Context <!-- - Why is this change required? What problem does it solve? - If it fixes an open issue, please link to the issue here. -->
1 parent 817d70a commit 003c7d3

File tree

1 file changed

+99
-0
lines changed

1 file changed

+99
-0
lines changed

orttraining/orttraining/test/python/orttraining_test_ortmodule_pytorch_ddp.py

+99
Original file line numberDiff line numberDiff line change
@@ -146,6 +146,104 @@ def demo_checkpoint(rank, world_size, use_ort_module):
146146
cleanup()
147147

148148

149+
"""
150+
CustomEmbedding is adapted from
151+
https://github.com/huggingface/transformers/blob/312b104ff65514736c0475814fec19e47425b0b5/src/transformers/models/distilbert/modeling_distilbert.py#L91.
152+
"""
153+
154+
155+
class CustomEmbeddings(nn.Module):
156+
def __init__(self):
157+
super().__init__()
158+
vocab_size = 511
159+
dim = 10
160+
pad_token_id = 0
161+
max_position_embeddings = 16
162+
self.word_embeddings = nn.Embedding(vocab_size, dim, padding_idx=pad_token_id)
163+
self.position_embeddings = nn.Embedding(max_position_embeddings, dim)
164+
self.LayerNorm = nn.LayerNorm(dim, eps=1e-12)
165+
self.register_buffer("position_ids", torch.arange(max_position_embeddings).expand((1, -1)), persistent=False)
166+
167+
def forward(self, input_ids: torch.Tensor) -> torch.Tensor:
168+
input_embeds = self.word_embeddings(input_ids) # (bs, max_seq_length, dim)
169+
seq_length = input_embeds.size(1)
170+
position_ids = self.position_ids[:, :seq_length]
171+
position_embeddings = self.position_embeddings(position_ids) # (bs, max_seq_length, dim)
172+
173+
embeddings = input_embeds + position_embeddings # (bs, max_seq_length, dim)
174+
embeddings = self.LayerNorm(embeddings) # (bs, max_seq_length, dim)
175+
return embeddings
176+
177+
178+
"""
179+
This module calls `CustomEmbeddings`, which will generate a series of nodes (Shape->Gather->Unsqueeze...) whos
180+
allocate output tensors on CPU memory. Then this test can converge failures when CUDA EP enabled, CPU device
181+
allocation and usage are correct.
182+
"""
183+
184+
185+
class AnotherLayerOfToyModel(nn.Module):
186+
def __init__(self):
187+
super().__init__()
188+
self.embedding = CustomEmbeddings()
189+
self.t = ToyModel()
190+
191+
def forward(self, x):
192+
embed_val = self.embedding(x)
193+
return self.t(embed_val)
194+
195+
196+
"""
197+
`Mixed device allocation` here means ORT backend allocates output tensors on CPU for some nodes and
198+
on CUDA for other nodes. This test could help catch regression when ORT allocation planner logic got changed with bugs.
199+
"""
200+
201+
202+
def demo_mixed_device_allocation_training(rank, world_size, use_ort_module):
203+
torch.manual_seed(0)
204+
print(f"Running basic DDP example on rank {rank}.")
205+
setup(rank, world_size)
206+
device = "cuda:" + str(rank)
207+
208+
# create a model and move it to GPU with id rank
209+
model = AnotherLayerOfToyModel().to(device)
210+
if use_ort_module:
211+
model = ORTModule(model)
212+
print(f" Rank {rank} uses ORTModule.")
213+
else:
214+
print(f" Rank {rank} uses Pytorch's nn.Module.")
215+
216+
ddp_model = DDP(model, device_ids=[device])
217+
218+
loss_fn = nn.MSELoss()
219+
optimizer = optim.Adagrad(ddp_model.parameters(), lr=0.01)
220+
221+
batch = 2
222+
max_seq_length = 16
223+
x = torch.randint(1, 511, (batch, max_seq_length)).to(device)
224+
y = torch.randn(batch, max_seq_length, 5).to(device)
225+
226+
loss_history = []
227+
228+
for i in range(5):
229+
optimizer.zero_grad()
230+
p = ddp_model(x)
231+
loss = loss_fn(p, y)
232+
with torch.no_grad():
233+
print(f" Rank {rank} at iteration {i} has loss {loss}.")
234+
loss.backward()
235+
optimizer.step()
236+
with torch.no_grad():
237+
loss_history.append(torch.unsqueeze(loss, 0))
238+
239+
loss_history = torch.cat(loss_history).cpu()
240+
expected_loss_history = torch.FloatTensor([1.1589857340, 1.0975260735, 1.0628030300, 1.0386666059, 1.0196533203])
241+
242+
assert torch.allclose(expected_loss_history, loss_history)
243+
244+
cleanup()
245+
246+
149247
def run_demo(demo_fn, world_size, use_ort_module):
150248
mp.spawn(demo_fn, args=(world_size, use_ort_module), nprocs=world_size, join=True)
151249

@@ -160,3 +258,4 @@ def parse_args():
160258
args = parse_args()
161259
run_demo(demo_basic, 4, args.use_ort_module)
162260
run_demo(demo_checkpoint, 4, args.use_ort_module)
261+
run_demo(demo_mixed_device_allocation_training, 4, args.use_ort_module)

0 commit comments

Comments
 (0)