From 0934458afabc2cff74329f5612519acb6b0f8b55 Mon Sep 17 00:00:00 2001 From: Letsch22 Date: Sat, 4 Jun 2022 01:15:34 -0700 Subject: [PATCH] Add multibind class provider --- .gitignore | 1 + injector/__init__.py | 20 ++++++++++++++++++++ injector_test.py | 13 +++++++++++++ 3 files changed, 34 insertions(+) diff --git a/.gitignore b/.gitignore index 0a7c43f..262ed2d 100644 --- a/.gitignore +++ b/.gitignore @@ -10,3 +10,4 @@ coverage.xml /dist/ /injector.egg-info/ /.coverage +venv/ diff --git a/injector/__init__.py b/injector/__init__.py index 3bc6142..157a8a6 100644 --- a/injector/__init__.py +++ b/injector/__init__.py @@ -369,6 +369,17 @@ def __repr__(self) -> str: return '%s(%r)' % (type(self).__name__, self._providers) +@private +class ClassListProvider(Provider[List[T]]): + """Provides a list of instances from a given class.""" + + def __init__(self, cls: Type[T]) -> None: + self._cls = cls + + def get(self, injector: 'Injector') -> List[T]: + return [injector.create_object(self._cls)] + + class MultiBindProvider(ListOfProviders[List[T]]): """Used by :meth:`Binder.multibind` to flatten results of providers that return sequences.""" @@ -377,6 +388,15 @@ def get(self, injector: 'Injector') -> List[T]: return [i for provider in self._providers for i in provider.get(injector)] +class MultiBindClassProvider(MultiBindProvider): + """A provider for a list of instances from a list of classes.""" + + def __init__(self, classes: List[Type[T]]) -> None: + super().__init__() + for cls in classes: + self.append(ClassListProvider(cls)) + + class MapBindProvider(ListOfProviders[Dict[str, T]]): """A provider for map bindings.""" diff --git a/injector_test.py b/injector_test.py index c37295f..68f9c7b 100644 --- a/injector_test.py +++ b/injector_test.py @@ -26,6 +26,7 @@ Binder, CallError, Injector, + MultiBindClassProvider, Scope, InstanceProvider, ClassProvider, @@ -473,6 +474,14 @@ def provide_description(self, age: int, weight: float) -> str: def test_multibind(): + class A: + def print(self) -> str: + return 'A' + + class B(A): + def print(self) -> str: + return 'B' + # First let's have some explicit multibindings def configure(binder): binder.multibind(List[str], to=['not a name']) @@ -483,6 +492,8 @@ def configure(binder): # To see that NewTypes are treated distinctly binder.multibind(Names, to=['Bob']) binder.multibind(Passwords, to={'Bob': 'password1'}) + # To see that MultiBindClassProvider works for lists of types + binder.multibind(List[A], to=MultiBindClassProvider([A, B])) # Then @multiprovider-decorated Module methods class CustomModule(Module): @@ -517,6 +528,8 @@ def provide_passwords(self) -> Passwords: assert injector.get(Dict[str, int]) == {'weight': 12, 'height': 33} assert injector.get(Names) == ['Bob', 'Alice', 'Clarice'] assert injector.get(Passwords) == {'Bob': 'password1', 'Alice': 'aojrioeg3', 'Clarice': 'clarice30'} + assert injector.get(List[A])[0].print() == 'A' + assert injector.get(List[A])[1].print() == 'B' def test_regular_bind_and_provider_dont_work_with_multibind():