diff --git a/.github/workflows/testing.yml b/.github/workflows/testing.yml index 1fc0df1845..be5caf4929 100644 --- a/.github/workflows/testing.yml +++ b/.github/workflows/testing.yml @@ -17,6 +17,7 @@ jobs: # This output will be 'true' if files in the 'table_related_paths' list changed, 'false' otherwise. table_paths_changed: ${{ steps.filter.outputs.table_related_paths }} background_cb_changed: ${{ steps.filter.outputs.background_paths }} + backend_cb_changed: ${{ steps.filter.outputs.backend_paths }} steps: - name: Checkout repository uses: actions/checkout@v4 @@ -37,6 +38,9 @@ jobs: - 'tests/background_callback/**' - 'tests/async_tests/**' - 'requirements/**' + backend_paths: + - 'dash/backend/**' + - 'tests/backend/**' build: name: Build Dash Package @@ -271,6 +275,109 @@ jobs: cd bgtests pytest --headless --nopercyfinalize tests/async_tests -v -s + backend-tests: + name: Run Backend Callback Tests (Python ${{ matrix.python-version }}) + needs: [build, changes_filter] + if: | + (github.event_name == 'push' && (github.ref == 'refs/heads/master' || github.ref == 'refs/heads/dev')) || + needs.changes_filter.outputs.backend_cb_changed == 'true' + timeout-minutes: 30 + runs-on: ubuntu-latest + strategy: + fail-fast: false + matrix: + python-version: ["3.9", "3.12"] + + services: + redis: + image: redis:6 + ports: + - 6379:6379 + options: >- + --health-cmd "redis-cli ping" + --health-interval 10s + --health-timeout 5s + --health-retries 5 + + env: + REDIS_URL: redis://localhost:6379 + steps: + - name: Checkout repository + uses: actions/checkout@v4 + + - name: Set up Node.js + uses: actions/setup-node@v4 + with: + node-version: '20' + cache: 'npm' + + - name: Install Node.js dependencies + run: npm ci + + - name: Set up Python ${{ matrix.python-version }} + uses: actions/setup-python@v5 + with: + python-version: ${{ matrix.python-version }} + cache: 'pip' + + - name: Download built Dash packages + uses: actions/download-artifact@v4 + with: + name: dash-packages + path: packages/ + + - name: Install Dash packages + run: | + python -m pip install --upgrade pip wheel + python -m pip install "setuptools<78.0.0" + python -m pip install "selenium==4.32.0" + find packages -name dash-*.whl -print -exec sh -c 'pip install "{}[async,ci,testing,dev,celery,diskcache,fastapi,quart]"' \; + + - name: Install Google Chrome + run: | + sudo apt-get update + sudo apt-get install -y google-chrome-stable + + - name: Install ChromeDriver + run: | + echo "Determining Chrome version..." + CHROME_BROWSER_VERSION=$(google-chrome --version) + echo "Installed Chrome Browser version: $CHROME_BROWSER_VERSION" + CHROME_MAJOR_VERSION=$(echo "$CHROME_BROWSER_VERSION" | cut -f 3 -d ' ' | cut -f 1 -d '.') + echo "Detected Chrome Major version: $CHROME_MAJOR_VERSION" + if [ "$CHROME_MAJOR_VERSION" -ge 115 ]; then + echo "Fetching ChromeDriver version for Chrome $CHROME_MAJOR_VERSION using CfT endpoint..." + CHROMEDRIVER_VERSION_STRING=$(curl -sS "https://googlechromelabs.github.io/chrome-for-testing/LATEST_RELEASE_${CHROME_MAJOR_VERSION}") + if [ -z "$CHROMEDRIVER_VERSION_STRING" ]; then + echo "Could not automatically find ChromeDriver version for Chrome $CHROME_MAJOR_VERSION via LATEST_RELEASE. Please check CfT endpoints." + exit 1 + fi + CHROMEDRIVER_URL="https://edgedl.me.gvt1.com/edgedl/chrome/chrome-for-testing/${CHROMEDRIVER_VERSION_STRING}/linux64/chromedriver-linux64.zip" + else + echo "Fetching ChromeDriver version for Chrome $CHROME_MAJOR_VERSION using older method..." + CHROMEDRIVER_VERSION_STRING=$(curl -sS "https://chromedriver.storage.googleapis.com/LATEST_RELEASE_${CHROME_MAJOR_VERSION}") + CHROMEDRIVER_URL="https://chromedriver.storage.googleapis.com/${CHROMEDRIVER_VERSION_STRING}/chromedriver_linux64.zip" + fi + echo "Using ChromeDriver version string: $CHROMEDRIVER_VERSION_STRING" + echo "Downloading ChromeDriver from: $CHROMEDRIVER_URL" + wget -q -O chromedriver.zip "$CHROMEDRIVER_URL" + unzip -o chromedriver.zip -d /tmp/ + sudo mv /tmp/chromedriver-linux64/chromedriver /usr/local/bin/chromedriver || sudo mv /tmp/chromedriver /usr/local/bin/chromedriver + sudo chmod +x /usr/local/bin/chromedriver + echo "/usr/local/bin" >> $GITHUB_PATH + shell: bash + + - name: Build/Setup test components + run: npm run setup-tests.py + + - name: Run Backend Callback Tests + run: | + mkdir bgtests + cp -r tests bgtests/tests + cd bgtests + touch __init__.py + pytest --headless --nopercyfinalize tests/backend_tests -v -s + table-unit: name: Table Unit/Lint Tests (Python ${{ matrix.python-version }}) needs: [build, changes_filter] diff --git a/dash/_callback.py b/dash/_callback.py index 6cc55b9162..4a714caeac 100644 --- a/dash/_callback.py +++ b/dash/_callback.py @@ -1,12 +1,8 @@ +from typing import Callable, Optional, Any, List, Tuple, Union +from functools import wraps import collections import hashlib -from functools import wraps - -from typing import Callable, Optional, Any, List, Tuple, Union - - import asyncio -from dash.backend import get_request_adapter from .dependencies import ( handle_callback_args, @@ -39,10 +35,11 @@ clean_property_name, ) -from . import _validate from .background_callback.managers import BaseBackgroundCallbackManager from ._callback_context import context_value from ._no_update import NoUpdate +from . import _validate +from . import backends async def _async_invoke_callback( @@ -176,7 +173,6 @@ def callback( Note that the endpoint will not appear in the list of registered callbacks in the Dash devtools. """ - background_spec = None config_prevent_initial_callbacks = _kwargs.pop( @@ -376,7 +372,8 @@ def _get_callback_manager( " and store results on redis.\n" ) - old_job = get_request_adapter().get_args().getlist("oldJob") + adapter = backends.request_adapter() + old_job = adapter.args.getlist("oldJob") if hasattr(adapter.args, "getlist") else [] if old_job: for job in old_job: @@ -390,6 +387,8 @@ def _setup_background_callback( ): """Set up the background callback and manage jobs.""" callback_manager = _get_callback_manager(kwargs, background) + if not callback_manager: + return to_json({"error": "No background callback manager configured"}) progress_outputs = background.get("progress") @@ -397,14 +396,11 @@ def _setup_background_callback( cache_key = callback_manager.build_cache_key( func, - # Inputs provided as dict is kwargs. func_args if func_args else func_kwargs, background.get("cache_args_to_ignore", []), None if cache_ignore_triggered else callback_ctx.get("triggered_inputs", []), ) - job_fn = callback_manager.func_registry.get(background_key) - ctx_value = AttributeDict(**context_value.get()) ctx_value.ignore_register_page = True ctx_value.pop("background_callback_manager") @@ -436,7 +432,8 @@ def _setup_background_callback( def _progress_background_callback(response, callback_manager, background): progress_outputs = background.get("progress") - cache_key = get_request_adapter().get_args().get("cacheKey") + adapter = backends.request_adapter() + cache_key = adapter.args.get("cacheKey") if progress_outputs: # Get the progress before the result as it would be erased after the results. @@ -453,8 +450,9 @@ def _update_background_callback( """Set up the background callback and manage jobs.""" callback_manager = _get_callback_manager(kwargs, background) - cache_key = get_request_adapter().get_args().get("cacheKey") - job_id = get_request_adapter().get_args().get("job") + adapter = backends.request_adapter() + cache_key = adapter.args.get("cacheKey") if adapter else None + job_id = adapter.args.get("job") if adapter else None _progress_background_callback(response, callback_manager, background) @@ -474,8 +472,9 @@ def _handle_rest_background_callback( multi, has_update=False, ): - cache_key = get_request_adapter().get_args().get("cacheKey") - job_id = get_request_adapter().get_args().get("job") + adapter = backends.request_adapter() + cache_key = adapter.args.get("cacheKey") if adapter else None + job_id = adapter.args.get("job") if adapter else None # Must get job_running after get_result since get_results terminates it. job_running = callback_manager.job_running(job_id) if not job_running and output_value is callback_manager.UNDEFINED: @@ -688,10 +687,11 @@ def add_context(*args, **kwargs): ) response: dict = {"multi": True} - jsonResponse = None + jsonResponse: Optional[str] = None try: if background is not None: - if not get_request_adapter().get_args().get("cacheKey"): + adapter = backends.request_adapter() + if not (adapter and adapter.args.get("cacheKey")): return _setup_background_callback( kwargs, background, @@ -762,7 +762,8 @@ async def async_add_context(*args, **kwargs): try: if background is not None: - if not get_request_adapter().get_args().get("cacheKey"): + adapter = backends.request_adapter() + if not (adapter and adapter.args.get("cacheKey")): return _setup_background_callback( kwargs, background, diff --git a/dash/_pages.py b/dash/_pages.py index acb26e8791..19a797bcf2 100644 --- a/dash/_pages.py +++ b/dash/_pages.py @@ -318,18 +318,22 @@ def register_page( ) page.update( supplied_title=title, - title=title - if title is not None - else CONFIG.title - if CONFIG.title != "Dash" - else page["name"], + title=( + title + if title is not None + else CONFIG.title + if CONFIG.title != "Dash" + else page["name"] + ), ) page.update( - description=description - if description - else CONFIG.description - if CONFIG.description - else "", + description=( + description + if description + else CONFIG.description + if CONFIG.description + else "" + ), order=order, supplied_order=order, supplied_layout=layout, @@ -390,15 +394,13 @@ def _path_to_page(path_id): def _page_meta_tags(app, request): - request_path = request.get_path() + request_path = request.path start_page, path_variables = _path_to_page(request_path.strip("/")) image = start_page.get("image", "") if image: image = app.get_asset_url(image) - assets_image_url = ( - "".join([request.get_root(), image.lstrip("/")]) if image else None - ) + assets_image_url = "".join([request.root, image.lstrip("/")]) if image else None supplied_image_url = start_page.get("image_url") image_url = supplied_image_url if supplied_image_url else assets_image_url @@ -413,7 +415,7 @@ def _page_meta_tags(app, request): return [ {"name": "description", "content": description}, {"property": "twitter:card", "content": "summary_large_image"}, - {"property": "twitter:url", "content": request.get_url()}, + {"property": "twitter:url", "content": request.url}, {"property": "twitter:title", "content": title}, {"property": "twitter:description", "content": description}, {"property": "twitter:image", "content": image_url or ""}, diff --git a/dash/_validate.py b/dash/_validate.py index dea19d64c2..76661cef6b 100644 --- a/dash/_validate.py +++ b/dash/_validate.py @@ -8,6 +8,7 @@ from ._grouping import grouping_len, map_grouping from ._no_update import NoUpdate from .development.base_component import Component +from . import backends from . import exceptions from ._utils import ( patch_collections_abc, @@ -585,3 +586,41 @@ def _valid(out): return _valid(output) + + +def check_async(use_async): + if use_async is None: + try: + import asgiref # pylint: disable=unused-import, import-outside-toplevel # noqa + + use_async = True + except ImportError: + pass + elif use_async: + try: + import asgiref # pylint: disable=unused-import, import-outside-toplevel # noqa + except ImportError as exc: + raise Exception( + "You are trying to use dash[async] without having installed the requirements please install via: `pip install dash[async]`" + ) from exc + + +def check_backend(backend, inferred_backend): + if backend is not None: + if isinstance(backend, type): + # get_backend returns the backend class for a string + # So we compare the class names + expected_backend_cls, _ = backends.get_backend(inferred_backend) + if ( + backend.__module__ != expected_backend_cls.__module__ + or backend.__name__ != expected_backend_cls.__name__ + ): + raise ValueError( + f"Conflict between provided backend '{backend.__name__}' and server type '{inferred_backend}'." + ) + elif not isinstance(backend, str): + raise ValueError("Invalid backend argument") + elif backend.lower() != inferred_backend: + raise ValueError( + f"Conflict between provided backend '{backend}' and server type '{inferred_backend}'." + ) diff --git a/dash/backend/__init__.py b/dash/backend/__init__.py deleted file mode 100644 index eb1d47bc3f..0000000000 --- a/dash/backend/__init__.py +++ /dev/null @@ -1,15 +0,0 @@ -# python -import contextvars -from .registry import get_backend # pylint: disable=unused-import - -__all__ = ["set_request_adapter", "get_request_adapter", "get_backend"] - -_request_adapter_var = contextvars.ContextVar("request_adapter") - - -def set_request_adapter(adapter): - _request_adapter_var.set(adapter) - - -def get_request_adapter(): - return _request_adapter_var.get() diff --git a/dash/backend/base_server.py b/dash/backend/base_server.py deleted file mode 100644 index 4855f86ad6..0000000000 --- a/dash/backend/base_server.py +++ /dev/null @@ -1,58 +0,0 @@ -from abc import ABC, abstractmethod -from typing import Any - - -class BaseDashServer(ABC): - def __call__(self, server, *args, **kwargs) -> Any: - # Default: WSGI - return server(*args, **kwargs) - - @abstractmethod - def create_app( - self, name: str = "__main__", config=None - ) -> Any: # pragma: no cover - interface - pass - - @abstractmethod - def register_assets_blueprint( - self, app, blueprint_name: str, assets_url_path: str, assets_folder: str - ) -> None: # pragma: no cover - interface - pass - - @abstractmethod - def register_error_handlers(self, app) -> None: # pragma: no cover - interface - pass - - @abstractmethod - def add_url_rule( - self, app, rule: str, view_func, endpoint=None, methods=None - ) -> None: # pragma: no cover - interface - pass - - @abstractmethod - def before_request(self, app, func) -> None: # pragma: no cover - interface - pass - - @abstractmethod - def after_request(self, app, func) -> None: # pragma: no cover - interface - pass - - @abstractmethod - def run( - self, app, host: str, port: int, debug: bool, **kwargs - ) -> None: # pragma: no cover - interface - pass - - @abstractmethod - def make_response( - self, data, mimetype=None, content_type=None - ) -> Any: # pragma: no cover - interface - pass - - @abstractmethod - def jsonify(self, obj) -> Any: # pragma: no cover - interface - pass - - @abstractmethod - def get_request_adapter(self) -> Any: # pragma: no cover - interface - pass diff --git a/dash/backend/fastapi.py b/dash/backend/fastapi.py deleted file mode 100644 index 56f2761a3d..0000000000 --- a/dash/backend/fastapi.py +++ /dev/null @@ -1,374 +0,0 @@ -import sys -import mimetypes -import hashlib -import inspect -import pkgutil -from contextvars import copy_context -import importlib.util -import time - -try: - import uvicorn - from fastapi import FastAPI, Request, Response - from fastapi.responses import JSONResponse, PlainTextResponse - from fastapi.staticfiles import StaticFiles - from starlette.responses import Response as StarletteResponse - from starlette.datastructures import MutableHeaders - from pydantic import create_model - from typing import Any, Optional -except ImportError: - uvicorn = None - FastAPI = None - Request = None - Response = None - JSONResponse = None - PlainTextResponse = None - StaticFiles = None - StarletteResponse = None - MutableHeaders = None - create_model = None - Any = None - Optional = None - -from dash.fingerprint import check_fingerprint -from dash import _validate -from dash.exceptions import PreventUpdate, InvalidResourceError -from dash.backend import set_request_adapter -from .base_server import BaseDashServer - - -class FastAPIDashServer(BaseDashServer): - def __init__(self): - self.config = {} - super().__init__() - - def __call__(self, server, *args, **kwargs): - # ASGI: (scope, receive, send) - if len(args) == 3 and isinstance(args[0], dict) and "type" in args[0]: - return server(*args, **kwargs) - raise TypeError("FastAPI app must be called with (scope, receive, send)") - - def create_app(self, name="__main__", config=None): - app = FastAPI() - if config: - for key, value in config.items(): - setattr(app.state, key, value) - return app - - def register_assets_blueprint( - self, app, blueprint_name, assets_url_path, assets_folder - ): - try: - app.mount( - assets_url_path, - StaticFiles(directory=assets_folder), - name=blueprint_name, - ) - except RuntimeError: - # directory doesnt exist - pass - - def register_error_handlers(self, app): - @app.exception_handler(PreventUpdate) - async def _handle_error(_request, _exc): - return Response(status_code=204) - - @app.exception_handler(InvalidResourceError) - async def _invalid_resources_handler(_request, exc): - return Response(content=exc.args[0], status_code=404) - - def register_prune_error_handler(self, app, secret, get_traceback_func): - @app.exception_handler(Exception) - async def _wrap_errors(_error_request, error): - tb = get_traceback_func(secret, error) - return PlainTextResponse(tb, status_code=500) - - def _html_response_wrapper(self, view_func): - async def wrapped(*_args, **_kwargs): - # If view_func is a function, call it; if it's a string, use it directly - html = view_func() if callable(view_func) else view_func - return Response(content=html, media_type="text/html") - - return wrapped - - def setup_index(self, dash_app): - async def index(request: Request): - adapter = FastAPIRequestAdapter() - set_request_adapter(adapter) - adapter.set_request(request) - return Response(content=dash_app.index(), media_type="text/html") - - # pylint: disable=protected-access - dash_app._add_url("", index, methods=["GET"]) - - def setup_catchall(self, dash_app): - @dash_app.server.on_event("startup") - def _setup_catchall(): - dash_app.enable_dev_tools( - **self.config, first_run=False - ) # do this to make sure dev tools are enabled - - async def catchall(request: Request): - adapter = FastAPIRequestAdapter() - set_request_adapter(adapter) - adapter.set_request(request) - return Response(content=dash_app.index(), media_type="text/html") - - # pylint: disable=protected-access - dash_app._add_url("{path:path}", catchall, methods=["GET"]) - - def add_url_rule( - self, app, rule, view_func, endpoint=None, methods=None, include_in_schema=False - ): - if rule == "": - rule = "/" - if isinstance(view_func, str): - # Wrap string or sync function to async FastAPI handler - view_func = self._html_response_wrapper(view_func) - app.add_api_route( - rule, - view_func, - methods=methods or ["GET"], - name=endpoint, - include_in_schema=include_in_schema, - ) - - def before_request(self, app, func): - # FastAPI does not have before_request, but we can use middleware - app.middleware("http")(self._make_before_middleware(func)) - - def after_request(self, app, func): - # FastAPI does not have after_request, but we can use middleware - app.middleware("http")(self._make_after_middleware(func)) - - def run(self, app, host, port, debug, **kwargs): - frame = inspect.stack()[2] - self.config = dict({"debug": debug} if debug else {}, **kwargs) - reload = debug - if reload: - # Dynamically determine the module name from the file path - file_path = frame.filename - module_name = importlib.util.spec_from_file_location("app", file_path).name - uvicorn.run( - f"{module_name}:app.server", - host=host, - port=port, - reload=reload, - **kwargs, - ) - else: - uvicorn.run(app, host=host, port=port, reload=reload, **kwargs) - - def make_response(self, data, mimetype=None, content_type=None): - headers = {} - if mimetype: - headers["content-type"] = mimetype - if content_type: - headers["content-type"] = content_type - return Response(content=data, headers=headers) - - def jsonify(self, obj): - return JSONResponse(content=obj) - - def get_request_adapter(self): - return FastAPIRequestAdapter - - def _make_before_middleware(self, func): - async def middleware(request, call_next): - if func is not None: - if inspect.iscoroutinefunction(func): - await func() - else: - func() - response = await call_next(request) - return response - - return middleware - - def _make_after_middleware(self, func): - async def middleware(request, call_next): - response = await call_next(request) - if func is not None: - if inspect.iscoroutinefunction(func): - await func() - else: - func() - return response - - return middleware - - def serve_component_suites( - self, dash_app, package_name, fingerprinted_path, request - ): - path_in_pkg, has_fingerprint = check_fingerprint(fingerprinted_path) - _validate.validate_js_path(dash_app.registered_paths, package_name, path_in_pkg) - extension = "." + path_in_pkg.split(".")[-1] - mimetype = mimetypes.types_map.get(extension, "application/octet-stream") - package = sys.modules[package_name] - dash_app.logger.debug( - "serving -- package: %s[%s] resource: %s => location: %s", - package_name, - package.__version__, - path_in_pkg, - package.__path__, - ) - data = pkgutil.get_data(package_name, path_in_pkg) - headers = {} - if has_fingerprint: - headers["Cache-Control"] = "public, max-age=31536000" - return StarletteResponse(content=data, media_type=mimetype, headers=headers) - etag = hashlib.md5(data).hexdigest() if data else "" - headers["ETag"] = etag - if request.headers.get("if-none-match") == etag: - return StarletteResponse(status_code=304) - return StarletteResponse(content=data, media_type=mimetype, headers=headers) - - def setup_component_suites(self, dash_app): - async def serve(request: Request, package_name: str, fingerprinted_path: str): - return self.serve_component_suites( - dash_app, package_name, fingerprinted_path, request - ) - - # pylint: disable=protected-access - dash_app._add_url( - "_dash-component-suites/{package_name}/{fingerprinted_path:path}", - serve, - ) - - # pylint: disable=unused-argument - def dispatch(self, app, dash_app, use_async=False): - async def _dispatch(request: Request): - adapter = FastAPIRequestAdapter() - set_request_adapter(adapter) - adapter.set_request(request) - # pylint: disable=protected-access - body = await request.json() - g = dash_app._initialize_context( - body, adapter - ) # pylint: disable=protected-access - func = dash_app._prepare_callback( - g, body - ) # pylint: disable=protected-access - args = dash_app._inputs_to_vals( - g.inputs_list + g.states_list - ) # pylint: disable=protected-access - ctx = copy_context() - partial_func = dash_app._execute_callback( - func, args, g.outputs_list, g - ) # pylint: disable=protected-access - response_data = ctx.run(partial_func) - if inspect.iscoroutine(response_data): - response_data = await response_data - # Instead of set_data, return a new Response - return Response(content=response_data, media_type="application/json") - - return _dispatch - - def _serve_default_favicon(self): - return Response( - content=pkgutil.get_data("dash", "favicon.ico"), media_type="image/x-icon" - ) - - def register_timing_hooks(self, app, first_run): - if not first_run: - return - - @app.middleware("http") - async def timing_middleware(request, call_next): - # Before request - request.state.timing_information = { - "__dash_server": {"dur": time.time(), "desc": None} - } - response = await call_next(request) - # After request - timing_information = getattr(request.state, "timing_information", None) - if timing_information is not None: - dash_total = timing_information.get("__dash_server", None) - if dash_total is not None: - dash_total["dur"] = round((time.time() - dash_total["dur"]) * 1000) - headers = MutableHeaders(response.headers) - for name, info in timing_information.items(): - value = name - if info.get("desc") is not None: - value += f';desc="{info["desc"]}"' - if info.get("dur") is not None: - value += f";dur={info['dur']}" - headers.append("Server-Timing", value) - return response - - def register_callback_api_routes(self, app, callback_api_paths): - """ - Register callback API endpoints on the FastAPI app. - Each key in callback_api_paths is a route, each value is a handler (sync or async). - Dynamically creates a Pydantic model for the handler's parameters and uses it as the body parameter. - """ - for path, handler in callback_api_paths.items(): - endpoint = f"dash_callback_api_{path}" - route = path if path.startswith("/") else f"/{path}" - methods = ["POST"] - sig = inspect.signature(handler) - param_names = list(sig.parameters.keys()) - fields = {name: (Optional[Any], None) for name in param_names} - Model = create_model( - f"Payload_{endpoint}", **fields - ) # pylint: disable=cell-var-from-loop - - # pylint: disable=cell-var-from-loop - async def view_func(request: Request, body: Model): - kwargs = body.dict(exclude_unset=True) - if inspect.iscoroutinefunction(handler): - result = await handler(**kwargs) - else: - result = handler(**kwargs) - return JSONResponse(content=result) - - app.add_api_route( - route, - view_func, - methods=methods, - name=endpoint, - include_in_schema=True, - ) - - -class FastAPIRequestAdapter: - def __init__(self): - self._request = None - - def set_request(self, request: Request): - self._request = request - - def get_root(self): - return str(self._request.base_url) - - def get_args(self): - return self._request.query_params - - async def get_json(self): - return await self._request.json() - - def is_json(self): - return self._request.headers.get("content-type", "").startswith( - "application/json" - ) - - def get_cookies(self, _request=None): - return self._request.cookies - - def get_headers(self): - return self._request.headers - - def get_full_path(self): - return str(self._request.url) - - def get_url(self): - return str(self._request.url) - - def get_remote_addr(self): - return self._request.client.host if self._request.client else None - - def get_origin(self): - return self._request.headers.get("origin") - - def get_path(self): - return self._request.url.path diff --git a/dash/backend/flask.py b/dash/backend/flask.py deleted file mode 100644 index b48225a3c5..0000000000 --- a/dash/backend/flask.py +++ /dev/null @@ -1,278 +0,0 @@ -from contextvars import copy_context -import asyncio -import pkgutil -import sys -import mimetypes -import time -import inspect -import flask -from dash.fingerprint import check_fingerprint -from dash import _validate -from dash.exceptions import PreventUpdate, InvalidResourceError -from dash.backend import set_request_adapter -from .base_server import BaseDashServer - - -class FlaskDashServer(BaseDashServer): - def __call__(self, server, *args, **kwargs): - # Always WSGI - return server(*args, **kwargs) - - def create_app(self, name="__main__", config=None): - app = flask.Flask(name) - if config: - app.config.update(config) - return app - - def register_assets_blueprint( - self, app, blueprint_name, assets_url_path, assets_folder - ): - bp = flask.Blueprint( - blueprint_name, - __name__, - static_folder=assets_folder, - static_url_path=assets_url_path, - ) - app.register_blueprint(bp) - - def register_error_handlers(self, app): - @app.errorhandler(PreventUpdate) - def _handle_error(_): - return "", 204 - - @app.errorhandler(InvalidResourceError) - def _invalid_resources_handler(err): - return err.args[0], 404 - - def register_prune_error_handler(self, app, secret, get_traceback_func): - @app.errorhandler(Exception) - def _wrap_errors(error): - tb = get_traceback_func(secret, error) - return tb, 500 - - def add_url_rule(self, app, rule, view_func, endpoint=None, methods=None): - app.add_url_rule( - rule, view_func=view_func, endpoint=endpoint, methods=methods or ["GET"] - ) - - def before_request(self, app, func): - app.before_request(func) - - def after_request(self, app, func): - app.after_request(func) - - def run(self, app, host, port, debug, **kwargs): - app.run(host=host, port=port, debug=debug, **kwargs) - - def make_response(self, data, mimetype=None, content_type=None): - return flask.Response(data, mimetype=mimetype, content_type=content_type) - - def jsonify(self, obj): - return flask.jsonify(obj) - - def get_request_adapter(self): - return FlaskRequestAdapter - - def setup_catchall(self, dash_app): - def catchall(*args, **kwargs): - adapter = FlaskRequestAdapter() - set_request_adapter(adapter) - return dash_app.index(*args, **kwargs) - - # pylint: disable=protected-access - dash_app._add_url("", catchall, methods=["GET"]) - - def setup_index(self, dash_app): - def index(*args, **kwargs): - adapter = FlaskRequestAdapter() - set_request_adapter(adapter) - return dash_app.index(*args, **kwargs) - - # pylint: disable=protected-access - dash_app._add_url("", index, methods=["GET"]) - - def serve_component_suites(self, dash_app, package_name, fingerprinted_path): - path_in_pkg, has_fingerprint = check_fingerprint(fingerprinted_path) - _validate.validate_js_path(dash_app.registered_paths, package_name, path_in_pkg) - extension = "." + path_in_pkg.split(".")[-1] - mimetype = mimetypes.types_map.get(extension, "application/octet-stream") - package = sys.modules[package_name] - dash_app.logger.debug( - "serving -- package: %s[%s] resource: %s => location: %s", - package_name, - package.__version__, - path_in_pkg, - package.__path__, - ) - data = pkgutil.get_data(package_name, path_in_pkg) - response = flask.Response(data, mimetype=mimetype) - if has_fingerprint: - response.cache_control.max_age = 31536000 # 1 year - else: - response.add_etag() - tag = response.get_etag()[0] - request_etag = flask.request.headers.get("If-None-Match") - if f'"{tag}"' == request_etag: - response = flask.Response(None, status=304) - return response - - def setup_component_suites(self, dash_app): - def serve(package_name, fingerprinted_path): - return self.serve_component_suites( - dash_app, package_name, fingerprinted_path - ) - - # pylint: disable=protected-access - dash_app._add_url( - "_dash-component-suites//", - serve, - ) - - # pylint: disable=unused-argument - def dispatch(self, app, dash_app, use_async=False): - def _dispatch(): - adapter = FlaskRequestAdapter() - set_request_adapter(adapter) - body = flask.request.get_json() - # pylint: disable=protected-access - g = dash_app._initialize_context(body, adapter) - func = dash_app._prepare_callback(g, body) - args = dash_app._inputs_to_vals(g.inputs_list + g.states_list) - ctx = copy_context() - partial_func = dash_app._execute_callback(func, args, g.outputs_list, g) - response_data = ctx.run(partial_func) - if asyncio.iscoroutine(response_data): - raise Exception( - "You are trying to use a coroutine without dash[async]. " - "Please install the dependencies via `pip install dash[async]` and ensure " - "that `use_async=False` is not being passed to the app." - ) - g.dash_response.set_data(response_data) - return g.dash_response - - async def _dispatch_async(): - adapter = FlaskRequestAdapter() - set_request_adapter(adapter) - body = flask.request.get_json() - # pylint: disable=protected-access - g = dash_app._initialize_context(body, adapter) - func = dash_app._prepare_callback(g, body) - args = dash_app._inputs_to_vals(g.inputs_list + g.states_list) - ctx = copy_context() - partial_func = dash_app._execute_callback(func, args, g.outputs_list, g) - response_data = ctx.run(partial_func) - if asyncio.iscoroutine(response_data): - response_data = await response_data - g.dash_response.set_data(response_data) - return g.dash_response - - if use_async: - return _dispatch_async - return _dispatch - - def _serve_default_favicon(self): - - return flask.Response( - pkgutil.get_data("dash", "favicon.ico"), content_type="image/x-icon" - ) - - def register_timing_hooks(self, app, _first_run): - def _before_request(): - flask.g.timing_information = { - "__dash_server": {"dur": time.time(), "desc": None} - } - - def _after_request(response): - timing_information = flask.g.get("timing_information", None) - if timing_information is None: - return response - dash_total = timing_information.get("__dash_server", None) - if dash_total is not None: - dash_total["dur"] = round((time.time() - dash_total["dur"]) * 1000) - for name, info in timing_information.items(): - value = name - if info.get("desc") is not None: - value += f';desc="{info["desc"]}"' - if info.get("dur") is not None: - value += f";dur={info['dur']}" - response.headers.add("Server-Timing", value) - return response - - self.before_request(app, _before_request) - self.after_request(app, _after_request) - - def register_callback_api_routes(self, app, callback_api_paths): - """ - Register callback API endpoints on the Flask app. - Each key in callback_api_paths is a route, each value is a handler (sync or async). - The view function parses the JSON body and passes it to the handler. - """ - for path, handler in callback_api_paths.items(): - endpoint = f"dash_callback_api_{path}" - route = path if path.startswith("/") else f"/{path}" - methods = ["POST"] - - if inspect.iscoroutinefunction(handler): - - async def view_func(*args, handler=handler, **kwargs): - data = flask.request.get_json() - result = await handler(**data) if data else await handler() - return flask.jsonify(result) - - else: - - def view_func(*args, handler=handler, **kwargs): - data = flask.request.get_json() - result = handler(**data) if data else handler() - return flask.jsonify(result) - - # Flask 2.x+ supports async views natively - app.add_url_rule( - route, endpoint=endpoint, view_func=view_func, methods=methods - ) - - -class FlaskRequestAdapter: - @staticmethod - def get_args(): - return flask.request.args - - @staticmethod - def get_root(): - return flask.request.url_root - - @staticmethod - def get_json(): - return flask.request.get_json() - - @staticmethod - def is_json(): - return flask.request.is_json - - @staticmethod - def get_cookies(): - return flask.request.cookies - - @staticmethod - def get_headers(): - return flask.request.headers - - @staticmethod - def get_url(): - return flask.request.url - - @staticmethod - def get_full_path(): - return flask.request.full_path - - @staticmethod - def get_remote_addr(): - return flask.request.remote_addr - - @staticmethod - def get_origin(): - return getattr(flask.request, "origin", None) - - @staticmethod - def get_path(): - return flask.request.path diff --git a/dash/backend/quart.py b/dash/backend/quart.py deleted file mode 100644 index c3d42dadee..0000000000 --- a/dash/backend/quart.py +++ /dev/null @@ -1,297 +0,0 @@ -import inspect -import pkgutil -import mimetypes -import sys -import time -from contextvars import copy_context - -try: - import quart - from quart import Quart, Response, jsonify, request, Blueprint -except ImportError: - quart = None - Quart = None - Response = None - jsonify = None - request = None - Blueprint = None -from dash.exceptions import PreventUpdate, InvalidResourceError -from dash.backend import set_request_adapter -from dash.fingerprint import check_fingerprint -from dash import _validate -from .base_server import BaseDashServer - - -class QuartDashServer(BaseDashServer): - """Quart implementation of the Dash server factory. - - All Quart/async specific imports are at the top-level (per user request) so - Quart must be installed when this module is imported. - """ - - def __init__(self) -> None: - self.config = {} - super().__init__() - - def __call__(self, server, *args, **kwargs): - return server(*args, **kwargs) - - def create_app(self, name="__main__", config=None): - app = Quart(name) - if config: - for key, value in config.items(): - app.config[key] = value - return app - - def register_assets_blueprint( - self, app, blueprint_name, assets_url_path, assets_folder - ): - bp = Blueprint( - blueprint_name, - __name__, - static_folder=assets_folder, - static_url_path=assets_url_path, - ) - app.register_blueprint(bp) - - def register_prune_error_handler(self, app, secret, get_traceback_func): - @app.errorhandler(Exception) - async def _wrap_errors(_error_request, error): - tb = get_traceback_func(secret, error) - return tb, 500 - - def register_timing_hooks(self, app, _first_run): # parity with Flask factory - @app.before_request - async def _before_request(): # pragma: no cover - timing infra - quart.g.timing_information = { - "__dash_server": {"dur": time.time(), "desc": None} - } - - @app.after_request - async def _after_request(response): # pragma: no cover - timing infra - timing_information = getattr(quart.g, "timing_information", None) - if timing_information is None: - return response - dash_total = timing_information.get("__dash_server", None) - if dash_total is not None: - dash_total["dur"] = round((time.time() - dash_total["dur"]) * 1000) - for name, info in timing_information.items(): - value = name - if info.get("desc") is not None: - value += f';desc="{info["desc"]}"' - if info.get("dur") is not None: - value += f";dur={info['dur']}" - # Quart/Werkzeug headers expose 'add' (not 'append') - if hasattr(response.headers, "add"): - response.headers.add("Server-Timing", value) - else: # fallback just in case - response.headers["Server-Timing"] = value - return response - - def register_error_handlers(self, app): - @app.errorhandler(PreventUpdate) - async def _prevent_update(_): - return "", 204 - - @app.errorhandler(InvalidResourceError) - async def _invalid_resource(err): - return err.args[0], 404 - - def _html_response_wrapper(self, view_func): - async def wrapped(*_args, **_kwargs): - html_val = view_func() if callable(view_func) else view_func - if inspect.iscoroutine(html_val): # handle async function returning html - html_val = await html_val - html = str(html_val) - return Response(html, content_type="text/html") - - return wrapped - - def add_url_rule(self, app, rule, view_func, endpoint=None, methods=None): - app.add_url_rule( - rule, view_func=view_func, endpoint=endpoint, methods=methods or ["GET"] - ) - - def setup_index(self, dash_app): - async def index(*args, **kwargs): - adapter = QuartRequestAdapter() - set_request_adapter(adapter) - adapter.set_request() - return Response(dash_app.index(*args, **kwargs), content_type="text/html") - - # pylint: disable=protected-access - dash_app._add_url("", index, methods=["GET"]) - - def setup_catchall(self, dash_app): - async def catchall( - path, *args, **kwargs - ): # noqa: ARG001 - path is unused but kept for route signature, pylint: disable=unused-argument - adapter = QuartRequestAdapter() - set_request_adapter(adapter) - adapter.set_request() - return Response(dash_app.index(*args, **kwargs), content_type="text/html") - - # pylint: disable=protected-access - dash_app._add_url("", catchall, methods=["GET"]) - - def before_request(self, app, func): - app.before_request(func) - - def after_request(self, app, func): - @app.after_request - async def _after(response): - if func is not None: - result = func() - if inspect.iscoroutine(result): # Allow async hooks - await result - return response - - def run(self, app, host, port, debug, **kwargs): - self.config = {"debug": debug, **kwargs} if debug else kwargs - app.run(host=host, port=port, debug=debug, **kwargs) - - def make_response(self, data, mimetype=None, content_type=None): - return Response(data, mimetype=mimetype, content_type=content_type) - - def jsonify(self, obj): - return jsonify(obj) - - def get_request_adapter(self): - return QuartRequestAdapter - - def serve_component_suites( - self, dash_app, package_name, fingerprinted_path - ): # noqa: ARG002 unused req preserved for interface parity - path_in_pkg, has_fingerprint = check_fingerprint(fingerprinted_path) - _validate.validate_js_path(dash_app.registered_paths, package_name, path_in_pkg) - extension = "." + path_in_pkg.split(".")[-1] - mimetype = mimetypes.types_map.get(extension, "application/octet-stream") - package = sys.modules[package_name] - dash_app.logger.debug( - "serving -- package: %s[%s] resource: %s => location: %s", - package_name, - getattr(package, "__version__", "unknown"), - path_in_pkg, - package.__path__, - ) - data = pkgutil.get_data(package_name, path_in_pkg) - headers = {} - if has_fingerprint: - headers["Cache-Control"] = "public, max-age=31536000" - - return Response(data, content_type=mimetype, headers=headers) - - def setup_component_suites(self, dash_app): - async def serve(package_name, fingerprinted_path): - return self.serve_component_suites( - dash_app, package_name, fingerprinted_path - ) - - # pylint: disable=protected-access - dash_app._add_url( - "_dash-component-suites//", - serve, - ) - - # pylint: disable=unused-argument - def dispatch(self, app, dash_app, use_async=True): # Quart always async - async def _dispatch(): - adapter = QuartRequestAdapter() - set_request_adapter(adapter) - adapter.set_request() - body = await request.get_json() - # pylint: disable=protected-access - g = dash_app._initialize_context(body, adapter) - # pylint: disable=protected-access - func = dash_app._prepare_callback(g, body) - # pylint: disable=protected-access - args = dash_app._inputs_to_vals(g.inputs_list + g.states_list) - ctx = copy_context() - # pylint: disable=protected-access - partial_func = dash_app._execute_callback(func, args, g.outputs_list, g) - response_data = ctx.run(partial_func) - if inspect.iscoroutine(response_data): # if user callback is async - response_data = await response_data - return Response(response_data, content_type="application/json") - - return _dispatch - - def register_callback_api_routes(self, app, callback_api_paths): - """ - Register callback API endpoints on the Quart app. - Each key in callback_api_paths is a route, each value is a handler (sync or async). - The view function parses the JSON body and passes it to the handler. - """ - for path, handler in callback_api_paths.items(): - endpoint = f"dash_callback_api_{path}" - route = path if path.startswith("/") else f"/{path}" - methods = ["POST"] - - def _make_view_func(handler): - if inspect.iscoroutinefunction(handler): - - async def async_view_func(*args, **kwargs): - data = await request.get_json() - result = await handler(**data) if data else await handler() - return jsonify(result) - - return async_view_func - - async def sync_view_func(*args, **kwargs): - data = await request.get_json() - result = handler(**data) if data else handler() - return jsonify(result) - - return sync_view_func - - view_func = _make_view_func(handler) - app.add_url_rule( - route, endpoint=endpoint, view_func=view_func, methods=methods - ) - - def _serve_default_favicon(self): - return Response( - pkgutil.get_data("dash", "favicon.ico"), content_type="image/x-icon" - ) - - -class QuartRequestAdapter: - def __init__(self) -> None: - self._request = None - - def set_request(self) -> None: - self._request = request - - # Accessors (instance-based) - def get_root(self): - return self._request.root_url - - def get_args(self): - return self._request.args - - async def get_json(self): - return await self._request.get_json() - - def is_json(self): - return self._request.is_json - - def get_cookies(self): - return self._request.cookies - - def get_headers(self): - return self._request.headers - - def get_full_path(self): - return self._request.full_path - - def get_url(self): - return str(self._request.url) - - def get_remote_addr(self): - return self._request.remote_addr - - def get_origin(self): - return self._request.headers.get("origin") - - def get_path(self): - return self._request.path diff --git a/dash/backend/registry.py b/dash/backend/registry.py deleted file mode 100644 index 4aae9fafc5..0000000000 --- a/dash/backend/registry.py +++ /dev/null @@ -1,29 +0,0 @@ -import importlib - -_backend_imports = { - "flask": ("dash.backend.flask", "FlaskDashServer"), - "fastapi": ("dash.backend.fastapi", "FastAPIDashServer"), - "quart": ("dash.backend.quart", "QuartDashServer"), -} - - -def register_backend(name, module_path, class_name): - """Register a new backend by name.""" - _backend_imports[name.lower()] = (module_path, class_name) - - -def get_backend(name): - try: - module_name, class_name = _backend_imports[name.lower()] - module = importlib.import_module(module_name) - return getattr(module, class_name) - except KeyError as e: - raise ValueError(f"Unknown backend: {name}") from e - except ImportError as e: - raise ImportError( - f"Could not import module '{module_name}' for backend '{name}': {e}" - ) from e - except AttributeError as e: - raise AttributeError( - f"Module '{module_name}' does not have class '{class_name}' for backend '{name}': {e}" - ) from e diff --git a/dash/backends/__init__.py b/dash/backends/__init__.py new file mode 100644 index 0000000000..940c8f18bd --- /dev/null +++ b/dash/backends/__init__.py @@ -0,0 +1,88 @@ +from .base_server import BaseDashServer, RequestAdapter + +from typing import Literal, Any +import importlib + + +request_adapter: RequestAdapter +backend: BaseDashServer + + +_backend_imports = { + "flask": ("dash.backends._flask", "FlaskDashServer", "FlaskRequestAdapter"), + "fastapi": ("dash.backends._fastapi", "FastAPIDashServer", "FastAPIRequestAdapter"), + "quart": ("dash.backends._quart", "QuartDashServer", "QuartRequestAdapter"), +} + + +request_adapter: RequestAdapter +backend: BaseDashServer + + +def get_backend( + name: Literal["flask", "fastapi", "quart"] | str +) -> tuple[BaseDashServer, RequestAdapter]: + module_name, server_class, request_class = _backend_imports[name.lower()] + try: + module = importlib.import_module(module_name) + server = getattr(module, server_class) + request_adapter = getattr(module, request_class) + return server, request_adapter + except KeyError as e: + raise ValueError(f"Unknown backend: {name}") from e + except ImportError as e: + raise ImportError( + f"Could not import module '{module_name}' for backend '{name}': {e}" + ) from e + except AttributeError as e: + raise AttributeError( + f"Module '{module_name}' does not have class '{server_class}' for backend '{name}': {e}" + ) from e + + +def _is_flask_instance(obj): + try: + # pylint: disable=import-outside-toplevel + from flask import Flask + + return isinstance(obj, Flask) + except ImportError: + return False + + +def _is_fastapi_instance(obj): + try: + # pylint: disable=import-outside-toplevel + from fastapi import FastAPI + + return isinstance(obj, FastAPI) + except ImportError: + return False + + +def _is_quart_instance(obj): + try: + # pylint: disable=import-outside-toplevel + from quart import Quart + + return isinstance(obj, Quart) + except ImportError: + return False + + +def get_server_type(server): + if _is_flask_instance(server): + return "flask" + if _is_quart_instance(server): + return "quart" + if _is_fastapi_instance(server): + return "fastapi" + raise ValueError("Invalid backend argument") + + +__all__ = [ + "get_backend", + "request_adapter", + "backend", + "get_server_type", +] diff --git a/dash/backends/_fastapi.py b/dash/backends/_fastapi.py new file mode 100644 index 0000000000..f3f9f2df33 --- /dev/null +++ b/dash/backends/_fastapi.py @@ -0,0 +1,559 @@ +from __future__ import annotations + +from contextvars import copy_context, ContextVar +from typing import TYPE_CHECKING, Any, Callable, Dict +import sys +import mimetypes +import hashlib +import inspect +import pkgutil +import time +import traceback +from importlib.util import spec_from_file_location +import json +import os +import re + +from dash.fingerprint import check_fingerprint +from dash import _validate +from dash.exceptions import PreventUpdate +from .base_server import BaseDashServer, RequestAdapter + +from fastapi import FastAPI, Request, Response, Body +from fastapi.responses import JSONResponse +from fastapi.staticfiles import StaticFiles +from starlette.responses import Response as StarletteResponse +from starlette.datastructures import MutableHeaders +from starlette.types import ASGIApp, Scope, Receive, Send +import uvicorn + +if TYPE_CHECKING: # pragma: no cover - typing only + from dash.dash import Dash + + +_current_request_var = ContextVar("dash_current_request", default=None) + + +def set_current_request(req): + return _current_request_var.set(req) + + +def reset_current_request(token): + _current_request_var.reset(token) + + +def get_current_request() -> Request: + req = _current_request_var.get() + if req is None: + raise RuntimeError("No active request in context") + return req + + +class CurrentRequestMiddleware: + def __init__(self, app: ASGIApp) -> None: # type: ignore[name-defined] + self.app = app + + async def __call__(self, scope: Scope, receive: Receive, send: Send) -> None: # type: ignore[name-defined] + # non-http/ws scopes pass through (lifespan etc.) + if scope["type"] not in ("http", "websocket"): + await self.app(scope, receive, send) + return + + request = Request(scope, receive=receive) + token = set_current_request(request) + try: + await self.app(scope, receive, send) + finally: + reset_current_request(token) + + +CONFIG_PATH = "dash_config.json" + + +def save_config(config): + with open(CONFIG_PATH, "w") as f: + json.dump(config, f) + + +def load_config(): + if os.path.exists(CONFIG_PATH): + with open(CONFIG_PATH, "r") as f: + return json.load(f) + return {} + + +class FastAPIDashServer(BaseDashServer): + + def __init__(self, server: FastAPI): + self.config = {} + self.server_type = "fastapi" + self.server: FastAPI = server + self.error_handling_mode = "prune" + super().__init__() + + def __call__(self, *args: Any, **kwargs: Any): + # ASGI: (scope, receive, send) + if len(args) == 3 and isinstance(args[0], dict) and "type" in args[0]: + return self.server(*args, **kwargs) + raise TypeError("FastAPI app must be called with (scope, receive, send)") + + @staticmethod + def create_app(name: str = "__main__", config: Dict[str, Any] | None = None): + app = FastAPI() + app.add_middleware(CurrentRequestMiddleware) + + if config: + for key, value in config.items(): + setattr(app.state, key, value) + return app + + def register_assets_blueprint( + self, blueprint_name: str, assets_url_path: str, assets_folder: str + ): + try: + self.server.mount( + assets_url_path, + StaticFiles(directory=assets_folder), + name=blueprint_name, + ) + except RuntimeError: + # directory doesnt exist + pass + + def register_error_handlers(self): + self.error_handling_mode = "prune" + + def _get_traceback(self, _secret, error: Exception): + tb = error.__traceback__ + errors = traceback.format_exception(type(error), error, tb) + pass_errs = [] + callback_handled = False + for err in errors: + if self.error_handling_mode == "prune": + if not callback_handled: + if "callback invoked" in str(err) and "_callback.py" in str(err): + callback_handled = True + continue + pass_errs.append(err) + formatted_tb = "".join(pass_errs) + error_type = type(error).__name__ + error_msg = str(error) + + # Parse traceback lines to group by file + file_cards = [] + pattern = re.compile(r' File "(.+)", line (\d+), in (\w+)') + lines = formatted_tb.split("\n") + current_file = None + card_lines = [] + + for line in lines[:-1]: # Skip the last line (error message) + match = pattern.match(line) + if match: + if current_file and card_lines: + file_cards.append((current_file, card_lines)) + current_file = ( + f"{match.group(1)} (line {match.group(2)}, in {match.group(3)})" + ) + card_lines = [line] + elif current_file: + card_lines.append(line) + if current_file and card_lines: + file_cards.append((current_file, card_lines)) + + cards_html = "" + for filename, card in file_cards: + cards_html += ( + f""" +
+
{filename}
+
"""
+                + "\n".join(card)
+                + """
+
+ """ + ) + + html = f""" + + + + {error_type}: {error_msg} // FastAPI Debugger + + + +
+

