agnes-the-ai-analyst/app/api/v2_scan.py
ZdenekSrotyr 2e1dfb7553
feat(v2): claude-driven fetch primitives + 0.14.0 (#102)
Replaces the BigQuery wrap-view pattern with a discovery + scoped-fetch toolkit driven by the analyst's Claude session. Adds /api/v2/{catalog,schema,sample,scan,scan/estimate}, da catalog/schema/describe/fetch/snapshot/disk-info CLI commands, sqlglot-backed WHERE validator, process-local quota tracker, agent rails skill (cli/skills/agnes-data-querying.md). BREAKING: BQ wrap views off by default — set data_source.bigquery.legacy_wrap_views=true for one cycle. Backward-compat field_validator on primary_key. Catalog cache now matches documented 300s TTL with RBAC fresh per request. Cuts release v0.14.0.
2026-04-29 01:07:19 +02:00

411 lines
17 KiB
Python

"""POST /api/v2/scan and POST /api/v2/scan/estimate (spec §3.4 + §3.5)."""
from __future__ import annotations
import logging
import re
from typing import Optional
import pyarrow as pa
from fastapi import APIRouter, Depends, HTTPException
from fastapi.responses import Response
from pydantic import BaseModel, Field
import duckdb
from app.auth.dependencies import get_current_user, _get_db
from app.instance_config import get_value
from src.rbac import can_access_table
from src.repositories.table_registry import TableRegistryRepository
from app.api.where_validator import (
validate_where, safe_where_predicate, WhereValidationError,
)
from app.api.v2_schema import build_schema # reused for column resolution
from app.api.v2_arrow import arrow_table_to_ipc_bytes, CONTENT_TYPE
from app.api.v2_quota import QuotaTracker, QuotaExceededError
logger = logging.getLogger(__name__)
router = APIRouter(prefix="/api/v2", tags=["v2"])
class ScanRequest(BaseModel):
table_id: str
select: Optional[list[str]] = None
where: Optional[str] = None
limit: Optional[int] = Field(default=None, ge=1)
order_by: Optional[list[str]] = None
def _resolve_schema(conn, user, table_id: str, project_id: str) -> dict:
"""Get {column: type} dict for the target table — used by validator + projection check."""
s = build_schema(conn, user, table_id, project_id=project_id)
return {c["name"]: c["type"] for c in s.get("columns", [])}
def _bq_dry_run_bytes(project: str, sql: str) -> int:
"""Run a BQ dry-run via the google-cloud-bigquery client and return totalBytesProcessed."""
from google.cloud import bigquery
from google.api_core.client_options import ClientOptions
client = bigquery.Client(
project=project,
client_options=ClientOptions(quota_project_id=project),
)
job = client.query(
sql,
job_config=bigquery.QueryJobConfig(dry_run=True, use_query_cache=False),
)
return int(job.total_bytes_processed or 0)
_COLUMN_NAME_RE = re.compile(r"^[A-Za-z_][A-Za-z0-9_]{0,63}$")
_ORDER_BY_RE = re.compile(r"^[A-Za-z_][A-Za-z0-9_]*(\s+(ASC|DESC))?$", re.IGNORECASE)
def _validate_select_columns(select: list[str] | None, schema: dict) -> None:
"""Reject SELECT column names that don't fit the safe-identifier shape.
Schema-existence is checked separately; this guard is defense-in-depth
so a backtick / double-quote in a column name can't break out of the
`…` (BQ) or "" (DuckDB) identifier wrapper in `_build_bq_sql` and the
local-scan path. Today, schema names from BQ INFORMATION_SCHEMA never
contain those characters — but Devin called this out as relying on an
implicit upstream constraint. Make it explicit."""
if not select:
return
for entry in select:
if not _COLUMN_NAME_RE.match(entry or ""):
raise ValueError(f"invalid column name: {entry!r}")
def _validate_order_by(order_by: list[str] | None, schema: dict) -> None:
"""Reject anything other than `<column>` or `<column> ASC|DESC` against the schema.
Without this, `order_by` is concatenated raw into the FROM clause SQL — exploitable."""
if not order_by:
return
known = {c.lower() for c in schema}
for entry in order_by:
s = (entry or "").strip()
if not _ORDER_BY_RE.match(s):
raise ValueError(f"invalid order_by entry: {entry!r}")
col = s.split()[0].lower()
if col not in known:
raise ValueError(f"unknown order_by column: {entry!r}")
def _quote_order_by_bq(entry: str) -> str:
"""Backtick-quote the column part of an order_by entry, preserve direction."""
parts = entry.strip().split()
return f"`{parts[0]}`" + ("" if len(parts) == 1 else " " + " ".join(parts[1:]))
def _quote_order_by_duckdb(entry: str) -> str:
parts = entry.strip().split()
return f'"{parts[0]}"' + ("" if len(parts) == 1 else " " + " ".join(parts[1:]))
def _build_bq_sql(
table_row: dict, project_id: str, req: ScanRequest, *, safe_where: str | None = None,
) -> str:
"""Build the BQ SQL string. ``safe_where`` MUST be the comment-stripped
fragment from ``safe_where_predicate`` — splicing ``req.where`` raw lets a
`1=1 --` predicate comment out everything that follows (LIMIT/ORDER BY).
Identifier quoting: column names are validated against the schema before
we get here, but reserved words (`order`, `group`, `timestamp`, …) still
need backticks to parse as identifiers in BQ.
"""
from src.identifier_validation import validate_quoted_identifier
bucket = table_row.get('bucket') or ''
src_table = table_row.get('source_table') or req.table_id
if not (validate_quoted_identifier(project_id, "BQ project")
and validate_quoted_identifier(bucket, "BQ dataset")
and validate_quoted_identifier(src_table, "BQ source_table")):
raise ValueError("unsafe BQ identifier in registry — refusing to build SQL")
select_sql = ", ".join(f"`{c}`" for c in req.select) if req.select else "*"
table_ref = f"`{project_id}.{bucket}.{src_table}`"
sql = f"SELECT {select_sql} FROM {table_ref}"
if safe_where:
sql += f" WHERE {safe_where}"
if req.order_by:
sql += f" ORDER BY {', '.join(_quote_order_by_bq(e) for e in req.order_by)}"
if req.limit:
sql += f" LIMIT {int(req.limit)}"
return sql
def estimate(conn, user, raw_request: dict, *, project_id: str, billing_project: str | None = None) -> dict:
req = ScanRequest(**raw_request)
repo = TableRegistryRepository(conn)
row = repo.get(req.table_id)
if not row:
raise FileNotFoundError(req.table_id)
if user.get("role") != "admin" and not can_access_table(user, req.table_id, conn):
raise PermissionError(req.table_id)
schema = _resolve_schema(conn, user, req.table_id, project_id)
dialect = "bigquery" if (row.get("source_type") or "") == "bigquery" else "duckdb"
# Validate WHERE and capture the comment-stripped fragment for splicing.
safe_where = (
safe_where_predicate(req.where, req.table_id, schema, dialect=dialect)
if req.where else None
)
# Validate select columns exist (case-insensitive, matching order_by).
if req.select:
_validate_select_columns(req.select, schema)
known = {c.lower() for c in schema}
unknown = [c for c in req.select if c.lower() not in known]
if unknown:
raise ValueError(f"unknown columns: {unknown}")
_validate_order_by(req.order_by, schema)
if (row.get("source_type") or "") != "bigquery":
return {
"table_id": req.table_id,
"estimated_scan_bytes": 0,
"estimated_result_rows": None,
"estimated_result_bytes": None,
"bq_cost_estimate_usd": 0.0,
}
bq_sql = _build_bq_sql(row, project_id, req, safe_where=safe_where)
scan_bytes = _bq_dry_run_bytes(billing_project or project_id, bq_sql)
cost_per_tb = float(get_value("api", "scan", "bq_cost_per_tb_usd", default=5.0) or 5.0)
cost = (scan_bytes / 1_099_511_627_776) * cost_per_tb # 1 TiB = 2^40
# Heuristic for result row/byte estimate. A row contains all selected
# columns, so per-row bytes = sum of per-column estimates (NOT average).
# If req.select is set, narrow to those columns; otherwise use full schema.
# Case-insensitive lookup matches the SELECT-validation policy — analysts
# often write a lowercased column name where INFORMATION_SCHEMA returned
# mixed-case; the schema lookup must follow.
schema_lower = {k.lower(): v for k, v in schema.items()}
cols_for_estimate = (
[schema_lower[c.lower()] for c in (req.select or []) if c.lower() in schema_lower]
or list(schema.values())
)
avg_row_bytes = max(1, sum(_avg_bytes_for_type(t) for t in cols_for_estimate))
rows_est = scan_bytes // max(avg_row_bytes, 1)
if req.limit:
rows_est = min(rows_est, req.limit)
return {
"table_id": req.table_id,
"estimated_scan_bytes": int(scan_bytes),
"estimated_result_rows": int(rows_est),
"estimated_result_bytes": int(rows_est * avg_row_bytes),
"bq_cost_estimate_usd": round(cost, 4),
}
def _avg_bytes_for_type(t: str) -> int:
t = (t or "").upper()
if t in ("INT64", "FLOAT64", "DATE", "TIMESTAMP", "DATETIME", "TIME"):
return 8
if t == "STRING":
return 32 # rough average
if t == "BYTES":
return 64
if t == "BOOL":
return 1
return 16
@router.post("/scan/estimate")
async def scan_estimate_endpoint(
raw: dict,
user: dict = Depends(get_current_user),
conn: duckdb.DuckDBPyConnection = Depends(_get_db),
):
project_id = get_value("data_source", "bigquery", "project", default="") or ""
billing_project = get_value("data_source", "bigquery", "billing_project", default="") or project_id
try:
return estimate(conn, user, raw, project_id=project_id, billing_project=billing_project)
except WhereValidationError as e:
raise HTTPException(status_code=400, detail={"error": "validator_rejected", "kind": e.kind, "details": e.detail or {}})
except PermissionError:
raise HTTPException(status_code=403, detail="not authorized")
except FileNotFoundError:
raise HTTPException(status_code=404, detail="table not found")
except ValueError as e:
raise HTTPException(status_code=400, detail=str(e))
# Module-level singleton (process-local quota state per spec §3.8). FastAPI
# dispatches sync handlers via a thread pool, so two concurrent first-time
# requests can both observe `_quota_singleton is None` and each construct a
# separate tracker; the second assignment wins and the first reference leaks
# split-brain quota state. Guard with an init lock + double-check.
import threading as _threading
_quota_init_lock = _threading.Lock()
_quota_singleton: QuotaTracker | None = None
def _build_quota_tracker() -> QuotaTracker:
"""Returns or constructs the process-local quota tracker (thread-safe)."""
global _quota_singleton
if _quota_singleton is not None:
return _quota_singleton
with _quota_init_lock:
if _quota_singleton is None:
_quota_singleton = QuotaTracker(
max_concurrent_per_user=int(get_value("api", "scan", "max_concurrent_per_user", default=5) or 5),
max_daily_bytes_per_user=int(get_value("api", "scan", "max_daily_bytes_per_user", default=53687091200) or 53687091200),
)
return _quota_singleton
def _max_result_bytes() -> int:
return int(get_value("api", "scan", "max_result_bytes", default=2_147_483_648) or 2_147_483_648)
def _max_limit() -> int:
return int(get_value("api", "scan", "max_limit", default=10_000_000) or 10_000_000)
def _run_bq_scan(project: str, sql: str) -> pa.Table:
"""Execute SQL via DuckDB BQ extension, return pyarrow Table."""
from connectors.bigquery.auth import get_metadata_token
token = get_metadata_token()
conn = duckdb.connect(":memory:")
try:
conn.execute("INSTALL bigquery FROM community; LOAD bigquery;")
escaped = token.replace("'", "''")
conn.execute(f"CREATE OR REPLACE SECRET bq_s (TYPE bigquery, ACCESS_TOKEN '{escaped}')")
# Use bigquery_query() since the SQL is already authored against the BQ jobs API
return conn.execute(
"SELECT * FROM bigquery_query(?, ?)",
[project, sql],
).arrow()
finally:
conn.close()
def run_scan(
conn: duckdb.DuckDBPyConnection,
user: dict,
raw_request: dict,
*,
project_id: str,
quota: QuotaTracker,
billing_project: str | None = None,
) -> bytes:
"""Validate → quota → execute → serialize. Returns Arrow IPC bytes.
Raises:
WhereValidationError, QuotaExceededError, FileNotFoundError, PermissionError, ValueError
"""
req = ScanRequest(**raw_request)
repo = TableRegistryRepository(conn)
row = repo.get(req.table_id)
if not row:
raise FileNotFoundError(req.table_id)
if user.get("role") != "admin" and not can_access_table(user, req.table_id, conn):
raise PermissionError(req.table_id)
if req.limit and req.limit > _max_limit():
raise ValueError(f"limit {req.limit} exceeds max {_max_limit()}")
schema = _resolve_schema(conn, user, req.table_id, project_id)
dialect = "bigquery" if (row.get("source_type") or "") == "bigquery" else "duckdb"
# Validate WHERE and capture the comment-stripped fragment for splicing.
safe_where = (
safe_where_predicate(req.where, req.table_id, schema, dialect=dialect)
if req.where else None
)
if req.select:
# Case-insensitive (BQ identifiers are case-insensitive; mixed-case
# names from INFORMATION_SCHEMA.COLUMNS shouldn't 400-reject the
# lowercased form a typical analyst writes).
_validate_select_columns(req.select, schema)
known = {c.lower() for c in schema}
unknown = [c for c in req.select if c.lower() not in known]
if unknown:
raise ValueError(f"unknown columns: {unknown}")
_validate_order_by(req.order_by, schema)
source_type = row.get("source_type") or ""
user_id = user.get("email") or "anon"
# Pre-flight quota check — fail BEFORE running the BQ scan so the user
# doesn't pay for a query whose result we'd then refuse to return.
quota.check_daily_budget(user=user_id)
with quota.acquire(user=user_id):
if source_type != "bigquery":
# Local source: query parquet directly. `source_type` extracted above
# because `row["source_type"]` could be NULL for legacy registry rows
# and `Path(...) / None` raises TypeError.
from app.utils import get_data_dir
parquet = (
get_data_dir() / "extracts" / source_type / "data" / f"{req.table_id}.parquet"
)
local = duckdb.connect(":memory:")
try:
projection = ", ".join(f'"{c}"' for c in req.select) if req.select else "*"
sql = f"SELECT {projection} FROM read_parquet(?)"
if safe_where:
sql += f" WHERE {safe_where}"
if req.order_by:
sql += f" ORDER BY {', '.join(_quote_order_by_duckdb(e) for e in req.order_by)}"
if req.limit:
sql += f" LIMIT {int(req.limit)}"
table = local.execute(sql, [str(parquet)]).arrow()
finally:
local.close()
else:
bq_sql = _build_bq_sql(row, project_id, req, safe_where=safe_where)
table = _run_bq_scan(billing_project or project_id, bq_sql)
ipc = arrow_table_to_ipc_bytes(table)
# Enforce max_result_bytes guard (spec §3.4 step 8)
if len(ipc) > _max_result_bytes():
# Truncate by taking only as many rows as fit roughly
# Simple heuristic: cap rows to estimated avg per max_bytes
row_count = table.num_rows
avg = max(1, len(ipc) // max(row_count, 1))
keep = min(row_count, _max_result_bytes() // max(avg, 1))
table = table.slice(0, keep)
ipc = arrow_table_to_ipc_bytes(table)
# Record bytes for daily quota
quota.record_bytes(user=user_id, n=len(ipc))
return ipc
@router.post("/scan")
async def scan_endpoint(
raw: dict,
user: dict = Depends(get_current_user),
conn: duckdb.DuckDBPyConnection = Depends(_get_db),
):
project_id = get_value("data_source", "bigquery", "project", default="") or ""
billing_project = get_value("data_source", "bigquery", "billing_project", default="") or project_id
quota = _build_quota_tracker()
try:
ipc = run_scan(conn, user, raw, project_id=project_id, quota=quota, billing_project=billing_project)
return Response(content=ipc, media_type=CONTENT_TYPE)
except WhereValidationError as e:
raise HTTPException(
status_code=400,
detail={"error": "validator_rejected", "kind": e.kind, "details": e.detail or {}},
)
except QuotaExceededError as e:
raise HTTPException(
status_code=429,
detail={
"error": "quota_exceeded",
"kind": e.kind,
"current": e.current,
"limit": e.limit,
"retry_after_seconds": e.retry_after_seconds,
},
)
except FileNotFoundError:
raise HTTPException(status_code=404, detail="table not found")
except PermissionError:
raise HTTPException(status_code=403, detail="not authorized")
except ValueError as e:
raise HTTPException(status_code=400, detail=str(e))