Skip to content

Commit 10b9e37

Browse files
committed
Change AsyncThread<A, R> to AsyncThread<R>.
Push arguments in `Thread::into_async()` to the thread during the call instead of first poll. The pushed arguments will be automatically used on resume. Fixes #508 and relates to #500.
1 parent cd4091f commit 10b9e37

File tree

3 files changed

+156
-83
lines changed

3 files changed

+156
-83
lines changed

src/function.rs

+3-3
Original file line numberDiff line numberDiff line change
@@ -161,10 +161,10 @@ impl Function {
161161
{
162162
let lua = self.0.lua.lock();
163163
let thread_res = unsafe {
164-
lua.create_recycled_thread(self).map(|th| {
165-
let mut th = th.into_async(args);
164+
lua.create_recycled_thread(self).and_then(|th| {
165+
let mut th = th.into_async(args)?;
166166
th.set_recyclable(true);
167-
th
167+
Ok(th)
168168
})
169169
};
170170
async move { thread_res?.await }

src/thread.rs

+120-77
Original file line numberDiff line numberDiff line change
@@ -42,6 +42,33 @@ pub enum ThreadStatus {
4242
Error,
4343
}
4444

45+
/// Internal representation of a Lua thread status.
46+
///
47+
/// The number in `New` and `Yielded` variants is the number of arguments pushed
48+
/// to the thread stack.
49+
#[derive(Clone, Copy)]
50+
enum ThreadStatusInner {
51+
New(c_int),
52+
Running,
53+
Yielded(c_int),
54+
Finished,
55+
Error,
56+
}
57+
58+
impl ThreadStatusInner {
59+
#[cfg(feature = "async")]
60+
#[inline(always)]
61+
fn is_resumable(self) -> bool {
62+
matches!(self, ThreadStatusInner::New(_) | ThreadStatusInner::Yielded(_))
63+
}
64+
65+
#[cfg(feature = "async")]
66+
#[inline(always)]
67+
fn is_yielded(self) -> bool {
68+
matches!(self, ThreadStatusInner::Yielded(_))
69+
}
70+
}
71+
4572
/// Handle to an internal Lua thread (coroutine).
4673
#[derive(Clone)]
4774
pub struct Thread(pub(crate) ValueRef, pub(crate) *mut ffi::lua_State);
@@ -60,9 +87,8 @@ unsafe impl Sync for Thread {}
6087
#[cfg(feature = "async")]
6188
#[cfg_attr(docsrs, doc(cfg(feature = "async")))]
6289
#[must_use = "futures do nothing unless you `.await` or poll them"]
63-
pub struct AsyncThread<A, R> {
90+
pub struct AsyncThread<R> {
6491
thread: Thread,
65-
init_args: Option<A>,
6692
ret: PhantomData<R>,
6793
recycle: bool,
6894
}
@@ -122,17 +148,25 @@ impl Thread {
122148
R: FromLuaMulti,
123149
{
124150
let lua = self.0.lua.lock();
125-
if self.status_inner(&lua) != ThreadStatus::Resumable {
126-
return Err(Error::CoroutineUnresumable);
127-
}
151+
let mut pushed_nargs = match self.status_inner(&lua) {
152+
ThreadStatusInner::New(nargs) | ThreadStatusInner::Yielded(nargs) => nargs,
153+
_ => return Err(Error::CoroutineUnresumable),
154+
};
128155

129156
let state = lua.state();
130157
let thread_state = self.state();
131158
unsafe {
132159
let _sg = StackGuard::new(state);
133160
let _thread_sg = StackGuard::with_top(thread_state, 0);
134161

135-
let nresults = self.resume_inner(&lua, args)?;
162+
let nargs = args.push_into_stack_multi(&lua)?;
163+
if nargs > 0 {
164+
check_stack(thread_state, nargs)?;
165+
ffi::lua_xmove(state, thread_state, nargs);
166+
pushed_nargs += nargs;
167+
}
168+
169+
let (_, nresults) = self.resume_inner(&lua, pushed_nargs)?;
136170
check_stack(state, nresults + 1)?;
137171
ffi::lua_xmove(thread_state, state, nresults);
138172

@@ -143,50 +177,50 @@ impl Thread {
143177
/// Resumes execution of this thread.
144178
///
145179
/// It's similar to `resume()` but leaves `nresults` values on the thread stack.
146-
unsafe fn resume_inner(&self, lua: &RawLua, args: impl IntoLuaMulti) -> Result<c_int> {
180+
unsafe fn resume_inner(&self, lua: &RawLua, nargs: c_int) -> Result<(ThreadStatusInner, c_int)> {
147181
let state = lua.state();
148182
let thread_state = self.state();
149-
150-
let nargs = args.push_into_stack_multi(lua)?;
151-
if nargs > 0 {
152-
check_stack(thread_state, nargs)?;
153-
ffi::lua_xmove(state, thread_state, nargs);
154-
}
155-
156183
let mut nresults = 0;
157184
let ret = ffi::lua_resume(thread_state, state, nargs, &mut nresults as *mut c_int);
158-
if ret != ffi::LUA_OK && ret != ffi::LUA_YIELD {
159-
if ret == ffi::LUA_ERRMEM {
185+
match ret {
186+
ffi::LUA_OK => Ok((ThreadStatusInner::Finished, nresults)),
187+
ffi::LUA_YIELD => Ok((ThreadStatusInner::Yielded(0), nresults)),
188+
ffi::LUA_ERRMEM => {
160189
// Don't call error handler for memory errors
161-
return Err(pop_error(thread_state, ret));
190+
Err(pop_error(thread_state, ret))
191+
}
192+
_ => {
193+
check_stack(state, 3)?;
194+
protect_lua!(state, 0, 1, |state| error_traceback_thread(state, thread_state))?;
195+
Err(pop_error(state, ret))
162196
}
163-
check_stack(state, 3)?;
164-
protect_lua!(state, 0, 1, |state| error_traceback_thread(state, thread_state))?;
165-
return Err(pop_error(state, ret));
166197
}
167-
168-
Ok(nresults)
169198
}
170199

171200
/// Gets the status of the thread.
172201
pub fn status(&self) -> ThreadStatus {
173-
self.status_inner(&self.0.lua.lock())
202+
match self.status_inner(&self.0.lua.lock()) {
203+
ThreadStatusInner::New(_) | ThreadStatusInner::Yielded(_) => ThreadStatus::Resumable,
204+
ThreadStatusInner::Running => ThreadStatus::Running,
205+
ThreadStatusInner::Finished => ThreadStatus::Finished,
206+
ThreadStatusInner::Error => ThreadStatus::Error,
207+
}
174208
}
175209

176210
/// Gets the status of the thread (internal implementation).
177-
pub(crate) fn status_inner(&self, lua: &RawLua) -> ThreadStatus {
211+
fn status_inner(&self, lua: &RawLua) -> ThreadStatusInner {
178212
let thread_state = self.state();
179213
if thread_state == lua.state() {
180214
// The thread is currently running
181-
return ThreadStatus::Running;
215+
return ThreadStatusInner::Running;
182216
}
183217
let status = unsafe { ffi::lua_status(thread_state) };
184-
if status != ffi::LUA_OK && status != ffi::LUA_YIELD {
185-
ThreadStatus::Error
186-
} else if status == ffi::LUA_YIELD || unsafe { ffi::lua_gettop(thread_state) > 0 } {
187-
ThreadStatus::Resumable
188-
} else {
189-
ThreadStatus::Finished
218+
let top = unsafe { ffi::lua_gettop(thread_state) };
219+
match status {
220+
ffi::LUA_YIELD => ThreadStatusInner::Yielded(top),
221+
ffi::LUA_OK if top > 0 => ThreadStatusInner::New(top - 1),
222+
ffi::LUA_OK => ThreadStatusInner::Finished,
223+
_ => ThreadStatusInner::Error,
190224
}
191225
}
192226

@@ -224,7 +258,7 @@ impl Thread {
224258
#[cfg_attr(docsrs, doc(cfg(any(feature = "lua54", feature = "luau"))))]
225259
pub fn reset(&self, func: crate::function::Function) -> Result<()> {
226260
let lua = self.0.lua.lock();
227-
if self.status_inner(&lua) == ThreadStatus::Running {
261+
if matches!(self.status_inner(&lua), ThreadStatusInner::Running) {
228262
return Err(Error::runtime("cannot reset a running thread"));
229263
}
230264

@@ -257,7 +291,9 @@ impl Thread {
257291

258292
/// Converts [`Thread`] to an [`AsyncThread`] which implements [`Future`] and [`Stream`] traits.
259293
///
260-
/// `args` are passed as arguments to the thread function for first call.
294+
/// Only resumable threads can be converted to [`AsyncThread`].
295+
///
296+
/// `args` are pushed to the thread stack and will be used when the thread is resumed.
261297
/// The object calls [`resume`] while polling and also allow to run Rust futures
262298
/// to completion using an executor.
263299
///
@@ -290,7 +326,7 @@ impl Thread {
290326
/// end)
291327
/// "#).eval()?;
292328
///
293-
/// let mut stream = thread.into_async::<i64>(1);
329+
/// let mut stream = thread.into_async::<i64>(1)?;
294330
/// let mut sum = 0;
295331
/// while let Some(n) = stream.try_next().await? {
296332
/// sum += n;
@@ -303,15 +339,31 @@ impl Thread {
303339
/// ```
304340
#[cfg(feature = "async")]
305341
#[cfg_attr(docsrs, doc(cfg(feature = "async")))]
306-
pub fn into_async<R>(self, args: impl IntoLuaMulti) -> AsyncThread<impl IntoLuaMulti, R>
342+
pub fn into_async<R>(self, args: impl IntoLuaMulti) -> Result<AsyncThread<R>>
307343
where
308344
R: FromLuaMulti,
309345
{
310-
AsyncThread {
311-
thread: self,
312-
init_args: Some(args),
313-
ret: PhantomData,
314-
recycle: false,
346+
let lua = self.0.lua.lock();
347+
if !self.status_inner(&lua).is_resumable() {
348+
return Err(Error::CoroutineUnresumable);
349+
}
350+
351+
let state = lua.state();
352+
let thread_state = self.state();
353+
unsafe {
354+
let _sg = StackGuard::new(state);
355+
356+
let nargs = args.push_into_stack_multi(&lua)?;
357+
if nargs > 0 {
358+
check_stack(thread_state, nargs)?;
359+
ffi::lua_xmove(state, thread_state, nargs);
360+
}
361+
362+
Ok(AsyncThread {
363+
thread: self,
364+
ret: PhantomData,
365+
recycle: false,
366+
})
315367
}
316368
}
317369

@@ -392,7 +444,7 @@ impl LuaType for Thread {
392444
}
393445

394446
#[cfg(feature = "async")]
395-
impl<A, R> AsyncThread<A, R> {
447+
impl<R> AsyncThread<R> {
396448
#[inline]
397449
pub(crate) fn set_recyclable(&mut self, recyclable: bool) {
398450
self.recycle = recyclable;
@@ -401,15 +453,15 @@ impl<A, R> AsyncThread<A, R> {
401453

402454
#[cfg(feature = "async")]
403455
#[cfg(any(feature = "lua54", feature = "luau"))]
404-
impl<A, R> Drop for AsyncThread<A, R> {
456+
impl<R> Drop for AsyncThread<R> {
405457
fn drop(&mut self) {
406458
if self.recycle {
407459
if let Some(lua) = self.thread.0.lua.try_lock() {
408460
unsafe {
409461
// For Lua 5.4 this also closes all pending to-be-closed variables
410462
if !lua.recycle_thread(&mut self.thread) {
411463
#[cfg(feature = "lua54")]
412-
if self.thread.status_inner(&lua) == ThreadStatus::Error {
464+
if matches!(self.thread.status_inner(&lua), ThreadStatusInner::Error) {
413465
#[cfg(not(feature = "vendored"))]
414466
ffi::lua_resetthread(self.thread.state());
415467
#[cfg(feature = "vendored")]
@@ -423,14 +475,15 @@ impl<A, R> Drop for AsyncThread<A, R> {
423475
}
424476

425477
#[cfg(feature = "async")]
426-
impl<A: IntoLuaMulti, R: FromLuaMulti> Stream for AsyncThread<A, R> {
478+
impl<R: FromLuaMulti> Stream for AsyncThread<R> {
427479
type Item = Result<R>;
428480

429481
fn poll_next(self: Pin<&mut Self>, cx: &mut Context<'_>) -> Poll<Option<Self::Item>> {
430482
let lua = self.thread.0.lua.lock();
431-
if self.thread.status_inner(&lua) != ThreadStatus::Resumable {
432-
return Poll::Ready(None);
433-
}
483+
let nargs = match self.thread.status_inner(&lua) {
484+
ThreadStatusInner::New(nargs) | ThreadStatusInner::Yielded(nargs) => nargs,
485+
_ => return Poll::Ready(None),
486+
};
434487

435488
let state = lua.state();
436489
let thread_state = self.thread.state();
@@ -439,36 +492,34 @@ impl<A: IntoLuaMulti, R: FromLuaMulti> Stream for AsyncThread<A, R> {
439492
let _thread_sg = StackGuard::with_top(thread_state, 0);
440493
let _wg = WakerGuard::new(&lua, cx.waker());
441494

442-
// This is safe as we are not moving the whole struct
443-
let this = self.get_unchecked_mut();
444-
let nresults = if let Some(args) = this.init_args.take() {
445-
this.thread.resume_inner(&lua, args)?
446-
} else {
447-
this.thread.resume_inner(&lua, ())?
448-
};
495+
let (status, nresults) = (self.thread).resume_inner(&lua, nargs)?;
449496

450-
if nresults == 1 && is_poll_pending(thread_state) {
451-
return Poll::Pending;
497+
if status.is_yielded() {
498+
if nresults == 1 && is_poll_pending(thread_state) {
499+
return Poll::Pending;
500+
}
501+
// Continue polling
502+
cx.waker().wake_by_ref();
452503
}
453504

454505
check_stack(state, nresults + 1)?;
455506
ffi::lua_xmove(thread_state, state, nresults);
456507

457-
cx.waker().wake_by_ref();
458508
Poll::Ready(Some(R::from_stack_multi(nresults, &lua)))
459509
}
460510
}
461511
}
462512

463513
#[cfg(feature = "async")]
464-
impl<A: IntoLuaMulti, R: FromLuaMulti> Future for AsyncThread<A, R> {
514+
impl<R: FromLuaMulti> Future for AsyncThread<R> {
465515
type Output = Result<R>;
466516

467517
fn poll(self: Pin<&mut Self>, cx: &mut Context<'_>) -> Poll<Self::Output> {
468518
let lua = self.thread.0.lua.lock();
469-
if self.thread.status_inner(&lua) != ThreadStatus::Resumable {
470-
return Poll::Ready(Err(Error::CoroutineUnresumable));
471-
}
519+
let nargs = match self.thread.status_inner(&lua) {
520+
ThreadStatusInner::New(nargs) | ThreadStatusInner::Yielded(nargs) => nargs,
521+
_ => return Poll::Ready(Err(Error::CoroutineUnresumable)),
522+
};
472523

473524
let state = lua.state();
474525
let thread_state = self.thread.state();
@@ -477,21 +528,13 @@ impl<A: IntoLuaMulti, R: FromLuaMulti> Future for AsyncThread<A, R> {
477528
let _thread_sg = StackGuard::with_top(thread_state, 0);
478529
let _wg = WakerGuard::new(&lua, cx.waker());
479530

480-
// This is safe as we are not moving the whole struct
481-
let this = self.get_unchecked_mut();
482-
let nresults = if let Some(args) = this.init_args.take() {
483-
this.thread.resume_inner(&lua, args)?
484-
} else {
485-
this.thread.resume_inner(&lua, ())?
486-
};
487-
488-
if nresults == 1 && is_poll_pending(thread_state) {
489-
return Poll::Pending;
490-
}
531+
let (status, nresults) = self.thread.resume_inner(&lua, nargs)?;
491532

492-
if ffi::lua_status(thread_state) == ffi::LUA_YIELD {
493-
// Ignore value returned via yield()
494-
cx.waker().wake_by_ref();
533+
if status.is_yielded() {
534+
if !(nresults == 1 && is_poll_pending(thread_state)) {
535+
// Ignore value returned via yield()
536+
cx.waker().wake_by_ref();
537+
}
495538
return Poll::Pending;
496539
}
497540

@@ -545,7 +588,7 @@ mod assertions {
545588
#[cfg(feature = "send")]
546589
static_assertions::assert_impl_all!(Thread: Send, Sync);
547590
#[cfg(all(feature = "async", not(feature = "send")))]
548-
static_assertions::assert_not_impl_any!(AsyncThread<(), ()>: Send);
591+
static_assertions::assert_not_impl_any!(AsyncThread<()>: Send);
549592
#[cfg(all(feature = "async", feature = "send"))]
550-
static_assertions::assert_impl_all!(AsyncThread<(), ()>: Send, Sync);
593+
static_assertions::assert_impl_all!(AsyncThread<()>: Send, Sync);
551594
}

0 commit comments

Comments
 (0)