feat: add RemoteQueryEngine with BQ registration and safety limits

Two-phase query engine: Phase 1 registers BQ query results as DuckDB
Arrow views (with COUNT pre-check, row/memory limits, Storage API
fallback); Phase 2 executes validated SQL against DuckDB with result
serialization and truncation. 25 tests covering all branches.
This commit is contained in:
ZdenekSrotyr 2026-04-11 11:07:08 +02:00
parent 0a69814fca
commit 86bbb8fce4
2 changed files with 599 additions and 0 deletions

375
src/remote_query.py Normal file
View file

@ -0,0 +1,375 @@
"""RemoteQueryEngine — two-phase BQ registration + DuckDB execution.
Phase 1 (register_bq): validate SQL, COUNT(*) pre-check against BigQuery,
fetch Arrow table, check memory, register as DuckDB view.
Phase 2 (execute): validate SQL, execute against DuckDB (which may reference
registered BQ views), serialize and return results.
"""
from __future__ import annotations
import logging
import os
from typing import Any, Dict, List, Optional
import duckdb
logger = logging.getLogger(__name__)
# ---------------------------------------------------------------------------
# SQL blocklist — mirrors app/api/query.py lines 40-63
# ---------------------------------------------------------------------------
_BLOCKED_KEYWORDS: List[str] = [
"drop ",
"delete ",
"insert ",
"update ",
"alter ",
"create ",
"copy ",
"attach ",
"detach ",
"load ",
"install ",
"export ",
"import ",
"pragma ",
"call ",
# File access functions
"read_csv",
"read_json",
"read_parquet",
"read_text",
"write_csv",
"write_parquet",
"read_blob",
"read_ndjson",
"parquet_scan",
"parquet_metadata",
"parquet_schema",
"json_scan",
"csv_scan",
"query_table",
"iceberg_scan",
"delta_scan",
"glob(",
"list_files",
"'/",
'\"/',
"http://",
"https://",
"s3://",
"gcs://",
# DuckDB metadata (leaks schema info regardless of RBAC)
"information_schema",
"duckdb_tables",
"duckdb_columns",
"duckdb_databases",
"duckdb_settings",
"duckdb_functions",
"duckdb_views",
"duckdb_indexes",
"duckdb_schemas",
"pragma_table_info",
"pragma_storage_info",
# Relative path traversal
"'../",
'"../',
# Multiple statements
";",
]
# ---------------------------------------------------------------------------
# Exception
# ---------------------------------------------------------------------------
class RemoteQueryError(Exception):
"""Raised by RemoteQueryEngine for all controlled error conditions.
Attributes:
error_type: One of "row_limit", "memory_limit", "bq_error",
"query_error", "timeout".
details: Optional dict with additional context.
"""
def __init__(
self,
message: str,
error_type: str,
details: Optional[Dict[str, Any]] = None,
) -> None:
super().__init__(message)
self.error_type = error_type
self.details = details or {}
# ---------------------------------------------------------------------------
# Module-level helpers
# ---------------------------------------------------------------------------
def _validate_sql(sql: str) -> None:
"""Raise RemoteQueryError if *sql* contains blocked patterns.
Raises:
RemoteQueryError: with error_type="query_error" if validation fails.
"""
sql_lower = sql.strip().lower()
for keyword in _BLOCKED_KEYWORDS:
if keyword in sql_lower:
raise RemoteQueryError(
f"Blocked SQL pattern: {keyword!r}",
error_type="query_error",
details={"blocked_keyword": keyword},
)
if not sql_lower.startswith("select ") and not sql_lower.startswith("with "):
raise RemoteQueryError(
"Query must start with SELECT or WITH",
error_type="query_error",
)
def load_config() -> Dict[str, Any]:
"""Load the ``remote_query:`` section from instance.yaml.
Returns an empty dict if the section is missing or config cannot be loaded.
"""
try:
from app.instance_config import get_value
return get_value("remote_query", default={}) or {}
except Exception:
return {}
# ---------------------------------------------------------------------------
# Engine
# ---------------------------------------------------------------------------
class RemoteQueryEngine:
"""Two-phase query engine: BQ registration (Phase 1) + DuckDB execution (Phase 2).
Args:
conn: Open DuckDB connection used for both view registration and querying.
_bq_client_factory: Optional callable ``(project: str) -> BQ client``.
Defaults to ``scripts.duckdb_manager._create_bq_client``.
max_bq_registration_rows: Maximum rows allowed in a single BQ registration.
max_memory_mb: Maximum in-memory Arrow table size (MiB).
max_result_rows: Maximum rows returned by ``execute()``.
timeout_seconds: Query timeout (reserved for future use).
"""
def __init__(
self,
conn: duckdb.DuckDBPyConnection,
*,
_bq_client_factory=None,
max_bq_registration_rows: int = 500_000,
max_memory_mb: float = 2048.0,
max_result_rows: int = 100_000,
timeout_seconds: int = 300,
) -> None:
self._conn = conn
self._bq_client_factory = _bq_client_factory
self.max_bq_registration_rows = max_bq_registration_rows
self.max_memory_mb = max_memory_mb
self.max_result_rows = max_result_rows
self.timeout_seconds = timeout_seconds
# Track which aliases have been registered in this session
self._registered: Dict[str, Dict[str, Any]] = {}
# ------------------------------------------------------------------
# Phase 1
# ------------------------------------------------------------------
def register_bq(self, alias: str, bq_sql: str) -> Dict[str, Any]:
"""Register a BigQuery query result as a DuckDB view.
Steps:
1. Validate *bq_sql* against the SQL blocklist.
2. COUNT(*) pre-check via BQ client.
3. Execute the actual BQ query and fetch as Arrow table.
4. Check in-memory size against *max_memory_mb*.
5. Register Arrow table in DuckDB under *alias*.
Args:
alias: DuckDB view name to register (e.g. ``"bq_orders"``).
bq_sql: SQL query to execute on BigQuery.
Returns:
``{alias, rows, columns, memory_mb}``
Raises:
RemoteQueryError: For row/memory limits or BQ errors.
ImportError: If google-cloud-bigquery is not installed.
"""
_validate_sql(bq_sql)
client = self._get_bq_client()
# --- Phase 1a: COUNT(*) pre-check ---
count_sql = f"SELECT COUNT(*) FROM ({bq_sql})"
try:
count_job = client.query(count_sql)
count_arrow = count_job.to_arrow()
count_value = int(count_arrow.column(0)[0].as_py())
except RemoteQueryError:
raise
except Exception as exc:
raise RemoteQueryError(
f"BQ COUNT pre-check failed: {exc}",
error_type="bq_error",
details={"original_error": str(exc)},
) from exc
if count_value > self.max_bq_registration_rows:
raise RemoteQueryError(
f"BQ result has {count_value:,} rows, exceeding the "
f"limit of {self.max_bq_registration_rows:,}.",
error_type="row_limit",
details={
"count": count_value,
"max": self.max_bq_registration_rows,
},
)
# --- Phase 1b: Fetch actual data ---
try:
data_job = client.query(bq_sql)
try:
arrow_table = data_job.to_arrow()
except Exception as storage_exc:
if "readsessions" in str(storage_exc) or "PERMISSION_DENIED" in str(storage_exc):
logger.warning("BQ Storage API unavailable, falling back to REST")
arrow_table = data_job.to_arrow(create_bqstorage_client=False)
else:
raise
except RemoteQueryError:
raise
except Exception as exc:
raise RemoteQueryError(
f"BQ query failed: {exc}",
error_type="bq_error",
details={"original_error": str(exc)},
) from exc
# --- Phase 1c: Memory check (accurate, post-fetch) ---
memory_mb = arrow_table.nbytes / (1024 * 1024)
if memory_mb > self.max_memory_mb:
raise RemoteQueryError(
f"Arrow table uses {memory_mb:.1f} MiB, exceeding the "
f"limit of {self.max_memory_mb:.1f} MiB.",
error_type="memory_limit",
details={"memory_mb": memory_mb, "max_memory_mb": self.max_memory_mb},
)
# --- Phase 1d: Register in DuckDB ---
self._conn.register(alias, arrow_table)
info: Dict[str, Any] = {
"alias": alias,
"rows": arrow_table.num_rows,
"columns": arrow_table.schema.names,
"memory_mb": memory_mb,
}
self._registered[alias] = info
logger.info(
"Registered BQ alias %r: %d rows, %.2f MiB",
alias,
arrow_table.num_rows,
memory_mb,
)
return info
# ------------------------------------------------------------------
# Phase 2
# ------------------------------------------------------------------
def execute(self, sql: str) -> Dict[str, Any]:
"""Execute SQL against DuckDB (which may reference registered BQ views).
Args:
sql: SQL query to execute. Must pass the SQL blocklist.
Returns:
``{columns, rows, row_count, truncated, bq_stats}``
Raises:
RemoteQueryError: If SQL is blocked or a DuckDB error occurs.
"""
_validate_sql(sql)
try:
result = self._conn.execute(sql).fetchmany(self.max_result_rows + 1)
columns = (
[desc[0] for desc in self._conn.description]
if self._conn.description
else []
)
except RemoteQueryError:
raise
except Exception as exc:
raise RemoteQueryError(
f"Query error: {exc}",
error_type="query_error",
details={"original_error": str(exc)},
) from exc
truncated = len(result) > self.max_result_rows
rows = result[: self.max_result_rows]
# Serialize non-standard types (mirrors app/api/query.py lines 92-96)
serializable_rows = []
for row in rows:
serializable_rows.append(
[
str(v) if v is not None and not isinstance(v, (int, float, bool, str)) else v
for v in row
]
)
return {
"columns": columns,
"rows": serializable_rows,
"row_count": len(serializable_rows),
"truncated": truncated,
"bq_stats": {
"registered_aliases": list(self._registered.keys()),
"alias_count": len(self._registered),
},
}
# ------------------------------------------------------------------
# Internal helpers
# ------------------------------------------------------------------
def _get_bq_client(self):
"""Return a BigQuery client from the injected factory or the default one.
Raises:
ImportError: If google-cloud-bigquery is not installed and no
factory was injected.
"""
if self._bq_client_factory is not None:
project = os.environ.get("BIGQUERY_PROJECT", "unknown")
return self._bq_client_factory(project)
# Trigger ImportError early if the package is missing.
# This is a lazy import so the module stays usable without BQ installed.
import google.cloud.bigquery as _bq_module # noqa: PLC0415, F401
project = os.environ.get("BIGQUERY_PROJECT")
if not project:
raise RemoteQueryError(
"BIGQUERY_PROJECT env var is not set.",
error_type="bq_error",
)
return _bq_module.Client(project=project)

