Skip to content

Commit 3defc67

Browse files
committed
Added Ray Train & Pytorch Lightning demo
1 parent eb643ca commit 3defc67

File tree

3 files changed

+288
-0
lines changed

3 files changed

+288
-0
lines changed
+167
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,167 @@
1+
{
2+
"cells": [
3+
{
4+
"cell_type": "markdown",
5+
"metadata": {},
6+
"source": [
7+
"In this notebook we are going to run a Ray Train & Pytorch Lightning script using the CodeFlare SDK and Ray Job Submission.\n",
8+
"\n",
9+
"NOTE: For distributed training an external persistent storage option should be set in the `run_config`.\n",
10+
"You can find examples in the `pytorch_lightning.py` script."
11+
]
12+
},
13+
{
14+
"cell_type": "code",
15+
"execution_count": null,
16+
"metadata": {},
17+
"outputs": [],
18+
"source": [
19+
"# Import pieces from codeflare-sdk\n",
20+
"from codeflare_sdk import Cluster, ClusterConfiguration, TokenAuthentication"
21+
]
22+
},
23+
{
24+
"cell_type": "code",
25+
"execution_count": null,
26+
"metadata": {},
27+
"outputs": [],
28+
"source": [
29+
"# Create authentication object for user permissions\n",
30+
"# IF unused, SDK will automatically check for default kubeconfig, then in-cluster config\n",
31+
"# KubeConfigFileAuthentication can also be used to specify kubeconfig path manually\n",
32+
"auth = TokenAuthentication(\n",
33+
" token = \"XXXXX\",\n",
34+
" server = \"XXXXX\",\n",
35+
" skip_tls=False\n",
36+
")\n",
37+
"auth.login()"
38+
]
39+
},
40+
{
41+
"cell_type": "markdown",
42+
"metadata": {},
43+
"source": [
44+
"Once again, let's start by running through the same cluster setup as before:\n",
45+
"\n",
46+
"NOTE: We must specify the `image` which will be used in our RayCluster, we recommend you bring your own image which suits your purposes. \n",
47+
"The example here is a community image."
48+
]
49+
},
50+
{
51+
"cell_type": "code",
52+
"execution_count": null,
53+
"metadata": {},
54+
"outputs": [],
55+
"source": [
56+
"# Create and configure our cluster object\n",
57+
"# The SDK will try to find the name of your default local queue based on the annotation \"kueue.x-k8s.io/default-queue\": \"true\" unless you specify the local queue manually below\n",
58+
"cluster = Cluster(ClusterConfiguration(\n",
59+
" name='raytest',\n",
60+
" namespace='default', # Update to your namespace\n",
61+
" num_workers=2,\n",
62+
" min_cpus=2,\n",
63+
" max_cpus=2,\n",
64+
" min_memory=8,\n",
65+
" max_memory=8,\n",
66+
" num_gpus=1,\n",
67+
" head_gpus=1,\n",
68+
" image=\"quay.io/project-codeflare/ray:latest-py39-cu118\",\n",
69+
" write_to_file=True, # When enabled Ray Cluster yaml files are written to /HOME/.codeflare/resources \n",
70+
" # local_queue=\"local-queue-name\" # Specify the local queue manually\n",
71+
"))"
72+
]
73+
},
74+
{
75+
"cell_type": "code",
76+
"execution_count": null,
77+
"metadata": {},
78+
"outputs": [],
79+
"source": [
80+
"# Bring up the cluster\n",
81+
"cluster.up()\n",
82+
"cluster.wait_ready()"
83+
]
84+
},
85+
{
86+
"cell_type": "code",
87+
"execution_count": null,
88+
"metadata": {},
89+
"outputs": [],
90+
"source": [
91+
"cluster.details()"
92+
]
93+
},
94+
{
95+
"cell_type": "code",
96+
"execution_count": null,
97+
"metadata": {},
98+
"outputs": [],
99+
"source": [
100+
"# Initialize the Job Submission Client\n",
101+
"\"\"\"\n",
102+
"The SDK will automatically gather the dashboard address and authenticate using the Ray Job Submission Client\n",
103+
"\"\"\"\n",
104+
"client = cluster.job_client"
105+
]
106+
},
107+
{
108+
"cell_type": "code",
109+
"execution_count": null,
110+
"metadata": {},
111+
"outputs": [],
112+
"source": [
113+
"# Submit an example mnist job using the Job Submission Client\n",
114+
"submission_id = client.submit_job(\n",
115+
" entrypoint=\"python pytorch_lightning.py\",\n",
116+
" runtime_env={\"working_dir\": \"./\",\"pip\": \"requirements_lightning.txt\"},\n",
117+
")\n",
118+
"print(submission_id)"
119+
]
120+
},
121+
{
122+
"cell_type": "code",
123+
"execution_count": null,
124+
"metadata": {},
125+
"outputs": [],
126+
"source": [
127+
"# Get the job's logs\n",
128+
"client.get_job_logs(submission_id)"
129+
]
130+
},
131+
{
132+
"cell_type": "code",
133+
"execution_count": null,
134+
"metadata": {},
135+
"outputs": [],
136+
"source": [
137+
"# Get the job's status\n",
138+
"client.get_job_status(submission_id)"
139+
]
140+
},
141+
{
142+
"cell_type": "code",
143+
"execution_count": null,
144+
"metadata": {},
145+
"outputs": [],
146+
"source": [
147+
"cluster.down()"
148+
]
149+
},
150+
{
151+
"cell_type": "code",
152+
"execution_count": null,
153+
"metadata": {},
154+
"outputs": [],
155+
"source": [
156+
"auth.logout()"
157+
]
158+
}
159+
],
160+
"metadata": {
161+
"language_info": {
162+
"name": "python"
163+
}
164+
},
165+
"nbformat": 4,
166+
"nbformat_minor": 2
167+
}

