diff --git a/app/api/welcome.py b/app/api/welcome.py
index 9a55e4b..9d279cf 100644
--- a/app/api/welcome.py
+++ b/app/api/welcome.py
@@ -1,9 +1,9 @@
-"""REST endpoints for the analyst-onboarding welcome prompt.
+"""REST endpoints for the agent-setup-prompt banner.
-- GET /api/welcome : render for the calling user (auth required)
-- GET /api/admin/welcome-template : raw template + shipped default (admin)
+- GET /api/admin/welcome-template : raw template override (admin)
- PUT /api/admin/welcome-template : set override (admin)
-- DELETE /api/admin/welcome-template : reset to default (admin)
+- DELETE /api/admin/welcome-template : reset to default / no banner (admin)
+- POST /api/admin/welcome-template/preview : live preview without persisting (admin)
"""
import datetime
@@ -11,14 +11,14 @@ import logging
from typing import Optional
import duckdb
-from fastapi import APIRouter, Depends, HTTPException, Query, Request, Response
-from jinja2 import Environment, StrictUndefined, TemplateError, TemplateSyntaxError
+from fastapi import APIRouter, Depends, HTTPException, Request, Response
+from jinja2 import Environment, StrictUndefined, TemplateError
from pydantic import BaseModel, Field
from app.auth.access import require_admin
-from app.auth.dependencies import _get_db, get_current_user
+from app.auth.dependencies import _get_db
from src.repositories.welcome_template import WelcomeTemplateRepository
-from src.welcome_template import _load_default_template, build_context, render_welcome
+from src.welcome_template import build_context, render_agent_prompt_banner
logger = logging.getLogger(__name__)
@@ -27,16 +27,11 @@ router = APIRouter(tags=["welcome"])
# Stub context used to validate that a saved template renders end-to-end,
# not just that it parses. Mirrors the shape of build_context() output.
+# user may be None for anonymous visitors; the stub uses an authenticated
+# user so templates that reference user.* fields are validated.
_VALIDATION_STUB_CONTEXT = {
"instance": {"name": "Example", "subtitle": "Example Org"},
"server": {"url": "https://example.com", "hostname": "example.com"},
- "sync_interval": "1 hour",
- "data_source": {"type": "local"},
- "tables": [{"name": "example", "description": "", "query_mode": "local"}],
- "metrics": {"count": 0, "categories": []},
- "marketplaces": [
- {"slug": "example", "name": "Example Marketplace", "plugins": [{"name": "x"}]}
- ],
"user": {
"id": "u",
"email": "user@example.com",
@@ -49,13 +44,12 @@ _VALIDATION_STUB_CONTEXT = {
}
-class WelcomeResponse(BaseModel):
+class BannerResponse(BaseModel):
content: str
class TemplateGetResponse(BaseModel):
content: Optional[str]
- default: str
updated_at: Optional[str] = None
updated_by: Optional[str] = None
@@ -68,24 +62,6 @@ class TemplatePreviewRequest(BaseModel):
content: str = Field(..., min_length=1, max_length=200_000)
-@router.get("/api/welcome", response_model=WelcomeResponse)
-async def get_welcome(
- server_url: str = Query(..., description="The server URL the analyst is bootstrapping against"),
- user: dict = Depends(get_current_user),
- conn: duckdb.DuckDBPyConnection = Depends(_get_db),
-):
- """Render the welcome prompt for the calling user. Returns rendered markdown."""
- try:
- rendered = render_welcome(conn, user=user, server_url=server_url)
- except TemplateError as e:
- logger.warning("Welcome render failed: %s", e, exc_info=True)
- raise HTTPException(
- status_code=500,
- detail="Welcome template render failed. An admin can fix it at /admin/agent-prompt.",
- )
- return WelcomeResponse(content=rendered)
-
-
@router.get("/api/admin/welcome-template", response_model=TemplateGetResponse)
async def admin_get_template(
user: dict = Depends(require_admin),
@@ -94,7 +70,6 @@ async def admin_get_template(
row = WelcomeTemplateRepository(conn).get()
return TemplateGetResponse(
content=row["content"],
- default=_load_default_template(),
updated_at=row["updated_at"].isoformat() if row["updated_at"] else None,
updated_by=row["updated_by"],
)
@@ -106,11 +81,13 @@ async def admin_put_template(
user: dict = Depends(require_admin),
conn: duckdb.DuckDBPyConnection = Depends(_get_db),
):
- env = Environment(undefined=StrictUndefined)
+ # Validate with autoescape=True (matches runtime environment) and
+ # StrictUndefined so unknown placeholders are caught at save time.
+ env = Environment(undefined=StrictUndefined, autoescape=True)
try:
template = env.from_string(payload.content)
# Render against a stub context so undefined placeholders or runtime
- # errors are caught here, not when an analyst calls /api/welcome.
+ # errors are caught here, not when /setup renders for a real user.
template.render(**_VALIDATION_STUB_CONTEXT)
except TemplateError as e:
raise HTTPException(status_code=400, detail=f"Template invalid: {e}")
@@ -127,7 +104,7 @@ async def admin_reset_template(
return Response(status_code=204)
-@router.post("/api/admin/welcome-template/preview", response_model=WelcomeResponse)
+@router.post("/api/admin/welcome-template/preview", response_model=BannerResponse)
async def admin_preview_template(
payload: TemplatePreviewRequest,
request: Request,
@@ -137,11 +114,13 @@ async def admin_preview_template(
"""Render arbitrary template content against the live context for the
calling admin, without persisting. Used by the /admin/agent-prompt editor's
Preview button so admins can see their edits before saving."""
- env = Environment(undefined=StrictUndefined, autoescape=False)
+ env = Environment(undefined=StrictUndefined, autoescape=True)
try:
template = env.from_string(payload.content)
- ctx = build_context(conn, user=user, server_url=str(request.base_url).rstrip("/"))
+ ctx = build_context(
+ user=user, server_url=str(request.base_url).rstrip("/")
+ )
rendered = template.render(**ctx)
except TemplateError as e:
raise HTTPException(status_code=400, detail=f"Template invalid: {e}")
- return WelcomeResponse(content=rendered)
+ return BannerResponse(content=rendered)
diff --git a/app/web/router.py b/app/web/router.py
index 0ab31b7..e962513 100644
--- a/app/web/router.py
+++ b/app/web/router.py
@@ -727,13 +727,17 @@ async def setup_page(
conn: duckdb.DuckDBPyConnection = Depends(_get_db),
):
"""Setup instructions for the local agent (CLI + Claude Code)."""
+ from src.welcome_template import render_agent_prompt_banner
+
base_url = str(request.base_url).rstrip("/")
+ banner_html = render_agent_prompt_banner(conn, user=user, server_url=base_url)
ctx = _build_context(
request,
user=user,
conn=conn,
server_url=base_url,
agnes_version=os.environ.get("AGNES_VERSION", "dev"),
+ banner_html=banner_html,
)
return templates.TemplateResponse(request, "install.html", ctx)
@@ -901,14 +905,12 @@ async def admin_agent_prompt_page(
conn: duckdb.DuckDBPyConnection = Depends(_get_db),
):
from src.repositories.welcome_template import WelcomeTemplateRepository
- from src.welcome_template import _load_default_template
row = WelcomeTemplateRepository(conn).get()
ctx = _build_context(
request,
user=user,
current=row["content"] or "",
- default_template=_load_default_template(),
updated_at=row["updated_at"],
updated_by=row["updated_by"],
is_override=row["content"] is not None,
diff --git a/app/web/templates/install.html b/app/web/templates/install.html
index 39ffea6..6a92763 100644
--- a/app/web/templates/install.html
+++ b/app/web/templates/install.html
@@ -628,6 +628,21 @@
.manual-summary { padding: 14px 18px; }
.manual-body { padding: 16px 18px 18px; }
}
+
+ /* ── Admin-configured banner (above setup commands) ── */
+ .setup-banner {
+ background: var(--background, #f6f7f9);
+ border: 1px solid var(--border, #e1e4e8);
+ border-left: 3px solid var(--primary, #0073D1);
+ border-radius: 8px;
+ padding: 14px 18px;
+ margin-bottom: 20px;
+ font-size: 14px;
+ line-height: 1.6;
+ color: var(--text-primary);
+ }
+ .setup-banner > *:first-child { margin-top: 0; }
+ .setup-banner > *:last-child { margin-bottom: 0; }
{% include '_theme.html' %}
@@ -648,6 +663,9 @@
+
+ {% if banner_html %}
{{ banner_html | safe }}
{% endif %}
+
Getting started
diff --git a/cli/commands/analyst.py b/cli/commands/analyst.py
index db37ef6..57f8e3d 100644
--- a/cli/commands/analyst.py
+++ b/cli/commands/analyst.py
@@ -294,51 +294,15 @@ def _install_claude_hooks(settings_path: Path) -> None:
# ---------------------------------------------------------------------------
-# Helper: generate CLAUDE.md from server-rendered template
+# Helper: initialise Claude workspace (.claude/ directory)
# ---------------------------------------------------------------------------
-def _generate_claude_md(workspace: Path, server_url: str, token: str) -> None:
- """Fetch the rendered welcome prompt from the server and write CLAUDE.md.
+def _init_claude_workspace(workspace: Path) -> None:
+ """Initialise the .claude/ directory with placeholder files and hooks.
- Falls back to a minimal embedded template if the server endpoint is
- unavailable (e.g., older server versions before /api/welcome shipped).
+ Does NOT write CLAUDE.md — workspace-context customisation is handled
+ server-side via the banner on /setup, not as a file in the workspace.
"""
- from urllib.parse import quote
-
- server_url = server_url.rstrip("/")
- headers = {"Authorization": f"Bearer {token}"}
- url = f"{server_url}/api/welcome?server_url={quote(server_url, safe='')}"
-
- rendered: Optional[str] = None
- try:
- resp = httpx.get(url, headers=headers, timeout=15.0)
- if resp.status_code == 200:
- rendered = resp.json().get("content")
- elif resp.status_code != 404:
- typer.echo(
- f" Warning: server returned {resp.status_code} for /api/welcome; "
- "using minimal fallback. Tell your admin if this persists.",
- err=True,
- )
- except Exception as e:
- typer.echo(
- f" Warning: couldn't fetch welcome prompt ({e}); using minimal fallback.",
- err=True,
- )
-
- if rendered is None:
- # Fallback for older servers — keeps the CLI usable, just less rich.
- rendered = (
- "# AI Data Analyst\n\n"
- f"This workspace is connected to {server_url}.\n\n"
- "## Rules\n"
- "- Before computing any business metric: run `da metrics show /`\n"
- "- Save work output to `user/artifacts/`\n"
- "- Sync data regularly with `da sync`\n"
- )
-
- (workspace / "CLAUDE.md").write_text(rendered, encoding="utf-8")
-
local_md = workspace / ".claude" / "CLAUDE.local.md"
if not local_md.exists():
local_md.write_text(
@@ -421,9 +385,9 @@ def setup(
typer.echo("Initialising DuckDB views...")
total_rows = _initialize_duckdb(workspace)
- # 7. Generate CLAUDE.md (rendered server-side)
- typer.echo("Fetching welcome prompt from server...")
- _generate_claude_md(workspace, server_url, token)
+ # 7. Initialise Claude workspace (.claude/ hooks + placeholder)
+ typer.echo("Initializing Claude workspace...")
+ _init_claude_workspace(workspace)
# 8. Summary
typer.echo("")
diff --git a/src/welcome_template.py b/src/welcome_template.py
index 595f727..d8a6c3e 100644
--- a/src/welcome_template.py
+++ b/src/welcome_template.py
@@ -1,139 +1,106 @@
-"""Render the analyst-onboarding welcome prompt (CLAUDE.md).
+"""Render the agent-setup-prompt banner shown on /setup.
-Two layers:
- 1. Template source — admin override from welcome_template.content,
- or the shipped default at config/claude_md_template.txt.
- 2. Render context — built from instance config, table_registry,
- metric_definitions, and the calling user's RBAC-filtered marketplaces.
+The banner is a small HTML snippet admin-editable at /admin/agent-prompt.
+It appears above the bash bootstrap commands on the /setup page and is
+intended for org-specific operational notes (VPN warning, support channel,
+data classification reminder, platform requirements).
-The Jinja2 environment uses StrictUndefined so that any typo in the
-template raises immediately rather than rendering empty strings.
+Default: no banner (empty string). Admins override via the welcome_template
+DB table (singleton, content TEXT).
+
+Security: output is HTML-sanitized after render (script/iframe/event-handler
+strip). The Jinja2 environment uses StrictUndefined with autoescape=True so
+template typos raise immediately rather than silently emitting empty HTML.
"""
# See also: surfaced as the "Agent Setup Prompt" admin editor at /admin/agent-prompt.
from __future__ import annotations
+import logging
+import re
from datetime import date, datetime, timezone
-from pathlib import Path
from typing import Any
from urllib.parse import urlparse
import duckdb
-from jinja2 import Environment, StrictUndefined
+from jinja2 import Environment, StrictUndefined, TemplateError
from app.instance_config import (
- get_data_source_type,
get_instance_name,
get_instance_subtitle,
- get_sync_interval,
)
-from src.marketplace_filter import resolve_allowed_plugins
from src.repositories.welcome_template import WelcomeTemplateRepository
-_DEFAULT_TEMPLATE_PATH = (
- Path(__file__).resolve().parent.parent / "config" / "claude_md_template.txt"
+logger = logging.getLogger(__name__)
+
+# ---------------------------------------------------------------------------
+# HTML sanitization
+# ---------------------------------------------------------------------------
+
+_RE_SCRIPT = re.compile(
+ r" blocks (any content)
+ - tags
+ - on*= event handler attributes (onclick=, onload=, etc.)
+ - javascript: / data: URI schemes in href/src/action attributes
"""
- try:
- allowed = resolve_allowed_plugins(conn, user)
- except duckdb.CatalogException:
- return []
- if not allowed:
- return []
+ html = _RE_SCRIPT.sub("", html)
+ html = _RE_IFRAME.sub("", html)
+ html = _RE_ON_EVENT.sub("", html)
+ html = _RE_JS_URI.sub(
+ lambda m: m.group(0).replace(m.group(1), "#"), html
+ )
+ return html
- # Build slug → display name lookup from registry
- slugs = list({p["marketplace_slug"] for p in allowed})
- placeholders = ",".join(["?"] * len(slugs))
- try:
- name_rows = conn.execute(
- f"SELECT id, name FROM marketplace_registry WHERE id IN ({placeholders})",
- slugs,
- ).fetchall()
- slug_to_name: dict[str, str] = {r[0]: r[1] for r in name_rows}
- except duckdb.CatalogException:
- slug_to_name = {}
-
- grouped: dict[str, dict[str, Any]] = {}
- for plugin in allowed:
- slug = plugin["marketplace_slug"]
- bucket = grouped.setdefault(
- slug,
- {
- "slug": slug,
- "name": slug_to_name.get(slug, slug),
- "plugins": [],
- },
- )
- bucket["plugins"].append({"name": plugin["original_name"]})
-
- return list(grouped.values())
+# ---------------------------------------------------------------------------
+# Render context
+# ---------------------------------------------------------------------------
def build_context(
- conn: duckdb.DuckDBPyConnection,
*,
- user: dict[str, Any],
+ user: dict[str, Any] | None,
server_url: str,
) -> dict[str, Any]:
- """Compose the Jinja2 render context. Pure, no side effects.
+ """Compose the Jinja2 render context for the banner.
- Note: ``now`` is tz-aware UTC; DB-sourced timestamps elsewhere in the
- codebase are naive (DuckDB stores ``TIMESTAMP``, not ``TIMESTAMPTZ``).
- Don't subtract or compare them inside templates without normalising.
+ Intentionally small: instance identity, server URL, and the requesting
+ user (may be None for anonymous /setup visitors). No tables, metrics, or
+ marketplaces — the banner is for org-operational notes, not data-catalog
+ content.
+
+ Note: ``now`` is tz-aware UTC.
"""
now = datetime.now(timezone.utc)
parsed = urlparse(server_url)
+ user_ctx: dict[str, Any] | None = None
+ if user:
+ user_ctx = {
+ "id": user.get("id", ""),
+ "email": user.get("email", ""),
+ "name": user.get("name") or "",
+ "is_admin": bool(user.get("is_admin")),
+ "groups": user.get("groups") or [],
+ }
return {
"instance": {
"name": get_instance_name(),
@@ -143,36 +110,44 @@ def build_context(
"url": server_url,
"hostname": parsed.hostname or "",
},
- "sync_interval": get_sync_interval(),
- "data_source": {"type": get_data_source_type()},
- "tables": _list_tables(conn),
- "metrics": _metrics_summary(conn),
- "marketplaces": _marketplaces_for_user(conn, user),
- "user": {
- "id": user.get("id", ""),
- "email": user.get("email", ""),
- "name": user.get("name") or "",
- "is_admin": bool(user.get("is_admin")),
- "groups": user.get("groups") or [],
- },
+ "user": user_ctx,
"now": now,
"today": date.today().isoformat(),
}
-def _resolve_template_source(conn: duckdb.DuckDBPyConnection) -> str:
- row = WelcomeTemplateRepository(conn).get()
- return row["content"] if row.get("content") else _load_default_template()
+# ---------------------------------------------------------------------------
+# Banner renderer
+# ---------------------------------------------------------------------------
-
-def render_welcome(
+def render_agent_prompt_banner(
conn: duckdb.DuckDBPyConnection,
*,
- user: dict[str, Any],
+ user: dict[str, Any] | None,
server_url: str,
) -> str:
- """Resolve the active template and render it for the given user."""
- source = _resolve_template_source(conn)
- env = Environment(undefined=StrictUndefined, autoescape=False)
- template = env.from_string(source)
- return template.render(**build_context(conn, user=user, server_url=server_url))
+ """Render the admin-configured HTML banner for the /setup page.
+
+ Returns an empty string when no override is set (default = no banner).
+ Render failures are swallowed (logged) and return empty string so a
+ broken template never blocks the /setup page from rendering.
+ """
+ row = WelcomeTemplateRepository(conn).get()
+ content = row.get("content")
+ if not content:
+ return ""
+
+ try:
+ env = Environment(undefined=StrictUndefined, autoescape=True)
+ template = env.from_string(content)
+ ctx = build_context(user=user, server_url=server_url)
+ rendered = template.render(**ctx)
+ return _sanitize_banner_html(rendered)
+ except TemplateError as exc:
+ logger.warning(
+ "Agent-prompt banner render failed (template error): %s", exc
+ )
+ return ""
+ except Exception:
+ logger.exception("Agent-prompt banner render failed (unexpected)")
+ return ""
diff --git a/tests/test_analyst_bootstrap.py b/tests/test_analyst_bootstrap.py
index a0715ec..3c4b8f5 100644
--- a/tests/test_analyst_bootstrap.py
+++ b/tests/test_analyst_bootstrap.py
@@ -78,7 +78,7 @@ class TestDetectExistingProject:
patch("cli.commands.analyst._download_metadata"), \
patch("cli.commands.analyst._download_data", return_value=0), \
patch("cli.commands.analyst._initialize_duckdb", return_value=0), \
- patch("cli.commands.analyst._generate_claude_md"):
+ patch("cli.commands.analyst._init_claude_workspace"):
result = runner.invoke(
app,
["analyst", "setup", "--server-url", "http://localhost:8000", "--force"],
@@ -116,59 +116,52 @@ class TestCreateWorkspace:
# ---------------------------------------------------------------------------
-# TestGenerateClaudeMd
+# TestInitClaudeWorkspace
# ---------------------------------------------------------------------------
-class TestGenerateClaudeMd:
- """Server-side render flow: _generate_claude_md fetches /api/welcome.
-
- The local-fallback path is exercised by tests/test_cli_analyst_welcome.py.
- These tests cover the side-effects on the workspace (CLAUDE.local.md,
- settings.json) and verify the new signature is honored.
+class TestInitClaudeWorkspace:
+ """Tests for _init_claude_workspace: no CLAUDE.md written, but
+ .claude/CLAUDE.local.md placeholder and settings.json hooks are created.
"""
- def _patch_httpx_404(self, monkeypatch):
- """Stub httpx.get to return 404 so _generate_claude_md falls back to embedded text."""
- import httpx
-
- def fake_get(url, headers=None, timeout=None):
- return httpx.Response(
- status_code=404, json={}, request=httpx.Request("GET", url)
- )
-
- monkeypatch.setattr("cli.commands.analyst.httpx", type("_M", (), {"get": fake_get}))
-
- def test_creates_claude_local_md_when_absent(self, tmp_workspace, monkeypatch):
- from cli.commands.analyst import _create_workspace, _generate_claude_md
+ def test_does_not_write_claude_md(self, tmp_workspace):
+ from cli.commands.analyst import _create_workspace, _init_claude_workspace
_create_workspace(tmp_workspace)
- self._patch_httpx_404(monkeypatch)
- _generate_claude_md(tmp_workspace, server_url="http://localhost:8000", token="t")
+ _init_claude_workspace(tmp_workspace)
+
+ assert not (tmp_workspace / "CLAUDE.md").exists(), (
+ "CLAUDE.md must NOT be written by _init_claude_workspace"
+ )
+
+ def test_creates_claude_local_md_when_absent(self, tmp_workspace):
+ from cli.commands.analyst import _create_workspace, _init_claude_workspace
+
+ _create_workspace(tmp_workspace)
+ _init_claude_workspace(tmp_workspace)
local_md = tmp_workspace / ".claude" / "CLAUDE.local.md"
assert local_md.exists()
assert local_md.read_text(encoding="utf-8").strip() != ""
- def test_does_not_overwrite_existing_local_md(self, tmp_workspace, monkeypatch):
- from cli.commands.analyst import _create_workspace, _generate_claude_md
+ def test_does_not_overwrite_existing_local_md(self, tmp_workspace):
+ from cli.commands.analyst import _create_workspace, _init_claude_workspace
_create_workspace(tmp_workspace)
local_md = tmp_workspace / ".claude" / "CLAUDE.local.md"
original_content = "# My custom notes\n\nDo not overwrite me.\n"
local_md.write_text(original_content, encoding="utf-8")
- self._patch_httpx_404(monkeypatch)
- _generate_claude_md(tmp_workspace, server_url="http://localhost:8000", token="t")
+ _init_claude_workspace(tmp_workspace)
assert local_md.read_text(encoding="utf-8") == original_content
- def test_writes_settings_json(self, tmp_workspace, monkeypatch):
- from cli.commands.analyst import _create_workspace, _generate_claude_md
+ def test_writes_settings_json(self, tmp_workspace):
+ from cli.commands.analyst import _create_workspace, _init_claude_workspace
import json as _json
_create_workspace(tmp_workspace)
- self._patch_httpx_404(monkeypatch)
- _generate_claude_md(tmp_workspace, server_url="http://localhost:8000", token="t")
+ _init_claude_workspace(tmp_workspace)
settings = _json.loads(
(tmp_workspace / ".claude" / "settings.json").read_text(encoding="utf-8")
@@ -176,6 +169,21 @@ class TestGenerateClaudeMd:
assert settings["model"] == "sonnet"
assert "Read" in settings["permissions"]["allow"]
+ def test_installs_session_hooks(self, tmp_workspace):
+ """SessionStart and SessionEnd hooks must be present in settings.json."""
+ from cli.commands.analyst import _create_workspace, _init_claude_workspace
+ import json as _json
+
+ _create_workspace(tmp_workspace)
+ _init_claude_workspace(tmp_workspace)
+
+ settings = _json.loads(
+ (tmp_workspace / ".claude" / "settings.json").read_text(encoding="utf-8")
+ )
+ hooks = settings.get("hooks", {})
+ assert "SessionStart" in hooks
+ assert "SessionEnd" in hooks
+
# ---------------------------------------------------------------------------
# TestReturningSession
diff --git a/tests/test_cli_analyst_welcome.py b/tests/test_cli_analyst_welcome.py
deleted file mode 100644
index 41edc42..0000000
--- a/tests/test_cli_analyst_welcome.py
+++ /dev/null
@@ -1,89 +0,0 @@
-"""Integration tests for da analyst setup → /api/welcome wiring."""
-
-import json
-from pathlib import Path
-
-import httpx
-import pytest
-
-from cli.commands.analyst import _generate_claude_md
-
-
-class _MockClient:
- def __init__(self, responses):
- self._responses = responses
- self.calls = []
-
- def get(self, url, headers=None, timeout=None):
- self.calls.append(url)
- body, status = self._responses.get(url, ({}, 404))
- return httpx.Response(status_code=status, json=body, request=httpx.Request("GET", url))
-
-
-def _ws(tmp_path: Path) -> Path:
- workspace = tmp_path / "ws"
- (workspace / ".claude").mkdir(parents=True)
- return workspace
-
-
-def test_generate_claude_md_uses_server_render(tmp_path, monkeypatch):
- workspace = _ws(tmp_path)
- rendered = "# CUSTOM\n\nFrom server.\n"
- mock = _MockClient({
- "https://example.com/api/welcome?server_url=https%3A%2F%2Fexample.com": (
- {"content": rendered}, 200
- ),
- })
- monkeypatch.setattr("cli.commands.analyst.httpx", type("_M", (), {"get": mock.get}))
- _generate_claude_md(workspace, server_url="https://example.com", token="t")
-
- assert (workspace / "CLAUDE.md").read_text(encoding="utf-8") == rendered
- # Workspace side-effects are created on the success path too.
- assert (workspace / ".claude" / "CLAUDE.local.md").exists()
- settings = json.loads((workspace / ".claude" / "settings.json").read_text(encoding="utf-8"))
- assert settings["model"] == "sonnet"
-
-
-def test_generate_claude_md_falls_back_on_404(tmp_path, monkeypatch):
- workspace = _ws(tmp_path)
- mock = _MockClient({}) # everything 404s
- monkeypatch.setattr("cli.commands.analyst.httpx", type("_M", (), {"get": mock.get}))
- _generate_claude_md(workspace, server_url="https://example.com", token="t")
- body = (workspace / "CLAUDE.md").read_text(encoding="utf-8")
- assert "AI Data Analyst" in body
- assert "https://example.com" in body
-
-
-def test_generate_claude_md_falls_back_on_null_content(tmp_path, monkeypatch):
- """Server returns 200 but malformed body (`content: null`). CLI must use fallback."""
- workspace = _ws(tmp_path)
- mock = _MockClient({
- "https://example.com/api/welcome?server_url=https%3A%2F%2Fexample.com": (
- {"content": None}, 200
- ),
- })
- monkeypatch.setattr("cli.commands.analyst.httpx", type("_M", (), {"get": mock.get}))
- _generate_claude_md(workspace, server_url="https://example.com", token="t")
- body = (workspace / "CLAUDE.md").read_text(encoding="utf-8")
- # Embedded fallback contains these literals
- assert "AI Data Analyst" in body
- assert "https://example.com" in body
-
-
-def test_generate_claude_md_warns_on_5xx(tmp_path, monkeypatch, capsys):
- """500 from server → embedded fallback, with a stderr warning so operators can diagnose."""
- workspace = _ws(tmp_path)
- mock = _MockClient({
- "https://example.com/api/welcome?server_url=https%3A%2F%2Fexample.com": (
- {"detail": "boom"}, 500
- ),
- })
- monkeypatch.setattr("cli.commands.analyst.httpx", type("_M", (), {"get": mock.get}))
- _generate_claude_md(workspace, server_url="https://example.com", token="t")
-
- body = (workspace / "CLAUDE.md").read_text(encoding="utf-8")
- assert "AI Data Analyst" in body # fallback used
-
- captured = capsys.readouterr()
- assert "500" in captured.err
- assert "fallback" in captured.err.lower()
diff --git a/tests/test_welcome_template_api.py b/tests/test_welcome_template_api.py
index 5440865..e883e6f 100644
--- a/tests/test_welcome_template_api.py
+++ b/tests/test_welcome_template_api.py
@@ -1,4 +1,8 @@
-"""End-to-end tests for /api/welcome and /api/admin/welcome-template."""
+"""End-to-end tests for /api/admin/welcome-template (banner editor endpoints).
+
+GET /api/welcome has been removed — the analyst-facing endpoint is gone.
+These tests cover only the admin CRUD + preview endpoints.
+"""
import duckdb
@@ -10,7 +14,8 @@ def _auth(token: str) -> dict[str, str]:
return {"Authorization": f"Bearer {token}"}
-def test_get_welcome_returns_rendered_markdown(seeded_app):
+def test_get_welcome_endpoint_removed(seeded_app):
+ """GET /api/welcome must return 404 — the endpoint was deleted."""
c = seeded_app["client"]
token = seeded_app["analyst_token"]
resp = c.get(
@@ -18,48 +23,39 @@ def test_get_welcome_returns_rendered_markdown(seeded_app):
params={"server_url": "https://example.com"},
headers=_auth(token),
)
- assert resp.status_code == 200
- body = resp.json()
- assert "content" in body
- assert "AI Data Analyst" in body["content"]
- assert "https://example.com" in body["content"]
+ assert resp.status_code == 404
-def test_get_welcome_requires_auth(seeded_app):
+def test_admin_get_template_initially_null(seeded_app):
c = seeded_app["client"]
- resp = c.get("/api/welcome", params={"server_url": "https://example.com"})
- assert resp.status_code == 401
+ admin = _auth(seeded_app["admin_token"])
+
+ r = c.get("/api/admin/welcome-template", headers=admin)
+ assert r.status_code == 200
+ body = r.json()
+ assert body["content"] is None
+ # No longer returns a `default` field — banner default is empty
+ assert "default" not in body or body.get("default") is None
def test_admin_can_set_and_reset_template(seeded_app):
c = seeded_app["client"]
admin = _auth(seeded_app["admin_token"])
- # GET initial state
- r = c.get("/api/admin/welcome-template", headers=admin)
- assert r.status_code == 200
- body = r.json()
- assert body["content"] is None
- # The shipped default starts with the Jinja2 comment block.
- assert body["default"].startswith("{#")
-
# PUT override
r = c.put(
"/api/admin/welcome-template",
- json={"content": "Hello {{ user.email }}"},
+ json={"content": "
Hello {{ user.email }}
"},
headers=admin,
)
assert r.status_code == 200
- # Verify rendered output uses override
- r = c.get(
- "/api/welcome",
- params={"server_url": "https://example.com"},
- headers=admin, # admin user can also call /api/welcome
- )
- assert r.json()["content"].startswith("Hello ")
+ # GET reflects override
+ r = c.get("/api/admin/welcome-template", headers=admin)
+ assert r.status_code == 200
+ assert r.json()["content"] == "
Hello {{ user.email }}
"
- # DELETE = reset
+ # DELETE = reset (no banner)
r = c.delete("/api/admin/welcome-template", headers=admin)
assert r.status_code == 204
r = c.get("/api/admin/welcome-template", headers=admin)
@@ -69,7 +65,7 @@ def test_admin_can_set_and_reset_template(seeded_app):
def test_non_admin_cannot_edit_template(seeded_app):
c = seeded_app["client"]
analyst = _auth(seeded_app["analyst_token"])
- r = c.put("/api/admin/welcome-template", json={"content": "x"}, headers=analyst)
+ r = c.put("/api/admin/welcome-template", json={"content": "
x
"}, headers=analyst)
assert r.status_code == 403
@@ -86,58 +82,29 @@ def test_invalid_jinja2_returns_400(seeded_app):
def test_put_rejects_undefined_placeholder(seeded_app):
- """Templates that parse but reference unknown placeholders must be rejected
- at PUT time so an admin can fix the typo immediately rather than after an
- analyst's bootstrap blows up."""
+ """Templates that reference unknown placeholders must be rejected at PUT time."""
c = seeded_app["client"]
admin = _auth(seeded_app["admin_token"])
r = c.put(
"/api/admin/welcome-template",
- json={"content": "Hello {{ user.emial }}"}, # typo, would fail StrictUndefined at render
+ json={"content": "
{{ user.emial }}
"}, # typo
headers=admin,
)
assert r.status_code == 400
assert "emial" in r.json()["detail"] or "undefined" in r.json()["detail"].lower()
-def test_get_welcome_500_includes_reset_hint_on_render_failure(seeded_app, monkeypatch):
- """If an override slips through validation and fails at render time, the
- user-visible 500 must point at /admin/agent-prompt rather than leaking a
- Jinja stack trace."""
- # Stub render_welcome to raise a TemplateError so we exercise the
- # exception path without needing a malformed override (PUT validation
- # blocks those now).
- from jinja2 import UndefinedError
- import app.api.welcome as welcome_module
-
- def fake_render(*args, **kwargs):
- raise UndefinedError("'foo' is undefined")
-
- monkeypatch.setattr(welcome_module, "render_welcome", fake_render)
-
- c = seeded_app["client"]
- admin = _auth(seeded_app["admin_token"])
- r = c.get(
- "/api/welcome",
- params={"server_url": "https://example.com"},
- headers=admin,
- )
- assert r.status_code == 500
- assert "/admin/agent-prompt" in r.json()["detail"]
-
-
-def test_admin_preview_renders_arbitrary_content(seeded_app):
- """Preview endpoint must render the supplied content (not whatever's
- stored), so the admin UI can show pre-save preview."""
+def test_admin_preview_renders_html(seeded_app):
+ """Preview endpoint renders supplied HTML content without persisting."""
c = seeded_app["client"]
admin = _auth(seeded_app["admin_token"])
r = c.post(
"/api/admin/welcome-template/preview",
- json={"content": "# Preview {{ user.email }}"},
+ json={"content": "
"},
headers=analyst,
)
assert r.status_code == 403
def test_validation_stub_matches_build_context_shape(seeded_app, tmp_path, monkeypatch):
- """If build_context grows new keys, _VALIDATION_STUB_CONTEXT must too —
- otherwise admins can save templates referencing keys the PUT validator
- accepts but the live render rejects."""
+ """_VALIDATION_STUB_CONTEXT top-level keys must match build_context() output.
+
+ If build_context gains new keys, the stub must track them so admins can
+ save templates that reference those keys without hitting a live-render
+ rejection after the PUT validation accepted them.
+ """
from app.api.welcome import _VALIDATION_STUB_CONTEXT
- monkeypatch.setenv("DATA_DIR", str(tmp_path))
- db_path = tmp_path / "system.duckdb"
- conn = duckdb.connect(str(db_path))
- _ensure_schema(conn)
-
- user = {"id": "u1", "email": "admin@test.com", "name": "Admin", "is_admin": True, "groups": ["Admin"]}
- real_ctx = build_context(conn, user=user, server_url="https://example.com")
- conn.close()
+ user = {
+ "id": "u1",
+ "email": "admin@test.com",
+ "name": "Admin",
+ "is_admin": True,
+ "groups": ["Admin"],
+ }
+ real_ctx = build_context(user=user, server_url="https://example.com")
# Top-level keys must match
assert set(_VALIDATION_STUB_CONTEXT.keys()) == set(real_ctx.keys()), (
@@ -184,9 +154,13 @@ def test_validation_stub_matches_build_context_shape(seeded_app, tmp_path, monke
f"real has: {set(real_ctx.keys())}"
)
- # One level deep for nested dicts
- for key in ("instance", "server", "user"):
- if isinstance(real_ctx.get(key), dict):
- assert set(_VALIDATION_STUB_CONTEXT[key].keys()) == set(real_ctx[key].keys()), (
- f"_VALIDATION_STUB_CONTEXT[{key!r}] drifted from build_context output"
- )
+ # One level deep for nested dicts (user may be None in real_ctx — compare stub shape)
+ for key in ("instance", "server"):
+ assert set(_VALIDATION_STUB_CONTEXT[key].keys()) == set(real_ctx[key].keys()), (
+ f"_VALIDATION_STUB_CONTEXT[{key!r}] drifted from build_context output"
+ )
+ # user sub-keys
+ if real_ctx.get("user") and _VALIDATION_STUB_CONTEXT.get("user"):
+ assert set(_VALIDATION_STUB_CONTEXT["user"].keys()) == set(real_ctx["user"].keys()), (
+ "_VALIDATION_STUB_CONTEXT['user'] drifted from build_context output"
+ )
diff --git a/tests/test_welcome_template_renderer.py b/tests/test_welcome_template_renderer.py
index 02ce4b1..473a3a8 100644
--- a/tests/test_welcome_template_renderer.py
+++ b/tests/test_welcome_template_renderer.py
@@ -1,14 +1,15 @@
-"""Unit tests for the welcome-prompt renderer."""
-
-import uuid
-from pathlib import Path
+"""Unit tests for the agent-setup-prompt banner renderer."""
import duckdb
import pytest
from src.db import _ensure_schema
from src.repositories.welcome_template import WelcomeTemplateRepository
-from src.welcome_template import build_context, render_welcome
+from src.welcome_template import (
+ _sanitize_banner_html,
+ build_context,
+ render_agent_prompt_banner,
+)
@pytest.fixture
@@ -22,144 +23,188 @@ def conn(tmp_path, monkeypatch):
def _user(email="alice@example.com"):
- return {"id": "u1", "email": email, "name": "Alice", "is_admin": False, "groups": ["Everyone"]}
+ return {
+ "id": "u1",
+ "email": email,
+ "name": "Alice",
+ "is_admin": False,
+ "groups": ["Everyone"],
+ }
-def test_renders_default_when_no_override(conn):
- out = render_welcome(conn, user=_user(), server_url="https://example.com")
- assert "AI Data Analyst" in out
- assert "https://example.com" in out
- assert "Alice" in out
+# ---------------------------------------------------------------------------
+# Default (no override) → empty string
+# ---------------------------------------------------------------------------
+def test_returns_empty_when_no_override(conn):
+ out = render_agent_prompt_banner(conn, user=_user(), server_url="https://example.com")
+ assert out == ""
+
+
+# ---------------------------------------------------------------------------
+# Override renders correctly
+# ---------------------------------------------------------------------------
def test_renders_override(conn):
WelcomeTemplateRepository(conn).set(
- "# {{ instance.name }} for {{ user.email }}",
+ "
Welcome to {{ instance.name }}!
",
updated_by="admin@example.com",
)
- out = render_welcome(conn, user=_user(), server_url="https://example.com")
- assert out.startswith("# AI Data Analyst for alice@example.com")
+ out = render_agent_prompt_banner(conn, user=_user(), server_url="https://example.com")
+ assert "
Welcome to" in out
+ # instance.name comes from instance_config — any non-empty string is fine
+ assert "!" in out
-def test_strict_undefined_raises_on_missing_placeholder(conn):
+def test_renders_user_placeholder(conn):
+ WelcomeTemplateRepository(conn).set(
+ "
Hello {{ user.email }}
",
+ updated_by="admin@example.com",
+ )
+ out = render_agent_prompt_banner(
+ conn, user=_user("bob@example.com"), server_url="https://example.com"
+ )
+ assert "bob@example.com" in out
+
+
+def test_renders_server_placeholder(conn):
+ WelcomeTemplateRepository(conn).set(
+ "
Server: {{ server.url }}
",
+ updated_by="admin@example.com",
+ )
+ out = render_agent_prompt_banner(
+ conn, user=_user(), server_url="https://myserver.example.com"
+ )
+ assert "https://myserver.example.com" in out
+
+
+# ---------------------------------------------------------------------------
+# Anonymous user (user=None)
+# ---------------------------------------------------------------------------
+
+def test_renders_with_anonymous_user(conn):
+ WelcomeTemplateRepository(conn).set(
+ "{% if user %}
Hi {{ user.email }}
{% else %}
Please sign in.
{% endif %}",
+ updated_by="admin@example.com",
+ )
+ out = render_agent_prompt_banner(conn, user=None, server_url="https://example.com")
+ assert "Please sign in." in out
+ assert "Hi" not in out
+
+
+def test_returns_empty_for_none_user_with_no_override(conn):
+ out = render_agent_prompt_banner(conn, user=None, server_url="https://example.com")
+ assert out == ""
+
+
+# ---------------------------------------------------------------------------
+# Build context shape
+# ---------------------------------------------------------------------------
+
+def test_context_exposes_documented_keys():
+ ctx = build_context(user=_user(), server_url="https://example.com")
+ for key in ("instance", "server", "user", "now", "today"):
+ assert key in ctx, f"missing context key: {key}"
+ assert "tables" not in ctx
+ assert "metrics" not in ctx
+ assert "marketplaces" not in ctx
+ assert "sync_interval" not in ctx
+ assert "data_source" not in ctx
+
+
+def test_context_user_none():
+ ctx = build_context(user=None, server_url="https://example.com")
+ assert ctx["user"] is None
+
+
+def test_context_instance_keys():
+ ctx = build_context(user=_user(), server_url="https://example.com")
+ assert "name" in ctx["instance"]
+ assert "subtitle" in ctx["instance"]
+
+
+def test_context_server_keys():
+ ctx = build_context(user=_user(), server_url="https://example.com")
+ assert ctx["server"]["url"] == "https://example.com"
+ assert ctx["server"]["hostname"] == "example.com"
+
+
+# ---------------------------------------------------------------------------
+# HTML sanitization
+# ---------------------------------------------------------------------------
+
+def test_sanitize_strips_script_tag():
+ html = '
Hello
'
+ result = _sanitize_banner_html(html)
+ assert "
ok
'
+ result = _sanitize_banner_html(html)
+ assert "evil" not in result
+ assert "
ok
" in result
+
+
+def test_sanitize_strips_iframe():
+ html = '
text
'
+ result = _sanitize_banner_html(html)
+ assert "
" in result
+
+
+def test_sanitize_strips_event_handlers():
+ html = ''
+ result = _sanitize_banner_html(html)
+ assert "onclick" not in result
+ assert "evil" not in result
+ assert "Click me" in result
+
+
+def test_sanitize_strips_onload_on_img():
+ html = ''
+ result = _sanitize_banner_html(html)
+ assert "onload" not in result
+ assert "steal" not in result
+
+
+def test_sanitize_strips_javascript_uri():
+ html = 'click'
+ result = _sanitize_banner_html(html)
+ assert "javascript:" not in result
+
+
+def test_sanitize_allows_safe_html():
+ html = "