Skip to content

Commit c2b2341

Browse files
Add TPU Pod to doc (skypilot-org#1318)
* add pod to doc * Apply suggestions from code review Co-authored-by: Zongheng Yang <[email protected]> * comments * comments * update bucket note * Apply suggestions from code review Co-authored-by: Zongheng Yang <[email protected]> * update * update * fix * fix * comments * fix Co-authored-by: Zongheng Yang <[email protected]>
1 parent 1af147a commit c2b2341

File tree

2 files changed

+181
-66
lines changed

2 files changed

+181
-66
lines changed

docs/source/reference/tpu.rst

Lines changed: 181 additions & 66 deletions
Original file line numberDiff line numberDiff line change
@@ -1,29 +1,20 @@
11
.. _tpu:
22

3+
=========
34
Cloud TPU
4-
================================
5+
=========
56

6-
SkyPilot supports running jobs on Google's `Cloud TPU <https://cloud.google.com/tpu/docs/intro-to-tpu>`_.
7-
Two different TPU architectures are available on GCP:
8-
9-
- `TPU Nodes <https://cloud.google.com/tpu/docs/system-architecture-tpu-vm#tpu-node>`_
10-
- `TPU VMs <https://cloud.google.com/tpu/docs/system-architecture-tpu-vm#tpu-vm>`_
7+
SkyPilot supports running jobs on Google's `Cloud TPU <https://cloud.google.com/tpu>`_, a specialized hardware accelerator for ML workloads.
118

12-
Both are supported by SkyPilot.
139

14-
The two architectures differ as follows.
15-
For TPU Nodes, a host VM communicates with the TPU host over gRPC.
16-
For TPU VMs, you can SSH directly into a VM that is physically connected to the TPU device.
17-
For more details please refer to GCP `documentation <https://cloud.google.com/tpu/docs/system-architecture-tpu-vm#tpu-arch>`_.
18-
19-
20-
.. note::
10+
Free TPUs via TPU Research Cloud (TRC)
11+
======================================
2112

22-
We encourage researchers to apply for free TPU access through `TPU Research Cloud (TRC) <https://sites.research.google/trc/about/>`_ program.
13+
ML researchers and students are encouraged to apply for free TPU access through `TPU Research Cloud (TRC) <https://sites.research.google/trc/about/>`_ program!
2314

2415

2516
Getting TPUs in one command
26-
--------------------------------
17+
===========================
2718

2819
Like :ref:`GPUs <interactive-nodes>`, SkyPilot provides a simple command to quickly get TPUs for development:
2920

@@ -35,48 +26,132 @@ Like :ref:`GPUs <interactive-nodes>`, SkyPilot provides a simple command to quic
3526
sky tpunode --instance-type n1-highmem-16 # Change the host VM type to n1-highmem-16
3627
sky tpunode --tpu-vm # Use TPU VM (instead of TPU Node)
3728
38-
After the command has finished, you will be dropped into the host VM and can start develop code right away!
29+
After the command finishes, you will be dropped into a TPU host VM and can start developing code right away.
30+
31+
Below, we show examples of using SkyPilot to run MNIST training on (1) TPU VMs and (2) TPU Nodes.
32+
33+
TPU Architectures
34+
=================
35+
36+
Two different TPU architectures are available on GCP:
37+
38+
- `TPU VMs <https://cloud.google.com/tpu/docs/system-architecture-tpu-vm#tpu-vm>`_
39+
- `TPU Nodes <https://cloud.google.com/tpu/docs/system-architecture-tpu-vm#tpu-node>`_
40+
41+
Both are supported by SkyPilot. We recommend TPU VMs which is a newer architecture encouraged by GCP.
42+
43+
The two architectures differ as follows.
44+
For TPU VMs, you can directly SSH into the "TPU host" VM that is physically connected to the TPU device.
45+
For TPU Nodes, a user VM (an `n1` instance) must be separately provisioned to communicate with an inaccessible TPU host over gRPC.
46+
More details can be found on GCP `documentation <https://cloud.google.com/tpu/docs/system-architecture-tpu-vm#tpu-arch>`_.
47+
48+
TPU VMs
49+
-------
50+
51+
To use TPU VMs, set the following in a task YAML's ``resources`` field:
52+
53+
.. code-block:: yaml
54+
55+
resources:
56+
accelerators: tpu-v2-8
57+
accelerator_args:
58+
tpu_vm: True
59+
runtime_version: tpu-vm-base # optional
60+
61+
The ``accelerators`` field specifies the TPU type, and the :code:`accelerator_args` dict includes the :code:`tpu_vm` bool (defaults to false, which means TPU Node is used), and an optional TPU ``runtime_version`` field.
62+
To show what TPU types are supported, run :code:`sky show-gpus`.
63+
64+
Here is a complete task YAML that runs `MNIST training <https://cloud.google.com/tpu/docs/run-calculation-jax#running_jax_code_on_a_tpu_vm>`_ on a TPU VM using JAX.
65+
66+
.. code-block:: yaml
67+
68+
name: mnist-tpu-vm
69+
70+
resources:
71+
accelerators: tpu-v2-8
72+
accelerator_args:
73+
tpu_vm: True
74+
runtime_version: tpu-vm-base
75+
76+
setup: |
77+
git clone https://github.com/google/flax.git
78+
79+
conda activate flax
80+
if [ $? -eq 0 ]; then
81+
echo 'conda env exists'
82+
else
83+
conda create -n flax python=3.8 -y
84+
conda activate flax
85+
# Make sure to install TPU related packages in a conda env to avoid package conflicts.
86+
pip install "jax[tpu]>=0.2.16" -f https://storage.googleapis.com/jax-releases/libtpu_releases.html
87+
pip install --upgrade clu
88+
pip install -e flax
89+
fi
90+
91+
run: |
92+
conda activate flax
93+
cd flax/examples/mnist
94+
python3 main.py --workdir=/tmp/mnist \
95+
--config=configs/default.py \
96+
--config.learning_rate=0.05 \
97+
--config.num_epochs=10
98+
99+
This YAML lives under the `SkyPilot repo <https://github.com/skypilot-org/skypilot/tree/master/examples/tpu>`_ (``examples/tpu/tpuvm_mnist.yaml``), or you can paste it into a local file.
100+
101+
Launch it with:
102+
103+
.. code-block:: console
104+
105+
$ sky launch examples/tpu/tpuvm_mnist.yaml -c mycluster
106+
107+
You should see the following outputs when the job finishes.
108+
109+
.. code-block:: console
110+
111+
$ sky launch examples/tpu/tpuvm_mnist.yaml -c mycluster
112+
...
113+
(mnist-tpu-vm pid=10155) I0823 07:49:25.468526 139641357117440 train.py:146] epoch: 9, train_loss: 0.0120, train_accuracy: 99.64, test_loss: 0.0278, test_accuracy: 99.02
114+
(mnist-tpu-vm pid=10155) I0823 07:49:26.966874 139641357117440 train.py:146] epoch: 10, train_loss: 0.0095, train_accuracy: 99.73, test_loss: 0.0264, test_accuracy: 99.19
39115
40-
Below we demonstrate how to run MNIST training on both TPU Nodes and TPU VMs with SkyPilot YAML.
41116
42117
TPU Nodes
43-
--------------------------------
118+
---------
44119

