Skip to content

Commit

Permalink
Merge pull request #1 from neodsp/dev
Browse files Browse the repository at this point in the history
FftConvolver is now generic for f32 and f64
  • Loading branch information
steckes authored Oct 19, 2023
2 parents c2438f8 + 25b5685 commit 7f09eb8
Show file tree
Hide file tree
Showing 4 changed files with 93 additions and 77 deletions.
3 changes: 2 additions & 1 deletion Cargo.toml
Original file line number Diff line number Diff line change
@@ -1,6 +1,6 @@
[package]
name = "fft-convolver"
version = "0.1.0"
version = "0.2.0"
edition = "2021"
authors = ["Stephan Eckes <[email protected]>"]
description = "Audio convolution algorithm in pure Rust for real time audio processing"
Expand All @@ -17,3 +17,4 @@ homepage = "https://neodsp.com/"
realfft = "3.0.1"
rustfft = "6.0.1"
thiserror = "1.0.37"
num = "0.4.1"
28 changes: 16 additions & 12 deletions src/fft.rs
Original file line number Diff line number Diff line change
@@ -1,44 +1,48 @@
use realfft::{ComplexToReal, FftError, RealFftPlanner, RealToComplex};
use rustfft::num_complex::Complex;
use rustfft::{num_complex::Complex, FftNum};
use std::sync::Arc;

