Skip to content

Commit 18d333f

Browse files
committed
Revert add prefetch for USM hip allocations a6b8fa66b537753415d24076f1025c040110c332
1 parent 2b77f79 commit 18d333f

File tree

4 files changed

+2
-79
lines changed

4 files changed

+2
-79
lines changed

source/adapters/hip/context.hpp

-53
Original file line numberDiff line numberDiff line change
@@ -10,7 +10,6 @@
1010
#pragma once
1111

1212
#include <set>
13-
#include <unordered_map>
1413

1514
#include "common.hpp"
1615
#include "device.hpp"
@@ -106,61 +105,9 @@ struct ur_context_handle_t_ {
106105

107106
ur_usm_pool_handle_t getOwningURPool(umf_memory_pool_t *UMFPool);
108107

109-
/// We need to keep track of USM mappings in AMD HIP, as certain extra
110-
/// synchronization *is* actually required for correctness.
111-
/// During kernel enqueue we must dispatch a prefetch for each kernel argument
112-
/// that points to a USM mapping to ensure the mapping is correctly
113-
/// populated on the device (https://github.com/intel/llvm/issues/7252). Thus,
114-
/// we keep track of mappings in the context, and then check against them just
115-
/// before the kernel is launched. The stream against which the kernel is
116-
/// launched is not known until enqueue time, but the USM mappings can happen
117-
/// at any time. Thus, they are tracked on the context used for the urUSM*
118-
/// mapping.
119-
///
120-
/// The three utility function are simple wrappers around a mapping from a
121-
/// pointer to a size.
122-
void addUSMMapping(void *Ptr, size_t Size) {
123-
std::lock_guard<std::mutex> Guard(Mutex);
124-
assert(USMMappings.find(Ptr) == USMMappings.end() &&
125-
"mapping already exists");
126-
USMMappings[Ptr] = Size;
127-
}
128-
129-
void removeUSMMapping(const void *Ptr) {
130-
std::lock_guard<std::mutex> guard(Mutex);
131-
auto It = USMMappings.find(Ptr);
132-
if (It != USMMappings.end())
133-
USMMappings.erase(It);
134-
}
135-
136-
std::pair<const void *, size_t> getUSMMapping(const void *Ptr) {
137-
std::lock_guard<std::mutex> Guard(Mutex);
138-
auto It = USMMappings.find(Ptr);
139-
// The simple case is the fast case...
140-
if (It != USMMappings.end())
141-
return *It;
142-
143-
// ... but in the failure case we have to fall back to a full scan to search
144-
// for "offset" pointers in case the user passes in the middle of an
145-
// allocation. We have to do some not-so-ordained-by-the-standard ordered
146-
// comparisons of pointers here, but it'll work on all platforms we support.
147-
uintptr_t PtrVal = (uintptr_t)Ptr;
148-
for (std::pair<const void *, size_t> Pair : USMMappings) {
149-
uintptr_t BaseAddr = (uintptr_t)Pair.first;
150-
uintptr_t EndAddr = BaseAddr + Pair.second;
151-
if (PtrVal > BaseAddr && PtrVal < EndAddr) {
152-
// If we've found something now, offset *must* be nonzero
153-
assert(Pair.second);
154-
return Pair;
155-
}
156-
}
157-
return {nullptr, 0};
158-
}
159-
160108
private:
161109
std::mutex Mutex;
162110
std::vector<deleter_data> ExtendedDeleters;
163-
std::unordered_map<const void *, size_t> USMMappings;
164111
std::set<ur_usm_pool_handle_t> PoolHandles;
165112
};
166113

source/adapters/hip/enqueue.cpp

