Skip to content

Commit d3cd1df

Browse files
committed
[cudacompat] update version comparision to account for leading zeros in version strings
Any version string passed as an argument to semver.Compare() must be a valid semantic version. It is common for NVIDIA driver versions to have leading zeros in the MINOR or PATCH portion of a version string, e.g. 575.57.08. As a result, a call to semver.Compare("575.57.08", "575.10.10") would incorrectly return -1 because the first argument is not a valid semantic version. And from https://pkg.go.dev/golang.org/x/mod/semver#Compare: ''' An invalid semantic version string is considered less than a valid one. All invalid semantic version strings compare equal to each other. ''' Signed-off-by: Christopher Desiniotis <cdesiniotis@nvidia.com>
1 parent 60b67f2 commit d3cd1df

2 files changed

Lines changed: 112 additions & 3 deletions

File tree

cmd/nvidia-cdi-hook/cudacompat/cuda-elf-header.go

Lines changed: 22 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -143,7 +143,7 @@ func (h *compatElfHeader) UseCompat(compatDriverVersion string, hostDriverVersio
143143
return false
144144
}
145145

146-
return semver.Compare(normalizeVersion(compatDriverVersion), normalizeVersion(hostDriverVersion)) > 0
146+
return compareVersions(compatDriverVersion, hostDriverVersion) > 0
147147
}
148148

149149
type cudaVersion string
@@ -155,9 +155,28 @@ func (containerVersion cudaVersion) UseCompat(hostVersion string) bool {
155155
return false
156156
}
157157

158-
return semver.Compare(normalizeVersion(containerVersion), normalizeVersion(hostVersion)) > 0
158+
return compareVersions(containerVersion, hostVersion) > 0
159159
}
160160

161+
func compareVersions[T string | cudaVersion, O string | cudaVersion](this T, other O) int {
162+
return semver.Compare(normalizeVersion(this), normalizeVersion(other))
163+
}
164+
165+
// normalizeVersion converts the given version into a valid semantic version.
166+
// This function will always return a string in the format of vMAJOR.MINOR.PATCH
167+
// It accounts for version strings that have leading zeros, which is common
168+
// in NVIDIA driver version strings. For example, 570.211.01 will be converted to
169+
// v570.22.1
161170
func normalizeVersion[T string | cudaVersion](v T) string {
162-
return "v" + strings.TrimPrefix(string(v), "v")
171+
majorMinorPatch := []string{"0", "0", "0"}
172+
versionParts := strings.SplitN(strings.TrimPrefix(string(v), "v"), ".", 3)
173+
for i, versionPart := range versionParts {
174+
trimmed := strings.TrimLeft(versionPart, "0")
175+
if trimmed == "" {
176+
trimmed = "0"
177+
}
178+
majorMinorPatch[i] = trimmed
179+
}
180+
181+
return "v" + strings.Join(majorMinorPatch, ".")
163182
}

cmd/nvidia-cdi-hook/cudacompat/cuda-elf-header_test.go

Lines changed: 90 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -183,3 +183,93 @@ func TestUseCompat(t *testing.T) {
183183
})
184184
}
185185
}
186+
187+
func TestCompareVersions(t *testing.T) {
188+
testCases := []struct {
189+
description string
190+
a string
191+
b string
192+
expected int
193+
}{
194+
{
195+
description: "empty",
196+
expected: 0,
197+
},
198+
{
199+
description: "less than",
200+
a: "1.2.3",
201+
b: "2.4.5",
202+
expected: -1,
203+
},
204+
{
205+
description: "equal",
206+
a: "1.1.1",
207+
b: "1.1.1",
208+
expected: 0,
209+
},
210+
{
211+
description: "equal with leading zeros in version string",
212+
a: "1.1.1",
213+
b: "1.01.1",
214+
expected: 0,
215+
},
216+
{
217+
description: "greater than",
218+
a: "2.4.5",
219+
b: "2.4.4",
220+
expected: 1,
221+
},
222+
}
223+
for _, tc := range testCases {
224+
t.Run(tc.description, func(t *testing.T) {
225+
require.EqualValues(t, tc.expected, compareVersions(tc.a, tc.b))
226+
})
227+
}
228+
229+
}
230+
231+
func TestNormalizeVersion(t *testing.T) {
232+
testCases := []struct {
233+
description string
234+
input string
235+
expected string
236+
}{
237+
{
238+
description: "empty",
239+
input: "",
240+
expected: "v0.0.0",
241+
},
242+
{
243+
description: "major is 0",
244+
input: "v0.1.2",
245+
expected: "v0.1.2",
246+
},
247+
{
248+
description: "major only",
249+
input: "1",
250+
expected: "v1.0.0",
251+
},
252+
{
253+
description: "major and minor only",
254+
input: "1.1",
255+
expected: "v1.1.0",
256+
},
257+
{
258+
description: "zero-padded version",
259+
input: "01.02.03",
260+
expected: "v1.2.3",
261+
},
262+
{
263+
description: "valid semantic version",
264+
input: "v1.2.3-4+567",
265+
expected: "v1.2.3-4+567",
266+
},
267+
}
268+
269+
for _, tc := range testCases {
270+
t.Run(tc.description, func(t *testing.T) {
271+
output := normalizeVersion(tc.input)
272+
require.EqualValues(t, tc.expected, output)
273+
})
274+
}
275+
}

0 commit comments

Comments
 (0)