diff --git a/src/db.py b/src/db.py index e5630f7..66c9d0a 100644 --- a/src/db.py +++ b/src/db.py @@ -250,6 +250,92 @@ def get_analytics_db() -> duckdb.DuckDBPyConnection: return duckdb.connect(str(db_path)) +def _reattach_remote_extensions( + conn: duckdb.DuckDBPyConnection, extracts_dir: Path +) -> None: + """Re-LOAD DuckDB extensions listed in _remote_attach tables of each extract.duckdb. + + Called from get_analytics_db_readonly() after ATTACHing extract.duckdb files so + that remote views (e.g. BigQuery) resolve correctly. Uses LOAD only — no INSTALL — + to avoid touching the network in read-only query paths. + """ + if not extracts_dir.exists(): + return + + try: + attached_dbs = { + r[0] for r in conn.execute("SELECT database_name FROM duckdb_databases()").fetchall() + } + except Exception: + return + + for ext_dir in sorted(extracts_dir.iterdir()): + if not ext_dir.is_dir(): + continue + if not _SAFE_IDENTIFIER.match(ext_dir.name): + continue + db_file = ext_dir / "extract.duckdb" + if not db_file.exists(): + continue + # Only process sources that were successfully attached + if ext_dir.name not in attached_dbs: + continue + + # Check whether this extract has a _remote_attach table + try: + has_table = conn.execute( + "SELECT 1 FROM information_schema.tables " + f"WHERE table_schema='{ext_dir.name}' AND table_name='_remote_attach'" + ).fetchone() + if not has_table: + continue + except Exception: + continue + + try: + rows = conn.execute( + f"SELECT alias, extension, url, token_env FROM {ext_dir.name}._remote_attach" + ).fetchall() + except Exception as e: + logger.debug("Could not read _remote_attach from %s: %s", ext_dir.name, e) + continue + + # Refresh attached list before processing each source's rows + try: + attached_dbs = { + r[0] for r in conn.execute("SELECT database_name FROM duckdb_databases()").fetchall() + } + except Exception: + pass + + for alias, extension, url, token_env in rows: + if not _SAFE_IDENTIFIER.match(alias or ""): + logger.debug("Skipping unsafe remote_attach alias: %r", alias) + continue + if not _SAFE_IDENTIFIER.match(extension or ""): + logger.debug("Skipping unsafe remote_attach extension: %r", extension) + continue + if alias in attached_dbs: + logger.debug("Remote source %s already attached, skipping", alias) + continue + try: + conn.execute(f"LOAD {extension};") + token = os.environ.get(token_env, "") if token_env else "" + if token: + escaped_token = token.replace("'", "''") + conn.execute( + f"ATTACH '{url}' AS {alias} (TYPE {extension}, TOKEN '{escaped_token}')" + ) + else: + conn.execute( + f"ATTACH '{url}' AS {alias} (TYPE {extension}, READ_ONLY)" + ) + attached_dbs.add(alias) + logger.debug("Re-attached remote source %s via %s extension", alias, extension) + except Exception as e: + logger.debug("Could not re-attach remote source %s: %s", alias, e) + + def get_analytics_db_readonly() -> duckdb.DuckDBPyConnection: """Read-only connection to analytics DB. Blocks writes and external access. @@ -277,6 +363,8 @@ def get_analytics_db_readonly() -> duckdb.DuckDBPyConnection: conn.execute(f"ATTACH '{db_file}' AS {ext_dir.name} (READ_ONLY)") except Exception: pass + # Re-attach remote extensions so BigQuery / other remote views resolve. + _reattach_remote_extensions(conn, extracts_dir) # Note: external_access stays enabled because views use read_parquet() on local files. # File-path-based attacks are blocked by the SQL blocklist in app/api/query.py. return conn diff --git a/tests/test_db.py b/tests/test_db.py index 8a06937..e99e87e 100644 --- a/tests/test_db.py +++ b/tests/test_db.py @@ -462,6 +462,76 @@ class TestSchemaV4: conn2.close() +class TestExtensionReattach: + """Resilience tests for _reattach_remote_extensions() called by get_analytics_db_readonly().""" + + def _make_analytics_db(self, tmp_path): + """Create an empty analytics server.duckdb so get_analytics_db_readonly() takes the read_only path.""" + analytics_dir = tmp_path / "analytics" + analytics_dir.mkdir(parents=True, exist_ok=True) + import duckdb as _duckdb + conn = _duckdb.connect(str(analytics_dir / "server.duckdb")) + conn.close() + + def _make_extract_db(self, tmp_path, source_name, with_remote_attach=True): + """Create a minimal extract.duckdb, optionally with a _remote_attach table.""" + ext_dir = tmp_path / "extracts" / source_name + ext_dir.mkdir(parents=True, exist_ok=True) + import duckdb as _duckdb + conn = _duckdb.connect(str(ext_dir / "extract.duckdb")) + try: + conn.execute( + "CREATE TABLE _meta (table_name VARCHAR, description VARCHAR, rows BIGINT, " + "size_bytes BIGINT, extracted_at TIMESTAMP, query_mode VARCHAR)" + ) + if with_remote_attach: + conn.execute( + "CREATE TABLE _remote_attach (alias VARCHAR, extension VARCHAR, url VARCHAR, token_env VARCHAR)" + ) + # Use 'bigquery' which won't be installed in CI — tests resilience + conn.execute( + "INSERT INTO _remote_attach VALUES ('bq', 'bigquery', 'project/dataset', '')" + ) + finally: + conn.close() + + def test_reads_remote_attach_table(self, tmp_path, monkeypatch): + """get_analytics_db_readonly() doesn't crash even when LOAD fails for missing extension.""" + monkeypatch.setenv("DATA_DIR", str(tmp_path)) + import importlib + import src.db as db_module + importlib.reload(db_module) + + self._make_analytics_db(tmp_path) + self._make_extract_db(tmp_path, "mysource", with_remote_attach=True) + + # Should not raise even though 'bigquery' extension is not installed + conn = db_module.get_analytics_db_readonly() + try: + # Connection must still be usable for local queries + result = conn.execute("SELECT 42 AS n").fetchone() + assert result[0] == 42 + finally: + conn.close() + + def test_skips_missing_remote_attach(self, tmp_path, monkeypatch): + """get_analytics_db_readonly() works fine when _remote_attach table is absent.""" + monkeypatch.setenv("DATA_DIR", str(tmp_path)) + import importlib + import src.db as db_module + importlib.reload(db_module) + + self._make_analytics_db(tmp_path) + self._make_extract_db(tmp_path, "localsource", with_remote_attach=False) + + conn = db_module.get_analytics_db_readonly() + try: + result = conn.execute("SELECT 'ok' AS status").fetchone() + assert result[0] == "ok" + finally: + conn.close() + + class TestGetAnalyticsDbReadonly: def test_analytics_readonly_rejects_malicious_dir_name(self, tmp_path, monkeypatch): """Directories with SQL-injection chars in their name are skipped."""