224
tests/test_remote_query.py Normal file
View file

@ -0,0 +1,224 @@
"""Tests for RemoteQueryEngine — two-phase BQ registration + DuckDB execution."""
import sys
from datetime import date
from decimal import Decimal
from unittest.mock import MagicMock, patch
import duckdb
import pyarrow as pa
import pytest
from src.remote_query import RemoteQueryEngine, RemoteQueryError, _validate_sql
# ---------------------------------------------------------------------------
# Fixtures
# ---------------------------------------------------------------------------
@pytest.fixture
def analytics_conn():
conn = duckdb.connect()
conn.execute("CREATE TABLE orders (id INT, date DATE, amount DECIMAL(10,2))")
conn.execute(
"INSERT INTO orders VALUES (1, '2026-01-01', 100.0), (2, '2026-01-15', 200.0)"
)
yield conn
conn.close()
def _make_bq_mock(arrow_table, count_value=None):
"""Build a minimal BQ client mock.
First call to client.query() returns a count job, second returns a data job.
If count_value is None, infer it from arrow_table.num_rows.
"""
if count_value is None:
count_value = arrow_table.num_rows
count_arrow = pa.table({"count": pa.array([count_value], type=pa.int64())})
count_job = MagicMock()
count_job.to_arrow.return_value = count_arrow
data_job = MagicMock()
data_job.to_arrow.return_value = arrow_table
mock_client = MagicMock()
mock_client.query.side_effect = [count_job, data_job]
return mock_client
# ---------------------------------------------------------------------------
# TestRemoteQueryEngineRegister
# ---------------------------------------------------------------------------
class TestRemoteQueryEngineRegister:
def test_register_bq_success(self, analytics_conn):
"""Mock BQ client returning an Arrow table; verify view is queryable."""
arrow_table = pa.table(
{
"order_id": pa.array([10, 20, 30], type=pa.int64()),
"revenue": pa.array([1.0, 2.0, 3.0], type=pa.float64()),
}
)
mock_client = _make_bq_mock(arrow_table)
engine = RemoteQueryEngine(
analytics_conn,
_bq_client_factory=lambda project: mock_client,
max_bq_registration_rows=500_000,
)
result = engine.register_bq("bq_orders", "SELECT order_id, revenue FROM bq.orders")
assert result["alias"] == "bq_orders"
assert result["rows"] == 3
assert result["columns"] == ["order_id", "revenue"]
assert result["memory_mb"] > 0
# The alias must be queryable from DuckDB
rows = analytics_conn.execute("SELECT COUNT(*) FROM bq_orders").fetchone()
assert rows[0] == 3
def test_register_bq_row_limit_exceeded(self, analytics_conn):
"""COUNT pre-check returns a value exceeding the row limit → RemoteQueryError."""
arrow_table = pa.table({"x": pa.array([1], type=pa.int64())})
# count exceeds limit
mock_client = _make_bq_mock(arrow_table, count_value=1_000_000)
engine = RemoteQueryEngine(
analytics_conn,
_bq_client_factory=lambda project: mock_client,
max_bq_registration_rows=500_000,
)
with pytest.raises(RemoteQueryError) as exc_info:
engine.register_bq("bq_big", "SELECT * FROM bq.huge_table")
assert exc_info.value.error_type == "row_limit"
assert exc_info.value.details["count"] == 1_000_000
def test_register_bq_missing_package(self, analytics_conn):
"""When google-cloud-bigquery is not installed, engine must raise ImportError."""
engine = RemoteQueryEngine(
analytics_conn,
# No factory — will try to import google.cloud.bigquery
_bq_client_factory=None,
max_bq_registration_rows=500_000,
)
with patch.dict(sys.modules, {"google": None, "google.cloud": None, "google.cloud.bigquery": None}):
with pytest.raises((ImportError, ModuleNotFoundError)):
engine.register_bq("bq_alias", "SELECT 1")
# ---------------------------------------------------------------------------
# TestRemoteQueryEngineExecute
# ---------------------------------------------------------------------------
class TestRemoteQueryEngineExecute:
def test_execute_local_only(self, analytics_conn):
"""Query local table; result dict has correct structure."""
engine = RemoteQueryEngine(analytics_conn)
result = engine.execute("SELECT id, amount FROM orders ORDER BY id")
assert result["columns"] == ["id", "amount"]
assert result["row_count"] == 2
assert result["truncated"] is False
assert len(result["rows"]) == 2
# Non-standard types (Decimal) must be serialized to str
for row in result["rows"]:
for val in row:
assert isinstance(val, (int, float, bool, str, type(None)))
def test_execute_with_registered_bq(self, analytics_conn):
"""Manually register an Arrow table, then JOIN it with local orders."""
bq_arrow = pa.table(
{
"id": pa.array([1, 2], type=pa.int64()),
"label": pa.array(["first", "second"], type=pa.utf8()),
}
)
mock_client = _make_bq_mock(bq_arrow)
engine = RemoteQueryEngine(
analytics_conn,
_bq_client_factory=lambda project: mock_client,
max_bq_registration_rows=500_000,
)
engine.register_bq("bq_labels", "SELECT id, label FROM bq.labels")
result = engine.execute(
"SELECT o.id, o.amount, b.label "
"FROM orders o JOIN bq_labels b ON o.id = b.id "
"ORDER BY o.id"
)
assert result["row_count"] == 2
assert "label" in result["columns"]
def test_execute_respects_max_result_rows(self, analytics_conn):
"""When max_result_rows=1, result is truncated after 1 row."""
engine = RemoteQueryEngine(analytics_conn, max_result_rows=1)
result = engine.execute("SELECT id FROM orders ORDER BY id")
assert result["row_count"] == 1
assert result["truncated"] is True
def test_execute_invalid_sql(self, analytics_conn):
"""DROP TABLE must be rejected with RemoteQueryError(error_type='query_error')."""
engine = RemoteQueryEngine(analytics_conn)
with pytest.raises(RemoteQueryError) as exc_info:
engine.execute("DROP TABLE orders")
assert exc_info.value.error_type == "query_error"
# ---------------------------------------------------------------------------
# _validate_sql unit tests
# ---------------------------------------------------------------------------
class TestValidateSql:
@pytest.mark.parametrize(
"sql",
[
"DROP TABLE foo",
"DELETE FROM foo",
"INSERT INTO foo VALUES (1)",
"UPDATE foo SET x=1",
"ALTER TABLE foo ADD COLUMN y INT",
"CREATE TABLE foo (x INT)",
"COPY foo TO '/tmp/out.csv'",
"ATTACH '/db.duckdb'",
"DETACH db",
"LOAD 'extension'",
"INSTALL httpfs",
"SELECT read_parquet('/data/file.parquet')",
"SELECT * FROM '../secret/file'",
"SELECT 1; DROP TABLE foo",
],
)
def test_blocked_sql(self, sql):
with pytest.raises(RemoteQueryError) as exc_info:
_validate_sql(sql)
assert exc_info.value.error_type == "query_error"
@pytest.mark.parametrize(
"sql",
[
"SELECT id FROM orders",
"WITH cte AS (SELECT 1 AS x) SELECT x FROM cte",
"select count(*) from orders",
"with t as (select 1) select * from t",
],
)
def test_allowed_sql(self, sql):
# Should not raise
_validate_sql(sql)