diff --git a/.github/workflows/publishing.yml b/.github/workflows/publishing.yml index c3b085fa..f3bf1529 100644 --- a/.github/workflows/publishing.yml +++ b/.github/workflows/publishing.yml @@ -10,7 +10,7 @@ jobs: tests: name: Run tests - runs-on: ubuntu-22.04 + runs-on: ubuntu-24.04 steps: - uses: actions/checkout@v3 - uses: actions/setup-python@v4 @@ -23,7 +23,7 @@ jobs: linters: name: Run linters - runs-on: ubuntu-22.04 + runs-on: ubuntu-24.04 strategy: matrix: toxenv: [flake8, pydocstyle, mypy, pylint] @@ -40,7 +40,7 @@ jobs: build-sdist: name: Build source tarball needs: [tests, linters] - runs-on: ubuntu-22.04 + runs-on: ubuntu-24.04 steps: - uses: actions/checkout@v3 - uses: actions/setup-python@v4 @@ -49,8 +49,9 @@ jobs: - run: | python -m pip install --upgrade build python -m build --sdist - - uses: actions/upload-artifact@v3 + - uses: actions/upload-artifact@v4 with: + name: cibw-sdist path: ./dist/* build-wheels: @@ -59,45 +60,28 @@ jobs: runs-on: ${{ matrix.os }} strategy: matrix: - os: [ubuntu-22.04, windows-2019, macos-14] + os: [ubuntu-24.04, ubuntu-24.04-arm, windows-2019, macos-14] env: CIBW_SKIP: cp27-* steps: - uses: actions/checkout@v3 - name: Build wheels uses: pypa/cibuildwheel@v2.20.0 - - uses: actions/upload-artifact@v3 - with: - path: ./wheelhouse/*.whl - - build-wheels-linux-aarch64: - name: Build wheels (ubuntu-22.04-aarch64) - needs: [tests, linters] - runs-on: ubuntu-22.04 - env: - CIBW_SKIP: cp27-* - steps: - - uses: actions/checkout@v3 - - name: Set up QEMU - if: runner.os == 'Linux' - uses: docker/setup-qemu-action@v2 - - name: Build wheels - uses: pypa/cibuildwheel@v2.20.0 - env: - CIBW_ARCHS_LINUX: aarch64 - - uses: actions/upload-artifact@v3 + - uses: actions/upload-artifact@v4 with: + name: cibw-wheels-x86-${{ matrix.os }}-${{ strategy.job-index }} path: ./wheelhouse/*.whl publish: name: Publish on PyPI - needs: [build-sdist, build-wheels, build-wheels-linux-aarch64] - runs-on: ubuntu-22.04 + needs: [build-sdist, build-wheels] + runs-on: ubuntu-24.04 steps: - - uses: actions/download-artifact@v3 + - uses: actions/download-artifact@v4 with: - name: artifact + pattern: cibw-* path: dist + merge-multiple: true - uses: pypa/gh-action-pypi-publish@release/v1 with: user: __token__ @@ -109,7 +93,7 @@ jobs: publish-docs: name: Publish docs needs: [publish] - runs-on: ubuntu-22.04 + runs-on: ubuntu-24.04 steps: - uses: actions/checkout@v3 - uses: actions/setup-python@v4 diff --git a/docs/main/changelog.rst b/docs/main/changelog.rst index 54a0e8eb..4d724da1 100644 --- a/docs/main/changelog.rst +++ b/docs/main/changelog.rst @@ -7,6 +7,25 @@ that were made in every particular version. From version 0.7.6 *Dependency Injector* framework strictly follows `Semantic versioning`_ +4.46.0 +------ + +- Add option to disable env var interpolation in configs (`#861 `_) +- Fix ``Closing`` dependency resolution (`#852 `_) +- Add support for ``inspect.iscoroutinefunction()`` in ``Coroutine`` provider (`#830 `_) +- Fix broken wiring of sync inject-decorated methods (`#673 `_) +- Add support for ``typing.Annotated`` (`#721 `_, `#853 `_) +- Documentation updates for movie-lister example (`#747 `_) +- Fix type propagation in ``Provider.provider`` (`#744 `_) + +Many thanks for the contributions to: +- `ZipFile `_ +- `Yegor Statkevich `_ +- `Federico Tomasi `_ +- `Martin Lafrance `_ +- `Philip Bjorge `_ +- `Ilya Kazakov `_ + 4.45.0 -------- - Add Starlette lifespan handler implementation (`#683 `_). diff --git a/docs/providers/configuration.rst b/docs/providers/configuration.rst index 582c0cc1..66c299f3 100644 --- a/docs/providers/configuration.rst +++ b/docs/providers/configuration.rst @@ -366,6 +366,19 @@ See also: :ref:`configuration-strict-mode`. assert container.config.section.option() is None +If you want to disable environment variables interpolation, pass ``envs_required=None``: + +.. code-block:: yaml + :caption: templates.yml + + template_string: 'Hello, ${name}!' + +.. code-block:: python + + >>> container.config.from_yaml("templates.yml", envs_required=None) + >>> container.config.template_string() + 'Hello, ${name}!' + Mandatory and optional sources ------------------------------ diff --git a/docs/tutorials/cli.rst b/docs/tutorials/cli.rst index 8b108e08..88014ff3 100644 --- a/docs/tutorials/cli.rst +++ b/docs/tutorials/cli.rst @@ -911,7 +911,7 @@ Create ``tests.py`` in the ``movies`` package: and put next into it: .. code-block:: python - :emphasize-lines: 36,51 + :emphasize-lines: 41,50 """Tests module.""" @@ -941,13 +941,18 @@ and put next into it: return container - def test_movies_directed_by(container): + @pytest.fixture + def finder_mock(container): finder_mock = mock.Mock() finder_mock.find_all.return_value = [ container.movie("The 33", 2015, "Patricia Riggen"), container.movie("The Jungle Book", 2016, "Jon Favreau"), ] + return finder_mock + + + def test_movies_directed_by(container, finder_mock): with container.finder.override(finder_mock): lister = container.lister() movies = lister.movies_directed_by("Jon Favreau") @@ -956,13 +961,7 @@ and put next into it: assert movies[0].title == "The Jungle Book" - def test_movies_released_in(container): - finder_mock = mock.Mock() - finder_mock.find_all.return_value = [ - container.movie("The 33", 2015, "Patricia Riggen"), - container.movie("The Jungle Book", 2016, "Jon Favreau"), - ] - + def test_movies_released_in(container, finder_mock): with container.finder.override(finder_mock): lister = container.lister() movies = lister.movies_released_in(2015) @@ -995,9 +994,9 @@ You should see: movies/entities.py 7 1 86% movies/finders.py 26 13 50% movies/listers.py 8 0 100% - movies/tests.py 23 0 100% + movies/tests.py 24 0 100% ------------------------------------------ - TOTAL 89 30 66% + TOTAL 90 30 67% .. note:: diff --git a/docs/wiring.rst b/docs/wiring.rst index 2de708d0..74026879 100644 --- a/docs/wiring.rst +++ b/docs/wiring.rst @@ -64,7 +64,7 @@ FastAPI example: @app.api_route("/") @inject - async def index(service: Service = Depends(Provide[Container.service])): + async def index(service: Annotated[Service, Depends(Provide[Container.service])]): value = await service.process() return {"result": value} diff --git a/examples/miniapps/fastapi-redis/fastapiredis/application.py b/examples/miniapps/fastapi-redis/fastapiredis/application.py index f8e4a3bb..52f14366 100644 --- a/examples/miniapps/fastapi-redis/fastapiredis/application.py +++ b/examples/miniapps/fastapi-redis/fastapiredis/application.py @@ -1,18 +1,22 @@ """Application module.""" -from dependency_injector.wiring import inject, Provide -from fastapi import FastAPI, Depends +from typing import Annotated + +from fastapi import Depends, FastAPI + +from dependency_injector.wiring import Provide, inject from .containers import Container from .services import Service - app = FastAPI() @app.api_route("/") @inject -async def index(service: Service = Depends(Provide[Container.service])): +async def index( + service: Annotated[Service, Depends(Provide[Container.service])] +) -> dict[str, str]: value = await service.process() return {"result": value} diff --git a/examples/miniapps/fastapi-simple/fastapi_di_example.py b/examples/miniapps/fastapi-simple/fastapi_di_example.py index 9f3d3f83..6d50499c 100644 --- a/examples/miniapps/fastapi-simple/fastapi_di_example.py +++ b/examples/miniapps/fastapi-simple/fastapi_di_example.py @@ -1,4 +1,7 @@ -from fastapi import FastAPI, Depends +from typing import Annotated + +from fastapi import Depends, FastAPI + from dependency_injector import containers, providers from dependency_injector.wiring import Provide, inject @@ -18,7 +21,9 @@ class Container(containers.DeclarativeContainer): @app.api_route("/") @inject -async def index(service: Service = Depends(Provide[Container.service])): +async def index( + service: Annotated[Service, Depends(Provide[Container.service])] +) -> dict[str, str]: result = await service.process() return {"result": result} diff --git a/examples/miniapps/fastapi-sqlalchemy/webapp/endpoints.py b/examples/miniapps/fastapi-sqlalchemy/webapp/endpoints.py index 4d27101e..e02c2740 100644 --- a/examples/miniapps/fastapi-sqlalchemy/webapp/endpoints.py +++ b/examples/miniapps/fastapi-sqlalchemy/webapp/endpoints.py @@ -1,11 +1,14 @@ """Endpoints module.""" +from typing import Annotated + from fastapi import APIRouter, Depends, Response, status -from dependency_injector.wiring import inject, Provide + +from dependency_injector.wiring import Provide, inject from .containers import Container -from .services import UserService from .repositories import NotFoundError +from .services import UserService router = APIRouter() @@ -13,7 +16,7 @@ @router.get("/users") @inject def get_list( - user_service: UserService = Depends(Provide[Container.user_service]), + user_service: Annotated[UserService, Depends(Provide[Container.user_service])], ): return user_service.get_users() @@ -21,8 +24,8 @@ def get_list( @router.get("/users/{user_id}") @inject def get_by_id( - user_id: int, - user_service: UserService = Depends(Provide[Container.user_service]), + user_id: int, + user_service: Annotated[UserService, Depends(Provide[Container.user_service])], ): try: return user_service.get_user_by_id(user_id) @@ -33,7 +36,7 @@ def get_by_id( @router.post("/users", status_code=status.HTTP_201_CREATED) @inject def add( - user_service: UserService = Depends(Provide[Container.user_service]), + user_service: Annotated[UserService, Depends(Provide[Container.user_service])], ): return user_service.create_user() @@ -41,9 +44,9 @@ def add( @router.delete("/users/{user_id}", status_code=status.HTTP_204_NO_CONTENT) @inject def remove( - user_id: int, - user_service: UserService = Depends(Provide[Container.user_service]), -): + user_id: int, + user_service: Annotated[UserService, Depends(Provide[Container.user_service])], +) -> Response: try: user_service.delete_user_by_id(user_id) except NotFoundError: diff --git a/examples/miniapps/fastapi/giphynavigator/endpoints.py b/examples/miniapps/fastapi/giphynavigator/endpoints.py index 2761f203..904eb71d 100644 --- a/examples/miniapps/fastapi/giphynavigator/endpoints.py +++ b/examples/miniapps/fastapi/giphynavigator/endpoints.py @@ -1,13 +1,14 @@ """Endpoints module.""" -from typing import Optional, List +from typing import Annotated, List from fastapi import APIRouter, Depends from pydantic import BaseModel -from dependency_injector.wiring import inject, Provide -from .services import SearchService +from dependency_injector.wiring import Provide, inject + from .containers import Container +from .services import SearchService class Gif(BaseModel): @@ -26,11 +27,15 @@ class Response(BaseModel): @router.get("/", response_model=Response) @inject async def index( - query: Optional[str] = None, - limit: Optional[str] = None, - default_query: str = Depends(Provide[Container.config.default.query]), - default_limit: int = Depends(Provide[Container.config.default.limit.as_int()]), - search_service: SearchService = Depends(Provide[Container.search_service]), + default_query: Annotated[str, Depends(Provide[Container.config.default.query])], + default_limit: Annotated[ + int, Depends(Provide[Container.config.default.limit.as_int()]) + ], + search_service: Annotated[ + SearchService, Depends(Provide[Container.search_service]) + ], + query: str | None = None, + limit: int | None = None, ): query = query or default_query limit = limit or default_limit diff --git a/examples/miniapps/movie-lister/movies/tests.py b/examples/miniapps/movie-lister/movies/tests.py index 1b29d824..1a133a5f 100644 --- a/examples/miniapps/movie-lister/movies/tests.py +++ b/examples/miniapps/movie-lister/movies/tests.py @@ -26,13 +26,18 @@ def container(): return container -def test_movies_directed_by(container): +@pytest.fixture +def finder_mock(container): finder_mock = mock.Mock() finder_mock.find_all.return_value = [ container.movie("The 33", 2015, "Patricia Riggen"), container.movie("The Jungle Book", 2016, "Jon Favreau"), ] + return finder_mock + + +def test_movies_directed_by(container, finder_mock): with container.finder.override(finder_mock): lister = container.lister() movies = lister.movies_directed_by("Jon Favreau") @@ -41,13 +46,7 @@ def test_movies_directed_by(container): assert movies[0].title == "The Jungle Book" -def test_movies_released_in(container): - finder_mock = mock.Mock() - finder_mock.find_all.return_value = [ - container.movie("The 33", 2015, "Patricia Riggen"), - container.movie("The Jungle Book", 2016, "Jon Favreau"), - ] - +def test_movies_released_in(container, finder_mock): with container.finder.override(finder_mock): lister = container.lister() movies = lister.movies_released_in(2015) diff --git a/examples/wiring/example.py b/examples/wiring/example.py index 4221ab13..0e32b192 100644 --- a/examples/wiring/example.py +++ b/examples/wiring/example.py @@ -2,10 +2,10 @@ from dependency_injector import containers, providers from dependency_injector.wiring import Provide, inject +from typing import Annotated -class Service: - ... +class Service: ... class Container(containers.DeclarativeContainer): @@ -13,9 +13,16 @@ class Container(containers.DeclarativeContainer): service = providers.Factory(Service) +# You can place marker on parameter default value @inject -def main(service: Service = Provide[Container.service]) -> None: - ... +def main(service: Service = Provide[Container.service]) -> None: ... + + +# Also, you can place marker with typing.Annotated +@inject +def main_with_annotated( + service: Annotated[Service, Provide[Container.service]] +) -> None: ... if __name__ == "__main__": diff --git a/requirements-dev.txt b/requirements-dev.txt index bc533741..0d759d4e 100644 --- a/requirements-dev.txt +++ b/requirements-dev.txt @@ -18,5 +18,6 @@ numpy scipy boto3 mypy_boto3_s3 +typing_extensions -r requirements-ext.txt diff --git a/setup.cfg b/setup.cfg index cb18dacf..9bb1e56b 100644 --- a/setup.cfg +++ b/setup.cfg @@ -2,6 +2,7 @@ max_line_length = 120 max_complexity = 10 exclude = types.py +extend-ignore = E203,E701 per-file-ignores = examples/demo/*: F841 examples/containers/traverse.py: E501 diff --git a/src/dependency_injector/__init__.py b/src/dependency_injector/__init__.py index 14e3c273..faa905f5 100644 --- a/src/dependency_injector/__init__.py +++ b/src/dependency_injector/__init__.py @@ -1,6 +1,6 @@ """Top-level package.""" -__version__ = "4.45.0" +__version__ = "4.46.0" """Version number. :type: str diff --git a/src/dependency_injector/_cwiring.pyx b/src/dependency_injector/_cwiring.pyx index 88b6bc5a..84a5485f 100644 --- a/src/dependency_injector/_cwiring.pyx +++ b/src/dependency_injector/_cwiring.pyx @@ -2,44 +2,39 @@ import asyncio import collections.abc -import functools import inspect import types -from . import providers -from .wiring import _Marker, PatchedCallable +from .wiring import _Marker -from .providers cimport Provider +from .providers cimport Provider, Resource -def _get_sync_patched(fn, patched: PatchedCallable): - @functools.wraps(fn) - def _patched(*args, **kwargs): - cdef object result - cdef dict to_inject - cdef object arg_key - cdef Provider provider +def _sync_inject(object fn, tuple args, dict kwargs, dict injections, dict closings, /): + cdef object result + cdef dict to_inject + cdef object arg_key + cdef Provider provider - to_inject = kwargs.copy() - for arg_key, provider in patched.injections.items(): - if arg_key not in kwargs or isinstance(kwargs[arg_key], _Marker): - to_inject[arg_key] = provider() + to_inject = kwargs.copy() + for arg_key, provider in injections.items(): + if arg_key not in kwargs or isinstance(kwargs[arg_key], _Marker): + to_inject[arg_key] = provider() - result = fn(*args, **to_inject) + result = fn(*args, **to_inject) - if patched.closing: - for arg_key, provider in patched.closing.items(): - if arg_key in kwargs and not isinstance(kwargs[arg_key], _Marker): - continue - if not isinstance(provider, providers.Resource): - continue - provider.shutdown() + if closings: + for arg_key, provider in closings.items(): + if arg_key in kwargs and not isinstance(kwargs[arg_key], _Marker): + continue + if not isinstance(provider, Resource): + continue + provider.shutdown() - return result - return _patched + return result -async def _async_inject(object fn, tuple args, dict kwargs, dict injections, dict closings): +async def _async_inject(object fn, tuple args, dict kwargs, dict injections, dict closings, /): cdef object result cdef dict to_inject cdef list to_inject_await = [] @@ -69,7 +64,7 @@ async def _async_inject(object fn, tuple args, dict kwargs, dict injections, dic for arg_key, provider in closings.items(): if arg_key in kwargs and isinstance(kwargs[arg_key], _Marker): continue - if not isinstance(provider, providers.Resource): + if not isinstance(provider, Resource): continue shutdown = provider.shutdown() if _isawaitable(shutdown): diff --git a/src/dependency_injector/containers.pyi b/src/dependency_injector/containers.pyi index e773bfb3..4b40fbba 100644 --- a/src/dependency_injector/containers.pyi +++ b/src/dependency_injector/containers.pyi @@ -19,21 +19,24 @@ from typing import ( from .providers import Provider, Self, ProviderParent - C_Base = TypeVar("C_Base", bound="Container") C = TypeVar("C", bound="DeclarativeContainer") C_Overriding = TypeVar("C_Overriding", bound="DeclarativeContainer") T = TypeVar("T") TT = TypeVar("TT") - class WiringConfiguration: modules: List[Any] packages: List[Any] from_package: Optional[str] auto_wire: bool - def __init__(self, modules: Optional[Iterable[Any]] = None, packages: Optional[Iterable[Any]] = None, from_package: Optional[str] = None, auto_wire: bool = True) -> None: ... - + def __init__( + self, + modules: Optional[Iterable[Any]] = None, + packages: Optional[Iterable[Any]] = None, + from_package: Optional[str] = None, + auto_wire: bool = True, + ) -> None: ... class Container: provider_type: Type[Provider] = Provider @@ -51,11 +54,18 @@ class Container: def set_providers(self, **providers: Provider): ... def set_provider(self, name: str, provider: Provider) -> None: ... def override(self, overriding: Union[Container, Type[Container]]) -> None: ... - def override_providers(self, **overriding_providers: Union[Provider, Any]) -> ProvidersOverridingContext[C_Base]: ... + def override_providers( + self, **overriding_providers: Union[Provider, Any] + ) -> ProvidersOverridingContext[C_Base]: ... def reset_last_overriding(self) -> None: ... def reset_override(self) -> None: ... def is_auto_wiring_enabled(self) -> bool: ... - def wire(self, modules: Optional[Iterable[Any]] = None, packages: Optional[Iterable[Any]] = None, from_package: Optional[str] = None) -> None: ... + def wire( + self, + modules: Optional[Iterable[Any]] = None, + packages: Optional[Iterable[Any]] = None, + from_package: Optional[str] = None, + ) -> None: ... def unwire(self) -> None: ... def init_resources(self) -> Optional[Awaitable]: ... def shutdown_resources(self) -> Optional[Awaitable]: ... @@ -64,7 +74,9 @@ class Container: def reset_singletons(self) -> SingletonResetContext[C_Base]: ... def check_dependencies(self) -> None: ... def from_schema(self, schema: Dict[Any, Any]) -> None: ... - def from_yaml_schema(self, filepath: Union[Path, str], loader: Optional[Any]=None) -> None: ... + def from_yaml_schema( + self, filepath: Union[Path, str], loader: Optional[Any] = None + ) -> None: ... def from_json_schema(self, filepath: Union[Path, str]) -> None: ... @overload def resolve_provider_name(self, provider: Provider) -> str: ... @@ -82,10 +94,8 @@ class Container: @overload def traverse(cls, types: Optional[Iterable[Type[TT]]] = None) -> Iterator[TT]: ... - class DynamicContainer(Container): ... - class DeclarativeContainer(Container): cls_providers: ClassVar[Dict[str, Provider]] inherited_providers: ClassVar[Dict[str, Provider]] @@ -93,29 +103,28 @@ class DeclarativeContainer(Container): @classmethod def override(cls, overriding: Union[Container, Type[Container]]) -> None: ... @classmethod - def override_providers(cls, **overriding_providers: Union[Provider, Any]) -> ProvidersOverridingContext[C_Base]: ... + def override_providers( + cls, **overriding_providers: Union[Provider, Any] + ) -> ProvidersOverridingContext[C_Base]: ... @classmethod def reset_last_overriding(cls) -> None: ... @classmethod def reset_override(cls) -> None: ... - class ProvidersOverridingContext(Generic[T]): - def __init__(self, container: T, overridden_providers: Iterable[Union[Provider, Any]]) -> None: ... + def __init__( + self, container: T, overridden_providers: Iterable[Union[Provider, Any]] + ) -> None: ... def __enter__(self) -> T: ... def __exit__(self, *_: Any) -> None: ... - class SingletonResetContext(Generic[T]): def __init__(self, container: T): ... def __enter__(self) -> T: ... def __exit__(self, *_: Any) -> None: ... - -def override(container: Type[C]) -> _Callable[[Type[C_Overriding]], Type[C_Overriding]]: ... - - +def override( + container: Type[C], +) -> _Callable[[Type[C_Overriding]], Type[C_Overriding]]: ... def copy(container: Type[C]) -> _Callable[[Type[C_Overriding]], Type[C_Overriding]]: ... - - def is_container(instance: Any) -> bool: ... diff --git a/src/dependency_injector/ext/aiohttp.py b/src/dependency_injector/ext/aiohttp.py index b132f362..976089c3 100644 --- a/src/dependency_injector/ext/aiohttp.py +++ b/src/dependency_injector/ext/aiohttp.py @@ -38,9 +38,11 @@ class View(providers.Callable): def as_view(self): """Return aiohttp view function.""" + @functools.wraps(self.provides) async def _view(request, *args, **kwargs): return await self.__call__(request, *args, **kwargs) + return _view @@ -49,6 +51,8 @@ class ClassBasedView(providers.Factory): def as_view(self): """Return aiohttp view function.""" + async def _view(request, *args, **kwargs): return await self.__call__(request, *args, **kwargs) + return _view diff --git a/src/dependency_injector/ext/aiohttp.pyi b/src/dependency_injector/ext/aiohttp.pyi index 05c68acc..370cc9b0 100644 --- a/src/dependency_injector/ext/aiohttp.pyi +++ b/src/dependency_injector/ext/aiohttp.pyi @@ -2,22 +2,13 @@ from typing import Awaitable as _Awaitable from dependency_injector import providers - class Application(providers.Singleton): ... - - class Extension(providers.Singleton): ... - - class Middleware(providers.DelegatedCallable): ... - - class MiddlewareFactory(providers.Factory): ... - class View(providers.Callable): def as_view(self) -> _Awaitable: ... - class ClassBasedView(providers.Factory): def as_view(self) -> _Awaitable: ... diff --git a/src/dependency_injector/ext/flask.py b/src/dependency_injector/ext/flask.py index b30b8797..498a9eee 100644 --- a/src/dependency_injector/ext/flask.py +++ b/src/dependency_injector/ext/flask.py @@ -45,6 +45,7 @@ def as_view(self, name): def as_view(provider, name=None): """Transform class-based view provider to view function.""" if isinstance(provider, providers.Factory): + def view(*args, **kwargs): self = provider() return self.dispatch_request(*args, **kwargs) @@ -52,12 +53,13 @@ def view(*args, **kwargs): assert name, 'Argument "endpoint" is required for class-based views' view.__name__ = name elif isinstance(provider, providers.Callable): + def view(*args, **kwargs): return provider(*args, **kwargs) view.__name__ = provider.provides.__name__ else: - raise errors.Error('Undefined provider type') + raise errors.Error("Undefined provider type") view.__doc__ = provider.provides.__doc__ view.__module__ = provider.provides.__module__ @@ -65,14 +67,14 @@ def view(*args, **kwargs): if isinstance(provider.provides, type): view.view_class = provider.provides - if hasattr(provider.provides, 'decorators'): + if hasattr(provider.provides, "decorators"): for decorator in provider.provides.decorators: view = decorator(view) - if hasattr(provider.provides, 'methods'): + if hasattr(provider.provides, "methods"): view.methods = provider.provides.methods - if hasattr(provider.provides, 'provide_automatic_options'): + if hasattr(provider.provides, "provide_automatic_options"): view.provide_automatic_options = provider.provides.provide_automatic_options return view diff --git a/src/dependency_injector/ext/flask.pyi b/src/dependency_injector/ext/flask.pyi index 37bffd36..9b180c89 100644 --- a/src/dependency_injector/ext/flask.pyi +++ b/src/dependency_injector/ext/flask.pyi @@ -3,22 +3,17 @@ from typing import Union, Optional, Callable as _Callable, Any from flask import request as flask_request from dependency_injector import providers - request: providers.Object[flask_request] - class Application(providers.Singleton): ... - - class Extension(providers.Singleton): ... - class View(providers.Callable): def as_view(self) -> _Callable[..., Any]: ... - class ClassBasedView(providers.Factory): def as_view(self, name: str) -> _Callable[..., Any]: ... - -def as_view(provider: Union[View, ClassBasedView], name: Optional[str] = None) -> _Callable[..., Any]: ... +def as_view( + provider: Union[View, ClassBasedView], name: Optional[str] = None +) -> _Callable[..., Any]: ... diff --git a/src/dependency_injector/providers.pyi b/src/dependency_injector/providers.pyi index b7fbf211..e4d62506 100644 --- a/src/dependency_injector/providers.pyi +++ b/src/dependency_injector/providers.pyi @@ -33,7 +33,6 @@ except ImportError: from . import resources - Injection = Any ProviderParent = Union["Provider", Any] T = TypeVar("T") @@ -41,16 +40,13 @@ TT = TypeVar("TT") P = TypeVar("P", bound="Provider") BS = TypeVar("BS", bound="BaseSingleton") - class Provider(Generic[T]): def __init__(self) -> None: ... - @overload def __call__(self, *args: Injection, **kwargs: Injection) -> T: ... @overload def __call__(self, *args: Injection, **kwargs: Injection) -> Awaitable[T]: ... def async_(self, *args: Injection, **kwargs: Injection) -> Awaitable[T]: ... - def __deepcopy__(self, memo: Optional[_Dict[Any, Any]]) -> Provider: ... def __str__(self) -> str: ... def __repr__(self) -> str: ... @@ -67,9 +63,9 @@ class Provider(Generic[T]): def unregister_overrides(self, provider: Union[Provider, Any]) -> None: ... def delegate(self) -> Provider: ... @property - def provider(self) -> Provider: ... + def provider(self) -> Provider[T]: ... @property - def provided(self) -> ProvidedInstance: ... + def provided(self) -> ProvidedInstance[T]: ... def enable_async_mode(self) -> None: ... def disable_async_mode(self) -> None: ... def reset_async_mode(self) -> None: ... @@ -78,9 +74,12 @@ class Provider(Generic[T]): def is_async_mode_undefined(self) -> bool: ... @property def related(self) -> _Iterator[Provider]: ... - def traverse(self, types: Optional[_Iterable[Type[TT]]] = None) -> _Iterator[TT]: ... - def _copy_overridings(self, copied: Provider, memo: Optional[_Dict[Any, Any]]) -> None: ... - + def traverse( + self, types: Optional[_Iterable[Type[TT]]] = None + ) -> _Iterator[TT]: ... + def _copy_overridings( + self, copied: Provider, memo: Optional[_Dict[Any, Any]] + ) -> None: ... class Object(Provider[T]): def __init__(self, provides: Optional[T] = None) -> None: ... @@ -88,7 +87,6 @@ class Object(Provider[T]): def provides(self) -> Optional[T]: ... def set_provides(self, provides: Optional[T]) -> Object: ... - class Self(Provider[T]): def __init__(self, container: Optional[T] = None) -> None: ... def set_container(self, container: T) -> None: ... @@ -96,41 +94,51 @@ class Self(Provider[T]): @property def alt_names(self) -> Tuple[Any]: ... - class Delegate(Provider[Provider]): def __init__(self, provides: Optional[Provider] = None) -> None: ... @property def provides(self) -> Optional[Provider]: ... def set_provides(self, provides: Optional[Provider]) -> Delegate: ... - class Aggregate(Provider[T]): - def __init__(self, provider_dict: Optional[_Dict[Any, Provider[T]]] = None, **provider_kwargs: Provider[T]): ... + def __init__( + self, + provider_dict: Optional[_Dict[Any, Provider[T]]] = None, + **provider_kwargs: Provider[T], + ): ... def __getattr__(self, provider_name: Any) -> Provider[T]: ... - @overload - def __call__(self, provider_name: Optional[Any] = None, *args: Injection, **kwargs: Injection) -> T: ... + def __call__( + self, provider_name: Optional[Any] = None, *args: Injection, **kwargs: Injection + ) -> T: ... @overload - def __call__(self, provider_name: Optional[Any] = None, *args: Injection, **kwargs: Injection) -> Awaitable[T]: ... - def async_(self, provider_name: Optional[Any] = None, *args: Injection, **kwargs: Injection) -> Awaitable[T]: ... - + def __call__( + self, provider_name: Optional[Any] = None, *args: Injection, **kwargs: Injection + ) -> Awaitable[T]: ... + def async_( + self, provider_name: Optional[Any] = None, *args: Injection, **kwargs: Injection + ) -> Awaitable[T]: ... @property def providers(self) -> _Dict[Any, Provider[T]]: ... - def set_providers(self, provider_dict: Optional[_Dict[Any, Provider[T]]] = None, **provider_kwargs: Provider[T]) -> Aggregate[T]: ... - + def set_providers( + self, + provider_dict: Optional[_Dict[Any, Provider[T]]] = None, + **provider_kwargs: Provider[T], + ) -> Aggregate[T]: ... class Dependency(Provider[T]): - def __init__(self, instance_of: Type[T] = object, default: Optional[Union[Provider, Any]] = None) -> None: ... + def __init__( + self, + instance_of: Type[T] = object, + default: Optional[Union[Provider, Any]] = None, + ) -> None: ... def __getattr__(self, name: str) -> Any: ... - @property def instance_of(self) -> Type[T]: ... def set_instance_of(self, instance_of: Type[T]) -> Dependency[T]: ... - @property def default(self) -> Provider[T]: ... def set_default(self, default: Optional[Union[Provider, Any]]) -> Dependency[T]: ... - @property def is_defined(self) -> bool: ... def provided_by(self, provider: Provider) -> OverridingContext[P]: ... @@ -140,10 +148,8 @@ class Dependency(Provider[T]): def parent_name(self) -> Optional[str]: ... def assign_parent(self, parent: ProviderParent) -> None: ... - class ExternalDependency(Dependency[T]): ... - class DependenciesContainer(Object): def __init__(self, **dependencies: Provider) -> None: ... def __getattr__(self, name: str) -> Provider: ... @@ -156,12 +162,18 @@ class DependenciesContainer(Object): def parent_name(self) -> Optional[str]: ... def assign_parent(self, parent: ProviderParent) -> None: ... - class Callable(Provider[T]): - def __init__(self, provides: Optional[Union[_Callable[..., T], str]] = None, *args: Injection, **kwargs: Injection) -> None: ... + def __init__( + self, + provides: Optional[Union[_Callable[..., T], str]] = None, + *args: Injection, + **kwargs: Injection, + ) -> None: ... @property def provides(self) -> Optional[_Callable[..., T]]: ... - def set_provides(self, provides: Optional[Union[_Callable[..., T], str]]) -> Callable[T]: ... + def set_provides( + self, provides: Optional[Union[_Callable[..., T], str]] + ) -> Callable[T]: ... @property def args(self) -> Tuple[Injection]: ... def add_args(self, *args: Injection) -> Callable[T]: ... @@ -173,32 +185,23 @@ class Callable(Provider[T]): def set_kwargs(self, **kwargs: Injection) -> Callable[T]: ... def clear_kwargs(self) -> Callable[T]: ... - class DelegatedCallable(Callable[T]): ... - class AbstractCallable(Callable[T]): def override(self, provider: Callable) -> OverridingContext[P]: ... - class CallableDelegate(Delegate): def __init__(self, callable: Callable) -> None: ... - class Coroutine(Callable[T]): ... - - class DelegatedCoroutine(Coroutine[T]): ... - class AbstractCoroutine(Coroutine[T]): def override(self, provider: Coroutine) -> OverridingContext[P]: ... - class CoroutineDelegate(Delegate): def __init__(self, coroutine: Coroutine) -> None: ... - class ConfigurationOption(Provider[Any]): UNDEFINED: object def __init__(self, name: Tuple[str], root: Configuration) -> None: ... @@ -212,89 +215,137 @@ class ConfigurationOption(Provider[Any]): def get_name_segments(self) -> Tuple[Union[str, Provider]]: ... def as_int(self) -> TypedConfigurationOption[int]: ... def as_float(self) -> TypedConfigurationOption[float]: ... - def as_(self, callback: _Callable[..., T], *args: Injection, **kwargs: Injection) -> TypedConfigurationOption[T]: ... + def as_( + self, callback: _Callable[..., T], *args: Injection, **kwargs: Injection + ) -> TypedConfigurationOption[T]: ... def required(self) -> ConfigurationOption: ... def is_required(self) -> bool: ... def update(self, value: Any) -> None: ... - def from_ini(self, filepath: Union[Path, str], required: bool = False, envs_required: bool = False) -> None: ... - def from_yaml(self, filepath: Union[Path, str], required: bool = False, loader: Optional[Any] = None, envs_required: bool = False) -> None: ... - def from_json(self, filepath: Union[Path, str], required: bool = False, envs_required: bool = False) -> None: ... - def from_pydantic(self, settings: PydanticSettings, required: bool = False, **kwargs: Any) -> None: ... + def from_ini( + self, + filepath: Union[Path, str], + required: bool = False, + envs_required: Optional[bool] = False, + ) -> None: ... + def from_yaml( + self, + filepath: Union[Path, str], + required: bool = False, + loader: Optional[Any] = None, + envs_required: Optional[bool] = False, + ) -> None: ... + def from_json( + self, + filepath: Union[Path, str], + required: bool = False, + envs_required: Optional[bool] = False, + ) -> None: ... + def from_pydantic( + self, settings: PydanticSettings, required: bool = False, **kwargs: Any + ) -> None: ... def from_dict(self, options: _Dict[str, Any], required: bool = False) -> None: ... - def from_env(self, name: str, default: Optional[Any] = None, required: bool = False, as_: Optional[_Callable[..., Any]] = None) -> None: ... + def from_env( + self, + name: str, + default: Optional[Any] = None, + required: bool = False, + as_: Optional[_Callable[..., Any]] = None, + ) -> None: ... def from_value(self, value: Any) -> None: ... - class TypedConfigurationOption(Callable[T]): @property def option(self) -> ConfigurationOption: ... - class Configuration(Object[Any]): DEFAULT_NAME: str = "config" def __init__( - self, - name: str = DEFAULT_NAME, - default: Optional[Any] = None, - *, - strict: bool = False, - ini_files: Optional[_Iterable[Union[Path, str]]] = None, - yaml_files: Optional[_Iterable[Union[Path, str]]] = None, - json_files: Optional[_Iterable[Union[Path, str]]] = None, - pydantic_settings: Optional[_Iterable[PydanticSettings]] = None, + self, + name: str = DEFAULT_NAME, + default: Optional[Any] = None, + *, + strict: bool = False, + ini_files: Optional[_Iterable[Union[Path, str]]] = None, + yaml_files: Optional[_Iterable[Union[Path, str]]] = None, + json_files: Optional[_Iterable[Union[Path, str]]] = None, + pydantic_settings: Optional[_Iterable[PydanticSettings]] = None, ) -> None: ... - def __enter__(self) -> Configuration : ... + def __enter__(self) -> Configuration: ... def __exit__(self, *exc_info: Any) -> None: ... def __getattr__(self, item: str) -> ConfigurationOption: ... def __getitem__(self, item: Union[str, Provider]) -> ConfigurationOption: ... - def get_name(self) -> str: ... def set_name(self, name: str) -> Configuration: ... - def get_default(self) -> _Dict[Any, Any]: ... def set_default(self, default: _Dict[Any, Any]): ... - def get_strict(self) -> bool: ... def set_strict(self, strict: bool) -> Configuration: ... - def get_children(self) -> _Dict[str, ConfigurationOption]: ... - def set_children(self, children: _Dict[str, ConfigurationOption]) -> Configuration: ... - + def set_children( + self, children: _Dict[str, ConfigurationOption] + ) -> Configuration: ... def get_ini_files(self) -> _List[Union[Path, str]]: ... def set_ini_files(self, files: _Iterable[Union[Path, str]]) -> Configuration: ... - def get_yaml_files(self) -> _List[Union[Path, str]]: ... def set_yaml_files(self, files: _Iterable[Union[Path, str]]) -> Configuration: ... - def get_json_files(self) -> _List[Union[Path, str]]: ... def set_json_files(self, files: _Iterable[Union[Path, str]]) -> Configuration: ... - def get_pydantic_settings(self) -> _List[PydanticSettings]: ... - def set_pydantic_settings(self, settings: _Iterable[PydanticSettings]) -> Configuration: ... - + def set_pydantic_settings( + self, settings: _Iterable[PydanticSettings] + ) -> Configuration: ... def load(self, required: bool = False, envs_required: bool = False) -> None: ... - def get(self, selector: str) -> Any: ... def set(self, selector: str, value: Any) -> OverridingContext[P]: ... def reset_cache(self) -> None: ... def update(self, value: Any) -> None: ... - def from_ini(self, filepath: Union[Path, str], required: bool = False, envs_required: bool = False) -> None: ... - def from_yaml(self, filepath: Union[Path, str], required: bool = False, loader: Optional[Any] = None, envs_required: bool = False) -> None: ... - def from_json(self, filepath: Union[Path, str], required: bool = False, envs_required: bool = False) -> None: ... - def from_pydantic(self, settings: PydanticSettings, required: bool = False, **kwargs: Any) -> None: ... + def from_ini( + self, + filepath: Union[Path, str], + required: bool = False, + envs_required: bool = False, + ) -> None: ... + def from_yaml( + self, + filepath: Union[Path, str], + required: bool = False, + loader: Optional[Any] = None, + envs_required: bool = False, + ) -> None: ... + def from_json( + self, + filepath: Union[Path, str], + required: bool = False, + envs_required: bool = False, + ) -> None: ... + def from_pydantic( + self, settings: PydanticSettings, required: bool = False, **kwargs: Any + ) -> None: ... def from_dict(self, options: _Dict[str, Any], required: bool = False) -> None: ... - def from_env(self, name: str, default: Optional[Any] = None, required: bool = False, as_: Optional[_Callable[..., Any]] = None) -> None: ... + def from_env( + self, + name: str, + default: Optional[Any] = None, + required: bool = False, + as_: Optional[_Callable[..., Any]] = None, + ) -> None: ... def from_value(self, value: Any) -> None: ... - class Factory(Provider[T]): provided_type: Optional[Type] - def __init__(self, provides: Optional[Union[_Callable[..., T], str]] = None, *args: Injection, **kwargs: Injection) -> None: ... + def __init__( + self, + provides: Optional[Union[_Callable[..., T], str]] = None, + *args: Injection, + **kwargs: Injection, + ) -> None: ... @property def cls(self) -> Type[T]: ... @property def provides(self) -> Optional[_Callable[..., T]]: ... - def set_provides(self, provides: Optional[Union[_Callable[..., T], str]]) -> Factory[T]: ... + def set_provides( + self, provides: Optional[Union[_Callable[..., T], str]] + ) -> Factory[T]: ... @property def args(self) -> Tuple[Injection]: ... def add_args(self, *args: Injection) -> Factory[T]: ... @@ -311,33 +362,39 @@ class Factory(Provider[T]): def set_attributes(self, **kwargs: Injection) -> Factory[T]: ... def clear_attributes(self) -> Factory[T]: ... - class DelegatedFactory(Factory[T]): ... - class AbstractFactory(Factory[T]): def override(self, provider: Factory) -> OverridingContext[P]: ... - class FactoryDelegate(Delegate): def __init__(self, factory: Factory): ... - class FactoryAggregate(Aggregate[T]): def __getattr__(self, provider_name: Any) -> Factory[T]: ... @property def factories(self) -> _Dict[Any, Factory[T]]: ... - def set_factories(self, provider_dict: Optional[_Dict[Any, Factory[T]]] = None, **provider_kwargs: Factory[T]) -> FactoryAggregate[T]: ... - + def set_factories( + self, + provider_dict: Optional[_Dict[Any, Factory[T]]] = None, + **provider_kwargs: Factory[T], + ) -> FactoryAggregate[T]: ... class BaseSingleton(Provider[T]): provided_type = Optional[Type] - def __init__(self, provides: Optional[Union[_Callable[..., T], str]] = None, *args: Injection, **kwargs: Injection) -> None: ... + def __init__( + self, + provides: Optional[Union[_Callable[..., T], str]] = None, + *args: Injection, + **kwargs: Injection, + ) -> None: ... @property def cls(self) -> Type[T]: ... @property def provides(self) -> Optional[_Callable[..., T]]: ... - def set_provides(self, provides: Optional[Union[_Callable[..., T], str]]) -> BaseSingleton[T]: ... + def set_provides( + self, provides: Optional[Union[_Callable[..., T], str]] + ) -> BaseSingleton[T]: ... @property def args(self) -> Tuple[Injection]: ... def add_args(self, *args: Injection) -> BaseSingleton[T]: ... @@ -356,36 +413,20 @@ class BaseSingleton(Provider[T]): def reset(self) -> SingletonResetContext[BS]: ... def full_reset(self) -> SingletonFullResetContext[BS]: ... - class Singleton(BaseSingleton[T]): ... - - class DelegatedSingleton(Singleton[T]): ... - - class ThreadSafeSingleton(Singleton[T]): ... - - class DelegatedThreadSafeSingleton(ThreadSafeSingleton[T]): ... - - class ThreadLocalSingleton(BaseSingleton[T]): ... - - class ContextLocalSingleton(BaseSingleton[T]): ... - - class DelegatedThreadLocalSingleton(ThreadLocalSingleton[T]): ... - class AbstractSingleton(BaseSingleton[T]): def override(self, provider: BaseSingleton) -> OverridingContext[P]: ... - class SingletonDelegate(Delegate): def __init__(self, singleton: BaseSingleton): ... - class List(Provider[_List]): def __init__(self, *args: Injection): ... @property @@ -394,29 +435,63 @@ class List(Provider[_List]): def set_args(self, *args: Injection) -> List[T]: ... def clear_args(self) -> List[T]: ... - class Dict(Provider[_Dict]): - def __init__(self, dict_: Optional[_Dict[Any, Injection]] = None, **kwargs: Injection): ... + def __init__( + self, dict_: Optional[_Dict[Any, Injection]] = None, **kwargs: Injection + ): ... @property def kwargs(self) -> _Dict[Any, Injection]: ... - def add_kwargs(self, dict_: Optional[_Dict[Any, Injection]] = None, **kwargs: Injection) -> Dict: ... - def set_kwargs(self, dict_: Optional[_Dict[Any, Injection]] = None, **kwargs: Injection) -> Dict: ... + def add_kwargs( + self, dict_: Optional[_Dict[Any, Injection]] = None, **kwargs: Injection + ) -> Dict: ... + def set_kwargs( + self, dict_: Optional[_Dict[Any, Injection]] = None, **kwargs: Injection + ) -> Dict: ... def clear_kwargs(self) -> Dict: ... - class Resource(Provider[T]): @overload - def __init__(self, provides: Optional[Type[resources.Resource[T]]] = None, *args: Injection, **kwargs: Injection) -> None: ... + def __init__( + self, + provides: Optional[Type[resources.Resource[T]]] = None, + *args: Injection, + **kwargs: Injection, + ) -> None: ... @overload - def __init__(self, provides: Optional[Type[resources.AsyncResource[T]]] = None, *args: Injection, **kwargs: Injection) -> None: ... + def __init__( + self, + provides: Optional[Type[resources.AsyncResource[T]]] = None, + *args: Injection, + **kwargs: Injection, + ) -> None: ... @overload - def __init__(self, provides: Optional[_Callable[..., _Iterator[T]]] = None, *args: Injection, **kwargs: Injection) -> None: ... + def __init__( + self, + provides: Optional[_Callable[..., _Iterator[T]]] = None, + *args: Injection, + **kwargs: Injection, + ) -> None: ... @overload - def __init__(self, provides: Optional[_Callable[..., _AsyncIterator[T]]] = None, *args: Injection, **kwargs: Injection) -> None: ... + def __init__( + self, + provides: Optional[_Callable[..., _AsyncIterator[T]]] = None, + *args: Injection, + **kwargs: Injection, + ) -> None: ... @overload - def __init__(self, provides: Optional[_Callable[..., _Coroutine[Injection, Injection, T]]] = None, *args: Injection, **kwargs: Injection) -> None: ... + def __init__( + self, + provides: Optional[_Callable[..., _Coroutine[Injection, Injection, T]]] = None, + *args: Injection, + **kwargs: Injection, + ) -> None: ... @overload - def __init__(self, provides: Optional[Union[_Callable[..., T], str]] = None, *args: Injection, **kwargs: Injection) -> None: ... + def __init__( + self, + provides: Optional[Union[_Callable[..., T], str]] = None, + *args: Injection, + **kwargs: Injection, + ) -> None: ... @property def provides(self) -> Optional[_Callable[..., Any]]: ... def set_provides(self, provides: Optional[Any]) -> Resource[T]: ... @@ -435,9 +510,13 @@ class Resource(Provider[T]): def init(self) -> Optional[Awaitable[T]]: ... def shutdown(self) -> Optional[Awaitable]: ... - class Container(Provider[T]): - def __init__(self, container_cls: Type[T], container: Optional[T] = None, **overriding_providers: Union[Provider, Any]) -> None: ... + def __init__( + self, + container_cls: Type[T], + container: Optional[T] = None, + **overriding_providers: Union[Provider, Any], + ) -> None: ... def __getattr__(self, name: str) -> Provider: ... @property def container(self) -> T: ... @@ -448,50 +527,51 @@ class Container(Provider[T]): def parent_name(self) -> Optional[str]: ... def assign_parent(self, parent: ProviderParent) -> None: ... - class Selector(Provider[Any]): - def __init__(self, selector: Optional[_Callable[..., Any]] = None, **providers: Provider): ... + def __init__( + self, selector: Optional[_Callable[..., Any]] = None, **providers: Provider + ): ... def __getattr__(self, name: str) -> Provider: ... - @property def selector(self) -> Optional[_Callable[..., Any]]: ... def set_selector(self, selector: Optional[_Callable[..., Any]]) -> Selector: ... - @property def providers(self) -> _Dict[str, Provider]: ... def set_providers(self, **providers: Provider) -> Selector: ... - class ProvidedInstanceFluentInterface: def __getattr__(self, item: Any) -> AttributeGetter: ... def __getitem__(self, item: Any) -> ItemGetter: ... def call(self, *args: Injection, **kwargs: Injection) -> MethodCaller: ... @property def provides(self) -> Optional[Provider]: ... - def set_provides(self, provides: Optional[Provider]) -> ProvidedInstanceFluentInterface: ... - + def set_provides( + self, provides: Optional[Provider] + ) -> ProvidedInstanceFluentInterface: ... class ProvidedInstance(Provider, ProvidedInstanceFluentInterface): def __init__(self, provides: Optional[Provider] = None) -> None: ... - class AttributeGetter(Provider, ProvidedInstanceFluentInterface): - def __init__(self, provides: Optional[Provider] = None, name: Optional[str] = None) -> None: ... + def __init__( + self, provides: Optional[Provider] = None, name: Optional[str] = None + ) -> None: ... @property def name(self) -> Optional[str]: ... def set_name(self, name: Optional[str]) -> ProvidedInstanceFluentInterface: ... - class ItemGetter(Provider, ProvidedInstanceFluentInterface): - def __init__(self, provides: Optional[Provider] = None, name: Optional[str] = None) -> None: ... + def __init__( + self, provides: Optional[Provider] = None, name: Optional[str] = None + ) -> None: ... @property def name(self) -> Optional[str]: ... def set_name(self, name: Optional[str]) -> ProvidedInstanceFluentInterface: ... - class MethodCaller(Provider, ProvidedInstanceFluentInterface): - def __init__(self, provides: Optional[Provider] = None, *args: Injection, **kwargs: Injection) -> None: ... - + def __init__( + self, provides: Optional[Provider] = None, *args: Injection, **kwargs: Injection + ) -> None: ... class OverridingContext(Generic[T]): def __init__(self, overridden: Provider, overriding: Provider): ... @@ -500,61 +580,39 @@ class OverridingContext(Generic[T]): pass ... - class BaseSingletonResetContext(Generic[T]): def __init__(self, provider: T): ... def __enter__(self) -> T: ... def __exit__(self, *_: Any) -> None: ... - -class SingletonResetContext(BaseSingletonResetContext): - ... - - -class SingletonFullResetContext(BaseSingletonResetContext): - ... - +class SingletonResetContext(BaseSingletonResetContext): ... +class SingletonFullResetContext(BaseSingletonResetContext): ... CHILD_PROVIDERS: Tuple[Provider] - def is_provider(instance: Any) -> bool: ... - - def ensure_is_provider(instance: Any) -> Provider: ... - - def is_delegated(instance: Any) -> bool: ... - - def represent_provider(provider: Provider, provides: Any) -> str: ... - - def deepcopy(instance: Any, memo: Optional[_Dict[Any, Any]] = None) -> Any: ... - - def deepcopy_args( provider: Provider[Any], args: Tuple[Any, ...], memo: Optional[_Dict[int, Any]] = None, ) -> Tuple[Any, ...]: ... - - def deepcopy_kwargs( provider: Provider[Any], kwargs: _Dict[str, Any], memo: Optional[_Dict[int, Any]] = None, ) -> Dict[str, Any]: ... - - def merge_dicts(dict1: _Dict[Any, Any], dict2: _Dict[Any, Any]) -> _Dict[Any, Any]: ... - - -def traverse(*providers: Provider, types: Optional[_Iterable[Type]]=None) -> _Iterator[Provider]: ... - +def traverse( + *providers: Provider, types: Optional[_Iterable[Type]] = None +) -> _Iterator[Provider]: ... if yaml: class YamlLoader(yaml.SafeLoader): ... + else: class YamlLoader: ... diff --git a/src/dependency_injector/providers.pyx b/src/dependency_injector/providers.pyx index 39716ea0..a3620350 100644 --- a/src/dependency_injector/providers.pyx +++ b/src/dependency_injector/providers.pyx @@ -2,6 +2,7 @@ from __future__ import absolute_import +import asyncio import copy import errno import functools @@ -14,6 +15,7 @@ import sys import threading import types import warnings +from configparser import ConfigParser as IniConfigParser try: import contextvars @@ -27,21 +29,18 @@ except ImportError: import __builtin__ as builtins try: - import asyncio + from inspect import _is_coroutine_mark as _is_coroutine_marker except ImportError: - asyncio = None - _is_coroutine_marker = None -else: - if sys.version_info >= (3, 5, 3): - import asyncio.coroutines - _is_coroutine_marker = asyncio.coroutines._is_coroutine - else: + try: + # Python >=3.12.0,<3.12.5 + from inspect import _is_coroutine_marker + except ImportError: _is_coroutine_marker = True try: - import ConfigParser as iniconfigparser + from asyncio.coroutines import _is_coroutine except ImportError: - import configparser as iniconfigparser + _is_coroutine = True try: import yaml @@ -99,7 +98,7 @@ config_env_marker_pattern = re.compile( r"\${(?P[^}^{:]+)(?P:?)(?P.*?)}", ) -def _resolve_config_env_markers(config_content, envs_required=False): +cdef str _resolve_config_env_markers(config_content: str, envs_required: bool): """Replace environment variable markers with their values.""" findings = list(config_env_marker_pattern.finditer(config_content)) @@ -118,28 +117,19 @@ def _resolve_config_env_markers(config_content, envs_required=False): return config_content -if sys.version_info[0] == 3: - def _parse_ini_file(filepath, envs_required=False): - parser = iniconfigparser.ConfigParser() - with open(filepath) as config_file: - config_string = _resolve_config_env_markers( - config_file.read(), - envs_required=envs_required, - ) - parser.read_string(config_string) - return parser -else: - import StringIO +cdef object _parse_ini_file(filepath, envs_required: bool | None): + parser = IniConfigParser() + + with open(filepath) as config_file: + config_string = config_file.read() - def _parse_ini_file(filepath, envs_required=False): - parser = iniconfigparser.ConfigParser() - with open(filepath) as config_file: + if envs_required is not None: config_string = _resolve_config_env_markers( - config_file.read(), + config_string, envs_required=envs_required, ) - parser.readfp(StringIO.StringIO(config_string)) - return parser + parser.read_string(config_string) + return parser if yaml: @@ -1475,7 +1465,8 @@ cdef class Coroutine(Callable): some_coroutine.add_kwargs(keyword_argument1=3, keyword_argument=4) """ - _is_coroutine = _is_coroutine_marker + _is_coroutine_marker = _is_coroutine_marker # Python >=3.12 + _is_coroutine = _is_coroutine # Python <3.16 def set_provides(self, provides): """Set provider provides.""" @@ -1713,7 +1704,7 @@ cdef class ConfigurationOption(Provider): try: parser = _parse_ini_file( filepath, - envs_required=envs_required if envs_required is not UNDEFINED else self._is_strict_mode_enabled(), + envs_required if envs_required is not UNDEFINED else self._is_strict_mode_enabled(), ) except IOError as exception: if required is not False \ @@ -1772,10 +1763,11 @@ cdef class ConfigurationOption(Provider): raise return - config_content = _resolve_config_env_markers( - config_content, - envs_required=envs_required if envs_required is not UNDEFINED else self._is_strict_mode_enabled(), - ) + if envs_required is not None: + config_content = _resolve_config_env_markers( + config_content, + envs_required if envs_required is not UNDEFINED else self._is_strict_mode_enabled(), + ) config = yaml.load(config_content, loader) current_config = self.__call__() @@ -1810,10 +1802,11 @@ cdef class ConfigurationOption(Provider): raise return - config_content = _resolve_config_env_markers( - config_content, - envs_required=envs_required if envs_required is not UNDEFINED else self._is_strict_mode_enabled(), - ) + if envs_required is not None: + config_content = _resolve_config_env_markers( + config_content, + envs_required if envs_required is not UNDEFINED else self._is_strict_mode_enabled(), + ) config = json.loads(config_content) current_config = self.__call__() @@ -2266,7 +2259,7 @@ cdef class Configuration(Object): try: parser = _parse_ini_file( filepath, - envs_required=envs_required if envs_required is not UNDEFINED else self._is_strict_mode_enabled(), + envs_required if envs_required is not UNDEFINED else self._is_strict_mode_enabled(), ) except IOError as exception: if required is not False \ @@ -2325,10 +2318,11 @@ cdef class Configuration(Object): raise return - config_content = _resolve_config_env_markers( - config_content, - envs_required=envs_required if envs_required is not UNDEFINED else self._is_strict_mode_enabled(), - ) + if envs_required is not None: + config_content = _resolve_config_env_markers( + config_content, + envs_required if envs_required is not UNDEFINED else self._is_strict_mode_enabled(), + ) config = yaml.load(config_content, loader) current_config = self.__call__() @@ -2363,10 +2357,11 @@ cdef class Configuration(Object): raise return - config_content = _resolve_config_env_markers( - config_content, - envs_required=envs_required if envs_required is not UNDEFINED else self._is_strict_mode_enabled(), - ) + if envs_required is not None: + config_content = _resolve_config_env_markers( + config_content, + envs_required if envs_required is not UNDEFINED else self._is_strict_mode_enabled(), + ) config = json.loads(config_content) current_config = self.__call__() diff --git a/src/dependency_injector/resources.py b/src/dependency_injector/resources.py index b5946cfa..7d71d4d8 100644 --- a/src/dependency_injector/resources.py +++ b/src/dependency_injector/resources.py @@ -10,18 +10,14 @@ class Resource(Generic[T], metaclass=abc.ABCMeta): @abc.abstractmethod - def init(self, *args, **kwargs) -> Optional[T]: - ... + def init(self, *args, **kwargs) -> Optional[T]: ... - def shutdown(self, resource: Optional[T]) -> None: - ... + def shutdown(self, resource: Optional[T]) -> None: ... class AsyncResource(Generic[T], metaclass=abc.ABCMeta): @abc.abstractmethod - async def init(self, *args, **kwargs) -> Optional[T]: - ... + async def init(self, *args, **kwargs) -> Optional[T]: ... - async def shutdown(self, resource: Optional[T]) -> None: - ... + async def shutdown(self, resource: Optional[T]) -> None: ... diff --git a/src/dependency_injector/schema.py b/src/dependency_injector/schema.py index c224e2d1..8547ebc2 100644 --- a/src/dependency_injector/schema.py +++ b/src/dependency_injector/schema.py @@ -27,9 +27,9 @@ def get_providers(self): return self._container.providers def _create_providers( - self, - provider_schema: ProviderSchema, - container: Optional[containers.Container] = None, + self, + provider_schema: ProviderSchema, + container: Optional[containers.Container] = None, ) -> None: if container is None: container = self._container @@ -57,9 +57,9 @@ def _create_providers( self._create_providers(provider_schema=data, container=provider) def _setup_injections( # noqa: C901 - self, - provider_schema: ProviderSchema, - container: Optional[containers.Container] = None, + self, + provider_schema: ProviderSchema, + container: Optional[containers.Container] = None, ) -> None: if container is None: container = self._container @@ -72,7 +72,7 @@ def _setup_injections( # noqa: C901 provides = data.get("provides") if provides: if isinstance(provides, str) and provides.startswith("container."): - provides = self._resolve_provider(provides[len("container."):]) + provides = self._resolve_provider(provides[len("container.") :]) else: provides = _import_string(provides) provider.set_provides(provides) @@ -83,7 +83,7 @@ def _setup_injections( # noqa: C901 injection = None if isinstance(arg, str) and arg.startswith("container."): - injection = self._resolve_provider(arg[len("container."):]) + injection = self._resolve_provider(arg[len("container.") :]) # TODO: refactoring if isinstance(arg, dict): @@ -91,16 +91,23 @@ def _setup_injections( # noqa: C901 provider_type = _get_provider_cls(arg.get("provider")) provides = arg.get("provides") if provides: - if isinstance(provides, str) and provides.startswith("container."): - provides = self._resolve_provider(provides[len("container."):]) + if isinstance(provides, str) and provides.startswith( + "container." + ): + provides = self._resolve_provider( + provides[len("container.") :] + ) else: provides = _import_string(provides) provider_args.append(provides) for provider_arg in arg.get("args", []): - if isinstance(provider_arg, str) \ - and provider_arg.startswith("container."): + if isinstance( + provider_arg, str + ) and provider_arg.startswith("container."): provider_args.append( - self._resolve_provider(provider_arg[len("container."):]), + self._resolve_provider( + provider_arg[len("container.") :] + ), ) injection = provider_type(*provider_args) @@ -117,7 +124,7 @@ def _setup_injections( # noqa: C901 injection = None if isinstance(arg, str) and arg.startswith("container."): - injection = self._resolve_provider(arg[len("container."):]) + injection = self._resolve_provider(arg[len("container.") :]) # TODO: refactoring if isinstance(arg, dict): @@ -125,16 +132,23 @@ def _setup_injections( # noqa: C901 provider_type = _get_provider_cls(arg.get("provider")) provides = arg.get("provides") if provides: - if isinstance(provides, str) and provides.startswith("container."): - provides = self._resolve_provider(provides[len("container."):]) + if isinstance(provides, str) and provides.startswith( + "container." + ): + provides = self._resolve_provider( + provides[len("container.") :] + ) else: provides = _import_string(provides) provider_args.append(provides) for provider_arg in arg.get("args", []): - if isinstance(provider_arg, str) \ - and provider_arg.startswith("container."): + if isinstance( + provider_arg, str + ) and provider_arg.startswith("container."): provider_args.append( - self._resolve_provider(provider_arg[len("container."):]), + self._resolve_provider( + provider_arg[len("container.") :] + ), ) injection = provider_type(*provider_args) @@ -158,7 +172,7 @@ def _resolve_provider(self, name: str) -> Optional[providers.Provider]: for segment in segments[1:]: parentheses = "" if "(" in segment and ")" in segment: - parentheses = segment[segment.find("("):segment.rfind(")")+1] + parentheses = segment[segment.find("(") : segment.rfind(")") + 1] segment = segment.replace(parentheses, "") try: @@ -190,10 +204,12 @@ def _get_provider_cls(provider_cls_name: str) -> Type[providers.Provider]: if custom_provider_type: return custom_provider_type - raise SchemaError(f"Undefined provider class \"{provider_cls_name}\"") + raise SchemaError(f'Undefined provider class "{provider_cls_name}"') -def _fetch_provider_cls_from_std(provider_cls_name: str) -> Optional[Type[providers.Provider]]: +def _fetch_provider_cls_from_std( + provider_cls_name: str, +) -> Optional[Type[providers.Provider]]: return getattr(providers, provider_cls_name, None) @@ -201,12 +217,16 @@ def _import_provider_cls(provider_cls_name: str) -> Optional[Type[providers.Prov try: cls = _import_string(provider_cls_name) except (ImportError, ValueError) as exception: - raise SchemaError(f"Can not import provider \"{provider_cls_name}\"") from exception + raise SchemaError( + f'Can not import provider "{provider_cls_name}"' + ) from exception except AttributeError: return None else: if isinstance(cls, type) and not issubclass(cls, providers.Provider): - raise SchemaError(f"Provider class \"{cls}\" is not a subclass of providers base class") + raise SchemaError( + f'Provider class "{cls}" is not a subclass of providers base class' + ) return cls diff --git a/src/dependency_injector/wiring.py b/src/dependency_injector/wiring.py index b1f01622..5cded9f5 100644 --- a/src/dependency_injector/wiring.py +++ b/src/dependency_injector/wiring.py @@ -1,34 +1,35 @@ """Wiring module.""" import functools -import inspect import importlib import importlib.machinery +import inspect import pkgutil -import warnings import sys +import warnings from types import ModuleType from typing import ( - Optional, - Iterable, - Iterator, - Callable, Any, - Tuple, + Callable, Dict, Generic, - TypeVar, + Iterable, + Iterator, + Optional, + Set, + Tuple, Type, + TypeVar, Union, - Set, cast, ) if sys.version_info < (3, 7): from typing import GenericMeta else: - class GenericMeta(type): - ... + + class GenericMeta(type): ... + # Hotfix, see: https://github.com/ets-labs/python-dependency-injector/issues/362 if sys.version_info >= (3, 9): @@ -36,6 +37,21 @@ class GenericMeta(type): else: GenericAlias = None +if sys.version_info >= (3, 9): + from typing import Annotated, get_args, get_origin +else: + try: + from typing_extensions import Annotated, get_args, get_origin + except ImportError: + Annotated = object() + + # For preventing NameError. Never executes + def get_args(hint): + return () + + def get_origin(tp): + return None + try: import fastapi.params @@ -99,7 +115,9 @@ def __init__(self) -> None: def register_callable(self, patched: "PatchedCallable") -> None: self._callables[patched.patched] = patched - def get_callables_from_module(self, module: ModuleType) -> Iterator[Callable[..., Any]]: + def get_callables_from_module( + self, module: ModuleType + ) -> Iterator[Callable[..., Any]]: for patched_callable in self._callables.values(): if not patched_callable.is_in_module(module): continue @@ -114,7 +132,9 @@ def has_callable(self, fn: Callable[..., Any]) -> bool: def register_attribute(self, patched: "PatchedAttribute") -> None: self._attributes.add(patched) - def get_attributes_from_module(self, module: ModuleType) -> Iterator["PatchedAttribute"]: + def get_attributes_from_module( + self, module: ModuleType + ) -> Iterator["PatchedAttribute"]: for attribute in self._attributes: if not attribute.is_in_module(module): continue @@ -139,11 +159,11 @@ class PatchedCallable: ) def __init__( - self, - patched: Optional[Callable[..., Any]] = None, - original: Optional[Callable[..., Any]] = None, - reference_injections: Optional[Dict[Any, Any]] = None, - reference_closing: Optional[Dict[Any, Any]] = None, + self, + patched: Optional[Callable[..., Any]] = None, + original: Optional[Callable[..., Any]] = None, + reference_injections: Optional[Dict[Any, Any]] = None, + reference_closing: Optional[Dict[Any, Any]] = None, ) -> None: self.patched = patched self.original = original @@ -214,18 +234,21 @@ def __init__(self, container) -> None: ) def resolve_provider( - self, - provider: Union[providers.Provider, str], - modifier: Optional["Modifier"] = None, + self, + provider: Union[providers.Provider, str], + modifier: Optional["Modifier"] = None, ) -> Optional[providers.Provider]: if isinstance(provider, providers.Delegate): return self._resolve_delegate(provider) - elif isinstance(provider, ( - providers.ProvidedInstance, - providers.AttributeGetter, - providers.ItemGetter, - providers.MethodCaller, - )): + elif isinstance( + provider, + ( + providers.ProvidedInstance, + providers.AttributeGetter, + providers.ItemGetter, + providers.MethodCaller, + ), + ): return self._resolve_provided_instance(provider) elif isinstance(provider, providers.ConfigurationOption): return self._resolve_config_option(provider) @@ -237,9 +260,9 @@ def resolve_provider( return self._resolve_provider(provider) def _resolve_string_id( - self, - id: str, - modifier: Optional["Modifier"] = None, + self, + id: str, + modifier: Optional["Modifier"] = None, ) -> Optional[providers.Provider]: if id == self.CONTAINER_STRING_ID: return self._container.__self__ @@ -256,16 +279,19 @@ def _resolve_string_id( return provider def _resolve_provided_instance( - self, - original: providers.Provider, + self, + original: providers.Provider, ) -> Optional[providers.Provider]: modifiers = [] - while isinstance(original, ( + while isinstance( + original, + ( providers.ProvidedInstance, providers.AttributeGetter, providers.ItemGetter, providers.MethodCaller, - )): + ), + ): modifiers.insert(0, original) original = original.provides @@ -289,8 +315,8 @@ def _resolve_provided_instance( return new def _resolve_delegate( - self, - original: providers.Delegate, + self, + original: providers.Delegate, ) -> Optional[providers.Provider]: provider = self._resolve_provider(original.provides) if provider: @@ -298,9 +324,9 @@ def _resolve_delegate( return provider def _resolve_config_option( - self, - original: providers.ConfigurationOption, - as_: Any = None, + self, + original: providers.ConfigurationOption, + as_: Any = None, ) -> Optional[providers.Provider]: original_root = original.root new = self._resolve_provider(original_root) @@ -324,8 +350,8 @@ def _resolve_config_option( return new def _resolve_provider( - self, - original: providers.Provider, + self, + original: providers.Provider, ) -> Optional[providers.Provider]: try: return self._map[original] @@ -334,9 +360,9 @@ def _resolve_provider( @classmethod def _create_providers_map( - cls, - current_container: Container, - original_container: Container, + cls, + current_container: Container, + original_container: Container, ) -> Dict[providers.Provider, providers.Provider]: current_providers = current_container.providers current_providers["__self__"] = current_container.__self__ @@ -349,8 +375,9 @@ def _create_providers_map( original_provider = original_providers[provider_name] providers_map[original_provider] = current_provider - if isinstance(current_provider, providers.Container) \ - and isinstance(original_provider, providers.Container): + if isinstance(current_provider, providers.Container) and isinstance( + original_provider, providers.Container + ): subcontainer_map = cls._create_providers_map( current_container=current_provider.container, original_container=original_provider.container, @@ -376,19 +403,21 @@ def _is_werkzeug_local_proxy(self, instance: object) -> bool: return werkzeug and isinstance(instance, werkzeug.local.LocalProxy) def _is_starlette_request_cls(self, instance: object) -> bool: - return starlette \ - and isinstance(instance, type) \ - and _safe_is_subclass(instance, starlette.requests.Request) + return ( + starlette + and isinstance(instance, type) + and _safe_is_subclass(instance, starlette.requests.Request) + ) def _is_builtin(self, instance: object) -> bool: return inspect.isbuiltin(instance) def wire( # noqa: C901 - container: Container, - *, - modules: Optional[Iterable[ModuleType]] = None, - packages: Optional[Iterable[ModuleType]] = None, + container: Container, + *, + modules: Optional[Iterable[ModuleType]] = None, + packages: Optional[Iterable[ModuleType]] = None, ) -> None: """Wire container providers with provided packages and modules.""" modules = [*modules] if modules else [] @@ -418,18 +447,22 @@ def wire( # noqa: C901 else: for cls_member_name, cls_member in cls_members: if _is_marker(cls_member): - _patch_attribute(cls, cls_member_name, cls_member, providers_map) + _patch_attribute( + cls, cls_member_name, cls_member, providers_map + ) elif _is_method(cls_member): - _patch_method(cls, cls_member_name, cls_member, providers_map) + _patch_method( + cls, cls_member_name, cls_member, providers_map + ) for patched in _patched_registry.get_callables_from_module(module): _bind_injections(patched, providers_map) def unwire( # noqa: C901 - *, - modules: Optional[Iterable[ModuleType]] = None, - packages: Optional[Iterable[ModuleType]] = None, + *, + modules: Optional[Iterable[ModuleType]] = None, + packages: Optional[Iterable[ModuleType]] = None, ) -> None: """Wire provided packages and modules with previous wired providers.""" modules = [*modules] if modules else [] @@ -443,7 +476,9 @@ def unwire( # noqa: C901 if inspect.isfunction(member): _unpatch(module, name, member) elif inspect.isclass(member): - for method_name, method in inspect.getmembers(member, inspect.isfunction): + for method_name, method in inspect.getmembers( + member, inspect.isfunction + ): _unpatch(member, method_name, method) for patched in _patched_registry.get_callables_from_module(module): @@ -462,10 +497,10 @@ def inject(fn: F) -> F: def _patch_fn( - module: ModuleType, - name: str, - fn: Callable[..., Any], - providers_map: ProvidersMap, + module: ModuleType, + name: str, + fn: Callable[..., Any], + providers_map: ProvidersMap, ) -> None: if not _is_patched(fn): reference_injections, reference_closing = _fetch_reference_injections(fn) @@ -479,14 +514,16 @@ def _patch_fn( def _patch_method( - cls: Type, - name: str, - method: Callable[..., Any], - providers_map: ProvidersMap, + cls: Type, + name: str, + method: Callable[..., Any], + providers_map: ProvidersMap, ) -> None: - if hasattr(cls, "__dict__") \ - and name in cls.__dict__ \ - and isinstance(cls.__dict__[name], (classmethod, staticmethod)): + if ( + hasattr(cls, "__dict__") + and name in cls.__dict__ + and isinstance(cls.__dict__[name], (classmethod, staticmethod)) + ): method = cls.__dict__[name] fn = method.__func__ else: @@ -507,13 +544,15 @@ def _patch_method( def _unpatch( - module: ModuleType, - name: str, - fn: Callable[..., Any], + module: ModuleType, + name: str, + fn: Callable[..., Any], ) -> None: - if hasattr(module, "__dict__") \ - and name in module.__dict__ \ - and isinstance(module.__dict__[name], (classmethod, staticmethod)): + if ( + hasattr(module, "__dict__") + and name in module.__dict__ + and isinstance(module.__dict__[name], (classmethod, staticmethod)) + ): method = module.__dict__[name] fn = method.__func__ @@ -524,10 +563,10 @@ def _unpatch( def _patch_attribute( - member: Any, - name: str, - marker: "_Marker", - providers_map: ProvidersMap, + member: Any, + name: str, + marker: "_Marker", + providers_map: ProvidersMap, ) -> None: provider = providers_map.resolve_provider(marker.provider, marker.modifier) if provider is None: @@ -548,16 +587,33 @@ def _unpatch_attribute(patched: PatchedAttribute) -> None: setattr(patched.member, patched.name, patched.marker) +def _extract_marker(parameter: inspect.Parameter) -> Optional["_Marker"]: + if get_origin(parameter.annotation) is Annotated: + marker = get_args(parameter.annotation)[1] + else: + marker = parameter.default + + if not isinstance(marker, _Marker) and not _is_fastapi_depends(marker): + return None + + if _is_fastapi_depends(marker): + marker = marker.dependency + + if not isinstance(marker, _Marker): + return None + + return marker + + def _fetch_reference_injections( # noqa: C901 - fn: Callable[..., Any], + fn: Callable[..., Any], ) -> Tuple[Dict[str, Any], Dict[str, Any]]: # Hotfix, see: # - https://github.com/ets-labs/python-dependency-injector/issues/362 # - https://github.com/ets-labs/python-dependency-injector/issues/398 - if GenericAlias and any(( - fn is GenericAlias, - getattr(fn, "__func__", None) is GenericAlias - )): + if GenericAlias and any( + (fn is GenericAlias, getattr(fn, "__func__", None) is GenericAlias) + ): fn = fn.__init__ try: @@ -573,17 +629,10 @@ def _fetch_reference_injections( # noqa: C901 injections = {} closing = {} for parameter_name, parameter in signature.parameters.items(): - if not isinstance(parameter.default, _Marker) \ - and not _is_fastapi_depends(parameter.default): - continue - - marker = parameter.default + marker = _extract_marker(parameter) - if _is_fastapi_depends(marker): - marker = marker.dependency - - if not isinstance(marker, _Marker): - continue + if marker is None: + continue if isinstance(marker, Closing): marker = marker.provider @@ -593,20 +642,19 @@ def _fetch_reference_injections( # noqa: C901 return injections, closing -def _locate_dependent_closing_args(provider: providers.Provider) -> Dict[str, providers.Provider]: - if not hasattr(provider, "args"): - return {} - - closing_deps = {} - for arg in provider.args: - if not isinstance(arg, providers.Provider) or not hasattr(arg, "args"): +def _locate_dependent_closing_args( + provider: providers.Provider, closing_deps: Dict[str, providers.Provider] +) -> Dict[str, providers.Provider]: + for arg in [ + *getattr(provider, "args", []), + *getattr(provider, "kwargs", {}).values(), + ]: + if not isinstance(arg, providers.Provider): continue + if isinstance(arg, providers.Resource): + closing_deps[str(id(arg))] = arg - if not arg.args and isinstance(arg, providers.Resource): - return {str(id(arg)): arg} - else: - closing_deps += _locate_dependent_closing_args(arg) - return closing_deps + _locate_dependent_closing_args(arg, closing_deps) def _bind_injections(fn: Callable[..., Any], providers_map: ProvidersMap) -> None: @@ -630,7 +678,8 @@ def _bind_injections(fn: Callable[..., Any], providers_map: ProvidersMap) -> Non if injection in patched_callable.reference_closing: patched_callable.add_closing(injection, provider) - deps = _locate_dependent_closing_args(provider) + deps = {} + _locate_dependent_closing_args(provider, deps) for key, dep in deps.items(): patched_callable.add_closing(key, dep) @@ -647,8 +696,8 @@ def _fetch_modules(package): if not hasattr(package, "__path__") or not hasattr(package, "__name__"): return modules for module_info in pkgutil.walk_packages( - path=package.__path__, - prefix=package.__name__ + ".", + path=package.__path__, + prefix=package.__name__ + ".", ): module = importlib.import_module(module_info.name) modules.append(module) @@ -664,9 +713,9 @@ def _is_marker(member) -> bool: def _get_patched( - fn: F, - reference_injections: Dict[Any, Any], - reference_closing: Dict[Any, Any], + fn: F, + reference_injections: Dict[Any, Any], + reference_closing: Dict[Any, Any], ) -> F: patched_object = PatchedCallable( original=fn, @@ -694,9 +743,11 @@ def _is_patched(fn) -> bool: def _is_declarative_container(instance: Any) -> bool: - return (isinstance(instance, type) - and getattr(instance, "__IS_CONTAINER__", False) is True - and getattr(instance, "declarative_parent", None) is None) + return ( + isinstance(instance, type) + and getattr(instance, "__IS_CONTAINER__", False) is True + and getattr(instance, "declarative_parent", None) is None + ) def _safe_is_subclass(instance: Any, cls: Type) -> bool: @@ -709,11 +760,10 @@ def _safe_is_subclass(instance: Any, cls: Type) -> bool: class Modifier: def modify( - self, - provider: providers.ConfigurationOption, - providers_map: ProvidersMap, - ) -> providers.Provider: - ... + self, + provider: providers.ConfigurationOption, + providers_map: ProvidersMap, + ) -> providers.Provider: ... class TypeModifier(Modifier): @@ -722,9 +772,9 @@ def __init__(self, type_: Type) -> None: self.type_ = type_ def modify( - self, - provider: providers.ConfigurationOption, - providers_map: ProvidersMap, + self, + provider: providers.ConfigurationOption, + providers_map: ProvidersMap, ) -> providers.Provider: return provider.as_(self.type_) @@ -762,9 +812,9 @@ def as_(self, type_: Type) -> "RequiredModifier": return self def modify( - self, - provider: providers.ConfigurationOption, - providers_map: ProvidersMap, + self, + provider: providers.ConfigurationOption, + providers_map: ProvidersMap, ) -> providers.Provider: provider = provider.required() if self.type_modifier: @@ -783,9 +833,9 @@ def __init__(self, id: str) -> None: self.id = id def modify( - self, - provider: providers.ConfigurationOption, - providers_map: ProvidersMap, + self, + provider: providers.ConfigurationOption, + providers_map: ProvidersMap, ) -> providers.Provider: invariant_segment = providers_map.resolve_provider(self.id) return provider[invariant_segment] @@ -818,9 +868,9 @@ def call(self): return self def modify( - self, - provider: providers.Provider, - providers_map: ProvidersMap, + self, + provider: providers.Provider, + providers_map: ProvidersMap, ) -> providers.Provider: provider = provider.provided for type_, value in self.segments: @@ -851,9 +901,9 @@ class _Marker(Generic[T], metaclass=ClassGetItemMeta): __IS_MARKER__ = True def __init__( - self, - provider: Union[providers.Provider, Container, str], - modifier: Optional[Modifier] = None, + self, + provider: Union[providers.Provider, Container, str], + modifier: Optional[Modifier] = None, ) -> None: if _is_declarative_container(provider): provider = provider.__self__ @@ -869,16 +919,13 @@ def __call__(self) -> T: return self -class Provide(_Marker): - ... +class Provide(_Marker): ... -class Provider(_Marker): - ... +class Provider(_Marker): ... -class Closing(_Marker): - ... +class Closing(_Marker): ... class AutoLoader: @@ -928,8 +975,7 @@ def exec_module(self, module): super().exec_module(module) loader.wire_module(module) - class ExtensionFileLoader(importlib.machinery.ExtensionFileLoader): - ... + class ExtensionFileLoader(importlib.machinery.ExtensionFileLoader): ... loader_details = [ (SourcelessFileLoader, importlib.machinery.BYTECODE_SUFFIXES), @@ -982,7 +1028,7 @@ def is_loader_installed() -> bool: _loader = AutoLoader() # Optimizations -from ._cwiring import _get_sync_patched # noqa +from ._cwiring import _sync_inject # noqa from ._cwiring import _async_inject # noqa @@ -998,4 +1044,18 @@ async def _patched(*args, **kwargs): patched.injections, patched.closing, ) - return _patched + + return cast(F, _patched) + + +def _get_sync_patched(fn: F, patched: PatchedCallable) -> F: + @functools.wraps(fn) + def _patched(*args, **kwargs): + return _sync_inject( + fn, + args, + kwargs, + patched.injections, + patched.closing, + ) + return cast(F, _patched) diff --git a/tests/typing/callable.py b/tests/typing/callable.py index 51c8d3a7..09a5a3f8 100644 --- a/tests/typing/callable.py +++ b/tests/typing/callable.py @@ -34,7 +34,7 @@ def create(cls) -> Animal: # Test 5: to check the provided instance interface provider5 = providers.Callable(Animal) -provided5: providers.ProvidedInstance = provider5.provided +provided5: Animal = provider5.provided() attr_getter5: providers.AttributeGetter = provider5.provided.attr item_getter5: providers.ItemGetter = provider5.provided["item"] method_caller: providers.MethodCaller = provider5.provided.method.call(123, arg=324) diff --git a/tests/typing/dict.py b/tests/typing/dict.py index 31205067..f676b67d 100644 --- a/tests/typing/dict.py +++ b/tests/typing/dict.py @@ -34,7 +34,7 @@ a1=providers.Factory(object), a2=providers.Factory(object), ) -provided5: providers.ProvidedInstance = provider5.provided +provided5: dict[Any, Any] = provider5.provided() # Test 6: to check the return type with await diff --git a/tests/typing/factory.py b/tests/typing/factory.py index 132a4c29..089bef07 100644 --- a/tests/typing/factory.py +++ b/tests/typing/factory.py @@ -37,7 +37,7 @@ def create(cls) -> Animal: # Test 5: to check the provided instance interface provider5 = providers.Factory(Animal) -provided5: providers.ProvidedInstance = provider5.provided +provided5: Animal = provider5.provided() attr_getter5: providers.AttributeGetter = provider5.provided.attr item_getter5: providers.ItemGetter = provider5.provided["item"] method_caller5: providers.MethodCaller = provider5.provided.method.call(123, arg=324) diff --git a/tests/typing/list.py b/tests/typing/list.py index 3ceae7cc..d29baadb 100644 --- a/tests/typing/list.py +++ b/tests/typing/list.py @@ -23,7 +23,7 @@ providers.Factory(object), providers.Factory(object), ) -provided3: providers.ProvidedInstance = provider3.provided +provided3: List[Any] = provider3.provided() attr_getter3: providers.AttributeGetter = provider3.provided.attr item_getter3: providers.ItemGetter = provider3.provided["item"] method_caller3: providers.MethodCaller = provider3.provided.method.call(123, arg=324) diff --git a/tests/typing/object.py b/tests/typing/object.py index b099c83b..103071ae 100644 --- a/tests/typing/object.py +++ b/tests/typing/object.py @@ -9,7 +9,7 @@ # Test 2: to check the provided instance interface provider2 = providers.Object(int) -provided2: providers.ProvidedInstance = provider2.provided +provided2: Type[int] = provider2.provided() attr_getter2: providers.AttributeGetter = provider2.provided.attr item_getter2: providers.ItemGetter = provider2.provided["item"] method_caller2: providers.MethodCaller = provider2.provided.method.call(123, arg=324) diff --git a/tests/typing/provider.py b/tests/typing/provider.py index 46dd23a0..18afd69a 100644 --- a/tests/typing/provider.py +++ b/tests/typing/provider.py @@ -3,7 +3,8 @@ # Test 1: to check .provided attribute provider1: providers.Provider[int] = providers.Object(1) -provided: providers.ProvidedInstance = provider1.provided +provided: int = provider1.provided() +provider1_delegate: providers.Provider[int] = provider1.provider # Test 2: to check async mode API provider2: providers.Provider = providers.Provider() diff --git a/tests/typing/singleton.py b/tests/typing/singleton.py index badfe1c6..8dc3df23 100644 --- a/tests/typing/singleton.py +++ b/tests/typing/singleton.py @@ -37,7 +37,7 @@ def create(cls) -> Animal: # Test 5: to check the provided instance interface provider5 = providers.Singleton(Animal) -provided5: providers.ProvidedInstance = provider5.provided +provided5: Animal = provider5.provided() attr_getter5: providers.AttributeGetter = provider5.provided.attr item_getter5: providers.ItemGetter = provider5.provided["item"] method_caller5: providers.MethodCaller = provider5.provided.method.call(123, arg=324) diff --git a/tests/unit/providers/configuration/test_from_ini_with_env_py2_py3.py b/tests/unit/providers/configuration/test_from_ini_with_env_py2_py3.py index e669ce48..96949a67 100644 --- a/tests/unit/providers/configuration/test_from_ini_with_env_py2_py3.py +++ b/tests/unit/providers/configuration/test_from_ini_with_env_py2_py3.py @@ -5,6 +5,23 @@ from pytest import mark, raises +def test_no_env_variable_interpolation(config, ini_config_file_3): + config.from_ini(ini_config_file_3, envs_required=None) + + assert config() == { + "section1": { + "value1": "${CONFIG_TEST_ENV}", + "value2": "${CONFIG_TEST_PATH}/path", + }, + } + assert config.section1() == { + "value1": "${CONFIG_TEST_ENV}", + "value2": "${CONFIG_TEST_PATH}/path", + } + assert config.section1.value1() == "${CONFIG_TEST_ENV}" + assert config.section1.value2() == "${CONFIG_TEST_PATH}/path" + + def test_env_variable_interpolation(config, ini_config_file_3): config.from_ini(ini_config_file_3) diff --git a/tests/unit/providers/configuration/test_from_json_with_env_py2_py3.py b/tests/unit/providers/configuration/test_from_json_with_env_py2_py3.py index 2bd8f9aa..4ec7c4ea 100644 --- a/tests/unit/providers/configuration/test_from_json_with_env_py2_py3.py +++ b/tests/unit/providers/configuration/test_from_json_with_env_py2_py3.py @@ -6,6 +6,23 @@ from pytest import mark, raises +def test_no_env_variable_interpolation(config, json_config_file_3): + config.from_json(json_config_file_3, envs_required=None) + + assert config() == { + "section1": { + "value1": "${CONFIG_TEST_ENV}", + "value2": "${CONFIG_TEST_PATH}/path", + }, + } + assert config.section1() == { + "value1": "${CONFIG_TEST_ENV}", + "value2": "${CONFIG_TEST_PATH}/path", + } + assert config.section1.value1() == "${CONFIG_TEST_ENV}" + assert config.section1.value2() == "${CONFIG_TEST_PATH}/path" + + def test_env_variable_interpolation(config, json_config_file_3): config.from_json(json_config_file_3) diff --git a/tests/unit/providers/configuration/test_from_yaml_with_env_py2_py3.py b/tests/unit/providers/configuration/test_from_yaml_with_env_py2_py3.py index 8e6e1c0d..c047659e 100644 --- a/tests/unit/providers/configuration/test_from_yaml_with_env_py2_py3.py +++ b/tests/unit/providers/configuration/test_from_yaml_with_env_py2_py3.py @@ -6,6 +6,23 @@ from pytest import mark, raises +def test_no_env_variable_interpolation(config, yaml_config_file_3): + config.from_yaml(yaml_config_file_3, envs_required=None) + + assert config() == { + "section1": { + "value1": "${CONFIG_TEST_ENV}", + "value2": "${CONFIG_TEST_PATH}/path", + }, + } + assert config.section1() == { + "value1": "${CONFIG_TEST_ENV}", + "value2": "${CONFIG_TEST_PATH}/path", + } + assert config.section1.value1() == "${CONFIG_TEST_ENV}" + assert config.section1.value2() == "${CONFIG_TEST_PATH}/path" + + def test_env_variable_interpolation(config, yaml_config_file_3): config.from_yaml(yaml_config_file_3) diff --git a/tests/unit/providers/coroutines/test_coroutine_py35.py b/tests/unit/providers/coroutines/test_coroutine_py35.py index 22e794b1..de0e9c67 100644 --- a/tests/unit/providers/coroutines/test_coroutine_py35.py +++ b/tests/unit/providers/coroutines/test_coroutine_py35.py @@ -1,4 +1,5 @@ """Coroutine provider tests.""" +import sys from dependency_injector import providers, errors from pytest import mark, raises @@ -208,3 +209,17 @@ def test_repr(): "".format(repr(example), hex(id(provider))) ) + + +@mark.skipif(sys.version_info > (3, 15), reason="requires Python<3.16") +def test_asyncio_iscoroutinefunction() -> None: + from asyncio.coroutines import iscoroutinefunction + + assert iscoroutinefunction(providers.Coroutine(example)) + + +@mark.skipif(sys.version_info < (3, 12), reason="requires Python>=3.12") +def test_inspect_iscoroutinefunction() -> None: + from inspect import iscoroutinefunction + + assert iscoroutinefunction(providers.Coroutine(example)) diff --git a/tests/unit/samples/wiringfastapi/web.py b/tests/unit/samples/wiringfastapi/web.py index 3cee5450..c1ed5102 100644 --- a/tests/unit/samples/wiringfastapi/web.py +++ b/tests/unit/samples/wiringfastapi/web.py @@ -1,7 +1,11 @@ import sys +from typing_extensions import Annotated + from fastapi import FastAPI, Depends -from fastapi import Request # See: https://github.com/ets-labs/python-dependency-injector/issues/398 +from fastapi import ( + Request, +) # See: https://github.com/ets-labs/python-dependency-injector/issues/398 from fastapi.security import HTTPBasic, HTTPBasicCredentials from dependency_injector import containers, providers from dependency_injector.wiring import inject, Provide @@ -28,11 +32,16 @@ async def index(service: Service = Depends(Provide[Container.service])): return {"result": result} +@app.api_route("/annotated") +@inject +async def annotated(service: Annotated[Service, Depends(Provide[Container.service])]): + result = await service.process() + return {"result": result} + + @app.get("/auth") @inject -def read_current_user( - credentials: HTTPBasicCredentials = Depends(security) -): +def read_current_user(credentials: HTTPBasicCredentials = Depends(security)): return {"username": credentials.username, "password": credentials.password} diff --git a/tests/unit/samples/wiringflask/web.py b/tests/unit/samples/wiringflask/web.py index f273d8aa..8bb44494 100644 --- a/tests/unit/samples/wiringflask/web.py +++ b/tests/unit/samples/wiringflask/web.py @@ -1,3 +1,5 @@ +from typing_extensions import Annotated + from flask import Flask, jsonify, request, current_app, session, g from dependency_injector import containers, providers from dependency_injector.wiring import inject, Provide @@ -26,5 +28,12 @@ def index(service: Service = Provide[Container.service]): return jsonify({"result": result}) +@app.route("/annotated") +@inject +def annotated(service: Annotated[Service, Provide[Container.service]]): + result = service.process() + return jsonify({"result": result}) + + container = Container() container.wire(modules=[__name__]) diff --git a/tests/unit/samples/wiringstringids/resourceclosing.py b/tests/unit/samples/wiringstringids/resourceclosing.py index 6360e15c..c4d1f20f 100644 --- a/tests/unit/samples/wiringstringids/resourceclosing.py +++ b/tests/unit/samples/wiringstringids/resourceclosing.py @@ -1,41 +1,80 @@ +from typing import Any, Dict, List, Optional + from dependency_injector import containers, providers -from dependency_injector.wiring import inject, Provide, Closing +from dependency_injector.wiring import Closing, Provide, inject + + +class Counter: + def __init__(self) -> None: + self._init = 0 + self._shutdown = 0 + + def init(self) -> None: + self._init += 1 + + def shutdown(self) -> None: + self._shutdown += 1 + + def reset(self) -> None: + self._init = 0 + self._shutdown = 0 class Service: - init_counter: int = 0 - shutdown_counter: int = 0 + def __init__(self, counter: Optional[Counter] = None, **dependencies: Any) -> None: + self.counter = counter or Counter() + self.dependencies = dependencies + + def init(self) -> None: + self.counter.init() - @classmethod - def reset_counter(cls): - cls.init_counter = 0 - cls.shutdown_counter = 0 + def shutdown(self) -> None: + self.counter.shutdown() - @classmethod - def init(cls): - cls.init_counter += 1 + @property + def init_counter(self) -> int: + return self.counter._init - @classmethod - def shutdown(cls): - cls.shutdown_counter += 1 + @property + def shutdown_counter(self) -> int: + return self.counter._shutdown class FactoryService: - def __init__(self, service: Service): + def __init__(self, service: Service, service2: Service): self.service = service + self.service2 = service2 + + +class NestedService: + def __init__(self, factory_service: FactoryService): + self.factory_service = factory_service -def init_service(): - service = Service() +def init_service(counter: Counter, _list: List[int], _dict: Dict[str, int]): + service = Service(counter, _list=_list, _dict=_dict) service.init() yield service service.shutdown() class Container(containers.DeclarativeContainer): - - service = providers.Resource(init_service) - factory_service = providers.Factory(FactoryService, service) + counter = providers.Singleton(Counter) + _list = providers.List( + providers.Callable(lambda a: a, a=1), providers.Callable(lambda b: b, 2) + ) + _dict = providers.Dict( + a=providers.Callable(lambda a: a, a=3), b=providers.Callable(lambda b: b, 4) + ) + service = providers.Resource(init_service, counter, _list, _dict=_dict) + service2 = providers.Resource(init_service, counter, _list, _dict=_dict) + factory_service = providers.Factory(FactoryService, service, service2) + factory_service_kwargs = providers.Factory( + FactoryService, + service=service, + service2=service2, + ) + nested_service = providers.Factory(NestedService, factory_service) @inject @@ -44,5 +83,21 @@ def test_function(service: Service = Closing[Provide["service"]]): @inject -def test_function_dependency(factory: FactoryService = Closing[Provide["factory_service"]]): +def test_function_dependency( + factory: FactoryService = Closing[Provide["factory_service"]], +): + return factory + + +@inject +def test_function_dependency_kwargs( + factory: FactoryService = Closing[Provide["factory_service_kwargs"]], +): return factory + + +@inject +def test_function_nested_dependency( + nested: NestedService = Closing[Provide["nested_service"]], +): + return nested diff --git a/tests/unit/wiring/string_ids/test_main_py36.py b/tests/unit/wiring/string_ids/test_main_py36.py index d4c49fe8..8125481a 100644 --- a/tests/unit/wiring/string_ids/test_main_py36.py +++ b/tests/unit/wiring/string_ids/test_main_py36.py @@ -2,13 +2,13 @@ from decimal import Decimal -from dependency_injector import errors -from dependency_injector.wiring import Closing, Provide, Provider, wire from pytest import fixture, mark, raises - from samples.wiringstringids import module, package, resourceclosing -from samples.wiringstringids.service import Service from samples.wiringstringids.container import Container, SubContainer +from samples.wiringstringids.service import Service + +from dependency_injector import errors +from dependency_injector.wiring import Closing, Provide, Provider, wire @fixture(autouse=True) @@ -34,10 +34,11 @@ def subcontainer(): @fixture -def resourceclosing_container(): +def resourceclosing_container(request): container = resourceclosing.Container() container.wire(modules=[resourceclosing]) - yield container + with container.reset_singletons(): + yield container container.unwire() @@ -274,42 +275,65 @@ def test_wire_multiple_containers(): @mark.usefixtures("resourceclosing_container") def test_closing_resource(): - resourceclosing.Service.reset_counter() - result_1 = resourceclosing.test_function() assert isinstance(result_1, resourceclosing.Service) assert result_1.init_counter == 1 assert result_1.shutdown_counter == 1 + assert result_1.dependencies == {"_list": [1, 2], "_dict": {"a": 3, "b": 4}} result_2 = resourceclosing.test_function() assert isinstance(result_2, resourceclosing.Service) assert result_2.init_counter == 2 assert result_2.shutdown_counter == 2 + assert result_1.dependencies == {"_list": [1, 2], "_dict": {"a": 3, "b": 4}} assert result_1 is not result_2 @mark.usefixtures("resourceclosing_container") def test_closing_dependency_resource(): - resourceclosing.Service.reset_counter() - result_1 = resourceclosing.test_function_dependency() assert isinstance(result_1, resourceclosing.FactoryService) - assert result_1.service.init_counter == 1 - assert result_1.service.shutdown_counter == 1 + assert result_1.service.init_counter == 2 + assert result_1.service.shutdown_counter == 2 result_2 = resourceclosing.test_function_dependency() + assert isinstance(result_2, resourceclosing.FactoryService) - assert result_2.service.init_counter == 2 - assert result_2.service.shutdown_counter == 2 + assert result_2.service.init_counter == 4 + assert result_2.service.shutdown_counter == 4 + + +@mark.usefixtures("resourceclosing_container") +def test_closing_dependency_resource_kwargs(): + result_1 = resourceclosing.test_function_dependency_kwargs() + assert isinstance(result_1, resourceclosing.FactoryService) + assert result_1.service.init_counter == 2 + assert result_1.service.shutdown_counter == 2 + + result_2 = resourceclosing.test_function_dependency_kwargs() + assert isinstance(result_2, resourceclosing.FactoryService) + assert result_2.service.init_counter == 4 + assert result_2.service.shutdown_counter == 4 + + +@mark.usefixtures("resourceclosing_container") +def test_closing_nested_dependency_resource(): + result_1 = resourceclosing.test_function_nested_dependency() + assert isinstance(result_1, resourceclosing.NestedService) + assert result_1.factory_service.service.init_counter == 2 + assert result_1.factory_service.service.shutdown_counter == 2 + + result_2 = resourceclosing.test_function_nested_dependency() + assert isinstance(result_2, resourceclosing.NestedService) + assert result_2.factory_service.service.init_counter == 4 + assert result_2.factory_service.service.shutdown_counter == 4 assert result_1 is not result_2 @mark.usefixtures("resourceclosing_container") def test_closing_resource_bypass_marker_injection(): - resourceclosing.Service.reset_counter() - result_1 = resourceclosing.test_function(service=Closing[Provide["service"]]) assert isinstance(result_1, resourceclosing.Service) assert result_1.init_counter == 1 @@ -325,7 +349,6 @@ def test_closing_resource_bypass_marker_injection(): @mark.usefixtures("resourceclosing_container") def test_closing_resource_context(): - resourceclosing.Service.reset_counter() service = resourceclosing.Service() result_1 = resourceclosing.test_function(service=service) diff --git a/tests/unit/wiring/test_fastapi_py36.py b/tests/unit/wiring/test_fastapi_py36.py index 1e9ff584..491c991c 100644 --- a/tests/unit/wiring/test_fastapi_py36.py +++ b/tests/unit/wiring/test_fastapi_py36.py @@ -4,13 +4,17 @@ # Runtime import to avoid syntax errors in samples on Python < 3.5 and reach top-dir import os + _SAMPLES_DIR = os.path.abspath( - os.path.sep.join(( - os.path.dirname(__file__), - "../samples/", - )), + os.path.sep.join( + ( + os.path.dirname(__file__), + "../samples/", + ) + ), ) import sys + sys.path.append(_SAMPLES_DIR) @@ -37,6 +41,19 @@ async def process(self): assert response.json() == {"result": "Foo"} +@mark.asyncio +async def test_depends_with_annotated(async_client: AsyncClient): + class ServiceMock: + async def process(self): + return "Foo" + + with web.container.service.override(ServiceMock()): + response = await async_client.get("/") + + assert response.status_code == 200 + assert response.json() == {"result": "Foo"} + + @mark.asyncio async def test_depends_injection(async_client: AsyncClient): response = await async_client.get("/auth", auth=("john_smith", "secret")) diff --git a/tests/unit/wiring/test_flask_py36.py b/tests/unit/wiring/test_flask_py36.py index 751f04d8..97420275 100644 --- a/tests/unit/wiring/test_flask_py36.py +++ b/tests/unit/wiring/test_flask_py36.py @@ -2,19 +2,25 @@ # Runtime import to avoid syntax errors in samples on Python < 3.5 and reach top-dir import os + _TOP_DIR = os.path.abspath( - os.path.sep.join(( - os.path.dirname(__file__), - "../", - )), + os.path.sep.join( + ( + os.path.dirname(__file__), + "../", + ) + ), ) _SAMPLES_DIR = os.path.abspath( - os.path.sep.join(( - os.path.dirname(__file__), - "../samples/", - )), + os.path.sep.join( + ( + os.path.dirname(__file__), + "../samples/", + ) + ), ) import sys + sys.path.append(_TOP_DIR) sys.path.append(_SAMPLES_DIR) @@ -29,3 +35,13 @@ def test_wiring_with_flask(): assert response.status_code == 200 assert json.loads(response.data) == {"result": "OK"} + + +def test_wiring_with_annotated(): + client = web.app.test_client() + + with web.app.app_context(): + response = client.get("/annotated") + + assert response.status_code == 200 + assert json.loads(response.data) == {"result": "OK"} diff --git a/tests/unit/wiring/test_introspection_py36.py b/tests/unit/wiring/test_introspection_py36.py index c7149602..66b36a80 100644 --- a/tests/unit/wiring/test_introspection_py36.py +++ b/tests/unit/wiring/test_introspection_py36.py @@ -6,6 +6,13 @@ from dependency_injector.wiring import inject +def test_isfunction(): + @inject + def foo(): ... + + assert inspect.isfunction(foo) + + def test_asyncio_iscoroutinefunction(): @inject async def foo():