Skip to content

Commit 6a826d6

Browse files
Generate async-compatible stubs and servicers
1 parent 1cbc3a8 commit 6a826d6

File tree

7 files changed

+185
-53
lines changed

7 files changed

+185
-53
lines changed

mypy_protobuf/main.py

+88-21
Original file line numberDiff line numberDiff line change
@@ -663,30 +663,77 @@ def _map_key_value_types(
663663

664664
return ktype, vtype
665665

666-
def _callable_type(self, method: d.MethodDescriptorProto) -> str:
666+
def _callable_type(self, method: d.MethodDescriptorProto, is_async: bool = False) -> str:
667+
module = "grpc.aio" if is_async else "grpc"
667668
if method.client_streaming:
668669
if method.server_streaming:
669-
return self._import("grpc", "StreamStreamMultiCallable")
670+
return self._import(module, "StreamStreamMultiCallable")
670671
else:
671-
return self._import("grpc", "StreamUnaryMultiCallable")
672+
return self._import(module, "StreamUnaryMultiCallable")
672673
else:
673674
if method.server_streaming:
674-
return self._import("grpc", "UnaryStreamMultiCallable")
675+
return self._import(module, "UnaryStreamMultiCallable")
675676
else:
676-
return self._import("grpc", "UnaryUnaryMultiCallable")
677+
return self._import(module, "UnaryUnaryMultiCallable")
677678

678-
def _input_type(self, method: d.MethodDescriptorProto, use_stream_iterator: bool = True) -> str:
679+
def _input_type(self, method: d.MethodDescriptorProto) -> str:
679680
result = self._import_message(method.input_type)
680-
if use_stream_iterator and method.client_streaming:
681-
result = f"{self._import('collections.abc', 'Iterator')}[{result}]"
682681
return result
683682

684-
def _output_type(self, method: d.MethodDescriptorProto, use_stream_iterator: bool = True) -> str:
683+
def _servicer_input_type(self, method: d.MethodDescriptorProto) -> str:
684+
result = self._import_message(method.input_type)
685+
if method.client_streaming:
686+
# See write_grpc_async_hacks().
687+
result = f"_MaybeAsyncIterator[{result}]"
688+
return result
689+
690+
def _output_type(self, method: d.MethodDescriptorProto) -> str:
685691
result = self._import_message(method.output_type)
686-
if use_stream_iterator and method.server_streaming:
687-
result = f"{self._import('collections.abc', 'Iterator')}[{result}]"
688692
return result
689693

694+
def _servicer_output_type(self, method: d.MethodDescriptorProto) -> str:
695+
result = self._import_message(method.output_type)
696+
if method.server_streaming:
697+
# Union[Iterator[Resp], AsyncIterator[Resp]] is subtyped by Iterator[Resp] and AsyncIterator[Resp].
698+
# So both can be used in the covariant function return position.
699+
iterator = f"{self._import('typing', 'Iterator')}[{result}]"
700+
aiterator = f"{self._import('typing', 'AsyncIterator')}[{result}]"
701+
result = f"{self._import('typing', 'Union')}[{iterator}, {aiterator}]"
702+
else:
703+
# Union[Resp, Awaitable[Resp]] is subtyped by Resp and Awaitable[Resp].
704+
# So both can be used in the covariant function return position.
705+
# Awaitable[Resp] is equivalent to async def.
706+
awaitable = f"{self._import('typing', 'Awaitable')}[{result}]"
707+
result = f"{self._import('typing', 'Union')}[{result}, {awaitable}]"
708+
return result
709+
710+
def write_grpc_async_hacks(self) -> None:
711+
wl = self._write_line
712+
# _MaybeAsyncIterator[Req] is supertyped by Iterator[Req] and AsyncIterator[Req].
713+
# So both can be used in the contravariant function parameter position.
714+
wl("_T = {}('_T')", self._import("typing", "TypeVar"))
715+
wl("")
716+
wl(
717+
"class _MaybeAsyncIterator({}[_T], {}[_T], metaclass={}):",
718+
self._import("typing", "AsyncIterator"),
719+
self._import("typing", "Iterator"),
720+
self._import("abc", "ABCMeta"),
721+
)
722+
with self._indent():
723+
wl("...")
724+
wl("")
725+
726+
# _ServicerContext is supertyped by grpc.ServicerContext and grpc.aio.ServicerContext
727+
# So both can be used in the contravariant function parameter position.
728+
wl(
729+
"class _ServicerContext({}, {}): # type: ignore",
730+
self._import("grpc", "ServicerContext"),
731+
self._import("grpc.aio", "ServicerContext"),
732+
)
733+
with self._indent():
734+
wl("...")
735+
wl("")
736+
690737
def write_grpc_methods(self, service: d.ServiceDescriptorProto, scl_prefix: SourceCodeLocation) -> None:
691738
wl = self._write_line
692739
methods = [(i, m) for i, m in enumerate(service.method) if m.name not in PYTHON_RESERVED]
@@ -701,20 +748,20 @@ def write_grpc_methods(self, service: d.ServiceDescriptorProto, scl_prefix: Sour
701748
with self._indent():
702749
wl("self,")
703750
input_name = "request_iterator" if method.client_streaming else "request"
704-
input_type = self._input_type(method)
751+
input_type = self._servicer_input_type(method)
705752
wl(f"{input_name}: {input_type},")
706-
wl("context: {},", self._import("grpc", "ServicerContext"))
753+
wl("context: _ServicerContext,")
707754
wl(
708755
") -> {}:{}",
709-
self._output_type(method),
756+
self._servicer_output_type(method),
710757
" ..." if not self._has_comments(scl) else "",
711758
)
712759
if self._has_comments(scl):
713760
with self._indent():
714761
if not self._write_comments(scl):
715762
wl("...")
716763

717-
def write_grpc_stub_methods(self, service: d.ServiceDescriptorProto, scl_prefix: SourceCodeLocation) -> None:
764+
def write_grpc_stub_methods(self, service: d.ServiceDescriptorProto, scl_prefix: SourceCodeLocation, is_async: bool = False) -> None:
718765
wl = self._write_line
719766
methods = [(i, m) for i, m in enumerate(service.method) if m.name not in PYTHON_RESERVED]
720767
if not methods:
@@ -723,10 +770,10 @@ def write_grpc_stub_methods(self, service: d.ServiceDescriptorProto, scl_prefix:
723770
for i, method in methods:
724771
scl = scl_prefix + [d.ServiceDescriptorProto.METHOD_FIELD_NUMBER, i]
725772

726-
wl("{}: {}[", method.name, self._callable_type(method))
773+
wl("{}: {}[", method.name, self._callable_type(method, is_async=is_async))
727774
with self._indent():
728-
wl("{},", self._input_type(method, False))
729-
wl("{},", self._output_type(method, False))
775+
wl("{},", self._input_type(method))
776+
wl("{},", self._output_type(method))
730777
wl("]")
731778
self._write_comments(scl)
732779

@@ -743,17 +790,34 @@ def write_grpc_services(
743790
scl = scl_prefix + [i]
744791

745792
# The stub client
746-
wl(f"class {service.name}Stub:")
793+
wl(
794+
"class {}Stub:",
795+
service.name,
796+
)
747797
with self._indent():
748798
if self._write_comments(scl):
749799
wl("")
800+
# To support casting into FooAsyncStub, allow both Channel and aio.Channel here.
801+
channel = f"{self._import('typing', 'Union')}[{self._import('grpc', 'Channel')}, {self._import('grpc.aio', 'Channel')}]"
750802
wl(
751803
"def __init__(self, channel: {}) -> None: ...",
752-
self._import("grpc", "Channel"),
804+
channel
753805
)
754806
self.write_grpc_stub_methods(service, scl)
755807
wl("")
756808

809+
# The (fake) async stub client
810+
wl(
811+
"class {}AsyncStub:",
812+
service.name,
813+
)
814+
with self._indent():
815+
if self._write_comments(scl):
816+
wl("")
817+
# No __init__ since this isn't a real class (yet), and requires manual casting to work.
818+
self.write_grpc_stub_methods(service, scl, is_async=True)
819+
wl("")
820+
757821
# The service definition interface
758822
wl(
759823
"class {}Servicer(metaclass={}):",
@@ -765,11 +829,13 @@ def write_grpc_services(
765829
wl("")
766830
self.write_grpc_methods(service, scl)
767831
wl("")
832+
server = self._import('grpc', 'Server')
833+
aserver = self._import('grpc.aio', 'Server')
768834
wl(
769835
"def add_{}Servicer_to_server(servicer: {}Servicer, server: {}) -> None: ...",
770836
service.name,
771837
service.name,
772-
self._import("grpc", "Server"),
838+
f"{self._import('typing', 'Union')}[{server}, {aserver}]",
773839
)
774840
wl("")
775841

@@ -960,6 +1026,7 @@ def generate_mypy_grpc_stubs(
9601026
relax_strict_optional_primitives,
9611027
grpc=True,
9621028
)
1029+
pkg_writer.write_grpc_async_hacks()
9631030
pkg_writer.write_grpc_services(fd.service, [d.FileDescriptorProto.SERVICE_FIELD_NUMBER])
9641031

9651032
assert name == fd.name

stubtest_allowlist.txt

+4
Original file line numberDiff line numberDiff line change
@@ -48,6 +48,10 @@ testproto.readme_enum_pb2._?MyEnum(EnumTypeWrapper)?
4848
testproto.nested.nested_pb2.AnotherNested._?NestedEnum(EnumTypeWrapper)?
4949
testproto.nested.nested_pb2.AnotherNested.NestedMessage._?NestedEnum2(EnumTypeWrapper)?
5050

51+
# Our fake async stubs are not there at runtime (yet)
52+
testproto.grpc.dummy_pb2_grpc.DummyServiceAsyncStub
53+
testproto.grpc.import_pb2_grpc.SimpleServiceAsyncStub
54+
5155
# Part of an "EXPERIMENTAL API" according to comment. Not documented.
5256
testproto.grpc.dummy_pb2_grpc.DummyService
5357
testproto.grpc.import_pb2_grpc.SimpleService

test/generated/testproto/grpc/dummy_pb2_grpc.pyi

+46-13
Original file line numberDiff line numberDiff line change
@@ -3,14 +3,23 @@
33
isort:skip_file
44
https://github.com/vmagamedov/grpclib/blob/master/tests/dummy.proto"""
55
import abc
6-
import collections.abc
76
import grpc
7+
import grpc.aio
88
import testproto.grpc.dummy_pb2
9+
import typing
10+
11+
_T = typing.TypeVar('_T')
12+
13+
class _MaybeAsyncIterator(typing.AsyncIterator[_T], typing.Iterator[_T], metaclass=abc.ABCMeta):
14+
...
15+
16+
class _ServicerContext(grpc.ServicerContext, grpc.aio.ServicerContext): # type: ignore
17+
...
918

1019
class DummyServiceStub:
1120
"""DummyService"""
1221

13-
def __init__(self, channel: grpc.Channel) -> None: ...
22+
def __init__(self, channel: typing.Union[grpc.Channel, grpc.aio.Channel]) -> None: ...
1423
UnaryUnary: grpc.UnaryUnaryMultiCallable[
1524
testproto.grpc.dummy_pb2.DummyRequest,
1625
testproto.grpc.dummy_pb2.DummyReply,
@@ -32,36 +41,60 @@ class DummyServiceStub:
3241
]
3342
"""StreamStream"""
3443

44+
class DummyServiceAsyncStub:
45+
"""DummyService"""
46+
47+
UnaryUnary: grpc.aio.UnaryUnaryMultiCallable[
48+
testproto.grpc.dummy_pb2.DummyRequest,
49+
testproto.grpc.dummy_pb2.DummyReply,
50+
]
51+
"""UnaryUnary"""
52+
UnaryStream: grpc.aio.UnaryStreamMultiCallable[
53+
testproto.grpc.dummy_pb2.DummyRequest,
54+
testproto.grpc.dummy_pb2.DummyReply,
55+
]
56+
"""UnaryStream"""
57+
StreamUnary: grpc.aio.StreamUnaryMultiCallable[
58+
testproto.grpc.dummy_pb2.DummyRequest,
59+
testproto.grpc.dummy_pb2.DummyReply,
60+
]
61+
"""StreamUnary"""
62+
StreamStream: grpc.aio.StreamStreamMultiCallable[
63+
testproto.grpc.dummy_pb2.DummyRequest,
64+
testproto.grpc.dummy_pb2.DummyReply,
65+
]
66+
"""StreamStream"""
67+
3568
class DummyServiceServicer(metaclass=abc.ABCMeta):
3669
"""DummyService"""
3770

3871
@abc.abstractmethod
3972
def UnaryUnary(
4073
self,
4174
request: testproto.grpc.dummy_pb2.DummyRequest,
42-
context: grpc.ServicerContext,
43-
) -> testproto.grpc.dummy_pb2.DummyReply:
75+
context: _ServicerContext,
76+
) -> typing.Union[testproto.grpc.dummy_pb2.DummyReply, typing.Awaitable[testproto.grpc.dummy_pb2.DummyReply]]:
4477
"""UnaryUnary"""
4578
@abc.abstractmethod
4679
def UnaryStream(
4780
self,
4881
request: testproto.grpc.dummy_pb2.DummyRequest,
49-
context: grpc.ServicerContext,
50-
) -> collections.abc.Iterator[testproto.grpc.dummy_pb2.DummyReply]:
82+
context: _ServicerContext,
83+
) -> typing.Union[typing.Iterator[testproto.grpc.dummy_pb2.DummyReply], typing.AsyncIterator[testproto.grpc.dummy_pb2.DummyReply]]:
5184
"""UnaryStream"""
5285
@abc.abstractmethod
5386
def StreamUnary(
5487
self,
55-
request_iterator: collections.abc.Iterator[testproto.grpc.dummy_pb2.DummyRequest],
56-
context: grpc.ServicerContext,
57-
) -> testproto.grpc.dummy_pb2.DummyReply:
88+
request_iterator: _MaybeAsyncIterator[testproto.grpc.dummy_pb2.DummyRequest],
89+
context: _ServicerContext,
90+
) -> typing.Union[testproto.grpc.dummy_pb2.DummyReply, typing.Awaitable[testproto.grpc.dummy_pb2.DummyReply]]:
5891
"""StreamUnary"""
5992
@abc.abstractmethod
6093
def StreamStream(
6194
self,
62-
request_iterator: collections.abc.Iterator[testproto.grpc.dummy_pb2.DummyRequest],
63-
context: grpc.ServicerContext,
64-
) -> collections.abc.Iterator[testproto.grpc.dummy_pb2.DummyReply]:
95+
request_iterator: _MaybeAsyncIterator[testproto.grpc.dummy_pb2.DummyRequest],
96+
context: _ServicerContext,
97+
) -> typing.Union[typing.Iterator[testproto.grpc.dummy_pb2.DummyReply], typing.AsyncIterator[testproto.grpc.dummy_pb2.DummyReply]]:
6598
"""StreamStream"""
6699

67-
def add_DummyServiceServicer_to_server(servicer: DummyServiceServicer, server: grpc.Server) -> None: ...
100+
def add_DummyServiceServicer_to_server(servicer: DummyServiceServicer, server: typing.Union[grpc.Server, grpc.aio.Server]) -> None: ...

test/generated/testproto/grpc/import_pb2_grpc.pyi

+36-8
Original file line numberDiff line numberDiff line change
@@ -5,12 +5,22 @@ isort:skip_file
55
import abc
66
import google.protobuf.empty_pb2
77
import grpc
8+
import grpc.aio
89
import testproto.test_pb2
10+
import typing
11+
12+
_T = typing.TypeVar('_T')
13+
14+
class _MaybeAsyncIterator(typing.AsyncIterator[_T], typing.Iterator[_T], metaclass=abc.ABCMeta):
15+
...
16+
17+
class _ServicerContext(grpc.ServicerContext, grpc.aio.ServicerContext): # type: ignore
18+
...
919

1020
class SimpleServiceStub:
1121
"""SimpleService"""
1222

13-
def __init__(self, channel: grpc.Channel) -> None: ...
23+
def __init__(self, channel: typing.Union[grpc.Channel, grpc.aio.Channel]) -> None: ...
1424
UnaryUnary: grpc.UnaryUnaryMultiCallable[
1525
google.protobuf.empty_pb2.Empty,
1626
testproto.test_pb2.Simple1,
@@ -26,28 +36,46 @@ class SimpleServiceStub:
2636
google.protobuf.empty_pb2.Empty,
2737
]
2838

39+
class SimpleServiceAsyncStub:
40+
"""SimpleService"""
41+
42+
UnaryUnary: grpc.aio.UnaryUnaryMultiCallable[
43+
google.protobuf.empty_pb2.Empty,
44+
testproto.test_pb2.Simple1,
45+
]
46+
"""UnaryUnary"""
47+
UnaryStream: grpc.aio.UnaryUnaryMultiCallable[
48+
testproto.test_pb2.Simple1,
49+
google.protobuf.empty_pb2.Empty,
50+
]
51+
"""UnaryStream"""
52+
NoComment: grpc.aio.UnaryUnaryMultiCallable[
53+
testproto.test_pb2.Simple1,
54+
google.protobuf.empty_pb2.Empty,
55+
]
56+
2957
class SimpleServiceServicer(metaclass=abc.ABCMeta):
3058
"""SimpleService"""
3159

3260
@abc.abstractmethod
3361
def UnaryUnary(
3462
self,
3563
request: google.protobuf.empty_pb2.Empty,
36-
context: grpc.ServicerContext,
37-
) -> testproto.test_pb2.Simple1:
64+
context: _ServicerContext,
65+
) -> typing.Union[testproto.test_pb2.Simple1, typing.Awaitable[testproto.test_pb2.Simple1]]:
3866
"""UnaryUnary"""
3967
@abc.abstractmethod
4068
def UnaryStream(
4169
self,
4270
request: testproto.test_pb2.Simple1,
43-
context: grpc.ServicerContext,
44-
) -> google.protobuf.empty_pb2.Empty:
71+
context: _ServicerContext,
72+
) -> typing.Union[google.protobuf.empty_pb2.Empty, typing.Awaitable[google.protobuf.empty_pb2.Empty]]:
4573
"""UnaryStream"""
4674
@abc.abstractmethod
4775
def NoComment(
4876
self,
4977
request: testproto.test_pb2.Simple1,
50-
context: grpc.ServicerContext,
51-
) -> google.protobuf.empty_pb2.Empty: ...
78+
context: _ServicerContext,
79+
) -> typing.Union[google.protobuf.empty_pb2.Empty, typing.Awaitable[google.protobuf.empty_pb2.Empty]]: ...
5280

53-
def add_SimpleServiceServicer_to_server(servicer: SimpleServiceServicer, server: grpc.Server) -> None: ...
81+
def add_SimpleServiceServicer_to_server(servicer: SimpleServiceServicer, server: typing.Union[grpc.Server, grpc.aio.Server]) -> None: ...

test/test_grpc_async_usage.py

+1-1
Original file line numberDiff line numberDiff line change
@@ -53,7 +53,7 @@ async def test_grpc() -> None:
5353
server = make_server()
5454
await server.start()
5555
async with grpc.aio.insecure_channel(ADDRESS) as channel:
56-
client = dummy_pb2_grpc.DummyServiceStub(channel)
56+
client: dummy_pb2_grpc.DummyServiceAsyncStub = dummy_pb2_grpc.DummyServiceStub(channel) # type: ignore
5757
request = dummy_pb2.DummyRequest(value="cprg")
5858
result1 = await client.UnaryUnary(request)
5959
result2 = client.UnaryStream(dummy_pb2.DummyRequest(value=result1.value))

0 commit comments

Comments
 (0)