Add BigQuery data source adapter

BigQuery connector that syncs BQ tables to local Parquet files via PyArrow
(no CSV intermediate step). Supports full refresh, timestamp-based
incremental (via incremental_column), and partition-based sync strategies.

- connectors/bigquery/client.py: BQ API wrapper with ADC auth, parameterized
  queries, metadata cache, cross-project support (job project != data project)
- connectors/bigquery/adapter.py: DataSource implementation with merge/dedup
- src/config.py: Add incremental_column field to TableConfig
- 72 unit tests (mocked, no GCP SDK required)
This commit is contained in:
Petr 2026-03-11 13:56:12 +01:00
parent eb5264b903
commit 758910463b
9 changed files with 2619 additions and 2 deletions

View file

@ -44,13 +44,38 @@ auth:
google_client_id: "${GOOGLE_CLIENT_ID}"
google_client_secret: "${GOOGLE_CLIENT_SECRET}"
# --- Theme (optional) ---
# Customize colors, fonts, and shape to match your brand.
# All values are optional - defaults provide a clean blue theme.
# See docs/theme-reference.html for a visual guide.
theme:
# primary: "#0073D1" # Main brand color (buttons, links, accents)
# primary_dark: "#005BA3" # Hover/active state of primary
# primary_light: "rgba(0, 115, 209, 0.1)" # Light tint backgrounds
# text_primary: "#1A253C" # Main text color
# text_secondary: "#6B7280" # Muted/secondary text
# background: "#F5F7FA" # Page background
# surface: "#FFFFFF" # Card/panel background
# border: "#E5E7EB" # Borders and dividers
# font_primary: "'Inter', system-ui, sans-serif"
# font_url: "https://fonts.googleapis.com/css2?family=Inter:wght@400;500;600;700&display=swap"
# radius: "6px" # Border radius (cards, buttons, inputs)
# success: "#10B77F"
# warning: "#F59F0A"
# error: "#EA580C"
# --- Data source ---
data_source:
type: "keboola" # keboola | csv (bigquery planned)
type: "keboola" # keboola | bigquery | local
keboola:
storage_token: "${KEBOOLA_STORAGE_TOKEN}"
stack_url: "" # e.g., "https://connection.keboola.com"
project_id: ""
bigquery:
project: "${BIGQUERY_PROJECT}" # GCP project for job execution/billing
location: "${BIGQUERY_LOCATION}" # BigQuery location (e.g., "us-central1", "US")
# Uses ADC (Application Default Credentials) - VM service account on GCP
# Data can live in a different project -- use fully-qualified table IDs in data_description.md
# --- Email delivery (optional, for magic link auth) ---
# Without SMTP, magic links are shown directly in browser (development mode).

View file

@ -0,0 +1,11 @@
"""
BigQuery connector - data source adapter for Google BigQuery.
Syncs tables from BigQuery using the BigQuery Storage API,
converting query results directly to Parquet files via PyArrow
(no CSV intermediate step).
Enable by setting data_source.type: "bigquery" in config/instance.yaml
and providing BIGQUERY_PROJECT environment variable.
Uses Application Default Credentials (ADC) for authentication.
"""

View file

@ -0,0 +1,475 @@
"""
BigQuery data source adapter.
Implements the DataSource interface for Google BigQuery.
Reads tables via the BigQuery API, converts directly to Parquet files
using PyArrow (no CSV intermediate step).
"""
import logging
from pathlib import Path
from typing import Dict, List, Optional, Any
from datetime import datetime, timedelta
import pyarrow as pa
import pyarrow.parquet as pq
from src.config import get_config, TableConfig
from src.data_sync import DataSource, SyncState, _get_uncompressed_size
from src.parquet_manager import (
convert_date_columns_to_date32,
apply_schema_to_table,
)
from .client import create_client as create_bq_client
logger = logging.getLogger(__name__)
class BigQueryDataSource(DataSource):
"""
Data source: Google BigQuery.
Downloads data directly from BigQuery via PyArrow (no CSV step),
writes to local Parquet files with schema enforcement.
"""
def __init__(self):
"""Initialize BigQuery source with env var validation."""
self.config = get_config()
self.bq_client = create_bq_client()
def get_column_metadata(self, table_id: str) -> Optional[Dict[str, Any]]:
"""Return BigQuery column metadata for schema generation.
Returns:
{"columns": {"col_name": {"source_type": "...", "description": "..."}}}
or None if metadata unavailable.
"""
raw = self.bq_client.get_table_metadata(table_id)
column_types = raw.get("column_types", {})
column_descriptions = raw.get("column_descriptions", {})
if not column_types:
return None
result = {}
for col_name, bq_type in column_types.items():
entry = {"source_type": bq_type}
if col_name in column_descriptions:
entry["description"] = column_descriptions[col_name]
result[col_name] = entry
return {"columns": result}
def discover_tables(self) -> List[Dict[str, Any]]:
"""Discover all available tables from BigQuery."""
return self.bq_client.discover_all_tables()
def get_source_name(self) -> str:
"""Display name of this data source."""
return "Google BigQuery"
def sync_table(
self,
table_config: TableConfig,
sync_state: SyncState,
) -> Dict[str, Any]:
"""
Synchronize table from BigQuery.
Dispatches to the appropriate strategy based on table config.
Args:
table_config: Table configuration
sync_state: Sync state manager
Returns:
Dictionary with sync result
"""
logger.info(f"Syncing BQ table: {table_config.name} ({table_config.sync_strategy})")
# Clear metadata cache for fresh types
if table_config.id in self.bq_client.metadata_cache:
del self.bq_client.metadata_cache[table_config.id]
logger.debug(f"Cleared BQ metadata cache for {table_config.id}")
try:
if table_config.sync_strategy == "full_refresh":
result = self._full_refresh(table_config)
elif table_config.sync_strategy == "incremental":
result = self._incremental_sync(table_config, sync_state)
elif table_config.sync_strategy == "partitioned":
result = self._partitioned_sync(table_config, sync_state)
else:
raise ValueError(f"Unknown sync strategy: {table_config.sync_strategy}")
# Update sync state
sync_state.update_sync(
table_id=table_config.id,
table_name=table_config.name,
strategy=table_config.sync_strategy,
rows=result["rows"],
file_size_bytes=result["file_size_bytes"],
columns=result.get("columns", 0),
uncompressed_bytes=result.get("uncompressed_bytes", 0),
)
return {
"success": True,
"rows": result["rows"],
"strategy": table_config.sync_strategy,
"file_size_mb": result["file_size_bytes"] / 1024 / 1024,
}
except Exception as e:
logger.error(f"Error syncing BQ table {table_config.name}: {e}")
return {
"success": False,
"error": str(e),
"strategy": table_config.sync_strategy,
}
def _full_refresh(self, table_config: TableConfig) -> Dict[str, Any]:
"""
Full refresh: read entire table and replace Parquet file.
"""
logger.info(f"Full refresh: {table_config.name}")
parquet_path = self.config.get_parquet_path(table_config)
date_columns = self.bq_client.get_date_columns(table_config.id)
pyarrow_schema = self.bq_client.get_pyarrow_schema(table_config.id)
# Read full table from BigQuery -> PyArrow
arrow_table = self.bq_client.read_table(table_config.id)
# Apply schema enforcement
if date_columns:
arrow_table = convert_date_columns_to_date32(arrow_table, date_columns)
if pyarrow_schema:
arrow_table = apply_schema_to_table(arrow_table, pyarrow_schema)
# Write to Parquet
parquet_path.parent.mkdir(parents=True, exist_ok=True)
pq.write_table(arrow_table, parquet_path, compression="snappy")
file_size = parquet_path.stat().st_size
logger.info(
f"Full refresh complete: {arrow_table.num_rows} rows, "
f"{file_size / 1024 / 1024:.2f} MB"
)
return {
"rows": arrow_table.num_rows,
"columns": arrow_table.num_columns,
"file_size_bytes": file_size,
"uncompressed_bytes": _get_uncompressed_size(parquet_path),
}
def _incremental_sync(
self,
table_config: TableConfig,
sync_state: SyncState,
) -> Dict[str, Any]:
"""
Incremental sync: dispatch to column-based or partition-based strategy.
"""
# If partition_by is set, use partitioned incremental
if table_config.partition_by:
return self._partitioned_sync(table_config, sync_state)
# If incremental_column is set, use timestamp-based incremental
if table_config.incremental_column:
return self._incremental_column_sync(table_config, sync_state)
# Fallback: full refresh (no incremental column configured)
logger.warning(
f"Table {table_config.name}: incremental strategy but no "
f"incremental_column or partition_by configured, falling back to full refresh"
)
return self._full_refresh(table_config)
def _incremental_column_sync(
self,
table_config: TableConfig,
sync_state: SyncState,
) -> Dict[str, Any]:
"""
Timestamp-based incremental sync using incremental_column.
Reads rows WHERE incremental_column > last_sync_value,
merges with existing Parquet (dedup on primary key).
"""
logger.info(
f"Incremental column sync: {table_config.name} "
f"(column: {table_config.incremental_column})"
)
parquet_path = self.config.get_parquet_path(table_config)
date_columns = self.bq_client.get_date_columns(table_config.id)
pyarrow_schema = self.bq_client.get_pyarrow_schema(table_config.id)
# Determine since_value from last sync
last_sync = sync_state.get_last_sync(table_config.id)
if last_sync and parquet_path.exists():
# Apply window: go back incremental_window_days from last sync
last_sync_dt = datetime.fromisoformat(last_sync)
window_days = table_config.incremental_window_days or 7
since_dt = last_sync_dt - timedelta(days=window_days)
since_value = since_dt.isoformat()
logger.info(f" -> Since: {since_value} (window: {window_days} days)")
# Read incremental data
new_data = self.bq_client.read_table_incremental(
table_id=table_config.id,
incremental_column=table_config.incremental_column,
since_value=since_value,
)
if new_data.num_rows == 0:
logger.info(" -> No new data since last sync")
existing_pf = pq.ParquetFile(parquet_path)
return {
"rows": existing_pf.metadata.num_rows,
"columns": len(existing_pf.schema_arrow),
"file_size_bytes": parquet_path.stat().st_size,
"uncompressed_bytes": _get_uncompressed_size(parquet_path),
}
# Merge with existing data
logger.info(f" -> Merging {new_data.num_rows} new rows with existing data")
existing_table = pq.read_table(parquet_path)
merged = self._merge_arrow_tables(
existing_table, new_data, table_config.get_primary_key_columns()
)
# Apply schema enforcement
if date_columns:
merged = convert_date_columns_to_date32(merged, date_columns)
if pyarrow_schema:
merged = apply_schema_to_table(merged, pyarrow_schema)
pq.write_table(merged, parquet_path, compression="snappy")
file_size = parquet_path.stat().st_size
logger.info(
f" -> Incremental sync complete: {merged.num_rows} total rows"
)
return {
"rows": merged.num_rows,
"columns": merged.num_columns,
"file_size_bytes": file_size,
"uncompressed_bytes": _get_uncompressed_size(parquet_path),
}
else:
# First sync or no existing file -- full read
logger.info(" -> First sync, reading all data")
if table_config.max_history_days:
since_dt = datetime.now() - timedelta(days=table_config.max_history_days)
arrow_table = self.bq_client.read_table_incremental(
table_id=table_config.id,
incremental_column=table_config.incremental_column,
since_value=since_dt.isoformat(),
)
else:
arrow_table = self.bq_client.read_table(table_config.id)
# Apply schema enforcement
if date_columns:
arrow_table = convert_date_columns_to_date32(arrow_table, date_columns)
if pyarrow_schema:
arrow_table = apply_schema_to_table(arrow_table, pyarrow_schema)
parquet_path.parent.mkdir(parents=True, exist_ok=True)
pq.write_table(arrow_table, parquet_path, compression="snappy")
file_size = parquet_path.stat().st_size
return {
"rows": arrow_table.num_rows,
"columns": arrow_table.num_columns,
"file_size_bytes": file_size,
"uncompressed_bytes": _get_uncompressed_size(parquet_path),
}
def _partitioned_sync(
self,
table_config: TableConfig,
sync_state: SyncState,
) -> Dict[str, Any]:
"""
Partition-based sync: read data by partition range and write partition files.
"""
import pandas as pd
partition_col = table_config.partition_by
if not partition_col and table_config.incremental_column:
partition_col = table_config.incremental_column
if not partition_col:
logger.warning(
f"Table {table_config.name}: partitioned strategy but no "
f"partition_by or incremental_column, falling back to full refresh"
)
return self._full_refresh(table_config)
granularity = table_config.partition_granularity or "month"
logger.info(
f"Partitioned sync: {table_config.name} "
f"(by {partition_col}, {granularity})"
)
partition_dir = self.config.get_parquet_path(table_config)
date_columns = self.bq_client.get_date_columns(table_config.id)
pyarrow_schema = self.bq_client.get_pyarrow_schema(table_config.id)
# Determine time range
last_sync = sync_state.get_last_sync(table_config.id)
if last_sync:
last_sync_dt = datetime.fromisoformat(last_sync)
window_days = table_config.incremental_window_days or 7
start_dt = last_sync_dt - timedelta(days=window_days)
logger.info(f" -> Reading from {start_dt.isoformat()} (window: {window_days} days)")
else:
if table_config.max_history_days:
start_dt = datetime.now() - timedelta(days=table_config.max_history_days)
logger.info(f" -> First sync, limited to last {table_config.max_history_days} days")
else:
start_dt = None
logger.info(" -> First sync, reading all data")
# Read data from BigQuery
if start_dt:
arrow_table = self.bq_client.read_table_partitioned(
table_id=table_config.id,
partition_column=partition_col,
start=start_dt.isoformat(),
)
else:
arrow_table = self.bq_client.read_table(table_config.id)
if arrow_table.num_rows == 0:
logger.info(" -> No data to sync")
return self._get_partition_totals(partition_dir)
logger.info(f" -> Processing {arrow_table.num_rows} rows into partitions")
# Convert to pandas for partitioning
df = arrow_table.to_pandas()
# Ensure partition column is datetime
if not pd.api.types.is_datetime64_any_dtype(df[partition_col]):
df[partition_col] = pd.to_datetime(df[partition_col], format="ISO8601", utc=True)
# Create partition key
if granularity == "month":
df["_partition_key"] = df[partition_col].dt.strftime("%Y_%m")
elif granularity == "day":
df["_partition_key"] = df[partition_col].dt.strftime("%Y_%m_%d")
elif granularity == "year":
df["_partition_key"] = df[partition_col].dt.strftime("%Y")
primary_key_cols = table_config.get_primary_key_columns()
partitions_updated = set()
for partition_key, group_df in df.groupby("_partition_key"):
group_df = group_df.drop(columns=["_partition_key"])
partition_path = self.config.get_partition_path(table_config, partition_key)
partitions_updated.add(partition_key)
# Merge with existing partition if it exists
if partition_path.exists():
existing_df = pd.read_parquet(partition_path)
merged_df = pd.concat([existing_df, group_df], ignore_index=True)
merged_df = merged_df.drop_duplicates(subset=primary_key_cols, keep="last")
else:
merged_df = group_df
# Write partition
table = pa.Table.from_pandas(merged_df, preserve_index=False)
if date_columns:
table = convert_date_columns_to_date32(table, date_columns)
if pyarrow_schema:
table = apply_schema_to_table(table, pyarrow_schema)
pq.write_table(table, partition_path, compression="snappy")
logger.info(f" -> Partitioned sync complete: {len(partitions_updated)} partitions updated")
return self._get_partition_totals(partition_dir)
def _merge_arrow_tables(
self,
existing: pa.Table,
new_data: pa.Table,
primary_key: List[str],
) -> pa.Table:
"""
Merge two Arrow tables with deduplication on primary key.
New data overwrites existing rows with the same primary key.
Args:
existing: Existing data
new_data: New/changed data
primary_key: List of PK column names
Returns:
Merged PyArrow Table
"""
import pandas as pd
existing_df = existing.to_pandas()
new_df = new_data.to_pandas()
# Concat and dedup (keep last = new data wins)
merged_df = pd.concat([existing_df, new_df], ignore_index=True)
merged_df = merged_df.drop_duplicates(subset=primary_key, keep="last")
return pa.Table.from_pandas(merged_df, preserve_index=False)
def _get_partition_totals(self, partition_dir: Path) -> Dict[str, Any]:
"""
Calculate totals from all partition files in a directory.
"""
total_rows = 0
total_size = 0
total_uncompressed = 0
total_columns = 0
if not partition_dir.exists():
return {"rows": 0, "file_size_bytes": 0, "columns": 0, "uncompressed_bytes": 0}
all_partitions = list(partition_dir.glob("*.parquet"))
for part_path in all_partitions:
try:
pf = pq.ParquetFile(part_path)
meta = pf.metadata
total_rows += meta.num_rows
total_size += part_path.stat().st_size
if total_columns == 0:
total_columns = len(pf.schema_arrow)
for rg_idx in range(meta.num_row_groups):
rg = meta.row_group(rg_idx)
for col_idx in range(rg.num_columns):
total_uncompressed += rg.column(col_idx).total_uncompressed_size
except Exception as e:
logger.warning(f"Skipping corrupt partition {part_path.name}: {e}")
return {
"rows": total_rows,
"file_size_bytes": total_size,
"partitions": len(all_partitions),
"columns": total_columns,
"uncompressed_bytes": total_uncompressed,
}
def create_data_source() -> BigQueryDataSource:
"""Factory function for dynamic import compatibility."""
return BigQueryDataSource()

