agnes-the-ai-analyst/connectors/bigquery/adapter.py
Petr 8bb46a9e0a Add per-partition streaming sync and hybrid query architecture
Partitioned sync: iterates day-by-day instead of loading full dataset.
Each partition: query BQ -> stream to disk -> free RAM. Peak ~50 MB.
New helpers: _sync_single_partition, _cleanup_old_partitions, _generate_partition_dates.

Config: added partition_column_type (DATE/TIMESTAMP/DATETIME), query_mode (local/remote/hybrid).
DuckDB manager: hybrid architecture support (local Parquet + remote BQ tables).
Data sync: skips remote tables, filters by query_mode.

Tests: 113 passing (adapter, client, config, data_sync, duckdb_manager).
2026-03-12 13:20:41 +01:00

650 lines
23 KiB
Python

"""
BigQuery data source adapter.
Implements the DataSource interface for Google BigQuery.
Reads tables via the BigQuery API, converts directly to Parquet files
using PyArrow (no CSV intermediate step).
"""
import logging
from pathlib import Path
from typing import Dict, List, Optional, Any
from datetime import datetime, timedelta, date
import pyarrow as pa
import pyarrow.parquet as pq
from src.config import get_config, TableConfig
from src.data_sync import DataSource, SyncState, _get_uncompressed_size
from src.parquet_manager import (
convert_date_columns_to_date32,
apply_schema_to_table,
)
from .client import create_client as create_bq_client
logger = logging.getLogger(__name__)
class BigQueryDataSource(DataSource):
"""
Data source: Google BigQuery.
Downloads data directly from BigQuery via PyArrow (no CSV step),
writes to local Parquet files with schema enforcement.
"""
def __init__(self):
"""Initialize BigQuery source with env var validation."""
self.config = get_config()
self.bq_client = create_bq_client()
def get_column_metadata(self, table_id: str) -> Optional[Dict[str, Any]]:
"""Return BigQuery column metadata for schema generation.
Returns:
{"columns": {"col_name": {"source_type": "...", "description": "..."}}}
or None if metadata unavailable.
"""
raw = self.bq_client.get_table_metadata(table_id)
column_types = raw.get("column_types", {})
column_descriptions = raw.get("column_descriptions", {})
if not column_types:
return None
result = {}
for col_name, bq_type in column_types.items():
entry = {"source_type": bq_type}
if col_name in column_descriptions:
entry["description"] = column_descriptions[col_name]
result[col_name] = entry
return {"columns": result}
def discover_tables(self) -> List[Dict[str, Any]]:
"""Discover all available tables from BigQuery."""
return self.bq_client.discover_all_tables()
def get_source_name(self) -> str:
"""Display name of this data source."""
return "Google BigQuery"
def sync_table(
self,
table_config: TableConfig,
sync_state: SyncState,
) -> Dict[str, Any]:
"""
Synchronize table from BigQuery.
Dispatches to the appropriate strategy based on table config.
Args:
table_config: Table configuration
sync_state: Sync state manager
Returns:
Dictionary with sync result
"""
logger.info(f"Syncing BQ table: {table_config.name} ({table_config.sync_strategy})")
# Clear metadata cache for fresh types
if table_config.id in self.bq_client.metadata_cache:
del self.bq_client.metadata_cache[table_config.id]
logger.debug(f"Cleared BQ metadata cache for {table_config.id}")
try:
if table_config.sync_strategy == "full_refresh":
result = self._full_refresh(table_config)
elif table_config.sync_strategy == "incremental":
result = self._incremental_sync(table_config, sync_state)
elif table_config.sync_strategy == "partitioned":
result = self._partitioned_sync(table_config, sync_state)
else:
raise ValueError(f"Unknown sync strategy: {table_config.sync_strategy}")
# Update sync state
sync_state.update_sync(
table_id=table_config.id,
table_name=table_config.name,
strategy=table_config.sync_strategy,
rows=result["rows"],
file_size_bytes=result["file_size_bytes"],
columns=result.get("columns", 0),
uncompressed_bytes=result.get("uncompressed_bytes", 0),
)
return {
"success": True,
"rows": result["rows"],
"strategy": table_config.sync_strategy,
"file_size_mb": result["file_size_bytes"] / 1024 / 1024,
}
except Exception as e:
logger.error(f"Error syncing BQ table {table_config.name}: {e}")
return {
"success": False,
"error": str(e),
"strategy": table_config.sync_strategy,
}
def _full_refresh(self, table_config: TableConfig) -> Dict[str, Any]:
"""
Full refresh: stream table from BQ and write to Parquet in batches.
Uses streaming (constant memory) instead of loading entire table into RAM.
Each RecordBatch from BQ is written directly to disk via ParquetWriter.
"""
logger.info(f"Full refresh (streaming): {table_config.name}")
parquet_path = self.config.get_parquet_path(table_config)
date_columns = self.bq_client.get_date_columns(table_config.id)
pyarrow_schema = self.bq_client.get_pyarrow_schema(table_config.id)
parquet_path.parent.mkdir(parents=True, exist_ok=True)
# Stream BQ results directly to Parquet file (constant memory)
writer = None
total_rows = 0
num_columns = 0
for batch in self.bq_client.read_table_streaming(
table_config.id,
columns=table_config.columns,
row_filter=table_config.row_filter,
):
if batch.num_rows == 0:
continue
# Convert batch to table for schema enforcement
chunk = pa.Table.from_batches([batch])
if date_columns:
chunk = convert_date_columns_to_date32(chunk, date_columns)
if pyarrow_schema:
chunk = apply_schema_to_table(chunk, pyarrow_schema)
if writer is None:
writer = pq.ParquetWriter(
parquet_path, chunk.schema, compression="snappy",
)
num_columns = chunk.num_columns
writer.write_table(chunk)
total_rows += chunk.num_rows
# Log progress every ~1M rows
if total_rows % 1_000_000 < chunk.num_rows:
logger.info(f" -> {total_rows:,} rows written...")
if writer:
writer.close()
file_size = parquet_path.stat().st_size if parquet_path.exists() else 0
logger.info(
f"Full refresh complete: {total_rows:,} rows, "
f"{file_size / 1024 / 1024:.2f} MB"
)
return {
"rows": total_rows,
"columns": num_columns,
"file_size_bytes": file_size,
"uncompressed_bytes": _get_uncompressed_size(parquet_path) if total_rows > 0 else 0,
}
def _incremental_sync(
self,
table_config: TableConfig,
sync_state: SyncState,
) -> Dict[str, Any]:
"""
Incremental sync: dispatch to column-based or partition-based strategy.
"""
# If partition_by is set, use partitioned incremental
if table_config.partition_by:
return self._partitioned_sync(table_config, sync_state)
# If incremental_column is set, use timestamp-based incremental
if table_config.incremental_column:
return self._incremental_column_sync(table_config, sync_state)
# Fallback: full refresh (no incremental column configured)
logger.warning(
f"Table {table_config.name}: incremental strategy but no "
f"incremental_column or partition_by configured, falling back to full refresh"
)
return self._full_refresh(table_config)
def _incremental_column_sync(
self,
table_config: TableConfig,
sync_state: SyncState,
) -> Dict[str, Any]:
"""
Timestamp-based incremental sync using incremental_column.
Reads rows WHERE incremental_column > last_sync_value,
merges with existing Parquet (dedup on primary key).
"""
logger.info(
f"Incremental column sync: {table_config.name} "
f"(column: {table_config.incremental_column})"
)
parquet_path = self.config.get_parquet_path(table_config)
date_columns = self.bq_client.get_date_columns(table_config.id)
pyarrow_schema = self.bq_client.get_pyarrow_schema(table_config.id)
# Determine since_value from last sync
last_sync = sync_state.get_last_sync(table_config.id)
if last_sync and parquet_path.exists():
# Apply window: go back incremental_window_days from last sync
last_sync_dt = datetime.fromisoformat(last_sync)
window_days = table_config.incremental_window_days or 7
since_dt = last_sync_dt - timedelta(days=window_days)
since_value = since_dt.isoformat()
logger.info(f" -> Since: {since_value} (window: {window_days} days)")
# Read incremental data
new_data = self.bq_client.read_table_incremental(
table_id=table_config.id,
incremental_column=table_config.incremental_column,
since_value=since_value,
columns=table_config.columns,
)
if new_data.num_rows == 0:
logger.info(" -> No new data since last sync")
existing_pf = pq.ParquetFile(parquet_path)
return {
"rows": existing_pf.metadata.num_rows,
"columns": len(existing_pf.schema_arrow),
"file_size_bytes": parquet_path.stat().st_size,
"uncompressed_bytes": _get_uncompressed_size(parquet_path),
}
# Merge with existing data
logger.info(f" -> Merging {new_data.num_rows} new rows with existing data")
existing_table = pq.read_table(parquet_path)
merged = self._merge_arrow_tables(
existing_table, new_data, table_config.get_primary_key_columns()
)
# Apply schema enforcement
if date_columns:
merged = convert_date_columns_to_date32(merged, date_columns)
if pyarrow_schema:
merged = apply_schema_to_table(merged, pyarrow_schema)
pq.write_table(merged, parquet_path, compression="snappy")
file_size = parquet_path.stat().st_size
logger.info(
f" -> Incremental sync complete: {merged.num_rows} total rows"
)
return {
"rows": merged.num_rows,
"columns": merged.num_columns,
"file_size_bytes": file_size,
"uncompressed_bytes": _get_uncompressed_size(parquet_path),
}
else:
# First sync or no existing file -- full read
logger.info(" -> First sync, reading all data")
if table_config.max_history_days:
since_dt = datetime.now() - timedelta(days=table_config.max_history_days)
arrow_table = self.bq_client.read_table_incremental(
table_id=table_config.id,
incremental_column=table_config.incremental_column,
since_value=since_dt.isoformat(),
columns=table_config.columns,
)
else:
arrow_table = self.bq_client.read_table(
table_config.id,
columns=table_config.columns,
row_filter=table_config.row_filter,
)
# Apply schema enforcement
if date_columns:
arrow_table = convert_date_columns_to_date32(arrow_table, date_columns)
if pyarrow_schema:
arrow_table = apply_schema_to_table(arrow_table, pyarrow_schema)
parquet_path.parent.mkdir(parents=True, exist_ok=True)
pq.write_table(arrow_table, parquet_path, compression="snappy")
file_size = parquet_path.stat().st_size
return {
"rows": arrow_table.num_rows,
"columns": arrow_table.num_columns,
"file_size_bytes": file_size,
"uncompressed_bytes": _get_uncompressed_size(parquet_path),
}
def _partitioned_sync(
self,
table_config: TableConfig,
sync_state: SyncState,
) -> Dict[str, Any]:
"""
Per-partition streaming sync: process one partition (day) at a time.
Queries BQ for a single day, streams result to disk, then moves to next day.
Memory usage is constant (~20-50 MB per partition) regardless of total data volume.
"""
partition_col = table_config.partition_by
if not partition_col and table_config.incremental_column:
partition_col = table_config.incremental_column
if not partition_col:
logger.warning(
f"Table {table_config.name}: partitioned strategy but no "
f"partition_by or incremental_column, falling back to full refresh"
)
return self._full_refresh(table_config)
granularity = table_config.partition_granularity or "day"
column_type = table_config.partition_column_type
logger.info(
f"Partitioned sync: {table_config.name} "
f"(by {partition_col}, {granularity}, type={column_type})"
)
partition_dir = self.config.get_parquet_path(table_config)
date_columns = self.bq_client.get_date_columns(table_config.id)
pyarrow_schema = self.bq_client.get_pyarrow_schema(table_config.id)
# Determine date range
last_sync = sync_state.get_last_sync(table_config.id)
today = date.today()
if last_sync:
last_sync_dt = datetime.fromisoformat(last_sync)
window_days = table_config.incremental_window_days or 7
start_date = (last_sync_dt - timedelta(days=window_days)).date()
logger.info(f" -> Incremental sync from {start_date} (window: {window_days} days)")
else:
if table_config.max_history_days:
start_date = today - timedelta(days=table_config.max_history_days)
logger.info(f" -> First sync, last {table_config.max_history_days} days from {start_date}")
else:
start_date = today - timedelta(days=365)
logger.info(" -> First sync, no max_history_days, defaulting to 365 days")
# Generate list of partition dates
partition_dates = self._generate_partition_dates(start_date, today, granularity)
logger.info(f" -> Processing {len(partition_dates)} partitions")
total_rows = 0
partitions_updated = 0
for partition_date in partition_dates:
rows = self._sync_single_partition(
table_config=table_config,
partition_col=partition_col,
partition_date=partition_date,
partition_dir=partition_dir,
date_columns=date_columns,
pyarrow_schema=pyarrow_schema,
granularity=granularity,
column_type=column_type,
)
if rows > 0:
partitions_updated += 1
total_rows += rows
# Cleanup old partitions beyond retention window
deleted = self._cleanup_old_partitions(table_config, partition_dir, granularity)
if deleted > 0:
logger.info(f" -> Cleaned up {deleted} old partition files")
logger.info(
f" -> Partitioned sync complete: {partitions_updated} partitions updated, "
f"{total_rows} total rows processed"
)
return self._get_partition_totals(partition_dir)
@staticmethod
def _generate_partition_dates(
start_date: date,
end_date: date,
granularity: str,
) -> List[date]:
"""Generate list of partition start dates between start and end."""
dates = []
current = start_date
if granularity == "day":
while current <= end_date:
dates.append(current)
current += timedelta(days=1)
elif granularity == "month":
# Align to first of month
current = current.replace(day=1)
while current <= end_date:
dates.append(current)
# Move to first of next month
if current.month == 12:
current = current.replace(year=current.year + 1, month=1)
else:
current = current.replace(month=current.month + 1)
elif granularity == "year":
current = current.replace(month=1, day=1)
while current <= end_date:
dates.append(current)
current = current.replace(year=current.year + 1)
return dates
def _sync_single_partition(
self,
table_config: TableConfig,
partition_col: str,
partition_date: date,
partition_dir: Path,
date_columns: List[str],
pyarrow_schema,
granularity: str,
column_type: str,
) -> int:
"""
Query BQ for one partition period, stream to disk, merge with existing file.
Returns row count for this partition after merge.
"""
import pandas as pd
# Calculate partition range [start, end)
start = partition_date
if granularity == "day":
end = start + timedelta(days=1)
partition_key = start.strftime("%Y_%m_%d")
elif granularity == "month":
if start.month == 12:
end = start.replace(year=start.year + 1, month=1)
else:
end = start.replace(month=start.month + 1)
partition_key = start.strftime("%Y_%m")
elif granularity == "year":
end = start.replace(year=start.year + 1)
partition_key = start.strftime("%Y")
else:
raise ValueError(f"Unknown granularity: {granularity}")
partition_path = self.config.get_partition_path(table_config, partition_key)
# Stream data from BQ for this single partition
batches = []
for batch in self.bq_client.read_table_partitioned_streaming(
table_id=table_config.id,
partition_column=partition_col,
start=start.isoformat(),
end=end.isoformat(),
columns=table_config.columns,
column_type=column_type,
):
batches.append(batch)
if not batches:
return 0
new_data = pa.Table.from_batches(batches)
if new_data.num_rows == 0:
return 0
# Apply schema conversions
if date_columns:
new_data = convert_date_columns_to_date32(new_data, date_columns)
if pyarrow_schema:
new_data = apply_schema_to_table(new_data, pyarrow_schema)
# Merge with existing partition file if present
primary_key_cols = table_config.get_primary_key_columns()
if partition_path.exists():
existing = pq.read_table(partition_path)
merged = self._merge_arrow_tables(existing, new_data, primary_key_cols)
else:
merged = new_data
# Write partition file
pq.write_table(merged, partition_path, compression="snappy")
row_count = merged.num_rows
logger.debug(
f" Partition {partition_key}: {new_data.num_rows} new rows, "
f"{row_count} total after merge"
)
# Release memory
del batches, new_data, merged
return row_count
def _cleanup_old_partitions(
self,
table_config: TableConfig,
partition_dir: Path,
granularity: str,
) -> int:
"""
Delete partition files older than max_history_days.
Returns count of deleted files.
"""
if not table_config.max_history_days:
return 0
if not partition_dir.exists():
return 0
cutoff_date = date.today() - timedelta(days=table_config.max_history_days)
deleted = 0
for part_path in partition_dir.glob("*.parquet"):
try:
partition_date = self._parse_partition_date(part_path.stem, granularity)
if partition_date and partition_date < cutoff_date:
part_path.unlink()
deleted += 1
logger.debug(f" Deleted old partition: {part_path.name}")
except (ValueError, IndexError):
logger.warning(f" Skipping unrecognized partition file: {part_path.name}")
return deleted
@staticmethod
def _parse_partition_date(partition_key: str, granularity: str) -> Optional[date]:
"""Parse a partition key back to a date."""
try:
if granularity == "day":
return datetime.strptime(partition_key, "%Y_%m_%d").date()
elif granularity == "month":
return datetime.strptime(partition_key, "%Y_%m").date()
elif granularity == "year":
return datetime.strptime(partition_key, "%Y").date()
except ValueError:
return None
return None
def _merge_arrow_tables(
self,
existing: pa.Table,
new_data: pa.Table,
primary_key: List[str],
) -> pa.Table:
"""
Merge two Arrow tables with deduplication on primary key.
New data overwrites existing rows with the same primary key.
Args:
existing: Existing data
new_data: New/changed data
primary_key: List of PK column names
Returns:
Merged PyArrow Table
"""
import pandas as pd
existing_df = existing.to_pandas()
new_df = new_data.to_pandas()
# Concat and dedup (keep last = new data wins)
merged_df = pd.concat([existing_df, new_df], ignore_index=True)
merged_df = merged_df.drop_duplicates(subset=primary_key, keep="last")
return pa.Table.from_pandas(merged_df, preserve_index=False)
def _get_partition_totals(self, partition_dir: Path) -> Dict[str, Any]:
"""
Calculate totals from all partition files in a directory.
"""
total_rows = 0
total_size = 0
total_uncompressed = 0
total_columns = 0
if not partition_dir.exists():
return {"rows": 0, "file_size_bytes": 0, "columns": 0, "uncompressed_bytes": 0}
all_partitions = list(partition_dir.glob("*.parquet"))
for part_path in all_partitions:
try:
pf = pq.ParquetFile(part_path)
meta = pf.metadata
total_rows += meta.num_rows
total_size += part_path.stat().st_size
if total_columns == 0:
total_columns = len(pf.schema_arrow)
for rg_idx in range(meta.num_row_groups):
rg = meta.row_group(rg_idx)
for col_idx in range(rg.num_columns):
total_uncompressed += rg.column(col_idx).total_uncompressed_size
except Exception as e:
logger.warning(f"Skipping corrupt partition {part_path.name}: {e}")
return {
"rows": total_rows,
"file_size_bytes": total_size,
"partitions": len(all_partitions),
"columns": total_columns,
"uncompressed_bytes": total_uncompressed,
}
def create_data_source() -> BigQueryDataSource:
"""Factory function for dynamic import compatibility."""
return BigQueryDataSource()