feat(cli): da analyst setup fetches rendered welcome from /api/welcome
This commit is contained in:
parent
ecaa113c68
commit
c604dad9cf
3 changed files with 111 additions and 109 deletions
|
|
@ -5,8 +5,8 @@ import re
|
||||||
from datetime import datetime, timezone
|
from datetime import datetime, timezone
|
||||||
from pathlib import Path
|
from pathlib import Path
|
||||||
from typing import Optional
|
from typing import Optional
|
||||||
from urllib.parse import urlparse
|
|
||||||
|
|
||||||
|
import httpx
|
||||||
import typer
|
import typer
|
||||||
|
|
||||||
_SAFE_IDENTIFIER = re.compile(r"^[a-zA-Z_][a-zA-Z0-9_]{0,63}$")
|
_SAFE_IDENTIFIER = re.compile(r"^[a-zA-Z_][a-zA-Z0-9_]{0,63}$")
|
||||||
|
|
@ -247,32 +247,6 @@ def _initialize_duckdb(workspace: Path) -> int:
|
||||||
return total_rows
|
return total_rows
|
||||||
|
|
||||||
|
|
||||||
# ---------------------------------------------------------------------------
|
|
||||||
# Helper: resolve instance name
|
|
||||||
# ---------------------------------------------------------------------------
|
|
||||||
|
|
||||||
def _get_instance_name(server_url: str, token: str) -> str:
|
|
||||||
"""Retrieve instance name from /api/health, fall back to hostname."""
|
|
||||||
import httpx
|
|
||||||
|
|
||||||
server_url = server_url.rstrip("/")
|
|
||||||
headers = {"Authorization": f"Bearer {token}"}
|
|
||||||
|
|
||||||
try:
|
|
||||||
resp = httpx.get(f"{server_url}/api/health", headers=headers, timeout=10.0)
|
|
||||||
if resp.status_code == 200:
|
|
||||||
data = resp.json()
|
|
||||||
name = data.get("instance_name") or data.get("name")
|
|
||||||
if name:
|
|
||||||
return name
|
|
||||||
except Exception:
|
|
||||||
pass
|
|
||||||
|
|
||||||
# Fall back to hostname extracted from URL
|
|
||||||
parsed = urlparse(server_url)
|
|
||||||
return parsed.hostname or "AI Data Analyst"
|
|
||||||
|
|
||||||
|
|
||||||
# ---------------------------------------------------------------------------
|
# ---------------------------------------------------------------------------
|
||||||
# Helper: install SessionStart/End hooks into a Claude settings file
|
# Helper: install SessionStart/End hooks into a Claude settings file
|
||||||
# ---------------------------------------------------------------------------
|
# ---------------------------------------------------------------------------
|
||||||
|
|
@ -320,40 +294,42 @@ def _install_claude_hooks(settings_path: Path) -> None:
|
||||||
|
|
||||||
|
|
||||||
# ---------------------------------------------------------------------------
|
# ---------------------------------------------------------------------------
|
||||||
# Helper: generate CLAUDE.md from template
|
# Helper: generate CLAUDE.md from server-rendered template
|
||||||
# ---------------------------------------------------------------------------
|
# ---------------------------------------------------------------------------
|
||||||
|
|
||||||
def _generate_claude_md(
|
def _generate_claude_md(workspace: Path, server_url: str, token: str) -> None:
|
||||||
workspace: Path,
|
"""Fetch the rendered welcome prompt from the server and write CLAUDE.md.
|
||||||
instance_name: str,
|
|
||||||
server_url: str,
|
|
||||||
sync_interval: str,
|
|
||||||
) -> None:
|
|
||||||
"""Write CLAUDE.md from the template; create CLAUDE.local.md if absent."""
|
|
||||||
# Locate template relative to this file (../../config/claude_md_template.txt)
|
|
||||||
here = Path(__file__).parent
|
|
||||||
template_path = here.parent.parent / "config" / "claude_md_template.txt"
|
|
||||||
|
|
||||||
if template_path.exists():
|
Falls back to a minimal embedded template if the server endpoint is
|
||||||
template = template_path.read_text(encoding="utf-8")
|
unavailable (e.g., older server versions before /api/welcome shipped).
|
||||||
else:
|
"""
|
||||||
# Fallback minimal template
|
from urllib.parse import quote
|
||||||
template = (
|
|
||||||
"# {instance_name} — AI Data Analyst\n\n"
|
server_url = server_url.rstrip("/")
|
||||||
"This workspace is connected to {server_url}.\n\n"
|
headers = {"Authorization": f"Bearer {token}"}
|
||||||
"- Data on the server refreshes every {sync_interval}\n"
|
url = f"{server_url}/api/welcome?server_url={quote(server_url, safe='')}"
|
||||||
|
|
||||||
|
rendered: Optional[str] = None
|
||||||
|
try:
|
||||||
|
resp = httpx.get(url, headers=headers, timeout=15.0)
|
||||||
|
if resp.status_code == 200:
|
||||||
|
rendered = resp.json().get("content")
|
||||||
|
except Exception:
|
||||||
|
pass
|
||||||
|
|
||||||
|
if rendered is None:
|
||||||
|
# Fallback for older servers — keeps the CLI usable, just less rich.
|
||||||
|
rendered = (
|
||||||
|
"# AI Data Analyst\n\n"
|
||||||
|
f"This workspace is connected to {server_url}.\n\n"
|
||||||
|
"## Rules\n"
|
||||||
|
"- Before computing any business metric: run `da metrics show <category>/<name>`\n"
|
||||||
|
"- Save work output to `user/artifacts/`\n"
|
||||||
|
"- Sync data regularly with `da sync`\n"
|
||||||
)
|
)
|
||||||
|
|
||||||
content = (
|
(workspace / "CLAUDE.md").write_text(rendered, encoding="utf-8")
|
||||||
template
|
|
||||||
.replace("{instance_name}", instance_name)
|
|
||||||
.replace("{server_url}", server_url)
|
|
||||||
.replace("{sync_interval}", sync_interval)
|
|
||||||
)
|
|
||||||
|
|
||||||
(workspace / "CLAUDE.md").write_text(content, encoding="utf-8")
|
|
||||||
|
|
||||||
# .claude/CLAUDE.local.md — never overwrite if it already exists
|
|
||||||
local_md = workspace / ".claude" / "CLAUDE.local.md"
|
local_md = workspace / ".claude" / "CLAUDE.local.md"
|
||||||
if not local_md.exists():
|
if not local_md.exists():
|
||||||
local_md.write_text(
|
local_md.write_text(
|
||||||
|
|
@ -402,7 +378,6 @@ def _check_data_freshness(workspace: Path) -> str:
|
||||||
def setup(
|
def setup(
|
||||||
server_url: str = typer.Option(..., "--server-url", help="URL of the AI Data Analyst server"),
|
server_url: str = typer.Option(..., "--server-url", help="URL of the AI Data Analyst server"),
|
||||||
force: bool = typer.Option(False, "--force", help="Re-initialise even if workspace already exists"),
|
force: bool = typer.Option(False, "--force", help="Re-initialise even if workspace already exists"),
|
||||||
sync_interval: str = typer.Option("1 hour", "--sync-interval", help="Data refresh interval shown in CLAUDE.md"),
|
|
||||||
workspace_dir: Optional[str] = typer.Option(None, "--workspace", help="Workspace directory (default: current dir)"),
|
workspace_dir: Optional[str] = typer.Option(None, "--workspace", help="Workspace directory (default: current dir)"),
|
||||||
):
|
):
|
||||||
"""Bootstrap a new analyst workspace from a remote server."""
|
"""Bootstrap a new analyst workspace from a remote server."""
|
||||||
|
|
@ -437,15 +412,13 @@ def setup(
|
||||||
typer.echo("Initialising DuckDB views...")
|
typer.echo("Initialising DuckDB views...")
|
||||||
total_rows = _initialize_duckdb(workspace)
|
total_rows = _initialize_duckdb(workspace)
|
||||||
|
|
||||||
# 7. Generate CLAUDE.md
|
# 7. Generate CLAUDE.md (rendered server-side)
|
||||||
typer.echo("Generating CLAUDE.md...")
|
typer.echo("Fetching welcome prompt from server...")
|
||||||
instance_name = _get_instance_name(server_url, token)
|
_generate_claude_md(workspace, server_url, token)
|
||||||
_generate_claude_md(workspace, instance_name, server_url, sync_interval)
|
|
||||||
|
|
||||||
# 8. Summary
|
# 8. Summary
|
||||||
typer.echo("")
|
typer.echo("")
|
||||||
typer.echo("Setup complete!")
|
typer.echo("Setup complete!")
|
||||||
typer.echo(f" Instance : {instance_name}")
|
|
||||||
typer.echo(f" Server : {server_url}")
|
typer.echo(f" Server : {server_url}")
|
||||||
typer.echo(f" Tables : {n_downloaded} downloaded, {total_rows} total rows")
|
typer.echo(f" Tables : {n_downloaded} downloaded, {total_rows} total rows")
|
||||||
typer.echo(f" Workspace: {workspace}")
|
typer.echo(f" Workspace: {workspace}")
|
||||||
|
|
|
||||||
|
|
@ -78,7 +78,6 @@ class TestDetectExistingProject:
|
||||||
patch("cli.commands.analyst._download_metadata"), \
|
patch("cli.commands.analyst._download_metadata"), \
|
||||||
patch("cli.commands.analyst._download_data", return_value=0), \
|
patch("cli.commands.analyst._download_data", return_value=0), \
|
||||||
patch("cli.commands.analyst._initialize_duckdb", return_value=0), \
|
patch("cli.commands.analyst._initialize_duckdb", return_value=0), \
|
||||||
patch("cli.commands.analyst._get_instance_name", return_value="Acme"), \
|
|
||||||
patch("cli.commands.analyst._generate_claude_md"):
|
patch("cli.commands.analyst._generate_claude_md"):
|
||||||
result = runner.invoke(
|
result = runner.invoke(
|
||||||
app,
|
app,
|
||||||
|
|
@ -121,38 +120,36 @@ class TestCreateWorkspace:
|
||||||
# ---------------------------------------------------------------------------
|
# ---------------------------------------------------------------------------
|
||||||
|
|
||||||
class TestGenerateClaudeMd:
|
class TestGenerateClaudeMd:
|
||||||
def test_template_substitution(self, tmp_workspace):
|
"""Server-side render flow: _generate_claude_md fetches /api/welcome.
|
||||||
|
|
||||||
|
The local-fallback path is exercised by tests/test_cli_analyst_welcome.py.
|
||||||
|
These tests cover the side-effects on the workspace (CLAUDE.local.md,
|
||||||
|
settings.json) and verify the new signature is honored.
|
||||||
|
"""
|
||||||
|
|
||||||
|
def _patch_httpx_404(self, monkeypatch):
|
||||||
|
"""Stub httpx.get to return 404 so _generate_claude_md falls back to embedded text."""
|
||||||
|
import httpx
|
||||||
|
|
||||||
|
def fake_get(url, headers=None, timeout=None):
|
||||||
|
return httpx.Response(
|
||||||
|
status_code=404, json={}, request=httpx.Request("GET", url)
|
||||||
|
)
|
||||||
|
|
||||||
|
monkeypatch.setattr("cli.commands.analyst.httpx", type("_M", (), {"get": fake_get}))
|
||||||
|
|
||||||
|
def test_creates_claude_local_md_when_absent(self, tmp_workspace, monkeypatch):
|
||||||
from cli.commands.analyst import _create_workspace, _generate_claude_md
|
from cli.commands.analyst import _create_workspace, _generate_claude_md
|
||||||
|
|
||||||
_create_workspace(tmp_workspace)
|
_create_workspace(tmp_workspace)
|
||||||
_generate_claude_md(
|
self._patch_httpx_404(monkeypatch)
|
||||||
tmp_workspace,
|
_generate_claude_md(tmp_workspace, server_url="http://localhost:8000", token="t")
|
||||||
instance_name="Acme Corp",
|
|
||||||
server_url="https://data.acme.com",
|
|
||||||
sync_interval="2 hours",
|
|
||||||
)
|
|
||||||
|
|
||||||
content = (tmp_workspace / "CLAUDE.md").read_text(encoding="utf-8")
|
|
||||||
assert "Acme Corp" in content
|
|
||||||
assert "https://data.acme.com" in content
|
|
||||||
assert "2 hours" in content
|
|
||||||
|
|
||||||
def test_creates_claude_local_md_when_absent(self, tmp_workspace):
|
|
||||||
from cli.commands.analyst import _create_workspace, _generate_claude_md
|
|
||||||
|
|
||||||
_create_workspace(tmp_workspace)
|
|
||||||
_generate_claude_md(
|
|
||||||
tmp_workspace,
|
|
||||||
instance_name="Acme",
|
|
||||||
server_url="http://localhost:8000",
|
|
||||||
sync_interval="1 hour",
|
|
||||||
)
|
|
||||||
|
|
||||||
local_md = tmp_workspace / ".claude" / "CLAUDE.local.md"
|
local_md = tmp_workspace / ".claude" / "CLAUDE.local.md"
|
||||||
assert local_md.exists()
|
assert local_md.exists()
|
||||||
assert local_md.read_text(encoding="utf-8").strip() != ""
|
assert local_md.read_text(encoding="utf-8").strip() != ""
|
||||||
|
|
||||||
def test_does_not_overwrite_existing_local_md(self, tmp_workspace):
|
def test_does_not_overwrite_existing_local_md(self, tmp_workspace, monkeypatch):
|
||||||
from cli.commands.analyst import _create_workspace, _generate_claude_md
|
from cli.commands.analyst import _create_workspace, _generate_claude_md
|
||||||
|
|
||||||
_create_workspace(tmp_workspace)
|
_create_workspace(tmp_workspace)
|
||||||
|
|
@ -160,36 +157,24 @@ class TestGenerateClaudeMd:
|
||||||
original_content = "# My custom notes\n\nDo not overwrite me.\n"
|
original_content = "# My custom notes\n\nDo not overwrite me.\n"
|
||||||
local_md.write_text(original_content, encoding="utf-8")
|
local_md.write_text(original_content, encoding="utf-8")
|
||||||
|
|
||||||
_generate_claude_md(
|
self._patch_httpx_404(monkeypatch)
|
||||||
tmp_workspace,
|
_generate_claude_md(tmp_workspace, server_url="http://localhost:8000", token="t")
|
||||||
instance_name="Acme",
|
|
||||||
server_url="http://localhost:8000",
|
|
||||||
sync_interval="1 hour",
|
|
||||||
)
|
|
||||||
|
|
||||||
assert local_md.read_text(encoding="utf-8") == original_content
|
assert local_md.read_text(encoding="utf-8") == original_content
|
||||||
|
|
||||||
def test_uses_template_file_if_available(self, tmp_workspace):
|
def test_writes_settings_json(self, tmp_workspace, monkeypatch):
|
||||||
"""Smoke-test that the real template file is found and substituted."""
|
|
||||||
from cli.commands.analyst import _create_workspace, _generate_claude_md
|
from cli.commands.analyst import _create_workspace, _generate_claude_md
|
||||||
|
import json as _json
|
||||||
|
|
||||||
_create_workspace(tmp_workspace)
|
_create_workspace(tmp_workspace)
|
||||||
_generate_claude_md(
|
self._patch_httpx_404(monkeypatch)
|
||||||
tmp_workspace,
|
_generate_claude_md(tmp_workspace, server_url="http://localhost:8000", token="t")
|
||||||
instance_name="TestCo",
|
|
||||||
server_url="https://test.example.com",
|
|
||||||
sync_interval="30 minutes",
|
|
||||||
)
|
|
||||||
|
|
||||||
content = (tmp_workspace / "CLAUDE.md").read_text(encoding="utf-8")
|
settings = _json.loads(
|
||||||
# Template contains these literals after substitution
|
(tmp_workspace / ".claude" / "settings.json").read_text(encoding="utf-8")
|
||||||
assert "TestCo" in content
|
)
|
||||||
assert "https://test.example.com" in content
|
assert settings["model"] == "sonnet"
|
||||||
assert "30 minutes" in content
|
assert "Read" in settings["permissions"]["allow"]
|
||||||
# Ensure placeholders are gone
|
|
||||||
assert "{instance_name}" not in content
|
|
||||||
assert "{server_url}" not in content
|
|
||||||
assert "{sync_interval}" not in content
|
|
||||||
|
|
||||||
|
|
||||||
# ---------------------------------------------------------------------------
|
# ---------------------------------------------------------------------------
|
||||||
|
|
|
||||||
44
tests/test_cli_analyst_welcome.py
Normal file
44
tests/test_cli_analyst_welcome.py
Normal file
|
|
@ -0,0 +1,44 @@
|
||||||
|
"""Integration tests for da analyst setup → /api/welcome wiring."""
|
||||||
|
|
||||||
|
from pathlib import Path
|
||||||
|
|
||||||
|
import httpx
|
||||||
|
import pytest
|
||||||
|
|
||||||
|
from cli.commands.analyst import _generate_claude_md
|
||||||
|
|
||||||
|
|
||||||
|
class _MockClient:
|
||||||
|
def __init__(self, responses):
|
||||||
|
self._responses = responses
|
||||||
|
self.calls = []
|
||||||
|
|
||||||
|
def get(self, url, headers=None, timeout=None):
|
||||||
|
self.calls.append(url)
|
||||||
|
body, status = self._responses.get(url, ({}, 404))
|
||||||
|
return httpx.Response(status_code=status, json=body, request=httpx.Request("GET", url))
|
||||||
|
|
||||||
|
|
||||||
|
def test_generate_claude_md_uses_server_render(tmp_path, monkeypatch):
|
||||||
|
workspace = tmp_path / "ws"
|
||||||
|
(workspace / ".claude").mkdir(parents=True)
|
||||||
|
rendered = "# CUSTOM\n\nFrom server.\n"
|
||||||
|
mock = _MockClient({
|
||||||
|
"https://example.com/api/welcome?server_url=https%3A%2F%2Fexample.com": (
|
||||||
|
{"content": rendered}, 200
|
||||||
|
),
|
||||||
|
})
|
||||||
|
monkeypatch.setattr("cli.commands.analyst.httpx", type("_M", (), {"get": mock.get}))
|
||||||
|
_generate_claude_md(workspace, server_url="https://example.com", token="t")
|
||||||
|
assert (workspace / "CLAUDE.md").read_text(encoding="utf-8") == rendered
|
||||||
|
|
||||||
|
|
||||||
|
def test_generate_claude_md_falls_back_on_404(tmp_path, monkeypatch):
|
||||||
|
workspace = tmp_path / "ws"
|
||||||
|
(workspace / ".claude").mkdir(parents=True)
|
||||||
|
mock = _MockClient({}) # everything 404s
|
||||||
|
monkeypatch.setattr("cli.commands.analyst.httpx", type("_M", (), {"get": mock.get}))
|
||||||
|
_generate_claude_md(workspace, server_url="https://example.com", token="t")
|
||||||
|
body = (workspace / "CLAUDE.md").read_text(encoding="utf-8")
|
||||||
|
assert "AI Data Analyst" in body # embedded fallback contains this string
|
||||||
|
assert "https://example.com" in body
|
||||||
Loading…
Reference in a new issue