Skip to content

Commit a3ca7fd

Browse files
Refactor the way we create CDI Hooks
Signed-off-by: Carlos Eduardo Arango Gutierrez <[email protected]>
1 parent ac8f190 commit a3ca7fd

30 files changed

+200
-209
lines changed

internal/discover/compat_libs.go

Lines changed: 2 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -9,16 +9,12 @@ import (
99

1010
// NewCUDACompatHookDiscoverer creates a discoverer for a enable-cuda-compat hook.
1111
// This hook is responsible for setting up CUDA compatibility in the container and depends on the host driver version.
12-
func NewCUDACompatHookDiscoverer(logger logger.Interface, nvidiaCDIHookPath string, driver *root.Driver) Discover {
12+
func NewCUDACompatHookDiscoverer(logger logger.Interface, hookCreator HookCreator, driver *root.Driver) Discover {
1313
_, cudaVersionPattern := getCUDALibRootAndVersionPattern(logger, driver)
1414
var args []string
1515
if !strings.Contains(cudaVersionPattern, "*") {
1616
args = append(args, "--host-driver-version="+cudaVersionPattern)
1717
}
1818

19-
return CreateNvidiaCDIHook(
20-
nvidiaCDIHookPath,
21-
"enable-cuda-compat",
22-
args...,
23-
)
19+
return hookCreator.Create("enable-cuda-compat", args...)
2420
}

internal/discover/graphics.go

Lines changed: 23 additions & 27 deletions
Original file line numberDiff line numberDiff line change
@@ -36,21 +36,21 @@ import (
3636
// TODO: The logic for creating DRM devices should be consolidated between this
3737
// and the logic for generating CDI specs for a single device. This is only used
3838
// when applying OCI spec modifications to an incoming spec in "legacy" mode.
39-
func NewDRMNodesDiscoverer(logger logger.Interface, devices image.VisibleDevices, devRoot string, nvidiaCDIHookPath string) (Discover, error) {
39+
func NewDRMNodesDiscoverer(logger logger.Interface, devices image.VisibleDevices, devRoot string, hookCreator HookCreator) (Discover, error) {
4040
drmDeviceNodes, err := newDRMDeviceDiscoverer(logger, devices, devRoot)
4141
if err != nil {
4242
return nil, fmt.Errorf("failed to create DRM device discoverer: %v", err)
4343
}
4444

45-
drmByPathSymlinks := newCreateDRMByPathSymlinks(logger, drmDeviceNodes, devRoot, nvidiaCDIHookPath)
45+
drmByPathSymlinks := newCreateDRMByPathSymlinks(logger, drmDeviceNodes, devRoot, hookCreator)
4646

4747
discover := Merge(drmDeviceNodes, drmByPathSymlinks)
4848
return discover, nil
4949
}
5050

5151
// NewGraphicsMountsDiscoverer creates a discoverer for the mounts required by graphics tools such as vulkan.
52-
func NewGraphicsMountsDiscoverer(logger logger.Interface, driver *root.Driver, nvidiaCDIHookPath string) (Discover, error) {
53-
libraries := newGraphicsLibrariesDiscoverer(logger, driver, nvidiaCDIHookPath)
52+
func NewGraphicsMountsDiscoverer(logger logger.Interface, driver *root.Driver, hookCreator HookCreator) (Discover, error) {
53+
libraries := newGraphicsLibrariesDiscoverer(logger, driver, hookCreator)
5454

5555
configs := NewMounts(
5656
logger,
@@ -95,13 +95,13 @@ func newVulkanConfigsDiscover(logger logger.Interface, driver *root.Driver) Disc
9595

9696
type graphicsDriverLibraries struct {
9797
Discover
98-
logger logger.Interface
99-
nvidiaCDIHookPath string
98+
logger logger.Interface
99+
hookCreator HookCreator
100100
}
101101

102102
var _ Discover = (*graphicsDriverLibraries)(nil)
103103

104-
func newGraphicsLibrariesDiscoverer(logger logger.Interface, driver *root.Driver, nvidiaCDIHookPath string) Discover {
104+
func newGraphicsLibrariesDiscoverer(logger logger.Interface, driver *root.Driver, hookCreator HookCreator) Discover {
105105
cudaLibRoot, cudaVersionPattern := getCUDALibRootAndVersionPattern(logger, driver)
106106

107107
libraries := NewMounts(
@@ -140,9 +140,9 @@ func newGraphicsLibrariesDiscoverer(logger logger.Interface, driver *root.Driver
140140
)
141141

142142
return &graphicsDriverLibraries{
143-
Discover: Merge(libraries, xorgLibraries),
144-
logger: logger,
145-
nvidiaCDIHookPath: nvidiaCDIHookPath,
143+
Discover: Merge(libraries, xorgLibraries),
144+
logger: logger,
145+
hookCreator: hookCreator,
146146
}
147147
}
148148

@@ -203,9 +203,9 @@ func (d graphicsDriverLibraries) Hooks() ([]Hook, error) {
203203
return nil, nil
204204
}
205205

206-
hooks := CreateCreateSymlinkHook(d.nvidiaCDIHookPath, links)
206+
hook := d.hookCreator.Create("create-symlinks", links...)
207207

208-
return hooks.Hooks()
208+
return hook.Hooks()
209209
}
210210

211211
// isDriverLibrary checks whether the specified filename is a specific driver library.
@@ -275,19 +275,19 @@ func buildXOrgSearchPaths(libRoot string) []string {
275275

276276
type drmDevicesByPath struct {
277277
None
278-
logger logger.Interface
279-
nvidiaCDIHookPath string
280-
devRoot string
281-
devicesFrom Discover
278+
logger logger.Interface
279+
hookCreator HookCreator
280+
devRoot string
281+
devicesFrom Discover
282282
}
283283

284284
// newCreateDRMByPathSymlinks creates a discoverer for a hook to create the by-path symlinks for DRM devices discovered by the specified devices discoverer
285-
func newCreateDRMByPathSymlinks(logger logger.Interface, devices Discover, devRoot string, nvidiaCDIHookPath string) Discover {
285+
func newCreateDRMByPathSymlinks(logger logger.Interface, devices Discover, devRoot string, hookCreator HookCreator) Discover {
286286
d := drmDevicesByPath{
287-
logger: logger,
288-
nvidiaCDIHookPath: nvidiaCDIHookPath,
289-
devRoot: devRoot,
290-
devicesFrom: devices,
287+
logger: logger,
288+
hookCreator: hookCreator,
289+
devRoot: devRoot,
290+
devicesFrom: devices,
291291
}
292292

293293
return &d
@@ -315,13 +315,9 @@ func (d drmDevicesByPath) Hooks() ([]Hook, error) {
315315
args = append(args, "--link", l)
316316
}
317317

318-
hook := CreateNvidiaCDIHook(
319-
d.nvidiaCDIHookPath,
320-
"create-symlinks",
321-
args...,
322-
)
318+
hook := d.hookCreator.Create("create-symlinks", args...)
323319

324-
return []Hook{hook}, nil
320+
return hook.Hooks()
325321
}
326322

327323
// getSpecificLinkArgs returns the required specific links that need to be created

internal/discover/graphics_test.go

Lines changed: 4 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -25,6 +25,7 @@ import (
2525

2626
func TestGraphicsLibrariesDiscoverer(t *testing.T) {
2727
logger, _ := testlog.NewNullLogger()
28+
hookCreator := NewHookCreator("/usr/bin/nvidia-cdi-hook")
2829

2930
testCases := []struct {
3031
description string
@@ -136,9 +137,9 @@ func TestGraphicsLibrariesDiscoverer(t *testing.T) {
136137
for _, tc := range testCases {
137138
t.Run(tc.description, func(t *testing.T) {
138139
d := &graphicsDriverLibraries{
139-
Discover: tc.libraries,
140-
logger: logger,
141-
nvidiaCDIHookPath: "/usr/bin/nvidia-cdi-hook",
140+
Discover: tc.libraries,
141+
logger: logger,
142+
hookCreator: hookCreator,
142143
}
143144

144145
devices, err := d.Devices()

internal/discover/hooks.go

Lines changed: 39 additions & 27 deletions
Original file line numberDiff line numberDiff line change
@@ -25,54 +25,66 @@ import (
2525
var _ Discover = (*Hook)(nil)
2626

2727
// Devices returns an empty list of devices for a Hook discoverer.
28-
func (h Hook) Devices() ([]Device, error) {
28+
func (h *Hook) Devices() ([]Device, error) {
2929
return nil, nil
3030
}
3131

3232
// Mounts returns an empty list of mounts for a Hook discoverer.
33-
func (h Hook) Mounts() ([]Mount, error) {
33+
func (h *Hook) Mounts() ([]Mount, error) {
3434
return nil, nil
3535
}
3636

3737
// Hooks allows the Hook type to also implement the Discoverer interface.
3838
// It returns a single hook
39-
func (h Hook) Hooks() ([]Hook, error) {
40-
return []Hook{h}, nil
39+
func (h *Hook) Hooks() ([]Hook, error) {
40+
if h == nil {
41+
return nil, nil
42+
}
43+
44+
return []Hook{*h}, nil
4145
}
4246

43-
// CreateCreateSymlinkHook creates a hook which creates a symlink from link -> target.
44-
func CreateCreateSymlinkHook(nvidiaCDIHookPath string, links []string) Discover {
45-
if len(links) == 0 {
46-
return None{}
47-
}
47+
// Option is a function that configures the nvcdilib
48+
type Option func(*CDIHook)
4849

49-
var args []string
50-
for _, link := range links {
51-
args = append(args, "--link", link)
52-
}
53-
return CreateNvidiaCDIHook(
54-
nvidiaCDIHookPath,
55-
"create-symlinks",
56-
args...,
57-
)
50+
type CDIHook struct {
51+
nvidiaCDIHookPath string
5852
}
5953

60-
// CreateNvidiaCDIHook creates a hook which invokes the NVIDIA Container CLI hook subcommand.
61-
func CreateNvidiaCDIHook(nvidiaCDIHookPath string, hookName string, additionalArgs ...string) Hook {
62-
return cdiHook(nvidiaCDIHookPath).Create(hookName, additionalArgs...)
54+
type HookCreator interface {
55+
Create(string, ...string) *Hook
6356
}
6457

65-
type cdiHook string
58+
func NewHookCreator(nvidiaCDIHookPath string) HookCreator {
59+
CDIHook := &CDIHook{
60+
nvidiaCDIHookPath: nvidiaCDIHookPath,
61+
}
6662

67-
func (c cdiHook) Create(name string, args ...string) Hook {
68-
return Hook{
63+
return CDIHook
64+
}
65+
66+
func (c CDIHook) Create(name string, args ...string) *Hook {
67+
if name == "create-symlinks" {
68+
if len(args) == 0 {
69+
return nil
70+
}
71+
72+
links := []string{}
73+
for _, arg := range args {
74+
links = append(links, "--link", arg)
75+
}
76+
args = links
77+
}
78+
79+
return &Hook{
6980
Lifecycle: cdi.CreateContainerHook,
70-
Path: string(c),
81+
Path: c.nvidiaCDIHookPath,
7182
Args: append(c.requiredArgs(name), args...),
7283
}
7384
}
74-
func (c cdiHook) requiredArgs(name string) []string {
75-
base := filepath.Base(string(c))
85+
86+
func (c CDIHook) requiredArgs(name string) []string {
87+
base := filepath.Base(c.nvidiaCDIHookPath)
7688
if base == "nvidia-ctk" {
7789
return []string{base, "hook", name}
7890
}

internal/discover/ldconfig.go

Lines changed: 17 additions & 21 deletions
Original file line numberDiff line numberDiff line change
@@ -25,23 +25,23 @@ import (
2525
)
2626

2727
// NewLDCacheUpdateHook creates a discoverer that updates the ldcache for the specified mounts. A logger can also be specified
28-
func NewLDCacheUpdateHook(logger logger.Interface, mounts Discover, nvidiaCDIHookPath, ldconfigPath string) (Discover, error) {
28+
func NewLDCacheUpdateHook(logger logger.Interface, mounts Discover, hookCreator HookCreator, ldconfigPath string) (Discover, error) {
2929
d := ldconfig{
30-
logger: logger,
31-
nvidiaCDIHookPath: nvidiaCDIHookPath,
32-
ldconfigPath: ldconfigPath,
33-
mountsFrom: mounts,
30+
logger: logger,
31+
hookCreator: hookCreator,
32+
ldconfigPath: ldconfigPath,
33+
mountsFrom: mounts,
3434
}
3535

3636
return &d, nil
3737
}
3838

3939
type ldconfig struct {
4040
None
41-
logger logger.Interface
42-
nvidiaCDIHookPath string
43-
ldconfigPath string
44-
mountsFrom Discover
41+
logger logger.Interface
42+
hookCreator HookCreator
43+
ldconfigPath string
44+
mountsFrom Discover
4545
}
4646

4747
// Hooks checks the required mounts for libraries and returns a hook to update the LDcache for the discovered paths.
@@ -50,16 +50,18 @@ func (d ldconfig) Hooks() ([]Hook, error) {
5050
if err != nil {
5151
return nil, fmt.Errorf("failed to discover mounts for ldcache update: %v", err)
5252
}
53-
h := CreateLDCacheUpdateHook(
54-
d.nvidiaCDIHookPath,
53+
54+
h := createLDCacheUpdateHook(
55+
d.hookCreator,
5556
d.ldconfigPath,
5657
getLibraryPaths(mounts),
5758
)
58-
return []Hook{h}, nil
59+
60+
return h.Hooks()
5961
}
6062

61-
// CreateLDCacheUpdateHook locates the NVIDIA Container Toolkit CLI and creates a hook for updating the LD Cache
62-
func CreateLDCacheUpdateHook(executable string, ldconfig string, libraries []string) Hook {
63+
// createLDCacheUpdateHook locates the NVIDIA Container Toolkit CLI and creates a hook for updating the LD Cache
64+
func createLDCacheUpdateHook(hookCreator HookCreator, ldconfig string, libraries []string) *Hook {
6365
var args []string
6466

6567
if ldconfig != "" {
@@ -70,13 +72,7 @@ func CreateLDCacheUpdateHook(executable string, ldconfig string, libraries []str
7072
args = append(args, "--folder", f)
7173
}
7274

73-
hook := CreateNvidiaCDIHook(
74-
executable,
75-
"update-ldcache",
76-
args...,
77-
)
78-
79-
return hook
75+
return hookCreator.Create("update-ldcache", args...)
8076
}
8177

8278
// getLibraryPaths extracts the library dirs from the specified mounts

internal/discover/ldconfig_test.go

Lines changed: 2 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -31,6 +31,7 @@ const (
3131

3232
func TestLDCacheUpdateHook(t *testing.T) {
3333
logger, _ := testlog.NewNullLogger()
34+
hookCreator := NewHookCreator(testNvidiaCDIHookPath)
3435

3536
testCases := []struct {
3637
description string
@@ -97,7 +98,7 @@ func TestLDCacheUpdateHook(t *testing.T) {
9798
Lifecycle: "createContainer",
9899
}
99100

100-
d, err := NewLDCacheUpdateHook(logger, mountMock, testNvidiaCDIHookPath, tc.ldconfigPath)
101+
d, err := NewLDCacheUpdateHook(logger, mountMock, hookCreator, tc.ldconfigPath)
101102
require.NoError(t, err)
102103

103104
hooks, err := d.Hooks()

internal/discover/symlinks.go

Lines changed: 12 additions & 8 deletions
Original file line numberDiff line numberDiff line change
@@ -23,20 +23,20 @@ import (
2323

2424
type additionalSymlinks struct {
2525
Discover
26-
version string
27-
nvidiaCDIHookPath string
26+
version string
27+
hookCreator HookCreator
2828
}
2929

3030
// WithDriverDotSoSymlinks decorates the provided discoverer.
3131
// A hook is added that checks for specific driver symlinks that need to be created.
32-
func WithDriverDotSoSymlinks(mounts Discover, version string, nvidiaCDIHookPath string) Discover {
32+
func WithDriverDotSoSymlinks(mounts Discover, version string, hookCreator HookCreator) Discover {
3333
if version == "" {
3434
version = "*.*"
3535
}
3636
return &additionalSymlinks{
37-
Discover: mounts,
38-
nvidiaCDIHookPath: nvidiaCDIHookPath,
39-
version: version,
37+
Discover: mounts,
38+
hookCreator: hookCreator,
39+
version: version,
4040
}
4141
}
4242

@@ -73,8 +73,12 @@ func (d *additionalSymlinks) Hooks() ([]Hook, error) {
7373
return hooks, nil
7474
}
7575

76-
hook := CreateCreateSymlinkHook(d.nvidiaCDIHookPath, links).(Hook)
77-
return append(hooks, hook), nil
76+
createSymlinkHooks, err := d.hookCreator.Create("create-symlinks", links...).Hooks()
77+
if err != nil {
78+
return nil, fmt.Errorf("failed to create symlink hook: %v", err)
79+
}
80+
81+
return append(hooks, createSymlinkHooks...), nil
7882
}
7983

8084
// getLinksForMount maps the path to created links if any.

internal/discover/symlinks_test.go

Lines changed: 2 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -306,12 +306,13 @@ func TestWithWithDriverDotSoSymlinks(t *testing.T) {
306306
},
307307
}
308308

309+
hookCreator := NewHookCreator("/path/to/nvidia-cdi-hook")
309310
for _, tc := range testCases {
310311
t.Run(tc.description, func(t *testing.T) {
311312
d := WithDriverDotSoSymlinks(
312313
tc.discover,
313314
tc.version,
314-
"/path/to/nvidia-cdi-hook",
315+
hookCreator,
315316
)
316317

317318
devices, err := d.Devices()

0 commit comments

Comments
 (0)