Skip to content

Commit 4b21fc5

Browse files
committed
Implement thread creation deletion event callback.
1 parent cd4091f commit 4b21fc5

File tree

6 files changed

+90
-2
lines changed

6 files changed

+90
-2
lines changed

src/lib.rs

+1-1
Original file line numberDiff line numberDiff line change
@@ -113,7 +113,7 @@ pub use crate::traits::{
113113
FromLua, FromLuaMulti, IntoLua, IntoLuaMulti, LuaNativeFn, LuaNativeFnMut, ObjectLike,
114114
};
115115
pub use crate::types::{
116-
AppDataRef, AppDataRefMut, Either, Integer, LightUserData, MaybeSend, Number, RegistryKey, VmState,
116+
AppDataRef, AppDataRefMut, Either, Integer, LightUserData, MaybeSend, Number, RegistryKey, VmState, ThreadEventInfo
117117
};
118118
pub use crate::userdata::{
119119
AnyUserData, MetaMethod, UserData, UserDataFields, UserDataMetatable, UserDataMethods, UserDataRef,

src/prelude.rs

+1
Original file line numberDiff line numberDiff line change
@@ -14,6 +14,7 @@ pub use crate::{
1414
UserDataMethods as LuaUserDataMethods, UserDataRef as LuaUserDataRef,
1515
UserDataRefMut as LuaUserDataRefMut, UserDataRegistry as LuaUserDataRegistry, Value as LuaValue,
1616
VmState as LuaVmState,
17+
ThreadEventInfo as LuaThreadEventInfo
1718
};
1819

1920
#[cfg(not(feature = "luau"))]

src/state.rs

+49-1
Original file line numberDiff line numberDiff line change
@@ -21,7 +21,7 @@ use crate::thread::Thread;
2121
use crate::traits::{FromLua, FromLuaMulti, IntoLua, IntoLuaMulti};
2222
use crate::types::{
2323
AppDataRef, AppDataRefMut, ArcReentrantMutexGuard, Integer, LuaType, MaybeSend, Number, ReentrantMutex,
24-
ReentrantMutexGuard, RegistryKey, VmState, XRc, XWeak,
24+
ReentrantMutexGuard, RegistryKey, VmState, XRc, XWeak, ThreadEventInfo
2525
};
2626
use crate::userdata::{AnyUserData, UserData, UserDataProxy, UserDataRegistry, UserDataStorage};
2727
use crate::util::{
@@ -671,6 +671,54 @@ impl Lua {
671671
}
672672
}
673673

674+
/// Sets a callback that will be called by Luau whenever a thread is created/destroyed.
675+
///
676+
/// Often used for keeping track of threads.
677+
#[cfg(any(feature = "luau", doc))]
678+
#[cfg_attr(docsrs, doc(cfg(feature = "luau")))]
679+
pub fn set_thread_event_callback<F>(&self, callback: F)
680+
where
681+
F: Fn(&Lua, ThreadEventInfo) -> Result<()> + MaybeSend + 'static,
682+
{
683+
use std::rc::Rc;
684+
685+
unsafe extern "C-unwind" fn userthread_proc(parent: *mut ffi::lua_State, state: *mut ffi::lua_State) {
686+
let raw_lua: &RawLua = (*ExtraData::get(state)).raw_lua();
687+
raw_lua.push_ref_thread(parent);
688+
callback_error_ext(state, ptr::null_mut(), move |extra, _| {
689+
let userthread_cb = (*extra).userthread_callback.clone();
690+
let userthread_cb = mlua_expect!(userthread_cb, "no userthread callback set in userthread_proc");
691+
let _guard = StateGuard::new((*extra).raw_lua(), state);
692+
let event_info = match (*extra).raw_lua().pop_value() {
693+
Value::Thread(thr) => ThreadEventInfo::Created(thr),
694+
Value::Nil => ThreadEventInfo::Destroying,
695+
_ => unreachable!()
696+
};
697+
userthread_cb((*extra).lua(), event_info)
698+
});
699+
}
700+
701+
// Set interrupt callback
702+
let lua = self.lock();
703+
unsafe {
704+
(*lua.extra.get()).userthread_callback = Some(Rc::new(callback));
705+
(*ffi::lua_callbacks(lua.main_state())).userthread = Some(userthread_proc);
706+
}
707+
}
708+
709+
/// Removes any thread event function previously set by `set_thread_event_callback`.
710+
///
711+
/// This function has no effect if a callback was not previously set.
712+
#[cfg(any(feature = "luau", doc))]
713+
#[cfg_attr(docsrs, doc(cfg(feature = "luau")))]
714+
pub fn remove_thread_event_callback(&self) {
715+
let lua = self.lock();
716+
unsafe {
717+
(*lua.extra.get()).userthread_callback = None;
718+
(*ffi::lua_callbacks(lua.main_state())).userthread = None;
719+
}
720+
}
721+
674722
/// Sets the warning function to be used by Lua to emit warnings.
675723
///
676724
/// Requires `feature = "lua54"`

src/state/extra.rs

+4
Original file line numberDiff line numberDiff line change
@@ -80,6 +80,8 @@ pub(crate) struct ExtraData {
8080
pub(super) warn_callback: Option<crate::types::WarnCallback>,
8181
#[cfg(feature = "luau")]
8282
pub(super) interrupt_callback: Option<crate::types::InterruptCallback>,
83+
#[cfg(feature = "luau")]
84+
pub(super) userthread_callback: Option<crate::types::ThreadEventCallback>,
8385

8486
#[cfg(feature = "luau")]
8587
pub(super) sandboxed: bool,
@@ -177,6 +179,8 @@ impl ExtraData {
177179
#[cfg(feature = "luau")]
178180
interrupt_callback: None,
179181
#[cfg(feature = "luau")]
182+
userthread_callback: None,
183+
#[cfg(feature = "luau")]
180184
sandboxed: false,
181185
#[cfg(feature = "luau")]
182186
compiler: None,

src/state/raw.rs

+16
Original file line numberDiff line numberDiff line change
@@ -556,6 +556,22 @@ impl RawLua {
556556
value.push_into_stack(self)
557557
}
558558

559+
pub(crate) unsafe fn push_ref_thread(&self, ref_thread: *mut ffi::lua_State) -> Result<()> {
560+
let state = self.state();
561+
let _sg = StackGuard::new(state);
562+
check_stack(state, 2)?;
563+
let _sg = StackGuard::new(ref_thread);
564+
check_stack(ref_thread, 1)?;
565+
566+
if self.unlikely_memory_error() {
567+
ffi::lua_pushthread(ref_thread)
568+
} else {
569+
protect_lua!(ref_thread, 0, 1, |ref_thread| ffi::lua_pushthread(ref_thread))?
570+
};
571+
ffi::lua_xmove(ref_thread, self.ref_thread(), 1);
572+
Ok(())
573+
}
574+
559575
/// Pushes a `Value` (by reference) onto the Lua stack.
560576
///
561577
/// Uses 2 stack spaces, does not call `checkstack`.

src/types.rs

+19
Original file line numberDiff line numberDiff line change
@@ -6,6 +6,7 @@ use crate::error::Result;
66
#[cfg(not(feature = "luau"))]
77
use crate::hook::Debug;
88
use crate::state::{ExtraData, Lua, RawLua};
9+
use crate::thread::Thread;
910

1011
// Re-export mutex wrappers
1112
pub(crate) use sync::{ArcReentrantMutexGuard, ReentrantMutex, ReentrantMutexGuard, XRc, XWeak};
@@ -73,6 +74,17 @@ pub enum VmState {
7374
Yield,
7475
}
7576

77+
/// Information about a thread event.
78+
///
79+
/// For creating a thread, it contains the thread that created it.
80+
///
81+
/// This is useful when it is required for tracking all threads and where they come from.
82+
#[cfg(feature = "luau")]
83+
pub enum ThreadEventInfo {
84+
Created(Thread),
85+
Destroying
86+
}
87+
7688
#[cfg(all(feature = "send", not(feature = "luau")))]
7789
pub(crate) type HookCallback = Rc<dyn Fn(&Lua, Debug) -> Result<VmState> + Send>;
7890

@@ -85,6 +97,13 @@ pub(crate) type InterruptCallback = Rc<dyn Fn(&Lua) -> Result<VmState> + Send>;
8597
#[cfg(all(not(feature = "send"), feature = "luau"))]
8698
pub(crate) type InterruptCallback = Rc<dyn Fn(&Lua) -> Result<VmState>>;
8799

100+
#[cfg(all(feature = "send", feature = "luau"))]
101+
pub(crate) type ThreadEventCallback = Rc<dyn Fn(&Lua, ThreadEventInfo) -> Result<()> + Send>;
102+
103+
#[cfg(all(not(feature = "send"), feature = "luau"))]
104+
pub(crate) type ThreadEventCallback = Rc<dyn Fn(&Lua, ThreadEventInfo) -> Result<()>>;
105+
106+
88107
#[cfg(all(feature = "send", feature = "lua54"))]
89108
pub(crate) type WarnCallback = Box<dyn Fn(&Lua, &str, bool) -> Result<()> + Send>;
90109

0 commit comments

Comments
 (0)