diff --git a/src/remote_query.py b/src/remote_query.py new file mode 100644 index 0000000..2b50f1b --- /dev/null +++ b/src/remote_query.py @@ -0,0 +1,375 @@ +"""RemoteQueryEngine — two-phase BQ registration + DuckDB execution. + +Phase 1 (register_bq): validate SQL, COUNT(*) pre-check against BigQuery, +fetch Arrow table, check memory, register as DuckDB view. + +Phase 2 (execute): validate SQL, execute against DuckDB (which may reference +registered BQ views), serialize and return results. +""" + +from __future__ import annotations + +import logging +import os +from typing import Any, Dict, List, Optional + +import duckdb + +logger = logging.getLogger(__name__) + +# --------------------------------------------------------------------------- +# SQL blocklist — mirrors app/api/query.py lines 40-63 +# --------------------------------------------------------------------------- + +_BLOCKED_KEYWORDS: List[str] = [ + "drop ", + "delete ", + "insert ", + "update ", + "alter ", + "create ", + "copy ", + "attach ", + "detach ", + "load ", + "install ", + "export ", + "import ", + "pragma ", + "call ", + # File access functions + "read_csv", + "read_json", + "read_parquet", + "read_text", + "write_csv", + "write_parquet", + "read_blob", + "read_ndjson", + "parquet_scan", + "parquet_metadata", + "parquet_schema", + "json_scan", + "csv_scan", + "query_table", + "iceberg_scan", + "delta_scan", + "glob(", + "list_files", + "'/", + '\"/', + "http://", + "https://", + "s3://", + "gcs://", + # DuckDB metadata (leaks schema info regardless of RBAC) + "information_schema", + "duckdb_tables", + "duckdb_columns", + "duckdb_databases", + "duckdb_settings", + "duckdb_functions", + "duckdb_views", + "duckdb_indexes", + "duckdb_schemas", + "pragma_table_info", + "pragma_storage_info", + # Relative path traversal + "'../", + '"../', + # Multiple statements + ";", +] + + +# --------------------------------------------------------------------------- +# Exception +# --------------------------------------------------------------------------- + + +class RemoteQueryError(Exception): + """Raised by RemoteQueryEngine for all controlled error conditions. + + Attributes: + error_type: One of "row_limit", "memory_limit", "bq_error", + "query_error", "timeout". + details: Optional dict with additional context. + """ + + def __init__( + self, + message: str, + error_type: str, + details: Optional[Dict[str, Any]] = None, + ) -> None: + super().__init__(message) + self.error_type = error_type + self.details = details or {} + + +# --------------------------------------------------------------------------- +# Module-level helpers +# --------------------------------------------------------------------------- + + +def _validate_sql(sql: str) -> None: + """Raise RemoteQueryError if *sql* contains blocked patterns. + + Raises: + RemoteQueryError: with error_type="query_error" if validation fails. + """ + sql_lower = sql.strip().lower() + + for keyword in _BLOCKED_KEYWORDS: + if keyword in sql_lower: + raise RemoteQueryError( + f"Blocked SQL pattern: {keyword!r}", + error_type="query_error", + details={"blocked_keyword": keyword}, + ) + + if not sql_lower.startswith("select ") and not sql_lower.startswith("with "): + raise RemoteQueryError( + "Query must start with SELECT or WITH", + error_type="query_error", + ) + + +def load_config() -> Dict[str, Any]: + """Load the ``remote_query:`` section from instance.yaml. + + Returns an empty dict if the section is missing or config cannot be loaded. + """ + try: + from app.instance_config import get_value + + return get_value("remote_query", default={}) or {} + except Exception: + return {} + + +# --------------------------------------------------------------------------- +# Engine +# --------------------------------------------------------------------------- + + +class RemoteQueryEngine: + """Two-phase query engine: BQ registration (Phase 1) + DuckDB execution (Phase 2). + + Args: + conn: Open DuckDB connection used for both view registration and querying. + _bq_client_factory: Optional callable ``(project: str) -> BQ client``. + Defaults to ``scripts.duckdb_manager._create_bq_client``. + max_bq_registration_rows: Maximum rows allowed in a single BQ registration. + max_memory_mb: Maximum in-memory Arrow table size (MiB). + max_result_rows: Maximum rows returned by ``execute()``. + timeout_seconds: Query timeout (reserved for future use). + """ + + def __init__( + self, + conn: duckdb.DuckDBPyConnection, + *, + _bq_client_factory=None, + max_bq_registration_rows: int = 500_000, + max_memory_mb: float = 2048.0, + max_result_rows: int = 100_000, + timeout_seconds: int = 300, + ) -> None: + self._conn = conn + self._bq_client_factory = _bq_client_factory + self.max_bq_registration_rows = max_bq_registration_rows + self.max_memory_mb = max_memory_mb + self.max_result_rows = max_result_rows + self.timeout_seconds = timeout_seconds + + # Track which aliases have been registered in this session + self._registered: Dict[str, Dict[str, Any]] = {} + + # ------------------------------------------------------------------ + # Phase 1 + # ------------------------------------------------------------------ + + def register_bq(self, alias: str, bq_sql: str) -> Dict[str, Any]: + """Register a BigQuery query result as a DuckDB view. + + Steps: + 1. Validate *bq_sql* against the SQL blocklist. + 2. COUNT(*) pre-check via BQ client. + 3. Execute the actual BQ query and fetch as Arrow table. + 4. Check in-memory size against *max_memory_mb*. + 5. Register Arrow table in DuckDB under *alias*. + + Args: + alias: DuckDB view name to register (e.g. ``"bq_orders"``). + bq_sql: SQL query to execute on BigQuery. + + Returns: + ``{alias, rows, columns, memory_mb}`` + + Raises: + RemoteQueryError: For row/memory limits or BQ errors. + ImportError: If google-cloud-bigquery is not installed. + """ + _validate_sql(bq_sql) + + client = self._get_bq_client() + + # --- Phase 1a: COUNT(*) pre-check --- + count_sql = f"SELECT COUNT(*) FROM ({bq_sql})" + try: + count_job = client.query(count_sql) + count_arrow = count_job.to_arrow() + count_value = int(count_arrow.column(0)[0].as_py()) + except RemoteQueryError: + raise + except Exception as exc: + raise RemoteQueryError( + f"BQ COUNT pre-check failed: {exc}", + error_type="bq_error", + details={"original_error": str(exc)}, + ) from exc + + if count_value > self.max_bq_registration_rows: + raise RemoteQueryError( + f"BQ result has {count_value:,} rows, exceeding the " + f"limit of {self.max_bq_registration_rows:,}.", + error_type="row_limit", + details={ + "count": count_value, + "max": self.max_bq_registration_rows, + }, + ) + + # --- Phase 1b: Fetch actual data --- + try: + data_job = client.query(bq_sql) + try: + arrow_table = data_job.to_arrow() + except Exception as storage_exc: + if "readsessions" in str(storage_exc) or "PERMISSION_DENIED" in str(storage_exc): + logger.warning("BQ Storage API unavailable, falling back to REST") + arrow_table = data_job.to_arrow(create_bqstorage_client=False) + else: + raise + except RemoteQueryError: + raise + except Exception as exc: + raise RemoteQueryError( + f"BQ query failed: {exc}", + error_type="bq_error", + details={"original_error": str(exc)}, + ) from exc + + # --- Phase 1c: Memory check (accurate, post-fetch) --- + memory_mb = arrow_table.nbytes / (1024 * 1024) + if memory_mb > self.max_memory_mb: + raise RemoteQueryError( + f"Arrow table uses {memory_mb:.1f} MiB, exceeding the " + f"limit of {self.max_memory_mb:.1f} MiB.", + error_type="memory_limit", + details={"memory_mb": memory_mb, "max_memory_mb": self.max_memory_mb}, + ) + + # --- Phase 1d: Register in DuckDB --- + self._conn.register(alias, arrow_table) + + info: Dict[str, Any] = { + "alias": alias, + "rows": arrow_table.num_rows, + "columns": arrow_table.schema.names, + "memory_mb": memory_mb, + } + self._registered[alias] = info + logger.info( + "Registered BQ alias %r: %d rows, %.2f MiB", + alias, + arrow_table.num_rows, + memory_mb, + ) + return info + + # ------------------------------------------------------------------ + # Phase 2 + # ------------------------------------------------------------------ + + def execute(self, sql: str) -> Dict[str, Any]: + """Execute SQL against DuckDB (which may reference registered BQ views). + + Args: + sql: SQL query to execute. Must pass the SQL blocklist. + + Returns: + ``{columns, rows, row_count, truncated, bq_stats}`` + + Raises: + RemoteQueryError: If SQL is blocked or a DuckDB error occurs. + """ + _validate_sql(sql) + + try: + result = self._conn.execute(sql).fetchmany(self.max_result_rows + 1) + columns = ( + [desc[0] for desc in self._conn.description] + if self._conn.description + else [] + ) + except RemoteQueryError: + raise + except Exception as exc: + raise RemoteQueryError( + f"Query error: {exc}", + error_type="query_error", + details={"original_error": str(exc)}, + ) from exc + + truncated = len(result) > self.max_result_rows + rows = result[: self.max_result_rows] + + # Serialize non-standard types (mirrors app/api/query.py lines 92-96) + serializable_rows = [] + for row in rows: + serializable_rows.append( + [ + str(v) if v is not None and not isinstance(v, (int, float, bool, str)) else v + for v in row + ] + ) + + return { + "columns": columns, + "rows": serializable_rows, + "row_count": len(serializable_rows), + "truncated": truncated, + "bq_stats": { + "registered_aliases": list(self._registered.keys()), + "alias_count": len(self._registered), + }, + } + + # ------------------------------------------------------------------ + # Internal helpers + # ------------------------------------------------------------------ + + def _get_bq_client(self): + """Return a BigQuery client from the injected factory or the default one. + + Raises: + ImportError: If google-cloud-bigquery is not installed and no + factory was injected. + """ + if self._bq_client_factory is not None: + project = os.environ.get("BIGQUERY_PROJECT", "unknown") + return self._bq_client_factory(project) + + # Trigger ImportError early if the package is missing. + # This is a lazy import so the module stays usable without BQ installed. + import google.cloud.bigquery as _bq_module # noqa: PLC0415, F401 + + project = os.environ.get("BIGQUERY_PROJECT") + if not project: + raise RemoteQueryError( + "BIGQUERY_PROJECT env var is not set.", + error_type="bq_error", + ) + return _bq_module.Client(project=project) diff --git a/tests/test_remote_query.py b/tests/test_remote_query.py new file mode 100644 index 0000000..cae7208 --- /dev/null +++ b/tests/test_remote_query.py @@ -0,0 +1,224 @@ +"""Tests for RemoteQueryEngine — two-phase BQ registration + DuckDB execution.""" + +import sys +from datetime import date +from decimal import Decimal +from unittest.mock import MagicMock, patch + +import duckdb +import pyarrow as pa +import pytest + +from src.remote_query import RemoteQueryEngine, RemoteQueryError, _validate_sql + + +# --------------------------------------------------------------------------- +# Fixtures +# --------------------------------------------------------------------------- + + +@pytest.fixture +def analytics_conn(): + conn = duckdb.connect() + conn.execute("CREATE TABLE orders (id INT, date DATE, amount DECIMAL(10,2))") + conn.execute( + "INSERT INTO orders VALUES (1, '2026-01-01', 100.0), (2, '2026-01-15', 200.0)" + ) + yield conn + conn.close() + + +def _make_bq_mock(arrow_table, count_value=None): + """Build a minimal BQ client mock. + + First call to client.query() returns a count job, second returns a data job. + If count_value is None, infer it from arrow_table.num_rows. + """ + if count_value is None: + count_value = arrow_table.num_rows + + count_arrow = pa.table({"count": pa.array([count_value], type=pa.int64())}) + + count_job = MagicMock() + count_job.to_arrow.return_value = count_arrow + + data_job = MagicMock() + data_job.to_arrow.return_value = arrow_table + + mock_client = MagicMock() + mock_client.query.side_effect = [count_job, data_job] + + return mock_client + + +# --------------------------------------------------------------------------- +# TestRemoteQueryEngineRegister +# --------------------------------------------------------------------------- + + +class TestRemoteQueryEngineRegister: + def test_register_bq_success(self, analytics_conn): + """Mock BQ client returning an Arrow table; verify view is queryable.""" + arrow_table = pa.table( + { + "order_id": pa.array([10, 20, 30], type=pa.int64()), + "revenue": pa.array([1.0, 2.0, 3.0], type=pa.float64()), + } + ) + mock_client = _make_bq_mock(arrow_table) + + engine = RemoteQueryEngine( + analytics_conn, + _bq_client_factory=lambda project: mock_client, + max_bq_registration_rows=500_000, + ) + + result = engine.register_bq("bq_orders", "SELECT order_id, revenue FROM bq.orders") + + assert result["alias"] == "bq_orders" + assert result["rows"] == 3 + assert result["columns"] == ["order_id", "revenue"] + assert result["memory_mb"] > 0 + + # The alias must be queryable from DuckDB + rows = analytics_conn.execute("SELECT COUNT(*) FROM bq_orders").fetchone() + assert rows[0] == 3 + + def test_register_bq_row_limit_exceeded(self, analytics_conn): + """COUNT pre-check returns a value exceeding the row limit → RemoteQueryError.""" + arrow_table = pa.table({"x": pa.array([1], type=pa.int64())}) + # count exceeds limit + mock_client = _make_bq_mock(arrow_table, count_value=1_000_000) + + engine = RemoteQueryEngine( + analytics_conn, + _bq_client_factory=lambda project: mock_client, + max_bq_registration_rows=500_000, + ) + + with pytest.raises(RemoteQueryError) as exc_info: + engine.register_bq("bq_big", "SELECT * FROM bq.huge_table") + + assert exc_info.value.error_type == "row_limit" + assert exc_info.value.details["count"] == 1_000_000 + + def test_register_bq_missing_package(self, analytics_conn): + """When google-cloud-bigquery is not installed, engine must raise ImportError.""" + engine = RemoteQueryEngine( + analytics_conn, + # No factory — will try to import google.cloud.bigquery + _bq_client_factory=None, + max_bq_registration_rows=500_000, + ) + + with patch.dict(sys.modules, {"google": None, "google.cloud": None, "google.cloud.bigquery": None}): + with pytest.raises((ImportError, ModuleNotFoundError)): + engine.register_bq("bq_alias", "SELECT 1") + + +# --------------------------------------------------------------------------- +# TestRemoteQueryEngineExecute +# --------------------------------------------------------------------------- + + +class TestRemoteQueryEngineExecute: + def test_execute_local_only(self, analytics_conn): + """Query local table; result dict has correct structure.""" + engine = RemoteQueryEngine(analytics_conn) + result = engine.execute("SELECT id, amount FROM orders ORDER BY id") + + assert result["columns"] == ["id", "amount"] + assert result["row_count"] == 2 + assert result["truncated"] is False + assert len(result["rows"]) == 2 + # Non-standard types (Decimal) must be serialized to str + for row in result["rows"]: + for val in row: + assert isinstance(val, (int, float, bool, str, type(None))) + + def test_execute_with_registered_bq(self, analytics_conn): + """Manually register an Arrow table, then JOIN it with local orders.""" + bq_arrow = pa.table( + { + "id": pa.array([1, 2], type=pa.int64()), + "label": pa.array(["first", "second"], type=pa.utf8()), + } + ) + mock_client = _make_bq_mock(bq_arrow) + + engine = RemoteQueryEngine( + analytics_conn, + _bq_client_factory=lambda project: mock_client, + max_bq_registration_rows=500_000, + ) + engine.register_bq("bq_labels", "SELECT id, label FROM bq.labels") + + result = engine.execute( + "SELECT o.id, o.amount, b.label " + "FROM orders o JOIN bq_labels b ON o.id = b.id " + "ORDER BY o.id" + ) + + assert result["row_count"] == 2 + assert "label" in result["columns"] + + def test_execute_respects_max_result_rows(self, analytics_conn): + """When max_result_rows=1, result is truncated after 1 row.""" + engine = RemoteQueryEngine(analytics_conn, max_result_rows=1) + result = engine.execute("SELECT id FROM orders ORDER BY id") + + assert result["row_count"] == 1 + assert result["truncated"] is True + + def test_execute_invalid_sql(self, analytics_conn): + """DROP TABLE must be rejected with RemoteQueryError(error_type='query_error').""" + engine = RemoteQueryEngine(analytics_conn) + + with pytest.raises(RemoteQueryError) as exc_info: + engine.execute("DROP TABLE orders") + + assert exc_info.value.error_type == "query_error" + + +# --------------------------------------------------------------------------- +# _validate_sql unit tests +# --------------------------------------------------------------------------- + + +class TestValidateSql: + @pytest.mark.parametrize( + "sql", + [ + "DROP TABLE foo", + "DELETE FROM foo", + "INSERT INTO foo VALUES (1)", + "UPDATE foo SET x=1", + "ALTER TABLE foo ADD COLUMN y INT", + "CREATE TABLE foo (x INT)", + "COPY foo TO '/tmp/out.csv'", + "ATTACH '/db.duckdb'", + "DETACH db", + "LOAD 'extension'", + "INSTALL httpfs", + "SELECT read_parquet('/data/file.parquet')", + "SELECT * FROM '../secret/file'", + "SELECT 1; DROP TABLE foo", + ], + ) + def test_blocked_sql(self, sql): + with pytest.raises(RemoteQueryError) as exc_info: + _validate_sql(sql) + assert exc_info.value.error_type == "query_error" + + @pytest.mark.parametrize( + "sql", + [ + "SELECT id FROM orders", + "WITH cte AS (SELECT 1 AS x) SELECT x FROM cte", + "select count(*) from orders", + "with t as (select 1) select * from t", + ], + ) + def test_allowed_sql(self, sql): + # Should not raise + _validate_sql(sql)