Skip to content

Commit

Permalink
fix(demo): update matmul
Browse files Browse the repository at this point in the history
Signed-off-by: Martin Kröning <[email protected]>
  • Loading branch information
mkroening committed Apr 10, 2024
1 parent 9cb68d7 commit 3787cc9
Show file tree
Hide file tree
Showing 2 changed files with 26 additions and 26 deletions.
1 change: 0 additions & 1 deletion examples/demo/src/main.rs
Original file line number Diff line number Diff line change
@@ -1,4 +1,3 @@
#![allow(dead_code)]
#![feature(thread_id_value)]

#[cfg(target_os = "hermit")]
Expand Down
51 changes: 26 additions & 25 deletions examples/demo/src/matmul.rs
Original file line number Diff line number Diff line change
@@ -1,17 +1,15 @@
#![allow(clippy::many_single_char_names)]
#![allow(clippy::too_many_arguments)]
//! Parallel matrix multiplication.
//!
//! Taken from <https://github.com/rayon-rs/rayon/blob/main/rayon-demo/src/matmul/mod.rs>.
use std::thread;
use std::time::Instant;

/// Code is derived Rayon's matmul example
/// https://github.com/rayon-rs/rayon/tree/master/rayon-demo/src/matmul
use rayon::prelude::*;

// TODO: Investigate other cache patterns for row-major order that may be more
// parallelizable.
// https://tavianator.com/a-quick-trick-for-faster-naive-matrix-multiplication/
fn seq_matmul(a: &[f32], b: &[f32], dest: &mut [f32]) {
pub fn seq_matmul(a: &[f32], b: &[f32], dest: &mut [f32]) {
// Zero dest, as it may be uninitialized.
for d in dest.iter_mut() {
*d = 0.0;
Expand Down Expand Up @@ -74,7 +72,7 @@ fn test_splayed_counter() {
// Multiply the matrices laid out in z order.
// https://en.wikipedia.org/wiki/Z-order_curve
#[inline(never)]
fn seq_matmulz(a: &[f32], b: &[f32], dest: &mut [f32]) {
pub fn seq_matmulz(a: &[f32], b: &[f32], dest: &mut [f32]) {
// All inputs need to be the same length.
assert!(a.len() == b.len() && a.len() == dest.len());
// Input matrices must be square with each side a power of 2.
Expand Down Expand Up @@ -103,7 +101,8 @@ fn seq_matmulz(a: &[f32], b: &[f32], dest: &mut [f32]) {
}
}

const MULT_CHUNK: usize = 1024;
#[allow(clippy::identity_op)]
const MULT_CHUNK: usize = 1 * 1024;
const LINEAR_CHUNK: usize = 64 * 1024;

fn quarter_chunks(v: &[f32]) -> (&[f32], &[f32], &[f32], &[f32]) {
Expand Down Expand Up @@ -139,6 +138,7 @@ where
(r1, r2, r3, r4)
}

#[allow(clippy::too_many_arguments)]
fn join8<F1, F2, F3, F4, F5, F6, F7, F8, R1, R2, R3, R4, R5, R6, R7, R8>(
f1: F1,
f2: F2,
Expand Down Expand Up @@ -175,7 +175,7 @@ where
}

// Multiply two square power of two matrices, given in Z-order.
fn matmulz(a: &[f32], b: &[f32], dest: &mut [f32]) {
pub fn matmulz(a: &[f32], b: &[f32], dest: &mut [f32]) {
if a.len() <= MULT_CHUNK {
seq_matmulz(a, b, dest);
return;
Expand Down Expand Up @@ -206,7 +206,7 @@ fn matmulz(a: &[f32], b: &[f32], dest: &mut [f32]) {
rmatsum(tmp.as_mut(), dest);
}

fn matmul_strassen(a: &[f32], b: &[f32], dest: &mut [f32]) {
pub fn matmul_strassen(a: &[f32], b: &[f32], dest: &mut [f32]) {
if a.len() <= MULT_CHUNK {
seq_matmulz(a, b, dest);
return;
Expand Down Expand Up @@ -239,7 +239,8 @@ fn matmul_strassen(a: &[f32], b: &[f32], dest: &mut [f32]) {
}

fn raw_buffer(n: usize) -> Vec<f32> {
vec![0.0; n]
// A zero-initialized buffer is fast enough for our purposes.
vec![0f32; n]
}

fn strassen_add2_mul(a1: &[f32], a2: &[f32], b1: &[f32], b2: &[f32]) -> Vec<f32> {
Expand Down Expand Up @@ -381,22 +382,22 @@ fn timed_matmul<F: FnOnce(&[f32], &[f32], &mut [f32])>(size: usize, f: F, name:
nanos
}

const ROW_SIZE: usize = 512;
const SIZE: usize = if cfg!(debug_assertions) { 64 } else { 256 };

pub fn test_matmul_strassen() -> Result<(), ()> {
let ncpus = thread::available_parallelism().unwrap().get();
let pool = rayon::ThreadPoolBuilder::new()
.num_threads(ncpus)
.build()
.unwrap();
let n = ROW_SIZE * ROW_SIZE;
let x = vec![1f32; n];
let y = vec![2f32; n];
let mut z = vec![0f32; n];

let now = Instant::now();
pool.install(|| matmul_strassen(&x, &y, &mut z));
println!("Time to multiply matrix {} s", now.elapsed().as_secs_f64(),);
if SIZE <= 1024 {
// Crappy algorithm takes several minutes on larger inputs.
timed_matmul(SIZE, seq_matmul, "seq row-major");
}
let seq = if SIZE <= 2048 {
timed_matmul(SIZE, seq_matmulz, "seq z-order")
} else {
0
};
let par = timed_matmul(SIZE, matmulz, "par z-order");
timed_matmul(SIZE, matmul_strassen, "par strassen");
let speedup = seq as f64 / par as f64;
println!("speedup: {:.2}x", speedup);

Ok(())
}

0 comments on commit 3787cc9

Please sign in to comment.