Skip to content

Commit 5a2805e

Browse files
committed
feat(nvidia): build pytorch to get older cuda compute capabilities and setup arm64 support
1 parent b5a9e87 commit 5a2805e

File tree

23 files changed

+121
-60
lines changed

23 files changed

+121
-60
lines changed

.github/workflows/ci.yaml

Lines changed: 6 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -35,12 +35,16 @@ jobs:
3535
runs-on: ubuntu-latest
3636
steps:
3737
- uses: actions/checkout@v3
38-
- run: docker build --file test/images/nvidia-training/Dockerfile test/images/nvidia-training
38+
- run: |
39+
docker build --file test/images/nvidia-training/Dockerfile test/images/nvidia-training \
40+
--build-arg PYTORCH_BUILD_ENV="MAX_JOBS=8 BUILD_TEST=0 USE_FLASH_ATTENTION=0 USE_MEM_EFF_ATTENTION=0 USE_DISTRIBUTED=0"
3941
build-image-nvidia-inference:
4042
runs-on: ubuntu-latest
4143
steps:
4244
- uses: actions/checkout@v3
43-
- run: docker build --file test/images/nvidia-inference/Dockerfile test/images/nvidia-inference
45+
- run: |
46+
docker build --file test/images/nvidia-inference/Dockerfile test/images/nvidia-inference \
47+
--build-arg PYTORCH_BUILD_ENV="MAX_JOBS=8 BUILD_TEST=0 USE_FLASH_ATTENTION=0 USE_MEM_EFF_ATTENTION=0 USE_DISTRIBUTED=0"
4448
build-image-neuron-training:
4549
runs-on: ubuntu-latest
4650
steps:

internal/deployers/eksapi/kubeconfig.go

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -2,9 +2,9 @@ package eksapi
22

33
import (
44
"bytes"
5+
"fmt"
56
"os"
67
"text/template"
7-
"fmt"
88

99
"k8s.io/klog"
1010
)

test/cases/nvidia-training/bert_training_test.go

Lines changed: 4 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -69,9 +69,11 @@ func TestBertTraining(t *testing.T) {
6969
ObjectMeta: metav1.ObjectMeta{Name: "bert-training-launcher", Namespace: "default"},
7070
}
7171
err := wait.For(fwext.NewConditionExtension(cfg.Client().Resources()).JobSucceeded(job),
72-
wait.WithTimeout(time.Minute*20))
72+
wait.WithTimeout(time.Minute*20),
73+
wait.WithContext(ctx),
74+
)
7375
if err != nil {
74-
t.Fatal(err)
76+
t.Error(err)
7577
}
7678

7779
err = printJobLogs(ctx, cfg, "default", "bert-training-launcher")

test/cases/nvidia-training/main_test.go

