Skip to content

Commit 68fa4bb

Browse files
committed
Add support for requirements checks to CDI
Signed-off-by: Arjun <agadiyar@nvidia.com>
1 parent 3cfea27 commit 68fa4bb

3 files changed

Lines changed: 209 additions & 36 deletions

File tree

internal/modifier/cdi.go

Lines changed: 8 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -51,6 +51,14 @@ func (f *Factory) newCDIModifier(isJitCDI bool) (oci.SpecModifier, error) {
5151
defaultKind,
5252
)
5353
devices := deviceRequestor.DeviceRequests()
54+
55+
// Run before the empty-device return so NVIDIA_REQUIRE_* is still enforced when
56+
// len(devices)==0 (e.g. CRI CDI injection without matching spec signals). When
57+
// there are no requirements, checkRequirements returns immediately.
58+
if err := checkRequirements(f.logger, f.image, f.driver); err != nil {
59+
return nil, fmt.Errorf("requirements not met: %w", err)
60+
}
61+
5462
if len(devices) == 0 {
5563
f.logger.Debugf("No devices requested; no modification required.")
5664
return nil, nil

internal/modifier/csv.go

Lines changed: 1 addition & 36 deletions
Original file line numberDiff line numberDiff line change
@@ -20,10 +20,7 @@ import (
2020
"fmt"
2121

2222
"github.com/NVIDIA/nvidia-container-toolkit/internal/config/image"
23-
"github.com/NVIDIA/nvidia-container-toolkit/internal/cuda"
24-
"github.com/NVIDIA/nvidia-container-toolkit/internal/logger"
2523
"github.com/NVIDIA/nvidia-container-toolkit/internal/oci"
26-
"github.com/NVIDIA/nvidia-container-toolkit/internal/requirements"
2724
)
2825

2926
// newCSVModifier creates a modifier that applies modications to an OCI spec if required by the runtime wrapper.
@@ -36,45 +33,13 @@ func (f *Factory) newCSVModifier() (oci.SpecModifier, error) {
3633
}
3734
f.logger.Infof("Constructing modifier from config: %+v", *f.cfg)
3835

39-
if err := checkRequirements(f.logger, f.image); err != nil {
36+
if err := checkRequirements(f.logger, f.image, f.driver); err != nil {
4037
return nil, fmt.Errorf("requirements not met: %v", err)
4138
}
4239

4340
return f.newAutomaticCDISpecModifier(devices)
4441
}
4542

46-
func checkRequirements(logger logger.Interface, image *image.CUDA) error {
47-
if image == nil || image.HasDisableRequire() {
48-
// TODO: We could print the real value here instead
49-
logger.Debugf("NVIDIA_DISABLE_REQUIRE=%v; skipping requirement checks", true)
50-
return nil
51-
}
52-
53-
imageRequirements, err := image.GetRequirements()
54-
if err != nil {
55-
// TODO: Should we treat this as a failure, or just issue a warning?
56-
return fmt.Errorf("failed to get image requirements: %v", err)
57-
}
58-
59-
r := requirements.New(logger, imageRequirements)
60-
61-
cudaVersion, err := cuda.Version()
62-
if err != nil {
63-
logger.Warningf("Failed to get CUDA version: %v", err)
64-
} else {
65-
r.AddVersionProperty(requirements.CUDA, cudaVersion)
66-
}
67-
68-
compteCapability, err := cuda.ComputeCapability(0)
69-
if err != nil {
70-
logger.Warningf("Failed to get CUDA Compute Capability: %v", err)
71-
} else {
72-
r.AddVersionProperty(requirements.ARCH, compteCapability)
73-
}
74-
75-
return r.Assert()
76-
}
77-
7843
type csvDevices image.CUDA
7944

8045
func (d csvDevices) DeviceRequests() []string {
Lines changed: 200 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,200 @@
1+
/**
2+
# Copyright (c) 2022, NVIDIA CORPORATION. All rights reserved.
3+
#
4+
# Licensed under the Apache License, Version 2.0 (the "License");
5+
# you may not use this file except in compliance with the License.
6+
# You may obtain a copy of the License at
7+
#
8+
# http://www.apache.org/licenses/LICENSE-2.0
9+
#
10+
# Unless required by applicable law or agreed to in writing, software
11+
# distributed under the License is distributed on an "AS IS" BASIS,
12+
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
13+
# See the License for the specific language governing permissions and
14+
# limitations under the License.
15+
**/
16+
17+
package modifier
18+
19+
import (
20+
"fmt"
21+
"strconv"
22+
"strings"
23+
24+
"github.com/NVIDIA/go-nvml/pkg/nvml"
25+
"golang.org/x/mod/semver"
26+
27+
"github.com/NVIDIA/nvidia-container-toolkit/internal/config/image"
28+
"github.com/NVIDIA/nvidia-container-toolkit/internal/cuda"
29+
"github.com/NVIDIA/nvidia-container-toolkit/internal/logger"
30+
"github.com/NVIDIA/nvidia-container-toolkit/internal/lookup/root"
31+
"github.com/NVIDIA/nvidia-container-toolkit/internal/requirements"
32+
)
33+
34+
// checkRequirements evaluates NVIDIA_REQUIRE_* constraints using the host
35+
// CUDA driver API version from libcuda, the NVIDIA display driver version from
36+
// the driver root (libcuda / libnvidia-ml soname), the compute capability of
37+
// CUDA device 0, and (when requirements reference brand) the GPU product brand
38+
// from NVML. It is used for CSV and CDI / JIT-CDI modes.
39+
func checkRequirements(logger logger.Interface, image *image.CUDA, driver *root.Driver) error {
40+
if image == nil || image.HasDisableRequire() {
41+
logger.Debugf("NVIDIA_DISABLE_REQUIRE=%v; skipping requirement checks", true)
42+
return nil
43+
}
44+
45+
imageRequirements, err := image.GetRequirements()
46+
if err != nil {
47+
return fmt.Errorf("failed to get image requirements: %v", err)
48+
}
49+
if len(imageRequirements) == 0 {
50+
return nil
51+
}
52+
53+
r := requirements.New(logger, imageRequirements)
54+
55+
cudaVersion, err := cuda.Version()
56+
if err != nil {
57+
logger.Warningf("Failed to get CUDA version: %v", err)
58+
} else {
59+
r.AddVersionProperty(requirements.CUDA, cudaVersion)
60+
}
61+
62+
compteCapability, err := cuda.ComputeCapability(0)
63+
if err != nil {
64+
logger.Warningf("Failed to get CUDA Compute Capability: %v", err)
65+
} else {
66+
r.AddVersionProperty(requirements.ARCH, compteCapability)
67+
}
68+
69+
driverVersion, err := driver.Version()
70+
if err != nil {
71+
logger.Warningf("Failed to get NVIDIA driver version: %v", err)
72+
} else {
73+
normalized, normErr := normalizeDriverVersionForSemver(driverVersion)
74+
if normErr != nil {
75+
logger.Warningf("NVIDIA driver version %q is not semver-normalizable: %v", driverVersion, normErr)
76+
} else {
77+
r.AddVersionProperty(requirements.DRIVER, normalized)
78+
}
79+
}
80+
81+
brand, err := getBrandFromNVML(driver)
82+
if err != nil {
83+
logger.Warningf("Failed to get GPU brand from NVML: %v", err)
84+
} else {
85+
r.AddStringProperty(requirements.BRAND, brand)
86+
}
87+
88+
return r.Assert()
89+
}
90+
91+
// normalizeDriverVersionForSemver converts a driver version taken from a
92+
// libcuda / libnvidia-ml soname suffix into a form accepted by
93+
// golang.org/x/mod/semver (no leading zeros in numeric segments)
94+
func normalizeDriverVersionForSemver(raw string) (string, error) {
95+
raw = strings.TrimSpace(raw)
96+
if raw == "" {
97+
return "", fmt.Errorf("empty driver version")
98+
}
99+
parts := strings.Split(raw, ".")
100+
out := make([]string, 0, len(parts))
101+
for _, p := range parts {
102+
if p == "" {
103+
return "", fmt.Errorf("empty version segment in %q", raw)
104+
}
105+
if strings.TrimLeft(p, "0123456789") != "" {
106+
return "", fmt.Errorf("non-numeric version segment %q in %q", p, raw)
107+
}
108+
n, err := strconv.ParseUint(p, 10, 64)
109+
if err != nil {
110+
return "", fmt.Errorf("invalid version segment %q in %q: %w", p, raw, err)
111+
}
112+
out = append(out, strconv.FormatUint(n, 10))
113+
}
114+
normalized := strings.Join(out, ".")
115+
if !semver.IsValid("v" + normalized) {
116+
return "", fmt.Errorf("normalized driver version %q is not valid semver", normalized)
117+
}
118+
return normalized, nil
119+
}
120+
121+
// getBrandFromNVML returns a lowercase brand token for the first visible GPU
122+
// (index 0), using NVML. When driver is non-nil, NVML is loaded from the
123+
// versioned libnvidia-ml under the driver root when possible.
124+
func getBrandFromNVML(driver *root.Driver) (string, error) {
125+
var lib nvml.Interface
126+
var opts []nvml.LibraryOption
127+
v, err := driver.Version()
128+
if err == nil && v != "" && v != "*.*" {
129+
paths, err := driver.Libraries().Locate("libnvidia-ml.so." + v)
130+
if err == nil && len(paths) > 0 {
131+
opts = append(opts, nvml.WithLibraryPath(paths[0]))
132+
}
133+
}
134+
135+
lib = nvml.New(opts...)
136+
if ret := lib.Init(); ret != nvml.SUCCESS {
137+
return "", fmt.Errorf("nvml.Init: %s", lib.ErrorString(ret))
138+
}
139+
defer func() {
140+
_ = lib.Shutdown()
141+
}()
142+
143+
device, ret := lib.DeviceGetHandleByIndex(0)
144+
if ret != nvml.SUCCESS {
145+
return "", fmt.Errorf("nvml.DeviceGetHandleByIndex(0): %s", lib.ErrorString(ret))
146+
}
147+
148+
brandType, ret := lib.DeviceGetBrand(device)
149+
if ret != nvml.SUCCESS {
150+
return "", fmt.Errorf("nvml.DeviceGetBrand: %s", lib.ErrorString(ret))
151+
}
152+
brand, ok := brandTypeToRequirementString(brandType)
153+
if !ok {
154+
return "", fmt.Errorf("unknown NVML brand type %v", brandType)
155+
}
156+
return brand, nil
157+
}
158+
159+
// brandTypeToRequirementString maps NVML brand enums to lowercase tokens
160+
// consistent with typical NVIDIA_REQUIRE_* image constraints.
161+
func brandTypeToRequirementString(b nvml.BrandType) (string, bool) {
162+
switch b {
163+
case nvml.BRAND_UNKNOWN:
164+
return "", false
165+
case nvml.BRAND_QUADRO:
166+
return "quadro", true
167+
case nvml.BRAND_TESLA:
168+
return "tesla", true
169+
case nvml.BRAND_NVS:
170+
return "nvs", true
171+
case nvml.BRAND_GRID:
172+
return "grid", true
173+
case nvml.BRAND_GEFORCE:
174+
return "geforce", true
175+
case nvml.BRAND_TITAN:
176+
return "titan", true
177+
case nvml.BRAND_NVIDIA_VAPPS:
178+
return "nvidiavapps", true
179+
case nvml.BRAND_NVIDIA_VPC:
180+
return "nvidiavpc", true
181+
case nvml.BRAND_NVIDIA_VCS:
182+
return "nvidiavcs", true
183+
case nvml.BRAND_NVIDIA_VWS:
184+
return "nvidiavws", true
185+
case nvml.BRAND_NVIDIA_CLOUD_GAMING:
186+
return "nvidiacloudgaming", true
187+
case nvml.BRAND_QUADRO_RTX:
188+
return "quadrortx", true
189+
case nvml.BRAND_NVIDIA_RTX:
190+
return "nvidiartx", true
191+
case nvml.BRAND_NVIDIA:
192+
return "nvidia", true
193+
case nvml.BRAND_GEFORCE_RTX:
194+
return "geforcertx", true
195+
case nvml.BRAND_TITAN_RTX:
196+
return "titanrtx", true
197+
default:
198+
return "", false
199+
}
200+
}

0 commit comments

Comments
 (0)