diff --git a/app/api/welcome.py b/app/api/welcome.py index 9a55e4b..9d279cf 100644 --- a/app/api/welcome.py +++ b/app/api/welcome.py @@ -1,9 +1,9 @@ -"""REST endpoints for the analyst-onboarding welcome prompt. +"""REST endpoints for the agent-setup-prompt banner. -- GET /api/welcome : render for the calling user (auth required) -- GET /api/admin/welcome-template : raw template + shipped default (admin) +- GET /api/admin/welcome-template : raw template override (admin) - PUT /api/admin/welcome-template : set override (admin) -- DELETE /api/admin/welcome-template : reset to default (admin) +- DELETE /api/admin/welcome-template : reset to default / no banner (admin) +- POST /api/admin/welcome-template/preview : live preview without persisting (admin) """ import datetime @@ -11,14 +11,14 @@ import logging from typing import Optional import duckdb -from fastapi import APIRouter, Depends, HTTPException, Query, Request, Response -from jinja2 import Environment, StrictUndefined, TemplateError, TemplateSyntaxError +from fastapi import APIRouter, Depends, HTTPException, Request, Response +from jinja2 import Environment, StrictUndefined, TemplateError from pydantic import BaseModel, Field from app.auth.access import require_admin -from app.auth.dependencies import _get_db, get_current_user +from app.auth.dependencies import _get_db from src.repositories.welcome_template import WelcomeTemplateRepository -from src.welcome_template import _load_default_template, build_context, render_welcome +from src.welcome_template import build_context, render_agent_prompt_banner logger = logging.getLogger(__name__) @@ -27,16 +27,11 @@ router = APIRouter(tags=["welcome"]) # Stub context used to validate that a saved template renders end-to-end, # not just that it parses. Mirrors the shape of build_context() output. +# user may be None for anonymous visitors; the stub uses an authenticated +# user so templates that reference user.* fields are validated. _VALIDATION_STUB_CONTEXT = { "instance": {"name": "Example", "subtitle": "Example Org"}, "server": {"url": "https://example.com", "hostname": "example.com"}, - "sync_interval": "1 hour", - "data_source": {"type": "local"}, - "tables": [{"name": "example", "description": "", "query_mode": "local"}], - "metrics": {"count": 0, "categories": []}, - "marketplaces": [ - {"slug": "example", "name": "Example Marketplace", "plugins": [{"name": "x"}]} - ], "user": { "id": "u", "email": "user@example.com", @@ -49,13 +44,12 @@ _VALIDATION_STUB_CONTEXT = { } -class WelcomeResponse(BaseModel): +class BannerResponse(BaseModel): content: str class TemplateGetResponse(BaseModel): content: Optional[str] - default: str updated_at: Optional[str] = None updated_by: Optional[str] = None @@ -68,24 +62,6 @@ class TemplatePreviewRequest(BaseModel): content: str = Field(..., min_length=1, max_length=200_000) -@router.get("/api/welcome", response_model=WelcomeResponse) -async def get_welcome( - server_url: str = Query(..., description="The server URL the analyst is bootstrapping against"), - user: dict = Depends(get_current_user), - conn: duckdb.DuckDBPyConnection = Depends(_get_db), -): - """Render the welcome prompt for the calling user. Returns rendered markdown.""" - try: - rendered = render_welcome(conn, user=user, server_url=server_url) - except TemplateError as e: - logger.warning("Welcome render failed: %s", e, exc_info=True) - raise HTTPException( - status_code=500, - detail="Welcome template render failed. An admin can fix it at /admin/agent-prompt.", - ) - return WelcomeResponse(content=rendered) - - @router.get("/api/admin/welcome-template", response_model=TemplateGetResponse) async def admin_get_template( user: dict = Depends(require_admin), @@ -94,7 +70,6 @@ async def admin_get_template( row = WelcomeTemplateRepository(conn).get() return TemplateGetResponse( content=row["content"], - default=_load_default_template(), updated_at=row["updated_at"].isoformat() if row["updated_at"] else None, updated_by=row["updated_by"], ) @@ -106,11 +81,13 @@ async def admin_put_template( user: dict = Depends(require_admin), conn: duckdb.DuckDBPyConnection = Depends(_get_db), ): - env = Environment(undefined=StrictUndefined) + # Validate with autoescape=True (matches runtime environment) and + # StrictUndefined so unknown placeholders are caught at save time. + env = Environment(undefined=StrictUndefined, autoescape=True) try: template = env.from_string(payload.content) # Render against a stub context so undefined placeholders or runtime - # errors are caught here, not when an analyst calls /api/welcome. + # errors are caught here, not when /setup renders for a real user. template.render(**_VALIDATION_STUB_CONTEXT) except TemplateError as e: raise HTTPException(status_code=400, detail=f"Template invalid: {e}") @@ -127,7 +104,7 @@ async def admin_reset_template( return Response(status_code=204) -@router.post("/api/admin/welcome-template/preview", response_model=WelcomeResponse) +@router.post("/api/admin/welcome-template/preview", response_model=BannerResponse) async def admin_preview_template( payload: TemplatePreviewRequest, request: Request, @@ -137,11 +114,13 @@ async def admin_preview_template( """Render arbitrary template content against the live context for the calling admin, without persisting. Used by the /admin/agent-prompt editor's Preview button so admins can see their edits before saving.""" - env = Environment(undefined=StrictUndefined, autoescape=False) + env = Environment(undefined=StrictUndefined, autoescape=True) try: template = env.from_string(payload.content) - ctx = build_context(conn, user=user, server_url=str(request.base_url).rstrip("/")) + ctx = build_context( + user=user, server_url=str(request.base_url).rstrip("/") + ) rendered = template.render(**ctx) except TemplateError as e: raise HTTPException(status_code=400, detail=f"Template invalid: {e}") - return WelcomeResponse(content=rendered) + return BannerResponse(content=rendered) diff --git a/app/web/router.py b/app/web/router.py index 0ab31b7..e962513 100644 --- a/app/web/router.py +++ b/app/web/router.py @@ -727,13 +727,17 @@ async def setup_page( conn: duckdb.DuckDBPyConnection = Depends(_get_db), ): """Setup instructions for the local agent (CLI + Claude Code).""" + from src.welcome_template import render_agent_prompt_banner + base_url = str(request.base_url).rstrip("/") + banner_html = render_agent_prompt_banner(conn, user=user, server_url=base_url) ctx = _build_context( request, user=user, conn=conn, server_url=base_url, agnes_version=os.environ.get("AGNES_VERSION", "dev"), + banner_html=banner_html, ) return templates.TemplateResponse(request, "install.html", ctx) @@ -901,14 +905,12 @@ async def admin_agent_prompt_page( conn: duckdb.DuckDBPyConnection = Depends(_get_db), ): from src.repositories.welcome_template import WelcomeTemplateRepository - from src.welcome_template import _load_default_template row = WelcomeTemplateRepository(conn).get() ctx = _build_context( request, user=user, current=row["content"] or "", - default_template=_load_default_template(), updated_at=row["updated_at"], updated_by=row["updated_by"], is_override=row["content"] is not None, diff --git a/app/web/templates/install.html b/app/web/templates/install.html index 39ffea6..6a92763 100644 --- a/app/web/templates/install.html +++ b/app/web/templates/install.html @@ -628,6 +628,21 @@ .manual-summary { padding: 14px 18px; } .manual-body { padding: 16px 18px 18px; } } + + /* ── Admin-configured banner (above setup commands) ── */ + .setup-banner { + background: var(--background, #f6f7f9); + border: 1px solid var(--border, #e1e4e8); + border-left: 3px solid var(--primary, #0073D1); + border-radius: 8px; + padding: 14px 18px; + margin-bottom: 20px; + font-size: 14px; + line-height: 1.6; + color: var(--text-primary); + } + .setup-banner > *:first-child { margin-top: 0; } + .setup-banner > *:last-child { margin-bottom: 0; } {% include '_theme.html' %} @@ -648,6 +663,9 @@
+ + {% if banner_html %}
{{ banner_html | safe }}
{% endif %} +
Getting started
diff --git a/cli/commands/analyst.py b/cli/commands/analyst.py index db37ef6..57f8e3d 100644 --- a/cli/commands/analyst.py +++ b/cli/commands/analyst.py @@ -294,51 +294,15 @@ def _install_claude_hooks(settings_path: Path) -> None: # --------------------------------------------------------------------------- -# Helper: generate CLAUDE.md from server-rendered template +# Helper: initialise Claude workspace (.claude/ directory) # --------------------------------------------------------------------------- -def _generate_claude_md(workspace: Path, server_url: str, token: str) -> None: - """Fetch the rendered welcome prompt from the server and write CLAUDE.md. +def _init_claude_workspace(workspace: Path) -> None: + """Initialise the .claude/ directory with placeholder files and hooks. - Falls back to a minimal embedded template if the server endpoint is - unavailable (e.g., older server versions before /api/welcome shipped). + Does NOT write CLAUDE.md — workspace-context customisation is handled + server-side via the banner on /setup, not as a file in the workspace. """ - from urllib.parse import quote - - server_url = server_url.rstrip("/") - headers = {"Authorization": f"Bearer {token}"} - url = f"{server_url}/api/welcome?server_url={quote(server_url, safe='')}" - - rendered: Optional[str] = None - try: - resp = httpx.get(url, headers=headers, timeout=15.0) - if resp.status_code == 200: - rendered = resp.json().get("content") - elif resp.status_code != 404: - typer.echo( - f" Warning: server returned {resp.status_code} for /api/welcome; " - "using minimal fallback. Tell your admin if this persists.", - err=True, - ) - except Exception as e: - typer.echo( - f" Warning: couldn't fetch welcome prompt ({e}); using minimal fallback.", - err=True, - ) - - if rendered is None: - # Fallback for older servers — keeps the CLI usable, just less rich. - rendered = ( - "# AI Data Analyst\n\n" - f"This workspace is connected to {server_url}.\n\n" - "## Rules\n" - "- Before computing any business metric: run `da metrics show /`\n" - "- Save work output to `user/artifacts/`\n" - "- Sync data regularly with `da sync`\n" - ) - - (workspace / "CLAUDE.md").write_text(rendered, encoding="utf-8") - local_md = workspace / ".claude" / "CLAUDE.local.md" if not local_md.exists(): local_md.write_text( @@ -421,9 +385,9 @@ def setup( typer.echo("Initialising DuckDB views...") total_rows = _initialize_duckdb(workspace) - # 7. Generate CLAUDE.md (rendered server-side) - typer.echo("Fetching welcome prompt from server...") - _generate_claude_md(workspace, server_url, token) + # 7. Initialise Claude workspace (.claude/ hooks + placeholder) + typer.echo("Initializing Claude workspace...") + _init_claude_workspace(workspace) # 8. Summary typer.echo("") diff --git a/src/welcome_template.py b/src/welcome_template.py index 595f727..d8a6c3e 100644 --- a/src/welcome_template.py +++ b/src/welcome_template.py @@ -1,139 +1,106 @@ -"""Render the analyst-onboarding welcome prompt (CLAUDE.md). +"""Render the agent-setup-prompt banner shown on /setup. -Two layers: - 1. Template source — admin override from welcome_template.content, - or the shipped default at config/claude_md_template.txt. - 2. Render context — built from instance config, table_registry, - metric_definitions, and the calling user's RBAC-filtered marketplaces. +The banner is a small HTML snippet admin-editable at /admin/agent-prompt. +It appears above the bash bootstrap commands on the /setup page and is +intended for org-specific operational notes (VPN warning, support channel, +data classification reminder, platform requirements). -The Jinja2 environment uses StrictUndefined so that any typo in the -template raises immediately rather than rendering empty strings. +Default: no banner (empty string). Admins override via the welcome_template +DB table (singleton, content TEXT). + +Security: output is HTML-sanitized after render (script/iframe/event-handler +strip). The Jinja2 environment uses StrictUndefined with autoescape=True so +template typos raise immediately rather than silently emitting empty HTML. """ # See also: surfaced as the "Agent Setup Prompt" admin editor at /admin/agent-prompt. from __future__ import annotations +import logging +import re from datetime import date, datetime, timezone -from pathlib import Path from typing import Any from urllib.parse import urlparse import duckdb -from jinja2 import Environment, StrictUndefined +from jinja2 import Environment, StrictUndefined, TemplateError from app.instance_config import ( - get_data_source_type, get_instance_name, get_instance_subtitle, - get_sync_interval, ) -from src.marketplace_filter import resolve_allowed_plugins from src.repositories.welcome_template import WelcomeTemplateRepository -_DEFAULT_TEMPLATE_PATH = ( - Path(__file__).resolve().parent.parent / "config" / "claude_md_template.txt" +logger = logging.getLogger(__name__) + +# --------------------------------------------------------------------------- +# HTML sanitization +# --------------------------------------------------------------------------- + +_RE_SCRIPT = re.compile( + r"", re.IGNORECASE +) +_RE_IFRAME = re.compile( + r"|/>)", re.IGNORECASE +) +_RE_ON_EVENT = re.compile( + r"\s+on\w+\s*=\s*(?:\"[^\"]*\"|'[^']*'|[^\s>]+)", re.IGNORECASE +) +_RE_JS_URI = re.compile( + r"""(?:href|src|action)\s*=\s*(?:"|')(javascript:|data:)""", re.IGNORECASE ) -def _load_default_template() -> str: - if _DEFAULT_TEMPLATE_PATH.exists(): - return _DEFAULT_TEMPLATE_PATH.read_text(encoding="utf-8") - # Last-resort embedded fallback if the OSS template file is missing - # from the install (e.g., partial Docker COPY). - return ( - "# {{ instance.name }} — AI Data Analyst\n\n" - "This workspace is connected to {{ server.url }}.\n" - "Data refreshes every {{ sync_interval }}.\n" - ) +def _sanitize_banner_html(html: str) -> str: + """Strip dangerous constructs from admin-authored HTML. + Defense-in-depth only — admins are trusted, but this prevents accidental + XSS from copy-pasted snippets reaching the public /setup page. -def _list_tables(conn: duckdb.DuckDBPyConnection) -> list[dict[str, Any]]: - try: - rows = conn.execute( - """SELECT name, description, query_mode - FROM table_registry - ORDER BY name""" - ).fetchall() - except duckdb.CatalogException: - return [] - return [ - {"name": r[0], "description": r[1] or "", "query_mode": r[2] or "local"} - for r in rows - ] - - -def _metrics_summary(conn: duckdb.DuckDBPyConnection) -> dict[str, Any]: - try: - rows = conn.execute( - "SELECT category, COUNT(*) FROM metric_definitions GROUP BY category" - ).fetchall() - except duckdb.CatalogException: - return {"count": 0, "categories": []} - return { - "count": sum(r[1] for r in rows), - "categories": sorted({r[0] for r in rows if r[0]}), - } - - -def _marketplaces_for_user( - conn: duckdb.DuckDBPyConnection, user: dict[str, Any] -) -> list[dict[str, Any]]: - """Return marketplaces with the plugins the user is allowed to see. - - Delegates RBAC filtering entirely to resolve_allowed_plugins, which - returns List[dict] with marketplace_slug, original_name, etc. - Results are grouped by marketplace slug; display names are fetched - from marketplace_registry in a single query. + Strips: + - blocks (any content) + - tags + - on*= event handler attributes (onclick=, onload=, etc.) + - javascript: / data: URI schemes in href/src/action attributes """ - try: - allowed = resolve_allowed_plugins(conn, user) - except duckdb.CatalogException: - return [] - if not allowed: - return [] + html = _RE_SCRIPT.sub("", html) + html = _RE_IFRAME.sub("", html) + html = _RE_ON_EVENT.sub("", html) + html = _RE_JS_URI.sub( + lambda m: m.group(0).replace(m.group(1), "#"), html + ) + return html - # Build slug → display name lookup from registry - slugs = list({p["marketplace_slug"] for p in allowed}) - placeholders = ",".join(["?"] * len(slugs)) - try: - name_rows = conn.execute( - f"SELECT id, name FROM marketplace_registry WHERE id IN ({placeholders})", - slugs, - ).fetchall() - slug_to_name: dict[str, str] = {r[0]: r[1] for r in name_rows} - except duckdb.CatalogException: - slug_to_name = {} - - grouped: dict[str, dict[str, Any]] = {} - for plugin in allowed: - slug = plugin["marketplace_slug"] - bucket = grouped.setdefault( - slug, - { - "slug": slug, - "name": slug_to_name.get(slug, slug), - "plugins": [], - }, - ) - bucket["plugins"].append({"name": plugin["original_name"]}) - - return list(grouped.values()) +# --------------------------------------------------------------------------- +# Render context +# --------------------------------------------------------------------------- def build_context( - conn: duckdb.DuckDBPyConnection, *, - user: dict[str, Any], + user: dict[str, Any] | None, server_url: str, ) -> dict[str, Any]: - """Compose the Jinja2 render context. Pure, no side effects. + """Compose the Jinja2 render context for the banner. - Note: ``now`` is tz-aware UTC; DB-sourced timestamps elsewhere in the - codebase are naive (DuckDB stores ``TIMESTAMP``, not ``TIMESTAMPTZ``). - Don't subtract or compare them inside templates without normalising. + Intentionally small: instance identity, server URL, and the requesting + user (may be None for anonymous /setup visitors). No tables, metrics, or + marketplaces — the banner is for org-operational notes, not data-catalog + content. + + Note: ``now`` is tz-aware UTC. """ now = datetime.now(timezone.utc) parsed = urlparse(server_url) + user_ctx: dict[str, Any] | None = None + if user: + user_ctx = { + "id": user.get("id", ""), + "email": user.get("email", ""), + "name": user.get("name") or "", + "is_admin": bool(user.get("is_admin")), + "groups": user.get("groups") or [], + } return { "instance": { "name": get_instance_name(), @@ -143,36 +110,44 @@ def build_context( "url": server_url, "hostname": parsed.hostname or "", }, - "sync_interval": get_sync_interval(), - "data_source": {"type": get_data_source_type()}, - "tables": _list_tables(conn), - "metrics": _metrics_summary(conn), - "marketplaces": _marketplaces_for_user(conn, user), - "user": { - "id": user.get("id", ""), - "email": user.get("email", ""), - "name": user.get("name") or "", - "is_admin": bool(user.get("is_admin")), - "groups": user.get("groups") or [], - }, + "user": user_ctx, "now": now, "today": date.today().isoformat(), } -def _resolve_template_source(conn: duckdb.DuckDBPyConnection) -> str: - row = WelcomeTemplateRepository(conn).get() - return row["content"] if row.get("content") else _load_default_template() +# --------------------------------------------------------------------------- +# Banner renderer +# --------------------------------------------------------------------------- - -def render_welcome( +def render_agent_prompt_banner( conn: duckdb.DuckDBPyConnection, *, - user: dict[str, Any], + user: dict[str, Any] | None, server_url: str, ) -> str: - """Resolve the active template and render it for the given user.""" - source = _resolve_template_source(conn) - env = Environment(undefined=StrictUndefined, autoescape=False) - template = env.from_string(source) - return template.render(**build_context(conn, user=user, server_url=server_url)) + """Render the admin-configured HTML banner for the /setup page. + + Returns an empty string when no override is set (default = no banner). + Render failures are swallowed (logged) and return empty string so a + broken template never blocks the /setup page from rendering. + """ + row = WelcomeTemplateRepository(conn).get() + content = row.get("content") + if not content: + return "" + + try: + env = Environment(undefined=StrictUndefined, autoescape=True) + template = env.from_string(content) + ctx = build_context(user=user, server_url=server_url) + rendered = template.render(**ctx) + return _sanitize_banner_html(rendered) + except TemplateError as exc: + logger.warning( + "Agent-prompt banner render failed (template error): %s", exc + ) + return "" + except Exception: + logger.exception("Agent-prompt banner render failed (unexpected)") + return "" diff --git a/tests/test_analyst_bootstrap.py b/tests/test_analyst_bootstrap.py index a0715ec..3c4b8f5 100644 --- a/tests/test_analyst_bootstrap.py +++ b/tests/test_analyst_bootstrap.py @@ -78,7 +78,7 @@ class TestDetectExistingProject: patch("cli.commands.analyst._download_metadata"), \ patch("cli.commands.analyst._download_data", return_value=0), \ patch("cli.commands.analyst._initialize_duckdb", return_value=0), \ - patch("cli.commands.analyst._generate_claude_md"): + patch("cli.commands.analyst._init_claude_workspace"): result = runner.invoke( app, ["analyst", "setup", "--server-url", "http://localhost:8000", "--force"], @@ -116,59 +116,52 @@ class TestCreateWorkspace: # --------------------------------------------------------------------------- -# TestGenerateClaudeMd +# TestInitClaudeWorkspace # --------------------------------------------------------------------------- -class TestGenerateClaudeMd: - """Server-side render flow: _generate_claude_md fetches /api/welcome. - - The local-fallback path is exercised by tests/test_cli_analyst_welcome.py. - These tests cover the side-effects on the workspace (CLAUDE.local.md, - settings.json) and verify the new signature is honored. +class TestInitClaudeWorkspace: + """Tests for _init_claude_workspace: no CLAUDE.md written, but + .claude/CLAUDE.local.md placeholder and settings.json hooks are created. """ - def _patch_httpx_404(self, monkeypatch): - """Stub httpx.get to return 404 so _generate_claude_md falls back to embedded text.""" - import httpx - - def fake_get(url, headers=None, timeout=None): - return httpx.Response( - status_code=404, json={}, request=httpx.Request("GET", url) - ) - - monkeypatch.setattr("cli.commands.analyst.httpx", type("_M", (), {"get": fake_get})) - - def test_creates_claude_local_md_when_absent(self, tmp_workspace, monkeypatch): - from cli.commands.analyst import _create_workspace, _generate_claude_md + def test_does_not_write_claude_md(self, tmp_workspace): + from cli.commands.analyst import _create_workspace, _init_claude_workspace _create_workspace(tmp_workspace) - self._patch_httpx_404(monkeypatch) - _generate_claude_md(tmp_workspace, server_url="http://localhost:8000", token="t") + _init_claude_workspace(tmp_workspace) + + assert not (tmp_workspace / "CLAUDE.md").exists(), ( + "CLAUDE.md must NOT be written by _init_claude_workspace" + ) + + def test_creates_claude_local_md_when_absent(self, tmp_workspace): + from cli.commands.analyst import _create_workspace, _init_claude_workspace + + _create_workspace(tmp_workspace) + _init_claude_workspace(tmp_workspace) local_md = tmp_workspace / ".claude" / "CLAUDE.local.md" assert local_md.exists() assert local_md.read_text(encoding="utf-8").strip() != "" - def test_does_not_overwrite_existing_local_md(self, tmp_workspace, monkeypatch): - from cli.commands.analyst import _create_workspace, _generate_claude_md + def test_does_not_overwrite_existing_local_md(self, tmp_workspace): + from cli.commands.analyst import _create_workspace, _init_claude_workspace _create_workspace(tmp_workspace) local_md = tmp_workspace / ".claude" / "CLAUDE.local.md" original_content = "# My custom notes\n\nDo not overwrite me.\n" local_md.write_text(original_content, encoding="utf-8") - self._patch_httpx_404(monkeypatch) - _generate_claude_md(tmp_workspace, server_url="http://localhost:8000", token="t") + _init_claude_workspace(tmp_workspace) assert local_md.read_text(encoding="utf-8") == original_content - def test_writes_settings_json(self, tmp_workspace, monkeypatch): - from cli.commands.analyst import _create_workspace, _generate_claude_md + def test_writes_settings_json(self, tmp_workspace): + from cli.commands.analyst import _create_workspace, _init_claude_workspace import json as _json _create_workspace(tmp_workspace) - self._patch_httpx_404(monkeypatch) - _generate_claude_md(tmp_workspace, server_url="http://localhost:8000", token="t") + _init_claude_workspace(tmp_workspace) settings = _json.loads( (tmp_workspace / ".claude" / "settings.json").read_text(encoding="utf-8") @@ -176,6 +169,21 @@ class TestGenerateClaudeMd: assert settings["model"] == "sonnet" assert "Read" in settings["permissions"]["allow"] + def test_installs_session_hooks(self, tmp_workspace): + """SessionStart and SessionEnd hooks must be present in settings.json.""" + from cli.commands.analyst import _create_workspace, _init_claude_workspace + import json as _json + + _create_workspace(tmp_workspace) + _init_claude_workspace(tmp_workspace) + + settings = _json.loads( + (tmp_workspace / ".claude" / "settings.json").read_text(encoding="utf-8") + ) + hooks = settings.get("hooks", {}) + assert "SessionStart" in hooks + assert "SessionEnd" in hooks + # --------------------------------------------------------------------------- # TestReturningSession diff --git a/tests/test_cli_analyst_welcome.py b/tests/test_cli_analyst_welcome.py deleted file mode 100644 index 41edc42..0000000 --- a/tests/test_cli_analyst_welcome.py +++ /dev/null @@ -1,89 +0,0 @@ -"""Integration tests for da analyst setup → /api/welcome wiring.""" - -import json -from pathlib import Path - -import httpx -import pytest - -from cli.commands.analyst import _generate_claude_md - - -class _MockClient: - def __init__(self, responses): - self._responses = responses - self.calls = [] - - def get(self, url, headers=None, timeout=None): - self.calls.append(url) - body, status = self._responses.get(url, ({}, 404)) - return httpx.Response(status_code=status, json=body, request=httpx.Request("GET", url)) - - -def _ws(tmp_path: Path) -> Path: - workspace = tmp_path / "ws" - (workspace / ".claude").mkdir(parents=True) - return workspace - - -def test_generate_claude_md_uses_server_render(tmp_path, monkeypatch): - workspace = _ws(tmp_path) - rendered = "# CUSTOM\n\nFrom server.\n" - mock = _MockClient({ - "https://example.com/api/welcome?server_url=https%3A%2F%2Fexample.com": ( - {"content": rendered}, 200 - ), - }) - monkeypatch.setattr("cli.commands.analyst.httpx", type("_M", (), {"get": mock.get})) - _generate_claude_md(workspace, server_url="https://example.com", token="t") - - assert (workspace / "CLAUDE.md").read_text(encoding="utf-8") == rendered - # Workspace side-effects are created on the success path too. - assert (workspace / ".claude" / "CLAUDE.local.md").exists() - settings = json.loads((workspace / ".claude" / "settings.json").read_text(encoding="utf-8")) - assert settings["model"] == "sonnet" - - -def test_generate_claude_md_falls_back_on_404(tmp_path, monkeypatch): - workspace = _ws(tmp_path) - mock = _MockClient({}) # everything 404s - monkeypatch.setattr("cli.commands.analyst.httpx", type("_M", (), {"get": mock.get})) - _generate_claude_md(workspace, server_url="https://example.com", token="t") - body = (workspace / "CLAUDE.md").read_text(encoding="utf-8") - assert "AI Data Analyst" in body - assert "https://example.com" in body - - -def test_generate_claude_md_falls_back_on_null_content(tmp_path, monkeypatch): - """Server returns 200 but malformed body (`content: null`). CLI must use fallback.""" - workspace = _ws(tmp_path) - mock = _MockClient({ - "https://example.com/api/welcome?server_url=https%3A%2F%2Fexample.com": ( - {"content": None}, 200 - ), - }) - monkeypatch.setattr("cli.commands.analyst.httpx", type("_M", (), {"get": mock.get})) - _generate_claude_md(workspace, server_url="https://example.com", token="t") - body = (workspace / "CLAUDE.md").read_text(encoding="utf-8") - # Embedded fallback contains these literals - assert "AI Data Analyst" in body - assert "https://example.com" in body - - -def test_generate_claude_md_warns_on_5xx(tmp_path, monkeypatch, capsys): - """500 from server → embedded fallback, with a stderr warning so operators can diagnose.""" - workspace = _ws(tmp_path) - mock = _MockClient({ - "https://example.com/api/welcome?server_url=https%3A%2F%2Fexample.com": ( - {"detail": "boom"}, 500 - ), - }) - monkeypatch.setattr("cli.commands.analyst.httpx", type("_M", (), {"get": mock.get})) - _generate_claude_md(workspace, server_url="https://example.com", token="t") - - body = (workspace / "CLAUDE.md").read_text(encoding="utf-8") - assert "AI Data Analyst" in body # fallback used - - captured = capsys.readouterr() - assert "500" in captured.err - assert "fallback" in captured.err.lower() diff --git a/tests/test_welcome_template_api.py b/tests/test_welcome_template_api.py index 5440865..e883e6f 100644 --- a/tests/test_welcome_template_api.py +++ b/tests/test_welcome_template_api.py @@ -1,4 +1,8 @@ -"""End-to-end tests for /api/welcome and /api/admin/welcome-template.""" +"""End-to-end tests for /api/admin/welcome-template (banner editor endpoints). + +GET /api/welcome has been removed — the analyst-facing endpoint is gone. +These tests cover only the admin CRUD + preview endpoints. +""" import duckdb @@ -10,7 +14,8 @@ def _auth(token: str) -> dict[str, str]: return {"Authorization": f"Bearer {token}"} -def test_get_welcome_returns_rendered_markdown(seeded_app): +def test_get_welcome_endpoint_removed(seeded_app): + """GET /api/welcome must return 404 — the endpoint was deleted.""" c = seeded_app["client"] token = seeded_app["analyst_token"] resp = c.get( @@ -18,48 +23,39 @@ def test_get_welcome_returns_rendered_markdown(seeded_app): params={"server_url": "https://example.com"}, headers=_auth(token), ) - assert resp.status_code == 200 - body = resp.json() - assert "content" in body - assert "AI Data Analyst" in body["content"] - assert "https://example.com" in body["content"] + assert resp.status_code == 404 -def test_get_welcome_requires_auth(seeded_app): +def test_admin_get_template_initially_null(seeded_app): c = seeded_app["client"] - resp = c.get("/api/welcome", params={"server_url": "https://example.com"}) - assert resp.status_code == 401 + admin = _auth(seeded_app["admin_token"]) + + r = c.get("/api/admin/welcome-template", headers=admin) + assert r.status_code == 200 + body = r.json() + assert body["content"] is None + # No longer returns a `default` field — banner default is empty + assert "default" not in body or body.get("default") is None def test_admin_can_set_and_reset_template(seeded_app): c = seeded_app["client"] admin = _auth(seeded_app["admin_token"]) - # GET initial state - r = c.get("/api/admin/welcome-template", headers=admin) - assert r.status_code == 200 - body = r.json() - assert body["content"] is None - # The shipped default starts with the Jinja2 comment block. - assert body["default"].startswith("{#") - # PUT override r = c.put( "/api/admin/welcome-template", - json={"content": "Hello {{ user.email }}"}, + json={"content": "

