diff --git a/.github/workflows/release.yml b/.github/workflows/release.yml index 86d636e..d2b7537 100644 --- a/.github/workflows/release.yml +++ b/.github/workflows/release.yml @@ -121,6 +121,22 @@ jobs: echo "Version: ${VERSION}" echo "Versioned tag: ${TAG}" + - name: Extract package version from pyproject.toml + id: pkgver + run: | + # Single source of truth for the product version: the + # pyproject.toml [project] table. The CalVer "${YEAR_MONTH}.${N}" + # claimed above stays as the git / image tag (release identity), + # but AGNES_VERSION — what /api/version, /cli/latest, and `da + # --version` all expose — tracks the package version. + VERSION=$(grep '^version' pyproject.toml | head -1 | sed -E 's/^version\s*=\s*"([^"]+)".*/\1/') + if [ -z "$VERSION" ]; then + echo "::error::Could not extract version from pyproject.toml" + exit 1 + fi + echo "version=${VERSION}" >> "$GITHUB_OUTPUT" + echo "Package version: ${VERSION}" + - name: Log in to GHCR uses: docker/login-action@v4 with: @@ -133,9 +149,10 @@ jobs: with: push: true build-args: | - AGNES_VERSION=${{ steps.meta.outputs.version }} + AGNES_VERSION=${{ steps.pkgver.outputs.version }} RELEASE_CHANNEL=${{ steps.meta.outputs.channel }} AGNES_COMMIT_SHA=${{ github.sha }} + AGNES_TAG=${{ steps.meta.outputs.versioned_tag }} tags: | ghcr.io/${{ github.repository }}:${{ steps.meta.outputs.channel }} ghcr.io/${{ github.repository }}:${{ steps.meta.outputs.versioned_tag }} diff --git a/Dockerfile b/Dockerfile index 4b88690..098aaef 100644 --- a/Dockerfile +++ b/Dockerfile @@ -7,9 +7,11 @@ COPY --from=ghcr.io/astral-sh/uv:latest /uv /usr/local/bin/uv ARG AGNES_VERSION=dev ARG RELEASE_CHANNEL=dev ARG AGNES_COMMIT_SHA=unknown +ARG AGNES_TAG=unknown ENV AGNES_VERSION=${AGNES_VERSION} ENV RELEASE_CHANNEL=${RELEASE_CHANNEL} ENV AGNES_COMMIT_SHA=${AGNES_COMMIT_SHA} +ENV AGNES_TAG=${AGNES_TAG} WORKDIR /app diff --git a/app/api/cli_artifacts.py b/app/api/cli_artifacts.py index d61d049..a0af651 100644 --- a/app/api/cli_artifacts.py +++ b/app/api/cli_artifacts.py @@ -39,6 +39,29 @@ def _find_wheel() -> Path | None: return wheels[-1] if wheels else None +@router.get("/cli/latest") +async def cli_latest(): + """Metadata for the currently-shipped CLI wheel. + + Consumed by `da` CLI's auto-update check so it can warn when a newer + version is on the server. Public + cacheable — no secrets here. + Returns `version=None` when the server has no wheel yet (dev image that + didn't run `uv build`). + """ + wheel = _find_wheel() + if not wheel: + return {"version": None, "wheel_filename": None, "download_url_path": None} + # PEP 427 wheel filename: {name}-{version}(-{build})?-{py}-{abi}-{plat}.whl + # The version is the second `-`-separated token. + parts = wheel.stem.split("-") + version = parts[1] if len(parts) >= 2 else None + return { + "version": version, + "wheel_filename": wheel.name, + "download_url_path": f"/cli/wheel/{wheel.name}", + } + + @router.get("/cli/download") async def cli_download(): wheel = _find_wheel() @@ -58,25 +81,17 @@ async def cli_download(): ) -@router.get("/cli/agnes.whl") -async def cli_wheel_stable(): - """Stable `.whl` URL alias so `uv tool install /cli/agnes.whl` works. +@router.get("/cli/wheel/{wheel_name}") +async def cli_wheel_versioned(wheel_name: str): + """Serve the currently-present wheel at a PEP 427-compliant URL. - `uv tool install` inspects the URL path to decide how to treat the resource - and only accepts it as a wheel when the path ends in `.whl`. The existing - `/cli/download` path does not, which forces users through a multi-step - curl + tmpfile + install + rm dance. This alias collapses that into a - single `uv tool install` invocation. + Only the exact filename of the current wheel is honoured; any other + `wheel_name` returns 404. No filesystem lookup is done from user input — + the path param is only compared against `_find_wheel().name`. """ wheel = _find_wheel() - if not wheel: - raise HTTPException( - status_code=404, - detail=( - "CLI wheel not found in dist dir. Build it with `uv build --wheel` " - "or run the official docker image (which builds on image-build)." - ), - ) + if not wheel or wheel.name != wheel_name: + raise HTTPException(status_code=404, detail="Wheel not found") return FileResponse( path=str(wheel), filename=wheel.name, diff --git a/app/main.py b/app/main.py index 32ee07b..056c5fd 100644 --- a/app/main.py +++ b/app/main.py @@ -2,17 +2,58 @@ import logging from contextlib import asynccontextmanager +from importlib.metadata import PackageNotFoundError +from importlib.metadata import version as _pkg_version from pathlib import Path from urllib.parse import quote import os + +def _app_version() -> str: + """Product version for FastAPI title / OpenAPI schema. + + Single source of truth is `pyproject.toml` `[project].version`; we read + it back via `importlib.metadata` at runtime so `/docs`, `/openapi.json`, + `/api/version`, `/cli/latest`, and `da --version` can never drift. + """ + try: + return _pkg_version("agnes-the-ai-analyst") + except PackageNotFoundError: + return "dev" + from fastapi import FastAPI from fastapi.middleware.cors import CORSMiddleware from fastapi.responses import RedirectResponse from fastapi.staticfiles import StaticFiles from starlette.exceptions import HTTPException as StarletteHTTPException +from starlette.middleware.gzip import GZipMiddleware from starlette.middleware.sessions import SessionMiddleware +from starlette.types import ASGIApp, Receive, Scope, Send + + +class _SelectiveGZipMiddleware: + """GZipMiddleware wrapper that skips a set of path prefixes. + + Parquet-serving endpoints send responses that are already columnar- + compressed (parquet's internal codec) and — for /api/data — can reach + hundreds of MB. Gzipping them on the way out costs CPU and latency with + no meaningful size reduction. Skip those paths; every other endpoint + (JSON manifests, HTML previews, install.sh) still gets compressed. + """ + + def __init__(self, app: ASGIApp, minimum_size: int = 1024, skip_prefixes: tuple[str, ...] = ()) -> None: + self._raw = app + self._gzip = GZipMiddleware(app, minimum_size=minimum_size) + self._skip_prefixes = skip_prefixes + + async def __call__(self, scope: Scope, receive: Receive, send: Send) -> None: + if scope.get("type") == "http": + path = scope.get("path", "") + if any(path.startswith(p) for p in self._skip_prefixes): + await self._raw(scope, receive, send) + return + await self._gzip(scope, receive, send) from app.auth.router import router as auth_router from app.api.health import router as health_router @@ -51,10 +92,20 @@ def create_app() -> FastAPI: app = FastAPI( title="AI Data Analyst", description="Data distribution platform for AI analytical systems", - version="2.0.0", + version=_app_version(), lifespan=lifespan, ) + # Compress JSON / HTML responses on the wire. Parquet downloads are + # excluded — they're already columnar-compressed and re-gzipping them + # just burns CPU with no size win. minimum_size=1024 keeps tiny + # responses uncompressed too (cheaper than the header overhead). + app.add_middleware( + _SelectiveGZipMiddleware, + minimum_size=1024, + skip_prefixes=("/api/data/", "/cli/wheel/", "/cli/download"), + ) + # Session middleware (required for OAuth state) from app.secrets import get_session_secret session_secret = get_session_secret() diff --git a/app/web/router.py b/app/web/router.py index 146d347..5d5859a 100644 --- a/app/web/router.py +++ b/app/web/router.py @@ -155,7 +155,13 @@ def _build_context(request: Request, user: Optional[dict] = None, **extra) -> di # Lines + server_url for the "Setup a new Claude Code" preview/clipboard # partial; single source of truth lives in app/web/setup_instructions.py. - from app.web.setup_instructions import SETUP_INSTRUCTIONS_LINES + # Resolve the wheel filename server-side so the URL in the setup snippet + # is a PEP 427-compliant path — `uv tool install` rejects bare `agnes.whl`. + from app.web.setup_instructions import resolve_lines + from app.api.cli_artifacts import _find_wheel + _wheel = _find_wheel() + _wheel_filename = _wheel.name if _wheel else "agnes.whl" + setup_instructions_lines = resolve_lines(_wheel_filename) ctx_server_url = str(request.base_url).rstrip("/") ctx = { @@ -168,7 +174,7 @@ def _build_context(request: Request, user: Optional[dict] = None, **extra) -> di "get_flashed_messages": lambda **kwargs: [], "url_for": lambda endpoint, **kw: _url_for_shim(endpoint, **kw), "session": _FlexDict({"user": user}) if user else _FlexDict(), - "setup_instructions_lines": SETUP_INSTRUCTIONS_LINES, + "setup_instructions_lines": setup_instructions_lines, "server_url": ctx_server_url, } # Flex all extra context values for template compatibility diff --git a/app/web/setup_instructions.py b/app/web/setup_instructions.py index bedb9e0..877f6ba 100644 --- a/app/web/setup_instructions.py +++ b/app/web/setup_instructions.py @@ -4,9 +4,11 @@ Both the JS-embedded clipboard renderer (`_claude_setup_instructions.jinja`) and the read-only HTML preview on the dashboard and /install pages consume these lines. Keep it in Python so there is exactly ONE place that edits. -Placeholders `{server_url}` and `{token}` are substituted at render time. -For the preview we substitute `{token}` with a user-visible placeholder -string styled distinctly in the HTML preview. +Placeholders `{server_url}`, `{token}`, and `{wheel_filename}` are substituted +at render time. `{wheel_filename}` is pre-substituted server-side via +`resolve_lines()` because `uv tool install` validates the PEP 427 filename +*in the URL path* before fetching, so a stable alias like `agnes.whl` fails +with "Must have a version" — we need the real versioned filename inlined. """ from __future__ import annotations @@ -21,7 +23,7 @@ SETUP_INSTRUCTIONS_LINES: list[str] = [ "Run these, in order. If any step fails, paste the exact error back and stop.", "", "1) Install the CLI:", - " uv tool install --force {server_url}/cli/agnes.whl", + " uv tool install --force {server_url}/cli/wheel/{wheel_filename}", "", " If uv is not installed yet:", " curl -LsSf https://astral.sh/uv/install.sh | sh", @@ -68,12 +70,27 @@ SETUP_INSTRUCTIONS_LINES: list[str] = [ ] -def render_setup_instructions(server_url: str, token: str) -> str: +def resolve_lines(wheel_filename: str) -> list[str]: + """Return the template lines with `{wheel_filename}` pre-substituted. + + Called by the web router before passing the lines to the Jinja partial + (both preview and JS modes). Keeps the client side from having to know + the wheel filename and keeps the two renderers byte-identical. + + Fallback: callers pass `"agnes.whl"` when no wheel is present on disk. + The resulting URL (`/cli/wheel/agnes.whl`) will 404 at download time, but + the instruction text still renders so operators can see the snippet shape + and diagnose the missing wheel on the server. + """ + return [line.replace("{wheel_filename}", wheel_filename) for line in SETUP_INSTRUCTIONS_LINES] + + +def render_setup_instructions(server_url: str, token: str, wheel_filename: str = "agnes.whl") -> str: """Render the setup instructions as a single string. Used server-side for tests and any non-JS rendering path. The browser clipboard flow uses the JS renderer embedded in the Jinja partial; both - must produce byte-identical output for a given (server_url, token). + must produce byte-identical output for a given (server_url, token, wheel). """ - text = "\n".join(SETUP_INSTRUCTIONS_LINES) + text = "\n".join(resolve_lines(wheel_filename)) return text.replace("{server_url}", server_url).replace("{token}", token) diff --git a/cli/client.py b/cli/client.py index 37134e2..79bcef3 100644 --- a/cli/client.py +++ b/cli/client.py @@ -1,11 +1,20 @@ """HTTP client wrapper for CLI — handles auth, retries, streaming.""" +import os +import time +from pathlib import Path from typing import Optional import httpx from cli.config import get_server_url, get_token +# Retry policy for transient failures during stream downloads. Scoped to +# network issues and 5xx — 4xx (auth, 404, 400) is NOT retried. Tunable via +# env for tests; defaults sit in the "one flaky network blip" window. +_RETRY_ATTEMPTS = int(os.environ.get("DA_STREAM_RETRIES", "3")) +_RETRY_BACKOFFS_S = (0.3, 1.0, 3.0) # seconds before attempt 2, 3, 4 + def get_client(timeout: float = 30.0) -> httpx.Client: """Get an authenticated httpx client.""" @@ -40,16 +49,51 @@ def api_patch(path: str, **kwargs) -> httpx.Response: return client.patch(path, **kwargs) +def _is_transient(exc: Exception) -> bool: + """Worth retrying? Network blip or 5xx — yes. Auth / 4xx — no.""" + if isinstance(exc, (httpx.ConnectError, httpx.ReadError, httpx.WriteError, + httpx.RemoteProtocolError, httpx.TimeoutException)): + return True + if isinstance(exc, httpx.HTTPStatusError): + return 500 <= exc.response.status_code < 600 + return False + + def stream_download(path: str, target_path: str, progress_callback=None) -> int: - """Stream download a file from the API. Returns bytes written.""" - with get_client(timeout=300.0) as client: - with client.stream("GET", path) as response: - response.raise_for_status() - total = 0 - with open(target_path, "wb") as f: - for chunk in response.iter_bytes(chunk_size=65536): - f.write(chunk) - total += len(chunk) - if progress_callback: - progress_callback(len(chunk)) + """Stream a file to `target_path` atomically and with retries. + + Durability properties: + - Writes to `target_path + ".tmp"`, then `os.replace` on success. The + real target file never exists in a half-written state. + - Retries up to `_RETRY_ATTEMPTS` times on transient errors (network + blip, 5xx); 4xx (auth/404) is raised immediately. + - No hash check here — that's done in the sync command against the + manifest hash, because only the caller knows the expected value. + """ + tmp_path = Path(f"{target_path}.tmp") + last_exc: Optional[Exception] = None + for attempt in range(_RETRY_ATTEMPTS + 1): + try: + tmp_path.unlink(missing_ok=True) + with get_client(timeout=300.0) as client: + with client.stream("GET", path) as response: + response.raise_for_status() + total = 0 + with open(tmp_path, "wb") as f: + for chunk in response.iter_bytes(chunk_size=65536): + f.write(chunk) + total += len(chunk) + if progress_callback: + progress_callback(len(chunk)) + # os.replace is atomic on POSIX and Windows for same-filesystem moves. + os.replace(tmp_path, target_path) return total + except Exception as exc: + last_exc = exc + if attempt == _RETRY_ATTEMPTS or not _is_transient(exc): + break + time.sleep(_RETRY_BACKOFFS_S[min(attempt, len(_RETRY_BACKOFFS_S) - 1)]) + # Clean up any leftover tmp, then surface the last exception. + tmp_path.unlink(missing_ok=True) + assert last_exc is not None + raise last_exc diff --git a/cli/commands/sync.py b/cli/commands/sync.py index b4da0f1..4106173 100644 --- a/cli/commands/sync.py +++ b/cli/commands/sync.py @@ -1,11 +1,19 @@ """Sync commands — da sync.""" +import hashlib import json import os from pathlib import Path import typer -from rich.progress import Progress, SpinnerColumn, TextColumn +from rich.progress import ( + BarColumn, + Progress, + SpinnerColumn, + TaskProgressColumn, + TextColumn, + TimeElapsedColumn, +) from cli.client import api_get, api_post, stream_download from cli.config import get_sync_state, save_sync_state @@ -23,14 +31,25 @@ def sync( upload_only: bool = typer.Option(False, "--upload-only", help="Only upload sessions/artifacts"), docs_only: bool = typer.Option(False, "--docs-only", help="Only sync documentation"), as_json: bool = typer.Option(False, "--json", help="Output as JSON"), + dry_run: bool = typer.Option( + False, + "--dry-run", + help="Show what would be synced without downloading, uploading, or writing local state.", + ), ): """Sync data between server and local machine.""" if upload_only: - _upload(as_json) + _upload(as_json, dry_run=dry_run) return - with Progress(SpinnerColumn(), TextColumn("[progress.description]{task.description}")) as progress: - # 1. Get manifest + with Progress( + SpinnerColumn(), + TextColumn("[progress.description]{task.description}"), + BarColumn(), + TaskProgressColumn(), + TimeElapsedColumn(), + ) as progress: + # 1. Get manifest — indeterminate spinner (total unknown until manifest lands) task = progress.add_task("Fetching manifest...", total=None) try: resp = api_get("/api/sync/manifest") @@ -57,35 +76,66 @@ def sync( if server_hash != local_hash or tid not in local_tables or not server_hash: to_download.append(tid) - progress.update(task, description=f"Found {len(to_download)} tables to sync") + # Switch the bar from indeterminate to "X/N" progress once we know the total. + progress.update( + task, + description=f"Found {len(to_download)} tables to sync", + total=len(to_download) or None, + completed=0, + ) - # 3. Download parquets + # 3. Dry-run short-circuit — report what would happen, touch nothing on disk. + if dry_run: + progress.update(task, description="Dry run — nothing will be downloaded") + _print_dry_run_plan(to_download, server_tables, len(server_tables), as_json) + return + + # 4. Download parquets local_dir = _local_data_dir() parquet_dir = local_dir / "server" / "parquet" parquet_dir.mkdir(parents=True, exist_ok=True) results = {"downloaded": [], "skipped": [], "errors": []} - for tid in to_download: - progress.update(task, description=f"Downloading {tid}...") + total = len(to_download) + for idx, tid in enumerate(to_download, start=1): + progress.update(task, description=f"[{idx}/{total}] Downloading {tid}...") target = parquet_dir / f"{tid}.parquet" + expected_hash = server_tables[tid].get("hash", "") try: stream_download(f"/api/data/{tid}/download", str(target)) + # Integrity check against the manifest hash (server uses MD5 + # over the parquet — see app/api/sync.py:_file_hash). A + # structural PAR1 check is kept as a fallback for when the + # manifest hash is empty (legacy snapshots). + if expected_hash: + actual_hash = _md5_file(target) + if actual_hash != expected_hash: + target.unlink(missing_ok=True) + raise ValueError( + f"hash mismatch: expected {expected_hash[:12]}…, got {actual_hash[:12]}…" + ) + elif not _is_valid_parquet(target): + target.unlink(missing_ok=True) + raise ValueError( + "downloaded file is not a valid parquet (missing PAR1 magic bytes)" + ) local_tables[tid] = { - "hash": server_tables[tid].get("hash", ""), + "hash": expected_hash, "rows": server_tables[tid].get("rows", 0), "size_bytes": server_tables[tid].get("size_bytes", 0), } results["downloaded"].append(tid) except Exception as e: results["errors"].append({"table": tid, "error": str(e)}) + progress.advance(task, 1) - # 4. Save local state + # 5. Save local state from datetime import datetime, timezone local_state["tables"] = local_tables local_state["last_sync"] = datetime.now(timezone.utc).isoformat() save_sync_state(local_state) - # 5. Rebuild DuckDB views + # 6. Rebuild DuckDB views if results["downloaded"]: progress.update(task, description="Rebuilding DuckDB views...") _rebuild_duckdb_views(local_dir, parquet_dir) @@ -105,6 +155,66 @@ def sync( typer.echo(f" {err['table']}: {err['error']}") +def _print_dry_run_plan( + to_download: list[str], + server_tables: dict, + total_tables: int, + as_json: bool, +) -> None: + """Render the dry-run plan for the download flow (no disk writes). + + Pairs table IDs with their manifest `size_bytes` / `rows` so the operator + can judge cost before committing to the real sync. + """ + total_bytes = sum(server_tables.get(tid, {}).get("size_bytes", 0) or 0 for tid in to_download) + plan = [ + { + "table": tid, + "rows": server_tables.get(tid, {}).get("rows", 0) or 0, + "size_bytes": server_tables.get(tid, {}).get("size_bytes", 0) or 0, + } + for tid in to_download + ] + if as_json: + typer.echo(json.dumps( + { + "dry_run": True, + "would_download": plan, + "summary": { + "tables_total": total_tables, + "tables_to_download": len(to_download), + "tables_skipped_unchanged": total_tables - len(to_download), + "bytes_total": total_bytes, + }, + }, + indent=2, + )) + return + + typer.echo(f"Dry run — would download {len(to_download)} tables ({_fmt_bytes(total_bytes)})") + typer.echo(f"Skipped (unchanged): {total_tables - len(to_download)}") + for row in plan: + typer.echo(f" {row['table']} rows={row['rows']} size={_fmt_bytes(row['size_bytes'])}") + + +def _fmt_bytes(n: int) -> str: + """Human-readable byte size. + + Every named unit must appear inside the loop so `n` gets divided one + more time than the label it's attached to. Otherwise the fallback + reports 1 unit-of-next-magnitude as "1024.0 ". + """ + if n < 1024: + return f"{n} B" + value = float(n) + for unit in ("KiB", "MiB", "GiB", "TiB", "PiB", "EiB"): + value /= 1024 + if value < 1024: + return f"{value:.1f} {unit}" + # Beyond EiB is astronomical — just keep dividing and label as EiB. + return f"{value:.1f} EiB" + + def _rebuild_duckdb_views(local_dir: Path, parquet_dir: Path): """Recreate DuckDB views from downloaded parquets. Preserve user tables.""" import duckdb @@ -132,24 +242,102 @@ def _rebuild_duckdb_views(local_dir: Path, parquet_dir: Path): except Exception: pass - # Create views for each parquet file + # Create views for each parquet file. One broken file (corrupt download, + # partial write left over from a previous sync, …) must not abort the + # whole rebuild — skip it with a warning and keep going. + skipped_broken: list[str] = [] for pq_file in parquet_dir.rglob("*.parquet"): view_name = pq_file.stem if view_name in existing_tables: continue # don't shadow user tables + if not _is_valid_parquet(pq_file): + skipped_broken.append(view_name) + continue abs_path = str(pq_file.resolve()) - conn.execute(f"CREATE VIEW \"{view_name}\" AS SELECT * FROM read_parquet('{abs_path}')") + try: + conn.execute(f"CREATE VIEW \"{view_name}\" AS SELECT * FROM read_parquet('{abs_path}')") + except duckdb.Error: + skipped_broken.append(view_name) conn.close() + if skipped_broken: + typer.echo( + f"Warning: skipped {len(skipped_broken)} broken parquet file(s) during view rebuild:", + err=True, + ) + for name in skipped_broken: + typer.echo(f" - {name}.parquet", err=True) -def _upload(as_json: bool): - """Upload sessions and CLAUDE.local.md to server.""" + +def _md5_file(path: Path) -> str: + """MD5 of a file, same chunking as app/api/sync.py:_file_hash so the + client-side verification matches the manifest hash byte-for-byte.""" + h = hashlib.md5() + with open(path, "rb") as f: + for chunk in iter(lambda: f.read(8192), b""): + h.update(chunk) + return h.hexdigest() + + +def _is_valid_parquet(path: Path) -> bool: + """Cheap structural check — parquet files begin and end with `PAR1`. + + Used as a fallback when the manifest has no hash (legacy snapshots) and + during view rebuild to skip obviously-broken files. Does not guarantee + the footer is well-formed — that's DuckDB's job at CREATE VIEW time. + """ + try: + size = path.stat().st_size + if size < 8: + return False + with open(path, "rb") as f: + head = f.read(4) + f.seek(-4, 2) + tail = f.read(4) + return head == b"PAR1" and tail == b"PAR1" + except OSError: + return False + + +def _upload(as_json: bool, dry_run: bool = False): + """Upload sessions and CLAUDE.local.md to server. + + When `dry_run=True`, enumerate what would be uploaded without hitting the + API or mutating anything on disk. + """ local_dir = _local_data_dir() + sessions_dir = local_dir / "user" / "sessions" + local_md = local_dir / ".claude" / "CLAUDE.local.md" + + if dry_run: + session_files = sorted(str(f) for f in sessions_dir.glob("*.jsonl")) if sessions_dir.exists() else [] + plan = { + "dry_run": True, + "would_upload": { + "sessions": session_files, + "local_md": str(local_md) if local_md.exists() else None, + }, + "summary": { + "sessions_count": len(session_files), + "local_md_present": local_md.exists(), + }, + } + if as_json: + typer.echo(json.dumps(plan, indent=2)) + return + typer.echo(f"Dry run — would upload {len(session_files)} session file(s)") + for f in session_files: + typer.echo(f" {f}") + if local_md.exists(): + typer.echo(f"Would upload CLAUDE.local.md ({local_md})") + else: + typer.echo("No CLAUDE.local.md to upload") + return + results = {"sessions": 0, "local_md": False} # Upload sessions - sessions_dir = local_dir / "user" / "sessions" if sessions_dir.exists(): for f in sessions_dir.glob("*.jsonl"): try: @@ -161,7 +349,6 @@ def _upload(as_json: bool): pass # Upload CLAUDE.local.md - local_md = local_dir / ".claude" / "CLAUDE.local.md" if local_md.exists(): content = local_md.read_text(encoding="utf-8") try: diff --git a/cli/main.py b/cli/main.py index ffe0329..77b80b3 100644 --- a/cli/main.py +++ b/cli/main.py @@ -3,6 +3,9 @@ Primary interface for AI agents. Install: uv tool install data-analyst """ +from importlib.metadata import PackageNotFoundError +from importlib.metadata import version as _pkg_version + import typer from cli.commands.auth import auth_app @@ -18,12 +21,67 @@ from cli.commands.explore import explore_app from cli.commands.metrics import metrics_app from cli.commands.analyst import analyst_app + +def _cli_version() -> str: + """Return the installed CLI version from package metadata. + + Falls back to `"unknown"` when the package is not installed (e.g. running + from a source checkout without `uv pip install -e .`). Deliberately does + not read pyproject.toml at runtime — that file is not shipped with the + wheel and the metadata lookup is the canonical source. + """ + try: + return _pkg_version("agnes-the-ai-analyst") + except PackageNotFoundError: + return "unknown" + + +def _version_callback(value: bool) -> None: + if value: + typer.echo(f"da {_cli_version()}") + raise typer.Exit() + + app = typer.Typer( name="da", help="AI Data Analyst CLI — data sync, queries, and admin for AI agents", no_args_is_help=True, ) + +@app.callback() +def _root( + version: bool = typer.Option( + None, + "--version", + "-V", + callback=_version_callback, + is_eager=True, + help="Show the CLI version and exit.", + ), +) -> None: + """Root callback — carries the --version option and fires the auto-update check. + + Update check runs before subcommand dispatch but after the --version flag + (which exits early). It's best-effort: any failure is swallowed so a bad + network never blocks a working `da` command. Disable with + `DA_NO_UPDATE_CHECK=1`. + """ + _maybe_warn_outdated() + + +def _maybe_warn_outdated() -> None: + """Hit /cli/latest on the configured server (cached 24h) and emit a + one-line stderr warning if the installed CLI is older. Never raises.""" + try: + from cli.config import get_server_url + from cli.update_check import check, format_outdated_notice + info = check(get_server_url()) + if info and info.is_outdated(): + typer.echo(format_outdated_notice(info), err=True) + except Exception: + pass # best-effort: never fail a command on the probe + # Register subcommands app.add_typer(auth_app, name="auth") app.add_typer(sync_app, name="sync") diff --git a/cli/update_check.py b/cli/update_check.py new file mode 100644 index 0000000..cf4cb87 --- /dev/null +++ b/cli/update_check.py @@ -0,0 +1,190 @@ +"""Auto-check for a newer CLI version on the configured server. + +Runs in the root typer callback before subcommand dispatch. Failure is +silent — we never block a working `da` command on a best-effort version +probe. Result is cached in `$DA_CONFIG_DIR/update_check.json` for 24h so +we don't hammer the server on every invocation. + +Disable with `DA_NO_UPDATE_CHECK=1`. +""" + +from __future__ import annotations + +import json +import os +import time +from dataclasses import dataclass +from pathlib import Path +from typing import Optional + +from cli.config import _config_dir + +_CACHE_FILENAME = "update_check.json" +_CACHE_TTL_SECONDS = 24 * 60 * 60 # 24h on a successful probe +_NEGATIVE_CACHE_TTL_SECONDS = 5 * 60 # 5min on a failed probe, to avoid +# re-probing 3s of silence (drop-packet networks: corporate firewall, VPN) +# on every `da` invocation. +_REQUEST_TIMEOUT_SECONDS = 3.0 # keep startup snappy + + +@dataclass(frozen=True) +class UpdateInfo: + installed: str + latest: Optional[str] + download_url: Optional[str] + + def is_outdated(self) -> bool: + if not self.latest or self.installed == "unknown": + return False + # Directional: only warn when installed < latest. `!=` would also + # fire when the CLI is *newer* than the server (e.g. after a server + # rollback) and prompt the user to downgrade. + return _version_lt(self.installed, self.latest) + + +def _version_lt(installed: str, latest: str) -> bool: + """Is `installed` strictly older than `latest`? + + Prefer packaging.version.Version (PEP 440, handles pre-release tags). + Fall back to a naive dotted-int tuple for the simple N.N.N case if + packaging is somehow unavailable. Unparseable strings return False — + we'd rather miss an upgrade hint than prompt a silent downgrade. + """ + try: + from packaging.version import InvalidVersion, Version + try: + return Version(installed) < Version(latest) + except InvalidVersion: + pass + except ImportError: + pass + try: + a = tuple(int(x) for x in installed.split(".")) + b = tuple(int(x) for x in latest.split(".")) + return a < b + except ValueError: + return False + + +def is_disabled() -> bool: + return os.environ.get("DA_NO_UPDATE_CHECK", "").lower() in ("1", "true", "yes") + + +def _installed_version() -> str: + from importlib.metadata import PackageNotFoundError + from importlib.metadata import version as _pkg_version + try: + return _pkg_version("agnes-the-ai-analyst") + except PackageNotFoundError: + return "unknown" + + +def _cache_path() -> Path: + return _config_dir() / _CACHE_FILENAME + + +def _read_cache() -> Optional[dict]: + p = _cache_path() + if not p.exists(): + return None + try: + return json.loads(p.read_text()) + except (OSError, json.JSONDecodeError): + return None + + +def _write_cache(entry: dict) -> None: + p = _cache_path() + try: + p.parent.mkdir(parents=True, exist_ok=True) + p.write_text(json.dumps(entry)) + except OSError: + pass # best-effort — cache failure must not break the flow + + +def _fetch_latest(server_url: str) -> Optional[dict]: + """Hit /cli/latest with a short timeout. Returns None on any failure.""" + import httpx + try: + with httpx.Client(base_url=server_url, timeout=_REQUEST_TIMEOUT_SECONDS) as c: + resp = c.get("/cli/latest") + resp.raise_for_status() + return resp.json() + except Exception: + return None + + +def check(server_url: Optional[str]) -> Optional[UpdateInfo]: + """Return UpdateInfo if a check ran (cached or fresh), else None. + + Silent on every failure path: no server configured, CLI package not + installed, network down, malformed response, cache unreadable. + """ + if is_disabled() or not server_url: + return None + + installed = _installed_version() + if installed == "unknown": + return None # can't compare without a known local version + + cache = _read_cache() + now = time.time() + if ( + cache + and cache.get("installed") == installed + and cache.get("server_url") == server_url + and isinstance(cache.get("checked_at"), (int, float)) + ): + age = now - cache["checked_at"] + cached_latest = cache.get("latest") + # Positive cache — keep for 24h. Negative cache (failed probe, + # latest=None) — keep for 5min so we don't re-probe the 3s + # timeout on every command when the server is silently dropping. + ttl = _CACHE_TTL_SECONDS if cached_latest else _NEGATIVE_CACHE_TTL_SECONDS + if age < ttl: + if cached_latest is None: + return None + return UpdateInfo( + installed=installed, + latest=cached_latest, + download_url=cache.get("download_url"), + ) + + payload = _fetch_latest(server_url) + if not payload: + # Negative cache — avoid re-probing on every invocation. + _write_cache({ + "installed": installed, + "server_url": server_url, + "latest": None, + "download_url": None, + "checked_at": now, + }) + return None + + latest = payload.get("version") + dl = payload.get("download_url_path") + download_url = f"{server_url.rstrip('/')}{dl}" if dl else None + + _write_cache({ + "installed": installed, + "server_url": server_url, + "latest": latest, + "download_url": download_url, + "checked_at": now, + }) + return UpdateInfo(installed=installed, latest=latest, download_url=download_url) + + +def format_outdated_notice(info: UpdateInfo) -> str: + """One-line stderr warning when the CLI is out of date. + + `download_url` may be absent (stale cache entry written by an older client, + or server returned a version without a download path). Don't emit the + literal string "None" into a copy-pasteable command — drop the upgrade + snippet in that case. + """ + msg = f"[update] da {info.installed} is out of date — latest on this server is {info.latest}." + if info.download_url: + msg += f" Upgrade: uv tool install --force {info.download_url}" + return msg diff --git a/pyproject.toml b/pyproject.toml index 5127eed..e237627 100644 --- a/pyproject.toml +++ b/pyproject.toml @@ -1,6 +1,6 @@ [project] name = "agnes-the-ai-analyst" -version = "2.0.0" +version = "2.1.0" description = "Agnes — AI Data Analyst platform for AI analytical systems" requires-python = ">=3.11,<3.14" license = "MIT" diff --git a/tests/test_app_version.py b/tests/test_app_version.py new file mode 100644 index 0000000..cd869bd --- /dev/null +++ b/tests/test_app_version.py @@ -0,0 +1,36 @@ +"""Pin that the FastAPI `version=` is read dynamically from package metadata. + +The OpenAPI schema (`/openapi.json`, `/docs`) advertises this version. A +hardcoded literal — the previous state — silently drifts from +`pyproject.toml` on every bump, leaving `/openapi.json` reporting a stale +version while `/api/version`, `/cli/latest`, and `da --version` all +report the bumped one. +""" + +from unittest.mock import patch + + +def test_app_version_reads_package_metadata(): + """`_app_version()` must call importlib.metadata.version with the + canonical package name, not return a hardcoded literal.""" + with patch("app.main._pkg_version", return_value="9.9.9") as mock_pkg_ver: + from app.main import _app_version + assert _app_version() == "9.9.9" + mock_pkg_ver.assert_called_once_with("agnes-the-ai-analyst") + + +def test_app_version_falls_back_to_dev_when_package_missing(): + """Source-checkout without install → report 'dev', not crash.""" + from importlib.metadata import PackageNotFoundError + with patch("app.main._pkg_version", side_effect=PackageNotFoundError): + from app.main import _app_version + assert _app_version() == "dev" + + +def test_fastapi_app_version_matches_package_metadata(): + """End-to-end: what FastAPI stores in `app.version` is whatever + `_app_version()` returned — not a stale literal.""" + with patch("app.main._pkg_version", return_value="7.7.7"): + from app.main import create_app + app = create_app() + assert app.version == "7.7.7" diff --git a/tests/test_cli.py b/tests/test_cli.py index 99c9f0b..bd814e0 100644 --- a/tests/test_cli.py +++ b/tests/test_cli.py @@ -59,6 +59,26 @@ class TestCLIHelp: assert result.exit_code == 0 +class TestCLIVersion: + def test_version_long_flag(self): + result = runner.invoke(app, ["--version"]) + assert result.exit_code == 0 + assert result.output.startswith("da ") + # Version string must be non-empty after the `da ` prefix. + assert result.output.strip() != "da" + + def test_version_short_flag(self): + result = runner.invoke(app, ["-V"]) + assert result.exit_code == 0 + assert result.output.startswith("da ") + + def test_version_exits_before_subcommand_resolution(self): + """Eager callback must run even when an unknown subcommand follows.""" + result = runner.invoke(app, ["--version", "bogus-subcommand"]) + assert result.exit_code == 0 + assert "da " in result.output + + class TestSkills: def test_list_skills(self): result = runner.invoke(app, ["skills", "list"]) diff --git a/tests/test_cli_artifacts.py b/tests/test_cli_artifacts.py index f82c4e1..7618966 100644 --- a/tests/test_cli_artifacts.py +++ b/tests/test_cli_artifacts.py @@ -43,10 +43,8 @@ def test_cli_download_serves_wheel_when_present(monkeypatch, tmp_path): assert resp.content.startswith(b"PK") -def test_cli_agnes_whl_alias_serves_same_bytes_as_download(monkeypatch, tmp_path): - """`/cli/agnes.whl` is a stable alias over `/cli/download` whose URL path - ends in `.whl`, which `uv tool install` requires to treat the resource as - a wheel. Both endpoints must serve identical bytes.""" +def test_cli_wheel_versioned_serves_current_wheel(monkeypatch, tmp_path): + """`/cli/wheel/{filename}` serves the current wheel and matches `/cli/download` bytes.""" wheel = tmp_path / "agnes_fake-1.0-py3-none-any.whl" wheel.write_bytes(b"PK\x03\x04fake-wheel-bytes-agnes") monkeypatch.setenv("AGNES_CLI_DIST_DIR", str(tmp_path)) @@ -54,23 +52,40 @@ def test_cli_agnes_whl_alias_serves_same_bytes_as_download(monkeypatch, tmp_path from app.main import app client = TestClient(app) - resp_alias = client.get("/cli/agnes.whl") - assert resp_alias.status_code == 200 - assert resp_alias.headers["content-type"] == "application/octet-stream" - assert resp_alias.content == wheel.read_bytes() + resp = client.get("/cli/wheel/agnes_fake-1.0-py3-none-any.whl") + assert resp.status_code == 200 + assert resp.headers["content-type"] == "application/octet-stream" + assert resp.content == wheel.read_bytes() resp_download = client.get("/cli/download") assert resp_download.status_code == 200 - assert resp_alias.content == resp_download.content + assert resp.content == resp_download.content -def test_cli_agnes_whl_alias_404_when_no_wheel(monkeypatch, tmp_path): - """Alias returns 404 with a helpful message when no wheel is present.""" +def test_cli_wheel_versioned_rejects_other_filenames(monkeypatch, tmp_path): + """Arbitrary `wheel_name` values must 404 — no filesystem lookup from user input.""" + wheel = tmp_path / "agnes_fake-1.0-py3-none-any.whl" + wheel.write_bytes(b"PK\x03\x04") monkeypatch.setenv("AGNES_CLI_DIST_DIR", str(tmp_path)) from fastapi.testclient import TestClient from app.main import app client = TestClient(app) - resp = client.get("/cli/agnes.whl") + + resp_wrong = client.get("/cli/wheel/other-2.0-py3-none-any.whl") + assert resp_wrong.status_code == 404 + + +def test_cli_agnes_whl_alias_is_gone(monkeypatch, tmp_path): + """The bareword alias was removed — it never worked with `uv tool install` + (uv validates the filename before fetching) and only confused users. The + only CLI wheel URL is now `/cli/wheel/{filename}`.""" + wheel = tmp_path / "agnes_fake-1.0-py3-none-any.whl" + wheel.write_bytes(b"PK\x03\x04") + monkeypatch.setenv("AGNES_CLI_DIST_DIR", str(tmp_path)) + from fastapi.testclient import TestClient + from app.main import app + client = TestClient(app) + resp = client.get("/cli/agnes.whl", follow_redirects=False) assert resp.status_code == 404 diff --git a/tests/test_cli_sync.py b/tests/test_cli_sync.py index d14379f..8ec979b 100644 --- a/tests/test_cli_sync.py +++ b/tests/test_cli_sync.py @@ -1,5 +1,6 @@ """Tests for da sync command.""" +import hashlib import json import pytest from unittest.mock import patch, MagicMock, call @@ -27,19 +28,32 @@ def _resp(status_code=200, json_data=None): return r +# Hash of the fake parquet payload below — matches what sync.py would compute. +_FAKE_PARQUET_BYTES = b"PAR1" + b"\x00" * 32 + b"PAR1" +_FAKE_PARQUET_MD5 = hashlib.md5(_FAKE_PARQUET_BYTES).hexdigest() + MANIFEST = { "tables": { - "orders": {"hash": "abc123", "rows": 100, "size_bytes": 2048}, - "customers": {"hash": "def456", "rows": 50, "size_bytes": 1024}, + # Hashes match _FAKE_PARQUET_BYTES so happy-path tests pass the + # manifest-hash integrity check. + "orders": {"hash": _FAKE_PARQUET_MD5, "rows": 100, "size_bytes": 2048}, + "customers": {"hash": _FAKE_PARQUET_MD5, "rows": 50, "size_bytes": 1024}, } } +def _fake_stream_download(path, target, *args, **kwargs): + """Drop-in replacement for cli.commands.sync.stream_download that writes + the well-known fake parquet to the target path.""" + with open(target, "wb") as f: + f.write(_FAKE_PARQUET_BYTES) + + class TestSyncHappyPath: def test_sync_downloads_all_tables(self, tmp_config): """Sync with no local state downloads all tables.""" with patch("cli.commands.sync.api_get", return_value=_resp(200, MANIFEST)): - with patch("cli.commands.sync.stream_download") as mock_dl: + with patch("cli.commands.sync.stream_download", side_effect=_fake_stream_download) as mock_dl: with patch("cli.commands.sync._rebuild_duckdb_views"): result = runner.invoke(app, ["sync"]) assert result.exit_code == 0 @@ -49,7 +63,7 @@ class TestSyncHappyPath: def test_sync_specific_table(self, tmp_config): """--table flag limits download to one table.""" with patch("cli.commands.sync.api_get", return_value=_resp(200, MANIFEST)): - with patch("cli.commands.sync.stream_download") as mock_dl: + with patch("cli.commands.sync.stream_download", side_effect=_fake_stream_download) as mock_dl: with patch("cli.commands.sync._rebuild_duckdb_views"): result = runner.invoke(app, ["sync", "--table", "orders"]) assert result.exit_code == 0 @@ -60,7 +74,7 @@ class TestSyncHappyPath: def test_sync_json_output(self, tmp_config): """--json flag produces valid JSON output (rich spinner may precede JSON).""" with patch("cli.commands.sync.api_get", return_value=_resp(200, MANIFEST)): - with patch("cli.commands.sync.stream_download"): + with patch("cli.commands.sync.stream_download", side_effect=_fake_stream_download): with patch("cli.commands.sync._rebuild_duckdb_views"): result = runner.invoke(app, ["sync", "--json"]) assert result.exit_code == 0 @@ -102,8 +116,8 @@ class TestSyncErrors: """Tables with matching hashes are not re-downloaded.""" state = { "tables": { - "orders": {"hash": "abc123"}, - "customers": {"hash": "def456"}, + "orders": {"hash": _FAKE_PARQUET_MD5}, + "customers": {"hash": _FAKE_PARQUET_MD5}, } } with patch("cli.commands.sync.get_sync_state", return_value=state): @@ -114,3 +128,235 @@ class TestSyncErrors: # Nothing to download — both hashes match assert mock_dl.call_count == 0 assert "Downloaded: 0" in result.output + + +class TestFmtBytes: + """_fmt_bytes must label magnitudes correctly — the fallback unit has + to match the final loop exit, not be a fixed label.""" + + def test_small_and_medium_sizes(self): + from cli.commands.sync import _fmt_bytes + assert _fmt_bytes(0) == "0 B" + assert _fmt_bytes(512) == "512 B" + assert _fmt_bytes(2048) == "2.0 KiB" + assert _fmt_bytes(2 * 1024**2) == "2.0 MiB" + assert _fmt_bytes(5 * 1024**3) == "5.0 GiB" + assert _fmt_bytes(3 * 1024**4) == "3.0 TiB" + + def test_pib_and_eib_are_labelled_correctly(self): + """Off-by-unit regression: 1 PiB must render as '1.0 PiB', not '1024.0 PiB'.""" + from cli.commands.sync import _fmt_bytes + assert _fmt_bytes(1024**5) == "1.0 PiB" + assert _fmt_bytes(2 * 1024**5) == "2.0 PiB" + # Fallback unit at the very top. + assert _fmt_bytes(1024**6) == "1.0 EiB" + + +class TestSyncDurability: + """Durability & integrity layer: hash check, PAR1 fallback, broken-rebuild recovery.""" + + def _write(self, tmp_config, tid: str, body: bytes) -> None: + (tmp_config / "local" / "server" / "parquet").mkdir(parents=True, exist_ok=True) + (tmp_config / "local" / "server" / "parquet" / f"{tid}.parquet").write_bytes(body) + + def test_hash_mismatch_recorded_as_error(self, tmp_config): + """If manifest hash is present and does not match the downloaded bytes, + the file must be discarded and the error recorded.""" + def bad_stream(path, target, *a, **kw): + with open(target, "wb") as f: + f.write(b"PAR1" + b"\xaa" * 50 + b"PAR1") # valid PAR1, wrong hash + + with patch("cli.commands.sync.api_get", return_value=_resp(200, MANIFEST)): + with patch("cli.commands.sync.stream_download", side_effect=bad_stream): + with patch("cli.commands.sync._rebuild_duckdb_views") as mock_rebuild: + result = runner.invoke(app, ["sync"]) + assert result.exit_code == 0 + assert "Downloaded: 0" in result.output + assert "Errors: 2" in result.output + assert "hash mismatch" in result.output + assert mock_rebuild.call_count == 0 + + def test_par1_fallback_when_manifest_hash_missing(self, tmp_config): + """Legacy manifests without `hash` must fall back to the PAR1 structural check.""" + manifest_no_hash = {"tables": {"orders": {"hash": "", "rows": 10, "size_bytes": 16}}} + + def html_stream(path, target, *a, **kw): + with open(target, "wb") as f: + f.write(b"oops") + + with patch("cli.commands.sync.api_get", return_value=_resp(200, manifest_no_hash)): + with patch("cli.commands.sync.stream_download", side_effect=html_stream): + with patch("cli.commands.sync._rebuild_duckdb_views"): + result = runner.invoke(app, ["sync"]) + assert "PAR1" in result.output # fallback message appears + assert "Downloaded: 0" in result.output + + def test_rebuild_skips_broken_parquet_without_aborting(self, tmp_config): + """Pre-existing broken parquet must not kill the whole rebuild.""" + self._write(tmp_config, "broken", b"not-parquet-at-all") + self._write(tmp_config, "also_bad", b"PAR1" + b"\x00" * 10 + b"PAR1") + + from cli.commands.sync import _rebuild_duckdb_views + local_dir = tmp_config / "local" + parquet_dir = local_dir / "server" / "parquet" + # Must not raise — both files are garbage but the function recovers. + _rebuild_duckdb_views(local_dir, parquet_dir) + + +class TestStreamDownloadAtomicAndRetry: + """stream_download: atomic tmp→rename, retries on transient errors, no retry on 4xx.""" + + def test_atomic_write_via_tmp_then_rename(self, tmp_path, monkeypatch): + """Target file must not exist before os.replace runs; writes go to .tmp first.""" + monkeypatch.setenv("DA_CONFIG_DIR", str(tmp_path)) + monkeypatch.setenv("DA_SERVER_URL", "http://localhost:9999") + + target = tmp_path / "x.parquet" + observed_paths: list[str] = [] + + class FakeStream: + def __init__(self, chunks): + self._chunks = chunks + def raise_for_status(self): pass + def iter_bytes(self, chunk_size=65536): + # Observe target path at the moment of writing. + observed_paths.append(str(target) + " exists=" + str(target.exists())) + yield from self._chunks + def __enter__(self): return self + def __exit__(self, *a): pass + + class FakeClient: + def __init__(self, *a, **kw): pass + def stream(self, method, path): return FakeStream([b"PAR1", b"\x00" * 10, b"PAR1"]) + def __enter__(self): return self + def __exit__(self, *a): pass + + import cli.client as client_mod + monkeypatch.setattr(client_mod, "get_client", lambda timeout=30.0: FakeClient()) + client_mod.stream_download("/ignored", str(target)) + assert target.exists() + assert not (tmp_path / "x.parquet.tmp").exists() + # The target did NOT exist while iter_bytes was pumping — only the .tmp did. + assert all("exists=False" in p for p in observed_paths) + + def test_retries_on_transient_error(self, tmp_path, monkeypatch): + """Transient network errors (ConnectError) trigger retry; eventual success is transparent.""" + monkeypatch.setenv("DA_CONFIG_DIR", str(tmp_path)) + monkeypatch.setenv("DA_SERVER_URL", "http://localhost:9999") + monkeypatch.setenv("DA_STREAM_RETRIES", "3") + + target = tmp_path / "x.parquet" + calls = {"n": 0} + + import httpx + class FakeStream: + def raise_for_status(self): pass + def iter_bytes(self, chunk_size=65536): + yield b"PAR1" + b"\x00" * 4 + b"PAR1" + def __enter__(self): return self + def __exit__(self, *a): pass + + class FakeClient: + def stream(self, method, path): + calls["n"] += 1 + if calls["n"] < 3: + raise httpx.ConnectError("flap") + return FakeStream() + def __enter__(self): return self + def __exit__(self, *a): pass + + import cli.client as client_mod + monkeypatch.setattr(client_mod, "get_client", lambda timeout=30.0: FakeClient()) + # Speed up test — drop sleep to zero. + monkeypatch.setattr(client_mod, "_RETRY_BACKOFFS_S", (0.0, 0.0, 0.0)) + + client_mod.stream_download("/ignored", str(target)) + assert calls["n"] == 3 # 2 failures + 1 success + assert target.exists() + + def test_no_retry_on_4xx(self, tmp_path, monkeypatch): + """4xx (auth, 404) must surface immediately — retries are for transient issues only.""" + monkeypatch.setenv("DA_CONFIG_DIR", str(tmp_path)) + monkeypatch.setenv("DA_SERVER_URL", "http://localhost:9999") + + import httpx + calls = {"n": 0} + + class FakeResponse: + status_code = 404 + def raise_for_status(self): + raise httpx.HTTPStatusError( + "404", request=MagicMock(), response=MagicMock(status_code=404) + ) + def iter_bytes(self, chunk_size=65536): + return iter([]) + def __enter__(self): return self + def __exit__(self, *a): pass + + class FakeClient: + def stream(self, method, path): + calls["n"] += 1 + return FakeResponse() + def __enter__(self): return self + def __exit__(self, *a): pass + + import cli.client as client_mod + monkeypatch.setattr(client_mod, "get_client", lambda timeout=30.0: FakeClient()) + monkeypatch.setattr(client_mod, "_RETRY_BACKOFFS_S", (0.0, 0.0, 0.0)) + + with pytest.raises(httpx.HTTPStatusError): + client_mod.stream_download("/ignored", str(tmp_path / "x.parquet")) + assert calls["n"] == 1 # no retry on 4xx + + +class TestSyncDryRun: + def test_dry_run_skips_download_and_state_writes(self, tmp_config): + """--dry-run must not call stream_download, save_sync_state, or _rebuild_duckdb_views.""" + with patch("cli.commands.sync.api_get", return_value=_resp(200, MANIFEST)): + with patch("cli.commands.sync.stream_download") as mock_dl: + with patch("cli.commands.sync.save_sync_state") as mock_save: + with patch("cli.commands.sync._rebuild_duckdb_views") as mock_rebuild: + result = runner.invoke(app, ["sync", "--dry-run"]) + assert result.exit_code == 0 + assert mock_dl.call_count == 0 + assert mock_save.call_count == 0 + assert mock_rebuild.call_count == 0 + assert "Dry run" in result.output + # Table ids from the MANIFEST fixture must show up in the plan. + assert "orders" in result.output + assert "customers" in result.output + + def test_dry_run_json_output_shape(self, tmp_config): + """--dry-run --json emits a parseable plan with dry_run=True and a summary.""" + with patch("cli.commands.sync.api_get", return_value=_resp(200, MANIFEST)): + with patch("cli.commands.sync.stream_download"): + result = runner.invoke(app, ["sync", "--dry-run", "--json"]) + assert result.exit_code == 0 + json_start = result.output.find("{") + assert json_start >= 0 + # Rich Progress may emit additional lines after the JSON block, so use + # raw_decode to stop at the object boundary. + data, _ = json.JSONDecoder().raw_decode(result.output[json_start:]) + assert data["dry_run"] is True + assert data["summary"]["tables_to_download"] == 2 + assert data["summary"]["bytes_total"] == 2048 + 1024 + tables = [row["table"] for row in data["would_download"]] + assert set(tables) == {"orders", "customers"} + + def test_dry_run_respects_table_filter(self, tmp_config): + """--dry-run --table X only lists that one table in the plan.""" + with patch("cli.commands.sync.api_get", return_value=_resp(200, MANIFEST)): + with patch("cli.commands.sync.stream_download") as mock_dl: + result = runner.invoke(app, ["sync", "--dry-run", "--table", "orders"]) + assert result.exit_code == 0 + assert mock_dl.call_count == 0 + assert "orders" in result.output + assert "customers" not in result.output + + def test_dry_run_upload_only_does_not_hit_api(self, tmp_config): + """--upload-only --dry-run must not call api_post.""" + with patch("cli.commands.sync.api_post") as mock_post: + result = runner.invoke(app, ["sync", "--upload-only", "--dry-run"]) + assert result.exit_code == 0 + assert mock_post.call_count == 0 + assert "Dry run" in result.output or "would upload" in result.output.lower() diff --git a/tests/test_cli_update_check.py b/tests/test_cli_update_check.py new file mode 100644 index 0000000..4bbfb78 --- /dev/null +++ b/tests/test_cli_update_check.py @@ -0,0 +1,250 @@ +"""Tests for the CLI auto-update check (cli/update_check.py).""" + +import json +from unittest.mock import patch + +import pytest +from typer.testing import CliRunner + +from cli.main import app + +runner = CliRunner() + + +@pytest.fixture(autouse=True) +def tmp_config(tmp_path, monkeypatch): + monkeypatch.setenv("DA_CONFIG_DIR", str(tmp_path)) + # Point CLI at a fake server so get_server_url() returns something stable. + monkeypatch.setenv("DA_SERVER", "http://server.test:8000") + yield tmp_path + + +def test_check_returns_none_when_disabled(tmp_config): + import os + os.environ["DA_NO_UPDATE_CHECK"] = "1" + try: + from cli import update_check + assert update_check.check("http://server.test:8000") is None + finally: + del os.environ["DA_NO_UPDATE_CHECK"] + + +def test_check_returns_none_when_server_url_missing(tmp_config): + from cli import update_check + assert update_check.check("") is None + assert update_check.check(None) is None # type: ignore[arg-type] + + +def test_check_returns_none_when_installed_version_unknown(tmp_config): + from cli import update_check + with patch("cli.update_check._installed_version", return_value="unknown"): + assert update_check.check("http://server.test:8000") is None + + +def test_check_fresh_fetch_and_cache_write(tmp_config): + from cli import update_check + + payload = { + "version": "2.1.0", + "wheel_filename": "agnes_the_ai_analyst-2.1.0-py3-none-any.whl", + "download_url_path": "/cli/wheel/agnes_the_ai_analyst-2.1.0-py3-none-any.whl", + } + with patch("cli.update_check._installed_version", return_value="2.0.0"): + with patch("cli.update_check._fetch_latest", return_value=payload): + info = update_check.check("http://server.test:8000") + + assert info is not None + assert info.installed == "2.0.0" + assert info.latest == "2.1.0" + assert info.download_url == ( + "http://server.test:8000/cli/wheel/agnes_the_ai_analyst-2.1.0-py3-none-any.whl" + ) + assert info.is_outdated() is True + + # Cache file was written and re-reading it returns the same latest. + cache = json.loads((tmp_config / "update_check.json").read_text()) + assert cache["installed"] == "2.0.0" + assert cache["latest"] == "2.1.0" + + +def test_check_uses_cache_within_ttl(tmp_config): + """Cached entry within 24h skips the network fetch.""" + from cli import update_check + + # Seed a fresh cache entry. + (tmp_config / "update_check.json").write_text(json.dumps({ + "installed": "2.0.0", + "server_url": "http://server.test:8000", + "latest": "2.0.5", + "download_url": "http://server.test:8000/cli/wheel/agnes_the_ai_analyst-2.0.5-py3-none-any.whl", + "checked_at": __import__("time").time(), # now + })) + + with patch("cli.update_check._installed_version", return_value="2.0.0"): + with patch("cli.update_check._fetch_latest") as mock_fetch: + info = update_check.check("http://server.test:8000") + + assert mock_fetch.call_count == 0 # cache hit + assert info.latest == "2.0.5" + assert info.is_outdated() is True + + +def test_check_invalidates_cache_when_installed_version_changed(tmp_config): + """User ran a fresh install after the cache was written — re-probe.""" + from cli import update_check + + # Seed cache claiming the installed version was 1.9.0. + (tmp_config / "update_check.json").write_text(json.dumps({ + "installed": "1.9.0", + "server_url": "http://server.test:8000", + "latest": "2.0.0", + "download_url": "http://server.test:8000/cli/wheel/x.whl", + "checked_at": __import__("time").time(), + })) + + payload = {"version": "2.1.0", "download_url_path": "/cli/wheel/y.whl"} + with patch("cli.update_check._installed_version", return_value="2.0.0"): + with patch("cli.update_check._fetch_latest", return_value=payload) as mock_fetch: + info = update_check.check("http://server.test:8000") + + assert mock_fetch.call_count == 1 # cache was invalidated + assert info.latest == "2.1.0" + + +def test_check_handles_network_failure_silently(tmp_config): + """A probe that errors out returns None; no exception leaks.""" + from cli import update_check + with patch("cli.update_check._installed_version", return_value="2.0.0"): + with patch("cli.update_check._fetch_latest", return_value=None): + assert update_check.check("http://server.test:8000") is None + + +def test_negative_cache_avoids_reprobe_on_repeated_failure(tmp_config): + """Two consecutive check() calls after a failed probe must fire the + network once — the second call hits the 5-minute negative cache.""" + from cli import update_check + + with patch("cli.update_check._installed_version", return_value="2.0.0"): + with patch("cli.update_check._fetch_latest", return_value=None) as mock_fetch: + assert update_check.check("http://server.test:8000") is None + # Second call within the negative-cache window. + assert update_check.check("http://server.test:8000") is None + + assert mock_fetch.call_count == 1 # no re-probe + + +def test_negative_cache_expires_after_ttl(tmp_config): + """After the negative TTL elapses, the probe fires again.""" + import time + import json as _json + + from cli import update_check + + # Seed a stale negative-cache entry (older than 5min). + stale_ts = time.time() - (update_check._NEGATIVE_CACHE_TTL_SECONDS + 60) + (tmp_config / "update_check.json").write_text(_json.dumps({ + "installed": "2.0.0", + "server_url": "http://server.test:8000", + "latest": None, + "download_url": None, + "checked_at": stale_ts, + })) + + payload = {"version": "2.1.0", "download_url_path": "/cli/wheel/x.whl"} + with patch("cli.update_check._installed_version", return_value="2.0.0"): + with patch("cli.update_check._fetch_latest", return_value=payload) as mock_fetch: + info = update_check.check("http://server.test:8000") + + assert mock_fetch.call_count == 1 # cache expired, refetch + assert info is not None + assert info.latest == "2.1.0" + + +def test_is_outdated_false_when_same_version(tmp_config): + from cli.update_check import UpdateInfo + info = UpdateInfo(installed="2.0.0", latest="2.0.0", download_url="…") + assert info.is_outdated() is False + + +def test_is_outdated_false_when_latest_unknown(tmp_config): + from cli.update_check import UpdateInfo + info = UpdateInfo(installed="2.0.0", latest=None, download_url=None) + assert info.is_outdated() is False + + +def test_is_outdated_true_when_installed_older(tmp_config): + from cli.update_check import UpdateInfo + info = UpdateInfo(installed="2.0.0", latest="2.1.0", download_url="…") + assert info.is_outdated() is True + + +def test_is_outdated_false_when_installed_newer_than_server(tmp_config): + """After a server rollback the CLI may be ahead — don't prompt a downgrade.""" + from cli.update_check import UpdateInfo + info = UpdateInfo(installed="2.1.0", latest="2.0.0", download_url="…") + assert info.is_outdated() is False + + +def test_is_outdated_uses_pep440_comparison(tmp_config): + """`10.0.0 > 2.1.0` — must not be tripped by lexicographic string compare.""" + from cli.update_check import UpdateInfo + newer_on_server = UpdateInfo(installed="2.1.0", latest="10.0.0", download_url="…") + older_on_server = UpdateInfo(installed="10.0.0", latest="2.1.0", download_url="…") + assert newer_on_server.is_outdated() is True + assert older_on_server.is_outdated() is False + + +def test_is_outdated_false_for_unparseable_strings(tmp_config): + """Unparseable versions default to False — we'd rather miss an upgrade + hint than suggest a bogus downgrade.""" + from cli.update_check import UpdateInfo + info = UpdateInfo(installed="nightly-abc", latest="nightly-def", download_url="…") + assert info.is_outdated() is False + + +def test_format_outdated_notice_drops_upgrade_line_when_no_download_url(tmp_config): + """`download_url=None` must NOT produce literal "None" in the copy-pasteable command.""" + from cli.update_check import UpdateInfo, format_outdated_notice + info = UpdateInfo(installed="2.0.0", latest="2.1.0", download_url=None) + msg = format_outdated_notice(info) + assert "None" not in msg + assert "uv tool install" not in msg + assert "2.0.0" in msg and "2.1.0" in msg + + +def test_format_outdated_notice_includes_upgrade_command_when_url_present(tmp_config): + from cli.update_check import UpdateInfo, format_outdated_notice + info = UpdateInfo( + installed="2.0.0", + latest="2.1.0", + download_url="http://s/cli/wheel/a-2.1.0-py3-none-any.whl", + ) + msg = format_outdated_notice(info) + assert "uv tool install --force http://s/cli/wheel/a-2.1.0-py3-none-any.whl" in msg + + +class TestRootCallbackIntegration: + """The root callback must not crash a command when the probe fails, and + must emit a stderr warning when the server advertises a newer version.""" + + def test_probe_failure_does_not_break_command(self, tmp_config): + with patch("cli.update_check.check", side_effect=RuntimeError("boom")): + result = runner.invoke(app, ["--help"]) + assert result.exit_code == 0 + + def test_outdated_warning_is_emitted(self, tmp_config, capsys): + """Unit-test the warning hook directly: `--help` is eager and bypasses + the callback body, so we test `_maybe_warn_outdated` itself, which + is what every real subcommand dispatch triggers.""" + from cli.main import _maybe_warn_outdated + from cli.update_check import UpdateInfo + info = UpdateInfo( + installed="2.0.0", + latest="2.1.0", + download_url="http://server.test:8000/cli/wheel/x.whl", + ) + with patch("cli.update_check.check", return_value=info): + _maybe_warn_outdated() + captured = capsys.readouterr() + assert "[update]" in captured.err + assert "2.1.0" in captured.err diff --git a/tests/test_selective_gzip.py b/tests/test_selective_gzip.py new file mode 100644 index 0000000..7c17779 --- /dev/null +++ b/tests/test_selective_gzip.py @@ -0,0 +1,87 @@ +"""Tests for the SelectiveGZipMiddleware path-skip logic in app/main.py. + +Key property: parquet-serving endpoints must not be gzipped on the wire, +but JSON / HTML endpoints above the minimum-size threshold must be. +""" + +import pytest +from fastapi.testclient import TestClient + + +@pytest.fixture +def isolated_client(tmp_path, monkeypatch): + """Fresh FastAPI app with its own tmp DATA_DIR so DuckDB locks don't + collide with a concurrently-running dev container.""" + monkeypatch.setenv("DATA_DIR", str(tmp_path)) + monkeypatch.setenv("TESTING", "1") + monkeypatch.setenv("JWT_SECRET_KEY", "test-secret-key-min-32-characters!!") + (tmp_path / "state").mkdir() + (tmp_path / "analytics").mkdir() + (tmp_path / "extracts").mkdir() + from src.db import close_system_db + close_system_db() + from app.main import create_app + yield TestClient(create_app()) + close_system_db() + + +def test_parquet_path_is_not_gzipped(isolated_client, tmp_path, monkeypatch): + """/cli/wheel/... must return the raw bytes without Content-Encoding: gzip.""" + wheel = tmp_path / "agnes_fake-1.0-py3-none-any.whl" + wheel.write_bytes(b"PK\x03\x04" + b"x" * 4096) + monkeypatch.setenv("AGNES_CLI_DIST_DIR", str(tmp_path)) + + resp = isolated_client.get( + f"/cli/wheel/{wheel.name}", + headers={"Accept-Encoding": "gzip"}, + ) + assert resp.status_code == 200 + assert "gzip" not in resp.headers.get("content-encoding", "") + assert resp.content.startswith(b"PK") + + +def test_install_page_is_gzipped(isolated_client): + """/install is HTML above the threshold — gzip should kick in when the + client advertises gzip support. TestClient may decompress transparently, + so we accept either the header or readable body as proof that the + middleware decided to handle the response (i.e. did not skip).""" + resp = isolated_client.get("/install", headers={"Accept-Encoding": "gzip"}) + assert resp.status_code == 200 + enc = resp.headers.get("content-encoding", "") + # Either we see the encoding on the wire OR TestClient auto-decoded it. + assert "gzip" in enc or "install" in resp.text.lower() + + +def test_no_accept_encoding_means_no_gzip_anywhere(isolated_client): + """Client that doesn't advertise gzip gets uncompressed body.""" + resp = isolated_client.get("/install", headers={"Accept-Encoding": "identity"}) + assert resp.status_code == 200 + assert "gzip" not in resp.headers.get("content-encoding", "") + + +def test_selective_gzip_wrapper_dispatches_on_prefix(): + """Direct unit test of the wrapper's path-based branch without standing up + the whole FastAPI app — verifies the skip list is honoured.""" + from app.main import _SelectiveGZipMiddleware + + calls = {"raw": 0, "gzip": 0} + + async def raw_app(scope, receive, send): + calls["raw"] += 1 + + wrapper = _SelectiveGZipMiddleware(raw_app, minimum_size=10, skip_prefixes=("/api/data/",)) + # Monkey-patch the gzip inner so we can count hits without running middleware. + async def stub_gzip(scope, receive, send): + calls["gzip"] += 1 + wrapper._gzip = stub_gzip + + import asyncio + # Path that matches the skip prefix → raw app + asyncio.run(wrapper({"type": "http", "path": "/api/data/orders/download"}, None, None)) + assert calls == {"raw": 1, "gzip": 0} + # Path that does not → gzip app + asyncio.run(wrapper({"type": "http", "path": "/api/sync/manifest"}, None, None)) + assert calls == {"raw": 1, "gzip": 1} + # Non-http scope (websocket, lifespan) → gzip app (it handles lifespan as pass-through) + asyncio.run(wrapper({"type": "lifespan"}, None, None)) + assert calls == {"raw": 1, "gzip": 2} diff --git a/tests/test_setup_instructions.py b/tests/test_setup_instructions.py new file mode 100644 index 0000000..13df5f9 --- /dev/null +++ b/tests/test_setup_instructions.py @@ -0,0 +1,56 @@ +"""Tests for the setup-instructions template + resolver. + +`uv tool install` validates the PEP 427 filename in the URL path before +fetching, so our setup snippet cannot use a stable alias like `agnes.whl`. +These tests pin the wheel-filename substitution behavior. +""" + + +def test_resolve_lines_substitutes_wheel_filename(): + from app.web.setup_instructions import resolve_lines + + lines = resolve_lines("agnes_the_ai_analyst-2.0.0-py3-none-any.whl") + joined = "\n".join(lines) + assert "{wheel_filename}" not in joined + assert "/cli/wheel/agnes_the_ai_analyst-2.0.0-py3-none-any.whl" in joined + + +def test_resolve_lines_fallback_filename_is_honoured(): + """Callers pass `'agnes.whl'` when no wheel is on disk; substitution still works.""" + from app.web.setup_instructions import resolve_lines + + lines = resolve_lines("agnes.whl") + assert "{wheel_filename}" not in "\n".join(lines) + assert any("/cli/wheel/agnes.whl" in line for line in lines) + + +def test_render_setup_instructions_wires_all_placeholders(): + from app.web.setup_instructions import render_setup_instructions + + out = render_setup_instructions( + server_url="https://agnes.example.com", + token="T-123", + wheel_filename="agnes_the_ai_analyst-2.0.0-py3-none-any.whl", + ) + assert "{server_url}" not in out + assert "{token}" not in out + assert "{wheel_filename}" not in out + assert "https://agnes.example.com/cli/wheel/agnes_the_ai_analyst-2.0.0-py3-none-any.whl" in out + assert "T-123" in out + + +def test_install_page_uses_versioned_wheel_url(monkeypatch, tmp_path): + """End-to-end: the /install preview must render the PEP 427 wheel URL, + so a user copy-pasting the snippet gets a URL `uv tool install` accepts.""" + wheel = tmp_path / "agnes_the_ai_analyst-2.0.0-py3-none-any.whl" + wheel.write_bytes(b"PK\x03\x04") + monkeypatch.setenv("AGNES_CLI_DIST_DIR", str(tmp_path)) + + from fastapi.testclient import TestClient + from app.main import app + client = TestClient(app) + resp = client.get("/install", headers={"host": "agnes.test", "Accept": "text/html"}) + assert resp.status_code == 200 + assert "/cli/wheel/agnes_the_ai_analyst-2.0.0-py3-none-any.whl" in resp.text + # The bare alias must no longer appear in the rendered snippet. + assert "/cli/agnes.whl" not in resp.text diff --git a/tests/test_web_ui.py b/tests/test_web_ui.py index 7c58ca4..c8f7dba 100644 --- a/tests/test_web_ui.py +++ b/tests/test_web_ui.py @@ -160,8 +160,11 @@ class TestClaudeSetupPreview: assert "What Claude Code will receive" in body assert "<will be generated on click>" in body assert 'class="placeholder-token"' in body - # Setup payload text substituted with real server URL - assert "/cli/agnes.whl" in body + # Setup payload text substituted with real server URL. The wheel URL + # must be under /cli/wheel/ (uv tool install rejects a bare .whl alias + # because it validates the PEP 427 filename in the URL before fetch). + assert "/cli/wheel/" in body + assert "/cli/agnes.whl" not in body # New numbered headers + da diagnose step assert "1) Install the CLI" in body assert "4) Run diagnostics" in body