From ce22e7539823d113bf889dbcc7d4e171ec5b6f68 Mon Sep 17 00:00:00 2001 From: uttarayan21 Date: Fri, 29 Nov 2024 15:21:22 +0530 Subject: [PATCH 01/22] feat: Update to mnn 3.0.0 --- flake.lock | 8 ++++---- flake.nix | 10 ++++++---- mnn-sys/vendor | 2 +- src/schedule.rs | 22 +++++++++++++++------- 4 files changed, 26 insertions(+), 16 deletions(-) diff --git a/flake.lock b/flake.lock index eec96e3..fa6f8b2 100644 --- a/flake.lock +++ b/flake.lock @@ -109,16 +109,16 @@ "mnn-src": { "flake": false, "locked": { - "lastModified": 1728910345, - "narHash": "sha256-jsCdiFW8oIlKISo/+qrMDy8/RHblrgaQdSuQhoiOd8M=", + "lastModified": 1732022377, + "narHash": "sha256-Ui6b83zI1MaiuXeGYNftxAKoGgwVLj6rwSd7H4aVs/8=", "owner": "alibaba", "repo": "MNN", - "rev": "a74551b4f34b46ce7027c64e800d49fcab497261", + "rev": "707b8a41b25e3d0b7c4a39cd81109d7074ca3c28", "type": "github" }, "original": { "owner": "alibaba", - "ref": "2.9.6", + "ref": "3.0.0", "repo": "MNN", "type": "github" } diff --git a/flake.nix b/flake.nix index f60d9e3..cc972d8 100644 --- a/flake.nix +++ b/flake.nix @@ -22,7 +22,7 @@ flake = false; }; mnn-src = { - url = "github:alibaba/MNN/2.9.6"; + url = "github:alibaba/MNN/3.0.0"; flake = false; }; }; @@ -47,7 +47,7 @@ rust-overlay.overlays.default (final: prev: { mnn = mnn-overlay.packages.${system}.mnn.override { - version = "2.9.6"; + version = "3.0.0"; src = mnn-src; buildConverter = true; enableVulkan = false; @@ -199,7 +199,9 @@ pname = "inspect"; cargoExtraArgs = "--example inspect" - + (lib.optionalString pkgs.stdenv.isDarwin " --features opencl" + lib.optionalString pkgs.stdenv.isAarch64 ",metal,coreml"); + + ( + lib.optionalString pkgs.stdenv.isDarwin " --features opencl" # + lib.optionalString pkgs.stdenv.isAarch64 ",metal,coreml" + ); }); default = mnn; }; @@ -218,7 +220,7 @@ git git-lfs llvm - mnn + # mnn nushell rust-bindgen rustToolchainWithRustAnalyzer diff --git a/mnn-sys/vendor b/mnn-sys/vendor index a74551b..707b8a4 160000 --- a/mnn-sys/vendor +++ b/mnn-sys/vendor @@ -1 +1 @@ -Subproject commit a74551b4f34b46ce7027c64e800d49fcab497261 +Subproject commit 707b8a41b25e3d0b7c4a39cd81109d7074ca3c28 diff --git a/src/schedule.rs b/src/schedule.rs index 293d720..da7e570 100644 --- a/src/schedule.rs +++ b/src/schedule.rs @@ -227,7 +227,7 @@ impl ScheduleConfig { /// # Errors /// /// Returns an error if any of the tensor names contain null bytes. - pub fn set_save_tensors(&mut self, save_tensors: &[&str]) -> Result<()> { + pub fn set_save_tensors(&mut self, save_tensors: &[&str]) -> Result<&mut Self> { let vec_cstring = save_tensors .iter() .map(|s| std::ffi::CString::new(*s).map_err(|e| error!(ErrorKind::AsciiError, e))) @@ -237,7 +237,7 @@ impl ScheduleConfig { .map(|s: &CString| s.as_c_str().as_ptr()) .collect::>(); unsafe { mnnsc_set_save_tensors(self.inner, vec_cstr.as_ptr(), vec_cstr.len()) } - Ok(()) + Ok(self) } /// Sets the type of backend to be used for computation. @@ -245,10 +245,11 @@ impl ScheduleConfig { /// # Arguments /// /// - `forward_type`: The type of backend to be used. - pub fn set_type(&mut self, forward_type: ForwardType) { + pub fn set_type(&mut self, forward_type: ForwardType) -> &mut Self { unsafe { mnnsc_set_type(self.inner, forward_type.to_mnn_sys()); } + self } /// Sets the number of threads to be used for computation. @@ -256,10 +257,11 @@ impl ScheduleConfig { /// # Arguments /// /// - `num_threads`: The number of threads to be used. - pub fn set_num_threads(&mut self, num_threads: i32) { + pub fn set_num_threads(&mut self, num_threads: i32) -> &mut Self { unsafe { mnnsc_set_num_threads(self.inner, num_threads); } + self } /// Sets the mode of computation. @@ -267,10 +269,11 @@ impl ScheduleConfig { /// # Arguments /// /// - `mode`: The mode of computation. - pub fn set_mode(&mut self, mode: i32) { + pub fn set_mode(&mut self, mode: i32) -> &mut Self { unsafe { mnnsc_set_mode(self.inner, mode); } + self } /// Sets the backup type of backend to be used if the primary backend fails. @@ -278,10 +281,11 @@ impl ScheduleConfig { /// # Arguments /// /// - `backup_type`: The backup type of backend to be used. - pub fn set_backup_type(&mut self, backup_type: ForwardType) { + pub fn set_backup_type(&mut self, backup_type: ForwardType) -> &mut Self { unsafe { mnnsc_set_backup_type(self.inner, backup_type.to_mnn_sys()); } + self } /// Sets the backend-specific configuration. @@ -289,7 +293,10 @@ impl ScheduleConfig { /// # Arguments /// /// - `backend_config`: specifies additional backend-specific configurations. - pub fn set_backend_config(&mut self, backend_config: impl Into>) { + pub fn set_backend_config( + &mut self, + backend_config: impl Into>, + ) -> &mut Self { self.backend_config = backend_config.into(); let ptr = if let Some(ref b) = self.backend_config { b.inner @@ -299,6 +306,7 @@ impl ScheduleConfig { unsafe { mnnsc_set_backend_config(self.inner, ptr); } + self } } From dd5ee5d587d2027181b87b8b1ccd19c1a1f87266 Mon Sep 17 00:00:00 2001 From: uttarayan21 Date: Mon, 2 Dec 2024 15:40:55 +0530 Subject: [PATCH 02/22] feat: Emit tracing events from mnn --- Cargo.lock | 19 +-- Cargo.toml | 6 +- examples/inspect.rs | 8 ++ flake.nix | 6 +- mnn-sys/Cargo.toml | 2 + mnn-sys/mnn_c/mnndefine.h | 29 +++++ mnn-sys/patches/mnn-tracing.patch | 34 ++++++ mnn-sys/src/lib.rs | 1 + mnn-sys/src/tracing.rs | 184 ++++++++++++++++++++++++++++++ 9 files changed, 278 insertions(+), 11 deletions(-) create mode 100644 mnn-sys/patches/mnn-tracing.patch create mode 100644 mnn-sys/src/tracing.rs diff --git a/Cargo.lock b/Cargo.lock index 5b9aee7..8b4501b 100644 --- a/Cargo.lock +++ b/Cargo.lock @@ -450,6 +450,7 @@ dependencies = [ "serde", "thiserror", "tracing", + "tracing-subscriber", ] [[package]] @@ -487,7 +488,9 @@ dependencies = [ "fs_extra", "itertools", "libc", + "once_cell", "tap", + "tracing-core", ] [[package]] @@ -861,9 +864,9 @@ dependencies = [ [[package]] name = "tracing" -version = "0.1.40" +version = "0.1.41" source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "c3523ab5a71916ccf420eebdf5521fcef02141234bbc0b8a49f2fdc4544364ef" +checksum = "784e0ac535deb450455cbfa28a6f0df145ea1bb7ae51b821cf5e7927fdcfbdd0" dependencies = [ "pin-project-lite", "tracing-attributes", @@ -872,9 +875,9 @@ dependencies = [ [[package]] name = "tracing-attributes" -version = "0.1.27" +version = "0.1.28" source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "34704c8d6ebcbc939824180af020566b01a7c01f80641264eba0999f6c2b6be7" +checksum = "395ae124c09f9e6918a2310af6038fba074bcf474ac352496d5910dd59a2226d" dependencies = [ "proc-macro2", "quote", @@ -883,9 +886,9 @@ dependencies = [ [[package]] name = "tracing-core" -version = "0.1.32" +version = "0.1.33" source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "c06d3da6113f116aaee68e4d601191614c9053067f9ab7f6edbcb161237daa54" +checksum = "e672c95779cf947c5311f83787af4fa8fffd12fb27e4993211a84bdfd9610f9c" dependencies = [ "once_cell", "valuable", @@ -904,9 +907,9 @@ dependencies = [ [[package]] name = "tracing-subscriber" -version = "0.3.18" +version = "0.3.19" source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "ad0f048c97dbd9faa9b7df56362b8ebcaa52adb06b498c050d2f4e32f90a7a8b" +checksum = "e8189decb5ac0fa7bc8b96b7cb9b2701d60d48805aca84a238004d665fcc4008" dependencies = [ "matchers", "nu-ansi-term 0.46.0", diff --git a/Cargo.toml b/Cargo.toml index ab65104..38fa5fa 100644 --- a/Cargo.toml +++ b/Cargo.toml @@ -19,7 +19,7 @@ mnn-sys = { version = "0.1", path = "mnn-sys", features = [] } thiserror = "2.0" error-stack.workspace = true oneshot = "0.1" -tracing = { version = "0.1.40", optional = true } +tracing = { version = "0.1.40" } dunce = "1.0.5" serde = { version = "1.0", features = ["derive"], optional = true } @@ -33,7 +33,8 @@ crt_static = ["mnn-sys/crt_static"] # Disable mnn-threadpool to enable this openmp = ["mnn-sys/openmp"] mnn-threadpool = ["mnn-sys/mnn-threadpool"] -tracing = ["dep:tracing"] +# To keep compatibility with older versions newer ones automatically use it +tracing = [] profile = ["tracing"] serde = ["dep:serde"] @@ -45,6 +46,7 @@ anyhow = "1.0" bytemuck = "1.17" clap = { version = "4.5", features = ["derive"] } divan = "0.1.14" +tracing-subscriber = "0.3.19" # mnn-sync = { path = "mnn-sync" } [[bench]] diff --git a/examples/inspect.rs b/examples/inspect.rs index f7ffb9c..44b967d 100644 --- a/examples/inspect.rs +++ b/examples/inspect.rs @@ -48,6 +48,14 @@ pub fn main() -> anyhow::Result<()> { let mut interpreter = Interpreter::from_file(&cli.model)?; interpreter.set_cache_file(cli.model.with_extension("cache"), 128)?; + tracing_subscriber::fmt() + .event_format( + tracing_subscriber::fmt::format() + .with_file(true) + .with_line_number(true), + ) + .init(); + let mut config = ScheduleConfig::new(); config.set_type(cli.forward); let mut session = time!(interpreter.create_session(config)?; "create session"); diff --git a/flake.nix b/flake.nix index cc972d8..648a255 100644 --- a/flake.nix +++ b/flake.nix @@ -86,7 +86,11 @@ craneLibLLvmTools = (crane.mkLib pkgs).overrideToolchain rustToolchainWithLLvmTools; src = lib.sources.sourceFilesBySuffices ./. [".rs" ".toml" ".patch" ".mnn" ".h" ".cpp" ".svg" "lock"]; - MNN_SRC = mnn-src; + MNN_SRC = pkgs.applyPatches { + name = "mnn-src"; + src = mnn-src; + patches = [./mnn-sys/patches/mnn-tracing.patch]; + }; commonArgs = { inherit src MNN_SRC; pname = "mnn"; diff --git a/mnn-sys/Cargo.toml b/mnn-sys/Cargo.toml index b0cb6ce..8a3befb 100644 --- a/mnn-sys/Cargo.toml +++ b/mnn-sys/Cargo.toml @@ -31,3 +31,5 @@ crt_static = [] [dependencies] libc = "0.2.155" +once_cell = "1.20.2" +tracing-core = "0.1.33" diff --git a/mnn-sys/mnn_c/mnndefine.h b/mnn-sys/mnn_c/mnndefine.h index e69de29..6b0057d 100644 --- a/mnn-sys/mnn_c/mnndefine.h +++ b/mnn-sys/mnn_c/mnndefine.h @@ -0,0 +1,29 @@ +#ifndef MNNDEFINE_H +#define MNNDEFINE_H +#include + +enum class Level { + Info = 0, + Error = 1, +}; + +extern "C" { +void mnn_ffi_emit(const char *file, size_t line, Level level, + const char *message); +} + +#define MNN_PRINT(format, ...) \ + { \ + char logtmp[4096]; \ + snprintf(logtmp, 4096, format, ##__VA_ARGS__); \ + mnn_ffi_emit(__FILE__, __LINE__, Level::Info, logtmp); \ + } + +#define MNN_ERROR(format, ...) \ + { \ + char logtmp[4096]; \ + snprintf(logtmp, 4096, format, ##__VA_ARGS__); \ + mnn_ffi_emit(__FILE__, __LINE__, Level::Error, logtmp); \ + } + +#endif diff --git a/mnn-sys/patches/mnn-tracing.patch b/mnn-sys/patches/mnn-tracing.patch new file mode 100644 index 0000000..7939766 --- /dev/null +++ b/mnn-sys/patches/mnn-tracing.patch @@ -0,0 +1,34 @@ +diff --git a/include/MNN/MNNDefine.h b/include/MNN/MNNDefine.h +index 8f30cd68..77407812 100644 +--- a/include/MNN/MNNDefine.h ++++ b/include/MNN/MNNDefine.h +@@ -35,8 +35,27 @@ + #define MNN_PRINT(format, ...) syslog(LOG_WARNING, format, ##__VA_ARGS__); fprintf(stderr, format, ##__VA_ARGS__) + #define MNN_ERROR(format, ...) syslog(LOG_WARNING, format, ##__VA_ARGS__); fprintf(stderr, format, ##__VA_ARGS__) + #else +-#define MNN_PRINT(format, ...) printf(format, ##__VA_ARGS__) +-#define MNN_ERROR(format, ...) printf(format, ##__VA_ARGS__) ++enum class Level { ++ Info = 0, ++ Error = 1, ++}; ++extern "C" { ++void mnn_ffi_emit(const char *file, size_t line, Level level, ++ const char *message); ++} ++#define MNN_PRINT(format, ...) \ ++ { \ ++ char logtmp[4096]; \ ++ snprintf(logtmp, 4096, format, ##__VA_ARGS__); \ ++ mnn_ffi_emit(__FILE__, __LINE__, Level::Info, logtmp); \ ++ } ++ ++#define MNN_ERROR(format, ...) \ ++ { \ ++ char logtmp[4096]; \ ++ snprintf(logtmp, 4096, format, ##__VA_ARGS__); \ ++ mnn_ffi_emit(__FILE__, __LINE__, Level::Error, logtmp); \ ++ } + #endif + + #ifdef DEBUG diff --git a/mnn-sys/src/lib.rs b/mnn-sys/src/lib.rs index ce2bafb..c2bf3b6 100644 --- a/mnn-sys/src/lib.rs +++ b/mnn-sys/src/lib.rs @@ -1,4 +1,5 @@ use std::ffi::CStr; +mod tracing; pub mod cpp { #![allow(non_upper_case_globals)] diff --git a/mnn-sys/src/tracing.rs b/mnn-sys/src/tracing.rs new file mode 100644 index 0000000..0e77309 --- /dev/null +++ b/mnn-sys/src/tracing.rs @@ -0,0 +1,184 @@ +// This is mostly adapted from tracing-gstreamer crate's implementation +use once_cell::sync::OnceCell; +use std::sync::atomic::AtomicUsize; +use std::sync::{PoisonError, RwLock}; +use std::{collections::BTreeMap, ffi::c_char}; +use tracing_core::{field::FieldSet, identify_callsite, Callsite, Interest, Kind, Metadata}; + +pub const CALLSITE_INTEREST_NEVER: usize = 1; +pub const CALLSITE_INTEREST_SOMETIMES: usize = 2; +pub const CALLSITE_INTEREST_ALWAYS: usize = 3; + +#[repr(C)] +#[derive(Debug, Clone, Copy, PartialEq, Eq, Ord, PartialOrd)] +pub enum Level { + Info = 0, + Error = 1, +} + +impl From for tracing_core::Level { + fn from(value: Level) -> Self { + match value { + Level::Info => tracing_core::Level::INFO, + Level::Error => tracing_core::Level::ERROR, + } + } +} + +pub struct DynamicCallsites { + callsites: RwLock, +} + +type Map = BTreeMap, &'static MnnCallsite>; + +impl DynamicCallsites { + pub(crate) fn get() -> &'static Self { + static MAP: OnceCell = OnceCell::new(); + MAP.get_or_init(|| DynamicCallsites { + callsites: RwLock::new(Map::new()), + }) + } + + fn callsite_for( + &'static self, + level: Level, + line: Option, + file: Option<&'static str>, + ) -> &'static MnnCallsite { + let mut guard = self + .callsites + .write() + .unwrap_or_else(PoisonError::into_inner); + let lookup_key = Key { level, line, file }; + if let Some(callsite) = guard.get(&lookup_key) { + return callsite; + } + let callsite = MnnCallsite::make_static(&lookup_key); + let key = Key::<'static> { + level, + line, + file: callsite.metadata.file(), + }; + guard.insert(key, callsite); + tracing_core::callsite::register(callsite); + callsite + } +} + +impl Callsite for MnnCallsite { + fn set_interest(&self, interest: Interest) { + self.interest.store( + match () { + _ if interest.is_never() => CALLSITE_INTEREST_NEVER, + _ if interest.is_always() => CALLSITE_INTEREST_ALWAYS, + _ => CALLSITE_INTEREST_SOMETIMES, + }, + std::sync::atomic::Ordering::Release, + ); + } + + fn metadata(&self) -> &Metadata<'_> { + &self.metadata + } +} +#[derive(Debug, Clone, Copy, PartialEq, Eq, PartialOrd, Ord)] +pub struct Key<'k> { + level: Level, + line: Option, + file: Option<&'k str>, +} + +impl DynamicCallsites {} + +pub struct MnnCallsite { + interest: AtomicUsize, + metadata: Metadata<'static>, +} + +impl MnnCallsite { + pub fn make_static(key: &Key<'static>) -> &'static Self { + unsafe { + use std::alloc::GlobalAlloc as _; + let callsite_layout = std::alloc::Layout::new::(); + let alloc = std::alloc::System.alloc(callsite_layout); + let callsite = alloc as *mut MnnCallsite; + // No allocation for string required as they are static by default + callsite.write(MnnCallsite { + interest: AtomicUsize::new(0), + metadata: Metadata::new( + "", + "mnn_ffi_emit", + key.level.into(), + key.file, + key.line, + None, + FieldSet::new(&["message"], identify_callsite!(&*callsite)), + Kind::EVENT, + ), + }); + &*callsite + } + } + + pub(crate) fn interest(&self) -> Interest { + match self.interest.load(std::sync::atomic::Ordering::Acquire) { + CALLSITE_INTEREST_NEVER => Interest::never(), + CALLSITE_INTEREST_SOMETIMES => Interest::sometimes(), + CALLSITE_INTEREST_ALWAYS => Interest::always(), + _ => panic!("attempting to obtain callsite's interest before its been set"), + } + } +} + +#[no_mangle] +extern "C" fn mnn_ffi_emit( + file: *const c_char, + line: libc::size_t, + level: Level, + message: *const c_char, +) { + std::panic::catch_unwind(|| { + let file: &'static str = unsafe { + core::ffi::CStr::from_ptr(file) + .to_str() + .expect("Invalid filename for C file") + }; + + let callsite = DynamicCallsites::get().callsite_for(level, Some(line as u32), Some(file)); + // let interest = callsite.interest + let interest = callsite.interest(); + if interest.is_never() { + return; + } + let meta = callsite.metadata(); + tracing_core::dispatcher::get_default(move |dispatcher| { + if !dispatcher.enabled(meta) { + return; + } + let fields = meta.fields(); + let message = unsafe { + std::ffi::CStr::from_ptr(message) + .to_str() + .expect("Invalid message for C message") + }; + + let message_value = + &tracing_core::field::display(message) as &dyn tracing_core::field::Value; + let message_field = fields + .into_iter() + .next() + .expect("Failed to get message field"); + let values = &[(&message_field, Some(message_value))]; + let valueset = fields.value_set(values); + + let event = tracing_core::Event::new(meta, &valueset); + + dispatcher.event(&event); + }); + }) + .unwrap_or_else(|_e| { + eprintln!("Panic in mnn_ffi_emit aborting"); + // Cannot let the panic escape the ffi boundary + std::process::abort(); + }) +} From 1a31f2caeb22fbbd403057710ac6b07b65c93e1f Mon Sep 17 00:00:00 2001 From: uttarayan21 Date: Mon, 2 Dec 2024 16:31:42 +0530 Subject: [PATCH 03/22] feat: Apply the patch at compile time with diffy --- mnn-sys/build.rs | 5 ++--- mnn-sys/patches/mnn-tracing.patch | 1 - 2 files changed, 2 insertions(+), 4 deletions(-) diff --git a/mnn-sys/build.rs b/mnn-sys/build.rs index 2a5e158..e46e7d5 100644 --- a/mnn-sys/build.rs +++ b/mnn-sys/build.rs @@ -76,11 +76,10 @@ fn main() -> Result<()> { .copy_inside(true), ) .context("Failed to copy vendor")?; - let intptr = vendor.join("include").join("MNN").join("HalideRuntime.h"); + let intptr = vendor.join("include").join("MNN").join("MNNDefine.h"); #[cfg(unix)] std::fs::set_permissions(&intptr, std::fs::Permissions::from_mode(0o644))?; - // try_patch_file("patches/halide_type_t_64.patch", intptr) - // .context("Failed to patch vendor")?; + try_patch_file("patches/mnn-tracing.patch", &intptr).context("Failed to patch vendor")?; use itertools::Itertools; let intptr_contents = std::fs::read_to_string(&intptr)?; diff --git a/mnn-sys/patches/mnn-tracing.patch b/mnn-sys/patches/mnn-tracing.patch index 7939766..55a4b3f 100644 --- a/mnn-sys/patches/mnn-tracing.patch +++ b/mnn-sys/patches/mnn-tracing.patch @@ -1,4 +1,3 @@ -diff --git a/include/MNN/MNNDefine.h b/include/MNN/MNNDefine.h index 8f30cd68..77407812 100644 --- a/include/MNN/MNNDefine.h +++ b/include/MNN/MNNDefine.h From 1561857267bb22564517d4d095ed1f512f4fada6 Mon Sep 17 00:00:00 2001 From: uttarayan21 Date: Mon, 2 Dec 2024 16:38:44 +0530 Subject: [PATCH 04/22] feat: rename benchmarks --- benches/mnn-bench.rs | 4 ++-- 1 file changed, 2 insertions(+), 2 deletions(-) diff --git a/benches/mnn-bench.rs b/benches/mnn-bench.rs index 8616f3a..3c59838 100644 --- a/benches/mnn-bench.rs +++ b/benches/mnn-bench.rs @@ -4,7 +4,7 @@ mod mnn_realesr_bench_with_ones { use divan::*; use mnn::*; #[divan::bench] - pub fn mnn_benchmark_cpu(bencher: Bencher) { + pub fn mnn_realesr_benchmark_cpu(bencher: Bencher) { let mut net = Interpreter::from_file("tests/assets/realesr.mnn").unwrap(); let mut config = ScheduleConfig::new(); config.set_type(ForwardType::CPU); @@ -18,7 +18,7 @@ mod mnn_realesr_bench_with_ones { #[cfg(feature = "opencl")] #[divan::bench] - pub fn mnn_benchmark_opencl(bencher: Bencher) { + pub fn mnn_realesr_benchmark_opencl(bencher: Bencher) { let mut net = Interpreter::from_file("tests/assets/realesr.mnn").unwrap(); let mut config = ScheduleConfig::new(); config.set_type(ForwardType::OpenCL); From b56c75078a8cb2bce99e4b0b86b7706da8123d91 Mon Sep 17 00:00:00 2001 From: uttarayan21 Date: Mon, 2 Dec 2024 16:56:02 +0530 Subject: [PATCH 05/22] feat(tracing): Don't use diffy::patch since that's not working properly --- Cargo.toml | 6 ++---- mnn-sys/build.rs | 40 +++++++++++++++++++++++++++++++++------ mnn-sys/mnn_c/mnndefine.h | 29 ---------------------------- mnn-sys/src/tracing.rs | 3 +-- 4 files changed, 37 insertions(+), 41 deletions(-) delete mode 100644 mnn-sys/mnn_c/mnndefine.h diff --git a/Cargo.toml b/Cargo.toml index 38fa5fa..4f1c672 100644 --- a/Cargo.toml +++ b/Cargo.toml @@ -19,7 +19,7 @@ mnn-sys = { version = "0.1", path = "mnn-sys", features = [] } thiserror = "2.0" error-stack.workspace = true oneshot = "0.1" -tracing = { version = "0.1.40" } +tracing = { version = "0.1.40", optional = true } dunce = "1.0.5" serde = { version = "1.0", features = ["derive"], optional = true } @@ -33,8 +33,7 @@ crt_static = ["mnn-sys/crt_static"] # Disable mnn-threadpool to enable this openmp = ["mnn-sys/openmp"] mnn-threadpool = ["mnn-sys/mnn-threadpool"] -# To keep compatibility with older versions newer ones automatically use it -tracing = [] +tracing = ["dep:tracing"] profile = ["tracing"] serde = ["dep:serde"] @@ -47,7 +46,6 @@ bytemuck = "1.17" clap = { version = "4.5", features = ["derive"] } divan = "0.1.14" tracing-subscriber = "0.3.19" -# mnn-sync = { path = "mnn-sync" } [[bench]] name = "mnn-bench" diff --git a/mnn-sys/build.rs b/mnn-sys/build.rs index e46e7d5..38bd483 100644 --- a/mnn-sys/build.rs +++ b/mnn-sys/build.rs @@ -1,7 +1,7 @@ use ::tap::*; use anyhow::*; -#[cfg(unix)] -use std::os::unix::fs::PermissionsExt; +// #[cfg(unix)] +// use std::os::unix::fs::PermissionsExt; use std::{ path::{Path, PathBuf}, sync::LazyLock, @@ -39,6 +39,30 @@ static MNN_COMPILE: LazyLock = LazyLock::new(|| { const HALIDE_SEARCH: &str = r#"HALIDE_ATTRIBUTE_ALIGN(1) halide_type_code_t code; // halide_type_code_t"#; +const TRACING_SEARCH: &str = "#define MNN_PRINT(format, ...) printf(format, ##__VA_ARGS__)\n#define MNN_ERROR(format, ...) printf(format, ##__VA_ARGS__)"; +const TRACING_REPLACE: &str = r#" +enum class Level { + Info = 0, + Error = 1, +}; +extern "C" { +void mnn_ffi_emit(const char *file, size_t line, Level level, + const char *message); +} +#define MNN_PRINT(format, ...) \ + { \ + char logtmp[4096]; \ + snprintf(logtmp, 4096, format, ##__VA_ARGS__); \ + mnn_ffi_emit(__FILE__, __LINE__, Level::Info, logtmp); \ + } + +#define MNN_ERROR(format, ...) \ + { \ + char logtmp[4096]; \ + snprintf(logtmp, 4096, format, ##__VA_ARGS__); \ + mnn_ffi_emit(__FILE__, __LINE__, Level::Error, logtmp); \ + } +"#; fn ensure_vendor_exists(vendor: impl AsRef) -> Result<()> { if vendor @@ -77,9 +101,9 @@ fn main() -> Result<()> { ) .context("Failed to copy vendor")?; let intptr = vendor.join("include").join("MNN").join("MNNDefine.h"); - #[cfg(unix)] - std::fs::set_permissions(&intptr, std::fs::Permissions::from_mode(0o644))?; - try_patch_file("patches/mnn-tracing.patch", &intptr).context("Failed to patch vendor")?; + // #[cfg(unix)] + // std::fs::set_permissions(&intptr, std::fs::Permissions::from_mode(0o644))?; + // try_patch_file("patches/mnn-tracing.patch", &intptr).context("Failed to patch vendor")?; use itertools::Itertools; let intptr_contents = std::fs::read_to_string(&intptr)?; @@ -95,7 +119,11 @@ fn main() -> Result<()> { .filter(|(c_idx, _)| !(*c_idx == idx - 1 || (idx + 1..=idx + 3).contains(c_idx))) .map(|(_, c)| c) .collect::>(); - std::fs::write(intptr, patched.join("\n"))?; + + std::fs::write( + intptr, + patched.join("\n").replace(TRACING_SEARCH, TRACING_REPLACE), + )?; } } diff --git a/mnn-sys/mnn_c/mnndefine.h b/mnn-sys/mnn_c/mnndefine.h deleted file mode 100644 index 6b0057d..0000000 --- a/mnn-sys/mnn_c/mnndefine.h +++ /dev/null @@ -1,29 +0,0 @@ -#ifndef MNNDEFINE_H -#define MNNDEFINE_H -#include - -enum class Level { - Info = 0, - Error = 1, -}; - -extern "C" { -void mnn_ffi_emit(const char *file, size_t line, Level level, - const char *message); -} - -#define MNN_PRINT(format, ...) \ - { \ - char logtmp[4096]; \ - snprintf(logtmp, 4096, format, ##__VA_ARGS__); \ - mnn_ffi_emit(__FILE__, __LINE__, Level::Info, logtmp); \ - } - -#define MNN_ERROR(format, ...) \ - { \ - char logtmp[4096]; \ - snprintf(logtmp, 4096, format, ##__VA_ARGS__); \ - mnn_ffi_emit(__FILE__, __LINE__, Level::Error, logtmp); \ - } - -#endif diff --git a/mnn-sys/src/tracing.rs b/mnn-sys/src/tracing.rs index 0e77309..7f298db 100644 --- a/mnn-sys/src/tracing.rs +++ b/mnn-sys/src/tracing.rs @@ -162,8 +162,7 @@ extern "C" fn mnn_ffi_emit( .expect("Invalid message for C message") }; - let message_value = - &tracing_core::field::display(message) as &dyn tracing_core::field::Value; + let message_value = &message as &dyn tracing_core::field::Value; let message_field = fields .into_iter() .next() From 0bd1c167c3b9f1b54b50ea9b0a21b9931b4b2107 Mon Sep 17 00:00:00 2001 From: uttarayan21 Date: Tue, 3 Dec 2024 11:50:08 +0530 Subject: [PATCH 06/22] feat: Don't try to patch MNNDefine with the HalideRuntime patch --- mnn-sys/build.rs | 23 +++++++++++++---------- 1 file changed, 13 insertions(+), 10 deletions(-) diff --git a/mnn-sys/build.rs b/mnn-sys/build.rs index 38bd483..cb57614 100644 --- a/mnn-sys/build.rs +++ b/mnn-sys/build.rs @@ -1,7 +1,7 @@ use ::tap::*; use anyhow::*; -// #[cfg(unix)] -// use std::os::unix::fs::PermissionsExt; +#[cfg(unix)] +use std::os::unix::fs::PermissionsExt; use std::{ path::{Path, PathBuf}, sync::LazyLock, @@ -100,10 +100,9 @@ fn main() -> Result<()> { .copy_inside(true), ) .context("Failed to copy vendor")?; - let intptr = vendor.join("include").join("MNN").join("MNNDefine.h"); - // #[cfg(unix)] - // std::fs::set_permissions(&intptr, std::fs::Permissions::from_mode(0o644))?; - // try_patch_file("patches/mnn-tracing.patch", &intptr).context("Failed to patch vendor")?; + let intptr = vendor.join("include").join("MNN").join("HalideRuntime.h"); + #[cfg(unix)] + std::fs::set_permissions(&intptr, std::fs::Permissions::from_mode(0o644))?; use itertools::Itertools; let intptr_contents = std::fs::read_to_string(&intptr)?; @@ -120,11 +119,15 @@ fn main() -> Result<()> { .map(|(_, c)| c) .collect::>(); - std::fs::write( - intptr, - patched.join("\n").replace(TRACING_SEARCH, TRACING_REPLACE), - )?; + std::fs::write(intptr, patched.join("\n"))?; } + + let mnn_define = vendor.join("include").join("MNN").join("MNNDefine.hpp"); + let patched = + std::fs::read_to_string(&mnn_define)?.replace(TRACING_SEARCH, TRACING_REPLACE); + #[cfg(unix)] + std::fs::set_permissions(&mnn_define, std::fs::Permissions::from_mode(0o644))?; + std::fs::write(mnn_define, patched)?; } if *MNN_COMPILE { From ec4bee2fd41d225c5d517af169926ebe46df7f15 Mon Sep 17 00:00:00 2001 From: uttarayan21 Date: Tue, 3 Dec 2024 13:09:57 +0530 Subject: [PATCH 07/22] fix(mnn-sys): Fix buildscript typo MNNDefine.hpp -> MNNDefine.h --- mnn-sys/build.rs | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/mnn-sys/build.rs b/mnn-sys/build.rs index cb57614..0395609 100644 --- a/mnn-sys/build.rs +++ b/mnn-sys/build.rs @@ -122,7 +122,7 @@ fn main() -> Result<()> { std::fs::write(intptr, patched.join("\n"))?; } - let mnn_define = vendor.join("include").join("MNN").join("MNNDefine.hpp"); + let mnn_define = vendor.join("include").join("MNN").join("MNNDefine.h"); let patched = std::fs::read_to_string(&mnn_define)?.replace(TRACING_SEARCH, TRACING_REPLACE); #[cfg(unix)] From 7a15b719d6ce0e59a659d9321d8525dde7b0bc8a Mon Sep 17 00:00:00 2001 From: uttarayan21 Date: Tue, 3 Dec 2024 17:54:20 +0530 Subject: [PATCH 08/22] broken(coreml): Gatherv2 support in progress --- .ignore | 1 - examples/inspect.rs | 10 +++++-- flake.lock | 64 ++++++++++++++++++++++----------------------- flake.nix | 28 +++++++------------- mnn-sys/vendor | 2 +- 5 files changed, 50 insertions(+), 55 deletions(-) delete mode 100644 .ignore diff --git a/.ignore b/.ignore deleted file mode 100644 index d88084b..0000000 --- a/.ignore +++ /dev/null @@ -1 +0,0 @@ -mnn-sys/vendor diff --git a/examples/inspect.rs b/examples/inspect.rs index 44b967d..eb162a9 100644 --- a/examples/inspect.rs +++ b/examples/inspect.rs @@ -18,6 +18,8 @@ pub struct Cli { input_data_type: DataType, #[clap(short, long, default_value = "1")] loops: usize, + #[clap(short, long)] + no_cache: bool, } #[derive(Debug, Clone, clap::ValueEnum)] @@ -46,7 +48,9 @@ pub fn main() -> anyhow::Result<()> { use clap::Parser; let cli = Cli::parse(); let mut interpreter = Interpreter::from_file(&cli.model)?; - interpreter.set_cache_file(cli.model.with_extension("cache"), 128)?; + if !cli.no_cache { + interpreter.set_cache_file(cli.model.with_extension("cache"), 128)?; + } tracing_subscriber::fmt() .event_format( @@ -59,7 +63,9 @@ pub fn main() -> anyhow::Result<()> { let mut config = ScheduleConfig::new(); config.set_type(cli.forward); let mut session = time!(interpreter.create_session(config)?; "create session"); - interpreter.update_cache_file(&mut session)?; + if !cli.no_cache { + interpreter.update_cache_file(&mut session)?; + } let mut current = 0; println!("--------------------------------Info--------------------------------"); diff --git a/flake.lock b/flake.lock index fa6f8b2..4f7e25e 100644 --- a/flake.lock +++ b/flake.lock @@ -3,11 +3,11 @@ "advisory-db": { "flake": false, "locked": { - "lastModified": 1731271136, - "narHash": "sha256-VsrCHM1gP8YqBTQWBQ0TmFNAFv3lBA0PvtWh8/sA9n4=", + "lastModified": 1732819720, + "narHash": "sha256-6H7mKBKw3VErpGcCGEamBYJsopvqqdFmJhl8slfCtOQ=", "owner": "rustsec", "repo": "advisory-db", - "rev": "509528f6775ad69ab114f1e4b37b4359cae5cef4", + "rev": "9dc4a0bb102451e3c71e1b639068aec5a3e1f5f3", "type": "github" }, "original": { @@ -18,11 +18,11 @@ }, "crane": { "locked": { - "lastModified": 1731098351, - "narHash": "sha256-HQkYvKvaLQqNa10KEFGgWHfMAbWBfFp+4cAgkut+NNE=", + "lastModified": 1733016477, + "narHash": "sha256-Hh0khbqBeCtiNS0SJgqdWrQDem9WlPEc2KF5pAY+st0=", "owner": "ipetkov", "repo": "crane", - "rev": "ef80ead953c1b28316cc3f8613904edc2eb90c28", + "rev": "76d64e779e2fbaf172110038492343a8c4e29b55", "type": "github" }, "original": { @@ -36,11 +36,11 @@ "systems": "systems" }, "locked": { - "lastModified": 1726560853, - "narHash": "sha256-X6rJYSESBVr3hBoH0WbKE5KvhPU5bloyZ2L4K60/fPQ=", + "lastModified": 1731533236, + "narHash": "sha256-l0KFg5HjrsfsO/JpG+r7fRrqm12kzFHyUHqHCVpMMbI=", "owner": "numtide", "repo": "flake-utils", - "rev": "c1dfcf08411b08f6b8615f7d8971a2bfa81d5e8a", + "rev": "11707dc2f618dd54ca8739b309ec4fc024de578b", "type": "github" }, "original": { @@ -54,11 +54,11 @@ "systems": "systems_2" }, "locked": { - "lastModified": 1710146030, - "narHash": "sha256-SZ5L6eA7HJ/nmkzGG7/ISclqe6oZdOZTNoesiInkXPQ=", + "lastModified": 1731533236, + "narHash": "sha256-l0KFg5HjrsfsO/JpG+r7fRrqm12kzFHyUHqHCVpMMbI=", "owner": "numtide", "repo": "flake-utils", - "rev": "b1d9ab70662946ef0850d488da1c9019f3a9752a", + "rev": "11707dc2f618dd54ca8739b309ec4fc024de578b", "type": "github" }, "original": { @@ -70,16 +70,16 @@ "mnn": { "flake": false, "locked": { - "lastModified": 1728910345, - "narHash": "sha256-jsCdiFW8oIlKISo/+qrMDy8/RHblrgaQdSuQhoiOd8M=", + "lastModified": 1733133656, + "narHash": "sha256-5DWM5nA3riAkuuhEFdq5tbOUg/mkTOPBhb/m4w4lYB4=", "owner": "alibaba", "repo": "MNN", - "rev": "a74551b4f34b46ce7027c64e800d49fcab497261", + "rev": "98ba45b4561f68cbf6b5a39debb80eb91c64cc32", "type": "github" }, "original": { "owner": "alibaba", - "ref": "2.9.6", + "ref": "3.0.1", "repo": "MNN", "type": "github" } @@ -93,11 +93,11 @@ ] }, "locked": { - "lastModified": 1729601933, - "narHash": "sha256-iV9Whjepp2CmozEbQ026l/fuxx/jXOJGkhu6VVkz/Ec=", + "lastModified": 1733219994, + "narHash": "sha256-H97GEgSR+2L8qGkSU0KpnLUreoCntjq0wSqGt1DsBqw=", "owner": "uttarayan21", "repo": "mnn-nix-overlay", - "rev": "1b2c0b9708bb0b0312cf4cdd33dd76894da25842", + "rev": "a9947ae71989061914b0cdd4e6a4dbc8b82620d3", "type": "github" }, "original": { @@ -109,16 +109,16 @@ "mnn-src": { "flake": false, "locked": { - "lastModified": 1732022377, - "narHash": "sha256-Ui6b83zI1MaiuXeGYNftxAKoGgwVLj6rwSd7H4aVs/8=", + "lastModified": 1733133656, + "narHash": "sha256-5DWM5nA3riAkuuhEFdq5tbOUg/mkTOPBhb/m4w4lYB4=", "owner": "alibaba", "repo": "MNN", - "rev": "707b8a41b25e3d0b7c4a39cd81109d7074ca3c28", + "rev": "98ba45b4561f68cbf6b5a39debb80eb91c64cc32", "type": "github" }, "original": { "owner": "alibaba", - "ref": "3.0.0", + "ref": "3.0.1", "repo": "MNN", "type": "github" } @@ -130,11 +130,11 @@ ] }, "locked": { - "lastModified": 1729742964, - "narHash": "sha256-B4mzTcQ0FZHdpeWcpDYPERtyjJd/NIuaQ9+BV1h+MpA=", + "lastModified": 1731952509, + "narHash": "sha256-p4gB3Rhw8R6Ak4eMl8pqjCPOLCZRqaehZxdZ/mbFClM=", "owner": "nix-community", "repo": "nix-github-actions", - "rev": "e04df33f62cdcf93d73e9a04142464753a16db67", + "rev": "7b5f051df789b6b20d259924d349a9ba3319b226", "type": "github" }, "original": { @@ -145,11 +145,11 @@ }, "nixpkgs": { "locked": { - "lastModified": 1731676054, - "narHash": "sha256-OZiZ3m8SCMfh3B6bfGC/Bm4x3qc1m2SVEAlkV6iY7Yg=", + "lastModified": 1733015953, + "narHash": "sha256-t4BBVpwG9B4hLgc6GUBuj3cjU7lP/PJfpTHuSqE+crk=", "owner": "nixos", "repo": "nixpkgs", - "rev": "5e4fbfb6b3de1aa2872b76d49fafc942626e2add", + "rev": "ac35b104800bff9028425fec3b6e8a41de2bbfff", "type": "github" }, "original": { @@ -178,11 +178,11 @@ ] }, "locked": { - "lastModified": 1731464916, - "narHash": "sha256-WZ5rpjr/wCt7yBOUsvDE2i22hYz9g8W921jlwVktRQ4=", + "lastModified": 1733193245, + "narHash": "sha256-nwvKoPi3S6XyliqBRuC+01QFF0k94ZOvnoZtbGi/ObM=", "owner": "oxalica", "repo": "rust-overlay", - "rev": "2c19bad6e881b5a154cafb7f9106879b5b356d1f", + "rev": "3458f7f946ba61d1a1069aedcc17d7b7616f23cd", "type": "github" }, "original": { diff --git a/flake.nix b/flake.nix index 648a255..6db55f3 100644 --- a/flake.nix +++ b/flake.nix @@ -22,7 +22,7 @@ flake = false; }; mnn-src = { - url = "github:alibaba/MNN/3.0.0"; + url = "github:alibaba/MNN/3.0.1"; flake = false; }; }; @@ -47,11 +47,9 @@ rust-overlay.overlays.default (final: prev: { mnn = mnn-overlay.packages.${system}.mnn.override { - version = "3.0.0"; src = mnn-src; buildConverter = true; - enableVulkan = false; - # enableMetal = true; + enableMetal = true; enableOpencl = true; }; cargo-audit = pkgs.rustPlatform.buildRustPackage rec { @@ -109,12 +107,8 @@ opencl-headers ]) ++ (lib.optionals pkgs.stdenv.isDarwin [ - darwin.apple_sdk.frameworks.OpenCL - ] - ++ (lib.optionals pkgs.stdenv.isAarch64 [ - darwin.apple_sdk.frameworks.Metal - darwin.apple_sdk.frameworks.CoreML - ])); + apple-sdk_15 + ]); }; cargoArtifacts = craneLib.buildPackage commonArgs; in { @@ -204,7 +198,7 @@ cargoExtraArgs = "--example inspect" + ( - lib.optionalString pkgs.stdenv.isDarwin " --features opencl" # + lib.optionalString pkgs.stdenv.isAarch64 ",metal,coreml" + lib.optionalString pkgs.stdenv.isDarwin " --features opencl,metal,coreml" # + lib.optionalString pkgs.stdenv.isAarch64 ",metal,coreml" ); }); default = mnn; @@ -213,6 +207,7 @@ devShells = { default = pkgs.mkShell (commonArgs // { + MNN_SRC = null; packages = with pkgs; [ cargo-audit @@ -224,18 +219,13 @@ git git-lfs llvm - # mnn + llvmPackages.lldb + mnn nushell rust-bindgen + google-cloud-sdk rustToolchainWithRustAnalyzer ] - ++ (lib.optionals pkgs.stdenv.isDarwin [ - darwin.apple_sdk.frameworks.OpenCL - ] - ++ (lib.optionals pkgs.stdenv.isAarch64 [ - darwin.apple_sdk.frameworks.Metal - darwin.apple_sdk.frameworks.CoreML - ])) ++ (lib.optionals pkgs.stdenv.isLinux [ cargo-llvm-cov ]); diff --git a/mnn-sys/vendor b/mnn-sys/vendor index 707b8a4..a74551b 160000 --- a/mnn-sys/vendor +++ b/mnn-sys/vendor @@ -1 +1 @@ -Subproject commit 707b8a41b25e3d0b7c4a39cd81109d7074ca3c28 +Subproject commit a74551b4f34b46ce7027c64e800d49fcab497261 From f0f0844c41e3607d9376808037bf0ada0e0a4e4c Mon Sep 17 00:00:00 2001 From: uttarayan21 Date: Wed, 4 Dec 2024 13:30:46 +0530 Subject: [PATCH 09/22] feat: Added test cases to ensure outputs for multiple backends are same --- Cargo.lock | 1 + Cargo.toml | 2 ++ flake.nix | 5 ++++- mnn-sys/build.rs | 1 + mnn-sys/vendor | 2 +- tests/backend.rs | 41 +++++++++++++++++++++++++++++++++++++++++ 6 files changed, 50 insertions(+), 2 deletions(-) create mode 100644 tests/backend.rs diff --git a/Cargo.lock b/Cargo.lock index 8b4501b..60146ed 100644 --- a/Cargo.lock +++ b/Cargo.lock @@ -451,6 +451,7 @@ dependencies = [ "thiserror", "tracing", "tracing-subscriber", + "tracing-test", ] [[package]] diff --git a/Cargo.toml b/Cargo.toml index 4f1c672..50caace 100644 --- a/Cargo.toml +++ b/Cargo.toml @@ -45,7 +45,9 @@ anyhow = "1.0" bytemuck = "1.17" clap = { version = "4.5", features = ["derive"] } divan = "0.1.14" +tracing = "0.1.40" tracing-subscriber = "0.3.19" +tracing-test = { version = "0.2.5", features = ["no-env-filter"] } [[bench]] name = "mnn-bench" diff --git a/flake.nix b/flake.nix index 6db55f3..4283706 100644 --- a/flake.nix +++ b/flake.nix @@ -101,7 +101,9 @@ pkg-config ]; buildInputs = with pkgs; - [] + [ + mnn + ] ++ (lib.optionals pkgs.stdenv.isLinux [ ocl-icd opencl-headers @@ -208,6 +210,7 @@ default = pkgs.mkShell (commonArgs // { MNN_SRC = null; + LLDB_DEBUGSERVER_PATH = "/Applications/Xcode.app/Contents/SharedFrameworks/LLDB.framework/Versions/A/Resources/debugserver"; packages = with pkgs; [ cargo-audit diff --git a/mnn-sys/build.rs b/mnn-sys/build.rs index 0395609..bff15e7 100644 --- a/mnn-sys/build.rs +++ b/mnn-sys/build.rs @@ -91,6 +91,7 @@ fn main() -> Result<()> { ensure_vendor_exists(&source)?; let vendor = out_dir.join("vendor"); + std::fs::remove_dir_all(&vendor).ok(); if !vendor.exists() { fs_extra::dir::copy( &source, diff --git a/mnn-sys/vendor b/mnn-sys/vendor index a74551b..f80e36e 160000 --- a/mnn-sys/vendor +++ b/mnn-sys/vendor @@ -1 +1 @@ -Subproject commit a74551b4f34b46ce7027c64e800d49fcab497261 +Subproject commit f80e36e5146253503f6a078ded1e628eb1746489 diff --git a/tests/backend.rs b/tests/backend.rs new file mode 100644 index 0000000..081d857 --- /dev/null +++ b/tests/backend.rs @@ -0,0 +1,41 @@ +pub mod common; +use common::*; +use mnn::ForwardType; +use tracing_test::traced_test; + +#[test] +#[traced_test] +fn compare_cpu_and_coreml_outputs() { + let mut net = mnn::Interpreter::from_file("tests/assets/realesr.mnn").unwrap(); + let cpu_config = ScheduleConfig::new(); + let mut coreml_config = ScheduleConfig::new(); + let mut bc = BackendConfig::new(); + coreml_config.set_type(ForwardType::CoreML); + let cpu_session = net.create_session(cpu_config).unwrap(); + let coreml_session = net.create_session(coreml_config).unwrap(); + net.inputs(&cpu_session).iter().for_each(|x| { + let mut tensor = x.tensor::().expect("No tensor"); + tensor.fill(1.0f32); + }); + net.inputs(&coreml_session).iter().for_each(|x| { + let mut tensor = x.tensor::().expect("No tensor"); + tensor.fill(1.0f32); + }); + + net.run_session(&cpu_session).unwrap(); + net.run_session(&coreml_session).unwrap(); + + let cpu_outputs = net.outputs(&cpu_session); + let coreml_outputs = net.outputs(&coreml_session); + + cpu_outputs + .iter() + .zip(coreml_outputs.iter()) + .for_each(|(cpu, coreml)| { + let cpu_tensor = cpu.tensor::().expect("No tensor"); + let coreml_tensor = coreml.tensor::().expect("No tensor"); + let cpu = cpu_tensor.create_host_tensor_from_device(true); + let coreml = coreml_tensor.create_host_tensor_from_device(true); + assert_eq!(cpu.host(), coreml.host()); + }); +} From 576b97e701d4ed4f0e8b06b7b4d3108d52007fff Mon Sep 17 00:00:00 2001 From: uttarayan21 Date: Mon, 9 Dec 2024 18:15:30 +0530 Subject: [PATCH 10/22] feat: Added testing tool --- Cargo.lock | 251 ++++++++++++++++- Cargo.toml | 2 +- flake.lock | 24 +- flake.nix | 8 +- mnn-sys/build.rs | 26 +- mnn-sys/mnn_c/backend_c.cpp | 13 + mnn-sys/mnn_c/backend_c.h | 4 + mnn-sys/mnn_c/schedule_c.cpp | 7 + mnn-sys/mnn_c/schedule_c.h | 2 + mnn-sys/vendor | 2 +- src/backend.rs | 91 +++++- src/schedule.rs | 70 ++++- src/tensor.rs | 25 ++ tools/bencher/Cargo.toml | 30 ++ tools/bencher/src/cli.rs | 133 +++++++++ tools/bencher/src/main.rs | 521 +++++++++++++++++++++++++++++++++++ 16 files changed, 1167 insertions(+), 42 deletions(-) create mode 100644 tools/bencher/Cargo.toml create mode 100644 tools/bencher/src/cli.rs create mode 100644 tools/bencher/src/main.rs diff --git a/Cargo.lock b/Cargo.lock index 60146ed..5abe199 100644 --- a/Cargo.lock +++ b/Cargo.lock @@ -2,6 +2,18 @@ # It is not intended for manual editing. version = 4 +[[package]] +name = "ahash" +version = "0.8.11" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "e89da841a80418a9b391ebaea17f5c112ffaaa96f621d2c285b5174da76b9011" +dependencies = [ + "cfg-if", + "once_cell", + "version_check", + "zerocopy", +] + [[package]] name = "aho-corasick" version = "1.1.3" @@ -11,13 +23,19 @@ dependencies = [ "memchr", ] +[[package]] +name = "allocator-api2" +version = "0.2.21" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "683d7910e743518b0e34f1186f92494becacb047c7b6bf616c96772180fef923" + [[package]] name = "annotate-snippets" version = "0.9.2" source = "registry+https://github.com/rust-lang/crates.io-index" checksum = "ccaf7e9dfbb6ab22c82e473cd1a8a7bd313c19a5b7e40970f3d89ef5a5c9e81e" dependencies = [ - "unicode-width", + "unicode-width 0.1.14", "yansi-term", ] @@ -82,6 +100,27 @@ version = "1.4.0" source = "registry+https://github.com/rust-lang/crates.io-index" checksum = "ace50bade8e6234aa140d9a2f552bbee1db4d353f69b8217bc503490fc1a9f26" +[[package]] +name = "bencher" +version = "0.1.0" +dependencies = [ + "chumsky", + "clap", + "clap-verbosity-flag", + "console", + "dunce", + "error-stack", + "indicatif", + "mnn", + "same-file", + "serde", + "serde_json", + "tempfile", + "thiserror", + "tracing", + "tracing-subscriber", +] + [[package]] name = "bindgen" version = "0.70.1" @@ -147,6 +186,16 @@ version = "1.0.0" source = "registry+https://github.com/rust-lang/crates.io-index" checksum = "baf1de4339761588bc0619e3cbc0120ee582ebb74b53b4efbf79117bd2da40fd" +[[package]] +name = "chumsky" +version = "0.9.3" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "8eebd66744a15ded14960ab4ccdbfb51ad3b81f51f3f04a80adac98c985396c9" +dependencies = [ + "hashbrown", + "stacker", +] + [[package]] name = "clang-sys" version = "1.8.1" @@ -160,19 +209,29 @@ dependencies = [ [[package]] name = "clap" -version = "4.5.21" +version = "4.5.22" source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "fb3b4b9e5a7c7514dfa52869339ee98b3156b0bfb4e8a77c4ff4babb64b1604f" +checksum = "69371e34337c4c984bbe322360c2547210bf632eb2814bbe78a6e87a2935bd2b" dependencies = [ "clap_builder", "clap_derive", ] +[[package]] +name = "clap-verbosity-flag" +version = "3.0.1" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "54381ae56ad222eea3f529c692879e9c65e07945ae48d3dc4d1cb18dbec8cf44" +dependencies = [ + "clap", + "tracing-core", +] + [[package]] name = "clap_builder" -version = "4.5.21" +version = "4.5.22" source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "b17a95aa67cc7b5ebd32aa5370189aa0d79069ef1c64ce893bd30fb24bff20ec" +checksum = "6e24c1b4099818523236a8ca881d2b45db98dadfb4625cf6608c12069fcbbde1" dependencies = [ "anstream", "anstyle", @@ -219,6 +278,19 @@ version = "1.3.0" source = "registry+https://github.com/rust-lang/crates.io-index" checksum = "baf0a07a401f374238ab8e2f11a104d2851bf9ce711ec69804834de8af45c7af" +[[package]] +name = "console" +version = "0.15.8" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "0e1f83fc076bd6dd27517eacdf25fef6c4dfe5f1d7448bafaaf3a26f13b5e4eb" +dependencies = [ + "encode_unicode", + "lazy_static", + "libc", + "unicode-width 0.1.14", + "windows-sys 0.52.0", +] + [[package]] name = "diffy" version = "0.4.0" @@ -265,6 +337,12 @@ version = "1.13.0" source = "registry+https://github.com/rust-lang/crates.io-index" checksum = "60b1af1c220855b6ceac025d3f6ecdd2b7c4894bfe9cd9bda4fbb4bc7c0d4cf0" +[[package]] +name = "encode_unicode" +version = "0.3.6" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "a357d28ed41a50f9c765dbfe56cbc04a64e53e5fc58ba79fbc34c10ef3df831f" + [[package]] name = "errno" version = "0.3.9" @@ -283,8 +361,15 @@ checksum = "fe413319145d1063f080f27556fd30b1d70b01e2ba10c2a6e40d4be982ffc5d1" dependencies = [ "anyhow", "rustc_version", + "serde", ] +[[package]] +name = "fastrand" +version = "2.2.0" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "486f806e73c5707928240ddc295403b1b93c96a02038563881c4a2fd84b81ac4" + [[package]] name = "flume" version = "0.11.1" @@ -320,12 +405,35 @@ version = "0.3.1" source = "registry+https://github.com/rust-lang/crates.io-index" checksum = "d2fabcfbdc87f4758337ca535fb41a6d701b65693ce38287d856d1674551ec9b" +[[package]] +name = "hashbrown" +version = "0.14.5" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "e5274423e17b7c9fc20b6e7e208532f9b19825d82dfd615708b70edd83df41f1" +dependencies = [ + "ahash", + "allocator-api2", +] + [[package]] name = "heck" version = "0.5.0" source = "registry+https://github.com/rust-lang/crates.io-index" checksum = "2304e00983f87ffb38b55b444b5e3b60a884b5d30c0fca7d82fe33449bbe55ea" +[[package]] +name = "indicatif" +version = "0.17.9" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "cbf675b85ed934d3c67b5c5469701eec7db22689d0a2139d856e0925fa28b281" +dependencies = [ + "console", + "number_prefix", + "portable-atomic", + "unicode-width 0.2.0", + "web-time", +] + [[package]] name = "is_terminal_polyfill" version = "1.70.1" @@ -341,6 +449,12 @@ dependencies = [ "either", ] +[[package]] +name = "itoa" +version = "1.0.14" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "d75a2a4b1b190afb6f5425f10f6a8f959d2ea0b9c2b1d79553551850539e4674" + [[package]] name = "jobserver" version = "0.1.32" @@ -587,6 +701,12 @@ dependencies = [ "autocfg", ] +[[package]] +name = "number_prefix" +version = "0.4.0" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "830b246a0e5f20af87141b25c173cd1b609bd7779a4617d6ec582abaf90870f3" + [[package]] name = "once_cell" version = "1.20.2" @@ -645,6 +765,15 @@ dependencies = [ "unicode-ident", ] +[[package]] +name = "psm" +version = "0.1.24" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "200b9ff220857e53e184257720a14553b2f4aa02577d2ed9842d45d4b9654810" +dependencies = [ + "cc", +] + [[package]] name = "quote" version = "1.0.37" @@ -738,6 +867,21 @@ dependencies = [ "windows-sys 0.52.0", ] +[[package]] +name = "ryu" +version = "1.0.18" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "f3cb5ba0dc43242ce17de99c180e96db90b235b8a9fdc9543c96d2209116bd9f" + +[[package]] +name = "same-file" +version = "1.0.6" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "93fc1dc3aaa9bfed95e02e6eadabb4baf7e3078b0bd1b4d7b6b0b68378900502" +dependencies = [ + "winapi-util", +] + [[package]] name = "scopeguard" version = "1.2.0" @@ -770,6 +914,18 @@ dependencies = [ "syn", ] +[[package]] +name = "serde_json" +version = "1.0.133" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "c7fceb2473b9166b2294ef05efcb65a3db80803f0b03ef86a5fc88a2b85ee377" +dependencies = [ + "itoa", + "memchr", + "ryu", + "serde", +] + [[package]] name = "sharded-slab" version = "0.1.7" @@ -800,6 +956,19 @@ dependencies = [ "lock_api", ] +[[package]] +name = "stacker" +version = "0.1.17" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "799c883d55abdb5e98af1a7b3f23b9b6de8ecada0ecac058672d7635eb48ca7b" +dependencies = [ + "cc", + "cfg-if", + "libc", + "psm", + "windows-sys 0.59.0", +] + [[package]] name = "strsim" version = "0.11.1" @@ -823,6 +992,19 @@ version = "1.0.1" source = "registry+https://github.com/rust-lang/crates.io-index" checksum = "55937e1799185b12863d447f42597ed69d9928686b8d88a1df17376a097d8369" +[[package]] +name = "tempfile" +version = "3.14.0" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "28cce251fcbc87fac86a866eeb0d6c2d536fc16d06f184bb61aeae11aa4cee0c" +dependencies = [ + "cfg-if", + "fastrand", + "once_cell", + "rustix", + "windows-sys 0.59.0", +] + [[package]] name = "terminal_size" version = "0.4.0" @@ -835,18 +1017,18 @@ dependencies = [ [[package]] name = "thiserror" -version = "2.0.3" +version = "2.0.4" source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "c006c85c7651b3cf2ada4584faa36773bd07bac24acfb39f3c431b36d7e667aa" +checksum = "2f49a1853cf82743e3b7950f77e0f4d622ca36cf4317cba00c767838bac8d490" dependencies = [ "thiserror-impl", ] [[package]] name = "thiserror-impl" -version = "2.0.3" +version = "2.0.4" source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "f077553d607adc1caf65430528a576c757a71ed73944b66ebb58ef2bbd243568" +checksum = "8381894bb3efe0c4acac3ded651301ceee58a15d47c2e34885ed1908ad667061" dependencies = [ "proc-macro2", "quote", @@ -957,6 +1139,12 @@ version = "0.1.14" source = "registry+https://github.com/rust-lang/crates.io-index" checksum = "7dd6e30e90baa6f72411720665d41d89b9a3d039dc45b8faea1ddd07f617f6af" +[[package]] +name = "unicode-width" +version = "0.2.0" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "1fc81956842c57dac11422a97c3b8195a1ff727f06e85c84ed2e8aa277c9a0fd" + [[package]] name = "utf8parse" version = "0.2.2" @@ -969,6 +1157,12 @@ version = "0.1.0" source = "registry+https://github.com/rust-lang/crates.io-index" checksum = "830b7e5d4d90034032940e4ace0d9a9a057e7a45cd94e6c007832e39edb82f6d" +[[package]] +name = "version_check" +version = "0.9.5" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "0b928f33d975fc6ad9f86c8f283853ad26bdd5b10b7f1542aa2fa15e2289105a" + [[package]] name = "wasi" version = "0.11.0+wasi-snapshot-preview1" @@ -1030,6 +1224,16 @@ version = "0.2.95" source = "registry+https://github.com/rust-lang/crates.io-index" checksum = "65fc09f10666a9f147042251e0dda9c18f166ff7de300607007e96bdebc1068d" +[[package]] +name = "web-time" +version = "1.1.0" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "5a6580f308b1fad9207618087a65c04e7a10bc77e02c8e84e9b00dd4b12fa0bb" +dependencies = [ + "js-sys", + "wasm-bindgen", +] + [[package]] name = "winapi" version = "0.3.9" @@ -1046,6 +1250,15 @@ version = "0.4.0" source = "registry+https://github.com/rust-lang/crates.io-index" checksum = "ac3b87c63620426dd9b991e5ce0329eff545bccbbb34f3be09ff6fb6ab51b7b6" +[[package]] +name = "winapi-util" +version = "0.1.9" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "cf221c93e13a30d793f7645a0e7762c55d169dbb0a49671918a2319d289b10bb" +dependencies = [ + "windows-sys 0.59.0", +] + [[package]] name = "winapi-x86_64-pc-windows-gnu" version = "0.4.0" @@ -1142,3 +1355,23 @@ checksum = "fe5c30ade05e61656247b2e334a031dfd0cc466fadef865bdcdea8d537951bf1" dependencies = [ "winapi", ] + +[[package]] +name = "zerocopy" +version = "0.7.35" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "1b9b4fd18abc82b8136838da5d50bae7bdea537c574d8dc1a34ed098d6c166f0" +dependencies = [ + "zerocopy-derive", +] + +[[package]] +name = "zerocopy-derive" +version = "0.7.35" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "fa4f8080344d4671fb4e831a13ad1e68092748387dfc4f55e356242fae12ce3e" +dependencies = [ + "proc-macro2", + "quote", + "syn", +] diff --git a/Cargo.toml b/Cargo.toml index 50caace..f8db97d 100644 --- a/Cargo.toml +++ b/Cargo.toml @@ -1,5 +1,5 @@ [workspace] -members = [".", "mnn-bridge", "mnn-sync", "mnn-sys"] +members = [".", "mnn-bridge", "mnn-sync", "mnn-sys", "tools/bencher"] [workspace.package] license = "Apache-2.0" diff --git a/flake.lock b/flake.lock index 4f7e25e..0b10dd1 100644 --- a/flake.lock +++ b/flake.lock @@ -3,11 +3,11 @@ "advisory-db": { "flake": false, "locked": { - "lastModified": 1732819720, - "narHash": "sha256-6H7mKBKw3VErpGcCGEamBYJsopvqqdFmJhl8slfCtOQ=", + "lastModified": 1733371256, + "narHash": "sha256-gWvibGRlB+SMgqTOblVPpkcIAcl0LppLz1dBukEyXoY=", "owner": "rustsec", "repo": "advisory-db", - "rev": "9dc4a0bb102451e3c71e1b639068aec5a3e1f5f3", + "rev": "463107188fc02ccaddefc8f4a65746afa06bb7fa", "type": "github" }, "original": { @@ -18,11 +18,11 @@ }, "crane": { "locked": { - "lastModified": 1733016477, - "narHash": "sha256-Hh0khbqBeCtiNS0SJgqdWrQDem9WlPEc2KF5pAY+st0=", + "lastModified": 1733286231, + "narHash": "sha256-mlIDSv1/jqWnH8JTiOV7GMUNPCXL25+6jmD+7hdxx5o=", "owner": "ipetkov", "repo": "crane", - "rev": "76d64e779e2fbaf172110038492343a8c4e29b55", + "rev": "af1556ecda8bcf305820f68ec2f9d77b41d9cc80", "type": "github" }, "original": { @@ -145,11 +145,11 @@ }, "nixpkgs": { "locked": { - "lastModified": 1733015953, - "narHash": "sha256-t4BBVpwG9B4hLgc6GUBuj3cjU7lP/PJfpTHuSqE+crk=", + "lastModified": 1733212471, + "narHash": "sha256-M1+uCoV5igihRfcUKrr1riygbe73/dzNnzPsmaLCmpo=", "owner": "nixos", "repo": "nixpkgs", - "rev": "ac35b104800bff9028425fec3b6e8a41de2bbfff", + "rev": "55d15ad12a74eb7d4646254e13638ad0c4128776", "type": "github" }, "original": { @@ -178,11 +178,11 @@ ] }, "locked": { - "lastModified": 1733193245, - "narHash": "sha256-nwvKoPi3S6XyliqBRuC+01QFF0k94ZOvnoZtbGi/ObM=", + "lastModified": 1733366051, + "narHash": "sha256-Zlas3LFqrW8bVVrZYgkzS4VNkZgtZ/hsbYhO0GtKLys=", "owner": "oxalica", "repo": "rust-overlay", - "rev": "3458f7f946ba61d1a1069aedcc17d7b7616f23cd", + "rev": "ba5ed0362eaae83fe8925a2d5cfcf356ff22f70f", "type": "github" }, "original": { diff --git a/flake.nix b/flake.nix index 4283706..d508ae2 100644 --- a/flake.nix +++ b/flake.nix @@ -75,7 +75,7 @@ extensions = ["rust-src" "llvm-tools"]; }; rustToolchainWithRustAnalyzer = pkgs.rust-bin.stable.${version}.default.override ({ - extensions = ["rust-src" "rust-analyzer"]; + extensions = ["rust-docs" "rust-src" "rust-analyzer"]; } // (lib.optionalAttrs pkgs.stdenv.isDarwin { targets = ["aarch64-apple-darwin" "x86_64-apple-darwin"]; @@ -102,14 +102,14 @@ ]; buildInputs = with pkgs; [ - mnn - ] + mnn + ] ++ (lib.optionals pkgs.stdenv.isLinux [ ocl-icd opencl-headers ]) ++ (lib.optionals pkgs.stdenv.isDarwin [ - apple-sdk_15 + apple-sdk_13 ]); }; cargoArtifacts = craneLib.buildPackage commonArgs; diff --git a/mnn-sys/build.rs b/mnn-sys/build.rs index bff15e7..636dc01 100644 --- a/mnn-sys/build.rs +++ b/mnn-sys/build.rs @@ -91,7 +91,7 @@ fn main() -> Result<()> { ensure_vendor_exists(&source)?; let vendor = out_dir.join("vendor"); - std::fs::remove_dir_all(&vendor).ok(); + // std::fs::remove_dir_all(&vendor).ok(); if !vendor.exists() { fs_extra::dir::copy( &source, @@ -374,18 +374,18 @@ pub fn build_cmake(path: impl AsRef, install: impl AsRef) -> Result< Ok(()) } -pub fn try_patch_file(patch: impl AsRef, file: impl AsRef) -> Result<()> { - let patch = dunce::canonicalize(patch)?; - rerun_if_changed(&patch); - let patch = std::fs::read_to_string(&patch)?; - let patch = diffy::Patch::from_str(&patch)?; - let file_path = file.as_ref(); - let file = std::fs::read_to_string(file_path).context("Failed to read input file")?; - let patched_file = - diffy::apply(&file, &patch).context("Failed to apply patches using diffy")?; - std::fs::write(file_path, patched_file)?; - Ok(()) -} +// pub fn try_patch_file(patch: impl AsRef, file: impl AsRef) -> Result<()> { +// let patch = dunce::canonicalize(patch)?; +// rerun_if_changed(&patch); +// let patch = std::fs::read_to_string(&patch)?; +// let patch = diffy::Patch::from_str(&patch)?; +// let file_path = file.as_ref(); +// let file = std::fs::read_to_string(file_path).context("Failed to read input file")?; +// let patched_file = +// diffy::apply(&file, &patch).context("Failed to apply patches using diffy")?; +// std::fs::write(file_path, patched_file)?; +// Ok(()) +// } pub fn rerun_if_changed(path: impl AsRef) { println!("cargo:rerun-if-changed={}", path.as_ref().display()); diff --git a/mnn-sys/mnn_c/backend_c.cpp b/mnn-sys/mnn_c/backend_c.cpp index ff802d2..298dab7 100644 --- a/mnn-sys/mnn_c/backend_c.cpp +++ b/mnn-sys/mnn_c/backend_c.cpp @@ -43,3 +43,16 @@ void mnnbc_reset(MNNBackendConfig *config) { MNN::BackendConfig::Precision_Normal; reinterpret_cast(config)->sharedContext = nullptr; } + +MemoryMode mnnbc_get_memory_mode(MNNBackendConfig *config) { + return static_cast( + reinterpret_cast(config)->memory); +} +PowerMode mnnbc_get_power_mode(MNNBackendConfig *config) { + return static_cast( + reinterpret_cast(config)->power); +} +PrecisionMode mnnbc_get_precision_mode(MNNBackendConfig *config) { + return static_cast( + reinterpret_cast(config)->precision); +} diff --git a/mnn-sys/mnn_c/backend_c.h b/mnn-sys/mnn_c/backend_c.h index 9f2ae4e..51212fa 100644 --- a/mnn-sys/mnn_c/backend_c.h +++ b/mnn-sys/mnn_c/backend_c.h @@ -37,6 +37,10 @@ void mnnbc_set_shared_context(MNNBackendConfig *config, void *shared_context); void mnnbc_set_flags(MNNBackendConfig *config, size_t flags); void mnnbc_reset(MNNBackendConfig *config); +MemoryMode mnnbc_get_memory_mode(MNNBackendConfig *config); +PowerMode mnnbc_get_power_mode(MNNBackendConfig *config); +PrecisionMode mnnbc_get_precision_mode(MNNBackendConfig *config); + #ifdef __cplusplus } #endif diff --git a/mnn-sys/mnn_c/schedule_c.cpp b/mnn-sys/mnn_c/schedule_c.cpp index 5fc10b4..491c309 100644 --- a/mnn-sys/mnn_c/schedule_c.cpp +++ b/mnn-sys/mnn_c/schedule_c.cpp @@ -54,3 +54,10 @@ void mnnsc_set_backend_config(MNNScheduleConfig *config, mnn_config->backendConfig = reinterpret_cast(backendConfig); } + +MNNForwardType mnnsc_get_type(MNNScheduleConfig *config) { + return reinterpret_cast(config)->type; +} +MNNForwardType mnnsc_get_backup_type(MNNScheduleConfig *config) { + return reinterpret_cast(config)->backupType; +} diff --git a/mnn-sys/mnn_c/schedule_c.h b/mnn-sys/mnn_c/schedule_c.h index 048f8bb..648ed04 100644 --- a/mnn-sys/mnn_c/schedule_c.h +++ b/mnn-sys/mnn_c/schedule_c.h @@ -23,6 +23,8 @@ void mnnsc_set_backup_type(MNNScheduleConfig *config, MNNForwardType backupType); void mnnsc_set_backend_config(MNNScheduleConfig *config, MNNBackendConfig *backendConfig); +MNNForwardType mnnsc_get_type(MNNScheduleConfig *config); +MNNForwardType mnnsc_get_backup_type(MNNScheduleConfig *config); #ifdef __cplusplus } diff --git a/mnn-sys/vendor b/mnn-sys/vendor index f80e36e..b03cd53 160000 --- a/mnn-sys/vendor +++ b/mnn-sys/vendor @@ -1 +1 @@ -Subproject commit f80e36e5146253503f6a078ded1e628eb1746489 +Subproject commit b03cd53191c586cc94a94b76f85b904b654d8d78 diff --git a/src/backend.rs b/src/backend.rs index 85c790f..1f633bb 100644 --- a/src/backend.rs +++ b/src/backend.rs @@ -3,13 +3,37 @@ use std::str::FromStr; use mnn_sys::*; -#[derive(Debug)] #[repr(transparent)] pub struct BackendConfig { pub(crate) inner: *mut MNNBackendConfig, __marker: core::marker::PhantomData<()>, } +impl core::fmt::Debug for BackendConfig { + fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result { + f.debug_struct("BackendConfig") + .field("memory", &self.get_memory_mode()) + .field("power", &self.get_power_mode()) + .field("precision", &self.get_precision_mode()) + .finish() + } +} + +#[cfg(feature = "serde")] +impl serde::Serialize for BackendConfig { + fn serialize(&self, serializer: S) -> Result + where + S: serde::ser::Serializer, + { + use serde::ser::SerializeStruct; + let mut state = serializer.serialize_struct("BackendConfig", 3)?; + state.serialize_field("memory", &self.get_memory_mode())?; + state.serialize_field("power", &self.get_power_mode())?; + state.serialize_field("precision", &self.get_precision_mode())?; + state.end() + } +} + impl Clone for BackendConfig { fn clone(&self) -> Self { unsafe { @@ -52,6 +76,23 @@ impl PowerMode { Self::High => mnn_sys::PowerMode::Power_High, } } + + pub fn to_str(self) -> &'static str { + match self { + Self::Low => "low", + Self::Normal => "normal", + Self::High => "high", + } + } + + fn from_mnn_sys(mode: mnn_sys::PowerMode) -> Self { + match mode { + mnn_sys::PowerMode::Power_Low => Self::Low, + mnn_sys::PowerMode::Power_Normal => Self::Normal, + mnn_sys::PowerMode::Power_High => Self::High, + _ => Self::Normal, + } + } } impl FromStr for PowerMode { @@ -114,6 +155,23 @@ impl MemoryMode { Self::High => mnn_sys::MemoryMode::Memory_High, } } + + pub fn to_str(self) -> &'static str { + match self { + Self::Low => "low", + Self::Normal => "normal", + Self::High => "high", + } + } + + fn from_mnn_sys(mode: mnn_sys::MemoryMode) -> Self { + match mode { + mnn_sys::MemoryMode::Memory_Low => Self::Low, + mnn_sys::MemoryMode::Memory_Normal => Self::Normal, + mnn_sys::MemoryMode::Memory_High => Self::High, + _ => Self::Normal, + } + } } #[derive(Debug, Clone, Copy)] @@ -133,6 +191,25 @@ impl PrecisionMode { Self::High => mnn_sys::PrecisionMode::Precision_High, } } + + pub fn to_str(self) -> &'static str { + match self { + Self::LowBf16 => "low_bf16", + Self::Low => "low", + Self::Normal => "normal", + Self::High => "high", + } + } + + fn from_mnn_sys(mode: mnn_sys::PrecisionMode) -> Self { + match mode { + mnn_sys::PrecisionMode::Precision_Low_BF16 => Self::LowBf16, + mnn_sys::PrecisionMode::Precision_Low => Self::Low, + mnn_sys::PrecisionMode::Precision_Normal => Self::Normal, + mnn_sys::PrecisionMode::Precision_High => Self::High, + _ => Self::Normal, + } + } } impl BackendConfig { @@ -154,6 +231,10 @@ impl BackendConfig { } } + pub fn get_memory_mode(&self) -> MemoryMode { + unsafe { MemoryMode::from_mnn_sys(mnn_sys::mnnbc_get_memory_mode(self.inner)) } + } + /// Sets the [PowerMode] for the backend pub fn set_power_mode(&mut self, mode: PowerMode) { unsafe { @@ -161,6 +242,10 @@ impl BackendConfig { } } + pub fn get_power_mode(&self) -> PowerMode { + unsafe { PowerMode::from_mnn_sys(mnn_sys::mnnbc_get_power_mode(self.inner)) } + } + /// Sets the [PrecisionMode] for the backend pub fn set_precision_mode(&mut self, mode: PrecisionMode) { unsafe { @@ -168,6 +253,10 @@ impl BackendConfig { } } + pub fn get_precision_mode(&self) -> PrecisionMode { + unsafe { PrecisionMode::from_mnn_sys(mnn_sys::mnnbc_get_precision_mode(self.inner)) } + } + /// Sets the flags for the backend /// What the flag represents is depends on each backend or isn't documented pub fn set_flags(&mut self, flags: usize) { diff --git a/src/schedule.rs b/src/schedule.rs index da7e570..f120bde 100644 --- a/src/schedule.rs +++ b/src/schedule.rs @@ -72,6 +72,25 @@ impl ForwardType { } } + fn from_mnn_sys(mode: MNNForwardType) -> Self { + match mode { + MNNForwardType::MNN_FORWARD_AUTO => ForwardType::Auto, + MNNForwardType::MNN_FORWARD_ALL => ForwardType::All, + MNNForwardType::MNN_FORWARD_CPU => ForwardType::CPU, + #[cfg(feature = "metal")] + MNNForwardType::MNN_FORWARD_METAL => ForwardType::Metal, + #[cfg(feature = "opencl")] + MNNForwardType::MNN_FORWARD_OPENCL => ForwardType::OpenCL, + #[cfg(feature = "opengl")] + MNNForwardType::MNN_FORWARD_OPENGL => ForwardType::OpenGL, + #[cfg(feature = "vulkan")] + MNNForwardType::MNN_FORWARD_VULKAN => ForwardType::Vulkan, + #[cfg(feature = "coreml")] + MNNForwardType::MNN_FORWARD_NN => ForwardType::CoreML, + _ => ForwardType::Auto, + } + } + fn list() -> Vec<&'static str> { vec![ "auto", @@ -89,6 +108,24 @@ impl ForwardType { "coreml", ] } + + pub fn to_str(self) -> &'static str { + match self { + ForwardType::Auto => "auto", + ForwardType::All => "all", + ForwardType::CPU => "cpu", + #[cfg(feature = "metal")] + ForwardType::Metal => "metal", + #[cfg(feature = "opencl")] + ForwardType::OpenCL => "opencl", + #[cfg(feature = "opengl")] + ForwardType::OpenGL => "opengl", + #[cfg(feature = "vulkan")] + ForwardType::Vulkan => "vulkan", + #[cfg(feature = "coreml")] + ForwardType::CoreML => "coreml", + } + } } impl core::str::FromStr for ForwardType { @@ -164,13 +201,34 @@ impl core::str::FromStr for ForwardType { /// /// **Warning:** The `Drop` implementation for `ScheduleConfig` ensures that the underlying `MNNScheduleConfig` /// is properly destroyed when the struct goes out of scope. Users should not manually free the `inner` pointer. -#[derive(Debug)] pub struct ScheduleConfig { pub(crate) inner: *mut MNNScheduleConfig, pub(crate) backend_config: Option, pub(crate) __marker: core::marker::PhantomData<()>, } +impl core::fmt::Debug for ScheduleConfig { + fn fmt(&self, f: &mut core::fmt::Formatter<'_>) -> core::fmt::Result { + f.debug_struct("ScheduleConfig") + .field("type", &self.get_type()) + .field("backup_type", &self.get_backup_type()) + .field("backend_config", &self.backend_config) + .finish() + } +} + +#[cfg(feature = "serde")] +impl serde::Serialize for ScheduleConfig { + fn serialize(&self, serializer: S) -> Result { + use serde::ser::SerializeStruct; + let mut state = serializer.serialize_struct("ScheduleConfig", 3)?; + state.serialize_field("type", &self.get_type())?; + state.serialize_field("backup_type", &self.get_backup_type())?; + state.serialize_field("backend_config", &self.backend_config)?; + state.end() + } +} + impl Clone for ScheduleConfig { fn clone(&self) -> Self { unsafe { @@ -252,6 +310,11 @@ impl ScheduleConfig { self } + /// Gets the type of backend to be used for computation. + pub fn get_type(&self) -> ForwardType { + unsafe { ForwardType::from_mnn_sys(mnnsc_get_type(self.inner)) } + } + /// Sets the number of threads to be used for computation. /// /// # Arguments @@ -288,6 +351,11 @@ impl ScheduleConfig { self } + /// Gets the backup type of backend to be used if the primary backend fails. + pub fn get_backup_type(&self) -> ForwardType { + unsafe { ForwardType::from_mnn_sys(mnnsc_get_backup_type(self.inner)) } + } + /// Sets the backend-specific configuration. /// /// # Arguments diff --git a/src/tensor.rs b/src/tensor.rs index 087e3b8..53a6d36 100644 --- a/src/tensor.rs +++ b/src/tensor.rs @@ -696,6 +696,14 @@ impl<'r> RawTensor<'r> { } } + pub fn size(&self) -> usize { + unsafe { mnn_sys::Tensor_usize(self.inner) } + } + + pub fn element_size(&self) -> usize { + unsafe { mnn_sys::Tensor_elementSize(self.inner) as usize } + } + pub fn dimensions(&self) -> usize { unsafe { mnn_sys::Tensor_dimensions(self.inner) as usize } } @@ -722,6 +730,23 @@ impl<'r> RawTensor<'r> { } } + /// # Safety + /// This is very unsafe do not use this unless you know what you are doing + /// Gives a raw pointer to the tensor's data + /// P.S. I don't know what I'm doing + pub unsafe fn unchecked_host_ptr(&self) -> *mut c_void { + let data = mnn_sys::Tensor_host_mut(self.inner); + debug_assert!(data.is_null()); + data + } + + /// # Safety + /// This is very unsafe do not use this unless you know what you are doing + /// Gives a mutable byte slice to the tensor's data + pub unsafe fn unchecked_host_bytes(&self) -> &mut [u8] { + core::slice::from_raw_parts_mut(self.unchecked_host_ptr().cast(), self.size()) + } + /// # Safety /// This is very unsafe do not use this unless you know what you are doing pub unsafe fn to_concrete(self) -> super::Tensor diff --git a/tools/bencher/Cargo.toml b/tools/bencher/Cargo.toml new file mode 100644 index 0000000..5f6826a --- /dev/null +++ b/tools/bencher/Cargo.toml @@ -0,0 +1,30 @@ +[package] +name = "bencher" +version = "0.1.0" +edition = "2021" +license.workspace = true + +[target."aarch64-apple-darwin".dependencies] +mnn = { workspace = true, features = ["metal", "opencl", "serde"] } + +[target."x86_64-apple-darwin".dependencies] +mnn = { workspace = true, features = ["opencl", "serde"] } + +[target."cfg(windows)".dependencies] +mnn = { workspace = true, features = ["opencl", "serde"] } + +[dependencies] +chumsky = "0.9.3" +clap = { version = "4.5.22", features = ["derive"] } +clap-verbosity-flag = { version = "3.0.1", features = ["tracing"], default-features = false } +console = "0.15.8" +dunce = "1.0.5" +error-stack = { workspace = true, features = ["serde"] } +indicatif = "0.17.9" +same-file = "1.0.6" +serde = { version = "1.0.215", features = ["derive"] } +serde_json = "1.0.133" +tempfile = "3.14.0" +thiserror = "2.0.4" +tracing = "0.1.41" +tracing-subscriber = "0.3.19" diff --git a/tools/bencher/src/cli.rs b/tools/bencher/src/cli.rs new file mode 100644 index 0000000..8dc31fc --- /dev/null +++ b/tools/bencher/src/cli.rs @@ -0,0 +1,133 @@ +use std::path::PathBuf; + +use chumsky::prelude::*; +// fn parse() -> impl Parser> +// fn models() -> impl Parser> { +// let model = super::ModelIO::parser(); +// let comma = char(',').skip_many1(); +// let models = model.sep_by(comma); +// models +// } +pub enum ModelIOArgs { + Path(PathBuf), + Assert(PathBuf), + InputType(super::DataTypes), + OutputType(super::DataTypes), +} + +// pub fn arg<'a, T: Clone + 'a, E: chumsky::Error<&'a T>>( +// s: T, +// ) -> chumsky::primitive::Just<&'a [T], &'a [T], E> { +// just(&[s]) +// } +macro_rules! arg { + ($s:expr) => { + just::<&str, _, Simple<&str>>($s) + }; +} + +fn models<'a>() -> impl Parser<&'a str, Vec, Error = Simple<&'a str>> { + let assert = choice((arg!("--assert"), arg!("-a"))) + .then(path()) + .map(|(_, p)| p); + let data_type = choice(( + arg!("f32").to(super::DataTypes::F32), + arg!("u8").to(super::DataTypes::U8), + )); + let input_type = choice((arg!("--input-type"), arg!("-i"))) + .then(data_type) + .map(|(_, t)| t); + let output_type = choice((arg!("--output-type"), arg!("-o"))) + .then(choice(( + arg!("f32").to(super::DataTypes::F32), + arg!("u8").to(super::DataTypes::U8), + ))) + .map(|(_, t)| t); + let args = choice(( + // path.map(|p| ModelIOArgs::Path(p)), + assert.map(|p| ModelIOArgs::Assert(p)), + input_type.map(|t| ModelIOArgs::InputType(t)), + output_type.map(|t| ModelIOArgs::OutputType(t)), + )) + .repeated(); + let mios = path().then(args).map(|(p, margs)| { + let mut mio = super::ModelIO::default(); + mio.path = p; + margs.into_iter().for_each(|arg| match arg { + ModelIOArgs::Path(p) => mio.path = p, + ModelIOArgs::Assert(p) => mio.assert = Some(p), + ModelIOArgs::InputType(t) => mio.input_type = t, + ModelIOArgs::OutputType(t) => mio.output_type = t, + }); + mio + }); + mios.repeated() +} + +#[derive(Debug, Clone)] +pub enum Flags { + Verbose, + Warmup(u8), + Output(PathBuf), + Exec, +} +fn flags<'a>() -> impl Parser<&'a str, Vec, Error = Simple<&'a str>> { + choice(( + choice((arg!("--verbose"), arg!("-v"))).to(Flags::Verbose), + choice((arg!("--warmup"), arg!("-w"))) + .ignore_then(any().from_str().unwrapped()) + .map(Flags::Warmup), + )) + .repeated() +} + +fn path<'i>() -> impl Parser<&'i str, PathBuf, Error = Simple<&'i str>> { + any().map(|c| PathBuf::from(c)) +} + +impl super::Cli { + pub fn try_from_env() -> super::Result { + // let args: Vec<_> = std::env::args() + // // .enumerate() + // // .map(|(i, a)| (a, i..i + 1)) + // .collect(); + // let args_str: Vec<_> = args + // .iter() + // // .enumerate() + // // .map(|(i, item)| (item.as_str(), i..i + 1)) + // .map(|i| i.as_str()) + // .collect(); + let args = std::env::args().collect::>(); + let args_str = args.iter().map(|i| i.as_str()).collect::>(); + + let mio = path() + .then(choice((models().to(()), flags().to(())))) + .parse(args_str); + + // let mio = super::ModelIO::parse().parse(args_str.as_slice()); + dbg!(mio.unwrap()); + todo!() + } +} +#[derive(Debug, Clone, ValueEnum, Default)] +pub enum DataTypes { + #[default] + F32, + U8, +} + +#[derive(Debug, Clone, Args, Default)] +pub struct ModelIO { + path: PathBuf, + #[clap(short, long)] + assert: Option, + #[clap(short, long, default_value = "f32")] + input_type: DataTypes, + #[clap(short, long, default_value = "f32")] + output_type: DataTypes, +} +impl AsRef for ModelIO { + fn as_ref(&self) -> &Path { + &self.path + } +} diff --git a/tools/bencher/src/main.rs b/tools/bencher/src/main.rs new file mode 100644 index 0000000..57e9018 --- /dev/null +++ b/tools/bencher/src/main.rs @@ -0,0 +1,521 @@ +use console::Term; +use error_stack::*; +use indicatif::{MultiProgress, ProgressBar}; +use mnn::ScheduleConfig; +use std::{ + collections::BTreeMap, + io::IsTerminal, + path::{Path, PathBuf}, + time::Duration, +}; +use thiserror::Error; +use tracing_subscriber::{layer::SubscriberExt as _, util::SubscriberInitExt as _}; +#[derive(Debug, Clone, Error, Copy)] +#[error("BenchError: Failed to bench")] +pub struct BenchError; +use clap::*; + +pub trait ResultExtCC: ResultExt + Sized { + #[track_caller] + fn cc(self, context: C) -> core::result::Result> { + self.change_context(context) + } +} + +impl ResultExtCC for T where T: ResultExt {} + +#[derive(Debug, Clone, Parser)] +pub struct Generate { + models: Vec, + // Always generate with cpu by default + #[clap(short, long, default_value = "cpu")] + forward: mnn::ForwardType, + #[clap(short, long, default_value = "high")] + power: mnn::PowerMode, + #[clap(short, long, default_value = "high")] + precision: mnn::PrecisionMode, + #[clap(short, long, default_value = "high")] + memory: mnn::MemoryMode, +} + +#[derive(Debug, Clone, Parser)] +pub struct Cli { + #[clap(subcommand)] + subcommand: Subcommand, + #[command(flatten)] + verbose: clap_verbosity_flag::Verbosity, +} + +#[derive(Debug, Clone, Subcommand)] +pub enum Subcommand { + Bench(Bench), + Generate(Generate), +} + +#[derive(Debug, Clone, Parser)] +pub struct Bench { + models: Vec, + #[clap(flatten)] + sc_items: ScheduleConfigItems, + #[clap(short, long, default_value = "10")] + warmup: u8, + #[clap(short, long)] + output: Option, + /// Run in exec mode i.e. run the self binary with the given arguments individually. This + /// provides a way to bypass segmentation faults in the library. + #[clap(short, long)] + exec: bool, +} + +#[derive(Default, Debug, Clone, serde::Serialize, serde::Deserialize)] +pub struct Config { + inputs: BTreeMap, + outputs: BTreeMap, +} + +impl Config { + pub fn find(model: impl AsRef) -> Result { + let model = model.as_ref(); + let config = model.with_extension("json"); + let config = std::fs::read(config).cc(BenchError)?; + let config: Config = serde_json::from_slice(&config).cc(BenchError)?; + Ok(config) + } +} + +#[derive(Debug, Clone, Args)] +pub struct ScheduleConfigItems { + /// Comma separated list of forward types (cpu / opencl / metal / coreml) + #[clap(short, long, value_delimiter = ',', num_args= 1.., default_value = "cpu")] + forward: Vec, + /// Comma separated list of power modes (low / high / normal) + #[clap(short = 'P', long,value_delimiter = ',', num_args= 1.., default_value = "normal")] + power: Vec, + /// Comma separated list of precision modes (low / high / normal) + #[clap(short, long,value_delimiter = ',', num_args= 1.., default_value = "normal")] + precision: Vec, + /// Comma separated list of memory modes (low / high / normal) + #[clap(short, long,value_delimiter = ',', num_args= 1.., default_value = "normal")] + memory: Vec, +} + +pub struct ScheduleConfigItem { + pub forward: mnn::ForwardType, + pub power: mnn::PowerMode, + pub precision: mnn::PrecisionMode, + pub memory: mnn::MemoryMode, +} + +impl ScheduleConfigItem { + pub fn new( + forward: mnn::ForwardType, + power: mnn::PowerMode, + precision: mnn::PrecisionMode, + memory: mnn::MemoryMode, + ) -> Self { + Self { + forward, + power, + precision, + memory, + } + } + + pub fn into_schedule_config(self) -> ScheduleConfig { + let mut sc = mnn::ScheduleConfig::new(); + let mut bc = mnn::BackendConfig::new(); + bc.set_power_mode(self.power); + bc.set_precision_mode(self.precision); + bc.set_memory_mode(self.memory); + sc.set_type(self.forward).set_backend_config(bc); + sc + } +} + +impl ScheduleConfigItems { + pub fn is_empty(&self) -> bool { + self.forward.is_empty() + || self.power.is_empty() + || self.precision.is_empty() + || self.memory.is_empty() + } + + pub fn is_single(&self) -> bool { + self.combinations() == 1 + } + + pub fn combinations(&self) -> usize { + self.forward.len() * self.power.len() * self.precision.len() * self.memory.len() + } +} + +impl IntoIterator for ScheduleConfigItems { + type Item = ScheduleConfigItem; + type IntoIter = std::vec::IntoIter; + + fn into_iter(self) -> Self::IntoIter { + let outputs: Vec = self + .forward + .iter() + .map(|f| { + self.power.iter().map(|p| { + self.precision.iter().map(|pr| { + self.memory + .iter() + .map(|m| ScheduleConfigItem::new(*f, *p, *pr, *m)) + }) + }) + }) + .flatten() + .flatten() + .flatten() + .collect(); + outputs.into_iter() + } +} +type Result> = core::result::Result; + +#[derive(Debug, serde::Serialize)] +pub struct Metrics { + pub model: PathBuf, + pub metrics: Vec, +} + +#[derive(Debug)] +pub struct Metric { + pub memory: f32, // in MiB + pub flops: f32, // in Mflops + pub initial_load_time: Duration, // in ms + pub cached_load_time: Duration, // in ms + pub inference_time: Duration, // in ms + pub schedule_config: ScheduleConfig, +} + +impl serde::Serialize for Metric { + fn serialize(&self, serializer: S) -> Result { + use serde::ser::SerializeStruct as _; + let mut state = serializer.serialize_struct("Metric", 6)?; + state.serialize_field("memory", &format!("{:.0}MiB", self.memory))?; + state.serialize_field("flops", &format!("{:.0}M", self.flops))?; + state.serialize_field( + "initial_load_time", + &format!("{}ms", self.initial_load_time.as_millis()), + )?; + state.serialize_field( + "cached_load_time", + &format!("{}ms", self.cached_load_time.as_millis()), + )?; + state.serialize_field( + "inference_time", + &format!("{}ms", self.inference_time.as_millis()), + )?; + state.serialize_field("schedule_config", &self.schedule_config)?; + state.end() + } +} + +pub fn main() -> Result<()> { + let cli = Cli::parse(); + // let cli = Bench::parse(); + // let indicatif_layer = IndicatifLayer::new(); + tracing_subscriber::registry() + .with(cli.verbose.tracing_level_filter()) + // .with(tracing_subscriber::fmt::layer().with_writer(Term::stderr)) + .init(); + + match cli.subcommand { + Subcommand::Bench(cli) => bench_main(cli)?, + Subcommand::Generate(cli) => generate_main(cli)?, + } + + Ok(()) +} + +pub fn generate_main(_cli: Generate) -> Result<()> { + Ok(()) +} + +pub fn bench_main(cli: Bench) -> Result<()> { + let multi_progress = indicatif::MultiProgress::new(); + let output = if !cli.exec { + let results = bench_all(cli.models.iter(), cli.sc_items, cli.warmup, &multi_progress); + serde_json::to_string_pretty(&results).cc(BenchError)? + } else { + let results = exec_bench_all(cli.models.iter(), cli.sc_items, cli.warmup, &multi_progress)?; + serde_json::to_string_pretty(&results).cc(BenchError)? + }; + use std::io::Write; + if let Some(out_f) = cli.output { + std::fs::File::create(out_f) + .cc(BenchError)? + .write_all(output.as_bytes()) + .cc(BenchError)?; + } else { + Term::stdout().write_all(output.as_bytes()).cc(BenchError)?; + } + Ok(()) +} + +pub fn exec_bench_all<'a>( + models: impl Iterator, + sc_items: ScheduleConfigItems, + warmup: u8, + mp: &MultiProgress, +) -> Result>> { + let self_exe = std::env::current_exe().cc(BenchError)?; + let result: Vec> = models + .map(|m| { + let pb = indicatif::ProgressBar::new(sc_items.combinations() as u64) + .with_prefix(format!("{}", m.file_name().unwrap().to_string_lossy())) + .with_style( + indicatif::ProgressStyle::default_bar() + .template("{prefix} {bar:80} {pos}/{len} {msg}") + .expect("Failed to build progress bar style"), + ); + mp.insert(0, pb.clone()); + sc_items + .clone() + .into_iter() + .map({ + |sc| { + pb.set_message(format!( + "{:?}:power->{:?}:precision->{:?}:memory->{:?}", + sc.forward, sc.power, sc.precision, sc.memory + )); + let out = exec_bench(&self_exe, warmup, sc, m, &mp); + pb.inc(1); + out + } + }) + .collect::>() + }) + .flatten() + .collect(); + Ok(result) +} + +pub fn exec_bench( + exec: &Path, + w: u8, + sc: ScheduleConfigItem, + model: impl AsRef, + mp: &MultiProgress, +) -> Result { + let mut child = std::process::Command::new(exec) + .stdout(std::process::Stdio::piped()) + .stderr(std::process::Stdio::piped()) + .arg("bench") + .arg(model.as_ref()) + .arg("--memory") + .arg(sc.memory.to_str()) + .arg("--power") + .arg(sc.power.to_str()) + .arg("--precision") + .arg(sc.precision.to_str()) + .arg("--forward") + .arg(sc.forward.to_str()) + .arg("--warmup") + .arg(w.to_string()) + .spawn() + .cc(BenchError)?; + let child_stderr = child.stderr.take().expect("Failed to get stderr"); + let child_stdout = child.stdout.take().expect("Failed to get stdout"); + let progress = p_read(child_stderr); + progress.enable_steady_tick(Duration::from_millis(200)); + mp.insert(0, progress.clone()); + let output = child.wait().cc(BenchError)?; + if !output.success() { + return Err(Report::new(BenchError) + .attach_printable(format!("Failed to execute {exec}", exec = exec.display()))); + } + progress.finish_and_clear(); + let metrics = serde_json::from_reader(child_stdout).cc(BenchError)?; + Ok(metrics) +} + +pub fn bench_all( + models: impl Iterator>, + sc_items: ScheduleConfigItems, + warmup: u8, + multi_progress: &MultiProgress, +) -> Vec> { + let result: Vec> = models + .map(|m| -> Result { + // Check create_session_time without cache + let pb = indicatif::ProgressBar::new(sc_items.combinations() as u64) + .with_prefix(format!( + "{}", + m.as_ref().file_name().unwrap().to_string_lossy() + )) + .with_style(if sc_items.is_single() { + indicatif::ProgressStyle::default_bar() + .template("{prefix} {msg}") + .expect("Failed to build progress bar style") + } else { + indicatif::ProgressStyle::default_bar() + .template("{prefix} {bar:80} {pos}/{len} {msg}") + .expect("Failed to build progress bar style") + }); + + multi_progress.add(pb.clone()); + let metrics = sc_items + .clone() + .into_iter() + .map(|sc| { + pb.set_message(format!( + "{:?}:power->{:?}:precision->{:?}:memory->{:?}", + sc.forward, sc.power, sc.precision, sc.memory + )); + let o = bench( + warmup, + sc.into_schedule_config(), + m.as_ref(), + &multi_progress, + ) + .cc(BenchError); + pb.inc(1); + o + }) + .collect::>>() + .cc(BenchError)?; + Ok(Metrics { + model: dunce::canonicalize(m).cc(BenchError)?, + metrics, + }) + }) + .collect(); + result +} + +// #[tracing::instrument(skip(model))] +pub fn bench( + w: u8, + sc: ScheduleConfig, + model: impl AsRef, + mp: &MultiProgress, +) -> Result { + let bar = indicatif::ProgressBar::new_spinner(); + mp.insert(0, bar.clone()); + bar.enable_steady_tick(Duration::from_millis(300)); + let not_terminal = !std::io::stdout().is_terminal(); + + tracing::info!("Benching {:?}", sc); + let mut net = mnn::Interpreter::from_file(&model).cc(BenchError)?; + + bar.set_message("Creating session without cache"); + not_terminal.then(|| eprintln!("Creating session without cache")); + let (mut uncached, initial_load_time) = timeit(|| { + tracing::trace!("Creating session without cache"); + net.create_session(sc.clone()) + }) + .cc(BenchError)?; + let temp_file = temp_file_path()?; + net.set_cache_file(&temp_file, 128).cc(BenchError)?; + net.update_cache_file(&mut uncached).cc(BenchError)?; + drop(uncached); + drop(net); + let mut net = mnn::Interpreter::from_file(&model).cc(BenchError)?; + net.set_cache_file(&temp_file, 128).cc(BenchError)?; + bar.set_message("Creating session with cache"); + not_terminal.then(|| eprintln!("Creating session with cache")); + let (session, cached_load_time) = timeit(|| { + tracing::trace!("Creating session with cache {temp_file:?}"); + net.create_session(sc.clone()) + }) + .cc(BenchError)?; + for c in 0..w { + bar.set_message(format!("Warming up {c}")); + not_terminal.then(|| eprintln!("Warming up {c}")); + net.run_session(&session).cc(BenchError)?; + } + let config = Config::find(&model).cc(BenchError).unwrap_or_default(); + for (name, path) in config.inputs.iter() { + let input = std::fs::read(path).cc(BenchError)?; + bar.set_message(format!("Setting input {name}")); + not_terminal.then(|| eprintln!("Setting input {name}")); + unsafe { + net.raw_input(&session, name) + .cc(BenchError)? + .unchecked_host_bytes() + .copy_from_slice(&input); + } + } + let (_, inference_time) = timeit(|| -> Result<()> { + bar.set_message("Running session"); + not_terminal.then(|| eprintln!("Running session")); + net.run_session(&session).cc(BenchError)?; + net.wait(&session); + Ok(()) + }) + .cc(BenchError)?; + + for (name, path) in config.outputs.iter() { + bar.set_message(format!("Checking output {name}")); + not_terminal.then(|| eprintln!("Checking output {name}")); + let output = unsafe { + net.raw_output(&session, name) + .cc(BenchError)? + .unchecked_host_bytes() + .to_vec() + }; + assert_eq!( + output.len(), + std::fs::metadata(path).cc(BenchError)?.len() as usize + ); + assert_eq!(output, std::fs::read(path).cc(BenchError)?); + } + let memory = net.memory(&session).cc(BenchError)?; + let flops = net.flops(&session).cc(BenchError)?; + temp_file.close().cc(BenchError)?; + Ok(Metric { + schedule_config: sc, + memory, + flops, + initial_load_time, + cached_load_time, + inference_time, + }) +} + +pub fn timeit Result, T, E>(f: F) -> Result<(T, Duration), E> { + let start = std::time::Instant::now(); + let result = f()?; + let duration = start.elapsed(); + Ok((result, duration)) +} + +pub fn temp_file_path() -> Result { + Ok(tempfile::NamedTempFile::new() + .cc(BenchError)? + .into_temp_path()) +} + +pub fn p_read(reader: impl std::io::Read + Send + Sync + 'static) -> ProgressBar { + let bar = ProgressBar::new_spinner().with_style( + indicatif::ProgressStyle::default_bar() + .template("{spinner} {msg}") + .expect("Failed to build progress bar style"), + ); + let bar_ = bar.clone(); + + std::thread::spawn(move || { + use std::io::BufRead; + let mut reader = std::io::BufReader::new(reader); + let mut buffer = String::new(); + while reader + .read_line(&mut buffer) + .cc(BenchError) + .expect("Failed to read line") + > 0 + { + buffer.ends_with('\n').then(|| buffer.pop()); + bar.set_message(buffer.clone()); + buffer.clear(); + std::thread::sleep(Duration::from_millis(100)); + if bar.is_finished() { + break; + } + } + }); + bar_ +} From 5125aa24be0d35684232377388371344fde37222 Mon Sep 17 00:00:00 2001 From: uttarayan21 Date: Mon, 9 Dec 2024 18:33:55 +0530 Subject: [PATCH 11/22] feat: Added completions --- Cargo.lock | 97 ++++----------------------------------- tools/bencher/Cargo.toml | 6 ++- tools/bencher/src/main.rs | 15 ++++++ 3 files changed, 29 insertions(+), 89 deletions(-) diff --git a/Cargo.lock b/Cargo.lock index 5abe199..242ab40 100644 --- a/Cargo.lock +++ b/Cargo.lock @@ -2,18 +2,6 @@ # It is not intended for manual editing. version = 4 -[[package]] -name = "ahash" -version = "0.8.11" -source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "e89da841a80418a9b391ebaea17f5c112ffaaa96f621d2c285b5174da76b9011" -dependencies = [ - "cfg-if", - "once_cell", - "version_check", - "zerocopy", -] - [[package]] name = "aho-corasick" version = "1.1.3" @@ -23,12 +11,6 @@ dependencies = [ "memchr", ] -[[package]] -name = "allocator-api2" -version = "0.2.21" -source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "683d7910e743518b0e34f1186f92494becacb047c7b6bf616c96772180fef923" - [[package]] name = "annotate-snippets" version = "0.9.2" @@ -104,9 +86,9 @@ checksum = "ace50bade8e6234aa140d9a2f552bbee1db4d353f69b8217bc503490fc1a9f26" name = "bencher" version = "0.1.0" dependencies = [ - "chumsky", "clap", "clap-verbosity-flag", + "clap_complete", "console", "dunce", "error-stack", @@ -186,16 +168,6 @@ version = "1.0.0" source = "registry+https://github.com/rust-lang/crates.io-index" checksum = "baf1de4339761588bc0619e3cbc0120ee582ebb74b53b4efbf79117bd2da40fd" -[[package]] -name = "chumsky" -version = "0.9.3" -source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "8eebd66744a15ded14960ab4ccdbfb51ad3b81f51f3f04a80adac98c985396c9" -dependencies = [ - "hashbrown", - "stacker", -] - [[package]] name = "clang-sys" version = "1.8.1" @@ -240,6 +212,15 @@ dependencies = [ "terminal_size", ] +[[package]] +name = "clap_complete" +version = "4.5.38" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "d9647a559c112175f17cf724dc72d3645680a883c58481332779192b0d8e7a01" +dependencies = [ + "clap", +] + [[package]] name = "clap_derive" version = "4.5.18" @@ -405,16 +386,6 @@ version = "0.3.1" source = "registry+https://github.com/rust-lang/crates.io-index" checksum = "d2fabcfbdc87f4758337ca535fb41a6d701b65693ce38287d856d1674551ec9b" -[[package]] -name = "hashbrown" -version = "0.14.5" -source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "e5274423e17b7c9fc20b6e7e208532f9b19825d82dfd615708b70edd83df41f1" -dependencies = [ - "ahash", - "allocator-api2", -] - [[package]] name = "heck" version = "0.5.0" @@ -765,15 +736,6 @@ dependencies = [ "unicode-ident", ] -[[package]] -name = "psm" -version = "0.1.24" -source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "200b9ff220857e53e184257720a14553b2f4aa02577d2ed9842d45d4b9654810" -dependencies = [ - "cc", -] - [[package]] name = "quote" version = "1.0.37" @@ -956,19 +918,6 @@ dependencies = [ "lock_api", ] -[[package]] -name = "stacker" -version = "0.1.17" -source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "799c883d55abdb5e98af1a7b3f23b9b6de8ecada0ecac058672d7635eb48ca7b" -dependencies = [ - "cc", - "cfg-if", - "libc", - "psm", - "windows-sys 0.59.0", -] - [[package]] name = "strsim" version = "0.11.1" @@ -1157,12 +1106,6 @@ version = "0.1.0" source = "registry+https://github.com/rust-lang/crates.io-index" checksum = "830b7e5d4d90034032940e4ace0d9a9a057e7a45cd94e6c007832e39edb82f6d" -[[package]] -name = "version_check" -version = "0.9.5" -source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "0b928f33d975fc6ad9f86c8f283853ad26bdd5b10b7f1542aa2fa15e2289105a" - [[package]] name = "wasi" version = "0.11.0+wasi-snapshot-preview1" @@ -1355,23 +1298,3 @@ checksum = "fe5c30ade05e61656247b2e334a031dfd0cc466fadef865bdcdea8d537951bf1" dependencies = [ "winapi", ] - -[[package]] -name = "zerocopy" -version = "0.7.35" -source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "1b9b4fd18abc82b8136838da5d50bae7bdea537c574d8dc1a34ed098d6c166f0" -dependencies = [ - "zerocopy-derive", -] - -[[package]] -name = "zerocopy-derive" -version = "0.7.35" -source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "fa4f8080344d4671fb4e831a13ad1e68092748387dfc4f55e356242fae12ce3e" -dependencies = [ - "proc-macro2", - "quote", - "syn", -] diff --git a/tools/bencher/Cargo.toml b/tools/bencher/Cargo.toml index 5f6826a..342db2a 100644 --- a/tools/bencher/Cargo.toml +++ b/tools/bencher/Cargo.toml @@ -14,9 +14,11 @@ mnn = { workspace = true, features = ["opencl", "serde"] } mnn = { workspace = true, features = ["opencl", "serde"] } [dependencies] -chumsky = "0.9.3" clap = { version = "4.5.22", features = ["derive"] } -clap-verbosity-flag = { version = "3.0.1", features = ["tracing"], default-features = false } +clap-verbosity-flag = { version = "3.0.1", features = [ + "tracing", +], default-features = false } +clap_complete = "4.5.38" console = "0.15.8" dunce = "1.0.5" error-stack = { workspace = true, features = ["serde"] } diff --git a/tools/bencher/src/main.rs b/tools/bencher/src/main.rs index 57e9018..a28218d 100644 --- a/tools/bencher/src/main.rs +++ b/tools/bencher/src/main.rs @@ -50,6 +50,12 @@ pub struct Cli { pub enum Subcommand { Bench(Bench), Generate(Generate), + Completions(Completions), +} +#[derive(Debug, Clone, Parser)] +pub struct Completions { + #[clap(short, long)] + pub shell: clap_complete::Shell, } #[derive(Debug, Clone, Parser)] @@ -226,6 +232,15 @@ pub fn main() -> Result<()> { match cli.subcommand { Subcommand::Bench(cli) => bench_main(cli)?, Subcommand::Generate(cli) => generate_main(cli)?, + Subcommand::Completions(cli) => { + use clap_complete::aot::{generate, Generator, Shell}; + + let shell = cli.shell; + let mut cmd = Cli::command(); + let name = cmd.get_name().to_string(); + generate(shell, &mut cmd, name, &mut std::io::stdout()); + // Cli::command().gen_completions_to("mnn", shell, &mut std::io::stdout()); + } } Ok(()) From 2cf1bd10587037a2edd0de21445cc2361678d065 Mon Sep 17 00:00:00 2001 From: uttarayan21 Date: Tue, 10 Dec 2024 18:05:05 +0530 Subject: [PATCH 12/22] feat: Added benchmark result generator and checker --- Cargo.lock | 49 ++++++++++ flake.lock | 24 ++--- src/backend.rs | 25 +++++ src/interpreter.rs | 1 + src/schedule.rs | 25 +++++ src/tensor.rs | 40 ++++++++ tools/bencher/Cargo.toml | 5 +- tools/bencher/src/main.rs | 193 +++++++++++++++++++++++++++++++++----- 8 files changed, 325 insertions(+), 37 deletions(-) diff --git a/Cargo.lock b/Cargo.lock index 242ab40..e743591 100644 --- a/Cargo.lock +++ b/Cargo.lock @@ -86,6 +86,7 @@ checksum = "ace50bade8e6234aa140d9a2f552bbee1db4d353f69b8217bc503490fc1a9f26" name = "bencher" version = "0.1.0" dependencies = [ + "bytemuck", "clap", "clap-verbosity-flag", "clap_complete", @@ -94,6 +95,8 @@ dependencies = [ "error-stack", "indicatif", "mnn", + "ndarray 0.16.1", + "num", "same-file", "serde", "serde_json", @@ -645,6 +648,30 @@ dependencies = [ "windows-sys 0.52.0", ] +[[package]] +name = "num" +version = "0.4.3" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "35bd024e8b2ff75562e5f34e7f4905839deb4b22955ef5e73d2fea1b9813cb23" +dependencies = [ + "num-bigint", + "num-complex", + "num-integer", + "num-iter", + "num-rational", + "num-traits", +] + +[[package]] +name = "num-bigint" +version = "0.4.6" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "a5e44f723f1133c9deac646763579fdb3ac745e418f2a7af9cd0c431da1f20b9" +dependencies = [ + "num-integer", + "num-traits", +] + [[package]] name = "num-complex" version = "0.4.6" @@ -663,6 +690,28 @@ dependencies = [ "num-traits", ] +[[package]] +name = "num-iter" +version = "0.1.45" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "1429034a0490724d0075ebb2bc9e875d6503c3cf69e235a8941aa757d83ef5bf" +dependencies = [ + "autocfg", + "num-integer", + "num-traits", +] + +[[package]] +name = "num-rational" +version = "0.4.2" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "f83d14da390562dca69fc84082e73e548e1ad308d24accdedd2720017cb37824" +dependencies = [ + "num-bigint", + "num-integer", + "num-traits", +] + [[package]] name = "num-traits" version = "0.2.19" diff --git a/flake.lock b/flake.lock index 0b10dd1..edf96f4 100644 --- a/flake.lock +++ b/flake.lock @@ -3,11 +3,11 @@ "advisory-db": { "flake": false, "locked": { - "lastModified": 1733371256, - "narHash": "sha256-gWvibGRlB+SMgqTOblVPpkcIAcl0LppLz1dBukEyXoY=", + "lastModified": 1733749954, + "narHash": "sha256-2Ug80Uf/oUujxgh02Iy5vTG0V+Ab9+YUHuRLRY0ayiY=", "owner": "rustsec", "repo": "advisory-db", - "rev": "463107188fc02ccaddefc8f4a65746afa06bb7fa", + "rev": "ec9ce28714bb38d77a2223e7266df705500a7f11", "type": "github" }, "original": { @@ -18,11 +18,11 @@ }, "crane": { "locked": { - "lastModified": 1733286231, - "narHash": "sha256-mlIDSv1/jqWnH8JTiOV7GMUNPCXL25+6jmD+7hdxx5o=", + "lastModified": 1733688869, + "narHash": "sha256-KrhxxFj1CjESDrL5+u/zsVH0K+Ik9tvoac/oFPoxSB8=", "owner": "ipetkov", "repo": "crane", - "rev": "af1556ecda8bcf305820f68ec2f9d77b41d9cc80", + "rev": "604637106e420ad99907cae401e13ab6b452e7d9", "type": "github" }, "original": { @@ -145,11 +145,11 @@ }, "nixpkgs": { "locked": { - "lastModified": 1733212471, - "narHash": "sha256-M1+uCoV5igihRfcUKrr1riygbe73/dzNnzPsmaLCmpo=", + "lastModified": 1733581040, + "narHash": "sha256-Qn3nPMSopRQJgmvHzVqPcE3I03zJyl8cSbgnnltfFDY=", "owner": "nixos", "repo": "nixpkgs", - "rev": "55d15ad12a74eb7d4646254e13638ad0c4128776", + "rev": "22c3f2cf41a0e70184334a958e6b124fb0ce3e01", "type": "github" }, "original": { @@ -178,11 +178,11 @@ ] }, "locked": { - "lastModified": 1733366051, - "narHash": "sha256-Zlas3LFqrW8bVVrZYgkzS4VNkZgtZ/hsbYhO0GtKLys=", + "lastModified": 1733798086, + "narHash": "sha256-XHIh0h84xDnjkqampyNI/r2FAkKmwbL719ZsygiJHKE=", "owner": "oxalica", "repo": "rust-overlay", - "rev": "ba5ed0362eaae83fe8925a2d5cfcf356ff22f70f", + "rev": "8a19e07800d64462913f3dbf5c9a20ea7b50e6cd", "type": "github" }, "original": { diff --git a/src/backend.rs b/src/backend.rs index 1f633bb..65e00d9 100644 --- a/src/backend.rs +++ b/src/backend.rs @@ -231,6 +231,11 @@ impl BackendConfig { } } + pub fn with_memory_mode(mut self, mode: MemoryMode) -> Self { + self.set_memory_mode(mode); + self + } + pub fn get_memory_mode(&self) -> MemoryMode { unsafe { MemoryMode::from_mnn_sys(mnn_sys::mnnbc_get_memory_mode(self.inner)) } } @@ -242,6 +247,11 @@ impl BackendConfig { } } + pub fn with_power_mode(mut self, mode: PowerMode) -> Self { + self.set_power_mode(mode); + self + } + pub fn get_power_mode(&self) -> PowerMode { unsafe { PowerMode::from_mnn_sys(mnn_sys::mnnbc_get_power_mode(self.inner)) } } @@ -253,6 +263,11 @@ impl BackendConfig { } } + pub fn with_precision_mode(mut self, mode: PrecisionMode) -> Self { + self.set_precision_mode(mode); + self + } + pub fn get_precision_mode(&self) -> PrecisionMode { unsafe { PrecisionMode::from_mnn_sys(mnn_sys::mnnbc_get_precision_mode(self.inner)) } } @@ -265,6 +280,11 @@ impl BackendConfig { } } + pub fn with_flags(mut self, flags: usize) -> Self { + self.set_flags(flags); + self + } + /// # Safety /// This just binds to the underlying unsafe api and should be used only if you know what you /// are doing @@ -273,4 +293,9 @@ impl BackendConfig { mnn_sys::mnnbc_set_shared_context(self.inner, shared_context); } } + + pub unsafe fn with_shared_context(mut self, shared_context: *mut libc::c_void) -> Self { + self.set_shared_context(shared_context); + self + } } diff --git a/src/interpreter.rs b/src/interpreter.rs index 80c1a25..70a0e52 100644 --- a/src/interpreter.rs +++ b/src/interpreter.rs @@ -654,6 +654,7 @@ impl<'t, 'tl> TensorInfo<'t, 'tl> { Ok(tensor) } + /// This function return's the raw tensor without any sort of type-checking or shape-checking pub fn raw_tensor(&self) -> RawTensor<'t> { debug_assert!(!self.tensor_info.is_null()); unsafe { debug_assert!(!(*self.tensor_info).tensor.is_null()) }; diff --git a/src/schedule.rs b/src/schedule.rs index f120bde..8c314a2 100644 --- a/src/schedule.rs +++ b/src/schedule.rs @@ -310,6 +310,11 @@ impl ScheduleConfig { self } + pub fn with_type(mut self, forward_type: ForwardType) -> Self { + self.set_type(forward_type); + self + } + /// Gets the type of backend to be used for computation. pub fn get_type(&self) -> ForwardType { unsafe { ForwardType::from_mnn_sys(mnnsc_get_type(self.inner)) } @@ -327,6 +332,11 @@ impl ScheduleConfig { self } + pub fn with_num_threads(mut self, num_threads: i32) -> Self { + self.set_num_threads(num_threads); + self + } + /// Sets the mode of computation. /// /// # Arguments @@ -339,6 +349,11 @@ impl ScheduleConfig { self } + pub fn with_mode(mut self, mode: i32) -> Self { + self.set_mode(mode); + self + } + /// Sets the backup type of backend to be used if the primary backend fails. /// /// # Arguments @@ -351,6 +366,11 @@ impl ScheduleConfig { self } + pub fn with_backup_type(mut self, backup_type: ForwardType) -> Self { + self.set_backup_type(backup_type); + self + } + /// Gets the backup type of backend to be used if the primary backend fails. pub fn get_backup_type(&self) -> ForwardType { unsafe { ForwardType::from_mnn_sys(mnnsc_get_backup_type(self.inner)) } @@ -376,6 +396,11 @@ impl ScheduleConfig { } self } + + pub fn with_backend_config(mut self, backend_config: impl Into>) -> Self { + self.set_backend_config(backend_config); + self + } } #[derive(Debug)] diff --git a/src/tensor.rs b/src/tensor.rs index 53a6d36..54bec06 100644 --- a/src/tensor.rs +++ b/src/tensor.rs @@ -282,6 +282,15 @@ where let htc = halide_type_of::(); unsafe { Tensor_isTypeOf(self.tensor, htc) } } + + pub unsafe fn into_raw(self) -> RawTensor<'static> { + let out = RawTensor { + inner: self.tensor, + __marker: PhantomData, + }; + core::mem::forget(self); + out + } } impl Tensor where @@ -686,10 +695,40 @@ pub struct RawTensor<'r> { // } impl<'r> RawTensor<'r> { + pub fn create_host_tensor_from_device(&self, copy_data: bool) -> RawTensor<'static> { + let tensor = + unsafe { mnn_sys::Tensor_createHostTensorFromDevice(self.inner, copy_data as i32) }; + // crate::ensure!(!tensor.is_null(), ErrorKind::TensorError); + assert!(!tensor.is_null()); + RawTensor { + inner: tensor, + __marker: PhantomData, + } + } + + /// Copies the data from a host tensor to the self tensor + pub fn copy_from_host_tensor(&mut self, tensor: &RawTensor) -> Result<()> { + let ret = unsafe { Tensor_copyFromHostTensor(self.inner, tensor.inner) }; + crate::ensure!(ret != 0, ErrorKind::TensorCopyFailed(ret)); + Ok(()) + } + + /// Copies the data from the self tensor to a host tensor + pub fn copy_to_host_tensor(&self, tensor: &mut RawTensor) -> Result<()> { + let ret = unsafe { Tensor_copyToHostTensor(self.inner, tensor.inner) }; + crate::ensure!(ret != 0, ErrorKind::TensorCopyFailed(ret)); + Ok(()) + } + pub fn shape(&self) -> TensorShape { unsafe { mnn_sys::Tensor_shape(self.inner) }.into() } + pub fn get_dimension_type(&self) -> DimensionType { + debug_assert!(!self.inner.is_null()); + From::from(unsafe { mnn_sys::Tensor_getDimensionType(self.inner) }) + } + pub fn destroy(self) { unsafe { mnn_sys::Tensor_destroy(self.inner); @@ -735,6 +774,7 @@ impl<'r> RawTensor<'r> { /// Gives a raw pointer to the tensor's data /// P.S. I don't know what I'm doing pub unsafe fn unchecked_host_ptr(&self) -> *mut c_void { + debug_assert!(!self.inner.is_null()); let data = mnn_sys::Tensor_host_mut(self.inner); debug_assert!(data.is_null()); data diff --git a/tools/bencher/Cargo.toml b/tools/bencher/Cargo.toml index 342db2a..c428bcc 100644 --- a/tools/bencher/Cargo.toml +++ b/tools/bencher/Cargo.toml @@ -14,7 +14,8 @@ mnn = { workspace = true, features = ["opencl", "serde"] } mnn = { workspace = true, features = ["opencl", "serde"] } [dependencies] -clap = { version = "4.5.22", features = ["derive"] } +bytemuck = { version = "1.20.0", features = ["extern_crate_alloc"] } +clap = { version = "4.5.22", features = ["derive", "unstable-v5"] } clap-verbosity-flag = { version = "3.0.1", features = [ "tracing", ], default-features = false } @@ -23,6 +24,8 @@ console = "0.15.8" dunce = "1.0.5" error-stack = { workspace = true, features = ["serde"] } indicatif = "0.17.9" +ndarray = "0.16.1" +num = "0.4.3" same-file = "1.0.6" serde = { version = "1.0.215", features = ["derive"] } serde_json = "1.0.133" diff --git a/tools/bencher/src/main.rs b/tools/bencher/src/main.rs index a28218d..8718048 100644 --- a/tools/bencher/src/main.rs +++ b/tools/bencher/src/main.rs @@ -26,16 +26,27 @@ impl ResultExtCC for T where T: ResultExt {} #[derive(Debug, Clone, Parser)] pub struct Generate { + #[arg(required = true)] models: Vec, // Always generate with cpu by default - #[clap(short, long, default_value = "cpu")] + #[arg(short, long, default_value = "cpu")] forward: mnn::ForwardType, - #[clap(short, long, default_value = "high")] + #[arg(short, long, default_value = "high")] power: mnn::PowerMode, - #[clap(short, long, default_value = "high")] + #[arg(short, long, default_value = "high")] precision: mnn::PrecisionMode, - #[clap(short, long, default_value = "high")] + #[arg(short, long, default_value = "high")] memory: mnn::MemoryMode, + // #[arg(flatten)] + // output_types: Vec, + #[arg(short, long)] + output_type: DataType, +} + +#[derive(Debug, Clone, Args)] +pub struct TypedOutput { + name: String, + data_type: DataType, } #[derive(Debug, Clone, Parser)] @@ -54,29 +65,94 @@ pub enum Subcommand { } #[derive(Debug, Clone, Parser)] pub struct Completions { - #[clap(short, long)] + #[arg(short, long)] pub shell: clap_complete::Shell, } #[derive(Debug, Clone, Parser)] pub struct Bench { + #[arg(required = true)] models: Vec, - #[clap(flatten)] + #[command(flatten)] sc_items: ScheduleConfigItems, - #[clap(short, long, default_value = "10")] + #[arg(short, long, default_value = "10")] warmup: u8, - #[clap(short, long)] + #[arg(short, long)] output: Option, /// Run in exec mode i.e. run the self binary with the given arguments individually. This /// provides a way to bypass segmentation faults in the library. - #[clap(short, long)] + #[arg(short, long)] exec: bool, } #[derive(Default, Debug, Clone, serde::Serialize, serde::Deserialize)] pub struct Config { inputs: BTreeMap, - outputs: BTreeMap, + outputs: BTreeMap, +} + +#[derive(Debug, Clone, PartialEq, Eq, serde::Serialize, serde::Deserialize)] +pub struct ConfigData { + data_type: DataType, + path: PathBuf, +} + +#[derive(Debug, Clone, Copy, PartialEq, Eq, serde::Serialize, serde::Deserialize, ValueEnum)] +pub enum DataType { + // Float16, + Float32, + Int32, + Int64, + Int8, + Uint8, +} + +impl DataType { + pub fn mas(&self, lhs: &[u8], rhs: &[u8]) -> f64 { + match self { + // Self::Float16 => Self::mean_absolute_error_bytes::(lhs, rhs), + Self::Float32 => Self::mean_absolute_error_bytes::(lhs, rhs), + Self::Int32 => Self::mean_absolute_error_bytes::(lhs, rhs), + Self::Int64 => Self::mean_absolute_error_bytes::(lhs, rhs), + Self::Int8 => Self::mean_absolute_error_bytes::(lhs, rhs), + Self::Uint8 => Self::mean_absolute_error_bytes::(lhs, rhs), + } + } + + pub fn mean_absolute_error_bytes< + T: core::ops::Sub + + PartialOrd + + Copy + + core::ops::Add + + num::cast::AsPrimitive + + bytemuck::Pod, + >( + lhs: &[u8], + rhs: &[u8], + ) -> f64 { + let lhs = bytemuck::cast_slice(lhs); + let rhs = bytemuck::cast_slice(rhs); + Self::mean_absolute_error::(lhs, rhs) + } + + pub fn mean_absolute_error< + T: core::ops::Sub + + PartialOrd + + Copy + + core::ops::Add + + num::cast::AsPrimitive, + >( + lhs: impl AsRef<[T]>, + rhs: impl AsRef<[T]>, + ) -> f64 { + let (sum, count) = lhs + .as_ref() + .iter() + .zip(rhs.as_ref()) + .map(|(&l, &r)| if l > r { l - r } else { r - l }) + .fold((0f64, 0usize), |(acc, count), x| (acc + x.as_(), count + 1)); + sum / count as f64 + } } impl Config { @@ -92,16 +168,16 @@ impl Config { #[derive(Debug, Clone, Args)] pub struct ScheduleConfigItems { /// Comma separated list of forward types (cpu / opencl / metal / coreml) - #[clap(short, long, value_delimiter = ',', num_args= 1.., default_value = "cpu")] + #[arg(short, long, value_delimiter = ',', num_args= 1.., default_value = "cpu")] forward: Vec, /// Comma separated list of power modes (low / high / normal) - #[clap(short = 'P', long,value_delimiter = ',', num_args= 1.., default_value = "normal")] + #[arg(short = 'P', long,value_delimiter = ',', num_args= 1.., default_value = "normal")] power: Vec, /// Comma separated list of precision modes (low / high / normal) - #[clap(short, long,value_delimiter = ',', num_args= 1.., default_value = "normal")] + #[arg(short, long,value_delimiter = ',', num_args= 1.., default_value = "normal")] precision: Vec, /// Comma separated list of memory modes (low / high / normal) - #[clap(short, long,value_delimiter = ',', num_args= 1.., default_value = "normal")] + #[arg(short, long,value_delimiter = ',', num_args= 1.., default_value = "normal")] memory: Vec, } @@ -195,6 +271,7 @@ pub struct Metric { pub cached_load_time: Duration, // in ms pub inference_time: Duration, // in ms pub schedule_config: ScheduleConfig, + pub outputs: BTreeMap, // mean absolute error } impl serde::Serialize for Metric { @@ -216,6 +293,7 @@ impl serde::Serialize for Metric { &format!("{}ms", self.inference_time.as_millis()), )?; state.serialize_field("schedule_config", &self.schedule_config)?; + state.serialize_field("outputs", &self.outputs)?; state.end() } } @@ -239,14 +317,77 @@ pub fn main() -> Result<()> { let mut cmd = Cli::command(); let name = cmd.get_name().to_string(); generate(shell, &mut cmd, name, &mut std::io::stdout()); - // Cli::command().gen_completions_to("mnn", shell, &mut std::io::stdout()); } } Ok(()) } -pub fn generate_main(_cli: Generate) -> Result<()> { +pub fn generate_main(cli: Generate) -> Result<()> { + for model in cli.models { + let mut cfg = Config { + inputs: Default::default(), + outputs: Default::default(), + }; + let mut net = mnn::Interpreter::from_file(&model).cc(BenchError)?; + let sc = ScheduleConfig::new() + .with_type(cli.forward) + .with_backend_config( + mnn::BackendConfig::new() + .with_power_mode(cli.power) + .with_precision_mode(cli.precision) + .with_memory_mode(cli.memory), + ); + let session = net.create_session(sc).cc(BenchError)?; + let inputs = net.inputs(&session); + for input in &inputs { + let model_name = model + .file_stem() + .expect("Failed to get model name") + .to_string_lossy(); + let name = format!("{}_input_{}.bin", model_name, input.name()); + let path = model.with_file_name(name); + let tensor = input.raw_tensor(); + unsafe { + tensor.unchecked_host_bytes().fill(1); + std::fs::write(&path, tensor.unchecked_host_bytes()).cc(BenchError)?; + } + cfg.inputs.insert( + input.name().to_string(), + dunce::canonicalize(path).cc(BenchError)?, + ); + } + drop(inputs); + + net.run_session(&session); + + let outputs = net.outputs(&session); + for output in &outputs { + let model_name = model + .file_stem() + .expect("Failed to get model name") + .to_string_lossy(); + let name = format!("{}_output_{}.bin", model_name, output.name()); + let path = model.with_file_name(name); + let tensor = output.raw_tensor(); + unsafe { + let out = tensor.unchecked_host_bytes(); + std::fs::write(&path, out).cc(BenchError)?; + } + cfg.outputs.insert( + output.name().to_string(), + ConfigData { + data_type: cli.output_type, + path: dunce::canonicalize(path).cc(BenchError)?, + }, + ); + } + std::fs::write( + model.with_extension("json"), + serde_json::to_string_pretty(&cfg).cc(BenchError)?, + ) + .cc(BenchError)?; + } Ok(()) } @@ -464,20 +605,23 @@ pub fn bench( }) .cc(BenchError)?; + let mut outputs = BTreeMap::new(); for (name, path) in config.outputs.iter() { bar.set_message(format!("Checking output {name}")); not_terminal.then(|| eprintln!("Checking output {name}")); let output = unsafe { - net.raw_output(&session, name) + let out = net + .raw_output(&session, name) .cc(BenchError)? - .unchecked_host_bytes() - .to_vec() + .create_host_tensor_from_device(true); + + out.unchecked_host_bytes().to_vec() }; - assert_eq!( - output.len(), - std::fs::metadata(path).cc(BenchError)?.len() as usize - ); - assert_eq!(output, std::fs::read(path).cc(BenchError)?); + if let Some(cd) = config.outputs.get(name) { + let expected = std::fs::read(&cd.path).cc(BenchError)?; + let mas = cd.data_type.mas(&output, &expected); + outputs.insert(name.clone(), mas); + } } let memory = net.memory(&session).cc(BenchError)?; let flops = net.flops(&session).cc(BenchError)?; @@ -489,6 +633,7 @@ pub fn bench( initial_load_time, cached_load_time, inference_time, + outputs, }) } From 552850397fa9dc7394997929d451360743704f2f Mon Sep 17 00:00:00 2001 From: uttarayan21 Date: Fri, 13 Dec 2024 16:42:10 +0530 Subject: [PATCH 13/22] feat: Added nix support for building bencher --- flake.lock | 12 ++++++------ flake.nix | 28 +++++++++++++--------------- tools/bencher/Cargo.toml | 2 +- 3 files changed, 20 insertions(+), 22 deletions(-) diff --git a/flake.lock b/flake.lock index edf96f4..82135aa 100644 --- a/flake.lock +++ b/flake.lock @@ -145,11 +145,11 @@ }, "nixpkgs": { "locked": { - "lastModified": 1733581040, - "narHash": "sha256-Qn3nPMSopRQJgmvHzVqPcE3I03zJyl8cSbgnnltfFDY=", + "lastModified": 1733940404, + "narHash": "sha256-Pj39hSoUA86ZePPF/UXiYHHM7hMIkios8TYG29kQT4g=", "owner": "nixos", "repo": "nixpkgs", - "rev": "22c3f2cf41a0e70184334a958e6b124fb0ce3e01", + "rev": "5d67ea6b4b63378b9c13be21e2ec9d1afc921713", "type": "github" }, "original": { @@ -178,11 +178,11 @@ ] }, "locked": { - "lastModified": 1733798086, - "narHash": "sha256-XHIh0h84xDnjkqampyNI/r2FAkKmwbL719ZsygiJHKE=", + "lastModified": 1734057252, + "narHash": "sha256-fpSFuiW+O2L0ru2GrXBS0wcAYV9+yDE0Gf800UsWutY=", "owner": "oxalica", "repo": "rust-overlay", - "rev": "8a19e07800d64462913f3dbf5c9a20ea7b50e6cd", + "rev": "1f56a5c88e4dcaa0ab1ba04c4bc5a977cff840b2", "type": "github" }, "original": { diff --git a/flake.nix b/flake.nix index d508ae2..fc0a2e3 100644 --- a/flake.nix +++ b/flake.nix @@ -52,17 +52,6 @@ enableMetal = true; enableOpencl = true; }; - cargo-audit = pkgs.rustPlatform.buildRustPackage rec { - version = "0.21.0"; - pname = "cargo-audit"; - src = pkgs.fetchCrate { - inherit pname version; - sha256 = "sha256-oMXpJE49If4QKE80ZKhRpMRPh3Bl517a2Ez/1VcaQJQ="; - }; - cargoLock = rec { - lockFile = "${src}/Cargo.lock"; - }; - }; }) ]; }; @@ -113,7 +102,7 @@ ]); }; cargoArtifacts = craneLib.buildPackage commonArgs; - in { + in rec { checks = { mnn-clippy = craneLib.cargoClippy (commonArgs @@ -203,6 +192,12 @@ lib.optionalString pkgs.stdenv.isDarwin " --features opencl,metal,coreml" # + lib.optionalString pkgs.stdenv.isAarch64 ",metal,coreml" ); }); + bencher = craneLib.buildPackage (commonArgs + // { + inherit cargoArtifacts; + pname = "bencher"; + cargoExtraArgs = "--package bencher"; + }); default = mnn; }; @@ -229,9 +224,12 @@ google-cloud-sdk rustToolchainWithRustAnalyzer ] - ++ (lib.optionals pkgs.stdenv.isLinux [ - cargo-llvm-cov - ]); + ++ ( + lib.optionals pkgs.stdenv.isLinux [ + cargo-llvm-cov + ] + ) + ++ [packages.bencher]; }); }; } diff --git a/tools/bencher/Cargo.toml b/tools/bencher/Cargo.toml index c428bcc..c427421 100644 --- a/tools/bencher/Cargo.toml +++ b/tools/bencher/Cargo.toml @@ -5,7 +5,7 @@ edition = "2021" license.workspace = true [target."aarch64-apple-darwin".dependencies] -mnn = { workspace = true, features = ["metal", "opencl", "serde"] } +mnn = { workspace = true, features = ["opencl", "serde", "metal", "coreml"] } [target."x86_64-apple-darwin".dependencies] mnn = { workspace = true, features = ["opencl", "serde"] } From c78f7b3ed50974b77bd22a0c882dd594cd03f471 Mon Sep 17 00:00:00 2001 From: uttarayan21 Date: Fri, 13 Dec 2024 18:42:36 +0530 Subject: [PATCH 14/22] fix: Fixed issue with models not creating cpu tensor using gpu backend --- flake.nix | 2 +- tools/bencher/Cargo.toml | 2 +- tools/bencher/src/main.rs | 22 +++++++++++++--------- 3 files changed, 15 insertions(+), 11 deletions(-) diff --git a/flake.nix b/flake.nix index fc0a2e3..6b24720 100644 --- a/flake.nix +++ b/flake.nix @@ -229,7 +229,7 @@ cargo-llvm-cov ] ) - ++ [packages.bencher]; + ++ (with packages; [bencher inspect]); }); }; } diff --git a/tools/bencher/Cargo.toml b/tools/bencher/Cargo.toml index c427421..b38025f 100644 --- a/tools/bencher/Cargo.toml +++ b/tools/bencher/Cargo.toml @@ -5,7 +5,7 @@ edition = "2021" license.workspace = true [target."aarch64-apple-darwin".dependencies] -mnn = { workspace = true, features = ["opencl", "serde", "metal", "coreml"] } +mnn = { workspace = true, features = ["opencl", "serde", "metal"] } [target."x86_64-apple-darwin".dependencies] mnn = { workspace = true, features = ["opencl", "serde"] } diff --git a/tools/bencher/src/main.rs b/tools/bencher/src/main.rs index 8718048..e991d00 100644 --- a/tools/bencher/src/main.rs +++ b/tools/bencher/src/main.rs @@ -110,7 +110,6 @@ pub enum DataType { impl DataType { pub fn mas(&self, lhs: &[u8], rhs: &[u8]) -> f64 { match self { - // Self::Float16 => Self::mean_absolute_error_bytes::(lhs, rhs), Self::Float32 => Self::mean_absolute_error_bytes::(lhs, rhs), Self::Int32 => Self::mean_absolute_error_bytes::(lhs, rhs), Self::Int64 => Self::mean_absolute_error_bytes::(lhs, rhs), @@ -347,7 +346,7 @@ pub fn generate_main(cli: Generate) -> Result<()> { .to_string_lossy(); let name = format!("{}_input_{}.bin", model_name, input.name()); let path = model.with_file_name(name); - let tensor = input.raw_tensor(); + let tensor = input.raw_tensor().create_host_tensor_from_device(false); unsafe { tensor.unchecked_host_bytes().fill(1); std::fs::write(&path, tensor.unchecked_host_bytes()).cc(BenchError)?; @@ -371,8 +370,13 @@ pub fn generate_main(cli: Generate) -> Result<()> { let path = model.with_file_name(name); let tensor = output.raw_tensor(); unsafe { - let out = tensor.unchecked_host_bytes(); - std::fs::write(&path, out).cc(BenchError)?; + std::fs::write( + &path, + tensor + .create_host_tensor_from_device(true) + .unchecked_host_bytes(), + ) + .cc(BenchError)?; } cfg.outputs.insert( output.name().to_string(), @@ -592,6 +596,7 @@ pub fn bench( unsafe { net.raw_input(&session, name) .cc(BenchError)? + .create_host_tensor_from_device(false) .unchecked_host_bytes() .copy_from_slice(&input); } @@ -610,12 +615,11 @@ pub fn bench( bar.set_message(format!("Checking output {name}")); not_terminal.then(|| eprintln!("Checking output {name}")); let output = unsafe { - let out = net - .raw_output(&session, name) + net.raw_output(&session, name) .cc(BenchError)? - .create_host_tensor_from_device(true); - - out.unchecked_host_bytes().to_vec() + .create_host_tensor_from_device(true) + .unchecked_host_bytes() + .to_vec() }; if let Some(cd) = config.outputs.get(name) { let expected = std::fs::read(&cd.path).cc(BenchError)?; From 004e3ea869629916c5773aa56e834717b4e15b01 Mon Sep 17 00:00:00 2001 From: Uttarayan Mondal Date: Mon, 16 Dec 2024 13:00:00 +0530 Subject: [PATCH 15/22] feat: Use the average of 5 inference runs --- tools/bencher/src/main.rs | 17 +++++++++++++---- 1 file changed, 13 insertions(+), 4 deletions(-) diff --git a/tools/bencher/src/main.rs b/tools/bencher/src/main.rs index e991d00..0910fa5 100644 --- a/tools/bencher/src/main.rs +++ b/tools/bencher/src/main.rs @@ -310,7 +310,7 @@ pub fn main() -> Result<()> { Subcommand::Bench(cli) => bench_main(cli)?, Subcommand::Generate(cli) => generate_main(cli)?, Subcommand::Completions(cli) => { - use clap_complete::aot::{generate, Generator, Shell}; + use clap_complete::aot::generate; let shell = cli.shell; let mut cmd = Cli::command(); @@ -358,7 +358,7 @@ pub fn generate_main(cli: Generate) -> Result<()> { } drop(inputs); - net.run_session(&session); + net.run_session(&session).cc(BenchError)?; let outputs = net.outputs(&session); for output in &outputs { @@ -588,6 +588,15 @@ pub fn bench( not_terminal.then(|| eprintln!("Warming up {c}")); net.run_session(&session).cc(BenchError)?; } + let (_, inference_time) = timeit(|| -> Result<()> { + for c in 0..5 { + bar.set_message(format!("Running inference {c}")); + not_terminal.then(|| eprintln!("Running inference {c}")); + net.run_session(&session).cc(BenchError)?; + } + Ok(()) + })?; + let inference_time = inference_time / 5; let config = Config::find(&model).cc(BenchError).unwrap_or_default(); for (name, path) in config.inputs.iter() { let input = std::fs::read(path).cc(BenchError)?; @@ -601,7 +610,7 @@ pub fn bench( .copy_from_slice(&input); } } - let (_, inference_time) = timeit(|| -> Result<()> { + let (_, _) = timeit(|| -> Result<()> { bar.set_message("Running session"); not_terminal.then(|| eprintln!("Running session")); net.run_session(&session).cc(BenchError)?; @@ -611,7 +620,7 @@ pub fn bench( .cc(BenchError)?; let mut outputs = BTreeMap::new(); - for (name, path) in config.outputs.iter() { + for (name, _path) in config.outputs.iter() { bar.set_message(format!("Checking output {name}")); not_terminal.then(|| eprintln!("Checking output {name}")); let output = unsafe { From 78d262aa80c9794e886a1849389ee077bc993b56 Mon Sep 17 00:00:00 2001 From: uttarayan21 Date: Tue, 17 Dec 2024 17:56:11 +0530 Subject: [PATCH 16/22] fix: Fix the issue with debugging invalid filling of tensors --- tools/bencher/src/main.rs | 29 ++++++++++++++++++++--------- 1 file changed, 20 insertions(+), 9 deletions(-) diff --git a/tools/bencher/src/main.rs b/tools/bencher/src/main.rs index 0910fa5..804f910 100644 --- a/tools/bencher/src/main.rs +++ b/tools/bencher/src/main.rs @@ -124,13 +124,16 @@ impl DataType { + Copy + core::ops::Add + num::cast::AsPrimitive - + bytemuck::Pod, + + bytemuck::Pod + + core::fmt::Debug, >( lhs: &[u8], rhs: &[u8], ) -> f64 { let lhs = bytemuck::cast_slice(lhs); let rhs = bytemuck::cast_slice(rhs); + + assert_eq!(lhs.len(), rhs.len(), "lhs and rhs have different lengths"); Self::mean_absolute_error::(lhs, rhs) } @@ -346,10 +349,12 @@ pub fn generate_main(cli: Generate) -> Result<()> { .to_string_lossy(); let name = format!("{}_input_{}.bin", model_name, input.name()); let path = model.with_file_name(name); - let tensor = input.raw_tensor().create_host_tensor_from_device(false); + let mut tensor = input.raw_tensor().create_host_tensor_from_device(false); unsafe { - tensor.unchecked_host_bytes().fill(1); - std::fs::write(&path, tensor.unchecked_host_bytes()).cc(BenchError)?; + let host = tensor.create_host_tensor_from_device(false); + host.unchecked_host_bytes().fill(1); + tensor.copy_from_host_tensor(&host); + std::fs::write(&path, host.unchecked_host_bytes()).cc(BenchError)?; } cfg.inputs.insert( input.name().to_string(), @@ -603,11 +608,11 @@ pub fn bench( bar.set_message(format!("Setting input {name}")); not_terminal.then(|| eprintln!("Setting input {name}")); unsafe { - net.raw_input(&session, name) - .cc(BenchError)? - .create_host_tensor_from_device(false) - .unchecked_host_bytes() - .copy_from_slice(&input); + let mut tensor = net.raw_input(&session, name).cc(BenchError)?; + let host = tensor.create_host_tensor_from_device(false); + host.unchecked_host_bytes().copy_from_slice(&input); + tensor.copy_from_host_tensor(&host); + drop(host); } } let (_, _) = timeit(|| -> Result<()> { @@ -632,9 +637,15 @@ pub fn bench( }; if let Some(cd) = config.outputs.get(name) { let expected = std::fs::read(&cd.path).cc(BenchError)?; + assert_eq!( + output.len(), + expected.len(), + "Failed to compare sizes of output and expected" + ); let mas = cd.data_type.mas(&output, &expected); outputs.insert(name.clone(), mas); } + drop(output); } let memory = net.memory(&session).cc(BenchError)?; let flops = net.flops(&session).cc(BenchError)?; From 2b1ea417be9ad5c3f814c6f9c52194cbc8652839 Mon Sep 17 00:00:00 2001 From: K-prog Date: Tue, 17 Dec 2024 19:04:23 +0530 Subject: [PATCH 17/22] fix: dont duplicate host tensor --- tools/bencher/src/main.rs | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/tools/bencher/src/main.rs b/tools/bencher/src/main.rs index 804f910..67e7683 100644 --- a/tools/bencher/src/main.rs +++ b/tools/bencher/src/main.rs @@ -349,7 +349,7 @@ pub fn generate_main(cli: Generate) -> Result<()> { .to_string_lossy(); let name = format!("{}_input_{}.bin", model_name, input.name()); let path = model.with_file_name(name); - let mut tensor = input.raw_tensor().create_host_tensor_from_device(false); + let mut tensor = input.raw_tensor(); unsafe { let host = tensor.create_host_tensor_from_device(false); host.unchecked_host_bytes().fill(1); From 46cba3392588eb05df76d3613d5572526bdadd79 Mon Sep 17 00:00:00 2001 From: uttarayan21 Date: Fri, 20 Dec 2024 15:01:40 +0530 Subject: [PATCH 18/22] feat: Document everything --- flake.nix | 6 +- mnn-sys/mnn_c/interpreter_c.cpp | 2 +- mnn-sys/mnn_c/interpreter_c.h | 2 +- mnn-sys/src/lib.rs | 1 + src/backend.rs | 34 ++++- src/error.rs | 32 ++++- src/interpreter.rs | 228 ++++++----------------------- src/lib.rs | 12 +- src/schedule.rs | 20 ++- src/session.rs | 12 +- src/tensor.rs | 247 +++++++------------------------- src/tensor/list.rs | 148 +++++++++++++++++++ src/tensor/raw.rs | 140 ++++++++++++++++++ tests/backend.rs | 1 + tools/bencher/src/main.rs | 8 +- 15 files changed, 484 insertions(+), 409 deletions(-) create mode 100644 src/tensor/list.rs create mode 100644 src/tensor/raw.rs diff --git a/flake.nix b/flake.nix index 6b24720..817255d 100644 --- a/flake.nix +++ b/flake.nix @@ -218,7 +218,7 @@ git-lfs llvm llvmPackages.lldb - mnn + # mnn nushell rust-bindgen google-cloud-sdk @@ -228,8 +228,8 @@ lib.optionals pkgs.stdenv.isLinux [ cargo-llvm-cov ] - ) - ++ (with packages; [bencher inspect]); + ); + # ++ (with packages; [bencher inspect]); }); }; } diff --git a/mnn-sys/mnn_c/interpreter_c.cpp b/mnn-sys/mnn_c/interpreter_c.cpp index 2fe130f..0e8da2e 100644 --- a/mnn-sys/mnn_c/interpreter_c.cpp +++ b/mnn-sys/mnn_c/interpreter_c.cpp @@ -353,7 +353,7 @@ const char *OperatorInfo_name(const void *op) { const char *OperatorInfo_type(const void *op) { return reinterpret_cast(op)->type().c_str(); } -const float OperatorInfo_flops(const void *op) { +float OperatorInfo_flops(const void *op) { return reinterpret_cast(op)->flops(); } } // extern "C" diff --git a/mnn-sys/mnn_c/interpreter_c.h b/mnn-sys/mnn_c/interpreter_c.h index 40c053b..bcd82af 100644 --- a/mnn-sys/mnn_c/interpreter_c.h +++ b/mnn-sys/mnn_c/interpreter_c.h @@ -196,7 +196,7 @@ const char *Interpreter_uuid(const Interpreter *interpreter); const char *OperatorInfo_name(const void *op); const char *OperatorInfo_type(const void *op); -const float OperatorInfo_flops(const void *op); +float OperatorInfo_flops(const void *op); #ifdef __cplusplus } diff --git a/mnn-sys/src/lib.rs b/mnn-sys/src/lib.rs index c2bf3b6..1e87349 100644 --- a/mnn-sys/src/lib.rs +++ b/mnn-sys/src/lib.rs @@ -11,6 +11,7 @@ mod sys { #![allow(non_upper_case_globals)] #![allow(non_camel_case_types)] #![allow(non_snake_case)] + #![allow(clippy::manual_c_str_literals)] include!(concat!(env!("OUT_DIR"), "/mnn_c.rs")); } pub use sys::*; diff --git a/src/backend.rs b/src/backend.rs index 65e00d9..bbb3868 100644 --- a/src/backend.rs +++ b/src/backend.rs @@ -1,8 +1,14 @@ +//! The backend module contains the data types for the backend configuration + use crate::prelude::*; use std::str::FromStr; use mnn_sys::*; +/// BackendConfig is a struct that holds the configuration for the backend +/// memory: [MemoryMode] +/// power: [PowerMode] +/// precision: [PrecisionMode] #[repr(transparent)] pub struct BackendConfig { pub(crate) inner: *mut MNNBackendConfig, @@ -60,11 +66,15 @@ impl Default for BackendConfig { } } +/// PowerModes depend on if the specific backend has support for it #[derive(Debug, Clone, Copy)] #[cfg_attr(feature = "serde", derive(serde::Serialize, serde::Deserialize))] pub enum PowerMode { + /// Low power mode Low, + /// Normal power mode Normal, + /// High power mode High, } @@ -77,6 +87,7 @@ impl PowerMode { } } + /// Returns a string representation of the power mode pub fn to_str(self) -> &'static str { match self { Self::Low => "low", @@ -139,11 +150,15 @@ impl FromStr for PrecisionMode { } } +/// MemoryModes depend on if the specific backend has support for it #[derive(Debug, Clone, Copy)] #[cfg_attr(feature = "serde", derive(serde::Serialize, serde::Deserialize))] pub enum MemoryMode { + /// Low memory mode Low, + /// Normal memory mode Normal, + /// High memory mode High, } @@ -156,6 +171,7 @@ impl MemoryMode { } } + /// Returns a string representation of the memory mode pub fn to_str(self) -> &'static str { match self { Self::Low => "low", @@ -174,16 +190,21 @@ impl MemoryMode { } } +/// PrecisionModes depend on if the specific backend has support for it #[derive(Debug, Clone, Copy)] #[cfg_attr(feature = "serde", derive(serde::Serialize, serde::Deserialize))] pub enum PrecisionMode { + /// Normal precision mode Normal = 0, + /// High precision mode High, + /// Low precision mode Low, + /// Low precision mode with BF16 LowBf16, } impl PrecisionMode { - fn to_mnn_sys(self) -> mnn_sys::PrecisionMode { + pub(crate) fn to_mnn_sys(self) -> mnn_sys::PrecisionMode { match self { Self::LowBf16 => mnn_sys::PrecisionMode::Precision_Low_BF16, Self::Low => mnn_sys::PrecisionMode::Precision_Low, @@ -192,6 +213,7 @@ impl PrecisionMode { } } + /// Returns a string representation of the precision mode pub fn to_str(self) -> &'static str { match self { Self::LowBf16 => "low_bf16", @@ -231,11 +253,13 @@ impl BackendConfig { } } + /// Sets the [MemoryMode] for the backend pub fn with_memory_mode(mut self, mode: MemoryMode) -> Self { self.set_memory_mode(mode); self } + /// Gets the [MemoryMode] for the backend pub fn get_memory_mode(&self) -> MemoryMode { unsafe { MemoryMode::from_mnn_sys(mnn_sys::mnnbc_get_memory_mode(self.inner)) } } @@ -247,11 +271,13 @@ impl BackendConfig { } } + /// Sets the [PowerMode] for the backend pub fn with_power_mode(mut self, mode: PowerMode) -> Self { self.set_power_mode(mode); self } + /// Gets the [PowerMode] for the backend pub fn get_power_mode(&self) -> PowerMode { unsafe { PowerMode::from_mnn_sys(mnn_sys::mnnbc_get_power_mode(self.inner)) } } @@ -263,11 +289,13 @@ impl BackendConfig { } } + /// Sets the [PrecisionMode] for the backend pub fn with_precision_mode(mut self, mode: PrecisionMode) -> Self { self.set_precision_mode(mode); self } + /// Gets the [PrecisionMode] for the backend pub fn get_precision_mode(&self) -> PrecisionMode { unsafe { PrecisionMode::from_mnn_sys(mnn_sys::mnnbc_get_precision_mode(self.inner)) } } @@ -280,6 +308,7 @@ impl BackendConfig { } } + /// Sets the flags for the backend pub fn with_flags(mut self, flags: usize) -> Self { self.set_flags(flags); self @@ -294,6 +323,9 @@ impl BackendConfig { } } + /// # Safety + /// This just binds to the underlying unsafe api and should be used only if you know what you + /// are doing pub unsafe fn with_shared_context(mut self, shared_context: *mut libc::c_void) -> Self { self.set_shared_context(shared_context); self diff --git a/src/error.rs b/src/error.rs index c938c8a..d17922e 100644 --- a/src/error.rs +++ b/src/error.rs @@ -1,7 +1,9 @@ use mnn_sys::ErrorCode; +#[doc(hidden)] pub type Result = core::result::Result; +/// Error type container for MNN pub struct MNNError { kind: error_stack::Report, } @@ -21,44 +23,66 @@ impl core::fmt::Debug for MNNError { impl std::error::Error for MNNError {} // pub type MNNError = error_stack::Report; +/// Error types for MNN #[derive(thiserror::Error, Debug)] pub enum ErrorKind { + /// Internal error (from MNN library) #[error("Internal error: {0:?}")] InternalError(ErrorCode), + /// Mismatching Size for input #[error("Invalid input: expected {expected}, got {got}")] - SizeMismatch { expected: usize, got: usize }, + SizeMismatch { + /// Expected size + expected: usize, + /// Provided size + got: usize, + }, + /// Failed to copy tensor #[error("Failed to copy tensor")] TensorCopyFailed(i32), + /// I/O Error #[error("IO Error")] IOError, + /// Interpreter Error #[error("Interpreter Error")] InterpreterError, + /// ASCII Error (path, name, etc had invalid characters) #[error("Ascii Error")] AsciiError, + /// HalideType mismatch (e.g. trying to convert from a float tensor to an int tensor) #[error("HalideType mismatch: got {got}")] - HalideTypeMismatch { got: &'static str }, + HalideTypeMismatch { + /// HalideType that was + got: &'static str, + }, + /// Failed to parse the Argument #[error("Parse Error")] ParseError, + /// Error with mnn-sync crate #[error("Sync Error")] SyncError, + /// Error with some tensor #[error("Tensor Error")] TensorError, + /// Tried to run a dynamic tensor without resizing it first #[error("Dynamic Tensor Error: Tensor needs to be resized before using")] DynamicTensorError, } impl MNNError { #[track_caller] + #[doc(hidden)] pub fn new(kind: ErrorKind) -> Self { let kind = error_stack::Report::new(kind); Self { kind } } #[track_caller] - pub fn from_error_code(code: ErrorCode) -> Self { + pub(crate) fn from_error_code(code: ErrorCode) -> Self { Self::new(ErrorKind::InternalError(code)) } + /// Return the inner [error_stack::Report] containing the error #[inline(always)] pub fn into_inner(self) -> error_stack::Report { self.kind @@ -123,7 +147,7 @@ impl From> for MNNError { } impl MNNError { - pub fn attach_printable( + pub(crate) fn attach_printable( self, printable: impl core::fmt::Display + core::fmt::Debug + Send + Sync + 'static, ) -> Self { diff --git a/src/interpreter.rs b/src/interpreter.rs index 70a0e52..9d2d0c0 100644 --- a/src/interpreter.rs +++ b/src/interpreter.rs @@ -1,3 +1,5 @@ +//! The interpreter module provides the `Interpreter` struct which is used to load and run models. +use crate::tensor::list::TensorList; use std::{ffi::CStr, path::Path, sync::Arc}; use crate::{ @@ -5,10 +7,10 @@ use crate::{ }; use mnn_sys::HalideType; -pub type TensorCallbackT = Box bool>; +pub(crate) type TensorCallbackT = Box bool>; #[repr(transparent)] -pub struct TensorCallback { +pub(crate) struct TensorCallback { inner: Arc, } @@ -21,7 +23,7 @@ impl Default for TensorCallback { } impl TensorCallback { - pub fn from_ptr(f: *mut libc::c_void) -> Self { + pub(crate) fn from_ptr(f: *mut libc::c_void) -> Self { debug_assert!(!f.is_null()); unsafe { Self { @@ -30,11 +32,12 @@ impl TensorCallback { } } - pub fn into_ptr(self) -> *mut libc::c_void { + pub(crate) fn into_ptr(self) -> *mut libc::c_void { Arc::into_raw(self.inner) as *mut libc::c_void } - pub fn identity() -> impl Fn(&[RawTensor], OperatorInfo) -> bool { + #[cfg(test)] + pub(crate) fn identity() -> impl Fn(&[RawTensor], OperatorInfo) -> bool { |_, _| true } } @@ -72,6 +75,12 @@ impl core::ops::Deref for TensorCallback { } } +/// The session mode to be used +/// The items are mostly untested and are only documented 1:1 to the C++ codebase +/// The only two items tested are +/// - `Debug` +/// - `Release` +/// Which work fine #[derive(Debug, Copy, Clone)] #[cfg_attr(windows, repr(i32))] #[cfg_attr(unix, repr(u32))] @@ -202,6 +211,7 @@ impl Interpreter { unsafe { mnn_sys::Interpreter_resizeSessionWithFlag(self.inner, session.inner, 1i32) } } + /// Resize the tensor using the given shape pub fn resize_tensor(&self, tensor: &mut Tensor, dims: impl AsTensorShape) { let dims = dims.as_tensor_shape(); let dims_len = dims.size; @@ -215,22 +225,27 @@ impl Interpreter { } } + /// Resize tensor by + /// - N -> batch + /// - C -> channel + /// - H -> height + /// - W -> width pub fn resize_tensor_by_nchw( &self, tensor: &mut Tensor, - batch: i32, - channel: i32, - height: i32, - width: i32, + batch: u16, + channel: u16, + height: u16, + width: u16, ) { unsafe { mnn_sys::Interpreter_resizeTensorByNCHW( self.inner, tensor.tensor, - batch, - channel, - height, - width, + batch.into(), + channel.into(), + height.into(), + width.into(), ) } } @@ -302,7 +317,7 @@ impl Interpreter { /// `session`: the session to get input tensor /// /// return: List of input tensors - pub fn inputs(&self, session: &crate::Session) -> TensorList { + pub fn inputs<'i>(&self, session: &'i crate::Session) -> TensorList<'i> { let inputs = unsafe { mnn_sys::Interpreter_getSessionInputAll(self.inner, session.inner) }; TensorList::from_ptr(inputs) } @@ -338,6 +353,7 @@ impl Interpreter { Ok(tensor) } + /// Get the raw input tensor of a session by name pub fn raw_input<'s>( &self, session: &'s crate::Session, @@ -421,6 +437,7 @@ impl Interpreter { Ok(tensor) } + /// Get the raw output tensor of a session by name pub fn raw_output<'s>( &self, session: &'s crate::Session, @@ -483,7 +500,7 @@ impl Interpreter { } /// Get all output tensors of a session - pub fn outputs(&self, session: &crate::session::Session) -> TensorList { + pub fn outputs<'o>(&self, session: &'o crate::session::Session) -> TensorList<'o> { let outputs = unsafe { mnn_sys::Interpreter_getSessionOutputAll(self.inner, session.inner) }; TensorList::from_ptr(outputs) @@ -566,6 +583,7 @@ impl Interpreter { Ok(flop) } + /// Get the resize status pub fn resize_status(&self, session: &crate::Session) -> Result { let mut resize_status = 0i32; let ptr = &mut resize_status as *mut i32; @@ -591,186 +609,18 @@ impl Interpreter { } } +/// The status of the resize operation #[derive(Debug, Copy, Clone, PartialEq, Eq)] #[repr(C)] pub enum ResizeStatus { + /// No resize needed None = 0, + /// Need to malloc memory NeedMalloc = 1, + /// Need to resize memory NeedResize = 2, } -#[repr(transparent)] -pub struct TensorInfo<'t, 'tl> { - pub(crate) tensor_info: *mut mnn_sys::TensorInfo, - pub(crate) __marker: PhantomData<&'tl TensorList<'t>>, -} - -impl core::fmt::Debug for TensorInfo<'_, '_> { - fn fmt(&self, f: &mut core::fmt::Formatter<'_>) -> core::fmt::Result { - let tensor = self.raw_tensor(); - let shape = tensor.shape(); - f.debug_struct("TensorInfo") - .field("name", &self.name()) - .field("tensor", &shape) - .finish() - } -} - -impl<'t, 'tl> TensorInfo<'t, 'tl> { - pub fn name(&self) -> &'tl str { - debug_assert!(!self.tensor_info.is_null()); - unsafe { (*self.tensor_info).name.to_cstr() } - .to_str() - .expect("Tensor name is not utf-8") - } - - pub fn tensor(&self) -> Result>>> { - debug_assert!(!self.tensor_info.is_null()); - unsafe { debug_assert!(!(*self.tensor_info).tensor.is_null()) }; - let tensor = unsafe { Tensor::from_ptr((*self.tensor_info).tensor.cast()) }; - let shape = tensor.shape(); - ensure!(!shape.as_ref().contains(&-1), ErrorKind::DynamicTensorError); - ensure!( - tensor.is_type_of::(), - ErrorKind::HalideTypeMismatch { - got: std::any::type_name::(), - } - ); - Ok(tensor) - } - - /// # Safety - /// The shape is not checked so it's marked unsafe since futher calls to interpreter might be **unsafe** with this - pub unsafe fn tensor_unresized(&self) -> Result>>> { - debug_assert!(!self.tensor_info.is_null()); - unsafe { debug_assert!(!(*self.tensor_info).tensor.is_null()) }; - let tensor = unsafe { Tensor::from_ptr((*self.tensor_info).tensor.cast()) }; - ensure!( - tensor.is_type_of::(), - ErrorKind::HalideTypeMismatch { - got: std::any::type_name::(), - } - ); - Ok(tensor) - } - - /// This function return's the raw tensor without any sort of type-checking or shape-checking - pub fn raw_tensor(&self) -> RawTensor<'t> { - debug_assert!(!self.tensor_info.is_null()); - unsafe { debug_assert!(!(*self.tensor_info).tensor.is_null()) }; - RawTensor::from_ptr(unsafe { (*self.tensor_info).tensor.cast() }) - } -} - -#[repr(transparent)] -pub struct TensorList<'t> { - pub(crate) inner: *const mnn_sys::TensorInfoArray, - pub(crate) __marker: PhantomData<&'t Interpreter>, -} - -impl core::fmt::Debug for TensorList<'_> { - fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> core::fmt::Result { - f.debug_list().entries(self.iter()).finish() - } -} - -impl Drop for TensorList<'_> { - fn drop(&mut self) { - unsafe { mnn_sys::destroyTensorInfoArray(self.inner.cast_mut()) } - } -} - -impl<'t> TensorList<'t> { - pub fn from_ptr(inner: *const mnn_sys::TensorInfoArray) -> Self { - Self { - inner, - __marker: PhantomData, - } - } - - pub fn size(&self) -> usize { - unsafe { (*self.inner).size } - } - - pub fn get(&self, index: usize) -> Option> { - if index >= self.size() { - None - } else { - let gtinfo = unsafe { mnn_sys::getTensorInfoArray(self.inner, index) }; - if !gtinfo.is_null() { - Some(TensorInfo { - tensor_info: gtinfo, - __marker: PhantomData, - }) - } else { - None - } - } - } - - pub fn iter(&self) -> TensorListIter { - TensorListIter { - tensor_list: self, - idx: 0, - } - } -} - -impl<'t, 'tl: 't> IntoIterator for &'tl TensorList<'t> { - type Item = TensorInfo<'t, 'tl>; - type IntoIter = TensorListIter<'t, 'tl>; - - fn into_iter(self) -> Self::IntoIter { - TensorListIter { - tensor_list: self, - idx: 0, - } - } -} - -pub struct TensorListIter<'t, 'tl> { - tensor_list: &'tl TensorList<'t>, - idx: usize, -} -impl<'t, 'tl> Iterator for TensorListIter<'t, 'tl> { - type Item = TensorInfo<'t, 'tl>; - fn next(&mut self) -> Option { - let idx = self.idx; - self.idx += 1; - self.tensor_list.get(idx) - } -} - -// #[no_mangle] -// extern "C" fn rust_closure_callback_runner( -// f: *mut libc::c_void, -// tensors: *const *mut mnn_sys::Tensor, -// tensor_count: usize, -// name: *const libc::c_char, -// ) -> libc::c_int { -// let tensors = unsafe { std::slice::from_raw_parts(tensors.cast(), tensor_count) }; -// let name = unsafe { std::ffi::CStr::from_ptr(name) }; -// let f: TensorCallback = unsafe { Box::from_raw(f.cast::()) }; -// let ret = f(tensors, name) as libc::c_int; -// core::mem::forget(f); -// ret -// } - -// #[test] -// fn test_extern_c_rust_closure_callback_runner() { -// let f = |_tensors: &[RawTensor], name: &CStr| -> bool { -// println!("Callback: {:?}", name); -// true -// }; -// let f: Box = Box::new(Box::new(f)); -// let f = Box::into_raw(f).cast(); -// let tensors = [std::ptr::null_mut()]; -// let name = std::ffi::CString::new("Test").unwrap(); -// let ret = rust_closure_callback_runner(f, tensors.as_ptr(), tensors.len(), name.as_ptr()) -// as libc::c_int; -// assert_eq!(ret, 0); -// } - #[no_mangle] extern "C" fn rust_closure_callback_runner_op( f: *mut libc::c_void, @@ -790,6 +640,7 @@ extern "C" fn rust_closure_callback_runner_op( ret } +/// A struct that holds information about an operator #[repr(transparent)] pub struct OperatorInfo<'op> { pub(crate) inner: *mut libc::c_void, @@ -807,14 +658,17 @@ impl core::fmt::Debug for OperatorInfo<'_> { } impl OperatorInfo<'_> { + /// Get the name of the operator pub fn name(&self) -> &CStr { unsafe { CStr::from_ptr(mnn_sys::OperatorInfo_name(self.inner)) } } + /// Get the type of the operator pub fn type_name(&self) -> &CStr { unsafe { CStr::from_ptr(mnn_sys::OperatorInfo_type(self.inner)) } } + /// Get the number of flops of the operator pub fn flops(&self) -> f32 { unsafe { mnn_sys::OperatorInfo_flops(self.inner) } } diff --git a/src/lib.rs b/src/lib.rs index f15942d..f990720 100644 --- a/src/lib.rs +++ b/src/lib.rs @@ -1,3 +1,4 @@ +#![deny(missing_docs)] //! //! Ergonomic rust bindings for [MNN](https://github.com/alibaba/MNN) //! @@ -59,16 +60,23 @@ //! - 🚸 - Some models work //! - ❌ - Doesn't work +/// Re-export of whole mnn-sys pub mod ffi { pub use mnn_sys::*; } +mod profile; + pub mod backend; +/// Error handling pub mod error; +/// MNN::Interpreter related items pub mod interpreter; -pub mod profile; +/// Schedule configuration pub mod schedule; +/// MNN::Session related items pub mod session; +/// MNN::Tensor related items pub mod tensor; pub use backend::*; @@ -81,10 +89,12 @@ pub use tensor::*; pub use ffi::HalideType; pub use ffi::MapType; +/// Re-export of commonly used items pub mod prelude { pub use crate::error::*; pub(crate) use crate::profile::profile; pub use core::marker::PhantomData; pub use error_stack::{Report, ResultExt}; pub use libc::*; + pub use mnn_sys::{HalideType, MapType}; } diff --git a/src/schedule.rs b/src/schedule.rs index 8c314a2..6fe6955 100644 --- a/src/schedule.rs +++ b/src/schedule.rs @@ -36,18 +36,23 @@ use crate::{prelude::*, BackendConfig}; #[derive(Debug, Copy, Clone, Default, PartialEq, Eq)] #[cfg_attr(feature = "serde", derive(serde::Serialize, serde::Deserialize))] pub enum ForwardType { + /// Use all available backends. All, #[default] + /// Try to automatically select the best backend based on the current environment and hardware. Auto, + /// Use the CPU for computation. CPU, #[cfg(feature = "metal")] + /// Use the Metal backend for computation. Metal, #[cfg(feature = "opencl")] + /// Use the OpenCL backend for computation. OpenCL, - #[cfg(feature = "opengl")] - OpenGL, + /// Use the Vulkan backend for computation. #[cfg(feature = "vulkan")] Vulkan, + /// Use the CoreML backend for computation. #[cfg(feature = "coreml")] CoreML, } @@ -91,6 +96,7 @@ impl ForwardType { } } + /// List all available `ForwardType` variants as string slices. fn list() -> Vec<&'static str> { vec![ "auto", @@ -109,6 +115,7 @@ impl ForwardType { ] } + /// Convert the `ForwardType` enum to a string slice. pub fn to_str(self) -> &'static str { match self { ForwardType::Auto => "auto", @@ -310,6 +317,7 @@ impl ScheduleConfig { self } + /// Sets the type of backend to be used for computation. pub fn with_type(mut self, forward_type: ForwardType) -> Self { self.set_type(forward_type); self @@ -332,6 +340,7 @@ impl ScheduleConfig { self } + /// Sets the number of threads to be used for computation. pub fn with_num_threads(mut self, num_threads: i32) -> Self { self.set_num_threads(num_threads); self @@ -349,6 +358,7 @@ impl ScheduleConfig { self } + /// Sets the mode of computation. pub fn with_mode(mut self, mode: i32) -> Self { self.set_mode(mode); self @@ -366,6 +376,7 @@ impl ScheduleConfig { self } + /// Sets the backup type of backend to be used if the primary backend fails. pub fn with_backup_type(mut self, backup_type: ForwardType) -> Self { self.set_backup_type(backup_type); self @@ -397,12 +408,14 @@ impl ScheduleConfig { self } + /// Sets the backend-specific configuration. pub fn with_backend_config(mut self, backend_config: impl Into>) -> Self { self.set_backend_config(backend_config); self } } +/// A list of `ScheduleConfig` objects to be used for scheduling the forward computation in MNN. #[derive(Debug)] pub struct ScheduleConfigs { pub(crate) inner: Vec<*const MNNScheduleConfig>, @@ -420,12 +433,14 @@ impl Drop for ScheduleConfigs { } impl ScheduleConfigs { + /// Pushed a new `ScheduleConfig` to the list of configurations. pub fn push(&mut self, config: ScheduleConfig) { let mut config = ManuallyDrop::new(config); self.inner.push(config.inner); self.backend_configs.push(config.backend_config.take()); } + /// Creates a new (empty) `ScheduleConfigs` with the specified capacity. pub fn with_capacity(capacity: usize) -> Self { Self { inner: Vec::with_capacity(capacity), @@ -433,6 +448,7 @@ impl ScheduleConfigs { } } + /// Creates a new (empty) `ScheduleConfigs` with default settings. pub const fn new() -> Self { Self { inner: Vec::new(), diff --git a/src/session.rs b/src/session.rs index 5be36db..9e2fae4 100644 --- a/src/session.rs +++ b/src/session.rs @@ -30,17 +30,7 @@ pub enum SessionInternals { } impl Session { - // pub unsafe fn from_ptr(session: *mut mnn_sys::Session) -> Self { - // Self { - // session, - // __marker: PhantomData, - // } - // } - - // pub fn as_ptr_mut(&self) -> *mut mnn_sys::Session { - // self.session - // } - // + /// Calls the destroy function on the underlying MNN session. pub fn destroy(&mut self) { unsafe { mnn_sys::Interpreter_releaseSession(self.net, self.inner); diff --git a/src/tensor.rs b/src/tensor.rs index 54bec06..7a61a54 100644 --- a/src/tensor.rs +++ b/src/tensor.rs @@ -1,5 +1,9 @@ use crate::prelude::*; use core::marker::PhantomData; +use mnn_sys::*; +pub(crate) mod list; +mod raw; +pub use raw::RawTensor; use mnn_sys::HalideType; @@ -15,21 +19,32 @@ macro_rules! seal { } seal!(Host, Device, Ref<'_, T>, RefMut<'_, T>); +/// A trait to represent the type of a tensor pub trait TensorType: seal::Sealed { + /// The halide type of the tensor type H; + /// Check if the tensor is owned fn owned() -> bool; + /// Check if the tensor is borrowed fn borrowed() -> bool { !Self::owned() } + /// Check if the tensor is allocated in the host fn host() -> bool; + /// Check if the tensor is allocated in the device fn device() -> bool { !Self::host() } } +/// A tensor that is owned pub trait OwnedTensorType: TensorType {} +/// A tensor that is borrowed pub trait RefTensorType: TensorType {} +/// A tensor that is allocated in the cpu / host platform pub trait HostTensorType: TensorType {} +/// A tensor that is allocated in the device / gpu platform pub trait DeviceTensorType: TensorType {} +/// A tensor that is mutable pub trait MutableTensorType: TensorType {} impl TensorType for Host { @@ -84,20 +99,25 @@ impl MutableTensorType for RefMut<'_, T> {} impl RefTensorType for Ref<'_, T> {} impl RefTensorType for RefMut<'_, T> {} +/// A tensor that is owned by the cpu / host platform pub struct Host { pub(crate) __marker: PhantomData, } +/// A tensor that is owned by the device / gpu platform pub struct Device { pub(crate) __marker: PhantomData, } +/// A reference to a any tensor pub struct Ref<'t, T> { pub(crate) __marker: PhantomData<&'t [T]>, } +/// A mutable reference to a any tensor pub struct RefMut<'t, T> { pub(crate) __marker: PhantomData<&'t mut [T]>, } +/// A generic tensor that can of host / device / owned / borrowed pub struct Tensor { pub(crate) tensor: *mut mnn_sys::Tensor, __marker: PhantomData, @@ -114,6 +134,7 @@ impl Drop for Tensor { } impl Tensor> { + /// Get's a reference to an owned host tensor pub fn as_ref(&self) -> Tensor>> { Tensor { tensor: self.tensor, @@ -123,6 +144,7 @@ impl Tensor> { } impl Tensor> { + /// Get's a reference to an owned device tensor pub fn as_ref(&self) -> Tensor>> { Tensor { tensor: self.tensor, @@ -131,8 +153,6 @@ impl Tensor> { } } -use mnn_sys::*; - /// The type of the tensor dimension /// If you are manually specifying the shapes then this doesn't really matter /// N -> Batch size @@ -150,10 +170,13 @@ pub enum DimensionType { } impl DimensionType { + /// Tensorflow style dimensions or NHWC pub const NHWC: Self = Self::TensorFlow; + /// Caffe style dimensions or NCHW pub const NCHW: Self = Self::Caffe; + /// Caffe style dimensions with channel packed in 4 bytes or NC4HW4 pub const NC4HW4: Self = Self::CaffeC4; - pub fn to_mnn_sys(&self) -> mnn_sys::DimensionType { + pub(crate) fn to_mnn_sys(&self) -> mnn_sys::DimensionType { match self { DimensionType::Caffe => mnn_sys::DimensionType::CAFFE, DimensionType::CaffeC4 => mnn_sys::DimensionType::CAFFE_C4, @@ -201,6 +224,7 @@ where Ok(()) } + /// Get the device id of the tensor pub fn device_id(&self) -> u64 { unsafe { Tensor_deviceId(self.tensor) } } @@ -210,40 +234,49 @@ where unsafe { Tensor_shape(self.tensor) }.into() } + /// Get the dimensions of the tensor pub fn dimensions(&self) -> usize { unsafe { Tensor_dimensions(self.tensor) as usize } } + /// Get the width of the tensor pub fn width(&self) -> u32 { unsafe { Tensor_width(self.tensor) as u32 } } + /// Get the height of the tensor pub fn height(&self) -> u32 { unsafe { Tensor_height(self.tensor) as u32 } } + /// Get the channel size of the tensor pub fn channel(&self) -> u32 { unsafe { Tensor_channel(self.tensor) as u32 } } + /// Get the batch size of the tensor pub fn batch(&self) -> u32 { unsafe { Tensor_batch(self.tensor) as u32 } } + /// Get the size of the tensor when counted by bytes pub fn size(&self) -> usize { unsafe { Tensor_usize(self.tensor) } } + /// Get the size of the tensor when counted by elements pub fn element_size(&self) -> usize { unsafe { Tensor_elementSize(self.tensor) as usize } } + /// Print the shape of the tensor pub fn print_shape(&self) { unsafe { Tensor_printShape(self.tensor); } } + /// Print the tensor pub fn print(&self) { unsafe { Tensor_print(self.tensor); @@ -269,20 +302,25 @@ where Tensor_buffer_mut(self.tensor) } + /// Get the dimension type of the tensor pub fn get_dimension_type(&self) -> DimensionType { debug_assert!(!self.tensor.is_null()); From::from(unsafe { Tensor_getDimensionType(self.tensor) }) } + /// Get the data type of the tensor pub fn get_type(&self) -> mnn_sys::halide_type_t { unsafe { Tensor_getType(self.tensor) } } + /// Check if the tensor is of the specified data type pub fn is_type_of(&self) -> bool { let htc = halide_type_of::(); unsafe { Tensor_isTypeOf(self.tensor, htc) } } + /// # Safety + /// This is very unsafe do not use this unless you know what you are doing pub unsafe fn into_raw(self) -> RawTensor<'static> { let out = RawTensor { inner: self.tensor, @@ -296,6 +334,7 @@ impl Tensor where T::H: HalideType, { + /// Fill the tensor with the specified value pub fn fill(&mut self, value: T::H) where T::H: Copy, @@ -359,10 +398,12 @@ where Ok(result) } + /// Get the host memory slice of the tensor pub fn host(&self) -> &[T::H] { self.try_host().expect("Failed to get tensor host") } + /// Get the mutable host memory slice of the tensor pub fn host_mut(&mut self) -> &mut [T::H] { self.try_host_mut().expect("Failed to get tensor host_mut") } @@ -372,12 +413,15 @@ impl Tensor where T::H: HalideType, { + /// Try to wait for the device tensor to finish processing pub fn wait(&self, map_type: MapType, finish: bool) { unsafe { Tensor_wait(self.tensor, map_type, finish as i32); } } + /// Create a host tensor from the device tensor with same dimensions and data type and + /// optionally copy the data from the device tensor pub fn create_host_tensor_from_device(&self, copy_data: bool) -> Tensor> { let shape = self.shape(); let dm_type = self.get_dimension_type(); @@ -395,6 +439,7 @@ impl Tensor where T::H: HalideType, { + /// Create a new tensor with the specified shape and dimension type pub fn new(shape: impl AsTensorShape, dm_type: DimensionType) -> Self { let shape = shape.as_tensor_shape(); let tensor = unsafe { @@ -423,50 +468,6 @@ where } } -// impl Tensor> { -// pub fn new(shape: &[i32], data: &mut [T]) -> Self { -// let tensor = unsafe { -// }; -// debug_assert!(!tensor.is_null()); -// Self { -// tensor, -// __marker: PhantomData, -// } -// } -// -// // pub fn new_with_host_data(shape: &[usize], data: &[T::H]) -> Self { -// // let tensor = unsafe { -// // Tensor_createHostTensorWithData( -// // shape.as_ptr(), -// // shape.len() as i32, -// // data.as_ptr().cast(), -// // data.len() as i32, -// // ) -// // }; -// // debug_assert!(!tensor.is_null()); -// // Self { -// // tensor, -// // __marker: PhantomData, -// // } -// // } -// -// // pub fn new_with_host_data_mut(shape: &[usize], data: &mut [T::H]) -> Self { -// // let tensor = unsafe { -// // Tensor_createHostTensorWithData( -// // shape.as_ptr(), -// // shape.len() as i32, -// // data.as_mut_ptr().cast(), -// // data.len() as i32, -// // ) -// // }; -// // debug_assert!(!tensor.is_null()); -// // Self { -// // tensor, -// // __marker: PhantomData, -// // } -// // } -// } - impl Clone for Tensor where T::H: HalideType, @@ -480,6 +481,7 @@ where } } +/// A tensor shape #[derive(Clone, Copy)] #[repr(C)] pub struct TensorShape { @@ -539,7 +541,9 @@ impl core::fmt::Debug for TensorShape { } } +/// A trait to convert any array-like type to a tensor shape pub trait AsTensorShape { + /// Convert the array-like type to a tensor shape fn as_tensor_shape(&self) -> TensorShape; } @@ -600,6 +604,7 @@ impl Tensor where T::H: HalideType, { + /// Try to create a ref tensor from any array-like type pub fn borrowed(shape: impl AsTensorShape, input: impl AsRef<[T::H]>) -> Self { let shape = shape.as_tensor_shape(); let input = input.as_ref(); @@ -619,6 +624,7 @@ where } } + /// Try to create a mutable ref tensor from any array-like type pub fn borrowed_mut(shape: impl AsTensorShape, mut input: impl AsMut<[T::H]>) -> Self { let shape = shape.as_tensor_shape(); let input = input.as_mut(); @@ -656,150 +662,3 @@ pub fn test_tensor_borrow_mut() { tensor.host_mut().fill(1); assert_eq!(data, &[1, 1, 1, 1, 1, 1]); } - -pub struct Dyn { - __marker: PhantomData, -} -impl seal::Sealed for Dyn {} - -impl super::TensorType for Dyn { - type H = T::H; - fn host() -> bool { - T::host() - } - fn device() -> bool { - T::device() - } - fn owned() -> bool { - T::owned() - } - fn borrowed() -> bool { - T::borrowed() - } -} - -/// A raw tensor type that doesn't have any guarantees -/// and will be unconditionally dropped -#[repr(transparent)] -pub struct RawTensor<'r> { - pub(crate) inner: *mut mnn_sys::Tensor, - pub(crate) __marker: PhantomData<&'r ()>, -} - -// impl<'r> core::ops::Drop for RawTensor<'r> { -// fn drop(&mut self) { -// unsafe { -// mnn_sys::Tensor_destroy(self.inner); -// } -// } -// } - -impl<'r> RawTensor<'r> { - pub fn create_host_tensor_from_device(&self, copy_data: bool) -> RawTensor<'static> { - let tensor = - unsafe { mnn_sys::Tensor_createHostTensorFromDevice(self.inner, copy_data as i32) }; - // crate::ensure!(!tensor.is_null(), ErrorKind::TensorError); - assert!(!tensor.is_null()); - RawTensor { - inner: tensor, - __marker: PhantomData, - } - } - - /// Copies the data from a host tensor to the self tensor - pub fn copy_from_host_tensor(&mut self, tensor: &RawTensor) -> Result<()> { - let ret = unsafe { Tensor_copyFromHostTensor(self.inner, tensor.inner) }; - crate::ensure!(ret != 0, ErrorKind::TensorCopyFailed(ret)); - Ok(()) - } - - /// Copies the data from the self tensor to a host tensor - pub fn copy_to_host_tensor(&self, tensor: &mut RawTensor) -> Result<()> { - let ret = unsafe { Tensor_copyToHostTensor(self.inner, tensor.inner) }; - crate::ensure!(ret != 0, ErrorKind::TensorCopyFailed(ret)); - Ok(()) - } - - pub fn shape(&self) -> TensorShape { - unsafe { mnn_sys::Tensor_shape(self.inner) }.into() - } - - pub fn get_dimension_type(&self) -> DimensionType { - debug_assert!(!self.inner.is_null()); - From::from(unsafe { mnn_sys::Tensor_getDimensionType(self.inner) }) - } - - pub fn destroy(self) { - unsafe { - mnn_sys::Tensor_destroy(self.inner); - } - } - - pub fn size(&self) -> usize { - unsafe { mnn_sys::Tensor_usize(self.inner) } - } - - pub fn element_size(&self) -> usize { - unsafe { mnn_sys::Tensor_elementSize(self.inner) as usize } - } - - pub fn dimensions(&self) -> usize { - unsafe { mnn_sys::Tensor_dimensions(self.inner) as usize } - } - - pub fn width(&self) -> u32 { - unsafe { mnn_sys::Tensor_width(self.inner) as u32 } - } - - pub fn height(&self) -> u32 { - unsafe { mnn_sys::Tensor_height(self.inner) as u32 } - } - - pub fn channel(&self) -> u32 { - unsafe { mnn_sys::Tensor_channel(self.inner) as u32 } - } - - pub fn is_dynamic_unsized(&self) -> bool { - self.shape().as_ref().contains(&-1) - } - - pub fn wait(&self, map_type: MapType, finish: bool) { - unsafe { - Tensor_wait(self.inner, map_type, finish as i32); - } - } - - /// # Safety - /// This is very unsafe do not use this unless you know what you are doing - /// Gives a raw pointer to the tensor's data - /// P.S. I don't know what I'm doing - pub unsafe fn unchecked_host_ptr(&self) -> *mut c_void { - debug_assert!(!self.inner.is_null()); - let data = mnn_sys::Tensor_host_mut(self.inner); - debug_assert!(data.is_null()); - data - } - - /// # Safety - /// This is very unsafe do not use this unless you know what you are doing - /// Gives a mutable byte slice to the tensor's data - pub unsafe fn unchecked_host_bytes(&self) -> &mut [u8] { - core::slice::from_raw_parts_mut(self.unchecked_host_ptr().cast(), self.size()) - } - - /// # Safety - /// This is very unsafe do not use this unless you know what you are doing - pub unsafe fn to_concrete(self) -> super::Tensor - where - T::H: HalideType, - { - super::Tensor::from_ptr(self.inner) - } - - pub(crate) fn from_ptr(inner: *mut mnn_sys::Tensor) -> Self { - Self { - inner, - __marker: PhantomData, - } - } -} diff --git a/src/tensor/list.rs b/src/tensor/list.rs new file mode 100644 index 0000000..13fa729 --- /dev/null +++ b/src/tensor/list.rs @@ -0,0 +1,148 @@ +#![deny(missing_docs)] +use crate::{prelude::*, Device, RawTensor, RefMut, Tensor}; +use mnn_sys::HalideType; + +#[repr(transparent)] +pub struct TensorList<'t> { + pub(crate) inner: *const mnn_sys::TensorInfoArray, + pub(crate) __marker: PhantomData<&'t ()>, +} + +impl core::fmt::Debug for TensorList<'_> { + fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> core::fmt::Result { + f.debug_list().entries(self.iter()).finish() + } +} + +impl Drop for TensorList<'_> { + fn drop(&mut self) { + unsafe { mnn_sys::destroyTensorInfoArray(self.inner.cast_mut()) } + } +} + +impl<'t> TensorList<'t> { + pub(crate) fn from_ptr(inner: *const mnn_sys::TensorInfoArray) -> Self { + Self { + inner, + __marker: PhantomData, + } + } + + /// Returns the size of the tensor list + pub fn size(&self) -> usize { + unsafe { (*self.inner).size } + } + + /// Get the tensor at the given index + pub fn get(&self, index: usize) -> Option> { + if index >= self.size() { + None + } else { + let gtinfo = unsafe { mnn_sys::getTensorInfoArray(self.inner, index) }; + if !gtinfo.is_null() { + Some(TensorInfo { + tensor_info: gtinfo, + __marker: PhantomData, + }) + } else { + None + } + } + } + + /// Get an iterator over the tensor list + pub fn iter(&self) -> TensorListIter { + TensorListIter { + tensor_list: self, + idx: 0, + } + } +} + +impl<'t, 'tl: 't> IntoIterator for &'tl TensorList<'t> { + type Item = TensorInfo<'t, 'tl>; + type IntoIter = TensorListIter<'t, 'tl>; + + fn into_iter(self) -> Self::IntoIter { + TensorListIter { + tensor_list: self, + idx: 0, + } + } +} + +pub struct TensorListIter<'t, 'tl> { + tensor_list: &'tl TensorList<'t>, + idx: usize, +} +impl<'t, 'tl> Iterator for TensorListIter<'t, 'tl> { + type Item = TensorInfo<'t, 'tl>; + fn next(&mut self) -> Option { + let idx = self.idx; + self.idx += 1; + self.tensor_list.get(idx) + } +} + +#[repr(transparent)] +pub struct TensorInfo<'t, 'tl> { + pub(crate) tensor_info: *mut mnn_sys::TensorInfo, + pub(crate) __marker: PhantomData<&'tl TensorList<'t>>, +} + +impl core::fmt::Debug for TensorInfo<'_, '_> { + fn fmt(&self, f: &mut core::fmt::Formatter<'_>) -> core::fmt::Result { + let tensor = self.raw_tensor(); + let shape = tensor.shape(); + f.debug_struct("TensorInfo") + .field("name", &self.name()) + .field("tensor", &shape) + .finish() + } +} + +impl<'t, 'tl> TensorInfo<'t, 'tl> { + pub fn name(&self) -> &'tl str { + debug_assert!(!self.tensor_info.is_null()); + unsafe { (*self.tensor_info).name.to_cstr() } + .to_str() + .expect("Tensor name is not utf-8") + } + + pub fn tensor(&self) -> Result>>> { + debug_assert!(!self.tensor_info.is_null()); + unsafe { debug_assert!(!(*self.tensor_info).tensor.is_null()) }; + let tensor = unsafe { Tensor::from_ptr((*self.tensor_info).tensor.cast()) }; + let shape = tensor.shape(); + ensure!(!shape.as_ref().contains(&-1), ErrorKind::DynamicTensorError); + ensure!( + tensor.is_type_of::(), + ErrorKind::HalideTypeMismatch { + got: std::any::type_name::(), + } + ); + Ok(tensor) + } + + /// # Safety + /// The shape is not checked so it's marked unsafe since futher calls to interpreter might be **unsafe** with this + pub unsafe fn tensor_unresized(&self) -> Result>>> { + debug_assert!(!self.tensor_info.is_null()); + unsafe { debug_assert!(!(*self.tensor_info).tensor.is_null()) }; + let tensor = unsafe { Tensor::from_ptr((*self.tensor_info).tensor.cast()) }; + ensure!( + tensor.is_type_of::(), + ErrorKind::HalideTypeMismatch { + got: std::any::type_name::(), + } + ); + Ok(tensor) + } + + /// This function return's the raw tensor without any sort of type-checking or shape-checking + pub fn raw_tensor(&self) -> RawTensor<'t> { + debug_assert!(!self.tensor_info.is_null()); + unsafe { debug_assert!(!(*self.tensor_info).tensor.is_null()) }; + RawTensor::from_ptr(unsafe { (*self.tensor_info).tensor.cast() }) + } +} diff --git a/src/tensor/raw.rs b/src/tensor/raw.rs new file mode 100644 index 0000000..b536699 --- /dev/null +++ b/src/tensor/raw.rs @@ -0,0 +1,140 @@ +use crate::prelude::*; +use core::marker::PhantomData; +use mnn_sys::HalideType; +/// A raw tensor type that doesn't have any guarantees +/// and will be unconditionally dropped +#[repr(transparent)] +pub struct RawTensor<'r> { + pub(crate) inner: *mut mnn_sys::Tensor, + pub(crate) __marker: PhantomData<&'r ()>, +} + +// impl<'r> core::ops::Drop for RawTensor<'r> { +// fn drop(&mut self) { +// unsafe { +// mnn_sys::Tensor_destroy(self.inner); +// } +// } +// } + +impl RawTensor<'_> { + /// Creates a new host tensor from the device tensor + pub fn create_host_tensor_from_device(&self, copy_data: bool) -> RawTensor<'static> { + let tensor = + unsafe { mnn_sys::Tensor_createHostTensorFromDevice(self.inner, copy_data as i32) }; + // crate::ensure!(!tensor.is_null(), ErrorKind::TensorError); + assert!(!tensor.is_null()); + RawTensor { + inner: tensor, + __marker: PhantomData, + } + } + + /// Copies the data from a host tensor to the self tensor + pub fn copy_from_host_tensor(&mut self, tensor: &RawTensor) -> Result<()> { + let ret = unsafe { mnn_sys::Tensor_copyFromHostTensor(self.inner, tensor.inner) }; + crate::ensure!(ret != 0, ErrorKind::TensorCopyFailed(ret)); + Ok(()) + } + + /// Copies the data from the self tensor to a host tensor + pub fn copy_to_host_tensor(&self, tensor: &mut RawTensor) -> Result<()> { + let ret = unsafe { mnn_sys::Tensor_copyToHostTensor(self.inner, tensor.inner) }; + crate::ensure!(ret != 0, ErrorKind::TensorCopyFailed(ret)); + Ok(()) + } + + /// Returns the shape of the tensor + pub fn shape(&self) -> crate::TensorShape { + unsafe { mnn_sys::Tensor_shape(self.inner) }.into() + } + + /// Returns the dimension type of the tensor + pub fn get_dimension_type(&self) -> super::DimensionType { + debug_assert!(!self.inner.is_null()); + From::from(unsafe { mnn_sys::Tensor_getDimensionType(self.inner) }) + } + + /// Cleans up the tensor by calling the destructor of the tensor + pub fn destroy(self) { + unsafe { + mnn_sys::Tensor_destroy(self.inner); + } + } + + /// Returns the size of the tensor when counted by bytes + pub fn size(&self) -> usize { + unsafe { mnn_sys::Tensor_usize(self.inner) } + } + + /// Returns the size of the tensor when counted by elements + pub fn element_size(&self) -> usize { + unsafe { mnn_sys::Tensor_elementSize(self.inner) as usize } + } + + /// Returns the number of dimensions of the tensor + pub fn dimensions(&self) -> usize { + unsafe { mnn_sys::Tensor_dimensions(self.inner) as usize } + } + + /// Returns the width of the tensor + pub fn width(&self) -> u32 { + unsafe { mnn_sys::Tensor_width(self.inner) as u32 } + } + + /// Returns the height of the tensor + pub fn height(&self) -> u32 { + unsafe { mnn_sys::Tensor_height(self.inner) as u32 } + } + + /// Returns the channel of the tensor + pub fn channel(&self) -> u32 { + unsafe { mnn_sys::Tensor_channel(self.inner) as u32 } + } + + /// Returns true if the tensor is unsized and dynamic (needs to be resized to work) + pub fn is_dynamic_unsized(&self) -> bool { + self.shape().as_ref().contains(&-1) + } + + /// Waits for the tensor to be ready + pub fn wait(&self, map_type: MapType, finish: bool) { + unsafe { + mnn_sys::Tensor_wait(self.inner, map_type, finish as i32); + } + } + + /// # Safety + /// This is very unsafe do not use this unless you know what you are doing + /// Gives a raw pointer to the tensor's data + /// P.S. I don't know what I'm doing + pub unsafe fn unchecked_host_ptr(&self) -> *mut c_void { + debug_assert!(!self.inner.is_null()); + let data = mnn_sys::Tensor_host_mut(self.inner); + debug_assert!(data.is_null()); + data + } + + /// # Safety + /// This is very unsafe do not use this unless you know what you are doing + /// Gives a mutable byte slice to the tensor's data + pub unsafe fn unchecked_host_bytes(&mut self) -> &mut [u8] { + core::slice::from_raw_parts_mut(self.unchecked_host_ptr().cast(), self.size()) + } + + /// # Safety + /// This is very unsafe do not use this unless you know what you are doing + pub unsafe fn to_concrete(self) -> super::Tensor + where + T::H: HalideType, + { + super::Tensor::from_ptr(self.inner) + } + + pub(crate) fn from_ptr(inner: *mut mnn_sys::Tensor) -> Self { + Self { + inner, + __marker: PhantomData, + } + } +} diff --git a/tests/backend.rs b/tests/backend.rs index 081d857..1fac1b6 100644 --- a/tests/backend.rs +++ b/tests/backend.rs @@ -3,6 +3,7 @@ use common::*; use mnn::ForwardType; use tracing_test::traced_test; +#[cfg(feature = "coreml")] #[test] #[traced_test] fn compare_cpu_and_coreml_outputs() { diff --git a/tools/bencher/src/main.rs b/tools/bencher/src/main.rs index 67e7683..0ec7af0 100644 --- a/tools/bencher/src/main.rs +++ b/tools/bencher/src/main.rs @@ -351,9 +351,9 @@ pub fn generate_main(cli: Generate) -> Result<()> { let path = model.with_file_name(name); let mut tensor = input.raw_tensor(); unsafe { - let host = tensor.create_host_tensor_from_device(false); + let mut host = tensor.create_host_tensor_from_device(false); host.unchecked_host_bytes().fill(1); - tensor.copy_from_host_tensor(&host); + tensor.copy_from_host_tensor(&host).cc(BenchError)?; std::fs::write(&path, host.unchecked_host_bytes()).cc(BenchError)?; } cfg.inputs.insert( @@ -609,9 +609,9 @@ pub fn bench( not_terminal.then(|| eprintln!("Setting input {name}")); unsafe { let mut tensor = net.raw_input(&session, name).cc(BenchError)?; - let host = tensor.create_host_tensor_from_device(false); + let mut host = tensor.create_host_tensor_from_device(false); host.unchecked_host_bytes().copy_from_slice(&input); - tensor.copy_from_host_tensor(&host); + tensor.copy_from_host_tensor(&host).cc(BenchError)?; drop(host); } } From ffbf063589bd9f502eb9582e357dc494b444e1cd Mon Sep 17 00:00:00 2001 From: uttarayan21 Date: Fri, 20 Dec 2024 15:22:52 +0530 Subject: [PATCH 19/22] chore: Fix clippy lints --- src/interpreter.rs | 1 - src/tensor.rs | 6 +++--- tests/backend.rs | 1 + 3 files changed, 4 insertions(+), 4 deletions(-) diff --git a/src/interpreter.rs b/src/interpreter.rs index 9d2d0c0..8526bd3 100644 --- a/src/interpreter.rs +++ b/src/interpreter.rs @@ -80,7 +80,6 @@ impl core::ops::Deref for TensorCallback { /// The only two items tested are /// - `Debug` /// - `Release` -/// Which work fine #[derive(Debug, Copy, Clone)] #[cfg_attr(windows, repr(i32))] #[cfg_attr(unix, repr(u32))] diff --git a/src/tensor.rs b/src/tensor.rs index 7a61a54..8592cbd 100644 --- a/src/tensor.rs +++ b/src/tensor.rs @@ -176,7 +176,7 @@ impl DimensionType { pub const NCHW: Self = Self::Caffe; /// Caffe style dimensions with channel packed in 4 bytes or NC4HW4 pub const NC4HW4: Self = Self::CaffeC4; - pub(crate) fn to_mnn_sys(&self) -> mnn_sys::DimensionType { + pub(crate) fn to_mnn_sys(self) -> mnn_sys::DimensionType { match self { DimensionType::Caffe => mnn_sys::DimensionType::CAFFE, DimensionType::CaffeC4 => mnn_sys::DimensionType::CAFFE_C4, @@ -646,7 +646,7 @@ where } #[test] -pub fn test_tensor_borrowed() { +fn test_tensor_borrowed() { let shape = [1, 2, 3]; let data = vec![1, 2, 3, 4, 5, 6]; let tensor = Tensor::>>::borrowed(&shape, &data); @@ -655,7 +655,7 @@ pub fn test_tensor_borrowed() { } #[test] -pub fn test_tensor_borrow_mut() { +fn test_tensor_borrow_mut() { let shape = [1, 2, 3]; let mut data = vec![1, 2, 3, 4, 5, 6]; let mut tensor = Tensor::>>::borrowed_mut(&shape, &mut data); diff --git a/tests/backend.rs b/tests/backend.rs index 1fac1b6..4cde667 100644 --- a/tests/backend.rs +++ b/tests/backend.rs @@ -1,3 +1,4 @@ +#![allow(unused_imports)] pub mod common; use common::*; use mnn::ForwardType; From 954c631c4a462cebbee934b283ceacc12f7c61c3 Mon Sep 17 00:00:00 2001 From: uttarayan21 Date: Fri, 20 Dec 2024 15:39:00 +0530 Subject: [PATCH 20/22] fix(flake): Don't use mnn as a buildInput on common args --- flake.nix | 6 +----- 1 file changed, 1 insertion(+), 5 deletions(-) diff --git a/flake.nix b/flake.nix index 817255d..6ed5717 100644 --- a/flake.nix +++ b/flake.nix @@ -90,9 +90,7 @@ pkg-config ]; buildInputs = with pkgs; - [ - mnn - ] + [] ++ (lib.optionals pkgs.stdenv.isLinux [ ocl-icd opencl-headers @@ -218,7 +216,6 @@ git-lfs llvm llvmPackages.lldb - # mnn nushell rust-bindgen google-cloud-sdk @@ -229,7 +226,6 @@ cargo-llvm-cov ] ); - # ++ (with packages; [bencher inspect]); }); }; } From 3d61bc6c47988cff57844a9aad6e3ca58d51f440 Mon Sep 17 00:00:00 2001 From: uttarayan21 Date: Fri, 20 Dec 2024 16:35:55 +0530 Subject: [PATCH 21/22] feat: Added some more tests --- src/backend.rs | 25 ++++++++++++++++++++++--- 1 file changed, 22 insertions(+), 3 deletions(-) diff --git a/src/backend.rs b/src/backend.rs index bbb3868..7f9a45e 100644 --- a/src/backend.rs +++ b/src/backend.rs @@ -67,7 +67,7 @@ impl Default for BackendConfig { } /// PowerModes depend on if the specific backend has support for it -#[derive(Debug, Clone, Copy)] +#[derive(Debug, Clone, Copy, Eq, PartialEq, Hash)] #[cfg_attr(feature = "serde", derive(serde::Serialize, serde::Deserialize))] pub enum PowerMode { /// Low power mode @@ -151,7 +151,7 @@ impl FromStr for PrecisionMode { } /// MemoryModes depend on if the specific backend has support for it -#[derive(Debug, Clone, Copy)] +#[derive(Debug, Clone, Copy, Eq, PartialEq, Hash)] #[cfg_attr(feature = "serde", derive(serde::Serialize, serde::Deserialize))] pub enum MemoryMode { /// Low memory mode @@ -191,7 +191,7 @@ impl MemoryMode { } /// PrecisionModes depend on if the specific backend has support for it -#[derive(Debug, Clone, Copy)] +#[derive(Debug, Clone, Copy, Eq, PartialEq, Hash)] #[cfg_attr(feature = "serde", derive(serde::Serialize, serde::Deserialize))] pub enum PrecisionMode { /// Normal precision mode @@ -331,3 +331,22 @@ impl BackendConfig { self } } + +#[test] +fn test_backend_config() { + let mut config = BackendConfig::new(); + config.set_memory_mode(MemoryMode::Low); + config.set_power_mode(PowerMode::Low); + config.set_precision_mode(PrecisionMode::Low); + let config = std::hint::black_box(config.clone()); + assert_eq!(config.get_memory_mode(), MemoryMode::Low); + assert_eq!(config.get_power_mode(), PowerMode::Low); + assert_eq!(config.get_precision_mode(), PrecisionMode::Low); + let config = config + .with_memory_mode(MemoryMode::Normal) + .with_power_mode(PowerMode::Normal) + .with_precision_mode(PrecisionMode::Normal); + assert_eq!(config.get_memory_mode(), MemoryMode::Normal); + assert_eq!(config.get_power_mode(), PowerMode::Normal); + assert_eq!(config.get_precision_mode(), PrecisionMode::Normal); +} From e337d711caad7c8e5204bdb606901b5e4d92b5fa Mon Sep 17 00:00:00 2001 From: uttarayan21 Date: Fri, 20 Dec 2024 16:37:02 +0530 Subject: [PATCH 22/22] feat: Added cachix push script --- tools/cachix/push.sh | 2 ++ 1 file changed, 2 insertions(+) create mode 100755 tools/cachix/push.sh diff --git a/tools/cachix/push.sh b/tools/cachix/push.sh new file mode 100755 index 0000000..0425b5c --- /dev/null +++ b/tools/cachix/push.sh @@ -0,0 +1,2 @@ +cachix watch-exec mnn-rs -- nix flake check --system x86_64-linux --max-jobs 0 +cachix watch-exec mnn-rs -- nix flake check --system aarch64-darwin