From b392211de1d7dace51c03f1e30d8643c433e1be1 Mon Sep 17 00:00:00 2001 From: Amit Sheokand Date: Wed, 23 Oct 2024 12:33:21 +0530 Subject: [PATCH] added basic docs --- src/interpreter.rs | 87 ++++++++++++++++++++++++++++++++++-- src/schedule.rs | 108 +++++++++++++++++++++++++++++++++++++++++++++ src/session.rs | 10 +++++ 3 files changed, 202 insertions(+), 3 deletions(-) diff --git a/src/interpreter.rs b/src/interpreter.rs index f04b466..3d63312 100644 --- a/src/interpreter.rs +++ b/src/interpreter.rs @@ -121,6 +121,7 @@ impl SessionMode { } } +/// net data holder. multiple sessions could share same net. #[repr(transparent)] pub struct Interpreter { pub(crate) inner: *mut mnn_sys::Interpreter, @@ -136,6 +137,11 @@ impl Drop for Interpreter { } impl Interpreter { + /// Create an net/interpreter from a file. + /// + /// `path`: the file path of the model + /// + /// return: the created net/interpreter pub fn from_file(path: impl AsRef) -> Result { let path = path.as_ref(); ensure!(path.exists(), ErrorKind::IOError; path.to_string_lossy().to_string(), "File not found"); @@ -149,6 +155,11 @@ impl Interpreter { }) } + /// Create an net/interpreter from a buffer. + /// + /// `bytes`: the buffer of the model + /// + /// return: the created net/interpreter pub fn from_bytes(bytes: impl AsRef<[u8]>) -> Result { let bytes = bytes.as_ref(); let size = bytes.len(); @@ -161,14 +172,31 @@ impl Interpreter { }) } + /// Set session mode + /// + /// `mode`: the session mode + /// + /// **Warning:** + /// It should be called before create session! pub fn set_session_mode(&mut self, mode: SessionMode) { unsafe { mnn_sys::Interpreter_setSessionMode(self.inner, mode.to_mnn_sys()) } } + ///call this function to get tensors ready. + /// + ///output tensor buffer (host or deviceId) should be retrieved after resize of any input tensor. + /// + ///`session`: the session to be prepared pub fn resize_session(&self, session: &mut crate::Session) { unsafe { mnn_sys::Interpreter_resizeSession(self.inner, session.inner) } } + /// Resize session and reallocate the buffer. + /// + /// `session`: the session to be prepared. + /// + /// # Note + /// NeedRelloc is default to 1, 1 means need realloc! pub fn resize_session_reallocate(&self, session: &mut crate::Session) { unsafe { mnn_sys::Interpreter_resizeSessionWithFlag(self.inner, session.inner, 1i32) } } @@ -206,6 +234,11 @@ impl Interpreter { } } + /// Create a session with session config. Session will be managed in net/interpreter. + /// + /// `schedule` : the config of the session + /// + /// return: the created session pub fn create_session( &mut self, schedule: crate::ScheduleConfig, @@ -221,6 +254,11 @@ impl Interpreter { }) } + /// Create multi-path session with schedule configs and user-specified runtime. created session will be managed in net/interpreter. + /// + /// `schedule` : the config of the session + /// + /// return: the created session pub fn create_multipath_session( &mut self, schedule: impl IntoIterator, @@ -238,6 +276,7 @@ impl Interpreter { }) } + /// Print all input and output tensors info. pub fn model_print_io(path: impl AsRef) -> Result<()> { let path = path.as_ref(); crate::ensure!(path.exists(), ErrorKind::IOError); @@ -247,11 +286,23 @@ impl Interpreter { Ok(()) } + /// Get the input tensor of the session. + /// + /// `session`: the session to get input tensor + /// + /// return: List of input tensors pub fn inputs(&self, session: &crate::Session) -> TensorList { let inputs = unsafe { mnn_sys::Interpreter_getSessionInputAll(self.inner, session.inner) }; TensorList::from_ptr(inputs) } + /// Get the input tensor of the session by name. + /// + /// `session`: the session to get input tensor from + /// + /// `name`: the name of the input tensor + /// + /// return: the input tensor pub fn input<'s, H: HalideType>( &self, session: &'s crate::Session, @@ -291,7 +342,7 @@ impl Interpreter { } /// # Safety - /// We Still don't know the safety guarantees of this function so it's marked unsafe + /// **Warning** We Still don't know the safety guarantees of this function so it's marked unsafe pub unsafe fn input_unresized<'s, H: HalideType>( &self, session: &'s crate::Session, @@ -314,7 +365,7 @@ impl Interpreter { } /// # Safety - /// Very unsafe since it doesn't check the type of the tensor + /// Very **unsafe** since it doesn't check the type of the tensor /// as well as the shape of the tensor pub unsafe fn input_unchecked<'s, H: HalideType>( &self, @@ -329,6 +380,10 @@ impl Interpreter { } /// Get the output tensor of a session by name + /// + /// `session` : the session to get output tensor from + /// + /// `name` : the name of the output tensor pub fn output<'s, H: HalideType>( &self, session: &'s crate::Session, @@ -366,6 +421,7 @@ impl Interpreter { Ok(RawTensor::from_ptr(output)) } + /// Run a session pub fn run_session(&mut self, session: &crate::session::Session) -> Result<()> { profile!("Running session"; { let ret = unsafe { mnn_sys::Interpreter_runSession(self.inner, session.inner) }; @@ -377,6 +433,15 @@ impl Interpreter { }) } + /// Run a session with a callback + /// + /// `session` : the session to run + /// + /// `before` : a callback before each op. return true to run the op; return false to skip the op. + /// + /// `after` : a callback after each op. return true to continue running; return false to interrupt the session. + /// + /// `sync` : synchronously wait for finish of execution or not. pub fn run_session_with_callback( &mut self, session: &crate::session::Session, @@ -403,12 +468,24 @@ impl Interpreter { Ok(()) } + /// Get all output tensors of a session pub fn outputs(&self, session: &crate::session::Session) -> TensorList { let outputs = unsafe { mnn_sys::Interpreter_getSessionOutputAll(self.inner, session.inner) }; TensorList::from_ptr(outputs) } + /// If the cache exist, try to load cache from file. + /// After createSession, try to save cache to file. + /// + /// `cache_file` : the file path to save or load cache. + /// + /// `key_size` : the size of key + /// + /// # Note + /// The API should be called before create session. + /// + /// Key Depercerate, keeping for future use! pub fn set_cache_file(&mut self, path: impl AsRef, key_size: usize) -> Result<()> { let path = path.as_ref(); let path = dunce::simplified(path); @@ -417,6 +494,8 @@ impl Interpreter { unsafe { mnn_sys::Interpreter_setCacheFile(self.inner, c_path.as_ptr(), key_size) } Ok(()) } + + /// Update cache file pub fn update_cache_file(&mut self, session: &mut crate::session::Session) -> Result<()> { MNNError::from_error_code(unsafe { mnn_sys::Interpreter_updateCacheFile(self.inner, session.inner) @@ -433,6 +512,7 @@ impl Interpreter { }); } + /// Get memory usage of a session in MB pub fn memory(&self, session: &crate::session::Session) -> Result { let mut memory = 0f32; let memory_ptr = &mut memory as *mut f32; @@ -447,6 +527,7 @@ impl Interpreter { Ok(memory) } + /// Get float operation needed in session in M pub fn flops(&self, session: &crate::Session) -> Result { let mut flop = 0.0f32; let flop_ptr = &mut flop as *mut f32; @@ -535,7 +616,7 @@ impl<'t, 'tl> TensorInfo<'t, 'tl> { } /// # Safety - /// The shape is not checked so it's marked unsafe since futher calls to interpreter might be unsafe with this + /// The shape is not checked so it's marked unsafe since futher calls to interpreter might be **unsafe** with this pub unsafe fn tensor_unresized(&self) -> Result>>> { debug_assert!(!self.tensor_info.is_null()); unsafe { debug_assert!(!(*self.tensor_info).tensor.is_null()) }; diff --git a/src/schedule.rs b/src/schedule.rs index e7693b7..4201de2 100644 --- a/src/schedule.rs +++ b/src/schedule.rs @@ -3,6 +3,34 @@ use std::{ffi::CString, mem::ManuallyDrop}; use crate::{prelude::*, BackendConfig}; +/// The `ForwardType` enum is used to specify the backend that will be used for forward computation +/// in the MNN framework. Each variant corresponds to a different backend, which may be enabled +/// or disabled based on the features enabled in the build configuration. +/// +/// # Variants +/// +/// - `All`: Use all available backends. +/// - `Auto`: Automatically select the best backend based on the current environment and hardware. +/// - `CPU`: Use the CPU for computation. +/// - `Metal`: Use the Metal backend for computation (requires the `metal` feature). +/// - `OpenCL`: Use the OpenCL backend for computation (requires the `opencl` feature). +/// - `OpenGL`: Use the OpenGL backend for computation (requires the `opengl` feature). +/// - `Vulkan`: Use the Vulkan backend for computation (requires the `vulkan` feature). +/// - `CoreML`: Use the CoreML backend for computation (requires the `coreml` feature). +/// +/// # Example +/// +/// ```rust +/// use mnn_rs::schedule::ForwardType; +/// +/// let forward_type = ForwardType::Auto; +/// println!("Selected forward type: {:?}", forward_type); +/// ``` +/// +/// # Note +/// +/// The availability of certain variants depends on the features enabled during the build. +/// For example, the `Metal` variant is only available if the `metal` feature is enabled. #[derive(Debug, Copy, Clone, Default, PartialEq, Eq)] pub enum ForwardType { All, @@ -22,6 +50,7 @@ pub enum ForwardType { } impl ForwardType { + /// Convert the `ForwardType` enum to the corresponding C++ `MNNForwardType` enum. fn to_mnn_sys(self) -> MNNForwardType { match self { ForwardType::Auto => MNNForwardType::MNN_FORWARD_AUTO, @@ -89,6 +118,49 @@ impl core::str::FromStr for ForwardType { } } +/// Configuration for scheduling the forward computation in MNN. +/// +/// The `ScheduleConfig` struct is used to configure various parameters for scheduling the forward +/// computation in the MNN framework. It allows setting the type of backend, the number of threads, +/// the mode of computation, and other options. +/// +/// # Example +/// +/// ```rust +/// use mnn_rs::schedule::{ScheduleConfig, ForwardType}; +/// +/// let mut config = ScheduleConfig::new(); +/// config.set_type(ForwardType::Auto); +/// config.set_num_threads(4); +/// config.set_mode(0); +/// ``` +/// +/// # Fields +/// +/// - `inner`: A raw pointer to the underlying `MNNScheduleConfig` structure. +/// - `backend_config`: Specifies backend-specific configurations. +/// - `__marker`: A marker to ensure the struct is `!Send` by default. +/// +/// # Methods +/// +/// - `new() -> Self`: Creates a new `ScheduleConfig` with default settings. +/// - `as_ptr_mut(&mut self) -> *mut MNNScheduleConfig`: Returns a mutable raw pointer to the underlying `MNNScheduleConfig`. +/// - `set_save_tensors(&mut self, save_tensors: &[&str]) -> Result<()>`: Sets the tensors to be saved during computation. +/// - `set_type(&mut self, forward_type: ForwardType)`: Sets the type of backend to be used for computation. +/// - `set_num_threads(&mut self, num_threads: i32)`: Sets the number of threads to be used for computation. +/// - `set_mode(&mut self, mode: i32)`: Sets the mode of computation. +/// - `set_backup_type(&mut self, backup_type: ForwardType)`: Sets the backup type of backend to be used if the primary backend fails. +/// - `set_backend_config(&mut self, backend_config: impl Into>)`: Sets the backend-specific configuration. +/// +/// # Safety +/// +/// The `ScheduleConfig` struct contains raw pointers and interacts with the underlying C API of MNN. +/// Users should be cautious when using this struct to avoid undefined behavior. +/// +/// # Warning +/// +/// **Warning:** The `Drop` implementation for `ScheduleConfig` ensures that the underlying `MNNScheduleConfig` +/// is properly destroyed when the struct goes out of scope. Users should not manually free the `inner` pointer. // #[derive(Debug)] pub struct ScheduleConfig { pub(crate) inner: *mut MNNScheduleConfig, @@ -113,10 +185,12 @@ impl Default for ScheduleConfig { } impl ScheduleConfig { + /// Returns a mutable raw pointer to the underlying `MNNScheduleConfig`. pub fn as_ptr_mut(&mut self) -> *mut MNNScheduleConfig { self.inner } + /// Creates a new `ScheduleConfig` with default settings. pub fn new() -> Self { unsafe { let inner = mnnsc_create(); @@ -128,6 +202,15 @@ impl ScheduleConfig { } } + /// Sets the tensors to be saved during computation. + /// + /// # Arguments + /// + /// - `save_tensors`: A slice of tensor names to be saved. + /// + /// # Errors + /// + /// Returns an error if any of the tensor names contain null bytes. pub fn set_save_tensors(&mut self, save_tensors: &[&str]) -> Result<()> { let vec_cstring = save_tensors .iter() @@ -141,30 +224,55 @@ impl ScheduleConfig { Ok(()) } + /// Sets the type of backend to be used for computation. + /// + /// # Arguments + /// + /// - `forward_type`: The type of backend to be used. pub fn set_type(&mut self, forward_type: ForwardType) { unsafe { mnnsc_set_type(self.inner, forward_type.to_mnn_sys()); } } + /// Sets the number of threads to be used for computation. + /// + /// # Arguments + /// + /// - `num_threads`: The number of threads to be used. pub fn set_num_threads(&mut self, num_threads: i32) { unsafe { mnnsc_set_num_threads(self.inner, num_threads); } } + /// Sets the mode of computation. + /// + /// # Arguments + /// + /// - `mode`: The mode of computation. pub fn set_mode(&mut self, mode: i32) { unsafe { mnnsc_set_mode(self.inner, mode); } } + /// Sets the backup type of backend to be used if the primary backend fails. + /// + /// # Arguments + /// + /// - `backup_type`: The backup type of backend to be used. pub fn set_backup_type(&mut self, backup_type: ForwardType) { unsafe { mnnsc_set_backup_type(self.inner, backup_type.to_mnn_sys()); } } + /// Sets the backend-specific configuration. + /// + /// # Arguments + /// + /// - `backend_config`: specifies additional backend-specific configurations. pub fn set_backend_config(&mut self, backend_config: impl Into>) { self.backend_config = backend_config.into(); let ptr = if let Some(ref b) = self.backend_config { diff --git a/src/session.rs b/src/session.rs index 4dd91fb..b905072 100644 --- a/src/session.rs +++ b/src/session.rs @@ -1,13 +1,22 @@ use crate::prelude::*; +/// A session is a context in which a computation graph is executed. +/// +/// Inference unit. multiple sessions could share one net/interpreter. pub struct Session { + /// Pointer to the underlying MNN session. pub(crate) inner: *mut mnn_sys::Session, + /// Internal session configurations. pub(crate) __session_internals: crate::SessionInternals, + /// Marker to ensure the struct is not Send or Sync. pub(crate) __marker: PhantomData<()>, } +/// Enum representing the internal configurations of a session. pub enum SessionInternals { + /// Single session configuration. Single(crate::ScheduleConfig), + /// Multiple session configurations. MultiSession(crate::ScheduleConfigs), } @@ -25,6 +34,7 @@ impl Session { } impl Drop for Session { + /// Custom drop implementation to ensure the underlying MNN session is properly destroyed. fn drop(&mut self) { unsafe { mnn_sys::Session_destroy(self.inner) } }