agnes-the-ai-analyst/src/remote_query.py
Petr 8c6c162417 Fix: --sql not required when --stdin is used
argparse was rejecting --stdin mode because --sql was required=True.
Changed to required=False with runtime validation in main().
2026-03-21 12:17:02 +01:00

636 lines
20 KiB
Python

"""
Remote Query - Execute DuckDB queries spanning local Parquet + remote BigQuery tables.
Provides a server-side CLI for the AI agent to run SQL queries that JOIN local
(Parquet-backed) tables with on-demand BigQuery results. Designed for tables too
large to sync locally (e.g., daily_deal_traffic: ~3M rows/day).
Two-phase query protocol:
1. BQ sub-queries (--register-bq "alias=SQL") run on BigQuery, results registered
as DuckDB views via PyArrow (reuses register_bq_table from duckdb_manager).
2. DuckDB SQL (--sql) runs against local Parquet views + registered BQ results.
Usage:
python -m src.remote_query \\
--sql "SELECT ... FROM order_economics o JOIN traffic t ON ..." \\
--register-bq "traffic=SELECT ... FROM \\`project.dataset.table\\` WHERE ..." \\
--format table
Safety features:
- COUNT(*) pre-check before fetching BQ data
- Memory estimation (refuses queries > 2 GB estimated)
- Configurable row limits (per BQ sub-query and final result)
- Query timeout support
"""
import argparse
import csv
import io
import json
import logging
import os
import sys
import time
from pathlib import Path
from typing import Optional
import duckdb
from config.loader import get_instance_value
from scripts.duckdb_manager import (
create_local_views,
register_bq_table,
_create_bq_client,
)
logger = logging.getLogger(__name__)
class RemoteQueryError(Exception):
"""Error during remote query execution."""
# ---------------------------------------------------------------------------
# Configuration
# ---------------------------------------------------------------------------
def _load_remote_query_config() -> dict:
"""Load remote_query settings from instance.yaml with defaults.
Uses raw YAML loading instead of load_instance_config() to avoid
requiring webapp secrets (WEBAPP_SECRET_KEY etc.) that analysts
don't have access to.
"""
import yaml as _yaml
from pathlib import Path as _Path
instance_config: dict = {}
config_dir = _Path(os.environ.get("CONFIG_DIR", "./config"))
yaml_path = config_dir / "instance.yaml"
if yaml_path.exists():
try:
with open(yaml_path) as f:
instance_config = _yaml.safe_load(f) or {}
except Exception as e:
logger.warning("Could not load instance.yaml: %s. Using defaults.", e)
return {
"timeout_seconds": get_instance_value(
instance_config, "remote_query", "timeout_seconds", default=300,
),
"max_result_rows": get_instance_value(
instance_config, "remote_query", "max_result_rows", default=100_000,
),
"max_bq_registration_rows": get_instance_value(
instance_config, "remote_query", "max_bq_registration_rows", default=500_000,
),
"default_format": get_instance_value(
instance_config, "remote_query", "default_format", default="table",
),
"output_dir": get_instance_value(
instance_config, "remote_query", "output_dir", default="/tmp/remote_query",
),
}
# ---------------------------------------------------------------------------
# BQ registration with safety checks
# ---------------------------------------------------------------------------
def _validate_bq_result_size(
bq_client, sql: str, alias: str, max_rows: int,
) -> int:
"""Execute COUNT(*) on the BQ sub-query before fetching all rows.
Args:
bq_client: BigQuery client instance
sql: The BQ SQL query to count
alias: Alias name (for error messages)
max_rows: Maximum allowed rows
Returns:
Row count
Raises:
RemoteQueryError: If count exceeds max_rows
"""
count_sql = f"SELECT COUNT(*) AS cnt FROM ({sql})"
_log_progress(f" Counting rows for '{alias}'...")
job = bq_client.query(count_sql)
result = job.result()
row_count = next(iter(result))[0]
if row_count > max_rows:
raise RemoteQueryError(
f"BQ sub-query '{alias}' would return {row_count:,} rows "
f"(limit: {max_rows:,}). Add more WHERE filters or GROUP BY "
f"to reduce the result set."
)
return row_count
def _estimate_memory_mb(row_count: int, column_count: int) -> float:
"""Estimate memory usage in MB for a PyArrow table.
Uses ~50 bytes per cell as a rough average across data types.
"""
return (row_count * column_count * 50) / (1024 * 1024)
def _register_bq_views(
conn: duckdb.DuckDBPyConnection,
registrations: list[tuple[str, str]],
max_bq_rows: int,
timeout_seconds: int,
quiet: bool = False,
) -> dict[str, int]:
"""Register BQ query results as DuckDB views with safety checks.
Args:
conn: DuckDB connection
registrations: List of (alias, bq_sql) tuples
max_bq_rows: Max rows per sub-query
timeout_seconds: BQ job timeout
quiet: Suppress progress messages
Returns:
Dict of {alias: row_count}
"""
if not registrations:
return {}
bq_project = os.environ.get("BIGQUERY_PROJECT")
if not bq_project:
raise RemoteQueryError(
"BIGQUERY_PROJECT environment variable not set. "
"Required for BigQuery sub-queries."
)
bq_client = _create_bq_client(bq_project)
results = {}
for alias, bq_sql in registrations:
start_time = time.time()
# Phase 1: COUNT(*) safety check
row_count = _validate_bq_result_size(bq_client, bq_sql, alias, max_bq_rows)
_log_progress(f" '{alias}': {row_count:,} rows (within limit)")
# Phase 2: Memory estimation
# Estimate column count from a LIMIT 0 query (cheap)
sample_job = bq_client.query(f"SELECT * FROM ({bq_sql}) LIMIT 0")
schema = sample_job.result().schema
col_count = len(schema)
estimated_mb = _estimate_memory_mb(row_count, col_count)
if estimated_mb > 2048: # 2 GB = 25% of 8 GB server RAM
raise RemoteQueryError(
f"BQ sub-query '{alias}' estimated memory: {estimated_mb:.0f} MB "
f"({row_count:,} rows x {col_count} cols). "
f"Limit is 2048 MB. Add more aggregation or filters."
)
# Phase 3: Execute and register
_log_progress(f" Fetching '{alias}' ({row_count:,} rows, ~{estimated_mb:.0f} MB)...")
actual_rows = register_bq_table(
conn=conn,
table_id=f"bq_registration.{alias}",
view_name=alias,
sql=bq_sql,
bq_project=bq_project,
)
elapsed = time.time() - start_time
_log_progress(f" '{alias}' registered: {actual_rows:,} rows in {elapsed:.1f}s")
results[alias] = actual_rows
return results
# ---------------------------------------------------------------------------
# Local view setup
# ---------------------------------------------------------------------------
def _setup_local_views(
conn: duckdb.DuckDBPyConnection,
data_dir: str,
quiet: bool = False,
) -> list[str]:
"""Create DuckDB views for all local/hybrid tables from Parquet.
Args:
conn: DuckDB connection
data_dir: Path to data directory (e.g., "/data/src_data")
quiet: Suppress progress messages
Returns:
List of created view names
"""
created, skipped = create_local_views(
conn=conn,
data_dir=data_dir,
verbose=not quiet,
)
return created
# ---------------------------------------------------------------------------
# Output formatting
# ---------------------------------------------------------------------------
def _print_table(columns: list[str], rows: list[tuple]) -> None:
"""Print an aligned ASCII table to stdout."""
if not rows:
print("(empty result)")
return
# Calculate column widths
str_rows = [[str(v) if v is not None else "NULL" for v in row] for row in rows]
widths = [len(col) for col in columns]
for row in str_rows:
for i, val in enumerate(row):
widths[i] = max(widths[i], len(val))
# Header
header = " | ".join(col.ljust(widths[i]) for i, col in enumerate(columns))
separator = "-+-".join("-" * widths[i] for i in range(len(columns)))
print(header)
print(separator)
# Rows
for row in str_rows:
line = " | ".join(val.ljust(widths[i]) for i, val in enumerate(row))
print(line)
print(f"\n({len(rows)} rows)")
def _format_output(
conn: duckdb.DuckDBPyConnection,
sql: str,
fmt: str,
output_path: Optional[str],
max_rows: int,
) -> None:
"""Execute final SQL and output results in the requested format.
Args:
conn: DuckDB connection with all views registered
sql: The final DuckDB SQL query
fmt: Output format (table, csv, json, parquet)
output_path: File path for file-based outputs
max_rows: Maximum rows to return
"""
# Add LIMIT to prevent runaway results
limited_sql = f"SELECT * FROM ({sql}) AS _rq LIMIT {max_rows + 1}"
result = conn.execute(limited_sql)
columns = [desc[0] for desc in result.description]
rows = result.fetchall()
# Check if result exceeded limit
if len(rows) > max_rows:
rows = rows[:max_rows]
_log_progress(
f" WARNING: Result truncated to {max_rows:,} rows. "
f"Add more filters or increase --max-rows."
)
if fmt == "table":
_print_table(columns, rows)
elif fmt == "csv":
if output_path:
with open(output_path, "w", newline="") as f:
writer = csv.writer(f)
writer.writerow(columns)
writer.writerows(rows)
_log_progress(f" CSV written: {output_path} ({len(rows)} rows)")
else:
writer = csv.writer(sys.stdout)
writer.writerow(columns)
writer.writerows(rows)
elif fmt == "json":
data = [dict(zip(columns, row)) for row in rows]
json_str = json.dumps(data, default=str, indent=2)
if output_path:
with open(output_path, "w") as f:
f.write(json_str)
_log_progress(f" JSON written: {output_path} ({len(rows)} rows)")
else:
print(json_str)
elif fmt == "parquet":
import pyarrow as pa
import pyarrow.parquet as pq
# Re-execute without limit wrapper for clean Arrow export
arrow_result = conn.execute(
f"SELECT * FROM ({sql}) AS _rq LIMIT {max_rows}"
).fetch_arrow_table()
if not output_path:
output_path = str(Path(_load_remote_query_config()["output_dir"]) / "result.parquet")
Path(output_path).parent.mkdir(parents=True, exist_ok=True)
pq.write_table(arrow_result, output_path)
_log_progress(
f" Parquet written: {output_path} "
f"({arrow_result.num_rows} rows, {arrow_result.num_columns} cols)"
)
else:
raise RemoteQueryError(f"Unknown format: {fmt}")
# ---------------------------------------------------------------------------
# Progress logging (stderr so stdout stays clean for data)
# ---------------------------------------------------------------------------
_quiet_mode = False
def _log_progress(msg: str) -> None:
"""Print progress message to stderr (keeps stdout clean for data output)."""
if not _quiet_mode:
print(msg, file=sys.stderr)
# ---------------------------------------------------------------------------
# Main execution
# ---------------------------------------------------------------------------
def execute_remote_query(
sql: str,
bq_registrations: list[tuple[str, str]],
fmt: str = "table",
output: Optional[str] = None,
max_rows: Optional[int] = None,
max_bq_rows: Optional[int] = None,
timeout: Optional[int] = None,
data_dir: str = "/data/src_data",
quiet: bool = False,
) -> None:
"""Main execution function for remote queries.
Args:
sql: DuckDB SQL query to execute
bq_registrations: List of (alias, bq_sql) tuples
fmt: Output format (table, csv, json, parquet)
output: Output file path (for parquet/csv/json)
max_rows: Max rows in final result
max_bq_rows: Max rows per BQ sub-query
timeout: Query timeout in seconds
data_dir: Path to data directory
quiet: Suppress progress messages
"""
global _quiet_mode
_quiet_mode = quiet
config = _load_remote_query_config()
max_rows = max_rows or config["max_result_rows"]
max_bq_rows = max_bq_rows or config["max_bq_registration_rows"]
timeout = timeout or config["timeout_seconds"]
fmt = fmt or config["default_format"]
start_time = time.time()
# Create in-memory DuckDB connection
conn = duckdb.connect(":memory:")
try:
# Step 1: Register local Parquet views
_log_progress("Setting up local views...")
local_views = _setup_local_views(conn, data_dir, quiet=quiet)
_log_progress(f" {len(local_views)} local views ready")
# Step 2: Register BQ sub-query results
if bq_registrations:
_log_progress(f"Registering {len(bq_registrations)} BQ sub-queries...")
bq_results = _register_bq_views(
conn, bq_registrations, max_bq_rows, timeout, quiet=quiet,
)
for alias, count in bq_results.items():
_log_progress(f" {alias}: {count:,} rows")
# Step 3: Execute the final DuckDB query
_log_progress("Executing query...")
_format_output(conn, sql, fmt, output, max_rows)
elapsed = time.time() - start_time
_log_progress(f"Done in {elapsed:.1f}s")
finally:
conn.close()
# ---------------------------------------------------------------------------
# CLI argument parsing
# ---------------------------------------------------------------------------
def _parse_register_bq(value: str) -> tuple[str, str]:
"""Parse --register-bq argument in 'alias=SQL' format.
Args:
value: String in format "alias=SELECT ..."
Returns:
Tuple of (alias, sql)
Raises:
argparse.ArgumentTypeError: If format is invalid
"""
eq_pos = value.find("=")
if eq_pos <= 0:
raise argparse.ArgumentTypeError(
f"Invalid --register-bq format: '{value}'. "
f"Expected: 'alias=SELECT ...' (e.g., 'traffic=SELECT report_date, ...')"
)
alias = value[:eq_pos].strip()
sql = value[eq_pos + 1:].strip()
if not sql:
raise argparse.ArgumentTypeError(
f"Empty SQL in --register-bq for alias '{alias}'"
)
return alias, sql
def build_parser() -> argparse.ArgumentParser:
"""Build the argument parser for remote_query CLI."""
parser = argparse.ArgumentParser(
description="Execute DuckDB queries spanning local Parquet + remote BigQuery tables",
formatter_class=argparse.RawDescriptionHelpFormatter,
epilog="""
Examples:
# Local-only query (no BigQuery):
python -m src.remote_query --sql "SELECT COUNT(*) FROM order_economics"
# Register BQ result and query it:
python -m src.remote_query \\
--register-bq "traffic=SELECT report_date, SUM(visitors) FROM \\`proj.ds.table\\` GROUP BY 1" \\
--sql "SELECT * FROM traffic ORDER BY report_date"
# JOIN local + remote:
python -m src.remote_query \\
--register-bq "traffic=SELECT ... GROUP BY ..." \\
--sql "SELECT o.*, t.visitors FROM order_economics o JOIN traffic t ON ..." \\
--format parquet --output /tmp/result.parquet
""",
)
parser.add_argument(
"--sql",
required=False, # not required when --stdin is used
default=None,
help="DuckDB SQL query (executed after all views are registered)",
)
parser.add_argument(
"--register-bq",
action="append",
type=_parse_register_bq,
default=[],
metavar="ALIAS=SQL",
dest="bq_registrations",
help='Register BQ query result as DuckDB view. Format: "alias=BQ_SQL". Repeatable.',
)
parser.add_argument(
"--format",
choices=["table", "csv", "json", "parquet"],
default=None,
dest="fmt",
help="Output format (default: from config or 'table')",
)
parser.add_argument(
"--output",
default=None,
help="Output file path for parquet/csv/json (default: auto for parquet)",
)
parser.add_argument(
"--max-rows",
type=int,
default=None,
help="Max rows in final result (default: from config)",
)
parser.add_argument(
"--max-bq-rows",
type=int,
default=None,
help="Max rows per BQ sub-query (default: from config)",
)
parser.add_argument(
"--timeout",
type=int,
default=None,
help="Query timeout in seconds (default: from config)",
)
parser.add_argument(
"--data-dir",
default="/data/src_data",
help="Parquet data directory (default: /data/src_data)",
)
parser.add_argument(
"--quiet",
action="store_true",
help="Suppress progress messages (stderr)",
)
parser.add_argument(
"--stdin",
action="store_true",
help="Read query spec from stdin as JSON. Avoids shell escaping issues.",
)
return parser
def _parse_stdin_query() -> dict:
"""Parse query specification from stdin JSON.
Expected format:
{
"sql": "SELECT ... FROM ...",
"register_bq": {"alias": "BQ SQL", ...},
"format": "table",
"output": "/path/to/file",
"max_rows": 100000,
"max_bq_rows": 500000
}
Returns:
Dict with parsed query spec
"""
raw = sys.stdin.read().strip()
if not raw:
raise RemoteQueryError("Empty stdin. Provide JSON query spec.")
try:
spec = json.loads(raw)
except json.JSONDecodeError as e:
raise RemoteQueryError(f"Invalid JSON on stdin: {e}")
if "sql" not in spec:
raise RemoteQueryError("JSON must contain 'sql' field.")
return spec
def main(argv: Optional[list[str]] = None) -> None:
"""CLI entry point."""
parser = build_parser()
args = parser.parse_args(argv)
# Setup logging
logging.basicConfig(
level=logging.WARNING if args.quiet else logging.INFO,
format="%(asctime)s - %(name)s - %(levelname)s - %(message)s",
stream=sys.stderr,
)
try:
# --stdin mode: read query spec from JSON on stdin (no shell escaping needed)
if args.stdin:
spec = _parse_stdin_query()
bq_regs = [
(alias, sql) for alias, sql in spec.get("register_bq", {}).items()
]
execute_remote_query(
sql=spec["sql"],
bq_registrations=bq_regs,
fmt=spec.get("format", args.fmt),
output=spec.get("output", args.output),
max_rows=spec.get("max_rows", args.max_rows),
max_bq_rows=spec.get("max_bq_rows", args.max_bq_rows),
timeout=args.timeout,
data_dir=args.data_dir,
quiet=args.quiet,
)
return
# Validate --sql is provided when not using --stdin
if not args.sql:
parser.error("--sql is required (or use --stdin for JSON input)")
execute_remote_query(
sql=args.sql,
bq_registrations=args.bq_registrations,
fmt=args.fmt,
output=args.output,
max_rows=args.max_rows,
max_bq_rows=args.max_bq_rows,
timeout=args.timeout,
data_dir=args.data_dir,
quiet=args.quiet,
)
except RemoteQueryError as e:
print(f"ERROR: {e}", file=sys.stderr)
sys.exit(1)
except KeyboardInterrupt:
print("\nInterrupted.", file=sys.stderr)
sys.exit(130)
except Exception as e:
print(f"UNEXPECTED ERROR: {e}", file=sys.stderr)
logger.exception("Unexpected error in remote_query")
sys.exit(2)
if __name__ == "__main__":
main()