Skip to content

Commit 358d73f

Browse files
authored
[DevTSAN] Support device thread sanitizer for device globals (#17548)
1.Add a global '__TsanDeviceGlobalMetadata' to record device global's information 2.Read global meta data when build/link program done, and then poison related shadow memory
1 parent 946e9b8 commit 358d73f

File tree

7 files changed

+325
-16
lines changed

7 files changed

+325
-16
lines changed

llvm/lib/Transforms/Instrumentation/ThreadSanitizer.cpp

+72-2
Original file line numberDiff line numberDiff line change
@@ -118,13 +118,19 @@ struct ThreadSanitizerOnSpirv {
118118

119119
void initialize();
120120

121-
void instrumentKernelsMetadata();
121+
void instrumentModule();
122122

123123
void appendDebugInfoToArgs(Instruction *I, SmallVectorImpl<Value *> &Args);
124124

125125
private:
126+
void instrumentGlobalVariables();
127+
128+
void instrumentKernelsMetadata();
129+
126130
bool isSupportedSPIRKernel(Function &F);
127131

132+
bool isUnsupportedDeviceGlobal(const GlobalVariable &G);
133+
128134
GlobalVariable *GetOrCreateGlobalString(StringRef Name, StringRef Value,
129135
unsigned AddressSpace);
130136

@@ -243,7 +249,7 @@ PreservedAnalyses ModuleThreadSanitizerPass::run(Module &M,
243249
return PreservedAnalyses::all();
244250
if (Triple(M.getTargetTriple()).isSPIROrSPIRV()) {
245251
ThreadSanitizerOnSpirv Spirv(M);
246-
Spirv.instrumentKernelsMetadata();
252+
Spirv.instrumentModule();
247253
} else
248254
insertModuleCtor(M);
249255
return PreservedAnalyses::none();
@@ -327,6 +333,70 @@ bool ThreadSanitizerOnSpirv::isSupportedSPIRKernel(Function &F) {
327333
return true;
328334
}
329335

336+
bool ThreadSanitizerOnSpirv::isUnsupportedDeviceGlobal(
337+
const GlobalVariable &G) {
338+
if (G.user_empty())
339+
return true;
340+
// Skip instrumenting on "__TsanKernelMetadata" etc.
341+
if (G.getName().starts_with("__Tsan"))
342+
return true;
343+
if (G.getName().starts_with("__tsan_"))
344+
return true;
345+
if (G.getName().starts_with("__spirv_BuiltIn"))
346+
return true;
347+
if (G.getName().starts_with("__usid_str"))
348+
return true;
349+
// TODO: Will support global variable with local address space later.
350+
if (G.getAddressSpace() == kSpirOffloadLocalAS)
351+
return true;
352+
// Global variables have constant value or constant address space will not
353+
// trigger race condition.
354+
if (G.isConstant() || G.getAddressSpace() == kSpirOffloadConstantAS)
355+
return true;
356+
return false;
357+
}
358+
359+
void ThreadSanitizerOnSpirv::instrumentModule() {
360+
instrumentGlobalVariables();
361+
instrumentKernelsMetadata();
362+
}
363+
364+
void ThreadSanitizerOnSpirv::instrumentGlobalVariables() {
365+
SmallVector<Constant *, 8> DeviceGlobalMetadata;
366+
367+
// Device global metadata is described by a structure
368+
// size_t device_global_size
369+
// size_t beginning address of the device global
370+
StructType *StructTy = StructType::get(IntptrTy, IntptrTy);
371+
372+
for (auto &G : M.globals()) {
373+
if (isUnsupportedDeviceGlobal(G)) {
374+
for (auto *User : G.users())
375+
if (auto *Inst = dyn_cast<Instruction>(User))
376+
Inst->setNoSanitizeMetadata();
377+
continue;
378+
}
379+
380+
DeviceGlobalMetadata.push_back(ConstantStruct::get(
381+
StructTy,
382+
ConstantInt::get(IntptrTy, DL.getTypeAllocSize(G.getValueType())),
383+
ConstantExpr::getPointerCast(&G, IntptrTy)));
384+
}
385+
386+
if (DeviceGlobalMetadata.empty())
387+
return;
388+
389+
// Create meta data global to record device globals' information
390+
ArrayType *ArrayTy = ArrayType::get(StructTy, DeviceGlobalMetadata.size());
391+
Constant *MetadataInitializer =
392+
ConstantArray::get(ArrayTy, DeviceGlobalMetadata);
393+
GlobalVariable *MsanDeviceGlobalMetadata = new GlobalVariable(
394+
M, MetadataInitializer->getType(), false, GlobalValue::AppendingLinkage,
395+
MetadataInitializer, "__TsanDeviceGlobalMetadata", nullptr,
396+
GlobalValue::NotThreadLocal, 1);
397+
MsanDeviceGlobalMetadata->setUnnamedAddr(GlobalValue::UnnamedAddr::Local);
398+
}
399+
330400
void ThreadSanitizerOnSpirv::instrumentKernelsMetadata() {
331401
SmallVector<Constant *, 8> SpirKernelsMetadata;
332402

Original file line numberDiff line numberDiff line change
@@ -0,0 +1,20 @@
1+
; RUN: opt < %s -passes='function(tsan),module(tsan-module)' -tsan-instrument-func-entry-exit=0 -tsan-instrument-memintrinsics=0 -S | FileCheck %s
2+
target datalayout = "e-i64:64-v16:16-v24:32-v32:32-v48:64-v96:128-v192:256-v256:256-v512:512-v1024:1024-n8:16:32:64-G1"
3+
target triple = "spir64-unknown-unknown"
4+
5+
@dev_global = external addrspace(1) global { [4 x i32] }
6+
@dev_global_no_users = dso_local addrspace(1) global { [4 x i32] } zeroinitializer
7+
@.str = external addrspace(1) constant [59 x i8]
8+
@__spirv_BuiltInGlobalInvocationId = external addrspace(1) constant <3 x i64>
9+
10+
; CHECK: @__TsanDeviceGlobalMetadata
11+
; CHECK-NOT: @dev_global_no_users
12+
; CHECK-NOT: @.str
13+
; CHECK-NOT: @__spirv_BuiltInGlobalInvocationId
14+
; CHECK-SAME: @dev_global
15+
16+
define spir_func void @test() {
17+
entry:
18+
%call = call spir_func ptr addrspace(4) null(ptr addrspace(4) addrspacecast (ptr addrspace(1) @dev_global to ptr addrspace(4)), i64 0)
19+
ret void
20+
}
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,30 @@
1+
// REQUIRES: linux, cpu || (gpu && level_zero)
2+
// ALLOW_RETRIES: 10
3+
// RUN: %{build} %device_tsan_flags -O0 -g -o %t1.out
4+
// RUN: %{run} %t1.out 2>&1 | FileCheck %s
5+
// RUN: %{build} %device_tsan_flags -O2 -g -o %t2.out
6+
// RUN: %{run} %t2.out 2>&1 | FileCheck %s
7+
#include <sycl/detail/core.hpp>
8+
#include <sycl/ext/oneapi/device_global/device_global.hpp>
9+
10+
using namespace sycl;
11+
using namespace sycl::ext::oneapi;
12+
using namespace sycl::ext::oneapi::experimental;
13+
14+
sycl::ext::oneapi::experimental::device_global<
15+
int[4], decltype(properties(device_image_scope, host_access_read_write))>
16+
dev_global;
17+
18+
int main() {
19+
sycl::queue Q;
20+
21+
Q.submit([&](sycl::handler &h) {
22+
h.parallel_for<class Test>(sycl::nd_range<1>(32, 8),
23+
[=](sycl::nd_item<1>) { dev_global[0]++; });
24+
}).wait();
25+
// CHECK: WARNING: DeviceSanitizer: data race
26+
// CHECK-NEXT: When write of size 4 at 0x{{.*}} in kernel <{{.*}}Test>
27+
// CHECK-NEXT: #0 {{.*}}check_device_global.cpp:[[@LINE-4]]
28+
29+
return 0;
30+
}

unified-runtime/source/loader/layers/sanitizer/tsan/tsan_ddi.cpp

+132
Original file line numberDiff line numberDiff line change
@@ -130,6 +130,95 @@ ur_result_t urContextRelease(
130130
return UR_RESULT_SUCCESS;
131131
}
132132

133+
///////////////////////////////////////////////////////////////////////////////
134+
/// @brief Intercept function for urProgramBuild
135+
ur_result_t urProgramBuild(
136+
/// [in] handle of the context object
137+
ur_context_handle_t hContext,
138+
/// [in] handle of the program object
139+
ur_program_handle_t hProgram,
140+
/// [in] string of build options
141+
const char *pOptions) {
142+
getContext()->logger.debug("==== urProgramBuild");
143+
144+
UR_CALL(
145+
getContext()->urDdiTable.Program.pfnBuild(hContext, hProgram, pOptions));
146+
147+
UR_CALL(getTsanInterceptor()->registerProgram(hProgram));
148+
149+
return UR_RESULT_SUCCESS;
150+
}
151+
152+
///////////////////////////////////////////////////////////////////////////////
153+
/// @brief Intercept function for urProgramBuildExp
154+
ur_result_t urProgramBuildExp(
155+
/// [in] Handle of the program to build.
156+
ur_program_handle_t hProgram,
157+
/// [in] number of devices
158+
uint32_t numDevices,
159+
/// [in][range(0, numDevices)] pointer to array of device handles
160+
ur_device_handle_t *phDevices,
161+
/// [in][optional] pointer to build options null-terminated string.
162+
const char *pOptions) {
163+
getContext()->logger.debug("==== urProgramBuildExp");
164+
165+
UR_CALL(getContext()->urDdiTable.ProgramExp.pfnBuildExp(hProgram, numDevices,
166+
phDevices, pOptions));
167+
UR_CALL(getTsanInterceptor()->registerProgram(hProgram));
168+
169+
return UR_RESULT_SUCCESS;
170+
}
171+
172+
///////////////////////////////////////////////////////////////////////////////
173+
/// @brief Intercept function for urProgramLink
174+
ur_result_t urProgramLink(
175+
/// [in] handle of the context instance.
176+
ur_context_handle_t hContext,
177+
/// [in] number of program handles in `phPrograms`.
178+
uint32_t count,
179+
/// [in][range(0, count)] pointer to array of program handles.
180+
const ur_program_handle_t *phPrograms,
181+
/// [in][optional] pointer to linker options null-terminated string.
182+
const char *pOptions,
183+
/// [out] pointer to handle of program object created.
184+
ur_program_handle_t *phProgram) {
185+
getContext()->logger.debug("==== urProgramLink");
186+
187+
UR_CALL(getContext()->urDdiTable.Program.pfnLink(hContext, count, phPrograms,
188+
pOptions, phProgram));
189+
190+
UR_CALL(getTsanInterceptor()->registerProgram(*phProgram));
191+
192+
return UR_RESULT_SUCCESS;
193+
}
194+
195+
///////////////////////////////////////////////////////////////////////////////
196+
/// @brief Intercept function for urProgramLinkExp
197+
ur_result_t urProgramLinkExp(
198+
/// [in] handle of the context instance.
199+
ur_context_handle_t hContext,
200+
/// [in] number of devices
201+
uint32_t numDevices,
202+
/// [in][range(0, numDevices)] pointer to array of device handles
203+
ur_device_handle_t *phDevices,
204+
/// [in] number of program handles in `phPrograms`.
205+
uint32_t count,
206+
/// [in][range(0, count)] pointer to array of program handles.
207+
const ur_program_handle_t *phPrograms,
208+
/// [in][optional] pointer to linker options null-terminated string.
209+
const char *pOptions,
210+
/// [out] pointer to handle of program object created.
211+
ur_program_handle_t *phProgram) {
212+
getContext()->logger.debug("==== urProgramLinkExp");
213+
214+
UR_CALL(getContext()->urDdiTable.ProgramExp.pfnLinkExp(
215+
hContext, numDevices, phDevices, count, phPrograms, pOptions, phProgram));
216+
217+
UR_CALL(getTsanInterceptor()->registerProgram(*phProgram));
218+
219+
return UR_RESULT_SUCCESS;
220+
}
221+
133222
///////////////////////////////////////////////////////////////////////////////
134223
/// @brief Intercept function for urUSMDeviceAlloc
135224
__urdlllocal ur_result_t UR_APICALL urUSMDeviceAlloc(
@@ -283,6 +372,39 @@ __urdlllocal ur_result_t UR_APICALL urGetContextProcAddrTable(
283372
return result;
284373
}
285374

375+
///////////////////////////////////////////////////////////////////////////////
376+
/// @brief Exported function for filling application's Program table
377+
/// with current process' addresses
378+
///
379+
/// @returns
380+
/// - ::UR_RESULT_SUCCESS
381+
/// - ::UR_RESULT_ERROR_INVALID_NULL_POINTER
382+
ur_result_t urGetProgramProcAddrTable(
383+
/// [in,out] pointer to table of DDI function pointers
384+
ur_program_dditable_t *pDdiTable) {
385+
pDdiTable->pfnBuild = ur_sanitizer_layer::tsan::urProgramBuild;
386+
pDdiTable->pfnLink = ur_sanitizer_layer::tsan::urProgramLink;
387+
388+
return UR_RESULT_SUCCESS;
389+
}
390+
391+
/// @brief Exported function for filling application's ProgramExp table
392+
/// with current process' addresses
393+
///
394+
/// @returns
395+
/// - ::UR_RESULT_SUCCESS
396+
/// - ::UR_RESULT_ERROR_INVALID_NULL_POINTER
397+
ur_result_t urGetProgramExpProcAddrTable(
398+
/// [in,out] pointer to table of DDI function pointers
399+
ur_program_exp_dditable_t *pDdiTable) {
400+
ur_result_t result = UR_RESULT_SUCCESS;
401+
402+
pDdiTable->pfnBuildExp = ur_sanitizer_layer::tsan::urProgramBuildExp;
403+
pDdiTable->pfnLinkExp = ur_sanitizer_layer::tsan::urProgramLinkExp;
404+
405+
return result;
406+
}
407+
286408
///////////////////////////////////////////////////////////////////////////////
287409
/// @brief Exported function for filling application's USM table
288410
/// with current process' addresses
@@ -363,6 +485,16 @@ ur_result_t initTsanDDITable(ur_dditable_t *dditable) {
363485
UR_API_VERSION_CURRENT, &dditable->Context);
364486
}
365487

488+
if (UR_RESULT_SUCCESS == result) {
489+
result =
490+
ur_sanitizer_layer::tsan::urGetProgramProcAddrTable(&dditable->Program);
491+
}
492+
493+
if (UR_RESULT_SUCCESS == result) {
494+
result = ur_sanitizer_layer::tsan::urGetProgramExpProcAddrTable(
495+
&dditable->ProgramExp);
496+
}
497+
366498
if (UR_RESULT_SUCCESS == result) {
367499
result = ur_sanitizer_layer::tsan::urGetUSMProcAddrTable(
368500
UR_API_VERSION_CURRENT, &dditable->USM);

0 commit comments

Comments
 (0)