From 882586bc52957b67975691765357513eab345824 Mon Sep 17 00:00:00 2001 From: Adrian Borrmann Date: Wed, 1 Apr 2026 09:38:24 -0600 Subject: [PATCH 01/80] Add layout traversal utilities for Dash component trees --- dash/layout.py | 228 +++++++++++++++++++++++++++++++++++++++++++++++++ 1 file changed, 228 insertions(+) create mode 100644 dash/layout.py diff --git a/dash/layout.py b/dash/layout.py new file mode 100644 index 0000000000..fdca86edca --- /dev/null +++ b/dash/layout.py @@ -0,0 +1,228 @@ +"""Reusable layout utilities for traversing and inspecting Dash component trees.""" + +from __future__ import annotations + +import json +from typing import Any, Generator + +from dash import get_app +from dash._pages import PAGE_REGISTRY +from dash.dependencies import Wildcard +from dash.development.base_component import Component + +_WILDCARD_VALUES = frozenset(w.value for w in Wildcard) + + +def traverse( + start: Component | None = None, +) -> Generator[tuple[Component, tuple[Component, ...]], None, None]: + """Yield ``(component, ancestors)`` for every Component in the tree. + + If ``start`` is ``None``, the full app layout is resolved via + ``dash.get_app()``, preferring ``validation_layout`` for completeness. + """ + if start is None: + app = get_app() + start = getattr(app, "validation_layout", None) or app.get_layout() + + yield from _walk(start, ()) + + +def _walk( + node: Any, + ancestors: tuple[Component, ...], +) -> Generator[tuple[Component, tuple[Component, ...]], None, None]: + if node is None: + return + if isinstance(node, (list, tuple)): + for item in node: + yield from _walk(item, ancestors) + return + if not isinstance(node, Component): + return + + yield node, ancestors + + child_ancestors = (*ancestors, node) + for _prop_name, child in iter_children(node): + yield from _walk(child, child_ancestors) + + +def iter_children( + component: Component, +) -> Generator[tuple[str, Component], None, None]: + """Yield ``(prop_name, child_component)`` for all component-valued props. + + Walks ``children`` plus any props declared in the component's + ``_children_props`` list. Supports nested path expressions like + ``control_groups[].children`` and ``insights.title``. + """ + props_to_walk = ["children"] + getattr(component, "_children_props", []) + for prop_path in props_to_walk: + for child in get_children(component, prop_path): + yield prop_path, child + + +def get_children(component: Any, prop_path: str) -> list[Component]: + """Resolve a ``_children_props`` path expression to child Components. + + Mirrors the dash-renderer's path parsing in ``DashWrapper.tsx``. + Supports: + - ``"children"`` — simple prop + - ``"control_groups[].children"`` — array, then sub-prop per element + - ``"insights.title"`` — nested object prop + """ + clean_path = prop_path.replace("[]", "").replace("{}", "") + + if "." not in prop_path: + return _collect_components(getattr(component, clean_path, None)) + + parts = prop_path.split(".") + array_idx = next((i for i, p in enumerate(parts) if "[]" in p), len(parts)) + front = [p.replace("[]", "").replace("{}", "") for p in parts[: array_idx + 1]] + back = [p.replace("{}", "") for p in parts[array_idx + 1 :]] + + node = _resolve_path(component, front) + if node is None: + return [] + + if back and isinstance(node, (list, tuple)): + results: list[Component] = [] + for element in node: + child = _resolve_path(element, back) + results.extend(_collect_components(child)) + return results + + return _collect_components(node) + + +def _resolve_path(node: Any, keys: list[str]) -> Any: + """Walk a chain of keys through Components and dicts.""" + for key in keys: + if isinstance(node, Component): + node = getattr(node, key, None) + elif isinstance(node, dict): + node = node.get(key) + else: + return None + if node is None: + return None + return node + + +def _collect_components(value: Any) -> list[Component]: + """Extract Components from a value (single, list, or None).""" + if value is None: + return [] + if isinstance(value, Component): + return [value] + if isinstance(value, (list, tuple)): + return [item for item in value if isinstance(item, (Component, list, tuple))] + return [] + + +def find_component( + component_id: str | dict, + layout: Component | None = None, + page: str | None = None, +) -> Component | None: + """Find a component by ID. + + If neither ``layout`` nor ``page`` is provided, searches the full + app layout (preferring ``validation_layout`` for completeness). + """ + if page is not None: + layout = _resolve_page_layout(page) + + if layout is None: + app = get_app() + layout = getattr(app, "validation_layout", None) or app.get_layout() + + for comp, _ in traverse(layout): + if getattr(comp, "id", None) == component_id: + return comp + return None + + +def parse_wildcard_id(pid: Any) -> dict | None: + """Parse a component ID and return it as a dict if it contains a wildcard. + + Accepts string (JSON-encoded) or dict IDs. Returns ``None`` + if the ID is not a wildcard pattern. + + Example:: + + >>> parse_wildcard_id('{"type":"input","index":["ALL"]}') + {"type": "input", "index": ["ALL"]} + >>> parse_wildcard_id("my-dropdown") + None + """ + if isinstance(pid, str) and pid.startswith("{"): + try: + pid = json.loads(pid) + except (json.JSONDecodeError, ValueError): + return None + if not isinstance(pid, dict): + return None + for v in pid.values(): + if isinstance(v, list) and len(v) == 1 and v[0] in _WILDCARD_VALUES: + return pid + return None + + +def find_matching_components(pattern: dict) -> list[Component]: + """Find all components whose dict ID matches a wildcard pattern. + + Non-wildcard keys must match exactly. Wildcard keys are ignored. + """ + non_wildcard_keys = { + k: v + for k, v in pattern.items() + if not (isinstance(v, list) and len(v) == 1 and v[0] in _WILDCARD_VALUES) + } + matches = [] + for comp, _ in traverse(): + comp_id = getattr(comp, "id", None) + if not isinstance(comp_id, dict): + continue + if all(comp_id.get(k) == v for k, v in non_wildcard_keys.items()): + matches.append(comp) + return matches + + +def extract_text(component: Component) -> str: + """Recursively extract plain text from a component's children tree. + + Mimics the browser's ``element.textContent``. + """ + children = getattr(component, "children", None) + if children is None: + return "" + if isinstance(children, str): + return children + if isinstance(children, Component): + return extract_text(children) + if isinstance(children, (list, tuple)): + parts: list[str] = [] + for child in children: + if isinstance(child, str): + parts.append(child) + elif isinstance(child, Component): + parts.append(extract_text(child)) + return "".join(parts).strip() + return "" + + +def _resolve_page_layout(page: str) -> Any | None: + if not PAGE_REGISTRY: + return None + for _module, page_info in PAGE_REGISTRY.items(): + if page_info.get("path") == page: + page_layout = page_info.get("layout") + if callable(page_layout): + try: + page_layout = page_layout() + except (TypeError, RuntimeError): + return None + return page_layout + return None From 4d55ef8eb6c2f68c411b5af65d91254ea917b579 Mon Sep 17 00:00:00 2001 From: Adrian Borrmann Date: Wed, 1 Apr 2026 10:01:59 -0600 Subject: [PATCH 02/80] Make Dash components compatible with Pydantic types --- dash/development/_py_components_generation.py | 5 +- dash/development/base_component.py | 17 ++++ dash/types.py | 67 ++++++++++++++- requirements/install.txt | 1 + tests/unit/test_layout.py | 83 +++++++++++++++++++ tests/unit/test_pydantic_types.py | 36 ++++++++ 6 files changed, 204 insertions(+), 5 deletions(-) create mode 100644 tests/unit/test_layout.py create mode 100644 tests/unit/test_pydantic_types.py diff --git a/dash/development/_py_components_generation.py b/dash/development/_py_components_generation.py index 73545ea4a5..b597283a04 100644 --- a/dash/development/_py_components_generation.py +++ b/dash/development/_py_components_generation.py @@ -24,6 +24,7 @@ import typing # noqa: F401 from typing_extensions import TypedDict, NotRequired, Literal # noqa: F401 from dash.development.base_component import Component, _explicitize_args +from dash.types import NumberType # noqa: F401 {custom_imports} ComponentSingleType = typing.Union[str, int, float, Component, None] ComponentType = typing.Union[ @@ -31,10 +32,6 @@ typing.Sequence[ComponentSingleType], ] -NumberType = typing.Union[ - typing.SupportsFloat, typing.SupportsInt, typing.SupportsComplex -] - """ diff --git a/dash/development/base_component.py b/dash/development/base_component.py index 342ad6da6f..70ed17e9de 100644 --- a/dash/development/base_component.py +++ b/dash/development/base_component.py @@ -117,6 +117,23 @@ class Component(metaclass=ComponentMeta): _valid_wildcard_attributes: typing.List[str] available_wildcard_properties: typing.List[str] + @classmethod + def __get_pydantic_core_schema__(cls, source_type, handler): + from pydantic_core import core_schema + return core_schema.any_schema() + + @classmethod + def __get_pydantic_json_schema__(cls, schema, handler): + namespaces = list(ComponentRegistry.namespace_to_package.keys()) + return { + "type": "object", + "properties": { + "type": {"type": "string"}, + "namespace": {"type": "string", "enum": namespaces} if namespaces else {"type": "string"}, + "props": {"type": "object"}, + }, + } + class _UNDEFINED: def __repr__(self): return "undefined" diff --git a/dash/types.py b/dash/types.py index 9a39adb43e..43bf16dc30 100644 --- a/dash/types.py +++ b/dash/types.py @@ -1,4 +1,29 @@ -from typing_extensions import TypedDict, NotRequired +import typing +from typing import Any, Dict, List, Union + +from pydantic import Field, GetCoreSchemaHandler, GetJsonSchemaHandler +from pydantic_core import core_schema +from typing_extensions import Annotated, TypedDict, NotRequired + + +class _NumberSchema: # pylint: disable=too-few-public-methods + @classmethod + def __get_pydantic_core_schema__( + cls, _source_type: Any, _handler: GetCoreSchemaHandler + ) -> Any: + return core_schema.float_schema() + + @classmethod + def __get_pydantic_json_schema__( + cls, _schema: Any, _handler: GetJsonSchemaHandler + ) -> dict: + return {"type": "number"} + + +NumberType = Annotated[ + Union[typing.SupportsFloat, typing.SupportsInt, typing.SupportsComplex], + _NumberSchema, +] class RendererHooks(TypedDict): # pylint: disable=too-many-ancestors @@ -8,3 +33,43 @@ class RendererHooks(TypedDict): # pylint: disable=too-many-ancestors request_post: NotRequired[str] callback_resolved: NotRequired[str] request_refresh_jwt: NotRequired[str] + + +class CallbackDependency(TypedDict): + id: Union[str, Dict[str, Any]] + property: str + + +class CallbackInput(TypedDict): + id: Union[str, Dict[str, Any]] + property: str + value: Any + + +class CallbackDispatchBody(TypedDict): + output: str + outputs: List[CallbackDependency] + inputs: List[CallbackInput] + state: List[CallbackInput] + changedPropIds: List[str] + + +CallbackOutput = Annotated[ + Dict[str, Any], + Field( + description="The return values of the callback. A mapping of component & property names to their updated values." + ), +] + +CallbackSideOutput = Annotated[ + Dict[str, Any], + Field( + description="Side-effect updates that the callback performed but did not declare ahead of time. A mapping of component & property names to their updated values." + ), +] + + +class CallbackDispatchResponse(TypedDict): + multi: NotRequired[bool] + response: NotRequired[Dict[str, CallbackOutput]] + sideUpdate: NotRequired[Dict[str, CallbackSideOutput]] diff --git a/requirements/install.txt b/requirements/install.txt index 284f3a5031..f65522326e 100644 --- a/requirements/install.txt +++ b/requirements/install.txt @@ -8,3 +8,4 @@ retrying nest-asyncio setuptools janus>=1.0.0 +pydantic>=2.12.5 diff --git a/tests/unit/test_layout.py b/tests/unit/test_layout.py new file mode 100644 index 0000000000..76a72f7fb4 --- /dev/null +++ b/tests/unit/test_layout.py @@ -0,0 +1,83 @@ +"""Tests for dash.layout — layout traversal and component lookup utilities.""" + +import pytest + +from dash import html, dcc +from dash.layout import ( + traverse, + find_component, + extract_text, + parse_wildcard_id, +) + + +@pytest.fixture +def sample_layout(): + return html.Div( + [ + html.Label("Name:", htmlFor="name-input"), + " ", + dcc.Input(id="name-input", value="World"), + html.Div( + [html.Span(id="deep-child", children="deep text")], + id="inner", + ), + ], + id="root", + ) + + +class TestTraverse: + def test_yields_all_components_with_correct_ancestors(self, sample_layout): + results = { + getattr(c, "id", None): len(ancestors) + for c, ancestors in traverse(sample_layout) + } + assert results["root"] == 0 + assert results["name-input"] == 1 + assert results["deep-child"] == 2 + + def test_empty_layout(self): + results = list(traverse(html.Div())) + assert len(results) == 1 # just the Div itself + + +class TestFindComponent: + def test_finds_by_string_id(self, sample_layout): + comp = find_component("deep-child", layout=sample_layout) + assert comp is not None and comp.id == "deep-child" + + def test_returns_none_for_missing_id(self, sample_layout): + assert find_component("nope", layout=sample_layout) is None + + def test_finds_by_dict_id(self): + layout = html.Div([html.Div(id={"type": "item", "index": 0})]) + assert find_component({"type": "item", "index": 0}, layout=layout) is not None + + +class TestExtractText: + def test_extracts_all_text_content(self, sample_layout): + assert extract_text(sample_layout) == "Name: deep text" + + def test_none_children(self): + assert extract_text(html.Div()) == "" + + +class TestParseWildcardId: + @pytest.mark.parametrize("wildcard", ["ALL", "MATCH", "ALLSMALLER"]) + def test_returns_dict_for_wildcard(self, wildcard): + result = parse_wildcard_id({"type": "input", "index": [wildcard]}) + assert result == {"type": "input", "index": [wildcard]} + + def test_parses_json_string(self): + result = parse_wildcard_id('{"type":"input","index":["ALL"]}') + assert result == {"type": "input", "index": ["ALL"]} + + def test_returns_none_for_plain_string(self): + assert parse_wildcard_id("my-dropdown") is None + + def test_returns_none_for_non_wildcard_dict(self): + assert parse_wildcard_id({"type": "input", "index": 0}) is None + + def test_returns_none_for_invalid_json(self): + assert parse_wildcard_id("{not valid}") is None diff --git a/tests/unit/test_pydantic_types.py b/tests/unit/test_pydantic_types.py new file mode 100644 index 0000000000..75d1dc7f41 --- /dev/null +++ b/tests/unit/test_pydantic_types.py @@ -0,0 +1,36 @@ +"""Tests for dash.types — Pydantic-compatible types and schemas.""" + +from pydantic import TypeAdapter + +from dash.types import NumberType, CallbackDispatchBody, CallbackDispatchResponse +from dash.development.base_component import Component + + +class TestNumberType: + def test_json_schema_is_number(self): + schema = TypeAdapter(NumberType).json_schema() + assert schema["type"] == "number" + + +class TestComponentPydanticSchema: + def test_produces_object_schema(self): + schema = TypeAdapter(Component).json_schema() + assert schema["type"] == "object" + assert "properties" in schema + + def test_schema_has_type_and_props(self): + schema = TypeAdapter(Component).json_schema() + props = schema["properties"] + assert "type" in props + assert "props" in props + + +class TestCallbackDispatchTypes: + def test_dispatch_body_schema(self): + schema = TypeAdapter(CallbackDispatchBody).json_schema() + assert "output" in schema["properties"] + assert "inputs" in schema["properties"] + + def test_dispatch_response_schema(self): + schema = TypeAdapter(CallbackDispatchResponse).json_schema() + assert "response" in schema["properties"] From 998d8d702a14c17a19fe4623c10e5bd56fc6ebde Mon Sep 17 00:00:00 2001 From: Adrian Borrmann Date: Wed, 1 Apr 2026 10:31:58 -0600 Subject: [PATCH 03/80] Extract get_layout() from serve_layout() --- dash/dash.py | 13 ++++++++++--- 1 file changed, 10 insertions(+), 3 deletions(-) diff --git a/dash/dash.py b/dash/dash.py index f0821abef2..9212754303 100644 --- a/dash/dash.py +++ b/dash/dash.py @@ -898,15 +898,22 @@ def index_string(self, value: str) -> None: self._index_string = value @with_app_context - def serve_layout(self): - layout = self._layout_value() + def get_layout(self): + """Return the resolved layout with all hooks applied. + This is the canonical way to obtain the app's layout — it + calls the layout function (if callable), includes extra + components, and runs layout hooks. + """ + layout = self._layout_value() for hook in self._hooks.get_hooks("layout"): layout = hook(layout) + return layout + def serve_layout(self): # TODO - Set browser cache limit - pass hash into frontend return self.backend.make_response( - to_json(layout), + to_json(self.get_layout()), mimetype="application/json", ) From 66eda33b1a28f9f7c93f8639c88bbcf21202705b Mon Sep 17 00:00:00 2001 From: Adrian Borrmann Date: Wed, 1 Apr 2026 10:46:25 -0600 Subject: [PATCH 04/80] Fix build issues for dash-table and dash-core-components --- components/dash-core-components/package.json | 4 ++-- components/dash-table/package.json | 2 +- dash/dash-renderer/babel.config.js | 2 +- dash/dash-renderer/package.json | 2 +- 4 files changed, 5 insertions(+), 5 deletions(-) diff --git a/components/dash-core-components/package.json b/components/dash-core-components/package.json index ffe7de1d2f..e430a00b6a 100644 --- a/components/dash-core-components/package.json +++ b/components/dash-core-components/package.json @@ -27,7 +27,7 @@ "build:js": "webpack --mode production", "build:backends": "dash-generate-components ./src/components dash_core_components -p package-info.json && cp dash_core_components_base/** dash_core_components/ && dash-generate-components ./src/components dash_core_components -p package-info.json -k RangeSlider,Slider,Dropdown,RadioItems,Checklist,DatePickerSingle,DatePickerRange,Input,Link --r-prefix 'dcc' --r-suggests 'dash,dashHtmlComponents,jsonlite,plotly' --jl-prefix 'dcc' && black dash_core_components", "build": "run-s prepublishOnly build:js build:backends", - "postbuild": "es-check es2015 dash_core_components/*.js", + "postbuild": "es-check es2017 dash_core_components/*.js", "build:watch": "watch 'npm run build' src", "format": "run-s private::format.*", "lint": "run-s private::lint.*" @@ -126,6 +126,6 @@ "react-dom": "16 - 19" }, "browserslist": [ - "last 10 years and not dead" + "last 11 years and not dead" ] } diff --git a/components/dash-table/package.json b/components/dash-table/package.json index 295517a4c9..fd905c6f40 100644 --- a/components/dash-table/package.json +++ b/components/dash-table/package.json @@ -119,6 +119,6 @@ "npm": ">=6.1.0" }, "browserslist": [ - "last 10 years and not dead" + "last 11 years and not dead" ] } diff --git a/dash/dash-renderer/babel.config.js b/dash/dash-renderer/babel.config.js index d7b0c89e8e..6e6cc5d957 100644 --- a/dash/dash-renderer/babel.config.js +++ b/dash/dash-renderer/babel.config.js @@ -3,7 +3,7 @@ module.exports = { '@babel/preset-typescript', ['@babel/preset-env', { "targets": { - "browsers": ["last 10 years and not dead"] + "browsers": ["last 11 years and not dead"] } }], '@babel/preset-react' diff --git a/dash/dash-renderer/package.json b/dash/dash-renderer/package.json index a404fa2425..b9d1965f0d 100644 --- a/dash/dash-renderer/package.json +++ b/dash/dash-renderer/package.json @@ -89,6 +89,6 @@ ], "prettier": "@plotly/prettier-config-dash", "browserslist": [ - "last 10 years and not dead" + "last 11 years and not dead" ] } From 0cdb67af951aebbcdd3a3622ead0c6228c08f98e Mon Sep 17 00:00:00 2001 From: Adrian Borrmann Date: Wed, 1 Apr 2026 15:13:51 -0600 Subject: [PATCH 05/80] Add CallbackDispatchBody type hints to dispatch methods --- dash/dash.py | 6 +++--- 1 file changed, 3 insertions(+), 3 deletions(-) diff --git a/dash/dash.py b/dash/dash.py index 9212754303..225c36bb47 100644 --- a/dash/dash.py +++ b/dash/dash.py @@ -75,7 +75,7 @@ _import_layouts_from_pages, ) from ._jupyter import jupyter_dash, JupyterDisplayMode -from .types import RendererHooks +from .types import CallbackDispatchBody, RendererHooks RouteCallable = Callable[..., Any] @@ -1469,7 +1469,7 @@ def _inputs_to_vals(self, inputs): return inputs_to_vals(inputs) # pylint: disable=R0915 - def _initialize_context(self, body): + def _initialize_context(self, body: CallbackDispatchBody): """Initialize the global context for the request.""" adapter = self.backend.request_adapter() g = AttributeDict({}) @@ -1492,7 +1492,7 @@ def _initialize_context(self, body): g.updated_props = {} return g - def _prepare_callback(self, g, body): + def _prepare_callback(self, g, body: CallbackDispatchBody): """Prepare callback-related data.""" output = body["output"] try: From b64c3fee5d8bc118945dbdb2bb1e086fc9159717 Mon Sep 17 00:00:00 2001 From: Adrian Borrmann Date: Wed, 1 Apr 2026 16:28:56 -0600 Subject: [PATCH 06/80] Use python3.8 compatible pydantic --- requirements/install.txt | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/requirements/install.txt b/requirements/install.txt index f65522326e..119a8fe523 100644 --- a/requirements/install.txt +++ b/requirements/install.txt @@ -8,4 +8,4 @@ retrying nest-asyncio setuptools janus>=1.0.0 -pydantic>=2.12.5 +pydantic>=2.10 From a85aa6ccbdbfc82c66b87ccecd5d79add60f4bdc Mon Sep 17 00:00:00 2001 From: Adrian Borrmann Date: Wed, 1 Apr 2026 16:43:50 -0600 Subject: [PATCH 07/80] lint --- dash/development/base_component.py | 5 ++++- 1 file changed, 4 insertions(+), 1 deletion(-) diff --git a/dash/development/base_component.py b/dash/development/base_component.py index 70ed17e9de..ca52837dfe 100644 --- a/dash/development/base_component.py +++ b/dash/development/base_component.py @@ -120,6 +120,7 @@ class Component(metaclass=ComponentMeta): @classmethod def __get_pydantic_core_schema__(cls, source_type, handler): from pydantic_core import core_schema + return core_schema.any_schema() @classmethod @@ -129,7 +130,9 @@ def __get_pydantic_json_schema__(cls, schema, handler): "type": "object", "properties": { "type": {"type": "string"}, - "namespace": {"type": "string", "enum": namespaces} if namespaces else {"type": "string"}, + "namespace": {"type": "string", "enum": namespaces} + if namespaces + else {"type": "string"}, "props": {"type": "object"}, }, } From 4ecaf7db2597b984ef3ecf17171f331b6db3812d Mon Sep 17 00:00:00 2001 From: Adrian Borrmann Date: Wed, 1 Apr 2026 17:03:03 -0600 Subject: [PATCH 08/80] Fix lint error on CI --- dash/development/base_component.py | 6 +++--- 1 file changed, 3 insertions(+), 3 deletions(-) diff --git a/dash/development/base_component.py b/dash/development/base_component.py index ca52837dfe..d7a473c25e 100644 --- a/dash/development/base_component.py +++ b/dash/development/base_component.py @@ -118,13 +118,13 @@ class Component(metaclass=ComponentMeta): available_wildcard_properties: typing.List[str] @classmethod - def __get_pydantic_core_schema__(cls, source_type, handler): - from pydantic_core import core_schema + def __get_pydantic_core_schema__(cls, _source_type, _handler): + from pydantic_core import core_schema # pylint: disable=import-outside-toplevel return core_schema.any_schema() @classmethod - def __get_pydantic_json_schema__(cls, schema, handler): + def __get_pydantic_json_schema__(cls, _schema, _handler): namespaces = list(ComponentRegistry.namespace_to_package.keys()) return { "type": "object", From 069b0494a8b84b496f82e8fc7c51c14f825868b0 Mon Sep 17 00:00:00 2001 From: Adrian Borrmann Date: Mon, 6 Apr 2026 12:00:13 -0600 Subject: [PATCH 09/80] Rename layout.py to _layout_utils.py --- dash/{layout.py => _layout_utils.py} | 0 tests/unit/test_layout.py | 4 ++-- 2 files changed, 2 insertions(+), 2 deletions(-) rename dash/{layout.py => _layout_utils.py} (100%) diff --git a/dash/layout.py b/dash/_layout_utils.py similarity index 100% rename from dash/layout.py rename to dash/_layout_utils.py diff --git a/tests/unit/test_layout.py b/tests/unit/test_layout.py index 76a72f7fb4..64fff724a1 100644 --- a/tests/unit/test_layout.py +++ b/tests/unit/test_layout.py @@ -1,9 +1,9 @@ -"""Tests for dash.layout — layout traversal and component lookup utilities.""" +"""Tests for dash._layout_utils — layout traversal and component lookup utilities.""" import pytest from dash import html, dcc -from dash.layout import ( +from dash._layout_utils import ( traverse, find_component, extract_text, From d684507fe3e29c8aa2225135e632f79faafb114d Mon Sep 17 00:00:00 2001 From: Adrian Borrmann Date: Mon, 6 Apr 2026 13:41:38 -0600 Subject: [PATCH 10/80] Make NumberType import backwards-compatible in generated components --- dash/development/_py_components_generation.py | 7 ++++++- 1 file changed, 6 insertions(+), 1 deletion(-) diff --git a/dash/development/_py_components_generation.py b/dash/development/_py_components_generation.py index b597283a04..2fd6a6cdb7 100644 --- a/dash/development/_py_components_generation.py +++ b/dash/development/_py_components_generation.py @@ -24,7 +24,12 @@ import typing # noqa: F401 from typing_extensions import TypedDict, NotRequired, Literal # noqa: F401 from dash.development.base_component import Component, _explicitize_args -from dash.types import NumberType # noqa: F401 +try: + from dash.types import NumberType # noqa: F401 +except ImportError: + NumberType = typing.Union[ # noqa: F401 + typing.SupportsFloat, typing.SupportsInt, typing.SupportsComplex + ] {custom_imports} ComponentSingleType = typing.Union[str, int, float, Component, None] ComponentType = typing.Union[ From 17a754164773824bd24113ca705f5aea26e7df66 Mon Sep 17 00:00:00 2001 From: Adrian Borrmann Date: Mon, 6 Apr 2026 14:34:19 -0600 Subject: [PATCH 11/80] Fix type checker for NumberType implementation --- dash/development/_py_components_generation.py | 14 +++++++++----- 1 file changed, 9 insertions(+), 5 deletions(-) diff --git a/dash/development/_py_components_generation.py b/dash/development/_py_components_generation.py index 2fd6a6cdb7..1690a1744e 100644 --- a/dash/development/_py_components_generation.py +++ b/dash/development/_py_components_generation.py @@ -22,14 +22,18 @@ import_string = """# AUTO GENERATED FILE - DO NOT EDIT import typing # noqa: F401 +from typing import TYPE_CHECKING # noqa: F401 from typing_extensions import TypedDict, NotRequired, Literal # noqa: F401 from dash.development.base_component import Component, _explicitize_args -try: +if TYPE_CHECKING: from dash.types import NumberType # noqa: F401 -except ImportError: - NumberType = typing.Union[ # noqa: F401 - typing.SupportsFloat, typing.SupportsInt, typing.SupportsComplex - ] +else: + try: + from dash.types import NumberType # noqa: F401 + except ImportError: + NumberType = typing.Union[ # noqa: F401 + typing.SupportsFloat, typing.SupportsInt, typing.SupportsComplex + ] {custom_imports} ComponentSingleType = typing.Union[str, int, float, Component, None] ComponentType = typing.Union[ From dc35bff200091e79701c808fe39ee473260931bd Mon Sep 17 00:00:00 2001 From: Adrian Borrmann Date: Tue, 7 Apr 2026 10:30:13 -0600 Subject: [PATCH 12/80] Clean up import for pyright --- dash/development/_py_components_generation.py | 17 ++++++++--------- 1 file changed, 8 insertions(+), 9 deletions(-) diff --git a/dash/development/_py_components_generation.py b/dash/development/_py_components_generation.py index 1690a1744e..87211c7cc8 100644 --- a/dash/development/_py_components_generation.py +++ b/dash/development/_py_components_generation.py @@ -22,18 +22,17 @@ import_string = """# AUTO GENERATED FILE - DO NOT EDIT import typing # noqa: F401 -from typing import TYPE_CHECKING # noqa: F401 from typing_extensions import TypedDict, NotRequired, Literal # noqa: F401 from dash.development.base_component import Component, _explicitize_args -if TYPE_CHECKING: +try: from dash.types import NumberType # noqa: F401 -else: - try: - from dash.types import NumberType # noqa: F401 - except ImportError: - NumberType = typing.Union[ # noqa: F401 - typing.SupportsFloat, typing.SupportsInt, typing.SupportsComplex - ] +except ImportError: + # Backwards compatibility for dash<=4.1.0 + if typing.TYPE_CHECKING: + raise + NumberType = typing.Union[ # noqa: F401 + typing.SupportsFloat, typing.SupportsInt, typing.SupportsComplex + ] {custom_imports} ComponentSingleType = typing.Union[str, int, float, Component, None] ComponentType = typing.Union[ From 00eb203db2f25571f39c9965d039a433312d274e Mon Sep 17 00:00:00 2001 From: Adrian Borrmann Date: Wed, 1 Apr 2026 10:37:02 -0600 Subject: [PATCH 13/80] Add callback adapter core for MCP tool generation --- dash/_callback.py | 12 +- dash/mcp/primitives/tools/callback_adapter.py | 461 ++++++++++++++++++ .../tools/callback_adapter_collection.py | 154 ++++++ dash/mcp/primitives/tools/callback_utils.py | 36 ++ .../primitives/tools/descriptions/__init__.py | 7 + .../tools/input_schemas/__init__.py | 5 + .../tools/output_schemas/__init__.py | 5 + dash/mcp/types/__init__.py | 26 + dash/mcp/types/callback_types.py | 33 ++ dash/mcp/types/component_types.py | 20 + dash/mcp/types/exceptions.py | 30 ++ dash/mcp/types/typing_utils.py | 28 ++ requirements/install.txt | 1 + tests/unit/mcp/conftest.py | 6 + tests/unit/mcp/tools/test_callback_adapter.py | 227 +++++++++ .../tools/test_callback_adapter_collection.py | 145 ++++++ 16 files changed, 1193 insertions(+), 3 deletions(-) create mode 100644 dash/mcp/primitives/tools/callback_adapter.py create mode 100644 dash/mcp/primitives/tools/callback_adapter_collection.py create mode 100644 dash/mcp/primitives/tools/callback_utils.py create mode 100644 dash/mcp/primitives/tools/descriptions/__init__.py create mode 100644 dash/mcp/primitives/tools/input_schemas/__init__.py create mode 100644 dash/mcp/primitives/tools/output_schemas/__init__.py create mode 100644 dash/mcp/types/__init__.py create mode 100644 dash/mcp/types/callback_types.py create mode 100644 dash/mcp/types/component_types.py create mode 100644 dash/mcp/types/exceptions.py create mode 100644 dash/mcp/types/typing_utils.py create mode 100644 tests/unit/mcp/conftest.py create mode 100644 tests/unit/mcp/tools/test_callback_adapter.py create mode 100644 tests/unit/mcp/tools/test_callback_adapter_collection.py diff --git a/dash/_callback.py b/dash/_callback.py index a0d5a1021d..96a5e8f16c 100644 --- a/dash/_callback.py +++ b/dash/_callback.py @@ -39,6 +39,7 @@ from .background_callback.managers import BaseBackgroundCallbackManager from ._callback_context import context_value +from .types import CallbackDispatchResponse from ._no_update import NoUpdate from . import _validate @@ -85,6 +86,7 @@ def callback( hidden: Optional[bool] = None, websocket: Optional[bool] = False, persistent: Optional[bool] = False, + mcp_enabled: bool = True, **_kwargs, ) -> Callable[[Callable[Params, ReturnVar]], Callable[Params, ReturnVar]]: """ @@ -242,6 +244,7 @@ def callback( hidden=hidden, websocket=websocket, persistent=persistent, + mcp_enabled=mcp_enabled, ) return cast( @@ -295,6 +298,7 @@ def insert_callback( hidden=None, websocket=False, persistent=False, + mcp_enabled=True, ) -> str: if prevent_initial_call is None: prevent_initial_call = config_prevent_initial_callbacks @@ -338,6 +342,7 @@ def insert_callback( "allow_dynamic_callbacks": dynamic_creator, "no_output": no_output, "websocket": websocket, + "mcp_enabled": mcp_enabled, } callback_list.append(callback_spec) @@ -546,7 +551,7 @@ def _prepare_response( output_value, output_spec, multi, - response, + response: CallbackDispatchResponse, callback_ctx, app, original_packages, @@ -677,6 +682,7 @@ def register_callback( hidden=_kwargs.get("hidden", None), websocket=_kwargs.get("websocket", False), persistent=_kwargs.get("persistent", False), + mcp_enabled=_kwargs.get("mcp_enabled", True), ) # pylint: disable=too-many-locals @@ -711,7 +717,7 @@ def add_context(*args, **kwargs): args, kwargs, inputs_state_indices, has_output, insert_output ) - response: dict = {"multi": True} # type: ignore + response: CallbackDispatchResponse = {"multi": True} jsonResponse: Optional[str] = None try: if background is not None: @@ -783,7 +789,7 @@ async def async_add_context(*args, **kwargs): args, kwargs, inputs_state_indices, has_output, insert_output ) - response = {"multi": True} + response: CallbackDispatchResponse = {"multi": True} try: if background is not None: diff --git a/dash/mcp/primitives/tools/callback_adapter.py b/dash/mcp/primitives/tools/callback_adapter.py new file mode 100644 index 0000000000..0f50d15c03 --- /dev/null +++ b/dash/mcp/primitives/tools/callback_adapter.py @@ -0,0 +1,461 @@ +"""Adapter: Dash callback → MCP tool interface. + +Wraps a raw ``callback_map`` entry and exposes MCP-facing +properties (tool name, params, outputs) lazily. +""" + +from __future__ import annotations + +import inspect +import json +import typing +from functools import cached_property +from typing import Any + +from mcp.types import Tool + +from dash import get_app +from dash.layout import ( + _WILDCARD_VALUES, + find_component, + find_matching_components, + parse_wildcard_id, +) +from dash.mcp.types import is_nullable +from dash._grouping import flatten_grouping +from dash._utils import clean_property_name, split_callback_id +from dash.mcp.types import MCPInput, MCPOutput +from .callback_utils import run_callback +from .descriptions import build_tool_description +from .input_schemas import get_input_schema +from .output_schemas import get_output_schema + + +class CallbackAdapter: + """Adapts a single Dash callback_map entry to the MCP tool interface.""" + + def __init__(self, callback_output_id: str): + self._output_id = callback_output_id + + # ------------------------------------------------------------------- + # Projections + # ------------------------------------------------------------------- + + @cached_property + def as_mcp_tool(self) -> Tool: + """Stub — will be implemented in a future PR.""" + raise NotImplementedError("as_mcp_tool will be implemented in a future PR.") + + def as_callback_body(self, kwargs: dict[str, Any]) -> dict[str, Any]: + """Transforms the given kwargs to a dict suitable for calling this callback. + + Mirrors how the Dash renderer assembles the callback payload — + see ``fillVals()`` in ``dash-renderer/src/actions/callbacks.ts``. + + For pattern-matching callbacks, wildcard deps are expanded into + nested arrays with concrete component IDs. + """ + coerced = {k: _coerce_value(v) for k, v in kwargs.items()} + + raw_inputs = self._cb_info.get("inputs", []) + raw_state = self._cb_info.get("state", []) + n_deps = len(raw_inputs) + len(raw_state) + + flat_values = [None] * n_deps + for i, name in enumerate(self._param_names): + if i < n_deps and name in coerced: + flat_values[i] = coerced[name] + + inputs_with_values = [ + _expand_dep(dep, flat_values[i]) for i, dep in enumerate(raw_inputs) + ] + state_with_values = [ + _expand_dep(dep, flat_values[len(raw_inputs) + i]) + for i, dep in enumerate(raw_state) + ] + + outputs_spec = _expand_output_spec( + self._output_id, self._cb_info, inputs_with_values + ) + + # changedPropIds: only inputs with non-None values. + # This determines ctx.triggered_id in the callback. + changed = [] + for entry in inputs_with_values: + if isinstance(entry, dict) and entry.get("value") is not None: + eid = entry.get("id") + if isinstance(eid, dict): + changed.append( + f"{json.dumps(eid, sort_keys=True)}.{entry['property']}" + ) + elif isinstance(eid, str): + changed.append(f"{eid}.{entry['property']}") + + return { + "output": self._output_id, + "outputs": outputs_spec, + "inputs": inputs_with_values, + "state": state_with_values, + "changedPropIds": changed, + } + + # ------------------------------------------------------------------- + # Public identity and metadata + # ------------------------------------------------------------------- + + @cached_property + def is_valid(self) -> bool: + """Whether all input components exist in the layout.""" + all_deps = self._cb_info.get("inputs", []) + self._cb_info.get("state", []) + for dep in all_deps: + dep_id = str(dep.get("id", "")) + if dep_id.startswith("{"): + continue + if find_component(dep_id) is None: + return False + return True + + @property + def output_id(self) -> str: + return self._output_id + + @property + def tool_name(self) -> str: + return get_app().mcp_callback_map._tool_names_map[self._output_id] + + @cached_property + def prevents_initial_call(self) -> bool: + for cb in get_app()._callback_list: + if cb["output"] == self._output_id: + return cb.get("prevent_initial_call", False) + return False + + # ------------------------------------------------------------------- + # Private: computed fields for the MCP Tool + # ------------------------------------------------------------------- + + @cached_property + def _description(self) -> str: + return build_tool_description(self.outputs, self._docstring) + + @cached_property + def _input_schema(self) -> dict[str, Any]: + properties = {p["name"]: get_input_schema(p) for p in self.inputs} + required = [p["name"] for p in self.inputs if p["required"]] + + input_schema: dict[str, Any] = {"type": "object", "properties": properties} + if required: + input_schema["required"] = required + return input_schema + + @cached_property + def _output_schema(self) -> dict[str, Any]: + return get_output_schema() + + # ------------------------------------------------------------------- + # Private: callback metadata + # ------------------------------------------------------------------- + + @cached_property + def _docstring(self) -> str | None: + return getattr(self._original_func, "__doc__", None) + + @cached_property + def _initial_output(self) -> dict[str, dict[str, Any]]: + """Run this callback with initial input values. + + Returns the ``response`` portion of the dispatch result: + ``{component_id: {property: value}}``. + + Skipped for callbacks with ``prevent_initial_call=True``, + matching how the Dash renderer skips them on page load. + """ + if self.prevents_initial_call: + return {} + + callback_map = get_app().mcp_callback_map + kwargs = {} + for p in self.inputs: + upstream = callback_map.find_by_output(p["id_and_prop"]) + if upstream is self: + kwargs[p["name"]] = getattr( + find_component(p["component_id"]), p["property"], None + ) + else: + kwargs[p["name"]] = callback_map.get_initial_value(p["id_and_prop"]) + try: + result = run_callback(self, kwargs) + return result.get("response", {}) + except Exception: + return {} + + def initial_output_value(self, id_and_prop: str) -> Any: + """Return the initial value for a specific output ``"component_id.property"``.""" + component_id, prop = id_and_prop.rsplit(".", 1) + return self._initial_output.get(component_id, {}).get(prop) + + @cached_property + def outputs(self) -> list[MCPOutput]: + if self._cb_info.get("no_output"): + return [] + parsed = split_callback_id(self._output_id) + if isinstance(parsed, dict): + parsed = [parsed] + result: list[MCPOutput] = [] + for p in parsed: + comp_id = p["id"] + prop = clean_property_name(p["property"]) + id_and_prop = f"{comp_id}.{prop}" + comp = find_component(comp_id) + result.append( + { + "id_and_prop": id_and_prop, + "component_id": comp_id, + "property": prop, + "component_type": getattr(comp, "_type", None), + "initial_value": self.initial_output_value(id_and_prop), + "tool_name": self.tool_name, + } + ) + return result + + @cached_property + def inputs(self) -> list[MCPInput]: + all_deps = self._cb_info.get("inputs", []) + self._cb_info.get("state", []) + callback_map = get_app().mcp_callback_map + + result: list[MCPInput] = [] + for dep, name, annotation in zip( + all_deps, self._param_names, self._param_annotations + ): + comp_id = str(dep.get("id", "unknown")) + comp = find_component(comp_id) + prop = dep.get("property", "unknown") + id_and_prop = f"{comp_id}.{prop}" + + upstream_cb = callback_map.find_by_output(id_and_prop) + upstream_output = None + if upstream_cb is not None and upstream_cb is not self: + if not upstream_cb.prevents_initial_call: + for out in upstream_cb.outputs: + if out["id_and_prop"] == id_and_prop: + upstream_output = out + break + + initial_value = ( + upstream_output["initial_value"] + if upstream_output is not None + else getattr(comp, prop, None) + ) + + if annotation is not None: + required = not is_nullable(annotation) + else: + required = initial_value is not None + + result.append( + { + "name": name, + "id_and_prop": id_and_prop, + "component_id": comp_id, + "property": prop, + "annotation": annotation, + "component_type": getattr(comp, "_type", None), + "component": comp, + "required": required, + "initial_value": initial_value, + "upstream_output": upstream_output, + } + ) + return result + + # ------------------------------------------------------------------- + # Helpers + # ------------------------------------------------------------------- + + @cached_property + def _cb_info(self) -> dict[str, Any]: + return get_app().callback_map[self._output_id] + + @cached_property + def _original_func(self) -> Any | None: + func = self._cb_info.get("callback") + return getattr(func, "__wrapped__", func) + + @cached_property + def _func_signature(self) -> inspect.Signature | None: + if self._original_func is None: + return None + try: + return inspect.signature(self._original_func) + except (ValueError, TypeError): + return None + + @cached_property + def _dep_param_map(self) -> list[tuple[str, str]]: + """(func_param_name, mcp_param_name) per dep, in dep order. + + Single source of truth for mapping deps to param names. + All dict-vs-list branching is confined here. + """ + all_deps = self._cb_info.get("inputs", []) + self._cb_info.get("state", []) + n_deps = len(all_deps) + indices = self._cb_info.get("inputs_state_indices") + + if isinstance(indices, dict): + entries: list[tuple[int, str, str]] = [] + for func_name, idx in indices.items(): + positions = flatten_grouping(idx) + if len(positions) == 1: + entries.append((positions[0], func_name, func_name)) + else: + for pos in positions: + dep = all_deps[pos] if pos < n_deps else {} + comp_id = str(dep.get("id", "unknown")).replace("-", "_") + prop = dep.get("property", "unknown") + entries.append( + (pos, func_name, f"{func_name}_{comp_id}__{prop}") + ) + entries.sort(key=lambda e: e[0]) + result = [(f, m) for _, f, m in entries] + elif self._func_signature is not None: + names = list(self._func_signature.parameters.keys()) + result = [(n, n) for n in names] + else: + result = [] + + while len(result) < n_deps: + fallback = f"param_{len(result)}" + result.append((fallback, fallback)) + return result + + @cached_property + def _param_names(self) -> list[str]: + """MCP param name per dep, in dep order.""" + return [mcp for _, mcp in self._dep_param_map] + + @cached_property + def _param_annotations(self) -> list[Any | None]: + """One annotation per dep, in dep order.""" + if self._func_signature is None: + return [None] * len(self._dep_param_map) + try: + hints = typing.get_type_hints(self._original_func) + except Exception: + hints = getattr(self._original_func, "__annotations__", {}) + return [hints.get(func_name) for func_name, _ in self._dep_param_map] + + +def _expand_dep(dep: dict, value: Any) -> Any: + """Expand a dependency into the dispatch format. + + For regular deps, returns ``{id, property, value}``. + For ALL/ALLSMALLER: passes through the list of ``{id, property, value}`` dicts. + For MATCH: passes through the single ``{id, property, value}`` dict. + """ + pattern = parse_wildcard_id(dep.get("id", "")) + if pattern is None: + return {**dep, "value": value} + + # LLM provides browser-like format + if isinstance(value, list): + return value + if isinstance(value, dict) and "id" in value: + return value + return {**dep, "value": value} + + +def _expand_output_spec(output_id: str, cb_info: dict, resolved_inputs: list) -> Any: + """Build the outputs spec, expanding wildcards to concrete IDs. + + For wildcard outputs, derives concrete IDs from the resolved inputs. + The browser does the same: input and output wildcards resolve against + the same set of matching components. + """ + if cb_info.get("no_output"): + return [] + + parsed = split_callback_id(output_id) + if isinstance(parsed, dict): + parsed = [parsed] + + results = [] + for p in parsed: + pid = p["id"] + prop = clean_property_name(p["property"]) + pattern = parse_wildcard_id(pid) + if pattern is not None: + concrete_ids = _derive_output_ids(pattern, resolved_inputs) + if not concrete_ids: + concrete_ids = [comp.id for comp in find_matching_components(pattern)] + expanded = [{"id": cid, "property": prop} for cid in concrete_ids] + # ALL/ALLSMALLER → nested list; MATCH → single dict + if len(expanded) == 1: + results.append(expanded[0]) + else: + results.append(expanded) + else: + results.append({"id": pid, "property": prop}) + + if len(results) == 1: + return results[0] + return results + + +def _derive_output_ids( + output_pattern: dict, resolved_inputs: list +) -> list[dict] | None: + """Derive concrete output IDs from the resolved input entries. + + Extracts the wildcard key values from the LLM-provided concrete + input IDs and substitutes them into the output pattern. + """ + wildcard_keys = [ + k + for k, v in output_pattern.items() + if isinstance(v, list) and len(v) == 1 and v[0] in _WILDCARD_VALUES + ] + if not wildcard_keys: + return None + + def _substitute(item_id: dict) -> dict | None: + if not isinstance(item_id, dict): + return None + output_id = dict(output_pattern) + for wk in wildcard_keys: + if wk in item_id: + output_id[wk] = item_id[wk] + return output_id + + for entry in resolved_inputs: + # ALL/ALLSMALLER: nested array of {id, property, value} dicts + if isinstance(entry, list) and entry: + concrete_ids = [] + for item in entry: + out = _substitute(item.get("id")) + if out: + concrete_ids.append(out) + if concrete_ids: + return concrete_ids + # MATCH: single {id, property, value} dict + elif isinstance(entry, dict) and isinstance(entry.get("id"), dict): + out = _substitute(entry["id"]) + if out: + return [out] + + return None + + +def _coerce_value(value: Any) -> Any: + """Parse JSON strings back to Python objects. + + MCP tool parameters arrive as strings. This recovers the + intended type (list, dict, number, bool, null) via json.loads. + Plain strings that aren't valid JSON pass through unchanged. + """ + if not isinstance(value, str): + return value + try: + return json.loads(value) + except (json.JSONDecodeError, ValueError): + return value diff --git a/dash/mcp/primitives/tools/callback_adapter_collection.py b/dash/mcp/primitives/tools/callback_adapter_collection.py new file mode 100644 index 0000000000..60e9e2efe5 --- /dev/null +++ b/dash/mcp/primitives/tools/callback_adapter_collection.py @@ -0,0 +1,154 @@ +"""Collection of CallbackAdapters with cross-adapter queries. + +Stored as a singleton on ``app.mcp_callback_map``. +""" + +from __future__ import annotations + +import hashlib +import re +from functools import cached_property +from typing import Any + +from mcp.types import Tool + +from dash import get_app +from dash._utils import clean_property_name, split_callback_id +from dash.layout import extract_text, find_component, traverse +from .callback_adapter import CallbackAdapter + + +class CallbackAdapterCollection: + def __init__(self, app): + callback_map = getattr(app, "callback_map", {}) + + raw: list[tuple[str, dict]] = [] + for output_id, cb_info in callback_map.items(): + if cb_info.get("mcp_enabled") is False: + continue + if "callback" not in cb_info: + continue + raw.append((output_id, cb_info)) + + self._tool_names_map = self._build_tool_names(raw) + self._callbacks = [ + CallbackAdapter(callback_output_id=output_id) + for output_id in self._tool_names_map + ] + # TODO: enable_mcp_server() will replace this with a direct assignment on app + app.mcp_callback_map = self + + @staticmethod + def _sanitize_name(name: str) -> str: + + max_len = 64 + sanitized = re.sub(r"[^a-zA-Z0-9_]", "_", name) + sanitized = re.sub(r"_+", "_", sanitized).strip("_") + if sanitized and sanitized[0].isdigit(): + sanitized = "cb_" + sanitized + full = sanitized or "unnamed_callback" + if len(full) <= max_len: + return full + hash_suffix = hashlib.sha256(full.encode()).hexdigest()[:8] + truncated = sanitized[: max_len - 9].rstrip("_") + return f"{truncated}_{hash_suffix}" + + @classmethod + def _build_tool_names(cls, raw: list[tuple[str, dict]]) -> dict[str, str]: + func_name_counts: dict[str, int] = {} + for _output_id, cb_info in raw: + func = cb_info.get("callback") + original = getattr(func, "__wrapped__", func) + fn = getattr(original, "__name__", "") or "" + func_name_counts[fn] = func_name_counts.get(fn, 0) + 1 + + name_map: dict[str, str] = {} + for output_id, cb_info in raw: + func = cb_info.get("callback") + original = getattr(func, "__wrapped__", func) + fn = getattr(original, "__name__", "") or "" + raw_name = fn if fn and func_name_counts[fn] == 1 else output_id + name_map[output_id] = cls._sanitize_name(raw_name) + return name_map + + def __iter__(self): + return iter(self._callbacks) + + def __len__(self): + return len(self._callbacks) + + def __getitem__(self, index): + return self._callbacks[index] + + def find_by_tool_name(self, name: str) -> CallbackAdapter | None: + for cb in self._callbacks: + if cb.tool_name == name: + return cb + return None + + def find_by_output(self, id_and_prop: str) -> CallbackAdapter | None: + """Find the adapter that outputs to ``id_and_prop`` (``"component_id.property"``).""" + for cb in self._callbacks: + try: + parsed = split_callback_id(cb.output_id) + except ValueError: + continue + if isinstance(parsed, dict): + parsed = [parsed] + for p in parsed: + if f"{p['id']}.{clean_property_name(p['property'])}" == id_and_prop: + return cb + return None + + def get_initial_value(self, id_and_prop: str) -> Any: + """Return the initial value for ``id_and_prop`` (``"component_id.property"``). + + If a callback outputs to this property, runs it (recursively + resolving its inputs). Otherwise returns the layout default. + """ + upstream_cb = self.find_by_output(id_and_prop) + if upstream_cb is not None: + return upstream_cb.initial_output_value(id_and_prop) + else: + component_id, prop = id_and_prop.rsplit(".", 1) + layout_component = find_component(component_id) + return getattr(layout_component, prop, None) + + def as_mcp_tools(self) -> list[Tool]: + """Stub — will be implemented in a future PR.""" + raise NotImplementedError("as_mcp_tools will be implemented in a future PR.") + + @property + def tool_names(self) -> set[str]: + return set(self._tool_names_map.values()) + + @cached_property + def component_label_map(self) -> dict[str, list[str]]: + """Map component ID → list of label texts from html.Label containers + and/or `htmlFor` associations. + """ + layout = get_app().get_layout() + if layout is None: + return {} + + labels: dict[str, list[str]] = {} + for comp, ancestors in traverse(layout): + if getattr(comp, "_type", None) == "Label": + html_for = getattr(comp, "htmlFor", None) + if html_for is not None: + text = extract_text(comp) + if text: + labels.setdefault(str(html_for), []).append(text) + + comp_id = getattr(comp, "id", None) + if comp_id is not None: + for ancestor in reversed(ancestors): + if getattr(ancestor, "_type", None) == "Label": + text = extract_text(ancestor) + if text: + sid = str(comp_id) + if text not in labels.get(sid, []): + labels.setdefault(sid, []).append(text) + break + + return labels diff --git a/dash/mcp/primitives/tools/callback_utils.py b/dash/mcp/primitives/tools/callback_utils.py new file mode 100644 index 0000000000..ec157b6037 --- /dev/null +++ b/dash/mcp/primitives/tools/callback_utils.py @@ -0,0 +1,36 @@ +"""Callback introspection utilities for MCP tools.""" + +from __future__ import annotations + +import json +from typing import TYPE_CHECKING, Any + +from dash import get_app + +if TYPE_CHECKING: + from .callback_adapter import CallbackAdapter + + +def run_callback(callback: CallbackAdapter, kwargs: dict[str, Any]) -> dict[str, Any]: + """Execute a callback via Dash's dispatch pipeline.""" + from dash.mcp.types import CallbackExecutionError + + body = callback.as_callback_body(kwargs) + + app = get_app() + with app.server.test_request_context( + "/_dash-update-component", + method="POST", + data=json.dumps(body, default=str), + content_type="application/json", + ): + response = app.dispatch() + + response_text = response.get_data(as_text=True) + if response.status_code != 200: + raise CallbackExecutionError( + f"Callback {callback.output_id} failed " + f"(HTTP {response.status_code}): {response_text[:500]}" + ) + + return json.loads(response_text) diff --git a/dash/mcp/primitives/tools/descriptions/__init__.py b/dash/mcp/primitives/tools/descriptions/__init__.py new file mode 100644 index 0000000000..67ec78c9ff --- /dev/null +++ b/dash/mcp/primitives/tools/descriptions/__init__.py @@ -0,0 +1,7 @@ +"""Stub — real implementation in a later PR.""" + + +def build_tool_description(outputs, docstring=None): + if docstring: + return docstring.strip() + return "Dash callback" diff --git a/dash/mcp/primitives/tools/input_schemas/__init__.py b/dash/mcp/primitives/tools/input_schemas/__init__.py new file mode 100644 index 0000000000..f306042a0c --- /dev/null +++ b/dash/mcp/primitives/tools/input_schemas/__init__.py @@ -0,0 +1,5 @@ +"""Stub — real implementation in a later PR.""" + + +def get_input_schema(param): + return {} diff --git a/dash/mcp/primitives/tools/output_schemas/__init__.py b/dash/mcp/primitives/tools/output_schemas/__init__.py new file mode 100644 index 0000000000..d2d70c3552 --- /dev/null +++ b/dash/mcp/primitives/tools/output_schemas/__init__.py @@ -0,0 +1,5 @@ +"""Stub — real implementation in a later PR.""" + + +def get_output_schema(): + return {} diff --git a/dash/mcp/types/__init__.py b/dash/mcp/types/__init__.py new file mode 100644 index 0000000000..af588e0808 --- /dev/null +++ b/dash/mcp/types/__init__.py @@ -0,0 +1,26 @@ +"""MCP types, exceptions, and typing utilities.""" + +from dash.mcp.types.callback_types import MCPInput, MCPOutput +from dash.mcp.types.component_types import ( + ComponentPropertyInfo, + ComponentQueryResult, +) +from dash.mcp.types.exceptions import ( + CallbackExecutionError, + InvalidParamsError, + MCPError, + ToolNotFoundError, +) +from dash.mcp.types.typing_utils import is_nullable + +__all__ = [ + "CallbackExecutionError", + "ComponentPropertyInfo", + "ComponentQueryResult", + "InvalidParamsError", + "MCPError", + "MCPInput", + "MCPOutput", + "ToolNotFoundError", + "is_nullable", +] diff --git a/dash/mcp/types/callback_types.py b/dash/mcp/types/callback_types.py new file mode 100644 index 0000000000..9c65dcb9d8 --- /dev/null +++ b/dash/mcp/types/callback_types.py @@ -0,0 +1,33 @@ +"""Typed dicts for MCP callback adapter data.""" + +from __future__ import annotations + +from typing import Any + +from typing_extensions import TypedDict + + +class MCPOutput(TypedDict): + """A single callback output, with component type and initial value resolved.""" + + id_and_prop: str + component_id: str + property: str + component_type: str | None + initial_value: Any + tool_name: str + + +class MCPInput(TypedDict): + """A single callback parameter (input or state), fully resolved.""" + + name: str + id_and_prop: str + component_id: str + property: str + annotation: Any | None + component_type: str | None + component: Any | None + required: bool + initial_value: Any + upstream_output: MCPOutput | None diff --git a/dash/mcp/types/component_types.py b/dash/mcp/types/component_types.py new file mode 100644 index 0000000000..0cac3ad689 --- /dev/null +++ b/dash/mcp/types/component_types.py @@ -0,0 +1,20 @@ +"""Typed dicts for component data in MCP.""" + +from __future__ import annotations + +from typing import Any + +from typing_extensions import NotRequired, TypedDict + + +class ComponentPropertyInfo(TypedDict): + initial_value: Any + modified_by_tool: list[str] + input_to_tool: list[str] + + +class ComponentQueryResult(TypedDict): + component_id: str + component_type: str + label: NotRequired[list[str] | None] + properties: dict[str, ComponentPropertyInfo] diff --git a/dash/mcp/types/exceptions.py b/dash/mcp/types/exceptions.py new file mode 100644 index 0000000000..7fb962db85 --- /dev/null +++ b/dash/mcp/types/exceptions.py @@ -0,0 +1,30 @@ +"""MCP error types with JSON-RPC error codes.""" + +from __future__ import annotations + + +class MCPError(Exception): + """Base MCP error carrying a JSON-RPC error code.""" + + code = -32603 + + def __init__(self, message: str): + super().__init__(message) + + +class ToolNotFoundError(MCPError): + """Tool name not found in the callback registry.""" + + code = -32601 + + +class InvalidParamsError(MCPError): + """Invalid or missing parameters for a tool call.""" + + code = -32602 + + +class CallbackExecutionError(MCPError): + """Callback raised an exception during execution.""" + + code = -32603 diff --git a/dash/mcp/types/typing_utils.py b/dash/mcp/types/typing_utils.py new file mode 100644 index 0000000000..9a96d4135d --- /dev/null +++ b/dash/mcp/types/typing_utils.py @@ -0,0 +1,28 @@ +"""Shared typing utilities for the MCP layer.""" + +from __future__ import annotations + +import typing +from typing import Any + + +def is_nullable(annotation: Any) -> bool: + """Check if a type annotation includes NoneType (is nullable/Optional).""" + origin = getattr(annotation, "__origin__", None) + args = getattr(annotation, "__args__", ()) + + _is_union = origin is typing.Union + if not _is_union: + try: + import types as _types + + if isinstance(annotation, _types.UnionType): + _is_union = True + args = annotation.__args__ + except AttributeError: + pass + + if _is_union and args: + return type(None) in args + + return False diff --git a/requirements/install.txt b/requirements/install.txt index 119a8fe523..a976ab9010 100644 --- a/requirements/install.txt +++ b/requirements/install.txt @@ -9,3 +9,4 @@ nest-asyncio setuptools janus>=1.0.0 pydantic>=2.10 +mcp>=1.0.0; python_version>="3.10" diff --git a/tests/unit/mcp/conftest.py b/tests/unit/mcp/conftest.py new file mode 100644 index 0000000000..437a71db5c --- /dev/null +++ b/tests/unit/mcp/conftest.py @@ -0,0 +1,6 @@ +import sys + +collect_ignore_glob = [] + +if sys.version_info < (3, 10): + collect_ignore_glob.append("*") diff --git a/tests/unit/mcp/tools/test_callback_adapter.py b/tests/unit/mcp/tools/test_callback_adapter.py new file mode 100644 index 0000000000..91808d304e --- /dev/null +++ b/tests/unit/mcp/tools/test_callback_adapter.py @@ -0,0 +1,227 @@ +"""Tests for CallbackAdapter.""" + +import pytest +from dash import Dash, Input, Output, dcc, html +from dash._get_app import app_context + +from dash.mcp.primitives.tools.callback_adapter_collection import ( + CallbackAdapterCollection, +) + + +# --------------------------------------------------------------------------- +# Fixtures +# --------------------------------------------------------------------------- + + +@pytest.fixture +def simple_app(): + app = Dash(__name__) + app.layout = html.Div( + [ + html.Label("Your Name", htmlFor="inp"), + dcc.Input(id="inp", type="text"), + html.Div(id="out"), + ] + ) + + @app.callback(Output("out", "children"), Input("inp", "value")) + def update(val): + """Update output.""" + return val + + app_context.set(app) + app.mcp_callback_map = CallbackAdapterCollection(app) + return app + + +@pytest.fixture +def duplicate_names_app(): + app = Dash(__name__) + app.layout = html.Div( + [ + html.Div(id="in1"), + html.Div(id="out1"), + html.Div(id="in2"), + html.Div(id="out2"), + ] + ) + + @app.callback(Output("out1", "children"), Input("in1", "children")) + def cb(v): + return v + + @app.callback(Output("out2", "children"), Input("in2", "children")) + def cb(v): # noqa: F811 + return v + + app_context.set(app) + app.mcp_callback_map = CallbackAdapterCollection(app) + return app + + +# --------------------------------------------------------------------------- +# Tests +# --------------------------------------------------------------------------- + + +class TestFromApp: + def test_returns_list(self, simple_app): + assert len(app_context.get().mcp_callback_map) == 1 + + def test_excludes_clientside(self): + app = Dash(__name__) + app.layout = html.Div( + [ + html.Button(id="btn"), + html.Div(id="cs-out"), + html.Div(id="srv-out"), + ] + ) + app.clientside_callback( + "function(n) { return n; }", + Output("cs-out", "children"), + Input("btn", "n_clicks"), + ) + + @app.callback(Output("srv-out", "children"), Input("btn", "n_clicks")) + def server_cb(n): + return str(n) + + app_context.set(app) + app.mcp_callback_map = CallbackAdapterCollection(app) + + names = [a.tool_name for a in app.mcp_callback_map] + assert names == ["server_cb"] + + def test_excludes_mcp_disabled(self): + app = Dash(__name__) + app.layout = html.Div( + [ + dcc.Input(id="inp"), + html.Div(id="out1"), + html.Div(id="out2"), + ] + ) + + @app.callback(Output("out1", "children"), Input("inp", "value")) + def visible(val): + return val + + @app.callback( + Output("out2", "children"), Input("inp", "value"), mcp_enabled=False + ) + def hidden(val): + return val + + app_context.set(app) + app.mcp_callback_map = CallbackAdapterCollection(app) + names = [a.tool_name for a in app.mcp_callback_map] + assert "visible" in names + assert "hidden" not in names + + +class TestToolName: + def test_uses_func_name(self, simple_app): + assert app_context.get().mcp_callback_map[0].tool_name == "update" + + def test_duplicates_get_unique_names(self, duplicate_names_app): + names = [a.tool_name for a in app_context.get().mcp_callback_map] + assert len(names) == 2 + assert names[0] != names[1] + + +class TestGetInitialValue: + def test_returns_layout_value(self, simple_app): + callback_map = app_context.get().mcp_callback_map + # Input with no value set — returns None (layout default for dcc.Input) + assert callback_map.get_initial_value("inp.value") is None + + def test_returns_set_value(self): + app = Dash(__name__) + app.layout = html.Div( + [ + dcc.Dropdown(id="dd", options=["a", "b"], value="a"), + html.Div(id="out"), + ] + ) + + @app.callback(Output("out", "children"), Input("dd", "value")) + def update(selected): + return selected + + app_context.set(app) + app.mcp_callback_map = CallbackAdapterCollection(app) + assert app.mcp_callback_map.get_initial_value("dd.value") == "a" + + def test_initial_callback_makes_param_required(self): + """A param with None in layout but set by an initial callback is required.""" + app = Dash(__name__) + app.layout = html.Div( + [ + dcc.Dropdown( + id="country", options=["France", "Germany"], value="France" + ), + dcc.Dropdown(id="city"), # value=None in layout + html.Div(id="out"), + ] + ) + + @app.callback( + Output("city", "options"), + Output("city", "value"), + Input("country", "value"), + ) + def update_cities(country): + return [{"label": "Paris", "value": "Paris"}], "Paris" + + @app.callback(Output("out", "children"), Input("city", "value")) + def show_city(city): + return f"Selected: {city}" + + app_context.set(app) + app.mcp_callback_map = CallbackAdapterCollection(app) + + # city.value is None in layout but "Paris" after initial callback + with app.server.test_request_context(): + show_city_cb = app.mcp_callback_map.find_by_tool_name("show_city") + city_param = show_city_cb.inputs[0] + assert city_param["name"] == "city" + assert city_param["required"] is True # not optional despite None in layout + + +class TestIsValid: + def test_valid_when_inputs_in_layout(self, simple_app): + assert app_context.get().mcp_callback_map[0].is_valid + + def test_invalid_when_input_not_in_layout(self): + app = Dash(__name__) + app.layout = html.Div([html.Div(id="out")]) + + @app.callback(Output("out", "children"), Input("nonexistent", "value")) + def update(val): + return val + + app_context.set(app) + app.mcp_callback_map = CallbackAdapterCollection(app) + assert not app.mcp_callback_map[0].is_valid + + def test_pattern_matching_ids_always_valid(self): + app = Dash(__name__) + app.layout = html.Div( + [ + dcc.Input(id={"type": "field", "index": 0}, value="a"), + html.Div(id="out"), + ] + ) + + @app.callback( + Output("out", "children"), + Input({"type": "field", "index": 0}, "value"), + ) + def update(val): + return val + + app_context.set(app) + app.mcp_callback_map = CallbackAdapterCollection(app) + assert app.mcp_callback_map[0].is_valid diff --git a/tests/unit/mcp/tools/test_callback_adapter_collection.py b/tests/unit/mcp/tools/test_callback_adapter_collection.py new file mode 100644 index 0000000000..c120a2df8b --- /dev/null +++ b/tests/unit/mcp/tools/test_callback_adapter_collection.py @@ -0,0 +1,145 @@ +"""Tests for CallbackAdapterCollection.""" + +from dash import Dash, Input, Output, dcc, html +from dash._get_app import app_context + +from dash.mcp.primitives.tools.callback_adapter_collection import ( + CallbackAdapterCollection, +) + + +def _setup(app): + app_context.set(app) + app.mcp_callback_map = CallbackAdapterCollection(app) + + +class TestToolNameCollisions: + @staticmethod + def _make_duplicate_cb_app(n=3): + ids = [f"dd{i + 1}" for i in range(n)] + app = Dash(__name__) + app.layout = html.Div( + [ + item + for i in ids + for item in [ + dcc.Dropdown( + id=i, options=[chr(97 + j) for j in range(1)], value="a" + ), + html.Div(id=f"{i}-output"), + ] + ] + ) + for idx, dd_id in enumerate(ids): + + @app.callback(Output(f"{dd_id}-output", "children"), Input(dd_id, "value")) + def cb(value, _id=dd_id): # noqa: F811 + return f"{_id}: {value}" + + return app + + def test_duplicate_func_names_get_unique_tools(self): + app = self._make_duplicate_cb_app(3) + _setup(app) + tool_names = [a.tool_name for a in app.mcp_callback_map] + assert len(tool_names) == 3 + assert len(set(tool_names)) == 3, f"Tool names are not unique: {tool_names}" + for name in tool_names: + assert "dd" in name, f"Expected output ID in tool name: {name}" + + def test_unique_func_names_use_func_name(self): + app = Dash(__name__) + app.layout = html.Div( + [ + html.Div(id="in1"), + html.Div(id="out1"), + html.Div(id="in2"), + html.Div(id="out2"), + ] + ) + + @app.callback(Output("out1", "children"), Input("in1", "children")) + def alpha_handler(value): + return value + + @app.callback(Output("out2", "children"), Input("in2", "children")) + def beta_handler(value): + return value + + _setup(app) + tool_names = [a.tool_name for a in app.mcp_callback_map] + assert "alpha_handler" in tool_names + assert "beta_handler" in tool_names + + def test_duplicate_func_names_use_output_id(self): + app = Dash(__name__) + app.layout = html.Div( + [ + html.Div(id="out1"), + html.Div(id="out2"), + html.Div(id="out3"), + html.Div(id="in1"), + html.Div(id="in2"), + html.Div(id="in3"), + ] + ) + + @app.callback(Output("out1", "children"), Input("in1", "children")) + def unique_func(v): + return v + + @app.callback(Output("out2", "children"), Input("in2", "children")) + def cb(v): + return v + + @app.callback(Output("out3", "children"), Input("in3", "children")) + def cb(v): # noqa: F811 + return v + + _setup(app) + tool_names = [a.tool_name for a in app.mcp_callback_map] + assert "unique_func" in tool_names + non_unique = [n for n in tool_names if n != "unique_func"] + assert len(non_unique) == 2 + assert non_unique[0] != non_unique[1] + + +class TestAllCallbacksVisibleByDefault: + def test_all_callbacks_visible_by_default(self): + app = Dash(__name__) + app.layout = html.Div( + [ + html.Div(id="in1"), + html.Div(id="out1"), + html.Div(id="in2"), + html.Div(id="out2"), + ] + ) + + @app.callback(Output("out1", "children"), Input("in1", "children")) + def cb_one(value): + return value + + @app.callback(Output("out2", "children"), Input("in2", "children")) + def cb_two(value): + return value + + _setup(app) + tool_names = [a.tool_name for a in app.mcp_callback_map] + assert "cb_one" in tool_names + assert "cb_two" in tool_names + + +class TestAdapterCollection: + def test_adapter_has_expected_properties(self): + app = Dash(__name__) + app.layout = html.Div([dcc.Input(id="inp"), html.Div(id="out")]) + + @app.callback(Output("out", "children"), Input("inp", "value")) + def update(val): + return val + + _setup(app) + adapter = app.mcp_callback_map[0] + assert adapter.tool_name == "update" + assert adapter.output_id == "out.children" From fcfa76518d57b0a888b59fc92970a80b68af74fc Mon Sep 17 00:00:00 2001 From: Adrian Borrmann Date: Wed, 1 Apr 2026 16:59:57 -0600 Subject: [PATCH 14/80] Fix type errors --- dash/_callback.py | 2 +- dash/types.py | 1 + 2 files changed, 2 insertions(+), 1 deletion(-) diff --git a/dash/_callback.py b/dash/_callback.py index 96a5e8f16c..fc7f7312c7 100644 --- a/dash/_callback.py +++ b/dash/_callback.py @@ -563,7 +563,7 @@ def _prepare_response( allow_dynamic_callbacks, ): """Prepare the response object based on the callback output.""" - component_ids = collections.defaultdict(dict) + component_ids: dict = collections.defaultdict(dict) if has_output: if not multi: diff --git a/dash/types.py b/dash/types.py index 43bf16dc30..e392a2d599 100644 --- a/dash/types.py +++ b/dash/types.py @@ -73,3 +73,4 @@ class CallbackDispatchResponse(TypedDict): multi: NotRequired[bool] response: NotRequired[Dict[str, CallbackOutput]] sideUpdate: NotRequired[Dict[str, CallbackSideOutput]] + dist: NotRequired[List[Any]] From 8c89d518dfdd7aecaf8aedc7ee7da0a5ad78bc03 Mon Sep 17 00:00:00 2001 From: Adrian Borrmann Date: Mon, 6 Apr 2026 12:11:03 -0600 Subject: [PATCH 15/80] Fix import path --- dash/mcp/primitives/tools/callback_adapter.py | 2 +- dash/mcp/primitives/tools/callback_adapter_collection.py | 2 +- 2 files changed, 2 insertions(+), 2 deletions(-) diff --git a/dash/mcp/primitives/tools/callback_adapter.py b/dash/mcp/primitives/tools/callback_adapter.py index 0f50d15c03..743453af10 100644 --- a/dash/mcp/primitives/tools/callback_adapter.py +++ b/dash/mcp/primitives/tools/callback_adapter.py @@ -15,7 +15,7 @@ from mcp.types import Tool from dash import get_app -from dash.layout import ( +from dash._layout_utils import ( _WILDCARD_VALUES, find_component, find_matching_components, diff --git a/dash/mcp/primitives/tools/callback_adapter_collection.py b/dash/mcp/primitives/tools/callback_adapter_collection.py index 60e9e2efe5..59c1a7ac47 100644 --- a/dash/mcp/primitives/tools/callback_adapter_collection.py +++ b/dash/mcp/primitives/tools/callback_adapter_collection.py @@ -14,7 +14,7 @@ from dash import get_app from dash._utils import clean_property_name, split_callback_id -from dash.layout import extract_text, find_component, traverse +from dash._layout_utils import extract_text, find_component, traverse from .callback_adapter import CallbackAdapter From eeadb7f99fce72dc74f19cb89c4dc4040323b84d Mon Sep 17 00:00:00 2001 From: Adrian Borrmann Date: Mon, 13 Apr 2026 16:12:57 -0600 Subject: [PATCH 16/80] tighten up types used around callbacks --- dash/_callback.py | 8 +-- dash/dash.py | 6 +-- dash/mcp/primitives/tools/callback_adapter.py | 54 ++++++++----------- dash/mcp/primitives/tools/callback_utils.py | 5 +- dash/types.py | 35 ++++++++++-- tests/unit/test_pydantic_types.py | 8 +-- 6 files changed, 67 insertions(+), 49 deletions(-) diff --git a/dash/_callback.py b/dash/_callback.py index fc7f7312c7..c8d610fa48 100644 --- a/dash/_callback.py +++ b/dash/_callback.py @@ -39,7 +39,7 @@ from .background_callback.managers import BaseBackgroundCallbackManager from ._callback_context import context_value -from .types import CallbackDispatchResponse +from .types import CallbackExecutionResponse from ._no_update import NoUpdate from . import _validate @@ -551,7 +551,7 @@ def _prepare_response( output_value, output_spec, multi, - response: CallbackDispatchResponse, + response: CallbackExecutionResponse, callback_ctx, app, original_packages, @@ -717,7 +717,7 @@ def add_context(*args, **kwargs): args, kwargs, inputs_state_indices, has_output, insert_output ) - response: CallbackDispatchResponse = {"multi": True} + response: CallbackExecutionResponse = {"multi": True} jsonResponse: Optional[str] = None try: if background is not None: @@ -789,7 +789,7 @@ async def async_add_context(*args, **kwargs): args, kwargs, inputs_state_indices, has_output, insert_output ) - response: CallbackDispatchResponse = {"multi": True} + response: CallbackExecutionResponse = {"multi": True} try: if background is not None: diff --git a/dash/dash.py b/dash/dash.py index 225c36bb47..f887598497 100644 --- a/dash/dash.py +++ b/dash/dash.py @@ -75,7 +75,7 @@ _import_layouts_from_pages, ) from ._jupyter import jupyter_dash, JupyterDisplayMode -from .types import CallbackDispatchBody, RendererHooks +from .types import CallbackExecutionBody, RendererHooks RouteCallable = Callable[..., Any] @@ -1469,7 +1469,7 @@ def _inputs_to_vals(self, inputs): return inputs_to_vals(inputs) # pylint: disable=R0915 - def _initialize_context(self, body: CallbackDispatchBody): + def _initialize_context(self, body: CallbackExecutionBody): """Initialize the global context for the request.""" adapter = self.backend.request_adapter() g = AttributeDict({}) @@ -1492,7 +1492,7 @@ def _initialize_context(self, body: CallbackDispatchBody): g.updated_props = {} return g - def _prepare_callback(self, g, body: CallbackDispatchBody): + def _prepare_callback(self, g, body: CallbackExecutionBody): """Prepare callback-related data.""" output = body["output"] try: diff --git a/dash/mcp/primitives/tools/callback_adapter.py b/dash/mcp/primitives/tools/callback_adapter.py index 743453af10..693d3f8e07 100644 --- a/dash/mcp/primitives/tools/callback_adapter.py +++ b/dash/mcp/primitives/tools/callback_adapter.py @@ -21,10 +21,17 @@ find_matching_components, parse_wildcard_id, ) -from dash.mcp.types import is_nullable from dash._grouping import flatten_grouping from dash._utils import clean_property_name, split_callback_id -from dash.mcp.types import MCPInput, MCPOutput +from dash.types import ( + CallbackDependency, + CallbackExecutionBody, + CallbackInput, + CallbackOutput, + CallbackOutputTarget, + WildcardId, +) +from dash.mcp.types import MCPInput, MCPOutput, is_nullable from .callback_utils import run_callback from .descriptions import build_tool_description from .input_schemas import get_input_schema @@ -46,7 +53,7 @@ def as_mcp_tool(self) -> Tool: """Stub — will be implemented in a future PR.""" raise NotImplementedError("as_mcp_tool will be implemented in a future PR.") - def as_callback_body(self, kwargs: dict[str, Any]) -> dict[str, Any]: + def as_callback_body(self, kwargs: dict[str, Any]) -> CallbackExecutionBody: """Transforms the given kwargs to a dict suitable for calling this callback. Mirrors how the Dash renderer assembles the callback payload — @@ -55,16 +62,14 @@ def as_callback_body(self, kwargs: dict[str, Any]) -> dict[str, Any]: For pattern-matching callbacks, wildcard deps are expanded into nested arrays with concrete component IDs. """ - coerced = {k: _coerce_value(v) for k, v in kwargs.items()} - raw_inputs = self._cb_info.get("inputs", []) raw_state = self._cb_info.get("state", []) n_deps = len(raw_inputs) + len(raw_state) flat_values = [None] * n_deps for i, name in enumerate(self._param_names): - if i < n_deps and name in coerced: - flat_values[i] = coerced[name] + if i < n_deps and name in kwargs: + flat_values[i] = kwargs[name] inputs_with_values = [ _expand_dep(dep, flat_values[i]) for i, dep in enumerate(raw_inputs) @@ -161,10 +166,10 @@ def _docstring(self) -> str | None: return getattr(self._original_func, "__doc__", None) @cached_property - def _initial_output(self) -> dict[str, dict[str, Any]]: + def _initial_output(self) -> dict[str, CallbackOutput]: """Run this callback with initial input values. - Returns the ``response`` portion of the dispatch result: + Returns the ``response`` portion of the callback result: ``{component_id: {property: value}}``. Skipped for callbacks with ``prevent_initial_call=True``, @@ -346,8 +351,8 @@ def _param_annotations(self) -> list[Any | None]: return [hints.get(func_name) for func_name, _ in self._dep_param_map] -def _expand_dep(dep: dict, value: Any) -> Any: - """Expand a dependency into the dispatch format. +def _expand_dep(dep: CallbackDependency, value: Any) -> CallbackInput | list[CallbackInput]: + """Attach a concrete value to a callback dependency to produce a valid callback input. For regular deps, returns ``{id, property, value}``. For ALL/ALLSMALLER: passes through the list of ``{id, property, value}`` dicts. @@ -365,7 +370,9 @@ def _expand_dep(dep: dict, value: Any) -> Any: return {**dep, "value": value} -def _expand_output_spec(output_id: str, cb_info: dict, resolved_inputs: list) -> Any: +def _expand_output_spec( + output_id: str, cb_info: dict, resolved_inputs: list[CallbackInput], +) -> list[CallbackOutputTarget]: """Build the outputs spec, expanding wildcards to concrete IDs. For wildcard outputs, derives concrete IDs from the resolved inputs. @@ -379,7 +386,7 @@ def _expand_output_spec(output_id: str, cb_info: dict, resolved_inputs: list) -> if isinstance(parsed, dict): parsed = [parsed] - results = [] + results: list[CallbackOutputTarget] = [] for p in parsed: pid = p["id"] prop = clean_property_name(p["property"]) @@ -397,14 +404,12 @@ def _expand_output_spec(output_id: str, cb_info: dict, resolved_inputs: list) -> else: results.append({"id": pid, "property": prop}) - if len(results) == 1: - return results[0] return results def _derive_output_ids( - output_pattern: dict, resolved_inputs: list -) -> list[dict] | None: + output_pattern: WildcardId, resolved_inputs: list[CallbackInput], +) -> list[WildcardId] | None: """Derive concrete output IDs from the resolved input entries. Extracts the wildcard key values from the LLM-provided concrete @@ -418,7 +423,7 @@ def _derive_output_ids( if not wildcard_keys: return None - def _substitute(item_id: dict) -> dict | None: + def _substitute(item_id: WildcardId) -> WildcardId | None: if not isinstance(item_id, dict): return None output_id = dict(output_pattern) @@ -446,16 +451,3 @@ def _substitute(item_id: dict) -> dict | None: return None -def _coerce_value(value: Any) -> Any: - """Parse JSON strings back to Python objects. - - MCP tool parameters arrive as strings. This recovers the - intended type (list, dict, number, bool, null) via json.loads. - Plain strings that aren't valid JSON pass through unchanged. - """ - if not isinstance(value, str): - return value - try: - return json.loads(value) - except (json.JSONDecodeError, ValueError): - return value diff --git a/dash/mcp/primitives/tools/callback_utils.py b/dash/mcp/primitives/tools/callback_utils.py index ec157b6037..86c7e837f0 100644 --- a/dash/mcp/primitives/tools/callback_utils.py +++ b/dash/mcp/primitives/tools/callback_utils.py @@ -6,13 +6,14 @@ from typing import TYPE_CHECKING, Any from dash import get_app +from dash.types import CallbackExecutionResponse if TYPE_CHECKING: from .callback_adapter import CallbackAdapter -def run_callback(callback: CallbackAdapter, kwargs: dict[str, Any]) -> dict[str, Any]: - """Execute a callback via Dash's dispatch pipeline.""" +def run_callback(callback: CallbackAdapter, kwargs: dict[str, Any]) -> CallbackExecutionResponse: + """Execute a callback via the framework.""" from dash.mcp.types import CallbackExecutionError body = callback.as_callback_body(kwargs) diff --git a/dash/types.py b/dash/types.py index e392a2d599..cbc94b8151 100644 --- a/dash/types.py +++ b/dash/types.py @@ -35,20 +35,45 @@ class RendererHooks(TypedDict): # pylint: disable=too-many-ancestors request_refresh_jwt: NotRequired[str] +WildcardId = Dict[str, Any] +"""A pattern-matching component ID, e.g. ``{"type": "item", "index": 0}``.""" + + class CallbackDependency(TypedDict): - id: Union[str, Dict[str, Any]] + id: Union[str, WildcardId] property: str +CallbackOutputTarget = Union[CallbackDependency, List[CallbackDependency]] +"""One callback Output() declaration resolved against the layout. + +For regular callbacks, a single dependency:: + + {"id": "chart", "property": "figure"} + +For pattern-matching callbacks (ALL/ALLSMALLER), a list of concrete +targets that the wildcard expanded to:: + + [ + {"id": {"type": "item", "index": 0}, "property": "children"}, + {"id": {"type": "item", "index": 1}, "property": "children"}, + ] + +For MATCH, a single dependency with a dict id:: + + {"id": {"type": "item", "index": 0}, "property": "children"} +""" + + class CallbackInput(TypedDict): - id: Union[str, Dict[str, Any]] + id: Union[str, WildcardId] property: str value: Any -class CallbackDispatchBody(TypedDict): +class CallbackExecutionBody(TypedDict): output: str - outputs: List[CallbackDependency] + outputs: List[CallbackOutputTarget] inputs: List[CallbackInput] state: List[CallbackInput] changedPropIds: List[str] @@ -69,7 +94,7 @@ class CallbackDispatchBody(TypedDict): ] -class CallbackDispatchResponse(TypedDict): +class CallbackExecutionResponse(TypedDict): multi: NotRequired[bool] response: NotRequired[Dict[str, CallbackOutput]] sideUpdate: NotRequired[Dict[str, CallbackSideOutput]] diff --git a/tests/unit/test_pydantic_types.py b/tests/unit/test_pydantic_types.py index 75d1dc7f41..389e678028 100644 --- a/tests/unit/test_pydantic_types.py +++ b/tests/unit/test_pydantic_types.py @@ -2,7 +2,7 @@ from pydantic import TypeAdapter -from dash.types import NumberType, CallbackDispatchBody, CallbackDispatchResponse +from dash.types import NumberType, CallbackExecutionBody, CallbackExecutionResponse from dash.development.base_component import Component @@ -25,12 +25,12 @@ def test_schema_has_type_and_props(self): assert "props" in props -class TestCallbackDispatchTypes: +class TestCallbackExecutionTypes: def test_dispatch_body_schema(self): - schema = TypeAdapter(CallbackDispatchBody).json_schema() + schema = TypeAdapter(CallbackExecutionBody).json_schema() assert "output" in schema["properties"] assert "inputs" in schema["properties"] def test_dispatch_response_schema(self): - schema = TypeAdapter(CallbackDispatchResponse).json_schema() + schema = TypeAdapter(CallbackExecutionResponse).json_schema() assert "response" in schema["properties"] From 5b09f990d598182cbc644827853a559aa21f8f6e Mon Sep 17 00:00:00 2001 From: Adrian Borrmann Date: Mon, 13 Apr 2026 16:23:41 -0600 Subject: [PATCH 17/80] Remove redundant error code --- dash/mcp/types/exceptions.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/dash/mcp/types/exceptions.py b/dash/mcp/types/exceptions.py index 7fb962db85..4b860f4bba 100644 --- a/dash/mcp/types/exceptions.py +++ b/dash/mcp/types/exceptions.py @@ -27,4 +27,4 @@ class InvalidParamsError(MCPError): class CallbackExecutionError(MCPError): """Callback raised an exception during execution.""" - code = -32603 + pass From fa99ab096748c88c3abc67358418aabb9a16c411 Mon Sep 17 00:00:00 2001 From: Adrian Borrmann Date: Tue, 14 Apr 2026 09:10:32 -0600 Subject: [PATCH 18/80] lint --- dash/mcp/primitives/tools/callback_adapter.py | 13 ++++++++----- dash/mcp/primitives/tools/callback_utils.py | 4 +++- 2 files changed, 11 insertions(+), 6 deletions(-) diff --git a/dash/mcp/primitives/tools/callback_adapter.py b/dash/mcp/primitives/tools/callback_adapter.py index 693d3f8e07..9dcb8d959e 100644 --- a/dash/mcp/primitives/tools/callback_adapter.py +++ b/dash/mcp/primitives/tools/callback_adapter.py @@ -351,7 +351,9 @@ def _param_annotations(self) -> list[Any | None]: return [hints.get(func_name) for func_name, _ in self._dep_param_map] -def _expand_dep(dep: CallbackDependency, value: Any) -> CallbackInput | list[CallbackInput]: +def _expand_dep( + dep: CallbackDependency, value: Any +) -> CallbackInput | list[CallbackInput]: """Attach a concrete value to a callback dependency to produce a valid callback input. For regular deps, returns ``{id, property, value}``. @@ -371,7 +373,9 @@ def _expand_dep(dep: CallbackDependency, value: Any) -> CallbackInput | list[Cal def _expand_output_spec( - output_id: str, cb_info: dict, resolved_inputs: list[CallbackInput], + output_id: str, + cb_info: dict, + resolved_inputs: list[CallbackInput], ) -> list[CallbackOutputTarget]: """Build the outputs spec, expanding wildcards to concrete IDs. @@ -408,7 +412,8 @@ def _expand_output_spec( def _derive_output_ids( - output_pattern: WildcardId, resolved_inputs: list[CallbackInput], + output_pattern: WildcardId, + resolved_inputs: list[CallbackInput], ) -> list[WildcardId] | None: """Derive concrete output IDs from the resolved input entries. @@ -449,5 +454,3 @@ def _substitute(item_id: WildcardId) -> WildcardId | None: return [out] return None - - diff --git a/dash/mcp/primitives/tools/callback_utils.py b/dash/mcp/primitives/tools/callback_utils.py index 86c7e837f0..0ff4f5f578 100644 --- a/dash/mcp/primitives/tools/callback_utils.py +++ b/dash/mcp/primitives/tools/callback_utils.py @@ -12,7 +12,9 @@ from .callback_adapter import CallbackAdapter -def run_callback(callback: CallbackAdapter, kwargs: dict[str, Any]) -> CallbackExecutionResponse: +def run_callback( + callback: CallbackAdapter, kwargs: dict[str, Any] +) -> CallbackExecutionResponse: """Execute a callback via the framework.""" from dash.mcp.types import CallbackExecutionError From cd1f4d4492a0b310edd57418071c0b5b0ad89d7a Mon Sep 17 00:00:00 2001 From: Adrian Borrmann Date: Wed, 1 Apr 2026 12:46:29 -0600 Subject: [PATCH 19/80] Add MCP resource providers for layout, components, pages, and clientside callbacks --- dash/mcp/primitives/resources/__init__.py | 52 ++++++++++ .../resource_clientside_callbacks.py | 95 +++++++++++++++++++ .../resources/resource_components.py | 59 ++++++++++++ .../primitives/resources/resource_layout.py | 44 +++++++++ .../resources/resource_page_layout.py | 77 +++++++++++++++ .../primitives/resources/resource_pages.py | 76 +++++++++++++++ .../test_resource_clientside_callbacks.py | 54 +++++++++++ .../resources/test_resource_layout.py | 59 ++++++++++++ .../resources/test_resource_page_layout.py | 52 ++++++++++ .../resources/test_resource_pages.py | 78 +++++++++++++++ 10 files changed, 646 insertions(+) create mode 100644 dash/mcp/primitives/resources/__init__.py create mode 100644 dash/mcp/primitives/resources/resource_clientside_callbacks.py create mode 100644 dash/mcp/primitives/resources/resource_components.py create mode 100644 dash/mcp/primitives/resources/resource_layout.py create mode 100644 dash/mcp/primitives/resources/resource_page_layout.py create mode 100644 dash/mcp/primitives/resources/resource_pages.py create mode 100644 tests/unit/mcp/primitives/resources/test_resource_clientside_callbacks.py create mode 100644 tests/unit/mcp/primitives/resources/test_resource_layout.py create mode 100644 tests/unit/mcp/primitives/resources/test_resource_page_layout.py create mode 100644 tests/unit/mcp/primitives/resources/test_resource_pages.py diff --git a/dash/mcp/primitives/resources/__init__.py b/dash/mcp/primitives/resources/__init__.py new file mode 100644 index 0000000000..da93feae04 --- /dev/null +++ b/dash/mcp/primitives/resources/__init__.py @@ -0,0 +1,52 @@ +"""MCP resource listing and read handling. + +Each resource module exports: +- ``URI`` — the URI prefix this module handles +- ``get_resource() -> Resource | None`` +- ``get_template() -> ResourceTemplate | None`` +- ``read_resource(uri) -> ReadResourceResult`` + +Dispatch is by prefix match: more specific prefixes must come first. +""" + +from __future__ import annotations + +from mcp.types import ( + ListResourcesResult, + ListResourceTemplatesResult, + ReadResourceResult, +) + +from . import ( + resource_clientside_callbacks as _clientside, + resource_components as _components, + resource_layout as _layout, + resource_page_layout as _page_layout, + resource_pages as _pages, +) + +_RESOURCE_MODULES = [_layout, _components, _pages, _clientside, _page_layout] + + +def list_resources() -> ListResourcesResult: + """Build the MCP resources/list response.""" + resources = [ + r for mod in _RESOURCE_MODULES for r in [mod.get_resource()] if r is not None + ] + return ListResourcesResult(resources=resources) + + +def list_resource_templates() -> ListResourceTemplatesResult: + """Build the MCP resources/templates/list response.""" + templates = [ + t for mod in _RESOURCE_MODULES for t in [mod.get_template()] if t is not None + ] + return ListResourceTemplatesResult(resourceTemplates=templates) + + +def read_resource(uri: str) -> ReadResourceResult: + """Dispatch a resources/read request by URI prefix match.""" + for mod in _RESOURCE_MODULES: + if uri.startswith(mod.URI): + return mod.read_resource(uri) + raise ValueError(f"Unknown resource URI: {uri}") diff --git a/dash/mcp/primitives/resources/resource_clientside_callbacks.py b/dash/mcp/primitives/resources/resource_clientside_callbacks.py new file mode 100644 index 0000000000..dbc3009edb --- /dev/null +++ b/dash/mcp/primitives/resources/resource_clientside_callbacks.py @@ -0,0 +1,95 @@ +"""Clientside callbacks resource.""" + +from __future__ import annotations + +import json +from typing import Any + +from mcp.types import ( + ReadResourceResult, + Resource, + ResourceTemplate, + TextResourceContents, +) + +from dash import get_app +from dash._utils import clean_property_name, split_callback_id + +URI = "dash://clientside-callbacks" + + +def get_resource() -> Resource | None: + if not _get_clientside_callbacks(): + return None + return Resource( + uri=URI, + name="dash_clientside_callbacks", + description=( + "Actions the user can take manually in the browser " + "to affect clientside state. Inputs describe the " + "components that can be changed to trigger an effect. " + "Outputs describe the components that will change " + "in response." + ), + mimeType="application/json", + ) + + +def get_template() -> ResourceTemplate | None: + return None + + +def read_resource(uri: str = "") -> ReadResourceResult: + data = { + "description": ( + "These are actions that the user can take manually in the " + "browser to affect the clientside state. Inputs describe " + "the components that can be changed to trigger an effect. " + "Outputs describe the components that will change in " + "response to the effect." + ), + "callbacks": _get_clientside_callbacks(), + } + return ReadResourceResult( + contents=[ + TextResourceContents( + uri=URI, + mimeType="application/json", + text=json.dumps(data, default=str), + ) + ] + ) + + +def _get_clientside_callbacks() -> list[dict[str, Any]]: + """Get input/output mappings for clientside callbacks.""" + app = get_app() + callbacks = [] + callback_map = getattr(app, "callback_map", {}) + + for output_id, callback_info in callback_map.items(): + if "callback" in callback_info: + continue + normalize_deps = lambda deps: [ + { + "component_id": str(d.get("id", "unknown")), + "property": d.get("property", "unknown"), + } + for d in deps + ] + parsed = split_callback_id(output_id) + if isinstance(parsed, dict): + parsed = [parsed] + outputs = [ + {"component_id": p["id"], "property": clean_property_name(p["property"])} + for p in parsed + ] + callbacks.append( + { + "outputs": outputs, + "inputs": normalize_deps(callback_info.get("inputs", [])), + "state": normalize_deps(callback_info.get("state", [])), + } + ) + + return callbacks diff --git a/dash/mcp/primitives/resources/resource_components.py b/dash/mcp/primitives/resources/resource_components.py new file mode 100644 index 0000000000..e6441d7aee --- /dev/null +++ b/dash/mcp/primitives/resources/resource_components.py @@ -0,0 +1,59 @@ +"""Component list resource.""" + +from __future__ import annotations + +import json + +from mcp.types import ( + ReadResourceResult, + Resource, + ResourceTemplate, + TextResourceContents, +) + +from dash import get_app +from dash.layout import traverse + +URI = "dash://components" + + +def get_resource() -> Resource | None: + return Resource( + uri=URI, + name="dash_components", + description=( + "All components with IDs in the app layout. " + "Use get_dash_component with any of these IDs " + "to inspect their properties and values. " + "See dash://layout for the tree structure showing " + "how these components are nested in the page." + ), + mimeType="application/json", + ) + + +def get_template() -> ResourceTemplate | None: + return None + + +def read_resource(uri: str = "") -> ReadResourceResult: + app = get_app() + layout = app.get_layout() + components = sorted( + [ + {"id": str(comp.id), "type": getattr(comp, "_type", type(comp).__name__)} + for comp, _ in traverse(layout) + if getattr(comp, "id", None) is not None + ], + key=lambda c: c["id"], + ) + + return ReadResourceResult( + contents=[ + TextResourceContents( + uri=URI, + mimeType="application/json", + text=json.dumps(components), + ) + ] + ) diff --git a/dash/mcp/primitives/resources/resource_layout.py b/dash/mcp/primitives/resources/resource_layout.py new file mode 100644 index 0000000000..01d0be046d --- /dev/null +++ b/dash/mcp/primitives/resources/resource_layout.py @@ -0,0 +1,44 @@ +"""Layout tree resource for the whole app.""" + +from __future__ import annotations + +from mcp.types import ( + ReadResourceResult, + Resource, + ResourceTemplate, + TextResourceContents, +) + +from dash import get_app +from dash._utils import to_json + +URI = "dash://layout" + + +def get_resource() -> Resource | None: + return Resource( + uri=URI, + name="dash_app_layout", + description=( + "Full component tree of the Dash app. " + "See dash://components for a compact list of component IDs." + ), + mimeType="application/json", + ) + + +def get_template() -> ResourceTemplate | None: + return None + + +def read_resource(uri: str = "") -> ReadResourceResult: + app = get_app() + return ReadResourceResult( + contents=[ + TextResourceContents( + uri=URI, + mimeType="application/json", + text=to_json(app.get_layout()), + ) + ] + ) diff --git a/dash/mcp/primitives/resources/resource_page_layout.py b/dash/mcp/primitives/resources/resource_page_layout.py new file mode 100644 index 0000000000..d82d366298 --- /dev/null +++ b/dash/mcp/primitives/resources/resource_page_layout.py @@ -0,0 +1,77 @@ +"""Per-page layout resource template for multi-page apps.""" + +from __future__ import annotations + +from mcp.types import ( + ReadResourceResult, + Resource, + ResourceTemplate, + TextResourceContents, +) + +from dash._utils import to_json + +URI = "dash://page-layout/" +_URI_TEMPLATE = "dash://page-layout/{path}" + + +def get_resource() -> Resource | None: + return None + + +def get_template() -> ResourceTemplate | None: + if not _has_pages(): + return None + return ResourceTemplate( + uriTemplate=_URI_TEMPLATE, + name="dash_page_layout", + description="Component tree for a specific page in the app.", + mimeType="application/json", + ) + + +def read_resource(uri: str) -> ReadResourceResult: + path = uri[len(URI) :] + if not path.startswith("/"): + path = "/" + path + + try: + from dash._pages import PAGE_REGISTRY + except ImportError: + raise ValueError("Dash Pages is not available.") + + page_layout = None + for _module, page in PAGE_REGISTRY.items(): + if page.get("path") == path: + page_layout = page.get("layout") + break + + if page_layout is None: + raise ValueError(f"Page not found: {path}") + + if callable(page_layout): + page_layout = page_layout() + + if isinstance(page_layout, (list, tuple)): + from dash import html + + page_layout = html.Div(list(page_layout)) + + return ReadResourceResult( + contents=[ + TextResourceContents( + uri=uri, + mimeType="application/json", + text=to_json(page_layout), + ) + ] + ) + + +def _has_pages() -> bool: + try: + from dash._pages import PAGE_REGISTRY + + return bool(PAGE_REGISTRY) + except ImportError: + return False diff --git a/dash/mcp/primitives/resources/resource_pages.py b/dash/mcp/primitives/resources/resource_pages.py new file mode 100644 index 0000000000..51a61b9f00 --- /dev/null +++ b/dash/mcp/primitives/resources/resource_pages.py @@ -0,0 +1,76 @@ +"""Pages resource for multi-page apps.""" + +from __future__ import annotations + +import json + +from mcp.types import ( + ReadResourceResult, + Resource, + ResourceTemplate, + TextResourceContents, +) + +URI = "dash://pages" + + +def _has_pages() -> bool: + try: + from dash._pages import PAGE_REGISTRY + + return bool(PAGE_REGISTRY) + except ImportError: + return False + + +def get_resource() -> Resource | None: + if not _has_pages(): + return None + return Resource( + uri=URI, + name="dash_app_pages", + description=( + "List of all pages in this multi-page Dash app " + "with paths, names, titles, and descriptions." + ), + mimeType="application/json", + ) + + +def get_template() -> ResourceTemplate | None: + return None + + +def read_resource(uri: str = "") -> ReadResourceResult: + try: + from dash._pages import PAGE_REGISTRY + except ImportError: + return ReadResourceResult( + contents=[ + TextResourceContents(uri=URI, mimeType="application/json", text="[]") + ] + ) + + pages = [] + for module, page in PAGE_REGISTRY.items(): + title = page.get("title", "") + description = page.get("description", "") + pages.append( + { + "module": module, + "path": page.get("path", ""), + "name": page.get("name", ""), + "title": title if not callable(title) else page.get("name", ""), + "description": description if not callable(description) else "", + } + ) + + return ReadResourceResult( + contents=[ + TextResourceContents( + uri=URI, + mimeType="application/json", + text=json.dumps(pages, default=str), + ) + ] + ) diff --git a/tests/unit/mcp/primitives/resources/test_resource_clientside_callbacks.py b/tests/unit/mcp/primitives/resources/test_resource_clientside_callbacks.py new file mode 100644 index 0000000000..3ba2ce7996 --- /dev/null +++ b/tests/unit/mcp/primitives/resources/test_resource_clientside_callbacks.py @@ -0,0 +1,54 @@ +"""Tests for the dash://clientside-callbacks resource.""" + +import json + +from dash import Dash, Input, Output, clientside_callback, html + +from dash.mcp.primitives.resources import list_resources, read_resource + + +class TestClientsideCallbacksResource: + @staticmethod + def _make_app(): + app = Dash(__name__) + app.layout = html.Div( + [ + html.Button(id="btn", children="Click"), + html.Div(id="out"), + html.Div(id="server-out"), + ] + ) + + clientside_callback( + "function(n) { return n; }", + Output("out", "children"), + Input("btn", "n_clicks"), + ) + + @app.callback(Output("server-out", "children"), Input("btn", "n_clicks")) + def server_cb(n): + return str(n) + + with app.server.test_request_context(): + app._setup_server() + + return app + + def test_resource_listed(self): + app = self._make_app() + with app.server.test_request_context(): + result = list_resources() + uris = [str(r.uri) for r in result.resources] + assert "dash://clientside-callbacks" in uris + + def test_resource_read(self): + app = self._make_app() + with app.server.test_request_context(): + result = read_resource("dash://clientside-callbacks") + data = json.loads(result.contents[0].text) + assert "description" in data + callbacks = data["callbacks"] + assert len(callbacks) == 1 + assert callbacks[0]["inputs"][0]["component_id"] == "btn" + assert callbacks[0]["inputs"][0]["property"] == "n_clicks" + assert callbacks[0]["outputs"][0]["component_id"] == "out" diff --git a/tests/unit/mcp/primitives/resources/test_resource_layout.py b/tests/unit/mcp/primitives/resources/test_resource_layout.py new file mode 100644 index 0000000000..ade207b1f3 --- /dev/null +++ b/tests/unit/mcp/primitives/resources/test_resource_layout.py @@ -0,0 +1,59 @@ +"""Tests for the dash://layout resource.""" + +import json +from unittest.mock import patch + +from dash import Dash, dcc, html + +from dash.mcp.primitives.resources import list_resources, read_resource + +EXPECTED_LAYOUT = { + "type": "Div", + "namespace": "dash_html_components", + "props": { + "children": [ + { + "type": "Dropdown", + "namespace": "dash_core_components", + "props": { + "id": "test-dd", + "options": ["a", "b"], + "value": "a", + }, + }, + { + "type": "Div", + "namespace": "dash_html_components", + "props": { + "children": None, + "id": "output", + }, + }, + ] + }, +} + + +class TestLayoutResource: + def test_listed_in_resources(self): + app = Dash(__name__) + app.layout = html.Div(id="main") + with app.server.test_request_context(): + result = list_resources() + uris = [str(r.uri) for r in result.resources] + assert "dash://layout" in uris + + def test_read_returns_layout(self): + app = Dash(__name__) + app.layout = html.Div( + [ + dcc.Dropdown(id="test-dd", options=["a", "b"], value="a"), + html.Div(id="output"), + ] + ) + with app.server.test_request_context(): + with patch.object(app, "get_layout", wraps=app.get_layout) as mock: + result = read_resource("dash://layout") + mock.assert_called_once() + layout = json.loads(result.contents[0].text) + assert layout == EXPECTED_LAYOUT diff --git a/tests/unit/mcp/primitives/resources/test_resource_page_layout.py b/tests/unit/mcp/primitives/resources/test_resource_page_layout.py new file mode 100644 index 0000000000..88ffd82118 --- /dev/null +++ b/tests/unit/mcp/primitives/resources/test_resource_page_layout.py @@ -0,0 +1,52 @@ +"""Tests for the dash://page-layout/{path} resource template.""" + +import json +from unittest.mock import patch + +from dash import Dash, dcc, html + +from dash.mcp.primitives.resources import read_resource + +EXPECTED_PAGE_LAYOUT = { + "type": "Div", + "namespace": "dash_html_components", + "props": { + "children": [ + { + "type": "Dropdown", + "namespace": "dash_core_components", + "props": { + "id": "page-dd", + "options": ["a", "b"], + "value": "a", + }, + } + ] + }, +} + + +class TestPageLayoutResource: + def test_read_page_layout(self): + app = Dash(__name__) + app.layout = html.Div(id="main") + + page_layout = html.Div( + [ + dcc.Dropdown(id="page-dd", options=["a", "b"], value="a"), + ] + ) + fake_registry = { + "pages.test": { + "path": "/test", + "name": "Test", + "title": "Test Page", + "description": "", + "layout": page_layout, + }, + } + with app.server.test_request_context(): + with patch("dash._pages.PAGE_REGISTRY", fake_registry): + result = read_resource("dash://page-layout/test") + layout = json.loads(result.contents[0].text) + assert layout == EXPECTED_PAGE_LAYOUT diff --git a/tests/unit/mcp/primitives/resources/test_resource_pages.py b/tests/unit/mcp/primitives/resources/test_resource_pages.py new file mode 100644 index 0000000000..b2307d6fef --- /dev/null +++ b/tests/unit/mcp/primitives/resources/test_resource_pages.py @@ -0,0 +1,78 @@ +"""Tests for the dash://pages resource.""" + +import json +from unittest.mock import patch + +from dash import Dash, html + +from dash.mcp.primitives.resources import list_resources, read_resource + +EXPECTED_PAGES = [ + { + "path": "/", + "name": "Home", + "title": "Home Page", + "description": "The landing page", + "module": "pages.home", + }, + { + "path": "/analytics", + "name": "Analytics", + "title": "Analytics Dashboard", + "description": "View analytics", + "module": "pages.analytics", + }, +] + + +class TestPagesResource: + @staticmethod + def _make_app(): + app = Dash(__name__) + app.layout = html.Div(id="main") + return app + + def test_listed_for_multi_page_app(self): + app = self._make_app() + fake_registry = { + "pages.home": { + "path": "/", + "name": "Home", + "title": "Home", + "description": "", + } + } + with app.server.test_request_context(): + with patch("dash._pages.PAGE_REGISTRY", fake_registry): + result = list_resources() + uris = [str(r.uri) for r in result.resources] + assert "dash://pages" in uris + + def test_returns_page_info(self): + app = self._make_app() + fake_registry = { + "pages.home": EXPECTED_PAGES[0], + "pages.analytics": EXPECTED_PAGES[1], + } + with app.server.test_request_context(): + with patch("dash._pages.PAGE_REGISTRY", fake_registry): + result = read_resource("dash://pages") + content = json.loads(result.contents[0].text) + assert content == EXPECTED_PAGES + + def test_callable_title_falls_back_to_name(self): + app = self._make_app() + fake_registry = { + "pages.dynamic": { + "path": "/item/", + "name": "Item Detail", + "title": lambda **kwargs: f"Item {kwargs.get('item_id', '')}", + "description": lambda **kwargs: f"Details for {kwargs.get('item_id', '')}", + }, + } + with app.server.test_request_context(): + with patch("dash._pages.PAGE_REGISTRY", fake_registry): + result = read_resource("dash://pages") + page = json.loads(result.contents[0].text)[0] + assert page["title"] == "Item Detail" + assert page["description"] == "" From 2f00c0e5bd403b833c05ae1f508c32102dd20543 Mon Sep 17 00:00:00 2001 From: Adrian Borrmann Date: Mon, 6 Apr 2026 12:13:28 -0600 Subject: [PATCH 20/80] Fix import path --- dash/mcp/primitives/resources/resource_components.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/dash/mcp/primitives/resources/resource_components.py b/dash/mcp/primitives/resources/resource_components.py index e6441d7aee..8cf366f95c 100644 --- a/dash/mcp/primitives/resources/resource_components.py +++ b/dash/mcp/primitives/resources/resource_components.py @@ -12,7 +12,7 @@ ) from dash import get_app -from dash.layout import traverse +from dash._layout_utils import traverse URI = "dash://components" From 5e472b2f40dac44dca4848a481b2cc97babb39d6 Mon Sep 17 00:00:00 2001 From: Adrian Borrmann Date: Mon, 13 Apr 2026 16:43:56 -0600 Subject: [PATCH 21/80] Implement base class for each resource type --- dash/mcp/primitives/resources/__init__.py | 44 ++++--- dash/mcp/primitives/resources/base.py | 28 +++++ .../resource_clientside_callbacks.py | 81 +++++++------ .../resources/resource_components.py | 87 +++++++------- .../primitives/resources/resource_layout.py | 61 +++++----- .../resources/resource_page_layout.py | 111 ++++++++---------- .../primitives/resources/resource_pages.py | 91 ++++++-------- .../resources/test_resource_page_layout.py | 5 +- .../resources/test_resource_pages.py | 15 ++- 9 files changed, 263 insertions(+), 260 deletions(-) create mode 100644 dash/mcp/primitives/resources/base.py diff --git a/dash/mcp/primitives/resources/__init__.py b/dash/mcp/primitives/resources/__init__.py index da93feae04..a65e376e6f 100644 --- a/dash/mcp/primitives/resources/__init__.py +++ b/dash/mcp/primitives/resources/__init__.py @@ -1,13 +1,4 @@ -"""MCP resource listing and read handling. - -Each resource module exports: -- ``URI`` — the URI prefix this module handles -- ``get_resource() -> Resource | None`` -- ``get_template() -> ResourceTemplate | None`` -- ``read_resource(uri) -> ReadResourceResult`` - -Dispatch is by prefix match: more specific prefixes must come first. -""" +"""MCP resource listing and read handling.""" from __future__ import annotations @@ -17,21 +8,26 @@ ReadResourceResult, ) -from . import ( - resource_clientside_callbacks as _clientside, - resource_components as _components, - resource_layout as _layout, - resource_page_layout as _page_layout, - resource_pages as _pages, -) +from .base import MCPResourceProvider +from .resource_clientside_callbacks import ClientsideCallbacksResource +from .resource_components import ComponentsResource +from .resource_layout import LayoutResource +from .resource_page_layout import PageLayoutResource +from .resource_pages import PagesResource -_RESOURCE_MODULES = [_layout, _components, _pages, _clientside, _page_layout] +_RESOURCE_PROVIDERS: list[type[MCPResourceProvider]] = [ + LayoutResource, + ComponentsResource, + PagesResource, + ClientsideCallbacksResource, + PageLayoutResource, +] def list_resources() -> ListResourcesResult: """Build the MCP resources/list response.""" resources = [ - r for mod in _RESOURCE_MODULES for r in [mod.get_resource()] if r is not None + r for p in _RESOURCE_PROVIDERS for r in [p.get_resource()] if r is not None ] return ListResourcesResult(resources=resources) @@ -39,14 +35,14 @@ def list_resources() -> ListResourcesResult: def list_resource_templates() -> ListResourceTemplatesResult: """Build the MCP resources/templates/list response.""" templates = [ - t for mod in _RESOURCE_MODULES for t in [mod.get_template()] if t is not None + t for p in _RESOURCE_PROVIDERS for t in [p.get_template()] if t is not None ] return ListResourceTemplatesResult(resourceTemplates=templates) def read_resource(uri: str) -> ReadResourceResult: - """Dispatch a resources/read request by URI prefix match.""" - for mod in _RESOURCE_MODULES: - if uri.startswith(mod.URI): - return mod.read_resource(uri) + """Route a resources/read request by URI prefix match.""" + for p in _RESOURCE_PROVIDERS: + if uri.startswith(p.uri): + return p.read_resource(uri) raise ValueError(f"Unknown resource URI: {uri}") diff --git a/dash/mcp/primitives/resources/base.py b/dash/mcp/primitives/resources/base.py new file mode 100644 index 0000000000..e63ffc9681 --- /dev/null +++ b/dash/mcp/primitives/resources/base.py @@ -0,0 +1,28 @@ +"""Base class for MCP resource providers.""" + +from __future__ import annotations + +from mcp.types import ReadResourceResult, Resource, ResourceTemplate + + +class MCPResourceProvider: + """Base class for MCP resource providers. + + Subclasses must set ``uri`` and implement ``read_resource``. + Override ``get_resource`` and/or ``get_template`` to advertise + the resource in ``resources/list`` or ``resources/templates/list``. + """ + + uri: str + + @classmethod + def get_resource(cls) -> Resource | None: + return None + + @classmethod + def get_template(cls) -> ResourceTemplate | None: + return None + + @classmethod + def read_resource(cls, uri: str) -> ReadResourceResult: + raise NotImplementedError diff --git a/dash/mcp/primitives/resources/resource_clientside_callbacks.py b/dash/mcp/primitives/resources/resource_clientside_callbacks.py index dbc3009edb..127c0f9adc 100644 --- a/dash/mcp/primitives/resources/resource_clientside_callbacks.py +++ b/dash/mcp/primitives/resources/resource_clientside_callbacks.py @@ -8,57 +8,56 @@ from mcp.types import ( ReadResourceResult, Resource, - ResourceTemplate, TextResourceContents, ) from dash import get_app from dash._utils import clean_property_name, split_callback_id -URI = "dash://clientside-callbacks" +from .base import MCPResourceProvider -def get_resource() -> Resource | None: - if not _get_clientside_callbacks(): - return None - return Resource( - uri=URI, - name="dash_clientside_callbacks", - description=( - "Actions the user can take manually in the browser " - "to affect clientside state. Inputs describe the " - "components that can be changed to trigger an effect. " - "Outputs describe the components that will change " - "in response." - ), - mimeType="application/json", - ) - - -def get_template() -> ResourceTemplate | None: - return None +class ClientsideCallbacksResource(MCPResourceProvider): + uri = "dash://clientside-callbacks" + @classmethod + def get_resource(cls) -> Resource | None: + if not _get_clientside_callbacks(): + return None + return Resource( + uri=cls.uri, + name="dash_clientside_callbacks", + description=( + "Actions the user can take manually in the browser " + "to affect clientside state. Inputs describe the " + "components that can be changed to trigger an effect. " + "Outputs describe the components that will change " + "in response." + ), + mimeType="application/json", + ) -def read_resource(uri: str = "") -> ReadResourceResult: - data = { - "description": ( - "These are actions that the user can take manually in the " - "browser to affect the clientside state. Inputs describe " - "the components that can be changed to trigger an effect. " - "Outputs describe the components that will change in " - "response to the effect." - ), - "callbacks": _get_clientside_callbacks(), - } - return ReadResourceResult( - contents=[ - TextResourceContents( - uri=URI, - mimeType="application/json", - text=json.dumps(data, default=str), - ) - ] - ) + @classmethod + def read_resource(cls, uri: str = "") -> ReadResourceResult: + data = { + "description": ( + "These are actions that the user can take manually in the " + "browser to affect the clientside state. Inputs describe " + "the components that can be changed to trigger an effect. " + "Outputs describe the components that will change in " + "response to the effect." + ), + "callbacks": _get_clientside_callbacks(), + } + return ReadResourceResult( + contents=[ + TextResourceContents( + uri=cls.uri, + mimeType="application/json", + text=json.dumps(data, default=str), + ) + ] + ) def _get_clientside_callbacks() -> list[dict[str, Any]]: diff --git a/dash/mcp/primitives/resources/resource_components.py b/dash/mcp/primitives/resources/resource_components.py index 8cf366f95c..8175aab72e 100644 --- a/dash/mcp/primitives/resources/resource_components.py +++ b/dash/mcp/primitives/resources/resource_components.py @@ -7,53 +7,52 @@ from mcp.types import ( ReadResourceResult, Resource, - ResourceTemplate, TextResourceContents, ) from dash import get_app from dash._layout_utils import traverse -URI = "dash://components" - - -def get_resource() -> Resource | None: - return Resource( - uri=URI, - name="dash_components", - description=( - "All components with IDs in the app layout. " - "Use get_dash_component with any of these IDs " - "to inspect their properties and values. " - "See dash://layout for the tree structure showing " - "how these components are nested in the page." - ), - mimeType="application/json", - ) - - -def get_template() -> ResourceTemplate | None: - return None - - -def read_resource(uri: str = "") -> ReadResourceResult: - app = get_app() - layout = app.get_layout() - components = sorted( - [ - {"id": str(comp.id), "type": getattr(comp, "_type", type(comp).__name__)} - for comp, _ in traverse(layout) - if getattr(comp, "id", None) is not None - ], - key=lambda c: c["id"], - ) - - return ReadResourceResult( - contents=[ - TextResourceContents( - uri=URI, - mimeType="application/json", - text=json.dumps(components), - ) - ] - ) +from .base import MCPResourceProvider + + +class ComponentsResource(MCPResourceProvider): + uri = "dash://components" + + @classmethod + def get_resource(cls) -> Resource | None: + return Resource( + uri=cls.uri, + name="dash_components", + description=( + "All components with IDs in the app layout. " + "Use get_dash_component with any of these IDs " + "to inspect their properties and values. " + "See dash://layout for the tree structure showing " + "how these components are nested in the page." + ), + mimeType="application/json", + ) + + @classmethod + def read_resource(cls, uri: str = "") -> ReadResourceResult: + app = get_app() + layout = app.get_layout() + components = sorted( + [ + {"id": str(comp.id), "type": getattr(comp, "_type", type(comp).__name__)} + for comp, _ in traverse(layout) + if getattr(comp, "id", None) is not None + ], + key=lambda c: c["id"], + ) + + return ReadResourceResult( + contents=[ + TextResourceContents( + uri=cls.uri, + mimeType="application/json", + text=json.dumps(components), + ) + ] + ) diff --git a/dash/mcp/primitives/resources/resource_layout.py b/dash/mcp/primitives/resources/resource_layout.py index 01d0be046d..753e2b9229 100644 --- a/dash/mcp/primitives/resources/resource_layout.py +++ b/dash/mcp/primitives/resources/resource_layout.py @@ -5,40 +5,39 @@ from mcp.types import ( ReadResourceResult, Resource, - ResourceTemplate, TextResourceContents, ) from dash import get_app from dash._utils import to_json -URI = "dash://layout" - - -def get_resource() -> Resource | None: - return Resource( - uri=URI, - name="dash_app_layout", - description=( - "Full component tree of the Dash app. " - "See dash://components for a compact list of component IDs." - ), - mimeType="application/json", - ) - - -def get_template() -> ResourceTemplate | None: - return None - - -def read_resource(uri: str = "") -> ReadResourceResult: - app = get_app() - return ReadResourceResult( - contents=[ - TextResourceContents( - uri=URI, - mimeType="application/json", - text=to_json(app.get_layout()), - ) - ] - ) +from .base import MCPResourceProvider + + +class LayoutResource(MCPResourceProvider): + uri = "dash://layout" + + @classmethod + def get_resource(cls) -> Resource | None: + return Resource( + uri=cls.uri, + name="dash_app_layout", + description=( + "Full component tree of the Dash app. " + "See dash://components for a compact list of component IDs." + ), + mimeType="application/json", + ) + + @classmethod + def read_resource(cls, uri: str = "") -> ReadResourceResult: + app = get_app() + return ReadResourceResult( + contents=[ + TextResourceContents( + uri=cls.uri, + mimeType="application/json", + text=to_json(app.get_layout()), + ) + ] + ) diff --git a/dash/mcp/primitives/resources/resource_page_layout.py b/dash/mcp/primitives/resources/resource_page_layout.py index d82d366298..02f322be25 100644 --- a/dash/mcp/primitives/resources/resource_page_layout.py +++ b/dash/mcp/primitives/resources/resource_page_layout.py @@ -4,74 +4,61 @@ from mcp.types import ( ReadResourceResult, - Resource, ResourceTemplate, TextResourceContents, ) +from dash._pages import PAGE_REGISTRY from dash._utils import to_json -URI = "dash://page-layout/" -_URI_TEMPLATE = "dash://page-layout/{path}" - - -def get_resource() -> Resource | None: - return None - - -def get_template() -> ResourceTemplate | None: - if not _has_pages(): - return None - return ResourceTemplate( - uriTemplate=_URI_TEMPLATE, - name="dash_page_layout", - description="Component tree for a specific page in the app.", - mimeType="application/json", - ) - - -def read_resource(uri: str) -> ReadResourceResult: - path = uri[len(URI) :] - if not path.startswith("/"): - path = "/" + path - - try: - from dash._pages import PAGE_REGISTRY - except ImportError: - raise ValueError("Dash Pages is not available.") - - page_layout = None - for _module, page in PAGE_REGISTRY.items(): - if page.get("path") == path: - page_layout = page.get("layout") - break - - if page_layout is None: - raise ValueError(f"Page not found: {path}") - - if callable(page_layout): - page_layout = page_layout() - - if isinstance(page_layout, (list, tuple)): - from dash import html - - page_layout = html.Div(list(page_layout)) - - return ReadResourceResult( - contents=[ - TextResourceContents( - uri=uri, - mimeType="application/json", - text=to_json(page_layout), - ) - ] - ) +from .base import MCPResourceProvider +_URI_TEMPLATE = "dash://page-layout/{path}" -def _has_pages() -> bool: - try: - from dash._pages import PAGE_REGISTRY - return bool(PAGE_REGISTRY) - except ImportError: - return False +class PageLayoutResource(MCPResourceProvider): + uri = "dash://page-layout/" + + @classmethod + def get_template(cls) -> ResourceTemplate | None: + if not PAGE_REGISTRY: + return None + return ResourceTemplate( + uriTemplate=_URI_TEMPLATE, + name="dash_page_layout", + description="Component tree for a specific page in the app.", + mimeType="application/json", + ) + + @classmethod + def read_resource(cls, uri: str) -> ReadResourceResult: + path = uri[len(cls.uri):] + if not path.startswith("/"): + path = "/" + path + + page_layout = None + for _module, page in PAGE_REGISTRY.items(): + if page.get("path") == path: + page_layout = page.get("layout") + break + + if page_layout is None: + raise ValueError(f"Page not found: {path}") + + if callable(page_layout): + page_layout = page_layout() + + if isinstance(page_layout, (list, tuple)): + from dash import html + + page_layout = html.Div(list(page_layout)) + + return ReadResourceResult( + contents=[ + TextResourceContents( + uri=uri, + mimeType="application/json", + text=to_json(page_layout), + ) + ] + ) diff --git a/dash/mcp/primitives/resources/resource_pages.py b/dash/mcp/primitives/resources/resource_pages.py index 51a61b9f00..27c39013f3 100644 --- a/dash/mcp/primitives/resources/resource_pages.py +++ b/dash/mcp/primitives/resources/resource_pages.py @@ -7,70 +7,53 @@ from mcp.types import ( ReadResourceResult, Resource, - ResourceTemplate, TextResourceContents, ) -URI = "dash://pages" +from dash._pages import PAGE_REGISTRY +from .base import MCPResourceProvider -def _has_pages() -> bool: - try: - from dash._pages import PAGE_REGISTRY - return bool(PAGE_REGISTRY) - except ImportError: - return False +class PagesResource(MCPResourceProvider): + uri = "dash://pages" + @classmethod + def get_resource(cls) -> Resource | None: + if not PAGE_REGISTRY: + return None + return Resource( + uri=cls.uri, + name="dash_app_pages", + description=( + "List of all pages in this multi-page Dash app " + "with paths, names, titles, and descriptions." + ), + mimeType="application/json", + ) -def get_resource() -> Resource | None: - if not _has_pages(): - return None - return Resource( - uri=URI, - name="dash_app_pages", - description=( - "List of all pages in this multi-page Dash app " - "with paths, names, titles, and descriptions." - ), - mimeType="application/json", - ) - - -def get_template() -> ResourceTemplate | None: - return None - + @classmethod + def read_resource(cls, uri: str = "") -> ReadResourceResult: + pages = [] + for module, page in PAGE_REGISTRY.items(): + title = page.get("title", "") + description = page.get("description", "") + pages.append( + { + "module": module, + "path": page.get("path", ""), + "name": page.get("name", ""), + "title": title if not callable(title) else page.get("name", ""), + "description": description if not callable(description) else "", + } + ) -def read_resource(uri: str = "") -> ReadResourceResult: - try: - from dash._pages import PAGE_REGISTRY - except ImportError: return ReadResourceResult( contents=[ - TextResourceContents(uri=URI, mimeType="application/json", text="[]") + TextResourceContents( + uri=cls.uri, + mimeType="application/json", + text=json.dumps(pages, default=str), + ) ] ) - - pages = [] - for module, page in PAGE_REGISTRY.items(): - title = page.get("title", "") - description = page.get("description", "") - pages.append( - { - "module": module, - "path": page.get("path", ""), - "name": page.get("name", ""), - "title": title if not callable(title) else page.get("name", ""), - "description": description if not callable(description) else "", - } - ) - - return ReadResourceResult( - contents=[ - TextResourceContents( - uri=URI, - mimeType="application/json", - text=json.dumps(pages, default=str), - ) - ] - ) diff --git a/tests/unit/mcp/primitives/resources/test_resource_page_layout.py b/tests/unit/mcp/primitives/resources/test_resource_page_layout.py index 88ffd82118..f4e9caac5d 100644 --- a/tests/unit/mcp/primitives/resources/test_resource_page_layout.py +++ b/tests/unit/mcp/primitives/resources/test_resource_page_layout.py @@ -46,7 +46,10 @@ def test_read_page_layout(self): }, } with app.server.test_request_context(): - with patch("dash._pages.PAGE_REGISTRY", fake_registry): + with patch( + "dash.mcp.primitives.resources.resource_page_layout.PAGE_REGISTRY", + fake_registry, + ): result = read_resource("dash://page-layout/test") layout = json.loads(result.contents[0].text) assert layout == EXPECTED_PAGE_LAYOUT diff --git a/tests/unit/mcp/primitives/resources/test_resource_pages.py b/tests/unit/mcp/primitives/resources/test_resource_pages.py index b2307d6fef..22e6e798fc 100644 --- a/tests/unit/mcp/primitives/resources/test_resource_pages.py +++ b/tests/unit/mcp/primitives/resources/test_resource_pages.py @@ -43,7 +43,10 @@ def test_listed_for_multi_page_app(self): } } with app.server.test_request_context(): - with patch("dash._pages.PAGE_REGISTRY", fake_registry): + with patch( + "dash.mcp.primitives.resources.resource_pages.PAGE_REGISTRY", + fake_registry, + ): result = list_resources() uris = [str(r.uri) for r in result.resources] assert "dash://pages" in uris @@ -55,7 +58,10 @@ def test_returns_page_info(self): "pages.analytics": EXPECTED_PAGES[1], } with app.server.test_request_context(): - with patch("dash._pages.PAGE_REGISTRY", fake_registry): + with patch( + "dash.mcp.primitives.resources.resource_pages.PAGE_REGISTRY", + fake_registry, + ): result = read_resource("dash://pages") content = json.loads(result.contents[0].text) assert content == EXPECTED_PAGES @@ -71,7 +77,10 @@ def test_callable_title_falls_back_to_name(self): }, } with app.server.test_request_context(): - with patch("dash._pages.PAGE_REGISTRY", fake_registry): + with patch( + "dash.mcp.primitives.resources.resource_pages.PAGE_REGISTRY", + fake_registry, + ): result = read_resource("dash://pages") page = json.loads(result.contents[0].text)[0] assert page["title"] == "Item Detail" From f8d9739d3298345743d0cab66db84abcb3d04802 Mon Sep 17 00:00:00 2001 From: Adrian Borrmann Date: Thu, 16 Apr 2026 11:51:20 -0600 Subject: [PATCH 22/80] lint --- dash/mcp/primitives/resources/resource_components.py | 5 ++++- dash/mcp/primitives/resources/resource_page_layout.py | 2 +- 2 files changed, 5 insertions(+), 2 deletions(-) diff --git a/dash/mcp/primitives/resources/resource_components.py b/dash/mcp/primitives/resources/resource_components.py index 8175aab72e..9d035a855f 100644 --- a/dash/mcp/primitives/resources/resource_components.py +++ b/dash/mcp/primitives/resources/resource_components.py @@ -40,7 +40,10 @@ def read_resource(cls, uri: str = "") -> ReadResourceResult: layout = app.get_layout() components = sorted( [ - {"id": str(comp.id), "type": getattr(comp, "_type", type(comp).__name__)} + { + "id": str(comp.id), + "type": getattr(comp, "_type", type(comp).__name__), + } for comp, _ in traverse(layout) if getattr(comp, "id", None) is not None ], diff --git a/dash/mcp/primitives/resources/resource_page_layout.py b/dash/mcp/primitives/resources/resource_page_layout.py index 02f322be25..c1218a57d0 100644 --- a/dash/mcp/primitives/resources/resource_page_layout.py +++ b/dash/mcp/primitives/resources/resource_page_layout.py @@ -32,7 +32,7 @@ def get_template(cls) -> ResourceTemplate | None: @classmethod def read_resource(cls, uri: str) -> ReadResourceResult: - path = uri[len(cls.uri):] + path = uri[len(cls.uri) :] if not path.startswith("/"): path = "/" + path From 6d4e2a0c2badc3b8c06f17ae675a57366f6d8358 Mon Sep 17 00:00:00 2001 From: Adrian Borrmann Date: Wed, 22 Apr 2026 11:22:22 -0600 Subject: [PATCH 23/80] Fix pylint errors --- dash/mcp/__init__.py | 0 dash/mcp/primitives/__init__.py | 0 .../resources/resource_page_layout.py | 3 +- dash/mcp/primitives/tools/__init__.py | 0 dash/mcp/primitives/tools/callback_adapter.py | 8 ++--- .../tools/callback_adapter_collection.py | 29 +++++++++++-------- dash/mcp/primitives/tools/callback_utils.py | 3 +- .../primitives/tools/descriptions/__init__.py | 2 +- .../tools/input_schemas/__init__.py | 2 +- dash/mcp/types/exceptions.py | 2 -- dash/mcp/types/typing_utils.py | 2 +- 11 files changed, 26 insertions(+), 25 deletions(-) create mode 100644 dash/mcp/__init__.py create mode 100644 dash/mcp/primitives/__init__.py create mode 100644 dash/mcp/primitives/tools/__init__.py diff --git a/dash/mcp/__init__.py b/dash/mcp/__init__.py new file mode 100644 index 0000000000..e69de29bb2 diff --git a/dash/mcp/primitives/__init__.py b/dash/mcp/primitives/__init__.py new file mode 100644 index 0000000000..e69de29bb2 diff --git a/dash/mcp/primitives/resources/resource_page_layout.py b/dash/mcp/primitives/resources/resource_page_layout.py index c1218a57d0..613f0b41b9 100644 --- a/dash/mcp/primitives/resources/resource_page_layout.py +++ b/dash/mcp/primitives/resources/resource_page_layout.py @@ -8,6 +8,7 @@ TextResourceContents, ) +from dash import html from dash._pages import PAGE_REGISTRY from dash._utils import to_json @@ -49,8 +50,6 @@ def read_resource(cls, uri: str) -> ReadResourceResult: page_layout = page_layout() if isinstance(page_layout, (list, tuple)): - from dash import html - page_layout = html.Div(list(page_layout)) return ReadResourceResult( diff --git a/dash/mcp/primitives/tools/__init__.py b/dash/mcp/primitives/tools/__init__.py new file mode 100644 index 0000000000..e69de29bb2 diff --git a/dash/mcp/primitives/tools/callback_adapter.py b/dash/mcp/primitives/tools/callback_adapter.py index 9dcb8d959e..1ed30cdad8 100644 --- a/dash/mcp/primitives/tools/callback_adapter.py +++ b/dash/mcp/primitives/tools/callback_adapter.py @@ -126,11 +126,11 @@ def output_id(self) -> str: @property def tool_name(self) -> str: - return get_app().mcp_callback_map._tool_names_map[self._output_id] + return get_app().mcp_callback_map._tool_names_map[self._output_id] # pylint: disable=protected-access @cached_property def prevents_initial_call(self) -> bool: - for cb in get_app()._callback_list: + for cb in get_app()._callback_list: # pylint: disable=protected-access if cb["output"] == self._output_id: return cb.get("prevent_initial_call", False) return False @@ -191,7 +191,7 @@ def _initial_output(self) -> dict[str, CallbackOutput]: try: result = run_callback(self, kwargs) return result.get("response", {}) - except Exception: + except Exception: # pylint: disable=broad-exception-caught return {} def initial_output_value(self, id_and_prop: str) -> Any: @@ -346,7 +346,7 @@ def _param_annotations(self) -> list[Any | None]: return [None] * len(self._dep_param_map) try: hints = typing.get_type_hints(self._original_func) - except Exception: + except Exception: # pylint: disable=broad-exception-caught hints = getattr(self._original_func, "__annotations__", {}) return [hints.get(func_name) for func_name, _ in self._dep_param_map] diff --git a/dash/mcp/primitives/tools/callback_adapter_collection.py b/dash/mcp/primitives/tools/callback_adapter_collection.py index 59c1a7ac47..8e5769124b 100644 --- a/dash/mcp/primitives/tools/callback_adapter_collection.py +++ b/dash/mcp/primitives/tools/callback_adapter_collection.py @@ -109,10 +109,9 @@ def get_initial_value(self, id_and_prop: str) -> Any: upstream_cb = self.find_by_output(id_and_prop) if upstream_cb is not None: return upstream_cb.initial_output_value(id_and_prop) - else: - component_id, prop = id_and_prop.rsplit(".", 1) - layout_component = find_component(component_id) - return getattr(layout_component, prop, None) + component_id, prop = id_and_prop.rsplit(".", 1) + layout_component = find_component(component_id) + return getattr(layout_component, prop, None) def as_mcp_tools(self) -> list[Tool]: """Stub — will be implemented in a future PR.""" @@ -142,13 +141,19 @@ def component_label_map(self) -> dict[str, list[str]]: comp_id = getattr(comp, "id", None) if comp_id is not None: - for ancestor in reversed(ancestors): - if getattr(ancestor, "_type", None) == "Label": - text = extract_text(ancestor) - if text: - sid = str(comp_id) - if text not in labels.get(sid, []): - labels.setdefault(sid, []).append(text) - break + self._add_ancestor_label(comp_id, ancestors, labels) return labels + + @staticmethod + def _add_ancestor_label(comp_id, ancestors, labels: dict[str, list[str]]) -> None: + """Record the text of the nearest Label ancestor for ``comp_id``, if any.""" + for ancestor in reversed(ancestors): + if getattr(ancestor, "_type", None) != "Label": + continue + text = extract_text(ancestor) + if text: + sid = str(comp_id) + if text not in labels.get(sid, []): + labels.setdefault(sid, []).append(text) + return diff --git a/dash/mcp/primitives/tools/callback_utils.py b/dash/mcp/primitives/tools/callback_utils.py index 0ff4f5f578..361b27fc32 100644 --- a/dash/mcp/primitives/tools/callback_utils.py +++ b/dash/mcp/primitives/tools/callback_utils.py @@ -6,6 +6,7 @@ from typing import TYPE_CHECKING, Any from dash import get_app +from dash.mcp.types import CallbackExecutionError from dash.types import CallbackExecutionResponse if TYPE_CHECKING: @@ -16,8 +17,6 @@ def run_callback( callback: CallbackAdapter, kwargs: dict[str, Any] ) -> CallbackExecutionResponse: """Execute a callback via the framework.""" - from dash.mcp.types import CallbackExecutionError - body = callback.as_callback_body(kwargs) app = get_app() diff --git a/dash/mcp/primitives/tools/descriptions/__init__.py b/dash/mcp/primitives/tools/descriptions/__init__.py index 67ec78c9ff..b3c0dd3527 100644 --- a/dash/mcp/primitives/tools/descriptions/__init__.py +++ b/dash/mcp/primitives/tools/descriptions/__init__.py @@ -1,7 +1,7 @@ """Stub — real implementation in a later PR.""" -def build_tool_description(outputs, docstring=None): +def build_tool_description(outputs, docstring=None): # pylint: disable=unused-argument if docstring: return docstring.strip() return "Dash callback" diff --git a/dash/mcp/primitives/tools/input_schemas/__init__.py b/dash/mcp/primitives/tools/input_schemas/__init__.py index f306042a0c..968363ff99 100644 --- a/dash/mcp/primitives/tools/input_schemas/__init__.py +++ b/dash/mcp/primitives/tools/input_schemas/__init__.py @@ -1,5 +1,5 @@ """Stub — real implementation in a later PR.""" -def get_input_schema(param): +def get_input_schema(param): # pylint: disable=unused-argument return {} diff --git a/dash/mcp/types/exceptions.py b/dash/mcp/types/exceptions.py index 4b860f4bba..578fa51cc9 100644 --- a/dash/mcp/types/exceptions.py +++ b/dash/mcp/types/exceptions.py @@ -26,5 +26,3 @@ class InvalidParamsError(MCPError): class CallbackExecutionError(MCPError): """Callback raised an exception during execution.""" - - pass diff --git a/dash/mcp/types/typing_utils.py b/dash/mcp/types/typing_utils.py index 9a96d4135d..e685f5808b 100644 --- a/dash/mcp/types/typing_utils.py +++ b/dash/mcp/types/typing_utils.py @@ -14,7 +14,7 @@ def is_nullable(annotation: Any) -> bool: _is_union = origin is typing.Union if not _is_union: try: - import types as _types + import types as _types # pylint: disable=import-outside-toplevel if isinstance(annotation, _types.UnionType): _is_union = True From ef1ea3fbb5bf383a4ac3d85f96124bb3a9b83671 Mon Sep 17 00:00:00 2001 From: Adrian Borrmann Date: Wed, 1 Apr 2026 13:05:07 -0600 Subject: [PATCH 24/80] Implement callbacks as tools with rich input/output schema and description generation --- dash/mcp/primitives/tools/callback_adapter.py | 13 +- .../tools/callback_adapter_collection.py | 3 +- .../primitives/tools/descriptions/__init__.py | 35 +- .../descriptions/description_docstring.py | 15 + .../tools/descriptions/description_outputs.py | 56 +++ .../tools/input_schemas/__init__.py | 44 +- .../input_descriptions/__init__.py | 31 ++ .../description_component_props.py | 81 ++++ .../description_docstrings.py | 71 +++ .../description_html_labels.py | 23 + .../schema_callback_type_annotations.py | 67 +++ .../schema_component_proptypes.py | 32 ++ .../schema_component_proptypes_overrides.py | 70 +++ .../tools/output_schemas/__init__.py | 28 +- .../schema_callback_response.py | 16 + tests/unit/mcp/conftest.py | 81 ++++ .../unit/mcp/tools/input_schemas/__init__.py | 0 .../input_descriptions/__init__.py | 0 .../input_descriptions/test_descriptions.py | 424 ++++++++++++++++++ .../tools/input_schemas/test_input_schemas.py | 331 ++++++++++++++ .../test_schema_component_proptypes.py | 15 + tests/unit/mcp/tools/test_callback_adapter.py | 167 ++++++- tests/unit/mcp/tools/test_tool_schema.py | 64 +++ 23 files changed, 1652 insertions(+), 15 deletions(-) create mode 100644 dash/mcp/primitives/tools/descriptions/description_docstring.py create mode 100644 dash/mcp/primitives/tools/descriptions/description_outputs.py create mode 100644 dash/mcp/primitives/tools/input_schemas/input_descriptions/__init__.py create mode 100644 dash/mcp/primitives/tools/input_schemas/input_descriptions/description_component_props.py create mode 100644 dash/mcp/primitives/tools/input_schemas/input_descriptions/description_docstrings.py create mode 100644 dash/mcp/primitives/tools/input_schemas/input_descriptions/description_html_labels.py create mode 100644 dash/mcp/primitives/tools/input_schemas/schema_callback_type_annotations.py create mode 100644 dash/mcp/primitives/tools/input_schemas/schema_component_proptypes.py create mode 100644 dash/mcp/primitives/tools/input_schemas/schema_component_proptypes_overrides.py create mode 100644 dash/mcp/primitives/tools/output_schemas/schema_callback_response.py create mode 100644 tests/unit/mcp/tools/input_schemas/__init__.py create mode 100644 tests/unit/mcp/tools/input_schemas/input_descriptions/__init__.py create mode 100644 tests/unit/mcp/tools/input_schemas/input_descriptions/test_descriptions.py create mode 100644 tests/unit/mcp/tools/input_schemas/test_input_schemas.py create mode 100644 tests/unit/mcp/tools/input_schemas/test_schema_component_proptypes.py create mode 100644 tests/unit/mcp/tools/test_tool_schema.py diff --git a/dash/mcp/primitives/tools/callback_adapter.py b/dash/mcp/primitives/tools/callback_adapter.py index 1ed30cdad8..c7044de3f4 100644 --- a/dash/mcp/primitives/tools/callback_adapter.py +++ b/dash/mcp/primitives/tools/callback_adapter.py @@ -50,8 +50,17 @@ def __init__(self, callback_output_id: str): @cached_property def as_mcp_tool(self) -> Tool: - """Stub — will be implemented in a future PR.""" - raise NotImplementedError("as_mcp_tool will be implemented in a future PR.") + """Transforms the internal Dash callback to a structured MCP tool. + + This tool can be serialized for LLM consumption or used internally for + its computed data. + """ + return Tool( + name=self.tool_name, + description=self._description, + inputSchema=self._input_schema, + outputSchema=self._output_schema, + ) def as_callback_body(self, kwargs: dict[str, Any]) -> CallbackExecutionBody: """Transforms the given kwargs to a dict suitable for calling this callback. diff --git a/dash/mcp/primitives/tools/callback_adapter_collection.py b/dash/mcp/primitives/tools/callback_adapter_collection.py index 8e5769124b..0304394f63 100644 --- a/dash/mcp/primitives/tools/callback_adapter_collection.py +++ b/dash/mcp/primitives/tools/callback_adapter_collection.py @@ -114,8 +114,7 @@ def get_initial_value(self, id_and_prop: str) -> Any: return getattr(layout_component, prop, None) def as_mcp_tools(self) -> list[Tool]: - """Stub — will be implemented in a future PR.""" - raise NotImplementedError("as_mcp_tools will be implemented in a future PR.") + return [cb.as_mcp_tool for cb in self._callbacks if cb.is_valid] @property def tool_names(self) -> set[str]: diff --git a/dash/mcp/primitives/tools/descriptions/__init__.py b/dash/mcp/primitives/tools/descriptions/__init__.py index b3c0dd3527..d464677251 100644 --- a/dash/mcp/primitives/tools/descriptions/__init__.py +++ b/dash/mcp/primitives/tools/descriptions/__init__.py @@ -1,7 +1,32 @@ -"""Stub — real implementation in a later PR.""" +"""Tool-level description generation for MCP tools. +Each source shares the same signature: +``(outputs, docstring) -> list[str]`` -def build_tool_description(outputs, docstring=None): # pylint: disable=unused-argument - if docstring: - return docstring.strip() - return "Dash callback" +This is distinct from per-parameter descriptions +(in ``input_schemas/input_descriptions/``) which populate +``inputSchema.properties.{param}.description``. +""" + +from __future__ import annotations + +from typing import Any + +from .description_docstring import callback_docstring +from .description_outputs import output_summary + +_SOURCES = [ + output_summary, + callback_docstring, +] + + +def build_tool_description( + outputs: list[dict[str, Any]], + docstring: str | None = None, +) -> str: + """Build a human-readable description for an MCP tool.""" + lines: list[str] = [] + for source in _SOURCES: + lines.extend(source(outputs, docstring)) + return "\n".join(lines) if lines else "Dash callback" diff --git a/dash/mcp/primitives/tools/descriptions/description_docstring.py b/dash/mcp/primitives/tools/descriptions/description_docstring.py new file mode 100644 index 0000000000..71cf4d3d5a --- /dev/null +++ b/dash/mcp/primitives/tools/descriptions/description_docstring.py @@ -0,0 +1,15 @@ +"""Callback docstring for tool descriptions.""" + +from __future__ import annotations + +from typing import Any + + +def callback_docstring( + outputs: list[dict[str, Any]], + docstring: str | None = None, +) -> list[str]: + """Return the callback's docstring as description lines.""" + if docstring: + return ["", docstring.strip()] + return [] diff --git a/dash/mcp/primitives/tools/descriptions/description_outputs.py b/dash/mcp/primitives/tools/descriptions/description_outputs.py new file mode 100644 index 0000000000..c174c177f4 --- /dev/null +++ b/dash/mcp/primitives/tools/descriptions/description_outputs.py @@ -0,0 +1,56 @@ +"""Output summary for tool descriptions.""" + +from __future__ import annotations + +from typing import Any + +_OUTPUT_SEMANTICS: dict[tuple[str | None, str], str] = { + ("Graph", "figure"): "Returns chart/visualization data", + ("DataTable", "data"): "Returns tabular data", + ("DataTable", "columns"): "Returns table column definitions", + ("Dropdown", "options"): "Returns selection options", + ("Dropdown", "value"): "Updates a selection value", + ("RadioItems", "options"): "Returns selection options", + ("Checklist", "options"): "Returns selection options", + ("Store", "data"): "Returns stored data", + ("Download", "data"): "Returns downloadable content", + ("Markdown", "children"): "Returns formatted text", + (None, "figure"): "Returns chart/visualization data", + (None, "data"): "Returns data", + (None, "options"): "Returns selection options", + (None, "columns"): "Returns column definitions", + (None, "children"): "Returns content", + (None, "value"): "Returns a value", + (None, "style"): "Updates styling", + (None, "disabled"): "Updates enabled/disabled state", +} + + +def output_summary( + outputs: list[dict[str, Any]], + docstring: str | None = None, +) -> list[str]: + """Produce a short summary of what the callback outputs represent.""" + if not outputs: + return ["Dash callback"] + + lines: list[str] = [] + for out in outputs: + comp_id = out["component_id"] + prop = out["property"] + comp_type = out.get("component_type") + + semantic = _OUTPUT_SEMANTICS.get((comp_type, prop)) + if semantic is None: + semantic = _OUTPUT_SEMANTICS.get((None, prop)) + + if semantic is not None: + lines.append(f"- {comp_id}.{prop}: {semantic}") + else: + lines.append(f"- {comp_id}.{prop}") + + n = len(outputs) + if n == 1: + return [lines[0].lstrip("- ")] + header = f"Returns {n} output{'s' if n > 1 else ''}:" + return [header] + lines diff --git a/dash/mcp/primitives/tools/input_schemas/__init__.py b/dash/mcp/primitives/tools/input_schemas/__init__.py index 968363ff99..2c1646f56a 100644 --- a/dash/mcp/primitives/tools/input_schemas/__init__.py +++ b/dash/mcp/primitives/tools/input_schemas/__init__.py @@ -1,5 +1,43 @@ -"""Stub — real implementation in a later PR.""" +"""Input schema generation for MCP tool inputSchema fields. +Mirrors ``output_schemas/`` which generates ``outputSchema``. -def get_input_schema(param): # pylint: disable=unused-argument - return {} +Each source is tried in priority order. All share the same signature: +``(param: MCPInput) -> dict | None``. +""" + +from __future__ import annotations + +from typing import Any + +from dash.mcp.types import MCPInput +from .schema_callback_type_annotations import annotation_to_schema +from .schema_component_proptypes_overrides import get_override_schema +from .schema_component_proptypes import get_component_prop_schema +from .input_descriptions import get_property_description + +_SOURCES = [ + annotation_to_schema, + get_override_schema, + get_component_prop_schema, +] + + +def get_input_schema(param: MCPInput) -> dict[str, Any]: + """Return the complete JSON Schema for a callback input parameter. + + Type sources provide ``type``/``enum`` (first non-None wins). + Description is assembled by ``input_descriptions``. + """ + schema: dict[str, Any] = {} + for source in _SOURCES: + result = source(param) + if result is not None: + schema = result + break + + description = get_property_description(param) + if description: + schema = {**schema, "description": description} + + return schema diff --git a/dash/mcp/primitives/tools/input_schemas/input_descriptions/__init__.py b/dash/mcp/primitives/tools/input_schemas/input_descriptions/__init__.py new file mode 100644 index 0000000000..e1d1e9f47c --- /dev/null +++ b/dash/mcp/primitives/tools/input_schemas/input_descriptions/__init__.py @@ -0,0 +1,31 @@ +"""Per-property description generation for MCP tool input parameters. + +Each source shares the same signature: +``(param: MCPInput) -> list[str]`` + +Sources are tried in order from most generic to most instance-specific. +All sources that produce lines are combined. +""" + +from __future__ import annotations + +from dash.mcp.types import MCPInput +from .description_component_props import component_props_description +from .description_docstrings import docstring_prop_description +from .description_html_labels import label_description + +_SOURCES = [ + docstring_prop_description, + label_description, + component_props_description, +] + + +def get_property_description(param: MCPInput) -> str | None: + """Build a complete description string for a callback input parameter.""" + lines: list[str] = [] + if not param.get("required", True): + lines.append("Input is optional.") + for source in _SOURCES: + lines.extend(source(param)) + return "\n".join(lines) if lines else None diff --git a/dash/mcp/primitives/tools/input_schemas/input_descriptions/description_component_props.py b/dash/mcp/primitives/tools/input_schemas/input_descriptions/description_component_props.py new file mode 100644 index 0000000000..6934918260 --- /dev/null +++ b/dash/mcp/primitives/tools/input_schemas/input_descriptions/description_component_props.py @@ -0,0 +1,81 @@ +"""Generic component property descriptions. + +Generate a description for each component prop that has a value (either set +directly in the layout or by an upstream callback). +""" + +from __future__ import annotations + +from typing import Any + +from dash import get_app +from dash.mcp.types import MCPInput + +_MAX_VALUE_LENGTH = 200 + +_MCP_EXCLUDED_PROPS = {"id", "className", "style"} + +_PROP_TEMPLATES: dict[tuple[str | None, str], str] = { + ("Store", "storage_type"): ( + "storage_type: {value}. Describes how to store the value client-side" + "'memory' resets on page refresh. " + "'session' persists for the duration of this session. " + "'local' persists on disk until explicitly cleared." + ), +} + + +def component_props_description(param: MCPInput) -> list[str]: + component = param.get("component") + if component is None: + return [] + + component_id = param["component_id"] + cbmap = get_app().mcp_callback_map + prop_lines: list[str] = [] + + for prop_name in getattr(component, "_prop_names", []): + if prop_name in _MCP_EXCLUDED_PROPS: + continue + + upstream = cbmap.find_by_output(f"{component_id}.{prop_name}") + if upstream is not None and not upstream.prevents_initial_call: + value = upstream.initial_output_value(f"{component_id}.{prop_name}") + else: + value = getattr(component, prop_name, None) + tool_name = upstream.tool_name if upstream is not None else None + + if value is None and tool_name is None: + continue + + component_type = param.get("component_type") + template = _PROP_TEMPLATES.get((component_type, prop_name)) + formatted_value = ( + _truncate_large_values(value, component_id, prop_name) + if value is not None + else None + ) + + if template and formatted_value is not None: + line = template.format(value=formatted_value) + elif formatted_value is not None: + line = f"{prop_name}: {formatted_value}" + else: + line = prop_name + + if tool_name: + line += f" (can be updated by tool: `{tool_name}`)" + + prop_lines.append(line) + + if not prop_lines: + return [] + return [f"Component properties for {component_id}:"] + prop_lines + + +def _truncate_large_values(value: Any, component_id: str, prop_name: str) -> str: + text = repr(value) + if len(text) > _MAX_VALUE_LENGTH: + hint = f"Use get_dash_component('{component_id}', '{prop_name}') for the full value" + return f"{text[:_MAX_VALUE_LENGTH]}... ({hint})" + return text diff --git a/dash/mcp/primitives/tools/input_schemas/input_descriptions/description_docstrings.py b/dash/mcp/primitives/tools/input_schemas/input_descriptions/description_docstrings.py new file mode 100644 index 0000000000..1f67c3c0f2 --- /dev/null +++ b/dash/mcp/primitives/tools/input_schemas/input_descriptions/description_docstrings.py @@ -0,0 +1,71 @@ +"""Extract property descriptions from component class docstrings. + +Dash component classes have structured docstrings generated by +``dash-generate-components`` in the format:: + + Keyword arguments: + + - prop_name (type_string; optional): + Description text that may span + multiple lines. + +This module parses that format and returns the first sentence of the +description for a given property. +""" + +from __future__ import annotations + +import re + +from dash.mcp.types import MCPInput + +_PROP_RE = re.compile( + r"^[ ]*- (\w+) \([^)]+\):\s*\n((?:[ ]+.+\n)*)", + re.MULTILINE, +) + +_cache: dict[type, dict[str, str]] = {} + +_SENTENCE_END = re.compile(r"(?<=[.!?])\s") + + +def docstring_prop_description(param: MCPInput) -> list[str]: + component = param.get("component") + if component is None: + return [] + desc = _get_prop_description(type(component), param["property"]) + return [desc] if desc else [] + + +def _get_prop_description(cls: type, prop: str) -> str | None: + props = _parse_docstring(cls) + return props.get(prop) + + +def _parse_docstring(cls: type) -> dict[str, str]: + if cls in _cache: + return _cache[cls] + + doc = getattr(cls, "__doc__", None) + if not doc: + _cache[cls] = {} + return _cache[cls] + + props: dict[str, str] = {} + for match in _PROP_RE.finditer(doc): + prop_name = match.group(1) + raw_desc = match.group(2) + lines = [line.strip() for line in raw_desc.strip().splitlines()] + desc = " ".join(lines) + if desc: + props[prop_name] = _first_sentence(desc) + + _cache[cls] = props + return props + + +def _first_sentence(text: str) -> str: + m = _SENTENCE_END.search(text) + if m: + return text[: m.start() + 1].rstrip() + return text diff --git a/dash/mcp/primitives/tools/input_schemas/input_descriptions/description_html_labels.py b/dash/mcp/primitives/tools/input_schemas/input_descriptions/description_html_labels.py new file mode 100644 index 0000000000..2c9cd8dea9 --- /dev/null +++ b/dash/mcp/primitives/tools/input_schemas/input_descriptions/description_html_labels.py @@ -0,0 +1,23 @@ +"""Label-based property descriptions. + +Reads the label map from the ``CallbackAdapterCollection``, +which builds it from the layout using ``htmlFor`` and +containment associations. +""" + +from __future__ import annotations + +from dash import get_app +from dash.mcp.types import MCPInput + + +def label_description(param: MCPInput) -> list[str]: + """Return the label text for this component, if any.""" + component_id = param.get("component_id") + if not component_id: + return [] + label_map = get_app().mcp_callback_map.component_label_map + texts = label_map.get(component_id, []) + if texts: + return [f"Labeled with: {'; '.join(texts)}"] + return [] diff --git a/dash/mcp/primitives/tools/input_schemas/schema_callback_type_annotations.py b/dash/mcp/primitives/tools/input_schemas/schema_callback_type_annotations.py new file mode 100644 index 0000000000..aee5b17c6f --- /dev/null +++ b/dash/mcp/primitives/tools/input_schemas/schema_callback_type_annotations.py @@ -0,0 +1,67 @@ +"""Map callback function type annotations to JSON Schema. + +When a callback function has explicit type annotations, those take +priority over all other schema sources (static overrides, component +introspection). + +Unlike component annotations (where nullable means "not required"), +callback annotations preserve ``null`` in the schema type when the +user writes ``Optional[X]`` — the user is explicitly saying the +value can be null. + +Also provides ``annotation_to_json_schema``, the shared low-level +converter used by both callback and component annotation pipelines. +""" + +from __future__ import annotations + +import inspect +from typing import Any + +from pydantic import TypeAdapter + +from dash.development.base_component import Component +from dash.mcp.types import MCPInput, is_nullable + + +def annotation_to_json_schema(annotation: type) -> dict[str, Any] | None: + """Convert a Python type annotation to a JSON Schema dict. + + Returns ``None`` if the annotation cannot be translated. + """ + if annotation is inspect.Parameter.empty or annotation is type(None): + return None + + if isinstance(annotation, type) and issubclass(annotation, Component): + return {"type": "string"} + + try: + return TypeAdapter(annotation).json_schema() + except Exception: + return None + + +def annotation_to_schema(param: MCPInput) -> dict[str, Any] | None: + """Convert a callback parameter's type annotation to a JSON Schema dict. + + Returns ``None`` if the annotation is not recognised, meaning the + caller should fall through to the next schema source. + + ``Optional[X]`` produces ``{"type": ["X", "null"]}`` — the user + explicitly chose a nullable type. + """ + annotation = param.get("annotation") + if annotation is None: + return None + schema = annotation_to_json_schema(annotation) + if schema is None: + return None + + if is_nullable(annotation) and schema: + t = schema.get("type") + if isinstance(t, str): + schema = {**schema, "type": [t, "null"]} + elif isinstance(t, list) and "null" not in t: + schema = {**schema, "type": [*t, "null"]} + + return schema diff --git a/dash/mcp/primitives/tools/input_schemas/schema_component_proptypes.py b/dash/mcp/primitives/tools/input_schemas/schema_component_proptypes.py new file mode 100644 index 0000000000..151e391cf4 --- /dev/null +++ b/dash/mcp/primitives/tools/input_schemas/schema_component_proptypes.py @@ -0,0 +1,32 @@ +"""Derive JSON Schema from a component's ``__init__`` type annotations.""" + +from __future__ import annotations + +import inspect +from typing import Any + +from dash.mcp.types import MCPInput +from .schema_callback_type_annotations import annotation_to_json_schema + + +def get_component_prop_schema(param: MCPInput) -> dict[str, Any] | None: + """Return the JSON Schema for a component property. + + Inspects the ``__init__`` signature of the component's class. + Returns ``None`` if the prop has no annotation. + """ + component = param.get("component") + prop = param["property"] + if component is None: + return None + + try: + sig = inspect.signature(type(component).__init__) + except (ValueError, TypeError): + return None + + sig_param = sig.parameters.get(prop) + if sig_param is None or sig_param.annotation is inspect.Parameter.empty: + return None + + return annotation_to_json_schema(sig_param.annotation) diff --git a/dash/mcp/primitives/tools/input_schemas/schema_component_proptypes_overrides.py b/dash/mcp/primitives/tools/input_schemas/schema_component_proptypes_overrides.py new file mode 100644 index 0000000000..25086896e7 --- /dev/null +++ b/dash/mcp/primitives/tools/input_schemas/schema_component_proptypes_overrides.py @@ -0,0 +1,70 @@ +"""A place to manually define Schemas that override component-defined prop types +where type generation produces insufficient results. +""" + +from __future__ import annotations + +from typing import Any + +from dash.mcp.types import MCPInput +from .schema_component_proptypes import get_component_prop_schema + +_DATE_SCHEMA = { + "type": "string", + "format": "date", + "pattern": r"^\d{4}-\d{2}-\d{2}$", +} + + +def _compute_dropdown_value_schema(param: MCPInput) -> dict[str, Any] | None: + """Dropdown values are an array if `multi=True`; scalar values otherwise.""" + schema = get_component_prop_schema(param) + if schema is None: + return None + + component = param.get("component") + t = schema.get("type") + if not isinstance(t, list): + return schema + + if getattr(component, "multi", False): + items_schema = schema.get("items", {}) + return ( + {"type": "array", "items": items_schema} + if items_schema + else {"type": "array"} + ) + + scalar_types = [x for x in t if x != "array"] + refined = dict(schema) + refined["type"] = scalar_types[0] if len(scalar_types) == 1 else scalar_types + refined.pop("items", None) + return refined + + +_OVERRIDES: dict[tuple[str, str], dict[str, Any] | callable] = { + ("DatePickerSingle", "date"): _DATE_SCHEMA, + ("DatePickerRange", "start_date"): _DATE_SCHEMA, + ("DatePickerRange", "end_date"): _DATE_SCHEMA, + # Graph — annotation says "object", we add structured properties. + ("Graph", "figure"): { + "type": "object", + "properties": { + "data": {"type": "array", "items": {"type": "object"}}, + "layout": {"type": "object"}, + "frames": {"type": "array", "items": {"type": "object"}}, + }, + }, + ("Dropdown", "value"): _compute_dropdown_value_schema, +} + + +def get_override_schema(param: MCPInput) -> dict[str, Any] | None: + """Return a schema override, or None to fall through to introspection.""" + key = (param.get("component_type"), param["property"]) + override = _OVERRIDES.get(key) + if override is None: + return None + if callable(override): + return override(param) + return dict(override) diff --git a/dash/mcp/primitives/tools/output_schemas/__init__.py b/dash/mcp/primitives/tools/output_schemas/__init__.py index d2d70c3552..41ddfd8d49 100644 --- a/dash/mcp/primitives/tools/output_schemas/__init__.py +++ b/dash/mcp/primitives/tools/output_schemas/__init__.py @@ -1,5 +1,29 @@ -"""Stub — real implementation in a later PR.""" +"""Output schema generation for MCP tool outputSchema fields. +Mirrors ``input_schemas/`` which generates ``inputSchema``. -def get_output_schema(): +Each source shares the same signature: ``() -> dict | None``. +""" + +from __future__ import annotations + +from typing import Any + +from .schema_callback_response import callback_response_schema + +_SOURCES = [ + callback_response_schema, +] + + +def get_output_schema() -> dict[str, Any]: + """Return the JSON Schema for a callback tool's output. + + Tries each source in order, returning the first non-None result. + Falls back to ``{}`` (any type). + """ + for source in _SOURCES: + schema = source() + if schema is not None: + return schema return {} diff --git a/dash/mcp/primitives/tools/output_schemas/schema_callback_response.py b/dash/mcp/primitives/tools/output_schemas/schema_callback_response.py new file mode 100644 index 0000000000..e61a482cba --- /dev/null +++ b/dash/mcp/primitives/tools/output_schemas/schema_callback_response.py @@ -0,0 +1,16 @@ +"""Output schema derived from CallbackDispatchResponse.""" + +from __future__ import annotations + +from typing import Any + +from pydantic import TypeAdapter + +from dash.types import CallbackDispatchResponse + +_schema = TypeAdapter(CallbackDispatchResponse).json_schema() + + +def callback_response_schema() -> dict[str, Any]: + """Return the JSON Schema for a callback dispatch response.""" + return _schema diff --git a/tests/unit/mcp/conftest.py b/tests/unit/mcp/conftest.py index 437a71db5c..97b8d9c137 100644 --- a/tests/unit/mcp/conftest.py +++ b/tests/unit/mcp/conftest.py @@ -4,3 +4,84 @@ if sys.version_info < (3, 10): collect_ignore_glob.append("*") + +"""Shared helpers for MCP unit tests. + +These helpers work directly with Tool objects from CallbackAdapterCollection, +avoiding the MCP server so they can be used before the server is wired up. +""" + +from dash import Dash, Input, Output, html +from dash._get_app import app_context +from dash.mcp.primitives.tools.callback_adapter_collection import ( + CallbackAdapterCollection, +) + +BUILTINS = {"get_dash_component"} + + +def _setup_mcp(app): + """Set up MCP for an app in tests.""" + app_context.set(app) + app.mcp_callback_map = CallbackAdapterCollection(app) + return app + + +def _make_app(**kwargs): + """Create a minimal Dash app with a layout and one callback.""" + app = Dash(__name__, **kwargs) + app.layout = html.Div( + [ + html.Div(id="my-input"), + html.Div(id="my-output"), + ] + ) + + @app.callback(Output("my-output", "children"), Input("my-input", "children")) + def update_output(value): + """Test callback docstring.""" + return f"echo: {value}" + + return _setup_mcp(app) + + +def _tools_list(app): + """Return tools as Tool objects via as_mcp_tools().""" + _setup_mcp(app) + with app.server.test_request_context(): + return app.mcp_callback_map.as_mcp_tools() + + +def _user_tool(tools): + """Return the first tool that isn't a builtin.""" + return next(t for t in tools if t.name not in BUILTINS) + + +def _app_with_callback(component, input_prop="value", output_id="out"): + """Create a Dash app with one callback using ``component`` as Input.""" + app = Dash(__name__) + app.layout = html.Div([component, html.Div(id=output_id)]) + + @app.callback(Output(output_id, "children"), Input(component.id, input_prop)) + def update(val): + return f"got: {val}" + + return _setup_mcp(app) + + +def _schema_for(tool, param_name=None): + """Extract the JSON schema dict for a parameter, without description.""" + props = tool.inputSchema["properties"] + if param_name is None: + param_name = next(iter(props)) + schema = dict(props[param_name]) + schema.pop("description", None) + return schema + + +def _desc_for(tool, param_name=None): + """Extract the description string for a parameter, or ''.""" + props = tool.inputSchema["properties"] + if param_name is None: + param_name = next(iter(props)) + return props[param_name].get("description", "") diff --git a/tests/unit/mcp/tools/input_schemas/__init__.py b/tests/unit/mcp/tools/input_schemas/__init__.py new file mode 100644 index 0000000000..e69de29bb2 diff --git a/tests/unit/mcp/tools/input_schemas/input_descriptions/__init__.py b/tests/unit/mcp/tools/input_schemas/input_descriptions/__init__.py new file mode 100644 index 0000000000..e69de29bb2 diff --git a/tests/unit/mcp/tools/input_schemas/input_descriptions/test_descriptions.py b/tests/unit/mcp/tools/input_schemas/input_descriptions/test_descriptions.py new file mode 100644 index 0000000000..bc6758d2d3 --- /dev/null +++ b/tests/unit/mcp/tools/input_schemas/input_descriptions/test_descriptions.py @@ -0,0 +1,424 @@ +"""Description tests — verifies per-property description generation. + +Tests are organized by description source: +- Labels (htmlFor, containment, text extraction) +- Component-specific (date pickers, sliders) +- Options (Dropdown, RadioItems, Checklist) +- Generic props (placeholder, default value, min/max/step) +- Chained callbacks (dynamic prop/options detection) +- Combinations (label + component-specific) +""" + +import pytest + +from dash import Dash, Input, Output, dcc, html + +from tests.unit.mcp.conftest import ( + _app_with_callback, + _desc_for, + _tools_list, + _user_tool, +) + + +def _app_with_layout(layout, *inputs): + app = Dash(__name__) + app.layout = layout + + @app.callback( + Output("out", "children"), + [Input(cid, prop) for cid, prop in inputs], + ) + def update(*args): + return str(args) + + return app + + +def _tool_for(component, input_prop="value"): + app = _app_with_callback(component, input_prop=input_prop) + return _user_tool(_tools_list(app)) + + +# --------------------------------------------------------------------------- +# Labels +# --------------------------------------------------------------------------- + + +class TestLabels: + def test_html_for(self): + app = _app_with_layout( + html.Div( + [ + html.Label("Your Name", htmlFor="inp"), + dcc.Input(id="inp"), + html.Div(id="out"), + ] + ), + ("inp", "value"), + ) + tool = _user_tool(_tools_list(app)) + assert "Your Name" in _desc_for(tool) + + def test_html_for_not_adjacent(self): + app = _app_with_layout( + html.Div( + [ + html.Div(html.Label("Remote Label", htmlFor="inp")), + dcc.Input(id="inp"), + html.Div(id="out"), + ] + ), + ("inp", "value"), + ) + tool = _user_tool(_tools_list(app)) + assert "Remote Label" in _desc_for(tool) + + def test_containment(self): + app = _app_with_layout( + html.Div( + [ + html.Label( + [ + "Pick a city", + dcc.Dropdown(id="city_dd", options=["NYC", "LA"]), + ] + ), + html.Div(id="out"), + ] + ), + ("city_dd", "value"), + ) + tool = _user_tool(_tools_list(app)) + assert "Pick a city" in _desc_for(tool) + + def test_deeply_nested_containment(self): + app = _app_with_layout( + html.Div( + [ + html.Label( + [ + html.Span("Nested Label"), + html.Div(dcc.Input(id="nested_inp")), + ] + ), + html.Div(id="out"), + ] + ), + ("nested_inp", "value"), + ) + tool = _user_tool(_tools_list(app)) + assert "Nested Label" in _desc_for(tool) + + def test_both_htmlfor_and_containment_captured(self): + app = _app_with_layout( + html.Div( + [ + html.Label(["Containment Label", dcc.Input(id="inp")]), + html.Label("HtmlFor Label", htmlFor="inp"), + html.Div(id="out"), + ] + ), + ("inp", "value"), + ) + tool = _user_tool(_tools_list(app)) + desc = _desc_for(tool) + assert "HtmlFor Label" in desc + assert "Containment Label" in desc + + def test_deep_text_extraction(self): + app = _app_with_layout( + html.Div( + [ + html.Label( + html.Div(html.Span(html.B("Deep Text"))), + htmlFor="inp", + ), + dcc.Input(id="inp"), + html.Div(id="out"), + ] + ), + ("inp", "value"), + ) + tool = _user_tool(_tools_list(app)) + assert "Deep Text" in _desc_for(tool) + + def test_multiple_text_nodes(self): + app = _app_with_layout( + html.Div( + [ + html.Label( + [html.B("First"), " ", html.I("Second")], + htmlFor="inp", + ), + dcc.Input(id="inp"), + html.Div(id="out"), + ] + ), + ("inp", "value"), + ) + tool = _user_tool(_tools_list(app)) + desc = _desc_for(tool) + assert "Labeled with: First Second" in desc + + def test_unrelated_label_excluded(self): + app = _app_with_layout( + html.Div( + [ + html.Label("Other Field", htmlFor="other"), + dcc.Input(id="other"), + dcc.Input(id="target"), + html.Div(id="out"), + ] + ), + ("target", "value"), + ) + tool = _user_tool(_tools_list(app)) + desc = _desc_for(tool) + assert "Other Field" not in (desc or "") + + +# --------------------------------------------------------------------------- +# Component-specific: date pickers +# --------------------------------------------------------------------------- + + +class TestDatePickerDescriptions: + def test_single_full_range(self): + dp = dcc.DatePickerSingle( + id="dp", + min_date_allowed="2020-01-01", + max_date_allowed="2025-12-31", + ) + desc = _desc_for(_tool_for(dp, "date"), "val") + assert "2020-01-01" in desc + assert "2025-12-31" in desc + + def test_single_min_only(self): + dp = dcc.DatePickerSingle(id="dp", min_date_allowed="2020-01-01") + desc = _desc_for(_tool_for(dp, "date"), "val") + assert "min_date_allowed: '2020-01-01'" in desc + + def test_single_default_date(self): + dp = dcc.DatePickerSingle(id="dp", date="2024-06-15") + desc = _desc_for(_tool_for(dp, "date"), "val") + assert "date: '2024-06-15'" in desc + + def test_range_with_constraints(self): + dpr = dcc.DatePickerRange( + id="dpr", + min_date_allowed="2020-01-01", + max_date_allowed="2025-12-31", + ) + desc = _desc_for(_tool_for(dpr, "start_date"), "val") + assert "2020-01-01" in desc + + +# --------------------------------------------------------------------------- +# Component-specific: sliders +# --------------------------------------------------------------------------- + + +class TestSliderDescriptions: + def test_min_max(self): + sl = dcc.Slider(id="sl", min=0, max=100) + desc = _desc_for(_tool_for(sl), "val") + assert "min: 0" in desc + assert "max: 100" in desc + + def test_step(self): + sl = dcc.Slider(id="sl", min=0, max=100, step=5) + desc = _desc_for(_tool_for(sl), "val") + assert "step: 5" in desc + + def test_default_value(self): + sl = dcc.Slider(id="sl", min=0, max=100, value=50) + desc = _desc_for(_tool_for(sl), "val") + assert "value: 50" in desc + + def test_marks(self): + sl = dcc.Slider(id="sl", min=0, max=100, marks={0: "Low", 100: "High"}) + desc = _desc_for(_tool_for(sl), "val") + assert "marks: {0: 'Low', 100: 'High'}" in desc + + def test_range_slider_min_max(self): + rs = dcc.RangeSlider(id="rs", min=0, max=100) + desc = _desc_for(_tool_for(rs), "val") + assert "min: 0" in desc + assert "max: 100" in desc + + +# --------------------------------------------------------------------------- +# Options (parametrized across Dropdown, RadioItems, Checklist) +# --------------------------------------------------------------------------- + + +_OPTIONS_COMPONENTS = [ + ("Dropdown", lambda **kw: dcc.Dropdown(id="comp", **kw), "comp"), + ("RadioItems", lambda **kw: dcc.RadioItems(id="comp", **kw), "comp"), + ("Checklist", lambda **kw: dcc.Checklist(id="comp", **kw), "comp"), +] + + +class TestOptionsDescriptions: + @pytest.mark.parametrize( + "name,factory,cid", _OPTIONS_COMPONENTS, ids=[c[0] for c in _OPTIONS_COMPONENTS] + ) + def test_options_shown(self, name, factory, cid): + comp = factory(options=["X", "Y", "Z"]) + desc = _desc_for(_tool_for(comp), "val") + assert "options: ['X', 'Y', 'Z']" in desc + + @pytest.mark.parametrize( + "name,factory,cid", _OPTIONS_COMPONENTS, ids=[c[0] for c in _OPTIONS_COMPONENTS] + ) + def test_default_shown(self, name, factory, cid): + value = ["a"] if name == "Checklist" else "a" + comp = factory(options=["a", "b"], value=value) + desc = _desc_for(_tool_for(comp), "val") + assert f"value: {value!r}" in desc + + def test_dropdown_dict_options(self): + dd = dcc.Dropdown( + id="dd", + options=[ + {"label": "New York", "value": "NYC"}, + ], + ) + assert "NYC" in _desc_for(_tool_for(dd), "val") + + def test_store_storage_type_template(self): + store = dcc.Store(id="store", storage_type="session") + app = _app_with_callback(store, input_prop="data") + tool = _user_tool(_tools_list(app)) + desc = _desc_for(tool, "val") + assert ( + "storage_type: 'session'. Describes how to store the value client-side" + in desc + ) + + def test_many_options_truncated(self): + dd = dcc.Dropdown(id="big", options=[str(i) for i in range(50)], value="0") + app = _app_with_callback(dd) + tool = _user_tool(_tools_list(app)) + desc = _desc_for(tool, "val") + assert "options:" in desc + assert "Use get_dash_component('big', 'options') for the full value" in desc + + +# --------------------------------------------------------------------------- +# Generic props +# --------------------------------------------------------------------------- + + +class TestGenericDescriptions: + def test_placeholder(self): + inp = dcc.Input(id="inp", placeholder="Enter your name") + assert "placeholder: 'Enter your name'" in _desc_for(_tool_for(inp), "val") + + def test_numeric_min_max(self): + inp = dcc.Input(id="inp", type="number", min=0, max=999) + desc = _desc_for(_tool_for(inp), "val") + assert "min: 0" in desc + assert "max: 999" in desc + + def test_step(self): + inp = dcc.Input(id="inp", type="number", min=0, max=100, step=0.1) + assert "step: 0.1" in _desc_for(_tool_for(inp), "val") + + def test_default_value(self): + inp = dcc.Input(id="inp", value="hello") + desc = _desc_for(_tool_for(inp), "val") + assert "value: 'hello'" in desc + + def test_non_text_type(self): + inp = dcc.Input(id="inp", type="email") + assert "type: 'email'" in _desc_for(_tool_for(inp), "val") + + def test_store_default(self): + store = dcc.Store(id="store", data={"key": "value"}) + app = _app_with_callback(store, input_prop="data") + tool = _user_tool(_tools_list(app)) + assert "data: {'key': 'value'}" in _desc_for(tool, "val") + + +# --------------------------------------------------------------------------- +# Chained callbacks +# --------------------------------------------------------------------------- + + +class TestChainedCallbacks: + def test_options_set_by_upstream(self): + app = Dash(__name__) + app.layout = html.Div( + [ + dcc.Dropdown(id="country", options=["US", "CA"], value="US"), + dcc.Dropdown(id="city", options=[], value=None), + html.Div(id="result"), + ] + ) + + @app.callback(Output("city", "options"), Input("country", "value")) + def update_cities(country): + return ["NYC", "LA"] if country == "US" else ["Toronto"] + + @app.callback(Output("result", "children"), Input("city", "value")) + def show_city(city): + return city + + tools = _tools_list(app) + tool = next(t for t in tools if "show_city" in t.name) + desc = _desc_for(tool, "city") + assert "can be updated by tool: `update_cities`" in desc + assert "options:" in desc + + def test_value_set_by_upstream(self): + app = Dash(__name__) + app.layout = html.Div( + [ + dcc.Input(id="source", value=""), + html.Div(id="derived", children=""), + html.Div(id="result"), + ] + ) + + @app.callback(Output("derived", "children"), Input("source", "value")) + def compute_derived(val): + return f"derived: {val}" + + @app.callback(Output("result", "children"), Input("derived", "children")) + def use_derived(val): + return val + + tools = _tools_list(app) + tool = next(t for t in tools if "use_derived" in t.name) + desc = _desc_for(tool, "val") + assert "can be updated by tool: `compute_derived`" in desc + + +# --------------------------------------------------------------------------- +# Combinations +# --------------------------------------------------------------------------- + + +class TestCombinations: + def test_label_with_date_picker(self): + dp = dcc.DatePickerSingle( + id="dp", + min_date_allowed="2020-01-01", + max_date_allowed="2025-12-31", + ) + app = _app_with_layout( + html.Div( + [ + html.Label("Departure Date", htmlFor="dp"), + dp, + html.Div(id="out"), + ] + ), + ("dp", "date"), + ) + tool = _user_tool(_tools_list(app)) + desc = _desc_for(tool) + assert "Departure Date" in desc + assert "2020-01-01" in desc diff --git a/tests/unit/mcp/tools/input_schemas/test_input_schemas.py b/tests/unit/mcp/tools/input_schemas/test_input_schemas.py new file mode 100644 index 0000000000..5350bd955e --- /dev/null +++ b/tests/unit/mcp/tools/input_schemas/test_input_schemas.py @@ -0,0 +1,331 @@ +"""Input schema tests — verifies JSON Schema generation for component properties. + +Tests are organized by concern: +- Static overrides (date pickers, graph, interval, sliders) +- Component introspection (representative samples — full type coverage in test_json_prop_typing) +- Callback annotation overrides (highest priority) +- Required/nullable behavior +""" + +import pytest +from typing import Optional + +from dash import Dash, Input, Output, State, dcc, html + +from tests.unit.mcp.conftest import ( + _app_with_callback, + _schema_for, + _tools_list, + _user_tool, +) + +# --------------------------------------------------------------------------- +# Helpers +# --------------------------------------------------------------------------- + + +def _get_schema(component_type, prop): + _factories = { + "DatePickerSingle": lambda: dcc.DatePickerSingle(id="dp"), + "DatePickerRange": lambda: dcc.DatePickerRange(id="dpr"), + "Graph": lambda: dcc.Graph(id="graph"), + "Interval": lambda: dcc.Interval(id="intv"), + "Input": lambda: dcc.Input(id="inp"), + "Textarea": lambda: dcc.Textarea(id="ta"), + "Tabs": lambda: dcc.Tabs(id="tabs"), + "Dropdown": lambda: dcc.Dropdown(id="dd"), + "RadioItems": lambda: dcc.RadioItems(id="ri"), + "Checklist": lambda: dcc.Checklist(id="cl"), + "Store": lambda: dcc.Store(id="store"), + "Upload": lambda: dcc.Upload(id="upload"), + "Slider": lambda: dcc.Slider(id="sl"), + "RangeSlider": lambda: dcc.RangeSlider(id="rs"), + } + app = _app_with_callback(_factories[component_type](), input_prop=prop) + tool = _user_tool(_tools_list(app)) + return _schema_for(tool) + + +# --------------------------------------------------------------------------- +# Static overrides take priority over introspection +# --------------------------------------------------------------------------- + + +class TestStaticOverrides: + """Verify that overrides win over component introspection.""" + + def test_override_beats_introspection(self): + schema = _get_schema("DatePickerSingle", "date") + # Introspection would return None for this prop; + # override provides a date format with pattern + assert schema["type"] == "string" + assert schema["format"] == "date" + assert "pattern" in schema + + +# --------------------------------------------------------------------------- +# Introspection — representative samples (not exhaustive per-component) +# --------------------------------------------------------------------------- + +INTROSPECTION_CASES = [ + # (component_type, prop, expected_schema) — one per distinct type shape + ( + "Input", + "value", + {"anyOf": [{"type": "string"}, {"type": "number"}, {"type": "null"}]}, + ), + ( + "Input", + "disabled", + { + "anyOf": [ + {"type": "boolean"}, + {"const": "disabled", "type": "string"}, + {"const": "DISABLED", "type": "string"}, + {"type": "null"}, + ] + }, + ), + ("Input", "n_submit", {"anyOf": [{"type": "number"}, {"type": "null"}]}), + ( + "Dropdown", + "value", + { + "anyOf": [ + {"type": "string"}, + {"type": "number"}, + {"type": "boolean"}, + { + "items": { + "anyOf": [ + {"type": "string"}, + {"type": "number"}, + {"type": "boolean"}, + ] + }, + "type": "array", + }, + {"type": "null"}, + ] + }, + ), + ("Dropdown", "options", {"anyOf": [{}, {"type": "null"}]}), + ( + "Checklist", + "value", + { + "anyOf": [ + { + "items": { + "anyOf": [ + {"type": "string"}, + {"type": "number"}, + {"type": "boolean"}, + ] + }, + "type": "array", + }, + {"type": "null"}, + ] + }, + ), + ( + "Store", + "data", + { + "anyOf": [ + {"additionalProperties": True, "type": "object"}, + {"items": {}, "type": "array"}, + {"type": "number"}, + {"type": "string"}, + {"type": "boolean"}, + {"type": "null"}, + ] + }, + ), + ( + "Upload", + "contents", + { + "anyOf": [ + {"type": "string"}, + {"items": {"type": "string"}, "type": "array"}, + {"type": "null"}, + ] + }, + ), + ( + "RangeSlider", + "value", + {"anyOf": [{"items": {"type": "number"}, "type": "array"}, {"type": "null"}]}, + ), + ("Tabs", "value", {"anyOf": [{"type": "string"}, {"type": "null"}]}), +] + + +class TestIntrospection: + """Representative introspection tests — full type coverage in test_json_prop_typing.""" + + @pytest.mark.parametrize( + "component_type,prop,expected", + INTROSPECTION_CASES, + ids=[f"{c}.{p}" for c, p, _ in INTROSPECTION_CASES], + ) + def test_introspected_schema(self, component_type, prop, expected): + assert _get_schema(component_type, prop) == expected + + +# --------------------------------------------------------------------------- +# Callback annotation overrides +# --------------------------------------------------------------------------- + + +def _app_with_annotated_callback(annotation_type, input_prop="disabled"): + app = Dash(__name__) + app.layout = html.Div([dcc.Input(id="inp"), html.Div(id="out")]) + + if annotation_type is None: + + @app.callback(Output("out", "children"), Input("inp", input_prop)) + def update(val): + return str(val) + + else: + + @app.callback(Output("out", "children"), Input("inp", input_prop)) + def update(val: annotation_type): + return str(val) + + return app + + +ANNOTATION_CASES = [ + (str, "disabled", {"type": "string"}), + (int, "value", {"type": "integer"}), + (float, "value", {"type": "number"}), + (bool, "value", {"type": "boolean"}), + (list, "value", {"items": {}, "type": "array"}), + (dict, "value", {"additionalProperties": True, "type": "object"}), + (Optional[int], "value", {"anyOf": [{"type": "integer"}, {"type": "null"}]}), + (Optional[str], "value", {"anyOf": [{"type": "string"}, {"type": "null"}]}), +] + + +class TestAnnotationOverrides: + """Callback type annotations override component schemas.""" + + @pytest.mark.parametrize( + "ann,prop,expected", + ANNOTATION_CASES, + ids=[ + f"{a.__name__ if hasattr(a, '__name__') else a}-{p}" + for a, p, _ in ANNOTATION_CASES + ], + ) + def test_annotation(self, ann, prop, expected): + app = _app_with_annotated_callback(ann, input_prop=prop) + tool = _user_tool(_tools_list(app)) + assert _schema_for(tool, "val") == expected + + def test_no_annotation_uses_introspection(self): + app = _app_with_annotated_callback(None) + tool = _user_tool(_tools_list(app)) + assert _schema_for(tool, "val") == { + "anyOf": [ + {"type": "boolean"}, + {"const": "disabled", "type": "string"}, + {"const": "DISABLED", "type": "string"}, + {"type": "null"}, + ] + } + + +class TestAnnotationNullability: + """Annotations control nullable vs non-nullable schemas.""" + + def test_str_removes_null(self): + app = Dash(__name__) + app.layout = html.Div([dcc.Dropdown(id="dd"), html.Div(id="out")]) + + @app.callback(Output("out", "children"), Input("dd", "value")) + def update(val: str): + return val + + tool = _user_tool(_tools_list(app)) + assert _schema_for(tool, "val") == {"type": "string"} + + def test_optional_preserves_null(self): + app = Dash(__name__) + app.layout = html.Div([dcc.Dropdown(id="dd"), html.Div(id="out")]) + + @app.callback(Output("out", "children"), Input("dd", "value")) + def update(val: Optional[str]): + return val or "" + + tool = _user_tool(_tools_list(app)) + assert _schema_for(tool, "val") == { + "anyOf": [{"type": "string"}, {"type": "null"}] + } + + def test_optional_param_not_required(self): + app = Dash(__name__) + app.layout = html.Div([dcc.Dropdown(id="dd"), html.Div(id="out")]) + + @app.callback(Output("out", "children"), Input("dd", "value")) + def update(val: Optional[str]): + return val or "" + + tool = _user_tool(_tools_list(app)) + assert "val" not in tool.inputSchema.get("required", []) + + +class TestAnnotationWithState: + """Annotations work for State parameters too.""" + + def test_state_annotation_overrides(self): + app = Dash(__name__) + app.layout = html.Div( + [dcc.Input(id="inp"), dcc.Store(id="store"), html.Div(id="out")] + ) + + @app.callback( + Output("out", "children"), + Input("inp", "value"), + State("store", "data"), + ) + def update(val: str, data: dict): + return str(val) + + tool = _user_tool(_tools_list(app)) + assert _schema_for(tool, "val") == {"type": "string"} + assert _schema_for(tool, "data") == { + "additionalProperties": True, + "type": "object", + } + + def test_partial_annotations(self): + app = Dash(__name__) + app.layout = html.Div( + [dcc.Input(id="inp"), dcc.Store(id="store"), html.Div(id="out")] + ) + + @app.callback( + Output("out", "children"), + Input("inp", "value"), + State("store", "data"), + ) + def update(val: int, data): + return str(val) + + tool = _user_tool(_tools_list(app)) + assert _schema_for(tool, "val") == {"type": "integer"} + assert _schema_for(tool, "data") == { + "anyOf": [ + {"additionalProperties": True, "type": "object"}, + {"items": {}, "type": "array"}, + {"type": "number"}, + {"type": "string"}, + {"type": "boolean"}, + {"type": "null"}, + ] + } diff --git a/tests/unit/mcp/tools/input_schemas/test_schema_component_proptypes.py b/tests/unit/mcp/tools/input_schemas/test_schema_component_proptypes.py new file mode 100644 index 0000000000..10b6ae5543 --- /dev/null +++ b/tests/unit/mcp/tools/input_schemas/test_schema_component_proptypes.py @@ -0,0 +1,15 @@ +"""Tests for schema_component_proptypes. + +Only tests our custom logic — pydantic's type-to-schema conversion +is tested by pydantic itself. +""" + +from dash.development.base_component import Component +from dash.mcp.primitives.tools.input_schemas.schema_callback_type_annotations import ( + annotation_to_json_schema, +) + + +class TestComponentTypes: + def test_component_type_maps_to_string(self): + assert annotation_to_json_schema(Component) == {"type": "string"} diff --git a/tests/unit/mcp/tools/test_callback_adapter.py b/tests/unit/mcp/tools/test_callback_adapter.py index 91808d304e..dc3fc041fc 100644 --- a/tests/unit/mcp/tools/test_callback_adapter.py +++ b/tests/unit/mcp/tools/test_callback_adapter.py @@ -1,8 +1,9 @@ """Tests for CallbackAdapter.""" import pytest -from dash import Dash, Input, Output, dcc, html +from dash import Dash, Input, Output, State, dcc, html from dash._get_app import app_context +from mcp.types import Tool from dash.mcp.primitives.tools.callback_adapter_collection import ( CallbackAdapterCollection, @@ -35,6 +36,68 @@ def update(val): return app +@pytest.fixture +def multi_output_app(): + app = Dash(__name__) + app.layout = html.Div( + [ + dcc.Dropdown(id="dd", options=["a", "b"], value="a"), + dcc.Dropdown(id="dd2"), + html.Div(id="out"), + ] + ) + + @app.callback( + Output("dd2", "options"), + Output("out", "children"), + Input("dd", "value"), + ) + def update(val): + return [], val + + app_context.set(app) + app.mcp_callback_map = CallbackAdapterCollection(app) + return app + + +@pytest.fixture +def state_app(): + app = Dash(__name__) + app.layout = html.Div( + [ + html.Button(id="btn"), + dcc.Input(id="inp"), + html.Div(id="out"), + ] + ) + + @app.callback( + Output("out", "children"), + Input("btn", "n_clicks"), + State("inp", "value"), + ) + def update(clicks, val): + return val + + app_context.set(app) + app.mcp_callback_map = CallbackAdapterCollection(app) + return app + + +@pytest.fixture +def typed_app(): + app = Dash(__name__) + app.layout = html.Div([dcc.Input(id="inp"), html.Div(id="out")]) + + @app.callback(Output("out", "children"), Input("inp", "value")) + def update(val: str): + return val + + app_context.set(app) + app.mcp_callback_map = CallbackAdapterCollection(app) + return app + + @pytest.fixture def duplicate_names_app(): app = Dash(__name__) @@ -131,6 +194,52 @@ def test_duplicates_get_unique_names(self, duplicate_names_app): assert names[0] != names[1] +class TestTool: + def test_returns_tool_instance(self, simple_app): + with simple_app.server.test_request_context(): + tool = app_context.get().mcp_callback_map[0].as_mcp_tool + assert isinstance(tool, Tool) + assert tool.name == "update" + + def test_description_includes_docstring(self, simple_app): + with simple_app.server.test_request_context(): + tool = app_context.get().mcp_callback_map[0].as_mcp_tool + assert "Update output." in tool.description + + def test_description_includes_output_target(self, simple_app): + with simple_app.server.test_request_context(): + tool = app_context.get().mcp_callback_map[0].as_mcp_tool + assert "out.children" in tool.description + + def test_param_name_from_function_signature(self, simple_app): + with simple_app.server.test_request_context(): + tool = app_context.get().mcp_callback_map[0].as_mcp_tool + assert "val" in tool.inputSchema["properties"] + + def test_param_has_label_description(self, simple_app): + with simple_app.server.test_request_context(): + tool = app_context.get().mcp_callback_map[0].as_mcp_tool + desc = tool.inputSchema["properties"]["val"].get("description", "") + assert "Your Name" in desc + + def test_state_params_included(self, state_app): + with state_app.server.test_request_context(): + tool = app_context.get().mcp_callback_map[0].as_mcp_tool + props = tool.inputSchema["properties"] + assert set(props.keys()) == {"clicks", "val"} + + def test_multi_output_description(self, multi_output_app): + with multi_output_app.server.test_request_context(): + tool = app_context.get().mcp_callback_map[0].as_mcp_tool + assert "dd2.options" in tool.description + assert "out.children" in tool.description + + def test_typed_annotation_narrows_schema(self, typed_app): + with typed_app.server.test_request_context(): + tool = app_context.get().mcp_callback_map[0].as_mcp_tool + assert tool.inputSchema["properties"]["val"]["type"] == "string" + + class TestGetInitialValue: def test_returns_layout_value(self, simple_app): callback_map = app_context.get().mcp_callback_map @@ -225,3 +334,59 @@ def update(val): app_context.set(app) app.mcp_callback_map = CallbackAdapterCollection(app) assert app.mcp_callback_map[0].is_valid + + +class TestNoInfiniteLoop: + @pytest.mark.timeout(5) + def test_initial_output_does_not_loop(self): + """Building a tool must not trigger infinite re-entry in _initial_output.""" + app = Dash(__name__) + app.layout = html.Div( + [ + dcc.Slider(id="sl", min=0, max=10, value=5), + html.Div(id="out"), + ] + ) + + @app.callback(Output("out", "children"), Input("sl", "value")) + def show(value): + return f"Value: {value}" + + app_context.set(app) + app.mcp_callback_map = CallbackAdapterCollection(app) + + with app.server.test_request_context(): + tool = app.mcp_callback_map[0].as_mcp_tool + assert tool.name == "show" + + @pytest.mark.timeout(5) + def test_chained_callbacks_do_not_loop(self): + """Chained callbacks with initial value resolution must not loop.""" + app = Dash(__name__) + app.layout = html.Div( + [ + dcc.Slider(id="sl", min=0, max=10, value=5), + dcc.Slider(id="sl2", min=0, max=10), + html.Div(id="out"), + ] + ) + + @app.callback(Output("sl2", "value"), Input("sl", "value")) + def sync(v): + return v + + @app.callback( + Output("out", "children"), + Input("sl", "value"), + Input("sl2", "value"), + ) + def show(v1, v2): + return f"{v1} + {v2}" + + app_context.set(app) + app.mcp_callback_map = CallbackAdapterCollection(app) + + with app.server.test_request_context(): + for cb in app.mcp_callback_map: + tool = cb.as_mcp_tool + assert tool.name is not None diff --git a/tests/unit/mcp/tools/test_tool_schema.py b/tests/unit/mcp/tools/test_tool_schema.py new file mode 100644 index 0000000000..49b639834c --- /dev/null +++ b/tests/unit/mcp/tools/test_tool_schema.py @@ -0,0 +1,64 @@ +"""Tool schema tests — what a Dash MCP tool looks like. + +The EXPECTED_TOOL dict below is the canonical reference for the shape of +a callback-generated MCP tool. It doubles as human-readable documentation +and as a test fixture. + +Reference: https://modelcontextprotocol.io/specification/2025-11-25/server/tools +""" + +from tests.unit.mcp.conftest import ( + _make_app, + _tools_list, + _user_tool, +) + +from pydantic import TypeAdapter +from dash.development.base_component import Component +from dash.types import CallbackDispatchResponse + +_DASH_COMPONENT_SCHEMA = TypeAdapter(Component).json_schema() + +EXPECTED_TOOL = { + "name": "update_output", + "description": ( + "my-output.children: Returns content\n" "\n" "Test callback docstring." + ), + "inputSchema": { + "type": "object", + "properties": { + "value": { + "anyOf": [ + {"type": "string"}, + {"type": "integer"}, + {"type": "number"}, + _DASH_COMPONENT_SCHEMA, + { + "items": { + "anyOf": [ + {"type": "string"}, + {"type": "integer"}, + {"type": "number"}, + _DASH_COMPONENT_SCHEMA, + {"type": "null"}, + ] + }, + "type": "array", + }, + {"type": "null"}, + ], + "description": "Input is optional.\nThe children of this component.", + }, + }, + }, + "outputSchema": TypeAdapter(CallbackDispatchResponse).json_schema(), +} + + +class TestToolSchema: + """Verify that the generated tool matches EXPECTED_TOOL exactly.""" + + def test_full_tool(self): + """The entire tool dict matches the expected shape.""" + tool = _user_tool(_tools_list(_make_app())) + assert tool.model_dump(exclude_none=True) == EXPECTED_TOOL From 9060647c0a565597ff7e6cb637ff14cd1133140c Mon Sep 17 00:00:00 2001 From: Adrian Borrmann Date: Wed, 8 Apr 2026 10:51:47 -0600 Subject: [PATCH 25/80] Refactor description sources to accept CallbackAdapter instances --- dash/mcp/primitives/tools/callback_adapter.py | 2 +- .../primitives/tools/descriptions/__init__.py | 16 +++++++++------- .../tools/descriptions/description_docstring.py | 11 ++++++----- .../tools/descriptions/description_outputs.py | 11 ++++++----- 4 files changed, 22 insertions(+), 18 deletions(-) diff --git a/dash/mcp/primitives/tools/callback_adapter.py b/dash/mcp/primitives/tools/callback_adapter.py index c7044de3f4..d142223700 100644 --- a/dash/mcp/primitives/tools/callback_adapter.py +++ b/dash/mcp/primitives/tools/callback_adapter.py @@ -150,7 +150,7 @@ def prevents_initial_call(self) -> bool: @cached_property def _description(self) -> str: - return build_tool_description(self.outputs, self._docstring) + return build_tool_description(self) @cached_property def _input_schema(self) -> dict[str, Any]: diff --git a/dash/mcp/primitives/tools/descriptions/__init__.py b/dash/mcp/primitives/tools/descriptions/__init__.py index d464677251..29cc2840d0 100644 --- a/dash/mcp/primitives/tools/descriptions/__init__.py +++ b/dash/mcp/primitives/tools/descriptions/__init__.py @@ -1,7 +1,7 @@ """Tool-level description generation for MCP tools. Each source shares the same signature: -``(outputs, docstring) -> list[str]`` +``(adapter: CallbackAdapter) -> list[str]`` This is distinct from per-parameter descriptions (in ``input_schemas/input_descriptions/``) which populate @@ -10,23 +10,25 @@ from __future__ import annotations -from typing import Any +from __future__ import annotations + +from typing import TYPE_CHECKING from .description_docstring import callback_docstring from .description_outputs import output_summary +if TYPE_CHECKING: + from dash.mcp.primitives.tools.callback_adapter import CallbackAdapter + _SOURCES = [ output_summary, callback_docstring, ] -def build_tool_description( - outputs: list[dict[str, Any]], - docstring: str | None = None, -) -> str: +def build_tool_description(adapter: CallbackAdapter) -> str: """Build a human-readable description for an MCP tool.""" lines: list[str] = [] for source in _SOURCES: - lines.extend(source(outputs, docstring)) + lines.extend(source(adapter)) return "\n".join(lines) if lines else "Dash callback" diff --git a/dash/mcp/primitives/tools/descriptions/description_docstring.py b/dash/mcp/primitives/tools/descriptions/description_docstring.py index 71cf4d3d5a..21dbeed804 100644 --- a/dash/mcp/primitives/tools/descriptions/description_docstring.py +++ b/dash/mcp/primitives/tools/descriptions/description_docstring.py @@ -2,14 +2,15 @@ from __future__ import annotations -from typing import Any +from typing import TYPE_CHECKING +if TYPE_CHECKING: + from dash.mcp.primitives.tools.callback_adapter import CallbackAdapter -def callback_docstring( - outputs: list[dict[str, Any]], - docstring: str | None = None, -) -> list[str]: + +def callback_docstring(adapter: CallbackAdapter) -> list[str]: """Return the callback's docstring as description lines.""" + docstring = adapter._docstring if docstring: return ["", docstring.strip()] return [] diff --git a/dash/mcp/primitives/tools/descriptions/description_outputs.py b/dash/mcp/primitives/tools/descriptions/description_outputs.py index c174c177f4..986344c75c 100644 --- a/dash/mcp/primitives/tools/descriptions/description_outputs.py +++ b/dash/mcp/primitives/tools/descriptions/description_outputs.py @@ -2,7 +2,10 @@ from __future__ import annotations -from typing import Any +from typing import TYPE_CHECKING + +if TYPE_CHECKING: + from dash.mcp.primitives.tools.callback_adapter import CallbackAdapter _OUTPUT_SEMANTICS: dict[tuple[str | None, str], str] = { ("Graph", "figure"): "Returns chart/visualization data", @@ -26,11 +29,9 @@ } -def output_summary( - outputs: list[dict[str, Any]], - docstring: str | None = None, -) -> list[str]: +def output_summary(adapter: CallbackAdapter) -> list[str]: """Produce a short summary of what the callback outputs represent.""" + outputs = adapter.outputs if not outputs: return ["Dash callback"] From 554ddc0bc00b21ca24147344891c11212011e227 Mon Sep 17 00:00:00 2001 From: Adrian Borrmann Date: Wed, 8 Apr 2026 12:12:17 -0600 Subject: [PATCH 26/80] Fix pylint error --- tests/unit/mcp/conftest.py | 22 +++++++++------------- 1 file changed, 9 insertions(+), 13 deletions(-) diff --git a/tests/unit/mcp/conftest.py b/tests/unit/mcp/conftest.py index 97b8d9c137..83f6e5378c 100644 --- a/tests/unit/mcp/conftest.py +++ b/tests/unit/mcp/conftest.py @@ -1,21 +1,17 @@ +"""Shared helpers for MCP unit tests.""" + import sys -collect_ignore_glob = [] +from dash import Dash, Input, Output, html +from dash._get_app import app_context +collect_ignore_glob = [] if sys.version_info < (3, 10): collect_ignore_glob.append("*") - -"""Shared helpers for MCP unit tests. - -These helpers work directly with Tool objects from CallbackAdapterCollection, -avoiding the MCP server so they can be used before the server is wired up. -""" - -from dash import Dash, Input, Output, html -from dash._get_app import app_context -from dash.mcp.primitives.tools.callback_adapter_collection import ( - CallbackAdapterCollection, -) +else: + from dash.mcp.primitives.tools.callback_adapter_collection import ( # pylint: disable=wrong-import-position + CallbackAdapterCollection, + ) BUILTINS = {"get_dash_component"} From 2cfa19328a2cfaeef0a23c71cd6e764d3af82852 Mon Sep 17 00:00:00 2001 From: Adrian Borrmann Date: Thu, 16 Apr 2026 12:55:37 -0600 Subject: [PATCH 27/80] Fix regression in CallbackAdapter --- dash/mcp/primitives/tools/callback_adapter.py | 7 ++++++- .../tools/output_schemas/schema_callback_response.py | 6 +++--- tests/unit/mcp/tools/test_tool_schema.py | 4 ++-- 3 files changed, 11 insertions(+), 6 deletions(-) diff --git a/dash/mcp/primitives/tools/callback_adapter.py b/dash/mcp/primitives/tools/callback_adapter.py index d142223700..c97c7721ce 100644 --- a/dash/mcp/primitives/tools/callback_adapter.py +++ b/dash/mcp/primitives/tools/callback_adapter.py @@ -385,7 +385,7 @@ def _expand_output_spec( output_id: str, cb_info: dict, resolved_inputs: list[CallbackInput], -) -> list[CallbackOutputTarget]: +) -> CallbackOutputTarget | list[CallbackOutputTarget]: """Build the outputs spec, expanding wildcards to concrete IDs. For wildcard outputs, derives concrete IDs from the resolved inputs. @@ -417,6 +417,11 @@ def _expand_output_spec( else: results.append({"id": pid, "property": prop}) + # Mirror the Dash renderer: single-output callbacks send a bare dict, + # multi-output callbacks send a list. The framework's output value + # matching depends on this shape. + if len(results) == 1: + return results[0] return results diff --git a/dash/mcp/primitives/tools/output_schemas/schema_callback_response.py b/dash/mcp/primitives/tools/output_schemas/schema_callback_response.py index e61a482cba..6962fb4a4f 100644 --- a/dash/mcp/primitives/tools/output_schemas/schema_callback_response.py +++ b/dash/mcp/primitives/tools/output_schemas/schema_callback_response.py @@ -1,4 +1,4 @@ -"""Output schema derived from CallbackDispatchResponse.""" +"""Output schema derived from CallbackExecutionResponse.""" from __future__ import annotations @@ -6,9 +6,9 @@ from pydantic import TypeAdapter -from dash.types import CallbackDispatchResponse +from dash.types import CallbackExecutionResponse -_schema = TypeAdapter(CallbackDispatchResponse).json_schema() +_schema = TypeAdapter(CallbackExecutionResponse).json_schema() def callback_response_schema() -> dict[str, Any]: diff --git a/tests/unit/mcp/tools/test_tool_schema.py b/tests/unit/mcp/tools/test_tool_schema.py index 49b639834c..b39dfe08c9 100644 --- a/tests/unit/mcp/tools/test_tool_schema.py +++ b/tests/unit/mcp/tools/test_tool_schema.py @@ -15,7 +15,7 @@ from pydantic import TypeAdapter from dash.development.base_component import Component -from dash.types import CallbackDispatchResponse +from dash.types import CallbackExecutionResponse _DASH_COMPONENT_SCHEMA = TypeAdapter(Component).json_schema() @@ -51,7 +51,7 @@ }, }, }, - "outputSchema": TypeAdapter(CallbackDispatchResponse).json_schema(), + "outputSchema": TypeAdapter(CallbackExecutionResponse).json_schema(), } From dce8e9aa2bd6b1a5767a0fd83394ae38a1fba2b9 Mon Sep 17 00:00:00 2001 From: Adrian Borrmann Date: Thu, 16 Apr 2026 13:09:05 -0600 Subject: [PATCH 28/80] Refactor tool descriptions/schemas to use a base class (just as resources do) --- .../primitives/tools/descriptions/__init__.py | 21 ++-- .../mcp/primitives/tools/descriptions/base.py | 21 ++++ .../descriptions/description_docstring.py | 15 ++- .../tools/descriptions/description_outputs.py | 67 +++++++------ .../tools/input_schemas/__init__.py | 25 ++--- .../primitives/tools/input_schemas/base.py | 20 ++++ .../input_descriptions/__init__.py | 25 +++-- .../input_schemas/input_descriptions/base.py | 18 ++++ .../description_component_props.py | 98 ++++++++++--------- .../description_docstrings.py | 18 ++-- .../description_html_labels.py | 21 ++-- .../schema_callback_type_annotations.py | 40 ++++---- .../schema_component_proptypes.py | 33 ++++--- .../schema_component_proptypes_overrides.py | 26 ++--- 14 files changed, 267 insertions(+), 181 deletions(-) create mode 100644 dash/mcp/primitives/tools/descriptions/base.py create mode 100644 dash/mcp/primitives/tools/input_schemas/base.py create mode 100644 dash/mcp/primitives/tools/input_schemas/input_descriptions/base.py diff --git a/dash/mcp/primitives/tools/descriptions/__init__.py b/dash/mcp/primitives/tools/descriptions/__init__.py index 29cc2840d0..b32238992c 100644 --- a/dash/mcp/primitives/tools/descriptions/__init__.py +++ b/dash/mcp/primitives/tools/descriptions/__init__.py @@ -1,7 +1,7 @@ """Tool-level description generation for MCP tools. -Each source shares the same signature: -``(adapter: CallbackAdapter) -> list[str]`` +Each source is a ``ToolDescriptionSource`` subclass that can add text +to the tool's description. All sources are accumulated. This is distinct from per-parameter descriptions (in ``input_schemas/input_descriptions/``) which populate @@ -10,25 +10,24 @@ from __future__ import annotations -from __future__ import annotations - from typing import TYPE_CHECKING -from .description_docstring import callback_docstring -from .description_outputs import output_summary +from .base import ToolDescriptionSource +from .description_docstring import DocstringDescription +from .description_outputs import OutputSummaryDescription if TYPE_CHECKING: from dash.mcp.primitives.tools.callback_adapter import CallbackAdapter -_SOURCES = [ - output_summary, - callback_docstring, +_SOURCES: list[type[ToolDescriptionSource]] = [ + OutputSummaryDescription, + DocstringDescription, ] -def build_tool_description(adapter: CallbackAdapter) -> str: +def build_tool_description(callback: CallbackAdapter) -> str: """Build a human-readable description for an MCP tool.""" lines: list[str] = [] for source in _SOURCES: - lines.extend(source(adapter)) + lines.extend(source.describe(callback)) return "\n".join(lines) if lines else "Dash callback" diff --git a/dash/mcp/primitives/tools/descriptions/base.py b/dash/mcp/primitives/tools/descriptions/base.py new file mode 100644 index 0000000000..c069f67918 --- /dev/null +++ b/dash/mcp/primitives/tools/descriptions/base.py @@ -0,0 +1,21 @@ +"""Base class for tool-level description sources.""" + +from __future__ import annotations + +from typing import TYPE_CHECKING + +if TYPE_CHECKING: + from dash.mcp.primitives.tools.callback_adapter import CallbackAdapter + + +class ToolDescriptionSource: + """A source of text that can describe an MCP tool. + + Subclasses implement ``describe`` to return strings that will be + joined into the tool's ``description`` field. All sources are + accumulated — every source can add text to the overall description. + """ + + @classmethod + def describe(cls, callback: CallbackAdapter) -> list[str]: + raise NotImplementedError diff --git a/dash/mcp/primitives/tools/descriptions/description_docstring.py b/dash/mcp/primitives/tools/descriptions/description_docstring.py index 21dbeed804..9bcc697248 100644 --- a/dash/mcp/primitives/tools/descriptions/description_docstring.py +++ b/dash/mcp/primitives/tools/descriptions/description_docstring.py @@ -4,13 +4,18 @@ from typing import TYPE_CHECKING +from .base import ToolDescriptionSource + if TYPE_CHECKING: from dash.mcp.primitives.tools.callback_adapter import CallbackAdapter -def callback_docstring(adapter: CallbackAdapter) -> list[str]: +class DocstringDescription(ToolDescriptionSource): """Return the callback's docstring as description lines.""" - docstring = adapter._docstring - if docstring: - return ["", docstring.strip()] - return [] + + @classmethod + def describe(cls, callback: CallbackAdapter) -> list[str]: + docstring = callback._docstring + if docstring: + return ["", docstring.strip()] + return [] diff --git a/dash/mcp/primitives/tools/descriptions/description_outputs.py b/dash/mcp/primitives/tools/descriptions/description_outputs.py index 986344c75c..06371d7717 100644 --- a/dash/mcp/primitives/tools/descriptions/description_outputs.py +++ b/dash/mcp/primitives/tools/descriptions/description_outputs.py @@ -4,54 +4,53 @@ from typing import TYPE_CHECKING +from .base import ToolDescriptionSource + if TYPE_CHECKING: from dash.mcp.primitives.tools.callback_adapter import CallbackAdapter _OUTPUT_SEMANTICS: dict[tuple[str | None, str], str] = { - ("Graph", "figure"): "Returns chart/visualization data", ("DataTable", "data"): "Returns tabular data", ("DataTable", "columns"): "Returns table column definitions", - ("Dropdown", "options"): "Returns selection options", - ("Dropdown", "value"): "Updates a selection value", - ("RadioItems", "options"): "Returns selection options", - ("Checklist", "options"): "Returns selection options", - ("Store", "data"): "Returns stored data", + ("Store", "data"): "Returns data to be remembered client-side", ("Download", "data"): "Returns downloadable content", ("Markdown", "children"): "Returns formatted text", (None, "figure"): "Returns chart/visualization data", - (None, "data"): "Returns data", - (None, "options"): "Returns selection options", + (None, "options"): "Returns available options", (None, "columns"): "Returns column definitions", (None, "children"): "Returns content", - (None, "value"): "Returns a value", + (None, "value"): "Returns the current value", (None, "style"): "Updates styling", (None, "disabled"): "Updates enabled/disabled state", } -def output_summary(adapter: CallbackAdapter) -> list[str]: +class OutputSummaryDescription(ToolDescriptionSource): """Produce a short summary of what the callback outputs represent.""" - outputs = adapter.outputs - if not outputs: - return ["Dash callback"] - - lines: list[str] = [] - for out in outputs: - comp_id = out["component_id"] - prop = out["property"] - comp_type = out.get("component_type") - - semantic = _OUTPUT_SEMANTICS.get((comp_type, prop)) - if semantic is None: - semantic = _OUTPUT_SEMANTICS.get((None, prop)) - - if semantic is not None: - lines.append(f"- {comp_id}.{prop}: {semantic}") - else: - lines.append(f"- {comp_id}.{prop}") - - n = len(outputs) - if n == 1: - return [lines[0].lstrip("- ")] - header = f"Returns {n} output{'s' if n > 1 else ''}:" - return [header] + lines + + @classmethod + def describe(cls, callback: CallbackAdapter) -> list[str]: + outputs = callback.outputs + if not outputs: + return ["Dash callback"] + + lines: list[str] = [] + for out in outputs: + comp_id = out["component_id"] + prop = out["property"] + comp_type = out.get("component_type") + + semantic = _OUTPUT_SEMANTICS.get((comp_type, prop)) + if semantic is None: + semantic = _OUTPUT_SEMANTICS.get((None, prop)) + + if semantic is not None: + lines.append(f"- {comp_id}.{prop}: {semantic}") + else: + lines.append(f"- {comp_id}.{prop}") + + n = len(outputs) + if n == 1: + return [lines[0].lstrip("- ")] + header = f"Returns {n} output{'s' if n > 1 else ''}:" + return [header] + lines diff --git a/dash/mcp/primitives/tools/input_schemas/__init__.py b/dash/mcp/primitives/tools/input_schemas/__init__.py index 2c1646f56a..9fa82eda55 100644 --- a/dash/mcp/primitives/tools/input_schemas/__init__.py +++ b/dash/mcp/primitives/tools/input_schemas/__init__.py @@ -1,9 +1,8 @@ """Input schema generation for MCP tool inputSchema fields. -Mirrors ``output_schemas/`` which generates ``outputSchema``. - -Each source is tried in priority order. All share the same signature: -``(param: MCPInput) -> dict | None``. +Each source is an ``InputSchemaSource`` subclass that can type +an input parameter. Sources are tried in priority order — first +non-None wins. """ from __future__ import annotations @@ -11,15 +10,17 @@ from typing import Any from dash.mcp.types import MCPInput -from .schema_callback_type_annotations import annotation_to_schema -from .schema_component_proptypes_overrides import get_override_schema -from .schema_component_proptypes import get_component_prop_schema + +from .base import InputSchemaSource +from .schema_callback_type_annotations import AnnotationSchema +from .schema_component_proptypes_overrides import OverrideSchema +from .schema_component_proptypes import ComponentPropSchema from .input_descriptions import get_property_description -_SOURCES = [ - annotation_to_schema, - get_override_schema, - get_component_prop_schema, +_SOURCES: list[type[InputSchemaSource]] = [ + AnnotationSchema, + OverrideSchema, + ComponentPropSchema, ] @@ -31,7 +32,7 @@ def get_input_schema(param: MCPInput) -> dict[str, Any]: """ schema: dict[str, Any] = {} for source in _SOURCES: - result = source(param) + result = source.get_schema(param) if result is not None: schema = result break diff --git a/dash/mcp/primitives/tools/input_schemas/base.py b/dash/mcp/primitives/tools/input_schemas/base.py new file mode 100644 index 0000000000..42fe2352b6 --- /dev/null +++ b/dash/mcp/primitives/tools/input_schemas/base.py @@ -0,0 +1,20 @@ +"""Base class for input schema sources.""" + +from __future__ import annotations + +from typing import Any + +from dash.mcp.types import MCPInput + + +class InputSchemaSource: + """A source of JSON Schema that can type an MCP tool input parameter. + + Subclasses implement ``get_schema`` to return a JSON Schema dict + for the parameter, or ``None`` if this source cannot determine the + type. Sources are tried in priority order — first non-None wins. + """ + + @classmethod + def get_schema(cls, param: MCPInput) -> dict[str, Any] | None: + raise NotImplementedError diff --git a/dash/mcp/primitives/tools/input_schemas/input_descriptions/__init__.py b/dash/mcp/primitives/tools/input_schemas/input_descriptions/__init__.py index e1d1e9f47c..4bc6d8e984 100644 --- a/dash/mcp/primitives/tools/input_schemas/input_descriptions/__init__.py +++ b/dash/mcp/primitives/tools/input_schemas/input_descriptions/__init__.py @@ -1,23 +1,22 @@ """Per-property description generation for MCP tool input parameters. -Each source shares the same signature: -``(param: MCPInput) -> list[str]`` - -Sources are tried in order from most generic to most instance-specific. -All sources that produce lines are combined. +Each source is an ``InputDescriptionSource`` subclass that can add +text to a parameter's description. All sources are accumulated. """ from __future__ import annotations from dash.mcp.types import MCPInput -from .description_component_props import component_props_description -from .description_docstrings import docstring_prop_description -from .description_html_labels import label_description -_SOURCES = [ - docstring_prop_description, - label_description, - component_props_description, +from .base import InputDescriptionSource +from .description_component_props import ComponentPropsDescription +from .description_docstrings import DocstringPropDescription +from .description_html_labels import LabelDescription + +_SOURCES: list[type[InputDescriptionSource]] = [ + DocstringPropDescription, + LabelDescription, + ComponentPropsDescription, ] @@ -27,5 +26,5 @@ def get_property_description(param: MCPInput) -> str | None: if not param.get("required", True): lines.append("Input is optional.") for source in _SOURCES: - lines.extend(source(param)) + lines.extend(source.describe(param)) return "\n".join(lines) if lines else None diff --git a/dash/mcp/primitives/tools/input_schemas/input_descriptions/base.py b/dash/mcp/primitives/tools/input_schemas/input_descriptions/base.py new file mode 100644 index 0000000000..6bfd62da04 --- /dev/null +++ b/dash/mcp/primitives/tools/input_schemas/input_descriptions/base.py @@ -0,0 +1,18 @@ +"""Base class for per-parameter description sources.""" + +from __future__ import annotations + +from dash.mcp.types import MCPInput + + +class InputDescriptionSource: + """A source of text that can describe an MCP tool input parameter. + + Subclasses implement ``describe`` to return strings that will be + added to the callback parameter's description. All sources + are accumulated — every source can add text to the overall description. + """ + + @classmethod + def describe(cls, param: MCPInput) -> list[str]: + raise NotImplementedError diff --git a/dash/mcp/primitives/tools/input_schemas/input_descriptions/description_component_props.py b/dash/mcp/primitives/tools/input_schemas/input_descriptions/description_component_props.py index 6934918260..58b4b4627e 100644 --- a/dash/mcp/primitives/tools/input_schemas/input_descriptions/description_component_props.py +++ b/dash/mcp/primitives/tools/input_schemas/input_descriptions/description_component_props.py @@ -11,6 +11,8 @@ from dash import get_app from dash.mcp.types import MCPInput +from .base import InputDescriptionSource + _MAX_VALUE_LENGTH = 200 _MCP_EXCLUDED_PROPS = {"id", "className", "style"} @@ -25,52 +27,56 @@ } -def component_props_description(param: MCPInput) -> list[str]: - component = param.get("component") - if component is None: - return [] - - component_id = param["component_id"] - cbmap = get_app().mcp_callback_map - prop_lines: list[str] = [] - - for prop_name in getattr(component, "_prop_names", []): - if prop_name in _MCP_EXCLUDED_PROPS: - continue - - upstream = cbmap.find_by_output(f"{component_id}.{prop_name}") - if upstream is not None and not upstream.prevents_initial_call: - value = upstream.initial_output_value(f"{component_id}.{prop_name}") - else: - value = getattr(component, prop_name, None) - tool_name = upstream.tool_name if upstream is not None else None - - if value is None and tool_name is None: - continue - - component_type = param.get("component_type") - template = _PROP_TEMPLATES.get((component_type, prop_name)) - formatted_value = ( - _truncate_large_values(value, component_id, prop_name) - if value is not None - else None - ) - - if template and formatted_value is not None: - line = template.format(value=formatted_value) - elif formatted_value is not None: - line = f"{prop_name}: {formatted_value}" - else: - line = prop_name - - if tool_name: - line += f" (can be updated by tool: `{tool_name}`)" - - prop_lines.append(line) - - if not prop_lines: - return [] - return [f"Component properties for {component_id}:"] + prop_lines +class ComponentPropsDescription(InputDescriptionSource): + """Describe component properties with their current values.""" + + @classmethod + def describe(cls, param: MCPInput) -> list[str]: + component = param.get("component") + if component is None: + return [] + + component_id = param["component_id"] + cbmap = get_app().mcp_callback_map + prop_lines: list[str] = [] + + for prop_name in getattr(component, "_prop_names", []): + if prop_name in _MCP_EXCLUDED_PROPS: + continue + + upstream = cbmap.find_by_output(f"{component_id}.{prop_name}") + if upstream is not None and not upstream.prevents_initial_call: + value = upstream.initial_output_value(f"{component_id}.{prop_name}") + else: + value = getattr(component, prop_name, None) + tool_name = upstream.tool_name if upstream is not None else None + + if value is None and tool_name is None: + continue + + component_type = param.get("component_type") + template = _PROP_TEMPLATES.get((component_type, prop_name)) + formatted_value = ( + _truncate_large_values(value, component_id, prop_name) + if value is not None + else None + ) + + if template and formatted_value is not None: + line = template.format(value=formatted_value) + elif formatted_value is not None: + line = f"{prop_name}: {formatted_value}" + else: + line = prop_name + + if tool_name: + line += f" (can be updated by tool: `{tool_name}`)" + + prop_lines.append(line) + + if not prop_lines: + return [] + return [f"Component properties for {component_id}:"] + prop_lines def _truncate_large_values(value: Any, component_id: str, prop_name: str) -> str: diff --git a/dash/mcp/primitives/tools/input_schemas/input_descriptions/description_docstrings.py b/dash/mcp/primitives/tools/input_schemas/input_descriptions/description_docstrings.py index 1f67c3c0f2..23045625bf 100644 --- a/dash/mcp/primitives/tools/input_schemas/input_descriptions/description_docstrings.py +++ b/dash/mcp/primitives/tools/input_schemas/input_descriptions/description_docstrings.py @@ -19,6 +19,8 @@ from dash.mcp.types import MCPInput +from .base import InputDescriptionSource + _PROP_RE = re.compile( r"^[ ]*- (\w+) \([^)]+\):\s*\n((?:[ ]+.+\n)*)", re.MULTILINE, @@ -29,12 +31,16 @@ _SENTENCE_END = re.compile(r"(?<=[.!?])\s") -def docstring_prop_description(param: MCPInput) -> list[str]: - component = param.get("component") - if component is None: - return [] - desc = _get_prop_description(type(component), param["property"]) - return [desc] if desc else [] +class DocstringPropDescription(InputDescriptionSource): + """Extract property description from the component's docstring.""" + + @classmethod + def describe(cls, param: MCPInput) -> list[str]: + component = param.get("component") + if component is None: + return [] + desc = _get_prop_description(type(component), param["property"]) + return [desc] if desc else [] def _get_prop_description(cls: type, prop: str) -> str | None: diff --git a/dash/mcp/primitives/tools/input_schemas/input_descriptions/description_html_labels.py b/dash/mcp/primitives/tools/input_schemas/input_descriptions/description_html_labels.py index 2c9cd8dea9..111e1eaaf7 100644 --- a/dash/mcp/primitives/tools/input_schemas/input_descriptions/description_html_labels.py +++ b/dash/mcp/primitives/tools/input_schemas/input_descriptions/description_html_labels.py @@ -10,14 +10,19 @@ from dash import get_app from dash.mcp.types import MCPInput +from .base import InputDescriptionSource -def label_description(param: MCPInput) -> list[str]: + +class LabelDescription(InputDescriptionSource): """Return the label text for this component, if any.""" - component_id = param.get("component_id") - if not component_id: + + @classmethod + def describe(cls, param: MCPInput) -> list[str]: + component_id = param.get("component_id") + if not component_id: + return [] + label_map = get_app().mcp_callback_map.component_label_map + texts = label_map.get(component_id, []) + if texts: + return [f"Labeled with: {'; '.join(texts)}"] return [] - label_map = get_app().mcp_callback_map.component_label_map - texts = label_map.get(component_id, []) - if texts: - return [f"Labeled with: {'; '.join(texts)}"] - return [] diff --git a/dash/mcp/primitives/tools/input_schemas/schema_callback_type_annotations.py b/dash/mcp/primitives/tools/input_schemas/schema_callback_type_annotations.py index aee5b17c6f..9cf73653fa 100644 --- a/dash/mcp/primitives/tools/input_schemas/schema_callback_type_annotations.py +++ b/dash/mcp/primitives/tools/input_schemas/schema_callback_type_annotations.py @@ -23,6 +23,8 @@ from dash.development.base_component import Component from dash.mcp.types import MCPInput, is_nullable +from .base import InputSchemaSource + def annotation_to_json_schema(annotation: type) -> dict[str, Any] | None: """Convert a Python type annotation to a JSON Schema dict. @@ -41,27 +43,23 @@ def annotation_to_json_schema(annotation: type) -> dict[str, Any] | None: return None -def annotation_to_schema(param: MCPInput) -> dict[str, Any] | None: - """Convert a callback parameter's type annotation to a JSON Schema dict. - - Returns ``None`` if the annotation is not recognised, meaning the - caller should fall through to the next schema source. +class AnnotationSchema(InputSchemaSource): + """Derive JSON Schema from the callback parameter's type annotation.""" - ``Optional[X]`` produces ``{"type": ["X", "null"]}`` — the user - explicitly chose a nullable type. - """ - annotation = param.get("annotation") - if annotation is None: - return None - schema = annotation_to_json_schema(annotation) - if schema is None: - return None + @classmethod + def get_schema(cls, param: MCPInput) -> dict[str, Any] | None: + annotation = param.get("annotation") + if annotation is None: + return None + schema = annotation_to_json_schema(annotation) + if schema is None: + return None - if is_nullable(annotation) and schema: - t = schema.get("type") - if isinstance(t, str): - schema = {**schema, "type": [t, "null"]} - elif isinstance(t, list) and "null" not in t: - schema = {**schema, "type": [*t, "null"]} + if is_nullable(annotation) and schema: + t = schema.get("type") + if isinstance(t, str): + schema = {**schema, "type": [t, "null"]} + elif isinstance(t, list) and "null" not in t: + schema = {**schema, "type": [*t, "null"]} - return schema + return schema diff --git a/dash/mcp/primitives/tools/input_schemas/schema_component_proptypes.py b/dash/mcp/primitives/tools/input_schemas/schema_component_proptypes.py index 151e391cf4..d7f72d81ff 100644 --- a/dash/mcp/primitives/tools/input_schemas/schema_component_proptypes.py +++ b/dash/mcp/primitives/tools/input_schemas/schema_component_proptypes.py @@ -6,27 +6,32 @@ from typing import Any from dash.mcp.types import MCPInput + +from .base import InputSchemaSource from .schema_callback_type_annotations import annotation_to_json_schema -def get_component_prop_schema(param: MCPInput) -> dict[str, Any] | None: - """Return the JSON Schema for a component property. +class ComponentPropSchema(InputSchemaSource): + """Derive JSON Schema from a component's ``__init__`` type annotations. Inspects the ``__init__`` signature of the component's class. Returns ``None`` if the prop has no annotation. """ - component = param.get("component") - prop = param["property"] - if component is None: - return None - try: - sig = inspect.signature(type(component).__init__) - except (ValueError, TypeError): - return None + @classmethod + def get_schema(cls, param: MCPInput) -> dict[str, Any] | None: + component = param.get("component") + prop = param["property"] + if component is None: + return None + + try: + sig = inspect.signature(type(component).__init__) + except (ValueError, TypeError): + return None - sig_param = sig.parameters.get(prop) - if sig_param is None or sig_param.annotation is inspect.Parameter.empty: - return None + sig_param = sig.parameters.get(prop) + if sig_param is None or sig_param.annotation is inspect.Parameter.empty: + return None - return annotation_to_json_schema(sig_param.annotation) + return annotation_to_json_schema(sig_param.annotation) diff --git a/dash/mcp/primitives/tools/input_schemas/schema_component_proptypes_overrides.py b/dash/mcp/primitives/tools/input_schemas/schema_component_proptypes_overrides.py index 25086896e7..984d493d69 100644 --- a/dash/mcp/primitives/tools/input_schemas/schema_component_proptypes_overrides.py +++ b/dash/mcp/primitives/tools/input_schemas/schema_component_proptypes_overrides.py @@ -7,7 +7,9 @@ from typing import Any from dash.mcp.types import MCPInput -from .schema_component_proptypes import get_component_prop_schema + +from .base import InputSchemaSource +from .schema_component_proptypes import ComponentPropSchema _DATE_SCHEMA = { "type": "string", @@ -18,7 +20,7 @@ def _compute_dropdown_value_schema(param: MCPInput) -> dict[str, Any] | None: """Dropdown values are an array if `multi=True`; scalar values otherwise.""" - schema = get_component_prop_schema(param) + schema = ComponentPropSchema.get_schema(param) if schema is None: return None @@ -46,7 +48,6 @@ def _compute_dropdown_value_schema(param: MCPInput) -> dict[str, Any] | None: ("DatePickerSingle", "date"): _DATE_SCHEMA, ("DatePickerRange", "start_date"): _DATE_SCHEMA, ("DatePickerRange", "end_date"): _DATE_SCHEMA, - # Graph — annotation says "object", we add structured properties. ("Graph", "figure"): { "type": "object", "properties": { @@ -59,12 +60,15 @@ def _compute_dropdown_value_schema(param: MCPInput) -> dict[str, Any] | None: } -def get_override_schema(param: MCPInput) -> dict[str, Any] | None: +class OverrideSchema(InputSchemaSource): """Return a schema override, or None to fall through to introspection.""" - key = (param.get("component_type"), param["property"]) - override = _OVERRIDES.get(key) - if override is None: - return None - if callable(override): - return override(param) - return dict(override) + + @classmethod + def get_schema(cls, param: MCPInput) -> dict[str, Any] | None: + key = (param.get("component_type"), param["property"]) + override = _OVERRIDES.get(key) + if override is None: + return None + if callable(override): + return override(param) + return dict(override) From e29cb9cd40cd38676317d2745ec60c23f4d67c84 Mon Sep 17 00:00:00 2001 From: Adrian Borrmann Date: Tue, 21 Apr 2026 11:24:12 -0600 Subject: [PATCH 29/80] Disable docstrings by default in MCP tool descriptions --- dash/_callback.py | 5 ++ .../descriptions/description_docstring.py | 19 +++++++- tests/unit/mcp/conftest.py | 6 ++- tests/unit/mcp/tools/test_callback_adapter.py | 46 +++++++++++++++++-- 4 files changed, 70 insertions(+), 6 deletions(-) diff --git a/dash/_callback.py b/dash/_callback.py index c8d610fa48..637a332905 100644 --- a/dash/_callback.py +++ b/dash/_callback.py @@ -87,6 +87,7 @@ def callback( websocket: Optional[bool] = False, persistent: Optional[bool] = False, mcp_enabled: bool = True, + mcp_expose_docstring: Optional[bool] = None, **_kwargs, ) -> Callable[[Callable[Params, ReturnVar]], Callable[Params, ReturnVar]]: """ @@ -245,6 +246,7 @@ def callback( websocket=websocket, persistent=persistent, mcp_enabled=mcp_enabled, + mcp_expose_docstring=mcp_expose_docstring, ) return cast( @@ -299,6 +301,7 @@ def insert_callback( websocket=False, persistent=False, mcp_enabled=True, + mcp_expose_docstring=None, ) -> str: if prevent_initial_call is None: prevent_initial_call = config_prevent_initial_callbacks @@ -343,6 +346,7 @@ def insert_callback( "no_output": no_output, "websocket": websocket, "mcp_enabled": mcp_enabled, + "mcp_expose_docstring": mcp_expose_docstring, } callback_list.append(callback_spec) @@ -683,6 +687,7 @@ def register_callback( websocket=_kwargs.get("websocket", False), persistent=_kwargs.get("persistent", False), mcp_enabled=_kwargs.get("mcp_enabled", True), + mcp_expose_docstring=_kwargs.get("mcp_expose_docstring"), ) # pylint: disable=too-many-locals diff --git a/dash/mcp/primitives/tools/descriptions/description_docstring.py b/dash/mcp/primitives/tools/descriptions/description_docstring.py index 9bcc697248..3641d3e1f7 100644 --- a/dash/mcp/primitives/tools/descriptions/description_docstring.py +++ b/dash/mcp/primitives/tools/descriptions/description_docstring.py @@ -4,6 +4,8 @@ from typing import TYPE_CHECKING +from dash import get_app + from .base import ToolDescriptionSource if TYPE_CHECKING: @@ -11,11 +13,26 @@ class DocstringDescription(ToolDescriptionSource): - """Return the callback's docstring as description lines.""" + """Return the callback's docstring as description lines. + + Gated behind an opt-in flag: docstrings may contain sensitive + implementation details that the browser never surfaces to users, + so we don't expose them to MCP clients unless the author opts in + — either per-callback or app-wide. + """ @classmethod def describe(cls, callback: CallbackAdapter) -> list[str]: + if not cls._is_exposed(callback): + return [] docstring = callback._docstring if docstring: return ["", docstring.strip()] return [] + + @classmethod + def _is_exposed(cls, callback: CallbackAdapter) -> bool: + per_callback = callback._cb_info.get("mcp_expose_docstring") + if per_callback is not None: + return per_callback + return get_app().config.get("mcp_expose_docstrings", False) diff --git a/tests/unit/mcp/conftest.py b/tests/unit/mcp/conftest.py index 83f6e5378c..7f85e4af9d 100644 --- a/tests/unit/mcp/conftest.py +++ b/tests/unit/mcp/conftest.py @@ -33,7 +33,11 @@ def _make_app(**kwargs): ] ) - @app.callback(Output("my-output", "children"), Input("my-input", "children")) + @app.callback( + Output("my-output", "children"), + Input("my-input", "children"), + mcp_expose_docstring=True, + ) def update_output(value): """Test callback docstring.""" return f"echo: {value}" diff --git a/tests/unit/mcp/tools/test_callback_adapter.py b/tests/unit/mcp/tools/test_callback_adapter.py index dc3fc041fc..41d9b18c21 100644 --- a/tests/unit/mcp/tools/test_callback_adapter.py +++ b/tests/unit/mcp/tools/test_callback_adapter.py @@ -201,10 +201,48 @@ def test_returns_tool_instance(self, simple_app): assert isinstance(tool, Tool) assert tool.name == "update" - def test_description_includes_docstring(self, simple_app): - with simple_app.server.test_request_context(): - tool = app_context.get().mcp_callback_map[0].as_mcp_tool - assert "Update output." in tool.description + def test_docstring_hidden_by_default(self): + """Callback docstrings are not exposed to MCP by default.""" + app = Dash(__name__) + app.layout = html.Div([dcc.Input(id="inp"), html.Div(id="out")]) + + @app.callback(Output("out", "children"), Input("inp", "value")) + def update(val): + """sensitive callback docstring text that must not leak to LLMs""" + return val + + app_context.set(app) + app.mcp_callback_map = CallbackAdapterCollection(app) + + with app.server.test_request_context(): + tool = app.mcp_callback_map[0].as_mcp_tool + assert ( + "sensitive callback docstring text that must not leak to LLMs" + not in tool.description + ) + + def test_docstring_exposed_when_opted_in_per_callback(self): + app = Dash(__name__) + app.layout = html.Div([dcc.Input(id="inp"), html.Div(id="out")]) + + @app.callback( + Output("out", "children"), + Input("inp", "value"), + mcp_expose_docstring=True, + ) + def update(val): + """intentionally-exposed callback docstring text for the LLM""" + return val + + app_context.set(app) + app.mcp_callback_map = CallbackAdapterCollection(app) + + with app.server.test_request_context(): + tool = app.mcp_callback_map[0].as_mcp_tool + assert ( + "intentionally-exposed callback docstring text for the LLM" + in tool.description + ) def test_description_includes_output_target(self, simple_app): with simple_app.server.test_request_context(): From 991dfa52ad0a9eaef5eba920e42958f5e541b0b9 Mon Sep 17 00:00:00 2001 From: Adrian Borrmann Date: Wed, 22 Apr 2026 11:29:50 -0600 Subject: [PATCH 30/80] lint --- .github/workflows/testing.yml | 2 + .pylintrc | 2 +- dash/mcp/primitives/tools/callback_adapter.py | 3 +- .../descriptions/description_docstring.py | 3 +- .../schema_callback_type_annotations.py | 2 +- package.json | 2 +- .../unit/mcp/tools/input_schemas/__init__.py | 0 .../input_descriptions/__init__.py | 0 .../input_descriptions/test_descriptions.py | 424 ----------------- .../tools/input_schemas/test_input_schemas.py | 331 ------------- .../test_schema_component_proptypes.py | 15 - tests/unit/mcp/tools/test_callback_adapter.py | 430 ----------------- .../mcp/tools/test_mcp_callback_adapter.py | 182 +++++++ .../mcp/tools/test_mcp_input_descriptions.py | 445 ++++++++++++++++++ .../unit/mcp/tools/test_mcp_input_schemas.py | 270 +++++++++++ tests/unit/mcp/tools/test_mcp_tools.py | 358 ++++++++++++++ tests/unit/mcp/tools/test_tool_schema.py | 64 --- 17 files changed, 1264 insertions(+), 1269 deletions(-) delete mode 100644 tests/unit/mcp/tools/input_schemas/__init__.py delete mode 100644 tests/unit/mcp/tools/input_schemas/input_descriptions/__init__.py delete mode 100644 tests/unit/mcp/tools/input_schemas/input_descriptions/test_descriptions.py delete mode 100644 tests/unit/mcp/tools/input_schemas/test_input_schemas.py delete mode 100644 tests/unit/mcp/tools/input_schemas/test_schema_component_proptypes.py delete mode 100644 tests/unit/mcp/tools/test_callback_adapter.py create mode 100644 tests/unit/mcp/tools/test_mcp_callback_adapter.py create mode 100644 tests/unit/mcp/tools/test_mcp_input_descriptions.py create mode 100644 tests/unit/mcp/tools/test_mcp_input_schemas.py create mode 100644 tests/unit/mcp/tools/test_mcp_tools.py delete mode 100644 tests/unit/mcp/tools/test_tool_schema.py diff --git a/.github/workflows/testing.yml b/.github/workflows/testing.yml index 2359695404..358f9fd2d2 100644 --- a/.github/workflows/testing.yml +++ b/.github/workflows/testing.yml @@ -122,6 +122,8 @@ jobs: echo "DISPLAY=:99" >> $GITHUB_ENV - name: Run lint + env: + PYLINT_EXTRA_ARGS: ${{ matrix.python-version == '3.8' && '--ignored-modules=mcp' || '' }} run: npm run lint - name: Run unit tests diff --git a/.pylintrc b/.pylintrc index 7ffb5576b7..39e142048d 100644 --- a/.pylintrc +++ b/.pylintrc @@ -431,7 +431,7 @@ max-returns=6 max-statements=50 # Minimum number of public methods for a class (see R0903). -min-public-methods=2 +min-public-methods=1 [IMPORTS] diff --git a/dash/mcp/primitives/tools/callback_adapter.py b/dash/mcp/primitives/tools/callback_adapter.py index c97c7721ce..c94ba32f38 100644 --- a/dash/mcp/primitives/tools/callback_adapter.py +++ b/dash/mcp/primitives/tools/callback_adapter.py @@ -135,7 +135,8 @@ def output_id(self) -> str: @property def tool_name(self) -> str: - return get_app().mcp_callback_map._tool_names_map[self._output_id] # pylint: disable=protected-access + # pylint: disable-next=protected-access + return get_app().mcp_callback_map._tool_names_map[self._output_id] @cached_property def prevents_initial_call(self) -> bool: diff --git a/dash/mcp/primitives/tools/descriptions/description_docstring.py b/dash/mcp/primitives/tools/descriptions/description_docstring.py index 3641d3e1f7..c34d527077 100644 --- a/dash/mcp/primitives/tools/descriptions/description_docstring.py +++ b/dash/mcp/primitives/tools/descriptions/description_docstring.py @@ -25,13 +25,14 @@ class DocstringDescription(ToolDescriptionSource): def describe(cls, callback: CallbackAdapter) -> list[str]: if not cls._is_exposed(callback): return [] - docstring = callback._docstring + docstring = callback._docstring # pylint: disable=protected-access if docstring: return ["", docstring.strip()] return [] @classmethod def _is_exposed(cls, callback: CallbackAdapter) -> bool: + # pylint: disable-next=protected-access per_callback = callback._cb_info.get("mcp_expose_docstring") if per_callback is not None: return per_callback diff --git a/dash/mcp/primitives/tools/input_schemas/schema_callback_type_annotations.py b/dash/mcp/primitives/tools/input_schemas/schema_callback_type_annotations.py index 9cf73653fa..b862b124d6 100644 --- a/dash/mcp/primitives/tools/input_schemas/schema_callback_type_annotations.py +++ b/dash/mcp/primitives/tools/input_schemas/schema_callback_type_annotations.py @@ -39,7 +39,7 @@ def annotation_to_json_schema(annotation: type) -> dict[str, Any] | None: try: return TypeAdapter(annotation).json_schema() - except Exception: + except Exception: # pylint: disable=broad-exception-caught return None diff --git a/package.json b/package.json index f44a0f2805..0f24f25d81 100644 --- a/package.json +++ b/package.json @@ -14,7 +14,7 @@ "private::build.jupyterlab": "cd @plotly/dash-jupyterlab && jlpm install && jlpm build:pack", "private::lint.black": "black dash tests --exclude 'metadata_test.py|node_modules' --check", "private::lint.flake8": "flake8 dash tests", - "private::lint.pylint-dash": "pylint dash setup.py --rcfile=.pylintrc", + "private::lint.pylint-dash": "pylint dash setup.py --rcfile=.pylintrc ${PYLINT_EXTRA_ARGS:-}", "private::lint.pylint-tests": "pylint tests/unit tests/integration -d all -e C0410,C0413,W0109 --rcfile=.pylintrc", "private::lint.renderer": "cd dash/dash-renderer && npm run lint", "private::test.setup-components": "cd @plotly/dash-test-components && npm ci && npm run build", diff --git a/tests/unit/mcp/tools/input_schemas/__init__.py b/tests/unit/mcp/tools/input_schemas/__init__.py deleted file mode 100644 index e69de29bb2..0000000000 diff --git a/tests/unit/mcp/tools/input_schemas/input_descriptions/__init__.py b/tests/unit/mcp/tools/input_schemas/input_descriptions/__init__.py deleted file mode 100644 index e69de29bb2..0000000000 diff --git a/tests/unit/mcp/tools/input_schemas/input_descriptions/test_descriptions.py b/tests/unit/mcp/tools/input_schemas/input_descriptions/test_descriptions.py deleted file mode 100644 index bc6758d2d3..0000000000 --- a/tests/unit/mcp/tools/input_schemas/input_descriptions/test_descriptions.py +++ /dev/null @@ -1,424 +0,0 @@ -"""Description tests — verifies per-property description generation. - -Tests are organized by description source: -- Labels (htmlFor, containment, text extraction) -- Component-specific (date pickers, sliders) -- Options (Dropdown, RadioItems, Checklist) -- Generic props (placeholder, default value, min/max/step) -- Chained callbacks (dynamic prop/options detection) -- Combinations (label + component-specific) -""" - -import pytest - -from dash import Dash, Input, Output, dcc, html - -from tests.unit.mcp.conftest import ( - _app_with_callback, - _desc_for, - _tools_list, - _user_tool, -) - - -def _app_with_layout(layout, *inputs): - app = Dash(__name__) - app.layout = layout - - @app.callback( - Output("out", "children"), - [Input(cid, prop) for cid, prop in inputs], - ) - def update(*args): - return str(args) - - return app - - -def _tool_for(component, input_prop="value"): - app = _app_with_callback(component, input_prop=input_prop) - return _user_tool(_tools_list(app)) - - -# --------------------------------------------------------------------------- -# Labels -# --------------------------------------------------------------------------- - - -class TestLabels: - def test_html_for(self): - app = _app_with_layout( - html.Div( - [ - html.Label("Your Name", htmlFor="inp"), - dcc.Input(id="inp"), - html.Div(id="out"), - ] - ), - ("inp", "value"), - ) - tool = _user_tool(_tools_list(app)) - assert "Your Name" in _desc_for(tool) - - def test_html_for_not_adjacent(self): - app = _app_with_layout( - html.Div( - [ - html.Div(html.Label("Remote Label", htmlFor="inp")), - dcc.Input(id="inp"), - html.Div(id="out"), - ] - ), - ("inp", "value"), - ) - tool = _user_tool(_tools_list(app)) - assert "Remote Label" in _desc_for(tool) - - def test_containment(self): - app = _app_with_layout( - html.Div( - [ - html.Label( - [ - "Pick a city", - dcc.Dropdown(id="city_dd", options=["NYC", "LA"]), - ] - ), - html.Div(id="out"), - ] - ), - ("city_dd", "value"), - ) - tool = _user_tool(_tools_list(app)) - assert "Pick a city" in _desc_for(tool) - - def test_deeply_nested_containment(self): - app = _app_with_layout( - html.Div( - [ - html.Label( - [ - html.Span("Nested Label"), - html.Div(dcc.Input(id="nested_inp")), - ] - ), - html.Div(id="out"), - ] - ), - ("nested_inp", "value"), - ) - tool = _user_tool(_tools_list(app)) - assert "Nested Label" in _desc_for(tool) - - def test_both_htmlfor_and_containment_captured(self): - app = _app_with_layout( - html.Div( - [ - html.Label(["Containment Label", dcc.Input(id="inp")]), - html.Label("HtmlFor Label", htmlFor="inp"), - html.Div(id="out"), - ] - ), - ("inp", "value"), - ) - tool = _user_tool(_tools_list(app)) - desc = _desc_for(tool) - assert "HtmlFor Label" in desc - assert "Containment Label" in desc - - def test_deep_text_extraction(self): - app = _app_with_layout( - html.Div( - [ - html.Label( - html.Div(html.Span(html.B("Deep Text"))), - htmlFor="inp", - ), - dcc.Input(id="inp"), - html.Div(id="out"), - ] - ), - ("inp", "value"), - ) - tool = _user_tool(_tools_list(app)) - assert "Deep Text" in _desc_for(tool) - - def test_multiple_text_nodes(self): - app = _app_with_layout( - html.Div( - [ - html.Label( - [html.B("First"), " ", html.I("Second")], - htmlFor="inp", - ), - dcc.Input(id="inp"), - html.Div(id="out"), - ] - ), - ("inp", "value"), - ) - tool = _user_tool(_tools_list(app)) - desc = _desc_for(tool) - assert "Labeled with: First Second" in desc - - def test_unrelated_label_excluded(self): - app = _app_with_layout( - html.Div( - [ - html.Label("Other Field", htmlFor="other"), - dcc.Input(id="other"), - dcc.Input(id="target"), - html.Div(id="out"), - ] - ), - ("target", "value"), - ) - tool = _user_tool(_tools_list(app)) - desc = _desc_for(tool) - assert "Other Field" not in (desc or "") - - -# --------------------------------------------------------------------------- -# Component-specific: date pickers -# --------------------------------------------------------------------------- - - -class TestDatePickerDescriptions: - def test_single_full_range(self): - dp = dcc.DatePickerSingle( - id="dp", - min_date_allowed="2020-01-01", - max_date_allowed="2025-12-31", - ) - desc = _desc_for(_tool_for(dp, "date"), "val") - assert "2020-01-01" in desc - assert "2025-12-31" in desc - - def test_single_min_only(self): - dp = dcc.DatePickerSingle(id="dp", min_date_allowed="2020-01-01") - desc = _desc_for(_tool_for(dp, "date"), "val") - assert "min_date_allowed: '2020-01-01'" in desc - - def test_single_default_date(self): - dp = dcc.DatePickerSingle(id="dp", date="2024-06-15") - desc = _desc_for(_tool_for(dp, "date"), "val") - assert "date: '2024-06-15'" in desc - - def test_range_with_constraints(self): - dpr = dcc.DatePickerRange( - id="dpr", - min_date_allowed="2020-01-01", - max_date_allowed="2025-12-31", - ) - desc = _desc_for(_tool_for(dpr, "start_date"), "val") - assert "2020-01-01" in desc - - -# --------------------------------------------------------------------------- -# Component-specific: sliders -# --------------------------------------------------------------------------- - - -class TestSliderDescriptions: - def test_min_max(self): - sl = dcc.Slider(id="sl", min=0, max=100) - desc = _desc_for(_tool_for(sl), "val") - assert "min: 0" in desc - assert "max: 100" in desc - - def test_step(self): - sl = dcc.Slider(id="sl", min=0, max=100, step=5) - desc = _desc_for(_tool_for(sl), "val") - assert "step: 5" in desc - - def test_default_value(self): - sl = dcc.Slider(id="sl", min=0, max=100, value=50) - desc = _desc_for(_tool_for(sl), "val") - assert "value: 50" in desc - - def test_marks(self): - sl = dcc.Slider(id="sl", min=0, max=100, marks={0: "Low", 100: "High"}) - desc = _desc_for(_tool_for(sl), "val") - assert "marks: {0: 'Low', 100: 'High'}" in desc - - def test_range_slider_min_max(self): - rs = dcc.RangeSlider(id="rs", min=0, max=100) - desc = _desc_for(_tool_for(rs), "val") - assert "min: 0" in desc - assert "max: 100" in desc - - -# --------------------------------------------------------------------------- -# Options (parametrized across Dropdown, RadioItems, Checklist) -# --------------------------------------------------------------------------- - - -_OPTIONS_COMPONENTS = [ - ("Dropdown", lambda **kw: dcc.Dropdown(id="comp", **kw), "comp"), - ("RadioItems", lambda **kw: dcc.RadioItems(id="comp", **kw), "comp"), - ("Checklist", lambda **kw: dcc.Checklist(id="comp", **kw), "comp"), -] - - -class TestOptionsDescriptions: - @pytest.mark.parametrize( - "name,factory,cid", _OPTIONS_COMPONENTS, ids=[c[0] for c in _OPTIONS_COMPONENTS] - ) - def test_options_shown(self, name, factory, cid): - comp = factory(options=["X", "Y", "Z"]) - desc = _desc_for(_tool_for(comp), "val") - assert "options: ['X', 'Y', 'Z']" in desc - - @pytest.mark.parametrize( - "name,factory,cid", _OPTIONS_COMPONENTS, ids=[c[0] for c in _OPTIONS_COMPONENTS] - ) - def test_default_shown(self, name, factory, cid): - value = ["a"] if name == "Checklist" else "a" - comp = factory(options=["a", "b"], value=value) - desc = _desc_for(_tool_for(comp), "val") - assert f"value: {value!r}" in desc - - def test_dropdown_dict_options(self): - dd = dcc.Dropdown( - id="dd", - options=[ - {"label": "New York", "value": "NYC"}, - ], - ) - assert "NYC" in _desc_for(_tool_for(dd), "val") - - def test_store_storage_type_template(self): - store = dcc.Store(id="store", storage_type="session") - app = _app_with_callback(store, input_prop="data") - tool = _user_tool(_tools_list(app)) - desc = _desc_for(tool, "val") - assert ( - "storage_type: 'session'. Describes how to store the value client-side" - in desc - ) - - def test_many_options_truncated(self): - dd = dcc.Dropdown(id="big", options=[str(i) for i in range(50)], value="0") - app = _app_with_callback(dd) - tool = _user_tool(_tools_list(app)) - desc = _desc_for(tool, "val") - assert "options:" in desc - assert "Use get_dash_component('big', 'options') for the full value" in desc - - -# --------------------------------------------------------------------------- -# Generic props -# --------------------------------------------------------------------------- - - -class TestGenericDescriptions: - def test_placeholder(self): - inp = dcc.Input(id="inp", placeholder="Enter your name") - assert "placeholder: 'Enter your name'" in _desc_for(_tool_for(inp), "val") - - def test_numeric_min_max(self): - inp = dcc.Input(id="inp", type="number", min=0, max=999) - desc = _desc_for(_tool_for(inp), "val") - assert "min: 0" in desc - assert "max: 999" in desc - - def test_step(self): - inp = dcc.Input(id="inp", type="number", min=0, max=100, step=0.1) - assert "step: 0.1" in _desc_for(_tool_for(inp), "val") - - def test_default_value(self): - inp = dcc.Input(id="inp", value="hello") - desc = _desc_for(_tool_for(inp), "val") - assert "value: 'hello'" in desc - - def test_non_text_type(self): - inp = dcc.Input(id="inp", type="email") - assert "type: 'email'" in _desc_for(_tool_for(inp), "val") - - def test_store_default(self): - store = dcc.Store(id="store", data={"key": "value"}) - app = _app_with_callback(store, input_prop="data") - tool = _user_tool(_tools_list(app)) - assert "data: {'key': 'value'}" in _desc_for(tool, "val") - - -# --------------------------------------------------------------------------- -# Chained callbacks -# --------------------------------------------------------------------------- - - -class TestChainedCallbacks: - def test_options_set_by_upstream(self): - app = Dash(__name__) - app.layout = html.Div( - [ - dcc.Dropdown(id="country", options=["US", "CA"], value="US"), - dcc.Dropdown(id="city", options=[], value=None), - html.Div(id="result"), - ] - ) - - @app.callback(Output("city", "options"), Input("country", "value")) - def update_cities(country): - return ["NYC", "LA"] if country == "US" else ["Toronto"] - - @app.callback(Output("result", "children"), Input("city", "value")) - def show_city(city): - return city - - tools = _tools_list(app) - tool = next(t for t in tools if "show_city" in t.name) - desc = _desc_for(tool, "city") - assert "can be updated by tool: `update_cities`" in desc - assert "options:" in desc - - def test_value_set_by_upstream(self): - app = Dash(__name__) - app.layout = html.Div( - [ - dcc.Input(id="source", value=""), - html.Div(id="derived", children=""), - html.Div(id="result"), - ] - ) - - @app.callback(Output("derived", "children"), Input("source", "value")) - def compute_derived(val): - return f"derived: {val}" - - @app.callback(Output("result", "children"), Input("derived", "children")) - def use_derived(val): - return val - - tools = _tools_list(app) - tool = next(t for t in tools if "use_derived" in t.name) - desc = _desc_for(tool, "val") - assert "can be updated by tool: `compute_derived`" in desc - - -# --------------------------------------------------------------------------- -# Combinations -# --------------------------------------------------------------------------- - - -class TestCombinations: - def test_label_with_date_picker(self): - dp = dcc.DatePickerSingle( - id="dp", - min_date_allowed="2020-01-01", - max_date_allowed="2025-12-31", - ) - app = _app_with_layout( - html.Div( - [ - html.Label("Departure Date", htmlFor="dp"), - dp, - html.Div(id="out"), - ] - ), - ("dp", "date"), - ) - tool = _user_tool(_tools_list(app)) - desc = _desc_for(tool) - assert "Departure Date" in desc - assert "2020-01-01" in desc diff --git a/tests/unit/mcp/tools/input_schemas/test_input_schemas.py b/tests/unit/mcp/tools/input_schemas/test_input_schemas.py deleted file mode 100644 index 5350bd955e..0000000000 --- a/tests/unit/mcp/tools/input_schemas/test_input_schemas.py +++ /dev/null @@ -1,331 +0,0 @@ -"""Input schema tests — verifies JSON Schema generation for component properties. - -Tests are organized by concern: -- Static overrides (date pickers, graph, interval, sliders) -- Component introspection (representative samples — full type coverage in test_json_prop_typing) -- Callback annotation overrides (highest priority) -- Required/nullable behavior -""" - -import pytest -from typing import Optional - -from dash import Dash, Input, Output, State, dcc, html - -from tests.unit.mcp.conftest import ( - _app_with_callback, - _schema_for, - _tools_list, - _user_tool, -) - -# --------------------------------------------------------------------------- -# Helpers -# --------------------------------------------------------------------------- - - -def _get_schema(component_type, prop): - _factories = { - "DatePickerSingle": lambda: dcc.DatePickerSingle(id="dp"), - "DatePickerRange": lambda: dcc.DatePickerRange(id="dpr"), - "Graph": lambda: dcc.Graph(id="graph"), - "Interval": lambda: dcc.Interval(id="intv"), - "Input": lambda: dcc.Input(id="inp"), - "Textarea": lambda: dcc.Textarea(id="ta"), - "Tabs": lambda: dcc.Tabs(id="tabs"), - "Dropdown": lambda: dcc.Dropdown(id="dd"), - "RadioItems": lambda: dcc.RadioItems(id="ri"), - "Checklist": lambda: dcc.Checklist(id="cl"), - "Store": lambda: dcc.Store(id="store"), - "Upload": lambda: dcc.Upload(id="upload"), - "Slider": lambda: dcc.Slider(id="sl"), - "RangeSlider": lambda: dcc.RangeSlider(id="rs"), - } - app = _app_with_callback(_factories[component_type](), input_prop=prop) - tool = _user_tool(_tools_list(app)) - return _schema_for(tool) - - -# --------------------------------------------------------------------------- -# Static overrides take priority over introspection -# --------------------------------------------------------------------------- - - -class TestStaticOverrides: - """Verify that overrides win over component introspection.""" - - def test_override_beats_introspection(self): - schema = _get_schema("DatePickerSingle", "date") - # Introspection would return None for this prop; - # override provides a date format with pattern - assert schema["type"] == "string" - assert schema["format"] == "date" - assert "pattern" in schema - - -# --------------------------------------------------------------------------- -# Introspection — representative samples (not exhaustive per-component) -# --------------------------------------------------------------------------- - -INTROSPECTION_CASES = [ - # (component_type, prop, expected_schema) — one per distinct type shape - ( - "Input", - "value", - {"anyOf": [{"type": "string"}, {"type": "number"}, {"type": "null"}]}, - ), - ( - "Input", - "disabled", - { - "anyOf": [ - {"type": "boolean"}, - {"const": "disabled", "type": "string"}, - {"const": "DISABLED", "type": "string"}, - {"type": "null"}, - ] - }, - ), - ("Input", "n_submit", {"anyOf": [{"type": "number"}, {"type": "null"}]}), - ( - "Dropdown", - "value", - { - "anyOf": [ - {"type": "string"}, - {"type": "number"}, - {"type": "boolean"}, - { - "items": { - "anyOf": [ - {"type": "string"}, - {"type": "number"}, - {"type": "boolean"}, - ] - }, - "type": "array", - }, - {"type": "null"}, - ] - }, - ), - ("Dropdown", "options", {"anyOf": [{}, {"type": "null"}]}), - ( - "Checklist", - "value", - { - "anyOf": [ - { - "items": { - "anyOf": [ - {"type": "string"}, - {"type": "number"}, - {"type": "boolean"}, - ] - }, - "type": "array", - }, - {"type": "null"}, - ] - }, - ), - ( - "Store", - "data", - { - "anyOf": [ - {"additionalProperties": True, "type": "object"}, - {"items": {}, "type": "array"}, - {"type": "number"}, - {"type": "string"}, - {"type": "boolean"}, - {"type": "null"}, - ] - }, - ), - ( - "Upload", - "contents", - { - "anyOf": [ - {"type": "string"}, - {"items": {"type": "string"}, "type": "array"}, - {"type": "null"}, - ] - }, - ), - ( - "RangeSlider", - "value", - {"anyOf": [{"items": {"type": "number"}, "type": "array"}, {"type": "null"}]}, - ), - ("Tabs", "value", {"anyOf": [{"type": "string"}, {"type": "null"}]}), -] - - -class TestIntrospection: - """Representative introspection tests — full type coverage in test_json_prop_typing.""" - - @pytest.mark.parametrize( - "component_type,prop,expected", - INTROSPECTION_CASES, - ids=[f"{c}.{p}" for c, p, _ in INTROSPECTION_CASES], - ) - def test_introspected_schema(self, component_type, prop, expected): - assert _get_schema(component_type, prop) == expected - - -# --------------------------------------------------------------------------- -# Callback annotation overrides -# --------------------------------------------------------------------------- - - -def _app_with_annotated_callback(annotation_type, input_prop="disabled"): - app = Dash(__name__) - app.layout = html.Div([dcc.Input(id="inp"), html.Div(id="out")]) - - if annotation_type is None: - - @app.callback(Output("out", "children"), Input("inp", input_prop)) - def update(val): - return str(val) - - else: - - @app.callback(Output("out", "children"), Input("inp", input_prop)) - def update(val: annotation_type): - return str(val) - - return app - - -ANNOTATION_CASES = [ - (str, "disabled", {"type": "string"}), - (int, "value", {"type": "integer"}), - (float, "value", {"type": "number"}), - (bool, "value", {"type": "boolean"}), - (list, "value", {"items": {}, "type": "array"}), - (dict, "value", {"additionalProperties": True, "type": "object"}), - (Optional[int], "value", {"anyOf": [{"type": "integer"}, {"type": "null"}]}), - (Optional[str], "value", {"anyOf": [{"type": "string"}, {"type": "null"}]}), -] - - -class TestAnnotationOverrides: - """Callback type annotations override component schemas.""" - - @pytest.mark.parametrize( - "ann,prop,expected", - ANNOTATION_CASES, - ids=[ - f"{a.__name__ if hasattr(a, '__name__') else a}-{p}" - for a, p, _ in ANNOTATION_CASES - ], - ) - def test_annotation(self, ann, prop, expected): - app = _app_with_annotated_callback(ann, input_prop=prop) - tool = _user_tool(_tools_list(app)) - assert _schema_for(tool, "val") == expected - - def test_no_annotation_uses_introspection(self): - app = _app_with_annotated_callback(None) - tool = _user_tool(_tools_list(app)) - assert _schema_for(tool, "val") == { - "anyOf": [ - {"type": "boolean"}, - {"const": "disabled", "type": "string"}, - {"const": "DISABLED", "type": "string"}, - {"type": "null"}, - ] - } - - -class TestAnnotationNullability: - """Annotations control nullable vs non-nullable schemas.""" - - def test_str_removes_null(self): - app = Dash(__name__) - app.layout = html.Div([dcc.Dropdown(id="dd"), html.Div(id="out")]) - - @app.callback(Output("out", "children"), Input("dd", "value")) - def update(val: str): - return val - - tool = _user_tool(_tools_list(app)) - assert _schema_for(tool, "val") == {"type": "string"} - - def test_optional_preserves_null(self): - app = Dash(__name__) - app.layout = html.Div([dcc.Dropdown(id="dd"), html.Div(id="out")]) - - @app.callback(Output("out", "children"), Input("dd", "value")) - def update(val: Optional[str]): - return val or "" - - tool = _user_tool(_tools_list(app)) - assert _schema_for(tool, "val") == { - "anyOf": [{"type": "string"}, {"type": "null"}] - } - - def test_optional_param_not_required(self): - app = Dash(__name__) - app.layout = html.Div([dcc.Dropdown(id="dd"), html.Div(id="out")]) - - @app.callback(Output("out", "children"), Input("dd", "value")) - def update(val: Optional[str]): - return val or "" - - tool = _user_tool(_tools_list(app)) - assert "val" not in tool.inputSchema.get("required", []) - - -class TestAnnotationWithState: - """Annotations work for State parameters too.""" - - def test_state_annotation_overrides(self): - app = Dash(__name__) - app.layout = html.Div( - [dcc.Input(id="inp"), dcc.Store(id="store"), html.Div(id="out")] - ) - - @app.callback( - Output("out", "children"), - Input("inp", "value"), - State("store", "data"), - ) - def update(val: str, data: dict): - return str(val) - - tool = _user_tool(_tools_list(app)) - assert _schema_for(tool, "val") == {"type": "string"} - assert _schema_for(tool, "data") == { - "additionalProperties": True, - "type": "object", - } - - def test_partial_annotations(self): - app = Dash(__name__) - app.layout = html.Div( - [dcc.Input(id="inp"), dcc.Store(id="store"), html.Div(id="out")] - ) - - @app.callback( - Output("out", "children"), - Input("inp", "value"), - State("store", "data"), - ) - def update(val: int, data): - return str(val) - - tool = _user_tool(_tools_list(app)) - assert _schema_for(tool, "val") == {"type": "integer"} - assert _schema_for(tool, "data") == { - "anyOf": [ - {"additionalProperties": True, "type": "object"}, - {"items": {}, "type": "array"}, - {"type": "number"}, - {"type": "string"}, - {"type": "boolean"}, - {"type": "null"}, - ] - } diff --git a/tests/unit/mcp/tools/input_schemas/test_schema_component_proptypes.py b/tests/unit/mcp/tools/input_schemas/test_schema_component_proptypes.py deleted file mode 100644 index 10b6ae5543..0000000000 --- a/tests/unit/mcp/tools/input_schemas/test_schema_component_proptypes.py +++ /dev/null @@ -1,15 +0,0 @@ -"""Tests for schema_component_proptypes. - -Only tests our custom logic — pydantic's type-to-schema conversion -is tested by pydantic itself. -""" - -from dash.development.base_component import Component -from dash.mcp.primitives.tools.input_schemas.schema_callback_type_annotations import ( - annotation_to_json_schema, -) - - -class TestComponentTypes: - def test_component_type_maps_to_string(self): - assert annotation_to_json_schema(Component) == {"type": "string"} diff --git a/tests/unit/mcp/tools/test_callback_adapter.py b/tests/unit/mcp/tools/test_callback_adapter.py deleted file mode 100644 index 41d9b18c21..0000000000 --- a/tests/unit/mcp/tools/test_callback_adapter.py +++ /dev/null @@ -1,430 +0,0 @@ -"""Tests for CallbackAdapter.""" - -import pytest -from dash import Dash, Input, Output, State, dcc, html -from dash._get_app import app_context -from mcp.types import Tool - -from dash.mcp.primitives.tools.callback_adapter_collection import ( - CallbackAdapterCollection, -) - - -# --------------------------------------------------------------------------- -# Fixtures -# --------------------------------------------------------------------------- - - -@pytest.fixture -def simple_app(): - app = Dash(__name__) - app.layout = html.Div( - [ - html.Label("Your Name", htmlFor="inp"), - dcc.Input(id="inp", type="text"), - html.Div(id="out"), - ] - ) - - @app.callback(Output("out", "children"), Input("inp", "value")) - def update(val): - """Update output.""" - return val - - app_context.set(app) - app.mcp_callback_map = CallbackAdapterCollection(app) - return app - - -@pytest.fixture -def multi_output_app(): - app = Dash(__name__) - app.layout = html.Div( - [ - dcc.Dropdown(id="dd", options=["a", "b"], value="a"), - dcc.Dropdown(id="dd2"), - html.Div(id="out"), - ] - ) - - @app.callback( - Output("dd2", "options"), - Output("out", "children"), - Input("dd", "value"), - ) - def update(val): - return [], val - - app_context.set(app) - app.mcp_callback_map = CallbackAdapterCollection(app) - return app - - -@pytest.fixture -def state_app(): - app = Dash(__name__) - app.layout = html.Div( - [ - html.Button(id="btn"), - dcc.Input(id="inp"), - html.Div(id="out"), - ] - ) - - @app.callback( - Output("out", "children"), - Input("btn", "n_clicks"), - State("inp", "value"), - ) - def update(clicks, val): - return val - - app_context.set(app) - app.mcp_callback_map = CallbackAdapterCollection(app) - return app - - -@pytest.fixture -def typed_app(): - app = Dash(__name__) - app.layout = html.Div([dcc.Input(id="inp"), html.Div(id="out")]) - - @app.callback(Output("out", "children"), Input("inp", "value")) - def update(val: str): - return val - - app_context.set(app) - app.mcp_callback_map = CallbackAdapterCollection(app) - return app - - -@pytest.fixture -def duplicate_names_app(): - app = Dash(__name__) - app.layout = html.Div( - [ - html.Div(id="in1"), - html.Div(id="out1"), - html.Div(id="in2"), - html.Div(id="out2"), - ] - ) - - @app.callback(Output("out1", "children"), Input("in1", "children")) - def cb(v): - return v - - @app.callback(Output("out2", "children"), Input("in2", "children")) - def cb(v): # noqa: F811 - return v - - app_context.set(app) - app.mcp_callback_map = CallbackAdapterCollection(app) - return app - - -# --------------------------------------------------------------------------- -# Tests -# --------------------------------------------------------------------------- - - -class TestFromApp: - def test_returns_list(self, simple_app): - assert len(app_context.get().mcp_callback_map) == 1 - - def test_excludes_clientside(self): - app = Dash(__name__) - app.layout = html.Div( - [ - html.Button(id="btn"), - html.Div(id="cs-out"), - html.Div(id="srv-out"), - ] - ) - app.clientside_callback( - "function(n) { return n; }", - Output("cs-out", "children"), - Input("btn", "n_clicks"), - ) - - @app.callback(Output("srv-out", "children"), Input("btn", "n_clicks")) - def server_cb(n): - return str(n) - - app_context.set(app) - app.mcp_callback_map = CallbackAdapterCollection(app) - - names = [a.tool_name for a in app.mcp_callback_map] - assert names == ["server_cb"] - - def test_excludes_mcp_disabled(self): - app = Dash(__name__) - app.layout = html.Div( - [ - dcc.Input(id="inp"), - html.Div(id="out1"), - html.Div(id="out2"), - ] - ) - - @app.callback(Output("out1", "children"), Input("inp", "value")) - def visible(val): - return val - - @app.callback( - Output("out2", "children"), Input("inp", "value"), mcp_enabled=False - ) - def hidden(val): - return val - - app_context.set(app) - app.mcp_callback_map = CallbackAdapterCollection(app) - names = [a.tool_name for a in app.mcp_callback_map] - assert "visible" in names - assert "hidden" not in names - - -class TestToolName: - def test_uses_func_name(self, simple_app): - assert app_context.get().mcp_callback_map[0].tool_name == "update" - - def test_duplicates_get_unique_names(self, duplicate_names_app): - names = [a.tool_name for a in app_context.get().mcp_callback_map] - assert len(names) == 2 - assert names[0] != names[1] - - -class TestTool: - def test_returns_tool_instance(self, simple_app): - with simple_app.server.test_request_context(): - tool = app_context.get().mcp_callback_map[0].as_mcp_tool - assert isinstance(tool, Tool) - assert tool.name == "update" - - def test_docstring_hidden_by_default(self): - """Callback docstrings are not exposed to MCP by default.""" - app = Dash(__name__) - app.layout = html.Div([dcc.Input(id="inp"), html.Div(id="out")]) - - @app.callback(Output("out", "children"), Input("inp", "value")) - def update(val): - """sensitive callback docstring text that must not leak to LLMs""" - return val - - app_context.set(app) - app.mcp_callback_map = CallbackAdapterCollection(app) - - with app.server.test_request_context(): - tool = app.mcp_callback_map[0].as_mcp_tool - assert ( - "sensitive callback docstring text that must not leak to LLMs" - not in tool.description - ) - - def test_docstring_exposed_when_opted_in_per_callback(self): - app = Dash(__name__) - app.layout = html.Div([dcc.Input(id="inp"), html.Div(id="out")]) - - @app.callback( - Output("out", "children"), - Input("inp", "value"), - mcp_expose_docstring=True, - ) - def update(val): - """intentionally-exposed callback docstring text for the LLM""" - return val - - app_context.set(app) - app.mcp_callback_map = CallbackAdapterCollection(app) - - with app.server.test_request_context(): - tool = app.mcp_callback_map[0].as_mcp_tool - assert ( - "intentionally-exposed callback docstring text for the LLM" - in tool.description - ) - - def test_description_includes_output_target(self, simple_app): - with simple_app.server.test_request_context(): - tool = app_context.get().mcp_callback_map[0].as_mcp_tool - assert "out.children" in tool.description - - def test_param_name_from_function_signature(self, simple_app): - with simple_app.server.test_request_context(): - tool = app_context.get().mcp_callback_map[0].as_mcp_tool - assert "val" in tool.inputSchema["properties"] - - def test_param_has_label_description(self, simple_app): - with simple_app.server.test_request_context(): - tool = app_context.get().mcp_callback_map[0].as_mcp_tool - desc = tool.inputSchema["properties"]["val"].get("description", "") - assert "Your Name" in desc - - def test_state_params_included(self, state_app): - with state_app.server.test_request_context(): - tool = app_context.get().mcp_callback_map[0].as_mcp_tool - props = tool.inputSchema["properties"] - assert set(props.keys()) == {"clicks", "val"} - - def test_multi_output_description(self, multi_output_app): - with multi_output_app.server.test_request_context(): - tool = app_context.get().mcp_callback_map[0].as_mcp_tool - assert "dd2.options" in tool.description - assert "out.children" in tool.description - - def test_typed_annotation_narrows_schema(self, typed_app): - with typed_app.server.test_request_context(): - tool = app_context.get().mcp_callback_map[0].as_mcp_tool - assert tool.inputSchema["properties"]["val"]["type"] == "string" - - -class TestGetInitialValue: - def test_returns_layout_value(self, simple_app): - callback_map = app_context.get().mcp_callback_map - # Input with no value set — returns None (layout default for dcc.Input) - assert callback_map.get_initial_value("inp.value") is None - - def test_returns_set_value(self): - app = Dash(__name__) - app.layout = html.Div( - [ - dcc.Dropdown(id="dd", options=["a", "b"], value="a"), - html.Div(id="out"), - ] - ) - - @app.callback(Output("out", "children"), Input("dd", "value")) - def update(selected): - return selected - - app_context.set(app) - app.mcp_callback_map = CallbackAdapterCollection(app) - assert app.mcp_callback_map.get_initial_value("dd.value") == "a" - - def test_initial_callback_makes_param_required(self): - """A param with None in layout but set by an initial callback is required.""" - app = Dash(__name__) - app.layout = html.Div( - [ - dcc.Dropdown( - id="country", options=["France", "Germany"], value="France" - ), - dcc.Dropdown(id="city"), # value=None in layout - html.Div(id="out"), - ] - ) - - @app.callback( - Output("city", "options"), - Output("city", "value"), - Input("country", "value"), - ) - def update_cities(country): - return [{"label": "Paris", "value": "Paris"}], "Paris" - - @app.callback(Output("out", "children"), Input("city", "value")) - def show_city(city): - return f"Selected: {city}" - - app_context.set(app) - app.mcp_callback_map = CallbackAdapterCollection(app) - - # city.value is None in layout but "Paris" after initial callback - with app.server.test_request_context(): - show_city_cb = app.mcp_callback_map.find_by_tool_name("show_city") - city_param = show_city_cb.inputs[0] - assert city_param["name"] == "city" - assert city_param["required"] is True # not optional despite None in layout - - -class TestIsValid: - def test_valid_when_inputs_in_layout(self, simple_app): - assert app_context.get().mcp_callback_map[0].is_valid - - def test_invalid_when_input_not_in_layout(self): - app = Dash(__name__) - app.layout = html.Div([html.Div(id="out")]) - - @app.callback(Output("out", "children"), Input("nonexistent", "value")) - def update(val): - return val - - app_context.set(app) - app.mcp_callback_map = CallbackAdapterCollection(app) - assert not app.mcp_callback_map[0].is_valid - - def test_pattern_matching_ids_always_valid(self): - app = Dash(__name__) - app.layout = html.Div( - [ - dcc.Input(id={"type": "field", "index": 0}, value="a"), - html.Div(id="out"), - ] - ) - - @app.callback( - Output("out", "children"), - Input({"type": "field", "index": 0}, "value"), - ) - def update(val): - return val - - app_context.set(app) - app.mcp_callback_map = CallbackAdapterCollection(app) - assert app.mcp_callback_map[0].is_valid - - -class TestNoInfiniteLoop: - @pytest.mark.timeout(5) - def test_initial_output_does_not_loop(self): - """Building a tool must not trigger infinite re-entry in _initial_output.""" - app = Dash(__name__) - app.layout = html.Div( - [ - dcc.Slider(id="sl", min=0, max=10, value=5), - html.Div(id="out"), - ] - ) - - @app.callback(Output("out", "children"), Input("sl", "value")) - def show(value): - return f"Value: {value}" - - app_context.set(app) - app.mcp_callback_map = CallbackAdapterCollection(app) - - with app.server.test_request_context(): - tool = app.mcp_callback_map[0].as_mcp_tool - assert tool.name == "show" - - @pytest.mark.timeout(5) - def test_chained_callbacks_do_not_loop(self): - """Chained callbacks with initial value resolution must not loop.""" - app = Dash(__name__) - app.layout = html.Div( - [ - dcc.Slider(id="sl", min=0, max=10, value=5), - dcc.Slider(id="sl2", min=0, max=10), - html.Div(id="out"), - ] - ) - - @app.callback(Output("sl2", "value"), Input("sl", "value")) - def sync(v): - return v - - @app.callback( - Output("out", "children"), - Input("sl", "value"), - Input("sl2", "value"), - ) - def show(v1, v2): - return f"{v1} + {v2}" - - app_context.set(app) - app.mcp_callback_map = CallbackAdapterCollection(app) - - with app.server.test_request_context(): - for cb in app.mcp_callback_map: - tool = cb.as_mcp_tool - assert tool.name is not None diff --git a/tests/unit/mcp/tools/test_mcp_callback_adapter.py b/tests/unit/mcp/tools/test_mcp_callback_adapter.py new file mode 100644 index 0000000000..82dfa0956b --- /dev/null +++ b/tests/unit/mcp/tools/test_mcp_callback_adapter.py @@ -0,0 +1,182 @@ +"""CallbackAdapter behavior: initial value resolution, validation, loop prevention.""" + +import pytest +from dash import Dash, Input, Output, dcc, html +from dash._get_app import app_context + +from dash.mcp.primitives.tools.callback_adapter_collection import ( + CallbackAdapterCollection, +) + + +@pytest.fixture +def simple_app(): + app = Dash(__name__) + app.layout = html.Div( + [ + html.Label("Your Name", htmlFor="inp"), + dcc.Input(id="inp", type="text"), + html.Div(id="out"), + ] + ) + + @app.callback(Output("out", "children"), Input("inp", "value")) + def update(val): + """Update output.""" + return val + + app_context.set(app) + app.mcp_callback_map = CallbackAdapterCollection(app) + return app + + +def test_mcpc001_returns_layout_value(simple_app): + callback_map = app_context.get().mcp_callback_map + # Input with no value set — returns None (layout default for dcc.Input) + assert callback_map.get_initial_value("inp.value") is None + + +def test_mcpc002_returns_set_value(): + app = Dash(__name__) + app.layout = html.Div( + [ + dcc.Dropdown(id="dd", options=["a", "b"], value="a"), + html.Div(id="out"), + ] + ) + + @app.callback(Output("out", "children"), Input("dd", "value")) + def update(selected): + return selected + + app_context.set(app) + app.mcp_callback_map = CallbackAdapterCollection(app) + assert app.mcp_callback_map.get_initial_value("dd.value") == "a" + + +def test_mcpc003_initial_callback_makes_param_required(): + """A param with None in layout but set by an initial callback is required.""" + app = Dash(__name__) + app.layout = html.Div( + [ + dcc.Dropdown(id="country", options=["France", "Germany"], value="France"), + dcc.Dropdown(id="city"), # value=None in layout + html.Div(id="out"), + ] + ) + + @app.callback( + Output("city", "options"), + Output("city", "value"), + Input("country", "value"), + ) + def update_cities(country): + return [{"label": "Paris", "value": "Paris"}], "Paris" + + @app.callback(Output("out", "children"), Input("city", "value")) + def show_city(city): + return f"Selected: {city}" + + app_context.set(app) + app.mcp_callback_map = CallbackAdapterCollection(app) + + # city.value is None in layout but "Paris" after initial callback + with app.server.test_request_context(): + show_city_cb = app.mcp_callback_map.find_by_tool_name("show_city") + city_param = show_city_cb.inputs[0] + assert city_param["name"] == "city" + assert city_param["required"] is True # not optional despite None in layout + + +def test_mcpc004_valid_when_inputs_in_layout(simple_app): + assert app_context.get().mcp_callback_map[0].is_valid + + +def test_mcpc005_invalid_when_input_not_in_layout(): + app = Dash(__name__) + app.layout = html.Div([html.Div(id="out")]) + + @app.callback(Output("out", "children"), Input("nonexistent", "value")) + def update(val): + return val + + app_context.set(app) + app.mcp_callback_map = CallbackAdapterCollection(app) + assert not app.mcp_callback_map[0].is_valid + + +def test_mcpc006_pattern_matching_ids_always_valid(): + app = Dash(__name__) + app.layout = html.Div( + [ + dcc.Input(id={"type": "field", "index": 0}, value="a"), + html.Div(id="out"), + ] + ) + + @app.callback( + Output("out", "children"), + Input({"type": "field", "index": 0}, "value"), + ) + def update(val): + return val + + app_context.set(app) + app.mcp_callback_map = CallbackAdapterCollection(app) + assert app.mcp_callback_map[0].is_valid + + +@pytest.mark.timeout(5) +def test_mcpc007_initial_output_does_not_loop(): + """Building a tool must not trigger infinite re-entry in _initial_output.""" + app = Dash(__name__) + app.layout = html.Div( + [ + dcc.Slider(id="sl", min=0, max=10, value=5), + html.Div(id="out"), + ] + ) + + @app.callback(Output("out", "children"), Input("sl", "value")) + def show(value): + return f"Value: {value}" + + app_context.set(app) + app.mcp_callback_map = CallbackAdapterCollection(app) + + with app.server.test_request_context(): + tool = app.mcp_callback_map[0].as_mcp_tool + assert tool.name == "show" + + +@pytest.mark.timeout(5) +def test_mcpc008_chained_callbacks_do_not_loop(): + """Chained callbacks with initial value resolution must not loop.""" + app = Dash(__name__) + app.layout = html.Div( + [ + dcc.Slider(id="sl", min=0, max=10, value=5), + dcc.Slider(id="sl2", min=0, max=10), + html.Div(id="out"), + ] + ) + + @app.callback(Output("sl2", "value"), Input("sl", "value")) + def sync(v): + return v + + @app.callback( + Output("out", "children"), + Input("sl", "value"), + Input("sl2", "value"), + ) + def show(v1, v2): + return f"{v1} + {v2}" + + app_context.set(app) + app.mcp_callback_map = CallbackAdapterCollection(app) + + with app.server.test_request_context(): + for cb in app.mcp_callback_map: + tool = cb.as_mcp_tool + assert tool.name is not None diff --git a/tests/unit/mcp/tools/test_mcp_input_descriptions.py b/tests/unit/mcp/tools/test_mcp_input_descriptions.py new file mode 100644 index 0000000000..55ac4df49a --- /dev/null +++ b/tests/unit/mcp/tools/test_mcp_input_descriptions.py @@ -0,0 +1,445 @@ +"""Input descriptions — human-readable per-property descriptions for MCP tool inputs. + +Covers: +- Labels (htmlFor, containment, text extraction) +- Component-specific (date pickers, sliders) +- Options (Dropdown, RadioItems, Checklist) +- Generic props (placeholder, default value, min/max/step) +- Chained callbacks (dynamic prop/options detection) +- Combinations (label + component-specific) +""" + +import pytest + +from dash import Dash, Input, Output, dcc, html + +from tests.unit.mcp.conftest import ( + _app_with_callback, + _desc_for, + _tools_list, + _user_tool, +) + + +# --------------------------------------------------------------------------- +# Helpers +# --------------------------------------------------------------------------- + + +def _app_with_layout(layout, *inputs): + app = Dash(__name__) + app.layout = layout + + @app.callback( + Output("out", "children"), + [Input(cid, prop) for cid, prop in inputs], + ) + def update(*args): + return str(args) + + return app + + +def _tool_for(component, input_prop="value"): + app = _app_with_callback(component, input_prop=input_prop) + return _user_tool(_tools_list(app)) + + +# --------------------------------------------------------------------------- +# Labels (htmlFor, containment, text extraction) +# --------------------------------------------------------------------------- + + +def test_mcpd001_label_html_for(): + app = _app_with_layout( + html.Div( + [ + html.Label("Your Name", htmlFor="inp"), + dcc.Input(id="inp"), + html.Div(id="out"), + ] + ), + ("inp", "value"), + ) + tool = _user_tool(_tools_list(app)) + assert "Your Name" in _desc_for(tool) + + +def test_mcpd002_label_html_for_not_adjacent(): + app = _app_with_layout( + html.Div( + [ + html.Div(html.Label("Remote Label", htmlFor="inp")), + dcc.Input(id="inp"), + html.Div(id="out"), + ] + ), + ("inp", "value"), + ) + tool = _user_tool(_tools_list(app)) + assert "Remote Label" in _desc_for(tool) + + +def test_mcpd003_label_containment(): + app = _app_with_layout( + html.Div( + [ + html.Label( + [ + "Pick a city", + dcc.Dropdown(id="city_dd", options=["NYC", "LA"]), + ] + ), + html.Div(id="out"), + ] + ), + ("city_dd", "value"), + ) + tool = _user_tool(_tools_list(app)) + assert "Pick a city" in _desc_for(tool) + + +def test_mcpd004_label_deeply_nested_containment(): + app = _app_with_layout( + html.Div( + [ + html.Label( + [ + html.Span("Nested Label"), + html.Div(dcc.Input(id="nested_inp")), + ] + ), + html.Div(id="out"), + ] + ), + ("nested_inp", "value"), + ) + tool = _user_tool(_tools_list(app)) + assert "Nested Label" in _desc_for(tool) + + +def test_mcpd005_label_both_htmlfor_and_containment_captured(): + app = _app_with_layout( + html.Div( + [ + html.Label(["Containment Label", dcc.Input(id="inp")]), + html.Label("HtmlFor Label", htmlFor="inp"), + html.Div(id="out"), + ] + ), + ("inp", "value"), + ) + tool = _user_tool(_tools_list(app)) + desc = _desc_for(tool) + assert "HtmlFor Label" in desc + assert "Containment Label" in desc + + +def test_mcpd006_label_deep_text_extraction(): + app = _app_with_layout( + html.Div( + [ + html.Label( + html.Div(html.Span(html.B("Deep Text"))), + htmlFor="inp", + ), + dcc.Input(id="inp"), + html.Div(id="out"), + ] + ), + ("inp", "value"), + ) + tool = _user_tool(_tools_list(app)) + assert "Deep Text" in _desc_for(tool) + + +def test_mcpd007_label_multiple_text_nodes(): + app = _app_with_layout( + html.Div( + [ + html.Label( + [html.B("First"), " ", html.I("Second")], + htmlFor="inp", + ), + dcc.Input(id="inp"), + html.Div(id="out"), + ] + ), + ("inp", "value"), + ) + tool = _user_tool(_tools_list(app)) + desc = _desc_for(tool) + assert "Labeled with: First Second" in desc + + +def test_mcpd008_label_unrelated_excluded(): + app = _app_with_layout( + html.Div( + [ + html.Label("Other Field", htmlFor="other"), + dcc.Input(id="other"), + dcc.Input(id="target"), + html.Div(id="out"), + ] + ), + ("target", "value"), + ) + tool = _user_tool(_tools_list(app)) + desc = _desc_for(tool) + assert "Other Field" not in (desc or "") + + +# --------------------------------------------------------------------------- +# Component-specific: date pickers +# --------------------------------------------------------------------------- + + +def test_mcpd009_date_picker_single_full_range(): + dp = dcc.DatePickerSingle( + id="dp", + min_date_allowed="2020-01-01", + max_date_allowed="2025-12-31", + ) + desc = _desc_for(_tool_for(dp, "date"), "val") + assert "2020-01-01" in desc + assert "2025-12-31" in desc + + +def test_mcpd010_date_picker_single_min_only(): + dp = dcc.DatePickerSingle(id="dp", min_date_allowed="2020-01-01") + desc = _desc_for(_tool_for(dp, "date"), "val") + assert "min_date_allowed: '2020-01-01'" in desc + + +def test_mcpd011_date_picker_single_default_date(): + dp = dcc.DatePickerSingle(id="dp", date="2024-06-15") + desc = _desc_for(_tool_for(dp, "date"), "val") + assert "date: '2024-06-15'" in desc + + +def test_mcpd012_date_picker_range_with_constraints(): + dpr = dcc.DatePickerRange( + id="dpr", + min_date_allowed="2020-01-01", + max_date_allowed="2025-12-31", + ) + desc = _desc_for(_tool_for(dpr, "start_date"), "val") + assert "2020-01-01" in desc + + +# --------------------------------------------------------------------------- +# Component-specific: sliders +# --------------------------------------------------------------------------- + + +def test_mcpd013_slider_min_max(): + sl = dcc.Slider(id="sl", min=0, max=100) + desc = _desc_for(_tool_for(sl), "val") + assert "min: 0" in desc + assert "max: 100" in desc + + +def test_mcpd014_slider_step(): + sl = dcc.Slider(id="sl", min=0, max=100, step=5) + desc = _desc_for(_tool_for(sl), "val") + assert "step: 5" in desc + + +def test_mcpd015_slider_default_value(): + sl = dcc.Slider(id="sl", min=0, max=100, value=50) + desc = _desc_for(_tool_for(sl), "val") + assert "value: 50" in desc + + +def test_mcpd016_slider_marks(): + sl = dcc.Slider(id="sl", min=0, max=100, marks={0: "Low", 100: "High"}) + desc = _desc_for(_tool_for(sl), "val") + assert "marks: {0: 'Low', 100: 'High'}" in desc + + +def test_mcpd017_range_slider_min_max(): + rs = dcc.RangeSlider(id="rs", min=0, max=100) + desc = _desc_for(_tool_for(rs), "val") + assert "min: 0" in desc + assert "max: 100" in desc + + +# --------------------------------------------------------------------------- +# Options (parametrized across Dropdown, RadioItems, Checklist) +# --------------------------------------------------------------------------- + + +_OPTIONS_COMPONENTS = [ + ("Dropdown", lambda **kw: dcc.Dropdown(id="comp", **kw), "comp"), + ("RadioItems", lambda **kw: dcc.RadioItems(id="comp", **kw), "comp"), + ("Checklist", lambda **kw: dcc.Checklist(id="comp", **kw), "comp"), +] + + +@pytest.mark.parametrize( + "name,factory,cid", _OPTIONS_COMPONENTS, ids=[c[0] for c in _OPTIONS_COMPONENTS] +) +def test_mcpd018_options_shown(name, factory, cid): + comp = factory(options=["X", "Y", "Z"]) + desc = _desc_for(_tool_for(comp), "val") + assert "options: ['X', 'Y', 'Z']" in desc + + +@pytest.mark.parametrize( + "name,factory,cid", _OPTIONS_COMPONENTS, ids=[c[0] for c in _OPTIONS_COMPONENTS] +) +def test_mcpd019_default_shown(name, factory, cid): + value = ["a"] if name == "Checklist" else "a" + comp = factory(options=["a", "b"], value=value) + desc = _desc_for(_tool_for(comp), "val") + assert f"value: {value!r}" in desc + + +def test_mcpd020_dropdown_dict_options(): + dd = dcc.Dropdown( + id="dd", + options=[ + {"label": "New York", "value": "NYC"}, + ], + ) + assert "NYC" in _desc_for(_tool_for(dd), "val") + + +def test_mcpd021_store_storage_type_template(): + store = dcc.Store(id="store", storage_type="session") + app = _app_with_callback(store, input_prop="data") + tool = _user_tool(_tools_list(app)) + desc = _desc_for(tool, "val") + assert ( + "storage_type: 'session'. Describes how to store the value client-side" in desc + ) + + +def test_mcpd022_many_options_truncated(): + dd = dcc.Dropdown(id="big", options=[str(i) for i in range(50)], value="0") + app = _app_with_callback(dd) + tool = _user_tool(_tools_list(app)) + desc = _desc_for(tool, "val") + assert "options:" in desc + assert "Use get_dash_component('big', 'options') for the full value" in desc + + +# --------------------------------------------------------------------------- +# Generic props (placeholder, default, numeric min/max/step) +# --------------------------------------------------------------------------- + + +def test_mcpd023_generic_placeholder(): + inp = dcc.Input(id="inp", placeholder="Enter your name") + assert "placeholder: 'Enter your name'" in _desc_for(_tool_for(inp), "val") + + +def test_mcpd024_generic_numeric_min_max(): + inp = dcc.Input(id="inp", type="number", min=0, max=999) + desc = _desc_for(_tool_for(inp), "val") + assert "min: 0" in desc + assert "max: 999" in desc + + +def test_mcpd025_generic_step(): + inp = dcc.Input(id="inp", type="number", min=0, max=100, step=0.1) + assert "step: 0.1" in _desc_for(_tool_for(inp), "val") + + +def test_mcpd026_generic_default_value(): + inp = dcc.Input(id="inp", value="hello") + desc = _desc_for(_tool_for(inp), "val") + assert "value: 'hello'" in desc + + +def test_mcpd027_generic_non_text_type(): + inp = dcc.Input(id="inp", type="email") + assert "type: 'email'" in _desc_for(_tool_for(inp), "val") + + +def test_mcpd028_generic_store_default(): + store = dcc.Store(id="store", data={"key": "value"}) + app = _app_with_callback(store, input_prop="data") + tool = _user_tool(_tools_list(app)) + assert "data: {'key': 'value'}" in _desc_for(tool, "val") + + +# --------------------------------------------------------------------------- +# Chained callbacks — descriptions reflect upstream dependencies +# --------------------------------------------------------------------------- + + +def test_mcpd029_chained_options_set_by_upstream(): + app = Dash(__name__) + app.layout = html.Div( + [ + dcc.Dropdown(id="country", options=["US", "CA"], value="US"), + dcc.Dropdown(id="city", options=[], value=None), + html.Div(id="result"), + ] + ) + + @app.callback(Output("city", "options"), Input("country", "value")) + def update_cities(country): + return ["NYC", "LA"] if country == "US" else ["Toronto"] + + @app.callback(Output("result", "children"), Input("city", "value")) + def show_city(city): + return city + + tools = _tools_list(app) + tool = next(t for t in tools if "show_city" in t.name) + desc = _desc_for(tool, "city") + assert "can be updated by tool: `update_cities`" in desc + assert "options:" in desc + + +def test_mcpd030_chained_value_set_by_upstream(): + app = Dash(__name__) + app.layout = html.Div( + [ + dcc.Input(id="source", value=""), + html.Div(id="derived", children=""), + html.Div(id="result"), + ] + ) + + @app.callback(Output("derived", "children"), Input("source", "value")) + def compute_derived(val): + return f"derived: {val}" + + @app.callback(Output("result", "children"), Input("derived", "children")) + def use_derived(val): + return val + + tools = _tools_list(app) + tool = next(t for t in tools if "use_derived" in t.name) + desc = _desc_for(tool, "val") + assert "can be updated by tool: `compute_derived`" in desc + + +# --------------------------------------------------------------------------- +# Combinations — label + component-specific +# --------------------------------------------------------------------------- + + +def test_mcpd031_combination_label_with_date_picker(): + dp = dcc.DatePickerSingle( + id="dp", + min_date_allowed="2020-01-01", + max_date_allowed="2025-12-31", + ) + app = _app_with_layout( + html.Div( + [ + html.Label("Departure Date", htmlFor="dp"), + dp, + html.Div(id="out"), + ] + ), + ("dp", "date"), + ) + tool = _user_tool(_tools_list(app)) + desc = _desc_for(tool) + assert "Departure Date" in desc + assert "2020-01-01" in desc diff --git a/tests/unit/mcp/tools/test_mcp_input_schemas.py b/tests/unit/mcp/tools/test_mcp_input_schemas.py new file mode 100644 index 0000000000..d0b5e86923 --- /dev/null +++ b/tests/unit/mcp/tools/test_mcp_input_schemas.py @@ -0,0 +1,270 @@ +"""Input schema generation — JSON Schema for callback input parameters. + +Covers: +- Static overrides (DatePicker, Graph, Interval, Slider) +- Component introspection (representative per-type samples) +- Callback annotation overrides (highest priority) +- Required / nullable behavior +- Component type → JSON schema mapping +""" + +import pytest +from typing import Optional + +from dash import Dash, Input, Output, State, dcc, html +from dash.development.base_component import Component +from dash.mcp.primitives.tools.input_schemas.schema_callback_type_annotations import ( + annotation_to_json_schema, +) + +from tests.unit.mcp.conftest import ( + _app_with_callback, + _schema_for, + _tools_list, + _user_tool, +) + + +# --------------------------------------------------------------------------- +# Schema building blocks (JSON Schema primitives) +# --------------------------------------------------------------------------- + +STRING = {"type": "string"} +NUMBER = {"type": "number"} +INTEGER = {"type": "integer"} +BOOLEAN = {"type": "boolean"} +NULL = {"type": "null"} +OBJECT = {"additionalProperties": True, "type": "object"} + + +def nullable(*schemas): + """``{anyOf: [*schemas, {type: null}]}`` — a common nullable-type shape.""" + return {"anyOf": [*schemas, NULL]} + + +def array_of(*item_schemas): + """Array of a single schema, or a union when multiple are passed.""" + items = item_schemas[0] if len(item_schemas) == 1 else {"anyOf": list(item_schemas)} + return {"items": items, "type": "array"} + + +# --------------------------------------------------------------------------- +# Helpers +# --------------------------------------------------------------------------- + + +def _get_schema(component_type, prop): + _factories = { + "DatePickerSingle": lambda: dcc.DatePickerSingle(id="dp"), + "DatePickerRange": lambda: dcc.DatePickerRange(id="dpr"), + "Graph": lambda: dcc.Graph(id="graph"), + "Interval": lambda: dcc.Interval(id="intv"), + "Input": lambda: dcc.Input(id="inp"), + "Textarea": lambda: dcc.Textarea(id="ta"), + "Tabs": lambda: dcc.Tabs(id="tabs"), + "Dropdown": lambda: dcc.Dropdown(id="dd"), + "RadioItems": lambda: dcc.RadioItems(id="ri"), + "Checklist": lambda: dcc.Checklist(id="cl"), + "Store": lambda: dcc.Store(id="store"), + "Upload": lambda: dcc.Upload(id="upload"), + "Slider": lambda: dcc.Slider(id="sl"), + "RangeSlider": lambda: dcc.RangeSlider(id="rs"), + } + app = _app_with_callback(_factories[component_type](), input_prop=prop) + tool = _user_tool(_tools_list(app)) + return _schema_for(tool) + + +def _app_with_annotated_callback(annotation_type, input_prop="disabled"): + app = Dash(__name__) + app.layout = html.Div([dcc.Input(id="inp"), html.Div(id="out")]) + + if annotation_type is None: + + @app.callback(Output("out", "children"), Input("inp", input_prop)) + def update(val): + return str(val) + + else: + + @app.callback(Output("out", "children"), Input("inp", input_prop)) + def update(val: annotation_type): + return str(val) + + return app + + +# --------------------------------------------------------------------------- +# Test cases +# --------------------------------------------------------------------------- + +# (component_type, prop, expected_schema) — representative per-type samples +INTROSPECTION_CASES = [ + ("Input", "value", nullable(STRING, NUMBER)), + ( + "Input", + "disabled", + nullable( + BOOLEAN, + {"const": "disabled", "type": "string"}, + {"const": "DISABLED", "type": "string"}, + ), + ), + ("Input", "n_submit", nullable(NUMBER)), + ( + "Dropdown", + "value", + nullable(STRING, NUMBER, BOOLEAN, array_of(STRING, NUMBER, BOOLEAN)), + ), + ("Dropdown", "options", nullable({})), + ("Checklist", "value", nullable(array_of(STRING, NUMBER, BOOLEAN))), + ("Store", "data", nullable(OBJECT, array_of({}), NUMBER, STRING, BOOLEAN)), + ("Upload", "contents", nullable(STRING, array_of(STRING))), + ("RangeSlider", "value", nullable(array_of(NUMBER))), + ("Tabs", "value", nullable(STRING)), +] + +# (annotation, prop, expected_schema) — callback annotations override introspection +ANNOTATION_CASES = [ + (str, "disabled", STRING), + (int, "value", INTEGER), + (float, "value", NUMBER), + (bool, "value", BOOLEAN), + (list, "value", array_of({})), + (dict, "value", OBJECT), + (Optional[int], "value", nullable(INTEGER)), + (Optional[str], "value", nullable(STRING)), +] + + +# --------------------------------------------------------------------------- +# Tests +# --------------------------------------------------------------------------- + + +def test_mcpi001_override_beats_introspection(): + """Static override wins over component introspection.""" + schema = _get_schema("DatePickerSingle", "date") + # Introspection would return None for this prop; + # override provides a date format with pattern + assert schema["type"] == "string" + assert schema["format"] == "date" + assert "pattern" in schema + + +@pytest.mark.parametrize( + "component_type,prop,expected", + INTROSPECTION_CASES, + ids=[f"{c}.{p}" for c, p, _ in INTROSPECTION_CASES], +) +def test_mcpi002_introspected_schema(component_type, prop, expected): + """Representative introspection tests across component types.""" + assert _get_schema(component_type, prop) == expected + + +@pytest.mark.parametrize( + "ann,prop,expected", + ANNOTATION_CASES, + ids=[ + f"{a.__name__ if hasattr(a, '__name__') else a}-{p}" + for a, p, _ in ANNOTATION_CASES + ], +) +def test_mcpi003_annotation(ann, prop, expected): + """Callback type annotations override component schemas.""" + app = _app_with_annotated_callback(ann, input_prop=prop) + tool = _user_tool(_tools_list(app)) + assert _schema_for(tool, "val") == expected + + +def test_mcpi004_no_annotation_uses_introspection(): + app = _app_with_annotated_callback(None) + tool = _user_tool(_tools_list(app)) + assert _schema_for(tool, "val") == nullable( + BOOLEAN, + {"const": "disabled", "type": "string"}, + {"const": "DISABLED", "type": "string"}, + ) + + +def test_mcpi005_str_removes_null(): + app = Dash(__name__) + app.layout = html.Div([dcc.Dropdown(id="dd"), html.Div(id="out")]) + + @app.callback(Output("out", "children"), Input("dd", "value")) + def update(val: str): + return val + + tool = _user_tool(_tools_list(app)) + assert _schema_for(tool, "val") == STRING + + +def test_mcpi006_optional_preserves_null(): + app = Dash(__name__) + app.layout = html.Div([dcc.Dropdown(id="dd"), html.Div(id="out")]) + + @app.callback(Output("out", "children"), Input("dd", "value")) + def update(val: Optional[str]): + return val or "" + + tool = _user_tool(_tools_list(app)) + assert _schema_for(tool, "val") == nullable(STRING) + + +def test_mcpi007_optional_param_not_required(): + app = Dash(__name__) + app.layout = html.Div([dcc.Dropdown(id="dd"), html.Div(id="out")]) + + @app.callback(Output("out", "children"), Input("dd", "value")) + def update(val: Optional[str]): + return val or "" + + tool = _user_tool(_tools_list(app)) + assert "val" not in tool.inputSchema.get("required", []) + + +def test_mcpi008_state_annotation_overrides(): + """Annotations work for State parameters too.""" + app = Dash(__name__) + app.layout = html.Div( + [dcc.Input(id="inp"), dcc.Store(id="store"), html.Div(id="out")] + ) + + @app.callback( + Output("out", "children"), + Input("inp", "value"), + State("store", "data"), + ) + def update(val: str, data: dict): + return str(val) + + tool = _user_tool(_tools_list(app)) + assert _schema_for(tool, "val") == STRING + assert _schema_for(tool, "data") == OBJECT + + +def test_mcpi009_partial_annotations(): + """Some annotated, some not — introspection fills in the rest.""" + app = Dash(__name__) + app.layout = html.Div( + [dcc.Input(id="inp"), dcc.Store(id="store"), html.Div(id="out")] + ) + + @app.callback( + Output("out", "children"), + Input("inp", "value"), + State("store", "data"), + ) + def update(val: int, data): + return str(val) + + tool = _user_tool(_tools_list(app)) + assert _schema_for(tool, "val") == INTEGER + assert _schema_for(tool, "data") == nullable( + OBJECT, array_of({}), NUMBER, STRING, BOOLEAN + ) + + +def test_mcpi010_component_type_maps_to_string(): + """Component annotation type maps to string schema.""" + assert annotation_to_json_schema(Component) == STRING diff --git a/tests/unit/mcp/tools/test_mcp_tools.py b/tests/unit/mcp/tools/test_mcp_tools.py new file mode 100644 index 0000000000..cacaf13b14 --- /dev/null +++ b/tests/unit/mcp/tools/test_mcp_tools.py @@ -0,0 +1,358 @@ +"""Tool construction: how Dash callbacks become MCP Tool objects. + +Covers the CallbackAdapter → Tool pipeline: list building (from_app), +tool name generation, and the resulting Tool object's shape (description, +input schema, param metadata). + +Reference: https://modelcontextprotocol.io/specification/2025-11-25/server/tools +""" + +import pytest +from dash import Dash, Input, Output, State, dcc, html +from dash._get_app import app_context +from dash.development.base_component import Component +from dash.types import CallbackExecutionResponse +from mcp.types import Tool +from pydantic import TypeAdapter + +from dash.mcp.primitives.tools.callback_adapter_collection import ( + CallbackAdapterCollection, +) + +from tests.unit.mcp.conftest import ( + _make_app, + _tools_list, + _user_tool, +) + + +# --------------------------------------------------------------------------- +# Fixtures +# --------------------------------------------------------------------------- + + +@pytest.fixture +def simple_app(): + app = Dash(__name__) + app.layout = html.Div( + [ + html.Label("Your Name", htmlFor="inp"), + dcc.Input(id="inp", type="text"), + html.Div(id="out"), + ] + ) + + @app.callback(Output("out", "children"), Input("inp", "value")) + def update(val): + """Update output.""" + return val + + app_context.set(app) + app.mcp_callback_map = CallbackAdapterCollection(app) + return app + + +@pytest.fixture +def multi_output_app(): + app = Dash(__name__) + app.layout = html.Div( + [ + dcc.Dropdown(id="dd", options=["a", "b"], value="a"), + dcc.Dropdown(id="dd2"), + html.Div(id="out"), + ] + ) + + @app.callback( + Output("dd2", "options"), + Output("out", "children"), + Input("dd", "value"), + ) + def update(val): + return [], val + + app_context.set(app) + app.mcp_callback_map = CallbackAdapterCollection(app) + return app + + +@pytest.fixture +def state_app(): + app = Dash(__name__) + app.layout = html.Div( + [ + html.Button(id="btn"), + dcc.Input(id="inp"), + html.Div(id="out"), + ] + ) + + @app.callback( + Output("out", "children"), + Input("btn", "n_clicks"), + State("inp", "value"), + ) + def update(clicks, val): + return val + + app_context.set(app) + app.mcp_callback_map = CallbackAdapterCollection(app) + return app + + +@pytest.fixture +def typed_app(): + app = Dash(__name__) + app.layout = html.Div([dcc.Input(id="inp"), html.Div(id="out")]) + + @app.callback(Output("out", "children"), Input("inp", "value")) + def update(val: str): + return val + + app_context.set(app) + app.mcp_callback_map = CallbackAdapterCollection(app) + return app + + +@pytest.fixture +def duplicate_names_app(): + app = Dash(__name__) + app.layout = html.Div( + [ + html.Div(id="in1"), + html.Div(id="out1"), + html.Div(id="in2"), + html.Div(id="out2"), + ] + ) + + @app.callback(Output("out1", "children"), Input("in1", "children")) + def cb(v): + return v + + @app.callback(Output("out2", "children"), Input("in2", "children")) + def cb(v): # noqa: F811 + return v + + app_context.set(app) + app.mcp_callback_map = CallbackAdapterCollection(app) + return app + + +# --------------------------------------------------------------------------- +# Tests — building the callback list from an app +# --------------------------------------------------------------------------- + + +def test_mcpt001_returns_list(simple_app): + assert len(app_context.get().mcp_callback_map) == 1 + + +def test_mcpt002_excludes_clientside(): + app = Dash(__name__) + app.layout = html.Div( + [ + html.Button(id="btn"), + html.Div(id="cs-out"), + html.Div(id="srv-out"), + ] + ) + app.clientside_callback( + "function(n) { return n; }", + Output("cs-out", "children"), + Input("btn", "n_clicks"), + ) + + @app.callback(Output("srv-out", "children"), Input("btn", "n_clicks")) + def server_cb(n): + return str(n) + + app_context.set(app) + app.mcp_callback_map = CallbackAdapterCollection(app) + + names = [a.tool_name for a in app.mcp_callback_map] + assert names == ["server_cb"] + + +def test_mcpt003_excludes_mcp_disabled(): + app = Dash(__name__) + app.layout = html.Div( + [ + dcc.Input(id="inp"), + html.Div(id="out1"), + html.Div(id="out2"), + ] + ) + + @app.callback(Output("out1", "children"), Input("inp", "value")) + def visible(val): + return val + + @app.callback(Output("out2", "children"), Input("inp", "value"), mcp_enabled=False) + def hidden(val): + return val + + app_context.set(app) + app.mcp_callback_map = CallbackAdapterCollection(app) + names = [a.tool_name for a in app.mcp_callback_map] + assert "visible" in names + assert "hidden" not in names + + +# --------------------------------------------------------------------------- +# Tests — tool name generation +# --------------------------------------------------------------------------- + + +def test_mcpt004_uses_func_name(simple_app): + assert app_context.get().mcp_callback_map[0].tool_name == "update" + + +def test_mcpt005_duplicates_get_unique_names(duplicate_names_app): + names = [a.tool_name for a in app_context.get().mcp_callback_map] + assert len(names) == 2 + assert names[0] != names[1] + + +# --------------------------------------------------------------------------- +# Tests — Tool object shape (description, input schema, params) +# --------------------------------------------------------------------------- + + +def test_mcpt006_returns_tool_instance(simple_app): + with simple_app.server.test_request_context(): + tool = app_context.get().mcp_callback_map[0].as_mcp_tool + assert isinstance(tool, Tool) + assert tool.name == "update" + + +def test_mcpt007_docstring_hidden_by_default(): + """Callback docstrings are not exposed to MCP by default.""" + app = Dash(__name__) + app.layout = html.Div([dcc.Input(id="inp"), html.Div(id="out")]) + + @app.callback(Output("out", "children"), Input("inp", "value")) + def update(val): + """sensitive callback docstring text that must not leak to LLMs""" + return val + + app_context.set(app) + app.mcp_callback_map = CallbackAdapterCollection(app) + + with app.server.test_request_context(): + tool = app.mcp_callback_map[0].as_mcp_tool + assert ( + "sensitive callback docstring text that must not leak to LLMs" + not in tool.description + ) + + +def test_mcpt008_docstring_exposed_when_opted_in_per_callback(): + app = Dash(__name__) + app.layout = html.Div([dcc.Input(id="inp"), html.Div(id="out")]) + + @app.callback( + Output("out", "children"), + Input("inp", "value"), + mcp_expose_docstring=True, + ) + def update(val): + """intentionally-exposed callback docstring text for the LLM""" + return val + + app_context.set(app) + app.mcp_callback_map = CallbackAdapterCollection(app) + + with app.server.test_request_context(): + tool = app.mcp_callback_map[0].as_mcp_tool + assert ( + "intentionally-exposed callback docstring text for the LLM" in tool.description + ) + + +def test_mcpt009_description_includes_output_target(simple_app): + with simple_app.server.test_request_context(): + tool = app_context.get().mcp_callback_map[0].as_mcp_tool + assert "out.children" in tool.description + + +def test_mcpt010_param_name_from_function_signature(simple_app): + with simple_app.server.test_request_context(): + tool = app_context.get().mcp_callback_map[0].as_mcp_tool + assert "val" in tool.inputSchema["properties"] + + +def test_mcpt011_param_has_label_description(simple_app): + with simple_app.server.test_request_context(): + tool = app_context.get().mcp_callback_map[0].as_mcp_tool + desc = tool.inputSchema["properties"]["val"].get("description", "") + assert "Your Name" in desc + + +def test_mcpt012_state_params_included(state_app): + with state_app.server.test_request_context(): + tool = app_context.get().mcp_callback_map[0].as_mcp_tool + props = tool.inputSchema["properties"] + assert set(props.keys()) == {"clicks", "val"} + + +def test_mcpt013_multi_output_description(multi_output_app): + with multi_output_app.server.test_request_context(): + tool = app_context.get().mcp_callback_map[0].as_mcp_tool + assert "dd2.options" in tool.description + assert "out.children" in tool.description + + +def test_mcpt014_typed_annotation_narrows_schema(typed_app): + with typed_app.server.test_request_context(): + tool = app_context.get().mcp_callback_map[0].as_mcp_tool + assert tool.inputSchema["properties"]["val"]["type"] == "string" + + +# --------------------------------------------------------------------------- +# Tests — end-to-end Tool shape +# --------------------------------------------------------------------------- + + +_DASH_COMPONENT_SCHEMA = TypeAdapter(Component).json_schema() + +EXPECTED_TOOL = { + "name": "update_output", + "description": ( + "my-output.children: Returns content\n" "\n" "Test callback docstring." + ), + "inputSchema": { + "type": "object", + "properties": { + "value": { + "anyOf": [ + {"type": "string"}, + {"type": "integer"}, + {"type": "number"}, + _DASH_COMPONENT_SCHEMA, + { + "items": { + "anyOf": [ + {"type": "string"}, + {"type": "integer"}, + {"type": "number"}, + _DASH_COMPONENT_SCHEMA, + {"type": "null"}, + ] + }, + "type": "array", + }, + {"type": "null"}, + ], + "description": "Input is optional.\nThe children of this component.", + }, + }, + }, + "outputSchema": TypeAdapter(CallbackExecutionResponse).json_schema(), +} + + +def test_mcpt015_full_tool(): + """The entire tool dict matches the expected shape end-to-end.""" + tool = _user_tool(_tools_list(_make_app())) + assert tool.model_dump(exclude_none=True) == EXPECTED_TOOL diff --git a/tests/unit/mcp/tools/test_tool_schema.py b/tests/unit/mcp/tools/test_tool_schema.py deleted file mode 100644 index b39dfe08c9..0000000000 --- a/tests/unit/mcp/tools/test_tool_schema.py +++ /dev/null @@ -1,64 +0,0 @@ -"""Tool schema tests — what a Dash MCP tool looks like. - -The EXPECTED_TOOL dict below is the canonical reference for the shape of -a callback-generated MCP tool. It doubles as human-readable documentation -and as a test fixture. - -Reference: https://modelcontextprotocol.io/specification/2025-11-25/server/tools -""" - -from tests.unit.mcp.conftest import ( - _make_app, - _tools_list, - _user_tool, -) - -from pydantic import TypeAdapter -from dash.development.base_component import Component -from dash.types import CallbackExecutionResponse - -_DASH_COMPONENT_SCHEMA = TypeAdapter(Component).json_schema() - -EXPECTED_TOOL = { - "name": "update_output", - "description": ( - "my-output.children: Returns content\n" "\n" "Test callback docstring." - ), - "inputSchema": { - "type": "object", - "properties": { - "value": { - "anyOf": [ - {"type": "string"}, - {"type": "integer"}, - {"type": "number"}, - _DASH_COMPONENT_SCHEMA, - { - "items": { - "anyOf": [ - {"type": "string"}, - {"type": "integer"}, - {"type": "number"}, - _DASH_COMPONENT_SCHEMA, - {"type": "null"}, - ] - }, - "type": "array", - }, - {"type": "null"}, - ], - "description": "Input is optional.\nThe children of this component.", - }, - }, - }, - "outputSchema": TypeAdapter(CallbackExecutionResponse).json_schema(), -} - - -class TestToolSchema: - """Verify that the generated tool matches EXPECTED_TOOL exactly.""" - - def test_full_tool(self): - """The entire tool dict matches the expected shape.""" - tool = _user_tool(_tools_list(_make_app())) - assert tool.model_dump(exclude_none=True) == EXPECTED_TOOL From 4d7411c6ab40b8357f0a20fabe3ee9fb02a47355 Mon Sep 17 00:00:00 2001 From: Adrian Borrmann Date: Fri, 24 Apr 2026 12:19:51 -0600 Subject: [PATCH 31/80] Move all hard-coded text into a `prop_roles.py` file where it can be managed centrally --- .../tools/descriptions/description_outputs.py | 31 ++-- .../schema_component_proptypes_overrides.py | 73 ++------- dash/mcp/primitives/tools/prop_roles.py | 151 ++++++++++++++++++ .../unit/mcp/tools/test_mcp_input_schemas.py | 29 +++- 4 files changed, 200 insertions(+), 84 deletions(-) create mode 100644 dash/mcp/primitives/tools/prop_roles.py diff --git a/dash/mcp/primitives/tools/descriptions/description_outputs.py b/dash/mcp/primitives/tools/descriptions/description_outputs.py index 06371d7717..b7bf55e81c 100644 --- a/dash/mcp/primitives/tools/descriptions/description_outputs.py +++ b/dash/mcp/primitives/tools/descriptions/description_outputs.py @@ -4,25 +4,18 @@ from typing import TYPE_CHECKING +from ..prop_roles import iter_prop_roles from .base import ToolDescriptionSource if TYPE_CHECKING: from dash.mcp.primitives.tools.callback_adapter import CallbackAdapter -_OUTPUT_SEMANTICS: dict[tuple[str | None, str], str] = { - ("DataTable", "data"): "Returns tabular data", - ("DataTable", "columns"): "Returns table column definitions", - ("Store", "data"): "Returns data to be remembered client-side", - ("Download", "data"): "Returns downloadable content", - ("Markdown", "children"): "Returns formatted text", - (None, "figure"): "Returns chart/visualization data", - (None, "options"): "Returns available options", - (None, "columns"): "Returns column definitions", - (None, "children"): "Returns content", - (None, "value"): "Returns the current value", - (None, "style"): "Updates styling", - (None, "disabled"): "Updates enabled/disabled state", -} + +def _describe_output(comp_type: str | None, prop: str) -> str | None: + for role in iter_prop_roles(): + if role.description is not None and role.matches(comp_type, prop): + return role.description + return None class OutputSummaryDescription(ToolDescriptionSource): @@ -38,14 +31,10 @@ def describe(cls, callback: CallbackAdapter) -> list[str]: for out in outputs: comp_id = out["component_id"] prop = out["property"] - comp_type = out.get("component_type") - - semantic = _OUTPUT_SEMANTICS.get((comp_type, prop)) - if semantic is None: - semantic = _OUTPUT_SEMANTICS.get((None, prop)) + description = _describe_output(out.get("component_type"), prop) - if semantic is not None: - lines.append(f"- {comp_id}.{prop}: {semantic}") + if description is not None: + lines.append(f"- {comp_id}.{prop}: {description}") else: lines.append(f"- {comp_id}.{prop}") diff --git a/dash/mcp/primitives/tools/input_schemas/schema_component_proptypes_overrides.py b/dash/mcp/primitives/tools/input_schemas/schema_component_proptypes_overrides.py index 984d493d69..e3d5b65756 100644 --- a/dash/mcp/primitives/tools/input_schemas/schema_component_proptypes_overrides.py +++ b/dash/mcp/primitives/tools/input_schemas/schema_component_proptypes_overrides.py @@ -1,5 +1,8 @@ -"""A place to manually define Schemas that override component-defined prop types -where type generation produces insufficient results. +"""Input schema overrides drawn from the ``PropRole`` registry. + +Looks up the parameter's ``(component_type, property)`` in the shared +registry and returns any attached ``input_schema``. Used when default +type introspection produces insufficient results. """ from __future__ import annotations @@ -8,56 +11,8 @@ from dash.mcp.types import MCPInput +from ..prop_roles import iter_prop_roles from .base import InputSchemaSource -from .schema_component_proptypes import ComponentPropSchema - -_DATE_SCHEMA = { - "type": "string", - "format": "date", - "pattern": r"^\d{4}-\d{2}-\d{2}$", -} - - -def _compute_dropdown_value_schema(param: MCPInput) -> dict[str, Any] | None: - """Dropdown values are an array if `multi=True`; scalar values otherwise.""" - schema = ComponentPropSchema.get_schema(param) - if schema is None: - return None - - component = param.get("component") - t = schema.get("type") - if not isinstance(t, list): - return schema - - if getattr(component, "multi", False): - items_schema = schema.get("items", {}) - return ( - {"type": "array", "items": items_schema} - if items_schema - else {"type": "array"} - ) - - scalar_types = [x for x in t if x != "array"] - refined = dict(schema) - refined["type"] = scalar_types[0] if len(scalar_types) == 1 else scalar_types - refined.pop("items", None) - return refined - - -_OVERRIDES: dict[tuple[str, str], dict[str, Any] | callable] = { - ("DatePickerSingle", "date"): _DATE_SCHEMA, - ("DatePickerRange", "start_date"): _DATE_SCHEMA, - ("DatePickerRange", "end_date"): _DATE_SCHEMA, - ("Graph", "figure"): { - "type": "object", - "properties": { - "data": {"type": "array", "items": {"type": "object"}}, - "layout": {"type": "object"}, - "frames": {"type": "array", "items": {"type": "object"}}, - }, - }, - ("Dropdown", "value"): _compute_dropdown_value_schema, -} class OverrideSchema(InputSchemaSource): @@ -65,10 +20,12 @@ class OverrideSchema(InputSchemaSource): @classmethod def get_schema(cls, param: MCPInput) -> dict[str, Any] | None: - key = (param.get("component_type"), param["property"]) - override = _OVERRIDES.get(key) - if override is None: - return None - if callable(override): - return override(param) - return dict(override) + component_type = param.get("component_type") + prop = param["property"] + for role in iter_prop_roles(): + if role.input_schema is None or not role.matches(component_type, prop): + continue + if callable(role.input_schema): + return role.input_schema(param) + return dict(role.input_schema) + return None diff --git a/dash/mcp/primitives/tools/prop_roles.py b/dash/mcp/primitives/tools/prop_roles.py new file mode 100644 index 0000000000..eface31147 --- /dev/null +++ b/dash/mcp/primitives/tools/prop_roles.py @@ -0,0 +1,151 @@ +"""Canonical registry of semantic roles for Dash component props. + +A ``PropRole`` bundles the set of ``(component_type, property)`` pairs +that play the same role with the metadata attached to that role: +an LLM-facing description, an input JSON Schema, etc. Tool descriptions, +input-schema overrides, and result formatters all consume this registry +so they can't drift. + +Use ``ANY_COMPONENT`` as the component_type sentinel to match any component with +the given property name. + +Declaration order matters: ``iter_prop_roles()`` yields roles in the +order they're defined in this module, and the first match wins. List +concrete-match roles before wildcard-match roles that share a prop +name (e.g. ``MARKDOWN`` before ``CONTENT`` for ``children``). +""" + +from __future__ import annotations + +from typing import Any, Callable, Dict, Iterator, NamedTuple, Union + +from typing_extensions import TypeAlias + +from dash.mcp.types import MCPInput + +PropSchema = Union[ + Dict[str, Any], + Callable[[MCPInput], Dict[str, Any]], +] + +COMPONENT: TypeAlias = Union[str, None] +ANY_COMPONENT: None = None +PROP: TypeAlias = str + + +class PropRole(NamedTuple): + identifiers: set[tuple[COMPONENT, PROP]] + description: str | None = None + input_schema: PropSchema | None = None + + def matches(self, component_type: COMPONENT, prop: PROP) -> bool: + """True if this role applies to the given ``(component_type, prop)``. + + Matches either a concrete entry or an ``ANY_COMPONENT`` wildcard + entry in ``identifiers``. Shared by every consumer so all metadata + fields apply uniformly to every identifier in the role. + """ + return (component_type, prop) in self.identifiers or ( + ANY_COMPONENT, + prop, + ) in self.identifiers + + +def _compute_dropdown_value_schema(param: MCPInput) -> dict[str, Any]: + """Dropdown values are an array if ``multi=True``; scalar otherwise.""" + _DROPDOWN_SCALAR_TYPE = { + "anyOf": [{"type": "string"}, {"type": "number"}, {"type": "boolean"}] + } + component = param.get("component") + if getattr(component, "multi", False): + return {"type": "array", "items": _DROPDOWN_SCALAR_TYPE} + return _DROPDOWN_SCALAR_TYPE + + +TABULAR = PropRole( + identifiers={("DataTable", "data"), ("AgGrid", "rowData")}, + description="Returns tabular data", +) + +DATE = PropRole( + identifiers={ + ("DatePickerSingle", "date"), + ("DatePickerRange", "start_date"), + ("DatePickerRange", "end_date"), + }, + input_schema={ + "type": "string", + "format": "date", + "pattern": r"^\d{4}-\d{2}-\d{2}$", + }, +) + +DROPDOWN_VALUE = PropRole( + identifiers={("Dropdown", "value")}, + input_schema=_compute_dropdown_value_schema, +) + +STORE_DATA = PropRole( + identifiers={("Store", "data")}, + description="Returns data to be remembered client-side", +) + +DOWNLOAD = PropRole( + identifiers={("Download", "data")}, + description="Returns downloadable content", +) + +MARKDOWN = PropRole( + identifiers={("Markdown", "children")}, + description="Returns formatted text", +) + +GENERIC_FIGURE = PropRole( + identifiers={(ANY_COMPONENT, "figure")}, + description="Returns chart/visualization data", + input_schema={ + "type": "object", + "properties": { + "data": {"type": "array", "items": {"type": "object"}}, + "layout": {"type": "object"}, + "frames": {"type": "array", "items": {"type": "object"}}, + }, + }, +) + +GENERIC_CONTENT = PropRole( + identifiers={(ANY_COMPONENT, "children")}, + description="Returns content", +) + +GENERIC_VALUE = PropRole( + identifiers={(ANY_COMPONENT, "value")}, + description="Returns the current value", +) + +GENERIC_OPTIONS = PropRole( + identifiers={(ANY_COMPONENT, "options")}, + description="Returns available options", +) + +GENERIC_COLUMNS = PropRole( + identifiers={(ANY_COMPONENT, "columns")}, + description="Returns column definitions", +) + +GENERIC_STYLE = PropRole( + identifiers={(ANY_COMPONENT, "style")}, + description="Updates styling", +) + +GENERIC_DISABLED = PropRole( + identifiers={(ANY_COMPONENT, "disabled")}, + description="Updates enabled/disabled state", +) + + +def iter_prop_roles() -> Iterator[PropRole]: + """Yield every PropRole defined in this module in declaration order.""" + for value in globals().values(): + if isinstance(value, PropRole): + yield value diff --git a/tests/unit/mcp/tools/test_mcp_input_schemas.py b/tests/unit/mcp/tools/test_mcp_input_schemas.py index d0b5e86923..ade87bbc79 100644 --- a/tests/unit/mcp/tools/test_mcp_input_schemas.py +++ b/tests/unit/mcp/tools/test_mcp_input_schemas.py @@ -111,11 +111,6 @@ def update(val: annotation_type): ), ), ("Input", "n_submit", nullable(NUMBER)), - ( - "Dropdown", - "value", - nullable(STRING, NUMBER, BOOLEAN, array_of(STRING, NUMBER, BOOLEAN)), - ), ("Dropdown", "options", nullable({})), ("Checklist", "value", nullable(array_of(STRING, NUMBER, BOOLEAN))), ("Store", "data", nullable(OBJECT, array_of({}), NUMBER, STRING, BOOLEAN)), @@ -152,6 +147,13 @@ def test_mcpi001_override_beats_introspection(): assert "pattern" in schema +def test_mcpi013_graph_figure_uses_plotly_schema_override(): + """Graph.figure matches the FIGURE role's schema override (concrete via wildcard).""" + schema = _get_schema("Graph", "figure") + assert schema["type"] == "object" + assert set(schema["properties"]) == {"data", "layout", "frames"} + + @pytest.mark.parametrize( "component_type,prop,expected", INTROSPECTION_CASES, @@ -268,3 +270,20 @@ def update(val: int, data): def test_mcpi010_component_type_maps_to_string(): """Component annotation type maps to string schema.""" assert annotation_to_json_schema(Component) == STRING + + +def test_mcpi011_dropdown_value_multi_false_narrows_to_scalar(): + """Dropdown.value with multi=False narrows to a scalar union.""" + app = _app_with_callback(dcc.Dropdown(id="dd")) + tool = _user_tool(_tools_list(app)) + assert _schema_for(tool) == {"anyOf": [STRING, NUMBER, BOOLEAN]} + + +def test_mcpi012_dropdown_value_multi_true_narrows_to_array(): + """Dropdown.value with multi=True narrows to an array of scalars.""" + app = _app_with_callback(dcc.Dropdown(id="dd", multi=True)) + tool = _user_tool(_tools_list(app)) + assert _schema_for(tool) == { + "type": "array", + "items": {"anyOf": [STRING, NUMBER, BOOLEAN]}, + } From b35749582d3f77936fad9a011bea7674d458217b Mon Sep 17 00:00:00 2001 From: Adrian Borrmann Date: Wed, 1 Apr 2026 14:05:37 -0600 Subject: [PATCH 32/80] Add pattern-matching callback support to input schemas and descriptions --- .../tools/input_schemas/__init__.py | 2 + .../input_descriptions/__init__.py | 2 + .../description_pattern_matching.py | 74 +++++++++++ .../input_schemas/schema_pattern_matching.py | 89 +++++++++++++ .../input_schemas/test_pattern_matching.py | 121 ++++++++++++++++++ 5 files changed, 288 insertions(+) create mode 100644 dash/mcp/primitives/tools/input_schemas/input_descriptions/description_pattern_matching.py create mode 100644 dash/mcp/primitives/tools/input_schemas/schema_pattern_matching.py create mode 100644 tests/unit/mcp/tools/input_schemas/test_pattern_matching.py diff --git a/dash/mcp/primitives/tools/input_schemas/__init__.py b/dash/mcp/primitives/tools/input_schemas/__init__.py index 9fa82eda55..e037c5793f 100644 --- a/dash/mcp/primitives/tools/input_schemas/__init__.py +++ b/dash/mcp/primitives/tools/input_schemas/__init__.py @@ -12,12 +12,14 @@ from dash.mcp.types import MCPInput from .base import InputSchemaSource +from .schema_pattern_matching import PatternMatchingSchema from .schema_callback_type_annotations import AnnotationSchema from .schema_component_proptypes_overrides import OverrideSchema from .schema_component_proptypes import ComponentPropSchema from .input_descriptions import get_property_description _SOURCES: list[type[InputSchemaSource]] = [ + PatternMatchingSchema, AnnotationSchema, OverrideSchema, ComponentPropSchema, diff --git a/dash/mcp/primitives/tools/input_schemas/input_descriptions/__init__.py b/dash/mcp/primitives/tools/input_schemas/input_descriptions/__init__.py index 4bc6d8e984..ebba3b4af8 100644 --- a/dash/mcp/primitives/tools/input_schemas/input_descriptions/__init__.py +++ b/dash/mcp/primitives/tools/input_schemas/input_descriptions/__init__.py @@ -12,11 +12,13 @@ from .description_component_props import ComponentPropsDescription from .description_docstrings import DocstringPropDescription from .description_html_labels import LabelDescription +from .description_pattern_matching import PatternMatchingDescription _SOURCES: list[type[InputDescriptionSource]] = [ DocstringPropDescription, LabelDescription, ComponentPropsDescription, + PatternMatchingDescription, ] diff --git a/dash/mcp/primitives/tools/input_schemas/input_descriptions/description_pattern_matching.py b/dash/mcp/primitives/tools/input_schemas/input_descriptions/description_pattern_matching.py new file mode 100644 index 0000000000..69a19829f7 --- /dev/null +++ b/dash/mcp/primitives/tools/input_schemas/input_descriptions/description_pattern_matching.py @@ -0,0 +1,74 @@ +"""Description for pattern-matching callback inputs. + +Explains that the input corresponds to a pattern-matching callback +(ALL, MATCH, ALLSMALLER) and describes the expected format. +See: https://dash.plotly.com/pattern-matching-callbacks +""" + +from __future__ import annotations + +import json + +from dash.dependencies import Wildcard +from dash.mcp.types import MCPInput + +from .base import InputDescriptionSource + +_WILDCARD_VALUES = frozenset(w.value for w in Wildcard) + + +class PatternMatchingDescription(InputDescriptionSource): + """Describe pattern-matching behavior for wildcard inputs.""" + + @classmethod + def describe(cls, param: MCPInput) -> list[str]: + dep_id = _parse_dep_id(param["component_id"]) + if dep_id is None: + return [] + + wildcard_key, wildcard_type = _find_wildcard(dep_id) + if wildcard_key is None: + return [] + + non_wildcard = {k: v for k, v in dep_id.items() if k != wildcard_key} + pattern_desc = ", ".join(f'{k}="{v}"' for k, v in non_wildcard.items()) + prop = param["property"] + + wildcard_descriptions = { + "ALL": ( + f"Pattern-matching input (ALL): provide an array of `{prop}` values, " + f"one per component matching {{{pattern_desc}}}. " + f"All matching components are included." + ), + "MATCH": ( + f"Pattern-matching input (MATCH): provide the `{prop}` value " + f"for the specific component matching {{{pattern_desc}}} " + f"that triggered this callback." + ), + "ALLSMALLER": ( + f"Pattern-matching input (ALLSMALLER): provide an array of `{prop}` values " + f"from components matching {{{pattern_desc}}} " + f"whose `{wildcard_key}` is smaller than the triggering component's `{wildcard_key}`." + ), + } + + desc = wildcard_descriptions.get(wildcard_type) + return [desc] if desc else [] + + +def _parse_dep_id(component_id: str) -> dict | None: + if not component_id.startswith("{"): + return None + try: + return json.loads(component_id) + except (json.JSONDecodeError, ValueError): + return None + + +def _find_wildcard(dep_id: dict) -> tuple[str | None, str | None]: + """Return (key, wildcard_type) for the first wildcard found.""" + for key, value in dep_id.items(): + if isinstance(value, list) and len(value) == 1: + if value[0] in _WILDCARD_VALUES: + return key, value[0] + return None, None diff --git a/dash/mcp/primitives/tools/input_schemas/schema_pattern_matching.py b/dash/mcp/primitives/tools/input_schemas/schema_pattern_matching.py new file mode 100644 index 0000000000..e6b095370f --- /dev/null +++ b/dash/mcp/primitives/tools/input_schemas/schema_pattern_matching.py @@ -0,0 +1,89 @@ +"""Schema for pattern-matching callback inputs (ALL, MATCH, ALLSMALLER). + +When a callback input uses a wildcard ID, the callback receives a +list of values — one per matching component. This source detects +wildcard IDs and produces an array schema. If matching components +exist in the layout, the item type is inferred from a concrete match. +""" + +from __future__ import annotations + +import json +from typing import Any + +from dash._layout_utils import find_matching_components, _WILDCARD_VALUES +from dash.mcp.types import MCPInput + +from .base import InputSchemaSource + + +class PatternMatchingSchema(InputSchemaSource): + """Return a schema for pattern-matching inputs. + + For ALL/ALLSMALLER: array of ``{id, property, value}`` objects. + For MATCH: a single ``{id, property, value}`` object. + """ + + @classmethod + def get_schema(cls, param: MCPInput) -> dict[str, Any] | None: + dep_id = _parse_dep_id(param["component_id"]) + if dep_id is None: + return None + + wildcard_type = _get_wildcard_type(dep_id) + if wildcard_type is None: + return None + + value_schema = _infer_value_schema(param) + + item_schema: dict[str, Any] = { + "type": "object", + "properties": { + "id": {"type": "object"}, + "property": {"type": "string"}, + "value": value_schema or {}, + }, + "required": ["id", "property", "value"], + } + + if wildcard_type == "MATCH": + return item_schema + + return {"type": "array", "items": item_schema} + + +def _parse_dep_id(component_id: str) -> dict | None: + if not component_id.startswith("{"): + return None + try: + return json.loads(component_id) + except (json.JSONDecodeError, ValueError): + return None + + +def _get_wildcard_type(dep_id: dict) -> str | None: + """Return the wildcard type (ALL, MATCH, ALLSMALLER) or None.""" + for value in dep_id.values(): + if isinstance(value, list) and len(value) == 1: + if value[0] in _WILDCARD_VALUES: + return value[0] + return None + + +def _infer_value_schema(param: MCPInput) -> dict[str, Any] | None: + """Infer the JSON Schema for the ``value`` field from a matching component.""" + matches = find_matching_components(_parse_dep_id(param["component_id"])) + if not matches: + return None + + from . import get_input_schema + + concrete_param: MCPInput = { + **param, + "component": matches[0], + "component_id": str(getattr(matches[0], "id", "")), + "component_type": getattr(matches[0], "_type", None), + } + schema = get_input_schema(concrete_param) + schema.pop("description", None) + return schema or None diff --git a/tests/unit/mcp/tools/input_schemas/test_pattern_matching.py b/tests/unit/mcp/tools/input_schemas/test_pattern_matching.py new file mode 100644 index 0000000000..37b8a642ee --- /dev/null +++ b/tests/unit/mcp/tools/input_schemas/test_pattern_matching.py @@ -0,0 +1,121 @@ +"""Tests for pattern-matching schema and description generation.""" + +from dash import Dash, html, Input, Output, ALL, MATCH + +from tests.unit.mcp.conftest import _tools_list, _user_tool, _schema_for, _desc_for + + +class TestPatternMatchingSchema: + def test_all_produces_array_schema(self): + app = Dash(__name__) + app.layout = html.Div( + [ + html.Div(id={"type": "item", "index": 0}, children="A"), + html.Div(id={"type": "item", "index": 1}, children="B"), + html.Div(id="result"), + ] + ) + + @app.callback( + Output("result", "children"), + Input({"type": "item", "index": ALL}, "children"), + ) + def combine(values): + return ", ".join(values) + + tool = _user_tool(_tools_list(app)) + schema = _schema_for(tool) + assert schema["type"] == "array" + assert schema["items"]["type"] == "object" + assert "value" in schema["items"]["properties"] + + def test_match_produces_object_schema(self): + app = Dash(__name__) + app.layout = html.Div( + [ + html.Div(id={"type": "item", "index": 0}, children="A"), + html.Div(id="result"), + ] + ) + + @app.callback( + Output("result", "children"), + Input({"type": "item", "index": MATCH}, "children"), + ) + def echo(value): + return value + + tool = _user_tool(_tools_list(app)) + schema = _schema_for(tool) + assert schema["type"] == "object" + assert "value" in schema["properties"] + + def test_annotation_narrows_value_schema(self): + from dash import dcc + + app = Dash(__name__) + app.layout = html.Div( + [ + dcc.Dropdown(id={"type": "filter", "index": 0}, options=["a", "b"]), + dcc.Dropdown(id={"type": "filter", "index": 1}, options=["c", "d"]), + html.Div(id="result"), + ] + ) + + @app.callback( + Output("result", "children"), + Input({"type": "filter", "index": ALL}, "options"), + ) + def combine(options: list[str]): + return str(options) + + tool = _user_tool(_tools_list(app)) + schema = _schema_for(tool) + assert schema["type"] == "array" + value_schema = schema["items"]["properties"]["value"] + # Annotation narrows value to list[str] instead of the broad introspected type + assert value_schema == {"items": {"type": "string"}, "type": "array"} + + +class TestPatternMatchingDescription: + def test_all_description(self): + app = Dash(__name__) + app.layout = html.Div( + [ + html.Div(id={"type": "item", "index": 0}), + html.Div(id="result"), + ] + ) + + @app.callback( + Output("result", "children"), + Input({"type": "item", "index": ALL}, "children"), + ) + def combine(values): + return str(values) + + tool = _user_tool(_tools_list(app)) + desc = _desc_for(tool) + assert "Pattern-matching input (ALL)" in desc + assert 'type="item"' in desc + + def test_match_description(self): + app = Dash(__name__) + app.layout = html.Div( + [ + html.Div(id={"type": "item", "index": 0}), + html.Div(id="result"), + ] + ) + + @app.callback( + Output("result", "children"), + Input({"type": "item", "index": MATCH}, "children"), + ) + def echo(value): + return value + + tool = _user_tool(_tools_list(app)) + desc = _desc_for(tool) + assert "Pattern-matching input (MATCH)" in desc + assert 'type="item"' in desc From 6906b3d3715632ecce6c67b023fd02e301d4ffdf Mon Sep 17 00:00:00 2001 From: Adrian Borrmann Date: Thu, 16 Apr 2026 13:24:03 -0600 Subject: [PATCH 33/80] clean up duplicated code --- .../description_pattern_matching.py | 17 ++------------- .../input_schemas/schema_pattern_matching.py | 21 +++++++------------ 2 files changed, 10 insertions(+), 28 deletions(-) diff --git a/dash/mcp/primitives/tools/input_schemas/input_descriptions/description_pattern_matching.py b/dash/mcp/primitives/tools/input_schemas/input_descriptions/description_pattern_matching.py index 69a19829f7..221423aa50 100644 --- a/dash/mcp/primitives/tools/input_schemas/input_descriptions/description_pattern_matching.py +++ b/dash/mcp/primitives/tools/input_schemas/input_descriptions/description_pattern_matching.py @@ -7,22 +7,18 @@ from __future__ import annotations -import json - -from dash.dependencies import Wildcard +from dash._layout_utils import _WILDCARD_VALUES, parse_wildcard_id from dash.mcp.types import MCPInput from .base import InputDescriptionSource -_WILDCARD_VALUES = frozenset(w.value for w in Wildcard) - class PatternMatchingDescription(InputDescriptionSource): """Describe pattern-matching behavior for wildcard inputs.""" @classmethod def describe(cls, param: MCPInput) -> list[str]: - dep_id = _parse_dep_id(param["component_id"]) + dep_id = parse_wildcard_id(param["component_id"]) if dep_id is None: return [] @@ -56,15 +52,6 @@ def describe(cls, param: MCPInput) -> list[str]: return [desc] if desc else [] -def _parse_dep_id(component_id: str) -> dict | None: - if not component_id.startswith("{"): - return None - try: - return json.loads(component_id) - except (json.JSONDecodeError, ValueError): - return None - - def _find_wildcard(dep_id: dict) -> tuple[str | None, str | None]: """Return (key, wildcard_type) for the first wildcard found.""" for key, value in dep_id.items(): diff --git a/dash/mcp/primitives/tools/input_schemas/schema_pattern_matching.py b/dash/mcp/primitives/tools/input_schemas/schema_pattern_matching.py index e6b095370f..52e16cf58b 100644 --- a/dash/mcp/primitives/tools/input_schemas/schema_pattern_matching.py +++ b/dash/mcp/primitives/tools/input_schemas/schema_pattern_matching.py @@ -8,10 +8,13 @@ from __future__ import annotations -import json from typing import Any -from dash._layout_utils import find_matching_components, _WILDCARD_VALUES +from dash._layout_utils import ( + _WILDCARD_VALUES, + find_matching_components, + parse_wildcard_id, +) from dash.mcp.types import MCPInput from .base import InputSchemaSource @@ -26,7 +29,7 @@ class PatternMatchingSchema(InputSchemaSource): @classmethod def get_schema(cls, param: MCPInput) -> dict[str, Any] | None: - dep_id = _parse_dep_id(param["component_id"]) + dep_id = parse_wildcard_id(param["component_id"]) if dep_id is None: return None @@ -52,15 +55,6 @@ def get_schema(cls, param: MCPInput) -> dict[str, Any] | None: return {"type": "array", "items": item_schema} -def _parse_dep_id(component_id: str) -> dict | None: - if not component_id.startswith("{"): - return None - try: - return json.loads(component_id) - except (json.JSONDecodeError, ValueError): - return None - - def _get_wildcard_type(dep_id: dict) -> str | None: """Return the wildcard type (ALL, MATCH, ALLSMALLER) or None.""" for value in dep_id.values(): @@ -72,10 +66,11 @@ def _get_wildcard_type(dep_id: dict) -> str | None: def _infer_value_schema(param: MCPInput) -> dict[str, Any] | None: """Infer the JSON Schema for the ``value`` field from a matching component.""" - matches = find_matching_components(_parse_dep_id(param["component_id"])) + matches = find_matching_components(parse_wildcard_id(param["component_id"])) if not matches: return None + # pylint: disable-next=cyclic-import,import-outside-toplevel from . import get_input_schema concrete_param: MCPInput = { From 6a52dba6c68a3fb4e4ba28f14ea7e85d930aa472 Mon Sep 17 00:00:00 2001 From: Adrian Borrmann Date: Thu, 23 Apr 2026 14:04:40 -0600 Subject: [PATCH 34/80] Refactor unit tests to conform to existing test patterns --- .../input_schemas/test_pattern_matching.py | 121 ---------------- .../mcp/tools/test_mcp_pattern_matching.py | 136 ++++++++++++++++++ 2 files changed, 136 insertions(+), 121 deletions(-) delete mode 100644 tests/unit/mcp/tools/input_schemas/test_pattern_matching.py create mode 100644 tests/unit/mcp/tools/test_mcp_pattern_matching.py diff --git a/tests/unit/mcp/tools/input_schemas/test_pattern_matching.py b/tests/unit/mcp/tools/input_schemas/test_pattern_matching.py deleted file mode 100644 index 37b8a642ee..0000000000 --- a/tests/unit/mcp/tools/input_schemas/test_pattern_matching.py +++ /dev/null @@ -1,121 +0,0 @@ -"""Tests for pattern-matching schema and description generation.""" - -from dash import Dash, html, Input, Output, ALL, MATCH - -from tests.unit.mcp.conftest import _tools_list, _user_tool, _schema_for, _desc_for - - -class TestPatternMatchingSchema: - def test_all_produces_array_schema(self): - app = Dash(__name__) - app.layout = html.Div( - [ - html.Div(id={"type": "item", "index": 0}, children="A"), - html.Div(id={"type": "item", "index": 1}, children="B"), - html.Div(id="result"), - ] - ) - - @app.callback( - Output("result", "children"), - Input({"type": "item", "index": ALL}, "children"), - ) - def combine(values): - return ", ".join(values) - - tool = _user_tool(_tools_list(app)) - schema = _schema_for(tool) - assert schema["type"] == "array" - assert schema["items"]["type"] == "object" - assert "value" in schema["items"]["properties"] - - def test_match_produces_object_schema(self): - app = Dash(__name__) - app.layout = html.Div( - [ - html.Div(id={"type": "item", "index": 0}, children="A"), - html.Div(id="result"), - ] - ) - - @app.callback( - Output("result", "children"), - Input({"type": "item", "index": MATCH}, "children"), - ) - def echo(value): - return value - - tool = _user_tool(_tools_list(app)) - schema = _schema_for(tool) - assert schema["type"] == "object" - assert "value" in schema["properties"] - - def test_annotation_narrows_value_schema(self): - from dash import dcc - - app = Dash(__name__) - app.layout = html.Div( - [ - dcc.Dropdown(id={"type": "filter", "index": 0}, options=["a", "b"]), - dcc.Dropdown(id={"type": "filter", "index": 1}, options=["c", "d"]), - html.Div(id="result"), - ] - ) - - @app.callback( - Output("result", "children"), - Input({"type": "filter", "index": ALL}, "options"), - ) - def combine(options: list[str]): - return str(options) - - tool = _user_tool(_tools_list(app)) - schema = _schema_for(tool) - assert schema["type"] == "array" - value_schema = schema["items"]["properties"]["value"] - # Annotation narrows value to list[str] instead of the broad introspected type - assert value_schema == {"items": {"type": "string"}, "type": "array"} - - -class TestPatternMatchingDescription: - def test_all_description(self): - app = Dash(__name__) - app.layout = html.Div( - [ - html.Div(id={"type": "item", "index": 0}), - html.Div(id="result"), - ] - ) - - @app.callback( - Output("result", "children"), - Input({"type": "item", "index": ALL}, "children"), - ) - def combine(values): - return str(values) - - tool = _user_tool(_tools_list(app)) - desc = _desc_for(tool) - assert "Pattern-matching input (ALL)" in desc - assert 'type="item"' in desc - - def test_match_description(self): - app = Dash(__name__) - app.layout = html.Div( - [ - html.Div(id={"type": "item", "index": 0}), - html.Div(id="result"), - ] - ) - - @app.callback( - Output("result", "children"), - Input({"type": "item", "index": MATCH}, "children"), - ) - def echo(value): - return value - - tool = _user_tool(_tools_list(app)) - desc = _desc_for(tool) - assert "Pattern-matching input (MATCH)" in desc - assert 'type="item"' in desc diff --git a/tests/unit/mcp/tools/test_mcp_pattern_matching.py b/tests/unit/mcp/tools/test_mcp_pattern_matching.py new file mode 100644 index 0000000000..25c8de1f8d --- /dev/null +++ b/tests/unit/mcp/tools/test_mcp_pattern_matching.py @@ -0,0 +1,136 @@ +"""Pattern-matching callback support — schemas and descriptions for wildcard IDs. + +Covers callbacks whose Input/Output use ``ALL``, ``MATCH``, or ``ALLSMALLER`` +wildcards in dict-based component IDs: the input is typed as an array (ALL) +or object (MATCH) of ``{id, property, value}`` entries, and descriptions +surface the wildcard kind and ID pattern. +""" + +from dash import Dash, html, Input, Output, ALL, MATCH, dcc + +from tests.unit.mcp.conftest import _tools_list, _user_tool, _schema_for, _desc_for + + +# --------------------------------------------------------------------------- +# Schema shape +# --------------------------------------------------------------------------- + + +def test_mcpm001_all_produces_array_schema(): + app = Dash(__name__) + app.layout = html.Div( + [ + html.Div(id={"type": "item", "index": 0}, children="A"), + html.Div(id={"type": "item", "index": 1}, children="B"), + html.Div(id="result"), + ] + ) + + @app.callback( + Output("result", "children"), + Input({"type": "item", "index": ALL}, "children"), + ) + def combine(values): + return ", ".join(values) + + tool = _user_tool(_tools_list(app)) + schema = _schema_for(tool) + assert schema["type"] == "array" + assert schema["items"]["type"] == "object" + assert "value" in schema["items"]["properties"] + + +def test_mcpm002_match_produces_object_schema(): + app = Dash(__name__) + app.layout = html.Div( + [ + html.Div(id={"type": "item", "index": 0}, children="A"), + html.Div(id="result"), + ] + ) + + @app.callback( + Output("result", "children"), + Input({"type": "item", "index": MATCH}, "children"), + ) + def echo(value): + return value + + tool = _user_tool(_tools_list(app)) + schema = _schema_for(tool) + assert schema["type"] == "object" + assert "value" in schema["properties"] + + +def test_mcpm003_annotation_narrows_value_schema(): + app = Dash(__name__) + app.layout = html.Div( + [ + dcc.Dropdown(id={"type": "filter", "index": 0}, options=["a", "b"]), + dcc.Dropdown(id={"type": "filter", "index": 1}, options=["c", "d"]), + html.Div(id="result"), + ] + ) + + @app.callback( + Output("result", "children"), + Input({"type": "filter", "index": ALL}, "options"), + ) + def combine(options: list[str]): + return str(options) + + tool = _user_tool(_tools_list(app)) + schema = _schema_for(tool) + assert schema["type"] == "array" + value_schema = schema["items"]["properties"]["value"] + # Annotation narrows value to list[str] instead of the broad introspected type + assert value_schema == {"items": {"type": "string"}, "type": "array"} + + +# --------------------------------------------------------------------------- +# Descriptions +# --------------------------------------------------------------------------- + + +def test_mcpm004_all_description(): + app = Dash(__name__) + app.layout = html.Div( + [ + html.Div(id={"type": "item", "index": 0}), + html.Div(id="result"), + ] + ) + + @app.callback( + Output("result", "children"), + Input({"type": "item", "index": ALL}, "children"), + ) + def combine(values): + return str(values) + + tool = _user_tool(_tools_list(app)) + desc = _desc_for(tool) + assert "Pattern-matching input (ALL)" in desc + assert 'type="item"' in desc + + +def test_mcpm005_match_description(): + app = Dash(__name__) + app.layout = html.Div( + [ + html.Div(id={"type": "item", "index": 0}), + html.Div(id="result"), + ] + ) + + @app.callback( + Output("result", "children"), + Input({"type": "item", "index": MATCH}, "children"), + ) + def echo(value): + return value + + tool = _user_tool(_tools_list(app)) + desc = _desc_for(tool) + assert "Pattern-matching input (MATCH)" in desc + assert 'type="item"' in desc From 7ce9939fbdc4b33bf7e7aa970b1f3556da24dadb Mon Sep 17 00:00:00 2001 From: Adrian Borrmann Date: Wed, 1 Apr 2026 14:23:37 -0600 Subject: [PATCH 35/80] Add result formatters for Plotly figures and tabular data --- dash/mcp/primitives/tools/results/__init__.py | 52 ++++++++++ .../tools/results/result_dataframe.py | 55 +++++++++++ .../tools/results/result_plotly_figure.py | 52 ++++++++++ .../tools/results/test_callback_response.py | 98 +++++++++++++++++++ .../unit/mcp/tools/results/test_dataframe.py | 63 ++++++++++++ .../mcp/tools/results/test_plotly_figure.py | 55 +++++++++++ 6 files changed, 375 insertions(+) create mode 100644 dash/mcp/primitives/tools/results/__init__.py create mode 100644 dash/mcp/primitives/tools/results/result_dataframe.py create mode 100644 dash/mcp/primitives/tools/results/result_plotly_figure.py create mode 100644 tests/unit/mcp/tools/results/test_callback_response.py create mode 100644 tests/unit/mcp/tools/results/test_dataframe.py create mode 100644 tests/unit/mcp/tools/results/test_plotly_figure.py diff --git a/dash/mcp/primitives/tools/results/__init__.py b/dash/mcp/primitives/tools/results/__init__.py new file mode 100644 index 0000000000..e2f91a67a8 --- /dev/null +++ b/dash/mcp/primitives/tools/results/__init__.py @@ -0,0 +1,52 @@ +"""Tool result formatting for MCP tools/call responses. + +Each result formatter shares the same signature: +``(output: MCPOutput, value: Any) -> list[TextContent | ImageContent]`` + +Formatters decide for themselves whether they care about a given output. +The structuredContent is always the full dispatch response. +""" + +from __future__ import annotations + +import json +from typing import Any + +from mcp.types import CallToolResult, TextContent + +from dash.types import CallbackDispatchResponse +from dash.mcp.primitives.tools.callback_adapter import CallbackAdapter + +from .result_dataframe import dataframe_result +from .result_plotly_figure import plotly_figure_result + +_RESULT_FORMATTERS = [ + plotly_figure_result, + dataframe_result, +] + + +def format_callback_response( + response: CallbackDispatchResponse, + callback: CallbackAdapter, +) -> CallToolResult: + """Format a dispatch response as a CallToolResult. + + The response is always returned as structuredContent. Result + formatters are called per output property and may add additional + content items (images, markdown, etc.). + """ + content: list[Any] = [ + TextContent(type="text", text=json.dumps(response, default=str)), + ] + + resp = response.get("response") or {} + for callback_output in callback.outputs: + value = resp.get(callback_output["component_id"], {}).get(callback_output["property"]) + for result_fn in _RESULT_FORMATTERS: + content.extend(result_fn(callback_output, value)) + + return CallToolResult( + content=content, + structuredContent=response, + ) diff --git a/dash/mcp/primitives/tools/results/result_dataframe.py b/dash/mcp/primitives/tools/results/result_dataframe.py new file mode 100644 index 0000000000..652c31589b --- /dev/null +++ b/dash/mcp/primitives/tools/results/result_dataframe.py @@ -0,0 +1,55 @@ +"""Tabular data result: render as a markdown table. + +Detects tabular output by component type and prop name: +- DataTable.data +- AgGrid.rowData +""" + +from __future__ import annotations + +from typing import Any + +from mcp.types import TextContent + +from dash.mcp.types import MCPOutput + +MAX_ROWS = 50 + +_TABULAR_PROPS = { + ("DataTable", "data"), + ("AgGrid", "rowData"), +} + + +def _to_markdown_table(rows: list[dict], max_rows: int = MAX_ROWS) -> str: + """Render a list of row dicts as a markdown table.""" + columns = list(rows[0].keys()) + total = len(rows) + + lines: list[str] = [] + lines.append(f"*{total} rows \u00d7 {len(columns)} columns*") + lines.append("") + lines.append("| " + " | ".join(columns) + " |") + lines.append("| " + " | ".join("---" for _ in columns) + " |") + + for row in rows[:max_rows]: + cells = [ + str(row.get(col, "")).replace("|", "\\|").replace("\n", " ") + for col in columns + ] + lines.append("| " + " | ".join(cells) + " |") + + if total > max_rows: + lines.append(f"\n(\u2026 {total - max_rows} more rows)") + + return "\n".join(lines) + + +def dataframe_result(callback_output: MCPOutput, callback_output_value: Any) -> list: + """Produce a markdown table for tabular component output values.""" + key = (callback_output.get("component_type"), callback_output.get("property")) + if key not in _TABULAR_PROPS: + return [] + if not isinstance(callback_output_value, list) or not callback_output_value or not isinstance(callback_output_value[0], dict): + return [] + return [TextContent(type="text", text=_to_markdown_table(callback_output_value))] diff --git a/dash/mcp/primitives/tools/results/result_plotly_figure.py b/dash/mcp/primitives/tools/results/result_plotly_figure.py new file mode 100644 index 0000000000..d837e17eed --- /dev/null +++ b/dash/mcp/primitives/tools/results/result_plotly_figure.py @@ -0,0 +1,52 @@ +"""Plotly figure tool result: rendered image.""" + +from __future__ import annotations + +import base64 +import logging +from typing import Any + +from mcp.types import ImageContent + +from dash.mcp.types import MCPOutput + +logger = logging.getLogger(__name__) + +IMAGE_WIDTH = 700 +IMAGE_HEIGHT = 450 + + +def _render_image(figure: Any) -> ImageContent | None: + """Render the figure as a base64 PNG ImageContent. + + Returns None if kaleido is not installed. + """ + try: + img_bytes = figure.to_image( + format="png", + width=IMAGE_WIDTH, + height=IMAGE_HEIGHT, + ) + except (ValueError, ImportError): + logger.debug("MCP: kaleido not available, skipping image render") + return None + + b64 = base64.b64encode(img_bytes).decode("ascii") + return ImageContent(type="image", data=b64, mimeType="image/png") + + +def plotly_figure_result(callback_output: MCPOutput, callback_output_value: Any) -> list: + """Produce a rendered PNG for Graph.figure output values.""" + if callback_output.get("component_type") != "Graph" or callback_output.get("property") != "figure": + return [] + if not isinstance(callback_output_value, dict): + return [] + + try: + import plotly.graph_objects as go + except ImportError: + return [] + + fig = go.Figure(callback_output_value) + image = _render_image(fig) + return [image] if image is not None else [] diff --git a/tests/unit/mcp/tools/results/test_callback_response.py b/tests/unit/mcp/tools/results/test_callback_response.py new file mode 100644 index 0000000000..ff8cca5e20 --- /dev/null +++ b/tests/unit/mcp/tools/results/test_callback_response.py @@ -0,0 +1,98 @@ +"""Tests for the callback response formatter.""" + +from unittest.mock import Mock + +from dash.mcp.primitives.tools.results import format_callback_response + + +def _mock_callback(outputs=None): + cb = Mock() + cb.outputs = outputs or [] + return cb + + +class TestFormatCallbackResponse: + def test_wraps_as_structured_content(self): + response = { + "multi": True, + "response": {"out": {"children": "hello"}}, + } + result = format_callback_response(response, _mock_callback()) + assert result.structuredContent == response + + def test_content_has_json_text_fallback(self): + """Per MCP spec, structuredContent SHOULD include a TextContent fallback.""" + response = {"multi": True, "response": {}} + result = format_callback_response(response, _mock_callback()) + assert len(result.content) >= 1 + assert result.content[0].type == "text" + assert '"multi": true' in result.content[0].text + + def test_is_error_defaults_false(self): + response = {"multi": True, "response": {}} + result = format_callback_response(response, _mock_callback()) + assert result.isError is False + + def test_preserves_side_update(self): + response = { + "multi": True, + "response": {"out": {"children": "x"}}, + "sideUpdate": {"other": {"value": 42}}, + } + result = format_callback_response(response, _mock_callback()) + assert result.structuredContent["sideUpdate"] == {"other": {"value": 42}} + + def test_datatable_result_includes_markdown_table(self): + response = { + "multi": True, + "response": { + "my-table": {"data": [{"name": "Alice", "age": 30}]}, + }, + } + outputs = [ + { + "component_id": "my-table", + "component_type": "DataTable", + "property": "data", + "id_and_prop": "my-table.data", + "initial_value": None, + "tool_name": "update", + } + ] + result = format_callback_response(response, _mock_callback(outputs)) + texts = [c.text for c in result.content if c.type == "text"] + assert any("| name | age |" in t for t in texts) + + def test_plotly_figure_includes_image(self): + from unittest.mock import patch + + try: + import plotly.graph_objects as go + except ImportError: + return + + response = { + "multi": True, + "response": { + "my-graph": { + "figure": { + "data": [{"type": "bar", "x": ["A"], "y": [1]}], + "layout": {}, + } + } + }, + } + outputs = [ + { + "component_id": "my-graph", + "component_type": "Graph", + "property": "figure", + "id_and_prop": "my-graph.figure", + "initial_value": None, + "tool_name": "update", + } + ] + with patch.object(go.Figure, "to_image", return_value=b"\x89PNGfake"): + result = format_callback_response(response, _mock_callback(outputs)) + images = [c for c in result.content if c.type == "image"] + assert len(images) == 1 diff --git a/tests/unit/mcp/tools/results/test_dataframe.py b/tests/unit/mcp/tools/results/test_dataframe.py new file mode 100644 index 0000000000..a7f9e42fca --- /dev/null +++ b/tests/unit/mcp/tools/results/test_dataframe.py @@ -0,0 +1,63 @@ +"""Tests for the tabular data result formatter.""" + +from dash.mcp.primitives.tools.results.result_dataframe import ( + MAX_ROWS, + dataframe_result, +) + +EXPECTED_TABLE = ( + "*2 rows \u00d7 2 columns*\n" + "\n" + "| name | age |\n" + "| --- | --- |\n" + "| Alice | 30 |\n" + "| Bob | 25 |" +) + +SAMPLE_ROWS = [{"name": "Alice", "age": 30}, {"name": "Bob", "age": 25}] + +DATATABLE_OUTPUT = { + "component_type": "DataTable", + "property": "data", + "component_id": "t", + "id_and_prop": "t.data", + "initial_value": None, + "tool_name": "update", +} + +AGGRID_OUTPUT = { + "component_type": "AgGrid", + "property": "rowData", + "component_id": "g", + "id_and_prop": "g.rowData", + "initial_value": None, + "tool_name": "update", +} + + +class TestDataframeResult: + def test_datatable_data_renders_markdown(self): + result = dataframe_result(DATATABLE_OUTPUT, SAMPLE_ROWS) + assert len(result) == 1 + assert result[0].text == EXPECTED_TABLE + + def test_aggrid_rowdata_renders_markdown(self): + result = dataframe_result(AGGRID_OUTPUT, SAMPLE_ROWS) + assert len(result) == 1 + assert result[0].text == EXPECTED_TABLE + + def test_ignores_non_tabular_props(self): + non_tabular = {**DATATABLE_OUTPUT, "property": "columns"} + assert dataframe_result(non_tabular, SAMPLE_ROWS) == [] + + def test_ignores_empty_or_non_dict_rows(self): + assert dataframe_result(DATATABLE_OUTPUT, []) == [] + assert dataframe_result(DATATABLE_OUTPUT, ["a", "b"]) == [] + + def test_truncates_large_tables(self): + rows = [{"i": n} for n in range(MAX_ROWS + 50)] + result = dataframe_result(DATATABLE_OUTPUT, rows) + text = result[0].text + assert f"| {MAX_ROWS - 1} |" in text + assert f"| {MAX_ROWS} |" not in text + assert "50 more rows" in text diff --git a/tests/unit/mcp/tools/results/test_plotly_figure.py b/tests/unit/mcp/tools/results/test_plotly_figure.py new file mode 100644 index 0000000000..8e336ba687 --- /dev/null +++ b/tests/unit/mcp/tools/results/test_plotly_figure.py @@ -0,0 +1,55 @@ +"""Tests for the Plotly figure tool result formatter.""" + +import base64 +from unittest.mock import patch + +import pytest + +from dash.mcp.primitives.tools.results.result_plotly_figure import ( + plotly_figure_result, +) + +go = pytest.importorskip("plotly.graph_objects") + +FAKE_PNG = b"\x89PNG\r\n\x1a\nfakedata" +FAKE_B64 = base64.b64encode(FAKE_PNG).decode("ascii") + +GRAPH_FIGURE_OUTPUT = { + "component_type": "Graph", + "property": "figure", + "component_id": "g", + "id_and_prop": "g.figure", + "initial_value": None, + "tool_name": "update", +} + + +class TestPlotlyFigureResult: + def test_returns_image_when_kaleido_available(self): + fig_dict = go.Figure(data=[go.Bar(x=["A", "B"], y=[1, 2])]).to_plotly_json() + with patch.object(go.Figure, "to_image", return_value=FAKE_PNG): + result = plotly_figure_result(GRAPH_FIGURE_OUTPUT, fig_dict) + assert len(result) == 1 + assert result[0].type == "image" + assert result[0].data == FAKE_B64 + + def test_returns_empty_when_kaleido_unavailable(self): + fig_dict = go.Figure(data=[go.Bar(x=["A", "B"], y=[1, 2])]).to_plotly_json() + with patch.object(go.Figure, "to_image", side_effect=ImportError): + result = plotly_figure_result(GRAPH_FIGURE_OUTPUT, fig_dict) + assert result == [] + + def test_ignores_non_graph_components(self): + output = { + **GRAPH_FIGURE_OUTPUT, + "component_type": "Div", + "property": "children", + } + assert plotly_figure_result(output, {}) == [] + + def test_ignores_non_figure_props(self): + output = {**GRAPH_FIGURE_OUTPUT, "property": "clickData"} + assert plotly_figure_result(output, {}) == [] + + def test_ignores_non_dict_values(self): + assert plotly_figure_result(GRAPH_FIGURE_OUTPUT, "not a dict") == [] From acf75a0f3f1ae6aa3bb2cced31e884a309b10f4c Mon Sep 17 00:00:00 2001 From: Adrian Borrmann Date: Thu, 16 Apr 2026 13:34:23 -0600 Subject: [PATCH 36/80] Update type names --- dash/mcp/primitives/tools/results/__init__.py | 4 ++-- 1 file changed, 2 insertions(+), 2 deletions(-) diff --git a/dash/mcp/primitives/tools/results/__init__.py b/dash/mcp/primitives/tools/results/__init__.py index e2f91a67a8..ed21178c0a 100644 --- a/dash/mcp/primitives/tools/results/__init__.py +++ b/dash/mcp/primitives/tools/results/__init__.py @@ -14,7 +14,7 @@ from mcp.types import CallToolResult, TextContent -from dash.types import CallbackDispatchResponse +from dash.types import CallbackExecutionResponse from dash.mcp.primitives.tools.callback_adapter import CallbackAdapter from .result_dataframe import dataframe_result @@ -27,7 +27,7 @@ def format_callback_response( - response: CallbackDispatchResponse, + response: CallbackExecutionResponse, callback: CallbackAdapter, ) -> CallToolResult: """Format a dispatch response as a CallToolResult. From d01d43ae5548e072db81e38c76dc8b83dd24a79d Mon Sep 17 00:00:00 2001 From: Adrian Borrmann Date: Thu, 16 Apr 2026 13:43:01 -0600 Subject: [PATCH 37/80] Refactor tool result formatters to use a base class (just as resources do) --- dash/mcp/primitives/tools/results/__init__.py | 24 ++++++------- dash/mcp/primitives/tools/results/base.py | 24 +++++++++++++ .../tools/results/result_dataframe.py | 27 +++++++++----- .../tools/results/result_plotly_figure.py | 35 +++++++++++-------- .../unit/mcp/tools/results/test_dataframe.py | 14 ++++---- .../mcp/tools/results/test_plotly_figure.py | 12 +++---- 6 files changed, 88 insertions(+), 48 deletions(-) create mode 100644 dash/mcp/primitives/tools/results/base.py diff --git a/dash/mcp/primitives/tools/results/__init__.py b/dash/mcp/primitives/tools/results/__init__.py index ed21178c0a..c0232d6028 100644 --- a/dash/mcp/primitives/tools/results/__init__.py +++ b/dash/mcp/primitives/tools/results/__init__.py @@ -1,10 +1,7 @@ """Tool result formatting for MCP tools/call responses. -Each result formatter shares the same signature: -``(output: MCPOutput, value: Any) -> list[TextContent | ImageContent]`` - -Formatters decide for themselves whether they care about a given output. -The structuredContent is always the full dispatch response. +Each formatter is a ``ResultFormatter`` subclass that can enrich +a tool result with additional content. All formatters are accumulated. """ from __future__ import annotations @@ -17,12 +14,13 @@ from dash.types import CallbackExecutionResponse from dash.mcp.primitives.tools.callback_adapter import CallbackAdapter -from .result_dataframe import dataframe_result -from .result_plotly_figure import plotly_figure_result +from .base import ResultFormatter +from .result_dataframe import DataFrameResult +from .result_plotly_figure import PlotlyFigureResult -_RESULT_FORMATTERS = [ - plotly_figure_result, - dataframe_result, +_RESULT_FORMATTERS: list[type[ResultFormatter]] = [ + PlotlyFigureResult, + DataFrameResult, ] @@ -30,7 +28,7 @@ def format_callback_response( response: CallbackExecutionResponse, callback: CallbackAdapter, ) -> CallToolResult: - """Format a dispatch response as a CallToolResult. + """Format a callback response as a CallToolResult. The response is always returned as structuredContent. Result formatters are called per output property and may add additional @@ -43,8 +41,8 @@ def format_callback_response( resp = response.get("response") or {} for callback_output in callback.outputs: value = resp.get(callback_output["component_id"], {}).get(callback_output["property"]) - for result_fn in _RESULT_FORMATTERS: - content.extend(result_fn(callback_output, value)) + for formatter in _RESULT_FORMATTERS: + content.extend(formatter.format(callback_output, value)) return CallToolResult( content=content, diff --git a/dash/mcp/primitives/tools/results/base.py b/dash/mcp/primitives/tools/results/base.py new file mode 100644 index 0000000000..1f7714ff6b --- /dev/null +++ b/dash/mcp/primitives/tools/results/base.py @@ -0,0 +1,24 @@ +"""Base class for result formatters.""" + +from __future__ import annotations + +from typing import Any + +from mcp.types import ImageContent, TextContent + +from dash.mcp.types import MCPOutput + + +class ResultFormatter: + """A formatter that can enrich an MCP tool result with additional content. + + Subclasses implement ``format`` to return content items (text, images) + for a specific callback output. All formatters are accumulated — every + formatter can add content to the overall tool result. + """ + + @classmethod + def format( + cls, output: MCPOutput, returned_output_value: Any + ) -> list[TextContent | ImageContent]: + raise NotImplementedError diff --git a/dash/mcp/primitives/tools/results/result_dataframe.py b/dash/mcp/primitives/tools/results/result_dataframe.py index 652c31589b..b7113f5d82 100644 --- a/dash/mcp/primitives/tools/results/result_dataframe.py +++ b/dash/mcp/primitives/tools/results/result_dataframe.py @@ -9,10 +9,12 @@ from typing import Any -from mcp.types import TextContent +from mcp.types import ImageContent, TextContent from dash.mcp.types import MCPOutput +from .base import ResultFormatter + MAX_ROWS = 50 _TABULAR_PROPS = { @@ -45,11 +47,20 @@ def _to_markdown_table(rows: list[dict], max_rows: int = MAX_ROWS) -> str: return "\n".join(lines) -def dataframe_result(callback_output: MCPOutput, callback_output_value: Any) -> list: +class DataFrameResult(ResultFormatter): """Produce a markdown table for tabular component output values.""" - key = (callback_output.get("component_type"), callback_output.get("property")) - if key not in _TABULAR_PROPS: - return [] - if not isinstance(callback_output_value, list) or not callback_output_value or not isinstance(callback_output_value[0], dict): - return [] - return [TextContent(type="text", text=_to_markdown_table(callback_output_value))] + + @classmethod + def format( + cls, output: MCPOutput, returned_output_value: Any + ) -> list[TextContent | ImageContent]: + key = (output.get("component_type"), output.get("property")) + if key not in _TABULAR_PROPS: + return [] + if ( + not isinstance(returned_output_value, list) + or not returned_output_value + or not isinstance(returned_output_value[0], dict) + ): + return [] + return [TextContent(type="text", text=_to_markdown_table(returned_output_value))] diff --git a/dash/mcp/primitives/tools/results/result_plotly_figure.py b/dash/mcp/primitives/tools/results/result_plotly_figure.py index d837e17eed..fff3a3de89 100644 --- a/dash/mcp/primitives/tools/results/result_plotly_figure.py +++ b/dash/mcp/primitives/tools/results/result_plotly_figure.py @@ -6,10 +6,12 @@ import logging from typing import Any -from mcp.types import ImageContent +from mcp.types import ImageContent, TextContent from dash.mcp.types import MCPOutput +from .base import ResultFormatter + logger = logging.getLogger(__name__) IMAGE_WIDTH = 700 @@ -35,18 +37,23 @@ def _render_image(figure: Any) -> ImageContent | None: return ImageContent(type="image", data=b64, mimeType="image/png") -def plotly_figure_result(callback_output: MCPOutput, callback_output_value: Any) -> list: +class PlotlyFigureResult(ResultFormatter): """Produce a rendered PNG for Graph.figure output values.""" - if callback_output.get("component_type") != "Graph" or callback_output.get("property") != "figure": - return [] - if not isinstance(callback_output_value, dict): - return [] - - try: - import plotly.graph_objects as go - except ImportError: - return [] - fig = go.Figure(callback_output_value) - image = _render_image(fig) - return [image] if image is not None else [] + @classmethod + def format( + cls, output: MCPOutput, returned_output_value: Any + ) -> list[TextContent | ImageContent]: + if output.get("component_type") != "Graph" or output.get("property") != "figure": + return [] + if not isinstance(returned_output_value, dict): + return [] + + try: + import plotly.graph_objects as go + except ImportError: + return [] + + fig = go.Figure(returned_output_value) + image = _render_image(fig) + return [image] if image is not None else [] diff --git a/tests/unit/mcp/tools/results/test_dataframe.py b/tests/unit/mcp/tools/results/test_dataframe.py index a7f9e42fca..65aef74d31 100644 --- a/tests/unit/mcp/tools/results/test_dataframe.py +++ b/tests/unit/mcp/tools/results/test_dataframe.py @@ -2,7 +2,7 @@ from dash.mcp.primitives.tools.results.result_dataframe import ( MAX_ROWS, - dataframe_result, + DataFrameResult, ) EXPECTED_TABLE = ( @@ -37,26 +37,26 @@ class TestDataframeResult: def test_datatable_data_renders_markdown(self): - result = dataframe_result(DATATABLE_OUTPUT, SAMPLE_ROWS) + result = DataFrameResult.format(DATATABLE_OUTPUT, SAMPLE_ROWS) assert len(result) == 1 assert result[0].text == EXPECTED_TABLE def test_aggrid_rowdata_renders_markdown(self): - result = dataframe_result(AGGRID_OUTPUT, SAMPLE_ROWS) + result = DataFrameResult.format(AGGRID_OUTPUT, SAMPLE_ROWS) assert len(result) == 1 assert result[0].text == EXPECTED_TABLE def test_ignores_non_tabular_props(self): non_tabular = {**DATATABLE_OUTPUT, "property": "columns"} - assert dataframe_result(non_tabular, SAMPLE_ROWS) == [] + assert DataFrameResult.format(non_tabular, SAMPLE_ROWS) == [] def test_ignores_empty_or_non_dict_rows(self): - assert dataframe_result(DATATABLE_OUTPUT, []) == [] - assert dataframe_result(DATATABLE_OUTPUT, ["a", "b"]) == [] + assert DataFrameResult.format(DATATABLE_OUTPUT, []) == [] + assert DataFrameResult.format(DATATABLE_OUTPUT, ["a", "b"]) == [] def test_truncates_large_tables(self): rows = [{"i": n} for n in range(MAX_ROWS + 50)] - result = dataframe_result(DATATABLE_OUTPUT, rows) + result = DataFrameResult.format(DATATABLE_OUTPUT, rows) text = result[0].text assert f"| {MAX_ROWS - 1} |" in text assert f"| {MAX_ROWS} |" not in text diff --git a/tests/unit/mcp/tools/results/test_plotly_figure.py b/tests/unit/mcp/tools/results/test_plotly_figure.py index 8e336ba687..e3c42af303 100644 --- a/tests/unit/mcp/tools/results/test_plotly_figure.py +++ b/tests/unit/mcp/tools/results/test_plotly_figure.py @@ -6,7 +6,7 @@ import pytest from dash.mcp.primitives.tools.results.result_plotly_figure import ( - plotly_figure_result, + PlotlyFigureResult, ) go = pytest.importorskip("plotly.graph_objects") @@ -28,7 +28,7 @@ class TestPlotlyFigureResult: def test_returns_image_when_kaleido_available(self): fig_dict = go.Figure(data=[go.Bar(x=["A", "B"], y=[1, 2])]).to_plotly_json() with patch.object(go.Figure, "to_image", return_value=FAKE_PNG): - result = plotly_figure_result(GRAPH_FIGURE_OUTPUT, fig_dict) + result = PlotlyFigureResult.format(GRAPH_FIGURE_OUTPUT, fig_dict) assert len(result) == 1 assert result[0].type == "image" assert result[0].data == FAKE_B64 @@ -36,7 +36,7 @@ def test_returns_image_when_kaleido_available(self): def test_returns_empty_when_kaleido_unavailable(self): fig_dict = go.Figure(data=[go.Bar(x=["A", "B"], y=[1, 2])]).to_plotly_json() with patch.object(go.Figure, "to_image", side_effect=ImportError): - result = plotly_figure_result(GRAPH_FIGURE_OUTPUT, fig_dict) + result = PlotlyFigureResult.format(GRAPH_FIGURE_OUTPUT, fig_dict) assert result == [] def test_ignores_non_graph_components(self): @@ -45,11 +45,11 @@ def test_ignores_non_graph_components(self): "component_type": "Div", "property": "children", } - assert plotly_figure_result(output, {}) == [] + assert PlotlyFigureResult.format(output, {}) == [] def test_ignores_non_figure_props(self): output = {**GRAPH_FIGURE_OUTPUT, "property": "clickData"} - assert plotly_figure_result(output, {}) == [] + assert PlotlyFigureResult.format(output, {}) == [] def test_ignores_non_dict_values(self): - assert plotly_figure_result(GRAPH_FIGURE_OUTPUT, "not a dict") == [] + assert PlotlyFigureResult.format(GRAPH_FIGURE_OUTPUT, "not a dict") == [] From 92cf27b428120d0bf710028291ec6814a69da6ca Mon Sep 17 00:00:00 2001 From: Adrian Borrmann Date: Thu, 16 Apr 2026 17:19:34 -0600 Subject: [PATCH 38/80] lint --- dash/mcp/primitives/tools/results/__init__.py | 4 +++- dash/mcp/primitives/tools/results/result_dataframe.py | 4 +++- dash/mcp/primitives/tools/results/result_plotly_figure.py | 5 ++++- 3 files changed, 10 insertions(+), 3 deletions(-) diff --git a/dash/mcp/primitives/tools/results/__init__.py b/dash/mcp/primitives/tools/results/__init__.py index c0232d6028..ae3517919c 100644 --- a/dash/mcp/primitives/tools/results/__init__.py +++ b/dash/mcp/primitives/tools/results/__init__.py @@ -40,7 +40,9 @@ def format_callback_response( resp = response.get("response") or {} for callback_output in callback.outputs: - value = resp.get(callback_output["component_id"], {}).get(callback_output["property"]) + value = resp.get(callback_output["component_id"], {}).get( + callback_output["property"] + ) for formatter in _RESULT_FORMATTERS: content.extend(formatter.format(callback_output, value)) diff --git a/dash/mcp/primitives/tools/results/result_dataframe.py b/dash/mcp/primitives/tools/results/result_dataframe.py index b7113f5d82..04b1d84b3e 100644 --- a/dash/mcp/primitives/tools/results/result_dataframe.py +++ b/dash/mcp/primitives/tools/results/result_dataframe.py @@ -63,4 +63,6 @@ def format( or not isinstance(returned_output_value[0], dict) ): return [] - return [TextContent(type="text", text=_to_markdown_table(returned_output_value))] + return [ + TextContent(type="text", text=_to_markdown_table(returned_output_value)) + ] diff --git a/dash/mcp/primitives/tools/results/result_plotly_figure.py b/dash/mcp/primitives/tools/results/result_plotly_figure.py index fff3a3de89..ad2c057f89 100644 --- a/dash/mcp/primitives/tools/results/result_plotly_figure.py +++ b/dash/mcp/primitives/tools/results/result_plotly_figure.py @@ -44,7 +44,10 @@ class PlotlyFigureResult(ResultFormatter): def format( cls, output: MCPOutput, returned_output_value: Any ) -> list[TextContent | ImageContent]: - if output.get("component_type") != "Graph" or output.get("property") != "figure": + if ( + output.get("component_type") != "Graph" + or output.get("property") != "figure" + ): return [] if not isinstance(returned_output_value, dict): return [] From 13d0fbfefec3b4f0b816a625b3c9a17afa2fd331 Mon Sep 17 00:00:00 2001 From: Adrian Borrmann Date: Wed, 22 Apr 2026 15:31:11 -0600 Subject: [PATCH 39/80] Make import top-level --- dash/mcp/primitives/tools/results/result_plotly_figure.py | 6 +----- 1 file changed, 1 insertion(+), 5 deletions(-) diff --git a/dash/mcp/primitives/tools/results/result_plotly_figure.py b/dash/mcp/primitives/tools/results/result_plotly_figure.py index ad2c057f89..d3b98376f9 100644 --- a/dash/mcp/primitives/tools/results/result_plotly_figure.py +++ b/dash/mcp/primitives/tools/results/result_plotly_figure.py @@ -6,6 +6,7 @@ import logging from typing import Any +import plotly.graph_objects as go from mcp.types import ImageContent, TextContent from dash.mcp.types import MCPOutput @@ -52,11 +53,6 @@ def format( if not isinstance(returned_output_value, dict): return [] - try: - import plotly.graph_objects as go - except ImportError: - return [] - fig = go.Figure(returned_output_value) image = _render_image(fig) return [image] if image is not None else [] From ef8fc34cbbb82f3eadbd47b1c71faa2cde2c4896 Mon Sep 17 00:00:00 2001 From: Adrian Borrmann Date: Thu, 23 Apr 2026 14:13:52 -0600 Subject: [PATCH 40/80] Refactor unit tests to conform to existing test patterns --- .../tools/results/test_callback_response.py | 98 ------- .../unit/mcp/tools/results/test_dataframe.py | 63 ----- .../mcp/tools/results/test_plotly_figure.py | 55 ---- .../test_mcp_formatted_output_results.py | 240 ++++++++++++++++++ 4 files changed, 240 insertions(+), 216 deletions(-) delete mode 100644 tests/unit/mcp/tools/results/test_callback_response.py delete mode 100644 tests/unit/mcp/tools/results/test_dataframe.py delete mode 100644 tests/unit/mcp/tools/results/test_plotly_figure.py create mode 100644 tests/unit/mcp/tools/test_mcp_formatted_output_results.py diff --git a/tests/unit/mcp/tools/results/test_callback_response.py b/tests/unit/mcp/tools/results/test_callback_response.py deleted file mode 100644 index ff8cca5e20..0000000000 --- a/tests/unit/mcp/tools/results/test_callback_response.py +++ /dev/null @@ -1,98 +0,0 @@ -"""Tests for the callback response formatter.""" - -from unittest.mock import Mock - -from dash.mcp.primitives.tools.results import format_callback_response - - -def _mock_callback(outputs=None): - cb = Mock() - cb.outputs = outputs or [] - return cb - - -class TestFormatCallbackResponse: - def test_wraps_as_structured_content(self): - response = { - "multi": True, - "response": {"out": {"children": "hello"}}, - } - result = format_callback_response(response, _mock_callback()) - assert result.structuredContent == response - - def test_content_has_json_text_fallback(self): - """Per MCP spec, structuredContent SHOULD include a TextContent fallback.""" - response = {"multi": True, "response": {}} - result = format_callback_response(response, _mock_callback()) - assert len(result.content) >= 1 - assert result.content[0].type == "text" - assert '"multi": true' in result.content[0].text - - def test_is_error_defaults_false(self): - response = {"multi": True, "response": {}} - result = format_callback_response(response, _mock_callback()) - assert result.isError is False - - def test_preserves_side_update(self): - response = { - "multi": True, - "response": {"out": {"children": "x"}}, - "sideUpdate": {"other": {"value": 42}}, - } - result = format_callback_response(response, _mock_callback()) - assert result.structuredContent["sideUpdate"] == {"other": {"value": 42}} - - def test_datatable_result_includes_markdown_table(self): - response = { - "multi": True, - "response": { - "my-table": {"data": [{"name": "Alice", "age": 30}]}, - }, - } - outputs = [ - { - "component_id": "my-table", - "component_type": "DataTable", - "property": "data", - "id_and_prop": "my-table.data", - "initial_value": None, - "tool_name": "update", - } - ] - result = format_callback_response(response, _mock_callback(outputs)) - texts = [c.text for c in result.content if c.type == "text"] - assert any("| name | age |" in t for t in texts) - - def test_plotly_figure_includes_image(self): - from unittest.mock import patch - - try: - import plotly.graph_objects as go - except ImportError: - return - - response = { - "multi": True, - "response": { - "my-graph": { - "figure": { - "data": [{"type": "bar", "x": ["A"], "y": [1]}], - "layout": {}, - } - } - }, - } - outputs = [ - { - "component_id": "my-graph", - "component_type": "Graph", - "property": "figure", - "id_and_prop": "my-graph.figure", - "initial_value": None, - "tool_name": "update", - } - ] - with patch.object(go.Figure, "to_image", return_value=b"\x89PNGfake"): - result = format_callback_response(response, _mock_callback(outputs)) - images = [c for c in result.content if c.type == "image"] - assert len(images) == 1 diff --git a/tests/unit/mcp/tools/results/test_dataframe.py b/tests/unit/mcp/tools/results/test_dataframe.py deleted file mode 100644 index 65aef74d31..0000000000 --- a/tests/unit/mcp/tools/results/test_dataframe.py +++ /dev/null @@ -1,63 +0,0 @@ -"""Tests for the tabular data result formatter.""" - -from dash.mcp.primitives.tools.results.result_dataframe import ( - MAX_ROWS, - DataFrameResult, -) - -EXPECTED_TABLE = ( - "*2 rows \u00d7 2 columns*\n" - "\n" - "| name | age |\n" - "| --- | --- |\n" - "| Alice | 30 |\n" - "| Bob | 25 |" -) - -SAMPLE_ROWS = [{"name": "Alice", "age": 30}, {"name": "Bob", "age": 25}] - -DATATABLE_OUTPUT = { - "component_type": "DataTable", - "property": "data", - "component_id": "t", - "id_and_prop": "t.data", - "initial_value": None, - "tool_name": "update", -} - -AGGRID_OUTPUT = { - "component_type": "AgGrid", - "property": "rowData", - "component_id": "g", - "id_and_prop": "g.rowData", - "initial_value": None, - "tool_name": "update", -} - - -class TestDataframeResult: - def test_datatable_data_renders_markdown(self): - result = DataFrameResult.format(DATATABLE_OUTPUT, SAMPLE_ROWS) - assert len(result) == 1 - assert result[0].text == EXPECTED_TABLE - - def test_aggrid_rowdata_renders_markdown(self): - result = DataFrameResult.format(AGGRID_OUTPUT, SAMPLE_ROWS) - assert len(result) == 1 - assert result[0].text == EXPECTED_TABLE - - def test_ignores_non_tabular_props(self): - non_tabular = {**DATATABLE_OUTPUT, "property": "columns"} - assert DataFrameResult.format(non_tabular, SAMPLE_ROWS) == [] - - def test_ignores_empty_or_non_dict_rows(self): - assert DataFrameResult.format(DATATABLE_OUTPUT, []) == [] - assert DataFrameResult.format(DATATABLE_OUTPUT, ["a", "b"]) == [] - - def test_truncates_large_tables(self): - rows = [{"i": n} for n in range(MAX_ROWS + 50)] - result = DataFrameResult.format(DATATABLE_OUTPUT, rows) - text = result[0].text - assert f"| {MAX_ROWS - 1} |" in text - assert f"| {MAX_ROWS} |" not in text - assert "50 more rows" in text diff --git a/tests/unit/mcp/tools/results/test_plotly_figure.py b/tests/unit/mcp/tools/results/test_plotly_figure.py deleted file mode 100644 index e3c42af303..0000000000 --- a/tests/unit/mcp/tools/results/test_plotly_figure.py +++ /dev/null @@ -1,55 +0,0 @@ -"""Tests for the Plotly figure tool result formatter.""" - -import base64 -from unittest.mock import patch - -import pytest - -from dash.mcp.primitives.tools.results.result_plotly_figure import ( - PlotlyFigureResult, -) - -go = pytest.importorskip("plotly.graph_objects") - -FAKE_PNG = b"\x89PNG\r\n\x1a\nfakedata" -FAKE_B64 = base64.b64encode(FAKE_PNG).decode("ascii") - -GRAPH_FIGURE_OUTPUT = { - "component_type": "Graph", - "property": "figure", - "component_id": "g", - "id_and_prop": "g.figure", - "initial_value": None, - "tool_name": "update", -} - - -class TestPlotlyFigureResult: - def test_returns_image_when_kaleido_available(self): - fig_dict = go.Figure(data=[go.Bar(x=["A", "B"], y=[1, 2])]).to_plotly_json() - with patch.object(go.Figure, "to_image", return_value=FAKE_PNG): - result = PlotlyFigureResult.format(GRAPH_FIGURE_OUTPUT, fig_dict) - assert len(result) == 1 - assert result[0].type == "image" - assert result[0].data == FAKE_B64 - - def test_returns_empty_when_kaleido_unavailable(self): - fig_dict = go.Figure(data=[go.Bar(x=["A", "B"], y=[1, 2])]).to_plotly_json() - with patch.object(go.Figure, "to_image", side_effect=ImportError): - result = PlotlyFigureResult.format(GRAPH_FIGURE_OUTPUT, fig_dict) - assert result == [] - - def test_ignores_non_graph_components(self): - output = { - **GRAPH_FIGURE_OUTPUT, - "component_type": "Div", - "property": "children", - } - assert PlotlyFigureResult.format(output, {}) == [] - - def test_ignores_non_figure_props(self): - output = {**GRAPH_FIGURE_OUTPUT, "property": "clickData"} - assert PlotlyFigureResult.format(output, {}) == [] - - def test_ignores_non_dict_values(self): - assert PlotlyFigureResult.format(GRAPH_FIGURE_OUTPUT, "not a dict") == [] diff --git a/tests/unit/mcp/tools/test_mcp_formatted_output_results.py b/tests/unit/mcp/tools/test_mcp_formatted_output_results.py new file mode 100644 index 0000000000..c255e9ba5c --- /dev/null +++ b/tests/unit/mcp/tools/test_mcp_formatted_output_results.py @@ -0,0 +1,240 @@ +"""Formatted tool-output results: structured content + text/image fallbacks. + +Covers: +- ``format_callback_response`` — wraps the callback result as + ``structuredContent`` and delegates to per-output formatters. +- ``DataFrameResult`` — renders DataTable/AgGrid rows as markdown tables. +- ``PlotlyFigureResult`` — renders ``dcc.Graph.figure`` values as PNG images + (via kaleido; skipped gracefully when kaleido is unavailable). +""" + +import base64 +from unittest.mock import Mock, patch + +import plotly.graph_objects as go # type: ignore[import-untyped] + +from dash.mcp.primitives.tools.results import format_callback_response +from dash.mcp.primitives.tools.results.result_dataframe import ( + MAX_ROWS, + DataFrameResult, +) +from dash.mcp.primitives.tools.results.result_plotly_figure import ( + PlotlyFigureResult, +) + + +# --------------------------------------------------------------------------- +# Helpers / fixtures +# --------------------------------------------------------------------------- + + +def _mock_callback(outputs=None): + cb = Mock() + cb.outputs = outputs or [] + return cb + + +EXPECTED_TABLE = ( + "*2 rows \u00d7 2 columns*\n" + "\n" + "| name | age |\n" + "| --- | --- |\n" + "| Alice | 30 |\n" + "| Bob | 25 |" +) + +SAMPLE_ROWS = [{"name": "Alice", "age": 30}, {"name": "Bob", "age": 25}] + +DATATABLE_OUTPUT = { + "component_type": "DataTable", + "property": "data", + "component_id": "t", + "id_and_prop": "t.data", + "initial_value": None, + "tool_name": "update", +} + +AGGRID_OUTPUT = { + "component_type": "AgGrid", + "property": "rowData", + "component_id": "g", + "id_and_prop": "g.rowData", + "initial_value": None, + "tool_name": "update", +} + +FAKE_PNG = b"\x89PNG\r\n\x1a\nfakedata" +FAKE_B64 = base64.b64encode(FAKE_PNG).decode("ascii") + +GRAPH_FIGURE_OUTPUT = { + "component_type": "Graph", + "property": "figure", + "component_id": "g", + "id_and_prop": "g.figure", + "initial_value": None, + "tool_name": "update", +} + + +# --------------------------------------------------------------------------- +# format_callback_response +# --------------------------------------------------------------------------- + + +def test_mcpr001_wraps_as_structured_content(): + response = { + "multi": True, + "response": {"out": {"children": "hello"}}, + } + result = format_callback_response(response, _mock_callback()) + assert result.structuredContent == response + + +def test_mcpr002_content_has_json_text_fallback(): + """Per MCP spec, structuredContent SHOULD include a TextContent fallback.""" + response = {"multi": True, "response": {}} + result = format_callback_response(response, _mock_callback()) + assert len(result.content) >= 1 + assert result.content[0].type == "text" + assert '"multi": true' in result.content[0].text + + +def test_mcpr003_is_error_defaults_false(): + response = {"multi": True, "response": {}} + result = format_callback_response(response, _mock_callback()) + assert result.isError is False + + +def test_mcpr004_preserves_side_update(): + response = { + "multi": True, + "response": {"out": {"children": "x"}}, + "sideUpdate": {"other": {"value": 42}}, + } + result = format_callback_response(response, _mock_callback()) + assert result.structuredContent["sideUpdate"] == {"other": {"value": 42}} + + +def test_mcpr005_datatable_result_includes_markdown_table(): + response = { + "multi": True, + "response": { + "my-table": {"data": [{"name": "Alice", "age": 30}]}, + }, + } + outputs = [ + { + "component_id": "my-table", + "component_type": "DataTable", + "property": "data", + "id_and_prop": "my-table.data", + "initial_value": None, + "tool_name": "update", + } + ] + result = format_callback_response(response, _mock_callback(outputs)) + texts = [c.text for c in result.content if c.type == "text"] + assert any("| name | age |" in t for t in texts) + + +def test_mcpr006_plotly_figure_includes_image(): + response = { + "multi": True, + "response": { + "my-graph": { + "figure": { + "data": [{"type": "bar", "x": ["A"], "y": [1]}], + "layout": {}, + } + } + }, + } + outputs = [ + { + "component_id": "my-graph", + "component_type": "Graph", + "property": "figure", + "id_and_prop": "my-graph.figure", + "initial_value": None, + "tool_name": "update", + } + ] + with patch.object(go.Figure, "to_image", return_value=b"\x89PNGfake"): + result = format_callback_response(response, _mock_callback(outputs)) + images = [c for c in result.content if c.type == "image"] + assert len(images) == 1 + + +# --------------------------------------------------------------------------- +# DataFrameResult (DataTable / AgGrid markdown rendering) +# --------------------------------------------------------------------------- + + +def test_mcpr007_datatable_data_renders_markdown(): + result = DataFrameResult.format(DATATABLE_OUTPUT, SAMPLE_ROWS) + assert len(result) == 1 + assert result[0].text == EXPECTED_TABLE + + +def test_mcpr008_aggrid_rowdata_renders_markdown(): + result = DataFrameResult.format(AGGRID_OUTPUT, SAMPLE_ROWS) + assert len(result) == 1 + assert result[0].text == EXPECTED_TABLE + + +def test_mcpr009_ignores_non_tabular_props(): + non_tabular = {**DATATABLE_OUTPUT, "property": "columns"} + assert DataFrameResult.format(non_tabular, SAMPLE_ROWS) == [] + + +def test_mcpr010_ignores_empty_or_non_dict_rows(): + assert DataFrameResult.format(DATATABLE_OUTPUT, []) == [] + assert DataFrameResult.format(DATATABLE_OUTPUT, ["a", "b"]) == [] + + +def test_mcpr011_truncates_large_tables(): + rows = [{"i": n} for n in range(MAX_ROWS + 50)] + result = DataFrameResult.format(DATATABLE_OUTPUT, rows) + text = result[0].text + assert f"| {MAX_ROWS - 1} |" in text + assert f"| {MAX_ROWS} |" not in text + assert "50 more rows" in text + + +# --------------------------------------------------------------------------- +# PlotlyFigureResult (Graph.figure → PNG image) +# --------------------------------------------------------------------------- + + +def test_mcpr012_returns_image_when_kaleido_available(): + fig_dict = go.Figure(data=[go.Bar(x=["A", "B"], y=[1, 2])]).to_plotly_json() + with patch.object(go.Figure, "to_image", return_value=FAKE_PNG): + result = PlotlyFigureResult.format(GRAPH_FIGURE_OUTPUT, fig_dict) + assert len(result) == 1 + assert result[0].type == "image" + assert result[0].data == FAKE_B64 + + +def test_mcpr013_returns_empty_when_kaleido_unavailable(): + fig_dict = go.Figure(data=[go.Bar(x=["A", "B"], y=[1, 2])]).to_plotly_json() + with patch.object(go.Figure, "to_image", side_effect=ImportError): + result = PlotlyFigureResult.format(GRAPH_FIGURE_OUTPUT, fig_dict) + assert result == [] + + +def test_mcpr014_ignores_non_graph_components(): + output = { + **GRAPH_FIGURE_OUTPUT, + "component_type": "Div", + "property": "children", + } + assert PlotlyFigureResult.format(output, {}) == [] + + +def test_mcpr015_ignores_non_figure_props(): + output = {**GRAPH_FIGURE_OUTPUT, "property": "clickData"} + assert PlotlyFigureResult.format(output, {}) == [] + + +def test_mcpr016_ignores_non_dict_values(): + assert PlotlyFigureResult.format(GRAPH_FIGURE_OUTPUT, "not a dict") == [] From 88589cbfd178c22df79f035c92c7f61a671efec4 Mon Sep 17 00:00:00 2001 From: Adrian Borrmann Date: Thu, 30 Apr 2026 09:20:37 -0600 Subject: [PATCH 41/80] Use TABULAR/PLOTLY_FIGURE PropRole.matches() in formatters --- dash/mcp/primitives/tools/prop_roles.py | 2 +- dash/mcp/primitives/tools/results/result_dataframe.py | 9 ++------- .../mcp/primitives/tools/results/result_plotly_figure.py | 6 +++--- 3 files changed, 6 insertions(+), 11 deletions(-) diff --git a/dash/mcp/primitives/tools/prop_roles.py b/dash/mcp/primitives/tools/prop_roles.py index eface31147..64fdc8f76d 100644 --- a/dash/mcp/primitives/tools/prop_roles.py +++ b/dash/mcp/primitives/tools/prop_roles.py @@ -100,7 +100,7 @@ def _compute_dropdown_value_schema(param: MCPInput) -> dict[str, Any]: description="Returns formatted text", ) -GENERIC_FIGURE = PropRole( +PLOTLY_FIGURE = PropRole( identifiers={(ANY_COMPONENT, "figure")}, description="Returns chart/visualization data", input_schema={ diff --git a/dash/mcp/primitives/tools/results/result_dataframe.py b/dash/mcp/primitives/tools/results/result_dataframe.py index 04b1d84b3e..1d93ca74aa 100644 --- a/dash/mcp/primitives/tools/results/result_dataframe.py +++ b/dash/mcp/primitives/tools/results/result_dataframe.py @@ -13,15 +13,11 @@ from dash.mcp.types import MCPOutput +from ..prop_roles import TABULAR from .base import ResultFormatter MAX_ROWS = 50 -_TABULAR_PROPS = { - ("DataTable", "data"), - ("AgGrid", "rowData"), -} - def _to_markdown_table(rows: list[dict], max_rows: int = MAX_ROWS) -> str: """Render a list of row dicts as a markdown table.""" @@ -54,8 +50,7 @@ class DataFrameResult(ResultFormatter): def format( cls, output: MCPOutput, returned_output_value: Any ) -> list[TextContent | ImageContent]: - key = (output.get("component_type"), output.get("property")) - if key not in _TABULAR_PROPS: + if not TABULAR.matches(output.get("component_type"), output.get("property")): return [] if ( not isinstance(returned_output_value, list) diff --git a/dash/mcp/primitives/tools/results/result_plotly_figure.py b/dash/mcp/primitives/tools/results/result_plotly_figure.py index d3b98376f9..4a62131573 100644 --- a/dash/mcp/primitives/tools/results/result_plotly_figure.py +++ b/dash/mcp/primitives/tools/results/result_plotly_figure.py @@ -11,6 +11,7 @@ from dash.mcp.types import MCPOutput +from ..prop_roles import PLOTLY_FIGURE from .base import ResultFormatter logger = logging.getLogger(__name__) @@ -45,9 +46,8 @@ class PlotlyFigureResult(ResultFormatter): def format( cls, output: MCPOutput, returned_output_value: Any ) -> list[TextContent | ImageContent]: - if ( - output.get("component_type") != "Graph" - or output.get("property") != "figure" + if not PLOTLY_FIGURE.matches( + output.get("component_type"), output.get("property") ): return [] if not isinstance(returned_output_value, dict): From 55f6514b48a7cce05f3bfb6727559957ca8fbbd6 Mon Sep 17 00:00:00 2001 From: Adrian Borrmann Date: Thu, 30 Apr 2026 10:02:11 -0600 Subject: [PATCH 42/80] Tighten formatter types: required-field access + plotly stub ignore --- dash/mcp/primitives/tools/results/result_dataframe.py | 2 +- dash/mcp/primitives/tools/results/result_plotly_figure.py | 6 ++---- 2 files changed, 3 insertions(+), 5 deletions(-) diff --git a/dash/mcp/primitives/tools/results/result_dataframe.py b/dash/mcp/primitives/tools/results/result_dataframe.py index 1d93ca74aa..66290791a9 100644 --- a/dash/mcp/primitives/tools/results/result_dataframe.py +++ b/dash/mcp/primitives/tools/results/result_dataframe.py @@ -50,7 +50,7 @@ class DataFrameResult(ResultFormatter): def format( cls, output: MCPOutput, returned_output_value: Any ) -> list[TextContent | ImageContent]: - if not TABULAR.matches(output.get("component_type"), output.get("property")): + if not TABULAR.matches(output.get("component_type"), output["property"]): return [] if ( not isinstance(returned_output_value, list) diff --git a/dash/mcp/primitives/tools/results/result_plotly_figure.py b/dash/mcp/primitives/tools/results/result_plotly_figure.py index 4a62131573..63e4adbb31 100644 --- a/dash/mcp/primitives/tools/results/result_plotly_figure.py +++ b/dash/mcp/primitives/tools/results/result_plotly_figure.py @@ -6,7 +6,7 @@ import logging from typing import Any -import plotly.graph_objects as go +import plotly.graph_objects as go # type: ignore[import-untyped] from mcp.types import ImageContent, TextContent from dash.mcp.types import MCPOutput @@ -46,9 +46,7 @@ class PlotlyFigureResult(ResultFormatter): def format( cls, output: MCPOutput, returned_output_value: Any ) -> list[TextContent | ImageContent]: - if not PLOTLY_FIGURE.matches( - output.get("component_type"), output.get("property") - ): + if not PLOTLY_FIGURE.matches(output.get("component_type"), output["property"]): return [] if not isinstance(returned_output_value, dict): return [] From 0be83d3f623e3d86c5b3792d6b23eb52fabe4a9f Mon Sep 17 00:00:00 2001 From: Adrian Borrmann Date: Fri, 8 May 2026 09:35:47 -0600 Subject: [PATCH 43/80] code review feedback --- .../tools/results/result_dataframe.py | 24 +++++++++---------- .../tools/results/result_plotly_figure.py | 5 ++-- .../test_mcp_formatted_output_results.py | 14 +++++------ 3 files changed, 21 insertions(+), 22 deletions(-) diff --git a/dash/mcp/primitives/tools/results/result_dataframe.py b/dash/mcp/primitives/tools/results/result_dataframe.py index 66290791a9..8a2130387c 100644 --- a/dash/mcp/primitives/tools/results/result_dataframe.py +++ b/dash/mcp/primitives/tools/results/result_dataframe.py @@ -9,7 +9,7 @@ from typing import Any -from mcp.types import ImageContent, TextContent +from mcp.types import TextContent from dash.mcp.types import MCPOutput @@ -22,23 +22,23 @@ def _to_markdown_table(rows: list[dict], max_rows: int = MAX_ROWS) -> str: """Render a list of row dicts as a markdown table.""" columns = list(rows[0].keys()) - total = len(rows) + total_rows = len(rows) lines: list[str] = [] - lines.append(f"*{total} rows \u00d7 {len(columns)} columns*") + lines.append(f"*{total_rows} rows \u00d7 {len(columns)} columns*") lines.append("") - lines.append("| " + " | ".join(columns) + " |") - lines.append("| " + " | ".join("---" for _ in columns) + " |") + lines.append(" | ".join(columns)) + lines.append(" | ".join("---" for _ in columns)) for row in rows[:max_rows]: cells = [ str(row.get(col, "")).replace("|", "\\|").replace("\n", " ") for col in columns ] - lines.append("| " + " | ".join(cells) + " |") + lines.append(" | ".join(cells)) - if total > max_rows: - lines.append(f"\n(\u2026 {total - max_rows} more rows)") + if total_rows > max_rows: + lines.append(f"\n(\u2026 {total_rows - max_rows} more rows)") return "\n".join(lines) @@ -47,14 +47,12 @@ class DataFrameResult(ResultFormatter): """Produce a markdown table for tabular component output values.""" @classmethod - def format( - cls, output: MCPOutput, returned_output_value: Any - ) -> list[TextContent | ImageContent]: + def format(cls, output: MCPOutput, returned_output_value: Any) -> list[TextContent]: if not TABULAR.matches(output.get("component_type"), output["property"]): return [] if ( - not isinstance(returned_output_value, list) - or not returned_output_value + not returned_output_value + or not isinstance(returned_output_value, list) or not isinstance(returned_output_value[0], dict) ): return [] diff --git a/dash/mcp/primitives/tools/results/result_plotly_figure.py b/dash/mcp/primitives/tools/results/result_plotly_figure.py index 63e4adbb31..b7f4273933 100644 --- a/dash/mcp/primitives/tools/results/result_plotly_figure.py +++ b/dash/mcp/primitives/tools/results/result_plotly_figure.py @@ -21,7 +21,8 @@ def _render_image(figure: Any) -> ImageContent | None: - """Render the figure as a base64 PNG ImageContent. + """ + Render the figure as a base64 PNG ImageContent. Returns None if kaleido is not installed. """ @@ -48,7 +49,7 @@ def format( ) -> list[TextContent | ImageContent]: if not PLOTLY_FIGURE.matches(output.get("component_type"), output["property"]): return [] - if not isinstance(returned_output_value, dict): + if not returned_output_value or not isinstance(returned_output_value, dict): return [] fig = go.Figure(returned_output_value) diff --git a/tests/unit/mcp/tools/test_mcp_formatted_output_results.py b/tests/unit/mcp/tools/test_mcp_formatted_output_results.py index c255e9ba5c..da931009cd 100644 --- a/tests/unit/mcp/tools/test_mcp_formatted_output_results.py +++ b/tests/unit/mcp/tools/test_mcp_formatted_output_results.py @@ -37,10 +37,10 @@ def _mock_callback(outputs=None): EXPECTED_TABLE = ( "*2 rows \u00d7 2 columns*\n" "\n" - "| name | age |\n" - "| --- | --- |\n" - "| Alice | 30 |\n" - "| Bob | 25 |" + "name | age\n" + "--- | ---\n" + "Alice | 30\n" + "Bob | 25" ) SAMPLE_ROWS = [{"name": "Alice", "age": 30}, {"name": "Bob", "age": 25}] @@ -134,7 +134,7 @@ def test_mcpr005_datatable_result_includes_markdown_table(): ] result = format_callback_response(response, _mock_callback(outputs)) texts = [c.text for c in result.content if c.type == "text"] - assert any("| name | age |" in t for t in texts) + assert any("name | age" in t for t in texts) def test_mcpr006_plotly_figure_includes_image(): @@ -196,8 +196,8 @@ def test_mcpr011_truncates_large_tables(): rows = [{"i": n} for n in range(MAX_ROWS + 50)] result = DataFrameResult.format(DATATABLE_OUTPUT, rows) text = result[0].text - assert f"| {MAX_ROWS - 1} |" in text - assert f"| {MAX_ROWS} |" not in text + assert f"\n{MAX_ROWS - 1}\n" in text + assert f"\n{MAX_ROWS}\n" not in text assert "50 more rows" in text From bcc86c80e2432ccd2392b905d449aee9e4da07f1 Mon Sep 17 00:00:00 2001 From: Adrian Borrmann Date: Wed, 1 Apr 2026 15:02:04 -0600 Subject: [PATCH 44/80] Add get_dash_component tool and callback tool dispatch pipeline --- dash/mcp/primitives/tools/__init__.py | 43 ++++++ .../tools/tool_get_dash_component.py | 123 ++++++++++++++++ dash/mcp/primitives/tools/tools_callbacks.py | 47 ++++++ tests/unit/mcp/conftest.py | 12 +- .../mcp/tools/test_tool_get_dash_component.py | 117 +++++++++++++++ tests/unit/mcp/tools/test_tools_callbacks.py | 137 ++++++++++++++++++ 6 files changed, 477 insertions(+), 2 deletions(-) create mode 100644 dash/mcp/primitives/tools/tool_get_dash_component.py create mode 100644 dash/mcp/primitives/tools/tools_callbacks.py create mode 100644 tests/unit/mcp/tools/test_tool_get_dash_component.py create mode 100644 tests/unit/mcp/tools/test_tools_callbacks.py diff --git a/dash/mcp/primitives/tools/__init__.py b/dash/mcp/primitives/tools/__init__.py index e69de29bb2..64f89dc3d0 100644 --- a/dash/mcp/primitives/tools/__init__.py +++ b/dash/mcp/primitives/tools/__init__.py @@ -0,0 +1,43 @@ +"""MCP tool listing and call handling. + +Each tool module exports: +- ``get_tool_names() -> set[str]`` +- ``get_tools() -> list[Tool]`` +- ``call_tool(tool_name, arguments) -> CallToolResult`` + +The __init__ assembles the list and dispatches calls by name. +""" + +from __future__ import annotations + +from typing import Any + +from mcp.types import CallToolResult, ListToolsResult + +from dash.mcp.types import ToolNotFoundError + +from . import tool_get_dash_component as _get_component +from . import tools_callbacks as _callbacks + +_TOOL_MODULES = [_callbacks, _get_component] + + +def list_tools() -> ListToolsResult: + """Build the MCP tools/list response.""" + tools = [] + for mod in _TOOL_MODULES: + tools.extend(mod.get_tools()) + return ListToolsResult(tools=tools) + + +def call_tool(tool_name: str, arguments: dict[str, Any]) -> CallToolResult: + """Dispatch a tools/call request by tool name.""" + for mod in _TOOL_MODULES: + if tool_name in mod.get_tool_names(): + result = mod.call_tool(tool_name, arguments) + return result + raise ToolNotFoundError( + f"Tool not found: {tool_name}." + " The app's callbacks may have changed." + " Please call tools/list to refresh your tool list." + ) diff --git a/dash/mcp/primitives/tools/tool_get_dash_component.py b/dash/mcp/primitives/tools/tool_get_dash_component.py new file mode 100644 index 0000000000..8584242333 --- /dev/null +++ b/dash/mcp/primitives/tools/tool_get_dash_component.py @@ -0,0 +1,123 @@ +"""Built-in tool: get_dash_component.""" + +from __future__ import annotations + +import json +from typing import Any + +from mcp.types import CallToolResult, TextContent, Tool +from pydantic import Field, TypeAdapter +from typing_extensions import Annotated, NotRequired, TypedDict + +from dash import get_app +from dash._layout_utils import find_component +from dash.mcp.types import ComponentPropertyInfo, ComponentQueryResult + + +class _ComponentQueryInput(TypedDict): + component_id: Annotated[str, Field(description="The component ID to query")] + property: NotRequired[ + Annotated[ + str, + Field( + description="The property name to read (e.g. 'options', 'value'). Omit to list all defined properties." + ), + ] + ] + + +_INPUT_SCHEMA = TypeAdapter(_ComponentQueryInput).json_schema() +_OUTPUT_SCHEMA = TypeAdapter(ComponentQueryResult).json_schema() + +NAME = "get_dash_component" + + +def get_tool_names() -> set[str]: + return {NAME} + + +def get_tools() -> list[Tool]: + return [_build_tool()] + + +def _build_tool() -> Tool: + return Tool( + name=NAME, + description=( + "Get a component's properties, values, and tool relationships. " + "If property is omitted, returns all defined properties. " + "If property is specified, returns only that property. " + "See the dash://components resource for available component IDs." + ), + inputSchema=_INPUT_SCHEMA, + outputSchema=_OUTPUT_SCHEMA, + ) + + +def call_tool(tool_name: str, arguments: dict[str, Any]) -> CallToolResult: + comp_id = arguments.get("component_id", "") + if not comp_id: + raise ValueError("component_id is required") + + prop_filter = arguments.get("property", "") + component = find_component(comp_id) + + if component is None: + callback_map = get_app().mcp_callback_map + rendering_tools = [ + cb.tool_name + for cb in callback_map + if any(out["component_id"] == comp_id for out in cb.outputs) + ] + msg = f"Component '{comp_id}' not found in static layout." + if rendering_tools: + msg += f" However, the following tools would modify it: {rendering_tools}." + msg += " Use the dash://components resource to see statically available component IDs." + return CallToolResult( + content=[TextContent(type="text", text=msg)], + isError=True, + ) + + callback_map = get_app().mcp_callback_map + + properties: dict[str, ComponentPropertyInfo] = {} + for prop_name in getattr(component, "_prop_names", []): + if prop_filter and prop_name != prop_filter: + continue + + value = callback_map.get_initial_value(f"{comp_id}.{prop_name}") + if value is None: + value = getattr(component, prop_name, None) + if value is None: + continue + + modified_by: list[str] = [] + input_to: list[str] = [] + id_and_prop = f"{comp_id}.{prop_name}" + for cb in callback_map: + for out in cb.outputs: + if out["id_and_prop"] == id_and_prop: + modified_by.append(cb.tool_name) + for inp in cb.inputs: + if inp["id_and_prop"] == id_and_prop: + input_to.append(cb.tool_name) + + properties[prop_name] = ComponentPropertyInfo( + initial_value=value, + modified_by_tool=modified_by, + input_to_tool=input_to, + ) + + labels = callback_map.component_label_map.get(comp_id, []) + + structured: ComponentQueryResult = ComponentQueryResult( + component_id=comp_id, + component_type=type(component).__name__, + label=labels if labels else None, + properties=properties, + ) + + return CallToolResult( + content=[TextContent(type="text", text=json.dumps(structured, default=str))], + structuredContent=structured, + ) diff --git a/dash/mcp/primitives/tools/tools_callbacks.py b/dash/mcp/primitives/tools/tools_callbacks.py new file mode 100644 index 0000000000..ba08795d35 --- /dev/null +++ b/dash/mcp/primitives/tools/tools_callbacks.py @@ -0,0 +1,47 @@ +"""Dynamic callback tools for MCP. + +Handles listing, naming, and executing callback-based tools. +""" + +from __future__ import annotations + +from typing import Any + +from mcp.types import CallToolResult, TextContent, Tool + +from dash import get_app +from dash.mcp.types import CallbackExecutionError, ToolNotFoundError + +from .results import format_callback_response + + +def get_tool_names() -> set[str]: + return get_app().mcp_callback_map.tool_names + + +def get_tools() -> list[Tool]: + """Return one Tool per server-callable callback.""" + return get_app().mcp_callback_map.as_mcp_tools() + + +def call_tool(tool_name: str, arguments: dict[str, Any]) -> CallToolResult: + """Execute a callback tool by name.""" + from .callback_utils import run_callback + + callback_map = get_app().mcp_callback_map + cb = callback_map.find_by_tool_name(tool_name) + if cb is None: + raise ToolNotFoundError( + f"Tool not found: {tool_name}." + " The app's callbacks may have changed." + " Please call tools/list to refresh your tool list." + ) + + try: + dispatch_response = run_callback(cb, arguments) + except CallbackExecutionError as e: + return CallToolResult( + content=[TextContent(type="text", text=str(e))], + isError=True, + ) + return format_callback_response(dispatch_response, cb) diff --git a/tests/unit/mcp/conftest.py b/tests/unit/mcp/conftest.py index 7f85e4af9d..38b2f8c00e 100644 --- a/tests/unit/mcp/conftest.py +++ b/tests/unit/mcp/conftest.py @@ -9,6 +9,7 @@ if sys.version_info < (3, 10): collect_ignore_glob.append("*") else: + from dash.mcp.primitives.tools import call_tool, list_tools # pylint: disable=wrong-import-position from dash.mcp.primitives.tools.callback_adapter_collection import ( # pylint: disable=wrong-import-position CallbackAdapterCollection, ) @@ -46,10 +47,10 @@ def update_output(value): def _tools_list(app): - """Return tools as Tool objects via as_mcp_tools().""" + """Return all tools (callbacks + builtins) as Tool objects.""" _setup_mcp(app) with app.server.test_request_context(): - return app.mcp_callback_map.as_mcp_tools() + return list_tools().tools def _user_tool(tools): @@ -85,3 +86,10 @@ def _desc_for(tool, param_name=None): if param_name is None: param_name = next(iter(props)) return props[param_name].get("description", "") + + +def _call_tool(app, tool_name, arguments=None): + """Call a tool via the dispatch pipeline and return the CallToolResult.""" + _setup_mcp(app) + with app.server.test_request_context(): + return call_tool(tool_name, arguments or {}) diff --git a/tests/unit/mcp/tools/test_tool_get_dash_component.py b/tests/unit/mcp/tools/test_tool_get_dash_component.py new file mode 100644 index 0000000000..5a8a454068 --- /dev/null +++ b/tests/unit/mcp/tools/test_tool_get_dash_component.py @@ -0,0 +1,117 @@ +"""Tests for the get_dash_component built-in tool.""" + +from dash import Dash, Input, Output, dcc, html + +from tests.unit.mcp.conftest import _call_tool, _make_app, _tools_list + + +class TestGetDashComponent: + def test_present_in_tools_list(self): + app = _make_app() + tool_names = [t.name for t in _tools_list(app)] + assert "get_dash_component" in tool_names + + def test_returns_structured_output_with_prop(self): + app = Dash(__name__) + app.layout = html.Div( + [ + dcc.Dropdown(id="my-dd", options=["a", "b"], value="b"), + ] + ) + + result = _call_tool( + app, + "get_dash_component", + { + "component_id": "my-dd", + "property": "value", + }, + ) + sc = result.structuredContent + assert sc["component_id"] == "my-dd" + assert sc["component_type"] == "Dropdown" + assert "value" in sc["properties"] + assert sc["properties"]["value"]["initial_value"] == "b" + assert "options" not in sc["properties"] + + def test_returns_all_props_without_property(self): + app = Dash(__name__) + app.layout = html.Div( + [ + dcc.Dropdown(id="my-dd", options=["a", "b"], value="b"), + ] + ) + + result = _call_tool( + app, + "get_dash_component", + { + "component_id": "my-dd", + }, + ) + sc = result.structuredContent + assert "options" in sc["properties"] + assert "value" in sc["properties"] + assert sc["properties"]["value"]["initial_value"] == "b" + + def test_includes_label(self): + app = Dash(__name__) + app.layout = html.Div( + [ + html.Label("Pick one", htmlFor="my-dd"), + dcc.Dropdown(id="my-dd", options=["a", "b"], value="a"), + ] + ) + + @app.callback(Output("my-dd", "value"), Input("my-dd", "options")) + def noop(o): + return "a" + + result = _call_tool( + app, + "get_dash_component", + { + "component_id": "my-dd", + }, + ) + sc = result.structuredContent + assert sc["label"] == ["Pick one"] + + def test_includes_tool_references(self): + app = Dash(__name__) + app.layout = html.Div( + [ + dcc.Dropdown(id="dd", options=["a", "b"], value="a"), + html.Div(id="out"), + ] + ) + + @app.callback(Output("out", "children"), Input("dd", "value")) + def update(val): + return val + + result = _call_tool( + app, + "get_dash_component", + { + "component_id": "dd", + "property": "value", + }, + ) + prop_info = result.structuredContent["properties"]["value"] + assert "update" in prop_info["input_to_tool"] + + def test_missing_id_returns_hint(self): + app = _make_app() + result = _call_tool( + app, + "get_dash_component", + { + "component_id": "nonexistent", + "property": "value", + }, + ) + text = result.content[0].text + assert "nonexistent" in text + assert "not found" in text + assert "dash://components" in text diff --git a/tests/unit/mcp/tools/test_tools_callbacks.py b/tests/unit/mcp/tools/test_tools_callbacks.py new file mode 100644 index 0000000000..751f85bedd --- /dev/null +++ b/tests/unit/mcp/tools/test_tools_callbacks.py @@ -0,0 +1,137 @@ +"""Tool definition tests — MCP spec compliance and Dash conventions. + +Verifies that generated tools conform to the MCP specification (2025-11-25) +and Dash-specific conventions. Focuses on shape/structure, not inputSchema +values (those are covered by input_schemas/). + +Reference: https://modelcontextprotocol.io/specification/2025-11-25/server/tools +""" + +import re + +from dash.mcp.primitives.tools.callback_adapter_collection import ( + CallbackAdapterCollection, +) +from dash.mcp.primitives.tools.descriptions import build_tool_description + +from tests.unit.mcp.conftest import ( + _make_app, + _tools_list, +) + +_TOOL_NAME_RE = re.compile(r"^[A-Za-z0-9_\-.]+$") + + +class TestToolSpecCompliance: + """Every tool must conform to the MCP 2025-11-25 specification.""" + + def test_all_tools_conform_to_mcp_spec(self): + tools = _tools_list(_make_app()) + names = [t.name for t in tools] + + assert len(names) == len(set(names)), f"Duplicate tool names: {names}" + + for tool in tools: + assert tool.name + assert tool.inputSchema + assert 1 <= len(tool.name) <= 128 + assert _TOOL_NAME_RE.match(tool.name), f"Invalid tool name: {tool.name}" + + schema = tool.inputSchema + assert isinstance(schema, dict) + assert schema.get("type") == "object" + assert isinstance(schema.get("properties", {}), dict) + + required = set(schema.get("required", [])) + props = set(schema.get("properties", {}).keys()) + assert ( + required <= props + ), f"{tool.name}: required {required - props} not in properties" + + +class TestBuiltinToolDefinitions: + def _tools(self): + return _tools_list(_make_app()) + + def _builtin(self, name): + return next(t for t in self._tools() if t.name == name) + + def test_query_component_always_present(self): + names = {t.name for t in self._tools()} + assert "get_dash_component" in names + + def test_query_component_has_required_params(self): + tool = self._builtin("get_dash_component") + assert "component_id" in tool.inputSchema["properties"] + assert "property" in tool.inputSchema["properties"] + assert set(tool.inputSchema.get("required", [])) == {"component_id"} + + +class TestSanitizeToolName: + def test_simple_name(self): + assert ( + CallbackAdapterCollection._sanitize_name("update_output") == "update_output" + ) + + def test_special_characters_replaced(self): + assert ( + CallbackAdapterCollection._sanitize_name("my-func.name") == "my_func_name" + ) + + def test_leading_digit(self): + assert CallbackAdapterCollection._sanitize_name("123func") == "cb_123func" + + def test_empty_name(self): + assert CallbackAdapterCollection._sanitize_name("") == "unnamed_callback" + + def test_consecutive_underscores_collapsed(self): + assert CallbackAdapterCollection._sanitize_name("a---b___c") == "a_b_c" + + def test_long_name_truncated_to_64_chars(self): + result = CallbackAdapterCollection._sanitize_name("a" * 200) + assert len(result) <= 64 + assert result[-8:].isalnum() + + def test_long_name_uniqueness(self): + result_a = CallbackAdapterCollection._sanitize_name("a" * 200) + result_b = CallbackAdapterCollection._sanitize_name("b" * 200) + assert result_a != result_b + + def test_short_name_not_truncated(self): + assert CallbackAdapterCollection._sanitize_name("short_name") == "short_name" + + +class TestOutputSemanticSummary: + """Test the _OUTPUT_SEMANTICS mapping in description_outputs.py. + + Other description tests (docstring, output target, multi-output) are + covered by TestTool in test_callback_adapter.py using real adapters. + """ + + @staticmethod + def _adapter_with_outputs(outputs, docstring=None): + from unittest.mock import Mock + adapter = Mock() + adapter.outputs = outputs + adapter._docstring = docstring + return adapter + + @staticmethod + def _out(comp_id, prop, comp_type=None): + return { + "id_and_prop": f"{comp_id}.{prop}", + "component_id": comp_id, + "property": prop, + "component_type": comp_type, + "initial_value": None, + } + + def test_semantic_summary_with_component_type(self): + adapter = self._adapter_with_outputs([self._out("my-graph", "figure", "Graph")]) + desc = build_tool_description(adapter) + assert "Returns chart/visualization data" in desc + + def test_semantic_summary_fallback_by_property(self): + adapter = self._adapter_with_outputs([self._out("unknown-id", "figure")]) + desc = build_tool_description(adapter) + assert "Returns chart/visualization data" in desc From a95b3bf15692dc8953427c631d330723b975a01e Mon Sep 17 00:00:00 2001 From: Adrian Borrmann Date: Thu, 16 Apr 2026 17:16:43 -0600 Subject: [PATCH 45/80] Refactor tools to use a base class (just as resources do) --- dash/mcp/primitives/tools/__init__.py | 33 ++-- dash/mcp/primitives/tools/base.py | 28 +++ .../tools/tool_get_dash_component.py | 170 +++++++++--------- dash/mcp/primitives/tools/tools_callbacks.py | 67 +++---- 4 files changed, 165 insertions(+), 133 deletions(-) create mode 100644 dash/mcp/primitives/tools/base.py diff --git a/dash/mcp/primitives/tools/__init__.py b/dash/mcp/primitives/tools/__init__.py index 64f89dc3d0..7fa1f4aefb 100644 --- a/dash/mcp/primitives/tools/__init__.py +++ b/dash/mcp/primitives/tools/__init__.py @@ -1,12 +1,4 @@ -"""MCP tool listing and call handling. - -Each tool module exports: -- ``get_tool_names() -> set[str]`` -- ``get_tools() -> list[Tool]`` -- ``call_tool(tool_name, arguments) -> CallToolResult`` - -The __init__ assembles the list and dispatches calls by name. -""" +"""MCP tool listing and call handling.""" from __future__ import annotations @@ -16,26 +8,29 @@ from dash.mcp.types import ToolNotFoundError -from . import tool_get_dash_component as _get_component -from . import tools_callbacks as _callbacks +from .base import MCPToolProvider +from .tool_get_dash_component import GetDashComponentTool +from .tools_callbacks import CallbackTools -_TOOL_MODULES = [_callbacks, _get_component] +_TOOL_PROVIDERS: list[type[MCPToolProvider]] = [ + CallbackTools, + GetDashComponentTool, +] def list_tools() -> ListToolsResult: """Build the MCP tools/list response.""" tools = [] - for mod in _TOOL_MODULES: - tools.extend(mod.get_tools()) + for provider in _TOOL_PROVIDERS: + tools.extend(provider.list_tools()) return ListToolsResult(tools=tools) def call_tool(tool_name: str, arguments: dict[str, Any]) -> CallToolResult: - """Dispatch a tools/call request by tool name.""" - for mod in _TOOL_MODULES: - if tool_name in mod.get_tool_names(): - result = mod.call_tool(tool_name, arguments) - return result + """Route a tools/call request by tool name.""" + for provider in _TOOL_PROVIDERS: + if tool_name in provider.get_tool_names(): + return provider.call_tool(tool_name, arguments) raise ToolNotFoundError( f"Tool not found: {tool_name}." " The app's callbacks may have changed." diff --git a/dash/mcp/primitives/tools/base.py b/dash/mcp/primitives/tools/base.py new file mode 100644 index 0000000000..60fa7374d6 --- /dev/null +++ b/dash/mcp/primitives/tools/base.py @@ -0,0 +1,28 @@ +"""Base class for MCP tool providers.""" + +from __future__ import annotations + +from typing import Any + +from mcp.types import CallToolResult, Tool + + +class MCPToolProvider: + """A provider of one or more MCP tools. + + Subclasses implement ``list_tools`` to return the tools they provide, + ``get_tool_names`` to advertise those names for routing, and + ``call_tool`` to execute a tool by name. + """ + + @classmethod + def get_tool_names(cls) -> set[str]: + raise NotImplementedError + + @classmethod + def list_tools(cls) -> list[Tool]: + raise NotImplementedError + + @classmethod + def call_tool(cls, tool_name: str, arguments: dict[str, Any]) -> CallToolResult: + raise NotImplementedError diff --git a/dash/mcp/primitives/tools/tool_get_dash_component.py b/dash/mcp/primitives/tools/tool_get_dash_component.py index 8584242333..f41c9bec68 100644 --- a/dash/mcp/primitives/tools/tool_get_dash_component.py +++ b/dash/mcp/primitives/tools/tool_get_dash_component.py @@ -13,6 +13,8 @@ from dash._layout_utils import find_component from dash.mcp.types import ComponentPropertyInfo, ComponentQueryResult +from .base import MCPToolProvider + class _ComponentQueryInput(TypedDict): component_id: Annotated[str, Field(description="The component ID to query")] @@ -32,92 +34,94 @@ class _ComponentQueryInput(TypedDict): NAME = "get_dash_component" -def get_tool_names() -> set[str]: - return {NAME} - - -def get_tools() -> list[Tool]: - return [_build_tool()] - - -def _build_tool() -> Tool: - return Tool( - name=NAME, - description=( - "Get a component's properties, values, and tool relationships. " - "If property is omitted, returns all defined properties. " - "If property is specified, returns only that property. " - "See the dash://components resource for available component IDs." - ), - inputSchema=_INPUT_SCHEMA, - outputSchema=_OUTPUT_SCHEMA, - ) - - -def call_tool(tool_name: str, arguments: dict[str, Any]) -> CallToolResult: - comp_id = arguments.get("component_id", "") - if not comp_id: - raise ValueError("component_id is required") +class GetDashComponentTool(MCPToolProvider): + """Inspects a component's properties and its tool relationships.""" + + @classmethod + def get_tool_names(cls) -> set[str]: + return {NAME} + + @classmethod + def list_tools(cls) -> list[Tool]: + return [ + Tool( + name=NAME, + description=( + "Get a component's properties, values, and tool relationships. " + "If property is omitted, returns all defined properties. " + "If property is specified, returns only that property. " + "See the dash://components resource for available component IDs." + ), + inputSchema=_INPUT_SCHEMA, + outputSchema=_OUTPUT_SCHEMA, + ) + ] - prop_filter = arguments.get("property", "") - component = find_component(comp_id) + @classmethod + def call_tool(cls, tool_name: str, arguments: dict[str, Any]) -> CallToolResult: + comp_id = arguments.get("component_id", "") + if not comp_id: + raise ValueError("component_id is required") + + prop_filter = arguments.get("property", "") + component = find_component(comp_id) + + if component is None: + callback_map = get_app().mcp_callback_map + rendering_tools = [ + cb.tool_name + for cb in callback_map + if any(out["component_id"] == comp_id for out in cb.outputs) + ] + msg = f"Component '{comp_id}' not found in static layout." + if rendering_tools: + msg += f" However, the following tools would modify it: {rendering_tools}." + msg += " Use the dash://components resource to see statically available component IDs." + return CallToolResult( + content=[TextContent(type="text", text=msg)], + isError=True, + ) - if component is None: callback_map = get_app().mcp_callback_map - rendering_tools = [ - cb.tool_name - for cb in callback_map - if any(out["component_id"] == comp_id for out in cb.outputs) - ] - msg = f"Component '{comp_id}' not found in static layout." - if rendering_tools: - msg += f" However, the following tools would modify it: {rendering_tools}." - msg += " Use the dash://components resource to see statically available component IDs." - return CallToolResult( - content=[TextContent(type="text", text=msg)], - isError=True, - ) - callback_map = get_app().mcp_callback_map - - properties: dict[str, ComponentPropertyInfo] = {} - for prop_name in getattr(component, "_prop_names", []): - if prop_filter and prop_name != prop_filter: - continue - - value = callback_map.get_initial_value(f"{comp_id}.{prop_name}") - if value is None: - value = getattr(component, prop_name, None) - if value is None: - continue - - modified_by: list[str] = [] - input_to: list[str] = [] - id_and_prop = f"{comp_id}.{prop_name}" - for cb in callback_map: - for out in cb.outputs: - if out["id_and_prop"] == id_and_prop: - modified_by.append(cb.tool_name) - for inp in cb.inputs: - if inp["id_and_prop"] == id_and_prop: - input_to.append(cb.tool_name) - - properties[prop_name] = ComponentPropertyInfo( - initial_value=value, - modified_by_tool=modified_by, - input_to_tool=input_to, + properties: dict[str, ComponentPropertyInfo] = {} + for prop_name in getattr(component, "_prop_names", []): + if prop_filter and prop_name != prop_filter: + continue + + value = callback_map.get_initial_value(f"{comp_id}.{prop_name}") + if value is None: + value = getattr(component, prop_name, None) + if value is None: + continue + + modified_by: list[str] = [] + input_to: list[str] = [] + id_and_prop = f"{comp_id}.{prop_name}" + for cb in callback_map: + for out in cb.outputs: + if out["id_and_prop"] == id_and_prop: + modified_by.append(cb.tool_name) + for inp in cb.inputs: + if inp["id_and_prop"] == id_and_prop: + input_to.append(cb.tool_name) + + properties[prop_name] = ComponentPropertyInfo( + initial_value=value, + modified_by_tool=modified_by, + input_to_tool=input_to, + ) + + labels = callback_map.component_label_map.get(comp_id, []) + + structured: ComponentQueryResult = ComponentQueryResult( + component_id=comp_id, + component_type=type(component).__name__, + label=labels if labels else None, + properties=properties, ) - labels = callback_map.component_label_map.get(comp_id, []) - - structured: ComponentQueryResult = ComponentQueryResult( - component_id=comp_id, - component_type=type(component).__name__, - label=labels if labels else None, - properties=properties, - ) - - return CallToolResult( - content=[TextContent(type="text", text=json.dumps(structured, default=str))], - structuredContent=structured, - ) + return CallToolResult( + content=[TextContent(type="text", text=json.dumps(structured, default=str))], + structuredContent=structured, + ) diff --git a/dash/mcp/primitives/tools/tools_callbacks.py b/dash/mcp/primitives/tools/tools_callbacks.py index ba08795d35..91bc360059 100644 --- a/dash/mcp/primitives/tools/tools_callbacks.py +++ b/dash/mcp/primitives/tools/tools_callbacks.py @@ -1,6 +1,6 @@ """Dynamic callback tools for MCP. -Handles listing, naming, and executing callback-based tools. +Exposes every server-callable callback as an MCP tool. """ from __future__ import annotations @@ -12,36 +12,41 @@ from dash import get_app from dash.mcp.types import CallbackExecutionError, ToolNotFoundError +from .base import MCPToolProvider from .results import format_callback_response -def get_tool_names() -> set[str]: - return get_app().mcp_callback_map.tool_names - - -def get_tools() -> list[Tool]: - """Return one Tool per server-callable callback.""" - return get_app().mcp_callback_map.as_mcp_tools() - - -def call_tool(tool_name: str, arguments: dict[str, Any]) -> CallToolResult: - """Execute a callback tool by name.""" - from .callback_utils import run_callback - - callback_map = get_app().mcp_callback_map - cb = callback_map.find_by_tool_name(tool_name) - if cb is None: - raise ToolNotFoundError( - f"Tool not found: {tool_name}." - " The app's callbacks may have changed." - " Please call tools/list to refresh your tool list." - ) - - try: - dispatch_response = run_callback(cb, arguments) - except CallbackExecutionError as e: - return CallToolResult( - content=[TextContent(type="text", text=str(e))], - isError=True, - ) - return format_callback_response(dispatch_response, cb) +class CallbackTools(MCPToolProvider): + """Exposes every server-callable callback as an MCP tool.""" + + @classmethod + def get_tool_names(cls) -> set[str]: + return get_app().mcp_callback_map.tool_names + + @classmethod + def list_tools(cls) -> list[Tool]: + """Return one Tool per server-callable callback.""" + return get_app().mcp_callback_map.as_mcp_tools() + + @classmethod + def call_tool(cls, tool_name: str, arguments: dict[str, Any]) -> CallToolResult: + """Execute a callback tool by name.""" + from .callback_utils import run_callback + + callback_map = get_app().mcp_callback_map + cb = callback_map.find_by_tool_name(tool_name) + if cb is None: + raise ToolNotFoundError( + f"Tool not found: {tool_name}." + " The app's callbacks may have changed." + " Please call tools/list to refresh your tool list." + ) + + try: + callback_response = run_callback(cb, arguments) + except CallbackExecutionError as e: + return CallToolResult( + content=[TextContent(type="text", text=str(e))], + isError=True, + ) + return format_callback_response(callback_response, cb) From 9e4cc2cbf468921f30cbb872f4d6b90bda917295 Mon Sep 17 00:00:00 2001 From: Adrian Borrmann Date: Thu, 16 Apr 2026 17:39:36 -0600 Subject: [PATCH 46/80] lint --- .../tools/tool_get_dash_component.py | 8 +- dash/mcp/primitives/tools/tools_callbacks.py | 3 +- tests/unit/development/metadata_test.py | 197 +++++++++++------- tests/unit/mcp/conftest.py | 5 +- tests/unit/mcp/tools/test_tools_callbacks.py | 1 + 5 files changed, 134 insertions(+), 80 deletions(-) diff --git a/dash/mcp/primitives/tools/tool_get_dash_component.py b/dash/mcp/primitives/tools/tool_get_dash_component.py index f41c9bec68..69b6276d5a 100644 --- a/dash/mcp/primitives/tools/tool_get_dash_component.py +++ b/dash/mcp/primitives/tools/tool_get_dash_component.py @@ -75,7 +75,9 @@ def call_tool(cls, tool_name: str, arguments: dict[str, Any]) -> CallToolResult: ] msg = f"Component '{comp_id}' not found in static layout." if rendering_tools: - msg += f" However, the following tools would modify it: {rendering_tools}." + msg += ( + f" However, the following tools would modify it: {rendering_tools}." + ) msg += " Use the dash://components resource to see statically available component IDs." return CallToolResult( content=[TextContent(type="text", text=msg)], @@ -122,6 +124,8 @@ def call_tool(cls, tool_name: str, arguments: dict[str, Any]) -> CallToolResult: ) return CallToolResult( - content=[TextContent(type="text", text=json.dumps(structured, default=str))], + content=[ + TextContent(type="text", text=json.dumps(structured, default=str)) + ], structuredContent=structured, ) diff --git a/dash/mcp/primitives/tools/tools_callbacks.py b/dash/mcp/primitives/tools/tools_callbacks.py index 91bc360059..716b777326 100644 --- a/dash/mcp/primitives/tools/tools_callbacks.py +++ b/dash/mcp/primitives/tools/tools_callbacks.py @@ -13,6 +13,7 @@ from dash.mcp.types import CallbackExecutionError, ToolNotFoundError from .base import MCPToolProvider +from .callback_utils import run_callback from .results import format_callback_response @@ -31,8 +32,6 @@ def list_tools(cls) -> list[Tool]: @classmethod def call_tool(cls, tool_name: str, arguments: dict[str, Any]) -> CallToolResult: """Execute a callback tool by name.""" - from .callback_utils import run_callback - callback_map = get_app().mcp_callback_map cb = callback_map.find_by_tool_name(tool_name) if cb is None: diff --git a/tests/unit/development/metadata_test.py b/tests/unit/development/metadata_test.py index 0ee96efdeb..e2938af338 100644 --- a/tests/unit/development/metadata_test.py +++ b/tests/unit/development/metadata_test.py @@ -1,7 +1,7 @@ # AUTO GENERATED FILE - DO NOT EDIT import typing # noqa: F401 -from typing_extensions import TypedDict, NotRequired, Literal # noqa: F401 +from typing_extensions import TypedDict, NotRequired, Literal # noqa: F401 from dash.development.base_component import Component, _explicitize_args ComponentType = typing.Union[ @@ -20,126 +20,120 @@ class Table(Component): """A Table component. -This is a description of the component. -It's multiple lines long. + This is a description of the component. + It's multiple lines long. -Keyword arguments: + Keyword arguments: -- children (a list of or a singular dash component, string or number; optional) + - children (a list of or a singular dash component, string or number; optional) -- id (string; optional) + - id (string; optional) -- aria-* (string; optional) + - aria-* (string; optional) -- customArrayProp (list; optional) + - customArrayProp (list; optional) -- customProp (optional) + - customProp (optional) -- data-* (string; optional) + - data-* (string; optional) -- in (string; optional) + - in (string; optional) -- optionalAny (boolean | number | string | dict | list; optional) + - optionalAny (boolean | number | string | dict | list; optional) -- optionalArray (list; optional): - Description of optionalArray. + - optionalArray (list; optional): + Description of optionalArray. -- optionalArrayOf (list of numbers; optional) + - optionalArrayOf (list of numbers; optional) -- optionalBool (boolean; optional) + - optionalBool (boolean; optional) -- optionalElement (dash component; optional) + - optionalElement (dash component; optional) -- optionalEnum (a value equal to: 'News', 'Photos'; optional) + - optionalEnum (a value equal to: 'News', 'Photos'; optional) -- optionalNode (a list of or a singular dash component, string or number; optional) + - optionalNode (a list of or a singular dash component, string or number; optional) -- optionalNumber (number; default 42) + - optionalNumber (number; default 42) -- optionalObject (dict; optional) + - optionalObject (dict; optional) -- optionalObjectOf (dict with strings as keys and values of type number; optional) + - optionalObjectOf (dict with strings as keys and values of type number; optional) -- optionalObjectWithExactAndNestedDescription (dict; optional) + - optionalObjectWithExactAndNestedDescription (dict; optional) - `optionalObjectWithExactAndNestedDescription` is a dict with keys: + `optionalObjectWithExactAndNestedDescription` is a dict with keys: - - color (string; optional) + - color (string; optional) - - fontSize (number; optional) + - fontSize (number; optional) - - figure (dict; optional): - Figure is a plotly graph object. + - figure (dict; optional): + Figure is a plotly graph object. - `figure` is a dict with keys: + `figure` is a dict with keys: - - data (list of dicts; optional): - data is a collection of traces. + - data (list of dicts; optional): + data is a collection of traces. - - layout (dict; optional): - layout describes the rest of the figure. + - layout (dict; optional): + layout describes the rest of the figure. -- optionalObjectWithShapeAndNestedDescription (dict; optional) + - optionalObjectWithShapeAndNestedDescription (dict; optional) - `optionalObjectWithShapeAndNestedDescription` is a dict with keys: + `optionalObjectWithShapeAndNestedDescription` is a dict with keys: - - color (string; optional) + - color (string; optional) - - fontSize (number; optional) + - fontSize (number; optional) - - figure (dict; optional): - Figure is a plotly graph object. + - figure (dict; optional): + Figure is a plotly graph object. - `figure` is a dict with keys: + `figure` is a dict with keys: - - data (list of dicts; optional): - data is a collection of traces. + - data (list of dicts; optional): + data is a collection of traces. - - layout (dict; optional): - layout describes the rest of the figure. + - layout (dict; optional): + layout describes the rest of the figure. -- optionalString (string; default 'hello world') + - optionalString (string; default 'hello world') -- optionalUnion (string | number; optional)""" - _children_props = ['optionalNode', 'optionalElement'] - _base_nodes = ['optionalNode', 'optionalElement', 'children'] - _namespace = 'TableComponents' - _type = 'Table' + - optionalUnion (string | number; optional)""" + + _children_props = ["optionalNode", "optionalElement"] + _base_nodes = ["optionalNode", "optionalElement", "children"] + _namespace = "TableComponents" + _type = "Table" OptionalObjectWithExactAndNestedDescriptionFigure = TypedDict( "OptionalObjectWithExactAndNestedDescriptionFigure", - { - "data": NotRequired[typing.Sequence[dict]], - "layout": NotRequired[dict] - } + {"data": NotRequired[typing.Sequence[dict]], "layout": NotRequired[dict]}, ) OptionalObjectWithExactAndNestedDescription = TypedDict( "OptionalObjectWithExactAndNestedDescription", - { + { "color": NotRequired[str], "fontSize": NotRequired[NumberType], - "figure": NotRequired["OptionalObjectWithExactAndNestedDescriptionFigure"] - } + "figure": NotRequired["OptionalObjectWithExactAndNestedDescriptionFigure"], + }, ) OptionalObjectWithShapeAndNestedDescriptionFigure = TypedDict( "OptionalObjectWithShapeAndNestedDescriptionFigure", - { - "data": NotRequired[typing.Sequence[dict]], - "layout": NotRequired[dict] - } + {"data": NotRequired[typing.Sequence[dict]], "layout": NotRequired[dict]}, ) OptionalObjectWithShapeAndNestedDescription = TypedDict( "OptionalObjectWithShapeAndNestedDescription", - { + { "color": NotRequired[str], "fontSize": NotRequired[NumberType], - "figure": NotRequired["OptionalObjectWithShapeAndNestedDescriptionFigure"] - } + "figure": NotRequired["OptionalObjectWithShapeAndNestedDescriptionFigure"], + }, ) - def __init__( self, children: typing.Optional[ComponentType] = None, @@ -154,26 +148,79 @@ def __init__( optionalElement: typing.Optional[Component] = None, optionalMessage: typing.Optional[typing.Any] = None, optionalEnum: typing.Optional[Literal["News", "Photos"]] = None, - optionalUnion: typing.Optional[typing.Union[str, NumberType, typing.Any]] = None, + optionalUnion: typing.Optional[ + typing.Union[str, NumberType, typing.Any] + ] = None, optionalArrayOf: typing.Optional[typing.Sequence[NumberType]] = None, - optionalObjectOf: typing.Optional[typing.Dict[typing.Union[str, float, int], NumberType]] = None, - optionalObjectWithExactAndNestedDescription: typing.Optional["OptionalObjectWithExactAndNestedDescription"] = None, - optionalObjectWithShapeAndNestedDescription: typing.Optional["OptionalObjectWithShapeAndNestedDescription"] = None, + optionalObjectOf: typing.Optional[ + typing.Dict[typing.Union[str, float, int], NumberType] + ] = None, + optionalObjectWithExactAndNestedDescription: typing.Optional[ + "OptionalObjectWithExactAndNestedDescription" + ] = None, + optionalObjectWithShapeAndNestedDescription: typing.Optional[ + "OptionalObjectWithShapeAndNestedDescription" + ] = None, optionalAny: typing.Optional[typing.Any] = None, customProp: typing.Optional[typing.Any] = None, customArrayProp: typing.Optional[typing.Sequence[typing.Any]] = None, id: typing.Optional[typing.Union[str, dict]] = None, **kwargs ): - self._prop_names = ['children', 'id', 'aria-*', 'customArrayProp', 'customProp', 'data-*', 'in', 'optionalAny', 'optionalArray', 'optionalArrayOf', 'optionalBool', 'optionalElement', 'optionalEnum', 'optionalNode', 'optionalNumber', 'optionalObject', 'optionalObjectOf', 'optionalObjectWithExactAndNestedDescription', 'optionalObjectWithShapeAndNestedDescription', 'optionalString', 'optionalUnion'] - self._valid_wildcard_attributes = ['data-', 'aria-'] - self.available_properties = ['children', 'id', 'aria-*', 'customArrayProp', 'customProp', 'data-*', 'in', 'optionalAny', 'optionalArray', 'optionalArrayOf', 'optionalBool', 'optionalElement', 'optionalEnum', 'optionalNode', 'optionalNumber', 'optionalObject', 'optionalObjectOf', 'optionalObjectWithExactAndNestedDescription', 'optionalObjectWithShapeAndNestedDescription', 'optionalString', 'optionalUnion'] - self.available_wildcard_properties = ['data-', 'aria-'] - _explicit_args = kwargs.pop('_explicit_args') + self._prop_names = [ + "children", + "id", + "aria-*", + "customArrayProp", + "customProp", + "data-*", + "in", + "optionalAny", + "optionalArray", + "optionalArrayOf", + "optionalBool", + "optionalElement", + "optionalEnum", + "optionalNode", + "optionalNumber", + "optionalObject", + "optionalObjectOf", + "optionalObjectWithExactAndNestedDescription", + "optionalObjectWithShapeAndNestedDescription", + "optionalString", + "optionalUnion", + ] + self._valid_wildcard_attributes = ["data-", "aria-"] + self.available_properties = [ + "children", + "id", + "aria-*", + "customArrayProp", + "customProp", + "data-*", + "in", + "optionalAny", + "optionalArray", + "optionalArrayOf", + "optionalBool", + "optionalElement", + "optionalEnum", + "optionalNode", + "optionalNumber", + "optionalObject", + "optionalObjectOf", + "optionalObjectWithExactAndNestedDescription", + "optionalObjectWithShapeAndNestedDescription", + "optionalString", + "optionalUnion", + ] + self.available_wildcard_properties = ["data-", "aria-"] + _explicit_args = kwargs.pop("_explicit_args") _locals = locals() _locals.update(kwargs) # For wildcard attrs and excess named props - args = {k: _locals[k] for k in _explicit_args if k != 'children'} + args = {k: _locals[k] for k in _explicit_args if k != "children"} super(Table, self).__init__(children=children, **args) + setattr(Table, "__init__", _explicitize_args(Table.__init__)) diff --git a/tests/unit/mcp/conftest.py b/tests/unit/mcp/conftest.py index 38b2f8c00e..2f7fbc1898 100644 --- a/tests/unit/mcp/conftest.py +++ b/tests/unit/mcp/conftest.py @@ -9,7 +9,10 @@ if sys.version_info < (3, 10): collect_ignore_glob.append("*") else: - from dash.mcp.primitives.tools import call_tool, list_tools # pylint: disable=wrong-import-position + from dash.mcp.primitives.tools import ( + call_tool, + list_tools, + ) # pylint: disable=wrong-import-position from dash.mcp.primitives.tools.callback_adapter_collection import ( # pylint: disable=wrong-import-position CallbackAdapterCollection, ) diff --git a/tests/unit/mcp/tools/test_tools_callbacks.py b/tests/unit/mcp/tools/test_tools_callbacks.py index 751f85bedd..fe57c42cce 100644 --- a/tests/unit/mcp/tools/test_tools_callbacks.py +++ b/tests/unit/mcp/tools/test_tools_callbacks.py @@ -111,6 +111,7 @@ class TestOutputSemanticSummary: @staticmethod def _adapter_with_outputs(outputs, docstring=None): from unittest.mock import Mock + adapter = Mock() adapter.outputs = outputs adapter._docstring = docstring From c079065dc31fad2962deb49a670711de9736231f Mon Sep 17 00:00:00 2001 From: Adrian Borrmann Date: Thu, 23 Apr 2026 14:24:38 -0600 Subject: [PATCH 47/80] Refactor unit tests to conform to existing test patterns --- .../tools/test_mcp_tool_get_dash_component.py | 125 ++++++++++++++ .../mcp/tools/test_mcp_tools_callbacks.py | 152 ++++++++++++++++++ .../mcp/tools/test_tool_get_dash_component.py | 117 -------------- tests/unit/mcp/tools/test_tools_callbacks.py | 138 ---------------- 4 files changed, 277 insertions(+), 255 deletions(-) create mode 100644 tests/unit/mcp/tools/test_mcp_tool_get_dash_component.py create mode 100644 tests/unit/mcp/tools/test_mcp_tools_callbacks.py delete mode 100644 tests/unit/mcp/tools/test_tool_get_dash_component.py delete mode 100644 tests/unit/mcp/tools/test_tools_callbacks.py diff --git a/tests/unit/mcp/tools/test_mcp_tool_get_dash_component.py b/tests/unit/mcp/tools/test_mcp_tool_get_dash_component.py new file mode 100644 index 0000000000..4cf11555be --- /dev/null +++ b/tests/unit/mcp/tools/test_mcp_tool_get_dash_component.py @@ -0,0 +1,125 @@ +"""The built-in ``get_dash_component`` tool. + +Lets LLMs inspect a component's current properties and their relationships +to callbacks (which tools write/read each prop). +""" + +from dash import Dash, Input, Output, dcc, html + +from tests.unit.mcp.conftest import _call_tool, _make_app, _tools_list + + +def test_mcpg001_present_in_tools_list(): + app = _make_app() + tool_names = [t.name for t in _tools_list(app)] + assert "get_dash_component" in tool_names + + +def test_mcpg002_returns_structured_output_with_prop(): + app = Dash(__name__) + app.layout = html.Div( + [ + dcc.Dropdown(id="my-dd", options=["a", "b"], value="b"), + ] + ) + + result = _call_tool( + app, + "get_dash_component", + { + "component_id": "my-dd", + "property": "value", + }, + ) + sc = result.structuredContent + assert sc["component_id"] == "my-dd" + assert sc["component_type"] == "Dropdown" + assert "value" in sc["properties"] + assert sc["properties"]["value"]["initial_value"] == "b" + assert "options" not in sc["properties"] + + +def test_mcpg003_returns_all_props_without_property(): + app = Dash(__name__) + app.layout = html.Div( + [ + dcc.Dropdown(id="my-dd", options=["a", "b"], value="b"), + ] + ) + + result = _call_tool( + app, + "get_dash_component", + { + "component_id": "my-dd", + }, + ) + sc = result.structuredContent + assert "options" in sc["properties"] + assert "value" in sc["properties"] + assert sc["properties"]["value"]["initial_value"] == "b" + + +def test_mcpg004_includes_label(): + app = Dash(__name__) + app.layout = html.Div( + [ + html.Label("Pick one", htmlFor="my-dd"), + dcc.Dropdown(id="my-dd", options=["a", "b"], value="a"), + ] + ) + + @app.callback(Output("my-dd", "value"), Input("my-dd", "options")) + def noop(o): + return "a" + + result = _call_tool( + app, + "get_dash_component", + { + "component_id": "my-dd", + }, + ) + sc = result.structuredContent + assert sc["label"] == ["Pick one"] + + +def test_mcpg005_includes_tool_references(): + app = Dash(__name__) + app.layout = html.Div( + [ + dcc.Dropdown(id="dd", options=["a", "b"], value="a"), + html.Div(id="out"), + ] + ) + + @app.callback(Output("out", "children"), Input("dd", "value")) + def update(val): + return val + + result = _call_tool( + app, + "get_dash_component", + { + "component_id": "dd", + "property": "value", + }, + ) + prop_info = result.structuredContent["properties"]["value"] + assert "update" in prop_info["input_to_tool"] + + +def test_mcpg006_missing_id_returns_hint(): + app = _make_app() + result = _call_tool( + app, + "get_dash_component", + { + "component_id": "nonexistent", + "property": "value", + }, + ) + text = result.content[0].text + assert "nonexistent" in text + assert "not found" in text + assert "dash://components" in text diff --git a/tests/unit/mcp/tools/test_mcp_tools_callbacks.py b/tests/unit/mcp/tools/test_mcp_tools_callbacks.py new file mode 100644 index 0000000000..7148803bc1 --- /dev/null +++ b/tests/unit/mcp/tools/test_mcp_tools_callbacks.py @@ -0,0 +1,152 @@ +"""Dynamic callback tools: MCP spec compliance, tool naming, output summaries. + +Verifies that generated tools conform to the MCP 2025-11-25 specification +and Dash-specific conventions. Focuses on shape/structure, tool-name +sanitization, and ``_OUTPUT_SEMANTICS`` fallback summaries; input-schema +values are covered by ``test_mcp_input_schemas``. + +Reference: https://modelcontextprotocol.io/specification/2025-11-25/server/tools +""" + +import re +from unittest.mock import Mock + +from dash.mcp.primitives.tools.callback_adapter_collection import ( + CallbackAdapterCollection, +) +from dash.mcp.primitives.tools.descriptions import build_tool_description + +from tests.unit.mcp.conftest import ( + _make_app, + _tools_list, +) + + +# --------------------------------------------------------------------------- +# Helpers +# --------------------------------------------------------------------------- + +_TOOL_NAME_RE = re.compile(r"^[A-Za-z0-9_\-.]+$") + + +def _adapter_with_outputs(outputs, docstring=None): + adapter = Mock() + adapter.outputs = outputs + adapter._docstring = docstring + return adapter + + +def _out(comp_id, prop, comp_type=None): + return { + "id_and_prop": f"{comp_id}.{prop}", + "component_id": comp_id, + "property": prop, + "component_type": comp_type, + "initial_value": None, + } + + +# --------------------------------------------------------------------------- +# MCP spec compliance — every generated tool must satisfy the spec +# --------------------------------------------------------------------------- + + +def test_mcptc001_all_tools_conform_to_mcp_spec(): + tools = _tools_list(_make_app()) + names = [t.name for t in tools] + + assert len(names) == len(set(names)), f"Duplicate tool names: {names}" + + for tool in tools: + assert tool.name + assert tool.inputSchema + assert 1 <= len(tool.name) <= 128 + assert _TOOL_NAME_RE.match(tool.name), f"Invalid tool name: {tool.name}" + + schema = tool.inputSchema + assert isinstance(schema, dict) + assert schema.get("type") == "object" + assert isinstance(schema.get("properties", {}), dict) + + required = set(schema.get("required", [])) + props = set(schema.get("properties", {}).keys()) + assert ( + required <= props + ), f"{tool.name}: required {required - props} not in properties" + + +# --------------------------------------------------------------------------- +# Built-in tools +# --------------------------------------------------------------------------- + + +def test_mcptc002_query_component_always_present(): + names = {t.name for t in _tools_list(_make_app())} + assert "get_dash_component" in names + + +def test_mcptc003_query_component_has_required_params(): + tool = next(t for t in _tools_list(_make_app()) if t.name == "get_dash_component") + assert "component_id" in tool.inputSchema["properties"] + assert "property" in tool.inputSchema["properties"] + assert set(tool.inputSchema.get("required", [])) == {"component_id"} + + +# --------------------------------------------------------------------------- +# Tool-name sanitization (CallbackAdapterCollection._sanitize_name) +# --------------------------------------------------------------------------- + + +def test_mcptc004_sanitize_simple_name(): + assert CallbackAdapterCollection._sanitize_name("update_output") == "update_output" + + +def test_mcptc005_sanitize_special_characters_replaced(): + assert CallbackAdapterCollection._sanitize_name("my-func.name") == "my_func_name" + + +def test_mcptc006_sanitize_leading_digit(): + assert CallbackAdapterCollection._sanitize_name("123func") == "cb_123func" + + +def test_mcptc007_sanitize_empty_name(): + assert CallbackAdapterCollection._sanitize_name("") == "unnamed_callback" + + +def test_mcptc008_sanitize_consecutive_underscores_collapsed(): + assert CallbackAdapterCollection._sanitize_name("a---b___c") == "a_b_c" + + +def test_mcptc009_sanitize_long_name_truncated_to_64_chars(): + result = CallbackAdapterCollection._sanitize_name("a" * 200) + assert len(result) <= 64 + assert result[-8:].isalnum() + + +def test_mcptc010_sanitize_long_name_uniqueness(): + result_a = CallbackAdapterCollection._sanitize_name("a" * 200) + result_b = CallbackAdapterCollection._sanitize_name("b" * 200) + assert result_a != result_b + + +def test_mcptc011_sanitize_short_name_not_truncated(): + assert CallbackAdapterCollection._sanitize_name("short_name") == "short_name" + + +# --------------------------------------------------------------------------- +# Output semantic summary (``_OUTPUT_SEMANTICS`` in description_outputs.py). +# Other description tests (docstring, output target, multi-output) are +# covered by test_mcp_tools using real adapters. +# --------------------------------------------------------------------------- + + +def test_mcptc012_semantic_summary_with_component_type(): + adapter = _adapter_with_outputs([_out("my-graph", "figure", "Graph")]) + desc = build_tool_description(adapter) + assert "Returns chart/visualization data" in desc + + +def test_mcptc013_semantic_summary_fallback_by_property(): + adapter = _adapter_with_outputs([_out("unknown-id", "figure")]) + desc = build_tool_description(adapter) + assert "Returns chart/visualization data" in desc diff --git a/tests/unit/mcp/tools/test_tool_get_dash_component.py b/tests/unit/mcp/tools/test_tool_get_dash_component.py deleted file mode 100644 index 5a8a454068..0000000000 --- a/tests/unit/mcp/tools/test_tool_get_dash_component.py +++ /dev/null @@ -1,117 +0,0 @@ -"""Tests for the get_dash_component built-in tool.""" - -from dash import Dash, Input, Output, dcc, html - -from tests.unit.mcp.conftest import _call_tool, _make_app, _tools_list - - -class TestGetDashComponent: - def test_present_in_tools_list(self): - app = _make_app() - tool_names = [t.name for t in _tools_list(app)] - assert "get_dash_component" in tool_names - - def test_returns_structured_output_with_prop(self): - app = Dash(__name__) - app.layout = html.Div( - [ - dcc.Dropdown(id="my-dd", options=["a", "b"], value="b"), - ] - ) - - result = _call_tool( - app, - "get_dash_component", - { - "component_id": "my-dd", - "property": "value", - }, - ) - sc = result.structuredContent - assert sc["component_id"] == "my-dd" - assert sc["component_type"] == "Dropdown" - assert "value" in sc["properties"] - assert sc["properties"]["value"]["initial_value"] == "b" - assert "options" not in sc["properties"] - - def test_returns_all_props_without_property(self): - app = Dash(__name__) - app.layout = html.Div( - [ - dcc.Dropdown(id="my-dd", options=["a", "b"], value="b"), - ] - ) - - result = _call_tool( - app, - "get_dash_component", - { - "component_id": "my-dd", - }, - ) - sc = result.structuredContent - assert "options" in sc["properties"] - assert "value" in sc["properties"] - assert sc["properties"]["value"]["initial_value"] == "b" - - def test_includes_label(self): - app = Dash(__name__) - app.layout = html.Div( - [ - html.Label("Pick one", htmlFor="my-dd"), - dcc.Dropdown(id="my-dd", options=["a", "b"], value="a"), - ] - ) - - @app.callback(Output("my-dd", "value"), Input("my-dd", "options")) - def noop(o): - return "a" - - result = _call_tool( - app, - "get_dash_component", - { - "component_id": "my-dd", - }, - ) - sc = result.structuredContent - assert sc["label"] == ["Pick one"] - - def test_includes_tool_references(self): - app = Dash(__name__) - app.layout = html.Div( - [ - dcc.Dropdown(id="dd", options=["a", "b"], value="a"), - html.Div(id="out"), - ] - ) - - @app.callback(Output("out", "children"), Input("dd", "value")) - def update(val): - return val - - result = _call_tool( - app, - "get_dash_component", - { - "component_id": "dd", - "property": "value", - }, - ) - prop_info = result.structuredContent["properties"]["value"] - assert "update" in prop_info["input_to_tool"] - - def test_missing_id_returns_hint(self): - app = _make_app() - result = _call_tool( - app, - "get_dash_component", - { - "component_id": "nonexistent", - "property": "value", - }, - ) - text = result.content[0].text - assert "nonexistent" in text - assert "not found" in text - assert "dash://components" in text diff --git a/tests/unit/mcp/tools/test_tools_callbacks.py b/tests/unit/mcp/tools/test_tools_callbacks.py deleted file mode 100644 index fe57c42cce..0000000000 --- a/tests/unit/mcp/tools/test_tools_callbacks.py +++ /dev/null @@ -1,138 +0,0 @@ -"""Tool definition tests — MCP spec compliance and Dash conventions. - -Verifies that generated tools conform to the MCP specification (2025-11-25) -and Dash-specific conventions. Focuses on shape/structure, not inputSchema -values (those are covered by input_schemas/). - -Reference: https://modelcontextprotocol.io/specification/2025-11-25/server/tools -""" - -import re - -from dash.mcp.primitives.tools.callback_adapter_collection import ( - CallbackAdapterCollection, -) -from dash.mcp.primitives.tools.descriptions import build_tool_description - -from tests.unit.mcp.conftest import ( - _make_app, - _tools_list, -) - -_TOOL_NAME_RE = re.compile(r"^[A-Za-z0-9_\-.]+$") - - -class TestToolSpecCompliance: - """Every tool must conform to the MCP 2025-11-25 specification.""" - - def test_all_tools_conform_to_mcp_spec(self): - tools = _tools_list(_make_app()) - names = [t.name for t in tools] - - assert len(names) == len(set(names)), f"Duplicate tool names: {names}" - - for tool in tools: - assert tool.name - assert tool.inputSchema - assert 1 <= len(tool.name) <= 128 - assert _TOOL_NAME_RE.match(tool.name), f"Invalid tool name: {tool.name}" - - schema = tool.inputSchema - assert isinstance(schema, dict) - assert schema.get("type") == "object" - assert isinstance(schema.get("properties", {}), dict) - - required = set(schema.get("required", [])) - props = set(schema.get("properties", {}).keys()) - assert ( - required <= props - ), f"{tool.name}: required {required - props} not in properties" - - -class TestBuiltinToolDefinitions: - def _tools(self): - return _tools_list(_make_app()) - - def _builtin(self, name): - return next(t for t in self._tools() if t.name == name) - - def test_query_component_always_present(self): - names = {t.name for t in self._tools()} - assert "get_dash_component" in names - - def test_query_component_has_required_params(self): - tool = self._builtin("get_dash_component") - assert "component_id" in tool.inputSchema["properties"] - assert "property" in tool.inputSchema["properties"] - assert set(tool.inputSchema.get("required", [])) == {"component_id"} - - -class TestSanitizeToolName: - def test_simple_name(self): - assert ( - CallbackAdapterCollection._sanitize_name("update_output") == "update_output" - ) - - def test_special_characters_replaced(self): - assert ( - CallbackAdapterCollection._sanitize_name("my-func.name") == "my_func_name" - ) - - def test_leading_digit(self): - assert CallbackAdapterCollection._sanitize_name("123func") == "cb_123func" - - def test_empty_name(self): - assert CallbackAdapterCollection._sanitize_name("") == "unnamed_callback" - - def test_consecutive_underscores_collapsed(self): - assert CallbackAdapterCollection._sanitize_name("a---b___c") == "a_b_c" - - def test_long_name_truncated_to_64_chars(self): - result = CallbackAdapterCollection._sanitize_name("a" * 200) - assert len(result) <= 64 - assert result[-8:].isalnum() - - def test_long_name_uniqueness(self): - result_a = CallbackAdapterCollection._sanitize_name("a" * 200) - result_b = CallbackAdapterCollection._sanitize_name("b" * 200) - assert result_a != result_b - - def test_short_name_not_truncated(self): - assert CallbackAdapterCollection._sanitize_name("short_name") == "short_name" - - -class TestOutputSemanticSummary: - """Test the _OUTPUT_SEMANTICS mapping in description_outputs.py. - - Other description tests (docstring, output target, multi-output) are - covered by TestTool in test_callback_adapter.py using real adapters. - """ - - @staticmethod - def _adapter_with_outputs(outputs, docstring=None): - from unittest.mock import Mock - - adapter = Mock() - adapter.outputs = outputs - adapter._docstring = docstring - return adapter - - @staticmethod - def _out(comp_id, prop, comp_type=None): - return { - "id_and_prop": f"{comp_id}.{prop}", - "component_id": comp_id, - "property": prop, - "component_type": comp_type, - "initial_value": None, - } - - def test_semantic_summary_with_component_type(self): - adapter = self._adapter_with_outputs([self._out("my-graph", "figure", "Graph")]) - desc = build_tool_description(adapter) - assert "Returns chart/visualization data" in desc - - def test_semantic_summary_fallback_by_property(self): - adapter = self._adapter_with_outputs([self._out("unknown-id", "figure")]) - desc = build_tool_description(adapter) - assert "Returns chart/visualization data" in desc From 99887535ff22a0c5c4522e7f551dabdd93bc32f0 Mon Sep 17 00:00:00 2001 From: Adrian Borrmann Date: Fri, 8 May 2026 15:48:58 -0600 Subject: [PATCH 48/80] Implement a mapping of "id+prop" to callbacks that are outputs/inputs to those pairs --- .../tools/callback_adapter_collection.py | 38 ++++++++++++++++--- 1 file changed, 33 insertions(+), 5 deletions(-) diff --git a/dash/mcp/primitives/tools/callback_adapter_collection.py b/dash/mcp/primitives/tools/callback_adapter_collection.py index 0304394f63..e740e7f0b8 100644 --- a/dash/mcp/primitives/tools/callback_adapter_collection.py +++ b/dash/mcp/primitives/tools/callback_adapter_collection.py @@ -86,8 +86,14 @@ def find_by_tool_name(self, name: str) -> CallbackAdapter | None: return cb return None - def find_by_output(self, id_and_prop: str) -> CallbackAdapter | None: - """Find the adapter that outputs to ``id_and_prop`` (``"component_id.property"``).""" + @cached_property + def outputs_by_prop(self) -> dict[str, list[CallbackAdapter]]: + """Index ``id_and_prop`` → callbacks outputting it. + + Mirrors the dash-renderer's ``outputMap`` (see + ``dash-renderer/src/actions/dependencies.js``). + """ + idx: dict[str, list[CallbackAdapter]] = {} for cb in self._callbacks: try: parsed = split_callback_id(cb.output_id) @@ -96,9 +102,31 @@ def find_by_output(self, id_and_prop: str) -> CallbackAdapter | None: if isinstance(parsed, dict): parsed = [parsed] for p in parsed: - if f"{p['id']}.{clean_property_name(p['property'])}" == id_and_prop: - return cb - return None + key = f"{p['id']}.{clean_property_name(p['property'])}" + idx.setdefault(key, []).append(cb) + return idx + + @cached_property + def inputs_by_prop(self) -> dict[str, list[CallbackAdapter]]: + """Index ``id_and_prop`` → callbacks consuming it as input/state. + + Mirrors the dash-renderer's ``inputMap`` (see + ``dash-renderer/src/actions/dependencies.js``). + Many callbacks may share a key. + """ + idx: dict[str, list[CallbackAdapter]] = {} + for cb in self._callbacks: + # pylint: disable-next=protected-access + deps = cb._cb_info.get("inputs", []) + cb._cb_info.get("state", []) + for dep in deps: + key = f"{dep.get('id', 'unknown')}.{dep.get('property', 'unknown')}" + idx.setdefault(key, []).append(cb) + return idx + + def find_by_output(self, id_and_prop: str) -> CallbackAdapter | None: + """Find the adapter that outputs to ``id_and_prop`` (``"component_id.property"``).""" + candidates = self.outputs_by_prop.get(id_and_prop, []) + return candidates[0] if candidates else None def get_initial_value(self, id_and_prop: str) -> Any: """Return the initial value for ``id_and_prop`` (``"component_id.property"``). From c2f06ae930e2dc24b3898b75c74e3935b1ab5ca6 Mon Sep 17 00:00:00 2001 From: Adrian Borrmann Date: Fri, 8 May 2026 15:49:06 -0600 Subject: [PATCH 49/80] Code review feedback --- .../tools/tool_get_dash_component.py | 28 +++++++++---------- 1 file changed, 13 insertions(+), 15 deletions(-) diff --git a/dash/mcp/primitives/tools/tool_get_dash_component.py b/dash/mcp/primitives/tools/tool_get_dash_component.py index 69b6276d5a..5dffc2cf58 100644 --- a/dash/mcp/primitives/tools/tool_get_dash_component.py +++ b/dash/mcp/primitives/tools/tool_get_dash_component.py @@ -61,13 +61,16 @@ def list_tools(cls) -> list[Tool]: def call_tool(cls, tool_name: str, arguments: dict[str, Any]) -> CallToolResult: comp_id = arguments.get("component_id", "") if not comp_id: - raise ValueError("component_id is required") + return CallToolResult( + content=[TextContent(type="text", text="component_id is required")], + isError=True, + ) prop_filter = arguments.get("property", "") component = find_component(comp_id) + callback_map = get_app().mcp_callback_map if component is None: - callback_map = get_app().mcp_callback_map rendering_tools = [ cb.tool_name for cb in callback_map @@ -84,29 +87,24 @@ def call_tool(cls, tool_name: str, arguments: dict[str, Any]) -> CallToolResult: isError=True, ) - callback_map = get_app().mcp_callback_map - properties: dict[str, ComponentPropertyInfo] = {} for prop_name in getattr(component, "_prop_names", []): if prop_filter and prop_name != prop_filter: continue - value = callback_map.get_initial_value(f"{comp_id}.{prop_name}") + id_and_prop = f"{comp_id}.{prop_name}" + value = callback_map.get_initial_value(id_and_prop) if value is None: value = getattr(component, prop_name, None) if value is None: continue - modified_by: list[str] = [] - input_to: list[str] = [] - id_and_prop = f"{comp_id}.{prop_name}" - for cb in callback_map: - for out in cb.outputs: - if out["id_and_prop"] == id_and_prop: - modified_by.append(cb.tool_name) - for inp in cb.inputs: - if inp["id_and_prop"] == id_and_prop: - input_to.append(cb.tool_name) + modified_by = [ + cb.tool_name for cb in callback_map.outputs_by_prop.get(id_and_prop, []) + ] + input_to = [ + cb.tool_name for cb in callback_map.inputs_by_prop.get(id_and_prop, []) + ] properties[prop_name] = ComponentPropertyInfo( initial_value=value, From 34a4d980fd0526e829bf7896bb0c6d2d38a0e7b7 Mon Sep 17 00:00:00 2001 From: Adrian Borrmann Date: Tue, 12 May 2026 17:18:56 -0600 Subject: [PATCH 50/80] Make run_callback compatible with 4.2.0 changes --- dash/mcp/primitives/tools/callback_utils.py | 33 +++++++++++---------- 1 file changed, 18 insertions(+), 15 deletions(-) diff --git a/dash/mcp/primitives/tools/callback_utils.py b/dash/mcp/primitives/tools/callback_utils.py index 361b27fc32..24a8a693fb 100644 --- a/dash/mcp/primitives/tools/callback_utils.py +++ b/dash/mcp/primitives/tools/callback_utils.py @@ -3,6 +3,7 @@ from __future__ import annotations import json +from contextvars import copy_context from typing import TYPE_CHECKING, Any from dash import get_app @@ -16,23 +17,25 @@ def run_callback( callback: CallbackAdapter, kwargs: dict[str, Any] ) -> CallbackExecutionResponse: - """Execute a callback via the framework.""" - body = callback.as_callback_body(kwargs) + """Execute a callback via the framework. + Must be called from inside an active request handler; the backend's + request adapter reads cookies/headers/args from the current request. + """ + body = callback.as_callback_body(kwargs) app = get_app() - with app.server.test_request_context( - "/_dash-update-component", - method="POST", - data=json.dumps(body, default=str), - content_type="application/json", - ): - response = app.dispatch() - - response_text = response.get_data(as_text=True) - if response.status_code != 200: + + try: + # pylint: disable=protected-access + cb_ctx = app._initialize_context(body) + func = app._prepare_callback(cb_ctx, body) + args = app._inputs_to_vals(cb_ctx.inputs_list + cb_ctx.states_list) + ctx = copy_context() + partial_func = app._execute_callback(func, args, cb_ctx.outputs_list, cb_ctx) + response_text = ctx.run(partial_func) + except Exception as err: raise CallbackExecutionError( - f"Callback {callback.output_id} failed " - f"(HTTP {response.status_code}): {response_text[:500]}" - ) + f"Callback {callback.output_id} failed: {err}" + ) from err return json.loads(response_text) From 227eeedb671b5fa2493815597316e169a64fdedb Mon Sep 17 00:00:00 2001 From: Adrian Borrmann Date: Wed, 1 Apr 2026 15:35:13 -0600 Subject: [PATCH 51/80] Wire MCP server, SSE transport, and Dash app integration --- dash/_configs.py | 2 + dash/dash.py | 31 + dash/mcp/__init__.py | 7 + dash/mcp/_server.py | 277 +++++ dash/mcp/_sse.py | 67 ++ dash/mcp/notifications/__init__.py | 7 + .../notification_tools_changed.py | 30 + dash/mcp/primitives/__init__.py | 17 + .../tools/callback_adapter_collection.py | 2 - tests/integration/mcp/conftest.py | 53 + .../primitives/resources/test_resources.py | 51 + .../tools/test_callback_signatures.py | 958 ++++++++++++++++++ .../tools/test_duplicate_outputs.py | 128 +++ .../primitives/tools/test_input_schemas.py | 66 ++ .../tools/test_tool_get_dash_component.py | 54 + .../mcp/primitives/tools/test_tools_list.py | 118 +++ tests/integration/mcp/test_server.py | 304 ++++++ tests/unit/mcp/test_server.py | 92 ++ tests/unit/mcp/tools/test_run_callback.py | 246 +++++ 19 files changed, 2508 insertions(+), 2 deletions(-) create mode 100644 dash/mcp/_server.py create mode 100644 dash/mcp/_sse.py create mode 100644 dash/mcp/notifications/__init__.py create mode 100644 dash/mcp/notifications/notification_tools_changed.py create mode 100644 tests/integration/mcp/conftest.py create mode 100644 tests/integration/mcp/primitives/resources/test_resources.py create mode 100644 tests/integration/mcp/primitives/tools/test_callback_signatures.py create mode 100644 tests/integration/mcp/primitives/tools/test_duplicate_outputs.py create mode 100644 tests/integration/mcp/primitives/tools/test_input_schemas.py create mode 100644 tests/integration/mcp/primitives/tools/test_tool_get_dash_component.py create mode 100644 tests/integration/mcp/primitives/tools/test_tools_list.py create mode 100644 tests/integration/mcp/test_server.py create mode 100644 tests/unit/mcp/test_server.py create mode 100644 tests/unit/mcp/tools/test_run_callback.py diff --git a/dash/_configs.py b/dash/_configs.py index 107b8308f5..0e1ab75505 100644 --- a/dash/_configs.py +++ b/dash/_configs.py @@ -33,6 +33,8 @@ def load_dash_env_vars(): "DASH_DISABLE_VERSION_CHECK", "DASH_PRUNE_ERRORS", "DASH_COMPRESS", + "DASH_MCP_ENABLED", + "DASH_MCP_PATH", "HOST", "PORT", ) diff --git a/dash/dash.py b/dash/dash.py index f887598497..c6fa5a6bf3 100644 --- a/dash/dash.py +++ b/dash/dash.py @@ -486,6 +486,8 @@ def __init__( # pylint: disable=too-many-statements, too-many-branches websocket_callbacks: Optional[bool] = False, websocket_allowed_origins: Optional[List[str]] = None, websocket_inactivity_timeout: Optional[int] = 300000, + enable_mcp: Optional[bool] = None, + mcp_path: Optional[str] = None, **obsolete, ): @@ -597,6 +599,13 @@ def __init__( # pylint: disable=too-many-statements, too-many-branches # keep title as a class property for backwards compatibility self.title = title + # MCP (Model Context Protocol) configuration + self._enable_mcp = get_combined_config("mcp_enabled", enable_mcp, True) + _mcp_path = get_combined_config("mcp_path", mcp_path, "_mcp") + self._mcp_path = ( + _mcp_path.lstrip("/") if isinstance(_mcp_path, str) else _mcp_path + ) + # list of dependencies - this one is used by the back end for dispatching self.callback_map: dict = {} # same deps as a list to catch duplicate outputs, and to send to the front end @@ -809,6 +818,21 @@ def _setup_routes(self): hook.data["methods"], ) + if self._enable_mcp: + from .mcp import ( # pylint: disable=import-outside-toplevel + enable_mcp_server, + ) + + try: + enable_mcp_server(self, self._mcp_path) + except Exception as e: # pylint: disable=broad-exception-caught + self._enable_mcp = False + self.logger.warning( + "MCP server could not be started at '%s': %s", + self._mcp_path, + e, + ) + def setup_apis(self): """ Register API endpoints for all callbacks defined using `dash.callback`. @@ -2452,6 +2476,13 @@ def verify_url_part(served_part, url_part, part_name): if not jupyter_dash or not jupyter_dash.in_ipython: self.logger.info("Dash is running on %s://%s%s%s\n", *display_url) + if self._enable_mcp: + self.logger.info( + " * MCP available at %s://%s%s%s%s\n", + *display_url[:3], + self.config.routes_pathname_prefix, + self._mcp_path, + ) if self.config.extra_hot_reload_paths: extra_files = flask_run_options["extra_files"] = [] diff --git a/dash/mcp/__init__.py b/dash/mcp/__init__.py index e69de29bb2..2677ea141b 100644 --- a/dash/mcp/__init__.py +++ b/dash/mcp/__init__.py @@ -0,0 +1,7 @@ +"""Dash MCP (Model Context Protocol) server integration.""" + +from dash.mcp._server import enable_mcp_server + +__all__ = [ + enable_mcp_server, +] diff --git a/dash/mcp/_server.py b/dash/mcp/_server.py new file mode 100644 index 0000000000..1c6279290b --- /dev/null +++ b/dash/mcp/_server.py @@ -0,0 +1,277 @@ +"""Flask route setup, Streamable HTTP transport, and MCP message handling.""" + +from __future__ import annotations + +import atexit +import json +import logging +import uuid +from typing import TYPE_CHECKING, Any + +from flask import Response, request + +from dash.mcp.types import MCPError + +if TYPE_CHECKING: + from dash import Dash + +from dash import get_app + +from mcp.types import ( + LATEST_PROTOCOL_VERSION, + ErrorData, + Implementation, + InitializeResult, + JSONRPCError, + JSONRPCResponse, + ResourcesCapability, + ServerCapabilities, + ToolsCapability, +) + +from dash.version import __version__ +from dash.mcp._sse import ( + close_sse_stream, + create_sse_stream, + shutdown_all_streams, +) +from dash.mcp.primitives import ( + call_tool, + list_resource_templates, + list_resources, + list_tools, + read_resource, +) +from dash.mcp.primitives.tools.callback_adapter_collection import ( + CallbackAdapterCollection, +) + +logger = logging.getLogger(__name__) + + +def enable_mcp_server(app: Dash, mcp_path: str) -> None: + """ + Add MCP routes to a Dash/Flask app. + + Registers a single Streamable HTTP endpoint for the MCP protocol. + Uses ``app._add_url()`` so that ``routes_pathname_prefix`` is applied + automatically. + + Args: + app: The Dash application instance. + mcp_path: Route prefix for MCP endpoints. + """ + # Session storage: session_id -> metadata + sessions: dict[str, dict[str, Any]] = {} + + def _create_session() -> str: + sid = str(uuid.uuid4()) + sessions[sid] = {} + return sid + + # -- Streamable HTTP endpoint -------------------------------------------- + + def mcp_handler() -> Response: + if request.method == "POST": + return _handle_post() + if request.method == "GET": + return _handle_get() + if request.method == "DELETE": + return _handle_delete() + return Response( + json.dumps({"error": "Method not allowed"}), + content_type="application/json", + status=405, + ) + + def _handle_get() -> Response: + session_id = request.headers.get("mcp-session-id") + if not session_id or session_id not in sessions: + return Response( + json.dumps({"error": "Session not found"}), + content_type="application/json", + status=404, + ) + return create_sse_stream(sessions, session_id) + + def _handle_post() -> Response: + content_type = request.content_type or "" + if "application/json" not in content_type: + return Response( + json.dumps({"error": "Content-Type must be application/json"}), + content_type="application/json", + status=415, + ) + + try: + data = request.get_json() + except Exception: + return Response( + json.dumps({"error": "Invalid JSON"}), + content_type="application/json", + status=400, + ) + + method = data.get("method", "") + request_id = data.get("id") + session_id = request.headers.get("mcp-session-id") + + stale_session = False + if method == "initialize": + session_id = _create_session() + elif session_id and session_id not in sessions: + stale_session = True + sessions[session_id] = {} + elif not session_id: + session_id = _create_session() + + response_data = _process_mcp_message(data) + + if response_data is None: + return Response("", status=202) + + if stale_session: + _inject_warning(response_data, _STALE_SESSION_WARNING) + + return Response( + json.dumps(response_data), + content_type="application/json", + status=200, + headers={"mcp-session-id": session_id}, + ) + + def _handle_delete() -> Response: + session_id = request.headers.get("mcp-session-id") + if not session_id or session_id not in sessions: + return Response( + json.dumps({"error": "Session not found"}), + content_type="application/json", + status=404, + ) + close_sse_stream(sessions[session_id]) + del sessions[session_id] + logger.info("MCP session terminated: %s", session_id) + return Response("", status=204) + + # -- Register routes ----------------------------------------------------- + + from dash._get_app import with_app_context_factory + + app._add_url( + mcp_path, with_app_context_factory(mcp_handler, app), ["GET", "POST", "DELETE"] + ) + + # Close all SSE streams on server shutdown so MCP clients see a + # clean stream end and can reconnect promptly. + atexit.register(shutdown_all_streams, sessions) + + logger.info( + "MCP routes registered at %s%s", + app.config.routes_pathname_prefix, + mcp_path, + ) + + +_STALE_SESSION_WARNING = ( + "[Warning: your session was not recognised" + " — the app may have restarted." + " Please call tools/list to refresh your tool list." + " Please ask the user to reconnect to the MCP server.]" +) + + +def _inject_warning(response_data: dict[str, Any], warning: str) -> None: + """Append a warning to a JSON-RPC response dict. + + For successful ``tools/call`` responses the warning is added as an + extra text content block so the agent sees it alongside the result. + For error responses the warning is appended to the error message. + Other responses (tools/list, resources/*) are left unchanged — the + JSON-RPC spec forbids extra top-level keys. + """ + # tools/call success: result has a "content" list + result = response_data.get("result") + if isinstance(result, dict) and isinstance(result.get("content"), list): + result["content"].append({"type": "text", "text": warning}) + return + + # Error response + error = response_data.get("error") + if isinstance(error, dict) and "message" in error: + error["message"] += " " + warning + + +def _handle_initialize() -> InitializeResult: + return InitializeResult( + protocolVersion=LATEST_PROTOCOL_VERSION, + capabilities=ServerCapabilities( + tools=ToolsCapability(listChanged=True), + resources=ResourcesCapability(), + ), + serverInfo=Implementation(name="Plotly Dash", version=__version__), + instructions=( + "This is a Dash web application. " + "Dash apps are stateless: calling a tool executes " + "a callback and returns its result to you, but does " + "NOT update the user's browser. " + "Use tool results to answer questions about what " + "the app would produce for given inputs." + ), + ) + + +def _process_mcp_message(data: dict[str, Any]) -> dict[str, Any] | None: + """ + Process an MCP JSON-RPC message and return the response dict. + + Returns ``None`` for notifications (no ``id`` field). + """ + method = data.get("method", "") + params = data.get("params", {}) or {} + request_id = data.get("id") + + app = get_app() + if not hasattr(app, "mcp_callback_map"): + app.mcp_callback_map = CallbackAdapterCollection(app) + + mcp_methods = { + "initialize": _handle_initialize, + "tools/list": lambda: list_tools(), + "tools/call": lambda: call_tool( + params.get("name", ""), params.get("arguments", {}) + ), + "resources/list": lambda: list_resources(), + "resources/templates/list": lambda: list_resource_templates(), + "resources/read": lambda: read_resource(params.get("uri", "")), + } + + try: + handler = mcp_methods.get(method) + if handler is None: + if method.startswith("notifications/"): + return None + raise ValueError(f"Unknown method: {method}") + + result = handler() + + response = JSONRPCResponse( + jsonrpc="2.0", + id=request_id, + result=result.model_dump(exclude_none=True, mode="json"), + ) + return response.model_dump(exclude_none=True, mode="json") + + except MCPError as e: + logger.error("MCP error: %s", e) + return JSONRPCError( + jsonrpc="2.0", + id=request_id, + error=ErrorData(code=e.code, message=str(e)), + ).model_dump(exclude_none=True) + except Exception as e: + logger.error("MCP error: %s", e, exc_info=True) + return JSONRPCError( + jsonrpc="2.0", + id=request_id, + error=ErrorData(code=-32603, message=f"{type(e).__name__}: {e}"), + ).model_dump(exclude_none=True) diff --git a/dash/mcp/_sse.py b/dash/mcp/_sse.py new file mode 100644 index 0000000000..4928dc68b2 --- /dev/null +++ b/dash/mcp/_sse.py @@ -0,0 +1,67 @@ +"""SSE stream generation and queue management.""" + +from __future__ import annotations + +import queue +from typing import Any + +from flask import Response + + +def create_sse_stream(sessions: dict[str, dict[str, Any]], session_id: str) -> Response: + """Create a Server-Sent Events stream for the given session. + + Stores a :class:`queue.Queue` in ``sessions[session_id]["sse_queue"]`` + and returns a Flask streaming ``Response``. The generator yields + events pushed to the queue, with keepalive comments every 30 seconds. + """ + event_queue: queue.Queue[str | None] = queue.Queue() + # Replace any prior SSE queue for this session (client reconnect). + sessions[session_id]["sse_queue"] = event_queue + + def _generate(): + try: + while True: + try: + event = event_queue.get(timeout=30) + if event is None: + return # Sentinel: server closing stream + yield f"event: message\ndata: {event}\n\n" + except queue.Empty: + yield ": keepalive\n\n" + except GeneratorExit: + pass + finally: + # Clean up queue reference if it's still ours. + if sessions.get(session_id, {}).get("sse_queue") is event_queue: + sessions[session_id].pop("sse_queue", None) + + return Response( + _generate(), + content_type="text/event-stream", + headers={ + "Cache-Control": "no-cache", + "Connection": "keep-alive", + "mcp-session-id": session_id, + }, + ) + + +def close_sse_stream(session_data: dict[str, Any]) -> None: + """Send a sentinel to shut down the session's SSE stream cleanly.""" + sse_queue = session_data.get("sse_queue") + if sse_queue is not None: + try: + sse_queue.put_nowait(None) + except queue.Full: + pass + + +def shutdown_all_streams(sessions: dict[str, dict[str, Any]]) -> None: + """Close all active SSE streams. + + Called during server shutdown (via ``atexit``) so that connected + MCP clients see a clean stream end and can reconnect promptly. + """ + for session_data in list(sessions.values()): + close_sse_stream(session_data) diff --git a/dash/mcp/notifications/__init__.py b/dash/mcp/notifications/__init__.py new file mode 100644 index 0000000000..b1fe9e8665 --- /dev/null +++ b/dash/mcp/notifications/__init__.py @@ -0,0 +1,7 @@ +"""Server-initiated MCP notifications.""" + +from .notification_tools_changed import broadcast_tools_changed + +__all__ = [ + "broadcast_tools_changed", +] diff --git a/dash/mcp/notifications/notification_tools_changed.py b/dash/mcp/notifications/notification_tools_changed.py new file mode 100644 index 0000000000..1970667d1a --- /dev/null +++ b/dash/mcp/notifications/notification_tools_changed.py @@ -0,0 +1,30 @@ +"""Tool list change notifications.""" + +from __future__ import annotations + +import json +import queue +from typing import Any + + +def broadcast_tools_changed( + sessions: dict[str, dict[str, Any]], +) -> None: + """Push a tools/list_changed notification to all active SSE streams. + + Not called automatically yet — available for future hot-reload + or dynamic callback registration. + """ + notification = json.dumps( + { + "jsonrpc": "2.0", + "method": "notifications/tools/list_changed", + } + ) + for data in sessions.values(): + sse_queue = data.get("sse_queue") + if sse_queue is not None: + try: + sse_queue.put_nowait(notification) + except queue.Full: + pass diff --git a/dash/mcp/primitives/__init__.py b/dash/mcp/primitives/__init__.py index e69de29bb2..b14839f1e1 100644 --- a/dash/mcp/primitives/__init__.py +++ b/dash/mcp/primitives/__init__.py @@ -0,0 +1,17 @@ +from .resources import ( + list_resources, + list_resource_templates, + read_resource, +) +from .tools import ( + call_tool, + list_tools, +) + +__all__ = [ + call_tool, + list_resources, + list_resource_templates, + list_tools, + read_resource, +] diff --git a/dash/mcp/primitives/tools/callback_adapter_collection.py b/dash/mcp/primitives/tools/callback_adapter_collection.py index e740e7f0b8..4fdaeabe9c 100644 --- a/dash/mcp/primitives/tools/callback_adapter_collection.py +++ b/dash/mcp/primitives/tools/callback_adapter_collection.py @@ -35,8 +35,6 @@ def __init__(self, app): CallbackAdapter(callback_output_id=output_id) for output_id in self._tool_names_map ] - # TODO: enable_mcp_server() will replace this with a direct assignment on app - app.mcp_callback_map = self @staticmethod def _sanitize_name(name: str) -> str: diff --git a/tests/integration/mcp/conftest.py b/tests/integration/mcp/conftest.py new file mode 100644 index 0000000000..0f212d1763 --- /dev/null +++ b/tests/integration/mcp/conftest.py @@ -0,0 +1,53 @@ +"""Shared helpers for MCP integration tests.""" + +import requests + + +def _mcp_post(server_url, method, params=None, session_id=None, request_id=1): + headers = {"Content-Type": "application/json"} + if session_id: + headers["mcp-session-id"] = session_id + return requests.post( + f"{server_url}/_mcp", + json={ + "jsonrpc": "2.0", + "method": method, + "id": request_id, + "params": params or {}, + }, + headers=headers, + timeout=5, + ) + + +def _mcp_session(server_url): + resp = _mcp_post(server_url, "initialize") + resp.raise_for_status() + return resp.headers["mcp-session-id"] + + +def _mcp_tools(server_url): + sid = _mcp_session(server_url) + resp = _mcp_post(server_url, "tools/list", session_id=sid, request_id=2) + resp.raise_for_status() + return resp.json()["result"]["tools"] + + +def _mcp_call_tool(server_url, tool_name, arguments=None): + sid = _mcp_session(server_url) + resp = _mcp_post( + server_url, + "tools/call", + {"name": tool_name, "arguments": arguments or {}}, + session_id=sid, + request_id=2, + ) + resp.raise_for_status() + return resp.json() + + +def _mcp_method(server_url, method, params=None): + sid = _mcp_session(server_url) + resp = _mcp_post(server_url, method, params, session_id=sid, request_id=2) + resp.raise_for_status() + return resp.json() diff --git a/tests/integration/mcp/primitives/resources/test_resources.py b/tests/integration/mcp/primitives/resources/test_resources.py new file mode 100644 index 0000000000..dfc1e09f9b --- /dev/null +++ b/tests/integration/mcp/primitives/resources/test_resources.py @@ -0,0 +1,51 @@ +"""Integration tests for MCP resources.""" + +import json + +from dash import Dash, dcc, html + +from tests.integration.mcp.conftest import _mcp_method + + +def test_resources_list_includes_layout(dash_duo): + app = Dash(__name__) + app.layout = html.Div( + [ + dcc.Dropdown(id="dd", options=["a"], value="a"), + html.Div(id="out"), + ] + ) + + dash_duo.start_server(app) + result = _mcp_method(dash_duo.server.url, "resources/list") + + assert "result" in result + uris = [r["uri"] for r in result["result"]["resources"]] + assert "dash://layout" in uris + + +def test_read_layout_resource(dash_duo): + app = Dash(__name__) + app.layout = html.Div( + [ + dcc.Dropdown(id="res-dd", options=["x", "y"], value="x"), + html.Div(id="out"), + ] + ) + + dash_duo.start_server(app) + result = _mcp_method( + dash_duo.server.url, + "resources/read", + {"uri": "dash://layout"}, + ) + + assert "result" in result + layout = json.loads(result["result"]["contents"][0]["text"]) + assert layout["type"] == "Div" + children = layout["props"]["children"] + dd = next( + c for c in children if isinstance(c, dict) and c.get("type") == "Dropdown" + ) + assert dd["props"]["id"] == "res-dd" + assert dd["props"]["options"] == ["x", "y"] diff --git a/tests/integration/mcp/primitives/tools/test_callback_signatures.py b/tests/integration/mcp/primitives/tools/test_callback_signatures.py new file mode 100644 index 0000000000..db325f2046 --- /dev/null +++ b/tests/integration/mcp/primitives/tools/test_callback_signatures.py @@ -0,0 +1,958 @@ +""" +Integration tests for all Dash callback signature types. + +Each test verifies that: +1. The MCP tool schema accurately reflects the callback's parameters +2. Calling the tool with those parameters produces the expected result + +Assertions are derived from the callback definition, not the implementation. + +See: https://dash.plotly.com/flexible-callback-signatures +""" + +from dash import Dash, Input, Output, State, dcc, html + +from tests.integration.mcp.conftest import _mcp_call_tool, _mcp_tools + + +def _find_tool(tools, name): + return next(t for t in tools if t["name"] == name) + + +def _get_response(result): + return result["result"]["structuredContent"]["response"] + + +def test_positional_callback(dash_duo): + """Standard positional: Input("fruit", "value") → param named 'value'.""" + app = Dash(__name__) + app.layout = html.Div( + [ + dcc.Dropdown(id="fruit", options=["apple", "banana"], value="apple"), + html.Div(id="out"), + ] + ) + + # Callback: 1 Input → 1 param named "value" (from function signature) + # Returns string → Output("out", "children") + @app.callback(Output("out", "children"), Input("fruit", "value")) + def show_fruit(value): + return f"Selected: {value}" + + dash_duo.start_server(app) + dash_duo.wait_for_text_to_equal("#out", "Selected: apple") + + tool = _find_tool(_mcp_tools(dash_duo.server.url), "show_fruit") + props = tool["inputSchema"]["properties"] + + assert set(props.keys()) == {"value"} + assert any(s.get("type") == "string" for s in props["value"]["anyOf"]) + + # Tool description reflects initial state + value_desc = props["value"].get("description", "") + assert "value: 'apple'" in value_desc + assert "options: ['apple', 'banana']" in value_desc + + # MCP tool with initial inputs matches browser + result = _mcp_call_tool(dash_duo.server.url, "show_fruit", {"value": "apple"}) + response = _get_response(result) + assert response["out"]["children"] == "Selected: apple" + + # MCP tool with different inputs + result = _mcp_call_tool(dash_duo.server.url, "show_fruit", {"value": "banana"}) + response = _get_response(result) + assert response["out"]["children"] == "Selected: banana" + + +def test_positional_with_state(dash_duo): + """Positional with State: Input + State both appear as params.""" + app = Dash(__name__) + app.layout = html.Div( + [ + html.Button(id="btn"), + dcc.Input(id="inp", value="hello"), + html.Div(id="out"), + ] + ) + + # Callback: 1 Input + 1 State → 2 params named "n_clicks" and "value" + @app.callback( + Output("out", "children"), + Input("btn", "n_clicks"), + State("inp", "value"), + ) + def update(n_clicks, value): + return f"Clicked {n_clicks} with {value}" + + dash_duo.start_server(app) + dash_duo.wait_for_text_to_equal("#out", "Clicked None with hello") + + tool = _find_tool(_mcp_tools(dash_duo.server.url), "update") + props = tool["inputSchema"]["properties"] + + assert set(props.keys()) == {"n_clicks", "value"} + assert any(s.get("type") == "number" for s in props["n_clicks"]["anyOf"]) + + # Tool description reflects initial state + assert "value: 'hello'" in props["value"].get("description", "") + + # MCP tool with initial inputs matches browser + result = _mcp_call_tool( + dash_duo.server.url, "update", {"n_clicks": None, "value": "hello"} + ) + response = _get_response(result) + assert response["out"]["children"] == "Clicked None with hello" + + result = _mcp_call_tool( + dash_duo.server.url, "update", {"n_clicks": 3, "value": "world"} + ) + response = _get_response(result) + assert response["out"]["children"] == "Clicked 3 with world" + + +def test_multi_output_positional(dash_duo): + """Multi-output: returns tuple → both outputs updated in response.""" + app = Dash(__name__) + app.layout = html.Div( + [ + dcc.Input(id="inp", value="test"), + html.Div(id="out1"), + html.Div(id="out2"), + ] + ) + + # Callback: 1 Input → 2 Outputs via tuple return + @app.callback( + Output("out1", "children"), + Output("out2", "children"), + Input("inp", "value"), + ) + def split_case(value): + return value.upper(), value.lower() + + dash_duo.start_server(app) + dash_duo.wait_for_text_to_equal("#out1", "TEST") + dash_duo.wait_for_text_to_equal("#out2", "test") + + tool = _find_tool(_mcp_tools(dash_duo.server.url), "split_case") + props = tool["inputSchema"]["properties"] + assert set(props.keys()) == {"value"} + + # Tool description reflects initial state + assert "value: 'test'" in props["value"].get("description", "") + + # MCP tool with initial inputs matches browser + result = _mcp_call_tool(dash_duo.server.url, "split_case", {"value": "test"}) + response = _get_response(result) + assert response["out1"]["children"] == "TEST" + assert response["out2"]["children"] == "test" + + +def test_dict_based_inputs_and_state(dash_duo): + """Dict-based: inputs=dict(trigger=...), state=dict(name=...) → dict keys are param names.""" + app = Dash(__name__) + app.layout = html.Div( + [ + html.Button(id="btn"), + dcc.Input(id="name-input", value="world"), + html.Div(id="out"), + ] + ) + + # Callback: dict keys "trigger" and "name" become param names + @app.callback( + Output("out", "children"), + inputs=dict(trigger=Input("btn", "n_clicks")), + state=dict(name=State("name-input", "value")), + ) + def greet(trigger, name): + return f"Hello, {name}!" + + dash_duo.start_server(app) + dash_duo.wait_for_text_to_equal("#out", "Hello, world!") + + tool = _find_tool(_mcp_tools(dash_duo.server.url), "greet") + props = tool["inputSchema"]["properties"] + + assert set(props.keys()) == {"trigger", "name"} + assert any(s.get("type") == "number" for s in props["trigger"]["anyOf"]) + + # MCP tool with initial inputs matches browser + result = _mcp_call_tool( + dash_duo.server.url, "greet", {"trigger": None, "name": "world"} + ) + response = _get_response(result) + assert response["out"]["children"] == "Hello, world!" + + result = _mcp_call_tool( + dash_duo.server.url, "greet", {"trigger": 1, "name": "Dash"} + ) + response = _get_response(result) + assert response["out"]["children"] == "Hello, Dash!" + + +def test_dict_based_outputs(dash_duo): + """Dict-based outputs: output=dict(...) → callback returns dict, both outputs updated.""" + app = Dash(__name__) + app.layout = html.Div( + [ + dcc.Input(id="inp", value="hello"), + html.Div(id="upper-out"), + html.Div(id="lower-out"), + ] + ) + + # Callback: dict output keys "upper" and "lower" map to components + @app.callback( + output=dict( + upper=Output("upper-out", "children"), + lower=Output("lower-out", "children"), + ), + inputs=dict(val=Input("inp", "value")), + ) + def transform(val): + return dict(upper=val.upper(), lower=val.lower()) + + dash_duo.start_server(app) + dash_duo.wait_for_text_to_equal("#upper-out", "HELLO") + dash_duo.wait_for_text_to_equal("#lower-out", "hello") + + tool = _find_tool(_mcp_tools(dash_duo.server.url), "transform") + props = tool["inputSchema"]["properties"] + assert set(props.keys()) == {"val"} + + # MCP tool with initial inputs matches browser + result = _mcp_call_tool(dash_duo.server.url, "transform", {"val": "hello"}) + response = _get_response(result) + assert response["upper-out"]["children"] == "HELLO" + assert response["lower-out"]["children"] == "hello" + + +def test_mixed_input_state_in_inputs(dash_duo): + """Mixed: State inside inputs=dict alongside Input → all appear as params.""" + app = Dash(__name__) + app.layout = html.Div( + [ + html.Button(id="btn"), + dcc.Input(id="first", value="Jane"), + dcc.Input(id="last", value="Doe"), + html.Div(id="out"), + ] + ) + + # Callback: Input and State mixed in same dict → all keys are params + @app.callback( + Output("out", "children"), + inputs=dict( + clicks=Input("btn", "n_clicks"), + first=State("first", "value"), + last=State("last", "value"), + ), + ) + def full_name(clicks, first, last): + return f"{first} {last}" + + dash_duo.start_server(app) + dash_duo.wait_for_text_to_equal("#out", "Jane Doe") + + tool = _find_tool(_mcp_tools(dash_duo.server.url), "full_name") + props = tool["inputSchema"]["properties"] + + assert set(props.keys()) == {"clicks", "first", "last"} + assert any(s.get("type") == "number" for s in props["clicks"]["anyOf"]) + + # MCP tool with initial inputs matches browser + result = _mcp_call_tool( + dash_duo.server.url, + "full_name", + {"clicks": None, "first": "Jane", "last": "Doe"}, + ) + response = _get_response(result) + assert response["out"]["children"] == "Jane Doe" + + result = _mcp_call_tool( + dash_duo.server.url, + "full_name", + {"clicks": 1, "first": "John", "last": "Smith"}, + ) + response = _get_response(result) + assert response["out"]["children"] == "John Smith" + + +def test_tuple_grouped_inputs(dash_duo): + """Tuple grouping: pair=(Input("a",...), Input("b",...)) → expands to two named params.""" + app = Dash(__name__) + app.layout = html.Div( + [ + dcc.Input(id="a", value="1"), + dcc.Input(id="b", value="2"), + html.Div(id="out"), + ] + ) + + # Callback: tuple group "pair" maps to 2 deps → 2 params named pair___ + @app.callback( + Output("out", "children"), + inputs=dict(pair=(Input("a", "value"), Input("b", "value"))), + ) + def combine(pair): + return f"{pair[0]}+{pair[1]}" + + dash_duo.start_server(app) + tool = _find_tool(_mcp_tools(dash_duo.server.url), "combine") + props = tool["inputSchema"]["properties"] + + # Tuple expands: one param per dep, named with group prefix + component info + assert set(props.keys()) == {"pair_a__value", "pair_b__value"} + for schema in props.values(): + assert any(s.get("type") == "string" for s in schema["anyOf"]) + + result = _mcp_call_tool( + dash_duo.server.url, + "combine", + {"pair_a__value": "x", "pair_b__value": "y"}, + ) + response = _get_response(result) + assert response["out"]["children"] == "x+y" + + +def test_initial_values_from_chained_callbacks(dash_duo): + """Querying components reflects post-initial-callback values. + + 3-link chain: country (default "France") → update_states → + state (should become "Ile-de-France") → update_cities → + city (should become "Paris"). + """ + DATA = { + "France": { + "Ile-de-France": ["Paris", "Versailles"], + "Provence": ["Marseille", "Nice"], + }, + "Germany": { + "Bavaria": ["Munich", "Nuremberg"], + "Berlin": ["Berlin"], + }, + } + + app = Dash(__name__) + app.layout = html.Div( + [ + dcc.Dropdown(id="country", options=list(DATA.keys()), value="France"), + dcc.Dropdown(id="state"), + dcc.Dropdown(id="city"), + ] + ) + + @app.callback( + Output("state", "options"), + Output("state", "value"), + Input("country", "value"), + ) + def update_states(country): + if not country: + return [], None + states = list(DATA[country].keys()) + return [{"label": s, "value": s} for s in states], states[0] + + @app.callback( + Output("city", "options"), + Output("city", "value"), + Input("state", "value"), + Input("country", "value"), + ) + def update_cities(state, country): + if not state or not country: + return [], None + cities = DATA[country][state] + return [{"label": c, "value": c} for c in cities], cities[0] + + dash_duo.start_server(app) + + # Tool descriptions should reflect post-initial-callback state + tools = _mcp_tools(dash_duo.server.url) + update_cities_tool = _find_tool(tools, "update_cities") + state_desc = update_cities_tool["inputSchema"]["properties"]["state"].get( + "description", "" + ) + # state.value was set to "Ile-de-France" by update_states initial callback + assert "Ile-de-France" in state_desc + + # state.value should be "Ile-de-France" (first state for France) + result = _mcp_call_tool( + dash_duo.server.url, + "get_dash_component", + {"component_id": "state", "property": "value"}, + ) + state_props = result["result"]["structuredContent"]["properties"] + assert state_props["value"]["initial_value"] == "Ile-de-France" + + # city.value should be "Paris" (first city for Ile-de-France) + result = _mcp_call_tool( + dash_duo.server.url, + "get_dash_component", + {"component_id": "city", "property": "value"}, + ) + city_props = result["result"]["structuredContent"]["properties"] + assert city_props["value"]["initial_value"] == "Paris" + + +def test_dict_based_reordered_state_input(dash_duo): + """Dict-based callback with State before Input: call works, schema types correct. + + State is listed before Input in the dict. The callback should still + work correctly via MCP, and the schema types should match the + function annotations (name: str, trigger: int), not be swapped. + """ + app = Dash(__name__) + app.layout = html.Div( + [ + html.Button(id="btn"), + dcc.Input(id="inp", value="World"), + html.Div(id="out"), + ] + ) + + @app.callback( + Output("out", "children"), + inputs=dict(name=State("inp", "value"), trigger=Input("btn", "n_clicks")), + ) + def greet(name: str, trigger: int): + return f"Hello {name}" + + dash_duo.start_server(app) + + # First: verify the callback actually works with these args + result = _mcp_call_tool( + dash_duo.server.url, + "greet", + {"name": "Dash", "trigger": 1}, + ) + assert _get_response(result)["out"]["children"] == "Hello Dash" + + # Second: verify schema types match annotations + tool = _find_tool(_mcp_tools(dash_duo.server.url), "greet") + props = tool["inputSchema"]["properties"] + assert props["trigger"]["type"] == "integer" + assert props["name"]["type"] == "string" + + # Third: verify each param describes the correct component + trigger_desc = props["trigger"].get("description", "") + assert "number of times that this element has been clicked on" in trigger_desc + name_desc = props["name"].get("description", "") + assert "The value of the input" in name_desc + + +def test_pattern_matching_callback(dash_duo): + """Pattern-matching dict IDs: tool works with correct params and results.""" + app = Dash(__name__) + app.layout = html.Div( + [ + dcc.Input(id={"type": "field", "index": 0}, value="hello"), + dcc.Input(id={"type": "field", "index": 1}, value="world"), + html.Div(id="out"), + ] + ) + + @app.callback( + Output("out", "children"), + Input({"type": "field", "index": 0}, "value"), + Input({"type": "field", "index": 1}, "value"), + ) + def combine(first, second): + return f"{first} {second}" + + dash_duo.start_server(app) + + tool = _find_tool(_mcp_tools(dash_duo.server.url), "combine") + assert tool is not None + props = tool["inputSchema"]["properties"] + assert "first" in props + assert "second" in props + + # Verify initial output matches what the browser shows + dash_duo.wait_for_text_to_equal("#out", "hello world") + result = _mcp_call_tool( + dash_duo.server.url, + "combine", + {"first": "hello", "second": "world"}, + ) + response = _get_response(result) + assert response["out"]["children"] == "hello world" + + # Verify with different values + result = _mcp_call_tool( + dash_duo.server.url, + "combine", + {"first": "foo", "second": "bar"}, + ) + response = _get_response(result) + assert response["out"]["children"] == "foo bar" + + +def test_pattern_matching_with_all_wildcard(dash_duo): + """ALL wildcard: one callback receives values from all matching components.""" + from dash import ALL + + app = Dash(__name__) + app.layout = html.Div( + [ + dcc.Input(id={"type": "input", "index": 0}, value="alpha"), + dcc.Input(id={"type": "input", "index": 1}, value="beta"), + html.Div(id="summary"), + ] + ) + + @app.callback( + Output("summary", "children"), + Input({"type": "input", "index": ALL}, "value"), + ) + def summarize(values): + return ", ".join(v for v in values if v) + + dash_duo.start_server(app) + dash_duo.wait_for_text_to_equal("#summary", "alpha, beta") + + tool = _find_tool(_mcp_tools(dash_duo.server.url), "summarize") + assert tool is not None + + # Schema must describe values as an array of {id, property, value} objects + values_schema = tool["inputSchema"]["properties"]["values"] + assert ( + values_schema["type"] == "array" + ), f"ALL wildcard param should be typed as array, got: {values_schema}" + assert "items" in values_schema, "Array schema should include items definition" + items = values_schema["items"] + assert items["type"] == "object" + assert "id" in items["properties"] + assert "value" in items["properties"] + assert "Pattern-matching input (ALL)" in values_schema.get( + "description", "" + ), "ALL wildcard param description should explain the pattern-matching behavior" + + # MCP tool call with browser-like format: concrete IDs + values + result = _mcp_call_tool( + dash_duo.server.url, + "summarize", + { + "values": [ + { + "id": {"type": "input", "index": 0}, + "property": "value", + "value": "alpha", + }, + { + "id": {"type": "input", "index": 1}, + "property": "value", + "value": "beta", + }, + ] + }, + ) + response = _get_response(result) + assert response["summary"]["children"] == "alpha, beta" + + # Different values + result = _mcp_call_tool( + dash_duo.server.url, + "summarize", + { + "values": [ + { + "id": {"type": "input", "index": 0}, + "property": "value", + "value": "one", + }, + { + "id": {"type": "input", "index": 1}, + "property": "value", + "value": "two", + }, + ] + }, + ) + response = _get_response(result) + assert response["summary"]["children"] == "one, two" + + +def test_pattern_matching_mixed_outputs(dash_duo): + """Mixed outputs: one regular + one ALL wildcard in the same callback.""" + from dash import ALL + + app = Dash(__name__) + app.layout = html.Div( + [ + dcc.Input(id={"type": "field", "index": 0}, value="a"), + dcc.Input(id={"type": "field", "index": 1}, value="b"), + html.Div(id={"type": "echo", "index": 0}), + html.Div(id={"type": "echo", "index": 1}), + html.Div(id="total"), + ] + ) + + @app.callback( + Output({"type": "echo", "index": ALL}, "children"), + Output("total", "children"), + Input({"type": "field", "index": ALL}, "value"), + ) + def echo_and_total(values): + echoes = [f"Echo: {v}" for v in values] + total = f"Total: {len(values)} items" + return echoes, total + + dash_duo.start_server(app) + dash_duo.wait_for_text_to_equal("#total", "Total: 2 items") + + result = _mcp_call_tool( + dash_duo.server.url, + "echo_and_total", + { + "values": [ + { + "id": {"type": "field", "index": 0}, + "property": "value", + "value": "x", + }, + { + "id": {"type": "field", "index": 1}, + "property": "value", + "value": "y", + }, + ] + }, + ) + response = _get_response(result) + assert response["total"]["children"] == "Total: 2 items" + + +def test_pattern_matching_with_match_wildcard(dash_duo): + """MATCH wildcard: callback fires per-component with matching index. + + Based on https://dash.plotly.com/pattern-matching-callbacks + """ + from dash import MATCH + + app = Dash(__name__) + app.layout = html.Div( + [ + dcc.Dropdown( + ["NYC", "MTL", "LA", "TOKYO"], + "NYC", + id={"type": "city-dd", "index": 0}, + ), + html.Div(id={"type": "city-out", "index": 0}), + dcc.Dropdown( + ["NYC", "MTL", "LA", "TOKYO"], + "LA", + id={"type": "city-dd", "index": 1}, + ), + html.Div(id={"type": "city-out", "index": 1}), + ] + ) + + @app.callback( + Output({"type": "city-out", "index": MATCH}, "children"), + Input({"type": "city-dd", "index": MATCH}, "value"), + ) + def show_city(value): + return f"Selected: {value}" + + dash_duo.start_server(app) + + tool = _find_tool(_mcp_tools(dash_duo.server.url), "show_city") + assert tool is not None + + # Schema describes MATCH input + value_schema = tool["inputSchema"]["properties"]["value"] + assert "Pattern-matching input (MATCH)" in value_schema.get("description", "") + + # Call with concrete ID for index 0 (MATCH takes a single entry, not an array) + result = _mcp_call_tool( + dash_duo.server.url, + "show_city", + { + "value": { + "id": {"type": "city-dd", "index": 0}, + "property": "value", + "value": "MTL", + } + }, + ) + response = _get_response(result) + # Find the output key containing "city-out" (Dash may serialize dict IDs differently) + out_key = next(k for k in response if "city-out" in k) + assert response[out_key]["children"] == "Selected: MTL" + + +def test_pattern_matching_with_allsmaller_wildcard(dash_duo): + """ALLSMALLER wildcard: receives values from components with smaller index. + + Based on https://dash.plotly.com/pattern-matching-callbacks + """ + from dash import MATCH, ALLSMALLER + + app = Dash(__name__) + app.layout = html.Div( + [ + dcc.Dropdown( + ["France", "Germany", "Japan"], + "France", + id={"type": "country-dd", "index": 0}, + ), + html.Div(id={"type": "country-out", "index": 0}), + dcc.Dropdown( + ["France", "Germany", "Japan"], + "Germany", + id={"type": "country-dd", "index": 1}, + ), + html.Div(id={"type": "country-out", "index": 1}), + dcc.Dropdown( + ["France", "Germany", "Japan"], + "Japan", + id={"type": "country-dd", "index": 2}, + ), + html.Div(id={"type": "country-out", "index": 2}), + ] + ) + + @app.callback( + Output({"type": "country-out", "index": MATCH}, "children"), + Input({"type": "country-dd", "index": MATCH}, "value"), + Input({"type": "country-dd", "index": ALLSMALLER}, "value"), + ) + def show_countries(current, previous): + all_selected = [current] + list(reversed(previous)) + return f"All: {', '.join(all_selected)}" + + dash_duo.start_server(app) + + tool = _find_tool(_mcp_tools(dash_duo.server.url), "show_countries") + assert tool is not None + + # Schema describes both MATCH and ALLSMALLER inputs + props = tool["inputSchema"]["properties"] + assert "Pattern-matching input (MATCH)" in props["current"].get("description", "") + assert "Pattern-matching input (ALLSMALLER)" in props["previous"].get( + "description", "" + ) + + # Call for index 2: MATCH is a single dict, ALLSMALLER is a list + result = _mcp_call_tool( + dash_duo.server.url, + "show_countries", + { + "current": { + "id": {"type": "country-dd", "index": 2}, + "property": "value", + "value": "Japan", + }, + "previous": [ + { + "id": {"type": "country-dd", "index": 0}, + "property": "value", + "value": "France", + }, + { + "id": {"type": "country-dd", "index": 1}, + "property": "value", + "value": "Germany", + }, + ], + }, + ) + response = _get_response(result) + out_key = next(k for k in response if "country-out" in k) + assert response[out_key]["children"] == "All: Japan, Germany, France" + + +def test_prevent_initial_call_uses_layout_default(dash_duo): + """prevent_initial_call=True: initial value stays as the layout default. + + The dropdown has value="original" in the layout. The callback has + prevent_initial_call=True so it doesn't run on page load. The MCP + tool description should show value: 'a' (layout default). + """ + app = Dash(__name__) + app.layout = html.Div( + [ + dcc.Dropdown(id="dd", options=["a", "b", "c"], value="a"), + html.Div(id="out", children="not yet"), + ] + ) + + @app.callback( + Output("out", "children"), + Input("dd", "value"), + prevent_initial_call=True, + ) + def update(val): + return f"Changed to: {val}" + + dash_duo.start_server(app) + # Browser shows layout default — callback hasn't fired + dash_duo.wait_for_text_to_equal("#out", "not yet") + + tool = _find_tool(_mcp_tools(dash_duo.server.url), "update") + val_desc = tool["inputSchema"]["properties"]["val"].get("description", "") + + # Tool description reflects layout default, not callback output + assert "value: 'a'" in val_desc + + +def test_initial_callback_overrides_layout_value(dash_duo): + """Initial callback overrides layout value in tool description. + + The city dropdown has value="default-city" in the layout. + update_city runs on page load (no prevent_initial_call) and + sets city.value to "Paris". The MCP tool should show "Paris" + as the default, not "default-city". + """ + app = Dash(__name__) + app.layout = html.Div( + [ + dcc.Dropdown(id="country", options=["France", "Germany"], value="France"), + dcc.Dropdown(id="city", options=[], value="default-city"), + html.Div(id="out"), + ] + ) + + @app.callback( + Output("city", "options"), + Output("city", "value"), + Input("country", "value"), + ) + def update_city(country): + if country == "France": + return [{"label": "Paris", "value": "Paris"}], "Paris" + return [{"label": "Berlin", "value": "Berlin"}], "Berlin" + + @app.callback(Output("out", "children"), Input("city", "value")) + def show_city(city): + return f"City: {city}" + + dash_duo.start_server(app) + # Browser shows "Paris" — the initial callback overrode "default-city" + dash_duo.wait_for_text_to_equal("#out", "City: Paris") + + tool = _find_tool(_mcp_tools(dash_duo.server.url), "show_city") + city_desc = tool["inputSchema"]["properties"]["city"].get("description", "") + + # Tool description should show the post-initial-callback value + assert "value: 'Paris'" in city_desc + assert "default-city" not in city_desc + + +def test_callback_context_triggered_id(dash_duo): + """Callbacks using dash.ctx.triggered_id work via MCP. + + Based on https://dash.plotly.com/determining-which-callback-input-changed + """ + from dash import ctx + + app = Dash(__name__) + app.layout = html.Div( + [ + html.Button("Button 1", id="btn-1"), + html.Button("Button 2", id="btn-2"), + html.Button("Button 3", id="btn-3"), + html.Div(id="output"), + ] + ) + + @app.callback( + Output("output", "children"), + Input("btn-1", "n_clicks"), + Input("btn-2", "n_clicks"), + Input("btn-3", "n_clicks"), + ) + def display(btn1, btn2, btn3): + if not ctx.triggered_id: + return "No button clicked yet" + return f"Last clicked: {ctx.triggered_id}" + + dash_duo.start_server(app) + + # Browser initial state: no button clicked + dash_duo.wait_for_text_to_equal("#output", "No button clicked yet") + + # Tool should have all three button params + tool = _find_tool(_mcp_tools(dash_duo.server.url), "display") + props = tool["inputSchema"]["properties"] + assert "btn1" in props + assert "btn2" in props + assert "btn3" in props + + # Click btn-2 via MCP — ctx.triggered_id should be "btn-2" + result = _mcp_call_tool( + dash_duo.server.url, + "display", + {"btn1": None, "btn2": 1, "btn3": None}, + ) + response = _get_response(result) + assert response["output"]["children"] == "Last clicked: btn-2" + + # Click btn-3 via MCP + result = _mcp_call_tool( + dash_duo.server.url, + "display", + {"btn1": None, "btn2": None, "btn3": 5}, + ) + response = _get_response(result) + assert response["output"]["children"] == "Last clicked: btn-3" + + +def test_no_output_callback_does_not_crash_tools_list(dash_duo): + """A callback with no Output should not crash tools/list. + + No-output callbacks use set_props for side effects. They produce + a hash-only output_id with no dot separator. + """ + from dash import set_props + + app = Dash(__name__) + app.layout = html.Div( + [ + html.Button("Log", id="log-btn"), + dcc.Dropdown(id="picker", options=["a", "b"], value="a"), + html.Div(id="display"), + ] + ) + + @app.callback(Input("log-btn", "n_clicks"), prevent_initial_call=True) + def log_click(n): + set_props("display", {"children": f"Logged {n} clicks"}) + + @app.callback(Output("display", "children"), Input("picker", "value")) + def show_selection(val): + return f"Selected: {val}" + + dash_duo.start_server(app) + + tools = _mcp_tools(dash_duo.server.url) + tool_names = [t["name"] for t in tools] + + # show_selection should appear as a tool + assert "show_selection" in tool_names + + # log_click has no declared output but uses set_props — still a valid tool + assert "log_click" in tool_names + + # Call log_click — sideUpdate should show the set_props effect + result = _mcp_call_tool( + dash_duo.server.url, + "log_click", + {"n": 3}, + ) + structured = result["result"]["structuredContent"] + assert "sideUpdate" in structured + assert structured["sideUpdate"]["display"]["children"] == "Logged 3 clicks" + + # get_dash_component shows show_selection as modifier (declared output). + # log_click uses set_props which bypasses the declarative graph — + # its effect is only visible via sideUpdate in tool call results. + result = _mcp_call_tool( + dash_duo.server.url, + "get_dash_component", + {"component_id": "display", "property": "children"}, + ) + prop_info = result["result"]["structuredContent"]["properties"]["children"] + assert "show_selection" in prop_info["modified_by_tool"] diff --git a/tests/integration/mcp/primitives/tools/test_duplicate_outputs.py b/tests/integration/mcp/primitives/tools/test_duplicate_outputs.py new file mode 100644 index 0000000000..4ad00641f8 --- /dev/null +++ b/tests/integration/mcp/primitives/tools/test_duplicate_outputs.py @@ -0,0 +1,128 @@ +"""Integration test for duplicate callback outputs. + +Multiple callbacks can output to the same component.property +when using ``allow_duplicate=True``. The MCP server must handle +this correctly — both callbacks should appear as tools, and +calling either should work. +""" + +from dash import Dash, Input, Output, dcc, html + +from tests.integration.mcp.conftest import _mcp_call_tool, _mcp_tools + + +def _find_tool(tools, name): + return next((t for t in tools if t["name"] == name), None) + + +def _get_response(result): + return result["result"]["structuredContent"]["response"] + + +def test_duplicate_outputs_both_tools_listed(dash_duo): + """Both callbacks outputting to the same component appear as tools.""" + app = Dash(__name__) + app.layout = html.Div( + [ + dcc.Input(id="first-name", value="Jane"), + dcc.Input(id="last-name", value="Doe"), + html.Div(id="greeting"), + ] + ) + + @app.callback( + Output("greeting", "children"), + Input("first-name", "value"), + ) + def greet_by_first(first): + return f"Hello, {first}!" + + @app.callback( + Output("greeting", "children", allow_duplicate=True), + Input("last-name", "value"), + prevent_initial_call=True, + ) + def greet_by_last(last): + return f"Hi, {last}!" + + dash_duo.start_server(app) + tools = _mcp_tools(dash_duo.server.url) + + first_tool = _find_tool(tools, "greet_by_first") + last_tool = _find_tool(tools, "greet_by_last") + + assert first_tool is not None, "greet_by_first should be listed" + assert last_tool is not None, "greet_by_last should be listed" + + +def test_duplicate_outputs_both_callable(dash_duo): + """Both callbacks can be called and produce correct results.""" + app = Dash(__name__) + app.layout = html.Div( + [ + dcc.Input(id="first-name", value="Jane"), + dcc.Input(id="last-name", value="Doe"), + html.Div(id="greeting"), + ] + ) + + @app.callback( + Output("greeting", "children"), + Input("first-name", "value"), + ) + def greet_by_first(first): + return f"Hello, {first}!" + + @app.callback( + Output("greeting", "children", allow_duplicate=True), + Input("last-name", "value"), + prevent_initial_call=True, + ) + def greet_by_last(last): + return f"Hi, {last}!" + + dash_duo.start_server(app) + + result1 = _mcp_call_tool(dash_duo.server.url, "greet_by_first", {"first": "Alice"}) + assert _get_response(result1)["greeting"]["children"] == "Hello, Alice!" + + result2 = _mcp_call_tool(dash_duo.server.url, "greet_by_last", {"last": "Smith"}) + assert _get_response(result2)["greeting"]["children"] == "Hi, Smith!" + + +def test_duplicate_outputs_find_by_output_returns_primary(dash_duo): + """find_by_output returns the primary (non-duplicate) callback.""" + app = Dash(__name__) + app.layout = html.Div( + [ + dcc.Input(id="first-name", value="Jane"), + dcc.Input(id="last-name", value="Doe"), + html.Div(id="greeting"), + ] + ) + + @app.callback( + Output("greeting", "children"), + Input("first-name", "value"), + ) + def greet_by_first(first): + return f"Hello, {first}!" + + @app.callback( + Output("greeting", "children", allow_duplicate=True), + Input("last-name", "value"), + prevent_initial_call=True, + ) + def greet_by_last(last): + return f"Hi, {last}!" + + dash_duo.start_server(app) + + # Query the component — should reflect initial callback (greet_by_first) + result = _mcp_call_tool( + dash_duo.server.url, + "get_dash_component", + {"component_id": "greeting", "property": "children"}, + ) + structured = result["result"]["structuredContent"] + assert structured["properties"]["children"]["initial_value"] == "Hello, Jane!" diff --git a/tests/integration/mcp/primitives/tools/test_input_schemas.py b/tests/integration/mcp/primitives/tools/test_input_schemas.py new file mode 100644 index 0000000000..6ee3510ddd --- /dev/null +++ b/tests/integration/mcp/primitives/tools/test_input_schemas.py @@ -0,0 +1,66 @@ +""" +Integration tests for MCP tool schema generation. + +Starts a real Dash server via ``dash_duo`` and verifies that tools +are generated with correct inputSchema, descriptions, and labels. +""" + +from dash import Dash, Input, Output, dcc, html + +from tests.integration.mcp.conftest import _mcp_tools + + +def test_mcp_tool_with_label_and_date_picker_schema(dash_duo): + """Full assertion on a tool with an html.Label and DatePickerSingle constraints.""" + + # -- Test data: change these to update the test -- + label_text = "Departure Date" + component_id = "dp" + min_date = "2020-01-01" + max_date = "2025-12-31" + default_date = "2024-06-15" + func_name = "select_date" + param_name = "date" # function parameter name + + app = Dash(__name__) + app.layout = html.Div( + [ + html.Label(label_text, htmlFor=component_id), + dcc.DatePickerSingle( + id=component_id, + min_date_allowed=min_date, + max_date_allowed=max_date, + date=default_date, + ), + html.Div(id="out"), + ] + ) + + @app.callback(Output("out", "children"), Input(component_id, "date")) + def select_date(date): + return f"Selected: {date}" + + dash_duo.start_server(app) + tools = _mcp_tools(dash_duo.server.url) + + # Find the callback tool + tool = next(t for t in tools if t["name"] not in ("get_dash_component",)) + + # -- Tool-level fields -- + assert func_name in tool["name"] + + # -- inputSchema structure -- + schema = tool["inputSchema"] + assert schema["type"] == "object" + assert param_name in schema["required"] + assert param_name in schema["properties"] + + # -- Property schema: type + format + description -- + prop = schema["properties"][param_name] + assert prop["type"] == "string" + assert prop["format"] == "date" + + # description includes all source values (label, constraints, default) + desc = prop["description"] + for expected in (label_text, min_date, max_date, default_date): + assert expected in desc, f"Expected {expected!r} in description: {desc!r}" diff --git a/tests/integration/mcp/primitives/tools/test_tool_get_dash_component.py b/tests/integration/mcp/primitives/tools/test_tool_get_dash_component.py new file mode 100644 index 0000000000..97472a16d7 --- /dev/null +++ b/tests/integration/mcp/primitives/tools/test_tool_get_dash_component.py @@ -0,0 +1,54 @@ +"""Integration tests for the get_dash_component tool.""" + +from dash import Dash, dcc, html + +from tests.integration.mcp.conftest import _mcp_call_tool + +EXPECTED_DROPDOWN_OPTIONS = { + "component_id": "my-dropdown", + "component_type": "Dropdown", + "label": None, + "properties": { + "options": { + "initial_value": [ + {"label": "New York", "value": "NYC"}, + {"label": "Montreal", "value": "MTL"}, + ], + "modified_by_tool": [], + "input_to_tool": [], + }, + }, +} + + +def test_query_component_returns_structured_output(dash_duo): + app = Dash(__name__) + app.layout = html.Div( + [ + dcc.Dropdown( + id="my-dropdown", + options=[ + {"label": "New York", "value": "NYC"}, + {"label": "Montreal", "value": "MTL"}, + ], + value="NYC", + ), + ] + ) + + dash_duo.start_server(app) + + result = _mcp_call_tool( + dash_duo.server.url, + "get_dash_component", + {"component_id": "my-dropdown", "property": "options"}, + ) + + assert "result" in result, f"Expected result in response: {result}" + structured = result["result"]["structuredContent"] + assert structured["component_id"] == EXPECTED_DROPDOWN_OPTIONS["component_id"] + assert structured["component_type"] == EXPECTED_DROPDOWN_OPTIONS["component_type"] + assert ( + structured["properties"]["options"] + == EXPECTED_DROPDOWN_OPTIONS["properties"]["options"] + ) diff --git a/tests/integration/mcp/primitives/tools/test_tools_list.py b/tests/integration/mcp/primitives/tools/test_tools_list.py new file mode 100644 index 0000000000..dc3d977146 --- /dev/null +++ b/tests/integration/mcp/primitives/tools/test_tools_list.py @@ -0,0 +1,118 @@ +"""Integration tests for tools/list — naming, dedup, and spec compliance.""" + +from dash import Dash, Input, Output, dcc, html + +from tests.integration.mcp.conftest import _mcp_tools + + +def test_tool_names_within_64_chars(dash_duo): + app = Dash(__name__) + app.layout = html.Div( + [ + dcc.Dropdown(id="dd", options=["a"], value="a"), + html.Div(id="out"), + ] + ) + + @app.callback(Output("out", "children"), Input("dd", "value")) + def update(val): + return val + + dash_duo.start_server(app) + for tool in _mcp_tools(dash_duo.server.url): + assert len(tool["name"]) <= 64, f"Tool name exceeds 64 chars: {tool['name']}" + for param_name in tool.get("inputSchema", {}).get("properties", {}): + assert len(param_name) <= 64, f"Param name exceeds 64 chars: {param_name}" + + +def test_long_callback_ids_within_64_chars(dash_duo): + app = Dash(__name__) + long_id = "a" * 120 + app.layout = html.Div( + [ + dcc.Input(id=long_id, value="test"), + html.Div(id=f"{long_id}-output"), + ] + ) + + @app.callback(Output(f"{long_id}-output", "children"), Input(long_id, "value")) + def process(val): + return val + + dash_duo.start_server(app) + for tool in _mcp_tools(dash_duo.server.url): + assert len(tool["name"]) <= 64, f"Tool name exceeds 64 chars: {tool['name']}" + + +def test_pattern_matching_ids_within_64_chars(dash_duo): + app = Dash(__name__) + app.layout = html.Div( + [ + html.Div( + [ + dcc.Input( + id={"type": "filter-input", "index": i, "category": "primary"}, + value=f"val-{i}", + ) + for i in range(3) + ] + ), + html.Div(id="pm-output"), + ] + ) + + @app.callback( + Output("pm-output", "children"), + Input({"type": "filter-input", "index": 0, "category": "primary"}, "value"), + ) + def filter_update(v0): + return str(v0) + + dash_duo.start_server(app) + for tool in _mcp_tools(dash_duo.server.url): + assert len(tool["name"]) <= 64, f"Tool name exceeds 64 chars: {tool['name']}" + + +def test_duplicate_func_names_produce_unique_tools(dash_duo): + app = Dash(__name__) + app.layout = html.Div( + [ + dcc.Dropdown(id="dd1", options=["a"], value="a"), + html.Div(id="dd1-output"), + dcc.Dropdown(id="dd2", options=["b"], value="b"), + html.Div(id="dd2-output"), + dcc.Dropdown(id="dd3", options=["c"], value="c"), + html.Div(id="dd3-output"), + ] + ) + + @app.callback(Output("dd1-output", "children"), Input("dd1", "value")) + def cb(value): + return f"first: {value}" + + @app.callback(Output("dd2-output", "children"), Input("dd2", "value")) + def cb(value): # noqa: F811 + return f"second: {value}" + + @app.callback(Output("dd3-output", "children"), Input("dd3", "value")) + def cb(value): # noqa: F811 + return f"third: {value}" + + dash_duo.start_server(app) + tools = _mcp_tools(dash_duo.server.url) + cb_tools = [t for t in tools if t["name"] not in ("get_dash_component",)] + tool_names = [t["name"] for t in cb_tools] + + assert ( + len(tool_names) == 3 + ), f"Expected 3 callback tools, got {len(tool_names)}: {tool_names}" + assert len(set(tool_names)) == 3, f"Tool names not unique: {tool_names}" + + +def test_builtin_tools_always_present(dash_duo): + app = Dash(__name__) + app.layout = html.Div(id="root") + + dash_duo.start_server(app) + tool_names = [t["name"] for t in _mcp_tools(dash_duo.server.url)] + assert "get_dash_component" in tool_names diff --git a/tests/integration/mcp/test_server.py b/tests/integration/mcp/test_server.py new file mode 100644 index 0000000000..7af88bfbff --- /dev/null +++ b/tests/integration/mcp/test_server.py @@ -0,0 +1,304 @@ +"""Integration tests for the MCP Streamable HTTP endpoint. + +These tests use Flask's test_client to exercise the HTTP transport layer +(POST/GET/DELETE at /_mcp), session management, content-type handling, +and route registration/configuration. +""" + +import json +import os + +from dash import Dash, Input, Output, html +from mcp.types import LATEST_PROTOCOL_VERSION + +MCP_PATH = "_mcp" + + +# --------------------------------------------------------------------------- +# Helpers +# --------------------------------------------------------------------------- + + +def _make_app(**kwargs): + """Create a minimal Dash app with a layout and one callback.""" + app = Dash(__name__, **kwargs) + app.layout = html.Div( + [ + html.Div(id="my-input"), + html.Div(id="my-output"), + ] + ) + + @app.callback(Output("my-output", "children"), Input("my-input", "children")) + def update_output(value): + """Test callback docstring.""" + return f"echo: {value}" + + return app + + +# --------------------------------------------------------------------------- +# Tests +# --------------------------------------------------------------------------- + + +class TestMCPEndpoint: + """Tests for the Streamable HTTP MCP endpoint at /_mcp.""" + + def test_post_initialize_creates_session(self): + app = _make_app() + client = app.server.test_client() + r = client.post( + f"/{MCP_PATH}", + data=json.dumps( + {"jsonrpc": "2.0", "method": "initialize", "id": 1, "params": {}} + ), + content_type="application/json", + ) + assert r.status_code == 200 + assert "mcp-session-id" in r.headers + data = json.loads(r.data) + assert data["result"]["protocolVersion"] == LATEST_PROTOCOL_VERSION + + def test_post_without_session_auto_assigns(self): + app = _make_app() + client = app.server.test_client() + r = client.post( + f"/{MCP_PATH}", + data=json.dumps( + {"jsonrpc": "2.0", "method": "tools/list", "id": 1, "params": {}} + ), + content_type="application/json", + ) + assert r.status_code == 200 + assert "mcp-session-id" in r.headers + data = json.loads(r.data) + assert "tools" in data["result"] + + def test_stale_session_error_includes_hint(self): + app = _make_app() + client = app.server.test_client() + r = client.post( + f"/{MCP_PATH}", + data=json.dumps( + { + "jsonrpc": "2.0", + "method": "tools/call", + "id": 1, + "params": {"name": "no_such_tool", "arguments": {}}, + } + ), + content_type="application/json", + headers={"mcp-session-id": "old-session-from-before-restart"}, + ) + assert r.status_code == 200 + data = json.loads(r.data) + assert "session was not recognised" in data["error"]["message"] + assert "tools/list" in data["error"]["message"] + + def test_post_with_valid_session(self): + app = _make_app() + client = app.server.test_client() + # Initialize to get session + r1 = client.post( + f"/{MCP_PATH}", + data=json.dumps( + {"jsonrpc": "2.0", "method": "initialize", "id": 1, "params": {}} + ), + content_type="application/json", + ) + session_id = r1.headers["mcp-session-id"] + # Use session for tools/list + r2 = client.post( + f"/{MCP_PATH}", + data=json.dumps( + {"jsonrpc": "2.0", "method": "tools/list", "id": 2, "params": {}} + ), + content_type="application/json", + headers={"mcp-session-id": session_id}, + ) + assert r2.status_code == 200 + data = json.loads(r2.data) + assert "result" in data + assert "tools" in data["result"] + + def test_notification_returns_202(self): + app = _make_app() + client = app.server.test_client() + # Initialize to get session + r1 = client.post( + f"/{MCP_PATH}", + data=json.dumps( + {"jsonrpc": "2.0", "method": "initialize", "id": 1, "params": {}} + ), + content_type="application/json", + ) + session_id = r1.headers["mcp-session-id"] + # Send notification (no id field) + r2 = client.post( + f"/{MCP_PATH}", + data=json.dumps({"jsonrpc": "2.0", "method": "notifications/initialized"}), + content_type="application/json", + headers={"mcp-session-id": session_id}, + ) + assert r2.status_code == 202 + + def test_delete_terminates_session(self): + app = _make_app() + client = app.server.test_client() + # Initialize + r1 = client.post( + f"/{MCP_PATH}", + data=json.dumps( + {"jsonrpc": "2.0", "method": "initialize", "id": 1, "params": {}} + ), + content_type="application/json", + ) + session_id = r1.headers["mcp-session-id"] + # Delete + r2 = client.delete( + f"/{MCP_PATH}", + headers={"mcp-session-id": session_id}, + ) + assert r2.status_code == 204 + # Post-delete requests still succeed + r3 = client.post( + f"/{MCP_PATH}", + data=json.dumps( + {"jsonrpc": "2.0", "method": "tools/list", "id": 2, "params": {}} + ), + content_type="application/json", + headers={"mcp-session-id": session_id}, + ) + assert r3.status_code == 200 + + def test_delete_nonexistent_session_returns_404(self): + app = _make_app() + client = app.server.test_client() + r = client.delete( + f"/{MCP_PATH}", + headers={"mcp-session-id": "nonexistent"}, + ) + assert r.status_code == 404 + + def test_get_without_session_returns_404(self): + app = _make_app() + client = app.server.test_client() + r = client.get(f"/{MCP_PATH}") + assert r.status_code == 404 + + def test_get_with_stale_session_returns_404(self): + app = _make_app() + client = app.server.test_client() + r = client.get( + f"/{MCP_PATH}", + headers={"mcp-session-id": "nonexistent"}, + ) + assert r.status_code == 404 + + def test_get_returns_sse_stream(self): + app = _make_app() + client = app.server.test_client() + # First create a session via POST initialize + init = client.post( + f"/{MCP_PATH}", + data=json.dumps( + {"jsonrpc": "2.0", "method": "initialize", "id": 1, "params": {}} + ), + content_type="application/json", + ) + session_id = init.headers["mcp-session-id"] + # GET with valid session returns SSE stream + r = client.get( + f"/{MCP_PATH}", + headers={"mcp-session-id": session_id}, + ) + assert r.status_code == 200 + assert r.content_type == "text/event-stream" + assert r.headers.get("Cache-Control") == "no-cache" + + def test_post_rejects_wrong_content_type(self): + app = _make_app() + client = app.server.test_client() + r = client.post( + f"/{MCP_PATH}", + data="not json", + content_type="text/plain", + ) + assert r.status_code == 415 + + def test_routes_not_registered_when_disabled(self): + app = _make_app(enable_mcp=False) + client = app.server.test_client() + r = client.post( + f"/{MCP_PATH}", + data=json.dumps( + {"jsonrpc": "2.0", "method": "initialize", "id": 1, "params": {}} + ), + content_type="application/json", + ) + # With MCP disabled, the route doesn't exist — response is HTML, not JSON + assert r.content_type != "application/json" + + def test_routes_respect_pathname_prefix(self): + app = _make_app(routes_pathname_prefix="/app/") + client = app.server.test_client() + + ok = client.post( + f"/app/{MCP_PATH}", + data=json.dumps( + {"jsonrpc": "2.0", "method": "initialize", "id": 1, "params": {}} + ), + content_type="application/json", + ) + assert ok.status_code == 200 + + miss = client.post( + f"/{MCP_PATH}", + data=json.dumps( + {"jsonrpc": "2.0", "method": "initialize", "id": 1, "params": {}} + ), + content_type="application/json", + ) + assert miss.status_code == 404 + + def test_enable_mcp_env_var_false(self): + old = os.environ.get("DASH_MCP_ENABLED") + try: + os.environ["DASH_MCP_ENABLED"] = "false" + app = _make_app() + client = app.server.test_client() + r = client.post( + f"/{MCP_PATH}", + data=json.dumps( + {"jsonrpc": "2.0", "method": "initialize", "id": 1, "params": {}} + ), + content_type="application/json", + ) + assert r.content_type != "application/json" + finally: + if old is None: + os.environ.pop("DASH_MCP_ENABLED", None) + else: + os.environ["DASH_MCP_ENABLED"] = old + + def test_constructor_overrides_env_var(self): + old = os.environ.get("DASH_MCP_ENABLED") + try: + os.environ["DASH_MCP_ENABLED"] = "false" + app = _make_app(enable_mcp=True) + client = app.server.test_client() + r = client.post( + f"/{MCP_PATH}", + data=json.dumps( + {"jsonrpc": "2.0", "method": "initialize", "id": 1, "params": {}} + ), + content_type="application/json", + ) + assert r.status_code == 200 + assert b"protocolVersion" in r.data + finally: + if old is None: + os.environ.pop("DASH_MCP_ENABLED", None) + else: + os.environ["DASH_MCP_ENABLED"] = old diff --git a/tests/unit/mcp/test_server.py b/tests/unit/mcp/test_server.py new file mode 100644 index 0000000000..93238faf19 --- /dev/null +++ b/tests/unit/mcp/test_server.py @@ -0,0 +1,92 @@ +"""Tests for MCP server (_server.py) — JSON-RPC message processing.""" + +from dash._get_app import app_context +from dash.mcp._server import _process_mcp_message +from mcp.types import LATEST_PROTOCOL_VERSION + +from tests.unit.mcp.conftest import _make_app, _setup_mcp + + +def _msg(method, params=None, request_id=1): + d = {"jsonrpc": "2.0", "method": method, "id": request_id} + d["params"] = params if params is not None else {} + return d + + +def _mcp(app, method, params=None, request_id=1): + with app.server.test_request_context(): + _setup_mcp(app) + return _process_mcp_message(_msg(method, params, request_id)) + + +def _tools_list(app): + return _mcp(app, "tools/list")["result"]["tools"] + + +def _call_tool(app, tool_name, arguments=None, request_id=1): + return _mcp( + app, "tools/call", {"name": tool_name, "arguments": arguments or {}}, request_id + ) + + +def _call_tool_output( + app, tool_name, arguments=None, component_id=None, prop="children" +): + result = _call_tool(app, tool_name, arguments) + structured = result["result"]["structuredContent"] + response = structured["response"] + if component_id is None: + component_id = next(iter(response)) + return response[component_id][prop] + + +class TestProcessMCPMessage: + def test_initialize(self): + app = _make_app() + result = _mcp(app, "initialize") + + assert result is not None + assert result["id"] == 1 + assert result["jsonrpc"] == "2.0" + assert result["result"]["protocolVersion"] == LATEST_PROTOCOL_VERSION + assert "serverInfo" in result["result"] + + def test_initialize_advertises_list_changed(self): + app = _make_app() + result = _mcp(app, "initialize") + caps = result["result"]["capabilities"] + assert caps["tools"]["listChanged"] is True + + def test_tools_call(self): + app = _make_app() + tools = _tools_list(app) + tool_name = next(t["name"] for t in tools if "update_output" in t["name"]) + + result = _call_tool(app, tool_name, {"value": "hello"}, request_id=2) + + assert result is not None + assert result["id"] == 2 + assert _call_tool_output(app, tool_name, {"value": "hello"}) == "echo: hello" + + def test_tools_call_unknown_tool_returns_error(self): + app = _make_app() + result = _call_tool(app, "nonexistent_tool") + + assert result is not None + assert "error" in result + assert result["error"]["code"] == -32601 + + def test_unknown_method_returns_error(self): + app = _make_app() + result = _mcp(app, "unknown/method") + + assert result is not None + assert "error" in result + + def test_notification_returns_none(self): + app = _make_app() + data = {"jsonrpc": "2.0", "method": "notifications/initialized"} + with app.server.test_request_context(): + app_context.set(app) + result = _process_mcp_message(data) + assert result is None diff --git a/tests/unit/mcp/tools/test_run_callback.py b/tests/unit/mcp/tools/test_run_callback.py new file mode 100644 index 0000000000..00f4e5b7b1 --- /dev/null +++ b/tests/unit/mcp/tools/test_run_callback.py @@ -0,0 +1,246 @@ +"""Tests for callback dispatch execution via MCP tools.""" + +from dash import Dash, Input, Output, State, dcc, html +from dash.exceptions import PreventUpdate +from dash.mcp._server import _process_mcp_message + +from tests.unit.mcp.conftest import _setup_mcp + + +def _msg(method, params=None, request_id=1): + d = {"jsonrpc": "2.0", "method": method, "id": request_id} + d["params"] = params if params is not None else {} + return d + + +def _mcp(app, method, params=None, request_id=1): + with app.server.test_request_context(): + _setup_mcp(app) + return _process_mcp_message(_msg(method, params, request_id)) + + +def _tools_list(app): + return _mcp(app, "tools/list")["result"]["tools"] + + +def _call_tool_structured(app, tool_name, arguments=None): + result = _mcp(app, "tools/call", {"name": tool_name, "arguments": arguments or {}}) + return result["result"]["structuredContent"] + + +def _call_tool_output( + app, tool_name, arguments=None, component_id=None, prop="children" +): + structured = _call_tool_structured(app, tool_name, arguments) + response = structured["response"] + if component_id is None: + component_id = next(iter(response)) + return response[component_id][prop] + + +class TestRunCallback: + def test_multi_output(self): + app = Dash(__name__) + app.layout = html.Div( + [ + dcc.Dropdown(id="dd", options=["a", "b"], value="a"), + dcc.Dropdown(id="dd2"), + html.Div(id="out"), + ] + ) + + @app.callback( + Output("dd2", "options"), + Output("out", "children"), + Input("dd", "value"), + ) + def update(val): + return [{"label": val, "value": val}], f"selected: {val}" + + tools = _tools_list(app) + tool_name = next(t["name"] for t in tools if "update" in t["name"]) + structured = _call_tool_structured(app, tool_name, {"val": "b"}) + assert structured["response"]["dd2"]["options"] == [ + {"label": "b", "value": "b"} + ] + assert structured["response"]["out"]["children"] == "selected: b" + + def test_omitted_kwargs_default_to_none(self): + app = Dash(__name__) + app.layout = html.Div( + [ + dcc.Dropdown(id="dd", options=["a"]), + dcc.Input(id="inp"), + html.Div(id="out"), + ] + ) + + @app.callback( + Output("out", "children"), + Input("dd", "value"), + State("inp", "value"), + ) + def update(selected, text): + return f"{selected}-{text}" + + tools = _tools_list(app) + tool_name = next(t["name"] for t in tools if "update" in t["name"]) + assert _call_tool_output(app, tool_name, {"selected": "a"}, "out") == "a-None" + + def test_no_output_callback(self): + app = Dash(__name__) + app.layout = html.Div( + [ + html.Button(id="btn"), + html.Div(id="display"), + ] + ) + + @app.callback(Input("btn", "n_clicks")) + def server_cb(n): + from dash import set_props + + set_props("display", {"children": f"Clicked {n} times"}) + + tools = _tools_list(app) + tool_names = [t["name"] for t in tools] + assert "server_cb" in tool_names + + def test_prevent_update(self): + app = Dash(__name__) + app.layout = html.Div( + [ + dcc.Input(id="inp", value="hello"), + html.Div(id="out"), + ] + ) + + @app.callback(Output("out", "children"), Input("inp", "value")) + def update(val): + if val == "block": + raise PreventUpdate + return f"got: {val}" + + tools = _tools_list(app) + tool_name = next(t["name"] for t in tools if "update" in t["name"]) + assert _call_tool_output(app, tool_name, {"val": "test"}, "out") == "got: test" + + def test_with_state(self): + app = Dash(__name__) + app.layout = html.Div( + [ + html.Div(id="trigger"), + html.Div(id="store"), + html.Div(id="result"), + ] + ) + + @app.callback( + Output("result", "children"), + Input("trigger", "children"), + State("store", "children"), + ) + def with_state(trigger, store): + return f"{trigger}-{store}" + + tools = _tools_list(app) + tool_name = next(t["name"] for t in tools if "with_state" in t["name"]) + assert ( + _call_tool_output( + app, + tool_name, + { + "trigger": "click", + "store": "data", + }, + "result", + ) + == "click-data" + ) + + def test_dict_inputs(self): + app = Dash(__name__) + app.layout = html.Div( + [ + dcc.Input(id="x-input", value="hello"), + dcc.Input(id="y-input", value="world"), + html.Div(id="dict-out"), + ] + ) + + @app.callback( + Output("dict-out", "children"), + inputs={ + "x_val": Input("x-input", "value"), + "y_val": Input("y-input", "value"), + }, + ) + def combine(**kwargs): + return f"{kwargs['x_val']}-{kwargs['y_val']}" + + tools = _tools_list(app) + tool_name = next(t["name"] for t in tools if "combine" in t["name"]) + assert ( + _call_tool_output( + app, + tool_name, + { + "x_val": "foo", + "y_val": "bar", + }, + "dict-out", + ) + == "foo-bar" + ) + + def test_positional_inputs(self): + app = Dash(__name__) + app.layout = html.Div( + [ + dcc.Input(id="a-input", value="A"), + html.Div(id="pos-out"), + ] + ) + + @app.callback(Output("pos-out", "children"), Input("a-input", "value")) + def echo(val): + return f"got:{val}" + + tools = _tools_list(app) + tool_name = next(t["name"] for t in tools if "echo" in t["name"]) + assert ( + _call_tool_output(app, tool_name, {"val": "test"}, "pos-out") == "got:test" + ) + + def test_dict_inputs_with_state(self): + app = Dash(__name__) + app.layout = html.Div( + [ + dcc.Input(id="inp", value="hi"), + html.Div(id="st", children="state-val"), + html.Div(id="ds-out"), + ] + ) + + @app.callback( + Output("ds-out", "children"), + inputs={"trigger": Input("inp", "value")}, + state={"kept": State("st", "children")}, + ) + def with_dict_state(**kwargs): + return f"{kwargs['trigger']}+{kwargs['kept']}" + + tools = _tools_list(app) + tool_name = next(t["name"] for t in tools if "with_dict_state" in t["name"]) + assert ( + _call_tool_output( + app, + tool_name, + { + "trigger": "hey", + "kept": "saved", + }, + "ds-out", + ) + == "hey+saved" + ) From 001a7b7f815b9aa3cedbe4ec311791453fc878b8 Mon Sep 17 00:00:00 2001 From: Adrian Borrmann Date: Thu, 9 Apr 2026 09:40:45 -0600 Subject: [PATCH 52/80] Enforce session management per MCP spec (404 for unknown sessions, 400 for missing session) --- dash/mcp/_server.py | 17 ++++++----- tests/integration/mcp/test_server.py | 43 ++++++---------------------- 2 files changed, 18 insertions(+), 42 deletions(-) diff --git a/dash/mcp/_server.py b/dash/mcp/_server.py index 1c6279290b..95d00578a2 100644 --- a/dash/mcp/_server.py +++ b/dash/mcp/_server.py @@ -116,23 +116,26 @@ def _handle_post() -> Response: request_id = data.get("id") session_id = request.headers.get("mcp-session-id") - stale_session = False if method == "initialize": session_id = _create_session() elif session_id and session_id not in sessions: - stale_session = True - sessions[session_id] = {} + return Response( + json.dumps({"error": "Session not found. Please reinitialize."}), + content_type="application/json", + status=404, + ) elif not session_id: - session_id = _create_session() + return Response( + json.dumps({"error": "Missing session ID. Send an initialize request first."}), + content_type="application/json", + status=400, + ) response_data = _process_mcp_message(data) if response_data is None: return Response("", status=202) - if stale_session: - _inject_warning(response_data, _STALE_SESSION_WARNING) - return Response( json.dumps(response_data), content_type="application/json", diff --git a/tests/integration/mcp/test_server.py b/tests/integration/mcp/test_server.py index 7af88bfbff..8917d0f5ab 100644 --- a/tests/integration/mcp/test_server.py +++ b/tests/integration/mcp/test_server.py @@ -60,7 +60,7 @@ def test_post_initialize_creates_session(self): data = json.loads(r.data) assert data["result"]["protocolVersion"] == LATEST_PROTOCOL_VERSION - def test_post_without_session_auto_assigns(self): + def test_post_without_session_returns_400(self): app = _make_app() client = app.server.test_client() r = client.post( @@ -70,12 +70,9 @@ def test_post_without_session_auto_assigns(self): ), content_type="application/json", ) - assert r.status_code == 200 - assert "mcp-session-id" in r.headers - data = json.loads(r.data) - assert "tools" in data["result"] + assert r.status_code == 400 - def test_stale_session_error_includes_hint(self): + def test_stale_session_returns_404(self): app = _make_app() client = app.server.test_client() r = client.post( @@ -83,18 +80,15 @@ def test_stale_session_error_includes_hint(self): data=json.dumps( { "jsonrpc": "2.0", - "method": "tools/call", + "method": "tools/list", "id": 1, - "params": {"name": "no_such_tool", "arguments": {}}, + "params": {}, } ), content_type="application/json", headers={"mcp-session-id": "old-session-from-before-restart"}, ) - assert r.status_code == 200 - data = json.loads(r.data) - assert "session was not recognised" in data["error"]["message"] - assert "tools/list" in data["error"]["message"] + assert r.status_code == 404 def test_post_with_valid_session(self): app = _make_app() @@ -161,7 +155,7 @@ def test_delete_terminates_session(self): headers={"mcp-session-id": session_id}, ) assert r2.status_code == 204 - # Post-delete requests still succeed + # Post-delete requests return 404 r3 = client.post( f"/{MCP_PATH}", data=json.dumps( @@ -170,7 +164,7 @@ def test_delete_terminates_session(self): content_type="application/json", headers={"mcp-session-id": session_id}, ) - assert r3.status_code == 200 + assert r3.status_code == 404 def test_delete_nonexistent_session_returns_404(self): app = _make_app() @@ -196,27 +190,6 @@ def test_get_with_stale_session_returns_404(self): ) assert r.status_code == 404 - def test_get_returns_sse_stream(self): - app = _make_app() - client = app.server.test_client() - # First create a session via POST initialize - init = client.post( - f"/{MCP_PATH}", - data=json.dumps( - {"jsonrpc": "2.0", "method": "initialize", "id": 1, "params": {}} - ), - content_type="application/json", - ) - session_id = init.headers["mcp-session-id"] - # GET with valid session returns SSE stream - r = client.get( - f"/{MCP_PATH}", - headers={"mcp-session-id": session_id}, - ) - assert r.status_code == 200 - assert r.content_type == "text/event-stream" - assert r.headers.get("Cache-Control") == "no-cache" - def test_post_rejects_wrong_content_type(self): app = _make_app() client = app.server.test_client() From 3fe035bfdc5883d71f424fed112cf867b37a6b27 Mon Sep 17 00:00:00 2001 From: Adrian Borrmann Date: Wed, 15 Apr 2026 16:55:23 -0600 Subject: [PATCH 53/80] remove unused code --- dash/mcp/_server.py | 33 +++------------------------------ 1 file changed, 3 insertions(+), 30 deletions(-) diff --git a/dash/mcp/_server.py b/dash/mcp/_server.py index 95d00578a2..1060d3b27f 100644 --- a/dash/mcp/_server.py +++ b/dash/mcp/_server.py @@ -126,7 +126,9 @@ def _handle_post() -> Response: ) elif not session_id: return Response( - json.dumps({"error": "Missing session ID. Send an initialize request first."}), + json.dumps( + {"error": "Missing session ID. Send an initialize request first."} + ), content_type="application/json", status=400, ) @@ -175,35 +177,6 @@ def _handle_delete() -> Response: ) -_STALE_SESSION_WARNING = ( - "[Warning: your session was not recognised" - " — the app may have restarted." - " Please call tools/list to refresh your tool list." - " Please ask the user to reconnect to the MCP server.]" -) - - -def _inject_warning(response_data: dict[str, Any], warning: str) -> None: - """Append a warning to a JSON-RPC response dict. - - For successful ``tools/call`` responses the warning is added as an - extra text content block so the agent sees it alongside the result. - For error responses the warning is appended to the error message. - Other responses (tools/list, resources/*) are left unchanged — the - JSON-RPC spec forbids extra top-level keys. - """ - # tools/call success: result has a "content" list - result = response_data.get("result") - if isinstance(result, dict) and isinstance(result.get("content"), list): - result["content"].append({"type": "text", "text": warning}) - return - - # Error response - error = response_data.get("error") - if isinstance(error, dict) and "message" in error: - error["message"] += " " + warning - - def _handle_initialize() -> InitializeResult: return InitializeResult( protocolVersion=LATEST_PROTOCOL_VERSION, From 3b4494496d10b2de874ad5cdd7ed41b290fa12af Mon Sep 17 00:00:00 2001 From: Adrian Borrmann Date: Thu, 16 Apr 2026 18:31:33 -0600 Subject: [PATCH 54/80] Fix types for mypy --- dash/mcp/__init__.py | 2 +- dash/mcp/_server.py | 3 +- dash/mcp/primitives/__init__.py | 10 ++-- .../resource_clientside_callbacks.py | 5 +- .../resources/resource_components.py | 7 +-- .../primitives/resources/resource_layout.py | 5 +- .../resources/resource_page_layout.py | 3 +- .../primitives/resources/resource_pages.py | 5 +- dash/mcp/primitives/tools/callback_adapter.py | 47 +++++++++++-------- .../description_pattern_matching.py | 2 +- .../input_schemas/schema_pattern_matching.py | 5 +- dash/mcp/primitives/tools/results/__init__.py | 2 +- .../tools/tool_get_dash_component.py | 2 +- dash/types.py | 9 ++-- 14 files changed, 63 insertions(+), 44 deletions(-) diff --git a/dash/mcp/__init__.py b/dash/mcp/__init__.py index 2677ea141b..2bc4757f13 100644 --- a/dash/mcp/__init__.py +++ b/dash/mcp/__init__.py @@ -3,5 +3,5 @@ from dash.mcp._server import enable_mcp_server __all__ = [ - enable_mcp_server, + "enable_mcp_server", ] diff --git a/dash/mcp/_server.py b/dash/mcp/_server.py index 1060d3b27f..24bbef4aeb 100644 --- a/dash/mcp/_server.py +++ b/dash/mcp/_server.py @@ -204,7 +204,8 @@ def _process_mcp_message(data: dict[str, Any]) -> dict[str, Any] | None: """ method = data.get("method", "") params = data.get("params", {}) or {} - request_id = data.get("id") + _id = data.get("id") + request_id: str | int = _id if isinstance(_id, (str, int)) else "" app = get_app() if not hasattr(app, "mcp_callback_map"): diff --git a/dash/mcp/primitives/__init__.py b/dash/mcp/primitives/__init__.py index b14839f1e1..e6b46a9af3 100644 --- a/dash/mcp/primitives/__init__.py +++ b/dash/mcp/primitives/__init__.py @@ -9,9 +9,9 @@ ) __all__ = [ - call_tool, - list_resources, - list_resource_templates, - list_tools, - read_resource, + "call_tool", + "list_resources", + "list_resource_templates", + "list_tools", + "read_resource", ] diff --git a/dash/mcp/primitives/resources/resource_clientside_callbacks.py b/dash/mcp/primitives/resources/resource_clientside_callbacks.py index 127c0f9adc..a8c0a0076a 100644 --- a/dash/mcp/primitives/resources/resource_clientside_callbacks.py +++ b/dash/mcp/primitives/resources/resource_clientside_callbacks.py @@ -10,6 +10,7 @@ Resource, TextResourceContents, ) +from pydantic import AnyUrl from dash import get_app from dash._utils import clean_property_name, split_callback_id @@ -25,7 +26,7 @@ def get_resource(cls) -> Resource | None: if not _get_clientside_callbacks(): return None return Resource( - uri=cls.uri, + uri=AnyUrl(cls.uri), name="dash_clientside_callbacks", description=( "Actions the user can take manually in the browser " @@ -52,7 +53,7 @@ def read_resource(cls, uri: str = "") -> ReadResourceResult: return ReadResourceResult( contents=[ TextResourceContents( - uri=cls.uri, + uri=AnyUrl(cls.uri), mimeType="application/json", text=json.dumps(data, default=str), ) diff --git a/dash/mcp/primitives/resources/resource_components.py b/dash/mcp/primitives/resources/resource_components.py index 9d035a855f..1f80c8bda2 100644 --- a/dash/mcp/primitives/resources/resource_components.py +++ b/dash/mcp/primitives/resources/resource_components.py @@ -9,6 +9,7 @@ Resource, TextResourceContents, ) +from pydantic import AnyUrl from dash import get_app from dash._layout_utils import traverse @@ -22,7 +23,7 @@ class ComponentsResource(MCPResourceProvider): @classmethod def get_resource(cls) -> Resource | None: return Resource( - uri=cls.uri, + uri=AnyUrl(cls.uri), name="dash_components", description=( "All components with IDs in the app layout. " @@ -41,7 +42,7 @@ def read_resource(cls, uri: str = "") -> ReadResourceResult: components = sorted( [ { - "id": str(comp.id), + "id": str(getattr(comp, "id", None)), "type": getattr(comp, "_type", type(comp).__name__), } for comp, _ in traverse(layout) @@ -53,7 +54,7 @@ def read_resource(cls, uri: str = "") -> ReadResourceResult: return ReadResourceResult( contents=[ TextResourceContents( - uri=cls.uri, + uri=AnyUrl(cls.uri), mimeType="application/json", text=json.dumps(components), ) diff --git a/dash/mcp/primitives/resources/resource_layout.py b/dash/mcp/primitives/resources/resource_layout.py index 753e2b9229..7659d1fd8f 100644 --- a/dash/mcp/primitives/resources/resource_layout.py +++ b/dash/mcp/primitives/resources/resource_layout.py @@ -7,6 +7,7 @@ Resource, TextResourceContents, ) +from pydantic import AnyUrl from dash import get_app from dash._utils import to_json @@ -20,7 +21,7 @@ class LayoutResource(MCPResourceProvider): @classmethod def get_resource(cls) -> Resource | None: return Resource( - uri=cls.uri, + uri=AnyUrl(cls.uri), name="dash_app_layout", description=( "Full component tree of the Dash app. " @@ -35,7 +36,7 @@ def read_resource(cls, uri: str = "") -> ReadResourceResult: return ReadResourceResult( contents=[ TextResourceContents( - uri=cls.uri, + uri=AnyUrl(cls.uri), mimeType="application/json", text=to_json(app.get_layout()), ) diff --git a/dash/mcp/primitives/resources/resource_page_layout.py b/dash/mcp/primitives/resources/resource_page_layout.py index 613f0b41b9..bbfca411bc 100644 --- a/dash/mcp/primitives/resources/resource_page_layout.py +++ b/dash/mcp/primitives/resources/resource_page_layout.py @@ -7,6 +7,7 @@ ResourceTemplate, TextResourceContents, ) +from pydantic import AnyUrl from dash import html from dash._pages import PAGE_REGISTRY @@ -55,7 +56,7 @@ def read_resource(cls, uri: str) -> ReadResourceResult: return ReadResourceResult( contents=[ TextResourceContents( - uri=uri, + uri=AnyUrl(uri), mimeType="application/json", text=to_json(page_layout), ) diff --git a/dash/mcp/primitives/resources/resource_pages.py b/dash/mcp/primitives/resources/resource_pages.py index 27c39013f3..21fa27679f 100644 --- a/dash/mcp/primitives/resources/resource_pages.py +++ b/dash/mcp/primitives/resources/resource_pages.py @@ -9,6 +9,7 @@ Resource, TextResourceContents, ) +from pydantic import AnyUrl from dash._pages import PAGE_REGISTRY @@ -23,7 +24,7 @@ def get_resource(cls) -> Resource | None: if not PAGE_REGISTRY: return None return Resource( - uri=cls.uri, + uri=AnyUrl(cls.uri), name="dash_app_pages", description=( "List of all pages in this multi-page Dash app " @@ -51,7 +52,7 @@ def read_resource(cls, uri: str = "") -> ReadResourceResult: return ReadResourceResult( contents=[ TextResourceContents( - uri=cls.uri, + uri=AnyUrl(cls.uri), mimeType="application/json", text=json.dumps(pages, default=str), ) diff --git a/dash/mcp/primitives/tools/callback_adapter.py b/dash/mcp/primitives/tools/callback_adapter.py index c94ba32f38..3000541a06 100644 --- a/dash/mcp/primitives/tools/callback_adapter.py +++ b/dash/mcp/primitives/tools/callback_adapter.py @@ -10,7 +10,7 @@ import json import typing from functools import cached_property -from typing import Any +from typing import Any, cast from mcp.types import Tool @@ -27,6 +27,7 @@ CallbackDependency, CallbackExecutionBody, CallbackInput, + CallbackInputs, CallbackOutput, CallbackOutputTarget, WildcardId, @@ -361,9 +362,7 @@ def _param_annotations(self) -> list[Any | None]: return [hints.get(func_name) for func_name, _ in self._dep_param_map] -def _expand_dep( - dep: CallbackDependency, value: Any -) -> CallbackInput | list[CallbackInput]: +def _expand_dep(dep: CallbackDependency, value: Any) -> CallbackInputs: """Attach a concrete value to a callback dependency to produce a valid callback input. For regular deps, returns ``{id, property, value}``. @@ -372,20 +371,20 @@ def _expand_dep( """ pattern = parse_wildcard_id(dep.get("id", "")) if pattern is None: - return {**dep, "value": value} + return CallbackInput(id=dep["id"], property=dep["property"], value=value) # LLM provides browser-like format if isinstance(value, list): - return value + return cast(list[CallbackInput], value) if isinstance(value, dict) and "id" in value: - return value - return {**dep, "value": value} + return cast(CallbackInput, value) + return CallbackInput(id=dep["id"], property=dep["property"], value=value) def _expand_output_spec( output_id: str, cb_info: dict, - resolved_inputs: list[CallbackInput], + resolved_inputs: list[CallbackInputs], ) -> CallbackOutputTarget | list[CallbackOutputTarget]: """Build the outputs spec, expanding wildcards to concrete IDs. @@ -408,15 +407,19 @@ def _expand_output_spec( if pattern is not None: concrete_ids = _derive_output_ids(pattern, resolved_inputs) if not concrete_ids: - concrete_ids = [comp.id for comp in find_matching_components(pattern)] - expanded = [{"id": cid, "property": prop} for cid in concrete_ids] + concrete_ids = [ + getattr(comp, "id") for comp in find_matching_components(pattern) + ] + expanded: list[CallbackDependency] = [ + CallbackDependency(id=cid, property=prop) for cid in concrete_ids + ] # ALL/ALLSMALLER → nested list; MATCH → single dict if len(expanded) == 1: results.append(expanded[0]) else: results.append(expanded) else: - results.append({"id": pid, "property": prop}) + results.append(CallbackDependency(id=pid, property=prop)) # Mirror the Dash renderer: single-output callbacks send a bare dict, # multi-output callbacks send a list. The framework's output value @@ -428,7 +431,7 @@ def _expand_output_spec( def _derive_output_ids( output_pattern: WildcardId, - resolved_inputs: list[CallbackInput], + resolved_inputs: list[CallbackInputs], ) -> list[WildcardId] | None: """Derive concrete output IDs from the resolved input entries. @@ -457,15 +460,19 @@ def _substitute(item_id: WildcardId) -> WildcardId | None: if isinstance(entry, list) and entry: concrete_ids = [] for item in entry: - out = _substitute(item.get("id")) - if out: - concrete_ids.append(out) + item_id = item.get("id") + if isinstance(item_id, dict): + out = _substitute(item_id) + if out: + concrete_ids.append(out) if concrete_ids: return concrete_ids # MATCH: single {id, property, value} dict - elif isinstance(entry, dict) and isinstance(entry.get("id"), dict): - out = _substitute(entry["id"]) - if out: - return [out] + elif isinstance(entry, dict): + entry_id = entry.get("id") + if isinstance(entry_id, dict): + out = _substitute(entry_id) + if out: + return [out] return None diff --git a/dash/mcp/primitives/tools/input_schemas/input_descriptions/description_pattern_matching.py b/dash/mcp/primitives/tools/input_schemas/input_descriptions/description_pattern_matching.py index 221423aa50..d9a1a5a26a 100644 --- a/dash/mcp/primitives/tools/input_schemas/input_descriptions/description_pattern_matching.py +++ b/dash/mcp/primitives/tools/input_schemas/input_descriptions/description_pattern_matching.py @@ -23,7 +23,7 @@ def describe(cls, param: MCPInput) -> list[str]: return [] wildcard_key, wildcard_type = _find_wildcard(dep_id) - if wildcard_key is None: + if wildcard_key is None or wildcard_type is None: return [] non_wildcard = {k: v for k, v in dep_id.items() if k != wildcard_key} diff --git a/dash/mcp/primitives/tools/input_schemas/schema_pattern_matching.py b/dash/mcp/primitives/tools/input_schemas/schema_pattern_matching.py index 52e16cf58b..093dc197b8 100644 --- a/dash/mcp/primitives/tools/input_schemas/schema_pattern_matching.py +++ b/dash/mcp/primitives/tools/input_schemas/schema_pattern_matching.py @@ -66,7 +66,10 @@ def _get_wildcard_type(dep_id: dict) -> str | None: def _infer_value_schema(param: MCPInput) -> dict[str, Any] | None: """Infer the JSON Schema for the ``value`` field from a matching component.""" - matches = find_matching_components(parse_wildcard_id(param["component_id"])) + pattern = parse_wildcard_id(param["component_id"]) + if pattern is None: + return None + matches = find_matching_components(pattern) if not matches: return None diff --git a/dash/mcp/primitives/tools/results/__init__.py b/dash/mcp/primitives/tools/results/__init__.py index ae3517919c..09e86410a7 100644 --- a/dash/mcp/primitives/tools/results/__init__.py +++ b/dash/mcp/primitives/tools/results/__init__.py @@ -48,5 +48,5 @@ def format_callback_response( return CallToolResult( content=content, - structuredContent=response, + structuredContent=dict(response), ) diff --git a/dash/mcp/primitives/tools/tool_get_dash_component.py b/dash/mcp/primitives/tools/tool_get_dash_component.py index 5dffc2cf58..8c131c4288 100644 --- a/dash/mcp/primitives/tools/tool_get_dash_component.py +++ b/dash/mcp/primitives/tools/tool_get_dash_component.py @@ -125,5 +125,5 @@ def call_tool(cls, tool_name: str, arguments: dict[str, Any]) -> CallToolResult: content=[ TextContent(type="text", text=json.dumps(structured, default=str)) ], - structuredContent=structured, + structuredContent=dict(structured), ) diff --git a/dash/types.py b/dash/types.py index cbc94b8151..9da246b16c 100644 --- a/dash/types.py +++ b/dash/types.py @@ -71,11 +71,14 @@ class CallbackInput(TypedDict): value: Any +CallbackInputs = Union[CallbackInput, List[CallbackInput]] + + class CallbackExecutionBody(TypedDict): output: str - outputs: List[CallbackOutputTarget] - inputs: List[CallbackInput] - state: List[CallbackInput] + outputs: Union[CallbackOutputTarget, List[CallbackOutputTarget]] + inputs: List[CallbackInputs] + state: List[CallbackInputs] changedPropIds: List[str] From 9dc1f104140c55e2e8eb3d63bce3e267c8665093 Mon Sep 17 00:00:00 2001 From: Adrian Borrmann Date: Mon, 20 Apr 2026 15:30:16 -0600 Subject: [PATCH 55/80] Remove unused SSE code for initial MCP implementation --- dash/mcp/_server.py | 87 +++---------- dash/mcp/_sse.py | 67 ---------- dash/mcp/notifications/__init__.py | 7 -- .../notification_tools_changed.py | 30 ----- tests/integration/mcp/conftest.py | 22 +--- tests/integration/mcp/test_server.py | 116 ++---------------- tests/unit/mcp/test_server.py | 6 - 7 files changed, 30 insertions(+), 305 deletions(-) delete mode 100644 dash/mcp/_sse.py delete mode 100644 dash/mcp/notifications/__init__.py delete mode 100644 dash/mcp/notifications/notification_tools_changed.py diff --git a/dash/mcp/_server.py b/dash/mcp/_server.py index 24bbef4aeb..c00ec6c398 100644 --- a/dash/mcp/_server.py +++ b/dash/mcp/_server.py @@ -2,10 +2,8 @@ from __future__ import annotations -import atexit import json import logging -import uuid from typing import TYPE_CHECKING, Any from flask import Response, request @@ -30,11 +28,6 @@ ) from dash.version import __version__ -from dash.mcp._sse import ( - close_sse_stream, - create_sse_stream, - shutdown_all_streams, -) from dash.mcp.primitives import ( call_tool, list_resource_templates, @@ -50,25 +43,7 @@ def enable_mcp_server(app: Dash, mcp_path: str) -> None: - """ - Add MCP routes to a Dash/Flask app. - - Registers a single Streamable HTTP endpoint for the MCP protocol. - Uses ``app._add_url()`` so that ``routes_pathname_prefix`` is applied - automatically. - - Args: - app: The Dash application instance. - mcp_path: Route prefix for MCP endpoints. - """ - # Session storage: session_id -> metadata - sessions: dict[str, dict[str, Any]] = {} - - def _create_session() -> str: - sid = str(uuid.uuid4()) - sessions[sid] = {} - return sid - + """Add MCP routes to a Dash/Flask app.""" # -- Streamable HTTP endpoint -------------------------------------------- def mcp_handler() -> Response: @@ -85,14 +60,13 @@ def mcp_handler() -> Response: ) def _handle_get() -> Response: - session_id = request.headers.get("mcp-session-id") - if not session_id or session_id not in sessions: - return Response( - json.dumps({"error": "Session not found"}), - content_type="application/json", - status=404, - ) - return create_sse_stream(sessions, session_id) + # MCP spec allows servers to opt out of GET-initiated SSE streams + # by returning 405. We don't push server-initiated events. + return Response( + json.dumps({"error": "Method not allowed"}), + content_type="application/json", + status=405, + ) def _handle_post() -> Response: content_type = request.content_type or "" @@ -112,27 +86,6 @@ def _handle_post() -> Response: status=400, ) - method = data.get("method", "") - request_id = data.get("id") - session_id = request.headers.get("mcp-session-id") - - if method == "initialize": - session_id = _create_session() - elif session_id and session_id not in sessions: - return Response( - json.dumps({"error": "Session not found. Please reinitialize."}), - content_type="application/json", - status=404, - ) - elif not session_id: - return Response( - json.dumps( - {"error": "Missing session ID. Send an initialize request first."} - ), - content_type="application/json", - status=400, - ) - response_data = _process_mcp_message(data) if response_data is None: @@ -142,21 +95,15 @@ def _handle_post() -> Response: json.dumps(response_data), content_type="application/json", status=200, - headers={"mcp-session-id": session_id}, ) def _handle_delete() -> Response: - session_id = request.headers.get("mcp-session-id") - if not session_id or session_id not in sessions: - return Response( - json.dumps({"error": "Session not found"}), - content_type="application/json", - status=404, - ) - close_sse_stream(sessions[session_id]) - del sessions[session_id] - logger.info("MCP session terminated: %s", session_id) - return Response("", status=204) + # No sessions to terminate — server is stateless. + return Response( + json.dumps({"error": "Method not allowed"}), + content_type="application/json", + status=405, + ) # -- Register routes ----------------------------------------------------- @@ -166,10 +113,6 @@ def _handle_delete() -> Response: mcp_path, with_app_context_factory(mcp_handler, app), ["GET", "POST", "DELETE"] ) - # Close all SSE streams on server shutdown so MCP clients see a - # clean stream end and can reconnect promptly. - atexit.register(shutdown_all_streams, sessions) - logger.info( "MCP routes registered at %s%s", app.config.routes_pathname_prefix, @@ -181,7 +124,7 @@ def _handle_initialize() -> InitializeResult: return InitializeResult( protocolVersion=LATEST_PROTOCOL_VERSION, capabilities=ServerCapabilities( - tools=ToolsCapability(listChanged=True), + tools=ToolsCapability(listChanged=False), resources=ResourcesCapability(), ), serverInfo=Implementation(name="Plotly Dash", version=__version__), diff --git a/dash/mcp/_sse.py b/dash/mcp/_sse.py deleted file mode 100644 index 4928dc68b2..0000000000 --- a/dash/mcp/_sse.py +++ /dev/null @@ -1,67 +0,0 @@ -"""SSE stream generation and queue management.""" - -from __future__ import annotations - -import queue -from typing import Any - -from flask import Response - - -def create_sse_stream(sessions: dict[str, dict[str, Any]], session_id: str) -> Response: - """Create a Server-Sent Events stream for the given session. - - Stores a :class:`queue.Queue` in ``sessions[session_id]["sse_queue"]`` - and returns a Flask streaming ``Response``. The generator yields - events pushed to the queue, with keepalive comments every 30 seconds. - """ - event_queue: queue.Queue[str | None] = queue.Queue() - # Replace any prior SSE queue for this session (client reconnect). - sessions[session_id]["sse_queue"] = event_queue - - def _generate(): - try: - while True: - try: - event = event_queue.get(timeout=30) - if event is None: - return # Sentinel: server closing stream - yield f"event: message\ndata: {event}\n\n" - except queue.Empty: - yield ": keepalive\n\n" - except GeneratorExit: - pass - finally: - # Clean up queue reference if it's still ours. - if sessions.get(session_id, {}).get("sse_queue") is event_queue: - sessions[session_id].pop("sse_queue", None) - - return Response( - _generate(), - content_type="text/event-stream", - headers={ - "Cache-Control": "no-cache", - "Connection": "keep-alive", - "mcp-session-id": session_id, - }, - ) - - -def close_sse_stream(session_data: dict[str, Any]) -> None: - """Send a sentinel to shut down the session's SSE stream cleanly.""" - sse_queue = session_data.get("sse_queue") - if sse_queue is not None: - try: - sse_queue.put_nowait(None) - except queue.Full: - pass - - -def shutdown_all_streams(sessions: dict[str, dict[str, Any]]) -> None: - """Close all active SSE streams. - - Called during server shutdown (via ``atexit``) so that connected - MCP clients see a clean stream end and can reconnect promptly. - """ - for session_data in list(sessions.values()): - close_sse_stream(session_data) diff --git a/dash/mcp/notifications/__init__.py b/dash/mcp/notifications/__init__.py deleted file mode 100644 index b1fe9e8665..0000000000 --- a/dash/mcp/notifications/__init__.py +++ /dev/null @@ -1,7 +0,0 @@ -"""Server-initiated MCP notifications.""" - -from .notification_tools_changed import broadcast_tools_changed - -__all__ = [ - "broadcast_tools_changed", -] diff --git a/dash/mcp/notifications/notification_tools_changed.py b/dash/mcp/notifications/notification_tools_changed.py deleted file mode 100644 index 1970667d1a..0000000000 --- a/dash/mcp/notifications/notification_tools_changed.py +++ /dev/null @@ -1,30 +0,0 @@ -"""Tool list change notifications.""" - -from __future__ import annotations - -import json -import queue -from typing import Any - - -def broadcast_tools_changed( - sessions: dict[str, dict[str, Any]], -) -> None: - """Push a tools/list_changed notification to all active SSE streams. - - Not called automatically yet — available for future hot-reload - or dynamic callback registration. - """ - notification = json.dumps( - { - "jsonrpc": "2.0", - "method": "notifications/tools/list_changed", - } - ) - for data in sessions.values(): - sse_queue = data.get("sse_queue") - if sse_queue is not None: - try: - sse_queue.put_nowait(notification) - except queue.Full: - pass diff --git a/tests/integration/mcp/conftest.py b/tests/integration/mcp/conftest.py index 0f212d1763..b833bc1dea 100644 --- a/tests/integration/mcp/conftest.py +++ b/tests/integration/mcp/conftest.py @@ -3,10 +3,7 @@ import requests -def _mcp_post(server_url, method, params=None, session_id=None, request_id=1): - headers = {"Content-Type": "application/json"} - if session_id: - headers["mcp-session-id"] = session_id +def _mcp_post(server_url, method, params=None, request_id=1): return requests.post( f"{server_url}/_mcp", json={ @@ -15,39 +12,28 @@ def _mcp_post(server_url, method, params=None, session_id=None, request_id=1): "id": request_id, "params": params or {}, }, - headers=headers, + headers={"Content-Type": "application/json"}, timeout=5, ) -def _mcp_session(server_url): - resp = _mcp_post(server_url, "initialize") - resp.raise_for_status() - return resp.headers["mcp-session-id"] - - def _mcp_tools(server_url): - sid = _mcp_session(server_url) - resp = _mcp_post(server_url, "tools/list", session_id=sid, request_id=2) + resp = _mcp_post(server_url, "tools/list") resp.raise_for_status() return resp.json()["result"]["tools"] def _mcp_call_tool(server_url, tool_name, arguments=None): - sid = _mcp_session(server_url) resp = _mcp_post( server_url, "tools/call", {"name": tool_name, "arguments": arguments or {}}, - session_id=sid, - request_id=2, ) resp.raise_for_status() return resp.json() def _mcp_method(server_url, method, params=None): - sid = _mcp_session(server_url) - resp = _mcp_post(server_url, method, params, session_id=sid, request_id=2) + resp = _mcp_post(server_url, method, params) resp.raise_for_status() return resp.json() diff --git a/tests/integration/mcp/test_server.py b/tests/integration/mcp/test_server.py index 8917d0f5ab..4f0d0fca00 100644 --- a/tests/integration/mcp/test_server.py +++ b/tests/integration/mcp/test_server.py @@ -45,7 +45,7 @@ def update_output(value): class TestMCPEndpoint: """Tests for the Streamable HTTP MCP endpoint at /_mcp.""" - def test_post_initialize_creates_session(self): + def test_post_initialize_returns_protocol_version(self): app = _make_app() client = app.server.test_client() r = client.post( @@ -56,11 +56,10 @@ def test_post_initialize_creates_session(self): content_type="application/json", ) assert r.status_code == 200 - assert "mcp-session-id" in r.headers data = json.loads(r.data) assert data["result"]["protocolVersion"] == LATEST_PROTOCOL_VERSION - def test_post_without_session_returns_400(self): + def test_post_tools_list(self): app = _make_app() client = app.server.test_client() r = client.post( @@ -70,125 +69,32 @@ def test_post_without_session_returns_400(self): ), content_type="application/json", ) - assert r.status_code == 400 - - def test_stale_session_returns_404(self): - app = _make_app() - client = app.server.test_client() - r = client.post( - f"/{MCP_PATH}", - data=json.dumps( - { - "jsonrpc": "2.0", - "method": "tools/list", - "id": 1, - "params": {}, - } - ), - content_type="application/json", - headers={"mcp-session-id": "old-session-from-before-restart"}, - ) - assert r.status_code == 404 - - def test_post_with_valid_session(self): - app = _make_app() - client = app.server.test_client() - # Initialize to get session - r1 = client.post( - f"/{MCP_PATH}", - data=json.dumps( - {"jsonrpc": "2.0", "method": "initialize", "id": 1, "params": {}} - ), - content_type="application/json", - ) - session_id = r1.headers["mcp-session-id"] - # Use session for tools/list - r2 = client.post( - f"/{MCP_PATH}", - data=json.dumps( - {"jsonrpc": "2.0", "method": "tools/list", "id": 2, "params": {}} - ), - content_type="application/json", - headers={"mcp-session-id": session_id}, - ) - assert r2.status_code == 200 - data = json.loads(r2.data) + assert r.status_code == 200 + data = json.loads(r.data) assert "result" in data assert "tools" in data["result"] def test_notification_returns_202(self): app = _make_app() client = app.server.test_client() - # Initialize to get session - r1 = client.post( - f"/{MCP_PATH}", - data=json.dumps( - {"jsonrpc": "2.0", "method": "initialize", "id": 1, "params": {}} - ), - content_type="application/json", - ) - session_id = r1.headers["mcp-session-id"] - # Send notification (no id field) - r2 = client.post( + r = client.post( f"/{MCP_PATH}", data=json.dumps({"jsonrpc": "2.0", "method": "notifications/initialized"}), content_type="application/json", - headers={"mcp-session-id": session_id}, ) - assert r2.status_code == 202 + assert r.status_code == 202 - def test_delete_terminates_session(self): + def test_delete_returns_405(self): app = _make_app() client = app.server.test_client() - # Initialize - r1 = client.post( - f"/{MCP_PATH}", - data=json.dumps( - {"jsonrpc": "2.0", "method": "initialize", "id": 1, "params": {}} - ), - content_type="application/json", - ) - session_id = r1.headers["mcp-session-id"] - # Delete - r2 = client.delete( - f"/{MCP_PATH}", - headers={"mcp-session-id": session_id}, - ) - assert r2.status_code == 204 - # Post-delete requests return 404 - r3 = client.post( - f"/{MCP_PATH}", - data=json.dumps( - {"jsonrpc": "2.0", "method": "tools/list", "id": 2, "params": {}} - ), - content_type="application/json", - headers={"mcp-session-id": session_id}, - ) - assert r3.status_code == 404 + r = client.delete(f"/{MCP_PATH}") + assert r.status_code == 405 - def test_delete_nonexistent_session_returns_404(self): - app = _make_app() - client = app.server.test_client() - r = client.delete( - f"/{MCP_PATH}", - headers={"mcp-session-id": "nonexistent"}, - ) - assert r.status_code == 404 - - def test_get_without_session_returns_404(self): + def test_get_returns_405(self): app = _make_app() client = app.server.test_client() r = client.get(f"/{MCP_PATH}") - assert r.status_code == 404 - - def test_get_with_stale_session_returns_404(self): - app = _make_app() - client = app.server.test_client() - r = client.get( - f"/{MCP_PATH}", - headers={"mcp-session-id": "nonexistent"}, - ) - assert r.status_code == 404 + assert r.status_code == 405 def test_post_rejects_wrong_content_type(self): app = _make_app() diff --git a/tests/unit/mcp/test_server.py b/tests/unit/mcp/test_server.py index 93238faf19..23c99c50ad 100644 --- a/tests/unit/mcp/test_server.py +++ b/tests/unit/mcp/test_server.py @@ -51,12 +51,6 @@ def test_initialize(self): assert result["result"]["protocolVersion"] == LATEST_PROTOCOL_VERSION assert "serverInfo" in result["result"] - def test_initialize_advertises_list_changed(self): - app = _make_app() - result = _mcp(app, "initialize") - caps = result["result"]["capabilities"] - assert caps["tools"]["listChanged"] is True - def test_tools_call(self): app = _make_app() tools = _tools_list(app) From f6403287b1c9e965087f699078e8f938aa91c5d5 Mon Sep 17 00:00:00 2001 From: Adrian Borrmann Date: Tue, 21 Apr 2026 11:41:01 -0600 Subject: [PATCH 56/80] add app-level config for exposing callback docstrings in MCP tools --- dash/_configs.py | 1 + dash/dash.py | 4 +++ tests/unit/mcp/tools/test_mcp_tools.py | 45 ++++++++++++++++++++++++++ 3 files changed, 50 insertions(+) diff --git a/dash/_configs.py b/dash/_configs.py index 0e1ab75505..25a401523b 100644 --- a/dash/_configs.py +++ b/dash/_configs.py @@ -35,6 +35,7 @@ def load_dash_env_vars(): "DASH_COMPRESS", "DASH_MCP_ENABLED", "DASH_MCP_PATH", + "DASH_MCP_EXPOSE_DOCSTRINGS", "HOST", "PORT", ) diff --git a/dash/dash.py b/dash/dash.py index c6fa5a6bf3..3912c638cf 100644 --- a/dash/dash.py +++ b/dash/dash.py @@ -488,6 +488,7 @@ def __init__( # pylint: disable=too-many-statements, too-many-branches websocket_inactivity_timeout: Optional[int] = 300000, enable_mcp: Optional[bool] = None, mcp_path: Optional[str] = None, + mcp_expose_docstrings: Optional[bool] = None, **obsolete, ): @@ -569,6 +570,9 @@ def __init__( # pylint: disable=too-many-statements, too-many-branches hide_all_callbacks=False, csrf_token_name=csrf_token_name, csrf_header_name=csrf_header_name, + mcp_expose_docstrings=get_combined_config( + "mcp_expose_docstrings", mcp_expose_docstrings, False + ), ) self.config.set_read_only( [ diff --git a/tests/unit/mcp/tools/test_mcp_tools.py b/tests/unit/mcp/tools/test_mcp_tools.py index cacaf13b14..3255809982 100644 --- a/tests/unit/mcp/tools/test_mcp_tools.py +++ b/tests/unit/mcp/tools/test_mcp_tools.py @@ -309,6 +309,51 @@ def test_mcpt014_typed_annotation_narrows_schema(typed_app): assert tool.inputSchema["properties"]["val"]["type"] == "string" +def test_mcpt016_app_level_opt_in_exposes_docstrings(): + """Dash(mcp_expose_docstrings=True) exposes docstrings for all callbacks.""" + app = Dash(__name__, mcp_expose_docstrings=True) + app.layout = html.Div([dcc.Input(id="inp"), html.Div(id="out")]) + + @app.callback(Output("out", "children"), Input("inp", "value")) + def update(val): + """intentionally-exposed callback docstring text for the LLM""" + return val + + app_context.set(app) + app.mcp_callback_map = CallbackAdapterCollection(app) + + with app.server.test_request_context(): + tool = app.mcp_callback_map[0].as_mcp_tool + assert ( + "intentionally-exposed callback docstring text for the LLM" in tool.description + ) + + +def test_mcpt017_per_callback_false_overrides_app_level_opt_in(): + """Per-callback mcp_expose_docstring=False wins over app-level opt-in.""" + app = Dash(__name__, mcp_expose_docstrings=True) + app.layout = html.Div([dcc.Input(id="inp"), html.Div(id="out")]) + + @app.callback( + Output("out", "children"), + Input("inp", "value"), + mcp_expose_docstring=False, + ) + def update(val): + """sensitive callback docstring text that must not leak to LLMs""" + return val + + app_context.set(app) + app.mcp_callback_map = CallbackAdapterCollection(app) + + with app.server.test_request_context(): + tool = app.mcp_callback_map[0].as_mcp_tool + assert ( + "sensitive callback docstring text that must not leak to LLMs" + not in tool.description + ) + + # --------------------------------------------------------------------------- # Tests — end-to-end Tool shape # --------------------------------------------------------------------------- From d108faf070aed5abf06e76c15219711fb696dc9a Mon Sep 17 00:00:00 2001 From: Adrian Borrmann Date: Tue, 21 Apr 2026 11:48:43 -0600 Subject: [PATCH 57/80] Disable MCP server by default --- dash/dash.py | 2 +- tests/integration/mcp/conftest.py | 7 +++++++ 2 files changed, 8 insertions(+), 1 deletion(-) diff --git a/dash/dash.py b/dash/dash.py index 3912c638cf..c34dca722b 100644 --- a/dash/dash.py +++ b/dash/dash.py @@ -604,7 +604,7 @@ def __init__( # pylint: disable=too-many-statements, too-many-branches self.title = title # MCP (Model Context Protocol) configuration - self._enable_mcp = get_combined_config("mcp_enabled", enable_mcp, True) + self._enable_mcp = get_combined_config("mcp_enabled", enable_mcp, False) _mcp_path = get_combined_config("mcp_path", mcp_path, "_mcp") self._mcp_path = ( _mcp_path.lstrip("/") if isinstance(_mcp_path, str) else _mcp_path diff --git a/tests/integration/mcp/conftest.py b/tests/integration/mcp/conftest.py index b833bc1dea..c81ffceb48 100644 --- a/tests/integration/mcp/conftest.py +++ b/tests/integration/mcp/conftest.py @@ -1,8 +1,15 @@ """Shared helpers for MCP integration tests.""" +import pytest import requests +@pytest.fixture(autouse=True) +def _enable_mcp_for_integration_tests(monkeypatch): + """MCP is off by default; integration tests need it on.""" + monkeypatch.setenv("DASH_MCP_ENABLED", "true") + + def _mcp_post(server_url, method, params=None, request_id=1): return requests.post( f"{server_url}/_mcp", From eb94e143654f56c0632cdd98d2b5debd35b63194 Mon Sep 17 00:00:00 2001 From: Adrian Borrmann Date: Wed, 22 Apr 2026 16:04:45 -0600 Subject: [PATCH 58/80] lint --- dash/_layout_utils.py | 2 +- dash/mcp/_server.py | 36 +++++++++---------- dash/mcp/primitives/tools/callback_adapter.py | 2 +- tests/integration/mcp/conftest.py | 6 ++++ 4 files changed, 26 insertions(+), 20 deletions(-) diff --git a/dash/_layout_utils.py b/dash/_layout_utils.py index fdca86edca..d421771afd 100644 --- a/dash/_layout_utils.py +++ b/dash/_layout_utils.py @@ -117,7 +117,7 @@ def _collect_components(value: Any) -> list[Component]: if isinstance(value, Component): return [value] if isinstance(value, (list, tuple)): - return [item for item in value if isinstance(item, (Component, list, tuple))] + return [item for item in value if isinstance(item, Component)] return [] diff --git a/dash/mcp/_server.py b/dash/mcp/_server.py index c00ec6c398..07b35cd373 100644 --- a/dash/mcp/_server.py +++ b/dash/mcp/_server.py @@ -1,5 +1,9 @@ """Flask route setup, Streamable HTTP transport, and MCP message handling.""" +# pylint: disable=cyclic-import +# The MCP server imports dash primitives to dispatch callbacks, and dash +# lazy-imports this module to wire the MCP endpoint. Cycle is managed here. + from __future__ import annotations import json @@ -7,14 +11,6 @@ from typing import TYPE_CHECKING, Any from flask import Response, request - -from dash.mcp.types import MCPError - -if TYPE_CHECKING: - from dash import Dash - -from dash import get_app - from mcp.types import ( LATEST_PROTOCOL_VERSION, ErrorData, @@ -27,7 +23,8 @@ ToolsCapability, ) -from dash.version import __version__ +from dash import get_app +from dash._get_app import with_app_context_factory from dash.mcp.primitives import ( call_tool, list_resource_templates, @@ -38,6 +35,11 @@ from dash.mcp.primitives.tools.callback_adapter_collection import ( CallbackAdapterCollection, ) +from dash.mcp.types import MCPError +from dash.version import __version__ + +if TYPE_CHECKING: + from dash import Dash logger = logging.getLogger(__name__) @@ -77,9 +79,8 @@ def _handle_post() -> Response: status=415, ) - try: - data = request.get_json() - except Exception: + data = request.get_json(silent=True) + if data is None: return Response( json.dumps({"error": "Invalid JSON"}), content_type="application/json", @@ -107,8 +108,7 @@ def _handle_delete() -> Response: # -- Register routes ----------------------------------------------------- - from dash._get_app import with_app_context_factory - + # pylint: disable-next=protected-access app._add_url( mcp_path, with_app_context_factory(mcp_handler, app), ["GET", "POST", "DELETE"] ) @@ -156,12 +156,12 @@ def _process_mcp_message(data: dict[str, Any]) -> dict[str, Any] | None: mcp_methods = { "initialize": _handle_initialize, - "tools/list": lambda: list_tools(), + "tools/list": list_tools, "tools/call": lambda: call_tool( params.get("name", ""), params.get("arguments", {}) ), - "resources/list": lambda: list_resources(), - "resources/templates/list": lambda: list_resource_templates(), + "resources/list": list_resources, + "resources/templates/list": list_resource_templates, "resources/read": lambda: read_resource(params.get("uri", "")), } @@ -188,7 +188,7 @@ def _process_mcp_message(data: dict[str, Any]) -> dict[str, Any] | None: id=request_id, error=ErrorData(code=e.code, message=str(e)), ).model_dump(exclude_none=True) - except Exception as e: + except Exception as e: # pylint: disable=broad-exception-caught logger.error("MCP error: %s", e, exc_info=True) return JSONRPCError( jsonrpc="2.0", diff --git a/dash/mcp/primitives/tools/callback_adapter.py b/dash/mcp/primitives/tools/callback_adapter.py index 3000541a06..8130c6a8a1 100644 --- a/dash/mcp/primitives/tools/callback_adapter.py +++ b/dash/mcp/primitives/tools/callback_adapter.py @@ -375,7 +375,7 @@ def _expand_dep(dep: CallbackDependency, value: Any) -> CallbackInputs: # LLM provides browser-like format if isinstance(value, list): - return cast(list[CallbackInput], value) + return cast("list[CallbackInput]", value) if isinstance(value, dict) and "id" in value: return cast(CallbackInput, value) return CallbackInput(id=dep["id"], property=dep["property"], value=value) diff --git a/tests/integration/mcp/conftest.py b/tests/integration/mcp/conftest.py index c81ffceb48..5030211ed8 100644 --- a/tests/integration/mcp/conftest.py +++ b/tests/integration/mcp/conftest.py @@ -1,8 +1,14 @@ """Shared helpers for MCP integration tests.""" +import sys + import pytest import requests +collect_ignore_glob = [] +if sys.version_info < (3, 10): + collect_ignore_glob.append("*") + @pytest.fixture(autouse=True) def _enable_mcp_for_integration_tests(monkeypatch): From fdaf822496048c14af4a92c72ced7cf357d92152 Mon Sep 17 00:00:00 2001 From: Adrian Borrmann Date: Thu, 23 Apr 2026 10:51:46 -0600 Subject: [PATCH 59/80] fix leaky state in tests --- tests/integration/mcp/conftest.py | 13 +++++++++++++ 1 file changed, 13 insertions(+) diff --git a/tests/integration/mcp/conftest.py b/tests/integration/mcp/conftest.py index 5030211ed8..0db1b775e8 100644 --- a/tests/integration/mcp/conftest.py +++ b/tests/integration/mcp/conftest.py @@ -5,6 +5,8 @@ import pytest import requests +from dash import _get_app + collect_ignore_glob = [] if sys.version_info < (3, 10): collect_ignore_glob.append("*") @@ -16,6 +18,17 @@ def _enable_mcp_for_integration_tests(monkeypatch): monkeypatch.setenv("DASH_MCP_ENABLED", "true") +@pytest.fixture(autouse=True) +def _reset_dash_app_state(): + """Reset Dash module-level state after each MCP test. + + TODO: this can be removed when 4.2 backend work lands + """ + yield + _get_app.APP = None + _get_app.app_context.set(None) + + def _mcp_post(server_url, method, params=None, request_id=1): return requests.post( f"{server_url}/_mcp", From bfc61f23e832798699e3b35d3f7c3195c68d94f6 Mon Sep 17 00:00:00 2001 From: Adrian Borrmann Date: Thu, 23 Apr 2026 14:49:31 -0600 Subject: [PATCH 60/80] Refactor unit tests to conform to existing test patterns --- .../tools/test_duplicate_outputs.py | 128 ----- .../primitives/tools/test_input_schemas.py | 66 --- .../tools/test_tool_get_dash_component.py | 54 -- .../mcp/primitives/tools/test_tools_list.py | 118 ---- ...tures.py => test_mcp_callback_behavior.py} | 508 ++++++++++++++---- tests/integration/mcp/test_mcp_endpoint.py | 189 +++++++ ...est_resources.py => test_mcp_resources.py} | 6 +- tests/integration/mcp/test_server.py | 183 ------- tests/unit/mcp/test_mcp_server.py | 99 ++++ tests/unit/mcp/test_server.py | 86 --- tests/unit/mcp/tools/test_mcp_run_callback.py | 253 +++++++++ tests/unit/mcp/tools/test_run_callback.py | 246 --------- 12 files changed, 947 insertions(+), 989 deletions(-) delete mode 100644 tests/integration/mcp/primitives/tools/test_duplicate_outputs.py delete mode 100644 tests/integration/mcp/primitives/tools/test_input_schemas.py delete mode 100644 tests/integration/mcp/primitives/tools/test_tool_get_dash_component.py delete mode 100644 tests/integration/mcp/primitives/tools/test_tools_list.py rename tests/integration/mcp/{primitives/tools/test_callback_signatures.py => test_mcp_callback_behavior.py} (66%) create mode 100644 tests/integration/mcp/test_mcp_endpoint.py rename tests/integration/mcp/{primitives/resources/test_resources.py => test_mcp_resources.py} (86%) delete mode 100644 tests/integration/mcp/test_server.py create mode 100644 tests/unit/mcp/test_mcp_server.py delete mode 100644 tests/unit/mcp/test_server.py create mode 100644 tests/unit/mcp/tools/test_mcp_run_callback.py delete mode 100644 tests/unit/mcp/tools/test_run_callback.py diff --git a/tests/integration/mcp/primitives/tools/test_duplicate_outputs.py b/tests/integration/mcp/primitives/tools/test_duplicate_outputs.py deleted file mode 100644 index 4ad00641f8..0000000000 --- a/tests/integration/mcp/primitives/tools/test_duplicate_outputs.py +++ /dev/null @@ -1,128 +0,0 @@ -"""Integration test for duplicate callback outputs. - -Multiple callbacks can output to the same component.property -when using ``allow_duplicate=True``. The MCP server must handle -this correctly — both callbacks should appear as tools, and -calling either should work. -""" - -from dash import Dash, Input, Output, dcc, html - -from tests.integration.mcp.conftest import _mcp_call_tool, _mcp_tools - - -def _find_tool(tools, name): - return next((t for t in tools if t["name"] == name), None) - - -def _get_response(result): - return result["result"]["structuredContent"]["response"] - - -def test_duplicate_outputs_both_tools_listed(dash_duo): - """Both callbacks outputting to the same component appear as tools.""" - app = Dash(__name__) - app.layout = html.Div( - [ - dcc.Input(id="first-name", value="Jane"), - dcc.Input(id="last-name", value="Doe"), - html.Div(id="greeting"), - ] - ) - - @app.callback( - Output("greeting", "children"), - Input("first-name", "value"), - ) - def greet_by_first(first): - return f"Hello, {first}!" - - @app.callback( - Output("greeting", "children", allow_duplicate=True), - Input("last-name", "value"), - prevent_initial_call=True, - ) - def greet_by_last(last): - return f"Hi, {last}!" - - dash_duo.start_server(app) - tools = _mcp_tools(dash_duo.server.url) - - first_tool = _find_tool(tools, "greet_by_first") - last_tool = _find_tool(tools, "greet_by_last") - - assert first_tool is not None, "greet_by_first should be listed" - assert last_tool is not None, "greet_by_last should be listed" - - -def test_duplicate_outputs_both_callable(dash_duo): - """Both callbacks can be called and produce correct results.""" - app = Dash(__name__) - app.layout = html.Div( - [ - dcc.Input(id="first-name", value="Jane"), - dcc.Input(id="last-name", value="Doe"), - html.Div(id="greeting"), - ] - ) - - @app.callback( - Output("greeting", "children"), - Input("first-name", "value"), - ) - def greet_by_first(first): - return f"Hello, {first}!" - - @app.callback( - Output("greeting", "children", allow_duplicate=True), - Input("last-name", "value"), - prevent_initial_call=True, - ) - def greet_by_last(last): - return f"Hi, {last}!" - - dash_duo.start_server(app) - - result1 = _mcp_call_tool(dash_duo.server.url, "greet_by_first", {"first": "Alice"}) - assert _get_response(result1)["greeting"]["children"] == "Hello, Alice!" - - result2 = _mcp_call_tool(dash_duo.server.url, "greet_by_last", {"last": "Smith"}) - assert _get_response(result2)["greeting"]["children"] == "Hi, Smith!" - - -def test_duplicate_outputs_find_by_output_returns_primary(dash_duo): - """find_by_output returns the primary (non-duplicate) callback.""" - app = Dash(__name__) - app.layout = html.Div( - [ - dcc.Input(id="first-name", value="Jane"), - dcc.Input(id="last-name", value="Doe"), - html.Div(id="greeting"), - ] - ) - - @app.callback( - Output("greeting", "children"), - Input("first-name", "value"), - ) - def greet_by_first(first): - return f"Hello, {first}!" - - @app.callback( - Output("greeting", "children", allow_duplicate=True), - Input("last-name", "value"), - prevent_initial_call=True, - ) - def greet_by_last(last): - return f"Hi, {last}!" - - dash_duo.start_server(app) - - # Query the component — should reflect initial callback (greet_by_first) - result = _mcp_call_tool( - dash_duo.server.url, - "get_dash_component", - {"component_id": "greeting", "property": "children"}, - ) - structured = result["result"]["structuredContent"] - assert structured["properties"]["children"]["initial_value"] == "Hello, Jane!" diff --git a/tests/integration/mcp/primitives/tools/test_input_schemas.py b/tests/integration/mcp/primitives/tools/test_input_schemas.py deleted file mode 100644 index 6ee3510ddd..0000000000 --- a/tests/integration/mcp/primitives/tools/test_input_schemas.py +++ /dev/null @@ -1,66 +0,0 @@ -""" -Integration tests for MCP tool schema generation. - -Starts a real Dash server via ``dash_duo`` and verifies that tools -are generated with correct inputSchema, descriptions, and labels. -""" - -from dash import Dash, Input, Output, dcc, html - -from tests.integration.mcp.conftest import _mcp_tools - - -def test_mcp_tool_with_label_and_date_picker_schema(dash_duo): - """Full assertion on a tool with an html.Label and DatePickerSingle constraints.""" - - # -- Test data: change these to update the test -- - label_text = "Departure Date" - component_id = "dp" - min_date = "2020-01-01" - max_date = "2025-12-31" - default_date = "2024-06-15" - func_name = "select_date" - param_name = "date" # function parameter name - - app = Dash(__name__) - app.layout = html.Div( - [ - html.Label(label_text, htmlFor=component_id), - dcc.DatePickerSingle( - id=component_id, - min_date_allowed=min_date, - max_date_allowed=max_date, - date=default_date, - ), - html.Div(id="out"), - ] - ) - - @app.callback(Output("out", "children"), Input(component_id, "date")) - def select_date(date): - return f"Selected: {date}" - - dash_duo.start_server(app) - tools = _mcp_tools(dash_duo.server.url) - - # Find the callback tool - tool = next(t for t in tools if t["name"] not in ("get_dash_component",)) - - # -- Tool-level fields -- - assert func_name in tool["name"] - - # -- inputSchema structure -- - schema = tool["inputSchema"] - assert schema["type"] == "object" - assert param_name in schema["required"] - assert param_name in schema["properties"] - - # -- Property schema: type + format + description -- - prop = schema["properties"][param_name] - assert prop["type"] == "string" - assert prop["format"] == "date" - - # description includes all source values (label, constraints, default) - desc = prop["description"] - for expected in (label_text, min_date, max_date, default_date): - assert expected in desc, f"Expected {expected!r} in description: {desc!r}" diff --git a/tests/integration/mcp/primitives/tools/test_tool_get_dash_component.py b/tests/integration/mcp/primitives/tools/test_tool_get_dash_component.py deleted file mode 100644 index 97472a16d7..0000000000 --- a/tests/integration/mcp/primitives/tools/test_tool_get_dash_component.py +++ /dev/null @@ -1,54 +0,0 @@ -"""Integration tests for the get_dash_component tool.""" - -from dash import Dash, dcc, html - -from tests.integration.mcp.conftest import _mcp_call_tool - -EXPECTED_DROPDOWN_OPTIONS = { - "component_id": "my-dropdown", - "component_type": "Dropdown", - "label": None, - "properties": { - "options": { - "initial_value": [ - {"label": "New York", "value": "NYC"}, - {"label": "Montreal", "value": "MTL"}, - ], - "modified_by_tool": [], - "input_to_tool": [], - }, - }, -} - - -def test_query_component_returns_structured_output(dash_duo): - app = Dash(__name__) - app.layout = html.Div( - [ - dcc.Dropdown( - id="my-dropdown", - options=[ - {"label": "New York", "value": "NYC"}, - {"label": "Montreal", "value": "MTL"}, - ], - value="NYC", - ), - ] - ) - - dash_duo.start_server(app) - - result = _mcp_call_tool( - dash_duo.server.url, - "get_dash_component", - {"component_id": "my-dropdown", "property": "options"}, - ) - - assert "result" in result, f"Expected result in response: {result}" - structured = result["result"]["structuredContent"] - assert structured["component_id"] == EXPECTED_DROPDOWN_OPTIONS["component_id"] - assert structured["component_type"] == EXPECTED_DROPDOWN_OPTIONS["component_type"] - assert ( - structured["properties"]["options"] - == EXPECTED_DROPDOWN_OPTIONS["properties"]["options"] - ) diff --git a/tests/integration/mcp/primitives/tools/test_tools_list.py b/tests/integration/mcp/primitives/tools/test_tools_list.py deleted file mode 100644 index dc3d977146..0000000000 --- a/tests/integration/mcp/primitives/tools/test_tools_list.py +++ /dev/null @@ -1,118 +0,0 @@ -"""Integration tests for tools/list — naming, dedup, and spec compliance.""" - -from dash import Dash, Input, Output, dcc, html - -from tests.integration.mcp.conftest import _mcp_tools - - -def test_tool_names_within_64_chars(dash_duo): - app = Dash(__name__) - app.layout = html.Div( - [ - dcc.Dropdown(id="dd", options=["a"], value="a"), - html.Div(id="out"), - ] - ) - - @app.callback(Output("out", "children"), Input("dd", "value")) - def update(val): - return val - - dash_duo.start_server(app) - for tool in _mcp_tools(dash_duo.server.url): - assert len(tool["name"]) <= 64, f"Tool name exceeds 64 chars: {tool['name']}" - for param_name in tool.get("inputSchema", {}).get("properties", {}): - assert len(param_name) <= 64, f"Param name exceeds 64 chars: {param_name}" - - -def test_long_callback_ids_within_64_chars(dash_duo): - app = Dash(__name__) - long_id = "a" * 120 - app.layout = html.Div( - [ - dcc.Input(id=long_id, value="test"), - html.Div(id=f"{long_id}-output"), - ] - ) - - @app.callback(Output(f"{long_id}-output", "children"), Input(long_id, "value")) - def process(val): - return val - - dash_duo.start_server(app) - for tool in _mcp_tools(dash_duo.server.url): - assert len(tool["name"]) <= 64, f"Tool name exceeds 64 chars: {tool['name']}" - - -def test_pattern_matching_ids_within_64_chars(dash_duo): - app = Dash(__name__) - app.layout = html.Div( - [ - html.Div( - [ - dcc.Input( - id={"type": "filter-input", "index": i, "category": "primary"}, - value=f"val-{i}", - ) - for i in range(3) - ] - ), - html.Div(id="pm-output"), - ] - ) - - @app.callback( - Output("pm-output", "children"), - Input({"type": "filter-input", "index": 0, "category": "primary"}, "value"), - ) - def filter_update(v0): - return str(v0) - - dash_duo.start_server(app) - for tool in _mcp_tools(dash_duo.server.url): - assert len(tool["name"]) <= 64, f"Tool name exceeds 64 chars: {tool['name']}" - - -def test_duplicate_func_names_produce_unique_tools(dash_duo): - app = Dash(__name__) - app.layout = html.Div( - [ - dcc.Dropdown(id="dd1", options=["a"], value="a"), - html.Div(id="dd1-output"), - dcc.Dropdown(id="dd2", options=["b"], value="b"), - html.Div(id="dd2-output"), - dcc.Dropdown(id="dd3", options=["c"], value="c"), - html.Div(id="dd3-output"), - ] - ) - - @app.callback(Output("dd1-output", "children"), Input("dd1", "value")) - def cb(value): - return f"first: {value}" - - @app.callback(Output("dd2-output", "children"), Input("dd2", "value")) - def cb(value): # noqa: F811 - return f"second: {value}" - - @app.callback(Output("dd3-output", "children"), Input("dd3", "value")) - def cb(value): # noqa: F811 - return f"third: {value}" - - dash_duo.start_server(app) - tools = _mcp_tools(dash_duo.server.url) - cb_tools = [t for t in tools if t["name"] not in ("get_dash_component",)] - tool_names = [t["name"] for t in cb_tools] - - assert ( - len(tool_names) == 3 - ), f"Expected 3 callback tools, got {len(tool_names)}: {tool_names}" - assert len(set(tool_names)) == 3, f"Tool names not unique: {tool_names}" - - -def test_builtin_tools_always_present(dash_duo): - app = Dash(__name__) - app.layout = html.Div(id="root") - - dash_duo.start_server(app) - tool_names = [t["name"] for t in _mcp_tools(dash_duo.server.url)] - assert "get_dash_component" in tool_names diff --git a/tests/integration/mcp/primitives/tools/test_callback_signatures.py b/tests/integration/mcp/test_mcp_callback_behavior.py similarity index 66% rename from tests/integration/mcp/primitives/tools/test_callback_signatures.py rename to tests/integration/mcp/test_mcp_callback_behavior.py index db325f2046..7778111386 100644 --- a/tests/integration/mcp/primitives/tools/test_callback_signatures.py +++ b/tests/integration/mcp/test_mcp_callback_behavior.py @@ -1,29 +1,56 @@ +"""Callback behaviors surfaced through MCP tools (end-to-end). + +Covers the full pipeline — a real Dash server via ``dash_duo`` + the MCP +HTTP endpoint — for every callback signature variant and the surrounding +tool-list conventions: + +- Positional, dict-based, and tuple-grouped ``inputs`` / ``state`` / + ``output`` forms. +- ``State``, multi-output, ``PreventUpdate``-style no-output callbacks, + ``ctx.triggered_id``, pattern-matching (``ALL``/``MATCH``/``ALLSMALLER``). +- Initial values: ``prevent_initial_call`` vs. initial-callback overrides. +- Duplicate outputs (``allow_duplicate=True``) appearing as separate tools. +- ``tools/list`` naming rules (64-char limit, uniqueness, built-ins). +- A representative input-schema smoke test (label + DatePicker). +- ``get_dash_component`` structured output via HTTP. """ -Integration tests for all Dash callback signature types. -Each test verifies that: -1. The MCP tool schema accurately reflects the callback's parameters -2. Calling the tool with those parameters produces the expected result +from dash import ( + ALL, + ALLSMALLER, + MATCH, + Dash, + Input, + Output, + State, + ctx, + dcc, + html, + set_props, +) -Assertions are derived from the callback definition, not the implementation. - -See: https://dash.plotly.com/flexible-callback-signatures -""" +from tests.integration.mcp.conftest import _mcp_call_tool, _mcp_tools -from dash import Dash, Input, Output, State, dcc, html -from tests.integration.mcp.conftest import _mcp_call_tool, _mcp_tools +# --------------------------------------------------------------------------- +# Helpers +# --------------------------------------------------------------------------- def _find_tool(tools, name): - return next(t for t in tools if t["name"] == name) + return next((t for t in tools if t["name"] == name), None) def _get_response(result): return result["result"]["structuredContent"]["response"] -def test_positional_callback(dash_duo): +# --------------------------------------------------------------------------- +# Callback signatures — positional, multi-output, State, dict-based, tuples +# --------------------------------------------------------------------------- + + +def test_mcpb001_positional_callback(dash_duo): """Standard positional: Input("fruit", "value") → param named 'value'.""" app = Dash(__name__) app.layout = html.Div( @@ -33,8 +60,6 @@ def test_positional_callback(dash_duo): ] ) - # Callback: 1 Input → 1 param named "value" (from function signature) - # Returns string → Output("out", "children") @app.callback(Output("out", "children"), Input("fruit", "value")) def show_fruit(value): return f"Selected: {value}" @@ -48,23 +73,20 @@ def show_fruit(value): assert set(props.keys()) == {"value"} assert any(s.get("type") == "string" for s in props["value"]["anyOf"]) - # Tool description reflects initial state value_desc = props["value"].get("description", "") assert "value: 'apple'" in value_desc assert "options: ['apple', 'banana']" in value_desc - # MCP tool with initial inputs matches browser result = _mcp_call_tool(dash_duo.server.url, "show_fruit", {"value": "apple"}) response = _get_response(result) assert response["out"]["children"] == "Selected: apple" - # MCP tool with different inputs result = _mcp_call_tool(dash_duo.server.url, "show_fruit", {"value": "banana"}) response = _get_response(result) assert response["out"]["children"] == "Selected: banana" -def test_positional_with_state(dash_duo): +def test_mcpb002_positional_with_state(dash_duo): """Positional with State: Input + State both appear as params.""" app = Dash(__name__) app.layout = html.Div( @@ -75,7 +97,6 @@ def test_positional_with_state(dash_duo): ] ) - # Callback: 1 Input + 1 State → 2 params named "n_clicks" and "value" @app.callback( Output("out", "children"), Input("btn", "n_clicks"), @@ -93,10 +114,8 @@ def update(n_clicks, value): assert set(props.keys()) == {"n_clicks", "value"} assert any(s.get("type") == "number" for s in props["n_clicks"]["anyOf"]) - # Tool description reflects initial state assert "value: 'hello'" in props["value"].get("description", "") - # MCP tool with initial inputs matches browser result = _mcp_call_tool( dash_duo.server.url, "update", {"n_clicks": None, "value": "hello"} ) @@ -110,7 +129,7 @@ def update(n_clicks, value): assert response["out"]["children"] == "Clicked 3 with world" -def test_multi_output_positional(dash_duo): +def test_mcpb003_multi_output_positional(dash_duo): """Multi-output: returns tuple → both outputs updated in response.""" app = Dash(__name__) app.layout = html.Div( @@ -121,7 +140,6 @@ def test_multi_output_positional(dash_duo): ] ) - # Callback: 1 Input → 2 Outputs via tuple return @app.callback( Output("out1", "children"), Output("out2", "children"), @@ -138,18 +156,16 @@ def split_case(value): props = tool["inputSchema"]["properties"] assert set(props.keys()) == {"value"} - # Tool description reflects initial state assert "value: 'test'" in props["value"].get("description", "") - # MCP tool with initial inputs matches browser result = _mcp_call_tool(dash_duo.server.url, "split_case", {"value": "test"}) response = _get_response(result) assert response["out1"]["children"] == "TEST" assert response["out2"]["children"] == "test" -def test_dict_based_inputs_and_state(dash_duo): - """Dict-based: inputs=dict(trigger=...), state=dict(name=...) → dict keys are param names.""" +def test_mcpb004_dict_based_inputs_and_state(dash_duo): + """Dict-based: inputs=dict(trigger=...), state=dict(name=...) → dict keys are params.""" app = Dash(__name__) app.layout = html.Div( [ @@ -159,7 +175,6 @@ def test_dict_based_inputs_and_state(dash_duo): ] ) - # Callback: dict keys "trigger" and "name" become param names @app.callback( Output("out", "children"), inputs=dict(trigger=Input("btn", "n_clicks")), @@ -177,7 +192,6 @@ def greet(trigger, name): assert set(props.keys()) == {"trigger", "name"} assert any(s.get("type") == "number" for s in props["trigger"]["anyOf"]) - # MCP tool with initial inputs matches browser result = _mcp_call_tool( dash_duo.server.url, "greet", {"trigger": None, "name": "world"} ) @@ -191,7 +205,7 @@ def greet(trigger, name): assert response["out"]["children"] == "Hello, Dash!" -def test_dict_based_outputs(dash_duo): +def test_mcpb005_dict_based_outputs(dash_duo): """Dict-based outputs: output=dict(...) → callback returns dict, both outputs updated.""" app = Dash(__name__) app.layout = html.Div( @@ -202,7 +216,6 @@ def test_dict_based_outputs(dash_duo): ] ) - # Callback: dict output keys "upper" and "lower" map to components @app.callback( output=dict( upper=Output("upper-out", "children"), @@ -221,14 +234,13 @@ def transform(val): props = tool["inputSchema"]["properties"] assert set(props.keys()) == {"val"} - # MCP tool with initial inputs matches browser result = _mcp_call_tool(dash_duo.server.url, "transform", {"val": "hello"}) response = _get_response(result) assert response["upper-out"]["children"] == "HELLO" assert response["lower-out"]["children"] == "hello" -def test_mixed_input_state_in_inputs(dash_duo): +def test_mcpb006_mixed_input_state_in_inputs(dash_duo): """Mixed: State inside inputs=dict alongside Input → all appear as params.""" app = Dash(__name__) app.layout = html.Div( @@ -240,7 +252,6 @@ def test_mixed_input_state_in_inputs(dash_duo): ] ) - # Callback: Input and State mixed in same dict → all keys are params @app.callback( Output("out", "children"), inputs=dict( @@ -261,7 +272,6 @@ def full_name(clicks, first, last): assert set(props.keys()) == {"clicks", "first", "last"} assert any(s.get("type") == "number" for s in props["clicks"]["anyOf"]) - # MCP tool with initial inputs matches browser result = _mcp_call_tool( dash_duo.server.url, "full_name", @@ -279,7 +289,7 @@ def full_name(clicks, first, last): assert response["out"]["children"] == "John Smith" -def test_tuple_grouped_inputs(dash_duo): +def test_mcpb007_tuple_grouped_inputs(dash_duo): """Tuple grouping: pair=(Input("a",...), Input("b",...)) → expands to two named params.""" app = Dash(__name__) app.layout = html.Div( @@ -290,7 +300,6 @@ def test_tuple_grouped_inputs(dash_duo): ] ) - # Callback: tuple group "pair" maps to 2 deps → 2 params named pair___ @app.callback( Output("out", "children"), inputs=dict(pair=(Input("a", "value"), Input("b", "value"))), @@ -302,7 +311,6 @@ def combine(pair): tool = _find_tool(_mcp_tools(dash_duo.server.url), "combine") props = tool["inputSchema"]["properties"] - # Tuple expands: one param per dep, named with group prefix + component info assert set(props.keys()) == {"pair_a__value", "pair_b__value"} for schema in props.values(): assert any(s.get("type") == "string" for s in schema["anyOf"]) @@ -316,7 +324,7 @@ def combine(pair): assert response["out"]["children"] == "x+y" -def test_initial_values_from_chained_callbacks(dash_duo): +def test_mcpb008_initial_values_from_chained_callbacks(dash_duo): """Querying components reflects post-initial-callback values. 3-link chain: country (default "France") → update_states → @@ -368,16 +376,13 @@ def update_cities(state, country): dash_duo.start_server(app) - # Tool descriptions should reflect post-initial-callback state tools = _mcp_tools(dash_duo.server.url) update_cities_tool = _find_tool(tools, "update_cities") state_desc = update_cities_tool["inputSchema"]["properties"]["state"].get( "description", "" ) - # state.value was set to "Ile-de-France" by update_states initial callback assert "Ile-de-France" in state_desc - # state.value should be "Ile-de-France" (first state for France) result = _mcp_call_tool( dash_duo.server.url, "get_dash_component", @@ -386,7 +391,6 @@ def update_cities(state, country): state_props = result["result"]["structuredContent"]["properties"] assert state_props["value"]["initial_value"] == "Ile-de-France" - # city.value should be "Paris" (first city for Ile-de-France) result = _mcp_call_tool( dash_duo.server.url, "get_dash_component", @@ -396,7 +400,7 @@ def update_cities(state, country): assert city_props["value"]["initial_value"] == "Paris" -def test_dict_based_reordered_state_input(dash_duo): +def test_mcpb009_dict_based_reordered_state_input(dash_duo): """Dict-based callback with State before Input: call works, schema types correct. State is listed before Input in the dict. The callback should still @@ -421,7 +425,6 @@ def greet(name: str, trigger: int): dash_duo.start_server(app) - # First: verify the callback actually works with these args result = _mcp_call_tool( dash_duo.server.url, "greet", @@ -429,20 +432,23 @@ def greet(name: str, trigger: int): ) assert _get_response(result)["out"]["children"] == "Hello Dash" - # Second: verify schema types match annotations tool = _find_tool(_mcp_tools(dash_duo.server.url), "greet") props = tool["inputSchema"]["properties"] assert props["trigger"]["type"] == "integer" assert props["name"]["type"] == "string" - # Third: verify each param describes the correct component trigger_desc = props["trigger"].get("description", "") assert "number of times that this element has been clicked on" in trigger_desc name_desc = props["name"].get("description", "") assert "The value of the input" in name_desc -def test_pattern_matching_callback(dash_duo): +# --------------------------------------------------------------------------- +# Pattern-matching callbacks (ALL / MATCH / ALLSMALLER) +# --------------------------------------------------------------------------- + + +def test_mcpb010_pattern_matching_callback(dash_duo): """Pattern-matching dict IDs: tool works with correct params and results.""" app = Dash(__name__) app.layout = html.Div( @@ -469,7 +475,6 @@ def combine(first, second): assert "first" in props assert "second" in props - # Verify initial output matches what the browser shows dash_duo.wait_for_text_to_equal("#out", "hello world") result = _mcp_call_tool( dash_duo.server.url, @@ -479,7 +484,6 @@ def combine(first, second): response = _get_response(result) assert response["out"]["children"] == "hello world" - # Verify with different values result = _mcp_call_tool( dash_duo.server.url, "combine", @@ -489,10 +493,8 @@ def combine(first, second): assert response["out"]["children"] == "foo bar" -def test_pattern_matching_with_all_wildcard(dash_duo): +def test_mcpb011_pattern_matching_with_all_wildcard(dash_duo): """ALL wildcard: one callback receives values from all matching components.""" - from dash import ALL - app = Dash(__name__) app.layout = html.Div( [ @@ -515,7 +517,6 @@ def summarize(values): tool = _find_tool(_mcp_tools(dash_duo.server.url), "summarize") assert tool is not None - # Schema must describe values as an array of {id, property, value} objects values_schema = tool["inputSchema"]["properties"]["values"] assert ( values_schema["type"] == "array" @@ -529,7 +530,6 @@ def summarize(values): "description", "" ), "ALL wildcard param description should explain the pattern-matching behavior" - # MCP tool call with browser-like format: concrete IDs + values result = _mcp_call_tool( dash_duo.server.url, "summarize", @@ -551,7 +551,6 @@ def summarize(values): response = _get_response(result) assert response["summary"]["children"] == "alpha, beta" - # Different values result = _mcp_call_tool( dash_duo.server.url, "summarize", @@ -574,10 +573,8 @@ def summarize(values): assert response["summary"]["children"] == "one, two" -def test_pattern_matching_mixed_outputs(dash_duo): +def test_mcpb012_pattern_matching_mixed_outputs(dash_duo): """Mixed outputs: one regular + one ALL wildcard in the same callback.""" - from dash import ALL - app = Dash(__name__) app.layout = html.Div( [ @@ -624,13 +621,11 @@ def echo_and_total(values): assert response["total"]["children"] == "Total: 2 items" -def test_pattern_matching_with_match_wildcard(dash_duo): +def test_mcpb013_pattern_matching_with_match_wildcard(dash_duo): """MATCH wildcard: callback fires per-component with matching index. Based on https://dash.plotly.com/pattern-matching-callbacks """ - from dash import MATCH - app = Dash(__name__) app.layout = html.Div( [ @@ -661,11 +656,9 @@ def show_city(value): tool = _find_tool(_mcp_tools(dash_duo.server.url), "show_city") assert tool is not None - # Schema describes MATCH input value_schema = tool["inputSchema"]["properties"]["value"] assert "Pattern-matching input (MATCH)" in value_schema.get("description", "") - # Call with concrete ID for index 0 (MATCH takes a single entry, not an array) result = _mcp_call_tool( dash_duo.server.url, "show_city", @@ -678,18 +671,15 @@ def show_city(value): }, ) response = _get_response(result) - # Find the output key containing "city-out" (Dash may serialize dict IDs differently) out_key = next(k for k in response if "city-out" in k) assert response[out_key]["children"] == "Selected: MTL" -def test_pattern_matching_with_allsmaller_wildcard(dash_duo): +def test_mcpb014_pattern_matching_with_allsmaller_wildcard(dash_duo): """ALLSMALLER wildcard: receives values from components with smaller index. Based on https://dash.plotly.com/pattern-matching-callbacks """ - from dash import MATCH, ALLSMALLER - app = Dash(__name__) app.layout = html.Div( [ @@ -728,14 +718,12 @@ def show_countries(current, previous): tool = _find_tool(_mcp_tools(dash_duo.server.url), "show_countries") assert tool is not None - # Schema describes both MATCH and ALLSMALLER inputs props = tool["inputSchema"]["properties"] assert "Pattern-matching input (MATCH)" in props["current"].get("description", "") assert "Pattern-matching input (ALLSMALLER)" in props["previous"].get( "description", "" ) - # Call for index 2: MATCH is a single dict, ALLSMALLER is a list result = _mcp_call_tool( dash_duo.server.url, "show_countries", @@ -764,13 +752,13 @@ def show_countries(current, previous): assert response[out_key]["children"] == "All: Japan, Germany, France" -def test_prevent_initial_call_uses_layout_default(dash_duo): - """prevent_initial_call=True: initial value stays as the layout default. +# --------------------------------------------------------------------------- +# Initial values: prevent_initial_call vs. initial-callback overrides +# --------------------------------------------------------------------------- - The dropdown has value="original" in the layout. The callback has - prevent_initial_call=True so it doesn't run on page load. The MCP - tool description should show value: 'a' (layout default). - """ + +def test_mcpb015_prevent_initial_call_uses_layout_default(dash_duo): + """prevent_initial_call=True: initial value stays as the layout default.""" app = Dash(__name__) app.layout = html.Div( [ @@ -788,24 +776,16 @@ def update(val): return f"Changed to: {val}" dash_duo.start_server(app) - # Browser shows layout default — callback hasn't fired dash_duo.wait_for_text_to_equal("#out", "not yet") tool = _find_tool(_mcp_tools(dash_duo.server.url), "update") val_desc = tool["inputSchema"]["properties"]["val"].get("description", "") - # Tool description reflects layout default, not callback output assert "value: 'a'" in val_desc -def test_initial_callback_overrides_layout_value(dash_duo): - """Initial callback overrides layout value in tool description. - - The city dropdown has value="default-city" in the layout. - update_city runs on page load (no prevent_initial_call) and - sets city.value to "Paris". The MCP tool should show "Paris" - as the default, not "default-city". - """ +def test_mcpb016_initial_callback_overrides_layout_value(dash_duo): + """Initial callback overrides layout value in tool description.""" app = Dash(__name__) app.layout = html.Div( [ @@ -830,24 +810,20 @@ def show_city(city): return f"City: {city}" dash_duo.start_server(app) - # Browser shows "Paris" — the initial callback overrode "default-city" dash_duo.wait_for_text_to_equal("#out", "City: Paris") tool = _find_tool(_mcp_tools(dash_duo.server.url), "show_city") city_desc = tool["inputSchema"]["properties"]["city"].get("description", "") - # Tool description should show the post-initial-callback value assert "value: 'Paris'" in city_desc assert "default-city" not in city_desc -def test_callback_context_triggered_id(dash_duo): +def test_mcpb017_callback_context_triggered_id(dash_duo): """Callbacks using dash.ctx.triggered_id work via MCP. Based on https://dash.plotly.com/determining-which-callback-input-changed """ - from dash import ctx - app = Dash(__name__) app.layout = html.Div( [ @@ -871,17 +847,14 @@ def display(btn1, btn2, btn3): dash_duo.start_server(app) - # Browser initial state: no button clicked dash_duo.wait_for_text_to_equal("#output", "No button clicked yet") - # Tool should have all three button params tool = _find_tool(_mcp_tools(dash_duo.server.url), "display") props = tool["inputSchema"]["properties"] assert "btn1" in props assert "btn2" in props assert "btn3" in props - # Click btn-2 via MCP — ctx.triggered_id should be "btn-2" result = _mcp_call_tool( dash_duo.server.url, "display", @@ -890,7 +863,6 @@ def display(btn1, btn2, btn3): response = _get_response(result) assert response["output"]["children"] == "Last clicked: btn-2" - # Click btn-3 via MCP result = _mcp_call_tool( dash_duo.server.url, "display", @@ -900,14 +872,12 @@ def display(btn1, btn2, btn3): assert response["output"]["children"] == "Last clicked: btn-3" -def test_no_output_callback_does_not_crash_tools_list(dash_duo): +def test_mcpb018_no_output_callback_does_not_crash_tools_list(dash_duo): """A callback with no Output should not crash tools/list. No-output callbacks use set_props for side effects. They produce a hash-only output_id with no dot separator. """ - from dash import set_props - app = Dash(__name__) app.layout = html.Div( [ @@ -930,13 +900,9 @@ def show_selection(val): tools = _mcp_tools(dash_duo.server.url) tool_names = [t["name"] for t in tools] - # show_selection should appear as a tool assert "show_selection" in tool_names - - # log_click has no declared output but uses set_props — still a valid tool assert "log_click" in tool_names - # Call log_click — sideUpdate should show the set_props effect result = _mcp_call_tool( dash_duo.server.url, "log_click", @@ -946,9 +912,6 @@ def show_selection(val): assert "sideUpdate" in structured assert structured["sideUpdate"]["display"]["children"] == "Logged 3 clicks" - # get_dash_component shows show_selection as modifier (declared output). - # log_click uses set_props which bypasses the declarative graph — - # its effect is only visible via sideUpdate in tool call results. result = _mcp_call_tool( dash_duo.server.url, "get_dash_component", @@ -956,3 +919,338 @@ def show_selection(val): ) prop_info = result["result"]["structuredContent"]["properties"]["children"] assert "show_selection" in prop_info["modified_by_tool"] + + +# --------------------------------------------------------------------------- +# Duplicate outputs (allow_duplicate=True) +# --------------------------------------------------------------------------- + + +def test_mcpb019_duplicate_outputs_both_tools_listed(dash_duo): + """Both callbacks outputting to the same component appear as tools.""" + app = Dash(__name__) + app.layout = html.Div( + [ + dcc.Input(id="first-name", value="Jane"), + dcc.Input(id="last-name", value="Doe"), + html.Div(id="greeting"), + ] + ) + + @app.callback( + Output("greeting", "children"), + Input("first-name", "value"), + ) + def greet_by_first(first): + return f"Hello, {first}!" + + @app.callback( + Output("greeting", "children", allow_duplicate=True), + Input("last-name", "value"), + prevent_initial_call=True, + ) + def greet_by_last(last): + return f"Hi, {last}!" + + dash_duo.start_server(app) + tools = _mcp_tools(dash_duo.server.url) + + first_tool = _find_tool(tools, "greet_by_first") + last_tool = _find_tool(tools, "greet_by_last") + + assert first_tool is not None, "greet_by_first should be listed" + assert last_tool is not None, "greet_by_last should be listed" + + +def test_mcpb020_duplicate_outputs_both_callable(dash_duo): + """Both callbacks can be called and produce correct results.""" + app = Dash(__name__) + app.layout = html.Div( + [ + dcc.Input(id="first-name", value="Jane"), + dcc.Input(id="last-name", value="Doe"), + html.Div(id="greeting"), + ] + ) + + @app.callback( + Output("greeting", "children"), + Input("first-name", "value"), + ) + def greet_by_first(first): + return f"Hello, {first}!" + + @app.callback( + Output("greeting", "children", allow_duplicate=True), + Input("last-name", "value"), + prevent_initial_call=True, + ) + def greet_by_last(last): + return f"Hi, {last}!" + + dash_duo.start_server(app) + + result1 = _mcp_call_tool(dash_duo.server.url, "greet_by_first", {"first": "Alice"}) + assert _get_response(result1)["greeting"]["children"] == "Hello, Alice!" + + result2 = _mcp_call_tool(dash_duo.server.url, "greet_by_last", {"last": "Smith"}) + assert _get_response(result2)["greeting"]["children"] == "Hi, Smith!" + + +def test_mcpb021_duplicate_outputs_find_by_output_returns_primary(dash_duo): + """find_by_output returns the primary (non-duplicate) callback.""" + app = Dash(__name__) + app.layout = html.Div( + [ + dcc.Input(id="first-name", value="Jane"), + dcc.Input(id="last-name", value="Doe"), + html.Div(id="greeting"), + ] + ) + + @app.callback( + Output("greeting", "children"), + Input("first-name", "value"), + ) + def greet_by_first(first): + return f"Hello, {first}!" + + @app.callback( + Output("greeting", "children", allow_duplicate=True), + Input("last-name", "value"), + prevent_initial_call=True, + ) + def greet_by_last(last): + return f"Hi, {last}!" + + dash_duo.start_server(app) + + result = _mcp_call_tool( + dash_duo.server.url, + "get_dash_component", + {"component_id": "greeting", "property": "children"}, + ) + structured = result["result"]["structuredContent"] + assert structured["properties"]["children"]["initial_value"] == "Hello, Jane!" + + +# --------------------------------------------------------------------------- +# tools/list — naming rules (64-char limit, uniqueness, built-ins) +# --------------------------------------------------------------------------- + + +def test_mcpb022_tool_names_within_64_chars(dash_duo): + app = Dash(__name__) + app.layout = html.Div( + [ + dcc.Dropdown(id="dd", options=["a"], value="a"), + html.Div(id="out"), + ] + ) + + @app.callback(Output("out", "children"), Input("dd", "value")) + def update(val): + return val + + dash_duo.start_server(app) + for tool in _mcp_tools(dash_duo.server.url): + assert len(tool["name"]) <= 64, f"Tool name exceeds 64 chars: {tool['name']}" + for param_name in tool.get("inputSchema", {}).get("properties", {}): + assert len(param_name) <= 64, f"Param name exceeds 64 chars: {param_name}" + + +def test_mcpb023_long_callback_ids_within_64_chars(dash_duo): + app = Dash(__name__) + long_id = "a" * 120 + app.layout = html.Div( + [ + dcc.Input(id=long_id, value="test"), + html.Div(id=f"{long_id}-output"), + ] + ) + + @app.callback(Output(f"{long_id}-output", "children"), Input(long_id, "value")) + def process(val): + return val + + dash_duo.start_server(app) + for tool in _mcp_tools(dash_duo.server.url): + assert len(tool["name"]) <= 64, f"Tool name exceeds 64 chars: {tool['name']}" + + +def test_mcpb024_pattern_matching_ids_within_64_chars(dash_duo): + app = Dash(__name__) + app.layout = html.Div( + [ + html.Div( + [ + dcc.Input( + id={"type": "filter-input", "index": i, "category": "primary"}, + value=f"val-{i}", + ) + for i in range(3) + ] + ), + html.Div(id="pm-output"), + ] + ) + + @app.callback( + Output("pm-output", "children"), + Input({"type": "filter-input", "index": 0, "category": "primary"}, "value"), + ) + def filter_update(v0): + return str(v0) + + dash_duo.start_server(app) + for tool in _mcp_tools(dash_duo.server.url): + assert len(tool["name"]) <= 64, f"Tool name exceeds 64 chars: {tool['name']}" + + +def test_mcpb025_duplicate_func_names_produce_unique_tools(dash_duo): + app = Dash(__name__) + app.layout = html.Div( + [ + dcc.Dropdown(id="dd1", options=["a"], value="a"), + html.Div(id="dd1-output"), + dcc.Dropdown(id="dd2", options=["b"], value="b"), + html.Div(id="dd2-output"), + dcc.Dropdown(id="dd3", options=["c"], value="c"), + html.Div(id="dd3-output"), + ] + ) + + @app.callback(Output("dd1-output", "children"), Input("dd1", "value")) + def cb(value): + return f"first: {value}" + + @app.callback(Output("dd2-output", "children"), Input("dd2", "value")) + def cb(value): # noqa: F811 + return f"second: {value}" + + @app.callback(Output("dd3-output", "children"), Input("dd3", "value")) + def cb(value): # noqa: F811 + return f"third: {value}" + + dash_duo.start_server(app) + tools = _mcp_tools(dash_duo.server.url) + cb_tools = [t for t in tools if t["name"] not in ("get_dash_component",)] + tool_names = [t["name"] for t in cb_tools] + + assert ( + len(tool_names) == 3 + ), f"Expected 3 callback tools, got {len(tool_names)}: {tool_names}" + assert len(set(tool_names)) == 3, f"Tool names not unique: {tool_names}" + + +def test_mcpb026_builtin_tools_always_present(dash_duo): + app = Dash(__name__) + app.layout = html.Div(id="root") + + dash_duo.start_server(app) + tool_names = [t["name"] for t in _mcp_tools(dash_duo.server.url)] + assert "get_dash_component" in tool_names + + +# --------------------------------------------------------------------------- +# Input schema smoke test + get_dash_component HTTP structured output +# --------------------------------------------------------------------------- + + +def test_mcpb027_mcp_tool_with_label_and_date_picker_schema(dash_duo): + """Full assertion on a tool with an html.Label and DatePickerSingle constraints.""" + label_text = "Departure Date" + component_id = "dp" + min_date = "2020-01-01" + max_date = "2025-12-31" + default_date = "2024-06-15" + func_name = "select_date" + param_name = "date" # function parameter name + + app = Dash(__name__) + app.layout = html.Div( + [ + html.Label(label_text, htmlFor=component_id), + dcc.DatePickerSingle( + id=component_id, + min_date_allowed=min_date, + max_date_allowed=max_date, + date=default_date, + ), + html.Div(id="out"), + ] + ) + + @app.callback(Output("out", "children"), Input(component_id, "date")) + def select_date(date): + return f"Selected: {date}" + + dash_duo.start_server(app) + tools = _mcp_tools(dash_duo.server.url) + + tool = next(t for t in tools if t["name"] not in ("get_dash_component",)) + + assert func_name in tool["name"] + + schema = tool["inputSchema"] + assert schema["type"] == "object" + assert param_name in schema["required"] + assert param_name in schema["properties"] + + prop = schema["properties"][param_name] + assert prop["type"] == "string" + assert prop["format"] == "date" + + desc = prop["description"] + for expected in (label_text, min_date, max_date, default_date): + assert expected in desc, f"Expected {expected!r} in description: {desc!r}" + + +EXPECTED_DROPDOWN_OPTIONS = { + "component_id": "my-dropdown", + "component_type": "Dropdown", + "label": None, + "properties": { + "options": { + "initial_value": [ + {"label": "New York", "value": "NYC"}, + {"label": "Montreal", "value": "MTL"}, + ], + "modified_by_tool": [], + "input_to_tool": [], + }, + }, +} + + +def test_mcpb028_query_component_returns_structured_output(dash_duo): + app = Dash(__name__) + app.layout = html.Div( + [ + dcc.Dropdown( + id="my-dropdown", + options=[ + {"label": "New York", "value": "NYC"}, + {"label": "Montreal", "value": "MTL"}, + ], + value="NYC", + ), + ] + ) + + dash_duo.start_server(app) + + result = _mcp_call_tool( + dash_duo.server.url, + "get_dash_component", + {"component_id": "my-dropdown", "property": "options"}, + ) + + assert "result" in result, f"Expected result in response: {result}" + structured = result["result"]["structuredContent"] + assert structured["component_id"] == EXPECTED_DROPDOWN_OPTIONS["component_id"] + assert structured["component_type"] == EXPECTED_DROPDOWN_OPTIONS["component_type"] + assert ( + structured["properties"]["options"] + == EXPECTED_DROPDOWN_OPTIONS["properties"]["options"] + ) diff --git a/tests/integration/mcp/test_mcp_endpoint.py b/tests/integration/mcp/test_mcp_endpoint.py new file mode 100644 index 0000000000..44b358c25d --- /dev/null +++ b/tests/integration/mcp/test_mcp_endpoint.py @@ -0,0 +1,189 @@ +"""MCP Streamable HTTP endpoint — transport-layer behavior. + +Uses Flask's test_client to exercise POST/GET/DELETE at /_mcp, +session management, content-type handling, and route registration +driven by ``enable_mcp`` / ``DASH_MCP_ENABLED`` / ``routes_pathname_prefix``. +""" + +import json +import os + +from dash import Dash, Input, Output, html +from mcp.types import LATEST_PROTOCOL_VERSION + +MCP_PATH = "_mcp" + + +# --------------------------------------------------------------------------- +# Helpers +# --------------------------------------------------------------------------- + + +def _make_app(**kwargs): + """Create a minimal Dash app with a layout and one callback.""" + app = Dash(__name__, **kwargs) + app.layout = html.Div( + [ + html.Div(id="my-input"), + html.Div(id="my-output"), + ] + ) + + @app.callback(Output("my-output", "children"), Input("my-input", "children")) + def update_output(value): + """Test callback docstring.""" + return f"echo: {value}" + + return app + + +# --------------------------------------------------------------------------- +# Tests +# --------------------------------------------------------------------------- + + +def test_mcpe001_post_initialize_returns_protocol_version(): + app = _make_app() + client = app.server.test_client() + r = client.post( + f"/{MCP_PATH}", + data=json.dumps( + {"jsonrpc": "2.0", "method": "initialize", "id": 1, "params": {}} + ), + content_type="application/json", + ) + assert r.status_code == 200 + data = json.loads(r.data) + assert data["result"]["protocolVersion"] == LATEST_PROTOCOL_VERSION + + +def test_mcpe002_post_tools_list(): + app = _make_app() + client = app.server.test_client() + r = client.post( + f"/{MCP_PATH}", + data=json.dumps( + {"jsonrpc": "2.0", "method": "tools/list", "id": 1, "params": {}} + ), + content_type="application/json", + ) + assert r.status_code == 200 + data = json.loads(r.data) + assert "result" in data + assert "tools" in data["result"] + + +def test_mcpe003_notification_returns_202(): + app = _make_app() + client = app.server.test_client() + r = client.post( + f"/{MCP_PATH}", + data=json.dumps({"jsonrpc": "2.0", "method": "notifications/initialized"}), + content_type="application/json", + ) + assert r.status_code == 202 + + +def test_mcpe004_delete_returns_405(): + app = _make_app() + client = app.server.test_client() + r = client.delete(f"/{MCP_PATH}") + assert r.status_code == 405 + + +def test_mcpe005_get_returns_405(): + app = _make_app() + client = app.server.test_client() + r = client.get(f"/{MCP_PATH}") + assert r.status_code == 405 + + +def test_mcpe006_post_rejects_wrong_content_type(): + app = _make_app() + client = app.server.test_client() + r = client.post( + f"/{MCP_PATH}", + data="not json", + content_type="text/plain", + ) + assert r.status_code == 415 + + +def test_mcpe007_routes_not_registered_when_disabled(): + app = _make_app(enable_mcp=False) + client = app.server.test_client() + r = client.post( + f"/{MCP_PATH}", + data=json.dumps( + {"jsonrpc": "2.0", "method": "initialize", "id": 1, "params": {}} + ), + content_type="application/json", + ) + # With MCP disabled, the route doesn't exist — response is HTML, not JSON + assert r.content_type != "application/json" + + +def test_mcpe008_routes_respect_pathname_prefix(): + app = _make_app(routes_pathname_prefix="/app/") + client = app.server.test_client() + + ok = client.post( + f"/app/{MCP_PATH}", + data=json.dumps( + {"jsonrpc": "2.0", "method": "initialize", "id": 1, "params": {}} + ), + content_type="application/json", + ) + assert ok.status_code == 200 + + miss = client.post( + f"/{MCP_PATH}", + data=json.dumps( + {"jsonrpc": "2.0", "method": "initialize", "id": 1, "params": {}} + ), + content_type="application/json", + ) + assert miss.status_code == 404 + + +def test_mcpe009_enable_mcp_env_var_false(): + old = os.environ.get("DASH_MCP_ENABLED") + try: + os.environ["DASH_MCP_ENABLED"] = "false" + app = _make_app() + client = app.server.test_client() + r = client.post( + f"/{MCP_PATH}", + data=json.dumps( + {"jsonrpc": "2.0", "method": "initialize", "id": 1, "params": {}} + ), + content_type="application/json", + ) + assert r.content_type != "application/json" + finally: + if old is None: + os.environ.pop("DASH_MCP_ENABLED", None) + else: + os.environ["DASH_MCP_ENABLED"] = old + + +def test_mcpe010_constructor_overrides_env_var(): + old = os.environ.get("DASH_MCP_ENABLED") + try: + os.environ["DASH_MCP_ENABLED"] = "false" + app = _make_app(enable_mcp=True) + client = app.server.test_client() + r = client.post( + f"/{MCP_PATH}", + data=json.dumps( + {"jsonrpc": "2.0", "method": "initialize", "id": 1, "params": {}} + ), + content_type="application/json", + ) + assert r.status_code == 200 + assert b"protocolVersion" in r.data + finally: + if old is None: + os.environ.pop("DASH_MCP_ENABLED", None) + else: + os.environ["DASH_MCP_ENABLED"] = old diff --git a/tests/integration/mcp/primitives/resources/test_resources.py b/tests/integration/mcp/test_mcp_resources.py similarity index 86% rename from tests/integration/mcp/primitives/resources/test_resources.py rename to tests/integration/mcp/test_mcp_resources.py index dfc1e09f9b..41519578d1 100644 --- a/tests/integration/mcp/primitives/resources/test_resources.py +++ b/tests/integration/mcp/test_mcp_resources.py @@ -1,4 +1,4 @@ -"""Integration tests for MCP resources.""" +"""MCP resources — ``resources/list`` and ``resources/read`` via HTTP.""" import json @@ -7,7 +7,7 @@ from tests.integration.mcp.conftest import _mcp_method -def test_resources_list_includes_layout(dash_duo): +def test_mcpz001_resources_list_includes_layout(dash_duo): app = Dash(__name__) app.layout = html.Div( [ @@ -24,7 +24,7 @@ def test_resources_list_includes_layout(dash_duo): assert "dash://layout" in uris -def test_read_layout_resource(dash_duo): +def test_mcpz002_read_layout_resource(dash_duo): app = Dash(__name__) app.layout = html.Div( [ diff --git a/tests/integration/mcp/test_server.py b/tests/integration/mcp/test_server.py deleted file mode 100644 index 4f0d0fca00..0000000000 --- a/tests/integration/mcp/test_server.py +++ /dev/null @@ -1,183 +0,0 @@ -"""Integration tests for the MCP Streamable HTTP endpoint. - -These tests use Flask's test_client to exercise the HTTP transport layer -(POST/GET/DELETE at /_mcp), session management, content-type handling, -and route registration/configuration. -""" - -import json -import os - -from dash import Dash, Input, Output, html -from mcp.types import LATEST_PROTOCOL_VERSION - -MCP_PATH = "_mcp" - - -# --------------------------------------------------------------------------- -# Helpers -# --------------------------------------------------------------------------- - - -def _make_app(**kwargs): - """Create a minimal Dash app with a layout and one callback.""" - app = Dash(__name__, **kwargs) - app.layout = html.Div( - [ - html.Div(id="my-input"), - html.Div(id="my-output"), - ] - ) - - @app.callback(Output("my-output", "children"), Input("my-input", "children")) - def update_output(value): - """Test callback docstring.""" - return f"echo: {value}" - - return app - - -# --------------------------------------------------------------------------- -# Tests -# --------------------------------------------------------------------------- - - -class TestMCPEndpoint: - """Tests for the Streamable HTTP MCP endpoint at /_mcp.""" - - def test_post_initialize_returns_protocol_version(self): - app = _make_app() - client = app.server.test_client() - r = client.post( - f"/{MCP_PATH}", - data=json.dumps( - {"jsonrpc": "2.0", "method": "initialize", "id": 1, "params": {}} - ), - content_type="application/json", - ) - assert r.status_code == 200 - data = json.loads(r.data) - assert data["result"]["protocolVersion"] == LATEST_PROTOCOL_VERSION - - def test_post_tools_list(self): - app = _make_app() - client = app.server.test_client() - r = client.post( - f"/{MCP_PATH}", - data=json.dumps( - {"jsonrpc": "2.0", "method": "tools/list", "id": 1, "params": {}} - ), - content_type="application/json", - ) - assert r.status_code == 200 - data = json.loads(r.data) - assert "result" in data - assert "tools" in data["result"] - - def test_notification_returns_202(self): - app = _make_app() - client = app.server.test_client() - r = client.post( - f"/{MCP_PATH}", - data=json.dumps({"jsonrpc": "2.0", "method": "notifications/initialized"}), - content_type="application/json", - ) - assert r.status_code == 202 - - def test_delete_returns_405(self): - app = _make_app() - client = app.server.test_client() - r = client.delete(f"/{MCP_PATH}") - assert r.status_code == 405 - - def test_get_returns_405(self): - app = _make_app() - client = app.server.test_client() - r = client.get(f"/{MCP_PATH}") - assert r.status_code == 405 - - def test_post_rejects_wrong_content_type(self): - app = _make_app() - client = app.server.test_client() - r = client.post( - f"/{MCP_PATH}", - data="not json", - content_type="text/plain", - ) - assert r.status_code == 415 - - def test_routes_not_registered_when_disabled(self): - app = _make_app(enable_mcp=False) - client = app.server.test_client() - r = client.post( - f"/{MCP_PATH}", - data=json.dumps( - {"jsonrpc": "2.0", "method": "initialize", "id": 1, "params": {}} - ), - content_type="application/json", - ) - # With MCP disabled, the route doesn't exist — response is HTML, not JSON - assert r.content_type != "application/json" - - def test_routes_respect_pathname_prefix(self): - app = _make_app(routes_pathname_prefix="/app/") - client = app.server.test_client() - - ok = client.post( - f"/app/{MCP_PATH}", - data=json.dumps( - {"jsonrpc": "2.0", "method": "initialize", "id": 1, "params": {}} - ), - content_type="application/json", - ) - assert ok.status_code == 200 - - miss = client.post( - f"/{MCP_PATH}", - data=json.dumps( - {"jsonrpc": "2.0", "method": "initialize", "id": 1, "params": {}} - ), - content_type="application/json", - ) - assert miss.status_code == 404 - - def test_enable_mcp_env_var_false(self): - old = os.environ.get("DASH_MCP_ENABLED") - try: - os.environ["DASH_MCP_ENABLED"] = "false" - app = _make_app() - client = app.server.test_client() - r = client.post( - f"/{MCP_PATH}", - data=json.dumps( - {"jsonrpc": "2.0", "method": "initialize", "id": 1, "params": {}} - ), - content_type="application/json", - ) - assert r.content_type != "application/json" - finally: - if old is None: - os.environ.pop("DASH_MCP_ENABLED", None) - else: - os.environ["DASH_MCP_ENABLED"] = old - - def test_constructor_overrides_env_var(self): - old = os.environ.get("DASH_MCP_ENABLED") - try: - os.environ["DASH_MCP_ENABLED"] = "false" - app = _make_app(enable_mcp=True) - client = app.server.test_client() - r = client.post( - f"/{MCP_PATH}", - data=json.dumps( - {"jsonrpc": "2.0", "method": "initialize", "id": 1, "params": {}} - ), - content_type="application/json", - ) - assert r.status_code == 200 - assert b"protocolVersion" in r.data - finally: - if old is None: - os.environ.pop("DASH_MCP_ENABLED", None) - else: - os.environ["DASH_MCP_ENABLED"] = old diff --git a/tests/unit/mcp/test_mcp_server.py b/tests/unit/mcp/test_mcp_server.py new file mode 100644 index 0000000000..f4bb595dce --- /dev/null +++ b/tests/unit/mcp/test_mcp_server.py @@ -0,0 +1,99 @@ +"""MCP server JSON-RPC message processing (``_process_mcp_message``).""" + +from dash._get_app import app_context +from dash.mcp._server import _process_mcp_message +from mcp.types import LATEST_PROTOCOL_VERSION + +from tests.unit.mcp.conftest import _make_app, _setup_mcp + + +# --------------------------------------------------------------------------- +# Helpers +# --------------------------------------------------------------------------- + + +def _msg(method, params=None, request_id=1): + d = {"jsonrpc": "2.0", "method": method, "id": request_id} + d["params"] = params if params is not None else {} + return d + + +def _mcp(app, method, params=None, request_id=1): + with app.server.test_request_context(): + _setup_mcp(app) + return _process_mcp_message(_msg(method, params, request_id)) + + +def _tools_list(app): + return _mcp(app, "tools/list")["result"]["tools"] + + +def _call_tool(app, tool_name, arguments=None, request_id=1): + return _mcp( + app, "tools/call", {"name": tool_name, "arguments": arguments or {}}, request_id + ) + + +def _call_tool_output( + app, tool_name, arguments=None, component_id=None, prop="children" +): + result = _call_tool(app, tool_name, arguments) + structured = result["result"]["structuredContent"] + response = structured["response"] + if component_id is None: + component_id = next(iter(response)) + return response[component_id][prop] + + +# --------------------------------------------------------------------------- +# Tests +# --------------------------------------------------------------------------- + + +def test_mcps001_initialize(): + app = _make_app() + result = _mcp(app, "initialize") + + assert result is not None + assert result["id"] == 1 + assert result["jsonrpc"] == "2.0" + assert result["result"]["protocolVersion"] == LATEST_PROTOCOL_VERSION + assert "serverInfo" in result["result"] + + +def test_mcps002_tools_call(): + app = _make_app() + tools = _tools_list(app) + tool_name = next(t["name"] for t in tools if "update_output" in t["name"]) + + result = _call_tool(app, tool_name, {"value": "hello"}, request_id=2) + + assert result is not None + assert result["id"] == 2 + assert _call_tool_output(app, tool_name, {"value": "hello"}) == "echo: hello" + + +def test_mcps003_tools_call_unknown_tool_returns_error(): + app = _make_app() + result = _call_tool(app, "nonexistent_tool") + + assert result is not None + assert "error" in result + assert result["error"]["code"] == -32601 + + +def test_mcps004_unknown_method_returns_error(): + app = _make_app() + result = _mcp(app, "unknown/method") + + assert result is not None + assert "error" in result + + +def test_mcps005_notification_returns_none(): + app = _make_app() + data = {"jsonrpc": "2.0", "method": "notifications/initialized"} + with app.server.test_request_context(): + app_context.set(app) + result = _process_mcp_message(data) + assert result is None diff --git a/tests/unit/mcp/test_server.py b/tests/unit/mcp/test_server.py deleted file mode 100644 index 23c99c50ad..0000000000 --- a/tests/unit/mcp/test_server.py +++ /dev/null @@ -1,86 +0,0 @@ -"""Tests for MCP server (_server.py) — JSON-RPC message processing.""" - -from dash._get_app import app_context -from dash.mcp._server import _process_mcp_message -from mcp.types import LATEST_PROTOCOL_VERSION - -from tests.unit.mcp.conftest import _make_app, _setup_mcp - - -def _msg(method, params=None, request_id=1): - d = {"jsonrpc": "2.0", "method": method, "id": request_id} - d["params"] = params if params is not None else {} - return d - - -def _mcp(app, method, params=None, request_id=1): - with app.server.test_request_context(): - _setup_mcp(app) - return _process_mcp_message(_msg(method, params, request_id)) - - -def _tools_list(app): - return _mcp(app, "tools/list")["result"]["tools"] - - -def _call_tool(app, tool_name, arguments=None, request_id=1): - return _mcp( - app, "tools/call", {"name": tool_name, "arguments": arguments or {}}, request_id - ) - - -def _call_tool_output( - app, tool_name, arguments=None, component_id=None, prop="children" -): - result = _call_tool(app, tool_name, arguments) - structured = result["result"]["structuredContent"] - response = structured["response"] - if component_id is None: - component_id = next(iter(response)) - return response[component_id][prop] - - -class TestProcessMCPMessage: - def test_initialize(self): - app = _make_app() - result = _mcp(app, "initialize") - - assert result is not None - assert result["id"] == 1 - assert result["jsonrpc"] == "2.0" - assert result["result"]["protocolVersion"] == LATEST_PROTOCOL_VERSION - assert "serverInfo" in result["result"] - - def test_tools_call(self): - app = _make_app() - tools = _tools_list(app) - tool_name = next(t["name"] for t in tools if "update_output" in t["name"]) - - result = _call_tool(app, tool_name, {"value": "hello"}, request_id=2) - - assert result is not None - assert result["id"] == 2 - assert _call_tool_output(app, tool_name, {"value": "hello"}) == "echo: hello" - - def test_tools_call_unknown_tool_returns_error(self): - app = _make_app() - result = _call_tool(app, "nonexistent_tool") - - assert result is not None - assert "error" in result - assert result["error"]["code"] == -32601 - - def test_unknown_method_returns_error(self): - app = _make_app() - result = _mcp(app, "unknown/method") - - assert result is not None - assert "error" in result - - def test_notification_returns_none(self): - app = _make_app() - data = {"jsonrpc": "2.0", "method": "notifications/initialized"} - with app.server.test_request_context(): - app_context.set(app) - result = _process_mcp_message(data) - assert result is None diff --git a/tests/unit/mcp/tools/test_mcp_run_callback.py b/tests/unit/mcp/tools/test_mcp_run_callback.py new file mode 100644 index 0000000000..e345b9682e --- /dev/null +++ b/tests/unit/mcp/tools/test_mcp_run_callback.py @@ -0,0 +1,253 @@ +"""Callback dispatch execution via MCP tools (``run_callback``). + +Exercises how the MCP tool pipeline runs a Dash callback through +``_process_mcp_message`` with various signatures: multi-output, State, +positional vs. dict-based ``inputs``, ``PreventUpdate``, and no-output +set_props-style callbacks. +""" + +from dash import Dash, Input, Output, State, dcc, html, set_props +from dash.exceptions import PreventUpdate +from dash.mcp._server import _process_mcp_message + +from tests.unit.mcp.conftest import _setup_mcp + + +# --------------------------------------------------------------------------- +# Helpers +# --------------------------------------------------------------------------- + + +def _msg(method, params=None, request_id=1): + d = {"jsonrpc": "2.0", "method": method, "id": request_id} + d["params"] = params if params is not None else {} + return d + + +def _mcp(app, method, params=None, request_id=1): + with app.server.test_request_context(): + _setup_mcp(app) + return _process_mcp_message(_msg(method, params, request_id)) + + +def _tools_list(app): + return _mcp(app, "tools/list")["result"]["tools"] + + +def _call_tool_structured(app, tool_name, arguments=None): + result = _mcp(app, "tools/call", {"name": tool_name, "arguments": arguments or {}}) + return result["result"]["structuredContent"] + + +def _call_tool_output( + app, tool_name, arguments=None, component_id=None, prop="children" +): + structured = _call_tool_structured(app, tool_name, arguments) + response = structured["response"] + if component_id is None: + component_id = next(iter(response)) + return response[component_id][prop] + + +# --------------------------------------------------------------------------- +# Tests +# --------------------------------------------------------------------------- + + +def test_mcpx001_multi_output(): + app = Dash(__name__) + app.layout = html.Div( + [ + dcc.Dropdown(id="dd", options=["a", "b"], value="a"), + dcc.Dropdown(id="dd2"), + html.Div(id="out"), + ] + ) + + @app.callback( + Output("dd2", "options"), + Output("out", "children"), + Input("dd", "value"), + ) + def update(val): + return [{"label": val, "value": val}], f"selected: {val}" + + tools = _tools_list(app) + tool_name = next(t["name"] for t in tools if "update" in t["name"]) + structured = _call_tool_structured(app, tool_name, {"val": "b"}) + assert structured["response"]["dd2"]["options"] == [{"label": "b", "value": "b"}] + assert structured["response"]["out"]["children"] == "selected: b" + + +def test_mcpx002_omitted_kwargs_default_to_none(): + app = Dash(__name__) + app.layout = html.Div( + [ + dcc.Dropdown(id="dd", options=["a"]), + dcc.Input(id="inp"), + html.Div(id="out"), + ] + ) + + @app.callback( + Output("out", "children"), + Input("dd", "value"), + State("inp", "value"), + ) + def update(selected, text): + return f"{selected}-{text}" + + tools = _tools_list(app) + tool_name = next(t["name"] for t in tools if "update" in t["name"]) + assert _call_tool_output(app, tool_name, {"selected": "a"}, "out") == "a-None" + + +def test_mcpx003_no_output_callback(): + app = Dash(__name__) + app.layout = html.Div( + [ + html.Button(id="btn"), + html.Div(id="display"), + ] + ) + + @app.callback(Input("btn", "n_clicks")) + def server_cb(n): + set_props("display", {"children": f"Clicked {n} times"}) + + tools = _tools_list(app) + tool_names = [t["name"] for t in tools] + assert "server_cb" in tool_names + + +def test_mcpx004_prevent_update(): + app = Dash(__name__) + app.layout = html.Div( + [ + dcc.Input(id="inp", value="hello"), + html.Div(id="out"), + ] + ) + + @app.callback(Output("out", "children"), Input("inp", "value")) + def update(val): + if val == "block": + raise PreventUpdate + return f"got: {val}" + + tools = _tools_list(app) + tool_name = next(t["name"] for t in tools if "update" in t["name"]) + assert _call_tool_output(app, tool_name, {"val": "test"}, "out") == "got: test" + + +def test_mcpx005_with_state(): + app = Dash(__name__) + app.layout = html.Div( + [ + html.Div(id="trigger"), + html.Div(id="store"), + html.Div(id="result"), + ] + ) + + @app.callback( + Output("result", "children"), + Input("trigger", "children"), + State("store", "children"), + ) + def with_state(trigger, store): + return f"{trigger}-{store}" + + tools = _tools_list(app) + tool_name = next(t["name"] for t in tools if "with_state" in t["name"]) + assert ( + _call_tool_output( + app, + tool_name, + {"trigger": "click", "store": "data"}, + "result", + ) + == "click-data" + ) + + +def test_mcpx006_dict_inputs(): + app = Dash(__name__) + app.layout = html.Div( + [ + dcc.Input(id="x-input", value="hello"), + dcc.Input(id="y-input", value="world"), + html.Div(id="dict-out"), + ] + ) + + @app.callback( + Output("dict-out", "children"), + inputs={ + "x_val": Input("x-input", "value"), + "y_val": Input("y-input", "value"), + }, + ) + def combine(**kwargs): + return f"{kwargs['x_val']}-{kwargs['y_val']}" + + tools = _tools_list(app) + tool_name = next(t["name"] for t in tools if "combine" in t["name"]) + assert ( + _call_tool_output( + app, + tool_name, + {"x_val": "foo", "y_val": "bar"}, + "dict-out", + ) + == "foo-bar" + ) + + +def test_mcpx007_positional_inputs(): + app = Dash(__name__) + app.layout = html.Div( + [ + dcc.Input(id="a-input", value="A"), + html.Div(id="pos-out"), + ] + ) + + @app.callback(Output("pos-out", "children"), Input("a-input", "value")) + def echo(val): + return f"got:{val}" + + tools = _tools_list(app) + tool_name = next(t["name"] for t in tools if "echo" in t["name"]) + assert _call_tool_output(app, tool_name, {"val": "test"}, "pos-out") == "got:test" + + +def test_mcpx008_dict_inputs_with_state(): + app = Dash(__name__) + app.layout = html.Div( + [ + dcc.Input(id="inp", value="hi"), + html.Div(id="st", children="state-val"), + html.Div(id="ds-out"), + ] + ) + + @app.callback( + Output("ds-out", "children"), + inputs={"trigger": Input("inp", "value")}, + state={"kept": State("st", "children")}, + ) + def with_dict_state(**kwargs): + return f"{kwargs['trigger']}+{kwargs['kept']}" + + tools = _tools_list(app) + tool_name = next(t["name"] for t in tools if "with_dict_state" in t["name"]) + assert ( + _call_tool_output( + app, + tool_name, + {"trigger": "hey", "kept": "saved"}, + "ds-out", + ) + == "hey+saved" + ) diff --git a/tests/unit/mcp/tools/test_run_callback.py b/tests/unit/mcp/tools/test_run_callback.py deleted file mode 100644 index 00f4e5b7b1..0000000000 --- a/tests/unit/mcp/tools/test_run_callback.py +++ /dev/null @@ -1,246 +0,0 @@ -"""Tests for callback dispatch execution via MCP tools.""" - -from dash import Dash, Input, Output, State, dcc, html -from dash.exceptions import PreventUpdate -from dash.mcp._server import _process_mcp_message - -from tests.unit.mcp.conftest import _setup_mcp - - -def _msg(method, params=None, request_id=1): - d = {"jsonrpc": "2.0", "method": method, "id": request_id} - d["params"] = params if params is not None else {} - return d - - -def _mcp(app, method, params=None, request_id=1): - with app.server.test_request_context(): - _setup_mcp(app) - return _process_mcp_message(_msg(method, params, request_id)) - - -def _tools_list(app): - return _mcp(app, "tools/list")["result"]["tools"] - - -def _call_tool_structured(app, tool_name, arguments=None): - result = _mcp(app, "tools/call", {"name": tool_name, "arguments": arguments or {}}) - return result["result"]["structuredContent"] - - -def _call_tool_output( - app, tool_name, arguments=None, component_id=None, prop="children" -): - structured = _call_tool_structured(app, tool_name, arguments) - response = structured["response"] - if component_id is None: - component_id = next(iter(response)) - return response[component_id][prop] - - -class TestRunCallback: - def test_multi_output(self): - app = Dash(__name__) - app.layout = html.Div( - [ - dcc.Dropdown(id="dd", options=["a", "b"], value="a"), - dcc.Dropdown(id="dd2"), - html.Div(id="out"), - ] - ) - - @app.callback( - Output("dd2", "options"), - Output("out", "children"), - Input("dd", "value"), - ) - def update(val): - return [{"label": val, "value": val}], f"selected: {val}" - - tools = _tools_list(app) - tool_name = next(t["name"] for t in tools if "update" in t["name"]) - structured = _call_tool_structured(app, tool_name, {"val": "b"}) - assert structured["response"]["dd2"]["options"] == [ - {"label": "b", "value": "b"} - ] - assert structured["response"]["out"]["children"] == "selected: b" - - def test_omitted_kwargs_default_to_none(self): - app = Dash(__name__) - app.layout = html.Div( - [ - dcc.Dropdown(id="dd", options=["a"]), - dcc.Input(id="inp"), - html.Div(id="out"), - ] - ) - - @app.callback( - Output("out", "children"), - Input("dd", "value"), - State("inp", "value"), - ) - def update(selected, text): - return f"{selected}-{text}" - - tools = _tools_list(app) - tool_name = next(t["name"] for t in tools if "update" in t["name"]) - assert _call_tool_output(app, tool_name, {"selected": "a"}, "out") == "a-None" - - def test_no_output_callback(self): - app = Dash(__name__) - app.layout = html.Div( - [ - html.Button(id="btn"), - html.Div(id="display"), - ] - ) - - @app.callback(Input("btn", "n_clicks")) - def server_cb(n): - from dash import set_props - - set_props("display", {"children": f"Clicked {n} times"}) - - tools = _tools_list(app) - tool_names = [t["name"] for t in tools] - assert "server_cb" in tool_names - - def test_prevent_update(self): - app = Dash(__name__) - app.layout = html.Div( - [ - dcc.Input(id="inp", value="hello"), - html.Div(id="out"), - ] - ) - - @app.callback(Output("out", "children"), Input("inp", "value")) - def update(val): - if val == "block": - raise PreventUpdate - return f"got: {val}" - - tools = _tools_list(app) - tool_name = next(t["name"] for t in tools if "update" in t["name"]) - assert _call_tool_output(app, tool_name, {"val": "test"}, "out") == "got: test" - - def test_with_state(self): - app = Dash(__name__) - app.layout = html.Div( - [ - html.Div(id="trigger"), - html.Div(id="store"), - html.Div(id="result"), - ] - ) - - @app.callback( - Output("result", "children"), - Input("trigger", "children"), - State("store", "children"), - ) - def with_state(trigger, store): - return f"{trigger}-{store}" - - tools = _tools_list(app) - tool_name = next(t["name"] for t in tools if "with_state" in t["name"]) - assert ( - _call_tool_output( - app, - tool_name, - { - "trigger": "click", - "store": "data", - }, - "result", - ) - == "click-data" - ) - - def test_dict_inputs(self): - app = Dash(__name__) - app.layout = html.Div( - [ - dcc.Input(id="x-input", value="hello"), - dcc.Input(id="y-input", value="world"), - html.Div(id="dict-out"), - ] - ) - - @app.callback( - Output("dict-out", "children"), - inputs={ - "x_val": Input("x-input", "value"), - "y_val": Input("y-input", "value"), - }, - ) - def combine(**kwargs): - return f"{kwargs['x_val']}-{kwargs['y_val']}" - - tools = _tools_list(app) - tool_name = next(t["name"] for t in tools if "combine" in t["name"]) - assert ( - _call_tool_output( - app, - tool_name, - { - "x_val": "foo", - "y_val": "bar", - }, - "dict-out", - ) - == "foo-bar" - ) - - def test_positional_inputs(self): - app = Dash(__name__) - app.layout = html.Div( - [ - dcc.Input(id="a-input", value="A"), - html.Div(id="pos-out"), - ] - ) - - @app.callback(Output("pos-out", "children"), Input("a-input", "value")) - def echo(val): - return f"got:{val}" - - tools = _tools_list(app) - tool_name = next(t["name"] for t in tools if "echo" in t["name"]) - assert ( - _call_tool_output(app, tool_name, {"val": "test"}, "pos-out") == "got:test" - ) - - def test_dict_inputs_with_state(self): - app = Dash(__name__) - app.layout = html.Div( - [ - dcc.Input(id="inp", value="hi"), - html.Div(id="st", children="state-val"), - html.Div(id="ds-out"), - ] - ) - - @app.callback( - Output("ds-out", "children"), - inputs={"trigger": Input("inp", "value")}, - state={"kept": State("st", "children")}, - ) - def with_dict_state(**kwargs): - return f"{kwargs['trigger']}+{kwargs['kept']}" - - tools = _tools_list(app) - tool_name = next(t["name"] for t in tools if "with_dict_state" in t["name"]) - assert ( - _call_tool_output( - app, - tool_name, - { - "trigger": "hey", - "kept": "saved", - }, - "ds-out", - ) - == "hey+saved" - ) From 8c7d7208003aa9a0078ec764b0ad8e54ee420044 Mon Sep 17 00:00:00 2001 From: Adrian Borrmann Date: Mon, 4 May 2026 16:02:31 -0600 Subject: [PATCH 61/80] Add decorator to make any function mcp-enabled --- dash/dash.py | 1 + dash/mcp/__init__.py | 2 + dash/mcp/_decorator.py | 51 ++++++ dash/mcp/_server.py | 5 + dash/mcp/primitives/tools/__init__.py | 2 + .../tools/tool_decorated_mcp_functions.py | 148 +++++++++++++++ .../test_tool_decorated_mcp_functions.py | 160 +++++++++++++++++ tests/unit/mcp/conftest.py | 10 +- .../mcp/tools/test_mcp_enabled_decorator.py | 170 ++++++++++++++++++ 9 files changed, 547 insertions(+), 2 deletions(-) create mode 100644 dash/mcp/_decorator.py create mode 100644 dash/mcp/primitives/tools/tool_decorated_mcp_functions.py create mode 100644 tests/integration/mcp/primitives/tools/test_tool_decorated_mcp_functions.py create mode 100644 tests/unit/mcp/tools/test_mcp_enabled_decorator.py diff --git a/dash/dash.py b/dash/dash.py index c34dca722b..37eb7a1ffb 100644 --- a/dash/dash.py +++ b/dash/dash.py @@ -615,6 +615,7 @@ def __init__( # pylint: disable=too-many-statements, too-many-branches # same deps as a list to catch duplicate outputs, and to send to the front end self._callback_list: list = [] self.callback_api_paths: dict = {} + self.mcp_decorated_functions: dict = {} # list of inline scripts self._inline_scripts: list = [] diff --git a/dash/mcp/__init__.py b/dash/mcp/__init__.py index 2bc4757f13..2e6edffdb2 100644 --- a/dash/mcp/__init__.py +++ b/dash/mcp/__init__.py @@ -1,7 +1,9 @@ """Dash MCP (Model Context Protocol) server integration.""" +from dash.mcp._decorator import mcp_enabled from dash.mcp._server import enable_mcp_server __all__ = [ "enable_mcp_server", + "mcp_enabled", ] diff --git a/dash/mcp/_decorator.py b/dash/mcp/_decorator.py new file mode 100644 index 0000000000..1b85316207 --- /dev/null +++ b/dash/mcp/_decorator.py @@ -0,0 +1,51 @@ +"""Decorator to expose plain Python functions as MCP tools.""" + +from __future__ import annotations + +import functools +from typing import Any, Callable, Optional + +from typing_extensions import TypedDict + + +class MCPToolRegistration(TypedDict): + fn: Callable[..., Any] + expose_docstring: Optional[bool] + + +MCP_DECORATED_FUNCTIONS: dict[str, MCPToolRegistration] = {} + + +def mcp_enabled( + func: Callable[..., Any] | None = None, + *, + name: str | None = None, + expose_docstring: Optional[bool] = None, +) -> Callable[..., Any]: + """Mark a function as an MCP tool. + + Supports both bare and parameterised usage:: + + @mcp_enabled + def my_tool(x: int) -> str: ... + + @mcp_enabled(name="custom_name", expose_docstring=True) + def my_tool(x: int) -> str: ... + """ + + def _wrap(fn: Callable[..., Any]) -> Callable[..., Any]: + tool_name = name if name else fn.__name__ + MCP_DECORATED_FUNCTIONS[tool_name] = MCPToolRegistration( + fn=fn, + expose_docstring=expose_docstring, + ) + + @functools.wraps(fn) + def wrapper(*args: Any, **kwargs: Any) -> Any: + return fn(*args, **kwargs) + + return wrapper + + if func is not None: + return _wrap(func) + return _wrap diff --git a/dash/mcp/_server.py b/dash/mcp/_server.py index 07b35cd373..64323c09af 100644 --- a/dash/mcp/_server.py +++ b/dash/mcp/_server.py @@ -25,6 +25,7 @@ from dash import get_app from dash._get_app import with_app_context_factory +from dash.mcp._decorator import MCP_DECORATED_FUNCTIONS from dash.mcp.primitives import ( call_tool, list_resource_templates, @@ -46,6 +47,10 @@ def enable_mcp_server(app: Dash, mcp_path: str) -> None: """Add MCP routes to a Dash/Flask app.""" + + app.mcp_decorated_functions = dict(MCP_DECORATED_FUNCTIONS) + MCP_DECORATED_FUNCTIONS.clear() + # -- Streamable HTTP endpoint -------------------------------------------- def mcp_handler() -> Response: diff --git a/dash/mcp/primitives/tools/__init__.py b/dash/mcp/primitives/tools/__init__.py index 7fa1f4aefb..b8f12a1dbd 100644 --- a/dash/mcp/primitives/tools/__init__.py +++ b/dash/mcp/primitives/tools/__init__.py @@ -9,12 +9,14 @@ from dash.mcp.types import ToolNotFoundError from .base import MCPToolProvider +from .tool_decorated_mcp_functions import DecoratedFunctionTools from .tool_get_dash_component import GetDashComponentTool from .tools_callbacks import CallbackTools _TOOL_PROVIDERS: list[type[MCPToolProvider]] = [ CallbackTools, GetDashComponentTool, + DecoratedFunctionTools, ] diff --git a/dash/mcp/primitives/tools/tool_decorated_mcp_functions.py b/dash/mcp/primitives/tools/tool_decorated_mcp_functions.py new file mode 100644 index 0000000000..c135455c88 --- /dev/null +++ b/dash/mcp/primitives/tools/tool_decorated_mcp_functions.py @@ -0,0 +1,148 @@ +"""MCP tools backed by ``@mcp_enabled``-decorated functions.""" + +from __future__ import annotations + +import inspect +import json +import typing +from typing import Any, Callable + +from mcp.types import CallToolResult, TextContent, Tool + +from dash import get_app +from dash.mcp._decorator import MCPToolRegistration +from dash.mcp.primitives.tools.input_schemas import get_input_schema +from dash.mcp.primitives.tools.input_schemas.schema_callback_type_annotations import ( + annotation_to_json_schema, +) +from dash.mcp.types import MCPInput, is_nullable + +from .base import MCPToolProvider + + +def _build_inputs(fn: Callable[..., Any]) -> list[MCPInput]: + """Build an ``MCPInput`` from each of the function's arguments.""" + try: + hints = typing.get_type_hints(fn) + except Exception: # pylint: disable=broad-exception-caught + hints = getattr(fn, "__annotations__", {}) + + sig = inspect.signature(fn) + inputs: list[MCPInput] = [] + + for name, param in sig.parameters.items(): + annotation = hints.get(name) + + has_default = param.default is not inspect.Parameter.empty + required = not has_default and ( + annotation is None or not is_nullable(annotation) + ) + + inputs.append( + MCPInput( + name=name, + id_and_prop="", + component_id="", + property="", + annotation=annotation, + component_type=None, + component=None, + required=required, + initial_value=param.default if has_default else None, + upstream_output=None, + ) + ) + return inputs + + +def _build_output_schema(fn: Callable[..., Any]) -> dict[str, Any]: + """Build a JSON Schema ``outputSchema`` from the return annotation. + + The schema wraps the return type in ``{"result": }`` to match + the object that ``call_tool`` returns as ``structuredContent``. + """ + try: + hints = typing.get_type_hints(fn) + except Exception: # pylint: disable=broad-exception-caught + hints = getattr(fn, "__annotations__", {}) + + ret = hints.get("return") + if ret is None: + return {} + + inner = annotation_to_json_schema(ret) + if inner is None: + return {} + + return { + "type": "object", + "properties": {"result": inner}, + "required": ["result"], + } + + +def _build_tool(tool_name: str, reg: MCPToolRegistration) -> Tool: + fn = reg["fn"] + inputs = _build_inputs(fn) + properties = {p["name"]: get_input_schema(p) for p in inputs} + required = [p["name"] for p in inputs if p["required"]] + + input_schema: dict[str, Any] = {"type": "object", "properties": properties} + if required: + input_schema["required"] = required + + expose_docstring = reg["expose_docstring"] + if expose_docstring is None: + expose_docstring = get_app().config.get("mcp_expose_docstrings", False) + + description = "MCP tool" + if expose_docstring: + docstring = getattr(fn, "__doc__", None) + if docstring: + description = docstring.strip() + + return Tool( + name=tool_name, + description=description, + inputSchema=input_schema, + outputSchema=_build_output_schema(fn), + ) + + +class DecoratedFunctionTools(MCPToolProvider): + """Exposes ``@mcp_enabled``-decorated functions as MCP tools.""" + + @classmethod + def _registry(cls) -> dict[str, MCPToolRegistration]: + return get_app().mcp_decorated_functions + + @classmethod + def get_tool_names(cls) -> set[str]: + return set(cls._registry().keys()) + + @classmethod + def list_tools(cls) -> list[Tool]: + return [_build_tool(name, reg) for name, reg in cls._registry().items()] + + @classmethod + def call_tool(cls, tool_name: str, arguments: dict[str, Any]) -> CallToolResult: + reg = cls._registry().get(tool_name) + if reg is None: + return CallToolResult( + content=[TextContent(type="text", text=f"Tool not found: {tool_name}")], + isError=True, + ) + fn = reg["fn"] + try: + result = fn(**arguments) + except Exception as exc: # pylint: disable=broad-exception-caught + return CallToolResult( + content=[TextContent(type="text", text=f"{type(exc).__name__}: {exc}")], + isError=True, + ) + + serialized = json.dumps(result, default=str) + return CallToolResult( + content=[TextContent(type="text", text=serialized)], + structuredContent={"result": result}, + ) diff --git a/tests/integration/mcp/primitives/tools/test_tool_decorated_mcp_functions.py b/tests/integration/mcp/primitives/tools/test_tool_decorated_mcp_functions.py new file mode 100644 index 0000000000..afbaf415a8 --- /dev/null +++ b/tests/integration/mcp/primitives/tools/test_tool_decorated_mcp_functions.py @@ -0,0 +1,160 @@ +"""Integration tests for @mcp_enabled decorated function tools.""" + +import json +from typing import Optional + +from dash import Dash, html +from dash.mcp import mcp_enabled + +from tests.integration.mcp.conftest import _mcp_call_tool, _mcp_tools + +BUILTINS = {"get_dash_component"} + + +def test_mcpd001_bare_decorator_appears_as_tool(dash_duo): + @mcp_enabled + def add_numbers(a: int, b: int) -> int: + """Add two numbers together.""" + return a + b + + app = Dash(__name__) + app.layout = html.Div(id="root") + dash_duo.start_server(app) + + tools = _mcp_tools(dash_duo.server.url) + names = [t["name"] for t in tools] + assert "add_numbers" in names + + tool = next(t for t in tools if t["name"] == "add_numbers") + assert "Add two numbers" not in tool["description"] + + +def test_mcpd002_expose_docstring(dash_duo): + @mcp_enabled(expose_docstring=True) + def add_numbers(a: int, b: int) -> int: + """Add two numbers together.""" + return a + b + + app = Dash(__name__) + app.layout = html.Div(id="root") + dash_duo.start_server(app) + + tool = next( + t for t in _mcp_tools(dash_duo.server.url) if t["name"] == "add_numbers" + ) + assert "Add two numbers together" in tool["description"] + + +def test_mcpd003_custom_name_overrides_function_name(dash_duo): + @mcp_enabled(name="sum_values") + def add_numbers(a: int, b: int) -> int: + """Add two numbers together.""" + return a + b + + app = Dash(__name__) + app.layout = html.Div(id="root") + dash_duo.start_server(app) + + tools = _mcp_tools(dash_duo.server.url) + names = [t["name"] for t in tools] + assert "sum_values" in names + assert "add_numbers" not in names + + +def test_mcpd004_typed_params_produce_schema(dash_duo): + @mcp_enabled + def greet(name: str, times: int) -> str: + """Greet someone.""" + return name * times + + app = Dash(__name__) + app.layout = html.Div(id="root") + dash_duo.start_server(app) + + tool = next(t for t in _mcp_tools(dash_duo.server.url) if t["name"] == "greet") + schema = tool["inputSchema"] + assert schema["type"] == "object" + assert schema["properties"]["name"]["type"] == "string" + assert schema["properties"]["times"]["type"] == "integer" + assert set(schema["required"]) == {"name", "times"} + + +def test_mcpd005_optional_param_not_required(dash_duo): + @mcp_enabled + def search(query: str, limit: Optional[int] = None) -> str: + """Search for things.""" + return query + + app = Dash(__name__) + app.layout = html.Div(id="root") + dash_duo.start_server(app) + + tool = next(t for t in _mcp_tools(dash_duo.server.url) if t["name"] == "search") + schema = tool["inputSchema"] + assert "query" in schema["required"] + assert "limit" not in schema["required"] + + +def test_mcpd006_return_annotation_becomes_output_schema(dash_duo): + @mcp_enabled + def compute(x: int) -> str: + """Compute something.""" + return str(x) + + app = Dash(__name__) + app.layout = html.Div(id="root") + dash_duo.start_server(app) + + tool = next(t for t in _mcp_tools(dash_duo.server.url) if t["name"] == "compute") + assert tool["outputSchema"]["type"] == "object" + assert tool["outputSchema"]["properties"]["result"]["type"] == "string" + + +def test_mcpd007_call_returns_result(dash_duo): + @mcp_enabled + def multiply(a: int, b: int) -> int: + """Multiply two numbers.""" + return a * b + + app = Dash(__name__) + app.layout = html.Div(id="root") + dash_duo.start_server(app) + + resp = _mcp_call_tool(dash_duo.server.url, "multiply", {"a": 3, "b": 7}) + result = resp["result"] + assert result["isError"] is not True + text = result["content"][0]["text"] + assert json.loads(text) == 21 + + +def test_mcpd008_call_with_custom_name(dash_duo): + @mcp_enabled(name="product") + def multiply(a: int, b: int) -> int: + """Multiply two numbers.""" + return a * b + + app = Dash(__name__) + app.layout = html.Div(id="root") + dash_duo.start_server(app) + + resp = _mcp_call_tool(dash_duo.server.url, "product", {"a": 4, "b": 5}) + result = resp["result"] + assert result["isError"] is not True + text = result["content"][0]["text"] + assert json.loads(text) == 20 + + +def test_mcpd009_call_error_returns_is_error(dash_duo): + @mcp_enabled + def fail_hard(x: int) -> str: + """Always fails.""" + raise ValueError("boom") + + app = Dash(__name__) + app.layout = html.Div(id="root") + dash_duo.start_server(app) + + resp = _mcp_call_tool(dash_duo.server.url, "fail_hard", {"x": 1}) + result = resp["result"] + assert result["isError"] is True + assert "boom" in result["content"][0]["text"] diff --git a/tests/unit/mcp/conftest.py b/tests/unit/mcp/conftest.py index 2f7fbc1898..a3ddd191aa 100644 --- a/tests/unit/mcp/conftest.py +++ b/tests/unit/mcp/conftest.py @@ -9,10 +9,13 @@ if sys.version_info < (3, 10): collect_ignore_glob.append("*") else: - from dash.mcp.primitives.tools import ( + from dash.mcp._decorator import ( # pylint: disable=wrong-import-position + MCP_DECORATED_FUNCTIONS, + ) + from dash.mcp.primitives.tools import ( # pylint: disable=wrong-import-position call_tool, list_tools, - ) # pylint: disable=wrong-import-position + ) from dash.mcp.primitives.tools.callback_adapter_collection import ( # pylint: disable=wrong-import-position CallbackAdapterCollection, ) @@ -23,6 +26,9 @@ def _setup_mcp(app): """Set up MCP for an app in tests.""" app_context.set(app) + if MCP_DECORATED_FUNCTIONS: + app.mcp_decorated_functions = dict(MCP_DECORATED_FUNCTIONS) + MCP_DECORATED_FUNCTIONS.clear() app.mcp_callback_map = CallbackAdapterCollection(app) return app diff --git a/tests/unit/mcp/tools/test_mcp_enabled_decorator.py b/tests/unit/mcp/tools/test_mcp_enabled_decorator.py new file mode 100644 index 0000000000..c033205f5d --- /dev/null +++ b/tests/unit/mcp/tools/test_mcp_enabled_decorator.py @@ -0,0 +1,170 @@ +"""Unit tests for the @mcp_enabled decorator and its tool integration.""" + +from typing import Optional + +from dash import Dash, html +from dash.mcp import mcp_enabled + +from tests.unit.mcp.conftest import _setup_mcp, _tools_list, _call_tool + +BUILTINS = {"get_dash_component"} + + +def _make_app_with_decorated(): + app = Dash(__name__) + app.layout = html.Div(id="root") + return _setup_mcp(app) + + +def _user_tools(app): + return [t for t in _tools_list(app) if t.name not in BUILTINS] + + +def test_mcpd001_bare_decorator_preserves_function(): + @mcp_enabled + def double(x: int) -> int: + return x * 2 + + assert double(5) == 10 + + +def test_mcpd002_parameterised_decorator_preserves_function(): + @mcp_enabled(name="doubler") + def double(x: int) -> int: + return x * 2 + + assert double(5) == 10 + + +def test_mcpd003_bare_decorator_appears_as_tool(): + @mcp_enabled + def add_numbers(a: int, b: int) -> int: + """Add two numbers together.""" + return a + b + + app = _make_app_with_decorated() + names = {t.name for t in _user_tools(app)} + assert "add_numbers" in names + + +def test_mcpd004_custom_name_overrides_function_name(): + @mcp_enabled(name="sum_values") + def add_numbers(a: int, b: int) -> int: + """Add two numbers together.""" + return a + b + + app = _make_app_with_decorated() + names = {t.name for t in _user_tools(app)} + assert "sum_values" in names + assert "add_numbers" not in names + + +def test_mcpd005_docstring_hidden_by_default(): + @mcp_enabled + def add_numbers(a: int, b: int) -> int: + """Add two numbers together.""" + return a + b + + app = _make_app_with_decorated() + tool = next(t for t in _user_tools(app) if t.name == "add_numbers") + assert "Add two numbers" not in tool.description + + +def test_mcpd006_docstring_exposed_when_opted_in(): + @mcp_enabled(expose_docstring=True) + def add_numbers(a: int, b: int) -> int: + """Add two numbers together.""" + return a + b + + app = _make_app_with_decorated() + tool = next(t for t in _user_tools(app) if t.name == "add_numbers") + assert "Add two numbers together" in tool.description + + +def test_mcpd007_typed_params_produce_schema(): + @mcp_enabled + def greet(name: str, times: int) -> str: + """Greet someone.""" + return name * times + + app = _make_app_with_decorated() + tool = next(t for t in _user_tools(app) if t.name == "greet") + schema = tool.inputSchema + assert schema["type"] == "object" + assert schema["properties"]["name"]["type"] == "string" + assert schema["properties"]["times"]["type"] == "integer" + assert set(schema["required"]) == {"name", "times"} + + +def test_mcpd008_optional_param_not_required(): + @mcp_enabled + def search(query: str, limit: Optional[int] = None) -> str: + """Search for things.""" + return query + + app = _make_app_with_decorated() + tool = next(t for t in _user_tools(app) if t.name == "search") + schema = tool.inputSchema + assert "query" in schema["required"] + assert "limit" not in schema["required"] + + +def test_mcpd009_typed_param_with_default_not_required(): + @mcp_enabled + def filter_range(min_val: float = -180.0, max_val: float = 180.0) -> list[str]: + """Filter by range.""" + return [] + + app = _make_app_with_decorated() + tool = next(t for t in _user_tools(app) if t.name == "filter_range") + schema = tool.inputSchema + assert "required" not in schema or "min_val" not in schema.get("required", []) + assert "required" not in schema or "max_val" not in schema.get("required", []) + + +def test_mcpd010_return_annotation_becomes_output_schema(): + @mcp_enabled + def compute(x: int) -> str: + """Compute something.""" + return str(x) + + app = _make_app_with_decorated() + tool = next(t for t in _user_tools(app) if t.name == "compute") + assert tool.outputSchema["type"] == "object" + assert tool.outputSchema["properties"]["result"]["type"] == "string" + + +def test_mcpd011_call_returns_result(): + @mcp_enabled + def multiply(a: int, b: int) -> int: + """Multiply two numbers.""" + return a * b + + app = _make_app_with_decorated() + result = _call_tool(app, "multiply", {"a": 3, "b": 7}) + assert result.isError is not True + assert result.structuredContent["result"] == 21 + + +def test_mcpd012_call_with_custom_name(): + @mcp_enabled(name="product") + def multiply(a: int, b: int) -> int: + """Multiply two numbers.""" + return a * b + + app = _make_app_with_decorated() + result = _call_tool(app, "product", {"a": 4, "b": 5}) + assert result.isError is not True + assert result.structuredContent["result"] == 20 + + +def test_mcpd013_call_error_returns_is_error(): + @mcp_enabled + def fail_hard(x: int) -> str: + """Always fails.""" + raise ValueError("boom") + + app = _make_app_with_decorated() + result = _call_tool(app, "fail_hard", {"x": 1}) + assert result.isError is True + assert "boom" in result.content[0].text From 4f7329ca865aa7c5c5df6f533380d388bae3dde3 Mon Sep 17 00:00:00 2001 From: Adrian Borrmann Date: Fri, 8 May 2026 15:24:29 -0600 Subject: [PATCH 62/80] Notify connected agents of server restarts and hot-reloads so they can reload --- dash/mcp/_server.py | 86 ++++++++- .../tools/results/result_dataframe.py | 2 +- tests/integration/mcp/test_mcp_session.py | 128 ++++++++++++++ tests/unit/mcp/test_mcp_session.py | 164 ++++++++++++++++++ 4 files changed, 372 insertions(+), 8 deletions(-) create mode 100644 tests/integration/mcp/test_mcp_session.py create mode 100644 tests/unit/mcp/test_mcp_session.py diff --git a/dash/mcp/_server.py b/dash/mcp/_server.py index 64323c09af..f7632b134b 100644 --- a/dash/mcp/_server.py +++ b/dash/mcp/_server.py @@ -8,6 +8,7 @@ import json import logging +import uuid from typing import TYPE_CHECKING, Any from flask import Response, request @@ -51,6 +52,24 @@ def enable_mcp_server(app: Dash, mcp_path: str) -> None: app.mcp_decorated_functions = dict(MCP_DECORATED_FUNCTIONS) MCP_DECORATED_FUNCTIONS.clear() + _session_id: str | None = None + + def _get_or_create_session_id() -> str: + """Read the hot-reload hash or generate a stable fallback.""" + # pylint: disable=protected-access + reload_hash = app._hot_reload.hash + return reload_hash if reload_hash is not None else uuid.uuid4().hex + + def _is_session_stale(client_session_id: str | None) -> bool: + """True when the client's session doesn't match or the hash changed.""" + if client_session_id != _session_id: + return True + # pylint: disable=protected-access + reload_hash = app._hot_reload.hash + if reload_hash is None: + return False + return reload_hash != _session_id + # -- Streamable HTTP endpoint -------------------------------------------- def mcp_handler() -> Response: @@ -75,6 +94,42 @@ def _handle_get() -> Response: status=405, ) + def _check_session(method: str) -> bool: + """Validate the session header. + + Raises ``ValueError`` when the header is missing. + Returns ``True`` when the session was stale and transparently + recovered, or ``False`` when the session is valid. + """ + nonlocal _session_id + if method == "initialize": + _session_id = _get_or_create_session_id() + return False + client_session_id = request.headers.get("Mcp-Session-Id") + if _session_id is not None and not client_session_id: + raise ValueError("Missing Mcp-Session-Id header") + if _is_session_stale(client_session_id): + _session_id = _get_or_create_session_id() + logger.debug("MCP session recovered: %s", _session_id) + return True + return False + + def _json_response(*messages: dict) -> Response: + """Wrap one or more JSON-RPC messages in a Flask Response. + + A single message is serialised as a JSON object; multiple + messages are serialised as a JSON array. + """ + body = messages[0] if len(messages) == 1 else list(messages) + resp = Response( + json.dumps(body), + content_type="application/json", + status=200, + ) + if _session_id is not None: + resp.headers["Mcp-Session-Id"] = _session_id + return resp + def _handle_post() -> Response: content_type = request.content_type or "" if "application/json" not in content_type: @@ -92,16 +147,33 @@ def _handle_post() -> Response: status=400, ) + method = data.get("method", "") + + try: + is_stale_session = _check_session(method) + except ValueError as err: + return Response( + json.dumps({"error": str(err)}), + content_type="application/json", + status=400, + ) + response_data = _process_mcp_message(data) if response_data is None: return Response("", status=202) - return Response( - json.dumps(response_data), - content_type="application/json", - status=200, - ) + if is_stale_session: + return _json_response( + {"jsonrpc": "2.0", "method": "notifications/tools/list_changed"}, + { + "jsonrpc": "2.0", + "method": "notifications/resources/list_changed", + }, + response_data, + ) + + return _json_response(response_data) def _handle_delete() -> Response: # No sessions to terminate — server is stateless. @@ -129,8 +201,8 @@ def _handle_initialize() -> InitializeResult: return InitializeResult( protocolVersion=LATEST_PROTOCOL_VERSION, capabilities=ServerCapabilities( - tools=ToolsCapability(listChanged=False), - resources=ResourcesCapability(), + tools=ToolsCapability(listChanged=True), + resources=ResourcesCapability(listChanged=True), ), serverInfo=Implementation(name="Plotly Dash", version=__version__), instructions=( diff --git a/dash/mcp/primitives/tools/results/result_dataframe.py b/dash/mcp/primitives/tools/results/result_dataframe.py index 8a2130387c..780bb264ef 100644 --- a/dash/mcp/primitives/tools/results/result_dataframe.py +++ b/dash/mcp/primitives/tools/results/result_dataframe.py @@ -47,7 +47,7 @@ class DataFrameResult(ResultFormatter): """Produce a markdown table for tabular component output values.""" @classmethod - def format(cls, output: MCPOutput, returned_output_value: Any) -> list[TextContent]: + def format(cls, output: MCPOutput, returned_output_value: Any) -> list[TextContent]: # type: ignore[override] if not TABULAR.matches(output.get("component_type"), output["property"]): return [] if ( diff --git a/tests/integration/mcp/test_mcp_session.py b/tests/integration/mcp/test_mcp_session.py new file mode 100644 index 0000000000..f3cc2bfbd8 --- /dev/null +++ b/tests/integration/mcp/test_mcp_session.py @@ -0,0 +1,128 @@ +"""MCP session lifecycle — end-to-end over a real Dash server. + +Exercises the full MCP session flow (initialize → operate → hot-reload +recovery) against a live ``dash_duo`` server using real HTTP requests. +Unit-level checks (status codes, header mechanics) live in +``tests/unit/mcp/test_mcp_session.py``; these tests verify the broader +behavioral contract. +""" + +import requests + +from dash import Dash, Input, Output, html + +from tests.integration.mcp.conftest import _mcp_post + + +def _mcp_post_with_session( + server_url, method, params=None, request_id=1, session_id=None +): + """Like ``_mcp_post`` but forwards an ``Mcp-Session-Id`` header.""" + headers = {"Content-Type": "application/json"} + if session_id is not None: + headers["Mcp-Session-Id"] = session_id + return requests.post( + f"{server_url}/_mcp", + json={ + "jsonrpc": "2.0", + "method": method, + "id": request_id, + "params": params or {}, + }, + headers=headers, + timeout=5, + ) + + +def test_mcpse_e2e001_full_session_lifecycle(dash_duo): + """Initialize → tools/list → tools/call with session headers throughout.""" + app = Dash(__name__) + app.layout = html.Div([html.Div(id="inp"), html.Div(id="out")]) + + @app.callback(Output("out", "children"), Input("inp", "children")) + def echo(v): + return f"echo: {v}" + + dash_duo.start_server(app) + url = dash_duo.server.url + + init = _mcp_post_with_session(url, "initialize") + assert init.status_code == 200 + sid = init.headers.get("Mcp-Session-Id") + assert sid + + notif = _mcp_post_with_session( + url, "notifications/initialized", session_id=sid, request_id=None + ) + assert notif.status_code == 202 + + tools_resp = _mcp_post_with_session(url, "tools/list", session_id=sid, request_id=2) + assert tools_resp.status_code == 200 + tools = tools_resp.json()["result"]["tools"] + assert any("echo" in t["name"] for t in tools) + + tool_name = next(t["name"] for t in tools if "echo" in t["name"]) + call_resp = _mcp_post_with_session( + url, + "tools/call", + params={"name": tool_name, "arguments": {"v": "hello"}}, + session_id=sid, + request_id=3, + ) + assert call_resp.status_code == 200 + assert call_resp.headers.get("Mcp-Session-Id") == sid + + +def test_mcpse_e2e002_stale_session_recovers_with_notifications(dash_duo): + """Simulate a hot-reload hash change and verify transparent recovery.""" + app = Dash(__name__) + app.layout = html.Div([html.Div(id="inp"), html.Div(id="out")]) + + @app.callback(Output("out", "children"), Input("inp", "children")) + def echo(v): + return f"echo: {v}" + + dash_duo.start_server(app) + url = dash_duo.server.url + + app._hot_reload.hash = "original_hash" + + init = _mcp_post_with_session(url, "initialize") + sid = init.headers["Mcp-Session-Id"] + assert sid == "original_hash" + + resp = _mcp_post_with_session(url, "tools/list", session_id=sid, request_id=2) + assert resp.status_code == 200 + + app._hot_reload.hash = "new_hash" + + resp = _mcp_post_with_session(url, "tools/list", session_id=sid, request_id=3) + assert resp.status_code == 200 + new_sid = resp.headers["Mcp-Session-Id"] + assert new_sid == "new_hash" + + data = resp.json() + assert isinstance(data, list) + assert len(data) == 3 + assert data[0]["method"] == "notifications/tools/list_changed" + assert data[1]["method"] == "notifications/resources/list_changed" + assert "result" in data[2] + assert "tools" in data[2]["result"] + + resp = _mcp_post_with_session(url, "tools/list", session_id=new_sid, request_id=4) + assert resp.status_code == 200 + data = resp.json() + assert isinstance(data, dict) + assert "result" in data + + +def test_mcpse_e2e003_capabilities_advertise_list_changed(dash_duo): + """Server capabilities include listChanged for tools and resources.""" + app = Dash(__name__) + app.layout = html.Div(id="root") + dash_duo.start_server(app) + + resp = _mcp_post(dash_duo.server.url, "initialize") + caps = resp.json()["result"]["capabilities"] + assert caps["tools"]["listChanged"] is True + assert caps["resources"]["listChanged"] is True diff --git a/tests/unit/mcp/test_mcp_session.py b/tests/unit/mcp/test_mcp_session.py new file mode 100644 index 0000000000..029fd284ad --- /dev/null +++ b/tests/unit/mcp/test_mcp_session.py @@ -0,0 +1,164 @@ +"""MCP session management — ``Mcp-Session-Id`` header and hot-reload hash.""" + +import json +import sys + +import pytest + +if sys.version_info < (3, 10): + pytest.skip("MCP requires Python 3.10+", allow_module_level=True) + +from tests.unit.mcp.conftest import _make_app # pylint: disable=wrong-import-position + + +def _make_mcp_app(**kwargs): + return _make_app(enable_mcp=True, **kwargs) + + +def _post(client, method, params=None, request_id=1, session_id=None): + """POST a JSON-RPC message to the MCP endpoint.""" + headers = {"Content-Type": "application/json"} + if session_id is not None: + headers["Mcp-Session-Id"] = session_id + body = {"jsonrpc": "2.0", "method": method, "id": request_id} + body["params"] = params if params is not None else {} + return client.post("/_mcp", data=json.dumps(body), headers=headers) + + +# --------------------------------------------------------------------------- +# Tests +# --------------------------------------------------------------------------- + + +def test_mcpse001_initialize_returns_session_id(): + app = _make_mcp_app() + with app.server.test_client() as client: + resp = _post(client, "initialize") + assert resp.status_code == 200 + session_id = resp.headers.get("Mcp-Session-Id") + assert session_id is not None + assert len(session_id) > 0 + + +def test_mcpse002_request_without_session_after_init_returns_400(): + app = _make_mcp_app() + with app.server.test_client() as client: + _post(client, "initialize") + resp = _post(client, "tools/list") + assert resp.status_code == 400 + + +def test_mcpse003_request_with_valid_session_succeeds(): + app = _make_mcp_app() + with app.server.test_client() as client: + init_resp = _post(client, "initialize") + session_id = init_resp.headers["Mcp-Session-Id"] + resp = _post(client, "tools/list", session_id=session_id) + assert resp.status_code == 200 + data = json.loads(resp.data) + assert "result" in data + + +def test_mcpse004_stale_session_recovers_transparently(): + app = _make_mcp_app() + app._hot_reload.hash = "hash_v1" + + with app.server.test_client() as client: + init_resp = _post(client, "initialize") + old_session = init_resp.headers["Mcp-Session-Id"] + assert old_session == "hash_v1" + + app._hot_reload.hash = "hash_v2" + + resp = _post(client, "tools/list", session_id=old_session) + assert resp.status_code == 200 + + data = json.loads(resp.data) + assert isinstance(data, list) + assert len(data) == 3 + + new_session = resp.headers.get("Mcp-Session-Id") + assert new_session is not None + assert new_session == "hash_v2" + + +def test_mcpse005_stale_session_includes_list_changed_notifications(): + app = _make_mcp_app() + app._hot_reload.hash = "hash_v1" + + with app.server.test_client() as client: + init_resp = _post(client, "initialize") + old_session = init_resp.headers["Mcp-Session-Id"] + + app._hot_reload.hash = "hash_v2" + + resp = _post(client, "tools/list", session_id=old_session) + data = json.loads(resp.data) + + assert data[0]["method"] == "notifications/tools/list_changed" + assert data[1]["method"] == "notifications/resources/list_changed" + assert "result" in data[2] + + +def test_mcpse006_reinitialize_after_hot_reload_gets_new_session(): + app = _make_mcp_app() + app._hot_reload.hash = "hash_v1" + + with app.server.test_client() as client: + init_resp = _post(client, "initialize") + old_session = init_resp.headers["Mcp-Session-Id"] + assert old_session == "hash_v1" + + app._hot_reload.hash = "hash_v2" + + # Stale request triggers transparent recovery. + resp = _post(client, "tools/list", session_id=old_session) + assert resp.status_code == 200 + recovered_session = resp.headers["Mcp-Session-Id"] + assert recovered_session == "hash_v2" + + # Re-initialize picks up the new hash. + init_resp2 = _post(client, "initialize") + assert init_resp2.status_code == 200 + new_session = init_resp2.headers["Mcp-Session-Id"] + assert new_session == "hash_v2" + + # Subsequent requests with the new session work. + resp = _post(client, "tools/list", session_id=new_session) + assert resp.status_code == 200 + data = json.loads(resp.data) + assert "result" in data + + +def test_mcpse007_no_session_required_before_first_initialize(): + app = _make_mcp_app() + with app.server.test_client() as client: + resp = _post(client, "tools/list") + assert resp.status_code == 200 + + +def test_mcpse008_production_mode_generates_stable_session(): + app = _make_mcp_app() + assert app._hot_reload.hash is None + + with app.server.test_client() as client: + init_resp = _post(client, "initialize") + session_id = init_resp.headers["Mcp-Session-Id"] + assert session_id is not None + + resp = _post(client, "tools/list", session_id=session_id) + assert resp.status_code == 200 + + resp = _post(client, "tools/list", session_id=session_id) + assert resp.status_code == 200 + + +def test_mcpse009_session_header_on_every_response(): + app = _make_mcp_app() + with app.server.test_client() as client: + init_resp = _post(client, "initialize") + session_id = init_resp.headers["Mcp-Session-Id"] + + resp = _post(client, "tools/list", session_id=session_id) + assert resp.status_code == 200 + assert resp.headers.get("Mcp-Session-Id") == session_id From cf8e9d92bf4df58232bcda16cdd08310b688e9b7 Mon Sep 17 00:00:00 2001 From: Adrian Borrmann Date: Tue, 12 May 2026 18:08:09 -0600 Subject: [PATCH 63/80] Migrate MCP server to use v4.2.0 backend abstractions --- dash/mcp/_server.py | 78 ++++++++++++++----------------- tests/integration/mcp/conftest.py | 13 ------ 2 files changed, 36 insertions(+), 55 deletions(-) diff --git a/dash/mcp/_server.py b/dash/mcp/_server.py index f7632b134b..5e708da25c 100644 --- a/dash/mcp/_server.py +++ b/dash/mcp/_server.py @@ -11,7 +11,6 @@ import uuid from typing import TYPE_CHECKING, Any -from flask import Response, request from mcp.types import ( LATEST_PROTOCOL_VERSION, ErrorData, @@ -25,7 +24,6 @@ ) from dash import get_app -from dash._get_app import with_app_context_factory from dash.mcp._decorator import MCP_DECORATED_FUNCTIONS from dash.mcp.primitives import ( call_tool, @@ -47,7 +45,7 @@ def enable_mcp_server(app: Dash, mcp_path: str) -> None: - """Add MCP routes to a Dash/Flask app.""" + """Add MCP routes to a Dash app.""" app.mcp_decorated_functions = dict(MCP_DECORATED_FUNCTIONS) MCP_DECORATED_FUNCTIONS.clear() @@ -72,28 +70,6 @@ def _is_session_stale(client_session_id: str | None) -> bool: # -- Streamable HTTP endpoint -------------------------------------------- - def mcp_handler() -> Response: - if request.method == "POST": - return _handle_post() - if request.method == "GET": - return _handle_get() - if request.method == "DELETE": - return _handle_delete() - return Response( - json.dumps({"error": "Method not allowed"}), - content_type="application/json", - status=405, - ) - - def _handle_get() -> Response: - # MCP spec allows servers to opt out of GET-initiated SSE streams - # by returning 405. We don't push server-initiated events. - return Response( - json.dumps({"error": "Method not allowed"}), - content_type="application/json", - status=405, - ) - def _check_session(method: str) -> bool: """Validate the session header. @@ -102,10 +78,11 @@ def _check_session(method: str) -> bool: recovered, or ``False`` when the session is valid. """ nonlocal _session_id + adapter = app.backend.request_adapter() if method == "initialize": _session_id = _get_or_create_session_id() return False - client_session_id = request.headers.get("Mcp-Session-Id") + client_session_id = adapter.headers.get("Mcp-Session-Id") if _session_id is not None and not client_session_id: raise ValueError("Missing Mcp-Session-Id header") if _is_session_stale(client_session_id): @@ -114,14 +91,14 @@ def _check_session(method: str) -> bool: return True return False - def _json_response(*messages: dict) -> Response: - """Wrap one or more JSON-RPC messages in a Flask Response. + def _json_response(*messages: dict): + """Wrap one or more JSON-RPC messages in a response. A single message is serialised as a JSON object; multiple messages are serialised as a JSON array. """ body = messages[0] if len(messages) == 1 else list(messages) - resp = Response( + resp = app.backend.make_response( json.dumps(body), content_type="application/json", status=200, @@ -130,18 +107,19 @@ def _json_response(*messages: dict) -> Response: resp.headers["Mcp-Session-Id"] = _session_id return resp - def _handle_post() -> Response: - content_type = request.content_type or "" + def _handle_post() -> Any: + adapter = app.backend.request_adapter() + content_type = adapter.headers.get("Content-Type", "") if "application/json" not in content_type: - return Response( + return app.backend.make_response( json.dumps({"error": "Content-Type must be application/json"}), content_type="application/json", status=415, ) - data = request.get_json(silent=True) + data = adapter.get_json() if data is None: - return Response( + return app.backend.make_response( json.dumps({"error": "Invalid JSON"}), content_type="application/json", status=400, @@ -152,7 +130,7 @@ def _handle_post() -> Response: try: is_stale_session = _check_session(method) except ValueError as err: - return Response( + return app.backend.make_response( json.dumps({"error": str(err)}), content_type="application/json", status=400, @@ -161,7 +139,7 @@ def _handle_post() -> Response: response_data = _process_mcp_message(data) if response_data is None: - return Response("", status=202) + return app.backend.make_response("", status=202) if is_stale_session: return _json_response( @@ -175,20 +153,36 @@ def _handle_post() -> Response: return _json_response(response_data) - def _handle_delete() -> Response: - # No sessions to terminate — server is stateless. - return Response( + def _handle_not_allowed(): + """Return 405 for GET and DELETE (MCP SSE not supported).""" + return app.backend.make_response( json.dumps({"error": "Method not allowed"}), content_type="application/json", status=405, ) # -- Register routes ----------------------------------------------------- + # Separate registrations per HTTP method so the handler never needs to + # inspect the request method. Distinct endpoint names are required by + # Flask / Werkzeug when the same URL rule is registered more than once. - # pylint: disable-next=protected-access - app._add_url( - mcp_path, with_app_context_factory(mcp_handler, app), ["GET", "POST", "DELETE"] + mcp_url = app.config.routes_pathname_prefix + mcp_path + app.backend.add_url_rule( + mcp_url, view_func=_handle_post, endpoint=f"{mcp_url}:POST", methods=["POST"] + ) + app.backend.add_url_rule( + mcp_url, + view_func=_handle_not_allowed, + endpoint=f"{mcp_url}:GET", + methods=["GET"], + ) + app.backend.add_url_rule( + mcp_url, + view_func=_handle_not_allowed, + endpoint=f"{mcp_url}:DELETE", + methods=["DELETE"], ) + app.routes.append(mcp_url) logger.info( "MCP routes registered at %s%s", diff --git a/tests/integration/mcp/conftest.py b/tests/integration/mcp/conftest.py index 0db1b775e8..5030211ed8 100644 --- a/tests/integration/mcp/conftest.py +++ b/tests/integration/mcp/conftest.py @@ -5,8 +5,6 @@ import pytest import requests -from dash import _get_app - collect_ignore_glob = [] if sys.version_info < (3, 10): collect_ignore_glob.append("*") @@ -18,17 +16,6 @@ def _enable_mcp_for_integration_tests(monkeypatch): monkeypatch.setenv("DASH_MCP_ENABLED", "true") -@pytest.fixture(autouse=True) -def _reset_dash_app_state(): - """Reset Dash module-level state after each MCP test. - - TODO: this can be removed when 4.2 backend work lands - """ - yield - _get_app.APP = None - _get_app.app_context.set(None) - - def _mcp_post(server_url, method, params=None, request_id=1): return requests.post( f"{server_url}/_mcp", From 6c3fec8bf14b70048abed750b45dde9e736f62c5 Mon Sep 17 00:00:00 2001 From: Adrian Borrmann Date: Wed, 13 May 2026 15:17:54 -0600 Subject: [PATCH 64/80] Fix async (Quart) route registration --- dash/mcp/_server.py | 20 +++++++++++++++++--- 1 file changed, 17 insertions(+), 3 deletions(-) diff --git a/dash/mcp/_server.py b/dash/mcp/_server.py index 5e708da25c..ac8b58de5a 100644 --- a/dash/mcp/_server.py +++ b/dash/mcp/_server.py @@ -7,6 +7,7 @@ from __future__ import annotations import json +import inspect import logging import uuid from typing import TYPE_CHECKING, Any @@ -108,6 +109,14 @@ def _json_response(*messages: dict): return resp def _handle_post() -> Any: + adapter = app.backend.request_adapter() + return _handle_mcp_request(adapter.get_json()) + + async def _handle_post_async() -> Any: + adapter = app.backend.request_adapter() + return _handle_mcp_request(await adapter.get_json()) + + def _handle_mcp_request(data) -> Any: adapter = app.backend.request_adapter() content_type = adapter.headers.get("Content-Type", "") if "application/json" not in content_type: @@ -117,7 +126,6 @@ def _handle_post() -> Any: status=415, ) - data = adapter.get_json() if data is None: return app.backend.make_response( json.dumps({"error": "Invalid JSON"}), @@ -165,10 +173,16 @@ def _handle_not_allowed(): # Separate registrations per HTTP method so the handler never needs to # inspect the request method. Distinct endpoint names are required by # Flask / Werkzeug when the same URL rule is registered more than once. - + if inspect.iscoroutinefunction(app.backend.request_adapter.get_json): + post_handler = _handle_post_async + else: + post_handler = _handle_post mcp_url = app.config.routes_pathname_prefix + mcp_path app.backend.add_url_rule( - mcp_url, view_func=_handle_post, endpoint=f"{mcp_url}:POST", methods=["POST"] + mcp_url, + view_func=post_handler, + endpoint=f"{mcp_url}:POST", + methods=["POST"], ) app.backend.add_url_rule( mcp_url, From 5e4775536e4bf7eddb794a2f672afbcf8e98f002 Mon Sep 17 00:00:00 2001 From: Adrian Borrmann Date: Wed, 13 May 2026 17:07:51 -0600 Subject: [PATCH 65/80] Fix session management with multiple workers --- dash/mcp/_server.py | 27 ++++++++++++++++++--------- tests/unit/mcp/test_mcp_session.py | 8 -------- 2 files changed, 18 insertions(+), 17 deletions(-) diff --git a/dash/mcp/_server.py b/dash/mcp/_server.py index ac8b58de5a..9548342348 100644 --- a/dash/mcp/_server.py +++ b/dash/mcp/_server.py @@ -6,10 +6,12 @@ from __future__ import annotations -import json +import hashlib import inspect +import json import logging -import uuid +import os +import threading from typing import TYPE_CHECKING, Any from mcp.types import ( @@ -51,13 +53,22 @@ def enable_mcp_server(app: Dash, mcp_path: str) -> None: app.mcp_decorated_functions = dict(MCP_DECORATED_FUNCTIONS) MCP_DECORATED_FUNCTIONS.clear() - _session_id: str | None = None - def _get_or_create_session_id() -> str: - """Read the hot-reload hash or generate a stable fallback.""" + """ + Creates a shared session ID shared across all clients. The session is + used to notify clients of app restarts so they can refresh their view + of the app. + When hot-reloading is enabled, the reload_hash is used + Otherwise, the parent PID is used because it is a stable identifier + across different worker processes. + """ # pylint: disable=protected-access reload_hash = app._hot_reload.hash - return reload_hash if reload_hash is not None else uuid.uuid4().hex + if reload_hash is not None: + return reload_hash + return hashlib.sha256(f"dash-mcp-{os.getppid()}".encode()).hexdigest()[:32] + + _session_id: str = _get_or_create_session_id() def _is_session_stale(client_session_id: str | None) -> bool: """True when the client's session doesn't match or the hash changed.""" @@ -84,9 +95,7 @@ def _check_session(method: str) -> bool: _session_id = _get_or_create_session_id() return False client_session_id = adapter.headers.get("Mcp-Session-Id") - if _session_id is not None and not client_session_id: - raise ValueError("Missing Mcp-Session-Id header") - if _is_session_stale(client_session_id): + if client_session_id and _is_session_stale(client_session_id): _session_id = _get_or_create_session_id() logger.debug("MCP session recovered: %s", _session_id) return True diff --git a/tests/unit/mcp/test_mcp_session.py b/tests/unit/mcp/test_mcp_session.py index 029fd284ad..0c878eb3da 100644 --- a/tests/unit/mcp/test_mcp_session.py +++ b/tests/unit/mcp/test_mcp_session.py @@ -40,14 +40,6 @@ def test_mcpse001_initialize_returns_session_id(): assert len(session_id) > 0 -def test_mcpse002_request_without_session_after_init_returns_400(): - app = _make_mcp_app() - with app.server.test_client() as client: - _post(client, "initialize") - resp = _post(client, "tools/list") - assert resp.status_code == 400 - - def test_mcpse003_request_with_valid_session_succeeds(): app = _make_mcp_app() with app.server.test_client() as client: From 9566de02baaec62c19a72a7a2c5f479bb645b81e Mon Sep 17 00:00:00 2001 From: Adrian Borrmann Date: Wed, 13 May 2026 17:31:20 -0600 Subject: [PATCH 66/80] Fix CI errors after rebase --- dash/backends/ws.py | 5 +++-- dash/mcp/_server.py | 1 - tests/integration/mcp/conftest.py | 10 ++++++++++ 3 files changed, 13 insertions(+), 3 deletions(-) diff --git a/dash/backends/ws.py b/dash/backends/ws.py index 041241823e..3e9739e16c 100644 --- a/dash/backends/ws.py +++ b/dash/backends/ws.py @@ -20,6 +20,7 @@ import janus from dash.exceptions import PreventUpdate, WebsocketDisconnected +from dash.types import CallbackExecutionBody from dash._utils import to_json if TYPE_CHECKING: @@ -149,7 +150,7 @@ async def get_prop( def create_ws_context( - payload: dict, + payload: CallbackExecutionBody, response_adapter: "ResponseAdapter", websocket_callback: DashWebsocketCallback, ): @@ -276,7 +277,7 @@ def on_done(f: concurrent.futures.Future) -> None: def run_callback_in_executor( executor: ThreadPoolExecutor, dash_app: "dash.Dash", - payload: dict, + payload: CallbackExecutionBody, ws_callback: DashWebsocketCallback, response_adapter: "ResponseAdapter", ) -> concurrent.futures.Future: diff --git a/dash/mcp/_server.py b/dash/mcp/_server.py index 9548342348..aad2ce4bac 100644 --- a/dash/mcp/_server.py +++ b/dash/mcp/_server.py @@ -11,7 +11,6 @@ import json import logging import os -import threading from typing import TYPE_CHECKING, Any from mcp.types import ( diff --git a/tests/integration/mcp/conftest.py b/tests/integration/mcp/conftest.py index 5030211ed8..aad3f1b1b5 100644 --- a/tests/integration/mcp/conftest.py +++ b/tests/integration/mcp/conftest.py @@ -5,6 +5,8 @@ import pytest import requests +from dash import _get_app + collect_ignore_glob = [] if sys.version_info < (3, 10): collect_ignore_glob.append("*") @@ -16,6 +18,14 @@ def _enable_mcp_for_integration_tests(monkeypatch): monkeypatch.setenv("DASH_MCP_ENABLED", "true") +@pytest.fixture(autouse=True) +def _reset_dash_app_state(): + """Reset Dash module-level state after each MCP test.""" + yield + _get_app.APP = None + _get_app.app_context.set(None) + + def _mcp_post(server_url, method, params=None, request_id=1): return requests.post( f"{server_url}/_mcp", From 296b7ac8a89c035d9747bc186cd83f0f7a6d1bf5 Mon Sep 17 00:00:00 2001 From: Adrian Borrmann Date: Thu, 14 May 2026 09:28:47 -0600 Subject: [PATCH 67/80] code review feedback --- dash/mcp/primitives/resources/resource_components.py | 2 +- dash/mcp/primitives/tools/callback_adapter.py | 3 ++- 2 files changed, 3 insertions(+), 2 deletions(-) diff --git a/dash/mcp/primitives/resources/resource_components.py b/dash/mcp/primitives/resources/resource_components.py index 1f80c8bda2..e14cf80745 100644 --- a/dash/mcp/primitives/resources/resource_components.py +++ b/dash/mcp/primitives/resources/resource_components.py @@ -42,7 +42,7 @@ def read_resource(cls, uri: str = "") -> ReadResourceResult: components = sorted( [ { - "id": str(getattr(comp, "id", None)), + "id": str(getattr(comp, "id")), "type": getattr(comp, "_type", type(comp).__name__), } for comp, _ in traverse(layout) diff --git a/dash/mcp/primitives/tools/callback_adapter.py b/dash/mcp/primitives/tools/callback_adapter.py index 8130c6a8a1..9af56bb879 100644 --- a/dash/mcp/primitives/tools/callback_adapter.py +++ b/dash/mcp/primitives/tools/callback_adapter.py @@ -363,7 +363,8 @@ def _param_annotations(self) -> list[Any | None]: def _expand_dep(dep: CallbackDependency, value: Any) -> CallbackInputs: - """Attach a concrete value to a callback dependency to produce a valid callback input. + """ + Attach a concrete value to a callback dependency to produce a valid callback input. For regular deps, returns ``{id, property, value}``. For ALL/ALLSMALLER: passes through the list of ``{id, property, value}`` dicts. From 15f2111c03b167287ff06fd1a7e19dc1a837b349 Mon Sep 17 00:00:00 2001 From: Adrian Borrmann Date: Wed, 8 Apr 2026 17:06:29 -0600 Subject: [PATCH 68/80] Implement background callback support --- dash/_callback.py | 5 + dash/mcp/_server.py | 8 +- dash/mcp/primitives/tools/__init__.py | 16 +- dash/mcp/primitives/tools/base.py | 6 +- .../primitives/tools/descriptions/__init__.py | 2 + .../description_background_callbacks.py | 31 ++ dash/mcp/primitives/tools/results/__init__.py | 27 +- .../primitives/tools/tool_background_tasks.py | 105 ++++++ .../tools/tool_decorated_mcp_functions.py | 4 +- .../tools/tool_get_dash_component.py | 7 +- dash/mcp/primitives/tools/tools_callbacks.py | 21 +- dash/mcp/tasks/__init__.py | 5 + dash/mcp/tasks/tasks.py | 147 +++++++++ requirements/install.txt | 2 +- .../mcp/test_background_callbacks.py | 133 ++++++++ .../mcp/tools/test_background_callbacks.py | 300 ++++++++++++++++++ 16 files changed, 805 insertions(+), 14 deletions(-) create mode 100644 dash/mcp/primitives/tools/descriptions/description_background_callbacks.py create mode 100644 dash/mcp/primitives/tools/tool_background_tasks.py create mode 100644 dash/mcp/tasks/__init__.py create mode 100644 dash/mcp/tasks/tasks.py create mode 100644 tests/integration/mcp/test_background_callbacks.py create mode 100644 tests/unit/mcp/tools/test_background_callbacks.py diff --git a/dash/_callback.py b/dash/_callback.py index 637a332905..c77419e068 100644 --- a/dash/_callback.py +++ b/dash/_callback.py @@ -1,6 +1,7 @@ import collections import hashlib import inspect +from datetime import datetime, timezone from functools import wraps from typing import Callable, Optional, Any, List, Tuple, Union, Dict, TypeVar, cast @@ -445,6 +446,10 @@ def _setup_background_callback( ctx_value, ) + callback_manager.handle.set( + f"{cache_key}-created_at", datetime.now(timezone.utc).isoformat() + ) + data = { "cacheKey": cache_key, "job": job, diff --git a/dash/mcp/_server.py b/dash/mcp/_server.py index aad2ce4bac..07b0520bb9 100644 --- a/dash/mcp/_server.py +++ b/dash/mcp/_server.py @@ -34,6 +34,7 @@ list_tools, read_resource, ) +from dash.mcp.tasks import get_task, get_task_result, cancel_task from dash.mcp.primitives.tools.callback_adapter_collection import ( CallbackAdapterCollection, ) @@ -251,11 +252,16 @@ def _process_mcp_message(data: dict[str, Any]) -> dict[str, Any] | None: "initialize": _handle_initialize, "tools/list": list_tools, "tools/call": lambda: call_tool( - params.get("name", ""), params.get("arguments", {}) + tool_name=params.get("name", ""), + arguments=params.get("arguments", {}), + task=params.get("task"), ), "resources/list": list_resources, "resources/templates/list": list_resource_templates, "resources/read": lambda: read_resource(params.get("uri", "")), + "tasks/get": lambda: get_task(task_id=params.get("taskId", "")), + "tasks/result": lambda: get_task_result(task_id=params.get("taskId", "")), + "tasks/cancel": lambda: cancel_task(task_id=params.get("taskId", "")), } try: diff --git a/dash/mcp/primitives/tools/__init__.py b/dash/mcp/primitives/tools/__init__.py index b8f12a1dbd..eea7af43c1 100644 --- a/dash/mcp/primitives/tools/__init__.py +++ b/dash/mcp/primitives/tools/__init__.py @@ -4,17 +4,19 @@ from typing import Any -from mcp.types import CallToolResult, ListToolsResult +from mcp.types import CallToolResult, CreateTaskResult, ListToolsResult from dash.mcp.types import ToolNotFoundError from .base import MCPToolProvider +from .tool_background_tasks import BackgroundTaskTools from .tool_decorated_mcp_functions import DecoratedFunctionTools from .tool_get_dash_component import GetDashComponentTool from .tools_callbacks import CallbackTools _TOOL_PROVIDERS: list[type[MCPToolProvider]] = [ CallbackTools, + BackgroundTaskTools, GetDashComponentTool, DecoratedFunctionTools, ] @@ -28,11 +30,17 @@ def list_tools() -> ListToolsResult: return ListToolsResult(tools=tools) -def call_tool(tool_name: str, arguments: dict[str, Any]) -> CallToolResult: - """Route a tools/call request by tool name.""" +def call_tool( + tool_name: str, arguments: dict[str, Any], task: dict | None = None +) -> CallToolResult | CreateTaskResult: + """Route a tools/call request by tool name. + + The optional ``task`` parameter (per MCP Tasks protocol) is passed + through to providers that support background callbacks. + """ for provider in _TOOL_PROVIDERS: if tool_name in provider.get_tool_names(): - return provider.call_tool(tool_name, arguments) + return provider.call_tool(tool_name, arguments, task=task) raise ToolNotFoundError( f"Tool not found: {tool_name}." " The app's callbacks may have changed." diff --git a/dash/mcp/primitives/tools/base.py b/dash/mcp/primitives/tools/base.py index 60fa7374d6..f7a5c54aac 100644 --- a/dash/mcp/primitives/tools/base.py +++ b/dash/mcp/primitives/tools/base.py @@ -4,7 +4,7 @@ from typing import Any -from mcp.types import CallToolResult, Tool +from mcp.types import CallToolResult, CreateTaskResult, Tool class MCPToolProvider: @@ -24,5 +24,7 @@ def list_tools(cls) -> list[Tool]: raise NotImplementedError @classmethod - def call_tool(cls, tool_name: str, arguments: dict[str, Any]) -> CallToolResult: + def call_tool( + cls, tool_name: str, arguments: dict[str, Any], task: dict | None = None + ) -> CallToolResult | CreateTaskResult: raise NotImplementedError diff --git a/dash/mcp/primitives/tools/descriptions/__init__.py b/dash/mcp/primitives/tools/descriptions/__init__.py index b32238992c..a4227868d8 100644 --- a/dash/mcp/primitives/tools/descriptions/__init__.py +++ b/dash/mcp/primitives/tools/descriptions/__init__.py @@ -13,6 +13,7 @@ from typing import TYPE_CHECKING from .base import ToolDescriptionSource +from .description_background_callbacks import BackgroundCallbackDescription from .description_docstring import DocstringDescription from .description_outputs import OutputSummaryDescription @@ -22,6 +23,7 @@ _SOURCES: list[type[ToolDescriptionSource]] = [ OutputSummaryDescription, DocstringDescription, + BackgroundCallbackDescription, ] diff --git a/dash/mcp/primitives/tools/descriptions/description_background_callbacks.py b/dash/mcp/primitives/tools/descriptions/description_background_callbacks.py new file mode 100644 index 0000000000..3823e06e0b --- /dev/null +++ b/dash/mcp/primitives/tools/descriptions/description_background_callbacks.py @@ -0,0 +1,31 @@ +"""Description for background (long-running) callbacks. + +Informs the LLM that the tool returns a taskId immediately +and must be polled via the background task result tool. +""" + +from __future__ import annotations + +from typing import TYPE_CHECKING + +from .base import ToolDescriptionSource + +if TYPE_CHECKING: + from dash.mcp.primitives.tools.callback_adapter import CallbackAdapter + + +class BackgroundCallbackDescription(ToolDescriptionSource): + """Add async polling instructions for background callbacks.""" + + @classmethod + def describe(cls, callback: CallbackAdapter) -> list[str]: + if not callback._cb_info.get("background"): + return [] + from ..tool_background_tasks import GET_RESULT_TOOL_NAME + + return [ + "", + "This is a long-running background operation. " + "It returns a taskId immediately. " + f"Call tool `{GET_RESULT_TOOL_NAME}` with the taskId to poll for the result.", + ] diff --git a/dash/mcp/primitives/tools/results/__init__.py b/dash/mcp/primitives/tools/results/__init__.py index 09e86410a7..10f364959b 100644 --- a/dash/mcp/primitives/tools/results/__init__.py +++ b/dash/mcp/primitives/tools/results/__init__.py @@ -9,7 +9,7 @@ import json from typing import Any -from mcp.types import CallToolResult, TextContent +from mcp.types import CallToolResult, CreateTaskResult, TextContent from dash.types import CallbackExecutionResponse from dash.mcp.primitives.tools.callback_adapter import CallbackAdapter @@ -50,3 +50,28 @@ def format_callback_response( content=content, structuredContent=dict(response), ) + + +def task_result_to_tool_result(create_task_result: CreateTaskResult) -> CallToolResult: + """Wrap a CreateTaskResult as a CallToolResult with polling instructions. + + MCP Tasks are not yet supported by LLM clients, so this converts the + task metadata into a tool response that guides the LLM to poll via + the get_background_task_result tool. + """ + task = create_task_result.task + return CallToolResult( + content=[TextContent( + type="text", + text=json.dumps({ + "taskId": task.taskId, + "status": task.status, + "pollInterval": task.pollInterval, + "message": ( + "This is a long-running background callback. " + "Call the get_background_task_result tool with this taskId " + "to poll for the result." + ), + }), + )], + ) diff --git a/dash/mcp/primitives/tools/tool_background_tasks.py b/dash/mcp/primitives/tools/tool_background_tasks.py new file mode 100644 index 0000000000..37f76a9d4d --- /dev/null +++ b/dash/mcp/primitives/tools/tool_background_tasks.py @@ -0,0 +1,105 @@ +"""Built-in tools for background callback task lifecycle. + +Thin wrappers around the spec-aligned core in dash.mcp.tasks. +Only registered when the app has background callbacks. +""" + +from __future__ import annotations + +from typing import Any + +from mcp.types import CallToolResult, TextContent, Tool + +from dash import get_app +from dash.mcp.tasks import get_task, get_task_result, cancel_task + +from .base import MCPToolProvider + + +GET_RESULT_TOOL_NAME = "get_background_task_result" +CANCEL_TOOL_NAME = "cancel_background_task" + + +def _has_background_callbacks() -> bool: + return any( + cb_info.get("background") + for cb_info in get_app().callback_map.values() + ) + + +class BackgroundTaskTools(MCPToolProvider): + """Built-in tools for polling and cancelling background callback tasks. + + Only registered when the app has background callbacks. + """ + + @classmethod + def get_tool_names(cls) -> set[str]: + if not _has_background_callbacks(): + return set() + return {GET_RESULT_TOOL_NAME, CANCEL_TOOL_NAME} + + @classmethod + def list_tools(cls) -> list[Tool]: + if not _has_background_callbacks(): + return [] + return [ + Tool( + name=GET_RESULT_TOOL_NAME, + description=( + "Poll for the result of a long-running background callback. " + "Pass the taskId returned by the original tool call. " + "If the task is still running, call this tool again. " + "If complete, returns the callback result." + ), + inputSchema={ + "type": "object", + "properties": { + "taskId": { + "type": "string", + "description": "The taskId returned by the background callback tool.", + }, + }, + "required": ["taskId"], + }, + ), + Tool( + name=CANCEL_TOOL_NAME, + description="Cancel a running background callback.", + inputSchema={ + "type": "object", + "properties": { + "taskId": { + "type": "string", + "description": "The taskId of the background task to cancel.", + }, + }, + "required": ["taskId"], + }, + ), + ] + + @classmethod + def call_tool( + cls, + tool_name: str, + arguments: dict[str, Any], + task: dict | None = None, + ) -> CallToolResult: + task_id = arguments.get("taskId", "") + + if tool_name == GET_RESULT_TOOL_NAME: + task_status = get_task(task_id) + if task_status.status == "completed": + return get_task_result(task_id) + return CallToolResult( + content=[TextContent(type="text", text=task_status.model_dump_json())], + ) + + if tool_name == CANCEL_TOOL_NAME: + result = cancel_task(task_id) + return CallToolResult( + content=[TextContent(type="text", text=result.model_dump_json())], + ) + + raise ValueError(f"Unknown tool: {tool_name}") diff --git a/dash/mcp/primitives/tools/tool_decorated_mcp_functions.py b/dash/mcp/primitives/tools/tool_decorated_mcp_functions.py index c135455c88..0b3edbbcbe 100644 --- a/dash/mcp/primitives/tools/tool_decorated_mcp_functions.py +++ b/dash/mcp/primitives/tools/tool_decorated_mcp_functions.py @@ -125,7 +125,9 @@ def list_tools(cls) -> list[Tool]: return [_build_tool(name, reg) for name, reg in cls._registry().items()] @classmethod - def call_tool(cls, tool_name: str, arguments: dict[str, Any]) -> CallToolResult: + def call_tool( + cls, tool_name: str, arguments: dict[str, Any], task: dict | None = None + ) -> CallToolResult: reg = cls._registry().get(tool_name) if reg is None: return CallToolResult( diff --git a/dash/mcp/primitives/tools/tool_get_dash_component.py b/dash/mcp/primitives/tools/tool_get_dash_component.py index 8c131c4288..d8ba9b6e49 100644 --- a/dash/mcp/primitives/tools/tool_get_dash_component.py +++ b/dash/mcp/primitives/tools/tool_get_dash_component.py @@ -58,7 +58,12 @@ def list_tools(cls) -> list[Tool]: ] @classmethod - def call_tool(cls, tool_name: str, arguments: dict[str, Any]) -> CallToolResult: + def call_tool( + cls, + tool_name: str, + arguments: dict[str, Any], + task: dict | None = None, + ) -> CallToolResult: comp_id = arguments.get("component_id", "") if not comp_id: return CallToolResult( diff --git a/dash/mcp/primitives/tools/tools_callbacks.py b/dash/mcp/primitives/tools/tools_callbacks.py index 716b777326..990f90e285 100644 --- a/dash/mcp/primitives/tools/tools_callbacks.py +++ b/dash/mcp/primitives/tools/tools_callbacks.py @@ -7,14 +7,15 @@ from typing import Any -from mcp.types import CallToolResult, TextContent, Tool +from mcp.types import CallToolResult, CreateTaskResult, TextContent, Tool from dash import get_app +from dash.mcp.tasks import create_task from dash.mcp.types import CallbackExecutionError, ToolNotFoundError from .base import MCPToolProvider from .callback_utils import run_callback -from .results import format_callback_response +from .results import format_callback_response, task_result_to_tool_result class CallbackTools(MCPToolProvider): @@ -30,7 +31,12 @@ def list_tools(cls) -> list[Tool]: return get_app().mcp_callback_map.as_mcp_tools() @classmethod - def call_tool(cls, tool_name: str, arguments: dict[str, Any]) -> CallToolResult: + def call_tool( + cls, + tool_name: str, + arguments: dict[str, Any], + task: dict | None = None, + ) -> CallToolResult | CreateTaskResult: """Execute a callback tool by name.""" callback_map = get_app().mcp_callback_map cb = callback_map.find_by_tool_name(tool_name) @@ -41,6 +47,8 @@ def call_tool(cls, tool_name: str, arguments: dict[str, Any]) -> CallToolResult: " Please call tools/list to refresh your tool list." ) + is_background = bool(cb._cb_info.get("background")) + try: callback_response = run_callback(cb, arguments) except CallbackExecutionError as e: @@ -48,4 +56,11 @@ def call_tool(cls, tool_name: str, arguments: dict[str, Any]) -> CallToolResult: content=[TextContent(type="text", text=str(e))], isError=True, ) + + if is_background: + task_result = create_task(callback_response, cb) + if task is not None: + return task_result + return task_result_to_tool_result(task_result) + return format_callback_response(callback_response, cb) diff --git a/dash/mcp/tasks/__init__.py b/dash/mcp/tasks/__init__.py new file mode 100644 index 0000000000..8b78741d60 --- /dev/null +++ b/dash/mcp/tasks/__init__.py @@ -0,0 +1,5 @@ +"""MCP Tasks — lifecycle management for background callback execution.""" + +from .tasks import create_task, get_task, get_task_result, cancel_task + +__all__ = ["create_task", "get_task", "get_task_result", "cancel_task"] diff --git a/dash/mcp/tasks/tasks.py b/dash/mcp/tasks/tasks.py new file mode 100644 index 0000000000..6217fc7b1e --- /dev/null +++ b/dash/mcp/tasks/tasks.py @@ -0,0 +1,147 @@ +"""Handler functions for MCP tasks/* methods.""" + +from __future__ import annotations + +import json +from datetime import datetime, timezone +from typing import Any + +from mcp.types import CreateTaskResult, GetTaskResult, Task + +from dash import get_app +from dash.mcp.primitives.tools.results import format_callback_response +from dash.mcp.types import MCPError + + +def parse_task_id(task_id: str) -> tuple[str, str, str]: + """Parse a taskId into (tool_name, job_id, cache_key).""" + return task_id.split(":", 2) + + +def _get_callback_manager(): + """Get the background callback manager from the app's callback_map.""" + app = get_app() + for cb_info in app.callback_map.values(): + manager = cb_info.get("manager") + if manager is not None: + return manager + return None + + +def create_task(dispatch_response: dict[str, Any], callback) -> CreateTaskResult: + """Create a Task from a background callback's initial dispatch response.""" + cache_key = dispatch_response["cacheKey"] + job_id = str(dispatch_response["job"]) + task_id = f"{callback.tool_name}:{job_id}:{cache_key}" + interval = callback._cb_info.get("background", {}).get("interval", 1000) + now = datetime.now(timezone.utc) + return CreateTaskResult( + task=Task( + taskId=task_id, + status="working", + createdAt=now, + lastUpdatedAt=now, + ttl=None, + pollInterval=interval, + ), + ) + + +def get_task(task_id: str) -> GetTaskResult: + """Handle tasks/get — derive status from the callback manager.""" + tool_name, job_id, cache_key = parse_task_id(task_id) + + manager = _get_callback_manager() + if manager is None: + return GetTaskResult( + taskId=task_id, + status="failed", + statusMessage="No background callback manager configured.", + createdAt=datetime.now(timezone.utc), + lastUpdatedAt=datetime.now(timezone.utc), + ttl=None, + ) + + running = manager.job_running(job_id) + progress = manager.get_progress(cache_key) + + if running: + status = "working" + elif manager.result_ready(cache_key): + status = "completed" + else: + status = "failed" + + adapter = get_app().mcp_callback_map.find_by_tool_name(tool_name) + interval = None + if adapter is not None: + interval = adapter._cb_info.get("background", {}).get("interval", 1000) + + now = datetime.now(timezone.utc) + return GetTaskResult( + taskId=task_id, + status=status, + statusMessage=str(progress) if progress else None, + createdAt=datetime.fromisoformat(manager.handle.get(f"{cache_key}-created_at") or now.isoformat()), + lastUpdatedAt=now, + ttl=manager.expire * 1000 if manager.expire else None, + pollInterval=interval, + ) + + +def get_task_result(task_id: str) -> Any: + """Handle tasks/result — retrieve and format the callback result. + + Mirrors the Dash renderer: calls get_result() which clears from cache. + """ + tool_name, job_id, cache_key = parse_task_id(task_id) + + manager = _get_callback_manager() + if manager is None: + raise MCPError("No background callback manager configured.") + + # Mirror the renderer: dispatch with cacheKey/job query params. + # The framework handles result retrieval, wrapping, and cleanup. + adapter = get_app().mcp_callback_map.find_by_tool_name(tool_name) + body = adapter.as_callback_body({}) + app = get_app() + + with app.server.test_request_context( + f"/_dash-update-component?cacheKey={cache_key}&job={job_id}", + method="POST", + data=json.dumps(body, default=str), + content_type="application/json", + ): + response = app.dispatch() + + response_data = json.loads(response.get_data(as_text=True)) + + if "response" not in response_data: + raise MCPError("Task result not ready. Poll tasks/get until status is 'completed'.") + + return format_callback_response(response_data, adapter) + + +def cancel_task(task_id: str) -> Any: + """Handle tasks/cancel — terminate the background job. + + Same underlying mechanism as the renderer's cancelJob query param. + """ + from mcp.types import CancelTaskResult + + tool_name, job_id, cache_key = parse_task_id(task_id) + + manager = _get_callback_manager() + if manager is None: + raise MCPError("No background callback manager configured.") + + manager.terminate_job(job_id) + + now = datetime.now(timezone.utc) + return CancelTaskResult( + taskId=task_id, + status="cancelled", + createdAt=datetime.fromisoformat(manager.handle.get(f"{cache_key}-created_at") or now.isoformat()), + lastUpdatedAt=now, + ttl=manager.expire * 1000 if manager.expire else None, + ) diff --git a/requirements/install.txt b/requirements/install.txt index a976ab9010..1dedc8662c 100644 --- a/requirements/install.txt +++ b/requirements/install.txt @@ -9,4 +9,4 @@ nest-asyncio setuptools janus>=1.0.0 pydantic>=2.10 -mcp>=1.0.0; python_version>="3.10" +mcp>=1.23.0; python_version>="3.10" diff --git a/tests/integration/mcp/test_background_callbacks.py b/tests/integration/mcp/test_background_callbacks.py new file mode 100644 index 0000000000..75ae009aeb --- /dev/null +++ b/tests/integration/mcp/test_background_callbacks.py @@ -0,0 +1,133 @@ +"""Integration tests for background callback support via MCP.""" + +import json +import time + +import diskcache +from dash import Dash, Input, Output, html +from dash.background_callback.managers.diskcache_manager import DiskcacheManager + +MCP_PATH = "_mcp" + + +def _make_background_app(): + cache = diskcache.Cache() + manager = DiskcacheManager(cache) + + app = Dash(__name__) + app.layout = html.Div( + [ + html.Div(id="input"), + html.Div(id="output"), + ] + ) + + @app.callback( + Output("output", "children"), + Input("input", "children"), + background=True, + manager=manager, + ) + def slow_callback(value): + time.sleep(0.5) + return f"done: {value}" + + return app + + +def _post(client, method, params=None, session_id=None, request_id=1): + headers = {"Content-Type": "application/json"} + if session_id: + headers["mcp-session-id"] = session_id + return client.post( + f"/{MCP_PATH}", + data=json.dumps( + { + "jsonrpc": "2.0", + "method": method, + "id": request_id, + "params": params or {}, + } + ), + headers=headers, + ) + + +def _init_session(client): + r = _post(client, "initialize") + return r.headers["mcp-session-id"] + + +class TestBackgroundCallbackLifecycle: + """Full lifecycle: trigger → poll → get result, over HTTP.""" + + def test_trigger_poll_and_retrieve(self): + app = _make_background_app() + client = app.server.test_client() + sid = _init_session(client) + + # Trigger + r = _post( + client, + "tools/call", + { + "name": "slow_callback", + "arguments": {"value": "hello"}, + }, + session_id=sid, + ) + assert r.status_code == 200 + data = json.loads(r.data) + task_info = json.loads(data["result"]["content"][0]["text"]) + task_id = task_info["taskId"] + assert task_info["status"] == "working" + + # Poll — should be working initially + r = _post( + client, + "tools/call", + { + "name": "get_background_task_result", + "arguments": {"taskId": task_id}, + }, + session_id=sid, + request_id=2, + ) + assert r.status_code == 200 + + # Wait for completion + _, job_id, _ = task_id.split(":", 2) + manager = app.callback_map["output.children"]["manager"] + deadline = time.time() + 5 + while time.time() < deadline: + if not manager.job_running(job_id): + break + time.sleep(0.1) + + # Get result + r = _post( + client, + "tools/call", + { + "name": "get_background_task_result", + "arguments": {"taskId": task_id}, + }, + session_id=sid, + request_id=3, + ) + assert r.status_code == 200 + data = json.loads(r.data) + text = data["result"]["content"][0]["text"] + assert "done:" in text + + def test_background_tools_in_tools_list(self): + app = _make_background_app() + client = app.server.test_client() + sid = _init_session(client) + + r = _post(client, "tools/list", session_id=sid) + data = json.loads(r.data) + names = [t["name"] for t in data["result"]["tools"]] + assert "get_background_task_result" in names + assert "cancel_background_task" in names + assert "slow_callback" in names diff --git a/tests/unit/mcp/tools/test_background_callbacks.py b/tests/unit/mcp/tools/test_background_callbacks.py new file mode 100644 index 0000000000..8dc99164f5 --- /dev/null +++ b/tests/unit/mcp/tools/test_background_callbacks.py @@ -0,0 +1,300 @@ +"""Tests for background callback support via MCP Tasks API.""" + +import time + +from dash import Dash, Input, Output, html +from dash._get_app import app_context +from dash.background_callback.managers.diskcache_manager import DiskcacheManager +from dash.mcp._server import _process_mcp_message +from dash.mcp.primitives.tools.callback_adapter_collection import ( + CallbackAdapterCollection, +) + + +def _setup_mcp(app): + app_context.set(app) + app.mcp_callback_map = CallbackAdapterCollection(app) + return app + + +def _msg(method, params=None, request_id=1): + d = {"jsonrpc": "2.0", "method": method, "id": request_id} + d["params"] = params if params is not None else {} + return d + + +def _mcp(app, method, params=None, request_id=1): + with app.server.test_request_context(): + _setup_mcp(app) + return _process_mcp_message(_msg(method, params, request_id)) + + +def _make_background_app(): + import diskcache + + cache = diskcache.Cache() + manager = DiskcacheManager(cache) + + app = Dash(__name__) + app.layout = html.Div( + [ + html.Div(id="input"), + html.Div(id="output"), + ] + ) + + @app.callback( + Output("output", "children"), + Input("input", "children"), + background=True, + manager=manager, + ) + def slow_callback(value): + """A background callback.""" + time.sleep(0.3) + return f"done: {value}" + + return app + + +class TestCancelBackgroundTaskTool: + """cancel_background_task tool wrapper.""" + + def test_cancel_via_tool(self): + import json + + app = _make_background_app() + trigger = _mcp( + app, + "tools/call", + { + "name": "slow_callback", + "arguments": {"value": "hello"}, + }, + ) + task_id = json.loads(trigger["result"]["content"][0]["text"])["taskId"] + + cancel = _mcp( + app, + "tools/call", + { + "name": "cancel_background_task", + "arguments": {"taskId": task_id}, + }, + ) + assert cancel["result"].get("isError") is not True + + _, job_id, _ = task_id.split(":", 2) + manager = app.callback_map["output.children"]["manager"] + assert not manager.job_running(job_id) + + +class TestBackgroundToolRegistration: + """Background task tools only appear when app has background callbacks.""" + + def test_present_with_background_callbacks(self): + app = _make_background_app() + tools = _mcp(app, "tools/list")["result"]["tools"] + names = [t["name"] for t in tools] + assert "get_background_task_result" in names + assert "cancel_background_task" in names + + def test_absent_without_background_callbacks(self): + app = Dash(__name__) + app.layout = html.Div([html.Div(id="in"), html.Div(id="out")]) + + @app.callback(Output("out", "children"), Input("in", "children")) + def normal_cb(v): + return v + + tools = _mcp(app, "tools/list")["result"]["tools"] + names = [t["name"] for t in tools] + assert "get_background_task_result" not in names + assert "cancel_background_task" not in names + + +class TestGetBackgroundTaskResult: + """get_background_task_result tool: poll and retrieve results.""" + + def _trigger(self, app): + import json + + result = _mcp( + app, + "tools/call", + { + "name": "slow_callback", + "arguments": {"value": "hello"}, + }, + ) + return json.loads(result["result"]["content"][0]["text"])["taskId"] + + def test_returns_working_while_running(self): + app = _make_background_app() + task_id = self._trigger(app) + poll = _mcp( + app, + "tools/call", + { + "name": "get_background_task_result", + "arguments": {"taskId": task_id}, + }, + ) + text = poll["result"]["content"][0]["text"] + assert "working" in text.lower() + + def test_returns_result_when_complete(self): + app = _make_background_app() + task_id = self._trigger(app) + _, job_id, _ = task_id.split(":", 2) + + # Wait for completion + manager = app.callback_map["output.children"]["manager"] + deadline = time.time() + 3 + while time.time() < deadline: + if not manager.job_running(job_id): + break + time.sleep(0.1) + + result = _mcp( + app, + "tools/call", + { + "name": "get_background_task_result", + "arguments": {"taskId": task_id}, + }, + ) + text = result["result"]["content"][0]["text"] + assert "done:" in text + + +class TestBackgroundCallbackTrigger: + """Calling a background callback tool returns taskId immediately.""" + + def test_returns_task_id(self): + app = _make_background_app() + result = _mcp( + app, + "tools/call", + { + "name": "slow_callback", + "arguments": {"value": "hello"}, + }, + ) + text = result["result"]["content"][0]["text"] + assert "taskId" in text + assert "slow_callback:" in text + + +class TestTasksGet: + """tasks/get derives status from the callback manager.""" + + def test_working_status_while_running(self): + app = _make_background_app() + create_result = _mcp( + app, + "tools/call", + { + "name": "slow_callback", + "arguments": {"value": "hello"}, + "task": {"ttl": 60000}, + }, + ) + task_id = create_result["result"]["task"]["taskId"] + + # Immediately poll — job should still be running + get_result = _mcp(app, "tasks/get", {"taskId": task_id}) + assert get_result["result"]["status"] == "working" + assert get_result["result"]["taskId"] == task_id + + +class TestTasksResult: + """tasks/result retrieves and formats the callback result.""" + + def test_returns_formatted_result_when_complete(self): + app = _make_background_app() + create_result = _mcp( + app, + "tools/call", + { + "name": "slow_callback", + "arguments": {"value": "hello"}, + "task": {"ttl": 60000}, + }, + ) + task_id = create_result["result"]["task"]["taskId"] + tool_name, job_id, cache_key = task_id.split(":", 2) + + # Wait for the background job to finish + manager = app.callback_map["output.children"]["manager"] + deadline = time.time() + 3 + while time.time() < deadline: + if not manager.job_running(job_id): + break + time.sleep(0.1) + + # Fetch the result + result = _mcp(app, "tasks/result", {"taskId": task_id}) + assert "content" in result["result"] + text = result["result"]["content"][0]["text"] + assert "done:" in text + + +class TestTasksCancel: + """tasks/cancel terminates the background job.""" + + def test_cancel_terminates_job(self): + app = _make_background_app() + create_result = _mcp( + app, + "tools/call", + { + "name": "slow_callback", + "arguments": {"value": "hello"}, + "task": {"ttl": 60000}, + }, + ) + task_id = create_result["result"]["task"]["taskId"] + + cancel_result = _mcp(app, "tasks/cancel", {"taskId": task_id}) + assert "error" not in cancel_result + + _, job_id, _ = task_id.split(":", 2) + manager = app.callback_map["output.children"]["manager"] + assert not manager.job_running(job_id) + + +class TestBackgroundCallbackWithTask: + """When tools/call includes task metadata, return CreateTaskResult.""" + + def test_returns_create_task_result(self): + app = _make_background_app() + result = _mcp( + app, + "tools/call", + { + "name": "slow_callback", + "arguments": {"value": "hello"}, + "task": {"ttl": 60000}, + }, + ) + task = result["result"]["task"] + assert task["status"] == "working" + assert "taskId" in task + assert "pollInterval" in task + + def test_task_id_encodes_tool_name_job_id_cache_key(self): + app = _make_background_app() + result = _mcp( + app, + "tools/call", + { + "name": "slow_callback", + "arguments": {"value": "hello"}, + "task": {"ttl": 60000}, + }, + ) + task_id = result["result"]["task"]["taskId"] + tool_name, job_id, cache_key = task_id.split(":", 2) + assert tool_name == "slow_callback" + assert len(cache_key) == 64 # SHA256 hex From ee4e48c822fcae0f0edaa7b844ab7af91b5a29b3 Mon Sep 17 00:00:00 2001 From: Adrian Borrmann Date: Wed, 8 Apr 2026 17:43:51 -0600 Subject: [PATCH 69/80] Ensure that background callback expiry is communicated in MCP tool --- dash/_callback.py | 4 +- dash/mcp/tasks/tasks.py | 5 +- .../mcp/test_background_callbacks.py | 133 -------- .../mcp/test_mcp_background_tasks.py | 298 +++++++++++++++++ .../mcp/tools/test_background_callbacks.py | 300 ------------------ .../tools/test_mcp_background_callbacks.py | 277 ++++++++++++++++ 6 files changed, 582 insertions(+), 435 deletions(-) delete mode 100644 tests/integration/mcp/test_background_callbacks.py create mode 100644 tests/integration/mcp/test_mcp_background_tasks.py delete mode 100644 tests/unit/mcp/tools/test_background_callbacks.py create mode 100644 tests/unit/mcp/tools/test_mcp_background_callbacks.py diff --git a/dash/_callback.py b/dash/_callback.py index c77419e068..c7e4cb4102 100644 --- a/dash/_callback.py +++ b/dash/_callback.py @@ -447,7 +447,9 @@ def _setup_background_callback( ) callback_manager.handle.set( - f"{cache_key}-created_at", datetime.now(timezone.utc).isoformat() + f"{cache_key}-created_at", + datetime.now(timezone.utc).isoformat(), + expire=callback_manager.expire, ) data = { diff --git a/dash/mcp/tasks/tasks.py b/dash/mcp/tasks/tasks.py index 6217fc7b1e..877b84795b 100644 --- a/dash/mcp/tasks/tasks.py +++ b/dash/mcp/tasks/tasks.py @@ -138,10 +138,13 @@ def cancel_task(task_id: str) -> Any: manager.terminate_job(job_id) now = datetime.now(timezone.utc) + created_at = manager.handle.get(f"{cache_key}-created_at") + manager.handle.delete(f"{cache_key}-created_at") + return CancelTaskResult( taskId=task_id, status="cancelled", - createdAt=datetime.fromisoformat(manager.handle.get(f"{cache_key}-created_at") or now.isoformat()), + createdAt=datetime.fromisoformat(created_at) if created_at else now, lastUpdatedAt=now, ttl=manager.expire * 1000 if manager.expire else None, ) diff --git a/tests/integration/mcp/test_background_callbacks.py b/tests/integration/mcp/test_background_callbacks.py deleted file mode 100644 index 75ae009aeb..0000000000 --- a/tests/integration/mcp/test_background_callbacks.py +++ /dev/null @@ -1,133 +0,0 @@ -"""Integration tests for background callback support via MCP.""" - -import json -import time - -import diskcache -from dash import Dash, Input, Output, html -from dash.background_callback.managers.diskcache_manager import DiskcacheManager - -MCP_PATH = "_mcp" - - -def _make_background_app(): - cache = diskcache.Cache() - manager = DiskcacheManager(cache) - - app = Dash(__name__) - app.layout = html.Div( - [ - html.Div(id="input"), - html.Div(id="output"), - ] - ) - - @app.callback( - Output("output", "children"), - Input("input", "children"), - background=True, - manager=manager, - ) - def slow_callback(value): - time.sleep(0.5) - return f"done: {value}" - - return app - - -def _post(client, method, params=None, session_id=None, request_id=1): - headers = {"Content-Type": "application/json"} - if session_id: - headers["mcp-session-id"] = session_id - return client.post( - f"/{MCP_PATH}", - data=json.dumps( - { - "jsonrpc": "2.0", - "method": method, - "id": request_id, - "params": params or {}, - } - ), - headers=headers, - ) - - -def _init_session(client): - r = _post(client, "initialize") - return r.headers["mcp-session-id"] - - -class TestBackgroundCallbackLifecycle: - """Full lifecycle: trigger → poll → get result, over HTTP.""" - - def test_trigger_poll_and_retrieve(self): - app = _make_background_app() - client = app.server.test_client() - sid = _init_session(client) - - # Trigger - r = _post( - client, - "tools/call", - { - "name": "slow_callback", - "arguments": {"value": "hello"}, - }, - session_id=sid, - ) - assert r.status_code == 200 - data = json.loads(r.data) - task_info = json.loads(data["result"]["content"][0]["text"]) - task_id = task_info["taskId"] - assert task_info["status"] == "working" - - # Poll — should be working initially - r = _post( - client, - "tools/call", - { - "name": "get_background_task_result", - "arguments": {"taskId": task_id}, - }, - session_id=sid, - request_id=2, - ) - assert r.status_code == 200 - - # Wait for completion - _, job_id, _ = task_id.split(":", 2) - manager = app.callback_map["output.children"]["manager"] - deadline = time.time() + 5 - while time.time() < deadline: - if not manager.job_running(job_id): - break - time.sleep(0.1) - - # Get result - r = _post( - client, - "tools/call", - { - "name": "get_background_task_result", - "arguments": {"taskId": task_id}, - }, - session_id=sid, - request_id=3, - ) - assert r.status_code == 200 - data = json.loads(r.data) - text = data["result"]["content"][0]["text"] - assert "done:" in text - - def test_background_tools_in_tools_list(self): - app = _make_background_app() - client = app.server.test_client() - sid = _init_session(client) - - r = _post(client, "tools/list", session_id=sid) - data = json.loads(r.data) - names = [t["name"] for t in data["result"]["tools"]] - assert "get_background_task_result" in names - assert "cancel_background_task" in names - assert "slow_callback" in names diff --git a/tests/integration/mcp/test_mcp_background_tasks.py b/tests/integration/mcp/test_mcp_background_tasks.py new file mode 100644 index 0000000000..e3e0c0acdf --- /dev/null +++ b/tests/integration/mcp/test_mcp_background_tasks.py @@ -0,0 +1,298 @@ +"""Background callback support through the MCP HTTP endpoint. + +End-to-end flows: trigger a background callback, poll via +``get_background_task_result``, observe progress (``set_progress``), +confirm the cache-expiry behavior, and verify the background-only tools +appear in ``tools/list``. +""" + +import json +import re +import time +from datetime import datetime + +import diskcache +from dash import Dash, Input, Output, html +from dash.background_callback.managers.diskcache_manager import DiskcacheManager + +MCP_PATH = "_mcp" + + +# --------------------------------------------------------------------------- +# Helpers +# --------------------------------------------------------------------------- + + +def _make_background_app(): + cache = diskcache.Cache() + manager = DiskcacheManager(cache) + + app = Dash(__name__) + app.layout = html.Div( + [ + html.Div(id="input"), + html.Div(id="output"), + ] + ) + + @app.callback( + Output("output", "children"), + Input("input", "children"), + background=True, + manager=manager, + ) + def slow_callback(value): + time.sleep(0.5) + return f"done: {value}" + + return app + + +def _post(client, method, params=None, request_id=1): + return client.post( + f"/{MCP_PATH}", + data=json.dumps( + { + "jsonrpc": "2.0", + "method": method, + "id": request_id, + "params": params or {}, + } + ), + headers={"Content-Type": "application/json"}, + ) + + +# --------------------------------------------------------------------------- +# Tests +# --------------------------------------------------------------------------- + + +def test_mcpbg012_trigger_poll_and_retrieve(): + app = _make_background_app() + client = app.server.test_client() + + # Trigger + r = _post( + client, + "tools/call", + {"name": "slow_callback", "arguments": {"value": "hello"}}, + ) + assert r.status_code == 200 + data = json.loads(r.data) + task_info = json.loads(data["result"]["content"][0]["text"]) + task_id = task_info["taskId"] + assert task_info["status"] == "working" + + # Read createdAt from the callback manager directly + _, _, cache_key = task_id.split(":", 2) + stored_created_at = app.callback_map["output.children"]["manager"].handle.get( + f"{cache_key}-created_at" + ) + assert stored_created_at is not None + + # Poll — should be working, with createdAt matching the stored value + r = _post( + client, + "tools/call", + { + "name": "get_background_task_result", + "arguments": {"taskId": task_id}, + }, + request_id=2, + ) + assert r.status_code == 200 + poll_data = json.loads(json.loads(r.data)["result"]["content"][0]["text"]) + assert datetime.fromisoformat(poll_data["createdAt"]) == datetime.fromisoformat( + stored_created_at + ) + + # Wait for completion + _, job_id, _ = task_id.split(":", 2) + manager = app.callback_map["output.children"]["manager"] + deadline = time.time() + 5 + while time.time() < deadline: + if not manager.job_running(job_id): + break + time.sleep(0.1) + + # Get result + r = _post( + client, + "tools/call", + { + "name": "get_background_task_result", + "arguments": {"taskId": task_id}, + }, + request_id=3, + ) + assert r.status_code == 200 + data = json.loads(r.data) + text = data["result"]["content"][0]["text"] + assert "done:" in text + + +def test_mcpbg013_result_expires(): + """Result and createdAt are available until the cache expires.""" + cache = diskcache.Cache() + manager = DiskcacheManager(cache, cache_by=[lambda: "fixed"], expire=2) + + app = Dash(__name__) + app.layout = html.Div([html.Div(id="input"), html.Div(id="output")]) + + @app.callback( + Output("output", "children"), + Input("input", "children"), + background=True, + manager=manager, + ) + def fast_cb(value): + return f"done: {value}" + + client = app.server.test_client() + + # Trigger + r = _post( + client, + "tools/call", + {"name": "fast_cb", "arguments": {"value": "hi"}}, + ) + task_info = json.loads(json.loads(r.data)["result"]["content"][0]["text"]) + task_id = task_info["taskId"] + _, job_id, cache_key = task_id.split(":", 2) + + # Wait for job to finish + deadline = time.time() + 3 + while time.time() < deadline: + if not manager.job_running(job_id): + break + time.sleep(0.1) + + # First retrieval — result and createdAt available + r = _post( + client, + "tools/call", + { + "name": "get_background_task_result", + "arguments": {"taskId": task_id}, + }, + request_id=2, + ) + text = json.loads(r.data)["result"]["content"][0]["text"] + assert "done:" in text + created_at = manager.handle.get(f"{cache_key}-created_at") + assert created_at is not None + + # Second retrieval — still available (cache_by keeps it) + r = _post( + client, + "tools/call", + { + "name": "get_background_task_result", + "arguments": {"taskId": task_id}, + }, + request_id=3, + ) + text = json.loads(r.data)["result"]["content"][0]["text"] + assert "done:" in text + assert manager.handle.get(f"{cache_key}-created_at") == created_at + + # Wait for expiry + time.sleep(2.5) + + # After expiry — tool reports failure, createdAt is fresh (stored value gone) + r = _post( + client, + "tools/call", + { + "name": "get_background_task_result", + "arguments": {"taskId": task_id}, + }, + request_id=4, + ) + poll_data = json.loads(json.loads(r.data)["result"]["content"][0]["text"]) + assert poll_data["status"] == "failed" + assert datetime.fromisoformat(poll_data["createdAt"]) > datetime.fromisoformat( + created_at + ) + + +def test_mcpbg014_progress_in_poll_response(): + """Progress reported via set_progress appears in poll statusMessage.""" + cache = diskcache.Cache() + manager = DiskcacheManager(cache) + + app = Dash(__name__) + app.layout = html.Div( + [ + html.Div(id="input"), + html.Div(id="status"), + html.Div(id="output"), + ] + ) + + @app.callback( + Output("output", "children"), + Input("input", "children"), + progress=Output("status", "children"), + background=True, + manager=manager, + interval=200, + ) + def progress_cb(set_progress, value): + for i in range(10): + set_progress(f"Step {i + 1} of 10") + time.sleep(0.2) + return f"done: {value}" + + client = app.server.test_client() + + # Trigger + r = _post( + client, + "tools/call", + {"name": "progress_cb", "arguments": {"value": "hi"}}, + ) + task_info = json.loads(json.loads(r.data)["result"]["content"][0]["text"]) + task_id = task_info["taskId"] + + # Poll and collect all progress messages + progress_pattern = re.compile(r"Step \d+ of 10") + progress_messages = [] + deadline = time.time() + 10 + while time.time() < deadline: + r = _post( + client, + "tools/call", + { + "name": "get_background_task_result", + "arguments": {"taskId": task_id}, + }, + request_id=2, + ) + text = json.loads(r.data)["result"]["content"][0]["text"] + try: + poll_data = json.loads(text) + msg = poll_data.get("statusMessage") + if msg is not None: + progress_messages.append(msg) + if poll_data.get("status") == "completed": + break + except (json.JSONDecodeError, KeyError): + break + time.sleep(0.3) + + assert len(progress_messages) > 0, "Expected progress updates during polling" + for msg in progress_messages: + assert progress_pattern.search(msg), f"Unexpected progress format: {msg}" + + +def test_mcpbg015_background_tools_in_tools_list(): + app = _make_background_app() + client = app.server.test_client() + r = _post(client, "tools/list") + data = json.loads(r.data) + names = [t["name"] for t in data["result"]["tools"]] + assert "get_background_task_result" in names + assert "cancel_background_task" in names + assert "slow_callback" in names diff --git a/tests/unit/mcp/tools/test_background_callbacks.py b/tests/unit/mcp/tools/test_background_callbacks.py deleted file mode 100644 index 8dc99164f5..0000000000 --- a/tests/unit/mcp/tools/test_background_callbacks.py +++ /dev/null @@ -1,300 +0,0 @@ -"""Tests for background callback support via MCP Tasks API.""" - -import time - -from dash import Dash, Input, Output, html -from dash._get_app import app_context -from dash.background_callback.managers.diskcache_manager import DiskcacheManager -from dash.mcp._server import _process_mcp_message -from dash.mcp.primitives.tools.callback_adapter_collection import ( - CallbackAdapterCollection, -) - - -def _setup_mcp(app): - app_context.set(app) - app.mcp_callback_map = CallbackAdapterCollection(app) - return app - - -def _msg(method, params=None, request_id=1): - d = {"jsonrpc": "2.0", "method": method, "id": request_id} - d["params"] = params if params is not None else {} - return d - - -def _mcp(app, method, params=None, request_id=1): - with app.server.test_request_context(): - _setup_mcp(app) - return _process_mcp_message(_msg(method, params, request_id)) - - -def _make_background_app(): - import diskcache - - cache = diskcache.Cache() - manager = DiskcacheManager(cache) - - app = Dash(__name__) - app.layout = html.Div( - [ - html.Div(id="input"), - html.Div(id="output"), - ] - ) - - @app.callback( - Output("output", "children"), - Input("input", "children"), - background=True, - manager=manager, - ) - def slow_callback(value): - """A background callback.""" - time.sleep(0.3) - return f"done: {value}" - - return app - - -class TestCancelBackgroundTaskTool: - """cancel_background_task tool wrapper.""" - - def test_cancel_via_tool(self): - import json - - app = _make_background_app() - trigger = _mcp( - app, - "tools/call", - { - "name": "slow_callback", - "arguments": {"value": "hello"}, - }, - ) - task_id = json.loads(trigger["result"]["content"][0]["text"])["taskId"] - - cancel = _mcp( - app, - "tools/call", - { - "name": "cancel_background_task", - "arguments": {"taskId": task_id}, - }, - ) - assert cancel["result"].get("isError") is not True - - _, job_id, _ = task_id.split(":", 2) - manager = app.callback_map["output.children"]["manager"] - assert not manager.job_running(job_id) - - -class TestBackgroundToolRegistration: - """Background task tools only appear when app has background callbacks.""" - - def test_present_with_background_callbacks(self): - app = _make_background_app() - tools = _mcp(app, "tools/list")["result"]["tools"] - names = [t["name"] for t in tools] - assert "get_background_task_result" in names - assert "cancel_background_task" in names - - def test_absent_without_background_callbacks(self): - app = Dash(__name__) - app.layout = html.Div([html.Div(id="in"), html.Div(id="out")]) - - @app.callback(Output("out", "children"), Input("in", "children")) - def normal_cb(v): - return v - - tools = _mcp(app, "tools/list")["result"]["tools"] - names = [t["name"] for t in tools] - assert "get_background_task_result" not in names - assert "cancel_background_task" not in names - - -class TestGetBackgroundTaskResult: - """get_background_task_result tool: poll and retrieve results.""" - - def _trigger(self, app): - import json - - result = _mcp( - app, - "tools/call", - { - "name": "slow_callback", - "arguments": {"value": "hello"}, - }, - ) - return json.loads(result["result"]["content"][0]["text"])["taskId"] - - def test_returns_working_while_running(self): - app = _make_background_app() - task_id = self._trigger(app) - poll = _mcp( - app, - "tools/call", - { - "name": "get_background_task_result", - "arguments": {"taskId": task_id}, - }, - ) - text = poll["result"]["content"][0]["text"] - assert "working" in text.lower() - - def test_returns_result_when_complete(self): - app = _make_background_app() - task_id = self._trigger(app) - _, job_id, _ = task_id.split(":", 2) - - # Wait for completion - manager = app.callback_map["output.children"]["manager"] - deadline = time.time() + 3 - while time.time() < deadline: - if not manager.job_running(job_id): - break - time.sleep(0.1) - - result = _mcp( - app, - "tools/call", - { - "name": "get_background_task_result", - "arguments": {"taskId": task_id}, - }, - ) - text = result["result"]["content"][0]["text"] - assert "done:" in text - - -class TestBackgroundCallbackTrigger: - """Calling a background callback tool returns taskId immediately.""" - - def test_returns_task_id(self): - app = _make_background_app() - result = _mcp( - app, - "tools/call", - { - "name": "slow_callback", - "arguments": {"value": "hello"}, - }, - ) - text = result["result"]["content"][0]["text"] - assert "taskId" in text - assert "slow_callback:" in text - - -class TestTasksGet: - """tasks/get derives status from the callback manager.""" - - def test_working_status_while_running(self): - app = _make_background_app() - create_result = _mcp( - app, - "tools/call", - { - "name": "slow_callback", - "arguments": {"value": "hello"}, - "task": {"ttl": 60000}, - }, - ) - task_id = create_result["result"]["task"]["taskId"] - - # Immediately poll — job should still be running - get_result = _mcp(app, "tasks/get", {"taskId": task_id}) - assert get_result["result"]["status"] == "working" - assert get_result["result"]["taskId"] == task_id - - -class TestTasksResult: - """tasks/result retrieves and formats the callback result.""" - - def test_returns_formatted_result_when_complete(self): - app = _make_background_app() - create_result = _mcp( - app, - "tools/call", - { - "name": "slow_callback", - "arguments": {"value": "hello"}, - "task": {"ttl": 60000}, - }, - ) - task_id = create_result["result"]["task"]["taskId"] - tool_name, job_id, cache_key = task_id.split(":", 2) - - # Wait for the background job to finish - manager = app.callback_map["output.children"]["manager"] - deadline = time.time() + 3 - while time.time() < deadline: - if not manager.job_running(job_id): - break - time.sleep(0.1) - - # Fetch the result - result = _mcp(app, "tasks/result", {"taskId": task_id}) - assert "content" in result["result"] - text = result["result"]["content"][0]["text"] - assert "done:" in text - - -class TestTasksCancel: - """tasks/cancel terminates the background job.""" - - def test_cancel_terminates_job(self): - app = _make_background_app() - create_result = _mcp( - app, - "tools/call", - { - "name": "slow_callback", - "arguments": {"value": "hello"}, - "task": {"ttl": 60000}, - }, - ) - task_id = create_result["result"]["task"]["taskId"] - - cancel_result = _mcp(app, "tasks/cancel", {"taskId": task_id}) - assert "error" not in cancel_result - - _, job_id, _ = task_id.split(":", 2) - manager = app.callback_map["output.children"]["manager"] - assert not manager.job_running(job_id) - - -class TestBackgroundCallbackWithTask: - """When tools/call includes task metadata, return CreateTaskResult.""" - - def test_returns_create_task_result(self): - app = _make_background_app() - result = _mcp( - app, - "tools/call", - { - "name": "slow_callback", - "arguments": {"value": "hello"}, - "task": {"ttl": 60000}, - }, - ) - task = result["result"]["task"] - assert task["status"] == "working" - assert "taskId" in task - assert "pollInterval" in task - - def test_task_id_encodes_tool_name_job_id_cache_key(self): - app = _make_background_app() - result = _mcp( - app, - "tools/call", - { - "name": "slow_callback", - "arguments": {"value": "hello"}, - "task": {"ttl": 60000}, - }, - ) - task_id = result["result"]["task"]["taskId"] - tool_name, job_id, cache_key = task_id.split(":", 2) - assert tool_name == "slow_callback" - assert len(cache_key) == 64 # SHA256 hex diff --git a/tests/unit/mcp/tools/test_mcp_background_callbacks.py b/tests/unit/mcp/tools/test_mcp_background_callbacks.py new file mode 100644 index 0000000000..bc6f6b32ed --- /dev/null +++ b/tests/unit/mcp/tools/test_mcp_background_callbacks.py @@ -0,0 +1,277 @@ +"""Background callback support via MCP Tasks. + +Covers both layers: +- Layer 1 (``dash/mcp/tasks/``): ``tasks/get``, ``tasks/result``, ``tasks/cancel`` + derived on-demand from the callback manager. +- Layer 2 (tool wrappers): ``get_background_task_result`` and + ``cancel_background_task`` — only registered when the app has + background callbacks. +""" + +import json +import time + +import diskcache +from dash import Dash, Input, Output, html +from dash.background_callback.managers.diskcache_manager import DiskcacheManager +from dash.mcp._server import _process_mcp_message + +from tests.unit.mcp.conftest import _setup_mcp + + +# --------------------------------------------------------------------------- +# Helpers +# --------------------------------------------------------------------------- + + +def _msg(method, params=None, request_id=1): + d = {"jsonrpc": "2.0", "method": method, "id": request_id} + d["params"] = params if params is not None else {} + return d + + +def _mcp(app, method, params=None, request_id=1): + with app.server.test_request_context(): + _setup_mcp(app) + return _process_mcp_message(_msg(method, params, request_id)) + + +def _make_background_app(): + cache = diskcache.Cache() + manager = DiskcacheManager(cache) + + app = Dash(__name__) + app.layout = html.Div( + [ + html.Div(id="input"), + html.Div(id="output"), + ] + ) + + @app.callback( + Output("output", "children"), + Input("input", "children"), + background=True, + manager=manager, + ) + def slow_callback(value): + """A background callback.""" + time.sleep(0.3) + return f"done: {value}" + + return app + + +def _trigger_task(app): + """Call slow_callback via tools/call and return its taskId.""" + result = _mcp( + app, + "tools/call", + {"name": "slow_callback", "arguments": {"value": "hello"}}, + ) + return json.loads(result["result"]["content"][0]["text"])["taskId"] + + +def _wait_for_completion(app, task_id, timeout=3): + """Block until the callback manager reports the job is no longer running.""" + _, job_id, _ = task_id.split(":", 2) + manager = app.callback_map["output.children"]["manager"] + deadline = time.time() + timeout + while time.time() < deadline: + if not manager.job_running(job_id): + return + time.sleep(0.1) + + +# --------------------------------------------------------------------------- +# Tool-layer: cancel_background_task, get_background_task_result, registration +# --------------------------------------------------------------------------- + + +def test_mcpbg001_cancel_via_tool(): + app = _make_background_app() + task_id = _trigger_task(app) + + cancel = _mcp( + app, + "tools/call", + { + "name": "cancel_background_task", + "arguments": {"taskId": task_id}, + }, + ) + assert cancel["result"].get("isError") is not True + + _, job_id, _ = task_id.split(":", 2) + manager = app.callback_map["output.children"]["manager"] + assert not manager.job_running(job_id) + + +def test_mcpbg002_present_with_background_callbacks(): + app = _make_background_app() + tools = _mcp(app, "tools/list")["result"]["tools"] + names = [t["name"] for t in tools] + assert "get_background_task_result" in names + assert "cancel_background_task" in names + + +def test_mcpbg003_absent_without_background_callbacks(): + app = Dash(__name__) + app.layout = html.Div([html.Div(id="in"), html.Div(id="out")]) + + @app.callback(Output("out", "children"), Input("in", "children")) + def normal_cb(v): + return v + + tools = _mcp(app, "tools/list")["result"]["tools"] + names = [t["name"] for t in tools] + assert "get_background_task_result" not in names + assert "cancel_background_task" not in names + + +def test_mcpbg004_returns_working_while_running(): + app = _make_background_app() + task_id = _trigger_task(app) + poll = _mcp( + app, + "tools/call", + { + "name": "get_background_task_result", + "arguments": {"taskId": task_id}, + }, + ) + text = poll["result"]["content"][0]["text"] + assert "working" in text.lower() + + +def test_mcpbg005_returns_result_when_complete(): + app = _make_background_app() + task_id = _trigger_task(app) + _wait_for_completion(app, task_id) + + result = _mcp( + app, + "tools/call", + { + "name": "get_background_task_result", + "arguments": {"taskId": task_id}, + }, + ) + text = result["result"]["content"][0]["text"] + assert "done:" in text + + +def test_mcpbg006_returns_task_id(): + """Calling a background callback tool returns a taskId immediately.""" + app = _make_background_app() + result = _mcp( + app, + "tools/call", + {"name": "slow_callback", "arguments": {"value": "hello"}}, + ) + text = result["result"]["content"][0]["text"] + assert "taskId" in text + assert "slow_callback:" in text + + +# --------------------------------------------------------------------------- +# Tasks-protocol layer: tasks/get, tasks/result, tasks/cancel +# --------------------------------------------------------------------------- + + +def test_mcpbg007_tasks_get_working_status_while_running(): + app = _make_background_app() + create_result = _mcp( + app, + "tools/call", + { + "name": "slow_callback", + "arguments": {"value": "hello"}, + "task": {"ttl": 60000}, + }, + ) + task_id = create_result["result"]["task"]["taskId"] + + get_result = _mcp(app, "tasks/get", {"taskId": task_id}) + assert get_result["result"]["status"] == "working" + assert get_result["result"]["taskId"] == task_id + + +def test_mcpbg008_tasks_result_returns_formatted_result(): + app = _make_background_app() + create_result = _mcp( + app, + "tools/call", + { + "name": "slow_callback", + "arguments": {"value": "hello"}, + "task": {"ttl": 60000}, + }, + ) + task_id = create_result["result"]["task"]["taskId"] + _wait_for_completion(app, task_id) + + result = _mcp(app, "tasks/result", {"taskId": task_id}) + assert "content" in result["result"] + text = result["result"]["content"][0]["text"] + assert "done:" in text + + +def test_mcpbg009_tasks_cancel_terminates_job(): + app = _make_background_app() + create_result = _mcp( + app, + "tools/call", + { + "name": "slow_callback", + "arguments": {"value": "hello"}, + "task": {"ttl": 60000}, + }, + ) + task_id = create_result["result"]["task"]["taskId"] + + cancel_result = _mcp(app, "tasks/cancel", {"taskId": task_id}) + assert "error" not in cancel_result + + _, job_id, _ = task_id.split(":", 2) + manager = app.callback_map["output.children"]["manager"] + assert not manager.job_running(job_id) + + +# --------------------------------------------------------------------------- +# tools/call with task metadata → CreateTaskResult + taskId encoding +# --------------------------------------------------------------------------- + + +def test_mcpbg010_returns_create_task_result(): + app = _make_background_app() + result = _mcp( + app, + "tools/call", + { + "name": "slow_callback", + "arguments": {"value": "hello"}, + "task": {"ttl": 60000}, + }, + ) + task = result["result"]["task"] + assert task["status"] == "working" + assert "taskId" in task + assert "pollInterval" in task + + +def test_mcpbg011_task_id_encodes_tool_name_job_id_cache_key(): + app = _make_background_app() + result = _mcp( + app, + "tools/call", + { + "name": "slow_callback", + "arguments": {"value": "hello"}, + "task": {"ttl": 60000}, + }, + ) + task_id = result["result"]["task"]["taskId"] + tool_name, _job_id, cache_key = task_id.split(":", 2) + assert tool_name == "slow_callback" + assert len(cache_key) == 64 # SHA256 hex From bf442f789362ed7e251859affe82ff83c99e55d8 Mon Sep 17 00:00:00 2001 From: Adrian Borrmann Date: Thu, 30 Apr 2026 10:49:28 -0600 Subject: [PATCH 70/80] lint --- .../description_background_callbacks.py | 3 +- dash/mcp/primitives/tools/results/__init__.py | 34 +++++++++++-------- .../primitives/tools/tool_background_tasks.py | 5 +-- dash/mcp/primitives/tools/tools_callbacks.py | 1 + dash/mcp/tasks/tasks.py | 16 +++++---- 5 files changed, 34 insertions(+), 25 deletions(-) diff --git a/dash/mcp/primitives/tools/descriptions/description_background_callbacks.py b/dash/mcp/primitives/tools/descriptions/description_background_callbacks.py index 3823e06e0b..eada24a01e 100644 --- a/dash/mcp/primitives/tools/descriptions/description_background_callbacks.py +++ b/dash/mcp/primitives/tools/descriptions/description_background_callbacks.py @@ -8,6 +8,7 @@ from typing import TYPE_CHECKING +from ..tool_background_tasks import GET_RESULT_TOOL_NAME from .base import ToolDescriptionSource if TYPE_CHECKING: @@ -19,9 +20,9 @@ class BackgroundCallbackDescription(ToolDescriptionSource): @classmethod def describe(cls, callback: CallbackAdapter) -> list[str]: + # pylint: disable-next=protected-access if not callback._cb_info.get("background"): return [] - from ..tool_background_tasks import GET_RESULT_TOOL_NAME return [ "", diff --git a/dash/mcp/primitives/tools/results/__init__.py b/dash/mcp/primitives/tools/results/__init__.py index 10f364959b..12f9507da7 100644 --- a/dash/mcp/primitives/tools/results/__init__.py +++ b/dash/mcp/primitives/tools/results/__init__.py @@ -7,17 +7,19 @@ from __future__ import annotations import json -from typing import Any +from typing import TYPE_CHECKING, Any from mcp.types import CallToolResult, CreateTaskResult, TextContent from dash.types import CallbackExecutionResponse -from dash.mcp.primitives.tools.callback_adapter import CallbackAdapter from .base import ResultFormatter from .result_dataframe import DataFrameResult from .result_plotly_figure import PlotlyFigureResult +if TYPE_CHECKING: + from dash.mcp.primitives.tools.callback_adapter import CallbackAdapter + _RESULT_FORMATTERS: list[type[ResultFormatter]] = [ PlotlyFigureResult, DataFrameResult, @@ -61,17 +63,21 @@ def task_result_to_tool_result(create_task_result: CreateTaskResult) -> CallTool """ task = create_task_result.task return CallToolResult( - content=[TextContent( - type="text", - text=json.dumps({ - "taskId": task.taskId, - "status": task.status, - "pollInterval": task.pollInterval, - "message": ( - "This is a long-running background callback. " - "Call the get_background_task_result tool with this taskId " - "to poll for the result." + content=[ + TextContent( + type="text", + text=json.dumps( + { + "taskId": task.taskId, + "status": task.status, + "pollInterval": task.pollInterval, + "message": ( + "This is a long-running background callback. " + "Call the get_background_task_result tool with this taskId " + "to poll for the result." + ), + } ), - }), - )], + ) + ], ) diff --git a/dash/mcp/primitives/tools/tool_background_tasks.py b/dash/mcp/primitives/tools/tool_background_tasks.py index 37f76a9d4d..a4dffa23f4 100644 --- a/dash/mcp/primitives/tools/tool_background_tasks.py +++ b/dash/mcp/primitives/tools/tool_background_tasks.py @@ -21,10 +21,7 @@ def _has_background_callbacks() -> bool: - return any( - cb_info.get("background") - for cb_info in get_app().callback_map.values() - ) + return any(cb_info.get("background") for cb_info in get_app().callback_map.values()) class BackgroundTaskTools(MCPToolProvider): diff --git a/dash/mcp/primitives/tools/tools_callbacks.py b/dash/mcp/primitives/tools/tools_callbacks.py index 990f90e285..97970c5df7 100644 --- a/dash/mcp/primitives/tools/tools_callbacks.py +++ b/dash/mcp/primitives/tools/tools_callbacks.py @@ -47,6 +47,7 @@ def call_tool( " Please call tools/list to refresh your tool list." ) + # pylint: disable-next=protected-access is_background = bool(cb._cb_info.get("background")) try: diff --git a/dash/mcp/tasks/tasks.py b/dash/mcp/tasks/tasks.py index 877b84795b..aab1a98f85 100644 --- a/dash/mcp/tasks/tasks.py +++ b/dash/mcp/tasks/tasks.py @@ -6,7 +6,7 @@ from datetime import datetime, timezone from typing import Any -from mcp.types import CreateTaskResult, GetTaskResult, Task +from mcp.types import CancelTaskResult, CreateTaskResult, GetTaskResult, Task from dash import get_app from dash.mcp.primitives.tools.results import format_callback_response @@ -33,6 +33,7 @@ def create_task(dispatch_response: dict[str, Any], callback) -> CreateTaskResult cache_key = dispatch_response["cacheKey"] job_id = str(dispatch_response["job"]) task_id = f"{callback.tool_name}:{job_id}:{cache_key}" + # pylint: disable-next=protected-access interval = callback._cb_info.get("background", {}).get("interval", 1000) now = datetime.now(timezone.utc) return CreateTaskResult( @@ -75,6 +76,7 @@ def get_task(task_id: str) -> GetTaskResult: adapter = get_app().mcp_callback_map.find_by_tool_name(tool_name) interval = None if adapter is not None: + # pylint: disable-next=protected-access interval = adapter._cb_info.get("background", {}).get("interval", 1000) now = datetime.now(timezone.utc) @@ -82,7 +84,9 @@ def get_task(task_id: str) -> GetTaskResult: taskId=task_id, status=status, statusMessage=str(progress) if progress else None, - createdAt=datetime.fromisoformat(manager.handle.get(f"{cache_key}-created_at") or now.isoformat()), + createdAt=datetime.fromisoformat( + manager.handle.get(f"{cache_key}-created_at") or now.isoformat() + ), lastUpdatedAt=now, ttl=manager.expire * 1000 if manager.expire else None, pollInterval=interval, @@ -117,7 +121,9 @@ def get_task_result(task_id: str) -> Any: response_data = json.loads(response.get_data(as_text=True)) if "response" not in response_data: - raise MCPError("Task result not ready. Poll tasks/get until status is 'completed'.") + raise MCPError( + "Task result not ready. Poll tasks/get until status is 'completed'." + ) return format_callback_response(response_data, adapter) @@ -127,9 +133,7 @@ def cancel_task(task_id: str) -> Any: Same underlying mechanism as the renderer's cancelJob query param. """ - from mcp.types import CancelTaskResult - - tool_name, job_id, cache_key = parse_task_id(task_id) + _tool_name, job_id, cache_key = parse_task_id(task_id) manager = _get_callback_manager() if manager is None: From 576145c3ff67553f7dc52598d458ff3eced68b37 Mon Sep 17 00:00:00 2001 From: Adrian Borrmann Date: Fri, 8 May 2026 16:28:51 -0600 Subject: [PATCH 71/80] Install diskcache extra in lint-unit and typing CI jobs --- .github/workflows/testing.yml | 4 ++-- 1 file changed, 2 insertions(+), 2 deletions(-) diff --git a/.github/workflows/testing.yml b/.github/workflows/testing.yml index 358f9fd2d2..bb9b10a016 100644 --- a/.github/workflows/testing.yml +++ b/.github/workflows/testing.yml @@ -108,7 +108,7 @@ jobs: run: | python -m pip install --upgrade pip wheel python -m pip install "setuptools<80.0.0" - find packages -name dash-*.whl -print -exec sh -c 'pip install "{}[dev,ci,testing]"' \; + find packages -name dash-*.whl -print -exec sh -c 'pip install "{}[dev,ci,testing,diskcache]"' \; - name: Install dash-renderer dependencies working-directory: dash/dash-renderer @@ -231,7 +231,7 @@ jobs: run: | python -m pip install --upgrade pip wheel python -m pip install "setuptools<80.0.0" - find packages -name dash-*.whl -print -exec sh -c 'pip install "{}[ci,testing,dev]"' \; + find packages -name dash-*.whl -print -exec sh -c 'pip install "{}[ci,testing,dev,diskcache]"' \; - name: Build/Setup test components run: npm run setup-tests.py # TODO build the packages and save them to packages/ in build job From 0fef9b8bdbc907047a901baa270752bd8632bb47 Mon Sep 17 00:00:00 2001 From: Adrian Borrmann Date: Fri, 8 May 2026 16:45:16 -0600 Subject: [PATCH 72/80] Fix mypy errors in CI --- dash/background_callback/managers/diskcache_manager.py | 2 +- dash/mcp/primitives/tools/tools_callbacks.py | 2 +- dash/mcp/tasks/tasks.py | 6 ++++-- 3 files changed, 6 insertions(+), 4 deletions(-) diff --git a/dash/background_callback/managers/diskcache_manager.py b/dash/background_callback/managers/diskcache_manager.py index db7cd112bc..2ded7eb48e 100644 --- a/dash/background_callback/managers/diskcache_manager.py +++ b/dash/background_callback/managers/diskcache_manager.py @@ -36,7 +36,7 @@ def __init__(self, cache=None, cache_by=None, expire=None): is determined by the default behavior of the ``cache`` instance. """ try: - import diskcache # type: ignore[import-not-found] # pylint: disable=import-outside-toplevel + import diskcache # type: ignore[import-not-found,import-untyped] # pylint: disable=import-outside-toplevel import psutil # type: ignore[import-untyped] # noqa: F401,E402 pylint: disable=import-outside-toplevel,unused-import,unused-variable,import-error import multiprocess # type: ignore[import-untyped] # noqa: F401,E402 pylint: disable=import-outside-toplevel,unused-import,unused-variable except ImportError as missing_imports: diff --git a/dash/mcp/primitives/tools/tools_callbacks.py b/dash/mcp/primitives/tools/tools_callbacks.py index 97970c5df7..2a2d866ea9 100644 --- a/dash/mcp/primitives/tools/tools_callbacks.py +++ b/dash/mcp/primitives/tools/tools_callbacks.py @@ -59,7 +59,7 @@ def call_tool( ) if is_background: - task_result = create_task(callback_response, cb) + task_result = create_task(dict(callback_response), cb) if task is not None: return task_result return task_result_to_tool_result(task_result) diff --git a/dash/mcp/tasks/tasks.py b/dash/mcp/tasks/tasks.py index aab1a98f85..763f25352b 100644 --- a/dash/mcp/tasks/tasks.py +++ b/dash/mcp/tasks/tasks.py @@ -4,7 +4,7 @@ import json from datetime import datetime, timezone -from typing import Any +from typing import Any, Literal from mcp.types import CancelTaskResult, CreateTaskResult, GetTaskResult, Task @@ -15,7 +15,8 @@ def parse_task_id(task_id: str) -> tuple[str, str, str]: """Parse a taskId into (tool_name, job_id, cache_key).""" - return task_id.split(":", 2) + tool_name, job_id, cache_key = task_id.split(":", 2) + return tool_name, job_id, cache_key def _get_callback_manager(): @@ -66,6 +67,7 @@ def get_task(task_id: str) -> GetTaskResult: running = manager.job_running(job_id) progress = manager.get_progress(cache_key) + status: Literal["working", "completed", "failed"] if running: status = "working" elif manager.result_ready(cache_key): From fd8c8d972f9fb58b539aa24166f9249ebe67c79d Mon Sep 17 00:00:00 2001 From: Adrian Borrmann Date: Tue, 12 May 2026 09:01:57 -0600 Subject: [PATCH 73/80] Fix bug storing metadata on background callbacks --- dash/_callback.py | 6 - dash/mcp/tasks/tasks.py | 60 +++++----- .../mcp/test_mcp_background_tasks.py | 112 +++++++++++------- .../tools/test_mcp_background_callbacks.py | 3 +- 4 files changed, 100 insertions(+), 81 deletions(-) diff --git a/dash/_callback.py b/dash/_callback.py index c7e4cb4102..a8c5a93996 100644 --- a/dash/_callback.py +++ b/dash/_callback.py @@ -446,12 +446,6 @@ def _setup_background_callback( ctx_value, ) - callback_manager.handle.set( - f"{cache_key}-created_at", - datetime.now(timezone.utc).isoformat(), - expire=callback_manager.expire, - ) - data = { "cacheKey": cache_key, "job": job, diff --git a/dash/mcp/tasks/tasks.py b/dash/mcp/tasks/tasks.py index 763f25352b..2d5df45b23 100644 --- a/dash/mcp/tasks/tasks.py +++ b/dash/mcp/tasks/tasks.py @@ -13,30 +13,33 @@ from dash.mcp.types import MCPError -def parse_task_id(task_id: str) -> tuple[str, str, str]: - """Parse a taskId into (tool_name, job_id, cache_key).""" - tool_name, job_id, cache_key = task_id.split(":", 2) - return tool_name, job_id, cache_key +def parse_task_id(task_id: str) -> tuple[str, str, str, datetime]: + """Parse a taskId into (tool_name, job_id, cache_key, created_at).""" + tool_name, job_id, rest = task_id.split(":", 2) + cache_key, created_epoch = rest.rsplit(":", 1) + created_at = datetime.fromtimestamp(int(created_epoch), tz=timezone.utc) + return tool_name, job_id, cache_key, created_at -def _get_callback_manager(): - """Get the background callback manager from the app's callback_map.""" - app = get_app() - for cb_info in app.callback_map.values(): - manager = cb_info.get("manager") - if manager is not None: - return manager - return None +def _get_callback_manager(tool_name: str): + """Resolve the background callback manager for a specific MCP tool.""" + adapter = get_app().mcp_callback_map.find_by_tool_name(tool_name) + if adapter is None: + return None + # pylint: disable-next=protected-access + return adapter._cb_info.get("manager") def create_task(dispatch_response: dict[str, Any], callback) -> CreateTaskResult: """Create a Task from a background callback's initial dispatch response.""" cache_key = dispatch_response["cacheKey"] job_id = str(dispatch_response["job"]) - task_id = f"{callback.tool_name}:{job_id}:{cache_key}" + now = datetime.now(timezone.utc) + # taskId encodes creation time so subsequent polls return a stable + # createdAt without server-side storage (works across gunicorn workers). + task_id = f"{callback.tool_name}:{job_id}:{cache_key}:{int(now.timestamp())}" # pylint: disable-next=protected-access interval = callback._cb_info.get("background", {}).get("interval", 1000) - now = datetime.now(timezone.utc) return CreateTaskResult( task=Task( taskId=task_id, @@ -51,15 +54,15 @@ def create_task(dispatch_response: dict[str, Any], callback) -> CreateTaskResult def get_task(task_id: str) -> GetTaskResult: """Handle tasks/get — derive status from the callback manager.""" - tool_name, job_id, cache_key = parse_task_id(task_id) + tool_name, job_id, cache_key, created_at = parse_task_id(task_id) - manager = _get_callback_manager() + manager = _get_callback_manager(tool_name) if manager is None: return GetTaskResult( taskId=task_id, status="failed", statusMessage="No background callback manager configured.", - createdAt=datetime.now(timezone.utc), + createdAt=created_at, lastUpdatedAt=datetime.now(timezone.utc), ttl=None, ) @@ -81,15 +84,12 @@ def get_task(task_id: str) -> GetTaskResult: # pylint: disable-next=protected-access interval = adapter._cb_info.get("background", {}).get("interval", 1000) - now = datetime.now(timezone.utc) return GetTaskResult( taskId=task_id, status=status, statusMessage=str(progress) if progress else None, - createdAt=datetime.fromisoformat( - manager.handle.get(f"{cache_key}-created_at") or now.isoformat() - ), - lastUpdatedAt=now, + createdAt=created_at, + lastUpdatedAt=datetime.now(timezone.utc), ttl=manager.expire * 1000 if manager.expire else None, pollInterval=interval, ) @@ -100,9 +100,9 @@ def get_task_result(task_id: str) -> Any: Mirrors the Dash renderer: calls get_result() which clears from cache. """ - tool_name, job_id, cache_key = parse_task_id(task_id) + tool_name, job_id, cache_key, _created_at = parse_task_id(task_id) - manager = _get_callback_manager() + manager = _get_callback_manager(tool_name) if manager is None: raise MCPError("No background callback manager configured.") @@ -135,22 +135,18 @@ def cancel_task(task_id: str) -> Any: Same underlying mechanism as the renderer's cancelJob query param. """ - _tool_name, job_id, cache_key = parse_task_id(task_id) + tool_name, job_id, _cache_key, created_at = parse_task_id(task_id) - manager = _get_callback_manager() + manager = _get_callback_manager(tool_name) if manager is None: raise MCPError("No background callback manager configured.") manager.terminate_job(job_id) - now = datetime.now(timezone.utc) - created_at = manager.handle.get(f"{cache_key}-created_at") - manager.handle.delete(f"{cache_key}-created_at") - return CancelTaskResult( taskId=task_id, status="cancelled", - createdAt=datetime.fromisoformat(created_at) if created_at else now, - lastUpdatedAt=now, + createdAt=created_at, + lastUpdatedAt=datetime.now(timezone.utc), ttl=manager.expire * 1000 if manager.expire else None, ) diff --git a/tests/integration/mcp/test_mcp_background_tasks.py b/tests/integration/mcp/test_mcp_background_tasks.py index e3e0c0acdf..8f8b946754 100644 --- a/tests/integration/mcp/test_mcp_background_tasks.py +++ b/tests/integration/mcp/test_mcp_background_tasks.py @@ -9,7 +9,6 @@ import json import re import time -from datetime import datetime import diskcache from dash import Dash, Input, Output, html @@ -84,14 +83,7 @@ def test_mcpbg012_trigger_poll_and_retrieve(): task_id = task_info["taskId"] assert task_info["status"] == "working" - # Read createdAt from the callback manager directly - _, _, cache_key = task_id.split(":", 2) - stored_created_at = app.callback_map["output.children"]["manager"].handle.get( - f"{cache_key}-created_at" - ) - assert stored_created_at is not None - - # Poll — should be working, with createdAt matching the stored value + # Poll — should be working r = _post( client, "tools/call", @@ -103,12 +95,10 @@ def test_mcpbg012_trigger_poll_and_retrieve(): ) assert r.status_code == 200 poll_data = json.loads(json.loads(r.data)["result"]["content"][0]["text"]) - assert datetime.fromisoformat(poll_data["createdAt"]) == datetime.fromisoformat( - stored_created_at - ) + assert poll_data["status"] == "working" # Wait for completion - _, job_id, _ = task_id.split(":", 2) + job_id = task_id.split(":")[1] manager = app.callback_map["output.children"]["manager"] deadline = time.time() + 5 while time.time() < deadline: @@ -133,7 +123,7 @@ def test_mcpbg012_trigger_poll_and_retrieve(): def test_mcpbg013_result_expires(): - """Result and createdAt are available until the cache expires.""" + """Result is retrievable until the cache expires, then reports failure.""" cache = diskcache.Cache() manager = DiskcacheManager(cache, cache_by=[lambda: "fixed"], expire=2) @@ -151,7 +141,6 @@ def fast_cb(value): client = app.server.test_client() - # Trigger r = _post( client, "tools/call", @@ -159,16 +148,15 @@ def fast_cb(value): ) task_info = json.loads(json.loads(r.data)["result"]["content"][0]["text"]) task_id = task_info["taskId"] - _, job_id, cache_key = task_id.split(":", 2) + job_id = task_id.split(":")[1] - # Wait for job to finish deadline = time.time() + 3 while time.time() < deadline: if not manager.job_running(job_id): break time.sleep(0.1) - # First retrieval — result and createdAt available + # Before expiry — result available r = _post( client, "tools/call", @@ -178,29 +166,11 @@ def fast_cb(value): }, request_id=2, ) - text = json.loads(r.data)["result"]["content"][0]["text"] - assert "done:" in text - created_at = manager.handle.get(f"{cache_key}-created_at") - assert created_at is not None + assert "done:" in json.loads(r.data)["result"]["content"][0]["text"] - # Second retrieval — still available (cache_by keeps it) - r = _post( - client, - "tools/call", - { - "name": "get_background_task_result", - "arguments": {"taskId": task_id}, - }, - request_id=3, - ) - text = json.loads(r.data)["result"]["content"][0]["text"] - assert "done:" in text - assert manager.handle.get(f"{cache_key}-created_at") == created_at - - # Wait for expiry time.sleep(2.5) - # After expiry — tool reports failure, createdAt is fresh (stored value gone) + # After expiry — tool reports failure r = _post( client, "tools/call", @@ -208,13 +178,10 @@ def fast_cb(value): "name": "get_background_task_result", "arguments": {"taskId": task_id}, }, - request_id=4, + request_id=3, ) poll_data = json.loads(json.loads(r.data)["result"]["content"][0]["text"]) assert poll_data["status"] == "failed" - assert datetime.fromisoformat(poll_data["createdAt"]) > datetime.fromisoformat( - created_at - ) def test_mcpbg014_progress_in_poll_response(): @@ -296,3 +263,64 @@ def test_mcpbg015_background_tools_in_tools_list(): assert "get_background_task_result" in names assert "cancel_background_task" in names assert "slow_callback" in names + + +def test_mcpbg016_per_callback_manager_lookup(): + """``tasks/get`` uses the manager attached to the specific callback.""" + manager_a = DiskcacheManager(diskcache.Cache()) + manager_b = DiskcacheManager(diskcache.Cache()) + + app = Dash(__name__) + app.layout = html.Div( + [ + html.Div(id="input_a"), + html.Div(id="output_a"), + html.Div(id="input_b"), + html.Div(id="output_b"), + ] + ) + + @app.callback( + Output("output_a", "children"), + Input("input_a", "children"), + background=True, + manager=manager_a, + ) + def callback_a(value): + time.sleep(0.5) + return f"a: {value}" + + @app.callback( + Output("output_b", "children"), + Input("input_b", "children"), + background=True, + manager=manager_b, + ) + def callback_b(value): + time.sleep(0.5) + return f"b: {value}" + + client = app.server.test_client() + + r = _post( + client, + "tools/call", + {"name": "callback_b", "arguments": {"value": "hello"}}, + ) + assert r.status_code == 200 + task_info = json.loads(json.loads(r.data)["result"]["content"][0]["text"]) + task_id = task_info["taskId"] + cache_key = task_id.split(":")[2] + + deadline = time.time() + 5 + while time.time() < deadline: + if manager_b.result_ready(cache_key): + break + time.sleep(0.1) + + assert manager_b.result_ready(cache_key) + assert not manager_a.result_ready(cache_key) + + r = _post(client, "tasks/get", {"taskId": task_id}, request_id=2) + assert r.status_code == 200 + assert json.loads(r.data)["result"]["status"] == "completed" diff --git a/tests/unit/mcp/tools/test_mcp_background_callbacks.py b/tests/unit/mcp/tools/test_mcp_background_callbacks.py index bc6f6b32ed..ffd4fb5bac 100644 --- a/tests/unit/mcp/tools/test_mcp_background_callbacks.py +++ b/tests/unit/mcp/tools/test_mcp_background_callbacks.py @@ -272,6 +272,7 @@ def test_mcpbg011_task_id_encodes_tool_name_job_id_cache_key(): }, ) task_id = result["result"]["task"]["taskId"] - tool_name, _job_id, cache_key = task_id.split(":", 2) + tool_name, _job_id, cache_key, created_epoch = task_id.split(":") assert tool_name == "slow_callback" assert len(cache_key) == 64 # SHA256 hex + assert created_epoch.isdigit() From 3cc14d8f4fd940246015ba58ec5f4f4c96a5f203 Mon Sep 17 00:00:00 2001 From: Adrian Borrmann Date: Wed, 13 May 2026 11:00:28 -0600 Subject: [PATCH 74/80] Enable programmatic querying of background callback status/result; remove Flask dependency in MCP implementation --- dash/_callback.py | 48 +++++++++++++++++++++-------- dash/mcp/tasks/tasks.py | 67 ++++++++++++++++++++++++++++++----------- 2 files changed, 85 insertions(+), 30 deletions(-) diff --git a/dash/_callback.py b/dash/_callback.py index a8c5a93996..037f3d189b 100644 --- a/dash/_callback.py +++ b/dash/_callback.py @@ -1,7 +1,6 @@ import collections import hashlib import inspect -from datetime import datetime, timezone from functools import wraps from typing import Callable, Optional, Any, List, Tuple, Union, Dict, TypeVar, cast @@ -463,10 +462,13 @@ def _setup_background_callback( return to_json(data) -def _progress_background_callback(response, callback_manager, background): +def _progress_background_callback( + response, callback_manager, background, cache_key=None +): progress_outputs = background.get("progress") - adapter = get_app().backend.request_adapter() - cache_key = adapter.args.get("cacheKey") + if cache_key is None: + adapter = get_app().backend.request_adapter() + cache_key = adapter.args.get("cacheKey") if progress_outputs: # Get the progress before the result as it would be erased after the results. @@ -478,21 +480,38 @@ def _progress_background_callback(response, callback_manager, background): def _update_background_callback( - error_handler, callback_ctx, response, kwargs, background, multi + error_handler, + callback_ctx, + response, + kwargs, + background, + multi, + cache_key=None, + job_id=None, ): """Set up the background callback and manage jobs.""" callback_manager = _get_callback_manager(kwargs, background) - adapter = get_app().backend.request_adapter() - cache_key = adapter.args.get("cacheKey") if adapter else None - job_id = adapter.args.get("job") if adapter else None + if cache_key is None or job_id is None: + adapter = get_app().backend.request_adapter() + cache_key = cache_key or (adapter.args.get("cacheKey") if adapter else None) + job_id = job_id or (adapter.args.get("job") if adapter else None) - _progress_background_callback(response, callback_manager, background) + _progress_background_callback( + response, callback_manager, background, cache_key=cache_key + ) output_value = callback_manager.get_result(cache_key, job_id) return _handle_rest_background_callback( - output_value, callback_manager, response, error_handler, callback_ctx, multi + output_value, + callback_manager, + response, + error_handler, + callback_ctx, + multi, + cache_key=cache_key, + job_id=job_id, ) @@ -504,10 +523,13 @@ def _handle_rest_background_callback( callback_ctx, multi, has_update=False, + cache_key=None, + job_id=None, ): - adapter = get_app().backend.request_adapter() - cache_key = adapter.args.get("cacheKey") if adapter else None - job_id = adapter.args.get("job") if adapter else None + if cache_key is None or job_id is None: + adapter = get_app().backend.request_adapter() + cache_key = cache_key or (adapter.args.get("cacheKey") if adapter else None) + job_id = job_id or (adapter.args.get("job") if adapter else None) # Must get job_running after get_result since get_results terminates it. job_running = callback_manager.job_running(job_id) if not job_running and output_value is callback_manager.UNDEFINED: diff --git a/dash/mcp/tasks/tasks.py b/dash/mcp/tasks/tasks.py index 2d5df45b23..ab0fd82069 100644 --- a/dash/mcp/tasks/tasks.py +++ b/dash/mcp/tasks/tasks.py @@ -2,13 +2,15 @@ from __future__ import annotations -import json from datetime import datetime, timezone from typing import Any, Literal from mcp.types import CancelTaskResult, CreateTaskResult, GetTaskResult, Task from dash import get_app +from dash._callback import _update_background_callback, _prepare_response +from dash._utils import AttributeDict, split_callback_id +from dash.development.base_component import ComponentRegistry from dash.mcp.primitives.tools.results import format_callback_response from dash.mcp.types import MCPError @@ -98,7 +100,10 @@ def get_task(task_id: str) -> GetTaskResult: def get_task_result(task_id: str) -> Any: """Handle tasks/result — retrieve and format the callback result. - Mirrors the Dash renderer: calls get_result() which clears from cache. + Uses the framework's background callback retrieval functions + (``_update_background_callback`` and ``_prepare_response``) with + ``cache_key`` and ``job_id`` supplied directly, bypassing the + request adapter query-param lookup. """ tool_name, job_id, cache_key, _created_at = parse_task_id(task_id) @@ -106,28 +111,56 @@ def get_task_result(task_id: str) -> Any: if manager is None: raise MCPError("No background callback manager configured.") - # Mirror the renderer: dispatch with cacheKey/job query params. - # The framework handles result retrieval, wrapping, and cleanup. - adapter = get_app().mcp_callback_map.find_by_tool_name(tool_name) - body = adapter.as_callback_body({}) app = get_app() + adapter = app.mcp_callback_map.find_by_tool_name(tool_name) + # pylint: disable-next=protected-access + cb_info = adapter._cb_info + background = cb_info.get("background") + has_output = not cb_info.get("no_output") + multi = adapter.output_id.startswith("..") + output = split_callback_id(adapter.output_id) - with app.server.test_request_context( - f"/_dash-update-component?cacheKey={cache_key}&job={job_id}", - method="POST", - data=json.dumps(body, default=str), - content_type="application/json", - ): - response = app.dispatch() - - response_data = json.loads(response.get_data(as_text=True)) + # Build the body to get output_spec, same as as_callback_body + body = adapter.as_callback_body({}) + output_spec = body.get("outputs", []) + + callback_ctx = AttributeDict({"updated_props": {}}) + response = {"multi": True} + + output_value, has_update, skip = _update_background_callback( + error_handler=None, + callback_ctx=callback_ctx, + response=response, + kwargs={"background_callback_manager": manager}, + background=background, + multi=multi, + cache_key=cache_key, + job_id=job_id, + ) - if "response" not in response_data: + if skip: + # Result not ready — still polling raise MCPError( "Task result not ready. Poll tasks/get until status is 'completed'." ) - return format_callback_response(response_data, adapter) + _prepare_response( + output_value, + output_spec, + multi, + response, + callback_ctx, + app, + set(ComponentRegistry.registry), + background, + has_update, + has_output, + output, + adapter.output_id, + cb_info.get("allow_dynamic_callbacks"), + ) + + return format_callback_response(response, adapter) def cancel_task(task_id: str) -> Any: From 9d5fa6181623a14502b5ecb00ea6b5489907a330 Mon Sep 17 00:00:00 2001 From: robertclaus Date: Fri, 15 May 2026 15:55:25 -0500 Subject: [PATCH 75/80] Fix get_task status race and CallbackExecutionResponse type annotation - Reorder status checks in get_task: result_ready takes priority over job_running. A background process can write its result to the cache before the process fully exits, causing a false "working" status even after the result is available (fixes test_mcpbg016). - Annotate response as CallbackExecutionResponse so mypy accepts it as the argument to _prepare_response and format_callback_response (fixes Typing Tests CI failure). Co-Authored-By: Claude Sonnet 4.6 --- dash/mcp/tasks/tasks.py | 11 +++++++---- 1 file changed, 7 insertions(+), 4 deletions(-) diff --git a/dash/mcp/tasks/tasks.py b/dash/mcp/tasks/tasks.py index ab0fd82069..044ee74d25 100644 --- a/dash/mcp/tasks/tasks.py +++ b/dash/mcp/tasks/tasks.py @@ -13,6 +13,7 @@ from dash.development.base_component import ComponentRegistry from dash.mcp.primitives.tools.results import format_callback_response from dash.mcp.types import MCPError +from dash.types import CallbackExecutionResponse def parse_task_id(task_id: str) -> tuple[str, str, str, datetime]: @@ -72,11 +73,13 @@ def get_task(task_id: str) -> GetTaskResult: running = manager.job_running(job_id) progress = manager.get_progress(cache_key) + # Check result_ready before job_running: the process may store its result + # while still technically alive (teardown race), so result_ready wins. status: Literal["working", "completed", "failed"] - if running: - status = "working" - elif manager.result_ready(cache_key): + if manager.result_ready(cache_key): status = "completed" + elif running: + status = "working" else: status = "failed" @@ -125,7 +128,7 @@ def get_task_result(task_id: str) -> Any: output_spec = body.get("outputs", []) callback_ctx = AttributeDict({"updated_props": {}}) - response = {"multi": True} + response: CallbackExecutionResponse = {"multi": True} output_value, has_update, skip = _update_background_callback( error_handler=None, From ef778a213bcc85615a70d57df7a4306d27dd3122 Mon Sep 17 00:00:00 2001 From: Adrian Borrmann Date: Tue, 19 May 2026 11:52:47 -0600 Subject: [PATCH 76/80] code review feedback --- dash/mcp/primitives/tools/results/__init__.py | 3 +- .../primitives/tools/tool_background_tasks.py | 34 +++++++++++-------- dash/mcp/tasks/tasks.py | 11 ++++-- 3 files changed, 29 insertions(+), 19 deletions(-) diff --git a/dash/mcp/primitives/tools/results/__init__.py b/dash/mcp/primitives/tools/results/__init__.py index 12f9507da7..843ddbb4f2 100644 --- a/dash/mcp/primitives/tools/results/__init__.py +++ b/dash/mcp/primitives/tools/results/__init__.py @@ -55,7 +55,8 @@ def format_callback_response( def task_result_to_tool_result(create_task_result: CreateTaskResult) -> CallToolResult: - """Wrap a CreateTaskResult as a CallToolResult with polling instructions. + """ + Wrap a CreateTaskResult as a CallToolResult with polling instructions. MCP Tasks are not yet supported by LLM clients, so this converts the task metadata into a tool response that guides the LLM to poll via diff --git a/dash/mcp/primitives/tools/tool_background_tasks.py b/dash/mcp/primitives/tools/tool_background_tasks.py index a4dffa23f4..5f447074e4 100644 --- a/dash/mcp/primitives/tools/tool_background_tasks.py +++ b/dash/mcp/primitives/tools/tool_background_tasks.py @@ -85,18 +85,22 @@ def call_tool( ) -> CallToolResult: task_id = arguments.get("taskId", "") - if tool_name == GET_RESULT_TOOL_NAME: - task_status = get_task(task_id) - if task_status.status == "completed": - return get_task_result(task_id) - return CallToolResult( - content=[TextContent(type="text", text=task_status.model_dump_json())], - ) - - if tool_name == CANCEL_TOOL_NAME: - result = cancel_task(task_id) - return CallToolResult( - content=[TextContent(type="text", text=result.model_dump_json())], - ) - - raise ValueError(f"Unknown tool: {tool_name}") + match tool_name: + case name if name == GET_RESULT_TOOL_NAME: + task_status = get_task(task_id) + if task_status.status == "completed": + return get_task_result(task_id) + return CallToolResult( + content=[ + TextContent(type="text", text=task_status.model_dump_json()) + ], + ) + + case name if name == CANCEL_TOOL_NAME: + result = cancel_task(task_id) + return CallToolResult( + content=[TextContent(type="text", text=result.model_dump_json())], + ) + + case _: + raise ValueError(f"Unknown tool: {tool_name}") diff --git a/dash/mcp/tasks/tasks.py b/dash/mcp/tasks/tasks.py index 044ee74d25..5989dfde55 100644 --- a/dash/mcp/tasks/tasks.py +++ b/dash/mcp/tasks/tasks.py @@ -18,9 +18,12 @@ def parse_task_id(task_id: str) -> tuple[str, str, str, datetime]: """Parse a taskId into (tool_name, job_id, cache_key, created_at).""" - tool_name, job_id, rest = task_id.split(":", 2) - cache_key, created_epoch = rest.rsplit(":", 1) - created_at = datetime.fromtimestamp(int(created_epoch), tz=timezone.utc) + try: + tool_name, job_id, rest = task_id.split(":", 2) + cache_key, created_epoch = rest.rsplit(":", 1) + created_at = datetime.fromtimestamp(int(created_epoch), tz=timezone.utc) + except (ValueError, TypeError) as exc: + raise MCPError(f"Malformed taskId: {task_id!r}") from exc return tool_name, job_id, cache_key, created_at @@ -116,6 +119,8 @@ def get_task_result(task_id: str) -> Any: app = get_app() adapter = app.mcp_callback_map.find_by_tool_name(tool_name) + if adapter is None: + raise MCPError(f"Task not found: {tool_name}") # pylint: disable-next=protected-access cb_info = adapter._cb_info background = cb_info.get("background") From b78fe2a73a5e510ae7b34581ce52942f1e99c40b Mon Sep 17 00:00:00 2001 From: Adrian Borrmann Date: Wed, 20 May 2026 09:36:57 -0600 Subject: [PATCH 77/80] Revert match/case refactor due to compatibility problems --- .../primitives/tools/tool_background_tasks.py | 34 ++++++++----------- 1 file changed, 15 insertions(+), 19 deletions(-) diff --git a/dash/mcp/primitives/tools/tool_background_tasks.py b/dash/mcp/primitives/tools/tool_background_tasks.py index 5f447074e4..a4dffa23f4 100644 --- a/dash/mcp/primitives/tools/tool_background_tasks.py +++ b/dash/mcp/primitives/tools/tool_background_tasks.py @@ -85,22 +85,18 @@ def call_tool( ) -> CallToolResult: task_id = arguments.get("taskId", "") - match tool_name: - case name if name == GET_RESULT_TOOL_NAME: - task_status = get_task(task_id) - if task_status.status == "completed": - return get_task_result(task_id) - return CallToolResult( - content=[ - TextContent(type="text", text=task_status.model_dump_json()) - ], - ) - - case name if name == CANCEL_TOOL_NAME: - result = cancel_task(task_id) - return CallToolResult( - content=[TextContent(type="text", text=result.model_dump_json())], - ) - - case _: - raise ValueError(f"Unknown tool: {tool_name}") + if tool_name == GET_RESULT_TOOL_NAME: + task_status = get_task(task_id) + if task_status.status == "completed": + return get_task_result(task_id) + return CallToolResult( + content=[TextContent(type="text", text=task_status.model_dump_json())], + ) + + if tool_name == CANCEL_TOOL_NAME: + result = cancel_task(task_id) + return CallToolResult( + content=[TextContent(type="text", text=result.model_dump_json())], + ) + + raise ValueError(f"Unknown tool: {tool_name}") From 16da30a25e30376d1b341c2eab0fbf3d835bd3ed Mon Sep 17 00:00:00 2001 From: Adrian Borrmann Date: Wed, 20 May 2026 09:48:15 -0600 Subject: [PATCH 78/80] automatically load env vars with direnv --- .envrc | 5 +++++ 1 file changed, 5 insertions(+) diff --git a/.envrc b/.envrc index 96b5c94a1b..22cf211038 100644 --- a/.envrc +++ b/.envrc @@ -6,3 +6,8 @@ for venv_dir in .venv .env venv env .venv* env* ENV; do break fi done + +# Include optional, local-only environment overrides +if [ -f .env ]; then + dotenv +fi From 1861302c6c90646ae8c2bd1b491f9955539e2484 Mon Sep 17 00:00:00 2001 From: Adrian Borrmann Date: Thu, 21 May 2026 14:40:29 -0600 Subject: [PATCH 79/80] bump version and changelog --- CHANGELOG.md | 13 +++++++++++++ dash/version.py | 2 +- 2 files changed, 14 insertions(+), 1 deletion(-) diff --git a/CHANGELOG.md b/CHANGELOG.md index 13c0fc585b..80966c823d 100644 --- a/CHANGELOG.md +++ b/CHANGELOG.md @@ -20,6 +20,19 @@ This project adheres to [Semantic Versioning](https://semver.org/). - [#2462](https://github.com/plotly/dash/issues/2462) Allow `MATCH` in `Input`/`State` when the callback's `Output` has no wildcards (fixed-id Output, no Output, or `ALL`-only wildcard Output). `ALLSMALLER` still requires a corresponding `MATCH` in an Output. - [#3759](https://github.com/plotly/dash/pull/3759) Fix the issue where `Patch` objects cannot be updated via `set_props()` in `websocket` callback. Fix [#3742](https://github.com/plotly/dash/issues/3742) +## [4.3.0rc0] - 2026-05-21 + +## Added +- [#3710](https://github.com/plotly/dash/pull/3710) Framework utilities, types for interacting with layout +- [#3711](https://github.com/plotly/dash/pull/3711) `CallbackAdapter` for representing callback-related data in MCP-friendly format +- [#3712](https://github.com/plotly/dash/pull/3712) MCP `Resources` for exposing app layout, components, and pages +- [#3731](https://github.com/plotly/dash/pull/3731) Expose callbacks as MCP `Tools` +- [#3747](https://github.com/plotly/dash/pull/3747) Support pattern-matching callbacks in MCP tools +- [#3748](https://github.com/plotly/dash/pull/3748) Format callback results for LLM consumption (rendered graphs, markdown tables) +- [#3749](https://github.com/plotly/dash/pull/3749) `get_dash_component` MCP tool and callback execution +- [#3750](https://github.com/plotly/dash/pull/3750) MCP server routes and Streamable HTTP transport +- [#3766](https://github.com/plotly/dash/pull/3766) Support background callbacks in MCP tools + ## [4.2.0rc3] - 2026-05-12 - [#3771](https://github.com/plotly/dash/pull/3771) Add persistent callbacks and no inputs/no outputs callback support. diff --git a/dash/version.py b/dash/version.py index d39d52c3b8..c8f491bcb2 100644 --- a/dash/version.py +++ b/dash/version.py @@ -1 +1 @@ -__version__ = "4.2.0rc3" +__version__ = "4.3.0rc0" From 3558bd3afa84c92b2bef835bc785feb8454bd28b Mon Sep 17 00:00:00 2001 From: Adrian Borrmann Date: Thu, 21 May 2026 14:49:59 -0600 Subject: [PATCH 80/80] changelog --- CHANGELOG.md | 18 +++++++++--------- 1 file changed, 9 insertions(+), 9 deletions(-) diff --git a/CHANGELOG.md b/CHANGELOG.md index 80966c823d..36c31ec8e9 100644 --- a/CHANGELOG.md +++ b/CHANGELOG.md @@ -23,15 +23,15 @@ This project adheres to [Semantic Versioning](https://semver.org/). ## [4.3.0rc0] - 2026-05-21 ## Added -- [#3710](https://github.com/plotly/dash/pull/3710) Framework utilities, types for interacting with layout -- [#3711](https://github.com/plotly/dash/pull/3711) `CallbackAdapter` for representing callback-related data in MCP-friendly format -- [#3712](https://github.com/plotly/dash/pull/3712) MCP `Resources` for exposing app layout, components, and pages -- [#3731](https://github.com/plotly/dash/pull/3731) Expose callbacks as MCP `Tools` -- [#3747](https://github.com/plotly/dash/pull/3747) Support pattern-matching callbacks in MCP tools -- [#3748](https://github.com/plotly/dash/pull/3748) Format callback results for LLM consumption (rendered graphs, markdown tables) -- [#3749](https://github.com/plotly/dash/pull/3749) `get_dash_component` MCP tool and callback execution -- [#3750](https://github.com/plotly/dash/pull/3750) MCP server routes and Streamable HTTP transport -- [#3766](https://github.com/plotly/dash/pull/3766) Support background callbacks in MCP tools +- [#3710](https://github.com/plotly/dash/pull/3710) MCP: Framework utilities, types for interacting with layout +- [#3711](https://github.com/plotly/dash/pull/3711) MCP: `CallbackAdapter` for representing callback-related data in MCP-friendly format +- [#3712](https://github.com/plotly/dash/pull/3712) MCP: `Resources` for exposing app layout, components, and pages +- [#3731](https://github.com/plotly/dash/pull/3731) MCP: Expose callbacks as `Tools` +- [#3747](https://github.com/plotly/dash/pull/3747) MCP: Support pattern-matching callbacks in Tools +- [#3748](https://github.com/plotly/dash/pull/3748) MCP: Format callback results for LLM consumption (rendered graphs, markdown tables) +- [#3749](https://github.com/plotly/dash/pull/3749) MCP: `get_dash_component` Tool and callback execution +- [#3750](https://github.com/plotly/dash/pull/3750) MCP: Server routes, `mcp_enabled` function decorator, and Streamable HTTP transport +- [#3766](https://github.com/plotly/dash/pull/3766) MCP: Support background callbacks in Tools ## [4.2.0rc3] - 2026-05-12