{error_type}

+
+

{error_type}: {error_msg}

+
+

Traceback (most recent call last)

+ {cards_html} +
{error_type}: {error_msg}
+
+

This is the Copy/Paste friendly version of the traceback.

+ +
+
+ The debugger caught an exception in your ASGI application. You can now + look at the traceback which led to the error. +
+ +
+ + + """ + return html + + def register_prune_error_handler(self, _secret, prune_errors): + if prune_errors: + self.error_handling_mode = "prune" + else: + self.error_handling_mode = "raise" + + def _html_response_wrapper(self, view_func: Callable[..., Any] | str): + async def wrapped(*_args, **_kwargs): + # If view_func is a function, call it; if it's a string, use it directly + html = view_func() if callable(view_func) else view_func + return Response(content=html, media_type="text/html") + + return wrapped + + def setup_index(self, dash_app: Dash): + async def index(request: Request): + return Response(content=dash_app.index(), media_type="text/html") + + # pylint: disable=protected-access + dash_app._add_url("", index, methods=["GET"]) + + def setup_catchall(self, dash_app: Dash): + @self.server.on_event("startup") + def _setup_catchall(): + dash_app.enable_dev_tools( + **self.config, first_run=False + ) # do this to make sure dev tools are enabled + + async def catchall(request: Request): + return Response(content=dash_app.index(), media_type="text/html") + + # pylint: disable=protected-access + dash_app._add_url("{path:path}", catchall, methods=["GET"]) + + def add_url_rule( + self, + rule: str, + view_func: Callable[..., Any] | str, + endpoint: str | None = None, + methods: list[str] | None = None, + include_in_schema: bool = False, + ): + if rule == "": + rule = "/" + if isinstance(view_func, str): + # Wrap string or sync function to async FastAPI handler + view_func = self._html_response_wrapper(view_func) + self.server.add_api_route( + rule, + view_func, + methods=methods or ["GET"], + name=endpoint, + include_in_schema=include_in_schema, + ) + + def before_request(self, func: Callable[[], Any] | None): + # FastAPI does not have before_request, but we can use middleware + self.server.middleware("http")(self._make_before_middleware(func)) + + def after_request(self, func: Callable[[], Any] | None): + # FastAPI does not have after_request, but we can use middleware + self.server.middleware("http")(self._make_after_middleware(func)) + + def run(self, dash_app: Dash, host, port, debug, **kwargs): + frame = inspect.stack()[2] + config = dict( + {"debug": debug} if debug else {}, + **{ + f"dev_tools_{k}": v for k, v in dash_app._dev_tools.items() + }, # pylint: disable=protected-access + ) + save_config(config) + if debug: + if kwargs.get("reload") is None: + kwargs["reload"] = True + if kwargs.get("reload"): + # Dynamically determine the module name from the file path + file_path = frame.filename + spec = spec_from_file_location("app", file_path) + module_name = spec.name if spec and getattr(spec, "name", None) else "app" + uvicorn.run( + f"{module_name}:app.server", + host=host, + port=port, + **kwargs, + ) + else: + uvicorn.run(self.server, host=host, port=port, **kwargs) + + def make_response( + self, + data: str | bytes | bytearray, + mimetype: str | None = None, + content_type: str | None = None, + ): + headers = {} + if mimetype: + headers["content-type"] = mimetype + if content_type: + headers["content-type"] = content_type + return Response(content=data, headers=headers) + + def jsonify(self, obj: Any): + return JSONResponse(content=obj) + + def _make_before_middleware(self, func: Callable[[], Any] | None): + async def middleware(request, call_next): + try: + response = await call_next(request) + return response + except PreventUpdate: + # No content, nothing to update + return Response(status_code=204) + except Exception as e: + if self.error_handling_mode in ["raise", "prune"]: + # Prune the traceback to remove internal Dash calls + tb = self._get_traceback(None, e) + return Response(content=tb, media_type="text/html", status_code=500) + return JSONResponse( + status_code=500, + content={"error": "InternalServerError", "message": str(e.args[0])}, + ) + + return middleware + + def _make_after_middleware(self, func: Callable[[], Any] | None): + async def middleware(request, call_next): + response = await call_next(request) + if func is not None: + if inspect.iscoroutinefunction(func): + await func() + else: + func() + return response + + return middleware + + def serve_component_suites( + self, + dash_app: Dash, + package_name: str, + fingerprinted_path: str, + request: Request, + ): + + path_in_pkg, has_fingerprint = check_fingerprint(fingerprinted_path) + _validate.validate_js_path(dash_app.registered_paths, package_name, path_in_pkg) + extension = "." + path_in_pkg.split(".")[-1] + mimetype = mimetypes.types_map.get(extension, "application/octet-stream") + package = sys.modules[package_name] + dash_app.logger.debug( + "serving -- package: %s[%s] resource: %s => location: %s", + package_name, + package.__version__, + path_in_pkg, + package.__path__, + ) + data = pkgutil.get_data(package_name, path_in_pkg) + headers = {} + if has_fingerprint: + headers["Cache-Control"] = "public, max-age=31536000" + return StarletteResponse(content=data, media_type=mimetype, headers=headers) + etag = hashlib.md5(data).hexdigest() if data else "" + headers["ETag"] = etag + if request.headers.get("if-none-match") == etag: + return StarletteResponse(status_code=304) + return StarletteResponse(content=data, media_type=mimetype, headers=headers) + + def setup_component_suites(self, dash_app: Dash): + async def serve(request: Request, package_name: str, fingerprinted_path: str): + return self.serve_component_suites( + dash_app, package_name, fingerprinted_path, request + ) + + # pylint: disable=protected-access + dash_app._add_url( + "_dash-component-suites/{package_name}/{fingerprinted_path:path}", + serve, + ) + + # pylint: disable=unused-argument + def dispatch(self, dash_app: Dash): + + async def _dispatch(request: Request): + # pylint: disable=protected-access + body = await request.json() + g = dash_app._initialize_context(body) # pylint: disable=protected-access + func = dash_app._prepare_callback( + g, body + ) # pylint: disable=protected-access + args = dash_app._inputs_to_vals( + g.inputs_list + g.states_list + ) # pylint: disable=protected-access + ctx = copy_context() + partial_func = dash_app._execute_callback( + func, args, g.outputs_list, g + ) # pylint: disable=protected-access + response_data = ctx.run(partial_func) + if inspect.iscoroutine(response_data): + response_data = await response_data + # Instead of set_data, return a new Response + return Response(content=response_data, media_type="application/json") + + return _dispatch + + def _serve_default_favicon(self): + return Response( + content=pkgutil.get_data("dash", "favicon.ico"), media_type="image/x-icon" + ) + + def register_timing_hooks(self, first_run: bool): + if not first_run: + return + + @self.server.middleware("http") + async def timing_middleware(request: Request, call_next): + # Before request + request.state.timing_information = { + "__dash_server": {"dur": time.time(), "desc": None} + } + response = await call_next(request) + # After request + timing_information = getattr(request.state, "timing_information", None) + if timing_information is not None: + dash_total = timing_information.get("__dash_server", None) + if dash_total is not None: + dash_total["dur"] = round((time.time() - dash_total["dur"]) * 1000) + headers = MutableHeaders(response.headers) + for name, info in timing_information.items(): + value = name + if info.get("desc") is not None: + value += f';desc="{info["desc"]}"' + if info.get("dur") is not None: + value += f";dur={info['dur']}" + headers.append("Server-Timing", value) + return response + + def register_callback_api_routes(self, callback_api_paths: Dict[str, Callable[..., Any]]): + """ + Register callback API endpoints on the FastAPI app. + Each key in callback_api_paths is a route, each value is a handler (sync or async). + Accepts a JSON body (dict) and filters keys based on the handler's signature. + """ + for path, handler in callback_api_paths.items(): + endpoint = f"dash_callback_api_{path}" + route = path if path.startswith("/") else f"/{path}" + methods = ["POST"] + sig = inspect.signature(handler) + param_names = list(sig.parameters.keys()) + + async def view_func(request: Request, body: dict = Body(...)): + # Only pass expected params; ignore extras + kwargs = { + k: v for k, v in body.items() if k in param_names and v is not None + } + if inspect.iscoroutinefunction(handler): + result = await handler(**kwargs) + else: + result = handler(**kwargs) + return JSONResponse(content=result) + + self.server.add_api_route( + route, + view_func, + methods=methods, + name=endpoint, + include_in_schema=True, + ) + + +class FastAPIRequestAdapter(RequestAdapter): + + def __init__(self): + self._request: Request = get_current_request() + super().__init__() + + def __call__(self): + self._request = get_current_request() + return self + + @property + def root(self): + return str(self._request.base_url) + + @property + def args(self): + return self._request.query_params + + @property + def is_json(self): + return self._request.headers.get("content-type", "").startswith( + "application/json" + ) + + @property + def cookies(self): + return self._request.cookies + + @property + def headers(self): + return self._request.headers + + @property + def full_path(self): + return str(self._request.url) + + @property + def url(self): + return str(self._request.url) + + @property + def remote_addr(self): + client = getattr(self._request, "client", None) + return getattr(client, "host", None) + + @property + def origin(self): + return self._request.headers.get("origin") + + @property + def path(self): + return self._request.url.path + + async def get_json(self): # async method retained + return await self._request.json() diff --git a/dash/backends/_flask.py b/dash/backends/_flask.py new file mode 100644 index 0000000000..5a1385d574 --- /dev/null +++ b/dash/backends/_flask.py @@ -0,0 +1,360 @@ +from __future__ import annotations + +from contextvars import copy_context +from typing import TYPE_CHECKING, Any, Callable, Dict +import asyncio +import pkgutil +import sys +import mimetypes +import time +import inspect +import traceback +from flask import ( + Flask, + Blueprint, + Response, + request, + jsonify, + g as flask_g, +) + +from dash.fingerprint import check_fingerprint +from dash import _validate +from dash.exceptions import PreventUpdate, InvalidResourceError +from dash._callback import _invoke_callback, _async_invoke_callback +from .base_server import BaseDashServer, RequestAdapter + +if TYPE_CHECKING: # pragma: no cover - typing only + from dash import Dash + + +class FlaskDashServer(BaseDashServer): + + def __init__(self, server: Flask) -> None: + self.server: Flask = server + self.server_type = "flask" + super().__init__() + + def __call__(self, *args: Any, **kwargs: Any): + # Always WSGI + return self.server(*args, **kwargs) + + @staticmethod + def create_app(name: str = "__main__", config: Dict[str, Any] | None = None): + app = Flask(name) + if config: + app.config.update(config) + return app + + def register_assets_blueprint( + self, blueprint_name: str, assets_url_path: str, assets_folder: str + ): + bp = Blueprint( + blueprint_name, + __name__, + static_folder=assets_folder, + static_url_path=assets_url_path, + ) + self.server.register_blueprint(bp) + + def register_error_handlers(self): + @self.server.errorhandler(PreventUpdate) + def _handle_error(_): + return "", 204 + + @self.server.errorhandler(InvalidResourceError) + def _invalid_resources_handler(err): + return err.args[0], 404 + + def _get_traceback(self, secret, error: Exception): + try: + from werkzeug.debug import ( + tbtools, + ) # pylint: disable=import-outside-toplevel + except ImportError: + tbtools = None + + def _get_skip(error): + tb = error.__traceback__ + skip = 1 + while tb.tb_next is not None: + skip += 1 + tb = tb.tb_next + if tb.tb_frame.f_code in [ + _invoke_callback.__code__, + _async_invoke_callback.__code__, + ]: + return skip + return skip + + def _do_skip(error): + tb = error.__traceback__ + while tb.tb_next is not None: + if tb.tb_frame.f_code in [ + _invoke_callback.__code__, + _async_invoke_callback.__code__, + ]: + return tb.tb_next + tb = tb.tb_next + return error.__traceback__ + + if hasattr(tbtools, "get_current_traceback"): + return tbtools.get_current_traceback(skip=_get_skip(error)).render_full() + if hasattr(tbtools, "DebugTraceback"): + return tbtools.DebugTraceback( + error, skip=_get_skip(error) + ).render_debugger_html(True, secret, True) + return "".join(traceback.format_exception(type(error), error, _do_skip(error))) + + def register_prune_error_handler(self, secret, prune_errors): + if prune_errors: + + @self.server.errorhandler(Exception) + def _wrap_errors(error): + tb = self._get_traceback(secret, error) + return tb, 500 + + def add_url_rule( + self, + rule: str, + view_func: Callable[..., Any], + endpoint: str | None = None, + methods: list[str] | None = None, + ): + self.server.add_url_rule( + rule, view_func=view_func, endpoint=endpoint, methods=methods or ["GET"] + ) + + def before_request(self, func: Callable[[], Any]): + # Flask expects a callable; user responsibility not to pass None + self.server.before_request(func) + + def after_request(self, func: Callable[[Any], Any]): + # Flask after_request expects a function(response) -> response + self.server.after_request(func) + + def run(self, dash_app: Dash, host: str, port: int, debug: bool, **kwargs: Any): + self.server.run(host=host, port=port, debug=debug, **kwargs) + + def make_response( + self, + data: str | bytes | bytearray, + mimetype: str | None = None, + content_type: str | None = None, + ): + return Response(data, mimetype=mimetype, content_type=content_type) + + def jsonify(self, obj: Any): + return jsonify(obj) + + def setup_catchall(self, dash_app: Dash): + def catchall(*args, **kwargs): + return dash_app.index(*args, **kwargs) + + # pylint: disable=protected-access + dash_app._add_url("", catchall, methods=["GET"]) + + def setup_index(self, dash_app: Dash): + def index(*args, **kwargs): + return dash_app.index(*args, **kwargs) + + # pylint: disable=protected-access + dash_app._add_url("", index, methods=["GET"]) + + def serve_component_suites( + self, dash_app: Dash, package_name: str, fingerprinted_path: str + ): + path_in_pkg, has_fingerprint = check_fingerprint(fingerprinted_path) + _validate.validate_js_path(dash_app.registered_paths, package_name, path_in_pkg) + extension = "." + path_in_pkg.split(".")[-1] + mimetype = mimetypes.types_map.get(extension, "application/octet-stream") + package = sys.modules[package_name] + dash_app.logger.debug( + "serving -- package: %s[%s] resource: %s => location: %s", + package_name, + package.__version__, + path_in_pkg, + package.__path__, + ) + data = pkgutil.get_data(package_name, path_in_pkg) + response = Response(data, mimetype=mimetype) + if has_fingerprint: + response.cache_control.max_age = 31536000 # 1 year + else: + response.add_etag() + tag = response.get_etag()[0] + request_etag = request.headers.get("If-None-Match") + if f'"{tag}"' == request_etag: + response = Response(None, status=304) + return response + + def setup_component_suites(self, dash_app: Dash): + def serve(package_name, fingerprinted_path): + return self.serve_component_suites( + dash_app, package_name, fingerprinted_path + ) + + # pylint: disable=protected-access + dash_app._add_url( + "_dash-component-suites//", + serve, + ) + + # pylint: disable=unused-argument + def dispatch(self, dash_app: Dash): + def _dispatch(): + body = request.get_json() + # pylint: disable=protected-access + cb_ctx = dash_app._initialize_context(body) + func = dash_app._prepare_callback(cb_ctx, body) + args = dash_app._inputs_to_vals(cb_ctx.inputs_list + cb_ctx.states_list) + ctx = copy_context() + partial_func = dash_app._execute_callback(func, args, cb_ctx.outputs_list, cb_ctx) + response_data = ctx.run(partial_func) + if asyncio.iscoroutine(response_data): + raise Exception( + "You are trying to use a coroutine without dash[async]. " + "Please install the dependencies via `pip install dash[async]` and ensure " + "that `use_async=False` is not being passed to the app." + ) + cb_ctx.dash_response.set_data(response_data) + return cb_ctx.dash_response + + async def _dispatch_async(): + body = request.get_json() + # pylint: disable=protected-access + cb_ctx = dash_app._initialize_context(body) + func = dash_app._prepare_callback(cb_ctx, body) + args = dash_app._inputs_to_vals(cb_ctx.inputs_list + cb_ctx.states_list) + ctx = copy_context() + partial_func = dash_app._execute_callback(func, args, cb_ctx.outputs_list, cb_ctx) + response_data = ctx.run(partial_func) + if asyncio.iscoroutine(response_data): + response_data = await response_data + cb_ctx.dash_response.set_data(response_data) + return cb_ctx.dash_response + + if dash_app._use_async: + return _dispatch_async + return _dispatch + + def _serve_default_favicon(self): + return Response( + pkgutil.get_data("dash", "favicon.ico"), content_type="image/x-icon" + ) + + def register_timing_hooks(self, _first_run: bool): + # Define timing hooks inside method scope and register them + def _before_request() -> None: + flask_g.timing_information = { # type: ignore[attr-defined] + "__dash_server": {"dur": time.time(), "desc": None} + } + + def _after_request(response: Response): # type: ignore[name-defined] + timing_information = flask_g.get("timing_information", None) # type: ignore[attr-defined] + if timing_information is None: + return response + dash_total = timing_information.get("__dash_server", None) + if dash_total is not None: + dash_total["dur"] = round((time.time() - dash_total["dur"]) * 1000) + for name, info in timing_information.items(): + value = name + if info.get("desc") is not None: + value += f';desc="{info["desc"]}"' + if info.get("dur") is not None: + value += f";dur={info['dur']}" + response.headers.add("Server-Timing", value) + return response + + self.before_request(_before_request) + self.after_request(_after_request) + + def register_callback_api_routes(self, callback_api_paths: Dict[str, Callable[..., Any]]): + """ + Register callback API endpoints on the Flask app. + Each key in callback_api_paths is a route, each value is a handler (sync or async). + The view function parses the JSON body and passes it to the handler. + """ + for path, handler in callback_api_paths.items(): + endpoint = f"dash_callback_api_{path}" + route = path if path.startswith("/") else f"/{path}" + methods = ["POST"] + + if inspect.iscoroutinefunction(handler): + + async def _async_view_func(*args, handler=handler, **kwargs): + data = request.get_json() + result = await handler(**data) if data else await handler() + return jsonify(result) + + view_func = _async_view_func + else: + + def _sync_view_func(*args, handler=handler, **kwargs): + data = request.get_json() + result = handler(**data) if data else handler() + return jsonify(result) + + view_func = _sync_view_func + + view_func = _sync_view_func + + # Flask 2.x+ supports async views natively + self.server.add_url_rule( + route, endpoint=endpoint, view_func=view_func, methods=methods + ) + + +class FlaskRequestAdapter(RequestAdapter): + """Flask implementation using property-based accessors.""" + + def __init__(self) -> None: + # Store the request LocalProxy so we can reference it consistently + self._request = request + super().__init__() + + def __call__(self, *args: Any, **kwds: Any): + return self + + @property + def args(self): + return self._request.args + + @property + def root(self): + return self._request.url_root + + def get_json(self): # kept as method + return self._request.get_json() + + @property + def is_json(self): + return self._request.is_json + + @property + def cookies(self): + return self._request.cookies + + @property + def headers(self): + return self._request.headers + + @property + def url(self): + return self._request.url + + @property + def full_path(self): + return self._request.full_path + + @property + def remote_addr(self): + return self._request.remote_addr + + @property + def origin(self): + return getattr(self._request, "origin", None) + + @property + def path(self): + return self._request.path diff --git a/dash/backends/_quart.py b/dash/backends/_quart.py new file mode 100644 index 0000000000..a462d07af6 --- /dev/null +++ b/dash/backends/_quart.py @@ -0,0 +1,452 @@ +from __future__ import annotations +from contextvars import copy_context +import typing as _t +import traceback +import mimetypes +import inspect +import pkgutil +import time +import sys +import re + +# Attempt top-level Quart imports; allow absence if user not using quart backend +from quart import ( + Quart, + Response, + jsonify, + request, + Blueprint, + g, +) + +if _t.TYPE_CHECKING: + from dash import Dash + +from dash.exceptions import PreventUpdate, InvalidResourceError +from dash.fingerprint import check_fingerprint +from dash import _validate +from .base_server import BaseDashServer + + +class QuartDashServer(BaseDashServer): + + def __init__(self, server: Quart) -> None: + self.server_type = "quart" + self.server: Quart = server + self.config = {} + self.error_handling_mode = "prune" + super().__init__() + + def __call__(self, *args: Any, **kwargs: Any): # type: ignore[name-defined] + return self.server(*args, **kwargs) + + @staticmethod + def create_app(name: str = "__main__", config: _t.Optional[_t.Dict[str, _t.Any]] = None): + if Quart is None: + raise RuntimeError( + "Quart is not installed. Install with 'pip install quart' to use the quart backend." + ) + app = Quart(name) # type: ignore + if config: + for key, value in config.items(): + app.config[key] = value + return app + + def register_assets_blueprint( + self, blueprint_name: str, assets_url_path: str, assets_folder: str # type: ignore[name-defined] + ): + + bp = Blueprint( + blueprint_name, + __name__, + static_folder=assets_folder, + static_url_path=assets_url_path, + ) + self.server.register_blueprint(bp) + + def _get_traceback(self, _secret, error: Exception): + tb = error.__traceback__ + errors = traceback.format_exception(type(error), error, tb) + pass_errs = [] + callback_handled = False + for err in errors: + if self.error_handling_mode == "prune": + if not callback_handled: + if "callback invoked" in str(err) and "_callback.py" in str(err): + callback_handled = True + continue + pass_errs.append(err) + formatted_tb = "".join(pass_errs) + error_type = type(error).__name__ + error_msg = str(error) + + # Parse traceback lines to group by file + file_cards = [] + pattern = re.compile(r' File "(.+)", line (\d+), in (\w+)') + lines = formatted_tb.split("\n") + current_file = None + card_lines = [] + + for line in lines[:-1]: # Skip the last line (error message) + match = pattern.match(line) + if match: + if current_file and card_lines: + file_cards.append((current_file, card_lines)) + current_file = ( + f"{match.group(1)} (line {match.group(2)}, in {match.group(3)})" + ) + card_lines = [line] + elif current_file: + card_lines.append(line) + if current_file and card_lines: + file_cards.append((current_file, card_lines)) + + cards_html = "" + for filename, card in file_cards: + cards_html += ( + f""" +
+
{filename}
+
"""
+                + "\n".join(card)
+                + """
+
+ """ + ) + + html = f""" + + + + {error_type}: {error_msg} // Quart Debugger + + + +
+

