Thanks to visit codestin.com
Credit goes to github.com

Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
16 changes: 9 additions & 7 deletions fastapi/applications.py
Original file line number Diff line number Diff line change
Expand Up @@ -1101,16 +1101,18 @@ def openapi(self) -> dict[str, Any]:

def setup(self) -> None:
if self.openapi_url:
urls = (server_data.get("url") for server_data in self.servers)
server_urls = {url for url in urls if url}

async def openapi(req: Request) -> JSONResponse:
root_path = req.scope.get("root_path", "").rstrip("/")
if root_path not in server_urls:
if root_path and self.root_path_in_servers:
self.servers.insert(0, {"url": root_path})
server_urls.add(root_path)
return JSONResponse(self.openapi())
schema = self.openapi()
if root_path and self.root_path_in_servers:
server_urls = {s.get("url") for s in schema.get("servers", [])}
if root_path not in server_urls:
schema = dict(schema)
schema["servers"] = [{"url": root_path}] + schema.get(
"servers", []
)
return JSONResponse(schema)

self.add_route(self.openapi_url, openapi, include_in_schema=False)
if self.openapi_url and self.docs_url:
Expand Down
18 changes: 16 additions & 2 deletions fastapi/openapi/docs.py
Original file line number Diff line number Diff line change
Expand Up @@ -5,6 +5,20 @@
from fastapi.encoders import jsonable_encoder
from starlette.responses import HTMLResponse


def _html_safe_json(value: Any) -> str:
"""Serialize a value to JSON with HTML special characters escaped.

This prevents injection when the JSON is embedded inside a <script> tag.
"""
return (
json.dumps(value)
.replace("<", "\\u003c")
.replace(">", "\\u003e")
.replace("&", "\\u0026")
)


swagger_ui_default_parameters: Annotated[
dict[str, Any],
Doc(
Expand Down Expand Up @@ -155,7 +169,7 @@ def get_swagger_ui_html(
"""

for key, value in current_swagger_ui_parameters.items():
html += f"{json.dumps(key)}: {json.dumps(jsonable_encoder(value))},\n"
html += f"{_html_safe_json(key)}: {_html_safe_json(jsonable_encoder(value))},\n"

if oauth2_redirect_url:
html += f"oauth2RedirectUrl: window.location.origin + '{oauth2_redirect_url}',"
Expand All @@ -169,7 +183,7 @@ def get_swagger_ui_html(

if init_oauth:
html += f"""
ui.initOAuth({json.dumps(jsonable_encoder(init_oauth))})
ui.initOAuth({_html_safe_json(jsonable_encoder(init_oauth))})
"""

html += """
Expand Down
75 changes: 75 additions & 0 deletions tests/test_openapi_cache_root_path.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,75 @@
from fastapi import FastAPI
from fastapi.testclient import TestClient


def test_root_path_does_not_persist_across_requests():
app = FastAPI()

@app.get("/")
def read_root(): # pragma: no cover
return {"ok": True}

# Attacker request with a spoofed root_path
attacker_client = TestClient(app, root_path="/evil-api")
response1 = attacker_client.get("/openapi.json")
data1 = response1.json()
assert any(s.get("url") == "/evil-api" for s in data1.get("servers", []))

# Subsequent legitimate request with no root_path
clean_client = TestClient(app)
response2 = clean_client.get("/openapi.json")
data2 = response2.json()
servers = [s.get("url") for s in data2.get("servers", [])]
assert "/evil-api" not in servers


def test_multiple_different_root_paths_do_not_accumulate():
app = FastAPI()

@app.get("/")
def read_root(): # pragma: no cover
return {"ok": True}

for prefix in ["/path-a", "/path-b", "/path-c"]:
c = TestClient(app, root_path=prefix)
c.get("/openapi.json")

# A clean request should not have any of them
clean_client = TestClient(app)
response = clean_client.get("/openapi.json")
data = response.json()
servers = [s.get("url") for s in data.get("servers", [])]
for prefix in ["/path-a", "/path-b", "/path-c"]:
assert prefix not in servers, (
f"root_path '{prefix}' leaked into clean request: {servers}"
)


def test_legitimate_root_path_still_appears():
app = FastAPI()

@app.get("/")
def read_root(): # pragma: no cover
return {"ok": True}

client = TestClient(app, root_path="/api/v1")
response = client.get("/openapi.json")
data = response.json()
servers = [s.get("url") for s in data.get("servers", [])]
assert "/api/v1" in servers


def test_configured_servers_not_mutated():
configured_servers = [{"url": "https://prod.example.com"}]
app = FastAPI(servers=configured_servers)

@app.get("/")
def read_root(): # pragma: no cover
return {"ok": True}

# Request with a rogue root_path
attacker_client = TestClient(app, root_path="/evil")
attacker_client.get("/openapi.json")

# The original servers list must be untouched
assert configured_servers == [{"url": "https://prod.example.com"}]
37 changes: 37 additions & 0 deletions tests/test_swagger_ui_escape.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,37 @@
from fastapi.openapi.docs import get_swagger_ui_html


def test_init_oauth_html_chars_are_escaped():
xss_payload = "Evil</script><script>alert(1)</script>"
html = get_swagger_ui_html(
openapi_url="/openapi.json",
title="Test",
init_oauth={"appName": xss_payload},
)
body = html.body.decode()

assert "</script><script>" not in body
assert "\\u003c/script\\u003e\\u003cscript\\u003e" in body


def test_swagger_ui_parameters_html_chars_are_escaped():
html = get_swagger_ui_html(
openapi_url="/openapi.json",
title="Test",
swagger_ui_parameters={"customKey": "<img src=x onerror=alert(1)>"},
)
body = html.body.decode()
assert "<img src=x onerror=alert(1)>" not in body
assert "\\u003cimg" in body


def test_normal_init_oauth_still_works():
html = get_swagger_ui_html(
openapi_url="/openapi.json",
title="Test",
init_oauth={"clientId": "my-client", "appName": "My App"},
)
body = html.body.decode()
assert '"clientId": "my-client"' in body
assert '"appName": "My App"' in body
assert "ui.initOAuth" in body