View file

@ -0,0 +1,469 @@
"""
Google BigQuery API Client
Low-level wrapper for Google BigQuery with these functions:
1. Authentication using Application Default Credentials (ADC)
2. Query tables to PyArrow (no CSV intermediate step)
3. Get table metadata (schema, columns, data types)
4. Cache metadata for faster repeated use
5. Incremental reads (timestamp-based and partition-based)
Uses google-cloud-bigquery with native PyArrow support.
"""
import json
import logging
import os
from pathlib import Path
from typing import Dict, List, Optional, Any
from datetime import datetime, timedelta
import pyarrow as pa
from google.cloud import bigquery
from src.config import get_config
logger = logging.getLogger(__name__)
# Mapping BigQuery types to PyArrow types
BIGQUERY_TO_PYARROW_TYPES = {
"STRING": pa.string(),
"BYTES": pa.binary(),
"INTEGER": pa.int64(),
"INT64": pa.int64(),
"FLOAT": pa.float64(),
"FLOAT64": pa.float64(),
"NUMERIC": pa.float64(),
"BIGNUMERIC": pa.float64(),
"BOOLEAN": pa.bool_(),
"BOOL": pa.bool_(),
"TIMESTAMP": pa.timestamp("us", tz="UTC"),
"DATE": pa.date32(),
"TIME": pa.string(),
"DATETIME": pa.timestamp("us"),
"GEOGRAPHY": pa.string(),
"JSON": pa.string(),
"STRUCT": pa.string(),
"RECORD": pa.string(),
"ARRAY": pa.string(),
}
class BigQueryClient:
"""
Wrapper for Google BigQuery API.
Provides high-level methods for working with BigQuery tables:
- Query tables to PyArrow Tables (no CSV step)
- Get metadata (schema, columns)
- Incremental and partitioned reads
"""
def __init__(
self,
project_id: Optional[str] = None,
location: Optional[str] = None,
):
"""
Initialize BigQuery client.
Args:
project_id: GCP project ID for job execution/billing.
If None, reads from BIGQUERY_PROJECT env var.
location: BigQuery location for job execution (e.g., "us-central1").
If None, reads from BIGQUERY_LOCATION env var.
Raises:
ValueError: If project_id is not provided and BIGQUERY_PROJECT is not set.
"""
self.project_id = project_id or os.environ.get("BIGQUERY_PROJECT")
if not self.project_id:
raise ValueError(
"BigQuery project ID not set. "
"Set BIGQUERY_PROJECT environment variable."
)
self.location = location or os.environ.get("BIGQUERY_LOCATION")
# Initialize BigQuery client with ADC
# project_id is used for job execution and billing.
# Data can live in a different project -- table IDs in queries
# use fully-qualified format (project.dataset.table).
client_kwargs = {"project": self.project_id}
if self.location:
client_kwargs["location"] = self.location
self.client = bigquery.Client(**client_kwargs)
# Metadata cache
config = get_config()
self.metadata_cache: Dict[str, Dict[str, Any]] = {}
self.metadata_cache_path = config.get_metadata_path() / "bq_table_metadata.json"
# Load cache from disk if exists
self._load_metadata_cache()
logger.info(
f"BigQuery client initialized: project={self.project_id}, "
f"location={self.location or 'auto'}"
)
def _load_metadata_cache(self):
"""Load metadata cache from disk."""
if self.metadata_cache_path.exists():
try:
with open(self.metadata_cache_path, "r") as f:
self.metadata_cache = json.load(f)
logger.info(
f"BQ metadata cache loaded: {len(self.metadata_cache)} tables"
)
except Exception as e:
logger.warning(f"Error loading BQ metadata cache: {e}")
self.metadata_cache = {}
def _save_metadata_cache(self):
"""Save metadata cache to disk."""
try:
self.metadata_cache_path.parent.mkdir(parents=True, exist_ok=True)
with open(self.metadata_cache_path, "w") as f:
json.dump(self.metadata_cache, f, indent=2)
logger.debug("BQ metadata cache saved")
except Exception as e:
logger.warning(f"Error saving BQ metadata cache: {e}")
def get_table_metadata(
self,
table_id: str,
use_cache: bool = True,
cache_ttl_hours: int = 24,
) -> Dict[str, Any]:
"""
Get table metadata from BigQuery.
Args:
table_id: Full table ID (e.g., "project.dataset.table")
use_cache: Use cache if available
cache_ttl_hours: Cache TTL in hours (default 24h)
Returns:
Dictionary with metadata including columns, types, descriptions, row count.
"""
# Check cache
if use_cache and table_id in self.metadata_cache:
cached = self.metadata_cache[table_id]
cached_time = datetime.fromisoformat(cached.get("_cached_at", "2000-01-01"))
cache_age = datetime.now() - cached_time
if cache_age < timedelta(hours=cache_ttl_hours):
logger.debug(f"Using BQ metadata cache for {table_id}")
return cached
logger.info(f"Fetching metadata for BQ table: {table_id}")
try:
table_ref = self.client.get_table(table_id)
# Build column metadata
columns = []
column_types = {}
column_descriptions = {}
for field in table_ref.schema:
columns.append(field.name)
column_types[field.name] = field.field_type
if field.description:
column_descriptions[field.name] = field.description
metadata = {
"table_id": table_id,
"name": table_ref.table_id,
"dataset": table_ref.dataset_id,
"project": table_ref.project,
"columns": columns,
"column_types": column_types,
"column_descriptions": column_descriptions,
"row_count": table_ref.num_rows,
"size_bytes": table_ref.num_bytes,
"created": table_ref.created.isoformat() if table_ref.created else None,
"modified": table_ref.modified.isoformat() if table_ref.modified else None,
"partitioning": None,
"_cached_at": datetime.now().isoformat(),
}
# Capture partitioning info
if table_ref.time_partitioning:
metadata["partitioning"] = {
"type": table_ref.time_partitioning.type_,
"field": table_ref.time_partitioning.field,
"expiration_ms": table_ref.time_partitioning.expiration_ms,
}
# Save to cache
self.metadata_cache[table_id] = metadata
self._save_metadata_cache()
return metadata
except Exception as e:
logger.error(f"Error getting metadata for {table_id}: {e}")
raise
def get_pyarrow_schema(self, table_id: str) -> Optional[pa.Schema]:
"""
Build PyArrow schema from BigQuery table schema.
Args:
table_id: Full table ID
Returns:
PyArrow schema or None if metadata unavailable
"""
metadata = self.get_table_metadata(table_id)
column_types = metadata.get("column_types", {})
if not column_types:
logger.warning(f"No column types for {table_id}, schema will not be applied")
return None
fields = []
for col_name in metadata.get("columns", []):
bq_type = column_types.get(col_name, "STRING")
pa_type = BIGQUERY_TO_PYARROW_TYPES.get(bq_type, pa.string())
fields.append(pa.field(col_name, pa_type))
return pa.schema(fields)
def get_date_columns(self, table_id: str) -> List[str]:
"""
Get list of DATE-only columns for a table.
Args:
table_id: Full table ID
Returns:
List of column names that have DATE type in BigQuery
"""
metadata = self.get_table_metadata(table_id)
column_types = metadata.get("column_types", {})
return [
col_name for col_name, bq_type in column_types.items()
if bq_type == "DATE"
]
def query_to_arrow(
self,
sql: str,
params: Optional[List[bigquery.ScalarQueryParameter]] = None,
) -> pa.Table:
"""
Execute SQL query and return results as PyArrow Table.
Args:
sql: SQL query string (use @param_name for parameterized values)
params: List of BigQuery query parameters
Returns:
PyArrow Table with query results
"""
job_config = bigquery.QueryJobConfig()
if params:
job_config.query_parameters = params
logger.debug(f"Executing BQ query: {sql[:200]}...")
query_job = self.client.query(sql, job_config=job_config)
arrow_table = query_job.to_arrow()
logger.debug(f"Query returned {arrow_table.num_rows} rows, {arrow_table.num_columns} columns")
return arrow_table
def read_table(
self,
table_id: str,
columns: Optional[List[str]] = None,
row_filter: Optional[str] = None,
) -> pa.Table:
"""
Read full table (or filtered subset) as PyArrow Table.
Args:
table_id: Full table ID (e.g., "project.dataset.table")
columns: Optional list of columns to select
row_filter: Optional SQL WHERE clause (without WHERE keyword)
Returns:
PyArrow Table with table data
"""
# Build SELECT clause
select_cols = ", ".join(f"`{c}`" for c in columns) if columns else "*"
sql = f"SELECT {select_cols} FROM `{table_id}`"
if row_filter:
sql += f" WHERE {row_filter}"
logger.info(f"Reading BQ table: {table_id} (filter: {row_filter or 'none'})")
return self.query_to_arrow(sql)
def read_table_incremental(
self,
table_id: str,
incremental_column: str,
since_value: str,
columns: Optional[List[str]] = None,
) -> pa.Table:
"""
Read rows where incremental_column > since_value.
Uses parameterized query to prevent SQL injection.
Args:
table_id: Full table ID
incremental_column: Column name for incremental filter
since_value: ISO timestamp string - fetch rows after this value
columns: Optional list of columns to select
Returns:
PyArrow Table with incremental data
"""
select_cols = ", ".join(f"`{c}`" for c in columns) if columns else "*"
sql = (
f"SELECT {select_cols} FROM `{table_id}` "
f"WHERE `{incremental_column}` > @since_value"
)
params = [
bigquery.ScalarQueryParameter("since_value", "TIMESTAMP", since_value),
]
logger.info(
f"Incremental read: {table_id} WHERE {incremental_column} > {since_value}"
)
return self.query_to_arrow(sql, params=params)
def read_table_partitioned(
self,
table_id: str,
partition_column: str,
start: str,
end: Optional[str] = None,
columns: Optional[List[str]] = None,
) -> pa.Table:
"""
Read data within a partition range.
Args:
table_id: Full table ID
partition_column: Partition column name
start: Start date/timestamp (inclusive)
end: End date/timestamp (exclusive). If None, reads to present.
columns: Optional list of columns to select
Returns:
PyArrow Table with partition range data
"""
select_cols = ", ".join(f"`{c}`" for c in columns) if columns else "*"
sql = (
f"SELECT {select_cols} FROM `{table_id}` "
f"WHERE `{partition_column}` >= @start_value"
)
params = [
bigquery.ScalarQueryParameter("start_value", "TIMESTAMP", start),
]
if end:
sql += f" AND `{partition_column}` < @end_value"
params.append(
bigquery.ScalarQueryParameter("end_value", "TIMESTAMP", end),
)
logger.info(
f"Partitioned read: {table_id} [{start} .. {end or 'now'})"
)
return self.query_to_arrow(sql, params=params)
def discover_all_tables(self, dataset_id: Optional[str] = None) -> List[Dict[str, Any]]:
"""
List all tables in the project (or specific dataset).
Args:
dataset_id: Optional dataset ID to limit scope
Returns:
Normalized list of table dicts with id, name, columns, row_count, etc.
"""
logger.info(f"Discovering BQ tables (dataset={dataset_id or 'all'})...")
result = []
if dataset_id:
datasets = [self.client.get_dataset(dataset_id)]
else:
datasets = list(self.client.list_datasets())
for dataset in datasets:
ds_ref = dataset.reference if hasattr(dataset, "reference") else dataset.dataset_id
ds_id = str(ds_ref)
try:
tables = list(self.client.list_tables(ds_ref))
except Exception as e:
logger.warning(f"Could not list tables in dataset {ds_id}: {e}")
continue
for table_item in tables:
full_id = f"{table_item.project}.{table_item.dataset_id}.{table_item.table_id}"
try:
table_detail = self.client.get_table(full_id)
columns = [f.name for f in table_detail.schema]
result.append({
"id": full_id,
"name": table_item.table_id,
"bucket_id": table_item.dataset_id,
"bucket_name": table_item.dataset_id,
"columns": columns,
"row_count": table_detail.num_rows or 0,
"size_bytes": table_detail.num_bytes or 0,
"primary_key": [],
"last_change": (
table_detail.modified.isoformat()
if table_detail.modified else None
),
"last_import": None,
})
except Exception as e:
logger.warning(f"Could not get details for {full_id}: {e}")
logger.info(f"Discovered {len(result)} BQ tables")
return result
def test_connection(self) -> bool:
"""
Test connection to BigQuery.
Returns:
True if connection works, False otherwise
"""
try:
query_job = self.client.query("SELECT 1")
list(query_job.result())
logger.info(f"BigQuery connection OK (project: {self.project_id})")
return True
except Exception as e:
logger.error(f"BigQuery connection test failed: {e}")
return False
def create_client() -> BigQueryClient:
"""
Factory function to create BigQuery client.
Returns:
BigQueryClient instance
"""
return BigQueryClient()

