diff --git a/.envrc b/.envrc index 96b5c94a1b..22cf211038 100644 --- a/.envrc +++ b/.envrc @@ -6,3 +6,8 @@ for venv_dir in .venv .env venv env .venv* env* ENV; do break fi done + +# Include optional, local-only environment overrides +if [ -f .env ]; then + dotenv +fi diff --git a/.github/workflows/testing.yml b/.github/workflows/testing.yml index 2359695404..bb9b10a016 100644 --- a/.github/workflows/testing.yml +++ b/.github/workflows/testing.yml @@ -108,7 +108,7 @@ jobs: run: | python -m pip install --upgrade pip wheel python -m pip install "setuptools<80.0.0" - find packages -name dash-*.whl -print -exec sh -c 'pip install "{}[dev,ci,testing]"' \; + find packages -name dash-*.whl -print -exec sh -c 'pip install "{}[dev,ci,testing,diskcache]"' \; - name: Install dash-renderer dependencies working-directory: dash/dash-renderer @@ -122,6 +122,8 @@ jobs: echo "DISPLAY=:99" >> $GITHUB_ENV - name: Run lint + env: + PYLINT_EXTRA_ARGS: ${{ matrix.python-version == '3.8' && '--ignored-modules=mcp' || '' }} run: npm run lint - name: Run unit tests @@ -229,7 +231,7 @@ jobs: run: | python -m pip install --upgrade pip wheel python -m pip install "setuptools<80.0.0" - find packages -name dash-*.whl -print -exec sh -c 'pip install "{}[ci,testing,dev]"' \; + find packages -name dash-*.whl -print -exec sh -c 'pip install "{}[ci,testing,dev,diskcache]"' \; - name: Build/Setup test components run: npm run setup-tests.py # TODO build the packages and save them to packages/ in build job diff --git a/.pylintrc b/.pylintrc index 7ffb5576b7..39e142048d 100644 --- a/.pylintrc +++ b/.pylintrc @@ -431,7 +431,7 @@ max-returns=6 max-statements=50 # Minimum number of public methods for a class (see R0903). -min-public-methods=2 +min-public-methods=1 [IMPORTS] diff --git a/CHANGELOG.md b/CHANGELOG.md index 13c0fc585b..36c31ec8e9 100644 --- a/CHANGELOG.md +++ b/CHANGELOG.md @@ -20,6 +20,19 @@ This project adheres to [Semantic Versioning](https://semver.org/). - [#2462](https://github.com/plotly/dash/issues/2462) Allow `MATCH` in `Input`/`State` when the callback's `Output` has no wildcards (fixed-id Output, no Output, or `ALL`-only wildcard Output). `ALLSMALLER` still requires a corresponding `MATCH` in an Output. - [#3759](https://github.com/plotly/dash/pull/3759) Fix the issue where `Patch` objects cannot be updated via `set_props()` in `websocket` callback. Fix [#3742](https://github.com/plotly/dash/issues/3742) +## [4.3.0rc0] - 2026-05-21 + +## Added +- [#3710](https://github.com/plotly/dash/pull/3710) MCP: Framework utilities, types for interacting with layout +- [#3711](https://github.com/plotly/dash/pull/3711) MCP: `CallbackAdapter` for representing callback-related data in MCP-friendly format +- [#3712](https://github.com/plotly/dash/pull/3712) MCP: `Resources` for exposing app layout, components, and pages +- [#3731](https://github.com/plotly/dash/pull/3731) MCP: Expose callbacks as `Tools` +- [#3747](https://github.com/plotly/dash/pull/3747) MCP: Support pattern-matching callbacks in Tools +- [#3748](https://github.com/plotly/dash/pull/3748) MCP: Format callback results for LLM consumption (rendered graphs, markdown tables) +- [#3749](https://github.com/plotly/dash/pull/3749) MCP: `get_dash_component` Tool and callback execution +- [#3750](https://github.com/plotly/dash/pull/3750) MCP: Server routes, `mcp_enabled` function decorator, and Streamable HTTP transport +- [#3766](https://github.com/plotly/dash/pull/3766) MCP: Support background callbacks in Tools + ## [4.2.0rc3] - 2026-05-12 - [#3771](https://github.com/plotly/dash/pull/3771) Add persistent callbacks and no inputs/no outputs callback support. diff --git a/components/dash-core-components/package.json b/components/dash-core-components/package.json index ffe7de1d2f..e430a00b6a 100644 --- a/components/dash-core-components/package.json +++ b/components/dash-core-components/package.json @@ -27,7 +27,7 @@ "build:js": "webpack --mode production", "build:backends": "dash-generate-components ./src/components dash_core_components -p package-info.json && cp dash_core_components_base/** dash_core_components/ && dash-generate-components ./src/components dash_core_components -p package-info.json -k RangeSlider,Slider,Dropdown,RadioItems,Checklist,DatePickerSingle,DatePickerRange,Input,Link --r-prefix 'dcc' --r-suggests 'dash,dashHtmlComponents,jsonlite,plotly' --jl-prefix 'dcc' && black dash_core_components", "build": "run-s prepublishOnly build:js build:backends", - "postbuild": "es-check es2015 dash_core_components/*.js", + "postbuild": "es-check es2017 dash_core_components/*.js", "build:watch": "watch 'npm run build' src", "format": "run-s private::format.*", "lint": "run-s private::lint.*" @@ -126,6 +126,6 @@ "react-dom": "16 - 19" }, "browserslist": [ - "last 10 years and not dead" + "last 11 years and not dead" ] } diff --git a/components/dash-table/package.json b/components/dash-table/package.json index 295517a4c9..fd905c6f40 100644 --- a/components/dash-table/package.json +++ b/components/dash-table/package.json @@ -119,6 +119,6 @@ "npm": ">=6.1.0" }, "browserslist": [ - "last 10 years and not dead" + "last 11 years and not dead" ] } diff --git a/dash/_callback.py b/dash/_callback.py index a0d5a1021d..037f3d189b 100644 --- a/dash/_callback.py +++ b/dash/_callback.py @@ -39,6 +39,7 @@ from .background_callback.managers import BaseBackgroundCallbackManager from ._callback_context import context_value +from .types import CallbackExecutionResponse from ._no_update import NoUpdate from . import _validate @@ -85,6 +86,8 @@ def callback( hidden: Optional[bool] = None, websocket: Optional[bool] = False, persistent: Optional[bool] = False, + mcp_enabled: bool = True, + mcp_expose_docstring: Optional[bool] = None, **_kwargs, ) -> Callable[[Callable[Params, ReturnVar]], Callable[Params, ReturnVar]]: """ @@ -242,6 +245,8 @@ def callback( hidden=hidden, websocket=websocket, persistent=persistent, + mcp_enabled=mcp_enabled, + mcp_expose_docstring=mcp_expose_docstring, ) return cast( @@ -295,6 +300,8 @@ def insert_callback( hidden=None, websocket=False, persistent=False, + mcp_enabled=True, + mcp_expose_docstring=None, ) -> str: if prevent_initial_call is None: prevent_initial_call = config_prevent_initial_callbacks @@ -338,6 +345,8 @@ def insert_callback( "allow_dynamic_callbacks": dynamic_creator, "no_output": no_output, "websocket": websocket, + "mcp_enabled": mcp_enabled, + "mcp_expose_docstring": mcp_expose_docstring, } callback_list.append(callback_spec) @@ -453,10 +462,13 @@ def _setup_background_callback( return to_json(data) -def _progress_background_callback(response, callback_manager, background): +def _progress_background_callback( + response, callback_manager, background, cache_key=None +): progress_outputs = background.get("progress") - adapter = get_app().backend.request_adapter() - cache_key = adapter.args.get("cacheKey") + if cache_key is None: + adapter = get_app().backend.request_adapter() + cache_key = adapter.args.get("cacheKey") if progress_outputs: # Get the progress before the result as it would be erased after the results. @@ -468,21 +480,38 @@ def _progress_background_callback(response, callback_manager, background): def _update_background_callback( - error_handler, callback_ctx, response, kwargs, background, multi + error_handler, + callback_ctx, + response, + kwargs, + background, + multi, + cache_key=None, + job_id=None, ): """Set up the background callback and manage jobs.""" callback_manager = _get_callback_manager(kwargs, background) - adapter = get_app().backend.request_adapter() - cache_key = adapter.args.get("cacheKey") if adapter else None - job_id = adapter.args.get("job") if adapter else None + if cache_key is None or job_id is None: + adapter = get_app().backend.request_adapter() + cache_key = cache_key or (adapter.args.get("cacheKey") if adapter else None) + job_id = job_id or (adapter.args.get("job") if adapter else None) - _progress_background_callback(response, callback_manager, background) + _progress_background_callback( + response, callback_manager, background, cache_key=cache_key + ) output_value = callback_manager.get_result(cache_key, job_id) return _handle_rest_background_callback( - output_value, callback_manager, response, error_handler, callback_ctx, multi + output_value, + callback_manager, + response, + error_handler, + callback_ctx, + multi, + cache_key=cache_key, + job_id=job_id, ) @@ -494,10 +523,13 @@ def _handle_rest_background_callback( callback_ctx, multi, has_update=False, + cache_key=None, + job_id=None, ): - adapter = get_app().backend.request_adapter() - cache_key = adapter.args.get("cacheKey") if adapter else None - job_id = adapter.args.get("job") if adapter else None + if cache_key is None or job_id is None: + adapter = get_app().backend.request_adapter() + cache_key = cache_key or (adapter.args.get("cacheKey") if adapter else None) + job_id = job_id or (adapter.args.get("job") if adapter else None) # Must get job_running after get_result since get_results terminates it. job_running = callback_manager.job_running(job_id) if not job_running and output_value is callback_manager.UNDEFINED: @@ -546,7 +578,7 @@ def _prepare_response( output_value, output_spec, multi, - response, + response: CallbackExecutionResponse, callback_ctx, app, original_packages, @@ -558,7 +590,7 @@ def _prepare_response( allow_dynamic_callbacks, ): """Prepare the response object based on the callback output.""" - component_ids = collections.defaultdict(dict) + component_ids: dict = collections.defaultdict(dict) if has_output: if not multi: @@ -677,6 +709,8 @@ def register_callback( hidden=_kwargs.get("hidden", None), websocket=_kwargs.get("websocket", False), persistent=_kwargs.get("persistent", False), + mcp_enabled=_kwargs.get("mcp_enabled", True), + mcp_expose_docstring=_kwargs.get("mcp_expose_docstring"), ) # pylint: disable=too-many-locals @@ -711,7 +745,7 @@ def add_context(*args, **kwargs): args, kwargs, inputs_state_indices, has_output, insert_output ) - response: dict = {"multi": True} # type: ignore + response: CallbackExecutionResponse = {"multi": True} jsonResponse: Optional[str] = None try: if background is not None: @@ -783,7 +817,7 @@ async def async_add_context(*args, **kwargs): args, kwargs, inputs_state_indices, has_output, insert_output ) - response = {"multi": True} + response: CallbackExecutionResponse = {"multi": True} try: if background is not None: diff --git a/dash/_configs.py b/dash/_configs.py index 107b8308f5..25a401523b 100644 --- a/dash/_configs.py +++ b/dash/_configs.py @@ -33,6 +33,9 @@ def load_dash_env_vars(): "DASH_DISABLE_VERSION_CHECK", "DASH_PRUNE_ERRORS", "DASH_COMPRESS", + "DASH_MCP_ENABLED", + "DASH_MCP_PATH", + "DASH_MCP_EXPOSE_DOCSTRINGS", "HOST", "PORT", ) diff --git a/dash/_layout_utils.py b/dash/_layout_utils.py new file mode 100644 index 0000000000..d421771afd --- /dev/null +++ b/dash/_layout_utils.py @@ -0,0 +1,228 @@ +"""Reusable layout utilities for traversing and inspecting Dash component trees.""" + +from __future__ import annotations + +import json +from typing import Any, Generator + +from dash import get_app +from dash._pages import PAGE_REGISTRY +from dash.dependencies import Wildcard +from dash.development.base_component import Component + +_WILDCARD_VALUES = frozenset(w.value for w in Wildcard) + + +def traverse( + start: Component | None = None, +) -> Generator[tuple[Component, tuple[Component, ...]], None, None]: + """Yield ``(component, ancestors)`` for every Component in the tree. + + If ``start`` is ``None``, the full app layout is resolved via + ``dash.get_app()``, preferring ``validation_layout`` for completeness. + """ + if start is None: + app = get_app() + start = getattr(app, "validation_layout", None) or app.get_layout() + + yield from _walk(start, ()) + + +def _walk( + node: Any, + ancestors: tuple[Component, ...], +) -> Generator[tuple[Component, tuple[Component, ...]], None, None]: + if node is None: + return + if isinstance(node, (list, tuple)): + for item in node: + yield from _walk(item, ancestors) + return + if not isinstance(node, Component): + return + + yield node, ancestors + + child_ancestors = (*ancestors, node) + for _prop_name, child in iter_children(node): + yield from _walk(child, child_ancestors) + + +def iter_children( + component: Component, +) -> Generator[tuple[str, Component], None, None]: + """Yield ``(prop_name, child_component)`` for all component-valued props. + + Walks ``children`` plus any props declared in the component's + ``_children_props`` list. Supports nested path expressions like + ``control_groups[].children`` and ``insights.title``. + """ + props_to_walk = ["children"] + getattr(component, "_children_props", []) + for prop_path in props_to_walk: + for child in get_children(component, prop_path): + yield prop_path, child + + +def get_children(component: Any, prop_path: str) -> list[Component]: + """Resolve a ``_children_props`` path expression to child Components. + + Mirrors the dash-renderer's path parsing in ``DashWrapper.tsx``. + Supports: + - ``"children"`` — simple prop + - ``"control_groups[].children"`` — array, then sub-prop per element + - ``"insights.title"`` — nested object prop + """ + clean_path = prop_path.replace("[]", "").replace("{}", "") + + if "." not in prop_path: + return _collect_components(getattr(component, clean_path, None)) + + parts = prop_path.split(".") + array_idx = next((i for i, p in enumerate(parts) if "[]" in p), len(parts)) + front = [p.replace("[]", "").replace("{}", "") for p in parts[: array_idx + 1]] + back = [p.replace("{}", "") for p in parts[array_idx + 1 :]] + + node = _resolve_path(component, front) + if node is None: + return [] + + if back and isinstance(node, (list, tuple)): + results: list[Component] = [] + for element in node: + child = _resolve_path(element, back) + results.extend(_collect_components(child)) + return results + + return _collect_components(node) + + +def _resolve_path(node: Any, keys: list[str]) -> Any: + """Walk a chain of keys through Components and dicts.""" + for key in keys: + if isinstance(node, Component): + node = getattr(node, key, None) + elif isinstance(node, dict): + node = node.get(key) + else: + return None + if node is None: + return None + return node + + +def _collect_components(value: Any) -> list[Component]: + """Extract Components from a value (single, list, or None).""" + if value is None: + return [] + if isinstance(value, Component): + return [value] + if isinstance(value, (list, tuple)): + return [item for item in value if isinstance(item, Component)] + return [] + + +def find_component( + component_id: str | dict, + layout: Component | None = None, + page: str | None = None, +) -> Component | None: + """Find a component by ID. + + If neither ``layout`` nor ``page`` is provided, searches the full + app layout (preferring ``validation_layout`` for completeness). + """ + if page is not None: + layout = _resolve_page_layout(page) + + if layout is None: + app = get_app() + layout = getattr(app, "validation_layout", None) or app.get_layout() + + for comp, _ in traverse(layout): + if getattr(comp, "id", None) == component_id: + return comp + return None + + +def parse_wildcard_id(pid: Any) -> dict | None: + """Parse a component ID and return it as a dict if it contains a wildcard. + + Accepts string (JSON-encoded) or dict IDs. Returns ``None`` + if the ID is not a wildcard pattern. + + Example:: + + >>> parse_wildcard_id('{"type":"input","index":["ALL"]}') + {"type": "input", "index": ["ALL"]} + >>> parse_wildcard_id("my-dropdown") + None + """ + if isinstance(pid, str) and pid.startswith("{"): + try: + pid = json.loads(pid) + except (json.JSONDecodeError, ValueError): + return None + if not isinstance(pid, dict): + return None + for v in pid.values(): + if isinstance(v, list) and len(v) == 1 and v[0] in _WILDCARD_VALUES: + return pid + return None + + +def find_matching_components(pattern: dict) -> list[Component]: + """Find all components whose dict ID matches a wildcard pattern. + + Non-wildcard keys must match exactly. Wildcard keys are ignored. + """ + non_wildcard_keys = { + k: v + for k, v in pattern.items() + if not (isinstance(v, list) and len(v) == 1 and v[0] in _WILDCARD_VALUES) + } + matches = [] + for comp, _ in traverse(): + comp_id = getattr(comp, "id", None) + if not isinstance(comp_id, dict): + continue + if all(comp_id.get(k) == v for k, v in non_wildcard_keys.items()): + matches.append(comp) + return matches + + +def extract_text(component: Component) -> str: + """Recursively extract plain text from a component's children tree. + + Mimics the browser's ``element.textContent``. + """ + children = getattr(component, "children", None) + if children is None: + return "" + if isinstance(children, str): + return children + if isinstance(children, Component): + return extract_text(children) + if isinstance(children, (list, tuple)): + parts: list[str] = [] + for child in children: + if isinstance(child, str): + parts.append(child) + elif isinstance(child, Component): + parts.append(extract_text(child)) + return "".join(parts).strip() + return "" + + +def _resolve_page_layout(page: str) -> Any | None: + if not PAGE_REGISTRY: + return None + for _module, page_info in PAGE_REGISTRY.items(): + if page_info.get("path") == page: + page_layout = page_info.get("layout") + if callable(page_layout): + try: + page_layout = page_layout() + except (TypeError, RuntimeError): + return None + return page_layout + return None diff --git a/dash/backends/ws.py b/dash/backends/ws.py index 041241823e..3e9739e16c 100644 --- a/dash/backends/ws.py +++ b/dash/backends/ws.py @@ -20,6 +20,7 @@ import janus from dash.exceptions import PreventUpdate, WebsocketDisconnected +from dash.types import CallbackExecutionBody from dash._utils import to_json if TYPE_CHECKING: @@ -149,7 +150,7 @@ async def get_prop( def create_ws_context( - payload: dict, + payload: CallbackExecutionBody, response_adapter: "ResponseAdapter", websocket_callback: DashWebsocketCallback, ): @@ -276,7 +277,7 @@ def on_done(f: concurrent.futures.Future) -> None: def run_callback_in_executor( executor: ThreadPoolExecutor, dash_app: "dash.Dash", - payload: dict, + payload: CallbackExecutionBody, ws_callback: DashWebsocketCallback, response_adapter: "ResponseAdapter", ) -> concurrent.futures.Future: diff --git a/dash/background_callback/managers/diskcache_manager.py b/dash/background_callback/managers/diskcache_manager.py index db7cd112bc..2ded7eb48e 100644 --- a/dash/background_callback/managers/diskcache_manager.py +++ b/dash/background_callback/managers/diskcache_manager.py @@ -36,7 +36,7 @@ def __init__(self, cache=None, cache_by=None, expire=None): is determined by the default behavior of the ``cache`` instance. """ try: - import diskcache # type: ignore[import-not-found] # pylint: disable=import-outside-toplevel + import diskcache # type: ignore[import-not-found,import-untyped] # pylint: disable=import-outside-toplevel import psutil # type: ignore[import-untyped] # noqa: F401,E402 pylint: disable=import-outside-toplevel,unused-import,unused-variable,import-error import multiprocess # type: ignore[import-untyped] # noqa: F401,E402 pylint: disable=import-outside-toplevel,unused-import,unused-variable except ImportError as missing_imports: diff --git a/dash/dash-renderer/babel.config.js b/dash/dash-renderer/babel.config.js index d7b0c89e8e..6e6cc5d957 100644 --- a/dash/dash-renderer/babel.config.js +++ b/dash/dash-renderer/babel.config.js @@ -3,7 +3,7 @@ module.exports = { '@babel/preset-typescript', ['@babel/preset-env', { "targets": { - "browsers": ["last 10 years and not dead"] + "browsers": ["last 11 years and not dead"] } }], '@babel/preset-react' diff --git a/dash/dash-renderer/package.json b/dash/dash-renderer/package.json index a404fa2425..b9d1965f0d 100644 --- a/dash/dash-renderer/package.json +++ b/dash/dash-renderer/package.json @@ -89,6 +89,6 @@ ], "prettier": "@plotly/prettier-config-dash", "browserslist": [ - "last 10 years and not dead" + "last 11 years and not dead" ] } diff --git a/dash/dash.py b/dash/dash.py index f0821abef2..37eb7a1ffb 100644 --- a/dash/dash.py +++ b/dash/dash.py @@ -75,7 +75,7 @@ _import_layouts_from_pages, ) from ._jupyter import jupyter_dash, JupyterDisplayMode -from .types import RendererHooks +from .types import CallbackExecutionBody, RendererHooks RouteCallable = Callable[..., Any] @@ -486,6 +486,9 @@ def __init__( # pylint: disable=too-many-statements, too-many-branches websocket_callbacks: Optional[bool] = False, websocket_allowed_origins: Optional[List[str]] = None, websocket_inactivity_timeout: Optional[int] = 300000, + enable_mcp: Optional[bool] = None, + mcp_path: Optional[str] = None, + mcp_expose_docstrings: Optional[bool] = None, **obsolete, ): @@ -567,6 +570,9 @@ def __init__( # pylint: disable=too-many-statements, too-many-branches hide_all_callbacks=False, csrf_token_name=csrf_token_name, csrf_header_name=csrf_header_name, + mcp_expose_docstrings=get_combined_config( + "mcp_expose_docstrings", mcp_expose_docstrings, False + ), ) self.config.set_read_only( [ @@ -597,11 +603,19 @@ def __init__( # pylint: disable=too-many-statements, too-many-branches # keep title as a class property for backwards compatibility self.title = title + # MCP (Model Context Protocol) configuration + self._enable_mcp = get_combined_config("mcp_enabled", enable_mcp, False) + _mcp_path = get_combined_config("mcp_path", mcp_path, "_mcp") + self._mcp_path = ( + _mcp_path.lstrip("/") if isinstance(_mcp_path, str) else _mcp_path + ) + # list of dependencies - this one is used by the back end for dispatching self.callback_map: dict = {} # same deps as a list to catch duplicate outputs, and to send to the front end self._callback_list: list = [] self.callback_api_paths: dict = {} + self.mcp_decorated_functions: dict = {} # list of inline scripts self._inline_scripts: list = [] @@ -809,6 +823,21 @@ def _setup_routes(self): hook.data["methods"], ) + if self._enable_mcp: + from .mcp import ( # pylint: disable=import-outside-toplevel + enable_mcp_server, + ) + + try: + enable_mcp_server(self, self._mcp_path) + except Exception as e: # pylint: disable=broad-exception-caught + self._enable_mcp = False + self.logger.warning( + "MCP server could not be started at '%s': %s", + self._mcp_path, + e, + ) + def setup_apis(self): """ Register API endpoints for all callbacks defined using `dash.callback`. @@ -898,15 +927,22 @@ def index_string(self, value: str) -> None: self._index_string = value @with_app_context - def serve_layout(self): - layout = self._layout_value() + def get_layout(self): + """Return the resolved layout with all hooks applied. + This is the canonical way to obtain the app's layout — it + calls the layout function (if callable), includes extra + components, and runs layout hooks. + """ + layout = self._layout_value() for hook in self._hooks.get_hooks("layout"): layout = hook(layout) + return layout + def serve_layout(self): # TODO - Set browser cache limit - pass hash into frontend return self.backend.make_response( - to_json(layout), + to_json(self.get_layout()), mimetype="application/json", ) @@ -1462,7 +1498,7 @@ def _inputs_to_vals(self, inputs): return inputs_to_vals(inputs) # pylint: disable=R0915 - def _initialize_context(self, body): + def _initialize_context(self, body: CallbackExecutionBody): """Initialize the global context for the request.""" adapter = self.backend.request_adapter() g = AttributeDict({}) @@ -1485,7 +1521,7 @@ def _initialize_context(self, body): g.updated_props = {} return g - def _prepare_callback(self, g, body): + def _prepare_callback(self, g, body: CallbackExecutionBody): """Prepare callback-related data.""" output = body["output"] try: @@ -2445,6 +2481,13 @@ def verify_url_part(served_part, url_part, part_name): if not jupyter_dash or not jupyter_dash.in_ipython: self.logger.info("Dash is running on %s://%s%s%s\n", *display_url) + if self._enable_mcp: + self.logger.info( + " * MCP available at %s://%s%s%s%s\n", + *display_url[:3], + self.config.routes_pathname_prefix, + self._mcp_path, + ) if self.config.extra_hot_reload_paths: extra_files = flask_run_options["extra_files"] = [] diff --git a/dash/development/_py_components_generation.py b/dash/development/_py_components_generation.py index 73545ea4a5..87211c7cc8 100644 --- a/dash/development/_py_components_generation.py +++ b/dash/development/_py_components_generation.py @@ -24,6 +24,15 @@ import typing # noqa: F401 from typing_extensions import TypedDict, NotRequired, Literal # noqa: F401 from dash.development.base_component import Component, _explicitize_args +try: + from dash.types import NumberType # noqa: F401 +except ImportError: + # Backwards compatibility for dash<=4.1.0 + if typing.TYPE_CHECKING: + raise + NumberType = typing.Union[ # noqa: F401 + typing.SupportsFloat, typing.SupportsInt, typing.SupportsComplex + ] {custom_imports} ComponentSingleType = typing.Union[str, int, float, Component, None] ComponentType = typing.Union[ @@ -31,10 +40,6 @@ typing.Sequence[ComponentSingleType], ] -NumberType = typing.Union[ - typing.SupportsFloat, typing.SupportsInt, typing.SupportsComplex -] - """ diff --git a/dash/development/base_component.py b/dash/development/base_component.py index 342ad6da6f..d7a473c25e 100644 --- a/dash/development/base_component.py +++ b/dash/development/base_component.py @@ -117,6 +117,26 @@ class Component(metaclass=ComponentMeta): _valid_wildcard_attributes: typing.List[str] available_wildcard_properties: typing.List[str] + @classmethod + def __get_pydantic_core_schema__(cls, _source_type, _handler): + from pydantic_core import core_schema # pylint: disable=import-outside-toplevel + + return core_schema.any_schema() + + @classmethod + def __get_pydantic_json_schema__(cls, _schema, _handler): + namespaces = list(ComponentRegistry.namespace_to_package.keys()) + return { + "type": "object", + "properties": { + "type": {"type": "string"}, + "namespace": {"type": "string", "enum": namespaces} + if namespaces + else {"type": "string"}, + "props": {"type": "object"}, + }, + } + class _UNDEFINED: def __repr__(self): return "undefined" diff --git a/dash/mcp/__init__.py b/dash/mcp/__init__.py new file mode 100644 index 0000000000..2e6edffdb2 --- /dev/null +++ b/dash/mcp/__init__.py @@ -0,0 +1,9 @@ +"""Dash MCP (Model Context Protocol) server integration.""" + +from dash.mcp._decorator import mcp_enabled +from dash.mcp._server import enable_mcp_server + +__all__ = [ + "enable_mcp_server", + "mcp_enabled", +] diff --git a/dash/mcp/_decorator.py b/dash/mcp/_decorator.py new file mode 100644 index 0000000000..1b85316207 --- /dev/null +++ b/dash/mcp/_decorator.py @@ -0,0 +1,51 @@ +"""Decorator to expose plain Python functions as MCP tools.""" + +from __future__ import annotations + +import functools +from typing import Any, Callable, Optional + +from typing_extensions import TypedDict + + +class MCPToolRegistration(TypedDict): + fn: Callable[..., Any] + expose_docstring: Optional[bool] + + +MCP_DECORATED_FUNCTIONS: dict[str, MCPToolRegistration] = {} + + +def mcp_enabled( + func: Callable[..., Any] | None = None, + *, + name: str | None = None, + expose_docstring: Optional[bool] = None, +) -> Callable[..., Any]: + """Mark a function as an MCP tool. + + Supports both bare and parameterised usage:: + + @mcp_enabled + def my_tool(x: int) -> str: ... + + @mcp_enabled(name="custom_name", expose_docstring=True) + def my_tool(x: int) -> str: ... + """ + + def _wrap(fn: Callable[..., Any]) -> Callable[..., Any]: + tool_name = name if name else fn.__name__ + MCP_DECORATED_FUNCTIONS[tool_name] = MCPToolRegistration( + fn=fn, + expose_docstring=expose_docstring, + ) + + @functools.wraps(fn) + def wrapper(*args: Any, **kwargs: Any) -> Any: + return fn(*args, **kwargs) + + return wrapper + + if func is not None: + return _wrap(func) + return _wrap diff --git a/dash/mcp/_server.py b/dash/mcp/_server.py new file mode 100644 index 0000000000..07b0520bb9 --- /dev/null +++ b/dash/mcp/_server.py @@ -0,0 +1,296 @@ +"""Flask route setup, Streamable HTTP transport, and MCP message handling.""" + +# pylint: disable=cyclic-import +# The MCP server imports dash primitives to dispatch callbacks, and dash +# lazy-imports this module to wire the MCP endpoint. Cycle is managed here. + +from __future__ import annotations + +import hashlib +import inspect +import json +import logging +import os +from typing import TYPE_CHECKING, Any + +from mcp.types import ( + LATEST_PROTOCOL_VERSION, + ErrorData, + Implementation, + InitializeResult, + JSONRPCError, + JSONRPCResponse, + ResourcesCapability, + ServerCapabilities, + ToolsCapability, +) + +from dash import get_app +from dash.mcp._decorator import MCP_DECORATED_FUNCTIONS +from dash.mcp.primitives import ( + call_tool, + list_resource_templates, + list_resources, + list_tools, + read_resource, +) +from dash.mcp.tasks import get_task, get_task_result, cancel_task +from dash.mcp.primitives.tools.callback_adapter_collection import ( + CallbackAdapterCollection, +) +from dash.mcp.types import MCPError +from dash.version import __version__ + +if TYPE_CHECKING: + from dash import Dash + +logger = logging.getLogger(__name__) + + +def enable_mcp_server(app: Dash, mcp_path: str) -> None: + """Add MCP routes to a Dash app.""" + + app.mcp_decorated_functions = dict(MCP_DECORATED_FUNCTIONS) + MCP_DECORATED_FUNCTIONS.clear() + + def _get_or_create_session_id() -> str: + """ + Creates a shared session ID shared across all clients. The session is + used to notify clients of app restarts so they can refresh their view + of the app. + When hot-reloading is enabled, the reload_hash is used + Otherwise, the parent PID is used because it is a stable identifier + across different worker processes. + """ + # pylint: disable=protected-access + reload_hash = app._hot_reload.hash + if reload_hash is not None: + return reload_hash + return hashlib.sha256(f"dash-mcp-{os.getppid()}".encode()).hexdigest()[:32] + + _session_id: str = _get_or_create_session_id() + + def _is_session_stale(client_session_id: str | None) -> bool: + """True when the client's session doesn't match or the hash changed.""" + if client_session_id != _session_id: + return True + # pylint: disable=protected-access + reload_hash = app._hot_reload.hash + if reload_hash is None: + return False + return reload_hash != _session_id + + # -- Streamable HTTP endpoint -------------------------------------------- + + def _check_session(method: str) -> bool: + """Validate the session header. + + Raises ``ValueError`` when the header is missing. + Returns ``True`` when the session was stale and transparently + recovered, or ``False`` when the session is valid. + """ + nonlocal _session_id + adapter = app.backend.request_adapter() + if method == "initialize": + _session_id = _get_or_create_session_id() + return False + client_session_id = adapter.headers.get("Mcp-Session-Id") + if client_session_id and _is_session_stale(client_session_id): + _session_id = _get_or_create_session_id() + logger.debug("MCP session recovered: %s", _session_id) + return True + return False + + def _json_response(*messages: dict): + """Wrap one or more JSON-RPC messages in a response. + + A single message is serialised as a JSON object; multiple + messages are serialised as a JSON array. + """ + body = messages[0] if len(messages) == 1 else list(messages) + resp = app.backend.make_response( + json.dumps(body), + content_type="application/json", + status=200, + ) + if _session_id is not None: + resp.headers["Mcp-Session-Id"] = _session_id + return resp + + def _handle_post() -> Any: + adapter = app.backend.request_adapter() + return _handle_mcp_request(adapter.get_json()) + + async def _handle_post_async() -> Any: + adapter = app.backend.request_adapter() + return _handle_mcp_request(await adapter.get_json()) + + def _handle_mcp_request(data) -> Any: + adapter = app.backend.request_adapter() + content_type = adapter.headers.get("Content-Type", "") + if "application/json" not in content_type: + return app.backend.make_response( + json.dumps({"error": "Content-Type must be application/json"}), + content_type="application/json", + status=415, + ) + + if data is None: + return app.backend.make_response( + json.dumps({"error": "Invalid JSON"}), + content_type="application/json", + status=400, + ) + + method = data.get("method", "") + + try: + is_stale_session = _check_session(method) + except ValueError as err: + return app.backend.make_response( + json.dumps({"error": str(err)}), + content_type="application/json", + status=400, + ) + + response_data = _process_mcp_message(data) + + if response_data is None: + return app.backend.make_response("", status=202) + + if is_stale_session: + return _json_response( + {"jsonrpc": "2.0", "method": "notifications/tools/list_changed"}, + { + "jsonrpc": "2.0", + "method": "notifications/resources/list_changed", + }, + response_data, + ) + + return _json_response(response_data) + + def _handle_not_allowed(): + """Return 405 for GET and DELETE (MCP SSE not supported).""" + return app.backend.make_response( + json.dumps({"error": "Method not allowed"}), + content_type="application/json", + status=405, + ) + + # -- Register routes ----------------------------------------------------- + # Separate registrations per HTTP method so the handler never needs to + # inspect the request method. Distinct endpoint names are required by + # Flask / Werkzeug when the same URL rule is registered more than once. + if inspect.iscoroutinefunction(app.backend.request_adapter.get_json): + post_handler = _handle_post_async + else: + post_handler = _handle_post + mcp_url = app.config.routes_pathname_prefix + mcp_path + app.backend.add_url_rule( + mcp_url, + view_func=post_handler, + endpoint=f"{mcp_url}:POST", + methods=["POST"], + ) + app.backend.add_url_rule( + mcp_url, + view_func=_handle_not_allowed, + endpoint=f"{mcp_url}:GET", + methods=["GET"], + ) + app.backend.add_url_rule( + mcp_url, + view_func=_handle_not_allowed, + endpoint=f"{mcp_url}:DELETE", + methods=["DELETE"], + ) + app.routes.append(mcp_url) + + logger.info( + "MCP routes registered at %s%s", + app.config.routes_pathname_prefix, + mcp_path, + ) + + +def _handle_initialize() -> InitializeResult: + return InitializeResult( + protocolVersion=LATEST_PROTOCOL_VERSION, + capabilities=ServerCapabilities( + tools=ToolsCapability(listChanged=True), + resources=ResourcesCapability(listChanged=True), + ), + serverInfo=Implementation(name="Plotly Dash", version=__version__), + instructions=( + "This is a Dash web application. " + "Dash apps are stateless: calling a tool executes " + "a callback and returns its result to you, but does " + "NOT update the user's browser. " + "Use tool results to answer questions about what " + "the app would produce for given inputs." + ), + ) + + +def _process_mcp_message(data: dict[str, Any]) -> dict[str, Any] | None: + """ + Process an MCP JSON-RPC message and return the response dict. + + Returns ``None`` for notifications (no ``id`` field). + """ + method = data.get("method", "") + params = data.get("params", {}) or {} + _id = data.get("id") + request_id: str | int = _id if isinstance(_id, (str, int)) else "" + + app = get_app() + if not hasattr(app, "mcp_callback_map"): + app.mcp_callback_map = CallbackAdapterCollection(app) + + mcp_methods = { + "initialize": _handle_initialize, + "tools/list": list_tools, + "tools/call": lambda: call_tool( + tool_name=params.get("name", ""), + arguments=params.get("arguments", {}), + task=params.get("task"), + ), + "resources/list": list_resources, + "resources/templates/list": list_resource_templates, + "resources/read": lambda: read_resource(params.get("uri", "")), + "tasks/get": lambda: get_task(task_id=params.get("taskId", "")), + "tasks/result": lambda: get_task_result(task_id=params.get("taskId", "")), + "tasks/cancel": lambda: cancel_task(task_id=params.get("taskId", "")), + } + + try: + handler = mcp_methods.get(method) + if handler is None: + if method.startswith("notifications/"): + return None + raise ValueError(f"Unknown method: {method}") + + result = handler() + + response = JSONRPCResponse( + jsonrpc="2.0", + id=request_id, + result=result.model_dump(exclude_none=True, mode="json"), + ) + return response.model_dump(exclude_none=True, mode="json") + + except MCPError as e: + logger.error("MCP error: %s", e) + return JSONRPCError( + jsonrpc="2.0", + id=request_id, + error=ErrorData(code=e.code, message=str(e)), + ).model_dump(exclude_none=True) + except Exception as e: # pylint: disable=broad-exception-caught + logger.error("MCP error: %s", e, exc_info=True) + return JSONRPCError( + jsonrpc="2.0", + id=request_id, + error=ErrorData(code=-32603, message=f"{type(e).__name__}: {e}"), + ).model_dump(exclude_none=True) diff --git a/dash/mcp/primitives/__init__.py b/dash/mcp/primitives/__init__.py new file mode 100644 index 0000000000..e6b46a9af3 --- /dev/null +++ b/dash/mcp/primitives/__init__.py @@ -0,0 +1,17 @@ +from .resources import ( + list_resources, + list_resource_templates, + read_resource, +) +from .tools import ( + call_tool, + list_tools, +) + +__all__ = [ + "call_tool", + "list_resources", + "list_resource_templates", + "list_tools", + "read_resource", +] diff --git a/dash/mcp/primitives/resources/__init__.py b/dash/mcp/primitives/resources/__init__.py new file mode 100644 index 0000000000..a65e376e6f --- /dev/null +++ b/dash/mcp/primitives/resources/__init__.py @@ -0,0 +1,48 @@ +"""MCP resource listing and read handling.""" + +from __future__ import annotations + +from mcp.types import ( + ListResourcesResult, + ListResourceTemplatesResult, + ReadResourceResult, +) + +from .base import MCPResourceProvider +from .resource_clientside_callbacks import ClientsideCallbacksResource +from .resource_components import ComponentsResource +from .resource_layout import LayoutResource +from .resource_page_layout import PageLayoutResource +from .resource_pages import PagesResource + +_RESOURCE_PROVIDERS: list[type[MCPResourceProvider]] = [ + LayoutResource, + ComponentsResource, + PagesResource, + ClientsideCallbacksResource, + PageLayoutResource, +] + + +def list_resources() -> ListResourcesResult: + """Build the MCP resources/list response.""" + resources = [ + r for p in _RESOURCE_PROVIDERS for r in [p.get_resource()] if r is not None + ] + return ListResourcesResult(resources=resources) + + +def list_resource_templates() -> ListResourceTemplatesResult: + """Build the MCP resources/templates/list response.""" + templates = [ + t for p in _RESOURCE_PROVIDERS for t in [p.get_template()] if t is not None + ] + return ListResourceTemplatesResult(resourceTemplates=templates) + + +def read_resource(uri: str) -> ReadResourceResult: + """Route a resources/read request by URI prefix match.""" + for p in _RESOURCE_PROVIDERS: + if uri.startswith(p.uri): + return p.read_resource(uri) + raise ValueError(f"Unknown resource URI: {uri}") diff --git a/dash/mcp/primitives/resources/base.py b/dash/mcp/primitives/resources/base.py new file mode 100644 index 0000000000..e63ffc9681 --- /dev/null +++ b/dash/mcp/primitives/resources/base.py @@ -0,0 +1,28 @@ +"""Base class for MCP resource providers.""" + +from __future__ import annotations + +from mcp.types import ReadResourceResult, Resource, ResourceTemplate + + +class MCPResourceProvider: + """Base class for MCP resource providers. + + Subclasses must set ``uri`` and implement ``read_resource``. + Override ``get_resource`` and/or ``get_template`` to advertise + the resource in ``resources/list`` or ``resources/templates/list``. + """ + + uri: str + + @classmethod + def get_resource(cls) -> Resource | None: + return None + + @classmethod + def get_template(cls) -> ResourceTemplate | None: + return None + + @classmethod + def read_resource(cls, uri: str) -> ReadResourceResult: + raise NotImplementedError diff --git a/dash/mcp/primitives/resources/resource_clientside_callbacks.py b/dash/mcp/primitives/resources/resource_clientside_callbacks.py new file mode 100644 index 0000000000..a8c0a0076a --- /dev/null +++ b/dash/mcp/primitives/resources/resource_clientside_callbacks.py @@ -0,0 +1,95 @@ +"""Clientside callbacks resource.""" + +from __future__ import annotations + +import json +from typing import Any + +from mcp.types import ( + ReadResourceResult, + Resource, + TextResourceContents, +) +from pydantic import AnyUrl + +from dash import get_app +from dash._utils import clean_property_name, split_callback_id + +from .base import MCPResourceProvider + + +class ClientsideCallbacksResource(MCPResourceProvider): + uri = "dash://clientside-callbacks" + + @classmethod + def get_resource(cls) -> Resource | None: + if not _get_clientside_callbacks(): + return None + return Resource( + uri=AnyUrl(cls.uri), + name="dash_clientside_callbacks", + description=( + "Actions the user can take manually in the browser " + "to affect clientside state. Inputs describe the " + "components that can be changed to trigger an effect. " + "Outputs describe the components that will change " + "in response." + ), + mimeType="application/json", + ) + + @classmethod + def read_resource(cls, uri: str = "") -> ReadResourceResult: + data = { + "description": ( + "These are actions that the user can take manually in the " + "browser to affect the clientside state. Inputs describe " + "the components that can be changed to trigger an effect. " + "Outputs describe the components that will change in " + "response to the effect." + ), + "callbacks": _get_clientside_callbacks(), + } + return ReadResourceResult( + contents=[ + TextResourceContents( + uri=AnyUrl(cls.uri), + mimeType="application/json", + text=json.dumps(data, default=str), + ) + ] + ) + + +def _get_clientside_callbacks() -> list[dict[str, Any]]: + """Get input/output mappings for clientside callbacks.""" + app = get_app() + callbacks = [] + callback_map = getattr(app, "callback_map", {}) + + for output_id, callback_info in callback_map.items(): + if "callback" in callback_info: + continue + normalize_deps = lambda deps: [ + { + "component_id": str(d.get("id", "unknown")), + "property": d.get("property", "unknown"), + } + for d in deps + ] + parsed = split_callback_id(output_id) + if isinstance(parsed, dict): + parsed = [parsed] + outputs = [ + {"component_id": p["id"], "property": clean_property_name(p["property"])} + for p in parsed + ] + callbacks.append( + { + "outputs": outputs, + "inputs": normalize_deps(callback_info.get("inputs", [])), + "state": normalize_deps(callback_info.get("state", [])), + } + ) + + return callbacks diff --git a/dash/mcp/primitives/resources/resource_components.py b/dash/mcp/primitives/resources/resource_components.py new file mode 100644 index 0000000000..e14cf80745 --- /dev/null +++ b/dash/mcp/primitives/resources/resource_components.py @@ -0,0 +1,62 @@ +"""Component list resource.""" + +from __future__ import annotations + +import json + +from mcp.types import ( + ReadResourceResult, + Resource, + TextResourceContents, +) +from pydantic import AnyUrl + +from dash import get_app +from dash._layout_utils import traverse + +from .base import MCPResourceProvider + + +class ComponentsResource(MCPResourceProvider): + uri = "dash://components" + + @classmethod + def get_resource(cls) -> Resource | None: + return Resource( + uri=AnyUrl(cls.uri), + name="dash_components", + description=( + "All components with IDs in the app layout. " + "Use get_dash_component with any of these IDs " + "to inspect their properties and values. " + "See dash://layout for the tree structure showing " + "how these components are nested in the page." + ), + mimeType="application/json", + ) + + @classmethod + def read_resource(cls, uri: str = "") -> ReadResourceResult: + app = get_app() + layout = app.get_layout() + components = sorted( + [ + { + "id": str(getattr(comp, "id")), + "type": getattr(comp, "_type", type(comp).__name__), + } + for comp, _ in traverse(layout) + if getattr(comp, "id", None) is not None + ], + key=lambda c: c["id"], + ) + + return ReadResourceResult( + contents=[ + TextResourceContents( + uri=AnyUrl(cls.uri), + mimeType="application/json", + text=json.dumps(components), + ) + ] + ) diff --git a/dash/mcp/primitives/resources/resource_layout.py b/dash/mcp/primitives/resources/resource_layout.py new file mode 100644 index 0000000000..7659d1fd8f --- /dev/null +++ b/dash/mcp/primitives/resources/resource_layout.py @@ -0,0 +1,44 @@ +"""Layout tree resource for the whole app.""" + +from __future__ import annotations + +from mcp.types import ( + ReadResourceResult, + Resource, + TextResourceContents, +) +from pydantic import AnyUrl + +from dash import get_app +from dash._utils import to_json + +from .base import MCPResourceProvider + + +class LayoutResource(MCPResourceProvider): + uri = "dash://layout" + + @classmethod + def get_resource(cls) -> Resource | None: + return Resource( + uri=AnyUrl(cls.uri), + name="dash_app_layout", + description=( + "Full component tree of the Dash app. " + "See dash://components for a compact list of component IDs." + ), + mimeType="application/json", + ) + + @classmethod + def read_resource(cls, uri: str = "") -> ReadResourceResult: + app = get_app() + return ReadResourceResult( + contents=[ + TextResourceContents( + uri=AnyUrl(cls.uri), + mimeType="application/json", + text=to_json(app.get_layout()), + ) + ] + ) diff --git a/dash/mcp/primitives/resources/resource_page_layout.py b/dash/mcp/primitives/resources/resource_page_layout.py new file mode 100644 index 0000000000..bbfca411bc --- /dev/null +++ b/dash/mcp/primitives/resources/resource_page_layout.py @@ -0,0 +1,64 @@ +"""Per-page layout resource template for multi-page apps.""" + +from __future__ import annotations + +from mcp.types import ( + ReadResourceResult, + ResourceTemplate, + TextResourceContents, +) +from pydantic import AnyUrl + +from dash import html +from dash._pages import PAGE_REGISTRY +from dash._utils import to_json + +from .base import MCPResourceProvider + +_URI_TEMPLATE = "dash://page-layout/{path}" + + +class PageLayoutResource(MCPResourceProvider): + uri = "dash://page-layout/" + + @classmethod + def get_template(cls) -> ResourceTemplate | None: + if not PAGE_REGISTRY: + return None + return ResourceTemplate( + uriTemplate=_URI_TEMPLATE, + name="dash_page_layout", + description="Component tree for a specific page in the app.", + mimeType="application/json", + ) + + @classmethod + def read_resource(cls, uri: str) -> ReadResourceResult: + path = uri[len(cls.uri) :] + if not path.startswith("/"): + path = "/" + path + + page_layout = None + for _module, page in PAGE_REGISTRY.items(): + if page.get("path") == path: + page_layout = page.get("layout") + break + + if page_layout is None: + raise ValueError(f"Page not found: {path}") + + if callable(page_layout): + page_layout = page_layout() + + if isinstance(page_layout, (list, tuple)): + page_layout = html.Div(list(page_layout)) + + return ReadResourceResult( + contents=[ + TextResourceContents( + uri=AnyUrl(uri), + mimeType="application/json", + text=to_json(page_layout), + ) + ] + ) diff --git a/dash/mcp/primitives/resources/resource_pages.py b/dash/mcp/primitives/resources/resource_pages.py new file mode 100644 index 0000000000..21fa27679f --- /dev/null +++ b/dash/mcp/primitives/resources/resource_pages.py @@ -0,0 +1,60 @@ +"""Pages resource for multi-page apps.""" + +from __future__ import annotations + +import json + +from mcp.types import ( + ReadResourceResult, + Resource, + TextResourceContents, +) +from pydantic import AnyUrl + +from dash._pages import PAGE_REGISTRY + +from .base import MCPResourceProvider + + +class PagesResource(MCPResourceProvider): + uri = "dash://pages" + + @classmethod + def get_resource(cls) -> Resource | None: + if not PAGE_REGISTRY: + return None + return Resource( + uri=AnyUrl(cls.uri), + name="dash_app_pages", + description=( + "List of all pages in this multi-page Dash app " + "with paths, names, titles, and descriptions." + ), + mimeType="application/json", + ) + + @classmethod + def read_resource(cls, uri: str = "") -> ReadResourceResult: + pages = [] + for module, page in PAGE_REGISTRY.items(): + title = page.get("title", "") + description = page.get("description", "") + pages.append( + { + "module": module, + "path": page.get("path", ""), + "name": page.get("name", ""), + "title": title if not callable(title) else page.get("name", ""), + "description": description if not callable(description) else "", + } + ) + + return ReadResourceResult( + contents=[ + TextResourceContents( + uri=AnyUrl(cls.uri), + mimeType="application/json", + text=json.dumps(pages, default=str), + ) + ] + ) diff --git a/dash/mcp/primitives/tools/__init__.py b/dash/mcp/primitives/tools/__init__.py new file mode 100644 index 0000000000..eea7af43c1 --- /dev/null +++ b/dash/mcp/primitives/tools/__init__.py @@ -0,0 +1,48 @@ +"""MCP tool listing and call handling.""" + +from __future__ import annotations + +from typing import Any + +from mcp.types import CallToolResult, CreateTaskResult, ListToolsResult + +from dash.mcp.types import ToolNotFoundError + +from .base import MCPToolProvider +from .tool_background_tasks import BackgroundTaskTools +from .tool_decorated_mcp_functions import DecoratedFunctionTools +from .tool_get_dash_component import GetDashComponentTool +from .tools_callbacks import CallbackTools + +_TOOL_PROVIDERS: list[type[MCPToolProvider]] = [ + CallbackTools, + BackgroundTaskTools, + GetDashComponentTool, + DecoratedFunctionTools, +] + + +def list_tools() -> ListToolsResult: + """Build the MCP tools/list response.""" + tools = [] + for provider in _TOOL_PROVIDERS: + tools.extend(provider.list_tools()) + return ListToolsResult(tools=tools) + + +def call_tool( + tool_name: str, arguments: dict[str, Any], task: dict | None = None +) -> CallToolResult | CreateTaskResult: + """Route a tools/call request by tool name. + + The optional ``task`` parameter (per MCP Tasks protocol) is passed + through to providers that support background callbacks. + """ + for provider in _TOOL_PROVIDERS: + if tool_name in provider.get_tool_names(): + return provider.call_tool(tool_name, arguments, task=task) + raise ToolNotFoundError( + f"Tool not found: {tool_name}." + " The app's callbacks may have changed." + " Please call tools/list to refresh your tool list." + ) diff --git a/dash/mcp/primitives/tools/base.py b/dash/mcp/primitives/tools/base.py new file mode 100644 index 0000000000..f7a5c54aac --- /dev/null +++ b/dash/mcp/primitives/tools/base.py @@ -0,0 +1,30 @@ +"""Base class for MCP tool providers.""" + +from __future__ import annotations + +from typing import Any + +from mcp.types import CallToolResult, CreateTaskResult, Tool + + +class MCPToolProvider: + """A provider of one or more MCP tools. + + Subclasses implement ``list_tools`` to return the tools they provide, + ``get_tool_names`` to advertise those names for routing, and + ``call_tool`` to execute a tool by name. + """ + + @classmethod + def get_tool_names(cls) -> set[str]: + raise NotImplementedError + + @classmethod + def list_tools(cls) -> list[Tool]: + raise NotImplementedError + + @classmethod + def call_tool( + cls, tool_name: str, arguments: dict[str, Any], task: dict | None = None + ) -> CallToolResult | CreateTaskResult: + raise NotImplementedError diff --git a/dash/mcp/primitives/tools/callback_adapter.py b/dash/mcp/primitives/tools/callback_adapter.py new file mode 100644 index 0000000000..9af56bb879 --- /dev/null +++ b/dash/mcp/primitives/tools/callback_adapter.py @@ -0,0 +1,479 @@ +"""Adapter: Dash callback → MCP tool interface. + +Wraps a raw ``callback_map`` entry and exposes MCP-facing +properties (tool name, params, outputs) lazily. +""" + +from __future__ import annotations + +import inspect +import json +import typing +from functools import cached_property +from typing import Any, cast + +from mcp.types import Tool + +from dash import get_app +from dash._layout_utils import ( + _WILDCARD_VALUES, + find_component, + find_matching_components, + parse_wildcard_id, +) +from dash._grouping import flatten_grouping +from dash._utils import clean_property_name, split_callback_id +from dash.types import ( + CallbackDependency, + CallbackExecutionBody, + CallbackInput, + CallbackInputs, + CallbackOutput, + CallbackOutputTarget, + WildcardId, +) +from dash.mcp.types import MCPInput, MCPOutput, is_nullable +from .callback_utils import run_callback +from .descriptions import build_tool_description +from .input_schemas import get_input_schema +from .output_schemas import get_output_schema + + +class CallbackAdapter: + """Adapts a single Dash callback_map entry to the MCP tool interface.""" + + def __init__(self, callback_output_id: str): + self._output_id = callback_output_id + + # ------------------------------------------------------------------- + # Projections + # ------------------------------------------------------------------- + + @cached_property + def as_mcp_tool(self) -> Tool: + """Transforms the internal Dash callback to a structured MCP tool. + + This tool can be serialized for LLM consumption or used internally for + its computed data. + """ + return Tool( + name=self.tool_name, + description=self._description, + inputSchema=self._input_schema, + outputSchema=self._output_schema, + ) + + def as_callback_body(self, kwargs: dict[str, Any]) -> CallbackExecutionBody: + """Transforms the given kwargs to a dict suitable for calling this callback. + + Mirrors how the Dash renderer assembles the callback payload — + see ``fillVals()`` in ``dash-renderer/src/actions/callbacks.ts``. + + For pattern-matching callbacks, wildcard deps are expanded into + nested arrays with concrete component IDs. + """ + raw_inputs = self._cb_info.get("inputs", []) + raw_state = self._cb_info.get("state", []) + n_deps = len(raw_inputs) + len(raw_state) + + flat_values = [None] * n_deps + for i, name in enumerate(self._param_names): + if i < n_deps and name in kwargs: + flat_values[i] = kwargs[name] + + inputs_with_values = [ + _expand_dep(dep, flat_values[i]) for i, dep in enumerate(raw_inputs) + ] + state_with_values = [ + _expand_dep(dep, flat_values[len(raw_inputs) + i]) + for i, dep in enumerate(raw_state) + ] + + outputs_spec = _expand_output_spec( + self._output_id, self._cb_info, inputs_with_values + ) + + # changedPropIds: only inputs with non-None values. + # This determines ctx.triggered_id in the callback. + changed = [] + for entry in inputs_with_values: + if isinstance(entry, dict) and entry.get("value") is not None: + eid = entry.get("id") + if isinstance(eid, dict): + changed.append( + f"{json.dumps(eid, sort_keys=True)}.{entry['property']}" + ) + elif isinstance(eid, str): + changed.append(f"{eid}.{entry['property']}") + + return { + "output": self._output_id, + "outputs": outputs_spec, + "inputs": inputs_with_values, + "state": state_with_values, + "changedPropIds": changed, + } + + # ------------------------------------------------------------------- + # Public identity and metadata + # ------------------------------------------------------------------- + + @cached_property + def is_valid(self) -> bool: + """Whether all input components exist in the layout.""" + all_deps = self._cb_info.get("inputs", []) + self._cb_info.get("state", []) + for dep in all_deps: + dep_id = str(dep.get("id", "")) + if dep_id.startswith("{"): + continue + if find_component(dep_id) is None: + return False + return True + + @property + def output_id(self) -> str: + return self._output_id + + @property + def tool_name(self) -> str: + # pylint: disable-next=protected-access + return get_app().mcp_callback_map._tool_names_map[self._output_id] + + @cached_property + def prevents_initial_call(self) -> bool: + for cb in get_app()._callback_list: # pylint: disable=protected-access + if cb["output"] == self._output_id: + return cb.get("prevent_initial_call", False) + return False + + # ------------------------------------------------------------------- + # Private: computed fields for the MCP Tool + # ------------------------------------------------------------------- + + @cached_property + def _description(self) -> str: + return build_tool_description(self) + + @cached_property + def _input_schema(self) -> dict[str, Any]: + properties = {p["name"]: get_input_schema(p) for p in self.inputs} + required = [p["name"] for p in self.inputs if p["required"]] + + input_schema: dict[str, Any] = {"type": "object", "properties": properties} + if required: + input_schema["required"] = required + return input_schema + + @cached_property + def _output_schema(self) -> dict[str, Any]: + return get_output_schema() + + # ------------------------------------------------------------------- + # Private: callback metadata + # ------------------------------------------------------------------- + + @cached_property + def _docstring(self) -> str | None: + return getattr(self._original_func, "__doc__", None) + + @cached_property + def _initial_output(self) -> dict[str, CallbackOutput]: + """Run this callback with initial input values. + + Returns the ``response`` portion of the callback result: + ``{component_id: {property: value}}``. + + Skipped for callbacks with ``prevent_initial_call=True``, + matching how the Dash renderer skips them on page load. + """ + if self.prevents_initial_call: + return {} + + callback_map = get_app().mcp_callback_map + kwargs = {} + for p in self.inputs: + upstream = callback_map.find_by_output(p["id_and_prop"]) + if upstream is self: + kwargs[p["name"]] = getattr( + find_component(p["component_id"]), p["property"], None + ) + else: + kwargs[p["name"]] = callback_map.get_initial_value(p["id_and_prop"]) + try: + result = run_callback(self, kwargs) + return result.get("response", {}) + except Exception: # pylint: disable=broad-exception-caught + return {} + + def initial_output_value(self, id_and_prop: str) -> Any: + """Return the initial value for a specific output ``"component_id.property"``.""" + component_id, prop = id_and_prop.rsplit(".", 1) + return self._initial_output.get(component_id, {}).get(prop) + + @cached_property + def outputs(self) -> list[MCPOutput]: + if self._cb_info.get("no_output"): + return [] + parsed = split_callback_id(self._output_id) + if isinstance(parsed, dict): + parsed = [parsed] + result: list[MCPOutput] = [] + for p in parsed: + comp_id = p["id"] + prop = clean_property_name(p["property"]) + id_and_prop = f"{comp_id}.{prop}" + comp = find_component(comp_id) + result.append( + { + "id_and_prop": id_and_prop, + "component_id": comp_id, + "property": prop, + "component_type": getattr(comp, "_type", None), + "initial_value": self.initial_output_value(id_and_prop), + "tool_name": self.tool_name, + } + ) + return result + + @cached_property + def inputs(self) -> list[MCPInput]: + all_deps = self._cb_info.get("inputs", []) + self._cb_info.get("state", []) + callback_map = get_app().mcp_callback_map + + result: list[MCPInput] = [] + for dep, name, annotation in zip( + all_deps, self._param_names, self._param_annotations + ): + comp_id = str(dep.get("id", "unknown")) + comp = find_component(comp_id) + prop = dep.get("property", "unknown") + id_and_prop = f"{comp_id}.{prop}" + + upstream_cb = callback_map.find_by_output(id_and_prop) + upstream_output = None + if upstream_cb is not None and upstream_cb is not self: + if not upstream_cb.prevents_initial_call: + for out in upstream_cb.outputs: + if out["id_and_prop"] == id_and_prop: + upstream_output = out + break + + initial_value = ( + upstream_output["initial_value"] + if upstream_output is not None + else getattr(comp, prop, None) + ) + + if annotation is not None: + required = not is_nullable(annotation) + else: + required = initial_value is not None + + result.append( + { + "name": name, + "id_and_prop": id_and_prop, + "component_id": comp_id, + "property": prop, + "annotation": annotation, + "component_type": getattr(comp, "_type", None), + "component": comp, + "required": required, + "initial_value": initial_value, + "upstream_output": upstream_output, + } + ) + return result + + # ------------------------------------------------------------------- + # Helpers + # ------------------------------------------------------------------- + + @cached_property + def _cb_info(self) -> dict[str, Any]: + return get_app().callback_map[self._output_id] + + @cached_property + def _original_func(self) -> Any | None: + func = self._cb_info.get("callback") + return getattr(func, "__wrapped__", func) + + @cached_property + def _func_signature(self) -> inspect.Signature | None: + if self._original_func is None: + return None + try: + return inspect.signature(self._original_func) + except (ValueError, TypeError): + return None + + @cached_property + def _dep_param_map(self) -> list[tuple[str, str]]: + """(func_param_name, mcp_param_name) per dep, in dep order. + + Single source of truth for mapping deps to param names. + All dict-vs-list branching is confined here. + """ + all_deps = self._cb_info.get("inputs", []) + self._cb_info.get("state", []) + n_deps = len(all_deps) + indices = self._cb_info.get("inputs_state_indices") + + if isinstance(indices, dict): + entries: list[tuple[int, str, str]] = [] + for func_name, idx in indices.items(): + positions = flatten_grouping(idx) + if len(positions) == 1: + entries.append((positions[0], func_name, func_name)) + else: + for pos in positions: + dep = all_deps[pos] if pos < n_deps else {} + comp_id = str(dep.get("id", "unknown")).replace("-", "_") + prop = dep.get("property", "unknown") + entries.append( + (pos, func_name, f"{func_name}_{comp_id}__{prop}") + ) + entries.sort(key=lambda e: e[0]) + result = [(f, m) for _, f, m in entries] + elif self._func_signature is not None: + names = list(self._func_signature.parameters.keys()) + result = [(n, n) for n in names] + else: + result = [] + + while len(result) < n_deps: + fallback = f"param_{len(result)}" + result.append((fallback, fallback)) + return result + + @cached_property + def _param_names(self) -> list[str]: + """MCP param name per dep, in dep order.""" + return [mcp for _, mcp in self._dep_param_map] + + @cached_property + def _param_annotations(self) -> list[Any | None]: + """One annotation per dep, in dep order.""" + if self._func_signature is None: + return [None] * len(self._dep_param_map) + try: + hints = typing.get_type_hints(self._original_func) + except Exception: # pylint: disable=broad-exception-caught + hints = getattr(self._original_func, "__annotations__", {}) + return [hints.get(func_name) for func_name, _ in self._dep_param_map] + + +def _expand_dep(dep: CallbackDependency, value: Any) -> CallbackInputs: + """ + Attach a concrete value to a callback dependency to produce a valid callback input. + + For regular deps, returns ``{id, property, value}``. + For ALL/ALLSMALLER: passes through the list of ``{id, property, value}`` dicts. + For MATCH: passes through the single ``{id, property, value}`` dict. + """ + pattern = parse_wildcard_id(dep.get("id", "")) + if pattern is None: + return CallbackInput(id=dep["id"], property=dep["property"], value=value) + + # LLM provides browser-like format + if isinstance(value, list): + return cast("list[CallbackInput]", value) + if isinstance(value, dict) and "id" in value: + return cast(CallbackInput, value) + return CallbackInput(id=dep["id"], property=dep["property"], value=value) + + +def _expand_output_spec( + output_id: str, + cb_info: dict, + resolved_inputs: list[CallbackInputs], +) -> CallbackOutputTarget | list[CallbackOutputTarget]: + """Build the outputs spec, expanding wildcards to concrete IDs. + + For wildcard outputs, derives concrete IDs from the resolved inputs. + The browser does the same: input and output wildcards resolve against + the same set of matching components. + """ + if cb_info.get("no_output"): + return [] + + parsed = split_callback_id(output_id) + if isinstance(parsed, dict): + parsed = [parsed] + + results: list[CallbackOutputTarget] = [] + for p in parsed: + pid = p["id"] + prop = clean_property_name(p["property"]) + pattern = parse_wildcard_id(pid) + if pattern is not None: + concrete_ids = _derive_output_ids(pattern, resolved_inputs) + if not concrete_ids: + concrete_ids = [ + getattr(comp, "id") for comp in find_matching_components(pattern) + ] + expanded: list[CallbackDependency] = [ + CallbackDependency(id=cid, property=prop) for cid in concrete_ids + ] + # ALL/ALLSMALLER → nested list; MATCH → single dict + if len(expanded) == 1: + results.append(expanded[0]) + else: + results.append(expanded) + else: + results.append(CallbackDependency(id=pid, property=prop)) + + # Mirror the Dash renderer: single-output callbacks send a bare dict, + # multi-output callbacks send a list. The framework's output value + # matching depends on this shape. + if len(results) == 1: + return results[0] + return results + + +def _derive_output_ids( + output_pattern: WildcardId, + resolved_inputs: list[CallbackInputs], +) -> list[WildcardId] | None: + """Derive concrete output IDs from the resolved input entries. + + Extracts the wildcard key values from the LLM-provided concrete + input IDs and substitutes them into the output pattern. + """ + wildcard_keys = [ + k + for k, v in output_pattern.items() + if isinstance(v, list) and len(v) == 1 and v[0] in _WILDCARD_VALUES + ] + if not wildcard_keys: + return None + + def _substitute(item_id: WildcardId) -> WildcardId | None: + if not isinstance(item_id, dict): + return None + output_id = dict(output_pattern) + for wk in wildcard_keys: + if wk in item_id: + output_id[wk] = item_id[wk] + return output_id + + for entry in resolved_inputs: + # ALL/ALLSMALLER: nested array of {id, property, value} dicts + if isinstance(entry, list) and entry: + concrete_ids = [] + for item in entry: + item_id = item.get("id") + if isinstance(item_id, dict): + out = _substitute(item_id) + if out: + concrete_ids.append(out) + if concrete_ids: + return concrete_ids + # MATCH: single {id, property, value} dict + elif isinstance(entry, dict): + entry_id = entry.get("id") + if isinstance(entry_id, dict): + out = _substitute(entry_id) + if out: + return [out] + + return None diff --git a/dash/mcp/primitives/tools/callback_adapter_collection.py b/dash/mcp/primitives/tools/callback_adapter_collection.py new file mode 100644 index 0000000000..4fdaeabe9c --- /dev/null +++ b/dash/mcp/primitives/tools/callback_adapter_collection.py @@ -0,0 +1,184 @@ +"""Collection of CallbackAdapters with cross-adapter queries. + +Stored as a singleton on ``app.mcp_callback_map``. +""" + +from __future__ import annotations + +import hashlib +import re +from functools import cached_property +from typing import Any + +from mcp.types import Tool + +from dash import get_app +from dash._utils import clean_property_name, split_callback_id +from dash._layout_utils import extract_text, find_component, traverse +from .callback_adapter import CallbackAdapter + + +class CallbackAdapterCollection: + def __init__(self, app): + callback_map = getattr(app, "callback_map", {}) + + raw: list[tuple[str, dict]] = [] + for output_id, cb_info in callback_map.items(): + if cb_info.get("mcp_enabled") is False: + continue + if "callback" not in cb_info: + continue + raw.append((output_id, cb_info)) + + self._tool_names_map = self._build_tool_names(raw) + self._callbacks = [ + CallbackAdapter(callback_output_id=output_id) + for output_id in self._tool_names_map + ] + + @staticmethod + def _sanitize_name(name: str) -> str: + + max_len = 64 + sanitized = re.sub(r"[^a-zA-Z0-9_]", "_", name) + sanitized = re.sub(r"_+", "_", sanitized).strip("_") + if sanitized and sanitized[0].isdigit(): + sanitized = "cb_" + sanitized + full = sanitized or "unnamed_callback" + if len(full) <= max_len: + return full + hash_suffix = hashlib.sha256(full.encode()).hexdigest()[:8] + truncated = sanitized[: max_len - 9].rstrip("_") + return f"{truncated}_{hash_suffix}" + + @classmethod + def _build_tool_names(cls, raw: list[tuple[str, dict]]) -> dict[str, str]: + func_name_counts: dict[str, int] = {} + for _output_id, cb_info in raw: + func = cb_info.get("callback") + original = getattr(func, "__wrapped__", func) + fn = getattr(original, "__name__", "") or "" + func_name_counts[fn] = func_name_counts.get(fn, 0) + 1 + + name_map: dict[str, str] = {} + for output_id, cb_info in raw: + func = cb_info.get("callback") + original = getattr(func, "__wrapped__", func) + fn = getattr(original, "__name__", "") or "" + raw_name = fn if fn and func_name_counts[fn] == 1 else output_id + name_map[output_id] = cls._sanitize_name(raw_name) + return name_map + + def __iter__(self): + return iter(self._callbacks) + + def __len__(self): + return len(self._callbacks) + + def __getitem__(self, index): + return self._callbacks[index] + + def find_by_tool_name(self, name: str) -> CallbackAdapter | None: + for cb in self._callbacks: + if cb.tool_name == name: + return cb + return None + + @cached_property + def outputs_by_prop(self) -> dict[str, list[CallbackAdapter]]: + """Index ``id_and_prop`` → callbacks outputting it. + + Mirrors the dash-renderer's ``outputMap`` (see + ``dash-renderer/src/actions/dependencies.js``). + """ + idx: dict[str, list[CallbackAdapter]] = {} + for cb in self._callbacks: + try: + parsed = split_callback_id(cb.output_id) + except ValueError: + continue + if isinstance(parsed, dict): + parsed = [parsed] + for p in parsed: + key = f"{p['id']}.{clean_property_name(p['property'])}" + idx.setdefault(key, []).append(cb) + return idx + + @cached_property + def inputs_by_prop(self) -> dict[str, list[CallbackAdapter]]: + """Index ``id_and_prop`` → callbacks consuming it as input/state. + + Mirrors the dash-renderer's ``inputMap`` (see + ``dash-renderer/src/actions/dependencies.js``). + Many callbacks may share a key. + """ + idx: dict[str, list[CallbackAdapter]] = {} + for cb in self._callbacks: + # pylint: disable-next=protected-access + deps = cb._cb_info.get("inputs", []) + cb._cb_info.get("state", []) + for dep in deps: + key = f"{dep.get('id', 'unknown')}.{dep.get('property', 'unknown')}" + idx.setdefault(key, []).append(cb) + return idx + + def find_by_output(self, id_and_prop: str) -> CallbackAdapter | None: + """Find the adapter that outputs to ``id_and_prop`` (``"component_id.property"``).""" + candidates = self.outputs_by_prop.get(id_and_prop, []) + return candidates[0] if candidates else None + + def get_initial_value(self, id_and_prop: str) -> Any: + """Return the initial value for ``id_and_prop`` (``"component_id.property"``). + + If a callback outputs to this property, runs it (recursively + resolving its inputs). Otherwise returns the layout default. + """ + upstream_cb = self.find_by_output(id_and_prop) + if upstream_cb is not None: + return upstream_cb.initial_output_value(id_and_prop) + component_id, prop = id_and_prop.rsplit(".", 1) + layout_component = find_component(component_id) + return getattr(layout_component, prop, None) + + def as_mcp_tools(self) -> list[Tool]: + return [cb.as_mcp_tool for cb in self._callbacks if cb.is_valid] + + @property + def tool_names(self) -> set[str]: + return set(self._tool_names_map.values()) + + @cached_property + def component_label_map(self) -> dict[str, list[str]]: + """Map component ID → list of label texts from html.Label containers + and/or `htmlFor` associations. + """ + layout = get_app().get_layout() + if layout is None: + return {} + + labels: dict[str, list[str]] = {} + for comp, ancestors in traverse(layout): + if getattr(comp, "_type", None) == "Label": + html_for = getattr(comp, "htmlFor", None) + if html_for is not None: + text = extract_text(comp) + if text: + labels.setdefault(str(html_for), []).append(text) + + comp_id = getattr(comp, "id", None) + if comp_id is not None: + self._add_ancestor_label(comp_id, ancestors, labels) + + return labels + + @staticmethod + def _add_ancestor_label(comp_id, ancestors, labels: dict[str, list[str]]) -> None: + """Record the text of the nearest Label ancestor for ``comp_id``, if any.""" + for ancestor in reversed(ancestors): + if getattr(ancestor, "_type", None) != "Label": + continue + text = extract_text(ancestor) + if text: + sid = str(comp_id) + if text not in labels.get(sid, []): + labels.setdefault(sid, []).append(text) + return diff --git a/dash/mcp/primitives/tools/callback_utils.py b/dash/mcp/primitives/tools/callback_utils.py new file mode 100644 index 0000000000..24a8a693fb --- /dev/null +++ b/dash/mcp/primitives/tools/callback_utils.py @@ -0,0 +1,41 @@ +"""Callback introspection utilities for MCP tools.""" + +from __future__ import annotations + +import json +from contextvars import copy_context +from typing import TYPE_CHECKING, Any + +from dash import get_app +from dash.mcp.types import CallbackExecutionError +from dash.types import CallbackExecutionResponse + +if TYPE_CHECKING: + from .callback_adapter import CallbackAdapter + + +def run_callback( + callback: CallbackAdapter, kwargs: dict[str, Any] +) -> CallbackExecutionResponse: + """Execute a callback via the framework. + + Must be called from inside an active request handler; the backend's + request adapter reads cookies/headers/args from the current request. + """ + body = callback.as_callback_body(kwargs) + app = get_app() + + try: + # pylint: disable=protected-access + cb_ctx = app._initialize_context(body) + func = app._prepare_callback(cb_ctx, body) + args = app._inputs_to_vals(cb_ctx.inputs_list + cb_ctx.states_list) + ctx = copy_context() + partial_func = app._execute_callback(func, args, cb_ctx.outputs_list, cb_ctx) + response_text = ctx.run(partial_func) + except Exception as err: + raise CallbackExecutionError( + f"Callback {callback.output_id} failed: {err}" + ) from err + + return json.loads(response_text) diff --git a/dash/mcp/primitives/tools/descriptions/__init__.py b/dash/mcp/primitives/tools/descriptions/__init__.py new file mode 100644 index 0000000000..a4227868d8 --- /dev/null +++ b/dash/mcp/primitives/tools/descriptions/__init__.py @@ -0,0 +1,35 @@ +"""Tool-level description generation for MCP tools. + +Each source is a ``ToolDescriptionSource`` subclass that can add text +to the tool's description. All sources are accumulated. + +This is distinct from per-parameter descriptions +(in ``input_schemas/input_descriptions/``) which populate +``inputSchema.properties.{param}.description``. +""" + +from __future__ import annotations + +from typing import TYPE_CHECKING + +from .base import ToolDescriptionSource +from .description_background_callbacks import BackgroundCallbackDescription +from .description_docstring import DocstringDescription +from .description_outputs import OutputSummaryDescription + +if TYPE_CHECKING: + from dash.mcp.primitives.tools.callback_adapter import CallbackAdapter + +_SOURCES: list[type[ToolDescriptionSource]] = [ + OutputSummaryDescription, + DocstringDescription, + BackgroundCallbackDescription, +] + + +def build_tool_description(callback: CallbackAdapter) -> str: + """Build a human-readable description for an MCP tool.""" + lines: list[str] = [] + for source in _SOURCES: + lines.extend(source.describe(callback)) + return "\n".join(lines) if lines else "Dash callback" diff --git a/dash/mcp/primitives/tools/descriptions/base.py b/dash/mcp/primitives/tools/descriptions/base.py new file mode 100644 index 0000000000..c069f67918 --- /dev/null +++ b/dash/mcp/primitives/tools/descriptions/base.py @@ -0,0 +1,21 @@ +"""Base class for tool-level description sources.""" + +from __future__ import annotations + +from typing import TYPE_CHECKING + +if TYPE_CHECKING: + from dash.mcp.primitives.tools.callback_adapter import CallbackAdapter + + +class ToolDescriptionSource: + """A source of text that can describe an MCP tool. + + Subclasses implement ``describe`` to return strings that will be + joined into the tool's ``description`` field. All sources are + accumulated — every source can add text to the overall description. + """ + + @classmethod + def describe(cls, callback: CallbackAdapter) -> list[str]: + raise NotImplementedError diff --git a/dash/mcp/primitives/tools/descriptions/description_background_callbacks.py b/dash/mcp/primitives/tools/descriptions/description_background_callbacks.py new file mode 100644 index 0000000000..eada24a01e --- /dev/null +++ b/dash/mcp/primitives/tools/descriptions/description_background_callbacks.py @@ -0,0 +1,32 @@ +"""Description for background (long-running) callbacks. + +Informs the LLM that the tool returns a taskId immediately +and must be polled via the background task result tool. +""" + +from __future__ import annotations + +from typing import TYPE_CHECKING + +from ..tool_background_tasks import GET_RESULT_TOOL_NAME +from .base import ToolDescriptionSource + +if TYPE_CHECKING: + from dash.mcp.primitives.tools.callback_adapter import CallbackAdapter + + +class BackgroundCallbackDescription(ToolDescriptionSource): + """Add async polling instructions for background callbacks.""" + + @classmethod + def describe(cls, callback: CallbackAdapter) -> list[str]: + # pylint: disable-next=protected-access + if not callback._cb_info.get("background"): + return [] + + return [ + "", + "This is a long-running background operation. " + "It returns a taskId immediately. " + f"Call tool `{GET_RESULT_TOOL_NAME}` with the taskId to poll for the result.", + ] diff --git a/dash/mcp/primitives/tools/descriptions/description_docstring.py b/dash/mcp/primitives/tools/descriptions/description_docstring.py new file mode 100644 index 0000000000..c34d527077 --- /dev/null +++ b/dash/mcp/primitives/tools/descriptions/description_docstring.py @@ -0,0 +1,39 @@ +"""Callback docstring for tool descriptions.""" + +from __future__ import annotations + +from typing import TYPE_CHECKING + +from dash import get_app + +from .base import ToolDescriptionSource + +if TYPE_CHECKING: + from dash.mcp.primitives.tools.callback_adapter import CallbackAdapter + + +class DocstringDescription(ToolDescriptionSource): + """Return the callback's docstring as description lines. + + Gated behind an opt-in flag: docstrings may contain sensitive + implementation details that the browser never surfaces to users, + so we don't expose them to MCP clients unless the author opts in + — either per-callback or app-wide. + """ + + @classmethod + def describe(cls, callback: CallbackAdapter) -> list[str]: + if not cls._is_exposed(callback): + return [] + docstring = callback._docstring # pylint: disable=protected-access + if docstring: + return ["", docstring.strip()] + return [] + + @classmethod + def _is_exposed(cls, callback: CallbackAdapter) -> bool: + # pylint: disable-next=protected-access + per_callback = callback._cb_info.get("mcp_expose_docstring") + if per_callback is not None: + return per_callback + return get_app().config.get("mcp_expose_docstrings", False) diff --git a/dash/mcp/primitives/tools/descriptions/description_outputs.py b/dash/mcp/primitives/tools/descriptions/description_outputs.py new file mode 100644 index 0000000000..b7bf55e81c --- /dev/null +++ b/dash/mcp/primitives/tools/descriptions/description_outputs.py @@ -0,0 +1,45 @@ +"""Output summary for tool descriptions.""" + +from __future__ import annotations + +from typing import TYPE_CHECKING + +from ..prop_roles import iter_prop_roles +from .base import ToolDescriptionSource + +if TYPE_CHECKING: + from dash.mcp.primitives.tools.callback_adapter import CallbackAdapter + + +def _describe_output(comp_type: str | None, prop: str) -> str | None: + for role in iter_prop_roles(): + if role.description is not None and role.matches(comp_type, prop): + return role.description + return None + + +class OutputSummaryDescription(ToolDescriptionSource): + """Produce a short summary of what the callback outputs represent.""" + + @classmethod + def describe(cls, callback: CallbackAdapter) -> list[str]: + outputs = callback.outputs + if not outputs: + return ["Dash callback"] + + lines: list[str] = [] + for out in outputs: + comp_id = out["component_id"] + prop = out["property"] + description = _describe_output(out.get("component_type"), prop) + + if description is not None: + lines.append(f"- {comp_id}.{prop}: {description}") + else: + lines.append(f"- {comp_id}.{prop}") + + n = len(outputs) + if n == 1: + return [lines[0].lstrip("- ")] + header = f"Returns {n} output{'s' if n > 1 else ''}:" + return [header] + lines diff --git a/dash/mcp/primitives/tools/input_schemas/__init__.py b/dash/mcp/primitives/tools/input_schemas/__init__.py new file mode 100644 index 0000000000..e037c5793f --- /dev/null +++ b/dash/mcp/primitives/tools/input_schemas/__init__.py @@ -0,0 +1,46 @@ +"""Input schema generation for MCP tool inputSchema fields. + +Each source is an ``InputSchemaSource`` subclass that can type +an input parameter. Sources are tried in priority order — first +non-None wins. +""" + +from __future__ import annotations + +from typing import Any + +from dash.mcp.types import MCPInput + +from .base import InputSchemaSource +from .schema_pattern_matching import PatternMatchingSchema +from .schema_callback_type_annotations import AnnotationSchema +from .schema_component_proptypes_overrides import OverrideSchema +from .schema_component_proptypes import ComponentPropSchema +from .input_descriptions import get_property_description + +_SOURCES: list[type[InputSchemaSource]] = [ + PatternMatchingSchema, + AnnotationSchema, + OverrideSchema, + ComponentPropSchema, +] + + +def get_input_schema(param: MCPInput) -> dict[str, Any]: + """Return the complete JSON Schema for a callback input parameter. + + Type sources provide ``type``/``enum`` (first non-None wins). + Description is assembled by ``input_descriptions``. + """ + schema: dict[str, Any] = {} + for source in _SOURCES: + result = source.get_schema(param) + if result is not None: + schema = result + break + + description = get_property_description(param) + if description: + schema = {**schema, "description": description} + + return schema diff --git a/dash/mcp/primitives/tools/input_schemas/base.py b/dash/mcp/primitives/tools/input_schemas/base.py new file mode 100644 index 0000000000..42fe2352b6 --- /dev/null +++ b/dash/mcp/primitives/tools/input_schemas/base.py @@ -0,0 +1,20 @@ +"""Base class for input schema sources.""" + +from __future__ import annotations + +from typing import Any + +from dash.mcp.types import MCPInput + + +class InputSchemaSource: + """A source of JSON Schema that can type an MCP tool input parameter. + + Subclasses implement ``get_schema`` to return a JSON Schema dict + for the parameter, or ``None`` if this source cannot determine the + type. Sources are tried in priority order — first non-None wins. + """ + + @classmethod + def get_schema(cls, param: MCPInput) -> dict[str, Any] | None: + raise NotImplementedError diff --git a/dash/mcp/primitives/tools/input_schemas/input_descriptions/__init__.py b/dash/mcp/primitives/tools/input_schemas/input_descriptions/__init__.py new file mode 100644 index 0000000000..ebba3b4af8 --- /dev/null +++ b/dash/mcp/primitives/tools/input_schemas/input_descriptions/__init__.py @@ -0,0 +1,32 @@ +"""Per-property description generation for MCP tool input parameters. + +Each source is an ``InputDescriptionSource`` subclass that can add +text to a parameter's description. All sources are accumulated. +""" + +from __future__ import annotations + +from dash.mcp.types import MCPInput + +from .base import InputDescriptionSource +from .description_component_props import ComponentPropsDescription +from .description_docstrings import DocstringPropDescription +from .description_html_labels import LabelDescription +from .description_pattern_matching import PatternMatchingDescription + +_SOURCES: list[type[InputDescriptionSource]] = [ + DocstringPropDescription, + LabelDescription, + ComponentPropsDescription, + PatternMatchingDescription, +] + + +def get_property_description(param: MCPInput) -> str | None: + """Build a complete description string for a callback input parameter.""" + lines: list[str] = [] + if not param.get("required", True): + lines.append("Input is optional.") + for source in _SOURCES: + lines.extend(source.describe(param)) + return "\n".join(lines) if lines else None diff --git a/dash/mcp/primitives/tools/input_schemas/input_descriptions/base.py b/dash/mcp/primitives/tools/input_schemas/input_descriptions/base.py new file mode 100644 index 0000000000..6bfd62da04 --- /dev/null +++ b/dash/mcp/primitives/tools/input_schemas/input_descriptions/base.py @@ -0,0 +1,18 @@ +"""Base class for per-parameter description sources.""" + +from __future__ import annotations + +from dash.mcp.types import MCPInput + + +class InputDescriptionSource: + """A source of text that can describe an MCP tool input parameter. + + Subclasses implement ``describe`` to return strings that will be + added to the callback parameter's description. All sources + are accumulated — every source can add text to the overall description. + """ + + @classmethod + def describe(cls, param: MCPInput) -> list[str]: + raise NotImplementedError diff --git a/dash/mcp/primitives/tools/input_schemas/input_descriptions/description_component_props.py b/dash/mcp/primitives/tools/input_schemas/input_descriptions/description_component_props.py new file mode 100644 index 0000000000..58b4b4627e --- /dev/null +++ b/dash/mcp/primitives/tools/input_schemas/input_descriptions/description_component_props.py @@ -0,0 +1,87 @@ +"""Generic component property descriptions. + +Generate a description for each component prop that has a value (either set +directly in the layout or by an upstream callback). +""" + +from __future__ import annotations + +from typing import Any + +from dash import get_app +from dash.mcp.types import MCPInput + +from .base import InputDescriptionSource + +_MAX_VALUE_LENGTH = 200 + +_MCP_EXCLUDED_PROPS = {"id", "className", "style"} + +_PROP_TEMPLATES: dict[tuple[str | None, str], str] = { + ("Store", "storage_type"): ( + "storage_type: {value}. Describes how to store the value client-side" + "'memory' resets on page refresh. " + "'session' persists for the duration of this session. " + "'local' persists on disk until explicitly cleared." + ), +} + + +class ComponentPropsDescription(InputDescriptionSource): + """Describe component properties with their current values.""" + + @classmethod + def describe(cls, param: MCPInput) -> list[str]: + component = param.get("component") + if component is None: + return [] + + component_id = param["component_id"] + cbmap = get_app().mcp_callback_map + prop_lines: list[str] = [] + + for prop_name in getattr(component, "_prop_names", []): + if prop_name in _MCP_EXCLUDED_PROPS: + continue + + upstream = cbmap.find_by_output(f"{component_id}.{prop_name}") + if upstream is not None and not upstream.prevents_initial_call: + value = upstream.initial_output_value(f"{component_id}.{prop_name}") + else: + value = getattr(component, prop_name, None) + tool_name = upstream.tool_name if upstream is not None else None + + if value is None and tool_name is None: + continue + + component_type = param.get("component_type") + template = _PROP_TEMPLATES.get((component_type, prop_name)) + formatted_value = ( + _truncate_large_values(value, component_id, prop_name) + if value is not None + else None + ) + + if template and formatted_value is not None: + line = template.format(value=formatted_value) + elif formatted_value is not None: + line = f"{prop_name}: {formatted_value}" + else: + line = prop_name + + if tool_name: + line += f" (can be updated by tool: `{tool_name}`)" + + prop_lines.append(line) + + if not prop_lines: + return [] + return [f"Component properties for {component_id}:"] + prop_lines + + +def _truncate_large_values(value: Any, component_id: str, prop_name: str) -> str: + text = repr(value) + if len(text) > _MAX_VALUE_LENGTH: + hint = f"Use get_dash_component('{component_id}', '{prop_name}') for the full value" + return f"{text[:_MAX_VALUE_LENGTH]}... ({hint})" + return text diff --git a/dash/mcp/primitives/tools/input_schemas/input_descriptions/description_docstrings.py b/dash/mcp/primitives/tools/input_schemas/input_descriptions/description_docstrings.py new file mode 100644 index 0000000000..23045625bf --- /dev/null +++ b/dash/mcp/primitives/tools/input_schemas/input_descriptions/description_docstrings.py @@ -0,0 +1,77 @@ +"""Extract property descriptions from component class docstrings. + +Dash component classes have structured docstrings generated by +``dash-generate-components`` in the format:: + + Keyword arguments: + + - prop_name (type_string; optional): + Description text that may span + multiple lines. + +This module parses that format and returns the first sentence of the +description for a given property. +""" + +from __future__ import annotations + +import re + +from dash.mcp.types import MCPInput + +from .base import InputDescriptionSource + +_PROP_RE = re.compile( + r"^[ ]*- (\w+) \([^)]+\):\s*\n((?:[ ]+.+\n)*)", + re.MULTILINE, +) + +_cache: dict[type, dict[str, str]] = {} + +_SENTENCE_END = re.compile(r"(?<=[.!?])\s") + + +class DocstringPropDescription(InputDescriptionSource): + """Extract property description from the component's docstring.""" + + @classmethod + def describe(cls, param: MCPInput) -> list[str]: + component = param.get("component") + if component is None: + return [] + desc = _get_prop_description(type(component), param["property"]) + return [desc] if desc else [] + + +def _get_prop_description(cls: type, prop: str) -> str | None: + props = _parse_docstring(cls) + return props.get(prop) + + +def _parse_docstring(cls: type) -> dict[str, str]: + if cls in _cache: + return _cache[cls] + + doc = getattr(cls, "__doc__", None) + if not doc: + _cache[cls] = {} + return _cache[cls] + + props: dict[str, str] = {} + for match in _PROP_RE.finditer(doc): + prop_name = match.group(1) + raw_desc = match.group(2) + lines = [line.strip() for line in raw_desc.strip().splitlines()] + desc = " ".join(lines) + if desc: + props[prop_name] = _first_sentence(desc) + + _cache[cls] = props + return props + + +def _first_sentence(text: str) -> str: + m = _SENTENCE_END.search(text) + if m: + return text[: m.start() + 1].rstrip() + return text diff --git a/dash/mcp/primitives/tools/input_schemas/input_descriptions/description_html_labels.py b/dash/mcp/primitives/tools/input_schemas/input_descriptions/description_html_labels.py new file mode 100644 index 0000000000..111e1eaaf7 --- /dev/null +++ b/dash/mcp/primitives/tools/input_schemas/input_descriptions/description_html_labels.py @@ -0,0 +1,28 @@ +"""Label-based property descriptions. + +Reads the label map from the ``CallbackAdapterCollection``, +which builds it from the layout using ``htmlFor`` and +containment associations. +""" + +from __future__ import annotations + +from dash import get_app +from dash.mcp.types import MCPInput + +from .base import InputDescriptionSource + + +class LabelDescription(InputDescriptionSource): + """Return the label text for this component, if any.""" + + @classmethod + def describe(cls, param: MCPInput) -> list[str]: + component_id = param.get("component_id") + if not component_id: + return [] + label_map = get_app().mcp_callback_map.component_label_map + texts = label_map.get(component_id, []) + if texts: + return [f"Labeled with: {'; '.join(texts)}"] + return [] diff --git a/dash/mcp/primitives/tools/input_schemas/input_descriptions/description_pattern_matching.py b/dash/mcp/primitives/tools/input_schemas/input_descriptions/description_pattern_matching.py new file mode 100644 index 0000000000..d9a1a5a26a --- /dev/null +++ b/dash/mcp/primitives/tools/input_schemas/input_descriptions/description_pattern_matching.py @@ -0,0 +1,61 @@ +"""Description for pattern-matching callback inputs. + +Explains that the input corresponds to a pattern-matching callback +(ALL, MATCH, ALLSMALLER) and describes the expected format. +See: https://dash.plotly.com/pattern-matching-callbacks +""" + +from __future__ import annotations + +from dash._layout_utils import _WILDCARD_VALUES, parse_wildcard_id +from dash.mcp.types import MCPInput + +from .base import InputDescriptionSource + + +class PatternMatchingDescription(InputDescriptionSource): + """Describe pattern-matching behavior for wildcard inputs.""" + + @classmethod + def describe(cls, param: MCPInput) -> list[str]: + dep_id = parse_wildcard_id(param["component_id"]) + if dep_id is None: + return [] + + wildcard_key, wildcard_type = _find_wildcard(dep_id) + if wildcard_key is None or wildcard_type is None: + return [] + + non_wildcard = {k: v for k, v in dep_id.items() if k != wildcard_key} + pattern_desc = ", ".join(f'{k}="{v}"' for k, v in non_wildcard.items()) + prop = param["property"] + + wildcard_descriptions = { + "ALL": ( + f"Pattern-matching input (ALL): provide an array of `{prop}` values, " + f"one per component matching {{{pattern_desc}}}. " + f"All matching components are included." + ), + "MATCH": ( + f"Pattern-matching input (MATCH): provide the `{prop}` value " + f"for the specific component matching {{{pattern_desc}}} " + f"that triggered this callback." + ), + "ALLSMALLER": ( + f"Pattern-matching input (ALLSMALLER): provide an array of `{prop}` values " + f"from components matching {{{pattern_desc}}} " + f"whose `{wildcard_key}` is smaller than the triggering component's `{wildcard_key}`." + ), + } + + desc = wildcard_descriptions.get(wildcard_type) + return [desc] if desc else [] + + +def _find_wildcard(dep_id: dict) -> tuple[str | None, str | None]: + """Return (key, wildcard_type) for the first wildcard found.""" + for key, value in dep_id.items(): + if isinstance(value, list) and len(value) == 1: + if value[0] in _WILDCARD_VALUES: + return key, value[0] + return None, None diff --git a/dash/mcp/primitives/tools/input_schemas/schema_callback_type_annotations.py b/dash/mcp/primitives/tools/input_schemas/schema_callback_type_annotations.py new file mode 100644 index 0000000000..b862b124d6 --- /dev/null +++ b/dash/mcp/primitives/tools/input_schemas/schema_callback_type_annotations.py @@ -0,0 +1,65 @@ +"""Map callback function type annotations to JSON Schema. + +When a callback function has explicit type annotations, those take +priority over all other schema sources (static overrides, component +introspection). + +Unlike component annotations (where nullable means "not required"), +callback annotations preserve ``null`` in the schema type when the +user writes ``Optional[X]`` — the user is explicitly saying the +value can be null. + +Also provides ``annotation_to_json_schema``, the shared low-level +converter used by both callback and component annotation pipelines. +""" + +from __future__ import annotations + +import inspect +from typing import Any + +from pydantic import TypeAdapter + +from dash.development.base_component import Component +from dash.mcp.types import MCPInput, is_nullable + +from .base import InputSchemaSource + + +def annotation_to_json_schema(annotation: type) -> dict[str, Any] | None: + """Convert a Python type annotation to a JSON Schema dict. + + Returns ``None`` if the annotation cannot be translated. + """ + if annotation is inspect.Parameter.empty or annotation is type(None): + return None + + if isinstance(annotation, type) and issubclass(annotation, Component): + return {"type": "string"} + + try: + return TypeAdapter(annotation).json_schema() + except Exception: # pylint: disable=broad-exception-caught + return None + + +class AnnotationSchema(InputSchemaSource): + """Derive JSON Schema from the callback parameter's type annotation.""" + + @classmethod + def get_schema(cls, param: MCPInput) -> dict[str, Any] | None: + annotation = param.get("annotation") + if annotation is None: + return None + schema = annotation_to_json_schema(annotation) + if schema is None: + return None + + if is_nullable(annotation) and schema: + t = schema.get("type") + if isinstance(t, str): + schema = {**schema, "type": [t, "null"]} + elif isinstance(t, list) and "null" not in t: + schema = {**schema, "type": [*t, "null"]} + + return schema diff --git a/dash/mcp/primitives/tools/input_schemas/schema_component_proptypes.py b/dash/mcp/primitives/tools/input_schemas/schema_component_proptypes.py new file mode 100644 index 0000000000..d7f72d81ff --- /dev/null +++ b/dash/mcp/primitives/tools/input_schemas/schema_component_proptypes.py @@ -0,0 +1,37 @@ +"""Derive JSON Schema from a component's ``__init__`` type annotations.""" + +from __future__ import annotations + +import inspect +from typing import Any + +from dash.mcp.types import MCPInput + +from .base import InputSchemaSource +from .schema_callback_type_annotations import annotation_to_json_schema + + +class ComponentPropSchema(InputSchemaSource): + """Derive JSON Schema from a component's ``__init__`` type annotations. + + Inspects the ``__init__`` signature of the component's class. + Returns ``None`` if the prop has no annotation. + """ + + @classmethod + def get_schema(cls, param: MCPInput) -> dict[str, Any] | None: + component = param.get("component") + prop = param["property"] + if component is None: + return None + + try: + sig = inspect.signature(type(component).__init__) + except (ValueError, TypeError): + return None + + sig_param = sig.parameters.get(prop) + if sig_param is None or sig_param.annotation is inspect.Parameter.empty: + return None + + return annotation_to_json_schema(sig_param.annotation) diff --git a/dash/mcp/primitives/tools/input_schemas/schema_component_proptypes_overrides.py b/dash/mcp/primitives/tools/input_schemas/schema_component_proptypes_overrides.py new file mode 100644 index 0000000000..e3d5b65756 --- /dev/null +++ b/dash/mcp/primitives/tools/input_schemas/schema_component_proptypes_overrides.py @@ -0,0 +1,31 @@ +"""Input schema overrides drawn from the ``PropRole`` registry. + +Looks up the parameter's ``(component_type, property)`` in the shared +registry and returns any attached ``input_schema``. Used when default +type introspection produces insufficient results. +""" + +from __future__ import annotations + +from typing import Any + +from dash.mcp.types import MCPInput + +from ..prop_roles import iter_prop_roles +from .base import InputSchemaSource + + +class OverrideSchema(InputSchemaSource): + """Return a schema override, or None to fall through to introspection.""" + + @classmethod + def get_schema(cls, param: MCPInput) -> dict[str, Any] | None: + component_type = param.get("component_type") + prop = param["property"] + for role in iter_prop_roles(): + if role.input_schema is None or not role.matches(component_type, prop): + continue + if callable(role.input_schema): + return role.input_schema(param) + return dict(role.input_schema) + return None diff --git a/dash/mcp/primitives/tools/input_schemas/schema_pattern_matching.py b/dash/mcp/primitives/tools/input_schemas/schema_pattern_matching.py new file mode 100644 index 0000000000..093dc197b8 --- /dev/null +++ b/dash/mcp/primitives/tools/input_schemas/schema_pattern_matching.py @@ -0,0 +1,87 @@ +"""Schema for pattern-matching callback inputs (ALL, MATCH, ALLSMALLER). + +When a callback input uses a wildcard ID, the callback receives a +list of values — one per matching component. This source detects +wildcard IDs and produces an array schema. If matching components +exist in the layout, the item type is inferred from a concrete match. +""" + +from __future__ import annotations + +from typing import Any + +from dash._layout_utils import ( + _WILDCARD_VALUES, + find_matching_components, + parse_wildcard_id, +) +from dash.mcp.types import MCPInput + +from .base import InputSchemaSource + + +class PatternMatchingSchema(InputSchemaSource): + """Return a schema for pattern-matching inputs. + + For ALL/ALLSMALLER: array of ``{id, property, value}`` objects. + For MATCH: a single ``{id, property, value}`` object. + """ + + @classmethod + def get_schema(cls, param: MCPInput) -> dict[str, Any] | None: + dep_id = parse_wildcard_id(param["component_id"]) + if dep_id is None: + return None + + wildcard_type = _get_wildcard_type(dep_id) + if wildcard_type is None: + return None + + value_schema = _infer_value_schema(param) + + item_schema: dict[str, Any] = { + "type": "object", + "properties": { + "id": {"type": "object"}, + "property": {"type": "string"}, + "value": value_schema or {}, + }, + "required": ["id", "property", "value"], + } + + if wildcard_type == "MATCH": + return item_schema + + return {"type": "array", "items": item_schema} + + +def _get_wildcard_type(dep_id: dict) -> str | None: + """Return the wildcard type (ALL, MATCH, ALLSMALLER) or None.""" + for value in dep_id.values(): + if isinstance(value, list) and len(value) == 1: + if value[0] in _WILDCARD_VALUES: + return value[0] + return None + + +def _infer_value_schema(param: MCPInput) -> dict[str, Any] | None: + """Infer the JSON Schema for the ``value`` field from a matching component.""" + pattern = parse_wildcard_id(param["component_id"]) + if pattern is None: + return None + matches = find_matching_components(pattern) + if not matches: + return None + + # pylint: disable-next=cyclic-import,import-outside-toplevel + from . import get_input_schema + + concrete_param: MCPInput = { + **param, + "component": matches[0], + "component_id": str(getattr(matches[0], "id", "")), + "component_type": getattr(matches[0], "_type", None), + } + schema = get_input_schema(concrete_param) + schema.pop("description", None) + return schema or None diff --git a/dash/mcp/primitives/tools/output_schemas/__init__.py b/dash/mcp/primitives/tools/output_schemas/__init__.py new file mode 100644 index 0000000000..41ddfd8d49 --- /dev/null +++ b/dash/mcp/primitives/tools/output_schemas/__init__.py @@ -0,0 +1,29 @@ +"""Output schema generation for MCP tool outputSchema fields. + +Mirrors ``input_schemas/`` which generates ``inputSchema``. + +Each source shares the same signature: ``() -> dict | None``. +""" + +from __future__ import annotations + +from typing import Any + +from .schema_callback_response import callback_response_schema + +_SOURCES = [ + callback_response_schema, +] + + +def get_output_schema() -> dict[str, Any]: + """Return the JSON Schema for a callback tool's output. + + Tries each source in order, returning the first non-None result. + Falls back to ``{}`` (any type). + """ + for source in _SOURCES: + schema = source() + if schema is not None: + return schema + return {} diff --git a/dash/mcp/primitives/tools/output_schemas/schema_callback_response.py b/dash/mcp/primitives/tools/output_schemas/schema_callback_response.py new file mode 100644 index 0000000000..6962fb4a4f --- /dev/null +++ b/dash/mcp/primitives/tools/output_schemas/schema_callback_response.py @@ -0,0 +1,16 @@ +"""Output schema derived from CallbackExecutionResponse.""" + +from __future__ import annotations + +from typing import Any + +from pydantic import TypeAdapter + +from dash.types import CallbackExecutionResponse + +_schema = TypeAdapter(CallbackExecutionResponse).json_schema() + + +def callback_response_schema() -> dict[str, Any]: + """Return the JSON Schema for a callback dispatch response.""" + return _schema diff --git a/dash/mcp/primitives/tools/prop_roles.py b/dash/mcp/primitives/tools/prop_roles.py new file mode 100644 index 0000000000..64fdc8f76d --- /dev/null +++ b/dash/mcp/primitives/tools/prop_roles.py @@ -0,0 +1,151 @@ +"""Canonical registry of semantic roles for Dash component props. + +A ``PropRole`` bundles the set of ``(component_type, property)`` pairs +that play the same role with the metadata attached to that role: +an LLM-facing description, an input JSON Schema, etc. Tool descriptions, +input-schema overrides, and result formatters all consume this registry +so they can't drift. + +Use ``ANY_COMPONENT`` as the component_type sentinel to match any component with +the given property name. + +Declaration order matters: ``iter_prop_roles()`` yields roles in the +order they're defined in this module, and the first match wins. List +concrete-match roles before wildcard-match roles that share a prop +name (e.g. ``MARKDOWN`` before ``CONTENT`` for ``children``). +""" + +from __future__ import annotations + +from typing import Any, Callable, Dict, Iterator, NamedTuple, Union + +from typing_extensions import TypeAlias + +from dash.mcp.types import MCPInput + +PropSchema = Union[ + Dict[str, Any], + Callable[[MCPInput], Dict[str, Any]], +] + +COMPONENT: TypeAlias = Union[str, None] +ANY_COMPONENT: None = None +PROP: TypeAlias = str + + +class PropRole(NamedTuple): + identifiers: set[tuple[COMPONENT, PROP]] + description: str | None = None + input_schema: PropSchema | None = None + + def matches(self, component_type: COMPONENT, prop: PROP) -> bool: + """True if this role applies to the given ``(component_type, prop)``. + + Matches either a concrete entry or an ``ANY_COMPONENT`` wildcard + entry in ``identifiers``. Shared by every consumer so all metadata + fields apply uniformly to every identifier in the role. + """ + return (component_type, prop) in self.identifiers or ( + ANY_COMPONENT, + prop, + ) in self.identifiers + + +def _compute_dropdown_value_schema(param: MCPInput) -> dict[str, Any]: + """Dropdown values are an array if ``multi=True``; scalar otherwise.""" + _DROPDOWN_SCALAR_TYPE = { + "anyOf": [{"type": "string"}, {"type": "number"}, {"type": "boolean"}] + } + component = param.get("component") + if getattr(component, "multi", False): + return {"type": "array", "items": _DROPDOWN_SCALAR_TYPE} + return _DROPDOWN_SCALAR_TYPE + + +TABULAR = PropRole( + identifiers={("DataTable", "data"), ("AgGrid", "rowData")}, + description="Returns tabular data", +) + +DATE = PropRole( + identifiers={ + ("DatePickerSingle", "date"), + ("DatePickerRange", "start_date"), + ("DatePickerRange", "end_date"), + }, + input_schema={ + "type": "string", + "format": "date", + "pattern": r"^\d{4}-\d{2}-\d{2}$", + }, +) + +DROPDOWN_VALUE = PropRole( + identifiers={("Dropdown", "value")}, + input_schema=_compute_dropdown_value_schema, +) + +STORE_DATA = PropRole( + identifiers={("Store", "data")}, + description="Returns data to be remembered client-side", +) + +DOWNLOAD = PropRole( + identifiers={("Download", "data")}, + description="Returns downloadable content", +) + +MARKDOWN = PropRole( + identifiers={("Markdown", "children")}, + description="Returns formatted text", +) + +PLOTLY_FIGURE = PropRole( + identifiers={(ANY_COMPONENT, "figure")}, + description="Returns chart/visualization data", + input_schema={ + "type": "object", + "properties": { + "data": {"type": "array", "items": {"type": "object"}}, + "layout": {"type": "object"}, + "frames": {"type": "array", "items": {"type": "object"}}, + }, + }, +) + +GENERIC_CONTENT = PropRole( + identifiers={(ANY_COMPONENT, "children")}, + description="Returns content", +) + +GENERIC_VALUE = PropRole( + identifiers={(ANY_COMPONENT, "value")}, + description="Returns the current value", +) + +GENERIC_OPTIONS = PropRole( + identifiers={(ANY_COMPONENT, "options")}, + description="Returns available options", +) + +GENERIC_COLUMNS = PropRole( + identifiers={(ANY_COMPONENT, "columns")}, + description="Returns column definitions", +) + +GENERIC_STYLE = PropRole( + identifiers={(ANY_COMPONENT, "style")}, + description="Updates styling", +) + +GENERIC_DISABLED = PropRole( + identifiers={(ANY_COMPONENT, "disabled")}, + description="Updates enabled/disabled state", +) + + +def iter_prop_roles() -> Iterator[PropRole]: + """Yield every PropRole defined in this module in declaration order.""" + for value in globals().values(): + if isinstance(value, PropRole): + yield value diff --git a/dash/mcp/primitives/tools/results/__init__.py b/dash/mcp/primitives/tools/results/__init__.py new file mode 100644 index 0000000000..843ddbb4f2 --- /dev/null +++ b/dash/mcp/primitives/tools/results/__init__.py @@ -0,0 +1,84 @@ +"""Tool result formatting for MCP tools/call responses. + +Each formatter is a ``ResultFormatter`` subclass that can enrich +a tool result with additional content. All formatters are accumulated. +""" + +from __future__ import annotations + +import json +from typing import TYPE_CHECKING, Any + +from mcp.types import CallToolResult, CreateTaskResult, TextContent + +from dash.types import CallbackExecutionResponse + +from .base import ResultFormatter +from .result_dataframe import DataFrameResult +from .result_plotly_figure import PlotlyFigureResult + +if TYPE_CHECKING: + from dash.mcp.primitives.tools.callback_adapter import CallbackAdapter + +_RESULT_FORMATTERS: list[type[ResultFormatter]] = [ + PlotlyFigureResult, + DataFrameResult, +] + + +def format_callback_response( + response: CallbackExecutionResponse, + callback: CallbackAdapter, +) -> CallToolResult: + """Format a callback response as a CallToolResult. + + The response is always returned as structuredContent. Result + formatters are called per output property and may add additional + content items (images, markdown, etc.). + """ + content: list[Any] = [ + TextContent(type="text", text=json.dumps(response, default=str)), + ] + + resp = response.get("response") or {} + for callback_output in callback.outputs: + value = resp.get(callback_output["component_id"], {}).get( + callback_output["property"] + ) + for formatter in _RESULT_FORMATTERS: + content.extend(formatter.format(callback_output, value)) + + return CallToolResult( + content=content, + structuredContent=dict(response), + ) + + +def task_result_to_tool_result(create_task_result: CreateTaskResult) -> CallToolResult: + """ + Wrap a CreateTaskResult as a CallToolResult with polling instructions. + + MCP Tasks are not yet supported by LLM clients, so this converts the + task metadata into a tool response that guides the LLM to poll via + the get_background_task_result tool. + """ + task = create_task_result.task + return CallToolResult( + content=[ + TextContent( + type="text", + text=json.dumps( + { + "taskId": task.taskId, + "status": task.status, + "pollInterval": task.pollInterval, + "message": ( + "This is a long-running background callback. " + "Call the get_background_task_result tool with this taskId " + "to poll for the result." + ), + } + ), + ) + ], + ) diff --git a/dash/mcp/primitives/tools/results/base.py b/dash/mcp/primitives/tools/results/base.py new file mode 100644 index 0000000000..1f7714ff6b --- /dev/null +++ b/dash/mcp/primitives/tools/results/base.py @@ -0,0 +1,24 @@ +"""Base class for result formatters.""" + +from __future__ import annotations + +from typing import Any + +from mcp.types import ImageContent, TextContent + +from dash.mcp.types import MCPOutput + + +class ResultFormatter: + """A formatter that can enrich an MCP tool result with additional content. + + Subclasses implement ``format`` to return content items (text, images) + for a specific callback output. All formatters are accumulated — every + formatter can add content to the overall tool result. + """ + + @classmethod + def format( + cls, output: MCPOutput, returned_output_value: Any + ) -> list[TextContent | ImageContent]: + raise NotImplementedError diff --git a/dash/mcp/primitives/tools/results/result_dataframe.py b/dash/mcp/primitives/tools/results/result_dataframe.py new file mode 100644 index 0000000000..780bb264ef --- /dev/null +++ b/dash/mcp/primitives/tools/results/result_dataframe.py @@ -0,0 +1,61 @@ +"""Tabular data result: render as a markdown table. + +Detects tabular output by component type and prop name: +- DataTable.data +- AgGrid.rowData +""" + +from __future__ import annotations + +from typing import Any + +from mcp.types import TextContent + +from dash.mcp.types import MCPOutput + +from ..prop_roles import TABULAR +from .base import ResultFormatter + +MAX_ROWS = 50 + + +def _to_markdown_table(rows: list[dict], max_rows: int = MAX_ROWS) -> str: + """Render a list of row dicts as a markdown table.""" + columns = list(rows[0].keys()) + total_rows = len(rows) + + lines: list[str] = [] + lines.append(f"*{total_rows} rows \u00d7 {len(columns)} columns*") + lines.append("") + lines.append(" | ".join(columns)) + lines.append(" | ".join("---" for _ in columns)) + + for row in rows[:max_rows]: + cells = [ + str(row.get(col, "")).replace("|", "\\|").replace("\n", " ") + for col in columns + ] + lines.append(" | ".join(cells)) + + if total_rows > max_rows: + lines.append(f"\n(\u2026 {total_rows - max_rows} more rows)") + + return "\n".join(lines) + + +class DataFrameResult(ResultFormatter): + """Produce a markdown table for tabular component output values.""" + + @classmethod + def format(cls, output: MCPOutput, returned_output_value: Any) -> list[TextContent]: # type: ignore[override] + if not TABULAR.matches(output.get("component_type"), output["property"]): + return [] + if ( + not returned_output_value + or not isinstance(returned_output_value, list) + or not isinstance(returned_output_value[0], dict) + ): + return [] + return [ + TextContent(type="text", text=_to_markdown_table(returned_output_value)) + ] diff --git a/dash/mcp/primitives/tools/results/result_plotly_figure.py b/dash/mcp/primitives/tools/results/result_plotly_figure.py new file mode 100644 index 0000000000..b7f4273933 --- /dev/null +++ b/dash/mcp/primitives/tools/results/result_plotly_figure.py @@ -0,0 +1,57 @@ +"""Plotly figure tool result: rendered image.""" + +from __future__ import annotations + +import base64 +import logging +from typing import Any + +import plotly.graph_objects as go # type: ignore[import-untyped] +from mcp.types import ImageContent, TextContent + +from dash.mcp.types import MCPOutput + +from ..prop_roles import PLOTLY_FIGURE +from .base import ResultFormatter + +logger = logging.getLogger(__name__) + +IMAGE_WIDTH = 700 +IMAGE_HEIGHT = 450 + + +def _render_image(figure: Any) -> ImageContent | None: + """ + Render the figure as a base64 PNG ImageContent. + + Returns None if kaleido is not installed. + """ + try: + img_bytes = figure.to_image( + format="png", + width=IMAGE_WIDTH, + height=IMAGE_HEIGHT, + ) + except (ValueError, ImportError): + logger.debug("MCP: kaleido not available, skipping image render") + return None + + b64 = base64.b64encode(img_bytes).decode("ascii") + return ImageContent(type="image", data=b64, mimeType="image/png") + + +class PlotlyFigureResult(ResultFormatter): + """Produce a rendered PNG for Graph.figure output values.""" + + @classmethod + def format( + cls, output: MCPOutput, returned_output_value: Any + ) -> list[TextContent | ImageContent]: + if not PLOTLY_FIGURE.matches(output.get("component_type"), output["property"]): + return [] + if not returned_output_value or not isinstance(returned_output_value, dict): + return [] + + fig = go.Figure(returned_output_value) + image = _render_image(fig) + return [image] if image is not None else [] diff --git a/dash/mcp/primitives/tools/tool_background_tasks.py b/dash/mcp/primitives/tools/tool_background_tasks.py new file mode 100644 index 0000000000..a4dffa23f4 --- /dev/null +++ b/dash/mcp/primitives/tools/tool_background_tasks.py @@ -0,0 +1,102 @@ +"""Built-in tools for background callback task lifecycle. + +Thin wrappers around the spec-aligned core in dash.mcp.tasks. +Only registered when the app has background callbacks. +""" + +from __future__ import annotations + +from typing import Any + +from mcp.types import CallToolResult, TextContent, Tool + +from dash import get_app +from dash.mcp.tasks import get_task, get_task_result, cancel_task + +from .base import MCPToolProvider + + +GET_RESULT_TOOL_NAME = "get_background_task_result" +CANCEL_TOOL_NAME = "cancel_background_task" + + +def _has_background_callbacks() -> bool: + return any(cb_info.get("background") for cb_info in get_app().callback_map.values()) + + +class BackgroundTaskTools(MCPToolProvider): + """Built-in tools for polling and cancelling background callback tasks. + + Only registered when the app has background callbacks. + """ + + @classmethod + def get_tool_names(cls) -> set[str]: + if not _has_background_callbacks(): + return set() + return {GET_RESULT_TOOL_NAME, CANCEL_TOOL_NAME} + + @classmethod + def list_tools(cls) -> list[Tool]: + if not _has_background_callbacks(): + return [] + return [ + Tool( + name=GET_RESULT_TOOL_NAME, + description=( + "Poll for the result of a long-running background callback. " + "Pass the taskId returned by the original tool call. " + "If the task is still running, call this tool again. " + "If complete, returns the callback result." + ), + inputSchema={ + "type": "object", + "properties": { + "taskId": { + "type": "string", + "description": "The taskId returned by the background callback tool.", + }, + }, + "required": ["taskId"], + }, + ), + Tool( + name=CANCEL_TOOL_NAME, + description="Cancel a running background callback.", + inputSchema={ + "type": "object", + "properties": { + "taskId": { + "type": "string", + "description": "The taskId of the background task to cancel.", + }, + }, + "required": ["taskId"], + }, + ), + ] + + @classmethod + def call_tool( + cls, + tool_name: str, + arguments: dict[str, Any], + task: dict | None = None, + ) -> CallToolResult: + task_id = arguments.get("taskId", "") + + if tool_name == GET_RESULT_TOOL_NAME: + task_status = get_task(task_id) + if task_status.status == "completed": + return get_task_result(task_id) + return CallToolResult( + content=[TextContent(type="text", text=task_status.model_dump_json())], + ) + + if tool_name == CANCEL_TOOL_NAME: + result = cancel_task(task_id) + return CallToolResult( + content=[TextContent(type="text", text=result.model_dump_json())], + ) + + raise ValueError(f"Unknown tool: {tool_name}") diff --git a/dash/mcp/primitives/tools/tool_decorated_mcp_functions.py b/dash/mcp/primitives/tools/tool_decorated_mcp_functions.py new file mode 100644 index 0000000000..0b3edbbcbe --- /dev/null +++ b/dash/mcp/primitives/tools/tool_decorated_mcp_functions.py @@ -0,0 +1,150 @@ +"""MCP tools backed by ``@mcp_enabled``-decorated functions.""" + +from __future__ import annotations + +import inspect +import json +import typing +from typing import Any, Callable + +from mcp.types import CallToolResult, TextContent, Tool + +from dash import get_app +from dash.mcp._decorator import MCPToolRegistration +from dash.mcp.primitives.tools.input_schemas import get_input_schema +from dash.mcp.primitives.tools.input_schemas.schema_callback_type_annotations import ( + annotation_to_json_schema, +) +from dash.mcp.types import MCPInput, is_nullable + +from .base import MCPToolProvider + + +def _build_inputs(fn: Callable[..., Any]) -> list[MCPInput]: + """Build an ``MCPInput`` from each of the function's arguments.""" + try: + hints = typing.get_type_hints(fn) + except Exception: # pylint: disable=broad-exception-caught + hints = getattr(fn, "__annotations__", {}) + + sig = inspect.signature(fn) + inputs: list[MCPInput] = [] + + for name, param in sig.parameters.items(): + annotation = hints.get(name) + + has_default = param.default is not inspect.Parameter.empty + required = not has_default and ( + annotation is None or not is_nullable(annotation) + ) + + inputs.append( + MCPInput( + name=name, + id_and_prop="", + component_id="", + property="", + annotation=annotation, + component_type=None, + component=None, + required=required, + initial_value=param.default if has_default else None, + upstream_output=None, + ) + ) + return inputs + + +def _build_output_schema(fn: Callable[..., Any]) -> dict[str, Any]: + """Build a JSON Schema ``outputSchema`` from the return annotation. + + The schema wraps the return type in ``{"result": }`` to match + the object that ``call_tool`` returns as ``structuredContent``. + """ + try: + hints = typing.get_type_hints(fn) + except Exception: # pylint: disable=broad-exception-caught + hints = getattr(fn, "__annotations__", {}) + + ret = hints.get("return") + if ret is None: + return {} + + inner = annotation_to_json_schema(ret) + if inner is None: + return {} + + return { + "type": "object", + "properties": {"result": inner}, + "required": ["result"], + } + + +def _build_tool(tool_name: str, reg: MCPToolRegistration) -> Tool: + fn = reg["fn"] + inputs = _build_inputs(fn) + properties = {p["name"]: get_input_schema(p) for p in inputs} + required = [p["name"] for p in inputs if p["required"]] + + input_schema: dict[str, Any] = {"type": "object", "properties": properties} + if required: + input_schema["required"] = required + + expose_docstring = reg["expose_docstring"] + if expose_docstring is None: + expose_docstring = get_app().config.get("mcp_expose_docstrings", False) + + description = "MCP tool" + if expose_docstring: + docstring = getattr(fn, "__doc__", None) + if docstring: + description = docstring.strip() + + return Tool( + name=tool_name, + description=description, + inputSchema=input_schema, + outputSchema=_build_output_schema(fn), + ) + + +class DecoratedFunctionTools(MCPToolProvider): + """Exposes ``@mcp_enabled``-decorated functions as MCP tools.""" + + @classmethod + def _registry(cls) -> dict[str, MCPToolRegistration]: + return get_app().mcp_decorated_functions + + @classmethod + def get_tool_names(cls) -> set[str]: + return set(cls._registry().keys()) + + @classmethod + def list_tools(cls) -> list[Tool]: + return [_build_tool(name, reg) for name, reg in cls._registry().items()] + + @classmethod + def call_tool( + cls, tool_name: str, arguments: dict[str, Any], task: dict | None = None + ) -> CallToolResult: + reg = cls._registry().get(tool_name) + if reg is None: + return CallToolResult( + content=[TextContent(type="text", text=f"Tool not found: {tool_name}")], + isError=True, + ) + fn = reg["fn"] + try: + result = fn(**arguments) + except Exception as exc: # pylint: disable=broad-exception-caught + return CallToolResult( + content=[TextContent(type="text", text=f"{type(exc).__name__}: {exc}")], + isError=True, + ) + + serialized = json.dumps(result, default=str) + return CallToolResult( + content=[TextContent(type="text", text=serialized)], + structuredContent={"result": result}, + ) diff --git a/dash/mcp/primitives/tools/tool_get_dash_component.py b/dash/mcp/primitives/tools/tool_get_dash_component.py new file mode 100644 index 0000000000..d8ba9b6e49 --- /dev/null +++ b/dash/mcp/primitives/tools/tool_get_dash_component.py @@ -0,0 +1,134 @@ +"""Built-in tool: get_dash_component.""" + +from __future__ import annotations + +import json +from typing import Any + +from mcp.types import CallToolResult, TextContent, Tool +from pydantic import Field, TypeAdapter +from typing_extensions import Annotated, NotRequired, TypedDict + +from dash import get_app +from dash._layout_utils import find_component +from dash.mcp.types import ComponentPropertyInfo, ComponentQueryResult + +from .base import MCPToolProvider + + +class _ComponentQueryInput(TypedDict): + component_id: Annotated[str, Field(description="The component ID to query")] + property: NotRequired[ + Annotated[ + str, + Field( + description="The property name to read (e.g. 'options', 'value'). Omit to list all defined properties." + ), + ] + ] + + +_INPUT_SCHEMA = TypeAdapter(_ComponentQueryInput).json_schema() +_OUTPUT_SCHEMA = TypeAdapter(ComponentQueryResult).json_schema() + +NAME = "get_dash_component" + + +class GetDashComponentTool(MCPToolProvider): + """Inspects a component's properties and its tool relationships.""" + + @classmethod + def get_tool_names(cls) -> set[str]: + return {NAME} + + @classmethod + def list_tools(cls) -> list[Tool]: + return [ + Tool( + name=NAME, + description=( + "Get a component's properties, values, and tool relationships. " + "If property is omitted, returns all defined properties. " + "If property is specified, returns only that property. " + "See the dash://components resource for available component IDs." + ), + inputSchema=_INPUT_SCHEMA, + outputSchema=_OUTPUT_SCHEMA, + ) + ] + + @classmethod + def call_tool( + cls, + tool_name: str, + arguments: dict[str, Any], + task: dict | None = None, + ) -> CallToolResult: + comp_id = arguments.get("component_id", "") + if not comp_id: + return CallToolResult( + content=[TextContent(type="text", text="component_id is required")], + isError=True, + ) + + prop_filter = arguments.get("property", "") + component = find_component(comp_id) + callback_map = get_app().mcp_callback_map + + if component is None: + rendering_tools = [ + cb.tool_name + for cb in callback_map + if any(out["component_id"] == comp_id for out in cb.outputs) + ] + msg = f"Component '{comp_id}' not found in static layout." + if rendering_tools: + msg += ( + f" However, the following tools would modify it: {rendering_tools}." + ) + msg += " Use the dash://components resource to see statically available component IDs." + return CallToolResult( + content=[TextContent(type="text", text=msg)], + isError=True, + ) + + properties: dict[str, ComponentPropertyInfo] = {} + for prop_name in getattr(component, "_prop_names", []): + if prop_filter and prop_name != prop_filter: + continue + + id_and_prop = f"{comp_id}.{prop_name}" + value = callback_map.get_initial_value(id_and_prop) + if value is None: + value = getattr(component, prop_name, None) + if value is None: + continue + + modified_by = [ + cb.tool_name for cb in callback_map.outputs_by_prop.get(id_and_prop, []) + ] + input_to = [ + cb.tool_name for cb in callback_map.inputs_by_prop.get(id_and_prop, []) + ] + + properties[prop_name] = ComponentPropertyInfo( + initial_value=value, + modified_by_tool=modified_by, + input_to_tool=input_to, + ) + + labels = callback_map.component_label_map.get(comp_id, []) + + structured: ComponentQueryResult = ComponentQueryResult( + component_id=comp_id, + component_type=type(component).__name__, + label=labels if labels else None, + properties=properties, + ) + + return CallToolResult( + content=[ + TextContent(type="text", text=json.dumps(structured, default=str)) + ], + structuredContent=dict(structured), + ) diff --git a/dash/mcp/primitives/tools/tools_callbacks.py b/dash/mcp/primitives/tools/tools_callbacks.py new file mode 100644 index 0000000000..2a2d866ea9 --- /dev/null +++ b/dash/mcp/primitives/tools/tools_callbacks.py @@ -0,0 +1,67 @@ +"""Dynamic callback tools for MCP. + +Exposes every server-callable callback as an MCP tool. +""" + +from __future__ import annotations + +from typing import Any + +from mcp.types import CallToolResult, CreateTaskResult, TextContent, Tool + +from dash import get_app +from dash.mcp.tasks import create_task +from dash.mcp.types import CallbackExecutionError, ToolNotFoundError + +from .base import MCPToolProvider +from .callback_utils import run_callback +from .results import format_callback_response, task_result_to_tool_result + + +class CallbackTools(MCPToolProvider): + """Exposes every server-callable callback as an MCP tool.""" + + @classmethod + def get_tool_names(cls) -> set[str]: + return get_app().mcp_callback_map.tool_names + + @classmethod + def list_tools(cls) -> list[Tool]: + """Return one Tool per server-callable callback.""" + return get_app().mcp_callback_map.as_mcp_tools() + + @classmethod + def call_tool( + cls, + tool_name: str, + arguments: dict[str, Any], + task: dict | None = None, + ) -> CallToolResult | CreateTaskResult: + """Execute a callback tool by name.""" + callback_map = get_app().mcp_callback_map + cb = callback_map.find_by_tool_name(tool_name) + if cb is None: + raise ToolNotFoundError( + f"Tool not found: {tool_name}." + " The app's callbacks may have changed." + " Please call tools/list to refresh your tool list." + ) + + # pylint: disable-next=protected-access + is_background = bool(cb._cb_info.get("background")) + + try: + callback_response = run_callback(cb, arguments) + except CallbackExecutionError as e: + return CallToolResult( + content=[TextContent(type="text", text=str(e))], + isError=True, + ) + + if is_background: + task_result = create_task(dict(callback_response), cb) + if task is not None: + return task_result + return task_result_to_tool_result(task_result) + + return format_callback_response(callback_response, cb) diff --git a/dash/mcp/tasks/__init__.py b/dash/mcp/tasks/__init__.py new file mode 100644 index 0000000000..8b78741d60 --- /dev/null +++ b/dash/mcp/tasks/__init__.py @@ -0,0 +1,5 @@ +"""MCP Tasks — lifecycle management for background callback execution.""" + +from .tasks import create_task, get_task, get_task_result, cancel_task + +__all__ = ["create_task", "get_task", "get_task_result", "cancel_task"] diff --git a/dash/mcp/tasks/tasks.py b/dash/mcp/tasks/tasks.py new file mode 100644 index 0000000000..5989dfde55 --- /dev/null +++ b/dash/mcp/tasks/tasks.py @@ -0,0 +1,193 @@ +"""Handler functions for MCP tasks/* methods.""" + +from __future__ import annotations + +from datetime import datetime, timezone +from typing import Any, Literal + +from mcp.types import CancelTaskResult, CreateTaskResult, GetTaskResult, Task + +from dash import get_app +from dash._callback import _update_background_callback, _prepare_response +from dash._utils import AttributeDict, split_callback_id +from dash.development.base_component import ComponentRegistry +from dash.mcp.primitives.tools.results import format_callback_response +from dash.mcp.types import MCPError +from dash.types import CallbackExecutionResponse + + +def parse_task_id(task_id: str) -> tuple[str, str, str, datetime]: + """Parse a taskId into (tool_name, job_id, cache_key, created_at).""" + try: + tool_name, job_id, rest = task_id.split(":", 2) + cache_key, created_epoch = rest.rsplit(":", 1) + created_at = datetime.fromtimestamp(int(created_epoch), tz=timezone.utc) + except (ValueError, TypeError) as exc: + raise MCPError(f"Malformed taskId: {task_id!r}") from exc + return tool_name, job_id, cache_key, created_at + + +def _get_callback_manager(tool_name: str): + """Resolve the background callback manager for a specific MCP tool.""" + adapter = get_app().mcp_callback_map.find_by_tool_name(tool_name) + if adapter is None: + return None + # pylint: disable-next=protected-access + return adapter._cb_info.get("manager") + + +def create_task(dispatch_response: dict[str, Any], callback) -> CreateTaskResult: + """Create a Task from a background callback's initial dispatch response.""" + cache_key = dispatch_response["cacheKey"] + job_id = str(dispatch_response["job"]) + now = datetime.now(timezone.utc) + # taskId encodes creation time so subsequent polls return a stable + # createdAt without server-side storage (works across gunicorn workers). + task_id = f"{callback.tool_name}:{job_id}:{cache_key}:{int(now.timestamp())}" + # pylint: disable-next=protected-access + interval = callback._cb_info.get("background", {}).get("interval", 1000) + return CreateTaskResult( + task=Task( + taskId=task_id, + status="working", + createdAt=now, + lastUpdatedAt=now, + ttl=None, + pollInterval=interval, + ), + ) + + +def get_task(task_id: str) -> GetTaskResult: + """Handle tasks/get — derive status from the callback manager.""" + tool_name, job_id, cache_key, created_at = parse_task_id(task_id) + + manager = _get_callback_manager(tool_name) + if manager is None: + return GetTaskResult( + taskId=task_id, + status="failed", + statusMessage="No background callback manager configured.", + createdAt=created_at, + lastUpdatedAt=datetime.now(timezone.utc), + ttl=None, + ) + + running = manager.job_running(job_id) + progress = manager.get_progress(cache_key) + + # Check result_ready before job_running: the process may store its result + # while still technically alive (teardown race), so result_ready wins. + status: Literal["working", "completed", "failed"] + if manager.result_ready(cache_key): + status = "completed" + elif running: + status = "working" + else: + status = "failed" + + adapter = get_app().mcp_callback_map.find_by_tool_name(tool_name) + interval = None + if adapter is not None: + # pylint: disable-next=protected-access + interval = adapter._cb_info.get("background", {}).get("interval", 1000) + + return GetTaskResult( + taskId=task_id, + status=status, + statusMessage=str(progress) if progress else None, + createdAt=created_at, + lastUpdatedAt=datetime.now(timezone.utc), + ttl=manager.expire * 1000 if manager.expire else None, + pollInterval=interval, + ) + + +def get_task_result(task_id: str) -> Any: + """Handle tasks/result — retrieve and format the callback result. + + Uses the framework's background callback retrieval functions + (``_update_background_callback`` and ``_prepare_response``) with + ``cache_key`` and ``job_id`` supplied directly, bypassing the + request adapter query-param lookup. + """ + tool_name, job_id, cache_key, _created_at = parse_task_id(task_id) + + manager = _get_callback_manager(tool_name) + if manager is None: + raise MCPError("No background callback manager configured.") + + app = get_app() + adapter = app.mcp_callback_map.find_by_tool_name(tool_name) + if adapter is None: + raise MCPError(f"Task not found: {tool_name}") + # pylint: disable-next=protected-access + cb_info = adapter._cb_info + background = cb_info.get("background") + has_output = not cb_info.get("no_output") + multi = adapter.output_id.startswith("..") + output = split_callback_id(adapter.output_id) + + # Build the body to get output_spec, same as as_callback_body + body = adapter.as_callback_body({}) + output_spec = body.get("outputs", []) + + callback_ctx = AttributeDict({"updated_props": {}}) + response: CallbackExecutionResponse = {"multi": True} + + output_value, has_update, skip = _update_background_callback( + error_handler=None, + callback_ctx=callback_ctx, + response=response, + kwargs={"background_callback_manager": manager}, + background=background, + multi=multi, + cache_key=cache_key, + job_id=job_id, + ) + + if skip: + # Result not ready — still polling + raise MCPError( + "Task result not ready. Poll tasks/get until status is 'completed'." + ) + + _prepare_response( + output_value, + output_spec, + multi, + response, + callback_ctx, + app, + set(ComponentRegistry.registry), + background, + has_update, + has_output, + output, + adapter.output_id, + cb_info.get("allow_dynamic_callbacks"), + ) + + return format_callback_response(response, adapter) + + +def cancel_task(task_id: str) -> Any: + """Handle tasks/cancel — terminate the background job. + + Same underlying mechanism as the renderer's cancelJob query param. + """ + tool_name, job_id, _cache_key, created_at = parse_task_id(task_id) + + manager = _get_callback_manager(tool_name) + if manager is None: + raise MCPError("No background callback manager configured.") + + manager.terminate_job(job_id) + + return CancelTaskResult( + taskId=task_id, + status="cancelled", + createdAt=created_at, + lastUpdatedAt=datetime.now(timezone.utc), + ttl=manager.expire * 1000 if manager.expire else None, + ) diff --git a/dash/mcp/types/__init__.py b/dash/mcp/types/__init__.py new file mode 100644 index 0000000000..af588e0808 --- /dev/null +++ b/dash/mcp/types/__init__.py @@ -0,0 +1,26 @@ +"""MCP types, exceptions, and typing utilities.""" + +from dash.mcp.types.callback_types import MCPInput, MCPOutput +from dash.mcp.types.component_types import ( + ComponentPropertyInfo, + ComponentQueryResult, +) +from dash.mcp.types.exceptions import ( + CallbackExecutionError, + InvalidParamsError, + MCPError, + ToolNotFoundError, +) +from dash.mcp.types.typing_utils import is_nullable + +__all__ = [ + "CallbackExecutionError", + "ComponentPropertyInfo", + "ComponentQueryResult", + "InvalidParamsError", + "MCPError", + "MCPInput", + "MCPOutput", + "ToolNotFoundError", + "is_nullable", +] diff --git a/dash/mcp/types/callback_types.py b/dash/mcp/types/callback_types.py new file mode 100644 index 0000000000..9c65dcb9d8 --- /dev/null +++ b/dash/mcp/types/callback_types.py @@ -0,0 +1,33 @@ +"""Typed dicts for MCP callback adapter data.""" + +from __future__ import annotations + +from typing import Any + +from typing_extensions import TypedDict + + +class MCPOutput(TypedDict): + """A single callback output, with component type and initial value resolved.""" + + id_and_prop: str + component_id: str + property: str + component_type: str | None + initial_value: Any + tool_name: str + + +class MCPInput(TypedDict): + """A single callback parameter (input or state), fully resolved.""" + + name: str + id_and_prop: str + component_id: str + property: str + annotation: Any | None + component_type: str | None + component: Any | None + required: bool + initial_value: Any + upstream_output: MCPOutput | None diff --git a/dash/mcp/types/component_types.py b/dash/mcp/types/component_types.py new file mode 100644 index 0000000000..0cac3ad689 --- /dev/null +++ b/dash/mcp/types/component_types.py @@ -0,0 +1,20 @@ +"""Typed dicts for component data in MCP.""" + +from __future__ import annotations + +from typing import Any + +from typing_extensions import NotRequired, TypedDict + + +class ComponentPropertyInfo(TypedDict): + initial_value: Any + modified_by_tool: list[str] + input_to_tool: list[str] + + +class ComponentQueryResult(TypedDict): + component_id: str + component_type: str + label: NotRequired[list[str] | None] + properties: dict[str, ComponentPropertyInfo] diff --git a/dash/mcp/types/exceptions.py b/dash/mcp/types/exceptions.py new file mode 100644 index 0000000000..578fa51cc9 --- /dev/null +++ b/dash/mcp/types/exceptions.py @@ -0,0 +1,28 @@ +"""MCP error types with JSON-RPC error codes.""" + +from __future__ import annotations + + +class MCPError(Exception): + """Base MCP error carrying a JSON-RPC error code.""" + + code = -32603 + + def __init__(self, message: str): + super().__init__(message) + + +class ToolNotFoundError(MCPError): + """Tool name not found in the callback registry.""" + + code = -32601 + + +class InvalidParamsError(MCPError): + """Invalid or missing parameters for a tool call.""" + + code = -32602 + + +class CallbackExecutionError(MCPError): + """Callback raised an exception during execution.""" diff --git a/dash/mcp/types/typing_utils.py b/dash/mcp/types/typing_utils.py new file mode 100644 index 0000000000..e685f5808b --- /dev/null +++ b/dash/mcp/types/typing_utils.py @@ -0,0 +1,28 @@ +"""Shared typing utilities for the MCP layer.""" + +from __future__ import annotations + +import typing +from typing import Any + + +def is_nullable(annotation: Any) -> bool: + """Check if a type annotation includes NoneType (is nullable/Optional).""" + origin = getattr(annotation, "__origin__", None) + args = getattr(annotation, "__args__", ()) + + _is_union = origin is typing.Union + if not _is_union: + try: + import types as _types # pylint: disable=import-outside-toplevel + + if isinstance(annotation, _types.UnionType): + _is_union = True + args = annotation.__args__ + except AttributeError: + pass + + if _is_union and args: + return type(None) in args + + return False diff --git a/dash/types.py b/dash/types.py index 9a39adb43e..9da246b16c 100644 --- a/dash/types.py +++ b/dash/types.py @@ -1,4 +1,29 @@ -from typing_extensions import TypedDict, NotRequired +import typing +from typing import Any, Dict, List, Union + +from pydantic import Field, GetCoreSchemaHandler, GetJsonSchemaHandler +from pydantic_core import core_schema +from typing_extensions import Annotated, TypedDict, NotRequired + + +class _NumberSchema: # pylint: disable=too-few-public-methods + @classmethod + def __get_pydantic_core_schema__( + cls, _source_type: Any, _handler: GetCoreSchemaHandler + ) -> Any: + return core_schema.float_schema() + + @classmethod + def __get_pydantic_json_schema__( + cls, _schema: Any, _handler: GetJsonSchemaHandler + ) -> dict: + return {"type": "number"} + + +NumberType = Annotated[ + Union[typing.SupportsFloat, typing.SupportsInt, typing.SupportsComplex], + _NumberSchema, +] class RendererHooks(TypedDict): # pylint: disable=too-many-ancestors @@ -8,3 +33,72 @@ class RendererHooks(TypedDict): # pylint: disable=too-many-ancestors request_post: NotRequired[str] callback_resolved: NotRequired[str] request_refresh_jwt: NotRequired[str] + + +WildcardId = Dict[str, Any] +"""A pattern-matching component ID, e.g. ``{"type": "item", "index": 0}``.""" + + +class CallbackDependency(TypedDict): + id: Union[str, WildcardId] + property: str + + +CallbackOutputTarget = Union[CallbackDependency, List[CallbackDependency]] +"""One callback Output() declaration resolved against the layout. + +For regular callbacks, a single dependency:: + + {"id": "chart", "property": "figure"} + +For pattern-matching callbacks (ALL/ALLSMALLER), a list of concrete +targets that the wildcard expanded to:: + + [ + {"id": {"type": "item", "index": 0}, "property": "children"}, + {"id": {"type": "item", "index": 1}, "property": "children"}, + ] + +For MATCH, a single dependency with a dict id:: + + {"id": {"type": "item", "index": 0}, "property": "children"} +""" + + +class CallbackInput(TypedDict): + id: Union[str, WildcardId] + property: str + value: Any + + +CallbackInputs = Union[CallbackInput, List[CallbackInput]] + + +class CallbackExecutionBody(TypedDict): + output: str + outputs: Union[CallbackOutputTarget, List[CallbackOutputTarget]] + inputs: List[CallbackInputs] + state: List[CallbackInputs] + changedPropIds: List[str] + + +CallbackOutput = Annotated[ + Dict[str, Any], + Field( + description="The return values of the callback. A mapping of component & property names to their updated values." + ), +] + +CallbackSideOutput = Annotated[ + Dict[str, Any], + Field( + description="Side-effect updates that the callback performed but did not declare ahead of time. A mapping of component & property names to their updated values." + ), +] + + +class CallbackExecutionResponse(TypedDict): + multi: NotRequired[bool] + response: NotRequired[Dict[str, CallbackOutput]] + sideUpdate: NotRequired[Dict[str, CallbackSideOutput]] + dist: NotRequired[List[Any]] diff --git a/dash/version.py b/dash/version.py index d39d52c3b8..c8f491bcb2 100644 --- a/dash/version.py +++ b/dash/version.py @@ -1 +1 @@ -__version__ = "4.2.0rc3" +__version__ = "4.3.0rc0" diff --git a/package.json b/package.json index f44a0f2805..0f24f25d81 100644 --- a/package.json +++ b/package.json @@ -14,7 +14,7 @@ "private::build.jupyterlab": "cd @plotly/dash-jupyterlab && jlpm install && jlpm build:pack", "private::lint.black": "black dash tests --exclude 'metadata_test.py|node_modules' --check", "private::lint.flake8": "flake8 dash tests", - "private::lint.pylint-dash": "pylint dash setup.py --rcfile=.pylintrc", + "private::lint.pylint-dash": "pylint dash setup.py --rcfile=.pylintrc ${PYLINT_EXTRA_ARGS:-}", "private::lint.pylint-tests": "pylint tests/unit tests/integration -d all -e C0410,C0413,W0109 --rcfile=.pylintrc", "private::lint.renderer": "cd dash/dash-renderer && npm run lint", "private::test.setup-components": "cd @plotly/dash-test-components && npm ci && npm run build", diff --git a/requirements/install.txt b/requirements/install.txt index 284f3a5031..1dedc8662c 100644 --- a/requirements/install.txt +++ b/requirements/install.txt @@ -8,3 +8,5 @@ retrying nest-asyncio setuptools janus>=1.0.0 +pydantic>=2.10 +mcp>=1.23.0; python_version>="3.10" diff --git a/tests/integration/mcp/conftest.py b/tests/integration/mcp/conftest.py new file mode 100644 index 0000000000..aad3f1b1b5 --- /dev/null +++ b/tests/integration/mcp/conftest.py @@ -0,0 +1,62 @@ +"""Shared helpers for MCP integration tests.""" + +import sys + +import pytest +import requests + +from dash import _get_app + +collect_ignore_glob = [] +if sys.version_info < (3, 10): + collect_ignore_glob.append("*") + + +@pytest.fixture(autouse=True) +def _enable_mcp_for_integration_tests(monkeypatch): + """MCP is off by default; integration tests need it on.""" + monkeypatch.setenv("DASH_MCP_ENABLED", "true") + + +@pytest.fixture(autouse=True) +def _reset_dash_app_state(): + """Reset Dash module-level state after each MCP test.""" + yield + _get_app.APP = None + _get_app.app_context.set(None) + + +def _mcp_post(server_url, method, params=None, request_id=1): + return requests.post( + f"{server_url}/_mcp", + json={ + "jsonrpc": "2.0", + "method": method, + "id": request_id, + "params": params or {}, + }, + headers={"Content-Type": "application/json"}, + timeout=5, + ) + + +def _mcp_tools(server_url): + resp = _mcp_post(server_url, "tools/list") + resp.raise_for_status() + return resp.json()["result"]["tools"] + + +def _mcp_call_tool(server_url, tool_name, arguments=None): + resp = _mcp_post( + server_url, + "tools/call", + {"name": tool_name, "arguments": arguments or {}}, + ) + resp.raise_for_status() + return resp.json() + + +def _mcp_method(server_url, method, params=None): + resp = _mcp_post(server_url, method, params) + resp.raise_for_status() + return resp.json() diff --git a/tests/integration/mcp/primitives/tools/test_tool_decorated_mcp_functions.py b/tests/integration/mcp/primitives/tools/test_tool_decorated_mcp_functions.py new file mode 100644 index 0000000000..afbaf415a8 --- /dev/null +++ b/tests/integration/mcp/primitives/tools/test_tool_decorated_mcp_functions.py @@ -0,0 +1,160 @@ +"""Integration tests for @mcp_enabled decorated function tools.""" + +import json +from typing import Optional + +from dash import Dash, html +from dash.mcp import mcp_enabled + +from tests.integration.mcp.conftest import _mcp_call_tool, _mcp_tools + +BUILTINS = {"get_dash_component"} + + +def test_mcpd001_bare_decorator_appears_as_tool(dash_duo): + @mcp_enabled + def add_numbers(a: int, b: int) -> int: + """Add two numbers together.""" + return a + b + + app = Dash(__name__) + app.layout = html.Div(id="root") + dash_duo.start_server(app) + + tools = _mcp_tools(dash_duo.server.url) + names = [t["name"] for t in tools] + assert "add_numbers" in names + + tool = next(t for t in tools if t["name"] == "add_numbers") + assert "Add two numbers" not in tool["description"] + + +def test_mcpd002_expose_docstring(dash_duo): + @mcp_enabled(expose_docstring=True) + def add_numbers(a: int, b: int) -> int: + """Add two numbers together.""" + return a + b + + app = Dash(__name__) + app.layout = html.Div(id="root") + dash_duo.start_server(app) + + tool = next( + t for t in _mcp_tools(dash_duo.server.url) if t["name"] == "add_numbers" + ) + assert "Add two numbers together" in tool["description"] + + +def test_mcpd003_custom_name_overrides_function_name(dash_duo): + @mcp_enabled(name="sum_values") + def add_numbers(a: int, b: int) -> int: + """Add two numbers together.""" + return a + b + + app = Dash(__name__) + app.layout = html.Div(id="root") + dash_duo.start_server(app) + + tools = _mcp_tools(dash_duo.server.url) + names = [t["name"] for t in tools] + assert "sum_values" in names + assert "add_numbers" not in names + + +def test_mcpd004_typed_params_produce_schema(dash_duo): + @mcp_enabled + def greet(name: str, times: int) -> str: + """Greet someone.""" + return name * times + + app = Dash(__name__) + app.layout = html.Div(id="root") + dash_duo.start_server(app) + + tool = next(t for t in _mcp_tools(dash_duo.server.url) if t["name"] == "greet") + schema = tool["inputSchema"] + assert schema["type"] == "object" + assert schema["properties"]["name"]["type"] == "string" + assert schema["properties"]["times"]["type"] == "integer" + assert set(schema["required"]) == {"name", "times"} + + +def test_mcpd005_optional_param_not_required(dash_duo): + @mcp_enabled + def search(query: str, limit: Optional[int] = None) -> str: + """Search for things.""" + return query + + app = Dash(__name__) + app.layout = html.Div(id="root") + dash_duo.start_server(app) + + tool = next(t for t in _mcp_tools(dash_duo.server.url) if t["name"] == "search") + schema = tool["inputSchema"] + assert "query" in schema["required"] + assert "limit" not in schema["required"] + + +def test_mcpd006_return_annotation_becomes_output_schema(dash_duo): + @mcp_enabled + def compute(x: int) -> str: + """Compute something.""" + return str(x) + + app = Dash(__name__) + app.layout = html.Div(id="root") + dash_duo.start_server(app) + + tool = next(t for t in _mcp_tools(dash_duo.server.url) if t["name"] == "compute") + assert tool["outputSchema"]["type"] == "object" + assert tool["outputSchema"]["properties"]["result"]["type"] == "string" + + +def test_mcpd007_call_returns_result(dash_duo): + @mcp_enabled + def multiply(a: int, b: int) -> int: + """Multiply two numbers.""" + return a * b + + app = Dash(__name__) + app.layout = html.Div(id="root") + dash_duo.start_server(app) + + resp = _mcp_call_tool(dash_duo.server.url, "multiply", {"a": 3, "b": 7}) + result = resp["result"] + assert result["isError"] is not True + text = result["content"][0]["text"] + assert json.loads(text) == 21 + + +def test_mcpd008_call_with_custom_name(dash_duo): + @mcp_enabled(name="product") + def multiply(a: int, b: int) -> int: + """Multiply two numbers.""" + return a * b + + app = Dash(__name__) + app.layout = html.Div(id="root") + dash_duo.start_server(app) + + resp = _mcp_call_tool(dash_duo.server.url, "product", {"a": 4, "b": 5}) + result = resp["result"] + assert result["isError"] is not True + text = result["content"][0]["text"] + assert json.loads(text) == 20 + + +def test_mcpd009_call_error_returns_is_error(dash_duo): + @mcp_enabled + def fail_hard(x: int) -> str: + """Always fails.""" + raise ValueError("boom") + + app = Dash(__name__) + app.layout = html.Div(id="root") + dash_duo.start_server(app) + + resp = _mcp_call_tool(dash_duo.server.url, "fail_hard", {"x": 1}) + result = resp["result"] + assert result["isError"] is True + assert "boom" in result["content"][0]["text"] diff --git a/tests/integration/mcp/test_mcp_background_tasks.py b/tests/integration/mcp/test_mcp_background_tasks.py new file mode 100644 index 0000000000..8f8b946754 --- /dev/null +++ b/tests/integration/mcp/test_mcp_background_tasks.py @@ -0,0 +1,326 @@ +"""Background callback support through the MCP HTTP endpoint. + +End-to-end flows: trigger a background callback, poll via +``get_background_task_result``, observe progress (``set_progress``), +confirm the cache-expiry behavior, and verify the background-only tools +appear in ``tools/list``. +""" + +import json +import re +import time + +import diskcache +from dash import Dash, Input, Output, html +from dash.background_callback.managers.diskcache_manager import DiskcacheManager + +MCP_PATH = "_mcp" + + +# --------------------------------------------------------------------------- +# Helpers +# --------------------------------------------------------------------------- + + +def _make_background_app(): + cache = diskcache.Cache() + manager = DiskcacheManager(cache) + + app = Dash(__name__) + app.layout = html.Div( + [ + html.Div(id="input"), + html.Div(id="output"), + ] + ) + + @app.callback( + Output("output", "children"), + Input("input", "children"), + background=True, + manager=manager, + ) + def slow_callback(value): + time.sleep(0.5) + return f"done: {value}" + + return app + + +def _post(client, method, params=None, request_id=1): + return client.post( + f"/{MCP_PATH}", + data=json.dumps( + { + "jsonrpc": "2.0", + "method": method, + "id": request_id, + "params": params or {}, + } + ), + headers={"Content-Type": "application/json"}, + ) + + +# --------------------------------------------------------------------------- +# Tests +# --------------------------------------------------------------------------- + + +def test_mcpbg012_trigger_poll_and_retrieve(): + app = _make_background_app() + client = app.server.test_client() + + # Trigger + r = _post( + client, + "tools/call", + {"name": "slow_callback", "arguments": {"value": "hello"}}, + ) + assert r.status_code == 200 + data = json.loads(r.data) + task_info = json.loads(data["result"]["content"][0]["text"]) + task_id = task_info["taskId"] + assert task_info["status"] == "working" + + # Poll — should be working + r = _post( + client, + "tools/call", + { + "name": "get_background_task_result", + "arguments": {"taskId": task_id}, + }, + request_id=2, + ) + assert r.status_code == 200 + poll_data = json.loads(json.loads(r.data)["result"]["content"][0]["text"]) + assert poll_data["status"] == "working" + + # Wait for completion + job_id = task_id.split(":")[1] + manager = app.callback_map["output.children"]["manager"] + deadline = time.time() + 5 + while time.time() < deadline: + if not manager.job_running(job_id): + break + time.sleep(0.1) + + # Get result + r = _post( + client, + "tools/call", + { + "name": "get_background_task_result", + "arguments": {"taskId": task_id}, + }, + request_id=3, + ) + assert r.status_code == 200 + data = json.loads(r.data) + text = data["result"]["content"][0]["text"] + assert "done:" in text + + +def test_mcpbg013_result_expires(): + """Result is retrievable until the cache expires, then reports failure.""" + cache = diskcache.Cache() + manager = DiskcacheManager(cache, cache_by=[lambda: "fixed"], expire=2) + + app = Dash(__name__) + app.layout = html.Div([html.Div(id="input"), html.Div(id="output")]) + + @app.callback( + Output("output", "children"), + Input("input", "children"), + background=True, + manager=manager, + ) + def fast_cb(value): + return f"done: {value}" + + client = app.server.test_client() + + r = _post( + client, + "tools/call", + {"name": "fast_cb", "arguments": {"value": "hi"}}, + ) + task_info = json.loads(json.loads(r.data)["result"]["content"][0]["text"]) + task_id = task_info["taskId"] + job_id = task_id.split(":")[1] + + deadline = time.time() + 3 + while time.time() < deadline: + if not manager.job_running(job_id): + break + time.sleep(0.1) + + # Before expiry — result available + r = _post( + client, + "tools/call", + { + "name": "get_background_task_result", + "arguments": {"taskId": task_id}, + }, + request_id=2, + ) + assert "done:" in json.loads(r.data)["result"]["content"][0]["text"] + + time.sleep(2.5) + + # After expiry — tool reports failure + r = _post( + client, + "tools/call", + { + "name": "get_background_task_result", + "arguments": {"taskId": task_id}, + }, + request_id=3, + ) + poll_data = json.loads(json.loads(r.data)["result"]["content"][0]["text"]) + assert poll_data["status"] == "failed" + + +def test_mcpbg014_progress_in_poll_response(): + """Progress reported via set_progress appears in poll statusMessage.""" + cache = diskcache.Cache() + manager = DiskcacheManager(cache) + + app = Dash(__name__) + app.layout = html.Div( + [ + html.Div(id="input"), + html.Div(id="status"), + html.Div(id="output"), + ] + ) + + @app.callback( + Output("output", "children"), + Input("input", "children"), + progress=Output("status", "children"), + background=True, + manager=manager, + interval=200, + ) + def progress_cb(set_progress, value): + for i in range(10): + set_progress(f"Step {i + 1} of 10") + time.sleep(0.2) + return f"done: {value}" + + client = app.server.test_client() + + # Trigger + r = _post( + client, + "tools/call", + {"name": "progress_cb", "arguments": {"value": "hi"}}, + ) + task_info = json.loads(json.loads(r.data)["result"]["content"][0]["text"]) + task_id = task_info["taskId"] + + # Poll and collect all progress messages + progress_pattern = re.compile(r"Step \d+ of 10") + progress_messages = [] + deadline = time.time() + 10 + while time.time() < deadline: + r = _post( + client, + "tools/call", + { + "name": "get_background_task_result", + "arguments": {"taskId": task_id}, + }, + request_id=2, + ) + text = json.loads(r.data)["result"]["content"][0]["text"] + try: + poll_data = json.loads(text) + msg = poll_data.get("statusMessage") + if msg is not None: + progress_messages.append(msg) + if poll_data.get("status") == "completed": + break + except (json.JSONDecodeError, KeyError): + break + time.sleep(0.3) + + assert len(progress_messages) > 0, "Expected progress updates during polling" + for msg in progress_messages: + assert progress_pattern.search(msg), f"Unexpected progress format: {msg}" + + +def test_mcpbg015_background_tools_in_tools_list(): + app = _make_background_app() + client = app.server.test_client() + r = _post(client, "tools/list") + data = json.loads(r.data) + names = [t["name"] for t in data["result"]["tools"]] + assert "get_background_task_result" in names + assert "cancel_background_task" in names + assert "slow_callback" in names + + +def test_mcpbg016_per_callback_manager_lookup(): + """``tasks/get`` uses the manager attached to the specific callback.""" + manager_a = DiskcacheManager(diskcache.Cache()) + manager_b = DiskcacheManager(diskcache.Cache()) + + app = Dash(__name__) + app.layout = html.Div( + [ + html.Div(id="input_a"), + html.Div(id="output_a"), + html.Div(id="input_b"), + html.Div(id="output_b"), + ] + ) + + @app.callback( + Output("output_a", "children"), + Input("input_a", "children"), + background=True, + manager=manager_a, + ) + def callback_a(value): + time.sleep(0.5) + return f"a: {value}" + + @app.callback( + Output("output_b", "children"), + Input("input_b", "children"), + background=True, + manager=manager_b, + ) + def callback_b(value): + time.sleep(0.5) + return f"b: {value}" + + client = app.server.test_client() + + r = _post( + client, + "tools/call", + {"name": "callback_b", "arguments": {"value": "hello"}}, + ) + assert r.status_code == 200 + task_info = json.loads(json.loads(r.data)["result"]["content"][0]["text"]) + task_id = task_info["taskId"] + cache_key = task_id.split(":")[2] + + deadline = time.time() + 5 + while time.time() < deadline: + if manager_b.result_ready(cache_key): + break + time.sleep(0.1) + + assert manager_b.result_ready(cache_key) + assert not manager_a.result_ready(cache_key) + + r = _post(client, "tasks/get", {"taskId": task_id}, request_id=2) + assert r.status_code == 200 + assert json.loads(r.data)["result"]["status"] == "completed" diff --git a/tests/integration/mcp/test_mcp_callback_behavior.py b/tests/integration/mcp/test_mcp_callback_behavior.py new file mode 100644 index 0000000000..7778111386 --- /dev/null +++ b/tests/integration/mcp/test_mcp_callback_behavior.py @@ -0,0 +1,1256 @@ +"""Callback behaviors surfaced through MCP tools (end-to-end). + +Covers the full pipeline — a real Dash server via ``dash_duo`` + the MCP +HTTP endpoint — for every callback signature variant and the surrounding +tool-list conventions: + +- Positional, dict-based, and tuple-grouped ``inputs`` / ``state`` / + ``output`` forms. +- ``State``, multi-output, ``PreventUpdate``-style no-output callbacks, + ``ctx.triggered_id``, pattern-matching (``ALL``/``MATCH``/``ALLSMALLER``). +- Initial values: ``prevent_initial_call`` vs. initial-callback overrides. +- Duplicate outputs (``allow_duplicate=True``) appearing as separate tools. +- ``tools/list`` naming rules (64-char limit, uniqueness, built-ins). +- A representative input-schema smoke test (label + DatePicker). +- ``get_dash_component`` structured output via HTTP. +""" + +from dash import ( + ALL, + ALLSMALLER, + MATCH, + Dash, + Input, + Output, + State, + ctx, + dcc, + html, + set_props, +) + +from tests.integration.mcp.conftest import _mcp_call_tool, _mcp_tools + + +# --------------------------------------------------------------------------- +# Helpers +# --------------------------------------------------------------------------- + + +def _find_tool(tools, name): + return next((t for t in tools if t["name"] == name), None) + + +def _get_response(result): + return result["result"]["structuredContent"]["response"] + + +# --------------------------------------------------------------------------- +# Callback signatures — positional, multi-output, State, dict-based, tuples +# --------------------------------------------------------------------------- + + +def test_mcpb001_positional_callback(dash_duo): + """Standard positional: Input("fruit", "value") → param named 'value'.""" + app = Dash(__name__) + app.layout = html.Div( + [ + dcc.Dropdown(id="fruit", options=["apple", "banana"], value="apple"), + html.Div(id="out"), + ] + ) + + @app.callback(Output("out", "children"), Input("fruit", "value")) + def show_fruit(value): + return f"Selected: {value}" + + dash_duo.start_server(app) + dash_duo.wait_for_text_to_equal("#out", "Selected: apple") + + tool = _find_tool(_mcp_tools(dash_duo.server.url), "show_fruit") + props = tool["inputSchema"]["properties"] + + assert set(props.keys()) == {"value"} + assert any(s.get("type") == "string" for s in props["value"]["anyOf"]) + + value_desc = props["value"].get("description", "") + assert "value: 'apple'" in value_desc + assert "options: ['apple', 'banana']" in value_desc + + result = _mcp_call_tool(dash_duo.server.url, "show_fruit", {"value": "apple"}) + response = _get_response(result) + assert response["out"]["children"] == "Selected: apple" + + result = _mcp_call_tool(dash_duo.server.url, "show_fruit", {"value": "banana"}) + response = _get_response(result) + assert response["out"]["children"] == "Selected: banana" + + +def test_mcpb002_positional_with_state(dash_duo): + """Positional with State: Input + State both appear as params.""" + app = Dash(__name__) + app.layout = html.Div( + [ + html.Button(id="btn"), + dcc.Input(id="inp", value="hello"), + html.Div(id="out"), + ] + ) + + @app.callback( + Output("out", "children"), + Input("btn", "n_clicks"), + State("inp", "value"), + ) + def update(n_clicks, value): + return f"Clicked {n_clicks} with {value}" + + dash_duo.start_server(app) + dash_duo.wait_for_text_to_equal("#out", "Clicked None with hello") + + tool = _find_tool(_mcp_tools(dash_duo.server.url), "update") + props = tool["inputSchema"]["properties"] + + assert set(props.keys()) == {"n_clicks", "value"} + assert any(s.get("type") == "number" for s in props["n_clicks"]["anyOf"]) + + assert "value: 'hello'" in props["value"].get("description", "") + + result = _mcp_call_tool( + dash_duo.server.url, "update", {"n_clicks": None, "value": "hello"} + ) + response = _get_response(result) + assert response["out"]["children"] == "Clicked None with hello" + + result = _mcp_call_tool( + dash_duo.server.url, "update", {"n_clicks": 3, "value": "world"} + ) + response = _get_response(result) + assert response["out"]["children"] == "Clicked 3 with world" + + +def test_mcpb003_multi_output_positional(dash_duo): + """Multi-output: returns tuple → both outputs updated in response.""" + app = Dash(__name__) + app.layout = html.Div( + [ + dcc.Input(id="inp", value="test"), + html.Div(id="out1"), + html.Div(id="out2"), + ] + ) + + @app.callback( + Output("out1", "children"), + Output("out2", "children"), + Input("inp", "value"), + ) + def split_case(value): + return value.upper(), value.lower() + + dash_duo.start_server(app) + dash_duo.wait_for_text_to_equal("#out1", "TEST") + dash_duo.wait_for_text_to_equal("#out2", "test") + + tool = _find_tool(_mcp_tools(dash_duo.server.url), "split_case") + props = tool["inputSchema"]["properties"] + assert set(props.keys()) == {"value"} + + assert "value: 'test'" in props["value"].get("description", "") + + result = _mcp_call_tool(dash_duo.server.url, "split_case", {"value": "test"}) + response = _get_response(result) + assert response["out1"]["children"] == "TEST" + assert response["out2"]["children"] == "test" + + +def test_mcpb004_dict_based_inputs_and_state(dash_duo): + """Dict-based: inputs=dict(trigger=...), state=dict(name=...) → dict keys are params.""" + app = Dash(__name__) + app.layout = html.Div( + [ + html.Button(id="btn"), + dcc.Input(id="name-input", value="world"), + html.Div(id="out"), + ] + ) + + @app.callback( + Output("out", "children"), + inputs=dict(trigger=Input("btn", "n_clicks")), + state=dict(name=State("name-input", "value")), + ) + def greet(trigger, name): + return f"Hello, {name}!" + + dash_duo.start_server(app) + dash_duo.wait_for_text_to_equal("#out", "Hello, world!") + + tool = _find_tool(_mcp_tools(dash_duo.server.url), "greet") + props = tool["inputSchema"]["properties"] + + assert set(props.keys()) == {"trigger", "name"} + assert any(s.get("type") == "number" for s in props["trigger"]["anyOf"]) + + result = _mcp_call_tool( + dash_duo.server.url, "greet", {"trigger": None, "name": "world"} + ) + response = _get_response(result) + assert response["out"]["children"] == "Hello, world!" + + result = _mcp_call_tool( + dash_duo.server.url, "greet", {"trigger": 1, "name": "Dash"} + ) + response = _get_response(result) + assert response["out"]["children"] == "Hello, Dash!" + + +def test_mcpb005_dict_based_outputs(dash_duo): + """Dict-based outputs: output=dict(...) → callback returns dict, both outputs updated.""" + app = Dash(__name__) + app.layout = html.Div( + [ + dcc.Input(id="inp", value="hello"), + html.Div(id="upper-out"), + html.Div(id="lower-out"), + ] + ) + + @app.callback( + output=dict( + upper=Output("upper-out", "children"), + lower=Output("lower-out", "children"), + ), + inputs=dict(val=Input("inp", "value")), + ) + def transform(val): + return dict(upper=val.upper(), lower=val.lower()) + + dash_duo.start_server(app) + dash_duo.wait_for_text_to_equal("#upper-out", "HELLO") + dash_duo.wait_for_text_to_equal("#lower-out", "hello") + + tool = _find_tool(_mcp_tools(dash_duo.server.url), "transform") + props = tool["inputSchema"]["properties"] + assert set(props.keys()) == {"val"} + + result = _mcp_call_tool(dash_duo.server.url, "transform", {"val": "hello"}) + response = _get_response(result) + assert response["upper-out"]["children"] == "HELLO" + assert response["lower-out"]["children"] == "hello" + + +def test_mcpb006_mixed_input_state_in_inputs(dash_duo): + """Mixed: State inside inputs=dict alongside Input → all appear as params.""" + app = Dash(__name__) + app.layout = html.Div( + [ + html.Button(id="btn"), + dcc.Input(id="first", value="Jane"), + dcc.Input(id="last", value="Doe"), + html.Div(id="out"), + ] + ) + + @app.callback( + Output("out", "children"), + inputs=dict( + clicks=Input("btn", "n_clicks"), + first=State("first", "value"), + last=State("last", "value"), + ), + ) + def full_name(clicks, first, last): + return f"{first} {last}" + + dash_duo.start_server(app) + dash_duo.wait_for_text_to_equal("#out", "Jane Doe") + + tool = _find_tool(_mcp_tools(dash_duo.server.url), "full_name") + props = tool["inputSchema"]["properties"] + + assert set(props.keys()) == {"clicks", "first", "last"} + assert any(s.get("type") == "number" for s in props["clicks"]["anyOf"]) + + result = _mcp_call_tool( + dash_duo.server.url, + "full_name", + {"clicks": None, "first": "Jane", "last": "Doe"}, + ) + response = _get_response(result) + assert response["out"]["children"] == "Jane Doe" + + result = _mcp_call_tool( + dash_duo.server.url, + "full_name", + {"clicks": 1, "first": "John", "last": "Smith"}, + ) + response = _get_response(result) + assert response["out"]["children"] == "John Smith" + + +def test_mcpb007_tuple_grouped_inputs(dash_duo): + """Tuple grouping: pair=(Input("a",...), Input("b",...)) → expands to two named params.""" + app = Dash(__name__) + app.layout = html.Div( + [ + dcc.Input(id="a", value="1"), + dcc.Input(id="b", value="2"), + html.Div(id="out"), + ] + ) + + @app.callback( + Output("out", "children"), + inputs=dict(pair=(Input("a", "value"), Input("b", "value"))), + ) + def combine(pair): + return f"{pair[0]}+{pair[1]}" + + dash_duo.start_server(app) + tool = _find_tool(_mcp_tools(dash_duo.server.url), "combine") + props = tool["inputSchema"]["properties"] + + assert set(props.keys()) == {"pair_a__value", "pair_b__value"} + for schema in props.values(): + assert any(s.get("type") == "string" for s in schema["anyOf"]) + + result = _mcp_call_tool( + dash_duo.server.url, + "combine", + {"pair_a__value": "x", "pair_b__value": "y"}, + ) + response = _get_response(result) + assert response["out"]["children"] == "x+y" + + +def test_mcpb008_initial_values_from_chained_callbacks(dash_duo): + """Querying components reflects post-initial-callback values. + + 3-link chain: country (default "France") → update_states → + state (should become "Ile-de-France") → update_cities → + city (should become "Paris"). + """ + DATA = { + "France": { + "Ile-de-France": ["Paris", "Versailles"], + "Provence": ["Marseille", "Nice"], + }, + "Germany": { + "Bavaria": ["Munich", "Nuremberg"], + "Berlin": ["Berlin"], + }, + } + + app = Dash(__name__) + app.layout = html.Div( + [ + dcc.Dropdown(id="country", options=list(DATA.keys()), value="France"), + dcc.Dropdown(id="state"), + dcc.Dropdown(id="city"), + ] + ) + + @app.callback( + Output("state", "options"), + Output("state", "value"), + Input("country", "value"), + ) + def update_states(country): + if not country: + return [], None + states = list(DATA[country].keys()) + return [{"label": s, "value": s} for s in states], states[0] + + @app.callback( + Output("city", "options"), + Output("city", "value"), + Input("state", "value"), + Input("country", "value"), + ) + def update_cities(state, country): + if not state or not country: + return [], None + cities = DATA[country][state] + return [{"label": c, "value": c} for c in cities], cities[0] + + dash_duo.start_server(app) + + tools = _mcp_tools(dash_duo.server.url) + update_cities_tool = _find_tool(tools, "update_cities") + state_desc = update_cities_tool["inputSchema"]["properties"]["state"].get( + "description", "" + ) + assert "Ile-de-France" in state_desc + + result = _mcp_call_tool( + dash_duo.server.url, + "get_dash_component", + {"component_id": "state", "property": "value"}, + ) + state_props = result["result"]["structuredContent"]["properties"] + assert state_props["value"]["initial_value"] == "Ile-de-France" + + result = _mcp_call_tool( + dash_duo.server.url, + "get_dash_component", + {"component_id": "city", "property": "value"}, + ) + city_props = result["result"]["structuredContent"]["properties"] + assert city_props["value"]["initial_value"] == "Paris" + + +def test_mcpb009_dict_based_reordered_state_input(dash_duo): + """Dict-based callback with State before Input: call works, schema types correct. + + State is listed before Input in the dict. The callback should still + work correctly via MCP, and the schema types should match the + function annotations (name: str, trigger: int), not be swapped. + """ + app = Dash(__name__) + app.layout = html.Div( + [ + html.Button(id="btn"), + dcc.Input(id="inp", value="World"), + html.Div(id="out"), + ] + ) + + @app.callback( + Output("out", "children"), + inputs=dict(name=State("inp", "value"), trigger=Input("btn", "n_clicks")), + ) + def greet(name: str, trigger: int): + return f"Hello {name}" + + dash_duo.start_server(app) + + result = _mcp_call_tool( + dash_duo.server.url, + "greet", + {"name": "Dash", "trigger": 1}, + ) + assert _get_response(result)["out"]["children"] == "Hello Dash" + + tool = _find_tool(_mcp_tools(dash_duo.server.url), "greet") + props = tool["inputSchema"]["properties"] + assert props["trigger"]["type"] == "integer" + assert props["name"]["type"] == "string" + + trigger_desc = props["trigger"].get("description", "") + assert "number of times that this element has been clicked on" in trigger_desc + name_desc = props["name"].get("description", "") + assert "The value of the input" in name_desc + + +# --------------------------------------------------------------------------- +# Pattern-matching callbacks (ALL / MATCH / ALLSMALLER) +# --------------------------------------------------------------------------- + + +def test_mcpb010_pattern_matching_callback(dash_duo): + """Pattern-matching dict IDs: tool works with correct params and results.""" + app = Dash(__name__) + app.layout = html.Div( + [ + dcc.Input(id={"type": "field", "index": 0}, value="hello"), + dcc.Input(id={"type": "field", "index": 1}, value="world"), + html.Div(id="out"), + ] + ) + + @app.callback( + Output("out", "children"), + Input({"type": "field", "index": 0}, "value"), + Input({"type": "field", "index": 1}, "value"), + ) + def combine(first, second): + return f"{first} {second}" + + dash_duo.start_server(app) + + tool = _find_tool(_mcp_tools(dash_duo.server.url), "combine") + assert tool is not None + props = tool["inputSchema"]["properties"] + assert "first" in props + assert "second" in props + + dash_duo.wait_for_text_to_equal("#out", "hello world") + result = _mcp_call_tool( + dash_duo.server.url, + "combine", + {"first": "hello", "second": "world"}, + ) + response = _get_response(result) + assert response["out"]["children"] == "hello world" + + result = _mcp_call_tool( + dash_duo.server.url, + "combine", + {"first": "foo", "second": "bar"}, + ) + response = _get_response(result) + assert response["out"]["children"] == "foo bar" + + +def test_mcpb011_pattern_matching_with_all_wildcard(dash_duo): + """ALL wildcard: one callback receives values from all matching components.""" + app = Dash(__name__) + app.layout = html.Div( + [ + dcc.Input(id={"type": "input", "index": 0}, value="alpha"), + dcc.Input(id={"type": "input", "index": 1}, value="beta"), + html.Div(id="summary"), + ] + ) + + @app.callback( + Output("summary", "children"), + Input({"type": "input", "index": ALL}, "value"), + ) + def summarize(values): + return ", ".join(v for v in values if v) + + dash_duo.start_server(app) + dash_duo.wait_for_text_to_equal("#summary", "alpha, beta") + + tool = _find_tool(_mcp_tools(dash_duo.server.url), "summarize") + assert tool is not None + + values_schema = tool["inputSchema"]["properties"]["values"] + assert ( + values_schema["type"] == "array" + ), f"ALL wildcard param should be typed as array, got: {values_schema}" + assert "items" in values_schema, "Array schema should include items definition" + items = values_schema["items"] + assert items["type"] == "object" + assert "id" in items["properties"] + assert "value" in items["properties"] + assert "Pattern-matching input (ALL)" in values_schema.get( + "description", "" + ), "ALL wildcard param description should explain the pattern-matching behavior" + + result = _mcp_call_tool( + dash_duo.server.url, + "summarize", + { + "values": [ + { + "id": {"type": "input", "index": 0}, + "property": "value", + "value": "alpha", + }, + { + "id": {"type": "input", "index": 1}, + "property": "value", + "value": "beta", + }, + ] + }, + ) + response = _get_response(result) + assert response["summary"]["children"] == "alpha, beta" + + result = _mcp_call_tool( + dash_duo.server.url, + "summarize", + { + "values": [ + { + "id": {"type": "input", "index": 0}, + "property": "value", + "value": "one", + }, + { + "id": {"type": "input", "index": 1}, + "property": "value", + "value": "two", + }, + ] + }, + ) + response = _get_response(result) + assert response["summary"]["children"] == "one, two" + + +def test_mcpb012_pattern_matching_mixed_outputs(dash_duo): + """Mixed outputs: one regular + one ALL wildcard in the same callback.""" + app = Dash(__name__) + app.layout = html.Div( + [ + dcc.Input(id={"type": "field", "index": 0}, value="a"), + dcc.Input(id={"type": "field", "index": 1}, value="b"), + html.Div(id={"type": "echo", "index": 0}), + html.Div(id={"type": "echo", "index": 1}), + html.Div(id="total"), + ] + ) + + @app.callback( + Output({"type": "echo", "index": ALL}, "children"), + Output("total", "children"), + Input({"type": "field", "index": ALL}, "value"), + ) + def echo_and_total(values): + echoes = [f"Echo: {v}" for v in values] + total = f"Total: {len(values)} items" + return echoes, total + + dash_duo.start_server(app) + dash_duo.wait_for_text_to_equal("#total", "Total: 2 items") + + result = _mcp_call_tool( + dash_duo.server.url, + "echo_and_total", + { + "values": [ + { + "id": {"type": "field", "index": 0}, + "property": "value", + "value": "x", + }, + { + "id": {"type": "field", "index": 1}, + "property": "value", + "value": "y", + }, + ] + }, + ) + response = _get_response(result) + assert response["total"]["children"] == "Total: 2 items" + + +def test_mcpb013_pattern_matching_with_match_wildcard(dash_duo): + """MATCH wildcard: callback fires per-component with matching index. + + Based on https://dash.plotly.com/pattern-matching-callbacks + """ + app = Dash(__name__) + app.layout = html.Div( + [ + dcc.Dropdown( + ["NYC", "MTL", "LA", "TOKYO"], + "NYC", + id={"type": "city-dd", "index": 0}, + ), + html.Div(id={"type": "city-out", "index": 0}), + dcc.Dropdown( + ["NYC", "MTL", "LA", "TOKYO"], + "LA", + id={"type": "city-dd", "index": 1}, + ), + html.Div(id={"type": "city-out", "index": 1}), + ] + ) + + @app.callback( + Output({"type": "city-out", "index": MATCH}, "children"), + Input({"type": "city-dd", "index": MATCH}, "value"), + ) + def show_city(value): + return f"Selected: {value}" + + dash_duo.start_server(app) + + tool = _find_tool(_mcp_tools(dash_duo.server.url), "show_city") + assert tool is not None + + value_schema = tool["inputSchema"]["properties"]["value"] + assert "Pattern-matching input (MATCH)" in value_schema.get("description", "") + + result = _mcp_call_tool( + dash_duo.server.url, + "show_city", + { + "value": { + "id": {"type": "city-dd", "index": 0}, + "property": "value", + "value": "MTL", + } + }, + ) + response = _get_response(result) + out_key = next(k for k in response if "city-out" in k) + assert response[out_key]["children"] == "Selected: MTL" + + +def test_mcpb014_pattern_matching_with_allsmaller_wildcard(dash_duo): + """ALLSMALLER wildcard: receives values from components with smaller index. + + Based on https://dash.plotly.com/pattern-matching-callbacks + """ + app = Dash(__name__) + app.layout = html.Div( + [ + dcc.Dropdown( + ["France", "Germany", "Japan"], + "France", + id={"type": "country-dd", "index": 0}, + ), + html.Div(id={"type": "country-out", "index": 0}), + dcc.Dropdown( + ["France", "Germany", "Japan"], + "Germany", + id={"type": "country-dd", "index": 1}, + ), + html.Div(id={"type": "country-out", "index": 1}), + dcc.Dropdown( + ["France", "Germany", "Japan"], + "Japan", + id={"type": "country-dd", "index": 2}, + ), + html.Div(id={"type": "country-out", "index": 2}), + ] + ) + + @app.callback( + Output({"type": "country-out", "index": MATCH}, "children"), + Input({"type": "country-dd", "index": MATCH}, "value"), + Input({"type": "country-dd", "index": ALLSMALLER}, "value"), + ) + def show_countries(current, previous): + all_selected = [current] + list(reversed(previous)) + return f"All: {', '.join(all_selected)}" + + dash_duo.start_server(app) + + tool = _find_tool(_mcp_tools(dash_duo.server.url), "show_countries") + assert tool is not None + + props = tool["inputSchema"]["properties"] + assert "Pattern-matching input (MATCH)" in props["current"].get("description", "") + assert "Pattern-matching input (ALLSMALLER)" in props["previous"].get( + "description", "" + ) + + result = _mcp_call_tool( + dash_duo.server.url, + "show_countries", + { + "current": { + "id": {"type": "country-dd", "index": 2}, + "property": "value", + "value": "Japan", + }, + "previous": [ + { + "id": {"type": "country-dd", "index": 0}, + "property": "value", + "value": "France", + }, + { + "id": {"type": "country-dd", "index": 1}, + "property": "value", + "value": "Germany", + }, + ], + }, + ) + response = _get_response(result) + out_key = next(k for k in response if "country-out" in k) + assert response[out_key]["children"] == "All: Japan, Germany, France" + + +# --------------------------------------------------------------------------- +# Initial values: prevent_initial_call vs. initial-callback overrides +# --------------------------------------------------------------------------- + + +def test_mcpb015_prevent_initial_call_uses_layout_default(dash_duo): + """prevent_initial_call=True: initial value stays as the layout default.""" + app = Dash(__name__) + app.layout = html.Div( + [ + dcc.Dropdown(id="dd", options=["a", "b", "c"], value="a"), + html.Div(id="out", children="not yet"), + ] + ) + + @app.callback( + Output("out", "children"), + Input("dd", "value"), + prevent_initial_call=True, + ) + def update(val): + return f"Changed to: {val}" + + dash_duo.start_server(app) + dash_duo.wait_for_text_to_equal("#out", "not yet") + + tool = _find_tool(_mcp_tools(dash_duo.server.url), "update") + val_desc = tool["inputSchema"]["properties"]["val"].get("description", "") + + assert "value: 'a'" in val_desc + + +def test_mcpb016_initial_callback_overrides_layout_value(dash_duo): + """Initial callback overrides layout value in tool description.""" + app = Dash(__name__) + app.layout = html.Div( + [ + dcc.Dropdown(id="country", options=["France", "Germany"], value="France"), + dcc.Dropdown(id="city", options=[], value="default-city"), + html.Div(id="out"), + ] + ) + + @app.callback( + Output("city", "options"), + Output("city", "value"), + Input("country", "value"), + ) + def update_city(country): + if country == "France": + return [{"label": "Paris", "value": "Paris"}], "Paris" + return [{"label": "Berlin", "value": "Berlin"}], "Berlin" + + @app.callback(Output("out", "children"), Input("city", "value")) + def show_city(city): + return f"City: {city}" + + dash_duo.start_server(app) + dash_duo.wait_for_text_to_equal("#out", "City: Paris") + + tool = _find_tool(_mcp_tools(dash_duo.server.url), "show_city") + city_desc = tool["inputSchema"]["properties"]["city"].get("description", "") + + assert "value: 'Paris'" in city_desc + assert "default-city" not in city_desc + + +def test_mcpb017_callback_context_triggered_id(dash_duo): + """Callbacks using dash.ctx.triggered_id work via MCP. + + Based on https://dash.plotly.com/determining-which-callback-input-changed + """ + app = Dash(__name__) + app.layout = html.Div( + [ + html.Button("Button 1", id="btn-1"), + html.Button("Button 2", id="btn-2"), + html.Button("Button 3", id="btn-3"), + html.Div(id="output"), + ] + ) + + @app.callback( + Output("output", "children"), + Input("btn-1", "n_clicks"), + Input("btn-2", "n_clicks"), + Input("btn-3", "n_clicks"), + ) + def display(btn1, btn2, btn3): + if not ctx.triggered_id: + return "No button clicked yet" + return f"Last clicked: {ctx.triggered_id}" + + dash_duo.start_server(app) + + dash_duo.wait_for_text_to_equal("#output", "No button clicked yet") + + tool = _find_tool(_mcp_tools(dash_duo.server.url), "display") + props = tool["inputSchema"]["properties"] + assert "btn1" in props + assert "btn2" in props + assert "btn3" in props + + result = _mcp_call_tool( + dash_duo.server.url, + "display", + {"btn1": None, "btn2": 1, "btn3": None}, + ) + response = _get_response(result) + assert response["output"]["children"] == "Last clicked: btn-2" + + result = _mcp_call_tool( + dash_duo.server.url, + "display", + {"btn1": None, "btn2": None, "btn3": 5}, + ) + response = _get_response(result) + assert response["output"]["children"] == "Last clicked: btn-3" + + +def test_mcpb018_no_output_callback_does_not_crash_tools_list(dash_duo): + """A callback with no Output should not crash tools/list. + + No-output callbacks use set_props for side effects. They produce + a hash-only output_id with no dot separator. + """ + app = Dash(__name__) + app.layout = html.Div( + [ + html.Button("Log", id="log-btn"), + dcc.Dropdown(id="picker", options=["a", "b"], value="a"), + html.Div(id="display"), + ] + ) + + @app.callback(Input("log-btn", "n_clicks"), prevent_initial_call=True) + def log_click(n): + set_props("display", {"children": f"Logged {n} clicks"}) + + @app.callback(Output("display", "children"), Input("picker", "value")) + def show_selection(val): + return f"Selected: {val}" + + dash_duo.start_server(app) + + tools = _mcp_tools(dash_duo.server.url) + tool_names = [t["name"] for t in tools] + + assert "show_selection" in tool_names + assert "log_click" in tool_names + + result = _mcp_call_tool( + dash_duo.server.url, + "log_click", + {"n": 3}, + ) + structured = result["result"]["structuredContent"] + assert "sideUpdate" in structured + assert structured["sideUpdate"]["display"]["children"] == "Logged 3 clicks" + + result = _mcp_call_tool( + dash_duo.server.url, + "get_dash_component", + {"component_id": "display", "property": "children"}, + ) + prop_info = result["result"]["structuredContent"]["properties"]["children"] + assert "show_selection" in prop_info["modified_by_tool"] + + +# --------------------------------------------------------------------------- +# Duplicate outputs (allow_duplicate=True) +# --------------------------------------------------------------------------- + + +def test_mcpb019_duplicate_outputs_both_tools_listed(dash_duo): + """Both callbacks outputting to the same component appear as tools.""" + app = Dash(__name__) + app.layout = html.Div( + [ + dcc.Input(id="first-name", value="Jane"), + dcc.Input(id="last-name", value="Doe"), + html.Div(id="greeting"), + ] + ) + + @app.callback( + Output("greeting", "children"), + Input("first-name", "value"), + ) + def greet_by_first(first): + return f"Hello, {first}!" + + @app.callback( + Output("greeting", "children", allow_duplicate=True), + Input("last-name", "value"), + prevent_initial_call=True, + ) + def greet_by_last(last): + return f"Hi, {last}!" + + dash_duo.start_server(app) + tools = _mcp_tools(dash_duo.server.url) + + first_tool = _find_tool(tools, "greet_by_first") + last_tool = _find_tool(tools, "greet_by_last") + + assert first_tool is not None, "greet_by_first should be listed" + assert last_tool is not None, "greet_by_last should be listed" + + +def test_mcpb020_duplicate_outputs_both_callable(dash_duo): + """Both callbacks can be called and produce correct results.""" + app = Dash(__name__) + app.layout = html.Div( + [ + dcc.Input(id="first-name", value="Jane"), + dcc.Input(id="last-name", value="Doe"), + html.Div(id="greeting"), + ] + ) + + @app.callback( + Output("greeting", "children"), + Input("first-name", "value"), + ) + def greet_by_first(first): + return f"Hello, {first}!" + + @app.callback( + Output("greeting", "children", allow_duplicate=True), + Input("last-name", "value"), + prevent_initial_call=True, + ) + def greet_by_last(last): + return f"Hi, {last}!" + + dash_duo.start_server(app) + + result1 = _mcp_call_tool(dash_duo.server.url, "greet_by_first", {"first": "Alice"}) + assert _get_response(result1)["greeting"]["children"] == "Hello, Alice!" + + result2 = _mcp_call_tool(dash_duo.server.url, "greet_by_last", {"last": "Smith"}) + assert _get_response(result2)["greeting"]["children"] == "Hi, Smith!" + + +def test_mcpb021_duplicate_outputs_find_by_output_returns_primary(dash_duo): + """find_by_output returns the primary (non-duplicate) callback.""" + app = Dash(__name__) + app.layout = html.Div( + [ + dcc.Input(id="first-name", value="Jane"), + dcc.Input(id="last-name", value="Doe"), + html.Div(id="greeting"), + ] + ) + + @app.callback( + Output("greeting", "children"), + Input("first-name", "value"), + ) + def greet_by_first(first): + return f"Hello, {first}!" + + @app.callback( + Output("greeting", "children", allow_duplicate=True), + Input("last-name", "value"), + prevent_initial_call=True, + ) + def greet_by_last(last): + return f"Hi, {last}!" + + dash_duo.start_server(app) + + result = _mcp_call_tool( + dash_duo.server.url, + "get_dash_component", + {"component_id": "greeting", "property": "children"}, + ) + structured = result["result"]["structuredContent"] + assert structured["properties"]["children"]["initial_value"] == "Hello, Jane!" + + +# --------------------------------------------------------------------------- +# tools/list — naming rules (64-char limit, uniqueness, built-ins) +# --------------------------------------------------------------------------- + + +def test_mcpb022_tool_names_within_64_chars(dash_duo): + app = Dash(__name__) + app.layout = html.Div( + [ + dcc.Dropdown(id="dd", options=["a"], value="a"), + html.Div(id="out"), + ] + ) + + @app.callback(Output("out", "children"), Input("dd", "value")) + def update(val): + return val + + dash_duo.start_server(app) + for tool in _mcp_tools(dash_duo.server.url): + assert len(tool["name"]) <= 64, f"Tool name exceeds 64 chars: {tool['name']}" + for param_name in tool.get("inputSchema", {}).get("properties", {}): + assert len(param_name) <= 64, f"Param name exceeds 64 chars: {param_name}" + + +def test_mcpb023_long_callback_ids_within_64_chars(dash_duo): + app = Dash(__name__) + long_id = "a" * 120 + app.layout = html.Div( + [ + dcc.Input(id=long_id, value="test"), + html.Div(id=f"{long_id}-output"), + ] + ) + + @app.callback(Output(f"{long_id}-output", "children"), Input(long_id, "value")) + def process(val): + return val + + dash_duo.start_server(app) + for tool in _mcp_tools(dash_duo.server.url): + assert len(tool["name"]) <= 64, f"Tool name exceeds 64 chars: {tool['name']}" + + +def test_mcpb024_pattern_matching_ids_within_64_chars(dash_duo): + app = Dash(__name__) + app.layout = html.Div( + [ + html.Div( + [ + dcc.Input( + id={"type": "filter-input", "index": i, "category": "primary"}, + value=f"val-{i}", + ) + for i in range(3) + ] + ), + html.Div(id="pm-output"), + ] + ) + + @app.callback( + Output("pm-output", "children"), + Input({"type": "filter-input", "index": 0, "category": "primary"}, "value"), + ) + def filter_update(v0): + return str(v0) + + dash_duo.start_server(app) + for tool in _mcp_tools(dash_duo.server.url): + assert len(tool["name"]) <= 64, f"Tool name exceeds 64 chars: {tool['name']}" + + +def test_mcpb025_duplicate_func_names_produce_unique_tools(dash_duo): + app = Dash(__name__) + app.layout = html.Div( + [ + dcc.Dropdown(id="dd1", options=["a"], value="a"), + html.Div(id="dd1-output"), + dcc.Dropdown(id="dd2", options=["b"], value="b"), + html.Div(id="dd2-output"), + dcc.Dropdown(id="dd3", options=["c"], value="c"), + html.Div(id="dd3-output"), + ] + ) + + @app.callback(Output("dd1-output", "children"), Input("dd1", "value")) + def cb(value): + return f"first: {value}" + + @app.callback(Output("dd2-output", "children"), Input("dd2", "value")) + def cb(value): # noqa: F811 + return f"second: {value}" + + @app.callback(Output("dd3-output", "children"), Input("dd3", "value")) + def cb(value): # noqa: F811 + return f"third: {value}" + + dash_duo.start_server(app) + tools = _mcp_tools(dash_duo.server.url) + cb_tools = [t for t in tools if t["name"] not in ("get_dash_component",)] + tool_names = [t["name"] for t in cb_tools] + + assert ( + len(tool_names) == 3 + ), f"Expected 3 callback tools, got {len(tool_names)}: {tool_names}" + assert len(set(tool_names)) == 3, f"Tool names not unique: {tool_names}" + + +def test_mcpb026_builtin_tools_always_present(dash_duo): + app = Dash(__name__) + app.layout = html.Div(id="root") + + dash_duo.start_server(app) + tool_names = [t["name"] for t in _mcp_tools(dash_duo.server.url)] + assert "get_dash_component" in tool_names + + +# --------------------------------------------------------------------------- +# Input schema smoke test + get_dash_component HTTP structured output +# --------------------------------------------------------------------------- + + +def test_mcpb027_mcp_tool_with_label_and_date_picker_schema(dash_duo): + """Full assertion on a tool with an html.Label and DatePickerSingle constraints.""" + label_text = "Departure Date" + component_id = "dp" + min_date = "2020-01-01" + max_date = "2025-12-31" + default_date = "2024-06-15" + func_name = "select_date" + param_name = "date" # function parameter name + + app = Dash(__name__) + app.layout = html.Div( + [ + html.Label(label_text, htmlFor=component_id), + dcc.DatePickerSingle( + id=component_id, + min_date_allowed=min_date, + max_date_allowed=max_date, + date=default_date, + ), + html.Div(id="out"), + ] + ) + + @app.callback(Output("out", "children"), Input(component_id, "date")) + def select_date(date): + return f"Selected: {date}" + + dash_duo.start_server(app) + tools = _mcp_tools(dash_duo.server.url) + + tool = next(t for t in tools if t["name"] not in ("get_dash_component",)) + + assert func_name in tool["name"] + + schema = tool["inputSchema"] + assert schema["type"] == "object" + assert param_name in schema["required"] + assert param_name in schema["properties"] + + prop = schema["properties"][param_name] + assert prop["type"] == "string" + assert prop["format"] == "date" + + desc = prop["description"] + for expected in (label_text, min_date, max_date, default_date): + assert expected in desc, f"Expected {expected!r} in description: {desc!r}" + + +EXPECTED_DROPDOWN_OPTIONS = { + "component_id": "my-dropdown", + "component_type": "Dropdown", + "label": None, + "properties": { + "options": { + "initial_value": [ + {"label": "New York", "value": "NYC"}, + {"label": "Montreal", "value": "MTL"}, + ], + "modified_by_tool": [], + "input_to_tool": [], + }, + }, +} + + +def test_mcpb028_query_component_returns_structured_output(dash_duo): + app = Dash(__name__) + app.layout = html.Div( + [ + dcc.Dropdown( + id="my-dropdown", + options=[ + {"label": "New York", "value": "NYC"}, + {"label": "Montreal", "value": "MTL"}, + ], + value="NYC", + ), + ] + ) + + dash_duo.start_server(app) + + result = _mcp_call_tool( + dash_duo.server.url, + "get_dash_component", + {"component_id": "my-dropdown", "property": "options"}, + ) + + assert "result" in result, f"Expected result in response: {result}" + structured = result["result"]["structuredContent"] + assert structured["component_id"] == EXPECTED_DROPDOWN_OPTIONS["component_id"] + assert structured["component_type"] == EXPECTED_DROPDOWN_OPTIONS["component_type"] + assert ( + structured["properties"]["options"] + == EXPECTED_DROPDOWN_OPTIONS["properties"]["options"] + ) diff --git a/tests/integration/mcp/test_mcp_endpoint.py b/tests/integration/mcp/test_mcp_endpoint.py new file mode 100644 index 0000000000..44b358c25d --- /dev/null +++ b/tests/integration/mcp/test_mcp_endpoint.py @@ -0,0 +1,189 @@ +"""MCP Streamable HTTP endpoint — transport-layer behavior. + +Uses Flask's test_client to exercise POST/GET/DELETE at /_mcp, +session management, content-type handling, and route registration +driven by ``enable_mcp`` / ``DASH_MCP_ENABLED`` / ``routes_pathname_prefix``. +""" + +import json +import os + +from dash import Dash, Input, Output, html +from mcp.types import LATEST_PROTOCOL_VERSION + +MCP_PATH = "_mcp" + + +# --------------------------------------------------------------------------- +# Helpers +# --------------------------------------------------------------------------- + + +def _make_app(**kwargs): + """Create a minimal Dash app with a layout and one callback.""" + app = Dash(__name__, **kwargs) + app.layout = html.Div( + [ + html.Div(id="my-input"), + html.Div(id="my-output"), + ] + ) + + @app.callback(Output("my-output", "children"), Input("my-input", "children")) + def update_output(value): + """Test callback docstring.""" + return f"echo: {value}" + + return app + + +# --------------------------------------------------------------------------- +# Tests +# --------------------------------------------------------------------------- + + +def test_mcpe001_post_initialize_returns_protocol_version(): + app = _make_app() + client = app.server.test_client() + r = client.post( + f"/{MCP_PATH}", + data=json.dumps( + {"jsonrpc": "2.0", "method": "initialize", "id": 1, "params": {}} + ), + content_type="application/json", + ) + assert r.status_code == 200 + data = json.loads(r.data) + assert data["result"]["protocolVersion"] == LATEST_PROTOCOL_VERSION + + +def test_mcpe002_post_tools_list(): + app = _make_app() + client = app.server.test_client() + r = client.post( + f"/{MCP_PATH}", + data=json.dumps( + {"jsonrpc": "2.0", "method": "tools/list", "id": 1, "params": {}} + ), + content_type="application/json", + ) + assert r.status_code == 200 + data = json.loads(r.data) + assert "result" in data + assert "tools" in data["result"] + + +def test_mcpe003_notification_returns_202(): + app = _make_app() + client = app.server.test_client() + r = client.post( + f"/{MCP_PATH}", + data=json.dumps({"jsonrpc": "2.0", "method": "notifications/initialized"}), + content_type="application/json", + ) + assert r.status_code == 202 + + +def test_mcpe004_delete_returns_405(): + app = _make_app() + client = app.server.test_client() + r = client.delete(f"/{MCP_PATH}") + assert r.status_code == 405 + + +def test_mcpe005_get_returns_405(): + app = _make_app() + client = app.server.test_client() + r = client.get(f"/{MCP_PATH}") + assert r.status_code == 405 + + +def test_mcpe006_post_rejects_wrong_content_type(): + app = _make_app() + client = app.server.test_client() + r = client.post( + f"/{MCP_PATH}", + data="not json", + content_type="text/plain", + ) + assert r.status_code == 415 + + +def test_mcpe007_routes_not_registered_when_disabled(): + app = _make_app(enable_mcp=False) + client = app.server.test_client() + r = client.post( + f"/{MCP_PATH}", + data=json.dumps( + {"jsonrpc": "2.0", "method": "initialize", "id": 1, "params": {}} + ), + content_type="application/json", + ) + # With MCP disabled, the route doesn't exist — response is HTML, not JSON + assert r.content_type != "application/json" + + +def test_mcpe008_routes_respect_pathname_prefix(): + app = _make_app(routes_pathname_prefix="/app/") + client = app.server.test_client() + + ok = client.post( + f"/app/{MCP_PATH}", + data=json.dumps( + {"jsonrpc": "2.0", "method": "initialize", "id": 1, "params": {}} + ), + content_type="application/json", + ) + assert ok.status_code == 200 + + miss = client.post( + f"/{MCP_PATH}", + data=json.dumps( + {"jsonrpc": "2.0", "method": "initialize", "id": 1, "params": {}} + ), + content_type="application/json", + ) + assert miss.status_code == 404 + + +def test_mcpe009_enable_mcp_env_var_false(): + old = os.environ.get("DASH_MCP_ENABLED") + try: + os.environ["DASH_MCP_ENABLED"] = "false" + app = _make_app() + client = app.server.test_client() + r = client.post( + f"/{MCP_PATH}", + data=json.dumps( + {"jsonrpc": "2.0", "method": "initialize", "id": 1, "params": {}} + ), + content_type="application/json", + ) + assert r.content_type != "application/json" + finally: + if old is None: + os.environ.pop("DASH_MCP_ENABLED", None) + else: + os.environ["DASH_MCP_ENABLED"] = old + + +def test_mcpe010_constructor_overrides_env_var(): + old = os.environ.get("DASH_MCP_ENABLED") + try: + os.environ["DASH_MCP_ENABLED"] = "false" + app = _make_app(enable_mcp=True) + client = app.server.test_client() + r = client.post( + f"/{MCP_PATH}", + data=json.dumps( + {"jsonrpc": "2.0", "method": "initialize", "id": 1, "params": {}} + ), + content_type="application/json", + ) + assert r.status_code == 200 + assert b"protocolVersion" in r.data + finally: + if old is None: + os.environ.pop("DASH_MCP_ENABLED", None) + else: + os.environ["DASH_MCP_ENABLED"] = old diff --git a/tests/integration/mcp/test_mcp_resources.py b/tests/integration/mcp/test_mcp_resources.py new file mode 100644 index 0000000000..41519578d1 --- /dev/null +++ b/tests/integration/mcp/test_mcp_resources.py @@ -0,0 +1,51 @@ +"""MCP resources — ``resources/list`` and ``resources/read`` via HTTP.""" + +import json + +from dash import Dash, dcc, html + +from tests.integration.mcp.conftest import _mcp_method + + +def test_mcpz001_resources_list_includes_layout(dash_duo): + app = Dash(__name__) + app.layout = html.Div( + [ + dcc.Dropdown(id="dd", options=["a"], value="a"), + html.Div(id="out"), + ] + ) + + dash_duo.start_server(app) + result = _mcp_method(dash_duo.server.url, "resources/list") + + assert "result" in result + uris = [r["uri"] for r in result["result"]["resources"]] + assert "dash://layout" in uris + + +def test_mcpz002_read_layout_resource(dash_duo): + app = Dash(__name__) + app.layout = html.Div( + [ + dcc.Dropdown(id="res-dd", options=["x", "y"], value="x"), + html.Div(id="out"), + ] + ) + + dash_duo.start_server(app) + result = _mcp_method( + dash_duo.server.url, + "resources/read", + {"uri": "dash://layout"}, + ) + + assert "result" in result + layout = json.loads(result["result"]["contents"][0]["text"]) + assert layout["type"] == "Div" + children = layout["props"]["children"] + dd = next( + c for c in children if isinstance(c, dict) and c.get("type") == "Dropdown" + ) + assert dd["props"]["id"] == "res-dd" + assert dd["props"]["options"] == ["x", "y"] diff --git a/tests/integration/mcp/test_mcp_session.py b/tests/integration/mcp/test_mcp_session.py new file mode 100644 index 0000000000..f3cc2bfbd8 --- /dev/null +++ b/tests/integration/mcp/test_mcp_session.py @@ -0,0 +1,128 @@ +"""MCP session lifecycle — end-to-end over a real Dash server. + +Exercises the full MCP session flow (initialize → operate → hot-reload +recovery) against a live ``dash_duo`` server using real HTTP requests. +Unit-level checks (status codes, header mechanics) live in +``tests/unit/mcp/test_mcp_session.py``; these tests verify the broader +behavioral contract. +""" + +import requests + +from dash import Dash, Input, Output, html + +from tests.integration.mcp.conftest import _mcp_post + + +def _mcp_post_with_session( + server_url, method, params=None, request_id=1, session_id=None +): + """Like ``_mcp_post`` but forwards an ``Mcp-Session-Id`` header.""" + headers = {"Content-Type": "application/json"} + if session_id is not None: + headers["Mcp-Session-Id"] = session_id + return requests.post( + f"{server_url}/_mcp", + json={ + "jsonrpc": "2.0", + "method": method, + "id": request_id, + "params": params or {}, + }, + headers=headers, + timeout=5, + ) + + +def test_mcpse_e2e001_full_session_lifecycle(dash_duo): + """Initialize → tools/list → tools/call with session headers throughout.""" + app = Dash(__name__) + app.layout = html.Div([html.Div(id="inp"), html.Div(id="out")]) + + @app.callback(Output("out", "children"), Input("inp", "children")) + def echo(v): + return f"echo: {v}" + + dash_duo.start_server(app) + url = dash_duo.server.url + + init = _mcp_post_with_session(url, "initialize") + assert init.status_code == 200 + sid = init.headers.get("Mcp-Session-Id") + assert sid + + notif = _mcp_post_with_session( + url, "notifications/initialized", session_id=sid, request_id=None + ) + assert notif.status_code == 202 + + tools_resp = _mcp_post_with_session(url, "tools/list", session_id=sid, request_id=2) + assert tools_resp.status_code == 200 + tools = tools_resp.json()["result"]["tools"] + assert any("echo" in t["name"] for t in tools) + + tool_name = next(t["name"] for t in tools if "echo" in t["name"]) + call_resp = _mcp_post_with_session( + url, + "tools/call", + params={"name": tool_name, "arguments": {"v": "hello"}}, + session_id=sid, + request_id=3, + ) + assert call_resp.status_code == 200 + assert call_resp.headers.get("Mcp-Session-Id") == sid + + +def test_mcpse_e2e002_stale_session_recovers_with_notifications(dash_duo): + """Simulate a hot-reload hash change and verify transparent recovery.""" + app = Dash(__name__) + app.layout = html.Div([html.Div(id="inp"), html.Div(id="out")]) + + @app.callback(Output("out", "children"), Input("inp", "children")) + def echo(v): + return f"echo: {v}" + + dash_duo.start_server(app) + url = dash_duo.server.url + + app._hot_reload.hash = "original_hash" + + init = _mcp_post_with_session(url, "initialize") + sid = init.headers["Mcp-Session-Id"] + assert sid == "original_hash" + + resp = _mcp_post_with_session(url, "tools/list", session_id=sid, request_id=2) + assert resp.status_code == 200 + + app._hot_reload.hash = "new_hash" + + resp = _mcp_post_with_session(url, "tools/list", session_id=sid, request_id=3) + assert resp.status_code == 200 + new_sid = resp.headers["Mcp-Session-Id"] + assert new_sid == "new_hash" + + data = resp.json() + assert isinstance(data, list) + assert len(data) == 3 + assert data[0]["method"] == "notifications/tools/list_changed" + assert data[1]["method"] == "notifications/resources/list_changed" + assert "result" in data[2] + assert "tools" in data[2]["result"] + + resp = _mcp_post_with_session(url, "tools/list", session_id=new_sid, request_id=4) + assert resp.status_code == 200 + data = resp.json() + assert isinstance(data, dict) + assert "result" in data + + +def test_mcpse_e2e003_capabilities_advertise_list_changed(dash_duo): + """Server capabilities include listChanged for tools and resources.""" + app = Dash(__name__) + app.layout = html.Div(id="root") + dash_duo.start_server(app) + + resp = _mcp_post(dash_duo.server.url, "initialize") + caps = resp.json()["result"]["capabilities"] + assert caps["tools"]["listChanged"] is True + assert caps["resources"]["listChanged"] is True diff --git a/tests/unit/development/metadata_test.py b/tests/unit/development/metadata_test.py index 0ee96efdeb..e2938af338 100644 --- a/tests/unit/development/metadata_test.py +++ b/tests/unit/development/metadata_test.py @@ -1,7 +1,7 @@ # AUTO GENERATED FILE - DO NOT EDIT import typing # noqa: F401 -from typing_extensions import TypedDict, NotRequired, Literal # noqa: F401 +from typing_extensions import TypedDict, NotRequired, Literal # noqa: F401 from dash.development.base_component import Component, _explicitize_args ComponentType = typing.Union[ @@ -20,126 +20,120 @@ class Table(Component): """A Table component. -This is a description of the component. -It's multiple lines long. + This is a description of the component. + It's multiple lines long. -Keyword arguments: + Keyword arguments: -- children (a list of or a singular dash component, string or number; optional) + - children (a list of or a singular dash component, string or number; optional) -- id (string; optional) + - id (string; optional) -- aria-* (string; optional) + - aria-* (string; optional) -- customArrayProp (list; optional) + - customArrayProp (list; optional) -- customProp (optional) + - customProp (optional) -- data-* (string; optional) + - data-* (string; optional) -- in (string; optional) + - in (string; optional) -- optionalAny (boolean | number | string | dict | list; optional) + - optionalAny (boolean | number | string | dict | list; optional) -- optionalArray (list; optional): - Description of optionalArray. + - optionalArray (list; optional): + Description of optionalArray. -- optionalArrayOf (list of numbers; optional) + - optionalArrayOf (list of numbers; optional) -- optionalBool (boolean; optional) + - optionalBool (boolean; optional) -- optionalElement (dash component; optional) + - optionalElement (dash component; optional) -- optionalEnum (a value equal to: 'News', 'Photos'; optional) + - optionalEnum (a value equal to: 'News', 'Photos'; optional) -- optionalNode (a list of or a singular dash component, string or number; optional) + - optionalNode (a list of or a singular dash component, string or number; optional) -- optionalNumber (number; default 42) + - optionalNumber (number; default 42) -- optionalObject (dict; optional) + - optionalObject (dict; optional) -- optionalObjectOf (dict with strings as keys and values of type number; optional) + - optionalObjectOf (dict with strings as keys and values of type number; optional) -- optionalObjectWithExactAndNestedDescription (dict; optional) + - optionalObjectWithExactAndNestedDescription (dict; optional) - `optionalObjectWithExactAndNestedDescription` is a dict with keys: + `optionalObjectWithExactAndNestedDescription` is a dict with keys: - - color (string; optional) + - color (string; optional) - - fontSize (number; optional) + - fontSize (number; optional) - - figure (dict; optional): - Figure is a plotly graph object. + - figure (dict; optional): + Figure is a plotly graph object. - `figure` is a dict with keys: + `figure` is a dict with keys: - - data (list of dicts; optional): - data is a collection of traces. + - data (list of dicts; optional): + data is a collection of traces. - - layout (dict; optional): - layout describes the rest of the figure. + - layout (dict; optional): + layout describes the rest of the figure. -- optionalObjectWithShapeAndNestedDescription (dict; optional) + - optionalObjectWithShapeAndNestedDescription (dict; optional) - `optionalObjectWithShapeAndNestedDescription` is a dict with keys: + `optionalObjectWithShapeAndNestedDescription` is a dict with keys: - - color (string; optional) + - color (string; optional) - - fontSize (number; optional) + - fontSize (number; optional) - - figure (dict; optional): - Figure is a plotly graph object. + - figure (dict; optional): + Figure is a plotly graph object. - `figure` is a dict with keys: + `figure` is a dict with keys: - - data (list of dicts; optional): - data is a collection of traces. + - data (list of dicts; optional): + data is a collection of traces. - - layout (dict; optional): - layout describes the rest of the figure. + - layout (dict; optional): + layout describes the rest of the figure. -- optionalString (string; default 'hello world') + - optionalString (string; default 'hello world') -- optionalUnion (string | number; optional)""" - _children_props = ['optionalNode', 'optionalElement'] - _base_nodes = ['optionalNode', 'optionalElement', 'children'] - _namespace = 'TableComponents' - _type = 'Table' + - optionalUnion (string | number; optional)""" + + _children_props = ["optionalNode", "optionalElement"] + _base_nodes = ["optionalNode", "optionalElement", "children"] + _namespace = "TableComponents" + _type = "Table" OptionalObjectWithExactAndNestedDescriptionFigure = TypedDict( "OptionalObjectWithExactAndNestedDescriptionFigure", - { - "data": NotRequired[typing.Sequence[dict]], - "layout": NotRequired[dict] - } + {"data": NotRequired[typing.Sequence[dict]], "layout": NotRequired[dict]}, ) OptionalObjectWithExactAndNestedDescription = TypedDict( "OptionalObjectWithExactAndNestedDescription", - { + { "color": NotRequired[str], "fontSize": NotRequired[NumberType], - "figure": NotRequired["OptionalObjectWithExactAndNestedDescriptionFigure"] - } + "figure": NotRequired["OptionalObjectWithExactAndNestedDescriptionFigure"], + }, ) OptionalObjectWithShapeAndNestedDescriptionFigure = TypedDict( "OptionalObjectWithShapeAndNestedDescriptionFigure", - { - "data": NotRequired[typing.Sequence[dict]], - "layout": NotRequired[dict] - } + {"data": NotRequired[typing.Sequence[dict]], "layout": NotRequired[dict]}, ) OptionalObjectWithShapeAndNestedDescription = TypedDict( "OptionalObjectWithShapeAndNestedDescription", - { + { "color": NotRequired[str], "fontSize": NotRequired[NumberType], - "figure": NotRequired["OptionalObjectWithShapeAndNestedDescriptionFigure"] - } + "figure": NotRequired["OptionalObjectWithShapeAndNestedDescriptionFigure"], + }, ) - def __init__( self, children: typing.Optional[ComponentType] = None, @@ -154,26 +148,79 @@ def __init__( optionalElement: typing.Optional[Component] = None, optionalMessage: typing.Optional[typing.Any] = None, optionalEnum: typing.Optional[Literal["News", "Photos"]] = None, - optionalUnion: typing.Optional[typing.Union[str, NumberType, typing.Any]] = None, + optionalUnion: typing.Optional[ + typing.Union[str, NumberType, typing.Any] + ] = None, optionalArrayOf: typing.Optional[typing.Sequence[NumberType]] = None, - optionalObjectOf: typing.Optional[typing.Dict[typing.Union[str, float, int], NumberType]] = None, - optionalObjectWithExactAndNestedDescription: typing.Optional["OptionalObjectWithExactAndNestedDescription"] = None, - optionalObjectWithShapeAndNestedDescription: typing.Optional["OptionalObjectWithShapeAndNestedDescription"] = None, + optionalObjectOf: typing.Optional[ + typing.Dict[typing.Union[str, float, int], NumberType] + ] = None, + optionalObjectWithExactAndNestedDescription: typing.Optional[ + "OptionalObjectWithExactAndNestedDescription" + ] = None, + optionalObjectWithShapeAndNestedDescription: typing.Optional[ + "OptionalObjectWithShapeAndNestedDescription" + ] = None, optionalAny: typing.Optional[typing.Any] = None, customProp: typing.Optional[typing.Any] = None, customArrayProp: typing.Optional[typing.Sequence[typing.Any]] = None, id: typing.Optional[typing.Union[str, dict]] = None, **kwargs ): - self._prop_names = ['children', 'id', 'aria-*', 'customArrayProp', 'customProp', 'data-*', 'in', 'optionalAny', 'optionalArray', 'optionalArrayOf', 'optionalBool', 'optionalElement', 'optionalEnum', 'optionalNode', 'optionalNumber', 'optionalObject', 'optionalObjectOf', 'optionalObjectWithExactAndNestedDescription', 'optionalObjectWithShapeAndNestedDescription', 'optionalString', 'optionalUnion'] - self._valid_wildcard_attributes = ['data-', 'aria-'] - self.available_properties = ['children', 'id', 'aria-*', 'customArrayProp', 'customProp', 'data-*', 'in', 'optionalAny', 'optionalArray', 'optionalArrayOf', 'optionalBool', 'optionalElement', 'optionalEnum', 'optionalNode', 'optionalNumber', 'optionalObject', 'optionalObjectOf', 'optionalObjectWithExactAndNestedDescription', 'optionalObjectWithShapeAndNestedDescription', 'optionalString', 'optionalUnion'] - self.available_wildcard_properties = ['data-', 'aria-'] - _explicit_args = kwargs.pop('_explicit_args') + self._prop_names = [ + "children", + "id", + "aria-*", + "customArrayProp", + "customProp", + "data-*", + "in", + "optionalAny", + "optionalArray", + "optionalArrayOf", + "optionalBool", + "optionalElement", + "optionalEnum", + "optionalNode", + "optionalNumber", + "optionalObject", + "optionalObjectOf", + "optionalObjectWithExactAndNestedDescription", + "optionalObjectWithShapeAndNestedDescription", + "optionalString", + "optionalUnion", + ] + self._valid_wildcard_attributes = ["data-", "aria-"] + self.available_properties = [ + "children", + "id", + "aria-*", + "customArrayProp", + "customProp", + "data-*", + "in", + "optionalAny", + "optionalArray", + "optionalArrayOf", + "optionalBool", + "optionalElement", + "optionalEnum", + "optionalNode", + "optionalNumber", + "optionalObject", + "optionalObjectOf", + "optionalObjectWithExactAndNestedDescription", + "optionalObjectWithShapeAndNestedDescription", + "optionalString", + "optionalUnion", + ] + self.available_wildcard_properties = ["data-", "aria-"] + _explicit_args = kwargs.pop("_explicit_args") _locals = locals() _locals.update(kwargs) # For wildcard attrs and excess named props - args = {k: _locals[k] for k in _explicit_args if k != 'children'} + args = {k: _locals[k] for k in _explicit_args if k != "children"} super(Table, self).__init__(children=children, **args) + setattr(Table, "__init__", _explicitize_args(Table.__init__)) diff --git a/tests/unit/mcp/conftest.py b/tests/unit/mcp/conftest.py new file mode 100644 index 0000000000..a3ddd191aa --- /dev/null +++ b/tests/unit/mcp/conftest.py @@ -0,0 +1,104 @@ +"""Shared helpers for MCP unit tests.""" + +import sys + +from dash import Dash, Input, Output, html +from dash._get_app import app_context + +collect_ignore_glob = [] +if sys.version_info < (3, 10): + collect_ignore_glob.append("*") +else: + from dash.mcp._decorator import ( # pylint: disable=wrong-import-position + MCP_DECORATED_FUNCTIONS, + ) + from dash.mcp.primitives.tools import ( # pylint: disable=wrong-import-position + call_tool, + list_tools, + ) + from dash.mcp.primitives.tools.callback_adapter_collection import ( # pylint: disable=wrong-import-position + CallbackAdapterCollection, + ) + +BUILTINS = {"get_dash_component"} + + +def _setup_mcp(app): + """Set up MCP for an app in tests.""" + app_context.set(app) + if MCP_DECORATED_FUNCTIONS: + app.mcp_decorated_functions = dict(MCP_DECORATED_FUNCTIONS) + MCP_DECORATED_FUNCTIONS.clear() + app.mcp_callback_map = CallbackAdapterCollection(app) + return app + + +def _make_app(**kwargs): + """Create a minimal Dash app with a layout and one callback.""" + app = Dash(__name__, **kwargs) + app.layout = html.Div( + [ + html.Div(id="my-input"), + html.Div(id="my-output"), + ] + ) + + @app.callback( + Output("my-output", "children"), + Input("my-input", "children"), + mcp_expose_docstring=True, + ) + def update_output(value): + """Test callback docstring.""" + return f"echo: {value}" + + return _setup_mcp(app) + + +def _tools_list(app): + """Return all tools (callbacks + builtins) as Tool objects.""" + _setup_mcp(app) + with app.server.test_request_context(): + return list_tools().tools + + +def _user_tool(tools): + """Return the first tool that isn't a builtin.""" + return next(t for t in tools if t.name not in BUILTINS) + + +def _app_with_callback(component, input_prop="value", output_id="out"): + """Create a Dash app with one callback using ``component`` as Input.""" + app = Dash(__name__) + app.layout = html.Div([component, html.Div(id=output_id)]) + + @app.callback(Output(output_id, "children"), Input(component.id, input_prop)) + def update(val): + return f"got: {val}" + + return _setup_mcp(app) + + +def _schema_for(tool, param_name=None): + """Extract the JSON schema dict for a parameter, without description.""" + props = tool.inputSchema["properties"] + if param_name is None: + param_name = next(iter(props)) + schema = dict(props[param_name]) + schema.pop("description", None) + return schema + + +def _desc_for(tool, param_name=None): + """Extract the description string for a parameter, or ''.""" + props = tool.inputSchema["properties"] + if param_name is None: + param_name = next(iter(props)) + return props[param_name].get("description", "") + + +def _call_tool(app, tool_name, arguments=None): + """Call a tool via the dispatch pipeline and return the CallToolResult.""" + _setup_mcp(app) + with app.server.test_request_context(): + return call_tool(tool_name, arguments or {}) diff --git a/tests/unit/mcp/primitives/resources/test_resource_clientside_callbacks.py b/tests/unit/mcp/primitives/resources/test_resource_clientside_callbacks.py new file mode 100644 index 0000000000..3ba2ce7996 --- /dev/null +++ b/tests/unit/mcp/primitives/resources/test_resource_clientside_callbacks.py @@ -0,0 +1,54 @@ +"""Tests for the dash://clientside-callbacks resource.""" + +import json + +from dash import Dash, Input, Output, clientside_callback, html + +from dash.mcp.primitives.resources import list_resources, read_resource + + +class TestClientsideCallbacksResource: + @staticmethod + def _make_app(): + app = Dash(__name__) + app.layout = html.Div( + [ + html.Button(id="btn", children="Click"), + html.Div(id="out"), + html.Div(id="server-out"), + ] + ) + + clientside_callback( + "function(n) { return n; }", + Output("out", "children"), + Input("btn", "n_clicks"), + ) + + @app.callback(Output("server-out", "children"), Input("btn", "n_clicks")) + def server_cb(n): + return str(n) + + with app.server.test_request_context(): + app._setup_server() + + return app + + def test_resource_listed(self): + app = self._make_app() + with app.server.test_request_context(): + result = list_resources() + uris = [str(r.uri) for r in result.resources] + assert "dash://clientside-callbacks" in uris + + def test_resource_read(self): + app = self._make_app() + with app.server.test_request_context(): + result = read_resource("dash://clientside-callbacks") + data = json.loads(result.contents[0].text) + assert "description" in data + callbacks = data["callbacks"] + assert len(callbacks) == 1 + assert callbacks[0]["inputs"][0]["component_id"] == "btn" + assert callbacks[0]["inputs"][0]["property"] == "n_clicks" + assert callbacks[0]["outputs"][0]["component_id"] == "out" diff --git a/tests/unit/mcp/primitives/resources/test_resource_layout.py b/tests/unit/mcp/primitives/resources/test_resource_layout.py new file mode 100644 index 0000000000..ade207b1f3 --- /dev/null +++ b/tests/unit/mcp/primitives/resources/test_resource_layout.py @@ -0,0 +1,59 @@ +"""Tests for the dash://layout resource.""" + +import json +from unittest.mock import patch + +from dash import Dash, dcc, html + +from dash.mcp.primitives.resources import list_resources, read_resource + +EXPECTED_LAYOUT = { + "type": "Div", + "namespace": "dash_html_components", + "props": { + "children": [ + { + "type": "Dropdown", + "namespace": "dash_core_components", + "props": { + "id": "test-dd", + "options": ["a", "b"], + "value": "a", + }, + }, + { + "type": "Div", + "namespace": "dash_html_components", + "props": { + "children": None, + "id": "output", + }, + }, + ] + }, +} + + +class TestLayoutResource: + def test_listed_in_resources(self): + app = Dash(__name__) + app.layout = html.Div(id="main") + with app.server.test_request_context(): + result = list_resources() + uris = [str(r.uri) for r in result.resources] + assert "dash://layout" in uris + + def test_read_returns_layout(self): + app = Dash(__name__) + app.layout = html.Div( + [ + dcc.Dropdown(id="test-dd", options=["a", "b"], value="a"), + html.Div(id="output"), + ] + ) + with app.server.test_request_context(): + with patch.object(app, "get_layout", wraps=app.get_layout) as mock: + result = read_resource("dash://layout") + mock.assert_called_once() + layout = json.loads(result.contents[0].text) + assert layout == EXPECTED_LAYOUT diff --git a/tests/unit/mcp/primitives/resources/test_resource_page_layout.py b/tests/unit/mcp/primitives/resources/test_resource_page_layout.py new file mode 100644 index 0000000000..f4e9caac5d --- /dev/null +++ b/tests/unit/mcp/primitives/resources/test_resource_page_layout.py @@ -0,0 +1,55 @@ +"""Tests for the dash://page-layout/{path} resource template.""" + +import json +from unittest.mock import patch + +from dash import Dash, dcc, html + +from dash.mcp.primitives.resources import read_resource + +EXPECTED_PAGE_LAYOUT = { + "type": "Div", + "namespace": "dash_html_components", + "props": { + "children": [ + { + "type": "Dropdown", + "namespace": "dash_core_components", + "props": { + "id": "page-dd", + "options": ["a", "b"], + "value": "a", + }, + } + ] + }, +} + + +class TestPageLayoutResource: + def test_read_page_layout(self): + app = Dash(__name__) + app.layout = html.Div(id="main") + + page_layout = html.Div( + [ + dcc.Dropdown(id="page-dd", options=["a", "b"], value="a"), + ] + ) + fake_registry = { + "pages.test": { + "path": "/test", + "name": "Test", + "title": "Test Page", + "description": "", + "layout": page_layout, + }, + } + with app.server.test_request_context(): + with patch( + "dash.mcp.primitives.resources.resource_page_layout.PAGE_REGISTRY", + fake_registry, + ): + result = read_resource("dash://page-layout/test") + layout = json.loads(result.contents[0].text) + assert layout == EXPECTED_PAGE_LAYOUT diff --git a/tests/unit/mcp/primitives/resources/test_resource_pages.py b/tests/unit/mcp/primitives/resources/test_resource_pages.py new file mode 100644 index 0000000000..22e6e798fc --- /dev/null +++ b/tests/unit/mcp/primitives/resources/test_resource_pages.py @@ -0,0 +1,87 @@ +"""Tests for the dash://pages resource.""" + +import json +from unittest.mock import patch + +from dash import Dash, html + +from dash.mcp.primitives.resources import list_resources, read_resource + +EXPECTED_PAGES = [ + { + "path": "/", + "name": "Home", + "title": "Home Page", + "description": "The landing page", + "module": "pages.home", + }, + { + "path": "/analytics", + "name": "Analytics", + "title": "Analytics Dashboard", + "description": "View analytics", + "module": "pages.analytics", + }, +] + + +class TestPagesResource: + @staticmethod + def _make_app(): + app = Dash(__name__) + app.layout = html.Div(id="main") + return app + + def test_listed_for_multi_page_app(self): + app = self._make_app() + fake_registry = { + "pages.home": { + "path": "/", + "name": "Home", + "title": "Home", + "description": "", + } + } + with app.server.test_request_context(): + with patch( + "dash.mcp.primitives.resources.resource_pages.PAGE_REGISTRY", + fake_registry, + ): + result = list_resources() + uris = [str(r.uri) for r in result.resources] + assert "dash://pages" in uris + + def test_returns_page_info(self): + app = self._make_app() + fake_registry = { + "pages.home": EXPECTED_PAGES[0], + "pages.analytics": EXPECTED_PAGES[1], + } + with app.server.test_request_context(): + with patch( + "dash.mcp.primitives.resources.resource_pages.PAGE_REGISTRY", + fake_registry, + ): + result = read_resource("dash://pages") + content = json.loads(result.contents[0].text) + assert content == EXPECTED_PAGES + + def test_callable_title_falls_back_to_name(self): + app = self._make_app() + fake_registry = { + "pages.dynamic": { + "path": "/item/", + "name": "Item Detail", + "title": lambda **kwargs: f"Item {kwargs.get('item_id', '')}", + "description": lambda **kwargs: f"Details for {kwargs.get('item_id', '')}", + }, + } + with app.server.test_request_context(): + with patch( + "dash.mcp.primitives.resources.resource_pages.PAGE_REGISTRY", + fake_registry, + ): + result = read_resource("dash://pages") + page = json.loads(result.contents[0].text)[0] + assert page["title"] == "Item Detail" + assert page["description"] == "" diff --git a/tests/unit/mcp/test_mcp_server.py b/tests/unit/mcp/test_mcp_server.py new file mode 100644 index 0000000000..f4bb595dce --- /dev/null +++ b/tests/unit/mcp/test_mcp_server.py @@ -0,0 +1,99 @@ +"""MCP server JSON-RPC message processing (``_process_mcp_message``).""" + +from dash._get_app import app_context +from dash.mcp._server import _process_mcp_message +from mcp.types import LATEST_PROTOCOL_VERSION + +from tests.unit.mcp.conftest import _make_app, _setup_mcp + + +# --------------------------------------------------------------------------- +# Helpers +# --------------------------------------------------------------------------- + + +def _msg(method, params=None, request_id=1): + d = {"jsonrpc": "2.0", "method": method, "id": request_id} + d["params"] = params if params is not None else {} + return d + + +def _mcp(app, method, params=None, request_id=1): + with app.server.test_request_context(): + _setup_mcp(app) + return _process_mcp_message(_msg(method, params, request_id)) + + +def _tools_list(app): + return _mcp(app, "tools/list")["result"]["tools"] + + +def _call_tool(app, tool_name, arguments=None, request_id=1): + return _mcp( + app, "tools/call", {"name": tool_name, "arguments": arguments or {}}, request_id + ) + + +def _call_tool_output( + app, tool_name, arguments=None, component_id=None, prop="children" +): + result = _call_tool(app, tool_name, arguments) + structured = result["result"]["structuredContent"] + response = structured["response"] + if component_id is None: + component_id = next(iter(response)) + return response[component_id][prop] + + +# --------------------------------------------------------------------------- +# Tests +# --------------------------------------------------------------------------- + + +def test_mcps001_initialize(): + app = _make_app() + result = _mcp(app, "initialize") + + assert result is not None + assert result["id"] == 1 + assert result["jsonrpc"] == "2.0" + assert result["result"]["protocolVersion"] == LATEST_PROTOCOL_VERSION + assert "serverInfo" in result["result"] + + +def test_mcps002_tools_call(): + app = _make_app() + tools = _tools_list(app) + tool_name = next(t["name"] for t in tools if "update_output" in t["name"]) + + result = _call_tool(app, tool_name, {"value": "hello"}, request_id=2) + + assert result is not None + assert result["id"] == 2 + assert _call_tool_output(app, tool_name, {"value": "hello"}) == "echo: hello" + + +def test_mcps003_tools_call_unknown_tool_returns_error(): + app = _make_app() + result = _call_tool(app, "nonexistent_tool") + + assert result is not None + assert "error" in result + assert result["error"]["code"] == -32601 + + +def test_mcps004_unknown_method_returns_error(): + app = _make_app() + result = _mcp(app, "unknown/method") + + assert result is not None + assert "error" in result + + +def test_mcps005_notification_returns_none(): + app = _make_app() + data = {"jsonrpc": "2.0", "method": "notifications/initialized"} + with app.server.test_request_context(): + app_context.set(app) + result = _process_mcp_message(data) + assert result is None diff --git a/tests/unit/mcp/test_mcp_session.py b/tests/unit/mcp/test_mcp_session.py new file mode 100644 index 0000000000..0c878eb3da --- /dev/null +++ b/tests/unit/mcp/test_mcp_session.py @@ -0,0 +1,156 @@ +"""MCP session management — ``Mcp-Session-Id`` header and hot-reload hash.""" + +import json +import sys + +import pytest + +if sys.version_info < (3, 10): + pytest.skip("MCP requires Python 3.10+", allow_module_level=True) + +from tests.unit.mcp.conftest import _make_app # pylint: disable=wrong-import-position + + +def _make_mcp_app(**kwargs): + return _make_app(enable_mcp=True, **kwargs) + + +def _post(client, method, params=None, request_id=1, session_id=None): + """POST a JSON-RPC message to the MCP endpoint.""" + headers = {"Content-Type": "application/json"} + if session_id is not None: + headers["Mcp-Session-Id"] = session_id + body = {"jsonrpc": "2.0", "method": method, "id": request_id} + body["params"] = params if params is not None else {} + return client.post("/_mcp", data=json.dumps(body), headers=headers) + + +# --------------------------------------------------------------------------- +# Tests +# --------------------------------------------------------------------------- + + +def test_mcpse001_initialize_returns_session_id(): + app = _make_mcp_app() + with app.server.test_client() as client: + resp = _post(client, "initialize") + assert resp.status_code == 200 + session_id = resp.headers.get("Mcp-Session-Id") + assert session_id is not None + assert len(session_id) > 0 + + +def test_mcpse003_request_with_valid_session_succeeds(): + app = _make_mcp_app() + with app.server.test_client() as client: + init_resp = _post(client, "initialize") + session_id = init_resp.headers["Mcp-Session-Id"] + resp = _post(client, "tools/list", session_id=session_id) + assert resp.status_code == 200 + data = json.loads(resp.data) + assert "result" in data + + +def test_mcpse004_stale_session_recovers_transparently(): + app = _make_mcp_app() + app._hot_reload.hash = "hash_v1" + + with app.server.test_client() as client: + init_resp = _post(client, "initialize") + old_session = init_resp.headers["Mcp-Session-Id"] + assert old_session == "hash_v1" + + app._hot_reload.hash = "hash_v2" + + resp = _post(client, "tools/list", session_id=old_session) + assert resp.status_code == 200 + + data = json.loads(resp.data) + assert isinstance(data, list) + assert len(data) == 3 + + new_session = resp.headers.get("Mcp-Session-Id") + assert new_session is not None + assert new_session == "hash_v2" + + +def test_mcpse005_stale_session_includes_list_changed_notifications(): + app = _make_mcp_app() + app._hot_reload.hash = "hash_v1" + + with app.server.test_client() as client: + init_resp = _post(client, "initialize") + old_session = init_resp.headers["Mcp-Session-Id"] + + app._hot_reload.hash = "hash_v2" + + resp = _post(client, "tools/list", session_id=old_session) + data = json.loads(resp.data) + + assert data[0]["method"] == "notifications/tools/list_changed" + assert data[1]["method"] == "notifications/resources/list_changed" + assert "result" in data[2] + + +def test_mcpse006_reinitialize_after_hot_reload_gets_new_session(): + app = _make_mcp_app() + app._hot_reload.hash = "hash_v1" + + with app.server.test_client() as client: + init_resp = _post(client, "initialize") + old_session = init_resp.headers["Mcp-Session-Id"] + assert old_session == "hash_v1" + + app._hot_reload.hash = "hash_v2" + + # Stale request triggers transparent recovery. + resp = _post(client, "tools/list", session_id=old_session) + assert resp.status_code == 200 + recovered_session = resp.headers["Mcp-Session-Id"] + assert recovered_session == "hash_v2" + + # Re-initialize picks up the new hash. + init_resp2 = _post(client, "initialize") + assert init_resp2.status_code == 200 + new_session = init_resp2.headers["Mcp-Session-Id"] + assert new_session == "hash_v2" + + # Subsequent requests with the new session work. + resp = _post(client, "tools/list", session_id=new_session) + assert resp.status_code == 200 + data = json.loads(resp.data) + assert "result" in data + + +def test_mcpse007_no_session_required_before_first_initialize(): + app = _make_mcp_app() + with app.server.test_client() as client: + resp = _post(client, "tools/list") + assert resp.status_code == 200 + + +def test_mcpse008_production_mode_generates_stable_session(): + app = _make_mcp_app() + assert app._hot_reload.hash is None + + with app.server.test_client() as client: + init_resp = _post(client, "initialize") + session_id = init_resp.headers["Mcp-Session-Id"] + assert session_id is not None + + resp = _post(client, "tools/list", session_id=session_id) + assert resp.status_code == 200 + + resp = _post(client, "tools/list", session_id=session_id) + assert resp.status_code == 200 + + +def test_mcpse009_session_header_on_every_response(): + app = _make_mcp_app() + with app.server.test_client() as client: + init_resp = _post(client, "initialize") + session_id = init_resp.headers["Mcp-Session-Id"] + + resp = _post(client, "tools/list", session_id=session_id) + assert resp.status_code == 200 + assert resp.headers.get("Mcp-Session-Id") == session_id diff --git a/tests/unit/mcp/tools/test_callback_adapter_collection.py b/tests/unit/mcp/tools/test_callback_adapter_collection.py new file mode 100644 index 0000000000..c120a2df8b --- /dev/null +++ b/tests/unit/mcp/tools/test_callback_adapter_collection.py @@ -0,0 +1,145 @@ +"""Tests for CallbackAdapterCollection.""" + +from dash import Dash, Input, Output, dcc, html +from dash._get_app import app_context + +from dash.mcp.primitives.tools.callback_adapter_collection import ( + CallbackAdapterCollection, +) + + +def _setup(app): + app_context.set(app) + app.mcp_callback_map = CallbackAdapterCollection(app) + + +class TestToolNameCollisions: + @staticmethod + def _make_duplicate_cb_app(n=3): + ids = [f"dd{i + 1}" for i in range(n)] + app = Dash(__name__) + app.layout = html.Div( + [ + item + for i in ids + for item in [ + dcc.Dropdown( + id=i, options=[chr(97 + j) for j in range(1)], value="a" + ), + html.Div(id=f"{i}-output"), + ] + ] + ) + for idx, dd_id in enumerate(ids): + + @app.callback(Output(f"{dd_id}-output", "children"), Input(dd_id, "value")) + def cb(value, _id=dd_id): # noqa: F811 + return f"{_id}: {value}" + + return app + + def test_duplicate_func_names_get_unique_tools(self): + app = self._make_duplicate_cb_app(3) + _setup(app) + tool_names = [a.tool_name for a in app.mcp_callback_map] + assert len(tool_names) == 3 + assert len(set(tool_names)) == 3, f"Tool names are not unique: {tool_names}" + for name in tool_names: + assert "dd" in name, f"Expected output ID in tool name: {name}" + + def test_unique_func_names_use_func_name(self): + app = Dash(__name__) + app.layout = html.Div( + [ + html.Div(id="in1"), + html.Div(id="out1"), + html.Div(id="in2"), + html.Div(id="out2"), + ] + ) + + @app.callback(Output("out1", "children"), Input("in1", "children")) + def alpha_handler(value): + return value + + @app.callback(Output("out2", "children"), Input("in2", "children")) + def beta_handler(value): + return value + + _setup(app) + tool_names = [a.tool_name for a in app.mcp_callback_map] + assert "alpha_handler" in tool_names + assert "beta_handler" in tool_names + + def test_duplicate_func_names_use_output_id(self): + app = Dash(__name__) + app.layout = html.Div( + [ + html.Div(id="out1"), + html.Div(id="out2"), + html.Div(id="out3"), + html.Div(id="in1"), + html.Div(id="in2"), + html.Div(id="in3"), + ] + ) + + @app.callback(Output("out1", "children"), Input("in1", "children")) + def unique_func(v): + return v + + @app.callback(Output("out2", "children"), Input("in2", "children")) + def cb(v): + return v + + @app.callback(Output("out3", "children"), Input("in3", "children")) + def cb(v): # noqa: F811 + return v + + _setup(app) + tool_names = [a.tool_name for a in app.mcp_callback_map] + assert "unique_func" in tool_names + non_unique = [n for n in tool_names if n != "unique_func"] + assert len(non_unique) == 2 + assert non_unique[0] != non_unique[1] + + +class TestAllCallbacksVisibleByDefault: + def test_all_callbacks_visible_by_default(self): + app = Dash(__name__) + app.layout = html.Div( + [ + html.Div(id="in1"), + html.Div(id="out1"), + html.Div(id="in2"), + html.Div(id="out2"), + ] + ) + + @app.callback(Output("out1", "children"), Input("in1", "children")) + def cb_one(value): + return value + + @app.callback(Output("out2", "children"), Input("in2", "children")) + def cb_two(value): + return value + + _setup(app) + tool_names = [a.tool_name for a in app.mcp_callback_map] + assert "cb_one" in tool_names + assert "cb_two" in tool_names + + +class TestAdapterCollection: + def test_adapter_has_expected_properties(self): + app = Dash(__name__) + app.layout = html.Div([dcc.Input(id="inp"), html.Div(id="out")]) + + @app.callback(Output("out", "children"), Input("inp", "value")) + def update(val): + return val + + _setup(app) + adapter = app.mcp_callback_map[0] + assert adapter.tool_name == "update" + assert adapter.output_id == "out.children" diff --git a/tests/unit/mcp/tools/test_mcp_background_callbacks.py b/tests/unit/mcp/tools/test_mcp_background_callbacks.py new file mode 100644 index 0000000000..ffd4fb5bac --- /dev/null +++ b/tests/unit/mcp/tools/test_mcp_background_callbacks.py @@ -0,0 +1,278 @@ +"""Background callback support via MCP Tasks. + +Covers both layers: +- Layer 1 (``dash/mcp/tasks/``): ``tasks/get``, ``tasks/result``, ``tasks/cancel`` + derived on-demand from the callback manager. +- Layer 2 (tool wrappers): ``get_background_task_result`` and + ``cancel_background_task`` — only registered when the app has + background callbacks. +""" + +import json +import time + +import diskcache +from dash import Dash, Input, Output, html +from dash.background_callback.managers.diskcache_manager import DiskcacheManager +from dash.mcp._server import _process_mcp_message + +from tests.unit.mcp.conftest import _setup_mcp + + +# --------------------------------------------------------------------------- +# Helpers +# --------------------------------------------------------------------------- + + +def _msg(method, params=None, request_id=1): + d = {"jsonrpc": "2.0", "method": method, "id": request_id} + d["params"] = params if params is not None else {} + return d + + +def _mcp(app, method, params=None, request_id=1): + with app.server.test_request_context(): + _setup_mcp(app) + return _process_mcp_message(_msg(method, params, request_id)) + + +def _make_background_app(): + cache = diskcache.Cache() + manager = DiskcacheManager(cache) + + app = Dash(__name__) + app.layout = html.Div( + [ + html.Div(id="input"), + html.Div(id="output"), + ] + ) + + @app.callback( + Output("output", "children"), + Input("input", "children"), + background=True, + manager=manager, + ) + def slow_callback(value): + """A background callback.""" + time.sleep(0.3) + return f"done: {value}" + + return app + + +def _trigger_task(app): + """Call slow_callback via tools/call and return its taskId.""" + result = _mcp( + app, + "tools/call", + {"name": "slow_callback", "arguments": {"value": "hello"}}, + ) + return json.loads(result["result"]["content"][0]["text"])["taskId"] + + +def _wait_for_completion(app, task_id, timeout=3): + """Block until the callback manager reports the job is no longer running.""" + _, job_id, _ = task_id.split(":", 2) + manager = app.callback_map["output.children"]["manager"] + deadline = time.time() + timeout + while time.time() < deadline: + if not manager.job_running(job_id): + return + time.sleep(0.1) + + +# --------------------------------------------------------------------------- +# Tool-layer: cancel_background_task, get_background_task_result, registration +# --------------------------------------------------------------------------- + + +def test_mcpbg001_cancel_via_tool(): + app = _make_background_app() + task_id = _trigger_task(app) + + cancel = _mcp( + app, + "tools/call", + { + "name": "cancel_background_task", + "arguments": {"taskId": task_id}, + }, + ) + assert cancel["result"].get("isError") is not True + + _, job_id, _ = task_id.split(":", 2) + manager = app.callback_map["output.children"]["manager"] + assert not manager.job_running(job_id) + + +def test_mcpbg002_present_with_background_callbacks(): + app = _make_background_app() + tools = _mcp(app, "tools/list")["result"]["tools"] + names = [t["name"] for t in tools] + assert "get_background_task_result" in names + assert "cancel_background_task" in names + + +def test_mcpbg003_absent_without_background_callbacks(): + app = Dash(__name__) + app.layout = html.Div([html.Div(id="in"), html.Div(id="out")]) + + @app.callback(Output("out", "children"), Input("in", "children")) + def normal_cb(v): + return v + + tools = _mcp(app, "tools/list")["result"]["tools"] + names = [t["name"] for t in tools] + assert "get_background_task_result" not in names + assert "cancel_background_task" not in names + + +def test_mcpbg004_returns_working_while_running(): + app = _make_background_app() + task_id = _trigger_task(app) + poll = _mcp( + app, + "tools/call", + { + "name": "get_background_task_result", + "arguments": {"taskId": task_id}, + }, + ) + text = poll["result"]["content"][0]["text"] + assert "working" in text.lower() + + +def test_mcpbg005_returns_result_when_complete(): + app = _make_background_app() + task_id = _trigger_task(app) + _wait_for_completion(app, task_id) + + result = _mcp( + app, + "tools/call", + { + "name": "get_background_task_result", + "arguments": {"taskId": task_id}, + }, + ) + text = result["result"]["content"][0]["text"] + assert "done:" in text + + +def test_mcpbg006_returns_task_id(): + """Calling a background callback tool returns a taskId immediately.""" + app = _make_background_app() + result = _mcp( + app, + "tools/call", + {"name": "slow_callback", "arguments": {"value": "hello"}}, + ) + text = result["result"]["content"][0]["text"] + assert "taskId" in text + assert "slow_callback:" in text + + +# --------------------------------------------------------------------------- +# Tasks-protocol layer: tasks/get, tasks/result, tasks/cancel +# --------------------------------------------------------------------------- + + +def test_mcpbg007_tasks_get_working_status_while_running(): + app = _make_background_app() + create_result = _mcp( + app, + "tools/call", + { + "name": "slow_callback", + "arguments": {"value": "hello"}, + "task": {"ttl": 60000}, + }, + ) + task_id = create_result["result"]["task"]["taskId"] + + get_result = _mcp(app, "tasks/get", {"taskId": task_id}) + assert get_result["result"]["status"] == "working" + assert get_result["result"]["taskId"] == task_id + + +def test_mcpbg008_tasks_result_returns_formatted_result(): + app = _make_background_app() + create_result = _mcp( + app, + "tools/call", + { + "name": "slow_callback", + "arguments": {"value": "hello"}, + "task": {"ttl": 60000}, + }, + ) + task_id = create_result["result"]["task"]["taskId"] + _wait_for_completion(app, task_id) + + result = _mcp(app, "tasks/result", {"taskId": task_id}) + assert "content" in result["result"] + text = result["result"]["content"][0]["text"] + assert "done:" in text + + +def test_mcpbg009_tasks_cancel_terminates_job(): + app = _make_background_app() + create_result = _mcp( + app, + "tools/call", + { + "name": "slow_callback", + "arguments": {"value": "hello"}, + "task": {"ttl": 60000}, + }, + ) + task_id = create_result["result"]["task"]["taskId"] + + cancel_result = _mcp(app, "tasks/cancel", {"taskId": task_id}) + assert "error" not in cancel_result + + _, job_id, _ = task_id.split(":", 2) + manager = app.callback_map["output.children"]["manager"] + assert not manager.job_running(job_id) + + +# --------------------------------------------------------------------------- +# tools/call with task metadata → CreateTaskResult + taskId encoding +# --------------------------------------------------------------------------- + + +def test_mcpbg010_returns_create_task_result(): + app = _make_background_app() + result = _mcp( + app, + "tools/call", + { + "name": "slow_callback", + "arguments": {"value": "hello"}, + "task": {"ttl": 60000}, + }, + ) + task = result["result"]["task"] + assert task["status"] == "working" + assert "taskId" in task + assert "pollInterval" in task + + +def test_mcpbg011_task_id_encodes_tool_name_job_id_cache_key(): + app = _make_background_app() + result = _mcp( + app, + "tools/call", + { + "name": "slow_callback", + "arguments": {"value": "hello"}, + "task": {"ttl": 60000}, + }, + ) + task_id = result["result"]["task"]["taskId"] + tool_name, _job_id, cache_key, created_epoch = task_id.split(":") + assert tool_name == "slow_callback" + assert len(cache_key) == 64 # SHA256 hex + assert created_epoch.isdigit() diff --git a/tests/unit/mcp/tools/test_mcp_callback_adapter.py b/tests/unit/mcp/tools/test_mcp_callback_adapter.py new file mode 100644 index 0000000000..82dfa0956b --- /dev/null +++ b/tests/unit/mcp/tools/test_mcp_callback_adapter.py @@ -0,0 +1,182 @@ +"""CallbackAdapter behavior: initial value resolution, validation, loop prevention.""" + +import pytest +from dash import Dash, Input, Output, dcc, html +from dash._get_app import app_context + +from dash.mcp.primitives.tools.callback_adapter_collection import ( + CallbackAdapterCollection, +) + + +@pytest.fixture +def simple_app(): + app = Dash(__name__) + app.layout = html.Div( + [ + html.Label("Your Name", htmlFor="inp"), + dcc.Input(id="inp", type="text"), + html.Div(id="out"), + ] + ) + + @app.callback(Output("out", "children"), Input("inp", "value")) + def update(val): + """Update output.""" + return val + + app_context.set(app) + app.mcp_callback_map = CallbackAdapterCollection(app) + return app + + +def test_mcpc001_returns_layout_value(simple_app): + callback_map = app_context.get().mcp_callback_map + # Input with no value set — returns None (layout default for dcc.Input) + assert callback_map.get_initial_value("inp.value") is None + + +def test_mcpc002_returns_set_value(): + app = Dash(__name__) + app.layout = html.Div( + [ + dcc.Dropdown(id="dd", options=["a", "b"], value="a"), + html.Div(id="out"), + ] + ) + + @app.callback(Output("out", "children"), Input("dd", "value")) + def update(selected): + return selected + + app_context.set(app) + app.mcp_callback_map = CallbackAdapterCollection(app) + assert app.mcp_callback_map.get_initial_value("dd.value") == "a" + + +def test_mcpc003_initial_callback_makes_param_required(): + """A param with None in layout but set by an initial callback is required.""" + app = Dash(__name__) + app.layout = html.Div( + [ + dcc.Dropdown(id="country", options=["France", "Germany"], value="France"), + dcc.Dropdown(id="city"), # value=None in layout + html.Div(id="out"), + ] + ) + + @app.callback( + Output("city", "options"), + Output("city", "value"), + Input("country", "value"), + ) + def update_cities(country): + return [{"label": "Paris", "value": "Paris"}], "Paris" + + @app.callback(Output("out", "children"), Input("city", "value")) + def show_city(city): + return f"Selected: {city}" + + app_context.set(app) + app.mcp_callback_map = CallbackAdapterCollection(app) + + # city.value is None in layout but "Paris" after initial callback + with app.server.test_request_context(): + show_city_cb = app.mcp_callback_map.find_by_tool_name("show_city") + city_param = show_city_cb.inputs[0] + assert city_param["name"] == "city" + assert city_param["required"] is True # not optional despite None in layout + + +def test_mcpc004_valid_when_inputs_in_layout(simple_app): + assert app_context.get().mcp_callback_map[0].is_valid + + +def test_mcpc005_invalid_when_input_not_in_layout(): + app = Dash(__name__) + app.layout = html.Div([html.Div(id="out")]) + + @app.callback(Output("out", "children"), Input("nonexistent", "value")) + def update(val): + return val + + app_context.set(app) + app.mcp_callback_map = CallbackAdapterCollection(app) + assert not app.mcp_callback_map[0].is_valid + + +def test_mcpc006_pattern_matching_ids_always_valid(): + app = Dash(__name__) + app.layout = html.Div( + [ + dcc.Input(id={"type": "field", "index": 0}, value="a"), + html.Div(id="out"), + ] + ) + + @app.callback( + Output("out", "children"), + Input({"type": "field", "index": 0}, "value"), + ) + def update(val): + return val + + app_context.set(app) + app.mcp_callback_map = CallbackAdapterCollection(app) + assert app.mcp_callback_map[0].is_valid + + +@pytest.mark.timeout(5) +def test_mcpc007_initial_output_does_not_loop(): + """Building a tool must not trigger infinite re-entry in _initial_output.""" + app = Dash(__name__) + app.layout = html.Div( + [ + dcc.Slider(id="sl", min=0, max=10, value=5), + html.Div(id="out"), + ] + ) + + @app.callback(Output("out", "children"), Input("sl", "value")) + def show(value): + return f"Value: {value}" + + app_context.set(app) + app.mcp_callback_map = CallbackAdapterCollection(app) + + with app.server.test_request_context(): + tool = app.mcp_callback_map[0].as_mcp_tool + assert tool.name == "show" + + +@pytest.mark.timeout(5) +def test_mcpc008_chained_callbacks_do_not_loop(): + """Chained callbacks with initial value resolution must not loop.""" + app = Dash(__name__) + app.layout = html.Div( + [ + dcc.Slider(id="sl", min=0, max=10, value=5), + dcc.Slider(id="sl2", min=0, max=10), + html.Div(id="out"), + ] + ) + + @app.callback(Output("sl2", "value"), Input("sl", "value")) + def sync(v): + return v + + @app.callback( + Output("out", "children"), + Input("sl", "value"), + Input("sl2", "value"), + ) + def show(v1, v2): + return f"{v1} + {v2}" + + app_context.set(app) + app.mcp_callback_map = CallbackAdapterCollection(app) + + with app.server.test_request_context(): + for cb in app.mcp_callback_map: + tool = cb.as_mcp_tool + assert tool.name is not None diff --git a/tests/unit/mcp/tools/test_mcp_enabled_decorator.py b/tests/unit/mcp/tools/test_mcp_enabled_decorator.py new file mode 100644 index 0000000000..c033205f5d --- /dev/null +++ b/tests/unit/mcp/tools/test_mcp_enabled_decorator.py @@ -0,0 +1,170 @@ +"""Unit tests for the @mcp_enabled decorator and its tool integration.""" + +from typing import Optional + +from dash import Dash, html +from dash.mcp import mcp_enabled + +from tests.unit.mcp.conftest import _setup_mcp, _tools_list, _call_tool + +BUILTINS = {"get_dash_component"} + + +def _make_app_with_decorated(): + app = Dash(__name__) + app.layout = html.Div(id="root") + return _setup_mcp(app) + + +def _user_tools(app): + return [t for t in _tools_list(app) if t.name not in BUILTINS] + + +def test_mcpd001_bare_decorator_preserves_function(): + @mcp_enabled + def double(x: int) -> int: + return x * 2 + + assert double(5) == 10 + + +def test_mcpd002_parameterised_decorator_preserves_function(): + @mcp_enabled(name="doubler") + def double(x: int) -> int: + return x * 2 + + assert double(5) == 10 + + +def test_mcpd003_bare_decorator_appears_as_tool(): + @mcp_enabled + def add_numbers(a: int, b: int) -> int: + """Add two numbers together.""" + return a + b + + app = _make_app_with_decorated() + names = {t.name for t in _user_tools(app)} + assert "add_numbers" in names + + +def test_mcpd004_custom_name_overrides_function_name(): + @mcp_enabled(name="sum_values") + def add_numbers(a: int, b: int) -> int: + """Add two numbers together.""" + return a + b + + app = _make_app_with_decorated() + names = {t.name for t in _user_tools(app)} + assert "sum_values" in names + assert "add_numbers" not in names + + +def test_mcpd005_docstring_hidden_by_default(): + @mcp_enabled + def add_numbers(a: int, b: int) -> int: + """Add two numbers together.""" + return a + b + + app = _make_app_with_decorated() + tool = next(t for t in _user_tools(app) if t.name == "add_numbers") + assert "Add two numbers" not in tool.description + + +def test_mcpd006_docstring_exposed_when_opted_in(): + @mcp_enabled(expose_docstring=True) + def add_numbers(a: int, b: int) -> int: + """Add two numbers together.""" + return a + b + + app = _make_app_with_decorated() + tool = next(t for t in _user_tools(app) if t.name == "add_numbers") + assert "Add two numbers together" in tool.description + + +def test_mcpd007_typed_params_produce_schema(): + @mcp_enabled + def greet(name: str, times: int) -> str: + """Greet someone.""" + return name * times + + app = _make_app_with_decorated() + tool = next(t for t in _user_tools(app) if t.name == "greet") + schema = tool.inputSchema + assert schema["type"] == "object" + assert schema["properties"]["name"]["type"] == "string" + assert schema["properties"]["times"]["type"] == "integer" + assert set(schema["required"]) == {"name", "times"} + + +def test_mcpd008_optional_param_not_required(): + @mcp_enabled + def search(query: str, limit: Optional[int] = None) -> str: + """Search for things.""" + return query + + app = _make_app_with_decorated() + tool = next(t for t in _user_tools(app) if t.name == "search") + schema = tool.inputSchema + assert "query" in schema["required"] + assert "limit" not in schema["required"] + + +def test_mcpd009_typed_param_with_default_not_required(): + @mcp_enabled + def filter_range(min_val: float = -180.0, max_val: float = 180.0) -> list[str]: + """Filter by range.""" + return [] + + app = _make_app_with_decorated() + tool = next(t for t in _user_tools(app) if t.name == "filter_range") + schema = tool.inputSchema + assert "required" not in schema or "min_val" not in schema.get("required", []) + assert "required" not in schema or "max_val" not in schema.get("required", []) + + +def test_mcpd010_return_annotation_becomes_output_schema(): + @mcp_enabled + def compute(x: int) -> str: + """Compute something.""" + return str(x) + + app = _make_app_with_decorated() + tool = next(t for t in _user_tools(app) if t.name == "compute") + assert tool.outputSchema["type"] == "object" + assert tool.outputSchema["properties"]["result"]["type"] == "string" + + +def test_mcpd011_call_returns_result(): + @mcp_enabled + def multiply(a: int, b: int) -> int: + """Multiply two numbers.""" + return a * b + + app = _make_app_with_decorated() + result = _call_tool(app, "multiply", {"a": 3, "b": 7}) + assert result.isError is not True + assert result.structuredContent["result"] == 21 + + +def test_mcpd012_call_with_custom_name(): + @mcp_enabled(name="product") + def multiply(a: int, b: int) -> int: + """Multiply two numbers.""" + return a * b + + app = _make_app_with_decorated() + result = _call_tool(app, "product", {"a": 4, "b": 5}) + assert result.isError is not True + assert result.structuredContent["result"] == 20 + + +def test_mcpd013_call_error_returns_is_error(): + @mcp_enabled + def fail_hard(x: int) -> str: + """Always fails.""" + raise ValueError("boom") + + app = _make_app_with_decorated() + result = _call_tool(app, "fail_hard", {"x": 1}) + assert result.isError is True + assert "boom" in result.content[0].text diff --git a/tests/unit/mcp/tools/test_mcp_formatted_output_results.py b/tests/unit/mcp/tools/test_mcp_formatted_output_results.py new file mode 100644 index 0000000000..da931009cd --- /dev/null +++ b/tests/unit/mcp/tools/test_mcp_formatted_output_results.py @@ -0,0 +1,240 @@ +"""Formatted tool-output results: structured content + text/image fallbacks. + +Covers: +- ``format_callback_response`` — wraps the callback result as + ``structuredContent`` and delegates to per-output formatters. +- ``DataFrameResult`` — renders DataTable/AgGrid rows as markdown tables. +- ``PlotlyFigureResult`` — renders ``dcc.Graph.figure`` values as PNG images + (via kaleido; skipped gracefully when kaleido is unavailable). +""" + +import base64 +from unittest.mock import Mock, patch + +import plotly.graph_objects as go # type: ignore[import-untyped] + +from dash.mcp.primitives.tools.results import format_callback_response +from dash.mcp.primitives.tools.results.result_dataframe import ( + MAX_ROWS, + DataFrameResult, +) +from dash.mcp.primitives.tools.results.result_plotly_figure import ( + PlotlyFigureResult, +) + + +# --------------------------------------------------------------------------- +# Helpers / fixtures +# --------------------------------------------------------------------------- + + +def _mock_callback(outputs=None): + cb = Mock() + cb.outputs = outputs or [] + return cb + + +EXPECTED_TABLE = ( + "*2 rows \u00d7 2 columns*\n" + "\n" + "name | age\n" + "--- | ---\n" + "Alice | 30\n" + "Bob | 25" +) + +SAMPLE_ROWS = [{"name": "Alice", "age": 30}, {"name": "Bob", "age": 25}] + +DATATABLE_OUTPUT = { + "component_type": "DataTable", + "property": "data", + "component_id": "t", + "id_and_prop": "t.data", + "initial_value": None, + "tool_name": "update", +} + +AGGRID_OUTPUT = { + "component_type": "AgGrid", + "property": "rowData", + "component_id": "g", + "id_and_prop": "g.rowData", + "initial_value": None, + "tool_name": "update", +} + +FAKE_PNG = b"\x89PNG\r\n\x1a\nfakedata" +FAKE_B64 = base64.b64encode(FAKE_PNG).decode("ascii") + +GRAPH_FIGURE_OUTPUT = { + "component_type": "Graph", + "property": "figure", + "component_id": "g", + "id_and_prop": "g.figure", + "initial_value": None, + "tool_name": "update", +} + + +# --------------------------------------------------------------------------- +# format_callback_response +# --------------------------------------------------------------------------- + + +def test_mcpr001_wraps_as_structured_content(): + response = { + "multi": True, + "response": {"out": {"children": "hello"}}, + } + result = format_callback_response(response, _mock_callback()) + assert result.structuredContent == response + + +def test_mcpr002_content_has_json_text_fallback(): + """Per MCP spec, structuredContent SHOULD include a TextContent fallback.""" + response = {"multi": True, "response": {}} + result = format_callback_response(response, _mock_callback()) + assert len(result.content) >= 1 + assert result.content[0].type == "text" + assert '"multi": true' in result.content[0].text + + +def test_mcpr003_is_error_defaults_false(): + response = {"multi": True, "response": {}} + result = format_callback_response(response, _mock_callback()) + assert result.isError is False + + +def test_mcpr004_preserves_side_update(): + response = { + "multi": True, + "response": {"out": {"children": "x"}}, + "sideUpdate": {"other": {"value": 42}}, + } + result = format_callback_response(response, _mock_callback()) + assert result.structuredContent["sideUpdate"] == {"other": {"value": 42}} + + +def test_mcpr005_datatable_result_includes_markdown_table(): + response = { + "multi": True, + "response": { + "my-table": {"data": [{"name": "Alice", "age": 30}]}, + }, + } + outputs = [ + { + "component_id": "my-table", + "component_type": "DataTable", + "property": "data", + "id_and_prop": "my-table.data", + "initial_value": None, + "tool_name": "update", + } + ] + result = format_callback_response(response, _mock_callback(outputs)) + texts = [c.text for c in result.content if c.type == "text"] + assert any("name | age" in t for t in texts) + + +def test_mcpr006_plotly_figure_includes_image(): + response = { + "multi": True, + "response": { + "my-graph": { + "figure": { + "data": [{"type": "bar", "x": ["A"], "y": [1]}], + "layout": {}, + } + } + }, + } + outputs = [ + { + "component_id": "my-graph", + "component_type": "Graph", + "property": "figure", + "id_and_prop": "my-graph.figure", + "initial_value": None, + "tool_name": "update", + } + ] + with patch.object(go.Figure, "to_image", return_value=b"\x89PNGfake"): + result = format_callback_response(response, _mock_callback(outputs)) + images = [c for c in result.content if c.type == "image"] + assert len(images) == 1 + + +# --------------------------------------------------------------------------- +# DataFrameResult (DataTable / AgGrid markdown rendering) +# --------------------------------------------------------------------------- + + +def test_mcpr007_datatable_data_renders_markdown(): + result = DataFrameResult.format(DATATABLE_OUTPUT, SAMPLE_ROWS) + assert len(result) == 1 + assert result[0].text == EXPECTED_TABLE + + +def test_mcpr008_aggrid_rowdata_renders_markdown(): + result = DataFrameResult.format(AGGRID_OUTPUT, SAMPLE_ROWS) + assert len(result) == 1 + assert result[0].text == EXPECTED_TABLE + + +def test_mcpr009_ignores_non_tabular_props(): + non_tabular = {**DATATABLE_OUTPUT, "property": "columns"} + assert DataFrameResult.format(non_tabular, SAMPLE_ROWS) == [] + + +def test_mcpr010_ignores_empty_or_non_dict_rows(): + assert DataFrameResult.format(DATATABLE_OUTPUT, []) == [] + assert DataFrameResult.format(DATATABLE_OUTPUT, ["a", "b"]) == [] + + +def test_mcpr011_truncates_large_tables(): + rows = [{"i": n} for n in range(MAX_ROWS + 50)] + result = DataFrameResult.format(DATATABLE_OUTPUT, rows) + text = result[0].text + assert f"\n{MAX_ROWS - 1}\n" in text + assert f"\n{MAX_ROWS}\n" not in text + assert "50 more rows" in text + + +# --------------------------------------------------------------------------- +# PlotlyFigureResult (Graph.figure → PNG image) +# --------------------------------------------------------------------------- + + +def test_mcpr012_returns_image_when_kaleido_available(): + fig_dict = go.Figure(data=[go.Bar(x=["A", "B"], y=[1, 2])]).to_plotly_json() + with patch.object(go.Figure, "to_image", return_value=FAKE_PNG): + result = PlotlyFigureResult.format(GRAPH_FIGURE_OUTPUT, fig_dict) + assert len(result) == 1 + assert result[0].type == "image" + assert result[0].data == FAKE_B64 + + +def test_mcpr013_returns_empty_when_kaleido_unavailable(): + fig_dict = go.Figure(data=[go.Bar(x=["A", "B"], y=[1, 2])]).to_plotly_json() + with patch.object(go.Figure, "to_image", side_effect=ImportError): + result = PlotlyFigureResult.format(GRAPH_FIGURE_OUTPUT, fig_dict) + assert result == [] + + +def test_mcpr014_ignores_non_graph_components(): + output = { + **GRAPH_FIGURE_OUTPUT, + "component_type": "Div", + "property": "children", + } + assert PlotlyFigureResult.format(output, {}) == [] + + +def test_mcpr015_ignores_non_figure_props(): + output = {**GRAPH_FIGURE_OUTPUT, "property": "clickData"} + assert PlotlyFigureResult.format(output, {}) == [] + + +def test_mcpr016_ignores_non_dict_values(): + assert PlotlyFigureResult.format(GRAPH_FIGURE_OUTPUT, "not a dict") == [] diff --git a/tests/unit/mcp/tools/test_mcp_input_descriptions.py b/tests/unit/mcp/tools/test_mcp_input_descriptions.py new file mode 100644 index 0000000000..55ac4df49a --- /dev/null +++ b/tests/unit/mcp/tools/test_mcp_input_descriptions.py @@ -0,0 +1,445 @@ +"""Input descriptions — human-readable per-property descriptions for MCP tool inputs. + +Covers: +- Labels (htmlFor, containment, text extraction) +- Component-specific (date pickers, sliders) +- Options (Dropdown, RadioItems, Checklist) +- Generic props (placeholder, default value, min/max/step) +- Chained callbacks (dynamic prop/options detection) +- Combinations (label + component-specific) +""" + +import pytest + +from dash import Dash, Input, Output, dcc, html + +from tests.unit.mcp.conftest import ( + _app_with_callback, + _desc_for, + _tools_list, + _user_tool, +) + + +# --------------------------------------------------------------------------- +# Helpers +# --------------------------------------------------------------------------- + + +def _app_with_layout(layout, *inputs): + app = Dash(__name__) + app.layout = layout + + @app.callback( + Output("out", "children"), + [Input(cid, prop) for cid, prop in inputs], + ) + def update(*args): + return str(args) + + return app + + +def _tool_for(component, input_prop="value"): + app = _app_with_callback(component, input_prop=input_prop) + return _user_tool(_tools_list(app)) + + +# --------------------------------------------------------------------------- +# Labels (htmlFor, containment, text extraction) +# --------------------------------------------------------------------------- + + +def test_mcpd001_label_html_for(): + app = _app_with_layout( + html.Div( + [ + html.Label("Your Name", htmlFor="inp"), + dcc.Input(id="inp"), + html.Div(id="out"), + ] + ), + ("inp", "value"), + ) + tool = _user_tool(_tools_list(app)) + assert "Your Name" in _desc_for(tool) + + +def test_mcpd002_label_html_for_not_adjacent(): + app = _app_with_layout( + html.Div( + [ + html.Div(html.Label("Remote Label", htmlFor="inp")), + dcc.Input(id="inp"), + html.Div(id="out"), + ] + ), + ("inp", "value"), + ) + tool = _user_tool(_tools_list(app)) + assert "Remote Label" in _desc_for(tool) + + +def test_mcpd003_label_containment(): + app = _app_with_layout( + html.Div( + [ + html.Label( + [ + "Pick a city", + dcc.Dropdown(id="city_dd", options=["NYC", "LA"]), + ] + ), + html.Div(id="out"), + ] + ), + ("city_dd", "value"), + ) + tool = _user_tool(_tools_list(app)) + assert "Pick a city" in _desc_for(tool) + + +def test_mcpd004_label_deeply_nested_containment(): + app = _app_with_layout( + html.Div( + [ + html.Label( + [ + html.Span("Nested Label"), + html.Div(dcc.Input(id="nested_inp")), + ] + ), + html.Div(id="out"), + ] + ), + ("nested_inp", "value"), + ) + tool = _user_tool(_tools_list(app)) + assert "Nested Label" in _desc_for(tool) + + +def test_mcpd005_label_both_htmlfor_and_containment_captured(): + app = _app_with_layout( + html.Div( + [ + html.Label(["Containment Label", dcc.Input(id="inp")]), + html.Label("HtmlFor Label", htmlFor="inp"), + html.Div(id="out"), + ] + ), + ("inp", "value"), + ) + tool = _user_tool(_tools_list(app)) + desc = _desc_for(tool) + assert "HtmlFor Label" in desc + assert "Containment Label" in desc + + +def test_mcpd006_label_deep_text_extraction(): + app = _app_with_layout( + html.Div( + [ + html.Label( + html.Div(html.Span(html.B("Deep Text"))), + htmlFor="inp", + ), + dcc.Input(id="inp"), + html.Div(id="out"), + ] + ), + ("inp", "value"), + ) + tool = _user_tool(_tools_list(app)) + assert "Deep Text" in _desc_for(tool) + + +def test_mcpd007_label_multiple_text_nodes(): + app = _app_with_layout( + html.Div( + [ + html.Label( + [html.B("First"), " ", html.I("Second")], + htmlFor="inp", + ), + dcc.Input(id="inp"), + html.Div(id="out"), + ] + ), + ("inp", "value"), + ) + tool = _user_tool(_tools_list(app)) + desc = _desc_for(tool) + assert "Labeled with: First Second" in desc + + +def test_mcpd008_label_unrelated_excluded(): + app = _app_with_layout( + html.Div( + [ + html.Label("Other Field", htmlFor="other"), + dcc.Input(id="other"), + dcc.Input(id="target"), + html.Div(id="out"), + ] + ), + ("target", "value"), + ) + tool = _user_tool(_tools_list(app)) + desc = _desc_for(tool) + assert "Other Field" not in (desc or "") + + +# --------------------------------------------------------------------------- +# Component-specific: date pickers +# --------------------------------------------------------------------------- + + +def test_mcpd009_date_picker_single_full_range(): + dp = dcc.DatePickerSingle( + id="dp", + min_date_allowed="2020-01-01", + max_date_allowed="2025-12-31", + ) + desc = _desc_for(_tool_for(dp, "date"), "val") + assert "2020-01-01" in desc + assert "2025-12-31" in desc + + +def test_mcpd010_date_picker_single_min_only(): + dp = dcc.DatePickerSingle(id="dp", min_date_allowed="2020-01-01") + desc = _desc_for(_tool_for(dp, "date"), "val") + assert "min_date_allowed: '2020-01-01'" in desc + + +def test_mcpd011_date_picker_single_default_date(): + dp = dcc.DatePickerSingle(id="dp", date="2024-06-15") + desc = _desc_for(_tool_for(dp, "date"), "val") + assert "date: '2024-06-15'" in desc + + +def test_mcpd012_date_picker_range_with_constraints(): + dpr = dcc.DatePickerRange( + id="dpr", + min_date_allowed="2020-01-01", + max_date_allowed="2025-12-31", + ) + desc = _desc_for(_tool_for(dpr, "start_date"), "val") + assert "2020-01-01" in desc + + +# --------------------------------------------------------------------------- +# Component-specific: sliders +# --------------------------------------------------------------------------- + + +def test_mcpd013_slider_min_max(): + sl = dcc.Slider(id="sl", min=0, max=100) + desc = _desc_for(_tool_for(sl), "val") + assert "min: 0" in desc + assert "max: 100" in desc + + +def test_mcpd014_slider_step(): + sl = dcc.Slider(id="sl", min=0, max=100, step=5) + desc = _desc_for(_tool_for(sl), "val") + assert "step: 5" in desc + + +def test_mcpd015_slider_default_value(): + sl = dcc.Slider(id="sl", min=0, max=100, value=50) + desc = _desc_for(_tool_for(sl), "val") + assert "value: 50" in desc + + +def test_mcpd016_slider_marks(): + sl = dcc.Slider(id="sl", min=0, max=100, marks={0: "Low", 100: "High"}) + desc = _desc_for(_tool_for(sl), "val") + assert "marks: {0: 'Low', 100: 'High'}" in desc + + +def test_mcpd017_range_slider_min_max(): + rs = dcc.RangeSlider(id="rs", min=0, max=100) + desc = _desc_for(_tool_for(rs), "val") + assert "min: 0" in desc + assert "max: 100" in desc + + +# --------------------------------------------------------------------------- +# Options (parametrized across Dropdown, RadioItems, Checklist) +# --------------------------------------------------------------------------- + + +_OPTIONS_COMPONENTS = [ + ("Dropdown", lambda **kw: dcc.Dropdown(id="comp", **kw), "comp"), + ("RadioItems", lambda **kw: dcc.RadioItems(id="comp", **kw), "comp"), + ("Checklist", lambda **kw: dcc.Checklist(id="comp", **kw), "comp"), +] + + +@pytest.mark.parametrize( + "name,factory,cid", _OPTIONS_COMPONENTS, ids=[c[0] for c in _OPTIONS_COMPONENTS] +) +def test_mcpd018_options_shown(name, factory, cid): + comp = factory(options=["X", "Y", "Z"]) + desc = _desc_for(_tool_for(comp), "val") + assert "options: ['X', 'Y', 'Z']" in desc + + +@pytest.mark.parametrize( + "name,factory,cid", _OPTIONS_COMPONENTS, ids=[c[0] for c in _OPTIONS_COMPONENTS] +) +def test_mcpd019_default_shown(name, factory, cid): + value = ["a"] if name == "Checklist" else "a" + comp = factory(options=["a", "b"], value=value) + desc = _desc_for(_tool_for(comp), "val") + assert f"value: {value!r}" in desc + + +def test_mcpd020_dropdown_dict_options(): + dd = dcc.Dropdown( + id="dd", + options=[ + {"label": "New York", "value": "NYC"}, + ], + ) + assert "NYC" in _desc_for(_tool_for(dd), "val") + + +def test_mcpd021_store_storage_type_template(): + store = dcc.Store(id="store", storage_type="session") + app = _app_with_callback(store, input_prop="data") + tool = _user_tool(_tools_list(app)) + desc = _desc_for(tool, "val") + assert ( + "storage_type: 'session'. Describes how to store the value client-side" in desc + ) + + +def test_mcpd022_many_options_truncated(): + dd = dcc.Dropdown(id="big", options=[str(i) for i in range(50)], value="0") + app = _app_with_callback(dd) + tool = _user_tool(_tools_list(app)) + desc = _desc_for(tool, "val") + assert "options:" in desc + assert "Use get_dash_component('big', 'options') for the full value" in desc + + +# --------------------------------------------------------------------------- +# Generic props (placeholder, default, numeric min/max/step) +# --------------------------------------------------------------------------- + + +def test_mcpd023_generic_placeholder(): + inp = dcc.Input(id="inp", placeholder="Enter your name") + assert "placeholder: 'Enter your name'" in _desc_for(_tool_for(inp), "val") + + +def test_mcpd024_generic_numeric_min_max(): + inp = dcc.Input(id="inp", type="number", min=0, max=999) + desc = _desc_for(_tool_for(inp), "val") + assert "min: 0" in desc + assert "max: 999" in desc + + +def test_mcpd025_generic_step(): + inp = dcc.Input(id="inp", type="number", min=0, max=100, step=0.1) + assert "step: 0.1" in _desc_for(_tool_for(inp), "val") + + +def test_mcpd026_generic_default_value(): + inp = dcc.Input(id="inp", value="hello") + desc = _desc_for(_tool_for(inp), "val") + assert "value: 'hello'" in desc + + +def test_mcpd027_generic_non_text_type(): + inp = dcc.Input(id="inp", type="email") + assert "type: 'email'" in _desc_for(_tool_for(inp), "val") + + +def test_mcpd028_generic_store_default(): + store = dcc.Store(id="store", data={"key": "value"}) + app = _app_with_callback(store, input_prop="data") + tool = _user_tool(_tools_list(app)) + assert "data: {'key': 'value'}" in _desc_for(tool, "val") + + +# --------------------------------------------------------------------------- +# Chained callbacks — descriptions reflect upstream dependencies +# --------------------------------------------------------------------------- + + +def test_mcpd029_chained_options_set_by_upstream(): + app = Dash(__name__) + app.layout = html.Div( + [ + dcc.Dropdown(id="country", options=["US", "CA"], value="US"), + dcc.Dropdown(id="city", options=[], value=None), + html.Div(id="result"), + ] + ) + + @app.callback(Output("city", "options"), Input("country", "value")) + def update_cities(country): + return ["NYC", "LA"] if country == "US" else ["Toronto"] + + @app.callback(Output("result", "children"), Input("city", "value")) + def show_city(city): + return city + + tools = _tools_list(app) + tool = next(t for t in tools if "show_city" in t.name) + desc = _desc_for(tool, "city") + assert "can be updated by tool: `update_cities`" in desc + assert "options:" in desc + + +def test_mcpd030_chained_value_set_by_upstream(): + app = Dash(__name__) + app.layout = html.Div( + [ + dcc.Input(id="source", value=""), + html.Div(id="derived", children=""), + html.Div(id="result"), + ] + ) + + @app.callback(Output("derived", "children"), Input("source", "value")) + def compute_derived(val): + return f"derived: {val}" + + @app.callback(Output("result", "children"), Input("derived", "children")) + def use_derived(val): + return val + + tools = _tools_list(app) + tool = next(t for t in tools if "use_derived" in t.name) + desc = _desc_for(tool, "val") + assert "can be updated by tool: `compute_derived`" in desc + + +# --------------------------------------------------------------------------- +# Combinations — label + component-specific +# --------------------------------------------------------------------------- + + +def test_mcpd031_combination_label_with_date_picker(): + dp = dcc.DatePickerSingle( + id="dp", + min_date_allowed="2020-01-01", + max_date_allowed="2025-12-31", + ) + app = _app_with_layout( + html.Div( + [ + html.Label("Departure Date", htmlFor="dp"), + dp, + html.Div(id="out"), + ] + ), + ("dp", "date"), + ) + tool = _user_tool(_tools_list(app)) + desc = _desc_for(tool) + assert "Departure Date" in desc + assert "2020-01-01" in desc diff --git a/tests/unit/mcp/tools/test_mcp_input_schemas.py b/tests/unit/mcp/tools/test_mcp_input_schemas.py new file mode 100644 index 0000000000..ade87bbc79 --- /dev/null +++ b/tests/unit/mcp/tools/test_mcp_input_schemas.py @@ -0,0 +1,289 @@ +"""Input schema generation — JSON Schema for callback input parameters. + +Covers: +- Static overrides (DatePicker, Graph, Interval, Slider) +- Component introspection (representative per-type samples) +- Callback annotation overrides (highest priority) +- Required / nullable behavior +- Component type → JSON schema mapping +""" + +import pytest +from typing import Optional + +from dash import Dash, Input, Output, State, dcc, html +from dash.development.base_component import Component +from dash.mcp.primitives.tools.input_schemas.schema_callback_type_annotations import ( + annotation_to_json_schema, +) + +from tests.unit.mcp.conftest import ( + _app_with_callback, + _schema_for, + _tools_list, + _user_tool, +) + + +# --------------------------------------------------------------------------- +# Schema building blocks (JSON Schema primitives) +# --------------------------------------------------------------------------- + +STRING = {"type": "string"} +NUMBER = {"type": "number"} +INTEGER = {"type": "integer"} +BOOLEAN = {"type": "boolean"} +NULL = {"type": "null"} +OBJECT = {"additionalProperties": True, "type": "object"} + + +def nullable(*schemas): + """``{anyOf: [*schemas, {type: null}]}`` — a common nullable-type shape.""" + return {"anyOf": [*schemas, NULL]} + + +def array_of(*item_schemas): + """Array of a single schema, or a union when multiple are passed.""" + items = item_schemas[0] if len(item_schemas) == 1 else {"anyOf": list(item_schemas)} + return {"items": items, "type": "array"} + + +# --------------------------------------------------------------------------- +# Helpers +# --------------------------------------------------------------------------- + + +def _get_schema(component_type, prop): + _factories = { + "DatePickerSingle": lambda: dcc.DatePickerSingle(id="dp"), + "DatePickerRange": lambda: dcc.DatePickerRange(id="dpr"), + "Graph": lambda: dcc.Graph(id="graph"), + "Interval": lambda: dcc.Interval(id="intv"), + "Input": lambda: dcc.Input(id="inp"), + "Textarea": lambda: dcc.Textarea(id="ta"), + "Tabs": lambda: dcc.Tabs(id="tabs"), + "Dropdown": lambda: dcc.Dropdown(id="dd"), + "RadioItems": lambda: dcc.RadioItems(id="ri"), + "Checklist": lambda: dcc.Checklist(id="cl"), + "Store": lambda: dcc.Store(id="store"), + "Upload": lambda: dcc.Upload(id="upload"), + "Slider": lambda: dcc.Slider(id="sl"), + "RangeSlider": lambda: dcc.RangeSlider(id="rs"), + } + app = _app_with_callback(_factories[component_type](), input_prop=prop) + tool = _user_tool(_tools_list(app)) + return _schema_for(tool) + + +def _app_with_annotated_callback(annotation_type, input_prop="disabled"): + app = Dash(__name__) + app.layout = html.Div([dcc.Input(id="inp"), html.Div(id="out")]) + + if annotation_type is None: + + @app.callback(Output("out", "children"), Input("inp", input_prop)) + def update(val): + return str(val) + + else: + + @app.callback(Output("out", "children"), Input("inp", input_prop)) + def update(val: annotation_type): + return str(val) + + return app + + +# --------------------------------------------------------------------------- +# Test cases +# --------------------------------------------------------------------------- + +# (component_type, prop, expected_schema) — representative per-type samples +INTROSPECTION_CASES = [ + ("Input", "value", nullable(STRING, NUMBER)), + ( + "Input", + "disabled", + nullable( + BOOLEAN, + {"const": "disabled", "type": "string"}, + {"const": "DISABLED", "type": "string"}, + ), + ), + ("Input", "n_submit", nullable(NUMBER)), + ("Dropdown", "options", nullable({})), + ("Checklist", "value", nullable(array_of(STRING, NUMBER, BOOLEAN))), + ("Store", "data", nullable(OBJECT, array_of({}), NUMBER, STRING, BOOLEAN)), + ("Upload", "contents", nullable(STRING, array_of(STRING))), + ("RangeSlider", "value", nullable(array_of(NUMBER))), + ("Tabs", "value", nullable(STRING)), +] + +# (annotation, prop, expected_schema) — callback annotations override introspection +ANNOTATION_CASES = [ + (str, "disabled", STRING), + (int, "value", INTEGER), + (float, "value", NUMBER), + (bool, "value", BOOLEAN), + (list, "value", array_of({})), + (dict, "value", OBJECT), + (Optional[int], "value", nullable(INTEGER)), + (Optional[str], "value", nullable(STRING)), +] + + +# --------------------------------------------------------------------------- +# Tests +# --------------------------------------------------------------------------- + + +def test_mcpi001_override_beats_introspection(): + """Static override wins over component introspection.""" + schema = _get_schema("DatePickerSingle", "date") + # Introspection would return None for this prop; + # override provides a date format with pattern + assert schema["type"] == "string" + assert schema["format"] == "date" + assert "pattern" in schema + + +def test_mcpi013_graph_figure_uses_plotly_schema_override(): + """Graph.figure matches the FIGURE role's schema override (concrete via wildcard).""" + schema = _get_schema("Graph", "figure") + assert schema["type"] == "object" + assert set(schema["properties"]) == {"data", "layout", "frames"} + + +@pytest.mark.parametrize( + "component_type,prop,expected", + INTROSPECTION_CASES, + ids=[f"{c}.{p}" for c, p, _ in INTROSPECTION_CASES], +) +def test_mcpi002_introspected_schema(component_type, prop, expected): + """Representative introspection tests across component types.""" + assert _get_schema(component_type, prop) == expected + + +@pytest.mark.parametrize( + "ann,prop,expected", + ANNOTATION_CASES, + ids=[ + f"{a.__name__ if hasattr(a, '__name__') else a}-{p}" + for a, p, _ in ANNOTATION_CASES + ], +) +def test_mcpi003_annotation(ann, prop, expected): + """Callback type annotations override component schemas.""" + app = _app_with_annotated_callback(ann, input_prop=prop) + tool = _user_tool(_tools_list(app)) + assert _schema_for(tool, "val") == expected + + +def test_mcpi004_no_annotation_uses_introspection(): + app = _app_with_annotated_callback(None) + tool = _user_tool(_tools_list(app)) + assert _schema_for(tool, "val") == nullable( + BOOLEAN, + {"const": "disabled", "type": "string"}, + {"const": "DISABLED", "type": "string"}, + ) + + +def test_mcpi005_str_removes_null(): + app = Dash(__name__) + app.layout = html.Div([dcc.Dropdown(id="dd"), html.Div(id="out")]) + + @app.callback(Output("out", "children"), Input("dd", "value")) + def update(val: str): + return val + + tool = _user_tool(_tools_list(app)) + assert _schema_for(tool, "val") == STRING + + +def test_mcpi006_optional_preserves_null(): + app = Dash(__name__) + app.layout = html.Div([dcc.Dropdown(id="dd"), html.Div(id="out")]) + + @app.callback(Output("out", "children"), Input("dd", "value")) + def update(val: Optional[str]): + return val or "" + + tool = _user_tool(_tools_list(app)) + assert _schema_for(tool, "val") == nullable(STRING) + + +def test_mcpi007_optional_param_not_required(): + app = Dash(__name__) + app.layout = html.Div([dcc.Dropdown(id="dd"), html.Div(id="out")]) + + @app.callback(Output("out", "children"), Input("dd", "value")) + def update(val: Optional[str]): + return val or "" + + tool = _user_tool(_tools_list(app)) + assert "val" not in tool.inputSchema.get("required", []) + + +def test_mcpi008_state_annotation_overrides(): + """Annotations work for State parameters too.""" + app = Dash(__name__) + app.layout = html.Div( + [dcc.Input(id="inp"), dcc.Store(id="store"), html.Div(id="out")] + ) + + @app.callback( + Output("out", "children"), + Input("inp", "value"), + State("store", "data"), + ) + def update(val: str, data: dict): + return str(val) + + tool = _user_tool(_tools_list(app)) + assert _schema_for(tool, "val") == STRING + assert _schema_for(tool, "data") == OBJECT + + +def test_mcpi009_partial_annotations(): + """Some annotated, some not — introspection fills in the rest.""" + app = Dash(__name__) + app.layout = html.Div( + [dcc.Input(id="inp"), dcc.Store(id="store"), html.Div(id="out")] + ) + + @app.callback( + Output("out", "children"), + Input("inp", "value"), + State("store", "data"), + ) + def update(val: int, data): + return str(val) + + tool = _user_tool(_tools_list(app)) + assert _schema_for(tool, "val") == INTEGER + assert _schema_for(tool, "data") == nullable( + OBJECT, array_of({}), NUMBER, STRING, BOOLEAN + ) + + +def test_mcpi010_component_type_maps_to_string(): + """Component annotation type maps to string schema.""" + assert annotation_to_json_schema(Component) == STRING + + +def test_mcpi011_dropdown_value_multi_false_narrows_to_scalar(): + """Dropdown.value with multi=False narrows to a scalar union.""" + app = _app_with_callback(dcc.Dropdown(id="dd")) + tool = _user_tool(_tools_list(app)) + assert _schema_for(tool) == {"anyOf": [STRING, NUMBER, BOOLEAN]} + + +def test_mcpi012_dropdown_value_multi_true_narrows_to_array(): + """Dropdown.value with multi=True narrows to an array of scalars.""" + app = _app_with_callback(dcc.Dropdown(id="dd", multi=True)) + tool = _user_tool(_tools_list(app)) + assert _schema_for(tool) == { + "type": "array", + "items": {"anyOf": [STRING, NUMBER, BOOLEAN]}, + } diff --git a/tests/unit/mcp/tools/test_mcp_pattern_matching.py b/tests/unit/mcp/tools/test_mcp_pattern_matching.py new file mode 100644 index 0000000000..25c8de1f8d --- /dev/null +++ b/tests/unit/mcp/tools/test_mcp_pattern_matching.py @@ -0,0 +1,136 @@ +"""Pattern-matching callback support — schemas and descriptions for wildcard IDs. + +Covers callbacks whose Input/Output use ``ALL``, ``MATCH``, or ``ALLSMALLER`` +wildcards in dict-based component IDs: the input is typed as an array (ALL) +or object (MATCH) of ``{id, property, value}`` entries, and descriptions +surface the wildcard kind and ID pattern. +""" + +from dash import Dash, html, Input, Output, ALL, MATCH, dcc + +from tests.unit.mcp.conftest import _tools_list, _user_tool, _schema_for, _desc_for + + +# --------------------------------------------------------------------------- +# Schema shape +# --------------------------------------------------------------------------- + + +def test_mcpm001_all_produces_array_schema(): + app = Dash(__name__) + app.layout = html.Div( + [ + html.Div(id={"type": "item", "index": 0}, children="A"), + html.Div(id={"type": "item", "index": 1}, children="B"), + html.Div(id="result"), + ] + ) + + @app.callback( + Output("result", "children"), + Input({"type": "item", "index": ALL}, "children"), + ) + def combine(values): + return ", ".join(values) + + tool = _user_tool(_tools_list(app)) + schema = _schema_for(tool) + assert schema["type"] == "array" + assert schema["items"]["type"] == "object" + assert "value" in schema["items"]["properties"] + + +def test_mcpm002_match_produces_object_schema(): + app = Dash(__name__) + app.layout = html.Div( + [ + html.Div(id={"type": "item", "index": 0}, children="A"), + html.Div(id="result"), + ] + ) + + @app.callback( + Output("result", "children"), + Input({"type": "item", "index": MATCH}, "children"), + ) + def echo(value): + return value + + tool = _user_tool(_tools_list(app)) + schema = _schema_for(tool) + assert schema["type"] == "object" + assert "value" in schema["properties"] + + +def test_mcpm003_annotation_narrows_value_schema(): + app = Dash(__name__) + app.layout = html.Div( + [ + dcc.Dropdown(id={"type": "filter", "index": 0}, options=["a", "b"]), + dcc.Dropdown(id={"type": "filter", "index": 1}, options=["c", "d"]), + html.Div(id="result"), + ] + ) + + @app.callback( + Output("result", "children"), + Input({"type": "filter", "index": ALL}, "options"), + ) + def combine(options: list[str]): + return str(options) + + tool = _user_tool(_tools_list(app)) + schema = _schema_for(tool) + assert schema["type"] == "array" + value_schema = schema["items"]["properties"]["value"] + # Annotation narrows value to list[str] instead of the broad introspected type + assert value_schema == {"items": {"type": "string"}, "type": "array"} + + +# --------------------------------------------------------------------------- +# Descriptions +# --------------------------------------------------------------------------- + + +def test_mcpm004_all_description(): + app = Dash(__name__) + app.layout = html.Div( + [ + html.Div(id={"type": "item", "index": 0}), + html.Div(id="result"), + ] + ) + + @app.callback( + Output("result", "children"), + Input({"type": "item", "index": ALL}, "children"), + ) + def combine(values): + return str(values) + + tool = _user_tool(_tools_list(app)) + desc = _desc_for(tool) + assert "Pattern-matching input (ALL)" in desc + assert 'type="item"' in desc + + +def test_mcpm005_match_description(): + app = Dash(__name__) + app.layout = html.Div( + [ + html.Div(id={"type": "item", "index": 0}), + html.Div(id="result"), + ] + ) + + @app.callback( + Output("result", "children"), + Input({"type": "item", "index": MATCH}, "children"), + ) + def echo(value): + return value + + tool = _user_tool(_tools_list(app)) + desc = _desc_for(tool) + assert "Pattern-matching input (MATCH)" in desc + assert 'type="item"' in desc diff --git a/tests/unit/mcp/tools/test_mcp_run_callback.py b/tests/unit/mcp/tools/test_mcp_run_callback.py new file mode 100644 index 0000000000..e345b9682e --- /dev/null +++ b/tests/unit/mcp/tools/test_mcp_run_callback.py @@ -0,0 +1,253 @@ +"""Callback dispatch execution via MCP tools (``run_callback``). + +Exercises how the MCP tool pipeline runs a Dash callback through +``_process_mcp_message`` with various signatures: multi-output, State, +positional vs. dict-based ``inputs``, ``PreventUpdate``, and no-output +set_props-style callbacks. +""" + +from dash import Dash, Input, Output, State, dcc, html, set_props +from dash.exceptions import PreventUpdate +from dash.mcp._server import _process_mcp_message + +from tests.unit.mcp.conftest import _setup_mcp + + +# --------------------------------------------------------------------------- +# Helpers +# --------------------------------------------------------------------------- + + +def _msg(method, params=None, request_id=1): + d = {"jsonrpc": "2.0", "method": method, "id": request_id} + d["params"] = params if params is not None else {} + return d + + +def _mcp(app, method, params=None, request_id=1): + with app.server.test_request_context(): + _setup_mcp(app) + return _process_mcp_message(_msg(method, params, request_id)) + + +def _tools_list(app): + return _mcp(app, "tools/list")["result"]["tools"] + + +def _call_tool_structured(app, tool_name, arguments=None): + result = _mcp(app, "tools/call", {"name": tool_name, "arguments": arguments or {}}) + return result["result"]["structuredContent"] + + +def _call_tool_output( + app, tool_name, arguments=None, component_id=None, prop="children" +): + structured = _call_tool_structured(app, tool_name, arguments) + response = structured["response"] + if component_id is None: + component_id = next(iter(response)) + return response[component_id][prop] + + +# --------------------------------------------------------------------------- +# Tests +# --------------------------------------------------------------------------- + + +def test_mcpx001_multi_output(): + app = Dash(__name__) + app.layout = html.Div( + [ + dcc.Dropdown(id="dd", options=["a", "b"], value="a"), + dcc.Dropdown(id="dd2"), + html.Div(id="out"), + ] + ) + + @app.callback( + Output("dd2", "options"), + Output("out", "children"), + Input("dd", "value"), + ) + def update(val): + return [{"label": val, "value": val}], f"selected: {val}" + + tools = _tools_list(app) + tool_name = next(t["name"] for t in tools if "update" in t["name"]) + structured = _call_tool_structured(app, tool_name, {"val": "b"}) + assert structured["response"]["dd2"]["options"] == [{"label": "b", "value": "b"}] + assert structured["response"]["out"]["children"] == "selected: b" + + +def test_mcpx002_omitted_kwargs_default_to_none(): + app = Dash(__name__) + app.layout = html.Div( + [ + dcc.Dropdown(id="dd", options=["a"]), + dcc.Input(id="inp"), + html.Div(id="out"), + ] + ) + + @app.callback( + Output("out", "children"), + Input("dd", "value"), + State("inp", "value"), + ) + def update(selected, text): + return f"{selected}-{text}" + + tools = _tools_list(app) + tool_name = next(t["name"] for t in tools if "update" in t["name"]) + assert _call_tool_output(app, tool_name, {"selected": "a"}, "out") == "a-None" + + +def test_mcpx003_no_output_callback(): + app = Dash(__name__) + app.layout = html.Div( + [ + html.Button(id="btn"), + html.Div(id="display"), + ] + ) + + @app.callback(Input("btn", "n_clicks")) + def server_cb(n): + set_props("display", {"children": f"Clicked {n} times"}) + + tools = _tools_list(app) + tool_names = [t["name"] for t in tools] + assert "server_cb" in tool_names + + +def test_mcpx004_prevent_update(): + app = Dash(__name__) + app.layout = html.Div( + [ + dcc.Input(id="inp", value="hello"), + html.Div(id="out"), + ] + ) + + @app.callback(Output("out", "children"), Input("inp", "value")) + def update(val): + if val == "block": + raise PreventUpdate + return f"got: {val}" + + tools = _tools_list(app) + tool_name = next(t["name"] for t in tools if "update" in t["name"]) + assert _call_tool_output(app, tool_name, {"val": "test"}, "out") == "got: test" + + +def test_mcpx005_with_state(): + app = Dash(__name__) + app.layout = html.Div( + [ + html.Div(id="trigger"), + html.Div(id="store"), + html.Div(id="result"), + ] + ) + + @app.callback( + Output("result", "children"), + Input("trigger", "children"), + State("store", "children"), + ) + def with_state(trigger, store): + return f"{trigger}-{store}" + + tools = _tools_list(app) + tool_name = next(t["name"] for t in tools if "with_state" in t["name"]) + assert ( + _call_tool_output( + app, + tool_name, + {"trigger": "click", "store": "data"}, + "result", + ) + == "click-data" + ) + + +def test_mcpx006_dict_inputs(): + app = Dash(__name__) + app.layout = html.Div( + [ + dcc.Input(id="x-input", value="hello"), + dcc.Input(id="y-input", value="world"), + html.Div(id="dict-out"), + ] + ) + + @app.callback( + Output("dict-out", "children"), + inputs={ + "x_val": Input("x-input", "value"), + "y_val": Input("y-input", "value"), + }, + ) + def combine(**kwargs): + return f"{kwargs['x_val']}-{kwargs['y_val']}" + + tools = _tools_list(app) + tool_name = next(t["name"] for t in tools if "combine" in t["name"]) + assert ( + _call_tool_output( + app, + tool_name, + {"x_val": "foo", "y_val": "bar"}, + "dict-out", + ) + == "foo-bar" + ) + + +def test_mcpx007_positional_inputs(): + app = Dash(__name__) + app.layout = html.Div( + [ + dcc.Input(id="a-input", value="A"), + html.Div(id="pos-out"), + ] + ) + + @app.callback(Output("pos-out", "children"), Input("a-input", "value")) + def echo(val): + return f"got:{val}" + + tools = _tools_list(app) + tool_name = next(t["name"] for t in tools if "echo" in t["name"]) + assert _call_tool_output(app, tool_name, {"val": "test"}, "pos-out") == "got:test" + + +def test_mcpx008_dict_inputs_with_state(): + app = Dash(__name__) + app.layout = html.Div( + [ + dcc.Input(id="inp", value="hi"), + html.Div(id="st", children="state-val"), + html.Div(id="ds-out"), + ] + ) + + @app.callback( + Output("ds-out", "children"), + inputs={"trigger": Input("inp", "value")}, + state={"kept": State("st", "children")}, + ) + def with_dict_state(**kwargs): + return f"{kwargs['trigger']}+{kwargs['kept']}" + + tools = _tools_list(app) + tool_name = next(t["name"] for t in tools if "with_dict_state" in t["name"]) + assert ( + _call_tool_output( + app, + tool_name, + {"trigger": "hey", "kept": "saved"}, + "ds-out", + ) + == "hey+saved" + ) diff --git a/tests/unit/mcp/tools/test_mcp_tool_get_dash_component.py b/tests/unit/mcp/tools/test_mcp_tool_get_dash_component.py new file mode 100644 index 0000000000..4cf11555be --- /dev/null +++ b/tests/unit/mcp/tools/test_mcp_tool_get_dash_component.py @@ -0,0 +1,125 @@ +"""The built-in ``get_dash_component`` tool. + +Lets LLMs inspect a component's current properties and their relationships +to callbacks (which tools write/read each prop). +""" + +from dash import Dash, Input, Output, dcc, html + +from tests.unit.mcp.conftest import _call_tool, _make_app, _tools_list + + +def test_mcpg001_present_in_tools_list(): + app = _make_app() + tool_names = [t.name for t in _tools_list(app)] + assert "get_dash_component" in tool_names + + +def test_mcpg002_returns_structured_output_with_prop(): + app = Dash(__name__) + app.layout = html.Div( + [ + dcc.Dropdown(id="my-dd", options=["a", "b"], value="b"), + ] + ) + + result = _call_tool( + app, + "get_dash_component", + { + "component_id": "my-dd", + "property": "value", + }, + ) + sc = result.structuredContent + assert sc["component_id"] == "my-dd" + assert sc["component_type"] == "Dropdown" + assert "value" in sc["properties"] + assert sc["properties"]["value"]["initial_value"] == "b" + assert "options" not in sc["properties"] + + +def test_mcpg003_returns_all_props_without_property(): + app = Dash(__name__) + app.layout = html.Div( + [ + dcc.Dropdown(id="my-dd", options=["a", "b"], value="b"), + ] + ) + + result = _call_tool( + app, + "get_dash_component", + { + "component_id": "my-dd", + }, + ) + sc = result.structuredContent + assert "options" in sc["properties"] + assert "value" in sc["properties"] + assert sc["properties"]["value"]["initial_value"] == "b" + + +def test_mcpg004_includes_label(): + app = Dash(__name__) + app.layout = html.Div( + [ + html.Label("Pick one", htmlFor="my-dd"), + dcc.Dropdown(id="my-dd", options=["a", "b"], value="a"), + ] + ) + + @app.callback(Output("my-dd", "value"), Input("my-dd", "options")) + def noop(o): + return "a" + + result = _call_tool( + app, + "get_dash_component", + { + "component_id": "my-dd", + }, + ) + sc = result.structuredContent + assert sc["label"] == ["Pick one"] + + +def test_mcpg005_includes_tool_references(): + app = Dash(__name__) + app.layout = html.Div( + [ + dcc.Dropdown(id="dd", options=["a", "b"], value="a"), + html.Div(id="out"), + ] + ) + + @app.callback(Output("out", "children"), Input("dd", "value")) + def update(val): + return val + + result = _call_tool( + app, + "get_dash_component", + { + "component_id": "dd", + "property": "value", + }, + ) + prop_info = result.structuredContent["properties"]["value"] + assert "update" in prop_info["input_to_tool"] + + +def test_mcpg006_missing_id_returns_hint(): + app = _make_app() + result = _call_tool( + app, + "get_dash_component", + { + "component_id": "nonexistent", + "property": "value", + }, + ) + text = result.content[0].text + assert "nonexistent" in text + assert "not found" in text + assert "dash://components" in text diff --git a/tests/unit/mcp/tools/test_mcp_tools.py b/tests/unit/mcp/tools/test_mcp_tools.py new file mode 100644 index 0000000000..3255809982 --- /dev/null +++ b/tests/unit/mcp/tools/test_mcp_tools.py @@ -0,0 +1,403 @@ +"""Tool construction: how Dash callbacks become MCP Tool objects. + +Covers the CallbackAdapter → Tool pipeline: list building (from_app), +tool name generation, and the resulting Tool object's shape (description, +input schema, param metadata). + +Reference: https://modelcontextprotocol.io/specification/2025-11-25/server/tools +""" + +import pytest +from dash import Dash, Input, Output, State, dcc, html +from dash._get_app import app_context +from dash.development.base_component import Component +from dash.types import CallbackExecutionResponse +from mcp.types import Tool +from pydantic import TypeAdapter + +from dash.mcp.primitives.tools.callback_adapter_collection import ( + CallbackAdapterCollection, +) + +from tests.unit.mcp.conftest import ( + _make_app, + _tools_list, + _user_tool, +) + + +# --------------------------------------------------------------------------- +# Fixtures +# --------------------------------------------------------------------------- + + +@pytest.fixture +def simple_app(): + app = Dash(__name__) + app.layout = html.Div( + [ + html.Label("Your Name", htmlFor="inp"), + dcc.Input(id="inp", type="text"), + html.Div(id="out"), + ] + ) + + @app.callback(Output("out", "children"), Input("inp", "value")) + def update(val): + """Update output.""" + return val + + app_context.set(app) + app.mcp_callback_map = CallbackAdapterCollection(app) + return app + + +@pytest.fixture +def multi_output_app(): + app = Dash(__name__) + app.layout = html.Div( + [ + dcc.Dropdown(id="dd", options=["a", "b"], value="a"), + dcc.Dropdown(id="dd2"), + html.Div(id="out"), + ] + ) + + @app.callback( + Output("dd2", "options"), + Output("out", "children"), + Input("dd", "value"), + ) + def update(val): + return [], val + + app_context.set(app) + app.mcp_callback_map = CallbackAdapterCollection(app) + return app + + +@pytest.fixture +def state_app(): + app = Dash(__name__) + app.layout = html.Div( + [ + html.Button(id="btn"), + dcc.Input(id="inp"), + html.Div(id="out"), + ] + ) + + @app.callback( + Output("out", "children"), + Input("btn", "n_clicks"), + State("inp", "value"), + ) + def update(clicks, val): + return val + + app_context.set(app) + app.mcp_callback_map = CallbackAdapterCollection(app) + return app + + +@pytest.fixture +def typed_app(): + app = Dash(__name__) + app.layout = html.Div([dcc.Input(id="inp"), html.Div(id="out")]) + + @app.callback(Output("out", "children"), Input("inp", "value")) + def update(val: str): + return val + + app_context.set(app) + app.mcp_callback_map = CallbackAdapterCollection(app) + return app + + +@pytest.fixture +def duplicate_names_app(): + app = Dash(__name__) + app.layout = html.Div( + [ + html.Div(id="in1"), + html.Div(id="out1"), + html.Div(id="in2"), + html.Div(id="out2"), + ] + ) + + @app.callback(Output("out1", "children"), Input("in1", "children")) + def cb(v): + return v + + @app.callback(Output("out2", "children"), Input("in2", "children")) + def cb(v): # noqa: F811 + return v + + app_context.set(app) + app.mcp_callback_map = CallbackAdapterCollection(app) + return app + + +# --------------------------------------------------------------------------- +# Tests — building the callback list from an app +# --------------------------------------------------------------------------- + + +def test_mcpt001_returns_list(simple_app): + assert len(app_context.get().mcp_callback_map) == 1 + + +def test_mcpt002_excludes_clientside(): + app = Dash(__name__) + app.layout = html.Div( + [ + html.Button(id="btn"), + html.Div(id="cs-out"), + html.Div(id="srv-out"), + ] + ) + app.clientside_callback( + "function(n) { return n; }", + Output("cs-out", "children"), + Input("btn", "n_clicks"), + ) + + @app.callback(Output("srv-out", "children"), Input("btn", "n_clicks")) + def server_cb(n): + return str(n) + + app_context.set(app) + app.mcp_callback_map = CallbackAdapterCollection(app) + + names = [a.tool_name for a in app.mcp_callback_map] + assert names == ["server_cb"] + + +def test_mcpt003_excludes_mcp_disabled(): + app = Dash(__name__) + app.layout = html.Div( + [ + dcc.Input(id="inp"), + html.Div(id="out1"), + html.Div(id="out2"), + ] + ) + + @app.callback(Output("out1", "children"), Input("inp", "value")) + def visible(val): + return val + + @app.callback(Output("out2", "children"), Input("inp", "value"), mcp_enabled=False) + def hidden(val): + return val + + app_context.set(app) + app.mcp_callback_map = CallbackAdapterCollection(app) + names = [a.tool_name for a in app.mcp_callback_map] + assert "visible" in names + assert "hidden" not in names + + +# --------------------------------------------------------------------------- +# Tests — tool name generation +# --------------------------------------------------------------------------- + + +def test_mcpt004_uses_func_name(simple_app): + assert app_context.get().mcp_callback_map[0].tool_name == "update" + + +def test_mcpt005_duplicates_get_unique_names(duplicate_names_app): + names = [a.tool_name for a in app_context.get().mcp_callback_map] + assert len(names) == 2 + assert names[0] != names[1] + + +# --------------------------------------------------------------------------- +# Tests — Tool object shape (description, input schema, params) +# --------------------------------------------------------------------------- + + +def test_mcpt006_returns_tool_instance(simple_app): + with simple_app.server.test_request_context(): + tool = app_context.get().mcp_callback_map[0].as_mcp_tool + assert isinstance(tool, Tool) + assert tool.name == "update" + + +def test_mcpt007_docstring_hidden_by_default(): + """Callback docstrings are not exposed to MCP by default.""" + app = Dash(__name__) + app.layout = html.Div([dcc.Input(id="inp"), html.Div(id="out")]) + + @app.callback(Output("out", "children"), Input("inp", "value")) + def update(val): + """sensitive callback docstring text that must not leak to LLMs""" + return val + + app_context.set(app) + app.mcp_callback_map = CallbackAdapterCollection(app) + + with app.server.test_request_context(): + tool = app.mcp_callback_map[0].as_mcp_tool + assert ( + "sensitive callback docstring text that must not leak to LLMs" + not in tool.description + ) + + +def test_mcpt008_docstring_exposed_when_opted_in_per_callback(): + app = Dash(__name__) + app.layout = html.Div([dcc.Input(id="inp"), html.Div(id="out")]) + + @app.callback( + Output("out", "children"), + Input("inp", "value"), + mcp_expose_docstring=True, + ) + def update(val): + """intentionally-exposed callback docstring text for the LLM""" + return val + + app_context.set(app) + app.mcp_callback_map = CallbackAdapterCollection(app) + + with app.server.test_request_context(): + tool = app.mcp_callback_map[0].as_mcp_tool + assert ( + "intentionally-exposed callback docstring text for the LLM" in tool.description + ) + + +def test_mcpt009_description_includes_output_target(simple_app): + with simple_app.server.test_request_context(): + tool = app_context.get().mcp_callback_map[0].as_mcp_tool + assert "out.children" in tool.description + + +def test_mcpt010_param_name_from_function_signature(simple_app): + with simple_app.server.test_request_context(): + tool = app_context.get().mcp_callback_map[0].as_mcp_tool + assert "val" in tool.inputSchema["properties"] + + +def test_mcpt011_param_has_label_description(simple_app): + with simple_app.server.test_request_context(): + tool = app_context.get().mcp_callback_map[0].as_mcp_tool + desc = tool.inputSchema["properties"]["val"].get("description", "") + assert "Your Name" in desc + + +def test_mcpt012_state_params_included(state_app): + with state_app.server.test_request_context(): + tool = app_context.get().mcp_callback_map[0].as_mcp_tool + props = tool.inputSchema["properties"] + assert set(props.keys()) == {"clicks", "val"} + + +def test_mcpt013_multi_output_description(multi_output_app): + with multi_output_app.server.test_request_context(): + tool = app_context.get().mcp_callback_map[0].as_mcp_tool + assert "dd2.options" in tool.description + assert "out.children" in tool.description + + +def test_mcpt014_typed_annotation_narrows_schema(typed_app): + with typed_app.server.test_request_context(): + tool = app_context.get().mcp_callback_map[0].as_mcp_tool + assert tool.inputSchema["properties"]["val"]["type"] == "string" + + +def test_mcpt016_app_level_opt_in_exposes_docstrings(): + """Dash(mcp_expose_docstrings=True) exposes docstrings for all callbacks.""" + app = Dash(__name__, mcp_expose_docstrings=True) + app.layout = html.Div([dcc.Input(id="inp"), html.Div(id="out")]) + + @app.callback(Output("out", "children"), Input("inp", "value")) + def update(val): + """intentionally-exposed callback docstring text for the LLM""" + return val + + app_context.set(app) + app.mcp_callback_map = CallbackAdapterCollection(app) + + with app.server.test_request_context(): + tool = app.mcp_callback_map[0].as_mcp_tool + assert ( + "intentionally-exposed callback docstring text for the LLM" in tool.description + ) + + +def test_mcpt017_per_callback_false_overrides_app_level_opt_in(): + """Per-callback mcp_expose_docstring=False wins over app-level opt-in.""" + app = Dash(__name__, mcp_expose_docstrings=True) + app.layout = html.Div([dcc.Input(id="inp"), html.Div(id="out")]) + + @app.callback( + Output("out", "children"), + Input("inp", "value"), + mcp_expose_docstring=False, + ) + def update(val): + """sensitive callback docstring text that must not leak to LLMs""" + return val + + app_context.set(app) + app.mcp_callback_map = CallbackAdapterCollection(app) + + with app.server.test_request_context(): + tool = app.mcp_callback_map[0].as_mcp_tool + assert ( + "sensitive callback docstring text that must not leak to LLMs" + not in tool.description + ) + + +# --------------------------------------------------------------------------- +# Tests — end-to-end Tool shape +# --------------------------------------------------------------------------- + + +_DASH_COMPONENT_SCHEMA = TypeAdapter(Component).json_schema() + +EXPECTED_TOOL = { + "name": "update_output", + "description": ( + "my-output.children: Returns content\n" "\n" "Test callback docstring." + ), + "inputSchema": { + "type": "object", + "properties": { + "value": { + "anyOf": [ + {"type": "string"}, + {"type": "integer"}, + {"type": "number"}, + _DASH_COMPONENT_SCHEMA, + { + "items": { + "anyOf": [ + {"type": "string"}, + {"type": "integer"}, + {"type": "number"}, + _DASH_COMPONENT_SCHEMA, + {"type": "null"}, + ] + }, + "type": "array", + }, + {"type": "null"}, + ], + "description": "Input is optional.\nThe children of this component.", + }, + }, + }, + "outputSchema": TypeAdapter(CallbackExecutionResponse).json_schema(), +} + + +def test_mcpt015_full_tool(): + """The entire tool dict matches the expected shape end-to-end.""" + tool = _user_tool(_tools_list(_make_app())) + assert tool.model_dump(exclude_none=True) == EXPECTED_TOOL diff --git a/tests/unit/mcp/tools/test_mcp_tools_callbacks.py b/tests/unit/mcp/tools/test_mcp_tools_callbacks.py new file mode 100644 index 0000000000..7148803bc1 --- /dev/null +++ b/tests/unit/mcp/tools/test_mcp_tools_callbacks.py @@ -0,0 +1,152 @@ +"""Dynamic callback tools: MCP spec compliance, tool naming, output summaries. + +Verifies that generated tools conform to the MCP 2025-11-25 specification +and Dash-specific conventions. Focuses on shape/structure, tool-name +sanitization, and ``_OUTPUT_SEMANTICS`` fallback summaries; input-schema +values are covered by ``test_mcp_input_schemas``. + +Reference: https://modelcontextprotocol.io/specification/2025-11-25/server/tools +""" + +import re +from unittest.mock import Mock + +from dash.mcp.primitives.tools.callback_adapter_collection import ( + CallbackAdapterCollection, +) +from dash.mcp.primitives.tools.descriptions import build_tool_description + +from tests.unit.mcp.conftest import ( + _make_app, + _tools_list, +) + + +# --------------------------------------------------------------------------- +# Helpers +# --------------------------------------------------------------------------- + +_TOOL_NAME_RE = re.compile(r"^[A-Za-z0-9_\-.]+$") + + +def _adapter_with_outputs(outputs, docstring=None): + adapter = Mock() + adapter.outputs = outputs + adapter._docstring = docstring + return adapter + + +def _out(comp_id, prop, comp_type=None): + return { + "id_and_prop": f"{comp_id}.{prop}", + "component_id": comp_id, + "property": prop, + "component_type": comp_type, + "initial_value": None, + } + + +# --------------------------------------------------------------------------- +# MCP spec compliance — every generated tool must satisfy the spec +# --------------------------------------------------------------------------- + + +def test_mcptc001_all_tools_conform_to_mcp_spec(): + tools = _tools_list(_make_app()) + names = [t.name for t in tools] + + assert len(names) == len(set(names)), f"Duplicate tool names: {names}" + + for tool in tools: + assert tool.name + assert tool.inputSchema + assert 1 <= len(tool.name) <= 128 + assert _TOOL_NAME_RE.match(tool.name), f"Invalid tool name: {tool.name}" + + schema = tool.inputSchema + assert isinstance(schema, dict) + assert schema.get("type") == "object" + assert isinstance(schema.get("properties", {}), dict) + + required = set(schema.get("required", [])) + props = set(schema.get("properties", {}).keys()) + assert ( + required <= props + ), f"{tool.name}: required {required - props} not in properties" + + +# --------------------------------------------------------------------------- +# Built-in tools +# --------------------------------------------------------------------------- + + +def test_mcptc002_query_component_always_present(): + names = {t.name for t in _tools_list(_make_app())} + assert "get_dash_component" in names + + +def test_mcptc003_query_component_has_required_params(): + tool = next(t for t in _tools_list(_make_app()) if t.name == "get_dash_component") + assert "component_id" in tool.inputSchema["properties"] + assert "property" in tool.inputSchema["properties"] + assert set(tool.inputSchema.get("required", [])) == {"component_id"} + + +# --------------------------------------------------------------------------- +# Tool-name sanitization (CallbackAdapterCollection._sanitize_name) +# --------------------------------------------------------------------------- + + +def test_mcptc004_sanitize_simple_name(): + assert CallbackAdapterCollection._sanitize_name("update_output") == "update_output" + + +def test_mcptc005_sanitize_special_characters_replaced(): + assert CallbackAdapterCollection._sanitize_name("my-func.name") == "my_func_name" + + +def test_mcptc006_sanitize_leading_digit(): + assert CallbackAdapterCollection._sanitize_name("123func") == "cb_123func" + + +def test_mcptc007_sanitize_empty_name(): + assert CallbackAdapterCollection._sanitize_name("") == "unnamed_callback" + + +def test_mcptc008_sanitize_consecutive_underscores_collapsed(): + assert CallbackAdapterCollection._sanitize_name("a---b___c") == "a_b_c" + + +def test_mcptc009_sanitize_long_name_truncated_to_64_chars(): + result = CallbackAdapterCollection._sanitize_name("a" * 200) + assert len(result) <= 64 + assert result[-8:].isalnum() + + +def test_mcptc010_sanitize_long_name_uniqueness(): + result_a = CallbackAdapterCollection._sanitize_name("a" * 200) + result_b = CallbackAdapterCollection._sanitize_name("b" * 200) + assert result_a != result_b + + +def test_mcptc011_sanitize_short_name_not_truncated(): + assert CallbackAdapterCollection._sanitize_name("short_name") == "short_name" + + +# --------------------------------------------------------------------------- +# Output semantic summary (``_OUTPUT_SEMANTICS`` in description_outputs.py). +# Other description tests (docstring, output target, multi-output) are +# covered by test_mcp_tools using real adapters. +# --------------------------------------------------------------------------- + + +def test_mcptc012_semantic_summary_with_component_type(): + adapter = _adapter_with_outputs([_out("my-graph", "figure", "Graph")]) + desc = build_tool_description(adapter) + assert "Returns chart/visualization data" in desc + + +def test_mcptc013_semantic_summary_fallback_by_property(): + adapter = _adapter_with_outputs([_out("unknown-id", "figure")]) + desc = build_tool_description(adapter) + assert "Returns chart/visualization data" in desc diff --git a/tests/unit/test_layout.py b/tests/unit/test_layout.py new file mode 100644 index 0000000000..64fff724a1 --- /dev/null +++ b/tests/unit/test_layout.py @@ -0,0 +1,83 @@ +"""Tests for dash._layout_utils — layout traversal and component lookup utilities.""" + +import pytest + +from dash import html, dcc +from dash._layout_utils import ( + traverse, + find_component, + extract_text, + parse_wildcard_id, +) + + +@pytest.fixture +def sample_layout(): + return html.Div( + [ + html.Label("Name:", htmlFor="name-input"), + " ", + dcc.Input(id="name-input", value="World"), + html.Div( + [html.Span(id="deep-child", children="deep text")], + id="inner", + ), + ], + id="root", + ) + + +class TestTraverse: + def test_yields_all_components_with_correct_ancestors(self, sample_layout): + results = { + getattr(c, "id", None): len(ancestors) + for c, ancestors in traverse(sample_layout) + } + assert results["root"] == 0 + assert results["name-input"] == 1 + assert results["deep-child"] == 2 + + def test_empty_layout(self): + results = list(traverse(html.Div())) + assert len(results) == 1 # just the Div itself + + +class TestFindComponent: + def test_finds_by_string_id(self, sample_layout): + comp = find_component("deep-child", layout=sample_layout) + assert comp is not None and comp.id == "deep-child" + + def test_returns_none_for_missing_id(self, sample_layout): + assert find_component("nope", layout=sample_layout) is None + + def test_finds_by_dict_id(self): + layout = html.Div([html.Div(id={"type": "item", "index": 0})]) + assert find_component({"type": "item", "index": 0}, layout=layout) is not None + + +class TestExtractText: + def test_extracts_all_text_content(self, sample_layout): + assert extract_text(sample_layout) == "Name: deep text" + + def test_none_children(self): + assert extract_text(html.Div()) == "" + + +class TestParseWildcardId: + @pytest.mark.parametrize("wildcard", ["ALL", "MATCH", "ALLSMALLER"]) + def test_returns_dict_for_wildcard(self, wildcard): + result = parse_wildcard_id({"type": "input", "index": [wildcard]}) + assert result == {"type": "input", "index": [wildcard]} + + def test_parses_json_string(self): + result = parse_wildcard_id('{"type":"input","index":["ALL"]}') + assert result == {"type": "input", "index": ["ALL"]} + + def test_returns_none_for_plain_string(self): + assert parse_wildcard_id("my-dropdown") is None + + def test_returns_none_for_non_wildcard_dict(self): + assert parse_wildcard_id({"type": "input", "index": 0}) is None + + def test_returns_none_for_invalid_json(self): + assert parse_wildcard_id("{not valid}") is None diff --git a/tests/unit/test_pydantic_types.py b/tests/unit/test_pydantic_types.py new file mode 100644 index 0000000000..389e678028 --- /dev/null +++ b/tests/unit/test_pydantic_types.py @@ -0,0 +1,36 @@ +"""Tests for dash.types — Pydantic-compatible types and schemas.""" + +from pydantic import TypeAdapter + +from dash.types import NumberType, CallbackExecutionBody, CallbackExecutionResponse +from dash.development.base_component import Component + + +class TestNumberType: + def test_json_schema_is_number(self): + schema = TypeAdapter(NumberType).json_schema() + assert schema["type"] == "number" + + +class TestComponentPydanticSchema: + def test_produces_object_schema(self): + schema = TypeAdapter(Component).json_schema() + assert schema["type"] == "object" + assert "properties" in schema + + def test_schema_has_type_and_props(self): + schema = TypeAdapter(Component).json_schema() + props = schema["properties"] + assert "type" in props + assert "props" in props + + +class TestCallbackExecutionTypes: + def test_dispatch_body_schema(self): + schema = TypeAdapter(CallbackExecutionBody).json_schema() + assert "output" in schema["properties"] + assert "inputs" in schema["properties"] + + def test_dispatch_response_schema(self): + schema = TypeAdapter(CallbackExecutionResponse).json_schema() + assert "response" in schema["properties"]