From b502bd8bdd96e4e0304dc4a7371aff2f3e77c8a0 Mon Sep 17 00:00:00 2001 From: ZdenekSrotyr Date: Tue, 31 Mar 2026 07:50:37 +0200 Subject: [PATCH] =?UTF-8?q?refactor:=20delete=20old=20sync=20pipeline=20?= =?UTF-8?q?=E2=80=94=209,500=20lines=20removed?= MIME-Version: 1.0 Content-Type: text/plain; charset=UTF-8 Content-Transfer-Encoding: 8bit Phase 5 cleanup: remove all code replaced by extract.duckdb architecture. Deleted modules: - src/config.py (653) — replaced by DuckDB table_registry - src/parquet_manager.py (755) — replaced by DuckDB COPY TO - src/data_sync.py (734) — replaced by SyncOrchestrator - src/remote_query.py (636) — replaced by DuckDB BigQuery ATTACH - src/table_registry.py (464) — replaced by DuckDB repository - connectors/keboola/adapter.py (820) — replaced by extractor.py - connectors/bigquery/adapter.py (665) — replaced by extractor.py - connectors/bigquery/client.py (644) — replaced by DuckDB BQ extension Updated all imports in webapp, catalog_export, enricher, router, sync_settings_service, generate_sample_data. Kept keboola/client.py as fallback (removed src.config dependency). 704 tests passing. --- app/web/router.py | 25 +- connectors/bigquery/adapter.py | 665 ------------- connectors/bigquery/client.py | 644 ------------- connectors/keboola/adapter.py | 820 ---------------- connectors/keboola/client.py | 33 +- connectors/keboola/tests/test_adapter.py | 194 ---- connectors/openmetadata/enricher.py | 15 +- scripts/generate_sample_data.py | 36 +- src/catalog_export.py | 31 +- src/config.py | 653 ------------- src/data_sync.py | 734 -------------- src/parquet_manager.py | 755 --------------- src/remote_query.py | 636 ------------- src/table_registry.py | 464 --------- tests/test_bigquery_adapter.py | 1103 ---------------------- tests/test_bigquery_client.py | 1018 -------------------- tests/test_catalog_export.py | 56 +- tests/test_config_bq_entity_type.py | 69 -- tests/test_config_query_mode.py | 69 -- tests/test_config_sync_schedule.py | 103 -- tests/test_data_sync_query_mode.py | 228 ----- tests/test_openmetadata_enricher.py | 20 +- tests/test_remote_query.py | 779 --------------- tests/test_table_registry.py | 363 ------- webapp/app.py | 156 +-- webapp/sync_settings_service.py | 9 +- 26 files changed, 188 insertions(+), 9490 deletions(-) delete mode 100644 connectors/bigquery/adapter.py delete mode 100644 connectors/bigquery/client.py delete mode 100644 connectors/keboola/adapter.py delete mode 100644 connectors/keboola/tests/test_adapter.py delete mode 100644 src/config.py delete mode 100644 src/data_sync.py delete mode 100644 src/parquet_manager.py delete mode 100644 src/remote_query.py delete mode 100644 src/table_registry.py delete mode 100644 tests/test_bigquery_adapter.py delete mode 100644 tests/test_bigquery_client.py delete mode 100644 tests/test_config_bq_entity_type.py delete mode 100644 tests/test_config_query_mode.py delete mode 100644 tests/test_config_sync_schedule.py delete mode 100644 tests/test_data_sync_query_mode.py delete mode 100644 tests/test_remote_query.py delete mode 100644 tests/test_table_registry.py diff --git a/app/web/router.py b/app/web/router.py index 252d3bd..2759134 100644 --- a/app/web/router.py +++ b/app/web/router.py @@ -236,24 +236,25 @@ async def catalog( enabled_datasets = settings_repo.get_enabled_datasets(user["id"]) datasets = get_datasets() - # Build catalog data from config + # Build catalog data from table_registry in DuckDB try: - from src.config import get_config - config = get_config() + from src.repositories.table_registry import TableRegistryRepository + table_repo = TableRegistryRepository(conn) + registered = table_repo.list_all() tables = [] - for tc in config.tables: + for tc in registered: table_data = { - "id": tc.id, - "name": tc.name, - "description": tc.description, - "dataset": getattr(tc, "dataset", None), - "sync_strategy": tc.sync_strategy, - "query_mode": getattr(tc, "query_mode", "local"), - "profile": all_profiles.get(tc.id), + "id": tc.get("id", ""), + "name": tc.get("name", ""), + "description": tc.get("description", ""), + "dataset": tc.get("bucket"), + "sync_strategy": tc.get("sync_strategy", "full_refresh"), + "query_mode": tc.get("query_mode", "local"), + "profile": all_profiles.get(tc.get("id", "")), } # Add sync state for state in all_states: - if state["table_id"] == tc.id: + if state["table_id"] == tc.get("id"): table_data["last_sync"] = state.get("last_sync") table_data["rows"] = state.get("rows") break diff --git a/connectors/bigquery/adapter.py b/connectors/bigquery/adapter.py deleted file mode 100644 index b63bc04..0000000 --- a/connectors/bigquery/adapter.py +++ /dev/null @@ -1,665 +0,0 @@ -""" -BigQuery data source adapter. - -Implements the DataSource interface for Google BigQuery. -Reads tables via the BigQuery API, converts directly to Parquet files -using PyArrow (no CSV intermediate step). -""" - -import logging -from pathlib import Path -from typing import Dict, List, Optional, Any -from datetime import datetime, timedelta, date - -import pyarrow as pa -import pyarrow.parquet as pq - -from src.config import get_config, TableConfig -from src.data_sync import DataSource, SyncState, _get_uncompressed_size -from src.parquet_manager import ( - convert_date_columns_to_date32, - apply_schema_to_table, -) -from .client import create_client as create_bq_client - - -logger = logging.getLogger(__name__) - - -class BigQueryDataSource(DataSource): - """ - Data source: Google BigQuery. - - Downloads data directly from BigQuery via PyArrow (no CSV step), - writes to local Parquet files with schema enforcement. - """ - - def __init__(self): - """Initialize BigQuery source with env var validation.""" - self.config = get_config() - self.bq_client = create_bq_client() - - def get_column_metadata(self, table_id: str) -> Optional[Dict[str, Any]]: - """Return BigQuery column metadata for schema generation. - - Returns: - {"columns": {"col_name": {"source_type": "...", "description": "..."}}} - or None if metadata unavailable. - """ - raw = self.bq_client.get_table_metadata(table_id) - column_types = raw.get("column_types", {}) - column_descriptions = raw.get("column_descriptions", {}) - - if not column_types: - return None - - result = {} - for col_name, bq_type in column_types.items(): - entry = {"source_type": bq_type} - if col_name in column_descriptions: - entry["description"] = column_descriptions[col_name] - result[col_name] = entry - - return {"columns": result} - - def discover_tables(self) -> List[Dict[str, Any]]: - """Discover all available tables from BigQuery.""" - return self.bq_client.discover_all_tables() - - def get_source_name(self) -> str: - """Display name of this data source.""" - return "Google BigQuery" - - def sync_table( - self, - table_config: TableConfig, - sync_state: SyncState, - ) -> Dict[str, Any]: - """ - Synchronize table from BigQuery. - - Dispatches to the appropriate strategy based on table config. - - Args: - table_config: Table configuration - sync_state: Sync state manager - - Returns: - Dictionary with sync result - """ - logger.info(f"Syncing BQ table: {table_config.name} ({table_config.sync_strategy})") - - # Clear metadata cache for fresh types - if table_config.id in self.bq_client.metadata_cache: - del self.bq_client.metadata_cache[table_config.id] - logger.debug(f"Cleared BQ metadata cache for {table_config.id}") - - try: - if table_config.sync_strategy == "full_refresh": - result = self._full_refresh(table_config) - elif table_config.sync_strategy == "incremental": - result = self._incremental_sync(table_config, sync_state) - elif table_config.sync_strategy == "partitioned": - result = self._partitioned_sync(table_config, sync_state) - else: - raise ValueError(f"Unknown sync strategy: {table_config.sync_strategy}") - - # Skip sync state update if partitioned sync got no new data. - # This lets the scheduler retry on the next tick instead of - # marking the sync as done for the day with stale data. - skip_state_update = ( - table_config.sync_strategy == "partitioned" - and result.get("partitions_updated", -1) == 0 - ) - - if skip_state_update: - logger.warning( - f"Partitioned sync for {table_config.name} got 0 new partitions " - f"- NOT updating last_sync (will retry next tick)" - ) - else: - sync_state.update_sync( - table_id=table_config.id, - table_name=table_config.name, - strategy=table_config.sync_strategy, - rows=result["rows"], - file_size_bytes=result["file_size_bytes"], - columns=result.get("columns", 0), - uncompressed_bytes=result.get("uncompressed_bytes", 0), - ) - - return { - "success": True, - "rows": result["rows"], - "strategy": table_config.sync_strategy, - "file_size_mb": result["file_size_bytes"] / 1024 / 1024, - } - - except Exception as e: - logger.error(f"Error syncing BQ table {table_config.name}: {e}") - return { - "success": False, - "error": str(e), - "strategy": table_config.sync_strategy, - } - - def _full_refresh(self, table_config: TableConfig) -> Dict[str, Any]: - """ - Full refresh: stream table from BQ and write to Parquet in batches. - - Uses streaming (constant memory) instead of loading entire table into RAM. - Each RecordBatch from BQ is written directly to disk via ParquetWriter. - """ - logger.info(f"Full refresh (streaming): {table_config.name}") - - parquet_path = self.config.get_parquet_path(table_config) - date_columns = self.bq_client.get_date_columns(table_config.id) - pyarrow_schema = self.bq_client.get_pyarrow_schema(table_config.id) - - parquet_path.parent.mkdir(parents=True, exist_ok=True) - - # Stream BQ results directly to Parquet file (constant memory) - writer = None - total_rows = 0 - num_columns = 0 - - for batch in self.bq_client.read_table_streaming( - table_config.id, - columns=table_config.columns, - row_filter=table_config.row_filter, - ): - if batch.num_rows == 0: - continue - - # Convert batch to table for schema enforcement - chunk = pa.Table.from_batches([batch]) - if date_columns: - chunk = convert_date_columns_to_date32(chunk, date_columns) - if pyarrow_schema: - chunk = apply_schema_to_table(chunk, pyarrow_schema) - - if writer is None: - writer = pq.ParquetWriter( - parquet_path, chunk.schema, compression="snappy", - ) - num_columns = chunk.num_columns - - writer.write_table(chunk) - total_rows += chunk.num_rows - - # Log progress every ~1M rows - if total_rows % 1_000_000 < chunk.num_rows: - logger.info(f" -> {total_rows:,} rows written...") - - if writer: - writer.close() - - file_size = parquet_path.stat().st_size if parquet_path.exists() else 0 - logger.info( - f"Full refresh complete: {total_rows:,} rows, " - f"{file_size / 1024 / 1024:.2f} MB" - ) - - return { - "rows": total_rows, - "columns": num_columns, - "file_size_bytes": file_size, - "uncompressed_bytes": _get_uncompressed_size(parquet_path) if total_rows > 0 else 0, - } - - def _incremental_sync( - self, - table_config: TableConfig, - sync_state: SyncState, - ) -> Dict[str, Any]: - """ - Incremental sync: dispatch to column-based or partition-based strategy. - """ - # If partition_by is set, use partitioned incremental - if table_config.partition_by: - return self._partitioned_sync(table_config, sync_state) - - # If incremental_column is set, use timestamp-based incremental - if table_config.incremental_column: - return self._incremental_column_sync(table_config, sync_state) - - # Fallback: full refresh (no incremental column configured) - logger.warning( - f"Table {table_config.name}: incremental strategy but no " - f"incremental_column or partition_by configured, falling back to full refresh" - ) - return self._full_refresh(table_config) - - def _incremental_column_sync( - self, - table_config: TableConfig, - sync_state: SyncState, - ) -> Dict[str, Any]: - """ - Timestamp-based incremental sync using incremental_column. - - Reads rows WHERE incremental_column > last_sync_value, - merges with existing Parquet (dedup on primary key). - """ - logger.info( - f"Incremental column sync: {table_config.name} " - f"(column: {table_config.incremental_column})" - ) - - parquet_path = self.config.get_parquet_path(table_config) - date_columns = self.bq_client.get_date_columns(table_config.id) - pyarrow_schema = self.bq_client.get_pyarrow_schema(table_config.id) - - # Determine since_value from last sync - last_sync = sync_state.get_last_sync(table_config.id) - - if last_sync and parquet_path.exists(): - # Apply window: go back incremental_window_days from last sync - last_sync_dt = datetime.fromisoformat(last_sync) - window_days = table_config.incremental_window_days or 7 - since_dt = last_sync_dt - timedelta(days=window_days) - since_value = since_dt.isoformat() - - logger.info(f" -> Since: {since_value} (window: {window_days} days)") - - # Read incremental data - new_data = self.bq_client.read_table_incremental( - table_id=table_config.id, - incremental_column=table_config.incremental_column, - since_value=since_value, - columns=table_config.columns, - ) - - if new_data.num_rows == 0: - logger.info(" -> No new data since last sync") - existing_pf = pq.ParquetFile(parquet_path) - return { - "rows": existing_pf.metadata.num_rows, - "columns": len(existing_pf.schema_arrow), - "file_size_bytes": parquet_path.stat().st_size, - "uncompressed_bytes": _get_uncompressed_size(parquet_path), - } - - # Merge with existing data - logger.info(f" -> Merging {new_data.num_rows} new rows with existing data") - existing_table = pq.read_table(parquet_path) - merged = self._merge_arrow_tables( - existing_table, new_data, table_config.get_primary_key_columns() - ) - - # Apply schema enforcement - if date_columns: - merged = convert_date_columns_to_date32(merged, date_columns) - if pyarrow_schema: - merged = apply_schema_to_table(merged, pyarrow_schema) - - pq.write_table(merged, parquet_path, compression="snappy") - - file_size = parquet_path.stat().st_size - logger.info( - f" -> Incremental sync complete: {merged.num_rows} total rows" - ) - - return { - "rows": merged.num_rows, - "columns": merged.num_columns, - "file_size_bytes": file_size, - "uncompressed_bytes": _get_uncompressed_size(parquet_path), - } - - else: - # First sync or no existing file -- full read - logger.info(" -> First sync, reading all data") - - if table_config.max_history_days: - since_dt = datetime.now() - timedelta(days=table_config.max_history_days) - arrow_table = self.bq_client.read_table_incremental( - table_id=table_config.id, - incremental_column=table_config.incremental_column, - since_value=since_dt.isoformat(), - columns=table_config.columns, - ) - else: - arrow_table = self.bq_client.read_table( - table_config.id, - columns=table_config.columns, - row_filter=table_config.row_filter, - ) - - # Apply schema enforcement - if date_columns: - arrow_table = convert_date_columns_to_date32(arrow_table, date_columns) - if pyarrow_schema: - arrow_table = apply_schema_to_table(arrow_table, pyarrow_schema) - - parquet_path.parent.mkdir(parents=True, exist_ok=True) - pq.write_table(arrow_table, parquet_path, compression="snappy") - - file_size = parquet_path.stat().st_size - return { - "rows": arrow_table.num_rows, - "columns": arrow_table.num_columns, - "file_size_bytes": file_size, - "uncompressed_bytes": _get_uncompressed_size(parquet_path), - } - - def _partitioned_sync( - self, - table_config: TableConfig, - sync_state: SyncState, - ) -> Dict[str, Any]: - """ - Per-partition streaming sync: process one partition (day) at a time. - - Queries BQ for a single day, streams result to disk, then moves to next day. - Memory usage is constant (~20-50 MB per partition) regardless of total data volume. - """ - partition_col = table_config.partition_by - if not partition_col and table_config.incremental_column: - partition_col = table_config.incremental_column - - if not partition_col: - logger.warning( - f"Table {table_config.name}: partitioned strategy but no " - f"partition_by or incremental_column, falling back to full refresh" - ) - return self._full_refresh(table_config) - - granularity = table_config.partition_granularity or "day" - column_type = table_config.partition_column_type - logger.info( - f"Partitioned sync: {table_config.name} " - f"(by {partition_col}, {granularity}, type={column_type})" - ) - - partition_dir = self.config.get_parquet_path(table_config) - date_columns = self.bq_client.get_date_columns(table_config.id) - pyarrow_schema = self.bq_client.get_pyarrow_schema(table_config.id) - - # Determine date range - last_sync = sync_state.get_last_sync(table_config.id) - today = date.today() - - if last_sync: - last_sync_dt = datetime.fromisoformat(last_sync) - window_days = table_config.incremental_window_days or 7 - start_date = (last_sync_dt - timedelta(days=window_days)).date() - logger.info(f" -> Incremental sync from {start_date} (window: {window_days} days)") - else: - if table_config.max_history_days: - start_date = today - timedelta(days=table_config.max_history_days) - logger.info(f" -> First sync, last {table_config.max_history_days} days from {start_date}") - else: - start_date = today - timedelta(days=365) - logger.info(" -> First sync, no max_history_days, defaulting to 365 days") - - # Generate list of partition dates - partition_dates = self._generate_partition_dates(start_date, today, granularity) - logger.info(f" -> Processing {len(partition_dates)} partitions") - - total_rows = 0 - partitions_updated = 0 - - for partition_date in partition_dates: - rows = self._sync_single_partition( - table_config=table_config, - partition_col=partition_col, - partition_date=partition_date, - partition_dir=partition_dir, - date_columns=date_columns, - pyarrow_schema=pyarrow_schema, - granularity=granularity, - column_type=column_type, - ) - if rows > 0: - partitions_updated += 1 - total_rows += rows - - # Cleanup old partitions beyond retention window - deleted = self._cleanup_old_partitions(table_config, partition_dir, granularity) - if deleted > 0: - logger.info(f" -> Cleaned up {deleted} old partition files") - - logger.info( - f" -> Partitioned sync complete: {partitions_updated} partitions updated, " - f"{total_rows} total rows processed" - ) - - result = self._get_partition_totals(partition_dir) - result["partitions_updated"] = partitions_updated - return result - - @staticmethod - def _generate_partition_dates( - start_date: date, - end_date: date, - granularity: str, - ) -> List[date]: - """Generate list of partition start dates between start and end.""" - dates = [] - current = start_date - - if granularity == "day": - while current <= end_date: - dates.append(current) - current += timedelta(days=1) - elif granularity == "month": - # Align to first of month - current = current.replace(day=1) - while current <= end_date: - dates.append(current) - # Move to first of next month - if current.month == 12: - current = current.replace(year=current.year + 1, month=1) - else: - current = current.replace(month=current.month + 1) - elif granularity == "year": - current = current.replace(month=1, day=1) - while current <= end_date: - dates.append(current) - current = current.replace(year=current.year + 1) - - return dates - - def _sync_single_partition( - self, - table_config: TableConfig, - partition_col: str, - partition_date: date, - partition_dir: Path, - date_columns: List[str], - pyarrow_schema, - granularity: str, - column_type: str, - ) -> int: - """ - Query BQ for one partition period, stream to disk, merge with existing file. - - Returns row count for this partition after merge. - """ - import pandas as pd - - # Calculate partition range [start, end) - start = partition_date - if granularity == "day": - end = start + timedelta(days=1) - partition_key = start.strftime("%Y_%m_%d") - elif granularity == "month": - if start.month == 12: - end = start.replace(year=start.year + 1, month=1) - else: - end = start.replace(month=start.month + 1) - partition_key = start.strftime("%Y_%m") - elif granularity == "year": - end = start.replace(year=start.year + 1) - partition_key = start.strftime("%Y") - else: - raise ValueError(f"Unknown granularity: {granularity}") - - partition_path = self.config.get_partition_path(table_config, partition_key) - - # Stream data from BQ for this single partition - batches = [] - for batch in self.bq_client.read_table_partitioned_streaming( - table_id=table_config.id, - partition_column=partition_col, - start=start.isoformat(), - end=end.isoformat(), - columns=table_config.columns, - column_type=column_type, - ): - batches.append(batch) - - if not batches: - return 0 - - new_data = pa.Table.from_batches(batches) - if new_data.num_rows == 0: - return 0 - - # Apply schema conversions - if date_columns: - new_data = convert_date_columns_to_date32(new_data, date_columns) - if pyarrow_schema: - new_data = apply_schema_to_table(new_data, pyarrow_schema) - - # Merge with existing partition file if present - primary_key_cols = table_config.get_primary_key_columns() - - if partition_path.exists(): - existing = pq.read_table(partition_path) - merged = self._merge_arrow_tables(existing, new_data, primary_key_cols) - else: - merged = new_data - - # Write partition file - pq.write_table(merged, partition_path, compression="snappy") - row_count = merged.num_rows - - logger.debug( - f" Partition {partition_key}: {new_data.num_rows} new rows, " - f"{row_count} total after merge" - ) - - # Release memory - del batches, new_data, merged - - return row_count - - def _cleanup_old_partitions( - self, - table_config: TableConfig, - partition_dir: Path, - granularity: str, - ) -> int: - """ - Delete partition files older than max_history_days. - - Returns count of deleted files. - """ - if not table_config.max_history_days: - return 0 - - if not partition_dir.exists(): - return 0 - - cutoff_date = date.today() - timedelta(days=table_config.max_history_days) - deleted = 0 - - for part_path in partition_dir.glob("*.parquet"): - try: - partition_date = self._parse_partition_date(part_path.stem, granularity) - if partition_date and partition_date < cutoff_date: - part_path.unlink() - deleted += 1 - logger.debug(f" Deleted old partition: {part_path.name}") - except (ValueError, IndexError): - logger.warning(f" Skipping unrecognized partition file: {part_path.name}") - - return deleted - - @staticmethod - def _parse_partition_date(partition_key: str, granularity: str) -> Optional[date]: - """Parse a partition key back to a date.""" - try: - if granularity == "day": - return datetime.strptime(partition_key, "%Y_%m_%d").date() - elif granularity == "month": - return datetime.strptime(partition_key, "%Y_%m").date() - elif granularity == "year": - return datetime.strptime(partition_key, "%Y").date() - except ValueError: - return None - return None - - def _merge_arrow_tables( - self, - existing: pa.Table, - new_data: pa.Table, - primary_key: List[str], - ) -> pa.Table: - """ - Merge two Arrow tables with deduplication on primary key. - - New data overwrites existing rows with the same primary key. - - Args: - existing: Existing data - new_data: New/changed data - primary_key: List of PK column names - - Returns: - Merged PyArrow Table - """ - import pandas as pd - - existing_df = existing.to_pandas() - new_df = new_data.to_pandas() - - # Concat and dedup (keep last = new data wins) - merged_df = pd.concat([existing_df, new_df], ignore_index=True) - merged_df = merged_df.drop_duplicates(subset=primary_key, keep="last") - - return pa.Table.from_pandas(merged_df, preserve_index=False) - - def _get_partition_totals(self, partition_dir: Path) -> Dict[str, Any]: - """ - Calculate totals from all partition files in a directory. - """ - total_rows = 0 - total_size = 0 - total_uncompressed = 0 - total_columns = 0 - - if not partition_dir.exists(): - return {"rows": 0, "file_size_bytes": 0, "columns": 0, "uncompressed_bytes": 0} - - all_partitions = list(partition_dir.glob("*.parquet")) - - for part_path in all_partitions: - try: - pf = pq.ParquetFile(part_path) - meta = pf.metadata - total_rows += meta.num_rows - total_size += part_path.stat().st_size - if total_columns == 0: - total_columns = len(pf.schema_arrow) - for rg_idx in range(meta.num_row_groups): - rg = meta.row_group(rg_idx) - for col_idx in range(rg.num_columns): - total_uncompressed += rg.column(col_idx).total_uncompressed_size - except Exception as e: - logger.warning(f"Skipping corrupt partition {part_path.name}: {e}") - - return { - "rows": total_rows, - "file_size_bytes": total_size, - "partitions": len(all_partitions), - "columns": total_columns, - "uncompressed_bytes": total_uncompressed, - } - - -def create_data_source() -> BigQueryDataSource: - """Factory function for dynamic import compatibility.""" - return BigQueryDataSource() diff --git a/connectors/bigquery/client.py b/connectors/bigquery/client.py deleted file mode 100644 index 59b0341..0000000 --- a/connectors/bigquery/client.py +++ /dev/null @@ -1,644 +0,0 @@ -""" -Google BigQuery API Client - -Low-level wrapper for Google BigQuery with these functions: -1. Authentication using Application Default Credentials (ADC) -2. Query tables to PyArrow (no CSV intermediate step) -3. Get table metadata (schema, columns, data types) -4. Cache metadata for faster repeated use -5. Incremental reads (timestamp-based and partition-based) - -Uses google-cloud-bigquery with native PyArrow support. -""" - -import json -import logging -import os -from pathlib import Path -from typing import Dict, List, Optional, Any -from datetime import datetime, timedelta - -import pyarrow as pa -from google.cloud import bigquery - -try: - from google.cloud import bigquery_storage_v1 - - _HAS_BQ_STORAGE = True -except ImportError: - _HAS_BQ_STORAGE = False - -from src.config import get_config - - -logger = logging.getLogger(__name__) - - -# Mapping BigQuery types to PyArrow types -BIGQUERY_TO_PYARROW_TYPES = { - "STRING": pa.string(), - "BYTES": pa.binary(), - "INTEGER": pa.int64(), - "INT64": pa.int64(), - "FLOAT": pa.float64(), - "FLOAT64": pa.float64(), - "NUMERIC": pa.float64(), - "BIGNUMERIC": pa.float64(), - "BOOLEAN": pa.bool_(), - "BOOL": pa.bool_(), - "TIMESTAMP": pa.timestamp("us", tz="UTC"), - "DATE": pa.date32(), - "TIME": pa.string(), - "DATETIME": pa.timestamp("us"), - "GEOGRAPHY": pa.string(), - "JSON": pa.string(), - "STRUCT": pa.string(), - "RECORD": pa.string(), - "ARRAY": pa.string(), -} - - -class BigQueryClient: - """ - Wrapper for Google BigQuery API. - - Provides high-level methods for working with BigQuery tables: - - Query tables to PyArrow Tables (no CSV step) - - Get metadata (schema, columns) - - Incremental and partitioned reads - """ - - def __init__( - self, - project_id: Optional[str] = None, - location: Optional[str] = None, - ): - """ - Initialize BigQuery client. - - Args: - project_id: GCP project ID for job execution/billing. - If None, reads from BIGQUERY_PROJECT env var. - location: BigQuery location for job execution (e.g., "us-central1"). - If None, reads from BIGQUERY_LOCATION env var. - - Raises: - ValueError: If project_id is not provided and BIGQUERY_PROJECT is not set. - """ - self.project_id = project_id or os.environ.get("BIGQUERY_PROJECT") - - if not self.project_id: - raise ValueError( - "BigQuery project ID not set. " - "Set BIGQUERY_PROJECT environment variable." - ) - - self.location = location or os.environ.get("BIGQUERY_LOCATION") - - # Initialize BigQuery client with ADC - # project_id is used for job execution and billing. - # Data can live in a different project -- table IDs in queries - # use fully-qualified format (project.dataset.table). - client_kwargs = {"project": self.project_id} - if self.location: - client_kwargs["location"] = self.location - self.client = bigquery.Client(**client_kwargs) - - # BQ Storage API client for fast parallel reads (gRPC streams). - # Without explicit bqstorage_client, to_arrow_iterable() silently - # falls back to slow REST API pagination (~5K rows/sec vs ~300K rows/sec). - if _HAS_BQ_STORAGE: - try: - self.bqstorage_client = bigquery_storage_v1.BigQueryReadClient() - logger.info("BQ Storage API client initialized (fast parallel gRPC reads)") - except Exception as e: - self.bqstorage_client = None - logger.warning(f"BQ Storage API client failed to initialize: {e}") - else: - self.bqstorage_client = None - logger.info("BQ Storage API not available (install google-cloud-bigquery-storage)") - - # Metadata cache - config = get_config() - self.metadata_cache: Dict[str, Dict[str, Any]] = {} - self.metadata_cache_path = config.get_metadata_path() / "bq_table_metadata.json" - - # Load cache from disk if exists - self._load_metadata_cache() - - logger.info( - f"BigQuery client initialized: project={self.project_id}, " - f"location={self.location or 'auto'}" - ) - - def _load_metadata_cache(self): - """Load metadata cache from disk.""" - if self.metadata_cache_path.exists(): - try: - with open(self.metadata_cache_path, "r") as f: - self.metadata_cache = json.load(f) - logger.info( - f"BQ metadata cache loaded: {len(self.metadata_cache)} tables" - ) - except Exception as e: - logger.warning(f"Error loading BQ metadata cache: {e}") - self.metadata_cache = {} - - def _save_metadata_cache(self): - """Save metadata cache to disk.""" - try: - self.metadata_cache_path.parent.mkdir(parents=True, exist_ok=True) - with open(self.metadata_cache_path, "w") as f: - json.dump(self.metadata_cache, f, indent=2) - logger.debug("BQ metadata cache saved") - except Exception as e: - logger.warning(f"Error saving BQ metadata cache: {e}") - - def get_table_metadata( - self, - table_id: str, - use_cache: bool = True, - cache_ttl_hours: int = 24, - ) -> Dict[str, Any]: - """ - Get table metadata from BigQuery. - - Args: - table_id: Full table ID (e.g., "project.dataset.table") - use_cache: Use cache if available - cache_ttl_hours: Cache TTL in hours (default 24h) - - Returns: - Dictionary with metadata including columns, types, descriptions, row count. - """ - # Check cache - if use_cache and table_id in self.metadata_cache: - cached = self.metadata_cache[table_id] - cached_time = datetime.fromisoformat(cached.get("_cached_at", "2000-01-01")) - cache_age = datetime.now() - cached_time - - if cache_age < timedelta(hours=cache_ttl_hours): - logger.debug(f"Using BQ metadata cache for {table_id}") - return cached - - logger.info(f"Fetching metadata for BQ table: {table_id}") - - try: - table_ref = self.client.get_table(table_id) - - # Build column metadata - columns = [] - column_types = {} - column_descriptions = {} - for field in table_ref.schema: - columns.append(field.name) - column_types[field.name] = field.field_type - if field.description: - column_descriptions[field.name] = field.description - - metadata = { - "table_id": table_id, - "name": table_ref.table_id, - "dataset": table_ref.dataset_id, - "project": table_ref.project, - "columns": columns, - "column_types": column_types, - "column_descriptions": column_descriptions, - "row_count": table_ref.num_rows, - "size_bytes": table_ref.num_bytes, - "created": table_ref.created.isoformat() if table_ref.created else None, - "modified": table_ref.modified.isoformat() if table_ref.modified else None, - "partitioning": None, - "_cached_at": datetime.now().isoformat(), - } - - # Capture partitioning info - if table_ref.time_partitioning: - metadata["partitioning"] = { - "type": table_ref.time_partitioning.type_, - "field": table_ref.time_partitioning.field, - "expiration_ms": table_ref.time_partitioning.expiration_ms, - } - - # Save to cache - self.metadata_cache[table_id] = metadata - self._save_metadata_cache() - - return metadata - - except Exception as e: - logger.error(f"Error getting metadata for {table_id}: {e}") - raise - - def get_pyarrow_schema(self, table_id: str) -> Optional[pa.Schema]: - """ - Build PyArrow schema from BigQuery table schema. - - Args: - table_id: Full table ID - - Returns: - PyArrow schema or None if metadata unavailable - """ - metadata = self.get_table_metadata(table_id) - column_types = metadata.get("column_types", {}) - - if not column_types: - logger.warning(f"No column types for {table_id}, schema will not be applied") - return None - - fields = [] - for col_name in metadata.get("columns", []): - bq_type = column_types.get(col_name, "STRING") - pa_type = BIGQUERY_TO_PYARROW_TYPES.get(bq_type, pa.string()) - fields.append(pa.field(col_name, pa_type)) - - return pa.schema(fields) - - def get_date_columns(self, table_id: str) -> List[str]: - """ - Get list of DATE-only columns for a table. - - Args: - table_id: Full table ID - - Returns: - List of column names that have DATE type in BigQuery - """ - metadata = self.get_table_metadata(table_id) - column_types = metadata.get("column_types", {}) - - return [ - col_name for col_name, bq_type in column_types.items() - if bq_type == "DATE" - ] - - def query_to_arrow( - self, - sql: str, - params: Optional[List[bigquery.ScalarQueryParameter]] = None, - ) -> pa.Table: - """ - Execute SQL query and return results as PyArrow Table. - - Args: - sql: SQL query string (use @param_name for parameterized values) - params: List of BigQuery query parameters - - Returns: - PyArrow Table with query results - """ - job_config = bigquery.QueryJobConfig() - if params: - job_config.query_parameters = params - - logger.debug(f"Executing BQ query: {sql[:200]}...") - - query_job = self.client.query(sql, job_config=job_config) - - # Use BQ Storage API for fast reads (parallel gRPC) if available. - # Fall back to REST API if SA lacks bigquery.readsessions.create permission. - try: - if self.bqstorage_client: - arrow_table = query_job.to_arrow(bqstorage_client=self.bqstorage_client) - else: - arrow_table = query_job.to_arrow() - except Exception as storage_err: - if "readsessions" in str(storage_err) or "PERMISSION_DENIED" in str(storage_err): - logger.warning( - "BQ Storage API unavailable (missing readsessions permission), " - "falling back to REST API" - ) - arrow_table = query_job.to_arrow(create_bqstorage_client=False) - else: - raise - - logger.debug(f"Query returned {arrow_table.num_rows} rows, {arrow_table.num_columns} columns") - return arrow_table - - def query_to_arrow_batches( - self, - sql: str, - params: Optional[List[bigquery.ScalarQueryParameter]] = None, - ): - """ - Execute SQL query and yield results as streaming RecordBatches. - - Unlike query_to_arrow(), this does NOT load entire result into memory. - Each RecordBatch is a small chunk (typically a few MB) that can be - written to disk immediately. - - Args: - sql: SQL query string (use @param_name for parameterized values) - params: List of BigQuery query parameters - - Yields: - pyarrow.RecordBatch objects - """ - job_config = bigquery.QueryJobConfig() - if params: - job_config.query_parameters = params - - logger.debug(f"Executing BQ query (streaming): {sql[:200]}...") - - query_job = self.client.query(sql, job_config=job_config) - - # result() returns RowIterator which has to_arrow_iterable() - # (QueryJob itself only has to_arrow(), not to_arrow_iterable()) - row_iter = query_job.result() - - # IMPORTANT: to_arrow_iterable() requires explicit bqstorage_client - # to use BQ Storage API (parallel gRPC streams, ~300K rows/sec). - # Without it, silently falls back to REST pagination (~5K rows/sec). - # This is critical when querying VIEWS (DataView): BQ materializes - # the view into a temp table, and Storage API reads from that temp table. - try: - storage_kwargs = {} - if self.bqstorage_client: - storage_kwargs["bqstorage_client"] = self.bqstorage_client - batch_iter = row_iter.to_arrow_iterable(**storage_kwargs) - # Probe first batch to detect Storage API permission errors early - first_batch = next(batch_iter, None) - if first_batch is not None: - yield first_batch - yield from batch_iter - return - except Exception as storage_err: - if "readsessions" not in str(storage_err) and "PERMISSION_DENIED" not in str(storage_err): - raise - logger.warning( - "BQ Storage API unavailable (missing readsessions permission), " - "falling back to REST API (streaming)" - ) - - # Fallback: REST API streaming (re-execute query for fresh RowIterator) - row_iter = self.client.query(sql, job_config=job_config).result() - yield from row_iter.to_arrow_iterable(create_bqstorage_client=False) - - def read_table_streaming( - self, - table_id: str, - columns: Optional[List[str]] = None, - row_filter: Optional[str] = None, - ): - """ - Read table as streaming RecordBatches (constant memory). - - Args: - table_id: Full table ID (e.g., "project.dataset.table") - columns: Optional list of columns to select - row_filter: Optional SQL WHERE clause (without WHERE keyword) - - Yields: - pyarrow.RecordBatch objects - """ - select_cols = ", ".join(f"`{c}`" for c in columns) if columns else "*" - - sql = f"SELECT {select_cols} FROM `{table_id}`" - if row_filter: - sql += f" WHERE {row_filter}" - - logger.info( - f"Streaming BQ table: {table_id} " - f"(filter: {row_filter or 'none'}, " - f"storage_api={'yes' if self.bqstorage_client else 'no'})" - ) - yield from self.query_to_arrow_batches(sql) - - def read_table( - self, - table_id: str, - columns: Optional[List[str]] = None, - row_filter: Optional[str] = None, - ) -> pa.Table: - """ - Read full table (or filtered subset) as PyArrow Table. - - Args: - table_id: Full table ID (e.g., "project.dataset.table") - columns: Optional list of columns to select - row_filter: Optional SQL WHERE clause (without WHERE keyword) - - Returns: - PyArrow Table with table data - """ - # Build SELECT clause - select_cols = ", ".join(f"`{c}`" for c in columns) if columns else "*" - - sql = f"SELECT {select_cols} FROM `{table_id}`" - if row_filter: - sql += f" WHERE {row_filter}" - - logger.info(f"Reading BQ table: {table_id} (filter: {row_filter or 'none'})") - return self.query_to_arrow(sql) - - def read_table_incremental( - self, - table_id: str, - incremental_column: str, - since_value: str, - columns: Optional[List[str]] = None, - ) -> pa.Table: - """ - Read rows where incremental_column > since_value. - - Uses parameterized query to prevent SQL injection. - - Args: - table_id: Full table ID - incremental_column: Column name for incremental filter - since_value: ISO timestamp string - fetch rows after this value - columns: Optional list of columns to select - - Returns: - PyArrow Table with incremental data - """ - select_cols = ", ".join(f"`{c}`" for c in columns) if columns else "*" - - sql = ( - f"SELECT {select_cols} FROM `{table_id}` " - f"WHERE `{incremental_column}` > @since_value" - ) - - params = [ - bigquery.ScalarQueryParameter("since_value", "TIMESTAMP", since_value), - ] - - logger.info( - f"Incremental read: {table_id} WHERE {incremental_column} > {since_value}" - ) - return self.query_to_arrow(sql, params=params) - - def read_table_partitioned( - self, - table_id: str, - partition_column: str, - start: str, - end: Optional[str] = None, - columns: Optional[List[str]] = None, - column_type: str = "TIMESTAMP", - ) -> pa.Table: - """ - Read data within a partition range. - - Args: - table_id: Full table ID - partition_column: Partition column name - start: Start date/timestamp (inclusive) - end: End date/timestamp (exclusive). If None, reads to present. - columns: Optional list of columns to select - column_type: BQ SQL type for the partition column ("DATE", "TIMESTAMP", "DATETIME") - - Returns: - PyArrow Table with partition range data - """ - select_cols = ", ".join(f"`{c}`" for c in columns) if columns else "*" - - sql = ( - f"SELECT {select_cols} FROM `{table_id}` " - f"WHERE `{partition_column}` >= @start_value" - ) - params = [ - bigquery.ScalarQueryParameter("start_value", column_type, start), - ] - - if end: - sql += f" AND `{partition_column}` < @end_value" - params.append( - bigquery.ScalarQueryParameter("end_value", column_type, end), - ) - - logger.info( - f"Partitioned read: {table_id} [{start} .. {end or 'now'})" - ) - return self.query_to_arrow(sql, params=params) - - def read_table_partitioned_streaming( - self, - table_id: str, - partition_column: str, - start: str, - end: Optional[str] = None, - columns: Optional[List[str]] = None, - column_type: str = "TIMESTAMP", - ): - """ - Read data within a partition range as streaming RecordBatches (constant memory). - - Unlike read_table_partitioned(), this does NOT load entire result into memory. - Each RecordBatch is a small chunk that can be written to disk immediately. - - Args: - table_id: Full table ID - partition_column: Partition column name - start: Start date/timestamp (inclusive) - end: End date/timestamp (exclusive). If None, reads to present. - columns: Optional list of columns to select - column_type: BQ SQL type for the partition column ("DATE", "TIMESTAMP", "DATETIME") - - Yields: - pyarrow.RecordBatch objects - """ - select_cols = ", ".join(f"`{c}`" for c in columns) if columns else "*" - - sql = ( - f"SELECT {select_cols} FROM `{table_id}` " - f"WHERE `{partition_column}` >= @start_value" - ) - params = [ - bigquery.ScalarQueryParameter("start_value", column_type, start), - ] - - if end: - sql += f" AND `{partition_column}` < @end_value" - params.append( - bigquery.ScalarQueryParameter("end_value", column_type, end), - ) - - logger.info( - f"Partitioned streaming read: {table_id} [{start} .. {end or 'now'})" - ) - yield from self.query_to_arrow_batches(sql, params=params) - - def discover_all_tables(self, dataset_id: Optional[str] = None) -> List[Dict[str, Any]]: - """ - List all tables in the project (or specific dataset). - - Args: - dataset_id: Optional dataset ID to limit scope - - Returns: - Normalized list of table dicts with id, name, columns, row_count, etc. - """ - logger.info(f"Discovering BQ tables (dataset={dataset_id or 'all'})...") - - result = [] - - if dataset_id: - datasets = [self.client.get_dataset(dataset_id)] - else: - datasets = list(self.client.list_datasets()) - - for dataset in datasets: - ds_ref = dataset.reference if hasattr(dataset, "reference") else dataset.dataset_id - ds_id = str(ds_ref) - - try: - tables = list(self.client.list_tables(ds_ref)) - except Exception as e: - logger.warning(f"Could not list tables in dataset {ds_id}: {e}") - continue - - for table_item in tables: - full_id = f"{table_item.project}.{table_item.dataset_id}.{table_item.table_id}" - - try: - table_detail = self.client.get_table(full_id) - columns = [f.name for f in table_detail.schema] - - result.append({ - "id": full_id, - "name": table_item.table_id, - "bucket_id": table_item.dataset_id, - "bucket_name": table_item.dataset_id, - "columns": columns, - "row_count": table_detail.num_rows or 0, - "size_bytes": table_detail.num_bytes or 0, - "primary_key": [], - "last_change": ( - table_detail.modified.isoformat() - if table_detail.modified else None - ), - "last_import": None, - }) - except Exception as e: - logger.warning(f"Could not get details for {full_id}: {e}") - - logger.info(f"Discovered {len(result)} BQ tables") - return result - - def test_connection(self) -> bool: - """ - Test connection to BigQuery. - - Returns: - True if connection works, False otherwise - """ - try: - query_job = self.client.query("SELECT 1") - list(query_job.result()) - logger.info(f"BigQuery connection OK (project: {self.project_id})") - return True - except Exception as e: - logger.error(f"BigQuery connection test failed: {e}") - return False - - -def create_client() -> BigQueryClient: - """ - Factory function to create BigQuery client. - - Returns: - BigQueryClient instance - """ - return BigQueryClient() diff --git a/connectors/keboola/adapter.py b/connectors/keboola/adapter.py deleted file mode 100644 index 760e928..0000000 --- a/connectors/keboola/adapter.py +++ /dev/null @@ -1,820 +0,0 @@ -""" -Keboola data source adapter. - -Implements the DataSource interface for Keboola Storage API. -Downloads tables via the Storage API, converts CSV exports to Parquet files -with full type metadata from Keboola column metadata. -""" - -import logging -import tempfile -from pathlib import Path -from typing import Dict, List, Optional, Any -from datetime import datetime, timedelta - -import pyarrow as pa -from tqdm import tqdm - -from src.config import get_config, TableConfig -from src.data_sync import DataSource, SyncState, _get_uncompressed_size -from src.parquet_manager import ( - create_parquet_manager, - _convert_column, - convert_date_columns_to_date32, - apply_schema_to_table, -) -from .client import create_client as create_keboola_client - - -logger = logging.getLogger(__name__) - - -class KeboolaDataSource(DataSource): - """ - Data source: Direct download from Keboola Storage API. - - Downloads data directly from a Keboola project, converts CSV exports - to typed Parquet files using column metadata for schema enforcement. - """ - - def __init__(self): - """Initialize Keboola source with full env var validation.""" - self.config = get_config() - - # Validate all required Keboola env vars before proceeding - missing = [] - if not self.config.keboola_token: - missing.append("KEBOOLA_STORAGE_TOKEN") - if not self.config.keboola_stack_url: - missing.append("KEBOOLA_STACK_URL") - if not self.config.keboola_project_id: - missing.append("KEBOOLA_PROJECT_ID") - if missing: - raise ValueError( - f"Missing required environment variables for Keboola connector: " - f"{', '.join(missing)}. See config/.env.template" - ) - - self.keboola_client = create_keboola_client() - self.parquet_manager = create_parquet_manager() - - def get_column_metadata(self, table_id: str) -> Optional[Dict[str, Any]]: - """Return Keboola metadata with provider cascade applied. - - Delegates type resolution to the client's _resolve_keboola_type(), - and extracts descriptions via provider priority cascade. - - Returns: - {"columns": {"col_name": {"source_type": "...", "description": "..."}}} - or None if metadata is unavailable. - """ - raw = self.keboola_client.get_table_metadata(table_id) - column_metadata = raw.get("column_metadata", {}) - - if not column_metadata: - return None - - PROVIDER_PRIORITY = [ - "user", - "ai-metadata-enrichment", - "keboola.snowflake-transformation", - ] - - result = {} - for col_name, col_meta_list in column_metadata.items(): - # Delegate type resolution to client - source_type = self.keboola_client._resolve_keboola_type(col_meta_list) - - # Extract description via provider cascade - description = None - if isinstance(col_meta_list, list): - description_by_provider = {} - for entry in col_meta_list: - provider = entry.get("provider", "") - key = entry.get("key", "") - value = entry.get("value", "") - if key == "KBC.description": - description_by_provider[provider] = value - - for p in PROVIDER_PRIORITY: - if p in description_by_provider: - description = description_by_provider[p] - break - - result[col_name] = {"source_type": source_type} - if description: - result[col_name]["description"] = description - - return {"columns": result} - - def discover_tables(self) -> List[Dict[str, Any]]: - """Discover all available tables from Keboola Storage.""" - return self.keboola_client.discover_all_tables() - - def get_source_name(self) -> str: - """Display name of this data source.""" - return "Keboola Storage API" - - def _cleanup_staging(self): - """ - Remove all files from staging directory. - - Called before chunked initial load and after failures to free up disk space. - """ - staging_dir = self.config.get_staging_path() - for f in staging_dir.glob("*"): - if f.is_file(): - try: - f.unlink() - logger.debug(f"Cleaned up staging file: {f.name}") - except Exception as e: - logger.warning(f"Failed to clean up {f.name}: {e}") - - def sync_table( - self, - table_config: TableConfig, - sync_state: SyncState, - ) -> Dict[str, Any]: - """ - Synchronize table from Keboola. - - According to sync_strategy calls _full_refresh or _incremental_sync. - - Args: - table_config: Table configuration - sync_state: Sync state manager - - Returns: - Dictionary with result: - - success: bool - - rows: int - - strategy: str - - error: str (if failed) - """ - logger.info(f"Syncing table: {table_config.name} ({table_config.sync_strategy})") - - # Refresh metadata cache for this table to get latest types from Keboola - if table_config.id in self.keboola_client.metadata_cache: - del self.keboola_client.metadata_cache[table_config.id] - logger.debug(f"Cleared metadata cache for {table_config.id}") - - try: - if table_config.sync_strategy == "full_refresh": - result = self._full_refresh(table_config) - elif table_config.sync_strategy == "partitioned": - result = self._partitioned_sync(table_config) - else: # incremental - result = self._incremental_sync(table_config, sync_state) - - # Update sync state - sync_state.update_sync( - table_id=table_config.id, - table_name=table_config.name, - strategy=table_config.sync_strategy, - rows=result["rows"], - file_size_bytes=result["file_size_bytes"], - columns=result.get("columns", 0), - uncompressed_bytes=result.get("uncompressed_bytes", 0), - ) - - return { - "success": True, - "rows": result["rows"], - "strategy": table_config.sync_strategy, - "file_size_mb": result["file_size_bytes"] / 1024 / 1024, - } - - except Exception as e: - logger.error(f"Error syncing table {table_config.name}: {e}") - return { - "success": False, - "error": str(e), - "strategy": table_config.sync_strategy, - } - - def _full_refresh(self, table_config: TableConfig) -> Dict[str, Any]: - """ - Full refresh sync strategy. - - Downloads entire table and replaces existing Parquet file. - """ - logger.info(f"Full refresh: {table_config.name}") - - parquet_path = self.config.get_parquet_path(table_config) - staging_dir = self.config.get_staging_path() - - with tempfile.NamedTemporaryFile( - mode="w", suffix=".csv", delete=False, dir=staging_dir - ) as tmp_file: - tmp_csv_path = Path(tmp_file.name) - - try: - # 1. Export from Keboola to CSV - filters_desc = "" - if table_config.where_filters: - filters_desc = f" (filters: {len(table_config.where_filters)})" - logger.info(f" -> Exporting from Keboola...{filters_desc}") - export_info = self.keboola_client.export_table( - table_id=table_config.id, - output_path=tmp_csv_path, - where_filters=table_config.where_filters if table_config.where_filters else None, - ) - - # 2. Get dtypes for proper conversion - dtypes = self.keboola_client.get_pandas_dtypes(table_config.id) - date_columns = self.keboola_client.get_date_columns(table_config.id) - pyarrow_schema = self.keboola_client.get_pyarrow_schema(table_config.id) - - # 3. Convert CSV -> Parquet - logger.info(" -> Converting to Parquet...") - parquet_info = self.parquet_manager.csv_to_parquet( - csv_path=tmp_csv_path, - parquet_path=parquet_path, - dtypes=dtypes, - table_id=table_config.id, - date_columns=date_columns, - pyarrow_schema=pyarrow_schema, - ) - - return { - "rows": parquet_info["rows"], - "file_size_bytes": parquet_info["parquet_size_bytes"], - "columns": parquet_info.get("columns", 0), - "uncompressed_bytes": _get_uncompressed_size(parquet_path), - } - - finally: - if tmp_csv_path.exists(): - tmp_csv_path.unlink() - - def _incremental_sync( - self, - table_config: TableConfig, - sync_state: SyncState, - ) -> Dict[str, Any]: - """ - Incremental sync strategy. - - Downloads only changed rows using changedSince API parameter. - If partition_by is configured, outputs are partitioned. - Otherwise, merges into a single Parquet file. - """ - if table_config.partition_by: - return self._incremental_partitioned_sync(table_config, sync_state) - return self._incremental_single_file_sync(table_config, sync_state) - - def _incremental_single_file_sync( - self, - table_config: TableConfig, - sync_state: SyncState, - ) -> Dict[str, Any]: - """ - Incremental sync to a single Parquet file (no partitioning). - """ - logger.info(f"Incremental sync (single file): {table_config.name}") - - parquet_path = self.config.get_parquet_path(table_config) - staging_dir = self.config.get_staging_path() - - # Determine timestamp for changedSince - last_sync = sync_state.get_last_sync(table_config.id) - - if last_sync: - last_sync_dt = datetime.fromisoformat(last_sync) - window_days = table_config.incremental_window_days or 7 - changed_since_dt = last_sync_dt - timedelta(days=window_days) - changed_since = changed_since_dt.isoformat() - logger.info( - f" -> ChangedSince: {changed_since} (window: {window_days} days)" - ) - else: - if table_config.max_history_days: - changed_since_dt = datetime.now() - timedelta(days=table_config.max_history_days) - changed_since = changed_since_dt.isoformat() - logger.info( - f" -> First sync, limited to last {table_config.max_history_days} days " - f"(changedSince: {changed_since})" - ) - else: - logger.info(" -> First sync, downloading all data...") - changed_since = None - - with tempfile.NamedTemporaryFile( - mode="w", suffix=".csv", delete=False, dir=staging_dir - ) as tmp_file: - tmp_csv_path = Path(tmp_file.name) - - try: - # 1. Export changed data from Keboola - logger.info(" -> Exporting changes from Keboola...") - export_info = self.keboola_client.export_table( - table_id=table_config.id, - output_path=tmp_csv_path, - changed_since=changed_since, - ) - - if export_info["exported_rows"] == 0: - logger.info(" -> No changes since last synchronization") - if parquet_path.exists(): - existing_info = self.parquet_manager.get_parquet_info(parquet_path) - return { - "rows": existing_info["rows"], - "file_size_bytes": existing_info["file_size_bytes"], - "columns": existing_info.get("columns", 0), - "uncompressed_bytes": _get_uncompressed_size(parquet_path), - } - else: - return {"rows": 0, "file_size_bytes": 0, "columns": 0, "uncompressed_bytes": 0} - - # 2. Get dtypes and date columns - dtypes = self.keboola_client.get_pandas_dtypes(table_config.id) - date_columns = self.keboola_client.get_date_columns(table_config.id) - pyarrow_schema = self.keboola_client.get_pyarrow_schema(table_config.id) - - # 3. If Parquet exists, merge; otherwise create new - if parquet_path.exists(): - logger.info( - f" -> Merging {export_info['exported_rows']} changes into Parquet..." - ) - - with tempfile.NamedTemporaryFile( - mode="w", suffix=".parquet", delete=False, dir=staging_dir - ) as tmp_parquet_file: - tmp_parquet_path = Path(tmp_parquet_file.name) - - try: - merge_info = self.parquet_manager.merge_parquet( - existing_parquet=parquet_path, - new_csv=tmp_csv_path, - output_parquet=tmp_parquet_path, - primary_key=table_config.get_primary_key_columns(), - dtypes=dtypes, - date_columns=date_columns, - pyarrow_schema=pyarrow_schema, - ) - - tmp_parquet_path.replace(parquet_path) - - return { - "rows": merge_info["total_rows"], - "file_size_bytes": parquet_path.stat().st_size, - "columns": merge_info.get("total_columns", 0), - "uncompressed_bytes": _get_uncompressed_size(parquet_path), - } - - finally: - if tmp_parquet_path.exists(): - tmp_parquet_path.unlink() - - else: - logger.info(" -> Creating new Parquet...") - parquet_info = self.parquet_manager.csv_to_parquet( - csv_path=tmp_csv_path, - parquet_path=parquet_path, - dtypes=dtypes, - table_id=table_config.id, - date_columns=date_columns, - pyarrow_schema=pyarrow_schema, - ) - - return { - "rows": parquet_info["rows"], - "file_size_bytes": parquet_info["parquet_size_bytes"], - "columns": parquet_info.get("columns", 0), - "uncompressed_bytes": _get_uncompressed_size(parquet_path), - } - - finally: - if tmp_csv_path.exists(): - tmp_csv_path.unlink() - - def _incremental_partitioned_sync( - self, - table_config: TableConfig, - sync_state: SyncState, - ) -> Dict[str, Any]: - """ - Incremental sync with partitioned output. - - Downloads only changed rows using changedSince API parameter, - then partitions by partition_by column and merges into existing - partition files. Same logic as _partitioned_sync but uses - changedSince instead of whereFilters. - - For initial load of large tables (max_history_days > chunk_days), - uses chunked download to avoid filling up disk space. Each chunk - has 1-day overlap with the next to ensure no data is lost at boundaries. - """ - import pandas as pd - - logger.info( - f"Incremental sync (partitioned): {table_config.name} " - f"(by {table_config.partition_by}, {table_config.partition_granularity})" - ) - - partition_dir = self.config.get_parquet_path(table_config) - staging_dir = self.config.get_staging_path() - - last_sync = sync_state.get_last_sync(table_config.id) - - # For initial load (no last_sync), always use chunked approach - if not last_sync: - return self._chunked_initial_load(table_config, partition_dir, staging_dir) - - # Regular incremental sync - last_sync_dt = datetime.fromisoformat(last_sync) - window_days = table_config.incremental_window_days or 7 - changed_since_dt = last_sync_dt - timedelta(days=window_days) - changed_since = changed_since_dt.isoformat() - - logger.info( - f" -> ChangedSince: {changed_since} (window: {window_days} days)" - ) - - with tempfile.NamedTemporaryFile( - mode="w", suffix=".csv", delete=False, dir=staging_dir - ) as tmp_file: - tmp_csv_path = Path(tmp_file.name) - - try: - logger.info(" -> Exporting changes from Keboola...") - export_info = self.keboola_client.export_table( - table_id=table_config.id, - output_path=tmp_csv_path, - changed_since=changed_since, - ) - - if export_info["exported_rows"] == 0: - logger.info(" -> No changes since last synchronization") - return self._get_partition_totals(partition_dir) - - dtypes = self.keboola_client.get_pandas_dtypes(table_config.id) - date_columns = self.keboola_client.get_date_columns(table_config.id) - pyarrow_schema = self.keboola_client.get_pyarrow_schema(table_config.id) - - logger.info(f" -> Processing {export_info['exported_rows']} changed rows...") - - partitions_updated = self._process_csv_to_partitions( - tmp_csv_path, table_config, partition_dir, - dtypes=dtypes, date_columns=date_columns, pyarrow_schema=pyarrow_schema, - ) - - self._deduplicate_partitions( - table_config, partitions_updated, - date_columns=date_columns, pyarrow_schema=pyarrow_schema, - ) - - logger.info(f" -> Incremental sync complete, {len(partitions_updated)} partitions updated") - return self._get_partition_totals(partition_dir) - - finally: - if tmp_csv_path.exists(): - tmp_csv_path.unlink() - - def _chunked_initial_load( - self, - table_config: TableConfig, - partition_dir: Path, - staging_dir: Path, - ) -> Dict[str, Any]: - """ - Chunked initial load for large tables. - - Downloads data in time-window chunks to avoid filling up disk space. - Each chunk has 1-day overlap with the next to ensure no data is lost - at boundaries. Deduplication removes any duplicates from overlaps. - """ - chunk_days = table_config.initial_load_chunk_days - max_history_days = table_config.max_history_days - overlap_days = 1 - max_chunks_safety = 120 - - now = datetime.now() - - if max_history_days: - num_chunks = (max_history_days + chunk_days - 1) // chunk_days - logger.info( - f" -> CHUNKED INITIAL LOAD: {max_history_days} days in {num_chunks} chunks " - f"of {chunk_days} days each (with {overlap_days}-day overlap)" - ) - else: - num_chunks = None - logger.info( - f" -> CHUNKED INITIAL LOAD: iterating backwards in {chunk_days}-day chunks " - f"until no more data (with {overlap_days}-day overlap)" - ) - - self._cleanup_staging() - - dtypes = self.keboola_client.get_pandas_dtypes(table_config.id) - date_columns = self.keboola_client.get_date_columns(table_config.id) - pyarrow_schema = self.keboola_client.get_pyarrow_schema(table_config.id) - - all_partitions_updated = set() - chunk_idx = 0 - consecutive_empty_chunks = 0 - - while True: - if chunk_idx >= max_chunks_safety: - logger.warning(f" -> Reached safety limit of {max_chunks_safety} chunks, stopping") - break - - if num_chunks is not None and chunk_idx >= num_chunks: - break - - chunk_end_offset = chunk_idx * chunk_days - chunk_start_offset = chunk_end_offset + chunk_days + overlap_days - - chunk_end = now - timedelta(days=chunk_end_offset) if chunk_idx > 0 else None - chunk_start = now - timedelta(days=chunk_start_offset) - - if max_history_days and chunk_start_offset > max_history_days: - chunk_start = now - timedelta(days=max_history_days) - - chunk_label = f"{chunk_idx + 1}" if num_chunks is None else f"{chunk_idx + 1}/{num_chunks}" - logger.info( - f" -> Chunk {chunk_label}: " - f"{chunk_start.strftime('%Y-%m-%d')} to " - f"{chunk_end.strftime('%Y-%m-%d') if chunk_end else 'now'}" - ) - - with tempfile.NamedTemporaryFile( - mode="w", suffix=".csv", delete=False, dir=staging_dir - ) as tmp_file: - tmp_csv_path = Path(tmp_file.name) - - try: - export_info = self.keboola_client.export_table( - table_id=table_config.id, - output_path=tmp_csv_path, - changed_since=chunk_start.isoformat(), - changed_until=chunk_end.isoformat() if chunk_end else None, - ) - - if export_info["exported_rows"] == 0: - logger.info(" -> No data in this chunk") - consecutive_empty_chunks += 1 - if num_chunks is None and consecutive_empty_chunks >= 2: - logger.info( - f" -> Found {consecutive_empty_chunks} consecutive empty chunks, " - f"assuming end of history" - ) - break - chunk_idx += 1 - continue - - consecutive_empty_chunks = 0 - logger.info(f" -> Exported {export_info['exported_rows']} rows") - - partitions_updated = self._process_csv_to_partitions( - tmp_csv_path, table_config, partition_dir, - dtypes=dtypes, date_columns=date_columns, pyarrow_schema=pyarrow_schema, - ) - all_partitions_updated.update(partitions_updated) - - logger.info(f" -> Processed into {len(partitions_updated)} partitions") - - finally: - if tmp_csv_path.exists(): - tmp_csv_path.unlink() - - chunk_idx += 1 - - if all_partitions_updated: - logger.info( - f" -> Final deduplication of {len(all_partitions_updated)} partitions " - f"(removing duplicates from {overlap_days}-day overlaps)..." - ) - self._deduplicate_partitions( - table_config, all_partitions_updated, - date_columns=date_columns, pyarrow_schema=pyarrow_schema, - ) - - logger.info( - f" -> Chunked initial load complete: {len(all_partitions_updated)} partitions" - ) - - return self._get_partition_totals(partition_dir) - - def _process_csv_to_partitions( - self, - csv_path: Path, - table_config: TableConfig, - partition_dir: Path, - dtypes: Optional[Dict[str, str]] = None, - date_columns: Optional[List[str]] = None, - pyarrow_schema: Optional[pa.Schema] = None, - ) -> set: - """ - Process CSV file and write to partition files. - - Returns: - Set of partition keys that were updated - """ - import pandas as pd - import pyarrow.parquet as pq - - partition_col = table_config.partition_by - granularity = table_config.partition_granularity or "month" - - partitions_updated = set() - chunk_size = 500000 # 500k rows per pandas chunk - - chunk_num = 0 - for chunk_df in pd.read_csv(csv_path, chunksize=chunk_size, dtype=str): - chunk_num += 1 - logger.debug(f" -> Processing pandas chunk {chunk_num} ({len(chunk_df)} rows)...") - - if partition_col not in chunk_df.columns: - raise ValueError(f"Partition column '{partition_col}' not found in data") - - # Apply dtypes using _convert_column (except datetime columns) - if dtypes: - for col, dtype in dtypes.items(): - if col in chunk_df.columns and "datetime" not in dtype: - try: - chunk_df[col] = _convert_column(chunk_df[col], dtype, col_name=col) - except Exception as e: - logger.warning(f"Failed to apply dtype {dtype} to column {col}: {e}") - - # Convert partition column to datetime - if not pd.api.types.is_datetime64_any_dtype(chunk_df[partition_col]): - chunk_df[partition_col] = pd.to_datetime( - chunk_df[partition_col], format="ISO8601", utc=True - ) - - # Create partition key based on granularity - if granularity == "month": - chunk_df["_partition_key"] = chunk_df[partition_col].dt.strftime("%Y_%m") - elif granularity == "day": - chunk_df["_partition_key"] = chunk_df[partition_col].dt.strftime("%Y_%m_%d") - elif granularity == "year": - chunk_df["_partition_key"] = chunk_df[partition_col].dt.strftime("%Y") - - # Group by partition and append to partition files - for partition_key, partition_df in chunk_df.groupby("_partition_key"): - partition_df = partition_df.drop(columns=["_partition_key"]) - partition_path = self.config.get_partition_path(table_config, partition_key) - partitions_updated.add(partition_key) - - if partition_path.exists(): - existing_df = pd.read_parquet(partition_path) - merged_df = pd.concat([existing_df, partition_df], ignore_index=True) - table = pa.Table.from_pandas(merged_df, preserve_index=False) - if date_columns: - table = convert_date_columns_to_date32(table, date_columns) - if pyarrow_schema: - table = apply_schema_to_table(table, pyarrow_schema) - pq.write_table(table, partition_path, compression="snappy") - else: - table = pa.Table.from_pandas(partition_df, preserve_index=False) - if date_columns: - table = convert_date_columns_to_date32(table, date_columns) - if pyarrow_schema: - table = apply_schema_to_table(table, pyarrow_schema) - pq.write_table(table, partition_path, compression="snappy") - - return partitions_updated - - def _deduplicate_partitions( - self, - table_config: TableConfig, - partitions_to_dedup: set, - date_columns: Optional[List[str]] = None, - pyarrow_schema: Optional[pa.Schema] = None, - ): - """ - Deduplicate partition files based on primary key. - """ - import pandas as pd - import pyarrow.parquet as pq - - primary_key_cols = table_config.get_primary_key_columns() - - logger.info(f" -> Deduplicating {len(partitions_to_dedup)} partitions...") - - for partition_key in sorted(partitions_to_dedup): - partition_path = self.config.get_partition_path(table_config, partition_key) - - if not partition_path.exists(): - continue - - df = pd.read_parquet(partition_path) - rows_before = len(df) - df = df.drop_duplicates(subset=primary_key_cols, keep="last") - rows_after = len(df) - - table = pa.Table.from_pandas(df, preserve_index=False) - if date_columns: - table = convert_date_columns_to_date32(table, date_columns) - if pyarrow_schema: - table = apply_schema_to_table(table, pyarrow_schema) - pq.write_table(table, partition_path, compression="snappy") - - if rows_before != rows_after: - logger.debug( - f" -> Partition {partition_key}: {rows_before} -> {rows_after} rows " - f"(removed {rows_before - rows_after} duplicates)" - ) - - def _get_partition_totals(self, partition_dir: Path) -> Dict[str, Any]: - """ - Calculate totals from all partition files in a directory. - """ - import pyarrow.parquet as pq - - total_rows = 0 - total_size = 0 - total_uncompressed = 0 - total_columns = 0 - - if not partition_dir.exists(): - return {"rows": 0, "file_size_bytes": 0, "columns": 0, "uncompressed_bytes": 0} - - all_partitions = list(partition_dir.glob("*.parquet")) - - for part_path in all_partitions: - try: - pf = pq.ParquetFile(part_path) - meta = pf.metadata - total_rows += meta.num_rows - total_size += part_path.stat().st_size - if total_columns == 0: - total_columns = len(pf.schema_arrow) - for rg_idx in range(meta.num_row_groups): - rg = meta.row_group(rg_idx) - for col_idx in range(rg.num_columns): - total_uncompressed += rg.column(col_idx).total_uncompressed_size - except Exception as e: - logger.warning(f" -> Skipping corrupt partition {part_path.name}: {e}") - - return { - "rows": total_rows, - "file_size_bytes": total_size, - "partitions": len(all_partitions), - "columns": total_columns, - "uncompressed_bytes": total_uncompressed, - } - - def _partitioned_sync(self, table_config: TableConfig) -> Dict[str, Any]: - """ - Partitioned sync strategy. - - Downloads data and splits into monthly (or other granularity) partitions. - Each partition is stored as separate Parquet file and merged independently. - """ - logger.info( - f"Partitioned sync: {table_config.name} " - f"(by {table_config.partition_by}, {table_config.partition_granularity})" - ) - - partition_dir = self.config.get_parquet_path(table_config) - staging_dir = self.config.get_staging_path() - - with tempfile.NamedTemporaryFile( - mode="w", suffix=".csv", delete=False, dir=staging_dir - ) as tmp_file: - tmp_csv_path = Path(tmp_file.name) - - try: - filters_desc = "" - if table_config.where_filters: - filters_desc = f" (filters: {len(table_config.where_filters)})" - logger.info(f" -> Exporting from Keboola...{filters_desc}") - export_info = self.keboola_client.export_table( - table_id=table_config.id, - output_path=tmp_csv_path, - where_filters=table_config.where_filters if table_config.where_filters else None, - ) - - if export_info["exported_rows"] == 0: - logger.info(" -> No data exported") - return self._get_partition_totals(partition_dir) - - dtypes = self.keboola_client.get_pandas_dtypes(table_config.id) - date_columns = self.keboola_client.get_date_columns(table_config.id) - pyarrow_schema = self.keboola_client.get_pyarrow_schema(table_config.id) - - logger.info(f" -> Processing CSV in chunks ({export_info['exported_rows']} rows)...") - - partitions_seen = self._process_csv_to_partitions( - tmp_csv_path, table_config, partition_dir, - dtypes=dtypes, date_columns=date_columns, pyarrow_schema=pyarrow_schema, - ) - - self._deduplicate_partitions( - table_config, partitions_seen, - date_columns=date_columns, pyarrow_schema=pyarrow_schema, - ) - - totals = self._get_partition_totals(partition_dir) - logger.info( - f" -> Partitioned sync complete: {totals.get('partitions', 0)} partitions on disk, " - f"{totals['rows']} total rows" - ) - - return totals - - finally: - if tmp_csv_path.exists(): - tmp_csv_path.unlink() diff --git a/connectors/keboola/client.py b/connectors/keboola/client.py index 1c1b98e..e1bdf7d 100644 --- a/connectors/keboola/client.py +++ b/connectors/keboola/client.py @@ -23,7 +23,16 @@ import requests from kbcstorage.client import Client from kbcstorage.tables import Tables -from src.config import get_config, TableConfig, WhereFilter +from dataclasses import dataclass, field +from typing import Optional as _Opt + + +@dataclass +class WhereFilter: + """Keboola where filter for export operations.""" + column: str + operator: str = "eq" + values: list = field(default_factory=list) logger = logging.getLogger(__name__) @@ -86,10 +95,8 @@ class KeboolaClient: token: Storage API token. If None, loads from configuration. url: Stack URL. If None, loads from configuration. """ - config = get_config() - - self.token = token or config.keboola_token - self.url = url or config.keboola_stack_url + self.token = token or os.environ.get("KEBOOLA_STORAGE_TOKEN", "") + self.url = url or os.environ.get("KEBOOLA_STACK_URL", "") if not self.token: raise ValueError( @@ -824,7 +831,7 @@ def create_client() -> KeboolaClient: """ Factory function to create Keboola client. - Uses configuration from get_config(). + Uses environment variables for token and URL. Returns: KeboolaClient instance @@ -848,21 +855,21 @@ if __name__ == "__main__": print(" ❌ Connection failed!") exit(1) - # Test metadata + # Test metadata with discovered tables print("\n2️⃣ Testing metadata...") - config = get_config() - if config.tables: - test_table = config.tables[0] - print(f" Testing table: {test_table.id}") + tables = client.discover_all_tables() + if tables: + test_table_id = tables[0].get("id", tables[0].get("name", "")) + print(f" Testing table: {test_table_id}") - metadata = client.get_table_metadata(test_table.id) + metadata = client.get_table_metadata(test_table_id) print(f" ✅ Metadata loaded:") print(f" Columns: {len(metadata.get('columns', []))}") print(f" Rows: {metadata.get('row_count', 0):,}") print(f" Size: {metadata.get('data_size_bytes', 0) / 1024 / 1024:.2f} MB") # Test dtypes - dtypes = client.get_pandas_dtypes(test_table.id) + dtypes = client.get_pandas_dtypes(test_table_id) print(f" Pandas dtypes:") for col, dtype in list(dtypes.items())[:5]: print(f" {col}: {dtype}") diff --git a/connectors/keboola/tests/test_adapter.py b/connectors/keboola/tests/test_adapter.py deleted file mode 100644 index 9cd0795..0000000 --- a/connectors/keboola/tests/test_adapter.py +++ /dev/null @@ -1,194 +0,0 @@ -"""Tests for Keboola adapter and DataSource ABC / factory in src.data_sync. - -Covers: -- DataSource ABC default method behaviour -- create_data_source factory: keboola import error, unknown source, dynamic lookup -- KeboolaDataSource env var validation -""" - -from unittest.mock import patch, MagicMock - -import pytest - -from src.data_sync import DataSource, SyncState, create_data_source - - -# --------------------------------------------------------------------------- -# Helpers -# --------------------------------------------------------------------------- - -class _MinimalSource(DataSource): - """Concrete DataSource that only implements the required abstract method.""" - - def sync_table(self, table_config, sync_state): - return {"success": True, "rows": 0} - - -# --------------------------------------------------------------------------- -# 1. DataSource ABC default methods -# --------------------------------------------------------------------------- - -class TestDataSourceABCDefaultMethods: - """Verify that the optional methods on the DataSource ABC return sensible defaults.""" - - def test_get_column_metadata_returns_none(self): - source = _MinimalSource() - assert source.get_column_metadata("any.table.id") is None - - def test_get_source_name_returns_unknown(self): - source = _MinimalSource() - assert source.get_source_name() == "Unknown" - - -# --------------------------------------------------------------------------- -# 2. Factory: keboola without kbcstorage -# --------------------------------------------------------------------------- - -class TestFactoryKeboolaWithoutKbcstorage: - """create_data_source('keboola') must raise ImportError when kbcstorage is missing.""" - - def test_raises_import_error(self): - # Patch the import inside create_data_source so that importing - # connectors.keboola.adapter triggers a ModuleNotFoundError - # mentioning kbcstorage (simulates the package not being installed). - original_import = __builtins__.__import__ if hasattr(__builtins__, "__import__") else __import__ - - def _fake_import(name, *args, **kwargs): - if name == "connectors.keboola.adapter": - raise ModuleNotFoundError("No module named 'kbcstorage'") - return original_import(name, *args, **kwargs) - - with patch("builtins.__import__", side_effect=_fake_import): - with pytest.raises(ImportError, match="kbcstorage"): - create_data_source("keboola") - - -# --------------------------------------------------------------------------- -# 3. Factory: unknown source type -# --------------------------------------------------------------------------- - -class TestFactoryUnknownSource: - """create_data_source with a non-existent source type must raise ValueError.""" - - def test_raises_value_error(self): - with pytest.raises(ValueError, match="Unknown data source.*nonexistent"): - create_data_source("nonexistent") - - def test_error_message_contains_guidance(self): - with pytest.raises(ValueError, match="connectors/nonexistent/adapter.py"): - create_data_source("nonexistent") - - -# --------------------------------------------------------------------------- -# 4. Factory: dynamic connector lookup -# --------------------------------------------------------------------------- - -class TestFactoryDynamicConnectorLookup: - """create_data_source attempts dynamic import for unknown connector types.""" - - def test_jira_lookup_falls_through_to_value_error(self): - """'jira' has no adapter.py exporting a DataSource, so the factory - should try importing connectors.jira.adapter, fail, and finally - raise ValueError.""" - with pytest.raises(ValueError, match="Unknown data source.*jira"): - create_data_source("jira") - - def test_dynamic_import_is_attempted(self): - """Verify that importlib.import_module is called with the expected - module path when the source type is not hard-coded.""" - with patch("src.data_sync.importlib.import_module", side_effect=ModuleNotFoundError) as mock_imp: - with pytest.raises(ValueError): - create_data_source("custom_source") - mock_imp.assert_called_once_with("connectors.custom_source.adapter") - - def test_dynamic_import_with_factory_function(self): - """If the dynamically loaded module exposes create_data_source(), - the factory should call it and return its result.""" - fake_source = _MinimalSource() - fake_module = MagicMock() - fake_module.create_data_source = MagicMock(return_value=fake_source) - - with patch("src.data_sync.importlib.import_module", return_value=fake_module): - result = create_data_source("my_connector") - - assert result is fake_source - fake_module.create_data_source.assert_called_once() - - def test_dynamic_import_with_datasource_subclass(self): - """If the dynamically loaded module has no factory but exposes a - DataSource subclass, the factory should instantiate it.""" - import types - - fake_module = types.ModuleType("connectors.my_connector.adapter") - fake_module.MyDataSource = _MinimalSource - - with patch("src.data_sync.importlib.import_module", return_value=fake_module): - result = create_data_source("my_connector") - - assert isinstance(result, _MinimalSource) - - -# --------------------------------------------------------------------------- -# 5. KeboolaDataSource validates env vars -# --------------------------------------------------------------------------- - -class TestKeboolaAdapterValidatesEnvVars: - """KeboolaDataSource.__init__ must raise ValueError when required - Keboola env vars are missing.""" - - def _make_mock_config(self, token="", stack_url="", project_id=""): - """Build a mock config with the given Keboola credential values.""" - cfg = MagicMock() - cfg.keboola_token = token - cfg.keboola_stack_url = stack_url - cfg.keboola_project_id = project_id - return cfg - - def test_all_missing(self): - mock_cfg = self._make_mock_config() - with patch("connectors.keboola.adapter.get_config", return_value=mock_cfg): - with pytest.raises(ValueError, match="KEBOOLA_STORAGE_TOKEN"): - from connectors.keboola.adapter import KeboolaDataSource - KeboolaDataSource() - - def test_token_missing(self): - mock_cfg = self._make_mock_config( - stack_url="https://connection.keboola.com", - project_id="12345", - ) - with patch("connectors.keboola.adapter.get_config", return_value=mock_cfg): - with pytest.raises(ValueError, match="KEBOOLA_STORAGE_TOKEN"): - from connectors.keboola.adapter import KeboolaDataSource - KeboolaDataSource() - - def test_stack_url_missing(self): - mock_cfg = self._make_mock_config( - token="my-token", - project_id="12345", - ) - with patch("connectors.keboola.adapter.get_config", return_value=mock_cfg): - with pytest.raises(ValueError, match="KEBOOLA_STACK_URL"): - from connectors.keboola.adapter import KeboolaDataSource - KeboolaDataSource() - - def test_project_id_missing(self): - mock_cfg = self._make_mock_config( - token="my-token", - stack_url="https://connection.keboola.com", - ) - with patch("connectors.keboola.adapter.get_config", return_value=mock_cfg): - with pytest.raises(ValueError, match="KEBOOLA_PROJECT_ID"): - from connectors.keboola.adapter import KeboolaDataSource - KeboolaDataSource() - - def test_error_lists_all_missing_vars(self): - """When multiple env vars are missing, all should appear in the error message.""" - mock_cfg = self._make_mock_config() - with patch("connectors.keboola.adapter.get_config", return_value=mock_cfg): - with pytest.raises(ValueError) as exc_info: - from connectors.keboola.adapter import KeboolaDataSource - KeboolaDataSource() - msg = str(exc_info.value) - assert "KEBOOLA_STORAGE_TOKEN" in msg - assert "KEBOOLA_STACK_URL" in msg - assert "KEBOOLA_PROJECT_ID" in msg diff --git a/connectors/openmetadata/enricher.py b/connectors/openmetadata/enricher.py index 5f01316..c550033 100644 --- a/connectors/openmetadata/enricher.py +++ b/connectors/openmetadata/enricher.py @@ -15,10 +15,23 @@ from datetime import datetime, timedelta from pathlib import Path from typing import Dict, List, Optional, Any -from src.config import TableConfig +from dataclasses import dataclass as _dataclass from .client import OpenMetadataClient +@_dataclass +class TableConfig: + """Minimal table config used by the enricher. + + Attributes expected by CatalogEnricher: id, name, and optional catalog_fqn. + Can be constructed from a dict, e.g. ``TableConfig(**row)`` where *row* + comes from ``TableRegistryRepository.get()``. + """ + id: str + name: str + catalog_fqn: str = "" + + logger = logging.getLogger(__name__) diff --git a/scripts/generate_sample_data.py b/scripts/generate_sample_data.py index 572079e..3ec0d60 100644 --- a/scripts/generate_sample_data.py +++ b/scripts/generate_sample_data.py @@ -905,35 +905,35 @@ class SampleDataGenerator: # ── Parquet conversion ───────────────────────────────────── def _convert_to_parquet(self, parquet_dir: Path) -> None: - """Convert generated CSVs to Parquet using project's ParquetManager.""" - # Ensure project root is importable (script may run from any cwd) - project_root = Path(__file__).resolve().parent.parent - if str(project_root) not in sys.path: - sys.path.insert(0, str(project_root)) - from src.parquet_manager import create_parquet_manager + """Convert generated CSVs to Parquet using DuckDB.""" + import duckdb - manager = create_parquet_manager() parquet_dir.mkdir(parents=True, exist_ok=True) logger.info(f" Converting to Parquet -> {parquet_dir}/") + conn = duckdb.connect() for csv_path in sorted(self.output_dir.glob("*.csv")): table_name = csv_path.stem - schema = TABLE_SCHEMAS.get(table_name, {}) parquet_path = parquet_dir / f"{table_name}.parquet" - result = manager.csv_to_parquet( - csv_path=csv_path, - parquet_path=parquet_path, - dtypes=schema.get("dtypes"), - parse_dates=schema.get("parse_dates"), - date_columns=schema.get("date_columns"), - table_id=f"sample.{table_name}", + conn.execute( + f"COPY (SELECT * FROM read_csv_auto('{csv_path}')) " + f"TO '{parquet_path}' (FORMAT PARQUET, COMPRESSION ZSTD)" ) + + # Report stats + row_count = conn.execute( + f"SELECT count(*) FROM '{parquet_path}'" + ).fetchone()[0] + parquet_size = parquet_path.stat().st_size + csv_size = csv_path.stat().st_size + ratio = csv_size / parquet_size if parquet_size > 0 else 0 logger.info( - f" {table_name}: {result['rows']:,} rows, " - f"{result['parquet_size_bytes'] / 1024:.0f} KB " - f"({result['compression_ratio']:.1f}x compression)" + f" {table_name}: {row_count:,} rows, " + f"{parquet_size / 1024:.0f} KB " + f"({ratio:.1f}x compression)" ) + conn.close() # ── Orchestration ────────────────────────────────────────── diff --git a/src/catalog_export.py b/src/catalog_export.py index e306c38..bbe4242 100644 --- a/src/catalog_export.py +++ b/src/catalog_export.py @@ -34,7 +34,8 @@ from connectors.openmetadata.transformer import ( sanitize_filename, table_to_yaml_dict, ) -from src.config import Config +from src.db import get_system_db +from src.repositories.table_registry import TableRegistryRepository # --------------------------------------------------------------------------- # Logging @@ -279,14 +280,14 @@ def _write_metrics_index( def export_tables( client: OpenMetadataClient, - config: Config, + tables: list[dict], docs_dir: Path, catalog_url: str, ) -> int: """ Export table metadata from OpenMetadata to YAML files. - For each table defined in data_description.md: + For each table in the registry: 1. Derives the OpenMetadata FQN 2. Fetches table metadata (columns, owners, tags, description) 3. Transforms to YAML dict @@ -294,7 +295,7 @@ def export_tables( Args: client: Initialized OpenMetadata API client - config: Application config with table definitions + tables: List of table dicts from TableRegistryRepository.list_all() docs_dir: Base docs directory catalog_url: Catalog URL for header comments @@ -307,10 +308,12 @@ def export_tables( written_files: set[Path] = set() count = 0 - for table_config in config.tables: + for tbl in tables: + table_id = tbl.get("id", "") + table_name = tbl.get("name", "") try: - # Derive FQN: explicit override or auto-derive - fqn = table_config.catalog_fqn or f"bigquery.{table_config.id}" + # Use explicit catalog_fqn if set, otherwise derive from table id + fqn = tbl.get("catalog_fqn") or f"bigquery.{table_id}" logger.debug(f"Fetching table metadata: {fqn}") raw_table = client.get_table(fqn) @@ -318,7 +321,7 @@ def export_tables( yaml_dict = table_to_yaml_dict(raw_table) # Write table YAML - file_path = tables_dir / f"{table_config.name}.yml" + file_path = tables_dir / f"{table_name}.yml" header = _yaml_header(catalog_url, fqn, entity_type="table") yaml_content = yaml.dump( yaml_dict, @@ -330,10 +333,10 @@ def export_tables( written_files.add(file_path) count += 1 - logger.info(f"Exported table: {table_config.name} ({len(yaml_dict.get('columns', []))} columns)") + logger.info(f"Exported table: {table_name} ({len(yaml_dict.get('columns', []))} columns)") except Exception as e: - logger.warning(f"Failed to export table {table_config.name}: {e}") + logger.warning(f"Failed to export table {table_name}: {e}") continue # Cleanup stale auto-generated table files @@ -443,10 +446,12 @@ def main() -> None: # Export tables try: - config = Config() - tables_count = export_tables(client, config, docs_dir, catalog_url) + conn = get_system_db() + repo = TableRegistryRepository(conn) + registered_tables = repo.list_all() + tables_count = export_tables(client, registered_tables, docs_dir, catalog_url) except Exception as e: - logger.warning(f"Table export skipped (config error): {e}") + logger.warning(f"Table export skipped (registry error): {e}") tables_count = 0 # Write sync state diff --git a/src/config.py b/src/config.py deleted file mode 100644 index f3422c9..0000000 --- a/src/config.py +++ /dev/null @@ -1,653 +0,0 @@ -""" -Configuration Management - -This module handles: -1. Loading environment variables from .env file -2. Parsing data_description.md (YAML blocks with table definitions) -3. Validating configuration -4. Providing structured configuration data for other modules - -SINGLE SOURCE OF TRUTH is data_description.md - it defines: -- List of tables to synchronize -- Sync strategies (full_refresh vs incremental) -- Primary keys and foreign keys -- Incremental columns and windows -""" - -import os -import re -import logging -from pathlib import Path -from typing import Dict, List, Optional, Any -from dataclasses import dataclass, field -from datetime import datetime, timedelta - -import yaml -from dotenv import load_dotenv - - -# Logging setup -logging.basicConfig( - level=logging.INFO, - format='%(asctime)s - %(name)s - %(levelname)s - %(message)s' -) -logger = logging.getLogger(__name__) - - -@dataclass -class ForeignKey: - """ - Representation of foreign key relationship between tables. - - Attributes: - column: Column name in this table (e.g., "company_id") - references: Reference table and column (e.g., "company.id") - description: Relationship description - """ - column: str - references: str - description: Optional[str] = None - - -@dataclass -class WhereFilter: - """ - Filter for exporting subset of table data. - - Used with Keboola Storage API whereFilters parameter. - - Attributes: - column: Column name to filter on - operator: Comparison operator (eq, ne, gt, ge, lt, le) - values: List of values to compare against - """ - column: str - operator: str # eq, ne, gt, ge, lt, le - values: List[str] = field(default_factory=list) - - -@dataclass -class TableConfig: - """ - Configuration for a single table. - - Attributes: - id: Full table ID in Keboola (e.g., "in.c-sfdc.company") - name: Short table name (e.g., "company") - description: Table description - primary_key: Primary key column name (optional for remote-only tables) - sync_strategy: "full_refresh", "incremental", "partitioned", or "none" (for remote-only tables) - incremental_window_days: Number of days to backtrack for incremental sync - partition_by: Column name to partition by (for incremental/partitioned with partitions) - partition_granularity: Partition granularity: "month", "day", or "year" - foreign_keys: List of foreign key relationships - where_filters: List of filters to apply when exporting (for downloading subset of data) - folder: Override folder name (instead of bucket-level folder_mapping) - max_history_days: Max days of history for initial incremental load (None = download all) - dataset: Dataset group name for on-demand tables (e.g., "kbc_telemetry_expert") - initial_load_chunk_days: Chunk size in days for chunked initial load (default: 30) - sync_schedule: Schedule for automatic sync: "every 15m", "every 1h", "daily 05:00" (UTC) - profile_after_sync: Run profiler after sync (default True; disable for frequently synced tables) - """ - id: str - name: str - description: str - primary_key: str = "" # Optional for remote-only tables - sync_strategy: str = "none" # "full_refresh", "incremental", "partitioned", or "none" (remote-only) - incremental_window_days: Optional[int] = None - partition_by: Optional[str] = None - partition_granularity: Optional[str] = None # "month", "day", "year" - foreign_keys: List[ForeignKey] = field(default_factory=list) - where_filters: List[WhereFilter] = field(default_factory=list) - folder: Optional[str] = None - max_history_days: Optional[int] = None - dataset: Optional[str] = None - initial_load_chunk_days: int = 30 - incremental_column: Optional[str] = None # Column for timestamp-based incremental sync (BigQuery) - columns: Optional[List[str]] = None # Subset of columns to sync (None = all) - row_filter: Optional[str] = None # SQL WHERE clause for filtering (e.g., "event_date >= '2024-01-01'") - query_mode: str = "local" # "local" (Parquet) | "remote" (BQ direct) | "hybrid" (sync subset, query BQ) - bq_entity_type: str = "view" # "view" (Python BQ client) | "table" (DuckDB BQ extension) - partition_column_type: str = "TIMESTAMP" # BQ SQL type for partition column: "DATE", "TIMESTAMP", "DATETIME" - catalog_fqn: Optional[str] = None # Explicit OpenMetadata FQN override (auto-derived if not set) - sync_schedule: Optional[str] = None # Schedule: "every 15m", "every 1h", "daily 05:00" (UTC) - profile_after_sync: bool = True # Run profiler after sync (disable for frequently synced tables) - - def __post_init__(self): - """Validate configuration after initialization.""" - # Validate query_mode - valid_query_modes = ("local", "remote", "hybrid") - if self.query_mode not in valid_query_modes: - raise ValueError( - f"Invalid query_mode '{self.query_mode}' for table {self.id}. " - f"Allowed values: {', '.join(valid_query_modes)}" - ) - - # Validate bq_entity_type - valid_entity_types = ("view", "table") - if self.bq_entity_type not in valid_entity_types: - raise ValueError( - f"Invalid bq_entity_type '{self.bq_entity_type}' for table {self.id}. " - f"Allowed values: {', '.join(valid_entity_types)}" - ) - - # Validate sync_strategy - if self.sync_strategy not in ["full_refresh", "incremental", "partitioned", "none"]: - raise ValueError( - f"Invalid sync_strategy '{self.sync_strategy}' for table {self.id}. " - f"Allowed values: 'full_refresh', 'incremental', 'partitioned', 'none'" - ) - - # For incremental strategy: - # - changedSince is calculated from last sync timestamp (Keboola internal) - # - partition_by is optional - if set, output will be partitioned - if self.sync_strategy == "incremental": - if not self.incremental_window_days: - # Default 7 days if not specified - self.incremental_window_days = 7 - logger.warning( - f"Table {self.id}: incremental_window_days not set, " - f"using default 7 days" - ) - # If partition_by is set, validate partition_granularity - if self.partition_by: - if not self.partition_granularity: - self.partition_granularity = "month" - logger.info( - f"Table {self.id}: partition_granularity not set, " - f"using default 'month'" - ) - if self.partition_granularity not in ["month", "day", "year"]: - raise ValueError( - f"Invalid partition_granularity '{self.partition_granularity}' for table {self.id}. " - f"Allowed values: 'month', 'day', 'year'" - ) - - # Validate partition_column_type - valid_column_types = ("DATE", "TIMESTAMP", "DATETIME") - if self.partition_column_type not in valid_column_types: - raise ValueError( - f"Invalid partition_column_type '{self.partition_column_type}' for table {self.id}. " - f"Allowed values: {', '.join(valid_column_types)}" - ) - - # Validate sync_schedule format - if self.sync_schedule: - import re as _re - valid_schedule = ( - _re.match(r"^every \d+[mh]$", self.sync_schedule) - or _re.match(r"^daily \d{2}:\d{2}(,\d{2}:\d{2})*$", self.sync_schedule) - ) - if not valid_schedule: - raise ValueError( - f"Invalid sync_schedule '{self.sync_schedule}' for table {self.id}. " - f"Allowed formats: 'every 15m', 'every 1h', 'daily 05:00', " - f"'daily 07:00,13:00,18:00'" - ) - - # For partitioned, partition_by must be defined - if self.sync_strategy == "partitioned": - if not self.partition_by: - raise ValueError( - f"Table {self.id} has sync_strategy='partitioned', " - f"but partition_by is missing" - ) - if not self.partition_granularity: - self.partition_granularity = "month" - logger.info( - f"Table {self.id}: partition_granularity not set, " - f"using default 'month'" - ) - if self.partition_granularity not in ["month", "day", "year"]: - raise ValueError( - f"Invalid partition_granularity '{self.partition_granularity}' for table {self.id}. " - f"Allowed values: 'month', 'day', 'year'" - ) - - def get_primary_key_columns(self) -> List[str]: - """ - Get primary key as list of column names. - - Supports both single and composite primary keys. - Composite PKs are defined as comma-separated string: "col1, col2" - - Returns: - List of column names forming the primary key - """ - # Split by comma and strip whitespace - return [col.strip() for col in self.primary_key.split(",")] - - def is_partitioned(self) -> bool: - """Check if table output should be partitioned. - - Returns True for: - - partitioned strategy (always partitioned) - - incremental strategy with partition_by set - """ - if self.sync_strategy == "partitioned": - return True - if self.sync_strategy == "incremental" and self.partition_by: - return True - return False - - -class Config: - """ - Main configuration class. - - Loads environment variables and parses data_description.md. - Provides access to all configuration parameters. - """ - - def __init__(self, env_file: Optional[str] = None): - """ - Initialize configuration. - - Args: - env_file: Path to .env file. If None, looks for .env in project root. - """ - # Find project root (folder containing data_description.md) - self.project_root = self._find_project_root() - - # Load environment variables - if env_file is None: - env_file = self.project_root / ".env" - - if env_file.exists(): - load_dotenv(env_file) - logger.info(f"Loaded from .env: {env_file}") - else: - logger.warning( - f".env file not found: {env_file}. " - f"Use config/.env.template as reference." - ) - - # Read by connectors/keboola/ if enabled - self.keboola_token = os.getenv("KEBOOLA_STORAGE_TOKEN") - self.keboola_stack_url = os.getenv("KEBOOLA_STACK_URL") - self.keboola_project_id = os.getenv("KEBOOLA_PROJECT_ID") - self.data_dir = Path(os.getenv("DATA_DIR", "./data")) - self.docs_output_dir = Path(os.getenv("DOCS_OUTPUT_DIR", "./docs")) - self.data_source = os.getenv("DATA_SOURCE", "local") - self.log_level = os.getenv("LOG_LEVEL", "INFO") - - # Set log level - logging.getLogger().setLevel(self.log_level) - - # Validate required environment variables - self._validate_env_vars() - - # Parse data_description.md - self.tables, self.folder_mapping = self._parse_data_description() - - logger.info(f"Configuration loaded: {len(self.tables)} tables") - - def _find_project_root(self) -> Path: - """ - Find project root (folder containing docs/data_description.md). - - Searches from current folder upwards. - - Returns: - Path to project root - - Raises: - FileNotFoundError: If docs/data_description.md is not found - """ - current = Path.cwd() - - # Try current folder first - if (current / "docs" / "data_description.md").exists(): - return current - - # Try parent folders (up to 5 levels) - for _ in range(5): - current = current.parent - if (current / "docs" / "data_description.md").exists(): - return current - - raise FileNotFoundError( - "docs/data_description.md not found. " - "Make sure you're running from project root." - ) - - def _resolve_placeholder(self, value: str) -> str: - """ - Resolve placeholders in filter values. - - Supported placeholders: - - {{last_week}}: 7 days ago - - {{last_month}}: 30 days ago - - {{last_2_months}}: 60 days ago - - {{last_3_months}}: 90 days ago - - {{last_6_months}}: 180 days ago - - {{last_year}}: 365 days ago - - {{last_2_years}}: 730 days ago - - {{today}}: Today's date - - Args: - value: String that may contain placeholder - - Returns: - Resolved string with actual date values - """ - if not isinstance(value, str): - return value - - today = datetime.now() - - placeholders = { - "{{last_week}}": (today - timedelta(days=7)).strftime("%Y-%m-%d"), - "{{last_month}}": (today - timedelta(days=30)).strftime("%Y-%m-%d"), - "{{last_2_months}}": (today - timedelta(days=60)).strftime("%Y-%m-%d"), - "{{last_3_months}}": (today - timedelta(days=90)).strftime("%Y-%m-%d"), - "{{last_6_months}}": (today - timedelta(days=180)).strftime("%Y-%m-%d"), - "{{last_year}}": (today - timedelta(days=365)).strftime("%Y-%m-%d"), - "{{last_2_years}}": (today - timedelta(days=730)).strftime("%Y-%m-%d"), - "{{today}}": today.strftime("%Y-%m-%d"), - } - - result = value - for placeholder, replacement in placeholders.items(): - if placeholder in result: - result = result.replace(placeholder, replacement) - logger.debug(f"Resolved placeholder: {placeholder} -> {replacement}") - - return result - - def _validate_env_vars(self): - """ - Validate that required environment variables are set based on data source type. - - Raises: - ValueError: If any required variable is missing - """ - # Keboola env vars are validated by connectors/keboola/adapter.py at init time. - # No source-specific validation needed here. - pass - - def _parse_data_description(self) -> tuple[List[TableConfig], Dict[str, str]]: - """ - Parse docs/data_description.md and extract table definitions. - - Looks for YAML blocks in markdown file and parses them. - - Returns: - Tuple of (List of TableConfig objects, folder_mapping dict) - - Raises: - FileNotFoundError: If docs/data_description.md doesn't exist - yaml.YAMLError: If YAML is invalid - """ - # Check CONFIG_DIR first, then project root - config_dir = Path(os.environ.get("CONFIG_DIR", "")) - if config_dir and (config_dir / "data_description.md").exists(): - data_desc_path = config_dir / "data_description.md" - else: - data_desc_path = self.project_root / "docs" / "data_description.md" - - if not data_desc_path.exists(): - raise FileNotFoundError( - f"docs/data_description.md not found: {data_desc_path}" - ) - - # Collect all markdown files to parse: main + dataset files - md_files = [data_desc_path] - datasets_dir = self.project_root / "docs" / "datasets" - if datasets_dir.exists(): - for md_file in sorted(datasets_dir.glob("*.md")): - md_files.append(md_file) - logger.info(f"Found dataset file: {md_file.name}") - - # Find YAML blocks (between ```yaml and ```) from all files - yaml_pattern = r'```yaml\n(.*?)```' - yaml_matches = [] - for md_file in md_files: - content = md_file.read_text() - yaml_matches.extend(re.findall(yaml_pattern, content, re.DOTALL)) - - if not yaml_matches: - raise ValueError( - "data_description.md contains no YAML blocks. " - "Make sure tables are defined in ```yaml blocks." - ) - - # Parse all YAML blocks and merge them - all_tables = [] - folder_mapping = {} - for yaml_block in yaml_matches: - try: - data = yaml.safe_load(yaml_block) - if data: - if "tables" in data: - all_tables.extend(data["tables"]) - if "folder_mapping" in data: - folder_mapping.update(data["folder_mapping"]) - except yaml.YAMLError as e: - logger.error(f"Error parsing YAML: {e}") - raise - - if not all_tables: - raise ValueError( - "data_description.md contains no tables. " - "Make sure YAML block contains 'tables:' key." - ) - - # Convert to TableConfig objects - table_configs = [] - for table_data in all_tables: - # Parse foreign keys - fk_list = [] - if "foreign_keys" in table_data: - for fk_data in table_data["foreign_keys"]: - fk = ForeignKey( - column=fk_data["column"], - references=fk_data["references"], - description=fk_data.get("description") - ) - fk_list.append(fk) - - # Parse where filters with placeholder resolution - wf_list = [] - if "where_filters" in table_data: - for wf_data in table_data["where_filters"]: - # Resolve placeholders in values - resolved_values = [ - self._resolve_placeholder(v) for v in wf_data.get("values", []) - ] - wf = WhereFilter( - column=wf_data["column"], - operator=wf_data["operator"], - values=resolved_values - ) - wf_list.append(wf) - - # Create TableConfig - config = TableConfig( - id=table_data["id"], - name=table_data["name"], - description=table_data["description"], - primary_key=table_data.get("primary_key", ""), - sync_strategy=table_data.get("sync_strategy", "none"), - incremental_window_days=table_data.get("incremental_window_days"), - partition_by=table_data.get("partition_by"), - partition_granularity=table_data.get("partition_granularity"), - foreign_keys=fk_list, - where_filters=wf_list, - folder=table_data.get("folder"), - max_history_days=table_data.get("max_history_days"), - dataset=table_data.get("dataset"), - initial_load_chunk_days=table_data.get("initial_load_chunk_days", 30), - incremental_column=table_data.get("incremental_column"), - columns=table_data.get("columns"), - row_filter=table_data.get("row_filter"), - query_mode=table_data.get("query_mode", "local"), - bq_entity_type=table_data.get("bq_entity_type", "view"), - partition_column_type=table_data.get("partition_column_type", "TIMESTAMP"), - catalog_fqn=table_data.get("catalog_fqn"), - sync_schedule=table_data.get("sync_schedule"), - profile_after_sync=table_data.get("profile_after_sync", True), - ) - table_configs.append(config) - - return table_configs, folder_mapping - - def get_table_config(self, table_id: str) -> Optional[TableConfig]: - """ - Get configuration for specific table by ID. - - Args: - table_id: Full table ID (e.g., "in.c-sfdc.company") - - Returns: - TableConfig or None if table not in configuration - """ - for table in self.tables: - if table.id == table_id: - return table - return None - - def get_parquet_path(self, table_config: TableConfig) -> Path: - """ - Get path to Parquet file for given table. - - Format: data/parquet/{folder_name}/{table_name}.parquet - For partitioned tables: data/parquet/{folder_name}/{table_name}/ (directory) - - Folder name is determined by folder_mapping in data_description.md. - Falls back to bucket name if no mapping exists. - - Args: - table_config: Table configuration - - Returns: - Path to Parquet file (or directory for partitioned tables) - """ - # Extract bucket name from table ID (e.g., "in.c-crm" from "in.c-crm.company") - bucket_name = ".".join(table_config.id.split(".")[:-1]) - - # Use folder mapping if available, otherwise fall back to bucket name - folder_name = self.folder_mapping.get(bucket_name, bucket_name) - - # Table-level folder override (e.g., folder: kbc_telemetry_expert) - if table_config.folder: - folder_name = table_config.folder - - parquet_dir = self.data_dir / "parquet" / folder_name - parquet_dir.mkdir(parents=True, exist_ok=True) - - if table_config.is_partitioned(): - # For partitioned tables, return directory path - partition_dir = parquet_dir / table_config.name - partition_dir.mkdir(parents=True, exist_ok=True) - return partition_dir - else: - return parquet_dir / f"{table_config.name}.parquet" - - def get_partition_path(self, table_config: TableConfig, partition_key: str) -> Path: - """ - Get path to specific partition file. - - Args: - table_config: Table configuration (must be partitioned) - partition_key: Partition key (e.g., "2026_01" for monthly) - - Returns: - Path to partition Parquet file - """ - if not table_config.is_partitioned(): - raise ValueError(f"Table {table_config.id} is not partitioned") - - partition_dir = self.get_parquet_path(table_config) - return partition_dir / f"{partition_key}.parquet" - - def get_metadata_path(self) -> Path: - """ - Get path to metadata folder. - - Returns: - Path to metadata folder - """ - metadata_dir = self.data_dir / "metadata" - metadata_dir.mkdir(parents=True, exist_ok=True) - return metadata_dir - - def get_staging_path(self) -> Path: - """ - Get path to staging folder for temporary files. - - Uses /tmp/data_analyst_staging for faster I/O and to avoid filling /data disk. - Directory is created by deploy.sh on server startup. - - Returns: - Path to staging folder - """ - staging_dir = Path("/tmp/data_analyst_staging") - staging_dir.mkdir(parents=True, exist_ok=True) - return staging_dir - - def get_duckdb_path(self) -> Path: - """ - Get path to DuckDB database. - - Returns: - Path to DuckDB file - """ - duckdb_dir = self.data_dir / "duckdb" - duckdb_dir.mkdir(parents=True, exist_ok=True) - return duckdb_dir / "analytics.duckdb" - - -# Singleton instance for easy access from entire application -_config_instance: Optional[Config] = None - - -def get_config() -> Config: - """ - Get singleton configuration instance. - - On first call initializes configuration, then returns existing instance. - - Returns: - Config instance - """ - global _config_instance - if _config_instance is None: - _config_instance = Config() - return _config_instance - - -# For testing - allows resetting config -def reset_config(): - """Reset singleton config instance. For testing only.""" - global _config_instance - _config_instance = None - - -if __name__ == "__main__": - # Test configuration - print("🔧 Testing configuration...") - - try: - config = get_config() - - print(f"\n✅ Configuration loaded successfully!") - print(f" Project ID: {config.keboola_project_id}") - print(f" Stack URL: {config.keboola_stack_url}") - print(f" Data dir: {config.data_dir}") - print(f" Number of tables: {len(config.tables)}") - - print(f"\n📊 Tables:") - for table in config.tables: - print(f" - {table.name} ({table.id})") - print(f" Strategy: {table.sync_strategy}") - if table.sync_strategy == "incremental": - print(f" Incremental window: {table.incremental_window_days} days") - if table.partition_by: - print(f" Partitioned by: {table.partition_by} ({table.partition_granularity})") - print(f" Parquet: {config.get_parquet_path(table)}") - - except Exception as e: - print(f"\n❌ Error: {e}") - import traceback - traceback.print_exc() diff --git a/src/data_sync.py b/src/data_sync.py deleted file mode 100644 index 34f97f0..0000000 --- a/src/data_sync.py +++ /dev/null @@ -1,734 +0,0 @@ -""" -Data Synchronization Manager - -Orchestrates data synchronization from configured sources to local Parquet files. - -Main functions: -1. Tracking sync state (when was last synchronization) -2. DataSource ABC for pluggable connectors -3. Sync single table or all tables at once -4. Progress tracking and error handling -5. Schema generation from synced Parquet files - -Sync State: -- Stored in data/metadata/sync_state.json -- Contains timestamp of last synchronization for each table -- Used for incremental sync (changedSince parameter) -""" - -import importlib -import json -import logging -from abc import ABC, abstractmethod -from pathlib import Path -from typing import Dict, List, Optional, Any -from datetime import datetime - -from tqdm import tqdm - -from .config import get_config, TableConfig -from config.loader import load_instance_config -from connectors.openmetadata.enricher import CatalogEnricher - - -logger = logging.getLogger(__name__) - - -class SyncState: - """ - Synchronization state management. - - Stores and loads information about last synchronization of each table. - """ - - def __init__(self, state_file: Path): - """ - Args: - state_file: Path to JSON file with sync state - """ - self.state_file = state_file - self.state: Dict[str, Any] = self._load_state() - - def _load_state(self) -> Dict[str, Any]: - """ - Load sync state from disk. - - Returns: - Dictionary with sync state - """ - if self.state_file.exists(): - try: - with open(self.state_file, "r") as f: - return json.load(f) - except Exception as e: - logger.error(f"Error loading sync state: {e}") - return {} - return {} - - def _save_state(self): - """ - Save sync state to disk. - - Creates data/metadata/ directory if needed. - """ - try: - self.state_file.parent.mkdir(parents=True, exist_ok=True) - with open(self.state_file, "w") as f: - json.dump(self.state, f, indent=2, default=str) - except Exception as e: - logger.error(f"Error saving sync state: {e}") - - def get_last_sync(self, table_id: str) -> Optional[str]: - """ - Get timestamp of last synchronization for given table. - - Args: - table_id: Table identifier - - Returns: - ISO timestamp string, or None if not synced yet - """ - table_state = self.state.get(table_id, {}) - return table_state.get("last_sync") - - def get_table_state(self, table_id: str) -> Dict[str, Any]: - """ - Get complete sync state for a table. - - Args: - table_id: Table identifier - - Returns: - Dictionary with table sync state - """ - return self.state.get(table_id, {}) - - def update_sync( - self, - table_id: str, - table_name: str, - strategy: str, - rows: int, - file_size_bytes: int, - columns: int = 0, - uncompressed_bytes: int = 0, - ): - """ - Update synchronization state for a table. - - Args: - table_id: Table identifier - table_name: Human-readable table name - strategy: Sync strategy used - rows: Number of rows synced - file_size_bytes: Size of Parquet file in bytes - columns: Number of columns - uncompressed_bytes: Uncompressed data size - """ - self.state[table_id] = { - "table_name": table_name, - "last_sync": datetime.now().isoformat(), - "strategy": strategy, - "rows": rows, - "columns": columns, - "file_size_mb": round(file_size_bytes / 1024 / 1024, 2), - "uncompressed_mb": round(uncompressed_bytes / 1024 / 1024, 2), - } - - self._save_state() - - -class DataSource(ABC): - """ - Abstract class for data source. - - Connectors implement this to integrate different data backends. - See connectors/keboola/ for a reference implementation. - """ - - @abstractmethod - def sync_table( - self, - table_config: TableConfig, - sync_state: SyncState, - ) -> Dict[str, Any]: - """ - Synchronize single table. - - Args: - table_config: Table configuration - sync_state: Sync state manager - - Returns: - Dictionary with sync result - """ - pass - - def discover_tables(self) -> List[Dict[str, Any]]: - """List all available tables in the data source. - - Returns list of dicts with at minimum: - id, name, bucket_id, columns, row_count, size_bytes, - primary_key, last_change - Default: empty list (source doesn't support discovery). - """ - return [] - - def get_column_metadata(self, table_id: str) -> Optional[Dict[str, Any]]: - """Return processed column metadata for schema generation. - - Returns: - {"columns": {"col_name": {"source_type": "...", "description": "..."}}} - or None if the source doesn't support metadata. - """ - return None - - def get_source_name(self) -> str: - """Display name of this data source for schema comments.""" - return "Unknown" - - -def _get_uncompressed_size(parquet_path: Path) -> int: - """Read total uncompressed size from Parquet file metadata.""" - try: - import pyarrow.parquet as pq - - meta = pq.ParquetFile(parquet_path).metadata - total = 0 - for rg_idx in range(meta.num_row_groups): - rg = meta.row_group(rg_idx) - for col_idx in range(rg.num_columns): - total += rg.column(col_idx).total_uncompressed_size - return total - except Exception: - return 0 - - -class DataSyncManager: - """ - Main data synchronization orchestrator. - - Manages sync of all tables and tracks results. - """ - - def __init__(self): - """Initialize sync manager.""" - self.config = get_config() - self.sync_state = SyncState( - self.config.get_metadata_path() / "sync_state.json" - ) - self.data_source = create_data_source() - - # Initialize OpenMetadata catalog enricher - try: - instance_config = load_instance_config() - self.catalog_enricher = CatalogEnricher(instance_config) - except Exception as e: - logger.warning(f"Failed to initialize catalog enricher: {e}") - self.catalog_enricher = CatalogEnricher({}) # Disabled enricher - - def _generate_schema_yaml(self): - """ - Generate schema.yml file with actual table schemas from Parquet files. - - This file is auto-generated and contains: - - Table names and descriptions - - Column names, types (from Parquet), and descriptions (from source metadata) - - Primary keys - - Output: DOCS_OUTPUT_DIR/schema.yml (default: ./docs/schema.yml) - """ - import yaml - import pyarrow.parquet as pq - - source_name = self.data_source.get_source_name() - - logger.info("Generating schema.yml from synced tables...") - - schema_data = { - "_metadata": { - "_schema_version": 2, - "generated_at": datetime.now().isoformat(), - "note": "AUTO-GENERATED - DO NOT EDIT. This file contains actual table schemas from synced Parquet files.", - "source": source_name, - "generator": "src/data_sync.py::DataSyncManager._generate_schema_yaml()", - }, - "tables": {}, - } - - # Process each table in configuration - for table_config in self.config.tables: - try: - parquet_path = self.config.get_parquet_path(table_config) - - # Skip if Parquet doesn't exist (table not synced yet) - if table_config.partition_by: - if not parquet_path.exists() or not list(parquet_path.glob("*.parquet")): - logger.debug(f" Skipping {table_config.name} (not synced yet)") - continue - first_partition = next(parquet_path.glob("*.parquet")) - pf = pq.ParquetFile(first_partition) - else: - if not parquet_path.exists(): - logger.debug(f" Skipping {table_config.name} (not synced yet)") - continue - pf = pq.ParquetFile(parquet_path) - - arrow_schema = pf.schema_arrow - - # Get column metadata from data source (if supported) - col_metadata = self.data_source.get_column_metadata(table_config.id) - - # Enrich with catalog metadata (OpenMetadata) - catalog_data = self.catalog_enricher.enrich_table(table_config) - - # Extract column information - columns = [] - for field_item in arrow_schema: - col_name = field_item.name - col_name_lower = col_name.lower() - pyarrow_type = str(field_item.type) - - column_info = { - "name": col_name, - "type": pyarrow_type, - } - - # Priority for description: catalog > BQ API > (nothing) - description = None - if catalog_data and col_name_lower in catalog_data.columns: - description = catalog_data.columns[col_name_lower].description - elif col_metadata and "columns" in col_metadata: - col_meta = col_metadata["columns"].get(col_name, {}) - description = col_meta.get("description") - - if description: - column_info["description"] = description - - # Add source type from connector metadata - if col_metadata and "columns" in col_metadata: - col_meta = col_metadata["columns"].get(col_name, {}) - if "source_type" in col_meta: - column_info["source_type"] = col_meta["source_type"] - - columns.append(column_info) - - primary_key = table_config.get_primary_key_columns() - - # Priority for table description: catalog > data_description.md - table_description = table_config.description - if catalog_data: - table_description = catalog_data.description or table_description - - table_info = { - "table_id": table_config.id, - "description": table_description, - "primary_key": primary_key, - "sync_strategy": table_config.sync_strategy, - "columns": columns, - } - - if table_config.partition_by: - table_info["partitioned_by"] = table_config.partition_by - - schema_data["tables"][table_config.name] = table_info - - logger.debug(f" {table_config.name}: {len(columns)} columns") - - except Exception as e: - logger.warning(f" Error processing {table_config.name}: {e}") - - # Split tables into core (no dataset) and per-dataset groups - core_tables = {} - dataset_tables = {} # {dataset_name: {table_name: table_info}} - for table_name, table_info in schema_data["tables"].items(): - table_config = next( - (t for t in self.config.tables if t.name == table_name), None - ) - if table_config and table_config.dataset: - ds = table_config.dataset - if ds not in dataset_tables: - dataset_tables[ds] = {} - dataset_tables[ds][table_name] = table_info - else: - core_tables[table_name] = table_info - - generated_at = schema_data["_metadata"]["generated_at"] - - def _write_schema_file(filepath, tables, note=""): - """Write a schema YAML file with header comments.""" - filepath.parent.mkdir(parents=True, exist_ok=True) - data = { - "_metadata": { - "_schema_version": 2, - "generated_at": generated_at, - "note": "AUTO-GENERATED - DO NOT EDIT.", - "source": source_name, - "generator": "src/data_sync.py::DataSyncManager._generate_schema_yaml()", - }, - "tables": tables, - } - with open(filepath, "w") as f: - f.write("# AUTO-GENERATED - DO NOT EDIT\n") - f.write("# This file is automatically generated during data sync\n") - f.write(f"# Generated: {generated_at}\n") - if note: - f.write(f"# {note}\n") - f.write("#\n") - f.write("# Contains actual table schemas from synced Parquet files:\n") - f.write("# - Column names and PyArrow types (from Parquet)\n") - f.write(f"# - Source types and descriptions (from {source_name})\n") - f.write("# - Primary keys and sync strategies\n") - f.write("#\n") - f.write("# For architectural documentation and relationships, see data_description.md\n") - f.write("\n") - yaml.dump(data, f, default_flow_style=False, sort_keys=False, allow_unicode=True) - - # Write core schema.yml - schema_file = self.config.docs_output_dir / "schema.yml" - _write_schema_file(schema_file, core_tables) - logger.info(f"Core schema YAML: {len(core_tables)} tables -> {schema_file}") - - # Write per-dataset schema files - for ds_name, ds_tables in dataset_tables.items(): - ds_schema_file = self.config.docs_output_dir / "datasets" / ds_name / "schema.yml" - _write_schema_file(ds_schema_file, ds_tables, note=f"Dataset: {ds_name}") - logger.info(f"Dataset schema YAML: {len(ds_tables)} tables -> {ds_schema_file}") - - total = len(core_tables) + sum(len(t) for t in dataset_tables.values()) - logger.info(f"Schema generation complete: {total} tables total") - - return schema_file - - def sync_table(self, table_id: str) -> Dict[str, Any]: - """ - Synchronize single table by ID. - - Args: - table_id: Table ID to synchronize - - Returns: - Dictionary with sync result - """ - table_config = self.config.get_table_config(table_id) - if not table_config: - raise ValueError(f"Table {table_id} not found in configuration") - - return self.data_source.sync_table(table_config, self.sync_state) - - def sync_all(self, tables: Optional[List[str]] = None) -> Dict[str, Dict[str, Any]]: - """ - Synchronize all tables (or subset according to list). - - Args: - tables: List of table IDs to synchronize. If None, syncs all. - - Returns: - Dictionary {table_id: result} with sync results - """ - if tables: - table_configs = [ - self.config.get_table_config(tid) for tid in tables - ] - table_configs = [tc for tc in table_configs if tc is not None] - else: - table_configs = self.config.tables - - # Filter out remote-only tables (no local sync needed) - remote_skipped = [ - tc for tc in table_configs if tc.query_mode == "remote" - ] - table_configs = [ - tc for tc in table_configs if tc.query_mode != "remote" - ] - - if remote_skipped: - logger.info( - f"Skipping {len(remote_skipped)} remote-only tables " - f"(query via BigQuery): " - f"{', '.join(tc.name for tc in remote_skipped)}" - ) - - logger.info(f"Synchronizing {len(table_configs)} tables...") - - results = {} - with tqdm(table_configs, desc="Syncing tables") as pbar: - for table_config in pbar: - pbar.set_description(f"Sync: {table_config.name}") - - result = self.data_source.sync_table(table_config, self.sync_state) - results[table_config.id] = result - - if result["success"]: - pbar.write( - f" {table_config.name}: {result['rows']:,} rows, " - f"{result['file_size_mb']:.2f} MB" - ) - else: - pbar.write(f" {table_config.name}: {result['error']}") - - success_count = sum(1 for r in results.values() if r["success"]) - logger.info( - f"Synchronization completed: {success_count}/{len(results)} tables successful" - ) - - # Generate schema.yml from synced tables - if success_count > 0: - try: - self._generate_schema_yaml() - except Exception as e: - logger.warning(f"Failed to generate schema.yml: {e}") - - # Auto-profile changed tables - if success_count > 0: - self._auto_profile(results) - - return results - - def _auto_profile( - self, - results: Dict[str, Dict[str, Any]], - skip_tables: Optional[List[str]] = None, - ): - """Run profiler on successfully synced tables. - - Args: - results: Sync results dict {table_id: result} - skip_tables: Table IDs to skip profiling for - """ - skip_set = set(skip_tables or []) - try: - from src.profiler import profile_changed_tables - changed = [ - self.config.get_table_config(tid).name - for tid, r in results.items() - if r.get("success") - and self.config.get_table_config(tid) - and tid not in skip_set - ] - if changed: - result = profile_changed_tables(changed) - logger.info( - f"Auto-profiling: {result['success']} profiled, " - f"{result['errors']} errors, {result['skipped']} skipped" - ) - else: - logger.info("No tables to profile (all skipped or none succeeded)") - except Exception as e: - logger.warning(f"Auto-profiling failed (non-fatal): {e}") - - def sync_scheduled(self) -> Dict[str, Dict[str, Any]]: - """Synchronize only tables whose sync_schedule says they are due. - - Evaluates each table's sync_schedule against its last_sync timestamp. - Only syncs tables that are due. Respects profile_after_sync flag. - - Returns: - Dictionary {table_id: result} with sync results (only for synced tables) - """ - from src.scheduler import is_table_due - - scheduled_tables = [ - tc for tc in self.config.tables - if tc.sync_schedule and tc.query_mode != "remote" - ] - - if not scheduled_tables: - logger.info("No tables with sync_schedule configured") - return {} - - # Evaluate which tables are due - due_tables = [] - for tc in scheduled_tables: - last_sync = self.sync_state.get_last_sync(tc.id) - if is_table_due(tc.sync_schedule, last_sync): - due_tables.append(tc) - logger.info(f"Table {tc.name} is DUE (schedule: {tc.sync_schedule})") - else: - logger.debug(f"Table {tc.name} is not due (schedule: {tc.sync_schedule})") - - if not due_tables: - logger.info( - f"Checked {len(scheduled_tables)} scheduled tables, none are due" - ) - return {} - - logger.info( - f"Syncing {len(due_tables)}/{len(scheduled_tables)} due tables: " - f"{', '.join(tc.name for tc in due_tables)}" - ) - - # Sync due tables - results = {} - for table_config in due_tables: - try: - result = self.data_source.sync_table(table_config, self.sync_state) - results[table_config.id] = result - if result["success"]: - logger.info( - f" {table_config.name}: {result['rows']:,} rows, " - f"{result['file_size_mb']:.2f} MB" - ) - else: - logger.error(f" {table_config.name}: {result['error']}") - except Exception as e: - logger.error(f" {table_config.name}: sync failed: {e}") - results[table_config.id] = {"success": False, "error": str(e)} - - success_count = sum(1 for r in results.values() if r["success"]) - logger.info(f"Scheduled sync: {success_count}/{len(results)} tables successful") - - # Generate schema.yml - if success_count > 0: - try: - self._generate_schema_yaml() - except Exception as e: - logger.warning(f"Failed to generate schema.yml: {e}") - - # Profile only tables with profile_after_sync=True - skip_profiler = [ - tc.id for tc in due_tables if not tc.profile_after_sync - ] - if skip_profiler: - logger.info( - f"Skipping profiler for: " - f"{', '.join(self.config.get_table_config(tid).name for tid in skip_profiler)}" - ) - - profiled_any = False - if success_count > 0: - tables_to_profile = [ - tid for tid, r in results.items() - if r.get("success") and tid not in set(skip_profiler) - ] - if tables_to_profile: - self._auto_profile(results, skip_tables=skip_profiler) - profiled_any = True - - # Restart webapp if profiler ran (new profiles.json needs reload) - if profiled_any: - self._restart_webapp() - - return results - - def _restart_webapp(self): - """Restart webapp service to pick up new profiles.json.""" - import subprocess - try: - subprocess.run( - ["sudo", "systemctl", "restart", "webapp"], - check=True, - capture_output=True, - timeout=30, - ) - logger.info("Webapp restarted successfully") - except subprocess.CalledProcessError as e: - logger.warning(f"Failed to restart webapp: {e.stderr.decode() if e.stderr else e}") - except FileNotFoundError: - logger.debug("systemctl not found (not running on server)") - - -def create_sync_manager() -> DataSyncManager: - """ - Factory function to create DataSyncManager. - - Returns: - DataSyncManager instance - """ - return DataSyncManager() - - -def create_data_source(source_type: str = None) -> DataSource: - """Create a data source based on configuration. - - Args: - source_type: Override source type. If None, uses DATA_SOURCE env var. - - Returns: - DataSource instance - - Raises: - ValueError: If source type is unknown - ImportError: If connector dependencies are missing - """ - if source_type is None: - source_type = get_config().data_source - - if source_type in ("local", "keboola"): - try: - from connectors.keboola.adapter import KeboolaDataSource - except ModuleNotFoundError as e: - if "kbcstorage" in str(e): - raise ImportError( - "Keboola connector requires 'kbcstorage' package. " - "Install with: pip install kbcstorage" - ) from e - raise # Re-raise real import errors - return KeboolaDataSource() - - # Try dynamic connector import for other types - try: - mod = importlib.import_module(f"connectors.{source_type}.adapter") - factory = getattr(mod, "create_data_source", None) - if factory: - return factory() - # Fallback: look for a class named *DataSource - for attr_name in dir(mod): - attr = getattr(mod, attr_name) - if isinstance(attr, type) and issubclass(attr, DataSource) and attr is not DataSource: - return attr() - except ModuleNotFoundError: - pass - - raise ValueError( - f"Unknown data source: '{source_type}'. " - f"Available connectors: keboola, bigquery. " - f"Create connectors/{source_type}/adapter.py to add a new one." - ) - - -if __name__ == "__main__": - # CLI interface for sync - import sys - - scheduled_mode = "--scheduled" in sys.argv - table_args = [a for a in sys.argv[1:] if a != "--scheduled"] - - try: - manager = create_sync_manager() - - if scheduled_mode: - print("Data Sync (scheduled mode)") - results = manager.sync_scheduled() - - if not results: - print("No tables due for sync") - sys.exit(0) - elif table_args: - print("Data Sync") - print(f"\nSynchronizing selected tables: {', '.join(table_args)}") - results = manager.sync_all(tables=table_args) - else: - print("Data Sync") - print("\nSynchronizing all tables...") - results = manager.sync_all() - - success_count = sum(1 for r in results.values() if r["success"]) - total_count = len(results) - - if success_count == total_count: - print(f"\nAll {total_count} tables synchronized successfully!") - sys.exit(0) - else: - print( - f"\n{success_count}/{total_count} tables synchronized. " - f"Check logs for details." - ) - sys.exit(1) - - except Exception as e: - print(f"\nError: {e}") - import traceback - - traceback.print_exc() - sys.exit(1) diff --git a/src/parquet_manager.py b/src/parquet_manager.py deleted file mode 100644 index 9a45b61..0000000 --- a/src/parquet_manager.py +++ /dev/null @@ -1,755 +0,0 @@ -""" -Parquet File Manager - -Parquet file management: -1. CSV -> Parquet conversion with data type application -2. Compression (snappy) for space saving -3. Metadata embedding (table_id, export_date) -4. Information about existing Parquet files -5. Merge/upsert operations for incremental sync - -Parquet format advantages: -- Columnar storage -> faster analytical queries -- Compression -> smaller size than CSV -- Schema enforcement -> type safety -- Metadata support -> self-documenting -""" - -import logging -from pathlib import Path -from typing import Dict, Optional, List, Any -from datetime import datetime - -import pandas as pd -import pyarrow as pa -import pyarrow.parquet as pq - - -logger = logging.getLogger(__name__) - - -def convert_date_columns_to_date32(table: pa.Table, date_columns: List[str]) -> pa.Table: - """ - Convert timestamp/string columns to DATE32 type. - - Extracted from csv_to_parquet() for reuse in partitioned sync. - Invalid date values (like '0000-00-00') are converted to NULL, type stays DATE32. - - Args: - table: PyArrow table to convert - date_columns: List of column names to convert to DATE32 - - Returns: - PyArrow table with converted date columns - """ - if not date_columns: - return table - - schema_fields = [] - for i, field in enumerate(table.schema): - if field.name in date_columns: - schema_fields.append(pa.field(field.name, pa.date32())) - else: - schema_fields.append(field) - - # Cast columns to DATE32 - columns = [] - for i, field in enumerate(table.schema): - if field.name in date_columns: - col = table.column(i) - - # Skip if all nulls - nothing to convert, just cast type - if col.null_count == len(col): - columns.append(pa.nulls(len(col), type=pa.date32())) - continue - - # If column is string type, use pandas for robust parsing with errors='coerce' - # This converts invalid dates to NaT (NULL) while keeping the DATE type - if pa.types.is_string(col.type) or pa.types.is_large_string(col.type): - # Convert to pandas, parse dates with coerce (invalid -> NaT) - series = col.to_pandas() - parsed = pd.to_datetime(series, errors='coerce', format='mixed') - - # Count invalid values for logging - invalid_count = parsed.isna().sum() - series.isna().sum() - if invalid_count > 0: - # Find examples of invalid values - invalid_mask = series.notna() & parsed.isna() - examples = series[invalid_mask].head(3).tolist() - logger.warning( - f"Column '{field.name}': {invalid_count} invalid date values converted to NULL. " - f"Examples: {examples}" - ) - - # Convert to date only (remove time component) and then to PyArrow - date_series = parsed.dt.date - date_array = pa.array(date_series, type=pa.date32()) - columns.append(date_array) - else: - # Column is already timestamp/date type, just cast - try: - columns.append(col.cast(pa.date32())) - except Exception as e: - logger.warning( - f"Column '{field.name}': Failed to cast to date32, keeping original type. Error: {e}" - ) - columns.append(col) - schema_fields[i] = field - else: - columns.append(table.column(i)) - - # Rebuild table with new schema - new_schema = pa.schema(schema_fields, metadata=table.schema.metadata) - return pa.Table.from_arrays(columns, schema=new_schema) - - -def apply_schema_to_table(table: pa.Table, target_schema: pa.Schema) -> pa.Table: - """ - Apply target schema to PyArrow table, handling type mismatches gracefully. - - This function: - - Casts null-type columns to proper types (prevents DuckDB schema mismatch) - - Attempts safe casts for other type mismatches - - Keeps original type on cast failure (logs warning, no data loss) - - Preserves columns not in target schema with their inferred type - - Args: - table: PyArrow table to cast - target_schema: Target PyArrow schema - - Returns: - PyArrow table with applied schema - """ - if not target_schema: - return table - - # Build mapping of column names to target types - target_types = {field.name: field.type for field in target_schema} - - columns = [] - schema_fields = [] - - for i, field in enumerate(table.schema): - col = table.column(i) - col_name = field.name - - # Column not in target schema - keep as-is - if col_name not in target_types: - columns.append(col) - schema_fields.append(field) - continue - - target_type = target_types[col_name] - - # Case 1: Column has null type -> create typed null array - if pa.types.is_null(col.type): - columns.append(pa.nulls(len(col), type=target_type)) - schema_fields.append(pa.field(col_name, target_type)) - logger.debug(f"Column '{col_name}': converted null type to {target_type}") - continue - - # Case 2: Column type matches target -> keep as-is - if col.type == target_type: - columns.append(col) - schema_fields.append(pa.field(col_name, target_type)) - continue - - # Case 3: Type mismatch -> try safe cast - try: - casted = col.cast(target_type, safe=True) - columns.append(casted) - schema_fields.append(pa.field(col_name, target_type)) - logger.debug(f"Column '{col_name}': cast from {col.type} to {target_type}") - except Exception as e: - # Cast failed -> keep original type, log warning - logger.warning( - f"Column '{col_name}': cannot cast from {col.type} to {target_type}, " - f"keeping original type. Error: {e}" - ) - columns.append(col) - schema_fields.append(field) - - # Rebuild table with new schema - new_schema = pa.schema(schema_fields, metadata=table.schema.metadata) - return pa.Table.from_arrays(columns, schema=new_schema) - - -def _convert_column(series: pd.Series, dtype: str, col_name: str = "") -> pd.Series: - """ - Convert pandas Series to dtype, handling empty strings. - - Empty strings become NA/NaN for non-string types. - Logs warning if invalid (non-empty) values are found. - - Args: - series: Input pandas Series - dtype: Target dtype (e.g., "Int64", "float64", "boolean") - col_name: Column name for logging - - Returns: - Converted Series - """ - # Replace empty strings with NA for non-string types - if dtype != "object": - series = series.replace('', pd.NA) - - # Numeric types - use errors='coerce' but log invalid values - if dtype in ("Int64", "float64", "Float64"): - # Count non-null values before conversion - non_null_before = series.notna().sum() - - converted = pd.to_numeric(series, errors='coerce') - - # Count how many became NA after conversion (excluding already NA) - non_null_after = converted.notna().sum() - invalid_count = non_null_before - non_null_after - - if invalid_count > 0: - # Find examples of invalid values - invalid_mask = series.notna() & converted.isna() - examples = series[invalid_mask].head(3).tolist() - logger.warning( - f"Column '{col_name}': {invalid_count} invalid values converted to NULL. " - f"Examples: {examples}" - ) - - return converted.astype(dtype) - - # Boolean type - map string representations - if dtype == "boolean": - # If pandas already parsed as bool, just convert to nullable boolean - if series.dtype == bool or series.dtype == 'object' and series.dropna().apply(lambda x: isinstance(x, bool)).all(): - return series.astype(dtype) - - bool_map = { - 'true': True, 'false': False, - 'True': True, 'False': False, - 'TRUE': True, 'FALSE': False, - '1': True, '0': False, - 'yes': True, 'no': False, - 'Yes': True, 'No': False, - 'YES': True, 'NO': False, - } - # Log unknown values (non-empty strings that aren't in bool_map) - known_values = set(bool_map.keys()) - non_na_values = series.dropna() - unknown = non_na_values[~non_na_values.isin(known_values)] - if len(unknown) > 0: - examples = unknown.head(3).tolist() - logger.warning( - f"Column '{col_name}': {len(unknown)} unknown boolean values converted to NULL. " - f"Examples: {examples}" - ) - return series.map(bool_map).astype(dtype) - - # Default: direct conversion - return series.astype(dtype) - - -class ParquetManager: - """ - Parquet file manager. - - Provides methods for: - - CSV -> Parquet conversion - - Getting information about Parquet files - - Merge/upsert for incremental sync - """ - - def __init__(self): - """Initialize Parquet manager.""" - # Compression codec - snappy is fast and has good compression ratio - self.compression = "snappy" - - def csv_to_parquet( - self, - csv_path: Path, - parquet_path: Path, - dtypes: Optional[Dict[str, str]] = None, - table_id: Optional[str] = None, - parse_dates: Optional[List[str]] = None, - date_columns: Optional[List[str]] = None, - pyarrow_schema: Optional[pa.Schema] = None - ) -> Dict[str, Any]: - """ - Convert CSV file to Parquet format. - - Args: - csv_path: Path to source CSV file - parquet_path: Path where to save Parquet file - dtypes: Dictionary with data types for columns (pandas dtypes) - table_id: Table ID for metadata - parse_dates: List of columns with dates/timestamps to parse - date_columns: List of DATE-only columns (without time) to convert to PyArrow DATE32 - pyarrow_schema: Optional PyArrow schema to enforce (prevents null-type columns) - - Returns: - Dictionary with conversion information: - - rows: Number of rows - - columns: Number of columns - - file_size_bytes: Parquet file size - - compression_ratio: Compression ratio (CSV size / Parquet size) - - Raises: - Exception: If conversion fails - """ - logger.info(f"Converting CSV -> Parquet: {csv_path.name}") - - try: - # Load CSV into pandas DataFrame - # IMPORTANT: Use dtype=str to prevent pandas from guessing types - # We apply our own types from Keboola metadata using _convert_column - read_kwargs = {"dtype": str} - - # Get actual column names from CSV header first - with open(csv_path, 'r') as f: - header_line = f.readline().strip() - csv_columns = set(col.strip('"') for col in header_line.split(',')) - - # Parse datetime columns - only those that exist in CSV - if parse_dates: - valid_parse_dates = [col for col in parse_dates if col in csv_columns] - if valid_parse_dates: - read_kwargs["parse_dates"] = valid_parse_dates - elif dtypes: - # Auto-detect datetime columns from dtypes (only existing columns) - datetime_cols = [ - col for col, dtype in dtypes.items() - if "datetime" in dtype and col in csv_columns - ] - if datetime_cols: - read_kwargs["parse_dates"] = datetime_cols - - df = pd.read_csv(csv_path, **read_kwargs) - - logger.debug(f"CSV loaded: {len(df)} rows, {len(df.columns)} columns") - - # Apply dtypes using _convert_column to handle empty strings - if dtypes: - for col, dtype in dtypes.items(): - if col in df.columns and "datetime" not in dtype: - try: - df[col] = _convert_column(df[col], dtype, col_name=col) - except Exception as e: - logger.warning( - f"Failed to apply dtype {dtype} to column {col}: {e}" - ) - - # Add metadata as custom schema metadata - metadata = { - "created_at": datetime.now().isoformat(), - } - if table_id: - metadata["table_id"] = table_id - - # Convert to PyArrow Table - # PyArrow preserves pandas dtypes and adds metadata - table = pa.Table.from_pandas(df) - - # Convert DATE columns from timestamp/string to DATE32 - if date_columns: - table = convert_date_columns_to_date32(table, date_columns) - - # Apply explicit schema (prevents null-type columns) - if pyarrow_schema: - table = apply_schema_to_table(table, pyarrow_schema) - - # Add custom metadata - existing_metadata = table.schema.metadata or {} - new_metadata = { - **existing_metadata, - **{k.encode(): v.encode() for k, v in metadata.items()} - } - table = table.replace_schema_metadata(new_metadata) - - # Ensure output folder exists - parquet_path.parent.mkdir(parents=True, exist_ok=True) - - # Write to Parquet - pq.write_table( - table, - parquet_path, - compression=self.compression - ) - - # Get file sizes for compression ratio - csv_size = csv_path.stat().st_size - parquet_size = parquet_path.stat().st_size - compression_ratio = csv_size / parquet_size if parquet_size > 0 else 0 - - result = { - "rows": len(df), - "columns": len(df.columns), - "csv_size_bytes": csv_size, - "parquet_size_bytes": parquet_size, - "compression_ratio": compression_ratio, - "parquet_path": str(parquet_path) - } - - logger.info( - f"Parquet created: {len(df)} rows, " - f"{parquet_size / 1024 / 1024:.2f} MB, " - f"compression {compression_ratio:.2f}x" - ) - - return result - - except Exception as e: - logger.error(f"Error converting CSV -> Parquet: {e}") - raise - - def get_parquet_info(self, parquet_path: Path) -> Optional[Dict[str, Any]]: - """ - Get information about existing Parquet file. - - Args: - parquet_path: Path to Parquet file - - Returns: - Dictionary with information: - - rows: Number of rows - - columns: Number of columns - - file_size_bytes: File size - - modified_at: Last modification timestamp - - schema: PyArrow schema - - metadata: Custom metadata - Or None if file doesn't exist. - """ - if not parquet_path.exists(): - return None - - try: - # Load Parquet metadata (without loading data) - parquet_file = pq.ParquetFile(parquet_path) - - # Basic info - file_size = parquet_path.stat().st_size - modified_at = datetime.fromtimestamp(parquet_path.stat().st_mtime) - - # Schema and metadata - schema = parquet_file.schema_arrow - custom_metadata = {} - if schema.metadata: - custom_metadata = { - k.decode(): v.decode() - for k, v in schema.metadata.items() - if k.decode() not in ["pandas"] # Filter pandas internal metadata - } - - info = { - "rows": parquet_file.metadata.num_rows, - "columns": len(schema), - "file_size_bytes": file_size, - "modified_at": modified_at.isoformat(), - "schema": schema, - "metadata": custom_metadata, - "parquet_path": str(parquet_path) - } - - return info - - except Exception as e: - logger.error(f"Error reading Parquet info: {e}") - return None - - def merge_parquet( - self, - existing_parquet: Path, - new_csv: Path, - output_parquet: Path, - primary_key: List[str], - dtypes: Optional[Dict[str, str]] = None, - parse_dates: Optional[List[str]] = None, - date_columns: Optional[List[str]] = None, - pyarrow_schema: Optional[pa.Schema] = None - ) -> Dict[str, Any]: - """ - Merge new data from CSV into existing Parquet file. - - Performs upsert operation: - - Rows with existing primary_key are updated - - Rows with new primary_key are added - - Args: - existing_parquet: Path to existing Parquet file - new_csv: Path to CSV with new data - output_parquet: Path where to save resulting Parquet - primary_key: List of column names forming the primary key (supports composite PK) - dtypes: Dictionary with data types - parse_dates: List of datetime columns - date_columns: List of DATE-only columns to convert to PyArrow DATE32 - pyarrow_schema: Optional PyArrow schema to enforce (prevents null-type columns) - - Returns: - Dictionary with information: - - total_rows: Total number of rows after merge - - added_rows: Number of newly added rows - - updated_rows: Number of updated rows - - unchanged_rows: Number of unchanged rows - - Raises: - Exception: If merge fails - """ - pk_str = ", ".join(primary_key) - logger.info( - f"Merging Parquet: {existing_parquet.name} + {new_csv.name} -> {output_parquet.name} (PK: {pk_str})" - ) - - try: - # Load existing Parquet - existing_df = pd.read_parquet(existing_parquet) - original_count = len(existing_df) - - logger.debug(f"Existing data: {original_count} rows") - - # Load new data from CSV - # IMPORTANT: Use dtype=str to prevent pandas from guessing types - # We apply our own types from Keboola metadata using _convert_column - read_kwargs = {"dtype": str} - - if parse_dates: - read_kwargs["parse_dates"] = parse_dates - elif dtypes: - datetime_cols = [ - col for col, dtype in dtypes.items() - if "datetime" in dtype - ] - if datetime_cols: - read_kwargs["parse_dates"] = datetime_cols - - new_df = pd.read_csv(new_csv, **read_kwargs) - - # Apply dtypes using _convert_column to handle empty strings - if dtypes: - for col, dtype in dtypes.items(): - if col in new_df.columns and "datetime" not in dtype: - try: - new_df[col] = _convert_column(new_df[col], dtype, col_name=col) - except Exception as e: - logger.warning( - f"Failed to apply dtype {dtype} to column {col}: {e}" - ) - - new_count = len(new_df) - - logger.debug(f"New data: {new_count} rows") - - # Check that all primary_key columns exist in both dataframes - for pk_col in primary_key: - if pk_col not in existing_df.columns: - raise ValueError( - f"Primary key column '{pk_col}' not found in existing data" - ) - if pk_col not in new_df.columns: - raise ValueError( - f"Primary key column '{pk_col}' not found in new data" - ) - - # Perform upsert: concat and then drop_duplicates with keep='last' - # Keep='last' means that new data (which is second) will overwrite old - merged_df = pd.concat([existing_df, new_df], ignore_index=True) - merged_df = merged_df.drop_duplicates(subset=primary_key, keep='last') - - # Calculate statistics - final_count = len(merged_df) - added_rows = final_count - original_count - # Updated rows = rows that were in both datasets - updated_rows = new_count - added_rows if added_rows < new_count else 0 - unchanged_rows = original_count - updated_rows - - logger.info( - f"Merge completed: {final_count} total rows " - f"(+{added_rows} new, ~{updated_rows} updates)" - ) - - # Save as new Parquet - # Prepare metadata - metadata = { - "created_at": datetime.now().isoformat(), - "merged_from": new_csv.name - } - - table = pa.Table.from_pandas(merged_df) - - # Convert DATE columns from timestamp to DATE32 - if date_columns: - table = convert_date_columns_to_date32(table, date_columns) - - # Apply explicit schema (prevents null-type columns) - if pyarrow_schema: - table = apply_schema_to_table(table, pyarrow_schema) - - new_metadata = { - k.encode(): v.encode() for k, v in metadata.items() - } - table = table.replace_schema_metadata(new_metadata) - - # Ensure output folder exists - output_parquet.parent.mkdir(parents=True, exist_ok=True) - - # Write Parquet - pq.write_table(table, output_parquet, compression=self.compression) - - result = { - "total_rows": final_count, - "total_columns": len(merged_df.columns), - "added_rows": added_rows, - "updated_rows": updated_rows, - "unchanged_rows": unchanged_rows, - "parquet_path": str(output_parquet) - } - - return result - - except Exception as e: - logger.error(f"Error merging Parquet: {e}") - raise - - def validate_parquet(self, parquet_path: Path) -> bool: - """ - Validate that Parquet file is readable and not corrupted. - - Args: - parquet_path: Path to Parquet file - - Returns: - True if file is valid, False otherwise - """ - try: - # Try to load schema (fast operation) - parquet_file = pq.ParquetFile(parquet_path) - _ = parquet_file.schema_arrow - - # Try to load first row (data validation) - # Note: pyarrow 23.0 doesn't have nrows parameter, load full file then limit - df = pd.read_parquet(parquet_path) - _ = df.head(1) - - return True - except Exception as e: - logger.error(f"Parquet validation failed: {e}") - return False - - -def create_parquet_manager() -> ParquetManager: - """ - Factory function to create ParquetManager instance. - - Returns: - ParquetManager instance - """ - return ParquetManager() - - -if __name__ == "__main__": - # Test Parquet manager - print("📦 Testing Parquet manager...") - - import tempfile - - try: - manager = create_parquet_manager() - - # Create test CSV - with tempfile.TemporaryDirectory() as tmpdir: - tmpdir = Path(tmpdir) - - # Test data - test_csv = tmpdir / "test.csv" - test_csv.write_text( - "id,name,value,created_at\n" - "1,Alice,100,2026-01-01\n" - "2,Bob,200,2026-01-02\n" - "3,Charlie,300,2026-01-03\n" - ) - - test_parquet = tmpdir / "test.parquet" - - # Test 1: CSV -> Parquet - print("\n1️⃣ Testing CSV -> Parquet conversion...") - result = manager.csv_to_parquet( - csv_path=test_csv, - parquet_path=test_parquet, - dtypes={"id": "Int64", "name": "object", "value": "Int64"}, - parse_dates=["created_at"], - table_id="test.table" - ) - print(f" ✅ Conversion OK:") - print(f" Rows: {result['rows']}") - print(f" Compression: {result['compression_ratio']:.2f}x") - - # Test 2: Parquet info - print("\n2️⃣ Testing Parquet info...") - info = manager.get_parquet_info(test_parquet) - if info: - print(f" ✅ Info loaded:") - print(f" Rows: {info['rows']}") - print(f" Metadata: {info['metadata']}") - - # Test 3: Validation - print("\n3️⃣ Testing validation...") - if manager.validate_parquet(test_parquet): - print(" ✅ Parquet is valid!") - - # Test 4: Merge - print("\n4️⃣ Testing merge...") - # Create update CSV - update_csv = tmpdir / "update.csv" - update_csv.write_text( - "id,name,value,created_at\n" - "2,Bob Updated,250,2026-01-04\n" # Update - "4,David,400,2026-01-05\n" # New - ) - - merged_parquet = tmpdir / "merged.parquet" - merge_result = manager.merge_parquet( - existing_parquet=test_parquet, - new_csv=update_csv, - output_parquet=merged_parquet, - primary_key=["id"], # Now uses list - dtypes={"id": "Int64", "name": "object", "value": "Int64"}, - parse_dates=["created_at"] - ) - print(f" ✅ Merge OK:") - print(f" Total rows: {merge_result['total_rows']}") - print(f" Added: {merge_result['added_rows']}") - print(f" Updated: {merge_result['updated_rows']}") - - # Test 5: Empty string handling - print("\n5️⃣ Testing empty string handling...") - empty_csv = tmpdir / "empty_strings.csv" - empty_csv.write_text( - "id,is_active,revenue,note\n" - '1,true,100.5,text\n' - '2,,200.0,\n' # Empty boolean and string - '3,false,,note\n' # Empty float - '4,TRUE,N/A,\n' # Invalid float value - ) - - empty_parquet = tmpdir / "empty_strings.parquet" - result = manager.csv_to_parquet( - csv_path=empty_csv, - parquet_path=empty_parquet, - dtypes={ - "id": "Int64", - "is_active": "boolean", - "revenue": "float64", - "note": "object" - }, - table_id="test.empty_strings" - ) - print(f" ✅ Conversion with empty strings OK:") - print(f" Rows: {result['rows']}") - - # Verify the data - df = pd.read_parquet(empty_parquet) - print(f" Dtypes: {dict(df.dtypes)}") - print(f" is_active nulls: {df['is_active'].isna().sum()}") - print(f" revenue nulls: {df['revenue'].isna().sum()}") - - print("\n✅ All tests passed!") - - except Exception as e: - print(f"\n❌ Error: {e}") - import traceback - traceback.print_exc() diff --git a/src/remote_query.py b/src/remote_query.py deleted file mode 100644 index 7132d2b..0000000 --- a/src/remote_query.py +++ /dev/null @@ -1,636 +0,0 @@ -""" -Remote Query - Execute DuckDB queries spanning local Parquet + remote BigQuery tables. - -Provides a server-side CLI for the AI agent to run SQL queries that JOIN local -(Parquet-backed) tables with on-demand BigQuery results. Designed for tables too -large to sync locally (e.g., daily_deal_traffic: ~3M rows/day). - -Two-phase query protocol: -1. BQ sub-queries (--register-bq "alias=SQL") run on BigQuery, results registered - as DuckDB views via PyArrow (reuses register_bq_table from duckdb_manager). -2. DuckDB SQL (--sql) runs against local Parquet views + registered BQ results. - -Usage: - python -m src.remote_query \\ - --sql "SELECT ... FROM order_economics o JOIN traffic t ON ..." \\ - --register-bq "traffic=SELECT ... FROM \\`project.dataset.table\\` WHERE ..." \\ - --format table - -Safety features: -- COUNT(*) pre-check before fetching BQ data -- Memory estimation (refuses queries > 2 GB estimated) -- Configurable row limits (per BQ sub-query and final result) -- Query timeout support -""" - -import argparse -import csv -import io -import json -import logging -import os -import sys -import time -from pathlib import Path -from typing import Optional - -import duckdb - -from config.loader import get_instance_value -from scripts.duckdb_manager import ( - create_local_views, - register_bq_table, - _create_bq_client, -) - -logger = logging.getLogger(__name__) - - -class RemoteQueryError(Exception): - """Error during remote query execution.""" - - -# --------------------------------------------------------------------------- -# Configuration -# --------------------------------------------------------------------------- - -def _load_remote_query_config() -> dict: - """Load remote_query settings from instance.yaml with defaults. - - Uses raw YAML loading instead of load_instance_config() to avoid - requiring webapp secrets (WEBAPP_SECRET_KEY etc.) that analysts - don't have access to. - """ - import yaml as _yaml - from pathlib import Path as _Path - - instance_config: dict = {} - config_dir = _Path(os.environ.get("CONFIG_DIR", "./config")) - yaml_path = config_dir / "instance.yaml" - if yaml_path.exists(): - try: - with open(yaml_path) as f: - instance_config = _yaml.safe_load(f) or {} - except Exception as e: - logger.warning("Could not load instance.yaml: %s. Using defaults.", e) - - return { - "timeout_seconds": get_instance_value( - instance_config, "remote_query", "timeout_seconds", default=300, - ), - "max_result_rows": get_instance_value( - instance_config, "remote_query", "max_result_rows", default=100_000, - ), - "max_bq_registration_rows": get_instance_value( - instance_config, "remote_query", "max_bq_registration_rows", default=500_000, - ), - "default_format": get_instance_value( - instance_config, "remote_query", "default_format", default="table", - ), - "output_dir": get_instance_value( - instance_config, "remote_query", "output_dir", default="/tmp/remote_query", - ), - } - - -# --------------------------------------------------------------------------- -# BQ registration with safety checks -# --------------------------------------------------------------------------- - -def _validate_bq_result_size( - bq_client, sql: str, alias: str, max_rows: int, -) -> int: - """Execute COUNT(*) on the BQ sub-query before fetching all rows. - - Args: - bq_client: BigQuery client instance - sql: The BQ SQL query to count - alias: Alias name (for error messages) - max_rows: Maximum allowed rows - - Returns: - Row count - - Raises: - RemoteQueryError: If count exceeds max_rows - """ - count_sql = f"SELECT COUNT(*) AS cnt FROM ({sql})" - _log_progress(f" Counting rows for '{alias}'...") - - job = bq_client.query(count_sql) - result = job.result() - row_count = next(iter(result))[0] - - if row_count > max_rows: - raise RemoteQueryError( - f"BQ sub-query '{alias}' would return {row_count:,} rows " - f"(limit: {max_rows:,}). Add more WHERE filters or GROUP BY " - f"to reduce the result set." - ) - - return row_count - - -def _estimate_memory_mb(row_count: int, column_count: int) -> float: - """Estimate memory usage in MB for a PyArrow table. - - Uses ~50 bytes per cell as a rough average across data types. - """ - return (row_count * column_count * 50) / (1024 * 1024) - - -def _register_bq_views( - conn: duckdb.DuckDBPyConnection, - registrations: list[tuple[str, str]], - max_bq_rows: int, - timeout_seconds: int, - quiet: bool = False, -) -> dict[str, int]: - """Register BQ query results as DuckDB views with safety checks. - - Args: - conn: DuckDB connection - registrations: List of (alias, bq_sql) tuples - max_bq_rows: Max rows per sub-query - timeout_seconds: BQ job timeout - quiet: Suppress progress messages - - Returns: - Dict of {alias: row_count} - """ - if not registrations: - return {} - - bq_project = os.environ.get("BIGQUERY_PROJECT") - if not bq_project: - raise RemoteQueryError( - "BIGQUERY_PROJECT environment variable not set. " - "Required for BigQuery sub-queries." - ) - - bq_client = _create_bq_client(bq_project) - results = {} - - for alias, bq_sql in registrations: - start_time = time.time() - - # Phase 1: COUNT(*) safety check - row_count = _validate_bq_result_size(bq_client, bq_sql, alias, max_bq_rows) - _log_progress(f" '{alias}': {row_count:,} rows (within limit)") - - # Phase 2: Memory estimation - # Estimate column count from a LIMIT 0 query (cheap) - sample_job = bq_client.query(f"SELECT * FROM ({bq_sql}) LIMIT 0") - schema = sample_job.result().schema - col_count = len(schema) - estimated_mb = _estimate_memory_mb(row_count, col_count) - - if estimated_mb > 2048: # 2 GB = 25% of 8 GB server RAM - raise RemoteQueryError( - f"BQ sub-query '{alias}' estimated memory: {estimated_mb:.0f} MB " - f"({row_count:,} rows x {col_count} cols). " - f"Limit is 2048 MB. Add more aggregation or filters." - ) - - # Phase 3: Execute and register - _log_progress(f" Fetching '{alias}' ({row_count:,} rows, ~{estimated_mb:.0f} MB)...") - actual_rows = register_bq_table( - conn=conn, - table_id=f"bq_registration.{alias}", - view_name=alias, - sql=bq_sql, - bq_project=bq_project, - ) - - elapsed = time.time() - start_time - _log_progress(f" '{alias}' registered: {actual_rows:,} rows in {elapsed:.1f}s") - results[alias] = actual_rows - - return results - - -# --------------------------------------------------------------------------- -# Local view setup -# --------------------------------------------------------------------------- - -def _setup_local_views( - conn: duckdb.DuckDBPyConnection, - data_dir: str, - quiet: bool = False, -) -> list[str]: - """Create DuckDB views for all local/hybrid tables from Parquet. - - Args: - conn: DuckDB connection - data_dir: Path to data directory (e.g., "/data/src_data") - quiet: Suppress progress messages - - Returns: - List of created view names - """ - created, skipped = create_local_views( - conn=conn, - data_dir=data_dir, - verbose=not quiet, - ) - return created - - -# --------------------------------------------------------------------------- -# Output formatting -# --------------------------------------------------------------------------- - -def _print_table(columns: list[str], rows: list[tuple]) -> None: - """Print an aligned ASCII table to stdout.""" - if not rows: - print("(empty result)") - return - - # Calculate column widths - str_rows = [[str(v) if v is not None else "NULL" for v in row] for row in rows] - widths = [len(col) for col in columns] - for row in str_rows: - for i, val in enumerate(row): - widths[i] = max(widths[i], len(val)) - - # Header - header = " | ".join(col.ljust(widths[i]) for i, col in enumerate(columns)) - separator = "-+-".join("-" * widths[i] for i in range(len(columns))) - print(header) - print(separator) - - # Rows - for row in str_rows: - line = " | ".join(val.ljust(widths[i]) for i, val in enumerate(row)) - print(line) - - print(f"\n({len(rows)} rows)") - - -def _format_output( - conn: duckdb.DuckDBPyConnection, - sql: str, - fmt: str, - output_path: Optional[str], - max_rows: int, -) -> None: - """Execute final SQL and output results in the requested format. - - Args: - conn: DuckDB connection with all views registered - sql: The final DuckDB SQL query - fmt: Output format (table, csv, json, parquet) - output_path: File path for file-based outputs - max_rows: Maximum rows to return - """ - # Add LIMIT to prevent runaway results - limited_sql = f"SELECT * FROM ({sql}) AS _rq LIMIT {max_rows + 1}" - result = conn.execute(limited_sql) - columns = [desc[0] for desc in result.description] - rows = result.fetchall() - - # Check if result exceeded limit - if len(rows) > max_rows: - rows = rows[:max_rows] - _log_progress( - f" WARNING: Result truncated to {max_rows:,} rows. " - f"Add more filters or increase --max-rows." - ) - - if fmt == "table": - _print_table(columns, rows) - - elif fmt == "csv": - if output_path: - with open(output_path, "w", newline="") as f: - writer = csv.writer(f) - writer.writerow(columns) - writer.writerows(rows) - _log_progress(f" CSV written: {output_path} ({len(rows)} rows)") - else: - writer = csv.writer(sys.stdout) - writer.writerow(columns) - writer.writerows(rows) - - elif fmt == "json": - data = [dict(zip(columns, row)) for row in rows] - json_str = json.dumps(data, default=str, indent=2) - if output_path: - with open(output_path, "w") as f: - f.write(json_str) - _log_progress(f" JSON written: {output_path} ({len(rows)} rows)") - else: - print(json_str) - - elif fmt == "parquet": - import pyarrow as pa - import pyarrow.parquet as pq - - # Re-execute without limit wrapper for clean Arrow export - arrow_result = conn.execute( - f"SELECT * FROM ({sql}) AS _rq LIMIT {max_rows}" - ).arrow().read_all() - - if not output_path: - output_path = str(Path(_load_remote_query_config()["output_dir"]) / "result.parquet") - - Path(output_path).parent.mkdir(parents=True, exist_ok=True) - pq.write_table(arrow_result, output_path) - _log_progress( - f" Parquet written: {output_path} " - f"({arrow_result.num_rows} rows, {arrow_result.num_columns} cols)" - ) - - else: - raise RemoteQueryError(f"Unknown format: {fmt}") - - -# --------------------------------------------------------------------------- -# Progress logging (stderr so stdout stays clean for data) -# --------------------------------------------------------------------------- - -_quiet_mode = False - - -def _log_progress(msg: str) -> None: - """Print progress message to stderr (keeps stdout clean for data output).""" - if not _quiet_mode: - print(msg, file=sys.stderr) - - -# --------------------------------------------------------------------------- -# Main execution -# --------------------------------------------------------------------------- - -def execute_remote_query( - sql: str, - bq_registrations: list[tuple[str, str]], - fmt: str = "table", - output: Optional[str] = None, - max_rows: Optional[int] = None, - max_bq_rows: Optional[int] = None, - timeout: Optional[int] = None, - data_dir: str = "/data/src_data", - quiet: bool = False, -) -> None: - """Main execution function for remote queries. - - Args: - sql: DuckDB SQL query to execute - bq_registrations: List of (alias, bq_sql) tuples - fmt: Output format (table, csv, json, parquet) - output: Output file path (for parquet/csv/json) - max_rows: Max rows in final result - max_bq_rows: Max rows per BQ sub-query - timeout: Query timeout in seconds - data_dir: Path to data directory - quiet: Suppress progress messages - """ - global _quiet_mode - _quiet_mode = quiet - - config = _load_remote_query_config() - max_rows = max_rows or config["max_result_rows"] - max_bq_rows = max_bq_rows or config["max_bq_registration_rows"] - timeout = timeout or config["timeout_seconds"] - fmt = fmt or config["default_format"] - - start_time = time.time() - - # Create in-memory DuckDB connection - conn = duckdb.connect(":memory:") - - try: - # Step 1: Register local Parquet views - _log_progress("Setting up local views...") - local_views = _setup_local_views(conn, data_dir, quiet=quiet) - _log_progress(f" {len(local_views)} local views ready") - - # Step 2: Register BQ sub-query results - if bq_registrations: - _log_progress(f"Registering {len(bq_registrations)} BQ sub-queries...") - bq_results = _register_bq_views( - conn, bq_registrations, max_bq_rows, timeout, quiet=quiet, - ) - for alias, count in bq_results.items(): - _log_progress(f" {alias}: {count:,} rows") - - # Step 3: Execute the final DuckDB query - _log_progress("Executing query...") - _format_output(conn, sql, fmt, output, max_rows) - - elapsed = time.time() - start_time - _log_progress(f"Done in {elapsed:.1f}s") - - finally: - conn.close() - - -# --------------------------------------------------------------------------- -# CLI argument parsing -# --------------------------------------------------------------------------- - -def _parse_register_bq(value: str) -> tuple[str, str]: - """Parse --register-bq argument in 'alias=SQL' format. - - Args: - value: String in format "alias=SELECT ..." - - Returns: - Tuple of (alias, sql) - - Raises: - argparse.ArgumentTypeError: If format is invalid - """ - eq_pos = value.find("=") - if eq_pos <= 0: - raise argparse.ArgumentTypeError( - f"Invalid --register-bq format: '{value}'. " - f"Expected: 'alias=SELECT ...' (e.g., 'traffic=SELECT report_date, ...')" - ) - alias = value[:eq_pos].strip() - sql = value[eq_pos + 1:].strip() - if not sql: - raise argparse.ArgumentTypeError( - f"Empty SQL in --register-bq for alias '{alias}'" - ) - return alias, sql - - -def build_parser() -> argparse.ArgumentParser: - """Build the argument parser for remote_query CLI.""" - parser = argparse.ArgumentParser( - description="Execute DuckDB queries spanning local Parquet + remote BigQuery tables", - formatter_class=argparse.RawDescriptionHelpFormatter, - epilog=""" -Examples: - # Local-only query (no BigQuery): - python -m src.remote_query --sql "SELECT COUNT(*) FROM order_economics" - - # Register BQ result and query it: - python -m src.remote_query \\ - --register-bq "traffic=SELECT report_date, SUM(visitors) FROM \\`proj.ds.table\\` GROUP BY 1" \\ - --sql "SELECT * FROM traffic ORDER BY report_date" - - # JOIN local + remote: - python -m src.remote_query \\ - --register-bq "traffic=SELECT ... GROUP BY ..." \\ - --sql "SELECT o.*, t.visitors FROM order_economics o JOIN traffic t ON ..." \\ - --format parquet --output /tmp/result.parquet - """, - ) - parser.add_argument( - "--sql", - required=False, # not required when --stdin is used - default=None, - help="DuckDB SQL query (executed after all views are registered)", - ) - parser.add_argument( - "--register-bq", - action="append", - type=_parse_register_bq, - default=[], - metavar="ALIAS=SQL", - dest="bq_registrations", - help='Register BQ query result as DuckDB view. Format: "alias=BQ_SQL". Repeatable.', - ) - parser.add_argument( - "--format", - choices=["table", "csv", "json", "parquet"], - default=None, - dest="fmt", - help="Output format (default: from config or 'table')", - ) - parser.add_argument( - "--output", - default=None, - help="Output file path for parquet/csv/json (default: auto for parquet)", - ) - parser.add_argument( - "--max-rows", - type=int, - default=None, - help="Max rows in final result (default: from config)", - ) - parser.add_argument( - "--max-bq-rows", - type=int, - default=None, - help="Max rows per BQ sub-query (default: from config)", - ) - parser.add_argument( - "--timeout", - type=int, - default=None, - help="Query timeout in seconds (default: from config)", - ) - parser.add_argument( - "--data-dir", - default="/data/src_data", - help="Parquet data directory (default: /data/src_data)", - ) - parser.add_argument( - "--quiet", - action="store_true", - help="Suppress progress messages (stderr)", - ) - parser.add_argument( - "--stdin", - action="store_true", - help="Read query spec from stdin as JSON. Avoids shell escaping issues.", - ) - return parser - - -def _parse_stdin_query() -> dict: - """Parse query specification from stdin JSON. - - Expected format: - { - "sql": "SELECT ... FROM ...", - "register_bq": {"alias": "BQ SQL", ...}, - "format": "table", - "output": "/path/to/file", - "max_rows": 100000, - "max_bq_rows": 500000 - } - - Returns: - Dict with parsed query spec - """ - raw = sys.stdin.read().strip() - if not raw: - raise RemoteQueryError("Empty stdin. Provide JSON query spec.") - - try: - spec = json.loads(raw) - except json.JSONDecodeError as e: - raise RemoteQueryError(f"Invalid JSON on stdin: {e}") - - if "sql" not in spec: - raise RemoteQueryError("JSON must contain 'sql' field.") - - return spec - - -def main(argv: Optional[list[str]] = None) -> None: - """CLI entry point.""" - parser = build_parser() - args = parser.parse_args(argv) - - # Setup logging - logging.basicConfig( - level=logging.WARNING if args.quiet else logging.INFO, - format="%(asctime)s - %(name)s - %(levelname)s - %(message)s", - stream=sys.stderr, - ) - - try: - # --stdin mode: read query spec from JSON on stdin (no shell escaping needed) - if args.stdin: - spec = _parse_stdin_query() - bq_regs = [ - (alias, sql) for alias, sql in spec.get("register_bq", {}).items() - ] - execute_remote_query( - sql=spec["sql"], - bq_registrations=bq_regs, - fmt=spec.get("format", args.fmt), - output=spec.get("output", args.output), - max_rows=spec.get("max_rows", args.max_rows), - max_bq_rows=spec.get("max_bq_rows", args.max_bq_rows), - timeout=args.timeout, - data_dir=args.data_dir, - quiet=args.quiet, - ) - return - - # Validate --sql is provided when not using --stdin - if not args.sql: - parser.error("--sql is required (or use --stdin for JSON input)") - - execute_remote_query( - sql=args.sql, - bq_registrations=args.bq_registrations, - fmt=args.fmt, - output=args.output, - max_rows=args.max_rows, - max_bq_rows=args.max_bq_rows, - timeout=args.timeout, - data_dir=args.data_dir, - quiet=args.quiet, - ) - except RemoteQueryError as e: - print(f"ERROR: {e}", file=sys.stderr) - sys.exit(1) - except KeyboardInterrupt: - print("\nInterrupted.", file=sys.stderr) - sys.exit(130) - except Exception as e: - print(f"UNEXPECTED ERROR: {e}", file=sys.stderr) - logger.exception("Unexpected error in remote_query") - sys.exit(2) - - -if __name__ == "__main__": - main() diff --git a/src/table_registry.py b/src/table_registry.py deleted file mode 100644 index e8b84cf..0000000 --- a/src/table_registry.py +++ /dev/null @@ -1,464 +0,0 @@ -""" -Table Registry - Central source of truth for registered tables. - -Manages table registrations in a JSON file. Generates data_description.md -as a read-only output for downstream consumers (config.py, profiler.py, webapp). - -Supports: -- CRUD operations on registered tables -- Folder mapping (bucket -> folder name) -- Atomic persistence (tempfile + os.replace) -- Optimistic locking (version field) -- Audit logging -- One-time migration from existing data_description.md -- Generation of data_description.md with checksum header -""" - -import hashlib -import json -import logging -import os -import re -import tempfile -from datetime import datetime, timezone -from pathlib import Path -from typing import Any, Optional - -import yaml - -logger = logging.getLogger(__name__) - -# Default registry location -_DEFAULT_REGISTRY_DIR = Path( - os.environ.get("REGISTRY_DIR", "/data/src_data/metadata") -) -_REGISTRY_FILENAME = "table_registry.json" - - -def _now_iso() -> str: - """Return current UTC time as ISO string.""" - return datetime.now(timezone.utc).isoformat() - - -def _atomic_write_json(path: Path, data: dict) -> None: - """Write JSON atomically using tempfile + os.replace.""" - path.parent.mkdir(parents=True, exist_ok=True) - fd, tmp_path = tempfile.mkstemp( - dir=str(path.parent), suffix=".tmp" - ) - try: - with os.fdopen(fd, "w") as f: - json.dump(data, f, indent=2, default=str) - os.chmod(tmp_path, 0o660) - os.replace(tmp_path, str(path)) - except Exception: - try: - os.unlink(tmp_path) - except OSError: - pass - raise - - -def _audit_log(registry_path: Path, action: str, details: dict) -> None: - """Append entry to registry audit log.""" - audit_path = registry_path.parent / "registry_audit.log" - try: - entry = { - "timestamp": _now_iso(), - "action": action, - **details, - } - with open(audit_path, "a") as f: - f.write(json.dumps(entry, default=str) + "\n") - except Exception as e: - logger.warning(f"Could not write audit log: {e}") - - -class TableRegistry: - """Manages table registrations. Source of truth for what gets synced.""" - - def __init__(self, registry_path: Path): - self.registry_path = registry_path - self._data = self._load() - - @classmethod - def default(cls) -> "TableRegistry": - """Create registry at the default location.""" - return cls(_DEFAULT_REGISTRY_DIR / _REGISTRY_FILENAME) - - # ── Persistence ────────────────────────────────────────────────── - - def _load(self) -> dict: - """Load registry from disk. Returns empty structure if not found.""" - if self.registry_path.exists(): - try: - with open(self.registry_path) as f: - data = json.load(f) - logger.info( - f"Registry loaded: {len(data.get('tables', []))} tables" - ) - return data - except Exception as e: - logger.error(f"Error loading registry: {e}") - return self._empty_registry() - - def _save(self) -> None: - """Save registry to disk atomically.""" - self._data["_metadata"]["updated_at"] = _now_iso() - self._data["_metadata"]["version"] = self.version + 1 - _atomic_write_json(self.registry_path, self._data) - logger.debug("Registry saved (version %d)", self.version) - - @staticmethod - def _empty_registry() -> dict: - now = _now_iso() - return { - "_metadata": { - "version": 0, - "created_at": now, - "updated_at": now, - }, - "folder_mapping": {}, - "tables": [], - } - - # ── Properties ─────────────────────────────────────────────────── - - @property - def version(self) -> int: - return self._data.get("_metadata", {}).get("version", 0) - - # ── Core CRUD ──────────────────────────────────────────────────── - - def list_tables(self) -> list[dict]: - """Return all registered tables.""" - return list(self._data.get("tables", [])) - - def get_table(self, table_id: str) -> Optional[dict]: - """Get a single table by ID.""" - for t in self._data.get("tables", []): - if t["id"] == table_id: - return dict(t) - return None - - def is_registered(self, table_id: str) -> bool: - return any(t["id"] == table_id for t in self._data.get("tables", [])) - - def register_table( - self, - table_def: dict, - registered_by: str, - expected_version: Optional[int] = None, - ) -> None: - """Register a new table. - - Args: - table_def: Table definition dict (must contain id, name, sync_strategy, primary_key). - registered_by: Email of the admin who registered the table. - expected_version: If provided, reject if registry version doesn't match (optimistic lock). - - Raises: - ValueError: If table already registered or validation fails. - ConflictError: If expected_version doesn't match. - """ - if expected_version is not None and expected_version != self.version: - raise ConflictError( - f"Version conflict: expected {expected_version}, current {self.version}" - ) - - table_id = table_def.get("id", "") - if not table_id: - raise ValueError("Table definition must include 'id'") - - if self.is_registered(table_id): - raise ValueError(f"Table '{table_id}' is already registered") - - # Validate required fields - for field in ("name", "sync_strategy", "primary_key"): - if not table_def.get(field): - raise ValueError(f"Table definition must include '{field}'") - - # Validate sync_strategy - valid_strategies = ("full_refresh", "incremental", "partitioned") - if table_def["sync_strategy"] not in valid_strategies: - raise ValueError( - f"Invalid sync_strategy '{table_def['sync_strategy']}'. " - f"Allowed: {', '.join(valid_strategies)}" - ) - - # Build full record - record = { - "id": table_id, - "name": table_def["name"], - "description": table_def.get("description", ""), - "primary_key": table_def["primary_key"], - "sync_strategy": table_def["sync_strategy"], - "incremental_window_days": table_def.get("incremental_window_days"), - "partition_by": table_def.get("partition_by"), - "partition_granularity": table_def.get("partition_granularity"), - "foreign_keys": table_def.get("foreign_keys", []), - "where_filters": table_def.get("where_filters", []), - "folder": table_def.get("folder"), - "dataset": table_def.get("dataset"), - "initial_load_chunk_days": table_def.get("initial_load_chunk_days", 30), - "registered_at": _now_iso(), - "registered_by": registered_by, - "source_metadata": table_def.get("source_metadata", {}), - } - - self._data["tables"].append(record) - self._save() - - _audit_log(self.registry_path, "register", { - "table_id": table_id, - "by": registered_by, - }) - - def unregister_table( - self, - table_id: str, - unregistered_by: str = "", - expected_version: Optional[int] = None, - ) -> None: - """Remove a table from the registry. - - Raises: - ValueError: If table not found. - ConflictError: If expected_version doesn't match. - """ - if expected_version is not None and expected_version != self.version: - raise ConflictError( - f"Version conflict: expected {expected_version}, current {self.version}" - ) - - tables = self._data.get("tables", []) - new_tables = [t for t in tables if t["id"] != table_id] - - if len(new_tables) == len(tables): - raise ValueError(f"Table '{table_id}' is not registered") - - self._data["tables"] = new_tables - self._save() - - _audit_log(self.registry_path, "unregister", { - "table_id": table_id, - "by": unregistered_by, - }) - - def update_table( - self, - table_id: str, - updates: dict, - updated_by: str = "", - expected_version: Optional[int] = None, - ) -> None: - """Update table configuration. - - Raises: - ValueError: If table not found. - ConflictError: If expected_version doesn't match. - """ - if expected_version is not None and expected_version != self.version: - raise ConflictError( - f"Version conflict: expected {expected_version}, current {self.version}" - ) - - # Fields that can be updated - allowed_fields = { - "description", "primary_key", "sync_strategy", - "incremental_window_days", "partition_by", "partition_granularity", - "foreign_keys", "where_filters", "folder", "dataset", - "initial_load_chunk_days", - } - - for t in self._data.get("tables", []): - if t["id"] == table_id: - for key, value in updates.items(): - if key in allowed_fields: - t[key] = value - self._save() - _audit_log(self.registry_path, "update", { - "table_id": table_id, - "fields": list(updates.keys()), - "by": updated_by, - }) - return - - raise ValueError(f"Table '{table_id}' is not registered") - - # ── Folder mapping ─────────────────────────────────────────────── - - def get_folder_mapping(self) -> dict[str, str]: - return dict(self._data.get("folder_mapping", {})) - - def set_folder_mapping(self, bucket_id: str, folder: str) -> None: - self._data.setdefault("folder_mapping", {})[bucket_id] = folder - self._save() - - # ── Generation ─────────────────────────────────────────────────── - - def generate_data_description_md(self, output_path: Path) -> None: - """Regenerate data_description.md from registry. - - The generated file is read-only and includes a checksum header. - Existing readers (config.py, profiler.py) consume this without changes. - """ - tables = self.list_tables() - folder_mapping = self.get_folder_mapping() - - # Build YAML structure matching existing data_description.md format - yaml_data: dict[str, Any] = {} - - if folder_mapping: - yaml_data["folder_mapping"] = folder_mapping - - yaml_tables = [] - for t in tables: - entry: dict[str, Any] = { - "id": t["id"], - "name": t["name"], - "description": t.get("description", ""), - "primary_key": t["primary_key"], - "sync_strategy": t["sync_strategy"], - } - - # Optional fields -- only include if set - if t.get("incremental_window_days"): - entry["incremental_window_days"] = t["incremental_window_days"] - if t.get("partition_by"): - entry["partition_by"] = t["partition_by"] - if t.get("partition_granularity"): - entry["partition_granularity"] = t["partition_granularity"] - if t.get("max_history_days"): - entry["max_history_days"] = t["max_history_days"] - if t.get("initial_load_chunk_days") and t["initial_load_chunk_days"] != 30: - entry["initial_load_chunk_days"] = t["initial_load_chunk_days"] - if t.get("foreign_keys"): - entry["foreign_keys"] = t["foreign_keys"] - if t.get("where_filters"): - entry["where_filters"] = t["where_filters"] - if t.get("folder"): - entry["folder"] = t["folder"] - if t.get("dataset"): - entry["dataset"] = t["dataset"] - - yaml_tables.append(entry) - - yaml_data["tables"] = yaml_tables - - yaml_str = yaml.dump( - yaml_data, default_flow_style=False, sort_keys=False, allow_unicode=True - ) - - # Compute checksum - checksum = hashlib.sha256(yaml_str.encode()).hexdigest()[:16] - - # Build markdown - lines = [ - f"", - f"", - f"", - "", - "# Data Description", - "", - f"Generated at {_now_iso()} from table registry " - f"(version {self.version}, {len(yaml_tables)} tables).", - "", - "```yaml", - yaml_str.rstrip(), - "```", - "", - ] - - content = "\n".join(lines) - - output_path.parent.mkdir(parents=True, exist_ok=True) - output_path.write_text(content) - logger.info( - f"Generated data_description.md: {len(yaml_tables)} tables " - f"(checksum: {checksum})" - ) - - # ── Migration ──────────────────────────────────────────────────── - - @classmethod - def import_from_data_description( - cls, - md_path: Path, - registry_path: Path, - registered_by: str = "migration", - ) -> "TableRegistry": - """One-time migration: parse existing data_description.md into registry. - - Creates a new registry JSON from the existing markdown YAML blocks. - """ - if not md_path.exists(): - raise FileNotFoundError(f"data_description.md not found: {md_path}") - - content = md_path.read_text() - - # Extract YAML blocks - yaml_matches = re.findall(r"```yaml\n(.*?)```", content, re.DOTALL) - if not yaml_matches: - raise ValueError("No YAML blocks found in data_description.md") - - all_tables: list[dict] = [] - folder_mapping: dict[str, str] = {} - - for yaml_block in yaml_matches: - data = yaml.safe_load(yaml_block) - if data: - if "tables" in data: - all_tables.extend(data["tables"]) - if "folder_mapping" in data: - folder_mapping.update(data["folder_mapping"]) - - if not all_tables: - raise ValueError("No tables found in YAML blocks") - - # Build registry - registry = cls(registry_path) - registry._data = cls._empty_registry() - registry._data["folder_mapping"] = folder_mapping - registry._data["_metadata"]["migrated_from"] = str(md_path) - - now = _now_iso() - for table_data in all_tables: - record = { - "id": table_data.get("id", ""), - "name": table_data.get("name", ""), - "description": table_data.get("description", ""), - "primary_key": table_data.get("primary_key", ""), - "sync_strategy": table_data.get("sync_strategy", "full_refresh"), - "incremental_window_days": table_data.get("incremental_window_days"), - "partition_by": table_data.get("partition_by"), - "partition_granularity": table_data.get("partition_granularity"), - "foreign_keys": table_data.get("foreign_keys", []), - "where_filters": table_data.get("where_filters", []), - "folder": table_data.get("folder"), - "dataset": table_data.get("dataset"), - "initial_load_chunk_days": table_data.get("initial_load_chunk_days", 30), - "max_history_days": table_data.get("max_history_days"), - "registered_at": now, - "registered_by": registered_by, - "source_metadata": {}, - } - registry._data["tables"].append(record) - - registry._save() - - _audit_log(registry_path, "migrate", { - "source": str(md_path), - "tables_imported": len(all_tables), - "by": registered_by, - }) - - logger.info( - f"Migrated {len(all_tables)} tables from {md_path} to registry" - ) - return registry - - -class ConflictError(Exception): - """Raised when optimistic locking version doesn't match.""" - pass diff --git a/tests/test_bigquery_adapter.py b/tests/test_bigquery_adapter.py deleted file mode 100644 index fd5f0bb..0000000 --- a/tests/test_bigquery_adapter.py +++ /dev/null @@ -1,1103 +0,0 @@ -""" -Comprehensive unit tests for the BigQuery data source adapter. - -Tests the BigQueryDataSource class from connectors/bigquery/adapter.py -with all external dependencies (BigQueryClient, config, parquet_manager) mocked. - -The google-cloud-bigquery package is not installed in test environments, -so we install stub modules in sys.modules before importing the adapter. -""" - -import sys -from datetime import date, datetime, timedelta -from pathlib import Path -from unittest.mock import MagicMock, patch - -import pyarrow as pa -import pyarrow.parquet as pq -import pytest - -# --------------------------------------------------------------------------- -# Stub google.cloud.bigquery before any connector import -# --------------------------------------------------------------------------- -_bq_stub = MagicMock() -sys.modules.setdefault("google", _bq_stub) -sys.modules.setdefault("google.cloud", _bq_stub) -sys.modules.setdefault("google.cloud.bigquery", _bq_stub) - -from src.config import TableConfig # noqa: E402 -from src.data_sync import SyncState # noqa: E402 - - -# --------------------------------------------------------------------------- -# Fixtures -# --------------------------------------------------------------------------- - -@pytest.fixture -def tmp_parquet_dir(tmp_path): - """Provide a temporary directory for Parquet file output.""" - parquet_dir = tmp_path / "parquet" / "test_bucket" - parquet_dir.mkdir(parents=True) - return parquet_dir - - -@pytest.fixture -def mock_config(tmp_parquet_dir): - """Create a mock Config object that returns paths inside tmp_parquet_dir.""" - config = MagicMock() - config.get_parquet_path = MagicMock() - config.get_partition_path = MagicMock() - config.get_metadata_path.return_value = tmp_parquet_dir.parent / "metadata" - return config - - -@pytest.fixture -def mock_bq_client(): - """Create a mock BigQueryClient with sensible defaults.""" - client = MagicMock() - client.metadata_cache = {} - client.get_date_columns.return_value = [] - client.get_pyarrow_schema.return_value = None - return client - - -@pytest.fixture -def sync_state(tmp_path): - """Create a real SyncState backed by a temp JSON file.""" - state_file = tmp_path / "metadata" / "sync_state.json" - state_file.parent.mkdir(parents=True, exist_ok=True) - return SyncState(state_file) - - -# --------------------------------------------------------------------------- -# Helpers -# --------------------------------------------------------------------------- - -def _make_table_config( - *, - table_id: str = "project.dataset.orders", - name: str = "orders", - primary_key: str = "id", - sync_strategy: str = "full_refresh", - incremental_column: str | None = None, - incremental_window_days: int | None = None, - partition_by: str | None = None, - partition_granularity: str | None = None, - max_history_days: int | None = None, - partition_column_type: str = "TIMESTAMP", -) -> TableConfig: - """Helper to build a TableConfig with safe defaults.""" - return TableConfig( - id=table_id, - name=name, - description="Test table", - primary_key=primary_key, - sync_strategy=sync_strategy, - incremental_column=incremental_column, - incremental_window_days=incremental_window_days, - partition_by=partition_by, - partition_granularity=partition_granularity, - max_history_days=max_history_days, - partition_column_type=partition_column_type, - ) - - -def _sample_arrow_table(ids: list[int], names: list[str]) -> pa.Table: - """Build a small PyArrow Table with id and name columns.""" - return pa.table({"id": ids, "name": names}) - - -def _as_batches(arrow_table: pa.Table) -> list: - """Convert Arrow table to list of RecordBatches (mimics streaming from BQ).""" - return arrow_table.to_batches() - - -def _create_adapter(mock_config, mock_bq_client): - """Instantiate BigQueryDataSource with mocked dependencies. - - Patches get_config and create_bq_client so that no real GCP - credentials or network access are needed. - """ - with patch("connectors.bigquery.adapter.get_config", return_value=mock_config), \ - patch("connectors.bigquery.adapter.create_bq_client", return_value=mock_bq_client): - from connectors.bigquery.adapter import BigQueryDataSource - adapter = BigQueryDataSource() - return adapter - - -# --------------------------------------------------------------------------- -# 1. full_refresh writes valid Parquet file from Arrow table -# --------------------------------------------------------------------------- - -class TestFullRefresh: - - def test_writes_valid_parquet(self, mock_config, mock_bq_client, tmp_parquet_dir, sync_state): - """full_refresh should stream batches and write a valid Parquet file.""" - table_config = _make_table_config(sync_strategy="full_refresh") - parquet_path = tmp_parquet_dir / "orders.parquet" - mock_config.get_parquet_path.return_value = parquet_path - - arrow_data = _sample_arrow_table([1, 2, 3], ["Alice", "Bob", "Charlie"]) - mock_bq_client.read_table_streaming.return_value = _as_batches(arrow_data) - - adapter = _create_adapter(mock_config, mock_bq_client) - result = adapter.sync_table(table_config, sync_state) - - assert result["success"] is True - assert result["rows"] == 3 - assert parquet_path.exists() - - # Verify Parquet content matches source data - read_back = pq.read_table(parquet_path) - assert read_back.num_rows == 3 - assert read_back.column_names == ["id", "name"] - - def test_applies_date_columns(self, mock_config, mock_bq_client, tmp_parquet_dir, sync_state): - """full_refresh should call convert_date_columns_to_date32 per batch.""" - table_config = _make_table_config() - parquet_path = tmp_parquet_dir / "orders.parquet" - mock_config.get_parquet_path.return_value = parquet_path - - arrow_data = _sample_arrow_table([1], ["Alice"]) - mock_bq_client.read_table_streaming.return_value = _as_batches(arrow_data) - mock_bq_client.get_date_columns.return_value = ["created_at"] - - with patch("connectors.bigquery.adapter.convert_date_columns_to_date32", return_value=arrow_data) as mock_conv: - adapter = _create_adapter(mock_config, mock_bq_client) - adapter.sync_table(table_config, sync_state) - mock_conv.assert_called_once() - assert mock_conv.call_args[0][1] == ["created_at"] - - def test_applies_pyarrow_schema(self, mock_config, mock_bq_client, tmp_parquet_dir, sync_state): - """full_refresh should call apply_schema_to_table per batch.""" - table_config = _make_table_config() - parquet_path = tmp_parquet_dir / "orders.parquet" - mock_config.get_parquet_path.return_value = parquet_path - - arrow_data = _sample_arrow_table([1], ["Alice"]) - mock_bq_client.read_table_streaming.return_value = _as_batches(arrow_data) - schema = pa.schema([pa.field("id", pa.int64()), pa.field("name", pa.string())]) - mock_bq_client.get_pyarrow_schema.return_value = schema - - with patch("connectors.bigquery.adapter.apply_schema_to_table", return_value=arrow_data) as mock_apply: - adapter = _create_adapter(mock_config, mock_bq_client) - adapter.sync_table(table_config, sync_state) - mock_apply.assert_called_once() - assert mock_apply.call_args[0][1] == schema - - -# --------------------------------------------------------------------------- -# 2. incremental_column_sync merges correctly (dedup on PK, new data wins) -# --------------------------------------------------------------------------- - -class TestIncrementalColumnSync: - - def test_merge_dedup_new_data_wins(self, mock_config, mock_bq_client, tmp_parquet_dir, sync_state): - """Incremental sync should overwrite existing rows when PK matches (new data wins).""" - table_config = _make_table_config( - sync_strategy="incremental", - incremental_column="updated_at", - incremental_window_days=7, - ) - parquet_path = tmp_parquet_dir / "orders.parquet" - mock_config.get_parquet_path.return_value = parquet_path - - # Write existing data - existing = _sample_arrow_table([1, 2], ["Alice", "Bob"]) - pq.write_table(existing, parquet_path) - - # Simulate a previous sync timestamp - sync_state.update_sync( - table_id=table_config.id, - table_name=table_config.name, - strategy="incremental", - rows=2, - file_size_bytes=100, - ) - - # New data: id=2 gets updated name, id=3 is new - new_data = _sample_arrow_table([2, 3], ["Bob_Updated", "Charlie"]) - mock_bq_client.read_table_incremental.return_value = new_data - - adapter = _create_adapter(mock_config, mock_bq_client) - result = adapter.sync_table(table_config, sync_state) - - assert result["success"] is True - assert result["rows"] == 3 # Alice + Bob_Updated + Charlie - - read_back = pq.read_table(parquet_path) - df = read_back.to_pandas() - assert set(df["id"].tolist()) == {1, 2, 3} - # id=2 should have the updated name - bob_row = df[df["id"] == 2].iloc[0] - assert bob_row["name"] == "Bob_Updated" - - -# --------------------------------------------------------------------------- -# 3. incremental_column_sync with no new data returns existing file info -# --------------------------------------------------------------------------- - -class TestIncrementalNoNewData: - - def test_returns_existing_file_info(self, mock_config, mock_bq_client, tmp_parquet_dir, sync_state): - """When there is no new data, sync returns stats from the existing Parquet file.""" - table_config = _make_table_config( - sync_strategy="incremental", - incremental_column="updated_at", - incremental_window_days=7, - ) - parquet_path = tmp_parquet_dir / "orders.parquet" - mock_config.get_parquet_path.return_value = parquet_path - - # Write existing data - existing = _sample_arrow_table([1, 2, 3], ["A", "B", "C"]) - pq.write_table(existing, parquet_path) - - # Mark a previous sync - sync_state.update_sync( - table_id=table_config.id, - table_name=table_config.name, - strategy="incremental", - rows=3, - file_size_bytes=100, - ) - - # No new rows - empty_table = pa.table({ - "id": pa.array([], type=pa.int64()), - "name": pa.array([], type=pa.string()), - }) - mock_bq_client.read_table_incremental.return_value = empty_table - - adapter = _create_adapter(mock_config, mock_bq_client) - result = adapter.sync_table(table_config, sync_state) - - assert result["success"] is True - assert result["rows"] == 3 # existing row count preserved - - -# --------------------------------------------------------------------------- -# 4. partitioned_sync - per-day streaming behaviour -# --------------------------------------------------------------------------- - -class TestPartitionedSync: - """Tests for the rewritten _partitioned_sync() that streams per-day from BQ.""" - - def _setup_partition_config( - self, - mock_config, - tmp_parquet_dir, - *, - granularity: str = "day", - max_history_days: int | None = 3, - incremental_window_days: int | None = 2, - partition_column_type: str = "TIMESTAMP", - ): - """Common setup: create table config + wire mock_config partition paths.""" - table_config = _make_table_config( - sync_strategy="partitioned", - incremental_column="event_date", - partition_by="event_date", - partition_granularity=granularity, - max_history_days=max_history_days, - incremental_window_days=incremental_window_days, - partition_column_type=partition_column_type, - ) - - partition_dir = tmp_parquet_dir / "orders" - partition_dir.mkdir(parents=True, exist_ok=True) - mock_config.get_parquet_path.return_value = partition_dir - - def _partition_path(tc, key): - return partition_dir / f"{key}.parquet" - mock_config.get_partition_path.side_effect = _partition_path - - return table_config, partition_dir - - @staticmethod - def _make_day_table(row_id: int, day: date) -> pa.Table: - """Build a one-row Arrow table for a given day.""" - return pa.table({ - "id": [row_id], - "event_date": pa.array([datetime(day.year, day.month, day.day)], type=pa.timestamp("us")), - }) - - def test_creates_daily_partition_files( - self, mock_config, mock_bq_client, tmp_parquet_dir, sync_state - ): - """First sync with max_history_days creates one Parquet file per day with data.""" - table_config, partition_dir = self._setup_partition_config( - mock_config, tmp_parquet_dir, max_history_days=3, granularity="day", - ) - - today = date.today() - day0 = today - timedelta(days=3) - day1 = today - timedelta(days=2) - # day2, day3 (today-1, today) will have no data - - # Build per-day Arrow data for the two days that have rows - day0_table = self._make_day_table(1, day0) - day1_table = self._make_day_table(2, day1) - - # read_table_partitioned_streaming is called once per partition date. - # We need to return data for day0 and day1, empty iterators for the rest. - def _streaming_side_effect(*, table_id, partition_column, start, end, columns, column_type): - start_date = date.fromisoformat(start) - if start_date == day0: - return iter(day0_table.to_batches()) - if start_date == day1: - return iter(day1_table.to_batches()) - return iter([]) # empty for other days - - mock_bq_client.read_table_partitioned_streaming.side_effect = _streaming_side_effect - - adapter = _create_adapter(mock_config, mock_bq_client) - result = adapter.sync_table(table_config, sync_state) - - assert result["success"] is True - - # Should have exactly 2 partition files (days with data) - partition_files = sorted(partition_dir.glob("*.parquet")) - assert len(partition_files) == 2 - - file_names = sorted(f.stem for f in partition_files) - assert day0.strftime("%Y_%m_%d") in file_names - assert day1.strftime("%Y_%m_%d") in file_names - - # Verify content of each partition - t0 = pq.read_table(partition_dir / f"{day0.strftime('%Y_%m_%d')}.parquet") - assert t0.num_rows == 1 - assert t0.column("id").to_pylist() == [1] - - def test_incremental_sync_only_fetches_window( - self, mock_config, mock_bq_client, tmp_parquet_dir, sync_state - ): - """After a previous sync, only the incremental window of days is fetched.""" - table_config, partition_dir = self._setup_partition_config( - mock_config, tmp_parquet_dir, - max_history_days=30, - incremental_window_days=2, - granularity="day", - ) - - # Simulate a previous sync 1 day ago - sync_time = (datetime.now() - timedelta(days=1)).isoformat() - sync_state.update_sync( - table_id=table_config.id, - table_name=table_config.name, - strategy="partitioned", - rows=100, - file_size_bytes=5000, - ) - - # Return empty for all calls -- we just want to verify the call count - mock_bq_client.read_table_partitioned_streaming.return_value = iter([]) - - adapter = _create_adapter(mock_config, mock_bq_client) - result = adapter.sync_table(table_config, sync_state) - - assert result["success"] is True - - # With incremental_window_days=2, it should go back 2 days from last_sync. - # The number of partition dates from (last_sync - 2 days) to today. - last_sync_str = sync_state.get_last_sync(table_config.id) - last_sync_dt = datetime.fromisoformat(last_sync_str) - start_date = (last_sync_dt - timedelta(days=2)).date() - today = date.today() - expected_days = (today - start_date).days + 1 # inclusive - - actual_calls = mock_bq_client.read_table_partitioned_streaming.call_count - assert actual_calls == expected_days, ( - f"Expected {expected_days} BQ calls (from {start_date} to {today}), got {actual_calls}" - ) - - def test_merges_with_existing_partition( - self, mock_config, mock_bq_client, tmp_parquet_dir, sync_state - ): - """New data for an existing partition merges and deduplicates on PK.""" - table_config, partition_dir = self._setup_partition_config( - mock_config, tmp_parquet_dir, max_history_days=3, granularity="day", - ) - - today = date.today() - target_day = today - timedelta(days=1) - partition_key = target_day.strftime("%Y_%m_%d") - partition_path = partition_dir / f"{partition_key}.parquet" - - # Pre-write an existing partition file with id=1 - existing = pa.table({ - "id": [1], - "event_date": pa.array( - [datetime(target_day.year, target_day.month, target_day.day)], - type=pa.timestamp("us"), - ), - }) - pq.write_table(existing, partition_path, compression="snappy") - - # New data: id=1 (update) + id=2 (new row) - new_data = pa.table({ - "id": [1, 2], - "event_date": pa.array( - [ - datetime(target_day.year, target_day.month, target_day.day), - datetime(target_day.year, target_day.month, target_day.day), - ], - type=pa.timestamp("us"), - ), - }) - - def _streaming_side_effect(*, table_id, partition_column, start, end, columns, column_type): - start_date = date.fromisoformat(start) - if start_date == target_day: - return iter(new_data.to_batches()) - return iter([]) - - mock_bq_client.read_table_partitioned_streaming.side_effect = _streaming_side_effect - - adapter = _create_adapter(mock_config, mock_bq_client) - result = adapter.sync_table(table_config, sync_state) - - assert result["success"] is True - - # Read back the target partition -- should have 2 rows (dedup on id) - merged = pq.read_table(partition_path) - assert merged.num_rows == 2 - assert sorted(merged.column("id").to_pylist()) == [1, 2] - - def test_empty_partition_skipped( - self, mock_config, mock_bq_client, tmp_parquet_dir, sync_state - ): - """A partition day with no data from BQ should not create a file.""" - table_config, partition_dir = self._setup_partition_config( - mock_config, tmp_parquet_dir, max_history_days=2, granularity="day", - ) - - # Return empty iterator for every call - mock_bq_client.read_table_partitioned_streaming.return_value = iter([]) - # side_effect takes precedence over return_value when set, but let's use - # a function so each call gets a fresh empty iterator - mock_bq_client.read_table_partitioned_streaming.side_effect = ( - lambda **kw: iter([]) - ) - - adapter = _create_adapter(mock_config, mock_bq_client) - result = adapter.sync_table(table_config, sync_state) - - assert result["success"] is True - - # No partition files should have been created - partition_files = list(partition_dir.glob("*.parquet")) - assert len(partition_files) == 0 - - -# --------------------------------------------------------------------------- -# 5. discover_tables delegates to BigQueryClient.discover_all_tables() -# --------------------------------------------------------------------------- - -class TestDiscoverTables: - - def test_delegates_to_client(self, mock_config, mock_bq_client): - """discover_tables should forward the call to BigQueryClient.discover_all_tables.""" - expected = [{"id": "proj.ds.t1", "name": "t1", "columns": ["a", "b"]}] - mock_bq_client.discover_all_tables.return_value = expected - - adapter = _create_adapter(mock_config, mock_bq_client) - result = adapter.discover_tables() - - mock_bq_client.discover_all_tables.assert_called_once() - assert result == expected - - -# --------------------------------------------------------------------------- -# 6. get_source_name returns "Google BigQuery" -# --------------------------------------------------------------------------- - -class TestGetSourceName: - - def test_returns_google_bigquery(self, mock_config, mock_bq_client): - adapter = _create_adapter(mock_config, mock_bq_client) - assert adapter.get_source_name() == "Google BigQuery" - - -# --------------------------------------------------------------------------- -# 7. get_column_metadata returns correct format -# --------------------------------------------------------------------------- - -class TestGetColumnMetadata: - - def test_returns_correct_format(self, mock_config, mock_bq_client): - """get_column_metadata should transform BQ raw metadata into {columns: ...} format.""" - mock_bq_client.get_table_metadata.return_value = { - "column_types": {"id": "INT64", "name": "STRING", "email": "STRING"}, - "column_descriptions": {"id": "Primary key", "email": "User email address"}, - } - - adapter = _create_adapter(mock_config, mock_bq_client) - result = adapter.get_column_metadata("project.dataset.users") - - assert "columns" in result - assert result["columns"]["id"] == {"source_type": "INT64", "description": "Primary key"} - assert result["columns"]["name"] == {"source_type": "STRING"} - assert result["columns"]["email"] == { - "source_type": "STRING", - "description": "User email address", - } - - def test_returns_none_when_no_column_types(self, mock_config, mock_bq_client): - """get_column_metadata should return None if the metadata has no column types.""" - mock_bq_client.get_table_metadata.return_value = { - "column_types": {}, - "column_descriptions": {}, - } - - adapter = _create_adapter(mock_config, mock_bq_client) - result = adapter.get_column_metadata("project.dataset.users") - - assert result is None - - -# --------------------------------------------------------------------------- -# 8. Error handling (query failure -> {success: False, error: ...}) -# --------------------------------------------------------------------------- - -class TestErrorHandling: - - def test_query_failure_returns_error_dict( - self, mock_config, mock_bq_client, tmp_parquet_dir, sync_state - ): - """When BigQuery query raises, sync_table returns {success: False, error: ...}.""" - table_config = _make_table_config() - mock_config.get_parquet_path.return_value = tmp_parquet_dir / "orders.parquet" - mock_bq_client.read_table_streaming.side_effect = RuntimeError("BigQuery API timeout") - - adapter = _create_adapter(mock_config, mock_bq_client) - result = adapter.sync_table(table_config, sync_state) - - assert result["success"] is False - assert "BigQuery API timeout" in result["error"] - assert result["strategy"] == "full_refresh" - - def test_unknown_strategy_returns_error(self, mock_config, mock_bq_client, sync_state): - """Unknown sync_strategy in internal dispatch should produce an error result.""" - # We cannot create a TableConfig with an invalid strategy via constructor - # (it validates). Instead, we mutate it after creation. - table_config = _make_table_config() - table_config.sync_strategy = "magic_sync" - - adapter = _create_adapter(mock_config, mock_bq_client) - result = adapter.sync_table(table_config, sync_state) - - assert result["success"] is False - assert "Unknown sync strategy" in result["error"] - - -# --------------------------------------------------------------------------- -# 9. incremental_column config is used in WHERE clause -# --------------------------------------------------------------------------- - -class TestIncrementalColumnUsedInWhere: - - def test_incremental_column_passed_to_client( - self, mock_config, mock_bq_client, tmp_parquet_dir, sync_state - ): - """The configured incremental_column should be forwarded to read_table_incremental.""" - table_config = _make_table_config( - sync_strategy="incremental", - incremental_column="modified_at", - incremental_window_days=14, - ) - parquet_path = tmp_parquet_dir / "orders.parquet" - mock_config.get_parquet_path.return_value = parquet_path - - # Write existing data so we enter the incremental path - existing = _sample_arrow_table([1], ["Alice"]) - pq.write_table(existing, parquet_path) - - sync_state.update_sync( - table_id=table_config.id, - table_name=table_config.name, - strategy="incremental", - rows=1, - file_size_bytes=100, - ) - - # Return empty to keep the test simple - empty = pa.table({ - "id": pa.array([], type=pa.int64()), - "name": pa.array([], type=pa.string()), - }) - mock_bq_client.read_table_incremental.return_value = empty - - adapter = _create_adapter(mock_config, mock_bq_client) - adapter.sync_table(table_config, sync_state) - - call_kwargs = mock_bq_client.read_table_incremental.call_args - assert call_kwargs.kwargs["incremental_column"] == "modified_at" - assert call_kwargs.kwargs["table_id"] == "project.dataset.orders" - # since_value should be an ISO string - assert "since_value" in call_kwargs.kwargs - - -# --------------------------------------------------------------------------- -# 10. First sync without existing file downloads all data -# --------------------------------------------------------------------------- - -class TestFirstSyncDownloadsAll: - - def test_first_sync_reads_full_table( - self, mock_config, mock_bq_client, tmp_parquet_dir, sync_state - ): - """On first incremental sync (no existing file), adapter should read all data.""" - table_config = _make_table_config( - sync_strategy="incremental", - incremental_column="updated_at", - incremental_window_days=7, - ) - parquet_path = tmp_parquet_dir / "orders.parquet" - mock_config.get_parquet_path.return_value = parquet_path - - # No previous sync, no existing file - arrow_data = _sample_arrow_table([1, 2, 3], ["A", "B", "C"]) - mock_bq_client.read_table.return_value = arrow_data - - adapter = _create_adapter(mock_config, mock_bq_client) - result = adapter.sync_table(table_config, sync_state) - - assert result["success"] is True - assert result["rows"] == 3 - # Should call read_table (full), not read_table_incremental - mock_bq_client.read_table.assert_called_once_with( - table_config.id, columns=None, row_filter=None, - ) - mock_bq_client.read_table_incremental.assert_not_called() - - def test_first_sync_with_max_history_days( - self, mock_config, mock_bq_client, tmp_parquet_dir, sync_state - ): - """First sync with max_history_days should use read_table_incremental.""" - table_config = _make_table_config( - sync_strategy="incremental", - incremental_column="updated_at", - incremental_window_days=7, - max_history_days=90, - ) - parquet_path = tmp_parquet_dir / "orders.parquet" - mock_config.get_parquet_path.return_value = parquet_path - - arrow_data = _sample_arrow_table([1, 2], ["A", "B"]) - mock_bq_client.read_table_incremental.return_value = arrow_data - - adapter = _create_adapter(mock_config, mock_bq_client) - result = adapter.sync_table(table_config, sync_state) - - assert result["success"] is True - # Should use read_table_incremental (not read_table) because max_history_days is set - mock_bq_client.read_table_incremental.assert_called_once() - call_kwargs = mock_bq_client.read_table_incremental.call_args.kwargs - assert call_kwargs["incremental_column"] == "updated_at" - mock_bq_client.read_table.assert_not_called() - - -# --------------------------------------------------------------------------- -# 11. sync_table dispatches to correct strategy based on sync_strategy -# --------------------------------------------------------------------------- - -class TestSyncTableDispatch: - - def test_dispatches_full_refresh( - self, mock_config, mock_bq_client, tmp_parquet_dir, sync_state - ): - """sync_strategy='full_refresh' should call _full_refresh.""" - table_config = _make_table_config(sync_strategy="full_refresh") - mock_config.get_parquet_path.return_value = tmp_parquet_dir / "orders.parquet" - mock_bq_client.read_table_streaming.return_value = _as_batches(_sample_arrow_table([1], ["A"])) - - adapter = _create_adapter(mock_config, mock_bq_client) - - with patch.object(adapter, "_full_refresh", wraps=adapter._full_refresh) as spy: - adapter.sync_table(table_config, sync_state) - spy.assert_called_once_with(table_config) - - def test_dispatches_incremental( - self, mock_config, mock_bq_client, tmp_parquet_dir, sync_state - ): - """sync_strategy='incremental' should call _incremental_sync.""" - table_config = _make_table_config( - sync_strategy="incremental", - incremental_column="updated_at", - incremental_window_days=7, - ) - mock_config.get_parquet_path.return_value = tmp_parquet_dir / "orders.parquet" - mock_bq_client.read_table.return_value = _sample_arrow_table([1], ["A"]) - - adapter = _create_adapter(mock_config, mock_bq_client) - - with patch.object(adapter, "_incremental_sync", wraps=adapter._incremental_sync) as spy: - adapter.sync_table(table_config, sync_state) - spy.assert_called_once_with(table_config, sync_state) - - def test_dispatches_partitioned( - self, mock_config, mock_bq_client, tmp_parquet_dir, sync_state - ): - """sync_strategy='partitioned' should call _partitioned_sync.""" - table_config = _make_table_config( - sync_strategy="partitioned", - incremental_column="created_at", - partition_by="created_at", - partition_granularity="day", - max_history_days=2, - ) - partition_dir = tmp_parquet_dir / "orders" - partition_dir.mkdir(parents=True, exist_ok=True) - mock_config.get_parquet_path.return_value = partition_dir - - def _partition_path(tc, key): - return partition_dir / f"{key}.parquet" - mock_config.get_partition_path.side_effect = _partition_path - - mock_bq_client.read_table_partitioned_streaming.side_effect = ( - lambda **kw: iter([]) - ) - - adapter = _create_adapter(mock_config, mock_bq_client) - - with patch.object(adapter, "_partitioned_sync", wraps=adapter._partitioned_sync) as spy: - adapter.sync_table(table_config, sync_state) - spy.assert_called_once() - - def test_incremental_without_column_falls_back_to_full_refresh( - self, mock_config, mock_bq_client, tmp_parquet_dir, sync_state - ): - """incremental strategy without incremental_column or partition_by falls back to full_refresh.""" - table_config = _make_table_config( - sync_strategy="incremental", - incremental_column=None, - partition_by=None, - incremental_window_days=7, - ) - mock_config.get_parquet_path.return_value = tmp_parquet_dir / "orders.parquet" - mock_bq_client.read_table_streaming.return_value = _as_batches(_sample_arrow_table([1], ["A"])) - - adapter = _create_adapter(mock_config, mock_bq_client) - - with patch.object(adapter, "_full_refresh", wraps=adapter._full_refresh) as spy: - result = adapter.sync_table(table_config, sync_state) - spy.assert_called_once() - assert result["success"] is True - - -# --------------------------------------------------------------------------- -# 12. _merge_arrow_tables deduplicates correctly -# --------------------------------------------------------------------------- - -class TestMergeArrowTables: - - def test_dedup_on_single_pk(self, mock_config, mock_bq_client): - """Merge should deduplicate on single primary key column, new data wins.""" - adapter = _create_adapter(mock_config, mock_bq_client) - - existing = pa.table({"id": [1, 2, 3], "val": ["a", "b", "c"]}) - new_data = pa.table({"id": [2, 4], "val": ["B_new", "d"]}) - - merged = adapter._merge_arrow_tables(existing, new_data, primary_key=["id"]) - df = merged.to_pandas().sort_values("id").reset_index(drop=True) - - assert list(df["id"]) == [1, 2, 3, 4] - assert list(df["val"]) == ["a", "B_new", "c", "d"] - - def test_dedup_on_composite_pk(self, mock_config, mock_bq_client): - """Merge should deduplicate on composite primary key.""" - adapter = _create_adapter(mock_config, mock_bq_client) - - existing = pa.table({ - "pk1": [1, 1, 2], - "pk2": ["a", "b", "a"], - "val": ["old_1a", "old_1b", "old_2a"], - }) - new_data = pa.table({ - "pk1": [1, 2], - "pk2": ["a", "a"], - "val": ["new_1a", "new_2a"], - }) - - merged = adapter._merge_arrow_tables(existing, new_data, primary_key=["pk1", "pk2"]) - df = merged.to_pandas().sort_values(["pk1", "pk2"]).reset_index(drop=True) - - assert len(df) == 3 - # (1, a) should be updated - row_1a = df[(df["pk1"] == 1) & (df["pk2"] == "a")].iloc[0] - assert row_1a["val"] == "new_1a" - # (1, b) should be preserved - row_1b = df[(df["pk1"] == 1) & (df["pk2"] == "b")].iloc[0] - assert row_1b["val"] == "old_1b" - # (2, a) should be updated - row_2a = df[(df["pk1"] == 2) & (df["pk2"] == "a")].iloc[0] - assert row_2a["val"] == "new_2a" - - def test_merge_with_empty_new_data(self, mock_config, mock_bq_client): - """Merging with empty new data should return existing data unchanged.""" - adapter = _create_adapter(mock_config, mock_bq_client) - - existing = pa.table({"id": [1, 2], "val": ["a", "b"]}) - empty = pa.table({ - "id": pa.array([], type=pa.int64()), - "val": pa.array([], type=pa.string()), - }) - - merged = adapter._merge_arrow_tables(existing, empty, primary_key=["id"]) - assert merged.num_rows == 2 - - def test_merge_with_empty_existing(self, mock_config, mock_bq_client): - """Merging with empty existing data should return new data.""" - adapter = _create_adapter(mock_config, mock_bq_client) - - empty = pa.table({ - "id": pa.array([], type=pa.int64()), - "val": pa.array([], type=pa.string()), - }) - new_data = pa.table({"id": [1, 2], "val": ["a", "b"]}) - - merged = adapter._merge_arrow_tables(empty, new_data, primary_key=["id"]) - assert merged.num_rows == 2 - - -# --------------------------------------------------------------------------- -# Additional edge cases -# --------------------------------------------------------------------------- - -class TestMetadataCacheClearing: - - def test_clears_metadata_cache_before_sync( - self, mock_config, mock_bq_client, tmp_parquet_dir, sync_state - ): - """sync_table should clear the BQ metadata cache entry for the table being synced.""" - table_config = _make_table_config() - parquet_path = tmp_parquet_dir / "orders.parquet" - mock_config.get_parquet_path.return_value = parquet_path - mock_bq_client.read_table_streaming.return_value = _as_batches(_sample_arrow_table([1], ["A"])) - - # Pre-populate cache - mock_bq_client.metadata_cache[table_config.id] = {"some": "cached_data"} - - adapter = _create_adapter(mock_config, mock_bq_client) - adapter.sync_table(table_config, sync_state) - - assert table_config.id not in mock_bq_client.metadata_cache - - -class TestSyncStateUpdate: - - def test_sync_state_updated_after_success( - self, mock_config, mock_bq_client, tmp_parquet_dir, sync_state - ): - """After successful sync, the sync state should be updated with correct values.""" - table_config = _make_table_config() - parquet_path = tmp_parquet_dir / "orders.parquet" - mock_config.get_parquet_path.return_value = parquet_path - mock_bq_client.read_table_streaming.return_value = _as_batches(_sample_arrow_table([1, 2], ["A", "B"])) - - adapter = _create_adapter(mock_config, mock_bq_client) - adapter.sync_table(table_config, sync_state) - - state = sync_state.get_table_state(table_config.id) - assert state["rows"] == 2 - assert state["strategy"] == "full_refresh" - assert state["table_name"] == "orders" - assert "last_sync" in state - - def test_sync_state_not_updated_on_failure( - self, mock_config, mock_bq_client, tmp_parquet_dir, sync_state - ): - """On sync failure, the sync state should NOT be updated.""" - table_config = _make_table_config() - mock_config.get_parquet_path.return_value = tmp_parquet_dir / "orders.parquet" - mock_bq_client.read_table_streaming.side_effect = RuntimeError("boom") - - adapter = _create_adapter(mock_config, mock_bq_client) - adapter.sync_table(table_config, sync_state) - - state = sync_state.get_table_state(table_config.id) - assert state == {} - - -class TestCreateDataSourceFactory: - - def test_factory_returns_adapter_instance(self, mock_config, mock_bq_client): - """create_data_source() factory should return a BigQueryDataSource instance.""" - with patch("connectors.bigquery.adapter.get_config", return_value=mock_config), \ - patch("connectors.bigquery.adapter.create_bq_client", return_value=mock_bq_client): - from connectors.bigquery.adapter import create_data_source, BigQueryDataSource - instance = create_data_source() - assert isinstance(instance, BigQueryDataSource) - - -# --------------------------------------------------------------------------- -# 14. _cleanup_old_partitions deletes files beyond retention window -# --------------------------------------------------------------------------- - -class TestPartitionCleanup: - - def test_deletes_old_partitions(self, mock_config, mock_bq_client, tmp_parquet_dir): - """Partition files older than max_history_days should be deleted.""" - table_config = _make_table_config( - sync_strategy="partitioned", - partition_by="event_date", - partition_granularity="day", - max_history_days=5, - ) - - partition_dir = tmp_parquet_dir / "orders" - partition_dir.mkdir(parents=True, exist_ok=True) - - today = date.today() - # Create files: 3 days ago (keep), 6 days ago (delete), 10 days ago (delete) - keep_day = today - timedelta(days=3) - delete_day1 = today - timedelta(days=6) - delete_day2 = today - timedelta(days=10) - - for d in [keep_day, delete_day1, delete_day2]: - key = d.strftime("%Y_%m_%d") - dummy = pa.table({"id": [1]}) - pq.write_table(dummy, partition_dir / f"{key}.parquet") - - adapter = _create_adapter(mock_config, mock_bq_client) - deleted = adapter._cleanup_old_partitions(table_config, partition_dir, "day") - - assert deleted == 2 - - # Only the recent file should remain - remaining = [f.stem for f in partition_dir.glob("*.parquet")] - assert keep_day.strftime("%Y_%m_%d") in remaining - assert delete_day1.strftime("%Y_%m_%d") not in remaining - assert delete_day2.strftime("%Y_%m_%d") not in remaining - - def test_no_cleanup_without_max_history_days(self, mock_config, mock_bq_client, tmp_parquet_dir): - """Without max_history_days, no partition files should be deleted.""" - table_config = _make_table_config( - sync_strategy="partitioned", - partition_by="event_date", - partition_granularity="day", - max_history_days=None, - ) - - partition_dir = tmp_parquet_dir / "orders" - partition_dir.mkdir(parents=True, exist_ok=True) - - # Create an old file (100 days ago) - old_day = date.today() - timedelta(days=100) - key = old_day.strftime("%Y_%m_%d") - pq.write_table(pa.table({"id": [1]}), partition_dir / f"{key}.parquet") - - adapter = _create_adapter(mock_config, mock_bq_client) - deleted = adapter._cleanup_old_partitions(table_config, partition_dir, "day") - - assert deleted == 0 - assert len(list(partition_dir.glob("*.parquet"))) == 1 - - -# --------------------------------------------------------------------------- -# 15. _generate_partition_dates produces correct date ranges -# --------------------------------------------------------------------------- - -class TestGeneratePartitionDates: - - def test_daily_generation(self, mock_config, mock_bq_client): - """Daily granularity should generate one date per day, inclusive.""" - adapter = _create_adapter(mock_config, mock_bq_client) - - start = date(2026, 3, 1) - end = date(2026, 3, 5) - dates = adapter._generate_partition_dates(start, end, "day") - - assert dates == [ - date(2026, 3, 1), - date(2026, 3, 2), - date(2026, 3, 3), - date(2026, 3, 4), - date(2026, 3, 5), - ] - - def test_monthly_generation(self, mock_config, mock_bq_client): - """Monthly granularity should generate first-of-month dates, aligned.""" - adapter = _create_adapter(mock_config, mock_bq_client) - - # Start mid-month -- should align to 1st - start = date(2026, 1, 15) - end = date(2026, 4, 10) - dates = adapter._generate_partition_dates(start, end, "month") - - assert dates == [ - date(2026, 1, 1), - date(2026, 2, 1), - date(2026, 3, 1), - date(2026, 4, 1), - ] - - def test_monthly_generation_across_year_boundary(self, mock_config, mock_bq_client): - """Monthly generation should cross year boundaries correctly.""" - adapter = _create_adapter(mock_config, mock_bq_client) - - start = date(2025, 11, 1) - end = date(2026, 2, 15) - dates = adapter._generate_partition_dates(start, end, "month") - - assert dates == [ - date(2025, 11, 1), - date(2025, 12, 1), - date(2026, 1, 1), - date(2026, 2, 1), - ] - - def test_daily_single_day(self, mock_config, mock_bq_client): - """When start == end, should return a single date.""" - adapter = _create_adapter(mock_config, mock_bq_client) - - d = date(2026, 6, 15) - dates = adapter._generate_partition_dates(d, d, "day") - - assert dates == [d] - - def test_empty_range(self, mock_config, mock_bq_client): - """When start > end, should return an empty list.""" - adapter = _create_adapter(mock_config, mock_bq_client) - - dates = adapter._generate_partition_dates(date(2026, 3, 10), date(2026, 3, 5), "day") - assert dates == [] - - -# --------------------------------------------------------------------------- -# 16. _parse_partition_date converts partition keys back to dates -# --------------------------------------------------------------------------- - -class TestParsePartitionDate: - - def test_parse_day_format(self, mock_config, mock_bq_client): - """'2026_01_15' with day granularity should parse to date(2026, 1, 15).""" - adapter = _create_adapter(mock_config, mock_bq_client) - result = adapter._parse_partition_date("2026_01_15", "day") - assert result == date(2026, 1, 15) - - def test_parse_month_format(self, mock_config, mock_bq_client): - """'2026_01' with month granularity should parse to date(2026, 1, 1).""" - adapter = _create_adapter(mock_config, mock_bq_client) - result = adapter._parse_partition_date("2026_01", "month") - assert result == date(2026, 1, 1) - - def test_parse_year_format(self, mock_config, mock_bq_client): - """'2026' with year granularity should parse to date(2026, 1, 1).""" - adapter = _create_adapter(mock_config, mock_bq_client) - result = adapter._parse_partition_date("2026", "year") - assert result == date(2026, 1, 1) - - def test_parse_invalid_returns_none(self, mock_config, mock_bq_client): - """Invalid partition key should return None.""" - adapter = _create_adapter(mock_config, mock_bq_client) - assert adapter._parse_partition_date("invalid", "day") is None - assert adapter._parse_partition_date("not_a_date", "month") is None - assert adapter._parse_partition_date("abc", "year") is None - - def test_parse_mismatched_granularity_returns_none(self, mock_config, mock_bq_client): - """Day key with month granularity should return None (format mismatch).""" - adapter = _create_adapter(mock_config, mock_bq_client) - # "2026_01_15" is a day format -- parsing as "month" (%Y_%m) should fail - assert adapter._parse_partition_date("2026_01_15", "month") is None diff --git a/tests/test_bigquery_client.py b/tests/test_bigquery_client.py deleted file mode 100644 index aa0a678..0000000 --- a/tests/test_bigquery_client.py +++ /dev/null @@ -1,1018 +0,0 @@ -"""Tests for the BigQuery client connector. - -All external dependencies (google.cloud.bigquery, src.config) are mocked. -Tests cover initialization, metadata caching, schema building, query methods, -and connection testing. -""" - -import json -import sys -from datetime import datetime, timedelta -from pathlib import Path -from unittest.mock import MagicMock, mock_open, patch - -import pyarrow as pa -import pytest - -# Pre-populate sys.modules with a mock google.cloud.bigquery if not installed, -# so the client module can be imported without the real SDK. -_bq_mock_installed = False -try: - from google.cloud import bigquery as _bq_test # noqa: F401 -except ImportError: - _bq_mock_installed = True - _mock_bigquery = MagicMock() - # Expose commonly used classes as MagicMock so the client module - # can reference bigquery.Client, bigquery.QueryJobConfig, etc. - sys.modules.setdefault("google", MagicMock()) - sys.modules.setdefault("google.cloud", MagicMock()) - sys.modules.setdefault("google.cloud.bigquery", _mock_bigquery) - -from connectors.bigquery.client import ( - BIGQUERY_TO_PYARROW_TYPES, - BigQueryClient, - create_client, -) - -# Import the real or mock bigquery reference used in the client module -from google.cloud import bigquery - - -# --------------------------------------------------------------------------- -# Helpers -# --------------------------------------------------------------------------- - -def _make_bq_field(name: str, field_type: str, description: str = None): - """Create a mock BigQuery SchemaField.""" - field = MagicMock() - field.name = name - field.field_type = field_type - field.description = description - return field - - -def _make_table_ref( - table_id: str = "my-project.my_dataset.my_table", - schema=None, - num_rows: int = 1000, - num_bytes: int = 50000, - created: datetime = None, - modified: datetime = None, - time_partitioning=None, -): - """Create a mock BigQuery Table reference object.""" - table_ref = MagicMock() - table_ref.table_id = table_id.split(".")[-1] - table_ref.dataset_id = table_id.split(".")[1] if "." in table_id else "dataset" - table_ref.project = table_id.split(".")[0] if "." in table_id else "project" - table_ref.schema = schema or [] - table_ref.num_rows = num_rows - table_ref.num_bytes = num_bytes - table_ref.created = created or datetime(2025, 1, 1, 12, 0, 0) - table_ref.modified = modified or datetime(2025, 6, 1, 12, 0, 0) - table_ref.time_partitioning = time_partitioning - return table_ref - - -@pytest.fixture -def mock_config(tmp_path): - """Mock get_config() to return a config with metadata path in tmp_path.""" - config = MagicMock() - metadata_dir = tmp_path / "metadata" - metadata_dir.mkdir(parents=True, exist_ok=True) - config.get_metadata_path.return_value = metadata_dir - return config - - -@pytest.fixture -def mock_bq_client(): - """Create a mock BigQuery Client.""" - return MagicMock() - - -@pytest.fixture -def client(mock_config, mock_bq_client): - """Create a BigQueryClient instance with mocked dependencies.""" - with ( - patch("connectors.bigquery.client.bigquery.Client", return_value=mock_bq_client), - patch("connectors.bigquery.client.get_config", return_value=mock_config), - patch.dict("os.environ", {"BIGQUERY_PROJECT": "test-project"}), - ): - bq_client = BigQueryClient() - return bq_client - - -# --------------------------------------------------------------------------- -# 1. Init validates BIGQUERY_PROJECT env var -# --------------------------------------------------------------------------- - -class TestInit: - def test_raises_value_error_when_project_not_set(self, mock_config): - """Init raises ValueError if project_id is None and env var is missing.""" - with ( - patch("connectors.bigquery.client.bigquery.Client"), - patch("connectors.bigquery.client.get_config", return_value=mock_config), - patch.dict("os.environ", {}, clear=True), - ): - with pytest.raises(ValueError, match="BigQuery project ID not set"): - BigQueryClient() - - def test_raises_value_error_when_project_empty_string(self, mock_config): - """Init raises ValueError if BIGQUERY_PROJECT is set to empty string.""" - with ( - patch("connectors.bigquery.client.bigquery.Client"), - patch("connectors.bigquery.client.get_config", return_value=mock_config), - patch.dict("os.environ", {"BIGQUERY_PROJECT": ""}, clear=True), - ): - with pytest.raises(ValueError, match="BigQuery project ID not set"): - BigQueryClient() - - # ------------------------------------------------------------------- - # 2. Init creates client with correct project_id - # ------------------------------------------------------------------- - - def test_creates_client_with_env_project_id(self, mock_config): - """Client uses BIGQUERY_PROJECT from environment.""" - mock_bq = MagicMock() - with ( - patch("connectors.bigquery.client.bigquery.Client", return_value=mock_bq) as bq_cls, - patch("connectors.bigquery.client.get_config", return_value=mock_config), - patch.dict("os.environ", {"BIGQUERY_PROJECT": "env-project-123"}), - ): - client = BigQueryClient() - bq_cls.assert_called_once_with(project="env-project-123") - assert client.project_id == "env-project-123" - - def test_creates_client_with_explicit_project_id(self, mock_config): - """Explicit project_id argument takes precedence over env var.""" - mock_bq = MagicMock() - with ( - patch("connectors.bigquery.client.bigquery.Client", return_value=mock_bq) as bq_cls, - patch("connectors.bigquery.client.get_config", return_value=mock_config), - ): - client = BigQueryClient(project_id="explicit-project") - bq_cls.assert_called_once_with(project="explicit-project") - assert client.project_id == "explicit-project" - - -# --------------------------------------------------------------------------- -# 3. get_table_metadata fetches and caches metadata correctly -# --------------------------------------------------------------------------- - -class TestGetTableMetadata: - def test_fetches_metadata_from_bigquery(self, client, mock_bq_client): - """get_table_metadata calls client.get_table and returns correct dict.""" - table_id = "proj.dataset.orders" - schema = [ - _make_bq_field("order_id", "INTEGER"), - _make_bq_field("customer_name", "STRING", description="Full name"), - _make_bq_field("created_at", "TIMESTAMP"), - ] - table_ref = _make_table_ref( - table_id=table_id, - schema=schema, - num_rows=5000, - num_bytes=120000, - ) - mock_bq_client.get_table.return_value = table_ref - - metadata = client.get_table_metadata(table_id, use_cache=False) - - mock_bq_client.get_table.assert_called_once_with(table_id) - assert metadata["table_id"] == table_id - assert metadata["name"] == "orders" - assert metadata["dataset"] == "dataset" - assert metadata["project"] == "proj" - assert metadata["columns"] == ["order_id", "customer_name", "created_at"] - assert metadata["column_types"]["order_id"] == "INTEGER" - assert metadata["column_types"]["customer_name"] == "STRING" - assert metadata["column_types"]["created_at"] == "TIMESTAMP" - assert metadata["column_descriptions"]["customer_name"] == "Full name" - assert "order_id" not in metadata["column_descriptions"] - assert metadata["row_count"] == 5000 - assert metadata["size_bytes"] == 120000 - assert "_cached_at" in metadata - - def test_caches_metadata_in_memory(self, client, mock_bq_client): - """After first fetch, metadata is stored in the in-memory cache.""" - table_id = "proj.dataset.tbl" - table_ref = _make_table_ref(table_id=table_id) - mock_bq_client.get_table.return_value = table_ref - - client.get_table_metadata(table_id, use_cache=False) - - assert table_id in client.metadata_cache - assert client.metadata_cache[table_id]["table_id"] == table_id - - def test_captures_partitioning_info(self, client, mock_bq_client): - """Partitioning metadata is captured when table is partitioned.""" - table_id = "proj.dataset.events" - partition = MagicMock() - partition.type_ = "DAY" - partition.field = "event_date" - partition.expiration_ms = 7776000000 - - table_ref = _make_table_ref(table_id=table_id, time_partitioning=partition) - mock_bq_client.get_table.return_value = table_ref - - metadata = client.get_table_metadata(table_id, use_cache=False) - - assert metadata["partitioning"] is not None - assert metadata["partitioning"]["type"] == "DAY" - assert metadata["partitioning"]["field"] == "event_date" - assert metadata["partitioning"]["expiration_ms"] == 7776000000 - - def test_no_partitioning_when_absent(self, client, mock_bq_client): - """Partitioning is None when table has no partitioning.""" - table_id = "proj.dataset.simple" - table_ref = _make_table_ref(table_id=table_id, time_partitioning=None) - mock_bq_client.get_table.return_value = table_ref - - metadata = client.get_table_metadata(table_id, use_cache=False) - assert metadata["partitioning"] is None - - # ------------------------------------------------------------------- - # 4. get_table_metadata uses cache when available (within TTL) - # ------------------------------------------------------------------- - - def test_uses_cache_within_ttl(self, client, mock_bq_client): - """When cache is fresh (within TTL), BQ API is not called again.""" - table_id = "proj.dataset.cached_tbl" - now = datetime.now() - client.metadata_cache[table_id] = { - "table_id": table_id, - "columns": ["a", "b"], - "column_types": {"a": "STRING", "b": "INTEGER"}, - "_cached_at": now.isoformat(), - } - - result = client.get_table_metadata(table_id, use_cache=True, cache_ttl_hours=24) - - mock_bq_client.get_table.assert_not_called() - assert result["table_id"] == table_id - assert result["columns"] == ["a", "b"] - - def test_refetches_when_cache_expired(self, client, mock_bq_client): - """When cache is older than TTL, metadata is re-fetched from BQ.""" - table_id = "proj.dataset.stale_tbl" - old_time = (datetime.now() - timedelta(hours=48)).isoformat() - client.metadata_cache[table_id] = { - "table_id": table_id, - "columns": ["old_col"], - "column_types": {"old_col": "STRING"}, - "_cached_at": old_time, - } - - table_ref = _make_table_ref( - table_id=table_id, - schema=[_make_bq_field("new_col", "INTEGER")], - ) - mock_bq_client.get_table.return_value = table_ref - - result = client.get_table_metadata(table_id, use_cache=True, cache_ttl_hours=24) - - mock_bq_client.get_table.assert_called_once_with(table_id) - assert result["columns"] == ["new_col"] - - def test_bypasses_cache_when_use_cache_false(self, client, mock_bq_client): - """When use_cache=False, always fetches from BQ even if cache is fresh.""" - table_id = "proj.dataset.force_fetch" - client.metadata_cache[table_id] = { - "table_id": table_id, - "columns": ["cached"], - "column_types": {"cached": "STRING"}, - "_cached_at": datetime.now().isoformat(), - } - - table_ref = _make_table_ref( - table_id=table_id, - schema=[_make_bq_field("fresh", "INTEGER")], - ) - mock_bq_client.get_table.return_value = table_ref - - result = client.get_table_metadata(table_id, use_cache=False) - mock_bq_client.get_table.assert_called_once() - assert result["columns"] == ["fresh"] - - -# --------------------------------------------------------------------------- -# 5. get_pyarrow_schema builds correct schema from BQ types -# --------------------------------------------------------------------------- - -class TestGetPyarrowSchema: - def test_builds_correct_schema(self, client, mock_bq_client): - """Schema maps BQ types to correct PyArrow types.""" - table_id = "proj.dataset.typed_tbl" - schema = [ - _make_bq_field("id", "INT64"), - _make_bq_field("name", "STRING"), - _make_bq_field("price", "FLOAT64"), - _make_bq_field("active", "BOOLEAN"), - _make_bq_field("created", "DATE"), - _make_bq_field("updated_at", "TIMESTAMP"), - ] - table_ref = _make_table_ref(table_id=table_id, schema=schema) - mock_bq_client.get_table.return_value = table_ref - - pa_schema = client.get_pyarrow_schema(table_id) - - assert pa_schema is not None - assert pa_schema.field("id").type == pa.int64() - assert pa_schema.field("name").type == pa.string() - assert pa_schema.field("price").type == pa.float64() - assert pa_schema.field("active").type == pa.bool_() - assert pa_schema.field("created").type == pa.date32() - assert pa_schema.field("updated_at").type == pa.timestamp("us", tz="UTC") - - def test_returns_none_when_no_column_types(self, client): - """Returns None when metadata has no column_types.""" - table_id = "proj.dataset.empty_schema" - client.metadata_cache[table_id] = { - "table_id": table_id, - "columns": [], - "column_types": {}, - "_cached_at": datetime.now().isoformat(), - } - - result = client.get_pyarrow_schema(table_id) - assert result is None - - def test_unknown_type_falls_back_to_string(self, client, mock_bq_client): - """Unknown BQ types default to pa.string() in the schema.""" - table_id = "proj.dataset.exotic_types" - schema = [_make_bq_field("exotic_col", "SOME_UNKNOWN_TYPE")] - table_ref = _make_table_ref(table_id=table_id, schema=schema) - mock_bq_client.get_table.return_value = table_ref - - pa_schema = client.get_pyarrow_schema(table_id) - assert pa_schema.field("exotic_col").type == pa.string() - - -# --------------------------------------------------------------------------- -# 6. get_date_columns returns only DATE columns -# --------------------------------------------------------------------------- - -class TestGetDateColumns: - def test_returns_only_date_columns(self, client, mock_bq_client): - """Only columns with BQ type DATE are returned.""" - table_id = "proj.dataset.mixed_dates" - schema = [ - _make_bq_field("event_date", "DATE"), - _make_bq_field("created_at", "TIMESTAMP"), - _make_bq_field("name", "STRING"), - _make_bq_field("birth_date", "DATE"), - _make_bq_field("updated_ts", "DATETIME"), - ] - table_ref = _make_table_ref(table_id=table_id, schema=schema) - mock_bq_client.get_table.return_value = table_ref - - date_cols = client.get_date_columns(table_id) - assert sorted(date_cols) == ["birth_date", "event_date"] - - def test_returns_empty_when_no_date_columns(self, client, mock_bq_client): - """Returns empty list when no DATE columns exist.""" - table_id = "proj.dataset.no_dates" - schema = [ - _make_bq_field("id", "INTEGER"), - _make_bq_field("ts", "TIMESTAMP"), - ] - table_ref = _make_table_ref(table_id=table_id, schema=schema) - mock_bq_client.get_table.return_value = table_ref - - date_cols = client.get_date_columns(table_id) - assert date_cols == [] - - -# --------------------------------------------------------------------------- -# 7. query_to_arrow executes SQL and returns PyArrow table -# --------------------------------------------------------------------------- - -class TestQueryToArrow: - def test_executes_query_and_returns_arrow(self, client, mock_bq_client): - """query_to_arrow passes SQL to BQ and returns the arrow result.""" - expected_table = pa.table({"col1": [1, 2, 3]}) - mock_job = MagicMock() - mock_job.to_arrow.return_value = expected_table - mock_bq_client.query.return_value = mock_job - - with patch("connectors.bigquery.client.bigquery") as mock_bq_module: - mock_bq_module.QueryJobConfig.return_value = MagicMock(query_parameters=None) - client.client = mock_bq_client - - result = client.query_to_arrow("SELECT * FROM `proj.dataset.tbl`") - - mock_bq_client.query.assert_called_once() - call_args = mock_bq_client.query.call_args - assert call_args[0][0] == "SELECT * FROM `proj.dataset.tbl`" - assert result.equals(expected_table) - - def test_passes_query_parameters(self, client, mock_bq_client): - """query_to_arrow forwards BQ query parameters in job config.""" - expected_table = pa.table({"col1": [10]}) - mock_job = MagicMock() - mock_job.to_arrow.return_value = expected_table - mock_bq_client.query.return_value = mock_job - - mock_job_config = MagicMock() - params = [MagicMock()] # Mock ScalarQueryParameter - - with patch("connectors.bigquery.client.bigquery") as mock_bq_module: - mock_bq_module.QueryJobConfig.return_value = mock_job_config - client.client = mock_bq_client - - client.query_to_arrow("SELECT 1 WHERE x > @val", params=params) - - # Verify params were set on the job config - assert mock_job_config.query_parameters == params - - def test_no_params_does_not_set_query_parameters(self, client, mock_bq_client): - """When no params given, query_parameters is not set on job config.""" - mock_job = MagicMock() - mock_job.to_arrow.return_value = pa.table({"x": [1]}) - mock_bq_client.query.return_value = mock_job - - mock_job_config = MagicMock(spec=[]) - with patch("connectors.bigquery.client.bigquery") as mock_bq_module: - mock_bq_module.QueryJobConfig.return_value = mock_job_config - client.client = mock_bq_client - - client.query_to_arrow("SELECT 1") - - # query_parameters should not have been set - assert not hasattr(mock_job_config, "query_parameters") or not getattr( - mock_job_config, "query_parameters", None - ) - - -# --------------------------------------------------------------------------- -# 8. read_table builds correct SQL query -# --------------------------------------------------------------------------- - -class TestReadTable: - def test_full_table_select_all(self, client, mock_bq_client): - """read_table with no columns or filter generates SELECT *.""" - mock_job = MagicMock() - mock_job.to_arrow.return_value = pa.table({"a": [1]}) - mock_bq_client.query.return_value = mock_job - - client.read_table("proj.dataset.tbl") - - sql = mock_bq_client.query.call_args[0][0] - assert "SELECT *" in sql - assert "`proj.dataset.tbl`" in sql - assert "WHERE" not in sql - - def test_select_specific_columns(self, client, mock_bq_client): - """read_table with columns list generates SELECT with backtick-quoted names.""" - mock_job = MagicMock() - mock_job.to_arrow.return_value = pa.table({"a": [1]}) - mock_bq_client.query.return_value = mock_job - - client.read_table("proj.dataset.tbl", columns=["col_a", "col_b"]) - - sql = mock_bq_client.query.call_args[0][0] - assert "`col_a`" in sql - assert "`col_b`" in sql - assert "*" not in sql - - def test_with_row_filter(self, client, mock_bq_client): - """read_table with row_filter appends WHERE clause.""" - mock_job = MagicMock() - mock_job.to_arrow.return_value = pa.table({"a": [1]}) - mock_bq_client.query.return_value = mock_job - - client.read_table("proj.dataset.tbl", row_filter="status = 'active'") - - sql = mock_bq_client.query.call_args[0][0] - assert "WHERE status = 'active'" in sql - - def test_columns_and_filter_combined(self, client, mock_bq_client): - """read_table with both columns and filter generates correct SQL.""" - mock_job = MagicMock() - mock_job.to_arrow.return_value = pa.table({"x": [1]}) - mock_bq_client.query.return_value = mock_job - - client.read_table( - "proj.dataset.tbl", - columns=["id", "name"], - row_filter="id > 100", - ) - - sql = mock_bq_client.query.call_args[0][0] - assert "`id`, `name`" in sql - assert "WHERE id > 100" in sql - assert "`proj.dataset.tbl`" in sql - - -# --------------------------------------------------------------------------- -# 9. read_table_incremental builds parameterized WHERE clause -# --------------------------------------------------------------------------- - -class TestReadTableIncremental: - def test_incremental_query_structure(self, client, mock_bq_client): - """read_table_incremental builds WHERE col > @since_value with params.""" - mock_job = MagicMock() - mock_job.to_arrow.return_value = pa.table({"a": [1]}) - mock_bq_client.query.return_value = mock_job - - with patch("connectors.bigquery.client.bigquery") as mock_bq_module: - mock_bq_module.QueryJobConfig.return_value = MagicMock() - mock_param = MagicMock() - mock_bq_module.ScalarQueryParameter.return_value = mock_param - # Re-assign the client's bq client (the fixture already set it up) - client.client = mock_bq_client - - client.read_table_incremental( - table_id="proj.dataset.events", - incremental_column="updated_at", - since_value="2025-01-01T00:00:00Z", - ) - - sql = mock_bq_client.query.call_args[0][0] - assert "SELECT *" in sql - assert "`proj.dataset.events`" in sql - assert "`updated_at` > @since_value" in sql - - # Verify ScalarQueryParameter was constructed correctly - mock_bq_module.ScalarQueryParameter.assert_called_once_with( - "since_value", "TIMESTAMP", "2025-01-01T00:00:00Z" - ) - - def test_incremental_with_columns(self, client, mock_bq_client): - """read_table_incremental with columns list selects specific columns.""" - mock_job = MagicMock() - mock_job.to_arrow.return_value = pa.table({"a": [1]}) - mock_bq_client.query.return_value = mock_job - - with patch("connectors.bigquery.client.bigquery") as mock_bq_module: - mock_bq_module.QueryJobConfig.return_value = MagicMock() - mock_bq_module.ScalarQueryParameter.return_value = MagicMock() - client.client = mock_bq_client - - client.read_table_incremental( - table_id="proj.dataset.events", - incremental_column="updated_at", - since_value="2025-01-01T00:00:00Z", - columns=["id", "name"], - ) - - sql = mock_bq_client.query.call_args[0][0] - assert "`id`, `name`" in sql - assert "*" not in sql - - -# --------------------------------------------------------------------------- -# 10. read_table_partitioned builds correct range query -# --------------------------------------------------------------------------- - -class TestReadTablePartitioned: - def test_partitioned_start_only(self, client, mock_bq_client): - """With only start, generates >= @start_value without end clause.""" - mock_job = MagicMock() - mock_job.to_arrow.return_value = pa.table({"a": [1]}) - mock_bq_client.query.return_value = mock_job - - with patch("connectors.bigquery.client.bigquery") as mock_bq_module: - mock_bq_module.QueryJobConfig.return_value = MagicMock() - mock_bq_module.ScalarQueryParameter.return_value = MagicMock() - client.client = mock_bq_client - - client.read_table_partitioned( - table_id="proj.dataset.events", - partition_column="event_date", - start="2025-01-01", - ) - - sql = mock_bq_client.query.call_args[0][0] - assert "`event_date` >= @start_value" in sql - assert "@end_value" not in sql - - # Only start_value parameter created - assert mock_bq_module.ScalarQueryParameter.call_count == 1 - mock_bq_module.ScalarQueryParameter.assert_called_with( - "start_value", "TIMESTAMP", "2025-01-01" - ) - - def test_partitioned_start_and_end(self, client, mock_bq_client): - """With start and end, generates >= @start_value AND < @end_value.""" - mock_job = MagicMock() - mock_job.to_arrow.return_value = pa.table({"a": [1]}) - mock_bq_client.query.return_value = mock_job - - with patch("connectors.bigquery.client.bigquery") as mock_bq_module: - mock_bq_module.QueryJobConfig.return_value = MagicMock() - mock_bq_module.ScalarQueryParameter.return_value = MagicMock() - client.client = mock_bq_client - - client.read_table_partitioned( - table_id="proj.dataset.events", - partition_column="event_date", - start="2025-01-01", - end="2025-06-01", - ) - - sql = mock_bq_client.query.call_args[0][0] - assert "`event_date` >= @start_value" in sql - assert "`event_date` < @end_value" in sql - - # Both start_value and end_value parameters created - assert mock_bq_module.ScalarQueryParameter.call_count == 2 - calls = mock_bq_module.ScalarQueryParameter.call_args_list - assert calls[0].args == ("start_value", "TIMESTAMP", "2025-01-01") - assert calls[1].args == ("end_value", "TIMESTAMP", "2025-06-01") - - def test_partitioned_with_columns(self, client, mock_bq_client): - """read_table_partitioned with columns selects specific columns.""" - mock_job = MagicMock() - mock_job.to_arrow.return_value = pa.table({"a": [1]}) - mock_bq_client.query.return_value = mock_job - - with patch("connectors.bigquery.client.bigquery") as mock_bq_module: - mock_bq_module.QueryJobConfig.return_value = MagicMock() - mock_bq_module.ScalarQueryParameter.return_value = MagicMock() - client.client = mock_bq_client - - client.read_table_partitioned( - table_id="proj.dataset.events", - partition_column="event_date", - start="2025-01-01", - columns=["id", "event_date", "value"], - ) - - sql = mock_bq_client.query.call_args[0][0] - assert "`id`, `event_date`, `value`" in sql - assert "*" not in sql - - -# --------------------------------------------------------------------------- -# 11. test_connection returns True on success, False on failure -# --------------------------------------------------------------------------- - -class TestTestConnection: - def test_returns_true_on_success(self, client, mock_bq_client): - """test_connection returns True when SELECT 1 query succeeds.""" - mock_job = MagicMock() - mock_job.result.return_value = iter([(1,)]) - mock_bq_client.query.return_value = mock_job - - assert client.test_connection() is True - mock_bq_client.query.assert_called_once_with("SELECT 1") - - def test_returns_false_on_failure(self, client, mock_bq_client): - """test_connection returns False when the query raises an exception.""" - mock_bq_client.query.side_effect = Exception("Connection refused") - - assert client.test_connection() is False - - def test_returns_false_when_result_fails(self, client, mock_bq_client): - """test_connection returns False when result iteration fails.""" - mock_job = MagicMock() - mock_job.result.side_effect = Exception("Timeout") - mock_bq_client.query.return_value = mock_job - - assert client.test_connection() is False - - -# --------------------------------------------------------------------------- -# 12. Type mapping completeness (all BQ types have PyArrow mapping) -# --------------------------------------------------------------------------- - -class TestTypeMapping: - # All standard BigQuery types that should be mapped - EXPECTED_BQ_TYPES = [ - "STRING", "BYTES", "INTEGER", "INT64", - "FLOAT", "FLOAT64", "NUMERIC", "BIGNUMERIC", - "BOOLEAN", "BOOL", - "TIMESTAMP", "DATE", "TIME", "DATETIME", - "GEOGRAPHY", "JSON", - "STRUCT", "RECORD", "ARRAY", - ] - - def test_all_standard_bq_types_are_mapped(self): - """Every standard BigQuery type has an entry in BIGQUERY_TO_PYARROW_TYPES.""" - for bq_type in self.EXPECTED_BQ_TYPES: - assert bq_type in BIGQUERY_TO_PYARROW_TYPES, ( - f"Missing PyArrow mapping for BQ type: {bq_type}" - ) - - def test_all_mappings_produce_valid_pyarrow_types(self): - """Every mapped value is a valid PyArrow DataType.""" - for bq_type, pa_type in BIGQUERY_TO_PYARROW_TYPES.items(): - assert isinstance(pa_type, pa.DataType), ( - f"BQ type {bq_type} maps to non-DataType: {pa_type!r}" - ) - - def test_integer_types_map_to_int64(self): - """Both INTEGER and INT64 map to pa.int64().""" - assert BIGQUERY_TO_PYARROW_TYPES["INTEGER"] == pa.int64() - assert BIGQUERY_TO_PYARROW_TYPES["INT64"] == pa.int64() - - def test_float_types_map_to_float64(self): - """FLOAT, FLOAT64, NUMERIC, BIGNUMERIC all map to pa.float64().""" - for t in ["FLOAT", "FLOAT64", "NUMERIC", "BIGNUMERIC"]: - assert BIGQUERY_TO_PYARROW_TYPES[t] == pa.float64() - - def test_boolean_types_map_to_bool(self): - """Both BOOLEAN and BOOL map to pa.bool_().""" - assert BIGQUERY_TO_PYARROW_TYPES["BOOLEAN"] == pa.bool_() - assert BIGQUERY_TO_PYARROW_TYPES["BOOL"] == pa.bool_() - - def test_date_maps_to_date32(self): - """DATE maps to pa.date32().""" - assert BIGQUERY_TO_PYARROW_TYPES["DATE"] == pa.date32() - - def test_timestamp_has_utc_timezone(self): - """TIMESTAMP maps to pa.timestamp with UTC timezone.""" - ts_type = BIGQUERY_TO_PYARROW_TYPES["TIMESTAMP"] - assert ts_type == pa.timestamp("us", tz="UTC") - - def test_datetime_has_no_timezone(self): - """DATETIME maps to pa.timestamp without timezone.""" - dt_type = BIGQUERY_TO_PYARROW_TYPES["DATETIME"] - assert dt_type == pa.timestamp("us") - - def test_complex_types_map_to_string(self): - """STRUCT, RECORD, ARRAY, GEOGRAPHY, JSON all serialize as string.""" - for t in ["STRUCT", "RECORD", "ARRAY", "GEOGRAPHY", "JSON"]: - assert BIGQUERY_TO_PYARROW_TYPES[t] == pa.string() - - -# --------------------------------------------------------------------------- -# 13. Metadata cache save/load from disk -# --------------------------------------------------------------------------- - -class TestMetadataCachePersistence: - def test_save_and_load_cache(self, tmp_path): - """Metadata cache is persisted to disk and reloaded on new client init.""" - metadata_dir = tmp_path / "metadata" - metadata_dir.mkdir(parents=True, exist_ok=True) - cache_file = metadata_dir / "bq_table_metadata.json" - - mock_config = MagicMock() - mock_config.get_metadata_path.return_value = metadata_dir - - # First client: fetch metadata and save to cache - mock_bq = MagicMock() - table_id = "proj.ds.tbl" - schema = [_make_bq_field("col1", "STRING")] - table_ref = _make_table_ref(table_id=table_id, schema=schema) - mock_bq.get_table.return_value = table_ref - - with ( - patch("connectors.bigquery.client.bigquery.Client", return_value=mock_bq), - patch("connectors.bigquery.client.get_config", return_value=mock_config), - patch.dict("os.environ", {"BIGQUERY_PROJECT": "proj"}), - ): - client1 = BigQueryClient() - client1.get_table_metadata(table_id, use_cache=False) - - # Verify the cache file was written - assert cache_file.exists() - saved_data = json.loads(cache_file.read_text()) - assert table_id in saved_data - assert saved_data[table_id]["columns"] == ["col1"] - - # Second client: loads cache from disk on init - mock_bq2 = MagicMock() - with ( - patch("connectors.bigquery.client.bigquery.Client", return_value=mock_bq2), - patch("connectors.bigquery.client.get_config", return_value=mock_config), - patch.dict("os.environ", {"BIGQUERY_PROJECT": "proj"}), - ): - client2 = BigQueryClient() - - assert table_id in client2.metadata_cache - assert client2.metadata_cache[table_id]["columns"] == ["col1"] - - def test_load_handles_corrupt_cache_file(self, tmp_path): - """Client handles corrupt cache JSON gracefully without crashing.""" - metadata_dir = tmp_path / "metadata" - metadata_dir.mkdir(parents=True, exist_ok=True) - cache_file = metadata_dir / "bq_table_metadata.json" - cache_file.write_text("{corrupt json!!!") - - mock_config = MagicMock() - mock_config.get_metadata_path.return_value = metadata_dir - - mock_bq = MagicMock() - with ( - patch("connectors.bigquery.client.bigquery.Client", return_value=mock_bq), - patch("connectors.bigquery.client.get_config", return_value=mock_config), - patch.dict("os.environ", {"BIGQUERY_PROJECT": "proj"}), - ): - client = BigQueryClient() - - # Cache should be empty after corrupt file - assert client.metadata_cache == {} - - def test_load_handles_missing_cache_file(self, tmp_path): - """Client initializes with empty cache when no cache file exists.""" - metadata_dir = tmp_path / "metadata" - metadata_dir.mkdir(parents=True, exist_ok=True) - # No cache file created - - mock_config = MagicMock() - mock_config.get_metadata_path.return_value = metadata_dir - - mock_bq = MagicMock() - with ( - patch("connectors.bigquery.client.bigquery.Client", return_value=mock_bq), - patch("connectors.bigquery.client.get_config", return_value=mock_config), - patch.dict("os.environ", {"BIGQUERY_PROJECT": "proj"}), - ): - client = BigQueryClient() - - assert client.metadata_cache == {} - - def test_save_creates_parent_directories(self, tmp_path): - """_save_metadata_cache creates parent directories if they do not exist.""" - # Use a nested path that does not yet exist - metadata_dir = tmp_path / "deep" / "nested" / "metadata" - # Do NOT create directories upfront - - mock_config = MagicMock() - mock_config.get_metadata_path.return_value = metadata_dir - - mock_bq = MagicMock() - table_id = "proj.ds.tbl" - schema = [_make_bq_field("x", "INTEGER")] - table_ref = _make_table_ref(table_id=table_id, schema=schema) - mock_bq.get_table.return_value = table_ref - - with ( - patch("connectors.bigquery.client.bigquery.Client", return_value=mock_bq), - patch("connectors.bigquery.client.get_config", return_value=mock_config), - patch.dict("os.environ", {"BIGQUERY_PROJECT": "proj"}), - ): - client = BigQueryClient() - client.get_table_metadata(table_id, use_cache=False) - - cache_file = metadata_dir / "bq_table_metadata.json" - assert cache_file.exists() - - -# --------------------------------------------------------------------------- -# Factory function -# --------------------------------------------------------------------------- - -class TestCreateClient: - def test_create_client_returns_bigquery_client(self, mock_config): - """create_client() factory returns a BigQueryClient instance.""" - mock_bq = MagicMock() - with ( - patch("connectors.bigquery.client.bigquery.Client", return_value=mock_bq), - patch("connectors.bigquery.client.get_config", return_value=mock_config), - patch.dict("os.environ", {"BIGQUERY_PROJECT": "factory-project"}), - ): - result = create_client() - - assert isinstance(result, BigQueryClient) - assert result.project_id == "factory-project" - - -# --------------------------------------------------------------------------- -# 14. read_table_partitioned_streaming yields RecordBatches -# --------------------------------------------------------------------------- - -class TestReadTablePartitionedStreaming: - def test_streaming_yields_batches(self, client, mock_bq_client): - """read_table_partitioned_streaming yields RecordBatches, not a Table.""" - batch1 = pa.record_batch({"a": [1, 2]}) - batch2 = pa.record_batch({"a": [3, 4]}) - - mock_query_job = MagicMock() - mock_row_iter = MagicMock() - mock_row_iter.to_arrow_iterable.return_value = iter([batch1, batch2]) - mock_query_job.result.return_value = mock_row_iter - mock_bq_client.query.return_value = mock_query_job - - with patch("connectors.bigquery.client.bigquery") as mock_bq_module: - mock_bq_module.QueryJobConfig.return_value = MagicMock() - mock_bq_module.ScalarQueryParameter.return_value = MagicMock() - client.client = mock_bq_client - client.bqstorage_client = None # disable Storage API for simplicity - - batches = list(client.read_table_partitioned_streaming( - table_id="proj.dataset.events", - partition_column="event_date", - start="2025-01-01", - )) - - assert len(batches) == 2 - assert isinstance(batches[0], pa.RecordBatch) - assert isinstance(batches[1], pa.RecordBatch) - - def test_streaming_with_date_column_type(self, client, mock_bq_client): - """With column_type='DATE', ScalarQueryParameter uses 'DATE' type.""" - batch = pa.record_batch({"a": [1]}) - - mock_query_job = MagicMock() - mock_row_iter = MagicMock() - mock_row_iter.to_arrow_iterable.return_value = iter([batch]) - mock_query_job.result.return_value = mock_row_iter - mock_bq_client.query.return_value = mock_query_job - - with patch("connectors.bigquery.client.bigquery") as mock_bq_module: - mock_bq_module.QueryJobConfig.return_value = MagicMock() - mock_bq_module.ScalarQueryParameter.return_value = MagicMock() - client.client = mock_bq_client - client.bqstorage_client = None - - list(client.read_table_partitioned_streaming( - table_id="proj.dataset.events", - partition_column="event_date", - start="2025-01-01", - column_type="DATE", - )) - - # Verify ScalarQueryParameter was called with "DATE" type - mock_bq_module.ScalarQueryParameter.assert_called_once_with( - "start_value", "DATE", "2025-01-01" - ) - - def test_streaming_start_and_end(self, client, mock_bq_client): - """With start and end, both params are created with correct column_type.""" - batch = pa.record_batch({"a": [1]}) - - mock_query_job = MagicMock() - mock_row_iter = MagicMock() - mock_row_iter.to_arrow_iterable.return_value = iter([batch]) - mock_query_job.result.return_value = mock_row_iter - mock_bq_client.query.return_value = mock_query_job - - with patch("connectors.bigquery.client.bigquery") as mock_bq_module: - mock_bq_module.QueryJobConfig.return_value = MagicMock() - mock_bq_module.ScalarQueryParameter.return_value = MagicMock() - client.client = mock_bq_client - client.bqstorage_client = None - - list(client.read_table_partitioned_streaming( - table_id="proj.dataset.events", - partition_column="event_date", - start="2025-01-01", - end="2025-06-01", - column_type="DATE", - )) - - sql = mock_bq_client.query.call_args[0][0] - assert "`event_date` >= @start_value" in sql - assert "`event_date` < @end_value" in sql - - # Both parameters created with "DATE" type - assert mock_bq_module.ScalarQueryParameter.call_count == 2 - calls = mock_bq_module.ScalarQueryParameter.call_args_list - assert calls[0].args == ("start_value", "DATE", "2025-01-01") - assert calls[1].args == ("end_value", "DATE", "2025-06-01") - - -# --------------------------------------------------------------------------- -# 15. read_table_partitioned column_type parameter -# --------------------------------------------------------------------------- - -class TestReadTablePartitionedColumnType: - def test_date_column_type(self, client, mock_bq_client): - """read_table_partitioned with column_type='DATE' creates DATE params.""" - mock_job = MagicMock() - mock_job.to_arrow.return_value = pa.table({"a": [1]}) - mock_bq_client.query.return_value = mock_job - - with patch("connectors.bigquery.client.bigquery") as mock_bq_module: - mock_bq_module.QueryJobConfig.return_value = MagicMock() - mock_bq_module.ScalarQueryParameter.return_value = MagicMock() - client.client = mock_bq_client - - client.read_table_partitioned( - table_id="proj.dataset.events", - partition_column="event_date", - start="2025-01-01", - end="2025-06-01", - column_type="DATE", - ) - - # Both parameters should use "DATE" type - assert mock_bq_module.ScalarQueryParameter.call_count == 2 - calls = mock_bq_module.ScalarQueryParameter.call_args_list - assert calls[0].args == ("start_value", "DATE", "2025-01-01") - assert calls[1].args == ("end_value", "DATE", "2025-06-01") - - def test_default_column_type_is_timestamp(self, client, mock_bq_client): - """Default column_type is TIMESTAMP when not specified.""" - mock_job = MagicMock() - mock_job.to_arrow.return_value = pa.table({"a": [1]}) - mock_bq_client.query.return_value = mock_job - - with patch("connectors.bigquery.client.bigquery") as mock_bq_module: - mock_bq_module.QueryJobConfig.return_value = MagicMock() - mock_bq_module.ScalarQueryParameter.return_value = MagicMock() - client.client = mock_bq_client - - client.read_table_partitioned( - table_id="proj.dataset.events", - partition_column="created_at", - start="2025-01-01T00:00:00Z", - ) - - # Should default to "TIMESTAMP" - mock_bq_module.ScalarQueryParameter.assert_called_once_with( - "start_value", "TIMESTAMP", "2025-01-01T00:00:00Z" - ) diff --git a/tests/test_catalog_export.py b/tests/test_catalog_export.py index 082dcb5..9d9d04d 100644 --- a/tests/test_catalog_export.py +++ b/tests/test_catalog_export.py @@ -93,17 +93,15 @@ def mock_client(): @pytest.fixture -def mock_config(): - """Return a mock Config with a single table.""" - cfg = MagicMock() - cfg.tables = [ - FakeTableConfig( - name="order_economics", - id="prj.dataset.order_economics", - catalog_fqn="bigquery.prj.dataset.order_economics", - ) +def mock_tables(): + """Return a list of table dicts matching TableRegistryRepository.list_all() format.""" + return [ + { + "id": "prj.dataset.order_economics", + "name": "order_economics", + "catalog_fqn": "bigquery.prj.dataset.order_economics", + } ] - return cfg CATALOG_URL = "https://catalog.example.com" @@ -351,11 +349,11 @@ class TestExportMetrics: class TestExportTables: def test_export_tables_writes_files( - self, tmp_path: Path, mock_client, mock_config + self, tmp_path: Path, mock_client, mock_tables ): """Creates table YAML with columns.""" docs = tmp_path / "docs" - count = export_tables(mock_client, mock_config, docs, CATALOG_URL) + count = export_tables(mock_client, mock_tables, docs, CATALOG_URL) assert count == 1 @@ -374,21 +372,12 @@ class TestExportTables: assert parsed["columns"][0]["name"] == "order_id" def test_export_tables_handles_api_error( - self, tmp_path: Path, mock_client, mock_config + self, tmp_path: Path, mock_client ): """Continues on per-table errors, exports remaining tables.""" - # Two tables: first will fail, second succeeds - mock_config.tables = [ - FakeTableConfig( - name="broken_table", - id="prj.dataset.broken", - catalog_fqn="bigquery.prj.dataset.broken", - ), - FakeTableConfig( - name="good_table", - id="prj.dataset.good", - catalog_fqn="bigquery.prj.dataset.good", - ), + tables = [ + {"id": "prj.dataset.broken", "name": "broken_table", "catalog_fqn": "bigquery.prj.dataset.broken"}, + {"id": "prj.dataset.good", "name": "good_table", "catalog_fqn": "bigquery.prj.dataset.good"}, ] def side_effect(fqn): @@ -399,7 +388,7 @@ class TestExportTables: mock_client.get_table.side_effect = side_effect docs = tmp_path / "docs" - count = export_tables(mock_client, mock_config, docs, CATALOG_URL) + count = export_tables(mock_client, tables, docs, CATALOG_URL) # Only the good table should succeed assert count == 1 @@ -407,11 +396,11 @@ class TestExportTables: assert (docs / "tables" / "good_table.yml").exists() def test_export_tables_uses_catalog_fqn( - self, tmp_path: Path, mock_client, mock_config + self, tmp_path: Path, mock_client, mock_tables ): """Uses explicit catalog_fqn when set on table config.""" docs = tmp_path / "docs" - export_tables(mock_client, mock_config, docs, CATALOG_URL) + export_tables(mock_client, mock_tables, docs, CATALOG_URL) mock_client.get_table.assert_called_once_with( "bigquery.prj.dataset.order_economics" @@ -421,17 +410,12 @@ class TestExportTables: self, tmp_path: Path, mock_client ): """When catalog_fqn is None, derives FQN as 'bigquery.{id}'.""" - cfg = MagicMock() - cfg.tables = [ - FakeTableConfig( - name="my_table", - id="project.dataset.my_table", - catalog_fqn=None, - ) + tables = [ + {"id": "project.dataset.my_table", "name": "my_table"}, ] docs = tmp_path / "docs" - export_tables(mock_client, cfg, docs, CATALOG_URL) + export_tables(mock_client, tables, docs, CATALOG_URL) mock_client.get_table.assert_called_once_with( "bigquery.project.dataset.my_table" diff --git a/tests/test_config_bq_entity_type.py b/tests/test_config_bq_entity_type.py deleted file mode 100644 index 6640074..0000000 --- a/tests/test_config_bq_entity_type.py +++ /dev/null @@ -1,69 +0,0 @@ -"""Tests for TableConfig.bq_entity_type field validation.""" - -import pytest - -from src.config import TableConfig - - -# --------------------------------------------------------------------------- -# Helpers -# --------------------------------------------------------------------------- -def _make_table(**overrides) -> TableConfig: - """Create a TableConfig with sensible defaults, applying overrides.""" - defaults = dict( - id="test.dataset.table", - name="test_table", - description="Test", - primary_key="id", - sync_strategy="full_refresh", - ) - defaults.update(overrides) - return TableConfig(**defaults) - - -# --------------------------------------------------------------------------- -# Tests -# --------------------------------------------------------------------------- -class TestBqEntityTypeDefault: - def test_default_is_view(self): - table = _make_table() - assert table.bq_entity_type == "view" - - -class TestBqEntityTypeValidValues: - @pytest.mark.parametrize("entity_type", ["view", "table"]) - def test_valid_bq_entity_type(self, entity_type): - table = _make_table(bq_entity_type=entity_type) - assert table.bq_entity_type == entity_type - - -class TestBqEntityTypeInvalid: - @pytest.mark.parametrize("bad_type", ["VIEW", "physical", "", "tables", "materialized"]) - def test_invalid_bq_entity_type_raises(self, bad_type): - with pytest.raises(ValueError, match="Invalid bq_entity_type"): - _make_table(bq_entity_type=bad_type) - - -class TestBqEntityTypeFromKwarg: - def test_kwarg_sets_bq_entity_type(self): - """Simulate what _parse_data_description does: pass bq_entity_type as kwarg.""" - table = TableConfig( - id="proj.dataset.orders", - name="orders", - description="Order data", - primary_key="order_id", - sync_strategy="full_refresh", - bq_entity_type="table", - ) - assert table.bq_entity_type == "table" - - def test_kwarg_default_when_omitted(self): - """When YAML has no bq_entity_type, _parse_data_description passes 'view'.""" - table = TableConfig( - id="proj.dataset.orders", - name="orders", - description="Order data", - primary_key="order_id", - sync_strategy="full_refresh", - ) - assert table.bq_entity_type == "view" diff --git a/tests/test_config_query_mode.py b/tests/test_config_query_mode.py deleted file mode 100644 index 6370045..0000000 --- a/tests/test_config_query_mode.py +++ /dev/null @@ -1,69 +0,0 @@ -"""Tests for TableConfig.query_mode field validation.""" - -import pytest - -from src.config import TableConfig - - -# --------------------------------------------------------------------------- -# Helpers -# --------------------------------------------------------------------------- -def _make_table(**overrides) -> TableConfig: - """Create a TableConfig with sensible defaults, applying overrides.""" - defaults = dict( - id="test.dataset.table", - name="test_table", - description="Test", - primary_key="id", - sync_strategy="full_refresh", - ) - defaults.update(overrides) - return TableConfig(**defaults) - - -# --------------------------------------------------------------------------- -# Tests -# --------------------------------------------------------------------------- -class TestQueryModeDefault: - def test_default_is_local(self): - table = _make_table() - assert table.query_mode == "local" - - -class TestQueryModeValidValues: - @pytest.mark.parametrize("mode", ["local", "remote", "hybrid"]) - def test_valid_query_mode(self, mode): - table = _make_table(query_mode=mode) - assert table.query_mode == mode - - -class TestQueryModeInvalid: - @pytest.mark.parametrize("bad_mode", ["invalid", "Local", "REMOTE", "", "sql"]) - def test_invalid_query_mode_raises(self, bad_mode): - with pytest.raises(ValueError, match="Invalid query_mode"): - _make_table(query_mode=bad_mode) - - -class TestQueryModeFromKwarg: - def test_kwarg_sets_query_mode(self): - """Simulate what _parse_data_description does: pass query_mode as kwarg.""" - table = TableConfig( - id="proj.dataset.orders", - name="orders", - description="Order data", - primary_key="order_id", - sync_strategy="full_refresh", - query_mode="remote", - ) - assert table.query_mode == "remote" - - def test_kwarg_default_when_omitted(self): - """When YAML has no query_mode, _parse_data_description passes 'local'.""" - table = TableConfig( - id="proj.dataset.orders", - name="orders", - description="Order data", - primary_key="order_id", - sync_strategy="full_refresh", - ) - assert table.query_mode == "local" diff --git a/tests/test_config_sync_schedule.py b/tests/test_config_sync_schedule.py deleted file mode 100644 index 0a07c4a..0000000 --- a/tests/test_config_sync_schedule.py +++ /dev/null @@ -1,103 +0,0 @@ -"""Tests for TableConfig.sync_schedule field validation.""" - -import pytest - -from src.config import TableConfig - - -# --------------------------------------------------------------------------- -# Helpers -# --------------------------------------------------------------------------- -def _make_table(**overrides) -> TableConfig: - """Create a TableConfig with sensible defaults, applying overrides.""" - defaults = dict( - id="test.dataset.table", - name="test_table", - description="Test", - primary_key="id", - sync_strategy="full_refresh", - ) - defaults.update(overrides) - return TableConfig(**defaults) - - -# --------------------------------------------------------------------------- -# Tests -# --------------------------------------------------------------------------- -class TestSyncScheduleDefault: - def test_default_is_none(self): - table = _make_table() - assert table.sync_schedule is None - - -class TestSyncScheduleValidValues: - @pytest.mark.parametrize( - "schedule", - [ - "every 15m", - "every 1h", - "daily 05:00", - "daily 07:00,13:00,18:00", - "daily 00:00,12:00", - ], - ids=[ - "every-15m", - "every-1h", - "daily-single", - "daily-three-times", - "daily-two-times", - ], - ) - def test_valid_schedule_accepted(self, schedule: str): - table = _make_table(sync_schedule=schedule) - assert table.sync_schedule == schedule - - -class TestSyncScheduleEdgeCases: - def test_every_zero_minutes(self): - """every 0m matches the regex -- validation is syntactic, not semantic.""" - table = _make_table(sync_schedule="every 0m") - assert table.sync_schedule == "every 0m" - - def test_daily_2359(self): - table = _make_table(sync_schedule="daily 23:59") - assert table.sync_schedule == "daily 23:59" - - -class TestSyncScheduleInvalid: - @pytest.mark.parametrize( - "bad_schedule", - [ - "daily 07:00,13:00,18:00,", # trailing comma - "daily 7:00", # single-digit hour - "daily", # missing time - "hourly", # unsupported keyword - "weekly", # unsupported keyword - ], - ids=[ - "trailing-comma", - "single-digit-hour", - "daily-no-time", - "hourly-keyword", - "weekly-keyword", - ], - ) - def test_invalid_schedule_raises(self, bad_schedule: str): - with pytest.raises(ValueError, match="Invalid sync_schedule"): - _make_table(sync_schedule=bad_schedule) - - def test_empty_string_treated_as_none(self): - """Empty string is falsy, so validation is skipped (same as None).""" - table = _make_table(sync_schedule="") - assert table.sync_schedule == "" - - def test_daily_25_accepted_by_regex(self): - """25:00 passes regex validation (two digits). Document this behavior.""" - table = _make_table(sync_schedule="daily 25:00") - assert table.sync_schedule == "daily 25:00" - - -class TestSyncScheduleNoneExplicit: - def test_explicit_none_accepted(self): - table = _make_table(sync_schedule=None) - assert table.sync_schedule is None diff --git a/tests/test_data_sync_query_mode.py b/tests/test_data_sync_query_mode.py deleted file mode 100644 index 211b94e..0000000 --- a/tests/test_data_sync_query_mode.py +++ /dev/null @@ -1,228 +0,0 @@ -"""Tests for remote table skipping in DataSyncManager.sync_all(). - -Tables with query_mode == "remote" should be skipped during sync (no local -Parquet file is needed -- queries go directly to BigQuery). Tables with -query_mode "local" or "hybrid" must still be synced normally. -""" - -from unittest.mock import MagicMock, patch - -import pytest - -from src.config import TableConfig -from src.data_sync import DataSyncManager - - -# --------------------------------------------------------------------------- -# Helpers -# --------------------------------------------------------------------------- - -def _make_table_config( - table_id: str, - name: str, - query_mode: str = "local", -) -> TableConfig: - """Create a minimal TableConfig for testing.""" - return TableConfig( - id=table_id, - name=name, - description=f"Test table {name}", - primary_key="id", - sync_strategy="full_refresh", - query_mode=query_mode, - ) - - -def _successful_sync_result() -> dict: - """Return a fake successful sync result dict.""" - return { - "success": True, - "rows": 100, - "file_size_mb": 0.5, - } - - -# --------------------------------------------------------------------------- -# Fixtures -# --------------------------------------------------------------------------- - -@pytest.fixture -def table_local(): - return _make_table_config("t.local", "local_table", query_mode="local") - - -@pytest.fixture -def table_remote(): - return _make_table_config("t.remote", "remote_table", query_mode="remote") - - -@pytest.fixture -def table_hybrid(): - return _make_table_config("t.hybrid", "hybrid_table", query_mode="hybrid") - - -@pytest.fixture -def all_tables(table_local, table_remote, table_hybrid): - return [table_local, table_remote, table_hybrid] - - -@pytest.fixture -def mock_config(all_tables): - """Return a mock Config whose .tables list contains all three query modes.""" - cfg = MagicMock() - cfg.tables = all_tables - cfg.get_metadata_path.return_value = MagicMock() # Path-like - - def _get_table_config(tid): - return next((t for t in all_tables if t.id == tid), None) - - cfg.get_table_config.side_effect = _get_table_config - return cfg - - -@pytest.fixture -def mock_data_source(): - """Return a mock DataSource that always succeeds.""" - ds = MagicMock() - ds.sync_table.return_value = _successful_sync_result() - return ds - - -@pytest.fixture -def sync_manager(mock_config, mock_data_source): - """Create a DataSyncManager with mocked dependencies.""" - with ( - patch("src.data_sync.get_config", return_value=mock_config), - patch("src.data_sync.create_data_source", return_value=mock_data_source), - patch("src.data_sync.SyncState"), - ): - manager = DataSyncManager() - # Patch out schema generation and auto-profiling (not under test) - manager._generate_schema_yaml = MagicMock() - yield manager - - -# --------------------------------------------------------------------------- -# Tests -# --------------------------------------------------------------------------- - -class TestSyncAllRemoteSkipping: - """Verify that sync_all filters out remote tables.""" - - def test_remote_table_not_synced(self, sync_manager, mock_data_source, table_remote): - """Remote table must NOT be passed to data_source.sync_table.""" - sync_manager.sync_all() - - synced_ids = [ - call.args[0].id for call in mock_data_source.sync_table.call_args_list - ] - assert table_remote.id not in synced_ids - - def test_local_table_is_synced(self, sync_manager, mock_data_source, table_local): - """Local table must be synced normally.""" - sync_manager.sync_all() - - synced_ids = [ - call.args[0].id for call in mock_data_source.sync_table.call_args_list - ] - assert table_local.id in synced_ids - - def test_hybrid_table_is_synced(self, sync_manager, mock_data_source, table_hybrid): - """Hybrid table must be synced (needs local parquet for profiling).""" - sync_manager.sync_all() - - synced_ids = [ - call.args[0].id for call in mock_data_source.sync_table.call_args_list - ] - assert table_hybrid.id in synced_ids - - def test_sync_call_count(self, sync_manager, mock_data_source): - """Only local + hybrid tables should result in sync_table calls.""" - sync_manager.sync_all() - - # 3 tables total, 1 remote -> 2 sync calls - assert mock_data_source.sync_table.call_count == 2 - - def test_results_exclude_remote(self, sync_manager, table_remote): - """The results dict must not contain an entry for the remote table.""" - results = sync_manager.sync_all() - - assert table_remote.id not in results - - def test_results_include_local_and_hybrid( - self, sync_manager, table_local, table_hybrid - ): - """Results dict must contain entries for local and hybrid tables.""" - results = sync_manager.sync_all() - - assert table_local.id in results - assert table_hybrid.id in results - - -class TestSyncAllAllRemote: - """Edge case: every table is remote.""" - - def test_no_sync_calls_when_all_remote(self, mock_config, mock_data_source): - remote_only = [ - _make_table_config("t.r1", "remote1", query_mode="remote"), - _make_table_config("t.r2", "remote2", query_mode="remote"), - ] - mock_config.tables = remote_only - - with ( - patch("src.data_sync.get_config", return_value=mock_config), - patch("src.data_sync.create_data_source", return_value=mock_data_source), - patch("src.data_sync.SyncState"), - ): - manager = DataSyncManager() - manager._generate_schema_yaml = MagicMock() - results = manager.sync_all() - - assert mock_data_source.sync_table.call_count == 0 - assert results == {} - - -class TestSyncAllNoRemote: - """Edge case: no remote tables at all -- everything syncs.""" - - def test_all_tables_synced(self, mock_config, mock_data_source): - local_only = [ - _make_table_config("t.l1", "local1", query_mode="local"), - _make_table_config("t.l2", "local2", query_mode="local"), - ] - mock_config.tables = local_only - - with ( - patch("src.data_sync.get_config", return_value=mock_config), - patch("src.data_sync.create_data_source", return_value=mock_data_source), - patch("src.data_sync.SyncState"), - ): - manager = DataSyncManager() - manager._generate_schema_yaml = MagicMock() - results = manager.sync_all() - - assert mock_data_source.sync_table.call_count == 2 - assert "t.l1" in results - assert "t.l2" in results - - -class TestSyncAllWithTableFilter: - """When sync_all receives an explicit table list, remote filtering still applies.""" - - def test_explicit_remote_table_still_skipped( - self, sync_manager, mock_data_source, table_remote - ): - """Even if explicitly listed, a remote table should be skipped.""" - sync_manager.sync_all(tables=[table_remote.id]) - - assert mock_data_source.sync_table.call_count == 0 - - def test_explicit_local_table_synced( - self, sync_manager, mock_data_source, table_local - ): - """An explicitly listed local table should be synced.""" - sync_manager.sync_all(tables=[table_local.id]) - - assert mock_data_source.sync_table.call_count == 1 - synced_id = mock_data_source.sync_table.call_args_list[0].args[0].id - assert synced_id == table_local.id diff --git a/tests/test_openmetadata_enricher.py b/tests/test_openmetadata_enricher.py index 66ec8b9..4a33f71 100644 --- a/tests/test_openmetadata_enricher.py +++ b/tests/test_openmetadata_enricher.py @@ -7,11 +7,11 @@ from datetime import datetime, timedelta from unittest.mock import Mock, patch, MagicMock from dataclasses import dataclass -from src.config import TableConfig from connectors.openmetadata.enricher import ( CatalogEnricher, CatalogTableData, CatalogColumnData, + TableConfig, ) @@ -21,9 +21,6 @@ def sample_table_config(): return TableConfig( id="prj-grp-dataview-prod-1ff9.marketing.roi_datamart_v2", name="roi_datamart_v2", - description="ROI metrics", - primary_key="id", - sync_strategy="full_refresh", ) @@ -111,9 +108,6 @@ def test_enrich_table_disabled(): table_config = TableConfig( id="test.table", name="test", - description="Test", - primary_key="id", - sync_strategy="full_refresh", ) result = enricher.enrich_table(table_config) @@ -145,9 +139,6 @@ def test_enrich_table_cache_hit(): table_config = TableConfig( id="prj-grp-dataview-prod-1ff9.marketing.test", name="test", - description="Test", - primary_key="id", - sync_strategy="full_refresh", ) result = enricher.enrich_table(table_config) @@ -199,9 +190,6 @@ def test_derive_fqn_auto(): table_config = TableConfig( id="prj-grp-dataview-prod-1ff9.marketing.roi_datamart_v2", name="roi_datamart_v2", - description="Test", - primary_key="id", - sync_strategy="full_refresh", ) fqn = enricher._derive_fqn(table_config) @@ -223,9 +211,6 @@ def test_derive_fqn_explicit_override(): table_config = TableConfig( id="prj-grp-dataview-prod-1ff9.marketing.roi_datamart_v2", name="roi_datamart_v2", - description="Test", - primary_key="id", - sync_strategy="full_refresh", ) table_config.catalog_fqn = "bigquery.custom.fqn.override" @@ -390,9 +375,6 @@ def test_enrich_table_http_error_graceful(): table_config = TableConfig( id="test.table", name="test", - description="Test", - primary_key="id", - sync_strategy="full_refresh", ) # Should return None, not raise diff --git a/tests/test_remote_query.py b/tests/test_remote_query.py deleted file mode 100644 index 72b7e8e..0000000 --- a/tests/test_remote_query.py +++ /dev/null @@ -1,779 +0,0 @@ -"""Tests for remote_query module - hybrid local Parquet + remote BigQuery queries. - -Tests cover: -- CLI argument parsing (_parse_register_bq, build_parser) -- Local view setup (_setup_local_views via create_local_views) -- BQ registration with safety checks (_validate_bq_result_size, _estimate_memory_mb, _register_bq_views) -- Output formatting (_print_table, _format_output) -- End-to-end local-only queries (no BQ mocking needed) -""" - -import argparse -import csv -import json -import os -from io import StringIO -from pathlib import Path -from unittest.mock import MagicMock, patch - -import duckdb -import pyarrow as pa -import pyarrow.parquet as pq -import pytest - -from src.remote_query import ( - RemoteQueryError, - _estimate_memory_mb, - _format_output, - _parse_register_bq, - _print_table, - _register_bq_views, - _validate_bq_result_size, - build_parser, - execute_remote_query, -) - - -# --------------------------------------------------------------------------- -# Fixtures -# --------------------------------------------------------------------------- - -@pytest.fixture -def tmp_local_project(tmp_path): - """Create a minimal project with docs/data_description.md and parquet files. - - Layout: - tmp_path/ - docs/data_description.md (YAML with local + remote + hybrid tables) - server/parquet/crm_data/orders.parquet - server/parquet/crm_data/products.parquet - - Returns (project_root, data_dir) where data_dir = tmp_path / "server". - """ - docs_dir = tmp_path / "docs" - docs_dir.mkdir() - - data_description = """\ -# Data Description - -```yaml -folder_mapping: - in.c-crm: crm_data - -tables: - - id: "in.c-crm.orders" - name: "orders" - description: "Order data" - primary_key: "order_id" - sync_strategy: "full_refresh" - - - id: "in.c-crm.products" - name: "products" - description: "Product catalog" - primary_key: "product_id" - sync_strategy: "full_refresh" - - - id: "prj-grp-dataview-prod-1ff9.supply.traffic" - name: "traffic" - description: "Remote BQ traffic table" - primary_key: "id" - query_mode: "remote" - - - id: "in.c-crm.inventory" - name: "inventory" - description: "Hybrid inventory" - primary_key: "id" - sync_strategy: "full_refresh" - query_mode: "hybrid" -``` -""" - (docs_dir / "data_description.md").write_text(data_description) - - # Create parquet files for local tables - crm_dir = tmp_path / "server" / "parquet" / "crm_data" - crm_dir.mkdir(parents=True) - - orders_table = pa.table({ - "order_id": [1, 2, 3, 4, 5], - "amount": [10.0, 20.0, 30.0, 40.0, 50.0], - "product_id": [101, 102, 101, 103, 102], - }) - pq.write_table(orders_table, crm_dir / "orders.parquet") - - products_table = pa.table({ - "product_id": [101, 102, 103], - "name": ["Widget", "Gadget", "Doohickey"], - }) - pq.write_table(products_table, crm_dir / "products.parquet") - - # Create parquet for hybrid table - inventory_table = pa.table({ - "id": [1, 2], - "stock": [100, 200], - }) - pq.write_table(inventory_table, crm_dir / "inventory.parquet") - - data_dir = str(tmp_path / "server") - return tmp_path, data_dir - - -@pytest.fixture -def duckdb_conn(): - """Create an in-memory DuckDB connection, closed after test.""" - conn = duckdb.connect(":memory:") - yield conn - conn.close() - - -class _DuckDBConnectionProxy: - """Proxy around DuckDBPyConnection that silently ignores unsupported SET commands. - - DuckDB versions may not support 'statement_timeout'. This proxy catches - CatalogException for SET commands so end-to-end tests work across versions. - The real connection's execute method is read-only, so we wrap it. - """ - - def __init__(self, conn): - object.__setattr__(self, "_conn", conn) - - def execute(self, sql, *args, **kwargs): - if isinstance(sql, str) and sql.strip().upper().startswith("SET "): - try: - return self._conn.execute(sql, *args, **kwargs) - except duckdb.CatalogException: - return None - return self._conn.execute(sql, *args, **kwargs) - - def __getattr__(self, name): - return getattr(self._conn, name) - - -def _patched_duckdb_connect(*args, **kwargs): - """Create a DuckDB connection wrapped in the proxy.""" - conn = duckdb.connect(*args, **kwargs) - return _DuckDBConnectionProxy(conn) - - -# --------------------------------------------------------------------------- -# Tests: CLI argument parsing -# --------------------------------------------------------------------------- - -class TestCLIArgParsing: - """Test _parse_register_bq() and build_parser().""" - - def test_sql_defaults_to_none(self): - """Parser allows --sql to be omitted (for --stdin mode).""" - parser = build_parser() - args = parser.parse_args(["--stdin"]) - assert args.sql is None - assert args.stdin is True - - def test_register_bq_parsing(self): - """'alias=SELECT ...' parses into (alias, sql) tuple.""" - result = _parse_register_bq("traffic=SELECT report_date FROM `proj.ds.table`") - assert result == ("traffic", "SELECT report_date FROM `proj.ds.table`") - - def test_register_bq_invalid_format(self): - """Missing '=' should raise ArgumentTypeError.""" - with pytest.raises(argparse.ArgumentTypeError, match="Invalid --register-bq format"): - _parse_register_bq("no_equals_sign_here") - - def test_register_bq_empty_sql(self): - """Alias with empty SQL after '=' should raise.""" - with pytest.raises(argparse.ArgumentTypeError, match="Empty SQL"): - _parse_register_bq("alias=") - - def test_register_bq_empty_alias(self): - """'=SELECT ...' (empty alias) should raise.""" - with pytest.raises(argparse.ArgumentTypeError, match="Invalid --register-bq format"): - _parse_register_bq("=SELECT 1") - - def test_multiple_register_bq(self): - """Multiple --register-bq args should be collected into a list.""" - parser = build_parser() - args = parser.parse_args([ - "--sql", "SELECT 1", - "--register-bq", "t1=SELECT a FROM x", - "--register-bq", "t2=SELECT b FROM y", - ]) - assert len(args.bq_registrations) == 2 - assert args.bq_registrations[0] == ("t1", "SELECT a FROM x") - assert args.bq_registrations[1] == ("t2", "SELECT b FROM y") - - def test_default_format_is_none(self): - """Default --format should be None (uses config default at runtime).""" - parser = build_parser() - args = parser.parse_args(["--sql", "SELECT 1"]) - assert args.fmt is None - - def test_explicit_format(self): - """Explicit --format should be respected.""" - parser = build_parser() - args = parser.parse_args(["--sql", "SELECT 1", "--format", "csv"]) - assert args.fmt == "csv" - - def test_invalid_format_rejected(self): - """Invalid --format value should cause parser error.""" - parser = build_parser() - with pytest.raises(SystemExit): - parser.parse_args(["--sql", "SELECT 1", "--format", "xml"]) - - def test_no_register_bq_yields_empty_list(self): - """When no --register-bq is provided, bq_registrations defaults to [].""" - parser = build_parser() - args = parser.parse_args(["--sql", "SELECT 1"]) - assert args.bq_registrations == [] - - def test_register_bq_sql_with_equals(self): - """SQL containing '=' should be parsed correctly (split only on first '=').""" - result = _parse_register_bq("view=SELECT * FROM t WHERE col = 5") - assert result[0] == "view" - assert result[1] == "SELECT * FROM t WHERE col = 5" - - def test_quiet_flag(self): - """--quiet should set quiet=True.""" - parser = build_parser() - args = parser.parse_args(["--sql", "SELECT 1", "--quiet"]) - assert args.quiet is True - - def test_max_rows_parsing(self): - """--max-rows should be parsed as integer.""" - parser = build_parser() - args = parser.parse_args(["--sql", "SELECT 1", "--max-rows", "500"]) - assert args.max_rows == 500 - - -# --------------------------------------------------------------------------- -# Tests: Local view setup -# --------------------------------------------------------------------------- - -class TestLocalViewSetup: - """Test _setup_local_views via create_local_views with tmp_path fixture.""" - - def test_creates_views_from_parquet(self, tmp_local_project, duckdb_conn): - """Local tables should be available as DuckDB views after setup.""" - project_root, data_dir = tmp_local_project - - with patch("scripts.duckdb_manager.find_project_root", return_value=project_root): - from src.remote_query import _setup_local_views - created = _setup_local_views(duckdb_conn, data_dir, quiet=True) - - assert "orders" in created - assert "products" in created - - # Verify data is queryable - count = duckdb_conn.execute("SELECT COUNT(*) FROM orders").fetchone()[0] - assert count == 5 - - def test_skips_remote_tables(self, tmp_local_project, duckdb_conn): - """Remote tables (query_mode='remote') should NOT create local views.""" - project_root, data_dir = tmp_local_project - - with patch("scripts.duckdb_manager.find_project_root", return_value=project_root): - from src.remote_query import _setup_local_views - created = _setup_local_views(duckdb_conn, data_dir, quiet=True) - - assert "traffic" not in created - - # Verify the remote table is not queryable - tables = [row[0] for row in duckdb_conn.execute("SHOW TABLES").fetchall()] - assert "traffic" not in tables - - def test_includes_hybrid_tables(self, tmp_local_project, duckdb_conn): - """Hybrid tables (query_mode='hybrid') should create local views.""" - project_root, data_dir = tmp_local_project - - with patch("scripts.duckdb_manager.find_project_root", return_value=project_root): - from src.remote_query import _setup_local_views - created = _setup_local_views(duckdb_conn, data_dir, quiet=True) - - assert "inventory" in created - - count = duckdb_conn.execute("SELECT COUNT(*) FROM inventory").fetchone()[0] - assert count == 2 - - -# --------------------------------------------------------------------------- -# Tests: BQ registration with safety checks -# --------------------------------------------------------------------------- - -class TestBQRegistration: - """Test BQ result validation and registration (mocked BigQuery).""" - - @staticmethod - def _make_mock_bq_client(count_result: int = 100, schema_fields: int = 5): - """Create a mock BQ client that returns controlled count and schema. - - Args: - count_result: Row count returned by COUNT(*) query - schema_fields: Number of fields in the schema - """ - mock_client = MagicMock() - - # COUNT(*) query result - count_row = MagicMock() - count_row.__getitem__ = MagicMock(return_value=count_result) - count_iter = iter([count_row]) - - # Schema query result (LIMIT 0) - mock_schema_fields = [MagicMock() for _ in range(schema_fields)] - mock_schema = MagicMock() - mock_schema.__len__ = MagicMock(return_value=schema_fields) - - # Use side_effect to return different results for different queries - def query_side_effect(sql): - job = MagicMock() - if sql.startswith("SELECT COUNT(*)"): - result = MagicMock() - result.__iter__ = MagicMock(return_value=iter([count_row])) - job.result.return_value = result - elif "LIMIT 0" in sql: - result = MagicMock() - result.schema = mock_schema_fields - job.result.return_value = result - return job - - mock_client.query.side_effect = query_side_effect - return mock_client - - def test_count_check_blocks_large_result(self): - """BQ sub-query exceeding max_rows should raise RemoteQueryError.""" - mock_client = self._make_mock_bq_client(count_result=1_000_000) - - with pytest.raises(RemoteQueryError, match="would return 1,000,000 rows"): - _validate_bq_result_size( - bq_client=mock_client, - sql="SELECT * FROM big_table", - alias="big_table", - max_rows=500_000, - ) - - def test_validates_small_result_passes(self): - """BQ sub-query within limits should return the row count.""" - mock_client = self._make_mock_bq_client(count_result=1000) - - row_count = _validate_bq_result_size( - bq_client=mock_client, - sql="SELECT * FROM small_table", - alias="small_table", - max_rows=500_000, - ) - - assert row_count == 1000 - - def test_memory_estimate_blocks_huge_result(self): - """_register_bq_views should refuse when estimated memory exceeds 2 GB.""" - # Create a mock that passes count check but fails memory check - # 500K rows x 100 cols x 50 bytes/cell = ~2384 MB > 2048 MB limit - mock_client = self._make_mock_bq_client(count_result=500_000, schema_fields=100) - - conn = duckdb.connect(":memory:") - try: - with patch("src.remote_query._create_bq_client", return_value=mock_client), \ - patch.dict(os.environ, {"BIGQUERY_PROJECT": "test-proj"}): - with pytest.raises(RemoteQueryError, match="estimated memory"): - _register_bq_views( - conn=conn, - registrations=[("huge", "SELECT * FROM huge_table")], - max_bq_rows=1_000_000, - timeout_seconds=60, - quiet=True, - ) - finally: - conn.close() - - def test_registers_small_result(self): - """BQ sub-query within all limits should register successfully.""" - # Small result: 100 rows x 5 cols = ~0.02 MB - mock_client = self._make_mock_bq_client(count_result=100, schema_fields=5) - - # Mock register_bq_table to return the row count - conn = duckdb.connect(":memory:") - try: - with patch("src.remote_query._create_bq_client", return_value=mock_client), \ - patch("src.remote_query.register_bq_table", return_value=100) as mock_reg, \ - patch.dict(os.environ, {"BIGQUERY_PROJECT": "test-proj"}): - results = _register_bq_views( - conn=conn, - registrations=[("small_view", "SELECT * FROM small_table")], - max_bq_rows=500_000, - timeout_seconds=60, - quiet=True, - ) - - assert results == {"small_view": 100} - mock_reg.assert_called_once() - finally: - conn.close() - - def test_missing_bigquery_project_raises(self): - """Missing BIGQUERY_PROJECT env var should raise RemoteQueryError.""" - conn = duckdb.connect(":memory:") - try: - with patch.dict(os.environ, {}, clear=True): - with pytest.raises(RemoteQueryError, match="BIGQUERY_PROJECT"): - _register_bq_views( - conn=conn, - registrations=[("v", "SELECT 1")], - max_bq_rows=100, - timeout_seconds=60, - ) - finally: - conn.close() - - def test_empty_registrations_returns_empty(self): - """Empty registration list should return empty dict without BQ calls.""" - conn = duckdb.connect(":memory:") - try: - result = _register_bq_views( - conn=conn, - registrations=[], - max_bq_rows=100, - timeout_seconds=60, - ) - assert result == {} - finally: - conn.close() - - -# --------------------------------------------------------------------------- -# Tests: Memory estimation -# --------------------------------------------------------------------------- - -class TestMemoryEstimation: - """Test _estimate_memory_mb calculation.""" - - def test_small_table(self): - """100 rows x 10 cols = 50_000 bytes ~ 0.048 MB.""" - result = _estimate_memory_mb(100, 10) - assert abs(result - 50_000 / (1024 * 1024)) < 0.001 - - def test_large_table(self): - """1M rows x 50 cols x 50 bytes = ~2384 MB.""" - result = _estimate_memory_mb(1_000_000, 50) - expected = (1_000_000 * 50 * 50) / (1024 * 1024) - assert abs(result - expected) < 0.01 - - def test_zero_rows(self): - """Zero rows should return 0 MB.""" - assert _estimate_memory_mb(0, 50) == 0.0 - - def test_zero_columns(self): - """Zero columns should return 0 MB.""" - assert _estimate_memory_mb(1000, 0) == 0.0 - - -# --------------------------------------------------------------------------- -# Tests: Output formatting -# --------------------------------------------------------------------------- - -class TestOutputFormatting: - """Test _print_table and _format_output for various formats.""" - - def test_table_format_aligned(self, capsys): - """Table format should produce aligned columns with header separator.""" - columns = ["id", "name", "value"] - rows = [(1, "alice", 100), (2, "bob", 200)] - - _print_table(columns, rows) - - output = capsys.readouterr().out - lines = output.strip().split("\n") - - # Header line - assert "id" in lines[0] - assert "name" in lines[0] - assert "value" in lines[0] - - # Separator line - assert "-+-" in lines[1] - - # Data rows - assert "alice" in lines[2] - assert "bob" in lines[3] - - # Row count footer - assert "(2 rows)" in output - - def test_table_format_empty_result(self, capsys): - """Empty result should print '(empty result)'.""" - _print_table(["col1"], []) - - output = capsys.readouterr().out - assert "(empty result)" in output - - def test_table_format_null_values(self, capsys): - """None values should be rendered as 'NULL'.""" - _print_table(["col"], [(None,)]) - - output = capsys.readouterr().out - assert "NULL" in output - - def test_csv_format(self, tmp_path, duckdb_conn): - """CSV output should contain header + data rows.""" - duckdb_conn.execute("CREATE TABLE test AS SELECT 1 AS id, 'hello' AS msg") - - output_path = str(tmp_path / "result.csv") - _format_output( - conn=duckdb_conn, - sql="SELECT * FROM test", - fmt="csv", - output_path=output_path, - max_rows=1000, - ) - - with open(output_path) as f: - reader = csv.reader(f) - header = next(reader) - rows = list(reader) - - assert header == ["id", "msg"] - assert len(rows) == 1 - assert rows[0][1] == "hello" - - def test_csv_format_to_stdout(self, capsys, duckdb_conn): - """CSV with no output path should write to stdout.""" - duckdb_conn.execute("CREATE TABLE test AS SELECT 42 AS val") - - _format_output( - conn=duckdb_conn, - sql="SELECT * FROM test", - fmt="csv", - output_path=None, - max_rows=1000, - ) - - output = capsys.readouterr().out - assert "val" in output - assert "42" in output - - def test_json_format(self, tmp_path, duckdb_conn): - """JSON output should contain a list of dicts.""" - duckdb_conn.execute( - "CREATE TABLE test AS SELECT 1 AS id, 'world' AS msg" - ) - - output_path = str(tmp_path / "result.json") - _format_output( - conn=duckdb_conn, - sql="SELECT * FROM test", - fmt="json", - output_path=output_path, - max_rows=1000, - ) - - with open(output_path) as f: - data = json.load(f) - - assert len(data) == 1 - assert data[0]["id"] == 1 - assert data[0]["msg"] == "world" - - def test_json_format_to_stdout(self, capsys, duckdb_conn): - """JSON with no output path should print to stdout.""" - duckdb_conn.execute("CREATE TABLE test AS SELECT 99 AS num") - - _format_output( - conn=duckdb_conn, - sql="SELECT * FROM test", - fmt="json", - output_path=None, - max_rows=1000, - ) - - output = capsys.readouterr().out - data = json.loads(output) - assert data[0]["num"] == 99 - - def test_parquet_write(self, tmp_path, duckdb_conn): - """Parquet output should create a readable parquet file.""" - duckdb_conn.execute( - "CREATE TABLE test AS SELECT 1 AS id, 2.5 AS val" - ) - - output_path = str(tmp_path / "output" / "result.parquet") - - with patch("src.remote_query._load_remote_query_config", return_value={ - "output_dir": str(tmp_path / "default_output"), - "timeout_seconds": 300, - "max_result_rows": 100_000, - "max_bq_registration_rows": 500_000, - "default_format": "table", - }): - _format_output( - conn=duckdb_conn, - sql="SELECT * FROM test", - fmt="parquet", - output_path=output_path, - max_rows=1000, - ) - - assert Path(output_path).exists() - - # Read it back and verify - result = pq.read_table(output_path) - assert result.num_rows == 1 - assert result.num_columns == 2 - assert result.column_names == ["id", "val"] - - def test_parquet_default_path(self, tmp_path, duckdb_conn): - """Parquet with no output path should use config default dir.""" - duckdb_conn.execute("CREATE TABLE test AS SELECT 1 AS x") - - default_dir = str(tmp_path / "default_output") - with patch("src.remote_query._load_remote_query_config", return_value={ - "output_dir": default_dir, - "timeout_seconds": 300, - "max_result_rows": 100_000, - "max_bq_registration_rows": 500_000, - "default_format": "table", - }): - _format_output( - conn=duckdb_conn, - sql="SELECT * FROM test", - fmt="parquet", - output_path=None, - max_rows=1000, - ) - - expected_path = Path(default_dir) / "result.parquet" - assert expected_path.exists() - - def test_unknown_format_raises(self, duckdb_conn): - """Unknown format should raise RemoteQueryError.""" - duckdb_conn.execute("CREATE TABLE test AS SELECT 1 AS id") - - with pytest.raises(RemoteQueryError, match="Unknown format"): - _format_output( - conn=duckdb_conn, - sql="SELECT * FROM test", - fmt="xml", - output_path=None, - max_rows=1000, - ) - - -# --------------------------------------------------------------------------- -# Tests: End-to-end (local-only, no BQ mocking needed) -# --------------------------------------------------------------------------- - -class TestEndToEnd: - """End-to-end tests with local Parquet data only (no BigQuery dependency). - - Uses _patched_duckdb_connect to handle DuckDB version differences - (statement_timeout may not be supported in all versions). - """ - - _CONFIG = { - "timeout_seconds": 300, - "max_result_rows": 100_000, - "max_bq_registration_rows": 500_000, - "default_format": "table", - "output_dir": "/tmp/remote_query_test", - } - - def _run(self, tmp_local_project, **kwargs): - """Helper to run execute_remote_query with standard patches.""" - project_root, data_dir = tmp_local_project - config = dict(self._CONFIG) - config.update(kwargs.pop("config_overrides", {})) - - with patch("scripts.duckdb_manager.find_project_root", return_value=project_root), \ - patch("src.remote_query._load_remote_query_config", return_value=config), \ - patch("src.remote_query.duckdb") as mock_duckdb_mod: - mock_duckdb_mod.connect = _patched_duckdb_connect - kwargs.setdefault("data_dir", data_dir) - kwargs.setdefault("bq_registrations", []) - kwargs.setdefault("quiet", True) - execute_remote_query(**kwargs) - - def test_local_only_query(self, tmp_local_project, capsys): - """Execute a query against local Parquet views and verify table output.""" - self._run( - tmp_local_project, - sql="SELECT COUNT(*) AS cnt FROM orders", - fmt="table", - ) - - output = capsys.readouterr().out - assert "cnt" in output - assert "5" in output - - def test_local_join_query(self, tmp_local_project, capsys): - """JOIN between two local tables should work.""" - self._run( - tmp_local_project, - sql=( - "SELECT p.name, SUM(o.amount) AS total " - "FROM orders o JOIN products p ON o.product_id = p.product_id " - "GROUP BY p.name ORDER BY total DESC" - ), - fmt="json", - ) - - output = capsys.readouterr().out - data = json.loads(output) - assert len(data) == 3 - # Widget: orders 1,3 -> 10+30=40 - widget = next(r for r in data if r["name"] == "Widget") - assert widget["total"] == 40.0 - - def test_result_row_limit(self, tmp_local_project, capsys): - """Result exceeding max_rows should be truncated.""" - self._run( - tmp_local_project, - sql="SELECT * FROM orders ORDER BY order_id", - fmt="table", - max_rows=2, - quiet=False, - config_overrides={"max_result_rows": 2}, - ) - - out = capsys.readouterr().out - # Table output should show exactly 2 data rows - assert "(2 rows)" in out - - def test_csv_output_to_file(self, tmp_local_project, tmp_path): - """End-to-end CSV output written to a file.""" - output_path = str(tmp_path / "result.csv") - - self._run( - tmp_local_project, - sql="SELECT order_id, amount FROM orders ORDER BY order_id", - fmt="csv", - output=output_path, - ) - - with open(output_path) as f: - reader = csv.DictReader(f) - rows = list(reader) - - assert len(rows) == 5 - assert rows[0]["order_id"] == "1" - assert rows[0]["amount"] == "10.0" - - def test_hybrid_table_queryable(self, tmp_local_project, capsys): - """Hybrid table should be accessible in local queries.""" - self._run( - tmp_local_project, - sql="SELECT SUM(stock) AS total_stock FROM inventory", - fmt="json", - ) - - output = capsys.readouterr().out - data = json.loads(output) - assert data[0]["total_stock"] == 300 - - def test_quiet_mode_suppresses_stderr(self, tmp_local_project, capsys): - """With quiet=True, no progress messages should appear on stderr.""" - self._run( - tmp_local_project, - sql="SELECT COUNT(*) AS cnt FROM orders", - fmt="table", - quiet=True, - ) - - err = capsys.readouterr().err - # In quiet mode, _log_progress should not emit anything - assert "Setting up" not in err - assert "local views" not in err diff --git a/tests/test_table_registry.py b/tests/test_table_registry.py deleted file mode 100644 index 819bfa7..0000000 --- a/tests/test_table_registry.py +++ /dev/null @@ -1,363 +0,0 @@ -"""Tests for the Table Registry module.""" - -import json -from pathlib import Path - -import pytest -import yaml - -from src.table_registry import ConflictError, TableRegistry - - -# --------------------------------------------------------------------------- -# Fixtures -# --------------------------------------------------------------------------- - -@pytest.fixture -def registry_path(tmp_path): - """Return a temp path for the registry JSON.""" - return tmp_path / "table_registry.json" - - -@pytest.fixture -def registry(registry_path): - """Create an empty registry.""" - return TableRegistry(registry_path) - - -@pytest.fixture -def sample_table(): - """Minimal valid table definition.""" - return { - "id": "in.c-crm.company", - "name": "company", - "description": "Customer master data", - "primary_key": "id", - "sync_strategy": "full_refresh", - } - - -@pytest.fixture -def sample_table_incremental(): - """Incremental table definition.""" - return { - "id": "in.c-crm.events", - "name": "events", - "description": "User events", - "primary_key": "event_id", - "sync_strategy": "incremental", - "incremental_window_days": 14, - "partition_by": "created_at", - "partition_granularity": "month", - } - - -# --------------------------------------------------------------------------- -# Basic CRUD -# --------------------------------------------------------------------------- - -class TestRegistryCRUD: - - def test_empty_registry(self, registry): - assert registry.list_tables() == [] - assert registry.version == 0 - - def test_register_table(self, registry, sample_table): - registry.register_table(sample_table, registered_by="admin@test.com") - tables = registry.list_tables() - assert len(tables) == 1 - assert tables[0]["id"] == "in.c-crm.company" - assert tables[0]["registered_by"] == "admin@test.com" - assert registry.version == 1 - - def test_register_duplicate_raises(self, registry, sample_table): - registry.register_table(sample_table, registered_by="admin@test.com") - with pytest.raises(ValueError, match="already registered"): - registry.register_table(sample_table, registered_by="admin@test.com") - - def test_get_table(self, registry, sample_table): - registry.register_table(sample_table, registered_by="admin@test.com") - t = registry.get_table("in.c-crm.company") - assert t is not None - assert t["name"] == "company" - - def test_get_table_not_found(self, registry): - assert registry.get_table("nonexistent") is None - - def test_is_registered(self, registry, sample_table): - assert not registry.is_registered("in.c-crm.company") - registry.register_table(sample_table, registered_by="admin@test.com") - assert registry.is_registered("in.c-crm.company") - - def test_unregister_table(self, registry, sample_table): - registry.register_table(sample_table, registered_by="admin@test.com") - registry.unregister_table("in.c-crm.company", unregistered_by="admin@test.com") - assert not registry.is_registered("in.c-crm.company") - assert registry.list_tables() == [] - - def test_unregister_nonexistent_raises(self, registry): - with pytest.raises(ValueError, match="not registered"): - registry.unregister_table("nonexistent") - - def test_update_table(self, registry, sample_table): - registry.register_table(sample_table, registered_by="admin@test.com") - registry.update_table( - "in.c-crm.company", - {"description": "Updated description", "sync_strategy": "incremental"}, - updated_by="admin@test.com", - ) - t = registry.get_table("in.c-crm.company") - assert t["description"] == "Updated description" - assert t["sync_strategy"] == "incremental" - - def test_update_nonexistent_raises(self, registry): - with pytest.raises(ValueError, match="not registered"): - registry.update_table("nonexistent", {"description": "x"}) - - -# --------------------------------------------------------------------------- -# Validation -# --------------------------------------------------------------------------- - -class TestValidation: - - def test_missing_id_raises(self, registry): - with pytest.raises(ValueError, match="must include 'id'"): - registry.register_table( - {"name": "x", "sync_strategy": "full_refresh", "primary_key": "id"}, - registered_by="admin@test.com", - ) - - def test_missing_name_raises(self, registry): - with pytest.raises(ValueError, match="must include 'name'"): - registry.register_table( - {"id": "x.y.z", "sync_strategy": "full_refresh", "primary_key": "id"}, - registered_by="admin@test.com", - ) - - def test_invalid_sync_strategy_raises(self, registry): - with pytest.raises(ValueError, match="Invalid sync_strategy"): - registry.register_table( - { - "id": "x.y.z", - "name": "z", - "sync_strategy": "magic", - "primary_key": "id", - }, - registered_by="admin@test.com", - ) - - -# --------------------------------------------------------------------------- -# Optimistic locking -# --------------------------------------------------------------------------- - -class TestOptimisticLocking: - - def test_register_with_wrong_version_raises(self, registry, sample_table): - with pytest.raises(ConflictError, match="Version conflict"): - registry.register_table( - sample_table, registered_by="admin@test.com", expected_version=99 - ) - - def test_register_with_correct_version(self, registry, sample_table): - registry.register_table( - sample_table, registered_by="admin@test.com", expected_version=0 - ) - assert registry.version == 1 - - def test_unregister_with_wrong_version_raises(self, registry, sample_table): - registry.register_table(sample_table, registered_by="admin@test.com") - with pytest.raises(ConflictError): - registry.unregister_table( - "in.c-crm.company", expected_version=0 - ) - - def test_update_with_wrong_version_raises(self, registry, sample_table): - registry.register_table(sample_table, registered_by="admin@test.com") - with pytest.raises(ConflictError): - registry.update_table( - "in.c-crm.company", - {"description": "x"}, - expected_version=0, - ) - - -# --------------------------------------------------------------------------- -# Persistence -# --------------------------------------------------------------------------- - -class TestPersistence: - - def test_save_and_reload(self, registry_path, sample_table): - reg1 = TableRegistry(registry_path) - reg1.register_table(sample_table, registered_by="admin@test.com") - - # Reload from disk - reg2 = TableRegistry(registry_path) - assert len(reg2.list_tables()) == 1 - assert reg2.get_table("in.c-crm.company")["name"] == "company" - assert reg2.version == 1 - - def test_json_format(self, registry_path, sample_table): - reg = TableRegistry(registry_path) - reg.register_table(sample_table, registered_by="admin@test.com") - - with open(registry_path) as f: - data = json.load(f) - - assert "_metadata" in data - assert "tables" in data - assert data["_metadata"]["version"] == 1 - assert len(data["tables"]) == 1 - - -# --------------------------------------------------------------------------- -# Folder mapping -# --------------------------------------------------------------------------- - -class TestFolderMapping: - - def test_set_and_get(self, registry): - registry.set_folder_mapping("in.c-crm", "crm") - assert registry.get_folder_mapping() == {"in.c-crm": "crm"} - - def test_persists(self, registry_path): - reg1 = TableRegistry(registry_path) - reg1.set_folder_mapping("in.c-crm", "crm") - - reg2 = TableRegistry(registry_path) - assert reg2.get_folder_mapping() == {"in.c-crm": "crm"} - - -# --------------------------------------------------------------------------- -# Generation -# --------------------------------------------------------------------------- - -class TestGeneration: - - def test_generate_data_description_md(self, registry, sample_table, tmp_path): - registry.register_table(sample_table, registered_by="admin@test.com") - registry.set_folder_mapping("in.c-crm", "crm") - - output = tmp_path / "data_description.md" - registry.generate_data_description_md(output) - - content = output.read_text() - - # Check header - assert "AUTO-GENERATED" in content - assert "checksum: sha256:" in content - - # Check YAML block is parseable - yaml_match = __import__("re").search(r"```yaml\n(.*?)```", content, __import__("re").DOTALL) - assert yaml_match - yaml_data = yaml.safe_load(yaml_match.group(1)) - assert len(yaml_data["tables"]) == 1 - assert yaml_data["tables"][0]["id"] == "in.c-crm.company" - assert yaml_data["folder_mapping"] == {"in.c-crm": "crm"} - - def test_generate_includes_incremental_fields( - self, registry, sample_table_incremental, tmp_path - ): - registry.register_table(sample_table_incremental, registered_by="admin@test.com") - - output = tmp_path / "data_description.md" - registry.generate_data_description_md(output) - - content = output.read_text() - yaml_match = __import__("re").search(r"```yaml\n(.*?)```", content, __import__("re").DOTALL) - yaml_data = yaml.safe_load(yaml_match.group(1)) - table = yaml_data["tables"][0] - assert table["partition_by"] == "created_at" - assert table["partition_granularity"] == "month" - assert table["incremental_window_days"] == 14 - - -# --------------------------------------------------------------------------- -# Migration -# --------------------------------------------------------------------------- - -class TestMigration: - - def test_import_from_data_description(self, tmp_path): - # Create a fake data_description.md - md_content = """# Data Description - -```yaml -folder_mapping: - in.c-crm: crm - -tables: - - id: in.c-crm.company - name: company - description: Companies - primary_key: id - sync_strategy: full_refresh - - - id: in.c-crm.contact - name: contact - description: Contacts - primary_key: id - sync_strategy: incremental - incremental_window_days: 7 -``` -""" - md_path = tmp_path / "data_description.md" - md_path.write_text(md_content) - - registry_path = tmp_path / "table_registry.json" - registry = TableRegistry.import_from_data_description(md_path, registry_path) - - assert len(registry.list_tables()) == 2 - assert registry.is_registered("in.c-crm.company") - assert registry.is_registered("in.c-crm.contact") - assert registry.get_folder_mapping() == {"in.c-crm": "crm"} - - # Check migrated_from marker - with open(registry_path) as f: - data = json.load(f) - assert "migrated_from" in data["_metadata"] - - def test_import_no_yaml_raises(self, tmp_path): - md_path = tmp_path / "data_description.md" - md_path.write_text("# Empty file\nNo YAML here.") - - with pytest.raises(ValueError, match="No YAML blocks"): - TableRegistry.import_from_data_description( - md_path, tmp_path / "registry.json" - ) - - def test_import_file_not_found_raises(self, tmp_path): - with pytest.raises(FileNotFoundError): - TableRegistry.import_from_data_description( - tmp_path / "nonexistent.md", tmp_path / "registry.json" - ) - - -# --------------------------------------------------------------------------- -# Audit log -# --------------------------------------------------------------------------- - -class TestAuditLog: - - def test_register_writes_audit(self, registry, sample_table): - registry.register_table(sample_table, registered_by="admin@test.com") - - audit_path = registry.registry_path.parent / "registry_audit.log" - assert audit_path.exists() - - lines = audit_path.read_text().strip().split("\n") - assert len(lines) >= 1 - entry = json.loads(lines[-1]) - assert entry["action"] == "register" - assert entry["table_id"] == "in.c-crm.company" - - def test_unregister_writes_audit(self, registry, sample_table): - registry.register_table(sample_table, registered_by="admin@test.com") - registry.unregister_table("in.c-crm.company", unregistered_by="admin@test.com") - - audit_path = registry.registry_path.parent / "registry_audit.log" - lines = audit_path.read_text().strip().split("\n") - last_entry = json.loads(lines[-1]) - assert last_entry["action"] == "unregister" diff --git a/webapp/app.py b/webapp/app.py index 6efb011..36fda08 100644 --- a/webapp/app.py +++ b/webapp/app.py @@ -611,19 +611,11 @@ def _load_catalog_data() -> list: # Enrich with catalog metadata (OpenMetadata) if _catalog_enricher: try: - # Create config for enrichment with all available fields - from src.config import TableConfig - table_config = TableConfig( + # Create lightweight config for enrichment (enricher uses .id, .name, .catalog_fqn) + from types import SimpleNamespace + table_config = SimpleNamespace( id=table_id, name=table.get("name", ""), - description=table.get("description", ""), - primary_key=table.get("primary_key", "id"), - sync_strategy=table.get("sync_strategy", "full_refresh"), - incremental_window_days=table.get("incremental_window_days"), - partition_by=table.get("partition_by"), - partition_granularity=table.get("partition_granularity"), - max_history_days=table.get("max_history_days"), - partition_column_type=table.get("partition_column_type", "TIMESTAMP"), catalog_fqn=table.get("catalog_fqn"), ) catalog_data = _catalog_enricher.enrich_table(table_config) @@ -1097,7 +1089,7 @@ def register_routes(app: Flask) -> None: if _catalog_enricher and _catalog_enricher.enabled: try: # Find table config from data_description.md - from src.config import TableConfig + from types import SimpleNamespace from config.loader import load_instance_config # Load data_description.md to find table config by name @@ -1116,17 +1108,10 @@ def register_routes(app: Flask) -> None: # Find table by name for table_def in yaml_data["tables"]: if table_def.get("name") == table_name: - table_config = TableConfig( + # Lightweight config (enricher uses .id, .name, .catalog_fqn) + table_config = SimpleNamespace( id=table_def.get("id", ""), name=table_def.get("name", ""), - description=table_def.get("description", ""), - primary_key=table_def.get("primary_key", "id"), - sync_strategy=table_def.get("sync_strategy", "full_refresh"), - incremental_window_days=table_def.get("incremental_window_days"), - partition_by=table_def.get("partition_by"), - partition_granularity=table_def.get("partition_granularity"), - max_history_days=table_def.get("max_history_days"), - partition_column_type=table_def.get("partition_column_type", "TIMESTAMP"), catalog_fqn=table_def.get("catalog_fqn"), ) catalog_data = _catalog_enricher.enrich_table(table_config) @@ -1862,17 +1847,25 @@ def register_routes(app: Flask) -> None: def admin_discover_tables(): """Discover all available tables from the data source.""" try: - from src.data_sync import create_data_source + from app.instance_config import get_data_source_type, get_value - ds = create_data_source() - raw_tables = ds.discover_tables() + source_type = get_data_source_type() + raw_tables = [] + if source_type == "keboola": + from connectors.keboola.client import KeboolaClient + url = get_value("keboola", "url", default="") + token = os.environ.get(get_value("keboola", "token_env", default="KEBOOLA_STORAGE_TOKEN"), "") + client = KeboolaClient(token=token, url=url) + raw_tables = client.discover_all_tables() # Check which tables are already registered registered_ids = set() try: - from src.table_registry import TableRegistry - registry = TableRegistry.default() - registered_ids = {t["id"] for t in registry.list_tables()} + from src.db import get_system_db + from src.repositories.table_registry import TableRegistryRepository + conn = get_system_db() + repo = TableRegistryRepository(conn) + registered_ids = {t["id"] for t in repo.list_all()} except Exception: pass @@ -1905,14 +1898,16 @@ def register_routes(app: Flask) -> None: def admin_registry_list(): """Return the full table registry.""" try: - from src.table_registry import TableRegistry + from src.db import get_system_db + from src.repositories.table_registry import TableRegistryRepository - registry = TableRegistry.default() + conn = get_system_db() + repo = TableRegistryRepository(conn) return jsonify({ "ok": True, - "version": registry.version, - "folder_mapping": registry.get_folder_mapping(), - "tables": registry.list_tables(), + "version": 0, + "folder_mapping": {}, + "tables": repo.list_all(), }) except Exception as e: logger.error(f"Registry list failed: {e}") @@ -1923,7 +1918,8 @@ def register_routes(app: Flask) -> None: @admin_required def admin_register_table(): """Register a new table from discovery results.""" - from src.table_registry import ConflictError, TableRegistry + from src.db import get_system_db + from src.repositories.table_registry import TableRegistryRepository user = session.get("user", {}) email = user.get("email", "") @@ -1933,21 +1929,26 @@ def register_routes(app: Flask) -> None: return jsonify({"error": "Missing table 'id'"}), 400 try: - registry = TableRegistry.default() - registry.register_table( - table_def=data, + conn = get_system_db() + repo = TableRegistryRepository(conn) + repo.register( + id=data["id"], + name=data.get("name", ""), + folder=data.get("folder"), + sync_strategy=data.get("sync_strategy"), + primary_key=data.get("primary_key"), + description=data.get("description"), registered_by=email, - expected_version=data.get("version"), + source_type=data.get("source_type"), + bucket=data.get("bucket"), + source_table=data.get("source_table"), + query_mode=data.get("query_mode", "local"), + sync_schedule=data.get("sync_schedule"), + profile_after_sync=data.get("profile_after_sync", True), ) - # Regenerate data_description.md - docs_path = Path(os.path.dirname(__file__)) / ".." / "docs" / "data_description.md" - registry.generate_data_description_md(docs_path.resolve()) + return jsonify({"ok": True}) - return jsonify({"ok": True, "version": registry.version}) - - except ConflictError as e: - return jsonify({"error": str(e)}), 409 except ValueError as e: return jsonify({"error": str(e)}), 400 except Exception as e: @@ -1959,30 +1960,42 @@ def register_routes(app: Flask) -> None: @admin_required def admin_update_table(table_id): """Update configuration of a registered table.""" - from src.table_registry import ConflictError, TableRegistry + from src.db import get_system_db + from src.repositories.table_registry import TableRegistryRepository user = session.get("user", {}) email = user.get("email", "") data = request.get_json(silent=True) or {} + data.pop("version", None) # Not used by DuckDB repo try: - registry = TableRegistry.default() - registry.update_table( - table_id=table_id, - updates=data, - updated_by=email, - expected_version=data.pop("version", None), + conn = get_system_db() + repo = TableRegistryRepository(conn) + + # Get existing record and merge updates + existing = repo.get(table_id) + if not existing: + return jsonify({"error": f"Table '{table_id}' not found"}), 404 + + repo.register( + id=table_id, + name=data.get("name", existing.get("name", "")), + folder=data.get("folder", existing.get("folder")), + sync_strategy=data.get("sync_strategy", existing.get("sync_strategy")), + primary_key=data.get("primary_key", existing.get("primary_key")), + description=data.get("description", existing.get("description")), + registered_by=email, + source_type=data.get("source_type", existing.get("source_type")), + bucket=data.get("bucket", existing.get("bucket")), + source_table=data.get("source_table", existing.get("source_table")), + query_mode=data.get("query_mode", existing.get("query_mode", "local")), + sync_schedule=data.get("sync_schedule", existing.get("sync_schedule")), + profile_after_sync=data.get("profile_after_sync", existing.get("profile_after_sync", True)), ) - # Regenerate data_description.md - docs_path = Path(os.path.dirname(__file__)) / ".." / "docs" / "data_description.md" - registry.generate_data_description_md(docs_path.resolve()) + return jsonify({"ok": True}) - return jsonify({"ok": True, "version": registry.version}) - - except ConflictError as e: - return jsonify({"error": str(e)}), 409 except ValueError as e: return jsonify({"error": str(e)}), 400 except Exception as e: @@ -1994,25 +2007,18 @@ def register_routes(app: Flask) -> None: @admin_required def admin_unregister_table(table_id): """Unregister a table and clean up subscriptions.""" - from src.table_registry import ConflictError, TableRegistry - - user = session.get("user", {}) - email = user.get("email", "") - - data = request.get_json(silent=True) or {} + from src.db import get_system_db + from src.repositories.table_registry import TableRegistryRepository try: - registry = TableRegistry.default() + conn = get_system_db() + repo = TableRegistryRepository(conn) # Get table name before deletion (for subscription cleanup) - table_info = registry.get_table(table_id) + table_info = repo.get(table_id) table_name = table_info["name"] if table_info else None - registry.unregister_table( - table_id=table_id, - unregistered_by=email, - expected_version=data.get("version"), - ) + repo.unregister(table_id) # Clean up per-user subscriptions for removed table if table_name: @@ -2021,14 +2027,8 @@ def register_routes(app: Flask) -> None: except Exception as ce: logger.warning(f"Subscription cleanup for {table_name} failed: {ce}") - # Regenerate data_description.md - docs_path = Path(os.path.dirname(__file__)) / ".." / "docs" / "data_description.md" - registry.generate_data_description_md(docs_path.resolve()) + return jsonify({"ok": True}) - return jsonify({"ok": True, "version": registry.version}) - - except ConflictError as e: - return jsonify({"error": str(e)}), 409 except ValueError as e: return jsonify({"error": str(e)}), 400 except Exception as e: diff --git a/webapp/sync_settings_service.py b/webapp/sync_settings_service.py index 764db1b..b6d64f8 100644 --- a/webapp/sync_settings_service.py +++ b/webapp/sync_settings_service.py @@ -194,9 +194,12 @@ def _write_rsync_filter(username: str, dataset_settings: dict, table_mode: str, # Load folder_mapping from table registry (or instance config as fallback) folder_mapping = {} try: - from src.table_registry import TableRegistry - registry = TableRegistry.default() - folder_mapping = registry.get_folder_mapping() + from src.db import get_system_db + from src.repositories.table_registry import TableRegistryRepository + conn = get_system_db() + repo = TableRegistryRepository(conn) + tables = repo.list_all() + folder_mapping = {t["bucket"]: t["folder"] for t in tables if t.get("bucket") and t.get("folder")} except Exception: try: from config.loader import load_instance_config, get_instance_value