diff --git a/.github/workflows/tests.yml b/.github/workflows/tests.yml index 31cdceab..4058246a 100644 --- a/.github/workflows/tests.yml +++ b/.github/workflows/tests.yml @@ -77,7 +77,7 @@ jobs: strategy: fail-fast: false matrix: - primitive: [sampling, tools, resources, prompts, elicitation, notifications] + primitive: [sampling, tools, resources, prompts, elicitation, notifications, auth] steps: - uses: actions/checkout@v3 - name: Set up Python 3.11 diff --git a/README.md b/README.md index 4458cfc8..22010e97 100644 --- a/README.md +++ b/README.md @@ -45,7 +45,7 @@ | Supports | | | :------------- | :---------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------- | -| **Primitives** | [![Tools](https://img.shields.io/github/actions/workflow/status/pietrozullo/mcp-use/tests.yml?job=primitive-tools&label=Tools&style=flat)](https://github.com/pietrozullo/mcp-use/actions/workflows/tests.yml) [![Resources](https://img.shields.io/github/actions/workflow/status/pietrozullo/mcp-use/tests.yml?job=primitive-resources&label=Resources&style=flat)](https://github.com/pietrozullo/mcp-use/actions/workflows/tests.yml) [![Prompts](https://img.shields.io/github/actions/workflow/status/pietrozullo/mcp-use/tests.yml?job=primitive-prompts&label=Prompts&style=flat)](https://github.com/pietrozullo/mcp-use/actions/workflows/tests.yml) [![Sampling](https://img.shields.io/github/actions/workflow/status/pietrozullo/mcp-use/tests.yml?job=primitive-sampling&label=Sampling&style=flat)](https://github.com/pietrozullo/mcp-use/actions/workflows/tests.yml) [![Elicitation](https://img.shields.io/github/actions/workflow/status/pietrozullo/mcp-use/tests.yml?job=primitive-elicitation&label=Elicitation&style=flat)](https://github.com/pietrozullo/mcp-use/actions/workflows/tests.yml) | +| **Primitives** | [![Tools](https://img.shields.io/github/actions/workflow/status/pietrozullo/mcp-use/tests.yml?job=primitive-tools&label=Tools&style=flat)](https://github.com/pietrozullo/mcp-use/actions/workflows/tests.yml) [![Resources](https://img.shields.io/github/actions/workflow/status/pietrozullo/mcp-use/tests.yml?job=primitive-resources&label=Resources&style=flat)](https://github.com/pietrozullo/mcp-use/actions/workflows/tests.yml) [![Prompts](https://img.shields.io/github/actions/workflow/status/pietrozullo/mcp-use/tests.yml?job=primitive-prompts&label=Prompts&style=flat)](https://github.com/pietrozullo/mcp-use/actions/workflows/tests.yml) [![Sampling](https://img.shields.io/github/actions/workflow/status/pietrozullo/mcp-use/tests.yml?job=primitive-sampling&label=Sampling&style=flat)](https://github.com/pietrozullo/mcp-use/actions/workflows/tests.yml) [![Elicitation](https://img.shields.io/github/actions/workflow/status/pietrozullo/mcp-use/tests.yml?job=primitive-elicitation&label=Elicitation&style=flat)](https://github.com/pietrozullo/mcp-use/actions/workflows/tests.yml) [![Authentication](https://img.shields.io/github/actions/workflow/status/pietrozullo/mcp-use/tests.yml?job=primitive-authentication&label=Authentication&style=flat)](https://github.com/pietrozullo/mcp-use/actions/workflows/tests.yml) | | **Transports** | [![Stdio](https://img.shields.io/github/actions/workflow/status/pietrozullo/mcp-use/tests.yml?job=transport-stdio&label=Stdio&style=flat)](https://github.com/pietrozullo/mcp-use/actions/workflows/tests.yml) [![SSE](https://img.shields.io/github/actions/workflow/status/pietrozullo/mcp-use/tests.yml?job=transport-sse&label=SSE&style=flat)](https://github.com/pietrozullo/mcp-use/actions/workflows/tests.yml) [![Streamable HTTP](https://img.shields.io/github/actions/workflow/status/pietrozullo/mcp-use/tests.yml?job=transport-streamableHttp&label=Streamable%20HTTP&style=flat)](https://github.com/pietrozullo/mcp-use/actions/workflows/tests.yml) | ## Features diff --git a/docs/client/authentication.mdx b/docs/client/authentication.mdx new file mode 100644 index 00000000..612f0feb --- /dev/null +++ b/docs/client/authentication.mdx @@ -0,0 +1,416 @@ +--- +title: "Authentication" +description: "mcp-use supports multiple authentication methods for MCP servers, including OAuth 2.0 with automatic Dynamic Client Registration (DCR), bearer tokens, and custom authentication providers." +icon: "key" +--- + +## Quick Start + +### OAuth Authentication + +For servers that support OAuth, you can use Dynamic Client Registration (automatic) or pre-registered clients: + +```python +from mcp_use import MCPClient, MCPAgent +from langchain_openai import ChatOpenAI + +# Dynamic Client Registration (automatic) +config = { + "mcpServers": { + "linear": { + "url": "https://mcp.linear.app/sse", + # It's not needed to specify auth section + } + } +} + +# Or with pre-registered client +config = { + "mcpServers": { + "my_server": { + "url": "https://api.example.com/mcp/", + "auth": { + "client_id": "your-client-id", + "client_secret": "your-client-secret", + } + } + } +} + +# Create client and agent +client = MCPClient(config=config) +llm = ChatOpenAI(model="gpt-4o-mini", temperature=0) +agent = MCPAgent(llm=llm, client=client) + +# Use the agent +response = await agent.run("Your query here") +print(response) +``` + +### Bearer Token Authentication + +For servers requiring API keys: + +```python +config = { + "mcpServers": { + "api": { + "url": "https://api.example.com/mcp/sse", + "auth": "sk-your-api-key-here" + } + } +} +``` + +### Custom Port Configuration + +You can specify a custom port for OAuth callbacks to avoid conflicts: + +```python +config = { + "mcpServers": { + "my_server": { + "url": "https://api.example.com/mcp/", + "auth": { + "client_id": "your-client-id", + "client_secret": "your-client-secret", + "callback_port": 8082, # Use custom port instead of default 8080 + } + } + } +} +``` + +## Authentication Methods + +### 1. OAuth 2.0 Authentication + +OAuth 2.0 is the most common authentication method for MCP servers. mcp-use supports: + +- **Dynamic Client Registration (DCR)** - Automatic client registration +- **Pre-registered Clients** - Using existing OAuth applications +- **Custom OAuth Providers** - With explicit metadata + +#### Dynamic Client Registration + +For servers that support DCR, you don't need to register a client manually: + +```python +config = { + "mcpServers": { + "linear": { + "url": "https://mcp.linear.app/sse", + "auth": { + "scope": "read write" # Optional scopes + } + } + } +} +``` + +#### Pre-registered OAuth Client + +For servers requiring manual client registration: + +```python +config = { + "mcpServers": { + "example": { + "url": "https://api.example.com/mcp/", + "auth": { + "client_id": "your-registered-client-id", + "client_secret": "your-client-secret", # Optional + "callback_port": 8081, # Optional custom port + } + } + } +} +``` + +#### OAuth Provider with Metadata + +For servers with known OAuth endpoints, provide metadata upfront: + +```python +config = { + "mcpServers": { + "linear": { + "url": "https://mcp.linear.app/sse", + "auth": { + "oauth_provider": { + "id": "linear", + "display_name": "Linear", + "metadata": { + "issuer": "https://mcp.linear.app", + "authorization_endpoint": "https://mcp.linear.app/authorize", + "token_endpoint": "https://mcp.linear.app/token", + "registration_endpoint": "https://mcp.linear.app/register" + } + } + } + } + } +} +``` + +### 2. Bearer Token Authentication + +For servers requiring simple API keys or bearer tokens: + +```python +config = { + "mcpServers": { + "api": { + "url": "https://api.example.com/mcp/sse", + "auth": "sk-your-api-key-here" + } + } +} +``` + +### 3. Custom Authentication + +For servers requiring custom authentication methods: + +```python +from httpx import BasicAuth, DigestAuth + +config = { + "mcpServers": { + "secure": { + "url": "https://secure.example.com/mcp/sse", + "auth": BasicAuth("username", "password") + }, + "digest": { + "url": "https://digest.example.com/mcp/sse", + "auth": DigestAuth("username", "password") + } + } +} +``` + +### 4. No Authentication + +For public servers or servers without authentication: + +```python +config = { + "mcpServers": { + "public": { + "url": "https://public.example.com/mcp/sse" + # No auth config - will attempt discovery or continue without auth + } + } +} +``` + +## Complete Examples + +### GitHub MCP Server Example + +The GitHub MCP server requires OAuth authentication. You'll need to create a GitHub OAuth App first: + +1. **Create a GitHub OAuth App**: + - Go to [GitHub OAuth Apps](https://github.com/settings/applications/new) + - Set **Application name**: `your-app-name` + - Set **Homepage URL**: `http://localhost:8080` (or your custom port) + - Set **Authorization callback URL**: `http://localhost:8080/callback` (or your custom port) + - Click "Register application" + - Copy your **Client ID** and **Client Secret** + +2. **Configure mcp-use**: + +```python +import asyncio +from mcp_use import MCPClient, MCPAgent +from langchain_openai import ChatOpenAIxw + +async def github_example(): + # GitHub MCP server configuration + config = { + "mcpServers": { + "github": { + "url": "https://api.githubcopilot.com/mcp/", + "auth": { + "client_id": "your-github-client-id", + "client_secret": "your-github-client-secret", + "scope": "repo", # Needed for GitHub + "callback_port": 8080, # The same port as the callback on OAuth app + } + } + } + } + + # Create client and agent + client = MCPClient(config=config) + llm = ChatOpenAI(model="gpt-4o-mini", temperature=0) + agent = MCPAgent(llm=llm, client=client, max_steps=30) + + # Example queries + queries = [ + "Your queries" + ] + + for query in queries: + print(f"\n🔍 Query: {query}") + response = await agent.run(query) + print(f"📝 Response: {response}") + +if __name__ == "__main__": + asyncio.run(github_example()) +``` + +### Multi-Server Configuration + +You can mix different authentication methods across servers: + +```python +from httpx import BasicAuth + +config = { + "mcpServers": { + "github": { + "url": "https://api.githubcopilot.com/mcp/", + "auth": { + "client_id": "your-github-client-id", + "client_secret": "your-github-client-secret", + "scope": "repo", + "callback_port": 8082, # Remember to use the same on GitHub + } + }, + "linear": { + "url": "https://mcp.linear.app/sse", + # DCR + }, + "api": { + "url": "https://api.example.com/mcp/sse", + "auth": "sk-api-key" # Bearer token + }, + "secure": { + "url": "https://secure.example.com/mcp/sse", + "auth": BasicAuth("username", "password") # Custom auth + } + } +} + +client = MCPClient(config=config) +``` + +## OAuth Flow Process + +When OAuth authentication is required: + +1. **Browser Opens**: Your default browser opens to the authorization page +2. **Grant Access**: Review and approve the requested permissions +3. **Automatic Redirect**: You're redirected to a local callback URL +4. **Token Storage**: Access tokens are stored securely in `~/.mcp_use/tokens/` + +## Token Storage + +Authentication data is stored securely: +- **Access Tokens**: `~/.mcp_use/tokens/{server_domain}.json` +- **Client Registrations**: `~/.mcp_use/tokens/registrations/{server_domain}_registration.json` + +## Configuration Options + +### OAuth Configuration Parameters + +| Parameter | Type | Required | Description | +|-----------|------|----------|-------------| +| `client_id` | string | No* | OAuth client ID (required if not using DCR) | +| `client_secret` | string | No | OAuth client secret (required if not using DCR) | +| `scope` | string | No | OAuth scopes to request | +| `callback_port` | integer | No | Port for OAuth callback (default: 8080) | +| `oauth_provider` | object | No | OAuth provider metadata | + +*Required unless using Dynamic Client Registration + +### Port Configuration + +- **Default Port**: 8080 +- **Custom Ports**: Any available port (e.g., 8081, 8082, 3000) +- **Port Conflicts**: mcp-use will check if the port is available before starting OAuth flow + +## Troubleshooting + +### Common Issues + +#### OAuth Discovery Fails + +If a server doesn't support OAuth discovery: +- Provide an `oauth_provider` with metadata +- Use a pre-registered `client_id` +- Check if the server requires different authentication + +#### "Invalid redirect URI" Error + +Solutions: +- Use Dynamic Client Registration (omit `client_id`) +- Register your app with supported redirect URIs +- Check if your provider supports wildcard redirect URIs +- Ensure callback URL matches your OAuth app configuration + +#### Port Already in Use + +If you get a port conflict error: +```python +# Use a different port +"callback_port": 8081 # or any other available port +``` + +#### GitHub OAuth Issues + +For GitHub specifically: +- Ensure your OAuth app callback URL matches: `http://localhost:8080/callback` (or your custom port) +- Use correct scopes: `repo`, `read:user`, etc. +- Check that your GitHub OAuth app is properly configured + +### Debugging + +Enable debug logging to see detailed authentication flow: + +```python +from mcp_use import set_debug +set_debug(2) # Enable verbose logging +``` + +## Security Best Practices + +- **Token Storage**: Tokens are stored with restricted permissions +- **Version Control**: Never commit authentication files to version control +- **CSRF Protection**: OAuth flow uses state parameter for CSRF protection +- **Localhost Callbacks**: All callbacks use localhost (127.0.0.1) for security +- **Isolation**: Each server's authentication is isolated +- **Environment Variables**: Use environment variables for sensitive data: + +```python +import os + +config = { + "mcpServers": { + "github": { + "url": "https://api.githubcopilot.com/mcp/", + "auth": { + "client_id": os.getenv("GITHUB_CLIENT_ID"), + "client_secret": os.getenv("GITHUB_CLIENT_SECRET"), + "scope": "repo", + } + } + } +} +``` + +## Example Servers that support OAuth + +### OAuth with DCR Support +- **Linear**: `https://mcp.linear.app/sse` +- **Asana**: `https://mcp.asana.com/sse` +- **Atlassian**: `https://mcp.atlassian.com/v1/sse` + +### OAuth with Manual Registration +- **GitHub**: `https://api.githubcopilot.com/mcp/` + +### Bearer Token +- Most API-based MCP servers + +Check your server's documentation for specific authentication requirements and supported methods. diff --git a/docs/docs.json b/docs/docs.json index 6acf897d..20740390 100644 --- a/docs/docs.json +++ b/docs/docs.json @@ -77,7 +77,8 @@ "client/sampling", "client/elicitation", "client/notifications", - "client/logging" + "client/logging", + "client/authentication" ] } ] diff --git a/examples/simple_oauth_example.py b/examples/simple_oauth_example.py new file mode 100644 index 00000000..cc1e2378 --- /dev/null +++ b/examples/simple_oauth_example.py @@ -0,0 +1,37 @@ +from langchain_openai import ChatOpenAI + +from mcp_use import MCPAgent, MCPClient + +# This example demonstrates OAuth with Dynamic Client Registration (DCR) +# The client will automatically register itself with the Linear MCP server +# No manual client_id configuration required! + +# Clean MCP configuration - no auth details in the server config +linear_config = {"mcpServers": {"linear": {"url": "https://mcp.linear.app/sse"}}} + + +async def main(): + # Create client with OAuth-enabled configuration at the client level + # Option 1: Dynamic Client Registration (empty dict) + client = MCPClient(config=linear_config) + + # Option 2: If you already have a registered client_id, you can use it: + # client = MCPClient( + # config=linear_config, + # auth={ + # "client_id": "YOUR_CLIENT_ID", # Use your pre-registered client ID + # "client_secret": "YOUR_SECRET", # Only if required + # } + # ) + + llm = ChatOpenAI(model="gpt-4o-mini", temperature=0) + agent = MCPAgent(llm=llm, client=client) + + response = await agent.run(query="What are my latest linear tickets") + print(response) + + +if __name__ == "__main__": + import asyncio + + asyncio.run(main()) diff --git a/mcp_use/auth/__init__.py b/mcp_use/auth/__init__.py new file mode 100644 index 00000000..561b46d5 --- /dev/null +++ b/mcp_use/auth/__init__.py @@ -0,0 +1,6 @@ +"""Authentication support for MCP clients.""" + +from .bearer import BearerAuth +from .oauth import OAuth + +__all__ = ["BearerAuth", "OAuth"] diff --git a/mcp_use/auth/bearer.py b/mcp_use/auth/bearer.py new file mode 100644 index 00000000..98c2ede7 --- /dev/null +++ b/mcp_use/auth/bearer.py @@ -0,0 +1,17 @@ +"""Bearer token authentication support.""" + +from collections.abc import Generator + +import httpx +from pydantic import BaseModel, SecretStr + + +class BearerAuth(httpx.Auth, BaseModel): + """Bearer token authentication for HTTP requests.""" + + token: SecretStr + + def auth_flow(self, request: httpx.Request) -> Generator[httpx.Request, httpx.Response, None]: + """Apply bearer token authentication to the request.""" + request.headers["Authorization"] = f"Bearer {self.token.get_secret_value()}" + yield request diff --git a/mcp_use/auth/oauth.py b/mcp_use/auth/oauth.py new file mode 100644 index 00000000..f5d96d99 --- /dev/null +++ b/mcp_use/auth/oauth.py @@ -0,0 +1,625 @@ +"""OAuth authentication support for MCP clients.""" + +import json +import secrets +import webbrowser +from datetime import UTC, datetime, timedelta +from pathlib import Path +from typing import Any +from urllib.parse import urlparse + +import httpx +from authlib.integrations.httpx_client import AsyncOAuth2Client +from authlib.oauth2 import OAuth2Error +from pydantic import BaseModel, Field, HttpUrl, SecretStr + +from ..exceptions import OAuthAuthenticationError, OAuthDiscoveryError +from ..logging import logger +from .bearer import BearerAuth +from .oauth_callback import OAuthCallbackServer + + +class ServerOAuthMetadata(BaseModel): + """OAuth metadata from MCP server with flexible field support. + It is essentially a configuration that tells MCP client: + + - Where to send users for authorization + - Where to exchange the codes for tokens + - Which OAuth features are supported + - Where to register new users with DCR""" + + issuer: HttpUrl # The OAuth server's identity + authorization_endpoint: HttpUrl # URL with endpoint for client auth + token_endpoint: HttpUrl # URL with endpoint for tokens' exchange + userinfo_endpoint: HttpUrl | None = None + revocation_endpoint: HttpUrl | None = None + introspection_endpoint: HttpUrl | None = None + registration_endpoint: HttpUrl | None = None # Endpoint for DCR + jwks_uri: HttpUrl | None = None + response_types_supported: list[str] = Field(default_factory=lambda: ["code"]) + subject_types_supported: list[str] = Field(default_factory=lambda: ["public"]) + id_token_signing_alg_values_supported: list[str] = Field(default_factory=lambda: ["RS256"]) + scopes_supported: list[str] | None = None # Which permissions are supported + token_endpoint_auth_methods_supported: list[str] = Field(default_factory=lambda: ["client_secret_basic"]) + claims_supported: list[str] | None = None + code_challenge_methods_supported: list[str] | None = None + + class Config: + extra = "allow" # Allow additional fields + + +class OAuthClientProvider(BaseModel): + """OAuth client provider configuration for a specific server. + + This contains all the information needed to authenticate with an OAuth server + without needing to discover metadata or register clients dynamically.""" + + id: str # Unique identifier + display_name: str + metadata: ServerOAuthMetadata | dict[str, Any] + + @property + def oauth_metadata(self) -> ServerOAuthMetadata: + """Get OAuth metadata as ServerOAuthMetadata instance.""" + if isinstance(self.metadata, dict): + return ServerOAuthMetadata(**self.metadata) + return self.metadata + + +class TokenData(BaseModel): + """OAuth token data. + + This is the information received after + successfull authentication""" + + access_token: str # Actual credential used for requests + token_type: str = "Bearer" + expires_at: float | None = None + refresh_token: str | None = None + scope: str | None = None + + +class ClientRegistrationResponse(BaseModel): + """Dynamic Client Registration response. + + It represents the response from an OAuth server + when you dinamically register a new OAuth client.""" + + client_id: str + client_secret: str | None = None + client_id_issued_at: int | None = None + client_secret_expires_at: int | None = None + redirect_uris: list[str] | None = None # Where auth server should redirect after auth + grant_types: list[str] | None = None # Which oauth flows it uses + response_types: list[str] | None = None + client_name: str | None = None + token_endpoint_auth_method: str | None = None + + class Config: + extra = "allow" # Allow additional fields from server + + +class FileTokenStorage: + """File-based token storage. + + It's responsible for: + + - Saving OAuth tokens to disk after auth + - Loading saved tokens when the app restarts + - Deleting tokens when they're revoked + - Organizing tokens by server URL""" + + def __init__(self, base_dir: Path | None = None): + """Initialize token storage. + + Args: + base_dir: Base directory for token storage. Defaults to ~/.mcp_use/tokens + """ + self.base_dir = base_dir or Path.home() / ".mcp_use" / "tokens" + logger.debug(f"FileTokenStorage initialized with base_dir: {self.base_dir}") + self.base_dir.mkdir(parents=True, exist_ok=True) + + def _get_token_path(self, server_url: str) -> Path: + """Get token file path for a server.""" + # Create a safe filename from the URL + parsed = urlparse(server_url) + filename = f"{parsed.netloc}_{parsed.path.replace('/', '_')}.json" + path = self.base_dir / filename + logger.debug(f"Token path for server '{server_url}' is '{path}'") + return path + + async def save_tokens(self, server_url: str, tokens: dict[str, Any]) -> None: + """Save tokens to file.""" + token_path = self._get_token_path(server_url) + logger.debug(f"Saving tokens for '{server_url}' to '{token_path}'") + token_data = TokenData(**tokens) + token_path.write_text(token_data.model_dump_json()) + logger.debug(f"Tokens saved successfully for '{server_url}'") + + async def load_tokens(self, server_url: str) -> TokenData | None: + """Load tokens from file.""" + token_path = self._get_token_path(server_url) + logger.debug(f"Attempting to load tokens for '{server_url}' from '{token_path}'") + if not token_path.exists(): + logger.debug(f"Token file not found: '{token_path}'") + return None + + try: + data = json.loads(token_path.read_text()) + token_data = TokenData(**data) + logger.debug(f"Successfully loaded tokens for '{server_url}'") + return token_data + except (json.JSONDecodeError, ValueError) as e: + logger.debug(f"Failed to load or parse token file '{token_path}': {e}") + return None + + async def delete_tokens(self, server_url: str) -> None: + """Delete tokens for a server.""" + token_path = self._get_token_path(server_url) + logger.debug(f"Deleting tokens for '{server_url}' at '{token_path}'") + if token_path.exists(): + token_path.unlink() + logger.debug(f"Token file '{token_path}' deleted.") + else: + logger.debug(f"Token file '{token_path}' not found, nothing to delete.") + + +class OAuth: + """OAuth authentication handler for MCP clients. + + This is the main class that handles all the authentication + It has several features: + + - Discovers OAuth server capabilities automatically + - Registers client dynamically when possible + - Manages token storage and refresh automaticlly""" + + def __init__( + self, + server_url: str, + token_storage: FileTokenStorage | None = None, + scope: str | None = None, + client_id: str | None = None, + client_secret: str | None = None, + callback_port: int | None = None, + oauth_provider: OAuthClientProvider | None = None, + ): + """Initialize OAuth handler. + + Args: + server_url: The MCP server URL + token_storage: Token storage implementation. Defaults to FileTokenStorage + scope: OAuth scopes to request + client_id: OAuth client ID. If not provided, will attempt dynamic registration + client_secret: OAuth client secret (for confidential clients) + callback_port: Port for local callback server, if empty, 8080 is used + oauth_provider: OAuth client provider to prevent metadata discovery + """ + logger.debug(f"Initializing OAuth for server: {server_url}") + self.server_url = server_url + self.token_storage = token_storage or FileTokenStorage() + self.scope = scope + self.client_id = client_id + self.client_secret = client_secret + + if callback_port: + self.callback_port = callback_port + logger.info(f"Using custom callback port {self.callback_port} provided in config") + else: + self.callback_port = 8080 + logger.info(f"Using default callback port {self.callback_port}") + + # Set the default redirect uri + self.redirect_uri = f"http://localhost:{self.callback_port}/callback" + self._oauth_provider = oauth_provider + self._metadata: ServerOAuthMetadata | None = None + + if self._oauth_provider: + self._metadata = self._oauth_provider.oauth_metadata + logger.debug(f"Using OAuth provider {self._oauth_provider.id} with metadata") + + self._client: AsyncOAuth2Client | None = None + self._bearer_auth: BearerAuth | None = None + logger.debug(f"OAuth initialized with scope='{self.scope}', client_id='{self.client_id}'") + + async def initialize(self, client: httpx.AsyncClient) -> BearerAuth | None: + """Initialize OAuth and return bearer auth if tokens exist.""" + logger.debug(f"OAuth.initialize called for {self.server_url}") + # Try to load existing tokens + logger.debug("Attempting to load existing tokens") + token_data = await self.token_storage.load_tokens(self.server_url) + if token_data: + logger.debug("Found existing tokens, checking validity") + if self._is_token_valid(token_data): + logger.debug("Existing token is valid, creating BearerAuth") + self._bearer_auth = BearerAuth(token=SecretStr(token_data.access_token)) + logger.debug("OAuth.initialize returning existing valid BearerAuth") + return self._bearer_auth + else: + logger.debug("Existing token is expired") + else: + logger.debug("No existing tokens found") + + # Discover OAuth metadata + if not self._metadata: + logger.debug("No valid token, proceeding to discover OAuth metadata") + await self._discover_metadata(client) + else: + logger.debug("Using provided OAuth metadata, skipping discovery") + + logger.debug("OAuth.initialize finished, no valid token available yet") + return None + + async def authenticate(self) -> BearerAuth: + """Perform OAuth authentication flow.""" + logger.debug("OAuth.authenticate called") + if not self._metadata: + logger.error("OAuth.authenticate called before metadata was discovered.") + raise OAuthAuthenticationError("OAuth metadata not discovered") + + # The port check should be done now. OAuth servers + # register client_id with also redirect_uri, so we + # have to ensure port is available before DCR + try: + import socket + + sock = socket.socket() + sock.setsockopt(socket.SOL_SOCKET, socket.SO_REUSEADDR, 1) + sock.bind(("127.0.0.1", self.callback_port)) + sock.close() + logger.debug(f"Using registered port {self.callback_port} for callback") + except (ValueError, OSError) as exception: + logger.error(f"The port {self.callback_port} is not available! Try using a different port!") + raise exception + + # Try to get client_id - either from config or dynamic registration + client_id = self.client_id + client_secret = self.client_secret + registration = None # Track if we used DCR + + if not client_id: + logger.debug("No client_id provided, attempting dynamic client registration") + # Try to load previously registered client + registration = await self._load_client_registration() + + if registration: + logger.debug("Using previously registered client") + client_id = registration.client_id + client_secret = registration.client_secret + else: + # Attempt dynamic registration + registration = await self._try_dynamic_registration() + if registration: + logger.debug("Dynamic registration successful") + client_id = registration.client_id + client_secret = registration.client_secret + # Store for future use + await self._store_client_registration(registration) + else: + logger.error("Dynamic client registration failed or not supported") + raise OAuthAuthenticationError( + "OAuth requires a client_id. Server does not support dynamic registration. " + "Please provide one in the auth configuration. " + "Example: {'auth': {'client_id': 'your-registered-client-id'}}" + ) + + logger.debug(f"Using client_id: {client_id}") + + # Create OAuth client + logger.debug("Creating AsyncOAuth2Client") + self._client = AsyncOAuth2Client( + client_id=client_id, + client_secret=client_secret, + redirect_uri=self.redirect_uri, + scope=self.scope, + ) + + # Start callback server + logger.debug("Starting OAuth callback server") + + callback_server = OAuthCallbackServer(port=self.callback_port) + redirect_uri = await callback_server.start() + self._client.redirect_uri = redirect_uri + logger.debug(f"Callback server started, redirect_uri: {redirect_uri}") + + # Generate state for CSRF protection + state = secrets.token_urlsafe(32) + logger.debug(f"Generated state for CSRF protection: {state}") + + # Build authorization URL + logger.debug("Creating authorization URL") + auth_url, _ = self._client.create_authorization_url( + str(self._metadata.authorization_endpoint), + state=state, + ) + + logger.debug("OAuth flow started:") + logger.debug(f" Client ID: {client_id}") + logger.debug(f" Authorization endpoint: {self._metadata.authorization_endpoint}") + logger.debug(f" Redirect URI: {redirect_uri}") + logger.debug(f" Scope: {self.scope}") + + # Open browser for authorization + print(f"Opening browser for authorization: {auth_url}") + webbrowser.open(auth_url) + + # Wait for callback + logger.debug("Waiting for authorization code from callback server") + try: + response = await callback_server.wait_for_code() + logger.debug("Received response from callback server") + except TimeoutError as e: + logger.error(f"OAuth callback timed out: {e}") + raise OAuthAuthenticationError(f"OAuth timeout: {e}") from e + + if response.error: + logger.error("OAuth authorization failed:") + logger.error(f" Error: {response.error}") + logger.error(f" Description: {response.error_description}") + logger.error(" The OAuth server returned this error, likely because:") + logger.error(f" 1. The client_id '{client_id}' is not registered with the OAuth server") + logger.error(" 2. The redirect_uri doesn't match the registered one") + logger.error(" 3. The requested scopes are invalid") + raise OAuthAuthenticationError(f"{response.error}: {response.error_description}") + + if not response.code: + logger.error("Callback response did not contain an authorization code") + raise OAuthAuthenticationError("No authorization code received") + + logger.debug(f"Received authorization code: {response.code[:10]}...") + + # Verify state + logger.debug(f"Verifying state. Expected: {state}, Got: {response.state}") + if response.state != state: + logger.error("State mismatch in OAuth callback. Possible CSRF attack.") + raise OAuthAuthenticationError("Invalid state parameter - possible CSRF attack") + logger.debug("State verified successfully") + + # Exchange code for tokens + logger.debug("Exchanging authorization code for tokens") + try: + token_response = await self._client.fetch_token( + str(self._metadata.token_endpoint), + authorization_response=f"{redirect_uri}?code={response.code}&state={response.state}", + grant_type="authorization_code", + ) + logger.debug("Successfully fetched tokens") + except OAuth2Error as e: + logger.error(f"Token exchange failed: {e}") + raise OAuthAuthenticationError(f"Token exchange failed: {e}") from e + + # Save tokens + logger.debug("Saving fetched tokens") + await self.token_storage.save_tokens(self.server_url, token_response) + + # Create bearer auth + logger.debug("Creating BearerAuth with new access token") + self._bearer_auth = BearerAuth(token=SecretStr(token_response["access_token"])) + return self._bearer_auth + + async def _discover_metadata(self, client: httpx.AsyncClient) -> None: + """Discover OAuth metadata from server.""" + logger.debug(f"Discovering OAuth metadata for {self.server_url}") + # Try well-known endpoint first + parsed = urlparse(self.server_url) + + # Edge case for GH that doesn't have metadata discovery + if parsed.netloc == "api.githubcopilot.com": + logger.debug("Detected GitHub MCP server, using its metadata") + issuer = "https://github.com/login/oauth" + authorization_endpoint = "https://github.com/login/oauth/authorize" + token_endpoint = "https://github.com/login/oauth/access_token" + self._metadata = ServerOAuthMetadata( + issuer=issuer, authorization_endpoint=authorization_endpoint, token_endpoint=token_endpoint + ) + return + + base_url = f"{parsed.scheme}://{parsed.netloc}" + well_known_url = f"{base_url}/.well-known/oauth-authorization-server" + + try: + logger.debug(f"Trying OAuth metadata discovery at: {well_known_url}") + response = await client.get(well_known_url) + response.raise_for_status() + metadata = response.json() + self._metadata = ServerOAuthMetadata(**metadata) + logger.debug("Successfully discovered OAuth metadata") + logger.debug(f" Authorization endpoint: {self._metadata.authorization_endpoint}") + logger.debug(f" Token endpoint: {self._metadata.token_endpoint}") + return + except (httpx.HTTPError, ValueError) as e: + logger.debug(f"Failed to discover OAuth metadata at {well_known_url}: {e}") + pass + + # Try OpenID Connect discovery + oidc_url = f"{base_url}/.well-known/openid-configuration" + logger.debug(f"Trying OpenID Connect discovery at: {oidc_url}") + try: + response = await client.get(oidc_url) + response.raise_for_status() + metadata = response.json() + self._metadata = ServerOAuthMetadata(**metadata) + logger.debug("Successfully discovered OIDC metadata") + logger.debug(f" Authorization endpoint: {self._metadata.authorization_endpoint}") + logger.debug(f" Token endpoint: {self._metadata.token_endpoint}") + return + except (httpx.HTTPError, ValueError) as e: + logger.debug(f"Failed to discover OIDC metadata at {oidc_url}: {e}") + pass + + # If discovery fails, we'll need the metadata from somewhere else + logger.error(f"Failed to discover OAuth/OIDC metadata for {self.server_url}") + raise OAuthDiscoveryError( + f"Failed to discover OAuth metadata for {self.server_url}. " + "Server must support OAuth metadata discovery at " + "/.well-known/oauth-authorization-server or /.well-known/openid-configuration" + ) + + def _is_token_valid(self, token_data: TokenData) -> bool: + """Check if token is still valid.""" + logger.debug("Checking token validity") + if not token_data.expires_at: + logger.debug("Token has no expiration time, assuming it's valid.") + return True # No expiration info, assume valid + + # Check if token expires in more than 60 seconds + expires_at = datetime.fromtimestamp(token_data.expires_at, tz=UTC) + now = datetime.now(tz=UTC) + is_valid = expires_at > now + timedelta(seconds=60) + logger.debug(f"Token expires at {expires_at}, current time is {now}. Valid: {is_valid}") + return is_valid + + async def _try_dynamic_registration(self) -> ClientRegistrationResponse | None: + """Try Dynamic Client Registration if supported by the server.""" + if not self._metadata or not self._metadata.registration_endpoint: + logger.debug("No registration endpoint available, skipping DCR") + return None + + logger.info("Attempting Dynamic Client Registration") + logger.debug(f"DCR endpoint: {self._metadata.registration_endpoint}") + + registration_data = { + "client_name": "mcp-use", + "redirect_uris": [self.redirect_uri], + "grant_types": ["authorization_code"], + "response_types": ["code"], + "token_endpoint_auth_method": "none", # Public client + "application_type": "native", + } + + # Add scope if specified + if self.scope: + registration_data["scope"] = self.scope + + logger.debug(f"DCR request payload: {registration_data}") + try: + async with httpx.AsyncClient() as client: + response = await client.post( + str(self._metadata.registration_endpoint), + json=registration_data, + headers={"Content-Type": "application/json"}, + ) + logger.debug(f"DCR response status: {response.status_code}") + response.raise_for_status() + + # Parse registration response + reg_response_data = response.json() + logger.debug(f"DCR response body: {reg_response_data}") + reg_response = ClientRegistrationResponse(**reg_response_data) + + # Update our credentials + self.client_id = reg_response.client_id + self.client_secret = reg_response.client_secret + + logger.info(f"Dynamic Client Registration successful: {self.client_id}") + + # Store the registered client info for future use + await self._store_client_registration(reg_response) + + return reg_response + + except httpx.HTTPError as e: + logger.warning(f"Dynamic Client Registration failed: {e}") + # Log the response if available + if hasattr(e, "response") and e.response: + logger.debug(f"DCR response: {e.response.status_code} - {e.response.text}") + return None + except Exception as e: + logger.warning(f"Unexpected error during DCR: {e}") + return None + + async def _store_client_registration(self, registration: ClientRegistrationResponse) -> None: + """Store client registration data for future use.""" + logger.debug("Storing client registration data") + # Store alongside tokens in a separate file + storage_path = self.token_storage.base_dir / "registrations" + storage_path.mkdir(parents=True, exist_ok=True) + + # Create a safe filename from the server URL + parsed = urlparse(self.server_url) + filename = f"{parsed.netloc}_{parsed.path.replace('/', '_')}_registration.json" + reg_path = storage_path / filename + logger.debug(f"Storing client registration to '{reg_path}'") + + # Store registration data + reg_path.write_text(registration.model_dump_json()) + logger.debug("Client registration data stored successfully") + + async def _load_client_registration(self) -> ClientRegistrationResponse | None: + """Load previously registered client credentials if available.""" + logger.debug("Attempting to load client registration data") + storage_path = self.token_storage.base_dir / "registrations" + + # Create a safe filename from the server URL + parsed = urlparse(self.server_url) + filename = f"{parsed.netloc}_{parsed.path.replace('/', '_')}_registration.json" + reg_path = storage_path / filename + logger.debug(f"Checking for client registration file at '{reg_path}'") + + if reg_path.exists(): + logger.debug("Client registration file found") + try: + data = json.loads(reg_path.read_text()) + reg_response = ClientRegistrationResponse(**data) + + # Check if registration is still valid (if expiry info provided) + if reg_response.client_secret_expires_at: + expires_at = datetime.fromtimestamp(reg_response.client_secret_expires_at, tz=UTC) + now = datetime.now(tz=UTC) + logger.debug(f"Checking client registration expiry. Expires at: {expires_at}, Now: {now}") + if expires_at <= now: + logger.debug("Stored client registration has expired") + return None + + self.client_id = reg_response.client_id + self.client_secret = reg_response.client_secret + logger.debug(f"Loaded stored client registration: {self.client_id}") + return reg_response + + except Exception as e: + logger.debug(f"Failed to load client registration: {e}") + else: + logger.debug("Client registration file not found") + + return None + + async def refresh_token(self) -> BearerAuth | None: + """Refresh the access token if possible.""" + logger.debug("Attempting to refresh token") + token_data = await self.token_storage.load_tokens(self.server_url) + if not token_data or not token_data.refresh_token: + logger.debug("No token data or refresh token found, cannot refresh.") + return None + + if not self._metadata: + logger.debug("No OAuth metadata available, cannot refresh token.") + return None + + if not self._client: + if not self.client_id: + logger.debug("Cannot refresh token without client_id") + return None + logger.debug("Creating temporary AsyncOAuth2Client for token refresh") + self._client = AsyncOAuth2Client(client_id=self.client_id, client_secret=self.client_secret) + + logger.debug("Calling client.refresh_token") + try: + token_response = await self._client.refresh_token( + str(self._metadata.token_endpoint), + refresh_token=token_data.refresh_token, + ) + logger.debug("Token refresh successful") + + # Save new tokens + logger.debug("Saving new tokens after refresh") + await self.token_storage.save_tokens(self.server_url, token_response) + + # Update bearer auth + logger.debug("Updating BearerAuth with new access token") + self._bearer_auth = BearerAuth(token=SecretStr(token_response["access_token"])) + return self._bearer_auth + + except OAuth2Error as e: + logger.warning(f"Token refresh failed: {e}. Re-authentication is required.") + # Refresh failed, need to re-authenticate + return None diff --git a/mcp_use/auth/oauth_callback.py b/mcp_use/auth/oauth_callback.py new file mode 100644 index 00000000..eb6c66bd --- /dev/null +++ b/mcp_use/auth/oauth_callback.py @@ -0,0 +1,214 @@ +"""OAuth callback server implementation.""" + +import asyncio +from dataclasses import dataclass + +import anyio +import uvicorn +from starlette.applications import Starlette +from starlette.requests import Request +from starlette.responses import HTMLResponse +from starlette.routing import Route + +from ..logging import logger + + +@dataclass +class CallbackResponse: + """Response data from OAuth callback.""" + + code: str | None = None # Authorization code (success) + state: str | None = None # CSRF protection token + error: str | None = None # Errors code (if failed) + error_description: str | None = None + error_uri: str | None = None + + +class OAuthCallbackServer: + """Local server to handle OAuth callback.""" + + def __init__(self, port: int): + """Initialize the callback server. + + Args: + port: Port to listen on. + """ + self.port = port + self.redirect_uri: str | None = None + # Thread safe way to pass callback data to the main OAuth flow + self.response_queue: asyncio.Queue[CallbackResponse] = asyncio.Queue(maxsize=1) + self.server: uvicorn.Server | None = None + self._shutdown_event = anyio.Event() + + async def start(self) -> str: + """Start the callback server and return the redirect URI.""" + app = self._create_app() + + # Create the server + config = uvicorn.Config( + app, + host="127.0.0.1", + port=self.port, + log_level="error", # Suppress uvicorn logs + ) + self.server = uvicorn.Server(config) + + # Start server in background + self._server_task = asyncio.create_task(self.server.serve()) + + # Wait a moment for server to start + await asyncio.sleep(0.1) + + self.redirect_uri = f"http://localhost:{self.port}/callback" + return self.redirect_uri + + async def wait_for_code(self, timeout: float = 300) -> CallbackResponse: + """Wait for the OAuth callback with a timeout (default 5 minutes).""" + try: + response = await asyncio.wait_for(self.response_queue.get(), timeout=timeout) + return response + except TimeoutError: + raise TimeoutError(f"OAuth callback not received within {timeout} seconds") from None + finally: + await self.shutdown() + + async def shutdown(self): + """Shutdown the callback server.""" + self._shutdown_event.set() + if self.server: + self.server.should_exit = True + if hasattr(self, "_server_task"): + try: + await asyncio.wait_for(self._server_task, timeout=5.0) + except TimeoutError: + self._server_task.cancel() + + def _create_app(self) -> Starlette: + """Create the Starlette application.""" + + async def callback(request: Request) -> HTMLResponse: + """Handle the OAuth callback.""" + params = request.query_params + + # Extract OAuth parameters + response = CallbackResponse( + code=params.get("code"), + state=params.get("state"), + error=params.get("error"), + error_description=params.get("error_description"), + error_uri=params.get("error_uri"), + ) + + # Log the callback response + logger.debug( + f"OAuth callback received: error={response.error}, error_description={response.error_description}" + ) + if response.code: + logger.debug("OAuth callback received authorization code") + else: + logger.error(f"OAuth callback error: {response.error} - {response.error_description}") + + # Put response in queue + try: + self.response_queue.put_nowait(response) + except asyncio.QueueFull: + pass # Ignore if queue is already full + + # Return success page + if response.code: + html = self._success_html() + else: + html = self._error_html(response.error, response.error_description) + + return HTMLResponse(content=html) + + routes = [Route("/callback", callback)] + return Starlette(routes=routes) + + def _success_html(self) -> str: + """HTML response for successful authorization.""" + return """ + + + + Codestin Search App + + + +
+
+

Authorization Successful!

+

You can now close this window and return to your application.

+
+ + + + """ + + def _error_html(self, error: str | None, description: str | None) -> str: + """HTML response for authorization error.""" + error_msg = error or "Unknown error" + desc_msg = description or "Authorization was not completed successfully." + + return f""" + + + + Codestin Search App + + + +
+
+

Authorization Error

+

{error_msg}

+

{desc_msg}

+
+ + + """ diff --git a/mcp_use/client.py b/mcp_use/client.py index 30571dd4..548a89f0 100644 --- a/mcp_use/client.py +++ b/mcp_use/client.py @@ -192,7 +192,7 @@ async def create_session(self, server_name: str, auto_initialize: bool = True) - server_config = servers[server_name] - # Create connector with options + # Create connector with options and client-level auth connector = create_connector_from_config( server_config, sandbox=self.sandbox, diff --git a/mcp_use/config.py b/mcp_use/config.py index 10e1e370..0ae41479 100644 --- a/mcp_use/config.py +++ b/mcp_use/config.py @@ -79,7 +79,7 @@ def create_connector_from_config( return HttpConnector( base_url=server_config["url"], headers=server_config.get("headers", None), - auth_token=server_config.get("auth_token", None), + auth=server_config.get("auth", {}), timeout=server_config.get("timeout", 5), sse_read_timeout=server_config.get("sse_read_timeout", 60 * 5), sampling_callback=sampling_callback, @@ -93,7 +93,7 @@ def create_connector_from_config( return WebSocketConnector( url=server_config["ws_url"], headers=server_config.get("headers", None), - auth_token=server_config.get("auth_token", None), + auth=server_config.get("auth", {}), ) raise ValueError("Cannot determine connector type from config") diff --git a/mcp_use/connectors/http.py b/mcp_use/connectors/http.py index 2db4e405..b7fbeaa7 100644 --- a/mcp_use/connectors/http.py +++ b/mcp_use/connectors/http.py @@ -5,11 +5,17 @@ through HTTP APIs with SSE or Streamable HTTP for transport. """ +from typing import Any + import httpx from mcp import ClientSession from mcp.client.session import ElicitationFnT, LoggingFnT, MessageHandlerFnT, SamplingFnT from mcp.shared.exceptions import McpError +from mcp_use.auth.oauth import OAuthClientProvider + +from ..auth import BearerAuth, OAuth +from ..exceptions import OAuthAuthenticationError, OAuthDiscoveryError from ..logging import logger from ..task_managers import SseConnectionManager, StreamableHttpConnectionManager from .base import BaseConnector @@ -25,10 +31,10 @@ class HttpConnector(BaseConnector): def __init__( self, base_url: str, - auth_token: str | None = None, headers: dict[str, str] | None = None, timeout: float = 5, sse_read_timeout: float = 60 * 5, + auth: str | dict[str, Any] | httpx.Auth | None = None, sampling_callback: SamplingFnT | None = None, elicitation_callback: ElicitationFnT | None = None, message_handler: MessageHandlerFnT | None = None, @@ -38,10 +44,13 @@ def __init__( Args: base_url: The base URL of the MCP HTTP API. - auth_token: Optional authentication token. headers: Optional additional headers. timeout: Timeout for HTTP operations in seconds. sse_read_timeout: Timeout for SSE read operations in seconds. + auth: Authentication method - can be: + - A string token: Use Bearer token authentication + - A dict with OAuth config: {"client_id": "...", "client_secret": "...", "scope": "..."} + - An httpx.Auth object: Use custom authentication sampling_callback: Optional sampling callback. elicitation_callback: Optional elicitation callback. """ @@ -52,12 +61,57 @@ def __init__( logging_callback=logging_callback, ) self.base_url = base_url.rstrip("/") - self.auth_token = auth_token self.headers = headers or {} - if auth_token: - self.headers["Authorization"] = f"Bearer {auth_token}" self.timeout = timeout self.sse_read_timeout = sse_read_timeout + self._auth: httpx.Auth | None = None + self._oauth: OAuth | None = None + + # Handle authentication + if auth is not None: + self._set_auth(auth) + + def _set_auth(self, auth: str | dict[str, Any] | httpx.Auth) -> None: + """Set authentication method. + + Args: + auth: Authentication method - can be: + - A string token: Use Bearer token authentication + - A dict with OAuth config: {"client_id": "...", "client_secret": "...", "scope": "..."} + - An httpx.Auth object: Use custom authentication + """ + if isinstance(auth, str): + # Treat as bearer token + self._auth = BearerAuth(token=auth) + self.headers["Authorization"] = f"Bearer {auth}" + elif isinstance(auth, dict): + # Check if this is an OAuth provider configuration + if "oauth_provider" in auth: + oauth_provider = auth["oauth_provider"] + if isinstance(oauth_provider, dict): + oauth_provider = OAuthClientProvider(**oauth_provider) + self._oauth = OAuth( + self.base_url, + scope=auth.get("scope"), + client_id=auth.get("client_id"), + client_secret=auth.get("client_secret"), + callback_port=auth.get("callback_port"), + oauth_provider=oauth_provider, + ) + self._oauth_config = auth + else: + self._oauth = OAuth( + self.base_url, + scope=auth.get("scope"), + client_id=auth.get("client_id"), + client_secret=auth.get("client_secret"), + callback_port=auth.get("callback_port"), + ) + self._oauth_config = auth + elif isinstance(auth, httpx.Auth): + self._auth = auth + else: + raise ValueError(f"Invalid auth type: {type(auth)}") async def connect(self) -> None: """Establish a connection to the MCP implementation.""" @@ -65,6 +119,29 @@ async def connect(self) -> None: logger.debug("Already connected to MCP implementation") return + # Handle OAuth if needed + if self._oauth: + try: + # Create a temporary client for OAuth metadata discovery + async with httpx.AsyncClient() as client: + bearer_auth = await self._oauth.initialize(client) + if not bearer_auth: + # Need to perform OAuth flow + logger.info("OAuth authentication required") + bearer_auth = await self._oauth.authenticate() + + # Update auth and headers + self._auth = bearer_auth + self.headers["Authorization"] = f"Bearer {bearer_auth.token.get_secret_value()}" + except OAuthDiscoveryError: + # OAuth discovery failed - it means server doesn't support OAuth default urls + logger.debug("OAuth discovery failed, continuing without initialization.") + self._oauth = None + self._auth = None + except OAuthAuthenticationError as e: + logger.error(f"OAuth initialization failed: {e}") + raise + # Try streamable HTTP first (new transport), fall back to SSE (old transport) # This implements backwards compatibility per MCP specification self.transport_type = None @@ -74,7 +151,7 @@ async def connect(self) -> None: # First, try the new streamable HTTP transport logger.debug(f"Attempting streamable HTTP connection to: {self.base_url}") connection_manager = StreamableHttpConnectionManager( - self.base_url, self.headers, self.timeout, self.sse_read_timeout + self.base_url, self.headers, self.timeout, self.sse_read_timeout, auth=self._auth ) # Test if this is a streamable HTTP server by attempting initialization @@ -95,9 +172,9 @@ async def connect(self) -> None: try: # Try to initialize - this is where streamable HTTP vs SSE difference should show up result = await test_client.initialize() + logger.debug(f"Streamable HTTP initialization result: {result}") # If we get here, streamable HTTP works - self.client_session = test_client self.transport_type = "streamable HTTP" self._initialized = True # Mark as initialized since we just called initialize() @@ -127,18 +204,12 @@ async def connect(self) -> None: self._prompts = [] except McpError as mcp_error: - # This is a protocol error, not a transport error - # The server is reachable and speaking MCP, but rejecting our request - logger.error("MCP protocol error during initialization: %s", mcp_error) - + logger.error("MCP protocol error during initialization: %s", mcp_error.error) # Clean up the test client try: await test_client.__aexit__(None, None, None) except Exception: pass - - # Don't try SSE fallback for protocol errors - the server is working, - # it just doesn't like our request raise mcp_error except Exception as init_error: @@ -147,7 +218,16 @@ async def connect(self) -> None: await test_client.__aexit__(None, None, None) except Exception: pass - raise init_error + + if isinstance(init_error, httpx.HTTPStatusError): + if init_error.response.status_code in [401, 403, 407]: # Authentication error using status + # Server requires authentication but OAuth discovery failed + raise OAuthAuthenticationError( + f"Server requires authentication (HTTP {init_error.response.status_code}) " + "but OAuth discovery failed. Please provide OAuth configuration manually." + ) from init_error + else: + raise init_error except Exception as streamable_error: logger.debug(f"Streamable HTTP failed: {streamable_error}") @@ -160,15 +240,16 @@ async def connect(self) -> None: pass # Check if this is a 4xx error that indicates we should try SSE fallback + # HACK: Still sometimes StreamableHTTP will return other errors, so we still try to fallback to SSE should_fallback = False if isinstance(streamable_error, httpx.HTTPStatusError): if streamable_error.response.status_code in [404, 405]: should_fallback = True + logger.debug("Streamable HTTP failed: 404/ 405 Not Found/ Method Not Allowed") elif "405 Method Not Allowed" in str(streamable_error) or "404 Not Found" in str(streamable_error): should_fallback = True else: - # For other errors, still try fallback but they might indicate - # real connectivity issues + logger.debug("Streamable HTTP failed, falling back to SSE") should_fallback = True if should_fallback: @@ -176,7 +257,7 @@ async def connect(self) -> None: # Fall back to the old SSE transport logger.debug(f"Attempting SSE fallback connection to: {self.base_url}") connection_manager = SseConnectionManager( - self.base_url, self.headers, self.timeout, self.sse_read_timeout + self.base_url, self.headers, self.timeout, self.sse_read_timeout, auth=self._auth ) read_stream, write_stream = await connection_manager.start() @@ -195,10 +276,17 @@ async def connect(self) -> None: self.transport_type = "SSE" except Exception as sse_error: - logger.error( - f"Both transport methods failed. Streamable HTTP: {streamable_error}, SSE: {sse_error}" - ) - raise sse_error + if isinstance(sse_error, httpx.HTTPStatusError): + if sse_error.response.status_code in [401, 403, 407]: + raise OAuthAuthenticationError( + f"Server requires authentication (HTTP {sse_error.response.status_code}) " + "but OAuth discovery failed. Please provide OAuth configuration manually." + ) from sse_error + else: + logger.error( + f"Both transport methods failed. Streamable HTTP: {streamable_error}, SSE: {sse_error}" + ) + raise sse_error else: raise streamable_error diff --git a/mcp_use/connectors/websocket.py b/mcp_use/connectors/websocket.py index 3ca3ec0d..7ead33fb 100644 --- a/mcp_use/connectors/websocket.py +++ b/mcp_use/connectors/websocket.py @@ -10,6 +10,7 @@ import uuid from typing import Any +import httpx from mcp.types import Tool from websockets import ClientConnection @@ -28,21 +29,29 @@ class WebSocketConnector(BaseConnector): def __init__( self, url: str, - auth_token: str | None = None, headers: dict[str, str] | None = None, + auth: str | dict[str, Any] | httpx.Auth | None = None, ): """Initialize a new WebSocket connector. Args: url: The WebSocket URL to connect to. - auth_token: Optional authentication token. headers: Optional additional headers. + auth: Authentication method - can be: + - A string token: Use Bearer token authentication + - A dict: Not supported for WebSocket (will log warning) + - An httpx.Auth object: Not supported for WebSocket (will log warning) """ self.url = url - self.auth_token = auth_token self.headers = headers or {} - if auth_token: - self.headers["Authorization"] = f"Bearer {auth_token}" + + # Handle authentication - WebSocket only supports bearer tokens + # An auth field it's not needed + if auth is not None: + if isinstance(auth, str): + self.headers["Authorization"] = f"Bearer {auth}" + else: + logger.warning("WebSocket connector only supports bearer token authentication") self.ws: ClientConnection | None = None self._connection_manager: ConnectionManager | None = None diff --git a/mcp_use/exceptions.py b/mcp_use/exceptions.py new file mode 100644 index 00000000..6c943a78 --- /dev/null +++ b/mcp_use/exceptions.py @@ -0,0 +1,31 @@ +"""MCP-use exceptions.""" + + +class MCPError(Exception): + """Base exception for MCP-use.""" + + pass + + +class OAuthDiscoveryError(MCPError): + """OAuth discovery auth metadata error""" + + pass + + +class OAuthAuthenticationError(MCPError): + """OAuth authentication-related errors""" + + pass + + +class ConnectionError(MCPError): + """Connection-related errors.""" + + pass + + +class ConfigurationError(MCPError): + """Configuration-related errors.""" + + pass diff --git a/mcp_use/task_managers/sse.py b/mcp_use/task_managers/sse.py index 95a90c41..f5442ff4 100644 --- a/mcp_use/task_managers/sse.py +++ b/mcp_use/task_managers/sse.py @@ -7,6 +7,7 @@ from typing import Any +import httpx from mcp.client.sse import sse_client from ..logging import logger @@ -27,6 +28,7 @@ def __init__( headers: dict[str, str] | None = None, timeout: float = 5, sse_read_timeout: float = 60 * 5, + auth: httpx.Auth | None = None, ): """Initialize a new SSE connection manager. @@ -35,12 +37,14 @@ def __init__( headers: Optional HTTP headers timeout: Timeout for HTTP operations in seconds sse_read_timeout: Timeout for SSE read operations in seconds + auth: Optional httpx.Auth instance for authentication """ super().__init__() self.url = url self.headers = headers or {} self.timeout = timeout self.sse_read_timeout = sse_read_timeout + self.auth = auth self._sse_ctx = None async def _establish_connection(self) -> tuple[Any, Any]: @@ -58,6 +62,7 @@ async def _establish_connection(self) -> tuple[Any, Any]: headers=self.headers, timeout=self.timeout, sse_read_timeout=self.sse_read_timeout, + auth=self.auth, ) # Enter the context manager diff --git a/mcp_use/task_managers/streamable_http.py b/mcp_use/task_managers/streamable_http.py index d41040c1..493c3139 100644 --- a/mcp_use/task_managers/streamable_http.py +++ b/mcp_use/task_managers/streamable_http.py @@ -8,6 +8,7 @@ from datetime import timedelta from typing import Any +import httpx from mcp.client.streamable_http import streamablehttp_client from ..logging import logger @@ -28,6 +29,7 @@ def __init__( headers: dict[str, str] | None = None, timeout: float = 5, read_timeout: float = 60 * 5, + auth: httpx.Auth | None = None, ): """Initialize a new streamable HTTP connection manager. @@ -36,12 +38,14 @@ def __init__( headers: Optional HTTP headers timeout: Timeout for HTTP operations in seconds read_timeout: Timeout for HTTP read operations in seconds + auth: Optional httpx.Auth instance for authentication """ super().__init__() self.url = url self.headers = headers or {} self.timeout = timedelta(seconds=timeout) self.read_timeout = timedelta(seconds=read_timeout) + self.auth = auth self._http_ctx = None async def _establish_connection(self) -> tuple[Any, Any]: @@ -59,6 +63,7 @@ async def _establish_connection(self) -> tuple[Any, Any]: headers=self.headers, timeout=self.timeout, sse_read_timeout=self.read_timeout, + auth=self.auth, ) # Enter the context manager. Ignoring the session id callback diff --git a/pyproject.toml b/pyproject.toml index 2193808b..b8eb26ed 100644 --- a/pyproject.toml +++ b/pyproject.toml @@ -28,6 +28,7 @@ dependencies = [ "python-dotenv>=1.0.0", "posthog>=4.8.0", "scarf-sdk>=0.1.0", + "authlib>=1.6.3", ] [project.optional-dependencies] diff --git a/tests/integration/conftest.py b/tests/integration/conftest.py index d18709b3..0da277fe 100644 --- a/tests/integration/conftest.py +++ b/tests/integration/conftest.py @@ -36,3 +36,33 @@ async def primitive_server(): process.kill() process.wait() logger.info("Primitive server cleanup complete.") + + +@pytest.fixture(scope="session") +async def auth_server(): + """Starts the auth_server.py as a subprocess for integration tests.""" + server_path = Path(__file__).parent / "servers_for_testing" / "auth_server.py" + logger.info(f"Starting auth server: python {server_path}") + + process = subprocess.Popen( + ["python", str(server_path)], + stdout=subprocess.PIPE, + stderr=subprocess.PIPE, + text=True, + ) + + # Allow server to initialize + await asyncio.sleep(2) + + yield "http://127.0.0.1:8081" + + logger.info("Cleaning up auth server process") + if process.poll() is None: + process.terminate() + try: + process.wait(timeout=5) + except subprocess.TimeoutExpired: + logger.warning("Server process did not terminate gracefully, killing.") + process.kill() + process.wait() + logger.info("Auth server cleanup complete.") diff --git a/tests/integration/primitives/test_auth.py b/tests/integration/primitives/test_auth.py new file mode 100644 index 00000000..7fcd70f5 --- /dev/null +++ b/tests/integration/primitives/test_auth.py @@ -0,0 +1,200 @@ +import token +from pathlib import Path +from unittest.mock import AsyncMock, MagicMock, patch + +import pytest +from httpx import Auth, Request + +from mcp_use import MCPClient, set_debug +from mcp_use.auth.bearer import BearerAuth + +set_debug(2) + + +@pytest.mark.asyncio +async def test_bearer_auth(auth_server): + """Test Bearer token authentication.""" + config = {"mcpServers": {"AuthServer": {"url": f"{auth_server}/mcp", "auth": "valid_token"}}} + client = MCPClient(config) + try: + await client.create_all_sessions() + session = client.get_session("AuthServer") + + assert session.connector._auth is not None + + # Test that we can call protected tools + result = await session.call_tool(name="protected_tool", arguments={}) + assert result.content[0].text == "Authenticated access granted!" + + # Test that we can call regular tools + result = await session.call_tool(name="add", arguments={"a": 5, "b": 3}) + assert result.content[0].text == "8" + finally: + await client.close_all_sessions() + + +@pytest.mark.asyncio +@patch("mcp_use.auth.oauth.secrets.token_urlsafe") +@patch("mcp_use.auth.oauth.webbrowser.open") +@patch("mcp_use.auth.oauth.OAuthCallbackServer") +async def test_oauth_provider(mock_callback_server_class, mock_webbrowser_open, mock_token_urlsafe, auth_server): + """Test OAuth with pre-configured provider metadata.""" + await clean_token() + + # Mock the callback server + mock_callback_server = AsyncMock() + mock_callback_server.start.return_value = "http://127.0.0.1:8080/callback" + mock_callback_server.wait_for_code.return_value = MagicMock( + code="test_auth_code_12345", state="test_state_67890", error=None, error_description=None + ) + mock_callback_server_class.return_value = mock_callback_server + mock_webbrowser_open.return_value = None + + # Avoid CSRF error during testing + mock_token_urlsafe.return_value = "test_state_67890" + + config = { + "mcpServers": { + "AuthServer": { + "url": f"{auth_server}/mcp", + "auth": { + "oauth_provider": { + "id": "auth_server", + "display_name": "AuthServer", + "metadata": { + "issuer": "http://127.0.0.1:8081", + "authorization_endpoint": "http://127.0.0.1:8081/oauth/authorize", + "token_endpoint": "http://127.0.0.1:8081/oauth/token", + "registration_endpoint": "http://127.0.0.1:8081/oauth/register", + }, + } + }, + } + } + } + client = MCPClient(config) + try: + await client.create_all_sessions() + session = client.get_session("AuthServer") + + # Verify OAuth metadata was discovered + assert session.connector._oauth is not None, "OAuth should be initialized" + assert session.connector._oauth._metadata is not None, "OAuth metadata should be discovered" + assert str(session.connector._oauth._metadata.issuer) == "http://127.0.0.1:8081/" + assert str(session.connector._oauth._metadata.authorization_endpoint) == "http://127.0.0.1:8081/oauth/authorize" + assert str(session.connector._oauth._metadata.registration_endpoint) == "http://127.0.0.1:8081/oauth/register" + assert str(session.connector._oauth._metadata.token_endpoint) == "http://127.0.0.1:8081/oauth/token" + + mock_webbrowser_open.assert_called_once() + mock_callback_server_class.assert_called_once() + mock_callback_server.start.assert_called_once() + mock_callback_server.wait_for_code.assert_called_once() + + assert session.connector._auth is not None + + # Test that we can call protected tools + result = await session.call_tool(name="protected_tool", arguments={}) + assert result.content[0].text == "Authenticated access granted!" + + # Test that we can call regular tools + result = await session.call_tool(name="add", arguments={"a": 5, "b": 3}) + assert result.content[0].text == "8" + finally: + await client.close_all_sessions() + + +@pytest.mark.asyncio +async def test_custom_client(auth_server): + "Test that custom httpx.Auth objects works in auth field." + + custom_auth = BearerAuth(token="valid_token") + config = {"mcpServers": {"AuthServer": {"url": f"{auth_server}/mcp", "auth": custom_auth}}} + + client = MCPClient(config) + try: + await client.create_all_sessions() + session = client.get_session("AuthServer") + + # Verify the custom BearerAuth is being used + assert session.connector._auth == custom_auth + + # Test that we can call protected tools + result = await session.call_tool(name="protected_tool", arguments={}) + assert result.content[0].text == "Authenticated access granted!" + + # Test that we can call regular tools + result = await session.call_tool(name="add", arguments={"a": 5, "b": 3}) + assert result.content[0].text == "8" + + finally: + await client.close_all_sessions() + + +@pytest.mark.asyncio +@patch("mcp_use.auth.oauth.secrets.token_urlsafe") +@patch("mcp_use.auth.oauth.webbrowser.open") +@patch("mcp_use.auth.oauth.OAuthCallbackServer") +async def test_oauth_complete_flow(mock_callback_server_class, mock_webbrowser_open, mock_token_urlsafe, auth_server): + """Test OAuth complete flow, with metadata discovery, DCR and auth token.""" + await clean_token() + + # Mock the callback server + mock_callback_server = AsyncMock() + mock_callback_server.start.return_value = "http://127.0.0.1:8080/callback" + mock_callback_server.wait_for_code.return_value = MagicMock( + code="test_auth_code_12345", state="test_state_67890", error=None, error_description=None + ) + mock_callback_server_class.return_value = mock_callback_server + mock_webbrowser_open.return_value = None + + # Avoid CSRF error during testing + mock_token_urlsafe.return_value = "test_state_67890" + + config = { + "mcpServers": { + "AuthServer": { + "url": f"{auth_server}/mcp", + } + } + } + client = MCPClient(config) + try: + await client.create_all_sessions() + session = client.get_session("AuthServer") + + # Verify OAuth metadata was discovered + assert session.connector._oauth is not None, "OAuth should be initialized" + assert session.connector._oauth._metadata is not None, "OAuth metadata should be discovered" + assert str(session.connector._oauth._metadata.issuer) == "http://127.0.0.1:8081/" + assert str(session.connector._oauth._metadata.authorization_endpoint) == "http://127.0.0.1:8081/oauth/authorize" + assert str(session.connector._oauth._metadata.registration_endpoint) == "http://127.0.0.1:8081/oauth/register" + assert str(session.connector._oauth._metadata.token_endpoint) == "http://127.0.0.1:8081/oauth/token" + + mock_webbrowser_open.assert_called_once() + mock_callback_server_class.assert_called_once() + mock_callback_server.start.assert_called_once() + mock_callback_server.wait_for_code.assert_called_once() + + assert session.connector._auth is not None + + # Test that we can call protected tools + result = await session.call_tool(name="protected_tool", arguments={}) + assert result.content[0].text == "Authenticated access granted!" + + # Test that we can call regular tools + result = await session.call_tool(name="add", arguments={"a": 5, "b": 3}) + assert result.content[0].text == "8" + finally: + await client.close_all_sessions() + + +async def clean_token(): + # Clear any existing tokens for this server before the test + token_dir = Path.home() / ".mcp_use" / "tokens" + if token_dir.exists(): + for token_file in token_dir.glob("*127.0.0.1:8081__mcp.json"): + # Clear the file + token_file.unlink() + registrations_dir = token_dir / "registrations" + for registration_file in registrations_dir.glob("*127.0.0.1:8081__mcp_registration.json"): + registration_file.unlink() diff --git a/tests/integration/servers_for_testing/auth_server.py b/tests/integration/servers_for_testing/auth_server.py new file mode 100644 index 00000000..48d3a798 --- /dev/null +++ b/tests/integration/servers_for_testing/auth_server.py @@ -0,0 +1,140 @@ +import argparse + +from fastmcp import Context, FastMCP +from starlette.requests import Request +from starlette.responses import JSONResponse, RedirectResponse + +mcp = FastMCP(name="AuthServer") + +# We set all the fields to ensure that models +# correctly bind data + +# OAuth metadata configuration +OAUTH_METADATA_RESPONSE = { + "issuer": "http://127.0.0.1:8081", + "authorization_endpoint": "http://127.0.0.1:8081/oauth/authorize", + "token_endpoint": "http://127.0.0.1:8081/oauth/token", + "userinfo_endpoint": "http://127.0.0.1:8081/oauth/userinfo", + "revocation_endpoint": "http://127.0.0.1:8081/oauth/revocation", + "introspection_endpoint": "http://127.0.0.1:8081/oauth/introspection", + "registration_endpoint": "http://127.0.0.1:8081/oauth/register", + "jwks_uri": "http://127.0.0.1:8081/oauth/jwks", + "response_types_supported": ["code"], + "subject_types_supported": ["public"], + "id_token_signing_alg_values_supported": ["RS256"], + "scopes_supported": ["openid", "profile", "email"], + "token_endpoint_auth_methods_supported": ["client_secret_basic", "none"], + "claims_supported": ["bla", "blabla"], + "code_challenge_methods_supported": ["S256", "plain"], +} + +DCR_RESPONSE = { + "client_id": "renvins", + "client_secret": "what a secret", + "client_id_issued_at": 0, + "client_secret_expires_at": 0, + "redirect_uris": ["what a uri"], + "grant_types": ["what a grant"], + "response_types": ["good_response"], + "client_name": "renvins_better", + "token_endpoint_auth_method": "code", +} + +AUTH_CODE_RESPONSE = { + "code": "test_auth_code_12345", +} + +TOKEN_RESPONSE = { + "access_token": "valid_token", + "token_type": "Bearer", + "expires_in": 3600, + "scope": "openid profile email", +} + +SAMPLE_VALID_TOKEN = "valid_token" + + +def verify_auth(ctx: Context) -> str: + """Verify auth token from context.""" + request = ctx.get_http_request() + auth_header = request.headers.get("Authorization") if request else None + + if not auth_header or not auth_header.startswith("Bearer "): + raise Exception("Missing or invalid Authorization header") + + token = auth_header.split(" ")[1] + if token != SAMPLE_VALID_TOKEN: + raise Exception("Invalid token") + + return token + + +@mcp.custom_route("/.well-known/oauth-authorization-server", methods=["GET"]) +async def oauth_metadata(request: Request) -> JSONResponse: + """Serve OAuth 2.0 Authorization Server Metadata.""" + return JSONResponse(OAUTH_METADATA_RESPONSE) + + +@mcp.custom_route("/.well-known/openid-configuration", methods=["GET"]) +async def oidc_metadata(request: Request) -> JSONResponse: + """Serve OpenID Connect Discovery metadata.""" + return JSONResponse(OAUTH_METADATA_RESPONSE) + + +@mcp.custom_route("/oauth/register", methods=["POST"]) +async def dynamic_regisration(request: Request) -> JSONResponse: + """Serve client DCR data""" + return JSONResponse(DCR_RESPONSE) + + +@mcp.custom_route("/oauth/authorize", methods=["GET"]) +async def oauth_authorize(request: Request) -> RedirectResponse: + """OAuth authorization endpoint - returns pre-created authorization code""" + # Get the redirect_uri and state from query params + params = dict(request.query_params) + redirect_uri = params.get("redirect_uri", "http://127.0.0.1:8080/callback") + state = params.get("state") + + # Redirect with pre-created authorization code + redirect_url = f"{redirect_uri}?code={AUTH_CODE_RESPONSE['code']}&state={state}" + return RedirectResponse(url=redirect_url, status_code=302) + + +@mcp.custom_route("/oauth/token", methods=["POST"]) +async def oauth_token(request: Request) -> JSONResponse: + """OAuth token endpoint - returns pre-created token""" + return JSONResponse(TOKEN_RESPONSE) + + +# Protected tool that requires auth +@mcp.tool() +async def protected_tool(ctx: Context) -> str: + """A tool that requires authentication.""" + verify_auth(ctx) + return "Authenticated access granted!" + + +# Simple math tool for testing +@mcp.tool() +def add(a: int, b: int) -> int: + """Add two numbers.""" + return a + b + + +if __name__ == "__main__": + parser = argparse.ArgumentParser(description="Run MCP auth test server.") + parser.add_argument( + "--transport", + type=str, + choices=["streamable-http", "sse"], + default="streamable-http", + help="MCP transport type to use (default: streamable-http)", + ) + args = parser.parse_args() + + print(f"Starting MCP auth server with transport: {args.transport}") + + if args.transport == "streamable-http": + mcp.run(transport="streamable-http", host="127.0.0.1", port=8081) + elif args.transport == "sse": + mcp.run(transport="sse", host="127.0.0.1", port=8081) diff --git a/tests/unit/test_config.py b/tests/unit/test_config.py index e80c67c1..674034c6 100644 --- a/tests/unit/test_config.py +++ b/tests/unit/test_config.py @@ -8,6 +8,7 @@ import unittest from unittest.mock import patch +from mcp_use.auth import BearerAuth from mcp_use.config import create_connector_from_config, load_config_file from mcp_use.connectors import HttpConnector, SandboxConnector, StdioConnector, WebSocketConnector from mcp_use.types.sandbox import SandboxOptions @@ -47,7 +48,7 @@ def test_create_http_connector(self): server_config = { "url": "http://test.com", "headers": {"Content-Type": "application/json"}, - "auth_token": "test_token", + "auth": "test_token", } connector = create_connector_from_config(server_config) @@ -58,14 +59,15 @@ def test_create_http_connector(self): connector.headers, {"Content-Type": "application/json", "Authorization": "Bearer test_token"}, ) - self.assertEqual(connector.auth_token, "test_token") + self.assertIsInstance(connector._auth, BearerAuth) + self.assertEqual(connector._auth.token.get_secret_value(), "test_token") def test_create_http_connector_with_options(self): """Test creating an HTTP connector with options.""" server_config = { "url": "http://test.com", "headers": {"Content-Type": "application/json"}, - "auth_token": "test_token", + "auth": "test_token", } options: SandboxOptions = { "api_key": "test_key", @@ -80,7 +82,8 @@ def test_create_http_connector_with_options(self): connector.headers, {"Content-Type": "application/json", "Authorization": "Bearer test_token"}, ) - self.assertEqual(connector.auth_token, "test_token") + self.assertIsInstance(connector._auth, BearerAuth) + self.assertEqual(connector._auth.token.get_secret_value(), "test_token") def test_create_http_connector_minimal(self): """Test creating an HTTP connector with minimal config.""" @@ -91,14 +94,14 @@ def test_create_http_connector_minimal(self): self.assertIsInstance(connector, HttpConnector) self.assertEqual(connector.base_url, "http://test.com") self.assertEqual(connector.headers, {}) - self.assertIsNone(connector.auth_token) + self.assertIsNone(connector._auth) def test_create_websocket_connector(self): """Test creating a WebSocket connector from config.""" server_config = { "ws_url": "ws://test.com", "headers": {"Content-Type": "application/json"}, - "auth_token": "test_token", + "auth": "test_token", } connector = create_connector_from_config(server_config) @@ -109,14 +112,13 @@ def test_create_websocket_connector(self): connector.headers, {"Content-Type": "application/json", "Authorization": "Bearer test_token"}, ) - self.assertEqual(connector.auth_token, "test_token") def test_create_websocket_connector_with_options(self): """Test creating a WebSocket connector with options.""" server_config = { "ws_url": "ws://test.com", "headers": {"Content-Type": "application/json"}, - "auth_token": "test_token", + "auth": "test_token", } options: SandboxOptions = { "api_key": "test_key", @@ -131,7 +133,6 @@ def test_create_websocket_connector_with_options(self): connector.headers, {"Content-Type": "application/json", "Authorization": "Bearer test_token"}, ) - self.assertEqual(connector.auth_token, "test_token") def test_create_websocket_connector_minimal(self): """Test creating a WebSocket connector with minimal config.""" @@ -142,7 +143,6 @@ def test_create_websocket_connector_minimal(self): self.assertIsInstance(connector, WebSocketConnector) self.assertEqual(connector.url, "ws://test.com") self.assertEqual(connector.headers, {}) - self.assertIsNone(connector.auth_token) def test_create_stdio_connector(self): """Test creating a stdio connector from config.""" diff --git a/tests/unit/test_http_connector.py b/tests/unit/test_http_connector.py index 3f19f278..df4ed868 100644 --- a/tests/unit/test_http_connector.py +++ b/tests/unit/test_http_connector.py @@ -10,6 +10,7 @@ from mcp import McpError from mcp.types import EmptyResult, ErrorData, Prompt, Resource, Tool +from mcp_use.auth.bearer import BearerAuth from mcp_use.connectors.http import HttpConnector from mcp_use.task_managers import SseConnectionManager @@ -23,7 +24,7 @@ def test_init_minimal(self, _): connector = HttpConnector(base_url="http://localhost:8000") self.assertEqual(connector.base_url, "http://localhost:8000") - self.assertIsNone(connector.auth_token) + self.assertIsNone(connector._auth) self.assertEqual(connector.headers, {}) self.assertIsNone(connector.client_session) self.assertIsNone(connector._connection_manager) @@ -32,10 +33,11 @@ def test_init_minimal(self, _): def test_init_with_auth_token(self, _): """Test initialization with auth token.""" - connector = HttpConnector(base_url="http://localhost:8000", auth_token="test_token") + connector = HttpConnector(base_url="http://localhost:8000", auth="test_token") self.assertEqual(connector.base_url, "http://localhost:8000") - self.assertEqual(connector.auth_token, "test_token") + self.assertIsInstance(connector._auth, BearerAuth) + self.assertEqual(connector._auth.token.get_secret_value(), "test_token") self.assertEqual(connector.headers, {"Authorization": "Bearer test_token"}) self.assertIsNone(connector.client_session) self.assertIsNone(connector._connection_manager) @@ -48,7 +50,7 @@ def test_init_with_headers(self, _): connector = HttpConnector(base_url="http://localhost:8000", headers=headers) self.assertEqual(connector.base_url, "http://localhost:8000") - self.assertIsNone(connector.auth_token) + self.assertIsNone(connector._auth) self.assertEqual(connector.headers, headers) self.assertIsNone(connector.client_session) self.assertIsNone(connector._connection_manager) @@ -58,13 +60,13 @@ def test_init_with_headers(self, _): def test_init_with_auth_token_and_headers(self, _): """Test initialization with both auth token and headers.""" headers = {"Content-Type": "application/json", "Accept": "application/json"} - connector = HttpConnector(base_url="http://localhost:8000", auth_token="test_token", headers=headers) + connector = HttpConnector(base_url="http://localhost:8000", auth="test_token", headers=headers) expected_headers = headers.copy() expected_headers["Authorization"] = "Bearer test_token" self.assertEqual(connector.base_url, "http://localhost:8000") - self.assertEqual(connector.auth_token, "test_token") + self.assertEqual(connector._auth.token.get_secret_value(), "test_token") self.assertEqual(connector.headers, expected_headers) self.assertIsNone(connector.client_session) self.assertIsNone(connector._connection_manager) @@ -191,7 +193,7 @@ async def test_connect_with_streamable_http(self, mock_client_session_class, moc await self.connector.connect() # Verify streamable HTTP connection manager was used - mock_cm_class.assert_called_once_with("http://localhost:8000", {}, 5, 300) + mock_cm_class.assert_called_once_with("http://localhost:8000", {}, 5, 300, auth=None) mock_cm_instance.start.assert_called_once() # Verify client session was created and initialized