Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Added basic docs #5

Merged
merged 1 commit into from
Oct 23, 2024
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
87 changes: 84 additions & 3 deletions src/interpreter.rs
Original file line number Diff line number Diff line change
Expand Up @@ -121,6 +121,7 @@ impl SessionMode {
}
}

/// net data holder. multiple sessions could share same net.
#[repr(transparent)]
pub struct Interpreter {
pub(crate) inner: *mut mnn_sys::Interpreter,
Expand All @@ -136,6 +137,11 @@ impl Drop for Interpreter {
}

impl Interpreter {
/// Create an net/interpreter from a file.
///
/// `path`: the file path of the model
///
/// return: the created net/interpreter
pub fn from_file(path: impl AsRef<Path>) -> Result<Self> {
let path = path.as_ref();
ensure!(path.exists(), ErrorKind::IOError; path.to_string_lossy().to_string(), "File not found");
Expand All @@ -149,6 +155,11 @@ impl Interpreter {
})
}

/// Create an net/interpreter from a buffer.
///
/// `bytes`: the buffer of the model
///
/// return: the created net/interpreter
pub fn from_bytes(bytes: impl AsRef<[u8]>) -> Result<Self> {
let bytes = bytes.as_ref();
let size = bytes.len();
Expand All @@ -161,14 +172,31 @@ impl Interpreter {
})
}

/// Set session mode
///
/// `mode`: the session mode
///
/// **Warning:**
/// It should be called before create session!
pub fn set_session_mode(&mut self, mode: SessionMode) {
unsafe { mnn_sys::Interpreter_setSessionMode(self.inner, mode.to_mnn_sys()) }
}

///call this function to get tensors ready.
///
///output tensor buffer (host or deviceId) should be retrieved after resize of any input tensor.
///
///`session`: the session to be prepared
pub fn resize_session(&self, session: &mut crate::Session) {
unsafe { mnn_sys::Interpreter_resizeSession(self.inner, session.inner) }
}

/// Resize session and reallocate the buffer.
///
/// `session`: the session to be prepared.
///
/// # Note
/// NeedRelloc is default to 1, 1 means need realloc!
pub fn resize_session_reallocate(&self, session: &mut crate::Session) {
unsafe { mnn_sys::Interpreter_resizeSessionWithFlag(self.inner, session.inner, 1i32) }
}
Expand Down Expand Up @@ -206,6 +234,11 @@ impl Interpreter {
}
}

/// Create a session with session config. Session will be managed in net/interpreter.
///
/// `schedule` : the config of the session
///
/// return: the created session
pub fn create_session(
&mut self,
schedule: crate::ScheduleConfig,
Expand All @@ -221,6 +254,11 @@ impl Interpreter {
})
}

/// Create multi-path session with schedule configs and user-specified runtime. created session will be managed in net/interpreter.
///
/// `schedule` : the config of the session
///
/// return: the created session
pub fn create_multipath_session(
&mut self,
schedule: impl IntoIterator<Item = ScheduleConfig>,
Expand All @@ -238,6 +276,7 @@ impl Interpreter {
})
}

