feat(admin-prompt): variant C — banner on /setup, drop CLAUDE.md generation
- src/welcome_template.py: rewrite as HTML banner renderer (render_agent_prompt_banner); drop _list_tables, _metrics_summary, _marketplaces_for_user, render_welcome, _load_default_template. build_context now exposes only instance/server/user/now/today. _sanitize_banner_html strips script/iframe/on*/javascript: post-render. - app/api/welcome.py: drop get_welcome handler, WelcomeResponse, old _VALIDATION_STUB_CONTEXT. Admin endpoints stay at same URLs; validation stub updated to match new slim context. Preview now uses autoescape=True. - app/web/router.py: setup_page calls render_agent_prompt_banner and passes banner_html to install.html; admin_agent_prompt_page drops _load_default_template. - app/web/templates/install.html: add .setup-banner CSS + banner block above hero. - cli/commands/analyst.py: replace _generate_claude_md with _init_claude_workspace; no CLAUDE.md written, only .claude/CLAUDE.local.md placeholder + settings.json hooks. - tests: delete test_cli_analyst_welcome.py (tests deleted endpoint/function); rewrite TestGenerateClaudeMd → TestInitClaudeWorkspace; update api test to assert /api/welcome returns 404 and remove welcome-fetch tests.
This commit is contained in:
parent
60386b9c3c
commit
8db4c1645b
9 changed files with 416 additions and 540 deletions
|
|
@ -1,9 +1,9 @@
|
|||
"""REST endpoints for the analyst-onboarding welcome prompt.
|
||||
"""REST endpoints for the agent-setup-prompt banner.
|
||||
|
||||
- GET /api/welcome : render for the calling user (auth required)
|
||||
- GET /api/admin/welcome-template : raw template + shipped default (admin)
|
||||
- GET /api/admin/welcome-template : raw template override (admin)
|
||||
- PUT /api/admin/welcome-template : set override (admin)
|
||||
- DELETE /api/admin/welcome-template : reset to default (admin)
|
||||
- DELETE /api/admin/welcome-template : reset to default / no banner (admin)
|
||||
- POST /api/admin/welcome-template/preview : live preview without persisting (admin)
|
||||
"""
|
||||
|
||||
import datetime
|
||||
|
|
@ -11,14 +11,14 @@ import logging
|
|||
from typing import Optional
|
||||
|
||||
import duckdb
|
||||
from fastapi import APIRouter, Depends, HTTPException, Query, Request, Response
|
||||
from jinja2 import Environment, StrictUndefined, TemplateError, TemplateSyntaxError
|
||||
from fastapi import APIRouter, Depends, HTTPException, Request, Response
|
||||
from jinja2 import Environment, StrictUndefined, TemplateError
|
||||
from pydantic import BaseModel, Field
|
||||
|
||||
from app.auth.access import require_admin
|
||||
from app.auth.dependencies import _get_db, get_current_user
|
||||
from app.auth.dependencies import _get_db
|
||||
from src.repositories.welcome_template import WelcomeTemplateRepository
|
||||
from src.welcome_template import _load_default_template, build_context, render_welcome
|
||||
from src.welcome_template import build_context, render_agent_prompt_banner
|
||||
|
||||
logger = logging.getLogger(__name__)
|
||||
|
||||
|
|
@ -27,16 +27,11 @@ router = APIRouter(tags=["welcome"])
|
|||
|
||||
# Stub context used to validate that a saved template renders end-to-end,
|
||||
# not just that it parses. Mirrors the shape of build_context() output.
|
||||
# user may be None for anonymous visitors; the stub uses an authenticated
|
||||
# user so templates that reference user.* fields are validated.
|
||||
_VALIDATION_STUB_CONTEXT = {
|
||||
"instance": {"name": "Example", "subtitle": "Example Org"},
|
||||
"server": {"url": "https://example.com", "hostname": "example.com"},
|
||||
"sync_interval": "1 hour",
|
||||
"data_source": {"type": "local"},
|
||||
"tables": [{"name": "example", "description": "", "query_mode": "local"}],
|
||||
"metrics": {"count": 0, "categories": []},
|
||||
"marketplaces": [
|
||||
{"slug": "example", "name": "Example Marketplace", "plugins": [{"name": "x"}]}
|
||||
],
|
||||
"user": {
|
||||
"id": "u",
|
||||
"email": "user@example.com",
|
||||
|
|
@ -49,13 +44,12 @@ _VALIDATION_STUB_CONTEXT = {
|
|||
}
|
||||
|
||||
|
||||
class WelcomeResponse(BaseModel):
|
||||
class BannerResponse(BaseModel):
|
||||
content: str
|
||||
|
||||
|
||||
class TemplateGetResponse(BaseModel):
|
||||
content: Optional[str]
|
||||
default: str
|
||||
updated_at: Optional[str] = None
|
||||
updated_by: Optional[str] = None
|
||||
|
||||
|
|
@ -68,24 +62,6 @@ class TemplatePreviewRequest(BaseModel):
|
|||
content: str = Field(..., min_length=1, max_length=200_000)
|
||||
|
||||
|
||||
@router.get("/api/welcome", response_model=WelcomeResponse)
|
||||
async def get_welcome(
|
||||
server_url: str = Query(..., description="The server URL the analyst is bootstrapping against"),
|
||||
user: dict = Depends(get_current_user),
|
||||
conn: duckdb.DuckDBPyConnection = Depends(_get_db),
|
||||
):
|
||||
"""Render the welcome prompt for the calling user. Returns rendered markdown."""
|
||||
try:
|
||||
rendered = render_welcome(conn, user=user, server_url=server_url)
|
||||
except TemplateError as e:
|
||||
logger.warning("Welcome render failed: %s", e, exc_info=True)
|
||||
raise HTTPException(
|
||||
status_code=500,
|
||||
detail="Welcome template render failed. An admin can fix it at /admin/agent-prompt.",
|
||||
)
|
||||
return WelcomeResponse(content=rendered)
|
||||
|
||||
|
||||
@router.get("/api/admin/welcome-template", response_model=TemplateGetResponse)
|
||||
async def admin_get_template(
|
||||
user: dict = Depends(require_admin),
|
||||
|
|
@ -94,7 +70,6 @@ async def admin_get_template(
|
|||
row = WelcomeTemplateRepository(conn).get()
|
||||
return TemplateGetResponse(
|
||||
content=row["content"],
|
||||
default=_load_default_template(),
|
||||
updated_at=row["updated_at"].isoformat() if row["updated_at"] else None,
|
||||
updated_by=row["updated_by"],
|
||||
)
|
||||
|
|
@ -106,11 +81,13 @@ async def admin_put_template(
|
|||
user: dict = Depends(require_admin),
|
||||
conn: duckdb.DuckDBPyConnection = Depends(_get_db),
|
||||
):
|
||||
env = Environment(undefined=StrictUndefined)
|
||||
# Validate with autoescape=True (matches runtime environment) and
|
||||
# StrictUndefined so unknown placeholders are caught at save time.
|
||||
env = Environment(undefined=StrictUndefined, autoescape=True)
|
||||
try:
|
||||
template = env.from_string(payload.content)
|
||||
# Render against a stub context so undefined placeholders or runtime
|
||||
# errors are caught here, not when an analyst calls /api/welcome.
|
||||
# errors are caught here, not when /setup renders for a real user.
|
||||
template.render(**_VALIDATION_STUB_CONTEXT)
|
||||
except TemplateError as e:
|
||||
raise HTTPException(status_code=400, detail=f"Template invalid: {e}")
|
||||
|
|
@ -127,7 +104,7 @@ async def admin_reset_template(
|
|||
return Response(status_code=204)
|
||||
|
||||
|
||||
@router.post("/api/admin/welcome-template/preview", response_model=WelcomeResponse)
|
||||
@router.post("/api/admin/welcome-template/preview", response_model=BannerResponse)
|
||||
async def admin_preview_template(
|
||||
payload: TemplatePreviewRequest,
|
||||
request: Request,
|
||||
|
|
@ -137,11 +114,13 @@ async def admin_preview_template(
|
|||
"""Render arbitrary template content against the live context for the
|
||||
calling admin, without persisting. Used by the /admin/agent-prompt editor's
|
||||
Preview button so admins can see their edits before saving."""
|
||||
env = Environment(undefined=StrictUndefined, autoescape=False)
|
||||
env = Environment(undefined=StrictUndefined, autoescape=True)
|
||||
try:
|
||||
template = env.from_string(payload.content)
|
||||
ctx = build_context(conn, user=user, server_url=str(request.base_url).rstrip("/"))
|
||||
ctx = build_context(
|
||||
user=user, server_url=str(request.base_url).rstrip("/")
|
||||
)
|
||||
rendered = template.render(**ctx)
|
||||
except TemplateError as e:
|
||||
raise HTTPException(status_code=400, detail=f"Template invalid: {e}")
|
||||
return WelcomeResponse(content=rendered)
|
||||
return BannerResponse(content=rendered)
|
||||
|
|
|
|||
|
|
@ -727,13 +727,17 @@ async def setup_page(
|
|||
conn: duckdb.DuckDBPyConnection = Depends(_get_db),
|
||||
):
|
||||
"""Setup instructions for the local agent (CLI + Claude Code)."""
|
||||
from src.welcome_template import render_agent_prompt_banner
|
||||
|
||||
base_url = str(request.base_url).rstrip("/")
|
||||
banner_html = render_agent_prompt_banner(conn, user=user, server_url=base_url)
|
||||
ctx = _build_context(
|
||||
request,
|
||||
user=user,
|
||||
conn=conn,
|
||||
server_url=base_url,
|
||||
agnes_version=os.environ.get("AGNES_VERSION", "dev"),
|
||||
banner_html=banner_html,
|
||||
)
|
||||
return templates.TemplateResponse(request, "install.html", ctx)
|
||||
|
||||
|
|
@ -901,14 +905,12 @@ async def admin_agent_prompt_page(
|
|||
conn: duckdb.DuckDBPyConnection = Depends(_get_db),
|
||||
):
|
||||
from src.repositories.welcome_template import WelcomeTemplateRepository
|
||||
from src.welcome_template import _load_default_template
|
||||
|
||||
row = WelcomeTemplateRepository(conn).get()
|
||||
ctx = _build_context(
|
||||
request,
|
||||
user=user,
|
||||
current=row["content"] or "",
|
||||
default_template=_load_default_template(),
|
||||
updated_at=row["updated_at"],
|
||||
updated_by=row["updated_by"],
|
||||
is_override=row["content"] is not None,
|
||||
|
|
|
|||
|
|
@ -628,6 +628,21 @@
|
|||
.manual-summary { padding: 14px 18px; }
|
||||
.manual-body { padding: 16px 18px 18px; }
|
||||
}
|
||||
|
||||
/* ── Admin-configured banner (above setup commands) ── */
|
||||
.setup-banner {
|
||||
background: var(--background, #f6f7f9);
|
||||
border: 1px solid var(--border, #e1e4e8);
|
||||
border-left: 3px solid var(--primary, #0073D1);
|
||||
border-radius: 8px;
|
||||
padding: 14px 18px;
|
||||
margin-bottom: 20px;
|
||||
font-size: 14px;
|
||||
line-height: 1.6;
|
||||
color: var(--text-primary);
|
||||
}
|
||||
.setup-banner > *:first-child { margin-top: 0; }
|
||||
.setup-banner > *:last-child { margin-bottom: 0; }
|
||||
</style>
|
||||
{% include '_theme.html' %}
|
||||
</head>
|
||||
|
|
@ -648,6 +663,9 @@
|
|||
|
||||
<main class="main">
|
||||
|
||||
<!-- ═══════════════ ADMIN BANNER (optional) ═══════════════ -->
|
||||
{% if banner_html %}<div class="setup-banner">{{ banner_html | safe }}</div>{% endif %}
|
||||
|
||||
<!-- ═══════════════ HERO ═══════════════ -->
|
||||
<section class="hero">
|
||||
<div class="hero-eyebrow">Getting started</div>
|
||||
|
|
|
|||
|
|
@ -294,51 +294,15 @@ def _install_claude_hooks(settings_path: Path) -> None:
|
|||
|
||||
|
||||
# ---------------------------------------------------------------------------
|
||||
# Helper: generate CLAUDE.md from server-rendered template
|
||||
# Helper: initialise Claude workspace (.claude/ directory)
|
||||
# ---------------------------------------------------------------------------
|
||||
|
||||
def _generate_claude_md(workspace: Path, server_url: str, token: str) -> None:
|
||||
"""Fetch the rendered welcome prompt from the server and write CLAUDE.md.
|
||||
def _init_claude_workspace(workspace: Path) -> None:
|
||||
"""Initialise the .claude/ directory with placeholder files and hooks.
|
||||
|
||||
Falls back to a minimal embedded template if the server endpoint is
|
||||
unavailable (e.g., older server versions before /api/welcome shipped).
|
||||
Does NOT write CLAUDE.md — workspace-context customisation is handled
|
||||
server-side via the banner on /setup, not as a file in the workspace.
|
||||
"""
|
||||
from urllib.parse import quote
|
||||
|
||||
server_url = server_url.rstrip("/")
|
||||
headers = {"Authorization": f"Bearer {token}"}
|
||||
url = f"{server_url}/api/welcome?server_url={quote(server_url, safe='')}"
|
||||
|
||||
rendered: Optional[str] = None
|
||||
try:
|
||||
resp = httpx.get(url, headers=headers, timeout=15.0)
|
||||
if resp.status_code == 200:
|
||||
rendered = resp.json().get("content")
|
||||
elif resp.status_code != 404:
|
||||
typer.echo(
|
||||
f" Warning: server returned {resp.status_code} for /api/welcome; "
|
||||
"using minimal fallback. Tell your admin if this persists.",
|
||||
err=True,
|
||||
)
|
||||
except Exception as e:
|
||||
typer.echo(
|
||||
f" Warning: couldn't fetch welcome prompt ({e}); using minimal fallback.",
|
||||
err=True,
|
||||
)
|
||||
|
||||
if rendered is None:
|
||||
# Fallback for older servers — keeps the CLI usable, just less rich.
|
||||
rendered = (
|
||||
"# AI Data Analyst\n\n"
|
||||
f"This workspace is connected to {server_url}.\n\n"
|
||||
"## Rules\n"
|
||||
"- Before computing any business metric: run `da metrics show <category>/<name>`\n"
|
||||
"- Save work output to `user/artifacts/`\n"
|
||||
"- Sync data regularly with `da sync`\n"
|
||||
)
|
||||
|
||||
(workspace / "CLAUDE.md").write_text(rendered, encoding="utf-8")
|
||||
|
||||
local_md = workspace / ".claude" / "CLAUDE.local.md"
|
||||
if not local_md.exists():
|
||||
local_md.write_text(
|
||||
|
|
@ -421,9 +385,9 @@ def setup(
|
|||
typer.echo("Initialising DuckDB views...")
|
||||
total_rows = _initialize_duckdb(workspace)
|
||||
|
||||
# 7. Generate CLAUDE.md (rendered server-side)
|
||||
typer.echo("Fetching welcome prompt from server...")
|
||||
_generate_claude_md(workspace, server_url, token)
|
||||
# 7. Initialise Claude workspace (.claude/ hooks + placeholder)
|
||||
typer.echo("Initializing Claude workspace...")
|
||||
_init_claude_workspace(workspace)
|
||||
|
||||
# 8. Summary
|
||||
typer.echo("")
|
||||
|
|
|
|||
|
|
@ -1,139 +1,106 @@
|
|||
"""Render the analyst-onboarding welcome prompt (CLAUDE.md).
|
||||
"""Render the agent-setup-prompt banner shown on /setup.
|
||||
|
||||
Two layers:
|
||||
1. Template source — admin override from welcome_template.content,
|
||||
or the shipped default at config/claude_md_template.txt.
|
||||
2. Render context — built from instance config, table_registry,
|
||||
metric_definitions, and the calling user's RBAC-filtered marketplaces.
|
||||
The banner is a small HTML snippet admin-editable at /admin/agent-prompt.
|
||||
It appears above the bash bootstrap commands on the /setup page and is
|
||||
intended for org-specific operational notes (VPN warning, support channel,
|
||||
data classification reminder, platform requirements).
|
||||
|
||||
The Jinja2 environment uses StrictUndefined so that any typo in the
|
||||
template raises immediately rather than rendering empty strings.
|
||||
Default: no banner (empty string). Admins override via the welcome_template
|
||||
DB table (singleton, content TEXT).
|
||||
|
||||
Security: output is HTML-sanitized after render (script/iframe/event-handler
|
||||
strip). The Jinja2 environment uses StrictUndefined with autoescape=True so
|
||||
template typos raise immediately rather than silently emitting empty HTML.
|
||||
"""
|
||||
# See also: surfaced as the "Agent Setup Prompt" admin editor at /admin/agent-prompt.
|
||||
|
||||
from __future__ import annotations
|
||||
|
||||
import logging
|
||||
import re
|
||||
from datetime import date, datetime, timezone
|
||||
from pathlib import Path
|
||||
from typing import Any
|
||||
from urllib.parse import urlparse
|
||||
|
||||
import duckdb
|
||||
from jinja2 import Environment, StrictUndefined
|
||||
from jinja2 import Environment, StrictUndefined, TemplateError
|
||||
|
||||
from app.instance_config import (
|
||||
get_data_source_type,
|
||||
get_instance_name,
|
||||
get_instance_subtitle,
|
||||
get_sync_interval,
|
||||
)
|
||||
from src.marketplace_filter import resolve_allowed_plugins
|
||||
from src.repositories.welcome_template import WelcomeTemplateRepository
|
||||
|
||||
_DEFAULT_TEMPLATE_PATH = (
|
||||
Path(__file__).resolve().parent.parent / "config" / "claude_md_template.txt"
|
||||
logger = logging.getLogger(__name__)
|
||||
|
||||
# ---------------------------------------------------------------------------
|
||||
# HTML sanitization
|
||||
# ---------------------------------------------------------------------------
|
||||
|
||||
_RE_SCRIPT = re.compile(
|
||||
r"<script[\s\S]*?</script>", re.IGNORECASE
|
||||
)
|
||||
_RE_IFRAME = re.compile(
|
||||
r"<iframe[\s\S]*?(?:</iframe>|/>)", re.IGNORECASE
|
||||
)
|
||||
_RE_ON_EVENT = re.compile(
|
||||
r"\s+on\w+\s*=\s*(?:\"[^\"]*\"|'[^']*'|[^\s>]+)", re.IGNORECASE
|
||||
)
|
||||
_RE_JS_URI = re.compile(
|
||||
r"""(?:href|src|action)\s*=\s*(?:"|')(javascript:|data:)""", re.IGNORECASE
|
||||
)
|
||||
|
||||
|
||||
def _load_default_template() -> str:
|
||||
if _DEFAULT_TEMPLATE_PATH.exists():
|
||||
return _DEFAULT_TEMPLATE_PATH.read_text(encoding="utf-8")
|
||||
# Last-resort embedded fallback if the OSS template file is missing
|
||||
# from the install (e.g., partial Docker COPY).
|
||||
return (
|
||||
"# {{ instance.name }} — AI Data Analyst\n\n"
|
||||
"This workspace is connected to {{ server.url }}.\n"
|
||||
"Data refreshes every {{ sync_interval }}.\n"
|
||||
)
|
||||
def _sanitize_banner_html(html: str) -> str:
|
||||
"""Strip dangerous constructs from admin-authored HTML.
|
||||
|
||||
Defense-in-depth only — admins are trusted, but this prevents accidental
|
||||
XSS from copy-pasted snippets reaching the public /setup page.
|
||||
|
||||
def _list_tables(conn: duckdb.DuckDBPyConnection) -> list[dict[str, Any]]:
|
||||
try:
|
||||
rows = conn.execute(
|
||||
"""SELECT name, description, query_mode
|
||||
FROM table_registry
|
||||
ORDER BY name"""
|
||||
).fetchall()
|
||||
except duckdb.CatalogException:
|
||||
return []
|
||||
return [
|
||||
{"name": r[0], "description": r[1] or "", "query_mode": r[2] or "local"}
|
||||
for r in rows
|
||||
]
|
||||
|
||||
|
||||
def _metrics_summary(conn: duckdb.DuckDBPyConnection) -> dict[str, Any]:
|
||||
try:
|
||||
rows = conn.execute(
|
||||
"SELECT category, COUNT(*) FROM metric_definitions GROUP BY category"
|
||||
).fetchall()
|
||||
except duckdb.CatalogException:
|
||||
return {"count": 0, "categories": []}
|
||||
return {
|
||||
"count": sum(r[1] for r in rows),
|
||||
"categories": sorted({r[0] for r in rows if r[0]}),
|
||||
}
|
||||
|
||||
|
||||
def _marketplaces_for_user(
|
||||
conn: duckdb.DuckDBPyConnection, user: dict[str, Any]
|
||||
) -> list[dict[str, Any]]:
|
||||
"""Return marketplaces with the plugins the user is allowed to see.
|
||||
|
||||
Delegates RBAC filtering entirely to resolve_allowed_plugins, which
|
||||
returns List[dict] with marketplace_slug, original_name, etc.
|
||||
Results are grouped by marketplace slug; display names are fetched
|
||||
from marketplace_registry in a single query.
|
||||
Strips:
|
||||
- <script>…</script> blocks (any content)
|
||||
- <iframe>…</iframe> tags
|
||||
- on*= event handler attributes (onclick=, onload=, etc.)
|
||||
- javascript: / data: URI schemes in href/src/action attributes
|
||||
"""
|
||||
try:
|
||||
allowed = resolve_allowed_plugins(conn, user)
|
||||
except duckdb.CatalogException:
|
||||
return []
|
||||
if not allowed:
|
||||
return []
|
||||
html = _RE_SCRIPT.sub("", html)
|
||||
html = _RE_IFRAME.sub("", html)
|
||||
html = _RE_ON_EVENT.sub("", html)
|
||||
html = _RE_JS_URI.sub(
|
||||
lambda m: m.group(0).replace(m.group(1), "#"), html
|
||||
)
|
||||
return html
|
||||
|
||||
# Build slug → display name lookup from registry
|
||||
slugs = list({p["marketplace_slug"] for p in allowed})
|
||||
placeholders = ",".join(["?"] * len(slugs))
|
||||
try:
|
||||
name_rows = conn.execute(
|
||||
f"SELECT id, name FROM marketplace_registry WHERE id IN ({placeholders})",
|
||||
slugs,
|
||||
).fetchall()
|
||||
slug_to_name: dict[str, str] = {r[0]: r[1] for r in name_rows}
|
||||
except duckdb.CatalogException:
|
||||
slug_to_name = {}
|
||||
|
||||
grouped: dict[str, dict[str, Any]] = {}
|
||||
for plugin in allowed:
|
||||
slug = plugin["marketplace_slug"]
|
||||
bucket = grouped.setdefault(
|
||||
slug,
|
||||
{
|
||||
"slug": slug,
|
||||
"name": slug_to_name.get(slug, slug),
|
||||
"plugins": [],
|
||||
},
|
||||
)
|
||||
bucket["plugins"].append({"name": plugin["original_name"]})
|
||||
|
||||
return list(grouped.values())
|
||||
|
||||
# ---------------------------------------------------------------------------
|
||||
# Render context
|
||||
# ---------------------------------------------------------------------------
|
||||
|
||||
def build_context(
|
||||
conn: duckdb.DuckDBPyConnection,
|
||||
*,
|
||||
user: dict[str, Any],
|
||||
user: dict[str, Any] | None,
|
||||
server_url: str,
|
||||
) -> dict[str, Any]:
|
||||
"""Compose the Jinja2 render context. Pure, no side effects.
|
||||
"""Compose the Jinja2 render context for the banner.
|
||||
|
||||
Note: ``now`` is tz-aware UTC; DB-sourced timestamps elsewhere in the
|
||||
codebase are naive (DuckDB stores ``TIMESTAMP``, not ``TIMESTAMPTZ``).
|
||||
Don't subtract or compare them inside templates without normalising.
|
||||
Intentionally small: instance identity, server URL, and the requesting
|
||||
user (may be None for anonymous /setup visitors). No tables, metrics, or
|
||||
marketplaces — the banner is for org-operational notes, not data-catalog
|
||||
content.
|
||||
|
||||
Note: ``now`` is tz-aware UTC.
|
||||
"""
|
||||
now = datetime.now(timezone.utc)
|
||||
parsed = urlparse(server_url)
|
||||
user_ctx: dict[str, Any] | None = None
|
||||
if user:
|
||||
user_ctx = {
|
||||
"id": user.get("id", ""),
|
||||
"email": user.get("email", ""),
|
||||
"name": user.get("name") or "",
|
||||
"is_admin": bool(user.get("is_admin")),
|
||||
"groups": user.get("groups") or [],
|
||||
}
|
||||
return {
|
||||
"instance": {
|
||||
"name": get_instance_name(),
|
||||
|
|
@ -143,36 +110,44 @@ def build_context(
|
|||
"url": server_url,
|
||||
"hostname": parsed.hostname or "",
|
||||
},
|
||||
"sync_interval": get_sync_interval(),
|
||||
"data_source": {"type": get_data_source_type()},
|
||||
"tables": _list_tables(conn),
|
||||
"metrics": _metrics_summary(conn),
|
||||
"marketplaces": _marketplaces_for_user(conn, user),
|
||||
"user": {
|
||||
"id": user.get("id", ""),
|
||||
"email": user.get("email", ""),
|
||||
"name": user.get("name") or "",
|
||||
"is_admin": bool(user.get("is_admin")),
|
||||
"groups": user.get("groups") or [],
|
||||
},
|
||||
"user": user_ctx,
|
||||
"now": now,
|
||||
"today": date.today().isoformat(),
|
||||
}
|
||||
|
||||
|
||||
def _resolve_template_source(conn: duckdb.DuckDBPyConnection) -> str:
|
||||
row = WelcomeTemplateRepository(conn).get()
|
||||
return row["content"] if row.get("content") else _load_default_template()
|
||||
# ---------------------------------------------------------------------------
|
||||
# Banner renderer
|
||||
# ---------------------------------------------------------------------------
|
||||
|
||||
|
||||
def render_welcome(
|
||||
def render_agent_prompt_banner(
|
||||
conn: duckdb.DuckDBPyConnection,
|
||||
*,
|
||||
user: dict[str, Any],
|
||||
user: dict[str, Any] | None,
|
||||
server_url: str,
|
||||
) -> str:
|
||||
"""Resolve the active template and render it for the given user."""
|
||||
source = _resolve_template_source(conn)
|
||||
env = Environment(undefined=StrictUndefined, autoescape=False)
|
||||
template = env.from_string(source)
|
||||
return template.render(**build_context(conn, user=user, server_url=server_url))
|
||||
"""Render the admin-configured HTML banner for the /setup page.
|
||||
|
||||
Returns an empty string when no override is set (default = no banner).
|
||||
Render failures are swallowed (logged) and return empty string so a
|
||||
broken template never blocks the /setup page from rendering.
|
||||
"""
|
||||
row = WelcomeTemplateRepository(conn).get()
|
||||
content = row.get("content")
|
||||
if not content:
|
||||
return ""
|
||||
|
||||
try:
|
||||
env = Environment(undefined=StrictUndefined, autoescape=True)
|
||||
template = env.from_string(content)
|
||||
ctx = build_context(user=user, server_url=server_url)
|
||||
rendered = template.render(**ctx)
|
||||
return _sanitize_banner_html(rendered)
|
||||
except TemplateError as exc:
|
||||
logger.warning(
|
||||
"Agent-prompt banner render failed (template error): %s", exc
|
||||
)
|
||||
return ""
|
||||
except Exception:
|
||||
logger.exception("Agent-prompt banner render failed (unexpected)")
|
||||
return ""
|
||||
|
|
|
|||
|
|
@ -78,7 +78,7 @@ class TestDetectExistingProject:
|
|||
patch("cli.commands.analyst._download_metadata"), \
|
||||
patch("cli.commands.analyst._download_data", return_value=0), \
|
||||
patch("cli.commands.analyst._initialize_duckdb", return_value=0), \
|
||||
patch("cli.commands.analyst._generate_claude_md"):
|
||||
patch("cli.commands.analyst._init_claude_workspace"):
|
||||
result = runner.invoke(
|
||||
app,
|
||||
["analyst", "setup", "--server-url", "http://localhost:8000", "--force"],
|
||||
|
|
@ -116,59 +116,52 @@ class TestCreateWorkspace:
|
|||
|
||||
|
||||
# ---------------------------------------------------------------------------
|
||||
# TestGenerateClaudeMd
|
||||
# TestInitClaudeWorkspace
|
||||
# ---------------------------------------------------------------------------
|
||||
|
||||
class TestGenerateClaudeMd:
|
||||
"""Server-side render flow: _generate_claude_md fetches /api/welcome.
|
||||
|
||||
The local-fallback path is exercised by tests/test_cli_analyst_welcome.py.
|
||||
These tests cover the side-effects on the workspace (CLAUDE.local.md,
|
||||
settings.json) and verify the new signature is honored.
|
||||
class TestInitClaudeWorkspace:
|
||||
"""Tests for _init_claude_workspace: no CLAUDE.md written, but
|
||||
.claude/CLAUDE.local.md placeholder and settings.json hooks are created.
|
||||
"""
|
||||
|
||||
def _patch_httpx_404(self, monkeypatch):
|
||||
"""Stub httpx.get to return 404 so _generate_claude_md falls back to embedded text."""
|
||||
import httpx
|
||||
|
||||
def fake_get(url, headers=None, timeout=None):
|
||||
return httpx.Response(
|
||||
status_code=404, json={}, request=httpx.Request("GET", url)
|
||||
)
|
||||
|
||||
monkeypatch.setattr("cli.commands.analyst.httpx", type("_M", (), {"get": fake_get}))
|
||||
|
||||
def test_creates_claude_local_md_when_absent(self, tmp_workspace, monkeypatch):
|
||||
from cli.commands.analyst import _create_workspace, _generate_claude_md
|
||||
def test_does_not_write_claude_md(self, tmp_workspace):
|
||||
from cli.commands.analyst import _create_workspace, _init_claude_workspace
|
||||
|
||||
_create_workspace(tmp_workspace)
|
||||
self._patch_httpx_404(monkeypatch)
|
||||
_generate_claude_md(tmp_workspace, server_url="http://localhost:8000", token="t")
|
||||
_init_claude_workspace(tmp_workspace)
|
||||
|
||||
assert not (tmp_workspace / "CLAUDE.md").exists(), (
|
||||
"CLAUDE.md must NOT be written by _init_claude_workspace"
|
||||
)
|
||||
|
||||
def test_creates_claude_local_md_when_absent(self, tmp_workspace):
|
||||
from cli.commands.analyst import _create_workspace, _init_claude_workspace
|
||||
|
||||
_create_workspace(tmp_workspace)
|
||||
_init_claude_workspace(tmp_workspace)
|
||||
|
||||
local_md = tmp_workspace / ".claude" / "CLAUDE.local.md"
|
||||
assert local_md.exists()
|
||||
assert local_md.read_text(encoding="utf-8").strip() != ""
|
||||
|
||||
def test_does_not_overwrite_existing_local_md(self, tmp_workspace, monkeypatch):
|
||||
from cli.commands.analyst import _create_workspace, _generate_claude_md
|
||||
def test_does_not_overwrite_existing_local_md(self, tmp_workspace):
|
||||
from cli.commands.analyst import _create_workspace, _init_claude_workspace
|
||||
|
||||
_create_workspace(tmp_workspace)
|
||||
local_md = tmp_workspace / ".claude" / "CLAUDE.local.md"
|
||||
original_content = "# My custom notes\n\nDo not overwrite me.\n"
|
||||
local_md.write_text(original_content, encoding="utf-8")
|
||||
|
||||
self._patch_httpx_404(monkeypatch)
|
||||
_generate_claude_md(tmp_workspace, server_url="http://localhost:8000", token="t")
|
||||
_init_claude_workspace(tmp_workspace)
|
||||
|
||||
assert local_md.read_text(encoding="utf-8") == original_content
|
||||
|
||||
def test_writes_settings_json(self, tmp_workspace, monkeypatch):
|
||||
from cli.commands.analyst import _create_workspace, _generate_claude_md
|
||||
def test_writes_settings_json(self, tmp_workspace):
|
||||
from cli.commands.analyst import _create_workspace, _init_claude_workspace
|
||||
import json as _json
|
||||
|
||||
_create_workspace(tmp_workspace)
|
||||
self._patch_httpx_404(monkeypatch)
|
||||
_generate_claude_md(tmp_workspace, server_url="http://localhost:8000", token="t")
|
||||
_init_claude_workspace(tmp_workspace)
|
||||
|
||||
settings = _json.loads(
|
||||
(tmp_workspace / ".claude" / "settings.json").read_text(encoding="utf-8")
|
||||
|
|
@ -176,6 +169,21 @@ class TestGenerateClaudeMd:
|
|||
assert settings["model"] == "sonnet"
|
||||
assert "Read" in settings["permissions"]["allow"]
|
||||
|
||||
def test_installs_session_hooks(self, tmp_workspace):
|
||||
"""SessionStart and SessionEnd hooks must be present in settings.json."""
|
||||
from cli.commands.analyst import _create_workspace, _init_claude_workspace
|
||||
import json as _json
|
||||
|
||||
_create_workspace(tmp_workspace)
|
||||
_init_claude_workspace(tmp_workspace)
|
||||
|
||||
settings = _json.loads(
|
||||
(tmp_workspace / ".claude" / "settings.json").read_text(encoding="utf-8")
|
||||
)
|
||||
hooks = settings.get("hooks", {})
|
||||
assert "SessionStart" in hooks
|
||||
assert "SessionEnd" in hooks
|
||||
|
||||
|
||||
# ---------------------------------------------------------------------------
|
||||
# TestReturningSession
|
||||
|
|
|
|||
|
|
@ -1,89 +0,0 @@
|
|||
"""Integration tests for da analyst setup → /api/welcome wiring."""
|
||||
|
||||
import json
|
||||
from pathlib import Path
|
||||
|
||||
import httpx
|
||||
import pytest
|
||||
|
||||
from cli.commands.analyst import _generate_claude_md
|
||||
|
||||
|
||||
class _MockClient:
|
||||
def __init__(self, responses):
|
||||
self._responses = responses
|
||||
self.calls = []
|
||||
|
||||
def get(self, url, headers=None, timeout=None):
|
||||
self.calls.append(url)
|
||||
body, status = self._responses.get(url, ({}, 404))
|
||||
return httpx.Response(status_code=status, json=body, request=httpx.Request("GET", url))
|
||||
|
||||
|
||||
def _ws(tmp_path: Path) -> Path:
|
||||
workspace = tmp_path / "ws"
|
||||
(workspace / ".claude").mkdir(parents=True)
|
||||
return workspace
|
||||
|
||||
|
||||
def test_generate_claude_md_uses_server_render(tmp_path, monkeypatch):
|
||||
workspace = _ws(tmp_path)
|
||||
rendered = "# CUSTOM\n\nFrom server.\n"
|
||||
mock = _MockClient({
|
||||
"https://example.com/api/welcome?server_url=https%3A%2F%2Fexample.com": (
|
||||
{"content": rendered}, 200
|
||||
),
|
||||
})
|
||||
monkeypatch.setattr("cli.commands.analyst.httpx", type("_M", (), {"get": mock.get}))
|
||||
_generate_claude_md(workspace, server_url="https://example.com", token="t")
|
||||
|
||||
assert (workspace / "CLAUDE.md").read_text(encoding="utf-8") == rendered
|
||||
# Workspace side-effects are created on the success path too.
|
||||
assert (workspace / ".claude" / "CLAUDE.local.md").exists()
|
||||
settings = json.loads((workspace / ".claude" / "settings.json").read_text(encoding="utf-8"))
|
||||
assert settings["model"] == "sonnet"
|
||||
|
||||
|
||||
def test_generate_claude_md_falls_back_on_404(tmp_path, monkeypatch):
|
||||
workspace = _ws(tmp_path)
|
||||
mock = _MockClient({}) # everything 404s
|
||||
monkeypatch.setattr("cli.commands.analyst.httpx", type("_M", (), {"get": mock.get}))
|
||||
_generate_claude_md(workspace, server_url="https://example.com", token="t")
|
||||
body = (workspace / "CLAUDE.md").read_text(encoding="utf-8")
|
||||
assert "AI Data Analyst" in body
|
||||
assert "https://example.com" in body
|
||||
|
||||
|
||||
def test_generate_claude_md_falls_back_on_null_content(tmp_path, monkeypatch):
|
||||
"""Server returns 200 but malformed body (`content: null`). CLI must use fallback."""
|
||||
workspace = _ws(tmp_path)
|
||||
mock = _MockClient({
|
||||
"https://example.com/api/welcome?server_url=https%3A%2F%2Fexample.com": (
|
||||
{"content": None}, 200
|
||||
),
|
||||
})
|
||||
monkeypatch.setattr("cli.commands.analyst.httpx", type("_M", (), {"get": mock.get}))
|
||||
_generate_claude_md(workspace, server_url="https://example.com", token="t")
|
||||
body = (workspace / "CLAUDE.md").read_text(encoding="utf-8")
|
||||
# Embedded fallback contains these literals
|
||||
assert "AI Data Analyst" in body
|
||||
assert "https://example.com" in body
|
||||
|
||||
|
||||
def test_generate_claude_md_warns_on_5xx(tmp_path, monkeypatch, capsys):
|
||||
"""500 from server → embedded fallback, with a stderr warning so operators can diagnose."""
|
||||
workspace = _ws(tmp_path)
|
||||
mock = _MockClient({
|
||||
"https://example.com/api/welcome?server_url=https%3A%2F%2Fexample.com": (
|
||||
{"detail": "boom"}, 500
|
||||
),
|
||||
})
|
||||
monkeypatch.setattr("cli.commands.analyst.httpx", type("_M", (), {"get": mock.get}))
|
||||
_generate_claude_md(workspace, server_url="https://example.com", token="t")
|
||||
|
||||
body = (workspace / "CLAUDE.md").read_text(encoding="utf-8")
|
||||
assert "AI Data Analyst" in body # fallback used
|
||||
|
||||
captured = capsys.readouterr()
|
||||
assert "500" in captured.err
|
||||
assert "fallback" in captured.err.lower()
|
||||
|
|
@ -1,4 +1,8 @@
|
|||
"""End-to-end tests for /api/welcome and /api/admin/welcome-template."""
|
||||
"""End-to-end tests for /api/admin/welcome-template (banner editor endpoints).
|
||||
|
||||
GET /api/welcome has been removed — the analyst-facing endpoint is gone.
|
||||
These tests cover only the admin CRUD + preview endpoints.
|
||||
"""
|
||||
|
||||
import duckdb
|
||||
|
||||
|
|
@ -10,7 +14,8 @@ def _auth(token: str) -> dict[str, str]:
|
|||
return {"Authorization": f"Bearer {token}"}
|
||||
|
||||
|
||||
def test_get_welcome_returns_rendered_markdown(seeded_app):
|
||||
def test_get_welcome_endpoint_removed(seeded_app):
|
||||
"""GET /api/welcome must return 404 — the endpoint was deleted."""
|
||||
c = seeded_app["client"]
|
||||
token = seeded_app["analyst_token"]
|
||||
resp = c.get(
|
||||
|
|
@ -18,48 +23,39 @@ def test_get_welcome_returns_rendered_markdown(seeded_app):
|
|||
params={"server_url": "https://example.com"},
|
||||
headers=_auth(token),
|
||||
)
|
||||
assert resp.status_code == 200
|
||||
body = resp.json()
|
||||
assert "content" in body
|
||||
assert "AI Data Analyst" in body["content"]
|
||||
assert "https://example.com" in body["content"]
|
||||
assert resp.status_code == 404
|
||||
|
||||
|
||||
def test_get_welcome_requires_auth(seeded_app):
|
||||
def test_admin_get_template_initially_null(seeded_app):
|
||||
c = seeded_app["client"]
|
||||
resp = c.get("/api/welcome", params={"server_url": "https://example.com"})
|
||||
assert resp.status_code == 401
|
||||
admin = _auth(seeded_app["admin_token"])
|
||||
|
||||
r = c.get("/api/admin/welcome-template", headers=admin)
|
||||
assert r.status_code == 200
|
||||
body = r.json()
|
||||
assert body["content"] is None
|
||||
# No longer returns a `default` field — banner default is empty
|
||||
assert "default" not in body or body.get("default") is None
|
||||
|
||||
|
||||
def test_admin_can_set_and_reset_template(seeded_app):
|
||||
c = seeded_app["client"]
|
||||
admin = _auth(seeded_app["admin_token"])
|
||||
|
||||
# GET initial state
|
||||
r = c.get("/api/admin/welcome-template", headers=admin)
|
||||
assert r.status_code == 200
|
||||
body = r.json()
|
||||
assert body["content"] is None
|
||||
# The shipped default starts with the Jinja2 comment block.
|
||||
assert body["default"].startswith("{#")
|
||||
|
||||
# PUT override
|
||||
r = c.put(
|
||||
"/api/admin/welcome-template",
|
||||
json={"content": "Hello {{ user.email }}"},
|
||||
json={"content": "<p>Hello {{ user.email }}</p>"},
|
||||
headers=admin,
|
||||
)
|
||||
assert r.status_code == 200
|
||||
|
||||
# Verify rendered output uses override
|
||||
r = c.get(
|
||||
"/api/welcome",
|
||||
params={"server_url": "https://example.com"},
|
||||
headers=admin, # admin user can also call /api/welcome
|
||||
)
|
||||
assert r.json()["content"].startswith("Hello ")
|
||||
# GET reflects override
|
||||
r = c.get("/api/admin/welcome-template", headers=admin)
|
||||
assert r.status_code == 200
|
||||
assert r.json()["content"] == "<p>Hello {{ user.email }}</p>"
|
||||
|
||||
# DELETE = reset
|
||||
# DELETE = reset (no banner)
|
||||
r = c.delete("/api/admin/welcome-template", headers=admin)
|
||||
assert r.status_code == 204
|
||||
r = c.get("/api/admin/welcome-template", headers=admin)
|
||||
|
|
@ -69,7 +65,7 @@ def test_admin_can_set_and_reset_template(seeded_app):
|
|||
def test_non_admin_cannot_edit_template(seeded_app):
|
||||
c = seeded_app["client"]
|
||||
analyst = _auth(seeded_app["analyst_token"])
|
||||
r = c.put("/api/admin/welcome-template", json={"content": "x"}, headers=analyst)
|
||||
r = c.put("/api/admin/welcome-template", json={"content": "<p>x</p>"}, headers=analyst)
|
||||
assert r.status_code == 403
|
||||
|
||||
|
||||
|
|
@ -86,58 +82,29 @@ def test_invalid_jinja2_returns_400(seeded_app):
|
|||
|
||||
|
||||
def test_put_rejects_undefined_placeholder(seeded_app):
|
||||
"""Templates that parse but reference unknown placeholders must be rejected
|
||||
at PUT time so an admin can fix the typo immediately rather than after an
|
||||
analyst's bootstrap blows up."""
|
||||
"""Templates that reference unknown placeholders must be rejected at PUT time."""
|
||||
c = seeded_app["client"]
|
||||
admin = _auth(seeded_app["admin_token"])
|
||||
r = c.put(
|
||||
"/api/admin/welcome-template",
|
||||
json={"content": "Hello {{ user.emial }}"}, # typo, would fail StrictUndefined at render
|
||||
json={"content": "<p>{{ user.emial }}</p>"}, # typo
|
||||
headers=admin,
|
||||
)
|
||||
assert r.status_code == 400
|
||||
assert "emial" in r.json()["detail"] or "undefined" in r.json()["detail"].lower()
|
||||
|
||||
|
||||
def test_get_welcome_500_includes_reset_hint_on_render_failure(seeded_app, monkeypatch):
|
||||
"""If an override slips through validation and fails at render time, the
|
||||
user-visible 500 must point at /admin/agent-prompt rather than leaking a
|
||||
Jinja stack trace."""
|
||||
# Stub render_welcome to raise a TemplateError so we exercise the
|
||||
# exception path without needing a malformed override (PUT validation
|
||||
# blocks those now).
|
||||
from jinja2 import UndefinedError
|
||||
import app.api.welcome as welcome_module
|
||||
|
||||
def fake_render(*args, **kwargs):
|
||||
raise UndefinedError("'foo' is undefined")
|
||||
|
||||
monkeypatch.setattr(welcome_module, "render_welcome", fake_render)
|
||||
|
||||
c = seeded_app["client"]
|
||||
admin = _auth(seeded_app["admin_token"])
|
||||
r = c.get(
|
||||
"/api/welcome",
|
||||
params={"server_url": "https://example.com"},
|
||||
headers=admin,
|
||||
)
|
||||
assert r.status_code == 500
|
||||
assert "/admin/agent-prompt" in r.json()["detail"]
|
||||
|
||||
|
||||
def test_admin_preview_renders_arbitrary_content(seeded_app):
|
||||
"""Preview endpoint must render the supplied content (not whatever's
|
||||
stored), so the admin UI can show pre-save preview."""
|
||||
def test_admin_preview_renders_html(seeded_app):
|
||||
"""Preview endpoint renders supplied HTML content without persisting."""
|
||||
c = seeded_app["client"]
|
||||
admin = _auth(seeded_app["admin_token"])
|
||||
r = c.post(
|
||||
"/api/admin/welcome-template/preview",
|
||||
json={"content": "# Preview {{ user.email }}"},
|
||||
json={"content": "<p>Preview for {{ user.email }}</p>"},
|
||||
headers=admin,
|
||||
)
|
||||
assert r.status_code == 200
|
||||
assert r.json()["content"].startswith("# Preview admin@test.com")
|
||||
assert r.json()["content"].startswith("<p>Preview for admin@test.com")
|
||||
|
||||
|
||||
def test_preview_rejects_invalid_template(seeded_app):
|
||||
|
|
@ -156,26 +123,29 @@ def test_preview_requires_admin(seeded_app):
|
|||
analyst = _auth(seeded_app["analyst_token"])
|
||||
r = c.post(
|
||||
"/api/admin/welcome-template/preview",
|
||||
json={"content": "# x"},
|
||||
json={"content": "<p>x</p>"},
|
||||
headers=analyst,
|
||||
)
|
||||
assert r.status_code == 403
|
||||
|
||||
|
||||
def test_validation_stub_matches_build_context_shape(seeded_app, tmp_path, monkeypatch):
|
||||
"""If build_context grows new keys, _VALIDATION_STUB_CONTEXT must too —
|
||||
otherwise admins can save templates referencing keys the PUT validator
|
||||
accepts but the live render rejects."""
|
||||
"""_VALIDATION_STUB_CONTEXT top-level keys must match build_context() output.
|
||||
|
||||
If build_context gains new keys, the stub must track them so admins can
|
||||
save templates that reference those keys without hitting a live-render
|
||||
rejection after the PUT validation accepted them.
|
||||
"""
|
||||
from app.api.welcome import _VALIDATION_STUB_CONTEXT
|
||||
|
||||
monkeypatch.setenv("DATA_DIR", str(tmp_path))
|
||||
db_path = tmp_path / "system.duckdb"
|
||||
conn = duckdb.connect(str(db_path))
|
||||
_ensure_schema(conn)
|
||||
|
||||
user = {"id": "u1", "email": "admin@test.com", "name": "Admin", "is_admin": True, "groups": ["Admin"]}
|
||||
real_ctx = build_context(conn, user=user, server_url="https://example.com")
|
||||
conn.close()
|
||||
user = {
|
||||
"id": "u1",
|
||||
"email": "admin@test.com",
|
||||
"name": "Admin",
|
||||
"is_admin": True,
|
||||
"groups": ["Admin"],
|
||||
}
|
||||
real_ctx = build_context(user=user, server_url="https://example.com")
|
||||
|
||||
# Top-level keys must match
|
||||
assert set(_VALIDATION_STUB_CONTEXT.keys()) == set(real_ctx.keys()), (
|
||||
|
|
@ -184,9 +154,13 @@ def test_validation_stub_matches_build_context_shape(seeded_app, tmp_path, monke
|
|||
f"real has: {set(real_ctx.keys())}"
|
||||
)
|
||||
|
||||
# One level deep for nested dicts
|
||||
for key in ("instance", "server", "user"):
|
||||
if isinstance(real_ctx.get(key), dict):
|
||||
assert set(_VALIDATION_STUB_CONTEXT[key].keys()) == set(real_ctx[key].keys()), (
|
||||
f"_VALIDATION_STUB_CONTEXT[{key!r}] drifted from build_context output"
|
||||
)
|
||||
# One level deep for nested dicts (user may be None in real_ctx — compare stub shape)
|
||||
for key in ("instance", "server"):
|
||||
assert set(_VALIDATION_STUB_CONTEXT[key].keys()) == set(real_ctx[key].keys()), (
|
||||
f"_VALIDATION_STUB_CONTEXT[{key!r}] drifted from build_context output"
|
||||
)
|
||||
# user sub-keys
|
||||
if real_ctx.get("user") and _VALIDATION_STUB_CONTEXT.get("user"):
|
||||
assert set(_VALIDATION_STUB_CONTEXT["user"].keys()) == set(real_ctx["user"].keys()), (
|
||||
"_VALIDATION_STUB_CONTEXT['user'] drifted from build_context output"
|
||||
)
|
||||
|
|
|
|||
|
|
@ -1,14 +1,15 @@
|
|||
"""Unit tests for the welcome-prompt renderer."""
|
||||
|
||||
import uuid
|
||||
from pathlib import Path
|
||||
"""Unit tests for the agent-setup-prompt banner renderer."""
|
||||
|
||||
import duckdb
|
||||
import pytest
|
||||
|
||||
from src.db import _ensure_schema
|
||||
from src.repositories.welcome_template import WelcomeTemplateRepository
|
||||
from src.welcome_template import build_context, render_welcome
|
||||
from src.welcome_template import (
|
||||
_sanitize_banner_html,
|
||||
build_context,
|
||||
render_agent_prompt_banner,
|
||||
)
|
||||
|
||||
|
||||
@pytest.fixture
|
||||
|
|
@ -22,144 +23,188 @@ def conn(tmp_path, monkeypatch):
|
|||
|
||||
|
||||
def _user(email="alice@example.com"):
|
||||
return {"id": "u1", "email": email, "name": "Alice", "is_admin": False, "groups": ["Everyone"]}
|
||||
return {
|
||||
"id": "u1",
|
||||
"email": email,
|
||||
"name": "Alice",
|
||||
"is_admin": False,
|
||||
"groups": ["Everyone"],
|
||||
}
|
||||
|
||||
|
||||
def test_renders_default_when_no_override(conn):
|
||||
out = render_welcome(conn, user=_user(), server_url="https://example.com")
|
||||
assert "AI Data Analyst" in out
|
||||
assert "https://example.com" in out
|
||||
assert "Alice" in out
|
||||
# ---------------------------------------------------------------------------
|
||||
# Default (no override) → empty string
|
||||
# ---------------------------------------------------------------------------
|
||||
|
||||
def test_returns_empty_when_no_override(conn):
|
||||
out = render_agent_prompt_banner(conn, user=_user(), server_url="https://example.com")
|
||||
assert out == ""
|
||||
|
||||
|
||||
# ---------------------------------------------------------------------------
|
||||
# Override renders correctly
|
||||
# ---------------------------------------------------------------------------
|
||||
|
||||
def test_renders_override(conn):
|
||||
WelcomeTemplateRepository(conn).set(
|
||||
"# {{ instance.name }} for {{ user.email }}",
|
||||
"<p>Welcome to {{ instance.name }}!</p>",
|
||||
updated_by="admin@example.com",
|
||||
)
|
||||
out = render_welcome(conn, user=_user(), server_url="https://example.com")
|
||||
assert out.startswith("# AI Data Analyst for alice@example.com")
|
||||
out = render_agent_prompt_banner(conn, user=_user(), server_url="https://example.com")
|
||||
assert "<p>Welcome to" in out
|
||||
# instance.name comes from instance_config — any non-empty string is fine
|
||||
assert "!" in out
|
||||
|
||||
|
||||
def test_strict_undefined_raises_on_missing_placeholder(conn):
|
||||
def test_renders_user_placeholder(conn):
|
||||
WelcomeTemplateRepository(conn).set(
|
||||
"<p>Hello {{ user.email }}</p>",
|
||||
updated_by="admin@example.com",
|
||||
)
|
||||
out = render_agent_prompt_banner(
|
||||
conn, user=_user("bob@example.com"), server_url="https://example.com"
|
||||
)
|
||||
assert "bob@example.com" in out
|
||||
|
||||
|
||||
def test_renders_server_placeholder(conn):
|
||||
WelcomeTemplateRepository(conn).set(
|
||||
"<p>Server: {{ server.url }}</p>",
|
||||
updated_by="admin@example.com",
|
||||
)
|
||||
out = render_agent_prompt_banner(
|
||||
conn, user=_user(), server_url="https://myserver.example.com"
|
||||
)
|
||||
assert "https://myserver.example.com" in out
|
||||
|
||||
|
||||
# ---------------------------------------------------------------------------
|
||||
# Anonymous user (user=None)
|
||||
# ---------------------------------------------------------------------------
|
||||
|
||||
def test_renders_with_anonymous_user(conn):
|
||||
WelcomeTemplateRepository(conn).set(
|
||||
"{% if user %}<p>Hi {{ user.email }}</p>{% else %}<p>Please sign in.</p>{% endif %}",
|
||||
updated_by="admin@example.com",
|
||||
)
|
||||
out = render_agent_prompt_banner(conn, user=None, server_url="https://example.com")
|
||||
assert "Please sign in." in out
|
||||
assert "Hi" not in out
|
||||
|
||||
|
||||
def test_returns_empty_for_none_user_with_no_override(conn):
|
||||
out = render_agent_prompt_banner(conn, user=None, server_url="https://example.com")
|
||||
assert out == ""
|
||||
|
||||
|
||||
# ---------------------------------------------------------------------------
|
||||
# Build context shape
|
||||
# ---------------------------------------------------------------------------
|
||||
|
||||
def test_context_exposes_documented_keys():
|
||||
ctx = build_context(user=_user(), server_url="https://example.com")
|
||||
for key in ("instance", "server", "user", "now", "today"):
|
||||
assert key in ctx, f"missing context key: {key}"
|
||||
assert "tables" not in ctx
|
||||
assert "metrics" not in ctx
|
||||
assert "marketplaces" not in ctx
|
||||
assert "sync_interval" not in ctx
|
||||
assert "data_source" not in ctx
|
||||
|
||||
|
||||
def test_context_user_none():
|
||||
ctx = build_context(user=None, server_url="https://example.com")
|
||||
assert ctx["user"] is None
|
||||
|
||||
|
||||
def test_context_instance_keys():
|
||||
ctx = build_context(user=_user(), server_url="https://example.com")
|
||||
assert "name" in ctx["instance"]
|
||||
assert "subtitle" in ctx["instance"]
|
||||
|
||||
|
||||
def test_context_server_keys():
|
||||
ctx = build_context(user=_user(), server_url="https://example.com")
|
||||
assert ctx["server"]["url"] == "https://example.com"
|
||||
assert ctx["server"]["hostname"] == "example.com"
|
||||
|
||||
|
||||
# ---------------------------------------------------------------------------
|
||||
# HTML sanitization
|
||||
# ---------------------------------------------------------------------------
|
||||
|
||||
def test_sanitize_strips_script_tag():
|
||||
html = '<p>Hello</p><script>alert("xss")</script>'
|
||||
result = _sanitize_banner_html(html)
|
||||
assert "<script>" not in result
|
||||
assert "alert" not in result
|
||||
assert "<p>Hello</p>" in result
|
||||
|
||||
|
||||
def test_sanitize_strips_script_with_attributes():
|
||||
html = '<script type="text/javascript">evil()</script><p>ok</p>'
|
||||
result = _sanitize_banner_html(html)
|
||||
assert "evil" not in result
|
||||
assert "<p>ok</p>" in result
|
||||
|
||||
|
||||
def test_sanitize_strips_iframe():
|
||||
html = '<p>text</p><iframe src="https://evil.example.com"></iframe>'
|
||||
result = _sanitize_banner_html(html)
|
||||
assert "<iframe" not in result
|
||||
assert "<p>text</p>" in result
|
||||
|
||||
|
||||
def test_sanitize_strips_event_handlers():
|
||||
html = '<button onclick="evil()">Click me</button>'
|
||||
result = _sanitize_banner_html(html)
|
||||
assert "onclick" not in result
|
||||
assert "evil" not in result
|
||||
assert "Click me" in result
|
||||
|
||||
|
||||
def test_sanitize_strips_onload_on_img():
|
||||
html = '<img src="x" onload="steal()" alt="test">'
|
||||
result = _sanitize_banner_html(html)
|
||||
assert "onload" not in result
|
||||
assert "steal" not in result
|
||||
|
||||
|
||||
def test_sanitize_strips_javascript_uri():
|
||||
html = '<a href="javascript:alert(1)">click</a>'
|
||||
result = _sanitize_banner_html(html)
|
||||
assert "javascript:" not in result
|
||||
|
||||
|
||||
def test_sanitize_allows_safe_html():
|
||||
html = "<p>VPN required. Contact <a href='https://support.example.com'>support</a>.</p>"
|
||||
result = _sanitize_banner_html(html)
|
||||
assert "<p>" in result
|
||||
assert "<a href" in result
|
||||
assert "support" in result
|
||||
|
||||
|
||||
# ---------------------------------------------------------------------------
|
||||
# Render failure → empty string (not exception)
|
||||
# ---------------------------------------------------------------------------
|
||||
|
||||
def test_render_failure_returns_empty_not_exception(conn):
|
||||
# StrictUndefined: referencing an unknown variable raises at render time.
|
||||
WelcomeTemplateRepository(conn).set(
|
||||
"{{ does_not_exist }}", updated_by="admin@example.com"
|
||||
)
|
||||
with pytest.raises(Exception) as exc_info:
|
||||
render_welcome(conn, user=_user(), server_url="https://example.com")
|
||||
assert "does_not_exist" in str(exc_info.value)
|
||||
out = render_agent_prompt_banner(conn, user=_user(), server_url="https://example.com")
|
||||
# Must return empty string, not raise
|
||||
assert out == ""
|
||||
|
||||
|
||||
def test_context_exposes_documented_keys(conn):
|
||||
ctx = build_context(conn, user=_user(), server_url="https://example.com")
|
||||
for top in ("instance", "server", "sync_interval", "data_source",
|
||||
"tables", "metrics", "marketplaces", "user", "now", "today"):
|
||||
assert top in ctx, f"missing top-level key: {top}"
|
||||
|
||||
|
||||
def test_render_marketplaces_filtered_by_rbac(conn, monkeypatch):
|
||||
"""Two users with different group memberships render different marketplace lists."""
|
||||
from app.resource_types import ResourceType
|
||||
|
||||
# ── Seed two marketplaces ────────────────────────────────────────────
|
||||
conn.execute(
|
||||
"""INSERT INTO marketplace_registry (id, name, url) VALUES
|
||||
('mkt-a', 'Marketplace A', 'https://github.com/example/mkt-a'),
|
||||
('mkt-b', 'Marketplace B', 'https://github.com/example/mkt-b')"""
|
||||
)
|
||||
# Two plugins per marketplace
|
||||
for mkt, plugins in [("mkt-a", ["plugin-1", "plugin-2"]), ("mkt-b", ["plugin-3", "plugin-4"])]:
|
||||
for p in plugins:
|
||||
conn.execute(
|
||||
"INSERT INTO marketplace_plugins (marketplace_id, name) VALUES (?, ?)",
|
||||
[mkt, p],
|
||||
)
|
||||
|
||||
# ── Seed two non-system groups ──────────────────────────────────────
|
||||
gid_a = str(uuid.uuid4())
|
||||
gid_b = str(uuid.uuid4())
|
||||
conn.execute(
|
||||
"INSERT INTO user_groups (id, name) VALUES (?, ?), (?, ?)",
|
||||
[gid_a, "group-a", gid_b, "group-b"],
|
||||
)
|
||||
|
||||
# ── Grant mkt-a/* to group-a and mkt-b/* to group-b ─────────────────
|
||||
rtype = ResourceType.MARKETPLACE_PLUGIN.value
|
||||
for mkt, gid, plugins in [
|
||||
("mkt-a", gid_a, ["plugin-1", "plugin-2"]),
|
||||
("mkt-b", gid_b, ["plugin-3", "plugin-4"]),
|
||||
]:
|
||||
for p in plugins:
|
||||
conn.execute(
|
||||
"INSERT INTO resource_grants (id, group_id, resource_type, resource_id) "
|
||||
"VALUES (?, ?, ?, ?)",
|
||||
[str(uuid.uuid4()), gid, rtype, f"{mkt}/{p}"],
|
||||
)
|
||||
|
||||
# ── Seed two users, each in their own group + Everyone ───────────────
|
||||
everyone_gid = conn.execute(
|
||||
"SELECT id FROM user_groups WHERE name = 'Everyone'"
|
||||
).fetchone()[0]
|
||||
|
||||
conn.execute(
|
||||
"INSERT INTO users (id, email, name, active) VALUES "
|
||||
"('user-a', 'user-a@example.com', 'User A', TRUE), "
|
||||
"('user-b', 'user-b@example.com', 'User B', TRUE)"
|
||||
)
|
||||
for uid, gid in [("user-a", gid_a), ("user-b", gid_b)]:
|
||||
conn.execute(
|
||||
"INSERT INTO user_group_members (user_id, group_id, source) VALUES (?, ?, ?)",
|
||||
[uid, gid, "admin"],
|
||||
)
|
||||
conn.execute(
|
||||
"INSERT INTO user_group_members (user_id, group_id, source) VALUES (?, ?, ?)",
|
||||
[uid, everyone_gid, "system_seed"],
|
||||
)
|
||||
|
||||
# ── Render for each user ─────────────────────────────────────────────
|
||||
def test_sanitize_applied_after_render(conn):
|
||||
"""A template that produces <script> output is sanitized before return."""
|
||||
WelcomeTemplateRepository(conn).set(
|
||||
"{% for m in marketplaces %}{{ m.slug }}: "
|
||||
"{% for p in m.plugins %}{{ p.name }} {% endfor %}{% endfor %}",
|
||||
"<script>evil()</script><p>safe content</p>",
|
||||
updated_by="admin@example.com",
|
||||
)
|
||||
|
||||
user_a = {"id": "user-a", "email": "user-a@example.com", "name": "User A", "is_admin": False, "groups": ["group-a"]}
|
||||
user_b = {"id": "user-b", "email": "user-b@example.com", "name": "User B", "is_admin": False, "groups": ["group-b"]}
|
||||
|
||||
out_a = render_welcome(conn, user=user_a, server_url="https://example.com")
|
||||
out_b = render_welcome(conn, user=user_b, server_url="https://example.com")
|
||||
|
||||
# user-a sees mkt-a plugins only
|
||||
assert "mkt-a" in out_a
|
||||
assert "plugin-1" in out_a
|
||||
assert "mkt-b" not in out_a
|
||||
assert "plugin-3" not in out_a
|
||||
|
||||
# user-b sees mkt-b plugins only
|
||||
assert "mkt-b" in out_b
|
||||
assert "plugin-3" in out_b
|
||||
assert "mkt-a" not in out_b
|
||||
assert "plugin-1" not in out_b
|
||||
|
||||
|
||||
def test_render_tolerates_missing_optional_tables(tmp_path, monkeypatch):
|
||||
"""A bare DuckDB without table_registry / marketplace_registry must still render."""
|
||||
monkeypatch.setenv("DATA_DIR", str(tmp_path))
|
||||
db_path = tmp_path / "bare.duckdb"
|
||||
bare = duckdb.connect(str(db_path))
|
||||
# Only seed the welcome_template singleton manually; no other tables.
|
||||
bare.execute(
|
||||
"""CREATE TABLE welcome_template (
|
||||
id INTEGER PRIMARY KEY DEFAULT 1,
|
||||
content TEXT,
|
||||
updated_at TIMESTAMP,
|
||||
updated_by VARCHAR
|
||||
)"""
|
||||
)
|
||||
bare.execute("INSERT INTO welcome_template (id, content) VALUES (1, NULL)")
|
||||
|
||||
out = render_welcome(bare, user=_user(), server_url="https://example.com")
|
||||
bare.close()
|
||||
assert "AI Data Analyst" in out # default template still renders
|
||||
# No tables → "_No tables registered yet_" branch from the default template
|
||||
assert "No tables registered yet" in out
|
||||
out = render_agent_prompt_banner(conn, user=_user(), server_url="https://example.com")
|
||||
assert "<script>" not in out
|
||||
assert "evil" not in out
|
||||
assert "<p>safe content</p>" in out
|
||||
|
|
|
|||
Loading…
Reference in a new issue