diff --git a/CHANGELOG.md b/CHANGELOG.md index 56fd2cc..4b5d694 100644 --- a/CHANGELOG.md +++ b/CHANGELOG.md @@ -10,6 +10,80 @@ CalVer image tags (`stable-YYYY.MM.N`, `dev-YYYY.MM.N`) are produced for every C ## [Unreleased] +## [0.39.0] — 2026-05-06 + +### Performance +- **`/api/query` (and `agnes query --remote`) now rewrites user SQL referencing + `query_mode='remote'` BigQuery rows into a single `bigquery_query()` call + before execute** (`app/api/query.py`). Pre-fix the master view + (`CREATE VIEW AS SELECT * FROM bigquery..`) did + not push WHERE / SELECT / LIMIT into BQ — the DuckDB BQ extension opened a + Storage Read API session over the entire upstream table, scanning the full + partitioned dataset before the local DuckDB filter ran. On 100M+ row + remote-mode tables this was 50-100× slower than the equivalent direct + `bigquery_query()` call (70-150 s vs 1.5 s) and frequently failed with + `Response too large to return`. The rewriter (shared core with the existing + dry-run helper) wraps the user's whole SQL in `bigquery_query('', + '')` so the BQ planner receives the full query and applies + partition pruning + projection pushdown server-side. Conservative + fall-through: cross-source JOINs (BQ ↔ Keboola/Jira local), queries already + containing `bigquery_query(`, and unconfigured BQ project all keep the + original ATTACH-catalog path so behavior degrades gracefully. +- **DuckDB BigQuery-extension session pool** + (`connectors/bigquery/access.py`). `BqAccess.duckdb_session()` now acquires + pre-warmed connections from a bounded process-local pool instead of running + `INSTALL bigquery; LOAD bigquery; CREATE SECRET; ATTACH …` on every request. + Each acquire saves the ~0.5 s extension-load + secret-creation cost when + the pool has a warm entry; auth SECRET is refreshed on acquire so a + long-lived pooled entry doesn't keep a stale GCE metadata token past its + TTL. Pool size is configurable via `data_source.bigquery.session_pool_size` + (default 4; sentinel `0` disables pooling). Affects every BQ-touching path + — `/api/query`, `/api/v2/scan`, `/api/v2/sample`, `/api/v2/schema`, + materialize, and the orchestrator's remote-attach. +- **`agnes pull` chunked download for large parquets**: when the server + advertises `accept-ranges: bytes` and a parquet exceeds + `AGNES_PULL_CHUNK_THRESHOLD_BYTES` (default 50 MB), the CLI now splits + the file into N parallel HTTP Range requests + (`AGNES_PULL_CHUNK_PARALLELISM`, default 4, capped 1..16) and assembles + the parts into the destination atomically. Targets the per-flow-shaped + network (corp VPN with per-TCP-connection rate-limiting) where a single + stream is throttled but N parallel streams over the same connection + scale roughly linearly. Falls back to single-stream when the server + responds 200 instead of 206 to a Range probe, when no + `accept-ranges: bytes` is advertised, or when content is below the + threshold — no behavior change in the small-file / non-cooperating- + server cases. +- **Persistent HTTP/2 client across `agnes pull`**: `stream_download` now + routes through a process-wide pooled `httpx.Client` so N parquet + downloads share a single TLS handshake; HTTP/2 multiplexing + (when the optional `h2` package is installed) lets all chunk Range + requests share one TCP connection. Gracefully falls back to HTTP/1.1 + pooling when `h2` is missing — no crash, just slightly less benefit. + +### Fixed +- **BigQuery `responseTooLarge` no longer surfaces as a generic 400 / 502 with + the raw upstream message** (`connectors/bigquery/access.py`). The + `translate_bq_error` helper now classifies "Response too large to return" + errors via a dedicated `bq_response_too_large` kind (HTTP 400) with an + actionable hint pointing at the WHERE / aggregation / materialized-table + remediations. Pre-fix this failure mode fell through to the generic + `bq_bad_request` mapping, which implied the user's SQL had a syntax error + — wrong root cause. Affects every BQ-touching path (`/api/query`, + `/api/v2/scan`, `/api/v2/sample`, `/api/v2/schema`, materialize) since + they all share `translate_bq_error`. + +### Added +- New optional dependency `h2>=4.1.0` (HTTP/2 transport for httpx). Pure + performance — `agnes pull` works on HTTP/1.1 if the install skips it. +- **Textual progress fallback for non-TTY `agnes pull`**: when stderr is + not a terminal (Claude Code SessionStart hook, CI runner, Docker log + capture, …), `agnes pull --no-quiet` now emits a plain-text progress + line per file at most every 10% or 30 s, plus a final completion line. + Replaces the previous Rich-bar-on-pipe behavior that either suppressed + output entirely or leaked ANSI escape sequences. TTY path unchanged + (Rich progress bar with bytes / speed / ETA, aggregated per-file + across chunked-download chunks). + ## [0.38.3] — 2026-05-06 ### Changed diff --git a/app/api/query.py b/app/api/query.py index a758dc4..6dbbadd 100644 --- a/app/api/query.py +++ b/app/api/query.py @@ -31,6 +31,50 @@ logger = logging.getLogger(__name__) router = APIRouter(prefix="/api/query", tags=["query"]) + +# Heuristic: did the BQ-side execution of a `bigquery_query()`-rewritten +# query reject the inner SQL because of a **DuckDB-vs-BQ dialect mismatch** +# specifically? We want to fall back ONLY on cases where the same SQL +# would have worked under the legacy DuckDB ATTACH-catalog path — +# DuckDB-only syntax (``::INT`` casts, ``STRPTIME``, COALESCE arity quirks) +# that BQ's parser rejects. +# +# We DO NOT want to fall back on user-data errors that BQ would reject in +# either path (unknown column name, wrong function signature, invalid cast +# of literal user input). For those, the legacy ATTACH path would issue +# the same query and fail the same way — just 50-100× slower. Triggering +# fallback there is a 2× latency tax on every typo (devil's-advocate R1 +# finding #2). +# +# Conservative pattern set: only the BQ-emitted ``Syntax error: `` +# (with trailing colon) covers genuine parse-level dialect mismatch. +# ``Unrecognized name`` etc. surface for both bad-user-column AND +# DuckDB-only-name cases — the safe assumption is that user-column-typo +# is the more common case, so we don't fall back. If a deployment +# surfaces a real DuckDB-only-name regression, it's better caught as +# a BinderException with the original SQL in the logs than amplified +# via slow-path retry. +# +# The trailing colon (devil's-advocate R2 finding #3) anchors the match +# against BQ's verbatim error format and avoids false positives where +# the literal substring `Syntax error` appears in a user's SQL string +# literal that DuckDB then echoes back in an unrelated error message +# (e.g. `WHERE log_msg = 'Syntax error in foo'` failing on quota). +_BQ_REWRITE_PARSE_ERROR_PATTERNS = ( + "Syntax error: ", + "syntax error: ", +) + + +def _looks_like_bq_rewrite_parse_error(exc: BaseException) -> bool: + """Return True when ``exc`` is the BQ-rejected-inner-SQL flavour we + want to fall back from. Conservative: matches against the exception + message text only, no isinstance checks, so it works whether the + DuckDB BQ extension wrapped the error as BinderException, IOException, + or a plain Python Exception.""" + msg = str(exc) + return any(pat in msg for pat in _BQ_REWRITE_PARSE_ERROR_PATTERNS) + # Issue #160 §4.3.1 — direct `bq..` references in user # SQL. Catalog token accepts both `bq` (the unquoted DuckDB-style name) and # `"bq"` (quoted identifier). DuckDB resolves both to the same ATTACHed @@ -192,8 +236,64 @@ def execute_query( else contextlib.nullcontext() ) with guard: - # Open in read-only mode for extra safety - result = analytics.execute(request.sql).fetchmany(request.limit + 1) + # Performance fix: rewrite user SQL referencing BQ-remote tables + # to a single ``bigquery_query()`` call so WHERE / projection / + # LIMIT push into BQ via jobs.query (1-2 s) instead of falling + # through DuckDB's ATTACH-catalog Storage Read API session over + # the full table (often 70-150 s, fails with "Response too + # large to return" on >100M-row sources). Helper returns the + # original SQL unchanged when rewriting would be unsafe + # (cross-source JOIN, no BQ tables referenced, double-wrap). + execution_sql, did_rewrite = _rewrite_user_sql_for_bigquery_query( + request.sql, conn, + ) + if did_rewrite: + # Memory-safety: ``bigquery_query()`` materialises the entire + # BQ result into DuckDB before fetchmany sees it (vs the + # ATTACH-catalog Storage Read API path, which streams rows + # lazily). Wrap the rewritten SQL in an outer ``LIMIT N+1`` + # so a `SELECT *` against a billion-row remote table doesn't + # buffer the full table into the worker process — the cap + # is pushed into the BQ job itself. Aliased subquery so the + # outer LIMIT applies to the final rewritten result. + execution_sql = ( + f"SELECT * FROM ({execution_sql}) AS _bqq_outer " + f"LIMIT {request.limit + 1}" + ) + logger.info( + "query_rewrite_to_bigquery_query: user_id=%s — wrapped " + "SQL in bigquery_query() with outer LIMIT for BQ " + "predicate pushdown", + user_id, + ) + else: + logger.debug( + "query_rewrite_skipped: user_id=%s — running original " + "SQL via ATTACH-catalog path", + user_id, + ) + + # Open in read-only mode for extra safety. If the rewritten + # path errors (e.g. user SQL contained DuckDB-only syntax — + # ``::INT`` casts, ``STRPTIME``, COALESCE arity differences — + # that survives identifier rewrite but BQ refuses), fall back + # to the original SQL via the legacy ATTACH-catalog path so + # the request still succeeds (slower, but correct). Same + # safety contract as the dry-run fallback in + # ``_bq_quota_and_cap_guard``. + try: + result = analytics.execute(execution_sql).fetchmany(request.limit + 1) + except Exception as exc: + if did_rewrite and _looks_like_bq_rewrite_parse_error(exc): + logger.warning( + "query_rewrite_fallback: user_id=%s — bigquery_query() " + "rewrite rejected by BQ (%s); retrying via " + "ATTACH-catalog path", + user_id, type(exc).__name__, + ) + result = analytics.execute(request.sql).fetchmany(request.limit + 1) + else: + raise columns = [desc[0] for desc in analytics.description] if analytics.description else [] truncated = len(result) > request.limit rows = result[:request.limit] @@ -432,12 +532,11 @@ def _bq_guardrail_inputs( return dry_run, name_lookups, None -def _rewrite_user_sql_for_bq_dry_run( +def _rewrite_bq_table_refs_to_native( sql: str, name_lookups: list, project: str, ) -> str: - """Rewrite user SQL from DuckDB-flavor to BQ-native so a single - `_bq_dry_run_bytes` call can estimate scan size for the EXACT query - the user submitted (issue #171). + """Core identifier rewrite: DuckDB-flavor table references → BQ-native + backtick form. Shared between dry-run and execution-path rewriters. Two transformations: @@ -463,13 +562,15 @@ def _rewrite_user_sql_for_bq_dry_run( inside a string literal (e.g. an `IN (...)` value or a `LIKE` pattern) will also be rewritten. This is acceptable because (a) it's vanishingly rare to have a string literal exactly matching a registered table name, - and (b) when it does happen the dry-run errors out and the caller falls - back to the per-table SELECT * estimate (current behavior, no regression). + and (b) when it does happen the caller's error path covers the case + (dry-run falls back to per-table SELECT * estimate; execution falls + through to the ATTACH-catalog path). CTE shadowing: a `WITH unit_economics AS (...)` followed by `FROM unit_economics` would also rewrite the `FROM` reference. BQ then treats - the CTE as unreferenced (legal) and the dry-run estimates the rewritten - physical table — likely an over-estimate. Same fallback path covers this. + the CTE as unreferenced (legal) and the rewriter's caller deals with + the consequence — over-estimation for dry-run, fall-through-to-ATTACH + via BQ parse error for execution. """ out = sql @@ -509,6 +610,206 @@ def _rewrite_user_sql_for_bq_dry_run( return out +def _rewrite_user_sql_for_bq_dry_run( + sql: str, name_lookups: list, project: str, +) -> str: + """Rewrite user SQL from DuckDB-flavor to BQ-native so a single + `_bq_dry_run_bytes` call can estimate scan size for the EXACT query + the user submitted (issue #171). Thin wrapper around the shared + core; kept as a stable name for callers in /api/query's cap-guard. + """ + return _rewrite_bq_table_refs_to_native(sql, name_lookups, project) + + +def _rewrite_user_sql_for_bigquery_query( + user_sql: str, conn: duckdb.DuckDBPyConnection, +) -> tuple[str, bool]: + """Rewrite user SQL so the entire query ships to BQ as a single + ``bigquery_query(, )`` call. + + Returns ``(rewritten_sql, did_rewrite)``. When ``did_rewrite`` is + ``False``, the caller MUST execute the original ``user_sql`` via the + ATTACH-catalog path (slow but correct); the rewriter is conservative + on purpose — wrapping cross-source queries in ``bigquery_query()`` + would silently lose the local-side data. + + Why this matters + ---------------- + The orchestrator's master view (``CREATE VIEW name AS SELECT * FROM + bigquery..``) does not push WHERE / projections + into BQ when DuckDB resolves the query — the BQ extension opens a + Storage Read API session over the entire table, which on multi-100M-row + tables is 50-100× slower than letting BQ run the query server-side. + Wrapping the user's SQL in ``bigquery_query('', '')`` + makes the BQ extension issue a ``jobs.query`` instead, with full + predicate pushdown. + + Skip rules (returns ``(user_sql, False)``) + ------------------------------------------ + 1. No registered ``query_mode='remote'`` BQ row referenced in the SQL. + Nothing to rewrite — original SQL passes through unchanged. + 2. User SQL already contains ``bigquery_query(`` — never double-wrap. + (The /api/query keyword denylist also blocks this in production; + defensive guard for callers in other contexts.) + 3. SQL also references a non-BQ master view (Keboola/Jira local-mode + table). Wrapping would lose those references — fall through to + ATTACH-catalog so the cross-source query still runs. + 4. ``get_bq_access()`` returns the unconfigured sentinel + (``data == ''``). No project to fill into ``bigquery_query()``. + + Edge cases preserved by design + ------------------------------ + - CTEs / sub-queries referencing BQ tables: the table-name rewrite + happens at every match position, then the whole SQL is wrapped in + one ``bigquery_query()``. BQ supports CTEs, so this works. + - Multiple BQ tables, same project: combined into ONE wrap (single + jobs.query). DuckDB's BQ extension doesn't support multi-project + JOINs in a single ``bigquery_query()`` call today; if/when the + registry grows per-table source_project, this helper would need to + gate on cross-project mixing. + - ``bq."ds"."tbl"`` direct paths: rewritten to BQ-native backticks + via the same shared core as dry-run. + """ + # Skip 2: don't double-wrap. Cheap pre-check before any registry I/O. + if "bigquery_query(" in user_sql.lower(): + return user_sql, False + + # Find all referenced BQ remote-mode rows (bare-name + direct bq.path). + # Mirrors the non-RBAC parts of `_bq_guardrail_inputs`. + sql_lower = user_sql.lower() + name_lookups: list = [] + seen_paths: set = set() + + try: + repo = TableRegistryRepository(conn) + bq_rows = repo.list_by_source("bigquery") + all_rows = repo.list_all() + except Exception: + # Registry read failure — let the original SQL run through the + # ATTACH-catalog path. The handler's generic error path will + # surface anything user-visible. + return user_sql, False + + # Multi-project guard (devil's-advocate R1 finding #5): the rewriter + # assumes every BQ-remote table resolves under the single + # `bq.projects.data` project. The current registry schema doesn't + # store `source_project` per row, so `bucket` is the only place a + # cross-project leak could hide. A bucket containing `.` (e.g. + # `other_prj.dataset`) suggests the operator encoded a project + # prefix into the bucket name — wrapping that under our single + # project would silently target the wrong project. Conservative + # skip: any BQ row whose bucket contains `.` aborts the rewrite, + # falling through to the legacy ATTACH-catalog path which uses + # whatever resolution the operator's _remote_attach configured. + for r in bq_rows: + if (r.get("query_mode") or "") != "remote": + continue + bucket = r.get("bucket") + source_table = r.get("source_table") + name = r.get("name") + if not (bucket and source_table and name): + continue + if "." in str(bucket): + # Project-qualified bucket — can't safely wrap under our + # single-project assumption. Bail out completely so we don't + # mix rewritten and non-rewritten BQ paths in one query. + return user_sql, False + pattern = r'\b' + re.escape(str(name).lower()) + r'\b' + if re.search(pattern, sql_lower): + key = (bucket.lower(), source_table.lower()) + if key not in seen_paths: + seen_paths.add(key) + name_lookups.append((str(name), bucket, source_table)) + + # Direct bq."ds"."tbl" references — pull the registered (bucket, + # source_table) pair so the inner SQL receives a backticked BQ-native + # path. Mismatched / unregistered paths are caught upstream by the + # guardrail; here we just collect the mappings the rewriter needs. + direct_paths: set[tuple[str, str]] = set() + for m in BQ_PATH.finditer(user_sql): + bucket_raw = m.group(1).strip('"') + source_table_raw = m.group(2).strip('"') + direct_paths.add((bucket_raw, source_table_raw)) + + if not name_lookups and not direct_paths: + # Skip 1: no BQ tables referenced. + return user_sql, False + + # Skip 3: cross-source query (BQ + local-mode). If user SQL also + # references a non-BQ master view, we can't push the whole thing to + # BQ — DuckDB needs to do the join. + bq_names_lc = {n.lower() for n, _, _ in name_lookups} + for r in all_rows: + st = (r.get("source_type") or "").lower() + qm = (r.get("query_mode") or "").lower() + if st == "bigquery" and qm == "remote": + continue # already handled + name = r.get("name") + if not name: + continue + name_lc = str(name).lower() + if name_lc in bq_names_lc: + # Same name registered both BQ-remote and local? Pathological; + # skip as a safety measure. + return user_sql, False + if re.search(r'\b' + re.escape(name_lc) + r'\b', sql_lower): + logger.info( + "rewrite_skip_cross_source: user SQL references both " + "BQ-remote and local-mode tables; falling back to " + "ATTACH-catalog path", + ) + return user_sql, False + + # Skip 4: BQ project not configured. + try: + bq = get_bq_access() + data_project = bq.projects.data + # The first arg to `bigquery_query()` is the **execution / billing** + # project — the project under which the BQ job runs and is billed. + # In cross-project deployments the SA may only have + # `serviceusage.services.use` on the billing project, so passing + # the data project there returns 403 USER_PROJECT_DENIED. Match + # the convention used everywhere else in the codebase (v2_scan / + # v2_sample / v2_schema / extractor): backtick paths use the + # **data** project, `bigquery_query()` first-arg uses the + # **billing** project. For single-project deploys the two are + # identical so the fix is a no-op there. + billing_project = bq.projects.billing or data_project + except Exception: + return user_sql, False + if not data_project: + return user_sql, False + + # Rewrite identifiers using the DATA project — backtick paths + # `..` resolve to the same logical + # source no matter which project bills the query. + inner_sql = _rewrite_bq_table_refs_to_native(user_sql, name_lookups, data_project) + + # Embed the inner SQL using DuckDB's dollar-quoted string literal form + # (`$tag$ ... $tag$`). Naive `replace("'", "''")` doubling misses + # backslash-escape sequences DuckDB's lexer recognises (`\\`, `\n`, + # `\t`, …) — a predicate like `WHERE name = 'O\'Brien'` is unsafe + # under doubling. Dollar-quoting takes the inner SQL verbatim with no + # escape sequences whatsoever, so the user's exact bytes reach BQ. + # Tag is a fixed conventional value; the absurdly unlikely collision + # (user SQL containing the literal `$bqq_inner$`) falls back to the + # legacy doubling path so the rewrite still proceeds — over-doubled + # quotes are at worst a parse error caught by the handler's fallback + # at the call site, not a silent bad result. + DOLLAR_TAG = "$bqq_inner$" + if DOLLAR_TAG in inner_sql: + escaped_inner = inner_sql.replace("'", "''") + rewritten = ( + f"SELECT * FROM bigquery_query('{billing_project}', '{escaped_inner}')" + ) + else: + rewritten = ( + f"SELECT * FROM bigquery_query('{billing_project}', " + f"{DOLLAR_TAG}{inner_sql}{DOLLAR_TAG})" + ) + return rewritten, True + + @contextlib.contextmanager def _bq_quota_and_cap_guard( *, diff --git a/cli/client.py b/cli/client.py index 544026c..069ce3d 100644 --- a/cli/client.py +++ b/cli/client.py @@ -1,8 +1,13 @@ """HTTP client wrapper for CLI — handles auth, retries, streaming.""" +import atexit +import glob import os +import re +import threading import time import traceback +from concurrent.futures import ThreadPoolExecutor, as_completed from datetime import datetime, timezone from pathlib import Path from typing import Optional @@ -11,6 +16,60 @@ import httpx from cli.config import _config_dir, get_server_url, get_token + +# PID-suffixed tmp / part files — see `_download_chunked` and +# `_download_single_stream`. We extract the embedded PID and reap any +# leftover whose process is no longer alive on every pull. Without this, +# every SIGKILL'd pull leaks files indefinitely (devil's-advocate R3 +# finding #1). +_PID_SUFFIX_RE = re.compile(r"\.(\d+)\.(?:tmp|part\d+)$") + + +def _is_pid_alive(pid: int) -> bool: + """Return True if a process with the given PID exists. POSIX-only; + Windows users get the conservative `True` (file kept) which means + no reaping but also no false-deletion of a live sibling.""" + if pid <= 0: + return False + try: + # Signal 0 = no-op kill; raises ProcessLookupError when PID is + # gone, PermissionError when the PID exists but isn't ours + # (still alive, just owned by someone else — keep the file). + os.kill(pid, 0) + return True + except ProcessLookupError: + return False + except PermissionError: + return True + except Exception: + # Anything else (e.g. AttributeError on Windows where os.kill + # exists but signal 0 isn't supported the same way): be + # conservative and don't reap. + return True + + +def _reap_dead_pid_leftovers(target_path: str) -> None: + """Remove `.{pid}.tmp` and `.{pid}.partN` files + whose embedded PID is no longer alive. Called at the start of every + download to keep the parquet directory tidy across SIGKILL'd or + crashed prior runs. Never raises — leaked file is preferable to + failing the new pull on a permission error.""" + candidates = glob.glob(f"{target_path}.*.tmp") + glob.glob(f"{target_path}.*.part*") + for path in candidates: + m = _PID_SUFFIX_RE.search(path) + if not m: + continue + try: + pid = int(m.group(1)) + except ValueError: + continue + if _is_pid_alive(pid): + continue + try: + os.unlink(path) + except OSError: + pass + # Retry policy for transient failures during stream downloads. Scoped to # network issues and 5xx — 4xx (auth, 404, 400) is NOT retried. Tunable via # env for tests; defaults sit in the "one flaky network blip" window. @@ -22,6 +81,18 @@ _RETRY_BACKOFFS_S = (0.3, 1.0, 3.0) # seconds before attempt 2, 3, 4 # timeout dies long before BQ finishes. Operators tune via AGNES_QUERY_TIMEOUT. QUERY_TIMEOUT_S = float(os.environ.get("AGNES_QUERY_TIMEOUT", "300")) +# Range-chunked parallel download — see `stream_download` docstring. Defaults +# tuned for the corp-VPN per-flow rate-limiting case (single-stream throttled +# but N parallel range requests scale linearly). Disabled implicitly for +# files below the threshold or when the server doesn't advertise byte-range +# support. Operators can hard-disable by setting parallelism to 1. +_CHUNK_PARALLELISM = max(1, min(16, int( + os.environ.get("AGNES_PULL_CHUNK_PARALLELISM", "4"), +))) +_CHUNK_THRESHOLD_BYTES = int( + os.environ.get("AGNES_PULL_CHUNK_THRESHOLD_BYTES", str(50 * 1024 * 1024)), +) + # ── Transport-error translation ───────────────────────────────────────── # Pavel's Issue #185 Phase 3B caught the failure mode: when httpx raises @@ -143,7 +214,13 @@ def _translate_transport_error( def get_client(timeout: float = 30.0) -> httpx.Client: - """Get an authenticated httpx client.""" + """Get an authenticated httpx client. + + This factory creates a fresh client per call — used by the small + `api_*` helpers (one request, then close). The big-stream path + (`stream_download`) routes through `_get_shared_client()` to amortize + TLS handshakes and HTTP/2 multiplexing across N parquet downloads. + """ token = get_token() headers = {} if token: @@ -155,6 +232,80 @@ def get_client(timeout: float = 30.0) -> httpx.Client: ) +# ── Shared persistent client ──────────────────────────────────────────── +# `agnes pull` issues N stream_download calls — one per parquet — plus +# (with chunked downloads) M Range requests per file. Without pooling, +# each call performs a fresh TLS handshake; with HTTP/2 enabled, all +# those requests multiplex over a single TCP connection. The shared +# client is created lazily on first stream-download request, kept alive +# for the duration of the process, and closed at exit. +# +# HTTP/2 requires the optional `h2` package. If it's unavailable (slim +# install), we fall back to HTTP/1.1 — pooling alone still saves the +# handshake cost — and never raise. The CLI must not crash on `agnes +# pull` because of an h2 import error. + +_SHARED_CLIENT: Optional[httpx.Client] = None +_SHARED_CLIENT_LOCK = threading.Lock() + + +def _get_shared_client() -> httpx.Client: + """Lazily create + return a process-wide httpx.Client. + + Pool defaults: keep up to 32 keepalive connections (covers the + chunk-parallelism cap of 16 × 2 simultaneous files comfortably) and + cap the total at 64 so a runaway loop can't open thousands of + sockets. HTTP/2 is opt-in via httpx's `http2=True` and gracefully + degrades when the `h2` extra is missing. + """ + global _SHARED_CLIENT + with _SHARED_CLIENT_LOCK: + if _SHARED_CLIENT is not None: + return _SHARED_CLIENT + token = get_token() + headers = {} + if token: + headers["Authorization"] = f"Bearer {token}" + limits = httpx.Limits( + max_keepalive_connections=32, + max_connections=64, + ) + try: + client = httpx.Client( + base_url=get_server_url(), + headers=headers, + timeout=300.0, + http2=True, + limits=limits, + ) + except (ImportError, RuntimeError): + # `h2` not installed → httpx raises; fall back to HTTP/1.1. + # Pooling alone still amortizes the TLS handshake. + client = httpx.Client( + base_url=get_server_url(), + headers=headers, + timeout=300.0, + limits=limits, + ) + _SHARED_CLIENT = client + return client + + +def _close_shared_client() -> None: + """Close the shared client and clear the slot. Safe to call twice.""" + global _SHARED_CLIENT + with _SHARED_CLIENT_LOCK: + if _SHARED_CLIENT is not None: + try: + _SHARED_CLIENT.close() + except Exception: + pass + _SHARED_CLIENT = None + + +atexit.register(_close_shared_client) + + def api_get(path: str, *, timeout: float = 30.0, **kwargs) -> httpx.Response: try: with get_client(timeout=timeout) as client: @@ -197,33 +348,259 @@ def _is_transient(exc: Exception) -> bool: return False -def stream_download(path: str, target_path: str, progress_callback=None) -> int: - """Stream a file to `target_path` atomically and with retries. +def _read_chunk_threshold_bytes() -> int: + """Re-read threshold each call so tests / operators can flip it via + env var without restarting the process.""" + try: + return int(os.environ.get( + "AGNES_PULL_CHUNK_THRESHOLD_BYTES", str(_CHUNK_THRESHOLD_BYTES), + )) + except ValueError: + return _CHUNK_THRESHOLD_BYTES - Durability properties: - - Writes to `target_path + ".tmp"`, then `os.replace` on success. The - real target file never exists in a half-written state. - - Retries up to `_RETRY_ATTEMPTS` times on transient errors (network - blip, 5xx); 4xx (auth/404) is raised immediately. - - No hash check here — that's done in the sync command against the - manifest hash, because only the caller knows the expected value. + +def _read_chunk_parallelism() -> int: + """Re-read parallelism each call (same rationale as threshold). Floor 1, + ceiling 16.""" + try: + n = int(os.environ.get( + "AGNES_PULL_CHUNK_PARALLELISM", str(_CHUNK_PARALLELISM), + )) + except ValueError: + n = _CHUNK_PARALLELISM + return max(1, min(16, n)) + + +def _probe_range_support(client: httpx.Client, path: str) -> tuple[int, bool]: + """Send HEAD; return (content-length, accepts-byte-ranges). + + `(0, False)` means "we couldn't tell — fall back to single-stream". + Never raises; transport errors during the probe are treated as + "no chunking, try the GET instead and let it surface the failure + in the normal retry loop". + + Probe order: HEAD first (cheap, idempotent), then GET-with-tiny-range + fallback. The HEAD path covers Caddy's `file_server` (which advertises + HEAD) and Caddy's `reverse_proxy` (which forwards HEAD upstream). The + GET-fallback covers the dev `docker compose up` deployment where + requests go straight to FastAPI's GET-only `/api/data/{tid}/download` + route — FastAPI returns **405 Method Not Allowed** to a HEAD on a + GET-only route, which without this fallback would silently disable + chunked download for every dev / non-TLS install. The GET-with-Range + probe asks for 1 byte so the server response is bounded; we discard + the body and read only the headers + status code. """ - tmp_path = Path(f"{target_path}.tmp") + try: + resp = client.head(path) + status = getattr(resp, "status_code", 200) + if status < 400: + size = int(resp.headers.get("content-length", "0") or 0) + accepts = (resp.headers.get("accept-ranges", "").lower() == "bytes") + if size > 0: + return (size, accepts) + # HEAD failed (405 from GET-only route is the common case in + # non-Caddy deployments) or returned 0-length — fall through to + # the tiny-Range GET probe. + except Exception: + pass + try: + with client.stream("GET", path, headers={"Range": "bytes=0-0"}) as resp: + status = getattr(resp, "status_code", 0) + if status not in (200, 206): + return (0, False) + # Drain the 1-byte body so the connection is reusable. + for _ in resp.iter_bytes(): + pass + # Content-Range on a 206 response carries the total: `bytes 0-0/12345`. + # On a 200 response the server didn't honor Range — content-length is the total. + if status == 206: + cr = resp.headers.get("content-range", "") + if "/" in cr: + try: + total = int(cr.rsplit("/", 1)[1]) + return (total, True) + except ValueError: + return (0, False) + return (0, False) + # status == 200 → server ignored Range; we can read content-length but + # accept-ranges is False (or missing) so the caller will not chunk. + size = int(resp.headers.get("content-length", "0") or 0) + accepts = (resp.headers.get("accept-ranges", "").lower() == "bytes") + return (size, accepts) + except Exception: + return (0, False) + + +class _RangeNotHonored(Exception): + """Internal sentinel — server returned 200 instead of 206 to a Range + request. Caller catches and falls back to the single-stream path.""" + + +def _download_chunk( + client: httpx.Client, + path: str, + start: int, + end: int, + part_path: Path, + progress_callback, +) -> None: + """Stream `bytes=start-end` to `part_path`. Caller deals with retry + + cleanup. Raises on any failure (HTTPStatusError on non-206 response, + httpx.* on transport blip, `_RangeNotHonored` if server returned 200 + instead of 206 — chunked path can't trust that result).""" + headers = {"Range": f"bytes={start}-{end}"} + with client.stream("GET", path, headers=headers) as response: + # Server didn't honor the Range — RFC says it MAY return 200 with + # the full body. We can't safely splice that into one part of N, + # so we abort the whole chunked path and let the caller fall back. + if response.status_code == 200: + raise _RangeNotHonored() + response.raise_for_status() + with open(part_path, "wb") as f: + for piece in response.iter_bytes(chunk_size=65536): + f.write(piece) + if progress_callback and piece: + progress_callback(len(piece)) + + +def _download_chunked( + client: httpx.Client, + path: str, + target_path: str, + total_size: int, + parallelism: int, + progress_callback, +) -> int: + """Range-based parallel download. Returns total bytes written. + + Raises `_RangeNotHonored` on the first 200-instead-of-206 response so + the caller can fall back. All other exceptions propagate. + + Cleanup discipline: every part file we create gets removed before + return (success or failure). The destination is written via the + caller's `.tmp` and renamed atomically. + """ + target = Path(target_path) + # Reap leftovers from previously SIGKILL'd / crashed pulls before we + # start writing — without this, PID-suffixed files from dead PIDs + # accumulate forever on disk (devil's-advocate R3 finding #1). + _reap_dead_pid_leftovers(target_path) + # Per-process tmp + part suffixes (devil's-advocate R2 finding #2): + # if two `agnes pull` invocations target the same parquet + # concurrently (e.g. SessionStart hook + manual run, or two + # terminals), bare `.tmp` and `.partN` paths would + # collide — one process's part-write yanks the other's in-progress + # write, manifest hash check then fails spuriously. Including PID + # in the suffix makes each invocation's intermediate files + # disjoint; the final `os.replace` to the bare target is atomic so + # last-writer-wins, both processes succeed individually. + pid = os.getpid() + tmp_path = Path(f"{target_path}.{pid}.tmp") + parallelism = max(1, parallelism) + # Build chunks — last chunk takes the remainder. + chunk_size = total_size // parallelism + if chunk_size <= 0: + chunk_size = total_size # tiny file, single chunk + parallelism = 1 + ranges = [] + for i in range(parallelism): + start = i * chunk_size + end = (start + chunk_size - 1) if i < parallelism - 1 else (total_size - 1) + ranges.append((i, start, end)) + + part_paths = [Path(f"{target_path}.{pid}.part{i}") for i, _, _ in ranges] + # Pre-clean any leftovers from a prior run of THIS process. + for p in part_paths: + p.unlink(missing_ok=True) + + def _attempt_chunk(i: int, start: int, end: int) -> None: + last_exc: Optional[Exception] = None + for attempt in range(_RETRY_ATTEMPTS + 1): + try: + _download_chunk( + client, path, start, end, part_paths[i], + progress_callback, + ) + return + except _RangeNotHonored: + # Don't retry — server policy, not a transport blip. + raise + except Exception as exc: + last_exc = exc + if attempt == _RETRY_ATTEMPTS or not _is_transient(exc): + break + time.sleep(_RETRY_BACKOFFS_S[ + min(attempt, len(_RETRY_BACKOFFS_S) - 1) + ]) + assert last_exc is not None + raise last_exc + + try: + if parallelism == 1: + _attempt_chunk(*ranges[0]) + else: + # Use a thread pool so each chunk gets its own concurrent + # request slot on the (HTTP/2-multiplexed when available) + # shared client. httpx.Client is thread-safe for stream(). + with ThreadPoolExecutor(max_workers=parallelism) as ex: + futs = [ex.submit(_attempt_chunk, *r) for r in ranges] + for fut in as_completed(futs): + fut.result() # propagate first error + + # Concatenate parts → tmp_path → atomic rename. + tmp_path.unlink(missing_ok=True) + total_written = 0 + with open(tmp_path, "wb") as out: + for p in part_paths: + with open(p, "rb") as inp: + while True: + block = inp.read(65536) + if not block: + break + out.write(block) + total_written += len(block) + os.replace(tmp_path, target) + return total_written + finally: + # Always clean up part files + any stray tmp. + for p in part_paths: + p.unlink(missing_ok=True) + if tmp_path.exists(): + try: + tmp_path.unlink() + except OSError: + pass + + +def _download_single_stream( + client: httpx.Client, + path: str, + target_path: str, + progress_callback, +) -> int: + """Original single-stream path with retry. Used when chunking is + disabled (small file, no range support, or fallback after 200-on-Range).""" + # Same dead-PID reap as `_download_chunked` so leftovers from + # crashed prior pulls don't accumulate indefinitely. + _reap_dead_pid_leftovers(target_path) + # Per-process tmp suffix — same rationale as `_download_chunked` + # (devil's-advocate R2 finding #2): concurrent `agnes pull` + # invocations against the same target dir must not yank each + # other's in-progress writes. + tmp_path = Path(f"{target_path}.{os.getpid()}.tmp") last_exc: Optional[Exception] = None for attempt in range(_RETRY_ATTEMPTS + 1): try: tmp_path.unlink(missing_ok=True) - with get_client(timeout=300.0) as client: - with client.stream("GET", path) as response: - response.raise_for_status() - total = 0 - with open(tmp_path, "wb") as f: - for chunk in response.iter_bytes(chunk_size=65536): - f.write(chunk) - total += len(chunk) - if progress_callback: - progress_callback(len(chunk)) - # os.replace is atomic on POSIX and Windows for same-filesystem moves. + with client.stream("GET", path) as response: + response.raise_for_status() + total = 0 + with open(tmp_path, "wb") as f: + for chunk in response.iter_bytes(chunk_size=65536): + f.write(chunk) + total += len(chunk) + if progress_callback: + progress_callback(len(chunk)) os.replace(tmp_path, target_path) return total except Exception as exc: @@ -231,23 +608,115 @@ def stream_download(path: str, target_path: str, progress_callback=None) -> int: if attempt == _RETRY_ATTEMPTS or not _is_transient(exc): break time.sleep(_RETRY_BACKOFFS_S[min(attempt, len(_RETRY_BACKOFFS_S) - 1)]) - # Clean up any leftover tmp, then surface the last exception. Translate - # transport errors (timeouts, connection drops, protocol errors) to - # AgnesTransportError so the CLI prints a clean message instead of a - # Python traceback (Pavel's #185 Phase 3B). HTTPStatusError (4xx/5xx - # response from the server) is NOT a transport failure and must - # re-raise verbatim so the caller's status-code handling + the rich - # server error body (e.g. 401 with "token expired", 403 with - # cross_project_forbidden detail) reach the analyst — Devin Review on - # PR #188 caught: HTTPStatusError is a subclass of HTTPError, so the - # generic isinstance(HTTPError) translation was eating status codes. tmp_path.unlink(missing_ok=True) assert last_exc is not None - if isinstance(last_exc, httpx.HTTPStatusError): - raise last_exc - if isinstance(last_exc, httpx.HTTPError): - raise _translate_transport_error( - last_exc, context=f"GET {path} (stream → {target_path})", - timeout_s=300.0, - ) from last_exc raise last_exc + + +def stream_download(path: str, target_path: str, progress_callback=None) -> int: + """Stream a file to `target_path` atomically and with retries. + + Two paths: + 1. **Chunked parallel** — when the server advertises `accept-ranges: + bytes` and `content-length` exceeds `AGNES_PULL_CHUNK_THRESHOLD_BYTES` + (default 50 MB), split into N range requests + (`AGNES_PULL_CHUNK_PARALLELISM`, default 4, capped 1..16) and + download in parallel. Concatenate the part files into `.tmp`, + then `os.replace`. Falls back to single-stream if the server + responds 200 instead of 206 to a Range probe. + 2. **Single-stream** — for small files, no range support, or fallback + from the chunked path. Same atomic-rename + retry semantics as + before. + + Durability properties (unchanged): + - Writes to `.tmp`, then `os.replace` on success. The real + target file never exists in a half-written state. + - Retries up to `_RETRY_ATTEMPTS` on transient errors (network blip, + 5xx); 4xx (auth/404) is raised immediately. + - No hash check here — that's the caller's job (manifest hash). + + Threading: the chunked path uses a ThreadPoolExecutor sized to the + parallelism. httpx.Client.stream() is safe to call concurrently from + multiple threads on a single client (the connection pool serializes + the underlying socket access; HTTP/2 multiplexes streams when the + `h2` extra is installed). + """ + # Use the shared persistent client when available — one TLS + # handshake amortized across N stream_download calls within the same + # process, and HTTP/2 stream multiplexing across the chunk Range + # requests within a single download. Falls back to a fresh per-call + # client if shared-client construction fails (e.g. `h2` install + # broken at runtime). Devil's-advocate R2 finding #1: scope the + # try/except to *only* the shared-client construction — the actual + # download must NOT be retried under this except, otherwise hard + # failures (401/403/404/5xx) waste a full second download attempt + # and revoked-PAT cases don't fail-fast. + try: + client = _get_shared_client() + except Exception: + with get_client(timeout=300.0) as client: + return _stream_download_via(client, path, target_path, progress_callback) + return _stream_download_via(client, path, target_path, progress_callback) + + +def _stream_download_via( + client: httpx.Client, + path: str, + target_path: str, + progress_callback, +) -> int: + """The shared body of `stream_download` parameterized on the client. + Split out so tests can inject a fake client.""" + threshold = _read_chunk_threshold_bytes() + parallelism = _read_chunk_parallelism() + + total_size = 0 + accepts_ranges = False + if parallelism > 1: + total_size, accepts_ranges = _probe_range_support(client, path) + + # Sanity bound on the advertised total size (devil's-advocate R1 + # finding #4): a misconfigured proxy or buggy server returning a + # wildly inflated `Content-Length` would make us split into huge + # `Range: bytes=N-M` requests; the server then clamps each to actual + # bytes available, and we end up with overlapping bytes from the + # start of the file in every part → corrupt assembled output (caught + # later by manifest hash check, but only after wasted bandwidth). + # 100 GiB is the operational ceiling for any single materialized + # parquet on a typical Agnes deployment; values above suggest a + # server / proxy bug rather than a legitimate huge file. Drop to + # single-stream (which can't be confused by overlapping chunks). + SANE_MAX_TOTAL = 100 * 1024**3 # 100 GiB + if total_size > SANE_MAX_TOTAL: + total_size = 0 + accepts_ranges = False + + use_chunked = ( + parallelism > 1 + and accepts_ranges + and total_size > threshold + ) + + try: + if use_chunked: + try: + return _download_chunked( + client, path, target_path, total_size, parallelism, + progress_callback, + ) + except _RangeNotHonored: + # Server lied / proxy stripped the Range — fall through. + pass + return _download_single_stream( + client, path, target_path, progress_callback, + ) + except httpx.HTTPStatusError: + # 4xx / 5xx response from the server — re-raise verbatim so the + # caller's status-code handling + the rich server error body + # reach the analyst (Devin Review on PR #188). + raise + except httpx.HTTPError as exc: + raise _translate_transport_error( + exc, context=f"GET {path} (stream → {target_path})", + timeout_s=300.0, + ) from exc diff --git a/cli/lib/pull.py b/cli/lib/pull.py index b33c4b3..59d4db9 100644 --- a/cli/lib/pull.py +++ b/cli/lib/pull.py @@ -65,6 +65,128 @@ class PullResult: _SAFE_ID_RE = re.compile(r"^[a-zA-Z0-9_\-]{1,128}$") +class _TextualProgress: + """Plain-text progress emitter for non-TTY stderr. + + When `agnes pull` is invoked from a Claude Code SessionStart hook, + a CI runner, or any pipe consumer, stderr is not a terminal. Rich's + progress bar in that mode either suppresses output (silent for + minutes on a multi-GB parquet) or emits raw ANSI noise. This class + instead emits one terse line per file at sensible cadence. + + Cadence policy: emit when *either*: + - per-file bytes-downloaded crosses a 10%-of-total boundary, OR + - 30 s have elapsed since this file's last emission. + + Always emits one final "done" line per file via `finish()` so the + operator sees a confirmed completion even on tiny files. + + Format: `[N/T files] : 25% (16 MB / 66 MB) at 1.5 MB/s` — the + "[N/T files]" prefix lets the operator see overall pull progress + in a multi-table run without buffering all per-file lines. + + Thread-safe — `advance` is called from the chunked-download worker + threads; an internal lock serializes the update + emit. + """ + + _HUMAN_UNITS = ( + (1024 * 1024 * 1024 * 1024, "TB"), + (1024 * 1024 * 1024, "GB"), + (1024 * 1024, "MB"), + (1024, "KB"), + ) + + def __init__(self, *, stream, total_files: int, file_sizes: dict[str, int]): + import threading + self._stream = stream + self._total_files = total_files + self._file_sizes = file_sizes + self._lock = threading.Lock() + # Per-file state. + self._bytes: dict[str, int] = {tid: 0 for tid in file_sizes} + self._started_at: dict[str, float] = {} + self._last_emit_at: dict[str, float] = {} + self._last_emit_pct: dict[str, int] = {} + self._finished_idx: int = 0 # files whose `finish` line has been emitted + + def advance(self, tid: str, n: int) -> None: + """Add `n` bytes to the file's total. Emit a textual update if + the cadence policy allows.""" + with self._lock: + now = time.monotonic() + if tid not in self._started_at: + self._started_at[tid] = now + self._last_emit_at[tid] = now + self._last_emit_pct[tid] = 0 + self._bytes[tid] = self._bytes.get(tid, 0) + n + + total = self._file_sizes.get(tid, 0) + current = self._bytes[tid] + pct = int((current * 100) / total) if total > 0 else 0 + elapsed = now - self._last_emit_at[tid] + crossed_10 = pct >= self._last_emit_pct[tid] + 10 + if crossed_10 or elapsed >= 30.0: + self._last_emit_at[tid] = now + self._last_emit_pct[tid] = pct - (pct % 10) + self._emit_line(tid, current, total, now) + + def finish(self) -> None: + """Emit a final `done` line for any file we never closed out.""" + with self._lock: + now = time.monotonic() + for tid, total in self._file_sizes.items(): + # Treat any file we observed bytes for as needing a + # final line. Files that errored out before any callback + # are still announced (operator wants visibility even on + # zero-byte attempts). + self._finished_idx += 1 + bytes_ = self._bytes.get(tid, 0) + started = self._started_at.get(tid, now) + duration = max(0.001, now - started) + rate = bytes_ / duration + line = ( + f"[{self._finished_idx}/{self._total_files} files] " + f"{tid}: 100% done " + f"({self._fmt_bytes(bytes_)} in {duration:.1f}s, " + f"{self._fmt_bytes(int(rate))}/s)\n" + ) + self._stream.write(line) + try: + self._stream.flush() + except Exception: + pass + + def _emit_line(self, tid: str, current: int, total: int, now: float) -> None: + started = self._started_at.get(tid, now) + duration = max(0.001, now - started) + rate = current / duration + if total > 0: + pct_str = f"{int((current * 100) / total)}%" + size_str = ( + f"({self._fmt_bytes(current)} / {self._fmt_bytes(total)})" + ) + else: + pct_str = "?" + size_str = f"({self._fmt_bytes(current)})" + idx = self._finished_idx + 1 # 1-based "currently working on file N" + line = ( + f"[{idx}/{self._total_files} files] {tid}: {pct_str} " + f"{size_str} at {self._fmt_bytes(int(rate))}/s\n" + ) + self._stream.write(line) + try: + self._stream.flush() + except Exception: + pass + + @classmethod + def _fmt_bytes(cls, n: int) -> str: + for divisor, suffix in cls._HUMAN_UNITS: + if n >= divisor: + return f"{n / divisor:.1f} {suffix}" + return f"{n} B" + + @contextmanager def _override_server_env(server_url: str, token: str) -> Iterator[None]: """Set AGNES_SERVER + scoped token override for the duration of the call. @@ -219,15 +341,34 @@ def run_pull( # the executor + thread overhead for the common single-update case. workers = min(workers, len(to_download)) if to_download else 1 - # Optional progress bar — Rich's Progress tracks per-file bytes - # streamed, aggregated across the parallel ThreadPoolExecutor - # workers. Pavel's #185 Phase 1: a single 6.3 GB parquet on first - # init went 44 minutes silent, looked frozen. Now: aggregate "X.Y - # GB / Z.A GB · 56 MB/s · ETA 1m 20s" to stderr while threads - # stream. None when show_progress=False (SessionStart hooks etc.). + # Optional progress reporting — two paths. + # + # 1. Rich progress bar: per-file bytes-streamed bar with speed + + # ETA. Rendered to stderr when stderr is a TTY. Aggregates + # across the parallel ThreadPoolExecutor workers and across + # chunked-download chunks (all chunks call the same callback + # advancing the same task). + # 2. Textual fallback: when `show_progress=True` but stderr is + # NOT a TTY (Claude Code SessionStart hook, CI run, Docker + # log capture), Rich would either suppress the bar or emit + # raw control sequences. Instead we emit one plain-text line + # per file at most every 10% or 30 s — enough signal to know + # the pull isn't frozen on a multi-GB parquet, terse enough + # not to spam the consumer's log. + # + # Both paths receive the same per-file callback so the chunked- + # download contract ("one file = one task, sum-of-chunks bytes") + # is honored uniformly. + import sys as _sys progress = None progress_tasks: dict[str, int] = {} - if show_progress and to_download: + textual = None + use_textual_fallback = ( + show_progress + and to_download + and not _sys.stderr.isatty() + ) + if show_progress and to_download and not use_textual_fallback: from rich.progress import ( Progress, BarColumn, DownloadColumn, TextColumn, TimeRemainingColumn, TransferSpeedColumn, @@ -248,13 +389,22 @@ def run_pull( progress_tasks[tid] = progress.add_task( "download", label=tid, total=size if size > 0 else None, ) + elif use_textual_fallback: + textual = _TextualProgress( + stream=_sys.stderr, + total_files=len(to_download), + file_sizes={ + tid: int(server_tables[tid].get("size_bytes") or 0) + for tid in to_download + }, + ) def _download_one(tid: str) -> tuple[str, dict | None, str | None]: """Returns (tid, local_table_entry_or_None, error_or_None). One bound thread per call; stream_download is sync I/O so a ThreadPoolExecutor (not asyncio) is the right tool. The progress callback is thread-safe — Rich's Progress.update - holds an internal lock.""" + and the textual fallback's lock both serialize internally.""" target = parquet_dir / f"{tid}.parquet" expected_hash = server_tables[tid].get("hash", "") cb = None @@ -262,6 +412,9 @@ def run_pull( task_id = progress_tasks[tid] def cb(n: int, _tid=tid, _task=task_id): progress.update(_task, advance=n) + elif textual is not None: + def cb(n: int, _tid=tid): + textual.advance(_tid, n) try: stream_download(f"/api/data/{tid}/download", str(target), progress_callback=cb) @@ -294,6 +447,8 @@ def run_pull( finally: if progress is not None: progress.stop() + if textual is not None: + textual.finish() for tid, entry, err in outcomes: if err is not None: diff --git a/config/instance.yaml.example b/config/instance.yaml.example index c836144..be2c8a0 100644 --- a/config/instance.yaml.example +++ b/config/instance.yaml.example @@ -135,6 +135,13 @@ data_source: # # view-backed datasets -- bumped to 600 000 ms = 10 min by default. # # Set 0 to fall through to the extension default. Configurable via # # /admin/server-config UI. + # session_pool_size: 4 + # # Number of pre-warmed DuckDB+bigquery-extension sessions kept + # # in a process-local pool. Each acquire amortizes the + # # ~0.5 s INSTALL/LOAD/CREATE-SECRET cost across requests; a fresh + # # build only happens when the pool is empty. Default 4. Set 0 + # # to disable pooling (every acquire builds + closes a fresh + # # session; matches pre-pool behavior). # --- OpenMetadata catalog (optional) --- # Enriches table and column metadata from OpenMetadata REST API. diff --git a/connectors/bigquery/access.py b/connectors/bigquery/access.py index 48e4e9f..b26a2d1 100644 --- a/connectors/bigquery/access.py +++ b/connectors/bigquery/access.py @@ -8,6 +8,8 @@ from __future__ import annotations import functools import logging +import threading +from collections import deque from contextlib import contextmanager from dataclasses import dataclass from typing import Callable, Iterator, Literal @@ -42,6 +44,12 @@ class BqAccessError(Exception): "bq_forbidden": 502, # other Forbidden from BQ "bq_bad_request": 400, # 400 from BQ when caller flagged it as client-derived "bq_upstream_error": 502, # all other upstream BQ failures + # `responseTooLarge` is a BQ refusal whose root cause is query shape + # (the user asked for too many rows back inline), not auth or syntax. + # 400 with a specific actionable hint instead of the generic + # bq_bad_request / bq_upstream_error mappings, which surfaced the + # raw BQ message and gave operators no path forward. + "bq_response_too_large": 400, } def __init__(self, kind: str, message: str, details: dict | None = None): @@ -51,6 +59,43 @@ class BqAccessError(Exception): super().__init__(message) +_RESPONSE_TOO_LARGE_HINT = ( + "BigQuery refused to return the result inline; the query exceeded BQ's " + "response size limit. Narrow the WHERE clause, aggregate further, " + "select fewer columns, or query a materialized table that's already " + "been bounded server-side." +) + + +def _classify_response_too_large(msg: str, projects: BqProjects) -> BqAccessError: + """Build the `bq_response_too_large` BqAccessError with the canonical + actionable hint and the original BQ message preserved in details for + operator debugging.""" + return BqAccessError( + "bq_response_too_large", + _RESPONSE_TOO_LARGE_HINT, + details={ + "original": msg, + "billing_project": projects.billing, + "data_project": projects.data, + }, + ) + + +def _is_response_too_large(msg: str) -> bool: + """Detect BQ's `responseTooLarge` failure mode by message substring. + + The reason code is stable across HTTP transports (gax.BadRequest from + google-cloud-bigquery, duckdb.IOException from the BQ extension's own + HTTP layer); both surface 'Response too large to return' verbatim in + the message body. Match case-insensitively + tolerate the slight + variant 'response too large' that some surfaces emit without the + 'to return' suffix. + """ + ml = msg.lower() + return "response too large" in ml + + def translate_bq_error( e: Exception, projects: BqProjects, @@ -67,12 +112,24 @@ def translate_bq_error( 2. Forbidden + 'serviceusage' in str(e).lower() -> cross_project_forbidden (with hint) 3. Forbidden -> bq_forbidden - 4. BadRequest, bad_request_status='client_error' + 4. 'response too large' in str(e).lower() + -> bq_response_too_large (HTTP 400, with + actionable hint pointing at WHERE / + aggregate / materialized remediations) + 5. BadRequest, bad_request_status='client_error' -> bq_bad_request (HTTP 400) - 5. BadRequest, bad_request_status='upstream_error' + 6. BadRequest, bad_request_status='upstream_error' -> bq_upstream_error (HTTP 502) - 6. GoogleAPICallError (other) -> bq_upstream_error - 7. Anything else -> RE-RAISED unchanged (don't swallow programmer errors) + 7. GoogleAPICallError (other) -> bq_upstream_error + 8. Anything else -> RE-RAISED unchanged (don't swallow programmer errors) + + The `responseTooLarge` mapping (4) sits ahead of the generic BadRequest + cases on purpose: BQ surfaces this failure mode as a 400 with a + specific reason, but the actionable remediation is "shape your query + differently" — not "your SQL has a syntax error" (the typical + bq_bad_request user-facing meaning) and not "BQ is broken" + (bq_upstream_error). Routing it via its own kind keeps the user-facing + message tight + correct. """ if isinstance(e, BqAccessError): return e @@ -106,6 +163,13 @@ def translate_bq_error( details={"billing_project": projects.billing, "data_project": projects.data}, ) + # Special-case: `responseTooLarge` arrives as gax.BadRequest (HTTP 400) + # but has a unique reason code with a specific, actionable remediation. + # Catch it BEFORE the generic BadRequest mapping below so it doesn't + # surface as a confusing "bad request" (which implies bad SQL). + if _is_response_too_large(msg): + return _classify_response_too_large(msg, projects) + if isinstance(e, gax.BadRequest): if bad_request_status == "client_error": return BqAccessError("bq_bad_request", msg) @@ -196,15 +260,40 @@ def _default_client_factory(projects: BqProjects): ) -@contextmanager -def _default_duckdb_session_factory(projects: BqProjects): - """Yield an in-memory DuckDB conn with bigquery extension loaded + SECRET set - from get_metadata_token(). Auto-cleanup. Translates auth/install failures - to BqAccessError(kind='auth_failed' or 'bq_lib_missing'). +def _default_pool_size() -> int: + """Resolve the BQ DuckDB-extension session pool size from instance.yaml. - Note: `projects.billing` is not used by this factory directly — bigquery_query() - callers pass it themselves as the first positional arg to identify the billing - project. The factory keeps the parameter for symmetry with _default_client_factory. + Reads ``data_source.bigquery.session_pool_size`` (default 4). Sentinel + ``0`` disables pooling (every acquire builds + closes a fresh session; + matches pre-pool behavior). Negative / non-numeric values fall back to + the default — the pool is a perf optimization, not a correctness + boundary, so an unparseable config shouldn't fail-stop the app. + """ + try: + from app.instance_config import get_value + except Exception: + return 4 + raw = get_value("data_source", "bigquery", "session_pool_size", default=4) + try: + n = int(raw) if raw is not None else 4 + except (TypeError, ValueError): + logger.warning( + "BQ session_pool_size=%r is not an int; falling back to default 4", + raw, + ) + return 4 + if n < 0: + return 4 + return n + + +def _build_fresh_bq_session(): + """Build a single fresh in-memory DuckDB conn with the bigquery extension + INSTALL/LOAD'd, the auth SECRET created from get_metadata_token(), and + per-session settings applied. Translates auth / install failures to + BqAccessError. Caller owns the close. + + Used internally by the pool; also used directly when pooling is disabled. """ import duckdb # type: ignore from connectors.bigquery.auth import get_metadata_token, BQMetadataAuthError @@ -220,22 +309,176 @@ def _default_duckdb_session_factory(projects: BqProjects): conn = duckdb.connect(":memory:") try: + conn.execute("INSTALL bigquery FROM community; LOAD bigquery;") + escaped = token.replace("'", "''") + conn.execute( + f"CREATE OR REPLACE SECRET bq_s (TYPE bigquery, ACCESS_TOKEN '{escaped}')" + ) + except Exception as e: + # Build failed — must close the half-initialised conn, otherwise it + # leaks across the pool's lifetime. try: - conn.execute("INSTALL bigquery FROM community; LOAD bigquery;") - escaped = token.replace("'", "''") - conn.execute( - f"CREATE OR REPLACE SECRET bq_s (TYPE bigquery, ACCESS_TOKEN '{escaped}')" - ) - except Exception as e: - raise BqAccessError( - "bq_lib_missing", - f"failed to install/load BigQuery DuckDB extension: {e}", - details={"original": str(e)}, - ) - apply_bq_session_settings(conn) + conn.close() + except Exception: + pass + raise BqAccessError( + "bq_lib_missing", + f"failed to install/load BigQuery DuckDB extension: {e}", + details={"original": str(e)}, + ) + apply_bq_session_settings(conn) + return conn + + +def _refresh_bq_secret(conn) -> None: + """Refresh the auth SECRET on a pooled connection so token rotation + (default GCE metadata token TTL ~1 hr) doesn't break long-lived + pooled entries. + + Cheap when the token cache is warm (a few µs). Failures are + non-fatal here — the pool's liveness probe + per-acquire build + fallback will catch genuinely-broken entries. + """ + from connectors.bigquery.auth import get_metadata_token + try: + token = get_metadata_token() + escaped = token.replace("'", "''") + conn.execute( + f"CREATE OR REPLACE SECRET bq_s (TYPE bigquery, ACCESS_TOKEN '{escaped}')" + ) + except Exception as e: + # Bubble up so the pool drops this entry and rebuilds. + raise BqAccessError( + "auth_failed", + f"could not refresh BQ secret on pooled session: {e}", + details={"original": str(e)}, + ) + + +def _is_pool_entry_alive(conn) -> bool: + """Cheap liveness probe — `SELECT 1`. Returns False on any error so + the pool reaper drops the entry and builds a fresh one.""" + try: + result = conn.execute("SELECT 1").fetchone() + return result is not None and result[0] == 1 + except Exception: + return False + + +# Module-level pool state. Process-cached (mirrors get_bq_access's lifetime). +# Not fork-safe — single uvicorn worker process is the supported deployment +# shape per CLAUDE.md. +_pool: deque = deque() +_pool_lock = threading.Lock() + + +def _reset_session_pool_for_tests() -> None: + """Drop and close every pooled entry. Test helper — production code + should not call this. Exposed so test fixtures + the existing + test_bq_access tests can pin pre-test pool state to empty.""" + with _pool_lock: + while _pool: + entry = _pool.popleft() + try: + entry.close() + except Exception: + pass + + +@contextmanager +def _default_duckdb_session_factory(projects: BqProjects): + """Yield a pooled in-memory DuckDB conn with bigquery extension loaded + + SECRET set from get_metadata_token(). Translates auth / install + failures to BqAccessError(kind='auth_failed' or 'bq_lib_missing'). + + Pooling: amortizes the ~0.5 s INSTALL/LOAD/ATTACH cost across requests + by keeping pre-warmed connections in a bounded deque. Acquire reuses + an existing entry when available (refreshing its auth SECRET so + token rotation doesn't break long-lived entries) and probes liveness + cheaply via ``SELECT 1`` before handing it to the caller. On normal + exit the connection returns to the pool; on exception it's closed + instead (the underlying session may carry dirty state). + + Pool size is ``data_source.bigquery.session_pool_size`` (default 4; + sentinel ``0`` disables pooling entirely, matching pre-pool + behavior). Process-cached, not fork-safe. + + Note: `projects.billing` is not used by this factory directly — bigquery_query() + callers pass it themselves as the first positional arg to identify the billing + project. The factory keeps the parameter for symmetry with _default_client_factory. + """ + pool_size = _default_pool_size() + + # Acquire: prefer a warm entry, fall back to fresh build. + conn = None + if pool_size > 0: + while True: + with _pool_lock: + entry = _pool.popleft() if _pool else None + if entry is None: + break + if not _is_pool_entry_alive(entry): + # Reaper: drop broken entries. + try: + entry.close() + except Exception: + pass + continue + try: + # Refresh the auth SECRET so a long-lived pool entry + # doesn't keep a stale token past its TTL. Cheap when + # the token cache is warm. + _refresh_bq_secret(entry) + except BqAccessError: + try: + entry.close() + except Exception: + pass + continue + # Re-apply session settings (`bq_query_timeout_ms`, …) on + # every reuse so an operator's `/admin/server-config` change + # propagates to pooled entries without requiring container + # restart. Without this, a long-lived pool entry keeps the + # value baked in at first build forever (devil's-advocate + # R1 finding #3). `apply_bq_session_settings` is idempotent + # and fail-soft — re-running on every acquire is cheap. + try: + apply_bq_session_settings(entry) + except Exception: + # apply_bq_session_settings is documented as never + # raising for legitimate "extension doesn't recognise + # setting" cases (it only logs). Defensive guard for + # any unforeseen failure mode — keep the entry, the + # caller's actual query may still succeed. + pass + conn = entry + break + + if conn is None: + conn = _build_fresh_bq_session() + + try: yield conn - finally: - conn.close() + except Exception: + # Caller saw an exception — the conn may be in a dirty state. + # Don't return to pool; close to release native resources. + try: + conn.close() + except Exception: + pass + raise + else: + # Normal exit — return to pool if there's room. + if pool_size > 0: + with _pool_lock: + if len(_pool) < pool_size: + _pool.append(conn) + return + # Pool disabled or full — close. + try: + conn.close() + except Exception: + pass def apply_bq_session_settings(conn) -> None: diff --git a/pyproject.toml b/pyproject.toml index b6ccc9d..0093060 100644 --- a/pyproject.toml +++ b/pyproject.toml @@ -1,6 +1,6 @@ [project] name = "agnes-the-ai-analyst" -version = "0.38.3" +version = "0.39.0" description = "Agnes — AI Data Analyst platform for AI analytical systems" requires-python = ">=3.11,<3.14" license = "MIT" @@ -20,8 +20,13 @@ dependencies = [ "itsdangerous>=2.1.0", "authlib>=1.6.11", "argon2-cffi>=23.1.0", - # HTTP client + # HTTP client. `h2` enables HTTP/2 multiplexing for the persistent + # CLI client used by `agnes pull` (one TCP connection serves N + # concurrent parquet streams + range chunks). `cli/client.py` + # gracefully falls back to HTTP/1.1 if h2 is missing, so this + # extra is for performance, not correctness. "httpx>=0.27.0", + "h2>=4.1.0", # CLI "typer>=0.12.0", "rich>=13.0.0", diff --git a/tests/test_bq_access.py b/tests/test_bq_access.py index cfa1282..5f64918 100644 --- a/tests/test_bq_access.py +++ b/tests/test_bq_access.py @@ -1,5 +1,6 @@ """Tests for connectors/bigquery/access.py — the BqAccess facade.""" import pytest +import threading class TestBqProjects: @@ -36,6 +37,12 @@ class TestBqAccessError: "bq_forbidden": 502, "bq_bad_request": 400, "bq_upstream_error": 502, + # User-facing class for "Response too large to return" — an + # upstream BQ refusal, but caused by query shape (too many rows + # to fit in a single jobs.query response) rather than auth or + # syntax. 400 so the user sees an actionable error and not a + # 502 that suggests "BQ is broken". + "bq_response_too_large": 400, } assert BqAccessError.HTTP_STATUS == expected @@ -143,6 +150,70 @@ class TestTranslateBqError: translate_bq_error(ValueError("not a BQ error"), self.projects, bad_request_status="client_error") + def test_response_too_large_via_gax_bad_request(self): + """BQ ``responseTooLarge`` arrives as ``gax.BadRequest`` (HTTP 400 + with a specific `reason` field). Pre-fix this fell through to the + generic ``bq_bad_request`` mapping — surfacing as a 400 with the + raw upstream message and no actionable hint. Now it routes to a + dedicated ``bq_response_too_large`` kind whose message tells the + user exactly what to do (narrow WHERE / aggregate / use materialized). + """ + from google.api_core.exceptions import BadRequest + from connectors.bigquery.access import translate_bq_error + e = BadRequest("Response too large to return. Consider setting allowLargeResults to true ...") + result = translate_bq_error( + e, self.projects, bad_request_status="client_error", + ) + assert result.kind == "bq_response_too_large", ( + f"got {result.kind!r}; expected dedicated mapping for " + "'Response too large' to avoid the generic bq_bad_request 400 " + "with no actionable hint" + ) + # User-facing message must point at the actionable remediations, + # not just echo the raw BQ string. + assert "exceeded" in result.message.lower() or "too large" in result.message.lower() + assert "where" in result.message.lower() or "aggregate" in result.message.lower() or "materialized" in result.message.lower() + # Original upstream text preserved in details for operator debugging. + assert "original" in result.details + assert "Response too large" in result.details["original"] + + def test_response_too_large_via_duckdb_native_string(self): + """DuckDB-native exceptions (the BQ extension's C++ HTTP path) + carry the same 'Response too large' marker in plain ``Exception`` + messages — must classify the same way as the gax.BadRequest case.""" + from connectors.bigquery.access import translate_bq_error + e = Exception("HTTP 400: Response too large to return.") + result = translate_bq_error( + e, self.projects, bad_request_status="upstream_error", + ) + assert result.kind == "bq_response_too_large" + + def test_response_too_large_classification_is_status_independent(self): + """The mapping must fire regardless of ``bad_request_status`` + (some callers route via 'upstream_error', others via 'client_error'). + It's the BQ error shape that matters, not who's calling.""" + from google.api_core.exceptions import BadRequest + from connectors.bigquery.access import translate_bq_error + e = BadRequest("Response too large to return") + for status in ("client_error", "upstream_error"): + result = translate_bq_error(e, self.projects, bad_request_status=status) + assert result.kind == "bq_response_too_large", ( + f"bad_request_status={status!r} routed to {result.kind!r}; " + "expected bq_response_too_large for both" + ) + + def test_response_too_large_does_not_trigger_on_unrelated_bad_request(self): + """Other BadRequests (syntax errors, malformed identifiers, …) + must keep going through the generic bq_bad_request mapping — only + the 'Response too large' substring triggers the dedicated kind.""" + from google.api_core.exceptions import BadRequest + from connectors.bigquery.access import translate_bq_error + e = BadRequest("Syntax error at [1:23] near unexpected token") + result = translate_bq_error( + e, self.projects, bad_request_status="client_error", + ) + assert result.kind == "bq_bad_request" + class TestDefaultClientFactory: def test_constructs_client_with_billing_project_as_quota(self, monkeypatch): @@ -208,8 +279,21 @@ class TestDefaultClientFactory: class TestDefaultDuckdbSessionFactory: - def test_yields_duckdb_conn_with_secret_then_closes(self, monkeypatch): - from connectors.bigquery.access import _default_duckdb_session_factory, BqProjects + def test_yields_duckdb_conn_with_secret_set_via_pool(self, monkeypatch): + """The pool's first acquire on an empty pool runs the full + INSTALL/LOAD/SECRET sequence. After the with-block exits the + connection is RETURNED to the pool (not closed) so the next + acquire amortizes the extension-load cost. + + Pre-pool semantics (close-on-exit) are preserved on broken + entries + on the explicit pool-reset path; covered in + TestBqSessionPool. + """ + from connectors.bigquery.access import ( + _default_duckdb_session_factory, BqProjects, + _reset_session_pool_for_tests, + ) + _reset_session_pool_for_tests() executed_sql = [] @@ -218,7 +302,10 @@ class TestDefaultDuckdbSessionFactory: self.closed = False def execute(self, sql, params=None): executed_sql.append((sql, params)) - return self + class _Result: + def fetchone(self_inner): + return (1,) + return _Result() def close(self): self.closed = True @@ -228,19 +315,36 @@ class TestDefaultDuckdbSessionFactory: with _default_duckdb_session_factory(BqProjects(billing="b", data="d")) as conn: assert conn is fake_conn - assert fake_conn.closed is True + # Pool retains the conn — close happens at pool reset / shutdown. + assert fake_conn.closed is False # Verify INSTALL/LOAD/SECRET sequence ran assert any("INSTALL bigquery" in sql for sql, _ in executed_sql) assert any("LOAD bigquery" in sql for sql, _ in executed_sql) assert any("CREATE OR REPLACE SECRET" in sql and "tok123" in sql for sql, _ in executed_sql) + # Explicit pool reset closes the retained entry. + _reset_session_pool_for_tests() + assert fake_conn.closed is True + def test_closes_on_exception_inside_with_block(self, monkeypatch): - from connectors.bigquery.access import _default_duckdb_session_factory, BqProjects + """Exceptions inside the with-block leave the underlying conn in + an unknown state (half-completed query, dirty session); the pool + treats it as broken and closes it rather than returning to pool. + """ + from connectors.bigquery.access import ( + _default_duckdb_session_factory, BqProjects, + _reset_session_pool_for_tests, + ) + _reset_session_pool_for_tests() class FakeConn: closed = False - def execute(self, *a, **kw): return self + def execute(self, *a, **kw): + class _Result: + def fetchone(self_inner): + return (1,) + return _Result() def close(self): self.closed = True fake_conn = FakeConn() @@ -449,3 +553,222 @@ class TestGetBqAccess: assert a is b assert isinstance(a, BqAccess) assert a.projects.billing == "" + + +# --------------------------------------------------------------------------- +# DuckDB BQ-extension session pool — amortizes the ~0.5 s INSTALL/LOAD/ATTACH +# cost across requests by keeping pre-warmed DuckDB connections in a +# bounded pool. Each acquire reuses an existing connection (refreshing the +# auth SECRET so token rotation doesn't break long-lived entries) instead +# of spinning up a fresh DuckDB+extension load every time. +# --------------------------------------------------------------------------- + + +class _PoolFakeConn: + """Fake DuckDB connection that records executed SQL and supports + ``close()``. Used across pool tests so we can pin behavior without + booting the real BigQuery extension.""" + _serial = 0 + + def __init__(self): + type(self)._serial += 1 + self.id = type(self)._serial + self.closed = False + self.executed: list[str] = [] + + def execute(self, sql, params=None): + self.executed.append(sql) + # Liveness probe: SELECT 1 returns something fetchable. + class _Result: + def fetchone(self_inner): + return (1,) + def fetchall(self_inner): + return [(1,)] + return _Result() + + def close(self): + self.closed = True + + +@pytest.fixture +def reset_pool(monkeypatch): + """Reset the BQ session pool singleton between tests so leak-detection + assertions don't carry state.""" + from connectors.bigquery import access as bq_access_mod + if hasattr(bq_access_mod, "_reset_session_pool_for_tests"): + bq_access_mod._reset_session_pool_for_tests() + monkeypatch.setattr( + "connectors.bigquery.auth.get_metadata_token", + lambda: "tok-pool", + ) + yield + if hasattr(bq_access_mod, "_reset_session_pool_for_tests"): + bq_access_mod._reset_session_pool_for_tests() + + +class TestBqSessionPool: + def test_pool_reuses_connections_across_acquires(self, monkeypatch, reset_pool): + """Acquiring a session, releasing, then acquiring again must return + the SAME underlying DuckDB connection — no INSTALL/LOAD overhead on + the second request. This is the whole point of the pool.""" + from connectors.bigquery.access import _default_duckdb_session_factory, BqProjects + + # Each duckdb.connect() yields a fresh _PoolFakeConn so we can tell + # them apart by id. + connections_made = [] + def fake_connect(_path): + c = _PoolFakeConn() + connections_made.append(c) + return c + monkeypatch.setattr("duckdb.connect", fake_connect) + + # First acquire: pool is empty, factory builds a new entry. + with _default_duckdb_session_factory(BqProjects(billing="b", data="d")) as conn1: + id1 = conn1.id + + # Second acquire: pool has a warm entry, must hand back the same conn. + with _default_duckdb_session_factory(BqProjects(billing="b", data="d")) as conn2: + id2 = conn2.id + + assert id1 == id2, ( + "expected the same pooled connection across two acquires; " + f"got id1={id1}, id2={id2}" + ) + # And we must NOT have re-INSTALLed/LOADed the extension on reuse — + # only one duckdb.connect() call ever happened. + assert len(connections_made) == 1, ( + f"pool re-built the conn on second acquire; created {len(connections_made)}" + ) + + def test_pool_size_is_configurable(self, monkeypatch, reset_pool): + """``data_source.bigquery.session_pool_size`` controls the upper + bound on warm entries. Above the cap, releasing extra entries + closes them rather than retaining.""" + from connectors.bigquery.access import _default_duckdb_session_factory, BqProjects + + def fake_get_value(*keys, default=None): + if keys == ("data_source", "bigquery", "session_pool_size"): + return 2 # tiny pool + if keys == ("data_source", "bigquery", "query_timeout_ms"): + return 0 # don't try to SET timeout in tests + return default + + monkeypatch.setattr("app.instance_config.get_value", fake_get_value) + monkeypatch.setattr("duckdb.connect", lambda _: _PoolFakeConn()) + + # Acquire 3 in parallel to force 3 simultaneous entries. + cm1 = _default_duckdb_session_factory(BqProjects(billing="b", data="d")) + c1 = cm1.__enter__() + cm2 = _default_duckdb_session_factory(BqProjects(billing="b", data="d")) + c2 = cm2.__enter__() + cm3 = _default_duckdb_session_factory(BqProjects(billing="b", data="d")) + c3 = cm3.__enter__() + + # Release all three. The 3rd release should close the conn since + # the pool already has 2. + cm1.__exit__(None, None, None) + cm2.__exit__(None, None, None) + cm3.__exit__(None, None, None) + + # At least one of the three connections must be closed (pool overflow). + closed_count = sum(1 for c in (c1, c2, c3) if c.closed) + assert closed_count >= 1, ( + "pool retained more than its configured size; expected at least " + f"one close. closed_count={closed_count}" + ) + # Pool retained at most `size` entries, so total live + closed = 3, + # closed >= 1 means pool size <= 2. + assert closed_count == 1 + + def test_pool_replaces_broken_connection(self, monkeypatch, reset_pool): + """If a pooled entry's liveness check fails on acquire (the + underlying DuckDB conn was closed externally, BQ extension state + corrupted, etc.), the pool must drop it and build a fresh entry — + not hand the broken one to the caller.""" + from connectors.bigquery.access import _default_duckdb_session_factory, BqProjects + + # First acquire creates entry #1; we'll then mark it broken. + all_conns: list[_PoolFakeConn] = [] + def fake_connect(_path): + c = _PoolFakeConn() + all_conns.append(c) + return c + monkeypatch.setattr("duckdb.connect", fake_connect) + + with _default_duckdb_session_factory(BqProjects(billing="b", data="d")) as conn1: + id1 = conn1.id + # Simulate corruption: make execute() raise on next call. + def broken_execute(*a, **kw): + raise RuntimeError("connection broken") + conn1.execute = broken_execute # type: ignore[assignment] + + # Second acquire must skip the broken entry and build a fresh one. + with _default_duckdb_session_factory(BqProjects(billing="b", data="d")) as conn2: + id2 = conn2.id + + assert id1 != id2, ( + f"expected a fresh conn after broken-pool reaper; both acquires " + f"returned id={id1}" + ) + assert len(all_conns) >= 2 + + def test_pool_handles_reentrant_acquires_thread_safe(self, monkeypatch, reset_pool): + """Concurrent acquires from multiple threads must never hand the + same underlying DuckDB conn to two threads at once. The pool's + lock acquires/releases are the load-bearing invariant here. + """ + from connectors.bigquery.access import _default_duckdb_session_factory, BqProjects + + monkeypatch.setattr("duckdb.connect", lambda _: _PoolFakeConn()) + + active_ids: set = set() + active_lock = threading.Lock() + violations: list = [] + + def worker(): + for _ in range(20): + with _default_duckdb_session_factory( + BqProjects(billing="b", data="d"), + ) as conn: + with active_lock: + if conn.id in active_ids: + violations.append(conn.id) + active_ids.add(conn.id) + # Hold briefly to give other threads a chance to race. + time.sleep(0.001) + with active_lock: + active_ids.discard(conn.id) + + import time + threads = [threading.Thread(target=worker) for _ in range(4)] + for t in threads: + t.start() + for t in threads: + t.join() + + assert not violations, ( + f"pool handed the same conn to multiple threads concurrently: " + f"{violations}" + ) + + def test_pool_does_not_apply_when_factory_is_injected(self, monkeypatch, reset_pool): + """Test fixtures that inject a custom ``duckdb_session_factory`` + (e.g. tests/conftest.py's ``bq_access`` fixture) MUST bypass the + pool entirely — otherwise their nullcontext-wrapped fake would + get retained between tests and corrupt downstream assertions. + """ + from connectors.bigquery.access import BqAccess, BqProjects + from contextlib import contextmanager + + sentinel = object() + + @contextmanager + def custom_factory(_projects): + yield sentinel + + bq = BqAccess( + BqProjects(billing="b", data="d"), + duckdb_session_factory=custom_factory, + ) + with bq.duckdb_session() as conn: + assert conn is sentinel diff --git a/tests/test_pull_chunked.py b/tests/test_pull_chunked.py new file mode 100644 index 0000000..87edd17 --- /dev/null +++ b/tests/test_pull_chunked.py @@ -0,0 +1,399 @@ +"""Tests for range-based chunked download in cli/client.py:stream_download. + +Background — the previous diagnosis measured `agnes pull` on a single 5.1 GB +materialized parquet at 0.29 MB/s on a corp VPN with per-flow rate-limiting; +4 parallel range requests over the same connection sustained 1.65 MB/s +aggregate. Existing `AGNES_PULL_PARALLELISM=4` parallelizes across files, +not within a file, so a manifest with 1 large materialized parquet + 10 +remote tables yields 1 active worker = single-stream throughput. + +These tests exercise the chunking code path: HEAD probe, Range-request +splitting, fallback when the server doesn't honor ranges, cleanup on +chunk failure, and the small-file bypass. +""" + +from __future__ import annotations + +import os +import threading +from pathlib import Path +from unittest.mock import patch + +import pytest + + +# ── Fake HTTP layer ───────────────────────────────────────────────────── +# The real httpx Client / AsyncClient surface is large; we mock at the +# client-method level. Our `stream_download` should: +# 1. Call HEAD to learn `content-length` + `accept-ranges`. +# 2. If ranges supported and size > threshold, issue N parallel +# `GET` with `Range: bytes=A-B`, each returning 206 + body chunk. +# 3. Concatenate part files into the destination. + +class _FakeResponse: + def __init__(self, status_code: int, headers: dict | None = None, + body: bytes = b""): + self.status_code = status_code + self.headers = headers or {} + self._body = body + + def raise_for_status(self): + if self.status_code >= 400: + import httpx + raise httpx.HTTPStatusError( + f"HTTP {self.status_code}", request=None, response=self, + ) + + def iter_bytes(self, chunk_size: int = 65536): + # Yield in chunk_size pieces so the sink loop runs realistically. + for i in range(0, len(self._body), chunk_size): + yield self._body[i:i + chunk_size] + + def __enter__(self): + return self + + def __exit__(self, *a): + return False + + +class _FakeClient: + """Captures calls + returns canned responses.""" + + def __init__(self, *, body: bytes, accept_ranges: bool = True, + reject_range_with_200: bool = False, + fail_chunk_indices: tuple[int, ...] = (), + head_status: int = 200): + self._body = body + self._accept_ranges = accept_ranges + self._reject_range_with_200 = reject_range_with_200 + self._fail_chunk_indices = set(fail_chunk_indices) + self._head_status = head_status + self.head_calls = 0 + self.range_calls: list[tuple[int, int]] = [] + self.full_get_calls = 0 + self._lock = threading.Lock() + self._chunk_attempt_counts: dict[tuple[int, int], int] = {} + + # `stream_download` calls `client.head(path)` once to probe. + def head(self, path: str, **kwargs): + with self._lock: + self.head_calls += 1 + if self._head_status >= 400: + return _FakeResponse(self._head_status) + headers = {"content-length": str(len(self._body))} + if self._accept_ranges: + headers["accept-ranges"] = "bytes" + return _FakeResponse(200, headers=headers) + + # `stream_download` uses `client.stream("GET", path, headers=...)` + # for both the chunked and full-file paths. Range header presence + # tells us which one. + def stream(self, method: str, path: str, *, headers: dict | None = None, + **kwargs): + rng = (headers or {}).get("Range") or (headers or {}).get("range") + if rng: + # bytes=START-END + spec = rng.split("=", 1)[1] + start_s, end_s = spec.split("-", 1) + start = int(start_s) + end = int(end_s) + with self._lock: + self.range_calls.append((start, end)) + key = (start, end) + attempt = self._chunk_attempt_counts.get(key, 0) + self._chunk_attempt_counts[key] = attempt + 1 + # Determine chunk index (in order of unique starts). + # We map by start to a stable index for fail-injection. + chunk_idx = self._chunk_index_for_start(start) + # Should this attempt fail? Fail only on first attempt for + # listed indices — retry succeeds. + if chunk_idx in self._fail_chunk_indices and attempt == 0: + import httpx + raise httpx.ReadError("simulated chunk failure") + if self._reject_range_with_200: + # Server ignored Range — returns full body with 200. + return _FakeResponse(200, body=self._body) + piece = self._body[start:end + 1] + return _FakeResponse( + 206, + headers={"content-range": f"bytes {start}-{end}/{len(self._body)}"}, + body=piece, + ) + # Full-file GET (single-stream fallback). + with self._lock: + self.full_get_calls += 1 + return _FakeResponse(200, body=self._body) + + def _chunk_index_for_start(self, start: int) -> int: + # Unique sorted starts so fail_chunk_indices is deterministic. + starts = sorted({s for s, _ in self.range_calls}) + try: + return starts.index(start) + except ValueError: + return -1 + + def __enter__(self): + return self + + def __exit__(self, *a): + return False + + def close(self): + pass + + +# ── Test fixtures ─────────────────────────────────────────────────────── + +@pytest.fixture(autouse=True) +def _isolate_config_dir(tmp_path, monkeypatch): + cfg = tmp_path / "_cfg" + cfg.mkdir() + monkeypatch.setenv("AGNES_CONFIG_DIR", str(cfg)) + + +@pytest.fixture(autouse=True) +def _reset_shared_client(monkeypatch): + """Reset the persistent shared httpx.Client between tests so each + test starts from a known state. Tests that need to inject a fake + client also stub `_get_shared_client` directly via the + `_inject_fake_client` helper below.""" + import cli.client as cc + if hasattr(cc, "_SHARED_CLIENT"): + monkeypatch.setattr(cc, "_SHARED_CLIENT", None, raising=False) + yield + if hasattr(cc, "_SHARED_CLIENT"): + monkeypatch.setattr(cc, "_SHARED_CLIENT", None, raising=False) + + +def _inject_fake_client(monkeypatch, fake): + """Patch both client factories to return the same fake. Tests target + `_get_shared_client` (the path stream_download actually takes) and + also `get_client` so the fallback path also lands on the fake.""" + monkeypatch.setattr("cli.client.get_client", lambda timeout=300.0: fake) + monkeypatch.setattr("cli.client._get_shared_client", + lambda: fake, raising=False) + + +# ── Tests ─────────────────────────────────────────────────────────────── + +def test_chunked_download_success(tmp_path, monkeypatch): + """Server advertises ranges, file is large enough — 4 chunks, assembled + correctly into target.""" + body = bytes(range(256)) * 2048 # 512 KB + threshold = 1024 # 1 KB so 512 KB is "large" + monkeypatch.setenv("AGNES_PULL_CHUNK_THRESHOLD_BYTES", str(threshold)) + monkeypatch.setenv("AGNES_PULL_CHUNK_PARALLELISM", "4") + + fake = _FakeClient(body=body, accept_ranges=True) + _inject_fake_client(monkeypatch, fake) + + from cli.client import stream_download + target = tmp_path / "out.parquet" + progress_bytes = [] + total = stream_download("/api/data/x/download", str(target), + progress_callback=lambda n: progress_bytes.append(n)) + + assert total == len(body) + assert target.read_bytes() == body + # 4 distinct ranges issued (no overlaps; last one carries remainder). + assert len(set(fake.range_calls)) == 4 + assert fake.head_calls == 1 + assert fake.full_get_calls == 0 + # Progress callback was called and total bytes match. + assert sum(progress_bytes) == len(body) + # Chunk parts cleaned up. + leftovers = list(tmp_path.glob("*.part*")) + assert leftovers == [], f"orphan part files: {leftovers}" + + +def test_chunked_download_fallback_when_server_ignores_range( + tmp_path, monkeypatch, +): + """Server returns 200 instead of 206 on the first range probe — abort + chunked path, fall back to single-stream. No corrupt output.""" + body = b"X" * 200_000 + monkeypatch.setenv("AGNES_PULL_CHUNK_THRESHOLD_BYTES", "1024") + monkeypatch.setenv("AGNES_PULL_CHUNK_PARALLELISM", "4") + + # accept_ranges=True (HEAD lies), but every Range GET returns 200 + # with the full body — that's the "server ignored Range" path. + fake = _FakeClient(body=body, accept_ranges=True, + reject_range_with_200=True) + _inject_fake_client(monkeypatch, fake) + + from cli.client import stream_download + target = tmp_path / "out.bin" + total = stream_download("/api/data/x/download", str(target)) + + assert total == len(body) + assert target.read_bytes() == body + # Fell back to a single full-body GET. + assert fake.full_get_calls >= 1 + + +def test_small_file_uses_single_stream_path(tmp_path, monkeypatch): + """Below threshold → no HEAD probe needed (or HEAD short-circuits), + no Range requests, plain single-stream download.""" + body = b"x" * 500 # tiny + monkeypatch.setenv("AGNES_PULL_CHUNK_THRESHOLD_BYTES", "10000") # 10 KB + monkeypatch.setenv("AGNES_PULL_CHUNK_PARALLELISM", "4") + + fake = _FakeClient(body=body, accept_ranges=True) + _inject_fake_client(monkeypatch, fake) + + from cli.client import stream_download + target = tmp_path / "out.bin" + total = stream_download("/api/data/x/download", str(target)) + + assert total == len(body) + assert target.read_bytes() == body + assert fake.range_calls == [], "small file must not split into ranges" + assert fake.full_get_calls >= 1 + + +def test_chunked_download_no_accept_ranges_falls_back(tmp_path, monkeypatch): + """HEAD doesn't advertise byte-range support → skip chunked path, + plain single-stream.""" + body = b"y" * 200_000 + monkeypatch.setenv("AGNES_PULL_CHUNK_THRESHOLD_BYTES", "1024") + monkeypatch.setenv("AGNES_PULL_CHUNK_PARALLELISM", "4") + + fake = _FakeClient(body=body, accept_ranges=False) + _inject_fake_client(monkeypatch, fake) + + from cli.client import stream_download + target = tmp_path / "out.bin" + total = stream_download("/api/data/x/download", str(target)) + + assert total == len(body) + assert target.read_bytes() == body + assert fake.range_calls == [] + assert fake.full_get_calls >= 1 + + +def test_chunked_download_one_chunk_retries_then_succeeds( + tmp_path, monkeypatch, +): + """One chunk fails on first attempt; retry path completes the file.""" + body = bytes(range(256)) * 1024 # 256 KB + monkeypatch.setenv("AGNES_PULL_CHUNK_THRESHOLD_BYTES", "1024") + monkeypatch.setenv("AGNES_PULL_CHUNK_PARALLELISM", "4") + monkeypatch.setenv("AGNES_STREAM_RETRIES", "2") + + fake = _FakeClient(body=body, accept_ranges=True, + fail_chunk_indices=(1,)) # second chunk blips once + _inject_fake_client(monkeypatch, fake) + + from cli.client import stream_download + target = tmp_path / "out.bin" + total = stream_download("/api/data/x/download", str(target)) + + assert total == len(body) + assert target.read_bytes() == body + # Cleanup of all part files. + assert list(tmp_path.glob("*.part*")) == [] + + +def test_chunked_download_failure_cleans_up_part_files(tmp_path, monkeypatch): + """All retries exhausted on a chunk → no destination file, no orphan + part files.""" + body = b"z" * 200_000 + monkeypatch.setenv("AGNES_PULL_CHUNK_THRESHOLD_BYTES", "1024") + monkeypatch.setenv("AGNES_PULL_CHUNK_PARALLELISM", "4") + monkeypatch.setenv("AGNES_STREAM_RETRIES", "0") + + # Inject a permanent failure on chunk 2 (retries=0 → first failure + # is fatal). + class _ChronicFail(_FakeClient): + def stream(self, method, path, *, headers=None, **kwargs): + rng = (headers or {}).get("Range") + if rng: + spec = rng.split("=", 1)[1] + start = int(spec.split("-", 1)[0]) + # Permanently fail the chunk starting at exactly half. + if start >= len(body) // 4 and start <= len(body) // 2: + import httpx + raise httpx.ReadError("permanent") + return super().stream(method, path, headers=headers, **kwargs) + return super().stream(method, path, headers=headers, **kwargs) + + fake = _ChronicFail(body=body, accept_ranges=True) + _inject_fake_client(monkeypatch, fake) + + from cli.client import stream_download + target = tmp_path / "out.bin" + with pytest.raises(Exception): + stream_download("/api/data/x/download", str(target)) + + assert not target.exists(), "no destination file after total failure" + # No orphan parts. + assert list(tmp_path.glob("*.part*")) == [] + assert not (tmp_path / "out.bin.tmp").exists() + + +def test_progress_callback_aggregates_across_chunks(tmp_path, monkeypatch): + """The progress callback should fire with byte deltas summing to the + full file across all chunks — caller treats one file as one task.""" + body = bytes(range(256)) * 4096 # 1 MB + monkeypatch.setenv("AGNES_PULL_CHUNK_THRESHOLD_BYTES", "1024") + monkeypatch.setenv("AGNES_PULL_CHUNK_PARALLELISM", "4") + + fake = _FakeClient(body=body, accept_ranges=True) + _inject_fake_client(monkeypatch, fake) + + from cli.client import stream_download + target = tmp_path / "out.bin" + advances = [] + stream_download("/api/data/x/download", str(target), + progress_callback=lambda n: advances.append(n)) + assert sum(advances) == len(body) + + +def test_dead_pid_leftovers_are_reaped(tmp_path, monkeypatch): + """Devil's-advocate R3 finding #1: PID-suffixed `.{pid}.tmp` + and `.partN` files from a SIGKILL'd previous pull must be reaped on + the next pull, otherwise they accumulate on disk indefinitely. + + PID 1 (init) is always alive, so a file with pid=1 must NOT be + reaped. PID 99999999 (~10⁸) is essentially guaranteed not-alive on + any modern Linux/macOS — used as the dead-PID marker. + """ + target = tmp_path / "out.bin" + + # Live-PID leftover (pid=1 = init, always alive). Must NOT be reaped. + live_path = tmp_path / "out.bin.1.tmp" + live_path.write_bytes(b"live process leftover") + + # Dead-PID leftovers — both .tmp and .part0 forms. + dead_tmp = tmp_path / "out.bin.99999999.tmp" + dead_tmp.write_bytes(b"dead process leftover tmp") + dead_part = tmp_path / "out.bin.99999999.part0" + dead_part.write_bytes(b"dead process leftover part") + + # Bare-name leftover (no PID suffix) — pre-existing pattern, NOT + # touched by the new reaper. Reaper only matches `.{digits}.tmp` + # / `.{digits}.partN` exactly. + bare_tmp = tmp_path / "out.bin.tmp" + bare_tmp.write_bytes(b"bare leftover") + + from cli.client import _reap_dead_pid_leftovers + _reap_dead_pid_leftovers(str(target)) + + assert live_path.exists(), "live-PID leftover must be preserved" + assert not dead_tmp.exists(), "dead-PID .tmp must be reaped" + assert not dead_part.exists(), "dead-PID .partN must be reaped" + assert bare_tmp.exists(), "bare-name leftover is out of scope for the reaper" + + +def test_reap_handles_garbage_in_filename(tmp_path): + """Files in the parquet dir whose names happen to glob-match but + don't conform to the PID-suffix shape must be skipped without + raising.""" + target = tmp_path / "out.bin" + weird = tmp_path / "out.bin.garbage.tmp" + weird.write_bytes(b"x") + + from cli.client import _reap_dead_pid_leftovers + # Must not raise even though the filename has no integer PID. + _reap_dead_pid_leftovers(str(target)) + assert weird.exists(), "non-PID-shaped file must not be reaped" diff --git a/tests/test_pull_progress.py b/tests/test_pull_progress.py new file mode 100644 index 0000000..fe31c72 --- /dev/null +++ b/tests/test_pull_progress.py @@ -0,0 +1,123 @@ +"""Tests for `agnes pull` progress UX (Change 3). + +The Rich progress bar handles the TTY case fine, but Claude Code's +SessionStart context — and any hook running `agnes pull` non-interactively — +has stderr connected to a pipe, not a TTY. In that case Rich either +suppresses output entirely or emits raw ANSI noise into the consumer's +log. Goal: when the caller asks for progress and stderr is not a TTY, +emit a plain-text per-10%-or-30s update so the operator gets *some* +signal instead of multi-minute silence. +""" + +from __future__ import annotations + +import io +import time +from pathlib import Path +from unittest.mock import MagicMock + +import pytest + + +@pytest.fixture(autouse=True) +def _isolate_config_dir(tmp_path, monkeypatch): + cfg = tmp_path / "_cfg" + cfg.mkdir() + monkeypatch.setenv("AGNES_CONFIG_DIR", str(cfg)) + + +@pytest.fixture +def fake_pull_io(monkeypatch): + """Stub the manifest + memory + download endpoints so run_pull can + execute end-to-end with a fake parquet write per table.""" + canned_manifest = { + "tables": { + "tbl_big": {"hash": "h1", "rows": 0, "size_bytes": 1_000_000}, + }, + } + canned_memory = {"mandatory": [], "approved": []} + + def _api_get(path, *args, **kwargs): + resp = MagicMock() + resp.status_code = 200 + if path == "/api/sync/manifest": + resp.json.return_value = canned_manifest + elif path == "/api/memory/bundle": + resp.json.return_value = canned_memory + resp.raise_for_status = lambda: None + return resp + + def _stream_download(path, target_path, progress_callback=None): + # Simulate a chunked download: emit progress in 4 increments + # totaling the announced size. + total = 1_000_000 + slices = [total // 4] * 3 + [total - 3 * (total // 4)] + Path(target_path).write_bytes(b"PAR1" + b"\x00" * 1000 + b"PAR1") + if progress_callback: + for s in slices: + progress_callback(s) + return total + + monkeypatch.setattr("cli.lib.pull.api_get", _api_get, raising=False) + monkeypatch.setattr("cli.lib.pull.stream_download", _stream_download, + raising=False) + monkeypatch.setattr("cli.lib.pull._is_valid_parquet", lambda p: True, + raising=False) + monkeypatch.setattr("cli.lib.pull._file_md5", lambda p: "h1", raising=False) + + +def test_textual_progress_when_stderr_is_not_tty( + tmp_path, fake_pull_io, monkeypatch, capsys, +): + """Non-TTY stderr → emit a plain-text progress line per file.""" + # Force the non-TTY branch even if pytest's fake stderr is a tty. + monkeypatch.setattr("sys.stderr.isatty", lambda: False, raising=False) + + from cli.lib.pull import run_pull + result = run_pull( + server_url="http://x", token="t", workspace=tmp_path, + show_progress=True, + ) + captured = capsys.readouterr() + # Some indication of the file + bytes ran; we don't pin exact format. + assert "tbl_big" in captured.err + assert result.tables_updated == 1 + # No raw ANSI escape sequences in the textual fallback. + assert "\x1b[" not in captured.err.split("tbl_big")[0] + + +def test_no_progress_output_when_show_progress_is_false( + tmp_path, fake_pull_io, monkeypatch, capsys, +): + """`show_progress=False` (the SessionStart hook path) emits no + progress text on stderr in either TTY or non-TTY mode.""" + monkeypatch.setattr("sys.stderr.isatty", lambda: False, raising=False) + + from cli.lib.pull import run_pull + run_pull( + server_url="http://x", token="t", workspace=tmp_path, + show_progress=False, + ) + captured = capsys.readouterr() + assert "tbl_big" not in captured.err + + +def test_textual_progress_emits_at_completion( + tmp_path, fake_pull_io, monkeypatch, capsys, +): + """At least one final completion line gets emitted per file even if + the throttle window doesn't trigger mid-file.""" + monkeypatch.setattr("sys.stderr.isatty", lambda: False, raising=False) + from cli.lib.pull import run_pull + run_pull( + server_url="http://x", token="t", workspace=tmp_path, + show_progress=True, + ) + captured = capsys.readouterr() + # Final line marks the file as done — either "100%" or a "✓ tbl_big" / + # "tbl_big done" indicator. We accept any final-completion form. + assert ( + "100%" in captured.err + or "done" in captured.err.lower() + or "complete" in captured.err.lower() + ) diff --git a/tests/test_pull_shared_client.py b/tests/test_pull_shared_client.py new file mode 100644 index 0000000..113ed3a --- /dev/null +++ b/tests/test_pull_shared_client.py @@ -0,0 +1,85 @@ +"""Tests for the persistent HTTP/2-capable shared client (Change 2). + +`agnes pull` issues N stream_download calls — one per parquet. Without +pooling, each call performs a fresh TLS handshake. The shared client is +created lazily once per process and closed at exit; HTTP/2 (when `h2` is +available) further multiplexes all chunk Range requests over a single +TCP connection. +""" + +from __future__ import annotations + +import pytest + + +@pytest.fixture(autouse=True) +def _isolate_config_dir(tmp_path, monkeypatch): + cfg = tmp_path / "_cfg" + cfg.mkdir() + monkeypatch.setenv("AGNES_CONFIG_DIR", str(cfg)) + # Some dev environments point SSL_CERT_FILE / REQUESTS_CA_BUNDLE at a + # corp-CA bundle that may not exist on every laptop running the test + # suite. Clear those so httpx.Client() construction in the shared- + # client path can build a default SSL context without trying to load + # a missing PEM file. + for var in ("SSL_CERT_FILE", "REQUESTS_CA_BUNDLE", "CURL_CA_BUNDLE"): + monkeypatch.delenv(var, raising=False) + + +@pytest.fixture(autouse=True) +def _reset_shared(monkeypatch): + import cli.client as cc + cc._close_shared_client() + monkeypatch.setattr(cc, "_SHARED_CLIENT", None, raising=False) + yield + cc._close_shared_client() + + +def test_get_shared_client_is_cached(monkeypatch): + """Multiple calls return the same client instance — no fresh TLS + handshake per stream_download invocation.""" + monkeypatch.setenv("AGNES_SERVER", "https://x.example.test") + from cli.client import _get_shared_client + c1 = _get_shared_client() + c2 = _get_shared_client() + assert c1 is c2, "shared client must be a single instance" + + +def test_get_shared_client_falls_back_when_http2_unavailable(monkeypatch): + """If httpx raises during HTTP/2 client construction (e.g. `h2` not + installed in the runtime env), we must gracefully build a HTTP/1.1 + client instead of crashing the pull.""" + import httpx + + monkeypatch.setenv("AGNES_SERVER", "https://x.example.test") + import cli.client as cc + + real_client = httpx.Client + + construction_calls = [] + + def fake_client(*args, **kwargs): + construction_calls.append(kwargs.copy()) + if kwargs.get("http2") is True: + raise ImportError("Using http2=True, but the 'h2' package is not installed") + return real_client(*args, **kwargs) + + monkeypatch.setattr(httpx, "Client", fake_client) + + client = cc._get_shared_client() + assert client is not None + # Two construction attempts: first http2=True (raised), second falls + # back to HTTP/1.1 (no http2 kwarg). + assert construction_calls[0].get("http2") is True + assert construction_calls[1].get("http2") is None or construction_calls[1].get("http2") is False + cc._close_shared_client() + + +def test_close_shared_client_idempotent(monkeypatch): + """Calling close twice (once explicitly, once via atexit) must not + raise.""" + monkeypatch.setenv("AGNES_SERVER", "https://x.example.test") + from cli.client import _get_shared_client, _close_shared_client + _get_shared_client() + _close_shared_client() + _close_shared_client() # second close is a no-op diff --git a/tests/test_query_remote_rewrite.py b/tests/test_query_remote_rewrite.py new file mode 100644 index 0000000..bc84c51 --- /dev/null +++ b/tests/test_query_remote_rewrite.py @@ -0,0 +1,701 @@ +"""Unit tests for ``_rewrite_user_sql_for_bigquery_query``. + +The helper rewrites user SQL referencing query_mode='remote' BigQuery +tables so the entire query ships to BQ via the DuckDB BQ extension's +``bigquery_query(, )`` UDF — engaging WHERE / SELECT / +LIMIT predicate pushdown instead of falling through to ATTACH-catalog +mode (which opens a Storage Read API session over the whole table). + +These tests pin down each conservative-skip rule plus the happy-path +rewrites. Edge cases (CTE shadowing, double-wrap, mixed-source JOIN) +are intentionally explicit so a future refactor doesn't quietly +loosen the guard. +""" +from __future__ import annotations + +import pytest + + +# --------------------------------------------------------------------------- +# Test infrastructure: an in-memory DuckDB seeded with table_registry rows +# matching the shapes the production registry produces. Avoids the full app +# bootstrap path; the rewriter only needs ``conn.execute("SELECT * FROM +# table_registry ...")`` to resolve names. +# --------------------------------------------------------------------------- + + +@pytest.fixture +def seeded_registry(tmp_path, monkeypatch): + """Build a fresh ``system.duckdb`` in tmp_path with the schema migrated. + + Returns the open connection so tests can pass it to the rewriter. + Cleanup is automatic via tmp_path teardown — but we close the + open singleton handle first so a different DATA_DIR in the next + test doesn't see the previous tmp's lock. + """ + from src.db import get_system_db, close_system_db + + monkeypatch.setenv("DATA_DIR", str(tmp_path)) + (tmp_path / "state").mkdir(parents=True, exist_ok=True) + close_system_db() + conn = get_system_db() + yield conn + close_system_db() + + +def _register_bq_remote(conn, *, table_id, name, bucket, source_table): + from src.repositories.table_registry import TableRegistryRepository + TableRegistryRepository(conn).register( + id=table_id, + name=name, + source_type="bigquery", + bucket=bucket, + source_table=source_table, + query_mode="remote", + ) + + +def _register_local(conn, *, table_id, name, source_type="keboola"): + from src.repositories.table_registry import TableRegistryRepository + TableRegistryRepository(conn).register( + id=table_id, + name=name, + source_type=source_type, + bucket="bkt", + source_table=name, + query_mode="local", + ) + + +def _set_bq_project(monkeypatch, project="test-prj", billing=None): + """Stub get_bq_access so the rewriter sees a real-looking project ID. + + `project` configures the data project (used in backtick paths). + `billing` (when provided) configures a different billing project so + cross-project deployments can be exercised; defaults to `project` + for the single-project case.""" + from connectors.bigquery.access import BqAccess, BqProjects, get_bq_access + bq = BqAccess( + BqProjects(billing=billing or project, data=project), + client_factory=lambda projects: object(), + ) + monkeypatch.setattr( + "app.api.query.get_bq_access", + lambda: bq, + raising=False, + ) + get_bq_access.cache_clear() + + +# --------------------------------------------------------------------------- +# Happy-path rewrites +# --------------------------------------------------------------------------- + + +def test_simple_select_where_against_one_bq_table_rewrites(seeded_registry, monkeypatch): + """Single-table SELECT-WHERE against a registered BQ remote row → + full SQL wrapped in ``bigquery_query('project', '')``. + The bare-name reference gets translated to BQ-native backtick form.""" + from app.api.query import _rewrite_user_sql_for_bigquery_query + _register_bq_remote(seeded_registry, table_id="bq.fin.ue", name="ue", + bucket="fin", source_table="ue") + _set_bq_project(monkeypatch, "test-prj") + + rewritten, did_rewrite = _rewrite_user_sql_for_bigquery_query( + "SELECT count(*) FROM ue WHERE event_date = '2026-01-01'", + seeded_registry, + ) + + assert did_rewrite is True + # Outer wrap must be a single bigquery_query() FROM-source. + assert "bigquery_query(" in rewritten + assert "test-prj" in rewritten + # Inner SQL: bare name rewritten to backticked BQ-native path. + assert "`test-prj.fin.ue`" in rewritten + # Inner SQL is dollar-quoted (`$bqq_inner$ ... $bqq_inner$`), so + # single quotes inside the WHERE predicate remain literal — no + # doubling, no backslash escaping. Verifies the safer embedding form + # introduced after the code review caught naive single-quote-only + # escape doubling missing DuckDB backslash sequences. + assert "$bqq_inner$" in rewritten + assert "event_date = '2026-01-01'" in rewritten + + +def test_direct_bq_path_rewrites(seeded_registry, monkeypatch): + """User wrote the direct ``bq."ds"."tbl"`` form. The rewriter must + still translate to BQ-native backtick form before wrapping.""" + from app.api.query import _rewrite_user_sql_for_bigquery_query + _register_bq_remote(seeded_registry, table_id="bq.fin.ue", name="ue", + bucket="fin", source_table="ue") + _set_bq_project(monkeypatch, "test-prj") + + rewritten, did_rewrite = _rewrite_user_sql_for_bigquery_query( + 'SELECT * FROM bq."fin"."ue" LIMIT 10', + seeded_registry, + ) + + assert did_rewrite is True + assert "bigquery_query(" in rewritten + assert "`test-prj.fin.ue`" in rewritten + # Original duckdb-flavor path must NOT remain (it'd parse-fail under BQ). + assert 'bq."fin"."ue"' not in rewritten + + +def test_cte_referencing_bq_table_rewrites_inside_cte(seeded_registry, monkeypatch): + """A WITH clause whose body references a BQ table must rewrite that + inner reference; the wrapping happens at the top level so BQ sees a + valid BQ-flavor CTE.""" + from app.api.query import _rewrite_user_sql_for_bigquery_query + _register_bq_remote(seeded_registry, table_id="bq.fin.orders", name="orders", + bucket="fin", source_table="orders") + _set_bq_project(monkeypatch, "test-prj") + + rewritten, did_rewrite = _rewrite_user_sql_for_bigquery_query( + "WITH x AS (SELECT id FROM orders WHERE total > 0) SELECT count(*) FROM x", + seeded_registry, + ) + assert did_rewrite is True + # Inner reference is rewritten. + assert "`test-prj.fin.orders`" in rewritten + # The whole thing is wrapped — bigquery_query is the outermost FROM. + assert rewritten.lower().count("bigquery_query(") == 1 + + +def test_subquery_referencing_bq_table_rewrites(seeded_registry, monkeypatch): + """Subquery in FROM position — same handling as a CTE: rewrite the + inner table reference, wrap the whole at the top.""" + from app.api.query import _rewrite_user_sql_for_bigquery_query + _register_bq_remote(seeded_registry, table_id="bq.fin.ue", name="ue", + bucket="fin", source_table="ue") + _set_bq_project(monkeypatch, "test-prj") + + rewritten, did_rewrite = _rewrite_user_sql_for_bigquery_query( + "SELECT s.cnt FROM (SELECT count(*) AS cnt FROM ue) s", + seeded_registry, + ) + assert did_rewrite is True + assert "`test-prj.fin.ue`" in rewritten + assert rewritten.lower().count("bigquery_query(") == 1 + + +def test_multiple_bq_tables_one_project_combine(seeded_registry, monkeypatch): + """Two registered BQ tables in the same project → single + ``bigquery_query()`` wraps the whole SQL with both refs rewritten + inline. No separate parallel calls.""" + from app.api.query import _rewrite_user_sql_for_bigquery_query + _register_bq_remote(seeded_registry, table_id="bq.fin.orders", name="orders", + bucket="fin", source_table="orders") + _register_bq_remote(seeded_registry, table_id="bq.fin.users", name="users", + bucket="fin", source_table="users") + _set_bq_project(monkeypatch, "test-prj") + + rewritten, did_rewrite = _rewrite_user_sql_for_bigquery_query( + "SELECT u.id, count(o.id) " + "FROM users u JOIN orders o ON u.id = o.user_id " + "GROUP BY u.id", + seeded_registry, + ) + assert did_rewrite is True + # Both rewritten. + assert "`test-prj.fin.users`" in rewritten + assert "`test-prj.fin.orders`" in rewritten + # Single wrap. + assert rewritten.lower().count("bigquery_query(") == 1 + + +# --------------------------------------------------------------------------- +# Conservative-skip cases +# --------------------------------------------------------------------------- + + +def test_join_bq_to_local_skips_rewrite(seeded_registry, monkeypatch): + """A JOIN between a BQ table and a local-mode (Keboola/Jira) table + is a cross-source query — wrapping it in bigquery_query() would lose + the local table. The rewriter must fall through to the ATTACH-catalog + path (slow but correct). + """ + from app.api.query import _rewrite_user_sql_for_bigquery_query + _register_bq_remote(seeded_registry, table_id="bq.fin.ue", name="ue", + bucket="fin", source_table="ue") + _register_local(seeded_registry, table_id="kbc.in.local_orders", + name="local_orders") + _set_bq_project(monkeypatch, "test-prj") + + user_sql = ( + "SELECT u.id, lo.total " + "FROM ue u JOIN local_orders lo ON u.id = lo.user_id" + ) + rewritten, did_rewrite = _rewrite_user_sql_for_bigquery_query( + user_sql, seeded_registry, + ) + assert did_rewrite is False + assert rewritten == user_sql # untouched + + +def test_no_bq_tables_passes_through(seeded_registry, monkeypatch): + """User SQL referencing only local-source tables → no rewrite, + no log spam, original SQL returned.""" + from app.api.query import _rewrite_user_sql_for_bigquery_query + _register_local(seeded_registry, table_id="kbc.in.orders", name="orders") + _set_bq_project(monkeypatch, "test-prj") + + user_sql = "SELECT * FROM orders WHERE id = 1" + rewritten, did_rewrite = _rewrite_user_sql_for_bigquery_query( + user_sql, seeded_registry, + ) + assert did_rewrite is False + assert rewritten == user_sql + + +def test_already_contains_bigquery_query_passes_through(seeded_registry, monkeypatch): + """User SQL already calls bigquery_query() — never double-wrap. + + Note: the /api/query endpoint blocks ``bigquery_query`` in user SQL + via the keyword denylist, so this scenario can't reach the rewriter + in production today. Defensive guard for callers from other paths. + """ + from app.api.query import _rewrite_user_sql_for_bigquery_query + _register_bq_remote(seeded_registry, table_id="bq.fin.ue", name="ue", + bucket="fin", source_table="ue") + _set_bq_project(monkeypatch, "test-prj") + + user_sql = ( + "SELECT * FROM bigquery_query('test-prj', 'SELECT * FROM `test-prj.fin.ue`')" + ) + rewritten, did_rewrite = _rewrite_user_sql_for_bigquery_query( + user_sql, seeded_registry, + ) + assert did_rewrite is False + assert rewritten == user_sql + + +def test_unconfigured_bq_project_skips(seeded_registry, monkeypatch): + """If get_bq_access() is the not-configured sentinel (data=''), + don't rewrite — there's no project to fill into bigquery_query().""" + from app.api.query import _rewrite_user_sql_for_bigquery_query + _register_bq_remote(seeded_registry, table_id="bq.fin.ue", name="ue", + bucket="fin", source_table="ue") + + # Override to sentinel (empty data project). + from connectors.bigquery.access import BqAccess, BqProjects, get_bq_access + monkeypatch.setattr( + "app.api.query.get_bq_access", + lambda: BqAccess(BqProjects(billing="", data="")), + raising=False, + ) + get_bq_access.cache_clear() + + user_sql = "SELECT * FROM ue" + rewritten, did_rewrite = _rewrite_user_sql_for_bigquery_query( + user_sql, seeded_registry, + ) + assert did_rewrite is False + assert rewritten == user_sql + + +# --------------------------------------------------------------------------- +# Backwards-compat: dry-run helper still available + behaves the same +# --------------------------------------------------------------------------- + + +def test_existing_dry_run_helper_still_callable(): + """The original ``_rewrite_user_sql_for_bq_dry_run`` is now a thin + wrapper around the shared core rewriter (Pass 1 + Pass 2). Callers + that pass an explicit ``project`` argument keep working unchanged. + """ + from app.api.query import _rewrite_user_sql_for_bq_dry_run + + rewritten = _rewrite_user_sql_for_bq_dry_run( + sql="SELECT * FROM ue", + name_lookups=[("ue", "fin", "ue")], + project="some-prj", + ) + assert "`some-prj.fin.ue`" in rewritten + # The dry-run helper does NOT add a bigquery_query() wrapper; that's + # only the new execution-path helper's job. + assert "bigquery_query(" not in rewritten + + +# --------------------------------------------------------------------------- +# End-to-end: the /api/query handler must invoke the rewriter and execute +# the rewritten SQL (not the original) when there's a BQ-remote table. +# --------------------------------------------------------------------------- + + +def _auth(token: str) -> dict: + return {"Authorization": f"Bearer {token}"} + + +def _register_bq_remote_row(name: str, bucket: str, source_table: str) -> None: + from src.db import get_system_db + from src.repositories.table_registry import TableRegistryRepository + sys_conn = get_system_db() + try: + TableRegistryRepository(sys_conn).register( + id=f"bq.{bucket}.{source_table}", + name=name, + source_type="bigquery", + bucket=bucket, + source_table=source_table, + query_mode="remote", + ) + finally: + sys_conn.close() + + +@pytest.fixture +def stub_bq_for_endpoint(monkeypatch): + """Stub _bq_dry_run_bytes + get_bq_access at the endpoint level so the + cap-guard sees a real-looking BQ project but doesn't issue real RPCs. + """ + monkeypatch.setattr( + "app.api.query._bq_dry_run_bytes", + lambda *a, **k: 1024, # tiny — pass cap + raising=False, + ) + + class _FakeProjects: + data = "test-data-prj" + billing = "test-billing-prj" + + class _FakeBqAccess: + projects = _FakeProjects() + + monkeypatch.setattr( + "app.api.query.get_bq_access", + lambda: _FakeBqAccess(), + raising=False, + ) + + +def test_endpoint_executes_rewritten_sql_against_analytics( + seeded_app, stub_bq_for_endpoint, monkeypatch, +): + """The /api/query handler must call ``analytics.execute(rewritten_sql)`` + — NOT the user's original SQL — when a BQ-remote table is referenced. + Capture what reaches DuckDB and assert the bigquery_query() wrap is + present. + """ + _register_bq_remote_row("ue", "fin", "ue") + + # Capture analytics.execute calls. The handler does + # `analytics = get_analytics_db_readonly(); analytics.execute(sql)`, + # so we patch the connection factory to return a stub. + captured = {"sql": None} + + class _StubAnalytics: + description = [("c0",)] + def execute(self, sql, *args, **kwargs): + captured["sql"] = sql + class _R: + def fetchmany(self, _n): + return [] + return _R() + def close(self): + pass + + monkeypatch.setattr( + "app.api.query.get_analytics_db_readonly", + lambda: _StubAnalytics(), + raising=False, + ) + + c = seeded_app["client"] + token = seeded_app["admin_token"] + r = c.post( + "/api/query", + json={"sql": "SELECT count(*) FROM ue WHERE country = 'CZ'"}, + headers=_auth(token), + ) + assert r.status_code == 200, r.json() + sent = captured["sql"] + assert sent is not None, "analytics.execute was never called" + assert "bigquery_query(" in sent, ( + f"endpoint did not wrap user SQL in bigquery_query(); sent: {sent!r}" + ) + assert "test-data-prj" in sent + assert "`test-data-prj.fin.ue`" in sent + + +def test_endpoint_passes_original_sql_when_no_bq_table( + seeded_app, stub_bq_for_endpoint, monkeypatch, +): + """For queries that don't touch any BQ-remote registered name, the + handler must pass the original SQL through unchanged — the + ATTACH-catalog path handles local-source tables natively and any + rewrite would be wasted work.""" + captured = {"sql": None} + + class _StubAnalytics: + description = [("c0",)] + def execute(self, sql, *args, **kwargs): + captured["sql"] = sql + class _R: + def fetchmany(self, _n): + return [] + return _R() + def close(self): + pass + + monkeypatch.setattr( + "app.api.query.get_analytics_db_readonly", + lambda: _StubAnalytics(), + raising=False, + ) + + c = seeded_app["client"] + token = seeded_app["admin_token"] + user_sql = "SELECT 1 AS x" + r = c.post("/api/query", json={"sql": user_sql}, headers=_auth(token)) + assert r.status_code == 200, r.json() + assert captured["sql"] == user_sql + assert "bigquery_query(" not in captured["sql"] + + +def test_endpoint_wraps_rewritten_sql_with_outer_limit( + seeded_app, stub_bq_for_endpoint, monkeypatch, +): + """Memory-safety regression — when the rewriter fires, the handler + MUST wrap the bigquery_query() call in an outer ``LIMIT N+1`` so a + `SELECT *` against a billion-row remote table doesn't materialise the + full result into the worker before fetchmany applies the cap. + Code-review #2a fix. + """ + _register_bq_remote_row("ue", "fin", "ue") + + captured = {"sql": None} + + class _StubAnalytics: + description = [("c0",)] + def execute(self, sql, *args, **kwargs): + captured["sql"] = sql + class _R: + def fetchmany(self, _n): + return [] + return _R() + def close(self): + pass + + monkeypatch.setattr( + "app.api.query.get_analytics_db_readonly", + lambda: _StubAnalytics(), + raising=False, + ) + + c = seeded_app["client"] + token = seeded_app["admin_token"] + r = c.post( + "/api/query", + json={"sql": "SELECT * FROM ue", "limit": 100}, + headers=_auth(token), + ) + assert r.status_code == 200, r.json() + sent = captured["sql"] + # The bigquery_query() wrap is present, AND the whole thing is wrapped + # again with an outer LIMIT that includes the user-requested cap +1 + # (the +1 is the existing truncation-detection pattern). + assert "bigquery_query(" in sent + assert "_bqq_outer" in sent + assert "LIMIT 101" in sent # request.limit (100) + 1 + + +def test_endpoint_falls_back_to_original_sql_on_bq_parse_error( + seeded_app, stub_bq_for_endpoint, monkeypatch, +): + """When the rewritten ``bigquery_query()`` path fails with a parse- + level error (e.g. user SQL contained DuckDB-only syntax that BQ + can't parse), the handler MUST retry with the original SQL via the + ATTACH-catalog path so the user request still succeeds. Code-review + #4 fix. + """ + _register_bq_remote_row("ue", "fin", "ue") + + calls = {"sqls": []} + + class _StubAnalytics: + description = [("c0",)] + def execute(self, sql, *args, **kwargs): + calls["sqls"].append(sql) + # First call (rewritten) raises a BQ-style parse error; + # second call (original SQL fallback) returns rows. + if "bigquery_query(" in sql: + raise RuntimeError( + "BinderException: Query execution failed: " + "Syntax error: Unexpected token at [1:42]" + ) + class _R: + def fetchmany(self, _n): + return [(1,)] + return _R() + def close(self): + pass + + monkeypatch.setattr( + "app.api.query.get_analytics_db_readonly", + lambda: _StubAnalytics(), + raising=False, + ) + + c = seeded_app["client"] + token = seeded_app["admin_token"] + r = c.post( + "/api/query", + # DuckDB-only ::INT cast — survives identifier rewrite, BQ refuses. + json={"sql": "SELECT (count(*))::INT FROM ue"}, + headers=_auth(token), + ) + assert r.status_code == 200, r.json() + # Two execute calls: 1) rewritten (raised) 2) fallback to original. + assert len(calls["sqls"]) == 2 + assert "bigquery_query(" in calls["sqls"][0] + assert calls["sqls"][1] == "SELECT (count(*))::INT FROM ue" + + +def test_rewriter_uses_billing_project_for_bigquery_query_first_arg( + seeded_registry, monkeypatch, +): + """Devin-review BUG #1: `bigquery_query()` first arg is the + **billing** project (where BQ jobs are billed + executed), backtick + paths use the **data** project. In cross-project deploys the SA + has `serviceusage.services.use` only on the billing project, so + using the data project as billing → 403 USER_PROJECT_DENIED. + + Match the existing convention in v2_scan / v2_sample / v2_schema / + extractor. + """ + from app.api.query import _rewrite_user_sql_for_bigquery_query + _register_bq_remote( + seeded_registry, table_id="bq.fin.ue", name="ue", + bucket="fin", source_table="ue", + ) + _set_bq_project(monkeypatch, project="data-prj", billing="billing-prj") + + rewritten, did_rewrite = _rewrite_user_sql_for_bigquery_query( + "SELECT count(*) FROM ue", + seeded_registry, + ) + assert did_rewrite is True + # First arg of bigquery_query must be the billing project. + assert "bigquery_query('billing-prj'" in rewritten + # Backtick path must use the data project. + assert "`data-prj.fin.ue`" in rewritten + # And the data project must NOT appear as the first arg. + assert "bigquery_query('data-prj'" not in rewritten + + +def test_rewriter_skips_when_bq_row_bucket_contains_dot( + seeded_registry, monkeypatch, +): + """Devil's-advocate R1 finding #5: a BQ row whose `bucket` contains + `.` suggests the operator encoded a project prefix in the bucket + name. Wrapping under our single-project assumption could silently + target the wrong project. Rewriter must skip in that case (fall + through to ATTACH-catalog path which respects the operator's + `_remote_attach` configuration). + """ + from app.api.query import _rewrite_user_sql_for_bigquery_query + _register_bq_remote( + seeded_registry, + table_id="bq.other-prj.dataset.ue", + name="ue", + # Project-qualified bucket — the multi-project red flag. + bucket="other-prj.dataset", + source_table="ue", + ) + _set_bq_project(monkeypatch, "test-prj") + + rewritten, did_rewrite = _rewrite_user_sql_for_bigquery_query( + "SELECT count(*) FROM ue", + seeded_registry, + ) + # Skip — original SQL returned, no rewrite. + assert did_rewrite is False + assert rewritten == "SELECT count(*) FROM ue" + + +def test_fallback_does_not_trigger_on_user_column_typo( + seeded_app, stub_bq_for_endpoint, monkeypatch, +): + """Devil's-advocate R1 finding #2: previously the fallback + heuristic matched `Unrecognized name`, which BQ surfaces for both + DuckDB-only-name AND user-column-typo cases. The user-typo case + triggered re-running the original SQL through the slow ATTACH- + catalog path (90+ s) → 2× latency tax on every typo. + + Post-fix: heuristic only matches `Syntax error`. A BQ-side + `Unrecognized name: bad_col` should propagate as-is, NOT trigger + a fallback retry. + """ + _register_bq_remote_row("ue", "fin", "ue") + + calls = {"sqls": []} + + class _StubAnalytics: + description = [("c0",)] + def execute(self, sql, *args, **kwargs): + calls["sqls"].append(sql) + raise RuntimeError( + "BinderException: Query execution failed: " + "Unrecognized name: bad_col at [1:8]" + ) + def close(self): + pass + + monkeypatch.setattr( + "app.api.query.get_analytics_db_readonly", + lambda: _StubAnalytics(), + raising=False, + ) + + c = seeded_app["client"] + token = seeded_app["admin_token"] + r = c.post( + "/api/query", + json={"sql": "SELECT bad_col FROM ue"}, + headers=_auth(token), + ) + # Error propagates; fallback NOT triggered (only one execute call). + assert r.status_code in (400, 500, 502) + assert len(calls["sqls"]) == 1, ( + "user column typo must NOT trigger fallback retry" + ) + assert "bigquery_query(" in calls["sqls"][0] + + +def test_endpoint_does_not_fall_back_on_non_parse_errors( + seeded_app, stub_bq_for_endpoint, monkeypatch, +): + """Non-parse-error exceptions from the rewritten path (network, + quota, forbidden, generic runtime) must propagate, NOT silently + retry against the legacy path. Otherwise the legacy path would + just fail again and the user sees a slow + double-failure. + """ + _register_bq_remote_row("ue", "fin", "ue") + + calls = {"sqls": []} + + class _StubAnalytics: + description = [("c0",)] + def execute(self, sql, *args, **kwargs): + calls["sqls"].append(sql) + raise RuntimeError("Network unreachable: BQ endpoint timed out") + def close(self): + pass + + monkeypatch.setattr( + "app.api.query.get_analytics_db_readonly", + lambda: _StubAnalytics(), + raising=False, + ) + + c = seeded_app["client"] + token = seeded_app["admin_token"] + r = c.post( + "/api/query", + json={"sql": "SELECT count(*) FROM ue"}, + headers=_auth(token), + ) + # Generic 400 from the handler's outer except — body will surface + # the runtime error message; we just need to confirm no fallback. + assert r.status_code in (400, 500, 502) + assert len(calls["sqls"]) == 1, "must not retry on non-parse error"