diff --git a/CHANGELOG.md b/CHANGELOG.md index 9a598e6..4926cbb 100644 --- a/CHANGELOG.md +++ b/CHANGELOG.md @@ -11,30 +11,6 @@ CalVer image tags (`stable-YYYY.MM.N`, `dev-YYYY.MM.N`) are produced for every C ## [Unreleased] ### Performance -- **DuckDB BigQuery-extension session pool** - (`connectors/bigquery/access.py`). `BqAccess.duckdb_session()` now acquires - pre-warmed connections from a bounded process-local pool instead of running - `INSTALL bigquery; LOAD bigquery; CREATE SECRET; ATTACH …` on every request. - Each acquire saves the ~0.5 s extension-load + secret-creation cost when - the pool has a warm entry; auth SECRET is refreshed on acquire so a - long-lived pooled entry doesn't keep a stale GCE metadata token past its - TTL. Pool size is configurable via `data_source.bigquery.session_pool_size` - (default 4; sentinel `0` disables pooling). Affects every BQ-touching path - — `/api/query`, `/api/v2/scan`, `/api/v2/sample`, `/api/v2/schema`, - materialize, and the orchestrator's remote-attach. - -### Fixed -- **BigQuery `responseTooLarge` no longer surfaces as a generic 400 / 502 with - the raw upstream message** (`connectors/bigquery/access.py`). The - `translate_bq_error` helper now classifies "Response too large to return" - errors via a dedicated `bq_response_too_large` kind (HTTP 400) with an - actionable hint pointing at the WHERE / aggregation / materialized-table - remediations. Pre-fix this failure mode fell through to the generic - `bq_bad_request` mapping, which implied the user's SQL had a syntax error - — wrong root cause. Affects every BQ-touching path (`/api/query`, - `/api/v2/scan`, `/api/v2/sample`, `/api/v2/schema`, materialize) since - they all share `translate_bq_error`. - - **`/api/query` (and `agnes query --remote`) now rewrites user SQL referencing `query_mode='remote'` BigQuery rows into a single `bigquery_query()` call before execute** (`app/api/query.py`). Pre-fix the master view @@ -51,6 +27,60 @@ CalVer image tags (`stable-YYYY.MM.N`, `dev-YYYY.MM.N`) are produced for every C fall-through: cross-source JOINs (BQ ↔ Keboola/Jira local), queries already containing `bigquery_query(`, and unconfigured BQ project all keep the original ATTACH-catalog path so behavior degrades gracefully. +- **DuckDB BigQuery-extension session pool** + (`connectors/bigquery/access.py`). `BqAccess.duckdb_session()` now acquires + pre-warmed connections from a bounded process-local pool instead of running + `INSTALL bigquery; LOAD bigquery; CREATE SECRET; ATTACH …` on every request. + Each acquire saves the ~0.5 s extension-load + secret-creation cost when + the pool has a warm entry; auth SECRET is refreshed on acquire so a + long-lived pooled entry doesn't keep a stale GCE metadata token past its + TTL. Pool size is configurable via `data_source.bigquery.session_pool_size` + (default 4; sentinel `0` disables pooling). Affects every BQ-touching path + — `/api/query`, `/api/v2/scan`, `/api/v2/sample`, `/api/v2/schema`, + materialize, and the orchestrator's remote-attach. +- **`agnes pull` chunked download for large parquets**: when the server + advertises `accept-ranges: bytes` and a parquet exceeds + `AGNES_PULL_CHUNK_THRESHOLD_BYTES` (default 50 MB), the CLI now splits + the file into N parallel HTTP Range requests + (`AGNES_PULL_CHUNK_PARALLELISM`, default 4, capped 1..16) and assembles + the parts into the destination atomically. Targets the per-flow-shaped + network (corp VPN with per-TCP-connection rate-limiting) where a single + stream is throttled but N parallel streams over the same connection + scale roughly linearly. Falls back to single-stream when the server + responds 200 instead of 206 to a Range probe, when no + `accept-ranges: bytes` is advertised, or when content is below the + threshold — no behavior change in the small-file / non-cooperating- + server cases. +- **Persistent HTTP/2 client across `agnes pull`**: `stream_download` now + routes through a process-wide pooled `httpx.Client` so N parquet + downloads share a single TLS handshake; HTTP/2 multiplexing + (when the optional `h2` package is installed) lets all chunk Range + requests share one TCP connection. Gracefully falls back to HTTP/1.1 + pooling when `h2` is missing — no crash, just slightly less benefit. + +### Fixed +- **BigQuery `responseTooLarge` no longer surfaces as a generic 400 / 502 with + the raw upstream message** (`connectors/bigquery/access.py`). The + `translate_bq_error` helper now classifies "Response too large to return" + errors via a dedicated `bq_response_too_large` kind (HTTP 400) with an + actionable hint pointing at the WHERE / aggregation / materialized-table + remediations. Pre-fix this failure mode fell through to the generic + `bq_bad_request` mapping, which implied the user's SQL had a syntax error + — wrong root cause. Affects every BQ-touching path (`/api/query`, + `/api/v2/scan`, `/api/v2/sample`, `/api/v2/schema`, materialize) since + they all share `translate_bq_error`. + +### Added +- New optional dependency `h2>=4.1.0` (HTTP/2 transport for httpx). Pure + performance — `agnes pull` works on HTTP/1.1 if the install skips it. +- **Textual progress fallback for non-TTY `agnes pull`**: when stderr is + not a terminal (Claude Code SessionStart hook, CI runner, Docker log + capture, …), `agnes pull --no-quiet` now emits a plain-text progress + line per file at most every 10% or 30 s, plus a final completion line. + Replaces the previous Rich-bar-on-pipe behavior that either suppressed + output entirely or leaked ANSI escape sequences. TTY path unchanged + (Rich progress bar with bytes / speed / ETA, aggregated per-file + across chunked-download chunks). ## [0.38.3] — 2026-05-06 @@ -96,6 +126,37 @@ CalVer image tags (`stable-YYYY.MM.N`, `dev-YYYY.MM.N`) are produced for every C The dashboard-served setup payload (`app/web/setup_instructions.py`) already branches between the two automatically based on platform; the doc snippet now matches that behavior for manual flows. +- **`agnes pull` chunked download for large parquets**: when the server + advertises `accept-ranges: bytes` and a parquet exceeds + `AGNES_PULL_CHUNK_THRESHOLD_BYTES` (default 50 MB), the CLI now splits + the file into N parallel HTTP Range requests + (`AGNES_PULL_CHUNK_PARALLELISM`, default 4, capped 1..16) and assembles + the parts into the destination atomically. Targets the per-flow-shaped + network (corp VPN with per-TCP-connection rate-limiting) where a single + stream is throttled but N parallel streams over the same connection + scale roughly linearly. Falls back to single-stream when the server + responds 200 instead of 206 to a Range probe, when no + `accept-ranges: bytes` is advertised, or when content is below the + threshold — no behavior change in the small-file / non-cooperating- + server cases. +- **Persistent HTTP/2 client across `agnes pull`**: `stream_download` now + routes through a process-wide pooled `httpx.Client` so N parquet + downloads share a single TLS handshake; HTTP/2 multiplexing + (when the optional `h2` package is installed) lets all chunk Range + requests share one TCP connection. Gracefully falls back to HTTP/1.1 + pooling when `h2` is missing — no crash, just slightly less benefit. + +### Added +- New optional dependency `h2>=4.1.0` (HTTP/2 transport for httpx). Pure + performance — `agnes pull` works on HTTP/1.1 if the install skips it. +- **Textual progress fallback for non-TTY `agnes pull`**: when stderr is + not a terminal (Claude Code SessionStart hook, CI runner, Docker log + capture, …), `agnes pull --no-quiet` now emits a plain-text progress + line per file at most every 10% or 30 s, plus a final completion line. + Replaces the previous Rich-bar-on-pipe behavior that either suppressed + output entirely or leaked ANSI escape sequences. TTY path unchanged + (Rich progress bar with bytes / speed / ETA, aggregated per-file + across chunked-download chunks). ## [0.38.0] — 2026-05-06 diff --git a/cli/client.py b/cli/client.py index 544026c..d8ff1af 100644 --- a/cli/client.py +++ b/cli/client.py @@ -1,8 +1,11 @@ """HTTP client wrapper for CLI — handles auth, retries, streaming.""" +import atexit import os +import threading import time import traceback +from concurrent.futures import ThreadPoolExecutor, as_completed from datetime import datetime, timezone from pathlib import Path from typing import Optional @@ -22,6 +25,18 @@ _RETRY_BACKOFFS_S = (0.3, 1.0, 3.0) # seconds before attempt 2, 3, 4 # timeout dies long before BQ finishes. Operators tune via AGNES_QUERY_TIMEOUT. QUERY_TIMEOUT_S = float(os.environ.get("AGNES_QUERY_TIMEOUT", "300")) +# Range-chunked parallel download — see `stream_download` docstring. Defaults +# tuned for the corp-VPN per-flow rate-limiting case (single-stream throttled +# but N parallel range requests scale linearly). Disabled implicitly for +# files below the threshold or when the server doesn't advertise byte-range +# support. Operators can hard-disable by setting parallelism to 1. +_CHUNK_PARALLELISM = max(1, min(16, int( + os.environ.get("AGNES_PULL_CHUNK_PARALLELISM", "4"), +))) +_CHUNK_THRESHOLD_BYTES = int( + os.environ.get("AGNES_PULL_CHUNK_THRESHOLD_BYTES", str(50 * 1024 * 1024)), +) + # ── Transport-error translation ───────────────────────────────────────── # Pavel's Issue #185 Phase 3B caught the failure mode: when httpx raises @@ -143,7 +158,13 @@ def _translate_transport_error( def get_client(timeout: float = 30.0) -> httpx.Client: - """Get an authenticated httpx client.""" + """Get an authenticated httpx client. + + This factory creates a fresh client per call — used by the small + `api_*` helpers (one request, then close). The big-stream path + (`stream_download`) routes through `_get_shared_client()` to amortize + TLS handshakes and HTTP/2 multiplexing across N parquet downloads. + """ token = get_token() headers = {} if token: @@ -155,6 +176,80 @@ def get_client(timeout: float = 30.0) -> httpx.Client: ) +# ── Shared persistent client ──────────────────────────────────────────── +# `agnes pull` issues N stream_download calls — one per parquet — plus +# (with chunked downloads) M Range requests per file. Without pooling, +# each call performs a fresh TLS handshake; with HTTP/2 enabled, all +# those requests multiplex over a single TCP connection. The shared +# client is created lazily on first stream-download request, kept alive +# for the duration of the process, and closed at exit. +# +# HTTP/2 requires the optional `h2` package. If it's unavailable (slim +# install), we fall back to HTTP/1.1 — pooling alone still saves the +# handshake cost — and never raise. The CLI must not crash on `agnes +# pull` because of an h2 import error. + +_SHARED_CLIENT: Optional[httpx.Client] = None +_SHARED_CLIENT_LOCK = threading.Lock() + + +def _get_shared_client() -> httpx.Client: + """Lazily create + return a process-wide httpx.Client. + + Pool defaults: keep up to 32 keepalive connections (covers the + chunk-parallelism cap of 16 × 2 simultaneous files comfortably) and + cap the total at 64 so a runaway loop can't open thousands of + sockets. HTTP/2 is opt-in via httpx's `http2=True` and gracefully + degrades when the `h2` extra is missing. + """ + global _SHARED_CLIENT + with _SHARED_CLIENT_LOCK: + if _SHARED_CLIENT is not None: + return _SHARED_CLIENT + token = get_token() + headers = {} + if token: + headers["Authorization"] = f"Bearer {token}" + limits = httpx.Limits( + max_keepalive_connections=32, + max_connections=64, + ) + try: + client = httpx.Client( + base_url=get_server_url(), + headers=headers, + timeout=300.0, + http2=True, + limits=limits, + ) + except (ImportError, RuntimeError): + # `h2` not installed → httpx raises; fall back to HTTP/1.1. + # Pooling alone still amortizes the TLS handshake. + client = httpx.Client( + base_url=get_server_url(), + headers=headers, + timeout=300.0, + limits=limits, + ) + _SHARED_CLIENT = client + return client + + +def _close_shared_client() -> None: + """Close the shared client and clear the slot. Safe to call twice.""" + global _SHARED_CLIENT + with _SHARED_CLIENT_LOCK: + if _SHARED_CLIENT is not None: + try: + _SHARED_CLIENT.close() + except Exception: + pass + _SHARED_CLIENT = None + + +atexit.register(_close_shared_client) + + def api_get(path: str, *, timeout: float = 30.0, **kwargs) -> httpx.Response: try: with get_client(timeout=timeout) as client: @@ -197,33 +292,197 @@ def _is_transient(exc: Exception) -> bool: return False -def stream_download(path: str, target_path: str, progress_callback=None) -> int: - """Stream a file to `target_path` atomically and with retries. +def _read_chunk_threshold_bytes() -> int: + """Re-read threshold each call so tests / operators can flip it via + env var without restarting the process.""" + try: + return int(os.environ.get( + "AGNES_PULL_CHUNK_THRESHOLD_BYTES", str(_CHUNK_THRESHOLD_BYTES), + )) + except ValueError: + return _CHUNK_THRESHOLD_BYTES - Durability properties: - - Writes to `target_path + ".tmp"`, then `os.replace` on success. The - real target file never exists in a half-written state. - - Retries up to `_RETRY_ATTEMPTS` times on transient errors (network - blip, 5xx); 4xx (auth/404) is raised immediately. - - No hash check here — that's done in the sync command against the - manifest hash, because only the caller knows the expected value. + +def _read_chunk_parallelism() -> int: + """Re-read parallelism each call (same rationale as threshold). Floor 1, + ceiling 16.""" + try: + n = int(os.environ.get( + "AGNES_PULL_CHUNK_PARALLELISM", str(_CHUNK_PARALLELISM), + )) + except ValueError: + n = _CHUNK_PARALLELISM + return max(1, min(16, n)) + + +def _probe_range_support(client: httpx.Client, path: str) -> tuple[int, bool]: + """Send HEAD; return (content-length, accepts-byte-ranges). + + `(0, False)` means "we couldn't tell — fall back to single-stream". + Never raises; transport errors during the probe are treated as + "no chunking, try the GET instead and let it surface the failure + in the normal retry loop". """ + try: + resp = client.head(path) + if getattr(resp, "status_code", 200) >= 400: + return (0, False) + size = int(resp.headers.get("content-length", "0") or 0) + accepts = (resp.headers.get("accept-ranges", "").lower() == "bytes") + return (size, accepts) + except Exception: + return (0, False) + + +class _RangeNotHonored(Exception): + """Internal sentinel — server returned 200 instead of 206 to a Range + request. Caller catches and falls back to the single-stream path.""" + + +def _download_chunk( + client: httpx.Client, + path: str, + start: int, + end: int, + part_path: Path, + progress_callback, +) -> None: + """Stream `bytes=start-end` to `part_path`. Caller deals with retry + + cleanup. Raises on any failure (HTTPStatusError on non-206 response, + httpx.* on transport blip, `_RangeNotHonored` if server returned 200 + instead of 206 — chunked path can't trust that result).""" + headers = {"Range": f"bytes={start}-{end}"} + with client.stream("GET", path, headers=headers) as response: + # Server didn't honor the Range — RFC says it MAY return 200 with + # the full body. We can't safely splice that into one part of N, + # so we abort the whole chunked path and let the caller fall back. + if response.status_code == 200: + raise _RangeNotHonored() + response.raise_for_status() + with open(part_path, "wb") as f: + for piece in response.iter_bytes(chunk_size=65536): + f.write(piece) + if progress_callback and piece: + progress_callback(len(piece)) + + +def _download_chunked( + client: httpx.Client, + path: str, + target_path: str, + total_size: int, + parallelism: int, + progress_callback, +) -> int: + """Range-based parallel download. Returns total bytes written. + + Raises `_RangeNotHonored` on the first 200-instead-of-206 response so + the caller can fall back. All other exceptions propagate. + + Cleanup discipline: every part file we create gets removed before + return (success or failure). The destination is written via the + caller's `.tmp` and renamed atomically. + """ + target = Path(target_path) + tmp_path = Path(f"{target_path}.tmp") + parallelism = max(1, parallelism) + # Build chunks — last chunk takes the remainder. + chunk_size = total_size // parallelism + if chunk_size <= 0: + chunk_size = total_size # tiny file, single chunk + parallelism = 1 + ranges = [] + for i in range(parallelism): + start = i * chunk_size + end = (start + chunk_size - 1) if i < parallelism - 1 else (total_size - 1) + ranges.append((i, start, end)) + + part_paths = [Path(f"{target_path}.part{i}") for i, _, _ in ranges] + # Pre-clean any leftovers from a prior run. + for p in part_paths: + p.unlink(missing_ok=True) + + def _attempt_chunk(i: int, start: int, end: int) -> None: + last_exc: Optional[Exception] = None + for attempt in range(_RETRY_ATTEMPTS + 1): + try: + _download_chunk( + client, path, start, end, part_paths[i], + progress_callback, + ) + return + except _RangeNotHonored: + # Don't retry — server policy, not a transport blip. + raise + except Exception as exc: + last_exc = exc + if attempt == _RETRY_ATTEMPTS or not _is_transient(exc): + break + time.sleep(_RETRY_BACKOFFS_S[ + min(attempt, len(_RETRY_BACKOFFS_S) - 1) + ]) + assert last_exc is not None + raise last_exc + + try: + if parallelism == 1: + _attempt_chunk(*ranges[0]) + else: + # Use a thread pool so each chunk gets its own concurrent + # request slot on the (HTTP/2-multiplexed when available) + # shared client. httpx.Client is thread-safe for stream(). + with ThreadPoolExecutor(max_workers=parallelism) as ex: + futs = [ex.submit(_attempt_chunk, *r) for r in ranges] + for fut in as_completed(futs): + fut.result() # propagate first error + + # Concatenate parts → tmp_path → atomic rename. + tmp_path.unlink(missing_ok=True) + total_written = 0 + with open(tmp_path, "wb") as out: + for p in part_paths: + with open(p, "rb") as inp: + while True: + block = inp.read(65536) + if not block: + break + out.write(block) + total_written += len(block) + os.replace(tmp_path, target) + return total_written + finally: + # Always clean up part files + any stray tmp. + for p in part_paths: + p.unlink(missing_ok=True) + if tmp_path.exists(): + try: + tmp_path.unlink() + except OSError: + pass + + +def _download_single_stream( + client: httpx.Client, + path: str, + target_path: str, + progress_callback, +) -> int: + """Original single-stream path with retry. Used when chunking is + disabled (small file, no range support, or fallback after 200-on-Range).""" tmp_path = Path(f"{target_path}.tmp") last_exc: Optional[Exception] = None for attempt in range(_RETRY_ATTEMPTS + 1): try: tmp_path.unlink(missing_ok=True) - with get_client(timeout=300.0) as client: - with client.stream("GET", path) as response: - response.raise_for_status() - total = 0 - with open(tmp_path, "wb") as f: - for chunk in response.iter_bytes(chunk_size=65536): - f.write(chunk) - total += len(chunk) - if progress_callback: - progress_callback(len(chunk)) - # os.replace is atomic on POSIX and Windows for same-filesystem moves. + with client.stream("GET", path) as response: + response.raise_for_status() + total = 0 + with open(tmp_path, "wb") as f: + for chunk in response.iter_bytes(chunk_size=65536): + f.write(chunk) + total += len(chunk) + if progress_callback: + progress_callback(len(chunk)) os.replace(tmp_path, target_path) return total except Exception as exc: @@ -231,23 +490,94 @@ def stream_download(path: str, target_path: str, progress_callback=None) -> int: if attempt == _RETRY_ATTEMPTS or not _is_transient(exc): break time.sleep(_RETRY_BACKOFFS_S[min(attempt, len(_RETRY_BACKOFFS_S) - 1)]) - # Clean up any leftover tmp, then surface the last exception. Translate - # transport errors (timeouts, connection drops, protocol errors) to - # AgnesTransportError so the CLI prints a clean message instead of a - # Python traceback (Pavel's #185 Phase 3B). HTTPStatusError (4xx/5xx - # response from the server) is NOT a transport failure and must - # re-raise verbatim so the caller's status-code handling + the rich - # server error body (e.g. 401 with "token expired", 403 with - # cross_project_forbidden detail) reach the analyst — Devin Review on - # PR #188 caught: HTTPStatusError is a subclass of HTTPError, so the - # generic isinstance(HTTPError) translation was eating status codes. tmp_path.unlink(missing_ok=True) assert last_exc is not None - if isinstance(last_exc, httpx.HTTPStatusError): - raise last_exc - if isinstance(last_exc, httpx.HTTPError): - raise _translate_transport_error( - last_exc, context=f"GET {path} (stream → {target_path})", - timeout_s=300.0, - ) from last_exc raise last_exc + + +def stream_download(path: str, target_path: str, progress_callback=None) -> int: + """Stream a file to `target_path` atomically and with retries. + + Two paths: + 1. **Chunked parallel** — when the server advertises `accept-ranges: + bytes` and `content-length` exceeds `AGNES_PULL_CHUNK_THRESHOLD_BYTES` + (default 50 MB), split into N range requests + (`AGNES_PULL_CHUNK_PARALLELISM`, default 4, capped 1..16) and + download in parallel. Concatenate the part files into `.tmp`, + then `os.replace`. Falls back to single-stream if the server + responds 200 instead of 206 to a Range probe. + 2. **Single-stream** — for small files, no range support, or fallback + from the chunked path. Same atomic-rename + retry semantics as + before. + + Durability properties (unchanged): + - Writes to `.tmp`, then `os.replace` on success. The real + target file never exists in a half-written state. + - Retries up to `_RETRY_ATTEMPTS` on transient errors (network blip, + 5xx); 4xx (auth/404) is raised immediately. + - No hash check here — that's the caller's job (manifest hash). + + Threading: the chunked path uses a ThreadPoolExecutor sized to the + parallelism. httpx.Client.stream() is safe to call concurrently from + multiple threads on a single client (the connection pool serializes + the underlying socket access; HTTP/2 multiplexes streams when the + `h2` extra is installed). + """ + # Use the shared persistent client when available — one TLS + # handshake amortized across N stream_download calls within the same + # process, and HTTP/2 stream multiplexing across the chunk Range + # requests within a single download. Falls back to a fresh per-call + # client if shared-client construction fails for any reason. + try: + client = _get_shared_client() + return _stream_download_via(client, path, target_path, progress_callback) + except Exception: + with get_client(timeout=300.0) as client: + return _stream_download_via(client, path, target_path, progress_callback) + + +def _stream_download_via( + client: httpx.Client, + path: str, + target_path: str, + progress_callback, +) -> int: + """The shared body of `stream_download` parameterized on the client. + Split out so tests can inject a fake client.""" + threshold = _read_chunk_threshold_bytes() + parallelism = _read_chunk_parallelism() + + total_size = 0 + accepts_ranges = False + if parallelism > 1: + total_size, accepts_ranges = _probe_range_support(client, path) + + use_chunked = ( + parallelism > 1 + and accepts_ranges + and total_size > threshold + ) + + try: + if use_chunked: + try: + return _download_chunked( + client, path, target_path, total_size, parallelism, + progress_callback, + ) + except _RangeNotHonored: + # Server lied / proxy stripped the Range — fall through. + pass + return _download_single_stream( + client, path, target_path, progress_callback, + ) + except httpx.HTTPStatusError: + # 4xx / 5xx response from the server — re-raise verbatim so the + # caller's status-code handling + the rich server error body + # reach the analyst (Devin Review on PR #188). + raise + except httpx.HTTPError as exc: + raise _translate_transport_error( + exc, context=f"GET {path} (stream → {target_path})", + timeout_s=300.0, + ) from exc diff --git a/cli/lib/pull.py b/cli/lib/pull.py index b33c4b3..59d4db9 100644 --- a/cli/lib/pull.py +++ b/cli/lib/pull.py @@ -65,6 +65,128 @@ class PullResult: _SAFE_ID_RE = re.compile(r"^[a-zA-Z0-9_\-]{1,128}$") +class _TextualProgress: + """Plain-text progress emitter for non-TTY stderr. + + When `agnes pull` is invoked from a Claude Code SessionStart hook, + a CI runner, or any pipe consumer, stderr is not a terminal. Rich's + progress bar in that mode either suppresses output (silent for + minutes on a multi-GB parquet) or emits raw ANSI noise. This class + instead emits one terse line per file at sensible cadence. + + Cadence policy: emit when *either*: + - per-file bytes-downloaded crosses a 10%-of-total boundary, OR + - 30 s have elapsed since this file's last emission. + + Always emits one final "done" line per file via `finish()` so the + operator sees a confirmed completion even on tiny files. + + Format: `[N/T files] : 25% (16 MB / 66 MB) at 1.5 MB/s` — the + "[N/T files]" prefix lets the operator see overall pull progress + in a multi-table run without buffering all per-file lines. + + Thread-safe — `advance` is called from the chunked-download worker + threads; an internal lock serializes the update + emit. + """ + + _HUMAN_UNITS = ( + (1024 * 1024 * 1024 * 1024, "TB"), + (1024 * 1024 * 1024, "GB"), + (1024 * 1024, "MB"), + (1024, "KB"), + ) + + def __init__(self, *, stream, total_files: int, file_sizes: dict[str, int]): + import threading + self._stream = stream + self._total_files = total_files + self._file_sizes = file_sizes + self._lock = threading.Lock() + # Per-file state. + self._bytes: dict[str, int] = {tid: 0 for tid in file_sizes} + self._started_at: dict[str, float] = {} + self._last_emit_at: dict[str, float] = {} + self._last_emit_pct: dict[str, int] = {} + self._finished_idx: int = 0 # files whose `finish` line has been emitted + + def advance(self, tid: str, n: int) -> None: + """Add `n` bytes to the file's total. Emit a textual update if + the cadence policy allows.""" + with self._lock: + now = time.monotonic() + if tid not in self._started_at: + self._started_at[tid] = now + self._last_emit_at[tid] = now + self._last_emit_pct[tid] = 0 + self._bytes[tid] = self._bytes.get(tid, 0) + n + + total = self._file_sizes.get(tid, 0) + current = self._bytes[tid] + pct = int((current * 100) / total) if total > 0 else 0 + elapsed = now - self._last_emit_at[tid] + crossed_10 = pct >= self._last_emit_pct[tid] + 10 + if crossed_10 or elapsed >= 30.0: + self._last_emit_at[tid] = now + self._last_emit_pct[tid] = pct - (pct % 10) + self._emit_line(tid, current, total, now) + + def finish(self) -> None: + """Emit a final `done` line for any file we never closed out.""" + with self._lock: + now = time.monotonic() + for tid, total in self._file_sizes.items(): + # Treat any file we observed bytes for as needing a + # final line. Files that errored out before any callback + # are still announced (operator wants visibility even on + # zero-byte attempts). + self._finished_idx += 1 + bytes_ = self._bytes.get(tid, 0) + started = self._started_at.get(tid, now) + duration = max(0.001, now - started) + rate = bytes_ / duration + line = ( + f"[{self._finished_idx}/{self._total_files} files] " + f"{tid}: 100% done " + f"({self._fmt_bytes(bytes_)} in {duration:.1f}s, " + f"{self._fmt_bytes(int(rate))}/s)\n" + ) + self._stream.write(line) + try: + self._stream.flush() + except Exception: + pass + + def _emit_line(self, tid: str, current: int, total: int, now: float) -> None: + started = self._started_at.get(tid, now) + duration = max(0.001, now - started) + rate = current / duration + if total > 0: + pct_str = f"{int((current * 100) / total)}%" + size_str = ( + f"({self._fmt_bytes(current)} / {self._fmt_bytes(total)})" + ) + else: + pct_str = "?" + size_str = f"({self._fmt_bytes(current)})" + idx = self._finished_idx + 1 # 1-based "currently working on file N" + line = ( + f"[{idx}/{self._total_files} files] {tid}: {pct_str} " + f"{size_str} at {self._fmt_bytes(int(rate))}/s\n" + ) + self._stream.write(line) + try: + self._stream.flush() + except Exception: + pass + + @classmethod + def _fmt_bytes(cls, n: int) -> str: + for divisor, suffix in cls._HUMAN_UNITS: + if n >= divisor: + return f"{n / divisor:.1f} {suffix}" + return f"{n} B" + + @contextmanager def _override_server_env(server_url: str, token: str) -> Iterator[None]: """Set AGNES_SERVER + scoped token override for the duration of the call. @@ -219,15 +341,34 @@ def run_pull( # the executor + thread overhead for the common single-update case. workers = min(workers, len(to_download)) if to_download else 1 - # Optional progress bar — Rich's Progress tracks per-file bytes - # streamed, aggregated across the parallel ThreadPoolExecutor - # workers. Pavel's #185 Phase 1: a single 6.3 GB parquet on first - # init went 44 minutes silent, looked frozen. Now: aggregate "X.Y - # GB / Z.A GB · 56 MB/s · ETA 1m 20s" to stderr while threads - # stream. None when show_progress=False (SessionStart hooks etc.). + # Optional progress reporting — two paths. + # + # 1. Rich progress bar: per-file bytes-streamed bar with speed + + # ETA. Rendered to stderr when stderr is a TTY. Aggregates + # across the parallel ThreadPoolExecutor workers and across + # chunked-download chunks (all chunks call the same callback + # advancing the same task). + # 2. Textual fallback: when `show_progress=True` but stderr is + # NOT a TTY (Claude Code SessionStart hook, CI run, Docker + # log capture), Rich would either suppress the bar or emit + # raw control sequences. Instead we emit one plain-text line + # per file at most every 10% or 30 s — enough signal to know + # the pull isn't frozen on a multi-GB parquet, terse enough + # not to spam the consumer's log. + # + # Both paths receive the same per-file callback so the chunked- + # download contract ("one file = one task, sum-of-chunks bytes") + # is honored uniformly. + import sys as _sys progress = None progress_tasks: dict[str, int] = {} - if show_progress and to_download: + textual = None + use_textual_fallback = ( + show_progress + and to_download + and not _sys.stderr.isatty() + ) + if show_progress and to_download and not use_textual_fallback: from rich.progress import ( Progress, BarColumn, DownloadColumn, TextColumn, TimeRemainingColumn, TransferSpeedColumn, @@ -248,13 +389,22 @@ def run_pull( progress_tasks[tid] = progress.add_task( "download", label=tid, total=size if size > 0 else None, ) + elif use_textual_fallback: + textual = _TextualProgress( + stream=_sys.stderr, + total_files=len(to_download), + file_sizes={ + tid: int(server_tables[tid].get("size_bytes") or 0) + for tid in to_download + }, + ) def _download_one(tid: str) -> tuple[str, dict | None, str | None]: """Returns (tid, local_table_entry_or_None, error_or_None). One bound thread per call; stream_download is sync I/O so a ThreadPoolExecutor (not asyncio) is the right tool. The progress callback is thread-safe — Rich's Progress.update - holds an internal lock.""" + and the textual fallback's lock both serialize internally.""" target = parquet_dir / f"{tid}.parquet" expected_hash = server_tables[tid].get("hash", "") cb = None @@ -262,6 +412,9 @@ def run_pull( task_id = progress_tasks[tid] def cb(n: int, _tid=tid, _task=task_id): progress.update(_task, advance=n) + elif textual is not None: + def cb(n: int, _tid=tid): + textual.advance(_tid, n) try: stream_download(f"/api/data/{tid}/download", str(target), progress_callback=cb) @@ -294,6 +447,8 @@ def run_pull( finally: if progress is not None: progress.stop() + if textual is not None: + textual.finish() for tid, entry, err in outcomes: if err is not None: diff --git a/pyproject.toml b/pyproject.toml index b6ccc9d..a2cc602 100644 --- a/pyproject.toml +++ b/pyproject.toml @@ -20,8 +20,13 @@ dependencies = [ "itsdangerous>=2.1.0", "authlib>=1.6.11", "argon2-cffi>=23.1.0", - # HTTP client + # HTTP client. `h2` enables HTTP/2 multiplexing for the persistent + # CLI client used by `agnes pull` (one TCP connection serves N + # concurrent parquet streams + range chunks). `cli/client.py` + # gracefully falls back to HTTP/1.1 if h2 is missing, so this + # extra is for performance, not correctness. "httpx>=0.27.0", + "h2>=4.1.0", # CLI "typer>=0.12.0", "rich>=13.0.0", diff --git a/tests/test_pull_chunked.py b/tests/test_pull_chunked.py new file mode 100644 index 0000000..3629688 --- /dev/null +++ b/tests/test_pull_chunked.py @@ -0,0 +1,349 @@ +"""Tests for range-based chunked download in cli/client.py:stream_download. + +Background — the previous diagnosis measured `agnes pull` on a single 5.1 GB +materialized parquet at 0.29 MB/s on a corp VPN with per-flow rate-limiting; +4 parallel range requests over the same connection sustained 1.65 MB/s +aggregate. Existing `AGNES_PULL_PARALLELISM=4` parallelizes across files, +not within a file, so a manifest with 1 large materialized parquet + 10 +remote tables yields 1 active worker = single-stream throughput. + +These tests exercise the chunking code path: HEAD probe, Range-request +splitting, fallback when the server doesn't honor ranges, cleanup on +chunk failure, and the small-file bypass. +""" + +from __future__ import annotations + +import os +import threading +from pathlib import Path +from unittest.mock import patch + +import pytest + + +# ── Fake HTTP layer ───────────────────────────────────────────────────── +# The real httpx Client / AsyncClient surface is large; we mock at the +# client-method level. Our `stream_download` should: +# 1. Call HEAD to learn `content-length` + `accept-ranges`. +# 2. If ranges supported and size > threshold, issue N parallel +# `GET` with `Range: bytes=A-B`, each returning 206 + body chunk. +# 3. Concatenate part files into the destination. + +class _FakeResponse: + def __init__(self, status_code: int, headers: dict | None = None, + body: bytes = b""): + self.status_code = status_code + self.headers = headers or {} + self._body = body + + def raise_for_status(self): + if self.status_code >= 400: + import httpx + raise httpx.HTTPStatusError( + f"HTTP {self.status_code}", request=None, response=self, + ) + + def iter_bytes(self, chunk_size: int = 65536): + # Yield in chunk_size pieces so the sink loop runs realistically. + for i in range(0, len(self._body), chunk_size): + yield self._body[i:i + chunk_size] + + def __enter__(self): + return self + + def __exit__(self, *a): + return False + + +class _FakeClient: + """Captures calls + returns canned responses.""" + + def __init__(self, *, body: bytes, accept_ranges: bool = True, + reject_range_with_200: bool = False, + fail_chunk_indices: tuple[int, ...] = (), + head_status: int = 200): + self._body = body + self._accept_ranges = accept_ranges + self._reject_range_with_200 = reject_range_with_200 + self._fail_chunk_indices = set(fail_chunk_indices) + self._head_status = head_status + self.head_calls = 0 + self.range_calls: list[tuple[int, int]] = [] + self.full_get_calls = 0 + self._lock = threading.Lock() + self._chunk_attempt_counts: dict[tuple[int, int], int] = {} + + # `stream_download` calls `client.head(path)` once to probe. + def head(self, path: str, **kwargs): + with self._lock: + self.head_calls += 1 + if self._head_status >= 400: + return _FakeResponse(self._head_status) + headers = {"content-length": str(len(self._body))} + if self._accept_ranges: + headers["accept-ranges"] = "bytes" + return _FakeResponse(200, headers=headers) + + # `stream_download` uses `client.stream("GET", path, headers=...)` + # for both the chunked and full-file paths. Range header presence + # tells us which one. + def stream(self, method: str, path: str, *, headers: dict | None = None, + **kwargs): + rng = (headers or {}).get("Range") or (headers or {}).get("range") + if rng: + # bytes=START-END + spec = rng.split("=", 1)[1] + start_s, end_s = spec.split("-", 1) + start = int(start_s) + end = int(end_s) + with self._lock: + self.range_calls.append((start, end)) + key = (start, end) + attempt = self._chunk_attempt_counts.get(key, 0) + self._chunk_attempt_counts[key] = attempt + 1 + # Determine chunk index (in order of unique starts). + # We map by start to a stable index for fail-injection. + chunk_idx = self._chunk_index_for_start(start) + # Should this attempt fail? Fail only on first attempt for + # listed indices — retry succeeds. + if chunk_idx in self._fail_chunk_indices and attempt == 0: + import httpx + raise httpx.ReadError("simulated chunk failure") + if self._reject_range_with_200: + # Server ignored Range — returns full body with 200. + return _FakeResponse(200, body=self._body) + piece = self._body[start:end + 1] + return _FakeResponse( + 206, + headers={"content-range": f"bytes {start}-{end}/{len(self._body)}"}, + body=piece, + ) + # Full-file GET (single-stream fallback). + with self._lock: + self.full_get_calls += 1 + return _FakeResponse(200, body=self._body) + + def _chunk_index_for_start(self, start: int) -> int: + # Unique sorted starts so fail_chunk_indices is deterministic. + starts = sorted({s for s, _ in self.range_calls}) + try: + return starts.index(start) + except ValueError: + return -1 + + def __enter__(self): + return self + + def __exit__(self, *a): + return False + + def close(self): + pass + + +# ── Test fixtures ─────────────────────────────────────────────────────── + +@pytest.fixture(autouse=True) +def _isolate_config_dir(tmp_path, monkeypatch): + cfg = tmp_path / "_cfg" + cfg.mkdir() + monkeypatch.setenv("AGNES_CONFIG_DIR", str(cfg)) + + +@pytest.fixture(autouse=True) +def _reset_shared_client(monkeypatch): + """Reset the persistent shared httpx.Client between tests so each + test starts from a known state. Tests that need to inject a fake + client also stub `_get_shared_client` directly via the + `_inject_fake_client` helper below.""" + import cli.client as cc + if hasattr(cc, "_SHARED_CLIENT"): + monkeypatch.setattr(cc, "_SHARED_CLIENT", None, raising=False) + yield + if hasattr(cc, "_SHARED_CLIENT"): + monkeypatch.setattr(cc, "_SHARED_CLIENT", None, raising=False) + + +def _inject_fake_client(monkeypatch, fake): + """Patch both client factories to return the same fake. Tests target + `_get_shared_client` (the path stream_download actually takes) and + also `get_client` so the fallback path also lands on the fake.""" + monkeypatch.setattr("cli.client.get_client", lambda timeout=300.0: fake) + monkeypatch.setattr("cli.client._get_shared_client", + lambda: fake, raising=False) + + +# ── Tests ─────────────────────────────────────────────────────────────── + +def test_chunked_download_success(tmp_path, monkeypatch): + """Server advertises ranges, file is large enough — 4 chunks, assembled + correctly into target.""" + body = bytes(range(256)) * 2048 # 512 KB + threshold = 1024 # 1 KB so 512 KB is "large" + monkeypatch.setenv("AGNES_PULL_CHUNK_THRESHOLD_BYTES", str(threshold)) + monkeypatch.setenv("AGNES_PULL_CHUNK_PARALLELISM", "4") + + fake = _FakeClient(body=body, accept_ranges=True) + _inject_fake_client(monkeypatch, fake) + + from cli.client import stream_download + target = tmp_path / "out.parquet" + progress_bytes = [] + total = stream_download("/api/data/x/download", str(target), + progress_callback=lambda n: progress_bytes.append(n)) + + assert total == len(body) + assert target.read_bytes() == body + # 4 distinct ranges issued (no overlaps; last one carries remainder). + assert len(set(fake.range_calls)) == 4 + assert fake.head_calls == 1 + assert fake.full_get_calls == 0 + # Progress callback was called and total bytes match. + assert sum(progress_bytes) == len(body) + # Chunk parts cleaned up. + leftovers = list(tmp_path.glob("*.part*")) + assert leftovers == [], f"orphan part files: {leftovers}" + + +def test_chunked_download_fallback_when_server_ignores_range( + tmp_path, monkeypatch, +): + """Server returns 200 instead of 206 on the first range probe — abort + chunked path, fall back to single-stream. No corrupt output.""" + body = b"X" * 200_000 + monkeypatch.setenv("AGNES_PULL_CHUNK_THRESHOLD_BYTES", "1024") + monkeypatch.setenv("AGNES_PULL_CHUNK_PARALLELISM", "4") + + # accept_ranges=True (HEAD lies), but every Range GET returns 200 + # with the full body — that's the "server ignored Range" path. + fake = _FakeClient(body=body, accept_ranges=True, + reject_range_with_200=True) + _inject_fake_client(monkeypatch, fake) + + from cli.client import stream_download + target = tmp_path / "out.bin" + total = stream_download("/api/data/x/download", str(target)) + + assert total == len(body) + assert target.read_bytes() == body + # Fell back to a single full-body GET. + assert fake.full_get_calls >= 1 + + +def test_small_file_uses_single_stream_path(tmp_path, monkeypatch): + """Below threshold → no HEAD probe needed (or HEAD short-circuits), + no Range requests, plain single-stream download.""" + body = b"x" * 500 # tiny + monkeypatch.setenv("AGNES_PULL_CHUNK_THRESHOLD_BYTES", "10000") # 10 KB + monkeypatch.setenv("AGNES_PULL_CHUNK_PARALLELISM", "4") + + fake = _FakeClient(body=body, accept_ranges=True) + _inject_fake_client(monkeypatch, fake) + + from cli.client import stream_download + target = tmp_path / "out.bin" + total = stream_download("/api/data/x/download", str(target)) + + assert total == len(body) + assert target.read_bytes() == body + assert fake.range_calls == [], "small file must not split into ranges" + assert fake.full_get_calls >= 1 + + +def test_chunked_download_no_accept_ranges_falls_back(tmp_path, monkeypatch): + """HEAD doesn't advertise byte-range support → skip chunked path, + plain single-stream.""" + body = b"y" * 200_000 + monkeypatch.setenv("AGNES_PULL_CHUNK_THRESHOLD_BYTES", "1024") + monkeypatch.setenv("AGNES_PULL_CHUNK_PARALLELISM", "4") + + fake = _FakeClient(body=body, accept_ranges=False) + _inject_fake_client(monkeypatch, fake) + + from cli.client import stream_download + target = tmp_path / "out.bin" + total = stream_download("/api/data/x/download", str(target)) + + assert total == len(body) + assert target.read_bytes() == body + assert fake.range_calls == [] + assert fake.full_get_calls >= 1 + + +def test_chunked_download_one_chunk_retries_then_succeeds( + tmp_path, monkeypatch, +): + """One chunk fails on first attempt; retry path completes the file.""" + body = bytes(range(256)) * 1024 # 256 KB + monkeypatch.setenv("AGNES_PULL_CHUNK_THRESHOLD_BYTES", "1024") + monkeypatch.setenv("AGNES_PULL_CHUNK_PARALLELISM", "4") + monkeypatch.setenv("AGNES_STREAM_RETRIES", "2") + + fake = _FakeClient(body=body, accept_ranges=True, + fail_chunk_indices=(1,)) # second chunk blips once + _inject_fake_client(monkeypatch, fake) + + from cli.client import stream_download + target = tmp_path / "out.bin" + total = stream_download("/api/data/x/download", str(target)) + + assert total == len(body) + assert target.read_bytes() == body + # Cleanup of all part files. + assert list(tmp_path.glob("*.part*")) == [] + + +def test_chunked_download_failure_cleans_up_part_files(tmp_path, monkeypatch): + """All retries exhausted on a chunk → no destination file, no orphan + part files.""" + body = b"z" * 200_000 + monkeypatch.setenv("AGNES_PULL_CHUNK_THRESHOLD_BYTES", "1024") + monkeypatch.setenv("AGNES_PULL_CHUNK_PARALLELISM", "4") + monkeypatch.setenv("AGNES_STREAM_RETRIES", "0") + + # Inject a permanent failure on chunk 2 (retries=0 → first failure + # is fatal). + class _ChronicFail(_FakeClient): + def stream(self, method, path, *, headers=None, **kwargs): + rng = (headers or {}).get("Range") + if rng: + spec = rng.split("=", 1)[1] + start = int(spec.split("-", 1)[0]) + # Permanently fail the chunk starting at exactly half. + if start >= len(body) // 4 and start <= len(body) // 2: + import httpx + raise httpx.ReadError("permanent") + return super().stream(method, path, headers=headers, **kwargs) + return super().stream(method, path, headers=headers, **kwargs) + + fake = _ChronicFail(body=body, accept_ranges=True) + _inject_fake_client(monkeypatch, fake) + + from cli.client import stream_download + target = tmp_path / "out.bin" + with pytest.raises(Exception): + stream_download("/api/data/x/download", str(target)) + + assert not target.exists(), "no destination file after total failure" + # No orphan parts. + assert list(tmp_path.glob("*.part*")) == [] + assert not (tmp_path / "out.bin.tmp").exists() + + +def test_progress_callback_aggregates_across_chunks(tmp_path, monkeypatch): + """The progress callback should fire with byte deltas summing to the + full file across all chunks — caller treats one file as one task.""" + body = bytes(range(256)) * 4096 # 1 MB + monkeypatch.setenv("AGNES_PULL_CHUNK_THRESHOLD_BYTES", "1024") + monkeypatch.setenv("AGNES_PULL_CHUNK_PARALLELISM", "4") + + fake = _FakeClient(body=body, accept_ranges=True) + _inject_fake_client(monkeypatch, fake) + + from cli.client import stream_download + target = tmp_path / "out.bin" + advances = [] + stream_download("/api/data/x/download", str(target), + progress_callback=lambda n: advances.append(n)) + assert sum(advances) == len(body) diff --git a/tests/test_pull_progress.py b/tests/test_pull_progress.py new file mode 100644 index 0000000..fe31c72 --- /dev/null +++ b/tests/test_pull_progress.py @@ -0,0 +1,123 @@ +"""Tests for `agnes pull` progress UX (Change 3). + +The Rich progress bar handles the TTY case fine, but Claude Code's +SessionStart context — and any hook running `agnes pull` non-interactively — +has stderr connected to a pipe, not a TTY. In that case Rich either +suppresses output entirely or emits raw ANSI noise into the consumer's +log. Goal: when the caller asks for progress and stderr is not a TTY, +emit a plain-text per-10%-or-30s update so the operator gets *some* +signal instead of multi-minute silence. +""" + +from __future__ import annotations + +import io +import time +from pathlib import Path +from unittest.mock import MagicMock + +import pytest + + +@pytest.fixture(autouse=True) +def _isolate_config_dir(tmp_path, monkeypatch): + cfg = tmp_path / "_cfg" + cfg.mkdir() + monkeypatch.setenv("AGNES_CONFIG_DIR", str(cfg)) + + +@pytest.fixture +def fake_pull_io(monkeypatch): + """Stub the manifest + memory + download endpoints so run_pull can + execute end-to-end with a fake parquet write per table.""" + canned_manifest = { + "tables": { + "tbl_big": {"hash": "h1", "rows": 0, "size_bytes": 1_000_000}, + }, + } + canned_memory = {"mandatory": [], "approved": []} + + def _api_get(path, *args, **kwargs): + resp = MagicMock() + resp.status_code = 200 + if path == "/api/sync/manifest": + resp.json.return_value = canned_manifest + elif path == "/api/memory/bundle": + resp.json.return_value = canned_memory + resp.raise_for_status = lambda: None + return resp + + def _stream_download(path, target_path, progress_callback=None): + # Simulate a chunked download: emit progress in 4 increments + # totaling the announced size. + total = 1_000_000 + slices = [total // 4] * 3 + [total - 3 * (total // 4)] + Path(target_path).write_bytes(b"PAR1" + b"\x00" * 1000 + b"PAR1") + if progress_callback: + for s in slices: + progress_callback(s) + return total + + monkeypatch.setattr("cli.lib.pull.api_get", _api_get, raising=False) + monkeypatch.setattr("cli.lib.pull.stream_download", _stream_download, + raising=False) + monkeypatch.setattr("cli.lib.pull._is_valid_parquet", lambda p: True, + raising=False) + monkeypatch.setattr("cli.lib.pull._file_md5", lambda p: "h1", raising=False) + + +def test_textual_progress_when_stderr_is_not_tty( + tmp_path, fake_pull_io, monkeypatch, capsys, +): + """Non-TTY stderr → emit a plain-text progress line per file.""" + # Force the non-TTY branch even if pytest's fake stderr is a tty. + monkeypatch.setattr("sys.stderr.isatty", lambda: False, raising=False) + + from cli.lib.pull import run_pull + result = run_pull( + server_url="http://x", token="t", workspace=tmp_path, + show_progress=True, + ) + captured = capsys.readouterr() + # Some indication of the file + bytes ran; we don't pin exact format. + assert "tbl_big" in captured.err + assert result.tables_updated == 1 + # No raw ANSI escape sequences in the textual fallback. + assert "\x1b[" not in captured.err.split("tbl_big")[0] + + +def test_no_progress_output_when_show_progress_is_false( + tmp_path, fake_pull_io, monkeypatch, capsys, +): + """`show_progress=False` (the SessionStart hook path) emits no + progress text on stderr in either TTY or non-TTY mode.""" + monkeypatch.setattr("sys.stderr.isatty", lambda: False, raising=False) + + from cli.lib.pull import run_pull + run_pull( + server_url="http://x", token="t", workspace=tmp_path, + show_progress=False, + ) + captured = capsys.readouterr() + assert "tbl_big" not in captured.err + + +def test_textual_progress_emits_at_completion( + tmp_path, fake_pull_io, monkeypatch, capsys, +): + """At least one final completion line gets emitted per file even if + the throttle window doesn't trigger mid-file.""" + monkeypatch.setattr("sys.stderr.isatty", lambda: False, raising=False) + from cli.lib.pull import run_pull + run_pull( + server_url="http://x", token="t", workspace=tmp_path, + show_progress=True, + ) + captured = capsys.readouterr() + # Final line marks the file as done — either "100%" or a "✓ tbl_big" / + # "tbl_big done" indicator. We accept any final-completion form. + assert ( + "100%" in captured.err + or "done" in captured.err.lower() + or "complete" in captured.err.lower() + ) diff --git a/tests/test_pull_shared_client.py b/tests/test_pull_shared_client.py new file mode 100644 index 0000000..113ed3a --- /dev/null +++ b/tests/test_pull_shared_client.py @@ -0,0 +1,85 @@ +"""Tests for the persistent HTTP/2-capable shared client (Change 2). + +`agnes pull` issues N stream_download calls — one per parquet. Without +pooling, each call performs a fresh TLS handshake. The shared client is +created lazily once per process and closed at exit; HTTP/2 (when `h2` is +available) further multiplexes all chunk Range requests over a single +TCP connection. +""" + +from __future__ import annotations + +import pytest + + +@pytest.fixture(autouse=True) +def _isolate_config_dir(tmp_path, monkeypatch): + cfg = tmp_path / "_cfg" + cfg.mkdir() + monkeypatch.setenv("AGNES_CONFIG_DIR", str(cfg)) + # Some dev environments point SSL_CERT_FILE / REQUESTS_CA_BUNDLE at a + # corp-CA bundle that may not exist on every laptop running the test + # suite. Clear those so httpx.Client() construction in the shared- + # client path can build a default SSL context without trying to load + # a missing PEM file. + for var in ("SSL_CERT_FILE", "REQUESTS_CA_BUNDLE", "CURL_CA_BUNDLE"): + monkeypatch.delenv(var, raising=False) + + +@pytest.fixture(autouse=True) +def _reset_shared(monkeypatch): + import cli.client as cc + cc._close_shared_client() + monkeypatch.setattr(cc, "_SHARED_CLIENT", None, raising=False) + yield + cc._close_shared_client() + + +def test_get_shared_client_is_cached(monkeypatch): + """Multiple calls return the same client instance — no fresh TLS + handshake per stream_download invocation.""" + monkeypatch.setenv("AGNES_SERVER", "https://x.example.test") + from cli.client import _get_shared_client + c1 = _get_shared_client() + c2 = _get_shared_client() + assert c1 is c2, "shared client must be a single instance" + + +def test_get_shared_client_falls_back_when_http2_unavailable(monkeypatch): + """If httpx raises during HTTP/2 client construction (e.g. `h2` not + installed in the runtime env), we must gracefully build a HTTP/1.1 + client instead of crashing the pull.""" + import httpx + + monkeypatch.setenv("AGNES_SERVER", "https://x.example.test") + import cli.client as cc + + real_client = httpx.Client + + construction_calls = [] + + def fake_client(*args, **kwargs): + construction_calls.append(kwargs.copy()) + if kwargs.get("http2") is True: + raise ImportError("Using http2=True, but the 'h2' package is not installed") + return real_client(*args, **kwargs) + + monkeypatch.setattr(httpx, "Client", fake_client) + + client = cc._get_shared_client() + assert client is not None + # Two construction attempts: first http2=True (raised), second falls + # back to HTTP/1.1 (no http2 kwarg). + assert construction_calls[0].get("http2") is True + assert construction_calls[1].get("http2") is None or construction_calls[1].get("http2") is False + cc._close_shared_client() + + +def test_close_shared_client_idempotent(monkeypatch): + """Calling close twice (once explicitly, once via atexit) must not + raise.""" + monkeypatch.setenv("AGNES_SERVER", "https://x.example.test") + from cli.client import _get_shared_client, _close_shared_client + _get_shared_client() + _close_shared_client() + _close_shared_client() # second close is a no-op