Thanks to visit codestin.com
Credit goes to github.com

Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
10 changes: 10 additions & 0 deletions fastapi/dependencies/utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -188,6 +188,16 @@ def get_flat_dependant(
return flat_dependant


def get_flat_params(dependant: Dependant) -> List[ModelField]:
flat_dependant = get_flat_dependant(dependant, skip_repeats=True)
return (
flat_dependant.path_params
+ flat_dependant.query_params
+ flat_dependant.header_params
+ flat_dependant.cookie_params
)


def is_scalar_field(field: ModelField) -> bool:
field_info = get_field_info(field)
if not (
Expand Down
80 changes: 59 additions & 21 deletions fastapi/openapi/utils.py
Original file line number Diff line number Diff line change
@@ -1,9 +1,10 @@
import http.client
from typing import Any, Dict, List, Optional, Sequence, Tuple, Type, cast
from enum import Enum
from typing import Any, Dict, List, Optional, Sequence, Set, Tuple, Type, Union, cast

from fastapi import routing
from fastapi.dependencies.models import Dependant
from fastapi.dependencies.utils import get_flat_dependant
from fastapi.dependencies.utils import get_flat_dependant, get_flat_params
from fastapi.encoders import jsonable_encoder
from fastapi.openapi.constants import (
METHODS_WITH_BODY,
Expand All @@ -15,11 +16,14 @@
from fastapi.utils import (
generate_operation_id_for_path,
get_field_info,
get_flat_models_from_routes,
get_model_definitions,
)
from pydantic import BaseModel
from pydantic.schema import field_schema, get_model_name_map
from pydantic.schema import (
field_schema,
get_flat_models_from_fields,
get_model_name_map,
)
from pydantic.utils import lenient_issubclass
from starlette.responses import JSONResponse
from starlette.routing import BaseRoute
Expand Down Expand Up @@ -64,16 +68,6 @@
}


def get_openapi_params(dependant: Dependant) -> List[ModelField]:
flat_dependant = get_flat_dependant(dependant, skip_repeats=True)
return (
flat_dependant.path_params
+ flat_dependant.query_params
+ flat_dependant.header_params
+ flat_dependant.cookie_params
)


def get_openapi_security_definitions(flat_dependant: Dependant) -> Tuple[Dict, List]:
security_definitions = {}
operation_security = []
Expand All @@ -90,17 +84,22 @@ def get_openapi_security_definitions(flat_dependant: Dependant) -> Tuple[Dict, L


def get_openapi_operation_parameters(
*,
all_route_params: Sequence[ModelField],
model_name_map: Dict[Union[Type[BaseModel], Type[Enum]], str]
) -> List[Dict[str, Any]]:
parameters = []
for param in all_route_params:
field_info = get_field_info(param)
field_info = cast(Param, field_info)
# ignore mypy error until enum schemas are released
parameter = {
"name": param.alias,
"in": field_info.in_.value,
"required": param.required,
"schema": field_schema(param, model_name_map={})[0],
"schema": field_schema(
param, model_name_map=model_name_map, ref_prefix=REF_PREFIX # type: ignore
)[0],
}
if field_info.description:
parameter["description"] = field_info.description
Expand All @@ -111,13 +110,16 @@ def get_openapi_operation_parameters(


def get_openapi_operation_request_body(
*, body_field: Optional[ModelField], model_name_map: Dict[Type[BaseModel], str]
*,
body_field: Optional[ModelField],
model_name_map: Dict[Union[Type[BaseModel], Type[Enum]], str]
) -> Optional[Dict]:
if not body_field:
return None
assert isinstance(body_field, ModelField)
# ignore mypy error until enum schemas are released
body_schema, _, _ = field_schema(
body_field, model_name_map=model_name_map, ref_prefix=REF_PREFIX
body_field, model_name_map=model_name_map, ref_prefix=REF_PREFIX # type: ignore
)
field_info = cast(Body, get_field_info(body_field))
request_media_type = field_info.media_type
Expand Down Expand Up @@ -176,8 +178,10 @@ def get_openapi_path(
operation.setdefault("security", []).extend(operation_security)
if security_definitions:
security_schemes.update(security_definitions)
all_route_params = get_openapi_params(route.dependant)
operation_parameters = get_openapi_operation_parameters(all_route_params)
all_route_params = get_flat_params(route.dependant)
operation_parameters = get_openapi_operation_parameters(
all_route_params=all_route_params, model_name_map=model_name_map
)
parameters.extend(operation_parameters)
if parameters:
operation["parameters"] = list(
Expand Down Expand Up @@ -270,6 +274,38 @@ def get_openapi_path(
return path, security_schemes, definitions


def get_flat_models_from_routes(
routes: Sequence[BaseRoute],
) -> Set[Union[Type[BaseModel], Type[Enum]]]:
body_fields_from_routes: List[ModelField] = []
responses_from_routes: List[ModelField] = []
request_fields_from_routes: List[ModelField] = []
callback_flat_models: Set[Union[Type[BaseModel], Type[Enum]]] = set()
for route in routes:
if getattr(route, "include_in_schema", None) and isinstance(
route, routing.APIRoute
):
if route.body_field:
assert isinstance(
route.body_field, ModelField
), "A request body must be a Pydantic Field"
body_fields_from_routes.append(route.body_field)
if route.response_field:
responses_from_routes.append(route.response_field)
if route.response_fields:
responses_from_routes.extend(route.response_fields.values())
if route.callbacks:
callback_flat_models |= get_flat_models_from_routes(route.callbacks)
params = get_flat_params(route.dependant)
request_fields_from_routes.extend(params)

flat_models = callback_flat_models | get_flat_models_from_fields(
body_fields_from_routes + responses_from_routes + request_fields_from_routes,
known_models=set(),
)
return flat_models


def get_openapi(
*,
title: str,
Expand All @@ -286,9 +322,11 @@ def get_openapi(
components: Dict[str, Dict] = {}
paths: Dict[str, Dict] = {}
flat_models = get_flat_models_from_routes(routes)
model_name_map = get_model_name_map(flat_models)
# ignore mypy error until enum schemas are released
model_name_map = get_model_name_map(flat_models) # type: ignore
# ignore mypy error until enum schemas are released
definitions = get_model_definitions(
flat_models=flat_models, model_name_map=model_name_map
flat_models=flat_models, model_name_map=model_name_map # type: ignore
)
for route in routes:
if isinstance(route, routing.APIRoute):
Expand Down
39 changes: 8 additions & 31 deletions fastapi/utils.py
Original file line number Diff line number Diff line change
@@ -1,17 +1,16 @@
import functools
import re
from dataclasses import is_dataclass
from typing import Any, Dict, List, Optional, Sequence, Set, Type, Union, cast
from enum import Enum
from typing import Any, Dict, Optional, Set, Type, Union, cast

import fastapi
from fastapi import routing
from fastapi.logger import logger
from fastapi.openapi.constants import REF_PREFIX
from pydantic import BaseConfig, BaseModel, create_model
from pydantic.class_validators import Validator
from pydantic.schema import get_flat_models_from_fields, model_process_schema
from pydantic.schema import model_process_schema
from pydantic.utils import lenient_issubclass
from starlette.routing import BaseRoute

try:
from pydantic.fields import FieldInfo, ModelField, UndefinedType
Expand Down Expand Up @@ -50,38 +49,16 @@ def warning_response_model_skip_defaults_deprecated() -> None:
)


def get_flat_models_from_routes(routes: Sequence[BaseRoute]) -> Set[Type[BaseModel]]:
body_fields_from_routes: List[ModelField] = []
responses_from_routes: List[ModelField] = []
callback_flat_models: Set[Type[BaseModel]] = set()
for route in routes:
if getattr(route, "include_in_schema", None) and isinstance(
route, routing.APIRoute
):
if route.body_field:
assert isinstance(
route.body_field, ModelField
), "A request body must be a Pydantic Field"
body_fields_from_routes.append(route.body_field)
if route.response_field:
responses_from_routes.append(route.response_field)
if route.response_fields:
responses_from_routes.extend(route.response_fields.values())
if route.callbacks:
callback_flat_models |= get_flat_models_from_routes(route.callbacks)
flat_models = callback_flat_models | get_flat_models_from_fields(
body_fields_from_routes + responses_from_routes, known_models=set()
)
return flat_models


def get_model_definitions(
*, flat_models: Set[Type[BaseModel]], model_name_map: Dict[Type[BaseModel], str]
*,
flat_models: Set[Union[Type[BaseModel], Type[Enum]]],
model_name_map: Dict[Union[Type[BaseModel], Type[Enum]], str],
) -> Dict[str, Any]:
definitions: Dict[str, Dict] = {}
for model in flat_models:
# ignore mypy error until enum schemas are released
m_schema, m_definitions, m_nested_models = model_process_schema(
model, model_name_map=model_name_map, ref_prefix=REF_PREFIX
model, model_name_map=model_name_map, ref_prefix=REF_PREFIX # type: ignore
)
definitions.update(m_definitions)
model_name = model_name_map[model]
Expand Down
2 changes: 1 addition & 1 deletion scripts/format.sh
Original file line number Diff line number Diff line change
Expand Up @@ -3,4 +3,4 @@ set -x

autoflake --remove-all-unused-imports --recursive --remove-unused-variables --in-place docs_src fastapi tests scripts --exclude=__init__.py
black fastapi tests docs_src scripts
isort --multi-line=3 --trailing-comma --force-grid-wrap=0 --combine-as --line-width 88 --recursive --thirdparty fastapi --apply fastapi tests docs_src scripts
isort --multi-line=3 --trailing-comma --force-grid-wrap=0 --combine-as --line-width 88 --recursive --thirdparty fastapi --thirdparty pydantic --thirdparty starlette --apply fastapi tests docs_src scripts
2 changes: 1 addition & 1 deletion scripts/lint.sh
Original file line number Diff line number Diff line change
Expand Up @@ -5,4 +5,4 @@ set -x

mypy fastapi
black fastapi tests --check
isort --multi-line=3 --trailing-comma --force-grid-wrap=0 --combine-as --line-width 88 --recursive --check-only --thirdparty fastapi fastapi tests
isort --multi-line=3 --trailing-comma --force-grid-wrap=0 --combine-as --line-width 88 --recursive --check-only --thirdparty fastapi --thirdparty fastapi --thirdparty pydantic --thirdparty starlette fastapi tests
8 changes: 7 additions & 1 deletion tests/test_tutorial/test_path_params/test_tutorial005.py
Original file line number Diff line number Diff line change
Expand Up @@ -87,7 +87,7 @@
"parameters": [
{
"required": True,
"schema": {"$ref": "#/definitions/ModelName"},
"schema": {"$ref": "#/components/schemas/ModelName"},
"name": "model_name",
"in": "path",
}
Expand Down Expand Up @@ -124,6 +124,12 @@
}
},
},
"ModelName": {
"title": "ModelName",
"enum": ["alexnet", "resnet", "lenet"],
"type": "string",
"description": "An enumeration.",
},
"ValidationError": {
"title": "ValidationError",
"required": ["loc", "msg", "type"],
Expand Down