Skip to content

Commit 2f36148

Browse files
committed
Improve error handling around various Win32 APIs
Audited the codebase for suspicious calls and improved the error handling.
1 parent 19cdeaa commit 2f36148

File tree

4 files changed

+49
-31
lines changed

4 files changed

+49
-31
lines changed

Sources/SWBUtil/Library.swift

+39-7
Original file line numberDiff line numberDiff line change
@@ -64,15 +64,18 @@ public enum Library: Sendable {
6464
}
6565

6666
public static func locate<T>(_ pointer: T.Type) -> Path {
67-
let outPointer: UnsafeMutablePointer<CInterop.PlatformChar>
6867
#if os(Windows)
6968
var handle: HMODULE?
70-
GetModuleHandleExW(DWORD(GET_MODULE_HANDLE_EX_FLAG_FROM_ADDRESS | GET_MODULE_HANDLE_EX_FLAG_UNCHANGED_REFCOUNT), unsafeBitCast(pointer, to: LPCWSTR?.self), &handle)
71-
let capacity = 260
72-
outPointer = .allocate(capacity: capacity)
73-
defer { outPointer.deallocate() }
74-
GetModuleFileNameW(handle, outPointer, DWORD(capacity))
69+
guard GetModuleHandleExW(DWORD(GET_MODULE_HANDLE_EX_FLAG_FROM_ADDRESS | GET_MODULE_HANDLE_EX_FLAG_UNCHANGED_REFCOUNT), unsafeBitCast(pointer, to: LPCWSTR?.self), &handle) else {
70+
return Path("")
71+
}
72+
do {
73+
return try Path(SWB_GetModuleFileNameW(handle))
74+
} catch {
75+
return Path("")
76+
}
7577
#else
78+
let outPointer: UnsafeMutablePointer<CInterop.PlatformChar>
7679
var info = Dl_info()
7780
#if os(Android)
7881
dladdr(unsafeBitCast(pointer, to: UnsafeMutableRawPointer.self), &info)
@@ -81,8 +84,8 @@ public enum Library: Sendable {
8184
dladdr(unsafeBitCast(pointer, to: UnsafeMutableRawPointer?.self), &info)
8285
outPointer = UnsafeMutablePointer(mutating: info.dli_fname)
8386
#endif
84-
#endif
8587
return Path(platformString: outPointer)
88+
#endif
8689
}
8790
}
8891

@@ -114,3 +117,32 @@ public struct LibraryHandle: @unchecked Sendable {
114117
self.rawValue = rawValue
115118
}
116119
}
120+
121+
#if os(Windows)
122+
@_spi(Testing) public func SWB_GetModuleFileNameW(_ hModule: HMODULE?) throws -> String {
123+
#if DEBUG
124+
var bufferCount = Int(1) // force looping
125+
#else
126+
var bufferCount = Int(MAX_PATH)
127+
#endif
128+
while true {
129+
if let result = try withUnsafeTemporaryAllocation(of: WCHAR.self, capacity: bufferCount, { buffer in
130+
switch (GetModuleFileNameW(hModule, buffer.baseAddress!, DWORD(buffer.count)), GetLastError()) {
131+
case (1..<DWORD(bufferCount), DWORD(ERROR_SUCCESS)):
132+
guard let result = String.decodeCString(buffer.baseAddress!, as: UTF16.self)?.result else {
133+
throw Win32Error(DWORD(ERROR_ILLEGAL_CHARACTER))
134+
}
135+
return result
136+
case (DWORD(bufferCount), DWORD(ERROR_INSUFFICIENT_BUFFER)):
137+
bufferCount += Int(MAX_PATH)
138+
return nil
139+
case (_, let errorCode):
140+
throw Win32Error(errorCode)
141+
}
142+
}) {
143+
return result
144+
}
145+
}
146+
preconditionFailure("unreachable")
147+
}
148+
#endif

