From 8bb46a9e0a14815f1c5a6e9b91dd3023a2d6159a Mon Sep 17 00:00:00 2001 From: Petr Date: Thu, 12 Mar 2026 13:20:41 +0100 Subject: [PATCH] Add per-partition streaming sync and hybrid query architecture Partitioned sync: iterates day-by-day instead of loading full dataset. Each partition: query BQ -> stream to disk -> free RAM. Peak ~50 MB. New helpers: _sync_single_partition, _cleanup_old_partitions, _generate_partition_dates. Config: added partition_column_type (DATE/TIMESTAMP/DATETIME), query_mode (local/remote/hybrid). DuckDB manager: hybrid architecture support (local Parquet + remote BQ tables). Data sync: skips remote tables, filters by query_mode. Tests: 113 passing (adapter, client, config, data_sync, duckdb_manager). --- connectors/bigquery/adapter.py | 278 +++++++++++----- scripts/duckdb_manager.py | 213 +++++++++++-- src/config.py | 20 ++ src/data_sync.py | 15 + tests/test_bigquery_adapter.py | 407 +++++++++++++++++++++--- tests/test_bigquery_client.py | 148 +++++++++ tests/test_config_query_mode.py | 69 ++++ tests/test_data_sync_query_mode.py | 228 ++++++++++++++ tests/test_duckdb_manager.py | 488 +++++++++++++++++++++++++++++ 9 files changed, 1731 insertions(+), 135 deletions(-) create mode 100644 tests/test_config_query_mode.py create mode 100644 tests/test_data_sync_query_mode.py create mode 100644 tests/test_duckdb_manager.py diff --git a/connectors/bigquery/adapter.py b/connectors/bigquery/adapter.py index c7992d1..19fde49 100644 --- a/connectors/bigquery/adapter.py +++ b/connectors/bigquery/adapter.py @@ -9,7 +9,7 @@ using PyArrow (no CSV intermediate step). import logging from pathlib import Path from typing import Dict, List, Optional, Any -from datetime import datetime, timedelta +from datetime import datetime, timedelta, date import pyarrow as pa import pyarrow.parquet as pq @@ -336,10 +336,11 @@ class BigQueryDataSource(DataSource): sync_state: SyncState, ) -> Dict[str, Any]: """ - Partition-based sync: read data by partition range and write partition files. - """ - import pandas as pd + Per-partition streaming sync: process one partition (day) at a time. + Queries BQ for a single day, streams result to disk, then moves to next day. + Memory usage is constant (~20-50 MB per partition) regardless of total data volume. + """ partition_col = table_config.partition_by if not partition_col and table_config.incremental_column: partition_col = table_config.incremental_column @@ -351,96 +352,231 @@ class BigQueryDataSource(DataSource): ) return self._full_refresh(table_config) - granularity = table_config.partition_granularity or "month" + granularity = table_config.partition_granularity or "day" + column_type = table_config.partition_column_type logger.info( f"Partitioned sync: {table_config.name} " - f"(by {partition_col}, {granularity})" + f"(by {partition_col}, {granularity}, type={column_type})" ) partition_dir = self.config.get_parquet_path(table_config) date_columns = self.bq_client.get_date_columns(table_config.id) pyarrow_schema = self.bq_client.get_pyarrow_schema(table_config.id) - # Determine time range + # Determine date range last_sync = sync_state.get_last_sync(table_config.id) + today = date.today() if last_sync: last_sync_dt = datetime.fromisoformat(last_sync) window_days = table_config.incremental_window_days or 7 - start_dt = last_sync_dt - timedelta(days=window_days) - logger.info(f" -> Reading from {start_dt.isoformat()} (window: {window_days} days)") + start_date = (last_sync_dt - timedelta(days=window_days)).date() + logger.info(f" -> Incremental sync from {start_date} (window: {window_days} days)") else: if table_config.max_history_days: - start_dt = datetime.now() - timedelta(days=table_config.max_history_days) - logger.info(f" -> First sync, limited to last {table_config.max_history_days} days") + start_date = today - timedelta(days=table_config.max_history_days) + logger.info(f" -> First sync, last {table_config.max_history_days} days from {start_date}") else: - start_dt = None - logger.info(" -> First sync, reading all data") + start_date = today - timedelta(days=365) + logger.info(" -> First sync, no max_history_days, defaulting to 365 days") - # Read data from BigQuery - if start_dt: - arrow_table = self.bq_client.read_table_partitioned( - table_id=table_config.id, - partition_column=partition_col, - start=start_dt.isoformat(), - columns=table_config.columns, - ) - else: - arrow_table = self.bq_client.read_table( - table_config.id, - columns=table_config.columns, - row_filter=table_config.row_filter, + # Generate list of partition dates + partition_dates = self._generate_partition_dates(start_date, today, granularity) + logger.info(f" -> Processing {len(partition_dates)} partitions") + + total_rows = 0 + partitions_updated = 0 + + for partition_date in partition_dates: + rows = self._sync_single_partition( + table_config=table_config, + partition_col=partition_col, + partition_date=partition_date, + partition_dir=partition_dir, + date_columns=date_columns, + pyarrow_schema=pyarrow_schema, + granularity=granularity, + column_type=column_type, ) + if rows > 0: + partitions_updated += 1 + total_rows += rows - if arrow_table.num_rows == 0: - logger.info(" -> No data to sync") - return self._get_partition_totals(partition_dir) + # Cleanup old partitions beyond retention window + deleted = self._cleanup_old_partitions(table_config, partition_dir, granularity) + if deleted > 0: + logger.info(f" -> Cleaned up {deleted} old partition files") - logger.info(f" -> Processing {arrow_table.num_rows} rows into partitions") - - # Convert to pandas for partitioning - df = arrow_table.to_pandas() - - # Ensure partition column is datetime - if not pd.api.types.is_datetime64_any_dtype(df[partition_col]): - df[partition_col] = pd.to_datetime(df[partition_col], format="ISO8601", utc=True) - - # Create partition key - if granularity == "month": - df["_partition_key"] = df[partition_col].dt.strftime("%Y_%m") - elif granularity == "day": - df["_partition_key"] = df[partition_col].dt.strftime("%Y_%m_%d") - elif granularity == "year": - df["_partition_key"] = df[partition_col].dt.strftime("%Y") - - primary_key_cols = table_config.get_primary_key_columns() - partitions_updated = set() - - for partition_key, group_df in df.groupby("_partition_key"): - group_df = group_df.drop(columns=["_partition_key"]) - partition_path = self.config.get_partition_path(table_config, partition_key) - partitions_updated.add(partition_key) - - # Merge with existing partition if it exists - if partition_path.exists(): - existing_df = pd.read_parquet(partition_path) - merged_df = pd.concat([existing_df, group_df], ignore_index=True) - merged_df = merged_df.drop_duplicates(subset=primary_key_cols, keep="last") - else: - merged_df = group_df - - # Write partition - table = pa.Table.from_pandas(merged_df, preserve_index=False) - if date_columns: - table = convert_date_columns_to_date32(table, date_columns) - if pyarrow_schema: - table = apply_schema_to_table(table, pyarrow_schema) - pq.write_table(table, partition_path, compression="snappy") - - logger.info(f" -> Partitioned sync complete: {len(partitions_updated)} partitions updated") + logger.info( + f" -> Partitioned sync complete: {partitions_updated} partitions updated, " + f"{total_rows} total rows processed" + ) return self._get_partition_totals(partition_dir) + @staticmethod + def _generate_partition_dates( + start_date: date, + end_date: date, + granularity: str, + ) -> List[date]: + """Generate list of partition start dates between start and end.""" + dates = [] + current = start_date + + if granularity == "day": + while current <= end_date: + dates.append(current) + current += timedelta(days=1) + elif granularity == "month": + # Align to first of month + current = current.replace(day=1) + while current <= end_date: + dates.append(current) + # Move to first of next month + if current.month == 12: + current = current.replace(year=current.year + 1, month=1) + else: + current = current.replace(month=current.month + 1) + elif granularity == "year": + current = current.replace(month=1, day=1) + while current <= end_date: + dates.append(current) + current = current.replace(year=current.year + 1) + + return dates + + def _sync_single_partition( + self, + table_config: TableConfig, + partition_col: str, + partition_date: date, + partition_dir: Path, + date_columns: List[str], + pyarrow_schema, + granularity: str, + column_type: str, + ) -> int: + """ + Query BQ for one partition period, stream to disk, merge with existing file. + + Returns row count for this partition after merge. + """ + import pandas as pd + + # Calculate partition range [start, end) + start = partition_date + if granularity == "day": + end = start + timedelta(days=1) + partition_key = start.strftime("%Y_%m_%d") + elif granularity == "month": + if start.month == 12: + end = start.replace(year=start.year + 1, month=1) + else: + end = start.replace(month=start.month + 1) + partition_key = start.strftime("%Y_%m") + elif granularity == "year": + end = start.replace(year=start.year + 1) + partition_key = start.strftime("%Y") + else: + raise ValueError(f"Unknown granularity: {granularity}") + + partition_path = self.config.get_partition_path(table_config, partition_key) + + # Stream data from BQ for this single partition + batches = [] + for batch in self.bq_client.read_table_partitioned_streaming( + table_id=table_config.id, + partition_column=partition_col, + start=start.isoformat(), + end=end.isoformat(), + columns=table_config.columns, + column_type=column_type, + ): + batches.append(batch) + + if not batches: + return 0 + + new_data = pa.Table.from_batches(batches) + if new_data.num_rows == 0: + return 0 + + # Apply schema conversions + if date_columns: + new_data = convert_date_columns_to_date32(new_data, date_columns) + if pyarrow_schema: + new_data = apply_schema_to_table(new_data, pyarrow_schema) + + # Merge with existing partition file if present + primary_key_cols = table_config.get_primary_key_columns() + + if partition_path.exists(): + existing = pq.read_table(partition_path) + merged = self._merge_arrow_tables(existing, new_data, primary_key_cols) + else: + merged = new_data + + # Write partition file + pq.write_table(merged, partition_path, compression="snappy") + row_count = merged.num_rows + + logger.debug( + f" Partition {partition_key}: {new_data.num_rows} new rows, " + f"{row_count} total after merge" + ) + + # Release memory + del batches, new_data, merged + + return row_count + + def _cleanup_old_partitions( + self, + table_config: TableConfig, + partition_dir: Path, + granularity: str, + ) -> int: + """ + Delete partition files older than max_history_days. + + Returns count of deleted files. + """ + if not table_config.max_history_days: + return 0 + + if not partition_dir.exists(): + return 0 + + cutoff_date = date.today() - timedelta(days=table_config.max_history_days) + deleted = 0 + + for part_path in partition_dir.glob("*.parquet"): + try: + partition_date = self._parse_partition_date(part_path.stem, granularity) + if partition_date and partition_date < cutoff_date: + part_path.unlink() + deleted += 1 + logger.debug(f" Deleted old partition: {part_path.name}") + except (ValueError, IndexError): + logger.warning(f" Skipping unrecognized partition file: {part_path.name}") + + return deleted + + @staticmethod + def _parse_partition_date(partition_key: str, granularity: str) -> Optional[date]: + """Parse a partition key back to a date.""" + try: + if granularity == "day": + return datetime.strptime(partition_key, "%Y_%m_%d").date() + elif granularity == "month": + return datetime.strptime(partition_key, "%Y_%m").date() + elif granularity == "year": + return datetime.strptime(partition_key, "%Y").date() + except ValueError: + return None + return None + def _merge_arrow_tables( self, existing: pa.Table, diff --git a/scripts/duckdb_manager.py b/scripts/duckdb_manager.py index 1cf73d1..ecfd025 100644 --- a/scripts/duckdb_manager.py +++ b/scripts/duckdb_manager.py @@ -1,9 +1,11 @@ #!/usr/bin/env python3 """ -DuckDB Manager - Initialize and manage DuckDB database with views from parquet files. +DuckDB Manager - Initialize and manage DuckDB database with views from parquet files +and runtime BigQuery query registration for remote/hybrid tables. -This script dynamically reads table configurations from docs/data_description.md -and creates DuckDB views accordingly. No hardcoded table list needed! +For BigQuery data sources, tables with query_mode="remote" or "hybrid" are queried +at runtime via the Python BQ client and registered as in-memory Arrow tables in DuckDB. +This avoids the DuckDB BigQuery extension limitation (cannot read BQ views). Usage: python3 scripts/duckdb_manager.py --reinit # Initialize/reinitialize all views @@ -11,13 +13,17 @@ Usage: """ import duckdb +import logging import os import sys import argparse import re import yaml from pathlib import Path -from typing import Dict, List, Tuple +from typing import Dict, List, Optional, Tuple + + +logger = logging.getLogger(__name__) def find_project_root() -> Path: @@ -157,17 +163,131 @@ def get_parquet_path(table_config: Dict, folder_mapping: Dict[str, str], data_di return parquet_dir / f"{table_name}.parquet" -def init_duckdb(db_path="user/duckdb/analytics.duckdb", data_dir="server", verbose=True): +def _get_bq_project_from_table_id(table_id: str) -> Optional[str]: + """Extract BQ project ID from a fully-qualified table ID. + + Args: + table_id: e.g. "prj-grp-dataview-prod-1ff9.finance_unit_economics.unit_economics" + + Returns: + Project ID or None if format doesn't match BQ convention + """ + parts = table_id.split(".") + if len(parts) == 3 and "-" in parts[0]: + return parts[0] + return None + + +def _create_bq_client(project: str): + """Create a BigQuery client. Separated for testability. + + Args: + project: GCP project ID for billing + + Returns: + google.cloud.bigquery.Client instance + """ + from google.cloud import bigquery as bq_module + + return bq_module.Client(project=project) + + +def register_bq_table( + conn: duckdb.DuckDBPyConnection, + table_id: str, + view_name: str, + sql: str, + bq_project: Optional[str] = None, + _bq_client_factory=None, +) -> int: + """ + Execute a BigQuery SQL query and register the result as a DuckDB view. + + Uses the Python BigQuery client (Query API) which supports BQ views, + unlike the DuckDB BigQuery extension (Storage Read API). + The result is held in memory as a PyArrow table -- no disk I/O. + + Args: + conn: Open DuckDB connection + table_id: BQ table ID for logging (e.g., "project.dataset.table") + view_name: Name to register in DuckDB (e.g., "unit_economics_live") + sql: Full BigQuery SQL query to execute + bq_project: GCP project for billing. If None, uses BIGQUERY_PROJECT env var. + _bq_client_factory: Override BQ client creation (for testing) + + Returns: + Number of rows in the result + + Raises: + ImportError: If google-cloud-bigquery is not installed + ValueError: If bq_project is not set + """ + project = bq_project or os.environ.get("BIGQUERY_PROJECT") + if not project: + raise ValueError( + "BigQuery project not set. " + "Pass bq_project or set BIGQUERY_PROJECT env var." + ) + + logger.info(f"Querying BQ: {table_id} -> {view_name}") + logger.debug(f"SQL: {sql[:200]}...") + + factory = _bq_client_factory or _create_bq_client + client = factory(project) + job = client.query(sql) + + # Use Query API (not Storage Read API) to support BQ views + try: + arrow_table = job.to_arrow() + except Exception as e: + if "readsessions" in str(e) or "PERMISSION_DENIED" in str(e): + logger.warning("BQ Storage API unavailable, falling back to REST") + arrow_table = job.to_arrow(create_bqstorage_client=False) + else: + raise + + conn.register(view_name, arrow_table) + logger.info( + f"Registered {view_name}: {arrow_table.num_rows} rows, " + f"{arrow_table.num_columns} cols (in-memory)" + ) + + return arrow_table.num_rows + + +def get_remote_tables(table_configs: List[Dict]) -> List[Dict]: + """Return table configs with query_mode 'remote' or 'hybrid'. + + Args: + table_configs: List of table configuration dicts + + Returns: + List of remote/hybrid table configs + """ + return [ + tc for tc in table_configs + if tc.get("query_mode") in ("remote", "hybrid") + ] + + +def init_duckdb( + db_path="user/duckdb/analytics.duckdb", + data_dir="server", + verbose=True, + bq_project: Optional[str] = None, +): """ Initialize DuckDB database with views from parquet files. - Dynamically reads table configurations from docs/data_description.md - and creates views accordingly. + Creates DuckDB views for local/hybrid tables (from Parquet). + Remote tables are NOT pre-loaded -- they are registered at query time + via register_bq_table(). Args: db_path: Path to DuckDB database file data_dir: Base data directory (e.g., "server" for analysts, "data" for server) verbose: Print progress messages + bq_project: BigQuery execution project (for informational purposes only) Returns: True if successful, False otherwise @@ -176,29 +296,48 @@ def init_duckdb(db_path="user/duckdb/analytics.duckdb", data_dir="server", verbo os.makedirs(os.path.dirname(db_path), exist_ok=True) if verbose: - print("🦆 Inicializuji DuckDB databázi...") + print("Initializing DuckDB database...") try: # Find project root and parse data_description.md project_root = find_project_root() if verbose: - print(f" 📂 Project root: {project_root}") + print(f" Project root: {project_root}") table_configs, folder_mapping = parse_data_description(project_root) if verbose: - print(f" 📋 Načteno {len(table_configs)} tabulek z data_description.md") + print(f" Loaded {len(table_configs)} tables from data_description.md") + + # Separate tables by query_mode + local_tables = [] + remote_tables = [] + hybrid_tables = [] + + for tc in table_configs: + mode = tc.get("query_mode", "local") + if mode == "remote": + remote_tables.append(tc) + elif mode == "hybrid": + hybrid_tables.append(tc) + else: + local_tables.append(tc) + + if verbose: + print(f" Query modes: {len(local_tables)} local, " + f"{len(remote_tables)} remote, {len(hybrid_tables)} hybrid") # Connect to database (creates if doesn't exist) conn = duckdb.connect(db_path) - # Create views + # Create local views from parquet files if verbose: - print("\n📊 Vytvářím views z parquet souborů...") + print("\n Creating views from parquet files...") created_views = [] skipped_views = [] - for table_config in table_configs: + # Process local and hybrid tables (both have local parquet) + for table_config in local_tables + hybrid_tables: table_name = table_config['name'] try: @@ -209,7 +348,7 @@ def init_duckdb(db_path="user/duckdb/analytics.duckdb", data_dir="server", verbo if not parquet_path.exists(): skipped_views.append(table_name) if verbose: - print(f" ⚠️ Přeskakuji {table_name} - parquet neexistuje: {parquet_path}") + print(f" [SKIP] {table_name} - parquet not found: {parquet_path}") continue # Determine if partitioned @@ -229,49 +368,65 @@ def init_duckdb(db_path="user/duckdb/analytics.duckdb", data_dir="server", verbo if not partition_files: skipped_views.append(table_name) if verbose: - print(f" ⚠️ Přeskakuji {table_name} - žádné partition soubory") + print(f" [SKIP] {table_name} - no partition files") continue sql = f"CREATE OR REPLACE VIEW {table_name} AS SELECT * FROM read_parquet('{glob_pattern}', union_by_name=true)" if verbose: - print(f" ✅ {table_name} ({len(partition_files)} partitions)") + mode_label = "hybrid" if table_config.get("query_mode") == "hybrid" else "local" + print(f" [OK] {table_name} ({len(partition_files)} partitions, {mode_label})") else: # Single parquet file sql = f"CREATE OR REPLACE VIEW {table_name} AS SELECT * FROM read_parquet('{parquet_path}')" if verbose: - print(f" ✅ {table_name}") + mode_label = "hybrid" if table_config.get("query_mode") == "hybrid" else "local" + print(f" [OK] {table_name} ({mode_label})") conn.execute(sql) created_views.append(table_name) except Exception as e: if verbose: - print(f" ❌ Chyba při vytváření {table_name}: {e}") + print(f" [ERR] Error creating {table_name}: {e}") return False + # Log remote tables (queried at runtime via register_bq_table) + if remote_tables: + if verbose: + print("\n Remote tables (queried at runtime via BigQuery):") + for table_config in remote_tables: + table_name = table_config['name'] + table_id = table_config['id'] + if verbose: + print(f" [BQ] {table_name} -> {table_id}") + # Display table list with row counts if verbose: - print(f"\n📋 Seznam dostupných tabulek ({len(created_views)} vytvořeno):") + print(f"\n Available tables ({len(created_views)} local views):") tables = conn.execute("SHOW TABLES").fetchall() for table in tables: try: row_count = conn.execute(f"SELECT COUNT(*) FROM {table[0]}").fetchone()[0] - print(f" - {table[0]}: {row_count:,} řádků") - except Exception as e: - print(f" - {table[0]}: (chyba při počítání řádků)") + print(f" - {table[0]}: {row_count:,} rows (local)") + except Exception: + print(f" - {table[0]}: (error counting rows)") + + if remote_tables: + print(f"\n Remote tables ({len(remote_tables)}, loaded on demand):") + for tc in remote_tables: + print(f" - {tc['name']}: via BQ Query API (use date filters!)") # Close connection conn.close() if verbose: - print(f"\n✅ DuckDB databáze vytvořena: {db_path}") - print("💡 Můžeš začít analyzovat data pomocí DuckDB SQL dotazů") + print(f"\n DuckDB database created: {db_path}") return True except Exception as e: if verbose: - print(f"\n❌ Chyba při inicializaci DuckDB: {e}") + print(f"\n Error initializing DuckDB: {e}") import traceback traceback.print_exc() return False @@ -297,6 +452,11 @@ def main(): default='server', help='Base data directory (default: server, use "data" for server deployment)' ) + parser.add_argument( + '--bq-project', + default=None, + help='BigQuery execution project (informational only)' + ) parser.add_argument( '--quiet', action='store_true', @@ -314,7 +474,8 @@ def main(): success = init_duckdb( db_path=args.db_path, data_dir=args.data_dir, - verbose=not args.quiet + verbose=not args.quiet, + bq_project=args.bq_project, ) # Exit with appropriate code diff --git a/src/config.py b/src/config.py index 980d08e..4ee3df9 100644 --- a/src/config.py +++ b/src/config.py @@ -104,9 +104,19 @@ class TableConfig: incremental_column: Optional[str] = None # Column for timestamp-based incremental sync (BigQuery) columns: Optional[List[str]] = None # Subset of columns to sync (None = all) row_filter: Optional[str] = None # SQL WHERE clause for filtering (e.g., "event_date >= '2024-01-01'") + query_mode: str = "local" # "local" (Parquet) | "remote" (BQ direct) | "hybrid" (sync subset, query BQ) + partition_column_type: str = "TIMESTAMP" # BQ SQL type for partition column: "DATE", "TIMESTAMP", "DATETIME" def __post_init__(self): """Validate configuration after initialization.""" + # Validate query_mode + valid_query_modes = ("local", "remote", "hybrid") + if self.query_mode not in valid_query_modes: + raise ValueError( + f"Invalid query_mode '{self.query_mode}' for table {self.id}. " + f"Allowed values: {', '.join(valid_query_modes)}" + ) + # Validate sync_strategy if self.sync_strategy not in ["full_refresh", "incremental", "partitioned"]: raise ValueError( @@ -139,6 +149,14 @@ class TableConfig: f"Allowed values: 'month', 'day', 'year'" ) + # Validate partition_column_type + valid_column_types = ("DATE", "TIMESTAMP", "DATETIME") + if self.partition_column_type not in valid_column_types: + raise ValueError( + f"Invalid partition_column_type '{self.partition_column_type}' for table {self.id}. " + f"Allowed values: {', '.join(valid_column_types)}" + ) + # For partitioned, partition_by must be defined if self.sync_strategy == "partitioned": if not self.partition_by: @@ -435,6 +453,8 @@ class Config: incremental_column=table_data.get("incremental_column"), columns=table_data.get("columns"), row_filter=table_data.get("row_filter"), + query_mode=table_data.get("query_mode", "local"), + partition_column_type=table_data.get("partition_column_type", "TIMESTAMP"), ) table_configs.append(config) diff --git a/src/data_sync.py b/src/data_sync.py index 77169b6..bd601ca 100644 --- a/src/data_sync.py +++ b/src/data_sync.py @@ -406,6 +406,21 @@ class DataSyncManager: else: table_configs = self.config.tables + # Filter out remote-only tables (no local sync needed) + remote_skipped = [ + tc for tc in table_configs if tc.query_mode == "remote" + ] + table_configs = [ + tc for tc in table_configs if tc.query_mode != "remote" + ] + + if remote_skipped: + logger.info( + f"Skipping {len(remote_skipped)} remote-only tables " + f"(query via BigQuery): " + f"{', '.join(tc.name for tc in remote_skipped)}" + ) + logger.info(f"Synchronizing {len(table_configs)} tables...") results = {} diff --git a/tests/test_bigquery_adapter.py b/tests/test_bigquery_adapter.py index dafe0fa..fd5f0bb 100644 --- a/tests/test_bigquery_adapter.py +++ b/tests/test_bigquery_adapter.py @@ -9,6 +9,7 @@ so we install stub modules in sys.modules before importing the adapter. """ import sys +from datetime import date, datetime, timedelta from pathlib import Path from unittest.mock import MagicMock, patch @@ -83,6 +84,7 @@ def _make_table_config( partition_by: str | None = None, partition_granularity: str | None = None, max_history_days: int | None = None, + partition_column_type: str = "TIMESTAMP", ) -> TableConfig: """Helper to build a TableConfig with safe defaults.""" return TableConfig( @@ -96,6 +98,7 @@ def _make_table_config( partition_by=partition_by, partition_granularity=partition_granularity, max_history_days=max_history_days, + partition_column_type=partition_column_type, ) @@ -274,55 +277,217 @@ class TestIncrementalNoNewData: # --------------------------------------------------------------------------- -# 4. partitioned_sync creates partition files +# 4. partitioned_sync - per-day streaming behaviour # --------------------------------------------------------------------------- class TestPartitionedSync: + """Tests for the rewritten _partitioned_sync() that streams per-day from BQ.""" - def test_creates_partition_files(self, mock_config, mock_bq_client, tmp_parquet_dir, sync_state): - """Partitioned sync should create separate Parquet files per partition key.""" - import pandas as pd - + def _setup_partition_config( + self, + mock_config, + tmp_parquet_dir, + *, + granularity: str = "day", + max_history_days: int | None = 3, + incremental_window_days: int | None = 2, + partition_column_type: str = "TIMESTAMP", + ): + """Common setup: create table config + wire mock_config partition paths.""" table_config = _make_table_config( - sync_strategy="incremental", - incremental_column="created_at", - partition_by="created_at", - partition_granularity="month", - incremental_window_days=7, + sync_strategy="partitioned", + incremental_column="event_date", + partition_by="event_date", + partition_granularity=granularity, + max_history_days=max_history_days, + incremental_window_days=incremental_window_days, + partition_column_type=partition_column_type, ) - # For partitioned tables, parquet_path is a directory partition_dir = tmp_parquet_dir / "orders" partition_dir.mkdir(parents=True, exist_ok=True) mock_config.get_parquet_path.return_value = partition_dir - # Configure partition paths def _partition_path(tc, key): return partition_dir / f"{key}.parquet" mock_config.get_partition_path.side_effect = _partition_path - # Build arrow table with timestamps in two months - ts_jan = [pd.Timestamp("2026-01-15 10:00:00", tz="UTC")] - ts_feb = [pd.Timestamp("2026-02-20 14:00:00", tz="UTC")] - arrow_data = pa.table({ - "id": [1, 2], - "name": ["Jan_Order", "Feb_Order"], - "created_at": pa.array(ts_jan + ts_feb, type=pa.timestamp("us", tz="UTC")), + return table_config, partition_dir + + @staticmethod + def _make_day_table(row_id: int, day: date) -> pa.Table: + """Build a one-row Arrow table for a given day.""" + return pa.table({ + "id": [row_id], + "event_date": pa.array([datetime(day.year, day.month, day.day)], type=pa.timestamp("us")), }) - mock_bq_client.read_table.return_value = arrow_data + + def test_creates_daily_partition_files( + self, mock_config, mock_bq_client, tmp_parquet_dir, sync_state + ): + """First sync with max_history_days creates one Parquet file per day with data.""" + table_config, partition_dir = self._setup_partition_config( + mock_config, tmp_parquet_dir, max_history_days=3, granularity="day", + ) + + today = date.today() + day0 = today - timedelta(days=3) + day1 = today - timedelta(days=2) + # day2, day3 (today-1, today) will have no data + + # Build per-day Arrow data for the two days that have rows + day0_table = self._make_day_table(1, day0) + day1_table = self._make_day_table(2, day1) + + # read_table_partitioned_streaming is called once per partition date. + # We need to return data for day0 and day1, empty iterators for the rest. + def _streaming_side_effect(*, table_id, partition_column, start, end, columns, column_type): + start_date = date.fromisoformat(start) + if start_date == day0: + return iter(day0_table.to_batches()) + if start_date == day1: + return iter(day1_table.to_batches()) + return iter([]) # empty for other days + + mock_bq_client.read_table_partitioned_streaming.side_effect = _streaming_side_effect adapter = _create_adapter(mock_config, mock_bq_client) result = adapter.sync_table(table_config, sync_state) assert result["success"] is True - # Should have created two partition files - partition_files = list(partition_dir.glob("*.parquet")) + # Should have exactly 2 partition files (days with data) + partition_files = sorted(partition_dir.glob("*.parquet")) assert len(partition_files) == 2 - partition_names = sorted(f.stem for f in partition_files) - assert "2026_01" in partition_names - assert "2026_02" in partition_names + file_names = sorted(f.stem for f in partition_files) + assert day0.strftime("%Y_%m_%d") in file_names + assert day1.strftime("%Y_%m_%d") in file_names + + # Verify content of each partition + t0 = pq.read_table(partition_dir / f"{day0.strftime('%Y_%m_%d')}.parquet") + assert t0.num_rows == 1 + assert t0.column("id").to_pylist() == [1] + + def test_incremental_sync_only_fetches_window( + self, mock_config, mock_bq_client, tmp_parquet_dir, sync_state + ): + """After a previous sync, only the incremental window of days is fetched.""" + table_config, partition_dir = self._setup_partition_config( + mock_config, tmp_parquet_dir, + max_history_days=30, + incremental_window_days=2, + granularity="day", + ) + + # Simulate a previous sync 1 day ago + sync_time = (datetime.now() - timedelta(days=1)).isoformat() + sync_state.update_sync( + table_id=table_config.id, + table_name=table_config.name, + strategy="partitioned", + rows=100, + file_size_bytes=5000, + ) + + # Return empty for all calls -- we just want to verify the call count + mock_bq_client.read_table_partitioned_streaming.return_value = iter([]) + + adapter = _create_adapter(mock_config, mock_bq_client) + result = adapter.sync_table(table_config, sync_state) + + assert result["success"] is True + + # With incremental_window_days=2, it should go back 2 days from last_sync. + # The number of partition dates from (last_sync - 2 days) to today. + last_sync_str = sync_state.get_last_sync(table_config.id) + last_sync_dt = datetime.fromisoformat(last_sync_str) + start_date = (last_sync_dt - timedelta(days=2)).date() + today = date.today() + expected_days = (today - start_date).days + 1 # inclusive + + actual_calls = mock_bq_client.read_table_partitioned_streaming.call_count + assert actual_calls == expected_days, ( + f"Expected {expected_days} BQ calls (from {start_date} to {today}), got {actual_calls}" + ) + + def test_merges_with_existing_partition( + self, mock_config, mock_bq_client, tmp_parquet_dir, sync_state + ): + """New data for an existing partition merges and deduplicates on PK.""" + table_config, partition_dir = self._setup_partition_config( + mock_config, tmp_parquet_dir, max_history_days=3, granularity="day", + ) + + today = date.today() + target_day = today - timedelta(days=1) + partition_key = target_day.strftime("%Y_%m_%d") + partition_path = partition_dir / f"{partition_key}.parquet" + + # Pre-write an existing partition file with id=1 + existing = pa.table({ + "id": [1], + "event_date": pa.array( + [datetime(target_day.year, target_day.month, target_day.day)], + type=pa.timestamp("us"), + ), + }) + pq.write_table(existing, partition_path, compression="snappy") + + # New data: id=1 (update) + id=2 (new row) + new_data = pa.table({ + "id": [1, 2], + "event_date": pa.array( + [ + datetime(target_day.year, target_day.month, target_day.day), + datetime(target_day.year, target_day.month, target_day.day), + ], + type=pa.timestamp("us"), + ), + }) + + def _streaming_side_effect(*, table_id, partition_column, start, end, columns, column_type): + start_date = date.fromisoformat(start) + if start_date == target_day: + return iter(new_data.to_batches()) + return iter([]) + + mock_bq_client.read_table_partitioned_streaming.side_effect = _streaming_side_effect + + adapter = _create_adapter(mock_config, mock_bq_client) + result = adapter.sync_table(table_config, sync_state) + + assert result["success"] is True + + # Read back the target partition -- should have 2 rows (dedup on id) + merged = pq.read_table(partition_path) + assert merged.num_rows == 2 + assert sorted(merged.column("id").to_pylist()) == [1, 2] + + def test_empty_partition_skipped( + self, mock_config, mock_bq_client, tmp_parquet_dir, sync_state + ): + """A partition day with no data from BQ should not create a file.""" + table_config, partition_dir = self._setup_partition_config( + mock_config, tmp_parquet_dir, max_history_days=2, granularity="day", + ) + + # Return empty iterator for every call + mock_bq_client.read_table_partitioned_streaming.return_value = iter([]) + # side_effect takes precedence over return_value when set, but let's use + # a function so each call gets a fresh empty iterator + mock_bq_client.read_table_partitioned_streaming.side_effect = ( + lambda **kw: iter([]) + ) + + adapter = _create_adapter(mock_config, mock_bq_client) + result = adapter.sync_table(table_config, sync_state) + + assert result["success"] is True + + # No partition files should have been created + partition_files = list(partition_dir.glob("*.parquet")) + assert len(partition_files) == 0 # --------------------------------------------------------------------------- @@ -574,15 +739,13 @@ class TestSyncTableDispatch: def test_dispatches_partitioned( self, mock_config, mock_bq_client, tmp_parquet_dir, sync_state ): - """sync_strategy='incremental' with partition_by should call _partitioned_sync.""" - import pandas as pd - + """sync_strategy='partitioned' should call _partitioned_sync.""" table_config = _make_table_config( - sync_strategy="incremental", + sync_strategy="partitioned", incremental_column="created_at", partition_by="created_at", - partition_granularity="month", - incremental_window_days=7, + partition_granularity="day", + max_history_days=2, ) partition_dir = tmp_parquet_dir / "orders" partition_dir.mkdir(parents=True, exist_ok=True) @@ -592,13 +755,9 @@ class TestSyncTableDispatch: return partition_dir / f"{key}.parquet" mock_config.get_partition_path.side_effect = _partition_path - ts = [pd.Timestamp("2026-01-15 10:00:00", tz="UTC")] - arrow_data = pa.table({ - "id": [1], - "name": ["A"], - "created_at": pa.array(ts, type=pa.timestamp("us", tz="UTC")), - }) - mock_bq_client.read_table.return_value = arrow_data + mock_bq_client.read_table_partitioned_streaming.side_effect = ( + lambda **kw: iter([]) + ) adapter = _create_adapter(mock_config, mock_bq_client) @@ -770,3 +929,175 @@ class TestCreateDataSourceFactory: from connectors.bigquery.adapter import create_data_source, BigQueryDataSource instance = create_data_source() assert isinstance(instance, BigQueryDataSource) + + +# --------------------------------------------------------------------------- +# 14. _cleanup_old_partitions deletes files beyond retention window +# --------------------------------------------------------------------------- + +class TestPartitionCleanup: + + def test_deletes_old_partitions(self, mock_config, mock_bq_client, tmp_parquet_dir): + """Partition files older than max_history_days should be deleted.""" + table_config = _make_table_config( + sync_strategy="partitioned", + partition_by="event_date", + partition_granularity="day", + max_history_days=5, + ) + + partition_dir = tmp_parquet_dir / "orders" + partition_dir.mkdir(parents=True, exist_ok=True) + + today = date.today() + # Create files: 3 days ago (keep), 6 days ago (delete), 10 days ago (delete) + keep_day = today - timedelta(days=3) + delete_day1 = today - timedelta(days=6) + delete_day2 = today - timedelta(days=10) + + for d in [keep_day, delete_day1, delete_day2]: + key = d.strftime("%Y_%m_%d") + dummy = pa.table({"id": [1]}) + pq.write_table(dummy, partition_dir / f"{key}.parquet") + + adapter = _create_adapter(mock_config, mock_bq_client) + deleted = adapter._cleanup_old_partitions(table_config, partition_dir, "day") + + assert deleted == 2 + + # Only the recent file should remain + remaining = [f.stem for f in partition_dir.glob("*.parquet")] + assert keep_day.strftime("%Y_%m_%d") in remaining + assert delete_day1.strftime("%Y_%m_%d") not in remaining + assert delete_day2.strftime("%Y_%m_%d") not in remaining + + def test_no_cleanup_without_max_history_days(self, mock_config, mock_bq_client, tmp_parquet_dir): + """Without max_history_days, no partition files should be deleted.""" + table_config = _make_table_config( + sync_strategy="partitioned", + partition_by="event_date", + partition_granularity="day", + max_history_days=None, + ) + + partition_dir = tmp_parquet_dir / "orders" + partition_dir.mkdir(parents=True, exist_ok=True) + + # Create an old file (100 days ago) + old_day = date.today() - timedelta(days=100) + key = old_day.strftime("%Y_%m_%d") + pq.write_table(pa.table({"id": [1]}), partition_dir / f"{key}.parquet") + + adapter = _create_adapter(mock_config, mock_bq_client) + deleted = adapter._cleanup_old_partitions(table_config, partition_dir, "day") + + assert deleted == 0 + assert len(list(partition_dir.glob("*.parquet"))) == 1 + + +# --------------------------------------------------------------------------- +# 15. _generate_partition_dates produces correct date ranges +# --------------------------------------------------------------------------- + +class TestGeneratePartitionDates: + + def test_daily_generation(self, mock_config, mock_bq_client): + """Daily granularity should generate one date per day, inclusive.""" + adapter = _create_adapter(mock_config, mock_bq_client) + + start = date(2026, 3, 1) + end = date(2026, 3, 5) + dates = adapter._generate_partition_dates(start, end, "day") + + assert dates == [ + date(2026, 3, 1), + date(2026, 3, 2), + date(2026, 3, 3), + date(2026, 3, 4), + date(2026, 3, 5), + ] + + def test_monthly_generation(self, mock_config, mock_bq_client): + """Monthly granularity should generate first-of-month dates, aligned.""" + adapter = _create_adapter(mock_config, mock_bq_client) + + # Start mid-month -- should align to 1st + start = date(2026, 1, 15) + end = date(2026, 4, 10) + dates = adapter._generate_partition_dates(start, end, "month") + + assert dates == [ + date(2026, 1, 1), + date(2026, 2, 1), + date(2026, 3, 1), + date(2026, 4, 1), + ] + + def test_monthly_generation_across_year_boundary(self, mock_config, mock_bq_client): + """Monthly generation should cross year boundaries correctly.""" + adapter = _create_adapter(mock_config, mock_bq_client) + + start = date(2025, 11, 1) + end = date(2026, 2, 15) + dates = adapter._generate_partition_dates(start, end, "month") + + assert dates == [ + date(2025, 11, 1), + date(2025, 12, 1), + date(2026, 1, 1), + date(2026, 2, 1), + ] + + def test_daily_single_day(self, mock_config, mock_bq_client): + """When start == end, should return a single date.""" + adapter = _create_adapter(mock_config, mock_bq_client) + + d = date(2026, 6, 15) + dates = adapter._generate_partition_dates(d, d, "day") + + assert dates == [d] + + def test_empty_range(self, mock_config, mock_bq_client): + """When start > end, should return an empty list.""" + adapter = _create_adapter(mock_config, mock_bq_client) + + dates = adapter._generate_partition_dates(date(2026, 3, 10), date(2026, 3, 5), "day") + assert dates == [] + + +# --------------------------------------------------------------------------- +# 16. _parse_partition_date converts partition keys back to dates +# --------------------------------------------------------------------------- + +class TestParsePartitionDate: + + def test_parse_day_format(self, mock_config, mock_bq_client): + """'2026_01_15' with day granularity should parse to date(2026, 1, 15).""" + adapter = _create_adapter(mock_config, mock_bq_client) + result = adapter._parse_partition_date("2026_01_15", "day") + assert result == date(2026, 1, 15) + + def test_parse_month_format(self, mock_config, mock_bq_client): + """'2026_01' with month granularity should parse to date(2026, 1, 1).""" + adapter = _create_adapter(mock_config, mock_bq_client) + result = adapter._parse_partition_date("2026_01", "month") + assert result == date(2026, 1, 1) + + def test_parse_year_format(self, mock_config, mock_bq_client): + """'2026' with year granularity should parse to date(2026, 1, 1).""" + adapter = _create_adapter(mock_config, mock_bq_client) + result = adapter._parse_partition_date("2026", "year") + assert result == date(2026, 1, 1) + + def test_parse_invalid_returns_none(self, mock_config, mock_bq_client): + """Invalid partition key should return None.""" + adapter = _create_adapter(mock_config, mock_bq_client) + assert adapter._parse_partition_date("invalid", "day") is None + assert adapter._parse_partition_date("not_a_date", "month") is None + assert adapter._parse_partition_date("abc", "year") is None + + def test_parse_mismatched_granularity_returns_none(self, mock_config, mock_bq_client): + """Day key with month granularity should return None (format mismatch).""" + adapter = _create_adapter(mock_config, mock_bq_client) + # "2026_01_15" is a day format -- parsing as "month" (%Y_%m) should fail + assert adapter._parse_partition_date("2026_01_15", "month") is None diff --git a/tests/test_bigquery_client.py b/tests/test_bigquery_client.py index d499eab..aa0a678 100644 --- a/tests/test_bigquery_client.py +++ b/tests/test_bigquery_client.py @@ -868,3 +868,151 @@ class TestCreateClient: assert isinstance(result, BigQueryClient) assert result.project_id == "factory-project" + + +# --------------------------------------------------------------------------- +# 14. read_table_partitioned_streaming yields RecordBatches +# --------------------------------------------------------------------------- + +class TestReadTablePartitionedStreaming: + def test_streaming_yields_batches(self, client, mock_bq_client): + """read_table_partitioned_streaming yields RecordBatches, not a Table.""" + batch1 = pa.record_batch({"a": [1, 2]}) + batch2 = pa.record_batch({"a": [3, 4]}) + + mock_query_job = MagicMock() + mock_row_iter = MagicMock() + mock_row_iter.to_arrow_iterable.return_value = iter([batch1, batch2]) + mock_query_job.result.return_value = mock_row_iter + mock_bq_client.query.return_value = mock_query_job + + with patch("connectors.bigquery.client.bigquery") as mock_bq_module: + mock_bq_module.QueryJobConfig.return_value = MagicMock() + mock_bq_module.ScalarQueryParameter.return_value = MagicMock() + client.client = mock_bq_client + client.bqstorage_client = None # disable Storage API for simplicity + + batches = list(client.read_table_partitioned_streaming( + table_id="proj.dataset.events", + partition_column="event_date", + start="2025-01-01", + )) + + assert len(batches) == 2 + assert isinstance(batches[0], pa.RecordBatch) + assert isinstance(batches[1], pa.RecordBatch) + + def test_streaming_with_date_column_type(self, client, mock_bq_client): + """With column_type='DATE', ScalarQueryParameter uses 'DATE' type.""" + batch = pa.record_batch({"a": [1]}) + + mock_query_job = MagicMock() + mock_row_iter = MagicMock() + mock_row_iter.to_arrow_iterable.return_value = iter([batch]) + mock_query_job.result.return_value = mock_row_iter + mock_bq_client.query.return_value = mock_query_job + + with patch("connectors.bigquery.client.bigquery") as mock_bq_module: + mock_bq_module.QueryJobConfig.return_value = MagicMock() + mock_bq_module.ScalarQueryParameter.return_value = MagicMock() + client.client = mock_bq_client + client.bqstorage_client = None + + list(client.read_table_partitioned_streaming( + table_id="proj.dataset.events", + partition_column="event_date", + start="2025-01-01", + column_type="DATE", + )) + + # Verify ScalarQueryParameter was called with "DATE" type + mock_bq_module.ScalarQueryParameter.assert_called_once_with( + "start_value", "DATE", "2025-01-01" + ) + + def test_streaming_start_and_end(self, client, mock_bq_client): + """With start and end, both params are created with correct column_type.""" + batch = pa.record_batch({"a": [1]}) + + mock_query_job = MagicMock() + mock_row_iter = MagicMock() + mock_row_iter.to_arrow_iterable.return_value = iter([batch]) + mock_query_job.result.return_value = mock_row_iter + mock_bq_client.query.return_value = mock_query_job + + with patch("connectors.bigquery.client.bigquery") as mock_bq_module: + mock_bq_module.QueryJobConfig.return_value = MagicMock() + mock_bq_module.ScalarQueryParameter.return_value = MagicMock() + client.client = mock_bq_client + client.bqstorage_client = None + + list(client.read_table_partitioned_streaming( + table_id="proj.dataset.events", + partition_column="event_date", + start="2025-01-01", + end="2025-06-01", + column_type="DATE", + )) + + sql = mock_bq_client.query.call_args[0][0] + assert "`event_date` >= @start_value" in sql + assert "`event_date` < @end_value" in sql + + # Both parameters created with "DATE" type + assert mock_bq_module.ScalarQueryParameter.call_count == 2 + calls = mock_bq_module.ScalarQueryParameter.call_args_list + assert calls[0].args == ("start_value", "DATE", "2025-01-01") + assert calls[1].args == ("end_value", "DATE", "2025-06-01") + + +# --------------------------------------------------------------------------- +# 15. read_table_partitioned column_type parameter +# --------------------------------------------------------------------------- + +class TestReadTablePartitionedColumnType: + def test_date_column_type(self, client, mock_bq_client): + """read_table_partitioned with column_type='DATE' creates DATE params.""" + mock_job = MagicMock() + mock_job.to_arrow.return_value = pa.table({"a": [1]}) + mock_bq_client.query.return_value = mock_job + + with patch("connectors.bigquery.client.bigquery") as mock_bq_module: + mock_bq_module.QueryJobConfig.return_value = MagicMock() + mock_bq_module.ScalarQueryParameter.return_value = MagicMock() + client.client = mock_bq_client + + client.read_table_partitioned( + table_id="proj.dataset.events", + partition_column="event_date", + start="2025-01-01", + end="2025-06-01", + column_type="DATE", + ) + + # Both parameters should use "DATE" type + assert mock_bq_module.ScalarQueryParameter.call_count == 2 + calls = mock_bq_module.ScalarQueryParameter.call_args_list + assert calls[0].args == ("start_value", "DATE", "2025-01-01") + assert calls[1].args == ("end_value", "DATE", "2025-06-01") + + def test_default_column_type_is_timestamp(self, client, mock_bq_client): + """Default column_type is TIMESTAMP when not specified.""" + mock_job = MagicMock() + mock_job.to_arrow.return_value = pa.table({"a": [1]}) + mock_bq_client.query.return_value = mock_job + + with patch("connectors.bigquery.client.bigquery") as mock_bq_module: + mock_bq_module.QueryJobConfig.return_value = MagicMock() + mock_bq_module.ScalarQueryParameter.return_value = MagicMock() + client.client = mock_bq_client + + client.read_table_partitioned( + table_id="proj.dataset.events", + partition_column="created_at", + start="2025-01-01T00:00:00Z", + ) + + # Should default to "TIMESTAMP" + mock_bq_module.ScalarQueryParameter.assert_called_once_with( + "start_value", "TIMESTAMP", "2025-01-01T00:00:00Z" + ) diff --git a/tests/test_config_query_mode.py b/tests/test_config_query_mode.py new file mode 100644 index 0000000..6370045 --- /dev/null +++ b/tests/test_config_query_mode.py @@ -0,0 +1,69 @@ +"""Tests for TableConfig.query_mode field validation.""" + +import pytest + +from src.config import TableConfig + + +# --------------------------------------------------------------------------- +# Helpers +# --------------------------------------------------------------------------- +def _make_table(**overrides) -> TableConfig: + """Create a TableConfig with sensible defaults, applying overrides.""" + defaults = dict( + id="test.dataset.table", + name="test_table", + description="Test", + primary_key="id", + sync_strategy="full_refresh", + ) + defaults.update(overrides) + return TableConfig(**defaults) + + +# --------------------------------------------------------------------------- +# Tests +# --------------------------------------------------------------------------- +class TestQueryModeDefault: + def test_default_is_local(self): + table = _make_table() + assert table.query_mode == "local" + + +class TestQueryModeValidValues: + @pytest.mark.parametrize("mode", ["local", "remote", "hybrid"]) + def test_valid_query_mode(self, mode): + table = _make_table(query_mode=mode) + assert table.query_mode == mode + + +class TestQueryModeInvalid: + @pytest.mark.parametrize("bad_mode", ["invalid", "Local", "REMOTE", "", "sql"]) + def test_invalid_query_mode_raises(self, bad_mode): + with pytest.raises(ValueError, match="Invalid query_mode"): + _make_table(query_mode=bad_mode) + + +class TestQueryModeFromKwarg: + def test_kwarg_sets_query_mode(self): + """Simulate what _parse_data_description does: pass query_mode as kwarg.""" + table = TableConfig( + id="proj.dataset.orders", + name="orders", + description="Order data", + primary_key="order_id", + sync_strategy="full_refresh", + query_mode="remote", + ) + assert table.query_mode == "remote" + + def test_kwarg_default_when_omitted(self): + """When YAML has no query_mode, _parse_data_description passes 'local'.""" + table = TableConfig( + id="proj.dataset.orders", + name="orders", + description="Order data", + primary_key="order_id", + sync_strategy="full_refresh", + ) + assert table.query_mode == "local" diff --git a/tests/test_data_sync_query_mode.py b/tests/test_data_sync_query_mode.py new file mode 100644 index 0000000..211b94e --- /dev/null +++ b/tests/test_data_sync_query_mode.py @@ -0,0 +1,228 @@ +"""Tests for remote table skipping in DataSyncManager.sync_all(). + +Tables with query_mode == "remote" should be skipped during sync (no local +Parquet file is needed -- queries go directly to BigQuery). Tables with +query_mode "local" or "hybrid" must still be synced normally. +""" + +from unittest.mock import MagicMock, patch + +import pytest + +from src.config import TableConfig +from src.data_sync import DataSyncManager + + +# --------------------------------------------------------------------------- +# Helpers +# --------------------------------------------------------------------------- + +def _make_table_config( + table_id: str, + name: str, + query_mode: str = "local", +) -> TableConfig: + """Create a minimal TableConfig for testing.""" + return TableConfig( + id=table_id, + name=name, + description=f"Test table {name}", + primary_key="id", + sync_strategy="full_refresh", + query_mode=query_mode, + ) + + +def _successful_sync_result() -> dict: + """Return a fake successful sync result dict.""" + return { + "success": True, + "rows": 100, + "file_size_mb": 0.5, + } + + +# --------------------------------------------------------------------------- +# Fixtures +# --------------------------------------------------------------------------- + +@pytest.fixture +def table_local(): + return _make_table_config("t.local", "local_table", query_mode="local") + + +@pytest.fixture +def table_remote(): + return _make_table_config("t.remote", "remote_table", query_mode="remote") + + +@pytest.fixture +def table_hybrid(): + return _make_table_config("t.hybrid", "hybrid_table", query_mode="hybrid") + + +@pytest.fixture +def all_tables(table_local, table_remote, table_hybrid): + return [table_local, table_remote, table_hybrid] + + +@pytest.fixture +def mock_config(all_tables): + """Return a mock Config whose .tables list contains all three query modes.""" + cfg = MagicMock() + cfg.tables = all_tables + cfg.get_metadata_path.return_value = MagicMock() # Path-like + + def _get_table_config(tid): + return next((t for t in all_tables if t.id == tid), None) + + cfg.get_table_config.side_effect = _get_table_config + return cfg + + +@pytest.fixture +def mock_data_source(): + """Return a mock DataSource that always succeeds.""" + ds = MagicMock() + ds.sync_table.return_value = _successful_sync_result() + return ds + + +@pytest.fixture +def sync_manager(mock_config, mock_data_source): + """Create a DataSyncManager with mocked dependencies.""" + with ( + patch("src.data_sync.get_config", return_value=mock_config), + patch("src.data_sync.create_data_source", return_value=mock_data_source), + patch("src.data_sync.SyncState"), + ): + manager = DataSyncManager() + # Patch out schema generation and auto-profiling (not under test) + manager._generate_schema_yaml = MagicMock() + yield manager + + +# --------------------------------------------------------------------------- +# Tests +# --------------------------------------------------------------------------- + +class TestSyncAllRemoteSkipping: + """Verify that sync_all filters out remote tables.""" + + def test_remote_table_not_synced(self, sync_manager, mock_data_source, table_remote): + """Remote table must NOT be passed to data_source.sync_table.""" + sync_manager.sync_all() + + synced_ids = [ + call.args[0].id for call in mock_data_source.sync_table.call_args_list + ] + assert table_remote.id not in synced_ids + + def test_local_table_is_synced(self, sync_manager, mock_data_source, table_local): + """Local table must be synced normally.""" + sync_manager.sync_all() + + synced_ids = [ + call.args[0].id for call in mock_data_source.sync_table.call_args_list + ] + assert table_local.id in synced_ids + + def test_hybrid_table_is_synced(self, sync_manager, mock_data_source, table_hybrid): + """Hybrid table must be synced (needs local parquet for profiling).""" + sync_manager.sync_all() + + synced_ids = [ + call.args[0].id for call in mock_data_source.sync_table.call_args_list + ] + assert table_hybrid.id in synced_ids + + def test_sync_call_count(self, sync_manager, mock_data_source): + """Only local + hybrid tables should result in sync_table calls.""" + sync_manager.sync_all() + + # 3 tables total, 1 remote -> 2 sync calls + assert mock_data_source.sync_table.call_count == 2 + + def test_results_exclude_remote(self, sync_manager, table_remote): + """The results dict must not contain an entry for the remote table.""" + results = sync_manager.sync_all() + + assert table_remote.id not in results + + def test_results_include_local_and_hybrid( + self, sync_manager, table_local, table_hybrid + ): + """Results dict must contain entries for local and hybrid tables.""" + results = sync_manager.sync_all() + + assert table_local.id in results + assert table_hybrid.id in results + + +class TestSyncAllAllRemote: + """Edge case: every table is remote.""" + + def test_no_sync_calls_when_all_remote(self, mock_config, mock_data_source): + remote_only = [ + _make_table_config("t.r1", "remote1", query_mode="remote"), + _make_table_config("t.r2", "remote2", query_mode="remote"), + ] + mock_config.tables = remote_only + + with ( + patch("src.data_sync.get_config", return_value=mock_config), + patch("src.data_sync.create_data_source", return_value=mock_data_source), + patch("src.data_sync.SyncState"), + ): + manager = DataSyncManager() + manager._generate_schema_yaml = MagicMock() + results = manager.sync_all() + + assert mock_data_source.sync_table.call_count == 0 + assert results == {} + + +class TestSyncAllNoRemote: + """Edge case: no remote tables at all -- everything syncs.""" + + def test_all_tables_synced(self, mock_config, mock_data_source): + local_only = [ + _make_table_config("t.l1", "local1", query_mode="local"), + _make_table_config("t.l2", "local2", query_mode="local"), + ] + mock_config.tables = local_only + + with ( + patch("src.data_sync.get_config", return_value=mock_config), + patch("src.data_sync.create_data_source", return_value=mock_data_source), + patch("src.data_sync.SyncState"), + ): + manager = DataSyncManager() + manager._generate_schema_yaml = MagicMock() + results = manager.sync_all() + + assert mock_data_source.sync_table.call_count == 2 + assert "t.l1" in results + assert "t.l2" in results + + +class TestSyncAllWithTableFilter: + """When sync_all receives an explicit table list, remote filtering still applies.""" + + def test_explicit_remote_table_still_skipped( + self, sync_manager, mock_data_source, table_remote + ): + """Even if explicitly listed, a remote table should be skipped.""" + sync_manager.sync_all(tables=[table_remote.id]) + + assert mock_data_source.sync_table.call_count == 0 + + def test_explicit_local_table_synced( + self, sync_manager, mock_data_source, table_local + ): + """An explicitly listed local table should be synced.""" + sync_manager.sync_all(tables=[table_local.id]) + + assert mock_data_source.sync_table.call_count == 1 + synced_id = mock_data_source.sync_table.call_args_list[0].args[0].id + assert synced_id == table_local.id diff --git a/tests/test_duckdb_manager.py b/tests/test_duckdb_manager.py new file mode 100644 index 0000000..ba5233d --- /dev/null +++ b/tests/test_duckdb_manager.py @@ -0,0 +1,488 @@ +"""Tests for DuckDB Manager - query_mode classification and BQ registration. + +Tests cover: +- _get_bq_project_from_table_id: extracting BQ project from table IDs +- get_remote_tables: filtering tables by query_mode +- register_bq_table: registering BQ query results in DuckDB +- init_duckdb: table classification by query_mode, local view creation, + remote table logging +""" + +import os +from pathlib import Path +from unittest.mock import MagicMock, patch + +import duckdb +import pyarrow as pa +import pyarrow.parquet as pq +import pytest + +from scripts.duckdb_manager import ( + _get_bq_project_from_table_id, + get_remote_tables, + init_duckdb, + register_bq_table, +) + + +# --------------------------------------------------------------------------- +# Fixtures +# --------------------------------------------------------------------------- + +@pytest.fixture +def tmp_project(tmp_path): + """Create a minimal project layout with docs/data_description.md and a parquet file. + + Returns (project_root, db_path, data_dir) tuple. + """ + docs_dir = tmp_path / "docs" + docs_dir.mkdir() + + data_description = """\ +# Data Description + +```yaml +folder_mapping: + in.c-crm: crm_data + +tables: + - id: "in.c-crm.company" + name: "company" + description: "Company master data" + primary_key: "id" + sync_strategy: "full_refresh" +``` +""" + (docs_dir / "data_description.md").write_text(data_description) + + # Create parquet directory and a minimal parquet file + data_dir = tmp_path / "server" / "parquet" / "crm_data" + data_dir.mkdir(parents=True) + + table = pa.table({"id": [1, 2, 3], "name": ["a", "b", "c"]}) + pq.write_table(table, data_dir / "company.parquet") + + db_dir = tmp_path / "user" / "duckdb" + db_dir.mkdir(parents=True) + db_path = str(db_dir / "test.duckdb") + + return tmp_path, db_path, str(tmp_path / "server") + + +@pytest.fixture +def tmp_project_mixed(tmp_path): + """Project with local, remote, and hybrid tables.""" + docs_dir = tmp_path / "docs" + docs_dir.mkdir() + + data_description = """\ +# Data Description + +```yaml +folder_mapping: + in.c-crm: crm_data + +tables: + - id: "in.c-crm.company" + name: "company" + description: "Local table" + primary_key: "id" + sync_strategy: "full_refresh" + + - id: "prj-grp-dataview-prod-1ff9.finance.revenue" + name: "revenue" + description: "Remote BQ table" + primary_key: "id" + query_mode: "remote" + + - id: "prj-grp-dataview-prod-1ff9.marketing.campaigns" + name: "campaigns" + description: "Hybrid table" + primary_key: "id" + sync_strategy: "full_refresh" + query_mode: "hybrid" +``` +""" + (docs_dir / "data_description.md").write_text(data_description) + + # Create parquet files for local and hybrid tables + crm_dir = tmp_path / "server" / "parquet" / "crm_data" + crm_dir.mkdir(parents=True) + table = pa.table({"id": [1, 2], "name": ["a", "b"]}) + pq.write_table(table, crm_dir / "company.parquet") + + marketing_dir = tmp_path / "server" / "parquet" / "prj-grp-dataview-prod-1ff9.marketing" + marketing_dir.mkdir(parents=True) + campaigns_table = pa.table({"id": [10], "campaign": ["test"]}) + pq.write_table(campaigns_table, marketing_dir / "campaigns.parquet") + + db_dir = tmp_path / "user" / "duckdb" + db_dir.mkdir(parents=True) + db_path = str(db_dir / "test.duckdb") + + return tmp_path, db_path, str(tmp_path / "server") + + +@pytest.fixture +def tmp_project_remote_only(tmp_path): + """Project with only remote tables (no local parquet needed).""" + docs_dir = tmp_path / "docs" + docs_dir.mkdir() + + data_description = """\ +# Data Description + +```yaml +tables: + - id: "prj-grp-dataview-prod-1ff9.finance.revenue" + name: "revenue" + description: "Remote BQ table" + primary_key: "id" + query_mode: "remote" + + - id: "prj-grp-dataview-prod-1ff9.finance.costs" + name: "costs" + description: "Remote BQ table" + primary_key: "id" + query_mode: "remote" +``` +""" + (docs_dir / "data_description.md").write_text(data_description) + + db_dir = tmp_path / "user" / "duckdb" + db_dir.mkdir(parents=True) + db_path = str(db_dir / "test.duckdb") + + return tmp_path, db_path, str(tmp_path / "server") + + +# --------------------------------------------------------------------------- +# Tests: _get_bq_project_from_table_id +# --------------------------------------------------------------------------- + +class TestGetBqProjectFromTableId: + """Test extracting BigQuery project ID from fully-qualified table IDs.""" + + def test_valid_bq_table_id(self): + result = _get_bq_project_from_table_id( + "prj-grp-dataview-prod-1ff9.finance.table" + ) + assert result == "prj-grp-dataview-prod-1ff9" + + def test_valid_bq_table_id_different_project(self): + result = _get_bq_project_from_table_id( + "my-gcp-project.dataset_name.table_name" + ) + assert result == "my-gcp-project" + + def test_keboola_format_returns_none(self): + result = _get_bq_project_from_table_id("in.c-crm.table") + assert result is None + + def test_two_part_id_returns_none(self): + result = _get_bq_project_from_table_id("dataset.table") + assert result is None + + def test_single_part_returns_none(self): + result = _get_bq_project_from_table_id("table_only") + assert result is None + + def test_four_parts_returns_none(self): + result = _get_bq_project_from_table_id("a-b.c.d.e") + assert result is None + + def test_empty_string_returns_none(self): + result = _get_bq_project_from_table_id("") + assert result is None + + def test_three_parts_no_hyphen_returns_none(self): + result = _get_bq_project_from_table_id("project.dataset.table") + assert result is None + + def test_hyphen_in_first_part_is_key(self): + result = _get_bq_project_from_table_id("a-b.dataset.table") + assert result == "a-b" + + +# --------------------------------------------------------------------------- +# Tests: get_remote_tables +# --------------------------------------------------------------------------- + +class TestGetRemoteTables: + """Test filtering table configs by query_mode.""" + + def test_returns_remote_tables(self): + configs = [ + {"name": "local", "query_mode": "local"}, + {"name": "remote1", "query_mode": "remote"}, + {"name": "hybrid1", "query_mode": "hybrid"}, + ] + result = get_remote_tables(configs) + names = [tc["name"] for tc in result] + assert "remote1" in names + assert "hybrid1" in names + assert "local" not in names + + def test_returns_empty_when_all_local(self): + configs = [ + {"name": "t1", "query_mode": "local"}, + {"name": "t2"}, # defaults to local (no query_mode key) + ] + result = get_remote_tables(configs) + assert result == [] + + def test_missing_query_mode_treated_as_local(self): + configs = [{"name": "t1"}] # no query_mode + result = get_remote_tables(configs) + assert result == [] + + +# --------------------------------------------------------------------------- +# Tests: register_bq_table +# --------------------------------------------------------------------------- + +class TestRegisterBqTable: + """Test registering BQ query results as DuckDB views.""" + + @staticmethod + def _make_factory(arrow_table, side_effect=None): + """Create a mock BQ client factory returning a client that yields arrow_table.""" + mock_job = MagicMock() + if side_effect: + mock_job.to_arrow.side_effect = side_effect + else: + mock_job.to_arrow.return_value = arrow_table + mock_client = MagicMock() + mock_client.query.return_value = mock_job + factory = MagicMock(return_value=mock_client) + factory._mock_client = mock_client + factory._mock_job = mock_job + return factory + + def test_registers_arrow_table_in_duckdb(self): + """Result from BQ should be queryable in DuckDB after registration.""" + arrow_table = pa.table({"id": [1, 2], "val": [10.0, 20.0]}) + factory = self._make_factory(arrow_table) + + conn = duckdb.connect() + rows = register_bq_table( + conn=conn, + table_id="proj.dataset.table", + view_name="test_view", + sql="SELECT * FROM table", + bq_project="test-project", + _bq_client_factory=factory, + ) + + assert rows == 2 + result = conn.execute("SELECT SUM(val) FROM test_view").fetchone()[0] + assert result == 30.0 + conn.close() + + def test_raises_without_bq_project(self): + conn = duckdb.connect() + with patch.dict(os.environ, {}, clear=True): + with pytest.raises(ValueError, match="BigQuery project not set"): + register_bq_table( + conn=conn, + table_id="proj.ds.tbl", + view_name="v", + sql="SELECT 1", + ) + conn.close() + + def test_uses_env_var_when_no_project_arg(self): + arrow_table = pa.table({"x": [1]}) + factory = self._make_factory(arrow_table) + + conn = duckdb.connect() + with patch.dict(os.environ, {"BIGQUERY_PROJECT": "env-proj"}): + register_bq_table( + conn=conn, + table_id="p.d.t", + view_name="v", + sql="SELECT 1", + _bq_client_factory=factory, + ) + + factory.assert_called_once_with("env-proj") + conn.close() + + def test_storage_api_fallback(self): + """Falls back to REST when Storage API permission denied.""" + arrow_table = pa.table({"x": [1]}) + factory = self._make_factory( + arrow_table, + side_effect=[ + Exception("PERMISSION_DENIED readsessions"), + arrow_table, + ], + ) + + conn = duckdb.connect() + rows = register_bq_table( + conn=conn, + table_id="p.d.t", + view_name="v", + sql="SELECT 1", + bq_project="proj", + _bq_client_factory=factory, + ) + + assert rows == 1 + factory._mock_job.to_arrow.assert_called_with(create_bqstorage_client=False) + conn.close() + + +# --------------------------------------------------------------------------- +# Tests: init_duckdb - table classification +# --------------------------------------------------------------------------- + +class TestInitDuckdbClassification: + """Test that tables are correctly classified by query_mode.""" + + def test_local_tables_create_parquet_views(self, tmp_project): + project_root, db_path, data_dir = tmp_project + + with patch("scripts.duckdb_manager.find_project_root", return_value=project_root): + result = init_duckdb( + db_path=db_path, data_dir=data_dir, verbose=False + ) + + assert result is True + + conn = duckdb.connect(db_path, read_only=True) + tables = [row[0] for row in conn.execute("SHOW TABLES").fetchall()] + assert "company" in tables + row_count = conn.execute("SELECT COUNT(*) FROM company").fetchone()[0] + assert row_count == 3 + conn.close() + + def test_remote_tables_not_created_as_local_views(self, tmp_project_mixed): + project_root, db_path, data_dir = tmp_project_mixed + + with patch("scripts.duckdb_manager.find_project_root", return_value=project_root): + result = init_duckdb( + db_path=db_path, data_dir=data_dir, verbose=False + ) + + assert result is True + + conn = duckdb.connect(db_path, read_only=True) + tables = [row[0] for row in conn.execute("SHOW TABLES").fetchall()] + assert "revenue" not in tables + assert "company" in tables + conn.close() + + def test_hybrid_tables_create_local_views(self, tmp_project_mixed): + project_root, db_path, data_dir = tmp_project_mixed + + with patch("scripts.duckdb_manager.find_project_root", return_value=project_root): + result = init_duckdb( + db_path=db_path, data_dir=data_dir, verbose=False + ) + + assert result is True + + conn = duckdb.connect(db_path, read_only=True) + tables = [row[0] for row in conn.execute("SHOW TABLES").fetchall()] + assert "campaigns" in tables + conn.close() + + def test_default_query_mode_is_local(self, tmp_project): + project_root, db_path, data_dir = tmp_project + + with patch("scripts.duckdb_manager.find_project_root", return_value=project_root): + result = init_duckdb( + db_path=db_path, data_dir=data_dir, verbose=False + ) + + assert result is True + + conn = duckdb.connect(db_path, read_only=True) + tables = [row[0] for row in conn.execute("SHOW TABLES").fetchall()] + assert "company" in tables + conn.close() + + +# --------------------------------------------------------------------------- +# Tests: init_duckdb - remote table logging +# --------------------------------------------------------------------------- + +class TestInitDuckdbRemoteLogging: + """Test that remote tables are logged correctly.""" + + def test_remote_tables_logged(self, tmp_project_remote_only, capsys): + project_root, db_path, data_dir = tmp_project_remote_only + + with patch("scripts.duckdb_manager.find_project_root", return_value=project_root): + init_duckdb( + db_path=db_path, data_dir=data_dir, verbose=True, + ) + + output = capsys.readouterr().out + assert "revenue" in output + assert "costs" in output + assert "[BQ]" in output + + def test_remote_only_project_succeeds(self, tmp_project_remote_only): + project_root, db_path, data_dir = tmp_project_remote_only + + with patch("scripts.duckdb_manager.find_project_root", return_value=project_root): + result = init_duckdb( + db_path=db_path, data_dir=data_dir, verbose=False, + ) + + assert result is True + + +# --------------------------------------------------------------------------- +# Tests: init_duckdb - missing parquet handling +# --------------------------------------------------------------------------- + +class TestInitDuckdbMissingParquet: + """Test behavior when parquet files are missing.""" + + def test_missing_parquet_skips_view(self, tmp_path): + docs_dir = tmp_path / "docs" + docs_dir.mkdir() + + data_description = """\ +# Data Description + +```yaml +tables: + - id: "in.c-crm.missing_table" + name: "missing_table" + description: "No parquet exists" + primary_key: "id" + sync_strategy: "full_refresh" +``` +""" + (docs_dir / "data_description.md").write_text(data_description) + + db_dir = tmp_path / "user" / "duckdb" + db_dir.mkdir(parents=True) + db_path = str(db_dir / "test.duckdb") + + with patch("scripts.duckdb_manager.find_project_root", return_value=tmp_path): + result = init_duckdb( + db_path=db_path, data_dir=str(tmp_path / "server"), verbose=False + ) + + assert result is True + + conn = duckdb.connect(db_path, read_only=True) + tables = [row[0] for row in conn.execute("SHOW TABLES").fetchall()] + assert "missing_table" not in tables + conn.close() + + def test_remote_table_no_local_parquet_needed(self, tmp_project_remote_only): + project_root, db_path, data_dir = tmp_project_remote_only + + with patch("scripts.duckdb_manager.find_project_root", return_value=project_root): + result = init_duckdb( + db_path=db_path, data_dir=data_dir, verbose=False, + ) + + assert result is True