pub struct Fft {
fft_forward: Arc<dyn RealToComplex<f32>>,
fft_inverse: Arc<dyn ComplexToReal<f32>>,
pub struct Fft<F: FftNum> {
fft_forward: Arc<dyn RealToComplex<F>>,
fft_inverse: Arc<dyn ComplexToReal<F>>,
}

impl std::fmt::Debug for Fft {
impl<F: FftNum> std::fmt::Debug for Fft<F> {
fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
write!(f, "")
}
}

impl Fft {
pub fn default() -> Self {
let mut planner = RealFftPlanner::<f32>::new();
impl<F: FftNum> Default for Fft<F> {
fn default() -> Self {
let mut planner = RealFftPlanner::new();
Self {
fft_forward: planner.plan_fft_forward(0),
fft_inverse: planner.plan_fft_inverse(0),
}
}
}

impl<F: FftNum> Fft<F> {
pub fn init(&mut self, length: usize) {
let mut planner = RealFftPlanner::<f32>::new();
let mut planner = RealFftPlanner::new();
self.fft_forward = planner.plan_fft_forward(length);
self.fft_inverse = planner.plan_fft_inverse(length);
}

pub fn forward(&self, input: &mut [f32], output: &mut [Complex<f32>]) -> Result<(), FftError> {
pub fn forward(&self, input: &mut [F], output: &mut [Complex<F>]) -> Result<(), FftError> {
self.fft_forward.process(input, output)?;
Ok(())
}

pub fn inverse(&self, input: &mut [Complex<f32>], output: &mut [f32]) -> Result<(), FftError> {
pub fn inverse(&self, input: &mut [Complex<F>], output: &mut [F]) -> Result<(), FftError> {
self.fft_inverse.process(input, output)?;

// FFT Normalization
let len = output.len();
output.iter_mut().for_each(|bin| *bin /= len as f32);
output.iter_mut().for_each(|bin| {
*bin = *bin / F::from_usize(len).expect("usize can be converted to FftNum");
});

Ok(())
}
Expand Down
93 changes: 47 additions & 46 deletions src/lib.rs
Original file line number Diff line number Diff line change
Expand Up @@ -4,8 +4,10 @@ use crate::fft::Fft;
use crate::utilities::{
complex_multiply_accumulate, complex_size, copy_and_pad, next_power_of_2, sum,
};
use num::Zero;
use realfft::FftError;
use rustfft::num_complex::Complex;
use rustfft::FftNum;
use thiserror::Error;

#[derive(Error, Debug)]
Expand Down Expand Up @@ -38,47 +40,49 @@ pub enum FFTConvolverProcessError {
/// performed during processing (all necessary allocations and preparations take
/// place during initialization).
#[derive(Debug)]
pub struct FFTConvolver {
pub struct FFTConvolver<F: FftNum> {
ir_len: usize,
block_size: usize,
seg_size: usize,
seg_count: usize,
active_seg_count: usize,
fft_complex_size: usize,
segments: Vec<Vec<Complex<f32>>>,
segments_ir: Vec<Vec<Complex<f32>>>,
fft_buffer: Vec<f32>,
fft: Fft,
pre_multiplied: Vec<Complex<f32>>,
conv: Vec<Complex<f32>>,
overlap: Vec<f32>,
segments: Vec<Vec<Complex<F>>>,
segments_ir: Vec<Vec<Complex<F>>>,
fft_buffer: Vec<F>,
fft: Fft<F>,
pre_multiplied: Vec<Complex<F>>,
conv: Vec<Complex<F>>,
overlap: Vec<F>,
current: usize,
input_buffer: Vec<f32>,
input_buffer: Vec<F>,
input_buffer_fill: usize,
}

impl FFTConvolver {
pub fn default() -> Self {
impl<F: FftNum> Default for FFTConvolver<F> {
fn default() -> Self {
Self {
ir_len: 0,
block_size: 0,
seg_size: 0,
seg_count: 0,
active_seg_count: 0,
fft_complex_size: 0,
segments: Vec::new(),
segments_ir: Vec::new(),
fft_buffer: Vec::new(),
fft: Fft::default(),
pre_multiplied: Vec::new(),
conv: Vec::new(),
overlap: Vec::new(),
current: 0,
input_buffer: Vec::new(),
input_buffer_fill: 0,
ir_len: Default::default(),
block_size: Default::default(),
seg_size: Default::default(),
seg_count: Default::default(),
active_seg_count: Default::default(),
fft_complex_size: Default::default(),
segments: Default::default(),
segments_ir: Default::default(),
fft_buffer: Default::default(),
fft: Default::default(),
pre_multiplied: Default::default(),
conv: Default::default(),
overlap: Default::default(),
current: Default::default(),
input_buffer: Default::default(),
input_buffer_fill: Default::default(),
}
}
}

impl<F: FftNum> FFTConvolver<F> {
/// Resets the convolver and discards the set impulse response
pub fn reset(&mut self) {
*self = Self::default();
Expand All @@ -95,7 +99,7 @@ impl FFTConvolver {
pub fn init(
&mut self,
block_size: usize,
impulse_response: &[f32],
impulse_response: &[F],
) -> Result<(), FFTConvolverInitError> {
self.reset();

Expand All @@ -117,17 +121,15 @@ impl FFTConvolver {

// FFT
self.fft.init(self.seg_size);
self.fft_buffer.resize(self.seg_size, 0.);
self.fft_buffer.resize(self.seg_size, F::zero());

// prepare segments
self.segments.resize(
self.seg_count,
vec![Complex::new(0., 0.); self.fft_complex_size],
);
self.segments
.resize(self.seg_count, vec![Complex::zero(); self.fft_complex_size]);

// prepare ir
for i in 0..self.seg_count {
let mut segment = vec![Complex::new(0., 0.); self.fft_complex_size];
let mut segment = vec![Complex::zero(); self.fft_complex_size];
let remaining = self.ir_len - (i * self.block_size);
let size_copy = if remaining >= self.block_size {
self.block_size
Expand All @@ -145,13 +147,12 @@ impl FFTConvolver {

// prepare convolution buffers
self.pre_multiplied
.resize(self.fft_complex_size, Complex::new(0., 0.));
self.conv
.resize(self.fft_complex_size, Complex::new(0., 0.));
self.overlap.resize(self.block_size, 0.);
.resize(self.fft_complex_size, Complex::zero());
self.conv.resize(self.fft_complex_size, Complex::zero());
self.overlap.resize(self.block_size, F::zero());

// prepare input buffer
self.input_buffer.resize(self.block_size, 0.);
self.input_buffer.resize(self.block_size, F::zero());
self.input_buffer_fill = 0;

// reset current position
Expand All @@ -168,11 +169,11 @@ impl FFTConvolver {
/// * `output` - The convolution result
pub fn process(
&mut self,
input: &[f32],
output: &mut [f32],
input: &[F],
output: &mut [F],
) -> Result<(), FFTConvolverProcessError> {
if self.active_seg_count == 0 {
output.fill(0.);
output.fill(F::zero());
return Ok(());
}

Expand All @@ -194,13 +195,13 @@ impl FFTConvolver {
.fft
.forward(&mut self.fft_buffer, &mut self.segments[self.current])
{
output.fill(0.);
output.fill(F::zero());
return Err(err.into());
}

// complex multiplication
if input_buffer_was_empty {
self.pre_multiplied.fill(Complex { re: 0., im: 0. });
self.pre_multiplied.fill(Complex::zero());
for i in 1..self.active_seg_count {
let index_ir = i;
let index_audio = (self.current + i) % self.active_seg_count;
Expand All @@ -220,7 +221,7 @@ impl FFTConvolver {

// Backward FFT
if let Err(err) = self.fft.inverse(&mut self.conv, &mut self.fft_buffer) {
output.fill(0.);
output.fill(F::zero());
return Err(err.into());
}

Expand All @@ -235,7 +236,7 @@ impl FFTConvolver {
self.input_buffer_fill += processing;
if self.input_buffer_fill == self.block_size {
// Input buffer is empty again now
self.input_buffer.fill(0.);
self.input_buffer.fill(F::zero());
self.input_buffer_fill = 0;
// Save the overlap
self.overlap
Expand Down
46 changes: 28 additions & 18 deletions src/utilities.rs
Original file line number Diff line number Diff line change
@@ -1,4 +1,4 @@
use rustfft::num_complex::Complex;
use rustfft::{num_complex::Complex, FftNum};

pub fn next_power_of_2(value: usize) -> usize {
let mut new_value = 1;
Expand All @@ -14,38 +14,48 @@ pub fn complex_size(size: usize) -> usize {
(size / 2) + 1
}

pub fn copy_and_pad(dst: &mut [f32], src: &[f32], src_size: usize) {
pub fn copy_and_pad<F: FftNum>(dst: &mut [F], src: &[F], src_size: usize) {
assert!(dst.len() >= src_size);
dst[0..src_size].clone_from_slice(&src[0..src_size]);
dst[src_size..].iter_mut().for_each(|value| *value = 0.);
dst[src_size..]
.iter_mut()
.for_each(|value| *value = F::zero());
}

pub fn complex_multiply_accumulate(
result: &mut [Complex<f32>],
a: &[Complex<f32>],
b: &[Complex<f32>],
pub fn complex_multiply_accumulate<F: FftNum>(
result: &mut [Complex<F>],
a: &[Complex<F>],
b: &[Complex<F>],
) {
assert_eq!(result.len(), a.len());
assert_eq!(result.len(), b.len());
let len = result.len();
let end4 = 4 * (len / 4);
for i in (0..end4).step_by(4) {
result[i + 0].re += a[i + 0].re * b[i + 0].re - a[i + 0].im * b[i + 0].im;
result[i + 1].re += a[i + 1].re * b[i + 1].re - a[i + 1].im * b[i + 1].im;
result[i + 2].re += a[i + 2].re * b[i + 2].re - a[i + 2].im * b[i + 2].im;
result[i + 3].re += a[i + 3].re * b[i + 3].re - a[i + 3].im * b[i + 3].im;
result[i + 0].im += a[i + 0].re * b[i + 0].im + a[i + 0].im * b[i + 0].re;
result[i + 1].im += a[i + 1].re * b[i + 1].im + a[i + 1].im * b[i + 1].re;
result[i + 2].im += a[i + 2].re * b[i + 2].im + a[i + 2].im * b[i + 2].re;
result[i + 3].im += a[i + 3].re * b[i + 3].im + a[i + 3].im * b[i + 3].re;
result[i + 0].re =
result[i + 0].re + (a[i + 0].re * b[i + 0].re - a[i + 0].im * b[i + 0].im);
result[i + 1].re =
result[i + 1].re + (a[i + 1].re * b[i + 1].re - a[i + 1].im * b[i + 1].im);
result[i + 2].re =
result[i + 2].re + (a[i + 2].re * b[i + 2].re - a[i + 2].im * b[i + 2].im);
result[i + 3].re =
result[i + 3].re + (a[i + 3].re * b[i + 3].re - a[i + 3].im * b[i + 3].im);
result[i + 0].im =
result[i + 0].im + (a[i + 0].re * b[i + 0].im + a[i + 0].im * b[i + 0].re);
result[i + 1].im =
result[i + 1].im + (a[i + 1].re * b[i + 1].im + a[i + 1].im * b[i + 1].re);
result[i + 2].im =
result[i + 2].im + (a[i + 2].re * b[i + 2].im + a[i + 2].im * b[i + 2].re);
result[i + 3].im =
result[i + 3].im + (a[i + 3].re * b[i + 3].im + a[i + 3].im * b[i + 3].re);
}
for i in end4..len {
result[i].re += a[i].re * b[i].re - a[i].im * b[i].im;
result[i].im += a[i].re * b[i].im + a[i].im * b[i].re;
result[i].re = result[i].re + (a[i].re * b[i].re - a[i].im * b[i].im);
result[i].im = result[i].im + (a[i].re * b[i].im + a[i].im * b[i].re);
}
}

pub fn sum(result: &mut [f32], a: &[f32], b: &[f32]) {
pub fn sum<F: FftNum>(result: &mut [F], a: &[F], b: &[F]) {
assert_eq!(result.len(), a.len());
assert_eq!(result.len(), b.len());
let len = result.len();
Expand Down

0 comments on commit 7f09eb8

Please sign in to comment.