diff --git a/src/mcp/server/auth/errors.py b/src/mcp/server/auth/errors.py index 863a17b55..90d70f867 100644 --- a/src/mcp/server/auth/errors.py +++ b/src/mcp/server/auth/errors.py @@ -4,11 +4,18 @@ Corresponds to TypeScript file: src/server/auth/errors.ts """ -from typing import Dict +from typing import TypedDict from pydantic import ValidationError +class OAuthErrorResponse(TypedDict): + """OAuth error response format.""" + + error: str + error_description: str + + class OAuthError(Exception): """ Base class for all OAuth errors. @@ -22,7 +29,7 @@ def __init__(self, message: str): super().__init__(message) self.message = message - def to_response_object(self) -> Dict[str, str]: + def to_response_object(self) -> OAuthErrorResponse: """Convert error to JSON response object.""" return {"error": self.error_code, "error_description": self.message} @@ -146,5 +153,9 @@ class InsufficientScopeError(OAuthError): error_code = "insufficient_scope" + def stringify_pydantic_error(validation_error: ValidationError) -> str: - return "\n".join(f"{'.'.join(str(loc) for loc in e['loc'])}: {e['msg']}" for e in validation_error.errors()) \ No newline at end of file + return "\n".join( + f"{'.'.join(str(loc) for loc in e['loc'])}: {e['msg']}" + for e in validation_error.errors() + ) diff --git a/src/mcp/server/auth/handlers/authorize.py b/src/mcp/server/auth/handlers/authorize.py index 9d0b3c1d3..31a9eee21 100644 --- a/src/mcp/server/auth/handlers/authorize.py +++ b/src/mcp/server/auth/handlers/authorize.py @@ -4,8 +4,9 @@ Corresponds to TypeScript file: src/server/auth/handlers/authorize.ts """ +import logging from typing import Callable, Literal, Optional, Union -from urllib.parse import parse_qs, urlencode, urlparse, urlunparse +from urllib.parse import urlencode, urlparse, urlunparse from pydantic import AnyHttpUrl, AnyUrl, BaseModel, Field, RootModel, ValidationError from starlette.datastructures import FormData, QueryParams @@ -13,16 +14,17 @@ from starlette.responses import RedirectResponse, Response from mcp.server.auth.errors import ( - InvalidClientError, InvalidRequestError, OAuthError, stringify_pydantic_error, ) -from mcp.server.auth.provider import AuthorizationParams, OAuthServerProvider, construct_redirect_uri -from mcp.shared.auth import OAuthClientInformationFull from mcp.server.auth.json_response import PydanticJSONResponse - -import logging +from mcp.server.auth.provider import ( + AuthorizationParams, + OAuthServerProvider, + construct_redirect_uri, +) +from mcp.shared.auth import OAuthClientInformationFull logger = logging.getLogger(__name__) @@ -48,7 +50,6 @@ class AuthorizationRequest(BaseModel): description="Optional scope; if specified, should be " "a space-separated list of scope strings", ) - def validate_scope( @@ -80,15 +81,19 @@ def validate_redirect_uri( raise InvalidRequestError( "redirect_uri must be specified when client has multiple registered URIs" ) + + ErrorCode = Literal[ - "invalid_request", - "unauthorized_client", - "access_denied", - "unsupported_response_type", - "invalid_scope", - "server_error", - "temporarily_unavailable" - ] + "invalid_request", + "unauthorized_client", + "access_denied", + "unsupported_response_type", + "invalid_scope", + "server_error", + "temporarily_unavailable", +] + + class ErrorResponse(BaseModel): error: ErrorCode error_description: str @@ -96,7 +101,10 @@ class ErrorResponse(BaseModel): # must be set if provided in the request state: Optional[str] -def best_effort_extract_string(key: str, params: None | FormData | QueryParams) -> Optional[str]: + +def best_effort_extract_string( + key: str, params: None | FormData | QueryParams +) -> Optional[str]: if params is None: return None value = params.get(key) @@ -104,6 +112,7 @@ def best_effort_extract_string(key: str, params: None | FormData | QueryParams) return value return None + class AnyHttpUrlModel(RootModel): root: AnyHttpUrl @@ -118,18 +127,24 @@ async def authorization_handler(request: Request) -> Response: client = None params = None - async def error_response(error: ErrorCode, error_description: str, attempt_load_client: bool = True): + async def error_response( + error: ErrorCode, error_description: str, attempt_load_client: bool = True + ): nonlocal client, redirect_uri, state if client is None and attempt_load_client: # make last-ditch attempt to load the client client_id = best_effort_extract_string("client_id", params) - client = client_id and await provider.clients_store.get_client(client_id) + client = client_id and await provider.clients_store.get_client( + client_id + ) if redirect_uri is None and client: # make last-ditch effort to load the redirect uri if params is not None and "redirect_uri" not in params: raw_redirect_uri = None else: - raw_redirect_uri = AnyHttpUrlModel.model_validate(best_effort_extract_string("redirect_uri", params)).root + raw_redirect_uri = AnyHttpUrlModel.model_validate( + best_effort_extract_string("redirect_uri", params) + ).root try: redirect_uri = validate_redirect_uri(raw_redirect_uri, client) except (ValidationError, InvalidRequestError): @@ -146,7 +161,9 @@ async def error_response(error: ErrorCode, error_description: str, attempt_load_ if redirect_uri and client: return RedirectResponse( - url=construct_redirect_uri(str(redirect_uri), **error_resp.model_dump(exclude_none=True)), + url=construct_redirect_uri( + str(redirect_uri), **error_resp.model_dump(exclude_none=True) + ), status_code=302, headers={"Cache-Control": "no-store"}, ) @@ -156,7 +173,7 @@ async def error_response(error: ErrorCode, error_description: str, attempt_load_ content=error_resp, headers={"Cache-Control": "no-store"}, ) - + try: # Parse request parameters if request.method == "GET": @@ -165,20 +182,22 @@ async def error_response(error: ErrorCode, error_description: str, attempt_load_ else: # Parse form data for POST requests params = await request.form() - + # Save state if it exists, even before validation state = best_effort_extract_string("state", params) - + try: auth_request = AuthorizationRequest.model_validate(params) state = auth_request.state # Update with validated state except ValidationError as validation_error: error: ErrorCode = "invalid_request" for e in validation_error.errors(): - if e['loc'] == ('response_type',) and e['type'] == 'literal_error': + if e["loc"] == ("response_type",) and e["type"] == "literal_error": error = "unsupported_response_type" break - return await error_response(error, stringify_pydantic_error(validation_error)) + return await error_response( + error, stringify_pydantic_error(validation_error) + ) # Get client information client = await provider.clients_store.get_client(auth_request.client_id) @@ -190,7 +209,6 @@ async def error_response(error: ErrorCode, error_description: str, attempt_load_ attempt_load_client=False, ) - # Validate redirect_uri against client's registered URIs try: redirect_uri = validate_redirect_uri(auth_request.redirect_uri, client) @@ -200,7 +218,7 @@ async def error_response(error: ErrorCode, error_description: str, attempt_load_ error="invalid_request", error_description=validation_error.message, ) - + # Validate scope - for scope errors, we can redirect try: scopes = validate_scope(auth_request.scope, client) @@ -210,7 +228,7 @@ async def error_response(error: ErrorCode, error_description: str, attempt_load_ error="invalid_scope", error_description=validation_error.message, ) - + # Setup authorization parameters auth_params = AuthorizationParams( state=state, @@ -218,20 +236,22 @@ async def error_response(error: ErrorCode, error_description: str, attempt_load_ code_challenge=auth_request.code_challenge, redirect_uri=redirect_uri, ) - + # Let the provider pick the next URI to redirect to response = RedirectResponse( url="", status_code=302, headers={"Cache-Control": "no-store"} ) - response.headers["location"] = await provider.authorize( - client, auth_params - ) + response.headers["location"] = await provider.authorize(client, auth_params) return response - + except Exception as validation_error: # Catch-all for unexpected errors - logger.exception("Unexpected error in authorization_handler", exc_info=validation_error) - return await error_response(error="server_error", error_description="An unexpected error occurred") + logger.exception( + "Unexpected error in authorization_handler", exc_info=validation_error + ) + return await error_response( + error="server_error", error_description="An unexpected error occurred" + ) return authorization_handler @@ -240,7 +260,7 @@ def create_error_redirect( redirect_uri: AnyUrl, error: Union[Exception, ErrorResponse] ) -> str: parsed_uri = urlparse(str(redirect_uri)) - + if isinstance(error, ErrorResponse): # Convert ErrorResponse to dict error_dict = error.model_dump(exclude_none=True) @@ -251,7 +271,7 @@ def create_error_redirect( query_params[key] = str(value) else: query_params[key] = value - + elif isinstance(error, OAuthError): query_params = {"error": error.error_code, "error_description": str(error)} else: diff --git a/src/mcp/server/auth/handlers/metadata.py b/src/mcp/server/auth/handlers/metadata.py index 11a9c904d..c50789df1 100644 --- a/src/mcp/server/auth/handlers/metadata.py +++ b/src/mcp/server/auth/handlers/metadata.py @@ -4,13 +4,15 @@ Corresponds to TypeScript file: src/server/auth/handlers/metadata.ts """ -from typing import Any, Callable, Dict +from typing import Callable from starlette.requests import Request from starlette.responses import JSONResponse, Response +from mcp.shared.auth import OAuthMetadata -def create_metadata_handler(metadata: Dict[str, Any]) -> Callable: + +def create_metadata_handler(metadata: OAuthMetadata) -> Callable: """ Create a handler for OAuth 2.0 Authorization Server Metadata. @@ -33,8 +35,9 @@ async def metadata_handler(request: Request) -> Response: Returns: JSON response with the authorization server metadata """ - # Remove any None values from metadata - clean_metadata = {k: v for k, v in metadata.items() if v is not None} + # Convert metadata to dict and remove any None values + metadata_dict = metadata.model_dump() + clean_metadata = {k: v for k, v in metadata_dict.items() if v is not None} return JSONResponse( content=clean_metadata, diff --git a/src/mcp/server/auth/handlers/register.py b/src/mcp/server/auth/handlers/register.py index 4378dc949..cac764a5e 100644 --- a/src/mcp/server/auth/handlers/register.py +++ b/src/mcp/server/auth/handlers/register.py @@ -14,7 +14,6 @@ from starlette.responses import JSONResponse, Response from mcp.server.auth.errors import ( - InvalidRequestError, OAuthError, ServerError, stringify_pydantic_error, @@ -23,66 +22,93 @@ from mcp.server.auth.provider import OAuthRegisteredClientsStore from mcp.shared.auth import OAuthClientInformationFull, OAuthClientMetadata + class ErrorResponse(BaseModel): - error: Literal["invalid_redirect_uri", "invalid_client_metadata", "invalid_software_statement", "unapproved_software_statement"] + error: Literal[ + "invalid_redirect_uri", + "invalid_client_metadata", + "invalid_software_statement", + "unapproved_software_statement", + ] error_description: str def create_registration_handler( clients_store: OAuthRegisteredClientsStore, client_secret_expiry_seconds: int | None ) -> Callable: + """ + Create a handler for OAuth 2.0 Dynamic Client Registration. + + Corresponds to clientRegistrationHandler in src/server/auth/handlers/register.ts + + Args: + clients_store: The store for registered clients + client_secret_expiry_seconds: Optional expiry time for client secrets + + Returns: + A Starlette endpoint handler function + """ + async def registration_handler(request: Request) -> Response: - # Implements dynamic client registration as defined in https://datatracker.ietf.org/doc/html/rfc7591#section-3.1 + """ + Handler for the OAuth 2.0 Dynamic Client Registration endpoint. + + Args: + request: The Starlette request + + Returns: + JSON response with client information or error + """ try: # Parse request body as JSON - body = await request.json() - client_metadata = OAuthClientMetadata.model_validate(body) - except ValidationError as validation_error: - return PydanticJSONResponse(content=ErrorResponse( - error="invalid_client_metadata", - error_description=stringify_pydantic_error(validation_error) - ), status_code=400) - raise InvalidRequestError(f"Invalid client metadata: {str(e)}") - - client_id = str(uuid4()) - client_secret = None - if client_metadata.token_endpoint_auth_method != "none": - # cryptographically secure random 32-byte hex string - client_secret = secrets.token_hex(32) - - client_id_issued_at = int(time.time()) - client_secret_expires_at = ( - client_id_issued_at + client_secret_expiry_seconds - if client_secret_expiry_seconds is not None - else None - ) - - client_info = OAuthClientInformationFull( - client_id=client_id, - client_id_issued_at=client_id_issued_at, - client_secret=client_secret, - client_secret_expires_at=client_secret_expires_at, - # passthrough information from the client request - redirect_uris=client_metadata.redirect_uris, - token_endpoint_auth_method=client_metadata.token_endpoint_auth_method, - grant_types=client_metadata.grant_types, - response_types=client_metadata.response_types, - client_name=client_metadata.client_name, - client_uri=client_metadata.client_uri, - logo_uri=client_metadata.logo_uri, - scope=client_metadata.scope, - contacts=client_metadata.contacts, - tos_uri=client_metadata.tos_uri, - policy_uri=client_metadata.policy_uri, - jwks_uri=client_metadata.jwks_uri, - jwks=client_metadata.jwks, - software_id=client_metadata.software_id, - software_version=client_metadata.software_version, - ) - # Register client - client = await clients_store.register_client(client_info) - - # Return client information - return PydanticJSONResponse(content=client, status_code=201) + try: + body = await request.json() + client_metadata = OAuthClientMetadata.model_validate(body) + except ValidationError as validation_error: + return PydanticJSONResponse( + content=ErrorResponse( + error="invalid_client_metadata", + error_description=stringify_pydantic_error(validation_error), + ), + status_code=400, + ) + + client_id = str(uuid4()) + client_secret = None + if client_metadata.token_endpoint_auth_method != "none": + # cryptographically secure random 32-byte hex string + client_secret = secrets.token_hex(32) + + client_id_issued_at = int(time.time()) + client_secret_expires_at = ( + client_id_issued_at + client_secret_expiry_seconds + if client_secret_expiry_seconds is not None + else None + ) + + client_info = OAuthClientInformationFull( + client_id=client_id, + client_id_issued_at=client_id_issued_at, + client_secret=client_secret, + client_secret_expires_at=client_secret_expires_at, + **client_metadata.model_dump(exclude_unset=True), + policy_uri=client_metadata.policy_uri, + jwks_uri=client_metadata.jwks_uri, + jwks=client_metadata.jwks, + software_id=client_metadata.software_id, + software_version=client_metadata.software_version, + ) + # Register client + client = await clients_store.register_client(client_info) + if not client: + raise ServerError("Failed to register client") + + # Return client information + return PydanticJSONResponse(content=client, status_code=201) + + except OAuthError as e: + # Handle OAuth errors + status_code = 500 if isinstance(e, ServerError) else 400 + return JSONResponse(status_code=status_code, content=e.to_response_object()) return registration_handler diff --git a/src/mcp/server/auth/handlers/revoke.py b/src/mcp/server/auth/handlers/revoke.py index 3ede08c1f..5a640b8eb 100644 --- a/src/mcp/server/auth/handlers/revoke.py +++ b/src/mcp/server/auth/handlers/revoke.py @@ -4,7 +4,6 @@ Corresponds to TypeScript file: src/server/auth/handlers/revoke.ts """ -from tokenize import Token from typing import Callable from pydantic import ValidationError @@ -12,16 +11,15 @@ from starlette.responses import Response from mcp.server.auth.errors import ( - InvalidRequestError, stringify_pydantic_error, ) +from mcp.server.auth.json_response import PydanticJSONResponse from mcp.server.auth.middleware.client_auth import ( ClientAuthenticator, ClientAuthRequest, ) -from mcp.server.auth.provider import OAuthServerProvider, OAuthTokenRevocationRequest -from mcp.server.auth.json_response import PydanticJSONResponse -from mcp.shared.auth import TokenErrorResponse +from mcp.server.auth.provider import OAuthServerProvider +from mcp.shared.auth import OAuthTokenRevocationRequest, TokenErrorResponse class RevocationRequest(OAuthTokenRevocationRequest, ClientAuthRequest): @@ -31,6 +29,19 @@ class RevocationRequest(OAuthTokenRevocationRequest, ClientAuthRequest): def create_revocation_handler( provider: OAuthServerProvider, client_authenticator: ClientAuthenticator ) -> Callable: + """ + Create a handler for OAuth 2.0 Token Revocation. + + Corresponds to revocationHandler in src/server/auth/handlers/revoke.ts + + Args: + provider: The OAuth server provider + client_authenticator: The client authenticator + + Returns: + A Starlette endpoint handler function + """ + async def revocation_handler(request: Request) -> Response: """ Handler for the OAuth 2.0 Token Revocation endpoint. @@ -39,17 +50,25 @@ async def revocation_handler(request: Request) -> Response: form_data = await request.form() revocation_request = RevocationRequest.model_validate(dict(form_data)) except ValidationError as e: - return PydanticJSONResponse(status_code=400,content=TokenErrorResponse( - error="invalid_request", - error_description=stringify_pydantic_error(e) - )) + return PydanticJSONResponse( + status_code=400, + content=TokenErrorResponse( + error="invalid_request", + error_description=stringify_pydantic_error(e), + ), + ) # Authenticate client client_auth_result = await client_authenticator(revocation_request) # Revoke token if provider.revoke_token: - await provider.revoke_token(client_auth_result, revocation_request) + # Convert RevocationRequest to OAuthTokenRevocationRequest + oauth_revocation_request = OAuthTokenRevocationRequest( + token=revocation_request.token, + token_type_hint=revocation_request.token_type_hint, + ) + await provider.revoke_token(client_auth_result, oauth_revocation_request) # Return successful empty response return Response( diff --git a/src/mcp/server/auth/handlers/token.py b/src/mcp/server/auth/handlers/token.py index 0c8efe929..287e96a72 100644 --- a/src/mcp/server/auth/handlers/token.py +++ b/src/mcp/server/auth/handlers/token.py @@ -7,13 +7,12 @@ import base64 import hashlib import time -from typing import Annotated, Callable, Literal, Optional, Union +from typing import Annotated, Callable, Literal, Union from pydantic import AnyHttpUrl, Field, RootModel, ValidationError from starlette.requests import Request from mcp.server.auth.errors import ( - InvalidRequestError, stringify_pydantic_error, ) from mcp.server.auth.json_response import PydanticJSONResponse @@ -41,7 +40,7 @@ class RefreshTokenRequest(ClientAuthRequest): # See https://datatracker.ietf.org/doc/html/rfc6749#section-6 grant_type: Literal["refresh_token"] refresh_token: str = Field(..., description="The refresh token") - scope: Optional[str] = Field(None, description="Optional scope parameter") + scope: str | None = Field(None, description="Optional scope parameter") class TokenRequest(RootModel): @@ -54,11 +53,24 @@ class TokenRequest(RootModel): def create_token_handler( provider: OAuthServerProvider, client_authenticator: ClientAuthenticator ) -> Callable: + """ + Create a handler for the OAuth 2.0 Token endpoint. + + Corresponds to tokenHandler in src/server/auth/handlers/token.ts + + Args: + provider: The OAuth server provider + client_authenticator: The client authenticator + + Returns: + A Starlette endpoint handler function + """ + def response(obj: TokenSuccessResponse | TokenErrorResponse): status_code = 200 if isinstance(obj, TokenErrorResponse): status_code = 400 - + return PydanticJSONResponse( content=obj, status_code=status_code, @@ -69,23 +81,35 @@ def response(obj: TokenSuccessResponse | TokenErrorResponse): ) async def token_handler(request: Request): + """ + Handler for the OAuth 2.0 Token endpoint. + + Args: + request: The Starlette request + + Returns: + JSON response with tokens or error + """ try: form_data = await request.form() token_request = TokenRequest.model_validate(dict(form_data)).root except ValidationError as validation_error: - return response(TokenErrorResponse( - error="invalid_request", - error_description=stringify_pydantic_error(validation_error) - - )) + return response( + TokenErrorResponse( + error="invalid_request", + error_description=stringify_pydantic_error(validation_error), + ) + ) client_info = await client_authenticator(token_request) if token_request.grant_type not in client_info.grant_types: - return response(TokenErrorResponse( - error="unsupported_grant_type", - error_description=f"Unsupported grant type (supported grant types are " - f"{client_info.grant_types})" - )) + return response( + TokenErrorResponse( + error="unsupported_grant_type", + error_description=f"Unsupported grant type (supported grant types are " + f"{client_info.grant_types})", + ) + ) tokens: TokenSuccessResponse @@ -96,37 +120,47 @@ async def token_handler(request: Request): ) if auth_code is None or auth_code.client_id != token_request.client_id: # if the authoriation code belongs to a different client, pretend it doesn't exist - return response(TokenErrorResponse( - error="invalid_grant", - error_description=f"authorization code does not exist" - )) + return response( + TokenErrorResponse( + error="invalid_grant", + error_description="authorization code does not exist", + ) + ) # make auth codes expire after a deadline # see https://datatracker.ietf.org/doc/html/rfc6749#section-10.5 if auth_code.expires_at < time.time(): - return response(TokenErrorResponse( - error="invalid_grant", - error_description=f"authorization code has expired" - )) + return response( + TokenErrorResponse( + error="invalid_grant", + error_description="authorization code has expired", + ) + ) # verify redirect_uri doesn't change between /authorize and /tokens # see https://datatracker.ietf.org/doc/html/rfc6749#section-10.6 if token_request.redirect_uri != auth_code.redirect_uri: - return response(TokenErrorResponse( - error="invalid_request", - error_description=f"redirect_uri did not match redirect_uri used when authorization code was created" - )) + return response( + TokenErrorResponse( + error="invalid_request", + error_description="redirect_uri did not match redirect_uri used when authorization code was created", + ) + ) # Verify PKCE code verifier sha256 = hashlib.sha256(token_request.code_verifier.encode()).digest() - hashed_code_verifier = base64.urlsafe_b64encode(sha256).decode().rstrip("=") + hashed_code_verifier = ( + base64.urlsafe_b64encode(sha256).decode().rstrip("=") + ) if hashed_code_verifier != auth_code.code_challenge: # see https://datatracker.ietf.org/doc/html/rfc7636#section-4.6 - return response(TokenErrorResponse( - error="invalid_grant", - error_description=f"incorrect code_verifier" - )) + return response( + TokenErrorResponse( + error="invalid_grant", + error_description="incorrect code_verifier", + ) + ) # Exchange authorization code for tokens tokens = await provider.exchange_authorization_code( @@ -134,30 +168,45 @@ async def token_handler(request: Request): ) case RefreshTokenRequest(): - refresh_token = await provider.load_refresh_token(client_info, token_request.refresh_token) - if refresh_token is None or refresh_token.client_id != token_request.client_id: + refresh_token = await provider.load_refresh_token( + client_info, token_request.refresh_token + ) + if ( + refresh_token is None + or refresh_token.client_id != token_request.client_id + ): # if the authoriation code belongs to a different client, pretend it doesn't exist - return response(TokenErrorResponse( - error="invalid_grant", - error_description=f"refresh token does not exist" - )) + return response( + TokenErrorResponse( + error="invalid_grant", + error_description="refresh token does not exist", + ) + ) if refresh_token.expires_at and refresh_token.expires_at < time.time(): # if the authoriation code belongs to a different client, pretend it doesn't exist - return response(TokenErrorResponse( - error="invalid_grant", - error_description=f"refresh token has expired" - )) + return response( + TokenErrorResponse( + error="invalid_grant", + error_description="refresh token has expired", + ) + ) # Parse scopes if provided - scopes = token_request.scope.split(" ") if token_request.scope else refresh_token.scopes + scopes = ( + token_request.scope.split(" ") + if token_request.scope + else refresh_token.scopes + ) for scope in scopes: if scope not in refresh_token.scopes: - return response(TokenErrorResponse( - error="invalid_scope", - error_description=f"cannot request scope `{scope}` not provided by refresh token" - )) + return response( + TokenErrorResponse( + error="invalid_scope", + error_description=f"cannot request scope `{scope}` not provided by refresh token", + ) + ) # Exchange refresh token for new tokens tokens = await provider.exchange_refresh_token( diff --git a/src/mcp/server/auth/handlers/types.py b/src/mcp/server/auth/handlers/types.py new file mode 100644 index 000000000..bb3e15aa6 --- /dev/null +++ b/src/mcp/server/auth/handlers/types.py @@ -0,0 +1,8 @@ +from typing import Protocol + +from starlette.requests import Request +from starlette.responses import Response + + +class HandlerFn(Protocol): + async def __call__(self, request: Request) -> Response: ... diff --git a/src/mcp/server/auth/middleware/client_auth.py b/src/mcp/server/auth/middleware/client_auth.py index f24aefca2..fc357ae8f 100644 --- a/src/mcp/server/auth/middleware/client_auth.py +++ b/src/mcp/server/auth/middleware/client_auth.py @@ -5,7 +5,7 @@ """ import time -from typing import Any, Callable, Dict, Optional +from typing import Any, Callable, Dict from pydantic import BaseModel from starlette.exceptions import HTTPException @@ -27,14 +27,13 @@ class ClientAuthRequest(BaseModel): """ client_id: str - client_secret: Optional[str] = None + client_secret: str | None = None class ClientAuthenticator: """ ClientAuthenticator is a callable which validates requests from a client - application, - used to verify /token and /revoke calls. + application, used to verify /token and /revoke calls. If, during registration, the client requested to be issued a secret, the authenticator asserts that /token and /register calls must be authenticated with that same token. @@ -116,7 +115,7 @@ async def __call__(self, scope: Dict, receive: Callable, send: Callable) -> None # Add client authentication to the request try: - client = await self.client_auth(request) + client = await self.client_auth(ClientAuthRequest.model_validate(request)) # Store the client in the request state request.state.client = client except HTTPException: diff --git a/src/mcp/server/auth/provider.py b/src/mcp/server/auth/provider.py index 3013ae439..8b7558e89 100644 --- a/src/mcp/server/auth/provider.py +++ b/src/mcp/server/auth/provider.py @@ -7,7 +7,7 @@ from typing import List, Literal, Optional, Protocol from urllib.parse import parse_qs, urlencode, urlparse, urlunparse -from pydantic import AnyHttpUrl, AnyUrl, BaseModel +from pydantic import AnyHttpUrl, BaseModel from mcp.server.auth.types import AuthInfo from mcp.shared.auth import ( @@ -28,6 +28,7 @@ class AuthorizationParams(BaseModel): code_challenge: str redirect_uri: AnyHttpUrl + class AuthorizationCode(BaseModel): code: str scopes: list[str] @@ -36,6 +37,7 @@ class AuthorizationCode(BaseModel): code_challenge: str redirect_uri: AnyHttpUrl + class RefreshToken(BaseModel): token: str client_id: str @@ -51,6 +53,7 @@ class OAuthTokenRevocationRequest(BaseModel): token: str token_type_hint: Optional[Literal["access_token", "refresh_token"]] = None + class OAuthRegisteredClientsStore(Protocol): """ Interface for storing and retrieving registered OAuth clients. @@ -72,12 +75,15 @@ async def get_client(self, client_id: str) -> Optional[OAuthClientInformationFul async def register_client( self, client_info: OAuthClientInformationFull - ) -> None: + ) -> OAuthClientInformationFull | None: """ - Registers a new client + Registers a new client and returns client information. Args: client_info: The client metadata to register. + + Returns: + The client information, or None if registration failed. """ ... @@ -118,7 +124,7 @@ async def authorize( | | | | Redirect | |redirect_uri|<-----+ +------------------+ | | - +------------+ + +------------+ Implementations will need to define another handler on the MCP server return flow to perform the second redirect, and generates and stores an authorization @@ -161,8 +167,9 @@ async def exchange_authorization_code( """ ... - async def load_refresh_token(self, client: OAuthClientInformationFull, refresh_token: str) -> RefreshToken | None: - ... + async def load_refresh_token( + self, client: OAuthClientInformationFull, refresh_token: str + ) -> RefreshToken | None: ... async def exchange_refresh_token( self, @@ -209,6 +216,7 @@ async def revoke_token( """ ... + def construct_redirect_uri(redirect_uri_base: str, **params: str | None) -> str: parsed_uri = urlparse(redirect_uri_base) query_params = [(k, v) for k, vs in parse_qs(parsed_uri.query) for v in vs] @@ -216,7 +224,5 @@ def construct_redirect_uri(redirect_uri_base: str, **params: str | None) -> str: if v is not None: query_params.append((k, v)) - redirect_uri = urlunparse( - parsed_uri._replace(query=urlencode(query_params)) - ) - return redirect_uri \ No newline at end of file + redirect_uri = urlunparse(parsed_uri._replace(query=urlencode(query_params))) + return redirect_uri diff --git a/src/mcp/server/auth/router.py b/src/mcp/server/auth/router.py index 4dfa8e6ae..f9b16c9b5 100644 --- a/src/mcp/server/auth/router.py +++ b/src/mcp/server/auth/router.py @@ -5,7 +5,7 @@ """ from dataclasses import dataclass -from typing import Any, Dict, Optional +from typing import Optional from pydantic import AnyUrl from starlette.routing import Route, Router @@ -18,6 +18,7 @@ ClientAuthenticator, ) from mcp.server.auth.provider import OAuthServerProvider +from mcp.shared.auth import OAuthMetadata @dataclass @@ -146,7 +147,7 @@ def build_metadata( service_documentation_url: Optional[AnyUrl], client_registration_options: ClientRegistrationOptions, revocation_options: RevocationOptions, -) -> Dict[str, Any]: +) -> OAuthMetadata: issuer_url_str = str(issuer_url).rstrip("/") # Create metadata metadata = { @@ -171,4 +172,4 @@ def build_metadata( metadata["revocation_endpoint"] = f"{issuer_url_str}{REVOCATION_PATH}" metadata["revocation_endpoint_auth_methods_supported"] = ["client_secret_post"] - return metadata + return OAuthMetadata.model_validate(metadata) diff --git a/src/mcp/server/auth/types.py b/src/mcp/server/auth/types.py index f0593d864..3edc4cb93 100644 --- a/src/mcp/server/auth/types.py +++ b/src/mcp/server/auth/types.py @@ -20,3 +20,7 @@ class AuthInfo(BaseModel): client_id: str scopes: List[str] expires_at: Optional[int] = None + user_id: Optional[str] = None + + class Config: + extra = "ignore" diff --git a/src/mcp/server/sse.py b/src/mcp/server/sse.py index db36bffad..2ebc4b2cb 100644 --- a/src/mcp/server/sse.py +++ b/src/mcp/server/sse.py @@ -44,7 +44,6 @@ async def handle_sse(request): from starlette.requests import Request from starlette.responses import Response from starlette.types import Receive, Scope, Send -from typing_extensions import deprecated import mcp.types as types @@ -79,7 +78,7 @@ def __init__(self, endpoint: str) -> None: self._read_stream_writers = {} logger.debug(f"SseServerTransport initialized with endpoint: {endpoint}") - @deprecated("use connect_sse_v2 instead") + # @deprecated("use connect_sse_v2 instead") @asynccontextmanager async def connect_sse(self, scope: Scope, receive: Receive, send: Send): if scope["type"] != "http": diff --git a/src/mcp/shared/auth.py b/src/mcp/shared/auth.py index 9bcdaef15..2127575ee 100644 --- a/src/mcp/shared/auth.py +++ b/src/mcp/shared/auth.py @@ -14,7 +14,14 @@ class TokenErrorResponse(BaseModel): See https://datatracker.ietf.org/doc/html/rfc6749#section-5.2 """ - error: Literal["invalid_request", "invalid_client", "invalid_grant", "unauthorized_client", "unsupported_grant_type", "invalid_scope"] + error: Literal[ + "invalid_request", + "invalid_client", + "invalid_grant", + "unauthorized_client", + "unsupported_grant_type", + "invalid_scope", + ] error_description: Optional[str] = None error_uri: Optional[AnyHttpUrl] = None @@ -102,7 +109,15 @@ class OAuthClientRegistrationError(BaseModel): error_description: Optional[str] = None +class OAuthTokenRevocationRequest(BaseModel): + """ + RFC 7009 OAuth 2.0 Token Revocation request. + + Corresponds to OAuthTokenRevocationRequestSchema in src/shared/auth.ts + """ + token: str + token_type_hint: Optional[str] = None class OAuthMetadata(BaseModel): diff --git a/tests/server/fastmcp/auth/streaming_asgi_transport.py b/tests/server/fastmcp/auth/streaming_asgi_transport.py index eb1ba4342..8bd7c70de 100644 --- a/tests/server/fastmcp/auth/streaming_asgi_transport.py +++ b/tests/server/fastmcp/auth/streaming_asgi_transport.py @@ -6,7 +6,6 @@ the connection is closed. """ -import asyncio import typing from typing import Any, Dict, Tuple @@ -161,18 +160,21 @@ async def process_messages() -> None: response_complete.set() # Create tasks for running the app and processing messages - asyncio.create_task(run_app()) - asyncio.create_task(process_messages()) + tg = anyio.create_task_group() - # Wait for the initial response or timeout - await initial_response_ready.wait() + async with tg: + tg.start_soon(run_app) + tg.start_soon(process_messages) - # Create a streaming response - return Response( - status_code, - headers=response_headers, - stream=StreamingASGIResponseStream(content_receive_channel), - ) + # Wait for the initial response or timeout + await initial_response_ready.wait() + + # Create a streaming response + return Response( + status_code, + headers=response_headers, + stream=StreamingASGIResponseStream(content_receive_channel), + ) class StreamingASGIResponseStream(AsyncByteStream): diff --git a/tests/server/fastmcp/auth/test_auth_integration.py b/tests/server/fastmcp/auth/test_auth_integration.py index 785c5a7ad..368b09317 100644 --- a/tests/server/fastmcp/auth/test_auth_integration.py +++ b/tests/server/fastmcp/auth/test_auth_integration.py @@ -38,7 +38,6 @@ from mcp.shared.auth import ( OAuthClientInformationFull, TokenSuccessResponse, - TokenErrorResponse, ) from mcp.types import JSONRPCRequest @@ -79,15 +78,17 @@ async def authorize( # code and completes the redirect code = AuthorizationCode( code=f"code_{int(time.time())}", - client_id= client.client_id, - code_challenge= params.code_challenge, - redirect_uri= params.redirect_uri, + client_id=client.client_id, + code_challenge=params.code_challenge, + redirect_uri=params.redirect_uri, expires_at=time.time() + 300, - scopes=params.scopes or ["read", "write"] + scopes=params.scopes or ["read", "write"], ) self.auth_codes[code.code] = code - return construct_redirect_uri(str(params.redirect_uri), code=code.code, state=params.state) + return construct_redirect_uri( + str(params.redirect_uri), code=code.code, state=params.state + ) async def load_authorization_code( self, client: OAuthClientInformationFull, authorization_code: str @@ -107,8 +108,8 @@ async def exchange_authorization_code( # Store the tokens self.tokens[access_token] = AuthInfo( token=access_token, - client_id= client.client_id, - scopes= authorization_code.scopes, + client_id=client.client_id, + scopes=authorization_code.scopes, expires_at=int(time.time()) + 3600, ) @@ -125,14 +126,16 @@ async def exchange_authorization_code( refresh_token=refresh_token, ) - async def load_refresh_token(self, client: OAuthClientInformationFull, refresh_token: str) -> RefreshToken | None: + async def load_refresh_token( + self, client: OAuthClientInformationFull, refresh_token: str + ) -> RefreshToken | None: old_access_token = self.refresh_tokens.get(refresh_token) if old_access_token is None: return None token_info = self.tokens.get(old_access_token) if token_info is None: return None - + # Create a RefreshToken object that matches what is expected in later code refresh_obj = RefreshToken( token=refresh_token, @@ -140,7 +143,7 @@ async def load_refresh_token(self, client: OAuthClientInformationFull, refresh_t scopes=token_info.scopes, expires_at=token_info.expires_at, ) - + return refresh_obj async def exchange_refresh_token( @@ -269,10 +272,10 @@ def test_client(auth_app) -> httpx.AsyncClient: @pytest.fixture async def registered_client(test_client: httpx.AsyncClient, request): """Create and register a test client. - + Parameters can be customized via indirect parameterization: - @pytest.mark.parametrize("registered_client", - [{"grant_types": ["authorization_code"]}], + @pytest.mark.parametrize("registered_client", + [{"grant_types": ["authorization_code"]}], indirect=True) """ # Default client metadata @@ -281,14 +284,14 @@ async def registered_client(test_client: httpx.AsyncClient, request): "client_name": "Test Client", "grant_types": ["authorization_code", "refresh_token"], } - + # Override with any parameters from the test if hasattr(request, "param") and request.param: client_metadata.update(request.param) - + response = await test_client.post("/register", json=client_metadata) assert response.status_code == 201, f"Failed to register client: {response.content}" - + client_info = response.json() return client_info @@ -302,17 +305,17 @@ def pkce_challenge(): .decode() .rstrip("=") ) - + return {"code_verifier": code_verifier, "code_challenge": code_challenge} @pytest.fixture async def auth_code(test_client, registered_client, pkce_challenge, request): """Get an authorization code. - + Parameters can be customized via indirect parameterization: - @pytest.mark.parametrize("auth_code", - [{"redirect_uri": "https://client.example.com/other-callback"}], + @pytest.mark.parametrize("auth_code", + [{"redirect_uri": "https://client.example.com/other-callback"}], indirect=True) """ # Default authorize params @@ -324,22 +327,22 @@ async def auth_code(test_client, registered_client, pkce_challenge, request): "code_challenge_method": "S256", "state": "test_state", } - + # Override with any parameters from the test if hasattr(request, "param") and request.param: auth_params.update(request.param) - + response = await test_client.get("/authorize", params=auth_params) assert response.status_code == 302, f"Failed to get auth code: {response.content}" - + # Extract the authorization code redirect_url = response.headers["location"] parsed_url = urlparse(redirect_url) query_params = parse_qs(parsed_url.query) - + assert "code" in query_params, f"No code in response: {query_params}" auth_code = query_params["code"][0] - + return { "code": auth_code, "redirect_uri": auth_params["redirect_uri"], @@ -350,10 +353,10 @@ async def auth_code(test_client, registered_client, pkce_challenge, request): @pytest.fixture async def tokens(test_client, registered_client, auth_code, pkce_challenge, request): """Exchange authorization code for tokens. - + Parameters can be customized via indirect parameterization: - @pytest.mark.parametrize("tokens", - [{"code_verifier": "wrong_verifier"}], + @pytest.mark.parametrize("tokens", + [{"code_verifier": "wrong_verifier"}], indirect=True) """ # Default token request params @@ -365,13 +368,13 @@ async def tokens(test_client, registered_client, auth_code, pkce_challenge, requ "code_verifier": pkce_challenge["code_verifier"], "redirect_uri": auth_code["redirect_uri"], } - + # Override with any parameters from the test if hasattr(request, "param") and request.param: token_params.update(request.param) - + response = await test_client.post("/token", data=token_params) - + # Don't assert success here since some tests will intentionally cause errors return { "response": response, @@ -408,7 +411,7 @@ async def test_metadata_endpoint(self, test_client: httpx.AsyncClient): "refresh_token", ] assert metadata["service_documentation"] == "https://docs.example.com" - + @pytest.mark.anyio async def test_token_validation_error(self, test_client: httpx.AsyncClient): """Test token endpoint error - validation error.""" @@ -422,10 +425,14 @@ async def test_token_validation_error(self, test_client: httpx.AsyncClient): ) error_response = response.json() assert error_response["error"] == "invalid_request" - assert "error_description" in error_response # Contains validation error messages - + assert ( + "error_description" in error_response + ) # Contains validation error messages + @pytest.mark.anyio - @pytest.mark.parametrize("registered_client", [{"grant_types": ["authorization_code"]}], indirect=True) + @pytest.mark.parametrize( + "registered_client", [{"grant_types": ["authorization_code"]}], indirect=True + ) async def test_token_unsupported_grant_type(self, test_client, registered_client): """Test token endpoint error - unsupported grant type.""" # Try to use refresh_token grant type with a client that only supports authorization_code @@ -442,9 +449,11 @@ async def test_token_unsupported_grant_type(self, test_client, registered_client error_response = response.json() assert error_response["error"] == "unsupported_grant_type" assert "supported grant types" in error_response["error_description"] - + @pytest.mark.anyio - async def test_token_invalid_auth_code(self, test_client, registered_client, pkce_challenge): + async def test_token_invalid_auth_code( + self, test_client, registered_client, pkce_challenge + ): """Test token endpoint error - authorization code does not exist.""" # Try to use a non-existent authorization code response = await test_client.post( @@ -464,29 +473,36 @@ async def test_token_invalid_auth_code(self, test_client, registered_client, pkc assert response.status_code == 400 error_response = response.json() assert error_response["error"] == "invalid_grant" - assert "authorization code does not exist" in error_response["error_description"] - + assert ( + "authorization code does not exist" in error_response["error_description"] + ) + @pytest.mark.anyio async def test_token_expired_auth_code( - self, test_client, registered_client, auth_code, pkce_challenge, mock_oauth_provider + self, + test_client, + registered_client, + auth_code, + pkce_challenge, + mock_oauth_provider, ): """Test token endpoint error - authorization code has expired.""" # Get the current time for our time mocking current_time = time.time() - - # Find the auth code object + + # Find the auth code object code_value = auth_code["code"] found_code = None for code_obj in mock_oauth_provider.auth_codes.values(): if code_obj.code == code_value: found_code = code_obj break - + assert found_code is not None - + # Authorization codes are typically short-lived (5 minutes = 300 seconds) # So we'll mock time to be 10 minutes (600 seconds) in the future - with unittest.mock.patch('time.time', return_value=current_time + 600): + with unittest.mock.patch("time.time", return_value=current_time + 600): # Try to use the expired authorization code response = await test_client.post( "/token", @@ -502,14 +518,26 @@ async def test_token_expired_auth_code( assert response.status_code == 400 error_response = response.json() assert error_response["error"] == "invalid_grant" - assert "authorization code has expired" in error_response["error_description"] - + assert ( + "authorization code has expired" in error_response["error_description"] + ) + @pytest.mark.anyio - @pytest.mark.parametrize("registered_client", - [{"redirect_uris": ["https://client.example.com/callback", - "https://client.example.com/other-callback"]}], - indirect=True) - async def test_token_redirect_uri_mismatch(self, test_client, registered_client, auth_code, pkce_challenge): + @pytest.mark.parametrize( + "registered_client", + [ + { + "redirect_uris": [ + "https://client.example.com/callback", + "https://client.example.com/other-callback", + ] + } + ], + indirect=True, + ) + async def test_token_redirect_uri_mismatch( + self, test_client, registered_client, auth_code, pkce_challenge + ): """Test token endpoint error - redirect URI mismatch.""" # Try to use the code with a different redirect URI response = await test_client.post( @@ -527,9 +555,11 @@ async def test_token_redirect_uri_mismatch(self, test_client, registered_client, error_response = response.json() assert error_response["error"] == "invalid_request" assert "redirect_uri did not match" in error_response["error_description"] - + @pytest.mark.anyio - async def test_token_code_verifier_mismatch(self, test_client, registered_client, auth_code): + async def test_token_code_verifier_mismatch( + self, test_client, registered_client, auth_code + ): """Test token endpoint error - PKCE code verifier mismatch.""" # Try to use the code with an incorrect code verifier response = await test_client.post( @@ -547,7 +577,7 @@ async def test_token_code_verifier_mismatch(self, test_client, registered_client error_response = response.json() assert error_response["error"] == "invalid_grant" assert "incorrect code_verifier" in error_response["error_description"] - + @pytest.mark.anyio async def test_token_invalid_refresh_token(self, test_client, registered_client): """Test token endpoint error - refresh token does not exist.""" @@ -565,15 +595,20 @@ async def test_token_invalid_refresh_token(self, test_client, registered_client) error_response = response.json() assert error_response["error"] == "invalid_grant" assert "refresh token does not exist" in error_response["error_description"] - + @pytest.mark.anyio async def test_token_expired_refresh_token( - self, test_client, registered_client, auth_code, pkce_challenge, mock_oauth_provider + self, + test_client, + registered_client, + auth_code, + pkce_challenge, + mock_oauth_provider, ): """Test token endpoint error - refresh token has expired.""" # Step 1: First, let's create a token and refresh token at the current time current_time = time.time() - + # Exchange authorization code for tokens normally token_response = await test_client.post( "/token", @@ -589,10 +624,12 @@ async def test_token_expired_refresh_token( assert token_response.status_code == 200 tokens = token_response.json() refresh_token = tokens["refresh_token"] - + # Step 2: Now let's time travel forward 4 hours (tokens expire in 1 hour by default) # Mock the time.time() function to return a value 4 hours in the future - with unittest.mock.patch('time.time', return_value=current_time + 14400): # 4 hours = 14400 seconds + with unittest.mock.patch( + "time.time", return_value=current_time + 14400 + ): # 4 hours = 14400 seconds # Try to use the refresh token which should now be considered expired response = await test_client.post( "/token", @@ -603,13 +640,13 @@ async def test_token_expired_refresh_token( "refresh_token": refresh_token, }, ) - + # In the "future", the token should be considered expired assert response.status_code == 400 error_response = response.json() assert error_response["error"] == "invalid_grant" assert "refresh token has expired" in error_response["error_description"] - + @pytest.mark.anyio async def test_token_invalid_scope( self, test_client, registered_client, auth_code, pkce_challenge @@ -628,10 +665,10 @@ async def test_token_invalid_scope( }, ) assert token_response.status_code == 200 - + tokens = token_response.json() refresh_token = tokens["refresh_token"] - + # Try to use refresh token with an invalid scope response = await test_client.post( "/token", @@ -675,7 +712,7 @@ async def test_client_registration( # assert await mock_oauth_provider.clients_store.get_client( # client_info["client_id"] # ) is not None - + @pytest.mark.anyio async def test_client_registration_missing_required_fields( self, test_client: httpx.AsyncClient @@ -696,7 +733,7 @@ async def test_client_registration_missing_required_fields( assert "error" in error_data assert error_data["error"] == "invalid_client_metadata" assert error_data["error_description"] == "redirect_uris: Field required" - + @pytest.mark.anyio async def test_client_registration_invalid_uri( self, test_client: httpx.AsyncClient @@ -716,8 +753,11 @@ async def test_client_registration_invalid_uri( error_data = response.json() assert "error" in error_data assert error_data["error"] == "invalid_client_metadata" - assert error_data["error_description"] == "redirect_uris.0: Input should be a valid URL, relative URL without a base" - + assert ( + error_data["error_description"] + == "redirect_uris.0: Input should be a valid URL, relative URL without a base" + ) + @pytest.mark.anyio async def test_client_registration_empty_redirect_uris( self, test_client: httpx.AsyncClient @@ -736,11 +776,17 @@ async def test_client_registration_empty_redirect_uris( error_data = response.json() assert "error" in error_data assert error_data["error"] == "invalid_client_metadata" - assert error_data["error_description"] == "redirect_uris: List should have at least 1 item after validation, not 0" - + assert ( + error_data["error_description"] + == "redirect_uris: List should have at least 1 item after validation, not 0" + ) + @pytest.mark.anyio async def test_authorize_form_post( - self, test_client: httpx.AsyncClient, mock_oauth_provider: MockOAuthProvider, pkce_challenge + self, + test_client: httpx.AsyncClient, + mock_oauth_provider: MockOAuthProvider, + pkce_challenge, ): """Test the authorization endpoint using POST with form-encoded data.""" # Register a client @@ -781,7 +827,10 @@ async def test_authorize_form_post( @pytest.mark.anyio async def test_authorization_get( - self, test_client: httpx.AsyncClient, mock_oauth_provider: MockOAuthProvider, pkce_challenge + self, + test_client: httpx.AsyncClient, + mock_oauth_provider: MockOAuthProvider, + pkce_challenge, ): """Test the full authorization flow.""" # 1. Register a client @@ -884,9 +933,13 @@ async def test_authorization_get( assert response.status_code == 200 # Verify that the token was revoked - assert await mock_oauth_provider.load_access_token( - new_token_response["access_token"] - ) is None + assert ( + await mock_oauth_provider.load_access_token( + new_token_response["access_token"] + ) + is None + ) + @pytest.mark.anyio async def test_revoke_invalid_token(self, test_client, registered_client): """Test revoking an invalid token.""" @@ -900,6 +953,7 @@ async def test_revoke_invalid_token(self, test_client, registered_client): ) # per RFC, this should return 200 even if the token is invalid assert response.status_code == 200 + @pytest.mark.anyio async def test_revoke_with_malformed_token(self, test_client, registered_client): response = await test_client.post( @@ -908,23 +962,22 @@ async def test_revoke_with_malformed_token(self, test_client, registered_client) "client_id": registered_client["client_id"], "client_secret": registered_client["client_secret"], "token": 123, - "token_type_hint": "asdf" + "token_type_hint": "asdf", }, ) assert response.status_code == 400 error_response = response.json() assert error_response["error"] == "invalid_request" assert "token_type_hint" in error_response["error_description"] - - - class TestFastMCPWithAuth: """Test FastMCP server with authentication.""" @pytest.mark.anyio - async def test_fastmcp_with_auth(self, mock_oauth_provider: MockOAuthProvider, pkce_challenge): + async def test_fastmcp_with_auth( + self, mock_oauth_provider: MockOAuthProvider, pkce_challenge + ): """Test creating a FastMCP server with authentication.""" # Create FastMCP server with auth provider mcp = FastMCP( @@ -1053,15 +1106,17 @@ def test_tool(x: int) -> str: assert set(sse_data["result"]["capabilities"].keys()) == set( ("experimental", "prompts", "resources", "tools") ) - - + + class TestAuthorizeEndpointErrors: """Test error handling in the OAuth authorization endpoint.""" - + @pytest.mark.anyio - async def test_authorize_missing_client_id(self, test_client: httpx.AsyncClient, pkce_challenge): + async def test_authorize_missing_client_id( + self, test_client: httpx.AsyncClient, pkce_challenge + ): """Test authorization endpoint with missing client_id. - + According to the OAuth2.0 spec, if client_id is missing, the server should inform the resource owner and NOT redirect. """ @@ -1073,19 +1128,21 @@ async def test_authorize_missing_client_id(self, test_client: httpx.AsyncClient, "redirect_uri": "https://client.example.com/callback", "state": "test_state", "code_challenge": pkce_challenge["code_challenge"], - "code_challenge_method": "S256" + "code_challenge_method": "S256", }, ) - + # Should NOT redirect, should show an error page assert response.status_code == 400 # The response should include an error message about missing client_id assert "client_id" in response.text.lower() - + @pytest.mark.anyio - async def test_authorize_invalid_client_id(self, test_client: httpx.AsyncClient, pkce_challenge): + async def test_authorize_invalid_client_id( + self, test_client: httpx.AsyncClient, pkce_challenge + ): """Test authorization endpoint with invalid client_id. - + According to the OAuth2.0 spec, if client_id is invalid, the server should inform the resource owner and NOT redirect. """ @@ -1097,24 +1154,24 @@ async def test_authorize_invalid_client_id(self, test_client: httpx.AsyncClient, "redirect_uri": "https://client.example.com/callback", "state": "test_state", "code_challenge": pkce_challenge["code_challenge"], - "code_challenge_method": "S256" + "code_challenge_method": "S256", }, ) - + # Should NOT redirect, should show an error page assert response.status_code == 400 # The response should include an error message about invalid client_id assert "client" in response.text.lower() - + @pytest.mark.anyio async def test_authorize_missing_redirect_uri( self, test_client: httpx.AsyncClient, registered_client, pkce_challenge ): """Test authorization endpoint with missing redirect_uri. - + If client has only one registered redirect_uri, it can be omitted. """ - + response = await test_client.get( "/authorize", params={ @@ -1126,22 +1183,22 @@ async def test_authorize_missing_redirect_uri( "state": "test_state", }, ) - + # Should redirect to the registered redirect_uri assert response.status_code == 302, response.content redirect_url = response.headers["location"] assert redirect_url.startswith("https://client.example.com/callback") - + @pytest.mark.anyio async def test_authorize_invalid_redirect_uri( self, test_client: httpx.AsyncClient, registered_client, pkce_challenge ): """Test authorization endpoint with invalid redirect_uri. - + According to the OAuth2.0 spec, if redirect_uri is invalid or doesn't match, the server should inform the resource owner and NOT redirect. """ - + response = await test_client.get( "/authorize", params={ @@ -1153,25 +1210,33 @@ async def test_authorize_invalid_redirect_uri( "state": "test_state", }, ) - + # Should NOT redirect, should show an error page assert response.status_code == 400, response.content # The response should include an error message about redirect_uri mismatch assert "redirect" in response.text.lower() @pytest.mark.anyio - @pytest.mark.parametrize("registered_client", - [{"redirect_uris": ["https://client.example.com/callback", - "https://client.example.com/other-callback"]}], - indirect=True) + @pytest.mark.parametrize( + "registered_client", + [ + { + "redirect_uris": [ + "https://client.example.com/callback", + "https://client.example.com/other-callback", + ] + } + ], + indirect=True, + ) async def test_authorize_missing_redirect_uri_multiple_registered( self, test_client: httpx.AsyncClient, registered_client, pkce_challenge ): """Test authorization endpoint with missing redirect_uri when client has multiple registered URIs. - + If client has multiple registered redirect_uris, redirect_uri must be provided. """ - + response = await test_client.get( "/authorize", params={ @@ -1183,22 +1248,22 @@ async def test_authorize_missing_redirect_uri_multiple_registered( "state": "test_state", }, ) - + # Should NOT redirect, should return a 400 error assert response.status_code == 400 # The response should include an error message about missing redirect_uri assert "redirect_uri" in response.text.lower() - + @pytest.mark.anyio async def test_authorize_unsupported_response_type( self, test_client: httpx.AsyncClient, registered_client, pkce_challenge ): """Test authorization endpoint with unsupported response_type. - + According to the OAuth2.0 spec, for other errors like unsupported_response_type, the server should redirect with error parameters. """ - + response = await test_client.get( "/authorize", params={ @@ -1210,28 +1275,28 @@ async def test_authorize_unsupported_response_type( "state": "test_state", }, ) - + # Should redirect with error parameters assert response.status_code == 302 redirect_url = response.headers["location"] parsed_url = urlparse(redirect_url) query_params = parse_qs(parsed_url.query) - + assert "error" in query_params assert query_params["error"][0] == "unsupported_response_type" # State should be preserved assert "state" in query_params assert query_params["state"][0] == "test_state" - + @pytest.mark.anyio async def test_authorize_missing_response_type( self, test_client: httpx.AsyncClient, registered_client, pkce_challenge ): """Test authorization endpoint with missing response_type. - + Missing required parameter should result in invalid_request error. """ - + response = await test_client.get( "/authorize", params={ @@ -1243,25 +1308,25 @@ async def test_authorize_missing_response_type( "state": "test_state", }, ) - + # Should redirect with error parameters assert response.status_code == 302 redirect_url = response.headers["location"] parsed_url = urlparse(redirect_url) query_params = parse_qs(parsed_url.query) - + assert "error" in query_params assert query_params["error"][0] == "invalid_request" # State should be preserved assert "state" in query_params assert query_params["state"][0] == "test_state" - + @pytest.mark.anyio async def test_authorize_missing_pkce_challenge( self, test_client: httpx.AsyncClient, registered_client ): """Test authorization endpoint with missing PKCE code_challenge. - + Missing PKCE parameters should result in invalid_request error. """ response = await test_client.get( @@ -1274,28 +1339,28 @@ async def test_authorize_missing_pkce_challenge( # using default URL }, ) - + # Should redirect with error parameters assert response.status_code == 302 redirect_url = response.headers["location"] parsed_url = urlparse(redirect_url) query_params = parse_qs(parsed_url.query) - + assert "error" in query_params assert query_params["error"][0] == "invalid_request" # State should be preserved assert "state" in query_params assert query_params["state"][0] == "test_state" - + @pytest.mark.anyio async def test_authorize_invalid_scope( self, test_client: httpx.AsyncClient, registered_client, pkce_challenge ): """Test authorization endpoint with invalid scope. - + Invalid scope should redirect with invalid_scope error. """ - + response = await test_client.get( "/authorize", params={ @@ -1308,13 +1373,13 @@ async def test_authorize_invalid_scope( "state": "test_state", }, ) - + # Should redirect with error parameters assert response.status_code == 302 redirect_url = response.headers["location"] parsed_url = urlparse(redirect_url) query_params = parse_qs(parsed_url.query) - + assert "error" in query_params assert query_params["error"][0] == "invalid_scope" # State should be preserved