Skip to content

Commit

Permalink
feat(api): Added input_unresized method for getting unresized/dynamic…
Browse files Browse the repository at this point in the history
… tensors
  • Loading branch information
uttarayan21 committed Oct 7, 2024
1 parent 4de5fbd commit dbf27ce
Show file tree
Hide file tree
Showing 5 changed files with 74 additions and 7 deletions.
3 changes: 2 additions & 1 deletion .gitignore
Original file line number Diff line number Diff line change
Expand Up @@ -4,4 +4,5 @@ target
*.mnn
*.ppm
lama
*.json
*.json
*.cache
13 changes: 12 additions & 1 deletion examples/inspect.rs
Original file line number Diff line number Diff line change
Expand Up @@ -41,7 +41,18 @@ pub fn main() -> anyhow::Result<()> {
config.set_type(cli.forward);
let mut session = time!(interpreter.create_session(config)?; "create session");
interpreter.update_cache_file(&mut session)?;

let inputs = interpreter.inputs(&session);
let mut first = inputs
.iter()
.next()
.expect("No input")
.tensor::<f32>()
.unwrap();
let shape = first.shape();
interpreter.resize_tensor(&mut first, shape);
interpreter.resize_session(&mut session);
drop(first);
drop(inputs);
let mut current = 0;
time!(loop {
interpreter.inputs(&session).iter().for_each(|x| {
Expand Down
58 changes: 58 additions & 0 deletions src/interpreter.rs
Original file line number Diff line number Diff line change
Expand Up @@ -286,6 +286,45 @@ impl Interpreter {
Ok(RawTensor::from_ptr(input))
}

/// * Safety
/// We Still don't know the safety guarantees of this function so it's marked unsafe
pub unsafe fn input_unresized<'s, H: HalideType>(
&self,
session: &'s crate::Session,
name: impl AsRef<str>,
) -> Result<Tensor<RefMut<'s, Device<H>>>> {
let name = name.as_ref();
let c_name = std::ffi::CString::new(name).change_context(ErrorKind::AsciiError)?;
let input = unsafe {
mnn_sys::Interpreter_getSessionInput(self.inner, session.inner, c_name.as_ptr())
};
ensure!(!input.is_null(), ErrorKind::TensorError; format!("Input tensor \"{name}\" not found"));
let tensor = unsafe { Tensor::from_ptr(input) };
ensure!(
tensor.is_type_of::<H>(),
ErrorKind::HalideTypeMismatch {
got: std::any::type_name::<H>(),
}
);
Ok(tensor)
}

/// * Safety
/// Very unsafe since it doesn't check the type of the tensor
/// as well as the shape of the tensor
pub unsafe fn input_unchecked<'s, H: HalideType>(
&self,
session: &'s crate::Session,
name: impl AsRef<str>,
) -> Tensor<RefMut<'s, Device<H>>> {
let name = name.as_ref();
let c_name = std::ffi::CString::new(name).unwrap();
let input =
mnn_sys::Interpreter_getSessionInput(self.inner, session.inner, c_name.as_ptr());
Tensor::from_ptr(input)
}

/// Get the output tensor of a session by name
pub fn output<'s, H: HalideType>(
&self,
session: &'s crate::Session,
Expand Down Expand Up @@ -429,6 +468,25 @@ impl<'t, 'tl> TensorInfo<'t, 'tl> {
Ok(tensor)
}

/// * Safety
/// The shape is not checked so it's marked unsafe since futher calls to interpreter might be
/// unsafe with this
pub unsafe fn tensor_unresized<H: HalideType>(&self) -> Result<Tensor<RefMut<'t, Device<H>>>>
where
H: HalideType,
{
debug_assert!(!self.tensor_info.is_null());
unsafe { debug_assert!(!(*self.tensor_info).tensor.is_null()) };
let tensor = unsafe { Tensor::from_ptr((*self.tensor_info).tensor.cast()) };
ensure!(
tensor.is_type_of::<H>(),
ErrorKind::HalideTypeMismatch {
got: std::any::type_name::<H>(),
}
);
Ok(tensor)
}

pub fn raw_tensor(&self) -> RawTensor<'t> {
debug_assert!(!self.tensor_info.is_null());
unsafe { debug_assert!(!(*self.tensor_info).tensor.is_null()) };
Expand Down
5 changes: 0 additions & 5 deletions src/session.rs
Original file line number Diff line number Diff line change
Expand Up @@ -29,8 +29,3 @@ impl Drop for Session {
unsafe { mnn_sys::Session_destroy(self.inner) }
}
}

// pub struct SessionInterpreter {
// session: Session<'static>,
// interpreter: Interpreter,
// }
2 changes: 2 additions & 0 deletions src/tensor.rs
Original file line number Diff line number Diff line change
Expand Up @@ -316,6 +316,7 @@ impl<T: HostTensorType> Tensor<T>
where
T::H: HalideType,
{
/// Try to map the device tensor to the host memory and get the slice
pub fn try_host(&self) -> Result<&[T::H]> {
let size = self.element_size();
ensure!(
Expand All @@ -331,6 +332,7 @@ where
Ok(result)
}

/// Try to map the device tensor to the host memory and get the mutable slice
pub fn try_host_mut(&mut self) -> Result<&mut [T::H]> {
let size = self.element_size();
ensure!(
Expand Down

0 comments on commit dbf27ce

Please sign in to comment.