Add BigQuery data source adapter
BigQuery connector that syncs BQ tables to local Parquet files via PyArrow (no CSV intermediate step). Supports full refresh, timestamp-based incremental (via incremental_column), and partition-based sync strategies. - connectors/bigquery/client.py: BQ API wrapper with ADC auth, parameterized queries, metadata cache, cross-project support (job project != data project) - connectors/bigquery/adapter.py: DataSource implementation with merge/dedup - src/config.py: Add incremental_column field to TableConfig - 72 unit tests (mocked, no GCP SDK required)
This commit is contained in:
parent
eb5264b903
commit
758910463b
9 changed files with 2619 additions and 2 deletions
|
|
@ -44,13 +44,38 @@ auth:
|
||||||
google_client_id: "${GOOGLE_CLIENT_ID}"
|
google_client_id: "${GOOGLE_CLIENT_ID}"
|
||||||
google_client_secret: "${GOOGLE_CLIENT_SECRET}"
|
google_client_secret: "${GOOGLE_CLIENT_SECRET}"
|
||||||
|
|
||||||
|
# --- Theme (optional) ---
|
||||||
|
# Customize colors, fonts, and shape to match your brand.
|
||||||
|
# All values are optional - defaults provide a clean blue theme.
|
||||||
|
# See docs/theme-reference.html for a visual guide.
|
||||||
|
theme:
|
||||||
|
# primary: "#0073D1" # Main brand color (buttons, links, accents)
|
||||||
|
# primary_dark: "#005BA3" # Hover/active state of primary
|
||||||
|
# primary_light: "rgba(0, 115, 209, 0.1)" # Light tint backgrounds
|
||||||
|
# text_primary: "#1A253C" # Main text color
|
||||||
|
# text_secondary: "#6B7280" # Muted/secondary text
|
||||||
|
# background: "#F5F7FA" # Page background
|
||||||
|
# surface: "#FFFFFF" # Card/panel background
|
||||||
|
# border: "#E5E7EB" # Borders and dividers
|
||||||
|
# font_primary: "'Inter', system-ui, sans-serif"
|
||||||
|
# font_url: "https://fonts.googleapis.com/css2?family=Inter:wght@400;500;600;700&display=swap"
|
||||||
|
# radius: "6px" # Border radius (cards, buttons, inputs)
|
||||||
|
# success: "#10B77F"
|
||||||
|
# warning: "#F59F0A"
|
||||||
|
# error: "#EA580C"
|
||||||
|
|
||||||
# --- Data source ---
|
# --- Data source ---
|
||||||
data_source:
|
data_source:
|
||||||
type: "keboola" # keboola | csv (bigquery planned)
|
type: "keboola" # keboola | bigquery | local
|
||||||
keboola:
|
keboola:
|
||||||
storage_token: "${KEBOOLA_STORAGE_TOKEN}"
|
storage_token: "${KEBOOLA_STORAGE_TOKEN}"
|
||||||
stack_url: "" # e.g., "https://connection.keboola.com"
|
stack_url: "" # e.g., "https://connection.keboola.com"
|
||||||
project_id: ""
|
project_id: ""
|
||||||
|
bigquery:
|
||||||
|
project: "${BIGQUERY_PROJECT}" # GCP project for job execution/billing
|
||||||
|
location: "${BIGQUERY_LOCATION}" # BigQuery location (e.g., "us-central1", "US")
|
||||||
|
# Uses ADC (Application Default Credentials) - VM service account on GCP
|
||||||
|
# Data can live in a different project -- use fully-qualified table IDs in data_description.md
|
||||||
|
|
||||||
# --- Email delivery (optional, for magic link auth) ---
|
# --- Email delivery (optional, for magic link auth) ---
|
||||||
# Without SMTP, magic links are shown directly in browser (development mode).
|
# Without SMTP, magic links are shown directly in browser (development mode).
|
||||||
|
|
|
||||||
11
connectors/bigquery/__init__.py
Normal file
11
connectors/bigquery/__init__.py
Normal file
|
|
@ -0,0 +1,11 @@
|
||||||
|
"""
|
||||||
|
BigQuery connector - data source adapter for Google BigQuery.
|
||||||
|
|
||||||
|
Syncs tables from BigQuery using the BigQuery Storage API,
|
||||||
|
converting query results directly to Parquet files via PyArrow
|
||||||
|
(no CSV intermediate step).
|
||||||
|
|
||||||
|
Enable by setting data_source.type: "bigquery" in config/instance.yaml
|
||||||
|
and providing BIGQUERY_PROJECT environment variable.
|
||||||
|
Uses Application Default Credentials (ADC) for authentication.
|
||||||
|
"""
|
||||||
475
connectors/bigquery/adapter.py
Normal file
475
connectors/bigquery/adapter.py
Normal file
|
|
@ -0,0 +1,475 @@
|
||||||
|
"""
|
||||||
|
BigQuery data source adapter.
|
||||||
|
|
||||||
|
Implements the DataSource interface for Google BigQuery.
|
||||||
|
Reads tables via the BigQuery API, converts directly to Parquet files
|
||||||
|
using PyArrow (no CSV intermediate step).
|
||||||
|
"""
|
||||||
|
|
||||||
|
import logging
|
||||||
|
from pathlib import Path
|
||||||
|
from typing import Dict, List, Optional, Any
|
||||||
|
from datetime import datetime, timedelta
|
||||||
|
|
||||||
|
import pyarrow as pa
|
||||||
|
import pyarrow.parquet as pq
|
||||||
|
|
||||||
|
from src.config import get_config, TableConfig
|
||||||
|
from src.data_sync import DataSource, SyncState, _get_uncompressed_size
|
||||||
|
from src.parquet_manager import (
|
||||||
|
convert_date_columns_to_date32,
|
||||||
|
apply_schema_to_table,
|
||||||
|
)
|
||||||
|
from .client import create_client as create_bq_client
|
||||||
|
|
||||||
|
|
||||||
|
logger = logging.getLogger(__name__)
|
||||||
|
|
||||||
|
|
||||||
|
class BigQueryDataSource(DataSource):
|
||||||
|
"""
|
||||||
|
Data source: Google BigQuery.
|
||||||
|
|
||||||
|
Downloads data directly from BigQuery via PyArrow (no CSV step),
|
||||||
|
writes to local Parquet files with schema enforcement.
|
||||||
|
"""
|
||||||
|
|
||||||
|
def __init__(self):
|
||||||
|
"""Initialize BigQuery source with env var validation."""
|
||||||
|
self.config = get_config()
|
||||||
|
self.bq_client = create_bq_client()
|
||||||
|
|
||||||
|
def get_column_metadata(self, table_id: str) -> Optional[Dict[str, Any]]:
|
||||||
|
"""Return BigQuery column metadata for schema generation.
|
||||||
|
|
||||||
|
Returns:
|
||||||
|
{"columns": {"col_name": {"source_type": "...", "description": "..."}}}
|
||||||
|
or None if metadata unavailable.
|
||||||
|
"""
|
||||||
|
raw = self.bq_client.get_table_metadata(table_id)
|
||||||
|
column_types = raw.get("column_types", {})
|
||||||
|
column_descriptions = raw.get("column_descriptions", {})
|
||||||
|
|
||||||
|
if not column_types:
|
||||||
|
return None
|
||||||
|
|
||||||
|
result = {}
|
||||||
|
for col_name, bq_type in column_types.items():
|
||||||
|
entry = {"source_type": bq_type}
|
||||||
|
if col_name in column_descriptions:
|
||||||
|
entry["description"] = column_descriptions[col_name]
|
||||||
|
result[col_name] = entry
|
||||||
|
|
||||||
|
return {"columns": result}
|
||||||
|
|
||||||
|
def discover_tables(self) -> List[Dict[str, Any]]:
|
||||||
|
"""Discover all available tables from BigQuery."""
|
||||||
|
return self.bq_client.discover_all_tables()
|
||||||
|
|
||||||
|
def get_source_name(self) -> str:
|
||||||
|
"""Display name of this data source."""
|
||||||
|
return "Google BigQuery"
|
||||||
|
|
||||||
|
def sync_table(
|
||||||
|
self,
|
||||||
|
table_config: TableConfig,
|
||||||
|
sync_state: SyncState,
|
||||||
|
) -> Dict[str, Any]:
|
||||||
|
"""
|
||||||
|
Synchronize table from BigQuery.
|
||||||
|
|
||||||
|
Dispatches to the appropriate strategy based on table config.
|
||||||
|
|
||||||
|
Args:
|
||||||
|
table_config: Table configuration
|
||||||
|
sync_state: Sync state manager
|
||||||
|
|
||||||
|
Returns:
|
||||||
|
Dictionary with sync result
|
||||||
|
"""
|
||||||
|
logger.info(f"Syncing BQ table: {table_config.name} ({table_config.sync_strategy})")
|
||||||
|
|
||||||
|
# Clear metadata cache for fresh types
|
||||||
|
if table_config.id in self.bq_client.metadata_cache:
|
||||||
|
del self.bq_client.metadata_cache[table_config.id]
|
||||||
|
logger.debug(f"Cleared BQ metadata cache for {table_config.id}")
|
||||||
|
|
||||||
|
try:
|
||||||
|
if table_config.sync_strategy == "full_refresh":
|
||||||
|
result = self._full_refresh(table_config)
|
||||||
|
elif table_config.sync_strategy == "incremental":
|
||||||
|
result = self._incremental_sync(table_config, sync_state)
|
||||||
|
elif table_config.sync_strategy == "partitioned":
|
||||||
|
result = self._partitioned_sync(table_config, sync_state)
|
||||||
|
else:
|
||||||
|
raise ValueError(f"Unknown sync strategy: {table_config.sync_strategy}")
|
||||||
|
|
||||||
|
# Update sync state
|
||||||
|
sync_state.update_sync(
|
||||||
|
table_id=table_config.id,
|
||||||
|
table_name=table_config.name,
|
||||||
|
strategy=table_config.sync_strategy,
|
||||||
|
rows=result["rows"],
|
||||||
|
file_size_bytes=result["file_size_bytes"],
|
||||||
|
columns=result.get("columns", 0),
|
||||||
|
uncompressed_bytes=result.get("uncompressed_bytes", 0),
|
||||||
|
)
|
||||||
|
|
||||||
|
return {
|
||||||
|
"success": True,
|
||||||
|
"rows": result["rows"],
|
||||||
|
"strategy": table_config.sync_strategy,
|
||||||
|
"file_size_mb": result["file_size_bytes"] / 1024 / 1024,
|
||||||
|
}
|
||||||
|
|
||||||
|
except Exception as e:
|
||||||
|
logger.error(f"Error syncing BQ table {table_config.name}: {e}")
|
||||||
|
return {
|
||||||
|
"success": False,
|
||||||
|
"error": str(e),
|
||||||
|
"strategy": table_config.sync_strategy,
|
||||||
|
}
|
||||||
|
|
||||||
|
def _full_refresh(self, table_config: TableConfig) -> Dict[str, Any]:
|
||||||
|
"""
|
||||||
|
Full refresh: read entire table and replace Parquet file.
|
||||||
|
"""
|
||||||
|
logger.info(f"Full refresh: {table_config.name}")
|
||||||
|
|
||||||
|
parquet_path = self.config.get_parquet_path(table_config)
|
||||||
|
date_columns = self.bq_client.get_date_columns(table_config.id)
|
||||||
|
pyarrow_schema = self.bq_client.get_pyarrow_schema(table_config.id)
|
||||||
|
|
||||||
|
# Read full table from BigQuery -> PyArrow
|
||||||
|
arrow_table = self.bq_client.read_table(table_config.id)
|
||||||
|
|
||||||
|
# Apply schema enforcement
|
||||||
|
if date_columns:
|
||||||
|
arrow_table = convert_date_columns_to_date32(arrow_table, date_columns)
|
||||||
|
if pyarrow_schema:
|
||||||
|
arrow_table = apply_schema_to_table(arrow_table, pyarrow_schema)
|
||||||
|
|
||||||
|
# Write to Parquet
|
||||||
|
parquet_path.parent.mkdir(parents=True, exist_ok=True)
|
||||||
|
pq.write_table(arrow_table, parquet_path, compression="snappy")
|
||||||
|
|
||||||
|
file_size = parquet_path.stat().st_size
|
||||||
|
logger.info(
|
||||||
|
f"Full refresh complete: {arrow_table.num_rows} rows, "
|
||||||
|
f"{file_size / 1024 / 1024:.2f} MB"
|
||||||
|
)
|
||||||
|
|
||||||
|
return {
|
||||||
|
"rows": arrow_table.num_rows,
|
||||||
|
"columns": arrow_table.num_columns,
|
||||||
|
"file_size_bytes": file_size,
|
||||||
|
"uncompressed_bytes": _get_uncompressed_size(parquet_path),
|
||||||
|
}
|
||||||
|
|
||||||
|
def _incremental_sync(
|
||||||
|
self,
|
||||||
|
table_config: TableConfig,
|
||||||
|
sync_state: SyncState,
|
||||||
|
) -> Dict[str, Any]:
|
||||||
|
"""
|
||||||
|
Incremental sync: dispatch to column-based or partition-based strategy.
|
||||||
|
"""
|
||||||
|
# If partition_by is set, use partitioned incremental
|
||||||
|
if table_config.partition_by:
|
||||||
|
return self._partitioned_sync(table_config, sync_state)
|
||||||
|
|
||||||
|
# If incremental_column is set, use timestamp-based incremental
|
||||||
|
if table_config.incremental_column:
|
||||||
|
return self._incremental_column_sync(table_config, sync_state)
|
||||||
|
|
||||||
|
# Fallback: full refresh (no incremental column configured)
|
||||||
|
logger.warning(
|
||||||
|
f"Table {table_config.name}: incremental strategy but no "
|
||||||
|
f"incremental_column or partition_by configured, falling back to full refresh"
|
||||||
|
)
|
||||||
|
return self._full_refresh(table_config)
|
||||||
|
|
||||||
|
def _incremental_column_sync(
|
||||||
|
self,
|
||||||
|
table_config: TableConfig,
|
||||||
|
sync_state: SyncState,
|
||||||
|
) -> Dict[str, Any]:
|
||||||
|
"""
|
||||||
|
Timestamp-based incremental sync using incremental_column.
|
||||||
|
|
||||||
|
Reads rows WHERE incremental_column > last_sync_value,
|
||||||
|
merges with existing Parquet (dedup on primary key).
|
||||||
|
"""
|
||||||
|
logger.info(
|
||||||
|
f"Incremental column sync: {table_config.name} "
|
||||||
|
f"(column: {table_config.incremental_column})"
|
||||||
|
)
|
||||||
|
|
||||||
|
parquet_path = self.config.get_parquet_path(table_config)
|
||||||
|
date_columns = self.bq_client.get_date_columns(table_config.id)
|
||||||
|
pyarrow_schema = self.bq_client.get_pyarrow_schema(table_config.id)
|
||||||
|
|
||||||
|
# Determine since_value from last sync
|
||||||
|
last_sync = sync_state.get_last_sync(table_config.id)
|
||||||
|
|
||||||
|
if last_sync and parquet_path.exists():
|
||||||
|
# Apply window: go back incremental_window_days from last sync
|
||||||
|
last_sync_dt = datetime.fromisoformat(last_sync)
|
||||||
|
window_days = table_config.incremental_window_days or 7
|
||||||
|
since_dt = last_sync_dt - timedelta(days=window_days)
|
||||||
|
since_value = since_dt.isoformat()
|
||||||
|
|
||||||
|
logger.info(f" -> Since: {since_value} (window: {window_days} days)")
|
||||||
|
|
||||||
|
# Read incremental data
|
||||||
|
new_data = self.bq_client.read_table_incremental(
|
||||||
|
table_id=table_config.id,
|
||||||
|
incremental_column=table_config.incremental_column,
|
||||||
|
since_value=since_value,
|
||||||
|
)
|
||||||
|
|
||||||
|
if new_data.num_rows == 0:
|
||||||
|
logger.info(" -> No new data since last sync")
|
||||||
|
existing_pf = pq.ParquetFile(parquet_path)
|
||||||
|
return {
|
||||||
|
"rows": existing_pf.metadata.num_rows,
|
||||||
|
"columns": len(existing_pf.schema_arrow),
|
||||||
|
"file_size_bytes": parquet_path.stat().st_size,
|
||||||
|
"uncompressed_bytes": _get_uncompressed_size(parquet_path),
|
||||||
|
}
|
||||||
|
|
||||||
|
# Merge with existing data
|
||||||
|
logger.info(f" -> Merging {new_data.num_rows} new rows with existing data")
|
||||||
|
existing_table = pq.read_table(parquet_path)
|
||||||
|
merged = self._merge_arrow_tables(
|
||||||
|
existing_table, new_data, table_config.get_primary_key_columns()
|
||||||
|
)
|
||||||
|
|
||||||
|
# Apply schema enforcement
|
||||||
|
if date_columns:
|
||||||
|
merged = convert_date_columns_to_date32(merged, date_columns)
|
||||||
|
if pyarrow_schema:
|
||||||
|
merged = apply_schema_to_table(merged, pyarrow_schema)
|
||||||
|
|
||||||
|
pq.write_table(merged, parquet_path, compression="snappy")
|
||||||
|
|
||||||
|
file_size = parquet_path.stat().st_size
|
||||||
|
logger.info(
|
||||||
|
f" -> Incremental sync complete: {merged.num_rows} total rows"
|
||||||
|
)
|
||||||
|
|
||||||
|
return {
|
||||||
|
"rows": merged.num_rows,
|
||||||
|
"columns": merged.num_columns,
|
||||||
|
"file_size_bytes": file_size,
|
||||||
|
"uncompressed_bytes": _get_uncompressed_size(parquet_path),
|
||||||
|
}
|
||||||
|
|
||||||
|
else:
|
||||||
|
# First sync or no existing file -- full read
|
||||||
|
logger.info(" -> First sync, reading all data")
|
||||||
|
|
||||||
|
if table_config.max_history_days:
|
||||||
|
since_dt = datetime.now() - timedelta(days=table_config.max_history_days)
|
||||||
|
arrow_table = self.bq_client.read_table_incremental(
|
||||||
|
table_id=table_config.id,
|
||||||
|
incremental_column=table_config.incremental_column,
|
||||||
|
since_value=since_dt.isoformat(),
|
||||||
|
)
|
||||||
|
else:
|
||||||
|
arrow_table = self.bq_client.read_table(table_config.id)
|
||||||
|
|
||||||
|
# Apply schema enforcement
|
||||||
|
if date_columns:
|
||||||
|
arrow_table = convert_date_columns_to_date32(arrow_table, date_columns)
|
||||||
|
if pyarrow_schema:
|
||||||
|
arrow_table = apply_schema_to_table(arrow_table, pyarrow_schema)
|
||||||
|
|
||||||
|
parquet_path.parent.mkdir(parents=True, exist_ok=True)
|
||||||
|
pq.write_table(arrow_table, parquet_path, compression="snappy")
|
||||||
|
|
||||||
|
file_size = parquet_path.stat().st_size
|
||||||
|
return {
|
||||||
|
"rows": arrow_table.num_rows,
|
||||||
|
"columns": arrow_table.num_columns,
|
||||||
|
"file_size_bytes": file_size,
|
||||||
|
"uncompressed_bytes": _get_uncompressed_size(parquet_path),
|
||||||
|
}
|
||||||
|
|
||||||
|
def _partitioned_sync(
|
||||||
|
self,
|
||||||
|
table_config: TableConfig,
|
||||||
|
sync_state: SyncState,
|
||||||
|
) -> Dict[str, Any]:
|
||||||
|
"""
|
||||||
|
Partition-based sync: read data by partition range and write partition files.
|
||||||
|
"""
|
||||||
|
import pandas as pd
|
||||||
|
|
||||||
|
partition_col = table_config.partition_by
|
||||||
|
if not partition_col and table_config.incremental_column:
|
||||||
|
partition_col = table_config.incremental_column
|
||||||
|
|
||||||
|
if not partition_col:
|
||||||
|
logger.warning(
|
||||||
|
f"Table {table_config.name}: partitioned strategy but no "
|
||||||
|
f"partition_by or incremental_column, falling back to full refresh"
|
||||||
|
)
|
||||||
|
return self._full_refresh(table_config)
|
||||||
|
|
||||||
|
granularity = table_config.partition_granularity or "month"
|
||||||
|
logger.info(
|
||||||
|
f"Partitioned sync: {table_config.name} "
|
||||||
|
f"(by {partition_col}, {granularity})"
|
||||||
|
)
|
||||||
|
|
||||||
|
partition_dir = self.config.get_parquet_path(table_config)
|
||||||
|
date_columns = self.bq_client.get_date_columns(table_config.id)
|
||||||
|
pyarrow_schema = self.bq_client.get_pyarrow_schema(table_config.id)
|
||||||
|
|
||||||
|
# Determine time range
|
||||||
|
last_sync = sync_state.get_last_sync(table_config.id)
|
||||||
|
|
||||||
|
if last_sync:
|
||||||
|
last_sync_dt = datetime.fromisoformat(last_sync)
|
||||||
|
window_days = table_config.incremental_window_days or 7
|
||||||
|
start_dt = last_sync_dt - timedelta(days=window_days)
|
||||||
|
logger.info(f" -> Reading from {start_dt.isoformat()} (window: {window_days} days)")
|
||||||
|
else:
|
||||||
|
if table_config.max_history_days:
|
||||||
|
start_dt = datetime.now() - timedelta(days=table_config.max_history_days)
|
||||||
|
logger.info(f" -> First sync, limited to last {table_config.max_history_days} days")
|
||||||
|
else:
|
||||||
|
start_dt = None
|
||||||
|
logger.info(" -> First sync, reading all data")
|
||||||
|
|
||||||
|
# Read data from BigQuery
|
||||||
|
if start_dt:
|
||||||
|
arrow_table = self.bq_client.read_table_partitioned(
|
||||||
|
table_id=table_config.id,
|
||||||
|
partition_column=partition_col,
|
||||||
|
start=start_dt.isoformat(),
|
||||||
|
)
|
||||||
|
else:
|
||||||
|
arrow_table = self.bq_client.read_table(table_config.id)
|
||||||
|
|
||||||
|
if arrow_table.num_rows == 0:
|
||||||
|
logger.info(" -> No data to sync")
|
||||||
|
return self._get_partition_totals(partition_dir)
|
||||||
|
|
||||||
|
logger.info(f" -> Processing {arrow_table.num_rows} rows into partitions")
|
||||||
|
|
||||||
|
# Convert to pandas for partitioning
|
||||||
|
df = arrow_table.to_pandas()
|
||||||
|
|
||||||
|
# Ensure partition column is datetime
|
||||||
|
if not pd.api.types.is_datetime64_any_dtype(df[partition_col]):
|
||||||
|
df[partition_col] = pd.to_datetime(df[partition_col], format="ISO8601", utc=True)
|
||||||
|
|
||||||
|
# Create partition key
|
||||||
|
if granularity == "month":
|
||||||
|
df["_partition_key"] = df[partition_col].dt.strftime("%Y_%m")
|
||||||
|
elif granularity == "day":
|
||||||
|
df["_partition_key"] = df[partition_col].dt.strftime("%Y_%m_%d")
|
||||||
|
elif granularity == "year":
|
||||||
|
df["_partition_key"] = df[partition_col].dt.strftime("%Y")
|
||||||
|
|
||||||
|
primary_key_cols = table_config.get_primary_key_columns()
|
||||||
|
partitions_updated = set()
|
||||||
|
|
||||||
|
for partition_key, group_df in df.groupby("_partition_key"):
|
||||||
|
group_df = group_df.drop(columns=["_partition_key"])
|
||||||
|
partition_path = self.config.get_partition_path(table_config, partition_key)
|
||||||
|
partitions_updated.add(partition_key)
|
||||||
|
|
||||||
|
# Merge with existing partition if it exists
|
||||||
|
if partition_path.exists():
|
||||||
|
existing_df = pd.read_parquet(partition_path)
|
||||||
|
merged_df = pd.concat([existing_df, group_df], ignore_index=True)
|
||||||
|
merged_df = merged_df.drop_duplicates(subset=primary_key_cols, keep="last")
|
||||||
|
else:
|
||||||
|
merged_df = group_df
|
||||||
|
|
||||||
|
# Write partition
|
||||||
|
table = pa.Table.from_pandas(merged_df, preserve_index=False)
|
||||||
|
if date_columns:
|
||||||
|
table = convert_date_columns_to_date32(table, date_columns)
|
||||||
|
if pyarrow_schema:
|
||||||
|
table = apply_schema_to_table(table, pyarrow_schema)
|
||||||
|
pq.write_table(table, partition_path, compression="snappy")
|
||||||
|
|
||||||
|
logger.info(f" -> Partitioned sync complete: {len(partitions_updated)} partitions updated")
|
||||||
|
|
||||||
|
return self._get_partition_totals(partition_dir)
|
||||||
|
|
||||||
|
def _merge_arrow_tables(
|
||||||
|
self,
|
||||||
|
existing: pa.Table,
|
||||||
|
new_data: pa.Table,
|
||||||
|
primary_key: List[str],
|
||||||
|
) -> pa.Table:
|
||||||
|
"""
|
||||||
|
Merge two Arrow tables with deduplication on primary key.
|
||||||
|
|
||||||
|
New data overwrites existing rows with the same primary key.
|
||||||
|
|
||||||
|
Args:
|
||||||
|
existing: Existing data
|
||||||
|
new_data: New/changed data
|
||||||
|
primary_key: List of PK column names
|
||||||
|
|
||||||
|
Returns:
|
||||||
|
Merged PyArrow Table
|
||||||
|
"""
|
||||||
|
import pandas as pd
|
||||||
|
|
||||||
|
existing_df = existing.to_pandas()
|
||||||
|
new_df = new_data.to_pandas()
|
||||||
|
|
||||||
|
# Concat and dedup (keep last = new data wins)
|
||||||
|
merged_df = pd.concat([existing_df, new_df], ignore_index=True)
|
||||||
|
merged_df = merged_df.drop_duplicates(subset=primary_key, keep="last")
|
||||||
|
|
||||||
|
return pa.Table.from_pandas(merged_df, preserve_index=False)
|
||||||
|
|
||||||
|
def _get_partition_totals(self, partition_dir: Path) -> Dict[str, Any]:
|
||||||
|
"""
|
||||||
|
Calculate totals from all partition files in a directory.
|
||||||
|
"""
|
||||||
|
total_rows = 0
|
||||||
|
total_size = 0
|
||||||
|
total_uncompressed = 0
|
||||||
|
total_columns = 0
|
||||||
|
|
||||||
|
if not partition_dir.exists():
|
||||||
|
return {"rows": 0, "file_size_bytes": 0, "columns": 0, "uncompressed_bytes": 0}
|
||||||
|
|
||||||
|
all_partitions = list(partition_dir.glob("*.parquet"))
|
||||||
|
|
||||||
|
for part_path in all_partitions:
|
||||||
|
try:
|
||||||
|
pf = pq.ParquetFile(part_path)
|
||||||
|
meta = pf.metadata
|
||||||
|
total_rows += meta.num_rows
|
||||||
|
total_size += part_path.stat().st_size
|
||||||
|
if total_columns == 0:
|
||||||
|
total_columns = len(pf.schema_arrow)
|
||||||
|
for rg_idx in range(meta.num_row_groups):
|
||||||
|
rg = meta.row_group(rg_idx)
|
||||||
|
for col_idx in range(rg.num_columns):
|
||||||
|
total_uncompressed += rg.column(col_idx).total_uncompressed_size
|
||||||
|
except Exception as e:
|
||||||
|
logger.warning(f"Skipping corrupt partition {part_path.name}: {e}")
|
||||||
|
|
||||||
|
return {
|
||||||
|
"rows": total_rows,
|
||||||
|
"file_size_bytes": total_size,
|
||||||
|
"partitions": len(all_partitions),
|
||||||
|
"columns": total_columns,
|
||||||
|
"uncompressed_bytes": total_uncompressed,
|
||||||
|
}
|
||||||
|
|
||||||
|
|
||||||
|
def create_data_source() -> BigQueryDataSource:
|
||||||
|
"""Factory function for dynamic import compatibility."""
|
||||||
|
return BigQueryDataSource()
|
||||||
469
connectors/bigquery/client.py
Normal file
469
connectors/bigquery/client.py
Normal file
|
|
@ -0,0 +1,469 @@
|
||||||
|
"""
|
||||||
|
Google BigQuery API Client
|
||||||
|
|
||||||
|
Low-level wrapper for Google BigQuery with these functions:
|
||||||
|
1. Authentication using Application Default Credentials (ADC)
|
||||||
|
2. Query tables to PyArrow (no CSV intermediate step)
|
||||||
|
3. Get table metadata (schema, columns, data types)
|
||||||
|
4. Cache metadata for faster repeated use
|
||||||
|
5. Incremental reads (timestamp-based and partition-based)
|
||||||
|
|
||||||
|
Uses google-cloud-bigquery with native PyArrow support.
|
||||||
|
"""
|
||||||
|
|
||||||
|
import json
|
||||||
|
import logging
|
||||||
|
import os
|
||||||
|
from pathlib import Path
|
||||||
|
from typing import Dict, List, Optional, Any
|
||||||
|
from datetime import datetime, timedelta
|
||||||
|
|
||||||
|
import pyarrow as pa
|
||||||
|
from google.cloud import bigquery
|
||||||
|
|
||||||
|
from src.config import get_config
|
||||||
|
|
||||||
|
|
||||||
|
logger = logging.getLogger(__name__)
|
||||||
|
|
||||||
|
|
||||||
|
# Mapping BigQuery types to PyArrow types
|
||||||
|
BIGQUERY_TO_PYARROW_TYPES = {
|
||||||
|
"STRING": pa.string(),
|
||||||
|
"BYTES": pa.binary(),
|
||||||
|
"INTEGER": pa.int64(),
|
||||||
|
"INT64": pa.int64(),
|
||||||
|
"FLOAT": pa.float64(),
|
||||||
|
"FLOAT64": pa.float64(),
|
||||||
|
"NUMERIC": pa.float64(),
|
||||||
|
"BIGNUMERIC": pa.float64(),
|
||||||
|
"BOOLEAN": pa.bool_(),
|
||||||
|
"BOOL": pa.bool_(),
|
||||||
|
"TIMESTAMP": pa.timestamp("us", tz="UTC"),
|
||||||
|
"DATE": pa.date32(),
|
||||||
|
"TIME": pa.string(),
|
||||||
|
"DATETIME": pa.timestamp("us"),
|
||||||
|
"GEOGRAPHY": pa.string(),
|
||||||
|
"JSON": pa.string(),
|
||||||
|
"STRUCT": pa.string(),
|
||||||
|
"RECORD": pa.string(),
|
||||||
|
"ARRAY": pa.string(),
|
||||||
|
}
|
||||||
|
|
||||||
|
|
||||||
|
class BigQueryClient:
|
||||||
|
"""
|
||||||
|
Wrapper for Google BigQuery API.
|
||||||
|
|
||||||
|
Provides high-level methods for working with BigQuery tables:
|
||||||
|
- Query tables to PyArrow Tables (no CSV step)
|
||||||
|
- Get metadata (schema, columns)
|
||||||
|
- Incremental and partitioned reads
|
||||||
|
"""
|
||||||
|
|
||||||
|
def __init__(
|
||||||
|
self,
|
||||||
|
project_id: Optional[str] = None,
|
||||||
|
location: Optional[str] = None,
|
||||||
|
):
|
||||||
|
"""
|
||||||
|
Initialize BigQuery client.
|
||||||
|
|
||||||
|
Args:
|
||||||
|
project_id: GCP project ID for job execution/billing.
|
||||||
|
If None, reads from BIGQUERY_PROJECT env var.
|
||||||
|
location: BigQuery location for job execution (e.g., "us-central1").
|
||||||
|
If None, reads from BIGQUERY_LOCATION env var.
|
||||||
|
|
||||||
|
Raises:
|
||||||
|
ValueError: If project_id is not provided and BIGQUERY_PROJECT is not set.
|
||||||
|
"""
|
||||||
|
self.project_id = project_id or os.environ.get("BIGQUERY_PROJECT")
|
||||||
|
|
||||||
|
if not self.project_id:
|
||||||
|
raise ValueError(
|
||||||
|
"BigQuery project ID not set. "
|
||||||
|
"Set BIGQUERY_PROJECT environment variable."
|
||||||
|
)
|
||||||
|
|
||||||
|
self.location = location or os.environ.get("BIGQUERY_LOCATION")
|
||||||
|
|
||||||
|
# Initialize BigQuery client with ADC
|
||||||
|
# project_id is used for job execution and billing.
|
||||||
|
# Data can live in a different project -- table IDs in queries
|
||||||
|
# use fully-qualified format (project.dataset.table).
|
||||||
|
client_kwargs = {"project": self.project_id}
|
||||||
|
if self.location:
|
||||||
|
client_kwargs["location"] = self.location
|
||||||
|
self.client = bigquery.Client(**client_kwargs)
|
||||||
|
|
||||||
|
# Metadata cache
|
||||||
|
config = get_config()
|
||||||
|
self.metadata_cache: Dict[str, Dict[str, Any]] = {}
|
||||||
|
self.metadata_cache_path = config.get_metadata_path() / "bq_table_metadata.json"
|
||||||
|
|
||||||
|
# Load cache from disk if exists
|
||||||
|
self._load_metadata_cache()
|
||||||
|
|
||||||
|
logger.info(
|
||||||
|
f"BigQuery client initialized: project={self.project_id}, "
|
||||||
|
f"location={self.location or 'auto'}"
|
||||||
|
)
|
||||||
|
|
||||||
|
def _load_metadata_cache(self):
|
||||||
|
"""Load metadata cache from disk."""
|
||||||
|
if self.metadata_cache_path.exists():
|
||||||
|
try:
|
||||||
|
with open(self.metadata_cache_path, "r") as f:
|
||||||
|
self.metadata_cache = json.load(f)
|
||||||
|
logger.info(
|
||||||
|
f"BQ metadata cache loaded: {len(self.metadata_cache)} tables"
|
||||||
|
)
|
||||||
|
except Exception as e:
|
||||||
|
logger.warning(f"Error loading BQ metadata cache: {e}")
|
||||||
|
self.metadata_cache = {}
|
||||||
|
|
||||||
|
def _save_metadata_cache(self):
|
||||||
|
"""Save metadata cache to disk."""
|
||||||
|
try:
|
||||||
|
self.metadata_cache_path.parent.mkdir(parents=True, exist_ok=True)
|
||||||
|
with open(self.metadata_cache_path, "w") as f:
|
||||||
|
json.dump(self.metadata_cache, f, indent=2)
|
||||||
|
logger.debug("BQ metadata cache saved")
|
||||||
|
except Exception as e:
|
||||||
|
logger.warning(f"Error saving BQ metadata cache: {e}")
|
||||||
|
|
||||||
|
def get_table_metadata(
|
||||||
|
self,
|
||||||
|
table_id: str,
|
||||||
|
use_cache: bool = True,
|
||||||
|
cache_ttl_hours: int = 24,
|
||||||
|
) -> Dict[str, Any]:
|
||||||
|
"""
|
||||||
|
Get table metadata from BigQuery.
|
||||||
|
|
||||||
|
Args:
|
||||||
|
table_id: Full table ID (e.g., "project.dataset.table")
|
||||||
|
use_cache: Use cache if available
|
||||||
|
cache_ttl_hours: Cache TTL in hours (default 24h)
|
||||||
|
|
||||||
|
Returns:
|
||||||
|
Dictionary with metadata including columns, types, descriptions, row count.
|
||||||
|
"""
|
||||||
|
# Check cache
|
||||||
|
if use_cache and table_id in self.metadata_cache:
|
||||||
|
cached = self.metadata_cache[table_id]
|
||||||
|
cached_time = datetime.fromisoformat(cached.get("_cached_at", "2000-01-01"))
|
||||||
|
cache_age = datetime.now() - cached_time
|
||||||
|
|
||||||
|
if cache_age < timedelta(hours=cache_ttl_hours):
|
||||||
|
logger.debug(f"Using BQ metadata cache for {table_id}")
|
||||||
|
return cached
|
||||||
|
|
||||||
|
logger.info(f"Fetching metadata for BQ table: {table_id}")
|
||||||
|
|
||||||
|
try:
|
||||||
|
table_ref = self.client.get_table(table_id)
|
||||||
|
|
||||||
|
# Build column metadata
|
||||||
|
columns = []
|
||||||
|
column_types = {}
|
||||||
|
column_descriptions = {}
|
||||||
|
for field in table_ref.schema:
|
||||||
|
columns.append(field.name)
|
||||||
|
column_types[field.name] = field.field_type
|
||||||
|
if field.description:
|
||||||
|
column_descriptions[field.name] = field.description
|
||||||
|
|
||||||
|
metadata = {
|
||||||
|
"table_id": table_id,
|
||||||
|
"name": table_ref.table_id,
|
||||||
|
"dataset": table_ref.dataset_id,
|
||||||
|
"project": table_ref.project,
|
||||||
|
"columns": columns,
|
||||||
|
"column_types": column_types,
|
||||||
|
"column_descriptions": column_descriptions,
|
||||||
|
"row_count": table_ref.num_rows,
|
||||||
|
"size_bytes": table_ref.num_bytes,
|
||||||
|
"created": table_ref.created.isoformat() if table_ref.created else None,
|
||||||
|
"modified": table_ref.modified.isoformat() if table_ref.modified else None,
|
||||||
|
"partitioning": None,
|
||||||
|
"_cached_at": datetime.now().isoformat(),
|
||||||
|
}
|
||||||
|
|
||||||
|
# Capture partitioning info
|
||||||
|
if table_ref.time_partitioning:
|
||||||
|
metadata["partitioning"] = {
|
||||||
|
"type": table_ref.time_partitioning.type_,
|
||||||
|
"field": table_ref.time_partitioning.field,
|
||||||
|
"expiration_ms": table_ref.time_partitioning.expiration_ms,
|
||||||
|
}
|
||||||
|
|
||||||
|
# Save to cache
|
||||||
|
self.metadata_cache[table_id] = metadata
|
||||||
|
self._save_metadata_cache()
|
||||||
|
|
||||||
|
return metadata
|
||||||
|
|
||||||
|
except Exception as e:
|
||||||
|
logger.error(f"Error getting metadata for {table_id}: {e}")
|
||||||
|
raise
|
||||||
|
|
||||||
|
def get_pyarrow_schema(self, table_id: str) -> Optional[pa.Schema]:
|
||||||
|
"""
|
||||||
|
Build PyArrow schema from BigQuery table schema.
|
||||||
|
|
||||||
|
Args:
|
||||||
|
table_id: Full table ID
|
||||||
|
|
||||||
|
Returns:
|
||||||
|
PyArrow schema or None if metadata unavailable
|
||||||
|
"""
|
||||||
|
metadata = self.get_table_metadata(table_id)
|
||||||
|
column_types = metadata.get("column_types", {})
|
||||||
|
|
||||||
|
if not column_types:
|
||||||
|
logger.warning(f"No column types for {table_id}, schema will not be applied")
|
||||||
|
return None
|
||||||
|
|
||||||
|
fields = []
|
||||||
|
for col_name in metadata.get("columns", []):
|
||||||
|
bq_type = column_types.get(col_name, "STRING")
|
||||||
|
pa_type = BIGQUERY_TO_PYARROW_TYPES.get(bq_type, pa.string())
|
||||||
|
fields.append(pa.field(col_name, pa_type))
|
||||||
|
|
||||||
|
return pa.schema(fields)
|
||||||
|
|
||||||
|
def get_date_columns(self, table_id: str) -> List[str]:
|
||||||
|
"""
|
||||||
|
Get list of DATE-only columns for a table.
|
||||||
|
|
||||||
|
Args:
|
||||||
|
table_id: Full table ID
|
||||||
|
|
||||||
|
Returns:
|
||||||
|
List of column names that have DATE type in BigQuery
|
||||||
|
"""
|
||||||
|
metadata = self.get_table_metadata(table_id)
|
||||||
|
column_types = metadata.get("column_types", {})
|
||||||
|
|
||||||
|
return [
|
||||||
|
col_name for col_name, bq_type in column_types.items()
|
||||||
|
if bq_type == "DATE"
|
||||||
|
]
|
||||||
|
|
||||||
|
def query_to_arrow(
|
||||||
|
self,
|
||||||
|
sql: str,
|
||||||
|
params: Optional[List[bigquery.ScalarQueryParameter]] = None,
|
||||||
|
) -> pa.Table:
|
||||||
|
"""
|
||||||
|
Execute SQL query and return results as PyArrow Table.
|
||||||
|
|
||||||
|
Args:
|
||||||
|
sql: SQL query string (use @param_name for parameterized values)
|
||||||
|
params: List of BigQuery query parameters
|
||||||
|
|
||||||
|
Returns:
|
||||||
|
PyArrow Table with query results
|
||||||
|
"""
|
||||||
|
job_config = bigquery.QueryJobConfig()
|
||||||
|
if params:
|
||||||
|
job_config.query_parameters = params
|
||||||
|
|
||||||
|
logger.debug(f"Executing BQ query: {sql[:200]}...")
|
||||||
|
|
||||||
|
query_job = self.client.query(sql, job_config=job_config)
|
||||||
|
arrow_table = query_job.to_arrow()
|
||||||
|
|
||||||
|
logger.debug(f"Query returned {arrow_table.num_rows} rows, {arrow_table.num_columns} columns")
|
||||||
|
return arrow_table
|
||||||
|
|
||||||
|
def read_table(
|
||||||
|
self,
|
||||||
|
table_id: str,
|
||||||
|
columns: Optional[List[str]] = None,
|
||||||
|
row_filter: Optional[str] = None,
|
||||||
|
) -> pa.Table:
|
||||||
|
"""
|
||||||
|
Read full table (or filtered subset) as PyArrow Table.
|
||||||
|
|
||||||
|
Args:
|
||||||
|
table_id: Full table ID (e.g., "project.dataset.table")
|
||||||
|
columns: Optional list of columns to select
|
||||||
|
row_filter: Optional SQL WHERE clause (without WHERE keyword)
|
||||||
|
|
||||||
|
Returns:
|
||||||
|
PyArrow Table with table data
|
||||||
|
"""
|
||||||
|
# Build SELECT clause
|
||||||
|
select_cols = ", ".join(f"`{c}`" for c in columns) if columns else "*"
|
||||||
|
|
||||||
|
sql = f"SELECT {select_cols} FROM `{table_id}`"
|
||||||
|
if row_filter:
|
||||||
|
sql += f" WHERE {row_filter}"
|
||||||
|
|
||||||
|
logger.info(f"Reading BQ table: {table_id} (filter: {row_filter or 'none'})")
|
||||||
|
return self.query_to_arrow(sql)
|
||||||
|
|
||||||
|
def read_table_incremental(
|
||||||
|
self,
|
||||||
|
table_id: str,
|
||||||
|
incremental_column: str,
|
||||||
|
since_value: str,
|
||||||
|
columns: Optional[List[str]] = None,
|
||||||
|
) -> pa.Table:
|
||||||
|
"""
|
||||||
|
Read rows where incremental_column > since_value.
|
||||||
|
|
||||||
|
Uses parameterized query to prevent SQL injection.
|
||||||
|
|
||||||
|
Args:
|
||||||
|
table_id: Full table ID
|
||||||
|
incremental_column: Column name for incremental filter
|
||||||
|
since_value: ISO timestamp string - fetch rows after this value
|
||||||
|
columns: Optional list of columns to select
|
||||||
|
|
||||||
|
Returns:
|
||||||
|
PyArrow Table with incremental data
|
||||||
|
"""
|
||||||
|
select_cols = ", ".join(f"`{c}`" for c in columns) if columns else "*"
|
||||||
|
|
||||||
|
sql = (
|
||||||
|
f"SELECT {select_cols} FROM `{table_id}` "
|
||||||
|
f"WHERE `{incremental_column}` > @since_value"
|
||||||
|
)
|
||||||
|
|
||||||
|
params = [
|
||||||
|
bigquery.ScalarQueryParameter("since_value", "TIMESTAMP", since_value),
|
||||||
|
]
|
||||||
|
|
||||||
|
logger.info(
|
||||||
|
f"Incremental read: {table_id} WHERE {incremental_column} > {since_value}"
|
||||||
|
)
|
||||||
|
return self.query_to_arrow(sql, params=params)
|
||||||
|
|
||||||
|
def read_table_partitioned(
|
||||||
|
self,
|
||||||
|
table_id: str,
|
||||||
|
partition_column: str,
|
||||||
|
start: str,
|
||||||
|
end: Optional[str] = None,
|
||||||
|
columns: Optional[List[str]] = None,
|
||||||
|
) -> pa.Table:
|
||||||
|
"""
|
||||||
|
Read data within a partition range.
|
||||||
|
|
||||||
|
Args:
|
||||||
|
table_id: Full table ID
|
||||||
|
partition_column: Partition column name
|
||||||
|
start: Start date/timestamp (inclusive)
|
||||||
|
end: End date/timestamp (exclusive). If None, reads to present.
|
||||||
|
columns: Optional list of columns to select
|
||||||
|
|
||||||
|
Returns:
|
||||||
|
PyArrow Table with partition range data
|
||||||
|
"""
|
||||||
|
select_cols = ", ".join(f"`{c}`" for c in columns) if columns else "*"
|
||||||
|
|
||||||
|
sql = (
|
||||||
|
f"SELECT {select_cols} FROM `{table_id}` "
|
||||||
|
f"WHERE `{partition_column}` >= @start_value"
|
||||||
|
)
|
||||||
|
params = [
|
||||||
|
bigquery.ScalarQueryParameter("start_value", "TIMESTAMP", start),
|
||||||
|
]
|
||||||
|
|
||||||
|
if end:
|
||||||
|
sql += f" AND `{partition_column}` < @end_value"
|
||||||
|
params.append(
|
||||||
|
bigquery.ScalarQueryParameter("end_value", "TIMESTAMP", end),
|
||||||
|
)
|
||||||
|
|
||||||
|
logger.info(
|
||||||
|
f"Partitioned read: {table_id} [{start} .. {end or 'now'})"
|
||||||
|
)
|
||||||
|
return self.query_to_arrow(sql, params=params)
|
||||||
|
|
||||||
|
def discover_all_tables(self, dataset_id: Optional[str] = None) -> List[Dict[str, Any]]:
|
||||||
|
"""
|
||||||
|
List all tables in the project (or specific dataset).
|
||||||
|
|
||||||
|
Args:
|
||||||
|
dataset_id: Optional dataset ID to limit scope
|
||||||
|
|
||||||
|
Returns:
|
||||||
|
Normalized list of table dicts with id, name, columns, row_count, etc.
|
||||||
|
"""
|
||||||
|
logger.info(f"Discovering BQ tables (dataset={dataset_id or 'all'})...")
|
||||||
|
|
||||||
|
result = []
|
||||||
|
|
||||||
|
if dataset_id:
|
||||||
|
datasets = [self.client.get_dataset(dataset_id)]
|
||||||
|
else:
|
||||||
|
datasets = list(self.client.list_datasets())
|
||||||
|
|
||||||
|
for dataset in datasets:
|
||||||
|
ds_ref = dataset.reference if hasattr(dataset, "reference") else dataset.dataset_id
|
||||||
|
ds_id = str(ds_ref)
|
||||||
|
|
||||||
|
try:
|
||||||
|
tables = list(self.client.list_tables(ds_ref))
|
||||||
|
except Exception as e:
|
||||||
|
logger.warning(f"Could not list tables in dataset {ds_id}: {e}")
|
||||||
|
continue
|
||||||
|
|
||||||
|
for table_item in tables:
|
||||||
|
full_id = f"{table_item.project}.{table_item.dataset_id}.{table_item.table_id}"
|
||||||
|
|
||||||
|
try:
|
||||||
|
table_detail = self.client.get_table(full_id)
|
||||||
|
columns = [f.name for f in table_detail.schema]
|
||||||
|
|
||||||
|
result.append({
|
||||||
|
"id": full_id,
|
||||||
|
"name": table_item.table_id,
|
||||||
|
"bucket_id": table_item.dataset_id,
|
||||||
|
"bucket_name": table_item.dataset_id,
|
||||||
|
"columns": columns,
|
||||||
|
"row_count": table_detail.num_rows or 0,
|
||||||
|
"size_bytes": table_detail.num_bytes or 0,
|
||||||
|
"primary_key": [],
|
||||||
|
"last_change": (
|
||||||
|
table_detail.modified.isoformat()
|
||||||
|
if table_detail.modified else None
|
||||||
|
),
|
||||||
|
"last_import": None,
|
||||||
|
})
|
||||||
|
except Exception as e:
|
||||||
|
logger.warning(f"Could not get details for {full_id}: {e}")
|
||||||
|
|
||||||
|
logger.info(f"Discovered {len(result)} BQ tables")
|
||||||
|
return result
|
||||||
|
|
||||||
|
def test_connection(self) -> bool:
|
||||||
|
"""
|
||||||
|
Test connection to BigQuery.
|
||||||
|
|
||||||
|
Returns:
|
||||||
|
True if connection works, False otherwise
|
||||||
|
"""
|
||||||
|
try:
|
||||||
|
query_job = self.client.query("SELECT 1")
|
||||||
|
list(query_job.result())
|
||||||
|
logger.info(f"BigQuery connection OK (project: {self.project_id})")
|
||||||
|
return True
|
||||||
|
except Exception as e:
|
||||||
|
logger.error(f"BigQuery connection test failed: {e}")
|
||||||
|
return False
|
||||||
|
|
||||||
|
|
||||||
|
def create_client() -> BigQueryClient:
|
||||||
|
"""
|
||||||
|
Factory function to create BigQuery client.
|
||||||
|
|
||||||
|
Returns:
|
||||||
|
BigQueryClient instance
|
||||||
|
"""
|
||||||
|
return BigQueryClient()
|
||||||
|
|
@ -1,5 +1,7 @@
|
||||||
# Data source adapters (install only what you need)
|
# Data source adapters (install only what you need)
|
||||||
kbcstorage>=0.9.0 # For Keboola adapter
|
kbcstorage>=0.9.0 # For Keboola adapter
|
||||||
|
google-cloud-bigquery>=3.0.0 # For BigQuery adapter
|
||||||
|
google-cloud-bigquery-storage>=2.0.0 # For BigQuery adapter (fast Arrow transfer)
|
||||||
|
|
||||||
# Data processing
|
# Data processing
|
||||||
# pandas - core tabular data processing library
|
# pandas - core tabular data processing library
|
||||||
|
|
|
||||||
|
|
@ -101,6 +101,7 @@ class TableConfig:
|
||||||
max_history_days: Optional[int] = None
|
max_history_days: Optional[int] = None
|
||||||
dataset: Optional[str] = None
|
dataset: Optional[str] = None
|
||||||
initial_load_chunk_days: int = 30
|
initial_load_chunk_days: int = 30
|
||||||
|
incremental_column: Optional[str] = None # Column for timestamp-based incremental sync (BigQuery)
|
||||||
|
|
||||||
def __post_init__(self):
|
def __post_init__(self):
|
||||||
"""Validate configuration after initialization."""
|
"""Validate configuration after initialization."""
|
||||||
|
|
@ -429,6 +430,7 @@ class Config:
|
||||||
max_history_days=table_data.get("max_history_days"),
|
max_history_days=table_data.get("max_history_days"),
|
||||||
dataset=table_data.get("dataset"),
|
dataset=table_data.get("dataset"),
|
||||||
initial_load_chunk_days=table_data.get("initial_load_chunk_days", 30),
|
initial_load_chunk_days=table_data.get("initial_load_chunk_days", 30),
|
||||||
|
incremental_column=table_data.get("incremental_column"),
|
||||||
)
|
)
|
||||||
table_configs.append(config)
|
table_configs.append(config)
|
||||||
|
|
||||||
|
|
|
||||||
|
|
@ -511,7 +511,7 @@ def create_data_source(source_type: str = None) -> DataSource:
|
||||||
|
|
||||||
raise ValueError(
|
raise ValueError(
|
||||||
f"Unknown data source: '{source_type}'. "
|
f"Unknown data source: '{source_type}'. "
|
||||||
f"Available connectors: keboola. "
|
f"Available connectors: keboola, bigquery. "
|
||||||
f"Create connectors/{source_type}/adapter.py to add a new one."
|
f"Create connectors/{source_type}/adapter.py to add a new one."
|
||||||
)
|
)
|
||||||
|
|
||||||
|
|
|
||||||
763
tests/test_bigquery_adapter.py
Normal file
763
tests/test_bigquery_adapter.py
Normal file
|
|
@ -0,0 +1,763 @@
|
||||||
|
"""
|
||||||
|
Comprehensive unit tests for the BigQuery data source adapter.
|
||||||
|
|
||||||
|
Tests the BigQueryDataSource class from connectors/bigquery/adapter.py
|
||||||
|
with all external dependencies (BigQueryClient, config, parquet_manager) mocked.
|
||||||
|
|
||||||
|
The google-cloud-bigquery package is not installed in test environments,
|
||||||
|
so we install stub modules in sys.modules before importing the adapter.
|
||||||
|
"""
|
||||||
|
|
||||||
|
import sys
|
||||||
|
from pathlib import Path
|
||||||
|
from unittest.mock import MagicMock, patch
|
||||||
|
|
||||||
|
import pyarrow as pa
|
||||||
|
import pyarrow.parquet as pq
|
||||||
|
import pytest
|
||||||
|
|
||||||
|
# ---------------------------------------------------------------------------
|
||||||
|
# Stub google.cloud.bigquery before any connector import
|
||||||
|
# ---------------------------------------------------------------------------
|
||||||
|
_bq_stub = MagicMock()
|
||||||
|
sys.modules.setdefault("google", _bq_stub)
|
||||||
|
sys.modules.setdefault("google.cloud", _bq_stub)
|
||||||
|
sys.modules.setdefault("google.cloud.bigquery", _bq_stub)
|
||||||
|
|
||||||
|
from src.config import TableConfig # noqa: E402
|
||||||
|
from src.data_sync import SyncState # noqa: E402
|
||||||
|
|
||||||
|
|
||||||
|
# ---------------------------------------------------------------------------
|
||||||
|
# Fixtures
|
||||||
|
# ---------------------------------------------------------------------------
|
||||||
|
|
||||||
|
@pytest.fixture
|
||||||
|
def tmp_parquet_dir(tmp_path):
|
||||||
|
"""Provide a temporary directory for Parquet file output."""
|
||||||
|
parquet_dir = tmp_path / "parquet" / "test_bucket"
|
||||||
|
parquet_dir.mkdir(parents=True)
|
||||||
|
return parquet_dir
|
||||||
|
|
||||||
|
|
||||||
|
@pytest.fixture
|
||||||
|
def mock_config(tmp_parquet_dir):
|
||||||
|
"""Create a mock Config object that returns paths inside tmp_parquet_dir."""
|
||||||
|
config = MagicMock()
|
||||||
|
config.get_parquet_path = MagicMock()
|
||||||
|
config.get_partition_path = MagicMock()
|
||||||
|
config.get_metadata_path.return_value = tmp_parquet_dir.parent / "metadata"
|
||||||
|
return config
|
||||||
|
|
||||||
|
|
||||||
|
@pytest.fixture
|
||||||
|
def mock_bq_client():
|
||||||
|
"""Create a mock BigQueryClient with sensible defaults."""
|
||||||
|
client = MagicMock()
|
||||||
|
client.metadata_cache = {}
|
||||||
|
client.get_date_columns.return_value = []
|
||||||
|
client.get_pyarrow_schema.return_value = None
|
||||||
|
return client
|
||||||
|
|
||||||
|
|
||||||
|
@pytest.fixture
|
||||||
|
def sync_state(tmp_path):
|
||||||
|
"""Create a real SyncState backed by a temp JSON file."""
|
||||||
|
state_file = tmp_path / "metadata" / "sync_state.json"
|
||||||
|
state_file.parent.mkdir(parents=True, exist_ok=True)
|
||||||
|
return SyncState(state_file)
|
||||||
|
|
||||||
|
|
||||||
|
# ---------------------------------------------------------------------------
|
||||||
|
# Helpers
|
||||||
|
# ---------------------------------------------------------------------------
|
||||||
|
|
||||||
|
def _make_table_config(
|
||||||
|
*,
|
||||||
|
table_id: str = "project.dataset.orders",
|
||||||
|
name: str = "orders",
|
||||||
|
primary_key: str = "id",
|
||||||
|
sync_strategy: str = "full_refresh",
|
||||||
|
incremental_column: str | None = None,
|
||||||
|
incremental_window_days: int | None = None,
|
||||||
|
partition_by: str | None = None,
|
||||||
|
partition_granularity: str | None = None,
|
||||||
|
max_history_days: int | None = None,
|
||||||
|
) -> TableConfig:
|
||||||
|
"""Helper to build a TableConfig with safe defaults."""
|
||||||
|
return TableConfig(
|
||||||
|
id=table_id,
|
||||||
|
name=name,
|
||||||
|
description="Test table",
|
||||||
|
primary_key=primary_key,
|
||||||
|
sync_strategy=sync_strategy,
|
||||||
|
incremental_column=incremental_column,
|
||||||
|
incremental_window_days=incremental_window_days,
|
||||||
|
partition_by=partition_by,
|
||||||
|
partition_granularity=partition_granularity,
|
||||||
|
max_history_days=max_history_days,
|
||||||
|
)
|
||||||
|
|
||||||
|
|
||||||
|
def _sample_arrow_table(ids: list[int], names: list[str]) -> pa.Table:
|
||||||
|
"""Build a small PyArrow Table with id and name columns."""
|
||||||
|
return pa.table({"id": ids, "name": names})
|
||||||
|
|
||||||
|
|
||||||
|
def _create_adapter(mock_config, mock_bq_client):
|
||||||
|
"""Instantiate BigQueryDataSource with mocked dependencies.
|
||||||
|
|
||||||
|
Patches get_config and create_bq_client so that no real GCP
|
||||||
|
credentials or network access are needed.
|
||||||
|
"""
|
||||||
|
with patch("connectors.bigquery.adapter.get_config", return_value=mock_config), \
|
||||||
|
patch("connectors.bigquery.adapter.create_bq_client", return_value=mock_bq_client):
|
||||||
|
from connectors.bigquery.adapter import BigQueryDataSource
|
||||||
|
adapter = BigQueryDataSource()
|
||||||
|
return adapter
|
||||||
|
|
||||||
|
|
||||||
|
# ---------------------------------------------------------------------------
|
||||||
|
# 1. full_refresh writes valid Parquet file from Arrow table
|
||||||
|
# ---------------------------------------------------------------------------
|
||||||
|
|
||||||
|
class TestFullRefresh:
|
||||||
|
|
||||||
|
def test_writes_valid_parquet(self, mock_config, mock_bq_client, tmp_parquet_dir, sync_state):
|
||||||
|
"""full_refresh should write a valid, readable Parquet file."""
|
||||||
|
table_config = _make_table_config(sync_strategy="full_refresh")
|
||||||
|
parquet_path = tmp_parquet_dir / "orders.parquet"
|
||||||
|
mock_config.get_parquet_path.return_value = parquet_path
|
||||||
|
|
||||||
|
arrow_data = _sample_arrow_table([1, 2, 3], ["Alice", "Bob", "Charlie"])
|
||||||
|
mock_bq_client.read_table.return_value = arrow_data
|
||||||
|
|
||||||
|
adapter = _create_adapter(mock_config, mock_bq_client)
|
||||||
|
result = adapter.sync_table(table_config, sync_state)
|
||||||
|
|
||||||
|
assert result["success"] is True
|
||||||
|
assert result["rows"] == 3
|
||||||
|
assert parquet_path.exists()
|
||||||
|
|
||||||
|
# Verify Parquet content matches source data
|
||||||
|
read_back = pq.read_table(parquet_path)
|
||||||
|
assert read_back.num_rows == 3
|
||||||
|
assert read_back.column_names == ["id", "name"]
|
||||||
|
|
||||||
|
def test_applies_date_columns(self, mock_config, mock_bq_client, tmp_parquet_dir, sync_state):
|
||||||
|
"""full_refresh should call convert_date_columns_to_date32 when date columns exist."""
|
||||||
|
table_config = _make_table_config()
|
||||||
|
parquet_path = tmp_parquet_dir / "orders.parquet"
|
||||||
|
mock_config.get_parquet_path.return_value = parquet_path
|
||||||
|
|
||||||
|
arrow_data = _sample_arrow_table([1], ["Alice"])
|
||||||
|
mock_bq_client.read_table.return_value = arrow_data
|
||||||
|
mock_bq_client.get_date_columns.return_value = ["created_at"]
|
||||||
|
|
||||||
|
with patch("connectors.bigquery.adapter.convert_date_columns_to_date32", return_value=arrow_data) as mock_conv:
|
||||||
|
adapter = _create_adapter(mock_config, mock_bq_client)
|
||||||
|
adapter.sync_table(table_config, sync_state)
|
||||||
|
mock_conv.assert_called_once_with(arrow_data, ["created_at"])
|
||||||
|
|
||||||
|
def test_applies_pyarrow_schema(self, mock_config, mock_bq_client, tmp_parquet_dir, sync_state):
|
||||||
|
"""full_refresh should call apply_schema_to_table when schema is available."""
|
||||||
|
table_config = _make_table_config()
|
||||||
|
parquet_path = tmp_parquet_dir / "orders.parquet"
|
||||||
|
mock_config.get_parquet_path.return_value = parquet_path
|
||||||
|
|
||||||
|
arrow_data = _sample_arrow_table([1], ["Alice"])
|
||||||
|
mock_bq_client.read_table.return_value = arrow_data
|
||||||
|
schema = pa.schema([pa.field("id", pa.int64()), pa.field("name", pa.string())])
|
||||||
|
mock_bq_client.get_pyarrow_schema.return_value = schema
|
||||||
|
|
||||||
|
with patch("connectors.bigquery.adapter.apply_schema_to_table", return_value=arrow_data) as mock_apply:
|
||||||
|
adapter = _create_adapter(mock_config, mock_bq_client)
|
||||||
|
adapter.sync_table(table_config, sync_state)
|
||||||
|
mock_apply.assert_called_once_with(arrow_data, schema)
|
||||||
|
|
||||||
|
|
||||||
|
# ---------------------------------------------------------------------------
|
||||||
|
# 2. incremental_column_sync merges correctly (dedup on PK, new data wins)
|
||||||
|
# ---------------------------------------------------------------------------
|
||||||
|
|
||||||
|
class TestIncrementalColumnSync:
|
||||||
|
|
||||||
|
def test_merge_dedup_new_data_wins(self, mock_config, mock_bq_client, tmp_parquet_dir, sync_state):
|
||||||
|
"""Incremental sync should overwrite existing rows when PK matches (new data wins)."""
|
||||||
|
table_config = _make_table_config(
|
||||||
|
sync_strategy="incremental",
|
||||||
|
incremental_column="updated_at",
|
||||||
|
incremental_window_days=7,
|
||||||
|
)
|
||||||
|
parquet_path = tmp_parquet_dir / "orders.parquet"
|
||||||
|
mock_config.get_parquet_path.return_value = parquet_path
|
||||||
|
|
||||||
|
# Write existing data
|
||||||
|
existing = _sample_arrow_table([1, 2], ["Alice", "Bob"])
|
||||||
|
pq.write_table(existing, parquet_path)
|
||||||
|
|
||||||
|
# Simulate a previous sync timestamp
|
||||||
|
sync_state.update_sync(
|
||||||
|
table_id=table_config.id,
|
||||||
|
table_name=table_config.name,
|
||||||
|
strategy="incremental",
|
||||||
|
rows=2,
|
||||||
|
file_size_bytes=100,
|
||||||
|
)
|
||||||
|
|
||||||
|
# New data: id=2 gets updated name, id=3 is new
|
||||||
|
new_data = _sample_arrow_table([2, 3], ["Bob_Updated", "Charlie"])
|
||||||
|
mock_bq_client.read_table_incremental.return_value = new_data
|
||||||
|
|
||||||
|
adapter = _create_adapter(mock_config, mock_bq_client)
|
||||||
|
result = adapter.sync_table(table_config, sync_state)
|
||||||
|
|
||||||
|
assert result["success"] is True
|
||||||
|
assert result["rows"] == 3 # Alice + Bob_Updated + Charlie
|
||||||
|
|
||||||
|
read_back = pq.read_table(parquet_path)
|
||||||
|
df = read_back.to_pandas()
|
||||||
|
assert set(df["id"].tolist()) == {1, 2, 3}
|
||||||
|
# id=2 should have the updated name
|
||||||
|
bob_row = df[df["id"] == 2].iloc[0]
|
||||||
|
assert bob_row["name"] == "Bob_Updated"
|
||||||
|
|
||||||
|
|
||||||
|
# ---------------------------------------------------------------------------
|
||||||
|
# 3. incremental_column_sync with no new data returns existing file info
|
||||||
|
# ---------------------------------------------------------------------------
|
||||||
|
|
||||||
|
class TestIncrementalNoNewData:
|
||||||
|
|
||||||
|
def test_returns_existing_file_info(self, mock_config, mock_bq_client, tmp_parquet_dir, sync_state):
|
||||||
|
"""When there is no new data, sync returns stats from the existing Parquet file."""
|
||||||
|
table_config = _make_table_config(
|
||||||
|
sync_strategy="incremental",
|
||||||
|
incremental_column="updated_at",
|
||||||
|
incremental_window_days=7,
|
||||||
|
)
|
||||||
|
parquet_path = tmp_parquet_dir / "orders.parquet"
|
||||||
|
mock_config.get_parquet_path.return_value = parquet_path
|
||||||
|
|
||||||
|
# Write existing data
|
||||||
|
existing = _sample_arrow_table([1, 2, 3], ["A", "B", "C"])
|
||||||
|
pq.write_table(existing, parquet_path)
|
||||||
|
|
||||||
|
# Mark a previous sync
|
||||||
|
sync_state.update_sync(
|
||||||
|
table_id=table_config.id,
|
||||||
|
table_name=table_config.name,
|
||||||
|
strategy="incremental",
|
||||||
|
rows=3,
|
||||||
|
file_size_bytes=100,
|
||||||
|
)
|
||||||
|
|
||||||
|
# No new rows
|
||||||
|
empty_table = pa.table({
|
||||||
|
"id": pa.array([], type=pa.int64()),
|
||||||
|
"name": pa.array([], type=pa.string()),
|
||||||
|
})
|
||||||
|
mock_bq_client.read_table_incremental.return_value = empty_table
|
||||||
|
|
||||||
|
adapter = _create_adapter(mock_config, mock_bq_client)
|
||||||
|
result = adapter.sync_table(table_config, sync_state)
|
||||||
|
|
||||||
|
assert result["success"] is True
|
||||||
|
assert result["rows"] == 3 # existing row count preserved
|
||||||
|
|
||||||
|
|
||||||
|
# ---------------------------------------------------------------------------
|
||||||
|
# 4. partitioned_sync creates partition files
|
||||||
|
# ---------------------------------------------------------------------------
|
||||||
|
|
||||||
|
class TestPartitionedSync:
|
||||||
|
|
||||||
|
def test_creates_partition_files(self, mock_config, mock_bq_client, tmp_parquet_dir, sync_state):
|
||||||
|
"""Partitioned sync should create separate Parquet files per partition key."""
|
||||||
|
import pandas as pd
|
||||||
|
|
||||||
|
table_config = _make_table_config(
|
||||||
|
sync_strategy="incremental",
|
||||||
|
incremental_column="created_at",
|
||||||
|
partition_by="created_at",
|
||||||
|
partition_granularity="month",
|
||||||
|
incremental_window_days=7,
|
||||||
|
)
|
||||||
|
|
||||||
|
# For partitioned tables, parquet_path is a directory
|
||||||
|
partition_dir = tmp_parquet_dir / "orders"
|
||||||
|
partition_dir.mkdir(parents=True, exist_ok=True)
|
||||||
|
mock_config.get_parquet_path.return_value = partition_dir
|
||||||
|
|
||||||
|
# Configure partition paths
|
||||||
|
def _partition_path(tc, key):
|
||||||
|
return partition_dir / f"{key}.parquet"
|
||||||
|
mock_config.get_partition_path.side_effect = _partition_path
|
||||||
|
|
||||||
|
# Build arrow table with timestamps in two months
|
||||||
|
ts_jan = [pd.Timestamp("2026-01-15 10:00:00", tz="UTC")]
|
||||||
|
ts_feb = [pd.Timestamp("2026-02-20 14:00:00", tz="UTC")]
|
||||||
|
arrow_data = pa.table({
|
||||||
|
"id": [1, 2],
|
||||||
|
"name": ["Jan_Order", "Feb_Order"],
|
||||||
|
"created_at": pa.array(ts_jan + ts_feb, type=pa.timestamp("us", tz="UTC")),
|
||||||
|
})
|
||||||
|
mock_bq_client.read_table.return_value = arrow_data
|
||||||
|
|
||||||
|
adapter = _create_adapter(mock_config, mock_bq_client)
|
||||||
|
result = adapter.sync_table(table_config, sync_state)
|
||||||
|
|
||||||
|
assert result["success"] is True
|
||||||
|
|
||||||
|
# Should have created two partition files
|
||||||
|
partition_files = list(partition_dir.glob("*.parquet"))
|
||||||
|
assert len(partition_files) == 2
|
||||||
|
|
||||||
|
partition_names = sorted(f.stem for f in partition_files)
|
||||||
|
assert "2026_01" in partition_names
|
||||||
|
assert "2026_02" in partition_names
|
||||||
|
|
||||||
|
|
||||||
|
# ---------------------------------------------------------------------------
|
||||||
|
# 5. discover_tables delegates to BigQueryClient.discover_all_tables()
|
||||||
|
# ---------------------------------------------------------------------------
|
||||||
|
|
||||||
|
class TestDiscoverTables:
|
||||||
|
|
||||||
|
def test_delegates_to_client(self, mock_config, mock_bq_client):
|
||||||
|
"""discover_tables should forward the call to BigQueryClient.discover_all_tables."""
|
||||||
|
expected = [{"id": "proj.ds.t1", "name": "t1", "columns": ["a", "b"]}]
|
||||||
|
mock_bq_client.discover_all_tables.return_value = expected
|
||||||
|
|
||||||
|
adapter = _create_adapter(mock_config, mock_bq_client)
|
||||||
|
result = adapter.discover_tables()
|
||||||
|
|
||||||
|
mock_bq_client.discover_all_tables.assert_called_once()
|
||||||
|
assert result == expected
|
||||||
|
|
||||||
|
|
||||||
|
# ---------------------------------------------------------------------------
|
||||||
|
# 6. get_source_name returns "Google BigQuery"
|
||||||
|
# ---------------------------------------------------------------------------
|
||||||
|
|
||||||
|
class TestGetSourceName:
|
||||||
|
|
||||||
|
def test_returns_google_bigquery(self, mock_config, mock_bq_client):
|
||||||
|
adapter = _create_adapter(mock_config, mock_bq_client)
|
||||||
|
assert adapter.get_source_name() == "Google BigQuery"
|
||||||
|
|
||||||
|
|
||||||
|
# ---------------------------------------------------------------------------
|
||||||
|
# 7. get_column_metadata returns correct format
|
||||||
|
# ---------------------------------------------------------------------------
|
||||||
|
|
||||||
|
class TestGetColumnMetadata:
|
||||||
|
|
||||||
|
def test_returns_correct_format(self, mock_config, mock_bq_client):
|
||||||
|
"""get_column_metadata should transform BQ raw metadata into {columns: ...} format."""
|
||||||
|
mock_bq_client.get_table_metadata.return_value = {
|
||||||
|
"column_types": {"id": "INT64", "name": "STRING", "email": "STRING"},
|
||||||
|
"column_descriptions": {"id": "Primary key", "email": "User email address"},
|
||||||
|
}
|
||||||
|
|
||||||
|
adapter = _create_adapter(mock_config, mock_bq_client)
|
||||||
|
result = adapter.get_column_metadata("project.dataset.users")
|
||||||
|
|
||||||
|
assert "columns" in result
|
||||||
|
assert result["columns"]["id"] == {"source_type": "INT64", "description": "Primary key"}
|
||||||
|
assert result["columns"]["name"] == {"source_type": "STRING"}
|
||||||
|
assert result["columns"]["email"] == {
|
||||||
|
"source_type": "STRING",
|
||||||
|
"description": "User email address",
|
||||||
|
}
|
||||||
|
|
||||||
|
def test_returns_none_when_no_column_types(self, mock_config, mock_bq_client):
|
||||||
|
"""get_column_metadata should return None if the metadata has no column types."""
|
||||||
|
mock_bq_client.get_table_metadata.return_value = {
|
||||||
|
"column_types": {},
|
||||||
|
"column_descriptions": {},
|
||||||
|
}
|
||||||
|
|
||||||
|
adapter = _create_adapter(mock_config, mock_bq_client)
|
||||||
|
result = adapter.get_column_metadata("project.dataset.users")
|
||||||
|
|
||||||
|
assert result is None
|
||||||
|
|
||||||
|
|
||||||
|
# ---------------------------------------------------------------------------
|
||||||
|
# 8. Error handling (query failure -> {success: False, error: ...})
|
||||||
|
# ---------------------------------------------------------------------------
|
||||||
|
|
||||||
|
class TestErrorHandling:
|
||||||
|
|
||||||
|
def test_query_failure_returns_error_dict(
|
||||||
|
self, mock_config, mock_bq_client, tmp_parquet_dir, sync_state
|
||||||
|
):
|
||||||
|
"""When BigQuery query raises, sync_table returns {success: False, error: ...}."""
|
||||||
|
table_config = _make_table_config()
|
||||||
|
mock_config.get_parquet_path.return_value = tmp_parquet_dir / "orders.parquet"
|
||||||
|
mock_bq_client.read_table.side_effect = RuntimeError("BigQuery API timeout")
|
||||||
|
|
||||||
|
adapter = _create_adapter(mock_config, mock_bq_client)
|
||||||
|
result = adapter.sync_table(table_config, sync_state)
|
||||||
|
|
||||||
|
assert result["success"] is False
|
||||||
|
assert "BigQuery API timeout" in result["error"]
|
||||||
|
assert result["strategy"] == "full_refresh"
|
||||||
|
|
||||||
|
def test_unknown_strategy_returns_error(self, mock_config, mock_bq_client, sync_state):
|
||||||
|
"""Unknown sync_strategy in internal dispatch should produce an error result."""
|
||||||
|
# We cannot create a TableConfig with an invalid strategy via constructor
|
||||||
|
# (it validates). Instead, we mutate it after creation.
|
||||||
|
table_config = _make_table_config()
|
||||||
|
table_config.sync_strategy = "magic_sync"
|
||||||
|
|
||||||
|
adapter = _create_adapter(mock_config, mock_bq_client)
|
||||||
|
result = adapter.sync_table(table_config, sync_state)
|
||||||
|
|
||||||
|
assert result["success"] is False
|
||||||
|
assert "Unknown sync strategy" in result["error"]
|
||||||
|
|
||||||
|
|
||||||
|
# ---------------------------------------------------------------------------
|
||||||
|
# 9. incremental_column config is used in WHERE clause
|
||||||
|
# ---------------------------------------------------------------------------
|
||||||
|
|
||||||
|
class TestIncrementalColumnUsedInWhere:
|
||||||
|
|
||||||
|
def test_incremental_column_passed_to_client(
|
||||||
|
self, mock_config, mock_bq_client, tmp_parquet_dir, sync_state
|
||||||
|
):
|
||||||
|
"""The configured incremental_column should be forwarded to read_table_incremental."""
|
||||||
|
table_config = _make_table_config(
|
||||||
|
sync_strategy="incremental",
|
||||||
|
incremental_column="modified_at",
|
||||||
|
incremental_window_days=14,
|
||||||
|
)
|
||||||
|
parquet_path = tmp_parquet_dir / "orders.parquet"
|
||||||
|
mock_config.get_parquet_path.return_value = parquet_path
|
||||||
|
|
||||||
|
# Write existing data so we enter the incremental path
|
||||||
|
existing = _sample_arrow_table([1], ["Alice"])
|
||||||
|
pq.write_table(existing, parquet_path)
|
||||||
|
|
||||||
|
sync_state.update_sync(
|
||||||
|
table_id=table_config.id,
|
||||||
|
table_name=table_config.name,
|
||||||
|
strategy="incremental",
|
||||||
|
rows=1,
|
||||||
|
file_size_bytes=100,
|
||||||
|
)
|
||||||
|
|
||||||
|
# Return empty to keep the test simple
|
||||||
|
empty = pa.table({
|
||||||
|
"id": pa.array([], type=pa.int64()),
|
||||||
|
"name": pa.array([], type=pa.string()),
|
||||||
|
})
|
||||||
|
mock_bq_client.read_table_incremental.return_value = empty
|
||||||
|
|
||||||
|
adapter = _create_adapter(mock_config, mock_bq_client)
|
||||||
|
adapter.sync_table(table_config, sync_state)
|
||||||
|
|
||||||
|
call_kwargs = mock_bq_client.read_table_incremental.call_args
|
||||||
|
assert call_kwargs.kwargs["incremental_column"] == "modified_at"
|
||||||
|
assert call_kwargs.kwargs["table_id"] == "project.dataset.orders"
|
||||||
|
# since_value should be an ISO string
|
||||||
|
assert "since_value" in call_kwargs.kwargs
|
||||||
|
|
||||||
|
|
||||||
|
# ---------------------------------------------------------------------------
|
||||||
|
# 10. First sync without existing file downloads all data
|
||||||
|
# ---------------------------------------------------------------------------
|
||||||
|
|
||||||
|
class TestFirstSyncDownloadsAll:
|
||||||
|
|
||||||
|
def test_first_sync_reads_full_table(
|
||||||
|
self, mock_config, mock_bq_client, tmp_parquet_dir, sync_state
|
||||||
|
):
|
||||||
|
"""On first incremental sync (no existing file), adapter should read all data."""
|
||||||
|
table_config = _make_table_config(
|
||||||
|
sync_strategy="incremental",
|
||||||
|
incremental_column="updated_at",
|
||||||
|
incremental_window_days=7,
|
||||||
|
)
|
||||||
|
parquet_path = tmp_parquet_dir / "orders.parquet"
|
||||||
|
mock_config.get_parquet_path.return_value = parquet_path
|
||||||
|
|
||||||
|
# No previous sync, no existing file
|
||||||
|
arrow_data = _sample_arrow_table([1, 2, 3], ["A", "B", "C"])
|
||||||
|
mock_bq_client.read_table.return_value = arrow_data
|
||||||
|
|
||||||
|
adapter = _create_adapter(mock_config, mock_bq_client)
|
||||||
|
result = adapter.sync_table(table_config, sync_state)
|
||||||
|
|
||||||
|
assert result["success"] is True
|
||||||
|
assert result["rows"] == 3
|
||||||
|
# Should call read_table (full), not read_table_incremental
|
||||||
|
mock_bq_client.read_table.assert_called_once_with(table_config.id)
|
||||||
|
mock_bq_client.read_table_incremental.assert_not_called()
|
||||||
|
|
||||||
|
def test_first_sync_with_max_history_days(
|
||||||
|
self, mock_config, mock_bq_client, tmp_parquet_dir, sync_state
|
||||||
|
):
|
||||||
|
"""First sync with max_history_days should use read_table_incremental."""
|
||||||
|
table_config = _make_table_config(
|
||||||
|
sync_strategy="incremental",
|
||||||
|
incremental_column="updated_at",
|
||||||
|
incremental_window_days=7,
|
||||||
|
max_history_days=90,
|
||||||
|
)
|
||||||
|
parquet_path = tmp_parquet_dir / "orders.parquet"
|
||||||
|
mock_config.get_parquet_path.return_value = parquet_path
|
||||||
|
|
||||||
|
arrow_data = _sample_arrow_table([1, 2], ["A", "B"])
|
||||||
|
mock_bq_client.read_table_incremental.return_value = arrow_data
|
||||||
|
|
||||||
|
adapter = _create_adapter(mock_config, mock_bq_client)
|
||||||
|
result = adapter.sync_table(table_config, sync_state)
|
||||||
|
|
||||||
|
assert result["success"] is True
|
||||||
|
# Should use read_table_incremental (not read_table) because max_history_days is set
|
||||||
|
mock_bq_client.read_table_incremental.assert_called_once()
|
||||||
|
call_kwargs = mock_bq_client.read_table_incremental.call_args.kwargs
|
||||||
|
assert call_kwargs["incremental_column"] == "updated_at"
|
||||||
|
mock_bq_client.read_table.assert_not_called()
|
||||||
|
|
||||||
|
|
||||||
|
# ---------------------------------------------------------------------------
|
||||||
|
# 11. sync_table dispatches to correct strategy based on sync_strategy
|
||||||
|
# ---------------------------------------------------------------------------
|
||||||
|
|
||||||
|
class TestSyncTableDispatch:
|
||||||
|
|
||||||
|
def test_dispatches_full_refresh(
|
||||||
|
self, mock_config, mock_bq_client, tmp_parquet_dir, sync_state
|
||||||
|
):
|
||||||
|
"""sync_strategy='full_refresh' should call _full_refresh."""
|
||||||
|
table_config = _make_table_config(sync_strategy="full_refresh")
|
||||||
|
mock_config.get_parquet_path.return_value = tmp_parquet_dir / "orders.parquet"
|
||||||
|
mock_bq_client.read_table.return_value = _sample_arrow_table([1], ["A"])
|
||||||
|
|
||||||
|
adapter = _create_adapter(mock_config, mock_bq_client)
|
||||||
|
|
||||||
|
with patch.object(adapter, "_full_refresh", wraps=adapter._full_refresh) as spy:
|
||||||
|
adapter.sync_table(table_config, sync_state)
|
||||||
|
spy.assert_called_once_with(table_config)
|
||||||
|
|
||||||
|
def test_dispatches_incremental(
|
||||||
|
self, mock_config, mock_bq_client, tmp_parquet_dir, sync_state
|
||||||
|
):
|
||||||
|
"""sync_strategy='incremental' should call _incremental_sync."""
|
||||||
|
table_config = _make_table_config(
|
||||||
|
sync_strategy="incremental",
|
||||||
|
incremental_column="updated_at",
|
||||||
|
incremental_window_days=7,
|
||||||
|
)
|
||||||
|
mock_config.get_parquet_path.return_value = tmp_parquet_dir / "orders.parquet"
|
||||||
|
mock_bq_client.read_table.return_value = _sample_arrow_table([1], ["A"])
|
||||||
|
|
||||||
|
adapter = _create_adapter(mock_config, mock_bq_client)
|
||||||
|
|
||||||
|
with patch.object(adapter, "_incremental_sync", wraps=adapter._incremental_sync) as spy:
|
||||||
|
adapter.sync_table(table_config, sync_state)
|
||||||
|
spy.assert_called_once_with(table_config, sync_state)
|
||||||
|
|
||||||
|
def test_dispatches_partitioned(
|
||||||
|
self, mock_config, mock_bq_client, tmp_parquet_dir, sync_state
|
||||||
|
):
|
||||||
|
"""sync_strategy='incremental' with partition_by should call _partitioned_sync."""
|
||||||
|
import pandas as pd
|
||||||
|
|
||||||
|
table_config = _make_table_config(
|
||||||
|
sync_strategy="incremental",
|
||||||
|
incremental_column="created_at",
|
||||||
|
partition_by="created_at",
|
||||||
|
partition_granularity="month",
|
||||||
|
incremental_window_days=7,
|
||||||
|
)
|
||||||
|
partition_dir = tmp_parquet_dir / "orders"
|
||||||
|
partition_dir.mkdir(parents=True, exist_ok=True)
|
||||||
|
mock_config.get_parquet_path.return_value = partition_dir
|
||||||
|
|
||||||
|
def _partition_path(tc, key):
|
||||||
|
return partition_dir / f"{key}.parquet"
|
||||||
|
mock_config.get_partition_path.side_effect = _partition_path
|
||||||
|
|
||||||
|
ts = [pd.Timestamp("2026-01-15 10:00:00", tz="UTC")]
|
||||||
|
arrow_data = pa.table({
|
||||||
|
"id": [1],
|
||||||
|
"name": ["A"],
|
||||||
|
"created_at": pa.array(ts, type=pa.timestamp("us", tz="UTC")),
|
||||||
|
})
|
||||||
|
mock_bq_client.read_table.return_value = arrow_data
|
||||||
|
|
||||||
|
adapter = _create_adapter(mock_config, mock_bq_client)
|
||||||
|
|
||||||
|
with patch.object(adapter, "_partitioned_sync", wraps=adapter._partitioned_sync) as spy:
|
||||||
|
adapter.sync_table(table_config, sync_state)
|
||||||
|
spy.assert_called_once()
|
||||||
|
|
||||||
|
def test_incremental_without_column_falls_back_to_full_refresh(
|
||||||
|
self, mock_config, mock_bq_client, tmp_parquet_dir, sync_state
|
||||||
|
):
|
||||||
|
"""incremental strategy without incremental_column or partition_by falls back to full_refresh."""
|
||||||
|
table_config = _make_table_config(
|
||||||
|
sync_strategy="incremental",
|
||||||
|
incremental_column=None,
|
||||||
|
partition_by=None,
|
||||||
|
incremental_window_days=7,
|
||||||
|
)
|
||||||
|
mock_config.get_parquet_path.return_value = tmp_parquet_dir / "orders.parquet"
|
||||||
|
mock_bq_client.read_table.return_value = _sample_arrow_table([1], ["A"])
|
||||||
|
|
||||||
|
adapter = _create_adapter(mock_config, mock_bq_client)
|
||||||
|
|
||||||
|
with patch.object(adapter, "_full_refresh", wraps=adapter._full_refresh) as spy:
|
||||||
|
result = adapter.sync_table(table_config, sync_state)
|
||||||
|
spy.assert_called_once()
|
||||||
|
assert result["success"] is True
|
||||||
|
|
||||||
|
|
||||||
|
# ---------------------------------------------------------------------------
|
||||||
|
# 12. _merge_arrow_tables deduplicates correctly
|
||||||
|
# ---------------------------------------------------------------------------
|
||||||
|
|
||||||
|
class TestMergeArrowTables:
|
||||||
|
|
||||||
|
def test_dedup_on_single_pk(self, mock_config, mock_bq_client):
|
||||||
|
"""Merge should deduplicate on single primary key column, new data wins."""
|
||||||
|
adapter = _create_adapter(mock_config, mock_bq_client)
|
||||||
|
|
||||||
|
existing = pa.table({"id": [1, 2, 3], "val": ["a", "b", "c"]})
|
||||||
|
new_data = pa.table({"id": [2, 4], "val": ["B_new", "d"]})
|
||||||
|
|
||||||
|
merged = adapter._merge_arrow_tables(existing, new_data, primary_key=["id"])
|
||||||
|
df = merged.to_pandas().sort_values("id").reset_index(drop=True)
|
||||||
|
|
||||||
|
assert list(df["id"]) == [1, 2, 3, 4]
|
||||||
|
assert list(df["val"]) == ["a", "B_new", "c", "d"]
|
||||||
|
|
||||||
|
def test_dedup_on_composite_pk(self, mock_config, mock_bq_client):
|
||||||
|
"""Merge should deduplicate on composite primary key."""
|
||||||
|
adapter = _create_adapter(mock_config, mock_bq_client)
|
||||||
|
|
||||||
|
existing = pa.table({
|
||||||
|
"pk1": [1, 1, 2],
|
||||||
|
"pk2": ["a", "b", "a"],
|
||||||
|
"val": ["old_1a", "old_1b", "old_2a"],
|
||||||
|
})
|
||||||
|
new_data = pa.table({
|
||||||
|
"pk1": [1, 2],
|
||||||
|
"pk2": ["a", "a"],
|
||||||
|
"val": ["new_1a", "new_2a"],
|
||||||
|
})
|
||||||
|
|
||||||
|
merged = adapter._merge_arrow_tables(existing, new_data, primary_key=["pk1", "pk2"])
|
||||||
|
df = merged.to_pandas().sort_values(["pk1", "pk2"]).reset_index(drop=True)
|
||||||
|
|
||||||
|
assert len(df) == 3
|
||||||
|
# (1, a) should be updated
|
||||||
|
row_1a = df[(df["pk1"] == 1) & (df["pk2"] == "a")].iloc[0]
|
||||||
|
assert row_1a["val"] == "new_1a"
|
||||||
|
# (1, b) should be preserved
|
||||||
|
row_1b = df[(df["pk1"] == 1) & (df["pk2"] == "b")].iloc[0]
|
||||||
|
assert row_1b["val"] == "old_1b"
|
||||||
|
# (2, a) should be updated
|
||||||
|
row_2a = df[(df["pk1"] == 2) & (df["pk2"] == "a")].iloc[0]
|
||||||
|
assert row_2a["val"] == "new_2a"
|
||||||
|
|
||||||
|
def test_merge_with_empty_new_data(self, mock_config, mock_bq_client):
|
||||||
|
"""Merging with empty new data should return existing data unchanged."""
|
||||||
|
adapter = _create_adapter(mock_config, mock_bq_client)
|
||||||
|
|
||||||
|
existing = pa.table({"id": [1, 2], "val": ["a", "b"]})
|
||||||
|
empty = pa.table({
|
||||||
|
"id": pa.array([], type=pa.int64()),
|
||||||
|
"val": pa.array([], type=pa.string()),
|
||||||
|
})
|
||||||
|
|
||||||
|
merged = adapter._merge_arrow_tables(existing, empty, primary_key=["id"])
|
||||||
|
assert merged.num_rows == 2
|
||||||
|
|
||||||
|
def test_merge_with_empty_existing(self, mock_config, mock_bq_client):
|
||||||
|
"""Merging with empty existing data should return new data."""
|
||||||
|
adapter = _create_adapter(mock_config, mock_bq_client)
|
||||||
|
|
||||||
|
empty = pa.table({
|
||||||
|
"id": pa.array([], type=pa.int64()),
|
||||||
|
"val": pa.array([], type=pa.string()),
|
||||||
|
})
|
||||||
|
new_data = pa.table({"id": [1, 2], "val": ["a", "b"]})
|
||||||
|
|
||||||
|
merged = adapter._merge_arrow_tables(empty, new_data, primary_key=["id"])
|
||||||
|
assert merged.num_rows == 2
|
||||||
|
|
||||||
|
|
||||||
|
# ---------------------------------------------------------------------------
|
||||||
|
# Additional edge cases
|
||||||
|
# ---------------------------------------------------------------------------
|
||||||
|
|
||||||
|
class TestMetadataCacheClearing:
|
||||||
|
|
||||||
|
def test_clears_metadata_cache_before_sync(
|
||||||
|
self, mock_config, mock_bq_client, tmp_parquet_dir, sync_state
|
||||||
|
):
|
||||||
|
"""sync_table should clear the BQ metadata cache entry for the table being synced."""
|
||||||
|
table_config = _make_table_config()
|
||||||
|
parquet_path = tmp_parquet_dir / "orders.parquet"
|
||||||
|
mock_config.get_parquet_path.return_value = parquet_path
|
||||||
|
mock_bq_client.read_table.return_value = _sample_arrow_table([1], ["A"])
|
||||||
|
|
||||||
|
# Pre-populate cache
|
||||||
|
mock_bq_client.metadata_cache[table_config.id] = {"some": "cached_data"}
|
||||||
|
|
||||||
|
adapter = _create_adapter(mock_config, mock_bq_client)
|
||||||
|
adapter.sync_table(table_config, sync_state)
|
||||||
|
|
||||||
|
assert table_config.id not in mock_bq_client.metadata_cache
|
||||||
|
|
||||||
|
|
||||||
|
class TestSyncStateUpdate:
|
||||||
|
|
||||||
|
def test_sync_state_updated_after_success(
|
||||||
|
self, mock_config, mock_bq_client, tmp_parquet_dir, sync_state
|
||||||
|
):
|
||||||
|
"""After successful sync, the sync state should be updated with correct values."""
|
||||||
|
table_config = _make_table_config()
|
||||||
|
parquet_path = tmp_parquet_dir / "orders.parquet"
|
||||||
|
mock_config.get_parquet_path.return_value = parquet_path
|
||||||
|
mock_bq_client.read_table.return_value = _sample_arrow_table([1, 2], ["A", "B"])
|
||||||
|
|
||||||
|
adapter = _create_adapter(mock_config, mock_bq_client)
|
||||||
|
adapter.sync_table(table_config, sync_state)
|
||||||
|
|
||||||
|
state = sync_state.get_table_state(table_config.id)
|
||||||
|
assert state["rows"] == 2
|
||||||
|
assert state["strategy"] == "full_refresh"
|
||||||
|
assert state["table_name"] == "orders"
|
||||||
|
assert "last_sync" in state
|
||||||
|
|
||||||
|
def test_sync_state_not_updated_on_failure(
|
||||||
|
self, mock_config, mock_bq_client, tmp_parquet_dir, sync_state
|
||||||
|
):
|
||||||
|
"""On sync failure, the sync state should NOT be updated."""
|
||||||
|
table_config = _make_table_config()
|
||||||
|
mock_config.get_parquet_path.return_value = tmp_parquet_dir / "orders.parquet"
|
||||||
|
mock_bq_client.read_table.side_effect = RuntimeError("boom")
|
||||||
|
|
||||||
|
adapter = _create_adapter(mock_config, mock_bq_client)
|
||||||
|
adapter.sync_table(table_config, sync_state)
|
||||||
|
|
||||||
|
state = sync_state.get_table_state(table_config.id)
|
||||||
|
assert state == {}
|
||||||
|
|
||||||
|
|
||||||
|
class TestCreateDataSourceFactory:
|
||||||
|
|
||||||
|
def test_factory_returns_adapter_instance(self, mock_config, mock_bq_client):
|
||||||
|
"""create_data_source() factory should return a BigQueryDataSource instance."""
|
||||||
|
with patch("connectors.bigquery.adapter.get_config", return_value=mock_config), \
|
||||||
|
patch("connectors.bigquery.adapter.create_bq_client", return_value=mock_bq_client):
|
||||||
|
from connectors.bigquery.adapter import create_data_source, BigQueryDataSource
|
||||||
|
instance = create_data_source()
|
||||||
|
assert isinstance(instance, BigQueryDataSource)
|
||||||
870
tests/test_bigquery_client.py
Normal file
870
tests/test_bigquery_client.py
Normal file
|
|
@ -0,0 +1,870 @@
|
||||||
|
"""Tests for the BigQuery client connector.
|
||||||
|
|
||||||
|
All external dependencies (google.cloud.bigquery, src.config) are mocked.
|
||||||
|
Tests cover initialization, metadata caching, schema building, query methods,
|
||||||
|
and connection testing.
|
||||||
|
"""
|
||||||
|
|
||||||
|
import json
|
||||||
|
import sys
|
||||||
|
from datetime import datetime, timedelta
|
||||||
|
from pathlib import Path
|
||||||
|
from unittest.mock import MagicMock, mock_open, patch
|
||||||
|
|
||||||
|
import pyarrow as pa
|
||||||
|
import pytest
|
||||||
|
|
||||||
|
# Pre-populate sys.modules with a mock google.cloud.bigquery if not installed,
|
||||||
|
# so the client module can be imported without the real SDK.
|
||||||
|
_bq_mock_installed = False
|
||||||
|
try:
|
||||||
|
from google.cloud import bigquery as _bq_test # noqa: F401
|
||||||
|
except ImportError:
|
||||||
|
_bq_mock_installed = True
|
||||||
|
_mock_bigquery = MagicMock()
|
||||||
|
# Expose commonly used classes as MagicMock so the client module
|
||||||
|
# can reference bigquery.Client, bigquery.QueryJobConfig, etc.
|
||||||
|
sys.modules.setdefault("google", MagicMock())
|
||||||
|
sys.modules.setdefault("google.cloud", MagicMock())
|
||||||
|
sys.modules.setdefault("google.cloud.bigquery", _mock_bigquery)
|
||||||
|
|
||||||
|
from connectors.bigquery.client import (
|
||||||
|
BIGQUERY_TO_PYARROW_TYPES,
|
||||||
|
BigQueryClient,
|
||||||
|
create_client,
|
||||||
|
)
|
||||||
|
|
||||||
|
# Import the real or mock bigquery reference used in the client module
|
||||||
|
from google.cloud import bigquery
|
||||||
|
|
||||||
|
|
||||||
|
# ---------------------------------------------------------------------------
|
||||||
|
# Helpers
|
||||||
|
# ---------------------------------------------------------------------------
|
||||||
|
|
||||||
|
def _make_bq_field(name: str, field_type: str, description: str = None):
|
||||||
|
"""Create a mock BigQuery SchemaField."""
|
||||||
|
field = MagicMock()
|
||||||
|
field.name = name
|
||||||
|
field.field_type = field_type
|
||||||
|
field.description = description
|
||||||
|
return field
|
||||||
|
|
||||||
|
|
||||||
|
def _make_table_ref(
|
||||||
|
table_id: str = "my-project.my_dataset.my_table",
|
||||||
|
schema=None,
|
||||||
|
num_rows: int = 1000,
|
||||||
|
num_bytes: int = 50000,
|
||||||
|
created: datetime = None,
|
||||||
|
modified: datetime = None,
|
||||||
|
time_partitioning=None,
|
||||||
|
):
|
||||||
|
"""Create a mock BigQuery Table reference object."""
|
||||||
|
table_ref = MagicMock()
|
||||||
|
table_ref.table_id = table_id.split(".")[-1]
|
||||||
|
table_ref.dataset_id = table_id.split(".")[1] if "." in table_id else "dataset"
|
||||||
|
table_ref.project = table_id.split(".")[0] if "." in table_id else "project"
|
||||||
|
table_ref.schema = schema or []
|
||||||
|
table_ref.num_rows = num_rows
|
||||||
|
table_ref.num_bytes = num_bytes
|
||||||
|
table_ref.created = created or datetime(2025, 1, 1, 12, 0, 0)
|
||||||
|
table_ref.modified = modified or datetime(2025, 6, 1, 12, 0, 0)
|
||||||
|
table_ref.time_partitioning = time_partitioning
|
||||||
|
return table_ref
|
||||||
|
|
||||||
|
|
||||||
|
@pytest.fixture
|
||||||
|
def mock_config(tmp_path):
|
||||||
|
"""Mock get_config() to return a config with metadata path in tmp_path."""
|
||||||
|
config = MagicMock()
|
||||||
|
metadata_dir = tmp_path / "metadata"
|
||||||
|
metadata_dir.mkdir(parents=True, exist_ok=True)
|
||||||
|
config.get_metadata_path.return_value = metadata_dir
|
||||||
|
return config
|
||||||
|
|
||||||
|
|
||||||
|
@pytest.fixture
|
||||||
|
def mock_bq_client():
|
||||||
|
"""Create a mock BigQuery Client."""
|
||||||
|
return MagicMock()
|
||||||
|
|
||||||
|
|
||||||
|
@pytest.fixture
|
||||||
|
def client(mock_config, mock_bq_client):
|
||||||
|
"""Create a BigQueryClient instance with mocked dependencies."""
|
||||||
|
with (
|
||||||
|
patch("connectors.bigquery.client.bigquery.Client", return_value=mock_bq_client),
|
||||||
|
patch("connectors.bigquery.client.get_config", return_value=mock_config),
|
||||||
|
patch.dict("os.environ", {"BIGQUERY_PROJECT": "test-project"}),
|
||||||
|
):
|
||||||
|
bq_client = BigQueryClient()
|
||||||
|
return bq_client
|
||||||
|
|
||||||
|
|
||||||
|
# ---------------------------------------------------------------------------
|
||||||
|
# 1. Init validates BIGQUERY_PROJECT env var
|
||||||
|
# ---------------------------------------------------------------------------
|
||||||
|
|
||||||
|
class TestInit:
|
||||||
|
def test_raises_value_error_when_project_not_set(self, mock_config):
|
||||||
|
"""Init raises ValueError if project_id is None and env var is missing."""
|
||||||
|
with (
|
||||||
|
patch("connectors.bigquery.client.bigquery.Client"),
|
||||||
|
patch("connectors.bigquery.client.get_config", return_value=mock_config),
|
||||||
|
patch.dict("os.environ", {}, clear=True),
|
||||||
|
):
|
||||||
|
with pytest.raises(ValueError, match="BigQuery project ID not set"):
|
||||||
|
BigQueryClient()
|
||||||
|
|
||||||
|
def test_raises_value_error_when_project_empty_string(self, mock_config):
|
||||||
|
"""Init raises ValueError if BIGQUERY_PROJECT is set to empty string."""
|
||||||
|
with (
|
||||||
|
patch("connectors.bigquery.client.bigquery.Client"),
|
||||||
|
patch("connectors.bigquery.client.get_config", return_value=mock_config),
|
||||||
|
patch.dict("os.environ", {"BIGQUERY_PROJECT": ""}, clear=True),
|
||||||
|
):
|
||||||
|
with pytest.raises(ValueError, match="BigQuery project ID not set"):
|
||||||
|
BigQueryClient()
|
||||||
|
|
||||||
|
# -------------------------------------------------------------------
|
||||||
|
# 2. Init creates client with correct project_id
|
||||||
|
# -------------------------------------------------------------------
|
||||||
|
|
||||||
|
def test_creates_client_with_env_project_id(self, mock_config):
|
||||||
|
"""Client uses BIGQUERY_PROJECT from environment."""
|
||||||
|
mock_bq = MagicMock()
|
||||||
|
with (
|
||||||
|
patch("connectors.bigquery.client.bigquery.Client", return_value=mock_bq) as bq_cls,
|
||||||
|
patch("connectors.bigquery.client.get_config", return_value=mock_config),
|
||||||
|
patch.dict("os.environ", {"BIGQUERY_PROJECT": "env-project-123"}),
|
||||||
|
):
|
||||||
|
client = BigQueryClient()
|
||||||
|
bq_cls.assert_called_once_with(project="env-project-123")
|
||||||
|
assert client.project_id == "env-project-123"
|
||||||
|
|
||||||
|
def test_creates_client_with_explicit_project_id(self, mock_config):
|
||||||
|
"""Explicit project_id argument takes precedence over env var."""
|
||||||
|
mock_bq = MagicMock()
|
||||||
|
with (
|
||||||
|
patch("connectors.bigquery.client.bigquery.Client", return_value=mock_bq) as bq_cls,
|
||||||
|
patch("connectors.bigquery.client.get_config", return_value=mock_config),
|
||||||
|
):
|
||||||
|
client = BigQueryClient(project_id="explicit-project")
|
||||||
|
bq_cls.assert_called_once_with(project="explicit-project")
|
||||||
|
assert client.project_id == "explicit-project"
|
||||||
|
|
||||||
|
|
||||||
|
# ---------------------------------------------------------------------------
|
||||||
|
# 3. get_table_metadata fetches and caches metadata correctly
|
||||||
|
# ---------------------------------------------------------------------------
|
||||||
|
|
||||||
|
class TestGetTableMetadata:
|
||||||
|
def test_fetches_metadata_from_bigquery(self, client, mock_bq_client):
|
||||||
|
"""get_table_metadata calls client.get_table and returns correct dict."""
|
||||||
|
table_id = "proj.dataset.orders"
|
||||||
|
schema = [
|
||||||
|
_make_bq_field("order_id", "INTEGER"),
|
||||||
|
_make_bq_field("customer_name", "STRING", description="Full name"),
|
||||||
|
_make_bq_field("created_at", "TIMESTAMP"),
|
||||||
|
]
|
||||||
|
table_ref = _make_table_ref(
|
||||||
|
table_id=table_id,
|
||||||
|
schema=schema,
|
||||||
|
num_rows=5000,
|
||||||
|
num_bytes=120000,
|
||||||
|
)
|
||||||
|
mock_bq_client.get_table.return_value = table_ref
|
||||||
|
|
||||||
|
metadata = client.get_table_metadata(table_id, use_cache=False)
|
||||||
|
|
||||||
|
mock_bq_client.get_table.assert_called_once_with(table_id)
|
||||||
|
assert metadata["table_id"] == table_id
|
||||||
|
assert metadata["name"] == "orders"
|
||||||
|
assert metadata["dataset"] == "dataset"
|
||||||
|
assert metadata["project"] == "proj"
|
||||||
|
assert metadata["columns"] == ["order_id", "customer_name", "created_at"]
|
||||||
|
assert metadata["column_types"]["order_id"] == "INTEGER"
|
||||||
|
assert metadata["column_types"]["customer_name"] == "STRING"
|
||||||
|
assert metadata["column_types"]["created_at"] == "TIMESTAMP"
|
||||||
|
assert metadata["column_descriptions"]["customer_name"] == "Full name"
|
||||||
|
assert "order_id" not in metadata["column_descriptions"]
|
||||||
|
assert metadata["row_count"] == 5000
|
||||||
|
assert metadata["size_bytes"] == 120000
|
||||||
|
assert "_cached_at" in metadata
|
||||||
|
|
||||||
|
def test_caches_metadata_in_memory(self, client, mock_bq_client):
|
||||||
|
"""After first fetch, metadata is stored in the in-memory cache."""
|
||||||
|
table_id = "proj.dataset.tbl"
|
||||||
|
table_ref = _make_table_ref(table_id=table_id)
|
||||||
|
mock_bq_client.get_table.return_value = table_ref
|
||||||
|
|
||||||
|
client.get_table_metadata(table_id, use_cache=False)
|
||||||
|
|
||||||
|
assert table_id in client.metadata_cache
|
||||||
|
assert client.metadata_cache[table_id]["table_id"] == table_id
|
||||||
|
|
||||||
|
def test_captures_partitioning_info(self, client, mock_bq_client):
|
||||||
|
"""Partitioning metadata is captured when table is partitioned."""
|
||||||
|
table_id = "proj.dataset.events"
|
||||||
|
partition = MagicMock()
|
||||||
|
partition.type_ = "DAY"
|
||||||
|
partition.field = "event_date"
|
||||||
|
partition.expiration_ms = 7776000000
|
||||||
|
|
||||||
|
table_ref = _make_table_ref(table_id=table_id, time_partitioning=partition)
|
||||||
|
mock_bq_client.get_table.return_value = table_ref
|
||||||
|
|
||||||
|
metadata = client.get_table_metadata(table_id, use_cache=False)
|
||||||
|
|
||||||
|
assert metadata["partitioning"] is not None
|
||||||
|
assert metadata["partitioning"]["type"] == "DAY"
|
||||||
|
assert metadata["partitioning"]["field"] == "event_date"
|
||||||
|
assert metadata["partitioning"]["expiration_ms"] == 7776000000
|
||||||
|
|
||||||
|
def test_no_partitioning_when_absent(self, client, mock_bq_client):
|
||||||
|
"""Partitioning is None when table has no partitioning."""
|
||||||
|
table_id = "proj.dataset.simple"
|
||||||
|
table_ref = _make_table_ref(table_id=table_id, time_partitioning=None)
|
||||||
|
mock_bq_client.get_table.return_value = table_ref
|
||||||
|
|
||||||
|
metadata = client.get_table_metadata(table_id, use_cache=False)
|
||||||
|
assert metadata["partitioning"] is None
|
||||||
|
|
||||||
|
# -------------------------------------------------------------------
|
||||||
|
# 4. get_table_metadata uses cache when available (within TTL)
|
||||||
|
# -------------------------------------------------------------------
|
||||||
|
|
||||||
|
def test_uses_cache_within_ttl(self, client, mock_bq_client):
|
||||||
|
"""When cache is fresh (within TTL), BQ API is not called again."""
|
||||||
|
table_id = "proj.dataset.cached_tbl"
|
||||||
|
now = datetime.now()
|
||||||
|
client.metadata_cache[table_id] = {
|
||||||
|
"table_id": table_id,
|
||||||
|
"columns": ["a", "b"],
|
||||||
|
"column_types": {"a": "STRING", "b": "INTEGER"},
|
||||||
|
"_cached_at": now.isoformat(),
|
||||||
|
}
|
||||||
|
|
||||||
|
result = client.get_table_metadata(table_id, use_cache=True, cache_ttl_hours=24)
|
||||||
|
|
||||||
|
mock_bq_client.get_table.assert_not_called()
|
||||||
|
assert result["table_id"] == table_id
|
||||||
|
assert result["columns"] == ["a", "b"]
|
||||||
|
|
||||||
|
def test_refetches_when_cache_expired(self, client, mock_bq_client):
|
||||||
|
"""When cache is older than TTL, metadata is re-fetched from BQ."""
|
||||||
|
table_id = "proj.dataset.stale_tbl"
|
||||||
|
old_time = (datetime.now() - timedelta(hours=48)).isoformat()
|
||||||
|
client.metadata_cache[table_id] = {
|
||||||
|
"table_id": table_id,
|
||||||
|
"columns": ["old_col"],
|
||||||
|
"column_types": {"old_col": "STRING"},
|
||||||
|
"_cached_at": old_time,
|
||||||
|
}
|
||||||
|
|
||||||
|
table_ref = _make_table_ref(
|
||||||
|
table_id=table_id,
|
||||||
|
schema=[_make_bq_field("new_col", "INTEGER")],
|
||||||
|
)
|
||||||
|
mock_bq_client.get_table.return_value = table_ref
|
||||||
|
|
||||||
|
result = client.get_table_metadata(table_id, use_cache=True, cache_ttl_hours=24)
|
||||||
|
|
||||||
|
mock_bq_client.get_table.assert_called_once_with(table_id)
|
||||||
|
assert result["columns"] == ["new_col"]
|
||||||
|
|
||||||
|
def test_bypasses_cache_when_use_cache_false(self, client, mock_bq_client):
|
||||||
|
"""When use_cache=False, always fetches from BQ even if cache is fresh."""
|
||||||
|
table_id = "proj.dataset.force_fetch"
|
||||||
|
client.metadata_cache[table_id] = {
|
||||||
|
"table_id": table_id,
|
||||||
|
"columns": ["cached"],
|
||||||
|
"column_types": {"cached": "STRING"},
|
||||||
|
"_cached_at": datetime.now().isoformat(),
|
||||||
|
}
|
||||||
|
|
||||||
|
table_ref = _make_table_ref(
|
||||||
|
table_id=table_id,
|
||||||
|
schema=[_make_bq_field("fresh", "INTEGER")],
|
||||||
|
)
|
||||||
|
mock_bq_client.get_table.return_value = table_ref
|
||||||
|
|
||||||
|
result = client.get_table_metadata(table_id, use_cache=False)
|
||||||
|
mock_bq_client.get_table.assert_called_once()
|
||||||
|
assert result["columns"] == ["fresh"]
|
||||||
|
|
||||||
|
|
||||||
|
# ---------------------------------------------------------------------------
|
||||||
|
# 5. get_pyarrow_schema builds correct schema from BQ types
|
||||||
|
# ---------------------------------------------------------------------------
|
||||||
|
|
||||||
|
class TestGetPyarrowSchema:
|
||||||
|
def test_builds_correct_schema(self, client, mock_bq_client):
|
||||||
|
"""Schema maps BQ types to correct PyArrow types."""
|
||||||
|
table_id = "proj.dataset.typed_tbl"
|
||||||
|
schema = [
|
||||||
|
_make_bq_field("id", "INT64"),
|
||||||
|
_make_bq_field("name", "STRING"),
|
||||||
|
_make_bq_field("price", "FLOAT64"),
|
||||||
|
_make_bq_field("active", "BOOLEAN"),
|
||||||
|
_make_bq_field("created", "DATE"),
|
||||||
|
_make_bq_field("updated_at", "TIMESTAMP"),
|
||||||
|
]
|
||||||
|
table_ref = _make_table_ref(table_id=table_id, schema=schema)
|
||||||
|
mock_bq_client.get_table.return_value = table_ref
|
||||||
|
|
||||||
|
pa_schema = client.get_pyarrow_schema(table_id)
|
||||||
|
|
||||||
|
assert pa_schema is not None
|
||||||
|
assert pa_schema.field("id").type == pa.int64()
|
||||||
|
assert pa_schema.field("name").type == pa.string()
|
||||||
|
assert pa_schema.field("price").type == pa.float64()
|
||||||
|
assert pa_schema.field("active").type == pa.bool_()
|
||||||
|
assert pa_schema.field("created").type == pa.date32()
|
||||||
|
assert pa_schema.field("updated_at").type == pa.timestamp("us", tz="UTC")
|
||||||
|
|
||||||
|
def test_returns_none_when_no_column_types(self, client):
|
||||||
|
"""Returns None when metadata has no column_types."""
|
||||||
|
table_id = "proj.dataset.empty_schema"
|
||||||
|
client.metadata_cache[table_id] = {
|
||||||
|
"table_id": table_id,
|
||||||
|
"columns": [],
|
||||||
|
"column_types": {},
|
||||||
|
"_cached_at": datetime.now().isoformat(),
|
||||||
|
}
|
||||||
|
|
||||||
|
result = client.get_pyarrow_schema(table_id)
|
||||||
|
assert result is None
|
||||||
|
|
||||||
|
def test_unknown_type_falls_back_to_string(self, client, mock_bq_client):
|
||||||
|
"""Unknown BQ types default to pa.string() in the schema."""
|
||||||
|
table_id = "proj.dataset.exotic_types"
|
||||||
|
schema = [_make_bq_field("exotic_col", "SOME_UNKNOWN_TYPE")]
|
||||||
|
table_ref = _make_table_ref(table_id=table_id, schema=schema)
|
||||||
|
mock_bq_client.get_table.return_value = table_ref
|
||||||
|
|
||||||
|
pa_schema = client.get_pyarrow_schema(table_id)
|
||||||
|
assert pa_schema.field("exotic_col").type == pa.string()
|
||||||
|
|
||||||
|
|
||||||
|
# ---------------------------------------------------------------------------
|
||||||
|
# 6. get_date_columns returns only DATE columns
|
||||||
|
# ---------------------------------------------------------------------------
|
||||||
|
|
||||||
|
class TestGetDateColumns:
|
||||||
|
def test_returns_only_date_columns(self, client, mock_bq_client):
|
||||||
|
"""Only columns with BQ type DATE are returned."""
|
||||||
|
table_id = "proj.dataset.mixed_dates"
|
||||||
|
schema = [
|
||||||
|
_make_bq_field("event_date", "DATE"),
|
||||||
|
_make_bq_field("created_at", "TIMESTAMP"),
|
||||||
|
_make_bq_field("name", "STRING"),
|
||||||
|
_make_bq_field("birth_date", "DATE"),
|
||||||
|
_make_bq_field("updated_ts", "DATETIME"),
|
||||||
|
]
|
||||||
|
table_ref = _make_table_ref(table_id=table_id, schema=schema)
|
||||||
|
mock_bq_client.get_table.return_value = table_ref
|
||||||
|
|
||||||
|
date_cols = client.get_date_columns(table_id)
|
||||||
|
assert sorted(date_cols) == ["birth_date", "event_date"]
|
||||||
|
|
||||||
|
def test_returns_empty_when_no_date_columns(self, client, mock_bq_client):
|
||||||
|
"""Returns empty list when no DATE columns exist."""
|
||||||
|
table_id = "proj.dataset.no_dates"
|
||||||
|
schema = [
|
||||||
|
_make_bq_field("id", "INTEGER"),
|
||||||
|
_make_bq_field("ts", "TIMESTAMP"),
|
||||||
|
]
|
||||||
|
table_ref = _make_table_ref(table_id=table_id, schema=schema)
|
||||||
|
mock_bq_client.get_table.return_value = table_ref
|
||||||
|
|
||||||
|
date_cols = client.get_date_columns(table_id)
|
||||||
|
assert date_cols == []
|
||||||
|
|
||||||
|
|
||||||
|
# ---------------------------------------------------------------------------
|
||||||
|
# 7. query_to_arrow executes SQL and returns PyArrow table
|
||||||
|
# ---------------------------------------------------------------------------
|
||||||
|
|
||||||
|
class TestQueryToArrow:
|
||||||
|
def test_executes_query_and_returns_arrow(self, client, mock_bq_client):
|
||||||
|
"""query_to_arrow passes SQL to BQ and returns the arrow result."""
|
||||||
|
expected_table = pa.table({"col1": [1, 2, 3]})
|
||||||
|
mock_job = MagicMock()
|
||||||
|
mock_job.to_arrow.return_value = expected_table
|
||||||
|
mock_bq_client.query.return_value = mock_job
|
||||||
|
|
||||||
|
with patch("connectors.bigquery.client.bigquery") as mock_bq_module:
|
||||||
|
mock_bq_module.QueryJobConfig.return_value = MagicMock(query_parameters=None)
|
||||||
|
client.client = mock_bq_client
|
||||||
|
|
||||||
|
result = client.query_to_arrow("SELECT * FROM `proj.dataset.tbl`")
|
||||||
|
|
||||||
|
mock_bq_client.query.assert_called_once()
|
||||||
|
call_args = mock_bq_client.query.call_args
|
||||||
|
assert call_args[0][0] == "SELECT * FROM `proj.dataset.tbl`"
|
||||||
|
assert result.equals(expected_table)
|
||||||
|
|
||||||
|
def test_passes_query_parameters(self, client, mock_bq_client):
|
||||||
|
"""query_to_arrow forwards BQ query parameters in job config."""
|
||||||
|
expected_table = pa.table({"col1": [10]})
|
||||||
|
mock_job = MagicMock()
|
||||||
|
mock_job.to_arrow.return_value = expected_table
|
||||||
|
mock_bq_client.query.return_value = mock_job
|
||||||
|
|
||||||
|
mock_job_config = MagicMock()
|
||||||
|
params = [MagicMock()] # Mock ScalarQueryParameter
|
||||||
|
|
||||||
|
with patch("connectors.bigquery.client.bigquery") as mock_bq_module:
|
||||||
|
mock_bq_module.QueryJobConfig.return_value = mock_job_config
|
||||||
|
client.client = mock_bq_client
|
||||||
|
|
||||||
|
client.query_to_arrow("SELECT 1 WHERE x > @val", params=params)
|
||||||
|
|
||||||
|
# Verify params were set on the job config
|
||||||
|
assert mock_job_config.query_parameters == params
|
||||||
|
|
||||||
|
def test_no_params_does_not_set_query_parameters(self, client, mock_bq_client):
|
||||||
|
"""When no params given, query_parameters is not set on job config."""
|
||||||
|
mock_job = MagicMock()
|
||||||
|
mock_job.to_arrow.return_value = pa.table({"x": [1]})
|
||||||
|
mock_bq_client.query.return_value = mock_job
|
||||||
|
|
||||||
|
mock_job_config = MagicMock(spec=[])
|
||||||
|
with patch("connectors.bigquery.client.bigquery") as mock_bq_module:
|
||||||
|
mock_bq_module.QueryJobConfig.return_value = mock_job_config
|
||||||
|
client.client = mock_bq_client
|
||||||
|
|
||||||
|
client.query_to_arrow("SELECT 1")
|
||||||
|
|
||||||
|
# query_parameters should not have been set
|
||||||
|
assert not hasattr(mock_job_config, "query_parameters") or not getattr(
|
||||||
|
mock_job_config, "query_parameters", None
|
||||||
|
)
|
||||||
|
|
||||||
|
|
||||||
|
# ---------------------------------------------------------------------------
|
||||||
|
# 8. read_table builds correct SQL query
|
||||||
|
# ---------------------------------------------------------------------------
|
||||||
|
|
||||||
|
class TestReadTable:
|
||||||
|
def test_full_table_select_all(self, client, mock_bq_client):
|
||||||
|
"""read_table with no columns or filter generates SELECT *."""
|
||||||
|
mock_job = MagicMock()
|
||||||
|
mock_job.to_arrow.return_value = pa.table({"a": [1]})
|
||||||
|
mock_bq_client.query.return_value = mock_job
|
||||||
|
|
||||||
|
client.read_table("proj.dataset.tbl")
|
||||||
|
|
||||||
|
sql = mock_bq_client.query.call_args[0][0]
|
||||||
|
assert "SELECT *" in sql
|
||||||
|
assert "`proj.dataset.tbl`" in sql
|
||||||
|
assert "WHERE" not in sql
|
||||||
|
|
||||||
|
def test_select_specific_columns(self, client, mock_bq_client):
|
||||||
|
"""read_table with columns list generates SELECT with backtick-quoted names."""
|
||||||
|
mock_job = MagicMock()
|
||||||
|
mock_job.to_arrow.return_value = pa.table({"a": [1]})
|
||||||
|
mock_bq_client.query.return_value = mock_job
|
||||||
|
|
||||||
|
client.read_table("proj.dataset.tbl", columns=["col_a", "col_b"])
|
||||||
|
|
||||||
|
sql = mock_bq_client.query.call_args[0][0]
|
||||||
|
assert "`col_a`" in sql
|
||||||
|
assert "`col_b`" in sql
|
||||||
|
assert "*" not in sql
|
||||||
|
|
||||||
|
def test_with_row_filter(self, client, mock_bq_client):
|
||||||
|
"""read_table with row_filter appends WHERE clause."""
|
||||||
|
mock_job = MagicMock()
|
||||||
|
mock_job.to_arrow.return_value = pa.table({"a": [1]})
|
||||||
|
mock_bq_client.query.return_value = mock_job
|
||||||
|
|
||||||
|
client.read_table("proj.dataset.tbl", row_filter="status = 'active'")
|
||||||
|
|
||||||
|
sql = mock_bq_client.query.call_args[0][0]
|
||||||
|
assert "WHERE status = 'active'" in sql
|
||||||
|
|
||||||
|
def test_columns_and_filter_combined(self, client, mock_bq_client):
|
||||||
|
"""read_table with both columns and filter generates correct SQL."""
|
||||||
|
mock_job = MagicMock()
|
||||||
|
mock_job.to_arrow.return_value = pa.table({"x": [1]})
|
||||||
|
mock_bq_client.query.return_value = mock_job
|
||||||
|
|
||||||
|
client.read_table(
|
||||||
|
"proj.dataset.tbl",
|
||||||
|
columns=["id", "name"],
|
||||||
|
row_filter="id > 100",
|
||||||
|
)
|
||||||
|
|
||||||
|
sql = mock_bq_client.query.call_args[0][0]
|
||||||
|
assert "`id`, `name`" in sql
|
||||||
|
assert "WHERE id > 100" in sql
|
||||||
|
assert "`proj.dataset.tbl`" in sql
|
||||||
|
|
||||||
|
|
||||||
|
# ---------------------------------------------------------------------------
|
||||||
|
# 9. read_table_incremental builds parameterized WHERE clause
|
||||||
|
# ---------------------------------------------------------------------------
|
||||||
|
|
||||||
|
class TestReadTableIncremental:
|
||||||
|
def test_incremental_query_structure(self, client, mock_bq_client):
|
||||||
|
"""read_table_incremental builds WHERE col > @since_value with params."""
|
||||||
|
mock_job = MagicMock()
|
||||||
|
mock_job.to_arrow.return_value = pa.table({"a": [1]})
|
||||||
|
mock_bq_client.query.return_value = mock_job
|
||||||
|
|
||||||
|
with patch("connectors.bigquery.client.bigquery") as mock_bq_module:
|
||||||
|
mock_bq_module.QueryJobConfig.return_value = MagicMock()
|
||||||
|
mock_param = MagicMock()
|
||||||
|
mock_bq_module.ScalarQueryParameter.return_value = mock_param
|
||||||
|
# Re-assign the client's bq client (the fixture already set it up)
|
||||||
|
client.client = mock_bq_client
|
||||||
|
|
||||||
|
client.read_table_incremental(
|
||||||
|
table_id="proj.dataset.events",
|
||||||
|
incremental_column="updated_at",
|
||||||
|
since_value="2025-01-01T00:00:00Z",
|
||||||
|
)
|
||||||
|
|
||||||
|
sql = mock_bq_client.query.call_args[0][0]
|
||||||
|
assert "SELECT *" in sql
|
||||||
|
assert "`proj.dataset.events`" in sql
|
||||||
|
assert "`updated_at` > @since_value" in sql
|
||||||
|
|
||||||
|
# Verify ScalarQueryParameter was constructed correctly
|
||||||
|
mock_bq_module.ScalarQueryParameter.assert_called_once_with(
|
||||||
|
"since_value", "TIMESTAMP", "2025-01-01T00:00:00Z"
|
||||||
|
)
|
||||||
|
|
||||||
|
def test_incremental_with_columns(self, client, mock_bq_client):
|
||||||
|
"""read_table_incremental with columns list selects specific columns."""
|
||||||
|
mock_job = MagicMock()
|
||||||
|
mock_job.to_arrow.return_value = pa.table({"a": [1]})
|
||||||
|
mock_bq_client.query.return_value = mock_job
|
||||||
|
|
||||||
|
with patch("connectors.bigquery.client.bigquery") as mock_bq_module:
|
||||||
|
mock_bq_module.QueryJobConfig.return_value = MagicMock()
|
||||||
|
mock_bq_module.ScalarQueryParameter.return_value = MagicMock()
|
||||||
|
client.client = mock_bq_client
|
||||||
|
|
||||||
|
client.read_table_incremental(
|
||||||
|
table_id="proj.dataset.events",
|
||||||
|
incremental_column="updated_at",
|
||||||
|
since_value="2025-01-01T00:00:00Z",
|
||||||
|
columns=["id", "name"],
|
||||||
|
)
|
||||||
|
|
||||||
|
sql = mock_bq_client.query.call_args[0][0]
|
||||||
|
assert "`id`, `name`" in sql
|
||||||
|
assert "*" not in sql
|
||||||
|
|
||||||
|
|
||||||
|
# ---------------------------------------------------------------------------
|
||||||
|
# 10. read_table_partitioned builds correct range query
|
||||||
|
# ---------------------------------------------------------------------------
|
||||||
|
|
||||||
|
class TestReadTablePartitioned:
|
||||||
|
def test_partitioned_start_only(self, client, mock_bq_client):
|
||||||
|
"""With only start, generates >= @start_value without end clause."""
|
||||||
|
mock_job = MagicMock()
|
||||||
|
mock_job.to_arrow.return_value = pa.table({"a": [1]})
|
||||||
|
mock_bq_client.query.return_value = mock_job
|
||||||
|
|
||||||
|
with patch("connectors.bigquery.client.bigquery") as mock_bq_module:
|
||||||
|
mock_bq_module.QueryJobConfig.return_value = MagicMock()
|
||||||
|
mock_bq_module.ScalarQueryParameter.return_value = MagicMock()
|
||||||
|
client.client = mock_bq_client
|
||||||
|
|
||||||
|
client.read_table_partitioned(
|
||||||
|
table_id="proj.dataset.events",
|
||||||
|
partition_column="event_date",
|
||||||
|
start="2025-01-01",
|
||||||
|
)
|
||||||
|
|
||||||
|
sql = mock_bq_client.query.call_args[0][0]
|
||||||
|
assert "`event_date` >= @start_value" in sql
|
||||||
|
assert "@end_value" not in sql
|
||||||
|
|
||||||
|
# Only start_value parameter created
|
||||||
|
assert mock_bq_module.ScalarQueryParameter.call_count == 1
|
||||||
|
mock_bq_module.ScalarQueryParameter.assert_called_with(
|
||||||
|
"start_value", "TIMESTAMP", "2025-01-01"
|
||||||
|
)
|
||||||
|
|
||||||
|
def test_partitioned_start_and_end(self, client, mock_bq_client):
|
||||||
|
"""With start and end, generates >= @start_value AND < @end_value."""
|
||||||
|
mock_job = MagicMock()
|
||||||
|
mock_job.to_arrow.return_value = pa.table({"a": [1]})
|
||||||
|
mock_bq_client.query.return_value = mock_job
|
||||||
|
|
||||||
|
with patch("connectors.bigquery.client.bigquery") as mock_bq_module:
|
||||||
|
mock_bq_module.QueryJobConfig.return_value = MagicMock()
|
||||||
|
mock_bq_module.ScalarQueryParameter.return_value = MagicMock()
|
||||||
|
client.client = mock_bq_client
|
||||||
|
|
||||||
|
client.read_table_partitioned(
|
||||||
|
table_id="proj.dataset.events",
|
||||||
|
partition_column="event_date",
|
||||||
|
start="2025-01-01",
|
||||||
|
end="2025-06-01",
|
||||||
|
)
|
||||||
|
|
||||||
|
sql = mock_bq_client.query.call_args[0][0]
|
||||||
|
assert "`event_date` >= @start_value" in sql
|
||||||
|
assert "`event_date` < @end_value" in sql
|
||||||
|
|
||||||
|
# Both start_value and end_value parameters created
|
||||||
|
assert mock_bq_module.ScalarQueryParameter.call_count == 2
|
||||||
|
calls = mock_bq_module.ScalarQueryParameter.call_args_list
|
||||||
|
assert calls[0].args == ("start_value", "TIMESTAMP", "2025-01-01")
|
||||||
|
assert calls[1].args == ("end_value", "TIMESTAMP", "2025-06-01")
|
||||||
|
|
||||||
|
def test_partitioned_with_columns(self, client, mock_bq_client):
|
||||||
|
"""read_table_partitioned with columns selects specific columns."""
|
||||||
|
mock_job = MagicMock()
|
||||||
|
mock_job.to_arrow.return_value = pa.table({"a": [1]})
|
||||||
|
mock_bq_client.query.return_value = mock_job
|
||||||
|
|
||||||
|
with patch("connectors.bigquery.client.bigquery") as mock_bq_module:
|
||||||
|
mock_bq_module.QueryJobConfig.return_value = MagicMock()
|
||||||
|
mock_bq_module.ScalarQueryParameter.return_value = MagicMock()
|
||||||
|
client.client = mock_bq_client
|
||||||
|
|
||||||
|
client.read_table_partitioned(
|
||||||
|
table_id="proj.dataset.events",
|
||||||
|
partition_column="event_date",
|
||||||
|
start="2025-01-01",
|
||||||
|
columns=["id", "event_date", "value"],
|
||||||
|
)
|
||||||
|
|
||||||
|
sql = mock_bq_client.query.call_args[0][0]
|
||||||
|
assert "`id`, `event_date`, `value`" in sql
|
||||||
|
assert "*" not in sql
|
||||||
|
|
||||||
|
|
||||||
|
# ---------------------------------------------------------------------------
|
||||||
|
# 11. test_connection returns True on success, False on failure
|
||||||
|
# ---------------------------------------------------------------------------
|
||||||
|
|
||||||
|
class TestTestConnection:
|
||||||
|
def test_returns_true_on_success(self, client, mock_bq_client):
|
||||||
|
"""test_connection returns True when SELECT 1 query succeeds."""
|
||||||
|
mock_job = MagicMock()
|
||||||
|
mock_job.result.return_value = iter([(1,)])
|
||||||
|
mock_bq_client.query.return_value = mock_job
|
||||||
|
|
||||||
|
assert client.test_connection() is True
|
||||||
|
mock_bq_client.query.assert_called_once_with("SELECT 1")
|
||||||
|
|
||||||
|
def test_returns_false_on_failure(self, client, mock_bq_client):
|
||||||
|
"""test_connection returns False when the query raises an exception."""
|
||||||
|
mock_bq_client.query.side_effect = Exception("Connection refused")
|
||||||
|
|
||||||
|
assert client.test_connection() is False
|
||||||
|
|
||||||
|
def test_returns_false_when_result_fails(self, client, mock_bq_client):
|
||||||
|
"""test_connection returns False when result iteration fails."""
|
||||||
|
mock_job = MagicMock()
|
||||||
|
mock_job.result.side_effect = Exception("Timeout")
|
||||||
|
mock_bq_client.query.return_value = mock_job
|
||||||
|
|
||||||
|
assert client.test_connection() is False
|
||||||
|
|
||||||
|
|
||||||
|
# ---------------------------------------------------------------------------
|
||||||
|
# 12. Type mapping completeness (all BQ types have PyArrow mapping)
|
||||||
|
# ---------------------------------------------------------------------------
|
||||||
|
|
||||||
|
class TestTypeMapping:
|
||||||
|
# All standard BigQuery types that should be mapped
|
||||||
|
EXPECTED_BQ_TYPES = [
|
||||||
|
"STRING", "BYTES", "INTEGER", "INT64",
|
||||||
|
"FLOAT", "FLOAT64", "NUMERIC", "BIGNUMERIC",
|
||||||
|
"BOOLEAN", "BOOL",
|
||||||
|
"TIMESTAMP", "DATE", "TIME", "DATETIME",
|
||||||
|
"GEOGRAPHY", "JSON",
|
||||||
|
"STRUCT", "RECORD", "ARRAY",
|
||||||
|
]
|
||||||
|
|
||||||
|
def test_all_standard_bq_types_are_mapped(self):
|
||||||
|
"""Every standard BigQuery type has an entry in BIGQUERY_TO_PYARROW_TYPES."""
|
||||||
|
for bq_type in self.EXPECTED_BQ_TYPES:
|
||||||
|
assert bq_type in BIGQUERY_TO_PYARROW_TYPES, (
|
||||||
|
f"Missing PyArrow mapping for BQ type: {bq_type}"
|
||||||
|
)
|
||||||
|
|
||||||
|
def test_all_mappings_produce_valid_pyarrow_types(self):
|
||||||
|
"""Every mapped value is a valid PyArrow DataType."""
|
||||||
|
for bq_type, pa_type in BIGQUERY_TO_PYARROW_TYPES.items():
|
||||||
|
assert isinstance(pa_type, pa.DataType), (
|
||||||
|
f"BQ type {bq_type} maps to non-DataType: {pa_type!r}"
|
||||||
|
)
|
||||||
|
|
||||||
|
def test_integer_types_map_to_int64(self):
|
||||||
|
"""Both INTEGER and INT64 map to pa.int64()."""
|
||||||
|
assert BIGQUERY_TO_PYARROW_TYPES["INTEGER"] == pa.int64()
|
||||||
|
assert BIGQUERY_TO_PYARROW_TYPES["INT64"] == pa.int64()
|
||||||
|
|
||||||
|
def test_float_types_map_to_float64(self):
|
||||||
|
"""FLOAT, FLOAT64, NUMERIC, BIGNUMERIC all map to pa.float64()."""
|
||||||
|
for t in ["FLOAT", "FLOAT64", "NUMERIC", "BIGNUMERIC"]:
|
||||||
|
assert BIGQUERY_TO_PYARROW_TYPES[t] == pa.float64()
|
||||||
|
|
||||||
|
def test_boolean_types_map_to_bool(self):
|
||||||
|
"""Both BOOLEAN and BOOL map to pa.bool_()."""
|
||||||
|
assert BIGQUERY_TO_PYARROW_TYPES["BOOLEAN"] == pa.bool_()
|
||||||
|
assert BIGQUERY_TO_PYARROW_TYPES["BOOL"] == pa.bool_()
|
||||||
|
|
||||||
|
def test_date_maps_to_date32(self):
|
||||||
|
"""DATE maps to pa.date32()."""
|
||||||
|
assert BIGQUERY_TO_PYARROW_TYPES["DATE"] == pa.date32()
|
||||||
|
|
||||||
|
def test_timestamp_has_utc_timezone(self):
|
||||||
|
"""TIMESTAMP maps to pa.timestamp with UTC timezone."""
|
||||||
|
ts_type = BIGQUERY_TO_PYARROW_TYPES["TIMESTAMP"]
|
||||||
|
assert ts_type == pa.timestamp("us", tz="UTC")
|
||||||
|
|
||||||
|
def test_datetime_has_no_timezone(self):
|
||||||
|
"""DATETIME maps to pa.timestamp without timezone."""
|
||||||
|
dt_type = BIGQUERY_TO_PYARROW_TYPES["DATETIME"]
|
||||||
|
assert dt_type == pa.timestamp("us")
|
||||||
|
|
||||||
|
def test_complex_types_map_to_string(self):
|
||||||
|
"""STRUCT, RECORD, ARRAY, GEOGRAPHY, JSON all serialize as string."""
|
||||||
|
for t in ["STRUCT", "RECORD", "ARRAY", "GEOGRAPHY", "JSON"]:
|
||||||
|
assert BIGQUERY_TO_PYARROW_TYPES[t] == pa.string()
|
||||||
|
|
||||||
|
|
||||||
|
# ---------------------------------------------------------------------------
|
||||||
|
# 13. Metadata cache save/load from disk
|
||||||
|
# ---------------------------------------------------------------------------
|
||||||
|
|
||||||
|
class TestMetadataCachePersistence:
|
||||||
|
def test_save_and_load_cache(self, tmp_path):
|
||||||
|
"""Metadata cache is persisted to disk and reloaded on new client init."""
|
||||||
|
metadata_dir = tmp_path / "metadata"
|
||||||
|
metadata_dir.mkdir(parents=True, exist_ok=True)
|
||||||
|
cache_file = metadata_dir / "bq_table_metadata.json"
|
||||||
|
|
||||||
|
mock_config = MagicMock()
|
||||||
|
mock_config.get_metadata_path.return_value = metadata_dir
|
||||||
|
|
||||||
|
# First client: fetch metadata and save to cache
|
||||||
|
mock_bq = MagicMock()
|
||||||
|
table_id = "proj.ds.tbl"
|
||||||
|
schema = [_make_bq_field("col1", "STRING")]
|
||||||
|
table_ref = _make_table_ref(table_id=table_id, schema=schema)
|
||||||
|
mock_bq.get_table.return_value = table_ref
|
||||||
|
|
||||||
|
with (
|
||||||
|
patch("connectors.bigquery.client.bigquery.Client", return_value=mock_bq),
|
||||||
|
patch("connectors.bigquery.client.get_config", return_value=mock_config),
|
||||||
|
patch.dict("os.environ", {"BIGQUERY_PROJECT": "proj"}),
|
||||||
|
):
|
||||||
|
client1 = BigQueryClient()
|
||||||
|
client1.get_table_metadata(table_id, use_cache=False)
|
||||||
|
|
||||||
|
# Verify the cache file was written
|
||||||
|
assert cache_file.exists()
|
||||||
|
saved_data = json.loads(cache_file.read_text())
|
||||||
|
assert table_id in saved_data
|
||||||
|
assert saved_data[table_id]["columns"] == ["col1"]
|
||||||
|
|
||||||
|
# Second client: loads cache from disk on init
|
||||||
|
mock_bq2 = MagicMock()
|
||||||
|
with (
|
||||||
|
patch("connectors.bigquery.client.bigquery.Client", return_value=mock_bq2),
|
||||||
|
patch("connectors.bigquery.client.get_config", return_value=mock_config),
|
||||||
|
patch.dict("os.environ", {"BIGQUERY_PROJECT": "proj"}),
|
||||||
|
):
|
||||||
|
client2 = BigQueryClient()
|
||||||
|
|
||||||
|
assert table_id in client2.metadata_cache
|
||||||
|
assert client2.metadata_cache[table_id]["columns"] == ["col1"]
|
||||||
|
|
||||||
|
def test_load_handles_corrupt_cache_file(self, tmp_path):
|
||||||
|
"""Client handles corrupt cache JSON gracefully without crashing."""
|
||||||
|
metadata_dir = tmp_path / "metadata"
|
||||||
|
metadata_dir.mkdir(parents=True, exist_ok=True)
|
||||||
|
cache_file = metadata_dir / "bq_table_metadata.json"
|
||||||
|
cache_file.write_text("{corrupt json!!!")
|
||||||
|
|
||||||
|
mock_config = MagicMock()
|
||||||
|
mock_config.get_metadata_path.return_value = metadata_dir
|
||||||
|
|
||||||
|
mock_bq = MagicMock()
|
||||||
|
with (
|
||||||
|
patch("connectors.bigquery.client.bigquery.Client", return_value=mock_bq),
|
||||||
|
patch("connectors.bigquery.client.get_config", return_value=mock_config),
|
||||||
|
patch.dict("os.environ", {"BIGQUERY_PROJECT": "proj"}),
|
||||||
|
):
|
||||||
|
client = BigQueryClient()
|
||||||
|
|
||||||
|
# Cache should be empty after corrupt file
|
||||||
|
assert client.metadata_cache == {}
|
||||||
|
|
||||||
|
def test_load_handles_missing_cache_file(self, tmp_path):
|
||||||
|
"""Client initializes with empty cache when no cache file exists."""
|
||||||
|
metadata_dir = tmp_path / "metadata"
|
||||||
|
metadata_dir.mkdir(parents=True, exist_ok=True)
|
||||||
|
# No cache file created
|
||||||
|
|
||||||
|
mock_config = MagicMock()
|
||||||
|
mock_config.get_metadata_path.return_value = metadata_dir
|
||||||
|
|
||||||
|
mock_bq = MagicMock()
|
||||||
|
with (
|
||||||
|
patch("connectors.bigquery.client.bigquery.Client", return_value=mock_bq),
|
||||||
|
patch("connectors.bigquery.client.get_config", return_value=mock_config),
|
||||||
|
patch.dict("os.environ", {"BIGQUERY_PROJECT": "proj"}),
|
||||||
|
):
|
||||||
|
client = BigQueryClient()
|
||||||
|
|
||||||
|
assert client.metadata_cache == {}
|
||||||
|
|
||||||
|
def test_save_creates_parent_directories(self, tmp_path):
|
||||||
|
"""_save_metadata_cache creates parent directories if they do not exist."""
|
||||||
|
# Use a nested path that does not yet exist
|
||||||
|
metadata_dir = tmp_path / "deep" / "nested" / "metadata"
|
||||||
|
# Do NOT create directories upfront
|
||||||
|
|
||||||
|
mock_config = MagicMock()
|
||||||
|
mock_config.get_metadata_path.return_value = metadata_dir
|
||||||
|
|
||||||
|
mock_bq = MagicMock()
|
||||||
|
table_id = "proj.ds.tbl"
|
||||||
|
schema = [_make_bq_field("x", "INTEGER")]
|
||||||
|
table_ref = _make_table_ref(table_id=table_id, schema=schema)
|
||||||
|
mock_bq.get_table.return_value = table_ref
|
||||||
|
|
||||||
|
with (
|
||||||
|
patch("connectors.bigquery.client.bigquery.Client", return_value=mock_bq),
|
||||||
|
patch("connectors.bigquery.client.get_config", return_value=mock_config),
|
||||||
|
patch.dict("os.environ", {"BIGQUERY_PROJECT": "proj"}),
|
||||||
|
):
|
||||||
|
client = BigQueryClient()
|
||||||
|
client.get_table_metadata(table_id, use_cache=False)
|
||||||
|
|
||||||
|
cache_file = metadata_dir / "bq_table_metadata.json"
|
||||||
|
assert cache_file.exists()
|
||||||
|
|
||||||
|
|
||||||
|
# ---------------------------------------------------------------------------
|
||||||
|
# Factory function
|
||||||
|
# ---------------------------------------------------------------------------
|
||||||
|
|
||||||
|
class TestCreateClient:
|
||||||
|
def test_create_client_returns_bigquery_client(self, mock_config):
|
||||||
|
"""create_client() factory returns a BigQueryClient instance."""
|
||||||
|
mock_bq = MagicMock()
|
||||||
|
with (
|
||||||
|
patch("connectors.bigquery.client.bigquery.Client", return_value=mock_bq),
|
||||||
|
patch("connectors.bigquery.client.get_config", return_value=mock_config),
|
||||||
|
patch.dict("os.environ", {"BIGQUERY_PROJECT": "factory-project"}),
|
||||||
|
):
|
||||||
|
result = create_client()
|
||||||
|
|
||||||
|
assert isinstance(result, BigQueryClient)
|
||||||
|
assert result.project_id == "factory-project"
|
||||||
Loading…
Reference in a new issue