{error_type}

+
+

{error_type}: {error_msg}

+
+

Traceback (most recent call last)

+ {cards_html} +
{error_type}: {error_msg}
+
+

This is the Copy/Paste friendly version of the traceback.

+ +
+
+ The debugger caught an exception in your ASGI application. You can now + look at the traceback which led to the error. +
+ +
+ + + """ + return html + + def register_prune_error_handler(self, secret, prune_errors): + if prune_errors: + self.error_handling_mode = "prune" + else: + self.error_handling_mode = "raise" + + @self.server.errorhandler(Exception) + async def _wrap_errors(error): + tb = self._get_traceback(secret, error) + return Response(tb, status=500, content_type="text/html") + + def register_timing_hooks(self, _first_run: bool): # type: ignore[name-defined] parity with Flask factory + @self.server.before_request + async def _before_request(): # pragma: no cover - timing infra + if g is not None: + g.timing_information = { # type: ignore[attr-defined] + "__dash_server": {"dur": time.time(), "desc": None} + } + + @self.server.after_request + async def _after_request(response): # pragma: no cover - timing infra + timing_information = ( + getattr(g, "timing_information", None) if g is not None else None + ) + if timing_information is None: + return response + dash_total = timing_information.get("__dash_server", None) + if dash_total is not None: + dash_total["dur"] = round((time.time() - dash_total["dur"]) * 1000) + for name, info in timing_information.items(): + value = name + if info.get("desc") is not None: + value += f';desc="{info["desc"]}"' + if info.get("dur") is not None: + value += f";dur={info['dur']}" + # Quart/Werkzeug headers expose 'add' (not 'append') + if hasattr(response.headers, "add"): + response.headers.add("Server-Timing", value) + else: # fallback just in case + response.headers["Server-Timing"] = value + return response + + def register_error_handlers(self): # type: ignore[name-defined] + @self.server.errorhandler(PreventUpdate) + async def _prevent_update(_): + return "", 204 + + @self.server.errorhandler(InvalidResourceError) + async def _invalid_resource(err): + return err.args[0], 404 + + def _html_response_wrapper(self, view_func: _t.Callable[..., _t.Any] | str): + + async def wrapped(*_args, **_kwargs): + html_val = view_func() if callable(view_func) else view_func + if inspect.iscoroutine(html_val): # handle async function returning html + html_val = await html_val + html = str(html_val) + return Response(html, content_type="text/html") + + return wrapped + + def add_url_rule( + self, + rule: str, + view_func: _t.Callable[..., _t.Any], + endpoint: str | None = None, + methods: list[str] | None = None, + ): + self.server.add_url_rule( + rule, view_func=view_func, endpoint=endpoint, methods=methods or ["GET"] + ) + + def setup_index(self, dash_app: Dash): # type: ignore[name-defined] + + async def index(*args, **kwargs): + return Response(dash_app.index(*args, **kwargs), content_type="text/html") # type: ignore[arg-type] + + # pylint: disable=protected-access + dash_app._add_url("", index, methods=["GET"]) + + def setup_catchall(self, dash_app: Dash): + + async def catchall( + path: str, *args, **kwargs + ): # noqa: ARG001 - path is unused but kept for route signature, pylint: disable=unused-argument + return Response(dash_app.index(*args, **kwargs), content_type="text/html") # type: ignore[arg-type] + + # pylint: disable=protected-access + dash_app._add_url("", catchall, methods=["GET"]) + + def before_request(self, func: _t.Callable[[], _t.Any]): + self.server.before_request(func) + + def after_request(self, func: _t.Callable[[], _t.Any]): + @self.server.after_request + async def _after(response): + if func is not None: + result = func() + if inspect.iscoroutine(result): # Allow async hooks + await result + return response + + def run(self, dash_app: Dash, host: str, port: int, debug: bool, **kwargs: _t.Any): + self.config = {"debug": debug, **kwargs} if debug else kwargs + self.server.run(host=host, port=port, debug=debug, **kwargs) + + def make_response( + self, + data: str | bytes | bytearray, + mimetype: str | None = None, + content_type: str | None = None, + ): + if Response is None: + raise RuntimeError("Quart not installed; cannot generate Response") + return Response(data, mimetype=mimetype, content_type=content_type) + + def jsonify(self, obj): + return jsonify(obj) + + def serve_component_suites( + self, dash_app: Dash, package_name: str, fingerprinted_path: str + ): # noqa: ARG002 unused req preserved for interface parity + path_in_pkg, has_fingerprint = check_fingerprint(fingerprinted_path) + _validate.validate_js_path(dash_app.registered_paths, package_name, path_in_pkg) + extension = "." + path_in_pkg.split(".")[-1] + mimetype = mimetypes.types_map.get(extension, "application/octet-stream") + package = sys.modules[package_name] + dash_app.logger.debug( + "serving -- package: %s[%s] resource: %s => location: %s", + package_name, + getattr(package, "__version__", "unknown"), + path_in_pkg, + package.__path__, + ) + data = pkgutil.get_data(package_name, path_in_pkg) + headers = {} + if has_fingerprint: + headers["Cache-Control"] = "public, max-age=31536000" + + if Response is None: + raise RuntimeError("Quart not installed; cannot generate Response") + return Response(data, content_type=mimetype, headers=headers) + + def setup_component_suites(self, dash_app: Dash): + async def serve(package_name, fingerprinted_path): + return self.serve_component_suites( + dash_app, package_name, fingerprinted_path + ) + + # pylint: disable=protected-access + dash_app._add_url( + "_dash-component-suites//", + serve, + ) + + # pylint: disable=unused-argument + def dispatch(self, dash_app: Dash): # type: ignore[name-defined] Quart always async + + async def _dispatch(): + adapter = QuartRequestAdapter() + body = await adapter.get_json() + # pylint: disable=protected-access + g = dash_app._initialize_context(body) + # pylint: disable=protected-access + func = dash_app._prepare_callback(g, body) + # pylint: disable=protected-access + args = dash_app._inputs_to_vals(g.inputs_list + g.states_list) + ctx = copy_context() + # pylint: disable=protected-access + partial_func = dash_app._execute_callback(func, args, g.outputs_list, g) + response_data = ctx.run(partial_func) + if inspect.iscoroutine(response_data): # if user callback is async + response_data = await response_data + return Response(response_data, content_type="application/json") # type: ignore[arg-type] + + return _dispatch + + def register_callback_api_routes(self, callback_api_paths: _t.Dict[str, _t.Callable[..., _t.Any]]): + """ + Register callback API endpoints on the Quart app. + Each key in callback_api_paths is a route, each value is a handler (sync or async). + The view function parses the JSON body and passes it to the handler. + """ + for path, handler in callback_api_paths.items(): + endpoint = f"dash_callback_api_{path}" + route = path if path.startswith("/") else f"/{path}" + methods = ["POST"] + + def _make_view_func(handler): + if inspect.iscoroutinefunction(handler): + + async def async_view_func(*args, **kwargs): + if request is None: + raise RuntimeError( + "Quart not installed; request unavailable" + ) + data = await request.get_json() + result = await handler(**data) if data else await handler() + return jsonify(result) # type: ignore[arg-type] + + return async_view_func + + async def sync_view_func(*args, **kwargs): + if request is None: + raise RuntimeError("Quart not installed; request unavailable") + data = await request.get_json() + result = handler(**data) if data else handler() + return jsonify(result) # type: ignore[arg-type] + + return sync_view_func + + view_func = _make_view_func(handler) + self.server.add_url_rule( + route, endpoint=endpoint, view_func=view_func, methods=methods + ) + + def _serve_default_favicon(self): + if Response is None: + raise RuntimeError("Quart not installed; cannot generate Response") + return Response( + pkgutil.get_data("dash", "favicon.ico"), content_type="image/x-icon" + ) + + +class QuartRequestAdapter: + def __init__(self) -> None: + self._request = request # type: ignore[assignment] + if self._request is None: + raise RuntimeError("Quart not installed; cannot access request context") + + @property + def request(self) -> _t.Any: + return self._request + + @property + def root(self): + return self.request.root_url + + @property + def args(self): + return self.request.args + + @property + def is_json(self): + return self.request.is_json + + @property + def cookies(self): + return self.request.cookies + + @property + def headers(self): + return self.request.headers + + @property + def full_path(self): + return self.request.full_path + + @property + def url(self): + return str(self.request.url) + + @property + def remote_addr(self): + return self.request.remote_addr + + @property + def origin(self): + return self.request.headers.get("origin") + + @property + def path(self): + return self.request.path + + async def get_json(self): + return await self.request.get_json() diff --git a/dash/backends/base_server.py b/dash/backends/base_server.py new file mode 100644 index 0000000000..1c47548ad0 --- /dev/null +++ b/dash/backends/base_server.py @@ -0,0 +1,119 @@ +from abc import ABC, abstractmethod +from typing import Any + + +class BaseDashServer(ABC): + server_type: str + server: Any + config: dict[str, Any] + + def __call__(self, *args, **kwargs) -> Any: + # Default: WSGI + return self.server(*args, **kwargs) + + @staticmethod + @abstractmethod + def create_app( + name: str = "__main__", config=None + ) -> Any: # pragma: no cover - interface + pass + + @abstractmethod + def register_assets_blueprint( + self, blueprint_name: str, assets_url_path: str, assets_folder: str + ) -> None: # pragma: no cover - interface + pass + + @abstractmethod + def register_error_handlers(self) -> None: # pragma: no cover - interface + pass + + @abstractmethod + def add_url_rule( + self, rule: str, view_func, endpoint=None, methods=None + ) -> None: # pragma: no cover - interface + pass + + @abstractmethod + def before_request(self, func) -> None: # pragma: no cover - interface + pass + + @abstractmethod + def after_request(self, func) -> None: # pragma: no cover - interface + pass + + @abstractmethod + def run( + self, dash_app, host: str, port: int, debug: bool, **kwargs + ) -> None: # pragma: no cover - interface + pass + + @abstractmethod + def make_response( + self, data, mimetype=None, content_type=None + ) -> Any: # pragma: no cover - interface + pass + + @abstractmethod + def jsonify(self, obj) -> Any: # pragma: no cover - interface + pass + + +class RequestAdapter(ABC): + def __call__(self) -> Any: + return self + + # Properties to be implemented in concrete adapters + @property # pragma: no cover - interface + @abstractmethod + def root(self) -> str: + raise NotImplementedError() + + @property # pragma: no cover - interface + @abstractmethod + def args(self): + raise NotImplementedError() + + @abstractmethod # kept as method (may be sync or async) + def get_json(self): # pragma: no cover - interface + raise NotImplementedError() + + @property # pragma: no cover - interface + @abstractmethod + def is_json(self) -> bool: + raise NotImplementedError() + + @property # pragma: no cover - interface + @abstractmethod + def cookies(self): + raise NotImplementedError() + + @property # pragma: no cover - interface + @abstractmethod + def headers(self): + raise NotImplementedError() + + @property # pragma: no cover - interface + @abstractmethod + def full_path(self) -> str: + raise NotImplementedError() + + @property # pragma: no cover - interface + @abstractmethod + def url(self) -> str: + raise NotImplementedError() + + @property # pragma: no cover - interface + @abstractmethod + def remote_addr(self): + raise NotImplementedError() + + @property # pragma: no cover - interface + @abstractmethod + def origin(self): + raise NotImplementedError() + + @property # pragma: no cover - interface + @abstractmethod + def path(self) -> str: + raise NotImplementedError() diff --git a/dash/dash-renderer/src/components/error/FrontEnd/FrontEndError.react.js b/dash/dash-renderer/src/components/error/FrontEnd/FrontEndError.react.js index 176cb2c6f8..db4c6ddd2b 100644 --- a/dash/dash-renderer/src/components/error/FrontEnd/FrontEndError.react.js +++ b/dash/dash-renderer/src/components/error/FrontEnd/FrontEndError.react.js @@ -121,13 +121,18 @@ function BackendError({error, base}) { const MAX_MESSAGE_LENGTH = 40; /* eslint-disable no-inline-comments */ function UnconnectedErrorContent({error, base}) { + // Helper to detect full HTML document + const isFullHtmlDoc = + typeof error.html === 'string' && + error.html.trim().toLowerCase().startsWith(' - {/* - * 40 is a rough heuristic - if longer than 40 then the - * message might overflow into ellipses in the title above & - * will need to be displayed in full in this error body - */} + {/* Frontend error message */} {typeof error.message !== 'string' || error.message.length < MAX_MESSAGE_LENGTH ? null : (
@@ -137,6 +142,7 @@ function UnconnectedErrorContent({error, base}) {
)} + {/* Frontend stack trace */} {typeof error.stack !== 'string' ? null : (
@@ -149,7 +155,6 @@ function UnconnectedErrorContent({error, base}) { browser's console.) - {error.stack.split('\n').map((line, i) => (

{line}

))} @@ -157,24 +162,30 @@ function UnconnectedErrorContent({error, base}) {
)} - {/* Backend Error */} - {typeof error.html !== 'string' ? null : error.html - .substring(0, '
- {/* Embed werkzeug debugger in an iframe to prevent - CSS leaking - werkzeug HTML includes a bunch - of CSS on base html elements like `` - */}
- ) : ( + ) : isHtmlFragment ? ( + // Backend error: HTML fragment +
+
+
+ ) : typeof error.html === 'string' ? ( + // Backend error: plain text
-
{error.html}
+
+
{error.html}
+
- )} + ) : null}
); } diff --git a/dash/dash.py b/dash/dash.py index 18ad1c2367..1ed05657dc 100644 --- a/dash/dash.py +++ b/dash/dash.py @@ -14,7 +14,6 @@ import mimetypes import hashlib import base64 -import traceback from urllib.parse import urlparse from typing import Any, Callable, Dict, Optional, Union, Sequence, Literal, List @@ -26,7 +25,6 @@ from dash import dcc from dash import html from dash import dash_table - from .fingerprint import build_fingerprint from .resources import Scripts, Css from .dependencies import ( @@ -39,7 +37,7 @@ ProxyError, DuplicateCallback, ) -from .backend import get_request_adapter, get_backend +from .backends import get_backend from .version import __version__ from ._configs import get_combined_config, pathname_configs, pages_folder_config from ._utils import ( @@ -64,6 +62,7 @@ from . import _validate from . import _watch from . import _get_app +from . import backends from ._get_app import with_app_context, with_app_context_factory from ._grouping import map_grouping, grouping_len, update_args_group @@ -155,93 +154,6 @@ page_container = None -def _is_flask_instance(obj): - try: - # pylint: disable=import-outside-toplevel - from flask import Flask - - return isinstance(obj, Flask) - except ImportError: - return False - - -def _is_fastapi_instance(obj): - try: - # pylint: disable=import-outside-toplevel - from fastapi import FastAPI - - return isinstance(obj, FastAPI) - except ImportError: - return False - - -def _is_quart_instance(obj): - try: - # pylint: disable=import-outside-toplevel - from quart import Quart - - return isinstance(obj, Quart) - except ImportError: - return False - - -def _get_traceback(secret, error: Exception): - try: - # pylint: disable=import-outside-toplevel - from werkzeug.debug import tbtools - except ImportError: - tbtools = None - - def _get_skip(error): - from dash._callback import ( # pylint: disable=import-outside-toplevel - _invoke_callback, - _async_invoke_callback, - ) - - tb = error.__traceback__ - skip = 1 - while tb.tb_next is not None: - skip += 1 - tb = tb.tb_next - if tb.tb_frame.f_code in [ - _invoke_callback.__code__, - _async_invoke_callback.__code__, - ]: - return skip - - return skip - - def _do_skip(error): - from dash._callback import ( # pylint: disable=import-outside-toplevel - _invoke_callback, - _async_invoke_callback, - ) - - tb = error.__traceback__ - while tb.tb_next is not None: - if tb.tb_frame.f_code in [ - _invoke_callback.__code__, - _async_invoke_callback.__code__, - ]: - return tb.tb_next - tb = tb.tb_next - return error.__traceback__ - - # werkzeug<2.1.0 - if hasattr(tbtools, "get_current_traceback"): - return tbtools.get_current_traceback( # type: ignore - skip=_get_skip(error) - ).render_full() - - if hasattr(tbtools, "DebugTraceback"): - # pylint: disable=no-member - return tbtools.DebugTraceback( # type: ignore - error, skip=_get_skip(error) - ).render_debugger_html(True, secret, True) - - return "".join(traceback.format_exception(type(error), error, _do_skip(error))) - - # Singleton signal to not update an output, alternative to PreventUpdate no_update = _callback.NoUpdate() # pylint: disable=protected-access @@ -504,74 +416,41 @@ def __init__( # pylint: disable=too-many-statements, too-many-branches **obsolete, ): - if use_async is None: - try: - import asgiref # pylint: disable=unused-import, import-outside-toplevel # noqa - - use_async = True - except ImportError: - pass - elif use_async: - try: - import asgiref # pylint: disable=unused-import, import-outside-toplevel # noqa - except ImportError as exc: - raise Exception( - "You are trying to use dash[async] without having installed the requirements please install via: `pip install dash[async]`" - ) from exc - + _validate.check_async(use_async) _validate.check_obsolete(obsolete) caller_name: str = name if name is not None else get_caller_name() # Determine backend if backend is None: - backend_cls = get_backend("flask") + backend_cls, request_cls = get_backend("flask") elif isinstance(backend, str): - backend_cls = get_backend(backend) + backend_cls, request_cls = get_backend(backend) elif isinstance(backend, type): backend_cls = backend + _, request_cls = get_backend(backend.server_type) else: raise ValueError("Invalid backend argument") # Determine server and backend instance if server not in (None, True, False): # User provided a server instance (e.g., Flask, Quart, FastAPI) - if _is_flask_instance(server): - inferred_backend = "flask" - elif _is_quart_instance(server): - inferred_backend = "quart" - elif _is_fastapi_instance(server): - inferred_backend = "fastapi" - else: - raise ValueError("Unsupported server type") - # Validate that backend matches server type if both are provided - if backend is not None: - if isinstance(backend, type): - # get_backend returns the backend class for a string - # So we compare the class names - expected_backend_cls = get_backend(inferred_backend) - if ( - backend.__module__ != expected_backend_cls.__module__ - or backend.__name__ != expected_backend_cls.__name__ - ): - raise ValueError( - f"Conflict between provided backend '{backend.__name__}' and server type '{inferred_backend}'." - ) - elif not isinstance(backend, str): - raise ValueError("Invalid backend argument") - elif backend.lower() != inferred_backend: - raise ValueError( - f"Conflict between provided backend '{backend}' and server type '{inferred_backend}'." - ) - backend_cls = get_backend(inferred_backend) + inferred_backend = backends.get_server_type(server) + _validate.check_backend(backend, inferred_backend) + backend_cls, request_cls = get_backend(inferred_backend) if name is None: caller_name = getattr(server, "name", caller_name) - self.backend = backend_cls() + + self.backend = backend_cls(server) self.server = server + backends.backend = self.backend # type: ignore + backends.request_adapter = request_cls else: # No server instance provided, create backend and let backend create server - self.backend = backend_cls() - self.server = self.backend.create_app(caller_name) # type: ignore + self.server = backend_cls.create_app(caller_name) # type: ignore + self.backend = backend_cls(self.server) + backends.backend = self.backend + backends.request_adapter = request_cls base_prefix, routes_prefix, requests_prefix = pathname_configs( url_base_pathname, routes_pathname_prefix, requests_pathname_prefix @@ -768,7 +647,6 @@ def init_app(self, app: Optional[Any] = None, **kwargs) -> None: bp_prefix = config.routes_pathname_prefix.replace("/", "_").replace(".", "_") assets_blueprint_name = f"{bp_prefix}dash_assets" self.backend.register_assets_blueprint( - self.server, assets_blueprint_name, config.routes_pathname_prefix + self.config.assets_url_path.lstrip("/"), self.config.assets_folder, @@ -790,8 +668,9 @@ def init_app(self, app: Optional[Any] = None, **kwargs) -> None: raise ImportError( "To use the compress option, you need to install dash[compress]" ) from error - self.backend.register_error_handlers(self.server) - self.backend.before_request(self.server, self._setup_server) + + self.backend.register_error_handlers() + self.backend.before_request(self._setup_server) self._setup_routes() _get_app.APP = self self.enable_pages() @@ -800,7 +679,6 @@ def init_app(self, app: Optional[Any] = None, **kwargs) -> None: def _add_url(self, name: str, view_func: RouteCallable, methods=("GET",)) -> None: full_name = self.config.routes_pathname_prefix + name self.backend.add_url_rule( - self.server, full_name, view_func=view_func, endpoint=full_name, @@ -814,7 +692,7 @@ def _setup_routes(self): self._add_url("_dash-dependencies", self.dependencies) self._add_url( "_dash-update-component", - self.backend.dispatch(self.server, self, self._use_async), + self.backend.dispatch(self), ["POST"], ) self._add_url("_reload-hash", self.serve_reload_hash) @@ -861,7 +739,7 @@ def setup_apis(self): self.callback_api_paths[k] = _callback.GLOBAL_API_PATHS.pop(k) # Delegate to the server factory for route registration - self.backend.register_callback_api_routes(self.server, self.callback_api_paths) + self.backend.register_callback_api_routes(self.callback_api_paths) def _setup_plotlyjs(self): # pylint: disable=import-outside-toplevel @@ -1101,9 +979,11 @@ def _generate_css_dist_html(self): return "\n".join( [ - format_tag("link", link, opened=True) - if isinstance(link, dict) - else f'' + ( + format_tag("link", link, opened=True) + if isinstance(link, dict) + else f'' + ) for link in (external_links + links) ] ) @@ -1157,9 +1037,11 @@ def _generate_scripts_html(self) -> str: return "\n".join( [ - format_tag("script", src) - if isinstance(src, dict) - else f'' + ( + format_tag("script", src) + if isinstance(src, dict) + else f'' + ) for src in srcs ] + [f"" for src in self._inline_scripts] @@ -1197,11 +1079,8 @@ def index(self, *_args, **_kwargs): metas = self._generate_meta() renderer = self._generate_renderer() title = self.title - try: - request = get_request_adapter() - except LookupError: - # no request context - request = None + # Refactored: direct access to global request adapter + request = backends.request_adapter() if self.use_pages and self.config.include_pages_meta and request: metas = _page_meta_tags(self, request) + metas @@ -1415,8 +1294,9 @@ def _inputs_to_vals(self, inputs): return inputs_to_vals(inputs) # pylint: disable=R0915 - def _initialize_context(self, body, adapter): + def _initialize_context(self, body): """Initialize the global context for the request.""" + adapter = backends.request_adapter() g = AttributeDict({}) g.inputs_list = body.get("inputs", []) g.states_list = body.get("state", []) @@ -1430,12 +1310,12 @@ def _initialize_context(self, body, adapter): g.dash_response = self.backend.make_response( mimetype="application/json", data=None ) - g.cookies = dict(adapter.get_cookies()) - g.headers = dict(adapter.get_headers()) - g.args = adapter.get_args() - g.path = adapter.get_full_path() - g.remote = adapter.get_remote_addr() - g.origin = adapter.get_origin() + g.cookies = dict(adapter.cookies) + g.headers = dict(adapter.headers) + g.args = adapter.args + g.path = adapter.full_path + g.remote = adapter.remote_addr + g.origin = adapter.origin g.updated_props = {} return g @@ -2022,15 +1902,21 @@ def enable_dev_tools( packages[index] = dash_spec component_packages_dist = [ - dash_test_path # type: ignore[reportPossiblyUnboundVariable] - if isinstance(package, ModuleSpec) - else os.path.dirname(package.path) # type: ignore[reportAttributeAccessIssue] - if hasattr(package, "path") - else os.path.dirname( - package._path[0] # type: ignore[reportAttributeAccessIssue]; pylint: disable=protected-access - ) - if hasattr(package, "_path") - else package.filename # type: ignore[reportAttributeAccessIssue] + ( + dash_test_path # type: ignore[reportPossiblyUnboundVariable] + if isinstance(package, ModuleSpec) + else ( + os.path.dirname(package.path) # type: ignore[reportAttributeAccessIssue] + if hasattr(package, "path") + else ( + os.path.dirname( + package._path[0] # type: ignore[reportAttributeAccessIssue]; pylint: disable=protected-access + ) + if hasattr(package, "_path") + else package.filename + ) + ) + ) # type: ignore[reportAttributeAccessIssue] for package in packages ] @@ -2061,11 +1947,11 @@ def enable_dev_tools( elif dev_tools.prune_errors: secret = gen_salt(20) self.backend.register_prune_error_handler( - self.server, secret, _get_traceback + secret, dev_tools.prune_errors ) if debug and dev_tools.ui: - self.backend.register_timing_hooks(self.server, first_run) + self.backend.register_timing_hooks(first_run) if ( debug @@ -2349,8 +2235,8 @@ def verify_url_part(served_part, url_part, part_name): server_url=jupyter_server_url, ) else: - self.backend.run( - self.server, host=host, port=port, debug=debug, **flask_run_options + backends.backend.run( + dash_app=self, host=host, port=port, debug=debug, **flask_run_options ) def enable_pages(self) -> None: @@ -2422,9 +2308,11 @@ async def update(pathname_, search_, **states): if not self.config.suppress_callback_exceptions: self.validation_layout = html.Div( [ - asyncio.run(execute_async_function(page["layout"])) - if callable(page["layout"]) - else page["layout"] + ( + asyncio.run(execute_async_function(page["layout"])) + if callable(page["layout"]) + else page["layout"] + ) for page in _pages.PAGE_REGISTRY.values() ] + [ @@ -2493,9 +2381,11 @@ def update(pathname_, search_, **states): ] self.validation_layout = html.Div( [ - page["layout"]() - if callable(page["layout"]) - else page["layout"] + ( + page["layout"]() + if callable(page["layout"]) + else page["layout"] + ) for page in _pages.PAGE_REGISTRY.values() ] + layout @@ -2514,7 +2404,7 @@ def update(pathname_, search_, **states): Input(_ID_STORE, "data"), ) - self.backend.before_request(self.server, router) + self.backend.before_request(router) def __call__(self, *args, **kwargs): - return self.backend.__call__(self.server, *args, **kwargs) + return self.backend.__call__(*args, **kwargs) diff --git a/dash/testing/application_runners.py b/dash/testing/application_runners.py index dc88afe844..2956f1a4c0 100644 --- a/dash/testing/application_runners.py +++ b/dash/testing/application_runners.py @@ -171,7 +171,13 @@ def run(): self.port = options["port"] try: - app.run(threaded=True, **options) + module = app.server.__class__.__module__ + # FastAPI support + if not module.startswith("flask"): + app.run(**options) + # Dash/Flask fallback + else: + app.run(threaded=True, **options) except SystemExit: logger.info("Server stopped") except Exception as error: @@ -229,7 +235,13 @@ def target(): options = kwargs.copy() try: - app.run(threaded=True, **options) + module = app.server.__class__.__module__ + # FastAPI support + if not module.startswith("flask"): + app.run(**options) + # Dash/Flask fallback + else: + app.run(threaded=True, **options) except SystemExit: logger.info("Server stopped") raise diff --git a/dash_config.json b/dash_config.json new file mode 100644 index 0000000000..3afa0d11f1 --- /dev/null +++ b/dash_config.json @@ -0,0 +1 @@ +{"debug": true, "dev_tools_ui": true, "dev_tools_props_check": true, "dev_tools_serve_dev_bundles": true, "dev_tools_hot_reload": true, "dev_tools_silence_routes_logging": true, "dev_tools_prune_errors": true, "dev_tools_hot_reload_interval": 3.0, "dev_tools_hot_reload_watch_interval": 0.5, "dev_tools_hot_reload_max_retry": 8, "dev_tools_disable_version_check": false} \ No newline at end of file diff --git a/package.json b/package.json index e78e279c1b..b7416dbb34 100644 --- a/package.json +++ b/package.json @@ -44,7 +44,7 @@ "setup-tests.R": "run-s private::test.R.deploy-*", "citest.integration": "run-s setup-tests.py private::test.integration-*", "citest.unit": "run-s private::test.unit-**", - "test": "pytest && cd dash/dash-renderer && npm run test", + "test": "pytest --ignore=tests/backend_tests && cd dash/dash-renderer && npm run test", "first-build": "cd dash/dash-renderer && npm i && cd ../../ && cd components/dash-html-components && npm i && npm run extract && cd ../../ && npm run build" }, "devDependencies": { diff --git a/quart_app.py b/quart_app.py new file mode 100644 index 0000000000..54d40add56 --- /dev/null +++ b/quart_app.py @@ -0,0 +1,23 @@ +from dash import Dash, html, Input, Output +from dash import dcc +from dash import backends + +app = Dash(__name__, backend="quart") + +app.layout = html.Div( + [ + html.H2("Quart Server Factory Example"), + html.Div("Type below to see async callback update."), + dcc.Input(id="text", value="hello", autoComplete="off"), + html.Div(id="echo"), + ] +) + + +@app.callback(Output("echo", "children"), Input("text", "value")) +def update_echo(val): + return f"You typed: {val}" if val else "Type something" + + +if __name__ == "__main__": + app.run(debug=True) diff --git a/requirements/fastapi.txt b/requirements/fastapi.txt new file mode 100644 index 0000000000..97dc7cd8c1 --- /dev/null +++ b/requirements/fastapi.txt @@ -0,0 +1,2 @@ +fastapi +uvicorn diff --git a/requirements/quart.txt b/requirements/quart.txt new file mode 100644 index 0000000000..60af440c9c --- /dev/null +++ b/requirements/quart.txt @@ -0,0 +1 @@ +quart diff --git a/setup.py b/setup.py index 7ed781c20d..950bcbe14d 100644 --- a/setup.py +++ b/setup.py @@ -35,7 +35,9 @@ def read_req_file(req_type): "testing": read_req_file("testing"), "celery": read_req_file("celery"), "diskcache": read_req_file("diskcache"), - "compress": read_req_file("compress") + "compress": read_req_file("compress"), + "fastapi": read_req_file("fastapi"), + "quart": read_req_file("quart"), }, entry_points={ "console_scripts": [ diff --git a/tests/backend_tests/__init__.py b/tests/backend_tests/__init__.py new file mode 100644 index 0000000000..e69de29bb2 diff --git a/tests/backend_tests/test_preconfig_backends.py b/tests/backend_tests/test_preconfig_backends.py new file mode 100644 index 0000000000..5fbd28dfd9 --- /dev/null +++ b/tests/backend_tests/test_preconfig_backends.py @@ -0,0 +1,217 @@ +import pytest +from dash import Dash, Input, Output, html, dcc + + +@pytest.mark.parametrize( + "backend,fixture,input_value", + [ + ("fastapi", "dash_duo", "Hello FastAPI!"), + ("quart", "dash_duo_mp", "Hello Quart!"), + ], +) +def test_backend_basic_callback(request, backend, fixture, input_value): + dash_duo = request.getfixturevalue(fixture) + if backend == "fastapi": + from fastapi import FastAPI + + server = FastAPI() + else: + import quart + + server = quart.Quart(__name__) + app = Dash(__name__, server=server) + app.layout = html.Div( + [dcc.Input(id="input", value=input_value, type="text"), html.Div(id="output")] + ) + + @app.callback(Output("output", "children"), Input("input", "value")) + def update_output(value): + return f"You typed: {value}" + + dash_duo.start_server(app) + dash_duo.wait_for_text_to_equal("#output", f"You typed: {input_value}") + dash_duo.find_element("#input").clear() + dash_duo.find_element("#input").send_keys(f"{backend.title()} Test") + dash_duo.wait_for_text_to_equal("#output", f"You typed: {backend.title()} Test") + assert dash_duo.get_logs() == [] + + +@pytest.mark.parametrize( + "backend,fixture,start_server_kwargs", + [ + ( + "fastapi", + "dash_duo", + {"debug": True, "reload": False, "dev_tools_ui": True}, + ), + ( + "quart", + "dash_duo_mp", + { + "debug": True, + "use_reloader": False, + "dev_tools_hot_reload": False, + }, + ), + ], +) +def test_backend_error_handling(request, backend, fixture, start_server_kwargs): + dash_duo = request.getfixturevalue(fixture) + app = Dash(__name__, backend=backend) + app.layout = html.Div( + [html.Button(id="btn", children="Error", n_clicks=0), html.Div(id="output")] + ) + + @app.callback(Output("output", "children"), Input("btn", "n_clicks")) + def error_callback(n): + if n and n > 0: + return 1 / 0 # Intentional error + return "No error" + + dash_duo.start_server(app, **start_server_kwargs) + dash_duo.wait_for_text_to_equal("#output", "No error") + dash_duo.find_element("#btn").click() + dash_duo.wait_for_text_to_equal(dash_duo.devtools_error_count_locator, "1") + + +def get_error_html(dash_duo, index): + # error is in an iframe so is annoying to read out - get it from the store + return dash_duo.driver.execute_script( + "return store.getState().error.backEnd[{}].error.html;".format(index) + ) + + +@pytest.mark.parametrize( + "backend,fixture,start_server_kwargs, error_msg", + [ + ( + "fastapi", + "dash_duo", + { + "debug": True, + "dev_tools_ui": True, + "dev_tools_prune_errors": False, + "reload": False, + }, + "fastapi.py", + ), + ( + "quart", + "dash_duo_mp", + { + "debug": True, + "use_reloader": False, + "dev_tools_hot_reload": False, + "dev_tools_prune_errors": False, + }, + "quart.py", + ), + ], +) +def test_backend_error_handling_no_prune( + request, backend, fixture, start_server_kwargs, error_msg +): + dash_duo = request.getfixturevalue(fixture) + app = Dash(__name__, backend=backend) + app.layout = html.Div( + [html.Button(id="btn", children="Error", n_clicks=0), html.Div(id="output")] + ) + + @app.callback(Output("output", "children"), Input("btn", "n_clicks")) + def error_callback(n): + if n and n > 0: + return 1 / 0 # Intentional error + return "No error" + + dash_duo.start_server(app, **start_server_kwargs) + dash_duo.wait_for_text_to_equal("#output", "No error") + dash_duo.find_element("#btn").click() + dash_duo.wait_for_text_to_equal(dash_duo.devtools_error_count_locator, "1") + + error0 = get_error_html(dash_duo, 0) + assert "in error_callback" in error0 + assert "ZeroDivisionError" in error0 + assert "backend" in error0 and error_msg in error0 + + +@pytest.mark.parametrize( + "backend,fixture,start_server_kwargs, error_msg", + [ + ("fastapi", "dash_duo", {"debug": True, "reload": False}, "fastapi.py"), + ( + "quart", + "dash_duo_mp", + { + "debug": True, + "use_reloader": False, + "dev_tools_hot_reload": False, + }, + "quart.py", + ), + ], +) +def test_backend_error_handling_prune( + request, backend, fixture, start_server_kwargs, error_msg +): + dash_duo = request.getfixturevalue(fixture) + app = Dash(__name__, backend=backend) + app.layout = html.Div( + [html.Button(id="btn", children="Error", n_clicks=0), html.Div(id="output")] + ) + + @app.callback(Output("output", "children"), Input("btn", "n_clicks")) + def error_callback(n): + if n and n > 0: + return 1 / 0 # Intentional error + return "No error" + + dash_duo.start_server(app, **start_server_kwargs) + dash_duo.wait_for_text_to_equal("#output", "No error") + dash_duo.find_element("#btn").click() + dash_duo.wait_for_text_to_equal(dash_duo.devtools_error_count_locator, "1") + + error0 = get_error_html(dash_duo, 0) + assert "in error_callback" in error0 + assert "ZeroDivisionError" in error0 + assert "dash/backend" not in error0 and error_msg not in error0 + + +@pytest.mark.parametrize( + "backend,fixture,input_value", + [ + ("fastapi", "dash_duo", "Background FastAPI!"), + ("quart", "dash_duo_mp", "Background Quart!"), + ], +) +def test_backend_background_callback(request, backend, fixture, input_value): + dash_duo = request.getfixturevalue(fixture) + import diskcache + + cache = diskcache.Cache("./cache") + from dash.background_callback import DiskcacheManager + + background_callback_manager = DiskcacheManager(cache) + + app = Dash( + __name__, + backend=backend, + background_callback_manager=background_callback_manager, + ) + app.layout = html.Div( + [dcc.Input(id="input", value=input_value, type="text"), html.Div(id="output")] + ) + + @app.callback( + Output("output", "children"), Input("input", "value"), background=True + ) + def update_output_bg(value): + return f"Background typed: {value}" + + dash_duo.start_server(app) + dash_duo.wait_for_text_to_equal("#output", f"Background typed: {input_value}") + dash_duo.find_element("#input").clear() + dash_duo.find_element("#input").send_keys(f"{backend.title()} BG Test") + dash_duo.wait_for_text_to_equal( + "#output", f"Background typed: {backend.title()} BG Test" + ) + assert dash_duo.get_logs() == []