Skip to content

Commit c099de4

Browse files
authored
[Perf] Add a CPU-based training workload (#2116)
1 parent 296d480 commit c099de4

File tree

2 files changed

+220
-0
lines changed

2 files changed

+220
-0
lines changed
Lines changed: 57 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,57 @@
1+
apiVersion: ray.io/v1
2+
kind: RayJob
3+
metadata:
4+
name: rayjob-pytorch-mnist
5+
spec:
6+
shutdownAfterJobFinishes: false
7+
entrypoint: python ray-operator/config/samples/pytorch-mnist/ray_train_pytorch_mnist.py
8+
runtimeEnvYAML: |
9+
pip:
10+
- torch
11+
- torchvision
12+
working_dir: "https://github.com/ray-project/kuberay/archive/master.zip"
13+
14+
# rayClusterSpec specifies the RayCluster instance to be created by the RayJob controller.
15+
rayClusterSpec:
16+
rayVersion: '2.9.0'
17+
headGroupSpec:
18+
rayStartParams: {}
19+
# Pod template
20+
template:
21+
spec:
22+
containers:
23+
- name: ray-head
24+
image: rayproject/ray:2.9.0
25+
ports:
26+
- containerPort: 6379
27+
name: gcs-server
28+
- containerPort: 8265 # Ray dashboard
29+
name: dashboard
30+
- containerPort: 10001
31+
name: client
32+
resources:
33+
limits:
34+
cpu: "2"
35+
memory: "4Gi"
36+
requests:
37+
cpu: "2"
38+
memory: "4Gi"
39+
workerGroupSpecs:
40+
- replicas: 4
41+
minReplicas: 1
42+
maxReplicas: 5
43+
groupName: small-group
44+
rayStartParams: {}
45+
# Pod template
46+
template:
47+
spec:
48+
containers:
49+
- name: ray-worker
50+
image: rayproject/ray:2.9.0
51+
resources:
52+
limits:
53+
cpu: "2"
54+
memory: "4Gi"
55+
requests:
56+
cpu: "2"
57+
memory: "4Gi"
Lines changed: 163 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,163 @@
1+
"""
2+
Reference: https://docs.ray.io/en/master/train/examples/pytorch/torch_fashion_mnist_example.html
3+
4+
This script is a modified version of the original PyTorch Fashion MNIST
5+
example. It uses only CPU resources to train the MNIST model. See
6+
`ScalingConfig` for more details.
7+
"""
8+
import os
9+
from typing import Dict
10+
11+
import torch
12+
from filelock import FileLock
13+
from torch import nn
14+
from torch.utils.data import DataLoader
15+
from torchvision import datasets, transforms
16+
from torchvision.transforms import Normalize, ToTensor
17+
from tqdm import tqdm
18+
19+
import ray.train
20+
from ray.train import ScalingConfig
21+
from ray.train.torch import TorchTrainer
22+
23+
24+
def get_dataloaders(batch_size):
25+
# Transform to normalize the input images
26+
transform = transforms.Compose([ToTensor(), Normalize((0.5,), (0.5,))])
27+
28+
with FileLock(os.path.expanduser("~/data.lock")):
29+
# Download training data from open datasets
30+
training_data = datasets.FashionMNIST(
31+
root="~/data",
32+
train=True,
33+
download=True,
34+
transform=transform,
35+
)
36+
37+
# Download test data from open datasets
38+
test_data = datasets.FashionMNIST(
39+
root="~/data",
40+
train=False,
41+
download=True,
42+
transform=transform,
43+
)
44+
45+
# Create data loaders
46+
train_dataloader = DataLoader(training_data, batch_size=batch_size, shuffle=True)
47+
test_dataloader = DataLoader(test_data, batch_size=batch_size)
48+
49+
return train_dataloader, test_dataloader
50+
51+
52+
# Model Definition
53+
class NeuralNetwork(nn.Module):
54+
def __init__(self):
55+
super(NeuralNetwork, self).__init__()
56+
self.flatten = nn.Flatten()
57+
self.linear_relu_stack = nn.Sequential(
58+
nn.Linear(28 * 28, 512),
59+
nn.ReLU(),
60+
nn.Dropout(0.25),
61+
nn.Linear(512, 512),
62+
nn.ReLU(),
63+
nn.Dropout(0.25),
64+
nn.Linear(512, 10),
65+
nn.ReLU(),
66+
)
67+
68+
def forward(self, x):
69+
x = self.flatten(x)
70+
logits = self.linear_relu_stack(x)
71+
return logits
72+
73+
74+
def train_func_per_worker(config: Dict):
75+
lr = config["lr"]
76+
epochs = config["epochs"]
77+
batch_size = config["batch_size_per_worker"]
78+
79+
# Get dataloaders inside the worker training function
80+
train_dataloader, test_dataloader = get_dataloaders(batch_size=batch_size)
81+
82+
# [1] Prepare Dataloader for distributed training
83+
# Shard the datasets among workers and move batches to the correct device
84+
# =======================================================================
85+
train_dataloader = ray.train.torch.prepare_data_loader(train_dataloader)
86+
test_dataloader = ray.train.torch.prepare_data_loader(test_dataloader)
87+
88+
model = NeuralNetwork()
89+
90+
# [2] Prepare and wrap your model with DistributedDataParallel
91+
# Move the model to the correct GPU/CPU device
92+
# ============================================================
93+
model = ray.train.torch.prepare_model(model)
94+
95+
loss_fn = nn.CrossEntropyLoss()
96+
optimizer = torch.optim.SGD(model.parameters(), lr=lr, momentum=0.9)
97+
98+
# Model training loop
99+
for epoch in range(epochs):
100+
if ray.train.get_context().get_world_size() > 1:
101+
# Required for the distributed sampler to shuffle properly across epochs.
102+
train_dataloader.sampler.set_epoch(epoch)
103+
104+
model.train()
105+
for X, y in tqdm(train_dataloader, desc=f"Train Epoch {epoch}"):
106+
pred = model(X)
107+
loss = loss_fn(pred, y)
108+
109+
optimizer.zero_grad()
110+
loss.backward()
111+
optimizer.step()
112+
113+
model.eval()
114+
test_loss, num_correct, num_total = 0, 0, 0
115+
with torch.no_grad():
116+
for X, y in tqdm(test_dataloader, desc=f"Test Epoch {epoch}"):
117+
pred = model(X)
118+
loss = loss_fn(pred, y)
119+
120+
test_loss += loss.item()
121+
num_total += y.shape[0]
122+
num_correct += (pred.argmax(1) == y).sum().item()
123+
124+
test_loss /= len(test_dataloader)
125+
accuracy = num_correct / num_total
126+
127+
# [3] Report metrics to Ray Train
128+
# ===============================
129+
ray.train.report(metrics={"loss": test_loss, "accuracy": accuracy})
130+
131+
132+
def train_fashion_mnist(num_workers=2, use_gpu=False):
133+
global_batch_size = 32
134+
135+
train_config = {
136+
"lr": 1e-3,
137+
"epochs": 10,
138+
"batch_size_per_worker": global_batch_size // num_workers,
139+
}
140+
141+
# Configure computation resources
142+
scaling_config = ScalingConfig(
143+
num_workers=num_workers,
144+
use_gpu=use_gpu,
145+
resources_per_worker={"CPU": 2}
146+
)
147+
148+
# Initialize a Ray TorchTrainer
149+
trainer = TorchTrainer(
150+
train_loop_per_worker=train_func_per_worker,
151+
train_loop_config=train_config,
152+
scaling_config=scaling_config,
153+
)
154+
155+
# [4] Start distributed training
156+
# Run `train_func_per_worker` on all workers
157+
# =============================================
158+
result = trainer.fit()
159+
print(f"Training result: {result}")
160+
161+
162+
if __name__ == "__main__":
163+
train_fashion_mnist(num_workers=4)

0 commit comments

Comments
 (0)