refactor: delete old sync pipeline — 9,500 lines removed
Phase 5 cleanup: remove all code replaced by extract.duckdb architecture. Deleted modules: - src/config.py (653) — replaced by DuckDB table_registry - src/parquet_manager.py (755) — replaced by DuckDB COPY TO - src/data_sync.py (734) — replaced by SyncOrchestrator - src/remote_query.py (636) — replaced by DuckDB BigQuery ATTACH - src/table_registry.py (464) — replaced by DuckDB repository - connectors/keboola/adapter.py (820) — replaced by extractor.py - connectors/bigquery/adapter.py (665) — replaced by extractor.py - connectors/bigquery/client.py (644) — replaced by DuckDB BQ extension Updated all imports in webapp, catalog_export, enricher, router, sync_settings_service, generate_sample_data. Kept keboola/client.py as fallback (removed src.config dependency). 704 tests passing.
This commit is contained in:
parent
9f20529f10
commit
b502bd8bdd
26 changed files with 188 additions and 9490 deletions
|
|
@ -236,24 +236,25 @@ async def catalog(
|
|||
enabled_datasets = settings_repo.get_enabled_datasets(user["id"])
|
||||
datasets = get_datasets()
|
||||
|
||||
# Build catalog data from config
|
||||
# Build catalog data from table_registry in DuckDB
|
||||
try:
|
||||
from src.config import get_config
|
||||
config = get_config()
|
||||
from src.repositories.table_registry import TableRegistryRepository
|
||||
table_repo = TableRegistryRepository(conn)
|
||||
registered = table_repo.list_all()
|
||||
tables = []
|
||||
for tc in config.tables:
|
||||
for tc in registered:
|
||||
table_data = {
|
||||
"id": tc.id,
|
||||
"name": tc.name,
|
||||
"description": tc.description,
|
||||
"dataset": getattr(tc, "dataset", None),
|
||||
"sync_strategy": tc.sync_strategy,
|
||||
"query_mode": getattr(tc, "query_mode", "local"),
|
||||
"profile": all_profiles.get(tc.id),
|
||||
"id": tc.get("id", ""),
|
||||
"name": tc.get("name", ""),
|
||||
"description": tc.get("description", ""),
|
||||
"dataset": tc.get("bucket"),
|
||||
"sync_strategy": tc.get("sync_strategy", "full_refresh"),
|
||||
"query_mode": tc.get("query_mode", "local"),
|
||||
"profile": all_profiles.get(tc.get("id", "")),
|
||||
}
|
||||
# Add sync state
|
||||
for state in all_states:
|
||||
if state["table_id"] == tc.id:
|
||||
if state["table_id"] == tc.get("id"):
|
||||
table_data["last_sync"] = state.get("last_sync")
|
||||
table_data["rows"] = state.get("rows")
|
||||
break
|
||||
|
|
|
|||
|
|
@ -1,665 +0,0 @@
|
|||
"""
|
||||
BigQuery data source adapter.
|
||||
|
||||
Implements the DataSource interface for Google BigQuery.
|
||||
Reads tables via the BigQuery API, converts directly to Parquet files
|
||||
using PyArrow (no CSV intermediate step).
|
||||
"""
|
||||
|
||||
import logging
|
||||
from pathlib import Path
|
||||
from typing import Dict, List, Optional, Any
|
||||
from datetime import datetime, timedelta, date
|
||||
|
||||
import pyarrow as pa
|
||||
import pyarrow.parquet as pq
|
||||
|
||||
from src.config import get_config, TableConfig
|
||||
from src.data_sync import DataSource, SyncState, _get_uncompressed_size
|
||||
from src.parquet_manager import (
|
||||
convert_date_columns_to_date32,
|
||||
apply_schema_to_table,
|
||||
)
|
||||
from .client import create_client as create_bq_client
|
||||
|
||||
|
||||
logger = logging.getLogger(__name__)
|
||||
|
||||
|
||||
class BigQueryDataSource(DataSource):
|
||||
"""
|
||||
Data source: Google BigQuery.
|
||||
|
||||
Downloads data directly from BigQuery via PyArrow (no CSV step),
|
||||
writes to local Parquet files with schema enforcement.
|
||||
"""
|
||||
|
||||
def __init__(self):
|
||||
"""Initialize BigQuery source with env var validation."""
|
||||
self.config = get_config()
|
||||
self.bq_client = create_bq_client()
|
||||
|
||||
def get_column_metadata(self, table_id: str) -> Optional[Dict[str, Any]]:
|
||||
"""Return BigQuery column metadata for schema generation.
|
||||
|
||||
Returns:
|
||||
{"columns": {"col_name": {"source_type": "...", "description": "..."}}}
|
||||
or None if metadata unavailable.
|
||||
"""
|
||||
raw = self.bq_client.get_table_metadata(table_id)
|
||||
column_types = raw.get("column_types", {})
|
||||
column_descriptions = raw.get("column_descriptions", {})
|
||||
|
||||
if not column_types:
|
||||
return None
|
||||
|
||||
result = {}
|
||||
for col_name, bq_type in column_types.items():
|
||||
entry = {"source_type": bq_type}
|
||||
if col_name in column_descriptions:
|
||||
entry["description"] = column_descriptions[col_name]
|
||||
result[col_name] = entry
|
||||
|
||||
return {"columns": result}
|
||||
|
||||
def discover_tables(self) -> List[Dict[str, Any]]:
|
||||
"""Discover all available tables from BigQuery."""
|
||||
return self.bq_client.discover_all_tables()
|
||||
|
||||
def get_source_name(self) -> str:
|
||||
"""Display name of this data source."""
|
||||
return "Google BigQuery"
|
||||
|
||||
def sync_table(
|
||||
self,
|
||||
table_config: TableConfig,
|
||||
sync_state: SyncState,
|
||||
) -> Dict[str, Any]:
|
||||
"""
|
||||
Synchronize table from BigQuery.
|
||||
|
||||
Dispatches to the appropriate strategy based on table config.
|
||||
|
||||
Args:
|
||||
table_config: Table configuration
|
||||
sync_state: Sync state manager
|
||||
|
||||
Returns:
|
||||
Dictionary with sync result
|
||||
"""
|
||||
logger.info(f"Syncing BQ table: {table_config.name} ({table_config.sync_strategy})")
|
||||
|
||||
# Clear metadata cache for fresh types
|
||||
if table_config.id in self.bq_client.metadata_cache:
|
||||
del self.bq_client.metadata_cache[table_config.id]
|
||||
logger.debug(f"Cleared BQ metadata cache for {table_config.id}")
|
||||
|
||||
try:
|
||||
if table_config.sync_strategy == "full_refresh":
|
||||
result = self._full_refresh(table_config)
|
||||
elif table_config.sync_strategy == "incremental":
|
||||
result = self._incremental_sync(table_config, sync_state)
|
||||
elif table_config.sync_strategy == "partitioned":
|
||||
result = self._partitioned_sync(table_config, sync_state)
|
||||
else:
|
||||
raise ValueError(f"Unknown sync strategy: {table_config.sync_strategy}")
|
||||
|
||||
# Skip sync state update if partitioned sync got no new data.
|
||||
# This lets the scheduler retry on the next tick instead of
|
||||
# marking the sync as done for the day with stale data.
|
||||
skip_state_update = (
|
||||
table_config.sync_strategy == "partitioned"
|
||||
and result.get("partitions_updated", -1) == 0
|
||||
)
|
||||
|
||||
if skip_state_update:
|
||||
logger.warning(
|
||||
f"Partitioned sync for {table_config.name} got 0 new partitions "
|
||||
f"- NOT updating last_sync (will retry next tick)"
|
||||
)
|
||||
else:
|
||||
sync_state.update_sync(
|
||||
table_id=table_config.id,
|
||||
table_name=table_config.name,
|
||||
strategy=table_config.sync_strategy,
|
||||
rows=result["rows"],
|
||||
file_size_bytes=result["file_size_bytes"],
|
||||
columns=result.get("columns", 0),
|
||||
uncompressed_bytes=result.get("uncompressed_bytes", 0),
|
||||
)
|
||||
|
||||
return {
|
||||
"success": True,
|
||||
"rows": result["rows"],
|
||||
"strategy": table_config.sync_strategy,
|
||||
"file_size_mb": result["file_size_bytes"] / 1024 / 1024,
|
||||
}
|
||||
|
||||
except Exception as e:
|
||||
logger.error(f"Error syncing BQ table {table_config.name}: {e}")
|
||||
return {
|
||||
"success": False,
|
||||
"error": str(e),
|
||||
"strategy": table_config.sync_strategy,
|
||||
}
|
||||
|
||||
def _full_refresh(self, table_config: TableConfig) -> Dict[str, Any]:
|
||||
"""
|
||||
Full refresh: stream table from BQ and write to Parquet in batches.
|
||||
|
||||
Uses streaming (constant memory) instead of loading entire table into RAM.
|
||||
Each RecordBatch from BQ is written directly to disk via ParquetWriter.
|
||||
"""
|
||||
logger.info(f"Full refresh (streaming): {table_config.name}")
|
||||
|
||||
parquet_path = self.config.get_parquet_path(table_config)
|
||||
date_columns = self.bq_client.get_date_columns(table_config.id)
|
||||
pyarrow_schema = self.bq_client.get_pyarrow_schema(table_config.id)
|
||||
|
||||
parquet_path.parent.mkdir(parents=True, exist_ok=True)
|
||||
|
||||
# Stream BQ results directly to Parquet file (constant memory)
|
||||
writer = None
|
||||
total_rows = 0
|
||||
num_columns = 0
|
||||
|
||||
for batch in self.bq_client.read_table_streaming(
|
||||
table_config.id,
|
||||
columns=table_config.columns,
|
||||
row_filter=table_config.row_filter,
|
||||
):
|
||||
if batch.num_rows == 0:
|
||||
continue
|
||||
|
||||
# Convert batch to table for schema enforcement
|
||||
chunk = pa.Table.from_batches([batch])
|
||||
if date_columns:
|
||||
chunk = convert_date_columns_to_date32(chunk, date_columns)
|
||||
if pyarrow_schema:
|
||||
chunk = apply_schema_to_table(chunk, pyarrow_schema)
|
||||
|
||||
if writer is None:
|
||||
writer = pq.ParquetWriter(
|
||||
parquet_path, chunk.schema, compression="snappy",
|
||||
)
|
||||
num_columns = chunk.num_columns
|
||||
|
||||
writer.write_table(chunk)
|
||||
total_rows += chunk.num_rows
|
||||
|
||||
# Log progress every ~1M rows
|
||||
if total_rows % 1_000_000 < chunk.num_rows:
|
||||
logger.info(f" -> {total_rows:,} rows written...")
|
||||
|
||||
if writer:
|
||||
writer.close()
|
||||
|
||||
file_size = parquet_path.stat().st_size if parquet_path.exists() else 0
|
||||
logger.info(
|
||||
f"Full refresh complete: {total_rows:,} rows, "
|
||||
f"{file_size / 1024 / 1024:.2f} MB"
|
||||
)
|
||||
|
||||
return {
|
||||
"rows": total_rows,
|
||||
"columns": num_columns,
|
||||
"file_size_bytes": file_size,
|
||||
"uncompressed_bytes": _get_uncompressed_size(parquet_path) if total_rows > 0 else 0,
|
||||
}
|
||||
|
||||
def _incremental_sync(
|
||||
self,
|
||||
table_config: TableConfig,
|
||||
sync_state: SyncState,
|
||||
) -> Dict[str, Any]:
|
||||
"""
|
||||
Incremental sync: dispatch to column-based or partition-based strategy.
|
||||
"""
|
||||
# If partition_by is set, use partitioned incremental
|
||||
if table_config.partition_by:
|
||||
return self._partitioned_sync(table_config, sync_state)
|
||||
|
||||
# If incremental_column is set, use timestamp-based incremental
|
||||
if table_config.incremental_column:
|
||||
return self._incremental_column_sync(table_config, sync_state)
|
||||
|
||||
# Fallback: full refresh (no incremental column configured)
|
||||
logger.warning(
|
||||
f"Table {table_config.name}: incremental strategy but no "
|
||||
f"incremental_column or partition_by configured, falling back to full refresh"
|
||||
)
|
||||
return self._full_refresh(table_config)
|
||||
|
||||
def _incremental_column_sync(
|
||||
self,
|
||||
table_config: TableConfig,
|
||||
sync_state: SyncState,
|
||||
) -> Dict[str, Any]:
|
||||
"""
|
||||
Timestamp-based incremental sync using incremental_column.
|
||||
|
||||
Reads rows WHERE incremental_column > last_sync_value,
|
||||
merges with existing Parquet (dedup on primary key).
|
||||
"""
|
||||
logger.info(
|
||||
f"Incremental column sync: {table_config.name} "
|
||||
f"(column: {table_config.incremental_column})"
|
||||
)
|
||||
|
||||
parquet_path = self.config.get_parquet_path(table_config)
|
||||
date_columns = self.bq_client.get_date_columns(table_config.id)
|
||||
pyarrow_schema = self.bq_client.get_pyarrow_schema(table_config.id)
|
||||
|
||||
# Determine since_value from last sync
|
||||
last_sync = sync_state.get_last_sync(table_config.id)
|
||||
|
||||
if last_sync and parquet_path.exists():
|
||||
# Apply window: go back incremental_window_days from last sync
|
||||
last_sync_dt = datetime.fromisoformat(last_sync)
|
||||
window_days = table_config.incremental_window_days or 7
|
||||
since_dt = last_sync_dt - timedelta(days=window_days)
|
||||
since_value = since_dt.isoformat()
|
||||
|
||||
logger.info(f" -> Since: {since_value} (window: {window_days} days)")
|
||||
|
||||
# Read incremental data
|
||||
new_data = self.bq_client.read_table_incremental(
|
||||
table_id=table_config.id,
|
||||
incremental_column=table_config.incremental_column,
|
||||
since_value=since_value,
|
||||
columns=table_config.columns,
|
||||
)
|
||||
|
||||
if new_data.num_rows == 0:
|
||||
logger.info(" -> No new data since last sync")
|
||||
existing_pf = pq.ParquetFile(parquet_path)
|
||||
return {
|
||||
"rows": existing_pf.metadata.num_rows,
|
||||
"columns": len(existing_pf.schema_arrow),
|
||||
"file_size_bytes": parquet_path.stat().st_size,
|
||||
"uncompressed_bytes": _get_uncompressed_size(parquet_path),
|
||||
}
|
||||
|
||||
# Merge with existing data
|
||||
logger.info(f" -> Merging {new_data.num_rows} new rows with existing data")
|
||||
existing_table = pq.read_table(parquet_path)
|
||||
merged = self._merge_arrow_tables(
|
||||
existing_table, new_data, table_config.get_primary_key_columns()
|
||||
)
|
||||
|
||||
# Apply schema enforcement
|
||||
if date_columns:
|
||||
merged = convert_date_columns_to_date32(merged, date_columns)
|
||||
if pyarrow_schema:
|
||||
merged = apply_schema_to_table(merged, pyarrow_schema)
|
||||
|
||||
pq.write_table(merged, parquet_path, compression="snappy")
|
||||
|
||||
file_size = parquet_path.stat().st_size
|
||||
logger.info(
|
||||
f" -> Incremental sync complete: {merged.num_rows} total rows"
|
||||
)
|
||||
|
||||
return {
|
||||
"rows": merged.num_rows,
|
||||
"columns": merged.num_columns,
|
||||
"file_size_bytes": file_size,
|
||||
"uncompressed_bytes": _get_uncompressed_size(parquet_path),
|
||||
}
|
||||
|
||||
else:
|
||||
# First sync or no existing file -- full read
|
||||
logger.info(" -> First sync, reading all data")
|
||||
|
||||
if table_config.max_history_days:
|
||||
since_dt = datetime.now() - timedelta(days=table_config.max_history_days)
|
||||
arrow_table = self.bq_client.read_table_incremental(
|
||||
table_id=table_config.id,
|
||||
incremental_column=table_config.incremental_column,
|
||||
since_value=since_dt.isoformat(),
|
||||
columns=table_config.columns,
|
||||
)
|
||||
else:
|
||||
arrow_table = self.bq_client.read_table(
|
||||
table_config.id,
|
||||
columns=table_config.columns,
|
||||
row_filter=table_config.row_filter,
|
||||
)
|
||||
|
||||
# Apply schema enforcement
|
||||
if date_columns:
|
||||
arrow_table = convert_date_columns_to_date32(arrow_table, date_columns)
|
||||
if pyarrow_schema:
|
||||
arrow_table = apply_schema_to_table(arrow_table, pyarrow_schema)
|
||||
|
||||
parquet_path.parent.mkdir(parents=True, exist_ok=True)
|
||||
pq.write_table(arrow_table, parquet_path, compression="snappy")
|
||||
|
||||
file_size = parquet_path.stat().st_size
|
||||
return {
|
||||
"rows": arrow_table.num_rows,
|
||||
"columns": arrow_table.num_columns,
|
||||
"file_size_bytes": file_size,
|
||||
"uncompressed_bytes": _get_uncompressed_size(parquet_path),
|
||||
}
|
||||
|
||||
def _partitioned_sync(
|
||||
self,
|
||||
table_config: TableConfig,
|
||||
sync_state: SyncState,
|
||||
) -> Dict[str, Any]:
|
||||
"""
|
||||
Per-partition streaming sync: process one partition (day) at a time.
|
||||
|
||||
Queries BQ for a single day, streams result to disk, then moves to next day.
|
||||
Memory usage is constant (~20-50 MB per partition) regardless of total data volume.
|
||||
"""
|
||||
partition_col = table_config.partition_by
|
||||
if not partition_col and table_config.incremental_column:
|
||||
partition_col = table_config.incremental_column
|
||||
|
||||
if not partition_col:
|
||||
logger.warning(
|
||||
f"Table {table_config.name}: partitioned strategy but no "
|
||||
f"partition_by or incremental_column, falling back to full refresh"
|
||||
)
|
||||
return self._full_refresh(table_config)
|
||||
|
||||
granularity = table_config.partition_granularity or "day"
|
||||
column_type = table_config.partition_column_type
|
||||
logger.info(
|
||||
f"Partitioned sync: {table_config.name} "
|
||||
f"(by {partition_col}, {granularity}, type={column_type})"
|
||||
)
|
||||
|
||||
partition_dir = self.config.get_parquet_path(table_config)
|
||||
date_columns = self.bq_client.get_date_columns(table_config.id)
|
||||
pyarrow_schema = self.bq_client.get_pyarrow_schema(table_config.id)
|
||||
|
||||
# Determine date range
|
||||
last_sync = sync_state.get_last_sync(table_config.id)
|
||||
today = date.today()
|
||||
|
||||
if last_sync:
|
||||
last_sync_dt = datetime.fromisoformat(last_sync)
|
||||
window_days = table_config.incremental_window_days or 7
|
||||
start_date = (last_sync_dt - timedelta(days=window_days)).date()
|
||||
logger.info(f" -> Incremental sync from {start_date} (window: {window_days} days)")
|
||||
else:
|
||||
if table_config.max_history_days:
|
||||
start_date = today - timedelta(days=table_config.max_history_days)
|
||||
logger.info(f" -> First sync, last {table_config.max_history_days} days from {start_date}")
|
||||
else:
|
||||
start_date = today - timedelta(days=365)
|
||||
logger.info(" -> First sync, no max_history_days, defaulting to 365 days")
|
||||
|
||||
# Generate list of partition dates
|
||||
partition_dates = self._generate_partition_dates(start_date, today, granularity)
|
||||
logger.info(f" -> Processing {len(partition_dates)} partitions")
|
||||
|
||||
total_rows = 0
|
||||
partitions_updated = 0
|
||||
|
||||
for partition_date in partition_dates:
|
||||
rows = self._sync_single_partition(
|
||||
table_config=table_config,
|
||||
partition_col=partition_col,
|
||||
partition_date=partition_date,
|
||||
partition_dir=partition_dir,
|
||||
date_columns=date_columns,
|
||||
pyarrow_schema=pyarrow_schema,
|
||||
granularity=granularity,
|
||||
column_type=column_type,
|
||||
)
|
||||
if rows > 0:
|
||||
partitions_updated += 1
|
||||
total_rows += rows
|
||||
|
||||
# Cleanup old partitions beyond retention window
|
||||
deleted = self._cleanup_old_partitions(table_config, partition_dir, granularity)
|
||||
if deleted > 0:
|
||||
logger.info(f" -> Cleaned up {deleted} old partition files")
|
||||
|
||||
logger.info(
|
||||
f" -> Partitioned sync complete: {partitions_updated} partitions updated, "
|
||||
f"{total_rows} total rows processed"
|
||||
)
|
||||
|
||||
result = self._get_partition_totals(partition_dir)
|
||||
result["partitions_updated"] = partitions_updated
|
||||
return result
|
||||
|
||||
@staticmethod
|
||||
def _generate_partition_dates(
|
||||
start_date: date,
|
||||
end_date: date,
|
||||
granularity: str,
|
||||
) -> List[date]:
|
||||
"""Generate list of partition start dates between start and end."""
|
||||
dates = []
|
||||
current = start_date
|
||||
|
||||
if granularity == "day":
|
||||
while current <= end_date:
|
||||
dates.append(current)
|
||||
current += timedelta(days=1)
|
||||
elif granularity == "month":
|
||||
# Align to first of month
|
||||
current = current.replace(day=1)
|
||||
while current <= end_date:
|
||||
dates.append(current)
|
||||
# Move to first of next month
|
||||
if current.month == 12:
|
||||
current = current.replace(year=current.year + 1, month=1)
|
||||
else:
|
||||
current = current.replace(month=current.month + 1)
|
||||
elif granularity == "year":
|
||||
current = current.replace(month=1, day=1)
|
||||
while current <= end_date:
|
||||
dates.append(current)
|
||||
current = current.replace(year=current.year + 1)
|
||||
|
||||
return dates
|
||||
|
||||
def _sync_single_partition(
|
||||
self,
|
||||
table_config: TableConfig,
|
||||
partition_col: str,
|
||||
partition_date: date,
|
||||
partition_dir: Path,
|
||||
date_columns: List[str],
|
||||
pyarrow_schema,
|
||||
granularity: str,
|
||||
column_type: str,
|
||||
) -> int:
|
||||
"""
|
||||
Query BQ for one partition period, stream to disk, merge with existing file.
|
||||
|
||||
Returns row count for this partition after merge.
|
||||
"""
|
||||
import pandas as pd
|
||||
|
||||
# Calculate partition range [start, end)
|
||||
start = partition_date
|
||||
if granularity == "day":
|
||||
end = start + timedelta(days=1)
|
||||
partition_key = start.strftime("%Y_%m_%d")
|
||||
elif granularity == "month":
|
||||
if start.month == 12:
|
||||
end = start.replace(year=start.year + 1, month=1)
|
||||
else:
|
||||
end = start.replace(month=start.month + 1)
|
||||
partition_key = start.strftime("%Y_%m")
|
||||
elif granularity == "year":
|
||||
end = start.replace(year=start.year + 1)
|
||||
partition_key = start.strftime("%Y")
|
||||
else:
|
||||
raise ValueError(f"Unknown granularity: {granularity}")
|
||||
|
||||
partition_path = self.config.get_partition_path(table_config, partition_key)
|
||||
|
||||
# Stream data from BQ for this single partition
|
||||
batches = []
|
||||
for batch in self.bq_client.read_table_partitioned_streaming(
|
||||
table_id=table_config.id,
|
||||
partition_column=partition_col,
|
||||
start=start.isoformat(),
|
||||
end=end.isoformat(),
|
||||
columns=table_config.columns,
|
||||
column_type=column_type,
|
||||
):
|
||||
batches.append(batch)
|
||||
|
||||
if not batches:
|
||||
return 0
|
||||
|
||||
new_data = pa.Table.from_batches(batches)
|
||||
if new_data.num_rows == 0:
|
||||
return 0
|
||||
|
||||
# Apply schema conversions
|
||||
if date_columns:
|
||||
new_data = convert_date_columns_to_date32(new_data, date_columns)
|
||||
if pyarrow_schema:
|
||||
new_data = apply_schema_to_table(new_data, pyarrow_schema)
|
||||
|
||||
# Merge with existing partition file if present
|
||||
primary_key_cols = table_config.get_primary_key_columns()
|
||||
|
||||
if partition_path.exists():
|
||||
existing = pq.read_table(partition_path)
|
||||
merged = self._merge_arrow_tables(existing, new_data, primary_key_cols)
|
||||
else:
|
||||
merged = new_data
|
||||
|
||||
# Write partition file
|
||||
pq.write_table(merged, partition_path, compression="snappy")
|
||||
row_count = merged.num_rows
|
||||
|
||||
logger.debug(
|
||||
f" Partition {partition_key}: {new_data.num_rows} new rows, "
|
||||
f"{row_count} total after merge"
|
||||
)
|
||||
|
||||
# Release memory
|
||||
del batches, new_data, merged
|
||||
|
||||
return row_count
|
||||
|
||||
def _cleanup_old_partitions(
|
||||
self,
|
||||
table_config: TableConfig,
|
||||
partition_dir: Path,
|
||||
granularity: str,
|
||||
) -> int:
|
||||
"""
|
||||
Delete partition files older than max_history_days.
|
||||
|
||||
Returns count of deleted files.
|
||||
"""
|
||||
if not table_config.max_history_days:
|
||||
return 0
|
||||
|
||||
if not partition_dir.exists():
|
||||
return 0
|
||||
|
||||
cutoff_date = date.today() - timedelta(days=table_config.max_history_days)
|
||||
deleted = 0
|
||||
|
||||
for part_path in partition_dir.glob("*.parquet"):
|
||||
try:
|
||||
partition_date = self._parse_partition_date(part_path.stem, granularity)
|
||||
if partition_date and partition_date < cutoff_date:
|
||||
part_path.unlink()
|
||||
deleted += 1
|
||||
logger.debug(f" Deleted old partition: {part_path.name}")
|
||||
except (ValueError, IndexError):
|
||||
logger.warning(f" Skipping unrecognized partition file: {part_path.name}")
|
||||
|
||||
return deleted
|
||||
|
||||
@staticmethod
|
||||
def _parse_partition_date(partition_key: str, granularity: str) -> Optional[date]:
|
||||
"""Parse a partition key back to a date."""
|
||||
try:
|
||||
if granularity == "day":
|
||||
return datetime.strptime(partition_key, "%Y_%m_%d").date()
|
||||
elif granularity == "month":
|
||||
return datetime.strptime(partition_key, "%Y_%m").date()
|
||||
elif granularity == "year":
|
||||
return datetime.strptime(partition_key, "%Y").date()
|
||||
except ValueError:
|
||||
return None
|
||||
return None
|
||||
|
||||
def _merge_arrow_tables(
|
||||
self,
|
||||
existing: pa.Table,
|
||||
new_data: pa.Table,
|
||||
primary_key: List[str],
|
||||
) -> pa.Table:
|
||||
"""
|
||||
Merge two Arrow tables with deduplication on primary key.
|
||||
|
||||
New data overwrites existing rows with the same primary key.
|
||||
|
||||
Args:
|
||||
existing: Existing data
|
||||
new_data: New/changed data
|
||||
primary_key: List of PK column names
|
||||
|
||||
Returns:
|
||||
Merged PyArrow Table
|
||||
"""
|
||||
import pandas as pd
|
||||
|
||||
existing_df = existing.to_pandas()
|
||||
new_df = new_data.to_pandas()
|
||||
|
||||
# Concat and dedup (keep last = new data wins)
|
||||
merged_df = pd.concat([existing_df, new_df], ignore_index=True)
|
||||
merged_df = merged_df.drop_duplicates(subset=primary_key, keep="last")
|
||||
|
||||
return pa.Table.from_pandas(merged_df, preserve_index=False)
|
||||
|
||||
def _get_partition_totals(self, partition_dir: Path) -> Dict[str, Any]:
|
||||
"""
|
||||
Calculate totals from all partition files in a directory.
|
||||
"""
|
||||
total_rows = 0
|
||||
total_size = 0
|
||||
total_uncompressed = 0
|
||||
total_columns = 0
|
||||
|
||||
if not partition_dir.exists():
|
||||
return {"rows": 0, "file_size_bytes": 0, "columns": 0, "uncompressed_bytes": 0}
|
||||
|
||||
all_partitions = list(partition_dir.glob("*.parquet"))
|
||||
|
||||
for part_path in all_partitions:
|
||||
try:
|
||||
pf = pq.ParquetFile(part_path)
|
||||
meta = pf.metadata
|
||||
total_rows += meta.num_rows
|
||||
total_size += part_path.stat().st_size
|
||||
if total_columns == 0:
|
||||
total_columns = len(pf.schema_arrow)
|
||||
for rg_idx in range(meta.num_row_groups):
|
||||
rg = meta.row_group(rg_idx)
|
||||
for col_idx in range(rg.num_columns):
|
||||
total_uncompressed += rg.column(col_idx).total_uncompressed_size
|
||||
except Exception as e:
|
||||
logger.warning(f"Skipping corrupt partition {part_path.name}: {e}")
|
||||
|
||||
return {
|
||||
"rows": total_rows,
|
||||
"file_size_bytes": total_size,
|
||||
"partitions": len(all_partitions),
|
||||
"columns": total_columns,
|
||||
"uncompressed_bytes": total_uncompressed,
|
||||
}
|
||||
|
||||
|
||||
def create_data_source() -> BigQueryDataSource:
|
||||
"""Factory function for dynamic import compatibility."""
|
||||
return BigQueryDataSource()
|
||||
|
|
@ -1,644 +0,0 @@
|
|||
"""
|
||||
Google BigQuery API Client
|
||||
|
||||
Low-level wrapper for Google BigQuery with these functions:
|
||||
1. Authentication using Application Default Credentials (ADC)
|
||||
2. Query tables to PyArrow (no CSV intermediate step)
|
||||
3. Get table metadata (schema, columns, data types)
|
||||
4. Cache metadata for faster repeated use
|
||||
5. Incremental reads (timestamp-based and partition-based)
|
||||
|
||||
Uses google-cloud-bigquery with native PyArrow support.
|
||||
"""
|
||||
|
||||
import json
|
||||
import logging
|
||||
import os
|
||||
from pathlib import Path
|
||||
from typing import Dict, List, Optional, Any
|
||||
from datetime import datetime, timedelta
|
||||
|
||||
import pyarrow as pa
|
||||
from google.cloud import bigquery
|
||||
|
||||
try:
|
||||
from google.cloud import bigquery_storage_v1
|
||||
|
||||
_HAS_BQ_STORAGE = True
|
||||
except ImportError:
|
||||
_HAS_BQ_STORAGE = False
|
||||
|
||||
from src.config import get_config
|
||||
|
||||
|
||||
logger = logging.getLogger(__name__)
|
||||
|
||||
|
||||
# Mapping BigQuery types to PyArrow types
|
||||
BIGQUERY_TO_PYARROW_TYPES = {
|
||||
"STRING": pa.string(),
|
||||
"BYTES": pa.binary(),
|
||||
"INTEGER": pa.int64(),
|
||||
"INT64": pa.int64(),
|
||||
"FLOAT": pa.float64(),
|
||||
"FLOAT64": pa.float64(),
|
||||
"NUMERIC": pa.float64(),
|
||||
"BIGNUMERIC": pa.float64(),
|
||||
"BOOLEAN": pa.bool_(),
|
||||
"BOOL": pa.bool_(),
|
||||
"TIMESTAMP": pa.timestamp("us", tz="UTC"),
|
||||
"DATE": pa.date32(),
|
||||
"TIME": pa.string(),
|
||||
"DATETIME": pa.timestamp("us"),
|
||||
"GEOGRAPHY": pa.string(),
|
||||
"JSON": pa.string(),
|
||||
"STRUCT": pa.string(),
|
||||
"RECORD": pa.string(),
|
||||
"ARRAY": pa.string(),
|
||||
}
|
||||
|
||||
|
||||
class BigQueryClient:
|
||||
"""
|
||||
Wrapper for Google BigQuery API.
|
||||
|
||||
Provides high-level methods for working with BigQuery tables:
|
||||
- Query tables to PyArrow Tables (no CSV step)
|
||||
- Get metadata (schema, columns)
|
||||
- Incremental and partitioned reads
|
||||
"""
|
||||
|
||||
def __init__(
|
||||
self,
|
||||
project_id: Optional[str] = None,
|
||||
location: Optional[str] = None,
|
||||
):
|
||||
"""
|
||||
Initialize BigQuery client.
|
||||
|
||||
Args:
|
||||
project_id: GCP project ID for job execution/billing.
|
||||
If None, reads from BIGQUERY_PROJECT env var.
|
||||
location: BigQuery location for job execution (e.g., "us-central1").
|
||||
If None, reads from BIGQUERY_LOCATION env var.
|
||||
|
||||
Raises:
|
||||
ValueError: If project_id is not provided and BIGQUERY_PROJECT is not set.
|
||||
"""
|
||||
self.project_id = project_id or os.environ.get("BIGQUERY_PROJECT")
|
||||
|
||||
if not self.project_id:
|
||||
raise ValueError(
|
||||
"BigQuery project ID not set. "
|
||||
"Set BIGQUERY_PROJECT environment variable."
|
||||
)
|
||||
|
||||
self.location = location or os.environ.get("BIGQUERY_LOCATION")
|
||||
|
||||
# Initialize BigQuery client with ADC
|
||||
# project_id is used for job execution and billing.
|
||||
# Data can live in a different project -- table IDs in queries
|
||||
# use fully-qualified format (project.dataset.table).
|
||||
client_kwargs = {"project": self.project_id}
|
||||
if self.location:
|
||||
client_kwargs["location"] = self.location
|
||||
self.client = bigquery.Client(**client_kwargs)
|
||||
|
||||
# BQ Storage API client for fast parallel reads (gRPC streams).
|
||||
# Without explicit bqstorage_client, to_arrow_iterable() silently
|
||||
# falls back to slow REST API pagination (~5K rows/sec vs ~300K rows/sec).
|
||||
if _HAS_BQ_STORAGE:
|
||||
try:
|
||||
self.bqstorage_client = bigquery_storage_v1.BigQueryReadClient()
|
||||
logger.info("BQ Storage API client initialized (fast parallel gRPC reads)")
|
||||
except Exception as e:
|
||||
self.bqstorage_client = None
|
||||
logger.warning(f"BQ Storage API client failed to initialize: {e}")
|
||||
else:
|
||||
self.bqstorage_client = None
|
||||
logger.info("BQ Storage API not available (install google-cloud-bigquery-storage)")
|
||||
|
||||
# Metadata cache
|
||||
config = get_config()
|
||||
self.metadata_cache: Dict[str, Dict[str, Any]] = {}
|
||||
self.metadata_cache_path = config.get_metadata_path() / "bq_table_metadata.json"
|
||||
|
||||
# Load cache from disk if exists
|
||||
self._load_metadata_cache()
|
||||
|
||||
logger.info(
|
||||
f"BigQuery client initialized: project={self.project_id}, "
|
||||
f"location={self.location or 'auto'}"
|
||||
)
|
||||
|
||||
def _load_metadata_cache(self):
|
||||
"""Load metadata cache from disk."""
|
||||
if self.metadata_cache_path.exists():
|
||||
try:
|
||||
with open(self.metadata_cache_path, "r") as f:
|
||||
self.metadata_cache = json.load(f)
|
||||
logger.info(
|
||||
f"BQ metadata cache loaded: {len(self.metadata_cache)} tables"
|
||||
)
|
||||
except Exception as e:
|
||||
logger.warning(f"Error loading BQ metadata cache: {e}")
|
||||
self.metadata_cache = {}
|
||||
|
||||
def _save_metadata_cache(self):
|
||||
"""Save metadata cache to disk."""
|
||||
try:
|
||||
self.metadata_cache_path.parent.mkdir(parents=True, exist_ok=True)
|
||||
with open(self.metadata_cache_path, "w") as f:
|
||||
json.dump(self.metadata_cache, f, indent=2)
|
||||
logger.debug("BQ metadata cache saved")
|
||||
except Exception as e:
|
||||
logger.warning(f"Error saving BQ metadata cache: {e}")
|
||||
|
||||
def get_table_metadata(
|
||||
self,
|
||||
table_id: str,
|
||||
use_cache: bool = True,
|
||||
cache_ttl_hours: int = 24,
|
||||
) -> Dict[str, Any]:
|
||||
"""
|
||||
Get table metadata from BigQuery.
|
||||
|
||||
Args:
|
||||
table_id: Full table ID (e.g., "project.dataset.table")
|
||||
use_cache: Use cache if available
|
||||
cache_ttl_hours: Cache TTL in hours (default 24h)
|
||||
|
||||
Returns:
|
||||
Dictionary with metadata including columns, types, descriptions, row count.
|
||||
"""
|
||||
# Check cache
|
||||
if use_cache and table_id in self.metadata_cache:
|
||||
cached = self.metadata_cache[table_id]
|
||||
cached_time = datetime.fromisoformat(cached.get("_cached_at", "2000-01-01"))
|
||||
cache_age = datetime.now() - cached_time
|
||||
|
||||
if cache_age < timedelta(hours=cache_ttl_hours):
|
||||
logger.debug(f"Using BQ metadata cache for {table_id}")
|
||||
return cached
|
||||
|
||||
logger.info(f"Fetching metadata for BQ table: {table_id}")
|
||||
|
||||
try:
|
||||
table_ref = self.client.get_table(table_id)
|
||||
|
||||
# Build column metadata
|
||||
columns = []
|
||||
column_types = {}
|
||||
column_descriptions = {}
|
||||
for field in table_ref.schema:
|
||||
columns.append(field.name)
|
||||
column_types[field.name] = field.field_type
|
||||
if field.description:
|
||||
column_descriptions[field.name] = field.description
|
||||
|
||||
metadata = {
|
||||
"table_id": table_id,
|
||||
"name": table_ref.table_id,
|
||||
"dataset": table_ref.dataset_id,
|
||||
"project": table_ref.project,
|
||||
"columns": columns,
|
||||
"column_types": column_types,
|
||||
"column_descriptions": column_descriptions,
|
||||
"row_count": table_ref.num_rows,
|
||||
"size_bytes": table_ref.num_bytes,
|
||||
"created": table_ref.created.isoformat() if table_ref.created else None,
|
||||
"modified": table_ref.modified.isoformat() if table_ref.modified else None,
|
||||
"partitioning": None,
|
||||
"_cached_at": datetime.now().isoformat(),
|
||||
}
|
||||
|
||||
# Capture partitioning info
|
||||
if table_ref.time_partitioning:
|
||||
metadata["partitioning"] = {
|
||||
"type": table_ref.time_partitioning.type_,
|
||||
"field": table_ref.time_partitioning.field,
|
||||
"expiration_ms": table_ref.time_partitioning.expiration_ms,
|
||||
}
|
||||
|
||||
# Save to cache
|
||||
self.metadata_cache[table_id] = metadata
|
||||
self._save_metadata_cache()
|
||||
|
||||
return metadata
|
||||
|
||||
except Exception as e:
|
||||
logger.error(f"Error getting metadata for {table_id}: {e}")
|
||||
raise
|
||||
|
||||
def get_pyarrow_schema(self, table_id: str) -> Optional[pa.Schema]:
|
||||
"""
|
||||
Build PyArrow schema from BigQuery table schema.
|
||||
|
||||
Args:
|
||||
table_id: Full table ID
|
||||
|
||||
Returns:
|
||||
PyArrow schema or None if metadata unavailable
|
||||
"""
|
||||
metadata = self.get_table_metadata(table_id)
|
||||
column_types = metadata.get("column_types", {})
|
||||
|
||||
if not column_types:
|
||||
logger.warning(f"No column types for {table_id}, schema will not be applied")
|
||||
return None
|
||||
|
||||
fields = []
|
||||
for col_name in metadata.get("columns", []):
|
||||
bq_type = column_types.get(col_name, "STRING")
|
||||
pa_type = BIGQUERY_TO_PYARROW_TYPES.get(bq_type, pa.string())
|
||||
fields.append(pa.field(col_name, pa_type))
|
||||
|
||||
return pa.schema(fields)
|
||||
|
||||
def get_date_columns(self, table_id: str) -> List[str]:
|
||||
"""
|
||||
Get list of DATE-only columns for a table.
|
||||
|
||||
Args:
|
||||
table_id: Full table ID
|
||||
|
||||
Returns:
|
||||
List of column names that have DATE type in BigQuery
|
||||
"""
|
||||
metadata = self.get_table_metadata(table_id)
|
||||
column_types = metadata.get("column_types", {})
|
||||
|
||||
return [
|
||||
col_name for col_name, bq_type in column_types.items()
|
||||
if bq_type == "DATE"
|
||||
]
|
||||
|
||||
def query_to_arrow(
|
||||
self,
|
||||
sql: str,
|
||||
params: Optional[List[bigquery.ScalarQueryParameter]] = None,
|
||||
) -> pa.Table:
|
||||
"""
|
||||
Execute SQL query and return results as PyArrow Table.
|
||||
|
||||
Args:
|
||||
sql: SQL query string (use @param_name for parameterized values)
|
||||
params: List of BigQuery query parameters
|
||||
|
||||
Returns:
|
||||
PyArrow Table with query results
|
||||
"""
|
||||
job_config = bigquery.QueryJobConfig()
|
||||
if params:
|
||||
job_config.query_parameters = params
|
||||
|
||||
logger.debug(f"Executing BQ query: {sql[:200]}...")
|
||||
|
||||
query_job = self.client.query(sql, job_config=job_config)
|
||||
|
||||
# Use BQ Storage API for fast reads (parallel gRPC) if available.
|
||||
# Fall back to REST API if SA lacks bigquery.readsessions.create permission.
|
||||
try:
|
||||
if self.bqstorage_client:
|
||||
arrow_table = query_job.to_arrow(bqstorage_client=self.bqstorage_client)
|
||||
else:
|
||||
arrow_table = query_job.to_arrow()
|
||||
except Exception as storage_err:
|
||||
if "readsessions" in str(storage_err) or "PERMISSION_DENIED" in str(storage_err):
|
||||
logger.warning(
|
||||
"BQ Storage API unavailable (missing readsessions permission), "
|
||||
"falling back to REST API"
|
||||
)
|
||||
arrow_table = query_job.to_arrow(create_bqstorage_client=False)
|
||||
else:
|
||||
raise
|
||||
|
||||
logger.debug(f"Query returned {arrow_table.num_rows} rows, {arrow_table.num_columns} columns")
|
||||
return arrow_table
|
||||
|
||||
def query_to_arrow_batches(
|
||||
self,
|
||||
sql: str,
|
||||
params: Optional[List[bigquery.ScalarQueryParameter]] = None,
|
||||
):
|
||||
"""
|
||||
Execute SQL query and yield results as streaming RecordBatches.
|
||||
|
||||
Unlike query_to_arrow(), this does NOT load entire result into memory.
|
||||
Each RecordBatch is a small chunk (typically a few MB) that can be
|
||||
written to disk immediately.
|
||||
|
||||
Args:
|
||||
sql: SQL query string (use @param_name for parameterized values)
|
||||
params: List of BigQuery query parameters
|
||||
|
||||
Yields:
|
||||
pyarrow.RecordBatch objects
|
||||
"""
|
||||
job_config = bigquery.QueryJobConfig()
|
||||
if params:
|
||||
job_config.query_parameters = params
|
||||
|
||||
logger.debug(f"Executing BQ query (streaming): {sql[:200]}...")
|
||||
|
||||
query_job = self.client.query(sql, job_config=job_config)
|
||||
|
||||
# result() returns RowIterator which has to_arrow_iterable()
|
||||
# (QueryJob itself only has to_arrow(), not to_arrow_iterable())
|
||||
row_iter = query_job.result()
|
||||
|
||||
# IMPORTANT: to_arrow_iterable() requires explicit bqstorage_client
|
||||
# to use BQ Storage API (parallel gRPC streams, ~300K rows/sec).
|
||||
# Without it, silently falls back to REST pagination (~5K rows/sec).
|
||||
# This is critical when querying VIEWS (DataView): BQ materializes
|
||||
# the view into a temp table, and Storage API reads from that temp table.
|
||||
try:
|
||||
storage_kwargs = {}
|
||||
if self.bqstorage_client:
|
||||
storage_kwargs["bqstorage_client"] = self.bqstorage_client
|
||||
batch_iter = row_iter.to_arrow_iterable(**storage_kwargs)
|
||||
# Probe first batch to detect Storage API permission errors early
|
||||
first_batch = next(batch_iter, None)
|
||||
if first_batch is not None:
|
||||
yield first_batch
|
||||
yield from batch_iter
|
||||
return
|
||||
except Exception as storage_err:
|
||||
if "readsessions" not in str(storage_err) and "PERMISSION_DENIED" not in str(storage_err):
|
||||
raise
|
||||
logger.warning(
|
||||
"BQ Storage API unavailable (missing readsessions permission), "
|
||||
"falling back to REST API (streaming)"
|
||||
)
|
||||
|
||||
# Fallback: REST API streaming (re-execute query for fresh RowIterator)
|
||||
row_iter = self.client.query(sql, job_config=job_config).result()
|
||||
yield from row_iter.to_arrow_iterable(create_bqstorage_client=False)
|
||||
|
||||
def read_table_streaming(
|
||||
self,
|
||||
table_id: str,
|
||||
columns: Optional[List[str]] = None,
|
||||
row_filter: Optional[str] = None,
|
||||
):
|
||||
"""
|
||||
Read table as streaming RecordBatches (constant memory).
|
||||
|
||||
Args:
|
||||
table_id: Full table ID (e.g., "project.dataset.table")
|
||||
columns: Optional list of columns to select
|
||||
row_filter: Optional SQL WHERE clause (without WHERE keyword)
|
||||
|
||||
Yields:
|
||||
pyarrow.RecordBatch objects
|
||||
"""
|
||||
select_cols = ", ".join(f"`{c}`" for c in columns) if columns else "*"
|
||||
|
||||
sql = f"SELECT {select_cols} FROM `{table_id}`"
|
||||
if row_filter:
|
||||
sql += f" WHERE {row_filter}"
|
||||
|
||||
logger.info(
|
||||
f"Streaming BQ table: {table_id} "
|
||||
f"(filter: {row_filter or 'none'}, "
|
||||
f"storage_api={'yes' if self.bqstorage_client else 'no'})"
|
||||
)
|
||||
yield from self.query_to_arrow_batches(sql)
|
||||
|
||||
def read_table(
|
||||
self,
|
||||
table_id: str,
|
||||
columns: Optional[List[str]] = None,
|
||||
row_filter: Optional[str] = None,
|
||||
) -> pa.Table:
|
||||
"""
|
||||
Read full table (or filtered subset) as PyArrow Table.
|
||||
|
||||
Args:
|
||||
table_id: Full table ID (e.g., "project.dataset.table")
|
||||
columns: Optional list of columns to select
|
||||
row_filter: Optional SQL WHERE clause (without WHERE keyword)
|
||||
|
||||
Returns:
|
||||
PyArrow Table with table data
|
||||
"""
|
||||
# Build SELECT clause
|
||||
select_cols = ", ".join(f"`{c}`" for c in columns) if columns else "*"
|
||||
|
||||
sql = f"SELECT {select_cols} FROM `{table_id}`"
|
||||
if row_filter:
|
||||
sql += f" WHERE {row_filter}"
|
||||
|
||||
logger.info(f"Reading BQ table: {table_id} (filter: {row_filter or 'none'})")
|
||||
return self.query_to_arrow(sql)
|
||||
|
||||
def read_table_incremental(
|
||||
self,
|
||||
table_id: str,
|
||||
incremental_column: str,
|
||||
since_value: str,
|
||||
columns: Optional[List[str]] = None,
|
||||
) -> pa.Table:
|
||||
"""
|
||||
Read rows where incremental_column > since_value.
|
||||
|
||||
Uses parameterized query to prevent SQL injection.
|
||||
|
||||
Args:
|
||||
table_id: Full table ID
|
||||
incremental_column: Column name for incremental filter
|
||||
since_value: ISO timestamp string - fetch rows after this value
|
||||
columns: Optional list of columns to select
|
||||
|
||||
Returns:
|
||||
PyArrow Table with incremental data
|
||||
"""
|
||||
select_cols = ", ".join(f"`{c}`" for c in columns) if columns else "*"
|
||||
|
||||
sql = (
|
||||
f"SELECT {select_cols} FROM `{table_id}` "
|
||||
f"WHERE `{incremental_column}` > @since_value"
|
||||
)
|
||||
|
||||
params = [
|
||||
bigquery.ScalarQueryParameter("since_value", "TIMESTAMP", since_value),
|
||||
]
|
||||
|
||||
logger.info(
|
||||
f"Incremental read: {table_id} WHERE {incremental_column} > {since_value}"
|
||||
)
|
||||
return self.query_to_arrow(sql, params=params)
|
||||
|
||||
def read_table_partitioned(
|
||||
self,
|
||||
table_id: str,
|
||||
partition_column: str,
|
||||
start: str,
|
||||
end: Optional[str] = None,
|
||||
columns: Optional[List[str]] = None,
|
||||
column_type: str = "TIMESTAMP",
|
||||
) -> pa.Table:
|
||||
"""
|
||||
Read data within a partition range.
|
||||
|
||||
Args:
|
||||
table_id: Full table ID
|
||||
partition_column: Partition column name
|
||||
start: Start date/timestamp (inclusive)
|
||||
end: End date/timestamp (exclusive). If None, reads to present.
|
||||
columns: Optional list of columns to select
|
||||
column_type: BQ SQL type for the partition column ("DATE", "TIMESTAMP", "DATETIME")
|
||||
|
||||
Returns:
|
||||
PyArrow Table with partition range data
|
||||
"""
|
||||
select_cols = ", ".join(f"`{c}`" for c in columns) if columns else "*"
|
||||
|
||||
sql = (
|
||||
f"SELECT {select_cols} FROM `{table_id}` "
|
||||
f"WHERE `{partition_column}` >= @start_value"
|
||||
)
|
||||
params = [
|
||||
bigquery.ScalarQueryParameter("start_value", column_type, start),
|
||||
]
|
||||
|
||||
if end:
|
||||
sql += f" AND `{partition_column}` < @end_value"
|
||||
params.append(
|
||||
bigquery.ScalarQueryParameter("end_value", column_type, end),
|
||||
)
|
||||
|
||||
logger.info(
|
||||
f"Partitioned read: {table_id} [{start} .. {end or 'now'})"
|
||||
)
|
||||
return self.query_to_arrow(sql, params=params)
|
||||
|
||||
def read_table_partitioned_streaming(
|
||||
self,
|
||||
table_id: str,
|
||||
partition_column: str,
|
||||
start: str,
|
||||
end: Optional[str] = None,
|
||||
columns: Optional[List[str]] = None,
|
||||
column_type: str = "TIMESTAMP",
|
||||
):
|
||||
"""
|
||||
Read data within a partition range as streaming RecordBatches (constant memory).
|
||||
|
||||
Unlike read_table_partitioned(), this does NOT load entire result into memory.
|
||||
Each RecordBatch is a small chunk that can be written to disk immediately.
|
||||
|
||||
Args:
|
||||
table_id: Full table ID
|
||||
partition_column: Partition column name
|
||||
start: Start date/timestamp (inclusive)
|
||||
end: End date/timestamp (exclusive). If None, reads to present.
|
||||
columns: Optional list of columns to select
|
||||
column_type: BQ SQL type for the partition column ("DATE", "TIMESTAMP", "DATETIME")
|
||||
|
||||
Yields:
|
||||
pyarrow.RecordBatch objects
|
||||
"""
|
||||
select_cols = ", ".join(f"`{c}`" for c in columns) if columns else "*"
|
||||
|
||||
sql = (
|
||||
f"SELECT {select_cols} FROM `{table_id}` "
|
||||
f"WHERE `{partition_column}` >= @start_value"
|
||||
)
|
||||
params = [
|
||||
bigquery.ScalarQueryParameter("start_value", column_type, start),
|
||||
]
|
||||
|
||||
if end:
|
||||
sql += f" AND `{partition_column}` < @end_value"
|
||||
params.append(
|
||||
bigquery.ScalarQueryParameter("end_value", column_type, end),
|
||||
)
|
||||
|
||||
logger.info(
|
||||
f"Partitioned streaming read: {table_id} [{start} .. {end or 'now'})"
|
||||
)
|
||||
yield from self.query_to_arrow_batches(sql, params=params)
|
||||
|
||||
def discover_all_tables(self, dataset_id: Optional[str] = None) -> List[Dict[str, Any]]:
|
||||
"""
|
||||
List all tables in the project (or specific dataset).
|
||||
|
||||
Args:
|
||||
dataset_id: Optional dataset ID to limit scope
|
||||
|
||||
Returns:
|
||||
Normalized list of table dicts with id, name, columns, row_count, etc.
|
||||
"""
|
||||
logger.info(f"Discovering BQ tables (dataset={dataset_id or 'all'})...")
|
||||
|
||||
result = []
|
||||
|
||||
if dataset_id:
|
||||
datasets = [self.client.get_dataset(dataset_id)]
|
||||
else:
|
||||
datasets = list(self.client.list_datasets())
|
||||
|
||||
for dataset in datasets:
|
||||
ds_ref = dataset.reference if hasattr(dataset, "reference") else dataset.dataset_id
|
||||
ds_id = str(ds_ref)
|
||||
|
||||
try:
|
||||
tables = list(self.client.list_tables(ds_ref))
|
||||
except Exception as e:
|
||||
logger.warning(f"Could not list tables in dataset {ds_id}: {e}")
|
||||
continue
|
||||
|
||||
for table_item in tables:
|
||||
full_id = f"{table_item.project}.{table_item.dataset_id}.{table_item.table_id}"
|
||||
|
||||
try:
|
||||
table_detail = self.client.get_table(full_id)
|
||||
columns = [f.name for f in table_detail.schema]
|
||||
|
||||
result.append({
|
||||
"id": full_id,
|
||||
"name": table_item.table_id,
|
||||
"bucket_id": table_item.dataset_id,
|
||||
"bucket_name": table_item.dataset_id,
|
||||
"columns": columns,
|
||||
"row_count": table_detail.num_rows or 0,
|
||||
"size_bytes": table_detail.num_bytes or 0,
|
||||
"primary_key": [],
|
||||
"last_change": (
|
||||
table_detail.modified.isoformat()
|
||||
if table_detail.modified else None
|
||||
),
|
||||
"last_import": None,
|
||||
})
|
||||
except Exception as e:
|
||||
logger.warning(f"Could not get details for {full_id}: {e}")
|
||||
|
||||
logger.info(f"Discovered {len(result)} BQ tables")
|
||||
return result
|
||||
|
||||
def test_connection(self) -> bool:
|
||||
"""
|
||||
Test connection to BigQuery.
|
||||
|
||||
Returns:
|
||||
True if connection works, False otherwise
|
||||
"""
|
||||
try:
|
||||
query_job = self.client.query("SELECT 1")
|
||||
list(query_job.result())
|
||||
logger.info(f"BigQuery connection OK (project: {self.project_id})")
|
||||
return True
|
||||
except Exception as e:
|
||||
logger.error(f"BigQuery connection test failed: {e}")
|
||||
return False
|
||||
|
||||
|
||||
def create_client() -> BigQueryClient:
|
||||
"""
|
||||
Factory function to create BigQuery client.
|
||||
|
||||
Returns:
|
||||
BigQueryClient instance
|
||||
"""
|
||||
return BigQueryClient()
|
||||
|
|
@ -1,820 +0,0 @@
|
|||
"""
|
||||
Keboola data source adapter.
|
||||
|
||||
Implements the DataSource interface for Keboola Storage API.
|
||||
Downloads tables via the Storage API, converts CSV exports to Parquet files
|
||||
with full type metadata from Keboola column metadata.
|
||||
"""
|
||||
|
||||
import logging
|
||||
import tempfile
|
||||
from pathlib import Path
|
||||
from typing import Dict, List, Optional, Any
|
||||
from datetime import datetime, timedelta
|
||||
|
||||
import pyarrow as pa
|
||||
from tqdm import tqdm
|
||||
|
||||
from src.config import get_config, TableConfig
|
||||
from src.data_sync import DataSource, SyncState, _get_uncompressed_size
|
||||
from src.parquet_manager import (
|
||||
create_parquet_manager,
|
||||
_convert_column,
|
||||
convert_date_columns_to_date32,
|
||||
apply_schema_to_table,
|
||||
)
|
||||
from .client import create_client as create_keboola_client
|
||||
|
||||
|
||||
logger = logging.getLogger(__name__)
|
||||
|
||||
|
||||
class KeboolaDataSource(DataSource):
|
||||
"""
|
||||
Data source: Direct download from Keboola Storage API.
|
||||
|
||||
Downloads data directly from a Keboola project, converts CSV exports
|
||||
to typed Parquet files using column metadata for schema enforcement.
|
||||
"""
|
||||
|
||||
def __init__(self):
|
||||
"""Initialize Keboola source with full env var validation."""
|
||||
self.config = get_config()
|
||||
|
||||
# Validate all required Keboola env vars before proceeding
|
||||
missing = []
|
||||
if not self.config.keboola_token:
|
||||
missing.append("KEBOOLA_STORAGE_TOKEN")
|
||||
if not self.config.keboola_stack_url:
|
||||
missing.append("KEBOOLA_STACK_URL")
|
||||
if not self.config.keboola_project_id:
|
||||
missing.append("KEBOOLA_PROJECT_ID")
|
||||
if missing:
|
||||
raise ValueError(
|
||||
f"Missing required environment variables for Keboola connector: "
|
||||
f"{', '.join(missing)}. See config/.env.template"
|
||||
)
|
||||
|
||||
self.keboola_client = create_keboola_client()
|
||||
self.parquet_manager = create_parquet_manager()
|
||||
|
||||
def get_column_metadata(self, table_id: str) -> Optional[Dict[str, Any]]:
|
||||
"""Return Keboola metadata with provider cascade applied.
|
||||
|
||||
Delegates type resolution to the client's _resolve_keboola_type(),
|
||||
and extracts descriptions via provider priority cascade.
|
||||
|
||||
Returns:
|
||||
{"columns": {"col_name": {"source_type": "...", "description": "..."}}}
|
||||
or None if metadata is unavailable.
|
||||
"""
|
||||
raw = self.keboola_client.get_table_metadata(table_id)
|
||||
column_metadata = raw.get("column_metadata", {})
|
||||
|
||||
if not column_metadata:
|
||||
return None
|
||||
|
||||
PROVIDER_PRIORITY = [
|
||||
"user",
|
||||
"ai-metadata-enrichment",
|
||||
"keboola.snowflake-transformation",
|
||||
]
|
||||
|
||||
result = {}
|
||||
for col_name, col_meta_list in column_metadata.items():
|
||||
# Delegate type resolution to client
|
||||
source_type = self.keboola_client._resolve_keboola_type(col_meta_list)
|
||||
|
||||
# Extract description via provider cascade
|
||||
description = None
|
||||
if isinstance(col_meta_list, list):
|
||||
description_by_provider = {}
|
||||
for entry in col_meta_list:
|
||||
provider = entry.get("provider", "")
|
||||
key = entry.get("key", "")
|
||||
value = entry.get("value", "")
|
||||
if key == "KBC.description":
|
||||
description_by_provider[provider] = value
|
||||
|
||||
for p in PROVIDER_PRIORITY:
|
||||
if p in description_by_provider:
|
||||
description = description_by_provider[p]
|
||||
break
|
||||
|
||||
result[col_name] = {"source_type": source_type}
|
||||
if description:
|
||||
result[col_name]["description"] = description
|
||||
|
||||
return {"columns": result}
|
||||
|
||||
def discover_tables(self) -> List[Dict[str, Any]]:
|
||||
"""Discover all available tables from Keboola Storage."""
|
||||
return self.keboola_client.discover_all_tables()
|
||||
|
||||
def get_source_name(self) -> str:
|
||||
"""Display name of this data source."""
|
||||
return "Keboola Storage API"
|
||||
|
||||
def _cleanup_staging(self):
|
||||
"""
|
||||
Remove all files from staging directory.
|
||||
|
||||
Called before chunked initial load and after failures to free up disk space.
|
||||
"""
|
||||
staging_dir = self.config.get_staging_path()
|
||||
for f in staging_dir.glob("*"):
|
||||
if f.is_file():
|
||||
try:
|
||||
f.unlink()
|
||||
logger.debug(f"Cleaned up staging file: {f.name}")
|
||||
except Exception as e:
|
||||
logger.warning(f"Failed to clean up {f.name}: {e}")
|
||||
|
||||
def sync_table(
|
||||
self,
|
||||
table_config: TableConfig,
|
||||
sync_state: SyncState,
|
||||
) -> Dict[str, Any]:
|
||||
"""
|
||||
Synchronize table from Keboola.
|
||||
|
||||
According to sync_strategy calls _full_refresh or _incremental_sync.
|
||||
|
||||
Args:
|
||||
table_config: Table configuration
|
||||
sync_state: Sync state manager
|
||||
|
||||
Returns:
|
||||
Dictionary with result:
|
||||
- success: bool
|
||||
- rows: int
|
||||
- strategy: str
|
||||
- error: str (if failed)
|
||||
"""
|
||||
logger.info(f"Syncing table: {table_config.name} ({table_config.sync_strategy})")
|
||||
|
||||
# Refresh metadata cache for this table to get latest types from Keboola
|
||||
if table_config.id in self.keboola_client.metadata_cache:
|
||||
del self.keboola_client.metadata_cache[table_config.id]
|
||||
logger.debug(f"Cleared metadata cache for {table_config.id}")
|
||||
|
||||
try:
|
||||
if table_config.sync_strategy == "full_refresh":
|
||||
result = self._full_refresh(table_config)
|
||||
elif table_config.sync_strategy == "partitioned":
|
||||
result = self._partitioned_sync(table_config)
|
||||
else: # incremental
|
||||
result = self._incremental_sync(table_config, sync_state)
|
||||
|
||||
# Update sync state
|
||||
sync_state.update_sync(
|
||||
table_id=table_config.id,
|
||||
table_name=table_config.name,
|
||||
strategy=table_config.sync_strategy,
|
||||
rows=result["rows"],
|
||||
file_size_bytes=result["file_size_bytes"],
|
||||
columns=result.get("columns", 0),
|
||||
uncompressed_bytes=result.get("uncompressed_bytes", 0),
|
||||
)
|
||||
|
||||
return {
|
||||
"success": True,
|
||||
"rows": result["rows"],
|
||||
"strategy": table_config.sync_strategy,
|
||||
"file_size_mb": result["file_size_bytes"] / 1024 / 1024,
|
||||
}
|
||||
|
||||
except Exception as e:
|
||||
logger.error(f"Error syncing table {table_config.name}: {e}")
|
||||
return {
|
||||
"success": False,
|
||||
"error": str(e),
|
||||
"strategy": table_config.sync_strategy,
|
||||
}
|
||||
|
||||
def _full_refresh(self, table_config: TableConfig) -> Dict[str, Any]:
|
||||
"""
|
||||
Full refresh sync strategy.
|
||||
|
||||
Downloads entire table and replaces existing Parquet file.
|
||||
"""
|
||||
logger.info(f"Full refresh: {table_config.name}")
|
||||
|
||||
parquet_path = self.config.get_parquet_path(table_config)
|
||||
staging_dir = self.config.get_staging_path()
|
||||
|
||||
with tempfile.NamedTemporaryFile(
|
||||
mode="w", suffix=".csv", delete=False, dir=staging_dir
|
||||
) as tmp_file:
|
||||
tmp_csv_path = Path(tmp_file.name)
|
||||
|
||||
try:
|
||||
# 1. Export from Keboola to CSV
|
||||
filters_desc = ""
|
||||
if table_config.where_filters:
|
||||
filters_desc = f" (filters: {len(table_config.where_filters)})"
|
||||
logger.info(f" -> Exporting from Keboola...{filters_desc}")
|
||||
export_info = self.keboola_client.export_table(
|
||||
table_id=table_config.id,
|
||||
output_path=tmp_csv_path,
|
||||
where_filters=table_config.where_filters if table_config.where_filters else None,
|
||||
)
|
||||
|
||||
# 2. Get dtypes for proper conversion
|
||||
dtypes = self.keboola_client.get_pandas_dtypes(table_config.id)
|
||||
date_columns = self.keboola_client.get_date_columns(table_config.id)
|
||||
pyarrow_schema = self.keboola_client.get_pyarrow_schema(table_config.id)
|
||||
|
||||
# 3. Convert CSV -> Parquet
|
||||
logger.info(" -> Converting to Parquet...")
|
||||
parquet_info = self.parquet_manager.csv_to_parquet(
|
||||
csv_path=tmp_csv_path,
|
||||
parquet_path=parquet_path,
|
||||
dtypes=dtypes,
|
||||
table_id=table_config.id,
|
||||
date_columns=date_columns,
|
||||
pyarrow_schema=pyarrow_schema,
|
||||
)
|
||||
|
||||
return {
|
||||
"rows": parquet_info["rows"],
|
||||
"file_size_bytes": parquet_info["parquet_size_bytes"],
|
||||
"columns": parquet_info.get("columns", 0),
|
||||
"uncompressed_bytes": _get_uncompressed_size(parquet_path),
|
||||
}
|
||||
|
||||
finally:
|
||||
if tmp_csv_path.exists():
|
||||
tmp_csv_path.unlink()
|
||||
|
||||
def _incremental_sync(
|
||||
self,
|
||||
table_config: TableConfig,
|
||||
sync_state: SyncState,
|
||||
) -> Dict[str, Any]:
|
||||
"""
|
||||
Incremental sync strategy.
|
||||
|
||||
Downloads only changed rows using changedSince API parameter.
|
||||
If partition_by is configured, outputs are partitioned.
|
||||
Otherwise, merges into a single Parquet file.
|
||||
"""
|
||||
if table_config.partition_by:
|
||||
return self._incremental_partitioned_sync(table_config, sync_state)
|
||||
return self._incremental_single_file_sync(table_config, sync_state)
|
||||
|
||||
def _incremental_single_file_sync(
|
||||
self,
|
||||
table_config: TableConfig,
|
||||
sync_state: SyncState,
|
||||
) -> Dict[str, Any]:
|
||||
"""
|
||||
Incremental sync to a single Parquet file (no partitioning).
|
||||
"""
|
||||
logger.info(f"Incremental sync (single file): {table_config.name}")
|
||||
|
||||
parquet_path = self.config.get_parquet_path(table_config)
|
||||
staging_dir = self.config.get_staging_path()
|
||||
|
||||
# Determine timestamp for changedSince
|
||||
last_sync = sync_state.get_last_sync(table_config.id)
|
||||
|
||||
if last_sync:
|
||||
last_sync_dt = datetime.fromisoformat(last_sync)
|
||||
window_days = table_config.incremental_window_days or 7
|
||||
changed_since_dt = last_sync_dt - timedelta(days=window_days)
|
||||
changed_since = changed_since_dt.isoformat()
|
||||
logger.info(
|
||||
f" -> ChangedSince: {changed_since} (window: {window_days} days)"
|
||||
)
|
||||
else:
|
||||
if table_config.max_history_days:
|
||||
changed_since_dt = datetime.now() - timedelta(days=table_config.max_history_days)
|
||||
changed_since = changed_since_dt.isoformat()
|
||||
logger.info(
|
||||
f" -> First sync, limited to last {table_config.max_history_days} days "
|
||||
f"(changedSince: {changed_since})"
|
||||
)
|
||||
else:
|
||||
logger.info(" -> First sync, downloading all data...")
|
||||
changed_since = None
|
||||
|
||||
with tempfile.NamedTemporaryFile(
|
||||
mode="w", suffix=".csv", delete=False, dir=staging_dir
|
||||
) as tmp_file:
|
||||
tmp_csv_path = Path(tmp_file.name)
|
||||
|
||||
try:
|
||||
# 1. Export changed data from Keboola
|
||||
logger.info(" -> Exporting changes from Keboola...")
|
||||
export_info = self.keboola_client.export_table(
|
||||
table_id=table_config.id,
|
||||
output_path=tmp_csv_path,
|
||||
changed_since=changed_since,
|
||||
)
|
||||
|
||||
if export_info["exported_rows"] == 0:
|
||||
logger.info(" -> No changes since last synchronization")
|
||||
if parquet_path.exists():
|
||||
existing_info = self.parquet_manager.get_parquet_info(parquet_path)
|
||||
return {
|
||||
"rows": existing_info["rows"],
|
||||
"file_size_bytes": existing_info["file_size_bytes"],
|
||||
"columns": existing_info.get("columns", 0),
|
||||
"uncompressed_bytes": _get_uncompressed_size(parquet_path),
|
||||
}
|
||||
else:
|
||||
return {"rows": 0, "file_size_bytes": 0, "columns": 0, "uncompressed_bytes": 0}
|
||||
|
||||
# 2. Get dtypes and date columns
|
||||
dtypes = self.keboola_client.get_pandas_dtypes(table_config.id)
|
||||
date_columns = self.keboola_client.get_date_columns(table_config.id)
|
||||
pyarrow_schema = self.keboola_client.get_pyarrow_schema(table_config.id)
|
||||
|
||||
# 3. If Parquet exists, merge; otherwise create new
|
||||
if parquet_path.exists():
|
||||
logger.info(
|
||||
f" -> Merging {export_info['exported_rows']} changes into Parquet..."
|
||||
)
|
||||
|
||||
with tempfile.NamedTemporaryFile(
|
||||
mode="w", suffix=".parquet", delete=False, dir=staging_dir
|
||||
) as tmp_parquet_file:
|
||||
tmp_parquet_path = Path(tmp_parquet_file.name)
|
||||
|
||||
try:
|
||||
merge_info = self.parquet_manager.merge_parquet(
|
||||
existing_parquet=parquet_path,
|
||||
new_csv=tmp_csv_path,
|
||||
output_parquet=tmp_parquet_path,
|
||||
primary_key=table_config.get_primary_key_columns(),
|
||||
dtypes=dtypes,
|
||||
date_columns=date_columns,
|
||||
pyarrow_schema=pyarrow_schema,
|
||||
)
|
||||
|
||||
tmp_parquet_path.replace(parquet_path)
|
||||
|
||||
return {
|
||||
"rows": merge_info["total_rows"],
|
||||
"file_size_bytes": parquet_path.stat().st_size,
|
||||
"columns": merge_info.get("total_columns", 0),
|
||||
"uncompressed_bytes": _get_uncompressed_size(parquet_path),
|
||||
}
|
||||
|
||||
finally:
|
||||
if tmp_parquet_path.exists():
|
||||
tmp_parquet_path.unlink()
|
||||
|
||||
else:
|
||||
logger.info(" -> Creating new Parquet...")
|
||||
parquet_info = self.parquet_manager.csv_to_parquet(
|
||||
csv_path=tmp_csv_path,
|
||||
parquet_path=parquet_path,
|
||||
dtypes=dtypes,
|
||||
table_id=table_config.id,
|
||||
date_columns=date_columns,
|
||||
pyarrow_schema=pyarrow_schema,
|
||||
)
|
||||
|
||||
return {
|
||||
"rows": parquet_info["rows"],
|
||||
"file_size_bytes": parquet_info["parquet_size_bytes"],
|
||||
"columns": parquet_info.get("columns", 0),
|
||||
"uncompressed_bytes": _get_uncompressed_size(parquet_path),
|
||||
}
|
||||
|
||||
finally:
|
||||
if tmp_csv_path.exists():
|
||||
tmp_csv_path.unlink()
|
||||
|
||||
def _incremental_partitioned_sync(
|
||||
self,
|
||||
table_config: TableConfig,
|
||||
sync_state: SyncState,
|
||||
) -> Dict[str, Any]:
|
||||
"""
|
||||
Incremental sync with partitioned output.
|
||||
|
||||
Downloads only changed rows using changedSince API parameter,
|
||||
then partitions by partition_by column and merges into existing
|
||||
partition files. Same logic as _partitioned_sync but uses
|
||||
changedSince instead of whereFilters.
|
||||
|
||||
For initial load of large tables (max_history_days > chunk_days),
|
||||
uses chunked download to avoid filling up disk space. Each chunk
|
||||
has 1-day overlap with the next to ensure no data is lost at boundaries.
|
||||
"""
|
||||
import pandas as pd
|
||||
|
||||
logger.info(
|
||||
f"Incremental sync (partitioned): {table_config.name} "
|
||||
f"(by {table_config.partition_by}, {table_config.partition_granularity})"
|
||||
)
|
||||
|
||||
partition_dir = self.config.get_parquet_path(table_config)
|
||||
staging_dir = self.config.get_staging_path()
|
||||
|
||||
last_sync = sync_state.get_last_sync(table_config.id)
|
||||
|
||||
# For initial load (no last_sync), always use chunked approach
|
||||
if not last_sync:
|
||||
return self._chunked_initial_load(table_config, partition_dir, staging_dir)
|
||||
|
||||
# Regular incremental sync
|
||||
last_sync_dt = datetime.fromisoformat(last_sync)
|
||||
window_days = table_config.incremental_window_days or 7
|
||||
changed_since_dt = last_sync_dt - timedelta(days=window_days)
|
||||
changed_since = changed_since_dt.isoformat()
|
||||
|
||||
logger.info(
|
||||
f" -> ChangedSince: {changed_since} (window: {window_days} days)"
|
||||
)
|
||||
|
||||
with tempfile.NamedTemporaryFile(
|
||||
mode="w", suffix=".csv", delete=False, dir=staging_dir
|
||||
) as tmp_file:
|
||||
tmp_csv_path = Path(tmp_file.name)
|
||||
|
||||
try:
|
||||
logger.info(" -> Exporting changes from Keboola...")
|
||||
export_info = self.keboola_client.export_table(
|
||||
table_id=table_config.id,
|
||||
output_path=tmp_csv_path,
|
||||
changed_since=changed_since,
|
||||
)
|
||||
|
||||
if export_info["exported_rows"] == 0:
|
||||
logger.info(" -> No changes since last synchronization")
|
||||
return self._get_partition_totals(partition_dir)
|
||||
|
||||
dtypes = self.keboola_client.get_pandas_dtypes(table_config.id)
|
||||
date_columns = self.keboola_client.get_date_columns(table_config.id)
|
||||
pyarrow_schema = self.keboola_client.get_pyarrow_schema(table_config.id)
|
||||
|
||||
logger.info(f" -> Processing {export_info['exported_rows']} changed rows...")
|
||||
|
||||
partitions_updated = self._process_csv_to_partitions(
|
||||
tmp_csv_path, table_config, partition_dir,
|
||||
dtypes=dtypes, date_columns=date_columns, pyarrow_schema=pyarrow_schema,
|
||||
)
|
||||
|
||||
self._deduplicate_partitions(
|
||||
table_config, partitions_updated,
|
||||
date_columns=date_columns, pyarrow_schema=pyarrow_schema,
|
||||
)
|
||||
|
||||
logger.info(f" -> Incremental sync complete, {len(partitions_updated)} partitions updated")
|
||||
return self._get_partition_totals(partition_dir)
|
||||
|
||||
finally:
|
||||
if tmp_csv_path.exists():
|
||||
tmp_csv_path.unlink()
|
||||
|
||||
def _chunked_initial_load(
|
||||
self,
|
||||
table_config: TableConfig,
|
||||
partition_dir: Path,
|
||||
staging_dir: Path,
|
||||
) -> Dict[str, Any]:
|
||||
"""
|
||||
Chunked initial load for large tables.
|
||||
|
||||
Downloads data in time-window chunks to avoid filling up disk space.
|
||||
Each chunk has 1-day overlap with the next to ensure no data is lost
|
||||
at boundaries. Deduplication removes any duplicates from overlaps.
|
||||
"""
|
||||
chunk_days = table_config.initial_load_chunk_days
|
||||
max_history_days = table_config.max_history_days
|
||||
overlap_days = 1
|
||||
max_chunks_safety = 120
|
||||
|
||||
now = datetime.now()
|
||||
|
||||
if max_history_days:
|
||||
num_chunks = (max_history_days + chunk_days - 1) // chunk_days
|
||||
logger.info(
|
||||
f" -> CHUNKED INITIAL LOAD: {max_history_days} days in {num_chunks} chunks "
|
||||
f"of {chunk_days} days each (with {overlap_days}-day overlap)"
|
||||
)
|
||||
else:
|
||||
num_chunks = None
|
||||
logger.info(
|
||||
f" -> CHUNKED INITIAL LOAD: iterating backwards in {chunk_days}-day chunks "
|
||||
f"until no more data (with {overlap_days}-day overlap)"
|
||||
)
|
||||
|
||||
self._cleanup_staging()
|
||||
|
||||
dtypes = self.keboola_client.get_pandas_dtypes(table_config.id)
|
||||
date_columns = self.keboola_client.get_date_columns(table_config.id)
|
||||
pyarrow_schema = self.keboola_client.get_pyarrow_schema(table_config.id)
|
||||
|
||||
all_partitions_updated = set()
|
||||
chunk_idx = 0
|
||||
consecutive_empty_chunks = 0
|
||||
|
||||
while True:
|
||||
if chunk_idx >= max_chunks_safety:
|
||||
logger.warning(f" -> Reached safety limit of {max_chunks_safety} chunks, stopping")
|
||||
break
|
||||
|
||||
if num_chunks is not None and chunk_idx >= num_chunks:
|
||||
break
|
||||
|
||||
chunk_end_offset = chunk_idx * chunk_days
|
||||
chunk_start_offset = chunk_end_offset + chunk_days + overlap_days
|
||||
|
||||
chunk_end = now - timedelta(days=chunk_end_offset) if chunk_idx > 0 else None
|
||||
chunk_start = now - timedelta(days=chunk_start_offset)
|
||||
|
||||
if max_history_days and chunk_start_offset > max_history_days:
|
||||
chunk_start = now - timedelta(days=max_history_days)
|
||||
|
||||
chunk_label = f"{chunk_idx + 1}" if num_chunks is None else f"{chunk_idx + 1}/{num_chunks}"
|
||||
logger.info(
|
||||
f" -> Chunk {chunk_label}: "
|
||||
f"{chunk_start.strftime('%Y-%m-%d')} to "
|
||||
f"{chunk_end.strftime('%Y-%m-%d') if chunk_end else 'now'}"
|
||||
)
|
||||
|
||||
with tempfile.NamedTemporaryFile(
|
||||
mode="w", suffix=".csv", delete=False, dir=staging_dir
|
||||
) as tmp_file:
|
||||
tmp_csv_path = Path(tmp_file.name)
|
||||
|
||||
try:
|
||||
export_info = self.keboola_client.export_table(
|
||||
table_id=table_config.id,
|
||||
output_path=tmp_csv_path,
|
||||
changed_since=chunk_start.isoformat(),
|
||||
changed_until=chunk_end.isoformat() if chunk_end else None,
|
||||
)
|
||||
|
||||
if export_info["exported_rows"] == 0:
|
||||
logger.info(" -> No data in this chunk")
|
||||
consecutive_empty_chunks += 1
|
||||
if num_chunks is None and consecutive_empty_chunks >= 2:
|
||||
logger.info(
|
||||
f" -> Found {consecutive_empty_chunks} consecutive empty chunks, "
|
||||
f"assuming end of history"
|
||||
)
|
||||
break
|
||||
chunk_idx += 1
|
||||
continue
|
||||
|
||||
consecutive_empty_chunks = 0
|
||||
logger.info(f" -> Exported {export_info['exported_rows']} rows")
|
||||
|
||||
partitions_updated = self._process_csv_to_partitions(
|
||||
tmp_csv_path, table_config, partition_dir,
|
||||
dtypes=dtypes, date_columns=date_columns, pyarrow_schema=pyarrow_schema,
|
||||
)
|
||||
all_partitions_updated.update(partitions_updated)
|
||||
|
||||
logger.info(f" -> Processed into {len(partitions_updated)} partitions")
|
||||
|
||||
finally:
|
||||
if tmp_csv_path.exists():
|
||||
tmp_csv_path.unlink()
|
||||
|
||||
chunk_idx += 1
|
||||
|
||||
if all_partitions_updated:
|
||||
logger.info(
|
||||
f" -> Final deduplication of {len(all_partitions_updated)} partitions "
|
||||
f"(removing duplicates from {overlap_days}-day overlaps)..."
|
||||
)
|
||||
self._deduplicate_partitions(
|
||||
table_config, all_partitions_updated,
|
||||
date_columns=date_columns, pyarrow_schema=pyarrow_schema,
|
||||
)
|
||||
|
||||
logger.info(
|
||||
f" -> Chunked initial load complete: {len(all_partitions_updated)} partitions"
|
||||
)
|
||||
|
||||
return self._get_partition_totals(partition_dir)
|
||||
|
||||
def _process_csv_to_partitions(
|
||||
self,
|
||||
csv_path: Path,
|
||||
table_config: TableConfig,
|
||||
partition_dir: Path,
|
||||
dtypes: Optional[Dict[str, str]] = None,
|
||||
date_columns: Optional[List[str]] = None,
|
||||
pyarrow_schema: Optional[pa.Schema] = None,
|
||||
) -> set:
|
||||
"""
|
||||
Process CSV file and write to partition files.
|
||||
|
||||
Returns:
|
||||
Set of partition keys that were updated
|
||||
"""
|
||||
import pandas as pd
|
||||
import pyarrow.parquet as pq
|
||||
|
||||
partition_col = table_config.partition_by
|
||||
granularity = table_config.partition_granularity or "month"
|
||||
|
||||
partitions_updated = set()
|
||||
chunk_size = 500000 # 500k rows per pandas chunk
|
||||
|
||||
chunk_num = 0
|
||||
for chunk_df in pd.read_csv(csv_path, chunksize=chunk_size, dtype=str):
|
||||
chunk_num += 1
|
||||
logger.debug(f" -> Processing pandas chunk {chunk_num} ({len(chunk_df)} rows)...")
|
||||
|
||||
if partition_col not in chunk_df.columns:
|
||||
raise ValueError(f"Partition column '{partition_col}' not found in data")
|
||||
|
||||
# Apply dtypes using _convert_column (except datetime columns)
|
||||
if dtypes:
|
||||
for col, dtype in dtypes.items():
|
||||
if col in chunk_df.columns and "datetime" not in dtype:
|
||||
try:
|
||||
chunk_df[col] = _convert_column(chunk_df[col], dtype, col_name=col)
|
||||
except Exception as e:
|
||||
logger.warning(f"Failed to apply dtype {dtype} to column {col}: {e}")
|
||||
|
||||
# Convert partition column to datetime
|
||||
if not pd.api.types.is_datetime64_any_dtype(chunk_df[partition_col]):
|
||||
chunk_df[partition_col] = pd.to_datetime(
|
||||
chunk_df[partition_col], format="ISO8601", utc=True
|
||||
)
|
||||
|
||||
# Create partition key based on granularity
|
||||
if granularity == "month":
|
||||
chunk_df["_partition_key"] = chunk_df[partition_col].dt.strftime("%Y_%m")
|
||||
elif granularity == "day":
|
||||
chunk_df["_partition_key"] = chunk_df[partition_col].dt.strftime("%Y_%m_%d")
|
||||
elif granularity == "year":
|
||||
chunk_df["_partition_key"] = chunk_df[partition_col].dt.strftime("%Y")
|
||||
|
||||
# Group by partition and append to partition files
|
||||
for partition_key, partition_df in chunk_df.groupby("_partition_key"):
|
||||
partition_df = partition_df.drop(columns=["_partition_key"])
|
||||
partition_path = self.config.get_partition_path(table_config, partition_key)
|
||||
partitions_updated.add(partition_key)
|
||||
|
||||
if partition_path.exists():
|
||||
existing_df = pd.read_parquet(partition_path)
|
||||
merged_df = pd.concat([existing_df, partition_df], ignore_index=True)
|
||||
table = pa.Table.from_pandas(merged_df, preserve_index=False)
|
||||
if date_columns:
|
||||
table = convert_date_columns_to_date32(table, date_columns)
|
||||
if pyarrow_schema:
|
||||
table = apply_schema_to_table(table, pyarrow_schema)
|
||||
pq.write_table(table, partition_path, compression="snappy")
|
||||
else:
|
||||
table = pa.Table.from_pandas(partition_df, preserve_index=False)
|
||||
if date_columns:
|
||||
table = convert_date_columns_to_date32(table, date_columns)
|
||||
if pyarrow_schema:
|
||||
table = apply_schema_to_table(table, pyarrow_schema)
|
||||
pq.write_table(table, partition_path, compression="snappy")
|
||||
|
||||
return partitions_updated
|
||||
|
||||
def _deduplicate_partitions(
|
||||
self,
|
||||
table_config: TableConfig,
|
||||
partitions_to_dedup: set,
|
||||
date_columns: Optional[List[str]] = None,
|
||||
pyarrow_schema: Optional[pa.Schema] = None,
|
||||
):
|
||||
"""
|
||||
Deduplicate partition files based on primary key.
|
||||
"""
|
||||
import pandas as pd
|
||||
import pyarrow.parquet as pq
|
||||
|
||||
primary_key_cols = table_config.get_primary_key_columns()
|
||||
|
||||
logger.info(f" -> Deduplicating {len(partitions_to_dedup)} partitions...")
|
||||
|
||||
for partition_key in sorted(partitions_to_dedup):
|
||||
partition_path = self.config.get_partition_path(table_config, partition_key)
|
||||
|
||||
if not partition_path.exists():
|
||||
continue
|
||||
|
||||
df = pd.read_parquet(partition_path)
|
||||
rows_before = len(df)
|
||||
df = df.drop_duplicates(subset=primary_key_cols, keep="last")
|
||||
rows_after = len(df)
|
||||
|
||||
table = pa.Table.from_pandas(df, preserve_index=False)
|
||||
if date_columns:
|
||||
table = convert_date_columns_to_date32(table, date_columns)
|
||||
if pyarrow_schema:
|
||||
table = apply_schema_to_table(table, pyarrow_schema)
|
||||
pq.write_table(table, partition_path, compression="snappy")
|
||||
|
||||
if rows_before != rows_after:
|
||||
logger.debug(
|
||||
f" -> Partition {partition_key}: {rows_before} -> {rows_after} rows "
|
||||
f"(removed {rows_before - rows_after} duplicates)"
|
||||
)
|
||||
|
||||
def _get_partition_totals(self, partition_dir: Path) -> Dict[str, Any]:
|
||||
"""
|
||||
Calculate totals from all partition files in a directory.
|
||||
"""
|
||||
import pyarrow.parquet as pq
|
||||
|
||||
total_rows = 0
|
||||
total_size = 0
|
||||
total_uncompressed = 0
|
||||
total_columns = 0
|
||||
|
||||
if not partition_dir.exists():
|
||||
return {"rows": 0, "file_size_bytes": 0, "columns": 0, "uncompressed_bytes": 0}
|
||||
|
||||
all_partitions = list(partition_dir.glob("*.parquet"))
|
||||
|
||||
for part_path in all_partitions:
|
||||
try:
|
||||
pf = pq.ParquetFile(part_path)
|
||||
meta = pf.metadata
|
||||
total_rows += meta.num_rows
|
||||
total_size += part_path.stat().st_size
|
||||
if total_columns == 0:
|
||||
total_columns = len(pf.schema_arrow)
|
||||
for rg_idx in range(meta.num_row_groups):
|
||||
rg = meta.row_group(rg_idx)
|
||||
for col_idx in range(rg.num_columns):
|
||||
total_uncompressed += rg.column(col_idx).total_uncompressed_size
|
||||
except Exception as e:
|
||||
logger.warning(f" -> Skipping corrupt partition {part_path.name}: {e}")
|
||||
|
||||
return {
|
||||
"rows": total_rows,
|
||||
"file_size_bytes": total_size,
|
||||
"partitions": len(all_partitions),
|
||||
"columns": total_columns,
|
||||
"uncompressed_bytes": total_uncompressed,
|
||||
}
|
||||
|
||||
def _partitioned_sync(self, table_config: TableConfig) -> Dict[str, Any]:
|
||||
"""
|
||||
Partitioned sync strategy.
|
||||
|
||||
Downloads data and splits into monthly (or other granularity) partitions.
|
||||
Each partition is stored as separate Parquet file and merged independently.
|
||||
"""
|
||||
logger.info(
|
||||
f"Partitioned sync: {table_config.name} "
|
||||
f"(by {table_config.partition_by}, {table_config.partition_granularity})"
|
||||
)
|
||||
|
||||
partition_dir = self.config.get_parquet_path(table_config)
|
||||
staging_dir = self.config.get_staging_path()
|
||||
|
||||
with tempfile.NamedTemporaryFile(
|
||||
mode="w", suffix=".csv", delete=False, dir=staging_dir
|
||||
) as tmp_file:
|
||||
tmp_csv_path = Path(tmp_file.name)
|
||||
|
||||
try:
|
||||
filters_desc = ""
|
||||
if table_config.where_filters:
|
||||
filters_desc = f" (filters: {len(table_config.where_filters)})"
|
||||
logger.info(f" -> Exporting from Keboola...{filters_desc}")
|
||||
export_info = self.keboola_client.export_table(
|
||||
table_id=table_config.id,
|
||||
output_path=tmp_csv_path,
|
||||
where_filters=table_config.where_filters if table_config.where_filters else None,
|
||||
)
|
||||
|
||||
if export_info["exported_rows"] == 0:
|
||||
logger.info(" -> No data exported")
|
||||
return self._get_partition_totals(partition_dir)
|
||||
|
||||
dtypes = self.keboola_client.get_pandas_dtypes(table_config.id)
|
||||
date_columns = self.keboola_client.get_date_columns(table_config.id)
|
||||
pyarrow_schema = self.keboola_client.get_pyarrow_schema(table_config.id)
|
||||
|
||||
logger.info(f" -> Processing CSV in chunks ({export_info['exported_rows']} rows)...")
|
||||
|
||||
partitions_seen = self._process_csv_to_partitions(
|
||||
tmp_csv_path, table_config, partition_dir,
|
||||
dtypes=dtypes, date_columns=date_columns, pyarrow_schema=pyarrow_schema,
|
||||
)
|
||||
|
||||
self._deduplicate_partitions(
|
||||
table_config, partitions_seen,
|
||||
date_columns=date_columns, pyarrow_schema=pyarrow_schema,
|
||||
)
|
||||
|
||||
totals = self._get_partition_totals(partition_dir)
|
||||
logger.info(
|
||||
f" -> Partitioned sync complete: {totals.get('partitions', 0)} partitions on disk, "
|
||||
f"{totals['rows']} total rows"
|
||||
)
|
||||
|
||||
return totals
|
||||
|
||||
finally:
|
||||
if tmp_csv_path.exists():
|
||||
tmp_csv_path.unlink()
|
||||
|
|
@ -23,7 +23,16 @@ import requests
|
|||
from kbcstorage.client import Client
|
||||
from kbcstorage.tables import Tables
|
||||
|
||||
from src.config import get_config, TableConfig, WhereFilter
|
||||
from dataclasses import dataclass, field
|
||||
from typing import Optional as _Opt
|
||||
|
||||
|
||||
@dataclass
|
||||
class WhereFilter:
|
||||
"""Keboola where filter for export operations."""
|
||||
column: str
|
||||
operator: str = "eq"
|
||||
values: list = field(default_factory=list)
|
||||
|
||||
|
||||
logger = logging.getLogger(__name__)
|
||||
|
|
@ -86,10 +95,8 @@ class KeboolaClient:
|
|||
token: Storage API token. If None, loads from configuration.
|
||||
url: Stack URL. If None, loads from configuration.
|
||||
"""
|
||||
config = get_config()
|
||||
|
||||
self.token = token or config.keboola_token
|
||||
self.url = url or config.keboola_stack_url
|
||||
self.token = token or os.environ.get("KEBOOLA_STORAGE_TOKEN", "")
|
||||
self.url = url or os.environ.get("KEBOOLA_STACK_URL", "")
|
||||
|
||||
if not self.token:
|
||||
raise ValueError(
|
||||
|
|
@ -824,7 +831,7 @@ def create_client() -> KeboolaClient:
|
|||
"""
|
||||
Factory function to create Keboola client.
|
||||
|
||||
Uses configuration from get_config().
|
||||
Uses environment variables for token and URL.
|
||||
|
||||
Returns:
|
||||
KeboolaClient instance
|
||||
|
|
@ -848,21 +855,21 @@ if __name__ == "__main__":
|
|||
print(" ❌ Connection failed!")
|
||||
exit(1)
|
||||
|
||||
# Test metadata
|
||||
# Test metadata with discovered tables
|
||||
print("\n2️⃣ Testing metadata...")
|
||||
config = get_config()
|
||||
if config.tables:
|
||||
test_table = config.tables[0]
|
||||
print(f" Testing table: {test_table.id}")
|
||||
tables = client.discover_all_tables()
|
||||
if tables:
|
||||
test_table_id = tables[0].get("id", tables[0].get("name", ""))
|
||||
print(f" Testing table: {test_table_id}")
|
||||
|
||||
metadata = client.get_table_metadata(test_table.id)
|
||||
metadata = client.get_table_metadata(test_table_id)
|
||||
print(f" ✅ Metadata loaded:")
|
||||
print(f" Columns: {len(metadata.get('columns', []))}")
|
||||
print(f" Rows: {metadata.get('row_count', 0):,}")
|
||||
print(f" Size: {metadata.get('data_size_bytes', 0) / 1024 / 1024:.2f} MB")
|
||||
|
||||
# Test dtypes
|
||||
dtypes = client.get_pandas_dtypes(test_table.id)
|
||||
dtypes = client.get_pandas_dtypes(test_table_id)
|
||||
print(f" Pandas dtypes:")
|
||||
for col, dtype in list(dtypes.items())[:5]:
|
||||
print(f" {col}: {dtype}")
|
||||
|
|
|
|||
|
|
@ -1,194 +0,0 @@
|
|||
"""Tests for Keboola adapter and DataSource ABC / factory in src.data_sync.
|
||||
|
||||
Covers:
|
||||
- DataSource ABC default method behaviour
|
||||
- create_data_source factory: keboola import error, unknown source, dynamic lookup
|
||||
- KeboolaDataSource env var validation
|
||||
"""
|
||||
|
||||
from unittest.mock import patch, MagicMock
|
||||
|
||||
import pytest
|
||||
|
||||
from src.data_sync import DataSource, SyncState, create_data_source
|
||||
|
||||
|
||||
# ---------------------------------------------------------------------------
|
||||
# Helpers
|
||||
# ---------------------------------------------------------------------------
|
||||
|
||||
class _MinimalSource(DataSource):
|
||||
"""Concrete DataSource that only implements the required abstract method."""
|
||||
|
||||
def sync_table(self, table_config, sync_state):
|
||||
return {"success": True, "rows": 0}
|
||||
|
||||
|
||||
# ---------------------------------------------------------------------------
|
||||
# 1. DataSource ABC default methods
|
||||
# ---------------------------------------------------------------------------
|
||||
|
||||
class TestDataSourceABCDefaultMethods:
|
||||
"""Verify that the optional methods on the DataSource ABC return sensible defaults."""
|
||||
|
||||
def test_get_column_metadata_returns_none(self):
|
||||
source = _MinimalSource()
|
||||
assert source.get_column_metadata("any.table.id") is None
|
||||
|
||||
def test_get_source_name_returns_unknown(self):
|
||||
source = _MinimalSource()
|
||||
assert source.get_source_name() == "Unknown"
|
||||
|
||||
|
||||
# ---------------------------------------------------------------------------
|
||||
# 2. Factory: keboola without kbcstorage
|
||||
# ---------------------------------------------------------------------------
|
||||
|
||||
class TestFactoryKeboolaWithoutKbcstorage:
|
||||
"""create_data_source('keboola') must raise ImportError when kbcstorage is missing."""
|
||||
|
||||
def test_raises_import_error(self):
|
||||
# Patch the import inside create_data_source so that importing
|
||||
# connectors.keboola.adapter triggers a ModuleNotFoundError
|
||||
# mentioning kbcstorage (simulates the package not being installed).
|
||||
original_import = __builtins__.__import__ if hasattr(__builtins__, "__import__") else __import__
|
||||
|
||||
def _fake_import(name, *args, **kwargs):
|
||||
if name == "connectors.keboola.adapter":
|
||||
raise ModuleNotFoundError("No module named 'kbcstorage'")
|
||||
return original_import(name, *args, **kwargs)
|
||||
|
||||
with patch("builtins.__import__", side_effect=_fake_import):
|
||||
with pytest.raises(ImportError, match="kbcstorage"):
|
||||
create_data_source("keboola")
|
||||
|
||||
|
||||
# ---------------------------------------------------------------------------
|
||||
# 3. Factory: unknown source type
|
||||
# ---------------------------------------------------------------------------
|
||||
|
||||
class TestFactoryUnknownSource:
|
||||
"""create_data_source with a non-existent source type must raise ValueError."""
|
||||
|
||||
def test_raises_value_error(self):
|
||||
with pytest.raises(ValueError, match="Unknown data source.*nonexistent"):
|
||||
create_data_source("nonexistent")
|
||||
|
||||
def test_error_message_contains_guidance(self):
|
||||
with pytest.raises(ValueError, match="connectors/nonexistent/adapter.py"):
|
||||
create_data_source("nonexistent")
|
||||
|
||||
|
||||
# ---------------------------------------------------------------------------
|
||||
# 4. Factory: dynamic connector lookup
|
||||
# ---------------------------------------------------------------------------
|
||||
|
||||
class TestFactoryDynamicConnectorLookup:
|
||||
"""create_data_source attempts dynamic import for unknown connector types."""
|
||||
|
||||
def test_jira_lookup_falls_through_to_value_error(self):
|
||||
"""'jira' has no adapter.py exporting a DataSource, so the factory
|
||||
should try importing connectors.jira.adapter, fail, and finally
|
||||
raise ValueError."""
|
||||
with pytest.raises(ValueError, match="Unknown data source.*jira"):
|
||||
create_data_source("jira")
|
||||
|
||||
def test_dynamic_import_is_attempted(self):
|
||||
"""Verify that importlib.import_module is called with the expected
|
||||
module path when the source type is not hard-coded."""
|
||||
with patch("src.data_sync.importlib.import_module", side_effect=ModuleNotFoundError) as mock_imp:
|
||||
with pytest.raises(ValueError):
|
||||
create_data_source("custom_source")
|
||||
mock_imp.assert_called_once_with("connectors.custom_source.adapter")
|
||||
|
||||
def test_dynamic_import_with_factory_function(self):
|
||||
"""If the dynamically loaded module exposes create_data_source(),
|
||||
the factory should call it and return its result."""
|
||||
fake_source = _MinimalSource()
|
||||
fake_module = MagicMock()
|
||||
fake_module.create_data_source = MagicMock(return_value=fake_source)
|
||||
|
||||
with patch("src.data_sync.importlib.import_module", return_value=fake_module):
|
||||
result = create_data_source("my_connector")
|
||||
|
||||
assert result is fake_source
|
||||
fake_module.create_data_source.assert_called_once()
|
||||
|
||||
def test_dynamic_import_with_datasource_subclass(self):
|
||||
"""If the dynamically loaded module has no factory but exposes a
|
||||
DataSource subclass, the factory should instantiate it."""
|
||||
import types
|
||||
|
||||
fake_module = types.ModuleType("connectors.my_connector.adapter")
|
||||
fake_module.MyDataSource = _MinimalSource
|
||||
|
||||
with patch("src.data_sync.importlib.import_module", return_value=fake_module):
|
||||
result = create_data_source("my_connector")
|
||||
|
||||
assert isinstance(result, _MinimalSource)
|
||||
|
||||
|
||||
# ---------------------------------------------------------------------------
|
||||
# 5. KeboolaDataSource validates env vars
|
||||
# ---------------------------------------------------------------------------
|
||||
|
||||
class TestKeboolaAdapterValidatesEnvVars:
|
||||
"""KeboolaDataSource.__init__ must raise ValueError when required
|
||||
Keboola env vars are missing."""
|
||||
|
||||
def _make_mock_config(self, token="", stack_url="", project_id=""):
|
||||
"""Build a mock config with the given Keboola credential values."""
|
||||
cfg = MagicMock()
|
||||
cfg.keboola_token = token
|
||||
cfg.keboola_stack_url = stack_url
|
||||
cfg.keboola_project_id = project_id
|
||||
return cfg
|
||||
|
||||
def test_all_missing(self):
|
||||
mock_cfg = self._make_mock_config()
|
||||
with patch("connectors.keboola.adapter.get_config", return_value=mock_cfg):
|
||||
with pytest.raises(ValueError, match="KEBOOLA_STORAGE_TOKEN"):
|
||||
from connectors.keboola.adapter import KeboolaDataSource
|
||||
KeboolaDataSource()
|
||||
|
||||
def test_token_missing(self):
|
||||
mock_cfg = self._make_mock_config(
|
||||
stack_url="https://connection.keboola.com",
|
||||
project_id="12345",
|
||||
)
|
||||
with patch("connectors.keboola.adapter.get_config", return_value=mock_cfg):
|
||||
with pytest.raises(ValueError, match="KEBOOLA_STORAGE_TOKEN"):
|
||||
from connectors.keboola.adapter import KeboolaDataSource
|
||||
KeboolaDataSource()
|
||||
|
||||
def test_stack_url_missing(self):
|
||||
mock_cfg = self._make_mock_config(
|
||||
token="my-token",
|
||||
project_id="12345",
|
||||
)
|
||||
with patch("connectors.keboola.adapter.get_config", return_value=mock_cfg):
|
||||
with pytest.raises(ValueError, match="KEBOOLA_STACK_URL"):
|
||||
from connectors.keboola.adapter import KeboolaDataSource
|
||||
KeboolaDataSource()
|
||||
|
||||
def test_project_id_missing(self):
|
||||
mock_cfg = self._make_mock_config(
|
||||
token="my-token",
|
||||
stack_url="https://connection.keboola.com",
|
||||
)
|
||||
with patch("connectors.keboola.adapter.get_config", return_value=mock_cfg):
|
||||
with pytest.raises(ValueError, match="KEBOOLA_PROJECT_ID"):
|
||||
from connectors.keboola.adapter import KeboolaDataSource
|
||||
KeboolaDataSource()
|
||||
|
||||
def test_error_lists_all_missing_vars(self):
|
||||
"""When multiple env vars are missing, all should appear in the error message."""
|
||||
mock_cfg = self._make_mock_config()
|
||||
with patch("connectors.keboola.adapter.get_config", return_value=mock_cfg):
|
||||
with pytest.raises(ValueError) as exc_info:
|
||||
from connectors.keboola.adapter import KeboolaDataSource
|
||||
KeboolaDataSource()
|
||||
msg = str(exc_info.value)
|
||||
assert "KEBOOLA_STORAGE_TOKEN" in msg
|
||||
assert "KEBOOLA_STACK_URL" in msg
|
||||
assert "KEBOOLA_PROJECT_ID" in msg
|
||||
|
|
@ -15,10 +15,23 @@ from datetime import datetime, timedelta
|
|||
from pathlib import Path
|
||||
from typing import Dict, List, Optional, Any
|
||||
|
||||
from src.config import TableConfig
|
||||
from dataclasses import dataclass as _dataclass
|
||||
from .client import OpenMetadataClient
|
||||
|
||||
|
||||
@_dataclass
|
||||
class TableConfig:
|
||||
"""Minimal table config used by the enricher.
|
||||
|
||||
Attributes expected by CatalogEnricher: id, name, and optional catalog_fqn.
|
||||
Can be constructed from a dict, e.g. ``TableConfig(**row)`` where *row*
|
||||
comes from ``TableRegistryRepository.get()``.
|
||||
"""
|
||||
id: str
|
||||
name: str
|
||||
catalog_fqn: str = ""
|
||||
|
||||
|
||||
logger = logging.getLogger(__name__)
|
||||
|
||||
|
||||
|
|
|
|||
|
|
@ -905,35 +905,35 @@ class SampleDataGenerator:
|
|||
# ── Parquet conversion ─────────────────────────────────────
|
||||
|
||||
def _convert_to_parquet(self, parquet_dir: Path) -> None:
|
||||
"""Convert generated CSVs to Parquet using project's ParquetManager."""
|
||||
# Ensure project root is importable (script may run from any cwd)
|
||||
project_root = Path(__file__).resolve().parent.parent
|
||||
if str(project_root) not in sys.path:
|
||||
sys.path.insert(0, str(project_root))
|
||||
from src.parquet_manager import create_parquet_manager
|
||||
"""Convert generated CSVs to Parquet using DuckDB."""
|
||||
import duckdb
|
||||
|
||||
manager = create_parquet_manager()
|
||||
parquet_dir.mkdir(parents=True, exist_ok=True)
|
||||
logger.info(f" Converting to Parquet -> {parquet_dir}/")
|
||||
|
||||
conn = duckdb.connect()
|
||||
for csv_path in sorted(self.output_dir.glob("*.csv")):
|
||||
table_name = csv_path.stem
|
||||
schema = TABLE_SCHEMAS.get(table_name, {})
|
||||
parquet_path = parquet_dir / f"{table_name}.parquet"
|
||||
|
||||
result = manager.csv_to_parquet(
|
||||
csv_path=csv_path,
|
||||
parquet_path=parquet_path,
|
||||
dtypes=schema.get("dtypes"),
|
||||
parse_dates=schema.get("parse_dates"),
|
||||
date_columns=schema.get("date_columns"),
|
||||
table_id=f"sample.{table_name}",
|
||||
conn.execute(
|
||||
f"COPY (SELECT * FROM read_csv_auto('{csv_path}')) "
|
||||
f"TO '{parquet_path}' (FORMAT PARQUET, COMPRESSION ZSTD)"
|
||||
)
|
||||
|
||||
# Report stats
|
||||
row_count = conn.execute(
|
||||
f"SELECT count(*) FROM '{parquet_path}'"
|
||||
).fetchone()[0]
|
||||
parquet_size = parquet_path.stat().st_size
|
||||
csv_size = csv_path.stat().st_size
|
||||
ratio = csv_size / parquet_size if parquet_size > 0 else 0
|
||||
logger.info(
|
||||
f" {table_name}: {result['rows']:,} rows, "
|
||||
f"{result['parquet_size_bytes'] / 1024:.0f} KB "
|
||||
f"({result['compression_ratio']:.1f}x compression)"
|
||||
f" {table_name}: {row_count:,} rows, "
|
||||
f"{parquet_size / 1024:.0f} KB "
|
||||
f"({ratio:.1f}x compression)"
|
||||
)
|
||||
conn.close()
|
||||
|
||||
# ── Orchestration ──────────────────────────────────────────
|
||||
|
||||
|
|
|
|||
|
|
@ -34,7 +34,8 @@ from connectors.openmetadata.transformer import (
|
|||
sanitize_filename,
|
||||
table_to_yaml_dict,
|
||||
)
|
||||
from src.config import Config
|
||||
from src.db import get_system_db
|
||||
from src.repositories.table_registry import TableRegistryRepository
|
||||
|
||||
# ---------------------------------------------------------------------------
|
||||
# Logging
|
||||
|
|
@ -279,14 +280,14 @@ def _write_metrics_index(
|
|||
|
||||
def export_tables(
|
||||
client: OpenMetadataClient,
|
||||
config: Config,
|
||||
tables: list[dict],
|
||||
docs_dir: Path,
|
||||
catalog_url: str,
|
||||
) -> int:
|
||||
"""
|
||||
Export table metadata from OpenMetadata to YAML files.
|
||||
|
||||
For each table defined in data_description.md:
|
||||
For each table in the registry:
|
||||
1. Derives the OpenMetadata FQN
|
||||
2. Fetches table metadata (columns, owners, tags, description)
|
||||
3. Transforms to YAML dict
|
||||
|
|
@ -294,7 +295,7 @@ def export_tables(
|
|||
|
||||
Args:
|
||||
client: Initialized OpenMetadata API client
|
||||
config: Application config with table definitions
|
||||
tables: List of table dicts from TableRegistryRepository.list_all()
|
||||
docs_dir: Base docs directory
|
||||
catalog_url: Catalog URL for header comments
|
||||
|
||||
|
|
@ -307,10 +308,12 @@ def export_tables(
|
|||
written_files: set[Path] = set()
|
||||
count = 0
|
||||
|
||||
for table_config in config.tables:
|
||||
for tbl in tables:
|
||||
table_id = tbl.get("id", "")
|
||||
table_name = tbl.get("name", "")
|
||||
try:
|
||||
# Derive FQN: explicit override or auto-derive
|
||||
fqn = table_config.catalog_fqn or f"bigquery.{table_config.id}"
|
||||
# Use explicit catalog_fqn if set, otherwise derive from table id
|
||||
fqn = tbl.get("catalog_fqn") or f"bigquery.{table_id}"
|
||||
|
||||
logger.debug(f"Fetching table metadata: {fqn}")
|
||||
raw_table = client.get_table(fqn)
|
||||
|
|
@ -318,7 +321,7 @@ def export_tables(
|
|||
yaml_dict = table_to_yaml_dict(raw_table)
|
||||
|
||||
# Write table YAML
|
||||
file_path = tables_dir / f"{table_config.name}.yml"
|
||||
file_path = tables_dir / f"{table_name}.yml"
|
||||
header = _yaml_header(catalog_url, fqn, entity_type="table")
|
||||
yaml_content = yaml.dump(
|
||||
yaml_dict,
|
||||
|
|
@ -330,10 +333,10 @@ def export_tables(
|
|||
written_files.add(file_path)
|
||||
count += 1
|
||||
|
||||
logger.info(f"Exported table: {table_config.name} ({len(yaml_dict.get('columns', []))} columns)")
|
||||
logger.info(f"Exported table: {table_name} ({len(yaml_dict.get('columns', []))} columns)")
|
||||
|
||||
except Exception as e:
|
||||
logger.warning(f"Failed to export table {table_config.name}: {e}")
|
||||
logger.warning(f"Failed to export table {table_name}: {e}")
|
||||
continue
|
||||
|
||||
# Cleanup stale auto-generated table files
|
||||
|
|
@ -443,10 +446,12 @@ def main() -> None:
|
|||
|
||||
# Export tables
|
||||
try:
|
||||
config = Config()
|
||||
tables_count = export_tables(client, config, docs_dir, catalog_url)
|
||||
conn = get_system_db()
|
||||
repo = TableRegistryRepository(conn)
|
||||
registered_tables = repo.list_all()
|
||||
tables_count = export_tables(client, registered_tables, docs_dir, catalog_url)
|
||||
except Exception as e:
|
||||
logger.warning(f"Table export skipped (config error): {e}")
|
||||
logger.warning(f"Table export skipped (registry error): {e}")
|
||||
tables_count = 0
|
||||
|
||||
# Write sync state
|
||||
|
|
|
|||
653
src/config.py
653
src/config.py
|
|
@ -1,653 +0,0 @@
|
|||
"""
|
||||
Configuration Management
|
||||
|
||||
This module handles:
|
||||
1. Loading environment variables from .env file
|
||||
2. Parsing data_description.md (YAML blocks with table definitions)
|
||||
3. Validating configuration
|
||||
4. Providing structured configuration data for other modules
|
||||
|
||||
SINGLE SOURCE OF TRUTH is data_description.md - it defines:
|
||||
- List of tables to synchronize
|
||||
- Sync strategies (full_refresh vs incremental)
|
||||
- Primary keys and foreign keys
|
||||
- Incremental columns and windows
|
||||
"""
|
||||
|
||||
import os
|
||||
import re
|
||||
import logging
|
||||
from pathlib import Path
|
||||
from typing import Dict, List, Optional, Any
|
||||
from dataclasses import dataclass, field
|
||||
from datetime import datetime, timedelta
|
||||
|
||||
import yaml
|
||||
from dotenv import load_dotenv
|
||||
|
||||
|
||||
# Logging setup
|
||||
logging.basicConfig(
|
||||
level=logging.INFO,
|
||||
format='%(asctime)s - %(name)s - %(levelname)s - %(message)s'
|
||||
)
|
||||
logger = logging.getLogger(__name__)
|
||||
|
||||
|
||||
@dataclass
|
||||
class ForeignKey:
|
||||
"""
|
||||
Representation of foreign key relationship between tables.
|
||||
|
||||
Attributes:
|
||||
column: Column name in this table (e.g., "company_id")
|
||||
references: Reference table and column (e.g., "company.id")
|
||||
description: Relationship description
|
||||
"""
|
||||
column: str
|
||||
references: str
|
||||
description: Optional[str] = None
|
||||
|
||||
|
||||
@dataclass
|
||||
class WhereFilter:
|
||||
"""
|
||||
Filter for exporting subset of table data.
|
||||
|
||||
Used with Keboola Storage API whereFilters parameter.
|
||||
|
||||
Attributes:
|
||||
column: Column name to filter on
|
||||
operator: Comparison operator (eq, ne, gt, ge, lt, le)
|
||||
values: List of values to compare against
|
||||
"""
|
||||
column: str
|
||||
operator: str # eq, ne, gt, ge, lt, le
|
||||
values: List[str] = field(default_factory=list)
|
||||
|
||||
|
||||
@dataclass
|
||||
class TableConfig:
|
||||
"""
|
||||
Configuration for a single table.
|
||||
|
||||
Attributes:
|
||||
id: Full table ID in Keboola (e.g., "in.c-sfdc.company")
|
||||
name: Short table name (e.g., "company")
|
||||
description: Table description
|
||||
primary_key: Primary key column name (optional for remote-only tables)
|
||||
sync_strategy: "full_refresh", "incremental", "partitioned", or "none" (for remote-only tables)
|
||||
incremental_window_days: Number of days to backtrack for incremental sync
|
||||
partition_by: Column name to partition by (for incremental/partitioned with partitions)
|
||||
partition_granularity: Partition granularity: "month", "day", or "year"
|
||||
foreign_keys: List of foreign key relationships
|
||||
where_filters: List of filters to apply when exporting (for downloading subset of data)
|
||||
folder: Override folder name (instead of bucket-level folder_mapping)
|
||||
max_history_days: Max days of history for initial incremental load (None = download all)
|
||||
dataset: Dataset group name for on-demand tables (e.g., "kbc_telemetry_expert")
|
||||
initial_load_chunk_days: Chunk size in days for chunked initial load (default: 30)
|
||||
sync_schedule: Schedule for automatic sync: "every 15m", "every 1h", "daily 05:00" (UTC)
|
||||
profile_after_sync: Run profiler after sync (default True; disable for frequently synced tables)
|
||||
"""
|
||||
id: str
|
||||
name: str
|
||||
description: str
|
||||
primary_key: str = "" # Optional for remote-only tables
|
||||
sync_strategy: str = "none" # "full_refresh", "incremental", "partitioned", or "none" (remote-only)
|
||||
incremental_window_days: Optional[int] = None
|
||||
partition_by: Optional[str] = None
|
||||
partition_granularity: Optional[str] = None # "month", "day", "year"
|
||||
foreign_keys: List[ForeignKey] = field(default_factory=list)
|
||||
where_filters: List[WhereFilter] = field(default_factory=list)
|
||||
folder: Optional[str] = None
|
||||
max_history_days: Optional[int] = None
|
||||
dataset: Optional[str] = None
|
||||
initial_load_chunk_days: int = 30
|
||||
incremental_column: Optional[str] = None # Column for timestamp-based incremental sync (BigQuery)
|
||||
columns: Optional[List[str]] = None # Subset of columns to sync (None = all)
|
||||
row_filter: Optional[str] = None # SQL WHERE clause for filtering (e.g., "event_date >= '2024-01-01'")
|
||||
query_mode: str = "local" # "local" (Parquet) | "remote" (BQ direct) | "hybrid" (sync subset, query BQ)
|
||||
bq_entity_type: str = "view" # "view" (Python BQ client) | "table" (DuckDB BQ extension)
|
||||
partition_column_type: str = "TIMESTAMP" # BQ SQL type for partition column: "DATE", "TIMESTAMP", "DATETIME"
|
||||
catalog_fqn: Optional[str] = None # Explicit OpenMetadata FQN override (auto-derived if not set)
|
||||
sync_schedule: Optional[str] = None # Schedule: "every 15m", "every 1h", "daily 05:00" (UTC)
|
||||
profile_after_sync: bool = True # Run profiler after sync (disable for frequently synced tables)
|
||||
|
||||
def __post_init__(self):
|
||||
"""Validate configuration after initialization."""
|
||||
# Validate query_mode
|
||||
valid_query_modes = ("local", "remote", "hybrid")
|
||||
if self.query_mode not in valid_query_modes:
|
||||
raise ValueError(
|
||||
f"Invalid query_mode '{self.query_mode}' for table {self.id}. "
|
||||
f"Allowed values: {', '.join(valid_query_modes)}"
|
||||
)
|
||||
|
||||
# Validate bq_entity_type
|
||||
valid_entity_types = ("view", "table")
|
||||
if self.bq_entity_type not in valid_entity_types:
|
||||
raise ValueError(
|
||||
f"Invalid bq_entity_type '{self.bq_entity_type}' for table {self.id}. "
|
||||
f"Allowed values: {', '.join(valid_entity_types)}"
|
||||
)
|
||||
|
||||
# Validate sync_strategy
|
||||
if self.sync_strategy not in ["full_refresh", "incremental", "partitioned", "none"]:
|
||||
raise ValueError(
|
||||
f"Invalid sync_strategy '{self.sync_strategy}' for table {self.id}. "
|
||||
f"Allowed values: 'full_refresh', 'incremental', 'partitioned', 'none'"
|
||||
)
|
||||
|
||||
# For incremental strategy:
|
||||
# - changedSince is calculated from last sync timestamp (Keboola internal)
|
||||
# - partition_by is optional - if set, output will be partitioned
|
||||
if self.sync_strategy == "incremental":
|
||||
if not self.incremental_window_days:
|
||||
# Default 7 days if not specified
|
||||
self.incremental_window_days = 7
|
||||
logger.warning(
|
||||
f"Table {self.id}: incremental_window_days not set, "
|
||||
f"using default 7 days"
|
||||
)
|
||||
# If partition_by is set, validate partition_granularity
|
||||
if self.partition_by:
|
||||
if not self.partition_granularity:
|
||||
self.partition_granularity = "month"
|
||||
logger.info(
|
||||
f"Table {self.id}: partition_granularity not set, "
|
||||
f"using default 'month'"
|
||||
)
|
||||
if self.partition_granularity not in ["month", "day", "year"]:
|
||||
raise ValueError(
|
||||
f"Invalid partition_granularity '{self.partition_granularity}' for table {self.id}. "
|
||||
f"Allowed values: 'month', 'day', 'year'"
|
||||
)
|
||||
|
||||
# Validate partition_column_type
|
||||
valid_column_types = ("DATE", "TIMESTAMP", "DATETIME")
|
||||
if self.partition_column_type not in valid_column_types:
|
||||
raise ValueError(
|
||||
f"Invalid partition_column_type '{self.partition_column_type}' for table {self.id}. "
|
||||
f"Allowed values: {', '.join(valid_column_types)}"
|
||||
)
|
||||
|
||||
# Validate sync_schedule format
|
||||
if self.sync_schedule:
|
||||
import re as _re
|
||||
valid_schedule = (
|
||||
_re.match(r"^every \d+[mh]$", self.sync_schedule)
|
||||
or _re.match(r"^daily \d{2}:\d{2}(,\d{2}:\d{2})*$", self.sync_schedule)
|
||||
)
|
||||
if not valid_schedule:
|
||||
raise ValueError(
|
||||
f"Invalid sync_schedule '{self.sync_schedule}' for table {self.id}. "
|
||||
f"Allowed formats: 'every 15m', 'every 1h', 'daily 05:00', "
|
||||
f"'daily 07:00,13:00,18:00'"
|
||||
)
|
||||
|
||||
# For partitioned, partition_by must be defined
|
||||
if self.sync_strategy == "partitioned":
|
||||
if not self.partition_by:
|
||||
raise ValueError(
|
||||
f"Table {self.id} has sync_strategy='partitioned', "
|
||||
f"but partition_by is missing"
|
||||
)
|
||||
if not self.partition_granularity:
|
||||
self.partition_granularity = "month"
|
||||
logger.info(
|
||||
f"Table {self.id}: partition_granularity not set, "
|
||||
f"using default 'month'"
|
||||
)
|
||||
if self.partition_granularity not in ["month", "day", "year"]:
|
||||
raise ValueError(
|
||||
f"Invalid partition_granularity '{self.partition_granularity}' for table {self.id}. "
|
||||
f"Allowed values: 'month', 'day', 'year'"
|
||||
)
|
||||
|
||||
def get_primary_key_columns(self) -> List[str]:
|
||||
"""
|
||||
Get primary key as list of column names.
|
||||
|
||||
Supports both single and composite primary keys.
|
||||
Composite PKs are defined as comma-separated string: "col1, col2"
|
||||
|
||||
Returns:
|
||||
List of column names forming the primary key
|
||||
"""
|
||||
# Split by comma and strip whitespace
|
||||
return [col.strip() for col in self.primary_key.split(",")]
|
||||
|
||||
def is_partitioned(self) -> bool:
|
||||
"""Check if table output should be partitioned.
|
||||
|
||||
Returns True for:
|
||||
- partitioned strategy (always partitioned)
|
||||
- incremental strategy with partition_by set
|
||||
"""
|
||||
if self.sync_strategy == "partitioned":
|
||||
return True
|
||||
if self.sync_strategy == "incremental" and self.partition_by:
|
||||
return True
|
||||
return False
|
||||
|
||||
|
||||
class Config:
|
||||
"""
|
||||
Main configuration class.
|
||||
|
||||
Loads environment variables and parses data_description.md.
|
||||
Provides access to all configuration parameters.
|
||||
"""
|
||||
|
||||
def __init__(self, env_file: Optional[str] = None):
|
||||
"""
|
||||
Initialize configuration.
|
||||
|
||||
Args:
|
||||
env_file: Path to .env file. If None, looks for .env in project root.
|
||||
"""
|
||||
# Find project root (folder containing data_description.md)
|
||||
self.project_root = self._find_project_root()
|
||||
|
||||
# Load environment variables
|
||||
if env_file is None:
|
||||
env_file = self.project_root / ".env"
|
||||
|
||||
if env_file.exists():
|
||||
load_dotenv(env_file)
|
||||
logger.info(f"Loaded from .env: {env_file}")
|
||||
else:
|
||||
logger.warning(
|
||||
f".env file not found: {env_file}. "
|
||||
f"Use config/.env.template as reference."
|
||||
)
|
||||
|
||||
# Read by connectors/keboola/ if enabled
|
||||
self.keboola_token = os.getenv("KEBOOLA_STORAGE_TOKEN")
|
||||
self.keboola_stack_url = os.getenv("KEBOOLA_STACK_URL")
|
||||
self.keboola_project_id = os.getenv("KEBOOLA_PROJECT_ID")
|
||||
self.data_dir = Path(os.getenv("DATA_DIR", "./data"))
|
||||
self.docs_output_dir = Path(os.getenv("DOCS_OUTPUT_DIR", "./docs"))
|
||||
self.data_source = os.getenv("DATA_SOURCE", "local")
|
||||
self.log_level = os.getenv("LOG_LEVEL", "INFO")
|
||||
|
||||
# Set log level
|
||||
logging.getLogger().setLevel(self.log_level)
|
||||
|
||||
# Validate required environment variables
|
||||
self._validate_env_vars()
|
||||
|
||||
# Parse data_description.md
|
||||
self.tables, self.folder_mapping = self._parse_data_description()
|
||||
|
||||
logger.info(f"Configuration loaded: {len(self.tables)} tables")
|
||||
|
||||
def _find_project_root(self) -> Path:
|
||||
"""
|
||||
Find project root (folder containing docs/data_description.md).
|
||||
|
||||
Searches from current folder upwards.
|
||||
|
||||
Returns:
|
||||
Path to project root
|
||||
|
||||
Raises:
|
||||
FileNotFoundError: If docs/data_description.md is not found
|
||||
"""
|
||||
current = Path.cwd()
|
||||
|
||||
# Try current folder first
|
||||
if (current / "docs" / "data_description.md").exists():
|
||||
return current
|
||||
|
||||
# Try parent folders (up to 5 levels)
|
||||
for _ in range(5):
|
||||
current = current.parent
|
||||
if (current / "docs" / "data_description.md").exists():
|
||||
return current
|
||||
|
||||
raise FileNotFoundError(
|
||||
"docs/data_description.md not found. "
|
||||
"Make sure you're running from project root."
|
||||
)
|
||||
|
||||
def _resolve_placeholder(self, value: str) -> str:
|
||||
"""
|
||||
Resolve placeholders in filter values.
|
||||
|
||||
Supported placeholders:
|
||||
- {{last_week}}: 7 days ago
|
||||
- {{last_month}}: 30 days ago
|
||||
- {{last_2_months}}: 60 days ago
|
||||
- {{last_3_months}}: 90 days ago
|
||||
- {{last_6_months}}: 180 days ago
|
||||
- {{last_year}}: 365 days ago
|
||||
- {{last_2_years}}: 730 days ago
|
||||
- {{today}}: Today's date
|
||||
|
||||
Args:
|
||||
value: String that may contain placeholder
|
||||
|
||||
Returns:
|
||||
Resolved string with actual date values
|
||||
"""
|
||||
if not isinstance(value, str):
|
||||
return value
|
||||
|
||||
today = datetime.now()
|
||||
|
||||
placeholders = {
|
||||
"{{last_week}}": (today - timedelta(days=7)).strftime("%Y-%m-%d"),
|
||||
"{{last_month}}": (today - timedelta(days=30)).strftime("%Y-%m-%d"),
|
||||
"{{last_2_months}}": (today - timedelta(days=60)).strftime("%Y-%m-%d"),
|
||||
"{{last_3_months}}": (today - timedelta(days=90)).strftime("%Y-%m-%d"),
|
||||
"{{last_6_months}}": (today - timedelta(days=180)).strftime("%Y-%m-%d"),
|
||||
"{{last_year}}": (today - timedelta(days=365)).strftime("%Y-%m-%d"),
|
||||
"{{last_2_years}}": (today - timedelta(days=730)).strftime("%Y-%m-%d"),
|
||||
"{{today}}": today.strftime("%Y-%m-%d"),
|
||||
}
|
||||
|
||||
result = value
|
||||
for placeholder, replacement in placeholders.items():
|
||||
if placeholder in result:
|
||||
result = result.replace(placeholder, replacement)
|
||||
logger.debug(f"Resolved placeholder: {placeholder} -> {replacement}")
|
||||
|
||||
return result
|
||||
|
||||
def _validate_env_vars(self):
|
||||
"""
|
||||
Validate that required environment variables are set based on data source type.
|
||||
|
||||
Raises:
|
||||
ValueError: If any required variable is missing
|
||||
"""
|
||||
# Keboola env vars are validated by connectors/keboola/adapter.py at init time.
|
||||
# No source-specific validation needed here.
|
||||
pass
|
||||
|
||||
def _parse_data_description(self) -> tuple[List[TableConfig], Dict[str, str]]:
|
||||
"""
|
||||
Parse docs/data_description.md and extract table definitions.
|
||||
|
||||
Looks for YAML blocks in markdown file and parses them.
|
||||
|
||||
Returns:
|
||||
Tuple of (List of TableConfig objects, folder_mapping dict)
|
||||
|
||||
Raises:
|
||||
FileNotFoundError: If docs/data_description.md doesn't exist
|
||||
yaml.YAMLError: If YAML is invalid
|
||||
"""
|
||||
# Check CONFIG_DIR first, then project root
|
||||
config_dir = Path(os.environ.get("CONFIG_DIR", ""))
|
||||
if config_dir and (config_dir / "data_description.md").exists():
|
||||
data_desc_path = config_dir / "data_description.md"
|
||||
else:
|
||||
data_desc_path = self.project_root / "docs" / "data_description.md"
|
||||
|
||||
if not data_desc_path.exists():
|
||||
raise FileNotFoundError(
|
||||
f"docs/data_description.md not found: {data_desc_path}"
|
||||
)
|
||||
|
||||
# Collect all markdown files to parse: main + dataset files
|
||||
md_files = [data_desc_path]
|
||||
datasets_dir = self.project_root / "docs" / "datasets"
|
||||
if datasets_dir.exists():
|
||||
for md_file in sorted(datasets_dir.glob("*.md")):
|
||||
md_files.append(md_file)
|
||||
logger.info(f"Found dataset file: {md_file.name}")
|
||||
|
||||
# Find YAML blocks (between ```yaml and ```) from all files
|
||||
yaml_pattern = r'```yaml\n(.*?)```'
|
||||
yaml_matches = []
|
||||
for md_file in md_files:
|
||||
content = md_file.read_text()
|
||||
yaml_matches.extend(re.findall(yaml_pattern, content, re.DOTALL))
|
||||
|
||||
if not yaml_matches:
|
||||
raise ValueError(
|
||||
"data_description.md contains no YAML blocks. "
|
||||
"Make sure tables are defined in ```yaml blocks."
|
||||
)
|
||||
|
||||
# Parse all YAML blocks and merge them
|
||||
all_tables = []
|
||||
folder_mapping = {}
|
||||
for yaml_block in yaml_matches:
|
||||
try:
|
||||
data = yaml.safe_load(yaml_block)
|
||||
if data:
|
||||
if "tables" in data:
|
||||
all_tables.extend(data["tables"])
|
||||
if "folder_mapping" in data:
|
||||
folder_mapping.update(data["folder_mapping"])
|
||||
except yaml.YAMLError as e:
|
||||
logger.error(f"Error parsing YAML: {e}")
|
||||
raise
|
||||
|
||||
if not all_tables:
|
||||
raise ValueError(
|
||||
"data_description.md contains no tables. "
|
||||
"Make sure YAML block contains 'tables:' key."
|
||||
)
|
||||
|
||||
# Convert to TableConfig objects
|
||||
table_configs = []
|
||||
for table_data in all_tables:
|
||||
# Parse foreign keys
|
||||
fk_list = []
|
||||
if "foreign_keys" in table_data:
|
||||
for fk_data in table_data["foreign_keys"]:
|
||||
fk = ForeignKey(
|
||||
column=fk_data["column"],
|
||||
references=fk_data["references"],
|
||||
description=fk_data.get("description")
|
||||
)
|
||||
fk_list.append(fk)
|
||||
|
||||
# Parse where filters with placeholder resolution
|
||||
wf_list = []
|
||||
if "where_filters" in table_data:
|
||||
for wf_data in table_data["where_filters"]:
|
||||
# Resolve placeholders in values
|
||||
resolved_values = [
|
||||
self._resolve_placeholder(v) for v in wf_data.get("values", [])
|
||||
]
|
||||
wf = WhereFilter(
|
||||
column=wf_data["column"],
|
||||
operator=wf_data["operator"],
|
||||
values=resolved_values
|
||||
)
|
||||
wf_list.append(wf)
|
||||
|
||||
# Create TableConfig
|
||||
config = TableConfig(
|
||||
id=table_data["id"],
|
||||
name=table_data["name"],
|
||||
description=table_data["description"],
|
||||
primary_key=table_data.get("primary_key", ""),
|
||||
sync_strategy=table_data.get("sync_strategy", "none"),
|
||||
incremental_window_days=table_data.get("incremental_window_days"),
|
||||
partition_by=table_data.get("partition_by"),
|
||||
partition_granularity=table_data.get("partition_granularity"),
|
||||
foreign_keys=fk_list,
|
||||
where_filters=wf_list,
|
||||
folder=table_data.get("folder"),
|
||||
max_history_days=table_data.get("max_history_days"),
|
||||
dataset=table_data.get("dataset"),
|
||||
initial_load_chunk_days=table_data.get("initial_load_chunk_days", 30),
|
||||
incremental_column=table_data.get("incremental_column"),
|
||||
columns=table_data.get("columns"),
|
||||
row_filter=table_data.get("row_filter"),
|
||||
query_mode=table_data.get("query_mode", "local"),
|
||||
bq_entity_type=table_data.get("bq_entity_type", "view"),
|
||||
partition_column_type=table_data.get("partition_column_type", "TIMESTAMP"),
|
||||
catalog_fqn=table_data.get("catalog_fqn"),
|
||||
sync_schedule=table_data.get("sync_schedule"),
|
||||
profile_after_sync=table_data.get("profile_after_sync", True),
|
||||
)
|
||||
table_configs.append(config)
|
||||
|
||||
return table_configs, folder_mapping
|
||||
|
||||
def get_table_config(self, table_id: str) -> Optional[TableConfig]:
|
||||
"""
|
||||
Get configuration for specific table by ID.
|
||||
|
||||
Args:
|
||||
table_id: Full table ID (e.g., "in.c-sfdc.company")
|
||||
|
||||
Returns:
|
||||
TableConfig or None if table not in configuration
|
||||
"""
|
||||
for table in self.tables:
|
||||
if table.id == table_id:
|
||||
return table
|
||||
return None
|
||||
|
||||
def get_parquet_path(self, table_config: TableConfig) -> Path:
|
||||
"""
|
||||
Get path to Parquet file for given table.
|
||||
|
||||
Format: data/parquet/{folder_name}/{table_name}.parquet
|
||||
For partitioned tables: data/parquet/{folder_name}/{table_name}/ (directory)
|
||||
|
||||
Folder name is determined by folder_mapping in data_description.md.
|
||||
Falls back to bucket name if no mapping exists.
|
||||
|
||||
Args:
|
||||
table_config: Table configuration
|
||||
|
||||
Returns:
|
||||
Path to Parquet file (or directory for partitioned tables)
|
||||
"""
|
||||
# Extract bucket name from table ID (e.g., "in.c-crm" from "in.c-crm.company")
|
||||
bucket_name = ".".join(table_config.id.split(".")[:-1])
|
||||
|
||||
# Use folder mapping if available, otherwise fall back to bucket name
|
||||
folder_name = self.folder_mapping.get(bucket_name, bucket_name)
|
||||
|
||||
# Table-level folder override (e.g., folder: kbc_telemetry_expert)
|
||||
if table_config.folder:
|
||||
folder_name = table_config.folder
|
||||
|
||||
parquet_dir = self.data_dir / "parquet" / folder_name
|
||||
parquet_dir.mkdir(parents=True, exist_ok=True)
|
||||
|
||||
if table_config.is_partitioned():
|
||||
# For partitioned tables, return directory path
|
||||
partition_dir = parquet_dir / table_config.name
|
||||
partition_dir.mkdir(parents=True, exist_ok=True)
|
||||
return partition_dir
|
||||
else:
|
||||
return parquet_dir / f"{table_config.name}.parquet"
|
||||
|
||||
def get_partition_path(self, table_config: TableConfig, partition_key: str) -> Path:
|
||||
"""
|
||||
Get path to specific partition file.
|
||||
|
||||
Args:
|
||||
table_config: Table configuration (must be partitioned)
|
||||
partition_key: Partition key (e.g., "2026_01" for monthly)
|
||||
|
||||
Returns:
|
||||
Path to partition Parquet file
|
||||
"""
|
||||
if not table_config.is_partitioned():
|
||||
raise ValueError(f"Table {table_config.id} is not partitioned")
|
||||
|
||||
partition_dir = self.get_parquet_path(table_config)
|
||||
return partition_dir / f"{partition_key}.parquet"
|
||||
|
||||
def get_metadata_path(self) -> Path:
|
||||
"""
|
||||
Get path to metadata folder.
|
||||
|
||||
Returns:
|
||||
Path to metadata folder
|
||||
"""
|
||||
metadata_dir = self.data_dir / "metadata"
|
||||
metadata_dir.mkdir(parents=True, exist_ok=True)
|
||||
return metadata_dir
|
||||
|
||||
def get_staging_path(self) -> Path:
|
||||
"""
|
||||
Get path to staging folder for temporary files.
|
||||
|
||||
Uses /tmp/data_analyst_staging for faster I/O and to avoid filling /data disk.
|
||||
Directory is created by deploy.sh on server startup.
|
||||
|
||||
Returns:
|
||||
Path to staging folder
|
||||
"""
|
||||
staging_dir = Path("/tmp/data_analyst_staging")
|
||||
staging_dir.mkdir(parents=True, exist_ok=True)
|
||||
return staging_dir
|
||||
|
||||
def get_duckdb_path(self) -> Path:
|
||||
"""
|
||||
Get path to DuckDB database.
|
||||
|
||||
Returns:
|
||||
Path to DuckDB file
|
||||
"""
|
||||
duckdb_dir = self.data_dir / "duckdb"
|
||||
duckdb_dir.mkdir(parents=True, exist_ok=True)
|
||||
return duckdb_dir / "analytics.duckdb"
|
||||
|
||||
|
||||
# Singleton instance for easy access from entire application
|
||||
_config_instance: Optional[Config] = None
|
||||
|
||||
|
||||
def get_config() -> Config:
|
||||
"""
|
||||
Get singleton configuration instance.
|
||||
|
||||
On first call initializes configuration, then returns existing instance.
|
||||
|
||||
Returns:
|
||||
Config instance
|
||||
"""
|
||||
global _config_instance
|
||||
if _config_instance is None:
|
||||
_config_instance = Config()
|
||||
return _config_instance
|
||||
|
||||
|
||||
# For testing - allows resetting config
|
||||
def reset_config():
|
||||
"""Reset singleton config instance. For testing only."""
|
||||
global _config_instance
|
||||
_config_instance = None
|
||||
|
||||
|
||||
if __name__ == "__main__":
|
||||
# Test configuration
|
||||
print("🔧 Testing configuration...")
|
||||
|
||||
try:
|
||||
config = get_config()
|
||||
|
||||
print(f"\n✅ Configuration loaded successfully!")
|
||||
print(f" Project ID: {config.keboola_project_id}")
|
||||
print(f" Stack URL: {config.keboola_stack_url}")
|
||||
print(f" Data dir: {config.data_dir}")
|
||||
print(f" Number of tables: {len(config.tables)}")
|
||||
|
||||
print(f"\n📊 Tables:")
|
||||
for table in config.tables:
|
||||
print(f" - {table.name} ({table.id})")
|
||||
print(f" Strategy: {table.sync_strategy}")
|
||||
if table.sync_strategy == "incremental":
|
||||
print(f" Incremental window: {table.incremental_window_days} days")
|
||||
if table.partition_by:
|
||||
print(f" Partitioned by: {table.partition_by} ({table.partition_granularity})")
|
||||
print(f" Parquet: {config.get_parquet_path(table)}")
|
||||
|
||||
except Exception as e:
|
||||
print(f"\n❌ Error: {e}")
|
||||
import traceback
|
||||
traceback.print_exc()
|
||||
734
src/data_sync.py
734
src/data_sync.py
|
|
@ -1,734 +0,0 @@
|
|||
"""
|
||||
Data Synchronization Manager
|
||||
|
||||
Orchestrates data synchronization from configured sources to local Parquet files.
|
||||
|
||||
Main functions:
|
||||
1. Tracking sync state (when was last synchronization)
|
||||
2. DataSource ABC for pluggable connectors
|
||||
3. Sync single table or all tables at once
|
||||
4. Progress tracking and error handling
|
||||
5. Schema generation from synced Parquet files
|
||||
|
||||
Sync State:
|
||||
- Stored in data/metadata/sync_state.json
|
||||
- Contains timestamp of last synchronization for each table
|
||||
- Used for incremental sync (changedSince parameter)
|
||||
"""
|
||||
|
||||
import importlib
|
||||
import json
|
||||
import logging
|
||||
from abc import ABC, abstractmethod
|
||||
from pathlib import Path
|
||||
from typing import Dict, List, Optional, Any
|
||||
from datetime import datetime
|
||||
|
||||
from tqdm import tqdm
|
||||
|
||||
from .config import get_config, TableConfig
|
||||
from config.loader import load_instance_config
|
||||
from connectors.openmetadata.enricher import CatalogEnricher
|
||||
|
||||
|
||||
logger = logging.getLogger(__name__)
|
||||
|
||||
|
||||
class SyncState:
|
||||
"""
|
||||
Synchronization state management.
|
||||
|
||||
Stores and loads information about last synchronization of each table.
|
||||
"""
|
||||
|
||||
def __init__(self, state_file: Path):
|
||||
"""
|
||||
Args:
|
||||
state_file: Path to JSON file with sync state
|
||||
"""
|
||||
self.state_file = state_file
|
||||
self.state: Dict[str, Any] = self._load_state()
|
||||
|
||||
def _load_state(self) -> Dict[str, Any]:
|
||||
"""
|
||||
Load sync state from disk.
|
||||
|
||||
Returns:
|
||||
Dictionary with sync state
|
||||
"""
|
||||
if self.state_file.exists():
|
||||
try:
|
||||
with open(self.state_file, "r") as f:
|
||||
return json.load(f)
|
||||
except Exception as e:
|
||||
logger.error(f"Error loading sync state: {e}")
|
||||
return {}
|
||||
return {}
|
||||
|
||||
def _save_state(self):
|
||||
"""
|
||||
Save sync state to disk.
|
||||
|
||||
Creates data/metadata/ directory if needed.
|
||||
"""
|
||||
try:
|
||||
self.state_file.parent.mkdir(parents=True, exist_ok=True)
|
||||
with open(self.state_file, "w") as f:
|
||||
json.dump(self.state, f, indent=2, default=str)
|
||||
except Exception as e:
|
||||
logger.error(f"Error saving sync state: {e}")
|
||||
|
||||
def get_last_sync(self, table_id: str) -> Optional[str]:
|
||||
"""
|
||||
Get timestamp of last synchronization for given table.
|
||||
|
||||
Args:
|
||||
table_id: Table identifier
|
||||
|
||||
Returns:
|
||||
ISO timestamp string, or None if not synced yet
|
||||
"""
|
||||
table_state = self.state.get(table_id, {})
|
||||
return table_state.get("last_sync")
|
||||
|
||||
def get_table_state(self, table_id: str) -> Dict[str, Any]:
|
||||
"""
|
||||
Get complete sync state for a table.
|
||||
|
||||
Args:
|
||||
table_id: Table identifier
|
||||
|
||||
Returns:
|
||||
Dictionary with table sync state
|
||||
"""
|
||||
return self.state.get(table_id, {})
|
||||
|
||||
def update_sync(
|
||||
self,
|
||||
table_id: str,
|
||||
table_name: str,
|
||||
strategy: str,
|
||||
rows: int,
|
||||
file_size_bytes: int,
|
||||
columns: int = 0,
|
||||
uncompressed_bytes: int = 0,
|
||||
):
|
||||
"""
|
||||
Update synchronization state for a table.
|
||||
|
||||
Args:
|
||||
table_id: Table identifier
|
||||
table_name: Human-readable table name
|
||||
strategy: Sync strategy used
|
||||
rows: Number of rows synced
|
||||
file_size_bytes: Size of Parquet file in bytes
|
||||
columns: Number of columns
|
||||
uncompressed_bytes: Uncompressed data size
|
||||
"""
|
||||
self.state[table_id] = {
|
||||
"table_name": table_name,
|
||||
"last_sync": datetime.now().isoformat(),
|
||||
"strategy": strategy,
|
||||
"rows": rows,
|
||||
"columns": columns,
|
||||
"file_size_mb": round(file_size_bytes / 1024 / 1024, 2),
|
||||
"uncompressed_mb": round(uncompressed_bytes / 1024 / 1024, 2),
|
||||
}
|
||||
|
||||
self._save_state()
|
||||
|
||||
|
||||
class DataSource(ABC):
|
||||
"""
|
||||
Abstract class for data source.
|
||||
|
||||
Connectors implement this to integrate different data backends.
|
||||
See connectors/keboola/ for a reference implementation.
|
||||
"""
|
||||
|
||||
@abstractmethod
|
||||
def sync_table(
|
||||
self,
|
||||
table_config: TableConfig,
|
||||
sync_state: SyncState,
|
||||
) -> Dict[str, Any]:
|
||||
"""
|
||||
Synchronize single table.
|
||||
|
||||
Args:
|
||||
table_config: Table configuration
|
||||
sync_state: Sync state manager
|
||||
|
||||
Returns:
|
||||
Dictionary with sync result
|
||||
"""
|
||||
pass
|
||||
|
||||
def discover_tables(self) -> List[Dict[str, Any]]:
|
||||
"""List all available tables in the data source.
|
||||
|
||||
Returns list of dicts with at minimum:
|
||||
id, name, bucket_id, columns, row_count, size_bytes,
|
||||
primary_key, last_change
|
||||
Default: empty list (source doesn't support discovery).
|
||||
"""
|
||||
return []
|
||||
|
||||
def get_column_metadata(self, table_id: str) -> Optional[Dict[str, Any]]:
|
||||
"""Return processed column metadata for schema generation.
|
||||
|
||||
Returns:
|
||||
{"columns": {"col_name": {"source_type": "...", "description": "..."}}}
|
||||
or None if the source doesn't support metadata.
|
||||
"""
|
||||
return None
|
||||
|
||||
def get_source_name(self) -> str:
|
||||
"""Display name of this data source for schema comments."""
|
||||
return "Unknown"
|
||||
|
||||
|
||||
def _get_uncompressed_size(parquet_path: Path) -> int:
|
||||
"""Read total uncompressed size from Parquet file metadata."""
|
||||
try:
|
||||
import pyarrow.parquet as pq
|
||||
|
||||
meta = pq.ParquetFile(parquet_path).metadata
|
||||
total = 0
|
||||
for rg_idx in range(meta.num_row_groups):
|
||||
rg = meta.row_group(rg_idx)
|
||||
for col_idx in range(rg.num_columns):
|
||||
total += rg.column(col_idx).total_uncompressed_size
|
||||
return total
|
||||
except Exception:
|
||||
return 0
|
||||
|
||||
|
||||
class DataSyncManager:
|
||||
"""
|
||||
Main data synchronization orchestrator.
|
||||
|
||||
Manages sync of all tables and tracks results.
|
||||
"""
|
||||
|
||||
def __init__(self):
|
||||
"""Initialize sync manager."""
|
||||
self.config = get_config()
|
||||
self.sync_state = SyncState(
|
||||
self.config.get_metadata_path() / "sync_state.json"
|
||||
)
|
||||
self.data_source = create_data_source()
|
||||
|
||||
# Initialize OpenMetadata catalog enricher
|
||||
try:
|
||||
instance_config = load_instance_config()
|
||||
self.catalog_enricher = CatalogEnricher(instance_config)
|
||||
except Exception as e:
|
||||
logger.warning(f"Failed to initialize catalog enricher: {e}")
|
||||
self.catalog_enricher = CatalogEnricher({}) # Disabled enricher
|
||||
|
||||
def _generate_schema_yaml(self):
|
||||
"""
|
||||
Generate schema.yml file with actual table schemas from Parquet files.
|
||||
|
||||
This file is auto-generated and contains:
|
||||
- Table names and descriptions
|
||||
- Column names, types (from Parquet), and descriptions (from source metadata)
|
||||
- Primary keys
|
||||
|
||||
Output: DOCS_OUTPUT_DIR/schema.yml (default: ./docs/schema.yml)
|
||||
"""
|
||||
import yaml
|
||||
import pyarrow.parquet as pq
|
||||
|
||||
source_name = self.data_source.get_source_name()
|
||||
|
||||
logger.info("Generating schema.yml from synced tables...")
|
||||
|
||||
schema_data = {
|
||||
"_metadata": {
|
||||
"_schema_version": 2,
|
||||
"generated_at": datetime.now().isoformat(),
|
||||
"note": "AUTO-GENERATED - DO NOT EDIT. This file contains actual table schemas from synced Parquet files.",
|
||||
"source": source_name,
|
||||
"generator": "src/data_sync.py::DataSyncManager._generate_schema_yaml()",
|
||||
},
|
||||
"tables": {},
|
||||
}
|
||||
|
||||
# Process each table in configuration
|
||||
for table_config in self.config.tables:
|
||||
try:
|
||||
parquet_path = self.config.get_parquet_path(table_config)
|
||||
|
||||
# Skip if Parquet doesn't exist (table not synced yet)
|
||||
if table_config.partition_by:
|
||||
if not parquet_path.exists() or not list(parquet_path.glob("*.parquet")):
|
||||
logger.debug(f" Skipping {table_config.name} (not synced yet)")
|
||||
continue
|
||||
first_partition = next(parquet_path.glob("*.parquet"))
|
||||
pf = pq.ParquetFile(first_partition)
|
||||
else:
|
||||
if not parquet_path.exists():
|
||||
logger.debug(f" Skipping {table_config.name} (not synced yet)")
|
||||
continue
|
||||
pf = pq.ParquetFile(parquet_path)
|
||||
|
||||
arrow_schema = pf.schema_arrow
|
||||
|
||||
# Get column metadata from data source (if supported)
|
||||
col_metadata = self.data_source.get_column_metadata(table_config.id)
|
||||
|
||||
# Enrich with catalog metadata (OpenMetadata)
|
||||
catalog_data = self.catalog_enricher.enrich_table(table_config)
|
||||
|
||||
# Extract column information
|
||||
columns = []
|
||||
for field_item in arrow_schema:
|
||||
col_name = field_item.name
|
||||
col_name_lower = col_name.lower()
|
||||
pyarrow_type = str(field_item.type)
|
||||
|
||||
column_info = {
|
||||
"name": col_name,
|
||||
"type": pyarrow_type,
|
||||
}
|
||||
|
||||
# Priority for description: catalog > BQ API > (nothing)
|
||||
description = None
|
||||
if catalog_data and col_name_lower in catalog_data.columns:
|
||||
description = catalog_data.columns[col_name_lower].description
|
||||
elif col_metadata and "columns" in col_metadata:
|
||||
col_meta = col_metadata["columns"].get(col_name, {})
|
||||
description = col_meta.get("description")
|
||||
|
||||
if description:
|
||||
column_info["description"] = description
|
||||
|
||||
# Add source type from connector metadata
|
||||
if col_metadata and "columns" in col_metadata:
|
||||
col_meta = col_metadata["columns"].get(col_name, {})
|
||||
if "source_type" in col_meta:
|
||||
column_info["source_type"] = col_meta["source_type"]
|
||||
|
||||
columns.append(column_info)
|
||||
|
||||
primary_key = table_config.get_primary_key_columns()
|
||||
|
||||
# Priority for table description: catalog > data_description.md
|
||||
table_description = table_config.description
|
||||
if catalog_data:
|
||||
table_description = catalog_data.description or table_description
|
||||
|
||||
table_info = {
|
||||
"table_id": table_config.id,
|
||||
"description": table_description,
|
||||
"primary_key": primary_key,
|
||||
"sync_strategy": table_config.sync_strategy,
|
||||
"columns": columns,
|
||||
}
|
||||
|
||||
if table_config.partition_by:
|
||||
table_info["partitioned_by"] = table_config.partition_by
|
||||
|
||||
schema_data["tables"][table_config.name] = table_info
|
||||
|
||||
logger.debug(f" {table_config.name}: {len(columns)} columns")
|
||||
|
||||
except Exception as e:
|
||||
logger.warning(f" Error processing {table_config.name}: {e}")
|
||||
|
||||
# Split tables into core (no dataset) and per-dataset groups
|
||||
core_tables = {}
|
||||
dataset_tables = {} # {dataset_name: {table_name: table_info}}
|
||||
for table_name, table_info in schema_data["tables"].items():
|
||||
table_config = next(
|
||||
(t for t in self.config.tables if t.name == table_name), None
|
||||
)
|
||||
if table_config and table_config.dataset:
|
||||
ds = table_config.dataset
|
||||
if ds not in dataset_tables:
|
||||
dataset_tables[ds] = {}
|
||||
dataset_tables[ds][table_name] = table_info
|
||||
else:
|
||||
core_tables[table_name] = table_info
|
||||
|
||||
generated_at = schema_data["_metadata"]["generated_at"]
|
||||
|
||||
def _write_schema_file(filepath, tables, note=""):
|
||||
"""Write a schema YAML file with header comments."""
|
||||
filepath.parent.mkdir(parents=True, exist_ok=True)
|
||||
data = {
|
||||
"_metadata": {
|
||||
"_schema_version": 2,
|
||||
"generated_at": generated_at,
|
||||
"note": "AUTO-GENERATED - DO NOT EDIT.",
|
||||
"source": source_name,
|
||||
"generator": "src/data_sync.py::DataSyncManager._generate_schema_yaml()",
|
||||
},
|
||||
"tables": tables,
|
||||
}
|
||||
with open(filepath, "w") as f:
|
||||
f.write("# AUTO-GENERATED - DO NOT EDIT\n")
|
||||
f.write("# This file is automatically generated during data sync\n")
|
||||
f.write(f"# Generated: {generated_at}\n")
|
||||
if note:
|
||||
f.write(f"# {note}\n")
|
||||
f.write("#\n")
|
||||
f.write("# Contains actual table schemas from synced Parquet files:\n")
|
||||
f.write("# - Column names and PyArrow types (from Parquet)\n")
|
||||
f.write(f"# - Source types and descriptions (from {source_name})\n")
|
||||
f.write("# - Primary keys and sync strategies\n")
|
||||
f.write("#\n")
|
||||
f.write("# For architectural documentation and relationships, see data_description.md\n")
|
||||
f.write("\n")
|
||||
yaml.dump(data, f, default_flow_style=False, sort_keys=False, allow_unicode=True)
|
||||
|
||||
# Write core schema.yml
|
||||
schema_file = self.config.docs_output_dir / "schema.yml"
|
||||
_write_schema_file(schema_file, core_tables)
|
||||
logger.info(f"Core schema YAML: {len(core_tables)} tables -> {schema_file}")
|
||||
|
||||
# Write per-dataset schema files
|
||||
for ds_name, ds_tables in dataset_tables.items():
|
||||
ds_schema_file = self.config.docs_output_dir / "datasets" / ds_name / "schema.yml"
|
||||
_write_schema_file(ds_schema_file, ds_tables, note=f"Dataset: {ds_name}")
|
||||
logger.info(f"Dataset schema YAML: {len(ds_tables)} tables -> {ds_schema_file}")
|
||||
|
||||
total = len(core_tables) + sum(len(t) for t in dataset_tables.values())
|
||||
logger.info(f"Schema generation complete: {total} tables total")
|
||||
|
||||
return schema_file
|
||||
|
||||
def sync_table(self, table_id: str) -> Dict[str, Any]:
|
||||
"""
|
||||
Synchronize single table by ID.
|
||||
|
||||
Args:
|
||||
table_id: Table ID to synchronize
|
||||
|
||||
Returns:
|
||||
Dictionary with sync result
|
||||
"""
|
||||
table_config = self.config.get_table_config(table_id)
|
||||
if not table_config:
|
||||
raise ValueError(f"Table {table_id} not found in configuration")
|
||||
|
||||
return self.data_source.sync_table(table_config, self.sync_state)
|
||||
|
||||
def sync_all(self, tables: Optional[List[str]] = None) -> Dict[str, Dict[str, Any]]:
|
||||
"""
|
||||
Synchronize all tables (or subset according to list).
|
||||
|
||||
Args:
|
||||
tables: List of table IDs to synchronize. If None, syncs all.
|
||||
|
||||
Returns:
|
||||
Dictionary {table_id: result} with sync results
|
||||
"""
|
||||
if tables:
|
||||
table_configs = [
|
||||
self.config.get_table_config(tid) for tid in tables
|
||||
]
|
||||
table_configs = [tc for tc in table_configs if tc is not None]
|
||||
else:
|
||||
table_configs = self.config.tables
|
||||
|
||||
# Filter out remote-only tables (no local sync needed)
|
||||
remote_skipped = [
|
||||
tc for tc in table_configs if tc.query_mode == "remote"
|
||||
]
|
||||
table_configs = [
|
||||
tc for tc in table_configs if tc.query_mode != "remote"
|
||||
]
|
||||
|
||||
if remote_skipped:
|
||||
logger.info(
|
||||
f"Skipping {len(remote_skipped)} remote-only tables "
|
||||
f"(query via BigQuery): "
|
||||
f"{', '.join(tc.name for tc in remote_skipped)}"
|
||||
)
|
||||
|
||||
logger.info(f"Synchronizing {len(table_configs)} tables...")
|
||||
|
||||
results = {}
|
||||
with tqdm(table_configs, desc="Syncing tables") as pbar:
|
||||
for table_config in pbar:
|
||||
pbar.set_description(f"Sync: {table_config.name}")
|
||||
|
||||
result = self.data_source.sync_table(table_config, self.sync_state)
|
||||
results[table_config.id] = result
|
||||
|
||||
if result["success"]:
|
||||
pbar.write(
|
||||
f" {table_config.name}: {result['rows']:,} rows, "
|
||||
f"{result['file_size_mb']:.2f} MB"
|
||||
)
|
||||
else:
|
||||
pbar.write(f" {table_config.name}: {result['error']}")
|
||||
|
||||
success_count = sum(1 for r in results.values() if r["success"])
|
||||
logger.info(
|
||||
f"Synchronization completed: {success_count}/{len(results)} tables successful"
|
||||
)
|
||||
|
||||
# Generate schema.yml from synced tables
|
||||
if success_count > 0:
|
||||
try:
|
||||
self._generate_schema_yaml()
|
||||
except Exception as e:
|
||||
logger.warning(f"Failed to generate schema.yml: {e}")
|
||||
|
||||
# Auto-profile changed tables
|
||||
if success_count > 0:
|
||||
self._auto_profile(results)
|
||||
|
||||
return results
|
||||
|
||||
def _auto_profile(
|
||||
self,
|
||||
results: Dict[str, Dict[str, Any]],
|
||||
skip_tables: Optional[List[str]] = None,
|
||||
):
|
||||
"""Run profiler on successfully synced tables.
|
||||
|
||||
Args:
|
||||
results: Sync results dict {table_id: result}
|
||||
skip_tables: Table IDs to skip profiling for
|
||||
"""
|
||||
skip_set = set(skip_tables or [])
|
||||
try:
|
||||
from src.profiler import profile_changed_tables
|
||||
changed = [
|
||||
self.config.get_table_config(tid).name
|
||||
for tid, r in results.items()
|
||||
if r.get("success")
|
||||
and self.config.get_table_config(tid)
|
||||
and tid not in skip_set
|
||||
]
|
||||
if changed:
|
||||
result = profile_changed_tables(changed)
|
||||
logger.info(
|
||||
f"Auto-profiling: {result['success']} profiled, "
|
||||
f"{result['errors']} errors, {result['skipped']} skipped"
|
||||
)
|
||||
else:
|
||||
logger.info("No tables to profile (all skipped or none succeeded)")
|
||||
except Exception as e:
|
||||
logger.warning(f"Auto-profiling failed (non-fatal): {e}")
|
||||
|
||||
def sync_scheduled(self) -> Dict[str, Dict[str, Any]]:
|
||||
"""Synchronize only tables whose sync_schedule says they are due.
|
||||
|
||||
Evaluates each table's sync_schedule against its last_sync timestamp.
|
||||
Only syncs tables that are due. Respects profile_after_sync flag.
|
||||
|
||||
Returns:
|
||||
Dictionary {table_id: result} with sync results (only for synced tables)
|
||||
"""
|
||||
from src.scheduler import is_table_due
|
||||
|
||||
scheduled_tables = [
|
||||
tc for tc in self.config.tables
|
||||
if tc.sync_schedule and tc.query_mode != "remote"
|
||||
]
|
||||
|
||||
if not scheduled_tables:
|
||||
logger.info("No tables with sync_schedule configured")
|
||||
return {}
|
||||
|
||||
# Evaluate which tables are due
|
||||
due_tables = []
|
||||
for tc in scheduled_tables:
|
||||
last_sync = self.sync_state.get_last_sync(tc.id)
|
||||
if is_table_due(tc.sync_schedule, last_sync):
|
||||
due_tables.append(tc)
|
||||
logger.info(f"Table {tc.name} is DUE (schedule: {tc.sync_schedule})")
|
||||
else:
|
||||
logger.debug(f"Table {tc.name} is not due (schedule: {tc.sync_schedule})")
|
||||
|
||||
if not due_tables:
|
||||
logger.info(
|
||||
f"Checked {len(scheduled_tables)} scheduled tables, none are due"
|
||||
)
|
||||
return {}
|
||||
|
||||
logger.info(
|
||||
f"Syncing {len(due_tables)}/{len(scheduled_tables)} due tables: "
|
||||
f"{', '.join(tc.name for tc in due_tables)}"
|
||||
)
|
||||
|
||||
# Sync due tables
|
||||
results = {}
|
||||
for table_config in due_tables:
|
||||
try:
|
||||
result = self.data_source.sync_table(table_config, self.sync_state)
|
||||
results[table_config.id] = result
|
||||
if result["success"]:
|
||||
logger.info(
|
||||
f" {table_config.name}: {result['rows']:,} rows, "
|
||||
f"{result['file_size_mb']:.2f} MB"
|
||||
)
|
||||
else:
|
||||
logger.error(f" {table_config.name}: {result['error']}")
|
||||
except Exception as e:
|
||||
logger.error(f" {table_config.name}: sync failed: {e}")
|
||||
results[table_config.id] = {"success": False, "error": str(e)}
|
||||
|
||||
success_count = sum(1 for r in results.values() if r["success"])
|
||||
logger.info(f"Scheduled sync: {success_count}/{len(results)} tables successful")
|
||||
|
||||
# Generate schema.yml
|
||||
if success_count > 0:
|
||||
try:
|
||||
self._generate_schema_yaml()
|
||||
except Exception as e:
|
||||
logger.warning(f"Failed to generate schema.yml: {e}")
|
||||
|
||||
# Profile only tables with profile_after_sync=True
|
||||
skip_profiler = [
|
||||
tc.id for tc in due_tables if not tc.profile_after_sync
|
||||
]
|
||||
if skip_profiler:
|
||||
logger.info(
|
||||
f"Skipping profiler for: "
|
||||
f"{', '.join(self.config.get_table_config(tid).name for tid in skip_profiler)}"
|
||||
)
|
||||
|
||||
profiled_any = False
|
||||
if success_count > 0:
|
||||
tables_to_profile = [
|
||||
tid for tid, r in results.items()
|
||||
if r.get("success") and tid not in set(skip_profiler)
|
||||
]
|
||||
if tables_to_profile:
|
||||
self._auto_profile(results, skip_tables=skip_profiler)
|
||||
profiled_any = True
|
||||
|
||||
# Restart webapp if profiler ran (new profiles.json needs reload)
|
||||
if profiled_any:
|
||||
self._restart_webapp()
|
||||
|
||||
return results
|
||||
|
||||
def _restart_webapp(self):
|
||||
"""Restart webapp service to pick up new profiles.json."""
|
||||
import subprocess
|
||||
try:
|
||||
subprocess.run(
|
||||
["sudo", "systemctl", "restart", "webapp"],
|
||||
check=True,
|
||||
capture_output=True,
|
||||
timeout=30,
|
||||
)
|
||||
logger.info("Webapp restarted successfully")
|
||||
except subprocess.CalledProcessError as e:
|
||||
logger.warning(f"Failed to restart webapp: {e.stderr.decode() if e.stderr else e}")
|
||||
except FileNotFoundError:
|
||||
logger.debug("systemctl not found (not running on server)")
|
||||
|
||||
|
||||
def create_sync_manager() -> DataSyncManager:
|
||||
"""
|
||||
Factory function to create DataSyncManager.
|
||||
|
||||
Returns:
|
||||
DataSyncManager instance
|
||||
"""
|
||||
return DataSyncManager()
|
||||
|
||||
|
||||
def create_data_source(source_type: str = None) -> DataSource:
|
||||
"""Create a data source based on configuration.
|
||||
|
||||
Args:
|
||||
source_type: Override source type. If None, uses DATA_SOURCE env var.
|
||||
|
||||
Returns:
|
||||
DataSource instance
|
||||
|
||||
Raises:
|
||||
ValueError: If source type is unknown
|
||||
ImportError: If connector dependencies are missing
|
||||
"""
|
||||
if source_type is None:
|
||||
source_type = get_config().data_source
|
||||
|
||||
if source_type in ("local", "keboola"):
|
||||
try:
|
||||
from connectors.keboola.adapter import KeboolaDataSource
|
||||
except ModuleNotFoundError as e:
|
||||
if "kbcstorage" in str(e):
|
||||
raise ImportError(
|
||||
"Keboola connector requires 'kbcstorage' package. "
|
||||
"Install with: pip install kbcstorage"
|
||||
) from e
|
||||
raise # Re-raise real import errors
|
||||
return KeboolaDataSource()
|
||||
|
||||
# Try dynamic connector import for other types
|
||||
try:
|
||||
mod = importlib.import_module(f"connectors.{source_type}.adapter")
|
||||
factory = getattr(mod, "create_data_source", None)
|
||||
if factory:
|
||||
return factory()
|
||||
# Fallback: look for a class named *DataSource
|
||||
for attr_name in dir(mod):
|
||||
attr = getattr(mod, attr_name)
|
||||
if isinstance(attr, type) and issubclass(attr, DataSource) and attr is not DataSource:
|
||||
return attr()
|
||||
except ModuleNotFoundError:
|
||||
pass
|
||||
|
||||
raise ValueError(
|
||||
f"Unknown data source: '{source_type}'. "
|
||||
f"Available connectors: keboola, bigquery. "
|
||||
f"Create connectors/{source_type}/adapter.py to add a new one."
|
||||
)
|
||||
|
||||
|
||||
if __name__ == "__main__":
|
||||
# CLI interface for sync
|
||||
import sys
|
||||
|
||||
scheduled_mode = "--scheduled" in sys.argv
|
||||
table_args = [a for a in sys.argv[1:] if a != "--scheduled"]
|
||||
|
||||
try:
|
||||
manager = create_sync_manager()
|
||||
|
||||
if scheduled_mode:
|
||||
print("Data Sync (scheduled mode)")
|
||||
results = manager.sync_scheduled()
|
||||
|
||||
if not results:
|
||||
print("No tables due for sync")
|
||||
sys.exit(0)
|
||||
elif table_args:
|
||||
print("Data Sync")
|
||||
print(f"\nSynchronizing selected tables: {', '.join(table_args)}")
|
||||
results = manager.sync_all(tables=table_args)
|
||||
else:
|
||||
print("Data Sync")
|
||||
print("\nSynchronizing all tables...")
|
||||
results = manager.sync_all()
|
||||
|
||||
success_count = sum(1 for r in results.values() if r["success"])
|
||||
total_count = len(results)
|
||||
|
||||
if success_count == total_count:
|
||||
print(f"\nAll {total_count} tables synchronized successfully!")
|
||||
sys.exit(0)
|
||||
else:
|
||||
print(
|
||||
f"\n{success_count}/{total_count} tables synchronized. "
|
||||
f"Check logs for details."
|
||||
)
|
||||
sys.exit(1)
|
||||
|
||||
except Exception as e:
|
||||
print(f"\nError: {e}")
|
||||
import traceback
|
||||
|
||||
traceback.print_exc()
|
||||
sys.exit(1)
|
||||
|
|
@ -1,755 +0,0 @@
|
|||
"""
|
||||
Parquet File Manager
|
||||
|
||||
Parquet file management:
|
||||
1. CSV -> Parquet conversion with data type application
|
||||
2. Compression (snappy) for space saving
|
||||
3. Metadata embedding (table_id, export_date)
|
||||
4. Information about existing Parquet files
|
||||
5. Merge/upsert operations for incremental sync
|
||||
|
||||
Parquet format advantages:
|
||||
- Columnar storage -> faster analytical queries
|
||||
- Compression -> smaller size than CSV
|
||||
- Schema enforcement -> type safety
|
||||
- Metadata support -> self-documenting
|
||||
"""
|
||||
|
||||
import logging
|
||||
from pathlib import Path
|
||||
from typing import Dict, Optional, List, Any
|
||||
from datetime import datetime
|
||||
|
||||
import pandas as pd
|
||||
import pyarrow as pa
|
||||
import pyarrow.parquet as pq
|
||||
|
||||
|
||||
logger = logging.getLogger(__name__)
|
||||
|
||||
|
||||
def convert_date_columns_to_date32(table: pa.Table, date_columns: List[str]) -> pa.Table:
|
||||
"""
|
||||
Convert timestamp/string columns to DATE32 type.
|
||||
|
||||
Extracted from csv_to_parquet() for reuse in partitioned sync.
|
||||
Invalid date values (like '0000-00-00') are converted to NULL, type stays DATE32.
|
||||
|
||||
Args:
|
||||
table: PyArrow table to convert
|
||||
date_columns: List of column names to convert to DATE32
|
||||
|
||||
Returns:
|
||||
PyArrow table with converted date columns
|
||||
"""
|
||||
if not date_columns:
|
||||
return table
|
||||
|
||||
schema_fields = []
|
||||
for i, field in enumerate(table.schema):
|
||||
if field.name in date_columns:
|
||||
schema_fields.append(pa.field(field.name, pa.date32()))
|
||||
else:
|
||||
schema_fields.append(field)
|
||||
|
||||
# Cast columns to DATE32
|
||||
columns = []
|
||||
for i, field in enumerate(table.schema):
|
||||
if field.name in date_columns:
|
||||
col = table.column(i)
|
||||
|
||||
# Skip if all nulls - nothing to convert, just cast type
|
||||
if col.null_count == len(col):
|
||||
columns.append(pa.nulls(len(col), type=pa.date32()))
|
||||
continue
|
||||
|
||||
# If column is string type, use pandas for robust parsing with errors='coerce'
|
||||
# This converts invalid dates to NaT (NULL) while keeping the DATE type
|
||||
if pa.types.is_string(col.type) or pa.types.is_large_string(col.type):
|
||||
# Convert to pandas, parse dates with coerce (invalid -> NaT)
|
||||
series = col.to_pandas()
|
||||
parsed = pd.to_datetime(series, errors='coerce', format='mixed')
|
||||
|
||||
# Count invalid values for logging
|
||||
invalid_count = parsed.isna().sum() - series.isna().sum()
|
||||
if invalid_count > 0:
|
||||
# Find examples of invalid values
|
||||
invalid_mask = series.notna() & parsed.isna()
|
||||
examples = series[invalid_mask].head(3).tolist()
|
||||
logger.warning(
|
||||
f"Column '{field.name}': {invalid_count} invalid date values converted to NULL. "
|
||||
f"Examples: {examples}"
|
||||
)
|
||||
|
||||
# Convert to date only (remove time component) and then to PyArrow
|
||||
date_series = parsed.dt.date
|
||||
date_array = pa.array(date_series, type=pa.date32())
|
||||
columns.append(date_array)
|
||||
else:
|
||||
# Column is already timestamp/date type, just cast
|
||||
try:
|
||||
columns.append(col.cast(pa.date32()))
|
||||
except Exception as e:
|
||||
logger.warning(
|
||||
f"Column '{field.name}': Failed to cast to date32, keeping original type. Error: {e}"
|
||||
)
|
||||
columns.append(col)
|
||||
schema_fields[i] = field
|
||||
else:
|
||||
columns.append(table.column(i))
|
||||
|
||||
# Rebuild table with new schema
|
||||
new_schema = pa.schema(schema_fields, metadata=table.schema.metadata)
|
||||
return pa.Table.from_arrays(columns, schema=new_schema)
|
||||
|
||||
|
||||
def apply_schema_to_table(table: pa.Table, target_schema: pa.Schema) -> pa.Table:
|
||||
"""
|
||||
Apply target schema to PyArrow table, handling type mismatches gracefully.
|
||||
|
||||
This function:
|
||||
- Casts null-type columns to proper types (prevents DuckDB schema mismatch)
|
||||
- Attempts safe casts for other type mismatches
|
||||
- Keeps original type on cast failure (logs warning, no data loss)
|
||||
- Preserves columns not in target schema with their inferred type
|
||||
|
||||
Args:
|
||||
table: PyArrow table to cast
|
||||
target_schema: Target PyArrow schema
|
||||
|
||||
Returns:
|
||||
PyArrow table with applied schema
|
||||
"""
|
||||
if not target_schema:
|
||||
return table
|
||||
|
||||
# Build mapping of column names to target types
|
||||
target_types = {field.name: field.type for field in target_schema}
|
||||
|
||||
columns = []
|
||||
schema_fields = []
|
||||
|
||||
for i, field in enumerate(table.schema):
|
||||
col = table.column(i)
|
||||
col_name = field.name
|
||||
|
||||
# Column not in target schema - keep as-is
|
||||
if col_name not in target_types:
|
||||
columns.append(col)
|
||||
schema_fields.append(field)
|
||||
continue
|
||||
|
||||
target_type = target_types[col_name]
|
||||
|
||||
# Case 1: Column has null type -> create typed null array
|
||||
if pa.types.is_null(col.type):
|
||||
columns.append(pa.nulls(len(col), type=target_type))
|
||||
schema_fields.append(pa.field(col_name, target_type))
|
||||
logger.debug(f"Column '{col_name}': converted null type to {target_type}")
|
||||
continue
|
||||
|
||||
# Case 2: Column type matches target -> keep as-is
|
||||
if col.type == target_type:
|
||||
columns.append(col)
|
||||
schema_fields.append(pa.field(col_name, target_type))
|
||||
continue
|
||||
|
||||
# Case 3: Type mismatch -> try safe cast
|
||||
try:
|
||||
casted = col.cast(target_type, safe=True)
|
||||
columns.append(casted)
|
||||
schema_fields.append(pa.field(col_name, target_type))
|
||||
logger.debug(f"Column '{col_name}': cast from {col.type} to {target_type}")
|
||||
except Exception as e:
|
||||
# Cast failed -> keep original type, log warning
|
||||
logger.warning(
|
||||
f"Column '{col_name}': cannot cast from {col.type} to {target_type}, "
|
||||
f"keeping original type. Error: {e}"
|
||||
)
|
||||
columns.append(col)
|
||||
schema_fields.append(field)
|
||||
|
||||
# Rebuild table with new schema
|
||||
new_schema = pa.schema(schema_fields, metadata=table.schema.metadata)
|
||||
return pa.Table.from_arrays(columns, schema=new_schema)
|
||||
|
||||
|
||||
def _convert_column(series: pd.Series, dtype: str, col_name: str = "") -> pd.Series:
|
||||
"""
|
||||
Convert pandas Series to dtype, handling empty strings.
|
||||
|
||||
Empty strings become NA/NaN for non-string types.
|
||||
Logs warning if invalid (non-empty) values are found.
|
||||
|
||||
Args:
|
||||
series: Input pandas Series
|
||||
dtype: Target dtype (e.g., "Int64", "float64", "boolean")
|
||||
col_name: Column name for logging
|
||||
|
||||
Returns:
|
||||
Converted Series
|
||||
"""
|
||||
# Replace empty strings with NA for non-string types
|
||||
if dtype != "object":
|
||||
series = series.replace('', pd.NA)
|
||||
|
||||
# Numeric types - use errors='coerce' but log invalid values
|
||||
if dtype in ("Int64", "float64", "Float64"):
|
||||
# Count non-null values before conversion
|
||||
non_null_before = series.notna().sum()
|
||||
|
||||
converted = pd.to_numeric(series, errors='coerce')
|
||||
|
||||
# Count how many became NA after conversion (excluding already NA)
|
||||
non_null_after = converted.notna().sum()
|
||||
invalid_count = non_null_before - non_null_after
|
||||
|
||||
if invalid_count > 0:
|
||||
# Find examples of invalid values
|
||||
invalid_mask = series.notna() & converted.isna()
|
||||
examples = series[invalid_mask].head(3).tolist()
|
||||
logger.warning(
|
||||
f"Column '{col_name}': {invalid_count} invalid values converted to NULL. "
|
||||
f"Examples: {examples}"
|
||||
)
|
||||
|
||||
return converted.astype(dtype)
|
||||
|
||||
# Boolean type - map string representations
|
||||
if dtype == "boolean":
|
||||
# If pandas already parsed as bool, just convert to nullable boolean
|
||||
if series.dtype == bool or series.dtype == 'object' and series.dropna().apply(lambda x: isinstance(x, bool)).all():
|
||||
return series.astype(dtype)
|
||||
|
||||
bool_map = {
|
||||
'true': True, 'false': False,
|
||||
'True': True, 'False': False,
|
||||
'TRUE': True, 'FALSE': False,
|
||||
'1': True, '0': False,
|
||||
'yes': True, 'no': False,
|
||||
'Yes': True, 'No': False,
|
||||
'YES': True, 'NO': False,
|
||||
}
|
||||
# Log unknown values (non-empty strings that aren't in bool_map)
|
||||
known_values = set(bool_map.keys())
|
||||
non_na_values = series.dropna()
|
||||
unknown = non_na_values[~non_na_values.isin(known_values)]
|
||||
if len(unknown) > 0:
|
||||
examples = unknown.head(3).tolist()
|
||||
logger.warning(
|
||||
f"Column '{col_name}': {len(unknown)} unknown boolean values converted to NULL. "
|
||||
f"Examples: {examples}"
|
||||
)
|
||||
return series.map(bool_map).astype(dtype)
|
||||
|
||||
# Default: direct conversion
|
||||
return series.astype(dtype)
|
||||
|
||||
|
||||
class ParquetManager:
|
||||
"""
|
||||
Parquet file manager.
|
||||
|
||||
Provides methods for:
|
||||
- CSV -> Parquet conversion
|
||||
- Getting information about Parquet files
|
||||
- Merge/upsert for incremental sync
|
||||
"""
|
||||
|
||||
def __init__(self):
|
||||
"""Initialize Parquet manager."""
|
||||
# Compression codec - snappy is fast and has good compression ratio
|
||||
self.compression = "snappy"
|
||||
|
||||
def csv_to_parquet(
|
||||
self,
|
||||
csv_path: Path,
|
||||
parquet_path: Path,
|
||||
dtypes: Optional[Dict[str, str]] = None,
|
||||
table_id: Optional[str] = None,
|
||||
parse_dates: Optional[List[str]] = None,
|
||||
date_columns: Optional[List[str]] = None,
|
||||
pyarrow_schema: Optional[pa.Schema] = None
|
||||
) -> Dict[str, Any]:
|
||||
"""
|
||||
Convert CSV file to Parquet format.
|
||||
|
||||
Args:
|
||||
csv_path: Path to source CSV file
|
||||
parquet_path: Path where to save Parquet file
|
||||
dtypes: Dictionary with data types for columns (pandas dtypes)
|
||||
table_id: Table ID for metadata
|
||||
parse_dates: List of columns with dates/timestamps to parse
|
||||
date_columns: List of DATE-only columns (without time) to convert to PyArrow DATE32
|
||||
pyarrow_schema: Optional PyArrow schema to enforce (prevents null-type columns)
|
||||
|
||||
Returns:
|
||||
Dictionary with conversion information:
|
||||
- rows: Number of rows
|
||||
- columns: Number of columns
|
||||
- file_size_bytes: Parquet file size
|
||||
- compression_ratio: Compression ratio (CSV size / Parquet size)
|
||||
|
||||
Raises:
|
||||
Exception: If conversion fails
|
||||
"""
|
||||
logger.info(f"Converting CSV -> Parquet: {csv_path.name}")
|
||||
|
||||
try:
|
||||
# Load CSV into pandas DataFrame
|
||||
# IMPORTANT: Use dtype=str to prevent pandas from guessing types
|
||||
# We apply our own types from Keboola metadata using _convert_column
|
||||
read_kwargs = {"dtype": str}
|
||||
|
||||
# Get actual column names from CSV header first
|
||||
with open(csv_path, 'r') as f:
|
||||
header_line = f.readline().strip()
|
||||
csv_columns = set(col.strip('"') for col in header_line.split(','))
|
||||
|
||||
# Parse datetime columns - only those that exist in CSV
|
||||
if parse_dates:
|
||||
valid_parse_dates = [col for col in parse_dates if col in csv_columns]
|
||||
if valid_parse_dates:
|
||||
read_kwargs["parse_dates"] = valid_parse_dates
|
||||
elif dtypes:
|
||||
# Auto-detect datetime columns from dtypes (only existing columns)
|
||||
datetime_cols = [
|
||||
col for col, dtype in dtypes.items()
|
||||
if "datetime" in dtype and col in csv_columns
|
||||
]
|
||||
if datetime_cols:
|
||||
read_kwargs["parse_dates"] = datetime_cols
|
||||
|
||||
df = pd.read_csv(csv_path, **read_kwargs)
|
||||
|
||||
logger.debug(f"CSV loaded: {len(df)} rows, {len(df.columns)} columns")
|
||||
|
||||
# Apply dtypes using _convert_column to handle empty strings
|
||||
if dtypes:
|
||||
for col, dtype in dtypes.items():
|
||||
if col in df.columns and "datetime" not in dtype:
|
||||
try:
|
||||
df[col] = _convert_column(df[col], dtype, col_name=col)
|
||||
except Exception as e:
|
||||
logger.warning(
|
||||
f"Failed to apply dtype {dtype} to column {col}: {e}"
|
||||
)
|
||||
|
||||
# Add metadata as custom schema metadata
|
||||
metadata = {
|
||||
"created_at": datetime.now().isoformat(),
|
||||
}
|
||||
if table_id:
|
||||
metadata["table_id"] = table_id
|
||||
|
||||
# Convert to PyArrow Table
|
||||
# PyArrow preserves pandas dtypes and adds metadata
|
||||
table = pa.Table.from_pandas(df)
|
||||
|
||||
# Convert DATE columns from timestamp/string to DATE32
|
||||
if date_columns:
|
||||
table = convert_date_columns_to_date32(table, date_columns)
|
||||
|
||||
# Apply explicit schema (prevents null-type columns)
|
||||
if pyarrow_schema:
|
||||
table = apply_schema_to_table(table, pyarrow_schema)
|
||||
|
||||
# Add custom metadata
|
||||
existing_metadata = table.schema.metadata or {}
|
||||
new_metadata = {
|
||||
**existing_metadata,
|
||||
**{k.encode(): v.encode() for k, v in metadata.items()}
|
||||
}
|
||||
table = table.replace_schema_metadata(new_metadata)
|
||||
|
||||
# Ensure output folder exists
|
||||
parquet_path.parent.mkdir(parents=True, exist_ok=True)
|
||||
|
||||
# Write to Parquet
|
||||
pq.write_table(
|
||||
table,
|
||||
parquet_path,
|
||||
compression=self.compression
|
||||
)
|
||||
|
||||
# Get file sizes for compression ratio
|
||||
csv_size = csv_path.stat().st_size
|
||||
parquet_size = parquet_path.stat().st_size
|
||||
compression_ratio = csv_size / parquet_size if parquet_size > 0 else 0
|
||||
|
||||
result = {
|
||||
"rows": len(df),
|
||||
"columns": len(df.columns),
|
||||
"csv_size_bytes": csv_size,
|
||||
"parquet_size_bytes": parquet_size,
|
||||
"compression_ratio": compression_ratio,
|
||||
"parquet_path": str(parquet_path)
|
||||
}
|
||||
|
||||
logger.info(
|
||||
f"Parquet created: {len(df)} rows, "
|
||||
f"{parquet_size / 1024 / 1024:.2f} MB, "
|
||||
f"compression {compression_ratio:.2f}x"
|
||||
)
|
||||
|
||||
return result
|
||||
|
||||
except Exception as e:
|
||||
logger.error(f"Error converting CSV -> Parquet: {e}")
|
||||
raise
|
||||
|
||||
def get_parquet_info(self, parquet_path: Path) -> Optional[Dict[str, Any]]:
|
||||
"""
|
||||
Get information about existing Parquet file.
|
||||
|
||||
Args:
|
||||
parquet_path: Path to Parquet file
|
||||
|
||||
Returns:
|
||||
Dictionary with information:
|
||||
- rows: Number of rows
|
||||
- columns: Number of columns
|
||||
- file_size_bytes: File size
|
||||
- modified_at: Last modification timestamp
|
||||
- schema: PyArrow schema
|
||||
- metadata: Custom metadata
|
||||
Or None if file doesn't exist.
|
||||
"""
|
||||
if not parquet_path.exists():
|
||||
return None
|
||||
|
||||
try:
|
||||
# Load Parquet metadata (without loading data)
|
||||
parquet_file = pq.ParquetFile(parquet_path)
|
||||
|
||||
# Basic info
|
||||
file_size = parquet_path.stat().st_size
|
||||
modified_at = datetime.fromtimestamp(parquet_path.stat().st_mtime)
|
||||
|
||||
# Schema and metadata
|
||||
schema = parquet_file.schema_arrow
|
||||
custom_metadata = {}
|
||||
if schema.metadata:
|
||||
custom_metadata = {
|
||||
k.decode(): v.decode()
|
||||
for k, v in schema.metadata.items()
|
||||
if k.decode() not in ["pandas"] # Filter pandas internal metadata
|
||||
}
|
||||
|
||||
info = {
|
||||
"rows": parquet_file.metadata.num_rows,
|
||||
"columns": len(schema),
|
||||
"file_size_bytes": file_size,
|
||||
"modified_at": modified_at.isoformat(),
|
||||
"schema": schema,
|
||||
"metadata": custom_metadata,
|
||||
"parquet_path": str(parquet_path)
|
||||
}
|
||||
|
||||
return info
|
||||
|
||||
except Exception as e:
|
||||
logger.error(f"Error reading Parquet info: {e}")
|
||||
return None
|
||||
|
||||
def merge_parquet(
|
||||
self,
|
||||
existing_parquet: Path,
|
||||
new_csv: Path,
|
||||
output_parquet: Path,
|
||||
primary_key: List[str],
|
||||
dtypes: Optional[Dict[str, str]] = None,
|
||||
parse_dates: Optional[List[str]] = None,
|
||||
date_columns: Optional[List[str]] = None,
|
||||
pyarrow_schema: Optional[pa.Schema] = None
|
||||
) -> Dict[str, Any]:
|
||||
"""
|
||||
Merge new data from CSV into existing Parquet file.
|
||||
|
||||
Performs upsert operation:
|
||||
- Rows with existing primary_key are updated
|
||||
- Rows with new primary_key are added
|
||||
|
||||
Args:
|
||||
existing_parquet: Path to existing Parquet file
|
||||
new_csv: Path to CSV with new data
|
||||
output_parquet: Path where to save resulting Parquet
|
||||
primary_key: List of column names forming the primary key (supports composite PK)
|
||||
dtypes: Dictionary with data types
|
||||
parse_dates: List of datetime columns
|
||||
date_columns: List of DATE-only columns to convert to PyArrow DATE32
|
||||
pyarrow_schema: Optional PyArrow schema to enforce (prevents null-type columns)
|
||||
|
||||
Returns:
|
||||
Dictionary with information:
|
||||
- total_rows: Total number of rows after merge
|
||||
- added_rows: Number of newly added rows
|
||||
- updated_rows: Number of updated rows
|
||||
- unchanged_rows: Number of unchanged rows
|
||||
|
||||
Raises:
|
||||
Exception: If merge fails
|
||||
"""
|
||||
pk_str = ", ".join(primary_key)
|
||||
logger.info(
|
||||
f"Merging Parquet: {existing_parquet.name} + {new_csv.name} -> {output_parquet.name} (PK: {pk_str})"
|
||||
)
|
||||
|
||||
try:
|
||||
# Load existing Parquet
|
||||
existing_df = pd.read_parquet(existing_parquet)
|
||||
original_count = len(existing_df)
|
||||
|
||||
logger.debug(f"Existing data: {original_count} rows")
|
||||
|
||||
# Load new data from CSV
|
||||
# IMPORTANT: Use dtype=str to prevent pandas from guessing types
|
||||
# We apply our own types from Keboola metadata using _convert_column
|
||||
read_kwargs = {"dtype": str}
|
||||
|
||||
if parse_dates:
|
||||
read_kwargs["parse_dates"] = parse_dates
|
||||
elif dtypes:
|
||||
datetime_cols = [
|
||||
col for col, dtype in dtypes.items()
|
||||
if "datetime" in dtype
|
||||
]
|
||||
if datetime_cols:
|
||||
read_kwargs["parse_dates"] = datetime_cols
|
||||
|
||||
new_df = pd.read_csv(new_csv, **read_kwargs)
|
||||
|
||||
# Apply dtypes using _convert_column to handle empty strings
|
||||
if dtypes:
|
||||
for col, dtype in dtypes.items():
|
||||
if col in new_df.columns and "datetime" not in dtype:
|
||||
try:
|
||||
new_df[col] = _convert_column(new_df[col], dtype, col_name=col)
|
||||
except Exception as e:
|
||||
logger.warning(
|
||||
f"Failed to apply dtype {dtype} to column {col}: {e}"
|
||||
)
|
||||
|
||||
new_count = len(new_df)
|
||||
|
||||
logger.debug(f"New data: {new_count} rows")
|
||||
|
||||
# Check that all primary_key columns exist in both dataframes
|
||||
for pk_col in primary_key:
|
||||
if pk_col not in existing_df.columns:
|
||||
raise ValueError(
|
||||
f"Primary key column '{pk_col}' not found in existing data"
|
||||
)
|
||||
if pk_col not in new_df.columns:
|
||||
raise ValueError(
|
||||
f"Primary key column '{pk_col}' not found in new data"
|
||||
)
|
||||
|
||||
# Perform upsert: concat and then drop_duplicates with keep='last'
|
||||
# Keep='last' means that new data (which is second) will overwrite old
|
||||
merged_df = pd.concat([existing_df, new_df], ignore_index=True)
|
||||
merged_df = merged_df.drop_duplicates(subset=primary_key, keep='last')
|
||||
|
||||
# Calculate statistics
|
||||
final_count = len(merged_df)
|
||||
added_rows = final_count - original_count
|
||||
# Updated rows = rows that were in both datasets
|
||||
updated_rows = new_count - added_rows if added_rows < new_count else 0
|
||||
unchanged_rows = original_count - updated_rows
|
||||
|
||||
logger.info(
|
||||
f"Merge completed: {final_count} total rows "
|
||||
f"(+{added_rows} new, ~{updated_rows} updates)"
|
||||
)
|
||||
|
||||
# Save as new Parquet
|
||||
# Prepare metadata
|
||||
metadata = {
|
||||
"created_at": datetime.now().isoformat(),
|
||||
"merged_from": new_csv.name
|
||||
}
|
||||
|
||||
table = pa.Table.from_pandas(merged_df)
|
||||
|
||||
# Convert DATE columns from timestamp to DATE32
|
||||
if date_columns:
|
||||
table = convert_date_columns_to_date32(table, date_columns)
|
||||
|
||||
# Apply explicit schema (prevents null-type columns)
|
||||
if pyarrow_schema:
|
||||
table = apply_schema_to_table(table, pyarrow_schema)
|
||||
|
||||
new_metadata = {
|
||||
k.encode(): v.encode() for k, v in metadata.items()
|
||||
}
|
||||
table = table.replace_schema_metadata(new_metadata)
|
||||
|
||||
# Ensure output folder exists
|
||||
output_parquet.parent.mkdir(parents=True, exist_ok=True)
|
||||
|
||||
# Write Parquet
|
||||
pq.write_table(table, output_parquet, compression=self.compression)
|
||||
|
||||
result = {
|
||||
"total_rows": final_count,
|
||||
"total_columns": len(merged_df.columns),
|
||||
"added_rows": added_rows,
|
||||
"updated_rows": updated_rows,
|
||||
"unchanged_rows": unchanged_rows,
|
||||
"parquet_path": str(output_parquet)
|
||||
}
|
||||
|
||||
return result
|
||||
|
||||
except Exception as e:
|
||||
logger.error(f"Error merging Parquet: {e}")
|
||||
raise
|
||||
|
||||
def validate_parquet(self, parquet_path: Path) -> bool:
|
||||
"""
|
||||
Validate that Parquet file is readable and not corrupted.
|
||||
|
||||
Args:
|
||||
parquet_path: Path to Parquet file
|
||||
|
||||
Returns:
|
||||
True if file is valid, False otherwise
|
||||
"""
|
||||
try:
|
||||
# Try to load schema (fast operation)
|
||||
parquet_file = pq.ParquetFile(parquet_path)
|
||||
_ = parquet_file.schema_arrow
|
||||
|
||||
# Try to load first row (data validation)
|
||||
# Note: pyarrow 23.0 doesn't have nrows parameter, load full file then limit
|
||||
df = pd.read_parquet(parquet_path)
|
||||
_ = df.head(1)
|
||||
|
||||
return True
|
||||
except Exception as e:
|
||||
logger.error(f"Parquet validation failed: {e}")
|
||||
return False
|
||||
|
||||
|
||||
def create_parquet_manager() -> ParquetManager:
|
||||
"""
|
||||
Factory function to create ParquetManager instance.
|
||||
|
||||
Returns:
|
||||
ParquetManager instance
|
||||
"""
|
||||
return ParquetManager()
|
||||
|
||||
|
||||
if __name__ == "__main__":
|
||||
# Test Parquet manager
|
||||
print("📦 Testing Parquet manager...")
|
||||
|
||||
import tempfile
|
||||
|
||||
try:
|
||||
manager = create_parquet_manager()
|
||||
|
||||
# Create test CSV
|
||||
with tempfile.TemporaryDirectory() as tmpdir:
|
||||
tmpdir = Path(tmpdir)
|
||||
|
||||
# Test data
|
||||
test_csv = tmpdir / "test.csv"
|
||||
test_csv.write_text(
|
||||
"id,name,value,created_at\n"
|
||||
"1,Alice,100,2026-01-01\n"
|
||||
"2,Bob,200,2026-01-02\n"
|
||||
"3,Charlie,300,2026-01-03\n"
|
||||
)
|
||||
|
||||
test_parquet = tmpdir / "test.parquet"
|
||||
|
||||
# Test 1: CSV -> Parquet
|
||||
print("\n1️⃣ Testing CSV -> Parquet conversion...")
|
||||
result = manager.csv_to_parquet(
|
||||
csv_path=test_csv,
|
||||
parquet_path=test_parquet,
|
||||
dtypes={"id": "Int64", "name": "object", "value": "Int64"},
|
||||
parse_dates=["created_at"],
|
||||
table_id="test.table"
|
||||
)
|
||||
print(f" ✅ Conversion OK:")
|
||||
print(f" Rows: {result['rows']}")
|
||||
print(f" Compression: {result['compression_ratio']:.2f}x")
|
||||
|
||||
# Test 2: Parquet info
|
||||
print("\n2️⃣ Testing Parquet info...")
|
||||
info = manager.get_parquet_info(test_parquet)
|
||||
if info:
|
||||
print(f" ✅ Info loaded:")
|
||||
print(f" Rows: {info['rows']}")
|
||||
print(f" Metadata: {info['metadata']}")
|
||||
|
||||
# Test 3: Validation
|
||||
print("\n3️⃣ Testing validation...")
|
||||
if manager.validate_parquet(test_parquet):
|
||||
print(" ✅ Parquet is valid!")
|
||||
|
||||
# Test 4: Merge
|
||||
print("\n4️⃣ Testing merge...")
|
||||
# Create update CSV
|
||||
update_csv = tmpdir / "update.csv"
|
||||
update_csv.write_text(
|
||||
"id,name,value,created_at\n"
|
||||
"2,Bob Updated,250,2026-01-04\n" # Update
|
||||
"4,David,400,2026-01-05\n" # New
|
||||
)
|
||||
|
||||
merged_parquet = tmpdir / "merged.parquet"
|
||||
merge_result = manager.merge_parquet(
|
||||
existing_parquet=test_parquet,
|
||||
new_csv=update_csv,
|
||||
output_parquet=merged_parquet,
|
||||
primary_key=["id"], # Now uses list
|
||||
dtypes={"id": "Int64", "name": "object", "value": "Int64"},
|
||||
parse_dates=["created_at"]
|
||||
)
|
||||
print(f" ✅ Merge OK:")
|
||||
print(f" Total rows: {merge_result['total_rows']}")
|
||||
print(f" Added: {merge_result['added_rows']}")
|
||||
print(f" Updated: {merge_result['updated_rows']}")
|
||||
|
||||
# Test 5: Empty string handling
|
||||
print("\n5️⃣ Testing empty string handling...")
|
||||
empty_csv = tmpdir / "empty_strings.csv"
|
||||
empty_csv.write_text(
|
||||
"id,is_active,revenue,note\n"
|
||||
'1,true,100.5,text\n'
|
||||
'2,,200.0,\n' # Empty boolean and string
|
||||
'3,false,,note\n' # Empty float
|
||||
'4,TRUE,N/A,\n' # Invalid float value
|
||||
)
|
||||
|
||||
empty_parquet = tmpdir / "empty_strings.parquet"
|
||||
result = manager.csv_to_parquet(
|
||||
csv_path=empty_csv,
|
||||
parquet_path=empty_parquet,
|
||||
dtypes={
|
||||
"id": "Int64",
|
||||
"is_active": "boolean",
|
||||
"revenue": "float64",
|
||||
"note": "object"
|
||||
},
|
||||
table_id="test.empty_strings"
|
||||
)
|
||||
print(f" ✅ Conversion with empty strings OK:")
|
||||
print(f" Rows: {result['rows']}")
|
||||
|
||||
# Verify the data
|
||||
df = pd.read_parquet(empty_parquet)
|
||||
print(f" Dtypes: {dict(df.dtypes)}")
|
||||
print(f" is_active nulls: {df['is_active'].isna().sum()}")
|
||||
print(f" revenue nulls: {df['revenue'].isna().sum()}")
|
||||
|
||||
print("\n✅ All tests passed!")
|
||||
|
||||
except Exception as e:
|
||||
print(f"\n❌ Error: {e}")
|
||||
import traceback
|
||||
traceback.print_exc()
|
||||
|
|
@ -1,636 +0,0 @@
|
|||
"""
|
||||
Remote Query - Execute DuckDB queries spanning local Parquet + remote BigQuery tables.
|
||||
|
||||
Provides a server-side CLI for the AI agent to run SQL queries that JOIN local
|
||||
(Parquet-backed) tables with on-demand BigQuery results. Designed for tables too
|
||||
large to sync locally (e.g., daily_deal_traffic: ~3M rows/day).
|
||||
|
||||
Two-phase query protocol:
|
||||
1. BQ sub-queries (--register-bq "alias=SQL") run on BigQuery, results registered
|
||||
as DuckDB views via PyArrow (reuses register_bq_table from duckdb_manager).
|
||||
2. DuckDB SQL (--sql) runs against local Parquet views + registered BQ results.
|
||||
|
||||
Usage:
|
||||
python -m src.remote_query \\
|
||||
--sql "SELECT ... FROM order_economics o JOIN traffic t ON ..." \\
|
||||
--register-bq "traffic=SELECT ... FROM \\`project.dataset.table\\` WHERE ..." \\
|
||||
--format table
|
||||
|
||||
Safety features:
|
||||
- COUNT(*) pre-check before fetching BQ data
|
||||
- Memory estimation (refuses queries > 2 GB estimated)
|
||||
- Configurable row limits (per BQ sub-query and final result)
|
||||
- Query timeout support
|
||||
"""
|
||||
|
||||
import argparse
|
||||
import csv
|
||||
import io
|
||||
import json
|
||||
import logging
|
||||
import os
|
||||
import sys
|
||||
import time
|
||||
from pathlib import Path
|
||||
from typing import Optional
|
||||
|
||||
import duckdb
|
||||
|
||||
from config.loader import get_instance_value
|
||||
from scripts.duckdb_manager import (
|
||||
create_local_views,
|
||||
register_bq_table,
|
||||
_create_bq_client,
|
||||
)
|
||||
|
||||
logger = logging.getLogger(__name__)
|
||||
|
||||
|
||||
class RemoteQueryError(Exception):
|
||||
"""Error during remote query execution."""
|
||||
|
||||
|
||||
# ---------------------------------------------------------------------------
|
||||
# Configuration
|
||||
# ---------------------------------------------------------------------------
|
||||
|
||||
def _load_remote_query_config() -> dict:
|
||||
"""Load remote_query settings from instance.yaml with defaults.
|
||||
|
||||
Uses raw YAML loading instead of load_instance_config() to avoid
|
||||
requiring webapp secrets (WEBAPP_SECRET_KEY etc.) that analysts
|
||||
don't have access to.
|
||||
"""
|
||||
import yaml as _yaml
|
||||
from pathlib import Path as _Path
|
||||
|
||||
instance_config: dict = {}
|
||||
config_dir = _Path(os.environ.get("CONFIG_DIR", "./config"))
|
||||
yaml_path = config_dir / "instance.yaml"
|
||||
if yaml_path.exists():
|
||||
try:
|
||||
with open(yaml_path) as f:
|
||||
instance_config = _yaml.safe_load(f) or {}
|
||||
except Exception as e:
|
||||
logger.warning("Could not load instance.yaml: %s. Using defaults.", e)
|
||||
|
||||
return {
|
||||
"timeout_seconds": get_instance_value(
|
||||
instance_config, "remote_query", "timeout_seconds", default=300,
|
||||
),
|
||||
"max_result_rows": get_instance_value(
|
||||
instance_config, "remote_query", "max_result_rows", default=100_000,
|
||||
),
|
||||
"max_bq_registration_rows": get_instance_value(
|
||||
instance_config, "remote_query", "max_bq_registration_rows", default=500_000,
|
||||
),
|
||||
"default_format": get_instance_value(
|
||||
instance_config, "remote_query", "default_format", default="table",
|
||||
),
|
||||
"output_dir": get_instance_value(
|
||||
instance_config, "remote_query", "output_dir", default="/tmp/remote_query",
|
||||
),
|
||||
}
|
||||
|
||||
|
||||
# ---------------------------------------------------------------------------
|
||||
# BQ registration with safety checks
|
||||
# ---------------------------------------------------------------------------
|
||||
|
||||
def _validate_bq_result_size(
|
||||
bq_client, sql: str, alias: str, max_rows: int,
|
||||
) -> int:
|
||||
"""Execute COUNT(*) on the BQ sub-query before fetching all rows.
|
||||
|
||||
Args:
|
||||
bq_client: BigQuery client instance
|
||||
sql: The BQ SQL query to count
|
||||
alias: Alias name (for error messages)
|
||||
max_rows: Maximum allowed rows
|
||||
|
||||
Returns:
|
||||
Row count
|
||||
|
||||
Raises:
|
||||
RemoteQueryError: If count exceeds max_rows
|
||||
"""
|
||||
count_sql = f"SELECT COUNT(*) AS cnt FROM ({sql})"
|
||||
_log_progress(f" Counting rows for '{alias}'...")
|
||||
|
||||
job = bq_client.query(count_sql)
|
||||
result = job.result()
|
||||
row_count = next(iter(result))[0]
|
||||
|
||||
if row_count > max_rows:
|
||||
raise RemoteQueryError(
|
||||
f"BQ sub-query '{alias}' would return {row_count:,} rows "
|
||||
f"(limit: {max_rows:,}). Add more WHERE filters or GROUP BY "
|
||||
f"to reduce the result set."
|
||||
)
|
||||
|
||||
return row_count
|
||||
|
||||
|
||||
def _estimate_memory_mb(row_count: int, column_count: int) -> float:
|
||||
"""Estimate memory usage in MB for a PyArrow table.
|
||||
|
||||
Uses ~50 bytes per cell as a rough average across data types.
|
||||
"""
|
||||
return (row_count * column_count * 50) / (1024 * 1024)
|
||||
|
||||
|
||||
def _register_bq_views(
|
||||
conn: duckdb.DuckDBPyConnection,
|
||||
registrations: list[tuple[str, str]],
|
||||
max_bq_rows: int,
|
||||
timeout_seconds: int,
|
||||
quiet: bool = False,
|
||||
) -> dict[str, int]:
|
||||
"""Register BQ query results as DuckDB views with safety checks.
|
||||
|
||||
Args:
|
||||
conn: DuckDB connection
|
||||
registrations: List of (alias, bq_sql) tuples
|
||||
max_bq_rows: Max rows per sub-query
|
||||
timeout_seconds: BQ job timeout
|
||||
quiet: Suppress progress messages
|
||||
|
||||
Returns:
|
||||
Dict of {alias: row_count}
|
||||
"""
|
||||
if not registrations:
|
||||
return {}
|
||||
|
||||
bq_project = os.environ.get("BIGQUERY_PROJECT")
|
||||
if not bq_project:
|
||||
raise RemoteQueryError(
|
||||
"BIGQUERY_PROJECT environment variable not set. "
|
||||
"Required for BigQuery sub-queries."
|
||||
)
|
||||
|
||||
bq_client = _create_bq_client(bq_project)
|
||||
results = {}
|
||||
|
||||
for alias, bq_sql in registrations:
|
||||
start_time = time.time()
|
||||
|
||||
# Phase 1: COUNT(*) safety check
|
||||
row_count = _validate_bq_result_size(bq_client, bq_sql, alias, max_bq_rows)
|
||||
_log_progress(f" '{alias}': {row_count:,} rows (within limit)")
|
||||
|
||||
# Phase 2: Memory estimation
|
||||
# Estimate column count from a LIMIT 0 query (cheap)
|
||||
sample_job = bq_client.query(f"SELECT * FROM ({bq_sql}) LIMIT 0")
|
||||
schema = sample_job.result().schema
|
||||
col_count = len(schema)
|
||||
estimated_mb = _estimate_memory_mb(row_count, col_count)
|
||||
|
||||
if estimated_mb > 2048: # 2 GB = 25% of 8 GB server RAM
|
||||
raise RemoteQueryError(
|
||||
f"BQ sub-query '{alias}' estimated memory: {estimated_mb:.0f} MB "
|
||||
f"({row_count:,} rows x {col_count} cols). "
|
||||
f"Limit is 2048 MB. Add more aggregation or filters."
|
||||
)
|
||||
|
||||
# Phase 3: Execute and register
|
||||
_log_progress(f" Fetching '{alias}' ({row_count:,} rows, ~{estimated_mb:.0f} MB)...")
|
||||
actual_rows = register_bq_table(
|
||||
conn=conn,
|
||||
table_id=f"bq_registration.{alias}",
|
||||
view_name=alias,
|
||||
sql=bq_sql,
|
||||
bq_project=bq_project,
|
||||
)
|
||||
|
||||
elapsed = time.time() - start_time
|
||||
_log_progress(f" '{alias}' registered: {actual_rows:,} rows in {elapsed:.1f}s")
|
||||
results[alias] = actual_rows
|
||||
|
||||
return results
|
||||
|
||||
|
||||
# ---------------------------------------------------------------------------
|
||||
# Local view setup
|
||||
# ---------------------------------------------------------------------------
|
||||
|
||||
def _setup_local_views(
|
||||
conn: duckdb.DuckDBPyConnection,
|
||||
data_dir: str,
|
||||
quiet: bool = False,
|
||||
) -> list[str]:
|
||||
"""Create DuckDB views for all local/hybrid tables from Parquet.
|
||||
|
||||
Args:
|
||||
conn: DuckDB connection
|
||||
data_dir: Path to data directory (e.g., "/data/src_data")
|
||||
quiet: Suppress progress messages
|
||||
|
||||
Returns:
|
||||
List of created view names
|
||||
"""
|
||||
created, skipped = create_local_views(
|
||||
conn=conn,
|
||||
data_dir=data_dir,
|
||||
verbose=not quiet,
|
||||
)
|
||||
return created
|
||||
|
||||
|
||||
# ---------------------------------------------------------------------------
|
||||
# Output formatting
|
||||
# ---------------------------------------------------------------------------
|
||||
|
||||
def _print_table(columns: list[str], rows: list[tuple]) -> None:
|
||||
"""Print an aligned ASCII table to stdout."""
|
||||
if not rows:
|
||||
print("(empty result)")
|
||||
return
|
||||
|
||||
# Calculate column widths
|
||||
str_rows = [[str(v) if v is not None else "NULL" for v in row] for row in rows]
|
||||
widths = [len(col) for col in columns]
|
||||
for row in str_rows:
|
||||
for i, val in enumerate(row):
|
||||
widths[i] = max(widths[i], len(val))
|
||||
|
||||
# Header
|
||||
header = " | ".join(col.ljust(widths[i]) for i, col in enumerate(columns))
|
||||
separator = "-+-".join("-" * widths[i] for i in range(len(columns)))
|
||||
print(header)
|
||||
print(separator)
|
||||
|
||||
# Rows
|
||||
for row in str_rows:
|
||||
line = " | ".join(val.ljust(widths[i]) for i, val in enumerate(row))
|
||||
print(line)
|
||||
|
||||
print(f"\n({len(rows)} rows)")
|
||||
|
||||
|
||||
def _format_output(
|
||||
conn: duckdb.DuckDBPyConnection,
|
||||
sql: str,
|
||||
fmt: str,
|
||||
output_path: Optional[str],
|
||||
max_rows: int,
|
||||
) -> None:
|
||||
"""Execute final SQL and output results in the requested format.
|
||||
|
||||
Args:
|
||||
conn: DuckDB connection with all views registered
|
||||
sql: The final DuckDB SQL query
|
||||
fmt: Output format (table, csv, json, parquet)
|
||||
output_path: File path for file-based outputs
|
||||
max_rows: Maximum rows to return
|
||||
"""
|
||||
# Add LIMIT to prevent runaway results
|
||||
limited_sql = f"SELECT * FROM ({sql}) AS _rq LIMIT {max_rows + 1}"
|
||||
result = conn.execute(limited_sql)
|
||||
columns = [desc[0] for desc in result.description]
|
||||
rows = result.fetchall()
|
||||
|
||||
# Check if result exceeded limit
|
||||
if len(rows) > max_rows:
|
||||
rows = rows[:max_rows]
|
||||
_log_progress(
|
||||
f" WARNING: Result truncated to {max_rows:,} rows. "
|
||||
f"Add more filters or increase --max-rows."
|
||||
)
|
||||
|
||||
if fmt == "table":
|
||||
_print_table(columns, rows)
|
||||
|
||||
elif fmt == "csv":
|
||||
if output_path:
|
||||
with open(output_path, "w", newline="") as f:
|
||||
writer = csv.writer(f)
|
||||
writer.writerow(columns)
|
||||
writer.writerows(rows)
|
||||
_log_progress(f" CSV written: {output_path} ({len(rows)} rows)")
|
||||
else:
|
||||
writer = csv.writer(sys.stdout)
|
||||
writer.writerow(columns)
|
||||
writer.writerows(rows)
|
||||
|
||||
elif fmt == "json":
|
||||
data = [dict(zip(columns, row)) for row in rows]
|
||||
json_str = json.dumps(data, default=str, indent=2)
|
||||
if output_path:
|
||||
with open(output_path, "w") as f:
|
||||
f.write(json_str)
|
||||
_log_progress(f" JSON written: {output_path} ({len(rows)} rows)")
|
||||
else:
|
||||
print(json_str)
|
||||
|
||||
elif fmt == "parquet":
|
||||
import pyarrow as pa
|
||||
import pyarrow.parquet as pq
|
||||
|
||||
# Re-execute without limit wrapper for clean Arrow export
|
||||
arrow_result = conn.execute(
|
||||
f"SELECT * FROM ({sql}) AS _rq LIMIT {max_rows}"
|
||||
).arrow().read_all()
|
||||
|
||||
if not output_path:
|
||||
output_path = str(Path(_load_remote_query_config()["output_dir"]) / "result.parquet")
|
||||
|
||||
Path(output_path).parent.mkdir(parents=True, exist_ok=True)
|
||||
pq.write_table(arrow_result, output_path)
|
||||
_log_progress(
|
||||
f" Parquet written: {output_path} "
|
||||
f"({arrow_result.num_rows} rows, {arrow_result.num_columns} cols)"
|
||||
)
|
||||
|
||||
else:
|
||||
raise RemoteQueryError(f"Unknown format: {fmt}")
|
||||
|
||||
|
||||
# ---------------------------------------------------------------------------
|
||||
# Progress logging (stderr so stdout stays clean for data)
|
||||
# ---------------------------------------------------------------------------
|
||||
|
||||
_quiet_mode = False
|
||||
|
||||
|
||||
def _log_progress(msg: str) -> None:
|
||||
"""Print progress message to stderr (keeps stdout clean for data output)."""
|
||||
if not _quiet_mode:
|
||||
print(msg, file=sys.stderr)
|
||||
|
||||
|
||||
# ---------------------------------------------------------------------------
|
||||
# Main execution
|
||||
# ---------------------------------------------------------------------------
|
||||
|
||||
def execute_remote_query(
|
||||
sql: str,
|
||||
bq_registrations: list[tuple[str, str]],
|
||||
fmt: str = "table",
|
||||
output: Optional[str] = None,
|
||||
max_rows: Optional[int] = None,
|
||||
max_bq_rows: Optional[int] = None,
|
||||
timeout: Optional[int] = None,
|
||||
data_dir: str = "/data/src_data",
|
||||
quiet: bool = False,
|
||||
) -> None:
|
||||
"""Main execution function for remote queries.
|
||||
|
||||
Args:
|
||||
sql: DuckDB SQL query to execute
|
||||
bq_registrations: List of (alias, bq_sql) tuples
|
||||
fmt: Output format (table, csv, json, parquet)
|
||||
output: Output file path (for parquet/csv/json)
|
||||
max_rows: Max rows in final result
|
||||
max_bq_rows: Max rows per BQ sub-query
|
||||
timeout: Query timeout in seconds
|
||||
data_dir: Path to data directory
|
||||
quiet: Suppress progress messages
|
||||
"""
|
||||
global _quiet_mode
|
||||
_quiet_mode = quiet
|
||||
|
||||
config = _load_remote_query_config()
|
||||
max_rows = max_rows or config["max_result_rows"]
|
||||
max_bq_rows = max_bq_rows or config["max_bq_registration_rows"]
|
||||
timeout = timeout or config["timeout_seconds"]
|
||||
fmt = fmt or config["default_format"]
|
||||
|
||||
start_time = time.time()
|
||||
|
||||
# Create in-memory DuckDB connection
|
||||
conn = duckdb.connect(":memory:")
|
||||
|
||||
try:
|
||||
# Step 1: Register local Parquet views
|
||||
_log_progress("Setting up local views...")
|
||||
local_views = _setup_local_views(conn, data_dir, quiet=quiet)
|
||||
_log_progress(f" {len(local_views)} local views ready")
|
||||
|
||||
# Step 2: Register BQ sub-query results
|
||||
if bq_registrations:
|
||||
_log_progress(f"Registering {len(bq_registrations)} BQ sub-queries...")
|
||||
bq_results = _register_bq_views(
|
||||
conn, bq_registrations, max_bq_rows, timeout, quiet=quiet,
|
||||
)
|
||||
for alias, count in bq_results.items():
|
||||
_log_progress(f" {alias}: {count:,} rows")
|
||||
|
||||
# Step 3: Execute the final DuckDB query
|
||||
_log_progress("Executing query...")
|
||||
_format_output(conn, sql, fmt, output, max_rows)
|
||||
|
||||
elapsed = time.time() - start_time
|
||||
_log_progress(f"Done in {elapsed:.1f}s")
|
||||
|
||||
finally:
|
||||
conn.close()
|
||||
|
||||
|
||||
# ---------------------------------------------------------------------------
|
||||
# CLI argument parsing
|
||||
# ---------------------------------------------------------------------------
|
||||
|
||||
def _parse_register_bq(value: str) -> tuple[str, str]:
|
||||
"""Parse --register-bq argument in 'alias=SQL' format.
|
||||
|
||||
Args:
|
||||
value: String in format "alias=SELECT ..."
|
||||
|
||||
Returns:
|
||||
Tuple of (alias, sql)
|
||||
|
||||
Raises:
|
||||
argparse.ArgumentTypeError: If format is invalid
|
||||
"""
|
||||
eq_pos = value.find("=")
|
||||
if eq_pos <= 0:
|
||||
raise argparse.ArgumentTypeError(
|
||||
f"Invalid --register-bq format: '{value}'. "
|
||||
f"Expected: 'alias=SELECT ...' (e.g., 'traffic=SELECT report_date, ...')"
|
||||
)
|
||||
alias = value[:eq_pos].strip()
|
||||
sql = value[eq_pos + 1:].strip()
|
||||
if not sql:
|
||||
raise argparse.ArgumentTypeError(
|
||||
f"Empty SQL in --register-bq for alias '{alias}'"
|
||||
)
|
||||
return alias, sql
|
||||
|
||||
|
||||
def build_parser() -> argparse.ArgumentParser:
|
||||
"""Build the argument parser for remote_query CLI."""
|
||||
parser = argparse.ArgumentParser(
|
||||
description="Execute DuckDB queries spanning local Parquet + remote BigQuery tables",
|
||||
formatter_class=argparse.RawDescriptionHelpFormatter,
|
||||
epilog="""
|
||||
Examples:
|
||||
# Local-only query (no BigQuery):
|
||||
python -m src.remote_query --sql "SELECT COUNT(*) FROM order_economics"
|
||||
|
||||
# Register BQ result and query it:
|
||||
python -m src.remote_query \\
|
||||
--register-bq "traffic=SELECT report_date, SUM(visitors) FROM \\`proj.ds.table\\` GROUP BY 1" \\
|
||||
--sql "SELECT * FROM traffic ORDER BY report_date"
|
||||
|
||||
# JOIN local + remote:
|
||||
python -m src.remote_query \\
|
||||
--register-bq "traffic=SELECT ... GROUP BY ..." \\
|
||||
--sql "SELECT o.*, t.visitors FROM order_economics o JOIN traffic t ON ..." \\
|
||||
--format parquet --output /tmp/result.parquet
|
||||
""",
|
||||
)
|
||||
parser.add_argument(
|
||||
"--sql",
|
||||
required=False, # not required when --stdin is used
|
||||
default=None,
|
||||
help="DuckDB SQL query (executed after all views are registered)",
|
||||
)
|
||||
parser.add_argument(
|
||||
"--register-bq",
|
||||
action="append",
|
||||
type=_parse_register_bq,
|
||||
default=[],
|
||||
metavar="ALIAS=SQL",
|
||||
dest="bq_registrations",
|
||||
help='Register BQ query result as DuckDB view. Format: "alias=BQ_SQL". Repeatable.',
|
||||
)
|
||||
parser.add_argument(
|
||||
"--format",
|
||||
choices=["table", "csv", "json", "parquet"],
|
||||
default=None,
|
||||
dest="fmt",
|
||||
help="Output format (default: from config or 'table')",
|
||||
)
|
||||
parser.add_argument(
|
||||
"--output",
|
||||
default=None,
|
||||
help="Output file path for parquet/csv/json (default: auto for parquet)",
|
||||
)
|
||||
parser.add_argument(
|
||||
"--max-rows",
|
||||
type=int,
|
||||
default=None,
|
||||
help="Max rows in final result (default: from config)",
|
||||
)
|
||||
parser.add_argument(
|
||||
"--max-bq-rows",
|
||||
type=int,
|
||||
default=None,
|
||||
help="Max rows per BQ sub-query (default: from config)",
|
||||
)
|
||||
parser.add_argument(
|
||||
"--timeout",
|
||||
type=int,
|
||||
default=None,
|
||||
help="Query timeout in seconds (default: from config)",
|
||||
)
|
||||
parser.add_argument(
|
||||
"--data-dir",
|
||||
default="/data/src_data",
|
||||
help="Parquet data directory (default: /data/src_data)",
|
||||
)
|
||||
parser.add_argument(
|
||||
"--quiet",
|
||||
action="store_true",
|
||||
help="Suppress progress messages (stderr)",
|
||||
)
|
||||
parser.add_argument(
|
||||
"--stdin",
|
||||
action="store_true",
|
||||
help="Read query spec from stdin as JSON. Avoids shell escaping issues.",
|
||||
)
|
||||
return parser
|
||||
|
||||
|
||||
def _parse_stdin_query() -> dict:
|
||||
"""Parse query specification from stdin JSON.
|
||||
|
||||
Expected format:
|
||||
{
|
||||
"sql": "SELECT ... FROM ...",
|
||||
"register_bq": {"alias": "BQ SQL", ...},
|
||||
"format": "table",
|
||||
"output": "/path/to/file",
|
||||
"max_rows": 100000,
|
||||
"max_bq_rows": 500000
|
||||
}
|
||||
|
||||
Returns:
|
||||
Dict with parsed query spec
|
||||
"""
|
||||
raw = sys.stdin.read().strip()
|
||||
if not raw:
|
||||
raise RemoteQueryError("Empty stdin. Provide JSON query spec.")
|
||||
|
||||
try:
|
||||
spec = json.loads(raw)
|
||||
except json.JSONDecodeError as e:
|
||||
raise RemoteQueryError(f"Invalid JSON on stdin: {e}")
|
||||
|
||||
if "sql" not in spec:
|
||||
raise RemoteQueryError("JSON must contain 'sql' field.")
|
||||
|
||||
return spec
|
||||
|
||||
|
||||
def main(argv: Optional[list[str]] = None) -> None:
|
||||
"""CLI entry point."""
|
||||
parser = build_parser()
|
||||
args = parser.parse_args(argv)
|
||||
|
||||
# Setup logging
|
||||
logging.basicConfig(
|
||||
level=logging.WARNING if args.quiet else logging.INFO,
|
||||
format="%(asctime)s - %(name)s - %(levelname)s - %(message)s",
|
||||
stream=sys.stderr,
|
||||
)
|
||||
|
||||
try:
|
||||
# --stdin mode: read query spec from JSON on stdin (no shell escaping needed)
|
||||
if args.stdin:
|
||||
spec = _parse_stdin_query()
|
||||
bq_regs = [
|
||||
(alias, sql) for alias, sql in spec.get("register_bq", {}).items()
|
||||
]
|
||||
execute_remote_query(
|
||||
sql=spec["sql"],
|
||||
bq_registrations=bq_regs,
|
||||
fmt=spec.get("format", args.fmt),
|
||||
output=spec.get("output", args.output),
|
||||
max_rows=spec.get("max_rows", args.max_rows),
|
||||
max_bq_rows=spec.get("max_bq_rows", args.max_bq_rows),
|
||||
timeout=args.timeout,
|
||||
data_dir=args.data_dir,
|
||||
quiet=args.quiet,
|
||||
)
|
||||
return
|
||||
|
||||
# Validate --sql is provided when not using --stdin
|
||||
if not args.sql:
|
||||
parser.error("--sql is required (or use --stdin for JSON input)")
|
||||
|
||||
execute_remote_query(
|
||||
sql=args.sql,
|
||||
bq_registrations=args.bq_registrations,
|
||||
fmt=args.fmt,
|
||||
output=args.output,
|
||||
max_rows=args.max_rows,
|
||||
max_bq_rows=args.max_bq_rows,
|
||||
timeout=args.timeout,
|
||||
data_dir=args.data_dir,
|
||||
quiet=args.quiet,
|
||||
)
|
||||
except RemoteQueryError as e:
|
||||
print(f"ERROR: {e}", file=sys.stderr)
|
||||
sys.exit(1)
|
||||
except KeyboardInterrupt:
|
||||
print("\nInterrupted.", file=sys.stderr)
|
||||
sys.exit(130)
|
||||
except Exception as e:
|
||||
print(f"UNEXPECTED ERROR: {e}", file=sys.stderr)
|
||||
logger.exception("Unexpected error in remote_query")
|
||||
sys.exit(2)
|
||||
|
||||
|
||||
if __name__ == "__main__":
|
||||
main()
|
||||
|
|
@ -1,464 +0,0 @@
|
|||
"""
|
||||
Table Registry - Central source of truth for registered tables.
|
||||
|
||||
Manages table registrations in a JSON file. Generates data_description.md
|
||||
as a read-only output for downstream consumers (config.py, profiler.py, webapp).
|
||||
|
||||
Supports:
|
||||
- CRUD operations on registered tables
|
||||
- Folder mapping (bucket -> folder name)
|
||||
- Atomic persistence (tempfile + os.replace)
|
||||
- Optimistic locking (version field)
|
||||
- Audit logging
|
||||
- One-time migration from existing data_description.md
|
||||
- Generation of data_description.md with checksum header
|
||||
"""
|
||||
|
||||
import hashlib
|
||||
import json
|
||||
import logging
|
||||
import os
|
||||
import re
|
||||
import tempfile
|
||||
from datetime import datetime, timezone
|
||||
from pathlib import Path
|
||||
from typing import Any, Optional
|
||||
|
||||
import yaml
|
||||
|
||||
logger = logging.getLogger(__name__)
|
||||
|
||||
# Default registry location
|
||||
_DEFAULT_REGISTRY_DIR = Path(
|
||||
os.environ.get("REGISTRY_DIR", "/data/src_data/metadata")
|
||||
)
|
||||
_REGISTRY_FILENAME = "table_registry.json"
|
||||
|
||||
|
||||
def _now_iso() -> str:
|
||||
"""Return current UTC time as ISO string."""
|
||||
return datetime.now(timezone.utc).isoformat()
|
||||
|
||||
|
||||
def _atomic_write_json(path: Path, data: dict) -> None:
|
||||
"""Write JSON atomically using tempfile + os.replace."""
|
||||
path.parent.mkdir(parents=True, exist_ok=True)
|
||||
fd, tmp_path = tempfile.mkstemp(
|
||||
dir=str(path.parent), suffix=".tmp"
|
||||
)
|
||||
try:
|
||||
with os.fdopen(fd, "w") as f:
|
||||
json.dump(data, f, indent=2, default=str)
|
||||
os.chmod(tmp_path, 0o660)
|
||||
os.replace(tmp_path, str(path))
|
||||
except Exception:
|
||||
try:
|
||||
os.unlink(tmp_path)
|
||||
except OSError:
|
||||
pass
|
||||
raise
|
||||
|
||||
|
||||
def _audit_log(registry_path: Path, action: str, details: dict) -> None:
|
||||
"""Append entry to registry audit log."""
|
||||
audit_path = registry_path.parent / "registry_audit.log"
|
||||
try:
|
||||
entry = {
|
||||
"timestamp": _now_iso(),
|
||||
"action": action,
|
||||
**details,
|
||||
}
|
||||
with open(audit_path, "a") as f:
|
||||
f.write(json.dumps(entry, default=str) + "\n")
|
||||
except Exception as e:
|
||||
logger.warning(f"Could not write audit log: {e}")
|
||||
|
||||
|
||||
class TableRegistry:
|
||||
"""Manages table registrations. Source of truth for what gets synced."""
|
||||
|
||||
def __init__(self, registry_path: Path):
|
||||
self.registry_path = registry_path
|
||||
self._data = self._load()
|
||||
|
||||
@classmethod
|
||||
def default(cls) -> "TableRegistry":
|
||||
"""Create registry at the default location."""
|
||||
return cls(_DEFAULT_REGISTRY_DIR / _REGISTRY_FILENAME)
|
||||
|
||||
# ── Persistence ──────────────────────────────────────────────────
|
||||
|
||||
def _load(self) -> dict:
|
||||
"""Load registry from disk. Returns empty structure if not found."""
|
||||
if self.registry_path.exists():
|
||||
try:
|
||||
with open(self.registry_path) as f:
|
||||
data = json.load(f)
|
||||
logger.info(
|
||||
f"Registry loaded: {len(data.get('tables', []))} tables"
|
||||
)
|
||||
return data
|
||||
except Exception as e:
|
||||
logger.error(f"Error loading registry: {e}")
|
||||
return self._empty_registry()
|
||||
|
||||
def _save(self) -> None:
|
||||
"""Save registry to disk atomically."""
|
||||
self._data["_metadata"]["updated_at"] = _now_iso()
|
||||
self._data["_metadata"]["version"] = self.version + 1
|
||||
_atomic_write_json(self.registry_path, self._data)
|
||||
logger.debug("Registry saved (version %d)", self.version)
|
||||
|
||||
@staticmethod
|
||||
def _empty_registry() -> dict:
|
||||
now = _now_iso()
|
||||
return {
|
||||
"_metadata": {
|
||||
"version": 0,
|
||||
"created_at": now,
|
||||
"updated_at": now,
|
||||
},
|
||||
"folder_mapping": {},
|
||||
"tables": [],
|
||||
}
|
||||
|
||||
# ── Properties ───────────────────────────────────────────────────
|
||||
|
||||
@property
|
||||
def version(self) -> int:
|
||||
return self._data.get("_metadata", {}).get("version", 0)
|
||||
|
||||
# ── Core CRUD ────────────────────────────────────────────────────
|
||||
|
||||
def list_tables(self) -> list[dict]:
|
||||
"""Return all registered tables."""
|
||||
return list(self._data.get("tables", []))
|
||||
|
||||
def get_table(self, table_id: str) -> Optional[dict]:
|
||||
"""Get a single table by ID."""
|
||||
for t in self._data.get("tables", []):
|
||||
if t["id"] == table_id:
|
||||
return dict(t)
|
||||
return None
|
||||
|
||||
def is_registered(self, table_id: str) -> bool:
|
||||
return any(t["id"] == table_id for t in self._data.get("tables", []))
|
||||
|
||||
def register_table(
|
||||
self,
|
||||
table_def: dict,
|
||||
registered_by: str,
|
||||
expected_version: Optional[int] = None,
|
||||
) -> None:
|
||||
"""Register a new table.
|
||||
|
||||
Args:
|
||||
table_def: Table definition dict (must contain id, name, sync_strategy, primary_key).
|
||||
registered_by: Email of the admin who registered the table.
|
||||
expected_version: If provided, reject if registry version doesn't match (optimistic lock).
|
||||
|
||||
Raises:
|
||||
ValueError: If table already registered or validation fails.
|
||||
ConflictError: If expected_version doesn't match.
|
||||
"""
|
||||
if expected_version is not None and expected_version != self.version:
|
||||
raise ConflictError(
|
||||
f"Version conflict: expected {expected_version}, current {self.version}"
|
||||
)
|
||||
|
||||
table_id = table_def.get("id", "")
|
||||
if not table_id:
|
||||
raise ValueError("Table definition must include 'id'")
|
||||
|
||||
if self.is_registered(table_id):
|
||||
raise ValueError(f"Table '{table_id}' is already registered")
|
||||
|
||||
# Validate required fields
|
||||
for field in ("name", "sync_strategy", "primary_key"):
|
||||
if not table_def.get(field):
|
||||
raise ValueError(f"Table definition must include '{field}'")
|
||||
|
||||
# Validate sync_strategy
|
||||
valid_strategies = ("full_refresh", "incremental", "partitioned")
|
||||
if table_def["sync_strategy"] not in valid_strategies:
|
||||
raise ValueError(
|
||||
f"Invalid sync_strategy '{table_def['sync_strategy']}'. "
|
||||
f"Allowed: {', '.join(valid_strategies)}"
|
||||
)
|
||||
|
||||
# Build full record
|
||||
record = {
|
||||
"id": table_id,
|
||||
"name": table_def["name"],
|
||||
"description": table_def.get("description", ""),
|
||||
"primary_key": table_def["primary_key"],
|
||||
"sync_strategy": table_def["sync_strategy"],
|
||||
"incremental_window_days": table_def.get("incremental_window_days"),
|
||||
"partition_by": table_def.get("partition_by"),
|
||||
"partition_granularity": table_def.get("partition_granularity"),
|
||||
"foreign_keys": table_def.get("foreign_keys", []),
|
||||
"where_filters": table_def.get("where_filters", []),
|
||||
"folder": table_def.get("folder"),
|
||||
"dataset": table_def.get("dataset"),
|
||||
"initial_load_chunk_days": table_def.get("initial_load_chunk_days", 30),
|
||||
"registered_at": _now_iso(),
|
||||
"registered_by": registered_by,
|
||||
"source_metadata": table_def.get("source_metadata", {}),
|
||||
}
|
||||
|
||||
self._data["tables"].append(record)
|
||||
self._save()
|
||||
|
||||
_audit_log(self.registry_path, "register", {
|
||||
"table_id": table_id,
|
||||
"by": registered_by,
|
||||
})
|
||||
|
||||
def unregister_table(
|
||||
self,
|
||||
table_id: str,
|
||||
unregistered_by: str = "",
|
||||
expected_version: Optional[int] = None,
|
||||
) -> None:
|
||||
"""Remove a table from the registry.
|
||||
|
||||
Raises:
|
||||
ValueError: If table not found.
|
||||
ConflictError: If expected_version doesn't match.
|
||||
"""
|
||||
if expected_version is not None and expected_version != self.version:
|
||||
raise ConflictError(
|
||||
f"Version conflict: expected {expected_version}, current {self.version}"
|
||||
)
|
||||
|
||||
tables = self._data.get("tables", [])
|
||||
new_tables = [t for t in tables if t["id"] != table_id]
|
||||
|
||||
if len(new_tables) == len(tables):
|
||||
raise ValueError(f"Table '{table_id}' is not registered")
|
||||
|
||||
self._data["tables"] = new_tables
|
||||
self._save()
|
||||
|
||||
_audit_log(self.registry_path, "unregister", {
|
||||
"table_id": table_id,
|
||||
"by": unregistered_by,
|
||||
})
|
||||
|
||||
def update_table(
|
||||
self,
|
||||
table_id: str,
|
||||
updates: dict,
|
||||
updated_by: str = "",
|
||||
expected_version: Optional[int] = None,
|
||||
) -> None:
|
||||
"""Update table configuration.
|
||||
|
||||
Raises:
|
||||
ValueError: If table not found.
|
||||
ConflictError: If expected_version doesn't match.
|
||||
"""
|
||||
if expected_version is not None and expected_version != self.version:
|
||||
raise ConflictError(
|
||||
f"Version conflict: expected {expected_version}, current {self.version}"
|
||||
)
|
||||
|
||||
# Fields that can be updated
|
||||
allowed_fields = {
|
||||
"description", "primary_key", "sync_strategy",
|
||||
"incremental_window_days", "partition_by", "partition_granularity",
|
||||
"foreign_keys", "where_filters", "folder", "dataset",
|
||||
"initial_load_chunk_days",
|
||||
}
|
||||
|
||||
for t in self._data.get("tables", []):
|
||||
if t["id"] == table_id:
|
||||
for key, value in updates.items():
|
||||
if key in allowed_fields:
|
||||
t[key] = value
|
||||
self._save()
|
||||
_audit_log(self.registry_path, "update", {
|
||||
"table_id": table_id,
|
||||
"fields": list(updates.keys()),
|
||||
"by": updated_by,
|
||||
})
|
||||
return
|
||||
|
||||
raise ValueError(f"Table '{table_id}' is not registered")
|
||||
|
||||
# ── Folder mapping ───────────────────────────────────────────────
|
||||
|
||||
def get_folder_mapping(self) -> dict[str, str]:
|
||||
return dict(self._data.get("folder_mapping", {}))
|
||||
|
||||
def set_folder_mapping(self, bucket_id: str, folder: str) -> None:
|
||||
self._data.setdefault("folder_mapping", {})[bucket_id] = folder
|
||||
self._save()
|
||||
|
||||
# ── Generation ───────────────────────────────────────────────────
|
||||
|
||||
def generate_data_description_md(self, output_path: Path) -> None:
|
||||
"""Regenerate data_description.md from registry.
|
||||
|
||||
The generated file is read-only and includes a checksum header.
|
||||
Existing readers (config.py, profiler.py) consume this without changes.
|
||||
"""
|
||||
tables = self.list_tables()
|
||||
folder_mapping = self.get_folder_mapping()
|
||||
|
||||
# Build YAML structure matching existing data_description.md format
|
||||
yaml_data: dict[str, Any] = {}
|
||||
|
||||
if folder_mapping:
|
||||
yaml_data["folder_mapping"] = folder_mapping
|
||||
|
||||
yaml_tables = []
|
||||
for t in tables:
|
||||
entry: dict[str, Any] = {
|
||||
"id": t["id"],
|
||||
"name": t["name"],
|
||||
"description": t.get("description", ""),
|
||||
"primary_key": t["primary_key"],
|
||||
"sync_strategy": t["sync_strategy"],
|
||||
}
|
||||
|
||||
# Optional fields -- only include if set
|
||||
if t.get("incremental_window_days"):
|
||||
entry["incremental_window_days"] = t["incremental_window_days"]
|
||||
if t.get("partition_by"):
|
||||
entry["partition_by"] = t["partition_by"]
|
||||
if t.get("partition_granularity"):
|
||||
entry["partition_granularity"] = t["partition_granularity"]
|
||||
if t.get("max_history_days"):
|
||||
entry["max_history_days"] = t["max_history_days"]
|
||||
if t.get("initial_load_chunk_days") and t["initial_load_chunk_days"] != 30:
|
||||
entry["initial_load_chunk_days"] = t["initial_load_chunk_days"]
|
||||
if t.get("foreign_keys"):
|
||||
entry["foreign_keys"] = t["foreign_keys"]
|
||||
if t.get("where_filters"):
|
||||
entry["where_filters"] = t["where_filters"]
|
||||
if t.get("folder"):
|
||||
entry["folder"] = t["folder"]
|
||||
if t.get("dataset"):
|
||||
entry["dataset"] = t["dataset"]
|
||||
|
||||
yaml_tables.append(entry)
|
||||
|
||||
yaml_data["tables"] = yaml_tables
|
||||
|
||||
yaml_str = yaml.dump(
|
||||
yaml_data, default_flow_style=False, sort_keys=False, allow_unicode=True
|
||||
)
|
||||
|
||||
# Compute checksum
|
||||
checksum = hashlib.sha256(yaml_str.encode()).hexdigest()[:16]
|
||||
|
||||
# Build markdown
|
||||
lines = [
|
||||
f"<!-- AUTO-GENERATED from table_registry.json -- do not edit manually -->",
|
||||
f"<!-- Use the admin UI at /admin/tables to manage table registrations -->",
|
||||
f"<!-- checksum: sha256:{checksum} -->",
|
||||
"",
|
||||
"# Data Description",
|
||||
"",
|
||||
f"Generated at {_now_iso()} from table registry "
|
||||
f"(version {self.version}, {len(yaml_tables)} tables).",
|
||||
"",
|
||||
"```yaml",
|
||||
yaml_str.rstrip(),
|
||||
"```",
|
||||
"",
|
||||
]
|
||||
|
||||
content = "\n".join(lines)
|
||||
|
||||
output_path.parent.mkdir(parents=True, exist_ok=True)
|
||||
output_path.write_text(content)
|
||||
logger.info(
|
||||
f"Generated data_description.md: {len(yaml_tables)} tables "
|
||||
f"(checksum: {checksum})"
|
||||
)
|
||||
|
||||
# ── Migration ────────────────────────────────────────────────────
|
||||
|
||||
@classmethod
|
||||
def import_from_data_description(
|
||||
cls,
|
||||
md_path: Path,
|
||||
registry_path: Path,
|
||||
registered_by: str = "migration",
|
||||
) -> "TableRegistry":
|
||||
"""One-time migration: parse existing data_description.md into registry.
|
||||
|
||||
Creates a new registry JSON from the existing markdown YAML blocks.
|
||||
"""
|
||||
if not md_path.exists():
|
||||
raise FileNotFoundError(f"data_description.md not found: {md_path}")
|
||||
|
||||
content = md_path.read_text()
|
||||
|
||||
# Extract YAML blocks
|
||||
yaml_matches = re.findall(r"```yaml\n(.*?)```", content, re.DOTALL)
|
||||
if not yaml_matches:
|
||||
raise ValueError("No YAML blocks found in data_description.md")
|
||||
|
||||
all_tables: list[dict] = []
|
||||
folder_mapping: dict[str, str] = {}
|
||||
|
||||
for yaml_block in yaml_matches:
|
||||
data = yaml.safe_load(yaml_block)
|
||||
if data:
|
||||
if "tables" in data:
|
||||
all_tables.extend(data["tables"])
|
||||
if "folder_mapping" in data:
|
||||
folder_mapping.update(data["folder_mapping"])
|
||||
|
||||
if not all_tables:
|
||||
raise ValueError("No tables found in YAML blocks")
|
||||
|
||||
# Build registry
|
||||
registry = cls(registry_path)
|
||||
registry._data = cls._empty_registry()
|
||||
registry._data["folder_mapping"] = folder_mapping
|
||||
registry._data["_metadata"]["migrated_from"] = str(md_path)
|
||||
|
||||
now = _now_iso()
|
||||
for table_data in all_tables:
|
||||
record = {
|
||||
"id": table_data.get("id", ""),
|
||||
"name": table_data.get("name", ""),
|
||||
"description": table_data.get("description", ""),
|
||||
"primary_key": table_data.get("primary_key", ""),
|
||||
"sync_strategy": table_data.get("sync_strategy", "full_refresh"),
|
||||
"incremental_window_days": table_data.get("incremental_window_days"),
|
||||
"partition_by": table_data.get("partition_by"),
|
||||
"partition_granularity": table_data.get("partition_granularity"),
|
||||
"foreign_keys": table_data.get("foreign_keys", []),
|
||||
"where_filters": table_data.get("where_filters", []),
|
||||
"folder": table_data.get("folder"),
|
||||
"dataset": table_data.get("dataset"),
|
||||
"initial_load_chunk_days": table_data.get("initial_load_chunk_days", 30),
|
||||
"max_history_days": table_data.get("max_history_days"),
|
||||
"registered_at": now,
|
||||
"registered_by": registered_by,
|
||||
"source_metadata": {},
|
||||
}
|
||||
registry._data["tables"].append(record)
|
||||
|
||||
registry._save()
|
||||
|
||||
_audit_log(registry_path, "migrate", {
|
||||
"source": str(md_path),
|
||||
"tables_imported": len(all_tables),
|
||||
"by": registered_by,
|
||||
})
|
||||
|
||||
logger.info(
|
||||
f"Migrated {len(all_tables)} tables from {md_path} to registry"
|
||||
)
|
||||
return registry
|
||||
|
||||
|
||||
class ConflictError(Exception):
|
||||
"""Raised when optimistic locking version doesn't match."""
|
||||
pass
|
||||
File diff suppressed because it is too large
Load diff
File diff suppressed because it is too large
Load diff
|
|
@ -93,17 +93,15 @@ def mock_client():
|
|||
|
||||
|
||||
@pytest.fixture
|
||||
def mock_config():
|
||||
"""Return a mock Config with a single table."""
|
||||
cfg = MagicMock()
|
||||
cfg.tables = [
|
||||
FakeTableConfig(
|
||||
name="order_economics",
|
||||
id="prj.dataset.order_economics",
|
||||
catalog_fqn="bigquery.prj.dataset.order_economics",
|
||||
)
|
||||
def mock_tables():
|
||||
"""Return a list of table dicts matching TableRegistryRepository.list_all() format."""
|
||||
return [
|
||||
{
|
||||
"id": "prj.dataset.order_economics",
|
||||
"name": "order_economics",
|
||||
"catalog_fqn": "bigquery.prj.dataset.order_economics",
|
||||
}
|
||||
]
|
||||
return cfg
|
||||
|
||||
|
||||
CATALOG_URL = "https://catalog.example.com"
|
||||
|
|
@ -351,11 +349,11 @@ class TestExportMetrics:
|
|||
|
||||
class TestExportTables:
|
||||
def test_export_tables_writes_files(
|
||||
self, tmp_path: Path, mock_client, mock_config
|
||||
self, tmp_path: Path, mock_client, mock_tables
|
||||
):
|
||||
"""Creates table YAML with columns."""
|
||||
docs = tmp_path / "docs"
|
||||
count = export_tables(mock_client, mock_config, docs, CATALOG_URL)
|
||||
count = export_tables(mock_client, mock_tables, docs, CATALOG_URL)
|
||||
|
||||
assert count == 1
|
||||
|
||||
|
|
@ -374,21 +372,12 @@ class TestExportTables:
|
|||
assert parsed["columns"][0]["name"] == "order_id"
|
||||
|
||||
def test_export_tables_handles_api_error(
|
||||
self, tmp_path: Path, mock_client, mock_config
|
||||
self, tmp_path: Path, mock_client
|
||||
):
|
||||
"""Continues on per-table errors, exports remaining tables."""
|
||||
# Two tables: first will fail, second succeeds
|
||||
mock_config.tables = [
|
||||
FakeTableConfig(
|
||||
name="broken_table",
|
||||
id="prj.dataset.broken",
|
||||
catalog_fqn="bigquery.prj.dataset.broken",
|
||||
),
|
||||
FakeTableConfig(
|
||||
name="good_table",
|
||||
id="prj.dataset.good",
|
||||
catalog_fqn="bigquery.prj.dataset.good",
|
||||
),
|
||||
tables = [
|
||||
{"id": "prj.dataset.broken", "name": "broken_table", "catalog_fqn": "bigquery.prj.dataset.broken"},
|
||||
{"id": "prj.dataset.good", "name": "good_table", "catalog_fqn": "bigquery.prj.dataset.good"},
|
||||
]
|
||||
|
||||
def side_effect(fqn):
|
||||
|
|
@ -399,7 +388,7 @@ class TestExportTables:
|
|||
mock_client.get_table.side_effect = side_effect
|
||||
|
||||
docs = tmp_path / "docs"
|
||||
count = export_tables(mock_client, mock_config, docs, CATALOG_URL)
|
||||
count = export_tables(mock_client, tables, docs, CATALOG_URL)
|
||||
|
||||
# Only the good table should succeed
|
||||
assert count == 1
|
||||
|
|
@ -407,11 +396,11 @@ class TestExportTables:
|
|||
assert (docs / "tables" / "good_table.yml").exists()
|
||||
|
||||
def test_export_tables_uses_catalog_fqn(
|
||||
self, tmp_path: Path, mock_client, mock_config
|
||||
self, tmp_path: Path, mock_client, mock_tables
|
||||
):
|
||||
"""Uses explicit catalog_fqn when set on table config."""
|
||||
docs = tmp_path / "docs"
|
||||
export_tables(mock_client, mock_config, docs, CATALOG_URL)
|
||||
export_tables(mock_client, mock_tables, docs, CATALOG_URL)
|
||||
|
||||
mock_client.get_table.assert_called_once_with(
|
||||
"bigquery.prj.dataset.order_economics"
|
||||
|
|
@ -421,17 +410,12 @@ class TestExportTables:
|
|||
self, tmp_path: Path, mock_client
|
||||
):
|
||||
"""When catalog_fqn is None, derives FQN as 'bigquery.{id}'."""
|
||||
cfg = MagicMock()
|
||||
cfg.tables = [
|
||||
FakeTableConfig(
|
||||
name="my_table",
|
||||
id="project.dataset.my_table",
|
||||
catalog_fqn=None,
|
||||
)
|
||||
tables = [
|
||||
{"id": "project.dataset.my_table", "name": "my_table"},
|
||||
]
|
||||
|
||||
docs = tmp_path / "docs"
|
||||
export_tables(mock_client, cfg, docs, CATALOG_URL)
|
||||
export_tables(mock_client, tables, docs, CATALOG_URL)
|
||||
|
||||
mock_client.get_table.assert_called_once_with(
|
||||
"bigquery.project.dataset.my_table"
|
||||
|
|
|
|||
|
|
@ -1,69 +0,0 @@
|
|||
"""Tests for TableConfig.bq_entity_type field validation."""
|
||||
|
||||
import pytest
|
||||
|
||||
from src.config import TableConfig
|
||||
|
||||
|
||||
# ---------------------------------------------------------------------------
|
||||
# Helpers
|
||||
# ---------------------------------------------------------------------------
|
||||
def _make_table(**overrides) -> TableConfig:
|
||||
"""Create a TableConfig with sensible defaults, applying overrides."""
|
||||
defaults = dict(
|
||||
id="test.dataset.table",
|
||||
name="test_table",
|
||||
description="Test",
|
||||
primary_key="id",
|
||||
sync_strategy="full_refresh",
|
||||
)
|
||||
defaults.update(overrides)
|
||||
return TableConfig(**defaults)
|
||||
|
||||
|
||||
# ---------------------------------------------------------------------------
|
||||
# Tests
|
||||
# ---------------------------------------------------------------------------
|
||||
class TestBqEntityTypeDefault:
|
||||
def test_default_is_view(self):
|
||||
table = _make_table()
|
||||
assert table.bq_entity_type == "view"
|
||||
|
||||
|
||||
class TestBqEntityTypeValidValues:
|
||||
@pytest.mark.parametrize("entity_type", ["view", "table"])
|
||||
def test_valid_bq_entity_type(self, entity_type):
|
||||
table = _make_table(bq_entity_type=entity_type)
|
||||
assert table.bq_entity_type == entity_type
|
||||
|
||||
|
||||
class TestBqEntityTypeInvalid:
|
||||
@pytest.mark.parametrize("bad_type", ["VIEW", "physical", "", "tables", "materialized"])
|
||||
def test_invalid_bq_entity_type_raises(self, bad_type):
|
||||
with pytest.raises(ValueError, match="Invalid bq_entity_type"):
|
||||
_make_table(bq_entity_type=bad_type)
|
||||
|
||||
|
||||
class TestBqEntityTypeFromKwarg:
|
||||
def test_kwarg_sets_bq_entity_type(self):
|
||||
"""Simulate what _parse_data_description does: pass bq_entity_type as kwarg."""
|
||||
table = TableConfig(
|
||||
id="proj.dataset.orders",
|
||||
name="orders",
|
||||
description="Order data",
|
||||
primary_key="order_id",
|
||||
sync_strategy="full_refresh",
|
||||
bq_entity_type="table",
|
||||
)
|
||||
assert table.bq_entity_type == "table"
|
||||
|
||||
def test_kwarg_default_when_omitted(self):
|
||||
"""When YAML has no bq_entity_type, _parse_data_description passes 'view'."""
|
||||
table = TableConfig(
|
||||
id="proj.dataset.orders",
|
||||
name="orders",
|
||||
description="Order data",
|
||||
primary_key="order_id",
|
||||
sync_strategy="full_refresh",
|
||||
)
|
||||
assert table.bq_entity_type == "view"
|
||||
|
|
@ -1,69 +0,0 @@
|
|||
"""Tests for TableConfig.query_mode field validation."""
|
||||
|
||||
import pytest
|
||||
|
||||
from src.config import TableConfig
|
||||
|
||||
|
||||
# ---------------------------------------------------------------------------
|
||||
# Helpers
|
||||
# ---------------------------------------------------------------------------
|
||||
def _make_table(**overrides) -> TableConfig:
|
||||
"""Create a TableConfig with sensible defaults, applying overrides."""
|
||||
defaults = dict(
|
||||
id="test.dataset.table",
|
||||
name="test_table",
|
||||
description="Test",
|
||||
primary_key="id",
|
||||
sync_strategy="full_refresh",
|
||||
)
|
||||
defaults.update(overrides)
|
||||
return TableConfig(**defaults)
|
||||
|
||||
|
||||
# ---------------------------------------------------------------------------
|
||||
# Tests
|
||||
# ---------------------------------------------------------------------------
|
||||
class TestQueryModeDefault:
|
||||
def test_default_is_local(self):
|
||||
table = _make_table()
|
||||
assert table.query_mode == "local"
|
||||
|
||||
|
||||
class TestQueryModeValidValues:
|
||||
@pytest.mark.parametrize("mode", ["local", "remote", "hybrid"])
|
||||
def test_valid_query_mode(self, mode):
|
||||
table = _make_table(query_mode=mode)
|
||||
assert table.query_mode == mode
|
||||
|
||||
|
||||
class TestQueryModeInvalid:
|
||||
@pytest.mark.parametrize("bad_mode", ["invalid", "Local", "REMOTE", "", "sql"])
|
||||
def test_invalid_query_mode_raises(self, bad_mode):
|
||||
with pytest.raises(ValueError, match="Invalid query_mode"):
|
||||
_make_table(query_mode=bad_mode)
|
||||
|
||||
|
||||
class TestQueryModeFromKwarg:
|
||||
def test_kwarg_sets_query_mode(self):
|
||||
"""Simulate what _parse_data_description does: pass query_mode as kwarg."""
|
||||
table = TableConfig(
|
||||
id="proj.dataset.orders",
|
||||
name="orders",
|
||||
description="Order data",
|
||||
primary_key="order_id",
|
||||
sync_strategy="full_refresh",
|
||||
query_mode="remote",
|
||||
)
|
||||
assert table.query_mode == "remote"
|
||||
|
||||
def test_kwarg_default_when_omitted(self):
|
||||
"""When YAML has no query_mode, _parse_data_description passes 'local'."""
|
||||
table = TableConfig(
|
||||
id="proj.dataset.orders",
|
||||
name="orders",
|
||||
description="Order data",
|
||||
primary_key="order_id",
|
||||
sync_strategy="full_refresh",
|
||||
)
|
||||
assert table.query_mode == "local"
|
||||
|
|
@ -1,103 +0,0 @@
|
|||
"""Tests for TableConfig.sync_schedule field validation."""
|
||||
|
||||
import pytest
|
||||
|
||||
from src.config import TableConfig
|
||||
|
||||
|
||||
# ---------------------------------------------------------------------------
|
||||
# Helpers
|
||||
# ---------------------------------------------------------------------------
|
||||
def _make_table(**overrides) -> TableConfig:
|
||||
"""Create a TableConfig with sensible defaults, applying overrides."""
|
||||
defaults = dict(
|
||||
id="test.dataset.table",
|
||||
name="test_table",
|
||||
description="Test",
|
||||
primary_key="id",
|
||||
sync_strategy="full_refresh",
|
||||
)
|
||||
defaults.update(overrides)
|
||||
return TableConfig(**defaults)
|
||||
|
||||
|
||||
# ---------------------------------------------------------------------------
|
||||
# Tests
|
||||
# ---------------------------------------------------------------------------
|
||||
class TestSyncScheduleDefault:
|
||||
def test_default_is_none(self):
|
||||
table = _make_table()
|
||||
assert table.sync_schedule is None
|
||||
|
||||
|
||||
class TestSyncScheduleValidValues:
|
||||
@pytest.mark.parametrize(
|
||||
"schedule",
|
||||
[
|
||||
"every 15m",
|
||||
"every 1h",
|
||||
"daily 05:00",
|
||||
"daily 07:00,13:00,18:00",
|
||||
"daily 00:00,12:00",
|
||||
],
|
||||
ids=[
|
||||
"every-15m",
|
||||
"every-1h",
|
||||
"daily-single",
|
||||
"daily-three-times",
|
||||
"daily-two-times",
|
||||
],
|
||||
)
|
||||
def test_valid_schedule_accepted(self, schedule: str):
|
||||
table = _make_table(sync_schedule=schedule)
|
||||
assert table.sync_schedule == schedule
|
||||
|
||||
|
||||
class TestSyncScheduleEdgeCases:
|
||||
def test_every_zero_minutes(self):
|
||||
"""every 0m matches the regex -- validation is syntactic, not semantic."""
|
||||
table = _make_table(sync_schedule="every 0m")
|
||||
assert table.sync_schedule == "every 0m"
|
||||
|
||||
def test_daily_2359(self):
|
||||
table = _make_table(sync_schedule="daily 23:59")
|
||||
assert table.sync_schedule == "daily 23:59"
|
||||
|
||||
|
||||
class TestSyncScheduleInvalid:
|
||||
@pytest.mark.parametrize(
|
||||
"bad_schedule",
|
||||
[
|
||||
"daily 07:00,13:00,18:00,", # trailing comma
|
||||
"daily 7:00", # single-digit hour
|
||||
"daily", # missing time
|
||||
"hourly", # unsupported keyword
|
||||
"weekly", # unsupported keyword
|
||||
],
|
||||
ids=[
|
||||
"trailing-comma",
|
||||
"single-digit-hour",
|
||||
"daily-no-time",
|
||||
"hourly-keyword",
|
||||
"weekly-keyword",
|
||||
],
|
||||
)
|
||||
def test_invalid_schedule_raises(self, bad_schedule: str):
|
||||
with pytest.raises(ValueError, match="Invalid sync_schedule"):
|
||||
_make_table(sync_schedule=bad_schedule)
|
||||
|
||||
def test_empty_string_treated_as_none(self):
|
||||
"""Empty string is falsy, so validation is skipped (same as None)."""
|
||||
table = _make_table(sync_schedule="")
|
||||
assert table.sync_schedule == ""
|
||||
|
||||
def test_daily_25_accepted_by_regex(self):
|
||||
"""25:00 passes regex validation (two digits). Document this behavior."""
|
||||
table = _make_table(sync_schedule="daily 25:00")
|
||||
assert table.sync_schedule == "daily 25:00"
|
||||
|
||||
|
||||
class TestSyncScheduleNoneExplicit:
|
||||
def test_explicit_none_accepted(self):
|
||||
table = _make_table(sync_schedule=None)
|
||||
assert table.sync_schedule is None
|
||||
|
|
@ -1,228 +0,0 @@
|
|||
"""Tests for remote table skipping in DataSyncManager.sync_all().
|
||||
|
||||
Tables with query_mode == "remote" should be skipped during sync (no local
|
||||
Parquet file is needed -- queries go directly to BigQuery). Tables with
|
||||
query_mode "local" or "hybrid" must still be synced normally.
|
||||
"""
|
||||
|
||||
from unittest.mock import MagicMock, patch
|
||||
|
||||
import pytest
|
||||
|
||||
from src.config import TableConfig
|
||||
from src.data_sync import DataSyncManager
|
||||
|
||||
|
||||
# ---------------------------------------------------------------------------
|
||||
# Helpers
|
||||
# ---------------------------------------------------------------------------
|
||||
|
||||
def _make_table_config(
|
||||
table_id: str,
|
||||
name: str,
|
||||
query_mode: str = "local",
|
||||
) -> TableConfig:
|
||||
"""Create a minimal TableConfig for testing."""
|
||||
return TableConfig(
|
||||
id=table_id,
|
||||
name=name,
|
||||
description=f"Test table {name}",
|
||||
primary_key="id",
|
||||
sync_strategy="full_refresh",
|
||||
query_mode=query_mode,
|
||||
)
|
||||
|
||||
|
||||
def _successful_sync_result() -> dict:
|
||||
"""Return a fake successful sync result dict."""
|
||||
return {
|
||||
"success": True,
|
||||
"rows": 100,
|
||||
"file_size_mb": 0.5,
|
||||
}
|
||||
|
||||
|
||||
# ---------------------------------------------------------------------------
|
||||
# Fixtures
|
||||
# ---------------------------------------------------------------------------
|
||||
|
||||
@pytest.fixture
|
||||
def table_local():
|
||||
return _make_table_config("t.local", "local_table", query_mode="local")
|
||||
|
||||
|
||||
@pytest.fixture
|
||||
def table_remote():
|
||||
return _make_table_config("t.remote", "remote_table", query_mode="remote")
|
||||
|
||||
|
||||
@pytest.fixture
|
||||
def table_hybrid():
|
||||
return _make_table_config("t.hybrid", "hybrid_table", query_mode="hybrid")
|
||||
|
||||
|
||||
@pytest.fixture
|
||||
def all_tables(table_local, table_remote, table_hybrid):
|
||||
return [table_local, table_remote, table_hybrid]
|
||||
|
||||
|
||||
@pytest.fixture
|
||||
def mock_config(all_tables):
|
||||
"""Return a mock Config whose .tables list contains all three query modes."""
|
||||
cfg = MagicMock()
|
||||
cfg.tables = all_tables
|
||||
cfg.get_metadata_path.return_value = MagicMock() # Path-like
|
||||
|
||||
def _get_table_config(tid):
|
||||
return next((t for t in all_tables if t.id == tid), None)
|
||||
|
||||
cfg.get_table_config.side_effect = _get_table_config
|
||||
return cfg
|
||||
|
||||
|
||||
@pytest.fixture
|
||||
def mock_data_source():
|
||||
"""Return a mock DataSource that always succeeds."""
|
||||
ds = MagicMock()
|
||||
ds.sync_table.return_value = _successful_sync_result()
|
||||
return ds
|
||||
|
||||
|
||||
@pytest.fixture
|
||||
def sync_manager(mock_config, mock_data_source):
|
||||
"""Create a DataSyncManager with mocked dependencies."""
|
||||
with (
|
||||
patch("src.data_sync.get_config", return_value=mock_config),
|
||||
patch("src.data_sync.create_data_source", return_value=mock_data_source),
|
||||
patch("src.data_sync.SyncState"),
|
||||
):
|
||||
manager = DataSyncManager()
|
||||
# Patch out schema generation and auto-profiling (not under test)
|
||||
manager._generate_schema_yaml = MagicMock()
|
||||
yield manager
|
||||
|
||||
|
||||
# ---------------------------------------------------------------------------
|
||||
# Tests
|
||||
# ---------------------------------------------------------------------------
|
||||
|
||||
class TestSyncAllRemoteSkipping:
|
||||
"""Verify that sync_all filters out remote tables."""
|
||||
|
||||
def test_remote_table_not_synced(self, sync_manager, mock_data_source, table_remote):
|
||||
"""Remote table must NOT be passed to data_source.sync_table."""
|
||||
sync_manager.sync_all()
|
||||
|
||||
synced_ids = [
|
||||
call.args[0].id for call in mock_data_source.sync_table.call_args_list
|
||||
]
|
||||
assert table_remote.id not in synced_ids
|
||||
|
||||
def test_local_table_is_synced(self, sync_manager, mock_data_source, table_local):
|
||||
"""Local table must be synced normally."""
|
||||
sync_manager.sync_all()
|
||||
|
||||
synced_ids = [
|
||||
call.args[0].id for call in mock_data_source.sync_table.call_args_list
|
||||
]
|
||||
assert table_local.id in synced_ids
|
||||
|
||||
def test_hybrid_table_is_synced(self, sync_manager, mock_data_source, table_hybrid):
|
||||
"""Hybrid table must be synced (needs local parquet for profiling)."""
|
||||
sync_manager.sync_all()
|
||||
|
||||
synced_ids = [
|
||||
call.args[0].id for call in mock_data_source.sync_table.call_args_list
|
||||
]
|
||||
assert table_hybrid.id in synced_ids
|
||||
|
||||
def test_sync_call_count(self, sync_manager, mock_data_source):
|
||||
"""Only local + hybrid tables should result in sync_table calls."""
|
||||
sync_manager.sync_all()
|
||||
|
||||
# 3 tables total, 1 remote -> 2 sync calls
|
||||
assert mock_data_source.sync_table.call_count == 2
|
||||
|
||||
def test_results_exclude_remote(self, sync_manager, table_remote):
|
||||
"""The results dict must not contain an entry for the remote table."""
|
||||
results = sync_manager.sync_all()
|
||||
|
||||
assert table_remote.id not in results
|
||||
|
||||
def test_results_include_local_and_hybrid(
|
||||
self, sync_manager, table_local, table_hybrid
|
||||
):
|
||||
"""Results dict must contain entries for local and hybrid tables."""
|
||||
results = sync_manager.sync_all()
|
||||
|
||||
assert table_local.id in results
|
||||
assert table_hybrid.id in results
|
||||
|
||||
|
||||
class TestSyncAllAllRemote:
|
||||
"""Edge case: every table is remote."""
|
||||
|
||||
def test_no_sync_calls_when_all_remote(self, mock_config, mock_data_source):
|
||||
remote_only = [
|
||||
_make_table_config("t.r1", "remote1", query_mode="remote"),
|
||||
_make_table_config("t.r2", "remote2", query_mode="remote"),
|
||||
]
|
||||
mock_config.tables = remote_only
|
||||
|
||||
with (
|
||||
patch("src.data_sync.get_config", return_value=mock_config),
|
||||
patch("src.data_sync.create_data_source", return_value=mock_data_source),
|
||||
patch("src.data_sync.SyncState"),
|
||||
):
|
||||
manager = DataSyncManager()
|
||||
manager._generate_schema_yaml = MagicMock()
|
||||
results = manager.sync_all()
|
||||
|
||||
assert mock_data_source.sync_table.call_count == 0
|
||||
assert results == {}
|
||||
|
||||
|
||||
class TestSyncAllNoRemote:
|
||||
"""Edge case: no remote tables at all -- everything syncs."""
|
||||
|
||||
def test_all_tables_synced(self, mock_config, mock_data_source):
|
||||
local_only = [
|
||||
_make_table_config("t.l1", "local1", query_mode="local"),
|
||||
_make_table_config("t.l2", "local2", query_mode="local"),
|
||||
]
|
||||
mock_config.tables = local_only
|
||||
|
||||
with (
|
||||
patch("src.data_sync.get_config", return_value=mock_config),
|
||||
patch("src.data_sync.create_data_source", return_value=mock_data_source),
|
||||
patch("src.data_sync.SyncState"),
|
||||
):
|
||||
manager = DataSyncManager()
|
||||
manager._generate_schema_yaml = MagicMock()
|
||||
results = manager.sync_all()
|
||||
|
||||
assert mock_data_source.sync_table.call_count == 2
|
||||
assert "t.l1" in results
|
||||
assert "t.l2" in results
|
||||
|
||||
|
||||
class TestSyncAllWithTableFilter:
|
||||
"""When sync_all receives an explicit table list, remote filtering still applies."""
|
||||
|
||||
def test_explicit_remote_table_still_skipped(
|
||||
self, sync_manager, mock_data_source, table_remote
|
||||
):
|
||||
"""Even if explicitly listed, a remote table should be skipped."""
|
||||
sync_manager.sync_all(tables=[table_remote.id])
|
||||
|
||||
assert mock_data_source.sync_table.call_count == 0
|
||||
|
||||
def test_explicit_local_table_synced(
|
||||
self, sync_manager, mock_data_source, table_local
|
||||
):
|
||||
"""An explicitly listed local table should be synced."""
|
||||
sync_manager.sync_all(tables=[table_local.id])
|
||||
|
||||
assert mock_data_source.sync_table.call_count == 1
|
||||
synced_id = mock_data_source.sync_table.call_args_list[0].args[0].id
|
||||
assert synced_id == table_local.id
|
||||
|
|
@ -7,11 +7,11 @@ from datetime import datetime, timedelta
|
|||
from unittest.mock import Mock, patch, MagicMock
|
||||
from dataclasses import dataclass
|
||||
|
||||
from src.config import TableConfig
|
||||
from connectors.openmetadata.enricher import (
|
||||
CatalogEnricher,
|
||||
CatalogTableData,
|
||||
CatalogColumnData,
|
||||
TableConfig,
|
||||
)
|
||||
|
||||
|
||||
|
|
@ -21,9 +21,6 @@ def sample_table_config():
|
|||
return TableConfig(
|
||||
id="prj-grp-dataview-prod-1ff9.marketing.roi_datamart_v2",
|
||||
name="roi_datamart_v2",
|
||||
description="ROI metrics",
|
||||
primary_key="id",
|
||||
sync_strategy="full_refresh",
|
||||
)
|
||||
|
||||
|
||||
|
|
@ -111,9 +108,6 @@ def test_enrich_table_disabled():
|
|||
table_config = TableConfig(
|
||||
id="test.table",
|
||||
name="test",
|
||||
description="Test",
|
||||
primary_key="id",
|
||||
sync_strategy="full_refresh",
|
||||
)
|
||||
|
||||
result = enricher.enrich_table(table_config)
|
||||
|
|
@ -145,9 +139,6 @@ def test_enrich_table_cache_hit():
|
|||
table_config = TableConfig(
|
||||
id="prj-grp-dataview-prod-1ff9.marketing.test",
|
||||
name="test",
|
||||
description="Test",
|
||||
primary_key="id",
|
||||
sync_strategy="full_refresh",
|
||||
)
|
||||
|
||||
result = enricher.enrich_table(table_config)
|
||||
|
|
@ -199,9 +190,6 @@ def test_derive_fqn_auto():
|
|||
table_config = TableConfig(
|
||||
id="prj-grp-dataview-prod-1ff9.marketing.roi_datamart_v2",
|
||||
name="roi_datamart_v2",
|
||||
description="Test",
|
||||
primary_key="id",
|
||||
sync_strategy="full_refresh",
|
||||
)
|
||||
|
||||
fqn = enricher._derive_fqn(table_config)
|
||||
|
|
@ -223,9 +211,6 @@ def test_derive_fqn_explicit_override():
|
|||
table_config = TableConfig(
|
||||
id="prj-grp-dataview-prod-1ff9.marketing.roi_datamart_v2",
|
||||
name="roi_datamart_v2",
|
||||
description="Test",
|
||||
primary_key="id",
|
||||
sync_strategy="full_refresh",
|
||||
)
|
||||
table_config.catalog_fqn = "bigquery.custom.fqn.override"
|
||||
|
||||
|
|
@ -390,9 +375,6 @@ def test_enrich_table_http_error_graceful():
|
|||
table_config = TableConfig(
|
||||
id="test.table",
|
||||
name="test",
|
||||
description="Test",
|
||||
primary_key="id",
|
||||
sync_strategy="full_refresh",
|
||||
)
|
||||
|
||||
# Should return None, not raise
|
||||
|
|
|
|||
|
|
@ -1,779 +0,0 @@
|
|||
"""Tests for remote_query module - hybrid local Parquet + remote BigQuery queries.
|
||||
|
||||
Tests cover:
|
||||
- CLI argument parsing (_parse_register_bq, build_parser)
|
||||
- Local view setup (_setup_local_views via create_local_views)
|
||||
- BQ registration with safety checks (_validate_bq_result_size, _estimate_memory_mb, _register_bq_views)
|
||||
- Output formatting (_print_table, _format_output)
|
||||
- End-to-end local-only queries (no BQ mocking needed)
|
||||
"""
|
||||
|
||||
import argparse
|
||||
import csv
|
||||
import json
|
||||
import os
|
||||
from io import StringIO
|
||||
from pathlib import Path
|
||||
from unittest.mock import MagicMock, patch
|
||||
|
||||
import duckdb
|
||||
import pyarrow as pa
|
||||
import pyarrow.parquet as pq
|
||||
import pytest
|
||||
|
||||
from src.remote_query import (
|
||||
RemoteQueryError,
|
||||
_estimate_memory_mb,
|
||||
_format_output,
|
||||
_parse_register_bq,
|
||||
_print_table,
|
||||
_register_bq_views,
|
||||
_validate_bq_result_size,
|
||||
build_parser,
|
||||
execute_remote_query,
|
||||
)
|
||||
|
||||
|
||||
# ---------------------------------------------------------------------------
|
||||
# Fixtures
|
||||
# ---------------------------------------------------------------------------
|
||||
|
||||
@pytest.fixture
|
||||
def tmp_local_project(tmp_path):
|
||||
"""Create a minimal project with docs/data_description.md and parquet files.
|
||||
|
||||
Layout:
|
||||
tmp_path/
|
||||
docs/data_description.md (YAML with local + remote + hybrid tables)
|
||||
server/parquet/crm_data/orders.parquet
|
||||
server/parquet/crm_data/products.parquet
|
||||
|
||||
Returns (project_root, data_dir) where data_dir = tmp_path / "server".
|
||||
"""
|
||||
docs_dir = tmp_path / "docs"
|
||||
docs_dir.mkdir()
|
||||
|
||||
data_description = """\
|
||||
# Data Description
|
||||
|
||||
```yaml
|
||||
folder_mapping:
|
||||
in.c-crm: crm_data
|
||||
|
||||
tables:
|
||||
- id: "in.c-crm.orders"
|
||||
name: "orders"
|
||||
description: "Order data"
|
||||
primary_key: "order_id"
|
||||
sync_strategy: "full_refresh"
|
||||
|
||||
- id: "in.c-crm.products"
|
||||
name: "products"
|
||||
description: "Product catalog"
|
||||
primary_key: "product_id"
|
||||
sync_strategy: "full_refresh"
|
||||
|
||||
- id: "prj-grp-dataview-prod-1ff9.supply.traffic"
|
||||
name: "traffic"
|
||||
description: "Remote BQ traffic table"
|
||||
primary_key: "id"
|
||||
query_mode: "remote"
|
||||
|
||||
- id: "in.c-crm.inventory"
|
||||
name: "inventory"
|
||||
description: "Hybrid inventory"
|
||||
primary_key: "id"
|
||||
sync_strategy: "full_refresh"
|
||||
query_mode: "hybrid"
|
||||
```
|
||||
"""
|
||||
(docs_dir / "data_description.md").write_text(data_description)
|
||||
|
||||
# Create parquet files for local tables
|
||||
crm_dir = tmp_path / "server" / "parquet" / "crm_data"
|
||||
crm_dir.mkdir(parents=True)
|
||||
|
||||
orders_table = pa.table({
|
||||
"order_id": [1, 2, 3, 4, 5],
|
||||
"amount": [10.0, 20.0, 30.0, 40.0, 50.0],
|
||||
"product_id": [101, 102, 101, 103, 102],
|
||||
})
|
||||
pq.write_table(orders_table, crm_dir / "orders.parquet")
|
||||
|
||||
products_table = pa.table({
|
||||
"product_id": [101, 102, 103],
|
||||
"name": ["Widget", "Gadget", "Doohickey"],
|
||||
})
|
||||
pq.write_table(products_table, crm_dir / "products.parquet")
|
||||
|
||||
# Create parquet for hybrid table
|
||||
inventory_table = pa.table({
|
||||
"id": [1, 2],
|
||||
"stock": [100, 200],
|
||||
})
|
||||
pq.write_table(inventory_table, crm_dir / "inventory.parquet")
|
||||
|
||||
data_dir = str(tmp_path / "server")
|
||||
return tmp_path, data_dir
|
||||
|
||||
|
||||
@pytest.fixture
|
||||
def duckdb_conn():
|
||||
"""Create an in-memory DuckDB connection, closed after test."""
|
||||
conn = duckdb.connect(":memory:")
|
||||
yield conn
|
||||
conn.close()
|
||||
|
||||
|
||||
class _DuckDBConnectionProxy:
|
||||
"""Proxy around DuckDBPyConnection that silently ignores unsupported SET commands.
|
||||
|
||||
DuckDB versions may not support 'statement_timeout'. This proxy catches
|
||||
CatalogException for SET commands so end-to-end tests work across versions.
|
||||
The real connection's execute method is read-only, so we wrap it.
|
||||
"""
|
||||
|
||||
def __init__(self, conn):
|
||||
object.__setattr__(self, "_conn", conn)
|
||||
|
||||
def execute(self, sql, *args, **kwargs):
|
||||
if isinstance(sql, str) and sql.strip().upper().startswith("SET "):
|
||||
try:
|
||||
return self._conn.execute(sql, *args, **kwargs)
|
||||
except duckdb.CatalogException:
|
||||
return None
|
||||
return self._conn.execute(sql, *args, **kwargs)
|
||||
|
||||
def __getattr__(self, name):
|
||||
return getattr(self._conn, name)
|
||||
|
||||
|
||||
def _patched_duckdb_connect(*args, **kwargs):
|
||||
"""Create a DuckDB connection wrapped in the proxy."""
|
||||
conn = duckdb.connect(*args, **kwargs)
|
||||
return _DuckDBConnectionProxy(conn)
|
||||
|
||||
|
||||
# ---------------------------------------------------------------------------
|
||||
# Tests: CLI argument parsing
|
||||
# ---------------------------------------------------------------------------
|
||||
|
||||
class TestCLIArgParsing:
|
||||
"""Test _parse_register_bq() and build_parser()."""
|
||||
|
||||
def test_sql_defaults_to_none(self):
|
||||
"""Parser allows --sql to be omitted (for --stdin mode)."""
|
||||
parser = build_parser()
|
||||
args = parser.parse_args(["--stdin"])
|
||||
assert args.sql is None
|
||||
assert args.stdin is True
|
||||
|
||||
def test_register_bq_parsing(self):
|
||||
"""'alias=SELECT ...' parses into (alias, sql) tuple."""
|
||||
result = _parse_register_bq("traffic=SELECT report_date FROM `proj.ds.table`")
|
||||
assert result == ("traffic", "SELECT report_date FROM `proj.ds.table`")
|
||||
|
||||
def test_register_bq_invalid_format(self):
|
||||
"""Missing '=' should raise ArgumentTypeError."""
|
||||
with pytest.raises(argparse.ArgumentTypeError, match="Invalid --register-bq format"):
|
||||
_parse_register_bq("no_equals_sign_here")
|
||||
|
||||
def test_register_bq_empty_sql(self):
|
||||
"""Alias with empty SQL after '=' should raise."""
|
||||
with pytest.raises(argparse.ArgumentTypeError, match="Empty SQL"):
|
||||
_parse_register_bq("alias=")
|
||||
|
||||
def test_register_bq_empty_alias(self):
|
||||
"""'=SELECT ...' (empty alias) should raise."""
|
||||
with pytest.raises(argparse.ArgumentTypeError, match="Invalid --register-bq format"):
|
||||
_parse_register_bq("=SELECT 1")
|
||||
|
||||
def test_multiple_register_bq(self):
|
||||
"""Multiple --register-bq args should be collected into a list."""
|
||||
parser = build_parser()
|
||||
args = parser.parse_args([
|
||||
"--sql", "SELECT 1",
|
||||
"--register-bq", "t1=SELECT a FROM x",
|
||||
"--register-bq", "t2=SELECT b FROM y",
|
||||
])
|
||||
assert len(args.bq_registrations) == 2
|
||||
assert args.bq_registrations[0] == ("t1", "SELECT a FROM x")
|
||||
assert args.bq_registrations[1] == ("t2", "SELECT b FROM y")
|
||||
|
||||
def test_default_format_is_none(self):
|
||||
"""Default --format should be None (uses config default at runtime)."""
|
||||
parser = build_parser()
|
||||
args = parser.parse_args(["--sql", "SELECT 1"])
|
||||
assert args.fmt is None
|
||||
|
||||
def test_explicit_format(self):
|
||||
"""Explicit --format should be respected."""
|
||||
parser = build_parser()
|
||||
args = parser.parse_args(["--sql", "SELECT 1", "--format", "csv"])
|
||||
assert args.fmt == "csv"
|
||||
|
||||
def test_invalid_format_rejected(self):
|
||||
"""Invalid --format value should cause parser error."""
|
||||
parser = build_parser()
|
||||
with pytest.raises(SystemExit):
|
||||
parser.parse_args(["--sql", "SELECT 1", "--format", "xml"])
|
||||
|
||||
def test_no_register_bq_yields_empty_list(self):
|
||||
"""When no --register-bq is provided, bq_registrations defaults to []."""
|
||||
parser = build_parser()
|
||||
args = parser.parse_args(["--sql", "SELECT 1"])
|
||||
assert args.bq_registrations == []
|
||||
|
||||
def test_register_bq_sql_with_equals(self):
|
||||
"""SQL containing '=' should be parsed correctly (split only on first '=')."""
|
||||
result = _parse_register_bq("view=SELECT * FROM t WHERE col = 5")
|
||||
assert result[0] == "view"
|
||||
assert result[1] == "SELECT * FROM t WHERE col = 5"
|
||||
|
||||
def test_quiet_flag(self):
|
||||
"""--quiet should set quiet=True."""
|
||||
parser = build_parser()
|
||||
args = parser.parse_args(["--sql", "SELECT 1", "--quiet"])
|
||||
assert args.quiet is True
|
||||
|
||||
def test_max_rows_parsing(self):
|
||||
"""--max-rows should be parsed as integer."""
|
||||
parser = build_parser()
|
||||
args = parser.parse_args(["--sql", "SELECT 1", "--max-rows", "500"])
|
||||
assert args.max_rows == 500
|
||||
|
||||
|
||||
# ---------------------------------------------------------------------------
|
||||
# Tests: Local view setup
|
||||
# ---------------------------------------------------------------------------
|
||||
|
||||
class TestLocalViewSetup:
|
||||
"""Test _setup_local_views via create_local_views with tmp_path fixture."""
|
||||
|
||||
def test_creates_views_from_parquet(self, tmp_local_project, duckdb_conn):
|
||||
"""Local tables should be available as DuckDB views after setup."""
|
||||
project_root, data_dir = tmp_local_project
|
||||
|
||||
with patch("scripts.duckdb_manager.find_project_root", return_value=project_root):
|
||||
from src.remote_query import _setup_local_views
|
||||
created = _setup_local_views(duckdb_conn, data_dir, quiet=True)
|
||||
|
||||
assert "orders" in created
|
||||
assert "products" in created
|
||||
|
||||
# Verify data is queryable
|
||||
count = duckdb_conn.execute("SELECT COUNT(*) FROM orders").fetchone()[0]
|
||||
assert count == 5
|
||||
|
||||
def test_skips_remote_tables(self, tmp_local_project, duckdb_conn):
|
||||
"""Remote tables (query_mode='remote') should NOT create local views."""
|
||||
project_root, data_dir = tmp_local_project
|
||||
|
||||
with patch("scripts.duckdb_manager.find_project_root", return_value=project_root):
|
||||
from src.remote_query import _setup_local_views
|
||||
created = _setup_local_views(duckdb_conn, data_dir, quiet=True)
|
||||
|
||||
assert "traffic" not in created
|
||||
|
||||
# Verify the remote table is not queryable
|
||||
tables = [row[0] for row in duckdb_conn.execute("SHOW TABLES").fetchall()]
|
||||
assert "traffic" not in tables
|
||||
|
||||
def test_includes_hybrid_tables(self, tmp_local_project, duckdb_conn):
|
||||
"""Hybrid tables (query_mode='hybrid') should create local views."""
|
||||
project_root, data_dir = tmp_local_project
|
||||
|
||||
with patch("scripts.duckdb_manager.find_project_root", return_value=project_root):
|
||||
from src.remote_query import _setup_local_views
|
||||
created = _setup_local_views(duckdb_conn, data_dir, quiet=True)
|
||||
|
||||
assert "inventory" in created
|
||||
|
||||
count = duckdb_conn.execute("SELECT COUNT(*) FROM inventory").fetchone()[0]
|
||||
assert count == 2
|
||||
|
||||
|
||||
# ---------------------------------------------------------------------------
|
||||
# Tests: BQ registration with safety checks
|
||||
# ---------------------------------------------------------------------------
|
||||
|
||||
class TestBQRegistration:
|
||||
"""Test BQ result validation and registration (mocked BigQuery)."""
|
||||
|
||||
@staticmethod
|
||||
def _make_mock_bq_client(count_result: int = 100, schema_fields: int = 5):
|
||||
"""Create a mock BQ client that returns controlled count and schema.
|
||||
|
||||
Args:
|
||||
count_result: Row count returned by COUNT(*) query
|
||||
schema_fields: Number of fields in the schema
|
||||
"""
|
||||
mock_client = MagicMock()
|
||||
|
||||
# COUNT(*) query result
|
||||
count_row = MagicMock()
|
||||
count_row.__getitem__ = MagicMock(return_value=count_result)
|
||||
count_iter = iter([count_row])
|
||||
|
||||
# Schema query result (LIMIT 0)
|
||||
mock_schema_fields = [MagicMock() for _ in range(schema_fields)]
|
||||
mock_schema = MagicMock()
|
||||
mock_schema.__len__ = MagicMock(return_value=schema_fields)
|
||||
|
||||
# Use side_effect to return different results for different queries
|
||||
def query_side_effect(sql):
|
||||
job = MagicMock()
|
||||
if sql.startswith("SELECT COUNT(*)"):
|
||||
result = MagicMock()
|
||||
result.__iter__ = MagicMock(return_value=iter([count_row]))
|
||||
job.result.return_value = result
|
||||
elif "LIMIT 0" in sql:
|
||||
result = MagicMock()
|
||||
result.schema = mock_schema_fields
|
||||
job.result.return_value = result
|
||||
return job
|
||||
|
||||
mock_client.query.side_effect = query_side_effect
|
||||
return mock_client
|
||||
|
||||
def test_count_check_blocks_large_result(self):
|
||||
"""BQ sub-query exceeding max_rows should raise RemoteQueryError."""
|
||||
mock_client = self._make_mock_bq_client(count_result=1_000_000)
|
||||
|
||||
with pytest.raises(RemoteQueryError, match="would return 1,000,000 rows"):
|
||||
_validate_bq_result_size(
|
||||
bq_client=mock_client,
|
||||
sql="SELECT * FROM big_table",
|
||||
alias="big_table",
|
||||
max_rows=500_000,
|
||||
)
|
||||
|
||||
def test_validates_small_result_passes(self):
|
||||
"""BQ sub-query within limits should return the row count."""
|
||||
mock_client = self._make_mock_bq_client(count_result=1000)
|
||||
|
||||
row_count = _validate_bq_result_size(
|
||||
bq_client=mock_client,
|
||||
sql="SELECT * FROM small_table",
|
||||
alias="small_table",
|
||||
max_rows=500_000,
|
||||
)
|
||||
|
||||
assert row_count == 1000
|
||||
|
||||
def test_memory_estimate_blocks_huge_result(self):
|
||||
"""_register_bq_views should refuse when estimated memory exceeds 2 GB."""
|
||||
# Create a mock that passes count check but fails memory check
|
||||
# 500K rows x 100 cols x 50 bytes/cell = ~2384 MB > 2048 MB limit
|
||||
mock_client = self._make_mock_bq_client(count_result=500_000, schema_fields=100)
|
||||
|
||||
conn = duckdb.connect(":memory:")
|
||||
try:
|
||||
with patch("src.remote_query._create_bq_client", return_value=mock_client), \
|
||||
patch.dict(os.environ, {"BIGQUERY_PROJECT": "test-proj"}):
|
||||
with pytest.raises(RemoteQueryError, match="estimated memory"):
|
||||
_register_bq_views(
|
||||
conn=conn,
|
||||
registrations=[("huge", "SELECT * FROM huge_table")],
|
||||
max_bq_rows=1_000_000,
|
||||
timeout_seconds=60,
|
||||
quiet=True,
|
||||
)
|
||||
finally:
|
||||
conn.close()
|
||||
|
||||
def test_registers_small_result(self):
|
||||
"""BQ sub-query within all limits should register successfully."""
|
||||
# Small result: 100 rows x 5 cols = ~0.02 MB
|
||||
mock_client = self._make_mock_bq_client(count_result=100, schema_fields=5)
|
||||
|
||||
# Mock register_bq_table to return the row count
|
||||
conn = duckdb.connect(":memory:")
|
||||
try:
|
||||
with patch("src.remote_query._create_bq_client", return_value=mock_client), \
|
||||
patch("src.remote_query.register_bq_table", return_value=100) as mock_reg, \
|
||||
patch.dict(os.environ, {"BIGQUERY_PROJECT": "test-proj"}):
|
||||
results = _register_bq_views(
|
||||
conn=conn,
|
||||
registrations=[("small_view", "SELECT * FROM small_table")],
|
||||
max_bq_rows=500_000,
|
||||
timeout_seconds=60,
|
||||
quiet=True,
|
||||
)
|
||||
|
||||
assert results == {"small_view": 100}
|
||||
mock_reg.assert_called_once()
|
||||
finally:
|
||||
conn.close()
|
||||
|
||||
def test_missing_bigquery_project_raises(self):
|
||||
"""Missing BIGQUERY_PROJECT env var should raise RemoteQueryError."""
|
||||
conn = duckdb.connect(":memory:")
|
||||
try:
|
||||
with patch.dict(os.environ, {}, clear=True):
|
||||
with pytest.raises(RemoteQueryError, match="BIGQUERY_PROJECT"):
|
||||
_register_bq_views(
|
||||
conn=conn,
|
||||
registrations=[("v", "SELECT 1")],
|
||||
max_bq_rows=100,
|
||||
timeout_seconds=60,
|
||||
)
|
||||
finally:
|
||||
conn.close()
|
||||
|
||||
def test_empty_registrations_returns_empty(self):
|
||||
"""Empty registration list should return empty dict without BQ calls."""
|
||||
conn = duckdb.connect(":memory:")
|
||||
try:
|
||||
result = _register_bq_views(
|
||||
conn=conn,
|
||||
registrations=[],
|
||||
max_bq_rows=100,
|
||||
timeout_seconds=60,
|
||||
)
|
||||
assert result == {}
|
||||
finally:
|
||||
conn.close()
|
||||
|
||||
|
||||
# ---------------------------------------------------------------------------
|
||||
# Tests: Memory estimation
|
||||
# ---------------------------------------------------------------------------
|
||||
|
||||
class TestMemoryEstimation:
|
||||
"""Test _estimate_memory_mb calculation."""
|
||||
|
||||
def test_small_table(self):
|
||||
"""100 rows x 10 cols = 50_000 bytes ~ 0.048 MB."""
|
||||
result = _estimate_memory_mb(100, 10)
|
||||
assert abs(result - 50_000 / (1024 * 1024)) < 0.001
|
||||
|
||||
def test_large_table(self):
|
||||
"""1M rows x 50 cols x 50 bytes = ~2384 MB."""
|
||||
result = _estimate_memory_mb(1_000_000, 50)
|
||||
expected = (1_000_000 * 50 * 50) / (1024 * 1024)
|
||||
assert abs(result - expected) < 0.01
|
||||
|
||||
def test_zero_rows(self):
|
||||
"""Zero rows should return 0 MB."""
|
||||
assert _estimate_memory_mb(0, 50) == 0.0
|
||||
|
||||
def test_zero_columns(self):
|
||||
"""Zero columns should return 0 MB."""
|
||||
assert _estimate_memory_mb(1000, 0) == 0.0
|
||||
|
||||
|
||||
# ---------------------------------------------------------------------------
|
||||
# Tests: Output formatting
|
||||
# ---------------------------------------------------------------------------
|
||||
|
||||
class TestOutputFormatting:
|
||||
"""Test _print_table and _format_output for various formats."""
|
||||
|
||||
def test_table_format_aligned(self, capsys):
|
||||
"""Table format should produce aligned columns with header separator."""
|
||||
columns = ["id", "name", "value"]
|
||||
rows = [(1, "alice", 100), (2, "bob", 200)]
|
||||
|
||||
_print_table(columns, rows)
|
||||
|
||||
output = capsys.readouterr().out
|
||||
lines = output.strip().split("\n")
|
||||
|
||||
# Header line
|
||||
assert "id" in lines[0]
|
||||
assert "name" in lines[0]
|
||||
assert "value" in lines[0]
|
||||
|
||||
# Separator line
|
||||
assert "-+-" in lines[1]
|
||||
|
||||
# Data rows
|
||||
assert "alice" in lines[2]
|
||||
assert "bob" in lines[3]
|
||||
|
||||
# Row count footer
|
||||
assert "(2 rows)" in output
|
||||
|
||||
def test_table_format_empty_result(self, capsys):
|
||||
"""Empty result should print '(empty result)'."""
|
||||
_print_table(["col1"], [])
|
||||
|
||||
output = capsys.readouterr().out
|
||||
assert "(empty result)" in output
|
||||
|
||||
def test_table_format_null_values(self, capsys):
|
||||
"""None values should be rendered as 'NULL'."""
|
||||
_print_table(["col"], [(None,)])
|
||||
|
||||
output = capsys.readouterr().out
|
||||
assert "NULL" in output
|
||||
|
||||
def test_csv_format(self, tmp_path, duckdb_conn):
|
||||
"""CSV output should contain header + data rows."""
|
||||
duckdb_conn.execute("CREATE TABLE test AS SELECT 1 AS id, 'hello' AS msg")
|
||||
|
||||
output_path = str(tmp_path / "result.csv")
|
||||
_format_output(
|
||||
conn=duckdb_conn,
|
||||
sql="SELECT * FROM test",
|
||||
fmt="csv",
|
||||
output_path=output_path,
|
||||
max_rows=1000,
|
||||
)
|
||||
|
||||
with open(output_path) as f:
|
||||
reader = csv.reader(f)
|
||||
header = next(reader)
|
||||
rows = list(reader)
|
||||
|
||||
assert header == ["id", "msg"]
|
||||
assert len(rows) == 1
|
||||
assert rows[0][1] == "hello"
|
||||
|
||||
def test_csv_format_to_stdout(self, capsys, duckdb_conn):
|
||||
"""CSV with no output path should write to stdout."""
|
||||
duckdb_conn.execute("CREATE TABLE test AS SELECT 42 AS val")
|
||||
|
||||
_format_output(
|
||||
conn=duckdb_conn,
|
||||
sql="SELECT * FROM test",
|
||||
fmt="csv",
|
||||
output_path=None,
|
||||
max_rows=1000,
|
||||
)
|
||||
|
||||
output = capsys.readouterr().out
|
||||
assert "val" in output
|
||||
assert "42" in output
|
||||
|
||||
def test_json_format(self, tmp_path, duckdb_conn):
|
||||
"""JSON output should contain a list of dicts."""
|
||||
duckdb_conn.execute(
|
||||
"CREATE TABLE test AS SELECT 1 AS id, 'world' AS msg"
|
||||
)
|
||||
|
||||
output_path = str(tmp_path / "result.json")
|
||||
_format_output(
|
||||
conn=duckdb_conn,
|
||||
sql="SELECT * FROM test",
|
||||
fmt="json",
|
||||
output_path=output_path,
|
||||
max_rows=1000,
|
||||
)
|
||||
|
||||
with open(output_path) as f:
|
||||
data = json.load(f)
|
||||
|
||||
assert len(data) == 1
|
||||
assert data[0]["id"] == 1
|
||||
assert data[0]["msg"] == "world"
|
||||
|
||||
def test_json_format_to_stdout(self, capsys, duckdb_conn):
|
||||
"""JSON with no output path should print to stdout."""
|
||||
duckdb_conn.execute("CREATE TABLE test AS SELECT 99 AS num")
|
||||
|
||||
_format_output(
|
||||
conn=duckdb_conn,
|
||||
sql="SELECT * FROM test",
|
||||
fmt="json",
|
||||
output_path=None,
|
||||
max_rows=1000,
|
||||
)
|
||||
|
||||
output = capsys.readouterr().out
|
||||
data = json.loads(output)
|
||||
assert data[0]["num"] == 99
|
||||
|
||||
def test_parquet_write(self, tmp_path, duckdb_conn):
|
||||
"""Parquet output should create a readable parquet file."""
|
||||
duckdb_conn.execute(
|
||||
"CREATE TABLE test AS SELECT 1 AS id, 2.5 AS val"
|
||||
)
|
||||
|
||||
output_path = str(tmp_path / "output" / "result.parquet")
|
||||
|
||||
with patch("src.remote_query._load_remote_query_config", return_value={
|
||||
"output_dir": str(tmp_path / "default_output"),
|
||||
"timeout_seconds": 300,
|
||||
"max_result_rows": 100_000,
|
||||
"max_bq_registration_rows": 500_000,
|
||||
"default_format": "table",
|
||||
}):
|
||||
_format_output(
|
||||
conn=duckdb_conn,
|
||||
sql="SELECT * FROM test",
|
||||
fmt="parquet",
|
||||
output_path=output_path,
|
||||
max_rows=1000,
|
||||
)
|
||||
|
||||
assert Path(output_path).exists()
|
||||
|
||||
# Read it back and verify
|
||||
result = pq.read_table(output_path)
|
||||
assert result.num_rows == 1
|
||||
assert result.num_columns == 2
|
||||
assert result.column_names == ["id", "val"]
|
||||
|
||||
def test_parquet_default_path(self, tmp_path, duckdb_conn):
|
||||
"""Parquet with no output path should use config default dir."""
|
||||
duckdb_conn.execute("CREATE TABLE test AS SELECT 1 AS x")
|
||||
|
||||
default_dir = str(tmp_path / "default_output")
|
||||
with patch("src.remote_query._load_remote_query_config", return_value={
|
||||
"output_dir": default_dir,
|
||||
"timeout_seconds": 300,
|
||||
"max_result_rows": 100_000,
|
||||
"max_bq_registration_rows": 500_000,
|
||||
"default_format": "table",
|
||||
}):
|
||||
_format_output(
|
||||
conn=duckdb_conn,
|
||||
sql="SELECT * FROM test",
|
||||
fmt="parquet",
|
||||
output_path=None,
|
||||
max_rows=1000,
|
||||
)
|
||||
|
||||
expected_path = Path(default_dir) / "result.parquet"
|
||||
assert expected_path.exists()
|
||||
|
||||
def test_unknown_format_raises(self, duckdb_conn):
|
||||
"""Unknown format should raise RemoteQueryError."""
|
||||
duckdb_conn.execute("CREATE TABLE test AS SELECT 1 AS id")
|
||||
|
||||
with pytest.raises(RemoteQueryError, match="Unknown format"):
|
||||
_format_output(
|
||||
conn=duckdb_conn,
|
||||
sql="SELECT * FROM test",
|
||||
fmt="xml",
|
||||
output_path=None,
|
||||
max_rows=1000,
|
||||
)
|
||||
|
||||
|
||||
# ---------------------------------------------------------------------------
|
||||
# Tests: End-to-end (local-only, no BQ mocking needed)
|
||||
# ---------------------------------------------------------------------------
|
||||
|
||||
class TestEndToEnd:
|
||||
"""End-to-end tests with local Parquet data only (no BigQuery dependency).
|
||||
|
||||
Uses _patched_duckdb_connect to handle DuckDB version differences
|
||||
(statement_timeout may not be supported in all versions).
|
||||
"""
|
||||
|
||||
_CONFIG = {
|
||||
"timeout_seconds": 300,
|
||||
"max_result_rows": 100_000,
|
||||
"max_bq_registration_rows": 500_000,
|
||||
"default_format": "table",
|
||||
"output_dir": "/tmp/remote_query_test",
|
||||
}
|
||||
|
||||
def _run(self, tmp_local_project, **kwargs):
|
||||
"""Helper to run execute_remote_query with standard patches."""
|
||||
project_root, data_dir = tmp_local_project
|
||||
config = dict(self._CONFIG)
|
||||
config.update(kwargs.pop("config_overrides", {}))
|
||||
|
||||
with patch("scripts.duckdb_manager.find_project_root", return_value=project_root), \
|
||||
patch("src.remote_query._load_remote_query_config", return_value=config), \
|
||||
patch("src.remote_query.duckdb") as mock_duckdb_mod:
|
||||
mock_duckdb_mod.connect = _patched_duckdb_connect
|
||||
kwargs.setdefault("data_dir", data_dir)
|
||||
kwargs.setdefault("bq_registrations", [])
|
||||
kwargs.setdefault("quiet", True)
|
||||
execute_remote_query(**kwargs)
|
||||
|
||||
def test_local_only_query(self, tmp_local_project, capsys):
|
||||
"""Execute a query against local Parquet views and verify table output."""
|
||||
self._run(
|
||||
tmp_local_project,
|
||||
sql="SELECT COUNT(*) AS cnt FROM orders",
|
||||
fmt="table",
|
||||
)
|
||||
|
||||
output = capsys.readouterr().out
|
||||
assert "cnt" in output
|
||||
assert "5" in output
|
||||
|
||||
def test_local_join_query(self, tmp_local_project, capsys):
|
||||
"""JOIN between two local tables should work."""
|
||||
self._run(
|
||||
tmp_local_project,
|
||||
sql=(
|
||||
"SELECT p.name, SUM(o.amount) AS total "
|
||||
"FROM orders o JOIN products p ON o.product_id = p.product_id "
|
||||
"GROUP BY p.name ORDER BY total DESC"
|
||||
),
|
||||
fmt="json",
|
||||
)
|
||||
|
||||
output = capsys.readouterr().out
|
||||
data = json.loads(output)
|
||||
assert len(data) == 3
|
||||
# Widget: orders 1,3 -> 10+30=40
|
||||
widget = next(r for r in data if r["name"] == "Widget")
|
||||
assert widget["total"] == 40.0
|
||||
|
||||
def test_result_row_limit(self, tmp_local_project, capsys):
|
||||
"""Result exceeding max_rows should be truncated."""
|
||||
self._run(
|
||||
tmp_local_project,
|
||||
sql="SELECT * FROM orders ORDER BY order_id",
|
||||
fmt="table",
|
||||
max_rows=2,
|
||||
quiet=False,
|
||||
config_overrides={"max_result_rows": 2},
|
||||
)
|
||||
|
||||
out = capsys.readouterr().out
|
||||
# Table output should show exactly 2 data rows
|
||||
assert "(2 rows)" in out
|
||||
|
||||
def test_csv_output_to_file(self, tmp_local_project, tmp_path):
|
||||
"""End-to-end CSV output written to a file."""
|
||||
output_path = str(tmp_path / "result.csv")
|
||||
|
||||
self._run(
|
||||
tmp_local_project,
|
||||
sql="SELECT order_id, amount FROM orders ORDER BY order_id",
|
||||
fmt="csv",
|
||||
output=output_path,
|
||||
)
|
||||
|
||||
with open(output_path) as f:
|
||||
reader = csv.DictReader(f)
|
||||
rows = list(reader)
|
||||
|
||||
assert len(rows) == 5
|
||||
assert rows[0]["order_id"] == "1"
|
||||
assert rows[0]["amount"] == "10.0"
|
||||
|
||||
def test_hybrid_table_queryable(self, tmp_local_project, capsys):
|
||||
"""Hybrid table should be accessible in local queries."""
|
||||
self._run(
|
||||
tmp_local_project,
|
||||
sql="SELECT SUM(stock) AS total_stock FROM inventory",
|
||||
fmt="json",
|
||||
)
|
||||
|
||||
output = capsys.readouterr().out
|
||||
data = json.loads(output)
|
||||
assert data[0]["total_stock"] == 300
|
||||
|
||||
def test_quiet_mode_suppresses_stderr(self, tmp_local_project, capsys):
|
||||
"""With quiet=True, no progress messages should appear on stderr."""
|
||||
self._run(
|
||||
tmp_local_project,
|
||||
sql="SELECT COUNT(*) AS cnt FROM orders",
|
||||
fmt="table",
|
||||
quiet=True,
|
||||
)
|
||||
|
||||
err = capsys.readouterr().err
|
||||
# In quiet mode, _log_progress should not emit anything
|
||||
assert "Setting up" not in err
|
||||
assert "local views" not in err
|
||||
|
|
@ -1,363 +0,0 @@
|
|||
"""Tests for the Table Registry module."""
|
||||
|
||||
import json
|
||||
from pathlib import Path
|
||||
|
||||
import pytest
|
||||
import yaml
|
||||
|
||||
from src.table_registry import ConflictError, TableRegistry
|
||||
|
||||
|
||||
# ---------------------------------------------------------------------------
|
||||
# Fixtures
|
||||
# ---------------------------------------------------------------------------
|
||||
|
||||
@pytest.fixture
|
||||
def registry_path(tmp_path):
|
||||
"""Return a temp path for the registry JSON."""
|
||||
return tmp_path / "table_registry.json"
|
||||
|
||||
|
||||
@pytest.fixture
|
||||
def registry(registry_path):
|
||||
"""Create an empty registry."""
|
||||
return TableRegistry(registry_path)
|
||||
|
||||
|
||||
@pytest.fixture
|
||||
def sample_table():
|
||||
"""Minimal valid table definition."""
|
||||
return {
|
||||
"id": "in.c-crm.company",
|
||||
"name": "company",
|
||||
"description": "Customer master data",
|
||||
"primary_key": "id",
|
||||
"sync_strategy": "full_refresh",
|
||||
}
|
||||
|
||||
|
||||
@pytest.fixture
|
||||
def sample_table_incremental():
|
||||
"""Incremental table definition."""
|
||||
return {
|
||||
"id": "in.c-crm.events",
|
||||
"name": "events",
|
||||
"description": "User events",
|
||||
"primary_key": "event_id",
|
||||
"sync_strategy": "incremental",
|
||||
"incremental_window_days": 14,
|
||||
"partition_by": "created_at",
|
||||
"partition_granularity": "month",
|
||||
}
|
||||
|
||||
|
||||
# ---------------------------------------------------------------------------
|
||||
# Basic CRUD
|
||||
# ---------------------------------------------------------------------------
|
||||
|
||||
class TestRegistryCRUD:
|
||||
|
||||
def test_empty_registry(self, registry):
|
||||
assert registry.list_tables() == []
|
||||
assert registry.version == 0
|
||||
|
||||
def test_register_table(self, registry, sample_table):
|
||||
registry.register_table(sample_table, registered_by="admin@test.com")
|
||||
tables = registry.list_tables()
|
||||
assert len(tables) == 1
|
||||
assert tables[0]["id"] == "in.c-crm.company"
|
||||
assert tables[0]["registered_by"] == "admin@test.com"
|
||||
assert registry.version == 1
|
||||
|
||||
def test_register_duplicate_raises(self, registry, sample_table):
|
||||
registry.register_table(sample_table, registered_by="admin@test.com")
|
||||
with pytest.raises(ValueError, match="already registered"):
|
||||
registry.register_table(sample_table, registered_by="admin@test.com")
|
||||
|
||||
def test_get_table(self, registry, sample_table):
|
||||
registry.register_table(sample_table, registered_by="admin@test.com")
|
||||
t = registry.get_table("in.c-crm.company")
|
||||
assert t is not None
|
||||
assert t["name"] == "company"
|
||||
|
||||
def test_get_table_not_found(self, registry):
|
||||
assert registry.get_table("nonexistent") is None
|
||||
|
||||
def test_is_registered(self, registry, sample_table):
|
||||
assert not registry.is_registered("in.c-crm.company")
|
||||
registry.register_table(sample_table, registered_by="admin@test.com")
|
||||
assert registry.is_registered("in.c-crm.company")
|
||||
|
||||
def test_unregister_table(self, registry, sample_table):
|
||||
registry.register_table(sample_table, registered_by="admin@test.com")
|
||||
registry.unregister_table("in.c-crm.company", unregistered_by="admin@test.com")
|
||||
assert not registry.is_registered("in.c-crm.company")
|
||||
assert registry.list_tables() == []
|
||||
|
||||
def test_unregister_nonexistent_raises(self, registry):
|
||||
with pytest.raises(ValueError, match="not registered"):
|
||||
registry.unregister_table("nonexistent")
|
||||
|
||||
def test_update_table(self, registry, sample_table):
|
||||
registry.register_table(sample_table, registered_by="admin@test.com")
|
||||
registry.update_table(
|
||||
"in.c-crm.company",
|
||||
{"description": "Updated description", "sync_strategy": "incremental"},
|
||||
updated_by="admin@test.com",
|
||||
)
|
||||
t = registry.get_table("in.c-crm.company")
|
||||
assert t["description"] == "Updated description"
|
||||
assert t["sync_strategy"] == "incremental"
|
||||
|
||||
def test_update_nonexistent_raises(self, registry):
|
||||
with pytest.raises(ValueError, match="not registered"):
|
||||
registry.update_table("nonexistent", {"description": "x"})
|
||||
|
||||
|
||||
# ---------------------------------------------------------------------------
|
||||
# Validation
|
||||
# ---------------------------------------------------------------------------
|
||||
|
||||
class TestValidation:
|
||||
|
||||
def test_missing_id_raises(self, registry):
|
||||
with pytest.raises(ValueError, match="must include 'id'"):
|
||||
registry.register_table(
|
||||
{"name": "x", "sync_strategy": "full_refresh", "primary_key": "id"},
|
||||
registered_by="admin@test.com",
|
||||
)
|
||||
|
||||
def test_missing_name_raises(self, registry):
|
||||
with pytest.raises(ValueError, match="must include 'name'"):
|
||||
registry.register_table(
|
||||
{"id": "x.y.z", "sync_strategy": "full_refresh", "primary_key": "id"},
|
||||
registered_by="admin@test.com",
|
||||
)
|
||||
|
||||
def test_invalid_sync_strategy_raises(self, registry):
|
||||
with pytest.raises(ValueError, match="Invalid sync_strategy"):
|
||||
registry.register_table(
|
||||
{
|
||||
"id": "x.y.z",
|
||||
"name": "z",
|
||||
"sync_strategy": "magic",
|
||||
"primary_key": "id",
|
||||
},
|
||||
registered_by="admin@test.com",
|
||||
)
|
||||
|
||||
|
||||
# ---------------------------------------------------------------------------
|
||||
# Optimistic locking
|
||||
# ---------------------------------------------------------------------------
|
||||
|
||||
class TestOptimisticLocking:
|
||||
|
||||
def test_register_with_wrong_version_raises(self, registry, sample_table):
|
||||
with pytest.raises(ConflictError, match="Version conflict"):
|
||||
registry.register_table(
|
||||
sample_table, registered_by="admin@test.com", expected_version=99
|
||||
)
|
||||
|
||||
def test_register_with_correct_version(self, registry, sample_table):
|
||||
registry.register_table(
|
||||
sample_table, registered_by="admin@test.com", expected_version=0
|
||||
)
|
||||
assert registry.version == 1
|
||||
|
||||
def test_unregister_with_wrong_version_raises(self, registry, sample_table):
|
||||
registry.register_table(sample_table, registered_by="admin@test.com")
|
||||
with pytest.raises(ConflictError):
|
||||
registry.unregister_table(
|
||||
"in.c-crm.company", expected_version=0
|
||||
)
|
||||
|
||||
def test_update_with_wrong_version_raises(self, registry, sample_table):
|
||||
registry.register_table(sample_table, registered_by="admin@test.com")
|
||||
with pytest.raises(ConflictError):
|
||||
registry.update_table(
|
||||
"in.c-crm.company",
|
||||
{"description": "x"},
|
||||
expected_version=0,
|
||||
)
|
||||
|
||||
|
||||
# ---------------------------------------------------------------------------
|
||||
# Persistence
|
||||
# ---------------------------------------------------------------------------
|
||||
|
||||
class TestPersistence:
|
||||
|
||||
def test_save_and_reload(self, registry_path, sample_table):
|
||||
reg1 = TableRegistry(registry_path)
|
||||
reg1.register_table(sample_table, registered_by="admin@test.com")
|
||||
|
||||
# Reload from disk
|
||||
reg2 = TableRegistry(registry_path)
|
||||
assert len(reg2.list_tables()) == 1
|
||||
assert reg2.get_table("in.c-crm.company")["name"] == "company"
|
||||
assert reg2.version == 1
|
||||
|
||||
def test_json_format(self, registry_path, sample_table):
|
||||
reg = TableRegistry(registry_path)
|
||||
reg.register_table(sample_table, registered_by="admin@test.com")
|
||||
|
||||
with open(registry_path) as f:
|
||||
data = json.load(f)
|
||||
|
||||
assert "_metadata" in data
|
||||
assert "tables" in data
|
||||
assert data["_metadata"]["version"] == 1
|
||||
assert len(data["tables"]) == 1
|
||||
|
||||
|
||||
# ---------------------------------------------------------------------------
|
||||
# Folder mapping
|
||||
# ---------------------------------------------------------------------------
|
||||
|
||||
class TestFolderMapping:
|
||||
|
||||
def test_set_and_get(self, registry):
|
||||
registry.set_folder_mapping("in.c-crm", "crm")
|
||||
assert registry.get_folder_mapping() == {"in.c-crm": "crm"}
|
||||
|
||||
def test_persists(self, registry_path):
|
||||
reg1 = TableRegistry(registry_path)
|
||||
reg1.set_folder_mapping("in.c-crm", "crm")
|
||||
|
||||
reg2 = TableRegistry(registry_path)
|
||||
assert reg2.get_folder_mapping() == {"in.c-crm": "crm"}
|
||||
|
||||
|
||||
# ---------------------------------------------------------------------------
|
||||
# Generation
|
||||
# ---------------------------------------------------------------------------
|
||||
|
||||
class TestGeneration:
|
||||
|
||||
def test_generate_data_description_md(self, registry, sample_table, tmp_path):
|
||||
registry.register_table(sample_table, registered_by="admin@test.com")
|
||||
registry.set_folder_mapping("in.c-crm", "crm")
|
||||
|
||||
output = tmp_path / "data_description.md"
|
||||
registry.generate_data_description_md(output)
|
||||
|
||||
content = output.read_text()
|
||||
|
||||
# Check header
|
||||
assert "AUTO-GENERATED" in content
|
||||
assert "checksum: sha256:" in content
|
||||
|
||||
# Check YAML block is parseable
|
||||
yaml_match = __import__("re").search(r"```yaml\n(.*?)```", content, __import__("re").DOTALL)
|
||||
assert yaml_match
|
||||
yaml_data = yaml.safe_load(yaml_match.group(1))
|
||||
assert len(yaml_data["tables"]) == 1
|
||||
assert yaml_data["tables"][0]["id"] == "in.c-crm.company"
|
||||
assert yaml_data["folder_mapping"] == {"in.c-crm": "crm"}
|
||||
|
||||
def test_generate_includes_incremental_fields(
|
||||
self, registry, sample_table_incremental, tmp_path
|
||||
):
|
||||
registry.register_table(sample_table_incremental, registered_by="admin@test.com")
|
||||
|
||||
output = tmp_path / "data_description.md"
|
||||
registry.generate_data_description_md(output)
|
||||
|
||||
content = output.read_text()
|
||||
yaml_match = __import__("re").search(r"```yaml\n(.*?)```", content, __import__("re").DOTALL)
|
||||
yaml_data = yaml.safe_load(yaml_match.group(1))
|
||||
table = yaml_data["tables"][0]
|
||||
assert table["partition_by"] == "created_at"
|
||||
assert table["partition_granularity"] == "month"
|
||||
assert table["incremental_window_days"] == 14
|
||||
|
||||
|
||||
# ---------------------------------------------------------------------------
|
||||
# Migration
|
||||
# ---------------------------------------------------------------------------
|
||||
|
||||
class TestMigration:
|
||||
|
||||
def test_import_from_data_description(self, tmp_path):
|
||||
# Create a fake data_description.md
|
||||
md_content = """# Data Description
|
||||
|
||||
```yaml
|
||||
folder_mapping:
|
||||
in.c-crm: crm
|
||||
|
||||
tables:
|
||||
- id: in.c-crm.company
|
||||
name: company
|
||||
description: Companies
|
||||
primary_key: id
|
||||
sync_strategy: full_refresh
|
||||
|
||||
- id: in.c-crm.contact
|
||||
name: contact
|
||||
description: Contacts
|
||||
primary_key: id
|
||||
sync_strategy: incremental
|
||||
incremental_window_days: 7
|
||||
```
|
||||
"""
|
||||
md_path = tmp_path / "data_description.md"
|
||||
md_path.write_text(md_content)
|
||||
|
||||
registry_path = tmp_path / "table_registry.json"
|
||||
registry = TableRegistry.import_from_data_description(md_path, registry_path)
|
||||
|
||||
assert len(registry.list_tables()) == 2
|
||||
assert registry.is_registered("in.c-crm.company")
|
||||
assert registry.is_registered("in.c-crm.contact")
|
||||
assert registry.get_folder_mapping() == {"in.c-crm": "crm"}
|
||||
|
||||
# Check migrated_from marker
|
||||
with open(registry_path) as f:
|
||||
data = json.load(f)
|
||||
assert "migrated_from" in data["_metadata"]
|
||||
|
||||
def test_import_no_yaml_raises(self, tmp_path):
|
||||
md_path = tmp_path / "data_description.md"
|
||||
md_path.write_text("# Empty file\nNo YAML here.")
|
||||
|
||||
with pytest.raises(ValueError, match="No YAML blocks"):
|
||||
TableRegistry.import_from_data_description(
|
||||
md_path, tmp_path / "registry.json"
|
||||
)
|
||||
|
||||
def test_import_file_not_found_raises(self, tmp_path):
|
||||
with pytest.raises(FileNotFoundError):
|
||||
TableRegistry.import_from_data_description(
|
||||
tmp_path / "nonexistent.md", tmp_path / "registry.json"
|
||||
)
|
||||
|
||||
|
||||
# ---------------------------------------------------------------------------
|
||||
# Audit log
|
||||
# ---------------------------------------------------------------------------
|
||||
|
||||
class TestAuditLog:
|
||||
|
||||
def test_register_writes_audit(self, registry, sample_table):
|
||||
registry.register_table(sample_table, registered_by="admin@test.com")
|
||||
|
||||
audit_path = registry.registry_path.parent / "registry_audit.log"
|
||||
assert audit_path.exists()
|
||||
|
||||
lines = audit_path.read_text().strip().split("\n")
|
||||
assert len(lines) >= 1
|
||||
entry = json.loads(lines[-1])
|
||||
assert entry["action"] == "register"
|
||||
assert entry["table_id"] == "in.c-crm.company"
|
||||
|
||||
def test_unregister_writes_audit(self, registry, sample_table):
|
||||
registry.register_table(sample_table, registered_by="admin@test.com")
|
||||
registry.unregister_table("in.c-crm.company", unregistered_by="admin@test.com")
|
||||
|
||||
audit_path = registry.registry_path.parent / "registry_audit.log"
|
||||
lines = audit_path.read_text().strip().split("\n")
|
||||
last_entry = json.loads(lines[-1])
|
||||
assert last_entry["action"] == "unregister"
|
||||
156
webapp/app.py
156
webapp/app.py
|
|
@ -611,19 +611,11 @@ def _load_catalog_data() -> list:
|
|||
# Enrich with catalog metadata (OpenMetadata)
|
||||
if _catalog_enricher:
|
||||
try:
|
||||
# Create config for enrichment with all available fields
|
||||
from src.config import TableConfig
|
||||
table_config = TableConfig(
|
||||
# Create lightweight config for enrichment (enricher uses .id, .name, .catalog_fqn)
|
||||
from types import SimpleNamespace
|
||||
table_config = SimpleNamespace(
|
||||
id=table_id,
|
||||
name=table.get("name", ""),
|
||||
description=table.get("description", ""),
|
||||
primary_key=table.get("primary_key", "id"),
|
||||
sync_strategy=table.get("sync_strategy", "full_refresh"),
|
||||
incremental_window_days=table.get("incremental_window_days"),
|
||||
partition_by=table.get("partition_by"),
|
||||
partition_granularity=table.get("partition_granularity"),
|
||||
max_history_days=table.get("max_history_days"),
|
||||
partition_column_type=table.get("partition_column_type", "TIMESTAMP"),
|
||||
catalog_fqn=table.get("catalog_fqn"),
|
||||
)
|
||||
catalog_data = _catalog_enricher.enrich_table(table_config)
|
||||
|
|
@ -1097,7 +1089,7 @@ def register_routes(app: Flask) -> None:
|
|||
if _catalog_enricher and _catalog_enricher.enabled:
|
||||
try:
|
||||
# Find table config from data_description.md
|
||||
from src.config import TableConfig
|
||||
from types import SimpleNamespace
|
||||
from config.loader import load_instance_config
|
||||
|
||||
# Load data_description.md to find table config by name
|
||||
|
|
@ -1116,17 +1108,10 @@ def register_routes(app: Flask) -> None:
|
|||
# Find table by name
|
||||
for table_def in yaml_data["tables"]:
|
||||
if table_def.get("name") == table_name:
|
||||
table_config = TableConfig(
|
||||
# Lightweight config (enricher uses .id, .name, .catalog_fqn)
|
||||
table_config = SimpleNamespace(
|
||||
id=table_def.get("id", ""),
|
||||
name=table_def.get("name", ""),
|
||||
description=table_def.get("description", ""),
|
||||
primary_key=table_def.get("primary_key", "id"),
|
||||
sync_strategy=table_def.get("sync_strategy", "full_refresh"),
|
||||
incremental_window_days=table_def.get("incremental_window_days"),
|
||||
partition_by=table_def.get("partition_by"),
|
||||
partition_granularity=table_def.get("partition_granularity"),
|
||||
max_history_days=table_def.get("max_history_days"),
|
||||
partition_column_type=table_def.get("partition_column_type", "TIMESTAMP"),
|
||||
catalog_fqn=table_def.get("catalog_fqn"),
|
||||
)
|
||||
catalog_data = _catalog_enricher.enrich_table(table_config)
|
||||
|
|
@ -1862,17 +1847,25 @@ def register_routes(app: Flask) -> None:
|
|||
def admin_discover_tables():
|
||||
"""Discover all available tables from the data source."""
|
||||
try:
|
||||
from src.data_sync import create_data_source
|
||||
from app.instance_config import get_data_source_type, get_value
|
||||
|
||||
ds = create_data_source()
|
||||
raw_tables = ds.discover_tables()
|
||||
source_type = get_data_source_type()
|
||||
raw_tables = []
|
||||
if source_type == "keboola":
|
||||
from connectors.keboola.client import KeboolaClient
|
||||
url = get_value("keboola", "url", default="")
|
||||
token = os.environ.get(get_value("keboola", "token_env", default="KEBOOLA_STORAGE_TOKEN"), "")
|
||||
client = KeboolaClient(token=token, url=url)
|
||||
raw_tables = client.discover_all_tables()
|
||||
|
||||
# Check which tables are already registered
|
||||
registered_ids = set()
|
||||
try:
|
||||
from src.table_registry import TableRegistry
|
||||
registry = TableRegistry.default()
|
||||
registered_ids = {t["id"] for t in registry.list_tables()}
|
||||
from src.db import get_system_db
|
||||
from src.repositories.table_registry import TableRegistryRepository
|
||||
conn = get_system_db()
|
||||
repo = TableRegistryRepository(conn)
|
||||
registered_ids = {t["id"] for t in repo.list_all()}
|
||||
except Exception:
|
||||
pass
|
||||
|
||||
|
|
@ -1905,14 +1898,16 @@ def register_routes(app: Flask) -> None:
|
|||
def admin_registry_list():
|
||||
"""Return the full table registry."""
|
||||
try:
|
||||
from src.table_registry import TableRegistry
|
||||
from src.db import get_system_db
|
||||
from src.repositories.table_registry import TableRegistryRepository
|
||||
|
||||
registry = TableRegistry.default()
|
||||
conn = get_system_db()
|
||||
repo = TableRegistryRepository(conn)
|
||||
return jsonify({
|
||||
"ok": True,
|
||||
"version": registry.version,
|
||||
"folder_mapping": registry.get_folder_mapping(),
|
||||
"tables": registry.list_tables(),
|
||||
"version": 0,
|
||||
"folder_mapping": {},
|
||||
"tables": repo.list_all(),
|
||||
})
|
||||
except Exception as e:
|
||||
logger.error(f"Registry list failed: {e}")
|
||||
|
|
@ -1923,7 +1918,8 @@ def register_routes(app: Flask) -> None:
|
|||
@admin_required
|
||||
def admin_register_table():
|
||||
"""Register a new table from discovery results."""
|
||||
from src.table_registry import ConflictError, TableRegistry
|
||||
from src.db import get_system_db
|
||||
from src.repositories.table_registry import TableRegistryRepository
|
||||
|
||||
user = session.get("user", {})
|
||||
email = user.get("email", "")
|
||||
|
|
@ -1933,21 +1929,26 @@ def register_routes(app: Flask) -> None:
|
|||
return jsonify({"error": "Missing table 'id'"}), 400
|
||||
|
||||
try:
|
||||
registry = TableRegistry.default()
|
||||
registry.register_table(
|
||||
table_def=data,
|
||||
conn = get_system_db()
|
||||
repo = TableRegistryRepository(conn)
|
||||
repo.register(
|
||||
id=data["id"],
|
||||
name=data.get("name", ""),
|
||||
folder=data.get("folder"),
|
||||
sync_strategy=data.get("sync_strategy"),
|
||||
primary_key=data.get("primary_key"),
|
||||
description=data.get("description"),
|
||||
registered_by=email,
|
||||
expected_version=data.get("version"),
|
||||
source_type=data.get("source_type"),
|
||||
bucket=data.get("bucket"),
|
||||
source_table=data.get("source_table"),
|
||||
query_mode=data.get("query_mode", "local"),
|
||||
sync_schedule=data.get("sync_schedule"),
|
||||
profile_after_sync=data.get("profile_after_sync", True),
|
||||
)
|
||||
|
||||
# Regenerate data_description.md
|
||||
docs_path = Path(os.path.dirname(__file__)) / ".." / "docs" / "data_description.md"
|
||||
registry.generate_data_description_md(docs_path.resolve())
|
||||
return jsonify({"ok": True})
|
||||
|
||||
return jsonify({"ok": True, "version": registry.version})
|
||||
|
||||
except ConflictError as e:
|
||||
return jsonify({"error": str(e)}), 409
|
||||
except ValueError as e:
|
||||
return jsonify({"error": str(e)}), 400
|
||||
except Exception as e:
|
||||
|
|
@ -1959,30 +1960,42 @@ def register_routes(app: Flask) -> None:
|
|||
@admin_required
|
||||
def admin_update_table(table_id):
|
||||
"""Update configuration of a registered table."""
|
||||
from src.table_registry import ConflictError, TableRegistry
|
||||
from src.db import get_system_db
|
||||
from src.repositories.table_registry import TableRegistryRepository
|
||||
|
||||
user = session.get("user", {})
|
||||
email = user.get("email", "")
|
||||
|
||||
data = request.get_json(silent=True) or {}
|
||||
data.pop("version", None) # Not used by DuckDB repo
|
||||
|
||||
try:
|
||||
registry = TableRegistry.default()
|
||||
registry.update_table(
|
||||
table_id=table_id,
|
||||
updates=data,
|
||||
updated_by=email,
|
||||
expected_version=data.pop("version", None),
|
||||
conn = get_system_db()
|
||||
repo = TableRegistryRepository(conn)
|
||||
|
||||
# Get existing record and merge updates
|
||||
existing = repo.get(table_id)
|
||||
if not existing:
|
||||
return jsonify({"error": f"Table '{table_id}' not found"}), 404
|
||||
|
||||
repo.register(
|
||||
id=table_id,
|
||||
name=data.get("name", existing.get("name", "")),
|
||||
folder=data.get("folder", existing.get("folder")),
|
||||
sync_strategy=data.get("sync_strategy", existing.get("sync_strategy")),
|
||||
primary_key=data.get("primary_key", existing.get("primary_key")),
|
||||
description=data.get("description", existing.get("description")),
|
||||
registered_by=email,
|
||||
source_type=data.get("source_type", existing.get("source_type")),
|
||||
bucket=data.get("bucket", existing.get("bucket")),
|
||||
source_table=data.get("source_table", existing.get("source_table")),
|
||||
query_mode=data.get("query_mode", existing.get("query_mode", "local")),
|
||||
sync_schedule=data.get("sync_schedule", existing.get("sync_schedule")),
|
||||
profile_after_sync=data.get("profile_after_sync", existing.get("profile_after_sync", True)),
|
||||
)
|
||||
|
||||
# Regenerate data_description.md
|
||||
docs_path = Path(os.path.dirname(__file__)) / ".." / "docs" / "data_description.md"
|
||||
registry.generate_data_description_md(docs_path.resolve())
|
||||
return jsonify({"ok": True})
|
||||
|
||||
return jsonify({"ok": True, "version": registry.version})
|
||||
|
||||
except ConflictError as e:
|
||||
return jsonify({"error": str(e)}), 409
|
||||
except ValueError as e:
|
||||
return jsonify({"error": str(e)}), 400
|
||||
except Exception as e:
|
||||
|
|
@ -1994,25 +2007,18 @@ def register_routes(app: Flask) -> None:
|
|||
@admin_required
|
||||
def admin_unregister_table(table_id):
|
||||
"""Unregister a table and clean up subscriptions."""
|
||||
from src.table_registry import ConflictError, TableRegistry
|
||||
|
||||
user = session.get("user", {})
|
||||
email = user.get("email", "")
|
||||
|
||||
data = request.get_json(silent=True) or {}
|
||||
from src.db import get_system_db
|
||||
from src.repositories.table_registry import TableRegistryRepository
|
||||
|
||||
try:
|
||||
registry = TableRegistry.default()
|
||||
conn = get_system_db()
|
||||
repo = TableRegistryRepository(conn)
|
||||
|
||||
# Get table name before deletion (for subscription cleanup)
|
||||
table_info = registry.get_table(table_id)
|
||||
table_info = repo.get(table_id)
|
||||
table_name = table_info["name"] if table_info else None
|
||||
|
||||
registry.unregister_table(
|
||||
table_id=table_id,
|
||||
unregistered_by=email,
|
||||
expected_version=data.get("version"),
|
||||
)
|
||||
repo.unregister(table_id)
|
||||
|
||||
# Clean up per-user subscriptions for removed table
|
||||
if table_name:
|
||||
|
|
@ -2021,14 +2027,8 @@ def register_routes(app: Flask) -> None:
|
|||
except Exception as ce:
|
||||
logger.warning(f"Subscription cleanup for {table_name} failed: {ce}")
|
||||
|
||||
# Regenerate data_description.md
|
||||
docs_path = Path(os.path.dirname(__file__)) / ".." / "docs" / "data_description.md"
|
||||
registry.generate_data_description_md(docs_path.resolve())
|
||||
return jsonify({"ok": True})
|
||||
|
||||
return jsonify({"ok": True, "version": registry.version})
|
||||
|
||||
except ConflictError as e:
|
||||
return jsonify({"error": str(e)}), 409
|
||||
except ValueError as e:
|
||||
return jsonify({"error": str(e)}), 400
|
||||
except Exception as e:
|
||||
|
|
|
|||
|
|
@ -194,9 +194,12 @@ def _write_rsync_filter(username: str, dataset_settings: dict, table_mode: str,
|
|||
# Load folder_mapping from table registry (or instance config as fallback)
|
||||
folder_mapping = {}
|
||||
try:
|
||||
from src.table_registry import TableRegistry
|
||||
registry = TableRegistry.default()
|
||||
folder_mapping = registry.get_folder_mapping()
|
||||
from src.db import get_system_db
|
||||
from src.repositories.table_registry import TableRegistryRepository
|
||||
conn = get_system_db()
|
||||
repo = TableRegistryRepository(conn)
|
||||
tables = repo.list_all()
|
||||
folder_mapping = {t["bucket"]: t["folder"] for t in tables if t.get("bucket") and t.get("folder")}
|
||||
except Exception:
|
||||
try:
|
||||
from config.loader import load_instance_config, get_instance_value
|
||||
|
|
|
|||
Loading…
Reference in a new issue