View file

@ -1,5 +1,7 @@
# Data source adapters (install only what you need)
kbcstorage>=0.9.0 # For Keboola adapter
google-cloud-bigquery>=3.0.0 # For BigQuery adapter
google-cloud-bigquery-storage>=2.0.0 # For BigQuery adapter (fast Arrow transfer)
# Data processing
# pandas - core tabular data processing library

View file

@ -101,6 +101,7 @@ class TableConfig:
max_history_days: Optional[int] = None
dataset: Optional[str] = None
initial_load_chunk_days: int = 30
incremental_column: Optional[str] = None # Column for timestamp-based incremental sync (BigQuery)
def __post_init__(self):
"""Validate configuration after initialization."""
@ -429,6 +430,7 @@ class Config:
max_history_days=table_data.get("max_history_days"),
dataset=table_data.get("dataset"),
initial_load_chunk_days=table_data.get("initial_load_chunk_days", 30),
incremental_column=table_data.get("incremental_column"),
)
table_configs.append(config)

View file

@ -511,7 +511,7 @@ def create_data_source(source_type: str = None) -> DataSource:
raise ValueError(
f"Unknown data source: '{source_type}'. "
f"Available connectors: keboola. "
f"Available connectors: keboola, bigquery. "
f"Create connectors/{source_type}/adapter.py to add a new one."
)

View file

