feat(cli): da analyst setup fetches rendered welcome from /api/welcome

This commit is contained in:
ZdenekSrotyr 2026-04-30 19:25:52 +02:00
parent ecaa113c68
commit c604dad9cf
3 changed files with 111 additions and 109 deletions

View file

@ -5,8 +5,8 @@ import re
from datetime import datetime, timezone from datetime import datetime, timezone
from pathlib import Path from pathlib import Path
from typing import Optional from typing import Optional
from urllib.parse import urlparse
import httpx
import typer import typer
_SAFE_IDENTIFIER = re.compile(r"^[a-zA-Z_][a-zA-Z0-9_]{0,63}$") _SAFE_IDENTIFIER = re.compile(r"^[a-zA-Z_][a-zA-Z0-9_]{0,63}$")
@ -247,32 +247,6 @@ def _initialize_duckdb(workspace: Path) -> int:
return total_rows return total_rows
# ---------------------------------------------------------------------------
# Helper: resolve instance name
# ---------------------------------------------------------------------------
def _get_instance_name(server_url: str, token: str) -> str:
"""Retrieve instance name from /api/health, fall back to hostname."""
import httpx
server_url = server_url.rstrip("/")
headers = {"Authorization": f"Bearer {token}"}
try:
resp = httpx.get(f"{server_url}/api/health", headers=headers, timeout=10.0)
if resp.status_code == 200:
data = resp.json()
name = data.get("instance_name") or data.get("name")
if name:
return name
except Exception:
pass
# Fall back to hostname extracted from URL
parsed = urlparse(server_url)
return parsed.hostname or "AI Data Analyst"
# --------------------------------------------------------------------------- # ---------------------------------------------------------------------------
# Helper: install SessionStart/End hooks into a Claude settings file # Helper: install SessionStart/End hooks into a Claude settings file
# --------------------------------------------------------------------------- # ---------------------------------------------------------------------------
@ -320,40 +294,42 @@ def _install_claude_hooks(settings_path: Path) -> None:
# --------------------------------------------------------------------------- # ---------------------------------------------------------------------------
# Helper: generate CLAUDE.md from template # Helper: generate CLAUDE.md from server-rendered template
# --------------------------------------------------------------------------- # ---------------------------------------------------------------------------
def _generate_claude_md( def _generate_claude_md(workspace: Path, server_url: str, token: str) -> None:
workspace: Path, """Fetch the rendered welcome prompt from the server and write CLAUDE.md.
instance_name: str,
server_url: str,
sync_interval: str,
) -> None:
"""Write CLAUDE.md from the template; create CLAUDE.local.md if absent."""
# Locate template relative to this file (../../config/claude_md_template.txt)
here = Path(__file__).parent
template_path = here.parent.parent / "config" / "claude_md_template.txt"
if template_path.exists(): Falls back to a minimal embedded template if the server endpoint is
template = template_path.read_text(encoding="utf-8") unavailable (e.g., older server versions before /api/welcome shipped).
else: """
# Fallback minimal template from urllib.parse import quote
template = (
"# {instance_name} — AI Data Analyst\n\n" server_url = server_url.rstrip("/")
"This workspace is connected to {server_url}.\n\n" headers = {"Authorization": f"Bearer {token}"}
"- Data on the server refreshes every {sync_interval}\n" url = f"{server_url}/api/welcome?server_url={quote(server_url, safe='')}"
rendered: Optional[str] = None
try:
resp = httpx.get(url, headers=headers, timeout=15.0)
if resp.status_code == 200:
rendered = resp.json().get("content")
except Exception:
pass
if rendered is None:
# Fallback for older servers — keeps the CLI usable, just less rich.
rendered = (
"# AI Data Analyst\n\n"
f"This workspace is connected to {server_url}.\n\n"
"## Rules\n"
"- Before computing any business metric: run `da metrics show <category>/<name>`\n"
"- Save work output to `user/artifacts/`\n"
"- Sync data regularly with `da sync`\n"
) )
content = ( (workspace / "CLAUDE.md").write_text(rendered, encoding="utf-8")
template
.replace("{instance_name}", instance_name)
.replace("{server_url}", server_url)
.replace("{sync_interval}", sync_interval)
)
(workspace / "CLAUDE.md").write_text(content, encoding="utf-8")
# .claude/CLAUDE.local.md — never overwrite if it already exists
local_md = workspace / ".claude" / "CLAUDE.local.md" local_md = workspace / ".claude" / "CLAUDE.local.md"
if not local_md.exists(): if not local_md.exists():
local_md.write_text( local_md.write_text(
@ -402,7 +378,6 @@ def _check_data_freshness(workspace: Path) -> str:
def setup( def setup(
server_url: str = typer.Option(..., "--server-url", help="URL of the AI Data Analyst server"), server_url: str = typer.Option(..., "--server-url", help="URL of the AI Data Analyst server"),
force: bool = typer.Option(False, "--force", help="Re-initialise even if workspace already exists"), force: bool = typer.Option(False, "--force", help="Re-initialise even if workspace already exists"),
sync_interval: str = typer.Option("1 hour", "--sync-interval", help="Data refresh interval shown in CLAUDE.md"),
workspace_dir: Optional[str] = typer.Option(None, "--workspace", help="Workspace directory (default: current dir)"), workspace_dir: Optional[str] = typer.Option(None, "--workspace", help="Workspace directory (default: current dir)"),
): ):
"""Bootstrap a new analyst workspace from a remote server.""" """Bootstrap a new analyst workspace from a remote server."""
@ -437,15 +412,13 @@ def setup(
typer.echo("Initialising DuckDB views...") typer.echo("Initialising DuckDB views...")
total_rows = _initialize_duckdb(workspace) total_rows = _initialize_duckdb(workspace)
# 7. Generate CLAUDE.md # 7. Generate CLAUDE.md (rendered server-side)
typer.echo("Generating CLAUDE.md...") typer.echo("Fetching welcome prompt from server...")
instance_name = _get_instance_name(server_url, token) _generate_claude_md(workspace, server_url, token)
_generate_claude_md(workspace, instance_name, server_url, sync_interval)
# 8. Summary # 8. Summary
typer.echo("") typer.echo("")
typer.echo("Setup complete!") typer.echo("Setup complete!")
typer.echo(f" Instance : {instance_name}")
typer.echo(f" Server : {server_url}") typer.echo(f" Server : {server_url}")
typer.echo(f" Tables : {n_downloaded} downloaded, {total_rows} total rows") typer.echo(f" Tables : {n_downloaded} downloaded, {total_rows} total rows")
typer.echo(f" Workspace: {workspace}") typer.echo(f" Workspace: {workspace}")

View file

@ -78,7 +78,6 @@ class TestDetectExistingProject:
patch("cli.commands.analyst._download_metadata"), \ patch("cli.commands.analyst._download_metadata"), \
patch("cli.commands.analyst._download_data", return_value=0), \ patch("cli.commands.analyst._download_data", return_value=0), \
patch("cli.commands.analyst._initialize_duckdb", return_value=0), \ patch("cli.commands.analyst._initialize_duckdb", return_value=0), \
patch("cli.commands.analyst._get_instance_name", return_value="Acme"), \
patch("cli.commands.analyst._generate_claude_md"): patch("cli.commands.analyst._generate_claude_md"):
result = runner.invoke( result = runner.invoke(
app, app,
@ -121,38 +120,36 @@ class TestCreateWorkspace:
# --------------------------------------------------------------------------- # ---------------------------------------------------------------------------
class TestGenerateClaudeMd: class TestGenerateClaudeMd:
def test_template_substitution(self, tmp_workspace): """Server-side render flow: _generate_claude_md fetches /api/welcome.
The local-fallback path is exercised by tests/test_cli_analyst_welcome.py.
These tests cover the side-effects on the workspace (CLAUDE.local.md,
settings.json) and verify the new signature is honored.
"""
def _patch_httpx_404(self, monkeypatch):
"""Stub httpx.get to return 404 so _generate_claude_md falls back to embedded text."""
import httpx
def fake_get(url, headers=None, timeout=None):
return httpx.Response(
status_code=404, json={}, request=httpx.Request("GET", url)
)
monkeypatch.setattr("cli.commands.analyst.httpx", type("_M", (), {"get": fake_get}))
def test_creates_claude_local_md_when_absent(self, tmp_workspace, monkeypatch):
from cli.commands.analyst import _create_workspace, _generate_claude_md from cli.commands.analyst import _create_workspace, _generate_claude_md
_create_workspace(tmp_workspace) _create_workspace(tmp_workspace)
_generate_claude_md( self._patch_httpx_404(monkeypatch)
tmp_workspace, _generate_claude_md(tmp_workspace, server_url="http://localhost:8000", token="t")
instance_name="Acme Corp",
server_url="https://data.acme.com",
sync_interval="2 hours",
)
content = (tmp_workspace / "CLAUDE.md").read_text(encoding="utf-8")
assert "Acme Corp" in content
assert "https://data.acme.com" in content
assert "2 hours" in content
def test_creates_claude_local_md_when_absent(self, tmp_workspace):
from cli.commands.analyst import _create_workspace, _generate_claude_md
_create_workspace(tmp_workspace)
_generate_claude_md(
tmp_workspace,
instance_name="Acme",
server_url="http://localhost:8000",
sync_interval="1 hour",
)
local_md = tmp_workspace / ".claude" / "CLAUDE.local.md" local_md = tmp_workspace / ".claude" / "CLAUDE.local.md"
assert local_md.exists() assert local_md.exists()
assert local_md.read_text(encoding="utf-8").strip() != "" assert local_md.read_text(encoding="utf-8").strip() != ""
def test_does_not_overwrite_existing_local_md(self, tmp_workspace): def test_does_not_overwrite_existing_local_md(self, tmp_workspace, monkeypatch):
from cli.commands.analyst import _create_workspace, _generate_claude_md from cli.commands.analyst import _create_workspace, _generate_claude_md
_create_workspace(tmp_workspace) _create_workspace(tmp_workspace)
@ -160,36 +157,24 @@ class TestGenerateClaudeMd:
original_content = "# My custom notes\n\nDo not overwrite me.\n" original_content = "# My custom notes\n\nDo not overwrite me.\n"
local_md.write_text(original_content, encoding="utf-8") local_md.write_text(original_content, encoding="utf-8")
_generate_claude_md( self._patch_httpx_404(monkeypatch)
tmp_workspace, _generate_claude_md(tmp_workspace, server_url="http://localhost:8000", token="t")
instance_name="Acme",
server_url="http://localhost:8000",
sync_interval="1 hour",
)
assert local_md.read_text(encoding="utf-8") == original_content assert local_md.read_text(encoding="utf-8") == original_content
def test_uses_template_file_if_available(self, tmp_workspace): def test_writes_settings_json(self, tmp_workspace, monkeypatch):
"""Smoke-test that the real template file is found and substituted."""
from cli.commands.analyst import _create_workspace, _generate_claude_md from cli.commands.analyst import _create_workspace, _generate_claude_md
import json as _json
_create_workspace(tmp_workspace) _create_workspace(tmp_workspace)
_generate_claude_md( self._patch_httpx_404(monkeypatch)
tmp_workspace, _generate_claude_md(tmp_workspace, server_url="http://localhost:8000", token="t")
instance_name="TestCo",
server_url="https://test.example.com",
sync_interval="30 minutes",
)
content = (tmp_workspace / "CLAUDE.md").read_text(encoding="utf-8") settings = _json.loads(
# Template contains these literals after substitution (tmp_workspace / ".claude" / "settings.json").read_text(encoding="utf-8")
assert "TestCo" in content )
assert "https://test.example.com" in content assert settings["model"] == "sonnet"
assert "30 minutes" in content assert "Read" in settings["permissions"]["allow"]
# Ensure placeholders are gone
assert "{instance_name}" not in content
assert "{server_url}" not in content
assert "{sync_interval}" not in content
# --------------------------------------------------------------------------- # ---------------------------------------------------------------------------

View file

@ -0,0 +1,44 @@
"""Integration tests for da analyst setup → /api/welcome wiring."""
from pathlib import Path
import httpx
import pytest
from cli.commands.analyst import _generate_claude_md
class _MockClient:
def __init__(self, responses):
self._responses = responses
self.calls = []
def get(self, url, headers=None, timeout=None):
self.calls.append(url)
body, status = self._responses.get(url, ({}, 404))
return httpx.Response(status_code=status, json=body, request=httpx.Request("GET", url))
def test_generate_claude_md_uses_server_render(tmp_path, monkeypatch):
workspace = tmp_path / "ws"
(workspace / ".claude").mkdir(parents=True)
rendered = "# CUSTOM\n\nFrom server.\n"
mock = _MockClient({
"https://example.com/api/welcome?server_url=https%3A%2F%2Fexample.com": (
{"content": rendered}, 200
),
})
monkeypatch.setattr("cli.commands.analyst.httpx", type("_M", (), {"get": mock.get}))
_generate_claude_md(workspace, server_url="https://example.com", token="t")
assert (workspace / "CLAUDE.md").read_text(encoding="utf-8") == rendered
def test_generate_claude_md_falls_back_on_404(tmp_path, monkeypatch):
workspace = tmp_path / "ws"
(workspace / ".claude").mkdir(parents=True)
mock = _MockClient({}) # everything 404s
monkeypatch.setattr("cli.commands.analyst.httpx", type("_M", (), {"get": mock.get}))
_generate_claude_md(workspace, server_url="https://example.com", token="t")
body = (workspace / "CLAUDE.md").read_text(encoding="utf-8")
assert "AI Data Analyst" in body # embedded fallback contains this string
assert "https://example.com" in body