From aaa79171bf17a8b706f5150bd730c0a4ec7b5837 Mon Sep 17 00:00:00 2001 From: Radiant <69520693+RadiantUwU@users.noreply.github.com> Date: Sat, 25 Jan 2025 23:18:20 +0200 Subject: [PATCH] Add yielding support. --- src/error.rs | 10 ++++++++++ src/state.rs | 26 ++++++++++++++++++++++++++ src/state/extra.rs | 5 +++++ src/state/raw.rs | 6 ++++++ src/state/util.rs | 17 ++++++++++++++++- tests/thread.rs | 12 ++++++++++++ 6 files changed, 75 insertions(+), 1 deletion(-) diff --git a/src/error.rs b/src/error.rs index 1f243967..7cae9b0a 100644 --- a/src/error.rs +++ b/src/error.rs @@ -205,6 +205,13 @@ pub enum Error { /// Underlying error. cause: Arc, }, + // Yield. + // + // Not an error. + // Returning `Err(Yielding)` from a Rust callback will yield with no values. + // See `Lua::yield(args: impl FromLuaMulti) -> Result` for more info. + // If it cannot yield, it will raise an error. + Yielding, } /// A specialized `Result` type used by `mlua`'s API. @@ -321,6 +328,9 @@ impl fmt::Display for Error { Error::WithContext { context, cause } => { writeln!(fmt, "{context}")?; write!(fmt, "{cause}") + }, + Error::Yielding => { + write!(fmt, "yield across Rust/Lua boundary") } } } diff --git a/src/state.rs b/src/state.rs index 35d9e4ec..15607142 100644 --- a/src/state.rs +++ b/src/state.rs @@ -1,5 +1,6 @@ use std::any::TypeId; use std::cell::{BorrowError, BorrowMutError, RefCell}; +use std::convert::Infallible; use std::marker::PhantomData; use std::ops::Deref; use std::os::raw::c_int; @@ -1966,6 +1967,31 @@ impl Lua { pub(crate) unsafe fn raw_lua(&self) -> &RawLua { &*self.raw.data_ptr() } + /// Yields arguments + /// + /// If this function cannot yield, it will raise a runtime error. It yields by using + /// the Error::Yielding error to rapidly exit all of the try blocks. + /// + /// Note: On lua 5.1, 5.2, and JIT, this function will unable to know if it can yield + /// or not until it reaches the Lua state. + pub fn yield_args(&self, args: impl IntoLuaMulti) -> Result<()> { + let raw = self.lock(); + #[cfg(not(any(feature = "lua51", feature = "lua52", feature = "luajit")))] + if !raw.is_yieldable() { + return Err(Error::runtime("cannot yield across Rust/Lua boundary.")) + } + unsafe { + raw.extra.get().as_mut().unwrap_unchecked().yielded_values = args.into_lua_multi(self)?; + } + Err(Error::Yielding) + } + + /// Checks if Lua is currently allowed to yield. + #[cfg(not(any(feature = "lua51", feature="lua52", feature = "luajit")))] + #[inline] + pub(crate) fn is_yieldable(&self) -> bool { + self.lock().is_yieldable() + } } impl WeakLua { diff --git a/src/state/extra.rs b/src/state/extra.rs index d1823b5c..c5243710 100644 --- a/src/state/extra.rs +++ b/src/state/extra.rs @@ -18,6 +18,7 @@ use crate::util::{get_internal_metatable, push_internal_userdata, TypeKey, Wrapp #[cfg(any(feature = "luau", doc))] use crate::chunk::Compiler; +use crate::MultiValue; #[cfg(feature = "async")] use {futures_util::task::noop_waker_ref, std::ptr::NonNull, std::task::Waker}; @@ -87,6 +88,9 @@ pub(crate) struct ExtraData { pub(super) compiler: Option, #[cfg(feature = "luau-jit")] pub(super) enable_jit: bool, + + // Values currently being yielded from Lua.yield() + pub(super) yielded_values: MultiValue, } impl Drop for ExtraData { @@ -182,6 +186,7 @@ impl ExtraData { compiler: None, #[cfg(feature = "luau-jit")] enable_jit: true, + yielded_values: MultiValue::default(), })); // Store it in the registry diff --git a/src/state/raw.rs b/src/state/raw.rs index 0731f846..b114906d 100644 --- a/src/state/raw.rs +++ b/src/state/raw.rs @@ -342,6 +342,12 @@ impl RawLua { } } + #[cfg(not(any(feature = "lua51", feature="lua52", feature = "luajit")))] + #[inline] + pub(crate) fn is_yieldable(&self) -> bool { + unsafe { ffi::lua_isyieldable(self.state()) != 0 } + } + pub(crate) unsafe fn load_chunk_inner( &self, state: *mut ffi::lua_State, diff --git a/src/state/util.rs b/src/state/util.rs index ec701eaf..8a09138d 100644 --- a/src/state/util.rs +++ b/src/state/util.rs @@ -1,3 +1,4 @@ +use std::mem::take; use std::os::raw::c_int; use std::panic::{catch_unwind, AssertUnwindSafe}; use std::ptr; @@ -6,6 +7,7 @@ use std::sync::Arc; use crate::error::{Error, Result}; use crate::state::{ExtraData, RawLua}; use crate::util::{self, get_internal_metatable, WrappedFailure}; +use crate::IntoLuaMulti; pub(super) struct StateGuard<'a>(&'a RawLua, *mut ffi::lua_State); @@ -107,7 +109,20 @@ where prealloc_failure.release(state, extra); r } - Ok(Err(err)) => { + Ok(Err(mut err)) => { + if let Error::Yielding = err { + let raw = extra.as_ref().unwrap_unchecked().raw_lua(); + let values = take(&mut extra.as_mut().unwrap_unchecked().yielded_values); + match values.push_into_stack_multi(raw) { + Ok(nargs) => { + ffi::lua_yield(state, nargs); + unreachable!() + }, + Err(new_err) => { + err = new_err; + } + } + } let wrapped_error = prealloc_failure.r#use(state, extra); // Build `CallbackError` with traceback diff --git a/tests/thread.rs b/tests/thread.rs index 7ece2b56..fc548c4c 100644 --- a/tests/thread.rs +++ b/tests/thread.rs @@ -219,6 +219,18 @@ fn test_coroutine_panic() { } } +#[test] +fn test_yieldability() { + let lua = Lua::new(); + + let always_yield = lua.create_function(|lua, ()| { + lua.yield_args((42, "69420")) + }).unwrap(); + + let thread = lua.create_thread(always_yield).unwrap(); + assert_eq!(thread.resume::<(i32, String)>(()).unwrap(), (42, String::from("69420"))); +} + #[test] fn test_thread_pointer() -> Result<()> { let lua = Lua::new();