Hello {{ user.email }}

"}, headers=admin, ) assert r.status_code == 200 - # Verify rendered output uses override - r = c.get( - "/api/welcome", - params={"server_url": "https://example.com"}, - headers=admin, # admin user can also call /api/welcome - ) - assert r.json()["content"].startswith("Hello ") + # GET reflects override + r = c.get("/api/admin/welcome-template", headers=admin) + assert r.status_code == 200 + assert r.json()["content"] == "

Hello {{ user.email }}

" - # DELETE = reset + # DELETE = reset (no banner) r = c.delete("/api/admin/welcome-template", headers=admin) assert r.status_code == 204 r = c.get("/api/admin/welcome-template", headers=admin) @@ -69,7 +65,7 @@ def test_admin_can_set_and_reset_template(seeded_app): def test_non_admin_cannot_edit_template(seeded_app): c = seeded_app["client"] analyst = _auth(seeded_app["analyst_token"]) - r = c.put("/api/admin/welcome-template", json={"content": "x"}, headers=analyst) + r = c.put("/api/admin/welcome-template", json={"content": "

x

"}, headers=analyst) assert r.status_code == 403 @@ -86,58 +82,29 @@ def test_invalid_jinja2_returns_400(seeded_app): def test_put_rejects_undefined_placeholder(seeded_app): - """Templates that parse but reference unknown placeholders must be rejected - at PUT time so an admin can fix the typo immediately rather than after an - analyst's bootstrap blows up.""" + """Templates that reference unknown placeholders must be rejected at PUT time.""" c = seeded_app["client"] admin = _auth(seeded_app["admin_token"]) r = c.put( "/api/admin/welcome-template", - json={"content": "Hello {{ user.emial }}"}, # typo, would fail StrictUndefined at render + json={"content": "

