feat: server-side jinja2 renderer for welcome prompt

This commit is contained in:
ZdenekSrotyr 2026-04-30 18:50:43 +02:00
parent d055417377
commit 51f287a81a
2 changed files with 219 additions and 0 deletions

163
src/welcome_template.py Normal file
View file

@ -0,0 +1,163 @@
"""Render the analyst-onboarding welcome prompt (CLAUDE.md).
Two layers:
1. Template source admin override from welcome_template.content,
or the shipped default at config/claude_md_template.txt.
2. Render context built from instance config, table_registry,
metric_definitions, and the calling user's RBAC-filtered marketplaces.
The Jinja2 environment uses StrictUndefined so that any typo in the
template raises immediately rather than rendering empty strings.
"""
from __future__ import annotations
from datetime import date, datetime, timezone
from pathlib import Path
from typing import Any
from urllib.parse import urlparse
import duckdb
from jinja2 import Environment, StrictUndefined
from app.instance_config import (
get_data_source_type,
get_instance_name,
get_instance_subtitle,
get_sync_interval,
)
from src.marketplace_filter import resolve_allowed_plugins
from src.repositories.welcome_template import WelcomeTemplateRepository
_DEFAULT_TEMPLATE_PATH = (
Path(__file__).resolve().parent.parent / "config" / "claude_md_template.txt"
)
def _load_default_template() -> str:
if _DEFAULT_TEMPLATE_PATH.exists():
return _DEFAULT_TEMPLATE_PATH.read_text(encoding="utf-8")
# Last-resort embedded fallback if the OSS template file is missing
# from the install (e.g., partial Docker COPY).
return (
"# {{ instance.name }} — AI Data Analyst\n\n"
"This workspace is connected to {{ server.url }}.\n"
"Data refreshes every {{ sync_interval }}.\n"
)
def _list_tables(conn: duckdb.DuckDBPyConnection) -> list[dict[str, Any]]:
rows = conn.execute(
"""SELECT name, description, query_mode
FROM table_registry
ORDER BY name"""
).fetchall()
return [
{"name": r[0], "description": r[1] or "", "query_mode": r[2] or "local"}
for r in rows
]
def _metrics_summary(conn: duckdb.DuckDBPyConnection) -> dict[str, Any]:
try:
rows = conn.execute(
"SELECT category, COUNT(*) FROM metric_definitions GROUP BY category"
).fetchall()
except duckdb.CatalogException:
return {"count": 0, "categories": []}
return {
"count": sum(r[1] for r in rows),
"categories": sorted({r[0] for r in rows if r[0]}),
}
def _marketplaces_for_user(
conn: duckdb.DuckDBPyConnection, user: dict[str, Any]
) -> list[dict[str, Any]]:
"""Return marketplaces with the plugins the user is allowed to see.
Delegates RBAC filtering entirely to resolve_allowed_plugins, which
returns List[dict] with marketplace_slug, original_name, etc.
Results are grouped by marketplace slug; display names are fetched
from marketplace_registry in a single query.
"""
allowed = resolve_allowed_plugins(conn, user)
if not allowed:
return []
# Build slug → display name lookup from registry
slugs = list({p["marketplace_slug"] for p in allowed})
placeholders = ",".join(["?"] * len(slugs))
name_rows = conn.execute(
f"SELECT id, name FROM marketplace_registry WHERE id IN ({placeholders})",
slugs,
).fetchall()
slug_to_name: dict[str, str] = {r[0]: r[1] for r in name_rows}
grouped: dict[str, dict[str, Any]] = {}
for plugin in allowed:
slug = plugin["marketplace_slug"]
bucket = grouped.setdefault(
slug,
{
"slug": slug,
"name": slug_to_name.get(slug, slug),
"plugins": [],
},
)
bucket["plugins"].append({"name": plugin["original_name"]})
return list(grouped.values())
def build_context(
conn: duckdb.DuckDBPyConnection,
*,
user: dict[str, Any],
server_url: str,
) -> dict[str, Any]:
"""Compose the Jinja2 render context. Pure, no side effects."""
now = datetime.now(timezone.utc)
parsed = urlparse(server_url)
return {
"instance": {
"name": get_instance_name(),
"subtitle": get_instance_subtitle(),
},
"server": {
"url": server_url,
"hostname": parsed.hostname or "",
},
"sync_interval": get_sync_interval(),
"data_source": {"type": get_data_source_type()},
"tables": _list_tables(conn),
"metrics": _metrics_summary(conn),
"marketplaces": _marketplaces_for_user(conn, user),
"user": {
"id": user.get("id", ""),
"email": user.get("email", ""),
"name": user.get("name") or "",
"is_admin": bool(user.get("is_admin")),
"groups": user.get("groups") or [],
},
"now": now,
"today": date.today().isoformat(),
}
def _resolve_template_source(conn: duckdb.DuckDBPyConnection) -> str:
row = WelcomeTemplateRepository(conn).get()
return row["content"] if row.get("content") else _load_default_template()
def render_welcome(
conn: duckdb.DuckDBPyConnection,
*,
user: dict[str, Any],
server_url: str,
) -> str:
"""Resolve the active template and render it for the given user."""
source = _resolve_template_source(conn)
env = Environment(undefined=StrictUndefined, autoescape=False)
template = env.from_string(source)
return template.render(**build_context(conn, user=user, server_url=server_url))

View file

@ -0,0 +1,56 @@
"""Unit tests for the welcome-prompt renderer."""
from pathlib import Path
import duckdb
import pytest
from src.db import _ensure_schema
from src.repositories.welcome_template import WelcomeTemplateRepository
from src.welcome_template import build_context, render_welcome
@pytest.fixture
def conn(tmp_path, monkeypatch):
monkeypatch.setenv("DATA_DIR", str(tmp_path))
db_path = tmp_path / "system.duckdb"
c = duckdb.connect(str(db_path))
_ensure_schema(c)
yield c
c.close()
def _user(email="alice@example.com"):
return {"id": "u1", "email": email, "name": "Alice", "is_admin": False, "groups": ["Everyone"]}
def test_renders_default_when_no_override(conn):
out = render_welcome(conn, user=_user(), server_url="https://example.com")
assert "AI Data Analyst" in out
assert "https://example.com" in out
assert "Alice" in out
def test_renders_override(conn):
WelcomeTemplateRepository(conn).set(
"# {{ instance.name }} for {{ user.email }}",
updated_by="admin@example.com",
)
out = render_welcome(conn, user=_user(), server_url="https://example.com")
assert out.startswith("# AI Data Analyst for alice@example.com")
def test_strict_undefined_raises_on_missing_placeholder(conn):
WelcomeTemplateRepository(conn).set(
"{{ does_not_exist }}", updated_by="admin@example.com"
)
with pytest.raises(Exception) as exc_info:
render_welcome(conn, user=_user(), server_url="https://example.com")
assert "does_not_exist" in str(exc_info.value)
def test_context_exposes_documented_keys(conn):
ctx = build_context(conn, user=_user(), server_url="https://example.com")
for top in ("instance", "server", "sync_interval", "data_source",
"tables", "metrics", "marketplaces", "user", "now", "today"):
assert top in ctx, f"missing top-level key: {top}"