Skip to content

Refactor the way we handle Hook Creation #1090

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Merged
merged 2 commits into from
May 21, 2025
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
7 changes: 2 additions & 5 deletions .gitignore
Original file line number Diff line number Diff line change
Expand Up @@ -4,11 +4,8 @@
*.swo
/coverage.out*
/tests/output/
/nvidia-container-runtime
/nvidia-container-runtime.*
/nvidia-container-runtime-hook
/nvidia-container-toolkit
/nvidia-ctk
/nvidia-*
/shared-*
/release-*
/bin
/toolkit-test
8 changes: 2 additions & 6 deletions internal/discover/compat_libs.go
Original file line number Diff line number Diff line change
Expand Up @@ -9,16 +9,12 @@ import (

// NewCUDACompatHookDiscoverer creates a discoverer for a enable-cuda-compat hook.
// This hook is responsible for setting up CUDA compatibility in the container and depends on the host driver version.
func NewCUDACompatHookDiscoverer(logger logger.Interface, nvidiaCDIHookPath string, driver *root.Driver) Discover {
func NewCUDACompatHookDiscoverer(logger logger.Interface, hookCreator HookCreator, driver *root.Driver) Discover {
_, cudaVersionPattern := getCUDALibRootAndVersionPattern(logger, driver)
var args []string
if !strings.Contains(cudaVersionPattern, "*") {
args = append(args, "--host-driver-version="+cudaVersionPattern)
}

return CreateNvidiaCDIHook(
nvidiaCDIHookPath,
"enable-cuda-compat",
args...,
)
return hookCreator.Create("enable-cuda-compat", args...)
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Note for a follow-up: Create should be modified to accept a HookName instead of arbitrary strings.

Copy link
Collaborator Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I agree, but we would need to define the HookName in a new Package, as pkg/nvcdi calls internal/discover , so we can not reference pkg/nvcdi from internal/discover without running into an import loop.
I think...

Copy link
Collaborator Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I now have a better idea for this after our conversation, this will be addressed in a followup PR

}
50 changes: 23 additions & 27 deletions internal/discover/graphics.go
Original file line number Diff line number Diff line change
Expand Up @@ -36,21 +36,21 @@ import (
// TODO: The logic for creating DRM devices should be consolidated between this
// and the logic for generating CDI specs for a single device. This is only used
// when applying OCI spec modifications to an incoming spec in "legacy" mode.
func NewDRMNodesDiscoverer(logger logger.Interface, devices image.VisibleDevices, devRoot string, nvidiaCDIHookPath string) (Discover, error) {
func NewDRMNodesDiscoverer(logger logger.Interface, devices image.VisibleDevices, devRoot string, hookCreator HookCreator) (Discover, error) {
drmDeviceNodes, err := newDRMDeviceDiscoverer(logger, devices, devRoot)
if err != nil {
return nil, fmt.Errorf("failed to create DRM device discoverer: %v", err)
}

drmByPathSymlinks := newCreateDRMByPathSymlinks(logger, drmDeviceNodes, devRoot, nvidiaCDIHookPath)
drmByPathSymlinks := newCreateDRMByPathSymlinks(logger, drmDeviceNodes, devRoot, hookCreator)

discover := Merge(drmDeviceNodes, drmByPathSymlinks)
return discover, nil
}

// NewGraphicsMountsDiscoverer creates a discoverer for the mounts required by graphics tools such as vulkan.
func NewGraphicsMountsDiscoverer(logger logger.Interface, driver *root.Driver, nvidiaCDIHookPath string) (Discover, error) {
libraries := newGraphicsLibrariesDiscoverer(logger, driver, nvidiaCDIHookPath)
func NewGraphicsMountsDiscoverer(logger logger.Interface, driver *root.Driver, hookCreator HookCreator) (Discover, error) {
libraries := newGraphicsLibrariesDiscoverer(logger, driver, hookCreator)

configs := NewMounts(
logger,
Expand Down Expand Up @@ -95,13 +95,13 @@ func newVulkanConfigsDiscover(logger logger.Interface, driver *root.Driver) Disc

type graphicsDriverLibraries struct {
Discover
logger logger.Interface
nvidiaCDIHookPath string
logger logger.Interface
hookCreator HookCreator
}

var _ Discover = (*graphicsDriverLibraries)(nil)

func newGraphicsLibrariesDiscoverer(logger logger.Interface, driver *root.Driver, nvidiaCDIHookPath string) Discover {
func newGraphicsLibrariesDiscoverer(logger logger.Interface, driver *root.Driver, hookCreator HookCreator) Discover {
cudaLibRoot, cudaVersionPattern := getCUDALibRootAndVersionPattern(logger, driver)

libraries := NewMounts(
Expand Down Expand Up @@ -140,9 +140,9 @@ func newGraphicsLibrariesDiscoverer(logger logger.Interface, driver *root.Driver
)

return &graphicsDriverLibraries{
Discover: Merge(libraries, xorgLibraries),
logger: logger,
nvidiaCDIHookPath: nvidiaCDIHookPath,
Discover: Merge(libraries, xorgLibraries),
logger: logger,
hookCreator: hookCreator,
}
}

Expand Down Expand Up @@ -203,9 +203,9 @@ func (d graphicsDriverLibraries) Hooks() ([]Hook, error) {
return nil, nil
}

hooks := CreateCreateSymlinkHook(d.nvidiaCDIHookPath, links)
hook := d.hookCreator.Create("create-symlinks", links...)

return hooks.Hooks()
return hook.Hooks()
}

// isDriverLibrary checks whether the specified filename is a specific driver library.
Expand Down Expand Up @@ -275,19 +275,19 @@ func buildXOrgSearchPaths(libRoot string) []string {

type drmDevicesByPath struct {
None
logger logger.Interface
nvidiaCDIHookPath string
devRoot string
devicesFrom Discover
logger logger.Interface
hookCreator HookCreator
devRoot string
devicesFrom Discover
}

// newCreateDRMByPathSymlinks creates a discoverer for a hook to create the by-path symlinks for DRM devices discovered by the specified devices discoverer
func newCreateDRMByPathSymlinks(logger logger.Interface, devices Discover, devRoot string, nvidiaCDIHookPath string) Discover {
func newCreateDRMByPathSymlinks(logger logger.Interface, devices Discover, devRoot string, hookCreator HookCreator) Discover {
d := drmDevicesByPath{
logger: logger,
nvidiaCDIHookPath: nvidiaCDIHookPath,
devRoot: devRoot,
devicesFrom: devices,
logger: logger,
hookCreator: hookCreator,
devRoot: devRoot,
devicesFrom: devices,
}

return &d
Expand Down Expand Up @@ -315,13 +315,9 @@ func (d drmDevicesByPath) Hooks() ([]Hook, error) {
args = append(args, "--link", l)
}

hook := CreateNvidiaCDIHook(
d.nvidiaCDIHookPath,
"create-symlinks",
args...,
)
hook := d.hookCreator.Create("create-symlinks", args...)

return []Hook{hook}, nil
return hook.Hooks()
}

// getSpecificLinkArgs returns the required specific links that need to be created
Expand Down
7 changes: 4 additions & 3 deletions internal/discover/graphics_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -25,6 +25,7 @@ import (

func TestGraphicsLibrariesDiscoverer(t *testing.T) {
logger, _ := testlog.NewNullLogger()
hookCreator := NewHookCreator("/usr/bin/nvidia-cdi-hook")

testCases := []struct {
description string
Expand Down Expand Up @@ -136,9 +137,9 @@ func TestGraphicsLibrariesDiscoverer(t *testing.T) {
for _, tc := range testCases {
t.Run(tc.description, func(t *testing.T) {
d := &graphicsDriverLibraries{
Discover: tc.libraries,
logger: logger,
nvidiaCDIHookPath: "/usr/bin/nvidia-cdi-hook",
Discover: tc.libraries,
logger: logger,
hookCreator: hookCreator,
}

devices, err := d.Devices()
Expand Down
66 changes: 39 additions & 27 deletions internal/discover/hooks.go
Original file line number Diff line number Diff line change
Expand Up @@ -25,54 +25,66 @@ import (
var _ Discover = (*Hook)(nil)

// Devices returns an empty list of devices for a Hook discoverer.
func (h Hook) Devices() ([]Device, error) {
func (h *Hook) Devices() ([]Device, error) {
return nil, nil
}

// Mounts returns an empty list of mounts for a Hook discoverer.
func (h Hook) Mounts() ([]Mount, error) {
func (h *Hook) Mounts() ([]Mount, error) {
return nil, nil
}

// Hooks allows the Hook type to also implement the Discoverer interface.
// It returns a single hook
func (h Hook) Hooks() ([]Hook, error) {
return []Hook{h}, nil
func (h *Hook) Hooks() ([]Hook, error) {
if h == nil {
return nil, nil
}

return []Hook{*h}, nil
}

// CreateCreateSymlinkHook creates a hook which creates a symlink from link -> target.
func CreateCreateSymlinkHook(nvidiaCDIHookPath string, links []string) Discover {
if len(links) == 0 {
return None{}
}
// Option is a function that configures the nvcdilib
type Option func(*CDIHook)

var args []string
for _, link := range links {
args = append(args, "--link", link)
}
return CreateNvidiaCDIHook(
nvidiaCDIHookPath,
"create-symlinks",
args...,
)
type CDIHook struct {
nvidiaCDIHookPath string
}

// CreateNvidiaCDIHook creates a hook which invokes the NVIDIA Container CLI hook subcommand.
func CreateNvidiaCDIHook(nvidiaCDIHookPath string, hookName string, additionalArgs ...string) Hook {
return cdiHook(nvidiaCDIHookPath).Create(hookName, additionalArgs...)
type HookCreator interface {
Create(string, ...string) *Hook
}

type cdiHook string
func NewHookCreator(nvidiaCDIHookPath string) HookCreator {
CDIHook := &CDIHook{
nvidiaCDIHookPath: nvidiaCDIHookPath,
}

func (c cdiHook) Create(name string, args ...string) Hook {
return Hook{
return CDIHook
}

func (c CDIHook) Create(name string, args ...string) *Hook {
if name == "create-symlinks" {
if len(args) == 0 {
return nil
}

links := []string{}
for _, arg := range args {
links = append(links, "--link", arg)
}
args = links
}

return &Hook{
Lifecycle: cdi.CreateContainerHook,
Path: string(c),
Path: c.nvidiaCDIHookPath,
Args: append(c.requiredArgs(name), args...),
}
}
func (c cdiHook) requiredArgs(name string) []string {
base := filepath.Base(string(c))

func (c CDIHook) requiredArgs(name string) []string {
base := filepath.Base(c.nvidiaCDIHookPath)
if base == "nvidia-ctk" {
return []string{base, "hook", name}
}
Expand Down
38 changes: 17 additions & 21 deletions internal/discover/ldconfig.go
Original file line number Diff line number Diff line change
Expand Up @@ -25,23 +25,23 @@ import (
)

// NewLDCacheUpdateHook creates a discoverer that updates the ldcache for the specified mounts. A logger can also be specified
func NewLDCacheUpdateHook(logger logger.Interface, mounts Discover, nvidiaCDIHookPath, ldconfigPath string) (Discover, error) {
func NewLDCacheUpdateHook(logger logger.Interface, mounts Discover, hookCreator HookCreator, ldconfigPath string) (Discover, error) {
d := ldconfig{
logger: logger,
nvidiaCDIHookPath: nvidiaCDIHookPath,
ldconfigPath: ldconfigPath,
mountsFrom: mounts,
logger: logger,
hookCreator: hookCreator,
ldconfigPath: ldconfigPath,
mountsFrom: mounts,
}

return &d, nil
}

type ldconfig struct {
None
logger logger.Interface
nvidiaCDIHookPath string
ldconfigPath string
mountsFrom Discover
logger logger.Interface
hookCreator HookCreator
ldconfigPath string
mountsFrom Discover
}

// Hooks checks the required mounts for libraries and returns a hook to update the LDcache for the discovered paths.
Expand All @@ -50,16 +50,18 @@ func (d ldconfig) Hooks() ([]Hook, error) {
if err != nil {
return nil, fmt.Errorf("failed to discover mounts for ldcache update: %v", err)
}
h := CreateLDCacheUpdateHook(
d.nvidiaCDIHookPath,

h := createLDCacheUpdateHook(
d.hookCreator,
d.ldconfigPath,
getLibraryPaths(mounts),
)
return []Hook{h}, nil

return h.Hooks()
}

// CreateLDCacheUpdateHook locates the NVIDIA Container Toolkit CLI and creates a hook for updating the LD Cache
func CreateLDCacheUpdateHook(executable string, ldconfig string, libraries []string) Hook {
// createLDCacheUpdateHook locates the NVIDIA Container Toolkit CLI and creates a hook for updating the LD Cache
func createLDCacheUpdateHook(hookCreator HookCreator, ldconfig string, libraries []string) *Hook {
var args []string

if ldconfig != "" {
Expand All @@ -70,13 +72,7 @@ func CreateLDCacheUpdateHook(executable string, ldconfig string, libraries []str
args = append(args, "--folder", f)
}

hook := CreateNvidiaCDIHook(
executable,
"update-ldcache",
args...,
)

return hook
return hookCreator.Create("update-ldcache", args...)
}

// getLibraryPaths extracts the library dirs from the specified mounts
Expand Down
3 changes: 2 additions & 1 deletion internal/discover/ldconfig_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -31,6 +31,7 @@ const (

func TestLDCacheUpdateHook(t *testing.T) {
logger, _ := testlog.NewNullLogger()
hookCreator := NewHookCreator(testNvidiaCDIHookPath)

testCases := []struct {
description string
Expand Down Expand Up @@ -97,7 +98,7 @@ func TestLDCacheUpdateHook(t *testing.T) {
Lifecycle: "createContainer",
}

d, err := NewLDCacheUpdateHook(logger, mountMock, testNvidiaCDIHookPath, tc.ldconfigPath)
d, err := NewLDCacheUpdateHook(logger, mountMock, hookCreator, tc.ldconfigPath)
require.NoError(t, err)

hooks, err := d.Hooks()
Expand Down
20 changes: 12 additions & 8 deletions internal/discover/symlinks.go
Original file line number Diff line number Diff line change
Expand Up @@ -23,20 +23,20 @@ import (

type additionalSymlinks struct {
Discover
version string
nvidiaCDIHookPath string
version string
hookCreator HookCreator
}

// WithDriverDotSoSymlinks decorates the provided discoverer.
// A hook is added that checks for specific driver symlinks that need to be created.
func WithDriverDotSoSymlinks(mounts Discover, version string, nvidiaCDIHookPath string) Discover {
func WithDriverDotSoSymlinks(mounts Discover, version string, hookCreator HookCreator) Discover {
if version == "" {
version = "*.*"
}
return &additionalSymlinks{
Discover: mounts,
nvidiaCDIHookPath: nvidiaCDIHookPath,
version: version,
Discover: mounts,
hookCreator: hookCreator,
version: version,
}
}

Expand Down Expand Up @@ -73,8 +73,12 @@ func (d *additionalSymlinks) Hooks() ([]Hook, error) {
return hooks, nil
}

hook := CreateCreateSymlinkHook(d.nvidiaCDIHookPath, links).(Hook)
return append(hooks, hook), nil
createSymlinkHooks, err := d.hookCreator.Create("create-symlinks", links...).Hooks()
if err != nil {
return nil, fmt.Errorf("failed to create symlink hook: %v", err)
}

return append(hooks, createSymlinkHooks...), nil
}

// getLinksForMount maps the path to created links if any.
Expand Down
Loading