@ -0,0 +1,763 @@
"""
Comprehensive unit tests for the BigQuery data source adapter.
Tests the BigQueryDataSource class from connectors/bigquery/adapter.py
with all external dependencies (BigQueryClient, config, parquet_manager) mocked.
The google-cloud-bigquery package is not installed in test environments,
so we install stub modules in sys.modules before importing the adapter.
"""
import sys
from pathlib import Path
from unittest.mock import MagicMock, patch
import pyarrow as pa
import pyarrow.parquet as pq
import pytest
# ---------------------------------------------------------------------------
# Stub google.cloud.bigquery before any connector import
# ---------------------------------------------------------------------------
_bq_stub = MagicMock()
sys.modules.setdefault("google", _bq_stub)
sys.modules.setdefault("google.cloud", _bq_stub)
sys.modules.setdefault("google.cloud.bigquery", _bq_stub)
from src.config import TableConfig # noqa: E402
from src.data_sync import SyncState # noqa: E402
# ---------------------------------------------------------------------------
# Fixtures
# ---------------------------------------------------------------------------
@pytest.fixture
def tmp_parquet_dir(tmp_path):
"""Provide a temporary directory for Parquet file output."""
parquet_dir = tmp_path / "parquet" / "test_bucket"
parquet_dir.mkdir(parents=True)
return parquet_dir
@pytest.fixture
def mock_config(tmp_parquet_dir):
"""Create a mock Config object that returns paths inside tmp_parquet_dir."""
config = MagicMock()
config.get_parquet_path = MagicMock()
config.get_partition_path = MagicMock()
config.get_metadata_path.return_value = tmp_parquet_dir.parent / "metadata"
return config
@pytest.fixture
def mock_bq_client():
"""Create a mock BigQueryClient with sensible defaults."""
client = MagicMock()
client.metadata_cache = {}
client.get_date_columns.return_value = []
client.get_pyarrow_schema.return_value = None
return client
@pytest.fixture
def sync_state(tmp_path):
"""Create a real SyncState backed by a temp JSON file."""
state_file = tmp_path / "metadata" / "sync_state.json"
state_file.parent.mkdir(parents=True, exist_ok=True)
return SyncState(state_file)
# ---------------------------------------------------------------------------
# Helpers
# ---------------------------------------------------------------------------
def _make_table_config(
*,
table_id: str = "project.dataset.orders",
name: str = "orders",
primary_key: str = "id",
sync_strategy: str = "full_refresh",
incremental_column: str | None = None,
incremental_window_days: int | None = None,
partition_by: str | None = None,
partition_granularity: str | None = None,
max_history_days: int | None = None,
) -> TableConfig:
"""Helper to build a TableConfig with safe defaults."""
return TableConfig(
id=table_id,
name=name,
description="Test table",
primary_key=primary_key,
sync_strategy=sync_strategy,
incremental_column=incremental_column,
incremental_window_days=incremental_window_days,
partition_by=partition_by,
partition_granularity=partition_granularity,
max_history_days=max_history_days,
)
def _sample_arrow_table(ids: list[int], names: list[str]) -> pa.Table:
"""Build a small PyArrow Table with id and name columns."""
return pa.table({"id": ids, "name": names})
def _create_adapter(mock_config, mock_bq_client):
"""Instantiate BigQueryDataSource with mocked dependencies.
Patches get_config and create_bq_client so that no real GCP
credentials or network access are needed.
"""
with patch("connectors.bigquery.adapter.get_config", return_value=mock_config), \
patch("connectors.bigquery.adapter.create_bq_client", return_value=mock_bq_client):
from connectors.bigquery.adapter import BigQueryDataSource
adapter = BigQueryDataSource()
return adapter
# ---------------------------------------------------------------------------
# 1. full_refresh writes valid Parquet file from Arrow table
# ---------------------------------------------------------------------------
class TestFullRefresh:
def test_writes_valid_parquet(self, mock_config, mock_bq_client, tmp_parquet_dir, sync_state):
"""full_refresh should write a valid, readable Parquet file."""
table_config = _make_table_config(sync_strategy="full_refresh")
parquet_path = tmp_parquet_dir / "orders.parquet"
mock_config.get_parquet_path.return_value = parquet_path
arrow_data = _sample_arrow_table([1, 2, 3], ["Alice", "Bob", "Charlie"])
mock_bq_client.read_table.return_value = arrow_data
adapter = _create_adapter(mock_config, mock_bq_client)
result = adapter.sync_table(table_config, sync_state)
assert result["success"] is True
assert result["rows"] == 3
assert parquet_path.exists()
# Verify Parquet content matches source data
read_back = pq.read_table(parquet_path)
assert read_back.num_rows == 3
assert read_back.column_names == ["id", "name"]
def test_applies_date_columns(self, mock_config, mock_bq_client, tmp_parquet_dir, sync_state):
"""full_refresh should call convert_date_columns_to_date32 when date columns exist."""
table_config = _make_table_config()
parquet_path = tmp_parquet_dir / "orders.parquet"
mock_config.get_parquet_path.return_value = parquet_path
arrow_data = _sample_arrow_table([1], ["Alice"])
mock_bq_client.read_table.return_value = arrow_data
mock_bq_client.get_date_columns.return_value = ["created_at"]
with patch("connectors.bigquery.adapter.convert_date_columns_to_date32", return_value=arrow_data) as mock_conv:
adapter = _create_adapter(mock_config, mock_bq_client)
adapter.sync_table(table_config, sync_state)
mock_conv.assert_called_once_with(arrow_data, ["created_at"])
def test_applies_pyarrow_schema(self, mock_config, mock_bq_client, tmp_parquet_dir, sync_state):
"""full_refresh should call apply_schema_to_table when schema is available."""
table_config = _make_table_config()
parquet_path = tmp_parquet_dir / "orders.parquet"
mock_config.get_parquet_path.return_value = parquet_path
arrow_data = _sample_arrow_table([1], ["Alice"])
mock_bq_client.read_table.return_value = arrow_data
schema = pa.schema([pa.field("id", pa.int64()), pa.field("name", pa.string())])
mock_bq_client.get_pyarrow_schema.return_value = schema
with patch("connectors.bigquery.adapter.apply_schema_to_table", return_value=arrow_data) as mock_apply:
adapter = _create_adapter(mock_config, mock_bq_client)
adapter.sync_table(table_config, sync_state)
mock_apply.assert_called_once_with(arrow_data, schema)
# ---------------------------------------------------------------------------
# 2. incremental_column_sync merges correctly (dedup on PK, new data wins)
# ---------------------------------------------------------------------------
class TestIncrementalColumnSync:
def test_merge_dedup_new_data_wins(self, mock_config, mock_bq_client, tmp_parquet_dir, sync_state):
"""Incremental sync should overwrite existing rows when PK matches (new data wins)."""
table_config = _make_table_config(
sync_strategy="incremental",
incremental_column="updated_at",
incremental_window_days=7,
)
parquet_path = tmp_parquet_dir / "orders.parquet"
mock_config.get_parquet_path.return_value = parquet_path
# Write existing data
existing = _sample_arrow_table([1, 2], ["Alice", "Bob"])
pq.write_table(existing, parquet_path)
# Simulate a previous sync timestamp
sync_state.update_sync(
table_id=table_config.id,
table_name=table_config.name,
strategy="incremental",
rows=2,
file_size_bytes=100,
)
# New data: id=2 gets updated name, id=3 is new
new_data = _sample_arrow_table([2, 3], ["Bob_Updated", "Charlie"])
mock_bq_client.read_table_incremental.return_value = new_data
adapter = _create_adapter(mock_config, mock_bq_client)
result = adapter.sync_table(table_config, sync_state)
assert result["success"] is True
assert result["rows"] == 3 # Alice + Bob_Updated + Charlie
read_back = pq.read_table(parquet_path)
df = read_back.to_pandas()
assert set(df["id"].tolist()) == {1, 2, 3}
# id=2 should have the updated name
bob_row = df[df["id"] == 2].iloc[0]
assert bob_row["name"] == "Bob_Updated"
# ---------------------------------------------------------------------------
# 3. incremental_column_sync with no new data returns existing file info
# ---------------------------------------------------------------------------
class TestIncrementalNoNewData:
def test_returns_existing_file_info(self, mock_config, mock_bq_client, tmp_parquet_dir, sync_state):
"""When there is no new data, sync returns stats from the existing Parquet file."""
table_config = _make_table_config(
sync_strategy="incremental",
incremental_column="updated_at",
incremental_window_days=7,
)
parquet_path = tmp_parquet_dir / "orders.parquet"
mock_config.get_parquet_path.return_value = parquet_path
# Write existing data
existing = _sample_arrow_table([1, 2, 3], ["A", "B", "C"])
pq.write_table(existing, parquet_path)
# Mark a previous sync
sync_state.update_sync(
table_id=table_config.id,
table_name=table_config.name,
strategy="incremental",
rows=3,
file_size_bytes=100,
)
# No new rows
empty_table = pa.table({
"id": pa.array([], type=pa.int64()),
"name": pa.array([], type=pa.string()),
})
mock_bq_client.read_table_incremental.return_value = empty_table
adapter = _create_adapter(mock_config, mock_bq_client)
result = adapter.sync_table(table_config, sync_state)
assert result["success"] is True
assert result["rows"] == 3 # existing row count preserved
# ---------------------------------------------------------------------------
# 4. partitioned_sync creates partition files
# ---------------------------------------------------------------------------
class TestPartitionedSync:
def test_creates_partition_files(self, mock_config, mock_bq_client, tmp_parquet_dir, sync_state):
"""Partitioned sync should create separate Parquet files per partition key."""
import pandas as pd
table_config = _make_table_config(
sync_strategy="incremental",
incremental_column="created_at",
partition_by="created_at",
partition_granularity="month",
incremental_window_days=7,
)
# For partitioned tables, parquet_path is a directory
partition_dir = tmp_parquet_dir / "orders"
partition_dir.mkdir(parents=True, exist_ok=True)
mock_config.get_parquet_path.return_value = partition_dir
# Configure partition paths
def _partition_path(tc, key):
return partition_dir / f"{key}.parquet"
mock_config.get_partition_path.side_effect = _partition_path
# Build arrow table with timestamps in two months
ts_jan = [pd.Timestamp("2026-01-15 10:00:00", tz="UTC")]
ts_feb = [pd.Timestamp("2026-02-20 14:00:00", tz="UTC")]
arrow_data = pa.table({
"id": [1, 2],
"name": ["Jan_Order", "Feb_Order"],
"created_at": pa.array(ts_jan + ts_feb, type=pa.timestamp("us", tz="UTC")),
})
mock_bq_client.read_table.return_value = arrow_data
adapter = _create_adapter(mock_config, mock_bq_client)
result = adapter.sync_table(table_config, sync_state)
assert result["success"] is True
# Should have created two partition files
partition_files = list(partition_dir.glob("*.parquet"))
assert len(partition_files) == 2
partition_names = sorted(f.stem for f in partition_files)
assert "2026_01" in partition_names
assert "2026_02" in partition_names
# ---------------------------------------------------------------------------
# 5. discover_tables delegates to BigQueryClient.discover_all_tables()
# ---------------------------------------------------------------------------
class TestDiscoverTables:
def test_delegates_to_client(self, mock_config, mock_bq_client):
"""discover_tables should forward the call to BigQueryClient.discover_all_tables."""
expected = [{"id": "proj.ds.t1", "name": "t1", "columns": ["a", "b"]}]
mock_bq_client.discover_all_tables.return_value = expected
adapter = _create_adapter(mock_config, mock_bq_client)
result = adapter.discover_tables()
mock_bq_client.discover_all_tables.assert_called_once()
assert result == expected
# ---------------------------------------------------------------------------
# 6. get_source_name returns "Google BigQuery"
# ---------------------------------------------------------------------------
class TestGetSourceName:
def test_returns_google_bigquery(self, mock_config, mock_bq_client):
adapter = _create_adapter(mock_config, mock_bq_client)
assert adapter.get_source_name() == "Google BigQuery"
# ---------------------------------------------------------------------------
# 7. get_column_metadata returns correct format
# ---------------------------------------------------------------------------
class TestGetColumnMetadata:
def test_returns_correct_format(self, mock_config, mock_bq_client):
"""get_column_metadata should transform BQ raw metadata into {columns: ...} format."""
mock_bq_client.get_table_metadata.return_value = {
"column_types": {"id": "INT64", "name": "STRING", "email": "STRING"},
"column_descriptions": {"id": "Primary key", "email": "User email address"},
}
adapter = _create_adapter(mock_config, mock_bq_client)
result = adapter.get_column_metadata("project.dataset.users")
assert "columns" in result
assert result["columns"]["id"] == {"source_type": "INT64", "description": "Primary key"}
assert result["columns"]["name"] == {"source_type": "STRING"}
assert result["columns"]["email"] == {
"source_type": "STRING",
"description": "User email address",
}
def test_returns_none_when_no_column_types(self, mock_config, mock_bq_client):
"""get_column_metadata should return None if the metadata has no column types."""
mock_bq_client.get_table_metadata.return_value = {
"column_types": {},
"column_descriptions": {},
}
adapter = _create_adapter(mock_config, mock_bq_client)
result = adapter.get_column_metadata("project.dataset.users")
assert result is None
# ---------------------------------------------------------------------------
# 8. Error handling (query failure -> {success: False, error: ...})
# ---------------------------------------------------------------------------
class TestErrorHandling:
def test_query_failure_returns_error_dict(
self, mock_config, mock_bq_client, tmp_parquet_dir, sync_state
):
"""When BigQuery query raises, sync_table returns {success: False, error: ...}."""
table_config = _make_table_config()
mock_config.get_parquet_path.return_value = tmp_parquet_dir / "orders.parquet"
mock_bq_client.read_table.side_effect = RuntimeError("BigQuery API timeout")
adapter = _create_adapter(mock_config, mock_bq_client)
result = adapter.sync_table(table_config, sync_state)
assert result["success"] is False
assert "BigQuery API timeout" in result["error"]
assert result["strategy"] == "full_refresh"
def test_unknown_strategy_returns_error(self, mock_config, mock_bq_client, sync_state):
"""Unknown sync_strategy in internal dispatch should produce an error result."""
# We cannot create a TableConfig with an invalid strategy via constructor
# (it validates). Instead, we mutate it after creation.
table_config = _make_table_config()
table_config.sync_strategy = "magic_sync"
adapter = _create_adapter(mock_config, mock_bq_client)
result = adapter.sync_table(table_config, sync_state)
assert result["success"] is False
assert "Unknown sync strategy" in result["error"]
# ---------------------------------------------------------------------------
# 9. incremental_column config is used in WHERE clause
# ---------------------------------------------------------------------------
class TestIncrementalColumnUsedInWhere:
def test_incremental_column_passed_to_client(
self, mock_config, mock_bq_client, tmp_parquet_dir, sync_state
):
"""The configured incremental_column should be forwarded to read_table_incremental."""
table_config = _make_table_config(
sync_strategy="incremental",
incremental_column="modified_at",
incremental_window_days=14,
)
parquet_path = tmp_parquet_dir / "orders.parquet"
mock_config.get_parquet_path.return_value = parquet_path
# Write existing data so we enter the incremental path
existing = _sample_arrow_table([1], ["Alice"])
pq.write_table(existing, parquet_path)
sync_state.update_sync(
table_id=table_config.id,
table_name=table_config.name,
strategy="incremental",
rows=1,
file_size_bytes=100,
)
# Return empty to keep the test simple
empty = pa.table({
"id": pa.array([], type=pa.int64()),
"name": pa.array([], type=pa.string()),
})
mock_bq_client.read_table_incremental.return_value = empty
adapter = _create_adapter(mock_config, mock_bq_client)
adapter.sync_table(table_config, sync_state)
call_kwargs = mock_bq_client.read_table_incremental.call_args
assert call_kwargs.kwargs["incremental_column"] == "modified_at"
assert call_kwargs.kwargs["table_id"] == "project.dataset.orders"
# since_value should be an ISO string
assert "since_value" in call_kwargs.kwargs
# ---------------------------------------------------------------------------
# 10. First sync without existing file downloads all data
# ---------------------------------------------------------------------------
class TestFirstSyncDownloadsAll:
def test_first_sync_reads_full_table(
self, mock_config, mock_bq_client, tmp_parquet_dir, sync_state
):
"""On first incremental sync (no existing file), adapter should read all data."""
table_config = _make_table_config(
sync_strategy="incremental",
incremental_column="updated_at",
incremental_window_days=7,
)
parquet_path = tmp_parquet_dir / "orders.parquet"
mock_config.get_parquet_path.return_value = parquet_path
# No previous sync, no existing file
arrow_data = _sample_arrow_table([1, 2, 3], ["A", "B", "C"])
mock_bq_client.read_table.return_value = arrow_data
adapter = _create_adapter(mock_config, mock_bq_client)
result = adapter.sync_table(table_config, sync_state)
assert result["success"] is True
assert result["rows"] == 3
# Should call read_table (full), not read_table_incremental
mock_bq_client.read_table.assert_called_once_with(table_config.id)
mock_bq_client.read_table_incremental.assert_not_called()
def test_first_sync_with_max_history_days(
self, mock_config, mock_bq_client, tmp_parquet_dir, sync_state
):
"""First sync with max_history_days should use read_table_incremental."""
table_config = _make_table_config(
sync_strategy="incremental",
incremental_column="updated_at",
incremental_window_days=7,
max_history_days=90,
)
parquet_path = tmp_parquet_dir / "orders.parquet"
mock_config.get_parquet_path.return_value = parquet_path
arrow_data = _sample_arrow_table([1, 2], ["A", "B"])
mock_bq_client.read_table_incremental.return_value = arrow_data
adapter = _create_adapter(mock_config, mock_bq_client)
result = adapter.sync_table(table_config, sync_state)
assert result["success"] is True
# Should use read_table_incremental (not read_table) because max_history_days is set
mock_bq_client.read_table_incremental.assert_called_once()
call_kwargs = mock_bq_client.read_table_incremental.call_args.kwargs
assert call_kwargs["incremental_column"] == "updated_at"
mock_bq_client.read_table.assert_not_called()
# ---------------------------------------------------------------------------
# 11. sync_table dispatches to correct strategy based on sync_strategy
# ---------------------------------------------------------------------------
class TestSyncTableDispatch:
def test_dispatches_full_refresh(
self, mock_config, mock_bq_client, tmp_parquet_dir, sync_state
):
"""sync_strategy='full_refresh' should call _full_refresh."""
table_config = _make_table_config(sync_strategy="full_refresh")
mock_config.get_parquet_path.return_value = tmp_parquet_dir / "orders.parquet"
mock_bq_client.read_table.return_value = _sample_arrow_table([1], ["A"])
adapter = _create_adapter(mock_config, mock_bq_client)
with patch.object(adapter, "_full_refresh", wraps=adapter._full_refresh) as spy:
adapter.sync_table(table_config, sync_state)
spy.assert_called_once_with(table_config)
def test_dispatches_incremental(
self, mock_config, mock_bq_client, tmp_parquet_dir, sync_state
):
"""sync_strategy='incremental' should call _incremental_sync."""
table_config = _make_table_config(
sync_strategy="incremental",
incremental_column="updated_at",
incremental_window_days=7,
)
mock_config.get_parquet_path.return_value = tmp_parquet_dir / "orders.parquet"
mock_bq_client.read_table.return_value = _sample_arrow_table([1], ["A"])
adapter = _create_adapter(mock_config, mock_bq_client)
with patch.object(adapter, "_incremental_sync", wraps=adapter._incremental_sync) as spy:
adapter.sync_table(table_config, sync_state)
spy.assert_called_once_with(table_config, sync_state)
def test_dispatches_partitioned(
self, mock_config, mock_bq_client, tmp_parquet_dir, sync_state
):
"""sync_strategy='incremental' with partition_by should call _partitioned_sync."""
import pandas as pd
table_config = _make_table_config(
sync_strategy="incremental",
incremental_column="created_at",
partition_by="created_at",
partition_granularity="month",
incremental_window_days=7,
)
partition_dir = tmp_parquet_dir / "orders"
partition_dir.mkdir(parents=True, exist_ok=True)
mock_config.get_parquet_path.return_value = partition_dir
def _partition_path(tc, key):
return partition_dir / f"{key}.parquet"
mock_config.get_partition_path.side_effect = _partition_path
ts = [pd.Timestamp("2026-01-15 10:00:00", tz="UTC")]
arrow_data = pa.table({
"id": [1],
"name": ["A"],
"created_at": pa.array(ts, type=pa.timestamp("us", tz="UTC")),
})
mock_bq_client.read_table.return_value = arrow_data
adapter = _create_adapter(mock_config, mock_bq_client)
with patch.object(adapter, "_partitioned_sync", wraps=adapter._partitioned_sync) as spy:
adapter.sync_table(table_config, sync_state)
spy.assert_called_once()
def test_incremental_without_column_falls_back_to_full_refresh(
self, mock_config, mock_bq_client, tmp_parquet_dir, sync_state
):
"""incremental strategy without incremental_column or partition_by falls back to full_refresh."""
table_config = _make_table_config(
sync_strategy="incremental",
incremental_column=None,
partition_by=None,
incremental_window_days=7,
)
mock_config.get_parquet_path.return_value = tmp_parquet_dir / "orders.parquet"
mock_bq_client.read_table.return_value = _sample_arrow_table([1], ["A"])
adapter = _create_adapter(mock_config, mock_bq_client)
with patch.object(adapter, "_full_refresh", wraps=adapter._full_refresh) as spy:
result = adapter.sync_table(table_config, sync_state)
spy.assert_called_once()
assert result["success"] is True
# ---------------------------------------------------------------------------
# 12. _merge_arrow_tables deduplicates correctly
# ---------------------------------------------------------------------------
class TestMergeArrowTables:
def test_dedup_on_single_pk(self, mock_config, mock_bq_client):
"""Merge should deduplicate on single primary key column, new data wins."""
adapter = _create_adapter(mock_config, mock_bq_client)
existing = pa.table({"id": [1, 2, 3], "val": ["a", "b", "c"]})
new_data = pa.table({"id": [2, 4], "val": ["B_new", "d"]})
merged = adapter._merge_arrow_tables(existing, new_data, primary_key=["id"])
df = merged.to_pandas().sort_values("id").reset_index(drop=True)
assert list(df["id"]) == [1, 2, 3, 4]
assert list(df["val"]) == ["a", "B_new", "c", "d"]
def test_dedup_on_composite_pk(self, mock_config, mock_bq_client):
"""Merge should deduplicate on composite primary key."""
adapter = _create_adapter(mock_config, mock_bq_client)
existing = pa.table({
"pk1": [1, 1, 2],
"pk2": ["a", "b", "a"],
"val": ["old_1a", "old_1b", "old_2a"],
})
new_data = pa.table({
"pk1": [1, 2],
"pk2": ["a", "a"],
"val": ["new_1a", "new_2a"],
})
merged = adapter._merge_arrow_tables(existing, new_data, primary_key=["pk1", "pk2"])
df = merged.to_pandas().sort_values(["pk1", "pk2"]).reset_index(drop=True)
assert len(df) == 3
# (1, a) should be updated
row_1a = df[(df["pk1"] == 1) & (df["pk2"] == "a")].iloc[0]
assert row_1a["val"] == "new_1a"
# (1, b) should be preserved
row_1b = df[(df["pk1"] == 1) & (df["pk2"] == "b")].iloc[0]
assert row_1b["val"] == "old_1b"
# (2, a) should be updated
row_2a = df[(df["pk1"] == 2) & (df["pk2"] == "a")].iloc[0]
assert row_2a["val"] == "new_2a"
def test_merge_with_empty_new_data(self, mock_config, mock_bq_client):
"""Merging with empty new data should return existing data unchanged."""
adapter = _create_adapter(mock_config, mock_bq_client)
existing = pa.table({"id": [1, 2], "val": ["a", "b"]})
empty = pa.table({
"id": pa.array([], type=pa.int64()),
"val": pa.array([], type=pa.string()),
})
merged = adapter._merge_arrow_tables(existing, empty, primary_key=["id"])
assert merged.num_rows == 2
def test_merge_with_empty_existing(self, mock_config, mock_bq_client):
"""Merging with empty existing data should return new data."""
adapter = _create_adapter(mock_config, mock_bq_client)
empty = pa.table({
"id": pa.array([], type=pa.int64()),
"val": pa.array([], type=pa.string()),
})
new_data = pa.table({"id": [1, 2], "val": ["a", "b"]})
merged = adapter._merge_arrow_tables(empty, new_data, primary_key=["id"])
assert merged.num_rows == 2
# ---------------------------------------------------------------------------
# Additional edge cases
# ---------------------------------------------------------------------------
class TestMetadataCacheClearing:
def test_clears_metadata_cache_before_sync(
self, mock_config, mock_bq_client, tmp_parquet_dir, sync_state
):
"""sync_table should clear the BQ metadata cache entry for the table being synced."""
table_config = _make_table_config()
parquet_path = tmp_parquet_dir / "orders.parquet"
mock_config.get_parquet_path.return_value = parquet_path
mock_bq_client.read_table.return_value = _sample_arrow_table([1], ["A"])
# Pre-populate cache
mock_bq_client.metadata_cache[table_config.id] = {"some": "cached_data"}
adapter = _create_adapter(mock_config, mock_bq_client)
adapter.sync_table(table_config, sync_state)
assert table_config.id not in mock_bq_client.metadata_cache
class TestSyncStateUpdate:
def test_sync_state_updated_after_success(
self, mock_config, mock_bq_client, tmp_parquet_dir, sync_state
):
"""After successful sync, the sync state should be updated with correct values."""
table_config = _make_table_config()
parquet_path = tmp_parquet_dir / "orders.parquet"
mock_config.get_parquet_path.return_value = parquet_path
mock_bq_client.read_table.return_value = _sample_arrow_table([1, 2], ["A", "B"])
adapter = _create_adapter(mock_config, mock_bq_client)
adapter.sync_table(table_config, sync_state)
state = sync_state.get_table_state(table_config.id)
assert state["rows"] == 2
assert state["strategy"] == "full_refresh"
assert state["table_name"] == "orders"
assert "last_sync" in state
def test_sync_state_not_updated_on_failure(
self, mock_config, mock_bq_client, tmp_parquet_dir, sync_state
):
"""On sync failure, the sync state should NOT be updated."""
table_config = _make_table_config()
mock_config.get_parquet_path.return_value = tmp_parquet_dir / "orders.parquet"
mock_bq_client.read_table.side_effect = RuntimeError("boom")
adapter = _create_adapter(mock_config, mock_bq_client)
adapter.sync_table(table_config, sync_state)
state = sync_state.get_table_state(table_config.id)
assert state == {}
class TestCreateDataSourceFactory:
def test_factory_returns_adapter_instance(self, mock_config, mock_bq_client):
"""create_data_source() factory should return a BigQueryDataSource instance."""
with patch("connectors.bigquery.adapter.get_config", return_value=mock_config), \
patch("connectors.bigquery.adapter.create_bq_client", return_value=mock_bq_client):
from connectors.bigquery.adapter import create_data_source, BigQueryDataSource
instance = create_data_source()
assert isinstance(instance, BigQueryDataSource)

