Merge pull request #199 from keboola/zs/perf-bundle-0.39.0
perf(0.39.0): bundle — BQ query rewrite + session pool + chunked download + HTTP/2
This commit is contained in:
commit
6de7084c9f
12 changed files with 2976 additions and 91 deletions
74
CHANGELOG.md
74
CHANGELOG.md
|
|
@ -10,6 +10,80 @@ CalVer image tags (`stable-YYYY.MM.N`, `dev-YYYY.MM.N`) are produced for every C
|
|||
|
||||
## [Unreleased]
|
||||
|
||||
## [0.39.0] — 2026-05-06
|
||||
|
||||
### Performance
|
||||
- **`/api/query` (and `agnes query --remote`) now rewrites user SQL referencing
|
||||
`query_mode='remote'` BigQuery rows into a single `bigquery_query()` call
|
||||
before execute** (`app/api/query.py`). Pre-fix the master view
|
||||
(`CREATE VIEW <name> AS SELECT * FROM bigquery.<bucket>.<source_table>`) did
|
||||
not push WHERE / SELECT / LIMIT into BQ — the DuckDB BQ extension opened a
|
||||
Storage Read API session over the entire upstream table, scanning the full
|
||||
partitioned dataset before the local DuckDB filter ran. On 100M+ row
|
||||
remote-mode tables this was 50-100× slower than the equivalent direct
|
||||
`bigquery_query()` call (70-150 s vs 1.5 s) and frequently failed with
|
||||
`Response too large to return`. The rewriter (shared core with the existing
|
||||
dry-run helper) wraps the user's whole SQL in `bigquery_query('<project>',
|
||||
'<inner-sql>')` so the BQ planner receives the full query and applies
|
||||
partition pruning + projection pushdown server-side. Conservative
|
||||
fall-through: cross-source JOINs (BQ ↔ Keboola/Jira local), queries already
|
||||
containing `bigquery_query(`, and unconfigured BQ project all keep the
|
||||
original ATTACH-catalog path so behavior degrades gracefully.
|
||||
- **DuckDB BigQuery-extension session pool**
|
||||
(`connectors/bigquery/access.py`). `BqAccess.duckdb_session()` now acquires
|
||||
pre-warmed connections from a bounded process-local pool instead of running
|
||||
`INSTALL bigquery; LOAD bigquery; CREATE SECRET; ATTACH …` on every request.
|
||||
Each acquire saves the ~0.5 s extension-load + secret-creation cost when
|
||||
the pool has a warm entry; auth SECRET is refreshed on acquire so a
|
||||
long-lived pooled entry doesn't keep a stale GCE metadata token past its
|
||||
TTL. Pool size is configurable via `data_source.bigquery.session_pool_size`
|
||||
(default 4; sentinel `0` disables pooling). Affects every BQ-touching path
|
||||
— `/api/query`, `/api/v2/scan`, `/api/v2/sample`, `/api/v2/schema`,
|
||||
materialize, and the orchestrator's remote-attach.
|
||||
- **`agnes pull` chunked download for large parquets**: when the server
|
||||
advertises `accept-ranges: bytes` and a parquet exceeds
|
||||
`AGNES_PULL_CHUNK_THRESHOLD_BYTES` (default 50 MB), the CLI now splits
|
||||
the file into N parallel HTTP Range requests
|
||||
(`AGNES_PULL_CHUNK_PARALLELISM`, default 4, capped 1..16) and assembles
|
||||
the parts into the destination atomically. Targets the per-flow-shaped
|
||||
network (corp VPN with per-TCP-connection rate-limiting) where a single
|
||||
stream is throttled but N parallel streams over the same connection
|
||||
scale roughly linearly. Falls back to single-stream when the server
|
||||
responds 200 instead of 206 to a Range probe, when no
|
||||
`accept-ranges: bytes` is advertised, or when content is below the
|
||||
threshold — no behavior change in the small-file / non-cooperating-
|
||||
server cases.
|
||||
- **Persistent HTTP/2 client across `agnes pull`**: `stream_download` now
|
||||
routes through a process-wide pooled `httpx.Client` so N parquet
|
||||
downloads share a single TLS handshake; HTTP/2 multiplexing
|
||||
(when the optional `h2` package is installed) lets all chunk Range
|
||||
requests share one TCP connection. Gracefully falls back to HTTP/1.1
|
||||
pooling when `h2` is missing — no crash, just slightly less benefit.
|
||||
|
||||
### Fixed
|
||||
- **BigQuery `responseTooLarge` no longer surfaces as a generic 400 / 502 with
|
||||
the raw upstream message** (`connectors/bigquery/access.py`). The
|
||||
`translate_bq_error` helper now classifies "Response too large to return"
|
||||
errors via a dedicated `bq_response_too_large` kind (HTTP 400) with an
|
||||
actionable hint pointing at the WHERE / aggregation / materialized-table
|
||||
remediations. Pre-fix this failure mode fell through to the generic
|
||||
`bq_bad_request` mapping, which implied the user's SQL had a syntax error
|
||||
— wrong root cause. Affects every BQ-touching path (`/api/query`,
|
||||
`/api/v2/scan`, `/api/v2/sample`, `/api/v2/schema`, materialize) since
|
||||
they all share `translate_bq_error`.
|
||||
|
||||
### Added
|
||||
- New optional dependency `h2>=4.1.0` (HTTP/2 transport for httpx). Pure
|
||||
performance — `agnes pull` works on HTTP/1.1 if the install skips it.
|
||||
- **Textual progress fallback for non-TTY `agnes pull`**: when stderr is
|
||||
not a terminal (Claude Code SessionStart hook, CI runner, Docker log
|
||||
capture, …), `agnes pull --no-quiet` now emits a plain-text progress
|
||||
line per file at most every 10% or 30 s, plus a final completion line.
|
||||
Replaces the previous Rich-bar-on-pipe behavior that either suppressed
|
||||
output entirely or leaked ANSI escape sequences. TTY path unchanged
|
||||
(Rich progress bar with bytes / speed / ETA, aggregated per-file
|
||||
across chunked-download chunks).
|
||||
|
||||
## [0.38.3] — 2026-05-06
|
||||
|
||||
### Changed
|
||||
|
|
|
|||
319
app/api/query.py
319
app/api/query.py
|
|
@ -31,6 +31,50 @@ logger = logging.getLogger(__name__)
|
|||
|
||||
router = APIRouter(prefix="/api/query", tags=["query"])
|
||||
|
||||
|
||||
# Heuristic: did the BQ-side execution of a `bigquery_query()`-rewritten
|
||||
# query reject the inner SQL because of a **DuckDB-vs-BQ dialect mismatch**
|
||||
# specifically? We want to fall back ONLY on cases where the same SQL
|
||||
# would have worked under the legacy DuckDB ATTACH-catalog path —
|
||||
# DuckDB-only syntax (``::INT`` casts, ``STRPTIME``, COALESCE arity quirks)
|
||||
# that BQ's parser rejects.
|
||||
#
|
||||
# We DO NOT want to fall back on user-data errors that BQ would reject in
|
||||
# either path (unknown column name, wrong function signature, invalid cast
|
||||
# of literal user input). For those, the legacy ATTACH path would issue
|
||||
# the same query and fail the same way — just 50-100× slower. Triggering
|
||||
# fallback there is a 2× latency tax on every typo (devil's-advocate R1
|
||||
# finding #2).
|
||||
#
|
||||
# Conservative pattern set: only the BQ-emitted ``Syntax error: <detail>``
|
||||
# (with trailing colon) covers genuine parse-level dialect mismatch.
|
||||
# ``Unrecognized name`` etc. surface for both bad-user-column AND
|
||||
# DuckDB-only-name cases — the safe assumption is that user-column-typo
|
||||
# is the more common case, so we don't fall back. If a deployment
|
||||
# surfaces a real DuckDB-only-name regression, it's better caught as
|
||||
# a BinderException with the original SQL in the logs than amplified
|
||||
# via slow-path retry.
|
||||
#
|
||||
# The trailing colon (devil's-advocate R2 finding #3) anchors the match
|
||||
# against BQ's verbatim error format and avoids false positives where
|
||||
# the literal substring `Syntax error` appears in a user's SQL string
|
||||
# literal that DuckDB then echoes back in an unrelated error message
|
||||
# (e.g. `WHERE log_msg = 'Syntax error in foo'` failing on quota).
|
||||
_BQ_REWRITE_PARSE_ERROR_PATTERNS = (
|
||||
"Syntax error: ",
|
||||
"syntax error: ",
|
||||
)
|
||||
|
||||
|
||||
def _looks_like_bq_rewrite_parse_error(exc: BaseException) -> bool:
|
||||
"""Return True when ``exc`` is the BQ-rejected-inner-SQL flavour we
|
||||
want to fall back from. Conservative: matches against the exception
|
||||
message text only, no isinstance checks, so it works whether the
|
||||
DuckDB BQ extension wrapped the error as BinderException, IOException,
|
||||
or a plain Python Exception."""
|
||||
msg = str(exc)
|
||||
return any(pat in msg for pat in _BQ_REWRITE_PARSE_ERROR_PATTERNS)
|
||||
|
||||
# Issue #160 §4.3.1 — direct `bq.<dataset>.<source_table>` references in user
|
||||
# SQL. Catalog token accepts both `bq` (the unquoted DuckDB-style name) and
|
||||
# `"bq"` (quoted identifier). DuckDB resolves both to the same ATTACHed
|
||||
|
|
@ -192,8 +236,64 @@ def execute_query(
|
|||
else contextlib.nullcontext()
|
||||
)
|
||||
with guard:
|
||||
# Open in read-only mode for extra safety
|
||||
# Performance fix: rewrite user SQL referencing BQ-remote tables
|
||||
# to a single ``bigquery_query()`` call so WHERE / projection /
|
||||
# LIMIT push into BQ via jobs.query (1-2 s) instead of falling
|
||||
# through DuckDB's ATTACH-catalog Storage Read API session over
|
||||
# the full table (often 70-150 s, fails with "Response too
|
||||
# large to return" on >100M-row sources). Helper returns the
|
||||
# original SQL unchanged when rewriting would be unsafe
|
||||
# (cross-source JOIN, no BQ tables referenced, double-wrap).
|
||||
execution_sql, did_rewrite = _rewrite_user_sql_for_bigquery_query(
|
||||
request.sql, conn,
|
||||
)
|
||||
if did_rewrite:
|
||||
# Memory-safety: ``bigquery_query()`` materialises the entire
|
||||
# BQ result into DuckDB before fetchmany sees it (vs the
|
||||
# ATTACH-catalog Storage Read API path, which streams rows
|
||||
# lazily). Wrap the rewritten SQL in an outer ``LIMIT N+1``
|
||||
# so a `SELECT *` against a billion-row remote table doesn't
|
||||
# buffer the full table into the worker process — the cap
|
||||
# is pushed into the BQ job itself. Aliased subquery so the
|
||||
# outer LIMIT applies to the final rewritten result.
|
||||
execution_sql = (
|
||||
f"SELECT * FROM ({execution_sql}) AS _bqq_outer "
|
||||
f"LIMIT {request.limit + 1}"
|
||||
)
|
||||
logger.info(
|
||||
"query_rewrite_to_bigquery_query: user_id=%s — wrapped "
|
||||
"SQL in bigquery_query() with outer LIMIT for BQ "
|
||||
"predicate pushdown",
|
||||
user_id,
|
||||
)
|
||||
else:
|
||||
logger.debug(
|
||||
"query_rewrite_skipped: user_id=%s — running original "
|
||||
"SQL via ATTACH-catalog path",
|
||||
user_id,
|
||||
)
|
||||
|
||||
# Open in read-only mode for extra safety. If the rewritten
|
||||
# path errors (e.g. user SQL contained DuckDB-only syntax —
|
||||
# ``::INT`` casts, ``STRPTIME``, COALESCE arity differences —
|
||||
# that survives identifier rewrite but BQ refuses), fall back
|
||||
# to the original SQL via the legacy ATTACH-catalog path so
|
||||
# the request still succeeds (slower, but correct). Same
|
||||
# safety contract as the dry-run fallback in
|
||||
# ``_bq_quota_and_cap_guard``.
|
||||
try:
|
||||
result = analytics.execute(execution_sql).fetchmany(request.limit + 1)
|
||||
except Exception as exc:
|
||||
if did_rewrite and _looks_like_bq_rewrite_parse_error(exc):
|
||||
logger.warning(
|
||||
"query_rewrite_fallback: user_id=%s — bigquery_query() "
|
||||
"rewrite rejected by BQ (%s); retrying via "
|
||||
"ATTACH-catalog path",
|
||||
user_id, type(exc).__name__,
|
||||
)
|
||||
result = analytics.execute(request.sql).fetchmany(request.limit + 1)
|
||||
else:
|
||||
raise
|
||||
columns = [desc[0] for desc in analytics.description] if analytics.description else []
|
||||
truncated = len(result) > request.limit
|
||||
rows = result[:request.limit]
|
||||
|
|
@ -432,12 +532,11 @@ def _bq_guardrail_inputs(
|
|||
return dry_run, name_lookups, None
|
||||
|
||||
|
||||
def _rewrite_user_sql_for_bq_dry_run(
|
||||
def _rewrite_bq_table_refs_to_native(
|
||||
sql: str, name_lookups: list, project: str,
|
||||
) -> str:
|
||||
"""Rewrite user SQL from DuckDB-flavor to BQ-native so a single
|
||||
`_bq_dry_run_bytes` call can estimate scan size for the EXACT query
|
||||
the user submitted (issue #171).
|
||||
"""Core identifier rewrite: DuckDB-flavor table references → BQ-native
|
||||
backtick form. Shared between dry-run and execution-path rewriters.
|
||||
|
||||
Two transformations:
|
||||
|
||||
|
|
@ -463,13 +562,15 @@ def _rewrite_user_sql_for_bq_dry_run(
|
|||
inside a string literal (e.g. an `IN (...)` value or a `LIKE` pattern)
|
||||
will also be rewritten. This is acceptable because (a) it's vanishingly
|
||||
rare to have a string literal exactly matching a registered table name,
|
||||
and (b) when it does happen the dry-run errors out and the caller falls
|
||||
back to the per-table SELECT * estimate (current behavior, no regression).
|
||||
and (b) when it does happen the caller's error path covers the case
|
||||
(dry-run falls back to per-table SELECT * estimate; execution falls
|
||||
through to the ATTACH-catalog path).
|
||||
|
||||
CTE shadowing: a `WITH unit_economics AS (...)` followed by `FROM
|
||||
unit_economics` would also rewrite the `FROM` reference. BQ then treats
|
||||
the CTE as unreferenced (legal) and the dry-run estimates the rewritten
|
||||
physical table — likely an over-estimate. Same fallback path covers this.
|
||||
the CTE as unreferenced (legal) and the rewriter's caller deals with
|
||||
the consequence — over-estimation for dry-run, fall-through-to-ATTACH
|
||||
via BQ parse error for execution.
|
||||
"""
|
||||
out = sql
|
||||
|
||||
|
|
@ -509,6 +610,206 @@ def _rewrite_user_sql_for_bq_dry_run(
|
|||
return out
|
||||
|
||||
|
||||
def _rewrite_user_sql_for_bq_dry_run(
|
||||
sql: str, name_lookups: list, project: str,
|
||||
) -> str:
|
||||
"""Rewrite user SQL from DuckDB-flavor to BQ-native so a single
|
||||
`_bq_dry_run_bytes` call can estimate scan size for the EXACT query
|
||||
the user submitted (issue #171). Thin wrapper around the shared
|
||||
core; kept as a stable name for callers in /api/query's cap-guard.
|
||||
"""
|
||||
return _rewrite_bq_table_refs_to_native(sql, name_lookups, project)
|
||||
|
||||
|
||||
def _rewrite_user_sql_for_bigquery_query(
|
||||
user_sql: str, conn: duckdb.DuckDBPyConnection,
|
||||
) -> tuple[str, bool]:
|
||||
"""Rewrite user SQL so the entire query ships to BQ as a single
|
||||
``bigquery_query(<project>, <inner-sql>)`` call.
|
||||
|
||||
Returns ``(rewritten_sql, did_rewrite)``. When ``did_rewrite`` is
|
||||
``False``, the caller MUST execute the original ``user_sql`` via the
|
||||
ATTACH-catalog path (slow but correct); the rewriter is conservative
|
||||
on purpose — wrapping cross-source queries in ``bigquery_query()``
|
||||
would silently lose the local-side data.
|
||||
|
||||
Why this matters
|
||||
----------------
|
||||
The orchestrator's master view (``CREATE VIEW name AS SELECT * FROM
|
||||
bigquery.<bucket>.<source_table>``) does not push WHERE / projections
|
||||
into BQ when DuckDB resolves the query — the BQ extension opens a
|
||||
Storage Read API session over the entire table, which on multi-100M-row
|
||||
tables is 50-100× slower than letting BQ run the query server-side.
|
||||
Wrapping the user's SQL in ``bigquery_query('<project>', '<inner>')``
|
||||
makes the BQ extension issue a ``jobs.query`` instead, with full
|
||||
predicate pushdown.
|
||||
|
||||
Skip rules (returns ``(user_sql, False)``)
|
||||
------------------------------------------
|
||||
1. No registered ``query_mode='remote'`` BQ row referenced in the SQL.
|
||||
Nothing to rewrite — original SQL passes through unchanged.
|
||||
2. User SQL already contains ``bigquery_query(`` — never double-wrap.
|
||||
(The /api/query keyword denylist also blocks this in production;
|
||||
defensive guard for callers in other contexts.)
|
||||
3. SQL also references a non-BQ master view (Keboola/Jira local-mode
|
||||
table). Wrapping would lose those references — fall through to
|
||||
ATTACH-catalog so the cross-source query still runs.
|
||||
4. ``get_bq_access()`` returns the unconfigured sentinel
|
||||
(``data == ''``). No project to fill into ``bigquery_query()``.
|
||||
|
||||
Edge cases preserved by design
|
||||
------------------------------
|
||||
- CTEs / sub-queries referencing BQ tables: the table-name rewrite
|
||||
happens at every match position, then the whole SQL is wrapped in
|
||||
one ``bigquery_query()``. BQ supports CTEs, so this works.
|
||||
- Multiple BQ tables, same project: combined into ONE wrap (single
|
||||
jobs.query). DuckDB's BQ extension doesn't support multi-project
|
||||
JOINs in a single ``bigquery_query()`` call today; if/when the
|
||||
registry grows per-table source_project, this helper would need to
|
||||
gate on cross-project mixing.
|
||||
- ``bq."ds"."tbl"`` direct paths: rewritten to BQ-native backticks
|
||||
via the same shared core as dry-run.
|
||||
"""
|
||||
# Skip 2: don't double-wrap. Cheap pre-check before any registry I/O.
|
||||
if "bigquery_query(" in user_sql.lower():
|
||||
return user_sql, False
|
||||
|
||||
# Find all referenced BQ remote-mode rows (bare-name + direct bq.path).
|
||||
# Mirrors the non-RBAC parts of `_bq_guardrail_inputs`.
|
||||
sql_lower = user_sql.lower()
|
||||
name_lookups: list = []
|
||||
seen_paths: set = set()
|
||||
|
||||
try:
|
||||
repo = TableRegistryRepository(conn)
|
||||
bq_rows = repo.list_by_source("bigquery")
|
||||
all_rows = repo.list_all()
|
||||
except Exception:
|
||||
# Registry read failure — let the original SQL run through the
|
||||
# ATTACH-catalog path. The handler's generic error path will
|
||||
# surface anything user-visible.
|
||||
return user_sql, False
|
||||
|
||||
# Multi-project guard (devil's-advocate R1 finding #5): the rewriter
|
||||
# assumes every BQ-remote table resolves under the single
|
||||
# `bq.projects.data` project. The current registry schema doesn't
|
||||
# store `source_project` per row, so `bucket` is the only place a
|
||||
# cross-project leak could hide. A bucket containing `.` (e.g.
|
||||
# `other_prj.dataset`) suggests the operator encoded a project
|
||||
# prefix into the bucket name — wrapping that under our single
|
||||
# project would silently target the wrong project. Conservative
|
||||
# skip: any BQ row whose bucket contains `.` aborts the rewrite,
|
||||
# falling through to the legacy ATTACH-catalog path which uses
|
||||
# whatever resolution the operator's _remote_attach configured.
|
||||
for r in bq_rows:
|
||||
if (r.get("query_mode") or "") != "remote":
|
||||
continue
|
||||
bucket = r.get("bucket")
|
||||
source_table = r.get("source_table")
|
||||
name = r.get("name")
|
||||
if not (bucket and source_table and name):
|
||||
continue
|
||||
if "." in str(bucket):
|
||||
# Project-qualified bucket — can't safely wrap under our
|
||||
# single-project assumption. Bail out completely so we don't
|
||||
# mix rewritten and non-rewritten BQ paths in one query.
|
||||
return user_sql, False
|
||||
pattern = r'\b' + re.escape(str(name).lower()) + r'\b'
|
||||
if re.search(pattern, sql_lower):
|
||||
key = (bucket.lower(), source_table.lower())
|
||||
if key not in seen_paths:
|
||||
seen_paths.add(key)
|
||||
name_lookups.append((str(name), bucket, source_table))
|
||||
|
||||
# Direct bq."ds"."tbl" references — pull the registered (bucket,
|
||||
# source_table) pair so the inner SQL receives a backticked BQ-native
|
||||
# path. Mismatched / unregistered paths are caught upstream by the
|
||||
# guardrail; here we just collect the mappings the rewriter needs.
|
||||
direct_paths: set[tuple[str, str]] = set()
|
||||
for m in BQ_PATH.finditer(user_sql):
|
||||
bucket_raw = m.group(1).strip('"')
|
||||
source_table_raw = m.group(2).strip('"')
|
||||
direct_paths.add((bucket_raw, source_table_raw))
|
||||
|
||||
if not name_lookups and not direct_paths:
|
||||
# Skip 1: no BQ tables referenced.
|
||||
return user_sql, False
|
||||
|
||||
# Skip 3: cross-source query (BQ + local-mode). If user SQL also
|
||||
# references a non-BQ master view, we can't push the whole thing to
|
||||
# BQ — DuckDB needs to do the join.
|
||||
bq_names_lc = {n.lower() for n, _, _ in name_lookups}
|
||||
for r in all_rows:
|
||||
st = (r.get("source_type") or "").lower()
|
||||
qm = (r.get("query_mode") or "").lower()
|
||||
if st == "bigquery" and qm == "remote":
|
||||
continue # already handled
|
||||
name = r.get("name")
|
||||
if not name:
|
||||
continue
|
||||
name_lc = str(name).lower()
|
||||
if name_lc in bq_names_lc:
|
||||
# Same name registered both BQ-remote and local? Pathological;
|
||||
# skip as a safety measure.
|
||||
return user_sql, False
|
||||
if re.search(r'\b' + re.escape(name_lc) + r'\b', sql_lower):
|
||||
logger.info(
|
||||
"rewrite_skip_cross_source: user SQL references both "
|
||||
"BQ-remote and local-mode tables; falling back to "
|
||||
"ATTACH-catalog path",
|
||||
)
|
||||
return user_sql, False
|
||||
|
||||
# Skip 4: BQ project not configured.
|
||||
try:
|
||||
bq = get_bq_access()
|
||||
data_project = bq.projects.data
|
||||
# The first arg to `bigquery_query()` is the **execution / billing**
|
||||
# project — the project under which the BQ job runs and is billed.
|
||||
# In cross-project deployments the SA may only have
|
||||
# `serviceusage.services.use` on the billing project, so passing
|
||||
# the data project there returns 403 USER_PROJECT_DENIED. Match
|
||||
# the convention used everywhere else in the codebase (v2_scan /
|
||||
# v2_sample / v2_schema / extractor): backtick paths use the
|
||||
# **data** project, `bigquery_query()` first-arg uses the
|
||||
# **billing** project. For single-project deploys the two are
|
||||
# identical so the fix is a no-op there.
|
||||
billing_project = bq.projects.billing or data_project
|
||||
except Exception:
|
||||
return user_sql, False
|
||||
if not data_project:
|
||||
return user_sql, False
|
||||
|
||||
# Rewrite identifiers using the DATA project — backtick paths
|
||||
# `<data-project>.<dataset>.<table>` resolve to the same logical
|
||||
# source no matter which project bills the query.
|
||||
inner_sql = _rewrite_bq_table_refs_to_native(user_sql, name_lookups, data_project)
|
||||
|
||||
# Embed the inner SQL using DuckDB's dollar-quoted string literal form
|
||||
# (`$tag$ ... $tag$`). Naive `replace("'", "''")` doubling misses
|
||||
# backslash-escape sequences DuckDB's lexer recognises (`\\`, `\n`,
|
||||
# `\t`, …) — a predicate like `WHERE name = 'O\'Brien'` is unsafe
|
||||
# under doubling. Dollar-quoting takes the inner SQL verbatim with no
|
||||
# escape sequences whatsoever, so the user's exact bytes reach BQ.
|
||||
# Tag is a fixed conventional value; the absurdly unlikely collision
|
||||
# (user SQL containing the literal `$bqq_inner$`) falls back to the
|
||||
# legacy doubling path so the rewrite still proceeds — over-doubled
|
||||
# quotes are at worst a parse error caught by the handler's fallback
|
||||
# at the call site, not a silent bad result.
|
||||
DOLLAR_TAG = "$bqq_inner$"
|
||||
if DOLLAR_TAG in inner_sql:
|
||||
escaped_inner = inner_sql.replace("'", "''")
|
||||
rewritten = (
|
||||
f"SELECT * FROM bigquery_query('{billing_project}', '{escaped_inner}')"
|
||||
)
|
||||
else:
|
||||
rewritten = (
|
||||
f"SELECT * FROM bigquery_query('{billing_project}', "
|
||||
f"{DOLLAR_TAG}{inner_sql}{DOLLAR_TAG})"
|
||||
)
|
||||
return rewritten, True
|
||||
|
||||
|
||||
@contextlib.contextmanager
|
||||
def _bq_quota_and_cap_guard(
|
||||
*,
|
||||
|
|
|
|||
525
cli/client.py
525
cli/client.py
|
|
@ -1,8 +1,13 @@
|
|||
"""HTTP client wrapper for CLI — handles auth, retries, streaming."""
|
||||
|
||||
import atexit
|
||||
import glob
|
||||
import os
|
||||
import re
|
||||
import threading
|
||||
import time
|
||||
import traceback
|
||||
from concurrent.futures import ThreadPoolExecutor, as_completed
|
||||
from datetime import datetime, timezone
|
||||
from pathlib import Path
|
||||
from typing import Optional
|
||||
|
|
@ -11,6 +16,60 @@ import httpx
|
|||
|
||||
from cli.config import _config_dir, get_server_url, get_token
|
||||
|
||||
|
||||
# PID-suffixed tmp / part files — see `_download_chunked` and
|
||||
# `_download_single_stream`. We extract the embedded PID and reap any
|
||||
# leftover whose process is no longer alive on every pull. Without this,
|
||||
# every SIGKILL'd pull leaks files indefinitely (devil's-advocate R3
|
||||
# finding #1).
|
||||
_PID_SUFFIX_RE = re.compile(r"\.(\d+)\.(?:tmp|part\d+)$")
|
||||
|
||||
|
||||
def _is_pid_alive(pid: int) -> bool:
|
||||
"""Return True if a process with the given PID exists. POSIX-only;
|
||||
Windows users get the conservative `True` (file kept) which means
|
||||
no reaping but also no false-deletion of a live sibling."""
|
||||
if pid <= 0:
|
||||
return False
|
||||
try:
|
||||
# Signal 0 = no-op kill; raises ProcessLookupError when PID is
|
||||
# gone, PermissionError when the PID exists but isn't ours
|
||||
# (still alive, just owned by someone else — keep the file).
|
||||
os.kill(pid, 0)
|
||||
return True
|
||||
except ProcessLookupError:
|
||||
return False
|
||||
except PermissionError:
|
||||
return True
|
||||
except Exception:
|
||||
# Anything else (e.g. AttributeError on Windows where os.kill
|
||||
# exists but signal 0 isn't supported the same way): be
|
||||
# conservative and don't reap.
|
||||
return True
|
||||
|
||||
|
||||
def _reap_dead_pid_leftovers(target_path: str) -> None:
|
||||
"""Remove `<target>.{pid}.tmp` and `<target>.{pid}.partN` files
|
||||
whose embedded PID is no longer alive. Called at the start of every
|
||||
download to keep the parquet directory tidy across SIGKILL'd or
|
||||
crashed prior runs. Never raises — leaked file is preferable to
|
||||
failing the new pull on a permission error."""
|
||||
candidates = glob.glob(f"{target_path}.*.tmp") + glob.glob(f"{target_path}.*.part*")
|
||||
for path in candidates:
|
||||
m = _PID_SUFFIX_RE.search(path)
|
||||
if not m:
|
||||
continue
|
||||
try:
|
||||
pid = int(m.group(1))
|
||||
except ValueError:
|
||||
continue
|
||||
if _is_pid_alive(pid):
|
||||
continue
|
||||
try:
|
||||
os.unlink(path)
|
||||
except OSError:
|
||||
pass
|
||||
|
||||
# Retry policy for transient failures during stream downloads. Scoped to
|
||||
# network issues and 5xx — 4xx (auth, 404, 400) is NOT retried. Tunable via
|
||||
# env for tests; defaults sit in the "one flaky network blip" window.
|
||||
|
|
@ -22,6 +81,18 @@ _RETRY_BACKOFFS_S = (0.3, 1.0, 3.0) # seconds before attempt 2, 3, 4
|
|||
# timeout dies long before BQ finishes. Operators tune via AGNES_QUERY_TIMEOUT.
|
||||
QUERY_TIMEOUT_S = float(os.environ.get("AGNES_QUERY_TIMEOUT", "300"))
|
||||
|
||||
# Range-chunked parallel download — see `stream_download` docstring. Defaults
|
||||
# tuned for the corp-VPN per-flow rate-limiting case (single-stream throttled
|
||||
# but N parallel range requests scale linearly). Disabled implicitly for
|
||||
# files below the threshold or when the server doesn't advertise byte-range
|
||||
# support. Operators can hard-disable by setting parallelism to 1.
|
||||
_CHUNK_PARALLELISM = max(1, min(16, int(
|
||||
os.environ.get("AGNES_PULL_CHUNK_PARALLELISM", "4"),
|
||||
)))
|
||||
_CHUNK_THRESHOLD_BYTES = int(
|
||||
os.environ.get("AGNES_PULL_CHUNK_THRESHOLD_BYTES", str(50 * 1024 * 1024)),
|
||||
)
|
||||
|
||||
|
||||
# ── Transport-error translation ─────────────────────────────────────────
|
||||
# Pavel's Issue #185 Phase 3B caught the failure mode: when httpx raises
|
||||
|
|
@ -143,7 +214,13 @@ def _translate_transport_error(
|
|||
|
||||
|
||||
def get_client(timeout: float = 30.0) -> httpx.Client:
|
||||
"""Get an authenticated httpx client."""
|
||||
"""Get an authenticated httpx client.
|
||||
|
||||
This factory creates a fresh client per call — used by the small
|
||||
`api_*` helpers (one request, then close). The big-stream path
|
||||
(`stream_download`) routes through `_get_shared_client()` to amortize
|
||||
TLS handshakes and HTTP/2 multiplexing across N parquet downloads.
|
||||
"""
|
||||
token = get_token()
|
||||
headers = {}
|
||||
if token:
|
||||
|
|
@ -155,6 +232,80 @@ def get_client(timeout: float = 30.0) -> httpx.Client:
|
|||
)
|
||||
|
||||
|
||||
# ── Shared persistent client ────────────────────────────────────────────
|
||||
# `agnes pull` issues N stream_download calls — one per parquet — plus
|
||||
# (with chunked downloads) M Range requests per file. Without pooling,
|
||||
# each call performs a fresh TLS handshake; with HTTP/2 enabled, all
|
||||
# those requests multiplex over a single TCP connection. The shared
|
||||
# client is created lazily on first stream-download request, kept alive
|
||||
# for the duration of the process, and closed at exit.
|
||||
#
|
||||
# HTTP/2 requires the optional `h2` package. If it's unavailable (slim
|
||||
# install), we fall back to HTTP/1.1 — pooling alone still saves the
|
||||
# handshake cost — and never raise. The CLI must not crash on `agnes
|
||||
# pull` because of an h2 import error.
|
||||
|
||||
_SHARED_CLIENT: Optional[httpx.Client] = None
|
||||
_SHARED_CLIENT_LOCK = threading.Lock()
|
||||
|
||||
|
||||
def _get_shared_client() -> httpx.Client:
|
||||
"""Lazily create + return a process-wide httpx.Client.
|
||||
|
||||
Pool defaults: keep up to 32 keepalive connections (covers the
|
||||
chunk-parallelism cap of 16 × 2 simultaneous files comfortably) and
|
||||
cap the total at 64 so a runaway loop can't open thousands of
|
||||
sockets. HTTP/2 is opt-in via httpx's `http2=True` and gracefully
|
||||
degrades when the `h2` extra is missing.
|
||||
"""
|
||||
global _SHARED_CLIENT
|
||||
with _SHARED_CLIENT_LOCK:
|
||||
if _SHARED_CLIENT is not None:
|
||||
return _SHARED_CLIENT
|
||||
token = get_token()
|
||||
headers = {}
|
||||
if token:
|
||||
headers["Authorization"] = f"Bearer {token}"
|
||||
limits = httpx.Limits(
|
||||
max_keepalive_connections=32,
|
||||
max_connections=64,
|
||||
)
|
||||
try:
|
||||
client = httpx.Client(
|
||||
base_url=get_server_url(),
|
||||
headers=headers,
|
||||
timeout=300.0,
|
||||
http2=True,
|
||||
limits=limits,
|
||||
)
|
||||
except (ImportError, RuntimeError):
|
||||
# `h2` not installed → httpx raises; fall back to HTTP/1.1.
|
||||
# Pooling alone still amortizes the TLS handshake.
|
||||
client = httpx.Client(
|
||||
base_url=get_server_url(),
|
||||
headers=headers,
|
||||
timeout=300.0,
|
||||
limits=limits,
|
||||
)
|
||||
_SHARED_CLIENT = client
|
||||
return client
|
||||
|
||||
|
||||
def _close_shared_client() -> None:
|
||||
"""Close the shared client and clear the slot. Safe to call twice."""
|
||||
global _SHARED_CLIENT
|
||||
with _SHARED_CLIENT_LOCK:
|
||||
if _SHARED_CLIENT is not None:
|
||||
try:
|
||||
_SHARED_CLIENT.close()
|
||||
except Exception:
|
||||
pass
|
||||
_SHARED_CLIENT = None
|
||||
|
||||
|
||||
atexit.register(_close_shared_client)
|
||||
|
||||
|
||||
def api_get(path: str, *, timeout: float = 30.0, **kwargs) -> httpx.Response:
|
||||
try:
|
||||
with get_client(timeout=timeout) as client:
|
||||
|
|
@ -197,23 +348,250 @@ def _is_transient(exc: Exception) -> bool:
|
|||
return False
|
||||
|
||||
|
||||
def stream_download(path: str, target_path: str, progress_callback=None) -> int:
|
||||
"""Stream a file to `target_path` atomically and with retries.
|
||||
def _read_chunk_threshold_bytes() -> int:
|
||||
"""Re-read threshold each call so tests / operators can flip it via
|
||||
env var without restarting the process."""
|
||||
try:
|
||||
return int(os.environ.get(
|
||||
"AGNES_PULL_CHUNK_THRESHOLD_BYTES", str(_CHUNK_THRESHOLD_BYTES),
|
||||
))
|
||||
except ValueError:
|
||||
return _CHUNK_THRESHOLD_BYTES
|
||||
|
||||
Durability properties:
|
||||
- Writes to `target_path + ".tmp"`, then `os.replace` on success. The
|
||||
real target file never exists in a half-written state.
|
||||
- Retries up to `_RETRY_ATTEMPTS` times on transient errors (network
|
||||
blip, 5xx); 4xx (auth/404) is raised immediately.
|
||||
- No hash check here — that's done in the sync command against the
|
||||
manifest hash, because only the caller knows the expected value.
|
||||
|
||||
def _read_chunk_parallelism() -> int:
|
||||
"""Re-read parallelism each call (same rationale as threshold). Floor 1,
|
||||
ceiling 16."""
|
||||
try:
|
||||
n = int(os.environ.get(
|
||||
"AGNES_PULL_CHUNK_PARALLELISM", str(_CHUNK_PARALLELISM),
|
||||
))
|
||||
except ValueError:
|
||||
n = _CHUNK_PARALLELISM
|
||||
return max(1, min(16, n))
|
||||
|
||||
|
||||
def _probe_range_support(client: httpx.Client, path: str) -> tuple[int, bool]:
|
||||
"""Send HEAD; return (content-length, accepts-byte-ranges).
|
||||
|
||||
`(0, False)` means "we couldn't tell — fall back to single-stream".
|
||||
Never raises; transport errors during the probe are treated as
|
||||
"no chunking, try the GET instead and let it surface the failure
|
||||
in the normal retry loop".
|
||||
|
||||
Probe order: HEAD first (cheap, idempotent), then GET-with-tiny-range
|
||||
fallback. The HEAD path covers Caddy's `file_server` (which advertises
|
||||
HEAD) and Caddy's `reverse_proxy` (which forwards HEAD upstream). The
|
||||
GET-fallback covers the dev `docker compose up` deployment where
|
||||
requests go straight to FastAPI's GET-only `/api/data/{tid}/download`
|
||||
route — FastAPI returns **405 Method Not Allowed** to a HEAD on a
|
||||
GET-only route, which without this fallback would silently disable
|
||||
chunked download for every dev / non-TLS install. The GET-with-Range
|
||||
probe asks for 1 byte so the server response is bounded; we discard
|
||||
the body and read only the headers + status code.
|
||||
"""
|
||||
tmp_path = Path(f"{target_path}.tmp")
|
||||
try:
|
||||
resp = client.head(path)
|
||||
status = getattr(resp, "status_code", 200)
|
||||
if status < 400:
|
||||
size = int(resp.headers.get("content-length", "0") or 0)
|
||||
accepts = (resp.headers.get("accept-ranges", "").lower() == "bytes")
|
||||
if size > 0:
|
||||
return (size, accepts)
|
||||
# HEAD failed (405 from GET-only route is the common case in
|
||||
# non-Caddy deployments) or returned 0-length — fall through to
|
||||
# the tiny-Range GET probe.
|
||||
except Exception:
|
||||
pass
|
||||
try:
|
||||
with client.stream("GET", path, headers={"Range": "bytes=0-0"}) as resp:
|
||||
status = getattr(resp, "status_code", 0)
|
||||
if status not in (200, 206):
|
||||
return (0, False)
|
||||
# Drain the 1-byte body so the connection is reusable.
|
||||
for _ in resp.iter_bytes():
|
||||
pass
|
||||
# Content-Range on a 206 response carries the total: `bytes 0-0/12345`.
|
||||
# On a 200 response the server didn't honor Range — content-length is the total.
|
||||
if status == 206:
|
||||
cr = resp.headers.get("content-range", "")
|
||||
if "/" in cr:
|
||||
try:
|
||||
total = int(cr.rsplit("/", 1)[1])
|
||||
return (total, True)
|
||||
except ValueError:
|
||||
return (0, False)
|
||||
return (0, False)
|
||||
# status == 200 → server ignored Range; we can read content-length but
|
||||
# accept-ranges is False (or missing) so the caller will not chunk.
|
||||
size = int(resp.headers.get("content-length", "0") or 0)
|
||||
accepts = (resp.headers.get("accept-ranges", "").lower() == "bytes")
|
||||
return (size, accepts)
|
||||
except Exception:
|
||||
return (0, False)
|
||||
|
||||
|
||||
class _RangeNotHonored(Exception):
|
||||
"""Internal sentinel — server returned 200 instead of 206 to a Range
|
||||
request. Caller catches and falls back to the single-stream path."""
|
||||
|
||||
|
||||
def _download_chunk(
|
||||
client: httpx.Client,
|
||||
path: str,
|
||||
start: int,
|
||||
end: int,
|
||||
part_path: Path,
|
||||
progress_callback,
|
||||
) -> None:
|
||||
"""Stream `bytes=start-end` to `part_path`. Caller deals with retry +
|
||||
cleanup. Raises on any failure (HTTPStatusError on non-206 response,
|
||||
httpx.* on transport blip, `_RangeNotHonored` if server returned 200
|
||||
instead of 206 — chunked path can't trust that result)."""
|
||||
headers = {"Range": f"bytes={start}-{end}"}
|
||||
with client.stream("GET", path, headers=headers) as response:
|
||||
# Server didn't honor the Range — RFC says it MAY return 200 with
|
||||
# the full body. We can't safely splice that into one part of N,
|
||||
# so we abort the whole chunked path and let the caller fall back.
|
||||
if response.status_code == 200:
|
||||
raise _RangeNotHonored()
|
||||
response.raise_for_status()
|
||||
with open(part_path, "wb") as f:
|
||||
for piece in response.iter_bytes(chunk_size=65536):
|
||||
f.write(piece)
|
||||
if progress_callback and piece:
|
||||
progress_callback(len(piece))
|
||||
|
||||
|
||||
def _download_chunked(
|
||||
client: httpx.Client,
|
||||
path: str,
|
||||
target_path: str,
|
||||
total_size: int,
|
||||
parallelism: int,
|
||||
progress_callback,
|
||||
) -> int:
|
||||
"""Range-based parallel download. Returns total bytes written.
|
||||
|
||||
Raises `_RangeNotHonored` on the first 200-instead-of-206 response so
|
||||
the caller can fall back. All other exceptions propagate.
|
||||
|
||||
Cleanup discipline: every part file we create gets removed before
|
||||
return (success or failure). The destination is written via the
|
||||
caller's `<target>.tmp` and renamed atomically.
|
||||
"""
|
||||
target = Path(target_path)
|
||||
# Reap leftovers from previously SIGKILL'd / crashed pulls before we
|
||||
# start writing — without this, PID-suffixed files from dead PIDs
|
||||
# accumulate forever on disk (devil's-advocate R3 finding #1).
|
||||
_reap_dead_pid_leftovers(target_path)
|
||||
# Per-process tmp + part suffixes (devil's-advocate R2 finding #2):
|
||||
# if two `agnes pull` invocations target the same parquet
|
||||
# concurrently (e.g. SessionStart hook + manual run, or two
|
||||
# terminals), bare `<target>.tmp` and `<target>.partN` paths would
|
||||
# collide — one process's part-write yanks the other's in-progress
|
||||
# write, manifest hash check then fails spuriously. Including PID
|
||||
# in the suffix makes each invocation's intermediate files
|
||||
# disjoint; the final `os.replace` to the bare target is atomic so
|
||||
# last-writer-wins, both processes succeed individually.
|
||||
pid = os.getpid()
|
||||
tmp_path = Path(f"{target_path}.{pid}.tmp")
|
||||
parallelism = max(1, parallelism)
|
||||
# Build chunks — last chunk takes the remainder.
|
||||
chunk_size = total_size // parallelism
|
||||
if chunk_size <= 0:
|
||||
chunk_size = total_size # tiny file, single chunk
|
||||
parallelism = 1
|
||||
ranges = []
|
||||
for i in range(parallelism):
|
||||
start = i * chunk_size
|
||||
end = (start + chunk_size - 1) if i < parallelism - 1 else (total_size - 1)
|
||||
ranges.append((i, start, end))
|
||||
|
||||
part_paths = [Path(f"{target_path}.{pid}.part{i}") for i, _, _ in ranges]
|
||||
# Pre-clean any leftovers from a prior run of THIS process.
|
||||
for p in part_paths:
|
||||
p.unlink(missing_ok=True)
|
||||
|
||||
def _attempt_chunk(i: int, start: int, end: int) -> None:
|
||||
last_exc: Optional[Exception] = None
|
||||
for attempt in range(_RETRY_ATTEMPTS + 1):
|
||||
try:
|
||||
_download_chunk(
|
||||
client, path, start, end, part_paths[i],
|
||||
progress_callback,
|
||||
)
|
||||
return
|
||||
except _RangeNotHonored:
|
||||
# Don't retry — server policy, not a transport blip.
|
||||
raise
|
||||
except Exception as exc:
|
||||
last_exc = exc
|
||||
if attempt == _RETRY_ATTEMPTS or not _is_transient(exc):
|
||||
break
|
||||
time.sleep(_RETRY_BACKOFFS_S[
|
||||
min(attempt, len(_RETRY_BACKOFFS_S) - 1)
|
||||
])
|
||||
assert last_exc is not None
|
||||
raise last_exc
|
||||
|
||||
try:
|
||||
if parallelism == 1:
|
||||
_attempt_chunk(*ranges[0])
|
||||
else:
|
||||
# Use a thread pool so each chunk gets its own concurrent
|
||||
# request slot on the (HTTP/2-multiplexed when available)
|
||||
# shared client. httpx.Client is thread-safe for stream().
|
||||
with ThreadPoolExecutor(max_workers=parallelism) as ex:
|
||||
futs = [ex.submit(_attempt_chunk, *r) for r in ranges]
|
||||
for fut in as_completed(futs):
|
||||
fut.result() # propagate first error
|
||||
|
||||
# Concatenate parts → tmp_path → atomic rename.
|
||||
tmp_path.unlink(missing_ok=True)
|
||||
total_written = 0
|
||||
with open(tmp_path, "wb") as out:
|
||||
for p in part_paths:
|
||||
with open(p, "rb") as inp:
|
||||
while True:
|
||||
block = inp.read(65536)
|
||||
if not block:
|
||||
break
|
||||
out.write(block)
|
||||
total_written += len(block)
|
||||
os.replace(tmp_path, target)
|
||||
return total_written
|
||||
finally:
|
||||
# Always clean up part files + any stray tmp.
|
||||
for p in part_paths:
|
||||
p.unlink(missing_ok=True)
|
||||
if tmp_path.exists():
|
||||
try:
|
||||
tmp_path.unlink()
|
||||
except OSError:
|
||||
pass
|
||||
|
||||
|
||||
def _download_single_stream(
|
||||
client: httpx.Client,
|
||||
path: str,
|
||||
target_path: str,
|
||||
progress_callback,
|
||||
) -> int:
|
||||
"""Original single-stream path with retry. Used when chunking is
|
||||
disabled (small file, no range support, or fallback after 200-on-Range)."""
|
||||
# Same dead-PID reap as `_download_chunked` so leftovers from
|
||||
# crashed prior pulls don't accumulate indefinitely.
|
||||
_reap_dead_pid_leftovers(target_path)
|
||||
# Per-process tmp suffix — same rationale as `_download_chunked`
|
||||
# (devil's-advocate R2 finding #2): concurrent `agnes pull`
|
||||
# invocations against the same target dir must not yank each
|
||||
# other's in-progress writes.
|
||||
tmp_path = Path(f"{target_path}.{os.getpid()}.tmp")
|
||||
last_exc: Optional[Exception] = None
|
||||
for attempt in range(_RETRY_ATTEMPTS + 1):
|
||||
try:
|
||||
tmp_path.unlink(missing_ok=True)
|
||||
with get_client(timeout=300.0) as client:
|
||||
with client.stream("GET", path) as response:
|
||||
response.raise_for_status()
|
||||
total = 0
|
||||
|
|
@ -223,7 +601,6 @@ def stream_download(path: str, target_path: str, progress_callback=None) -> int:
|
|||
total += len(chunk)
|
||||
if progress_callback:
|
||||
progress_callback(len(chunk))
|
||||
# os.replace is atomic on POSIX and Windows for same-filesystem moves.
|
||||
os.replace(tmp_path, target_path)
|
||||
return total
|
||||
except Exception as exc:
|
||||
|
|
@ -231,23 +608,115 @@ def stream_download(path: str, target_path: str, progress_callback=None) -> int:
|
|||
if attempt == _RETRY_ATTEMPTS or not _is_transient(exc):
|
||||
break
|
||||
time.sleep(_RETRY_BACKOFFS_S[min(attempt, len(_RETRY_BACKOFFS_S) - 1)])
|
||||
# Clean up any leftover tmp, then surface the last exception. Translate
|
||||
# transport errors (timeouts, connection drops, protocol errors) to
|
||||
# AgnesTransportError so the CLI prints a clean message instead of a
|
||||
# Python traceback (Pavel's #185 Phase 3B). HTTPStatusError (4xx/5xx
|
||||
# response from the server) is NOT a transport failure and must
|
||||
# re-raise verbatim so the caller's status-code handling + the rich
|
||||
# server error body (e.g. 401 with "token expired", 403 with
|
||||
# cross_project_forbidden detail) reach the analyst — Devin Review on
|
||||
# PR #188 caught: HTTPStatusError is a subclass of HTTPError, so the
|
||||
# generic isinstance(HTTPError) translation was eating status codes.
|
||||
tmp_path.unlink(missing_ok=True)
|
||||
assert last_exc is not None
|
||||
if isinstance(last_exc, httpx.HTTPStatusError):
|
||||
raise last_exc
|
||||
if isinstance(last_exc, httpx.HTTPError):
|
||||
|
||||
|
||||
def stream_download(path: str, target_path: str, progress_callback=None) -> int:
|
||||
"""Stream a file to `target_path` atomically and with retries.
|
||||
|
||||
Two paths:
|
||||
1. **Chunked parallel** — when the server advertises `accept-ranges:
|
||||
bytes` and `content-length` exceeds `AGNES_PULL_CHUNK_THRESHOLD_BYTES`
|
||||
(default 50 MB), split into N range requests
|
||||
(`AGNES_PULL_CHUNK_PARALLELISM`, default 4, capped 1..16) and
|
||||
download in parallel. Concatenate the part files into `<target>.tmp`,
|
||||
then `os.replace`. Falls back to single-stream if the server
|
||||
responds 200 instead of 206 to a Range probe.
|
||||
2. **Single-stream** — for small files, no range support, or fallback
|
||||
from the chunked path. Same atomic-rename + retry semantics as
|
||||
before.
|
||||
|
||||
Durability properties (unchanged):
|
||||
- Writes to `<target>.tmp`, then `os.replace` on success. The real
|
||||
target file never exists in a half-written state.
|
||||
- Retries up to `_RETRY_ATTEMPTS` on transient errors (network blip,
|
||||
5xx); 4xx (auth/404) is raised immediately.
|
||||
- No hash check here — that's the caller's job (manifest hash).
|
||||
|
||||
Threading: the chunked path uses a ThreadPoolExecutor sized to the
|
||||
parallelism. httpx.Client.stream() is safe to call concurrently from
|
||||
multiple threads on a single client (the connection pool serializes
|
||||
the underlying socket access; HTTP/2 multiplexes streams when the
|
||||
`h2` extra is installed).
|
||||
"""
|
||||
# Use the shared persistent client when available — one TLS
|
||||
# handshake amortized across N stream_download calls within the same
|
||||
# process, and HTTP/2 stream multiplexing across the chunk Range
|
||||
# requests within a single download. Falls back to a fresh per-call
|
||||
# client if shared-client construction fails (e.g. `h2` install
|
||||
# broken at runtime). Devil's-advocate R2 finding #1: scope the
|
||||
# try/except to *only* the shared-client construction — the actual
|
||||
# download must NOT be retried under this except, otherwise hard
|
||||
# failures (401/403/404/5xx) waste a full second download attempt
|
||||
# and revoked-PAT cases don't fail-fast.
|
||||
try:
|
||||
client = _get_shared_client()
|
||||
except Exception:
|
||||
with get_client(timeout=300.0) as client:
|
||||
return _stream_download_via(client, path, target_path, progress_callback)
|
||||
return _stream_download_via(client, path, target_path, progress_callback)
|
||||
|
||||
|
||||
def _stream_download_via(
|
||||
client: httpx.Client,
|
||||
path: str,
|
||||
target_path: str,
|
||||
progress_callback,
|
||||
) -> int:
|
||||
"""The shared body of `stream_download` parameterized on the client.
|
||||
Split out so tests can inject a fake client."""
|
||||
threshold = _read_chunk_threshold_bytes()
|
||||
parallelism = _read_chunk_parallelism()
|
||||
|
||||
total_size = 0
|
||||
accepts_ranges = False
|
||||
if parallelism > 1:
|
||||
total_size, accepts_ranges = _probe_range_support(client, path)
|
||||
|
||||
# Sanity bound on the advertised total size (devil's-advocate R1
|
||||
# finding #4): a misconfigured proxy or buggy server returning a
|
||||
# wildly inflated `Content-Length` would make us split into huge
|
||||
# `Range: bytes=N-M` requests; the server then clamps each to actual
|
||||
# bytes available, and we end up with overlapping bytes from the
|
||||
# start of the file in every part → corrupt assembled output (caught
|
||||
# later by manifest hash check, but only after wasted bandwidth).
|
||||
# 100 GiB is the operational ceiling for any single materialized
|
||||
# parquet on a typical Agnes deployment; values above suggest a
|
||||
# server / proxy bug rather than a legitimate huge file. Drop to
|
||||
# single-stream (which can't be confused by overlapping chunks).
|
||||
SANE_MAX_TOTAL = 100 * 1024**3 # 100 GiB
|
||||
if total_size > SANE_MAX_TOTAL:
|
||||
total_size = 0
|
||||
accepts_ranges = False
|
||||
|
||||
use_chunked = (
|
||||
parallelism > 1
|
||||
and accepts_ranges
|
||||
and total_size > threshold
|
||||
)
|
||||
|
||||
try:
|
||||
if use_chunked:
|
||||
try:
|
||||
return _download_chunked(
|
||||
client, path, target_path, total_size, parallelism,
|
||||
progress_callback,
|
||||
)
|
||||
except _RangeNotHonored:
|
||||
# Server lied / proxy stripped the Range — fall through.
|
||||
pass
|
||||
return _download_single_stream(
|
||||
client, path, target_path, progress_callback,
|
||||
)
|
||||
except httpx.HTTPStatusError:
|
||||
# 4xx / 5xx response from the server — re-raise verbatim so the
|
||||
# caller's status-code handling + the rich server error body
|
||||
# reach the analyst (Devin Review on PR #188).
|
||||
raise
|
||||
except httpx.HTTPError as exc:
|
||||
raise _translate_transport_error(
|
||||
last_exc, context=f"GET {path} (stream → {target_path})",
|
||||
exc, context=f"GET {path} (stream → {target_path})",
|
||||
timeout_s=300.0,
|
||||
) from last_exc
|
||||
raise last_exc
|
||||
) from exc
|
||||
|
|
|
|||
171
cli/lib/pull.py
171
cli/lib/pull.py
|
|
@ -65,6 +65,128 @@ class PullResult:
|
|||
_SAFE_ID_RE = re.compile(r"^[a-zA-Z0-9_\-]{1,128}$")
|
||||
|
||||
|
||||
class _TextualProgress:
|
||||
"""Plain-text progress emitter for non-TTY stderr.
|
||||
|
||||
When `agnes pull` is invoked from a Claude Code SessionStart hook,
|
||||
a CI runner, or any pipe consumer, stderr is not a terminal. Rich's
|
||||
progress bar in that mode either suppresses output (silent for
|
||||
minutes on a multi-GB parquet) or emits raw ANSI noise. This class
|
||||
instead emits one terse line per file at sensible cadence.
|
||||
|
||||
Cadence policy: emit when *either*:
|
||||
- per-file bytes-downloaded crosses a 10%-of-total boundary, OR
|
||||
- 30 s have elapsed since this file's last emission.
|
||||
|
||||
Always emits one final "done" line per file via `finish()` so the
|
||||
operator sees a confirmed completion even on tiny files.
|
||||
|
||||
Format: `[N/T files] <tid>: 25% (16 MB / 66 MB) at 1.5 MB/s` — the
|
||||
"[N/T files]" prefix lets the operator see overall pull progress
|
||||
in a multi-table run without buffering all per-file lines.
|
||||
|
||||
Thread-safe — `advance` is called from the chunked-download worker
|
||||
threads; an internal lock serializes the update + emit.
|
||||
"""
|
||||
|
||||
_HUMAN_UNITS = (
|
||||
(1024 * 1024 * 1024 * 1024, "TB"),
|
||||
(1024 * 1024 * 1024, "GB"),
|
||||
(1024 * 1024, "MB"),
|
||||
(1024, "KB"),
|
||||
)
|
||||
|
||||
def __init__(self, *, stream, total_files: int, file_sizes: dict[str, int]):
|
||||
import threading
|
||||
self._stream = stream
|
||||
self._total_files = total_files
|
||||
self._file_sizes = file_sizes
|
||||
self._lock = threading.Lock()
|
||||
# Per-file state.
|
||||
self._bytes: dict[str, int] = {tid: 0 for tid in file_sizes}
|
||||
self._started_at: dict[str, float] = {}
|
||||
self._last_emit_at: dict[str, float] = {}
|
||||
self._last_emit_pct: dict[str, int] = {}
|
||||
self._finished_idx: int = 0 # files whose `finish` line has been emitted
|
||||
|
||||
def advance(self, tid: str, n: int) -> None:
|
||||
"""Add `n` bytes to the file's total. Emit a textual update if
|
||||
the cadence policy allows."""
|
||||
with self._lock:
|
||||
now = time.monotonic()
|
||||
if tid not in self._started_at:
|
||||
self._started_at[tid] = now
|
||||
self._last_emit_at[tid] = now
|
||||
self._last_emit_pct[tid] = 0
|
||||
self._bytes[tid] = self._bytes.get(tid, 0) + n
|
||||
|
||||
total = self._file_sizes.get(tid, 0)
|
||||
current = self._bytes[tid]
|
||||
pct = int((current * 100) / total) if total > 0 else 0
|
||||
elapsed = now - self._last_emit_at[tid]
|
||||
crossed_10 = pct >= self._last_emit_pct[tid] + 10
|
||||
if crossed_10 or elapsed >= 30.0:
|
||||
self._last_emit_at[tid] = now
|
||||
self._last_emit_pct[tid] = pct - (pct % 10)
|
||||
self._emit_line(tid, current, total, now)
|
||||
|
||||
def finish(self) -> None:
|
||||
"""Emit a final `done` line for any file we never closed out."""
|
||||
with self._lock:
|
||||
now = time.monotonic()
|
||||
for tid, total in self._file_sizes.items():
|
||||
# Treat any file we observed bytes for as needing a
|
||||
# final line. Files that errored out before any callback
|
||||
# are still announced (operator wants visibility even on
|
||||
# zero-byte attempts).
|
||||
self._finished_idx += 1
|
||||
bytes_ = self._bytes.get(tid, 0)
|
||||
started = self._started_at.get(tid, now)
|
||||
duration = max(0.001, now - started)
|
||||
rate = bytes_ / duration
|
||||
line = (
|
||||
f"[{self._finished_idx}/{self._total_files} files] "
|
||||
f"{tid}: 100% done "
|
||||
f"({self._fmt_bytes(bytes_)} in {duration:.1f}s, "
|
||||
f"{self._fmt_bytes(int(rate))}/s)\n"
|
||||
)
|
||||
self._stream.write(line)
|
||||
try:
|
||||
self._stream.flush()
|
||||
except Exception:
|
||||
pass
|
||||
|
||||
def _emit_line(self, tid: str, current: int, total: int, now: float) -> None:
|
||||
started = self._started_at.get(tid, now)
|
||||
duration = max(0.001, now - started)
|
||||
rate = current / duration
|
||||
if total > 0:
|
||||
pct_str = f"{int((current * 100) / total)}%"
|
||||
size_str = (
|
||||
f"({self._fmt_bytes(current)} / {self._fmt_bytes(total)})"
|
||||
)
|
||||
else:
|
||||
pct_str = "?"
|
||||
size_str = f"({self._fmt_bytes(current)})"
|
||||
idx = self._finished_idx + 1 # 1-based "currently working on file N"
|
||||
line = (
|
||||
f"[{idx}/{self._total_files} files] {tid}: {pct_str} "
|
||||
f"{size_str} at {self._fmt_bytes(int(rate))}/s\n"
|
||||
)
|
||||
self._stream.write(line)
|
||||
try:
|
||||
self._stream.flush()
|
||||
except Exception:
|
||||
pass
|
||||
|
||||
@classmethod
|
||||
def _fmt_bytes(cls, n: int) -> str:
|
||||
for divisor, suffix in cls._HUMAN_UNITS:
|
||||
if n >= divisor:
|
||||
return f"{n / divisor:.1f} {suffix}"
|
||||
return f"{n} B"
|
||||
|
||||
|
||||
@contextmanager
|
||||
def _override_server_env(server_url: str, token: str) -> Iterator[None]:
|
||||
"""Set AGNES_SERVER + scoped token override for the duration of the call.
|
||||
|
|
@ -219,15 +341,34 @@ def run_pull(
|
|||
# the executor + thread overhead for the common single-update case.
|
||||
workers = min(workers, len(to_download)) if to_download else 1
|
||||
|
||||
# Optional progress bar — Rich's Progress tracks per-file bytes
|
||||
# streamed, aggregated across the parallel ThreadPoolExecutor
|
||||
# workers. Pavel's #185 Phase 1: a single 6.3 GB parquet on first
|
||||
# init went 44 minutes silent, looked frozen. Now: aggregate "X.Y
|
||||
# GB / Z.A GB · 56 MB/s · ETA 1m 20s" to stderr while threads
|
||||
# stream. None when show_progress=False (SessionStart hooks etc.).
|
||||
# Optional progress reporting — two paths.
|
||||
#
|
||||
# 1. Rich progress bar: per-file bytes-streamed bar with speed +
|
||||
# ETA. Rendered to stderr when stderr is a TTY. Aggregates
|
||||
# across the parallel ThreadPoolExecutor workers and across
|
||||
# chunked-download chunks (all chunks call the same callback
|
||||
# advancing the same task).
|
||||
# 2. Textual fallback: when `show_progress=True` but stderr is
|
||||
# NOT a TTY (Claude Code SessionStart hook, CI run, Docker
|
||||
# log capture), Rich would either suppress the bar or emit
|
||||
# raw control sequences. Instead we emit one plain-text line
|
||||
# per file at most every 10% or 30 s — enough signal to know
|
||||
# the pull isn't frozen on a multi-GB parquet, terse enough
|
||||
# not to spam the consumer's log.
|
||||
#
|
||||
# Both paths receive the same per-file callback so the chunked-
|
||||
# download contract ("one file = one task, sum-of-chunks bytes")
|
||||
# is honored uniformly.
|
||||
import sys as _sys
|
||||
progress = None
|
||||
progress_tasks: dict[str, int] = {}
|
||||
if show_progress and to_download:
|
||||
textual = None
|
||||
use_textual_fallback = (
|
||||
show_progress
|
||||
and to_download
|
||||
and not _sys.stderr.isatty()
|
||||
)
|
||||
if show_progress and to_download and not use_textual_fallback:
|
||||
from rich.progress import (
|
||||
Progress, BarColumn, DownloadColumn, TextColumn,
|
||||
TimeRemainingColumn, TransferSpeedColumn,
|
||||
|
|
@ -248,13 +389,22 @@ def run_pull(
|
|||
progress_tasks[tid] = progress.add_task(
|
||||
"download", label=tid, total=size if size > 0 else None,
|
||||
)
|
||||
elif use_textual_fallback:
|
||||
textual = _TextualProgress(
|
||||
stream=_sys.stderr,
|
||||
total_files=len(to_download),
|
||||
file_sizes={
|
||||
tid: int(server_tables[tid].get("size_bytes") or 0)
|
||||
for tid in to_download
|
||||
},
|
||||
)
|
||||
|
||||
def _download_one(tid: str) -> tuple[str, dict | None, str | None]:
|
||||
"""Returns (tid, local_table_entry_or_None, error_or_None).
|
||||
One bound thread per call; stream_download is sync I/O so a
|
||||
ThreadPoolExecutor (not asyncio) is the right tool. The
|
||||
progress callback is thread-safe — Rich's Progress.update
|
||||
holds an internal lock."""
|
||||
and the textual fallback's lock both serialize internally."""
|
||||
target = parquet_dir / f"{tid}.parquet"
|
||||
expected_hash = server_tables[tid].get("hash", "")
|
||||
cb = None
|
||||
|
|
@ -262,6 +412,9 @@ def run_pull(
|
|||
task_id = progress_tasks[tid]
|
||||
def cb(n: int, _tid=tid, _task=task_id):
|
||||
progress.update(_task, advance=n)
|
||||
elif textual is not None:
|
||||
def cb(n: int, _tid=tid):
|
||||
textual.advance(_tid, n)
|
||||
try:
|
||||
stream_download(f"/api/data/{tid}/download", str(target),
|
||||
progress_callback=cb)
|
||||
|
|
@ -294,6 +447,8 @@ def run_pull(
|
|||
finally:
|
||||
if progress is not None:
|
||||
progress.stop()
|
||||
if textual is not None:
|
||||
textual.finish()
|
||||
|
||||
for tid, entry, err in outcomes:
|
||||
if err is not None:
|
||||
|
|
|
|||
|
|
@ -135,6 +135,13 @@ data_source:
|
|||
# # view-backed datasets -- bumped to 600 000 ms = 10 min by default.
|
||||
# # Set 0 to fall through to the extension default. Configurable via
|
||||
# # /admin/server-config UI.
|
||||
# session_pool_size: 4
|
||||
# # Number of pre-warmed DuckDB+bigquery-extension sessions kept
|
||||
# # in a process-local pool. Each acquire amortizes the
|
||||
# # ~0.5 s INSTALL/LOAD/CREATE-SECRET cost across requests; a fresh
|
||||
# # build only happens when the pool is empty. Default 4. Set 0
|
||||
# # to disable pooling (every acquire builds + closes a fresh
|
||||
# # session; matches pre-pool behavior).
|
||||
|
||||
# --- OpenMetadata catalog (optional) ---
|
||||
# Enriches table and column metadata from OpenMetadata REST API.
|
||||
|
|
|
|||
|
|
@ -8,6 +8,8 @@ from __future__ import annotations
|
|||
|
||||
import functools
|
||||
import logging
|
||||
import threading
|
||||
from collections import deque
|
||||
from contextlib import contextmanager
|
||||
from dataclasses import dataclass
|
||||
from typing import Callable, Iterator, Literal
|
||||
|
|
@ -42,6 +44,12 @@ class BqAccessError(Exception):
|
|||
"bq_forbidden": 502, # other Forbidden from BQ
|
||||
"bq_bad_request": 400, # 400 from BQ when caller flagged it as client-derived
|
||||
"bq_upstream_error": 502, # all other upstream BQ failures
|
||||
# `responseTooLarge` is a BQ refusal whose root cause is query shape
|
||||
# (the user asked for too many rows back inline), not auth or syntax.
|
||||
# 400 with a specific actionable hint instead of the generic
|
||||
# bq_bad_request / bq_upstream_error mappings, which surfaced the
|
||||
# raw BQ message and gave operators no path forward.
|
||||
"bq_response_too_large": 400,
|
||||
}
|
||||
|
||||
def __init__(self, kind: str, message: str, details: dict | None = None):
|
||||
|
|
@ -51,6 +59,43 @@ class BqAccessError(Exception):
|
|||
super().__init__(message)
|
||||
|
||||
|
||||
_RESPONSE_TOO_LARGE_HINT = (
|
||||
"BigQuery refused to return the result inline; the query exceeded BQ's "
|
||||
"response size limit. Narrow the WHERE clause, aggregate further, "
|
||||
"select fewer columns, or query a materialized table that's already "
|
||||
"been bounded server-side."
|
||||
)
|
||||
|
||||
|
||||
def _classify_response_too_large(msg: str, projects: BqProjects) -> BqAccessError:
|
||||
"""Build the `bq_response_too_large` BqAccessError with the canonical
|
||||
actionable hint and the original BQ message preserved in details for
|
||||
operator debugging."""
|
||||
return BqAccessError(
|
||||
"bq_response_too_large",
|
||||
_RESPONSE_TOO_LARGE_HINT,
|
||||
details={
|
||||
"original": msg,
|
||||
"billing_project": projects.billing,
|
||||
"data_project": projects.data,
|
||||
},
|
||||
)
|
||||
|
||||
|
||||
def _is_response_too_large(msg: str) -> bool:
|
||||
"""Detect BQ's `responseTooLarge` failure mode by message substring.
|
||||
|
||||
The reason code is stable across HTTP transports (gax.BadRequest from
|
||||
google-cloud-bigquery, duckdb.IOException from the BQ extension's own
|
||||
HTTP layer); both surface 'Response too large to return' verbatim in
|
||||
the message body. Match case-insensitively + tolerate the slight
|
||||
variant 'response too large' that some surfaces emit without the
|
||||
'to return' suffix.
|
||||
"""
|
||||
ml = msg.lower()
|
||||
return "response too large" in ml
|
||||
|
||||
|
||||
def translate_bq_error(
|
||||
e: Exception,
|
||||
projects: BqProjects,
|
||||
|
|
@ -67,12 +112,24 @@ def translate_bq_error(
|
|||
2. Forbidden + 'serviceusage' in str(e).lower()
|
||||
-> cross_project_forbidden (with hint)
|
||||
3. Forbidden -> bq_forbidden
|
||||
4. BadRequest, bad_request_status='client_error'
|
||||
4. 'response too large' in str(e).lower()
|
||||
-> bq_response_too_large (HTTP 400, with
|
||||
actionable hint pointing at WHERE /
|
||||
aggregate / materialized remediations)
|
||||
5. BadRequest, bad_request_status='client_error'
|
||||
-> bq_bad_request (HTTP 400)
|
||||
5. BadRequest, bad_request_status='upstream_error'
|
||||
6. BadRequest, bad_request_status='upstream_error'
|
||||
-> bq_upstream_error (HTTP 502)
|
||||
6. GoogleAPICallError (other) -> bq_upstream_error
|
||||
7. Anything else -> RE-RAISED unchanged (don't swallow programmer errors)
|
||||
7. GoogleAPICallError (other) -> bq_upstream_error
|
||||
8. Anything else -> RE-RAISED unchanged (don't swallow programmer errors)
|
||||
|
||||
The `responseTooLarge` mapping (4) sits ahead of the generic BadRequest
|
||||
cases on purpose: BQ surfaces this failure mode as a 400 with a
|
||||
specific reason, but the actionable remediation is "shape your query
|
||||
differently" — not "your SQL has a syntax error" (the typical
|
||||
bq_bad_request user-facing meaning) and not "BQ is broken"
|
||||
(bq_upstream_error). Routing it via its own kind keeps the user-facing
|
||||
message tight + correct.
|
||||
"""
|
||||
if isinstance(e, BqAccessError):
|
||||
return e
|
||||
|
|
@ -106,6 +163,13 @@ def translate_bq_error(
|
|||
details={"billing_project": projects.billing, "data_project": projects.data},
|
||||
)
|
||||
|
||||
# Special-case: `responseTooLarge` arrives as gax.BadRequest (HTTP 400)
|
||||
# but has a unique reason code with a specific, actionable remediation.
|
||||
# Catch it BEFORE the generic BadRequest mapping below so it doesn't
|
||||
# surface as a confusing "bad request" (which implies bad SQL).
|
||||
if _is_response_too_large(msg):
|
||||
return _classify_response_too_large(msg, projects)
|
||||
|
||||
if isinstance(e, gax.BadRequest):
|
||||
if bad_request_status == "client_error":
|
||||
return BqAccessError("bq_bad_request", msg)
|
||||
|
|
@ -196,15 +260,40 @@ def _default_client_factory(projects: BqProjects):
|
|||
)
|
||||
|
||||
|
||||
@contextmanager
|
||||
def _default_duckdb_session_factory(projects: BqProjects):
|
||||
"""Yield an in-memory DuckDB conn with bigquery extension loaded + SECRET set
|
||||
from get_metadata_token(). Auto-cleanup. Translates auth/install failures
|
||||
to BqAccessError(kind='auth_failed' or 'bq_lib_missing').
|
||||
def _default_pool_size() -> int:
|
||||
"""Resolve the BQ DuckDB-extension session pool size from instance.yaml.
|
||||
|
||||
Note: `projects.billing` is not used by this factory directly — bigquery_query()
|
||||
callers pass it themselves as the first positional arg to identify the billing
|
||||
project. The factory keeps the parameter for symmetry with _default_client_factory.
|
||||
Reads ``data_source.bigquery.session_pool_size`` (default 4). Sentinel
|
||||
``0`` disables pooling (every acquire builds + closes a fresh session;
|
||||
matches pre-pool behavior). Negative / non-numeric values fall back to
|
||||
the default — the pool is a perf optimization, not a correctness
|
||||
boundary, so an unparseable config shouldn't fail-stop the app.
|
||||
"""
|
||||
try:
|
||||
from app.instance_config import get_value
|
||||
except Exception:
|
||||
return 4
|
||||
raw = get_value("data_source", "bigquery", "session_pool_size", default=4)
|
||||
try:
|
||||
n = int(raw) if raw is not None else 4
|
||||
except (TypeError, ValueError):
|
||||
logger.warning(
|
||||
"BQ session_pool_size=%r is not an int; falling back to default 4",
|
||||
raw,
|
||||
)
|
||||
return 4
|
||||
if n < 0:
|
||||
return 4
|
||||
return n
|
||||
|
||||
|
||||
def _build_fresh_bq_session():
|
||||
"""Build a single fresh in-memory DuckDB conn with the bigquery extension
|
||||
INSTALL/LOAD'd, the auth SECRET created from get_metadata_token(), and
|
||||
per-session settings applied. Translates auth / install failures to
|
||||
BqAccessError. Caller owns the close.
|
||||
|
||||
Used internally by the pool; also used directly when pooling is disabled.
|
||||
"""
|
||||
import duckdb # type: ignore
|
||||
from connectors.bigquery.auth import get_metadata_token, BQMetadataAuthError
|
||||
|
|
@ -219,7 +308,6 @@ def _default_duckdb_session_factory(projects: BqProjects):
|
|||
)
|
||||
|
||||
conn = duckdb.connect(":memory:")
|
||||
try:
|
||||
try:
|
||||
conn.execute("INSTALL bigquery FROM community; LOAD bigquery;")
|
||||
escaped = token.replace("'", "''")
|
||||
|
|
@ -227,15 +315,170 @@ def _default_duckdb_session_factory(projects: BqProjects):
|
|||
f"CREATE OR REPLACE SECRET bq_s (TYPE bigquery, ACCESS_TOKEN '{escaped}')"
|
||||
)
|
||||
except Exception as e:
|
||||
# Build failed — must close the half-initialised conn, otherwise it
|
||||
# leaks across the pool's lifetime.
|
||||
try:
|
||||
conn.close()
|
||||
except Exception:
|
||||
pass
|
||||
raise BqAccessError(
|
||||
"bq_lib_missing",
|
||||
f"failed to install/load BigQuery DuckDB extension: {e}",
|
||||
details={"original": str(e)},
|
||||
)
|
||||
apply_bq_session_settings(conn)
|
||||
return conn
|
||||
|
||||
|
||||
def _refresh_bq_secret(conn) -> None:
|
||||
"""Refresh the auth SECRET on a pooled connection so token rotation
|
||||
(default GCE metadata token TTL ~1 hr) doesn't break long-lived
|
||||
pooled entries.
|
||||
|
||||
Cheap when the token cache is warm (a few µs). Failures are
|
||||
non-fatal here — the pool's liveness probe + per-acquire build
|
||||
fallback will catch genuinely-broken entries.
|
||||
"""
|
||||
from connectors.bigquery.auth import get_metadata_token
|
||||
try:
|
||||
token = get_metadata_token()
|
||||
escaped = token.replace("'", "''")
|
||||
conn.execute(
|
||||
f"CREATE OR REPLACE SECRET bq_s (TYPE bigquery, ACCESS_TOKEN '{escaped}')"
|
||||
)
|
||||
except Exception as e:
|
||||
# Bubble up so the pool drops this entry and rebuilds.
|
||||
raise BqAccessError(
|
||||
"auth_failed",
|
||||
f"could not refresh BQ secret on pooled session: {e}",
|
||||
details={"original": str(e)},
|
||||
)
|
||||
|
||||
|
||||
def _is_pool_entry_alive(conn) -> bool:
|
||||
"""Cheap liveness probe — `SELECT 1`. Returns False on any error so
|
||||
the pool reaper drops the entry and builds a fresh one."""
|
||||
try:
|
||||
result = conn.execute("SELECT 1").fetchone()
|
||||
return result is not None and result[0] == 1
|
||||
except Exception:
|
||||
return False
|
||||
|
||||
|
||||
# Module-level pool state. Process-cached (mirrors get_bq_access's lifetime).
|
||||
# Not fork-safe — single uvicorn worker process is the supported deployment
|
||||
# shape per CLAUDE.md.
|
||||
_pool: deque = deque()
|
||||
_pool_lock = threading.Lock()
|
||||
|
||||
|
||||
def _reset_session_pool_for_tests() -> None:
|
||||
"""Drop and close every pooled entry. Test helper — production code
|
||||
should not call this. Exposed so test fixtures + the existing
|
||||
test_bq_access tests can pin pre-test pool state to empty."""
|
||||
with _pool_lock:
|
||||
while _pool:
|
||||
entry = _pool.popleft()
|
||||
try:
|
||||
entry.close()
|
||||
except Exception:
|
||||
pass
|
||||
|
||||
|
||||
@contextmanager
|
||||
def _default_duckdb_session_factory(projects: BqProjects):
|
||||
"""Yield a pooled in-memory DuckDB conn with bigquery extension loaded
|
||||
+ SECRET set from get_metadata_token(). Translates auth / install
|
||||
failures to BqAccessError(kind='auth_failed' or 'bq_lib_missing').
|
||||
|
||||
Pooling: amortizes the ~0.5 s INSTALL/LOAD/ATTACH cost across requests
|
||||
by keeping pre-warmed connections in a bounded deque. Acquire reuses
|
||||
an existing entry when available (refreshing its auth SECRET so
|
||||
token rotation doesn't break long-lived entries) and probes liveness
|
||||
cheaply via ``SELECT 1`` before handing it to the caller. On normal
|
||||
exit the connection returns to the pool; on exception it's closed
|
||||
instead (the underlying session may carry dirty state).
|
||||
|
||||
Pool size is ``data_source.bigquery.session_pool_size`` (default 4;
|
||||
sentinel ``0`` disables pooling entirely, matching pre-pool
|
||||
behavior). Process-cached, not fork-safe.
|
||||
|
||||
Note: `projects.billing` is not used by this factory directly — bigquery_query()
|
||||
callers pass it themselves as the first positional arg to identify the billing
|
||||
project. The factory keeps the parameter for symmetry with _default_client_factory.
|
||||
"""
|
||||
pool_size = _default_pool_size()
|
||||
|
||||
# Acquire: prefer a warm entry, fall back to fresh build.
|
||||
conn = None
|
||||
if pool_size > 0:
|
||||
while True:
|
||||
with _pool_lock:
|
||||
entry = _pool.popleft() if _pool else None
|
||||
if entry is None:
|
||||
break
|
||||
if not _is_pool_entry_alive(entry):
|
||||
# Reaper: drop broken entries.
|
||||
try:
|
||||
entry.close()
|
||||
except Exception:
|
||||
pass
|
||||
continue
|
||||
try:
|
||||
# Refresh the auth SECRET so a long-lived pool entry
|
||||
# doesn't keep a stale token past its TTL. Cheap when
|
||||
# the token cache is warm.
|
||||
_refresh_bq_secret(entry)
|
||||
except BqAccessError:
|
||||
try:
|
||||
entry.close()
|
||||
except Exception:
|
||||
pass
|
||||
continue
|
||||
# Re-apply session settings (`bq_query_timeout_ms`, …) on
|
||||
# every reuse so an operator's `/admin/server-config` change
|
||||
# propagates to pooled entries without requiring container
|
||||
# restart. Without this, a long-lived pool entry keeps the
|
||||
# value baked in at first build forever (devil's-advocate
|
||||
# R1 finding #3). `apply_bq_session_settings` is idempotent
|
||||
# and fail-soft — re-running on every acquire is cheap.
|
||||
try:
|
||||
apply_bq_session_settings(entry)
|
||||
except Exception:
|
||||
# apply_bq_session_settings is documented as never
|
||||
# raising for legitimate "extension doesn't recognise
|
||||
# setting" cases (it only logs). Defensive guard for
|
||||
# any unforeseen failure mode — keep the entry, the
|
||||
# caller's actual query may still succeed.
|
||||
pass
|
||||
conn = entry
|
||||
break
|
||||
|
||||
if conn is None:
|
||||
conn = _build_fresh_bq_session()
|
||||
|
||||
try:
|
||||
yield conn
|
||||
finally:
|
||||
except Exception:
|
||||
# Caller saw an exception — the conn may be in a dirty state.
|
||||
# Don't return to pool; close to release native resources.
|
||||
try:
|
||||
conn.close()
|
||||
except Exception:
|
||||
pass
|
||||
raise
|
||||
else:
|
||||
# Normal exit — return to pool if there's room.
|
||||
if pool_size > 0:
|
||||
with _pool_lock:
|
||||
if len(_pool) < pool_size:
|
||||
_pool.append(conn)
|
||||
return
|
||||
# Pool disabled or full — close.
|
||||
try:
|
||||
conn.close()
|
||||
except Exception:
|
||||
pass
|
||||
|
||||
|
||||
def apply_bq_session_settings(conn) -> None:
|
||||
|
|
|
|||
|
|
@ -1,6 +1,6 @@
|
|||
[project]
|
||||
name = "agnes-the-ai-analyst"
|
||||
version = "0.38.3"
|
||||
version = "0.39.0"
|
||||
description = "Agnes — AI Data Analyst platform for AI analytical systems"
|
||||
requires-python = ">=3.11,<3.14"
|
||||
license = "MIT"
|
||||
|
|
@ -20,8 +20,13 @@ dependencies = [
|
|||
"itsdangerous>=2.1.0",
|
||||
"authlib>=1.6.11",
|
||||
"argon2-cffi>=23.1.0",
|
||||
# HTTP client
|
||||
# HTTP client. `h2` enables HTTP/2 multiplexing for the persistent
|
||||
# CLI client used by `agnes pull` (one TCP connection serves N
|
||||
# concurrent parquet streams + range chunks). `cli/client.py`
|
||||
# gracefully falls back to HTTP/1.1 if h2 is missing, so this
|
||||
# extra is for performance, not correctness.
|
||||
"httpx>=0.27.0",
|
||||
"h2>=4.1.0",
|
||||
# CLI
|
||||
"typer>=0.12.0",
|
||||
"rich>=13.0.0",
|
||||
|
|
|
|||
|
|
@ -1,5 +1,6 @@
|
|||
"""Tests for connectors/bigquery/access.py — the BqAccess facade."""
|
||||
import pytest
|
||||
import threading
|
||||
|
||||
|
||||
class TestBqProjects:
|
||||
|
|
@ -36,6 +37,12 @@ class TestBqAccessError:
|
|||
"bq_forbidden": 502,
|
||||
"bq_bad_request": 400,
|
||||
"bq_upstream_error": 502,
|
||||
# User-facing class for "Response too large to return" — an
|
||||
# upstream BQ refusal, but caused by query shape (too many rows
|
||||
# to fit in a single jobs.query response) rather than auth or
|
||||
# syntax. 400 so the user sees an actionable error and not a
|
||||
# 502 that suggests "BQ is broken".
|
||||
"bq_response_too_large": 400,
|
||||
}
|
||||
assert BqAccessError.HTTP_STATUS == expected
|
||||
|
||||
|
|
@ -143,6 +150,70 @@ class TestTranslateBqError:
|
|||
translate_bq_error(ValueError("not a BQ error"), self.projects,
|
||||
bad_request_status="client_error")
|
||||
|
||||
def test_response_too_large_via_gax_bad_request(self):
|
||||
"""BQ ``responseTooLarge`` arrives as ``gax.BadRequest`` (HTTP 400
|
||||
with a specific `reason` field). Pre-fix this fell through to the
|
||||
generic ``bq_bad_request`` mapping — surfacing as a 400 with the
|
||||
raw upstream message and no actionable hint. Now it routes to a
|
||||
dedicated ``bq_response_too_large`` kind whose message tells the
|
||||
user exactly what to do (narrow WHERE / aggregate / use materialized).
|
||||
"""
|
||||
from google.api_core.exceptions import BadRequest
|
||||
from connectors.bigquery.access import translate_bq_error
|
||||
e = BadRequest("Response too large to return. Consider setting allowLargeResults to true ...")
|
||||
result = translate_bq_error(
|
||||
e, self.projects, bad_request_status="client_error",
|
||||
)
|
||||
assert result.kind == "bq_response_too_large", (
|
||||
f"got {result.kind!r}; expected dedicated mapping for "
|
||||
"'Response too large' to avoid the generic bq_bad_request 400 "
|
||||
"with no actionable hint"
|
||||
)
|
||||
# User-facing message must point at the actionable remediations,
|
||||
# not just echo the raw BQ string.
|
||||
assert "exceeded" in result.message.lower() or "too large" in result.message.lower()
|
||||
assert "where" in result.message.lower() or "aggregate" in result.message.lower() or "materialized" in result.message.lower()
|
||||
# Original upstream text preserved in details for operator debugging.
|
||||
assert "original" in result.details
|
||||
assert "Response too large" in result.details["original"]
|
||||
|
||||
def test_response_too_large_via_duckdb_native_string(self):
|
||||
"""DuckDB-native exceptions (the BQ extension's C++ HTTP path)
|
||||
carry the same 'Response too large' marker in plain ``Exception``
|
||||
messages — must classify the same way as the gax.BadRequest case."""
|
||||
from connectors.bigquery.access import translate_bq_error
|
||||
e = Exception("HTTP 400: Response too large to return.")
|
||||
result = translate_bq_error(
|
||||
e, self.projects, bad_request_status="upstream_error",
|
||||
)
|
||||
assert result.kind == "bq_response_too_large"
|
||||
|
||||
def test_response_too_large_classification_is_status_independent(self):
|
||||
"""The mapping must fire regardless of ``bad_request_status``
|
||||
(some callers route via 'upstream_error', others via 'client_error').
|
||||
It's the BQ error shape that matters, not who's calling."""
|
||||
from google.api_core.exceptions import BadRequest
|
||||
from connectors.bigquery.access import translate_bq_error
|
||||
e = BadRequest("Response too large to return")
|
||||
for status in ("client_error", "upstream_error"):
|
||||
result = translate_bq_error(e, self.projects, bad_request_status=status)
|
||||
assert result.kind == "bq_response_too_large", (
|
||||
f"bad_request_status={status!r} routed to {result.kind!r}; "
|
||||
"expected bq_response_too_large for both"
|
||||
)
|
||||
|
||||
def test_response_too_large_does_not_trigger_on_unrelated_bad_request(self):
|
||||
"""Other BadRequests (syntax errors, malformed identifiers, …)
|
||||
must keep going through the generic bq_bad_request mapping — only
|
||||
the 'Response too large' substring triggers the dedicated kind."""
|
||||
from google.api_core.exceptions import BadRequest
|
||||
from connectors.bigquery.access import translate_bq_error
|
||||
e = BadRequest("Syntax error at [1:23] near unexpected token")
|
||||
result = translate_bq_error(
|
||||
e, self.projects, bad_request_status="client_error",
|
||||
)
|
||||
assert result.kind == "bq_bad_request"
|
||||
|
||||
|
||||
class TestDefaultClientFactory:
|
||||
def test_constructs_client_with_billing_project_as_quota(self, monkeypatch):
|
||||
|
|
@ -208,8 +279,21 @@ class TestDefaultClientFactory:
|
|||
|
||||
|
||||
class TestDefaultDuckdbSessionFactory:
|
||||
def test_yields_duckdb_conn_with_secret_then_closes(self, monkeypatch):
|
||||
from connectors.bigquery.access import _default_duckdb_session_factory, BqProjects
|
||||
def test_yields_duckdb_conn_with_secret_set_via_pool(self, monkeypatch):
|
||||
"""The pool's first acquire on an empty pool runs the full
|
||||
INSTALL/LOAD/SECRET sequence. After the with-block exits the
|
||||
connection is RETURNED to the pool (not closed) so the next
|
||||
acquire amortizes the extension-load cost.
|
||||
|
||||
Pre-pool semantics (close-on-exit) are preserved on broken
|
||||
entries + on the explicit pool-reset path; covered in
|
||||
TestBqSessionPool.
|
||||
"""
|
||||
from connectors.bigquery.access import (
|
||||
_default_duckdb_session_factory, BqProjects,
|
||||
_reset_session_pool_for_tests,
|
||||
)
|
||||
_reset_session_pool_for_tests()
|
||||
|
||||
executed_sql = []
|
||||
|
||||
|
|
@ -218,7 +302,10 @@ class TestDefaultDuckdbSessionFactory:
|
|||
self.closed = False
|
||||
def execute(self, sql, params=None):
|
||||
executed_sql.append((sql, params))
|
||||
return self
|
||||
class _Result:
|
||||
def fetchone(self_inner):
|
||||
return (1,)
|
||||
return _Result()
|
||||
def close(self):
|
||||
self.closed = True
|
||||
|
||||
|
|
@ -228,19 +315,36 @@ class TestDefaultDuckdbSessionFactory:
|
|||
|
||||
with _default_duckdb_session_factory(BqProjects(billing="b", data="d")) as conn:
|
||||
assert conn is fake_conn
|
||||
assert fake_conn.closed is True
|
||||
# Pool retains the conn — close happens at pool reset / shutdown.
|
||||
assert fake_conn.closed is False
|
||||
|
||||
# Verify INSTALL/LOAD/SECRET sequence ran
|
||||
assert any("INSTALL bigquery" in sql for sql, _ in executed_sql)
|
||||
assert any("LOAD bigquery" in sql for sql, _ in executed_sql)
|
||||
assert any("CREATE OR REPLACE SECRET" in sql and "tok123" in sql for sql, _ in executed_sql)
|
||||
|
||||
# Explicit pool reset closes the retained entry.
|
||||
_reset_session_pool_for_tests()
|
||||
assert fake_conn.closed is True
|
||||
|
||||
def test_closes_on_exception_inside_with_block(self, monkeypatch):
|
||||
from connectors.bigquery.access import _default_duckdb_session_factory, BqProjects
|
||||
"""Exceptions inside the with-block leave the underlying conn in
|
||||
an unknown state (half-completed query, dirty session); the pool
|
||||
treats it as broken and closes it rather than returning to pool.
|
||||
"""
|
||||
from connectors.bigquery.access import (
|
||||
_default_duckdb_session_factory, BqProjects,
|
||||
_reset_session_pool_for_tests,
|
||||
)
|
||||
_reset_session_pool_for_tests()
|
||||
|
||||
class FakeConn:
|
||||
closed = False
|
||||
def execute(self, *a, **kw): return self
|
||||
def execute(self, *a, **kw):
|
||||
class _Result:
|
||||
def fetchone(self_inner):
|
||||
return (1,)
|
||||
return _Result()
|
||||
def close(self): self.closed = True
|
||||
|
||||
fake_conn = FakeConn()
|
||||
|
|
@ -449,3 +553,222 @@ class TestGetBqAccess:
|
|||
assert a is b
|
||||
assert isinstance(a, BqAccess)
|
||||
assert a.projects.billing == ""
|
||||
|
||||
|
||||
# ---------------------------------------------------------------------------
|
||||
# DuckDB BQ-extension session pool — amortizes the ~0.5 s INSTALL/LOAD/ATTACH
|
||||
# cost across requests by keeping pre-warmed DuckDB connections in a
|
||||
# bounded pool. Each acquire reuses an existing connection (refreshing the
|
||||
# auth SECRET so token rotation doesn't break long-lived entries) instead
|
||||
# of spinning up a fresh DuckDB+extension load every time.
|
||||
# ---------------------------------------------------------------------------
|
||||
|
||||
|
||||
class _PoolFakeConn:
|
||||
"""Fake DuckDB connection that records executed SQL and supports
|
||||
``close()``. Used across pool tests so we can pin behavior without
|
||||
booting the real BigQuery extension."""
|
||||
_serial = 0
|
||||
|
||||
def __init__(self):
|
||||
type(self)._serial += 1
|
||||
self.id = type(self)._serial
|
||||
self.closed = False
|
||||
self.executed: list[str] = []
|
||||
|
||||
def execute(self, sql, params=None):
|
||||
self.executed.append(sql)
|
||||
# Liveness probe: SELECT 1 returns something fetchable.
|
||||
class _Result:
|
||||
def fetchone(self_inner):
|
||||
return (1,)
|
||||
def fetchall(self_inner):
|
||||
return [(1,)]
|
||||
return _Result()
|
||||
|
||||
def close(self):
|
||||
self.closed = True
|
||||
|
||||
|
||||
@pytest.fixture
|
||||
def reset_pool(monkeypatch):
|
||||
"""Reset the BQ session pool singleton between tests so leak-detection
|
||||
assertions don't carry state."""
|
||||
from connectors.bigquery import access as bq_access_mod
|
||||
if hasattr(bq_access_mod, "_reset_session_pool_for_tests"):
|
||||
bq_access_mod._reset_session_pool_for_tests()
|
||||
monkeypatch.setattr(
|
||||
"connectors.bigquery.auth.get_metadata_token",
|
||||
lambda: "tok-pool",
|
||||
)
|
||||
yield
|
||||
if hasattr(bq_access_mod, "_reset_session_pool_for_tests"):
|
||||
bq_access_mod._reset_session_pool_for_tests()
|
||||
|
||||
|
||||
class TestBqSessionPool:
|
||||
def test_pool_reuses_connections_across_acquires(self, monkeypatch, reset_pool):
|
||||
"""Acquiring a session, releasing, then acquiring again must return
|
||||
the SAME underlying DuckDB connection — no INSTALL/LOAD overhead on
|
||||
the second request. This is the whole point of the pool."""
|
||||
from connectors.bigquery.access import _default_duckdb_session_factory, BqProjects
|
||||
|
||||
# Each duckdb.connect() yields a fresh _PoolFakeConn so we can tell
|
||||
# them apart by id.
|
||||
connections_made = []
|
||||
def fake_connect(_path):
|
||||
c = _PoolFakeConn()
|
||||
connections_made.append(c)
|
||||
return c
|
||||
monkeypatch.setattr("duckdb.connect", fake_connect)
|
||||
|
||||
# First acquire: pool is empty, factory builds a new entry.
|
||||
with _default_duckdb_session_factory(BqProjects(billing="b", data="d")) as conn1:
|
||||
id1 = conn1.id
|
||||
|
||||
# Second acquire: pool has a warm entry, must hand back the same conn.
|
||||
with _default_duckdb_session_factory(BqProjects(billing="b", data="d")) as conn2:
|
||||
id2 = conn2.id
|
||||
|
||||
assert id1 == id2, (
|
||||
"expected the same pooled connection across two acquires; "
|
||||
f"got id1={id1}, id2={id2}"
|
||||
)
|
||||
# And we must NOT have re-INSTALLed/LOADed the extension on reuse —
|
||||
# only one duckdb.connect() call ever happened.
|
||||
assert len(connections_made) == 1, (
|
||||
f"pool re-built the conn on second acquire; created {len(connections_made)}"
|
||||
)
|
||||
|
||||
def test_pool_size_is_configurable(self, monkeypatch, reset_pool):
|
||||
"""``data_source.bigquery.session_pool_size`` controls the upper
|
||||
bound on warm entries. Above the cap, releasing extra entries
|
||||
closes them rather than retaining."""
|
||||
from connectors.bigquery.access import _default_duckdb_session_factory, BqProjects
|
||||
|
||||
def fake_get_value(*keys, default=None):
|
||||
if keys == ("data_source", "bigquery", "session_pool_size"):
|
||||
return 2 # tiny pool
|
||||
if keys == ("data_source", "bigquery", "query_timeout_ms"):
|
||||
return 0 # don't try to SET timeout in tests
|
||||
return default
|
||||
|
||||
monkeypatch.setattr("app.instance_config.get_value", fake_get_value)
|
||||
monkeypatch.setattr("duckdb.connect", lambda _: _PoolFakeConn())
|
||||
|
||||
# Acquire 3 in parallel to force 3 simultaneous entries.
|
||||
cm1 = _default_duckdb_session_factory(BqProjects(billing="b", data="d"))
|
||||
c1 = cm1.__enter__()
|
||||
cm2 = _default_duckdb_session_factory(BqProjects(billing="b", data="d"))
|
||||
c2 = cm2.__enter__()
|
||||
cm3 = _default_duckdb_session_factory(BqProjects(billing="b", data="d"))
|
||||
c3 = cm3.__enter__()
|
||||
|
||||
# Release all three. The 3rd release should close the conn since
|
||||
# the pool already has 2.
|
||||
cm1.__exit__(None, None, None)
|
||||
cm2.__exit__(None, None, None)
|
||||
cm3.__exit__(None, None, None)
|
||||
|
||||
# At least one of the three connections must be closed (pool overflow).
|
||||
closed_count = sum(1 for c in (c1, c2, c3) if c.closed)
|
||||
assert closed_count >= 1, (
|
||||
"pool retained more than its configured size; expected at least "
|
||||
f"one close. closed_count={closed_count}"
|
||||
)
|
||||
# Pool retained at most `size` entries, so total live + closed = 3,
|
||||
# closed >= 1 means pool size <= 2.
|
||||
assert closed_count == 1
|
||||
|
||||
def test_pool_replaces_broken_connection(self, monkeypatch, reset_pool):
|
||||
"""If a pooled entry's liveness check fails on acquire (the
|
||||
underlying DuckDB conn was closed externally, BQ extension state
|
||||
corrupted, etc.), the pool must drop it and build a fresh entry —
|
||||
not hand the broken one to the caller."""
|
||||
from connectors.bigquery.access import _default_duckdb_session_factory, BqProjects
|
||||
|
||||
# First acquire creates entry #1; we'll then mark it broken.
|
||||
all_conns: list[_PoolFakeConn] = []
|
||||
def fake_connect(_path):
|
||||
c = _PoolFakeConn()
|
||||
all_conns.append(c)
|
||||
return c
|
||||
monkeypatch.setattr("duckdb.connect", fake_connect)
|
||||
|
||||
with _default_duckdb_session_factory(BqProjects(billing="b", data="d")) as conn1:
|
||||
id1 = conn1.id
|
||||
# Simulate corruption: make execute() raise on next call.
|
||||
def broken_execute(*a, **kw):
|
||||
raise RuntimeError("connection broken")
|
||||
conn1.execute = broken_execute # type: ignore[assignment]
|
||||
|
||||
# Second acquire must skip the broken entry and build a fresh one.
|
||||
with _default_duckdb_session_factory(BqProjects(billing="b", data="d")) as conn2:
|
||||
id2 = conn2.id
|
||||
|
||||
assert id1 != id2, (
|
||||
f"expected a fresh conn after broken-pool reaper; both acquires "
|
||||
f"returned id={id1}"
|
||||
)
|
||||
assert len(all_conns) >= 2
|
||||
|
||||
def test_pool_handles_reentrant_acquires_thread_safe(self, monkeypatch, reset_pool):
|
||||
"""Concurrent acquires from multiple threads must never hand the
|
||||
same underlying DuckDB conn to two threads at once. The pool's
|
||||
lock acquires/releases are the load-bearing invariant here.
|
||||
"""
|
||||
from connectors.bigquery.access import _default_duckdb_session_factory, BqProjects
|
||||
|
||||
monkeypatch.setattr("duckdb.connect", lambda _: _PoolFakeConn())
|
||||
|
||||
active_ids: set = set()
|
||||
active_lock = threading.Lock()
|
||||
violations: list = []
|
||||
|
||||
def worker():
|
||||
for _ in range(20):
|
||||
with _default_duckdb_session_factory(
|
||||
BqProjects(billing="b", data="d"),
|
||||
) as conn:
|
||||
with active_lock:
|
||||
if conn.id in active_ids:
|
||||
violations.append(conn.id)
|
||||
active_ids.add(conn.id)
|
||||
# Hold briefly to give other threads a chance to race.
|
||||
time.sleep(0.001)
|
||||
with active_lock:
|
||||
active_ids.discard(conn.id)
|
||||
|
||||
import time
|
||||
threads = [threading.Thread(target=worker) for _ in range(4)]
|
||||
for t in threads:
|
||||
t.start()
|
||||
for t in threads:
|
||||
t.join()
|
||||
|
||||
assert not violations, (
|
||||
f"pool handed the same conn to multiple threads concurrently: "
|
||||
f"{violations}"
|
||||
)
|
||||
|
||||
def test_pool_does_not_apply_when_factory_is_injected(self, monkeypatch, reset_pool):
|
||||
"""Test fixtures that inject a custom ``duckdb_session_factory``
|
||||
(e.g. tests/conftest.py's ``bq_access`` fixture) MUST bypass the
|
||||
pool entirely — otherwise their nullcontext-wrapped fake would
|
||||
get retained between tests and corrupt downstream assertions.
|
||||
"""
|
||||
from connectors.bigquery.access import BqAccess, BqProjects
|
||||
from contextlib import contextmanager
|
||||
|
||||
sentinel = object()
|
||||
|
||||
@contextmanager
|
||||
def custom_factory(_projects):
|
||||
yield sentinel
|
||||
|
||||
bq = BqAccess(
|
||||
BqProjects(billing="b", data="d"),
|
||||
duckdb_session_factory=custom_factory,
|
||||
)
|
||||
with bq.duckdb_session() as conn:
|
||||
assert conn is sentinel
|
||||
|
|
|
|||
399
tests/test_pull_chunked.py
Normal file
399
tests/test_pull_chunked.py
Normal file
|
|
@ -0,0 +1,399 @@
|
|||
"""Tests for range-based chunked download in cli/client.py:stream_download.
|
||||
|
||||
Background — the previous diagnosis measured `agnes pull` on a single 5.1 GB
|
||||
materialized parquet at 0.29 MB/s on a corp VPN with per-flow rate-limiting;
|
||||
4 parallel range requests over the same connection sustained 1.65 MB/s
|
||||
aggregate. Existing `AGNES_PULL_PARALLELISM=4` parallelizes across files,
|
||||
not within a file, so a manifest with 1 large materialized parquet + 10
|
||||
remote tables yields 1 active worker = single-stream throughput.
|
||||
|
||||
These tests exercise the chunking code path: HEAD probe, Range-request
|
||||
splitting, fallback when the server doesn't honor ranges, cleanup on
|
||||
chunk failure, and the small-file bypass.
|
||||
"""
|
||||
|
||||
from __future__ import annotations
|
||||
|
||||
import os
|
||||
import threading
|
||||
from pathlib import Path
|
||||
from unittest.mock import patch
|
||||
|
||||
import pytest
|
||||
|
||||
|
||||
# ── Fake HTTP layer ─────────────────────────────────────────────────────
|
||||
# The real httpx Client / AsyncClient surface is large; we mock at the
|
||||
# client-method level. Our `stream_download` should:
|
||||
# 1. Call HEAD to learn `content-length` + `accept-ranges`.
|
||||
# 2. If ranges supported and size > threshold, issue N parallel
|
||||
# `GET` with `Range: bytes=A-B`, each returning 206 + body chunk.
|
||||
# 3. Concatenate part files into the destination.
|
||||
|
||||
class _FakeResponse:
|
||||
def __init__(self, status_code: int, headers: dict | None = None,
|
||||
body: bytes = b""):
|
||||
self.status_code = status_code
|
||||
self.headers = headers or {}
|
||||
self._body = body
|
||||
|
||||
def raise_for_status(self):
|
||||
if self.status_code >= 400:
|
||||
import httpx
|
||||
raise httpx.HTTPStatusError(
|
||||
f"HTTP {self.status_code}", request=None, response=self,
|
||||
)
|
||||
|
||||
def iter_bytes(self, chunk_size: int = 65536):
|
||||
# Yield in chunk_size pieces so the sink loop runs realistically.
|
||||
for i in range(0, len(self._body), chunk_size):
|
||||
yield self._body[i:i + chunk_size]
|
||||
|
||||
def __enter__(self):
|
||||
return self
|
||||
|
||||
def __exit__(self, *a):
|
||||
return False
|
||||
|
||||
|
||||
class _FakeClient:
|
||||
"""Captures calls + returns canned responses."""
|
||||
|
||||
def __init__(self, *, body: bytes, accept_ranges: bool = True,
|
||||
reject_range_with_200: bool = False,
|
||||
fail_chunk_indices: tuple[int, ...] = (),
|
||||
head_status: int = 200):
|
||||
self._body = body
|
||||
self._accept_ranges = accept_ranges
|
||||
self._reject_range_with_200 = reject_range_with_200
|
||||
self._fail_chunk_indices = set(fail_chunk_indices)
|
||||
self._head_status = head_status
|
||||
self.head_calls = 0
|
||||
self.range_calls: list[tuple[int, int]] = []
|
||||
self.full_get_calls = 0
|
||||
self._lock = threading.Lock()
|
||||
self._chunk_attempt_counts: dict[tuple[int, int], int] = {}
|
||||
|
||||
# `stream_download` calls `client.head(path)` once to probe.
|
||||
def head(self, path: str, **kwargs):
|
||||
with self._lock:
|
||||
self.head_calls += 1
|
||||
if self._head_status >= 400:
|
||||
return _FakeResponse(self._head_status)
|
||||
headers = {"content-length": str(len(self._body))}
|
||||
if self._accept_ranges:
|
||||
headers["accept-ranges"] = "bytes"
|
||||
return _FakeResponse(200, headers=headers)
|
||||
|
||||
# `stream_download` uses `client.stream("GET", path, headers=...)`
|
||||
# for both the chunked and full-file paths. Range header presence
|
||||
# tells us which one.
|
||||
def stream(self, method: str, path: str, *, headers: dict | None = None,
|
||||
**kwargs):
|
||||
rng = (headers or {}).get("Range") or (headers or {}).get("range")
|
||||
if rng:
|
||||
# bytes=START-END
|
||||
spec = rng.split("=", 1)[1]
|
||||
start_s, end_s = spec.split("-", 1)
|
||||
start = int(start_s)
|
||||
end = int(end_s)
|
||||
with self._lock:
|
||||
self.range_calls.append((start, end))
|
||||
key = (start, end)
|
||||
attempt = self._chunk_attempt_counts.get(key, 0)
|
||||
self._chunk_attempt_counts[key] = attempt + 1
|
||||
# Determine chunk index (in order of unique starts).
|
||||
# We map by start to a stable index for fail-injection.
|
||||
chunk_idx = self._chunk_index_for_start(start)
|
||||
# Should this attempt fail? Fail only on first attempt for
|
||||
# listed indices — retry succeeds.
|
||||
if chunk_idx in self._fail_chunk_indices and attempt == 0:
|
||||
import httpx
|
||||
raise httpx.ReadError("simulated chunk failure")
|
||||
if self._reject_range_with_200:
|
||||
# Server ignored Range — returns full body with 200.
|
||||
return _FakeResponse(200, body=self._body)
|
||||
piece = self._body[start:end + 1]
|
||||
return _FakeResponse(
|
||||
206,
|
||||
headers={"content-range": f"bytes {start}-{end}/{len(self._body)}"},
|
||||
body=piece,
|
||||
)
|
||||
# Full-file GET (single-stream fallback).
|
||||
with self._lock:
|
||||
self.full_get_calls += 1
|
||||
return _FakeResponse(200, body=self._body)
|
||||
|
||||
def _chunk_index_for_start(self, start: int) -> int:
|
||||
# Unique sorted starts so fail_chunk_indices is deterministic.
|
||||
starts = sorted({s for s, _ in self.range_calls})
|
||||
try:
|
||||
return starts.index(start)
|
||||
except ValueError:
|
||||
return -1
|
||||
|
||||
def __enter__(self):
|
||||
return self
|
||||
|
||||
def __exit__(self, *a):
|
||||
return False
|
||||
|
||||
def close(self):
|
||||
pass
|
||||
|
||||
|
||||
# ── Test fixtures ───────────────────────────────────────────────────────
|
||||
|
||||
@pytest.fixture(autouse=True)
|
||||
def _isolate_config_dir(tmp_path, monkeypatch):
|
||||
cfg = tmp_path / "_cfg"
|
||||
cfg.mkdir()
|
||||
monkeypatch.setenv("AGNES_CONFIG_DIR", str(cfg))
|
||||
|
||||
|
||||
@pytest.fixture(autouse=True)
|
||||
def _reset_shared_client(monkeypatch):
|
||||
"""Reset the persistent shared httpx.Client between tests so each
|
||||
test starts from a known state. Tests that need to inject a fake
|
||||
client also stub `_get_shared_client` directly via the
|
||||
`_inject_fake_client` helper below."""
|
||||
import cli.client as cc
|
||||
if hasattr(cc, "_SHARED_CLIENT"):
|
||||
monkeypatch.setattr(cc, "_SHARED_CLIENT", None, raising=False)
|
||||
yield
|
||||
if hasattr(cc, "_SHARED_CLIENT"):
|
||||
monkeypatch.setattr(cc, "_SHARED_CLIENT", None, raising=False)
|
||||
|
||||
|
||||
def _inject_fake_client(monkeypatch, fake):
|
||||
"""Patch both client factories to return the same fake. Tests target
|
||||
`_get_shared_client` (the path stream_download actually takes) and
|
||||
also `get_client` so the fallback path also lands on the fake."""
|
||||
monkeypatch.setattr("cli.client.get_client", lambda timeout=300.0: fake)
|
||||
monkeypatch.setattr("cli.client._get_shared_client",
|
||||
lambda: fake, raising=False)
|
||||
|
||||
|
||||
# ── Tests ───────────────────────────────────────────────────────────────
|
||||
|
||||
def test_chunked_download_success(tmp_path, monkeypatch):
|
||||
"""Server advertises ranges, file is large enough — 4 chunks, assembled
|
||||
correctly into target."""
|
||||
body = bytes(range(256)) * 2048 # 512 KB
|
||||
threshold = 1024 # 1 KB so 512 KB is "large"
|
||||
monkeypatch.setenv("AGNES_PULL_CHUNK_THRESHOLD_BYTES", str(threshold))
|
||||
monkeypatch.setenv("AGNES_PULL_CHUNK_PARALLELISM", "4")
|
||||
|
||||
fake = _FakeClient(body=body, accept_ranges=True)
|
||||
_inject_fake_client(monkeypatch, fake)
|
||||
|
||||
from cli.client import stream_download
|
||||
target = tmp_path / "out.parquet"
|
||||
progress_bytes = []
|
||||
total = stream_download("/api/data/x/download", str(target),
|
||||
progress_callback=lambda n: progress_bytes.append(n))
|
||||
|
||||
assert total == len(body)
|
||||
assert target.read_bytes() == body
|
||||
# 4 distinct ranges issued (no overlaps; last one carries remainder).
|
||||
assert len(set(fake.range_calls)) == 4
|
||||
assert fake.head_calls == 1
|
||||
assert fake.full_get_calls == 0
|
||||
# Progress callback was called and total bytes match.
|
||||
assert sum(progress_bytes) == len(body)
|
||||
# Chunk parts cleaned up.
|
||||
leftovers = list(tmp_path.glob("*.part*"))
|
||||
assert leftovers == [], f"orphan part files: {leftovers}"
|
||||
|
||||
|
||||
def test_chunked_download_fallback_when_server_ignores_range(
|
||||
tmp_path, monkeypatch,
|
||||
):
|
||||
"""Server returns 200 instead of 206 on the first range probe — abort
|
||||
chunked path, fall back to single-stream. No corrupt output."""
|
||||
body = b"X" * 200_000
|
||||
monkeypatch.setenv("AGNES_PULL_CHUNK_THRESHOLD_BYTES", "1024")
|
||||
monkeypatch.setenv("AGNES_PULL_CHUNK_PARALLELISM", "4")
|
||||
|
||||
# accept_ranges=True (HEAD lies), but every Range GET returns 200
|
||||
# with the full body — that's the "server ignored Range" path.
|
||||
fake = _FakeClient(body=body, accept_ranges=True,
|
||||
reject_range_with_200=True)
|
||||
_inject_fake_client(monkeypatch, fake)
|
||||
|
||||
from cli.client import stream_download
|
||||
target = tmp_path / "out.bin"
|
||||
total = stream_download("/api/data/x/download", str(target))
|
||||
|
||||
assert total == len(body)
|
||||
assert target.read_bytes() == body
|
||||
# Fell back to a single full-body GET.
|
||||
assert fake.full_get_calls >= 1
|
||||
|
||||
|
||||
def test_small_file_uses_single_stream_path(tmp_path, monkeypatch):
|
||||
"""Below threshold → no HEAD probe needed (or HEAD short-circuits),
|
||||
no Range requests, plain single-stream download."""
|
||||
body = b"x" * 500 # tiny
|
||||
monkeypatch.setenv("AGNES_PULL_CHUNK_THRESHOLD_BYTES", "10000") # 10 KB
|
||||
monkeypatch.setenv("AGNES_PULL_CHUNK_PARALLELISM", "4")
|
||||
|
||||
fake = _FakeClient(body=body, accept_ranges=True)
|
||||
_inject_fake_client(monkeypatch, fake)
|
||||
|
||||
from cli.client import stream_download
|
||||
target = tmp_path / "out.bin"
|
||||
total = stream_download("/api/data/x/download", str(target))
|
||||
|
||||
assert total == len(body)
|
||||
assert target.read_bytes() == body
|
||||
assert fake.range_calls == [], "small file must not split into ranges"
|
||||
assert fake.full_get_calls >= 1
|
||||
|
||||
|
||||
def test_chunked_download_no_accept_ranges_falls_back(tmp_path, monkeypatch):
|
||||
"""HEAD doesn't advertise byte-range support → skip chunked path,
|
||||
plain single-stream."""
|
||||
body = b"y" * 200_000
|
||||
monkeypatch.setenv("AGNES_PULL_CHUNK_THRESHOLD_BYTES", "1024")
|
||||
monkeypatch.setenv("AGNES_PULL_CHUNK_PARALLELISM", "4")
|
||||
|
||||
fake = _FakeClient(body=body, accept_ranges=False)
|
||||
_inject_fake_client(monkeypatch, fake)
|
||||
|
||||
from cli.client import stream_download
|
||||
target = tmp_path / "out.bin"
|
||||
total = stream_download("/api/data/x/download", str(target))
|
||||
|
||||
assert total == len(body)
|
||||
assert target.read_bytes() == body
|
||||
assert fake.range_calls == []
|
||||
assert fake.full_get_calls >= 1
|
||||
|
||||
|
||||
def test_chunked_download_one_chunk_retries_then_succeeds(
|
||||
tmp_path, monkeypatch,
|
||||
):
|
||||
"""One chunk fails on first attempt; retry path completes the file."""
|
||||
body = bytes(range(256)) * 1024 # 256 KB
|
||||
monkeypatch.setenv("AGNES_PULL_CHUNK_THRESHOLD_BYTES", "1024")
|
||||
monkeypatch.setenv("AGNES_PULL_CHUNK_PARALLELISM", "4")
|
||||
monkeypatch.setenv("AGNES_STREAM_RETRIES", "2")
|
||||
|
||||
fake = _FakeClient(body=body, accept_ranges=True,
|
||||
fail_chunk_indices=(1,)) # second chunk blips once
|
||||
_inject_fake_client(monkeypatch, fake)
|
||||
|
||||
from cli.client import stream_download
|
||||
target = tmp_path / "out.bin"
|
||||
total = stream_download("/api/data/x/download", str(target))
|
||||
|
||||
assert total == len(body)
|
||||
assert target.read_bytes() == body
|
||||
# Cleanup of all part files.
|
||||
assert list(tmp_path.glob("*.part*")) == []
|
||||
|
||||
|
||||
def test_chunked_download_failure_cleans_up_part_files(tmp_path, monkeypatch):
|
||||
"""All retries exhausted on a chunk → no destination file, no orphan
|
||||
part files."""
|
||||
body = b"z" * 200_000
|
||||
monkeypatch.setenv("AGNES_PULL_CHUNK_THRESHOLD_BYTES", "1024")
|
||||
monkeypatch.setenv("AGNES_PULL_CHUNK_PARALLELISM", "4")
|
||||
monkeypatch.setenv("AGNES_STREAM_RETRIES", "0")
|
||||
|
||||
# Inject a permanent failure on chunk 2 (retries=0 → first failure
|
||||
# is fatal).
|
||||
class _ChronicFail(_FakeClient):
|
||||
def stream(self, method, path, *, headers=None, **kwargs):
|
||||
rng = (headers or {}).get("Range")
|
||||
if rng:
|
||||
spec = rng.split("=", 1)[1]
|
||||
start = int(spec.split("-", 1)[0])
|
||||
# Permanently fail the chunk starting at exactly half.
|
||||
if start >= len(body) // 4 and start <= len(body) // 2:
|
||||
import httpx
|
||||
raise httpx.ReadError("permanent")
|
||||
return super().stream(method, path, headers=headers, **kwargs)
|
||||
return super().stream(method, path, headers=headers, **kwargs)
|
||||
|
||||
fake = _ChronicFail(body=body, accept_ranges=True)
|
||||
_inject_fake_client(monkeypatch, fake)
|
||||
|
||||
from cli.client import stream_download
|
||||
target = tmp_path / "out.bin"
|
||||
with pytest.raises(Exception):
|
||||
stream_download("/api/data/x/download", str(target))
|
||||
|
||||
assert not target.exists(), "no destination file after total failure"
|
||||
# No orphan parts.
|
||||
assert list(tmp_path.glob("*.part*")) == []
|
||||
assert not (tmp_path / "out.bin.tmp").exists()
|
||||
|
||||
|
||||
def test_progress_callback_aggregates_across_chunks(tmp_path, monkeypatch):
|
||||
"""The progress callback should fire with byte deltas summing to the
|
||||
full file across all chunks — caller treats one file as one task."""
|
||||
body = bytes(range(256)) * 4096 # 1 MB
|
||||
monkeypatch.setenv("AGNES_PULL_CHUNK_THRESHOLD_BYTES", "1024")
|
||||
monkeypatch.setenv("AGNES_PULL_CHUNK_PARALLELISM", "4")
|
||||
|
||||
fake = _FakeClient(body=body, accept_ranges=True)
|
||||
_inject_fake_client(monkeypatch, fake)
|
||||
|
||||
from cli.client import stream_download
|
||||
target = tmp_path / "out.bin"
|
||||
advances = []
|
||||
stream_download("/api/data/x/download", str(target),
|
||||
progress_callback=lambda n: advances.append(n))
|
||||
assert sum(advances) == len(body)
|
||||
|
||||
|
||||
def test_dead_pid_leftovers_are_reaped(tmp_path, monkeypatch):
|
||||
"""Devil's-advocate R3 finding #1: PID-suffixed `<target>.{pid}.tmp`
|
||||
and `.partN` files from a SIGKILL'd previous pull must be reaped on
|
||||
the next pull, otherwise they accumulate on disk indefinitely.
|
||||
|
||||
PID 1 (init) is always alive, so a file with pid=1 must NOT be
|
||||
reaped. PID 99999999 (~10⁸) is essentially guaranteed not-alive on
|
||||
any modern Linux/macOS — used as the dead-PID marker.
|
||||
"""
|
||||
target = tmp_path / "out.bin"
|
||||
|
||||
# Live-PID leftover (pid=1 = init, always alive). Must NOT be reaped.
|
||||
live_path = tmp_path / "out.bin.1.tmp"
|
||||
live_path.write_bytes(b"live process leftover")
|
||||
|
||||
# Dead-PID leftovers — both .tmp and .part0 forms.
|
||||
dead_tmp = tmp_path / "out.bin.99999999.tmp"
|
||||
dead_tmp.write_bytes(b"dead process leftover tmp")
|
||||
dead_part = tmp_path / "out.bin.99999999.part0"
|
||||
dead_part.write_bytes(b"dead process leftover part")
|
||||
|
||||
# Bare-name leftover (no PID suffix) — pre-existing pattern, NOT
|
||||
# touched by the new reaper. Reaper only matches `.{digits}.tmp`
|
||||
# / `.{digits}.partN` exactly.
|
||||
bare_tmp = tmp_path / "out.bin.tmp"
|
||||
bare_tmp.write_bytes(b"bare leftover")
|
||||
|
||||
from cli.client import _reap_dead_pid_leftovers
|
||||
_reap_dead_pid_leftovers(str(target))
|
||||
|
||||
assert live_path.exists(), "live-PID leftover must be preserved"
|
||||
assert not dead_tmp.exists(), "dead-PID .tmp must be reaped"
|
||||
assert not dead_part.exists(), "dead-PID .partN must be reaped"
|
||||
assert bare_tmp.exists(), "bare-name leftover is out of scope for the reaper"
|
||||
|
||||
|
||||
def test_reap_handles_garbage_in_filename(tmp_path):
|
||||
"""Files in the parquet dir whose names happen to glob-match but
|
||||
don't conform to the PID-suffix shape must be skipped without
|
||||
raising."""
|
||||
target = tmp_path / "out.bin"
|
||||
weird = tmp_path / "out.bin.garbage.tmp"
|
||||
weird.write_bytes(b"x")
|
||||
|
||||
from cli.client import _reap_dead_pid_leftovers
|
||||
# Must not raise even though the filename has no integer PID.
|
||||
_reap_dead_pid_leftovers(str(target))
|
||||
assert weird.exists(), "non-PID-shaped file must not be reaped"
|
||||
123
tests/test_pull_progress.py
Normal file
123
tests/test_pull_progress.py
Normal file
|
|
@ -0,0 +1,123 @@
|
|||
"""Tests for `agnes pull` progress UX (Change 3).
|
||||
|
||||
The Rich progress bar handles the TTY case fine, but Claude Code's
|
||||
SessionStart context — and any hook running `agnes pull` non-interactively —
|
||||
has stderr connected to a pipe, not a TTY. In that case Rich either
|
||||
suppresses output entirely or emits raw ANSI noise into the consumer's
|
||||
log. Goal: when the caller asks for progress and stderr is not a TTY,
|
||||
emit a plain-text per-10%-or-30s update so the operator gets *some*
|
||||
signal instead of multi-minute silence.
|
||||
"""
|
||||
|
||||
from __future__ import annotations
|
||||
|
||||
import io
|
||||
import time
|
||||
from pathlib import Path
|
||||
from unittest.mock import MagicMock
|
||||
|
||||
import pytest
|
||||
|
||||
|
||||
@pytest.fixture(autouse=True)
|
||||
def _isolate_config_dir(tmp_path, monkeypatch):
|
||||
cfg = tmp_path / "_cfg"
|
||||
cfg.mkdir()
|
||||
monkeypatch.setenv("AGNES_CONFIG_DIR", str(cfg))
|
||||
|
||||
|
||||
@pytest.fixture
|
||||
def fake_pull_io(monkeypatch):
|
||||
"""Stub the manifest + memory + download endpoints so run_pull can
|
||||
execute end-to-end with a fake parquet write per table."""
|
||||
canned_manifest = {
|
||||
"tables": {
|
||||
"tbl_big": {"hash": "h1", "rows": 0, "size_bytes": 1_000_000},
|
||||
},
|
||||
}
|
||||
canned_memory = {"mandatory": [], "approved": []}
|
||||
|
||||
def _api_get(path, *args, **kwargs):
|
||||
resp = MagicMock()
|
||||
resp.status_code = 200
|
||||
if path == "/api/sync/manifest":
|
||||
resp.json.return_value = canned_manifest
|
||||
elif path == "/api/memory/bundle":
|
||||
resp.json.return_value = canned_memory
|
||||
resp.raise_for_status = lambda: None
|
||||
return resp
|
||||
|
||||
def _stream_download(path, target_path, progress_callback=None):
|
||||
# Simulate a chunked download: emit progress in 4 increments
|
||||
# totaling the announced size.
|
||||
total = 1_000_000
|
||||
slices = [total // 4] * 3 + [total - 3 * (total // 4)]
|
||||
Path(target_path).write_bytes(b"PAR1" + b"\x00" * 1000 + b"PAR1")
|
||||
if progress_callback:
|
||||
for s in slices:
|
||||
progress_callback(s)
|
||||
return total
|
||||
|
||||
monkeypatch.setattr("cli.lib.pull.api_get", _api_get, raising=False)
|
||||
monkeypatch.setattr("cli.lib.pull.stream_download", _stream_download,
|
||||
raising=False)
|
||||
monkeypatch.setattr("cli.lib.pull._is_valid_parquet", lambda p: True,
|
||||
raising=False)
|
||||
monkeypatch.setattr("cli.lib.pull._file_md5", lambda p: "h1", raising=False)
|
||||
|
||||
|
||||
def test_textual_progress_when_stderr_is_not_tty(
|
||||
tmp_path, fake_pull_io, monkeypatch, capsys,
|
||||
):
|
||||
"""Non-TTY stderr → emit a plain-text progress line per file."""
|
||||
# Force the non-TTY branch even if pytest's fake stderr is a tty.
|
||||
monkeypatch.setattr("sys.stderr.isatty", lambda: False, raising=False)
|
||||
|
||||
from cli.lib.pull import run_pull
|
||||
result = run_pull(
|
||||
server_url="http://x", token="t", workspace=tmp_path,
|
||||
show_progress=True,
|
||||
)
|
||||
captured = capsys.readouterr()
|
||||
# Some indication of the file + bytes ran; we don't pin exact format.
|
||||
assert "tbl_big" in captured.err
|
||||
assert result.tables_updated == 1
|
||||
# No raw ANSI escape sequences in the textual fallback.
|
||||
assert "\x1b[" not in captured.err.split("tbl_big")[0]
|
||||
|
||||
|
||||
def test_no_progress_output_when_show_progress_is_false(
|
||||
tmp_path, fake_pull_io, monkeypatch, capsys,
|
||||
):
|
||||
"""`show_progress=False` (the SessionStart hook path) emits no
|
||||
progress text on stderr in either TTY or non-TTY mode."""
|
||||
monkeypatch.setattr("sys.stderr.isatty", lambda: False, raising=False)
|
||||
|
||||
from cli.lib.pull import run_pull
|
||||
run_pull(
|
||||
server_url="http://x", token="t", workspace=tmp_path,
|
||||
show_progress=False,
|
||||
)
|
||||
captured = capsys.readouterr()
|
||||
assert "tbl_big" not in captured.err
|
||||
|
||||
|
||||
def test_textual_progress_emits_at_completion(
|
||||
tmp_path, fake_pull_io, monkeypatch, capsys,
|
||||
):
|
||||
"""At least one final completion line gets emitted per file even if
|
||||
the throttle window doesn't trigger mid-file."""
|
||||
monkeypatch.setattr("sys.stderr.isatty", lambda: False, raising=False)
|
||||
from cli.lib.pull import run_pull
|
||||
run_pull(
|
||||
server_url="http://x", token="t", workspace=tmp_path,
|
||||
show_progress=True,
|
||||
)
|
||||
captured = capsys.readouterr()
|
||||
# Final line marks the file as done — either "100%" or a "✓ tbl_big" /
|
||||
# "tbl_big done" indicator. We accept any final-completion form.
|
||||
assert (
|
||||
"100%" in captured.err
|
||||
or "done" in captured.err.lower()
|
||||
or "complete" in captured.err.lower()
|
||||
)
|
||||
85
tests/test_pull_shared_client.py
Normal file
85
tests/test_pull_shared_client.py
Normal file
|
|
@ -0,0 +1,85 @@
|
|||
"""Tests for the persistent HTTP/2-capable shared client (Change 2).
|
||||
|
||||
`agnes pull` issues N stream_download calls — one per parquet. Without
|
||||
pooling, each call performs a fresh TLS handshake. The shared client is
|
||||
created lazily once per process and closed at exit; HTTP/2 (when `h2` is
|
||||
available) further multiplexes all chunk Range requests over a single
|
||||
TCP connection.
|
||||
"""
|
||||
|
||||
from __future__ import annotations
|
||||
|
||||
import pytest
|
||||
|
||||
|
||||
@pytest.fixture(autouse=True)
|
||||
def _isolate_config_dir(tmp_path, monkeypatch):
|
||||
cfg = tmp_path / "_cfg"
|
||||
cfg.mkdir()
|
||||
monkeypatch.setenv("AGNES_CONFIG_DIR", str(cfg))
|
||||
# Some dev environments point SSL_CERT_FILE / REQUESTS_CA_BUNDLE at a
|
||||
# corp-CA bundle that may not exist on every laptop running the test
|
||||
# suite. Clear those so httpx.Client() construction in the shared-
|
||||
# client path can build a default SSL context without trying to load
|
||||
# a missing PEM file.
|
||||
for var in ("SSL_CERT_FILE", "REQUESTS_CA_BUNDLE", "CURL_CA_BUNDLE"):
|
||||
monkeypatch.delenv(var, raising=False)
|
||||
|
||||
|
||||
@pytest.fixture(autouse=True)
|
||||
def _reset_shared(monkeypatch):
|
||||
import cli.client as cc
|
||||
cc._close_shared_client()
|
||||
monkeypatch.setattr(cc, "_SHARED_CLIENT", None, raising=False)
|
||||
yield
|
||||
cc._close_shared_client()
|
||||
|
||||
|
||||
def test_get_shared_client_is_cached(monkeypatch):
|
||||
"""Multiple calls return the same client instance — no fresh TLS
|
||||
handshake per stream_download invocation."""
|
||||
monkeypatch.setenv("AGNES_SERVER", "https://x.example.test")
|
||||
from cli.client import _get_shared_client
|
||||
c1 = _get_shared_client()
|
||||
c2 = _get_shared_client()
|
||||
assert c1 is c2, "shared client must be a single instance"
|
||||
|
||||
|
||||
def test_get_shared_client_falls_back_when_http2_unavailable(monkeypatch):
|
||||
"""If httpx raises during HTTP/2 client construction (e.g. `h2` not
|
||||
installed in the runtime env), we must gracefully build a HTTP/1.1
|
||||
client instead of crashing the pull."""
|
||||
import httpx
|
||||
|
||||
monkeypatch.setenv("AGNES_SERVER", "https://x.example.test")
|
||||
import cli.client as cc
|
||||
|
||||
real_client = httpx.Client
|
||||
|
||||
construction_calls = []
|
||||
|
||||
def fake_client(*args, **kwargs):
|
||||
construction_calls.append(kwargs.copy())
|
||||
if kwargs.get("http2") is True:
|
||||
raise ImportError("Using http2=True, but the 'h2' package is not installed")
|
||||
return real_client(*args, **kwargs)
|
||||
|
||||
monkeypatch.setattr(httpx, "Client", fake_client)
|
||||
|
||||
client = cc._get_shared_client()
|
||||
assert client is not None
|
||||
# Two construction attempts: first http2=True (raised), second falls
|
||||
# back to HTTP/1.1 (no http2 kwarg).
|
||||
assert construction_calls[0].get("http2") is True
|
||||
assert construction_calls[1].get("http2") is None or construction_calls[1].get("http2") is False
|
||||
cc._close_shared_client()
|
||||
|
||||
|
||||
def test_close_shared_client_idempotent(monkeypatch):
|
||||
"""Calling close twice (once explicitly, once via atexit) must not
|
||||
raise."""
|
||||
monkeypatch.setenv("AGNES_SERVER", "https://x.example.test")
|
||||
from cli.client import _get_shared_client, _close_shared_client
|
||||
_get_shared_client()
|
||||
_close_shared_client()
|
||||
_close_shared_client() # second close is a no-op
|
||||
701
tests/test_query_remote_rewrite.py
Normal file
701
tests/test_query_remote_rewrite.py
Normal file
|
|
@ -0,0 +1,701 @@
|
|||
"""Unit tests for ``_rewrite_user_sql_for_bigquery_query``.
|
||||
|
||||
The helper rewrites user SQL referencing query_mode='remote' BigQuery
|
||||
tables so the entire query ships to BQ via the DuckDB BQ extension's
|
||||
``bigquery_query(<project>, <sql>)`` UDF — engaging WHERE / SELECT /
|
||||
LIMIT predicate pushdown instead of falling through to ATTACH-catalog
|
||||
mode (which opens a Storage Read API session over the whole table).
|
||||
|
||||
These tests pin down each conservative-skip rule plus the happy-path
|
||||
rewrites. Edge cases (CTE shadowing, double-wrap, mixed-source JOIN)
|
||||
are intentionally explicit so a future refactor doesn't quietly
|
||||
loosen the guard.
|
||||
"""
|
||||
from __future__ import annotations
|
||||
|
||||
import pytest
|
||||
|
||||
|
||||
# ---------------------------------------------------------------------------
|
||||
# Test infrastructure: an in-memory DuckDB seeded with table_registry rows
|
||||
# matching the shapes the production registry produces. Avoids the full app
|
||||
# bootstrap path; the rewriter only needs ``conn.execute("SELECT * FROM
|
||||
# table_registry ...")`` to resolve names.
|
||||
# ---------------------------------------------------------------------------
|
||||
|
||||
|
||||
@pytest.fixture
|
||||
def seeded_registry(tmp_path, monkeypatch):
|
||||
"""Build a fresh ``system.duckdb`` in tmp_path with the schema migrated.
|
||||
|
||||
Returns the open connection so tests can pass it to the rewriter.
|
||||
Cleanup is automatic via tmp_path teardown — but we close the
|
||||
open singleton handle first so a different DATA_DIR in the next
|
||||
test doesn't see the previous tmp's lock.
|
||||
"""
|
||||
from src.db import get_system_db, close_system_db
|
||||
|
||||
monkeypatch.setenv("DATA_DIR", str(tmp_path))
|
||||
(tmp_path / "state").mkdir(parents=True, exist_ok=True)
|
||||
close_system_db()
|
||||
conn = get_system_db()
|
||||
yield conn
|
||||
close_system_db()
|
||||
|
||||
|
||||
def _register_bq_remote(conn, *, table_id, name, bucket, source_table):
|
||||
from src.repositories.table_registry import TableRegistryRepository
|
||||
TableRegistryRepository(conn).register(
|
||||
id=table_id,
|
||||
name=name,
|
||||
source_type="bigquery",
|
||||
bucket=bucket,
|
||||
source_table=source_table,
|
||||
query_mode="remote",
|
||||
)
|
||||
|
||||
|
||||
def _register_local(conn, *, table_id, name, source_type="keboola"):
|
||||
from src.repositories.table_registry import TableRegistryRepository
|
||||
TableRegistryRepository(conn).register(
|
||||
id=table_id,
|
||||
name=name,
|
||||
source_type=source_type,
|
||||
bucket="bkt",
|
||||
source_table=name,
|
||||
query_mode="local",
|
||||
)
|
||||
|
||||
|
||||
def _set_bq_project(monkeypatch, project="test-prj", billing=None):
|
||||
"""Stub get_bq_access so the rewriter sees a real-looking project ID.
|
||||
|
||||
`project` configures the data project (used in backtick paths).
|
||||
`billing` (when provided) configures a different billing project so
|
||||
cross-project deployments can be exercised; defaults to `project`
|
||||
for the single-project case."""
|
||||
from connectors.bigquery.access import BqAccess, BqProjects, get_bq_access
|
||||
bq = BqAccess(
|
||||
BqProjects(billing=billing or project, data=project),
|
||||
client_factory=lambda projects: object(),
|
||||
)
|
||||
monkeypatch.setattr(
|
||||
"app.api.query.get_bq_access",
|
||||
lambda: bq,
|
||||
raising=False,
|
||||
)
|
||||
get_bq_access.cache_clear()
|
||||
|
||||
|
||||
# ---------------------------------------------------------------------------
|
||||
# Happy-path rewrites
|
||||
# ---------------------------------------------------------------------------
|
||||
|
||||
|
||||
def test_simple_select_where_against_one_bq_table_rewrites(seeded_registry, monkeypatch):
|
||||
"""Single-table SELECT-WHERE against a registered BQ remote row →
|
||||
full SQL wrapped in ``bigquery_query('project', '<rewritten>')``.
|
||||
The bare-name reference gets translated to BQ-native backtick form."""
|
||||
from app.api.query import _rewrite_user_sql_for_bigquery_query
|
||||
_register_bq_remote(seeded_registry, table_id="bq.fin.ue", name="ue",
|
||||
bucket="fin", source_table="ue")
|
||||
_set_bq_project(monkeypatch, "test-prj")
|
||||
|
||||
rewritten, did_rewrite = _rewrite_user_sql_for_bigquery_query(
|
||||
"SELECT count(*) FROM ue WHERE event_date = '2026-01-01'",
|
||||
seeded_registry,
|
||||
)
|
||||
|
||||
assert did_rewrite is True
|
||||
# Outer wrap must be a single bigquery_query() FROM-source.
|
||||
assert "bigquery_query(" in rewritten
|
||||
assert "test-prj" in rewritten
|
||||
# Inner SQL: bare name rewritten to backticked BQ-native path.
|
||||
assert "`test-prj.fin.ue`" in rewritten
|
||||
# Inner SQL is dollar-quoted (`$bqq_inner$ ... $bqq_inner$`), so
|
||||
# single quotes inside the WHERE predicate remain literal — no
|
||||
# doubling, no backslash escaping. Verifies the safer embedding form
|
||||
# introduced after the code review caught naive single-quote-only
|
||||
# escape doubling missing DuckDB backslash sequences.
|
||||
assert "$bqq_inner$" in rewritten
|
||||
assert "event_date = '2026-01-01'" in rewritten
|
||||
|
||||
|
||||
def test_direct_bq_path_rewrites(seeded_registry, monkeypatch):
|
||||
"""User wrote the direct ``bq."ds"."tbl"`` form. The rewriter must
|
||||
still translate to BQ-native backtick form before wrapping."""
|
||||
from app.api.query import _rewrite_user_sql_for_bigquery_query
|
||||
_register_bq_remote(seeded_registry, table_id="bq.fin.ue", name="ue",
|
||||
bucket="fin", source_table="ue")
|
||||
_set_bq_project(monkeypatch, "test-prj")
|
||||
|
||||
rewritten, did_rewrite = _rewrite_user_sql_for_bigquery_query(
|
||||
'SELECT * FROM bq."fin"."ue" LIMIT 10',
|
||||
seeded_registry,
|
||||
)
|
||||
|
||||
assert did_rewrite is True
|
||||
assert "bigquery_query(" in rewritten
|
||||
assert "`test-prj.fin.ue`" in rewritten
|
||||
# Original duckdb-flavor path must NOT remain (it'd parse-fail under BQ).
|
||||
assert 'bq."fin"."ue"' not in rewritten
|
||||
|
||||
|
||||
def test_cte_referencing_bq_table_rewrites_inside_cte(seeded_registry, monkeypatch):
|
||||
"""A WITH clause whose body references a BQ table must rewrite that
|
||||
inner reference; the wrapping happens at the top level so BQ sees a
|
||||
valid BQ-flavor CTE."""
|
||||
from app.api.query import _rewrite_user_sql_for_bigquery_query
|
||||
_register_bq_remote(seeded_registry, table_id="bq.fin.orders", name="orders",
|
||||
bucket="fin", source_table="orders")
|
||||
_set_bq_project(monkeypatch, "test-prj")
|
||||
|
||||
rewritten, did_rewrite = _rewrite_user_sql_for_bigquery_query(
|
||||
"WITH x AS (SELECT id FROM orders WHERE total > 0) SELECT count(*) FROM x",
|
||||
seeded_registry,
|
||||
)
|
||||
assert did_rewrite is True
|
||||
# Inner reference is rewritten.
|
||||
assert "`test-prj.fin.orders`" in rewritten
|
||||
# The whole thing is wrapped — bigquery_query is the outermost FROM.
|
||||
assert rewritten.lower().count("bigquery_query(") == 1
|
||||
|
||||
|
||||
def test_subquery_referencing_bq_table_rewrites(seeded_registry, monkeypatch):
|
||||
"""Subquery in FROM position — same handling as a CTE: rewrite the
|
||||
inner table reference, wrap the whole at the top."""
|
||||
from app.api.query import _rewrite_user_sql_for_bigquery_query
|
||||
_register_bq_remote(seeded_registry, table_id="bq.fin.ue", name="ue",
|
||||
bucket="fin", source_table="ue")
|
||||
_set_bq_project(monkeypatch, "test-prj")
|
||||
|
||||
rewritten, did_rewrite = _rewrite_user_sql_for_bigquery_query(
|
||||
"SELECT s.cnt FROM (SELECT count(*) AS cnt FROM ue) s",
|
||||
seeded_registry,
|
||||
)
|
||||
assert did_rewrite is True
|
||||
assert "`test-prj.fin.ue`" in rewritten
|
||||
assert rewritten.lower().count("bigquery_query(") == 1
|
||||
|
||||
|
||||
def test_multiple_bq_tables_one_project_combine(seeded_registry, monkeypatch):
|
||||
"""Two registered BQ tables in the same project → single
|
||||
``bigquery_query()`` wraps the whole SQL with both refs rewritten
|
||||
inline. No separate parallel calls."""
|
||||
from app.api.query import _rewrite_user_sql_for_bigquery_query
|
||||
_register_bq_remote(seeded_registry, table_id="bq.fin.orders", name="orders",
|
||||
bucket="fin", source_table="orders")
|
||||
_register_bq_remote(seeded_registry, table_id="bq.fin.users", name="users",
|
||||
bucket="fin", source_table="users")
|
||||
_set_bq_project(monkeypatch, "test-prj")
|
||||
|
||||
rewritten, did_rewrite = _rewrite_user_sql_for_bigquery_query(
|
||||
"SELECT u.id, count(o.id) "
|
||||
"FROM users u JOIN orders o ON u.id = o.user_id "
|
||||
"GROUP BY u.id",
|
||||
seeded_registry,
|
||||
)
|
||||
assert did_rewrite is True
|
||||
# Both rewritten.
|
||||
assert "`test-prj.fin.users`" in rewritten
|
||||
assert "`test-prj.fin.orders`" in rewritten
|
||||
# Single wrap.
|
||||
assert rewritten.lower().count("bigquery_query(") == 1
|
||||
|
||||
|
||||
# ---------------------------------------------------------------------------
|
||||
# Conservative-skip cases
|
||||
# ---------------------------------------------------------------------------
|
||||
|
||||
|
||||
def test_join_bq_to_local_skips_rewrite(seeded_registry, monkeypatch):
|
||||
"""A JOIN between a BQ table and a local-mode (Keboola/Jira) table
|
||||
is a cross-source query — wrapping it in bigquery_query() would lose
|
||||
the local table. The rewriter must fall through to the ATTACH-catalog
|
||||
path (slow but correct).
|
||||
"""
|
||||
from app.api.query import _rewrite_user_sql_for_bigquery_query
|
||||
_register_bq_remote(seeded_registry, table_id="bq.fin.ue", name="ue",
|
||||
bucket="fin", source_table="ue")
|
||||
_register_local(seeded_registry, table_id="kbc.in.local_orders",
|
||||
name="local_orders")
|
||||
_set_bq_project(monkeypatch, "test-prj")
|
||||
|
||||
user_sql = (
|
||||
"SELECT u.id, lo.total "
|
||||
"FROM ue u JOIN local_orders lo ON u.id = lo.user_id"
|
||||
)
|
||||
rewritten, did_rewrite = _rewrite_user_sql_for_bigquery_query(
|
||||
user_sql, seeded_registry,
|
||||
)
|
||||
assert did_rewrite is False
|
||||
assert rewritten == user_sql # untouched
|
||||
|
||||
|
||||
def test_no_bq_tables_passes_through(seeded_registry, monkeypatch):
|
||||
"""User SQL referencing only local-source tables → no rewrite,
|
||||
no log spam, original SQL returned."""
|
||||
from app.api.query import _rewrite_user_sql_for_bigquery_query
|
||||
_register_local(seeded_registry, table_id="kbc.in.orders", name="orders")
|
||||
_set_bq_project(monkeypatch, "test-prj")
|
||||
|
||||
user_sql = "SELECT * FROM orders WHERE id = 1"
|
||||
rewritten, did_rewrite = _rewrite_user_sql_for_bigquery_query(
|
||||
user_sql, seeded_registry,
|
||||
)
|
||||
assert did_rewrite is False
|
||||
assert rewritten == user_sql
|
||||
|
||||
|
||||
def test_already_contains_bigquery_query_passes_through(seeded_registry, monkeypatch):
|
||||
"""User SQL already calls bigquery_query() — never double-wrap.
|
||||
|
||||
Note: the /api/query endpoint blocks ``bigquery_query`` in user SQL
|
||||
via the keyword denylist, so this scenario can't reach the rewriter
|
||||
in production today. Defensive guard for callers from other paths.
|
||||
"""
|
||||
from app.api.query import _rewrite_user_sql_for_bigquery_query
|
||||
_register_bq_remote(seeded_registry, table_id="bq.fin.ue", name="ue",
|
||||
bucket="fin", source_table="ue")
|
||||
_set_bq_project(monkeypatch, "test-prj")
|
||||
|
||||
user_sql = (
|
||||
"SELECT * FROM bigquery_query('test-prj', 'SELECT * FROM `test-prj.fin.ue`')"
|
||||
)
|
||||
rewritten, did_rewrite = _rewrite_user_sql_for_bigquery_query(
|
||||
user_sql, seeded_registry,
|
||||
)
|
||||
assert did_rewrite is False
|
||||
assert rewritten == user_sql
|
||||
|
||||
|
||||
def test_unconfigured_bq_project_skips(seeded_registry, monkeypatch):
|
||||
"""If get_bq_access() is the not-configured sentinel (data=''),
|
||||
don't rewrite — there's no project to fill into bigquery_query()."""
|
||||
from app.api.query import _rewrite_user_sql_for_bigquery_query
|
||||
_register_bq_remote(seeded_registry, table_id="bq.fin.ue", name="ue",
|
||||
bucket="fin", source_table="ue")
|
||||
|
||||
# Override to sentinel (empty data project).
|
||||
from connectors.bigquery.access import BqAccess, BqProjects, get_bq_access
|
||||
monkeypatch.setattr(
|
||||
"app.api.query.get_bq_access",
|
||||
lambda: BqAccess(BqProjects(billing="", data="")),
|
||||
raising=False,
|
||||
)
|
||||
get_bq_access.cache_clear()
|
||||
|
||||
user_sql = "SELECT * FROM ue"
|
||||
rewritten, did_rewrite = _rewrite_user_sql_for_bigquery_query(
|
||||
user_sql, seeded_registry,
|
||||
)
|
||||
assert did_rewrite is False
|
||||
assert rewritten == user_sql
|
||||
|
||||
|
||||
# ---------------------------------------------------------------------------
|
||||
# Backwards-compat: dry-run helper still available + behaves the same
|
||||
# ---------------------------------------------------------------------------
|
||||
|
||||
|
||||
def test_existing_dry_run_helper_still_callable():
|
||||
"""The original ``_rewrite_user_sql_for_bq_dry_run`` is now a thin
|
||||
wrapper around the shared core rewriter (Pass 1 + Pass 2). Callers
|
||||
that pass an explicit ``project`` argument keep working unchanged.
|
||||
"""
|
||||
from app.api.query import _rewrite_user_sql_for_bq_dry_run
|
||||
|
||||
rewritten = _rewrite_user_sql_for_bq_dry_run(
|
||||
sql="SELECT * FROM ue",
|
||||
name_lookups=[("ue", "fin", "ue")],
|
||||
project="some-prj",
|
||||
)
|
||||
assert "`some-prj.fin.ue`" in rewritten
|
||||
# The dry-run helper does NOT add a bigquery_query() wrapper; that's
|
||||
# only the new execution-path helper's job.
|
||||
assert "bigquery_query(" not in rewritten
|
||||
|
||||
|
||||
# ---------------------------------------------------------------------------
|
||||
# End-to-end: the /api/query handler must invoke the rewriter and execute
|
||||
# the rewritten SQL (not the original) when there's a BQ-remote table.
|
||||
# ---------------------------------------------------------------------------
|
||||
|
||||
|
||||
def _auth(token: str) -> dict:
|
||||
return {"Authorization": f"Bearer {token}"}
|
||||
|
||||
|
||||
def _register_bq_remote_row(name: str, bucket: str, source_table: str) -> None:
|
||||
from src.db import get_system_db
|
||||
from src.repositories.table_registry import TableRegistryRepository
|
||||
sys_conn = get_system_db()
|
||||
try:
|
||||
TableRegistryRepository(sys_conn).register(
|
||||
id=f"bq.{bucket}.{source_table}",
|
||||
name=name,
|
||||
source_type="bigquery",
|
||||
bucket=bucket,
|
||||
source_table=source_table,
|
||||
query_mode="remote",
|
||||
)
|
||||
finally:
|
||||
sys_conn.close()
|
||||
|
||||
|
||||
@pytest.fixture
|
||||
def stub_bq_for_endpoint(monkeypatch):
|
||||
"""Stub _bq_dry_run_bytes + get_bq_access at the endpoint level so the
|
||||
cap-guard sees a real-looking BQ project but doesn't issue real RPCs.
|
||||
"""
|
||||
monkeypatch.setattr(
|
||||
"app.api.query._bq_dry_run_bytes",
|
||||
lambda *a, **k: 1024, # tiny — pass cap
|
||||
raising=False,
|
||||
)
|
||||
|
||||
class _FakeProjects:
|
||||
data = "test-data-prj"
|
||||
billing = "test-billing-prj"
|
||||
|
||||
class _FakeBqAccess:
|
||||
projects = _FakeProjects()
|
||||
|
||||
monkeypatch.setattr(
|
||||
"app.api.query.get_bq_access",
|
||||
lambda: _FakeBqAccess(),
|
||||
raising=False,
|
||||
)
|
||||
|
||||
|
||||
def test_endpoint_executes_rewritten_sql_against_analytics(
|
||||
seeded_app, stub_bq_for_endpoint, monkeypatch,
|
||||
):
|
||||
"""The /api/query handler must call ``analytics.execute(rewritten_sql)``
|
||||
— NOT the user's original SQL — when a BQ-remote table is referenced.
|
||||
Capture what reaches DuckDB and assert the bigquery_query() wrap is
|
||||
present.
|
||||
"""
|
||||
_register_bq_remote_row("ue", "fin", "ue")
|
||||
|
||||
# Capture analytics.execute calls. The handler does
|
||||
# `analytics = get_analytics_db_readonly(); analytics.execute(sql)`,
|
||||
# so we patch the connection factory to return a stub.
|
||||
captured = {"sql": None}
|
||||
|
||||
class _StubAnalytics:
|
||||
description = [("c0",)]
|
||||
def execute(self, sql, *args, **kwargs):
|
||||
captured["sql"] = sql
|
||||
class _R:
|
||||
def fetchmany(self, _n):
|
||||
return []
|
||||
return _R()
|
||||
def close(self):
|
||||
pass
|
||||
|
||||
monkeypatch.setattr(
|
||||
"app.api.query.get_analytics_db_readonly",
|
||||
lambda: _StubAnalytics(),
|
||||
raising=False,
|
||||
)
|
||||
|
||||
c = seeded_app["client"]
|
||||
token = seeded_app["admin_token"]
|
||||
r = c.post(
|
||||
"/api/query",
|
||||
json={"sql": "SELECT count(*) FROM ue WHERE country = 'CZ'"},
|
||||
headers=_auth(token),
|
||||
)
|
||||
assert r.status_code == 200, r.json()
|
||||
sent = captured["sql"]
|
||||
assert sent is not None, "analytics.execute was never called"
|
||||
assert "bigquery_query(" in sent, (
|
||||
f"endpoint did not wrap user SQL in bigquery_query(); sent: {sent!r}"
|
||||
)
|
||||
assert "test-data-prj" in sent
|
||||
assert "`test-data-prj.fin.ue`" in sent
|
||||
|
||||
|
||||
def test_endpoint_passes_original_sql_when_no_bq_table(
|
||||
seeded_app, stub_bq_for_endpoint, monkeypatch,
|
||||
):
|
||||
"""For queries that don't touch any BQ-remote registered name, the
|
||||
handler must pass the original SQL through unchanged — the
|
||||
ATTACH-catalog path handles local-source tables natively and any
|
||||
rewrite would be wasted work."""
|
||||
captured = {"sql": None}
|
||||
|
||||
class _StubAnalytics:
|
||||
description = [("c0",)]
|
||||
def execute(self, sql, *args, **kwargs):
|
||||
captured["sql"] = sql
|
||||
class _R:
|
||||
def fetchmany(self, _n):
|
||||
return []
|
||||
return _R()
|
||||
def close(self):
|
||||
pass
|
||||
|
||||
monkeypatch.setattr(
|
||||
"app.api.query.get_analytics_db_readonly",
|
||||
lambda: _StubAnalytics(),
|
||||
raising=False,
|
||||
)
|
||||
|
||||
c = seeded_app["client"]
|
||||
token = seeded_app["admin_token"]
|
||||
user_sql = "SELECT 1 AS x"
|
||||
r = c.post("/api/query", json={"sql": user_sql}, headers=_auth(token))
|
||||
assert r.status_code == 200, r.json()
|
||||
assert captured["sql"] == user_sql
|
||||
assert "bigquery_query(" not in captured["sql"]
|
||||
|
||||
|
||||
def test_endpoint_wraps_rewritten_sql_with_outer_limit(
|
||||
seeded_app, stub_bq_for_endpoint, monkeypatch,
|
||||
):
|
||||
"""Memory-safety regression — when the rewriter fires, the handler
|
||||
MUST wrap the bigquery_query() call in an outer ``LIMIT N+1`` so a
|
||||
`SELECT *` against a billion-row remote table doesn't materialise the
|
||||
full result into the worker before fetchmany applies the cap.
|
||||
Code-review #2a fix.
|
||||
"""
|
||||
_register_bq_remote_row("ue", "fin", "ue")
|
||||
|
||||
captured = {"sql": None}
|
||||
|
||||
class _StubAnalytics:
|
||||
description = [("c0",)]
|
||||
def execute(self, sql, *args, **kwargs):
|
||||
captured["sql"] = sql
|
||||
class _R:
|
||||
def fetchmany(self, _n):
|
||||
return []
|
||||
return _R()
|
||||
def close(self):
|
||||
pass
|
||||
|
||||
monkeypatch.setattr(
|
||||
"app.api.query.get_analytics_db_readonly",
|
||||
lambda: _StubAnalytics(),
|
||||
raising=False,
|
||||
)
|
||||
|
||||
c = seeded_app["client"]
|
||||
token = seeded_app["admin_token"]
|
||||
r = c.post(
|
||||
"/api/query",
|
||||
json={"sql": "SELECT * FROM ue", "limit": 100},
|
||||
headers=_auth(token),
|
||||
)
|
||||
assert r.status_code == 200, r.json()
|
||||
sent = captured["sql"]
|
||||
# The bigquery_query() wrap is present, AND the whole thing is wrapped
|
||||
# again with an outer LIMIT that includes the user-requested cap +1
|
||||
# (the +1 is the existing truncation-detection pattern).
|
||||
assert "bigquery_query(" in sent
|
||||
assert "_bqq_outer" in sent
|
||||
assert "LIMIT 101" in sent # request.limit (100) + 1
|
||||
|
||||
|
||||
def test_endpoint_falls_back_to_original_sql_on_bq_parse_error(
|
||||
seeded_app, stub_bq_for_endpoint, monkeypatch,
|
||||
):
|
||||
"""When the rewritten ``bigquery_query()`` path fails with a parse-
|
||||
level error (e.g. user SQL contained DuckDB-only syntax that BQ
|
||||
can't parse), the handler MUST retry with the original SQL via the
|
||||
ATTACH-catalog path so the user request still succeeds. Code-review
|
||||
#4 fix.
|
||||
"""
|
||||
_register_bq_remote_row("ue", "fin", "ue")
|
||||
|
||||
calls = {"sqls": []}
|
||||
|
||||
class _StubAnalytics:
|
||||
description = [("c0",)]
|
||||
def execute(self, sql, *args, **kwargs):
|
||||
calls["sqls"].append(sql)
|
||||
# First call (rewritten) raises a BQ-style parse error;
|
||||
# second call (original SQL fallback) returns rows.
|
||||
if "bigquery_query(" in sql:
|
||||
raise RuntimeError(
|
||||
"BinderException: Query execution failed: "
|
||||
"Syntax error: Unexpected token at [1:42]"
|
||||
)
|
||||
class _R:
|
||||
def fetchmany(self, _n):
|
||||
return [(1,)]
|
||||
return _R()
|
||||
def close(self):
|
||||
pass
|
||||
|
||||
monkeypatch.setattr(
|
||||
"app.api.query.get_analytics_db_readonly",
|
||||
lambda: _StubAnalytics(),
|
||||
raising=False,
|
||||
)
|
||||
|
||||
c = seeded_app["client"]
|
||||
token = seeded_app["admin_token"]
|
||||
r = c.post(
|
||||
"/api/query",
|
||||
# DuckDB-only ::INT cast — survives identifier rewrite, BQ refuses.
|
||||
json={"sql": "SELECT (count(*))::INT FROM ue"},
|
||||
headers=_auth(token),
|
||||
)
|
||||
assert r.status_code == 200, r.json()
|
||||
# Two execute calls: 1) rewritten (raised) 2) fallback to original.
|
||||
assert len(calls["sqls"]) == 2
|
||||
assert "bigquery_query(" in calls["sqls"][0]
|
||||
assert calls["sqls"][1] == "SELECT (count(*))::INT FROM ue"
|
||||
|
||||
|
||||
def test_rewriter_uses_billing_project_for_bigquery_query_first_arg(
|
||||
seeded_registry, monkeypatch,
|
||||
):
|
||||
"""Devin-review BUG #1: `bigquery_query()` first arg is the
|
||||
**billing** project (where BQ jobs are billed + executed), backtick
|
||||
paths use the **data** project. In cross-project deploys the SA
|
||||
has `serviceusage.services.use` only on the billing project, so
|
||||
using the data project as billing → 403 USER_PROJECT_DENIED.
|
||||
|
||||
Match the existing convention in v2_scan / v2_sample / v2_schema /
|
||||
extractor.
|
||||
"""
|
||||
from app.api.query import _rewrite_user_sql_for_bigquery_query
|
||||
_register_bq_remote(
|
||||
seeded_registry, table_id="bq.fin.ue", name="ue",
|
||||
bucket="fin", source_table="ue",
|
||||
)
|
||||
_set_bq_project(monkeypatch, project="data-prj", billing="billing-prj")
|
||||
|
||||
rewritten, did_rewrite = _rewrite_user_sql_for_bigquery_query(
|
||||
"SELECT count(*) FROM ue",
|
||||
seeded_registry,
|
||||
)
|
||||
assert did_rewrite is True
|
||||
# First arg of bigquery_query must be the billing project.
|
||||
assert "bigquery_query('billing-prj'" in rewritten
|
||||
# Backtick path must use the data project.
|
||||
assert "`data-prj.fin.ue`" in rewritten
|
||||
# And the data project must NOT appear as the first arg.
|
||||
assert "bigquery_query('data-prj'" not in rewritten
|
||||
|
||||
|
||||
def test_rewriter_skips_when_bq_row_bucket_contains_dot(
|
||||
seeded_registry, monkeypatch,
|
||||
):
|
||||
"""Devil's-advocate R1 finding #5: a BQ row whose `bucket` contains
|
||||
`.` suggests the operator encoded a project prefix in the bucket
|
||||
name. Wrapping under our single-project assumption could silently
|
||||
target the wrong project. Rewriter must skip in that case (fall
|
||||
through to ATTACH-catalog path which respects the operator's
|
||||
`_remote_attach` configuration).
|
||||
"""
|
||||
from app.api.query import _rewrite_user_sql_for_bigquery_query
|
||||
_register_bq_remote(
|
||||
seeded_registry,
|
||||
table_id="bq.other-prj.dataset.ue",
|
||||
name="ue",
|
||||
# Project-qualified bucket — the multi-project red flag.
|
||||
bucket="other-prj.dataset",
|
||||
source_table="ue",
|
||||
)
|
||||
_set_bq_project(monkeypatch, "test-prj")
|
||||
|
||||
rewritten, did_rewrite = _rewrite_user_sql_for_bigquery_query(
|
||||
"SELECT count(*) FROM ue",
|
||||
seeded_registry,
|
||||
)
|
||||
# Skip — original SQL returned, no rewrite.
|
||||
assert did_rewrite is False
|
||||
assert rewritten == "SELECT count(*) FROM ue"
|
||||
|
||||
|
||||
def test_fallback_does_not_trigger_on_user_column_typo(
|
||||
seeded_app, stub_bq_for_endpoint, monkeypatch,
|
||||
):
|
||||
"""Devil's-advocate R1 finding #2: previously the fallback
|
||||
heuristic matched `Unrecognized name`, which BQ surfaces for both
|
||||
DuckDB-only-name AND user-column-typo cases. The user-typo case
|
||||
triggered re-running the original SQL through the slow ATTACH-
|
||||
catalog path (90+ s) → 2× latency tax on every typo.
|
||||
|
||||
Post-fix: heuristic only matches `Syntax error`. A BQ-side
|
||||
`Unrecognized name: bad_col` should propagate as-is, NOT trigger
|
||||
a fallback retry.
|
||||
"""
|
||||
_register_bq_remote_row("ue", "fin", "ue")
|
||||
|
||||
calls = {"sqls": []}
|
||||
|
||||
class _StubAnalytics:
|
||||
description = [("c0",)]
|
||||
def execute(self, sql, *args, **kwargs):
|
||||
calls["sqls"].append(sql)
|
||||
raise RuntimeError(
|
||||
"BinderException: Query execution failed: "
|
||||
"Unrecognized name: bad_col at [1:8]"
|
||||
)
|
||||
def close(self):
|
||||
pass
|
||||
|
||||
monkeypatch.setattr(
|
||||
"app.api.query.get_analytics_db_readonly",
|
||||
lambda: _StubAnalytics(),
|
||||
raising=False,
|
||||
)
|
||||
|
||||
c = seeded_app["client"]
|
||||
token = seeded_app["admin_token"]
|
||||
r = c.post(
|
||||
"/api/query",
|
||||
json={"sql": "SELECT bad_col FROM ue"},
|
||||
headers=_auth(token),
|
||||
)
|
||||
# Error propagates; fallback NOT triggered (only one execute call).
|
||||
assert r.status_code in (400, 500, 502)
|
||||
assert len(calls["sqls"]) == 1, (
|
||||
"user column typo must NOT trigger fallback retry"
|
||||
)
|
||||
assert "bigquery_query(" in calls["sqls"][0]
|
||||
|
||||
|
||||
def test_endpoint_does_not_fall_back_on_non_parse_errors(
|
||||
seeded_app, stub_bq_for_endpoint, monkeypatch,
|
||||
):
|
||||
"""Non-parse-error exceptions from the rewritten path (network,
|
||||
quota, forbidden, generic runtime) must propagate, NOT silently
|
||||
retry against the legacy path. Otherwise the legacy path would
|
||||
just fail again and the user sees a slow + double-failure.
|
||||
"""
|
||||
_register_bq_remote_row("ue", "fin", "ue")
|
||||
|
||||
calls = {"sqls": []}
|
||||
|
||||
class _StubAnalytics:
|
||||
description = [("c0",)]
|
||||
def execute(self, sql, *args, **kwargs):
|
||||
calls["sqls"].append(sql)
|
||||
raise RuntimeError("Network unreachable: BQ endpoint timed out")
|
||||
def close(self):
|
||||
pass
|
||||
|
||||
monkeypatch.setattr(
|
||||
"app.api.query.get_analytics_db_readonly",
|
||||
lambda: _StubAnalytics(),
|
||||
raising=False,
|
||||
)
|
||||
|
||||
c = seeded_app["client"]
|
||||
token = seeded_app["admin_token"]
|
||||
r = c.post(
|
||||
"/api/query",
|
||||
json={"sql": "SELECT count(*) FROM ue"},
|
||||
headers=_auth(token),
|
||||
)
|
||||
# Generic 400 from the handler's outer except — body will surface
|
||||
# the runtime error message; we just need to confirm no fallback.
|
||||
assert r.status_code in (400, 500, 502)
|
||||
assert len(calls["sqls"]) == 1, "must not retry on non-parse error"
|
||||
Loading…
Reference in a new issue