Add per-partition streaming sync and hybrid query architecture
Partitioned sync: iterates day-by-day instead of loading full dataset. Each partition: query BQ -> stream to disk -> free RAM. Peak ~50 MB. New helpers: _sync_single_partition, _cleanup_old_partitions, _generate_partition_dates. Config: added partition_column_type (DATE/TIMESTAMP/DATETIME), query_mode (local/remote/hybrid). DuckDB manager: hybrid architecture support (local Parquet + remote BQ tables). Data sync: skips remote tables, filters by query_mode. Tests: 113 passing (adapter, client, config, data_sync, duckdb_manager).
This commit is contained in:
parent
d2e83ce9d0
commit
8bb46a9e0a
9 changed files with 1731 additions and 135 deletions
|
|
@ -9,7 +9,7 @@ using PyArrow (no CSV intermediate step).
|
|||
import logging
|
||||
from pathlib import Path
|
||||
from typing import Dict, List, Optional, Any
|
||||
from datetime import datetime, timedelta
|
||||
from datetime import datetime, timedelta, date
|
||||
|
||||
import pyarrow as pa
|
||||
import pyarrow.parquet as pq
|
||||
|
|
@ -336,10 +336,11 @@ class BigQueryDataSource(DataSource):
|
|||
sync_state: SyncState,
|
||||
) -> Dict[str, Any]:
|
||||
"""
|
||||
Partition-based sync: read data by partition range and write partition files.
|
||||
"""
|
||||
import pandas as pd
|
||||
Per-partition streaming sync: process one partition (day) at a time.
|
||||
|
||||
Queries BQ for a single day, streams result to disk, then moves to next day.
|
||||
Memory usage is constant (~20-50 MB per partition) regardless of total data volume.
|
||||
"""
|
||||
partition_col = table_config.partition_by
|
||||
if not partition_col and table_config.incremental_column:
|
||||
partition_col = table_config.incremental_column
|
||||
|
|
@ -351,96 +352,231 @@ class BigQueryDataSource(DataSource):
|
|||
)
|
||||
return self._full_refresh(table_config)
|
||||
|
||||
granularity = table_config.partition_granularity or "month"
|
||||
granularity = table_config.partition_granularity or "day"
|
||||
column_type = table_config.partition_column_type
|
||||
logger.info(
|
||||
f"Partitioned sync: {table_config.name} "
|
||||
f"(by {partition_col}, {granularity})"
|
||||
f"(by {partition_col}, {granularity}, type={column_type})"
|
||||
)
|
||||
|
||||
partition_dir = self.config.get_parquet_path(table_config)
|
||||
date_columns = self.bq_client.get_date_columns(table_config.id)
|
||||
pyarrow_schema = self.bq_client.get_pyarrow_schema(table_config.id)
|
||||
|
||||
# Determine time range
|
||||
# Determine date range
|
||||
last_sync = sync_state.get_last_sync(table_config.id)
|
||||
today = date.today()
|
||||
|
||||
if last_sync:
|
||||
last_sync_dt = datetime.fromisoformat(last_sync)
|
||||
window_days = table_config.incremental_window_days or 7
|
||||
start_dt = last_sync_dt - timedelta(days=window_days)
|
||||
logger.info(f" -> Reading from {start_dt.isoformat()} (window: {window_days} days)")
|
||||
start_date = (last_sync_dt - timedelta(days=window_days)).date()
|
||||
logger.info(f" -> Incremental sync from {start_date} (window: {window_days} days)")
|
||||
else:
|
||||
if table_config.max_history_days:
|
||||
start_dt = datetime.now() - timedelta(days=table_config.max_history_days)
|
||||
logger.info(f" -> First sync, limited to last {table_config.max_history_days} days")
|
||||
start_date = today - timedelta(days=table_config.max_history_days)
|
||||
logger.info(f" -> First sync, last {table_config.max_history_days} days from {start_date}")
|
||||
else:
|
||||
start_dt = None
|
||||
logger.info(" -> First sync, reading all data")
|
||||
start_date = today - timedelta(days=365)
|
||||
logger.info(" -> First sync, no max_history_days, defaulting to 365 days")
|
||||
|
||||
# Read data from BigQuery
|
||||
if start_dt:
|
||||
arrow_table = self.bq_client.read_table_partitioned(
|
||||
table_id=table_config.id,
|
||||
partition_column=partition_col,
|
||||
start=start_dt.isoformat(),
|
||||
columns=table_config.columns,
|
||||
)
|
||||
else:
|
||||
arrow_table = self.bq_client.read_table(
|
||||
table_config.id,
|
||||
columns=table_config.columns,
|
||||
row_filter=table_config.row_filter,
|
||||
# Generate list of partition dates
|
||||
partition_dates = self._generate_partition_dates(start_date, today, granularity)
|
||||
logger.info(f" -> Processing {len(partition_dates)} partitions")
|
||||
|
||||
total_rows = 0
|
||||
partitions_updated = 0
|
||||
|
||||
for partition_date in partition_dates:
|
||||
rows = self._sync_single_partition(
|
||||
table_config=table_config,
|
||||
partition_col=partition_col,
|
||||
partition_date=partition_date,
|
||||
partition_dir=partition_dir,
|
||||
date_columns=date_columns,
|
||||
pyarrow_schema=pyarrow_schema,
|
||||
granularity=granularity,
|
||||
column_type=column_type,
|
||||
)
|
||||
if rows > 0:
|
||||
partitions_updated += 1
|
||||
total_rows += rows
|
||||
|
||||
if arrow_table.num_rows == 0:
|
||||
logger.info(" -> No data to sync")
|
||||
return self._get_partition_totals(partition_dir)
|
||||
# Cleanup old partitions beyond retention window
|
||||
deleted = self._cleanup_old_partitions(table_config, partition_dir, granularity)
|
||||
if deleted > 0:
|
||||
logger.info(f" -> Cleaned up {deleted} old partition files")
|
||||
|
||||
logger.info(f" -> Processing {arrow_table.num_rows} rows into partitions")
|
||||
|
||||
# Convert to pandas for partitioning
|
||||
df = arrow_table.to_pandas()
|
||||
|
||||
# Ensure partition column is datetime
|
||||
if not pd.api.types.is_datetime64_any_dtype(df[partition_col]):
|
||||
df[partition_col] = pd.to_datetime(df[partition_col], format="ISO8601", utc=True)
|
||||
|
||||
# Create partition key
|
||||
if granularity == "month":
|
||||
df["_partition_key"] = df[partition_col].dt.strftime("%Y_%m")
|
||||
elif granularity == "day":
|
||||
df["_partition_key"] = df[partition_col].dt.strftime("%Y_%m_%d")
|
||||
elif granularity == "year":
|
||||
df["_partition_key"] = df[partition_col].dt.strftime("%Y")
|
||||
|
||||
primary_key_cols = table_config.get_primary_key_columns()
|
||||
partitions_updated = set()
|
||||
|
||||
for partition_key, group_df in df.groupby("_partition_key"):
|
||||
group_df = group_df.drop(columns=["_partition_key"])
|
||||
partition_path = self.config.get_partition_path(table_config, partition_key)
|
||||
partitions_updated.add(partition_key)
|
||||
|
||||
# Merge with existing partition if it exists
|
||||
if partition_path.exists():
|
||||
existing_df = pd.read_parquet(partition_path)
|
||||
merged_df = pd.concat([existing_df, group_df], ignore_index=True)
|
||||
merged_df = merged_df.drop_duplicates(subset=primary_key_cols, keep="last")
|
||||
else:
|
||||
merged_df = group_df
|
||||
|
||||
# Write partition
|
||||
table = pa.Table.from_pandas(merged_df, preserve_index=False)
|
||||
if date_columns:
|
||||
table = convert_date_columns_to_date32(table, date_columns)
|
||||
if pyarrow_schema:
|
||||
table = apply_schema_to_table(table, pyarrow_schema)
|
||||
pq.write_table(table, partition_path, compression="snappy")
|
||||
|
||||
logger.info(f" -> Partitioned sync complete: {len(partitions_updated)} partitions updated")
|
||||
logger.info(
|
||||
f" -> Partitioned sync complete: {partitions_updated} partitions updated, "
|
||||
f"{total_rows} total rows processed"
|
||||
)
|
||||
|
||||
return self._get_partition_totals(partition_dir)
|
||||
|
||||
@staticmethod
|
||||
def _generate_partition_dates(
|
||||
start_date: date,
|
||||
end_date: date,
|
||||
granularity: str,
|
||||
) -> List[date]:
|
||||
"""Generate list of partition start dates between start and end."""
|
||||
dates = []
|
||||
current = start_date
|
||||
|
||||
if granularity == "day":
|
||||
while current <= end_date:
|
||||
dates.append(current)
|
||||
current += timedelta(days=1)
|
||||
elif granularity == "month":
|
||||
# Align to first of month
|
||||
current = current.replace(day=1)
|
||||
while current <= end_date:
|
||||
dates.append(current)
|
||||
# Move to first of next month
|
||||
if current.month == 12:
|
||||
current = current.replace(year=current.year + 1, month=1)
|
||||
else:
|
||||
current = current.replace(month=current.month + 1)
|
||||
elif granularity == "year":
|
||||
current = current.replace(month=1, day=1)
|
||||
while current <= end_date:
|
||||
dates.append(current)
|
||||
current = current.replace(year=current.year + 1)
|
||||
|
||||
return dates
|
||||
|
||||
def _sync_single_partition(
|
||||
self,
|
||||
table_config: TableConfig,
|
||||
partition_col: str,
|
||||
partition_date: date,
|
||||
partition_dir: Path,
|
||||
date_columns: List[str],
|
||||
pyarrow_schema,
|
||||
granularity: str,
|
||||
column_type: str,
|
||||
) -> int:
|
||||
"""
|
||||
Query BQ for one partition period, stream to disk, merge with existing file.
|
||||
|
||||
Returns row count for this partition after merge.
|
||||
"""
|
||||
import pandas as pd
|
||||
|
||||
# Calculate partition range [start, end)
|
||||
start = partition_date
|
||||
if granularity == "day":
|
||||
end = start + timedelta(days=1)
|
||||
partition_key = start.strftime("%Y_%m_%d")
|
||||
elif granularity == "month":
|
||||
if start.month == 12:
|
||||
end = start.replace(year=start.year + 1, month=1)
|
||||
else:
|
||||
end = start.replace(month=start.month + 1)
|
||||
partition_key = start.strftime("%Y_%m")
|
||||
elif granularity == "year":
|
||||
end = start.replace(year=start.year + 1)
|
||||
partition_key = start.strftime("%Y")
|
||||
else:
|
||||
raise ValueError(f"Unknown granularity: {granularity}")
|
||||
|
||||
partition_path = self.config.get_partition_path(table_config, partition_key)
|
||||
|
||||
# Stream data from BQ for this single partition
|
||||
batches = []
|
||||
for batch in self.bq_client.read_table_partitioned_streaming(
|
||||
table_id=table_config.id,
|
||||
partition_column=partition_col,
|
||||
start=start.isoformat(),
|
||||
end=end.isoformat(),
|
||||
columns=table_config.columns,
|
||||
column_type=column_type,
|
||||
):
|
||||
batches.append(batch)
|
||||
|
||||
if not batches:
|
||||
return 0
|
||||
|
||||
new_data = pa.Table.from_batches(batches)
|
||||
if new_data.num_rows == 0:
|
||||
return 0
|
||||
|
||||
# Apply schema conversions
|
||||
if date_columns:
|
||||
new_data = convert_date_columns_to_date32(new_data, date_columns)
|
||||
if pyarrow_schema:
|
||||
new_data = apply_schema_to_table(new_data, pyarrow_schema)
|
||||
|
||||
# Merge with existing partition file if present
|
||||
primary_key_cols = table_config.get_primary_key_columns()
|
||||
|
||||
if partition_path.exists():
|
||||
existing = pq.read_table(partition_path)
|
||||
merged = self._merge_arrow_tables(existing, new_data, primary_key_cols)
|
||||
else:
|
||||
merged = new_data
|
||||
|
||||
# Write partition file
|
||||
pq.write_table(merged, partition_path, compression="snappy")
|
||||
row_count = merged.num_rows
|
||||
|
||||
logger.debug(
|
||||
f" Partition {partition_key}: {new_data.num_rows} new rows, "
|
||||
f"{row_count} total after merge"
|
||||
)
|
||||
|
||||
# Release memory
|
||||
del batches, new_data, merged
|
||||
|
||||
return row_count
|
||||
|
||||
def _cleanup_old_partitions(
|
||||
self,
|
||||
table_config: TableConfig,
|
||||
partition_dir: Path,
|
||||
granularity: str,
|
||||
) -> int:
|
||||
"""
|
||||
Delete partition files older than max_history_days.
|
||||
|
||||
Returns count of deleted files.
|
||||
"""
|
||||
if not table_config.max_history_days:
|
||||
return 0
|
||||
|
||||
if not partition_dir.exists():
|
||||
return 0
|
||||
|
||||
cutoff_date = date.today() - timedelta(days=table_config.max_history_days)
|
||||
deleted = 0
|
||||
|
||||
for part_path in partition_dir.glob("*.parquet"):
|
||||
try:
|
||||
partition_date = self._parse_partition_date(part_path.stem, granularity)
|
||||
if partition_date and partition_date < cutoff_date:
|
||||
part_path.unlink()
|
||||
deleted += 1
|
||||
logger.debug(f" Deleted old partition: {part_path.name}")
|
||||
except (ValueError, IndexError):
|
||||
logger.warning(f" Skipping unrecognized partition file: {part_path.name}")
|
||||
|
||||
return deleted
|
||||
|
||||
@staticmethod
|
||||
def _parse_partition_date(partition_key: str, granularity: str) -> Optional[date]:
|
||||
"""Parse a partition key back to a date."""
|
||||
try:
|
||||
if granularity == "day":
|
||||
return datetime.strptime(partition_key, "%Y_%m_%d").date()
|
||||
elif granularity == "month":
|
||||
return datetime.strptime(partition_key, "%Y_%m").date()
|
||||
elif granularity == "year":
|
||||
return datetime.strptime(partition_key, "%Y").date()
|
||||
except ValueError:
|
||||
return None
|
||||
return None
|
||||
|
||||
def _merge_arrow_tables(
|
||||
self,
|
||||
existing: pa.Table,
|
||||
|
|
|
|||
|
|
@ -1,9 +1,11 @@
|
|||
#!/usr/bin/env python3
|
||||
"""
|
||||
DuckDB Manager - Initialize and manage DuckDB database with views from parquet files.
|
||||
DuckDB Manager - Initialize and manage DuckDB database with views from parquet files
|
||||
and runtime BigQuery query registration for remote/hybrid tables.
|
||||
|
||||
This script dynamically reads table configurations from docs/data_description.md
|
||||
and creates DuckDB views accordingly. No hardcoded table list needed!
|
||||
For BigQuery data sources, tables with query_mode="remote" or "hybrid" are queried
|
||||
at runtime via the Python BQ client and registered as in-memory Arrow tables in DuckDB.
|
||||
This avoids the DuckDB BigQuery extension limitation (cannot read BQ views).
|
||||
|
||||
Usage:
|
||||
python3 scripts/duckdb_manager.py --reinit # Initialize/reinitialize all views
|
||||
|
|
@ -11,13 +13,17 @@ Usage:
|
|||
"""
|
||||
|
||||
import duckdb
|
||||
import logging
|
||||
import os
|
||||
import sys
|
||||
import argparse
|
||||
import re
|
||||
import yaml
|
||||
from pathlib import Path
|
||||
from typing import Dict, List, Tuple
|
||||
from typing import Dict, List, Optional, Tuple
|
||||
|
||||
|
||||
logger = logging.getLogger(__name__)
|
||||
|
||||
|
||||
def find_project_root() -> Path:
|
||||
|
|
@ -157,17 +163,131 @@ def get_parquet_path(table_config: Dict, folder_mapping: Dict[str, str], data_di
|
|||
return parquet_dir / f"{table_name}.parquet"
|
||||
|
||||
|
||||
def init_duckdb(db_path="user/duckdb/analytics.duckdb", data_dir="server", verbose=True):
|
||||
def _get_bq_project_from_table_id(table_id: str) -> Optional[str]:
|
||||
"""Extract BQ project ID from a fully-qualified table ID.
|
||||
|
||||
Args:
|
||||
table_id: e.g. "prj-grp-dataview-prod-1ff9.finance_unit_economics.unit_economics"
|
||||
|
||||
Returns:
|
||||
Project ID or None if format doesn't match BQ convention
|
||||
"""
|
||||
parts = table_id.split(".")
|
||||
if len(parts) == 3 and "-" in parts[0]:
|
||||
return parts[0]
|
||||
return None
|
||||
|
||||
|
||||
def _create_bq_client(project: str):
|
||||
"""Create a BigQuery client. Separated for testability.
|
||||
|
||||
Args:
|
||||
project: GCP project ID for billing
|
||||
|
||||
Returns:
|
||||
google.cloud.bigquery.Client instance
|
||||
"""
|
||||
from google.cloud import bigquery as bq_module
|
||||
|
||||
return bq_module.Client(project=project)
|
||||
|
||||
|
||||
def register_bq_table(
|
||||
conn: duckdb.DuckDBPyConnection,
|
||||
table_id: str,
|
||||
view_name: str,
|
||||
sql: str,
|
||||
bq_project: Optional[str] = None,
|
||||
_bq_client_factory=None,
|
||||
) -> int:
|
||||
"""
|
||||
Execute a BigQuery SQL query and register the result as a DuckDB view.
|
||||
|
||||
Uses the Python BigQuery client (Query API) which supports BQ views,
|
||||
unlike the DuckDB BigQuery extension (Storage Read API).
|
||||
The result is held in memory as a PyArrow table -- no disk I/O.
|
||||
|
||||
Args:
|
||||
conn: Open DuckDB connection
|
||||
table_id: BQ table ID for logging (e.g., "project.dataset.table")
|
||||
view_name: Name to register in DuckDB (e.g., "unit_economics_live")
|
||||
sql: Full BigQuery SQL query to execute
|
||||
bq_project: GCP project for billing. If None, uses BIGQUERY_PROJECT env var.
|
||||
_bq_client_factory: Override BQ client creation (for testing)
|
||||
|
||||
Returns:
|
||||
Number of rows in the result
|
||||
|
||||
Raises:
|
||||
ImportError: If google-cloud-bigquery is not installed
|
||||
ValueError: If bq_project is not set
|
||||
"""
|
||||
project = bq_project or os.environ.get("BIGQUERY_PROJECT")
|
||||
if not project:
|
||||
raise ValueError(
|
||||
"BigQuery project not set. "
|
||||
"Pass bq_project or set BIGQUERY_PROJECT env var."
|
||||
)
|
||||
|
||||
logger.info(f"Querying BQ: {table_id} -> {view_name}")
|
||||
logger.debug(f"SQL: {sql[:200]}...")
|
||||
|
||||
factory = _bq_client_factory or _create_bq_client
|
||||
client = factory(project)
|
||||
job = client.query(sql)
|
||||
|
||||
# Use Query API (not Storage Read API) to support BQ views
|
||||
try:
|
||||
arrow_table = job.to_arrow()
|
||||
except Exception as e:
|
||||
if "readsessions" in str(e) or "PERMISSION_DENIED" in str(e):
|
||||
logger.warning("BQ Storage API unavailable, falling back to REST")
|
||||
arrow_table = job.to_arrow(create_bqstorage_client=False)
|
||||
else:
|
||||
raise
|
||||
|
||||
conn.register(view_name, arrow_table)
|
||||
logger.info(
|
||||
f"Registered {view_name}: {arrow_table.num_rows} rows, "
|
||||
f"{arrow_table.num_columns} cols (in-memory)"
|
||||
)
|
||||
|
||||
return arrow_table.num_rows
|
||||
|
||||
|
||||
def get_remote_tables(table_configs: List[Dict]) -> List[Dict]:
|
||||
"""Return table configs with query_mode 'remote' or 'hybrid'.
|
||||
|
||||
Args:
|
||||
table_configs: List of table configuration dicts
|
||||
|
||||
Returns:
|
||||
List of remote/hybrid table configs
|
||||
"""
|
||||
return [
|
||||
tc for tc in table_configs
|
||||
if tc.get("query_mode") in ("remote", "hybrid")
|
||||
]
|
||||
|
||||
|
||||
def init_duckdb(
|
||||
db_path="user/duckdb/analytics.duckdb",
|
||||
data_dir="server",
|
||||
verbose=True,
|
||||
bq_project: Optional[str] = None,
|
||||
):
|
||||
"""
|
||||
Initialize DuckDB database with views from parquet files.
|
||||
|
||||
Dynamically reads table configurations from docs/data_description.md
|
||||
and creates views accordingly.
|
||||
Creates DuckDB views for local/hybrid tables (from Parquet).
|
||||
Remote tables are NOT pre-loaded -- they are registered at query time
|
||||
via register_bq_table().
|
||||
|
||||
Args:
|
||||
db_path: Path to DuckDB database file
|
||||
data_dir: Base data directory (e.g., "server" for analysts, "data" for server)
|
||||
verbose: Print progress messages
|
||||
bq_project: BigQuery execution project (for informational purposes only)
|
||||
|
||||
Returns:
|
||||
True if successful, False otherwise
|
||||
|
|
@ -176,29 +296,48 @@ def init_duckdb(db_path="user/duckdb/analytics.duckdb", data_dir="server", verbo
|
|||
os.makedirs(os.path.dirname(db_path), exist_ok=True)
|
||||
|
||||
if verbose:
|
||||
print("🦆 Inicializuji DuckDB databázi...")
|
||||
print("Initializing DuckDB database...")
|
||||
|
||||
try:
|
||||
# Find project root and parse data_description.md
|
||||
project_root = find_project_root()
|
||||
if verbose:
|
||||
print(f" 📂 Project root: {project_root}")
|
||||
print(f" Project root: {project_root}")
|
||||
|
||||
table_configs, folder_mapping = parse_data_description(project_root)
|
||||
if verbose:
|
||||
print(f" 📋 Načteno {len(table_configs)} tabulek z data_description.md")
|
||||
print(f" Loaded {len(table_configs)} tables from data_description.md")
|
||||
|
||||
# Separate tables by query_mode
|
||||
local_tables = []
|
||||
remote_tables = []
|
||||
hybrid_tables = []
|
||||
|
||||
for tc in table_configs:
|
||||
mode = tc.get("query_mode", "local")
|
||||
if mode == "remote":
|
||||
remote_tables.append(tc)
|
||||
elif mode == "hybrid":
|
||||
hybrid_tables.append(tc)
|
||||
else:
|
||||
local_tables.append(tc)
|
||||
|
||||
if verbose:
|
||||
print(f" Query modes: {len(local_tables)} local, "
|
||||
f"{len(remote_tables)} remote, {len(hybrid_tables)} hybrid")
|
||||
|
||||
# Connect to database (creates if doesn't exist)
|
||||
conn = duckdb.connect(db_path)
|
||||
|
||||
# Create views
|
||||
# Create local views from parquet files
|
||||
if verbose:
|
||||
print("\n📊 Vytvářím views z parquet souborů...")
|
||||
print("\n Creating views from parquet files...")
|
||||
|
||||
created_views = []
|
||||
skipped_views = []
|
||||
|
||||
for table_config in table_configs:
|
||||
# Process local and hybrid tables (both have local parquet)
|
||||
for table_config in local_tables + hybrid_tables:
|
||||
table_name = table_config['name']
|
||||
|
||||
try:
|
||||
|
|
@ -209,7 +348,7 @@ def init_duckdb(db_path="user/duckdb/analytics.duckdb", data_dir="server", verbo
|
|||
if not parquet_path.exists():
|
||||
skipped_views.append(table_name)
|
||||
if verbose:
|
||||
print(f" ⚠️ Přeskakuji {table_name} - parquet neexistuje: {parquet_path}")
|
||||
print(f" [SKIP] {table_name} - parquet not found: {parquet_path}")
|
||||
continue
|
||||
|
||||
# Determine if partitioned
|
||||
|
|
@ -229,49 +368,65 @@ def init_duckdb(db_path="user/duckdb/analytics.duckdb", data_dir="server", verbo
|
|||
if not partition_files:
|
||||
skipped_views.append(table_name)
|
||||
if verbose:
|
||||
print(f" ⚠️ Přeskakuji {table_name} - žádné partition soubory")
|
||||
print(f" [SKIP] {table_name} - no partition files")
|
||||
continue
|
||||
|
||||
sql = f"CREATE OR REPLACE VIEW {table_name} AS SELECT * FROM read_parquet('{glob_pattern}', union_by_name=true)"
|
||||
if verbose:
|
||||
print(f" ✅ {table_name} ({len(partition_files)} partitions)")
|
||||
mode_label = "hybrid" if table_config.get("query_mode") == "hybrid" else "local"
|
||||
print(f" [OK] {table_name} ({len(partition_files)} partitions, {mode_label})")
|
||||
else:
|
||||
# Single parquet file
|
||||
sql = f"CREATE OR REPLACE VIEW {table_name} AS SELECT * FROM read_parquet('{parquet_path}')"
|
||||
if verbose:
|
||||
print(f" ✅ {table_name}")
|
||||
mode_label = "hybrid" if table_config.get("query_mode") == "hybrid" else "local"
|
||||
print(f" [OK] {table_name} ({mode_label})")
|
||||
|
||||
conn.execute(sql)
|
||||
created_views.append(table_name)
|
||||
|
||||
except Exception as e:
|
||||
if verbose:
|
||||
print(f" ❌ Chyba při vytváření {table_name}: {e}")
|
||||
print(f" [ERR] Error creating {table_name}: {e}")
|
||||
return False
|
||||
|
||||
# Log remote tables (queried at runtime via register_bq_table)
|
||||
if remote_tables:
|
||||
if verbose:
|
||||
print("\n Remote tables (queried at runtime via BigQuery):")
|
||||
for table_config in remote_tables:
|
||||
table_name = table_config['name']
|
||||
table_id = table_config['id']
|
||||
if verbose:
|
||||
print(f" [BQ] {table_name} -> {table_id}")
|
||||
|
||||
# Display table list with row counts
|
||||
if verbose:
|
||||
print(f"\n📋 Seznam dostupných tabulek ({len(created_views)} vytvořeno):")
|
||||
print(f"\n Available tables ({len(created_views)} local views):")
|
||||
tables = conn.execute("SHOW TABLES").fetchall()
|
||||
for table in tables:
|
||||
try:
|
||||
row_count = conn.execute(f"SELECT COUNT(*) FROM {table[0]}").fetchone()[0]
|
||||
print(f" - {table[0]}: {row_count:,} řádků")
|
||||
except Exception as e:
|
||||
print(f" - {table[0]}: (chyba při počítání řádků)")
|
||||
print(f" - {table[0]}: {row_count:,} rows (local)")
|
||||
except Exception:
|
||||
print(f" - {table[0]}: (error counting rows)")
|
||||
|
||||
if remote_tables:
|
||||
print(f"\n Remote tables ({len(remote_tables)}, loaded on demand):")
|
||||
for tc in remote_tables:
|
||||
print(f" - {tc['name']}: via BQ Query API (use date filters!)")
|
||||
|
||||
# Close connection
|
||||
conn.close()
|
||||
|
||||
if verbose:
|
||||
print(f"\n✅ DuckDB databáze vytvořena: {db_path}")
|
||||
print("💡 Můžeš začít analyzovat data pomocí DuckDB SQL dotazů")
|
||||
print(f"\n DuckDB database created: {db_path}")
|
||||
|
||||
return True
|
||||
|
||||
except Exception as e:
|
||||
if verbose:
|
||||
print(f"\n❌ Chyba při inicializaci DuckDB: {e}")
|
||||
print(f"\n Error initializing DuckDB: {e}")
|
||||
import traceback
|
||||
traceback.print_exc()
|
||||
return False
|
||||
|
|
@ -297,6 +452,11 @@ def main():
|
|||
default='server',
|
||||
help='Base data directory (default: server, use "data" for server deployment)'
|
||||
)
|
||||
parser.add_argument(
|
||||
'--bq-project',
|
||||
default=None,
|
||||
help='BigQuery execution project (informational only)'
|
||||
)
|
||||
parser.add_argument(
|
||||
'--quiet',
|
||||
action='store_true',
|
||||
|
|
@ -314,7 +474,8 @@ def main():
|
|||
success = init_duckdb(
|
||||
db_path=args.db_path,
|
||||
data_dir=args.data_dir,
|
||||
verbose=not args.quiet
|
||||
verbose=not args.quiet,
|
||||
bq_project=args.bq_project,
|
||||
)
|
||||
|
||||
# Exit with appropriate code
|
||||
|
|
|
|||
|
|
@ -104,9 +104,19 @@ class TableConfig:
|
|||
incremental_column: Optional[str] = None # Column for timestamp-based incremental sync (BigQuery)
|
||||
columns: Optional[List[str]] = None # Subset of columns to sync (None = all)
|
||||
row_filter: Optional[str] = None # SQL WHERE clause for filtering (e.g., "event_date >= '2024-01-01'")
|
||||
query_mode: str = "local" # "local" (Parquet) | "remote" (BQ direct) | "hybrid" (sync subset, query BQ)
|
||||
partition_column_type: str = "TIMESTAMP" # BQ SQL type for partition column: "DATE", "TIMESTAMP", "DATETIME"
|
||||
|
||||
def __post_init__(self):
|
||||
"""Validate configuration after initialization."""
|
||||
# Validate query_mode
|
||||
valid_query_modes = ("local", "remote", "hybrid")
|
||||
if self.query_mode not in valid_query_modes:
|
||||
raise ValueError(
|
||||
f"Invalid query_mode '{self.query_mode}' for table {self.id}. "
|
||||
f"Allowed values: {', '.join(valid_query_modes)}"
|
||||
)
|
||||
|
||||
# Validate sync_strategy
|
||||
if self.sync_strategy not in ["full_refresh", "incremental", "partitioned"]:
|
||||
raise ValueError(
|
||||
|
|
@ -139,6 +149,14 @@ class TableConfig:
|
|||
f"Allowed values: 'month', 'day', 'year'"
|
||||
)
|
||||
|
||||
# Validate partition_column_type
|
||||
valid_column_types = ("DATE", "TIMESTAMP", "DATETIME")
|
||||
if self.partition_column_type not in valid_column_types:
|
||||
raise ValueError(
|
||||
f"Invalid partition_column_type '{self.partition_column_type}' for table {self.id}. "
|
||||
f"Allowed values: {', '.join(valid_column_types)}"
|
||||
)
|
||||
|
||||
# For partitioned, partition_by must be defined
|
||||
if self.sync_strategy == "partitioned":
|
||||
if not self.partition_by:
|
||||
|
|
@ -435,6 +453,8 @@ class Config:
|
|||
incremental_column=table_data.get("incremental_column"),
|
||||
columns=table_data.get("columns"),
|
||||
row_filter=table_data.get("row_filter"),
|
||||
query_mode=table_data.get("query_mode", "local"),
|
||||
partition_column_type=table_data.get("partition_column_type", "TIMESTAMP"),
|
||||
)
|
||||
table_configs.append(config)
|
||||
|
||||
|
|
|
|||
|
|
@ -406,6 +406,21 @@ class DataSyncManager:
|
|||
else:
|
||||
table_configs = self.config.tables
|
||||
|
||||
# Filter out remote-only tables (no local sync needed)
|
||||
remote_skipped = [
|
||||
tc for tc in table_configs if tc.query_mode == "remote"
|
||||
]
|
||||
table_configs = [
|
||||
tc for tc in table_configs if tc.query_mode != "remote"
|
||||
]
|
||||
|
||||
if remote_skipped:
|
||||
logger.info(
|
||||
f"Skipping {len(remote_skipped)} remote-only tables "
|
||||
f"(query via BigQuery): "
|
||||
f"{', '.join(tc.name for tc in remote_skipped)}"
|
||||
)
|
||||
|
||||
logger.info(f"Synchronizing {len(table_configs)} tables...")
|
||||
|
||||
results = {}
|
||||
|
|
|
|||
|
|
@ -9,6 +9,7 @@ so we install stub modules in sys.modules before importing the adapter.
|
|||
"""
|
||||
|
||||
import sys
|
||||
from datetime import date, datetime, timedelta
|
||||
from pathlib import Path
|
||||
from unittest.mock import MagicMock, patch
|
||||
|
||||
|
|
@ -83,6 +84,7 @@ def _make_table_config(
|
|||
partition_by: str | None = None,
|
||||
partition_granularity: str | None = None,
|
||||
max_history_days: int | None = None,
|
||||
partition_column_type: str = "TIMESTAMP",
|
||||
) -> TableConfig:
|
||||
"""Helper to build a TableConfig with safe defaults."""
|
||||
return TableConfig(
|
||||
|
|
@ -96,6 +98,7 @@ def _make_table_config(
|
|||
partition_by=partition_by,
|
||||
partition_granularity=partition_granularity,
|
||||
max_history_days=max_history_days,
|
||||
partition_column_type=partition_column_type,
|
||||
)
|
||||
|
||||
|
||||
|
|
@ -274,55 +277,217 @@ class TestIncrementalNoNewData:
|
|||
|
||||
|
||||
# ---------------------------------------------------------------------------
|
||||
# 4. partitioned_sync creates partition files
|
||||
# 4. partitioned_sync - per-day streaming behaviour
|
||||
# ---------------------------------------------------------------------------
|
||||
|
||||
class TestPartitionedSync:
|
||||
"""Tests for the rewritten _partitioned_sync() that streams per-day from BQ."""
|
||||
|
||||
def test_creates_partition_files(self, mock_config, mock_bq_client, tmp_parquet_dir, sync_state):
|
||||
"""Partitioned sync should create separate Parquet files per partition key."""
|
||||
import pandas as pd
|
||||
|
||||
def _setup_partition_config(
|
||||
self,
|
||||
mock_config,
|
||||
tmp_parquet_dir,
|
||||
*,
|
||||
granularity: str = "day",
|
||||
max_history_days: int | None = 3,
|
||||
incremental_window_days: int | None = 2,
|
||||
partition_column_type: str = "TIMESTAMP",
|
||||
):
|
||||
"""Common setup: create table config + wire mock_config partition paths."""
|
||||
table_config = _make_table_config(
|
||||
sync_strategy="incremental",
|
||||
incremental_column="created_at",
|
||||
partition_by="created_at",
|
||||
partition_granularity="month",
|
||||
incremental_window_days=7,
|
||||
sync_strategy="partitioned",
|
||||
incremental_column="event_date",
|
||||
partition_by="event_date",
|
||||
partition_granularity=granularity,
|
||||
max_history_days=max_history_days,
|
||||
incremental_window_days=incremental_window_days,
|
||||
partition_column_type=partition_column_type,
|
||||
)
|
||||
|
||||
# For partitioned tables, parquet_path is a directory
|
||||
partition_dir = tmp_parquet_dir / "orders"
|
||||
partition_dir.mkdir(parents=True, exist_ok=True)
|
||||
mock_config.get_parquet_path.return_value = partition_dir
|
||||
|
||||
# Configure partition paths
|
||||
def _partition_path(tc, key):
|
||||
return partition_dir / f"{key}.parquet"
|
||||
mock_config.get_partition_path.side_effect = _partition_path
|
||||
|
||||
# Build arrow table with timestamps in two months
|
||||
ts_jan = [pd.Timestamp("2026-01-15 10:00:00", tz="UTC")]
|
||||
ts_feb = [pd.Timestamp("2026-02-20 14:00:00", tz="UTC")]
|
||||
arrow_data = pa.table({
|
||||
"id": [1, 2],
|
||||
"name": ["Jan_Order", "Feb_Order"],
|
||||
"created_at": pa.array(ts_jan + ts_feb, type=pa.timestamp("us", tz="UTC")),
|
||||
return table_config, partition_dir
|
||||
|
||||
@staticmethod
|
||||
def _make_day_table(row_id: int, day: date) -> pa.Table:
|
||||
"""Build a one-row Arrow table for a given day."""
|
||||
return pa.table({
|
||||
"id": [row_id],
|
||||
"event_date": pa.array([datetime(day.year, day.month, day.day)], type=pa.timestamp("us")),
|
||||
})
|
||||
mock_bq_client.read_table.return_value = arrow_data
|
||||
|
||||
def test_creates_daily_partition_files(
|
||||
self, mock_config, mock_bq_client, tmp_parquet_dir, sync_state
|
||||
):
|
||||
"""First sync with max_history_days creates one Parquet file per day with data."""
|
||||
table_config, partition_dir = self._setup_partition_config(
|
||||
mock_config, tmp_parquet_dir, max_history_days=3, granularity="day",
|
||||
)
|
||||
|
||||
today = date.today()
|
||||
day0 = today - timedelta(days=3)
|
||||
day1 = today - timedelta(days=2)
|
||||
# day2, day3 (today-1, today) will have no data
|
||||
|
||||
# Build per-day Arrow data for the two days that have rows
|
||||
day0_table = self._make_day_table(1, day0)
|
||||
day1_table = self._make_day_table(2, day1)
|
||||
|
||||
# read_table_partitioned_streaming is called once per partition date.
|
||||
# We need to return data for day0 and day1, empty iterators for the rest.
|
||||
def _streaming_side_effect(*, table_id, partition_column, start, end, columns, column_type):
|
||||
start_date = date.fromisoformat(start)
|
||||
if start_date == day0:
|
||||
return iter(day0_table.to_batches())
|
||||
if start_date == day1:
|
||||
return iter(day1_table.to_batches())
|
||||
return iter([]) # empty for other days
|
||||
|
||||
mock_bq_client.read_table_partitioned_streaming.side_effect = _streaming_side_effect
|
||||
|
||||
adapter = _create_adapter(mock_config, mock_bq_client)
|
||||
result = adapter.sync_table(table_config, sync_state)
|
||||
|
||||
assert result["success"] is True
|
||||
|
||||
# Should have created two partition files
|
||||
partition_files = list(partition_dir.glob("*.parquet"))
|
||||
# Should have exactly 2 partition files (days with data)
|
||||
partition_files = sorted(partition_dir.glob("*.parquet"))
|
||||
assert len(partition_files) == 2
|
||||
|
||||
partition_names = sorted(f.stem for f in partition_files)
|
||||
assert "2026_01" in partition_names
|
||||
assert "2026_02" in partition_names
|
||||
file_names = sorted(f.stem for f in partition_files)
|
||||
assert day0.strftime("%Y_%m_%d") in file_names
|
||||
assert day1.strftime("%Y_%m_%d") in file_names
|
||||
|
||||
# Verify content of each partition
|
||||
t0 = pq.read_table(partition_dir / f"{day0.strftime('%Y_%m_%d')}.parquet")
|
||||
assert t0.num_rows == 1
|
||||
assert t0.column("id").to_pylist() == [1]
|
||||
|
||||
def test_incremental_sync_only_fetches_window(
|
||||
self, mock_config, mock_bq_client, tmp_parquet_dir, sync_state
|
||||
):
|
||||
"""After a previous sync, only the incremental window of days is fetched."""
|
||||
table_config, partition_dir = self._setup_partition_config(
|
||||
mock_config, tmp_parquet_dir,
|
||||
max_history_days=30,
|
||||
incremental_window_days=2,
|
||||
granularity="day",
|
||||
)
|
||||
|
||||
# Simulate a previous sync 1 day ago
|
||||
sync_time = (datetime.now() - timedelta(days=1)).isoformat()
|
||||
sync_state.update_sync(
|
||||
table_id=table_config.id,
|
||||
table_name=table_config.name,
|
||||
strategy="partitioned",
|
||||
rows=100,
|
||||
file_size_bytes=5000,
|
||||
)
|
||||
|
||||
# Return empty for all calls -- we just want to verify the call count
|
||||
mock_bq_client.read_table_partitioned_streaming.return_value = iter([])
|
||||
|
||||
adapter = _create_adapter(mock_config, mock_bq_client)
|
||||
result = adapter.sync_table(table_config, sync_state)
|
||||
|
||||
assert result["success"] is True
|
||||
|
||||
# With incremental_window_days=2, it should go back 2 days from last_sync.
|
||||
# The number of partition dates from (last_sync - 2 days) to today.
|
||||
last_sync_str = sync_state.get_last_sync(table_config.id)
|
||||
last_sync_dt = datetime.fromisoformat(last_sync_str)
|
||||
start_date = (last_sync_dt - timedelta(days=2)).date()
|
||||
today = date.today()
|
||||
expected_days = (today - start_date).days + 1 # inclusive
|
||||
|
||||
actual_calls = mock_bq_client.read_table_partitioned_streaming.call_count
|
||||
assert actual_calls == expected_days, (
|
||||
f"Expected {expected_days} BQ calls (from {start_date} to {today}), got {actual_calls}"
|
||||
)
|
||||
|
||||
def test_merges_with_existing_partition(
|
||||
self, mock_config, mock_bq_client, tmp_parquet_dir, sync_state
|
||||
):
|
||||
"""New data for an existing partition merges and deduplicates on PK."""
|
||||
table_config, partition_dir = self._setup_partition_config(
|
||||
mock_config, tmp_parquet_dir, max_history_days=3, granularity="day",
|
||||
)
|
||||
|
||||
today = date.today()
|
||||
target_day = today - timedelta(days=1)
|
||||
partition_key = target_day.strftime("%Y_%m_%d")
|
||||
partition_path = partition_dir / f"{partition_key}.parquet"
|
||||
|
||||
# Pre-write an existing partition file with id=1
|
||||
existing = pa.table({
|
||||
"id": [1],
|
||||
"event_date": pa.array(
|
||||
[datetime(target_day.year, target_day.month, target_day.day)],
|
||||
type=pa.timestamp("us"),
|
||||
),
|
||||
})
|
||||
pq.write_table(existing, partition_path, compression="snappy")
|
||||
|
||||
# New data: id=1 (update) + id=2 (new row)
|
||||
new_data = pa.table({
|
||||
"id": [1, 2],
|
||||
"event_date": pa.array(
|
||||
[
|
||||
datetime(target_day.year, target_day.month, target_day.day),
|
||||
datetime(target_day.year, target_day.month, target_day.day),
|
||||
],
|
||||
type=pa.timestamp("us"),
|
||||
),
|
||||
})
|
||||
|
||||
def _streaming_side_effect(*, table_id, partition_column, start, end, columns, column_type):
|
||||
start_date = date.fromisoformat(start)
|
||||
if start_date == target_day:
|
||||
return iter(new_data.to_batches())
|
||||
return iter([])
|
||||
|
||||
mock_bq_client.read_table_partitioned_streaming.side_effect = _streaming_side_effect
|
||||
|
||||
adapter = _create_adapter(mock_config, mock_bq_client)
|
||||
result = adapter.sync_table(table_config, sync_state)
|
||||
|
||||
assert result["success"] is True
|
||||
|
||||
# Read back the target partition -- should have 2 rows (dedup on id)
|
||||
merged = pq.read_table(partition_path)
|
||||
assert merged.num_rows == 2
|
||||
assert sorted(merged.column("id").to_pylist()) == [1, 2]
|
||||
|
||||
def test_empty_partition_skipped(
|
||||
self, mock_config, mock_bq_client, tmp_parquet_dir, sync_state
|
||||
):
|
||||
"""A partition day with no data from BQ should not create a file."""
|
||||
table_config, partition_dir = self._setup_partition_config(
|
||||
mock_config, tmp_parquet_dir, max_history_days=2, granularity="day",
|
||||
)
|
||||
|
||||
# Return empty iterator for every call
|
||||
mock_bq_client.read_table_partitioned_streaming.return_value = iter([])
|
||||
# side_effect takes precedence over return_value when set, but let's use
|
||||
# a function so each call gets a fresh empty iterator
|
||||
mock_bq_client.read_table_partitioned_streaming.side_effect = (
|
||||
lambda **kw: iter([])
|
||||
)
|
||||
|
||||
adapter = _create_adapter(mock_config, mock_bq_client)
|
||||
result = adapter.sync_table(table_config, sync_state)
|
||||
|
||||
assert result["success"] is True
|
||||
|
||||
# No partition files should have been created
|
||||
partition_files = list(partition_dir.glob("*.parquet"))
|
||||
assert len(partition_files) == 0
|
||||
|
||||
|
||||
# ---------------------------------------------------------------------------
|
||||
|
|
@ -574,15 +739,13 @@ class TestSyncTableDispatch:
|
|||
def test_dispatches_partitioned(
|
||||
self, mock_config, mock_bq_client, tmp_parquet_dir, sync_state
|
||||
):
|
||||
"""sync_strategy='incremental' with partition_by should call _partitioned_sync."""
|
||||
import pandas as pd
|
||||
|
||||
"""sync_strategy='partitioned' should call _partitioned_sync."""
|
||||
table_config = _make_table_config(
|
||||
sync_strategy="incremental",
|
||||
sync_strategy="partitioned",
|
||||
incremental_column="created_at",
|
||||
partition_by="created_at",
|
||||
partition_granularity="month",
|
||||
incremental_window_days=7,
|
||||
partition_granularity="day",
|
||||
max_history_days=2,
|
||||
)
|
||||
partition_dir = tmp_parquet_dir / "orders"
|
||||
partition_dir.mkdir(parents=True, exist_ok=True)
|
||||
|
|
@ -592,13 +755,9 @@ class TestSyncTableDispatch:
|
|||
return partition_dir / f"{key}.parquet"
|
||||
mock_config.get_partition_path.side_effect = _partition_path
|
||||
|
||||
ts = [pd.Timestamp("2026-01-15 10:00:00", tz="UTC")]
|
||||
arrow_data = pa.table({
|
||||
"id": [1],
|
||||
"name": ["A"],
|
||||
"created_at": pa.array(ts, type=pa.timestamp("us", tz="UTC")),
|
||||
})
|
||||
mock_bq_client.read_table.return_value = arrow_data
|
||||
mock_bq_client.read_table_partitioned_streaming.side_effect = (
|
||||
lambda **kw: iter([])
|
||||
)
|
||||
|
||||
adapter = _create_adapter(mock_config, mock_bq_client)
|
||||
|
||||
|
|
@ -770,3 +929,175 @@ class TestCreateDataSourceFactory:
|
|||
from connectors.bigquery.adapter import create_data_source, BigQueryDataSource
|
||||
instance = create_data_source()
|
||||
assert isinstance(instance, BigQueryDataSource)
|
||||
|
||||
|
||||
# ---------------------------------------------------------------------------
|
||||
# 14. _cleanup_old_partitions deletes files beyond retention window
|
||||
# ---------------------------------------------------------------------------
|
||||
|
||||
class TestPartitionCleanup:
|
||||
|
||||
def test_deletes_old_partitions(self, mock_config, mock_bq_client, tmp_parquet_dir):
|
||||
"""Partition files older than max_history_days should be deleted."""
|
||||
table_config = _make_table_config(
|
||||
sync_strategy="partitioned",
|
||||
partition_by="event_date",
|
||||
partition_granularity="day",
|
||||
max_history_days=5,
|
||||
)
|
||||
|
||||
partition_dir = tmp_parquet_dir / "orders"
|
||||
partition_dir.mkdir(parents=True, exist_ok=True)
|
||||
|
||||
today = date.today()
|
||||
# Create files: 3 days ago (keep), 6 days ago (delete), 10 days ago (delete)
|
||||
keep_day = today - timedelta(days=3)
|
||||
delete_day1 = today - timedelta(days=6)
|
||||
delete_day2 = today - timedelta(days=10)
|
||||
|
||||
for d in [keep_day, delete_day1, delete_day2]:
|
||||
key = d.strftime("%Y_%m_%d")
|
||||
dummy = pa.table({"id": [1]})
|
||||
pq.write_table(dummy, partition_dir / f"{key}.parquet")
|
||||
|
||||
adapter = _create_adapter(mock_config, mock_bq_client)
|
||||
deleted = adapter._cleanup_old_partitions(table_config, partition_dir, "day")
|
||||
|
||||
assert deleted == 2
|
||||
|
||||
# Only the recent file should remain
|
||||
remaining = [f.stem for f in partition_dir.glob("*.parquet")]
|
||||
assert keep_day.strftime("%Y_%m_%d") in remaining
|
||||
assert delete_day1.strftime("%Y_%m_%d") not in remaining
|
||||
assert delete_day2.strftime("%Y_%m_%d") not in remaining
|
||||
|
||||
def test_no_cleanup_without_max_history_days(self, mock_config, mock_bq_client, tmp_parquet_dir):
|
||||
"""Without max_history_days, no partition files should be deleted."""
|
||||
table_config = _make_table_config(
|
||||
sync_strategy="partitioned",
|
||||
partition_by="event_date",
|
||||
partition_granularity="day",
|
||||
max_history_days=None,
|
||||
)
|
||||
|
||||
partition_dir = tmp_parquet_dir / "orders"
|
||||
partition_dir.mkdir(parents=True, exist_ok=True)
|
||||
|
||||
# Create an old file (100 days ago)
|
||||
old_day = date.today() - timedelta(days=100)
|
||||
key = old_day.strftime("%Y_%m_%d")
|
||||
pq.write_table(pa.table({"id": [1]}), partition_dir / f"{key}.parquet")
|
||||
|
||||
adapter = _create_adapter(mock_config, mock_bq_client)
|
||||
deleted = adapter._cleanup_old_partitions(table_config, partition_dir, "day")
|
||||
|
||||
assert deleted == 0
|
||||
assert len(list(partition_dir.glob("*.parquet"))) == 1
|
||||
|
||||
|
||||
# ---------------------------------------------------------------------------
|
||||
# 15. _generate_partition_dates produces correct date ranges
|
||||
# ---------------------------------------------------------------------------
|
||||
|
||||
class TestGeneratePartitionDates:
|
||||
|
||||
def test_daily_generation(self, mock_config, mock_bq_client):
|
||||
"""Daily granularity should generate one date per day, inclusive."""
|
||||
adapter = _create_adapter(mock_config, mock_bq_client)
|
||||
|
||||
start = date(2026, 3, 1)
|
||||
end = date(2026, 3, 5)
|
||||
dates = adapter._generate_partition_dates(start, end, "day")
|
||||
|
||||
assert dates == [
|
||||
date(2026, 3, 1),
|
||||
date(2026, 3, 2),
|
||||
date(2026, 3, 3),
|
||||
date(2026, 3, 4),
|
||||
date(2026, 3, 5),
|
||||
]
|
||||
|
||||
def test_monthly_generation(self, mock_config, mock_bq_client):
|
||||
"""Monthly granularity should generate first-of-month dates, aligned."""
|
||||
adapter = _create_adapter(mock_config, mock_bq_client)
|
||||
|
||||
# Start mid-month -- should align to 1st
|
||||
start = date(2026, 1, 15)
|
||||
end = date(2026, 4, 10)
|
||||
dates = adapter._generate_partition_dates(start, end, "month")
|
||||
|
||||
assert dates == [
|
||||
date(2026, 1, 1),
|
||||
date(2026, 2, 1),
|
||||
date(2026, 3, 1),
|
||||
date(2026, 4, 1),
|
||||
]
|
||||
|
||||
def test_monthly_generation_across_year_boundary(self, mock_config, mock_bq_client):
|
||||
"""Monthly generation should cross year boundaries correctly."""
|
||||
adapter = _create_adapter(mock_config, mock_bq_client)
|
||||
|
||||
start = date(2025, 11, 1)
|
||||
end = date(2026, 2, 15)
|
||||
dates = adapter._generate_partition_dates(start, end, "month")
|
||||
|
||||
assert dates == [
|
||||
date(2025, 11, 1),
|
||||
date(2025, 12, 1),
|
||||
date(2026, 1, 1),
|
||||
date(2026, 2, 1),
|
||||
]
|
||||
|
||||
def test_daily_single_day(self, mock_config, mock_bq_client):
|
||||
"""When start == end, should return a single date."""
|
||||
adapter = _create_adapter(mock_config, mock_bq_client)
|
||||
|
||||
d = date(2026, 6, 15)
|
||||
dates = adapter._generate_partition_dates(d, d, "day")
|
||||
|
||||
assert dates == [d]
|
||||
|
||||
def test_empty_range(self, mock_config, mock_bq_client):
|
||||
"""When start > end, should return an empty list."""
|
||||
adapter = _create_adapter(mock_config, mock_bq_client)
|
||||
|
||||
dates = adapter._generate_partition_dates(date(2026, 3, 10), date(2026, 3, 5), "day")
|
||||
assert dates == []
|
||||
|
||||
|
||||
# ---------------------------------------------------------------------------
|
||||
# 16. _parse_partition_date converts partition keys back to dates
|
||||
# ---------------------------------------------------------------------------
|
||||
|
||||
class TestParsePartitionDate:
|
||||
|
||||
def test_parse_day_format(self, mock_config, mock_bq_client):
|
||||
"""'2026_01_15' with day granularity should parse to date(2026, 1, 15)."""
|
||||
adapter = _create_adapter(mock_config, mock_bq_client)
|
||||
result = adapter._parse_partition_date("2026_01_15", "day")
|
||||
assert result == date(2026, 1, 15)
|
||||
|
||||
def test_parse_month_format(self, mock_config, mock_bq_client):
|
||||
"""'2026_01' with month granularity should parse to date(2026, 1, 1)."""
|
||||
adapter = _create_adapter(mock_config, mock_bq_client)
|
||||
result = adapter._parse_partition_date("2026_01", "month")
|
||||
assert result == date(2026, 1, 1)
|
||||
|
||||
def test_parse_year_format(self, mock_config, mock_bq_client):
|
||||
"""'2026' with year granularity should parse to date(2026, 1, 1)."""
|
||||
adapter = _create_adapter(mock_config, mock_bq_client)
|
||||
result = adapter._parse_partition_date("2026", "year")
|
||||
assert result == date(2026, 1, 1)
|
||||
|
||||
def test_parse_invalid_returns_none(self, mock_config, mock_bq_client):
|
||||
"""Invalid partition key should return None."""
|
||||
adapter = _create_adapter(mock_config, mock_bq_client)
|
||||
assert adapter._parse_partition_date("invalid", "day") is None
|
||||
assert adapter._parse_partition_date("not_a_date", "month") is None
|
||||
assert adapter._parse_partition_date("abc", "year") is None
|
||||
|
||||
def test_parse_mismatched_granularity_returns_none(self, mock_config, mock_bq_client):
|
||||
"""Day key with month granularity should return None (format mismatch)."""
|
||||
adapter = _create_adapter(mock_config, mock_bq_client)
|
||||
# "2026_01_15" is a day format -- parsing as "month" (%Y_%m) should fail
|
||||
assert adapter._parse_partition_date("2026_01_15", "month") is None
|
||||
|
|
|
|||
|
|
@ -868,3 +868,151 @@ class TestCreateClient:
|
|||
|
||||
assert isinstance(result, BigQueryClient)
|
||||
assert result.project_id == "factory-project"
|
||||
|
||||
|
||||
# ---------------------------------------------------------------------------
|
||||
# 14. read_table_partitioned_streaming yields RecordBatches
|
||||
# ---------------------------------------------------------------------------
|
||||
|
||||
class TestReadTablePartitionedStreaming:
|
||||
def test_streaming_yields_batches(self, client, mock_bq_client):
|
||||
"""read_table_partitioned_streaming yields RecordBatches, not a Table."""
|
||||
batch1 = pa.record_batch({"a": [1, 2]})
|
||||
batch2 = pa.record_batch({"a": [3, 4]})
|
||||
|
||||
mock_query_job = MagicMock()
|
||||
mock_row_iter = MagicMock()
|
||||
mock_row_iter.to_arrow_iterable.return_value = iter([batch1, batch2])
|
||||
mock_query_job.result.return_value = mock_row_iter
|
||||
mock_bq_client.query.return_value = mock_query_job
|
||||
|
||||
with patch("connectors.bigquery.client.bigquery") as mock_bq_module:
|
||||
mock_bq_module.QueryJobConfig.return_value = MagicMock()
|
||||
mock_bq_module.ScalarQueryParameter.return_value = MagicMock()
|
||||
client.client = mock_bq_client
|
||||
client.bqstorage_client = None # disable Storage API for simplicity
|
||||
|
||||
batches = list(client.read_table_partitioned_streaming(
|
||||
table_id="proj.dataset.events",
|
||||
partition_column="event_date",
|
||||
start="2025-01-01",
|
||||
))
|
||||
|
||||
assert len(batches) == 2
|
||||
assert isinstance(batches[0], pa.RecordBatch)
|
||||
assert isinstance(batches[1], pa.RecordBatch)
|
||||
|
||||
def test_streaming_with_date_column_type(self, client, mock_bq_client):
|
||||
"""With column_type='DATE', ScalarQueryParameter uses 'DATE' type."""
|
||||
batch = pa.record_batch({"a": [1]})
|
||||
|
||||
mock_query_job = MagicMock()
|
||||
mock_row_iter = MagicMock()
|
||||
mock_row_iter.to_arrow_iterable.return_value = iter([batch])
|
||||
mock_query_job.result.return_value = mock_row_iter
|
||||
mock_bq_client.query.return_value = mock_query_job
|
||||
|
||||
with patch("connectors.bigquery.client.bigquery") as mock_bq_module:
|
||||
mock_bq_module.QueryJobConfig.return_value = MagicMock()
|
||||
mock_bq_module.ScalarQueryParameter.return_value = MagicMock()
|
||||
client.client = mock_bq_client
|
||||
client.bqstorage_client = None
|
||||
|
||||
list(client.read_table_partitioned_streaming(
|
||||
table_id="proj.dataset.events",
|
||||
partition_column="event_date",
|
||||
start="2025-01-01",
|
||||
column_type="DATE",
|
||||
))
|
||||
|
||||
# Verify ScalarQueryParameter was called with "DATE" type
|
||||
mock_bq_module.ScalarQueryParameter.assert_called_once_with(
|
||||
"start_value", "DATE", "2025-01-01"
|
||||
)
|
||||
|
||||
def test_streaming_start_and_end(self, client, mock_bq_client):
|
||||
"""With start and end, both params are created with correct column_type."""
|
||||
batch = pa.record_batch({"a": [1]})
|
||||
|
||||
mock_query_job = MagicMock()
|
||||
mock_row_iter = MagicMock()
|
||||
mock_row_iter.to_arrow_iterable.return_value = iter([batch])
|
||||
mock_query_job.result.return_value = mock_row_iter
|
||||
mock_bq_client.query.return_value = mock_query_job
|
||||
|
||||
with patch("connectors.bigquery.client.bigquery") as mock_bq_module:
|
||||
mock_bq_module.QueryJobConfig.return_value = MagicMock()
|
||||
mock_bq_module.ScalarQueryParameter.return_value = MagicMock()
|
||||
client.client = mock_bq_client
|
||||
client.bqstorage_client = None
|
||||
|
||||
list(client.read_table_partitioned_streaming(
|
||||
table_id="proj.dataset.events",
|
||||
partition_column="event_date",
|
||||
start="2025-01-01",
|
||||
end="2025-06-01",
|
||||
column_type="DATE",
|
||||
))
|
||||
|
||||
sql = mock_bq_client.query.call_args[0][0]
|
||||
assert "`event_date` >= @start_value" in sql
|
||||
assert "`event_date` < @end_value" in sql
|
||||
|
||||
# Both parameters created with "DATE" type
|
||||
assert mock_bq_module.ScalarQueryParameter.call_count == 2
|
||||
calls = mock_bq_module.ScalarQueryParameter.call_args_list
|
||||
assert calls[0].args == ("start_value", "DATE", "2025-01-01")
|
||||
assert calls[1].args == ("end_value", "DATE", "2025-06-01")
|
||||
|
||||
|
||||
# ---------------------------------------------------------------------------
|
||||
# 15. read_table_partitioned column_type parameter
|
||||
# ---------------------------------------------------------------------------
|
||||
|
||||
class TestReadTablePartitionedColumnType:
|
||||
def test_date_column_type(self, client, mock_bq_client):
|
||||
"""read_table_partitioned with column_type='DATE' creates DATE params."""
|
||||
mock_job = MagicMock()
|
||||
mock_job.to_arrow.return_value = pa.table({"a": [1]})
|
||||
mock_bq_client.query.return_value = mock_job
|
||||
|
||||
with patch("connectors.bigquery.client.bigquery") as mock_bq_module:
|
||||
mock_bq_module.QueryJobConfig.return_value = MagicMock()
|
||||
mock_bq_module.ScalarQueryParameter.return_value = MagicMock()
|
||||
client.client = mock_bq_client
|
||||
|
||||
client.read_table_partitioned(
|
||||
table_id="proj.dataset.events",
|
||||
partition_column="event_date",
|
||||
start="2025-01-01",
|
||||
end="2025-06-01",
|
||||
column_type="DATE",
|
||||
)
|
||||
|
||||
# Both parameters should use "DATE" type
|
||||
assert mock_bq_module.ScalarQueryParameter.call_count == 2
|
||||
calls = mock_bq_module.ScalarQueryParameter.call_args_list
|
||||
assert calls[0].args == ("start_value", "DATE", "2025-01-01")
|
||||
assert calls[1].args == ("end_value", "DATE", "2025-06-01")
|
||||
|
||||
def test_default_column_type_is_timestamp(self, client, mock_bq_client):
|
||||
"""Default column_type is TIMESTAMP when not specified."""
|
||||
mock_job = MagicMock()
|
||||
mock_job.to_arrow.return_value = pa.table({"a": [1]})
|
||||
mock_bq_client.query.return_value = mock_job
|
||||
|
||||
with patch("connectors.bigquery.client.bigquery") as mock_bq_module:
|
||||
mock_bq_module.QueryJobConfig.return_value = MagicMock()
|
||||
mock_bq_module.ScalarQueryParameter.return_value = MagicMock()
|
||||
client.client = mock_bq_client
|
||||
|
||||
client.read_table_partitioned(
|
||||
table_id="proj.dataset.events",
|
||||
partition_column="created_at",
|
||||
start="2025-01-01T00:00:00Z",
|
||||
)
|
||||
|
||||
# Should default to "TIMESTAMP"
|
||||
mock_bq_module.ScalarQueryParameter.assert_called_once_with(
|
||||
"start_value", "TIMESTAMP", "2025-01-01T00:00:00Z"
|
||||
)
|
||||
|
|
|
|||
69
tests/test_config_query_mode.py
Normal file
69
tests/test_config_query_mode.py
Normal file
|
|
@ -0,0 +1,69 @@
|
|||
"""Tests for TableConfig.query_mode field validation."""
|
||||
|
||||
import pytest
|
||||
|
||||
from src.config import TableConfig
|
||||
|
||||
|
||||
# ---------------------------------------------------------------------------
|
||||
# Helpers
|
||||
# ---------------------------------------------------------------------------
|
||||
def _make_table(**overrides) -> TableConfig:
|
||||
"""Create a TableConfig with sensible defaults, applying overrides."""
|
||||
defaults = dict(
|
||||
id="test.dataset.table",
|
||||
name="test_table",
|
||||
description="Test",
|
||||
primary_key="id",
|
||||
sync_strategy="full_refresh",
|
||||
)
|
||||
defaults.update(overrides)
|
||||
return TableConfig(**defaults)
|
||||
|
||||
|
||||
# ---------------------------------------------------------------------------
|
||||
# Tests
|
||||
# ---------------------------------------------------------------------------
|
||||
class TestQueryModeDefault:
|
||||
def test_default_is_local(self):
|
||||
table = _make_table()
|
||||
assert table.query_mode == "local"
|
||||
|
||||
|
||||
class TestQueryModeValidValues:
|
||||
@pytest.mark.parametrize("mode", ["local", "remote", "hybrid"])
|
||||
def test_valid_query_mode(self, mode):
|
||||
table = _make_table(query_mode=mode)
|
||||
assert table.query_mode == mode
|
||||
|
||||
|
||||
class TestQueryModeInvalid:
|
||||
@pytest.mark.parametrize("bad_mode", ["invalid", "Local", "REMOTE", "", "sql"])
|
||||
def test_invalid_query_mode_raises(self, bad_mode):
|
||||
with pytest.raises(ValueError, match="Invalid query_mode"):
|
||||
_make_table(query_mode=bad_mode)
|
||||
|
||||
|
||||
class TestQueryModeFromKwarg:
|
||||
def test_kwarg_sets_query_mode(self):
|
||||
"""Simulate what _parse_data_description does: pass query_mode as kwarg."""
|
||||
table = TableConfig(
|
||||
id="proj.dataset.orders",
|
||||
name="orders",
|
||||
description="Order data",
|
||||
primary_key="order_id",
|
||||
sync_strategy="full_refresh",
|
||||
query_mode="remote",
|
||||
)
|
||||
assert table.query_mode == "remote"
|
||||
|
||||
def test_kwarg_default_when_omitted(self):
|
||||
"""When YAML has no query_mode, _parse_data_description passes 'local'."""
|
||||
table = TableConfig(
|
||||
id="proj.dataset.orders",
|
||||
name="orders",
|
||||
description="Order data",
|
||||
primary_key="order_id",
|
||||
sync_strategy="full_refresh",
|
||||
)
|
||||
assert table.query_mode == "local"
|
||||
228
tests/test_data_sync_query_mode.py
Normal file
228
tests/test_data_sync_query_mode.py
Normal file
|
|
@ -0,0 +1,228 @@
|
|||
"""Tests for remote table skipping in DataSyncManager.sync_all().
|
||||
|
||||
Tables with query_mode == "remote" should be skipped during sync (no local
|
||||
Parquet file is needed -- queries go directly to BigQuery). Tables with
|
||||
query_mode "local" or "hybrid" must still be synced normally.
|
||||
"""
|
||||
|
||||
from unittest.mock import MagicMock, patch
|
||||
|
||||
import pytest
|
||||
|
||||
from src.config import TableConfig
|
||||
from src.data_sync import DataSyncManager
|
||||
|
||||
|
||||
# ---------------------------------------------------------------------------
|
||||
# Helpers
|
||||
# ---------------------------------------------------------------------------
|
||||
|
||||
def _make_table_config(
|
||||
table_id: str,
|
||||
name: str,
|
||||
query_mode: str = "local",
|
||||
) -> TableConfig:
|
||||
"""Create a minimal TableConfig for testing."""
|
||||
return TableConfig(
|
||||
id=table_id,
|
||||
name=name,
|
||||
description=f"Test table {name}",
|
||||
primary_key="id",
|
||||
sync_strategy="full_refresh",
|
||||
query_mode=query_mode,
|
||||
)
|
||||
|
||||
|
||||
def _successful_sync_result() -> dict:
|
||||
"""Return a fake successful sync result dict."""
|
||||
return {
|
||||
"success": True,
|
||||
"rows": 100,
|
||||
"file_size_mb": 0.5,
|
||||
}
|
||||
|
||||
|
||||
# ---------------------------------------------------------------------------
|
||||
# Fixtures
|
||||
# ---------------------------------------------------------------------------
|
||||
|
||||
@pytest.fixture
|
||||
def table_local():
|
||||
return _make_table_config("t.local", "local_table", query_mode="local")
|
||||
|
||||
|
||||
@pytest.fixture
|
||||
def table_remote():
|
||||
return _make_table_config("t.remote", "remote_table", query_mode="remote")
|
||||
|
||||
|
||||
@pytest.fixture
|
||||
def table_hybrid():
|
||||
return _make_table_config("t.hybrid", "hybrid_table", query_mode="hybrid")
|
||||
|
||||
|
||||
@pytest.fixture
|
||||
def all_tables(table_local, table_remote, table_hybrid):
|
||||
return [table_local, table_remote, table_hybrid]
|
||||
|
||||
|
||||
@pytest.fixture
|
||||
def mock_config(all_tables):
|
||||
"""Return a mock Config whose .tables list contains all three query modes."""
|
||||
cfg = MagicMock()
|
||||
cfg.tables = all_tables
|
||||
cfg.get_metadata_path.return_value = MagicMock() # Path-like
|
||||
|
||||
def _get_table_config(tid):
|
||||
return next((t for t in all_tables if t.id == tid), None)
|
||||
|
||||
cfg.get_table_config.side_effect = _get_table_config
|
||||
return cfg
|
||||
|
||||
|
||||
@pytest.fixture
|
||||
def mock_data_source():
|
||||
"""Return a mock DataSource that always succeeds."""
|
||||
ds = MagicMock()
|
||||
ds.sync_table.return_value = _successful_sync_result()
|
||||
return ds
|
||||
|
||||
|
||||
@pytest.fixture
|
||||
def sync_manager(mock_config, mock_data_source):
|
||||
"""Create a DataSyncManager with mocked dependencies."""
|
||||
with (
|
||||
patch("src.data_sync.get_config", return_value=mock_config),
|
||||
patch("src.data_sync.create_data_source", return_value=mock_data_source),
|
||||
patch("src.data_sync.SyncState"),
|
||||
):
|
||||
manager = DataSyncManager()
|
||||
# Patch out schema generation and auto-profiling (not under test)
|
||||
manager._generate_schema_yaml = MagicMock()
|
||||
yield manager
|
||||
|
||||
|
||||
# ---------------------------------------------------------------------------
|
||||
# Tests
|
||||
# ---------------------------------------------------------------------------
|
||||
|
||||
class TestSyncAllRemoteSkipping:
|
||||
"""Verify that sync_all filters out remote tables."""
|
||||
|
||||
def test_remote_table_not_synced(self, sync_manager, mock_data_source, table_remote):
|
||||
"""Remote table must NOT be passed to data_source.sync_table."""
|
||||
sync_manager.sync_all()
|
||||
|
||||
synced_ids = [
|
||||
call.args[0].id for call in mock_data_source.sync_table.call_args_list
|
||||
]
|
||||
assert table_remote.id not in synced_ids
|
||||
|
||||
def test_local_table_is_synced(self, sync_manager, mock_data_source, table_local):
|
||||
"""Local table must be synced normally."""
|
||||
sync_manager.sync_all()
|
||||
|
||||
synced_ids = [
|
||||
call.args[0].id for call in mock_data_source.sync_table.call_args_list
|
||||
]
|
||||
assert table_local.id in synced_ids
|
||||
|
||||
def test_hybrid_table_is_synced(self, sync_manager, mock_data_source, table_hybrid):
|
||||
"""Hybrid table must be synced (needs local parquet for profiling)."""
|
||||
sync_manager.sync_all()
|
||||
|
||||
synced_ids = [
|
||||
call.args[0].id for call in mock_data_source.sync_table.call_args_list
|
||||
]
|
||||
assert table_hybrid.id in synced_ids
|
||||
|
||||
def test_sync_call_count(self, sync_manager, mock_data_source):
|
||||
"""Only local + hybrid tables should result in sync_table calls."""
|
||||
sync_manager.sync_all()
|
||||
|
||||
# 3 tables total, 1 remote -> 2 sync calls
|
||||
assert mock_data_source.sync_table.call_count == 2
|
||||
|
||||
def test_results_exclude_remote(self, sync_manager, table_remote):
|
||||
"""The results dict must not contain an entry for the remote table."""
|
||||
results = sync_manager.sync_all()
|
||||
|
||||
assert table_remote.id not in results
|
||||
|
||||
def test_results_include_local_and_hybrid(
|
||||
self, sync_manager, table_local, table_hybrid
|
||||
):
|
||||
"""Results dict must contain entries for local and hybrid tables."""
|
||||
results = sync_manager.sync_all()
|
||||
|
||||
assert table_local.id in results
|
||||
assert table_hybrid.id in results
|
||||
|
||||
|
||||
class TestSyncAllAllRemote:
|
||||
"""Edge case: every table is remote."""
|
||||
|
||||
def test_no_sync_calls_when_all_remote(self, mock_config, mock_data_source):
|
||||
remote_only = [
|
||||
_make_table_config("t.r1", "remote1", query_mode="remote"),
|
||||
_make_table_config("t.r2", "remote2", query_mode="remote"),
|
||||
]
|
||||
mock_config.tables = remote_only
|
||||
|
||||
with (
|
||||
patch("src.data_sync.get_config", return_value=mock_config),
|
||||
patch("src.data_sync.create_data_source", return_value=mock_data_source),
|
||||
patch("src.data_sync.SyncState"),
|
||||
):
|
||||
manager = DataSyncManager()
|
||||
manager._generate_schema_yaml = MagicMock()
|
||||
results = manager.sync_all()
|
||||
|
||||
assert mock_data_source.sync_table.call_count == 0
|
||||
assert results == {}
|
||||
|
||||
|
||||
class TestSyncAllNoRemote:
|
||||
"""Edge case: no remote tables at all -- everything syncs."""
|
||||
|
||||
def test_all_tables_synced(self, mock_config, mock_data_source):
|
||||
local_only = [
|
||||
_make_table_config("t.l1", "local1", query_mode="local"),
|
||||
_make_table_config("t.l2", "local2", query_mode="local"),
|
||||
]
|
||||
mock_config.tables = local_only
|
||||
|
||||
with (
|
||||
patch("src.data_sync.get_config", return_value=mock_config),
|
||||
patch("src.data_sync.create_data_source", return_value=mock_data_source),
|
||||
patch("src.data_sync.SyncState"),
|
||||
):
|
||||
manager = DataSyncManager()
|
||||
manager._generate_schema_yaml = MagicMock()
|
||||
results = manager.sync_all()
|
||||
|
||||
assert mock_data_source.sync_table.call_count == 2
|
||||
assert "t.l1" in results
|
||||
assert "t.l2" in results
|
||||
|
||||
|
||||
class TestSyncAllWithTableFilter:
|
||||
"""When sync_all receives an explicit table list, remote filtering still applies."""
|
||||
|
||||
def test_explicit_remote_table_still_skipped(
|
||||
self, sync_manager, mock_data_source, table_remote
|
||||
):
|
||||
"""Even if explicitly listed, a remote table should be skipped."""
|
||||
sync_manager.sync_all(tables=[table_remote.id])
|
||||
|
||||
assert mock_data_source.sync_table.call_count == 0
|
||||
|
||||
def test_explicit_local_table_synced(
|
||||
self, sync_manager, mock_data_source, table_local
|
||||
):
|
||||
"""An explicitly listed local table should be synced."""
|
||||
sync_manager.sync_all(tables=[table_local.id])
|
||||
|
||||
assert mock_data_source.sync_table.call_count == 1
|
||||
synced_id = mock_data_source.sync_table.call_args_list[0].args[0].id
|
||||
assert synced_id == table_local.id
|
||||
488
tests/test_duckdb_manager.py
Normal file
488
tests/test_duckdb_manager.py
Normal file
|
|
@ -0,0 +1,488 @@
|
|||
"""Tests for DuckDB Manager - query_mode classification and BQ registration.
|
||||
|
||||
Tests cover:
|
||||
- _get_bq_project_from_table_id: extracting BQ project from table IDs
|
||||
- get_remote_tables: filtering tables by query_mode
|
||||
- register_bq_table: registering BQ query results in DuckDB
|
||||
- init_duckdb: table classification by query_mode, local view creation,
|
||||
remote table logging
|
||||
"""
|
||||
|
||||
import os
|
||||
from pathlib import Path
|
||||
from unittest.mock import MagicMock, patch
|
||||
|
||||
import duckdb
|
||||
import pyarrow as pa
|
||||
import pyarrow.parquet as pq
|
||||
import pytest
|
||||
|
||||
from scripts.duckdb_manager import (
|
||||
_get_bq_project_from_table_id,
|
||||
get_remote_tables,
|
||||
init_duckdb,
|
||||
register_bq_table,
|
||||
)
|
||||
|
||||
|
||||
# ---------------------------------------------------------------------------
|
||||
# Fixtures
|
||||
# ---------------------------------------------------------------------------
|
||||
|
||||
@pytest.fixture
|
||||
def tmp_project(tmp_path):
|
||||
"""Create a minimal project layout with docs/data_description.md and a parquet file.
|
||||
|
||||
Returns (project_root, db_path, data_dir) tuple.
|
||||
"""
|
||||
docs_dir = tmp_path / "docs"
|
||||
docs_dir.mkdir()
|
||||
|
||||
data_description = """\
|
||||
# Data Description
|
||||
|
||||
```yaml
|
||||
folder_mapping:
|
||||
in.c-crm: crm_data
|
||||
|
||||
tables:
|
||||
- id: "in.c-crm.company"
|
||||
name: "company"
|
||||
description: "Company master data"
|
||||
primary_key: "id"
|
||||
sync_strategy: "full_refresh"
|
||||
```
|
||||
"""
|
||||
(docs_dir / "data_description.md").write_text(data_description)
|
||||
|
||||
# Create parquet directory and a minimal parquet file
|
||||
data_dir = tmp_path / "server" / "parquet" / "crm_data"
|
||||
data_dir.mkdir(parents=True)
|
||||
|
||||
table = pa.table({"id": [1, 2, 3], "name": ["a", "b", "c"]})
|
||||
pq.write_table(table, data_dir / "company.parquet")
|
||||
|
||||
db_dir = tmp_path / "user" / "duckdb"
|
||||
db_dir.mkdir(parents=True)
|
||||
db_path = str(db_dir / "test.duckdb")
|
||||
|
||||
return tmp_path, db_path, str(tmp_path / "server")
|
||||
|
||||
|
||||
@pytest.fixture
|
||||
def tmp_project_mixed(tmp_path):
|
||||
"""Project with local, remote, and hybrid tables."""
|
||||
docs_dir = tmp_path / "docs"
|
||||
docs_dir.mkdir()
|
||||
|
||||
data_description = """\
|
||||
# Data Description
|
||||
|
||||
```yaml
|
||||
folder_mapping:
|
||||
in.c-crm: crm_data
|
||||
|
||||
tables:
|
||||
- id: "in.c-crm.company"
|
||||
name: "company"
|
||||
description: "Local table"
|
||||
primary_key: "id"
|
||||
sync_strategy: "full_refresh"
|
||||
|
||||
- id: "prj-grp-dataview-prod-1ff9.finance.revenue"
|
||||
name: "revenue"
|
||||
description: "Remote BQ table"
|
||||
primary_key: "id"
|
||||
query_mode: "remote"
|
||||
|
||||
- id: "prj-grp-dataview-prod-1ff9.marketing.campaigns"
|
||||
name: "campaigns"
|
||||
description: "Hybrid table"
|
||||
primary_key: "id"
|
||||
sync_strategy: "full_refresh"
|
||||
query_mode: "hybrid"
|
||||
```
|
||||
"""
|
||||
(docs_dir / "data_description.md").write_text(data_description)
|
||||
|
||||
# Create parquet files for local and hybrid tables
|
||||
crm_dir = tmp_path / "server" / "parquet" / "crm_data"
|
||||
crm_dir.mkdir(parents=True)
|
||||
table = pa.table({"id": [1, 2], "name": ["a", "b"]})
|
||||
pq.write_table(table, crm_dir / "company.parquet")
|
||||
|
||||
marketing_dir = tmp_path / "server" / "parquet" / "prj-grp-dataview-prod-1ff9.marketing"
|
||||
marketing_dir.mkdir(parents=True)
|
||||
campaigns_table = pa.table({"id": [10], "campaign": ["test"]})
|
||||
pq.write_table(campaigns_table, marketing_dir / "campaigns.parquet")
|
||||
|
||||
db_dir = tmp_path / "user" / "duckdb"
|
||||
db_dir.mkdir(parents=True)
|
||||
db_path = str(db_dir / "test.duckdb")
|
||||
|
||||
return tmp_path, db_path, str(tmp_path / "server")
|
||||
|
||||
|
||||
@pytest.fixture
|
||||
def tmp_project_remote_only(tmp_path):
|
||||
"""Project with only remote tables (no local parquet needed)."""
|
||||
docs_dir = tmp_path / "docs"
|
||||
docs_dir.mkdir()
|
||||
|
||||
data_description = """\
|
||||
# Data Description
|
||||
|
||||
```yaml
|
||||
tables:
|
||||
- id: "prj-grp-dataview-prod-1ff9.finance.revenue"
|
||||
name: "revenue"
|
||||
description: "Remote BQ table"
|
||||
primary_key: "id"
|
||||
query_mode: "remote"
|
||||
|
||||
- id: "prj-grp-dataview-prod-1ff9.finance.costs"
|
||||
name: "costs"
|
||||
description: "Remote BQ table"
|
||||
primary_key: "id"
|
||||
query_mode: "remote"
|
||||
```
|
||||
"""
|
||||
(docs_dir / "data_description.md").write_text(data_description)
|
||||
|
||||
db_dir = tmp_path / "user" / "duckdb"
|
||||
db_dir.mkdir(parents=True)
|
||||
db_path = str(db_dir / "test.duckdb")
|
||||
|
||||
return tmp_path, db_path, str(tmp_path / "server")
|
||||
|
||||
|
||||
# ---------------------------------------------------------------------------
|
||||
# Tests: _get_bq_project_from_table_id
|
||||
# ---------------------------------------------------------------------------
|
||||
|
||||
class TestGetBqProjectFromTableId:
|
||||
"""Test extracting BigQuery project ID from fully-qualified table IDs."""
|
||||
|
||||
def test_valid_bq_table_id(self):
|
||||
result = _get_bq_project_from_table_id(
|
||||
"prj-grp-dataview-prod-1ff9.finance.table"
|
||||
)
|
||||
assert result == "prj-grp-dataview-prod-1ff9"
|
||||
|
||||
def test_valid_bq_table_id_different_project(self):
|
||||
result = _get_bq_project_from_table_id(
|
||||
"my-gcp-project.dataset_name.table_name"
|
||||
)
|
||||
assert result == "my-gcp-project"
|
||||
|
||||
def test_keboola_format_returns_none(self):
|
||||
result = _get_bq_project_from_table_id("in.c-crm.table")
|
||||
assert result is None
|
||||
|
||||
def test_two_part_id_returns_none(self):
|
||||
result = _get_bq_project_from_table_id("dataset.table")
|
||||
assert result is None
|
||||
|
||||
def test_single_part_returns_none(self):
|
||||
result = _get_bq_project_from_table_id("table_only")
|
||||
assert result is None
|
||||
|
||||
def test_four_parts_returns_none(self):
|
||||
result = _get_bq_project_from_table_id("a-b.c.d.e")
|
||||
assert result is None
|
||||
|
||||
def test_empty_string_returns_none(self):
|
||||
result = _get_bq_project_from_table_id("")
|
||||
assert result is None
|
||||
|
||||
def test_three_parts_no_hyphen_returns_none(self):
|
||||
result = _get_bq_project_from_table_id("project.dataset.table")
|
||||
assert result is None
|
||||
|
||||
def test_hyphen_in_first_part_is_key(self):
|
||||
result = _get_bq_project_from_table_id("a-b.dataset.table")
|
||||
assert result == "a-b"
|
||||
|
||||
|
||||
# ---------------------------------------------------------------------------
|
||||
# Tests: get_remote_tables
|
||||
# ---------------------------------------------------------------------------
|
||||
|
||||
class TestGetRemoteTables:
|
||||
"""Test filtering table configs by query_mode."""
|
||||
|
||||
def test_returns_remote_tables(self):
|
||||
configs = [
|
||||
{"name": "local", "query_mode": "local"},
|
||||
{"name": "remote1", "query_mode": "remote"},
|
||||
{"name": "hybrid1", "query_mode": "hybrid"},
|
||||
]
|
||||
result = get_remote_tables(configs)
|
||||
names = [tc["name"] for tc in result]
|
||||
assert "remote1" in names
|
||||
assert "hybrid1" in names
|
||||
assert "local" not in names
|
||||
|
||||
def test_returns_empty_when_all_local(self):
|
||||
configs = [
|
||||
{"name": "t1", "query_mode": "local"},
|
||||
{"name": "t2"}, # defaults to local (no query_mode key)
|
||||
]
|
||||
result = get_remote_tables(configs)
|
||||
assert result == []
|
||||
|
||||
def test_missing_query_mode_treated_as_local(self):
|
||||
configs = [{"name": "t1"}] # no query_mode
|
||||
result = get_remote_tables(configs)
|
||||
assert result == []
|
||||
|
||||
|
||||
# ---------------------------------------------------------------------------
|
||||
# Tests: register_bq_table
|
||||
# ---------------------------------------------------------------------------
|
||||
|
||||
class TestRegisterBqTable:
|
||||
"""Test registering BQ query results as DuckDB views."""
|
||||
|
||||
@staticmethod
|
||||
def _make_factory(arrow_table, side_effect=None):
|
||||
"""Create a mock BQ client factory returning a client that yields arrow_table."""
|
||||
mock_job = MagicMock()
|
||||
if side_effect:
|
||||
mock_job.to_arrow.side_effect = side_effect
|
||||
else:
|
||||
mock_job.to_arrow.return_value = arrow_table
|
||||
mock_client = MagicMock()
|
||||
mock_client.query.return_value = mock_job
|
||||
factory = MagicMock(return_value=mock_client)
|
||||
factory._mock_client = mock_client
|
||||
factory._mock_job = mock_job
|
||||
return factory
|
||||
|
||||
def test_registers_arrow_table_in_duckdb(self):
|
||||
"""Result from BQ should be queryable in DuckDB after registration."""
|
||||
arrow_table = pa.table({"id": [1, 2], "val": [10.0, 20.0]})
|
||||
factory = self._make_factory(arrow_table)
|
||||
|
||||
conn = duckdb.connect()
|
||||
rows = register_bq_table(
|
||||
conn=conn,
|
||||
table_id="proj.dataset.table",
|
||||
view_name="test_view",
|
||||
sql="SELECT * FROM table",
|
||||
bq_project="test-project",
|
||||
_bq_client_factory=factory,
|
||||
)
|
||||
|
||||
assert rows == 2
|
||||
result = conn.execute("SELECT SUM(val) FROM test_view").fetchone()[0]
|
||||
assert result == 30.0
|
||||
conn.close()
|
||||
|
||||
def test_raises_without_bq_project(self):
|
||||
conn = duckdb.connect()
|
||||
with patch.dict(os.environ, {}, clear=True):
|
||||
with pytest.raises(ValueError, match="BigQuery project not set"):
|
||||
register_bq_table(
|
||||
conn=conn,
|
||||
table_id="proj.ds.tbl",
|
||||
view_name="v",
|
||||
sql="SELECT 1",
|
||||
)
|
||||
conn.close()
|
||||
|
||||
def test_uses_env_var_when_no_project_arg(self):
|
||||
arrow_table = pa.table({"x": [1]})
|
||||
factory = self._make_factory(arrow_table)
|
||||
|
||||
conn = duckdb.connect()
|
||||
with patch.dict(os.environ, {"BIGQUERY_PROJECT": "env-proj"}):
|
||||
register_bq_table(
|
||||
conn=conn,
|
||||
table_id="p.d.t",
|
||||
view_name="v",
|
||||
sql="SELECT 1",
|
||||
_bq_client_factory=factory,
|
||||
)
|
||||
|
||||
factory.assert_called_once_with("env-proj")
|
||||
conn.close()
|
||||
|
||||
def test_storage_api_fallback(self):
|
||||
"""Falls back to REST when Storage API permission denied."""
|
||||
arrow_table = pa.table({"x": [1]})
|
||||
factory = self._make_factory(
|
||||
arrow_table,
|
||||
side_effect=[
|
||||
Exception("PERMISSION_DENIED readsessions"),
|
||||
arrow_table,
|
||||
],
|
||||
)
|
||||
|
||||
conn = duckdb.connect()
|
||||
rows = register_bq_table(
|
||||
conn=conn,
|
||||
table_id="p.d.t",
|
||||
view_name="v",
|
||||
sql="SELECT 1",
|
||||
bq_project="proj",
|
||||
_bq_client_factory=factory,
|
||||
)
|
||||
|
||||
assert rows == 1
|
||||
factory._mock_job.to_arrow.assert_called_with(create_bqstorage_client=False)
|
||||
conn.close()
|
||||
|
||||
|
||||
# ---------------------------------------------------------------------------
|
||||
# Tests: init_duckdb - table classification
|
||||
# ---------------------------------------------------------------------------
|
||||
|
||||
class TestInitDuckdbClassification:
|
||||
"""Test that tables are correctly classified by query_mode."""
|
||||
|
||||
def test_local_tables_create_parquet_views(self, tmp_project):
|
||||
project_root, db_path, data_dir = tmp_project
|
||||
|
||||
with patch("scripts.duckdb_manager.find_project_root", return_value=project_root):
|
||||
result = init_duckdb(
|
||||
db_path=db_path, data_dir=data_dir, verbose=False
|
||||
)
|
||||
|
||||
assert result is True
|
||||
|
||||
conn = duckdb.connect(db_path, read_only=True)
|
||||
tables = [row[0] for row in conn.execute("SHOW TABLES").fetchall()]
|
||||
assert "company" in tables
|
||||
row_count = conn.execute("SELECT COUNT(*) FROM company").fetchone()[0]
|
||||
assert row_count == 3
|
||||
conn.close()
|
||||
|
||||
def test_remote_tables_not_created_as_local_views(self, tmp_project_mixed):
|
||||
project_root, db_path, data_dir = tmp_project_mixed
|
||||
|
||||
with patch("scripts.duckdb_manager.find_project_root", return_value=project_root):
|
||||
result = init_duckdb(
|
||||
db_path=db_path, data_dir=data_dir, verbose=False
|
||||
)
|
||||
|
||||
assert result is True
|
||||
|
||||
conn = duckdb.connect(db_path, read_only=True)
|
||||
tables = [row[0] for row in conn.execute("SHOW TABLES").fetchall()]
|
||||
assert "revenue" not in tables
|
||||
assert "company" in tables
|
||||
conn.close()
|
||||
|
||||
def test_hybrid_tables_create_local_views(self, tmp_project_mixed):
|
||||
project_root, db_path, data_dir = tmp_project_mixed
|
||||
|
||||
with patch("scripts.duckdb_manager.find_project_root", return_value=project_root):
|
||||
result = init_duckdb(
|
||||
db_path=db_path, data_dir=data_dir, verbose=False
|
||||
)
|
||||
|
||||
assert result is True
|
||||
|
||||
conn = duckdb.connect(db_path, read_only=True)
|
||||
tables = [row[0] for row in conn.execute("SHOW TABLES").fetchall()]
|
||||
assert "campaigns" in tables
|
||||
conn.close()
|
||||
|
||||
def test_default_query_mode_is_local(self, tmp_project):
|
||||
project_root, db_path, data_dir = tmp_project
|
||||
|
||||
with patch("scripts.duckdb_manager.find_project_root", return_value=project_root):
|
||||
result = init_duckdb(
|
||||
db_path=db_path, data_dir=data_dir, verbose=False
|
||||
)
|
||||
|
||||
assert result is True
|
||||
|
||||
conn = duckdb.connect(db_path, read_only=True)
|
||||
tables = [row[0] for row in conn.execute("SHOW TABLES").fetchall()]
|
||||
assert "company" in tables
|
||||
conn.close()
|
||||
|
||||
|
||||
# ---------------------------------------------------------------------------
|
||||
# Tests: init_duckdb - remote table logging
|
||||
# ---------------------------------------------------------------------------
|
||||
|
||||
class TestInitDuckdbRemoteLogging:
|
||||
"""Test that remote tables are logged correctly."""
|
||||
|
||||
def test_remote_tables_logged(self, tmp_project_remote_only, capsys):
|
||||
project_root, db_path, data_dir = tmp_project_remote_only
|
||||
|
||||
with patch("scripts.duckdb_manager.find_project_root", return_value=project_root):
|
||||
init_duckdb(
|
||||
db_path=db_path, data_dir=data_dir, verbose=True,
|
||||
)
|
||||
|
||||
output = capsys.readouterr().out
|
||||
assert "revenue" in output
|
||||
assert "costs" in output
|
||||
assert "[BQ]" in output
|
||||
|
||||
def test_remote_only_project_succeeds(self, tmp_project_remote_only):
|
||||
project_root, db_path, data_dir = tmp_project_remote_only
|
||||
|
||||
with patch("scripts.duckdb_manager.find_project_root", return_value=project_root):
|
||||
result = init_duckdb(
|
||||
db_path=db_path, data_dir=data_dir, verbose=False,
|
||||
)
|
||||
|
||||
assert result is True
|
||||
|
||||
|
||||
# ---------------------------------------------------------------------------
|
||||
# Tests: init_duckdb - missing parquet handling
|
||||
# ---------------------------------------------------------------------------
|
||||
|
||||
class TestInitDuckdbMissingParquet:
|
||||
"""Test behavior when parquet files are missing."""
|
||||
|
||||
def test_missing_parquet_skips_view(self, tmp_path):
|
||||
docs_dir = tmp_path / "docs"
|
||||
docs_dir.mkdir()
|
||||
|
||||
data_description = """\
|
||||
# Data Description
|
||||
|
||||
```yaml
|
||||
tables:
|
||||
- id: "in.c-crm.missing_table"
|
||||
name: "missing_table"
|
||||
description: "No parquet exists"
|
||||
primary_key: "id"
|
||||
sync_strategy: "full_refresh"
|
||||
```
|
||||
"""
|
||||
(docs_dir / "data_description.md").write_text(data_description)
|
||||
|
||||
db_dir = tmp_path / "user" / "duckdb"
|
||||
db_dir.mkdir(parents=True)
|
||||
db_path = str(db_dir / "test.duckdb")
|
||||
|
||||
with patch("scripts.duckdb_manager.find_project_root", return_value=tmp_path):
|
||||
result = init_duckdb(
|
||||
db_path=db_path, data_dir=str(tmp_path / "server"), verbose=False
|
||||
)
|
||||
|
||||
assert result is True
|
||||
|
||||
conn = duckdb.connect(db_path, read_only=True)
|
||||
tables = [row[0] for row in conn.execute("SHOW TABLES").fetchall()]
|
||||
assert "missing_table" not in tables
|
||||
conn.close()
|
||||
|
||||
def test_remote_table_no_local_parquet_needed(self, tmp_project_remote_only):
|
||||
project_root, db_path, data_dir = tmp_project_remote_only
|
||||
|
||||
with patch("scripts.duckdb_manager.find_project_root", return_value=project_root):
|
||||
result = init_duckdb(
|
||||
db_path=db_path, data_dir=data_dir, verbose=False,
|
||||
)
|
||||
|
||||
assert result is True
|
||||
Loading…
Reference in a new issue