Skip to content

Commit 29baf39

Browse files
committed
Added Ray Train & Pytorch Lightning demo
1 parent 3864bcf commit 29baf39

File tree

3 files changed

+277
-0
lines changed

3 files changed

+277
-0
lines changed
+174
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,174 @@
1+
{
2+
"cells": [
3+
{
4+
"cell_type": "markdown",
5+
"metadata": {},
6+
"source": [
7+
"In this notebook we are going to run a Ray Train & Pytorch Lightning script using the CodeFlare SDK and Ray Job Submission.\n",
8+
"\n",
9+
"NOTE: For distributed training an external persistent storage option should be set in the `run_config`.\n",
10+
"You can find examples in the `pytorch_lightning.py` script."
11+
]
12+
},
13+
{
14+
"cell_type": "code",
15+
"execution_count": null,
16+
"metadata": {},
17+
"outputs": [],
18+
"source": [
19+
"# Import pieces from codeflare-sdk\n",
20+
"from codeflare_sdk import Cluster, ClusterConfiguration, TokenAuthentication"
21+
]
22+
},
23+
{
24+
"cell_type": "code",
25+
"execution_count": null,
26+
"metadata": {},
27+
"outputs": [],
28+
"source": [
29+
"# Create authentication object for user permissions\n",
30+
"# IF unused, SDK will automatically check for default kubeconfig, then in-cluster config\n",
31+
"# KubeConfigFileAuthentication can also be used to specify kubeconfig path manually\n",
32+
"auth = TokenAuthentication(\n",
33+
" token = \"XXXXX\",\n",
34+
" server = \"XXXXX\",\n",
35+
" skip_tls=False\n",
36+
")\n",
37+
"auth.login()"
38+
]
39+
},
40+
{
41+
"cell_type": "markdown",
42+
"metadata": {},
43+
"source": [
44+
"Once again, let's start by running through the same cluster setup as before:\n",
45+
"\n",
46+
"NOTE: We must specify the `image` which will be used in our RayCluster, we recommend you bring your own image which suits your purposes. \n",
47+
"The example here is a community image."
48+
]
49+
},
50+
{
51+
"cell_type": "code",
52+
"execution_count": null,
53+
"metadata": {},
54+
"outputs": [],
55+
"source": [
56+
"# Create and configure our cluster object\n",
57+
"# The SDK will try to find the name of your default local queue based on the annotation \"kueue.x-k8s.io/default-queue\": \"true\" unless you specify the local queue manually below\n",
58+
"cluster = Cluster(ClusterConfiguration(\n",
59+
" name='raytest',\n",
60+
" namespace='default', # Update to your namespace\n",
61+
" num_workers=2,\n",
62+
" min_cpus=2,\n",
63+
" max_cpus=2,\n",
64+
" min_memory=8,\n",
65+
" max_memory=8,\n",
66+
" num_gpus=1,\n",
67+
" head_gpus=1,\n",
68+
" image=\"quay.io/project-codeflare/ray:2.20.0-py39-cu118\",\n",
69+
" write_to_file=True, # When enabled Ray Cluster yaml files are written to /HOME/.codeflare/resources \n",
70+
" # local_queue=\"local-queue-name\" # Specify the local queue manually\n",
71+
"))"
72+
]
73+
},
74+
{
75+
"cell_type": "code",
76+
"execution_count": null,
77+
"metadata": {},
78+
"outputs": [],
79+
"source": [
80+
"# Bring up the cluster\n",
81+
"cluster.up()\n",
82+
"cluster.wait_ready()"
83+
]
84+
},
85+
{
86+
"cell_type": "code",
87+
"execution_count": null,
88+
"metadata": {},
89+
"outputs": [],
90+
"source": [
91+
"cluster.details()"
92+
]
93+
},
94+
{
95+
"cell_type": "markdown",
96+
"metadata": {},
97+
"source": [
98+
"Note: For this example external S3 compatible storage is required. Please refer to our [documentation](https://github.com/project-codeflare/codeflare-sdk/blob/main/docs/s3-compatible-storage.md) for steps on how to configure this training script."
99+
]
100+
},
101+
{
102+
"cell_type": "code",
103+
"execution_count": null,
104+
"metadata": {},
105+
"outputs": [],
106+
"source": [
107+
"# Initialize the Job Submission Client\n",
108+
"\"\"\"\n",
109+
"The SDK will automatically gather the dashboard address and authenticate using the Ray Job Submission Client\n",
110+
"\"\"\"\n",
111+
"client = cluster.job_client"
112+
]
113+
},
114+
{
115+
"cell_type": "code",
116+
"execution_count": null,
117+
"metadata": {},
118+
"outputs": [],
119+
"source": [
120+
"# Submit an example mnist job using the Job Submission Client\n",
121+
"submission_id = client.submit_job(\n",
122+
" entrypoint=\"python pytorch_lightning.py\",\n",
123+
" runtime_env={\"working_dir\": \"./\",\"pip\": \"requirements_lightning.txt\"},\n",
124+
")\n",
125+
"print(submission_id)"
126+
]
127+
},
128+
{
129+
"cell_type": "code",
130+
"execution_count": null,
131+
"metadata": {},
132+
"outputs": [],
133+
"source": [
134+
"# Get the job's logs\n",
135+
"client.get_job_logs(submission_id)"
136+
]
137+
},
138+
{
139+
"cell_type": "code",
140+
"execution_count": null,
141+
"metadata": {},
142+
"outputs": [],
143+
"source": [
144+
"# Get the job's status\n",
145+
"client.get_job_status(submission_id)"
146+
]
147+
},
148+
{
149+
"cell_type": "code",
150+
"execution_count": null,
151+
"metadata": {},
152+
"outputs": [],
153+
"source": [
154+
"cluster.down()"
155+
]
156+
},
157+
{
158+
"cell_type": "code",
159+
"execution_count": null,
160+
"metadata": {},
161+
"outputs": [],
162+
"source": [
163+
"auth.logout()"
164+
]
165+
}
166+
],
167+
"metadata": {
168+
"language_info": {
169+
"name": "python"
170+
}
171+
},
172+
"nbformat": 4,
173+
"nbformat_minor": 2
174+
}

