Skip to content

Commit 90e8411

Browse files
authored
[DirectX] Lower @llvm.dx.typedBufferStore to DXIL ops
The `@llvm.dx.typedBufferStore` intrinsic is lowered to `@dx.op.bufferStore`. Pull Request: #104253
1 parent 7fb19cb commit 90e8411

File tree

6 files changed

+265
-23
lines changed

6 files changed

+265
-23
lines changed

llvm/docs/DirectX/DXILResources.rst

+53-4
Original file line numberDiff line numberDiff line change
@@ -361,11 +361,60 @@ Examples:
361361
- ``i32``
362362
- Index into the buffer
363363

364+
Texture and Typed Buffer Stores
365+
-------------------------------
366+
367+
*relevant types: Textures and TypedBuffer*
368+
369+
The `TextureStore`_ and `BufferStore`_ DXIL operations always write all four
370+
32-bit components to a texture or a typed buffer. While both operations include
371+
a mask parameter, it is specified that the mask must cover all components when
372+
used with these types.
373+
374+
The store operations that we define as intrinsics behave similarly, and will
375+
only accept writes to the whole of the contained type. This differs from the
376+
loads above, but this makes sense to do from a semantics preserving point of
377+
view. Thus, texture and buffer stores may only operate on 4-element vectors of
378+
types that are 32-bits or fewer, such as ``<4 x i32>``, ``<4 x float>``, and
379+
``<4 x half>``, and 2 element vectors of 64-bit types like ``<2 x double>`` and
380+
``<2 x i64>``.
381+
382+
.. _BufferStore: https://github.com/microsoft/DirectXShaderCompiler/blob/main/docs/DXIL.rst#bufferstore
383+
.. _TextureStore: https://github.com/microsoft/DirectXShaderCompiler/blob/main/docs/DXIL.rst#texturestore
384+
364385
Examples:
365386

366-
.. code-block:: llvm
387+
.. list-table:: ``@llvm.dx.typedBufferStore``
388+
:header-rows: 1
367389

368-
%ret = call {<4 x float>, i1}
369-
@llvm.dx.typedBufferLoad.checkbit.v4f32.tdx.TypedBuffer_v4f32_0_0_0t(
370-
target("dx.TypedBuffer", <4 x float>, 0, 0, 0) %buffer, i32 %index)
390+
* - Argument
391+
-
392+
- Type
393+
- Description
394+
* - Return value
395+
-
396+
- ``void``
397+
-
398+
* - ``%buffer``
399+
- 0
400+
- ``target(dx.TypedBuffer, ...)``
401+
- The buffer to store into
402+
* - ``%index``
403+
- 1
404+
- ``i32``
405+
- Index into the buffer
406+
* - ``%data``
407+
- 2
408+
- A 4- or 2-element vector of the type of the buffer
409+
- The data to store
410+
411+
Examples:
412+
413+
.. code-block:: llvm
371414
415+
call void @llvm.dx.typedBufferStore.tdx.Buffer_v4f32_1_0_0t(
416+
target("dx.TypedBuffer", f32, 1, 0) %buf, i32 %index, <4 x f32> %data)
417+
call void @llvm.dx.typedBufferStore.tdx.Buffer_v4f16_1_0_0t(
418+
target("dx.TypedBuffer", f16, 1, 0) %buf, i32 %index, <4 x f16> %data)
419+
call void @llvm.dx.typedBufferStore.tdx.Buffer_v2f64_1_0_0t(
420+
target("dx.TypedBuffer", f64, 1, 0) %buf, i32 %index, <2 x f64> %data)

llvm/include/llvm/IR/IntrinsicsDirectX.td

+2
Original file line numberDiff line numberDiff line change
@@ -32,6 +32,8 @@ def int_dx_handle_fromBinding
3232

