feat(admin-prompt): variant C — banner on /setup, drop CLAUDE.md generation

- src/welcome_template.py: rewrite as HTML banner renderer
  (render_agent_prompt_banner); drop _list_tables, _metrics_summary,
  _marketplaces_for_user, render_welcome, _load_default_template.
  build_context now exposes only instance/server/user/now/today.
  _sanitize_banner_html strips script/iframe/on*/javascript: post-render.
- app/api/welcome.py: drop get_welcome handler, WelcomeResponse, old
  _VALIDATION_STUB_CONTEXT. Admin endpoints stay at same URLs; validation
  stub updated to match new slim context. Preview now uses autoescape=True.
- app/web/router.py: setup_page calls render_agent_prompt_banner and passes
  banner_html to install.html; admin_agent_prompt_page drops _load_default_template.
- app/web/templates/install.html: add .setup-banner CSS + banner block above hero.
- cli/commands/analyst.py: replace _generate_claude_md with _init_claude_workspace;
  no CLAUDE.md written, only .claude/CLAUDE.local.md placeholder + settings.json hooks.
- tests: delete test_cli_analyst_welcome.py (tests deleted endpoint/function);
  rewrite TestGenerateClaudeMd → TestInitClaudeWorkspace; update api test to
  assert /api/welcome returns 404 and remove welcome-fetch tests.
This commit is contained in:
ZdenekSrotyr 2026-05-02 22:18:12 +02:00
parent 60386b9c3c
commit 8db4c1645b
9 changed files with 416 additions and 540 deletions

View file

