|
2 | 2 | #include <vector>
|
3 | 3 | #include <string>
|
4 | 4 | #include <sstream>
|
| 5 | +#include <utility> |
5 | 6 | #include <unistd.h>
|
6 | 7 |
|
7 | 8 | #include <nvml.h>
|
@@ -82,44 +83,52 @@ vector<int> readCudaVisibleDevices() {
|
82 | 83 | return res;
|
83 | 84 | }
|
84 | 85 |
|
85 |
| -vector<int> getAvailableDevices() { |
| 86 | +pair<vector<int>, bool> getAvailableDevices() { |
86 | 87 | vector<int> visibleDevices = readCudaVisibleDevices();
|
87 | 88 |
|
| 89 | + // cout << ">>>>> visibleDevices.empty()" << visibleDevices.empty() << endl; |
| 90 | + |
88 | 91 | if (visibleDevices.empty())
|
89 |
| - return getAllPhysicallyAvailableDevices(); |
| 92 | + return make_pair(getAllPhysicallyAvailableDevices(), false); |
90 | 93 |
|
91 |
| - return visibleDevices; |
| 94 | + return make_pair(visibleDevices, true); |
92 | 95 | }
|
93 | 96 |
|
94 | 97 | extern "C"
|
95 |
| -int occupyDevices(int requestedDevicesCount, int * occupiedDevicesIdxs, char * errorMsg) { |
| 98 | +int occupyDevices(int requestedDevicesCount, int * occupiedDevicesIdxs, char * errorMsgOut) { |
96 | 99 | try {
|
97 |
| - vector<int> availableDevcices = getAvailableDevices(); |
| 100 | + auto availableDevcicesPair = getAvailableDevices(); |
| 101 | + auto availableDevices = availableDevcicesPair.first; |
| 102 | + bool cudaVisibleDevciesSetProperly = availableDevcicesPair.second; |
98 | 103 |
|
99 |
| - if ((int)availableDevcices.size() < requestedDevicesCount) { |
| 104 | + if ((int)availableDevices.size() < requestedDevicesCount) { |
100 | 105 | string msg = "There are not as many free devices as requested. Requested devices count: "
|
101 |
| - + to_string(requestedDevicesCount) + ". Available devices count: " + to_string(availableDevcices.size()) + "."; |
| 106 | + + to_string(requestedDevicesCount) + ". Available devices count: " + to_string(availableDevices.size()) + "."; |
102 | 107 |
|
103 |
| - memcpy(errorMsg, msg.c_str(), msg.length()); |
| 108 | + memcpy(errorMsgOut, msg.c_str(), msg.length()); |
104 | 109 |
|
105 | 110 | return -1;
|
106 | 111 | }
|
107 | 112 |
|
108 | 113 | int nextDeviceIdx = 0;
|
109 | 114 | for (int i = 0; i < requestedDevicesCount; i++) {
|
110 |
| - int deviceIdx = availableDevcices[i]; |
111 |
| - gpuErrchk( cudaSetDevice(i), deviceIdx, errorMsg ); |
| 115 | + int deviceIdx = -1; |
| 116 | + if (cudaVisibleDevciesSetProperly) |
| 117 | + deviceIdx = i; |
| 118 | + else |
| 119 | + deviceIdx = availableDevices[i]; |
| 120 | + gpuErrchk( cudaSetDevice(deviceIdx), deviceIdx, errorMsgOut ); |
112 | 121 |
|
113 | 122 | //call some API functions to really occupy device (I'm lazy to look for more elegant way to do it)
|
114 | 123 | char * ddata;
|
115 |
| - gpuErrchk( cudaMalloc(&ddata, 1), deviceIdx, errorMsg ); |
116 |
| - gpuErrchk( cudaFree(ddata), deviceIdx, errorMsg ); |
| 124 | + gpuErrchk( cudaMalloc(&ddata, 1), deviceIdx, errorMsgOut ); |
| 125 | + gpuErrchk( cudaFree(ddata), deviceIdx, errorMsgOut ); |
117 | 126 |
|
118 |
| - occupiedDevicesIdxs[nextDeviceIdx++] = i; |
| 127 | + occupiedDevicesIdxs[nextDeviceIdx++] = deviceIdx; |
119 | 128 | }
|
120 | 129 | } catch (const std::exception& e) {
|
121 | 130 | auto msg = string(e.what());
|
122 |
| - memcpy(errorMsg, msg.c_str(), msg.length()); |
| 131 | + memcpy(errorMsgOut, msg.c_str(), msg.length()); |
123 | 132 | return -1;
|
124 | 133 | }
|
125 | 134 |
|
|
0 commit comments