From 758910463b55fedc42171d88f6746decfa19785a Mon Sep 17 00:00:00 2001 From: Petr Date: Wed, 11 Mar 2026 13:56:12 +0100 Subject: [PATCH] Add BigQuery data source adapter BigQuery connector that syncs BQ tables to local Parquet files via PyArrow (no CSV intermediate step). Supports full refresh, timestamp-based incremental (via incremental_column), and partition-based sync strategies. - connectors/bigquery/client.py: BQ API wrapper with ADC auth, parameterized queries, metadata cache, cross-project support (job project != data project) - connectors/bigquery/adapter.py: DataSource implementation with merge/dedup - src/config.py: Add incremental_column field to TableConfig - 72 unit tests (mocked, no GCP SDK required) --- config/instance.yaml.example | 27 +- connectors/bigquery/__init__.py | 11 + connectors/bigquery/adapter.py | 475 +++++++++++++++++ connectors/bigquery/client.py | 469 +++++++++++++++++ requirements.txt | 2 + src/config.py | 2 + src/data_sync.py | 2 +- tests/test_bigquery_adapter.py | 763 ++++++++++++++++++++++++++++ tests/test_bigquery_client.py | 870 ++++++++++++++++++++++++++++++++ 9 files changed, 2619 insertions(+), 2 deletions(-) create mode 100644 connectors/bigquery/__init__.py create mode 100644 connectors/bigquery/adapter.py create mode 100644 connectors/bigquery/client.py create mode 100644 tests/test_bigquery_adapter.py create mode 100644 tests/test_bigquery_client.py diff --git a/config/instance.yaml.example b/config/instance.yaml.example index 30f54ed..c783409 100644 --- a/config/instance.yaml.example +++ b/config/instance.yaml.example @@ -44,13 +44,38 @@ auth: google_client_id: "${GOOGLE_CLIENT_ID}" google_client_secret: "${GOOGLE_CLIENT_SECRET}" +# --- Theme (optional) --- +# Customize colors, fonts, and shape to match your brand. +# All values are optional - defaults provide a clean blue theme. +# See docs/theme-reference.html for a visual guide. +theme: + # primary: "#0073D1" # Main brand color (buttons, links, accents) + # primary_dark: "#005BA3" # Hover/active state of primary + # primary_light: "rgba(0, 115, 209, 0.1)" # Light tint backgrounds + # text_primary: "#1A253C" # Main text color + # text_secondary: "#6B7280" # Muted/secondary text + # background: "#F5F7FA" # Page background + # surface: "#FFFFFF" # Card/panel background + # border: "#E5E7EB" # Borders and dividers + # font_primary: "'Inter', system-ui, sans-serif" + # font_url: "https://fonts.googleapis.com/css2?family=Inter:wght@400;500;600;700&display=swap" + # radius: "6px" # Border radius (cards, buttons, inputs) + # success: "#10B77F" + # warning: "#F59F0A" + # error: "#EA580C" + # --- Data source --- data_source: - type: "keboola" # keboola | csv (bigquery planned) + type: "keboola" # keboola | bigquery | local keboola: storage_token: "${KEBOOLA_STORAGE_TOKEN}" stack_url: "" # e.g., "https://connection.keboola.com" project_id: "" + bigquery: + project: "${BIGQUERY_PROJECT}" # GCP project for job execution/billing + location: "${BIGQUERY_LOCATION}" # BigQuery location (e.g., "us-central1", "US") + # Uses ADC (Application Default Credentials) - VM service account on GCP + # Data can live in a different project -- use fully-qualified table IDs in data_description.md # --- Email delivery (optional, for magic link auth) --- # Without SMTP, magic links are shown directly in browser (development mode). diff --git a/connectors/bigquery/__init__.py b/connectors/bigquery/__init__.py new file mode 100644 index 0000000..e6962fb --- /dev/null +++ b/connectors/bigquery/__init__.py @@ -0,0 +1,11 @@ +""" +BigQuery connector - data source adapter for Google BigQuery. + +Syncs tables from BigQuery using the BigQuery Storage API, +converting query results directly to Parquet files via PyArrow +(no CSV intermediate step). + +Enable by setting data_source.type: "bigquery" in config/instance.yaml +and providing BIGQUERY_PROJECT environment variable. +Uses Application Default Credentials (ADC) for authentication. +""" diff --git a/connectors/bigquery/adapter.py b/connectors/bigquery/adapter.py new file mode 100644 index 0000000..e20fe0c --- /dev/null +++ b/connectors/bigquery/adapter.py @@ -0,0 +1,475 @@ +""" +BigQuery data source adapter. + +Implements the DataSource interface for Google BigQuery. +Reads tables via the BigQuery API, converts directly to Parquet files +using PyArrow (no CSV intermediate step). +""" + +import logging +from pathlib import Path +from typing import Dict, List, Optional, Any +from datetime import datetime, timedelta + +import pyarrow as pa +import pyarrow.parquet as pq + +from src.config import get_config, TableConfig +from src.data_sync import DataSource, SyncState, _get_uncompressed_size +from src.parquet_manager import ( + convert_date_columns_to_date32, + apply_schema_to_table, +) +from .client import create_client as create_bq_client + + +logger = logging.getLogger(__name__) + + +class BigQueryDataSource(DataSource): + """ + Data source: Google BigQuery. + + Downloads data directly from BigQuery via PyArrow (no CSV step), + writes to local Parquet files with schema enforcement. + """ + + def __init__(self): + """Initialize BigQuery source with env var validation.""" + self.config = get_config() + self.bq_client = create_bq_client() + + def get_column_metadata(self, table_id: str) -> Optional[Dict[str, Any]]: + """Return BigQuery column metadata for schema generation. + + Returns: + {"columns": {"col_name": {"source_type": "...", "description": "..."}}} + or None if metadata unavailable. + """ + raw = self.bq_client.get_table_metadata(table_id) + column_types = raw.get("column_types", {}) + column_descriptions = raw.get("column_descriptions", {}) + + if not column_types: + return None + + result = {} + for col_name, bq_type in column_types.items(): + entry = {"source_type": bq_type} + if col_name in column_descriptions: + entry["description"] = column_descriptions[col_name] + result[col_name] = entry + + return {"columns": result} + + def discover_tables(self) -> List[Dict[str, Any]]: + """Discover all available tables from BigQuery.""" + return self.bq_client.discover_all_tables() + + def get_source_name(self) -> str: + """Display name of this data source.""" + return "Google BigQuery" + + def sync_table( + self, + table_config: TableConfig, + sync_state: SyncState, + ) -> Dict[str, Any]: + """ + Synchronize table from BigQuery. + + Dispatches to the appropriate strategy based on table config. + + Args: + table_config: Table configuration + sync_state: Sync state manager + + Returns: + Dictionary with sync result + """ + logger.info(f"Syncing BQ table: {table_config.name} ({table_config.sync_strategy})") + + # Clear metadata cache for fresh types + if table_config.id in self.bq_client.metadata_cache: + del self.bq_client.metadata_cache[table_config.id] + logger.debug(f"Cleared BQ metadata cache for {table_config.id}") + + try: + if table_config.sync_strategy == "full_refresh": + result = self._full_refresh(table_config) + elif table_config.sync_strategy == "incremental": + result = self._incremental_sync(table_config, sync_state) + elif table_config.sync_strategy == "partitioned": + result = self._partitioned_sync(table_config, sync_state) + else: + raise ValueError(f"Unknown sync strategy: {table_config.sync_strategy}") + + # Update sync state + sync_state.update_sync( + table_id=table_config.id, + table_name=table_config.name, + strategy=table_config.sync_strategy, + rows=result["rows"], + file_size_bytes=result["file_size_bytes"], + columns=result.get("columns", 0), + uncompressed_bytes=result.get("uncompressed_bytes", 0), + ) + + return { + "success": True, + "rows": result["rows"], + "strategy": table_config.sync_strategy, + "file_size_mb": result["file_size_bytes"] / 1024 / 1024, + } + + except Exception as e: + logger.error(f"Error syncing BQ table {table_config.name}: {e}") + return { + "success": False, + "error": str(e), + "strategy": table_config.sync_strategy, + } + + def _full_refresh(self, table_config: TableConfig) -> Dict[str, Any]: + """ + Full refresh: read entire table and replace Parquet file. + """ + logger.info(f"Full refresh: {table_config.name}") + + parquet_path = self.config.get_parquet_path(table_config) + date_columns = self.bq_client.get_date_columns(table_config.id) + pyarrow_schema = self.bq_client.get_pyarrow_schema(table_config.id) + + # Read full table from BigQuery -> PyArrow + arrow_table = self.bq_client.read_table(table_config.id) + + # Apply schema enforcement + if date_columns: + arrow_table = convert_date_columns_to_date32(arrow_table, date_columns) + if pyarrow_schema: + arrow_table = apply_schema_to_table(arrow_table, pyarrow_schema) + + # Write to Parquet + parquet_path.parent.mkdir(parents=True, exist_ok=True) + pq.write_table(arrow_table, parquet_path, compression="snappy") + + file_size = parquet_path.stat().st_size + logger.info( + f"Full refresh complete: {arrow_table.num_rows} rows, " + f"{file_size / 1024 / 1024:.2f} MB" + ) + + return { + "rows": arrow_table.num_rows, + "columns": arrow_table.num_columns, + "file_size_bytes": file_size, + "uncompressed_bytes": _get_uncompressed_size(parquet_path), + } + + def _incremental_sync( + self, + table_config: TableConfig, + sync_state: SyncState, + ) -> Dict[str, Any]: + """ + Incremental sync: dispatch to column-based or partition-based strategy. + """ + # If partition_by is set, use partitioned incremental + if table_config.partition_by: + return self._partitioned_sync(table_config, sync_state) + + # If incremental_column is set, use timestamp-based incremental + if table_config.incremental_column: + return self._incremental_column_sync(table_config, sync_state) + + # Fallback: full refresh (no incremental column configured) + logger.warning( + f"Table {table_config.name}: incremental strategy but no " + f"incremental_column or partition_by configured, falling back to full refresh" + ) + return self._full_refresh(table_config) + + def _incremental_column_sync( + self, + table_config: TableConfig, + sync_state: SyncState, + ) -> Dict[str, Any]: + """ + Timestamp-based incremental sync using incremental_column. + + Reads rows WHERE incremental_column > last_sync_value, + merges with existing Parquet (dedup on primary key). + """ + logger.info( + f"Incremental column sync: {table_config.name} " + f"(column: {table_config.incremental_column})" + ) + + parquet_path = self.config.get_parquet_path(table_config) + date_columns = self.bq_client.get_date_columns(table_config.id) + pyarrow_schema = self.bq_client.get_pyarrow_schema(table_config.id) + + # Determine since_value from last sync + last_sync = sync_state.get_last_sync(table_config.id) + + if last_sync and parquet_path.exists(): + # Apply window: go back incremental_window_days from last sync + last_sync_dt = datetime.fromisoformat(last_sync) + window_days = table_config.incremental_window_days or 7 + since_dt = last_sync_dt - timedelta(days=window_days) + since_value = since_dt.isoformat() + + logger.info(f" -> Since: {since_value} (window: {window_days} days)") + + # Read incremental data + new_data = self.bq_client.read_table_incremental( + table_id=table_config.id, + incremental_column=table_config.incremental_column, + since_value=since_value, + ) + + if new_data.num_rows == 0: + logger.info(" -> No new data since last sync") + existing_pf = pq.ParquetFile(parquet_path) + return { + "rows": existing_pf.metadata.num_rows, + "columns": len(existing_pf.schema_arrow), + "file_size_bytes": parquet_path.stat().st_size, + "uncompressed_bytes": _get_uncompressed_size(parquet_path), + } + + # Merge with existing data + logger.info(f" -> Merging {new_data.num_rows} new rows with existing data") + existing_table = pq.read_table(parquet_path) + merged = self._merge_arrow_tables( + existing_table, new_data, table_config.get_primary_key_columns() + ) + + # Apply schema enforcement + if date_columns: + merged = convert_date_columns_to_date32(merged, date_columns) + if pyarrow_schema: + merged = apply_schema_to_table(merged, pyarrow_schema) + + pq.write_table(merged, parquet_path, compression="snappy") + + file_size = parquet_path.stat().st_size + logger.info( + f" -> Incremental sync complete: {merged.num_rows} total rows" + ) + + return { + "rows": merged.num_rows, + "columns": merged.num_columns, + "file_size_bytes": file_size, + "uncompressed_bytes": _get_uncompressed_size(parquet_path), + } + + else: + # First sync or no existing file -- full read + logger.info(" -> First sync, reading all data") + + if table_config.max_history_days: + since_dt = datetime.now() - timedelta(days=table_config.max_history_days) + arrow_table = self.bq_client.read_table_incremental( + table_id=table_config.id, + incremental_column=table_config.incremental_column, + since_value=since_dt.isoformat(), + ) + else: + arrow_table = self.bq_client.read_table(table_config.id) + + # Apply schema enforcement + if date_columns: + arrow_table = convert_date_columns_to_date32(arrow_table, date_columns) + if pyarrow_schema: + arrow_table = apply_schema_to_table(arrow_table, pyarrow_schema) + + parquet_path.parent.mkdir(parents=True, exist_ok=True) + pq.write_table(arrow_table, parquet_path, compression="snappy") + + file_size = parquet_path.stat().st_size + return { + "rows": arrow_table.num_rows, + "columns": arrow_table.num_columns, + "file_size_bytes": file_size, + "uncompressed_bytes": _get_uncompressed_size(parquet_path), + } + + def _partitioned_sync( + self, + table_config: TableConfig, + sync_state: SyncState, + ) -> Dict[str, Any]: + """ + Partition-based sync: read data by partition range and write partition files. + """ + import pandas as pd + + partition_col = table_config.partition_by + if not partition_col and table_config.incremental_column: + partition_col = table_config.incremental_column + + if not partition_col: + logger.warning( + f"Table {table_config.name}: partitioned strategy but no " + f"partition_by or incremental_column, falling back to full refresh" + ) + return self._full_refresh(table_config) + + granularity = table_config.partition_granularity or "month" + logger.info( + f"Partitioned sync: {table_config.name} " + f"(by {partition_col}, {granularity})" + ) + + partition_dir = self.config.get_parquet_path(table_config) + date_columns = self.bq_client.get_date_columns(table_config.id) + pyarrow_schema = self.bq_client.get_pyarrow_schema(table_config.id) + + # Determine time range + last_sync = sync_state.get_last_sync(table_config.id) + + if last_sync: + last_sync_dt = datetime.fromisoformat(last_sync) + window_days = table_config.incremental_window_days or 7 + start_dt = last_sync_dt - timedelta(days=window_days) + logger.info(f" -> Reading from {start_dt.isoformat()} (window: {window_days} days)") + else: + if table_config.max_history_days: + start_dt = datetime.now() - timedelta(days=table_config.max_history_days) + logger.info(f" -> First sync, limited to last {table_config.max_history_days} days") + else: + start_dt = None + logger.info(" -> First sync, reading all data") + + # Read data from BigQuery + if start_dt: + arrow_table = self.bq_client.read_table_partitioned( + table_id=table_config.id, + partition_column=partition_col, + start=start_dt.isoformat(), + ) + else: + arrow_table = self.bq_client.read_table(table_config.id) + + if arrow_table.num_rows == 0: + logger.info(" -> No data to sync") + return self._get_partition_totals(partition_dir) + + logger.info(f" -> Processing {arrow_table.num_rows} rows into partitions") + + # Convert to pandas for partitioning + df = arrow_table.to_pandas() + + # Ensure partition column is datetime + if not pd.api.types.is_datetime64_any_dtype(df[partition_col]): + df[partition_col] = pd.to_datetime(df[partition_col], format="ISO8601", utc=True) + + # Create partition key + if granularity == "month": + df["_partition_key"] = df[partition_col].dt.strftime("%Y_%m") + elif granularity == "day": + df["_partition_key"] = df[partition_col].dt.strftime("%Y_%m_%d") + elif granularity == "year": + df["_partition_key"] = df[partition_col].dt.strftime("%Y") + + primary_key_cols = table_config.get_primary_key_columns() + partitions_updated = set() + + for partition_key, group_df in df.groupby("_partition_key"): + group_df = group_df.drop(columns=["_partition_key"]) + partition_path = self.config.get_partition_path(table_config, partition_key) + partitions_updated.add(partition_key) + + # Merge with existing partition if it exists + if partition_path.exists(): + existing_df = pd.read_parquet(partition_path) + merged_df = pd.concat([existing_df, group_df], ignore_index=True) + merged_df = merged_df.drop_duplicates(subset=primary_key_cols, keep="last") + else: + merged_df = group_df + + # Write partition + table = pa.Table.from_pandas(merged_df, preserve_index=False) + if date_columns: + table = convert_date_columns_to_date32(table, date_columns) + if pyarrow_schema: + table = apply_schema_to_table(table, pyarrow_schema) + pq.write_table(table, partition_path, compression="snappy") + + logger.info(f" -> Partitioned sync complete: {len(partitions_updated)} partitions updated") + + return self._get_partition_totals(partition_dir) + + def _merge_arrow_tables( + self, + existing: pa.Table, + new_data: pa.Table, + primary_key: List[str], + ) -> pa.Table: + """ + Merge two Arrow tables with deduplication on primary key. + + New data overwrites existing rows with the same primary key. + + Args: + existing: Existing data + new_data: New/changed data + primary_key: List of PK column names + + Returns: + Merged PyArrow Table + """ + import pandas as pd + + existing_df = existing.to_pandas() + new_df = new_data.to_pandas() + + # Concat and dedup (keep last = new data wins) + merged_df = pd.concat([existing_df, new_df], ignore_index=True) + merged_df = merged_df.drop_duplicates(subset=primary_key, keep="last") + + return pa.Table.from_pandas(merged_df, preserve_index=False) + + def _get_partition_totals(self, partition_dir: Path) -> Dict[str, Any]: + """ + Calculate totals from all partition files in a directory. + """ + total_rows = 0 + total_size = 0 + total_uncompressed = 0 + total_columns = 0 + + if not partition_dir.exists(): + return {"rows": 0, "file_size_bytes": 0, "columns": 0, "uncompressed_bytes": 0} + + all_partitions = list(partition_dir.glob("*.parquet")) + + for part_path in all_partitions: + try: + pf = pq.ParquetFile(part_path) + meta = pf.metadata + total_rows += meta.num_rows + total_size += part_path.stat().st_size + if total_columns == 0: + total_columns = len(pf.schema_arrow) + for rg_idx in range(meta.num_row_groups): + rg = meta.row_group(rg_idx) + for col_idx in range(rg.num_columns): + total_uncompressed += rg.column(col_idx).total_uncompressed_size + except Exception as e: + logger.warning(f"Skipping corrupt partition {part_path.name}: {e}") + + return { + "rows": total_rows, + "file_size_bytes": total_size, + "partitions": len(all_partitions), + "columns": total_columns, + "uncompressed_bytes": total_uncompressed, + } + + +def create_data_source() -> BigQueryDataSource: + """Factory function for dynamic import compatibility.""" + return BigQueryDataSource() diff --git a/connectors/bigquery/client.py b/connectors/bigquery/client.py new file mode 100644 index 0000000..b72e65b --- /dev/null +++ b/connectors/bigquery/client.py @@ -0,0 +1,469 @@ +""" +Google BigQuery API Client + +Low-level wrapper for Google BigQuery with these functions: +1. Authentication using Application Default Credentials (ADC) +2. Query tables to PyArrow (no CSV intermediate step) +3. Get table metadata (schema, columns, data types) +4. Cache metadata for faster repeated use +5. Incremental reads (timestamp-based and partition-based) + +Uses google-cloud-bigquery with native PyArrow support. +""" + +import json +import logging +import os +from pathlib import Path +from typing import Dict, List, Optional, Any +from datetime import datetime, timedelta + +import pyarrow as pa +from google.cloud import bigquery + +from src.config import get_config + + +logger = logging.getLogger(__name__) + + +# Mapping BigQuery types to PyArrow types +BIGQUERY_TO_PYARROW_TYPES = { + "STRING": pa.string(), + "BYTES": pa.binary(), + "INTEGER": pa.int64(), + "INT64": pa.int64(), + "FLOAT": pa.float64(), + "FLOAT64": pa.float64(), + "NUMERIC": pa.float64(), + "BIGNUMERIC": pa.float64(), + "BOOLEAN": pa.bool_(), + "BOOL": pa.bool_(), + "TIMESTAMP": pa.timestamp("us", tz="UTC"), + "DATE": pa.date32(), + "TIME": pa.string(), + "DATETIME": pa.timestamp("us"), + "GEOGRAPHY": pa.string(), + "JSON": pa.string(), + "STRUCT": pa.string(), + "RECORD": pa.string(), + "ARRAY": pa.string(), +} + + +class BigQueryClient: + """ + Wrapper for Google BigQuery API. + + Provides high-level methods for working with BigQuery tables: + - Query tables to PyArrow Tables (no CSV step) + - Get metadata (schema, columns) + - Incremental and partitioned reads + """ + + def __init__( + self, + project_id: Optional[str] = None, + location: Optional[str] = None, + ): + """ + Initialize BigQuery client. + + Args: + project_id: GCP project ID for job execution/billing. + If None, reads from BIGQUERY_PROJECT env var. + location: BigQuery location for job execution (e.g., "us-central1"). + If None, reads from BIGQUERY_LOCATION env var. + + Raises: + ValueError: If project_id is not provided and BIGQUERY_PROJECT is not set. + """ + self.project_id = project_id or os.environ.get("BIGQUERY_PROJECT") + + if not self.project_id: + raise ValueError( + "BigQuery project ID not set. " + "Set BIGQUERY_PROJECT environment variable." + ) + + self.location = location or os.environ.get("BIGQUERY_LOCATION") + + # Initialize BigQuery client with ADC + # project_id is used for job execution and billing. + # Data can live in a different project -- table IDs in queries + # use fully-qualified format (project.dataset.table). + client_kwargs = {"project": self.project_id} + if self.location: + client_kwargs["location"] = self.location + self.client = bigquery.Client(**client_kwargs) + + # Metadata cache + config = get_config() + self.metadata_cache: Dict[str, Dict[str, Any]] = {} + self.metadata_cache_path = config.get_metadata_path() / "bq_table_metadata.json" + + # Load cache from disk if exists + self._load_metadata_cache() + + logger.info( + f"BigQuery client initialized: project={self.project_id}, " + f"location={self.location or 'auto'}" + ) + + def _load_metadata_cache(self): + """Load metadata cache from disk.""" + if self.metadata_cache_path.exists(): + try: + with open(self.metadata_cache_path, "r") as f: + self.metadata_cache = json.load(f) + logger.info( + f"BQ metadata cache loaded: {len(self.metadata_cache)} tables" + ) + except Exception as e: + logger.warning(f"Error loading BQ metadata cache: {e}") + self.metadata_cache = {} + + def _save_metadata_cache(self): + """Save metadata cache to disk.""" + try: + self.metadata_cache_path.parent.mkdir(parents=True, exist_ok=True) + with open(self.metadata_cache_path, "w") as f: + json.dump(self.metadata_cache, f, indent=2) + logger.debug("BQ metadata cache saved") + except Exception as e: + logger.warning(f"Error saving BQ metadata cache: {e}") + + def get_table_metadata( + self, + table_id: str, + use_cache: bool = True, + cache_ttl_hours: int = 24, + ) -> Dict[str, Any]: + """ + Get table metadata from BigQuery. + + Args: + table_id: Full table ID (e.g., "project.dataset.table") + use_cache: Use cache if available + cache_ttl_hours: Cache TTL in hours (default 24h) + + Returns: + Dictionary with metadata including columns, types, descriptions, row count. + """ + # Check cache + if use_cache and table_id in self.metadata_cache: + cached = self.metadata_cache[table_id] + cached_time = datetime.fromisoformat(cached.get("_cached_at", "2000-01-01")) + cache_age = datetime.now() - cached_time + + if cache_age < timedelta(hours=cache_ttl_hours): + logger.debug(f"Using BQ metadata cache for {table_id}") + return cached + + logger.info(f"Fetching metadata for BQ table: {table_id}") + + try: + table_ref = self.client.get_table(table_id) + + # Build column metadata + columns = [] + column_types = {} + column_descriptions = {} + for field in table_ref.schema: + columns.append(field.name) + column_types[field.name] = field.field_type + if field.description: + column_descriptions[field.name] = field.description + + metadata = { + "table_id": table_id, + "name": table_ref.table_id, + "dataset": table_ref.dataset_id, + "project": table_ref.project, + "columns": columns, + "column_types": column_types, + "column_descriptions": column_descriptions, + "row_count": table_ref.num_rows, + "size_bytes": table_ref.num_bytes, + "created": table_ref.created.isoformat() if table_ref.created else None, + "modified": table_ref.modified.isoformat() if table_ref.modified else None, + "partitioning": None, + "_cached_at": datetime.now().isoformat(), + } + + # Capture partitioning info + if table_ref.time_partitioning: + metadata["partitioning"] = { + "type": table_ref.time_partitioning.type_, + "field": table_ref.time_partitioning.field, + "expiration_ms": table_ref.time_partitioning.expiration_ms, + } + + # Save to cache + self.metadata_cache[table_id] = metadata + self._save_metadata_cache() + + return metadata + + except Exception as e: + logger.error(f"Error getting metadata for {table_id}: {e}") + raise + + def get_pyarrow_schema(self, table_id: str) -> Optional[pa.Schema]: + """ + Build PyArrow schema from BigQuery table schema. + + Args: + table_id: Full table ID + + Returns: + PyArrow schema or None if metadata unavailable + """ + metadata = self.get_table_metadata(table_id) + column_types = metadata.get("column_types", {}) + + if not column_types: + logger.warning(f"No column types for {table_id}, schema will not be applied") + return None + + fields = [] + for col_name in metadata.get("columns", []): + bq_type = column_types.get(col_name, "STRING") + pa_type = BIGQUERY_TO_PYARROW_TYPES.get(bq_type, pa.string()) + fields.append(pa.field(col_name, pa_type)) + + return pa.schema(fields) + + def get_date_columns(self, table_id: str) -> List[str]: + """ + Get list of DATE-only columns for a table. + + Args: + table_id: Full table ID + + Returns: + List of column names that have DATE type in BigQuery + """ + metadata = self.get_table_metadata(table_id) + column_types = metadata.get("column_types", {}) + + return [ + col_name for col_name, bq_type in column_types.items() + if bq_type == "DATE" + ] + + def query_to_arrow( + self, + sql: str, + params: Optional[List[bigquery.ScalarQueryParameter]] = None, + ) -> pa.Table: + """ + Execute SQL query and return results as PyArrow Table. + + Args: + sql: SQL query string (use @param_name for parameterized values) + params: List of BigQuery query parameters + + Returns: + PyArrow Table with query results + """ + job_config = bigquery.QueryJobConfig() + if params: + job_config.query_parameters = params + + logger.debug(f"Executing BQ query: {sql[:200]}...") + + query_job = self.client.query(sql, job_config=job_config) + arrow_table = query_job.to_arrow() + + logger.debug(f"Query returned {arrow_table.num_rows} rows, {arrow_table.num_columns} columns") + return arrow_table + + def read_table( + self, + table_id: str, + columns: Optional[List[str]] = None, + row_filter: Optional[str] = None, + ) -> pa.Table: + """ + Read full table (or filtered subset) as PyArrow Table. + + Args: + table_id: Full table ID (e.g., "project.dataset.table") + columns: Optional list of columns to select + row_filter: Optional SQL WHERE clause (without WHERE keyword) + + Returns: + PyArrow Table with table data + """ + # Build SELECT clause + select_cols = ", ".join(f"`{c}`" for c in columns) if columns else "*" + + sql = f"SELECT {select_cols} FROM `{table_id}`" + if row_filter: + sql += f" WHERE {row_filter}" + + logger.info(f"Reading BQ table: {table_id} (filter: {row_filter or 'none'})") + return self.query_to_arrow(sql) + + def read_table_incremental( + self, + table_id: str, + incremental_column: str, + since_value: str, + columns: Optional[List[str]] = None, + ) -> pa.Table: + """ + Read rows where incremental_column > since_value. + + Uses parameterized query to prevent SQL injection. + + Args: + table_id: Full table ID + incremental_column: Column name for incremental filter + since_value: ISO timestamp string - fetch rows after this value + columns: Optional list of columns to select + + Returns: + PyArrow Table with incremental data + """ + select_cols = ", ".join(f"`{c}`" for c in columns) if columns else "*" + + sql = ( + f"SELECT {select_cols} FROM `{table_id}` " + f"WHERE `{incremental_column}` > @since_value" + ) + + params = [ + bigquery.ScalarQueryParameter("since_value", "TIMESTAMP", since_value), + ] + + logger.info( + f"Incremental read: {table_id} WHERE {incremental_column} > {since_value}" + ) + return self.query_to_arrow(sql, params=params) + + def read_table_partitioned( + self, + table_id: str, + partition_column: str, + start: str, + end: Optional[str] = None, + columns: Optional[List[str]] = None, + ) -> pa.Table: + """ + Read data within a partition range. + + Args: + table_id: Full table ID + partition_column: Partition column name + start: Start date/timestamp (inclusive) + end: End date/timestamp (exclusive). If None, reads to present. + columns: Optional list of columns to select + + Returns: + PyArrow Table with partition range data + """ + select_cols = ", ".join(f"`{c}`" for c in columns) if columns else "*" + + sql = ( + f"SELECT {select_cols} FROM `{table_id}` " + f"WHERE `{partition_column}` >= @start_value" + ) + params = [ + bigquery.ScalarQueryParameter("start_value", "TIMESTAMP", start), + ] + + if end: + sql += f" AND `{partition_column}` < @end_value" + params.append( + bigquery.ScalarQueryParameter("end_value", "TIMESTAMP", end), + ) + + logger.info( + f"Partitioned read: {table_id} [{start} .. {end or 'now'})" + ) + return self.query_to_arrow(sql, params=params) + + def discover_all_tables(self, dataset_id: Optional[str] = None) -> List[Dict[str, Any]]: + """ + List all tables in the project (or specific dataset). + + Args: + dataset_id: Optional dataset ID to limit scope + + Returns: + Normalized list of table dicts with id, name, columns, row_count, etc. + """ + logger.info(f"Discovering BQ tables (dataset={dataset_id or 'all'})...") + + result = [] + + if dataset_id: + datasets = [self.client.get_dataset(dataset_id)] + else: + datasets = list(self.client.list_datasets()) + + for dataset in datasets: + ds_ref = dataset.reference if hasattr(dataset, "reference") else dataset.dataset_id + ds_id = str(ds_ref) + + try: + tables = list(self.client.list_tables(ds_ref)) + except Exception as e: + logger.warning(f"Could not list tables in dataset {ds_id}: {e}") + continue + + for table_item in tables: + full_id = f"{table_item.project}.{table_item.dataset_id}.{table_item.table_id}" + + try: + table_detail = self.client.get_table(full_id) + columns = [f.name for f in table_detail.schema] + + result.append({ + "id": full_id, + "name": table_item.table_id, + "bucket_id": table_item.dataset_id, + "bucket_name": table_item.dataset_id, + "columns": columns, + "row_count": table_detail.num_rows or 0, + "size_bytes": table_detail.num_bytes or 0, + "primary_key": [], + "last_change": ( + table_detail.modified.isoformat() + if table_detail.modified else None + ), + "last_import": None, + }) + except Exception as e: + logger.warning(f"Could not get details for {full_id}: {e}") + + logger.info(f"Discovered {len(result)} BQ tables") + return result + + def test_connection(self) -> bool: + """ + Test connection to BigQuery. + + Returns: + True if connection works, False otherwise + """ + try: + query_job = self.client.query("SELECT 1") + list(query_job.result()) + logger.info(f"BigQuery connection OK (project: {self.project_id})") + return True + except Exception as e: + logger.error(f"BigQuery connection test failed: {e}") + return False + + +def create_client() -> BigQueryClient: + """ + Factory function to create BigQuery client. + + Returns: + BigQueryClient instance + """ + return BigQueryClient() diff --git a/requirements.txt b/requirements.txt index 81d63f2..914e03e 100644 --- a/requirements.txt +++ b/requirements.txt @@ -1,5 +1,7 @@ # Data source adapters (install only what you need) kbcstorage>=0.9.0 # For Keboola adapter +google-cloud-bigquery>=3.0.0 # For BigQuery adapter +google-cloud-bigquery-storage>=2.0.0 # For BigQuery adapter (fast Arrow transfer) # Data processing # pandas - core tabular data processing library diff --git a/src/config.py b/src/config.py index 5094864..578f6ee 100644 --- a/src/config.py +++ b/src/config.py @@ -101,6 +101,7 @@ class TableConfig: max_history_days: Optional[int] = None dataset: Optional[str] = None initial_load_chunk_days: int = 30 + incremental_column: Optional[str] = None # Column for timestamp-based incremental sync (BigQuery) def __post_init__(self): """Validate configuration after initialization.""" @@ -429,6 +430,7 @@ class Config: max_history_days=table_data.get("max_history_days"), dataset=table_data.get("dataset"), initial_load_chunk_days=table_data.get("initial_load_chunk_days", 30), + incremental_column=table_data.get("incremental_column"), ) table_configs.append(config) diff --git a/src/data_sync.py b/src/data_sync.py index 5484279..77169b6 100644 --- a/src/data_sync.py +++ b/src/data_sync.py @@ -511,7 +511,7 @@ def create_data_source(source_type: str = None) -> DataSource: raise ValueError( f"Unknown data source: '{source_type}'. " - f"Available connectors: keboola. " + f"Available connectors: keboola, bigquery. " f"Create connectors/{source_type}/adapter.py to add a new one." ) diff --git a/tests/test_bigquery_adapter.py b/tests/test_bigquery_adapter.py new file mode 100644 index 0000000..edec77a --- /dev/null +++ b/tests/test_bigquery_adapter.py @@ -0,0 +1,763 @@ +""" +Comprehensive unit tests for the BigQuery data source adapter. + +Tests the BigQueryDataSource class from connectors/bigquery/adapter.py +with all external dependencies (BigQueryClient, config, parquet_manager) mocked. + +The google-cloud-bigquery package is not installed in test environments, +so we install stub modules in sys.modules before importing the adapter. +""" + +import sys +from pathlib import Path +from unittest.mock import MagicMock, patch + +import pyarrow as pa +import pyarrow.parquet as pq +import pytest + +# --------------------------------------------------------------------------- +# Stub google.cloud.bigquery before any connector import +# --------------------------------------------------------------------------- +_bq_stub = MagicMock() +sys.modules.setdefault("google", _bq_stub) +sys.modules.setdefault("google.cloud", _bq_stub) +sys.modules.setdefault("google.cloud.bigquery", _bq_stub) + +from src.config import TableConfig # noqa: E402 +from src.data_sync import SyncState # noqa: E402 + + +# --------------------------------------------------------------------------- +# Fixtures +# --------------------------------------------------------------------------- + +@pytest.fixture +def tmp_parquet_dir(tmp_path): + """Provide a temporary directory for Parquet file output.""" + parquet_dir = tmp_path / "parquet" / "test_bucket" + parquet_dir.mkdir(parents=True) + return parquet_dir + + +@pytest.fixture +def mock_config(tmp_parquet_dir): + """Create a mock Config object that returns paths inside tmp_parquet_dir.""" + config = MagicMock() + config.get_parquet_path = MagicMock() + config.get_partition_path = MagicMock() + config.get_metadata_path.return_value = tmp_parquet_dir.parent / "metadata" + return config + + +@pytest.fixture +def mock_bq_client(): + """Create a mock BigQueryClient with sensible defaults.""" + client = MagicMock() + client.metadata_cache = {} + client.get_date_columns.return_value = [] + client.get_pyarrow_schema.return_value = None + return client + + +@pytest.fixture +def sync_state(tmp_path): + """Create a real SyncState backed by a temp JSON file.""" + state_file = tmp_path / "metadata" / "sync_state.json" + state_file.parent.mkdir(parents=True, exist_ok=True) + return SyncState(state_file) + + +# --------------------------------------------------------------------------- +# Helpers +# --------------------------------------------------------------------------- + +def _make_table_config( + *, + table_id: str = "project.dataset.orders", + name: str = "orders", + primary_key: str = "id", + sync_strategy: str = "full_refresh", + incremental_column: str | None = None, + incremental_window_days: int | None = None, + partition_by: str | None = None, + partition_granularity: str | None = None, + max_history_days: int | None = None, +) -> TableConfig: + """Helper to build a TableConfig with safe defaults.""" + return TableConfig( + id=table_id, + name=name, + description="Test table", + primary_key=primary_key, + sync_strategy=sync_strategy, + incremental_column=incremental_column, + incremental_window_days=incremental_window_days, + partition_by=partition_by, + partition_granularity=partition_granularity, + max_history_days=max_history_days, + ) + + +def _sample_arrow_table(ids: list[int], names: list[str]) -> pa.Table: + """Build a small PyArrow Table with id and name columns.""" + return pa.table({"id": ids, "name": names}) + + +def _create_adapter(mock_config, mock_bq_client): + """Instantiate BigQueryDataSource with mocked dependencies. + + Patches get_config and create_bq_client so that no real GCP + credentials or network access are needed. + """ + with patch("connectors.bigquery.adapter.get_config", return_value=mock_config), \ + patch("connectors.bigquery.adapter.create_bq_client", return_value=mock_bq_client): + from connectors.bigquery.adapter import BigQueryDataSource + adapter = BigQueryDataSource() + return adapter + + +# --------------------------------------------------------------------------- +# 1. full_refresh writes valid Parquet file from Arrow table +# --------------------------------------------------------------------------- + +class TestFullRefresh: + + def test_writes_valid_parquet(self, mock_config, mock_bq_client, tmp_parquet_dir, sync_state): + """full_refresh should write a valid, readable Parquet file.""" + table_config = _make_table_config(sync_strategy="full_refresh") + parquet_path = tmp_parquet_dir / "orders.parquet" + mock_config.get_parquet_path.return_value = parquet_path + + arrow_data = _sample_arrow_table([1, 2, 3], ["Alice", "Bob", "Charlie"]) + mock_bq_client.read_table.return_value = arrow_data + + adapter = _create_adapter(mock_config, mock_bq_client) + result = adapter.sync_table(table_config, sync_state) + + assert result["success"] is True + assert result["rows"] == 3 + assert parquet_path.exists() + + # Verify Parquet content matches source data + read_back = pq.read_table(parquet_path) + assert read_back.num_rows == 3 + assert read_back.column_names == ["id", "name"] + + def test_applies_date_columns(self, mock_config, mock_bq_client, tmp_parquet_dir, sync_state): + """full_refresh should call convert_date_columns_to_date32 when date columns exist.""" + table_config = _make_table_config() + parquet_path = tmp_parquet_dir / "orders.parquet" + mock_config.get_parquet_path.return_value = parquet_path + + arrow_data = _sample_arrow_table([1], ["Alice"]) + mock_bq_client.read_table.return_value = arrow_data + mock_bq_client.get_date_columns.return_value = ["created_at"] + + with patch("connectors.bigquery.adapter.convert_date_columns_to_date32", return_value=arrow_data) as mock_conv: + adapter = _create_adapter(mock_config, mock_bq_client) + adapter.sync_table(table_config, sync_state) + mock_conv.assert_called_once_with(arrow_data, ["created_at"]) + + def test_applies_pyarrow_schema(self, mock_config, mock_bq_client, tmp_parquet_dir, sync_state): + """full_refresh should call apply_schema_to_table when schema is available.""" + table_config = _make_table_config() + parquet_path = tmp_parquet_dir / "orders.parquet" + mock_config.get_parquet_path.return_value = parquet_path + + arrow_data = _sample_arrow_table([1], ["Alice"]) + mock_bq_client.read_table.return_value = arrow_data + schema = pa.schema([pa.field("id", pa.int64()), pa.field("name", pa.string())]) + mock_bq_client.get_pyarrow_schema.return_value = schema + + with patch("connectors.bigquery.adapter.apply_schema_to_table", return_value=arrow_data) as mock_apply: + adapter = _create_adapter(mock_config, mock_bq_client) + adapter.sync_table(table_config, sync_state) + mock_apply.assert_called_once_with(arrow_data, schema) + + +# --------------------------------------------------------------------------- +# 2. incremental_column_sync merges correctly (dedup on PK, new data wins) +# --------------------------------------------------------------------------- + +class TestIncrementalColumnSync: + + def test_merge_dedup_new_data_wins(self, mock_config, mock_bq_client, tmp_parquet_dir, sync_state): + """Incremental sync should overwrite existing rows when PK matches (new data wins).""" + table_config = _make_table_config( + sync_strategy="incremental", + incremental_column="updated_at", + incremental_window_days=7, + ) + parquet_path = tmp_parquet_dir / "orders.parquet" + mock_config.get_parquet_path.return_value = parquet_path + + # Write existing data + existing = _sample_arrow_table([1, 2], ["Alice", "Bob"]) + pq.write_table(existing, parquet_path) + + # Simulate a previous sync timestamp + sync_state.update_sync( + table_id=table_config.id, + table_name=table_config.name, + strategy="incremental", + rows=2, + file_size_bytes=100, + ) + + # New data: id=2 gets updated name, id=3 is new + new_data = _sample_arrow_table([2, 3], ["Bob_Updated", "Charlie"]) + mock_bq_client.read_table_incremental.return_value = new_data + + adapter = _create_adapter(mock_config, mock_bq_client) + result = adapter.sync_table(table_config, sync_state) + + assert result["success"] is True + assert result["rows"] == 3 # Alice + Bob_Updated + Charlie + + read_back = pq.read_table(parquet_path) + df = read_back.to_pandas() + assert set(df["id"].tolist()) == {1, 2, 3} + # id=2 should have the updated name + bob_row = df[df["id"] == 2].iloc[0] + assert bob_row["name"] == "Bob_Updated" + + +# --------------------------------------------------------------------------- +# 3. incremental_column_sync with no new data returns existing file info +# --------------------------------------------------------------------------- + +class TestIncrementalNoNewData: + + def test_returns_existing_file_info(self, mock_config, mock_bq_client, tmp_parquet_dir, sync_state): + """When there is no new data, sync returns stats from the existing Parquet file.""" + table_config = _make_table_config( + sync_strategy="incremental", + incremental_column="updated_at", + incremental_window_days=7, + ) + parquet_path = tmp_parquet_dir / "orders.parquet" + mock_config.get_parquet_path.return_value = parquet_path + + # Write existing data + existing = _sample_arrow_table([1, 2, 3], ["A", "B", "C"]) + pq.write_table(existing, parquet_path) + + # Mark a previous sync + sync_state.update_sync( + table_id=table_config.id, + table_name=table_config.name, + strategy="incremental", + rows=3, + file_size_bytes=100, + ) + + # No new rows + empty_table = pa.table({ + "id": pa.array([], type=pa.int64()), + "name": pa.array([], type=pa.string()), + }) + mock_bq_client.read_table_incremental.return_value = empty_table + + adapter = _create_adapter(mock_config, mock_bq_client) + result = adapter.sync_table(table_config, sync_state) + + assert result["success"] is True + assert result["rows"] == 3 # existing row count preserved + + +# --------------------------------------------------------------------------- +# 4. partitioned_sync creates partition files +# --------------------------------------------------------------------------- + +class TestPartitionedSync: + + def test_creates_partition_files(self, mock_config, mock_bq_client, tmp_parquet_dir, sync_state): + """Partitioned sync should create separate Parquet files per partition key.""" + import pandas as pd + + table_config = _make_table_config( + sync_strategy="incremental", + incremental_column="created_at", + partition_by="created_at", + partition_granularity="month", + incremental_window_days=7, + ) + + # For partitioned tables, parquet_path is a directory + partition_dir = tmp_parquet_dir / "orders" + partition_dir.mkdir(parents=True, exist_ok=True) + mock_config.get_parquet_path.return_value = partition_dir + + # Configure partition paths + def _partition_path(tc, key): + return partition_dir / f"{key}.parquet" + mock_config.get_partition_path.side_effect = _partition_path + + # Build arrow table with timestamps in two months + ts_jan = [pd.Timestamp("2026-01-15 10:00:00", tz="UTC")] + ts_feb = [pd.Timestamp("2026-02-20 14:00:00", tz="UTC")] + arrow_data = pa.table({ + "id": [1, 2], + "name": ["Jan_Order", "Feb_Order"], + "created_at": pa.array(ts_jan + ts_feb, type=pa.timestamp("us", tz="UTC")), + }) + mock_bq_client.read_table.return_value = arrow_data + + adapter = _create_adapter(mock_config, mock_bq_client) + result = adapter.sync_table(table_config, sync_state) + + assert result["success"] is True + + # Should have created two partition files + partition_files = list(partition_dir.glob("*.parquet")) + assert len(partition_files) == 2 + + partition_names = sorted(f.stem for f in partition_files) + assert "2026_01" in partition_names + assert "2026_02" in partition_names + + +# --------------------------------------------------------------------------- +# 5. discover_tables delegates to BigQueryClient.discover_all_tables() +# --------------------------------------------------------------------------- + +class TestDiscoverTables: + + def test_delegates_to_client(self, mock_config, mock_bq_client): + """discover_tables should forward the call to BigQueryClient.discover_all_tables.""" + expected = [{"id": "proj.ds.t1", "name": "t1", "columns": ["a", "b"]}] + mock_bq_client.discover_all_tables.return_value = expected + + adapter = _create_adapter(mock_config, mock_bq_client) + result = adapter.discover_tables() + + mock_bq_client.discover_all_tables.assert_called_once() + assert result == expected + + +# --------------------------------------------------------------------------- +# 6. get_source_name returns "Google BigQuery" +# --------------------------------------------------------------------------- + +class TestGetSourceName: + + def test_returns_google_bigquery(self, mock_config, mock_bq_client): + adapter = _create_adapter(mock_config, mock_bq_client) + assert adapter.get_source_name() == "Google BigQuery" + + +# --------------------------------------------------------------------------- +# 7. get_column_metadata returns correct format +# --------------------------------------------------------------------------- + +class TestGetColumnMetadata: + + def test_returns_correct_format(self, mock_config, mock_bq_client): + """get_column_metadata should transform BQ raw metadata into {columns: ...} format.""" + mock_bq_client.get_table_metadata.return_value = { + "column_types": {"id": "INT64", "name": "STRING", "email": "STRING"}, + "column_descriptions": {"id": "Primary key", "email": "User email address"}, + } + + adapter = _create_adapter(mock_config, mock_bq_client) + result = adapter.get_column_metadata("project.dataset.users") + + assert "columns" in result + assert result["columns"]["id"] == {"source_type": "INT64", "description": "Primary key"} + assert result["columns"]["name"] == {"source_type": "STRING"} + assert result["columns"]["email"] == { + "source_type": "STRING", + "description": "User email address", + } + + def test_returns_none_when_no_column_types(self, mock_config, mock_bq_client): + """get_column_metadata should return None if the metadata has no column types.""" + mock_bq_client.get_table_metadata.return_value = { + "column_types": {}, + "column_descriptions": {}, + } + + adapter = _create_adapter(mock_config, mock_bq_client) + result = adapter.get_column_metadata("project.dataset.users") + + assert result is None + + +# --------------------------------------------------------------------------- +# 8. Error handling (query failure -> {success: False, error: ...}) +# --------------------------------------------------------------------------- + +class TestErrorHandling: + + def test_query_failure_returns_error_dict( + self, mock_config, mock_bq_client, tmp_parquet_dir, sync_state + ): + """When BigQuery query raises, sync_table returns {success: False, error: ...}.""" + table_config = _make_table_config() + mock_config.get_parquet_path.return_value = tmp_parquet_dir / "orders.parquet" + mock_bq_client.read_table.side_effect = RuntimeError("BigQuery API timeout") + + adapter = _create_adapter(mock_config, mock_bq_client) + result = adapter.sync_table(table_config, sync_state) + + assert result["success"] is False + assert "BigQuery API timeout" in result["error"] + assert result["strategy"] == "full_refresh" + + def test_unknown_strategy_returns_error(self, mock_config, mock_bq_client, sync_state): + """Unknown sync_strategy in internal dispatch should produce an error result.""" + # We cannot create a TableConfig with an invalid strategy via constructor + # (it validates). Instead, we mutate it after creation. + table_config = _make_table_config() + table_config.sync_strategy = "magic_sync" + + adapter = _create_adapter(mock_config, mock_bq_client) + result = adapter.sync_table(table_config, sync_state) + + assert result["success"] is False + assert "Unknown sync strategy" in result["error"] + + +# --------------------------------------------------------------------------- +# 9. incremental_column config is used in WHERE clause +# --------------------------------------------------------------------------- + +class TestIncrementalColumnUsedInWhere: + + def test_incremental_column_passed_to_client( + self, mock_config, mock_bq_client, tmp_parquet_dir, sync_state + ): + """The configured incremental_column should be forwarded to read_table_incremental.""" + table_config = _make_table_config( + sync_strategy="incremental", + incremental_column="modified_at", + incremental_window_days=14, + ) + parquet_path = tmp_parquet_dir / "orders.parquet" + mock_config.get_parquet_path.return_value = parquet_path + + # Write existing data so we enter the incremental path + existing = _sample_arrow_table([1], ["Alice"]) + pq.write_table(existing, parquet_path) + + sync_state.update_sync( + table_id=table_config.id, + table_name=table_config.name, + strategy="incremental", + rows=1, + file_size_bytes=100, + ) + + # Return empty to keep the test simple + empty = pa.table({ + "id": pa.array([], type=pa.int64()), + "name": pa.array([], type=pa.string()), + }) + mock_bq_client.read_table_incremental.return_value = empty + + adapter = _create_adapter(mock_config, mock_bq_client) + adapter.sync_table(table_config, sync_state) + + call_kwargs = mock_bq_client.read_table_incremental.call_args + assert call_kwargs.kwargs["incremental_column"] == "modified_at" + assert call_kwargs.kwargs["table_id"] == "project.dataset.orders" + # since_value should be an ISO string + assert "since_value" in call_kwargs.kwargs + + +# --------------------------------------------------------------------------- +# 10. First sync without existing file downloads all data +# --------------------------------------------------------------------------- + +class TestFirstSyncDownloadsAll: + + def test_first_sync_reads_full_table( + self, mock_config, mock_bq_client, tmp_parquet_dir, sync_state + ): + """On first incremental sync (no existing file), adapter should read all data.""" + table_config = _make_table_config( + sync_strategy="incremental", + incremental_column="updated_at", + incremental_window_days=7, + ) + parquet_path = tmp_parquet_dir / "orders.parquet" + mock_config.get_parquet_path.return_value = parquet_path + + # No previous sync, no existing file + arrow_data = _sample_arrow_table([1, 2, 3], ["A", "B", "C"]) + mock_bq_client.read_table.return_value = arrow_data + + adapter = _create_adapter(mock_config, mock_bq_client) + result = adapter.sync_table(table_config, sync_state) + + assert result["success"] is True + assert result["rows"] == 3 + # Should call read_table (full), not read_table_incremental + mock_bq_client.read_table.assert_called_once_with(table_config.id) + mock_bq_client.read_table_incremental.assert_not_called() + + def test_first_sync_with_max_history_days( + self, mock_config, mock_bq_client, tmp_parquet_dir, sync_state + ): + """First sync with max_history_days should use read_table_incremental.""" + table_config = _make_table_config( + sync_strategy="incremental", + incremental_column="updated_at", + incremental_window_days=7, + max_history_days=90, + ) + parquet_path = tmp_parquet_dir / "orders.parquet" + mock_config.get_parquet_path.return_value = parquet_path + + arrow_data = _sample_arrow_table([1, 2], ["A", "B"]) + mock_bq_client.read_table_incremental.return_value = arrow_data + + adapter = _create_adapter(mock_config, mock_bq_client) + result = adapter.sync_table(table_config, sync_state) + + assert result["success"] is True + # Should use read_table_incremental (not read_table) because max_history_days is set + mock_bq_client.read_table_incremental.assert_called_once() + call_kwargs = mock_bq_client.read_table_incremental.call_args.kwargs + assert call_kwargs["incremental_column"] == "updated_at" + mock_bq_client.read_table.assert_not_called() + + +# --------------------------------------------------------------------------- +# 11. sync_table dispatches to correct strategy based on sync_strategy +# --------------------------------------------------------------------------- + +class TestSyncTableDispatch: + + def test_dispatches_full_refresh( + self, mock_config, mock_bq_client, tmp_parquet_dir, sync_state + ): + """sync_strategy='full_refresh' should call _full_refresh.""" + table_config = _make_table_config(sync_strategy="full_refresh") + mock_config.get_parquet_path.return_value = tmp_parquet_dir / "orders.parquet" + mock_bq_client.read_table.return_value = _sample_arrow_table([1], ["A"]) + + adapter = _create_adapter(mock_config, mock_bq_client) + + with patch.object(adapter, "_full_refresh", wraps=adapter._full_refresh) as spy: + adapter.sync_table(table_config, sync_state) + spy.assert_called_once_with(table_config) + + def test_dispatches_incremental( + self, mock_config, mock_bq_client, tmp_parquet_dir, sync_state + ): + """sync_strategy='incremental' should call _incremental_sync.""" + table_config = _make_table_config( + sync_strategy="incremental", + incremental_column="updated_at", + incremental_window_days=7, + ) + mock_config.get_parquet_path.return_value = tmp_parquet_dir / "orders.parquet" + mock_bq_client.read_table.return_value = _sample_arrow_table([1], ["A"]) + + adapter = _create_adapter(mock_config, mock_bq_client) + + with patch.object(adapter, "_incremental_sync", wraps=adapter._incremental_sync) as spy: + adapter.sync_table(table_config, sync_state) + spy.assert_called_once_with(table_config, sync_state) + + def test_dispatches_partitioned( + self, mock_config, mock_bq_client, tmp_parquet_dir, sync_state + ): + """sync_strategy='incremental' with partition_by should call _partitioned_sync.""" + import pandas as pd + + table_config = _make_table_config( + sync_strategy="incremental", + incremental_column="created_at", + partition_by="created_at", + partition_granularity="month", + incremental_window_days=7, + ) + partition_dir = tmp_parquet_dir / "orders" + partition_dir.mkdir(parents=True, exist_ok=True) + mock_config.get_parquet_path.return_value = partition_dir + + def _partition_path(tc, key): + return partition_dir / f"{key}.parquet" + mock_config.get_partition_path.side_effect = _partition_path + + ts = [pd.Timestamp("2026-01-15 10:00:00", tz="UTC")] + arrow_data = pa.table({ + "id": [1], + "name": ["A"], + "created_at": pa.array(ts, type=pa.timestamp("us", tz="UTC")), + }) + mock_bq_client.read_table.return_value = arrow_data + + adapter = _create_adapter(mock_config, mock_bq_client) + + with patch.object(adapter, "_partitioned_sync", wraps=adapter._partitioned_sync) as spy: + adapter.sync_table(table_config, sync_state) + spy.assert_called_once() + + def test_incremental_without_column_falls_back_to_full_refresh( + self, mock_config, mock_bq_client, tmp_parquet_dir, sync_state + ): + """incremental strategy without incremental_column or partition_by falls back to full_refresh.""" + table_config = _make_table_config( + sync_strategy="incremental", + incremental_column=None, + partition_by=None, + incremental_window_days=7, + ) + mock_config.get_parquet_path.return_value = tmp_parquet_dir / "orders.parquet" + mock_bq_client.read_table.return_value = _sample_arrow_table([1], ["A"]) + + adapter = _create_adapter(mock_config, mock_bq_client) + + with patch.object(adapter, "_full_refresh", wraps=adapter._full_refresh) as spy: + result = adapter.sync_table(table_config, sync_state) + spy.assert_called_once() + assert result["success"] is True + + +# --------------------------------------------------------------------------- +# 12. _merge_arrow_tables deduplicates correctly +# --------------------------------------------------------------------------- + +class TestMergeArrowTables: + + def test_dedup_on_single_pk(self, mock_config, mock_bq_client): + """Merge should deduplicate on single primary key column, new data wins.""" + adapter = _create_adapter(mock_config, mock_bq_client) + + existing = pa.table({"id": [1, 2, 3], "val": ["a", "b", "c"]}) + new_data = pa.table({"id": [2, 4], "val": ["B_new", "d"]}) + + merged = adapter._merge_arrow_tables(existing, new_data, primary_key=["id"]) + df = merged.to_pandas().sort_values("id").reset_index(drop=True) + + assert list(df["id"]) == [1, 2, 3, 4] + assert list(df["val"]) == ["a", "B_new", "c", "d"] + + def test_dedup_on_composite_pk(self, mock_config, mock_bq_client): + """Merge should deduplicate on composite primary key.""" + adapter = _create_adapter(mock_config, mock_bq_client) + + existing = pa.table({ + "pk1": [1, 1, 2], + "pk2": ["a", "b", "a"], + "val": ["old_1a", "old_1b", "old_2a"], + }) + new_data = pa.table({ + "pk1": [1, 2], + "pk2": ["a", "a"], + "val": ["new_1a", "new_2a"], + }) + + merged = adapter._merge_arrow_tables(existing, new_data, primary_key=["pk1", "pk2"]) + df = merged.to_pandas().sort_values(["pk1", "pk2"]).reset_index(drop=True) + + assert len(df) == 3 + # (1, a) should be updated + row_1a = df[(df["pk1"] == 1) & (df["pk2"] == "a")].iloc[0] + assert row_1a["val"] == "new_1a" + # (1, b) should be preserved + row_1b = df[(df["pk1"] == 1) & (df["pk2"] == "b")].iloc[0] + assert row_1b["val"] == "old_1b" + # (2, a) should be updated + row_2a = df[(df["pk1"] == 2) & (df["pk2"] == "a")].iloc[0] + assert row_2a["val"] == "new_2a" + + def test_merge_with_empty_new_data(self, mock_config, mock_bq_client): + """Merging with empty new data should return existing data unchanged.""" + adapter = _create_adapter(mock_config, mock_bq_client) + + existing = pa.table({"id": [1, 2], "val": ["a", "b"]}) + empty = pa.table({ + "id": pa.array([], type=pa.int64()), + "val": pa.array([], type=pa.string()), + }) + + merged = adapter._merge_arrow_tables(existing, empty, primary_key=["id"]) + assert merged.num_rows == 2 + + def test_merge_with_empty_existing(self, mock_config, mock_bq_client): + """Merging with empty existing data should return new data.""" + adapter = _create_adapter(mock_config, mock_bq_client) + + empty = pa.table({ + "id": pa.array([], type=pa.int64()), + "val": pa.array([], type=pa.string()), + }) + new_data = pa.table({"id": [1, 2], "val": ["a", "b"]}) + + merged = adapter._merge_arrow_tables(empty, new_data, primary_key=["id"]) + assert merged.num_rows == 2 + + +# --------------------------------------------------------------------------- +# Additional edge cases +# --------------------------------------------------------------------------- + +class TestMetadataCacheClearing: + + def test_clears_metadata_cache_before_sync( + self, mock_config, mock_bq_client, tmp_parquet_dir, sync_state + ): + """sync_table should clear the BQ metadata cache entry for the table being synced.""" + table_config = _make_table_config() + parquet_path = tmp_parquet_dir / "orders.parquet" + mock_config.get_parquet_path.return_value = parquet_path + mock_bq_client.read_table.return_value = _sample_arrow_table([1], ["A"]) + + # Pre-populate cache + mock_bq_client.metadata_cache[table_config.id] = {"some": "cached_data"} + + adapter = _create_adapter(mock_config, mock_bq_client) + adapter.sync_table(table_config, sync_state) + + assert table_config.id not in mock_bq_client.metadata_cache + + +class TestSyncStateUpdate: + + def test_sync_state_updated_after_success( + self, mock_config, mock_bq_client, tmp_parquet_dir, sync_state + ): + """After successful sync, the sync state should be updated with correct values.""" + table_config = _make_table_config() + parquet_path = tmp_parquet_dir / "orders.parquet" + mock_config.get_parquet_path.return_value = parquet_path + mock_bq_client.read_table.return_value = _sample_arrow_table([1, 2], ["A", "B"]) + + adapter = _create_adapter(mock_config, mock_bq_client) + adapter.sync_table(table_config, sync_state) + + state = sync_state.get_table_state(table_config.id) + assert state["rows"] == 2 + assert state["strategy"] == "full_refresh" + assert state["table_name"] == "orders" + assert "last_sync" in state + + def test_sync_state_not_updated_on_failure( + self, mock_config, mock_bq_client, tmp_parquet_dir, sync_state + ): + """On sync failure, the sync state should NOT be updated.""" + table_config = _make_table_config() + mock_config.get_parquet_path.return_value = tmp_parquet_dir / "orders.parquet" + mock_bq_client.read_table.side_effect = RuntimeError("boom") + + adapter = _create_adapter(mock_config, mock_bq_client) + adapter.sync_table(table_config, sync_state) + + state = sync_state.get_table_state(table_config.id) + assert state == {} + + +class TestCreateDataSourceFactory: + + def test_factory_returns_adapter_instance(self, mock_config, mock_bq_client): + """create_data_source() factory should return a BigQueryDataSource instance.""" + with patch("connectors.bigquery.adapter.get_config", return_value=mock_config), \ + patch("connectors.bigquery.adapter.create_bq_client", return_value=mock_bq_client): + from connectors.bigquery.adapter import create_data_source, BigQueryDataSource + instance = create_data_source() + assert isinstance(instance, BigQueryDataSource) diff --git a/tests/test_bigquery_client.py b/tests/test_bigquery_client.py new file mode 100644 index 0000000..d499eab --- /dev/null +++ b/tests/test_bigquery_client.py @@ -0,0 +1,870 @@ +"""Tests for the BigQuery client connector. + +All external dependencies (google.cloud.bigquery, src.config) are mocked. +Tests cover initialization, metadata caching, schema building, query methods, +and connection testing. +""" + +import json +import sys +from datetime import datetime, timedelta +from pathlib import Path +from unittest.mock import MagicMock, mock_open, patch + +import pyarrow as pa +import pytest + +# Pre-populate sys.modules with a mock google.cloud.bigquery if not installed, +# so the client module can be imported without the real SDK. +_bq_mock_installed = False +try: + from google.cloud import bigquery as _bq_test # noqa: F401 +except ImportError: + _bq_mock_installed = True + _mock_bigquery = MagicMock() + # Expose commonly used classes as MagicMock so the client module + # can reference bigquery.Client, bigquery.QueryJobConfig, etc. + sys.modules.setdefault("google", MagicMock()) + sys.modules.setdefault("google.cloud", MagicMock()) + sys.modules.setdefault("google.cloud.bigquery", _mock_bigquery) + +from connectors.bigquery.client import ( + BIGQUERY_TO_PYARROW_TYPES, + BigQueryClient, + create_client, +) + +# Import the real or mock bigquery reference used in the client module +from google.cloud import bigquery + + +# --------------------------------------------------------------------------- +# Helpers +# --------------------------------------------------------------------------- + +def _make_bq_field(name: str, field_type: str, description: str = None): + """Create a mock BigQuery SchemaField.""" + field = MagicMock() + field.name = name + field.field_type = field_type + field.description = description + return field + + +def _make_table_ref( + table_id: str = "my-project.my_dataset.my_table", + schema=None, + num_rows: int = 1000, + num_bytes: int = 50000, + created: datetime = None, + modified: datetime = None, + time_partitioning=None, +): + """Create a mock BigQuery Table reference object.""" + table_ref = MagicMock() + table_ref.table_id = table_id.split(".")[-1] + table_ref.dataset_id = table_id.split(".")[1] if "." in table_id else "dataset" + table_ref.project = table_id.split(".")[0] if "." in table_id else "project" + table_ref.schema = schema or [] + table_ref.num_rows = num_rows + table_ref.num_bytes = num_bytes + table_ref.created = created or datetime(2025, 1, 1, 12, 0, 0) + table_ref.modified = modified or datetime(2025, 6, 1, 12, 0, 0) + table_ref.time_partitioning = time_partitioning + return table_ref + + +@pytest.fixture +def mock_config(tmp_path): + """Mock get_config() to return a config with metadata path in tmp_path.""" + config = MagicMock() + metadata_dir = tmp_path / "metadata" + metadata_dir.mkdir(parents=True, exist_ok=True) + config.get_metadata_path.return_value = metadata_dir + return config + + +@pytest.fixture +def mock_bq_client(): + """Create a mock BigQuery Client.""" + return MagicMock() + + +@pytest.fixture +def client(mock_config, mock_bq_client): + """Create a BigQueryClient instance with mocked dependencies.""" + with ( + patch("connectors.bigquery.client.bigquery.Client", return_value=mock_bq_client), + patch("connectors.bigquery.client.get_config", return_value=mock_config), + patch.dict("os.environ", {"BIGQUERY_PROJECT": "test-project"}), + ): + bq_client = BigQueryClient() + return bq_client + + +# --------------------------------------------------------------------------- +# 1. Init validates BIGQUERY_PROJECT env var +# --------------------------------------------------------------------------- + +class TestInit: + def test_raises_value_error_when_project_not_set(self, mock_config): + """Init raises ValueError if project_id is None and env var is missing.""" + with ( + patch("connectors.bigquery.client.bigquery.Client"), + patch("connectors.bigquery.client.get_config", return_value=mock_config), + patch.dict("os.environ", {}, clear=True), + ): + with pytest.raises(ValueError, match="BigQuery project ID not set"): + BigQueryClient() + + def test_raises_value_error_when_project_empty_string(self, mock_config): + """Init raises ValueError if BIGQUERY_PROJECT is set to empty string.""" + with ( + patch("connectors.bigquery.client.bigquery.Client"), + patch("connectors.bigquery.client.get_config", return_value=mock_config), + patch.dict("os.environ", {"BIGQUERY_PROJECT": ""}, clear=True), + ): + with pytest.raises(ValueError, match="BigQuery project ID not set"): + BigQueryClient() + + # ------------------------------------------------------------------- + # 2. Init creates client with correct project_id + # ------------------------------------------------------------------- + + def test_creates_client_with_env_project_id(self, mock_config): + """Client uses BIGQUERY_PROJECT from environment.""" + mock_bq = MagicMock() + with ( + patch("connectors.bigquery.client.bigquery.Client", return_value=mock_bq) as bq_cls, + patch("connectors.bigquery.client.get_config", return_value=mock_config), + patch.dict("os.environ", {"BIGQUERY_PROJECT": "env-project-123"}), + ): + client = BigQueryClient() + bq_cls.assert_called_once_with(project="env-project-123") + assert client.project_id == "env-project-123" + + def test_creates_client_with_explicit_project_id(self, mock_config): + """Explicit project_id argument takes precedence over env var.""" + mock_bq = MagicMock() + with ( + patch("connectors.bigquery.client.bigquery.Client", return_value=mock_bq) as bq_cls, + patch("connectors.bigquery.client.get_config", return_value=mock_config), + ): + client = BigQueryClient(project_id="explicit-project") + bq_cls.assert_called_once_with(project="explicit-project") + assert client.project_id == "explicit-project" + + +# --------------------------------------------------------------------------- +# 3. get_table_metadata fetches and caches metadata correctly +# --------------------------------------------------------------------------- + +class TestGetTableMetadata: + def test_fetches_metadata_from_bigquery(self, client, mock_bq_client): + """get_table_metadata calls client.get_table and returns correct dict.""" + table_id = "proj.dataset.orders" + schema = [ + _make_bq_field("order_id", "INTEGER"), + _make_bq_field("customer_name", "STRING", description="Full name"), + _make_bq_field("created_at", "TIMESTAMP"), + ] + table_ref = _make_table_ref( + table_id=table_id, + schema=schema, + num_rows=5000, + num_bytes=120000, + ) + mock_bq_client.get_table.return_value = table_ref + + metadata = client.get_table_metadata(table_id, use_cache=False) + + mock_bq_client.get_table.assert_called_once_with(table_id) + assert metadata["table_id"] == table_id + assert metadata["name"] == "orders" + assert metadata["dataset"] == "dataset" + assert metadata["project"] == "proj" + assert metadata["columns"] == ["order_id", "customer_name", "created_at"] + assert metadata["column_types"]["order_id"] == "INTEGER" + assert metadata["column_types"]["customer_name"] == "STRING" + assert metadata["column_types"]["created_at"] == "TIMESTAMP" + assert metadata["column_descriptions"]["customer_name"] == "Full name" + assert "order_id" not in metadata["column_descriptions"] + assert metadata["row_count"] == 5000 + assert metadata["size_bytes"] == 120000 + assert "_cached_at" in metadata + + def test_caches_metadata_in_memory(self, client, mock_bq_client): + """After first fetch, metadata is stored in the in-memory cache.""" + table_id = "proj.dataset.tbl" + table_ref = _make_table_ref(table_id=table_id) + mock_bq_client.get_table.return_value = table_ref + + client.get_table_metadata(table_id, use_cache=False) + + assert table_id in client.metadata_cache + assert client.metadata_cache[table_id]["table_id"] == table_id + + def test_captures_partitioning_info(self, client, mock_bq_client): + """Partitioning metadata is captured when table is partitioned.""" + table_id = "proj.dataset.events" + partition = MagicMock() + partition.type_ = "DAY" + partition.field = "event_date" + partition.expiration_ms = 7776000000 + + table_ref = _make_table_ref(table_id=table_id, time_partitioning=partition) + mock_bq_client.get_table.return_value = table_ref + + metadata = client.get_table_metadata(table_id, use_cache=False) + + assert metadata["partitioning"] is not None + assert metadata["partitioning"]["type"] == "DAY" + assert metadata["partitioning"]["field"] == "event_date" + assert metadata["partitioning"]["expiration_ms"] == 7776000000 + + def test_no_partitioning_when_absent(self, client, mock_bq_client): + """Partitioning is None when table has no partitioning.""" + table_id = "proj.dataset.simple" + table_ref = _make_table_ref(table_id=table_id, time_partitioning=None) + mock_bq_client.get_table.return_value = table_ref + + metadata = client.get_table_metadata(table_id, use_cache=False) + assert metadata["partitioning"] is None + + # ------------------------------------------------------------------- + # 4. get_table_metadata uses cache when available (within TTL) + # ------------------------------------------------------------------- + + def test_uses_cache_within_ttl(self, client, mock_bq_client): + """When cache is fresh (within TTL), BQ API is not called again.""" + table_id = "proj.dataset.cached_tbl" + now = datetime.now() + client.metadata_cache[table_id] = { + "table_id": table_id, + "columns": ["a", "b"], + "column_types": {"a": "STRING", "b": "INTEGER"}, + "_cached_at": now.isoformat(), + } + + result = client.get_table_metadata(table_id, use_cache=True, cache_ttl_hours=24) + + mock_bq_client.get_table.assert_not_called() + assert result["table_id"] == table_id + assert result["columns"] == ["a", "b"] + + def test_refetches_when_cache_expired(self, client, mock_bq_client): + """When cache is older than TTL, metadata is re-fetched from BQ.""" + table_id = "proj.dataset.stale_tbl" + old_time = (datetime.now() - timedelta(hours=48)).isoformat() + client.metadata_cache[table_id] = { + "table_id": table_id, + "columns": ["old_col"], + "column_types": {"old_col": "STRING"}, + "_cached_at": old_time, + } + + table_ref = _make_table_ref( + table_id=table_id, + schema=[_make_bq_field("new_col", "INTEGER")], + ) + mock_bq_client.get_table.return_value = table_ref + + result = client.get_table_metadata(table_id, use_cache=True, cache_ttl_hours=24) + + mock_bq_client.get_table.assert_called_once_with(table_id) + assert result["columns"] == ["new_col"] + + def test_bypasses_cache_when_use_cache_false(self, client, mock_bq_client): + """When use_cache=False, always fetches from BQ even if cache is fresh.""" + table_id = "proj.dataset.force_fetch" + client.metadata_cache[table_id] = { + "table_id": table_id, + "columns": ["cached"], + "column_types": {"cached": "STRING"}, + "_cached_at": datetime.now().isoformat(), + } + + table_ref = _make_table_ref( + table_id=table_id, + schema=[_make_bq_field("fresh", "INTEGER")], + ) + mock_bq_client.get_table.return_value = table_ref + + result = client.get_table_metadata(table_id, use_cache=False) + mock_bq_client.get_table.assert_called_once() + assert result["columns"] == ["fresh"] + + +# --------------------------------------------------------------------------- +# 5. get_pyarrow_schema builds correct schema from BQ types +# --------------------------------------------------------------------------- + +class TestGetPyarrowSchema: + def test_builds_correct_schema(self, client, mock_bq_client): + """Schema maps BQ types to correct PyArrow types.""" + table_id = "proj.dataset.typed_tbl" + schema = [ + _make_bq_field("id", "INT64"), + _make_bq_field("name", "STRING"), + _make_bq_field("price", "FLOAT64"), + _make_bq_field("active", "BOOLEAN"), + _make_bq_field("created", "DATE"), + _make_bq_field("updated_at", "TIMESTAMP"), + ] + table_ref = _make_table_ref(table_id=table_id, schema=schema) + mock_bq_client.get_table.return_value = table_ref + + pa_schema = client.get_pyarrow_schema(table_id) + + assert pa_schema is not None + assert pa_schema.field("id").type == pa.int64() + assert pa_schema.field("name").type == pa.string() + assert pa_schema.field("price").type == pa.float64() + assert pa_schema.field("active").type == pa.bool_() + assert pa_schema.field("created").type == pa.date32() + assert pa_schema.field("updated_at").type == pa.timestamp("us", tz="UTC") + + def test_returns_none_when_no_column_types(self, client): + """Returns None when metadata has no column_types.""" + table_id = "proj.dataset.empty_schema" + client.metadata_cache[table_id] = { + "table_id": table_id, + "columns": [], + "column_types": {}, + "_cached_at": datetime.now().isoformat(), + } + + result = client.get_pyarrow_schema(table_id) + assert result is None + + def test_unknown_type_falls_back_to_string(self, client, mock_bq_client): + """Unknown BQ types default to pa.string() in the schema.""" + table_id = "proj.dataset.exotic_types" + schema = [_make_bq_field("exotic_col", "SOME_UNKNOWN_TYPE")] + table_ref = _make_table_ref(table_id=table_id, schema=schema) + mock_bq_client.get_table.return_value = table_ref + + pa_schema = client.get_pyarrow_schema(table_id) + assert pa_schema.field("exotic_col").type == pa.string() + + +# --------------------------------------------------------------------------- +# 6. get_date_columns returns only DATE columns +# --------------------------------------------------------------------------- + +class TestGetDateColumns: + def test_returns_only_date_columns(self, client, mock_bq_client): + """Only columns with BQ type DATE are returned.""" + table_id = "proj.dataset.mixed_dates" + schema = [ + _make_bq_field("event_date", "DATE"), + _make_bq_field("created_at", "TIMESTAMP"), + _make_bq_field("name", "STRING"), + _make_bq_field("birth_date", "DATE"), + _make_bq_field("updated_ts", "DATETIME"), + ] + table_ref = _make_table_ref(table_id=table_id, schema=schema) + mock_bq_client.get_table.return_value = table_ref + + date_cols = client.get_date_columns(table_id) + assert sorted(date_cols) == ["birth_date", "event_date"] + + def test_returns_empty_when_no_date_columns(self, client, mock_bq_client): + """Returns empty list when no DATE columns exist.""" + table_id = "proj.dataset.no_dates" + schema = [ + _make_bq_field("id", "INTEGER"), + _make_bq_field("ts", "TIMESTAMP"), + ] + table_ref = _make_table_ref(table_id=table_id, schema=schema) + mock_bq_client.get_table.return_value = table_ref + + date_cols = client.get_date_columns(table_id) + assert date_cols == [] + + +# --------------------------------------------------------------------------- +# 7. query_to_arrow executes SQL and returns PyArrow table +# --------------------------------------------------------------------------- + +class TestQueryToArrow: + def test_executes_query_and_returns_arrow(self, client, mock_bq_client): + """query_to_arrow passes SQL to BQ and returns the arrow result.""" + expected_table = pa.table({"col1": [1, 2, 3]}) + mock_job = MagicMock() + mock_job.to_arrow.return_value = expected_table + mock_bq_client.query.return_value = mock_job + + with patch("connectors.bigquery.client.bigquery") as mock_bq_module: + mock_bq_module.QueryJobConfig.return_value = MagicMock(query_parameters=None) + client.client = mock_bq_client + + result = client.query_to_arrow("SELECT * FROM `proj.dataset.tbl`") + + mock_bq_client.query.assert_called_once() + call_args = mock_bq_client.query.call_args + assert call_args[0][0] == "SELECT * FROM `proj.dataset.tbl`" + assert result.equals(expected_table) + + def test_passes_query_parameters(self, client, mock_bq_client): + """query_to_arrow forwards BQ query parameters in job config.""" + expected_table = pa.table({"col1": [10]}) + mock_job = MagicMock() + mock_job.to_arrow.return_value = expected_table + mock_bq_client.query.return_value = mock_job + + mock_job_config = MagicMock() + params = [MagicMock()] # Mock ScalarQueryParameter + + with patch("connectors.bigquery.client.bigquery") as mock_bq_module: + mock_bq_module.QueryJobConfig.return_value = mock_job_config + client.client = mock_bq_client + + client.query_to_arrow("SELECT 1 WHERE x > @val", params=params) + + # Verify params were set on the job config + assert mock_job_config.query_parameters == params + + def test_no_params_does_not_set_query_parameters(self, client, mock_bq_client): + """When no params given, query_parameters is not set on job config.""" + mock_job = MagicMock() + mock_job.to_arrow.return_value = pa.table({"x": [1]}) + mock_bq_client.query.return_value = mock_job + + mock_job_config = MagicMock(spec=[]) + with patch("connectors.bigquery.client.bigquery") as mock_bq_module: + mock_bq_module.QueryJobConfig.return_value = mock_job_config + client.client = mock_bq_client + + client.query_to_arrow("SELECT 1") + + # query_parameters should not have been set + assert not hasattr(mock_job_config, "query_parameters") or not getattr( + mock_job_config, "query_parameters", None + ) + + +# --------------------------------------------------------------------------- +# 8. read_table builds correct SQL query +# --------------------------------------------------------------------------- + +class TestReadTable: + def test_full_table_select_all(self, client, mock_bq_client): + """read_table with no columns or filter generates SELECT *.""" + mock_job = MagicMock() + mock_job.to_arrow.return_value = pa.table({"a": [1]}) + mock_bq_client.query.return_value = mock_job + + client.read_table("proj.dataset.tbl") + + sql = mock_bq_client.query.call_args[0][0] + assert "SELECT *" in sql + assert "`proj.dataset.tbl`" in sql + assert "WHERE" not in sql + + def test_select_specific_columns(self, client, mock_bq_client): + """read_table with columns list generates SELECT with backtick-quoted names.""" + mock_job = MagicMock() + mock_job.to_arrow.return_value = pa.table({"a": [1]}) + mock_bq_client.query.return_value = mock_job + + client.read_table("proj.dataset.tbl", columns=["col_a", "col_b"]) + + sql = mock_bq_client.query.call_args[0][0] + assert "`col_a`" in sql + assert "`col_b`" in sql + assert "*" not in sql + + def test_with_row_filter(self, client, mock_bq_client): + """read_table with row_filter appends WHERE clause.""" + mock_job = MagicMock() + mock_job.to_arrow.return_value = pa.table({"a": [1]}) + mock_bq_client.query.return_value = mock_job + + client.read_table("proj.dataset.tbl", row_filter="status = 'active'") + + sql = mock_bq_client.query.call_args[0][0] + assert "WHERE status = 'active'" in sql + + def test_columns_and_filter_combined(self, client, mock_bq_client): + """read_table with both columns and filter generates correct SQL.""" + mock_job = MagicMock() + mock_job.to_arrow.return_value = pa.table({"x": [1]}) + mock_bq_client.query.return_value = mock_job + + client.read_table( + "proj.dataset.tbl", + columns=["id", "name"], + row_filter="id > 100", + ) + + sql = mock_bq_client.query.call_args[0][0] + assert "`id`, `name`" in sql + assert "WHERE id > 100" in sql + assert "`proj.dataset.tbl`" in sql + + +# --------------------------------------------------------------------------- +# 9. read_table_incremental builds parameterized WHERE clause +# --------------------------------------------------------------------------- + +class TestReadTableIncremental: + def test_incremental_query_structure(self, client, mock_bq_client): + """read_table_incremental builds WHERE col > @since_value with params.""" + mock_job = MagicMock() + mock_job.to_arrow.return_value = pa.table({"a": [1]}) + mock_bq_client.query.return_value = mock_job + + with patch("connectors.bigquery.client.bigquery") as mock_bq_module: + mock_bq_module.QueryJobConfig.return_value = MagicMock() + mock_param = MagicMock() + mock_bq_module.ScalarQueryParameter.return_value = mock_param + # Re-assign the client's bq client (the fixture already set it up) + client.client = mock_bq_client + + client.read_table_incremental( + table_id="proj.dataset.events", + incremental_column="updated_at", + since_value="2025-01-01T00:00:00Z", + ) + + sql = mock_bq_client.query.call_args[0][0] + assert "SELECT *" in sql + assert "`proj.dataset.events`" in sql + assert "`updated_at` > @since_value" in sql + + # Verify ScalarQueryParameter was constructed correctly + mock_bq_module.ScalarQueryParameter.assert_called_once_with( + "since_value", "TIMESTAMP", "2025-01-01T00:00:00Z" + ) + + def test_incremental_with_columns(self, client, mock_bq_client): + """read_table_incremental with columns list selects specific columns.""" + mock_job = MagicMock() + mock_job.to_arrow.return_value = pa.table({"a": [1]}) + mock_bq_client.query.return_value = mock_job + + with patch("connectors.bigquery.client.bigquery") as mock_bq_module: + mock_bq_module.QueryJobConfig.return_value = MagicMock() + mock_bq_module.ScalarQueryParameter.return_value = MagicMock() + client.client = mock_bq_client + + client.read_table_incremental( + table_id="proj.dataset.events", + incremental_column="updated_at", + since_value="2025-01-01T00:00:00Z", + columns=["id", "name"], + ) + + sql = mock_bq_client.query.call_args[0][0] + assert "`id`, `name`" in sql + assert "*" not in sql + + +# --------------------------------------------------------------------------- +# 10. read_table_partitioned builds correct range query +# --------------------------------------------------------------------------- + +class TestReadTablePartitioned: + def test_partitioned_start_only(self, client, mock_bq_client): + """With only start, generates >= @start_value without end clause.""" + mock_job = MagicMock() + mock_job.to_arrow.return_value = pa.table({"a": [1]}) + mock_bq_client.query.return_value = mock_job + + with patch("connectors.bigquery.client.bigquery") as mock_bq_module: + mock_bq_module.QueryJobConfig.return_value = MagicMock() + mock_bq_module.ScalarQueryParameter.return_value = MagicMock() + client.client = mock_bq_client + + client.read_table_partitioned( + table_id="proj.dataset.events", + partition_column="event_date", + start="2025-01-01", + ) + + sql = mock_bq_client.query.call_args[0][0] + assert "`event_date` >= @start_value" in sql + assert "@end_value" not in sql + + # Only start_value parameter created + assert mock_bq_module.ScalarQueryParameter.call_count == 1 + mock_bq_module.ScalarQueryParameter.assert_called_with( + "start_value", "TIMESTAMP", "2025-01-01" + ) + + def test_partitioned_start_and_end(self, client, mock_bq_client): + """With start and end, generates >= @start_value AND < @end_value.""" + mock_job = MagicMock() + mock_job.to_arrow.return_value = pa.table({"a": [1]}) + mock_bq_client.query.return_value = mock_job + + with patch("connectors.bigquery.client.bigquery") as mock_bq_module: + mock_bq_module.QueryJobConfig.return_value = MagicMock() + mock_bq_module.ScalarQueryParameter.return_value = MagicMock() + client.client = mock_bq_client + + client.read_table_partitioned( + table_id="proj.dataset.events", + partition_column="event_date", + start="2025-01-01", + end="2025-06-01", + ) + + sql = mock_bq_client.query.call_args[0][0] + assert "`event_date` >= @start_value" in sql + assert "`event_date` < @end_value" in sql + + # Both start_value and end_value parameters created + assert mock_bq_module.ScalarQueryParameter.call_count == 2 + calls = mock_bq_module.ScalarQueryParameter.call_args_list + assert calls[0].args == ("start_value", "TIMESTAMP", "2025-01-01") + assert calls[1].args == ("end_value", "TIMESTAMP", "2025-06-01") + + def test_partitioned_with_columns(self, client, mock_bq_client): + """read_table_partitioned with columns selects specific columns.""" + mock_job = MagicMock() + mock_job.to_arrow.return_value = pa.table({"a": [1]}) + mock_bq_client.query.return_value = mock_job + + with patch("connectors.bigquery.client.bigquery") as mock_bq_module: + mock_bq_module.QueryJobConfig.return_value = MagicMock() + mock_bq_module.ScalarQueryParameter.return_value = MagicMock() + client.client = mock_bq_client + + client.read_table_partitioned( + table_id="proj.dataset.events", + partition_column="event_date", + start="2025-01-01", + columns=["id", "event_date", "value"], + ) + + sql = mock_bq_client.query.call_args[0][0] + assert "`id`, `event_date`, `value`" in sql + assert "*" not in sql + + +# --------------------------------------------------------------------------- +# 11. test_connection returns True on success, False on failure +# --------------------------------------------------------------------------- + +class TestTestConnection: + def test_returns_true_on_success(self, client, mock_bq_client): + """test_connection returns True when SELECT 1 query succeeds.""" + mock_job = MagicMock() + mock_job.result.return_value = iter([(1,)]) + mock_bq_client.query.return_value = mock_job + + assert client.test_connection() is True + mock_bq_client.query.assert_called_once_with("SELECT 1") + + def test_returns_false_on_failure(self, client, mock_bq_client): + """test_connection returns False when the query raises an exception.""" + mock_bq_client.query.side_effect = Exception("Connection refused") + + assert client.test_connection() is False + + def test_returns_false_when_result_fails(self, client, mock_bq_client): + """test_connection returns False when result iteration fails.""" + mock_job = MagicMock() + mock_job.result.side_effect = Exception("Timeout") + mock_bq_client.query.return_value = mock_job + + assert client.test_connection() is False + + +# --------------------------------------------------------------------------- +# 12. Type mapping completeness (all BQ types have PyArrow mapping) +# --------------------------------------------------------------------------- + +class TestTypeMapping: + # All standard BigQuery types that should be mapped + EXPECTED_BQ_TYPES = [ + "STRING", "BYTES", "INTEGER", "INT64", + "FLOAT", "FLOAT64", "NUMERIC", "BIGNUMERIC", + "BOOLEAN", "BOOL", + "TIMESTAMP", "DATE", "TIME", "DATETIME", + "GEOGRAPHY", "JSON", + "STRUCT", "RECORD", "ARRAY", + ] + + def test_all_standard_bq_types_are_mapped(self): + """Every standard BigQuery type has an entry in BIGQUERY_TO_PYARROW_TYPES.""" + for bq_type in self.EXPECTED_BQ_TYPES: + assert bq_type in BIGQUERY_TO_PYARROW_TYPES, ( + f"Missing PyArrow mapping for BQ type: {bq_type}" + ) + + def test_all_mappings_produce_valid_pyarrow_types(self): + """Every mapped value is a valid PyArrow DataType.""" + for bq_type, pa_type in BIGQUERY_TO_PYARROW_TYPES.items(): + assert isinstance(pa_type, pa.DataType), ( + f"BQ type {bq_type} maps to non-DataType: {pa_type!r}" + ) + + def test_integer_types_map_to_int64(self): + """Both INTEGER and INT64 map to pa.int64().""" + assert BIGQUERY_TO_PYARROW_TYPES["INTEGER"] == pa.int64() + assert BIGQUERY_TO_PYARROW_TYPES["INT64"] == pa.int64() + + def test_float_types_map_to_float64(self): + """FLOAT, FLOAT64, NUMERIC, BIGNUMERIC all map to pa.float64().""" + for t in ["FLOAT", "FLOAT64", "NUMERIC", "BIGNUMERIC"]: + assert BIGQUERY_TO_PYARROW_TYPES[t] == pa.float64() + + def test_boolean_types_map_to_bool(self): + """Both BOOLEAN and BOOL map to pa.bool_().""" + assert BIGQUERY_TO_PYARROW_TYPES["BOOLEAN"] == pa.bool_() + assert BIGQUERY_TO_PYARROW_TYPES["BOOL"] == pa.bool_() + + def test_date_maps_to_date32(self): + """DATE maps to pa.date32().""" + assert BIGQUERY_TO_PYARROW_TYPES["DATE"] == pa.date32() + + def test_timestamp_has_utc_timezone(self): + """TIMESTAMP maps to pa.timestamp with UTC timezone.""" + ts_type = BIGQUERY_TO_PYARROW_TYPES["TIMESTAMP"] + assert ts_type == pa.timestamp("us", tz="UTC") + + def test_datetime_has_no_timezone(self): + """DATETIME maps to pa.timestamp without timezone.""" + dt_type = BIGQUERY_TO_PYARROW_TYPES["DATETIME"] + assert dt_type == pa.timestamp("us") + + def test_complex_types_map_to_string(self): + """STRUCT, RECORD, ARRAY, GEOGRAPHY, JSON all serialize as string.""" + for t in ["STRUCT", "RECORD", "ARRAY", "GEOGRAPHY", "JSON"]: + assert BIGQUERY_TO_PYARROW_TYPES[t] == pa.string() + + +# --------------------------------------------------------------------------- +# 13. Metadata cache save/load from disk +# --------------------------------------------------------------------------- + +class TestMetadataCachePersistence: + def test_save_and_load_cache(self, tmp_path): + """Metadata cache is persisted to disk and reloaded on new client init.""" + metadata_dir = tmp_path / "metadata" + metadata_dir.mkdir(parents=True, exist_ok=True) + cache_file = metadata_dir / "bq_table_metadata.json" + + mock_config = MagicMock() + mock_config.get_metadata_path.return_value = metadata_dir + + # First client: fetch metadata and save to cache + mock_bq = MagicMock() + table_id = "proj.ds.tbl" + schema = [_make_bq_field("col1", "STRING")] + table_ref = _make_table_ref(table_id=table_id, schema=schema) + mock_bq.get_table.return_value = table_ref + + with ( + patch("connectors.bigquery.client.bigquery.Client", return_value=mock_bq), + patch("connectors.bigquery.client.get_config", return_value=mock_config), + patch.dict("os.environ", {"BIGQUERY_PROJECT": "proj"}), + ): + client1 = BigQueryClient() + client1.get_table_metadata(table_id, use_cache=False) + + # Verify the cache file was written + assert cache_file.exists() + saved_data = json.loads(cache_file.read_text()) + assert table_id in saved_data + assert saved_data[table_id]["columns"] == ["col1"] + + # Second client: loads cache from disk on init + mock_bq2 = MagicMock() + with ( + patch("connectors.bigquery.client.bigquery.Client", return_value=mock_bq2), + patch("connectors.bigquery.client.get_config", return_value=mock_config), + patch.dict("os.environ", {"BIGQUERY_PROJECT": "proj"}), + ): + client2 = BigQueryClient() + + assert table_id in client2.metadata_cache + assert client2.metadata_cache[table_id]["columns"] == ["col1"] + + def test_load_handles_corrupt_cache_file(self, tmp_path): + """Client handles corrupt cache JSON gracefully without crashing.""" + metadata_dir = tmp_path / "metadata" + metadata_dir.mkdir(parents=True, exist_ok=True) + cache_file = metadata_dir / "bq_table_metadata.json" + cache_file.write_text("{corrupt json!!!") + + mock_config = MagicMock() + mock_config.get_metadata_path.return_value = metadata_dir + + mock_bq = MagicMock() + with ( + patch("connectors.bigquery.client.bigquery.Client", return_value=mock_bq), + patch("connectors.bigquery.client.get_config", return_value=mock_config), + patch.dict("os.environ", {"BIGQUERY_PROJECT": "proj"}), + ): + client = BigQueryClient() + + # Cache should be empty after corrupt file + assert client.metadata_cache == {} + + def test_load_handles_missing_cache_file(self, tmp_path): + """Client initializes with empty cache when no cache file exists.""" + metadata_dir = tmp_path / "metadata" + metadata_dir.mkdir(parents=True, exist_ok=True) + # No cache file created + + mock_config = MagicMock() + mock_config.get_metadata_path.return_value = metadata_dir + + mock_bq = MagicMock() + with ( + patch("connectors.bigquery.client.bigquery.Client", return_value=mock_bq), + patch("connectors.bigquery.client.get_config", return_value=mock_config), + patch.dict("os.environ", {"BIGQUERY_PROJECT": "proj"}), + ): + client = BigQueryClient() + + assert client.metadata_cache == {} + + def test_save_creates_parent_directories(self, tmp_path): + """_save_metadata_cache creates parent directories if they do not exist.""" + # Use a nested path that does not yet exist + metadata_dir = tmp_path / "deep" / "nested" / "metadata" + # Do NOT create directories upfront + + mock_config = MagicMock() + mock_config.get_metadata_path.return_value = metadata_dir + + mock_bq = MagicMock() + table_id = "proj.ds.tbl" + schema = [_make_bq_field("x", "INTEGER")] + table_ref = _make_table_ref(table_id=table_id, schema=schema) + mock_bq.get_table.return_value = table_ref + + with ( + patch("connectors.bigquery.client.bigquery.Client", return_value=mock_bq), + patch("connectors.bigquery.client.get_config", return_value=mock_config), + patch.dict("os.environ", {"BIGQUERY_PROJECT": "proj"}), + ): + client = BigQueryClient() + client.get_table_metadata(table_id, use_cache=False) + + cache_file = metadata_dir / "bq_table_metadata.json" + assert cache_file.exists() + + +# --------------------------------------------------------------------------- +# Factory function +# --------------------------------------------------------------------------- + +class TestCreateClient: + def test_create_client_returns_bigquery_client(self, mock_config): + """create_client() factory returns a BigQueryClient instance.""" + mock_bq = MagicMock() + with ( + patch("connectors.bigquery.client.bigquery.Client", return_value=mock_bq), + patch("connectors.bigquery.client.get_config", return_value=mock_config), + patch.dict("os.environ", {"BIGQUERY_PROJECT": "factory-project"}), + ): + result = create_client() + + assert isinstance(result, BigQueryClient) + assert result.project_id == "factory-project"