agnes-the-ai-analyst/tests/test_remote_query.py
Petr 8c6c162417 Fix: --sql not required when --stdin is used
argparse was rejecting --stdin mode because --sql was required=True.
Changed to required=False with runtime validation in main().
2026-03-21 12:17:02 +01:00

779 lines
27 KiB
Python

"""Tests for remote_query module - hybrid local Parquet + remote BigQuery queries.
Tests cover:
- CLI argument parsing (_parse_register_bq, build_parser)
- Local view setup (_setup_local_views via create_local_views)
- BQ registration with safety checks (_validate_bq_result_size, _estimate_memory_mb, _register_bq_views)
- Output formatting (_print_table, _format_output)
- End-to-end local-only queries (no BQ mocking needed)
"""
import argparse
import csv
import json
import os
from io import StringIO
from pathlib import Path
from unittest.mock import MagicMock, patch
import duckdb
import pyarrow as pa
import pyarrow.parquet as pq
import pytest
from src.remote_query import (
RemoteQueryError,
_estimate_memory_mb,
_format_output,
_parse_register_bq,
_print_table,
_register_bq_views,
_validate_bq_result_size,
build_parser,
execute_remote_query,
)
# ---------------------------------------------------------------------------
# Fixtures
# ---------------------------------------------------------------------------
@pytest.fixture
def tmp_local_project(tmp_path):
"""Create a minimal project with docs/data_description.md and parquet files.
Layout:
tmp_path/
docs/data_description.md (YAML with local + remote + hybrid tables)
server/parquet/crm_data/orders.parquet
server/parquet/crm_data/products.parquet
Returns (project_root, data_dir) where data_dir = tmp_path / "server".
"""
docs_dir = tmp_path / "docs"
docs_dir.mkdir()
data_description = """\
# Data Description
```yaml
folder_mapping:
in.c-crm: crm_data
tables:
- id: "in.c-crm.orders"
name: "orders"
description: "Order data"
primary_key: "order_id"
sync_strategy: "full_refresh"
- id: "in.c-crm.products"
name: "products"
description: "Product catalog"
primary_key: "product_id"
sync_strategy: "full_refresh"
- id: "prj-grp-dataview-prod-1ff9.supply.traffic"
name: "traffic"
description: "Remote BQ traffic table"
primary_key: "id"
query_mode: "remote"
- id: "in.c-crm.inventory"
name: "inventory"
description: "Hybrid inventory"
primary_key: "id"
sync_strategy: "full_refresh"
query_mode: "hybrid"
```
"""
(docs_dir / "data_description.md").write_text(data_description)
# Create parquet files for local tables
crm_dir = tmp_path / "server" / "parquet" / "crm_data"
crm_dir.mkdir(parents=True)
orders_table = pa.table({
"order_id": [1, 2, 3, 4, 5],
"amount": [10.0, 20.0, 30.0, 40.0, 50.0],
"product_id": [101, 102, 101, 103, 102],
})
pq.write_table(orders_table, crm_dir / "orders.parquet")
products_table = pa.table({
"product_id": [101, 102, 103],
"name": ["Widget", "Gadget", "Doohickey"],
})
pq.write_table(products_table, crm_dir / "products.parquet")
# Create parquet for hybrid table
inventory_table = pa.table({
"id": [1, 2],
"stock": [100, 200],
})
pq.write_table(inventory_table, crm_dir / "inventory.parquet")
data_dir = str(tmp_path / "server")
return tmp_path, data_dir
@pytest.fixture
def duckdb_conn():
"""Create an in-memory DuckDB connection, closed after test."""
conn = duckdb.connect(":memory:")
yield conn
conn.close()
class _DuckDBConnectionProxy:
"""Proxy around DuckDBPyConnection that silently ignores unsupported SET commands.
DuckDB versions may not support 'statement_timeout'. This proxy catches
CatalogException for SET commands so end-to-end tests work across versions.
The real connection's execute method is read-only, so we wrap it.
"""
def __init__(self, conn):
object.__setattr__(self, "_conn", conn)
def execute(self, sql, *args, **kwargs):
if isinstance(sql, str) and sql.strip().upper().startswith("SET "):
try:
return self._conn.execute(sql, *args, **kwargs)
except duckdb.CatalogException:
return None
return self._conn.execute(sql, *args, **kwargs)
def __getattr__(self, name):
return getattr(self._conn, name)
def _patched_duckdb_connect(*args, **kwargs):
"""Create a DuckDB connection wrapped in the proxy."""
conn = duckdb.connect(*args, **kwargs)
return _DuckDBConnectionProxy(conn)
# ---------------------------------------------------------------------------
# Tests: CLI argument parsing
# ---------------------------------------------------------------------------
class TestCLIArgParsing:
"""Test _parse_register_bq() and build_parser()."""
def test_sql_defaults_to_none(self):
"""Parser allows --sql to be omitted (for --stdin mode)."""
parser = build_parser()
args = parser.parse_args(["--stdin"])
assert args.sql is None
assert args.stdin is True
def test_register_bq_parsing(self):
"""'alias=SELECT ...' parses into (alias, sql) tuple."""
result = _parse_register_bq("traffic=SELECT report_date FROM `proj.ds.table`")
assert result == ("traffic", "SELECT report_date FROM `proj.ds.table`")
def test_register_bq_invalid_format(self):
"""Missing '=' should raise ArgumentTypeError."""
with pytest.raises(argparse.ArgumentTypeError, match="Invalid --register-bq format"):
_parse_register_bq("no_equals_sign_here")
def test_register_bq_empty_sql(self):
"""Alias with empty SQL after '=' should raise."""
with pytest.raises(argparse.ArgumentTypeError, match="Empty SQL"):
_parse_register_bq("alias=")
def test_register_bq_empty_alias(self):
"""'=SELECT ...' (empty alias) should raise."""
with pytest.raises(argparse.ArgumentTypeError, match="Invalid --register-bq format"):
_parse_register_bq("=SELECT 1")
def test_multiple_register_bq(self):
"""Multiple --register-bq args should be collected into a list."""
parser = build_parser()
args = parser.parse_args([
"--sql", "SELECT 1",
"--register-bq", "t1=SELECT a FROM x",
"--register-bq", "t2=SELECT b FROM y",
])
assert len(args.bq_registrations) == 2
assert args.bq_registrations[0] == ("t1", "SELECT a FROM x")
assert args.bq_registrations[1] == ("t2", "SELECT b FROM y")
def test_default_format_is_none(self):
"""Default --format should be None (uses config default at runtime)."""
parser = build_parser()
args = parser.parse_args(["--sql", "SELECT 1"])
assert args.fmt is None
def test_explicit_format(self):
"""Explicit --format should be respected."""
parser = build_parser()
args = parser.parse_args(["--sql", "SELECT 1", "--format", "csv"])
assert args.fmt == "csv"
def test_invalid_format_rejected(self):
"""Invalid --format value should cause parser error."""
parser = build_parser()
with pytest.raises(SystemExit):
parser.parse_args(["--sql", "SELECT 1", "--format", "xml"])
def test_no_register_bq_yields_empty_list(self):
"""When no --register-bq is provided, bq_registrations defaults to []."""
parser = build_parser()
args = parser.parse_args(["--sql", "SELECT 1"])
assert args.bq_registrations == []
def test_register_bq_sql_with_equals(self):
"""SQL containing '=' should be parsed correctly (split only on first '=')."""
result = _parse_register_bq("view=SELECT * FROM t WHERE col = 5")
assert result[0] == "view"
assert result[1] == "SELECT * FROM t WHERE col = 5"
def test_quiet_flag(self):
"""--quiet should set quiet=True."""
parser = build_parser()
args = parser.parse_args(["--sql", "SELECT 1", "--quiet"])
assert args.quiet is True
def test_max_rows_parsing(self):
"""--max-rows should be parsed as integer."""
parser = build_parser()
args = parser.parse_args(["--sql", "SELECT 1", "--max-rows", "500"])
assert args.max_rows == 500
# ---------------------------------------------------------------------------
# Tests: Local view setup
# ---------------------------------------------------------------------------
class TestLocalViewSetup:
"""Test _setup_local_views via create_local_views with tmp_path fixture."""
def test_creates_views_from_parquet(self, tmp_local_project, duckdb_conn):
"""Local tables should be available as DuckDB views after setup."""
project_root, data_dir = tmp_local_project
with patch("scripts.duckdb_manager.find_project_root", return_value=project_root):
from src.remote_query import _setup_local_views
created = _setup_local_views(duckdb_conn, data_dir, quiet=True)
assert "orders" in created
assert "products" in created
# Verify data is queryable
count = duckdb_conn.execute("SELECT COUNT(*) FROM orders").fetchone()[0]
assert count == 5
def test_skips_remote_tables(self, tmp_local_project, duckdb_conn):
"""Remote tables (query_mode='remote') should NOT create local views."""
project_root, data_dir = tmp_local_project
with patch("scripts.duckdb_manager.find_project_root", return_value=project_root):
from src.remote_query import _setup_local_views
created = _setup_local_views(duckdb_conn, data_dir, quiet=True)
assert "traffic" not in created
# Verify the remote table is not queryable
tables = [row[0] for row in duckdb_conn.execute("SHOW TABLES").fetchall()]
assert "traffic" not in tables
def test_includes_hybrid_tables(self, tmp_local_project, duckdb_conn):
"""Hybrid tables (query_mode='hybrid') should create local views."""
project_root, data_dir = tmp_local_project
with patch("scripts.duckdb_manager.find_project_root", return_value=project_root):
from src.remote_query import _setup_local_views
created = _setup_local_views(duckdb_conn, data_dir, quiet=True)
assert "inventory" in created
count = duckdb_conn.execute("SELECT COUNT(*) FROM inventory").fetchone()[0]
assert count == 2
# ---------------------------------------------------------------------------
# Tests: BQ registration with safety checks
# ---------------------------------------------------------------------------
class TestBQRegistration:
"""Test BQ result validation and registration (mocked BigQuery)."""
@staticmethod
def _make_mock_bq_client(count_result: int = 100, schema_fields: int = 5):
"""Create a mock BQ client that returns controlled count and schema.
Args:
count_result: Row count returned by COUNT(*) query
schema_fields: Number of fields in the schema
"""
mock_client = MagicMock()
# COUNT(*) query result
count_row = MagicMock()
count_row.__getitem__ = MagicMock(return_value=count_result)
count_iter = iter([count_row])
# Schema query result (LIMIT 0)
mock_schema_fields = [MagicMock() for _ in range(schema_fields)]
mock_schema = MagicMock()
mock_schema.__len__ = MagicMock(return_value=schema_fields)
# Use side_effect to return different results for different queries
def query_side_effect(sql):
job = MagicMock()
if sql.startswith("SELECT COUNT(*)"):
result = MagicMock()
result.__iter__ = MagicMock(return_value=iter([count_row]))
job.result.return_value = result
elif "LIMIT 0" in sql:
result = MagicMock()
result.schema = mock_schema_fields
job.result.return_value = result
return job
mock_client.query.side_effect = query_side_effect
return mock_client
def test_count_check_blocks_large_result(self):
"""BQ sub-query exceeding max_rows should raise RemoteQueryError."""
mock_client = self._make_mock_bq_client(count_result=1_000_000)
with pytest.raises(RemoteQueryError, match="would return 1,000,000 rows"):
_validate_bq_result_size(
bq_client=mock_client,
sql="SELECT * FROM big_table",
alias="big_table",
max_rows=500_000,
)
def test_validates_small_result_passes(self):
"""BQ sub-query within limits should return the row count."""
mock_client = self._make_mock_bq_client(count_result=1000)
row_count = _validate_bq_result_size(
bq_client=mock_client,
sql="SELECT * FROM small_table",
alias="small_table",
max_rows=500_000,
)
assert row_count == 1000
def test_memory_estimate_blocks_huge_result(self):
"""_register_bq_views should refuse when estimated memory exceeds 2 GB."""
# Create a mock that passes count check but fails memory check
# 500K rows x 100 cols x 50 bytes/cell = ~2384 MB > 2048 MB limit
mock_client = self._make_mock_bq_client(count_result=500_000, schema_fields=100)
conn = duckdb.connect(":memory:")
try:
with patch("src.remote_query._create_bq_client", return_value=mock_client), \
patch.dict(os.environ, {"BIGQUERY_PROJECT": "test-proj"}):
with pytest.raises(RemoteQueryError, match="estimated memory"):
_register_bq_views(
conn=conn,
registrations=[("huge", "SELECT * FROM huge_table")],
max_bq_rows=1_000_000,
timeout_seconds=60,
quiet=True,
)
finally:
conn.close()
def test_registers_small_result(self):
"""BQ sub-query within all limits should register successfully."""
# Small result: 100 rows x 5 cols = ~0.02 MB
mock_client = self._make_mock_bq_client(count_result=100, schema_fields=5)
# Mock register_bq_table to return the row count
conn = duckdb.connect(":memory:")
try:
with patch("src.remote_query._create_bq_client", return_value=mock_client), \
patch("src.remote_query.register_bq_table", return_value=100) as mock_reg, \
patch.dict(os.environ, {"BIGQUERY_PROJECT": "test-proj"}):
results = _register_bq_views(
conn=conn,
registrations=[("small_view", "SELECT * FROM small_table")],
max_bq_rows=500_000,
timeout_seconds=60,
quiet=True,
)
assert results == {"small_view": 100}
mock_reg.assert_called_once()
finally:
conn.close()
def test_missing_bigquery_project_raises(self):
"""Missing BIGQUERY_PROJECT env var should raise RemoteQueryError."""
conn = duckdb.connect(":memory:")
try:
with patch.dict(os.environ, {}, clear=True):
with pytest.raises(RemoteQueryError, match="BIGQUERY_PROJECT"):
_register_bq_views(
conn=conn,
registrations=[("v", "SELECT 1")],
max_bq_rows=100,
timeout_seconds=60,
)
finally:
conn.close()
def test_empty_registrations_returns_empty(self):
"""Empty registration list should return empty dict without BQ calls."""
conn = duckdb.connect(":memory:")
try:
result = _register_bq_views(
conn=conn,
registrations=[],
max_bq_rows=100,
timeout_seconds=60,
)
assert result == {}
finally:
conn.close()
# ---------------------------------------------------------------------------
# Tests: Memory estimation
# ---------------------------------------------------------------------------
class TestMemoryEstimation:
"""Test _estimate_memory_mb calculation."""
def test_small_table(self):
"""100 rows x 10 cols = 50_000 bytes ~ 0.048 MB."""
result = _estimate_memory_mb(100, 10)
assert abs(result - 50_000 / (1024 * 1024)) < 0.001
def test_large_table(self):
"""1M rows x 50 cols x 50 bytes = ~2384 MB."""
result = _estimate_memory_mb(1_000_000, 50)
expected = (1_000_000 * 50 * 50) / (1024 * 1024)
assert abs(result - expected) < 0.01
def test_zero_rows(self):
"""Zero rows should return 0 MB."""
assert _estimate_memory_mb(0, 50) == 0.0
def test_zero_columns(self):
"""Zero columns should return 0 MB."""
assert _estimate_memory_mb(1000, 0) == 0.0
# ---------------------------------------------------------------------------
# Tests: Output formatting
# ---------------------------------------------------------------------------
class TestOutputFormatting:
"""Test _print_table and _format_output for various formats."""
def test_table_format_aligned(self, capsys):
"""Table format should produce aligned columns with header separator."""
columns = ["id", "name", "value"]
rows = [(1, "alice", 100), (2, "bob", 200)]
_print_table(columns, rows)
output = capsys.readouterr().out
lines = output.strip().split("\n")
# Header line
assert "id" in lines[0]
assert "name" in lines[0]
assert "value" in lines[0]
# Separator line
assert "-+-" in lines[1]
# Data rows
assert "alice" in lines[2]
assert "bob" in lines[3]
# Row count footer
assert "(2 rows)" in output
def test_table_format_empty_result(self, capsys):
"""Empty result should print '(empty result)'."""
_print_table(["col1"], [])
output = capsys.readouterr().out
assert "(empty result)" in output
def test_table_format_null_values(self, capsys):
"""None values should be rendered as 'NULL'."""
_print_table(["col"], [(None,)])
output = capsys.readouterr().out
assert "NULL" in output
def test_csv_format(self, tmp_path, duckdb_conn):
"""CSV output should contain header + data rows."""
duckdb_conn.execute("CREATE TABLE test AS SELECT 1 AS id, 'hello' AS msg")
output_path = str(tmp_path / "result.csv")
_format_output(
conn=duckdb_conn,
sql="SELECT * FROM test",
fmt="csv",
output_path=output_path,
max_rows=1000,
)
with open(output_path) as f:
reader = csv.reader(f)
header = next(reader)
rows = list(reader)
assert header == ["id", "msg"]
assert len(rows) == 1
assert rows[0][1] == "hello"
def test_csv_format_to_stdout(self, capsys, duckdb_conn):
"""CSV with no output path should write to stdout."""
duckdb_conn.execute("CREATE TABLE test AS SELECT 42 AS val")
_format_output(
conn=duckdb_conn,
sql="SELECT * FROM test",
fmt="csv",
output_path=None,
max_rows=1000,
)
output = capsys.readouterr().out
assert "val" in output
assert "42" in output
def test_json_format(self, tmp_path, duckdb_conn):
"""JSON output should contain a list of dicts."""
duckdb_conn.execute(
"CREATE TABLE test AS SELECT 1 AS id, 'world' AS msg"
)
output_path = str(tmp_path / "result.json")
_format_output(
conn=duckdb_conn,
sql="SELECT * FROM test",
fmt="json",
output_path=output_path,
max_rows=1000,
)
with open(output_path) as f:
data = json.load(f)
assert len(data) == 1
assert data[0]["id"] == 1
assert data[0]["msg"] == "world"
def test_json_format_to_stdout(self, capsys, duckdb_conn):
"""JSON with no output path should print to stdout."""
duckdb_conn.execute("CREATE TABLE test AS SELECT 99 AS num")
_format_output(
conn=duckdb_conn,
sql="SELECT * FROM test",
fmt="json",
output_path=None,
max_rows=1000,
)
output = capsys.readouterr().out
data = json.loads(output)
assert data[0]["num"] == 99
def test_parquet_write(self, tmp_path, duckdb_conn):
"""Parquet output should create a readable parquet file."""
duckdb_conn.execute(
"CREATE TABLE test AS SELECT 1 AS id, 2.5 AS val"
)
output_path = str(tmp_path / "output" / "result.parquet")
with patch("src.remote_query._load_remote_query_config", return_value={
"output_dir": str(tmp_path / "default_output"),
"timeout_seconds": 300,
"max_result_rows": 100_000,
"max_bq_registration_rows": 500_000,
"default_format": "table",
}):
_format_output(
conn=duckdb_conn,
sql="SELECT * FROM test",
fmt="parquet",
output_path=output_path,
max_rows=1000,
)
assert Path(output_path).exists()
# Read it back and verify
result = pq.read_table(output_path)
assert result.num_rows == 1
assert result.num_columns == 2
assert result.column_names == ["id", "val"]
def test_parquet_default_path(self, tmp_path, duckdb_conn):
"""Parquet with no output path should use config default dir."""
duckdb_conn.execute("CREATE TABLE test AS SELECT 1 AS x")
default_dir = str(tmp_path / "default_output")
with patch("src.remote_query._load_remote_query_config", return_value={
"output_dir": default_dir,
"timeout_seconds": 300,
"max_result_rows": 100_000,
"max_bq_registration_rows": 500_000,
"default_format": "table",
}):
_format_output(
conn=duckdb_conn,
sql="SELECT * FROM test",
fmt="parquet",
output_path=None,
max_rows=1000,
)
expected_path = Path(default_dir) / "result.parquet"
assert expected_path.exists()
def test_unknown_format_raises(self, duckdb_conn):
"""Unknown format should raise RemoteQueryError."""
duckdb_conn.execute("CREATE TABLE test AS SELECT 1 AS id")
with pytest.raises(RemoteQueryError, match="Unknown format"):
_format_output(
conn=duckdb_conn,
sql="SELECT * FROM test",
fmt="xml",
output_path=None,
max_rows=1000,
)
# ---------------------------------------------------------------------------
# Tests: End-to-end (local-only, no BQ mocking needed)
# ---------------------------------------------------------------------------
class TestEndToEnd:
"""End-to-end tests with local Parquet data only (no BigQuery dependency).
Uses _patched_duckdb_connect to handle DuckDB version differences
(statement_timeout may not be supported in all versions).
"""
_CONFIG = {
"timeout_seconds": 300,
"max_result_rows": 100_000,
"max_bq_registration_rows": 500_000,
"default_format": "table",
"output_dir": "/tmp/remote_query_test",
}
def _run(self, tmp_local_project, **kwargs):
"""Helper to run execute_remote_query with standard patches."""
project_root, data_dir = tmp_local_project
config = dict(self._CONFIG)
config.update(kwargs.pop("config_overrides", {}))
with patch("scripts.duckdb_manager.find_project_root", return_value=project_root), \
patch("src.remote_query._load_remote_query_config", return_value=config), \
patch("src.remote_query.duckdb") as mock_duckdb_mod:
mock_duckdb_mod.connect = _patched_duckdb_connect
kwargs.setdefault("data_dir", data_dir)
kwargs.setdefault("bq_registrations", [])
kwargs.setdefault("quiet", True)
execute_remote_query(**kwargs)
def test_local_only_query(self, tmp_local_project, capsys):
"""Execute a query against local Parquet views and verify table output."""
self._run(
tmp_local_project,
sql="SELECT COUNT(*) AS cnt FROM orders",
fmt="table",
)
output = capsys.readouterr().out
assert "cnt" in output
assert "5" in output
def test_local_join_query(self, tmp_local_project, capsys):
"""JOIN between two local tables should work."""
self._run(
tmp_local_project,
sql=(
"SELECT p.name, SUM(o.amount) AS total "
"FROM orders o JOIN products p ON o.product_id = p.product_id "
"GROUP BY p.name ORDER BY total DESC"
),
fmt="json",
)
output = capsys.readouterr().out
data = json.loads(output)
assert len(data) == 3
# Widget: orders 1,3 -> 10+30=40
widget = next(r for r in data if r["name"] == "Widget")
assert widget["total"] == 40.0
def test_result_row_limit(self, tmp_local_project, capsys):
"""Result exceeding max_rows should be truncated."""
self._run(
tmp_local_project,
sql="SELECT * FROM orders ORDER BY order_id",
fmt="table",
max_rows=2,
quiet=False,
config_overrides={"max_result_rows": 2},
)
out = capsys.readouterr().out
# Table output should show exactly 2 data rows
assert "(2 rows)" in out
def test_csv_output_to_file(self, tmp_local_project, tmp_path):
"""End-to-end CSV output written to a file."""
output_path = str(tmp_path / "result.csv")
self._run(
tmp_local_project,
sql="SELECT order_id, amount FROM orders ORDER BY order_id",
fmt="csv",
output=output_path,
)
with open(output_path) as f:
reader = csv.DictReader(f)
rows = list(reader)
assert len(rows) == 5
assert rows[0]["order_id"] == "1"
assert rows[0]["amount"] == "10.0"
def test_hybrid_table_queryable(self, tmp_local_project, capsys):
"""Hybrid table should be accessible in local queries."""
self._run(
tmp_local_project,
sql="SELECT SUM(stock) AS total_stock FROM inventory",
fmt="json",
)
output = capsys.readouterr().out
data = json.loads(output)
assert data[0]["total_stock"] == 300
def test_quiet_mode_suppresses_stderr(self, tmp_local_project, capsys):
"""With quiet=True, no progress messages should appear on stderr."""
self._run(
tmp_local_project,
sql="SELECT COUNT(*) AS cnt FROM orders",
fmt="table",
quiet=True,
)
err = capsys.readouterr().err
# In quiet mode, _log_progress should not emit anything
assert "Setting up" not in err
assert "local views" not in err