Sources/SWBUtil/POSIX.swift

+1-1
Original file line numberDiff line numberDiff line change
@@ -35,7 +35,7 @@ public enum POSIX: Sendable {
3535
}
3636
return try withUnsafeTemporaryAllocation(of: WCHAR.self, capacity: Int(dwLength)) {
3737
switch GetEnvironmentVariableW(wName, $0.baseAddress!, DWORD($0.count)) {
38-
case dwLength - 1:
38+
case 1..<dwLength:
3939
return String(decodingCString: $0.baseAddress!, as: CInterop.PlatformUnicodeEncoding.self)
4040
case 0 where GetLastError() == ERROR_ENVVAR_NOT_FOUND:
4141
return nil

Sources/SWBUtil/ProcessInfo.swift

+4-2
Original file line numberDiff line numberDiff line change
@@ -62,8 +62,10 @@ extension ProcessInfo {
6262
var capacity = UNLEN + 1
6363
let pointer = UnsafeMutablePointer<CInterop.PlatformChar>.allocate(capacity: Int(capacity))
6464
defer { pointer.deallocate() }
65-
GetUserNameW(pointer, &capacity)
66-
return String(platformString: pointer)
65+
if GetUserNameW(pointer, &capacity) {
66+
return String(platformString: pointer)
67+
}
68+
return ""
6769
#else
6870
let uid = geteuid().orIfZero(getuid())
6971
return (getpwuid(uid)?.pointee.pw_name).map { String(cString: $0) } ?? String(uid)

Tests/SwiftBuildTests/ConsoleCommands/CLIConnection.swift

+5-21
Original file line numberDiff line numberDiff line change
@@ -12,7 +12,7 @@
1212

1313
import Foundation
1414
import SWBTestSupport
15-
import SWBUtil
15+
@_spi(Testing) import SWBUtil
1616
import SWBCore
1717
import SWBLibc
1818
import SwiftBuild
@@ -319,25 +319,7 @@ fileprivate func swiftRuntimePath() throws -> Path? {
319319
guard let handle = GetModuleHandleW(wName) else {
320320
throw Win32Error(GetLastError())
321321
}
322-
323-
var capacity = MAX_PATH
324-
var path = ""
325-
while path.isEmpty {
326-
try withUnsafeTemporaryAllocation(of: WCHAR.self, capacity: Int(capacity)) {
327-
let dwLength = GetModuleFileNameW(handle, $0.baseAddress!, DWORD($0.count))
328-
switch dwLength {
329-
case 0:
330-
throw Win32Error(GetLastError())
331-
default:
332-
if GetLastError() == ERROR_INSUFFICIENT_BUFFER {
333-
capacity *= 2
334-
} else {
335-
path = String(decodingCString: $0.baseAddress!, as: CInterop.PlatformUnicodeEncoding.self)
336-
}
337-
}
338-
}
339-
}
340-
return Path(path).dirname
322+
return try Path(SWB_GetModuleFileNameW(handle)).dirname
341323
}
342324
#else
343325
return nil
@@ -352,10 +334,12 @@ fileprivate func systemRoot() throws -> Path? {
352334
}
353335
return try withUnsafeTemporaryAllocation(of: WCHAR.self, capacity: Int(dwLength)) {
354336
switch GetWindowsDirectoryW($0.baseAddress!, DWORD($0.count)) {
337+
case 1..<dwLength:
338+
return Path(String(decodingCString: $0.baseAddress!, as: CInterop.PlatformUnicodeEncoding.self))
355339
case 0:
356340
throw Win32Error(GetLastError())
357341
default:
358-
return Path(String(decodingCString: $0.baseAddress!, as: CInterop.PlatformUnicodeEncoding.self))
342+
throw Win32Error(DWORD(ERROR_INSUFFICIENT_BUFFER))
359343
}
360344
}
361345
#else

0 commit comments

Comments
 (0)