Skip to content

Commit 56fd4f4

Browse files
committed
Add support for requirements checks to CDI
Signed-off-by: Arjun <agadiyar@nvidia.com>
1 parent 3cfea27 commit 56fd4f4

3 files changed

Lines changed: 64 additions & 35 deletions

File tree

internal/modifier/cdi.go

Lines changed: 4 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -56,6 +56,10 @@ func (f *Factory) newCDIModifier(isJitCDI bool) (oci.SpecModifier, error) {
5656
return nil, nil
5757
}
5858

59+
if err := checkRequirements(f.logger, f.image); err != nil {
60+
return nil, fmt.Errorf("requirements not met: %w", err)
61+
}
62+
5963
automaticDevices := filterAutomaticDevices(devices)
6064
if len(automaticDevices) != len(devices) && len(automaticDevices) > 0 {
6165
return nil, fmt.Errorf("requesting a CDI device with vendor 'runtime.nvidia.com' is not supported when requesting other CDI devices")

internal/modifier/csv.go

Lines changed: 0 additions & 35 deletions
Original file line numberDiff line numberDiff line change
@@ -20,10 +20,7 @@ import (
2020
"fmt"
2121

2222
"github.com/NVIDIA/nvidia-container-toolkit/internal/config/image"
23-
"github.com/NVIDIA/nvidia-container-toolkit/internal/cuda"
24-
"github.com/NVIDIA/nvidia-container-toolkit/internal/logger"
2523
"github.com/NVIDIA/nvidia-container-toolkit/internal/oci"
26-
"github.com/NVIDIA/nvidia-container-toolkit/internal/requirements"
2724
)
2825

2926
// newCSVModifier creates a modifier that applies modications to an OCI spec if required by the runtime wrapper.
@@ -43,38 +40,6 @@ func (f *Factory) newCSVModifier() (oci.SpecModifier, error) {
4340
return f.newAutomaticCDISpecModifier(devices)
4441
}
4542

46-
func checkRequirements(logger logger.Interface, image *image.CUDA) error {
47-
if image == nil || image.HasDisableRequire() {
48-
// TODO: We could print the real value here instead
49-
logger.Debugf("NVIDIA_DISABLE_REQUIRE=%v; skipping requirement checks", true)
50-
return nil
51-
}
52-
53-
imageRequirements, err := image.GetRequirements()
54-
if err != nil {
55-
// TODO: Should we treat this as a failure, or just issue a warning?
56-
return fmt.Errorf("failed to get image requirements: %v", err)
57-
}
58-
59-
r := requirements.New(logger, imageRequirements)
60-
61-
cudaVersion, err := cuda.Version()
62-
if err != nil {
63-
logger.Warningf("Failed to get CUDA version: %v", err)
64-
} else {
65-
r.AddVersionProperty(requirements.CUDA, cudaVersion)
66-
}
67-
68-
compteCapability, err := cuda.ComputeCapability(0)
69-
if err != nil {
70-
logger.Warningf("Failed to get CUDA Compute Capability: %v", err)
71-
} else {
72-
r.AddVersionProperty(requirements.ARCH, compteCapability)
73-
}
74-
75-
return r.Assert()
76-
}
77-
7843
type csvDevices image.CUDA
7944

8045
func (d csvDevices) DeviceRequests() []string {
Lines changed: 60 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,60 @@
1+
/**
2+
# Copyright (c) 2022, NVIDIA CORPORATION. All rights reserved.
3+
#
4+
# Licensed under the Apache License, Version 2.0 (the "License");
5+
# you may not use this file except in compliance with the License.
6+
# You may obtain a copy of the License at
7+
#
8+
# http://www.apache.org/licenses/LICENSE-2.0
9+
#
10+
# Unless required by applicable law or agreed to in writing, software
11+
# distributed under the License is distributed on an "AS IS" BASIS,
12+
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
13+
# See the License for the specific language governing permissions and
14+
# limitations under the License.
15+
**/
16+
17+
package modifier
18+
19+
import (
20+
"fmt"
21+
22+
"github.com/NVIDIA/nvidia-container-toolkit/internal/config/image"
23+
"github.com/NVIDIA/nvidia-container-toolkit/internal/cuda"
24+
"github.com/NVIDIA/nvidia-container-toolkit/internal/logger"
25+
"github.com/NVIDIA/nvidia-container-toolkit/internal/requirements"
26+
)
27+
28+
// checkRequirements evaluates NVIDIA_REQUIRE_* constraints using the host
29+
// CUDA driver version and the compute capability of CUDA device 0. This
30+
// matches the subset enforced in CSV mode and is shared with CDI / JIT-CDI
31+
// modes (driver and brand constraints are not populated here).
32+
func checkRequirements(logger logger.Interface, image *image.CUDA) error {
33+
if image == nil || image.HasDisableRequire() {
34+
logger.Debugf("NVIDIA_DISABLE_REQUIRE=%v; skipping requirement checks", true)
35+
return nil
36+
}
37+
38+
imageRequirements, err := image.GetRequirements()
39+
if err != nil {
40+
return fmt.Errorf("failed to get image requirements: %v", err)
41+
}
42+
43+
r := requirements.New(logger, imageRequirements)
44+
45+
cudaVersion, err := cuda.Version()
46+
if err != nil {
47+
logger.Warningf("Failed to get CUDA version: %v", err)
48+
} else {
49+
r.AddVersionProperty(requirements.CUDA, cudaVersion)
50+
}
51+
52+
compteCapability, err := cuda.ComputeCapability(0)
53+
if err != nil {
54+
logger.Warningf("Failed to get CUDA Compute Capability: %v", err)
55+
} else {
56+
r.AddVersionProperty(requirements.ARCH, compteCapability)
57+
}
58+
59+
return r.Assert()
60+
}

0 commit comments

Comments
 (0)