@ -1,9 +1,9 @@
"""REST endpoints for the analyst-onboarding welcome prompt.
"""REST endpoints for the agent-setup-prompt banner.
- GET /api/welcome : render for the calling user (auth required)
- GET /api/admin/welcome-template : raw template + shipped default (admin)
- GET /api/admin/welcome-template : raw template override (admin)
- PUT /api/admin/welcome-template : set override (admin)
- DELETE /api/admin/welcome-template : reset to default (admin)
- DELETE /api/admin/welcome-template : reset to default / no banner (admin)
- POST /api/admin/welcome-template/preview : live preview without persisting (admin)
"""
import datetime
@ -11,14 +11,14 @@ import logging
from typing import Optional
import duckdb
from fastapi import APIRouter, Depends, HTTPException, Query, Request, Response
from jinja2 import Environment, StrictUndefined, TemplateError, TemplateSyntaxError
from fastapi import APIRouter, Depends, HTTPException, Request, Response
from jinja2 import Environment, StrictUndefined, TemplateError
from pydantic import BaseModel, Field
from app.auth.access import require_admin
from app.auth.dependencies import _get_db, get_current_user
from app.auth.dependencies import _get_db
from src.repositories.welcome_template import WelcomeTemplateRepository
from src.welcome_template import _load_default_template, build_context, render_welcome
from src.welcome_template import build_context, render_agent_prompt_banner
logger = logging.getLogger(__name__)
@ -27,16 +27,11 @@ router = APIRouter(tags=["welcome"])
# Stub context used to validate that a saved template renders end-to-end,
# not just that it parses. Mirrors the shape of build_context() output.
# user may be None for anonymous visitors; the stub uses an authenticated
# user so templates that reference user.* fields are validated.
_VALIDATION_STUB_CONTEXT = {
"instance": {"name": "Example", "subtitle": "Example Org"},
"server": {"url": "https://example.com", "hostname": "example.com"},
"sync_interval": "1 hour",
"data_source": {"type": "local"},
"tables": [{"name": "example", "description": "", "query_mode": "local"}],
"metrics": {"count": 0, "categories": []},
"marketplaces": [
{"slug": "example", "name": "Example Marketplace", "plugins": [{"name": "x"}]}
],
"user": {
"id": "u",
"email": "user@example.com",
@ -49,13 +44,12 @@ _VALIDATION_STUB_CONTEXT = {
}
class WelcomeResponse(BaseModel):
class BannerResponse(BaseModel):
content: str
class TemplateGetResponse(BaseModel):
content: Optional[str]
default: str
updated_at: Optional[str] = None
updated_by: Optional[str] = None
@ -68,24 +62,6 @@ class TemplatePreviewRequest(BaseModel):
content: str = Field(..., min_length=1, max_length=200_000)
@router.get("/api/welcome", response_model=WelcomeResponse)
async def get_welcome(
server_url: str = Query(..., description="The server URL the analyst is bootstrapping against"),
user: dict = Depends(get_current_user),
conn: duckdb.DuckDBPyConnection = Depends(_get_db),
):
"""Render the welcome prompt for the calling user. Returns rendered markdown."""
try:
rendered = render_welcome(conn, user=user, server_url=server_url)
except TemplateError as e:
logger.warning("Welcome render failed: %s", e, exc_info=True)
raise HTTPException(
status_code=500,
detail="Welcome template render failed. An admin can fix it at /admin/agent-prompt.",
)
return WelcomeResponse(content=rendered)
@router.get("/api/admin/welcome-template", response_model=TemplateGetResponse)
async def admin_get_template(
user: dict = Depends(require_admin),
@ -94,7 +70,6 @@ async def admin_get_template(
row = WelcomeTemplateRepository(conn).get()
return TemplateGetResponse(
content=row["content"],
default=_load_default_template(),
updated_at=row["updated_at"].isoformat() if row["updated_at"] else None,
updated_by=row["updated_by"],
)
@ -106,11 +81,13 @@ async def admin_put_template(
user: dict = Depends(require_admin),
conn: duckdb.DuckDBPyConnection = Depends(_get_db),
):
env = Environment(undefined=StrictUndefined)
# Validate with autoescape=True (matches runtime environment) and
# StrictUndefined so unknown placeholders are caught at save time.
env = Environment(undefined=StrictUndefined, autoescape=True)
try:
template = env.from_string(payload.content)
# Render against a stub context so undefined placeholders or runtime
# errors are caught here, not when an analyst calls /api/welcome.
# errors are caught here, not when /setup renders for a real user.
template.render(**_VALIDATION_STUB_CONTEXT)
except TemplateError as e:
raise HTTPException(status_code=400, detail=f"Template invalid: {e}")
@ -127,7 +104,7 @@ async def admin_reset_template(
return Response(status_code=204)
@router.post("/api/admin/welcome-template/preview", response_model=WelcomeResponse)
@router.post("/api/admin/welcome-template/preview", response_model=BannerResponse)
async def admin_preview_template(
payload: TemplatePreviewRequest,
request: Request,
@ -137,11 +114,13 @@ async def admin_preview_template(
"""Render arbitrary template content against the live context for the
calling admin, without persisting. Used by the /admin/agent-prompt editor's
Preview button so admins can see their edits before saving."""
env = Environment(undefined=StrictUndefined, autoescape=False)
env = Environment(undefined=StrictUndefined, autoescape=True)
try:
template = env.from_string(payload.content)
ctx = build_context(conn, user=user, server_url=str(request.base_url).rstrip("/"))
ctx = build_context(
user=user, server_url=str(request.base_url).rstrip("/")
)
rendered = template.render(**ctx)
except TemplateError as e:
raise HTTPException(status_code=400, detail=f"Template invalid: {e}")
return WelcomeResponse(content=rendered)
return BannerResponse(content=rendered)

View file

@ -727,13 +727,17 @@ async def setup_page(
conn: duckdb.DuckDBPyConnection = Depends(_get_db),
):
"""Setup instructions for the local agent (CLI + Claude Code)."""
from src.welcome_template import render_agent_prompt_banner
base_url = str(request.base_url).rstrip("/")
banner_html = render_agent_prompt_banner(conn, user=user, server_url=base_url)
ctx = _build_context(
request,
user=user,
conn=conn,
server_url=base_url,
agnes_version=os.environ.get("AGNES_VERSION", "dev"),
banner_html=banner_html,
)
return templates.TemplateResponse(request, "install.html", ctx)
@ -901,14 +905,12 @@ async def admin_agent_prompt_page(
conn: duckdb.DuckDBPyConnection = Depends(_get_db),
):
from src.repositories.welcome_template import WelcomeTemplateRepository
from src.welcome_template import _load_default_template
row = WelcomeTemplateRepository(conn).get()
ctx = _build_context(
request,
user=user,
current=row["content"] or "",
default_template=_load_default_template(),
updated_at=row["updated_at"],
updated_by=row["updated_by"],
is_override=row["content"] is not None,

View file

@ -628,6 +628,21 @@
.manual-summary { padding: 14px 18px; }
.manual-body { padding: 16px 18px 18px; }
}
/* ── Admin-configured banner (above setup commands) ── */
.setup-banner {
background: var(--background, #f6f7f9);
border: 1px solid var(--border, #e1e4e8);
border-left: 3px solid var(--primary, #0073D1);
border-radius: 8px;
padding: 14px 18px;
margin-bottom: 20px;
font-size: 14px;
line-height: 1.6;
color: var(--text-primary);
}
.setup-banner > *:first-child { margin-top: 0; }
.setup-banner > *:last-child { margin-bottom: 0; }
</style>
{% include '_theme.html' %}
</head>
@ -648,6 +663,9 @@
<main class="main">
<!-- ═══════════════ ADMIN BANNER (optional) ═══════════════ -->
{% if banner_html %}<div class="setup-banner">{{ banner_html | safe }}</div>{% endif %}
<!-- ═══════════════ HERO ═══════════════ -->
<section class="hero">
<div class="hero-eyebrow">Getting started</div>

View file

@ -294,51 +294,15 @@ def _install_claude_hooks(settings_path: Path) -> None:
# ---------------------------------------------------------------------------
# Helper: generate CLAUDE.md from server-rendered template
# Helper: initialise Claude workspace (.claude/ directory)
# ---------------------------------------------------------------------------
def _generate_claude_md(workspace: Path, server_url: str, token: str) -> None:
"""Fetch the rendered welcome prompt from the server and write CLAUDE.md.
def _init_claude_workspace(workspace: Path) -> None:
"""Initialise the .claude/ directory with placeholder files and hooks.
Falls back to a minimal embedded template if the server endpoint is
unavailable (e.g., older server versions before /api/welcome shipped).
Does NOT write CLAUDE.md workspace-context customisation is handled
server-side via the banner on /setup, not as a file in the workspace.
"""
from urllib.parse import quote
server_url = server_url.rstrip("/")
headers = {"Authorization": f"Bearer {token}"}
url = f"{server_url}/api/welcome?server_url={quote(server_url, safe='')}"
rendered: Optional[str] = None
try:
resp = httpx.get(url, headers=headers, timeout=15.0)
if resp.status_code == 200:
rendered = resp.json().get("content")
elif resp.status_code != 404:
typer.echo(
f" Warning: server returned {resp.status_code} for /api/welcome; "
"using minimal fallback. Tell your admin if this persists.",
err=True,
)
except Exception as e:
typer.echo(
f" Warning: couldn't fetch welcome prompt ({e}); using minimal fallback.",
err=True,
)
if rendered is None:
# Fallback for older servers — keeps the CLI usable, just less rich.
rendered = (
"# AI Data Analyst\n\n"
f"This workspace is connected to {server_url}.\n\n"
"## Rules\n"
"- Before computing any business metric: run `da metrics show <category>/<name>`\n"
"- Save work output to `user/artifacts/`\n"
"- Sync data regularly with `da sync`\n"
)
(workspace / "CLAUDE.md").write_text(rendered, encoding="utf-8")
local_md = workspace / ".claude" / "CLAUDE.local.md"
if not local_md.exists():
local_md.write_text(
@ -421,9 +385,9 @@ def setup(
typer.echo("Initialising DuckDB views...")
total_rows = _initialize_duckdb(workspace)
# 7. Generate CLAUDE.md (rendered server-side)
typer.echo("Fetching welcome prompt from server...")
_generate_claude_md(workspace, server_url, token)
# 7. Initialise Claude workspace (.claude/ hooks + placeholder)
typer.echo("Initializing Claude workspace...")
_init_claude_workspace(workspace)
# 8. Summary
typer.echo("")

View file

@ -1,139 +1,106 @@
"""Render the analyst-onboarding welcome prompt (CLAUDE.md).
"""Render the agent-setup-prompt banner shown on /setup.
Two layers:
1. Template source admin override from welcome_template.content,
or the shipped default at config/claude_md_template.txt.
2. Render context built from instance config, table_registry,
metric_definitions, and the calling user's RBAC-filtered marketplaces.
The banner is a small HTML snippet admin-editable at /admin/agent-prompt.
It appears above the bash bootstrap commands on the /setup page and is
intended for org-specific operational notes (VPN warning, support channel,
data classification reminder, platform requirements).
The Jinja2 environment uses StrictUndefined so that any typo in the
template raises immediately rather than rendering empty strings.
Default: no banner (empty string). Admins override via the welcome_template
DB table (singleton, content TEXT).
Security: output is HTML-sanitized after render (script/iframe/event-handler
strip). The Jinja2 environment uses StrictUndefined with autoescape=True so
template typos raise immediately rather than silently emitting empty HTML.
"""
# See also: surfaced as the "Agent Setup Prompt" admin editor at /admin/agent-prompt.
from __future__ import annotations
import logging
import re
from datetime import date, datetime, timezone
from pathlib import Path
from typing import Any
from urllib.parse import urlparse
import duckdb
from jinja2 import Environment, StrictUndefined
from jinja2 import Environment, StrictUndefined, TemplateError
from app.instance_config import (
get_data_source_type,
get_instance_name,
get_instance_subtitle,
get_sync_interval,
)
from src.marketplace_filter import resolve_allowed_plugins
from src.repositories.welcome_template import WelcomeTemplateRepository
_DEFAULT_TEMPLATE_PATH = (
Path(__file__).resolve().parent.parent / "config" / "claude_md_template.txt"
logger = logging.getLogger(__name__)
# ---------------------------------------------------------------------------
# HTML sanitization
# ---------------------------------------------------------------------------
_RE_SCRIPT = re.compile(
r"<script[\s\S]*?</script>", re.IGNORECASE
)
_RE_IFRAME = re.compile(
r"<iframe[\s\S]*?(?:</iframe>|/>)", re.IGNORECASE
)
_RE_ON_EVENT = re.compile(
r"\s+on\w+\s*=\s*(?:\"[^\"]*\"|'[^']*'|[^\s>]+)", re.IGNORECASE
)
_RE_JS_URI = re.compile(
r"""(?:href|src|action)\s*=\s*(?:"|')(javascript:|data:)""", re.IGNORECASE
)
def _load_default_template() -> str:
if _DEFAULT_TEMPLATE_PATH.exists():
return _DEFAULT_TEMPLATE_PATH.read_text(encoding="utf-8")
# Last-resort embedded fallback if the OSS template file is missing
# from the install (e.g., partial Docker COPY).
return (
"# {{ instance.name }} — AI Data Analyst\n\n"
"This workspace is connected to {{ server.url }}.\n"
"Data refreshes every {{ sync_interval }}.\n"
)
def _sanitize_banner_html(html: str) -> str:
"""Strip dangerous constructs from admin-authored HTML.
Defense-in-depth only admins are trusted, but this prevents accidental
XSS from copy-pasted snippets reaching the public /setup page.
def _list_tables(conn: duckdb.DuckDBPyConnection) -> list[dict[str, Any]]:
try:
rows = conn.execute(
"""SELECT name, description, query_mode
FROM table_registry
ORDER BY name"""
).fetchall()
except duckdb.CatalogException:
return []
return [
{"name": r[0], "description": r[1] or "", "query_mode": r[2] or "local"}
for r in rows
]
def _metrics_summary(conn: duckdb.DuckDBPyConnection) -> dict[str, Any]:
try:
rows = conn.execute(
"SELECT category, COUNT(*) FROM metric_definitions GROUP BY category"
).fetchall()
except duckdb.CatalogException:
return {"count": 0, "categories": []}
return {
"count": sum(r[1] for r in rows),
"categories": sorted({r[0] for r in rows if r[0]}),
}
def _marketplaces_for_user(
conn: duckdb.DuckDBPyConnection, user: dict[str, Any]
) -> list[dict[str, Any]]:
"""Return marketplaces with the plugins the user is allowed to see.
Delegates RBAC filtering entirely to resolve_allowed_plugins, which
returns List[dict] with marketplace_slug, original_name, etc.
Results are grouped by marketplace slug; display names are fetched
from marketplace_registry in a single query.
Strips:
- <script></script> blocks (any content)
- <iframe></iframe> tags
- on*= event handler attributes (onclick=, onload=, etc.)
- javascript: / data: URI schemes in href/src/action attributes
"""
try:
allowed = resolve_allowed_plugins(conn, user)
except duckdb.CatalogException:
return []
if not allowed:
return []
html = _RE_SCRIPT.sub("", html)
html = _RE_IFRAME.sub("", html)
html = _RE_ON_EVENT.sub("", html)
html = _RE_JS_URI.sub(
lambda m: m.group(0).replace(m.group(1), "#"), html
)
return html
# Build slug → display name lookup from registry
slugs = list({p["marketplace_slug"] for p in allowed})
placeholders = ",".join(["?"] * len(slugs))
try:
name_rows = conn.execute(
f"SELECT id, name FROM marketplace_registry WHERE id IN ({placeholders})",
slugs,
).fetchall()
slug_to_name: dict[str, str] = {r[0]: r[1] for r in name_rows}
except duckdb.CatalogException:
slug_to_name = {}
grouped: dict[str, dict[str, Any]] = {}
for plugin in allowed:
slug = plugin["marketplace_slug"]
bucket = grouped.setdefault(
slug,
{
"slug": slug,
"name": slug_to_name.get(slug, slug),
"plugins": [],
},
)
bucket["plugins"].append({"name": plugin["original_name"]})
return list(grouped.values())
# ---------------------------------------------------------------------------
# Render context
# ---------------------------------------------------------------------------
def build_context(
conn: duckdb.DuckDBPyConnection,
*,
user: dict[str, Any],
user: dict[str, Any] | None,
server_url: str,
) -> dict[str, Any]:
"""Compose the Jinja2 render context. Pure, no side effects.
"""Compose the Jinja2 render context for the banner.
Note: ``now`` is tz-aware UTC; DB-sourced timestamps elsewhere in the
codebase are naive (DuckDB stores ``TIMESTAMP``, not ``TIMESTAMPTZ``).
Don't subtract or compare them inside templates without normalising.
Intentionally small: instance identity, server URL, and the requesting
user (may be None for anonymous /setup visitors). No tables, metrics, or
marketplaces the banner is for org-operational notes, not data-catalog
content.
Note: ``now`` is tz-aware UTC.
"""
now = datetime.now(timezone.utc)
parsed = urlparse(server_url)
user_ctx: dict[str, Any] | None = None
if user:
user_ctx = {
"id": user.get("id", ""),
"email": user.get("email", ""),
"name": user.get("name") or "",
"is_admin": bool(user.get("is_admin")),
"groups": user.get("groups") or [],
}
return {
"instance": {
"name": get_instance_name(),
@ -143,36 +110,44 @@ def build_context(
"url": server_url,
"hostname": parsed.hostname or "",
},
"sync_interval": get_sync_interval(),
"data_source": {"type": get_data_source_type()},
"tables": _list_tables(conn),
"metrics": _metrics_summary(conn),
"marketplaces": _marketplaces_for_user(conn, user),
"user": {
"id": user.get("id", ""),
"email": user.get("email", ""),
"name": user.get("name") or "",
"is_admin": bool(user.get("is_admin")),
"groups": user.get("groups") or [],
},
"user": user_ctx,
"now": now,
"today": date.today().isoformat(),
}
def _resolve_template_source(conn: duckdb.DuckDBPyConnection) -> str:
row = WelcomeTemplateRepository(conn).get()
return row["content"] if row.get("content") else _load_default_template()
# ---------------------------------------------------------------------------
# Banner renderer
# ---------------------------------------------------------------------------
def render_welcome(
def render_agent_prompt_banner(
conn: duckdb.DuckDBPyConnection,
*,
user: dict[str, Any],
user: dict[str, Any] | None,
server_url: str,
) -> str:
"""Resolve the active template and render it for the given user."""
source = _resolve_template_source(conn)
env = Environment(undefined=StrictUndefined, autoescape=False)
template = env.from_string(source)
return template.render(**build_context(conn, user=user, server_url=server_url))
"""Render the admin-configured HTML banner for the /setup page.
Returns an empty string when no override is set (default = no banner).
Render failures are swallowed (logged) and return empty string so a
broken template never blocks the /setup page from rendering.
"""
row = WelcomeTemplateRepository(conn).get()
content = row.get("content")
if not content:
return ""
try:
env = Environment(undefined=StrictUndefined, autoescape=True)
template = env.from_string(content)
ctx = build_context(user=user, server_url=server_url)
rendered = template.render(**ctx)
return _sanitize_banner_html(rendered)
except TemplateError as exc:
logger.warning(
"Agent-prompt banner render failed (template error): %s", exc
)
return ""
except Exception:
logger.exception("Agent-prompt banner render failed (unexpected)")
return ""

View file

@ -78,7 +78,7 @@ class TestDetectExistingProject:
patch("cli.commands.analyst._download_metadata"), \
patch("cli.commands.analyst._download_data", return_value=0), \
patch("cli.commands.analyst._initialize_duckdb", return_value=0), \
patch("cli.commands.analyst._generate_claude_md"):
patch("cli.commands.analyst._init_claude_workspace"):
result = runner.invoke(
app,
["analyst", "setup", "--server-url", "http://localhost:8000", "--force"],
@ -116,59 +116,52 @@ class TestCreateWorkspace:
# ---------------------------------------------------------------------------
# TestGenerateClaudeMd
# TestInitClaudeWorkspace
# ---------------------------------------------------------------------------
class TestGenerateClaudeMd:
"""Server-side render flow: _generate_claude_md fetches /api/welcome.
The local-fallback path is exercised by tests/test_cli_analyst_welcome.py.
These tests cover the side-effects on the workspace (CLAUDE.local.md,
settings.json) and verify the new signature is honored.
class TestInitClaudeWorkspace:
"""Tests for _init_claude_workspace: no CLAUDE.md written, but
.claude/CLAUDE.local.md placeholder and settings.json hooks are created.
"""
def _patch_httpx_404(self, monkeypatch):
"""Stub httpx.get to return 404 so _generate_claude_md falls back to embedded text."""
import httpx
def fake_get(url, headers=None, timeout=None):
return httpx.Response(
status_code=404, json={}, request=httpx.Request("GET", url)
)
monkeypatch.setattr("cli.commands.analyst.httpx", type("_M", (), {"get": fake_get}))
def test_creates_claude_local_md_when_absent(self, tmp_workspace, monkeypatch):
from cli.commands.analyst import _create_workspace, _generate_claude_md
def test_does_not_write_claude_md(self, tmp_workspace):
from cli.commands.analyst import _create_workspace, _init_claude_workspace
_create_workspace(tmp_workspace)
self._patch_httpx_404(monkeypatch)
_generate_claude_md(tmp_workspace, server_url="http://localhost:8000", token="t")
_init_claude_workspace(tmp_workspace)
assert not (tmp_workspace / "CLAUDE.md").exists(), (
"CLAUDE.md must NOT be written by _init_claude_workspace"
)
def test_creates_claude_local_md_when_absent(self, tmp_workspace):
from cli.commands.analyst import _create_workspace, _init_claude_workspace
_create_workspace(tmp_workspace)
_init_claude_workspace(tmp_workspace)
local_md = tmp_workspace / ".claude" / "CLAUDE.local.md"
assert local_md.exists()
assert local_md.read_text(encoding="utf-8").strip() != ""
def test_does_not_overwrite_existing_local_md(self, tmp_workspace, monkeypatch):
from cli.commands.analyst import _create_workspace, _generate_claude_md
def test_does_not_overwrite_existing_local_md(self, tmp_workspace):
from cli.commands.analyst import _create_workspace, _init_claude_workspace
_create_workspace(tmp_workspace)
local_md = tmp_workspace / ".claude" / "CLAUDE.local.md"
original_content = "# My custom notes\n\nDo not overwrite me.\n"
local_md.write_text(original_content, encoding="utf-8")
self._patch_httpx_404(monkeypatch)
_generate_claude_md(tmp_workspace, server_url="http://localhost:8000", token="t")
_init_claude_workspace(tmp_workspace)
assert local_md.read_text(encoding="utf-8") == original_content
def test_writes_settings_json(self, tmp_workspace, monkeypatch):
from cli.commands.analyst import _create_workspace, _generate_claude_md
def test_writes_settings_json(self, tmp_workspace):
from cli.commands.analyst import _create_workspace, _init_claude_workspace
import json as _json
_create_workspace(tmp_workspace)
self._patch_httpx_404(monkeypatch)
_generate_claude_md(tmp_workspace, server_url="http://localhost:8000", token="t")
_init_claude_workspace(tmp_workspace)
settings = _json.loads(
(tmp_workspace / ".claude" / "settings.json").read_text(encoding="utf-8")
@ -176,6 +169,21 @@ class TestGenerateClaudeMd:
assert settings["model"] == "sonnet"
assert "Read" in settings["permissions"]["allow"]
def test_installs_session_hooks(self, tmp_workspace):
"""SessionStart and SessionEnd hooks must be present in settings.json."""
from cli.commands.analyst import _create_workspace, _init_claude_workspace
import json as _json
_create_workspace(tmp_workspace)
_init_claude_workspace(tmp_workspace)
settings = _json.loads(
(tmp_workspace / ".claude" / "settings.json").read_text(encoding="utf-8")
)
hooks = settings.get("hooks", {})
assert "SessionStart" in hooks
assert "SessionEnd" in hooks
# ---------------------------------------------------------------------------
# TestReturningSession

View file

@ -1,89 +0,0 @@
"""Integration tests for da analyst setup → /api/welcome wiring."""
import json
from pathlib import Path
import httpx
import pytest
from cli.commands.analyst import _generate_claude_md
class _MockClient:
def __init__(self, responses):
self._responses = responses
self.calls = []
def get(self, url, headers=None, timeout=None):
self.calls.append(url)
body, status = self._responses.get(url, ({}, 404))
return httpx.Response(status_code=status, json=body, request=httpx.Request("GET", url))
def _ws(tmp_path: Path) -> Path:
workspace = tmp_path / "ws"
(workspace / ".claude").mkdir(parents=True)
return workspace
def test_generate_claude_md_uses_server_render(tmp_path, monkeypatch):
workspace = _ws(tmp_path)
rendered = "# CUSTOM\n\nFrom server.\n"
mock = _MockClient({
"https://example.com/api/welcome?server_url=https%3A%2F%2Fexample.com": (
{"content": rendered}, 200
),
})
monkeypatch.setattr("cli.commands.analyst.httpx", type("_M", (), {"get": mock.get}))
_generate_claude_md(workspace, server_url="https://example.com", token="t")
assert (workspace / "CLAUDE.md").read_text(encoding="utf-8") == rendered
# Workspace side-effects are created on the success path too.
assert (workspace / ".claude" / "CLAUDE.local.md").exists()
settings = json.loads((workspace / ".claude" / "settings.json").read_text(encoding="utf-8"))
assert settings["model"] == "sonnet"
def test_generate_claude_md_falls_back_on_404(tmp_path, monkeypatch):
workspace = _ws(tmp_path)
mock = _MockClient({}) # everything 404s
monkeypatch.setattr("cli.commands.analyst.httpx", type("_M", (), {"get": mock.get}))
_generate_claude_md(workspace, server_url="https://example.com", token="t")
body = (workspace / "CLAUDE.md").read_text(encoding="utf-8")
assert "AI Data Analyst" in body
assert "https://example.com" in body
def test_generate_claude_md_falls_back_on_null_content(tmp_path, monkeypatch):
"""Server returns 200 but malformed body (`content: null`). CLI must use fallback."""
workspace = _ws(tmp_path)
mock = _MockClient({
"https://example.com/api/welcome?server_url=https%3A%2F%2Fexample.com": (
{"content": None}, 200
),
})
monkeypatch.setattr("cli.commands.analyst.httpx", type("_M", (), {"get": mock.get}))
_generate_claude_md(workspace, server_url="https://example.com", token="t")
body = (workspace / "CLAUDE.md").read_text(encoding="utf-8")
# Embedded fallback contains these literals
assert "AI Data Analyst" in body
assert "https://example.com" in body
def test_generate_claude_md_warns_on_5xx(tmp_path, monkeypatch, capsys):
"""500 from server → embedded fallback, with a stderr warning so operators can diagnose."""
workspace = _ws(tmp_path)
mock = _MockClient({
"https://example.com/api/welcome?server_url=https%3A%2F%2Fexample.com": (
{"detail": "boom"}, 500
),
})
monkeypatch.setattr("cli.commands.analyst.httpx", type("_M", (), {"get": mock.get}))
_generate_claude_md(workspace, server_url="https://example.com", token="t")
body = (workspace / "CLAUDE.md").read_text(encoding="utf-8")
assert "AI Data Analyst" in body # fallback used
captured = capsys.readouterr()
assert "500" in captured.err
assert "fallback" in captured.err.lower()

View file

@ -1,4 +1,8 @@
"""End-to-end tests for /api/welcome and /api/admin/welcome-template."""
"""End-to-end tests for /api/admin/welcome-template (banner editor endpoints).
GET /api/welcome has been removed the analyst-facing endpoint is gone.
These tests cover only the admin CRUD + preview endpoints.
"""
import duckdb
@ -10,7 +14,8 @@ def _auth(token: str) -> dict[str, str]:
return {"Authorization": f"Bearer {token}"}
def test_get_welcome_returns_rendered_markdown(seeded_app):
def test_get_welcome_endpoint_removed(seeded_app):
"""GET /api/welcome must return 404 — the endpoint was deleted."""
c = seeded_app["client"]
token = seeded_app["analyst_token"]
resp = c.get(
@ -18,48 +23,39 @@ def test_get_welcome_returns_rendered_markdown(seeded_app):
params={"server_url": "https://example.com"},
headers=_auth(token),
)
assert resp.status_code == 200
body = resp.json()
assert "content" in body
assert "AI Data Analyst" in body["content"]
assert "https://example.com" in body["content"]
assert resp.status_code == 404
def test_get_welcome_requires_auth(seeded_app):
def test_admin_get_template_initially_null(seeded_app):
c = seeded_app["client"]
resp = c.get("/api/welcome", params={"server_url": "https://example.com"})
assert resp.status_code == 401
admin = _auth(seeded_app["admin_token"])
r = c.get("/api/admin/welcome-template", headers=admin)
assert r.status_code == 200
body = r.json()
assert body["content"] is None
# No longer returns a `default` field — banner default is empty
assert "default" not in body or body.get("default") is None
def test_admin_can_set_and_reset_template(seeded_app):
c = seeded_app["client"]
admin = _auth(seeded_app["admin_token"])
# GET initial state
r = c.get("/api/admin/welcome-template", headers=admin)
assert r.status_code == 200
body = r.json()
assert body["content"] is None
# The shipped default starts with the Jinja2 comment block.
assert body["default"].startswith("{#")
# PUT override
r = c.put(
"/api/admin/welcome-template",
json={"content": "Hello {{ user.email }}"},
json={"content": "<p>Hello {{ user.email }}</p>"},
headers=admin,
)
assert r.status_code == 200
# Verify rendered output uses override
r = c.get(
"/api/welcome",
params={"server_url": "https://example.com"},
headers=admin, # admin user can also call /api/welcome
)
assert r.json()["content"].startswith("Hello ")
# GET reflects override
r = c.get("/api/admin/welcome-template", headers=admin)
assert r.status_code == 200
assert r.json()["content"] == "<p>Hello {{ user.email }}</p>"
# DELETE = reset
# DELETE = reset (no banner)
r = c.delete("/api/admin/welcome-template", headers=admin)
assert r.status_code == 204
r = c.get("/api/admin/welcome-template", headers=admin)
@ -69,7 +65,7 @@ def test_admin_can_set_and_reset_template(seeded_app):
def test_non_admin_cannot_edit_template(seeded_app):
c = seeded_app["client"]
analyst = _auth(seeded_app["analyst_token"])
r = c.put("/api/admin/welcome-template", json={"content": "x"}, headers=analyst)
r = c.put("/api/admin/welcome-template", json={"content": "<p>x</p>"}, headers=analyst)
assert r.status_code == 403
@ -86,58 +82,29 @@ def test_invalid_jinja2_returns_400(seeded_app):
def test_put_rejects_undefined_placeholder(seeded_app):
"""Templates that parse but reference unknown placeholders must be rejected
at PUT time so an admin can fix the typo immediately rather than after an
analyst's bootstrap blows up."""
"""Templates that reference unknown placeholders must be rejected at PUT time."""
c = seeded_app["client"]
admin = _auth(seeded_app["admin_token"])
r = c.put(
"/api/admin/welcome-template",
json={"content": "Hello {{ user.emial }}"}, # typo, would fail StrictUndefined at render
json={"content": "<p>{{ user.emial }}</p>"}, # typo
headers=admin,
)
assert r.status_code == 400
assert "emial" in r.json()["detail"] or "undefined" in r.json()["detail"].lower()
def test_get_welcome_500_includes_reset_hint_on_render_failure(seeded_app, monkeypatch):
"""If an override slips through validation and fails at render time, the
user-visible 500 must point at /admin/agent-prompt rather than leaking a
Jinja stack trace."""
# Stub render_welcome to raise a TemplateError so we exercise the
# exception path without needing a malformed override (PUT validation
# blocks those now).
from jinja2 import UndefinedError
import app.api.welcome as welcome_module
def fake_render(*args, **kwargs):
raise UndefinedError("'foo' is undefined")
monkeypatch.setattr(welcome_module, "render_welcome", fake_render)
c = seeded_app["client"]
admin = _auth(seeded_app["admin_token"])
r = c.get(
"/api/welcome",
params={"server_url": "https://example.com"},
headers=admin,
)
assert r.status_code == 500
assert "/admin/agent-prompt" in r.json()["detail"]
def test_admin_preview_renders_arbitrary_content(seeded_app):
"""Preview endpoint must render the supplied content (not whatever's
stored), so the admin UI can show pre-save preview."""
def test_admin_preview_renders_html(seeded_app):
"""Preview endpoint renders supplied HTML content without persisting."""
c = seeded_app["client"]
admin = _auth(seeded_app["admin_token"])
r = c.post(
"/api/admin/welcome-template/preview",
json={"content": "# Preview {{ user.email }}"},
json={"content": "<p>Preview for {{ user.email }}</p>"},
headers=admin,
)
assert r.status_code == 200
assert r.json()["content"].startswith("# Preview admin@test.com")
assert r.json()["content"].startswith("<p>Preview for admin@test.com")
def test_preview_rejects_invalid_template(seeded_app):
@ -156,26 +123,29 @@ def test_preview_requires_admin(seeded_app):
analyst = _auth(seeded_app["analyst_token"])
r = c.post(
"/api/admin/welcome-template/preview",
json={"content": "# x"},
json={"content": "<p>x</p>"},
headers=analyst,
)
assert r.status_code == 403
def test_validation_stub_matches_build_context_shape(seeded_app, tmp_path, monkeypatch):
"""If build_context grows new keys, _VALIDATION_STUB_CONTEXT must too —
otherwise admins can save templates referencing keys the PUT validator
accepts but the live render rejects."""
"""_VALIDATION_STUB_CONTEXT top-level keys must match build_context() output.
If build_context gains new keys, the stub must track them so admins can
save templates that reference those keys without hitting a live-render
rejection after the PUT validation accepted them.
"""
from app.api.welcome import _VALIDATION_STUB_CONTEXT
monkeypatch.setenv("DATA_DIR", str(tmp_path))
db_path = tmp_path / "system.duckdb"
conn = duckdb.connect(str(db_path))
_ensure_schema(conn)
user = {"id": "u1", "email": "admin@test.com", "name": "Admin", "is_admin": True, "groups": ["Admin"]}
real_ctx = build_context(conn, user=user, server_url="https://example.com")
conn.close()
user = {
"id": "u1",
"email": "admin@test.com",
"name": "Admin",
"is_admin": True,
"groups": ["Admin"],
}
real_ctx = build_context(user=user, server_url="https://example.com")
# Top-level keys must match
assert set(_VALIDATION_STUB_CONTEXT.keys()) == set(real_ctx.keys()), (
@ -184,9 +154,13 @@ def test_validation_stub_matches_build_context_shape(seeded_app, tmp_path, monke
f"real has: {set(real_ctx.keys())}"
)
# One level deep for nested dicts
for key in ("instance", "server", "user"):
if isinstance(real_ctx.get(key), dict):
assert set(_VALIDATION_STUB_CONTEXT[key].keys()) == set(real_ctx[key].keys()), (
f"_VALIDATION_STUB_CONTEXT[{key!r}] drifted from build_context output"
)
# One level deep for nested dicts (user may be None in real_ctx — compare stub shape)
for key in ("instance", "server"):
assert set(_VALIDATION_STUB_CONTEXT[key].keys()) == set(real_ctx[key].keys()), (
f"_VALIDATION_STUB_CONTEXT[{key!r}] drifted from build_context output"
)
# user sub-keys
if real_ctx.get("user") and _VALIDATION_STUB_CONTEXT.get("user"):
assert set(_VALIDATION_STUB_CONTEXT["user"].keys()) == set(real_ctx["user"].keys()), (
"_VALIDATION_STUB_CONTEXT['user'] drifted from build_context output"
)

View file

@ -1,14 +1,15 @@
"""Unit tests for the welcome-prompt renderer."""
import uuid
from pathlib import Path
"""Unit tests for the agent-setup-prompt banner renderer."""
import duckdb
import pytest
from src.db import _ensure_schema
from src.repositories.welcome_template import WelcomeTemplateRepository
from src.welcome_template import build_context, render_welcome
from src.welcome_template import (
_sanitize_banner_html,
build_context,
render_agent_prompt_banner,
)
@pytest.fixture
@ -22,144 +23,188 @@ def conn(tmp_path, monkeypatch):
def _user(email="alice@example.com"):
return {"id": "u1", "email": email, "name": "Alice", "is_admin": False, "groups": ["Everyone"]}
return {
"id": "u1",
"email": email,
"name": "Alice",
"is_admin": False,
"groups": ["Everyone"],
}
def test_renders_default_when_no_override(conn):
out = render_welcome(conn, user=_user(), server_url="https://example.com")
assert "AI Data Analyst" in out
assert "https://example.com" in out
assert "Alice" in out
# ---------------------------------------------------------------------------
# Default (no override) → empty string
# ---------------------------------------------------------------------------
def test_returns_empty_when_no_override(conn):
out = render_agent_prompt_banner(conn, user=_user(), server_url="https://example.com")
assert out == ""
# ---------------------------------------------------------------------------
# Override renders correctly
# ---------------------------------------------------------------------------
def test_renders_override(conn):
WelcomeTemplateRepository(conn).set(
"# {{ instance.name }} for {{ user.email }}",
"<p>Welcome to {{ instance.name }}!</p>",
updated_by="admin@example.com",
)
out = render_welcome(conn, user=_user(), server_url="https://example.com")
assert out.startswith("# AI Data Analyst for alice@example.com")
out = render_agent_prompt_banner(conn, user=_user(), server_url="https://example.com")
assert "<p>Welcome to" in out
# instance.name comes from instance_config — any non-empty string is fine
assert "!" in out
def test_strict_undefined_raises_on_missing_placeholder(conn):
def test_renders_user_placeholder(conn):
WelcomeTemplateRepository(conn).set(
"<p>Hello {{ user.email }}</p>",
updated_by="admin@example.com",
)
out = render_agent_prompt_banner(
conn, user=_user("bob@example.com"), server_url="https://example.com"
)
assert "bob@example.com" in out
def test_renders_server_placeholder(conn):
WelcomeTemplateRepository(conn).set(
"<p>Server: {{ server.url }}</p>",
updated_by="admin@example.com",
)
out = render_agent_prompt_banner(
conn, user=_user(), server_url="https://myserver.example.com"
)
assert "https://myserver.example.com" in out
# ---------------------------------------------------------------------------
# Anonymous user (user=None)
# ---------------------------------------------------------------------------
def test_renders_with_anonymous_user(conn):
WelcomeTemplateRepository(conn).set(
"{% if user %}<p>Hi {{ user.email }}</p>{% else %}<p>Please sign in.</p>{% endif %}",
updated_by="admin@example.com",
)
out = render_agent_prompt_banner(conn, user=None, server_url="https://example.com")
assert "Please sign in." in out
assert "Hi" not in out
def test_returns_empty_for_none_user_with_no_override(conn):
out = render_agent_prompt_banner(conn, user=None, server_url="https://example.com")
assert out == ""
# ---------------------------------------------------------------------------
# Build context shape
# ---------------------------------------------------------------------------
def test_context_exposes_documented_keys():
ctx = build_context(user=_user(), server_url="https://example.com")
for key in ("instance", "server", "user", "now", "today"):
assert key in ctx, f"missing context key: {key}"
assert "tables" not in ctx
assert "metrics" not in ctx
assert "marketplaces" not in ctx
assert "sync_interval" not in ctx
assert "data_source" not in ctx
def test_context_user_none():
ctx = build_context(user=None, server_url="https://example.com")
assert ctx["user"] is None
def test_context_instance_keys():
ctx = build_context(user=_user(), server_url="https://example.com")
assert "name" in ctx["instance"]
assert "subtitle" in ctx["instance"]
def test_context_server_keys():
ctx = build_context(user=_user(), server_url="https://example.com")
assert ctx["server"]["url"] == "https://example.com"
assert ctx["server"]["hostname"] == "example.com"
# ---------------------------------------------------------------------------
# HTML sanitization
# ---------------------------------------------------------------------------
def test_sanitize_strips_script_tag():
html = '<p>Hello</p><script>alert("xss")</script>'
result = _sanitize_banner_html(html)
assert "<script>" not in result
assert "alert" not in result
assert "<p>Hello</p>" in result
def test_sanitize_strips_script_with_attributes():
html = '<script type="text/javascript">evil()</script><p>ok</p>'
result = _sanitize_banner_html(html)
assert "evil" not in result
assert "<p>ok</p>" in result
def test_sanitize_strips_iframe():
html = '<p>text</p><iframe src="https://evil.example.com"></iframe>'
result = _sanitize_banner_html(html)
assert "<iframe" not in result
assert "<p>text</p>" in result
def test_sanitize_strips_event_handlers():
html = '<button onclick="evil()">Click me</button>'
result = _sanitize_banner_html(html)
assert "onclick" not in result
assert "evil" not in result
assert "Click me" in result
def test_sanitize_strips_onload_on_img():
html = '<img src="x" onload="steal()" alt="test">'
result = _sanitize_banner_html(html)
assert "onload" not in result
assert "steal" not in result
def test_sanitize_strips_javascript_uri():
html = '<a href="javascript:alert(1)">click</a>'
result = _sanitize_banner_html(html)
assert "javascript:" not in result
def test_sanitize_allows_safe_html():
html = "<p>VPN required. Contact <a href='https://support.example.com'>support</a>.</p>"
result = _sanitize_banner_html(html)
assert "<p>" in result
assert "<a href" in result
assert "support" in result
# ---------------------------------------------------------------------------
# Render failure → empty string (not exception)
# ---------------------------------------------------------------------------
def test_render_failure_returns_empty_not_exception(conn):
# StrictUndefined: referencing an unknown variable raises at render time.
WelcomeTemplateRepository(conn).set(
"{{ does_not_exist }}", updated_by="admin@example.com"
)
with pytest.raises(Exception) as exc_info:
render_welcome(conn, user=_user(), server_url="https://example.com")
assert "does_not_exist" in str(exc_info.value)
out = render_agent_prompt_banner(conn, user=_user(), server_url="https://example.com")
# Must return empty string, not raise
assert out == ""
def test_context_exposes_documented_keys(conn):
ctx = build_context(conn, user=_user(), server_url="https://example.com")
for top in ("instance", "server", "sync_interval", "data_source",
"tables", "metrics", "marketplaces", "user", "now", "today"):
assert top in ctx, f"missing top-level key: {top}"
def test_render_marketplaces_filtered_by_rbac(conn, monkeypatch):
"""Two users with different group memberships render different marketplace lists."""
from app.resource_types import ResourceType
# ── Seed two marketplaces ────────────────────────────────────────────
conn.execute(
"""INSERT INTO marketplace_registry (id, name, url) VALUES
('mkt-a', 'Marketplace A', 'https://github.com/example/mkt-a'),
('mkt-b', 'Marketplace B', 'https://github.com/example/mkt-b')"""
)
# Two plugins per marketplace
for mkt, plugins in [("mkt-a", ["plugin-1", "plugin-2"]), ("mkt-b", ["plugin-3", "plugin-4"])]:
for p in plugins:
conn.execute(
"INSERT INTO marketplace_plugins (marketplace_id, name) VALUES (?, ?)",
[mkt, p],
)
# ── Seed two non-system groups ──────────────────────────────────────
gid_a = str(uuid.uuid4())
gid_b = str(uuid.uuid4())
conn.execute(
"INSERT INTO user_groups (id, name) VALUES (?, ?), (?, ?)",
[gid_a, "group-a", gid_b, "group-b"],
)
# ── Grant mkt-a/* to group-a and mkt-b/* to group-b ─────────────────
rtype = ResourceType.MARKETPLACE_PLUGIN.value
for mkt, gid, plugins in [
("mkt-a", gid_a, ["plugin-1", "plugin-2"]),
("mkt-b", gid_b, ["plugin-3", "plugin-4"]),
]:
for p in plugins:
conn.execute(
"INSERT INTO resource_grants (id, group_id, resource_type, resource_id) "
"VALUES (?, ?, ?, ?)",
[str(uuid.uuid4()), gid, rtype, f"{mkt}/{p}"],
)
# ── Seed two users, each in their own group + Everyone ───────────────
everyone_gid = conn.execute(
"SELECT id FROM user_groups WHERE name = 'Everyone'"
).fetchone()[0]
conn.execute(
"INSERT INTO users (id, email, name, active) VALUES "
"('user-a', 'user-a@example.com', 'User A', TRUE), "
"('user-b', 'user-b@example.com', 'User B', TRUE)"
)
for uid, gid in [("user-a", gid_a), ("user-b", gid_b)]:
conn.execute(
"INSERT INTO user_group_members (user_id, group_id, source) VALUES (?, ?, ?)",
[uid, gid, "admin"],
)
conn.execute(
"INSERT INTO user_group_members (user_id, group_id, source) VALUES (?, ?, ?)",
[uid, everyone_gid, "system_seed"],
)
# ── Render for each user ─────────────────────────────────────────────
def test_sanitize_applied_after_render(conn):
"""A template that produces <script> output is sanitized before return."""
WelcomeTemplateRepository(conn).set(
"{% for m in marketplaces %}{{ m.slug }}: "
"{% for p in m.plugins %}{{ p.name }} {% endfor %}{% endfor %}",
"<script>evil()</script><p>safe content</p>",
updated_by="admin@example.com",
)
user_a = {"id": "user-a", "email": "user-a@example.com", "name": "User A", "is_admin": False, "groups": ["group-a"]}
user_b = {"id": "user-b", "email": "user-b@example.com", "name": "User B", "is_admin": False, "groups": ["group-b"]}
out_a = render_welcome(conn, user=user_a, server_url="https://example.com")
out_b = render_welcome(conn, user=user_b, server_url="https://example.com")
# user-a sees mkt-a plugins only
assert "mkt-a" in out_a
assert "plugin-1" in out_a
assert "mkt-b" not in out_a
assert "plugin-3" not in out_a
# user-b sees mkt-b plugins only
assert "mkt-b" in out_b
assert "plugin-3" in out_b
assert "mkt-a" not in out_b
assert "plugin-1" not in out_b
def test_render_tolerates_missing_optional_tables(tmp_path, monkeypatch):
"""A bare DuckDB without table_registry / marketplace_registry must still render."""
monkeypatch.setenv("DATA_DIR", str(tmp_path))
db_path = tmp_path / "bare.duckdb"
bare = duckdb.connect(str(db_path))
# Only seed the welcome_template singleton manually; no other tables.
bare.execute(
"""CREATE TABLE welcome_template (
id INTEGER PRIMARY KEY DEFAULT 1,
content TEXT,
updated_at TIMESTAMP,
updated_by VARCHAR
)"""
)
bare.execute("INSERT INTO welcome_template (id, content) VALUES (1, NULL)")
out = render_welcome(bare, user=_user(), server_url="https://example.com")
bare.close()
assert "AI Data Analyst" in out # default template still renders
# No tables → "_No tables registered yet_" branch from the default template
assert "No tables registered yet" in out
out = render_agent_prompt_banner(conn, user=_user(), server_url="https://example.com")
assert "<script>" not in out
assert "evil" not in out
assert "<p>safe content</p>" in out