Lines changed: 6 additions & 7 deletions
Original file line numberDiff line numberDiff line change
@@ -8,6 +8,7 @@ import (
88
"fmt"
99
"log"
1010
"os"
11+
"os/signal"
1112
"slices"
1213
"testing"
1314
"time"
@@ -37,7 +38,10 @@ func TestMain(m *testing.M) {
3738
if err != nil {
3839
log.Fatalf("failed to initialize test environment: %v", err)
3940
}
40-
testenv = env.NewWithConfig(cfg)
41+
42+
ctx, cancel := signal.NotifyContext(context.Background(), os.Interrupt)
43+
defer cancel()
44+
testenv = env.NewWithConfig(cfg).WithContext(ctx)
4145

4246
manifests := [][]byte{
4347
nvidiaDevicePluginManifest,
@@ -147,16 +151,11 @@ func checkNodeTypes(ctx context.Context, config *envconf.Config) (context.Contex
147151
return ctx, fmt.Errorf("no nodes found in the cluster")
148152
}
149153

150-
singleNodeType := true
151154
for i := 1; i < len(nodes.Items); i++ {
152155
if nodes.Items[i].Labels["node.kubernetes.io/instance-type"] != nodes.Items[i-1].Labels["node.kubernetes.io/instance-type"] {
153-
singleNodeType = false
154-
break
156+
return ctx, fmt.Errorf("node types are not the same, all node types must be the same in the cluster")
155157
}
156158
}
157-
if !singleNodeType {
158-
return ctx, fmt.Errorf("node types are not the same, all node types must be the same in the cluster")
159-
}
160159

161160
if *nodeType != "" {
162161
count := 0

test/cases/nvidia/main_test.go

Lines changed: 5 additions & 7 deletions
Original file line numberDiff line numberDiff line change
@@ -9,9 +9,9 @@ import (
99
"fmt"
1010
"log"
1111
"os"
12+
"os/signal"
1213
"slices"
1314
"testing"
14-
"time"
1515

1616
fwext "github.com/aws/aws-k8s-tester/internal/e2e"
1717
"github.com/aws/aws-sdk-go-v2/aws"
@@ -31,6 +31,7 @@ var (
3131
installDevicePlugin *bool
3232
efaEnabled *bool
3333
nvidiaTestImage *string
34+
pytorchImage *string
3435
skipUnitTestSubcommand *string
3536
nodeCount int
3637
gpuPerNode int
@@ -99,15 +100,11 @@ func checkNodeTypes(ctx context.Context, config *envconf.Config) (context.Contex
99100
return ctx, err
100101
}
101102

102-
singleNodeType := true
103103
for i := 1; i < len(nodes.Items)-1; i++ {
104104
if nodes.Items[i].Labels["node.kubernetes.io/instance-type"] != nodes.Items[i-1].Labels["node.kubernetes.io/instance-type"] {
105-
singleNodeType = false
105+
return ctx, fmt.Errorf("Node types are not the same, all node types must be the same in the cluster")
106106
}
107107
}
108-
if !singleNodeType {
109-
return ctx, fmt.Errorf("Node types are not the same, all node types must be the same in the cluster")
110-
}
111108

112109
if *nodeType != "" {
113110
for _, v := range nodes.Items {
@@ -135,6 +132,7 @@ func checkNodeTypes(ctx context.Context, config *envconf.Config) (context.Contex
135132
func TestMain(m *testing.M) {
136133
nodeType = flag.String("nodeType", "", "node type for the tests")
137134
nvidiaTestImage = flag.String("nvidiaTestImage", "", "nccl test image for nccl tests")
135+
pytorchImage = flag.String("pytorchImage", "763104351884.dkr.ecr.us-west-2.amazonaws.com/pytorch-training:2.1.0-gpu-py310-cu121-ubuntu20.04-ec2", "pytorch cuda image for single node tests")
138136
efaEnabled = flag.Bool("efaEnabled", false, "enable efa tests")
139137
installDevicePlugin = flag.Bool("installDevicePlugin", true, "install nvidia device plugin")
140138
skipUnitTestSubcommand = flag.String("skipUnitTestSubcommand", "", "optional command to skip specified unit test, `-s test1|test2|...`")
@@ -143,7 +141,7 @@ func TestMain(m *testing.M) {
143141
log.Fatalf("failed to initialize test environment: %v", err)
144142
}
145143
testenv = env.NewWithConfig(cfg)
146-
ctx, cancel := context.WithTimeout(context.Background(), 55*time.Minute)
144+
ctx, cancel := signal.NotifyContext(context.Background(), os.Interrupt)
147145
defer cancel()
148146
testenv = testenv.WithContext(ctx)
149147

test/cases/nvidia/manifests/job-hpc-benchmarks.yaml

Lines changed: 4 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -35,11 +35,15 @@ spec:
3535
- hpl.sh
3636
- --mem-affinity
3737
- 0:0:0:0:1:1:1:1
38+
# --cpu-affinity needs to be tuned depending on the number of CPUs
39+
# available on the instance type.
3840
- --cpu-affinity
3941
- 0-13:14-27:28-41:42-55:56-69:70-83:84-97:98-111
4042
- --no-multinode
4143
- --dat
4244
- hpl-linux-x86_64/sample-dat/HPL-dgx-1N.dat
45+
# TODO: the path differs for arm64
46+
#- hpl-linux-aarch64-gpu/sample-dat/HPL-dgx-1N.dat
4347
volumeMounts:
4448
- mountPath: /dev/shm
4549
name: dshm

test/cases/nvidia/manifests/job-unit-test-single-node.yaml

Lines changed: 7 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -17,8 +17,12 @@ spec:
1717
- /bin/bash
1818
- ./gpu_unit_tests/unit_test
1919
env:
20-
- name: SKIP_TESTS_SUBCOMMAND
21-
value: {{.SkipTestSubcommand}}
20+
- name: SKIP_TESTS_SUBCOMMAND
21+
value: {{.SkipTestSubcommand}}
22+
# because we started building these from source, this is just a
23+
# regular binary.
24+
- name: DEMO_SUITE_DIR
25+
value: /usr/bin
2226
imagePullPolicy: Always
2327
resources:
2428
limits:
@@ -29,4 +33,4 @@ spec:
2933
cpu: "1"
3034
memory: 1Gi
3135
restartPolicy: Never
32-
backoffLimit: 4
36+
backoffLimit: 4

test/cases/nvidia/manifests/mpi-job-pytorch-training-single-node.yaml

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -16,7 +16,7 @@ spec:
1616
spec:
1717
restartPolicy: OnFailure
1818
containers:
19-
- image: 763104351884.dkr.ecr.us-west-2.amazonaws.com/pytorch-training:2.1.0-gpu-py310-cu121-ubuntu20.04-ec2
19+
- image: {{.PytorchTestImage}}
2020
name: gpu-test
2121
command:
2222
- mpirun
@@ -48,7 +48,7 @@ spec:
4848
- MXNET_CUDNN_AUTOTUNE_DEFAULT=0
4949
- python
5050
- -c
51-
- import os; os.system("git clone https://github.com/pytorch/examples.git /pytorch-examples"); os.system("git -C pytorch-examples checkout 0f0c9131ca5c79d1332dce1f4c06fe942fbdc665"); os.system("python /pytorch-examples/mnist/main.py --epochs 1")
51+
- import os; os.system("git clone https://github.com/pytorch/examples.git pytorch-examples"); os.system("git -C pytorch-examples checkout 0f0c9131ca5c79d1332dce1f4c06fe942fbdc665"); os.system("python pytorch-examples/mnist/main.py --epochs 1")
5252
resources:
5353
limits:
5454
nvidia.com/gpu: 1

test/cases/nvidia/mpi_test.go

Lines changed: 9 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -48,7 +48,15 @@ func TestMPIJobPytorchTraining(t *testing.T) {
4848
WithLabel("hardware", "gpu").
4949
Setup(func(ctx context.Context, t *testing.T, cfg *envconf.Config) context.Context {
5050
t.Log("Applying single node manifest")
51-
err := fwext.ApplyManifests(cfg.Client().RESTConfig(), mpiJobPytorchTrainingSingleNodeManifest)
51+
renderedSingleNodeManifest, err := fwext.RenderManifests(mpiJobPytorchTrainingSingleNodeManifest, struct {
52+
PytorchTestImage string
53+
}{
54+
PytorchTestImage: *pytorchImage,
55+
})
56+
if err != nil {
57+
t.Fatal(err)
58+
}
59+
err = fwext.ApplyManifests(cfg.Client().RESTConfig(), renderedSingleNodeManifest)
5260
if err != nil {
5361
t.Fatal(err)
5462
}

test/cases/nvidia/unit_test.go

Lines changed: 4 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -75,13 +75,13 @@ func TestSingleNodeUnitTest(t *testing.T) {
7575
ObjectMeta: metav1.ObjectMeta{Name: "unit-test-job", Namespace: "default"},
7676
})
7777
if err != nil {
78-
t.Fatal(err)
78+
t.Error(err)
7979
}
8080
t.Log("Test log for unit-test-job:")
8181
t.Log(log)
8282
err = fwext.DeleteManifests(cfg.Client().RESTConfig(), renderedJobUnitTestSingleNodeManifest)
8383
if err != nil {
84-
t.Fatal(err)
84+
t.Error(err)
8585
}
8686
return ctx
8787
}).
@@ -120,13 +120,13 @@ func TestSingleNodeUnitTest(t *testing.T) {
120120
ObjectMeta: metav1.ObjectMeta{Name: "hpc-benckmarks-job", Namespace: "default"},
121121
})
122122
if err != nil {
123-
t.Fatal(err)
123+
t.Error(err)
124124
}
125125
t.Log("Test log for hpc-benckmarks-job:")
126126
t.Log(log)
127127
err = fwext.DeleteManifests(cfg.Client().RESTConfig(), renderedJobHpcBenchmarksSingleNodeManifest)
128128
if err != nil {
129-
t.Fatal(err)
129+
t.Error(err)
130130
}
131131
return ctx
132132
}).

0 commit comments

Comments
 (0)