Skip to content

Commit

Permalink
Avoid a copy of the routing table; move routing state
Browse files Browse the repository at this point in the history
The routing table doesn't need to be copied in prepapre, it
can live by reference and then the routing state can be
configured during prepare so the topology is prepared byt
depth, active, etc... can be dynamic.

Small API change implicit.
  • Loading branch information
baconpaul committed Feb 24, 2024
1 parent 5a32bd0 commit 7f256b5
Show file tree
Hide file tree
Showing 2 changed files with 37 additions and 45 deletions.
80 changes: 36 additions & 44 deletions include/sst/basic-blocks/mod-matrix/ModMatrix.h
Original file line number Diff line number Diff line change
Expand Up @@ -108,9 +108,6 @@ template <typename ModMatrixTraits> struct FixedLengthRoutingTable : RoutingTabl
static_assert(std::is_same<decltype(TR::FixedMatrixSize), const size_t>::value);

std::array<typename RT::Routing, TR::FixedMatrixSize> routes{};
std::unordered_map<typename TR::TargetIdentifier, bool> isOutputMapped;
std::unordered_map<typename TR::TargetIdentifier, size_t> targetToOutputIndex;
std::unordered_map<typename TR::SourceIdentifier, bool> isSourceUsed;

// fixed API for changing the mod matrix in increasing completeness
void updateDepthAt(size_t position, float depth)
Expand All @@ -132,8 +129,6 @@ template <typename ModMatrixTraits> struct FixedLengthRoutingTable : RoutingTabl
routes[position].source = source;
routes[position].target = target;
routes[position].depth = depth;

updateRoutingState();
}
void updateRoutingAt(size_t position, const typename TR::SourceIdentifier &source,
const typename TR::SourceIdentifier &sourceVia,
Expand All @@ -146,17 +141,38 @@ template <typename ModMatrixTraits> struct FixedLengthRoutingTable : RoutingTabl
routes[position].curve = curve;
routes[position].target = target;
routes[position].depth = depth;

updateRoutingState();
}
};

template <typename ModMatrixTraits> struct FixedMatrix : ModMatrix<ModMatrixTraits>
{
using TR = ModMatrixTraits;
using PT = ModMatrix<ModMatrixTraits>;
using RT = FixedLengthRoutingTable<ModMatrixTraits>;
using RoutingTable = RT;

std::array<float, TR::FixedMatrixSize> matrixOutputs{};

struct RoutingValuePointers
{
bool *active{nullptr};
float *source{nullptr}, *sourceVia{nullptr}, *depth{nullptr}, *target{nullptr};
float depthScale{1.f};
std::function<float(float)> curveFn;
};
std::array<RoutingValuePointers, TR::FixedMatrixSize> routingValuePointers{};

void updateRoutingState()
std::unordered_map<typename TR::TargetIdentifier, bool> isOutputMapped;
std::unordered_map<typename TR::TargetIdentifier, size_t> targetToOutputIndex;
std::unordered_map<typename TR::SourceIdentifier, bool> isSourceUsed;

void updateRoutingState(const RoutingTable &rt)
{
isOutputMapped.clear();
isSourceUsed.clear();
targetToOutputIndex.clear();
size_t outIdx{0};
for (auto &r : routes)
for (auto &r : rt.routes)
{
if (!r.source.has_value() || !r.target.has_value())
continue;
Expand All @@ -173,36 +189,14 @@ template <typename ModMatrixTraits> struct FixedLengthRoutingTable : RoutingTabl
}
}
}
};

