diff --git a/mypy_protobuf/main.py b/mypy_protobuf/main.py index 2529489d..05d6e2ba 100644 --- a/mypy_protobuf/main.py +++ b/mypy_protobuf/main.py @@ -149,14 +149,14 @@ def __init__( self.from_imports: Dict[str, Set[Tuple[str, Optional[str]]]] = defaultdict(set) self.locals: Set[str] = set() - def _import(self, path: str, name: str) -> str: + def _import(self, path: str, name: str, alias: Optional[str] = None) -> str: """Imports a stdlib path and returns a handle to it eg. self._import("typing", "Optional") -> "Optional" """ imp = path.replace("/", ".") if self.readable_stubs: - self.from_imports[imp].add((name, None)) - return name + self.from_imports[imp].add((name, alias)) + return alias or name else: self.imports.add(imp) return imp + "." + name @@ -656,7 +656,9 @@ def write_grpc_stub_methods(self, service: d.ServiceDescriptorProto) -> None: l("{}] = ...", self._output_type(method, False)) l("") - def write_grpc_services(self, services: Iterable[d.ServiceDescriptorProto]) -> None: + def write_grpc_services( + self, services: Iterable[d.ServiceDescriptorProto], use_aio: bool + ) -> None: l = self._write_line l( "from .{} import *", @@ -682,11 +684,19 @@ def write_grpc_services(self, services: Iterable[d.ServiceDescriptorProto]) -> N with self._indent(): self.write_grpc_methods(service) l("") + if use_aio: + server_type_hint = "{}[{}, {}]".format( + self._import("typing", "Union"), + self._import("grpc", "Server"), + self._import("grpc.aio", "Server", "AioServer"), + ) + else: + server_type_hint = "{}".format(self._import("grpc", "Server")) l( "def add_{}Servicer_to_server(servicer: {}Servicer, server: {}) -> None: ...", service.name, service.name, - self._import("grpc", "Server"), + server_type_hint, ) l("") @@ -802,12 +812,13 @@ def generate_mypy_grpc_stubs( quiet: bool, readable_stubs: bool, relax_strict_optional_primitives: bool, + use_aio: bool, ) -> None: for name, fd in descriptors.to_generate.items(): pkg_writer = PkgWriter( fd, descriptors, readable_stubs, relax_strict_optional_primitives ) - pkg_writer.write_grpc_services(fd.service) + pkg_writer.write_grpc_services(fd.service, use_aio) assert name == fd.name assert fd.name.endswith(".proto") @@ -867,6 +878,7 @@ def grpc() -> None: "quiet" in request.parameter, "readable_stubs" in request.parameter, "relax_strict_optional_primitives" in request.parameter, + "use_aio" in request.parameter, )