argparse was rejecting --stdin mode because --sql was required=True. Changed to required=False with runtime validation in main().
779 lines
27 KiB
Python
779 lines
27 KiB
Python
"""Tests for remote_query module - hybrid local Parquet + remote BigQuery queries.
|
|
|
|
Tests cover:
|
|
- CLI argument parsing (_parse_register_bq, build_parser)
|
|
- Local view setup (_setup_local_views via create_local_views)
|
|
- BQ registration with safety checks (_validate_bq_result_size, _estimate_memory_mb, _register_bq_views)
|
|
- Output formatting (_print_table, _format_output)
|
|
- End-to-end local-only queries (no BQ mocking needed)
|
|
"""
|
|
|
|
import argparse
|
|
import csv
|
|
import json
|
|
import os
|
|
from io import StringIO
|
|
from pathlib import Path
|
|
from unittest.mock import MagicMock, patch
|
|
|
|
import duckdb
|
|
import pyarrow as pa
|
|
import pyarrow.parquet as pq
|
|
import pytest
|
|
|
|
from src.remote_query import (
|
|
RemoteQueryError,
|
|
_estimate_memory_mb,
|
|
_format_output,
|
|
_parse_register_bq,
|
|
_print_table,
|
|
_register_bq_views,
|
|
_validate_bq_result_size,
|
|
build_parser,
|
|
execute_remote_query,
|
|
)
|
|
|
|
|
|
# ---------------------------------------------------------------------------
|
|
# Fixtures
|
|
# ---------------------------------------------------------------------------
|
|
|
|
@pytest.fixture
|
|
def tmp_local_project(tmp_path):
|
|
"""Create a minimal project with docs/data_description.md and parquet files.
|
|
|
|
Layout:
|
|
tmp_path/
|
|
docs/data_description.md (YAML with local + remote + hybrid tables)
|
|
server/parquet/crm_data/orders.parquet
|
|
server/parquet/crm_data/products.parquet
|
|
|
|
Returns (project_root, data_dir) where data_dir = tmp_path / "server".
|
|
"""
|
|
docs_dir = tmp_path / "docs"
|
|
docs_dir.mkdir()
|
|
|
|
data_description = """\
|
|
# Data Description
|
|
|
|
```yaml
|
|
folder_mapping:
|
|
in.c-crm: crm_data
|
|
|
|
tables:
|
|
- id: "in.c-crm.orders"
|
|
name: "orders"
|
|
description: "Order data"
|
|
primary_key: "order_id"
|
|
sync_strategy: "full_refresh"
|
|
|
|
- id: "in.c-crm.products"
|
|
name: "products"
|
|
description: "Product catalog"
|
|
primary_key: "product_id"
|
|
sync_strategy: "full_refresh"
|
|
|
|
- id: "prj-grp-dataview-prod-1ff9.supply.traffic"
|
|
name: "traffic"
|
|
description: "Remote BQ traffic table"
|
|
primary_key: "id"
|
|
query_mode: "remote"
|
|
|
|
- id: "in.c-crm.inventory"
|
|
name: "inventory"
|
|
description: "Hybrid inventory"
|
|
primary_key: "id"
|
|
sync_strategy: "full_refresh"
|
|
query_mode: "hybrid"
|
|
```
|
|
"""
|
|
(docs_dir / "data_description.md").write_text(data_description)
|
|
|
|
# Create parquet files for local tables
|
|
crm_dir = tmp_path / "server" / "parquet" / "crm_data"
|
|
crm_dir.mkdir(parents=True)
|
|
|
|
orders_table = pa.table({
|
|
"order_id": [1, 2, 3, 4, 5],
|
|
"amount": [10.0, 20.0, 30.0, 40.0, 50.0],
|
|
"product_id": [101, 102, 101, 103, 102],
|
|
})
|
|
pq.write_table(orders_table, crm_dir / "orders.parquet")
|
|
|
|
products_table = pa.table({
|
|
"product_id": [101, 102, 103],
|
|
"name": ["Widget", "Gadget", "Doohickey"],
|
|
})
|
|
pq.write_table(products_table, crm_dir / "products.parquet")
|
|
|
|
# Create parquet for hybrid table
|
|
inventory_table = pa.table({
|
|
"id": [1, 2],
|
|
"stock": [100, 200],
|
|
})
|
|
pq.write_table(inventory_table, crm_dir / "inventory.parquet")
|
|
|
|
data_dir = str(tmp_path / "server")
|
|
return tmp_path, data_dir
|
|
|
|
|
|
@pytest.fixture
|
|
def duckdb_conn():
|
|
"""Create an in-memory DuckDB connection, closed after test."""
|
|
conn = duckdb.connect(":memory:")
|
|
yield conn
|
|
conn.close()
|
|
|
|
|
|
class _DuckDBConnectionProxy:
|
|
"""Proxy around DuckDBPyConnection that silently ignores unsupported SET commands.
|
|
|
|
DuckDB versions may not support 'statement_timeout'. This proxy catches
|
|
CatalogException for SET commands so end-to-end tests work across versions.
|
|
The real connection's execute method is read-only, so we wrap it.
|
|
"""
|
|
|
|
def __init__(self, conn):
|
|
object.__setattr__(self, "_conn", conn)
|
|
|
|
def execute(self, sql, *args, **kwargs):
|
|
if isinstance(sql, str) and sql.strip().upper().startswith("SET "):
|
|
try:
|
|
return self._conn.execute(sql, *args, **kwargs)
|
|
except duckdb.CatalogException:
|
|
return None
|
|
return self._conn.execute(sql, *args, **kwargs)
|
|
|
|
def __getattr__(self, name):
|
|
return getattr(self._conn, name)
|
|
|
|
|
|
def _patched_duckdb_connect(*args, **kwargs):
|
|
"""Create a DuckDB connection wrapped in the proxy."""
|
|
conn = duckdb.connect(*args, **kwargs)
|
|
return _DuckDBConnectionProxy(conn)
|
|
|
|
|
|
# ---------------------------------------------------------------------------
|
|
# Tests: CLI argument parsing
|
|
# ---------------------------------------------------------------------------
|
|
|
|
class TestCLIArgParsing:
|
|
"""Test _parse_register_bq() and build_parser()."""
|
|
|
|
def test_sql_defaults_to_none(self):
|
|
"""Parser allows --sql to be omitted (for --stdin mode)."""
|
|
parser = build_parser()
|
|
args = parser.parse_args(["--stdin"])
|
|
assert args.sql is None
|
|
assert args.stdin is True
|
|
|
|
def test_register_bq_parsing(self):
|
|
"""'alias=SELECT ...' parses into (alias, sql) tuple."""
|
|
result = _parse_register_bq("traffic=SELECT report_date FROM `proj.ds.table`")
|
|
assert result == ("traffic", "SELECT report_date FROM `proj.ds.table`")
|
|
|
|
def test_register_bq_invalid_format(self):
|
|
"""Missing '=' should raise ArgumentTypeError."""
|
|
with pytest.raises(argparse.ArgumentTypeError, match="Invalid --register-bq format"):
|
|
_parse_register_bq("no_equals_sign_here")
|
|
|
|
def test_register_bq_empty_sql(self):
|
|
"""Alias with empty SQL after '=' should raise."""
|
|
with pytest.raises(argparse.ArgumentTypeError, match="Empty SQL"):
|
|
_parse_register_bq("alias=")
|
|
|
|
def test_register_bq_empty_alias(self):
|
|
"""'=SELECT ...' (empty alias) should raise."""
|
|
with pytest.raises(argparse.ArgumentTypeError, match="Invalid --register-bq format"):
|
|
_parse_register_bq("=SELECT 1")
|
|
|
|
def test_multiple_register_bq(self):
|
|
"""Multiple --register-bq args should be collected into a list."""
|
|
parser = build_parser()
|
|
args = parser.parse_args([
|
|
"--sql", "SELECT 1",
|
|
"--register-bq", "t1=SELECT a FROM x",
|
|
"--register-bq", "t2=SELECT b FROM y",
|
|
])
|
|
assert len(args.bq_registrations) == 2
|
|
assert args.bq_registrations[0] == ("t1", "SELECT a FROM x")
|
|
assert args.bq_registrations[1] == ("t2", "SELECT b FROM y")
|
|
|
|
def test_default_format_is_none(self):
|
|
"""Default --format should be None (uses config default at runtime)."""
|
|
parser = build_parser()
|
|
args = parser.parse_args(["--sql", "SELECT 1"])
|
|
assert args.fmt is None
|
|
|
|
def test_explicit_format(self):
|
|
"""Explicit --format should be respected."""
|
|
parser = build_parser()
|
|
args = parser.parse_args(["--sql", "SELECT 1", "--format", "csv"])
|
|
assert args.fmt == "csv"
|
|
|
|
def test_invalid_format_rejected(self):
|
|
"""Invalid --format value should cause parser error."""
|
|
parser = build_parser()
|
|
with pytest.raises(SystemExit):
|
|
parser.parse_args(["--sql", "SELECT 1", "--format", "xml"])
|
|
|
|
def test_no_register_bq_yields_empty_list(self):
|
|
"""When no --register-bq is provided, bq_registrations defaults to []."""
|
|
parser = build_parser()
|
|
args = parser.parse_args(["--sql", "SELECT 1"])
|
|
assert args.bq_registrations == []
|
|
|
|
def test_register_bq_sql_with_equals(self):
|
|
"""SQL containing '=' should be parsed correctly (split only on first '=')."""
|
|
result = _parse_register_bq("view=SELECT * FROM t WHERE col = 5")
|
|
assert result[0] == "view"
|
|
assert result[1] == "SELECT * FROM t WHERE col = 5"
|
|
|
|
def test_quiet_flag(self):
|
|
"""--quiet should set quiet=True."""
|
|
parser = build_parser()
|
|
args = parser.parse_args(["--sql", "SELECT 1", "--quiet"])
|
|
assert args.quiet is True
|
|
|
|
def test_max_rows_parsing(self):
|
|
"""--max-rows should be parsed as integer."""
|
|
parser = build_parser()
|
|
args = parser.parse_args(["--sql", "SELECT 1", "--max-rows", "500"])
|
|
assert args.max_rows == 500
|
|
|
|
|
|
# ---------------------------------------------------------------------------
|
|
# Tests: Local view setup
|
|
# ---------------------------------------------------------------------------
|
|
|
|
class TestLocalViewSetup:
|
|
"""Test _setup_local_views via create_local_views with tmp_path fixture."""
|
|
|
|
def test_creates_views_from_parquet(self, tmp_local_project, duckdb_conn):
|
|
"""Local tables should be available as DuckDB views after setup."""
|
|
project_root, data_dir = tmp_local_project
|
|
|
|
with patch("scripts.duckdb_manager.find_project_root", return_value=project_root):
|
|
from src.remote_query import _setup_local_views
|
|
created = _setup_local_views(duckdb_conn, data_dir, quiet=True)
|
|
|
|
assert "orders" in created
|
|
assert "products" in created
|
|
|
|
# Verify data is queryable
|
|
count = duckdb_conn.execute("SELECT COUNT(*) FROM orders").fetchone()[0]
|
|
assert count == 5
|
|
|
|
def test_skips_remote_tables(self, tmp_local_project, duckdb_conn):
|
|
"""Remote tables (query_mode='remote') should NOT create local views."""
|
|
project_root, data_dir = tmp_local_project
|
|
|
|
with patch("scripts.duckdb_manager.find_project_root", return_value=project_root):
|
|
from src.remote_query import _setup_local_views
|
|
created = _setup_local_views(duckdb_conn, data_dir, quiet=True)
|
|
|
|
assert "traffic" not in created
|
|
|
|
# Verify the remote table is not queryable
|
|
tables = [row[0] for row in duckdb_conn.execute("SHOW TABLES").fetchall()]
|
|
assert "traffic" not in tables
|
|
|
|
def test_includes_hybrid_tables(self, tmp_local_project, duckdb_conn):
|
|
"""Hybrid tables (query_mode='hybrid') should create local views."""
|
|
project_root, data_dir = tmp_local_project
|
|
|
|
with patch("scripts.duckdb_manager.find_project_root", return_value=project_root):
|
|
from src.remote_query import _setup_local_views
|
|
created = _setup_local_views(duckdb_conn, data_dir, quiet=True)
|
|
|
|
assert "inventory" in created
|
|
|
|
count = duckdb_conn.execute("SELECT COUNT(*) FROM inventory").fetchone()[0]
|
|
assert count == 2
|
|
|
|
|
|
# ---------------------------------------------------------------------------
|
|
# Tests: BQ registration with safety checks
|
|
# ---------------------------------------------------------------------------
|
|
|
|
class TestBQRegistration:
|
|
"""Test BQ result validation and registration (mocked BigQuery)."""
|
|
|
|
@staticmethod
|
|
def _make_mock_bq_client(count_result: int = 100, schema_fields: int = 5):
|
|
"""Create a mock BQ client that returns controlled count and schema.
|
|
|
|
Args:
|
|
count_result: Row count returned by COUNT(*) query
|
|
schema_fields: Number of fields in the schema
|
|
"""
|
|
mock_client = MagicMock()
|
|
|
|
# COUNT(*) query result
|
|
count_row = MagicMock()
|
|
count_row.__getitem__ = MagicMock(return_value=count_result)
|
|
count_iter = iter([count_row])
|
|
|
|
# Schema query result (LIMIT 0)
|
|
mock_schema_fields = [MagicMock() for _ in range(schema_fields)]
|
|
mock_schema = MagicMock()
|
|
mock_schema.__len__ = MagicMock(return_value=schema_fields)
|
|
|
|
# Use side_effect to return different results for different queries
|
|
def query_side_effect(sql):
|
|
job = MagicMock()
|
|
if sql.startswith("SELECT COUNT(*)"):
|
|
result = MagicMock()
|
|
result.__iter__ = MagicMock(return_value=iter([count_row]))
|
|
job.result.return_value = result
|
|
elif "LIMIT 0" in sql:
|
|
result = MagicMock()
|
|
result.schema = mock_schema_fields
|
|
job.result.return_value = result
|
|
return job
|
|
|
|
mock_client.query.side_effect = query_side_effect
|
|
return mock_client
|
|
|
|
def test_count_check_blocks_large_result(self):
|
|
"""BQ sub-query exceeding max_rows should raise RemoteQueryError."""
|
|
mock_client = self._make_mock_bq_client(count_result=1_000_000)
|
|
|
|
with pytest.raises(RemoteQueryError, match="would return 1,000,000 rows"):
|
|
_validate_bq_result_size(
|
|
bq_client=mock_client,
|
|
sql="SELECT * FROM big_table",
|
|
alias="big_table",
|
|
max_rows=500_000,
|
|
)
|
|
|
|
def test_validates_small_result_passes(self):
|
|
"""BQ sub-query within limits should return the row count."""
|
|
mock_client = self._make_mock_bq_client(count_result=1000)
|
|
|
|
row_count = _validate_bq_result_size(
|
|
bq_client=mock_client,
|
|
sql="SELECT * FROM small_table",
|
|
alias="small_table",
|
|
max_rows=500_000,
|
|
)
|
|
|
|
assert row_count == 1000
|
|
|
|
def test_memory_estimate_blocks_huge_result(self):
|
|
"""_register_bq_views should refuse when estimated memory exceeds 2 GB."""
|
|
# Create a mock that passes count check but fails memory check
|
|
# 500K rows x 100 cols x 50 bytes/cell = ~2384 MB > 2048 MB limit
|
|
mock_client = self._make_mock_bq_client(count_result=500_000, schema_fields=100)
|
|
|
|
conn = duckdb.connect(":memory:")
|
|
try:
|
|
with patch("src.remote_query._create_bq_client", return_value=mock_client), \
|
|
patch.dict(os.environ, {"BIGQUERY_PROJECT": "test-proj"}):
|
|
with pytest.raises(RemoteQueryError, match="estimated memory"):
|
|
_register_bq_views(
|
|
conn=conn,
|
|
registrations=[("huge", "SELECT * FROM huge_table")],
|
|
max_bq_rows=1_000_000,
|
|
timeout_seconds=60,
|
|
quiet=True,
|
|
)
|
|
finally:
|
|
conn.close()
|
|
|
|
def test_registers_small_result(self):
|
|
"""BQ sub-query within all limits should register successfully."""
|
|
# Small result: 100 rows x 5 cols = ~0.02 MB
|
|
mock_client = self._make_mock_bq_client(count_result=100, schema_fields=5)
|
|
|
|
# Mock register_bq_table to return the row count
|
|
conn = duckdb.connect(":memory:")
|
|
try:
|
|
with patch("src.remote_query._create_bq_client", return_value=mock_client), \
|
|
patch("src.remote_query.register_bq_table", return_value=100) as mock_reg, \
|
|
patch.dict(os.environ, {"BIGQUERY_PROJECT": "test-proj"}):
|
|
results = _register_bq_views(
|
|
conn=conn,
|
|
registrations=[("small_view", "SELECT * FROM small_table")],
|
|
max_bq_rows=500_000,
|
|
timeout_seconds=60,
|
|
quiet=True,
|
|
)
|
|
|
|
assert results == {"small_view": 100}
|
|
mock_reg.assert_called_once()
|
|
finally:
|
|
conn.close()
|
|
|
|
def test_missing_bigquery_project_raises(self):
|
|
"""Missing BIGQUERY_PROJECT env var should raise RemoteQueryError."""
|
|
conn = duckdb.connect(":memory:")
|
|
try:
|
|
with patch.dict(os.environ, {}, clear=True):
|
|
with pytest.raises(RemoteQueryError, match="BIGQUERY_PROJECT"):
|
|
_register_bq_views(
|
|
conn=conn,
|
|
registrations=[("v", "SELECT 1")],
|
|
max_bq_rows=100,
|
|
timeout_seconds=60,
|
|
)
|
|
finally:
|
|
conn.close()
|
|
|
|
def test_empty_registrations_returns_empty(self):
|
|
"""Empty registration list should return empty dict without BQ calls."""
|
|
conn = duckdb.connect(":memory:")
|
|
try:
|
|
result = _register_bq_views(
|
|
conn=conn,
|
|
registrations=[],
|
|
max_bq_rows=100,
|
|
timeout_seconds=60,
|
|
)
|
|
assert result == {}
|
|
finally:
|
|
conn.close()
|
|
|
|
|
|
# ---------------------------------------------------------------------------
|
|
# Tests: Memory estimation
|
|
# ---------------------------------------------------------------------------
|
|
|
|
class TestMemoryEstimation:
|
|
"""Test _estimate_memory_mb calculation."""
|
|
|
|
def test_small_table(self):
|
|
"""100 rows x 10 cols = 50_000 bytes ~ 0.048 MB."""
|
|
result = _estimate_memory_mb(100, 10)
|
|
assert abs(result - 50_000 / (1024 * 1024)) < 0.001
|
|
|
|
def test_large_table(self):
|
|
"""1M rows x 50 cols x 50 bytes = ~2384 MB."""
|
|
result = _estimate_memory_mb(1_000_000, 50)
|
|
expected = (1_000_000 * 50 * 50) / (1024 * 1024)
|
|
assert abs(result - expected) < 0.01
|
|
|
|
def test_zero_rows(self):
|
|
"""Zero rows should return 0 MB."""
|
|
assert _estimate_memory_mb(0, 50) == 0.0
|
|
|
|
def test_zero_columns(self):
|
|
"""Zero columns should return 0 MB."""
|
|
assert _estimate_memory_mb(1000, 0) == 0.0
|
|
|
|
|
|
# ---------------------------------------------------------------------------
|
|
# Tests: Output formatting
|
|
# ---------------------------------------------------------------------------
|
|
|
|
class TestOutputFormatting:
|
|
"""Test _print_table and _format_output for various formats."""
|
|
|
|
def test_table_format_aligned(self, capsys):
|
|
"""Table format should produce aligned columns with header separator."""
|
|
columns = ["id", "name", "value"]
|
|
rows = [(1, "alice", 100), (2, "bob", 200)]
|
|
|
|
_print_table(columns, rows)
|
|
|
|
output = capsys.readouterr().out
|
|
lines = output.strip().split("\n")
|
|
|
|
# Header line
|
|
assert "id" in lines[0]
|
|
assert "name" in lines[0]
|
|
assert "value" in lines[0]
|
|
|
|
# Separator line
|
|
assert "-+-" in lines[1]
|
|
|
|
# Data rows
|
|
assert "alice" in lines[2]
|
|
assert "bob" in lines[3]
|
|
|
|
# Row count footer
|
|
assert "(2 rows)" in output
|
|
|
|
def test_table_format_empty_result(self, capsys):
|
|
"""Empty result should print '(empty result)'."""
|
|
_print_table(["col1"], [])
|
|
|
|
output = capsys.readouterr().out
|
|
assert "(empty result)" in output
|
|
|
|
def test_table_format_null_values(self, capsys):
|
|
"""None values should be rendered as 'NULL'."""
|
|
_print_table(["col"], [(None,)])
|
|
|
|
output = capsys.readouterr().out
|
|
assert "NULL" in output
|
|
|
|
def test_csv_format(self, tmp_path, duckdb_conn):
|
|
"""CSV output should contain header + data rows."""
|
|
duckdb_conn.execute("CREATE TABLE test AS SELECT 1 AS id, 'hello' AS msg")
|
|
|
|
output_path = str(tmp_path / "result.csv")
|
|
_format_output(
|
|
conn=duckdb_conn,
|
|
sql="SELECT * FROM test",
|
|
fmt="csv",
|
|
output_path=output_path,
|
|
max_rows=1000,
|
|
)
|
|
|
|
with open(output_path) as f:
|
|
reader = csv.reader(f)
|
|
header = next(reader)
|
|
rows = list(reader)
|
|
|
|
assert header == ["id", "msg"]
|
|
assert len(rows) == 1
|
|
assert rows[0][1] == "hello"
|
|
|
|
def test_csv_format_to_stdout(self, capsys, duckdb_conn):
|
|
"""CSV with no output path should write to stdout."""
|
|
duckdb_conn.execute("CREATE TABLE test AS SELECT 42 AS val")
|
|
|
|
_format_output(
|
|
conn=duckdb_conn,
|
|
sql="SELECT * FROM test",
|
|
fmt="csv",
|
|
output_path=None,
|
|
max_rows=1000,
|
|
)
|
|
|
|
output = capsys.readouterr().out
|
|
assert "val" in output
|
|
assert "42" in output
|
|
|
|
def test_json_format(self, tmp_path, duckdb_conn):
|
|
"""JSON output should contain a list of dicts."""
|
|
duckdb_conn.execute(
|
|
"CREATE TABLE test AS SELECT 1 AS id, 'world' AS msg"
|
|
)
|
|
|
|
output_path = str(tmp_path / "result.json")
|
|
_format_output(
|
|
conn=duckdb_conn,
|
|
sql="SELECT * FROM test",
|
|
fmt="json",
|
|
output_path=output_path,
|
|
max_rows=1000,
|
|
)
|
|
|
|
with open(output_path) as f:
|
|
data = json.load(f)
|
|
|
|
assert len(data) == 1
|
|
assert data[0]["id"] == 1
|
|
assert data[0]["msg"] == "world"
|
|
|
|
def test_json_format_to_stdout(self, capsys, duckdb_conn):
|
|
"""JSON with no output path should print to stdout."""
|
|
duckdb_conn.execute("CREATE TABLE test AS SELECT 99 AS num")
|
|
|
|
_format_output(
|
|
conn=duckdb_conn,
|
|
sql="SELECT * FROM test",
|
|
fmt="json",
|
|
output_path=None,
|
|
max_rows=1000,
|
|
)
|
|
|
|
output = capsys.readouterr().out
|
|
data = json.loads(output)
|
|
assert data[0]["num"] == 99
|
|
|
|
def test_parquet_write(self, tmp_path, duckdb_conn):
|
|
"""Parquet output should create a readable parquet file."""
|
|
duckdb_conn.execute(
|
|
"CREATE TABLE test AS SELECT 1 AS id, 2.5 AS val"
|
|
)
|
|
|
|
output_path = str(tmp_path / "output" / "result.parquet")
|
|
|
|
with patch("src.remote_query._load_remote_query_config", return_value={
|
|
"output_dir": str(tmp_path / "default_output"),
|
|
"timeout_seconds": 300,
|
|
"max_result_rows": 100_000,
|
|
"max_bq_registration_rows": 500_000,
|
|
"default_format": "table",
|
|
}):
|
|
_format_output(
|
|
conn=duckdb_conn,
|
|
sql="SELECT * FROM test",
|
|
fmt="parquet",
|
|
output_path=output_path,
|
|
max_rows=1000,
|
|
)
|
|
|
|
assert Path(output_path).exists()
|
|
|
|
# Read it back and verify
|
|
result = pq.read_table(output_path)
|
|
assert result.num_rows == 1
|
|
assert result.num_columns == 2
|
|
assert result.column_names == ["id", "val"]
|
|
|
|
def test_parquet_default_path(self, tmp_path, duckdb_conn):
|
|
"""Parquet with no output path should use config default dir."""
|
|
duckdb_conn.execute("CREATE TABLE test AS SELECT 1 AS x")
|
|
|
|
default_dir = str(tmp_path / "default_output")
|
|
with patch("src.remote_query._load_remote_query_config", return_value={
|
|
"output_dir": default_dir,
|
|
"timeout_seconds": 300,
|
|
"max_result_rows": 100_000,
|
|
"max_bq_registration_rows": 500_000,
|
|
"default_format": "table",
|
|
}):
|
|
_format_output(
|
|
conn=duckdb_conn,
|
|
sql="SELECT * FROM test",
|
|
fmt="parquet",
|
|
output_path=None,
|
|
max_rows=1000,
|
|
)
|
|
|
|
expected_path = Path(default_dir) / "result.parquet"
|
|
assert expected_path.exists()
|
|
|
|
def test_unknown_format_raises(self, duckdb_conn):
|
|
"""Unknown format should raise RemoteQueryError."""
|
|
duckdb_conn.execute("CREATE TABLE test AS SELECT 1 AS id")
|
|
|
|
with pytest.raises(RemoteQueryError, match="Unknown format"):
|
|
_format_output(
|
|
conn=duckdb_conn,
|
|
sql="SELECT * FROM test",
|
|
fmt="xml",
|
|
output_path=None,
|
|
max_rows=1000,
|
|
)
|
|
|
|
|
|
# ---------------------------------------------------------------------------
|
|
# Tests: End-to-end (local-only, no BQ mocking needed)
|
|
# ---------------------------------------------------------------------------
|
|
|
|
class TestEndToEnd:
|
|
"""End-to-end tests with local Parquet data only (no BigQuery dependency).
|
|
|
|
Uses _patched_duckdb_connect to handle DuckDB version differences
|
|
(statement_timeout may not be supported in all versions).
|
|
"""
|
|
|
|
_CONFIG = {
|
|
"timeout_seconds": 300,
|
|
"max_result_rows": 100_000,
|
|
"max_bq_registration_rows": 500_000,
|
|
"default_format": "table",
|
|
"output_dir": "/tmp/remote_query_test",
|
|
}
|
|
|
|
def _run(self, tmp_local_project, **kwargs):
|
|
"""Helper to run execute_remote_query with standard patches."""
|
|
project_root, data_dir = tmp_local_project
|
|
config = dict(self._CONFIG)
|
|
config.update(kwargs.pop("config_overrides", {}))
|
|
|
|
with patch("scripts.duckdb_manager.find_project_root", return_value=project_root), \
|
|
patch("src.remote_query._load_remote_query_config", return_value=config), \
|
|
patch("src.remote_query.duckdb") as mock_duckdb_mod:
|
|
mock_duckdb_mod.connect = _patched_duckdb_connect
|
|
kwargs.setdefault("data_dir", data_dir)
|
|
kwargs.setdefault("bq_registrations", [])
|
|
kwargs.setdefault("quiet", True)
|
|
execute_remote_query(**kwargs)
|
|
|
|
def test_local_only_query(self, tmp_local_project, capsys):
|
|
"""Execute a query against local Parquet views and verify table output."""
|
|
self._run(
|
|
tmp_local_project,
|
|
sql="SELECT COUNT(*) AS cnt FROM orders",
|
|
fmt="table",
|
|
)
|
|
|
|
output = capsys.readouterr().out
|
|
assert "cnt" in output
|
|
assert "5" in output
|
|
|
|
def test_local_join_query(self, tmp_local_project, capsys):
|
|
"""JOIN between two local tables should work."""
|
|
self._run(
|
|
tmp_local_project,
|
|
sql=(
|
|
"SELECT p.name, SUM(o.amount) AS total "
|
|
"FROM orders o JOIN products p ON o.product_id = p.product_id "
|
|
"GROUP BY p.name ORDER BY total DESC"
|
|
),
|
|
fmt="json",
|
|
)
|
|
|
|
output = capsys.readouterr().out
|
|
data = json.loads(output)
|
|
assert len(data) == 3
|
|
# Widget: orders 1,3 -> 10+30=40
|
|
widget = next(r for r in data if r["name"] == "Widget")
|
|
assert widget["total"] == 40.0
|
|
|
|
def test_result_row_limit(self, tmp_local_project, capsys):
|
|
"""Result exceeding max_rows should be truncated."""
|
|
self._run(
|
|
tmp_local_project,
|
|
sql="SELECT * FROM orders ORDER BY order_id",
|
|
fmt="table",
|
|
max_rows=2,
|
|
quiet=False,
|
|
config_overrides={"max_result_rows": 2},
|
|
)
|
|
|
|
out = capsys.readouterr().out
|
|
# Table output should show exactly 2 data rows
|
|
assert "(2 rows)" in out
|
|
|
|
def test_csv_output_to_file(self, tmp_local_project, tmp_path):
|
|
"""End-to-end CSV output written to a file."""
|
|
output_path = str(tmp_path / "result.csv")
|
|
|
|
self._run(
|
|
tmp_local_project,
|
|
sql="SELECT order_id, amount FROM orders ORDER BY order_id",
|
|
fmt="csv",
|
|
output=output_path,
|
|
)
|
|
|
|
with open(output_path) as f:
|
|
reader = csv.DictReader(f)
|
|
rows = list(reader)
|
|
|
|
assert len(rows) == 5
|
|
assert rows[0]["order_id"] == "1"
|
|
assert rows[0]["amount"] == "10.0"
|
|
|
|
def test_hybrid_table_queryable(self, tmp_local_project, capsys):
|
|
"""Hybrid table should be accessible in local queries."""
|
|
self._run(
|
|
tmp_local_project,
|
|
sql="SELECT SUM(stock) AS total_stock FROM inventory",
|
|
fmt="json",
|
|
)
|
|
|
|
output = capsys.readouterr().out
|
|
data = json.loads(output)
|
|
assert data[0]["total_stock"] == 300
|
|
|
|
def test_quiet_mode_suppresses_stderr(self, tmp_local_project, capsys):
|
|
"""With quiet=True, no progress messages should appear on stderr."""
|
|
self._run(
|
|
tmp_local_project,
|
|
sql="SELECT COUNT(*) AS cnt FROM orders",
|
|
fmt="table",
|
|
quiet=True,
|
|
)
|
|
|
|
err = capsys.readouterr().err
|
|
# In quiet mode, _log_progress should not emit anything
|
|
assert "Setting up" not in err
|
|
assert "local views" not in err
|