- Add AS _cnt alias to COUNT(*) subquery (BQ Standard SQL requires it) - Catch ImportError in _get_bq_client() and raise RemoteQueryError so API endpoint returns proper 400 instead of 500 Co-Authored-By: Claude Opus 4.6 (1M context) <noreply@anthropic.com>
290 lines
10 KiB
Python
290 lines
10 KiB
Python
"""Tests for RemoteQueryEngine — two-phase BQ registration + DuckDB execution."""
|
|
|
|
import sys
|
|
from datetime import date
|
|
from decimal import Decimal
|
|
from unittest.mock import MagicMock, patch
|
|
|
|
import duckdb
|
|
import pyarrow as pa
|
|
import pytest
|
|
|
|
from src.remote_query import RemoteQueryEngine, RemoteQueryError, _validate_bq_sql, _validate_sql
|
|
|
|
|
|
# ---------------------------------------------------------------------------
|
|
# Fixtures
|
|
# ---------------------------------------------------------------------------
|
|
|
|
|
|
@pytest.fixture
|
|
def analytics_conn():
|
|
conn = duckdb.connect()
|
|
conn.execute("CREATE TABLE orders (id INT, date DATE, amount DECIMAL(10,2))")
|
|
conn.execute(
|
|
"INSERT INTO orders VALUES (1, '2026-01-01', 100.0), (2, '2026-01-15', 200.0)"
|
|
)
|
|
yield conn
|
|
conn.close()
|
|
|
|
|
|
def _make_bq_mock(arrow_table, count_value=None):
|
|
"""Build a minimal BQ client mock.
|
|
|
|
First call to client.query() returns a count job, second returns a data job.
|
|
If count_value is None, infer it from arrow_table.num_rows.
|
|
"""
|
|
if count_value is None:
|
|
count_value = arrow_table.num_rows
|
|
|
|
count_arrow = pa.table({"count": pa.array([count_value], type=pa.int64())})
|
|
|
|
count_job = MagicMock()
|
|
count_job.to_arrow.return_value = count_arrow
|
|
|
|
data_job = MagicMock()
|
|
data_job.to_arrow.return_value = arrow_table
|
|
|
|
mock_client = MagicMock()
|
|
mock_client.query.side_effect = [count_job, data_job]
|
|
|
|
return mock_client
|
|
|
|
|
|
# ---------------------------------------------------------------------------
|
|
# TestRemoteQueryEngineRegister
|
|
# ---------------------------------------------------------------------------
|
|
|
|
|
|
class TestRemoteQueryEngineRegister:
|
|
def test_register_bq_success(self, analytics_conn):
|
|
"""Mock BQ client returning an Arrow table; verify view is queryable."""
|
|
arrow_table = pa.table(
|
|
{
|
|
"order_id": pa.array([10, 20, 30], type=pa.int64()),
|
|
"revenue": pa.array([1.0, 2.0, 3.0], type=pa.float64()),
|
|
}
|
|
)
|
|
mock_client = _make_bq_mock(arrow_table)
|
|
|
|
engine = RemoteQueryEngine(
|
|
analytics_conn,
|
|
_bq_client_factory=lambda project: mock_client,
|
|
max_bq_registration_rows=500_000,
|
|
)
|
|
|
|
result = engine.register_bq("bq_orders", "SELECT order_id, revenue FROM bq.orders")
|
|
|
|
assert result["alias"] == "bq_orders"
|
|
assert result["rows"] == 3
|
|
assert result["columns"] == ["order_id", "revenue"]
|
|
assert result["memory_mb"] > 0
|
|
|
|
# The alias must be queryable from DuckDB
|
|
rows = analytics_conn.execute("SELECT COUNT(*) FROM bq_orders").fetchone()
|
|
assert rows[0] == 3
|
|
|
|
def test_register_bq_row_limit_exceeded(self, analytics_conn):
|
|
"""COUNT pre-check returns a value exceeding the row limit → RemoteQueryError."""
|
|
arrow_table = pa.table({"x": pa.array([1], type=pa.int64())})
|
|
# count exceeds limit
|
|
mock_client = _make_bq_mock(arrow_table, count_value=1_000_000)
|
|
|
|
engine = RemoteQueryEngine(
|
|
analytics_conn,
|
|
_bq_client_factory=lambda project: mock_client,
|
|
max_bq_registration_rows=500_000,
|
|
)
|
|
|
|
with pytest.raises(RemoteQueryError) as exc_info:
|
|
engine.register_bq("bq_big", "SELECT * FROM bq.huge_table")
|
|
|
|
assert exc_info.value.error_type == "row_limit"
|
|
assert exc_info.value.details["count"] == 1_000_000
|
|
|
|
def test_register_bq_invalid_alias(self, analytics_conn):
|
|
engine = RemoteQueryEngine(analytics_conn)
|
|
# Space in alias — invalid identifier
|
|
with pytest.raises(RemoteQueryError) as exc_info:
|
|
engine.register_bq("bad alias", "SELECT 1")
|
|
assert exc_info.value.error_type == "query_error"
|
|
|
|
# Reserved alias — information_schema
|
|
with pytest.raises(RemoteQueryError) as exc_info:
|
|
engine.register_bq("information_schema", "SELECT 1")
|
|
assert exc_info.value.error_type == "query_error"
|
|
|
|
# Valid alias — should not raise from alias validation
|
|
# (will raise later trying to reach BQ without a client, but not from alias check)
|
|
try:
|
|
engine.register_bq("valid_name", "SELECT 1")
|
|
except RemoteQueryError as exc:
|
|
assert exc.error_type != "query_error" or "Invalid alias" not in str(exc)
|
|
except (ImportError, ModuleNotFoundError):
|
|
pass # Expected — no BQ package in test env
|
|
|
|
def test_register_bq_missing_package(self, analytics_conn):
|
|
"""When google-cloud-bigquery is not installed, engine must raise RemoteQueryError."""
|
|
engine = RemoteQueryEngine(
|
|
analytics_conn,
|
|
# No factory — will try to import google.cloud.bigquery
|
|
_bq_client_factory=None,
|
|
max_bq_registration_rows=500_000,
|
|
)
|
|
|
|
with patch.dict(sys.modules, {"google": None, "google.cloud": None, "google.cloud.bigquery": None}):
|
|
with pytest.raises(RemoteQueryError, match="google-cloud-bigquery"):
|
|
engine.register_bq("bq_alias", "SELECT 1")
|
|
|
|
|
|
# ---------------------------------------------------------------------------
|
|
# TestRemoteQueryEngineExecute
|
|
# ---------------------------------------------------------------------------
|
|
|
|
|
|
class TestRemoteQueryEngineExecute:
|
|
def test_execute_local_only(self, analytics_conn):
|
|
"""Query local table; result dict has correct structure."""
|
|
engine = RemoteQueryEngine(analytics_conn)
|
|
result = engine.execute("SELECT id, amount FROM orders ORDER BY id")
|
|
|
|
assert result["columns"] == ["id", "amount"]
|
|
assert result["row_count"] == 2
|
|
assert result["truncated"] is False
|
|
assert len(result["rows"]) == 2
|
|
# Non-standard types (Decimal) must be serialized to str
|
|
for row in result["rows"]:
|
|
for val in row:
|
|
assert isinstance(val, (int, float, bool, str, type(None)))
|
|
|
|
def test_execute_with_registered_bq(self, analytics_conn):
|
|
"""Manually register an Arrow table, then JOIN it with local orders."""
|
|
bq_arrow = pa.table(
|
|
{
|
|
"id": pa.array([1, 2], type=pa.int64()),
|
|
"label": pa.array(["first", "second"], type=pa.utf8()),
|
|
}
|
|
)
|
|
mock_client = _make_bq_mock(bq_arrow)
|
|
|
|
engine = RemoteQueryEngine(
|
|
analytics_conn,
|
|
_bq_client_factory=lambda project: mock_client,
|
|
max_bq_registration_rows=500_000,
|
|
)
|
|
engine.register_bq("bq_labels", "SELECT id, label FROM bq.labels")
|
|
|
|
result = engine.execute(
|
|
"SELECT o.id, o.amount, b.label "
|
|
"FROM orders o JOIN bq_labels b ON o.id = b.id "
|
|
"ORDER BY o.id"
|
|
)
|
|
|
|
assert result["row_count"] == 2
|
|
assert "label" in result["columns"]
|
|
|
|
def test_execute_respects_max_result_rows(self, analytics_conn):
|
|
"""When max_result_rows=1, result is truncated after 1 row."""
|
|
engine = RemoteQueryEngine(analytics_conn, max_result_rows=1)
|
|
result = engine.execute("SELECT id FROM orders ORDER BY id")
|
|
|
|
assert result["row_count"] == 1
|
|
assert result["truncated"] is True
|
|
|
|
def test_execute_invalid_sql(self, analytics_conn):
|
|
"""DROP TABLE must be rejected with RemoteQueryError(error_type='query_error')."""
|
|
engine = RemoteQueryEngine(analytics_conn)
|
|
|
|
with pytest.raises(RemoteQueryError) as exc_info:
|
|
engine.execute("DROP TABLE orders")
|
|
|
|
assert exc_info.value.error_type == "query_error"
|
|
|
|
|
|
# ---------------------------------------------------------------------------
|
|
# _validate_sql unit tests
|
|
# ---------------------------------------------------------------------------
|
|
|
|
|
|
class TestValidateSql:
|
|
@pytest.mark.parametrize(
|
|
"sql",
|
|
[
|
|
"DROP TABLE foo",
|
|
"DELETE FROM foo",
|
|
"INSERT INTO foo VALUES (1)",
|
|
"UPDATE foo SET x=1",
|
|
"ALTER TABLE foo ADD COLUMN y INT",
|
|
"CREATE TABLE foo (x INT)",
|
|
"COPY foo TO '/tmp/out.csv'",
|
|
"ATTACH '/db.duckdb'",
|
|
"DETACH db",
|
|
"LOAD 'extension'",
|
|
"INSTALL httpfs",
|
|
"SELECT read_parquet('/data/file.parquet')",
|
|
"SELECT * FROM '../secret/file'",
|
|
"SELECT 1; DROP TABLE foo",
|
|
],
|
|
)
|
|
def test_blocked_sql(self, sql):
|
|
with pytest.raises(RemoteQueryError) as exc_info:
|
|
_validate_sql(sql)
|
|
assert exc_info.value.error_type == "query_error"
|
|
|
|
@pytest.mark.parametrize(
|
|
"sql",
|
|
[
|
|
"SELECT id FROM orders",
|
|
"WITH cte AS (SELECT 1 AS x) SELECT x FROM cte",
|
|
"select count(*) from orders",
|
|
"with t as (select 1) select * from t",
|
|
],
|
|
)
|
|
def test_allowed_sql(self, sql):
|
|
# Should not raise
|
|
_validate_sql(sql)
|
|
|
|
|
|
# ---------------------------------------------------------------------------
|
|
# _validate_bq_sql unit tests
|
|
# ---------------------------------------------------------------------------
|
|
|
|
|
|
class TestValidateBqSql:
|
|
def test_information_schema_is_allowed(self):
|
|
"""INFORMATION_SCHEMA queries must pass BQ SQL validation."""
|
|
# Should not raise
|
|
_validate_bq_sql("SELECT * FROM dataset.INFORMATION_SCHEMA.COLUMNS")
|
|
|
|
@pytest.mark.parametrize(
|
|
"sql",
|
|
[
|
|
"DROP TABLE x",
|
|
"INSERT INTO x VALUES (1)",
|
|
"DELETE FROM x",
|
|
"UPDATE x SET y=1",
|
|
"ALTER TABLE x ADD COLUMN z INT",
|
|
"CREATE TABLE x (y INT)",
|
|
"TRUNCATE TABLE x",
|
|
"MERGE INTO x USING y ON x.id=y.id WHEN MATCHED THEN UPDATE SET x.a=y.a",
|
|
"SELECT 1; DROP TABLE x",
|
|
],
|
|
)
|
|
def test_blocked_bq_sql(self, sql):
|
|
"""Write/mutation operations must be rejected."""
|
|
with pytest.raises(RemoteQueryError) as exc_info:
|
|
_validate_bq_sql(sql)
|
|
assert exc_info.value.error_type == "query_error"
|
|
|
|
@pytest.mark.parametrize(
|
|
"sql",
|
|
[
|
|
"SELECT * FROM dataset.INFORMATION_SCHEMA.COLUMNS",
|
|
"SELECT id FROM project.dataset.table",
|
|
"WITH cte AS (SELECT 1 AS x) SELECT x FROM cte",
|
|
],
|
|
)
|
|
def test_allowed_bq_sql(self, sql):
|
|
"""Valid read-only BQ queries must pass."""
|
|
# Should not raise
|
|
_validate_bq_sql(sql)
|