feat(cli): da analyst setup fetches rendered welcome from /api/welcome
This commit is contained in:
parent
ecaa113c68
commit
c604dad9cf
3 changed files with 111 additions and 109 deletions
|
|
@ -5,8 +5,8 @@ import re
|
|||
from datetime import datetime, timezone
|
||||
from pathlib import Path
|
||||
from typing import Optional
|
||||
from urllib.parse import urlparse
|
||||
|
||||
import httpx
|
||||
import typer
|
||||
|
||||
_SAFE_IDENTIFIER = re.compile(r"^[a-zA-Z_][a-zA-Z0-9_]{0,63}$")
|
||||
|
|
@ -247,32 +247,6 @@ def _initialize_duckdb(workspace: Path) -> int:
|
|||
return total_rows
|
||||
|
||||
|
||||
# ---------------------------------------------------------------------------
|
||||
# Helper: resolve instance name
|
||||
# ---------------------------------------------------------------------------
|
||||
|
||||
def _get_instance_name(server_url: str, token: str) -> str:
|
||||
"""Retrieve instance name from /api/health, fall back to hostname."""
|
||||
import httpx
|
||||
|
||||
server_url = server_url.rstrip("/")
|
||||
headers = {"Authorization": f"Bearer {token}"}
|
||||
|
||||
try:
|
||||
resp = httpx.get(f"{server_url}/api/health", headers=headers, timeout=10.0)
|
||||
if resp.status_code == 200:
|
||||
data = resp.json()
|
||||
name = data.get("instance_name") or data.get("name")
|
||||
if name:
|
||||
return name
|
||||
except Exception:
|
||||
pass
|
||||
|
||||
# Fall back to hostname extracted from URL
|
||||
parsed = urlparse(server_url)
|
||||
return parsed.hostname or "AI Data Analyst"
|
||||
|
||||
|
||||
# ---------------------------------------------------------------------------
|
||||
# Helper: install SessionStart/End hooks into a Claude settings file
|
||||
# ---------------------------------------------------------------------------
|
||||
|
|
@ -320,40 +294,42 @@ def _install_claude_hooks(settings_path: Path) -> None:
|
|||
|
||||
|
||||
# ---------------------------------------------------------------------------
|
||||
# Helper: generate CLAUDE.md from template
|
||||
# Helper: generate CLAUDE.md from server-rendered template
|
||||
# ---------------------------------------------------------------------------
|
||||
|
||||
def _generate_claude_md(
|
||||
workspace: Path,
|
||||
instance_name: str,
|
||||
server_url: str,
|
||||
sync_interval: str,
|
||||
) -> None:
|
||||
"""Write CLAUDE.md from the template; create CLAUDE.local.md if absent."""
|
||||
# Locate template relative to this file (../../config/claude_md_template.txt)
|
||||
here = Path(__file__).parent
|
||||
template_path = here.parent.parent / "config" / "claude_md_template.txt"
|
||||
def _generate_claude_md(workspace: Path, server_url: str, token: str) -> None:
|
||||
"""Fetch the rendered welcome prompt from the server and write CLAUDE.md.
|
||||
|
||||
if template_path.exists():
|
||||
template = template_path.read_text(encoding="utf-8")
|
||||
else:
|
||||
# Fallback minimal template
|
||||
template = (
|
||||
"# {instance_name} — AI Data Analyst\n\n"
|
||||
"This workspace is connected to {server_url}.\n\n"
|
||||
"- Data on the server refreshes every {sync_interval}\n"
|
||||
Falls back to a minimal embedded template if the server endpoint is
|
||||
unavailable (e.g., older server versions before /api/welcome shipped).
|
||||
"""
|
||||
from urllib.parse import quote
|
||||
|
||||
server_url = server_url.rstrip("/")
|
||||
headers = {"Authorization": f"Bearer {token}"}
|
||||
url = f"{server_url}/api/welcome?server_url={quote(server_url, safe='')}"
|
||||
|
||||
rendered: Optional[str] = None
|
||||
try:
|
||||
resp = httpx.get(url, headers=headers, timeout=15.0)
|
||||
if resp.status_code == 200:
|
||||
rendered = resp.json().get("content")
|
||||
except Exception:
|
||||
pass
|
||||
|
||||
if rendered is None:
|
||||
# Fallback for older servers — keeps the CLI usable, just less rich.
|
||||
rendered = (
|
||||
"# AI Data Analyst\n\n"
|
||||
f"This workspace is connected to {server_url}.\n\n"
|
||||
"## Rules\n"
|
||||
"- Before computing any business metric: run `da metrics show <category>/<name>`\n"
|
||||
"- Save work output to `user/artifacts/`\n"
|
||||
"- Sync data regularly with `da sync`\n"
|
||||
)
|
||||
|
||||
content = (
|
||||
template
|
||||
.replace("{instance_name}", instance_name)
|
||||
.replace("{server_url}", server_url)
|
||||
.replace("{sync_interval}", sync_interval)
|
||||
)
|
||||
(workspace / "CLAUDE.md").write_text(rendered, encoding="utf-8")
|
||||
|
||||
(workspace / "CLAUDE.md").write_text(content, encoding="utf-8")
|
||||
|
||||
# .claude/CLAUDE.local.md — never overwrite if it already exists
|
||||
local_md = workspace / ".claude" / "CLAUDE.local.md"
|
||||
if not local_md.exists():
|
||||
local_md.write_text(
|
||||
|
|
@ -402,7 +378,6 @@ def _check_data_freshness(workspace: Path) -> str:
|
|||
def setup(
|
||||
server_url: str = typer.Option(..., "--server-url", help="URL of the AI Data Analyst server"),
|
||||
force: bool = typer.Option(False, "--force", help="Re-initialise even if workspace already exists"),
|
||||
sync_interval: str = typer.Option("1 hour", "--sync-interval", help="Data refresh interval shown in CLAUDE.md"),
|
||||
workspace_dir: Optional[str] = typer.Option(None, "--workspace", help="Workspace directory (default: current dir)"),
|
||||
):
|
||||
"""Bootstrap a new analyst workspace from a remote server."""
|
||||
|
|
@ -437,15 +412,13 @@ def setup(
|
|||
typer.echo("Initialising DuckDB views...")
|
||||
total_rows = _initialize_duckdb(workspace)
|
||||
|
||||
# 7. Generate CLAUDE.md
|
||||
typer.echo("Generating CLAUDE.md...")
|
||||
instance_name = _get_instance_name(server_url, token)
|
||||
_generate_claude_md(workspace, instance_name, server_url, sync_interval)
|
||||
# 7. Generate CLAUDE.md (rendered server-side)
|
||||
typer.echo("Fetching welcome prompt from server...")
|
||||
_generate_claude_md(workspace, server_url, token)
|
||||
|
||||
# 8. Summary
|
||||
typer.echo("")
|
||||
typer.echo("Setup complete!")
|
||||
typer.echo(f" Instance : {instance_name}")
|
||||
typer.echo(f" Server : {server_url}")
|
||||
typer.echo(f" Tables : {n_downloaded} downloaded, {total_rows} total rows")
|
||||
typer.echo(f" Workspace: {workspace}")
|
||||
|
|
|
|||
|
|
@ -78,7 +78,6 @@ class TestDetectExistingProject:
|
|||
patch("cli.commands.analyst._download_metadata"), \
|
||||
patch("cli.commands.analyst._download_data", return_value=0), \
|
||||
patch("cli.commands.analyst._initialize_duckdb", return_value=0), \
|
||||
patch("cli.commands.analyst._get_instance_name", return_value="Acme"), \
|
||||
patch("cli.commands.analyst._generate_claude_md"):
|
||||
result = runner.invoke(
|
||||
app,
|
||||
|
|
@ -121,38 +120,36 @@ class TestCreateWorkspace:
|
|||
# ---------------------------------------------------------------------------
|
||||
|
||||
class TestGenerateClaudeMd:
|
||||
def test_template_substitution(self, tmp_workspace):
|
||||
"""Server-side render flow: _generate_claude_md fetches /api/welcome.
|
||||
|
||||
The local-fallback path is exercised by tests/test_cli_analyst_welcome.py.
|
||||
These tests cover the side-effects on the workspace (CLAUDE.local.md,
|
||||
settings.json) and verify the new signature is honored.
|
||||
"""
|
||||
|
||||
def _patch_httpx_404(self, monkeypatch):
|
||||
"""Stub httpx.get to return 404 so _generate_claude_md falls back to embedded text."""
|
||||
import httpx
|
||||
|
||||
def fake_get(url, headers=None, timeout=None):
|
||||
return httpx.Response(
|
||||
status_code=404, json={}, request=httpx.Request("GET", url)
|
||||
)
|
||||
|
||||
monkeypatch.setattr("cli.commands.analyst.httpx", type("_M", (), {"get": fake_get}))
|
||||
|
||||
def test_creates_claude_local_md_when_absent(self, tmp_workspace, monkeypatch):
|
||||
from cli.commands.analyst import _create_workspace, _generate_claude_md
|
||||
|
||||
_create_workspace(tmp_workspace)
|
||||
_generate_claude_md(
|
||||
tmp_workspace,
|
||||
instance_name="Acme Corp",
|
||||
server_url="https://data.acme.com",
|
||||
sync_interval="2 hours",
|
||||
)
|
||||
|
||||
content = (tmp_workspace / "CLAUDE.md").read_text(encoding="utf-8")
|
||||
assert "Acme Corp" in content
|
||||
assert "https://data.acme.com" in content
|
||||
assert "2 hours" in content
|
||||
|
||||
def test_creates_claude_local_md_when_absent(self, tmp_workspace):
|
||||
from cli.commands.analyst import _create_workspace, _generate_claude_md
|
||||
|
||||
_create_workspace(tmp_workspace)
|
||||
_generate_claude_md(
|
||||
tmp_workspace,
|
||||
instance_name="Acme",
|
||||
server_url="http://localhost:8000",
|
||||
sync_interval="1 hour",
|
||||
)
|
||||
self._patch_httpx_404(monkeypatch)
|
||||
_generate_claude_md(tmp_workspace, server_url="http://localhost:8000", token="t")
|
||||
|
||||
local_md = tmp_workspace / ".claude" / "CLAUDE.local.md"
|
||||
assert local_md.exists()
|
||||
assert local_md.read_text(encoding="utf-8").strip() != ""
|
||||
|
||||
def test_does_not_overwrite_existing_local_md(self, tmp_workspace):
|
||||
def test_does_not_overwrite_existing_local_md(self, tmp_workspace, monkeypatch):
|
||||
from cli.commands.analyst import _create_workspace, _generate_claude_md
|
||||
|
||||
_create_workspace(tmp_workspace)
|
||||
|
|
@ -160,36 +157,24 @@ class TestGenerateClaudeMd:
|
|||
original_content = "# My custom notes\n\nDo not overwrite me.\n"
|
||||
local_md.write_text(original_content, encoding="utf-8")
|
||||
|
||||
_generate_claude_md(
|
||||
tmp_workspace,
|
||||
instance_name="Acme",
|
||||
server_url="http://localhost:8000",
|
||||
sync_interval="1 hour",
|
||||
)
|
||||
self._patch_httpx_404(monkeypatch)
|
||||
_generate_claude_md(tmp_workspace, server_url="http://localhost:8000", token="t")
|
||||
|
||||
assert local_md.read_text(encoding="utf-8") == original_content
|
||||
|
||||
def test_uses_template_file_if_available(self, tmp_workspace):
|
||||
"""Smoke-test that the real template file is found and substituted."""
|
||||
def test_writes_settings_json(self, tmp_workspace, monkeypatch):
|
||||
from cli.commands.analyst import _create_workspace, _generate_claude_md
|
||||
import json as _json
|
||||
|
||||
_create_workspace(tmp_workspace)
|
||||
_generate_claude_md(
|
||||
tmp_workspace,
|
||||
instance_name="TestCo",
|
||||
server_url="https://test.example.com",
|
||||
sync_interval="30 minutes",
|
||||
)
|
||||
self._patch_httpx_404(monkeypatch)
|
||||
_generate_claude_md(tmp_workspace, server_url="http://localhost:8000", token="t")
|
||||
|
||||
content = (tmp_workspace / "CLAUDE.md").read_text(encoding="utf-8")
|
||||
# Template contains these literals after substitution
|
||||
assert "TestCo" in content
|
||||
assert "https://test.example.com" in content
|
||||
assert "30 minutes" in content
|
||||
# Ensure placeholders are gone
|
||||
assert "{instance_name}" not in content
|
||||
assert "{server_url}" not in content
|
||||
assert "{sync_interval}" not in content
|
||||
settings = _json.loads(
|
||||
(tmp_workspace / ".claude" / "settings.json").read_text(encoding="utf-8")
|
||||
)
|
||||
assert settings["model"] == "sonnet"
|
||||
assert "Read" in settings["permissions"]["allow"]
|
||||
|
||||
|
||||
# ---------------------------------------------------------------------------
|
||||
|
|
|
|||
44
tests/test_cli_analyst_welcome.py
Normal file
44
tests/test_cli_analyst_welcome.py
Normal file
|
|
@ -0,0 +1,44 @@
|
|||
"""Integration tests for da analyst setup → /api/welcome wiring."""
|
||||
|
||||
from pathlib import Path
|
||||
|
||||
import httpx
|
||||
import pytest
|
||||
|
||||
from cli.commands.analyst import _generate_claude_md
|
||||
|
||||
|
||||
class _MockClient:
|
||||
def __init__(self, responses):
|
||||
self._responses = responses
|
||||
self.calls = []
|
||||
|
||||
def get(self, url, headers=None, timeout=None):
|
||||
self.calls.append(url)
|
||||
body, status = self._responses.get(url, ({}, 404))
|
||||
return httpx.Response(status_code=status, json=body, request=httpx.Request("GET", url))
|
||||
|
||||
|
||||
def test_generate_claude_md_uses_server_render(tmp_path, monkeypatch):
|
||||
workspace = tmp_path / "ws"
|
||||
(workspace / ".claude").mkdir(parents=True)
|
||||
rendered = "# CUSTOM\n\nFrom server.\n"
|
||||
mock = _MockClient({
|
||||
"https://example.com/api/welcome?server_url=https%3A%2F%2Fexample.com": (
|
||||
{"content": rendered}, 200
|
||||
),
|
||||
})
|
||||
monkeypatch.setattr("cli.commands.analyst.httpx", type("_M", (), {"get": mock.get}))
|
||||
_generate_claude_md(workspace, server_url="https://example.com", token="t")
|
||||
assert (workspace / "CLAUDE.md").read_text(encoding="utf-8") == rendered
|
||||
|
||||
|
||||
def test_generate_claude_md_falls_back_on_404(tmp_path, monkeypatch):
|
||||
workspace = tmp_path / "ws"
|
||||
(workspace / ".claude").mkdir(parents=True)
|
||||
mock = _MockClient({}) # everything 404s
|
||||
monkeypatch.setattr("cli.commands.analyst.httpx", type("_M", (), {"get": mock.get}))
|
||||
_generate_claude_md(workspace, server_url="https://example.com", token="t")
|
||||
body = (workspace / "CLAUDE.md").read_text(encoding="utf-8")
|
||||
assert "AI Data Analyst" in body # embedded fallback contains this string
|
||||
assert "https://example.com" in body
|
||||
Loading…
Reference in a new issue