@@ -82,8 +82,11 @@ class OpLowerer {
82
82
public:
83
83
OpLowerer (Module &M, DXILResourceMap &DRM) : M(M), OpBuilder(M), DRM(DRM) {}
84
84
85
- void replaceFunction (Function &F,
86
- llvm::function_ref<Error(CallInst *CI)> ReplaceCall) {
85
+ // / Replace every call to \c F using \c ReplaceCall, and then erase \c F. If
86
+ // / there is an error replacing a call, we emit a diagnostic and return true.
87
+ [[nodiscard]] bool
88
+ replaceFunction (Function &F,
89
+ llvm::function_ref<Error(CallInst *CI)> ReplaceCall) {
87
90
for (User *U : make_early_inc_range (F.users ())) {
88
91
CallInst *CI = dyn_cast<CallInst>(U);
89
92
if (!CI)
@@ -94,16 +97,18 @@ class OpLowerer {
94
97
DiagnosticInfoUnsupported Diag (*CI->getFunction (), Message,
95
98
CI->getDebugLoc ());
96
99
M.getContext ().diagnose (Diag);
97
- continue ;
100
+ return true ;
98
101
}
99
102
}
100
103
if (F.user_empty ())
101
104
F.eraseFromParent ();
105
+ return false ;
102
106
}
103
107
104
- void replaceFunctionWithOp (Function &F, dxil::OpCode DXILOp) {
108
+ [[nodiscard]]
109
+ bool replaceFunctionWithOp (Function &F, dxil::OpCode DXILOp) {
105
110
bool IsVectorArgExpansion = isVectorArgExpansion (F);
106
- replaceFunction (F, [&](CallInst *CI) -> Error {
111
+ return replaceFunction (F, [&](CallInst *CI) -> Error {
107
112
SmallVector<Value *> Args;
108
113
OpBuilder.getIRB ().SetInsertPoint (CI);
109
114
if (IsVectorArgExpansion) {
@@ -175,12 +180,12 @@ class OpLowerer {
175
180
CleanupCasts.clear ();
176
181
}
177
182
178
- void lowerToCreateHandle (Function &F) {
183
+ [[nodiscard]] bool lowerToCreateHandle (Function &F) {
179
184
IRBuilder<> &IRB = OpBuilder.getIRB ();
180
185
Type *Int8Ty = IRB.getInt8Ty ();
181
186
Type *Int32Ty = IRB.getInt32Ty ();
182
187
183
- replaceFunction (F, [&](CallInst *CI) -> Error {
188
+ return replaceFunction (F, [&](CallInst *CI) -> Error {
184
189
IRB.SetInsertPoint (CI);
185
190
186
191
auto *It = DRM.find (CI);
@@ -205,10 +210,10 @@ class OpLowerer {
205
210
});
206
211
}
207
212
208
- void lowerToBindAndAnnotateHandle (Function &F) {
213
+ [[nodiscard]] bool lowerToBindAndAnnotateHandle (Function &F) {
209
214
IRBuilder<> &IRB = OpBuilder.getIRB ();
210
215
211
- replaceFunction (F, [&](CallInst *CI) -> Error {
216
+ return replaceFunction (F, [&](CallInst *CI) -> Error {
212
217
IRB.SetInsertPoint (CI);
213
218
214
219
auto *It = DRM.find (CI);
@@ -251,12 +256,11 @@ class OpLowerer {
251
256
252
257
// / Lower `dx.handle.fromBinding` intrinsics depending on the shader model and
253
258
// / taking into account binding information from DXILResourceAnalysis.
254
- void lowerHandleFromBinding (Function &F) {
259
+ bool lowerHandleFromBinding (Function &F) {
255
260
Triple TT (Triple (M.getTargetTriple ()));
256
261
if (TT.getDXILVersion () < VersionTuple (1 , 6 ))
257
- lowerToCreateHandle (F);
258
- else
259
- lowerToBindAndAnnotateHandle (F);
262
+ return lowerToCreateHandle (F);
263
+ return lowerToBindAndAnnotateHandle (F);
260
264
}
261
265
262
266
// / Replace uses of \c Intrin with the values in the `dx.ResRet` of \c Op.
@@ -342,11 +346,11 @@ class OpLowerer {
342
346
return Error::success ();
343
347
}
344
348
345
- void lowerTypedBufferLoad (Function &F) {
349
+ [[nodiscard]] bool lowerTypedBufferLoad (Function &F) {
346
350
IRBuilder<> &IRB = OpBuilder.getIRB ();
347
351
Type *Int32Ty = IRB.getInt32Ty ();
348
352
349
- replaceFunction (F, [&](CallInst *CI) -> Error {
353
+ return replaceFunction (F, [&](CallInst *CI) -> Error {
350
354
IRB.SetInsertPoint (CI);
351
355
352
356
Value *Handle =
@@ -368,8 +372,51 @@ class OpLowerer {
368
372
});
369
373
}
370
374
375
+ [[nodiscard]] bool lowerTypedBufferStore (Function &F) {
376
+ IRBuilder<> &IRB = OpBuilder.getIRB ();
377
+ Type *Int8Ty = IRB.getInt8Ty ();
378
+ Type *Int32Ty = IRB.getInt32Ty ();
379
+
380
+ return replaceFunction (F, [&](CallInst *CI) -> Error {
381
+ IRB.SetInsertPoint (CI);
382
+
383
+ Value *Handle =
384
+ createTmpHandleCast (CI->getArgOperand (0 ), OpBuilder.getHandleType ());
385
+ Value *Index0 = CI->getArgOperand (1 );
386
+ Value *Index1 = UndefValue::get (Int32Ty);
387
+ // For typed stores, the mask must always cover all four elements.
388
+ Constant *Mask = ConstantInt::get (Int8Ty, 0xF );
389
+
390
+ Value *Data = CI->getArgOperand (2 );
391
+ auto *DataTy = dyn_cast<FixedVectorType>(Data->getType ());
392
+ if (!DataTy || DataTy->getNumElements () != 4 )
393
+ return make_error<StringError>(
394
+ " typedBufferStore data must be a vector of 4 elements" ,
395
+ inconvertibleErrorCode ());
396
+ Value *Data0 =
397
+ IRB.CreateExtractElement (Data, ConstantInt::get (Int32Ty, 0 ));
398
+ Value *Data1 =
399
+ IRB.CreateExtractElement (Data, ConstantInt::get (Int32Ty, 1 ));
400
+ Value *Data2 =
401
+ IRB.CreateExtractElement (Data, ConstantInt::get (Int32Ty, 2 ));
402
+ Value *Data3 =
403
+ IRB.CreateExtractElement (Data, ConstantInt::get (Int32Ty, 3 ));
404
+
405
+ std::array<Value *, 8 > Args{Handle , Index0, Index1, Data0,
406
+ Data1, Data2, Data3, Mask};
407
+ Expected<CallInst *> OpCall =
408
+ OpBuilder.tryCreateOp (OpCode::BufferStore, Args);
409
+ if (Error E = OpCall.takeError ())
410
+ return E;
411
+
412
+ CI->eraseFromParent ();
413
+ return Error::success ();
414
+ });
415
+ }
416
+
371
417
bool lowerIntrinsics () {
372
418
bool Updated = false ;
419
+ bool HasErrors = false ;
373
420
374
421
for (Function &F : make_early_inc_range (M.functions ())) {
375
422
if (!F.isDeclaration ())
@@ -380,19 +427,22 @@ class OpLowerer {
380
427
continue ;
381
428
#define DXIL_OP_INTRINSIC (OpCode, Intrin ) \
382
429
case Intrin: \
383
- replaceFunctionWithOp (F, OpCode); \
430
+ HasErrors |= replaceFunctionWithOp (F, OpCode); \
384
431
break ;
385
432
#include " DXILOperation.inc"
386
433
case Intrinsic::dx_handle_fromBinding:
387
- lowerHandleFromBinding (F);
434
+ HasErrors |= lowerHandleFromBinding (F);
388
435
break ;
389
436
case Intrinsic::dx_typedBufferLoad:
390
- lowerTypedBufferLoad (F);
437
+ HasErrors |= lowerTypedBufferLoad (F);
438
+ break ;
439
+ case Intrinsic::dx_typedBufferStore:
440
+ HasErrors |= lowerTypedBufferStore (F);
391
441
break ;
392
442
}
393
443
Updated = true ;
394
444
}
395
- if (Updated)
445
+ if (Updated && !HasErrors )
396
446
cleanupHandleCasts ();
397
447
398
448
return Updated;
0 commit comments