Skip to content
Merged
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
64 changes: 41 additions & 23 deletions injector/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -344,12 +344,20 @@ def __repr__(self) -> str:
class MultiBinder(Provider, Generic[T]):
"""Provide a list of instances via other Providers."""

__metaclass__ = ABCMeta

_multi_bindings: List['Binding']

def __init__(self, parent: 'Binder') -> None:
self._multi_bindings = []
self._binder = Binder(parent.injector, auto_bind=False, parent=parent)

@abstractmethod
def multibind(
self, interface: type, to: Any, scope: Union['ScopeDecorator', Type['Scope'], None]
) -> None:
raise NotImplementedError

def append(self, provider: Provider[T], scope: Type['Scope']) -> None:
# HACK: generate a pseudo-type for this element in the list.
# This is needed for scopes to work properly. Some, like the Singleton scope,
Expand All @@ -372,6 +380,21 @@ class MultiBindProvider(MultiBinder[List[T]]):
"""Used by :meth:`Binder.multibind` to flatten results of providers that
return sequences."""

def multibind(
self, interface: type, to: Any, scope: Union['ScopeDecorator', Type['Scope'], None]
) -> None:
try:
element_type = get_args(_punch_through_alias(interface))[0]
except IndexError:
raise InvalidInterface(f"Use typing.List[T] or list[T] to specify the element type of the list")
if isinstance(to, list):
for element in to:
element_binding = self._binder.create_binding(element_type, element, scope)
self.append(element_binding.provider, element_binding.scope)
else:
element_binding = self._binder.create_binding(interface, to, scope)
self.append(element_binding.provider, element_binding.scope)

def get(self, injector: 'Injector') -> List[T]:
result: List[T] = []
for provider in self.get_scoped_providers(injector):
Expand All @@ -383,6 +406,23 @@ def get(self, injector: 'Injector') -> List[T]:
class MapBindProvider(MultiBinder[Dict[str, T]]):
"""A provider for map bindings."""

def multibind(
self, interface: type, to: Any, scope: Union['ScopeDecorator', Type['Scope'], None]
) -> None:
try:
value_type = get_args(_punch_through_alias(interface))[1]
except IndexError:
raise InvalidInterface(
f"Use typing.Dict[K, V] or dict[K, V] to specify the value type of the dict"
)
if isinstance(to, dict):
for key, value in to.items():
element_binding = self._binder.create_binding(value_type, value, scope)
self.append(KeyValueProvider(key, element_binding.provider), element_binding.scope)
else:
element_binding = self._binder.create_binding(interface, to, scope)
self.append(element_binding.provider, element_binding.scope)

def get(self, injector: 'Injector') -> Dict[str, T]:
map: Dict[str, T] = {}
for provider in self.get_scoped_providers(injector):
Expand Down Expand Up @@ -549,29 +589,7 @@ def multibind(
:param scope: Optional Scope in which to bind.
"""
multi_binder = self._get_multi_binder(interface)
if isinstance(multi_binder, MultiBindProvider) and isinstance(to, list):
try:
element_type = get_args(_punch_through_alias(interface))[0]
except IndexError:
raise InvalidInterface(
f"Use typing.List[T] or list[T] to specify the element type of the list"
)
for element in to:
element_binding = self.create_binding(element_type, element, scope)
multi_binder.append(element_binding.provider, element_binding.scope)
elif isinstance(multi_binder, MapBindProvider) and isinstance(to, dict):
try:
value_type = get_args(_punch_through_alias(interface))[1]
except IndexError:
raise InvalidInterface(
f"Use typing.Dict[K, V] or dict[K, V] to specify the value type of the dict"
)
for key, value in to.items():
element_binding = self.create_binding(value_type, value, scope)
multi_binder.append(KeyValueProvider(key, element_binding.provider), element_binding.scope)
else:
element_binding = self.create_binding(interface, to, scope)
multi_binder.append(element_binding.provider, element_binding.scope)
multi_binder.multibind(interface, to, scope)

def _get_multi_binder(self, interface: type) -> MultiBinder:
multi_binder: MultiBinder
Expand Down