+1-10
Original file line numberDiff line numberDiff line change
@@ -258,22 +258,13 @@ UR_APIEXPORT ur_result_t UR_APICALL urEnqueueKernelLaunch(
258258
try {
259259
ur_device_handle_t Dev = hQueue->getDevice();
260260
ScopedContext Active(Dev);
261-
ur_context_handle_t Ctx = hQueue->getContext();
262261

263262
uint32_t StreamToken;
264263
ur_stream_quard Guard;
265264
hipStream_t HIPStream = hQueue->getNextComputeStream(
266265
numEventsInWaitList, phEventWaitList, Guard, &StreamToken);
267266
hipFunction_t HIPFunc = hKernel->get();
268267

269-
hipDevice_t HIPDev = Dev->get();
270-
for (const void *P : hKernel->getPtrArgs()) {
271-
auto [Addr, Size] = Ctx->getUSMMapping(P);
272-
if (!Addr)
273-
continue;
274-
if (hipMemPrefetchAsync(Addr, Size, HIPDev, HIPStream) != hipSuccess)
275-
return UR_RESULT_ERROR_INVALID_KERNEL_ARGS;
276-
}
277268
Result = enqueueEventsWait(hQueue, HIPStream, numEventsInWaitList,
278269
phEventWaitList);
279270

@@ -315,7 +306,7 @@ UR_APIEXPORT ur_result_t UR_APICALL urEnqueueKernelLaunch(
315306
int DeviceMaxLocalMem = 0;
316307
UR_CHECK_ERROR(hipDeviceGetAttribute(
317308
&DeviceMaxLocalMem, hipDeviceAttributeMaxSharedMemoryPerBlock,
318-
HIPDev));
309+
Dev->get()));
319310

320311
static const int EnvVal = std::atoi(LocalMemSzPtr);
321312
if (EnvVal <= 0 || EnvVal > DeviceMaxLocalMem) {

source/adapters/hip/kernel.cpp

+1-1
Original file line numberDiff line numberDiff line change
@@ -259,7 +259,7 @@ urKernelGetSubGroupInfo(ur_kernel_handle_t hKernel, ur_device_handle_t hDevice,
259259
UR_APIEXPORT ur_result_t UR_APICALL urKernelSetArgPointer(
260260
ur_kernel_handle_t hKernel, uint32_t argIndex,
261261
const ur_kernel_arg_pointer_properties_t *, const void *pArgValue) {
262-
hKernel->setKernelPtrArg(argIndex, sizeof(pArgValue), pArgValue);
262+
hKernel->setKernelArg(argIndex, sizeof(pArgValue), pArgValue);
263263
return UR_RESULT_SUCCESS;
264264
}
265265

source/adapters/hip/kernel.hpp

-15
Original file line numberDiff line numberDiff line change
@@ -14,7 +14,6 @@
1414
#include <atomic>
1515
#include <cassert>
1616
#include <numeric>
17-
#include <set>
1817

1918
#include "program.hpp"
2019

@@ -58,7 +57,6 @@ struct ur_kernel_handle_t_ {
5857
args_size_t ParamSizes;
5958
args_index_t Indices;
6059
args_size_t OffsetPerIndex;
61-
std::set<const void *> PtrArgs;
6260

6361
std::uint32_t ImplicitOffsetArgs[3] = {0, 0, 0};
6462

@@ -179,19 +177,6 @@ struct ur_kernel_handle_t_ {
179177
Args.addArg(Index, Size, Arg);
180178
}
181179

182-
/// We track all pointer arguments to be able to issue prefetches at enqueue
183-
/// time
184-
void setKernelPtrArg(int Index, size_t Size, const void *PtrArg) {
185-
Args.PtrArgs.insert(*static_cast<void *const *>(PtrArg));
186-
setKernelArg(Index, Size, PtrArg);
187-
}
188-
189-
bool isPtrArg(const void *ptr) {
190-
return Args.PtrArgs.find(ptr) != Args.PtrArgs.end();
191-
}
192-
193-
std::set<const void *> &getPtrArgs() { return Args.PtrArgs; }
194-
195180
void setKernelLocalArg(int Index, size_t Size) {
196181
Args.addLocalArg(Index, Size);
197182
}

0 commit comments

Comments
 (0)