diff --git a/cli/commands/analyst.py b/cli/commands/analyst.py new file mode 100644 index 0000000..b2ffd6c --- /dev/null +++ b/cli/commands/analyst.py @@ -0,0 +1,424 @@ +"""Analyst bootstrap commands — da analyst setup, da analyst status.""" + +import json +from datetime import datetime, timezone +from pathlib import Path +from typing import Optional +from urllib.parse import urlparse + +import typer + +analyst_app = typer.Typer(help="Analyst workspace bootstrap and status") + +# --------------------------------------------------------------------------- +# Helper: detect existing workspace +# --------------------------------------------------------------------------- + +_CLAUDE_MD_MARKER = "AI Data Analyst" + + +def _detect_existing_project(workspace: Path) -> bool: + """Return True if CLAUDE.md with the analyst identifier already exists.""" + claude_md = workspace / "CLAUDE.md" + if claude_md.exists(): + content = claude_md.read_text(encoding="utf-8") + return _CLAUDE_MD_MARKER in content + return False + + +# --------------------------------------------------------------------------- +# Helper: connect to instance (health check + authenticate) +# --------------------------------------------------------------------------- + +def _connect_to_instance(server_url: str) -> str: + """Health-check the server, prompt for credentials, save config, return JWT.""" + import httpx + from cli.config import save_config, save_token + + server_url = server_url.rstrip("/") + + # Health check + try: + resp = httpx.get(f"{server_url}/api/health", timeout=10.0) + resp.raise_for_status() + except Exception as e: + typer.echo(f"Cannot reach {server_url}: {e}", err=True) + raise typer.Exit(1) + + typer.echo(f"Connected to {server_url}") + + # Authenticate + email = typer.prompt("Email") + password = typer.prompt("Password", hide_input=True) + + try: + resp = httpx.post( + f"{server_url}/auth/token", + json={"email": email, "password": password}, + timeout=15.0, + ) + resp.raise_for_status() + data = resp.json() + except Exception as e: + typer.echo(f"Authentication failed: {e}", err=True) + raise typer.Exit(1) + + token = data["access_token"] + role = data.get("role", "analyst") + + save_config({"server": server_url}) + save_token(token, email, role) + typer.echo(f"Authenticated as {email} (role: {role})") + return token + + +# --------------------------------------------------------------------------- +# Helper: create workspace directory structure +# --------------------------------------------------------------------------- + +def _create_workspace(workspace: Path) -> None: + """Create the analyst workspace directory layout.""" + dirs = [ + workspace / "data" / "parquet", + workspace / "data" / "duckdb", + workspace / "data" / "metadata", + workspace / "user" / "artifacts", + workspace / "user" / "sessions", + workspace / ".claude", + ] + for d in dirs: + d.mkdir(parents=True, exist_ok=True) + + +# --------------------------------------------------------------------------- +# Helper: download metadata +# --------------------------------------------------------------------------- + +def _download_metadata(workspace: Path, server_url: str, token: str) -> None: + """Fetch catalog tables and metrics from the server; save as JSON files.""" + import httpx + + server_url = server_url.rstrip("/") + headers = {"Authorization": f"Bearer {token}"} + metadata_dir = workspace / "data" / "metadata" + + # Catalog tables + try: + resp = httpx.get(f"{server_url}/api/catalog/tables", headers=headers, timeout=30.0) + resp.raise_for_status() + tables = resp.json() + except Exception as e: + typer.echo(f"Warning: could not fetch catalog tables: {e}", err=True) + tables = [] + + (metadata_dir / "schema.json").write_text( + json.dumps(tables, indent=2, ensure_ascii=False), encoding="utf-8" + ) + + # Metrics + try: + resp = httpx.get(f"{server_url}/api/metrics", headers=headers, timeout=30.0) + resp.raise_for_status() + metrics = resp.json() + except Exception as e: + typer.echo(f"Warning: could not fetch metrics: {e}", err=True) + metrics = [] + + (metadata_dir / "metrics.json").write_text( + json.dumps(metrics, indent=2, ensure_ascii=False), encoding="utf-8" + ) + + # Write last_sync timestamp + last_sync = {"synced_at": datetime.now(timezone.utc).isoformat()} + (metadata_dir / "last_sync.json").write_text( + json.dumps(last_sync, indent=2), encoding="utf-8" + ) + + +# --------------------------------------------------------------------------- +# Helper: download parquet data +# --------------------------------------------------------------------------- + +def _download_data(workspace: Path, server_url: str, token: str) -> int: + """Stream parquets for each registered table. Returns count of files downloaded.""" + import httpx + + server_url = server_url.rstrip("/") + headers = {"Authorization": f"Bearer {token}"} + parquet_dir = workspace / "data" / "parquet" + + # Fetch manifest to know which tables exist + try: + resp = httpx.get(f"{server_url}/api/sync/manifest", headers=headers, timeout=30.0) + resp.raise_for_status() + manifest = resp.json() + except Exception as e: + typer.echo(f"Warning: could not fetch data manifest: {e}", err=True) + return 0 + + tables = manifest.get("tables", {}) + downloaded = 0 + + for table_id in tables: + target = parquet_dir / f"{table_id}.parquet" + if target.exists(): + typer.echo(f" Skipping {table_id} (already exists)") + continue + + try: + with httpx.stream( + "GET", + f"{server_url}/api/data/{table_id}/download", + headers=headers, + timeout=300.0, + ) as stream_resp: + stream_resp.raise_for_status() + with open(target, "wb") as fh: + for chunk in stream_resp.iter_bytes(chunk_size=65536): + fh.write(chunk) + downloaded += 1 + typer.echo(f" Downloaded {table_id}") + except Exception as e: + typer.echo(f" Warning: could not download {table_id}: {e}", err=True) + + return downloaded + + +# --------------------------------------------------------------------------- +# Helper: initialise DuckDB +# --------------------------------------------------------------------------- + +def _initialize_duckdb(workspace: Path) -> int: + """Create DuckDB views over parquets. Returns total row count across all views.""" + import duckdb + + parquet_dir = workspace / "data" / "parquet" + db_path = workspace / "data" / "duckdb" / "analytics.duckdb" + db_path.parent.mkdir(parents=True, exist_ok=True) + + conn = duckdb.connect(str(db_path)) + total_rows = 0 + + for pq_file in parquet_dir.glob("*.parquet"): + view_name = pq_file.stem + abs_path = str(pq_file.resolve()) + try: + conn.execute(f'DROP VIEW IF EXISTS "{view_name}"') + conn.execute( + f"CREATE VIEW \"{view_name}\" AS SELECT * FROM read_parquet('{abs_path}')" + ) + count = conn.execute(f'SELECT count(*) FROM "{view_name}"').fetchone()[0] + total_rows += count + except Exception as e: + typer.echo(f" Warning: could not create view for {view_name}: {e}", err=True) + + conn.close() + return total_rows + + +# --------------------------------------------------------------------------- +# Helper: resolve instance name +# --------------------------------------------------------------------------- + +def _get_instance_name(server_url: str, token: str) -> str: + """Retrieve instance name from /api/health, fall back to hostname.""" + import httpx + + server_url = server_url.rstrip("/") + headers = {"Authorization": f"Bearer {token}"} + + try: + resp = httpx.get(f"{server_url}/api/health", headers=headers, timeout=10.0) + if resp.status_code == 200: + data = resp.json() + name = data.get("instance_name") or data.get("name") + if name: + return name + except Exception: + pass + + # Fall back to hostname extracted from URL + parsed = urlparse(server_url) + return parsed.hostname or "AI Data Analyst" + + +# --------------------------------------------------------------------------- +# Helper: generate CLAUDE.md from template +# --------------------------------------------------------------------------- + +def _generate_claude_md( + workspace: Path, + instance_name: str, + server_url: str, + sync_interval: str, +) -> None: + """Write CLAUDE.md from the template; create CLAUDE.local.md if absent.""" + # Locate template relative to this file (../../config/claude_md_template.txt) + here = Path(__file__).parent + template_path = here.parent.parent / "config" / "claude_md_template.txt" + + if template_path.exists(): + template = template_path.read_text(encoding="utf-8") + else: + # Fallback minimal template + template = ( + "# {instance_name} — AI Data Analyst\n\n" + "This workspace is connected to {server_url}.\n\n" + "- Data on the server refreshes every {sync_interval}\n" + ) + + content = ( + template + .replace("{instance_name}", instance_name) + .replace("{server_url}", server_url) + .replace("{sync_interval}", sync_interval) + ) + + (workspace / "CLAUDE.md").write_text(content, encoding="utf-8") + + # .claude/CLAUDE.local.md — never overwrite if it already exists + local_md = workspace / ".claude" / "CLAUDE.local.md" + if not local_md.exists(): + local_md.write_text( + "# My Notes\n\n" + "Personal notes for this workspace. Uploaded to the server on `da sync --upload-only`.\n", + encoding="utf-8", + ) + + +# --------------------------------------------------------------------------- +# Helper: data freshness check (for returning-session detection) +# --------------------------------------------------------------------------- + +def _check_data_freshness(workspace: Path) -> str: + """Return 'fresh', 'stale' (>24 h old), or 'missing'.""" + last_sync_file = workspace / "data" / "metadata" / "last_sync.json" + if not last_sync_file.exists(): + return "missing" + + try: + data = json.loads(last_sync_file.read_text(encoding="utf-8")) + synced_at_str = data.get("synced_at", "") + if not synced_at_str: + return "missing" + synced_at = datetime.fromisoformat(synced_at_str) + age_hours = (datetime.now(timezone.utc) - synced_at).total_seconds() / 3600 + return "stale" if age_hours > 24 else "fresh" + except Exception: + return "missing" + + +# --------------------------------------------------------------------------- +# Command: da analyst setup +# --------------------------------------------------------------------------- + +@analyst_app.command() +def setup( + server_url: str = typer.Option(..., "--server-url", help="URL of the AI Data Analyst server"), + force: bool = typer.Option(False, "--force", help="Re-initialise even if workspace already exists"), + sync_interval: str = typer.Option("1 hour", "--sync-interval", help="Data refresh interval shown in CLAUDE.md"), + workspace_dir: Optional[str] = typer.Option(None, "--workspace", help="Workspace directory (default: current dir)"), +): + """Bootstrap a new analyst workspace from a remote server.""" + workspace = Path(workspace_dir).resolve() if workspace_dir else Path.cwd() + + # 1. Detect existing project + if _detect_existing_project(workspace) and not force: + typer.echo( + "Existing analyst workspace detected. Use --force to re-initialise.", + err=True, + ) + raise typer.Exit(1) + + typer.echo(f"Setting up analyst workspace in: {workspace}") + + # 2. Connect to instance + token = _connect_to_instance(server_url) + + # 3. Create workspace directory structure + typer.echo("Creating workspace directories...") + _create_workspace(workspace) + + # 4. Download metadata + typer.echo("Downloading metadata...") + _download_metadata(workspace, server_url, token) + + # 5. Download data + typer.echo("Downloading data...") + n_downloaded = _download_data(workspace, server_url, token) + + # 6. Initialise DuckDB + typer.echo("Initialising DuckDB views...") + total_rows = _initialize_duckdb(workspace) + + # 7. Generate CLAUDE.md + typer.echo("Generating CLAUDE.md...") + instance_name = _get_instance_name(server_url, token) + _generate_claude_md(workspace, instance_name, server_url, sync_interval) + + # 8. Summary + typer.echo("") + typer.echo("Setup complete!") + typer.echo(f" Instance : {instance_name}") + typer.echo(f" Server : {server_url}") + typer.echo(f" Tables : {n_downloaded} downloaded, {total_rows} total rows") + typer.echo(f" Workspace: {workspace}") + typer.echo("") + typer.echo("Next steps:") + typer.echo(" da sync — refresh data") + typer.echo(" da metrics list — explore available metrics") + + +# --------------------------------------------------------------------------- +# Command: da analyst status +# --------------------------------------------------------------------------- + +@analyst_app.command() +def status( + workspace_dir: Optional[str] = typer.Option(None, "--workspace", help="Workspace directory (default: current dir)"), + as_json: bool = typer.Option(False, "--json", help="Output as JSON"), +): + """Show workspace status and data freshness for returning sessions.""" + workspace = Path(workspace_dir).resolve() if workspace_dir else Path.cwd() + + exists = _detect_existing_project(workspace) + freshness = _check_data_freshness(workspace) + + # Count parquet files + parquet_dir = workspace / "data" / "parquet" + parquet_count = len(list(parquet_dir.glob("*.parquet"))) if parquet_dir.exists() else 0 + + # Last sync timestamp + last_sync_file = workspace / "data" / "metadata" / "last_sync.json" + last_sync = "never" + if last_sync_file.exists(): + try: + data = json.loads(last_sync_file.read_text(encoding="utf-8")) + last_sync = data.get("synced_at", "never") + except Exception: + pass + + info = { + "workspace": str(workspace), + "initialized": exists, + "freshness": freshness, + "parquet_tables": parquet_count, + "last_sync": last_sync, + } + + if as_json: + typer.echo(json.dumps(info, indent=2)) + return + + typer.echo(f"Workspace : {workspace}") + typer.echo(f"Initialized: {'yes' if exists else 'no'}") + typer.echo(f"Data freshness: {freshness}") + typer.echo(f"Parquet tables: {parquet_count}") + typer.echo(f"Last sync: {last_sync}") + + if freshness == "stale": + typer.echo("") + typer.echo("Data is stale (>24 h). Run: da sync") + elif freshness == "missing": + typer.echo("") + typer.echo("No data found. Run: da analyst setup --server-url ") diff --git a/cli/config.py b/cli/config.py index e75961a..5e8236b 100644 --- a/cli/config.py +++ b/cli/config.py @@ -58,3 +58,15 @@ def get_sync_state() -> dict: def save_sync_state(state: dict): state_file = _config_dir() / "sync_state.json" state_file.write_text(json.dumps(state, indent=2)) + + +def save_config(data: dict): + """Persist server URL and other config to config.yaml.""" + import yaml + + config_file = _config_dir() / "config.yaml" + existing = {} + if config_file.exists(): + existing = yaml.safe_load(config_file.read_text()) or {} + existing.update(data) + config_file.write_text(yaml.dump(existing, default_flow_style=False)) diff --git a/cli/main.py b/cli/main.py index 427e598..ffe0329 100644 --- a/cli/main.py +++ b/cli/main.py @@ -16,6 +16,7 @@ from cli.commands.setup import setup_app from cli.commands.server import server_app from cli.commands.explore import explore_app from cli.commands.metrics import metrics_app +from cli.commands.analyst import analyst_app app = typer.Typer( name="da", @@ -35,6 +36,7 @@ app.add_typer(setup_app, name="setup") app.add_typer(server_app, name="server") app.add_typer(explore_app, name="explore") app.add_typer(metrics_app, name="metrics") +app.add_typer(analyst_app, name="analyst") if __name__ == "__main__": diff --git a/tests/test_analyst_bootstrap.py b/tests/test_analyst_bootstrap.py new file mode 100644 index 0000000..4d9aef8 --- /dev/null +++ b/tests/test_analyst_bootstrap.py @@ -0,0 +1,250 @@ +"""Tests for analyst bootstrap flow.""" + +import json +import os +from datetime import datetime, timedelta, timezone +from pathlib import Path +from unittest.mock import MagicMock, patch + +import pytest +from typer.testing import CliRunner + +from cli.main import app + +runner = CliRunner() + + +@pytest.fixture(autouse=True) +def tmp_workspace(tmp_path, monkeypatch): + monkeypatch.setenv("DATA_DIR", str(tmp_path / "data")) + monkeypatch.setenv("DA_CONFIG_DIR", str(tmp_path / "config")) + monkeypatch.setenv("DA_LOCAL_DIR", str(tmp_path / "local")) + monkeypatch.setenv("JWT_SECRET_KEY", "test-secret") + (tmp_path / "data").mkdir() + (tmp_path / "config").mkdir() + (tmp_path / "local").mkdir() + ws = tmp_path / "workspace" + ws.mkdir() + monkeypatch.chdir(ws) + yield ws + + +# --------------------------------------------------------------------------- +# TestDetectExistingProject +# --------------------------------------------------------------------------- + +class TestDetectExistingProject: + def test_no_claude_md_returns_false(self, tmp_workspace): + from cli.commands.analyst import _detect_existing_project + + assert _detect_existing_project(tmp_workspace) is False + + def test_claude_md_with_marker_returns_true(self, tmp_workspace): + from cli.commands.analyst import _detect_existing_project + + (tmp_workspace / "CLAUDE.md").write_text( + "# Acme — AI Data Analyst\n\nThis workspace is connected to http://localhost:8000.\n", + encoding="utf-8", + ) + assert _detect_existing_project(tmp_workspace) is True + + def test_claude_md_without_marker_returns_false(self, tmp_workspace): + from cli.commands.analyst import _detect_existing_project + + (tmp_workspace / "CLAUDE.md").write_text( + "# Some Other Project\n\nNot an analyst workspace.\n", + encoding="utf-8", + ) + assert _detect_existing_project(tmp_workspace) is False + + def test_setup_blocked_when_existing_without_force(self, tmp_workspace): + """Setup must exit(1) when workspace exists and --force not supplied.""" + (tmp_workspace / "CLAUDE.md").write_text( + "# Acme — AI Data Analyst\nThis workspace is connected to http://localhost:8000.\n", + encoding="utf-8", + ) + result = runner.invoke(app, ["analyst", "setup", "--server-url", "http://localhost:8000"]) + assert result.exit_code == 1 + assert "force" in result.output.lower() or "force" in (result.stderr or "").lower() + + def test_setup_proceeds_with_force(self, tmp_workspace): + """--force bypasses existing-project detection.""" + (tmp_workspace / "CLAUDE.md").write_text( + "# Acme — AI Data Analyst\nThis workspace is connected to http://localhost:8000.\n", + encoding="utf-8", + ) + + with patch("cli.commands.analyst._connect_to_instance", return_value="tok"), \ + patch("cli.commands.analyst._download_metadata"), \ + patch("cli.commands.analyst._download_data", return_value=0), \ + patch("cli.commands.analyst._initialize_duckdb", return_value=0), \ + patch("cli.commands.analyst._get_instance_name", return_value="Acme"), \ + patch("cli.commands.analyst._generate_claude_md"): + result = runner.invoke( + app, + ["analyst", "setup", "--server-url", "http://localhost:8000", "--force"], + ) + assert result.exit_code == 0 + + +# --------------------------------------------------------------------------- +# TestCreateWorkspace +# --------------------------------------------------------------------------- + +class TestCreateWorkspace: + def test_creates_all_directories(self, tmp_workspace): + from cli.commands.analyst import _create_workspace + + _create_workspace(tmp_workspace) + + expected = [ + tmp_workspace / "data" / "parquet", + tmp_workspace / "data" / "duckdb", + tmp_workspace / "data" / "metadata", + tmp_workspace / "user" / "artifacts", + tmp_workspace / "user" / "sessions", + tmp_workspace / ".claude", + ] + for d in expected: + assert d.is_dir(), f"Expected directory missing: {d}" + + def test_idempotent(self, tmp_workspace): + """Calling _create_workspace twice should not raise.""" + from cli.commands.analyst import _create_workspace + + _create_workspace(tmp_workspace) + _create_workspace(tmp_workspace) # should not raise + + +# --------------------------------------------------------------------------- +# TestGenerateClaudeMd +# --------------------------------------------------------------------------- + +class TestGenerateClaudeMd: + def test_template_substitution(self, tmp_workspace): + from cli.commands.analyst import _create_workspace, _generate_claude_md + + _create_workspace(tmp_workspace) + _generate_claude_md( + tmp_workspace, + instance_name="Acme Corp", + server_url="https://data.acme.com", + sync_interval="2 hours", + ) + + content = (tmp_workspace / "CLAUDE.md").read_text(encoding="utf-8") + assert "Acme Corp" in content + assert "https://data.acme.com" in content + assert "2 hours" in content + + def test_creates_claude_local_md_when_absent(self, tmp_workspace): + from cli.commands.analyst import _create_workspace, _generate_claude_md + + _create_workspace(tmp_workspace) + _generate_claude_md( + tmp_workspace, + instance_name="Acme", + server_url="http://localhost:8000", + sync_interval="1 hour", + ) + + local_md = tmp_workspace / ".claude" / "CLAUDE.local.md" + assert local_md.exists() + assert local_md.read_text(encoding="utf-8").strip() != "" + + def test_does_not_overwrite_existing_local_md(self, tmp_workspace): + from cli.commands.analyst import _create_workspace, _generate_claude_md + + _create_workspace(tmp_workspace) + local_md = tmp_workspace / ".claude" / "CLAUDE.local.md" + original_content = "# My custom notes\n\nDo not overwrite me.\n" + local_md.write_text(original_content, encoding="utf-8") + + _generate_claude_md( + tmp_workspace, + instance_name="Acme", + server_url="http://localhost:8000", + sync_interval="1 hour", + ) + + assert local_md.read_text(encoding="utf-8") == original_content + + def test_uses_template_file_if_available(self, tmp_workspace): + """Smoke-test that the real template file is found and substituted.""" + from cli.commands.analyst import _create_workspace, _generate_claude_md + + _create_workspace(tmp_workspace) + _generate_claude_md( + tmp_workspace, + instance_name="TestCo", + server_url="https://test.example.com", + sync_interval="30 minutes", + ) + + content = (tmp_workspace / "CLAUDE.md").read_text(encoding="utf-8") + # Template contains these literals after substitution + assert "TestCo" in content + assert "https://test.example.com" in content + assert "30 minutes" in content + # Ensure placeholders are gone + assert "{instance_name}" not in content + assert "{server_url}" not in content + assert "{sync_interval}" not in content + + +# --------------------------------------------------------------------------- +# TestReturningSession +# --------------------------------------------------------------------------- + +class TestReturningSession: + def test_missing_when_no_last_sync_file(self, tmp_workspace): + from cli.commands.analyst import _check_data_freshness + + assert _check_data_freshness(tmp_workspace) == "missing" + + def test_fresh_when_recent_sync(self, tmp_workspace): + from cli.commands.analyst import _create_workspace, _check_data_freshness + + _create_workspace(tmp_workspace) + synced_at = datetime.now(timezone.utc).isoformat() + (tmp_workspace / "data" / "metadata" / "last_sync.json").write_text( + json.dumps({"synced_at": synced_at}), encoding="utf-8" + ) + assert _check_data_freshness(tmp_workspace) == "fresh" + + def test_stale_when_old_sync(self, tmp_workspace): + from cli.commands.analyst import _create_workspace, _check_data_freshness + + _create_workspace(tmp_workspace) + old_time = (datetime.now(timezone.utc) - timedelta(hours=25)).isoformat() + (tmp_workspace / "data" / "metadata" / "last_sync.json").write_text( + json.dumps({"synced_at": old_time}), encoding="utf-8" + ) + assert _check_data_freshness(tmp_workspace) == "stale" + + def test_status_command_output(self, tmp_workspace): + result = runner.invoke(app, ["analyst", "status"]) + assert result.exit_code == 0 + assert "freshness" in result.output.lower() or "Data freshness" in result.output + + def test_status_command_json(self, tmp_workspace): + result = runner.invoke(app, ["analyst", "status", "--json"]) + assert result.exit_code == 0 + data = json.loads(result.output) + assert "freshness" in data + assert data["freshness"] == "missing" + assert "parquet_tables" in data + + def test_status_fresh_after_setup_metadata(self, tmp_workspace): + from cli.commands.analyst import _create_workspace + + _create_workspace(tmp_workspace) + synced_at = datetime.now(timezone.utc).isoformat() + (tmp_workspace / "data" / "metadata" / "last_sync.json").write_text( + json.dumps({"synced_at": synced_at}), encoding="utf-8" + ) + + result = runner.invoke(app, ["analyst", "status", "--json"]) + assert result.exit_code == 0 + data = json.loads(result.output) + assert data["freshness"] == "fresh" diff --git a/tests/test_cli.py b/tests/test_cli.py index f36de9b..55ad02c 100644 --- a/tests/test_cli.py +++ b/tests/test_cli.py @@ -245,3 +245,8 @@ class TestMetricsHelp: assert "list" in result.output assert "show" in result.output assert "import" in result.output + + def test_analyst_help(self): + result = runner.invoke(app, ["analyst", "--help"]) + assert result.exit_code == 0 + assert "setup" in result.output