Skip to content

Commit dc0285c

Browse files
committed
Add mlflow example for PyTorch DDP
1 parent 57c9136 commit dc0285c

File tree

2 files changed

+21
-3
lines changed

2 files changed

+21
-3
lines changed

mnist_ddp_mlflow.py

Lines changed: 4 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -47,9 +47,10 @@ def train(num_epochs):
4747
verbose = dist.get_rank() == 0 # print only on global_rank==0
4848
if verbose:
4949
mlflow.set_tracking_uri("/scratch/project_2001659/mvsjober/mlruns")
50-
slurm_id = os.getenv("SLURM_JOB_ID")
51-
if slurm_id:
52-
mlflow.start_run(run_name=slurm_id)
50+
#mlflow.set_tracking_uri("sqlite:////scratch/project_2001659/mvsjober/mlruns.db")
51+
#mlflow.set_tracking_uri('https://mats-mlflow2.rahtiapp.fi/')
52+
53+
mlflow.start_run(run_name=os.getenv("SLURM_JOB_ID"))
5354

5455
model = ConvNet().cuda()
5556
batch_size = 100

run-ddp-gpu1-mlflow.sh

Lines changed: 17 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,17 @@
1+
#!/bin/bash
2+
#SBATCH --account=project_2001659
3+
#SBATCH --partition=gputest
4+
#SBATCH --ntasks=1
5+
#SBATCH --cpus-per-task=10
6+
#SBATCH --mem=64G
7+
#SBATCH --time=15
8+
#SBATCH --gres=gpu:v100:1
9+
10+
module purge
11+
module load pytorch
12+
13+
# Old way with torch.distributed.run
14+
# srun python3 -m torch.distributed.run --standalone --nnodes=1 --nproc_per_node=4 mnist_ddp.py --epochs=100
15+
16+
# New way with torchrun
17+
srun torchrun --standalone --nnodes=1 --nproc_per_node=1 mnist_ddp_mlflow.py --epochs=100

0 commit comments

Comments
 (0)