{{ user.emial }}

"}, # typo headers=admin, ) assert r.status_code == 400 assert "emial" in r.json()["detail"] or "undefined" in r.json()["detail"].lower() -def test_get_welcome_500_includes_reset_hint_on_render_failure(seeded_app, monkeypatch): - """If an override slips through validation and fails at render time, the - user-visible 500 must point at /admin/agent-prompt rather than leaking a - Jinja stack trace.""" - # Stub render_welcome to raise a TemplateError so we exercise the - # exception path without needing a malformed override (PUT validation - # blocks those now). - from jinja2 import UndefinedError - import app.api.welcome as welcome_module - - def fake_render(*args, **kwargs): - raise UndefinedError("'foo' is undefined") - - monkeypatch.setattr(welcome_module, "render_welcome", fake_render) - - c = seeded_app["client"] - admin = _auth(seeded_app["admin_token"]) - r = c.get( - "/api/welcome", - params={"server_url": "https://example.com"}, - headers=admin, - ) - assert r.status_code == 500 - assert "/admin/agent-prompt" in r.json()["detail"] - - -def test_admin_preview_renders_arbitrary_content(seeded_app): - """Preview endpoint must render the supplied content (not whatever's - stored), so the admin UI can show pre-save preview.""" +def test_admin_preview_renders_html(seeded_app): + """Preview endpoint renders supplied HTML content without persisting.""" c = seeded_app["client"] admin = _auth(seeded_app["admin_token"]) r = c.post( "/api/admin/welcome-template/preview", - json={"content": "# Preview {{ user.email }}"}, + json={"content": "

Preview for {{ user.email }}

"}, headers=admin, ) assert r.status_code == 200 - assert r.json()["content"].startswith("# Preview admin@test.com") + assert r.json()["content"].startswith("

Preview for admin@test.com") def test_preview_rejects_invalid_template(seeded_app): @@ -156,26 +123,29 @@ def test_preview_requires_admin(seeded_app): analyst = _auth(seeded_app["analyst_token"]) r = c.post( "/api/admin/welcome-template/preview", - json={"content": "# x"}, + json={"content": "

x

"}, headers=analyst, ) assert r.status_code == 403 def test_validation_stub_matches_build_context_shape(seeded_app, tmp_path, monkeypatch): - """If build_context grows new keys, _VALIDATION_STUB_CONTEXT must too — - otherwise admins can save templates referencing keys the PUT validator - accepts but the live render rejects.""" + """_VALIDATION_STUB_CONTEXT top-level keys must match build_context() output. + + If build_context gains new keys, the stub must track them so admins can + save templates that reference those keys without hitting a live-render + rejection after the PUT validation accepted them. + """ from app.api.welcome import _VALIDATION_STUB_CONTEXT - monkeypatch.setenv("DATA_DIR", str(tmp_path)) - db_path = tmp_path / "system.duckdb" - conn = duckdb.connect(str(db_path)) - _ensure_schema(conn) - - user = {"id": "u1", "email": "admin@test.com", "name": "Admin", "is_admin": True, "groups": ["Admin"]} - real_ctx = build_context(conn, user=user, server_url="https://example.com") - conn.close() + user = { + "id": "u1", + "email": "admin@test.com", + "name": "Admin", + "is_admin": True, + "groups": ["Admin"], + } + real_ctx = build_context(user=user, server_url="https://example.com") # Top-level keys must match assert set(_VALIDATION_STUB_CONTEXT.keys()) == set(real_ctx.keys()), ( @@ -184,9 +154,13 @@ def test_validation_stub_matches_build_context_shape(seeded_app, tmp_path, monke f"real has: {set(real_ctx.keys())}" ) - # One level deep for nested dicts - for key in ("instance", "server", "user"): - if isinstance(real_ctx.get(key), dict): - assert set(_VALIDATION_STUB_CONTEXT[key].keys()) == set(real_ctx[key].keys()), ( - f"_VALIDATION_STUB_CONTEXT[{key!r}] drifted from build_context output" - ) + # One level deep for nested dicts (user may be None in real_ctx — compare stub shape) + for key in ("instance", "server"): + assert set(_VALIDATION_STUB_CONTEXT[key].keys()) == set(real_ctx[key].keys()), ( + f"_VALIDATION_STUB_CONTEXT[{key!r}] drifted from build_context output" + ) + # user sub-keys + if real_ctx.get("user") and _VALIDATION_STUB_CONTEXT.get("user"): + assert set(_VALIDATION_STUB_CONTEXT["user"].keys()) == set(real_ctx["user"].keys()), ( + "_VALIDATION_STUB_CONTEXT['user'] drifted from build_context output" + ) diff --git a/tests/test_welcome_template_renderer.py b/tests/test_welcome_template_renderer.py index 02ce4b1..473a3a8 100644 --- a/tests/test_welcome_template_renderer.py +++ b/tests/test_welcome_template_renderer.py @@ -1,14 +1,15 @@ -"""Unit tests for the welcome-prompt renderer.""" - -import uuid -from pathlib import Path +"""Unit tests for the agent-setup-prompt banner renderer.""" import duckdb import pytest from src.db import _ensure_schema from src.repositories.welcome_template import WelcomeTemplateRepository -from src.welcome_template import build_context, render_welcome +from src.welcome_template import ( + _sanitize_banner_html, + build_context, + render_agent_prompt_banner, +) @pytest.fixture @@ -22,144 +23,188 @@ def conn(tmp_path, monkeypatch): def _user(email="alice@example.com"): - return {"id": "u1", "email": email, "name": "Alice", "is_admin": False, "groups": ["Everyone"]} + return { + "id": "u1", + "email": email, + "name": "Alice", + "is_admin": False, + "groups": ["Everyone"], + } -def test_renders_default_when_no_override(conn): - out = render_welcome(conn, user=_user(), server_url="https://example.com") - assert "AI Data Analyst" in out - assert "https://example.com" in out - assert "Alice" in out +# --------------------------------------------------------------------------- +# Default (no override) → empty string +# --------------------------------------------------------------------------- +def test_returns_empty_when_no_override(conn): + out = render_agent_prompt_banner(conn, user=_user(), server_url="https://example.com") + assert out == "" + + +# --------------------------------------------------------------------------- +# Override renders correctly +# --------------------------------------------------------------------------- def test_renders_override(conn): WelcomeTemplateRepository(conn).set( - "# {{ instance.name }} for {{ user.email }}", + "

Welcome to {{ instance.name }}!

", updated_by="admin@example.com", ) - out = render_welcome(conn, user=_user(), server_url="https://example.com") - assert out.startswith("# AI Data Analyst for alice@example.com") + out = render_agent_prompt_banner(conn, user=_user(), server_url="https://example.com") + assert "

Welcome to" in out + # instance.name comes from instance_config — any non-empty string is fine + assert "!" in out -def test_strict_undefined_raises_on_missing_placeholder(conn): +def test_renders_user_placeholder(conn): + WelcomeTemplateRepository(conn).set( + "

Hello {{ user.email }}

", + updated_by="admin@example.com", + ) + out = render_agent_prompt_banner( + conn, user=_user("bob@example.com"), server_url="https://example.com" + ) + assert "bob@example.com" in out + + +def test_renders_server_placeholder(conn): + WelcomeTemplateRepository(conn).set( + "

Server: {{ server.url }}

", + updated_by="admin@example.com", + ) + out = render_agent_prompt_banner( + conn, user=_user(), server_url="https://myserver.example.com" + ) + assert "https://myserver.example.com" in out + + +# --------------------------------------------------------------------------- +# Anonymous user (user=None) +# --------------------------------------------------------------------------- + +def test_renders_with_anonymous_user(conn): + WelcomeTemplateRepository(conn).set( + "{% if user %}

Hi {{ user.email }}

{% else %}

Please sign in.

{% endif %}", + updated_by="admin@example.com", + ) + out = render_agent_prompt_banner(conn, user=None, server_url="https://example.com") + assert "Please sign in." in out + assert "Hi" not in out + + +def test_returns_empty_for_none_user_with_no_override(conn): + out = render_agent_prompt_banner(conn, user=None, server_url="https://example.com") + assert out == "" + + +# --------------------------------------------------------------------------- +# Build context shape +# --------------------------------------------------------------------------- + +def test_context_exposes_documented_keys(): + ctx = build_context(user=_user(), server_url="https://example.com") + for key in ("instance", "server", "user", "now", "today"): + assert key in ctx, f"missing context key: {key}" + assert "tables" not in ctx + assert "metrics" not in ctx + assert "marketplaces" not in ctx + assert "sync_interval" not in ctx + assert "data_source" not in ctx + + +def test_context_user_none(): + ctx = build_context(user=None, server_url="https://example.com") + assert ctx["user"] is None + + +def test_context_instance_keys(): + ctx = build_context(user=_user(), server_url="https://example.com") + assert "name" in ctx["instance"] + assert "subtitle" in ctx["instance"] + + +def test_context_server_keys(): + ctx = build_context(user=_user(), server_url="https://example.com") + assert ctx["server"]["url"] == "https://example.com" + assert ctx["server"]["hostname"] == "example.com" + + +# --------------------------------------------------------------------------- +# HTML sanitization +# --------------------------------------------------------------------------- + +def test_sanitize_strips_script_tag(): + html = '

Hello

' + result = _sanitize_banner_html(html) + assert "

ok

' + result = _sanitize_banner_html(html) + assert "evil" not in result + assert "

ok

" in result + + +def test_sanitize_strips_iframe(): + html = '

text

' + result = _sanitize_banner_html(html) + assert "text

" in result + + +def test_sanitize_strips_event_handlers(): + html = '' + result = _sanitize_banner_html(html) + assert "onclick" not in result + assert "evil" not in result + assert "Click me" in result + + +def test_sanitize_strips_onload_on_img(): + html = 'test' + result = _sanitize_banner_html(html) + assert "onload" not in result + assert "steal" not in result + + +def test_sanitize_strips_javascript_uri(): + html = 'click' + result = _sanitize_banner_html(html) + assert "javascript:" not in result + + +def test_sanitize_allows_safe_html(): + html = "

VPN required. Contact support.

" + result = _sanitize_banner_html(html) + assert "

" in result + assert " output is sanitized before return.""" WelcomeTemplateRepository(conn).set( - "{% for m in marketplaces %}{{ m.slug }}: " - "{% for p in m.plugins %}{{ p.name }} {% endfor %}{% endfor %}", + "

safe content

", updated_by="admin@example.com", ) - - user_a = {"id": "user-a", "email": "user-a@example.com", "name": "User A", "is_admin": False, "groups": ["group-a"]} - user_b = {"id": "user-b", "email": "user-b@example.com", "name": "User B", "is_admin": False, "groups": ["group-b"]} - - out_a = render_welcome(conn, user=user_a, server_url="https://example.com") - out_b = render_welcome(conn, user=user_b, server_url="https://example.com") - - # user-a sees mkt-a plugins only - assert "mkt-a" in out_a - assert "plugin-1" in out_a - assert "mkt-b" not in out_a - assert "plugin-3" not in out_a - - # user-b sees mkt-b plugins only - assert "mkt-b" in out_b - assert "plugin-3" in out_b - assert "mkt-a" not in out_b - assert "plugin-1" not in out_b - - -def test_render_tolerates_missing_optional_tables(tmp_path, monkeypatch): - """A bare DuckDB without table_registry / marketplace_registry must still render.""" - monkeypatch.setenv("DATA_DIR", str(tmp_path)) - db_path = tmp_path / "bare.duckdb" - bare = duckdb.connect(str(db_path)) - # Only seed the welcome_template singleton manually; no other tables. - bare.execute( - """CREATE TABLE welcome_template ( - id INTEGER PRIMARY KEY DEFAULT 1, - content TEXT, - updated_at TIMESTAMP, - updated_by VARCHAR - )""" - ) - bare.execute("INSERT INTO welcome_template (id, content) VALUES (1, NULL)") - - out = render_welcome(bare, user=_user(), server_url="https://example.com") - bare.close() - assert "AI Data Analyst" in out # default template still renders - # No tables → "_No tables registered yet_" branch from the default template - assert "No tables registered yet" in out + out = render_agent_prompt_banner(conn, user=_user(), server_url="https://example.com") + assert "