Skip to content

Commit 8522fbc

Browse files
authored
Sky benchmark (skypilot-org#832)
1 parent 06e2b2b commit 8522fbc

25 files changed

+2493
-66
lines changed

examples/benchmark/keras_asr.yaml

+30
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,30 @@
1+
name: ljspeech-asr
2+
3+
resources:
4+
candidates:
5+
- {accelerators: T4}
6+
- {accelerators: V100}
7+
8+
workdir: ./examples/benchmark/keras_asr
9+
10+
setup: |
11+
conda create -n keras python=3.8 -y
12+
conda activate keras
13+
14+
# Install SkyCallback
15+
git clone [email protected]:sky-proj/sky.git
16+
pip install sky/sky/callbacks/
17+
18+
# User setup
19+
pip install numpy pandas tensorflow
20+
git clone https://github.com/keras-team/keras-io.git
21+
cd keras-io
22+
git checkout 49a16474cc5bbf86792bb7557a70d13fdb7a9c97
23+
24+
# Apply the patch to enable SkyCallback
25+
git apply ../callback.patch
26+
27+
run: |
28+
conda activate keras
29+
cd keras-io/examples/audio/
30+
python transformer_asr.py
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,21 @@
1+
diff --git a/examples/audio/transformer_asr.py b/examples/audio/transformer_asr.py
2+
index 8cd3e04..5edf885 100644
3+
--- a/examples/audio/transformer_asr.py
4+
+++ b/examples/audio/transformer_asr.py
5+
@@ -35,6 +35,7 @@ from glob import glob
6+
import tensorflow as tf
7+
from tensorflow import keras
8+
from tensorflow.keras import layers
9+
+from sky_callback import SkyKerasCallback
10+
11+
12+
"""
13+
@@ -520,7 +521,7 @@ learning_rate = CustomSchedule(
14+
optimizer = keras.optimizers.Adam(learning_rate)
15+
model.compile(optimizer=optimizer, loss=loss_fn)
16+
17+
-history = model.fit(ds, validation_data=val_ds, callbacks=[display_cb], epochs=1)
18+
+history = model.fit(ds, validation_data=val_ds, callbacks=[display_cb, SkyKerasCallback()], epochs=1)
19+
20+
"""
21+
In practice, you should train for around 100 epochs or more.

examples/benchmark/lightning_gan.yaml

+30
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,30 @@
1+
name: mnist-gan
2+
3+
resources:
4+
candidates:
5+
- {accelerators: T4}
6+
- {accelerators: V100}
7+
8+
workdir: ./examples/benchmark/lightning_gan
9+
10+
setup: |
11+
conda create -n pl python=3.8 -y
12+
conda activate pl
13+
14+
# Install SkyCallback
15+
git clone [email protected]:sky-proj/sky.git
16+
pip install sky/sky/callbacks/
17+
18+
# User setup
19+
pip install "torchvision" "pytorch-lightning>=1.4" "torch>=1.6, <1.9"
20+
git clone https://github.com/Lightning-AI/tutorials.git
21+
cd tutorials
22+
git checkout e22e229921a97ea241277e19e0eaddedc35808cb
23+
24+
# Apply the patch to enable SkyCallback
25+
git apply ../callback.patch
26+
27+
run: |
28+
conda activate pl
29+
cd tutorials/lightning_examples/basic-gan/
30+
python gan.py
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,21 @@
1+
diff --git a/lightning_examples/basic-gan/gan.py b/lightning_examples/basic-gan/gan.py
2+
index 24520fa..4a1e988 100644
3+
--- a/lightning_examples/basic-gan/gan.py
4+
+++ b/lightning_examples/basic-gan/gan.py
5+
@@ -11,6 +11,7 @@ from pytorch_lightning import LightningDataModule, LightningModule, Trainer
6+
from pytorch_lightning.callbacks.progress import TQDMProgressBar
7+
from torch.utils.data import DataLoader, random_split
8+
from torchvision.datasets import MNIST
9+
+from sky_callback import SkyLightningCallback
10+
11+
PATH_DATASETS = os.environ.get("PATH_DATASETS", ".")
12+
BATCH_SIZE = 256 if torch.cuda.is_available() else 64
13+
@@ -253,7 +254,7 @@ trainer = Trainer(
14+
accelerator="auto",
15+
devices=1 if torch.cuda.is_available() else None, # limiting got iPython runs
16+
max_epochs=5,
17+
- callbacks=[TQDMProgressBar(refresh_rate=20)],
18+
+ callbacks=[TQDMProgressBar(refresh_rate=20), SkyLightningCallback()],
19+
)
20+
trainer.fit(model, dm)
21+

examples/benchmark/timm.yaml

+37
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,37 @@
1+
name: resnet50-randaug
2+
3+
resources:
4+
candidates:
5+
- {accelerators: T4:4}
6+
- {accelerators: V100:4}
7+
8+
workdir: ./examples/benchmark/timm
9+
10+
setup: |
11+
conda create -n timm python=3.8 -y
12+
conda activate timm
13+
14+
# Install SkyCallback
15+
git clone [email protected]:sky-proj/sky.git
16+
pip install sky/sky/callbacks/
17+
18+
# User setup
19+
git clone https://github.com/rwightman/pytorch-image-models.git timm
20+
cd timm
21+
git checkout v0.5.4
22+
pip install -r requirements.txt
23+
24+
# Apply the patch to enable SkyCallback
25+
git apply ../callback.patch
26+
27+
# Apply the patch to use a dummy ImageNet dataset to avoid data downloading
28+
git apply ../dummy_dataset.patch
29+
30+
run: |
31+
conda activate timm
32+
cd timm
33+
python3 -m torch.distributed.launch --nproc_per_node=4 train.py \
34+
-b 64 --model resnet50 --sched cosine --epochs 200 --lr 0.05 \
35+
--amp --remode pixel --reprob 0.6 --aug-splits 3 \
36+
--aa rand-m9-mstd0.5-inc1 --resplit --split-bn --jsd \
37+
--dist-bn reduce
+35
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,35 @@
1+
diff --git a/train.py b/train.py
2+
index 6e3b058..8c61ed4 100755
3+
--- a/train.py
4+
+++ b/train.py
5+
@@ -58,6 +58,9 @@ try:
6+
except ImportError:
7+
has_wandb = False
8+
9+
+import sky_callback
10+
+from sky_callback import step_iterator
11+
+
12+
torch.backends.cudnn.benchmark = True
13+
_logger = logging.getLogger('train')
14+
15+
@@ -609,6 +612,11 @@ def main():
16+
with open(os.path.join(output_dir, 'args.yaml'), 'w') as f:
17+
f.write(args_text)
18+
19+
+ sky_callback.init(
20+
+ global_rank=args.rank,
21+
+ total_steps=num_epochs * len(loader_train),
22+
+ )
23+
+
24+
try:
25+
for epoch in range(start_epoch, num_epochs):
26+
if args.distributed and hasattr(loader_train.sampler, 'set_epoch'):
27+
@@ -674,7 +682,7 @@ def train_one_epoch(
28+
end = time.time()
29+
last_idx = len(loader) - 1
30+
num_updates = epoch * len(loader)
31+
- for batch_idx, (input, target) in enumerate(loader):
32+
+ for batch_idx, (input, target) in step_iterator(enumerate(loader)):
33+
last_batch = batch_idx == last_idx
34+
data_time_m.update(time.time() - end)
35+
if not args.prefetcher:
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,71 @@
1+
# A patch file to replace ImageNet with a dummy dataset.
2+
# Use only for benchmarking purposes.
3+
4+
diff --git a/train.py b/train.py
5+
index 6e3b058..8ddbcdd 100755
6+
--- a/train.py
7+
+++ b/train.py
8+
@@ -61,6 +61,34 @@ except ImportError:
9+
torch.backends.cudnn.benchmark = True
10+
_logger = logging.getLogger('train')
11+
12+
+
13+
+class DummyImageDataset(torch.utils.data.Dataset):
14+
+ """Dummy dataset with synthetic images."""
15+
+ _IMAGE_HEIGHT = 3072
16+
+ _IMAGE_WIDTH = 2304
17+
+
18+
+ def __init__(self, num_images, num_classes):
19+
+ import numpy as np
20+
+ from PIL import Image
21+
+ imarray = np.random.rand(self._IMAGE_HEIGHT, self._IMAGE_WIDTH, 3) * 255
22+
+ self.img = Image.fromarray(imarray.astype('uint8')).convert('RGB')
23+
+ self.num_images = num_images
24+
+ self.num_classes = num_classes
25+
+ self.transform = None
26+
+ self.target_transform = None
27+
+
28+
+ def __len__(self):
29+
+ return self.num_images
30+
+
31+
+ def __getitem__(self, idx):
32+
+ if self.transform is not None:
33+
+ img = self.transform(self.img)
34+
+ target = idx % self.num_classes
35+
+ if self.target_transform is not None:
36+
+ target = self.target_transform(target)
37+
+ return img, target
38+
+
39+
+
40+
# The first arg parser parses out only the --config argument, this argument is used to
41+
# load a yaml file containing key-values that override the defaults for the main parser below
42+
config_parser = parser = argparse.ArgumentParser(description='Training Config', add_help=False)
43+
@@ -71,8 +99,6 @@ parser.add_argument('-c', '--config', default='', type=str, metavar='FILE',
44+
parser = argparse.ArgumentParser(description='PyTorch ImageNet Training')
45+
46+
# Dataset parameters
47+
-parser.add_argument('data_dir', metavar='DIR',
48+
- help='path to dataset')
49+
parser.add_argument('--dataset', '-d', metavar='NAME', default='',
50+
help='dataset type (default: ImageFolder/ImageTar if empty)')
51+
parser.add_argument('--train-split', metavar='NAME', default='train',
52+
@@ -486,17 +512,8 @@ def main():
53+
_logger.info('Scheduled epochs: {}'.format(num_epochs))
54+
55+
# create the train and eval datasets
56+
- dataset_train = create_dataset(
57+
- args.dataset, root=args.data_dir, split=args.train_split, is_training=True,
58+
- class_map=args.class_map,
59+
- download=args.dataset_download,
60+
- batch_size=args.batch_size,
61+
- repeats=args.epoch_repeats)
62+
- dataset_eval = create_dataset(
63+
- args.dataset, root=args.data_dir, split=args.val_split, is_training=False,
64+
- class_map=args.class_map,
65+
- download=args.dataset_download,
66+
- batch_size=args.batch_size)
67+
+ dataset_train = DummyImageDataset(num_images=1231167, num_classes=1000)
68+
+ dataset_eval = DummyImageDataset(num_images=50000, num_classes=1000)
69+
70+
# setup mixup / cutmix
71+
collate_fn = None
+42
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,42 @@
1+
name: squad_v2
2+
3+
resources:
4+
candidates:
5+
- {accelerators: T4:8}
6+
- {accelerators: V100:8}
7+
8+
workdir: ./examples/benchmark/transformers_qa
9+
10+
setup: |
11+
conda create -n hf python=3.8 -y
12+
conda activate hf
13+
14+
# Install SkyCallback
15+
git clone [email protected]:sky-proj/sky.git
16+
pip install sky/sky/callbacks/
17+
18+
# User setup
19+
pip install transformers
20+
git clone https://github.com/huggingface/transformers.git
21+
cd transformers
22+
git checkout v4.20.0
23+
pip install -r examples/pytorch/question-answering/requirements.txt
24+
25+
# Apply the patch to enable SkyCallback
26+
git apply ../callback.patch
27+
28+
run: |
29+
conda activate hf
30+
cd transformers/examples/pytorch/question-answering/
31+
python run_qa.py \
32+
--model_name_or_path bert-base-uncased \
33+
--dataset_name squad_v2 \
34+
--do_train \
35+
--do_eval \
36+
--per_device_train_batch_size 12 \
37+
--learning_rate 3e-5 \
38+
--num_train_epochs 2 \
39+
--max_seq_length 384 \
40+
--doc_stride 128 \
41+
--version_2_with_negative \
42+
--output_dir outputs/
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,20 @@
1+
diff --git a/examples/pytorch/question-answering/run_qa.py b/examples/pytorch/question-answering/run_qa.py
2+
index f8f2ad7db..82fd64221 100755
3+
--- a/examples/pytorch/question-answering/run_qa.py
4+
+++ b/examples/pytorch/question-answering/run_qa.py
5+
@@ -26,6 +26,7 @@ from typing import Optional
6+
7+
import datasets
8+
from datasets import load_dataset, load_metric
9+
+from sky_callback import SkyTransformersCallback
10+
11+
import transformers
12+
from trainer_qa import QuestionAnsweringTrainer
13+
@@ -609,6 +610,7 @@ def main():
14+
data_collator=data_collator,
15+
post_process_function=post_processing_function,
16+
compute_metrics=compute_metrics,
17+
+ callbacks=[SkyTransformersCallback()],
18+
)
19+
20+
# Training

sky/__init__.py

+2
Original file line numberDiff line numberDiff line change
@@ -3,6 +3,7 @@
33

44
# Keep this order to avoid cyclic imports
55
from sky import backends
6+
from sky import benchmark
67
from sky import clouds
78
from sky.clouds.service_catalog import list_accelerators
89
from sky.dag import Dag, DagContext
@@ -33,6 +34,7 @@
3334
'Resources',
3435
'Task',
3536
'backends',
37+
'benchmark',
3638
'launch',
3739
'exec',
3840
'list_accelerators',

sky/benchmark/__init__.py

Whitespace-only changes.

0 commit comments

Comments
 (0)