diff --git a/dash/_configs.py b/dash/_configs.py index 25a401523b..ac3c5c14c2 100644 --- a/dash/_configs.py +++ b/dash/_configs.py @@ -36,6 +36,7 @@ def load_dash_env_vars(): "DASH_MCP_ENABLED", "DASH_MCP_PATH", "DASH_MCP_EXPOSE_DOCSTRINGS", + "DASH_MCP_AUTHORIZATION_SERVER", "HOST", "PORT", ) diff --git a/dash/dash.py b/dash/dash.py index 37eb7a1ffb..7b11cec230 100644 --- a/dash/dash.py +++ b/dash/dash.py @@ -489,6 +489,7 @@ def __init__( # pylint: disable=too-many-statements, too-many-branches enable_mcp: Optional[bool] = None, mcp_path: Optional[str] = None, mcp_expose_docstrings: Optional[bool] = None, + mcp_authorization_server: Optional[str] = None, **obsolete, ): @@ -609,6 +610,9 @@ def __init__( # pylint: disable=too-many-statements, too-many-branches self._mcp_path = ( _mcp_path.lstrip("/") if isinstance(_mcp_path, str) else _mcp_path ) + self._mcp_authorization_server = get_combined_config( + "mcp_authorization_server", mcp_authorization_server + ) # list of dependencies - this one is used by the back end for dispatching self.callback_map: dict = {} @@ -829,7 +833,11 @@ def _setup_routes(self): ) try: - enable_mcp_server(self, self._mcp_path) + enable_mcp_server( + self, + self._mcp_path, + mcp_authorization_server=self._mcp_authorization_server, + ) except Exception as e: # pylint: disable=broad-exception-caught self._enable_mcp = False self.logger.warning( diff --git a/dash/mcp/_server.py b/dash/mcp/_server.py index 07b0520bb9..d218b0b7f9 100644 --- a/dash/mcp/_server.py +++ b/dash/mcp/_server.py @@ -11,7 +11,9 @@ import json import logging import os +from functools import reduce from typing import TYPE_CHECKING, Any +from urllib.parse import urljoin from mcp.types import ( LATEST_PROTOCOL_VERSION, @@ -47,7 +49,61 @@ logger = logging.getLogger(__name__) -def enable_mcp_server(app: Dash, mcp_path: str) -> None: +def _url_from_path(app: Dash, *parts: str) -> str: + """Build an absolute URL by joining path parts onto the current request origin. + + Behind a reverse proxy, TLS terminates at the proxy so + the scheme may report HTTP even when the client connected + over HTTPS. Use HTTPS unless running on localhost. + """ + from urllib.parse import urlparse # pylint: disable=import-outside-toplevel + + adapter = app.backend.request_adapter() + parsed = urlparse(adapter.url) + host = parsed.netloc + is_localhost = host.startswith("localhost") or host.startswith("127.0.0.1") + scheme = "http" if is_localhost else "https" + path = reduce(urljoin, parts, "/") + return f"{scheme}://{host}{path}" + + +def _setup_mcp_oauth(app: Dash, mcp_path: str, mcp_authorization_server: str) -> None: + """Register RFC 9728 Protected Resource Metadata endpoint for MCP. + + Serves discovery metadata so MCP clients can find the authorization + server. Auth enforcement is the responsibility of the hosting platform + (e.g. Plotly Cloud gateway, Dash Embedded, or a reverse proxy). + """ + well_known_path = urljoin("/.well-known/oauth-protected-resource/", mcp_path) + + def _serve_resource_metadata(): + return app.backend.make_response( + json.dumps( + { + "resource": _url_from_path( + app, app.config.requests_pathname_prefix, mcp_path + ), + "authorization_servers": [mcp_authorization_server], + "bearer_methods_supported": ["header"], + } + ), + content_type="application/json", + ) + + # pylint: disable-next=protected-access + app._add_url(well_known_path.lstrip("/"), _serve_resource_metadata) + + logger.info( + "MCP OAuth discovery enabled, authorization server: %s", + mcp_authorization_server, + ) + + +def enable_mcp_server( + app: Dash, + mcp_path: str, + mcp_authorization_server: str | None = None, +) -> None: """Add MCP routes to a Dash app.""" app.mcp_decorated_functions = dict(MCP_DECORATED_FUNCTIONS) @@ -207,6 +263,9 @@ def _handle_not_allowed(): ) app.routes.append(mcp_url) + if mcp_authorization_server: + _setup_mcp_oauth(app, mcp_path, mcp_authorization_server) + logger.info( "MCP routes registered at %s%s", app.config.routes_pathname_prefix, diff --git a/dash/mcp/tasks/tasks.py b/dash/mcp/tasks/tasks.py index ab0fd82069..044ee74d25 100644 --- a/dash/mcp/tasks/tasks.py +++ b/dash/mcp/tasks/tasks.py @@ -13,6 +13,7 @@ from dash.development.base_component import ComponentRegistry from dash.mcp.primitives.tools.results import format_callback_response from dash.mcp.types import MCPError +from dash.types import CallbackExecutionResponse def parse_task_id(task_id: str) -> tuple[str, str, str, datetime]: @@ -72,11 +73,13 @@ def get_task(task_id: str) -> GetTaskResult: running = manager.job_running(job_id) progress = manager.get_progress(cache_key) + # Check result_ready before job_running: the process may store its result + # while still technically alive (teardown race), so result_ready wins. status: Literal["working", "completed", "failed"] - if running: - status = "working" - elif manager.result_ready(cache_key): + if manager.result_ready(cache_key): status = "completed" + elif running: + status = "working" else: status = "failed" @@ -125,7 +128,7 @@ def get_task_result(task_id: str) -> Any: output_spec = body.get("outputs", []) callback_ctx = AttributeDict({"updated_props": {}}) - response = {"multi": True} + response: CallbackExecutionResponse = {"multi": True} output_value, has_update, skip = _update_background_callback( error_handler=None,