45-
To use TPU Node, a host CPU VM needs to be created together with a TPU node and configured correctly to connect with each other.
46-
SkyPilot automates the above process with a simple interface:
120+
In a TPU Node, a normal CPU VM (an `n1` instance) needs to be provisioned to communicate with the TPU host/device.
121+
122+
To use a TPU Node, set the following in a task YAML's ``resources`` field:
47123

48124
.. code-block:: yaml
49125
50126
resources:
51127
instance_type: n1-highmem-8
52128
accelerators: tpu-v2-8
53129
accelerator_args:
54-
runtime_version: 2.5.0 # TPU software version to be used.
130+
runtime_version: 2.5.0 # optional, TPU runtime version.
55131
56132
The above YAML considers :code:`n1-highmem-8` as the host machine and :code:`tpu-v2-8` as the TPU node resource.
57-
You may modify the host instance type or TPU type as you wish.
58-
To show more TPU accelerators, you may run the command :code:`sky show-gpus`.
133+
You can modify the host instance type or the TPU type.
134+
135+
Here is a complete task YAML that runs `MNIST training <https://cloud.google.com/tpu/docs/run-calculation-jax#running_jax_code_on_a_tpu_vm>`_ on a TPU Node using TensorFlow.
59136

60-
Now, we show a complete YAML for running `MNIST training <https://cloud.google.com/tpu/docs/tutorials/mnist-2.x>`_ on TPU node with TensorFlow.
61137

