-
Notifications
You must be signed in to change notification settings - Fork 470
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
Add RayJob training example using pytorch resnet image classifier #2107
Merged
kevin85421
merged 8 commits into
ray-project:master
from
andrewsykim:pytorch-lightning-image-classifier
May 16, 2024
Merged
Changes from 7 commits
Commits
Show all changes
8 commits
Select commit
Hold shift + click to select a range
712ae25
add sample YAML for pytorch distributed training with GCSFuse for che…
andrewsykim 06eb03b
add resume from checkpoint
andrewsykim 9377db6
add pytorch resnet image classifier example
andrewsykim a948b8f
use TorchTrainer.can_restore
andrewsykim 03cf52d
update placeholders
andrewsykim 7b9759d
address comments from Kai-Hsun
andrewsykim d8799fc
fix missing minReplicas and maxRelicas
andrewsykim b13a91b
revert changes to text classifier example
andrewsykim File filter
Filter by extension
Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
There are no files selected for viewing
225 changes: 225 additions & 0 deletions
225
...nfig/samples/pytorch-resnet-image-classifier/fine-tune-pytorch-resnet-image-classifier.py
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,225 @@ | ||
import os | ||
import warnings | ||
from tempfile import TemporaryDirectory | ||
|
||
import torch | ||
import torch.nn as nn | ||
import torch.optim as optim | ||
from torch.utils.data import DataLoader | ||
from torchvision import datasets, models, transforms | ||
import numpy as np | ||
|
||
import ray.train as train | ||
from ray.train.torch import TorchTrainer | ||
from ray.train import ScalingConfig, RunConfig, CheckpointConfig, Checkpoint | ||
|
||
# Data augmentation and normalization for training | ||
# Just normalization for validation | ||
data_transforms = { | ||
"train": transforms.Compose( | ||
[ | ||
transforms.RandomResizedCrop(224), | ||
transforms.RandomHorizontalFlip(), | ||
transforms.ToTensor(), | ||
transforms.Normalize([0.485, 0.456, 0.406], [0.229, 0.224, 0.225]), | ||
] | ||
), | ||
"val": transforms.Compose( | ||
[ | ||
transforms.Resize(224), | ||
transforms.CenterCrop(224), | ||
transforms.ToTensor(), | ||
transforms.Normalize([0.485, 0.456, 0.406], [0.229, 0.224, 0.225]), | ||
] | ||
), | ||
} | ||
|
||
def download_datasets(): | ||
os.system( | ||
"wget https://download.pytorch.org/tutorial/hymenoptera_data.zip >/dev/null 2>&1" | ||
) | ||
os.system("unzip hymenoptera_data.zip >/dev/null 2>&1") | ||
|
||
# Download and build torch datasets | ||
def build_datasets(): | ||
torch_datasets = {} | ||
for split in ["train", "val"]: | ||
torch_datasets[split] = datasets.ImageFolder( | ||
os.path.join("./hymenoptera_data", split), data_transforms[split] | ||
) | ||
return torch_datasets | ||
|
||
def initialize_model(): | ||
# Load pretrained model params | ||
model = models.resnet50(pretrained=True) | ||
|
||
# Replace the original classifier with a new Linear layer | ||
num_features = model.fc.in_features | ||
model.fc = nn.Linear(num_features, 2) | ||
|
||
# Ensure all params get updated during finetuning | ||
for param in model.parameters(): | ||
param.requires_grad = True | ||
return model | ||
|
||
|
||
def evaluate(logits, labels): | ||
_, preds = torch.max(logits, 1) | ||
corrects = torch.sum(preds == labels).item() | ||
return corrects | ||
|
||
train_loop_config = { | ||
"input_size": 224, # Input image size (224 x 224) | ||
"batch_size": 32, # Batch size for training | ||
"num_epochs": 10, # Number of epochs to train for | ||
"lr": 0.001, # Learning Rate | ||
"momentum": 0.9, # SGD optimizer momentum | ||
} | ||
|
||
def train_loop_per_worker(configs): | ||
warnings.filterwarnings("ignore") | ||
|
||
# Calculate the batch size for a single worker | ||
worker_batch_size = configs["batch_size"] | ||
|
||
# Download dataset once on local rank 0 worker | ||
if train.get_context().get_local_rank() == 0: | ||
download_datasets() | ||
torch.distributed.barrier() | ||
|
||
# Build datasets on each worker | ||
torch_datasets = build_datasets() | ||
|
||
# Prepare dataloader for each worker | ||
dataloaders = dict() | ||
dataloaders["train"] = DataLoader( | ||
torch_datasets["train"], batch_size=worker_batch_size, shuffle=True | ||
) | ||
dataloaders["val"] = DataLoader( | ||
torch_datasets["val"], batch_size=worker_batch_size, shuffle=False | ||
) | ||
|
||
# Distribute | ||
dataloaders["train"] = train.torch.prepare_data_loader(dataloaders["train"]) | ||
dataloaders["val"] = train.torch.prepare_data_loader(dataloaders["val"]) | ||
|
||
device = train.torch.get_device() | ||
|
||
# Prepare DDP Model, optimizer, and loss function. | ||
model = initialize_model() | ||
|
||
# Reload from checkpoint if exists. | ||
start_epoch = 0 | ||
checkpoint = train.get_checkpoint() | ||
if checkpoint: | ||
with checkpoint.as_directory() as checkpoint_dir: | ||
state_dict = torch.load(os.path.join(checkpoint_dir, "checkpoint.pt")) | ||
model.load_state_dict(state_dict['model']) | ||
|
||
start_epoch = state_dict['epoch'] + 1 | ||
|
||
model = train.torch.prepare_model(model) | ||
|
||
optimizer = optim.SGD( | ||
model.parameters(), lr=configs["lr"], momentum=configs["momentum"] | ||
) | ||
criterion = nn.CrossEntropyLoss() | ||
|
||
# Start training loops | ||
for epoch in range(start_epoch, configs["num_epochs"]): | ||
# Each epoch has a training and validation phase | ||
for phase in ["train", "val"]: | ||
if phase == "train": | ||
model.train() # Set model to training mode | ||
else: | ||
model.eval() # Set model to evaluate mode | ||
|
||
running_loss = 0.0 | ||
running_corrects = 0 | ||
|
||
if train.get_context().get_world_size() > 1: | ||
dataloaders[phase].sampler.set_epoch(epoch) | ||
|
||
for inputs, labels in dataloaders[phase]: | ||
inputs = inputs.to(device) | ||
labels = labels.to(device) | ||
|
||
# zero the parameter gradients | ||
optimizer.zero_grad() | ||
|
||
# forward | ||
with torch.set_grad_enabled(phase == "train"): | ||
# Get model outputs and calculate loss | ||
outputs = model(inputs) | ||
loss = criterion(outputs, labels) | ||
|
||
# backward + optimize only if in training phase | ||
if phase == "train": | ||
loss.backward() | ||
optimizer.step() | ||
|
||
# calculate statistics | ||
running_loss += loss.item() * inputs.size(0) | ||
running_corrects += evaluate(outputs, labels) | ||
|
||
size = len(torch_datasets[phase]) | ||
epoch_loss = running_loss / size | ||
epoch_acc = running_corrects / size | ||
|
||
if train.get_context().get_world_rank() == 0: | ||
print( | ||
"Epoch {}-{} Loss: {:.4f} Acc: {:.4f}".format( | ||
epoch, phase, epoch_loss, epoch_acc | ||
) | ||
) | ||
|
||
# Report metrics and checkpoint every epoch | ||
if phase == "val": | ||
with TemporaryDirectory() as tmpdir: | ||
state_dict = { | ||
"epoch": epoch, | ||
"model": model.module.state_dict(), | ||
"optimizer_state_dict": optimizer.state_dict(), | ||
} | ||
|
||
# In standard DDP training, where the model is the same across all ranks, | ||
# only the global rank 0 worker needs to save and report the checkpoint | ||
if train.get_context().get_world_rank() == 0: | ||
torch.save(state_dict, os.path.join(tmpdir, "checkpoint.pt")) | ||
|
||
train.report( | ||
metrics={"loss": epoch_loss, "acc": epoch_acc}, | ||
checkpoint=Checkpoint.from_directory(tmpdir), | ||
) | ||
|
||
if __name__ == "__main__": | ||
num_workers = int(os.environ.get("NUM_WORKERS", "4")) | ||
scaling_config = ScalingConfig( | ||
num_workers=num_workers, use_gpu=True, resources_per_worker={"CPU": 1, "GPU": 1} | ||
) | ||
|
||
checkpoint_config = CheckpointConfig(num_to_keep=3) | ||
run_config = RunConfig( | ||
name="finetune-resnet", | ||
storage_path="/mnt/cluster_storage", | ||
checkpoint_config=checkpoint_config, | ||
) | ||
|
||
experiment_path = os.path.expanduser("/mnt/cluster_storage/finetune-resnet") | ||
if TorchTrainer.can_restore(experiment_path): | ||
trainer = TorchTrainer.restore(experiment_path, | ||
train_loop_per_worker=train_loop_per_worker, | ||
train_loop_config=train_loop_config, | ||
scaling_config=scaling_config, | ||
run_config=run_config, | ||
) | ||
else: | ||
trainer = TorchTrainer( | ||
train_loop_per_worker=train_loop_per_worker, | ||
train_loop_config=train_loop_config, | ||
scaling_config=scaling_config, | ||
run_config=run_config, | ||
) | ||
|
||
result = trainer.fit() | ||
print(result) |
110 changes: 110 additions & 0 deletions
110
...ator/config/samples/pytorch-resnet-image-classifier/ray-job.pytorch-image-classifier.yaml
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,110 @@ | ||
# This RayJob is based on the "Finetuning a Pytorch Image Classifier with Ray Train" example in the Ray documentation. | ||
# See https://docs.ray.io/en/latest/train/examples/pytorch/pytorch_resnet_finetune.html for more details. | ||
apiVersion: ray.io/v1 | ||
kind: RayJob | ||
metadata: | ||
generateName: pytorch-image-classifier- | ||
spec: | ||
shutdownAfterJobFinishes: true | ||
entrypoint: python ray-operator/config/samples/pytorch-resnet-image-classifier/fine-tune-pytorch-resnet-image-classifier.py | ||
runtimeEnvYAML: | | ||
pip: | ||
- numpy | ||
- datasets | ||
- torch | ||
- torchvision | ||
- transformers>=4.19.1 | ||
working_dir: "https://github.com/ray-project/kuberay/archive/master.zip" | ||
rayClusterSpec: | ||
rayVersion: '2.9.0' | ||
headGroupSpec: | ||
rayStartParams: | ||
dashboard-host: '0.0.0.0' | ||
template: | ||
metadata: | ||
annotations: | ||
gke-gcsfuse/volumes: "true" | ||
gke-gcsfuse/cpu-limit: "0" | ||
gke-gcsfuse/memory-limit: 5Gi | ||
gke-gcsfuse/ephemeral-storage-limit: 10Gi | ||
spec: | ||
serviceAccountName: pytorch-distributed-training | ||
containers: | ||
- name: ray-head | ||
image: rayproject/ray:2.9.0 | ||
env: | ||
- name: NUM_WORKERS | ||
value: "4" | ||
ports: | ||
- containerPort: 6379 | ||
name: gcs-server | ||
- containerPort: 8265 | ||
name: dashboard | ||
- containerPort: 10001 | ||
name: client | ||
resources: | ||
limits: | ||
cpu: "1" | ||
memory: "8G" | ||
requests: | ||
cpu: "1" | ||
memory: "8G" | ||
volumeMounts: | ||
- mountPath: /tmp/ray | ||
name: ray-logs | ||
- mountPath: /mnt/cluster_storage | ||
name: cluster-storage | ||
volumes: | ||
- name: ray-logs | ||
emptyDir: {} | ||
- name: cluster-storage | ||
csi: | ||
driver: gcsfuse.csi.storage.gke.io | ||
volumeAttributes: | ||
bucketName: GCS_BUCKET | ||
mountOptions: "implicit-dirs,uid=1000,gid=100" | ||
workerGroupSpecs: | ||
- replicas: 4 | ||
minReplicas: 4 | ||
maxReplicas: 4 | ||
groupName: gpu-group | ||
rayStartParams: | ||
dashboard-host: '0.0.0.0' | ||
template: | ||
metadata: | ||
annotations: | ||
gke-gcsfuse/volumes: "true" | ||
gke-gcsfuse/cpu-limit: "0" | ||
gke-gcsfuse/memory-limit: 5Gi | ||
gke-gcsfuse/ephemeral-storage-limit: 10Gi | ||
spec: | ||
serviceAccountName: pytorch-distributed-training | ||
tolerations: | ||
- key: "nvidia.com/gpu" | ||
operator: "Exists" | ||
effect: "NoSchedule" | ||
containers: | ||
- name: ray-worker | ||
image: rayproject/ray-ml:2.9.0-gpu | ||
resources: | ||
limits: | ||
memory: "8G" | ||
nvidia.com/gpu: "1" | ||
requests: | ||
cpu: "1" | ||
memory: "8G" | ||
nvidia.com/gpu: "1" | ||
volumeMounts: | ||
- mountPath: /tmp/ray | ||
name: ray-logs | ||
- mountPath: /mnt/cluster_storage | ||
name: cluster-storage | ||
volumes: | ||
- name: ray-logs | ||
emptyDir: {} | ||
- name: cluster-storage | ||
csi: | ||
driver: gcsfuse.csi.storage.gke.io | ||
volumeAttributes: | ||
bucketName: GCS_BUCKET | ||
mountOptions: "implicit-dirs,uid=1000,gid=100" |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -1,5 +1,6 @@ | ||
import ray | ||
import torch | ||
import os | ||
import numpy as np | ||
import pytorch_lightning as pl | ||
import torch.nn.functional as F | ||
|
@@ -14,7 +15,7 @@ | |
RayTrainReportCallback, | ||
) | ||
from ray.train.torch import TorchTrainer | ||
from ray.train import RunConfig, ScalingConfig, CheckpointConfig, DataConfig | ||
from ray.train import RunConfig, ScalingConfig, CheckpointConfig, DataConfig, Checkpoint | ||
|
||
class SentimentModel(pl.LightningModule): | ||
def __init__(self, lr=2e-5, eps=1e-8): | ||
|
@@ -112,7 +113,6 @@ def train_func(config): | |
) | ||
|
||
trainer = prepare_trainer(trainer) | ||
|
||
trainer.fit(model, train_dataloaders=train_ds_loader, val_dataloaders=val_ds_loader) | ||
|
||
|
||
|
@@ -131,15 +131,16 @@ def train_func(config): | |
# The checkpoints and metrics are reported by `RayTrainReportCallback` | ||
run_config = RunConfig( | ||
name="ptl-sent-classification", | ||
storage_path="/mnt/cluster_storage", | ||
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. Why do we need this change? This change may cause users fail to follow this doc. There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. sorry, this was leftover from the last change, reverting it |
||
checkpoint_config=CheckpointConfig( | ||
num_to_keep=2, | ||
checkpoint_score_attribute="matthews_correlation", | ||
checkpoint_score_order="max", | ||
), | ||
) | ||
|
||
# Schedule 2 workers for DDP training (1 GPU/worker by default) | ||
scaling_config = ScalingConfig(num_workers=1, use_gpu=True) | ||
num_workers = int(os.environ.get("NUM_WORKERS", "1")) | ||
scaling_config = ScalingConfig(num_workers=num_workers, use_gpu=True) | ||
|
||
trainer = TorchTrainer( | ||
train_loop_per_worker=train_func, | ||
|
Oops, something went wrong.
Add this suggestion to a batch that can be applied as a single commit.
This suggestion is invalid because no changes were made to the code.
Suggestions cannot be applied while the pull request is closed.
Suggestions cannot be applied while viewing a subset of changes.
Only one suggestion per line can be applied in a batch.
Add this suggestion to a batch that can be applied as a single commit.
Applying suggestions on deleted lines is not supported.
You must change the existing code in this line in order to create a valid suggestion.
Outdated suggestions cannot be applied.
This suggestion has been applied or marked resolved.
Suggestions cannot be applied from pending reviews.
Suggestions cannot be applied on multi-line comments.
Suggestions cannot be applied while the pull request is queued to merge.
Suggestion cannot be applied right now. Please check back later.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
When I run
kubectl create -f ...
, I got the following error. We should specifyminReplicas
andmaxReplicas
here.There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Huh, that's weird -- I didn't run into that error, maybe I didn't have validation enabled? I added both
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Interesting. Both of them are required.
kuberay/ray-operator/apis/ray/v1/raycluster_types.go
Lines 53 to 58 in 7fb46ab
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Those have kubebuilder defaults though, should it error if not specified?
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
In my understanding, the kubebuilder defaults will never be used because there is no
omitempty
in eitherMinReplicas
orMaxReplicas
.