@@ -663,30 +663,77 @@ def _map_key_value_types(
663
663
664
664
return ktype , vtype
665
665
666
- def _callable_type (self , method : d .MethodDescriptorProto ) -> str :
666
+ def _callable_type (self , method : d .MethodDescriptorProto , is_async : bool = False ) -> str :
667
+ module = "grpc.aio" if is_async else "grpc"
667
668
if method .client_streaming :
668
669
if method .server_streaming :
669
- return self ._import ("grpc" , "StreamStreamMultiCallable" )
670
+ return self ._import (module , "StreamStreamMultiCallable" )
670
671
else :
671
- return self ._import ("grpc" , "StreamUnaryMultiCallable" )
672
+ return self ._import (module , "StreamUnaryMultiCallable" )
672
673
else :
673
674
if method .server_streaming :
674
- return self ._import ("grpc" , "UnaryStreamMultiCallable" )
675
+ return self ._import (module , "UnaryStreamMultiCallable" )
675
676
else :
676
- return self ._import ("grpc" , "UnaryUnaryMultiCallable" )
677
+ return self ._import (module , "UnaryUnaryMultiCallable" )
677
678
678
- def _input_type (self , method : d .MethodDescriptorProto , use_stream_iterator : bool = True ) -> str :
679
+ def _input_type (self , method : d .MethodDescriptorProto ) -> str :
679
680
result = self ._import_message (method .input_type )
680
- if use_stream_iterator and method .client_streaming :
681
- result = f"{ self ._import ('collections.abc' , 'Iterator' )} [{ result } ]"
682
681
return result
683
682
684
- def _output_type (self , method : d .MethodDescriptorProto , use_stream_iterator : bool = True ) -> str :
683
+ def _servicer_input_type (self , method : d .MethodDescriptorProto ) -> str :
684
+ result = self ._import_message (method .input_type )
685
+ if method .client_streaming :
686
+ # See write_grpc_async_hacks().
687
+ result = f"_MaybeAsyncIterator[{ result } ]"
688
+ return result
689
+
690
+ def _output_type (self , method : d .MethodDescriptorProto ) -> str :
685
691
result = self ._import_message (method .output_type )
686
- if use_stream_iterator and method .server_streaming :
687
- result = f"{ self ._import ('collections.abc' , 'Iterator' )} [{ result } ]"
688
692
return result
689
693
694
+ def _servicer_output_type (self , method : d .MethodDescriptorProto ) -> str :
695
+ result = self ._import_message (method .output_type )
696
+ if method .server_streaming :
697
+ # Union[Iterator[Resp], AsyncIterator[Resp]] is subtyped by Iterator[Resp] and AsyncIterator[Resp].
698
+ # So both can be used in the covariant function return position.
699
+ iterator = f"{ self ._import ('typing' , 'Iterator' )} [{ result } ]"
700
+ aiterator = f"{ self ._import ('typing' , 'AsyncIterator' )} [{ result } ]"
701
+ result = f"{ self ._import ('typing' , 'Union' )} [{ iterator } , { aiterator } ]"
702
+ else :
703
+ # Union[Resp, Awaitable[Resp]] is subtyped by Resp and Awaitable[Resp].
704
+ # So both can be used in the covariant function return position.
705
+ # Awaitable[Resp] is equivalent to async def.
706
+ awaitable = f"{ self ._import ('typing' , 'Awaitable' )} [{ result } ]"
707
+ result = f"{ self ._import ('typing' , 'Union' )} [{ result } , { awaitable } ]"
708
+ return result
709
+
710
+ def write_grpc_async_hacks (self ) -> None :
711
+ wl = self ._write_line
712
+ # _MaybeAsyncIterator[Req] is supertyped by Iterator[Req] and AsyncIterator[Req].
713
+ # So both can be used in the contravariant function parameter position.
714
+ wl ("_T = {}('_T')" , self ._import ("typing" , "TypeVar" ))
715
+ wl ("" )
716
+ wl (
717
+ "class _MaybeAsyncIterator({}[_T], {}[_T], metaclass={}):" ,
718
+ self ._import ("typing" , "AsyncIterator" ),
719
+ self ._import ("typing" , "Iterator" ),
720
+ self ._import ("abc" , "ABCMeta" ),
721
+ )
722
+ with self ._indent ():
723
+ wl ("..." )
724
+ wl ("" )
725
+
726
+ # _ServicerContext is supertyped by grpc.ServicerContext and grpc.aio.ServicerContext
727
+ # So both can be used in the contravariant function parameter position.
728
+ wl (
729
+ "class _ServicerContext({}, {}): # type: ignore" ,
730
+ self ._import ("grpc" , "ServicerContext" ),
731
+ self ._import ("grpc.aio" , "ServicerContext" ),
732
+ )
733
+ with self ._indent ():
734
+ wl ("..." )
735
+ wl ("" )
736
+
690
737
def write_grpc_methods (self , service : d .ServiceDescriptorProto , scl_prefix : SourceCodeLocation ) -> None :
691
738
wl = self ._write_line
692
739
methods = [(i , m ) for i , m in enumerate (service .method ) if m .name not in PYTHON_RESERVED ]
@@ -701,20 +748,20 @@ def write_grpc_methods(self, service: d.ServiceDescriptorProto, scl_prefix: Sour
701
748
with self ._indent ():
702
749
wl ("self," )
703
750
input_name = "request_iterator" if method .client_streaming else "request"
704
- input_type = self ._input_type (method )
751
+ input_type = self ._servicer_input_type (method )
705
752
wl (f"{ input_name } : { input_type } ," )
706
- wl ("context: {}," , self . _import ( "grpc" , "ServicerContext" ) )
753
+ wl ("context: _ServicerContext," )
707
754
wl (
708
755
") -> {}:{}" ,
709
- self ._output_type (method ),
756
+ self ._servicer_output_type (method ),
710
757
" ..." if not self ._has_comments (scl ) else "" ,
711
758
)
712
759
if self ._has_comments (scl ):
713
760
with self ._indent ():
714
761
if not self ._write_comments (scl ):
715
762
wl ("..." )
716
763
717
- def write_grpc_stub_methods (self , service : d .ServiceDescriptorProto , scl_prefix : SourceCodeLocation ) -> None :
764
+ def write_grpc_stub_methods (self , service : d .ServiceDescriptorProto , scl_prefix : SourceCodeLocation , is_async : bool = False ) -> None :
718
765
wl = self ._write_line
719
766
methods = [(i , m ) for i , m in enumerate (service .method ) if m .name not in PYTHON_RESERVED ]
720
767
if not methods :
@@ -723,10 +770,10 @@ def write_grpc_stub_methods(self, service: d.ServiceDescriptorProto, scl_prefix:
723
770
for i , method in methods :
724
771
scl = scl_prefix + [d .ServiceDescriptorProto .METHOD_FIELD_NUMBER , i ]
725
772
726
- wl ("{}: {}[" , method .name , self ._callable_type (method ))
773
+ wl ("{}: {}[" , method .name , self ._callable_type (method , is_async = is_async ))
727
774
with self ._indent ():
728
- wl ("{}," , self ._input_type (method , False ))
729
- wl ("{}," , self ._output_type (method , False ))
775
+ wl ("{}," , self ._input_type (method ))
776
+ wl ("{}," , self ._output_type (method ))
730
777
wl ("]" )
731
778
self ._write_comments (scl )
732
779
@@ -743,17 +790,34 @@ def write_grpc_services(
743
790
scl = scl_prefix + [i ]
744
791
745
792
# The stub client
746
- wl (f"class { service .name } Stub:" )
793
+ wl (
794
+ "class {}Stub:" ,
795
+ service .name ,
796
+ )
747
797
with self ._indent ():
748
798
if self ._write_comments (scl ):
749
799
wl ("" )
800
+ # To support casting into FooAsyncStub, allow both Channel and aio.Channel here.
801
+ channel = f"{ self ._import ('typing' , 'Union' )} [{ self ._import ('grpc' , 'Channel' )} , { self ._import ('grpc.aio' , 'Channel' )} ]"
750
802
wl (
751
803
"def __init__(self, channel: {}) -> None: ..." ,
752
- self . _import ( "grpc" , "Channel" ),
804
+ channel
753
805
)
754
806
self .write_grpc_stub_methods (service , scl )
755
807
wl ("" )
756
808
809
+ # The (fake) async stub client
810
+ wl (
811
+ "class {}AsyncStub:" ,
812
+ service .name ,
813
+ )
814
+ with self ._indent ():
815
+ if self ._write_comments (scl ):
816
+ wl ("" )
817
+ # No __init__ since this isn't a real class (yet), and requires manual casting to work.
818
+ self .write_grpc_stub_methods (service , scl , is_async = True )
819
+ wl ("" )
820
+
757
821
# The service definition interface
758
822
wl (
759
823
"class {}Servicer(metaclass={}):" ,
@@ -765,11 +829,13 @@ def write_grpc_services(
765
829
wl ("" )
766
830
self .write_grpc_methods (service , scl )
767
831
wl ("" )
832
+ server = self ._import ('grpc' , 'Server' )
833
+ aserver = self ._import ('grpc.aio' , 'Server' )
768
834
wl (
769
835
"def add_{}Servicer_to_server(servicer: {}Servicer, server: {}) -> None: ..." ,
770
836
service .name ,
771
837
service .name ,
772
- self ._import ("grpc" , "Server" ) ,
838
+ f" { self ._import ('typing' , 'Union' ) } [ { server } , { aserver } ]" ,
773
839
)
774
840
wl ("" )
775
841
@@ -960,6 +1026,7 @@ def generate_mypy_grpc_stubs(
960
1026
relax_strict_optional_primitives ,
961
1027
grpc = True ,
962
1028
)
1029
+ pkg_writer .write_grpc_async_hacks ()
963
1030
pkg_writer .write_grpc_services (fd .service , [d .FileDescriptorProto .SERVICE_FIELD_NUMBER ])
964
1031
965
1032
assert name == fd .name
0 commit comments