@@ -21,6 +21,7 @@ import (
2121 "fmt"
2222 "os"
2323 "path/filepath"
24+ "slices"
2425 "strings"
2526 "sync"
2627
@@ -43,8 +44,8 @@ type Driver struct {
4344
4445 // version caches the driver version.
4546 version string
46- // driverLibDirectory caches the path to parent of the driver libraries
47- driverLibDirectory string
47+ // driverLibDirectories caches the paths to parent of the driver libraries
48+ driverLibDirectories [] string
4849}
4950
5051// New creates a new Driver root using the specified options.
@@ -70,13 +71,13 @@ func New(opts ...Option) *Driver {
7071 }
7172
7273 d := & Driver {
73- logger : o .logger ,
74- Root : o .Root ,
75- DevRoot : o .DevRoot ,
76- librarySearchPaths : o .librarySearchPaths ,
77- configSearchPaths : o .configSearchPaths ,
78- version : driverVersion ,
79- driverLibDirectory : "" ,
74+ logger : o .logger ,
75+ Root : o .Root ,
76+ DevRoot : o .DevRoot ,
77+ librarySearchPaths : o .librarySearchPaths ,
78+ configSearchPaths : o .configSearchPaths ,
79+ version : driverVersion ,
80+ driverLibDirectories : nil ,
8081 }
8182
8283 return d
@@ -97,34 +98,36 @@ func (r *Driver) Version() (string, error) {
9798 return r .version , nil
9899}
99100
100- // GetDriverLibDirectory returns the cached directory where the driver libs are
101- // found if possible.
101+ // GetDriverLibDirectories returns the cached directories where the driver libs
102+ // are found if possible.
102103// If this has not yet been initialized, the path is first detected and then returned.
103- func (r * Driver ) GetDriverLibDirectory () (string , error ) {
104+ func (r * Driver ) GetDriverLibDirectories () ([] string , error ) {
104105 r .Lock ()
105106 defer r .Unlock ()
106107
107- if r . driverLibDirectory == "" {
108+ if len ( r . driverLibDirectories ) == 0 {
108109 if err := r .updateInfo (); err != nil {
109- return "" , err
110+ return nil , err
110111 }
111112 }
112113
113- return r .driverLibDirectory , nil
114+ return r .driverLibDirectories , nil
114115}
115116
116117func (r * Driver ) DriverLibraryLocator (additionalDirs ... string ) (lookup.Locator , error ) {
117- libcudasoParentDirPath , err := r .GetDriverLibDirectory ()
118+ libcudasoParentDirPaths , err := r .GetDriverLibDirectories ()
118119 if err != nil {
119120 return nil , fmt .Errorf ("failed to get libcuda.so parent directory: %w" , err )
120121 }
121122
122- searchPaths := [] string { libcudasoParentDirPath }
123+ searchPaths := slices . Clone ( libcudasoParentDirPaths )
123124 for _ , dir := range additionalDirs {
124125 if strings .HasPrefix (dir , "/" ) {
125126 searchPaths = append (searchPaths , dir )
126127 } else {
127- searchPaths = append (searchPaths , filepath .Join (libcudasoParentDirPath , dir ))
128+ for _ , libcudasoParentDirPath := range libcudasoParentDirPaths {
129+ searchPaths = append (searchPaths , filepath .Join (libcudasoParentDirPath , dir ))
130+ }
128131 }
129132 }
130133
@@ -141,16 +144,33 @@ func (r *Driver) DriverLibraryLocator(additionalDirs ...string) (lookup.Locator,
141144}
142145
143146func (r * Driver ) updateInfo () error {
144- driverLibPath , version , err := r .inferVersion ()
147+ _ , version , err := r .inferVersion ()
145148 if err != nil {
146149 return err
147150 }
148151 if r .version != "" && r .version != version {
149152 return fmt .Errorf ("unexpected version detected: %v != %v" , r .version , version )
150153 }
151154
155+ versionedDriverLibPaths , err := r .Libraries ().Locate ("lib*.so." + version )
156+ if err != nil {
157+ return fmt .Errorf ("failed to locate versioned driver libraries: %w" , err )
158+ }
159+
160+ var uniqueDirs []string
161+ seen := make (map [string ]bool )
162+
163+ for _ , path := range versionedDriverLibPaths {
164+ dir := filepath .Dir (path )
165+ if seen [dir ] {
166+ continue
167+ }
168+ seen [dir ] = true
169+ uniqueDirs = append (uniqueDirs , r .RelativeToRoot (dir ))
170+ }
171+
152172 r .version = version
153- r .driverLibDirectory = r . RelativeToRoot ( filepath . Dir ( driverLibPath ))
173+ r .driverLibDirectories = uniqueDirs
154174
155175 return nil
156176}
@@ -167,7 +187,7 @@ func (r *Driver) inferVersion() (string, string, error) {
167187 for _ , driverLib := range []string {"libcuda.so." , "libnvidia-ml.so." } {
168188 driverLibPaths , err := r .Libraries ().Locate (driverLib + versionSuffix )
169189 if err != nil {
170- errs = errors .Join (errs , fmt .Errorf ("failed to locate libcuda.so : %w" , err ))
190+ errs = errors .Join (errs , fmt .Errorf ("failed to locate %q : %w" , driverLib , err ))
171191 continue
172192 }
173193 driverLibPath := driverLibPaths [0 ]
0 commit comments