template <typename ModMatrixTraits> struct FixedMatrix : ModMatrix<ModMatrixTraits>
{
using TR = ModMatrixTraits;
using PT = ModMatrix<ModMatrixTraits>;
using RT = FixedLengthRoutingTable<ModMatrixTraits>;
using RoutingTable = RT;

RoutingTable routingTable;

std::array<float, TR::FixedMatrixSize> matrixOutputs{};

struct RoutingValuePointers
{
bool *active{nullptr};
float *source{nullptr}, *sourceVia{nullptr}, *depth{nullptr}, *target{nullptr};
float depthScale{1.f};
std::function<float(float)> curveFn;
};
std::array<RoutingValuePointers, TR::FixedMatrixSize> routingValuePointers{};

void prepare(RT rt)
void prepare(RT &rt)
{
this->routingTable = rt;
this->routingTable.updateRoutingState();
updateRoutingState(rt);

int idx{0};
std::unordered_set<typename TR::TargetIdentifier> depthMaps;
for (auto &r : routingTable.routes)
for (auto &r : rt.routes)
{
if (!r.source.has_value() && !r.target.has_value())
continue;
Expand All @@ -214,8 +208,7 @@ template <typename ModMatrixTraits> struct FixedMatrix : ModMatrix<ModMatrixTrai
{
continue;
}
if (this->routingTable.targetToOutputIndex.find(*r.target) ==
this->routingTable.targetToOutputIndex.end())
if (this->targetToOutputIndex.find(*r.target) == this->targetToOutputIndex.end())
{
continue;
}
Expand Down Expand Up @@ -244,7 +237,7 @@ template <typename ModMatrixTraits> struct FixedMatrix : ModMatrix<ModMatrixTrai
rv.curveFn = nullptr;
}

rv.target = &matrixOutputs[routingTable.targetToOutputIndex.at(*r.target)];
rv.target = &matrixOutputs[targetToOutputIndex.at(*r.target)];
}

if constexpr (ModMatrix<TR>::canSelfModulate)
Expand All @@ -253,9 +246,8 @@ template <typename ModMatrixTraits> struct FixedMatrix : ModMatrix<ModMatrixTrai
{
auto depthIndex = TR::getTargetModMatrixElement(m);
assert(depthIndex < routingValuePointers.size());
routingValuePointers[depthIndex].depth =
&matrixOutputs[routingTable.targetToOutputIndex.at(m)];
this->baseValues.insert_or_assign(m, this->routingTable.routes[depthIndex].depth);
routingValuePointers[depthIndex].depth = &matrixOutputs[targetToOutputIndex.at(m)];
this->baseValues.insert_or_assign(m, rt.routes[depthIndex].depth);
}
}
}
Expand All @@ -264,7 +256,7 @@ template <typename ModMatrixTraits> struct FixedMatrix : ModMatrix<ModMatrixTrai
{
std::fill(matrixOutputs.begin(), matrixOutputs.end(), 0.f);

for (const auto &[tgt, outIdx] : routingTable.targetToOutputIndex)
for (const auto &[tgt, outIdx] : targetToOutputIndex)
{
matrixOutputs[outIdx] = this->baseValues.at(tgt);
}
Expand Down Expand Up @@ -297,14 +289,14 @@ template <typename ModMatrixTraits> struct FixedMatrix : ModMatrix<ModMatrixTrai

const float *getTargetValuePointer(const typename TR::TargetIdentifier &s) const
{
auto f = routingTable.isOutputMapped.find(s);
if (f == routingTable.isOutputMapped.end() || !f->second)
auto f = isOutputMapped.find(s);
if (f == isOutputMapped.end() || !f->second)
{
return &this->baseValues.at(s);
}
else
{
return &matrixOutputs[routingTable.targetToOutputIndex.at(s)];
return &matrixOutputs[targetToOutputIndex.at(s)];
}
}
float getTargetValue(const typename TR::TargetIdentifier &s) const
Expand Down
2 changes: 1 addition & 1 deletion tests/mod_matrix_tests.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -302,7 +302,7 @@ TEST_CASE("Routing Activation", "[mod-matrix]")
REQUIRE(*t3P == Approx(t3V + 0.5 * barSVal).margin(1e-5));
REQUIRE(*t3PP == Approx(t3PV - 0.5 * fooSVal).margin(1e-5));

m.routingTable.routes[0].active = false;
rt.routes[0].active = false;
m.process();

REQUIRE(*t3P == Approx(t3V).margin(1e-5));
Expand Down

0 comments on commit 7f256b5

Please sign in to comment.