View file

@ -0,0 +1,870 @@
"""Tests for the BigQuery client connector.
All external dependencies (google.cloud.bigquery, src.config) are mocked.
Tests cover initialization, metadata caching, schema building, query methods,
and connection testing.
"""
import json
import sys
from datetime import datetime, timedelta
from pathlib import Path
from unittest.mock import MagicMock, mock_open, patch
import pyarrow as pa
import pytest
# Pre-populate sys.modules with a mock google.cloud.bigquery if not installed,
# so the client module can be imported without the real SDK.
_bq_mock_installed = False
try:
from google.cloud import bigquery as _bq_test # noqa: F401
except ImportError:
_bq_mock_installed = True
_mock_bigquery = MagicMock()
# Expose commonly used classes as MagicMock so the client module
# can reference bigquery.Client, bigquery.QueryJobConfig, etc.
sys.modules.setdefault("google", MagicMock())
sys.modules.setdefault("google.cloud", MagicMock())
sys.modules.setdefault("google.cloud.bigquery", _mock_bigquery)
from connectors.bigquery.client import (
BIGQUERY_TO_PYARROW_TYPES,
BigQueryClient,
create_client,
)
# Import the real or mock bigquery reference used in the client module
from google.cloud import bigquery
# ---------------------------------------------------------------------------
# Helpers
# ---------------------------------------------------------------------------
def _make_bq_field(name: str, field_type: str, description: str = None):
"""Create a mock BigQuery SchemaField."""
field = MagicMock()
field.name = name
field.field_type = field_type
field.description = description
return field
def _make_table_ref(
table_id: str = "my-project.my_dataset.my_table",
schema=None,
num_rows: int = 1000,
num_bytes: int = 50000,
created: datetime = None,
modified: datetime = None,
time_partitioning=None,
):
"""Create a mock BigQuery Table reference object."""
table_ref = MagicMock()
table_ref.table_id = table_id.split(".")[-1]
table_ref.dataset_id = table_id.split(".")[1] if "." in table_id else "dataset"
table_ref.project = table_id.split(".")[0] if "." in table_id else "project"
table_ref.schema = schema or []
table_ref.num_rows = num_rows
table_ref.num_bytes = num_bytes
table_ref.created = created or datetime(2025, 1, 1, 12, 0, 0)
table_ref.modified = modified or datetime(2025, 6, 1, 12, 0, 0)
table_ref.time_partitioning = time_partitioning
return table_ref
@pytest.fixture
def mock_config(tmp_path):
"""Mock get_config() to return a config with metadata path in tmp_path."""
config = MagicMock()
metadata_dir = tmp_path / "metadata"
metadata_dir.mkdir(parents=True, exist_ok=True)
config.get_metadata_path.return_value = metadata_dir
return config
@pytest.fixture
def mock_bq_client():
"""Create a mock BigQuery Client."""
return MagicMock()
@pytest.fixture
def client(mock_config, mock_bq_client):
"""Create a BigQueryClient instance with mocked dependencies."""
with (
patch("connectors.bigquery.client.bigquery.Client", return_value=mock_bq_client),
patch("connectors.bigquery.client.get_config", return_value=mock_config),
patch.dict("os.environ", {"BIGQUERY_PROJECT": "test-project"}),
):
bq_client = BigQueryClient()
return bq_client
# ---------------------------------------------------------------------------
# 1. Init validates BIGQUERY_PROJECT env var
# ---------------------------------------------------------------------------
class TestInit:
def test_raises_value_error_when_project_not_set(self, mock_config):
"""Init raises ValueError if project_id is None and env var is missing."""
with (
patch("connectors.bigquery.client.bigquery.Client"),
patch("connectors.bigquery.client.get_config", return_value=mock_config),
patch.dict("os.environ", {}, clear=True),
):
with pytest.raises(ValueError, match="BigQuery project ID not set"):
BigQueryClient()
def test_raises_value_error_when_project_empty_string(self, mock_config):
"""Init raises ValueError if BIGQUERY_PROJECT is set to empty string."""
with (
patch("connectors.bigquery.client.bigquery.Client"),
patch("connectors.bigquery.client.get_config", return_value=mock_config),
patch.dict("os.environ", {"BIGQUERY_PROJECT": ""}, clear=True),
):
with pytest.raises(ValueError, match="BigQuery project ID not set"):
BigQueryClient()
# -------------------------------------------------------------------
# 2. Init creates client with correct project_id
# -------------------------------------------------------------------
def test_creates_client_with_env_project_id(self, mock_config):
"""Client uses BIGQUERY_PROJECT from environment."""
mock_bq = MagicMock()
with (
patch("connectors.bigquery.client.bigquery.Client", return_value=mock_bq) as bq_cls,
patch("connectors.bigquery.client.get_config", return_value=mock_config),
patch.dict("os.environ", {"BIGQUERY_PROJECT": "env-project-123"}),
):
client = BigQueryClient()
bq_cls.assert_called_once_with(project="env-project-123")
assert client.project_id == "env-project-123"
def test_creates_client_with_explicit_project_id(self, mock_config):
"""Explicit project_id argument takes precedence over env var."""
mock_bq = MagicMock()
with (
patch("connectors.bigquery.client.bigquery.Client", return_value=mock_bq) as bq_cls,
patch("connectors.bigquery.client.get_config", return_value=mock_config),
):
client = BigQueryClient(project_id="explicit-project")
bq_cls.assert_called_once_with(project="explicit-project")
assert client.project_id == "explicit-project"
# ---------------------------------------------------------------------------
# 3. get_table_metadata fetches and caches metadata correctly
# ---------------------------------------------------------------------------
class TestGetTableMetadata:
def test_fetches_metadata_from_bigquery(self, client, mock_bq_client):
"""get_table_metadata calls client.get_table and returns correct dict."""
table_id = "proj.dataset.orders"
schema = [
_make_bq_field("order_id", "INTEGER"),
_make_bq_field("customer_name", "STRING", description="Full name"),
_make_bq_field("created_at", "TIMESTAMP"),
]
table_ref = _make_table_ref(
table_id=table_id,
schema=schema,
num_rows=5000,
num_bytes=120000,
)
mock_bq_client.get_table.return_value = table_ref
metadata = client.get_table_metadata(table_id, use_cache=False)
mock_bq_client.get_table.assert_called_once_with(table_id)
assert metadata["table_id"] == table_id
assert metadata["name"] == "orders"
assert metadata["dataset"] == "dataset"
assert metadata["project"] == "proj"
assert metadata["columns"] == ["order_id", "customer_name", "created_at"]
assert metadata["column_types"]["order_id"] == "INTEGER"
assert metadata["column_types"]["customer_name"] == "STRING"
assert metadata["column_types"]["created_at"] == "TIMESTAMP"
assert metadata["column_descriptions"]["customer_name"] == "Full name"
assert "order_id" not in metadata["column_descriptions"]
assert metadata["row_count"] == 5000
assert metadata["size_bytes"] == 120000
assert "_cached_at" in metadata
def test_caches_metadata_in_memory(self, client, mock_bq_client):
"""After first fetch, metadata is stored in the in-memory cache."""
table_id = "proj.dataset.tbl"
table_ref = _make_table_ref(table_id=table_id)
mock_bq_client.get_table.return_value = table_ref
client.get_table_metadata(table_id, use_cache=False)
assert table_id in client.metadata_cache
assert client.metadata_cache[table_id]["table_id"] == table_id
def test_captures_partitioning_info(self, client, mock_bq_client):
"""Partitioning metadata is captured when table is partitioned."""
table_id = "proj.dataset.events"
partition = MagicMock()
partition.type_ = "DAY"
partition.field = "event_date"
partition.expiration_ms = 7776000000
table_ref = _make_table_ref(table_id=table_id, time_partitioning=partition)
mock_bq_client.get_table.return_value = table_ref
metadata = client.get_table_metadata(table_id, use_cache=False)
assert metadata["partitioning"] is not None
assert metadata["partitioning"]["type"] == "DAY"
assert metadata["partitioning"]["field"] == "event_date"
assert metadata["partitioning"]["expiration_ms"] == 7776000000
def test_no_partitioning_when_absent(self, client, mock_bq_client):
"""Partitioning is None when table has no partitioning."""
table_id = "proj.dataset.simple"
table_ref = _make_table_ref(table_id=table_id, time_partitioning=None)
mock_bq_client.get_table.return_value = table_ref
metadata = client.get_table_metadata(table_id, use_cache=False)
assert metadata["partitioning"] is None
# -------------------------------------------------------------------
# 4. get_table_metadata uses cache when available (within TTL)
# -------------------------------------------------------------------
def test_uses_cache_within_ttl(self, client, mock_bq_client):
"""When cache is fresh (within TTL), BQ API is not called again."""
table_id = "proj.dataset.cached_tbl"
now = datetime.now()
client.metadata_cache[table_id] = {
"table_id": table_id,
"columns": ["a", "b"],
"column_types": {"a": "STRING", "b": "INTEGER"},
"_cached_at": now.isoformat(),
}
result = client.get_table_metadata(table_id, use_cache=True, cache_ttl_hours=24)
mock_bq_client.get_table.assert_not_called()
assert result["table_id"] == table_id
assert result["columns"] == ["a", "b"]
def test_refetches_when_cache_expired(self, client, mock_bq_client):
"""When cache is older than TTL, metadata is re-fetched from BQ."""
table_id = "proj.dataset.stale_tbl"
old_time = (datetime.now() - timedelta(hours=48)).isoformat()
client.metadata_cache[table_id] = {
"table_id": table_id,
"columns": ["old_col"],
"column_types": {"old_col": "STRING"},
"_cached_at": old_time,
}
table_ref = _make_table_ref(
table_id=table_id,
schema=[_make_bq_field("new_col", "INTEGER")],
)
mock_bq_client.get_table.return_value = table_ref
result = client.get_table_metadata(table_id, use_cache=True, cache_ttl_hours=24)
mock_bq_client.get_table.assert_called_once_with(table_id)
assert result["columns"] == ["new_col"]
def test_bypasses_cache_when_use_cache_false(self, client, mock_bq_client):
"""When use_cache=False, always fetches from BQ even if cache is fresh."""
table_id = "proj.dataset.force_fetch"
client.metadata_cache[table_id] = {
"table_id": table_id,
"columns": ["cached"],
"column_types": {"cached": "STRING"},
"_cached_at": datetime.now().isoformat(),
}
table_ref = _make_table_ref(
table_id=table_id,
schema=[_make_bq_field("fresh", "INTEGER")],
)
mock_bq_client.get_table.return_value = table_ref
result = client.get_table_metadata(table_id, use_cache=False)
mock_bq_client.get_table.assert_called_once()
assert result["columns"] == ["fresh"]
# ---------------------------------------------------------------------------
# 5. get_pyarrow_schema builds correct schema from BQ types
# ---------------------------------------------------------------------------
class TestGetPyarrowSchema:
def test_builds_correct_schema(self, client, mock_bq_client):
"""Schema maps BQ types to correct PyArrow types."""
table_id = "proj.dataset.typed_tbl"
schema = [
_make_bq_field("id", "INT64"),
_make_bq_field("name", "STRING"),
_make_bq_field("price", "FLOAT64"),
_make_bq_field("active", "BOOLEAN"),
_make_bq_field("created", "DATE"),
_make_bq_field("updated_at", "TIMESTAMP"),
]
table_ref = _make_table_ref(table_id=table_id, schema=schema)
mock_bq_client.get_table.return_value = table_ref
pa_schema = client.get_pyarrow_schema(table_id)
assert pa_schema is not None
assert pa_schema.field("id").type == pa.int64()
assert pa_schema.field("name").type == pa.string()
assert pa_schema.field("price").type == pa.float64()
assert pa_schema.field("active").type == pa.bool_()
assert pa_schema.field("created").type == pa.date32()
assert pa_schema.field("updated_at").type == pa.timestamp("us", tz="UTC")
def test_returns_none_when_no_column_types(self, client):
"""Returns None when metadata has no column_types."""
table_id = "proj.dataset.empty_schema"
client.metadata_cache[table_id] = {
"table_id": table_id,
"columns": [],
"column_types": {},
"_cached_at": datetime.now().isoformat(),
}
result = client.get_pyarrow_schema(table_id)
assert result is None
def test_unknown_type_falls_back_to_string(self, client, mock_bq_client):
"""Unknown BQ types default to pa.string() in the schema."""
table_id = "proj.dataset.exotic_types"
schema = [_make_bq_field("exotic_col", "SOME_UNKNOWN_TYPE")]
table_ref = _make_table_ref(table_id=table_id, schema=schema)
mock_bq_client.get_table.return_value = table_ref
pa_schema = client.get_pyarrow_schema(table_id)
assert pa_schema.field("exotic_col").type == pa.string()
# ---------------------------------------------------------------------------
# 6. get_date_columns returns only DATE columns
# ---------------------------------------------------------------------------
class TestGetDateColumns:
def test_returns_only_date_columns(self, client, mock_bq_client):
"""Only columns with BQ type DATE are returned."""
table_id = "proj.dataset.mixed_dates"
schema = [
_make_bq_field("event_date", "DATE"),
_make_bq_field("created_at", "TIMESTAMP"),
_make_bq_field("name", "STRING"),
_make_bq_field("birth_date", "DATE"),
_make_bq_field("updated_ts", "DATETIME"),
]
table_ref = _make_table_ref(table_id=table_id, schema=schema)
mock_bq_client.get_table.return_value = table_ref
date_cols = client.get_date_columns(table_id)
assert sorted(date_cols) == ["birth_date", "event_date"]
def test_returns_empty_when_no_date_columns(self, client, mock_bq_client):
"""Returns empty list when no DATE columns exist."""
table_id = "proj.dataset.no_dates"
schema = [
_make_bq_field("id", "INTEGER"),
_make_bq_field("ts", "TIMESTAMP"),
]
table_ref = _make_table_ref(table_id=table_id, schema=schema)
mock_bq_client.get_table.return_value = table_ref
date_cols = client.get_date_columns(table_id)
assert date_cols == []
# ---------------------------------------------------------------------------
# 7. query_to_arrow executes SQL and returns PyArrow table
# ---------------------------------------------------------------------------
class TestQueryToArrow:
def test_executes_query_and_returns_arrow(self, client, mock_bq_client):
"""query_to_arrow passes SQL to BQ and returns the arrow result."""
expected_table = pa.table({"col1": [1, 2, 3]})
mock_job = MagicMock()
mock_job.to_arrow.return_value = expected_table
mock_bq_client.query.return_value = mock_job
with patch("connectors.bigquery.client.bigquery") as mock_bq_module:
mock_bq_module.QueryJobConfig.return_value = MagicMock(query_parameters=None)
client.client = mock_bq_client
result = client.query_to_arrow("SELECT * FROM `proj.dataset.tbl`")
mock_bq_client.query.assert_called_once()
call_args = mock_bq_client.query.call_args
assert call_args[0][0] == "SELECT * FROM `proj.dataset.tbl`"
assert result.equals(expected_table)
def test_passes_query_parameters(self, client, mock_bq_client):
"""query_to_arrow forwards BQ query parameters in job config."""
expected_table = pa.table({"col1": [10]})
mock_job = MagicMock()
mock_job.to_arrow.return_value = expected_table
mock_bq_client.query.return_value = mock_job
mock_job_config = MagicMock()
params = [MagicMock()] # Mock ScalarQueryParameter
with patch("connectors.bigquery.client.bigquery") as mock_bq_module:
mock_bq_module.QueryJobConfig.return_value = mock_job_config
client.client = mock_bq_client
client.query_to_arrow("SELECT 1 WHERE x > @val", params=params)
# Verify params were set on the job config
assert mock_job_config.query_parameters == params
def test_no_params_does_not_set_query_parameters(self, client, mock_bq_client):
"""When no params given, query_parameters is not set on job config."""
mock_job = MagicMock()
mock_job.to_arrow.return_value = pa.table({"x": [1]})
mock_bq_client.query.return_value = mock_job
mock_job_config = MagicMock(spec=[])
with patch("connectors.bigquery.client.bigquery") as mock_bq_module:
mock_bq_module.QueryJobConfig.return_value = mock_job_config
client.client = mock_bq_client
client.query_to_arrow("SELECT 1")
# query_parameters should not have been set
assert not hasattr(mock_job_config, "query_parameters") or not getattr(
mock_job_config, "query_parameters", None
)
# ---------------------------------------------------------------------------
# 8. read_table builds correct SQL query
# ---------------------------------------------------------------------------
class TestReadTable:
def test_full_table_select_all(self, client, mock_bq_client):
"""read_table with no columns or filter generates SELECT *."""
mock_job = MagicMock()
mock_job.to_arrow.return_value = pa.table({"a": [1]})
mock_bq_client.query.return_value = mock_job
client.read_table("proj.dataset.tbl")
sql = mock_bq_client.query.call_args[0][0]
assert "SELECT *" in sql
assert "`proj.dataset.tbl`" in sql
assert "WHERE" not in sql
def test_select_specific_columns(self, client, mock_bq_client):
"""read_table with columns list generates SELECT with backtick-quoted names."""
mock_job = MagicMock()
mock_job.to_arrow.return_value = pa.table({"a": [1]})
mock_bq_client.query.return_value = mock_job
client.read_table("proj.dataset.tbl", columns=["col_a", "col_b"])
sql = mock_bq_client.query.call_args[0][0]
assert "`col_a`" in sql
assert "`col_b`" in sql
assert "*" not in sql
def test_with_row_filter(self, client, mock_bq_client):
"""read_table with row_filter appends WHERE clause."""
mock_job = MagicMock()
mock_job.to_arrow.return_value = pa.table({"a": [1]})
mock_bq_client.query.return_value = mock_job
client.read_table("proj.dataset.tbl", row_filter="status = 'active'")
sql = mock_bq_client.query.call_args[0][0]
assert "WHERE status = 'active'" in sql
def test_columns_and_filter_combined(self, client, mock_bq_client):
"""read_table with both columns and filter generates correct SQL."""
mock_job = MagicMock()
mock_job.to_arrow.return_value = pa.table({"x": [1]})
mock_bq_client.query.return_value = mock_job
client.read_table(
"proj.dataset.tbl",
columns=["id", "name"],
row_filter="id > 100",
)
sql = mock_bq_client.query.call_args[0][0]
assert "`id`, `name`" in sql
assert "WHERE id > 100" in sql
assert "`proj.dataset.tbl`" in sql
# ---------------------------------------------------------------------------
# 9. read_table_incremental builds parameterized WHERE clause
# ---------------------------------------------------------------------------
class TestReadTableIncremental:
def test_incremental_query_structure(self, client, mock_bq_client):
"""read_table_incremental builds WHERE col > @since_value with params."""
mock_job = MagicMock()
mock_job.to_arrow.return_value = pa.table({"a": [1]})
mock_bq_client.query.return_value = mock_job
with patch("connectors.bigquery.client.bigquery") as mock_bq_module:
mock_bq_module.QueryJobConfig.return_value = MagicMock()
mock_param = MagicMock()
mock_bq_module.ScalarQueryParameter.return_value = mock_param
# Re-assign the client's bq client (the fixture already set it up)
client.client = mock_bq_client
client.read_table_incremental(
table_id="proj.dataset.events",
incremental_column="updated_at",
since_value="2025-01-01T00:00:00Z",
)
sql = mock_bq_client.query.call_args[0][0]
assert "SELECT *" in sql
assert "`proj.dataset.events`" in sql
assert "`updated_at` > @since_value" in sql
# Verify ScalarQueryParameter was constructed correctly
mock_bq_module.ScalarQueryParameter.assert_called_once_with(
"since_value", "TIMESTAMP", "2025-01-01T00:00:00Z"
)
def test_incremental_with_columns(self, client, mock_bq_client):
"""read_table_incremental with columns list selects specific columns."""
mock_job = MagicMock()
mock_job.to_arrow.return_value = pa.table({"a": [1]})
mock_bq_client.query.return_value = mock_job
with patch("connectors.bigquery.client.bigquery") as mock_bq_module:
mock_bq_module.QueryJobConfig.return_value = MagicMock()
mock_bq_module.ScalarQueryParameter.return_value = MagicMock()
client.client = mock_bq_client
client.read_table_incremental(
table_id="proj.dataset.events",
incremental_column="updated_at",
since_value="2025-01-01T00:00:00Z",
columns=["id", "name"],
)
sql = mock_bq_client.query.call_args[0][0]
assert "`id`, `name`" in sql
assert "*" not in sql
# ---------------------------------------------------------------------------
# 10. read_table_partitioned builds correct range query
# ---------------------------------------------------------------------------
class TestReadTablePartitioned:
def test_partitioned_start_only(self, client, mock_bq_client):
"""With only start, generates >= @start_value without end clause."""
mock_job = MagicMock()
mock_job.to_arrow.return_value = pa.table({"a": [1]})
mock_bq_client.query.return_value = mock_job
with patch("connectors.bigquery.client.bigquery") as mock_bq_module:
mock_bq_module.QueryJobConfig.return_value = MagicMock()
mock_bq_module.ScalarQueryParameter.return_value = MagicMock()
client.client = mock_bq_client
client.read_table_partitioned(
table_id="proj.dataset.events",
partition_column="event_date",
start="2025-01-01",
)
sql = mock_bq_client.query.call_args[0][0]
assert "`event_date` >= @start_value" in sql
assert "@end_value" not in sql
# Only start_value parameter created
assert mock_bq_module.ScalarQueryParameter.call_count == 1
mock_bq_module.ScalarQueryParameter.assert_called_with(
"start_value", "TIMESTAMP", "2025-01-01"
)
def test_partitioned_start_and_end(self, client, mock_bq_client):
"""With start and end, generates >= @start_value AND < @end_value."""
mock_job = MagicMock()
mock_job.to_arrow.return_value = pa.table({"a": [1]})
mock_bq_client.query.return_value = mock_job
with patch("connectors.bigquery.client.bigquery") as mock_bq_module:
mock_bq_module.QueryJobConfig.return_value = MagicMock()
mock_bq_module.ScalarQueryParameter.return_value = MagicMock()
client.client = mock_bq_client
client.read_table_partitioned(
table_id="proj.dataset.events",
partition_column="event_date",
start="2025-01-01",
end="2025-06-01",
)
sql = mock_bq_client.query.call_args[0][0]
assert "`event_date` >= @start_value" in sql
assert "`event_date` < @end_value" in sql
# Both start_value and end_value parameters created
assert mock_bq_module.ScalarQueryParameter.call_count == 2
calls = mock_bq_module.ScalarQueryParameter.call_args_list
assert calls[0].args == ("start_value", "TIMESTAMP", "2025-01-01")
assert calls[1].args == ("end_value", "TIMESTAMP", "2025-06-01")
def test_partitioned_with_columns(self, client, mock_bq_client):
"""read_table_partitioned with columns selects specific columns."""
mock_job = MagicMock()
mock_job.to_arrow.return_value = pa.table({"a": [1]})
mock_bq_client.query.return_value = mock_job
with patch("connectors.bigquery.client.bigquery") as mock_bq_module:
mock_bq_module.QueryJobConfig.return_value = MagicMock()
mock_bq_module.ScalarQueryParameter.return_value = MagicMock()
client.client = mock_bq_client
client.read_table_partitioned(
table_id="proj.dataset.events",
partition_column="event_date",
start="2025-01-01",
columns=["id", "event_date", "value"],
)
sql = mock_bq_client.query.call_args[0][0]
assert "`id`, `event_date`, `value`" in sql
assert "*" not in sql
# ---------------------------------------------------------------------------
# 11. test_connection returns True on success, False on failure
# ---------------------------------------------------------------------------
class TestTestConnection:
def test_returns_true_on_success(self, client, mock_bq_client):
"""test_connection returns True when SELECT 1 query succeeds."""
mock_job = MagicMock()
mock_job.result.return_value = iter([(1,)])
mock_bq_client.query.return_value = mock_job
assert client.test_connection() is True
mock_bq_client.query.assert_called_once_with("SELECT 1")
def test_returns_false_on_failure(self, client, mock_bq_client):
"""test_connection returns False when the query raises an exception."""
mock_bq_client.query.side_effect = Exception("Connection refused")
assert client.test_connection() is False
def test_returns_false_when_result_fails(self, client, mock_bq_client):
"""test_connection returns False when result iteration fails."""
mock_job = MagicMock()
mock_job.result.side_effect = Exception("Timeout")
mock_bq_client.query.return_value = mock_job
assert client.test_connection() is False
# ---------------------------------------------------------------------------
# 12. Type mapping completeness (all BQ types have PyArrow mapping)
# ---------------------------------------------------------------------------
class TestTypeMapping:
# All standard BigQuery types that should be mapped
EXPECTED_BQ_TYPES = [
"STRING", "BYTES", "INTEGER", "INT64",
"FLOAT", "FLOAT64", "NUMERIC", "BIGNUMERIC",
"BOOLEAN", "BOOL",
"TIMESTAMP", "DATE", "TIME", "DATETIME",
"GEOGRAPHY", "JSON",
"STRUCT", "RECORD", "ARRAY",
]
def test_all_standard_bq_types_are_mapped(self):
"""Every standard BigQuery type has an entry in BIGQUERY_TO_PYARROW_TYPES."""
for bq_type in self.EXPECTED_BQ_TYPES:
assert bq_type in BIGQUERY_TO_PYARROW_TYPES, (
f"Missing PyArrow mapping for BQ type: {bq_type}"
)
def test_all_mappings_produce_valid_pyarrow_types(self):
"""Every mapped value is a valid PyArrow DataType."""
for bq_type, pa_type in BIGQUERY_TO_PYARROW_TYPES.items():
assert isinstance(pa_type, pa.DataType), (
f"BQ type {bq_type} maps to non-DataType: {pa_type!r}"
)
def test_integer_types_map_to_int64(self):
"""Both INTEGER and INT64 map to pa.int64()."""
assert BIGQUERY_TO_PYARROW_TYPES["INTEGER"] == pa.int64()
assert BIGQUERY_TO_PYARROW_TYPES["INT64"] == pa.int64()
def test_float_types_map_to_float64(self):
"""FLOAT, FLOAT64, NUMERIC, BIGNUMERIC all map to pa.float64()."""
for t in ["FLOAT", "FLOAT64", "NUMERIC", "BIGNUMERIC"]:
assert BIGQUERY_TO_PYARROW_TYPES[t] == pa.float64()
def test_boolean_types_map_to_bool(self):
"""Both BOOLEAN and BOOL map to pa.bool_()."""
assert BIGQUERY_TO_PYARROW_TYPES["BOOLEAN"] == pa.bool_()
assert BIGQUERY_TO_PYARROW_TYPES["BOOL"] == pa.bool_()
def test_date_maps_to_date32(self):
"""DATE maps to pa.date32()."""
assert BIGQUERY_TO_PYARROW_TYPES["DATE"] == pa.date32()
def test_timestamp_has_utc_timezone(self):
"""TIMESTAMP maps to pa.timestamp with UTC timezone."""
ts_type = BIGQUERY_TO_PYARROW_TYPES["TIMESTAMP"]
assert ts_type == pa.timestamp("us", tz="UTC")
def test_datetime_has_no_timezone(self):
"""DATETIME maps to pa.timestamp without timezone."""
dt_type = BIGQUERY_TO_PYARROW_TYPES["DATETIME"]
assert dt_type == pa.timestamp("us")
def test_complex_types_map_to_string(self):
"""STRUCT, RECORD, ARRAY, GEOGRAPHY, JSON all serialize as string."""
for t in ["STRUCT", "RECORD", "ARRAY", "GEOGRAPHY", "JSON"]:
assert BIGQUERY_TO_PYARROW_TYPES[t] == pa.string()
# ---------------------------------------------------------------------------
# 13. Metadata cache save/load from disk
# ---------------------------------------------------------------------------
class TestMetadataCachePersistence:
def test_save_and_load_cache(self, tmp_path):
"""Metadata cache is persisted to disk and reloaded on new client init."""
metadata_dir = tmp_path / "metadata"
metadata_dir.mkdir(parents=True, exist_ok=True)
cache_file = metadata_dir / "bq_table_metadata.json"
mock_config = MagicMock()
mock_config.get_metadata_path.return_value = metadata_dir
# First client: fetch metadata and save to cache
mock_bq = MagicMock()
table_id = "proj.ds.tbl"
schema = [_make_bq_field("col1", "STRING")]
table_ref = _make_table_ref(table_id=table_id, schema=schema)
mock_bq.get_table.return_value = table_ref
with (
patch("connectors.bigquery.client.bigquery.Client", return_value=mock_bq),
patch("connectors.bigquery.client.get_config", return_value=mock_config),
patch.dict("os.environ", {"BIGQUERY_PROJECT": "proj"}),
):
client1 = BigQueryClient()
client1.get_table_metadata(table_id, use_cache=False)
# Verify the cache file was written
assert cache_file.exists()
saved_data = json.loads(cache_file.read_text())
assert table_id in saved_data
assert saved_data[table_id]["columns"] == ["col1"]
# Second client: loads cache from disk on init
mock_bq2 = MagicMock()
with (
patch("connectors.bigquery.client.bigquery.Client", return_value=mock_bq2),
patch("connectors.bigquery.client.get_config", return_value=mock_config),
patch.dict("os.environ", {"BIGQUERY_PROJECT": "proj"}),
):
client2 = BigQueryClient()
assert table_id in client2.metadata_cache
assert client2.metadata_cache[table_id]["columns"] == ["col1"]
def test_load_handles_corrupt_cache_file(self, tmp_path):
"""Client handles corrupt cache JSON gracefully without crashing."""
metadata_dir = tmp_path / "metadata"
metadata_dir.mkdir(parents=True, exist_ok=True)
cache_file = metadata_dir / "bq_table_metadata.json"
cache_file.write_text("{corrupt json!!!")
mock_config = MagicMock()
mock_config.get_metadata_path.return_value = metadata_dir
mock_bq = MagicMock()
with (
patch("connectors.bigquery.client.bigquery.Client", return_value=mock_bq),
patch("connectors.bigquery.client.get_config", return_value=mock_config),
patch.dict("os.environ", {"BIGQUERY_PROJECT": "proj"}),
):
client = BigQueryClient()
# Cache should be empty after corrupt file
assert client.metadata_cache == {}
def test_load_handles_missing_cache_file(self, tmp_path):
"""Client initializes with empty cache when no cache file exists."""
metadata_dir = tmp_path / "metadata"
metadata_dir.mkdir(parents=True, exist_ok=True)
# No cache file created
mock_config = MagicMock()
mock_config.get_metadata_path.return_value = metadata_dir
mock_bq = MagicMock()
with (
patch("connectors.bigquery.client.bigquery.Client", return_value=mock_bq),
patch("connectors.bigquery.client.get_config", return_value=mock_config),
patch.dict("os.environ", {"BIGQUERY_PROJECT": "proj"}),
):
client = BigQueryClient()
assert client.metadata_cache == {}
def test_save_creates_parent_directories(self, tmp_path):
"""_save_metadata_cache creates parent directories if they do not exist."""
# Use a nested path that does not yet exist
metadata_dir = tmp_path / "deep" / "nested" / "metadata"
# Do NOT create directories upfront
mock_config = MagicMock()
mock_config.get_metadata_path.return_value = metadata_dir
mock_bq = MagicMock()
table_id = "proj.ds.tbl"
schema = [_make_bq_field("x", "INTEGER")]
table_ref = _make_table_ref(table_id=table_id, schema=schema)
mock_bq.get_table.return_value = table_ref
with (
patch("connectors.bigquery.client.bigquery.Client", return_value=mock_bq),
patch("connectors.bigquery.client.get_config", return_value=mock_config),
patch.dict("os.environ", {"BIGQUERY_PROJECT": "proj"}),
):
client = BigQueryClient()
client.get_table_metadata(table_id, use_cache=False)
cache_file = metadata_dir / "bq_table_metadata.json"
assert cache_file.exists()
# ---------------------------------------------------------------------------
# Factory function
# ---------------------------------------------------------------------------
class TestCreateClient:
def test_create_client_returns_bigquery_client(self, mock_config):
"""create_client() factory returns a BigQueryClient instance."""
mock_bq = MagicMock()
with (
patch("connectors.bigquery.client.bigquery.Client", return_value=mock_bq),
patch("connectors.bigquery.client.get_config", return_value=mock_config),
patch.dict("os.environ", {"BIGQUERY_PROJECT": "factory-project"}),
):
result = create_client()
assert isinstance(result, BigQueryClient)
assert result.project_id == "factory-project"