Skip to content

Commit

Permalink
dwidenoise: Modularise kernel
Browse files Browse the repository at this point in the history
Better separation of code responsible for fetching a batch of input data within a sliding spatial window from the code responsible for the denoising of the image data.
  • Loading branch information
Lestropie committed Nov 4, 2024
1 parent 4d6f20f commit 4424cb4
Showing 1 changed file with 136 additions and 82 deletions.
218 changes: 136 additions & 82 deletions cmd/dwidenoise.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -144,29 +144,98 @@ void usage() {

using real_type = float;

template <typename F = float> class DenoisingFunctor {
// Class to encode return information from kernel
template <class MatrixType> class KernelData {
public:
KernelData(const size_t volumes, const size_t kernel_size)
: centre_index(-1), //
voxel_count(kernel_size), //
X(MatrixType::Zero(volumes, kernel_size)) {} //
size_t centre_index;
size_t voxel_count;
MatrixType X;
};

template <class MatrixType> class KernelBase {
public:
KernelBase() : pos({-1, -1, -1}) {}
KernelBase(const KernelBase &) : pos({-1, -1, -1}) {}

protected:
// Store / restore position of image before / after data loading
std::array<ssize_t, 3> pos;
template <class ImageType> void stash_pos(const ImageType &image) {
for (size_t axis = 0; axis != 3; ++axis)
pos[axis] = image.index(axis);
}
template <class ImageType> void restore_pos(ImageType &image) {
for (size_t axis = 0; axis != 3; ++axis)
image.index(axis) = pos[axis];
}
};

template <class MatrixType> class KernelCube : public KernelBase<MatrixType> {
public:
KernelCube(const std::vector<uint32_t> &extent)
: half_extent({int(extent[0] / 2), int(extent[1] / 2), int(extent[2] / 2)}) {
for (auto e : extent) {
if (!(e % 2))
throw Exception("Size of cubic kernel must be an odd integer");
}
}
KernelCube(const KernelCube &) = default;
template <class ImageType> void operator()(ImageType &image, KernelData<MatrixType> &data) {
assert(data.X.cols() == size());
KernelBase<MatrixType>::stash_pos(image);
size_t k = 0;
for (int z = -half_extent[2]; z <= half_extent[2]; z++) {
image.index(2) = wrapindex(z, 2, image.size(2));
for (int y = -half_extent[1]; y <= half_extent[1]; y++) {
image.index(1) = wrapindex(y, 1, image.size(1));
for (int x = -half_extent[0]; x <= half_extent[0]; x++, k++) {
image.index(0) = wrapindex(x, 0, image.size(0));
data.X.col(k) = image.row(3);
}
}
}
KernelBase<MatrixType>::restore_pos(image);
data.voxel_count = size();
data.centre_index = size() / 2;
}
size_t size() const { return (2 * half_extent[0] + 1) * (2 * half_extent[1] + 1) * (2 * half_extent[2] + 1); }

private:
const std::vector<int> half_extent;

// patch handling at image edges
inline size_t wrapindex(int r, int axis, int max) const {
int rr = KernelBase<MatrixType>::pos[axis] + r;
if (rr < 0)
rr = half_extent[axis] - r;
if (rr >= max)
rr = (max - 1) - half_extent[axis] - r;
return rr;
}
};

template <typename F, class KernelType> class DenoisingFunctor {

public:
using MatrixType = Eigen::Matrix<F, Eigen::Dynamic, Eigen::Dynamic>;
using SValsType = Eigen::VectorXd;

DenoisingFunctor(int ndwi,
const std::vector<uint32_t> &extent,
Image<bool> &mask,
Image<real_type> &noise,
Image<uint16_t> &rank,
bool exp1)
: extent{{extent[0] / 2, extent[1] / 2, extent[2] / 2}},
DenoisingFunctor(
int ndwi, KernelType &kernel, Image<bool> &mask, Image<real_type> &noise, Image<uint16_t> &rank, bool exp1)
: data(ndwi, kernel.size()),
kernel(kernel),
m(ndwi),
n(extent[0] * extent[1] * extent[2]),
n(kernel.size()),
r(std::min(m, n)),
q(std::max(m, n)),
exp1(exp1),
X(m, n),
XtX(r, r),
eig(r),
s(r),
pos{{0, 0, 0}},
mask(mask),
noise(noise),
rankmap(rank) {}
Expand All @@ -180,13 +249,13 @@ template <typename F = float> class DenoisingFunctor {
}

// Load data in local window
load_data(dwi);
kernel(dwi, data);

// Compute Eigendecomposition:
if (m <= n)
XtX.template triangularView<Eigen::Lower>() = X * X.adjoint();
XtX.template triangularView<Eigen::Lower>() = data.X * data.X.adjoint();
else
XtX.template triangularView<Eigen::Lower>() = X.adjoint() * X;
XtX.template triangularView<Eigen::Lower>() = data.X.adjoint() * data.X;
eig.compute(XtX);
// eigenvalues sorted in increasing order:
s = eig.eigenvalues().template cast<double>();
Expand Down Expand Up @@ -215,14 +284,18 @@ template <typename F = float> class DenoisingFunctor {
s.head(cutoff_p).setZero();
s.tail(r - cutoff_p).setOnes();
if (m <= n)
X.col(n / 2) = eig.eigenvectors() * (s.cast<F>().asDiagonal() * (eig.eigenvectors().adjoint() * X.col(n / 2)));
data.X.col(data.centre_index) =
eig.eigenvectors() *
(s.cast<F>().asDiagonal() * (eig.eigenvectors().adjoint() * data.X.col(data.centre_index)));
else
X.col(n / 2) = X * (eig.eigenvectors() * (s.cast<F>().asDiagonal() * eig.eigenvectors().adjoint().col(n / 2)));
data.X.col(data.centre_index) =
data.X *
(eig.eigenvectors() * (s.cast<F>().asDiagonal() * eig.eigenvectors().adjoint().col(data.centre_index)));
}

// Store output
assign_pos_of(dwi).to(out);
out.row(3) = X.col(n / 2);
out.row(3) = data.X.col(data.centre_index);

// store noise map if requested:
if (noise.valid()) {
Expand All @@ -237,85 +310,47 @@ template <typename F = float> class DenoisingFunctor {
}

private:
const std::array<ssize_t, 3> extent;
KernelData<MatrixType> data;
KernelType kernel;
const ssize_t m, n, r, q;
const bool exp1;
MatrixType X;
MatrixType XtX;
Eigen::SelfAdjointEigenSolver<MatrixType> eig;
SValsType s;
std::array<ssize_t, 3> pos;
double sigma2;
Image<bool> mask;
Image<real_type> noise;
Image<uint16_t> rankmap;

template <typename ImageType> void load_data(ImageType &dwi) {
pos[0] = dwi.index(0);
pos[1] = dwi.index(1);
pos[2] = dwi.index(2);
// fill patch
X.setZero();
size_t k = 0;
for (int z = -extent[2]; z <= extent[2]; z++) {
dwi.index(2) = wrapindex(z, 2, dwi.size(2));
for (int y = -extent[1]; y <= extent[1]; y++) {
dwi.index(1) = wrapindex(y, 1, dwi.size(1));
for (int x = -extent[0]; x <= extent[0]; x++, k++) {
dwi.index(0) = wrapindex(x, 0, dwi.size(0));
X.col(k) = dwi.row(3);
}
}
}
// reset image position
dwi.index(0) = pos[0];
dwi.index(1) = pos[1];
dwi.index(2) = pos[2];
}

inline size_t wrapindex(int r, int axis, int max) const {
// patch handling at image edges
int rr = pos[axis] + r;
if (rr < 0)
rr = extent[axis] - r;
if (rr >= max)
rr = (max - 1) - extent[axis] - r;
return rr;
}
};

template <typename T>
template <typename T, class KernelType>
void process_image(Header &data,
Image<bool> &mask,
Image<real_type> &noise,
Image<uint16_t> &rank,
const std::string &output_name,
const std::vector<uint32_t> &extent,
KernelType &kernel,
bool exp1) {
auto input = data.get_image<T>().with_direct_io(3);
// create output
Header header(data);
header.datatype() = DataType::from<T>();
auto output = Image<T>::create(output_name, header);
// run
DenoisingFunctor<T> func(data.size(3), extent, mask, noise, rank, exp1);
DenoisingFunctor<T, KernelType> func(data.size(3), kernel, mask, noise, rank, exp1);
ThreadedLoop("running MP-PCA denoising", data, 0, 3).run(func, input, output);
}

void run() {
auto dwi = Header::open(argument[0]);

if (dwi.ndim() != 4 || dwi.size(3) <= 1)
throw Exception("input image must be 4-dimensional");

Image<bool> mask;
auto opt = get_options("mask");
if (!opt.empty()) {
mask = Image<bool>::open(opt[0][0]);
check_dimensions(mask, dwi, 0, 3);
}

opt = get_options("extent");
template <typename T>
void make_kernel(Header &data,
Image<bool> &mask,
Image<real_type> &noise,
Image<uint16_t> &rank,
const std::string &output_name,
bool exp1) {
using KernelType = KernelCube<Eigen::Matrix<T, Eigen::Dynamic, Eigen::Dynamic>>;

auto opt = get_options("extent");
std::vector<uint32_t> extent;
if (!opt.empty()) {
extent = parse_ints<uint32_t>(opt[0][0]);
Expand All @@ -326,25 +361,44 @@ void run() {
for (int i = 0; i < 3; i++) {
if (!(extent[i] & 1))
throw Exception("-extent must be a (list of) odd numbers");
if (extent[i] > dwi.size(i))
if (extent[i] > data.size(i))
throw Exception("-extent must not exceed the image dimensions");
}
} else {
uint32_t e = 1;
while (e * e * e < dwi.size(3))
while (Math::pow3(e) < data.size(3))
e += 2;
extent = {
std::min(e, uint32_t(dwi.size(0))), std::min(e, uint32_t(dwi.size(1))), std::min(e, uint32_t(dwi.size(2)))};
extent = {std::min(e, uint32_t(data.size(0))), //
std::min(e, uint32_t(data.size(1))), //
std::min(e, uint32_t(data.size(2)))}; //
}
INFO("selected patch size: " + str(extent[0]) + " x " + str(extent[1]) + " x " + str(extent[2]) + ".");

bool exp1 = get_option_value("estimator", 1) == 0; // default: Exp2 (unbiased estimator)
if (std::min<uint32_t>(data.size(3), extent[0] * extent[1] * extent[2]) < 15) {
WARN("The number of volumes or the patch size is small. "
"This may lead to discretisation effects in the noise level "
"and cause inconsistent denoising between adjacent voxels.");
}

KernelType kernel(extent);
process_image<T, KernelType>(data, mask, noise, rank, output_name, kernel, exp1);
}

if (std::min<uint32_t>(dwi.size(3), extent[0] * extent[1] * extent[2]) < 15) {
WARN("The number of volumes or the patch size is small. This may lead to discretisation effects "
"in the noise level and cause inconsistent denoising between adjacent voxels.");
void run() {
auto dwi = Header::open(argument[0]);

if (dwi.ndim() != 4 || dwi.size(3) <= 1)
throw Exception("input image must be 4-dimensional");

Image<bool> mask;
auto opt = get_options("mask");
if (!opt.empty()) {
mask = Image<bool>::open(opt[0][0]);
check_dimensions(mask, dwi, 0, 3);
}

bool exp1 = get_option_value("estimator", 1) == 0; // default: Exp2 (unbiased estimator)

Image<real_type> noise;
opt = get_options("noise");
if (!opt.empty()) {
Expand All @@ -370,19 +424,19 @@ void run() {
switch (prec) {
case 0:
INFO("select real float32 for processing");
process_image<float>(dwi, mask, noise, rank, argument[1], extent, exp1);
make_kernel<float>(dwi, mask, noise, rank, argument[1], exp1);
break;
case 1:
INFO("select real float64 for processing");
process_image<double>(dwi, mask, noise, rank, argument[1], extent, exp1);
make_kernel<double>(dwi, mask, noise, rank, argument[1], exp1);
break;
case 2:
INFO("select complex float32 for processing");
process_image<cfloat>(dwi, mask, noise, rank, argument[1], extent, exp1);
make_kernel<cfloat>(dwi, mask, noise, rank, argument[1], exp1);
break;
case 3:
INFO("select complex float64 for processing");
process_image<cdouble>(dwi, mask, noise, rank, argument[1], extent, exp1);
make_kernel<cdouble>(dwi, mask, noise, rank, argument[1], exp1);
break;
}
}

0 comments on commit 4424cb4

Please sign in to comment.