Skip to content

Commit

Permalink
[d3d11] Lazy-bind compute shader UAVs
Browse files Browse the repository at this point in the history
And factor UAV counter updates out of binding.
  • Loading branch information
doitsujin committed Feb 20, 2025
1 parent 63900e8 commit 126798b
Show file tree
Hide file tree
Showing 4 changed files with 105 additions and 23 deletions.
88 changes: 67 additions & 21 deletions src/d3d11/d3d11_context.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -2169,7 +2169,8 @@ namespace dxvk {
m_state.uav.views[uavId] = nullptr;
m_state.uav.mask.clr(uavId);

BindUnorderedAccessView(DxbcProgramType::ComputeShader, uavId, nullptr, ~0u);
if (!DirtyUnorderedAccessView(DxbcProgramType::ComputeShader, uavId, true))
BindUnorderedAccessView(DxbcProgramType::ComputeShader, uavId, nullptr);
}
}

Expand All @@ -2184,11 +2185,16 @@ namespace dxvk {
auto uav = static_cast<D3D11UnorderedAccessView*>(ppUnorderedAccessViews[i]);
auto ctr = pUAVInitialCounts ? pUAVInitialCounts[i] : ~0u;

if (m_state.uav.views[StartSlot + i] != uav || ctr != ~0u) {
if (ctr != ~0u && uav && uav->HasCounter())
UpdateUnorderedAccessViewCounter(uav, ctr);

if (m_state.uav.views[StartSlot + i] != uav) {
m_state.uav.views[StartSlot + i] = uav;
m_state.uav.mask.set(StartSlot + i, uav != nullptr);

BindUnorderedAccessView(DxbcProgramType::ComputeShader, StartSlot + i, uav, ctr);
if (!DirtyUnorderedAccessView(DxbcProgramType::ComputeShader, StartSlot + i, !uav))
BindUnorderedAccessView(DxbcProgramType::ComputeShader, StartSlot + i, uav);

ResolveCsSrvHazards(uav);
}
}
Expand Down Expand Up @@ -3241,6 +3247,28 @@ namespace dxvk {
}


template<typename ContextType>
void D3D11CommonContext<ContextType>::ApplyDirtyUnorderedAccessViews(
DxbcProgramType Stage,
const DxbcBindingMask& BoundMask,
DxbcBindingMask& DirtyMask) {
uint64_t bindMask = BoundMask.uavMask & DirtyMask.uavMask;

if (!bindMask)
return;

const auto& views = Stage == DxbcProgramType::ComputeShader
? m_state.uav.views
: m_state.om.uavs;

// Need to clear dirty bits before binding
DirtyMask.uavMask -= bindMask;

for (uint32_t slot : bit::BitMask(bindMask))
BindUnorderedAccessView(Stage, slot, views[slot].ptr());
}


template<typename ContextType>
void D3D11CommonContext<ContextType>::ApplyDirtyGraphicsBindings() {
auto dirtyMask = m_state.lazy.shadersDirty & m_state.lazy.shadersUsed;
Expand Down Expand Up @@ -3271,6 +3299,7 @@ namespace dxvk {
ApplyDirtySamplers(stage, boundMask, dirtyMask);
ApplyDirtyConstantBuffers(stage, boundMask, dirtyMask);
ApplyDirtyShaderResources(stage, boundMask, dirtyMask);
ApplyDirtyUnorderedAccessViews(stage, boundMask, dirtyMask);

m_state.lazy.shadersDirty.clr(stage);
}
Expand Down Expand Up @@ -3921,8 +3950,7 @@ namespace dxvk {
void D3D11CommonContext<ContextType>::BindUnorderedAccessView(
DxbcProgramType ShaderStage,
UINT Slot,
D3D11UnorderedAccessView* pUav,
UINT Counter) {
D3D11UnorderedAccessView* pUav) {
uint32_t uavSlotId = computeUavBinding(ShaderStage, Slot);
uint32_t ctrSlotId = computeUavCounterBinding(ShaderStage, Slot);

Expand All @@ -3937,19 +3965,8 @@ namespace dxvk {
cCtrSlotId = ctrSlotId,
cStages = stages,
cBufferView = pUav->GetBufferView(),
cCounterView = pUav->GetCounterView(),
cCounterValue = Counter
cCounterView = pUav->GetCounterView()
] (DxvkContext* ctx) mutable {
if (cCounterView != nullptr && cCounterValue != ~0u) {
DxvkBufferSlice counterSlice(cCounterView);

ctx->updateBuffer(
counterSlice.buffer(),
counterSlice.offset(),
sizeof(uint32_t),
&cCounterValue);
}

ctx->bindResourceBufferView(cStages, cUavSlotId,
Forwarder::move(cBufferView));
ctx->bindResourceBufferView(cStages, cCtrSlotId,
Expand Down Expand Up @@ -4432,6 +4449,18 @@ namespace dxvk {
}


template<typename ContextType>
bool D3D11CommonContext<ContextType>::DirtyUnorderedAccessView(
DxbcProgramType ShaderStage,
uint32_t Slot,
bool IsNull) {
return DirtyBindingGeneric(ShaderStage,
m_state.lazy.bindingsUsed[ShaderStage].uavMask,
m_state.lazy.bindingsDirty[ShaderStage].uavMask,
uint64_t(1u) << Slot, IsNull);
}


template<typename ContextType>
void D3D11CommonContext<ContextType>::DiscardBuffer(
ID3D11Resource* pResource) {
Expand Down Expand Up @@ -4854,7 +4883,7 @@ namespace dxvk {
if (CheckViewOverlap(pView, m_state.om.uavs[i].ptr())) {
m_state.om.uavs[i] = nullptr;

BindUnorderedAccessView(DxbcProgramType::PixelShader, i, nullptr, ~0u);
BindUnorderedAccessView(DxbcProgramType::PixelShader, i, nullptr);
}
}
}
Expand Down Expand Up @@ -4962,7 +4991,7 @@ namespace dxvk {
: m_state.om.maxUav;

for (uint32_t i = 0; i < maxCount; i++)
BindUnorderedAccessView(Stage, i, views[i].ptr(), ~0u);
BindUnorderedAccessView(Stage, i, views[i].ptr());
}


Expand Down Expand Up @@ -5192,10 +5221,13 @@ namespace dxvk {
ctr = pUAVInitialCounts ? pUAVInitialCounts[i - UAVStartSlot] : ~0u;
}

if (m_state.om.uavs[i] != uav || ctr != ~0u) {
if (ctr != ~0u && uav && uav->HasCounter())
UpdateUnorderedAccessViewCounter(uav, ctr);

if (m_state.om.uavs[i] != uav) {
m_state.om.uavs[i] = uav;

BindUnorderedAccessView(DxbcProgramType::PixelShader, i, uav, ctr);
BindUnorderedAccessView(DxbcProgramType::PixelShader, i, uav);
ResolveOmSrvHazards(uav);

if (NumRTVs == D3D11_KEEP_RENDER_TARGETS_AND_DEPTH_STENCIL)
Expand Down Expand Up @@ -5590,6 +5622,20 @@ namespace dxvk {
}


template<typename ContextType>
void D3D11CommonContext<ContextType>::UpdateUnorderedAccessViewCounter(
D3D11UnorderedAccessView* pUav,
uint32_t CounterValue) {
EmitCs([
cView = pUav->GetCounterView(),
cCounter = CounterValue
] (DxvkContext* ctx) {
ctx->updateBuffer(cView->buffer(),
cView->info().offset, sizeof(cCounter), &cCounter);
});
}


template<typename ContextType>
bool D3D11CommonContext<ContextType>::ValidateRenderTargets(
UINT NumViews,
Expand Down
17 changes: 15 additions & 2 deletions src/d3d11/d3d11_context.h
Original file line number Diff line number Diff line change
Expand Up @@ -819,6 +819,11 @@ namespace dxvk {
const DxbcBindingMask& BoundMask,
DxbcBindingMask& DirtyMask);

void ApplyDirtyUnorderedAccessViews(
DxbcProgramType Stage,
const DxbcBindingMask& BoundMask,
DxbcBindingMask& DirtyMask);

void ApplyDirtyGraphicsBindings();

void ApplyDirtyComputeBindings();
Expand Down Expand Up @@ -904,8 +909,7 @@ namespace dxvk {
void BindUnorderedAccessView(
DxbcProgramType ShaderStage,
UINT Slot,
D3D11UnorderedAccessView* pUav,
UINT Counter);
D3D11UnorderedAccessView* pUav);

VkClearValue ConvertColorValue(
const FLOAT Color[4],
Expand Down Expand Up @@ -957,6 +961,11 @@ namespace dxvk {
uint32_t Slot,
bool IsNull);

bool DirtyUnorderedAccessView(
DxbcProgramType ShaderStage,
uint32_t Slot,
bool IsNull);

void DiscardBuffer(
ID3D11Resource* pResource);

Expand Down Expand Up @@ -1115,6 +1124,10 @@ namespace dxvk {
UINT SrcDepthPitch,
UINT CopyFlags);

void UpdateUnorderedAccessViewCounter(
D3D11UnorderedAccessView* pUav,
uint32_t CounterValue);

bool ValidateRenderTargets(
UINT NumViews,
ID3D11RenderTargetView* const* ppRenderTargetViews,
Expand Down
19 changes: 19 additions & 0 deletions src/d3d11/d3d11_context_imm.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -1020,6 +1020,18 @@ namespace dxvk {
for (uint32_t index : bit::BitMask(cDirtyState[dxStage].srvMask[m]))
ctx->bindResourceImageView(vkStage, srvSlot + index + m * 64u, nullptr);
}

// Unbind all dirty unordered access views. Only consider compute
// here since we don't actually lazy-bind graphics UAVs.
if (dxStage == DxbcProgramType::ComputeShader) {
auto uavSlot = computeUavBinding(dxStage, 0);
auto ctrSlot = computeUavCounterBinding(dxStage, 0);

for (uint32_t index : bit::BitMask(cDirtyState[dxStage].uavMask)) {
ctx->bindResourceImageView(vkStage, uavSlot + index, nullptr);
ctx->bindResourceBufferView(vkStage, ctrSlot + index, nullptr);
}
}
}
});

Expand All @@ -1045,6 +1057,13 @@ namespace dxvk {
}
}

if (stage == DxbcProgramType::ComputeShader) {
for (uint32_t index : bit::BitMask(dirtyState[stage].uavMask)) {
if (!m_state.uav.views[index].ptr())
dirtyState[stage].uavMask &= ~(uint64_t(1u) << index);
}
}

if (dirtyState[stage].empty())
m_state.lazy.shadersDirty.clr(stage);
}
Expand Down
4 changes: 4 additions & 0 deletions src/d3d11/d3d11_view_uav.h
Original file line number Diff line number Diff line change
Expand Up @@ -43,6 +43,10 @@ namespace dxvk {
return m_info.BindFlags & Flags;
}

BOOL HasCounter() const {
return m_counterView != nullptr;
}

D3D11_RESOURCE_DIMENSION GetResourceType() const {
D3D11_RESOURCE_DIMENSION type;
m_resource->GetType(&type);
Expand Down

0 comments on commit 126798b

Please sign in to comment.