62138
.. code-block:: yaml
63139
64-
# Task name (optional), used for display purposes.
65140
name: mnist-tpu-node
66141
67142
resources:
68143
accelerators: tpu-v2-8
69144
accelerator_args:
70-
runtime_version: 2.5.0 # TPU software version to be used.
145+
runtime_version: 2.5.0 # optional, TPU runtime version.
71146
72147
# TPU node requires loading data from a GCS bucket.
148+
# We use SkyPilot Storage to mount a GCS bucket to /dataset.
73149
file_mounts:
74150
/dataset:
75151
name: mnist-tpu-node
76152
store: gcs
77153
mode: MOUNT
78154
79-
# The setup command. Will be run under the working directory.
80155
setup: |
81156
git clone https://github.com/tensorflow/models.git
82157
@@ -89,7 +164,6 @@ Now, we show a complete YAML for running `MNIST training <https://cloud.google.c
89164
pip install tensorflow==2.5.0 tensorflow-datasets tensorflow-model-optimization cloud-tpu-client
90165
fi
91166
92-
# The command to run. Will be run under the working directory.
93167
run: |
94168
conda activate mnist
95169
cd models/official/legacy/image_classification/
@@ -110,17 +184,20 @@ Now, we show a complete YAML for running `MNIST training <https://cloud.google.c
110184
111185
.. note::
112186

113-
TPU node requires loading data from a GCS bucket, so we add a :code:`file_mounts` to create a new bucket.
114-
Check :ref:`SkyPilot Storage <sky-storage>` for more details.
187+
TPU node requires loading data from a GCS bucket. The :code:`file_mounts` spec above simplifies this by using :ref:`SkyPilot Storage <sky-storage>` to create a new bucket/mount an existing bucket.
188+
If you encounter a bucket :code:`Permission denied` error,
189+
make sure the bucket is created in the same region as the Host VM/TPU Nodes and IAM permission for Cloud TPU is
190+
correctly setup (follow instructions `here <https://cloud.google.com/tpu/docs/storage-buckets#using_iam_permissions_for_alternative>`_).
115191

116192
.. note::
117-
The environment variable :code:`$TPU_NAME` is automatically set by SkyPilot for connecting TPU devices.
193+
The special environment variable :code:`$TPU_NAME` is automatically set by SkyPilot at run time, so it can be used in the ``run`` commands.
118194

119-
With the above YAML, you should be able to launch the training job with :code:`sky launch`!
195+
196+
This YAML lives under the `SkyPilot repo <https://github.com/skypilot-org/skypilot/tree/master/examples/tpu>`_ (``examples/tpu/tpu_node_mnist.yaml``). Launch it with:
120197

121198
.. code-block:: console
122199
123-
$ sky launch mnist-tpu-node.yaml -c mycluster
200+
$ sky launch examples/tpu/tpu_node_mnist.yaml -c mycluster
124201
...
125202
(mnist-tpu-node pid=28961) Epoch 9/10
126203
(mnist-tpu-node pid=28961) 58/58 [==============================] - 1s 19ms/step - loss: 0.1181 - sparse_categorical_accuracy: 0.9646 - val_loss: 0.0921 - val_sparse_categorical_accuracy: 0.9719
@@ -131,63 +208,101 @@ With the above YAML, you should be able to launch the training job with :code:`s
131208
132209
133210
134-
TPU VMs
135-
--------------------------------
136211
137-
To use TPU VMs, user only needs to add :code:`tpu_vm: True` and the desired TPU runtime version in :code:`accelerator_args` shown below:
212+
213+
214+
Using TPU Pods
215+
==============
216+
217+
A `TPU Pod <https://cloud.google.com/tpu/docs/training-on-tpu-pods>`_ is a collection of TPU devices connected by dedicated high-speed network interfaces for high-performance training.
218+
219+
To use a TPU Pod, simply change the ``accelerators`` field in the task YAML (e.g., :code:`v2-8` -> :code:`v2-32`).
138220

139221
.. code-block:: yaml
222+
:emphasize-lines: 2-2
140223
141224
resources:
142-
accelerators: tpu-v2-8
225+
accelerators: tpu-v2-32 # Pods have > 8 cores (the last number)
143226
accelerator_args:
144227
runtime_version: tpu-vm-base
145228
tpu_vm: True
146229
230+
.. note::
231+
232+
Both TPU architectures, TPU VMs and TPU Nodes, can be used with TPU Pods. The example below is based on TPU VMs.
233+
234+
To show all available TPU Pod types, run :code:`sky show-gpus` (more than 8 cores means Pods):
147235