/// Print all input and output tensors info.
pub fn model_print_io(path: impl AsRef<Path>) -> Result<()> {
let path = path.as_ref();
crate::ensure!(path.exists(), ErrorKind::IOError);
Expand All @@ -247,11 +286,23 @@ impl Interpreter {
Ok(())
}

/// Get the input tensor of the session.
///
/// `session`: the session to get input tensor
///
/// return: List of input tensors
pub fn inputs(&self, session: &crate::Session) -> TensorList {
let inputs = unsafe { mnn_sys::Interpreter_getSessionInputAll(self.inner, session.inner) };
TensorList::from_ptr(inputs)
}

/// Get the input tensor of the session by name.
///
/// `session`: the session to get input tensor from
///
/// `name`: the name of the input tensor
///
/// return: the input tensor
pub fn input<'s, H: HalideType>(
&self,
session: &'s crate::Session,
Expand Down Expand Up @@ -291,7 +342,7 @@ impl Interpreter {
}

/// # Safety
/// We Still don't know the safety guarantees of this function so it's marked unsafe
/// **Warning** We Still don't know the safety guarantees of this function so it's marked unsafe
pub unsafe fn input_unresized<'s, H: HalideType>(
&self,
session: &'s crate::Session,
Expand All @@ -314,7 +365,7 @@ impl Interpreter {
}

/// # Safety
/// Very unsafe since it doesn't check the type of the tensor
/// Very **unsafe** since it doesn't check the type of the tensor
/// as well as the shape of the tensor
pub unsafe fn input_unchecked<'s, H: HalideType>(
&self,
Expand All @@ -329,6 +380,10 @@ impl Interpreter {
}

/// Get the output tensor of a session by name
///
/// `session` : the session to get output tensor from
///
/// `name` : the name of the output tensor
pub fn output<'s, H: HalideType>(
&self,
session: &'s crate::Session,
Expand Down Expand Up @@ -366,6 +421,7 @@ impl Interpreter {
Ok(RawTensor::from_ptr(output))
}

/// Run a session
pub fn run_session(&mut self, session: &crate::session::Session) -> Result<()> {
profile!("Running session"; {
let ret = unsafe { mnn_sys::Interpreter_runSession(self.inner, session.inner) };
Expand All @@ -377,6 +433,15 @@ impl Interpreter {
})
}

/// Run a session with a callback
///
/// `session` : the session to run
///
/// `before` : a callback before each op. return true to run the op; return false to skip the op.
///
/// `after` : a callback after each op. return true to continue running; return false to interrupt the session.
///
/// `sync` : synchronously wait for finish of execution or not.
pub fn run_session_with_callback(
&mut self,
session: &crate::session::Session,
Expand All @@ -403,12 +468,24 @@ impl Interpreter {
Ok(())
}

/// Get all output tensors of a session
pub fn outputs(&self, session: &crate::session::Session) -> TensorList {
let outputs =
unsafe { mnn_sys::Interpreter_getSessionOutputAll(self.inner, session.inner) };
TensorList::from_ptr(outputs)
}

/// If the cache exist, try to load cache from file.
/// After createSession, try to save cache to file.
///
/// `cache_file` : the file path to save or load cache.
///
/// `key_size` : the size of key
///
/// # Note
/// The API should be called before create session.
///
/// Key Depercerate, keeping for future use!
pub fn set_cache_file(&mut self, path: impl AsRef<Path>, key_size: usize) -> Result<()> {
let path = path.as_ref();
let path = dunce::simplified(path);
Expand All @@ -417,6 +494,8 @@ impl Interpreter {
unsafe { mnn_sys::Interpreter_setCacheFile(self.inner, c_path.as_ptr(), key_size) }
Ok(())
}

/// Update cache file
pub fn update_cache_file(&mut self, session: &mut crate::session::Session) -> Result<()> {
MNNError::from_error_code(unsafe {
mnn_sys::Interpreter_updateCacheFile(self.inner, session.inner)
Expand All @@ -433,6 +512,7 @@ impl Interpreter {
});
}

/// Get memory usage of a session in MB
pub fn memory(&self, session: &crate::session::Session) -> Result<f32> {
let mut memory = 0f32;
let memory_ptr = &mut memory as *mut f32;
Expand All @@ -447,6 +527,7 @@ impl Interpreter {
Ok(memory)
}

/// Get float operation needed in session in M
pub fn flops(&self, session: &crate::Session) -> Result<f32> {
let mut flop = 0.0f32;
let flop_ptr = &mut flop as *mut f32;
Expand Down Expand Up @@ -535,7 +616,7 @@ impl<'t, 'tl> TensorInfo<'t, 'tl> {
}

/// # Safety
/// The shape is not checked so it's marked unsafe since futher calls to interpreter might be unsafe with this
/// The shape is not checked so it's marked unsafe since futher calls to interpreter might be **unsafe** with this
pub unsafe fn tensor_unresized<H: HalideType>(&self) -> Result<Tensor<RefMut<'t, Device<H>>>> {
debug_assert!(!self.tensor_info.is_null());
unsafe { debug_assert!(!(*self.tensor_info).tensor.is_null()) };
Expand Down
Loading
Loading