-
-
Notifications
You must be signed in to change notification settings - Fork 9.5k
✨ Update internal AsyncExitStack to fix context for dependencies with yield
#4575
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
Changes from all commits
bda6359
5b7d83d
c29d85e
017c235
cea488d
0706731
d33cde6
5a3b6f1
23b4ccc
c2e3d49
4531f6a
a69d916
File filter
Filter by extension
Conversations
Jump to
Diff view
Diff view
There are no files selected for viewing
| Original file line number | Diff line number | Diff line change |
|---|---|---|
| @@ -0,0 +1,28 @@ | ||
| from typing import Optional | ||
|
|
||
| from fastapi.concurrency import AsyncExitStack | ||
| from starlette.types import ASGIApp, Receive, Scope, Send | ||
|
|
||
|
|
||
| class AsyncExitStackMiddleware: | ||
| def __init__(self, app: ASGIApp, context_name: str = "fastapi_astack") -> None: | ||
| self.app = app | ||
| self.context_name = context_name | ||
|
|
||
| async def __call__(self, scope: Scope, receive: Receive, send: Send) -> None: | ||
| if AsyncExitStack: | ||
|
Contributor
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. It seems to me that
Member
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. yeah... 🤔 |
||
| dependency_exception: Optional[Exception] = None | ||
| async with AsyncExitStack() as stack: | ||
| scope[self.context_name] = stack | ||
| try: | ||
| await self.app(scope, receive, send) | ||
| except Exception as e: | ||
| dependency_exception = e | ||
| raise e | ||
| if dependency_exception: | ||
| # This exception was possibly handled by the dependency but it should | ||
| # still bubble up so that the ServerErrorMiddleware can return a 500 | ||
| # or the ExceptionMiddleware can catch and handle any other exceptions | ||
| raise dependency_exception | ||
| else: | ||
| await self.app(scope, receive, send) # pragma: no cover | ||
| Original file line number | Diff line number | Diff line change |
|---|---|---|
| @@ -0,0 +1,51 @@ | ||
| from contextvars import ContextVar | ||
| from typing import Any, Awaitable, Callable, Dict, Optional | ||
|
|
||
| from fastapi import Depends, FastAPI, Request, Response | ||
| from fastapi.testclient import TestClient | ||
|
|
||
| legacy_request_state_context_var: ContextVar[Optional[Dict[str, Any]]] = ContextVar( | ||
| "legacy_request_state_context_var", default=None | ||
| ) | ||
|
|
||
| app = FastAPI() | ||
|
|
||
|
|
||
| async def set_up_request_state_dependency(): | ||
| request_state = {"user": "deadpond"} | ||
| contextvar_token = legacy_request_state_context_var.set(request_state) | ||
| yield request_state | ||
| legacy_request_state_context_var.reset(contextvar_token) | ||
|
|
||
|
|
||
| @app.middleware("http") | ||
| async def custom_middleware( | ||
| request: Request, call_next: Callable[[Request], Awaitable[Response]] | ||
| ): | ||
| response = await call_next(request) | ||
| response.headers["custom"] = "foo" | ||
| return response | ||
|
|
||
|
|
||
| @app.get("/user", dependencies=[Depends(set_up_request_state_dependency)]) | ||
| def get_user(): | ||
| request_state = legacy_request_state_context_var.get() | ||
| assert request_state | ||
| return request_state["user"] | ||
|
|
||
|
|
||
| client = TestClient(app) | ||
|
|
||
|
|
||
| def test_dependency_contextvars(): | ||
| """ | ||
| Check that custom middlewares don't affect the contextvar context for dependencies. | ||
|
|
||
| The code before yield and the code after yield should be run in the same contextvar | ||
| context, so that request_state_context_var.reset(contextvar_token). | ||
|
|
||
| If they are run in a different context, that raises an error. | ||
| """ | ||
| response = client.get("/user") | ||
| assert response.json() == "deadpond" | ||
| assert response.headers["custom"] == "foo" |
| Original file line number | Diff line number | Diff line change |
|---|---|---|
| @@ -0,0 +1,71 @@ | ||
| import pytest | ||
| from fastapi import Body, Depends, FastAPI, HTTPException | ||
| from fastapi.testclient import TestClient | ||
|
|
||
| initial_fake_database = {"rick": "Rick Sanchez"} | ||
|
|
||
| fake_database = initial_fake_database.copy() | ||
|
|
||
| initial_state = {"except": False, "finally": False} | ||
|
|
||
| state = initial_state.copy() | ||
|
|
||
| app = FastAPI() | ||
|
|
||
|
|
||
| async def get_database(): | ||
| temp_database = fake_database.copy() | ||
| try: | ||
| yield temp_database | ||
| fake_database.update(temp_database) | ||
| except HTTPException: | ||
| state["except"] = True | ||
| finally: | ||
| state["finally"] = True | ||
|
|
||
|
|
||
| @app.put("/invalid-user/{user_id}") | ||
| def put_invalid_user( | ||
| user_id: str, name: str = Body(...), db: dict = Depends(get_database) | ||
| ): | ||
| db[user_id] = name | ||
| raise HTTPException(status_code=400, detail="Invalid user") | ||
|
|
||
|
|
||
| @app.put("/user/{user_id}") | ||
| def put_user(user_id: str, name: str = Body(...), db: dict = Depends(get_database)): | ||
| db[user_id] = name | ||
| return {"message": "OK"} | ||
|
|
||
|
|
||
| @pytest.fixture(autouse=True) | ||
| def reset_state_and_db(): | ||
| global fake_database | ||
| global state | ||
| fake_database = initial_fake_database.copy() | ||
| state = initial_state.copy() | ||
|
|
||
|
|
||
| client = TestClient(app) | ||
|
|
||
|
|
||
| def test_dependency_gets_exception(): | ||
| assert state["except"] is False | ||
| assert state["finally"] is False | ||
| response = client.put("/invalid-user/rick", json="Morty") | ||
| assert response.status_code == 400, response.text | ||
| assert response.json() == {"detail": "Invalid user"} | ||
| assert state["except"] is True | ||
| assert state["finally"] is True | ||
| assert fake_database["rick"] == "Rick Sanchez" | ||
|
|
||
|
|
||
| def test_dependency_no_exception(): | ||
| assert state["except"] is False | ||
| assert state["finally"] is False | ||
| response = client.put("/user/rick", json="Morty") | ||
| assert response.status_code == 200, response.text | ||
| assert response.json() == {"message": "OK"} | ||
| assert state["except"] is False | ||
| assert state["finally"] is True | ||
| assert fake_database["rick"] == "Morty" |
Uh oh!
There was an error while loading. Please reload this page.