148-
Note that :code:`instance_type` is no longer needed because TPU VMs is a standalone host VM that physically connects to the TPU device.
236+
.. code-block:: console
237+
238+
GOOGLE_TPU AVAILABLE_QUANTITIES
239+
tpu-v2-8 1
240+
tpu-v2-32 1
241+
tpu-v2-128 1
242+
tpu-v2-256 1
243+
tpu-v2-512 1
244+
tpu-v3-8 1
245+
tpu-v3-32 1
246+
tpu-v3-64 1
247+
tpu-v3-128 1
248+
tpu-v3-256 1
249+
tpu-v3-512 1
250+
tpu-v3-1024 1
251+
tpu-v3-2048 1
252+
253+
After creating a TPU Pod, multiple host VMs (e.g., :code:`v2-32` comes with 4 host VMs) are launched.
254+
Normally, the user needs to SSH into all hosts (depending on the architecture used, either the ``n1`` User VMs or the TPU Host VMs) to prepare files and setup environments, and
255+
then launch the job on each host, which is a tedious and error-prone process.
256+
257+
SkyPilot automates away this complexity. From your laptop, a single :code:`sky launch` command will perform:
149258

150-
Now we show an example of running `mnist training <https://cloud.google.com/tpu/docs/run-calculation-jax#running_jax_code_on_a_tpu_vm>`_ on TPU VM with JAX.
259+
- workdir/file_mounts syncing; and
260+
- execute the setup/run commands on every host of the pod.
261+
262+
Here is a task YAML for a cifar10 training job on a :code:`v2-32` TPU Pod with JAX (`code repo <https://github.com/infwinston/tpu-example>`_):
151263

152264
.. code-block:: yaml
153265
154-
name: mnist-tpu-vm
266+
name: cifar-tpu-pod
155267
156268
resources:
157-
accelerators: tpu-v2-8
269+
accelerators: tpu-v2-32
158270
accelerator_args:
159271
runtime_version: tpu-vm-base
160272
tpu_vm: True
161273
162274
setup: |
163-
git clone https://github.com/google/flax.git
164-
165-
conda activate flax
166-
if [ $? -eq 0 ]; then
167-
echo 'conda env exists'
168-
else
169-
conda create -n flax python=3.8 -y
170-
conda activate flax
171-
# Make sure to install TPU related packages in a conda env to avoid package conflicts.
172-
pip install "jax[tpu]>=0.2.16" -f https://storage.googleapis.com/jax-releases/libtpu_releases.html
173-
pip install --upgrade clu
174-
pip install -e flax
175-
fi
275+
git clone https://github.com/infwinston/tpu-example.git
276+
cd tpu-example
277+
pip install "jax[tpu]" -f https://storage.googleapis.com/jax-releases/libtpu_releases.html
278+
pip install -r requirements.txt
176279
177280
run: |
178-
conda activate flax
179-
cd flax/examples/mnist
180-
python3 main.py --workdir=/tmp/mnist \
181-
--config=configs/default.py \
182-
--config.learning_rate=0.05 \
183-
--config.num_epochs=10
281+
python -u tpu-example/train.py
184282
185-
A GCS bucket is not required as the TPU VM is physically linked to the TPU device, which can access data directly.
186-
You are expected to see the below outputs when the job finishes.
283+
Launch it with:
187284

188285
.. code-block:: console
189286
190-
$ sky launch examples/tpu/tpuvm_mnist.yaml -c mycluster
287+
$ sky launch examples/tpu/cifar_pod.yaml -c mycluster
288+
289+
You should see the following output.
290+
291+
.. code-block:: console
292+
293+
(node-0 pid=57977, ip=10.164.0.24) JAX process: 1 / 4
294+
(node-3 pid=57963, ip=10.164.0.26) JAX process: 3 / 4
295+
(node-2 pid=57922, ip=10.164.0.25) JAX process: 2 / 4
296+
(node-1 pid=63223) JAX process: 0 / 4
191297
...
192-
(mnist-tpu-vm pid=10155) I0823 07:49:25.468526 139641357117440 train.py:146] epoch: 9, train_loss: 0.0120, train_accuracy: 99.64, test_loss: 0.0278, test_accuracy: 99.02
193-
(mnist-tpu-vm pid=10155) I0823 07:49:26.966874 139641357117440 train.py:146] epoch: 10, train_loss: 0.0095, train_accuracy: 99.73, test_loss: 0.0264, test_accuracy: 99.19
298+
(node-0 pid=57977, ip=10.164.0.24) [ 1000/100000] time 0.034 ( 0.063) data 0.008 ( 0.008) loss 1.215 ( 1.489) acc 68.750 (46.163)
299+
300+
.. note::
301+
302+
By default, outputs from all hosts are shown with the ``node-<i>`` prefix. Use :code:`jax.process_index()` to control which host to print messages.
303+
304+
To submit more jobs to the same TPU Pod, use :code:`sky exec`:
305+
306+
.. code-block:: console
307+
308+
$ sky exec mycluster examples/tpu/cifar_pod.yaml
File renamed without changes.

0 commit comments

Comments
 (0)