7
7
8
8
use heck:: ToKebabCase ;
9
9
use heck:: { ToSnakeCase , ToUpperCamelCase } ;
10
- use proc_macro2:: TokenStream ;
10
+ use proc_macro2:: { Literal , TokenStream } ;
11
11
use quote:: format_ident;
12
12
use quote:: quote;
13
13
use std:: collections:: HashSet ;
14
14
use std:: path:: PathBuf ;
15
15
use tauri_bindgen_core:: { Generate , GeneratorBuilder , TypeInfo , TypeInfos } ;
16
16
use tauri_bindgen_gen_rust:: { print_generics, BorrowMode , FnSig , RustGenerator } ;
17
- use wit_parser:: { Function , FunctionResult , Interface , Type , TypeDefKind } ;
17
+ use wit_parser:: { Function , Interface , Type , TypeDefKind } ;
18
18
19
19
#[ derive( Default , Debug , Clone ) ]
20
20
#[ cfg_attr( feature = "clap" , derive( clap:: Args ) ) ]
@@ -252,7 +252,7 @@ impl Host {
252
252
253
253
let functions = functions. map ( |func| {
254
254
let sig = FnSig {
255
- async_ : false ,
255
+ async_ : self . opts . async_ ,
256
256
unsafe_ : false ,
257
257
private : true ,
258
258
self_arg : Some ( quote ! ( & self ) ) ,
@@ -266,7 +266,13 @@ impl Host {
266
266
267
267
let sized = sized. then_some ( quote ! ( : Sized ) ) ;
268
268
269
+ let async_trait = self
270
+ . opts
271
+ . async_
272
+ . then_some ( quote ! { #[ :: tauri_bindgen_host:: async_trait] } ) ;
273
+
269
274
quote ! {
275
+ #async_trait
270
276
pub trait #ident #sized {
271
277
#( #additional_items) *
272
278
#( #functions) *
@@ -304,114 +310,145 @@ impl Host {
304
310
}
305
311
}
306
312
307
- fn print_add_to_router < ' a > (
308
- & self ,
309
- mod_ident : & str ,
310
- functions : impl Iterator < Item = & ' a Function > ,
311
- methods : impl Iterator < Item = ( & ' a str , & ' a Function ) > ,
312
- ) -> TokenStream {
313
- let trait_ident = format_ident ! ( "{}" , mod_ident. to_upper_camel_case( ) ) ;
314
-
315
- let mod_name = mod_ident. to_snake_case ( ) ;
316
-
317
- let functions = functions. map ( |func| {
318
- let func_name = func. id . to_snake_case ( ) ;
319
- let func_ident = format_ident ! ( "{}" , func_name) ;
320
-
321
- let params = self . print_function_params ( & func. params , & BorrowMode :: Owned ) ;
322
-
323
- let param_idents = func
324
- . params
325
- . iter ( )
326
- . map ( |( ident, _) | { format_ident ! ( "{}" , ident) } ) ;
327
-
328
- let result = match func. result . as_ref ( ) {
329
- Some ( FunctionResult :: Anon ( ty) ) => {
330
- let ty = self . print_ty ( ty, & BorrowMode :: Owned ) ;
331
-
332
- quote ! { #ty }
333
- }
334
- Some ( FunctionResult :: Named ( types) ) if types. len ( ) == 1 => {
335
- let ( _, ty) = & types[ 0 ] ;
336
- let ty = self . print_ty ( ty, & BorrowMode :: Owned ) ;
337
-
338
- quote ! { #ty }
339
- }
340
- Some ( FunctionResult :: Named ( types) ) => {
341
- let types = types. iter ( ) . map ( |( _, ty) | self . print_ty ( ty, & BorrowMode :: Owned ) ) ;
313
+ fn print_router_fn_definition ( & self , mod_name : & str , func : & Function ) -> TokenStream {
314
+ let func_name = func. ident . to_snake_case ( ) ;
315
+ let func_ident = format_ident ! ( "{}" , func_name) ;
342
316
343
- quote ! { ( #( #types) , * ) }
344
- }
345
- _ => quote ! { ( ) } ,
346
- } ;
317
+ let param_decl = match func. params . len ( ) {
318
+ 0 => quote ! { ( ) } ,
319
+ 1 => {
320
+ let ty = & func. params . first ( ) . unwrap ( ) . 1 ;
321
+ let ty = self . print_ty ( ty, & BorrowMode :: Owned ) ;
322
+ quote ! { #ty }
323
+ }
324
+ _ => {
325
+ let tys = func
326
+ . params
327
+ . iter ( )
328
+ . map ( |( _, ty) | self . print_ty ( ty, & BorrowMode :: Owned ) ) ;
329
+ quote ! { ( #( #tys) , * ) }
330
+ }
331
+ } ;
332
+
333
+ let param_acc = match func. params . len ( ) {
334
+ 0 => quote ! { } ,
335
+ 1 => quote ! { p } ,
336
+ _ => {
337
+ let ids = func. params . iter ( ) . enumerate ( ) . map ( |( i, _) | {
338
+ let i = Literal :: usize_unsuffixed ( i) ;
339
+ quote ! { p. #i }
340
+ } ) ;
341
+ quote ! { #( #ids) , * }
342
+ }
343
+ } ;
347
344
345
+ if self . opts . async_ {
346
+ quote ! {
347
+ let get_cx = :: std:: sync:: Arc :: clone( & wrapped_get_cx) ;
348
+ router. define_async(
349
+ #mod_name,
350
+ #func_name,
351
+ move |ctx: :: tauri_bindgen_host:: ipc_router_wip:: Caller <T >, p: #param_decl| {
352
+ let get_cx = get_cx. clone( ) ;
353
+ Box :: pin( async move {
354
+ let ctx = get_cx( ctx. data( ) ) ;
355
+ Ok ( ctx. #func_ident( #param_acc) . await )
356
+ } )
357
+ } ) ?;
358
+ }
359
+ } else {
348
360
quote ! {
349
361
let get_cx = :: std:: sync:: Arc :: clone( & wrapped_get_cx) ;
350
- router. func_wrap (
362
+ router. define (
351
363
#mod_name,
352
364
#func_name,
353
- move |ctx: :: tauri_bindgen_host:: ipc_router_wip:: Caller <T >, #params| -> :: tauri_bindgen_host :: anyhow :: Result <#result> {
365
+ move |ctx: :: tauri_bindgen_host:: ipc_router_wip:: Caller <T >, p : #param_decl| {
354
366
let ctx = get_cx( ctx. data( ) ) ;
355
367
356
- Ok ( ctx. #func_ident( #( #param_idents ) , * ) )
368
+ Ok ( ctx. #func_ident( #param_acc ) )
357
369
} ,
358
370
) ?;
359
371
}
360
- } ) ;
361
-
362
- let methods = methods. map ( |( resource_name, method) | {
363
- let func_name = method. id . to_snake_case ( ) ;
364
- let func_ident = format_ident ! ( "{}" , func_name) ;
365
-
366
- let params = self . print_function_params ( & method. params , & BorrowMode :: Owned ) ;
367
-
368
- let param_idents = method
369
- . params
370
- . iter ( )
371
- . map ( |( ident, _) | format_ident ! ( "{}" , ident) ) ;
372
-
373
- let result = match method. result . as_ref ( ) {
374
- Some ( FunctionResult :: Anon ( ty) ) => {
375
- let ty = self . print_ty ( ty, & BorrowMode :: Owned ) ;
376
-
377
- quote ! { #ty }
378
- }
379
- Some ( FunctionResult :: Named ( types) ) if types. len ( ) == 1 => {
380
- let ( _, ty) = & types[ 0 ] ;
381
- let ty = self . print_ty ( ty, & BorrowMode :: Owned ) ;
372
+ }
373
+ }
382
374
383
- quote ! { #ty }
384
- }
385
- Some ( FunctionResult :: Named ( types) ) => {
386
- let types = types
387
- . iter ( )
388
- . map ( |( _, ty) | self . print_ty ( ty, & BorrowMode :: Owned ) ) ;
375
+ fn print_router_method_definition (
376
+ & self ,
377
+ mod_name : & str ,
378
+ resource_name : & str ,
379
+ method : & Function ,
380
+ ) -> TokenStream {
381
+ let func_name = method. ident . to_snake_case ( ) ;
382
+ let func_ident = format_ident ! ( "{}" , func_name) ;
389
383
390
- quote ! { ( #( #types) , * ) }
391
- }
392
- _ => quote ! { ( ) } ,
393
- } ;
384
+ let param_decl = method
385
+ . params
386
+ . iter ( )
387
+ . map ( |( _, ty) | self . print_ty ( ty, & BorrowMode :: Owned ) ) ;
388
+
389
+ let param_acc = match method. params . len ( ) {
390
+ 0 => quote ! { } ,
391
+ 1 => quote ! { p. 1 } ,
392
+ _ => {
393
+ let ids = method. params . iter ( ) . enumerate ( ) . map ( |( i, _) | {
394
+ let i = Literal :: usize_unsuffixed ( i + 1 ) ;
395
+ quote ! { p. #i }
396
+ } ) ;
397
+ quote ! { #( #ids) , * }
398
+ }
399
+ } ;
394
400
395
- let mod_name = format ! ( "{mod_name}::resource::{resource_name}" ) ;
396
- let get_r_ident = format_ident ! ( "get_{}" , resource_name. to_snake_case( ) ) ;
401
+ let mod_name = format ! ( "{mod_name}::resource::{resource_name}" ) ;
402
+ let get_r_ident = format_ident ! ( "get_{}" , resource_name. to_snake_case( ) ) ;
397
403
404
+ if self . opts . async_ {
398
405
quote ! {
399
406
let get_cx = :: std:: sync:: Arc :: clone( & wrapped_get_cx) ;
400
- router. func_wrap(
407
+ router. define_async(
408
+ #mod_name,
409
+ #func_name,
410
+ move |ctx: :: tauri_bindgen_host:: ipc_router_wip:: Caller <T >, p: ( :: tauri_bindgen_host:: ResourceId , #( #param_decl) , * ) | {
411
+ let get_cx = get_cx. clone( ) ;
412
+ Box :: pin( async move {
413
+ let ctx = get_cx( ctx. data( ) ) ;
414
+ let r = ctx. #get_r_ident( p. 0 ) ?;
415
+ Ok ( r. #func_ident( #param_acc) . await )
416
+ } )
417
+ } ) ?;
418
+ }
419
+ } else {
420
+ quote ! {
421
+ let get_cx = :: std:: sync:: Arc :: clone( & wrapped_get_cx) ;
422
+ router. define(
401
423
#mod_name,
402
424
#func_name,
403
425
move |
404
426
ctx: :: tauri_bindgen_host:: ipc_router_wip:: Caller <T >,
405
- this_rid: :: tauri_bindgen_host:: ResourceId ,
406
- #params
407
- | -> :: tauri_bindgen_host:: anyhow:: Result <#result> {
427
+ p: ( :: tauri_bindgen_host:: ResourceId , #( #param_decl) , * )
428
+ | {
408
429
let ctx = get_cx( ctx. data( ) ) ;
409
- let r = ctx. #get_r_ident( this_rid) ?;
410
-
411
- Ok ( r. #func_ident( #( #param_idents) , * ) )
430
+ let r = ctx. #get_r_ident( p. 0 ) ?;
431
+ Ok ( r. #func_ident( #param_acc) )
412
432
} ,
413
433
) ?;
414
434
}
435
+ }
436
+ }
437
+
438
+ fn print_add_to_router < ' a > (
439
+ & self ,
440
+ mod_ident : & str ,
441
+ functions : impl Iterator < Item = & ' a Function > ,
442
+ methods : impl Iterator < Item = ( & ' a str , & ' a Function ) > ,
443
+ ) -> TokenStream {
444
+ let trait_ident = format_ident ! ( "{}" , mod_ident. to_upper_camel_case( ) ) ;
445
+
446
+ let mod_name = mod_ident. to_snake_case ( ) ;
447
+
448
+ let functions = functions. map ( |func| self . print_router_fn_definition ( & mod_name, func) ) ;
449
+
450
+ let methods = methods. map ( |( resource_name, method) | {
451
+ self . print_router_method_definition ( & mod_name, resource_name, method)
415
452
} ) ;
416
453
417
454
quote ! {
@@ -420,6 +457,7 @@ impl Host {
420
457
get_cx: impl Fn ( & T ) -> & U + Send + Sync + ' static ,
421
458
) -> Result <( ) , :: tauri_bindgen_host:: ipc_router_wip:: Error >
422
459
where
460
+ T : Send + Sync + ' static ,
423
461
U : #trait_ident + Send + Sync + ' static ,
424
462
{
425
463
let wrapped_get_cx = :: std:: sync:: Arc :: new( get_cx) ;
0 commit comments