Skip to content

Commit ff90fb7

Browse files
authored
Cuda visible devices logic fix #12 (#13)
* Fixed an error regarding ebsence of CUDA_VISIBLE_DEVICES * Fixed values in array with occupied devices
1 parent 3f1bc8a commit ff90fb7

File tree

1 file changed

+23
-14
lines changed

1 file changed

+23
-14
lines changed

src/device_selection.cu

Lines changed: 23 additions & 14 deletions
Original file line numberDiff line numberDiff line change
@@ -2,6 +2,7 @@
22
#include <vector>
33
#include <string>
44
#include <sstream>
5+
#include <utility>
56
#include <unistd.h>
67

78
#include <nvml.h>
@@ -82,44 +83,52 @@ vector<int> readCudaVisibleDevices() {
8283
return res;
8384
}
8485

85-
vector<int> getAvailableDevices() {
86+
pair<vector<int>, bool> getAvailableDevices() {
8687
vector<int> visibleDevices = readCudaVisibleDevices();
8788

89+
// cout << ">>>>> visibleDevices.empty()" << visibleDevices.empty() << endl;
90+
8891
if (visibleDevices.empty())
89-
return getAllPhysicallyAvailableDevices();
92+
return make_pair(getAllPhysicallyAvailableDevices(), false);
9093

91-
return visibleDevices;
94+
return make_pair(visibleDevices, true);
9295
}
9396

9497
extern "C"
95-
int occupyDevices(int requestedDevicesCount, int * occupiedDevicesIdxs, char * errorMsg) {
98+
int occupyDevices(int requestedDevicesCount, int * occupiedDevicesIdxs, char * errorMsgOut) {
9699
try {
97-
vector<int> availableDevcices = getAvailableDevices();
100+
auto availableDevcicesPair = getAvailableDevices();
101+
auto availableDevices = availableDevcicesPair.first;
102+
bool cudaVisibleDevciesSetProperly = availableDevcicesPair.second;
98103

99-
if ((int)availableDevcices.size() < requestedDevicesCount) {
104+
if ((int)availableDevices.size() < requestedDevicesCount) {
100105
string msg = "There are not as many free devices as requested. Requested devices count: "
101-
+ to_string(requestedDevicesCount) + ". Available devices count: " + to_string(availableDevcices.size()) + ".";
106+
+ to_string(requestedDevicesCount) + ". Available devices count: " + to_string(availableDevices.size()) + ".";
102107

103-
memcpy(errorMsg, msg.c_str(), msg.length());
108+
memcpy(errorMsgOut, msg.c_str(), msg.length());
104109

105110
return -1;
106111
}
107112

108113
int nextDeviceIdx = 0;
109114
for (int i = 0; i < requestedDevicesCount; i++) {
110-
int deviceIdx = availableDevcices[i];
111-
gpuErrchk( cudaSetDevice(i), deviceIdx, errorMsg );
115+
int deviceIdx = -1;
116+
if (cudaVisibleDevciesSetProperly)
117+
deviceIdx = i;
118+
else
119+
deviceIdx = availableDevices[i];
120+
gpuErrchk( cudaSetDevice(deviceIdx), deviceIdx, errorMsgOut );
112121

113122
//call some API functions to really occupy device (I'm lazy to look for more elegant way to do it)
114123
char * ddata;
115-
gpuErrchk( cudaMalloc(&ddata, 1), deviceIdx, errorMsg );
116-
gpuErrchk( cudaFree(ddata), deviceIdx, errorMsg );
124+
gpuErrchk( cudaMalloc(&ddata, 1), deviceIdx, errorMsgOut );
125+
gpuErrchk( cudaFree(ddata), deviceIdx, errorMsgOut );
117126

118-
occupiedDevicesIdxs[nextDeviceIdx++] = i;
127+
occupiedDevicesIdxs[nextDeviceIdx++] = deviceIdx;
119128
}
120129
} catch (const std::exception& e) {
121130
auto msg = string(e.what());
122-
memcpy(errorMsg, msg.c_str(), msg.length());
131+
memcpy(errorMsgOut, msg.c_str(), msg.length());
123132
return -1;
124133
}
125134

0 commit comments

Comments
 (0)