Diff for: demo-notebooks/guided-demos/pytorch_lightning.py

+98
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,98 @@
1+
import os
2+
import tempfile
3+
4+
import torch
5+
from torch.utils.data import DataLoader, DistributedSampler
6+
from torchvision.models import resnet18
7+
from torchvision.datasets import FashionMNIST
8+
from torchvision.transforms import ToTensor, Normalize, Compose
9+
import lightning.pytorch as pl
10+
11+
import ray.train.lightning
12+
from ray.train.torch import TorchTrainer
13+
14+
# Based on https://docs.ray.io/en/latest/train/getting-started-pytorch-lightning.html
15+
16+
"""
17+
Note: This example requires an S3 compatible storage bucket for distributed training. Please visit our documentation for more information -> https://github.com/project-codeflare/codeflare-sdk/blob/main/docs/s3-compatible-storage.md
18+
"""
19+
20+
21+
# Model, Loss, Optimizer
22+
class ImageClassifier(pl.LightningModule):
23+
def __init__(self):
24+
super(ImageClassifier, self).__init__()
25+
self.model = resnet18(num_classes=10)
26+
self.model.conv1 = torch.nn.Conv2d(
27+
1, 64, kernel_size=(7, 7), stride=(2, 2), padding=(3, 3), bias=False
28+
)
29+
self.criterion = torch.nn.CrossEntropyLoss()
30+
31+
def forward(self, x):
32+
return self.model(x)
33+
34+
def training_step(self, batch, batch_idx):
35+
x, y = batch
36+
outputs = self.forward(x)
37+
loss = self.criterion(outputs, y)
38+
self.log("loss", loss, on_step=True, prog_bar=True)
39+
return loss
40+
41+
def configure_optimizers(self):
42+
return torch.optim.Adam(self.model.parameters(), lr=0.001)
43+
44+
45+
def train_func():
46+
# Data
47+
transform = Compose([ToTensor(), Normalize((0.5,), (0.5,))])
48+
data_dir = os.path.join(tempfile.gettempdir(), "data")
49+
train_data = FashionMNIST(
50+
root=data_dir, train=True, download=True, transform=transform
51+
)
52+
53+
# Training
54+
model = ImageClassifier()
55+
56+
sampler = DistributedSampler(
57+
train_data,
58+
num_replicas=ray.train.get_context().get_world_size(),
59+
rank=ray.train.get_context().get_world_rank(),
60+
)
61+
62+
train_dataloader = DataLoader(
63+
train_data, batch_size=128, shuffle=False, sampler=sampler
64+
)
65+
# [1] Configure PyTorch Lightning Trainer.
66+
trainer = pl.Trainer(
67+
max_epochs=10,
68+
devices="auto",
69+
accelerator="auto",
70+
strategy=ray.train.lightning.RayDDPStrategy(),
71+
plugins=[ray.train.lightning.RayLightningEnvironment()],
72+
callbacks=[ray.train.lightning.RayTrainReportCallback()],
73+
# [1a] Optionally, disable the default checkpointing behavior
74+
# in favor of the `RayTrainReportCallback` above.
75+
enable_checkpointing=False,
76+
)
77+
trainer = ray.train.lightning.prepare_trainer(trainer)
78+
trainer.fit(model, train_dataloaders=train_dataloader)
79+
80+
81+
# [2] Configure scaling and resource requirements. Set the number of workers to the total number of GPUs on your Ray Cluster.
82+
scaling_config = ray.train.ScalingConfig(num_workers=3, use_gpu=True)
83+
84+
# [3] Launch distributed training job.
85+
trainer = TorchTrainer(
86+
train_func,
87+
scaling_config=scaling_config,
88+
)
89+
result: ray.train.Result = trainer.fit()
90+
91+
# [4] Load the trained model.
92+
with result.checkpoint.as_directory() as checkpoint_dir:
93+
model = ImageClassifier.load_from_checkpoint(
94+
os.path.join(
95+
checkpoint_dir,
96+
ray.train.lightning.RayTrainReportCallback.CHECKPOINT_NAME,
97+
),
98+
)
+5
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,5 @@
1+
torch==2.3.0
2+
torchvision==0.18.0
3+
lightning==2.2.5
4+
ray[train]==2.20.0
5+
s3fs==2024.6.0

0 commit comments

Comments
 (0)