Diff for: demo-notebooks/guided-demos/pytorch_lightning.py

+116
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,116 @@
1+
import os
2+
import tempfile
3+
4+
import torch
5+
from torch.utils.data import DataLoader
6+
from torchvision.models import resnet18
7+
from torchvision.datasets import FashionMNIST
8+
from torchvision.transforms import ToTensor, Normalize, Compose
9+
import lightning.pytorch as pl
10+
11+
import ray.train.lightning
12+
from ray.train.torch import TorchTrainer
13+
14+
# Based on https://docs.ray.io/en/latest/train/getting-started-pytorch-lightning.html
15+
16+
"""
17+
# For S3 persistent storage replace the following environment variables with your AWS credentials then uncomment the S3 run_config
18+
# See here for information on how to set up an S3 bucket https://docs.aws.amazon.com/AmazonS3/latest/userguide/creating-bucket.html
19+
20+
os.environ["AWS_ACCESS_KEY_ID"] = "XXXXXXXX"
21+
os.environ["AWS_SECRET_ACCESS_KEY"] = "XXXXXXXX"
22+
os.environ["AWS_DEFAULT_REGION"] = "XXXXXXXX"
23+
"""
24+
25+
"""
26+
# For Minio persistent storage uncomment the following code and fill in the name, password and API URL then uncomment the minio run_config.
27+
# See here for information on how to set up a minio bucket https://ai-on-openshift.io/tools-and-applications/minio/minio/
28+
29+
def get_minio_run_config():
30+
import s3fs
31+
import pyarrow.fs
32+
33+
s3_fs = s3fs.S3FileSystem(
34+
key = os.getenv('MINIO_ACCESS_KEY', "XXXXX"),
35+
secret = os.getenv('MINIO_SECRET_ACCESS_KEY', "XXXXX"),
36+
endpoint_url = os.getenv('MINIO_URL', "XXXXX")
37+
)
38+
39+
custom_fs = pyarrow.fs.PyFileSystem(pyarrow.fs.FSSpecHandler(s3_fs))
40+
41+
run_config = ray.train.RunConfig(storage_path='training', storage_filesystem=custom_fs)
42+
return run_config
43+
"""
44+
45+
46+
# Model, Loss, Optimizer
47+
class ImageClassifier(pl.LightningModule):
48+
def __init__(self):
49+
super(ImageClassifier, self).__init__()
50+
self.model = resnet18(num_classes=10)
51+
self.model.conv1 = torch.nn.Conv2d(
52+
1, 64, kernel_size=(7, 7), stride=(2, 2), padding=(3, 3), bias=False
53+
)
54+
self.criterion = torch.nn.CrossEntropyLoss()
55+
56+
def forward(self, x):
57+
return self.model(x)
58+
59+
def training_step(self, batch, batch_idx):
60+
x, y = batch
61+
outputs = self.forward(x)
62+
loss = self.criterion(outputs, y)
63+
self.log("loss", loss, on_step=True, prog_bar=True)
64+
return loss
65+
66+
def configure_optimizers(self):
67+
return torch.optim.Adam(self.model.parameters(), lr=0.001)
68+
69+
70+
def train_func():
71+
# Data
72+
transform = Compose([ToTensor(), Normalize((0.5,), (0.5,))])
73+
data_dir = os.path.join(tempfile.gettempdir(), "data")
74+
train_data = FashionMNIST(
75+
root=data_dir, train=True, download=True, transform=transform
76+
)
77+
train_dataloader = DataLoader(train_data, batch_size=128, shuffle=True)
78+
79+
# Training
80+
model = ImageClassifier()
81+
# [1] Configure PyTorch Lightning Trainer.
82+
trainer = pl.Trainer(
83+
max_epochs=10,
84+
devices="auto",
85+
accelerator="auto",
86+
strategy=ray.train.lightning.RayDDPStrategy(),
87+
plugins=[ray.train.lightning.RayLightningEnvironment()],
88+
callbacks=[ray.train.lightning.RayTrainReportCallback()],
89+
# [1a] Optionally, disable the default checkpointing behavior
90+
# in favor of the `RayTrainReportCallback` above.
91+
enable_checkpointing=False,
92+
)
93+
trainer = ray.train.lightning.prepare_trainer(trainer)
94+
trainer.fit(model, train_dataloaders=train_dataloader)
95+
96+
97+
# [2] Configure scaling and resource requirements. Set the number of workers to the total number of GPUs on your Ray Cluster.
98+
scaling_config = ray.train.ScalingConfig(num_workers=3, use_gpu=True)
99+
100+
# [3] Launch distributed training job.
101+
trainer = TorchTrainer(
102+
train_func,
103+
scaling_config=scaling_config,
104+
# run_config = ray.train.RunConfig(storage_path="s3://BUCKET_NAME/SUB_PATH/", name="unique_run_name") # Uncomment and update the S3 URI for S3 persistent storage.
105+
# run_config=get_minio_run_config(), # Uncomment for minio persistent storage.
106+
)
107+
result: ray.train.Result = trainer.fit()
108+
109+
# [4] Load the trained model.
110+
with result.checkpoint.as_directory() as checkpoint_dir:
111+
model = ImageClassifier.load_from_checkpoint(
112+
os.path.join(
113+
checkpoint_dir,
114+
ray.train.lightning.RayTrainReportCallback.CHECKPOINT_NAME,
115+
),
116+
)
+5
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,5 @@
1+
torch==2.3.0
2+
torchvision==0.18.0
3+
lightning==2.2.5
4+
ray[train]==2.20.0
5+
s3fs==2024.6.0

0 commit comments

Comments
 (0)