Add per-partition streaming sync and hybrid query architecture

Partitioned sync: iterates day-by-day instead of loading full dataset.
Each partition: query BQ -> stream to disk -> free RAM. Peak ~50 MB.
New helpers: _sync_single_partition, _cleanup_old_partitions, _generate_partition_dates.

Config: added partition_column_type (DATE/TIMESTAMP/DATETIME), query_mode (local/remote/hybrid).
DuckDB manager: hybrid architecture support (local Parquet + remote BQ tables).
Data sync: skips remote tables, filters by query_mode.

Tests: 113 passing (adapter, client, config, data_sync, duckdb_manager).
This commit is contained in:
Petr 2026-03-12 13:20:41 +01:00
parent d2e83ce9d0
commit 8bb46a9e0a
9 changed files with 1731 additions and 135 deletions

View file

@ -9,7 +9,7 @@ using PyArrow (no CSV intermediate step).
import logging
from pathlib import Path
from typing import Dict, List, Optional, Any
from datetime import datetime, timedelta
from datetime import datetime, timedelta, date
import pyarrow as pa
import pyarrow.parquet as pq
@ -336,10 +336,11 @@ class BigQueryDataSource(DataSource):
sync_state: SyncState,
) -> Dict[str, Any]:
"""
Partition-based sync: read data by partition range and write partition files.
"""
import pandas as pd
Per-partition streaming sync: process one partition (day) at a time.
Queries BQ for a single day, streams result to disk, then moves to next day.
Memory usage is constant (~20-50 MB per partition) regardless of total data volume.
"""
partition_col = table_config.partition_by
if not partition_col and table_config.incremental_column:
partition_col = table_config.incremental_column
@ -351,96 +352,231 @@ class BigQueryDataSource(DataSource):
)
return self._full_refresh(table_config)
granularity = table_config.partition_granularity or "month"
granularity = table_config.partition_granularity or "day"
column_type = table_config.partition_column_type
logger.info(
f"Partitioned sync: {table_config.name} "
f"(by {partition_col}, {granularity})"
f"(by {partition_col}, {granularity}, type={column_type})"
)
partition_dir = self.config.get_parquet_path(table_config)
date_columns = self.bq_client.get_date_columns(table_config.id)
pyarrow_schema = self.bq_client.get_pyarrow_schema(table_config.id)
# Determine time range
# Determine date range
last_sync = sync_state.get_last_sync(table_config.id)
today = date.today()
if last_sync:
last_sync_dt = datetime.fromisoformat(last_sync)
window_days = table_config.incremental_window_days or 7
start_dt = last_sync_dt - timedelta(days=window_days)
logger.info(f" -> Reading from {start_dt.isoformat()} (window: {window_days} days)")
start_date = (last_sync_dt - timedelta(days=window_days)).date()
logger.info(f" -> Incremental sync from {start_date} (window: {window_days} days)")
else:
if table_config.max_history_days:
start_dt = datetime.now() - timedelta(days=table_config.max_history_days)
logger.info(f" -> First sync, limited to last {table_config.max_history_days} days")
start_date = today - timedelta(days=table_config.max_history_days)
logger.info(f" -> First sync, last {table_config.max_history_days} days from {start_date}")
else:
start_dt = None
logger.info(" -> First sync, reading all data")
start_date = today - timedelta(days=365)
logger.info(" -> First sync, no max_history_days, defaulting to 365 days")
# Read data from BigQuery
if start_dt:
arrow_table = self.bq_client.read_table_partitioned(
table_id=table_config.id,
partition_column=partition_col,
start=start_dt.isoformat(),
columns=table_config.columns,
)
else:
arrow_table = self.bq_client.read_table(
table_config.id,
columns=table_config.columns,
row_filter=table_config.row_filter,
# Generate list of partition dates
partition_dates = self._generate_partition_dates(start_date, today, granularity)
logger.info(f" -> Processing {len(partition_dates)} partitions")
total_rows = 0
partitions_updated = 0
for partition_date in partition_dates:
rows = self._sync_single_partition(
table_config=table_config,
partition_col=partition_col,
partition_date=partition_date,
partition_dir=partition_dir,
date_columns=date_columns,
pyarrow_schema=pyarrow_schema,
granularity=granularity,
column_type=column_type,
)
if rows > 0:
partitions_updated += 1
total_rows += rows
if arrow_table.num_rows == 0:
logger.info(" -> No data to sync")
return self._get_partition_totals(partition_dir)
# Cleanup old partitions beyond retention window
deleted = self._cleanup_old_partitions(table_config, partition_dir, granularity)
if deleted > 0:
logger.info(f" -> Cleaned up {deleted} old partition files")
logger.info(f" -> Processing {arrow_table.num_rows} rows into partitions")
# Convert to pandas for partitioning
df = arrow_table.to_pandas()
# Ensure partition column is datetime
if not pd.api.types.is_datetime64_any_dtype(df[partition_col]):
df[partition_col] = pd.to_datetime(df[partition_col], format="ISO8601", utc=True)
# Create partition key
if granularity == "month":
df["_partition_key"] = df[partition_col].dt.strftime("%Y_%m")
elif granularity == "day":
df["_partition_key"] = df[partition_col].dt.strftime("%Y_%m_%d")
elif granularity == "year":
df["_partition_key"] = df[partition_col].dt.strftime("%Y")
primary_key_cols = table_config.get_primary_key_columns()
partitions_updated = set()
for partition_key, group_df in df.groupby("_partition_key"):
group_df = group_df.drop(columns=["_partition_key"])
partition_path = self.config.get_partition_path(table_config, partition_key)
partitions_updated.add(partition_key)
# Merge with existing partition if it exists
if partition_path.exists():
existing_df = pd.read_parquet(partition_path)
merged_df = pd.concat([existing_df, group_df], ignore_index=True)
merged_df = merged_df.drop_duplicates(subset=primary_key_cols, keep="last")
else:
merged_df = group_df
# Write partition
table = pa.Table.from_pandas(merged_df, preserve_index=False)
if date_columns:
table = convert_date_columns_to_date32(table, date_columns)
if pyarrow_schema:
table = apply_schema_to_table(table, pyarrow_schema)
pq.write_table(table, partition_path, compression="snappy")
logger.info(f" -> Partitioned sync complete: {len(partitions_updated)} partitions updated")
logger.info(
f" -> Partitioned sync complete: {partitions_updated} partitions updated, "
f"{total_rows} total rows processed"
)
return self._get_partition_totals(partition_dir)
@staticmethod
def _generate_partition_dates(
start_date: date,
end_date: date,
granularity: str,
) -> List[date]:
"""Generate list of partition start dates between start and end."""
dates = []
current = start_date
if granularity == "day":
while current <= end_date:
dates.append(current)
current += timedelta(days=1)
elif granularity == "month":
# Align to first of month
current = current.replace(day=1)
while current <= end_date:
dates.append(current)
# Move to first of next month
if current.month == 12:
current = current.replace(year=current.year + 1, month=1)
else:
current = current.replace(month=current.month + 1)
elif granularity == "year":
current = current.replace(month=1, day=1)
while current <= end_date:
dates.append(current)
current = current.replace(year=current.year + 1)
return dates
def _sync_single_partition(
self,
table_config: TableConfig,
partition_col: str,
partition_date: date,
partition_dir: Path,
date_columns: List[str],
pyarrow_schema,
granularity: str,
column_type: str,
) -> int:
"""
Query BQ for one partition period, stream to disk, merge with existing file.
Returns row count for this partition after merge.
"""
import pandas as pd
# Calculate partition range [start, end)
start = partition_date
if granularity == "day":
end = start + timedelta(days=1)
partition_key = start.strftime("%Y_%m_%d")
elif granularity == "month":
if start.month == 12:
end = start.replace(year=start.year + 1, month=1)
else:
end = start.replace(month=start.month + 1)
partition_key = start.strftime("%Y_%m")
elif granularity == "year":
end = start.replace(year=start.year + 1)
partition_key = start.strftime("%Y")
else:
raise ValueError(f"Unknown granularity: {granularity}")
partition_path = self.config.get_partition_path(table_config, partition_key)
# Stream data from BQ for this single partition
batches = []
for batch in self.bq_client.read_table_partitioned_streaming(
table_id=table_config.id,
partition_column=partition_col,
start=start.isoformat(),
end=end.isoformat(),
columns=table_config.columns,
column_type=column_type,
):
batches.append(batch)
if not batches:
return 0
new_data = pa.Table.from_batches(batches)
if new_data.num_rows == 0:
return 0
# Apply schema conversions
if date_columns:
new_data = convert_date_columns_to_date32(new_data, date_columns)
if pyarrow_schema:
new_data = apply_schema_to_table(new_data, pyarrow_schema)
# Merge with existing partition file if present
primary_key_cols = table_config.get_primary_key_columns()
if partition_path.exists():
existing = pq.read_table(partition_path)
merged = self._merge_arrow_tables(existing, new_data, primary_key_cols)
else:
merged = new_data
# Write partition file
pq.write_table(merged, partition_path, compression="snappy")
row_count = merged.num_rows
logger.debug(
f" Partition {partition_key}: {new_data.num_rows} new rows, "
f"{row_count} total after merge"
)
# Release memory
del batches, new_data, merged
return row_count
def _cleanup_old_partitions(
self,
table_config: TableConfig,
partition_dir: Path,
granularity: str,
) -> int:
"""
Delete partition files older than max_history_days.
Returns count of deleted files.
"""
if not table_config.max_history_days:
return 0
if not partition_dir.exists():
return 0
cutoff_date = date.today() - timedelta(days=table_config.max_history_days)
deleted = 0
for part_path in partition_dir.glob("*.parquet"):
try:
partition_date = self._parse_partition_date(part_path.stem, granularity)
if partition_date and partition_date < cutoff_date:
part_path.unlink()
deleted += 1
logger.debug(f" Deleted old partition: {part_path.name}")
except (ValueError, IndexError):
logger.warning(f" Skipping unrecognized partition file: {part_path.name}")
return deleted
@staticmethod
def _parse_partition_date(partition_key: str, granularity: str) -> Optional[date]:
"""Parse a partition key back to a date."""
try:
if granularity == "day":
return datetime.strptime(partition_key, "%Y_%m_%d").date()
elif granularity == "month":
return datetime.strptime(partition_key, "%Y_%m").date()
elif granularity == "year":
return datetime.strptime(partition_key, "%Y").date()
except ValueError:
return None
return None
def _merge_arrow_tables(
self,
existing: pa.Table,

View file

@ -1,9 +1,11 @@
#!/usr/bin/env python3
"""
DuckDB Manager - Initialize and manage DuckDB database with views from parquet files.
DuckDB Manager - Initialize and manage DuckDB database with views from parquet files
and runtime BigQuery query registration for remote/hybrid tables.
This script dynamically reads table configurations from docs/data_description.md
and creates DuckDB views accordingly. No hardcoded table list needed!
For BigQuery data sources, tables with query_mode="remote" or "hybrid" are queried
at runtime via the Python BQ client and registered as in-memory Arrow tables in DuckDB.
This avoids the DuckDB BigQuery extension limitation (cannot read BQ views).
Usage:
python3 scripts/duckdb_manager.py --reinit # Initialize/reinitialize all views
@ -11,13 +13,17 @@ Usage:
"""
import duckdb
import logging
import os
import sys
import argparse
import re
import yaml
from pathlib import Path
from typing import Dict, List, Tuple
from typing import Dict, List, Optional, Tuple
logger = logging.getLogger(__name__)
def find_project_root() -> Path:
@ -157,17 +163,131 @@ def get_parquet_path(table_config: Dict, folder_mapping: Dict[str, str], data_di
return parquet_dir / f"{table_name}.parquet"
def init_duckdb(db_path="user/duckdb/analytics.duckdb", data_dir="server", verbose=True):
def _get_bq_project_from_table_id(table_id: str) -> Optional[str]:
"""Extract BQ project ID from a fully-qualified table ID.
Args:
table_id: e.g. "prj-grp-dataview-prod-1ff9.finance_unit_economics.unit_economics"
Returns:
Project ID or None if format doesn't match BQ convention
"""
parts = table_id.split(".")
if len(parts) == 3 and "-" in parts[0]:
return parts[0]
return None
def _create_bq_client(project: str):
"""Create a BigQuery client. Separated for testability.
Args:
project: GCP project ID for billing
Returns:
google.cloud.bigquery.Client instance
"""
from google.cloud import bigquery as bq_module
return bq_module.Client(project=project)
def register_bq_table(
conn: duckdb.DuckDBPyConnection,
table_id: str,
view_name: str,
sql: str,
bq_project: Optional[str] = None,
_bq_client_factory=None,
) -> int:
"""
Execute a BigQuery SQL query and register the result as a DuckDB view.
Uses the Python BigQuery client (Query API) which supports BQ views,
unlike the DuckDB BigQuery extension (Storage Read API).
The result is held in memory as a PyArrow table -- no disk I/O.
Args:
conn: Open DuckDB connection
table_id: BQ table ID for logging (e.g., "project.dataset.table")
view_name: Name to register in DuckDB (e.g., "unit_economics_live")
sql: Full BigQuery SQL query to execute
bq_project: GCP project for billing. If None, uses BIGQUERY_PROJECT env var.
_bq_client_factory: Override BQ client creation (for testing)
Returns:
Number of rows in the result
Raises:
ImportError: If google-cloud-bigquery is not installed
ValueError: If bq_project is not set
"""
project = bq_project or os.environ.get("BIGQUERY_PROJECT")
if not project:
raise ValueError(
"BigQuery project not set. "
"Pass bq_project or set BIGQUERY_PROJECT env var."
)
logger.info(f"Querying BQ: {table_id} -> {view_name}")
logger.debug(f"SQL: {sql[:200]}...")
factory = _bq_client_factory or _create_bq_client
client = factory(project)
job = client.query(sql)
# Use Query API (not Storage Read API) to support BQ views
try:
arrow_table = job.to_arrow()
except Exception as e:
if "readsessions" in str(e) or "PERMISSION_DENIED" in str(e):
logger.warning("BQ Storage API unavailable, falling back to REST")
arrow_table = job.to_arrow(create_bqstorage_client=False)
else:
raise
conn.register(view_name, arrow_table)
logger.info(
f"Registered {view_name}: {arrow_table.num_rows} rows, "
f"{arrow_table.num_columns} cols (in-memory)"
)
return arrow_table.num_rows
def get_remote_tables(table_configs: List[Dict]) -> List[Dict]:
"""Return table configs with query_mode 'remote' or 'hybrid'.
Args:
table_configs: List of table configuration dicts
Returns:
List of remote/hybrid table configs
"""
return [
tc for tc in table_configs
if tc.get("query_mode") in ("remote", "hybrid")
]
def init_duckdb(
db_path="user/duckdb/analytics.duckdb",
data_dir="server",
verbose=True,
bq_project: Optional[str] = None,
):
"""
Initialize DuckDB database with views from parquet files.
Dynamically reads table configurations from docs/data_description.md
and creates views accordingly.
Creates DuckDB views for local/hybrid tables (from Parquet).
Remote tables are NOT pre-loaded -- they are registered at query time
via register_bq_table().
Args:
db_path: Path to DuckDB database file
data_dir: Base data directory (e.g., "server" for analysts, "data" for server)
verbose: Print progress messages
bq_project: BigQuery execution project (for informational purposes only)
Returns:
True if successful, False otherwise
@ -176,29 +296,48 @@ def init_duckdb(db_path="user/duckdb/analytics.duckdb", data_dir="server", verbo
os.makedirs(os.path.dirname(db_path), exist_ok=True)
if verbose:
print("🦆 Inicializuji DuckDB databázi...")
print("Initializing DuckDB database...")
try:
# Find project root and parse data_description.md
project_root = find_project_root()
if verbose:
print(f" 📂 Project root: {project_root}")
print(f" Project root: {project_root}")
table_configs, folder_mapping = parse_data_description(project_root)
if verbose:
print(f" 📋 Načteno {len(table_configs)} tabulek z data_description.md")
print(f" Loaded {len(table_configs)} tables from data_description.md")
# Separate tables by query_mode
local_tables = []
remote_tables = []
hybrid_tables = []
for tc in table_configs:
mode = tc.get("query_mode", "local")
if mode == "remote":
remote_tables.append(tc)
elif mode == "hybrid":
hybrid_tables.append(tc)
else:
local_tables.append(tc)
if verbose:
print(f" Query modes: {len(local_tables)} local, "
f"{len(remote_tables)} remote, {len(hybrid_tables)} hybrid")
# Connect to database (creates if doesn't exist)
conn = duckdb.connect(db_path)
# Create views
# Create local views from parquet files
if verbose:
print("\n📊 Vytvářím views z parquet souborů...")
print("\n Creating views from parquet files...")
created_views = []
skipped_views = []
for table_config in table_configs:
# Process local and hybrid tables (both have local parquet)
for table_config in local_tables + hybrid_tables:
table_name = table_config['name']
try:
@ -209,7 +348,7 @@ def init_duckdb(db_path="user/duckdb/analytics.duckdb", data_dir="server", verbo
if not parquet_path.exists():
skipped_views.append(table_name)
if verbose:
print(f" ⚠️ Přeskakuji {table_name} - parquet neexistuje: {parquet_path}")
print(f" [SKIP] {table_name} - parquet not found: {parquet_path}")
continue
# Determine if partitioned
@ -229,49 +368,65 @@ def init_duckdb(db_path="user/duckdb/analytics.duckdb", data_dir="server", verbo
if not partition_files:
skipped_views.append(table_name)
if verbose:
print(f" ⚠️ Přeskakuji {table_name} - žádné partition soubory")
print(f" [SKIP] {table_name} - no partition files")
continue
sql = f"CREATE OR REPLACE VIEW {table_name} AS SELECT * FROM read_parquet('{glob_pattern}', union_by_name=true)"
if verbose:
print(f"{table_name} ({len(partition_files)} partitions)")
mode_label = "hybrid" if table_config.get("query_mode") == "hybrid" else "local"
print(f" [OK] {table_name} ({len(partition_files)} partitions, {mode_label})")
else:
# Single parquet file
sql = f"CREATE OR REPLACE VIEW {table_name} AS SELECT * FROM read_parquet('{parquet_path}')"
if verbose:
print(f"{table_name}")
mode_label = "hybrid" if table_config.get("query_mode") == "hybrid" else "local"
print(f" [OK] {table_name} ({mode_label})")
conn.execute(sql)
created_views.append(table_name)
except Exception as e:
if verbose:
print(f" ❌ Chyba při vytváření {table_name}: {e}")
print(f" [ERR] Error creating {table_name}: {e}")
return False
# Log remote tables (queried at runtime via register_bq_table)
if remote_tables:
if verbose:
print("\n Remote tables (queried at runtime via BigQuery):")
for table_config in remote_tables:
table_name = table_config['name']
table_id = table_config['id']
if verbose:
print(f" [BQ] {table_name} -> {table_id}")
# Display table list with row counts
if verbose:
print(f"\n📋 Seznam dostupných tabulek ({len(created_views)} vytvořeno):")
print(f"\n Available tables ({len(created_views)} local views):")
tables = conn.execute("SHOW TABLES").fetchall()
for table in tables:
try:
row_count = conn.execute(f"SELECT COUNT(*) FROM {table[0]}").fetchone()[0]
print(f" - {table[0]}: {row_count:,} řádků")
except Exception as e:
print(f" - {table[0]}: (chyba při počítání řádků)")
print(f" - {table[0]}: {row_count:,} rows (local)")
except Exception:
print(f" - {table[0]}: (error counting rows)")
if remote_tables:
print(f"\n Remote tables ({len(remote_tables)}, loaded on demand):")
for tc in remote_tables:
print(f" - {tc['name']}: via BQ Query API (use date filters!)")
# Close connection
conn.close()
if verbose:
print(f"\n✅ DuckDB databáze vytvořena: {db_path}")
print("💡 Můžeš začít analyzovat data pomocí DuckDB SQL dotazů")
print(f"\n DuckDB database created: {db_path}")
return True
except Exception as e:
if verbose:
print(f"\n❌ Chyba při inicializaci DuckDB: {e}")
print(f"\n Error initializing DuckDB: {e}")
import traceback
traceback.print_exc()
return False
@ -297,6 +452,11 @@ def main():
default='server',
help='Base data directory (default: server, use "data" for server deployment)'
)
parser.add_argument(
'--bq-project',
default=None,
help='BigQuery execution project (informational only)'
)
parser.add_argument(
'--quiet',
action='store_true',
@ -314,7 +474,8 @@ def main():
success = init_duckdb(
db_path=args.db_path,
data_dir=args.data_dir,
verbose=not args.quiet
verbose=not args.quiet,
bq_project=args.bq_project,
)
# Exit with appropriate code

View file

@ -104,9 +104,19 @@ class TableConfig:
incremental_column: Optional[str] = None # Column for timestamp-based incremental sync (BigQuery)
columns: Optional[List[str]] = None # Subset of columns to sync (None = all)
row_filter: Optional[str] = None # SQL WHERE clause for filtering (e.g., "event_date >= '2024-01-01'")
query_mode: str = "local" # "local" (Parquet) | "remote" (BQ direct) | "hybrid" (sync subset, query BQ)
partition_column_type: str = "TIMESTAMP" # BQ SQL type for partition column: "DATE", "TIMESTAMP", "DATETIME"
def __post_init__(self):
"""Validate configuration after initialization."""
# Validate query_mode
valid_query_modes = ("local", "remote", "hybrid")
if self.query_mode not in valid_query_modes:
raise ValueError(
f"Invalid query_mode '{self.query_mode}' for table {self.id}. "
f"Allowed values: {', '.join(valid_query_modes)}"
)
# Validate sync_strategy
if self.sync_strategy not in ["full_refresh", "incremental", "partitioned"]:
raise ValueError(
@ -139,6 +149,14 @@ class TableConfig:
f"Allowed values: 'month', 'day', 'year'"
)
# Validate partition_column_type
valid_column_types = ("DATE", "TIMESTAMP", "DATETIME")
if self.partition_column_type not in valid_column_types:
raise ValueError(
f"Invalid partition_column_type '{self.partition_column_type}' for table {self.id}. "
f"Allowed values: {', '.join(valid_column_types)}"
)
# For partitioned, partition_by must be defined
if self.sync_strategy == "partitioned":
if not self.partition_by:
@ -435,6 +453,8 @@ class Config:
incremental_column=table_data.get("incremental_column"),
columns=table_data.get("columns"),
row_filter=table_data.get("row_filter"),
query_mode=table_data.get("query_mode", "local"),
partition_column_type=table_data.get("partition_column_type", "TIMESTAMP"),
)
table_configs.append(config)

View file

@ -406,6 +406,21 @@ class DataSyncManager:
else:
table_configs = self.config.tables
# Filter out remote-only tables (no local sync needed)
remote_skipped = [
tc for tc in table_configs if tc.query_mode == "remote"
]
table_configs = [
tc for tc in table_configs if tc.query_mode != "remote"
]
if remote_skipped:
logger.info(
f"Skipping {len(remote_skipped)} remote-only tables "
f"(query via BigQuery): "
f"{', '.join(tc.name for tc in remote_skipped)}"
)
logger.info(f"Synchronizing {len(table_configs)} tables...")
results = {}

View file

@ -9,6 +9,7 @@ so we install stub modules in sys.modules before importing the adapter.
"""
import sys
from datetime import date, datetime, timedelta
from pathlib import Path
from unittest.mock import MagicMock, patch
@ -83,6 +84,7 @@ def _make_table_config(
partition_by: str | None = None,
partition_granularity: str | None = None,
max_history_days: int | None = None,
partition_column_type: str = "TIMESTAMP",
) -> TableConfig:
"""Helper to build a TableConfig with safe defaults."""
return TableConfig(
@ -96,6 +98,7 @@ def _make_table_config(
partition_by=partition_by,
partition_granularity=partition_granularity,
max_history_days=max_history_days,
partition_column_type=partition_column_type,
)
@ -274,55 +277,217 @@ class TestIncrementalNoNewData:
# ---------------------------------------------------------------------------
# 4. partitioned_sync creates partition files
# 4. partitioned_sync - per-day streaming behaviour
# ---------------------------------------------------------------------------
class TestPartitionedSync:
"""Tests for the rewritten _partitioned_sync() that streams per-day from BQ."""
def test_creates_partition_files(self, mock_config, mock_bq_client, tmp_parquet_dir, sync_state):
"""Partitioned sync should create separate Parquet files per partition key."""
import pandas as pd
def _setup_partition_config(
self,
mock_config,
tmp_parquet_dir,
*,
granularity: str = "day",
max_history_days: int | None = 3,
incremental_window_days: int | None = 2,
partition_column_type: str = "TIMESTAMP",
):
"""Common setup: create table config + wire mock_config partition paths."""
table_config = _make_table_config(
sync_strategy="incremental",
incremental_column="created_at",
partition_by="created_at",
partition_granularity="month",
incremental_window_days=7,
sync_strategy="partitioned",
incremental_column="event_date",
partition_by="event_date",
partition_granularity=granularity,
max_history_days=max_history_days,
incremental_window_days=incremental_window_days,
partition_column_type=partition_column_type,
)
# For partitioned tables, parquet_path is a directory
partition_dir = tmp_parquet_dir / "orders"
partition_dir.mkdir(parents=True, exist_ok=True)
mock_config.get_parquet_path.return_value = partition_dir
# Configure partition paths
def _partition_path(tc, key):
return partition_dir / f"{key}.parquet"
mock_config.get_partition_path.side_effect = _partition_path
# Build arrow table with timestamps in two months
ts_jan = [pd.Timestamp("2026-01-15 10:00:00", tz="UTC")]
ts_feb = [pd.Timestamp("2026-02-20 14:00:00", tz="UTC")]
arrow_data = pa.table({
"id": [1, 2],
"name": ["Jan_Order", "Feb_Order"],
"created_at": pa.array(ts_jan + ts_feb, type=pa.timestamp("us", tz="UTC")),
return table_config, partition_dir
@staticmethod
def _make_day_table(row_id: int, day: date) -> pa.Table:
"""Build a one-row Arrow table for a given day."""
return pa.table({
"id": [row_id],
"event_date": pa.array([datetime(day.year, day.month, day.day)], type=pa.timestamp("us")),
})
mock_bq_client.read_table.return_value = arrow_data
def test_creates_daily_partition_files(
self, mock_config, mock_bq_client, tmp_parquet_dir, sync_state
):
"""First sync with max_history_days creates one Parquet file per day with data."""
table_config, partition_dir = self._setup_partition_config(
mock_config, tmp_parquet_dir, max_history_days=3, granularity="day",
)
today = date.today()
day0 = today - timedelta(days=3)
day1 = today - timedelta(days=2)
# day2, day3 (today-1, today) will have no data
# Build per-day Arrow data for the two days that have rows
day0_table = self._make_day_table(1, day0)
day1_table = self._make_day_table(2, day1)
# read_table_partitioned_streaming is called once per partition date.
# We need to return data for day0 and day1, empty iterators for the rest.
def _streaming_side_effect(*, table_id, partition_column, start, end, columns, column_type):
start_date = date.fromisoformat(start)
if start_date == day0:
return iter(day0_table.to_batches())
if start_date == day1:
return iter(day1_table.to_batches())
return iter([]) # empty for other days
mock_bq_client.read_table_partitioned_streaming.side_effect = _streaming_side_effect
adapter = _create_adapter(mock_config, mock_bq_client)
result = adapter.sync_table(table_config, sync_state)
assert result["success"] is True
# Should have created two partition files
partition_files = list(partition_dir.glob("*.parquet"))
# Should have exactly 2 partition files (days with data)
partition_files = sorted(partition_dir.glob("*.parquet"))
assert len(partition_files) == 2
partition_names = sorted(f.stem for f in partition_files)
assert "2026_01" in partition_names
assert "2026_02" in partition_names
file_names = sorted(f.stem for f in partition_files)
assert day0.strftime("%Y_%m_%d") in file_names
assert day1.strftime("%Y_%m_%d") in file_names
# Verify content of each partition
t0 = pq.read_table(partition_dir / f"{day0.strftime('%Y_%m_%d')}.parquet")
assert t0.num_rows == 1
assert t0.column("id").to_pylist() == [1]
def test_incremental_sync_only_fetches_window(
self, mock_config, mock_bq_client, tmp_parquet_dir, sync_state
):
"""After a previous sync, only the incremental window of days is fetched."""
table_config, partition_dir = self._setup_partition_config(
mock_config, tmp_parquet_dir,
max_history_days=30,
incremental_window_days=2,
granularity="day",
)
# Simulate a previous sync 1 day ago
sync_time = (datetime.now() - timedelta(days=1)).isoformat()
sync_state.update_sync(
table_id=table_config.id,
table_name=table_config.name,
strategy="partitioned",
rows=100,
file_size_bytes=5000,
)
# Return empty for all calls -- we just want to verify the call count
mock_bq_client.read_table_partitioned_streaming.return_value = iter([])
adapter = _create_adapter(mock_config, mock_bq_client)
result = adapter.sync_table(table_config, sync_state)
assert result["success"] is True
# With incremental_window_days=2, it should go back 2 days from last_sync.
# The number of partition dates from (last_sync - 2 days) to today.
last_sync_str = sync_state.get_last_sync(table_config.id)
last_sync_dt = datetime.fromisoformat(last_sync_str)
start_date = (last_sync_dt - timedelta(days=2)).date()
today = date.today()
expected_days = (today - start_date).days + 1 # inclusive
actual_calls = mock_bq_client.read_table_partitioned_streaming.call_count
assert actual_calls == expected_days, (
f"Expected {expected_days} BQ calls (from {start_date} to {today}), got {actual_calls}"
)
def test_merges_with_existing_partition(
self, mock_config, mock_bq_client, tmp_parquet_dir, sync_state
):
"""New data for an existing partition merges and deduplicates on PK."""
table_config, partition_dir = self._setup_partition_config(
mock_config, tmp_parquet_dir, max_history_days=3, granularity="day",
)
today = date.today()
target_day = today - timedelta(days=1)
partition_key = target_day.strftime("%Y_%m_%d")
partition_path = partition_dir / f"{partition_key}.parquet"
# Pre-write an existing partition file with id=1
existing = pa.table({
"id": [1],
"event_date": pa.array(
[datetime(target_day.year, target_day.month, target_day.day)],
type=pa.timestamp("us"),
),
})
pq.write_table(existing, partition_path, compression="snappy")
# New data: id=1 (update) + id=2 (new row)
new_data = pa.table({
"id": [1, 2],
"event_date": pa.array(
[
datetime(target_day.year, target_day.month, target_day.day),
datetime(target_day.year, target_day.month, target_day.day),
],
type=pa.timestamp("us"),
),
})
def _streaming_side_effect(*, table_id, partition_column, start, end, columns, column_type):
start_date = date.fromisoformat(start)
if start_date == target_day:
return iter(new_data.to_batches())
return iter([])
mock_bq_client.read_table_partitioned_streaming.side_effect = _streaming_side_effect
adapter = _create_adapter(mock_config, mock_bq_client)
result = adapter.sync_table(table_config, sync_state)
assert result["success"] is True
# Read back the target partition -- should have 2 rows (dedup on id)
merged = pq.read_table(partition_path)
assert merged.num_rows == 2
assert sorted(merged.column("id").to_pylist()) == [1, 2]
def test_empty_partition_skipped(
self, mock_config, mock_bq_client, tmp_parquet_dir, sync_state
):
"""A partition day with no data from BQ should not create a file."""
table_config, partition_dir = self._setup_partition_config(
mock_config, tmp_parquet_dir, max_history_days=2, granularity="day",
)
# Return empty iterator for every call
mock_bq_client.read_table_partitioned_streaming.return_value = iter([])
# side_effect takes precedence over return_value when set, but let's use
# a function so each call gets a fresh empty iterator
mock_bq_client.read_table_partitioned_streaming.side_effect = (
lambda **kw: iter([])
)
adapter = _create_adapter(mock_config, mock_bq_client)
result = adapter.sync_table(table_config, sync_state)
assert result["success"] is True
# No partition files should have been created
partition_files = list(partition_dir.glob("*.parquet"))
assert len(partition_files) == 0
# ---------------------------------------------------------------------------
@ -574,15 +739,13 @@ class TestSyncTableDispatch:
def test_dispatches_partitioned(
self, mock_config, mock_bq_client, tmp_parquet_dir, sync_state
):
"""sync_strategy='incremental' with partition_by should call _partitioned_sync."""
import pandas as pd
"""sync_strategy='partitioned' should call _partitioned_sync."""
table_config = _make_table_config(
sync_strategy="incremental",
sync_strategy="partitioned",
incremental_column="created_at",
partition_by="created_at",
partition_granularity="month",
incremental_window_days=7,
partition_granularity="day",
max_history_days=2,
)
partition_dir = tmp_parquet_dir / "orders"
partition_dir.mkdir(parents=True, exist_ok=True)
@ -592,13 +755,9 @@ class TestSyncTableDispatch:
return partition_dir / f"{key}.parquet"
mock_config.get_partition_path.side_effect = _partition_path
ts = [pd.Timestamp("2026-01-15 10:00:00", tz="UTC")]
arrow_data = pa.table({
"id": [1],
"name": ["A"],
"created_at": pa.array(ts, type=pa.timestamp("us", tz="UTC")),
})
mock_bq_client.read_table.return_value = arrow_data
mock_bq_client.read_table_partitioned_streaming.side_effect = (
lambda **kw: iter([])
)
adapter = _create_adapter(mock_config, mock_bq_client)
@ -770,3 +929,175 @@ class TestCreateDataSourceFactory:
from connectors.bigquery.adapter import create_data_source, BigQueryDataSource
instance = create_data_source()
assert isinstance(instance, BigQueryDataSource)
# ---------------------------------------------------------------------------
# 14. _cleanup_old_partitions deletes files beyond retention window
# ---------------------------------------------------------------------------
class TestPartitionCleanup:
def test_deletes_old_partitions(self, mock_config, mock_bq_client, tmp_parquet_dir):
"""Partition files older than max_history_days should be deleted."""
table_config = _make_table_config(
sync_strategy="partitioned",
partition_by="event_date",
partition_granularity="day",
max_history_days=5,
)
partition_dir = tmp_parquet_dir / "orders"
partition_dir.mkdir(parents=True, exist_ok=True)
today = date.today()
# Create files: 3 days ago (keep), 6 days ago (delete), 10 days ago (delete)
keep_day = today - timedelta(days=3)
delete_day1 = today - timedelta(days=6)
delete_day2 = today - timedelta(days=10)
for d in [keep_day, delete_day1, delete_day2]:
key = d.strftime("%Y_%m_%d")
dummy = pa.table({"id": [1]})
pq.write_table(dummy, partition_dir / f"{key}.parquet")
adapter = _create_adapter(mock_config, mock_bq_client)
deleted = adapter._cleanup_old_partitions(table_config, partition_dir, "day")
assert deleted == 2
# Only the recent file should remain
remaining = [f.stem for f in partition_dir.glob("*.parquet")]
assert keep_day.strftime("%Y_%m_%d") in remaining
assert delete_day1.strftime("%Y_%m_%d") not in remaining
assert delete_day2.strftime("%Y_%m_%d") not in remaining
def test_no_cleanup_without_max_history_days(self, mock_config, mock_bq_client, tmp_parquet_dir):
"""Without max_history_days, no partition files should be deleted."""
table_config = _make_table_config(
sync_strategy="partitioned",
partition_by="event_date",
partition_granularity="day",
max_history_days=None,
)
partition_dir = tmp_parquet_dir / "orders"
partition_dir.mkdir(parents=True, exist_ok=True)
# Create an old file (100 days ago)
old_day = date.today() - timedelta(days=100)
key = old_day.strftime("%Y_%m_%d")
pq.write_table(pa.table({"id": [1]}), partition_dir / f"{key}.parquet")
adapter = _create_adapter(mock_config, mock_bq_client)
deleted = adapter._cleanup_old_partitions(table_config, partition_dir, "day")
assert deleted == 0
assert len(list(partition_dir.glob("*.parquet"))) == 1
# ---------------------------------------------------------------------------
# 15. _generate_partition_dates produces correct date ranges
# ---------------------------------------------------------------------------
class TestGeneratePartitionDates:
def test_daily_generation(self, mock_config, mock_bq_client):
"""Daily granularity should generate one date per day, inclusive."""
adapter = _create_adapter(mock_config, mock_bq_client)
start = date(2026, 3, 1)
end = date(2026, 3, 5)
dates = adapter._generate_partition_dates(start, end, "day")
assert dates == [
date(2026, 3, 1),
date(2026, 3, 2),
date(2026, 3, 3),
date(2026, 3, 4),
date(2026, 3, 5),
]
def test_monthly_generation(self, mock_config, mock_bq_client):
"""Monthly granularity should generate first-of-month dates, aligned."""
adapter = _create_adapter(mock_config, mock_bq_client)
# Start mid-month -- should align to 1st
start = date(2026, 1, 15)
end = date(2026, 4, 10)
dates = adapter._generate_partition_dates(start, end, "month")
assert dates == [
date(2026, 1, 1),
date(2026, 2, 1),
date(2026, 3, 1),
date(2026, 4, 1),
]
def test_monthly_generation_across_year_boundary(self, mock_config, mock_bq_client):
"""Monthly generation should cross year boundaries correctly."""
adapter = _create_adapter(mock_config, mock_bq_client)
start = date(2025, 11, 1)
end = date(2026, 2, 15)
dates = adapter._generate_partition_dates(start, end, "month")
assert dates == [
date(2025, 11, 1),
date(2025, 12, 1),
date(2026, 1, 1),
date(2026, 2, 1),
]
def test_daily_single_day(self, mock_config, mock_bq_client):
"""When start == end, should return a single date."""
adapter = _create_adapter(mock_config, mock_bq_client)
d = date(2026, 6, 15)
dates = adapter._generate_partition_dates(d, d, "day")
assert dates == [d]
def test_empty_range(self, mock_config, mock_bq_client):
"""When start > end, should return an empty list."""
adapter = _create_adapter(mock_config, mock_bq_client)
dates = adapter._generate_partition_dates(date(2026, 3, 10), date(2026, 3, 5), "day")
assert dates == []
# ---------------------------------------------------------------------------
# 16. _parse_partition_date converts partition keys back to dates
# ---------------------------------------------------------------------------
class TestParsePartitionDate:
def test_parse_day_format(self, mock_config, mock_bq_client):
"""'2026_01_15' with day granularity should parse to date(2026, 1, 15)."""
adapter = _create_adapter(mock_config, mock_bq_client)
result = adapter._parse_partition_date("2026_01_15", "day")
assert result == date(2026, 1, 15)
def test_parse_month_format(self, mock_config, mock_bq_client):
"""'2026_01' with month granularity should parse to date(2026, 1, 1)."""
adapter = _create_adapter(mock_config, mock_bq_client)
result = adapter._parse_partition_date("2026_01", "month")
assert result == date(2026, 1, 1)
def test_parse_year_format(self, mock_config, mock_bq_client):
"""'2026' with year granularity should parse to date(2026, 1, 1)."""
adapter = _create_adapter(mock_config, mock_bq_client)
result = adapter._parse_partition_date("2026", "year")
assert result == date(2026, 1, 1)
def test_parse_invalid_returns_none(self, mock_config, mock_bq_client):
"""Invalid partition key should return None."""
adapter = _create_adapter(mock_config, mock_bq_client)
assert adapter._parse_partition_date("invalid", "day") is None
assert adapter._parse_partition_date("not_a_date", "month") is None
assert adapter._parse_partition_date("abc", "year") is None
def test_parse_mismatched_granularity_returns_none(self, mock_config, mock_bq_client):
"""Day key with month granularity should return None (format mismatch)."""
adapter = _create_adapter(mock_config, mock_bq_client)
# "2026_01_15" is a day format -- parsing as "month" (%Y_%m) should fail
assert adapter._parse_partition_date("2026_01_15", "month") is None

View file

@ -868,3 +868,151 @@ class TestCreateClient:
assert isinstance(result, BigQueryClient)
assert result.project_id == "factory-project"
# ---------------------------------------------------------------------------
# 14. read_table_partitioned_streaming yields RecordBatches
# ---------------------------------------------------------------------------
class TestReadTablePartitionedStreaming:
def test_streaming_yields_batches(self, client, mock_bq_client):
"""read_table_partitioned_streaming yields RecordBatches, not a Table."""
batch1 = pa.record_batch({"a": [1, 2]})
batch2 = pa.record_batch({"a": [3, 4]})
mock_query_job = MagicMock()
mock_row_iter = MagicMock()
mock_row_iter.to_arrow_iterable.return_value = iter([batch1, batch2])
mock_query_job.result.return_value = mock_row_iter
mock_bq_client.query.return_value = mock_query_job
with patch("connectors.bigquery.client.bigquery") as mock_bq_module:
mock_bq_module.QueryJobConfig.return_value = MagicMock()
mock_bq_module.ScalarQueryParameter.return_value = MagicMock()
client.client = mock_bq_client
client.bqstorage_client = None # disable Storage API for simplicity
batches = list(client.read_table_partitioned_streaming(
table_id="proj.dataset.events",
partition_column="event_date",
start="2025-01-01",
))
assert len(batches) == 2
assert isinstance(batches[0], pa.RecordBatch)
assert isinstance(batches[1], pa.RecordBatch)
def test_streaming_with_date_column_type(self, client, mock_bq_client):
"""With column_type='DATE', ScalarQueryParameter uses 'DATE' type."""
batch = pa.record_batch({"a": [1]})
mock_query_job = MagicMock()
mock_row_iter = MagicMock()
mock_row_iter.to_arrow_iterable.return_value = iter([batch])
mock_query_job.result.return_value = mock_row_iter
mock_bq_client.query.return_value = mock_query_job
with patch("connectors.bigquery.client.bigquery") as mock_bq_module:
mock_bq_module.QueryJobConfig.return_value = MagicMock()
mock_bq_module.ScalarQueryParameter.return_value = MagicMock()
client.client = mock_bq_client
client.bqstorage_client = None
list(client.read_table_partitioned_streaming(
table_id="proj.dataset.events",
partition_column="event_date",
start="2025-01-01",
column_type="DATE",
))
# Verify ScalarQueryParameter was called with "DATE" type
mock_bq_module.ScalarQueryParameter.assert_called_once_with(
"start_value", "DATE", "2025-01-01"
)
def test_streaming_start_and_end(self, client, mock_bq_client):
"""With start and end, both params are created with correct column_type."""
batch = pa.record_batch({"a": [1]})
mock_query_job = MagicMock()
mock_row_iter = MagicMock()
mock_row_iter.to_arrow_iterable.return_value = iter([batch])
mock_query_job.result.return_value = mock_row_iter
mock_bq_client.query.return_value = mock_query_job
with patch("connectors.bigquery.client.bigquery") as mock_bq_module:
mock_bq_module.QueryJobConfig.return_value = MagicMock()
mock_bq_module.ScalarQueryParameter.return_value = MagicMock()
client.client = mock_bq_client
client.bqstorage_client = None
list(client.read_table_partitioned_streaming(
table_id="proj.dataset.events",
partition_column="event_date",
start="2025-01-01",
end="2025-06-01",
column_type="DATE",
))
sql = mock_bq_client.query.call_args[0][0]
assert "`event_date` >= @start_value" in sql
assert "`event_date` < @end_value" in sql
# Both parameters created with "DATE" type
assert mock_bq_module.ScalarQueryParameter.call_count == 2
calls = mock_bq_module.ScalarQueryParameter.call_args_list
assert calls[0].args == ("start_value", "DATE", "2025-01-01")
assert calls[1].args == ("end_value", "DATE", "2025-06-01")
# ---------------------------------------------------------------------------
# 15. read_table_partitioned column_type parameter
# ---------------------------------------------------------------------------
class TestReadTablePartitionedColumnType:
def test_date_column_type(self, client, mock_bq_client):
"""read_table_partitioned with column_type='DATE' creates DATE params."""
mock_job = MagicMock()
mock_job.to_arrow.return_value = pa.table({"a": [1]})
mock_bq_client.query.return_value = mock_job
with patch("connectors.bigquery.client.bigquery") as mock_bq_module:
mock_bq_module.QueryJobConfig.return_value = MagicMock()
mock_bq_module.ScalarQueryParameter.return_value = MagicMock()
client.client = mock_bq_client
client.read_table_partitioned(
table_id="proj.dataset.events",
partition_column="event_date",
start="2025-01-01",
end="2025-06-01",
column_type="DATE",
)
# Both parameters should use "DATE" type
assert mock_bq_module.ScalarQueryParameter.call_count == 2
calls = mock_bq_module.ScalarQueryParameter.call_args_list
assert calls[0].args == ("start_value", "DATE", "2025-01-01")
assert calls[1].args == ("end_value", "DATE", "2025-06-01")
def test_default_column_type_is_timestamp(self, client, mock_bq_client):
"""Default column_type is TIMESTAMP when not specified."""
mock_job = MagicMock()
mock_job.to_arrow.return_value = pa.table({"a": [1]})
mock_bq_client.query.return_value = mock_job
with patch("connectors.bigquery.client.bigquery") as mock_bq_module:
mock_bq_module.QueryJobConfig.return_value = MagicMock()
mock_bq_module.ScalarQueryParameter.return_value = MagicMock()
client.client = mock_bq_client
client.read_table_partitioned(
table_id="proj.dataset.events",
partition_column="created_at",
start="2025-01-01T00:00:00Z",
)
# Should default to "TIMESTAMP"
mock_bq_module.ScalarQueryParameter.assert_called_once_with(
"start_value", "TIMESTAMP", "2025-01-01T00:00:00Z"
)

View file

@ -0,0 +1,69 @@
"""Tests for TableConfig.query_mode field validation."""
import pytest
from src.config import TableConfig
# ---------------------------------------------------------------------------
# Helpers
# ---------------------------------------------------------------------------
def _make_table(**overrides) -> TableConfig:
"""Create a TableConfig with sensible defaults, applying overrides."""
defaults = dict(
id="test.dataset.table",
name="test_table",
description="Test",
primary_key="id",
sync_strategy="full_refresh",
)
defaults.update(overrides)
return TableConfig(**defaults)
# ---------------------------------------------------------------------------
# Tests
# ---------------------------------------------------------------------------
class TestQueryModeDefault:
def test_default_is_local(self):
table = _make_table()
assert table.query_mode == "local"
class TestQueryModeValidValues:
@pytest.mark.parametrize("mode", ["local", "remote", "hybrid"])
def test_valid_query_mode(self, mode):
table = _make_table(query_mode=mode)
assert table.query_mode == mode
class TestQueryModeInvalid:
@pytest.mark.parametrize("bad_mode", ["invalid", "Local", "REMOTE", "", "sql"])
def test_invalid_query_mode_raises(self, bad_mode):
with pytest.raises(ValueError, match="Invalid query_mode"):
_make_table(query_mode=bad_mode)
class TestQueryModeFromKwarg:
def test_kwarg_sets_query_mode(self):
"""Simulate what _parse_data_description does: pass query_mode as kwarg."""
table = TableConfig(
id="proj.dataset.orders",
name="orders",
description="Order data",
primary_key="order_id",
sync_strategy="full_refresh",
query_mode="remote",
)
assert table.query_mode == "remote"
def test_kwarg_default_when_omitted(self):
"""When YAML has no query_mode, _parse_data_description passes 'local'."""
table = TableConfig(
id="proj.dataset.orders",
name="orders",
description="Order data",
primary_key="order_id",
sync_strategy="full_refresh",
)
assert table.query_mode == "local"

View file

@ -0,0 +1,228 @@
"""Tests for remote table skipping in DataSyncManager.sync_all().
Tables with query_mode == "remote" should be skipped during sync (no local
Parquet file is needed -- queries go directly to BigQuery). Tables with
query_mode "local" or "hybrid" must still be synced normally.
"""
from unittest.mock import MagicMock, patch
import pytest
from src.config import TableConfig
from src.data_sync import DataSyncManager
# ---------------------------------------------------------------------------
# Helpers
# ---------------------------------------------------------------------------
def _make_table_config(
table_id: str,
name: str,
query_mode: str = "local",
) -> TableConfig:
"""Create a minimal TableConfig for testing."""
return TableConfig(
id=table_id,
name=name,
description=f"Test table {name}",
primary_key="id",
sync_strategy="full_refresh",
query_mode=query_mode,
)
def _successful_sync_result() -> dict:
"""Return a fake successful sync result dict."""
return {
"success": True,
"rows": 100,
"file_size_mb": 0.5,
}
# ---------------------------------------------------------------------------
# Fixtures
# ---------------------------------------------------------------------------
@pytest.fixture
def table_local():
return _make_table_config("t.local", "local_table", query_mode="local")
@pytest.fixture
def table_remote():
return _make_table_config("t.remote", "remote_table", query_mode="remote")
@pytest.fixture
def table_hybrid():
return _make_table_config("t.hybrid", "hybrid_table", query_mode="hybrid")
@pytest.fixture
def all_tables(table_local, table_remote, table_hybrid):
return [table_local, table_remote, table_hybrid]
@pytest.fixture
def mock_config(all_tables):
"""Return a mock Config whose .tables list contains all three query modes."""
cfg = MagicMock()
cfg.tables = all_tables
cfg.get_metadata_path.return_value = MagicMock() # Path-like
def _get_table_config(tid):
return next((t for t in all_tables if t.id == tid), None)
cfg.get_table_config.side_effect = _get_table_config
return cfg
@pytest.fixture
def mock_data_source():
"""Return a mock DataSource that always succeeds."""
ds = MagicMock()
ds.sync_table.return_value = _successful_sync_result()
return ds
@pytest.fixture
def sync_manager(mock_config, mock_data_source):
"""Create a DataSyncManager with mocked dependencies."""
with (
patch("src.data_sync.get_config", return_value=mock_config),
patch("src.data_sync.create_data_source", return_value=mock_data_source),
patch("src.data_sync.SyncState"),
):
manager = DataSyncManager()
# Patch out schema generation and auto-profiling (not under test)
manager._generate_schema_yaml = MagicMock()
yield manager
# ---------------------------------------------------------------------------
# Tests
# ---------------------------------------------------------------------------
class TestSyncAllRemoteSkipping:
"""Verify that sync_all filters out remote tables."""
def test_remote_table_not_synced(self, sync_manager, mock_data_source, table_remote):
"""Remote table must NOT be passed to data_source.sync_table."""
sync_manager.sync_all()
synced_ids = [
call.args[0].id for call in mock_data_source.sync_table.call_args_list
]
assert table_remote.id not in synced_ids
def test_local_table_is_synced(self, sync_manager, mock_data_source, table_local):
"""Local table must be synced normally."""
sync_manager.sync_all()
synced_ids = [
call.args[0].id for call in mock_data_source.sync_table.call_args_list
]
assert table_local.id in synced_ids
def test_hybrid_table_is_synced(self, sync_manager, mock_data_source, table_hybrid):
"""Hybrid table must be synced (needs local parquet for profiling)."""
sync_manager.sync_all()
synced_ids = [
call.args[0].id for call in mock_data_source.sync_table.call_args_list
]
assert table_hybrid.id in synced_ids
def test_sync_call_count(self, sync_manager, mock_data_source):
"""Only local + hybrid tables should result in sync_table calls."""
sync_manager.sync_all()
# 3 tables total, 1 remote -> 2 sync calls
assert mock_data_source.sync_table.call_count == 2
def test_results_exclude_remote(self, sync_manager, table_remote):
"""The results dict must not contain an entry for the remote table."""
results = sync_manager.sync_all()
assert table_remote.id not in results
def test_results_include_local_and_hybrid(
self, sync_manager, table_local, table_hybrid
):
"""Results dict must contain entries for local and hybrid tables."""
results = sync_manager.sync_all()
assert table_local.id in results
assert table_hybrid.id in results
class TestSyncAllAllRemote:
"""Edge case: every table is remote."""
def test_no_sync_calls_when_all_remote(self, mock_config, mock_data_source):
remote_only = [
_make_table_config("t.r1", "remote1", query_mode="remote"),
_make_table_config("t.r2", "remote2", query_mode="remote"),
]
mock_config.tables = remote_only
with (
patch("src.data_sync.get_config", return_value=mock_config),
patch("src.data_sync.create_data_source", return_value=mock_data_source),
patch("src.data_sync.SyncState"),
):
manager = DataSyncManager()
manager._generate_schema_yaml = MagicMock()
results = manager.sync_all()
assert mock_data_source.sync_table.call_count == 0
assert results == {}
class TestSyncAllNoRemote:
"""Edge case: no remote tables at all -- everything syncs."""
def test_all_tables_synced(self, mock_config, mock_data_source):
local_only = [
_make_table_config("t.l1", "local1", query_mode="local"),
_make_table_config("t.l2", "local2", query_mode="local"),
]
mock_config.tables = local_only
with (
patch("src.data_sync.get_config", return_value=mock_config),
patch("src.data_sync.create_data_source", return_value=mock_data_source),
patch("src.data_sync.SyncState"),
):
manager = DataSyncManager()
manager._generate_schema_yaml = MagicMock()
results = manager.sync_all()
assert mock_data_source.sync_table.call_count == 2
assert "t.l1" in results
assert "t.l2" in results
class TestSyncAllWithTableFilter:
"""When sync_all receives an explicit table list, remote filtering still applies."""
def test_explicit_remote_table_still_skipped(
self, sync_manager, mock_data_source, table_remote
):
"""Even if explicitly listed, a remote table should be skipped."""
sync_manager.sync_all(tables=[table_remote.id])
assert mock_data_source.sync_table.call_count == 0
def test_explicit_local_table_synced(
self, sync_manager, mock_data_source, table_local
):
"""An explicitly listed local table should be synced."""
sync_manager.sync_all(tables=[table_local.id])
assert mock_data_source.sync_table.call_count == 1
synced_id = mock_data_source.sync_table.call_args_list[0].args[0].id
assert synced_id == table_local.id

View file

@ -0,0 +1,488 @@
"""Tests for DuckDB Manager - query_mode classification and BQ registration.
Tests cover:
- _get_bq_project_from_table_id: extracting BQ project from table IDs
- get_remote_tables: filtering tables by query_mode
- register_bq_table: registering BQ query results in DuckDB
- init_duckdb: table classification by query_mode, local view creation,
remote table logging
"""
import os
from pathlib import Path
from unittest.mock import MagicMock, patch
import duckdb
import pyarrow as pa
import pyarrow.parquet as pq
import pytest
from scripts.duckdb_manager import (
_get_bq_project_from_table_id,
get_remote_tables,
init_duckdb,
register_bq_table,
)
# ---------------------------------------------------------------------------
# Fixtures
# ---------------------------------------------------------------------------
@pytest.fixture
def tmp_project(tmp_path):
"""Create a minimal project layout with docs/data_description.md and a parquet file.
Returns (project_root, db_path, data_dir) tuple.
"""
docs_dir = tmp_path / "docs"
docs_dir.mkdir()
data_description = """\
# Data Description
```yaml
folder_mapping:
in.c-crm: crm_data
tables:
- id: "in.c-crm.company"
name: "company"
description: "Company master data"
primary_key: "id"
sync_strategy: "full_refresh"
```
"""
(docs_dir / "data_description.md").write_text(data_description)
# Create parquet directory and a minimal parquet file
data_dir = tmp_path / "server" / "parquet" / "crm_data"
data_dir.mkdir(parents=True)
table = pa.table({"id": [1, 2, 3], "name": ["a", "b", "c"]})
pq.write_table(table, data_dir / "company.parquet")
db_dir = tmp_path / "user" / "duckdb"
db_dir.mkdir(parents=True)
db_path = str(db_dir / "test.duckdb")
return tmp_path, db_path, str(tmp_path / "server")
@pytest.fixture
def tmp_project_mixed(tmp_path):
"""Project with local, remote, and hybrid tables."""
docs_dir = tmp_path / "docs"
docs_dir.mkdir()
data_description = """\
# Data Description
```yaml
folder_mapping:
in.c-crm: crm_data
tables:
- id: "in.c-crm.company"
name: "company"
description: "Local table"
primary_key: "id"
sync_strategy: "full_refresh"
- id: "prj-grp-dataview-prod-1ff9.finance.revenue"
name: "revenue"
description: "Remote BQ table"
primary_key: "id"
query_mode: "remote"
- id: "prj-grp-dataview-prod-1ff9.marketing.campaigns"
name: "campaigns"
description: "Hybrid table"
primary_key: "id"
sync_strategy: "full_refresh"
query_mode: "hybrid"
```
"""
(docs_dir / "data_description.md").write_text(data_description)
# Create parquet files for local and hybrid tables
crm_dir = tmp_path / "server" / "parquet" / "crm_data"
crm_dir.mkdir(parents=True)
table = pa.table({"id": [1, 2], "name": ["a", "b"]})
pq.write_table(table, crm_dir / "company.parquet")
marketing_dir = tmp_path / "server" / "parquet" / "prj-grp-dataview-prod-1ff9.marketing"
marketing_dir.mkdir(parents=True)
campaigns_table = pa.table({"id": [10], "campaign": ["test"]})
pq.write_table(campaigns_table, marketing_dir / "campaigns.parquet")
db_dir = tmp_path / "user" / "duckdb"
db_dir.mkdir(parents=True)
db_path = str(db_dir / "test.duckdb")
return tmp_path, db_path, str(tmp_path / "server")
@pytest.fixture
def tmp_project_remote_only(tmp_path):
"""Project with only remote tables (no local parquet needed)."""
docs_dir = tmp_path / "docs"
docs_dir.mkdir()
data_description = """\
# Data Description
```yaml
tables:
- id: "prj-grp-dataview-prod-1ff9.finance.revenue"
name: "revenue"
description: "Remote BQ table"
primary_key: "id"
query_mode: "remote"
- id: "prj-grp-dataview-prod-1ff9.finance.costs"
name: "costs"
description: "Remote BQ table"
primary_key: "id"
query_mode: "remote"
```
"""
(docs_dir / "data_description.md").write_text(data_description)
db_dir = tmp_path / "user" / "duckdb"
db_dir.mkdir(parents=True)
db_path = str(db_dir / "test.duckdb")
return tmp_path, db_path, str(tmp_path / "server")
# ---------------------------------------------------------------------------
# Tests: _get_bq_project_from_table_id
# ---------------------------------------------------------------------------
class TestGetBqProjectFromTableId:
"""Test extracting BigQuery project ID from fully-qualified table IDs."""
def test_valid_bq_table_id(self):
result = _get_bq_project_from_table_id(
"prj-grp-dataview-prod-1ff9.finance.table"
)
assert result == "prj-grp-dataview-prod-1ff9"
def test_valid_bq_table_id_different_project(self):
result = _get_bq_project_from_table_id(
"my-gcp-project.dataset_name.table_name"
)
assert result == "my-gcp-project"
def test_keboola_format_returns_none(self):
result = _get_bq_project_from_table_id("in.c-crm.table")
assert result is None
def test_two_part_id_returns_none(self):
result = _get_bq_project_from_table_id("dataset.table")
assert result is None
def test_single_part_returns_none(self):
result = _get_bq_project_from_table_id("table_only")
assert result is None
def test_four_parts_returns_none(self):
result = _get_bq_project_from_table_id("a-b.c.d.e")
assert result is None
def test_empty_string_returns_none(self):
result = _get_bq_project_from_table_id("")
assert result is None
def test_three_parts_no_hyphen_returns_none(self):
result = _get_bq_project_from_table_id("project.dataset.table")
assert result is None
def test_hyphen_in_first_part_is_key(self):
result = _get_bq_project_from_table_id("a-b.dataset.table")
assert result == "a-b"
# ---------------------------------------------------------------------------
# Tests: get_remote_tables
# ---------------------------------------------------------------------------
class TestGetRemoteTables:
"""Test filtering table configs by query_mode."""
def test_returns_remote_tables(self):
configs = [
{"name": "local", "query_mode": "local"},
{"name": "remote1", "query_mode": "remote"},
{"name": "hybrid1", "query_mode": "hybrid"},
]
result = get_remote_tables(configs)
names = [tc["name"] for tc in result]
assert "remote1" in names
assert "hybrid1" in names
assert "local" not in names
def test_returns_empty_when_all_local(self):
configs = [
{"name": "t1", "query_mode": "local"},
{"name": "t2"}, # defaults to local (no query_mode key)
]
result = get_remote_tables(configs)
assert result == []
def test_missing_query_mode_treated_as_local(self):
configs = [{"name": "t1"}] # no query_mode
result = get_remote_tables(configs)
assert result == []
# ---------------------------------------------------------------------------
# Tests: register_bq_table
# ---------------------------------------------------------------------------
class TestRegisterBqTable:
"""Test registering BQ query results as DuckDB views."""
@staticmethod
def _make_factory(arrow_table, side_effect=None):
"""Create a mock BQ client factory returning a client that yields arrow_table."""
mock_job = MagicMock()
if side_effect:
mock_job.to_arrow.side_effect = side_effect
else:
mock_job.to_arrow.return_value = arrow_table
mock_client = MagicMock()
mock_client.query.return_value = mock_job
factory = MagicMock(return_value=mock_client)
factory._mock_client = mock_client
factory._mock_job = mock_job
return factory
def test_registers_arrow_table_in_duckdb(self):
"""Result from BQ should be queryable in DuckDB after registration."""
arrow_table = pa.table({"id": [1, 2], "val": [10.0, 20.0]})
factory = self._make_factory(arrow_table)
conn = duckdb.connect()
rows = register_bq_table(
conn=conn,
table_id="proj.dataset.table",
view_name="test_view",
sql="SELECT * FROM table",
bq_project="test-project",
_bq_client_factory=factory,
)
assert rows == 2
result = conn.execute("SELECT SUM(val) FROM test_view").fetchone()[0]
assert result == 30.0
conn.close()
def test_raises_without_bq_project(self):
conn = duckdb.connect()
with patch.dict(os.environ, {}, clear=True):
with pytest.raises(ValueError, match="BigQuery project not set"):
register_bq_table(
conn=conn,
table_id="proj.ds.tbl",
view_name="v",
sql="SELECT 1",
)
conn.close()
def test_uses_env_var_when_no_project_arg(self):
arrow_table = pa.table({"x": [1]})
factory = self._make_factory(arrow_table)
conn = duckdb.connect()
with patch.dict(os.environ, {"BIGQUERY_PROJECT": "env-proj"}):
register_bq_table(
conn=conn,
table_id="p.d.t",
view_name="v",
sql="SELECT 1",
_bq_client_factory=factory,
)
factory.assert_called_once_with("env-proj")
conn.close()
def test_storage_api_fallback(self):
"""Falls back to REST when Storage API permission denied."""
arrow_table = pa.table({"x": [1]})
factory = self._make_factory(
arrow_table,
side_effect=[
Exception("PERMISSION_DENIED readsessions"),
arrow_table,
],
)
conn = duckdb.connect()
rows = register_bq_table(
conn=conn,
table_id="p.d.t",
view_name="v",
sql="SELECT 1",
bq_project="proj",
_bq_client_factory=factory,
)
assert rows == 1
factory._mock_job.to_arrow.assert_called_with(create_bqstorage_client=False)
conn.close()
# ---------------------------------------------------------------------------
# Tests: init_duckdb - table classification
# ---------------------------------------------------------------------------
class TestInitDuckdbClassification:
"""Test that tables are correctly classified by query_mode."""
def test_local_tables_create_parquet_views(self, tmp_project):
project_root, db_path, data_dir = tmp_project
with patch("scripts.duckdb_manager.find_project_root", return_value=project_root):
result = init_duckdb(
db_path=db_path, data_dir=data_dir, verbose=False
)
assert result is True
conn = duckdb.connect(db_path, read_only=True)
tables = [row[0] for row in conn.execute("SHOW TABLES").fetchall()]
assert "company" in tables
row_count = conn.execute("SELECT COUNT(*) FROM company").fetchone()[0]
assert row_count == 3
conn.close()
def test_remote_tables_not_created_as_local_views(self, tmp_project_mixed):
project_root, db_path, data_dir = tmp_project_mixed
with patch("scripts.duckdb_manager.find_project_root", return_value=project_root):
result = init_duckdb(
db_path=db_path, data_dir=data_dir, verbose=False
)
assert result is True
conn = duckdb.connect(db_path, read_only=True)
tables = [row[0] for row in conn.execute("SHOW TABLES").fetchall()]
assert "revenue" not in tables
assert "company" in tables
conn.close()
def test_hybrid_tables_create_local_views(self, tmp_project_mixed):
project_root, db_path, data_dir = tmp_project_mixed
with patch("scripts.duckdb_manager.find_project_root", return_value=project_root):
result = init_duckdb(
db_path=db_path, data_dir=data_dir, verbose=False
)
assert result is True
conn = duckdb.connect(db_path, read_only=True)
tables = [row[0] for row in conn.execute("SHOW TABLES").fetchall()]
assert "campaigns" in tables
conn.close()
def test_default_query_mode_is_local(self, tmp_project):
project_root, db_path, data_dir = tmp_project
with patch("scripts.duckdb_manager.find_project_root", return_value=project_root):
result = init_duckdb(
db_path=db_path, data_dir=data_dir, verbose=False
)
assert result is True
conn = duckdb.connect(db_path, read_only=True)
tables = [row[0] for row in conn.execute("SHOW TABLES").fetchall()]
assert "company" in tables
conn.close()
# ---------------------------------------------------------------------------
# Tests: init_duckdb - remote table logging
# ---------------------------------------------------------------------------
class TestInitDuckdbRemoteLogging:
"""Test that remote tables are logged correctly."""
def test_remote_tables_logged(self, tmp_project_remote_only, capsys):
project_root, db_path, data_dir = tmp_project_remote_only
with patch("scripts.duckdb_manager.find_project_root", return_value=project_root):
init_duckdb(
db_path=db_path, data_dir=data_dir, verbose=True,
)
output = capsys.readouterr().out
assert "revenue" in output
assert "costs" in output
assert "[BQ]" in output
def test_remote_only_project_succeeds(self, tmp_project_remote_only):
project_root, db_path, data_dir = tmp_project_remote_only
with patch("scripts.duckdb_manager.find_project_root", return_value=project_root):
result = init_duckdb(
db_path=db_path, data_dir=data_dir, verbose=False,
)
assert result is True
# ---------------------------------------------------------------------------
# Tests: init_duckdb - missing parquet handling
# ---------------------------------------------------------------------------
class TestInitDuckdbMissingParquet:
"""Test behavior when parquet files are missing."""
def test_missing_parquet_skips_view(self, tmp_path):
docs_dir = tmp_path / "docs"
docs_dir.mkdir()
data_description = """\
# Data Description
```yaml
tables:
- id: "in.c-crm.missing_table"
name: "missing_table"
description: "No parquet exists"
primary_key: "id"
sync_strategy: "full_refresh"
```
"""
(docs_dir / "data_description.md").write_text(data_description)
db_dir = tmp_path / "user" / "duckdb"
db_dir.mkdir(parents=True)
db_path = str(db_dir / "test.duckdb")
with patch("scripts.duckdb_manager.find_project_root", return_value=tmp_path):
result = init_duckdb(
db_path=db_path, data_dir=str(tmp_path / "server"), verbose=False
)
assert result is True
conn = duckdb.connect(db_path, read_only=True)
tables = [row[0] for row in conn.execute("SHOW TABLES").fetchall()]
assert "missing_table" not in tables
conn.close()
def test_remote_table_no_local_parquet_needed(self, tmp_project_remote_only):
project_root, db_path, data_dir = tmp_project_remote_only
with patch("scripts.duckdb_manager.find_project_root", return_value=project_root):
result = init_duckdb(
db_path=db_path, data_dir=data_dir, verbose=False,
)
assert result is True