3333
def int_dx_typedBufferLoad
3434
: DefaultAttrsIntrinsic<[llvm_any_ty], [llvm_any_ty, llvm_i32_ty]>;
35+
def int_dx_typedBufferStore
36+
: DefaultAttrsIntrinsic<[], [llvm_any_ty, llvm_i32_ty, llvm_anyvector_ty]>;
3537

3638
// Cast between target extension handle types and dxil-style opaque handles
3739
def int_dx_cast_handle : Intrinsic<[llvm_any_ty], [llvm_any_ty]>;

llvm/lib/Target/DirectX/DXIL.td

+12
Original file line numberDiff line numberDiff line change
@@ -707,6 +707,18 @@ def BufferLoad : DXILOp<68, bufferLoad> {
707707
let stages = [Stages<DXIL1_0, [all_stages]>];
708708
}
709709

710+
def BufferStore : DXILOp<69, bufferStore> {
711+
let Doc = "writes to an RWTypedBuffer";
712+
// Handle, Coord0, Coord1, Val0, Val1, Val2, Val3, Mask
713+
let arguments = [
714+
HandleTy, Int32Ty, Int32Ty, OverloadTy, OverloadTy, OverloadTy, OverloadTy,
715+
Int8Ty
716+
];
717+
let result = VoidTy;
718+
let overloads = [Overloads<DXIL1_0, [HalfTy, FloatTy, Int16Ty, Int32Ty]>];
719+
let stages = [Stages<DXIL1_0, [all_stages]>];
720+
}
721+
710722
def ThreadId : DXILOp<93, threadId> {
711723
let Doc = "Reads the thread ID";
712724
let LLVMIntrinsic = int_dx_thread_id;

llvm/lib/Target/DirectX/DXILOpLowering.cpp

+69-19
Original file line numberDiff line numberDiff line change
@@ -82,8 +82,11 @@ class OpLowerer {
8282
public:
8383
OpLowerer(Module &M, DXILResourceMap &DRM) : M(M), OpBuilder(M), DRM(DRM) {}
8484

85-
void replaceFunction(Function &F,
86-
llvm::function_ref<Error(CallInst *CI)> ReplaceCall) {
85+
/// Replace every call to \c F using \c ReplaceCall, and then erase \c F. If
86+
/// there is an error replacing a call, we emit a diagnostic and return true.
87+
[[nodiscard]] bool
88+
replaceFunction(Function &F,
89+
llvm::function_ref<Error(CallInst *CI)> ReplaceCall) {
8790
for (User *U : make_early_inc_range(F.users())) {
8891
CallInst *CI = dyn_cast<CallInst>(U);
8992
if (!CI)
@@ -94,16 +97,18 @@ class OpLowerer {
9497
DiagnosticInfoUnsupported Diag(*CI->getFunction(), Message,
9598
CI->getDebugLoc());
9699
M.getContext().diagnose(Diag);
97-
continue;
100+
return true;
98101
}
99102
}
100103
if (F.user_empty())
101104
F.eraseFromParent();
105+
return false;
102106
}
103107

104-
void replaceFunctionWithOp(Function &F, dxil::OpCode DXILOp) {
108+
[[nodiscard]]
109+
bool replaceFunctionWithOp(Function &F, dxil::OpCode DXILOp) {
105110
bool IsVectorArgExpansion = isVectorArgExpansion(F);
106-
replaceFunction(F, [&](CallInst *CI) -> Error {
111+
return replaceFunction(F, [&](CallInst *CI) -> Error {
107112
SmallVector<Value *> Args;
108113
OpBuilder.getIRB().SetInsertPoint(CI);
109114
if (IsVectorArgExpansion) {
@@ -175,12 +180,12 @@ class OpLowerer {
175180
CleanupCasts.clear();
176181
}
177182

178-
void lowerToCreateHandle(Function &F) {
183+
[[nodiscard]] bool lowerToCreateHandle(Function &F) {
179184
IRBuilder<> &IRB = OpBuilder.getIRB();
180185
Type *Int8Ty = IRB.getInt8Ty();
181186
Type *Int32Ty = IRB.getInt32Ty();
182187

183-
replaceFunction(F, [&](CallInst *CI) -> Error {
188+
return replaceFunction(F, [&](CallInst *CI) -> Error {
184189
IRB.SetInsertPoint(CI);
185190

186191
auto *It = DRM.find(CI);
@@ -205,10 +210,10 @@ class OpLowerer {
205210
});
206211
}
207212

208-
void lowerToBindAndAnnotateHandle(Function &F) {
213+
[[nodiscard]] bool lowerToBindAndAnnotateHandle(Function &F) {
209214
IRBuilder<> &IRB = OpBuilder.getIRB();
210215

211-
replaceFunction(F, [&](CallInst *CI) -> Error {
216+
return replaceFunction(F, [&](CallInst *CI) -> Error {
212217
IRB.SetInsertPoint(CI);
213218

214219
auto *It = DRM.find(CI);
@@ -251,12 +256,11 @@ class OpLowerer {
251256

252257
/// Lower `dx.handle.fromBinding` intrinsics depending on the shader model and
253258
/// taking into account binding information from DXILResourceAnalysis.
254-
void lowerHandleFromBinding(Function &F) {
259+
bool lowerHandleFromBinding(Function &F) {
255260
Triple TT(Triple(M.getTargetTriple()));
256261
if (TT.getDXILVersion() < VersionTuple(1, 6))
257-
lowerToCreateHandle(F);
258-
else
259-
lowerToBindAndAnnotateHandle(F);
262+
return lowerToCreateHandle(F);
263+
return lowerToBindAndAnnotateHandle(F);
260264
}
261265

262266
/// Replace uses of \c Intrin with the values in the `dx.ResRet` of \c Op.
@@ -342,11 +346,11 @@ class OpLowerer {
342346
return Error::success();
343347
}
344348

345-
void lowerTypedBufferLoad(Function &F) {
349+
[[nodiscard]] bool lowerTypedBufferLoad(Function &F) {
346350
IRBuilder<> &IRB = OpBuilder.getIRB();
347351
Type *Int32Ty = IRB.getInt32Ty();
348352

349-
replaceFunction(F, [&](CallInst *CI) -> Error {
353+
return replaceFunction(F, [&](CallInst *CI) -> Error {
350354
IRB.SetInsertPoint(CI);
351355

352356
Value *Handle =
@@ -368,8 +372,51 @@ class OpLowerer {
368372
});
369373
}
370374

375+
[[nodiscard]] bool lowerTypedBufferStore(Function &F) {
376+
IRBuilder<> &IRB = OpBuilder.getIRB();
377+
Type *Int8Ty = IRB.getInt8Ty();
378+
Type *Int32Ty = IRB.getInt32Ty();
379+
380+
return replaceFunction(F, [&](CallInst *CI) -> Error {
381+
IRB.SetInsertPoint(CI);
382+
383+
Value *Handle =
384+
createTmpHandleCast(CI->getArgOperand(0), OpBuilder.getHandleType());
385+
Value *Index0 = CI->getArgOperand(1);
386+
Value *Index1 = UndefValue::get(Int32Ty);
387+
// For typed stores, the mask must always cover all four elements.
388+
Constant *Mask = ConstantInt::get(Int8Ty, 0xF);
389+
390+
Value *Data = CI->getArgOperand(2);
391+
auto *DataTy = dyn_cast<FixedVectorType>(Data->getType());
392+
if (!DataTy || DataTy->getNumElements() != 4)
393+
return make_error<StringError>(
394+
"typedBufferStore data must be a vector of 4 elements",
395+
inconvertibleErrorCode());
396+
Value *Data0 =
397+
IRB.CreateExtractElement(Data, ConstantInt::get(Int32Ty, 0));
398+
Value *Data1 =
399+
IRB.CreateExtractElement(Data, ConstantInt::get(Int32Ty, 1));
400+
Value *Data2 =
401+
IRB.CreateExtractElement(Data, ConstantInt::get(Int32Ty, 2));
402+
Value *Data3 =
403+
IRB.CreateExtractElement(Data, ConstantInt::get(Int32Ty, 3));
404+
405+
std::array<Value *, 8> Args{Handle, Index0, Index1, Data0,
406+
Data1, Data2, Data3, Mask};
407+
Expected<CallInst *> OpCall =
408+
OpBuilder.tryCreateOp(OpCode::BufferStore, Args);
409+
if (Error E = OpCall.takeError())
410+
return E;
411+
412+
CI->eraseFromParent();
413+
return Error::success();
414+
});
415+
}
416+
371417
bool lowerIntrinsics() {
372418
bool Updated = false;
419+
bool HasErrors = false;
373420

374421
for (Function &F : make_early_inc_range(M.functions())) {
375422
if (!F.isDeclaration())
@@ -380,19 +427,22 @@ class OpLowerer {
380427
continue;
381428
#define DXIL_OP_INTRINSIC(OpCode, Intrin) \
382429
case Intrin: \
383-
replaceFunctionWithOp(F, OpCode); \
430+
HasErrors |= replaceFunctionWithOp(F, OpCode); \
384431
break;
385432
#include "DXILOperation.inc"
386433
case Intrinsic::dx_handle_fromBinding:
387-
lowerHandleFromBinding(F);
434+
HasErrors |= lowerHandleFromBinding(F);
388435
break;
389436
case Intrinsic::dx_typedBufferLoad:
390-
lowerTypedBufferLoad(F);
437+
HasErrors |= lowerTypedBufferLoad(F);
438+
break;
439+
case Intrinsic::dx_typedBufferStore:
440+
HasErrors |= lowerTypedBufferStore(F);
391441
break;
392442
}
393443
Updated = true;
394444
}
395-
if (Updated)
445+
if (Updated && !HasErrors)
396446
cleanupHandleCasts();
397447

398448
return Updated;
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,37 @@
1+
; We use llc for this test so that we don't abort after the first error.
2+
; RUN: not llc %s -o /dev/null 2>&1 | FileCheck %s
3+
4+
target triple = "dxil-pc-shadermodel6.6-compute"
5+
6+
; CHECK: error:
7+
; CHECK-SAME: in function storetoomany
8+
; CHECK-SAME: typedBufferStore data must be a vector of 4 elements
9+
define void @storetoomany(<5 x float> %data, i32 %index) {
10+
%buffer = call target("dx.TypedBuffer", <4 x float>, 1, 0, 0)
11+
@llvm.dx.handle.fromBinding.tdx.TypedBuffer_v4f32_1_0_0(
12+
i32 0, i32 0, i32 1, i32 0, i1 false)
13+
14+
call void @llvm.dx.typedBufferStore.tdx.TypedBuffer_v4f32_1_0_0t.v5f32(
15+
target("dx.TypedBuffer", <4 x float>, 1, 0, 0) %buffer,
16+
i32 %index, <5 x float> %data)
17+
18+
ret void
19+
}
20+
21+
; CHECK: error:
22+
; CHECK-SAME: in function storetoofew
23+
; CHECK-SAME: typedBufferStore data must be a vector of 4 elements
24+
define void @storetoofew(<3 x i32> %data, i32 %index) {
25+
%buffer = call target("dx.TypedBuffer", <4 x i32>, 1, 0, 0)
26+
@llvm.dx.handle.fromBinding.tdx.TypedBuffer_v4i32_1_0_0(
27+
i32 0, i32 0, i32 1, i32 0, i1 false)
28+
29+
call void @llvm.dx.typedBufferStore.tdx.TypedBuffer_v4i32_1_0_0t.v3i32(
30+
target("dx.TypedBuffer", <4 x i32>, 1, 0, 0) %buffer,
31+
i32 %index, <3 x i32> %data)
32+
33+
ret void
34+
}
35+
36+
declare void @llvm.dx.typedBufferStore.tdx.TypedBuffer_v4f32_1_0_0t.v5f32(target("dx.TypedBuffer", <4 x float>, 1, 0, 0), i32, <5 x float>)
37+
declare void @llvm.dx.typedBufferStore.tdx.TypedBuffer_v4i32_1_0_0t.v3i32(target("dx.TypedBuffer", <4 x i32>, 1, 0, 0), i32, <3 x i32>)
+92
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,92 @@
1+
; RUN: opt -S -dxil-op-lower %s | FileCheck %s
2+
3+
target triple = "dxil-pc-shadermodel6.6-compute"
4+
5+
define void @storefloat(<4 x float> %data, i32 %index) {
6+
7+
; CHECK: [[BIND:%.*]] = call %dx.types.Handle @dx.op.createHandleFromBinding
8+
; CHECK: [[HANDLE:%.*]] = call %dx.types.Handle @dx.op.annotateHandle(i32 217, %dx.types.Handle [[BIND]]
9+
%buffer = call target("dx.TypedBuffer", <4 x float>, 1, 0, 0)
10+
@llvm.dx.handle.fromBinding.tdx.TypedBuffer_v4f32_1_0_0(
11+
i32 0, i32 0, i32 1, i32 0, i1 false)
12+
13+
; The temporary casts should all have been cleaned up
14+
; CHECK-NOT: %dx.cast_handle
15+
16+
; CHECK: [[DATA0_0:%.*]] = extractelement <4 x float> %data, i32 0
17+
; CHECK: [[DATA0_1:%.*]] = extractelement <4 x float> %data, i32 1
18+
; CHECK: [[DATA0_2:%.*]] = extractelement <4 x float> %data, i32 2
19+
; CHECK: [[DATA0_3:%.*]] = extractelement <4 x float> %data, i32 3
20+
; CHECK: call void @dx.op.bufferStore.f32(i32 69, %dx.types.Handle [[HANDLE]], i32 %index, i32 undef, float [[DATA0_0]], float [[DATA0_1]], float [[DATA0_2]], float [[DATA0_3]], i8 15)
21+
call void @llvm.dx.typedBufferStore(
22+
target("dx.TypedBuffer", <4 x float>, 1, 0, 0) %buffer,
23+
i32 %index, <4 x float> %data)
24+
25+
ret void
26+
}
27+
28+
define void @storeint(<4 x i32> %data, i32 %index) {
29+
30+
; CHECK: [[BIND:%.*]] = call %dx.types.Handle @dx.op.createHandleFromBinding
31+
; CHECK: [[HANDLE:%.*]] = call %dx.types.Handle @dx.op.annotateHandle(i32 217, %dx.types.Handle [[BIND]]
32+
%buffer = call target("dx.TypedBuffer", <4 x i32>, 1, 0, 0)
33+
@llvm.dx.handle.fromBinding.tdx.TypedBuffer_v4i32_1_0_0(
34+
i32 0, i32 0, i32 1, i32 0, i1 false)
35+
36+
; CHECK: [[DATA0_0:%.*]] = extractelement <4 x i32> %data, i32 0
37+
; CHECK: [[DATA0_1:%.*]] = extractelement <4 x i32> %data, i32 1
38+
; CHECK: [[DATA0_2:%.*]] = extractelement <4 x i32> %data, i32 2
39+
; CHECK: [[DATA0_3:%.*]] = extractelement <4 x i32> %data, i32 3
40+
; CHECK: call void @dx.op.bufferStore.i32(i32 69, %dx.types.Handle [[HANDLE]], i32 %index, i32 undef, i32 [[DATA0_0]], i32 [[DATA0_1]], i32 [[DATA0_2]], i32 [[DATA0_3]], i8 15)
41+
call void @llvm.dx.typedBufferStore(
42+
target("dx.TypedBuffer", <4 x i32>, 1, 0, 0) %buffer,
43+
i32 %index, <4 x i32> %data)
44+
45+
ret void
46+
}
47+
48+
define void @storehalf(<4 x half> %data, i32 %index) {
49+
50+
; CHECK: [[BIND:%.*]] = call %dx.types.Handle @dx.op.createHandleFromBinding
51+
; CHECK: [[HANDLE:%.*]] = call %dx.types.Handle @dx.op.annotateHandle(i32 217, %dx.types.Handle [[BIND]]
52+
%buffer = call target("dx.TypedBuffer", <4 x half>, 1, 0, 0)
53+
@llvm.dx.handle.fromBinding.tdx.TypedBuffer_v4f16_1_0_0(
54+
i32 0, i32 0, i32 1, i32 0, i1 false)
55+
56+
; The temporary casts should all have been cleaned up
57+
; CHECK-NOT: %dx.cast_handle
58+
59+
; CHECK: [[DATA0_0:%.*]] = extractelement <4 x half> %data, i32 0
60+
; CHECK: [[DATA0_1:%.*]] = extractelement <4 x half> %data, i32 1
61+
; CHECK: [[DATA0_2:%.*]] = extractelement <4 x half> %data, i32 2
62+
; CHECK: [[DATA0_3:%.*]] = extractelement <4 x half> %data, i32 3
63+
; CHECK: call void @dx.op.bufferStore.f16(i32 69, %dx.types.Handle [[HANDLE]], i32 %index, i32 undef, half [[DATA0_0]], half [[DATA0_1]], half [[DATA0_2]], half [[DATA0_3]], i8 15)
64+
call void @llvm.dx.typedBufferStore(
65+
target("dx.TypedBuffer", <4 x half>, 1, 0, 0) %buffer,
66+
i32 %index, <4 x half> %data)
67+
68+
ret void
69+
}
70+
71+
define void @storei16(<4 x i16> %data, i32 %index) {
72+
73+
; CHECK: [[BIND:%.*]] = call %dx.types.Handle @dx.op.createHandleFromBinding
74+
; CHECK: [[HANDLE:%.*]] = call %dx.types.Handle @dx.op.annotateHandle(i32 217, %dx.types.Handle [[BIND]]
75+
%buffer = call target("dx.TypedBuffer", <4 x i16>, 1, 0, 0)
76+
@llvm.dx.handle.fromBinding.tdx.TypedBuffer_v4i16_1_0_0(
77+
i32 0, i32 0, i32 1, i32 0, i1 false)
78+
79+
; The temporary casts should all have been cleaned up
80+
; CHECK-NOT: %dx.cast_handle
81+
82+
; CHECK: [[DATA0_0:%.*]] = extractelement <4 x i16> %data, i32 0
83+
; CHECK: [[DATA0_1:%.*]] = extractelement <4 x i16> %data, i32 1
84+
; CHECK: [[DATA0_2:%.*]] = extractelement <4 x i16> %data, i32 2
85+
; CHECK: [[DATA0_3:%.*]] = extractelement <4 x i16> %data, i32 3
86+
; CHECK: call void @dx.op.bufferStore.i16(i32 69, %dx.types.Handle [[HANDLE]], i32 %index, i32 undef, i16 [[DATA0_0]], i16 [[DATA0_1]], i16 [[DATA0_2]], i16 [[DATA0_3]], i8 15)
87+
call void @llvm.dx.typedBufferStore(
88+
target("dx.TypedBuffer", <4 x i16>, 1, 0, 0) %buffer,
89+
i32 %index, <4 x i16> %data)
90+
91+
ret void
92+
}

0 commit comments

Comments
 (0)