@@ -42,6 +42,33 @@ pub enum ThreadStatus {
42
42
Error ,
43
43
}
44
44
45
+ /// Internal representation of a Lua thread status.
46
+ ///
47
+ /// The number in `New` and `Yielded` variants is the number of arguments pushed
48
+ /// to the thread stack.
49
+ #[ derive( Clone , Copy ) ]
50
+ enum ThreadStatusInner {
51
+ New ( c_int ) ,
52
+ Running ,
53
+ Yielded ( c_int ) ,
54
+ Finished ,
55
+ Error ,
56
+ }
57
+
58
+ impl ThreadStatusInner {
59
+ #[ cfg( feature = "async" ) ]
60
+ #[ inline( always) ]
61
+ fn is_resumable ( self ) -> bool {
62
+ matches ! ( self , ThreadStatusInner :: New ( _) | ThreadStatusInner :: Yielded ( _) )
63
+ }
64
+
65
+ #[ cfg( feature = "async" ) ]
66
+ #[ inline( always) ]
67
+ fn is_yielded ( self ) -> bool {
68
+ matches ! ( self , ThreadStatusInner :: Yielded ( _) )
69
+ }
70
+ }
71
+
45
72
/// Handle to an internal Lua thread (coroutine).
46
73
#[ derive( Clone ) ]
47
74
pub struct Thread ( pub ( crate ) ValueRef , pub ( crate ) * mut ffi:: lua_State ) ;
@@ -60,9 +87,8 @@ unsafe impl Sync for Thread {}
60
87
#[ cfg( feature = "async" ) ]
61
88
#[ cfg_attr( docsrs, doc( cfg( feature = "async" ) ) ) ]
62
89
#[ must_use = "futures do nothing unless you `.await` or poll them" ]
63
- pub struct AsyncThread < A , R > {
90
+ pub struct AsyncThread < R > {
64
91
thread : Thread ,
65
- init_args : Option < A > ,
66
92
ret : PhantomData < R > ,
67
93
recycle : bool ,
68
94
}
@@ -122,17 +148,25 @@ impl Thread {
122
148
R : FromLuaMulti ,
123
149
{
124
150
let lua = self . 0 . lua . lock ( ) ;
125
- if self . status_inner ( & lua) != ThreadStatus :: Resumable {
126
- return Err ( Error :: CoroutineUnresumable ) ;
127
- }
151
+ let mut pushed_nargs = match self . status_inner ( & lua) {
152
+ ThreadStatusInner :: New ( nargs) | ThreadStatusInner :: Yielded ( nargs) => nargs,
153
+ _ => return Err ( Error :: CoroutineUnresumable ) ,
154
+ } ;
128
155
129
156
let state = lua. state ( ) ;
130
157
let thread_state = self . state ( ) ;
131
158
unsafe {
132
159
let _sg = StackGuard :: new ( state) ;
133
160
let _thread_sg = StackGuard :: with_top ( thread_state, 0 ) ;
134
161
135
- let nresults = self . resume_inner ( & lua, args) ?;
162
+ let nargs = args. push_into_stack_multi ( & lua) ?;
163
+ if nargs > 0 {
164
+ check_stack ( thread_state, nargs) ?;
165
+ ffi:: lua_xmove ( state, thread_state, nargs) ;
166
+ pushed_nargs += nargs;
167
+ }
168
+
169
+ let ( _, nresults) = self . resume_inner ( & lua, pushed_nargs) ?;
136
170
check_stack ( state, nresults + 1 ) ?;
137
171
ffi:: lua_xmove ( thread_state, state, nresults) ;
138
172
@@ -143,50 +177,50 @@ impl Thread {
143
177
/// Resumes execution of this thread.
144
178
///
145
179
/// It's similar to `resume()` but leaves `nresults` values on the thread stack.
146
- unsafe fn resume_inner ( & self , lua : & RawLua , args : impl IntoLuaMulti ) -> Result < c_int > {
180
+ unsafe fn resume_inner ( & self , lua : & RawLua , nargs : c_int ) -> Result < ( ThreadStatusInner , c_int ) > {
147
181
let state = lua. state ( ) ;
148
182
let thread_state = self . state ( ) ;
149
-
150
- let nargs = args. push_into_stack_multi ( lua) ?;
151
- if nargs > 0 {
152
- check_stack ( thread_state, nargs) ?;
153
- ffi:: lua_xmove ( state, thread_state, nargs) ;
154
- }
155
-
156
183
let mut nresults = 0 ;
157
184
let ret = ffi:: lua_resume ( thread_state, state, nargs, & mut nresults as * mut c_int ) ;
158
- if ret != ffi:: LUA_OK && ret != ffi:: LUA_YIELD {
159
- if ret == ffi:: LUA_ERRMEM {
185
+ match ret {
186
+ ffi:: LUA_OK => Ok ( ( ThreadStatusInner :: Finished , nresults) ) ,
187
+ ffi:: LUA_YIELD => Ok ( ( ThreadStatusInner :: Yielded ( 0 ) , nresults) ) ,
188
+ ffi:: LUA_ERRMEM => {
160
189
// Don't call error handler for memory errors
161
- return Err ( pop_error ( thread_state, ret) ) ;
190
+ Err ( pop_error ( thread_state, ret) )
191
+ }
192
+ _ => {
193
+ check_stack ( state, 3 ) ?;
194
+ protect_lua ! ( state, 0 , 1 , |state| error_traceback_thread( state, thread_state) ) ?;
195
+ Err ( pop_error ( state, ret) )
162
196
}
163
- check_stack ( state, 3 ) ?;
164
- protect_lua ! ( state, 0 , 1 , |state| error_traceback_thread( state, thread_state) ) ?;
165
- return Err ( pop_error ( state, ret) ) ;
166
197
}
167
-
168
- Ok ( nresults)
169
198
}
170
199
171
200
/// Gets the status of the thread.
172
201
pub fn status ( & self ) -> ThreadStatus {
173
- self . status_inner ( & self . 0 . lua . lock ( ) )
202
+ match self . status_inner ( & self . 0 . lua . lock ( ) ) {
203
+ ThreadStatusInner :: New ( _) | ThreadStatusInner :: Yielded ( _) => ThreadStatus :: Resumable ,
204
+ ThreadStatusInner :: Running => ThreadStatus :: Running ,
205
+ ThreadStatusInner :: Finished => ThreadStatus :: Finished ,
206
+ ThreadStatusInner :: Error => ThreadStatus :: Error ,
207
+ }
174
208
}
175
209
176
210
/// Gets the status of the thread (internal implementation).
177
- pub ( crate ) fn status_inner ( & self , lua : & RawLua ) -> ThreadStatus {
211
+ fn status_inner ( & self , lua : & RawLua ) -> ThreadStatusInner {
178
212
let thread_state = self . state ( ) ;
179
213
if thread_state == lua. state ( ) {
180
214
// The thread is currently running
181
- return ThreadStatus :: Running ;
215
+ return ThreadStatusInner :: Running ;
182
216
}
183
217
let status = unsafe { ffi:: lua_status ( thread_state) } ;
184
- if status != ffi :: LUA_OK && status != ffi:: LUA_YIELD {
185
- ThreadStatus :: Error
186
- } else if status == ffi:: LUA_YIELD || unsafe { ffi :: lua_gettop ( thread_state ) > 0 } {
187
- ThreadStatus :: Resumable
188
- } else {
189
- ThreadStatus :: Finished
218
+ let top = unsafe { ffi:: lua_gettop ( thread_state ) } ;
219
+ match status {
220
+ ffi:: LUA_YIELD => ThreadStatusInner :: Yielded ( top ) ,
221
+ ffi :: LUA_OK if top > 0 => ThreadStatusInner :: New ( top - 1 ) ,
222
+ ffi :: LUA_OK => ThreadStatusInner :: Finished ,
223
+ _ => ThreadStatusInner :: Error ,
190
224
}
191
225
}
192
226
@@ -224,7 +258,7 @@ impl Thread {
224
258
#[ cfg_attr( docsrs, doc( cfg( any( feature = "lua54" , feature = "luau" ) ) ) ) ]
225
259
pub fn reset ( & self , func : crate :: function:: Function ) -> Result < ( ) > {
226
260
let lua = self . 0 . lua . lock ( ) ;
227
- if self . status_inner ( & lua) == ThreadStatus :: Running {
261
+ if matches ! ( self . status_inner( & lua) , ThreadStatusInner :: Running ) {
228
262
return Err ( Error :: runtime ( "cannot reset a running thread" ) ) ;
229
263
}
230
264
@@ -257,7 +291,9 @@ impl Thread {
257
291
258
292
/// Converts [`Thread`] to an [`AsyncThread`] which implements [`Future`] and [`Stream`] traits.
259
293
///
260
- /// `args` are passed as arguments to the thread function for first call.
294
+ /// Only resumable threads can be converted to [`AsyncThread`].
295
+ ///
296
+ /// `args` are pushed to the thread stack and will be used when the thread is resumed.
261
297
/// The object calls [`resume`] while polling and also allow to run Rust futures
262
298
/// to completion using an executor.
263
299
///
@@ -290,7 +326,7 @@ impl Thread {
290
326
/// end)
291
327
/// "#).eval()?;
292
328
///
293
- /// let mut stream = thread.into_async::<i64>(1);
329
+ /// let mut stream = thread.into_async::<i64>(1)? ;
294
330
/// let mut sum = 0;
295
331
/// while let Some(n) = stream.try_next().await? {
296
332
/// sum += n;
@@ -303,15 +339,31 @@ impl Thread {
303
339
/// ```
304
340
#[ cfg( feature = "async" ) ]
305
341
#[ cfg_attr( docsrs, doc( cfg( feature = "async" ) ) ) ]
306
- pub fn into_async < R > ( self , args : impl IntoLuaMulti ) -> AsyncThread < impl IntoLuaMulti , R >
342
+ pub fn into_async < R > ( self , args : impl IntoLuaMulti ) -> Result < AsyncThread < R > >
307
343
where
308
344
R : FromLuaMulti ,
309
345
{
310
- AsyncThread {
311
- thread : self ,
312
- init_args : Some ( args) ,
313
- ret : PhantomData ,
314
- recycle : false ,
346
+ let lua = self . 0 . lua . lock ( ) ;
347
+ if !self . status_inner ( & lua) . is_resumable ( ) {
348
+ return Err ( Error :: CoroutineUnresumable ) ;
349
+ }
350
+
351
+ let state = lua. state ( ) ;
352
+ let thread_state = self . state ( ) ;
353
+ unsafe {
354
+ let _sg = StackGuard :: new ( state) ;
355
+
356
+ let nargs = args. push_into_stack_multi ( & lua) ?;
357
+ if nargs > 0 {
358
+ check_stack ( thread_state, nargs) ?;
359
+ ffi:: lua_xmove ( state, thread_state, nargs) ;
360
+ }
361
+
362
+ Ok ( AsyncThread {
363
+ thread : self ,
364
+ ret : PhantomData ,
365
+ recycle : false ,
366
+ } )
315
367
}
316
368
}
317
369
@@ -392,7 +444,7 @@ impl LuaType for Thread {
392
444
}
393
445
394
446
#[ cfg( feature = "async" ) ]
395
- impl < A , R > AsyncThread < A , R > {
447
+ impl < R > AsyncThread < R > {
396
448
#[ inline]
397
449
pub ( crate ) fn set_recyclable ( & mut self , recyclable : bool ) {
398
450
self . recycle = recyclable;
@@ -401,15 +453,15 @@ impl<A, R> AsyncThread<A, R> {
401
453
402
454
#[ cfg( feature = "async" ) ]
403
455
#[ cfg( any( feature = "lua54" , feature = "luau" ) ) ]
404
- impl < A , R > Drop for AsyncThread < A , R > {
456
+ impl < R > Drop for AsyncThread < R > {
405
457
fn drop ( & mut self ) {
406
458
if self . recycle {
407
459
if let Some ( lua) = self . thread . 0 . lua . try_lock ( ) {
408
460
unsafe {
409
461
// For Lua 5.4 this also closes all pending to-be-closed variables
410
462
if !lua. recycle_thread ( & mut self . thread ) {
411
463
#[ cfg( feature = "lua54" ) ]
412
- if self . thread . status_inner ( & lua) == ThreadStatus :: Error {
464
+ if matches ! ( self . thread. status_inner( & lua) , ThreadStatusInner :: Error ) {
413
465
#[ cfg( not( feature = "vendored" ) ) ]
414
466
ffi:: lua_resetthread ( self . thread . state ( ) ) ;
415
467
#[ cfg( feature = "vendored" ) ]
@@ -423,14 +475,15 @@ impl<A, R> Drop for AsyncThread<A, R> {
423
475
}
424
476
425
477
#[ cfg( feature = "async" ) ]
426
- impl < A : IntoLuaMulti , R : FromLuaMulti > Stream for AsyncThread < A , R > {
478
+ impl < R : FromLuaMulti > Stream for AsyncThread < R > {
427
479
type Item = Result < R > ;
428
480
429
481
fn poll_next ( self : Pin < & mut Self > , cx : & mut Context < ' _ > ) -> Poll < Option < Self :: Item > > {
430
482
let lua = self . thread . 0 . lua . lock ( ) ;
431
- if self . thread . status_inner ( & lua) != ThreadStatus :: Resumable {
432
- return Poll :: Ready ( None ) ;
433
- }
483
+ let nargs = match self . thread . status_inner ( & lua) {
484
+ ThreadStatusInner :: New ( nargs) | ThreadStatusInner :: Yielded ( nargs) => nargs,
485
+ _ => return Poll :: Ready ( None ) ,
486
+ } ;
434
487
435
488
let state = lua. state ( ) ;
436
489
let thread_state = self . thread . state ( ) ;
@@ -439,36 +492,34 @@ impl<A: IntoLuaMulti, R: FromLuaMulti> Stream for AsyncThread<A, R> {
439
492
let _thread_sg = StackGuard :: with_top ( thread_state, 0 ) ;
440
493
let _wg = WakerGuard :: new ( & lua, cx. waker ( ) ) ;
441
494
442
- // This is safe as we are not moving the whole struct
443
- let this = self . get_unchecked_mut ( ) ;
444
- let nresults = if let Some ( args) = this. init_args . take ( ) {
445
- this. thread . resume_inner ( & lua, args) ?
446
- } else {
447
- this. thread . resume_inner ( & lua, ( ) ) ?
448
- } ;
495
+ let ( status, nresults) = ( self . thread ) . resume_inner ( & lua, nargs) ?;
449
496
450
- if nresults == 1 && is_poll_pending ( thread_state) {
451
- return Poll :: Pending ;
497
+ if status. is_yielded ( ) {
498
+ if nresults == 1 && is_poll_pending ( thread_state) {
499
+ return Poll :: Pending ;
500
+ }
501
+ // Continue polling
502
+ cx. waker ( ) . wake_by_ref ( ) ;
452
503
}
453
504
454
505
check_stack ( state, nresults + 1 ) ?;
455
506
ffi:: lua_xmove ( thread_state, state, nresults) ;
456
507
457
- cx. waker ( ) . wake_by_ref ( ) ;
458
508
Poll :: Ready ( Some ( R :: from_stack_multi ( nresults, & lua) ) )
459
509
}
460
510
}
461
511
}
462
512
463
513
#[ cfg( feature = "async" ) ]
464
- impl < A : IntoLuaMulti , R : FromLuaMulti > Future for AsyncThread < A , R > {
514
+ impl < R : FromLuaMulti > Future for AsyncThread < R > {
465
515
type Output = Result < R > ;
466
516
467
517
fn poll ( self : Pin < & mut Self > , cx : & mut Context < ' _ > ) -> Poll < Self :: Output > {
468
518
let lua = self . thread . 0 . lua . lock ( ) ;
469
- if self . thread . status_inner ( & lua) != ThreadStatus :: Resumable {
470
- return Poll :: Ready ( Err ( Error :: CoroutineUnresumable ) ) ;
471
- }
519
+ let nargs = match self . thread . status_inner ( & lua) {
520
+ ThreadStatusInner :: New ( nargs) | ThreadStatusInner :: Yielded ( nargs) => nargs,
521
+ _ => return Poll :: Ready ( Err ( Error :: CoroutineUnresumable ) ) ,
522
+ } ;
472
523
473
524
let state = lua. state ( ) ;
474
525
let thread_state = self . thread . state ( ) ;
@@ -477,21 +528,13 @@ impl<A: IntoLuaMulti, R: FromLuaMulti> Future for AsyncThread<A, R> {
477
528
let _thread_sg = StackGuard :: with_top ( thread_state, 0 ) ;
478
529
let _wg = WakerGuard :: new ( & lua, cx. waker ( ) ) ;
479
530
480
- // This is safe as we are not moving the whole struct
481
- let this = self . get_unchecked_mut ( ) ;
482
- let nresults = if let Some ( args) = this. init_args . take ( ) {
483
- this. thread . resume_inner ( & lua, args) ?
484
- } else {
485
- this. thread . resume_inner ( & lua, ( ) ) ?
486
- } ;
487
-
488
- if nresults == 1 && is_poll_pending ( thread_state) {
489
- return Poll :: Pending ;
490
- }
531
+ let ( status, nresults) = self . thread . resume_inner ( & lua, nargs) ?;
491
532
492
- if ffi:: lua_status ( thread_state) == ffi:: LUA_YIELD {
493
- // Ignore value returned via yield()
494
- cx. waker ( ) . wake_by_ref ( ) ;
533
+ if status. is_yielded ( ) {
534
+ if !( nresults == 1 && is_poll_pending ( thread_state) ) {
535
+ // Ignore value returned via yield()
536
+ cx. waker ( ) . wake_by_ref ( ) ;
537
+ }
495
538
return Poll :: Pending ;
496
539
}
497
540
@@ -545,7 +588,7 @@ mod assertions {
545
588
#[ cfg( feature = "send" ) ]
546
589
static_assertions:: assert_impl_all!( Thread : Send , Sync ) ;
547
590
#[ cfg( all( feature = "async" , not( feature = "send" ) ) ) ]
548
- static_assertions:: assert_not_impl_any!( AsyncThread <( ) , ( ) >: Send ) ;
591
+ static_assertions:: assert_not_impl_any!( AsyncThread <( ) >: Send ) ;
549
592
#[ cfg( all( feature = "async" , feature = "send" ) ) ]
550
- static_assertions:: assert_impl_all!( AsyncThread <( ) , ( ) >: Send , Sync ) ;
593
+ static_assertions:: assert_impl_all!( AsyncThread <( ) >: Send , Sync ) ;
551
594
}
0 commit comments