refactor: delete old sync pipeline — 9,500 lines removed

Phase 5 cleanup: remove all code replaced by extract.duckdb architecture.

Deleted modules:
- src/config.py (653) — replaced by DuckDB table_registry
- src/parquet_manager.py (755) — replaced by DuckDB COPY TO
- src/data_sync.py (734) — replaced by SyncOrchestrator
- src/remote_query.py (636) — replaced by DuckDB BigQuery ATTACH
- src/table_registry.py (464) — replaced by DuckDB repository
- connectors/keboola/adapter.py (820) — replaced by extractor.py
- connectors/bigquery/adapter.py (665) — replaced by extractor.py
- connectors/bigquery/client.py (644) — replaced by DuckDB BQ extension

Updated all imports in webapp, catalog_export, enricher, router,
sync_settings_service, generate_sample_data. Kept keboola/client.py
as fallback (removed src.config dependency).

704 tests passing.
This commit is contained in:
ZdenekSrotyr 2026-03-31 07:50:37 +02:00
parent 9f20529f10
commit b502bd8bdd
26 changed files with 188 additions and 9490 deletions

View file

@ -236,24 +236,25 @@ async def catalog(
enabled_datasets = settings_repo.get_enabled_datasets(user["id"])
datasets = get_datasets()
# Build catalog data from config
# Build catalog data from table_registry in DuckDB
try:
from src.config import get_config
config = get_config()
from src.repositories.table_registry import TableRegistryRepository
table_repo = TableRegistryRepository(conn)
registered = table_repo.list_all()
tables = []
for tc in config.tables:
for tc in registered:
table_data = {
"id": tc.id,
"name": tc.name,
"description": tc.description,
"dataset": getattr(tc, "dataset", None),
"sync_strategy": tc.sync_strategy,
"query_mode": getattr(tc, "query_mode", "local"),
"profile": all_profiles.get(tc.id),
"id": tc.get("id", ""),
"name": tc.get("name", ""),
"description": tc.get("description", ""),
"dataset": tc.get("bucket"),
"sync_strategy": tc.get("sync_strategy", "full_refresh"),
"query_mode": tc.get("query_mode", "local"),
"profile": all_profiles.get(tc.get("id", "")),
}
# Add sync state
for state in all_states:
if state["table_id"] == tc.id:
if state["table_id"] == tc.get("id"):
table_data["last_sync"] = state.get("last_sync")
table_data["rows"] = state.get("rows")
break

View file

@ -1,665 +0,0 @@
"""
BigQuery data source adapter.
Implements the DataSource interface for Google BigQuery.
Reads tables via the BigQuery API, converts directly to Parquet files
using PyArrow (no CSV intermediate step).
"""
import logging
from pathlib import Path
from typing import Dict, List, Optional, Any
from datetime import datetime, timedelta, date
import pyarrow as pa
import pyarrow.parquet as pq
from src.config import get_config, TableConfig
from src.data_sync import DataSource, SyncState, _get_uncompressed_size
from src.parquet_manager import (
convert_date_columns_to_date32,
apply_schema_to_table,
)
from .client import create_client as create_bq_client
logger = logging.getLogger(__name__)
class BigQueryDataSource(DataSource):
"""
Data source: Google BigQuery.
Downloads data directly from BigQuery via PyArrow (no CSV step),
writes to local Parquet files with schema enforcement.
"""
def __init__(self):
"""Initialize BigQuery source with env var validation."""
self.config = get_config()
self.bq_client = create_bq_client()
def get_column_metadata(self, table_id: str) -> Optional[Dict[str, Any]]:
"""Return BigQuery column metadata for schema generation.
Returns:
{"columns": {"col_name": {"source_type": "...", "description": "..."}}}
or None if metadata unavailable.
"""
raw = self.bq_client.get_table_metadata(table_id)
column_types = raw.get("column_types", {})
column_descriptions = raw.get("column_descriptions", {})
if not column_types:
return None
result = {}
for col_name, bq_type in column_types.items():
entry = {"source_type": bq_type}
if col_name in column_descriptions:
entry["description"] = column_descriptions[col_name]
result[col_name] = entry
return {"columns": result}
def discover_tables(self) -> List[Dict[str, Any]]:
"""Discover all available tables from BigQuery."""
return self.bq_client.discover_all_tables()
def get_source_name(self) -> str:
"""Display name of this data source."""
return "Google BigQuery"
def sync_table(
self,
table_config: TableConfig,
sync_state: SyncState,
) -> Dict[str, Any]:
"""
Synchronize table from BigQuery.
Dispatches to the appropriate strategy based on table config.
Args:
table_config: Table configuration
sync_state: Sync state manager
Returns:
Dictionary with sync result
"""
logger.info(f"Syncing BQ table: {table_config.name} ({table_config.sync_strategy})")
# Clear metadata cache for fresh types
if table_config.id in self.bq_client.metadata_cache:
del self.bq_client.metadata_cache[table_config.id]
logger.debug(f"Cleared BQ metadata cache for {table_config.id}")
try:
if table_config.sync_strategy == "full_refresh":
result = self._full_refresh(table_config)
elif table_config.sync_strategy == "incremental":
result = self._incremental_sync(table_config, sync_state)
elif table_config.sync_strategy == "partitioned":
result = self._partitioned_sync(table_config, sync_state)
else:
raise ValueError(f"Unknown sync strategy: {table_config.sync_strategy}")
# Skip sync state update if partitioned sync got no new data.
# This lets the scheduler retry on the next tick instead of
# marking the sync as done for the day with stale data.
skip_state_update = (
table_config.sync_strategy == "partitioned"
and result.get("partitions_updated", -1) == 0
)
if skip_state_update:
logger.warning(
f"Partitioned sync for {table_config.name} got 0 new partitions "
f"- NOT updating last_sync (will retry next tick)"
)
else:
sync_state.update_sync(
table_id=table_config.id,
table_name=table_config.name,
strategy=table_config.sync_strategy,
rows=result["rows"],
file_size_bytes=result["file_size_bytes"],
columns=result.get("columns", 0),
uncompressed_bytes=result.get("uncompressed_bytes", 0),
)
return {
"success": True,
"rows": result["rows"],
"strategy": table_config.sync_strategy,
"file_size_mb": result["file_size_bytes"] / 1024 / 1024,
}
except Exception as e:
logger.error(f"Error syncing BQ table {table_config.name}: {e}")
return {
"success": False,
"error": str(e),
"strategy": table_config.sync_strategy,
}
def _full_refresh(self, table_config: TableConfig) -> Dict[str, Any]:
"""
Full refresh: stream table from BQ and write to Parquet in batches.
Uses streaming (constant memory) instead of loading entire table into RAM.
Each RecordBatch from BQ is written directly to disk via ParquetWriter.
"""
logger.info(f"Full refresh (streaming): {table_config.name}")
parquet_path = self.config.get_parquet_path(table_config)
date_columns = self.bq_client.get_date_columns(table_config.id)
pyarrow_schema = self.bq_client.get_pyarrow_schema(table_config.id)
parquet_path.parent.mkdir(parents=True, exist_ok=True)
# Stream BQ results directly to Parquet file (constant memory)
writer = None
total_rows = 0
num_columns = 0
for batch in self.bq_client.read_table_streaming(
table_config.id,
columns=table_config.columns,
row_filter=table_config.row_filter,
):
if batch.num_rows == 0:
continue
# Convert batch to table for schema enforcement
chunk = pa.Table.from_batches([batch])
if date_columns:
chunk = convert_date_columns_to_date32(chunk, date_columns)
if pyarrow_schema:
chunk = apply_schema_to_table(chunk, pyarrow_schema)
if writer is None:
writer = pq.ParquetWriter(
parquet_path, chunk.schema, compression="snappy",
)
num_columns = chunk.num_columns
writer.write_table(chunk)
total_rows += chunk.num_rows
# Log progress every ~1M rows
if total_rows % 1_000_000 < chunk.num_rows:
logger.info(f" -> {total_rows:,} rows written...")
if writer:
writer.close()
file_size = parquet_path.stat().st_size if parquet_path.exists() else 0
logger.info(
f"Full refresh complete: {total_rows:,} rows, "
f"{file_size / 1024 / 1024:.2f} MB"
)
return {
"rows": total_rows,
"columns": num_columns,
"file_size_bytes": file_size,
"uncompressed_bytes": _get_uncompressed_size(parquet_path) if total_rows > 0 else 0,
}
def _incremental_sync(
self,
table_config: TableConfig,
sync_state: SyncState,
) -> Dict[str, Any]:
"""
Incremental sync: dispatch to column-based or partition-based strategy.
"""
# If partition_by is set, use partitioned incremental
if table_config.partition_by:
return self._partitioned_sync(table_config, sync_state)
# If incremental_column is set, use timestamp-based incremental
if table_config.incremental_column:
return self._incremental_column_sync(table_config, sync_state)
# Fallback: full refresh (no incremental column configured)
logger.warning(
f"Table {table_config.name}: incremental strategy but no "
f"incremental_column or partition_by configured, falling back to full refresh"
)
return self._full_refresh(table_config)
def _incremental_column_sync(
self,
table_config: TableConfig,
sync_state: SyncState,
) -> Dict[str, Any]:
"""
Timestamp-based incremental sync using incremental_column.
Reads rows WHERE incremental_column > last_sync_value,
merges with existing Parquet (dedup on primary key).
"""
logger.info(
f"Incremental column sync: {table_config.name} "
f"(column: {table_config.incremental_column})"
)
parquet_path = self.config.get_parquet_path(table_config)
date_columns = self.bq_client.get_date_columns(table_config.id)
pyarrow_schema = self.bq_client.get_pyarrow_schema(table_config.id)
# Determine since_value from last sync
last_sync = sync_state.get_last_sync(table_config.id)
if last_sync and parquet_path.exists():
# Apply window: go back incremental_window_days from last sync
last_sync_dt = datetime.fromisoformat(last_sync)
window_days = table_config.incremental_window_days or 7
since_dt = last_sync_dt - timedelta(days=window_days)
since_value = since_dt.isoformat()
logger.info(f" -> Since: {since_value} (window: {window_days} days)")
# Read incremental data
new_data = self.bq_client.read_table_incremental(
table_id=table_config.id,
incremental_column=table_config.incremental_column,
since_value=since_value,
columns=table_config.columns,
)
if new_data.num_rows == 0:
logger.info(" -> No new data since last sync")
existing_pf = pq.ParquetFile(parquet_path)
return {
"rows": existing_pf.metadata.num_rows,
"columns": len(existing_pf.schema_arrow),
"file_size_bytes": parquet_path.stat().st_size,
"uncompressed_bytes": _get_uncompressed_size(parquet_path),
}
# Merge with existing data
logger.info(f" -> Merging {new_data.num_rows} new rows with existing data")
existing_table = pq.read_table(parquet_path)
merged = self._merge_arrow_tables(
existing_table, new_data, table_config.get_primary_key_columns()
)
# Apply schema enforcement
if date_columns:
merged = convert_date_columns_to_date32(merged, date_columns)
if pyarrow_schema:
merged = apply_schema_to_table(merged, pyarrow_schema)
pq.write_table(merged, parquet_path, compression="snappy")
file_size = parquet_path.stat().st_size
logger.info(
f" -> Incremental sync complete: {merged.num_rows} total rows"
)
return {
"rows": merged.num_rows,
"columns": merged.num_columns,
"file_size_bytes": file_size,
"uncompressed_bytes": _get_uncompressed_size(parquet_path),
}
else:
# First sync or no existing file -- full read
logger.info(" -> First sync, reading all data")
if table_config.max_history_days:
since_dt = datetime.now() - timedelta(days=table_config.max_history_days)
arrow_table = self.bq_client.read_table_incremental(
table_id=table_config.id,
incremental_column=table_config.incremental_column,
since_value=since_dt.isoformat(),
columns=table_config.columns,
)
else:
arrow_table = self.bq_client.read_table(
table_config.id,
columns=table_config.columns,
row_filter=table_config.row_filter,
)
# Apply schema enforcement
if date_columns:
arrow_table = convert_date_columns_to_date32(arrow_table, date_columns)
if pyarrow_schema:
arrow_table = apply_schema_to_table(arrow_table, pyarrow_schema)
parquet_path.parent.mkdir(parents=True, exist_ok=True)
pq.write_table(arrow_table, parquet_path, compression="snappy")
file_size = parquet_path.stat().st_size
return {
"rows": arrow_table.num_rows,
"columns": arrow_table.num_columns,
"file_size_bytes": file_size,
"uncompressed_bytes": _get_uncompressed_size(parquet_path),
}
def _partitioned_sync(
self,
table_config: TableConfig,
sync_state: SyncState,
) -> Dict[str, Any]:
"""
Per-partition streaming sync: process one partition (day) at a time.
Queries BQ for a single day, streams result to disk, then moves to next day.
Memory usage is constant (~20-50 MB per partition) regardless of total data volume.
"""
partition_col = table_config.partition_by
if not partition_col and table_config.incremental_column:
partition_col = table_config.incremental_column
if not partition_col:
logger.warning(
f"Table {table_config.name}: partitioned strategy but no "
f"partition_by or incremental_column, falling back to full refresh"
)
return self._full_refresh(table_config)
granularity = table_config.partition_granularity or "day"
column_type = table_config.partition_column_type
logger.info(
f"Partitioned sync: {table_config.name} "
f"(by {partition_col}, {granularity}, type={column_type})"
)
partition_dir = self.config.get_parquet_path(table_config)
date_columns = self.bq_client.get_date_columns(table_config.id)
pyarrow_schema = self.bq_client.get_pyarrow_schema(table_config.id)
# Determine date range
last_sync = sync_state.get_last_sync(table_config.id)
today = date.today()
if last_sync:
last_sync_dt = datetime.fromisoformat(last_sync)
window_days = table_config.incremental_window_days or 7
start_date = (last_sync_dt - timedelta(days=window_days)).date()
logger.info(f" -> Incremental sync from {start_date} (window: {window_days} days)")
else:
if table_config.max_history_days:
start_date = today - timedelta(days=table_config.max_history_days)
logger.info(f" -> First sync, last {table_config.max_history_days} days from {start_date}")
else:
start_date = today - timedelta(days=365)
logger.info(" -> First sync, no max_history_days, defaulting to 365 days")
# Generate list of partition dates
partition_dates = self._generate_partition_dates(start_date, today, granularity)
logger.info(f" -> Processing {len(partition_dates)} partitions")
total_rows = 0
partitions_updated = 0
for partition_date in partition_dates:
rows = self._sync_single_partition(
table_config=table_config,
partition_col=partition_col,
partition_date=partition_date,
partition_dir=partition_dir,
date_columns=date_columns,
pyarrow_schema=pyarrow_schema,
granularity=granularity,
column_type=column_type,
)
if rows > 0:
partitions_updated += 1
total_rows += rows
# Cleanup old partitions beyond retention window
deleted = self._cleanup_old_partitions(table_config, partition_dir, granularity)
if deleted > 0:
logger.info(f" -> Cleaned up {deleted} old partition files")
logger.info(
f" -> Partitioned sync complete: {partitions_updated} partitions updated, "
f"{total_rows} total rows processed"
)
result = self._get_partition_totals(partition_dir)
result["partitions_updated"] = partitions_updated
return result
@staticmethod
def _generate_partition_dates(
start_date: date,
end_date: date,
granularity: str,
) -> List[date]:
"""Generate list of partition start dates between start and end."""
dates = []
current = start_date
if granularity == "day":
while current <= end_date:
dates.append(current)
current += timedelta(days=1)
elif granularity == "month":
# Align to first of month
current = current.replace(day=1)
while current <= end_date:
dates.append(current)
# Move to first of next month
if current.month == 12:
current = current.replace(year=current.year + 1, month=1)
else:
current = current.replace(month=current.month + 1)
elif granularity == "year":
current = current.replace(month=1, day=1)
while current <= end_date:
dates.append(current)
current = current.replace(year=current.year + 1)
return dates
def _sync_single_partition(
self,
table_config: TableConfig,
partition_col: str,
partition_date: date,
partition_dir: Path,
date_columns: List[str],
pyarrow_schema,
granularity: str,
column_type: str,
) -> int:
"""
Query BQ for one partition period, stream to disk, merge with existing file.
Returns row count for this partition after merge.
"""
import pandas as pd
# Calculate partition range [start, end)
start = partition_date
if granularity == "day":
end = start + timedelta(days=1)
partition_key = start.strftime("%Y_%m_%d")
elif granularity == "month":
if start.month == 12:
end = start.replace(year=start.year + 1, month=1)
else:
end = start.replace(month=start.month + 1)
partition_key = start.strftime("%Y_%m")
elif granularity == "year":
end = start.replace(year=start.year + 1)
partition_key = start.strftime("%Y")
else:
raise ValueError(f"Unknown granularity: {granularity}")
partition_path = self.config.get_partition_path(table_config, partition_key)
# Stream data from BQ for this single partition
batches = []
for batch in self.bq_client.read_table_partitioned_streaming(
table_id=table_config.id,
partition_column=partition_col,
start=start.isoformat(),
end=end.isoformat(),
columns=table_config.columns,
column_type=column_type,
):
batches.append(batch)
if not batches:
return 0
new_data = pa.Table.from_batches(batches)
if new_data.num_rows == 0:
return 0
# Apply schema conversions
if date_columns:
new_data = convert_date_columns_to_date32(new_data, date_columns)
if pyarrow_schema:
new_data = apply_schema_to_table(new_data, pyarrow_schema)
# Merge with existing partition file if present
primary_key_cols = table_config.get_primary_key_columns()
if partition_path.exists():
existing = pq.read_table(partition_path)
merged = self._merge_arrow_tables(existing, new_data, primary_key_cols)
else:
merged = new_data
# Write partition file
pq.write_table(merged, partition_path, compression="snappy")
row_count = merged.num_rows
logger.debug(
f" Partition {partition_key}: {new_data.num_rows} new rows, "
f"{row_count} total after merge"
)
# Release memory
del batches, new_data, merged
return row_count
def _cleanup_old_partitions(
self,
table_config: TableConfig,
partition_dir: Path,
granularity: str,
) -> int:
"""
Delete partition files older than max_history_days.
Returns count of deleted files.
"""
if not table_config.max_history_days:
return 0
if not partition_dir.exists():
return 0
cutoff_date = date.today() - timedelta(days=table_config.max_history_days)
deleted = 0
for part_path in partition_dir.glob("*.parquet"):
try:
partition_date = self._parse_partition_date(part_path.stem, granularity)
if partition_date and partition_date < cutoff_date:
part_path.unlink()
deleted += 1
logger.debug(f" Deleted old partition: {part_path.name}")
except (ValueError, IndexError):
logger.warning(f" Skipping unrecognized partition file: {part_path.name}")
return deleted
@staticmethod
def _parse_partition_date(partition_key: str, granularity: str) -> Optional[date]:
"""Parse a partition key back to a date."""
try:
if granularity == "day":
return datetime.strptime(partition_key, "%Y_%m_%d").date()
elif granularity == "month":
return datetime.strptime(partition_key, "%Y_%m").date()
elif granularity == "year":
return datetime.strptime(partition_key, "%Y").date()
except ValueError:
return None
return None
def _merge_arrow_tables(
self,
existing: pa.Table,
new_data: pa.Table,
primary_key: List[str],
) -> pa.Table:
"""
Merge two Arrow tables with deduplication on primary key.
New data overwrites existing rows with the same primary key.
Args:
existing: Existing data
new_data: New/changed data
primary_key: List of PK column names
Returns:
Merged PyArrow Table
"""
import pandas as pd
existing_df = existing.to_pandas()
new_df = new_data.to_pandas()
# Concat and dedup (keep last = new data wins)
merged_df = pd.concat([existing_df, new_df], ignore_index=True)
merged_df = merged_df.drop_duplicates(subset=primary_key, keep="last")
return pa.Table.from_pandas(merged_df, preserve_index=False)
def _get_partition_totals(self, partition_dir: Path) -> Dict[str, Any]:
"""
Calculate totals from all partition files in a directory.
"""
total_rows = 0
total_size = 0
total_uncompressed = 0
total_columns = 0
if not partition_dir.exists():
return {"rows": 0, "file_size_bytes": 0, "columns": 0, "uncompressed_bytes": 0}
all_partitions = list(partition_dir.glob("*.parquet"))
for part_path in all_partitions:
try:
pf = pq.ParquetFile(part_path)
meta = pf.metadata
total_rows += meta.num_rows
total_size += part_path.stat().st_size
if total_columns == 0:
total_columns = len(pf.schema_arrow)
for rg_idx in range(meta.num_row_groups):
rg = meta.row_group(rg_idx)
for col_idx in range(rg.num_columns):
total_uncompressed += rg.column(col_idx).total_uncompressed_size
except Exception as e:
logger.warning(f"Skipping corrupt partition {part_path.name}: {e}")
return {
"rows": total_rows,
"file_size_bytes": total_size,
"partitions": len(all_partitions),
"columns": total_columns,
"uncompressed_bytes": total_uncompressed,
}
def create_data_source() -> BigQueryDataSource:
"""Factory function for dynamic import compatibility."""
return BigQueryDataSource()

View file

@ -1,644 +0,0 @@
"""
Google BigQuery API Client
Low-level wrapper for Google BigQuery with these functions:
1. Authentication using Application Default Credentials (ADC)
2. Query tables to PyArrow (no CSV intermediate step)
3. Get table metadata (schema, columns, data types)
4. Cache metadata for faster repeated use
5. Incremental reads (timestamp-based and partition-based)
Uses google-cloud-bigquery with native PyArrow support.
"""
import json
import logging
import os
from pathlib import Path
from typing import Dict, List, Optional, Any
from datetime import datetime, timedelta
import pyarrow as pa
from google.cloud import bigquery
try:
from google.cloud import bigquery_storage_v1
_HAS_BQ_STORAGE = True
except ImportError:
_HAS_BQ_STORAGE = False
from src.config import get_config
logger = logging.getLogger(__name__)
# Mapping BigQuery types to PyArrow types
BIGQUERY_TO_PYARROW_TYPES = {
"STRING": pa.string(),
"BYTES": pa.binary(),
"INTEGER": pa.int64(),
"INT64": pa.int64(),
"FLOAT": pa.float64(),
"FLOAT64": pa.float64(),
"NUMERIC": pa.float64(),
"BIGNUMERIC": pa.float64(),
"BOOLEAN": pa.bool_(),
"BOOL": pa.bool_(),
"TIMESTAMP": pa.timestamp("us", tz="UTC"),
"DATE": pa.date32(),
"TIME": pa.string(),
"DATETIME": pa.timestamp("us"),
"GEOGRAPHY": pa.string(),
"JSON": pa.string(),
"STRUCT": pa.string(),
"RECORD": pa.string(),
"ARRAY": pa.string(),
}
class BigQueryClient:
"""
Wrapper for Google BigQuery API.
Provides high-level methods for working with BigQuery tables:
- Query tables to PyArrow Tables (no CSV step)
- Get metadata (schema, columns)
- Incremental and partitioned reads
"""
def __init__(
self,
project_id: Optional[str] = None,
location: Optional[str] = None,
):
"""
Initialize BigQuery client.
Args:
project_id: GCP project ID for job execution/billing.
If None, reads from BIGQUERY_PROJECT env var.
location: BigQuery location for job execution (e.g., "us-central1").
If None, reads from BIGQUERY_LOCATION env var.
Raises:
ValueError: If project_id is not provided and BIGQUERY_PROJECT is not set.
"""
self.project_id = project_id or os.environ.get("BIGQUERY_PROJECT")
if not self.project_id:
raise ValueError(
"BigQuery project ID not set. "
"Set BIGQUERY_PROJECT environment variable."
)
self.location = location or os.environ.get("BIGQUERY_LOCATION")
# Initialize BigQuery client with ADC
# project_id is used for job execution and billing.
# Data can live in a different project -- table IDs in queries
# use fully-qualified format (project.dataset.table).
client_kwargs = {"project": self.project_id}
if self.location:
client_kwargs["location"] = self.location
self.client = bigquery.Client(**client_kwargs)
# BQ Storage API client for fast parallel reads (gRPC streams).
# Without explicit bqstorage_client, to_arrow_iterable() silently
# falls back to slow REST API pagination (~5K rows/sec vs ~300K rows/sec).
if _HAS_BQ_STORAGE:
try:
self.bqstorage_client = bigquery_storage_v1.BigQueryReadClient()
logger.info("BQ Storage API client initialized (fast parallel gRPC reads)")
except Exception as e:
self.bqstorage_client = None
logger.warning(f"BQ Storage API client failed to initialize: {e}")
else:
self.bqstorage_client = None
logger.info("BQ Storage API not available (install google-cloud-bigquery-storage)")
# Metadata cache
config = get_config()
self.metadata_cache: Dict[str, Dict[str, Any]] = {}
self.metadata_cache_path = config.get_metadata_path() / "bq_table_metadata.json"
# Load cache from disk if exists
self._load_metadata_cache()
logger.info(
f"BigQuery client initialized: project={self.project_id}, "
f"location={self.location or 'auto'}"
)
def _load_metadata_cache(self):
"""Load metadata cache from disk."""
if self.metadata_cache_path.exists():
try:
with open(self.metadata_cache_path, "r") as f:
self.metadata_cache = json.load(f)
logger.info(
f"BQ metadata cache loaded: {len(self.metadata_cache)} tables"
)
except Exception as e:
logger.warning(f"Error loading BQ metadata cache: {e}")
self.metadata_cache = {}
def _save_metadata_cache(self):
"""Save metadata cache to disk."""
try:
self.metadata_cache_path.parent.mkdir(parents=True, exist_ok=True)
with open(self.metadata_cache_path, "w") as f:
json.dump(self.metadata_cache, f, indent=2)
logger.debug("BQ metadata cache saved")
except Exception as e:
logger.warning(f"Error saving BQ metadata cache: {e}")
def get_table_metadata(
self,
table_id: str,
use_cache: bool = True,
cache_ttl_hours: int = 24,
) -> Dict[str, Any]:
"""
Get table metadata from BigQuery.
Args:
table_id: Full table ID (e.g., "project.dataset.table")
use_cache: Use cache if available
cache_ttl_hours: Cache TTL in hours (default 24h)
Returns:
Dictionary with metadata including columns, types, descriptions, row count.
"""
# Check cache
if use_cache and table_id in self.metadata_cache:
cached = self.metadata_cache[table_id]
cached_time = datetime.fromisoformat(cached.get("_cached_at", "2000-01-01"))
cache_age = datetime.now() - cached_time
if cache_age < timedelta(hours=cache_ttl_hours):
logger.debug(f"Using BQ metadata cache for {table_id}")
return cached
logger.info(f"Fetching metadata for BQ table: {table_id}")
try:
table_ref = self.client.get_table(table_id)
# Build column metadata
columns = []
column_types = {}
column_descriptions = {}
for field in table_ref.schema:
columns.append(field.name)
column_types[field.name] = field.field_type
if field.description:
column_descriptions[field.name] = field.description
metadata = {
"table_id": table_id,
"name": table_ref.table_id,
"dataset": table_ref.dataset_id,
"project": table_ref.project,
"columns": columns,
"column_types": column_types,
"column_descriptions": column_descriptions,
"row_count": table_ref.num_rows,
"size_bytes": table_ref.num_bytes,
"created": table_ref.created.isoformat() if table_ref.created else None,
"modified": table_ref.modified.isoformat() if table_ref.modified else None,
"partitioning": None,
"_cached_at": datetime.now().isoformat(),
}
# Capture partitioning info
if table_ref.time_partitioning:
metadata["partitioning"] = {
"type": table_ref.time_partitioning.type_,
"field": table_ref.time_partitioning.field,
"expiration_ms": table_ref.time_partitioning.expiration_ms,
}
# Save to cache
self.metadata_cache[table_id] = metadata
self._save_metadata_cache()
return metadata
except Exception as e:
logger.error(f"Error getting metadata for {table_id}: {e}")
raise
def get_pyarrow_schema(self, table_id: str) -> Optional[pa.Schema]:
"""
Build PyArrow schema from BigQuery table schema.
Args:
table_id: Full table ID
Returns:
PyArrow schema or None if metadata unavailable
"""
metadata = self.get_table_metadata(table_id)
column_types = metadata.get("column_types", {})
if not column_types:
logger.warning(f"No column types for {table_id}, schema will not be applied")
return None
fields = []
for col_name in metadata.get("columns", []):
bq_type = column_types.get(col_name, "STRING")
pa_type = BIGQUERY_TO_PYARROW_TYPES.get(bq_type, pa.string())
fields.append(pa.field(col_name, pa_type))
return pa.schema(fields)
def get_date_columns(self, table_id: str) -> List[str]:
"""
Get list of DATE-only columns for a table.
Args:
table_id: Full table ID
Returns:
List of column names that have DATE type in BigQuery
"""
metadata = self.get_table_metadata(table_id)
column_types = metadata.get("column_types", {})
return [
col_name for col_name, bq_type in column_types.items()
if bq_type == "DATE"
]
def query_to_arrow(
self,
sql: str,
params: Optional[List[bigquery.ScalarQueryParameter]] = None,
) -> pa.Table:
"""
Execute SQL query and return results as PyArrow Table.
Args:
sql: SQL query string (use @param_name for parameterized values)
params: List of BigQuery query parameters
Returns:
PyArrow Table with query results
"""
job_config = bigquery.QueryJobConfig()
if params:
job_config.query_parameters = params
logger.debug(f"Executing BQ query: {sql[:200]}...")
query_job = self.client.query(sql, job_config=job_config)
# Use BQ Storage API for fast reads (parallel gRPC) if available.
# Fall back to REST API if SA lacks bigquery.readsessions.create permission.
try:
if self.bqstorage_client:
arrow_table = query_job.to_arrow(bqstorage_client=self.bqstorage_client)
else:
arrow_table = query_job.to_arrow()
except Exception as storage_err:
if "readsessions" in str(storage_err) or "PERMISSION_DENIED" in str(storage_err):
logger.warning(
"BQ Storage API unavailable (missing readsessions permission), "
"falling back to REST API"
)
arrow_table = query_job.to_arrow(create_bqstorage_client=False)
else:
raise
logger.debug(f"Query returned {arrow_table.num_rows} rows, {arrow_table.num_columns} columns")
return arrow_table
def query_to_arrow_batches(
self,
sql: str,
params: Optional[List[bigquery.ScalarQueryParameter]] = None,
):
"""
Execute SQL query and yield results as streaming RecordBatches.
Unlike query_to_arrow(), this does NOT load entire result into memory.
Each RecordBatch is a small chunk (typically a few MB) that can be
written to disk immediately.
Args:
sql: SQL query string (use @param_name for parameterized values)
params: List of BigQuery query parameters
Yields:
pyarrow.RecordBatch objects
"""
job_config = bigquery.QueryJobConfig()
if params:
job_config.query_parameters = params
logger.debug(f"Executing BQ query (streaming): {sql[:200]}...")
query_job = self.client.query(sql, job_config=job_config)
# result() returns RowIterator which has to_arrow_iterable()
# (QueryJob itself only has to_arrow(), not to_arrow_iterable())
row_iter = query_job.result()
# IMPORTANT: to_arrow_iterable() requires explicit bqstorage_client
# to use BQ Storage API (parallel gRPC streams, ~300K rows/sec).
# Without it, silently falls back to REST pagination (~5K rows/sec).
# This is critical when querying VIEWS (DataView): BQ materializes
# the view into a temp table, and Storage API reads from that temp table.
try:
storage_kwargs = {}
if self.bqstorage_client:
storage_kwargs["bqstorage_client"] = self.bqstorage_client
batch_iter = row_iter.to_arrow_iterable(**storage_kwargs)
# Probe first batch to detect Storage API permission errors early
first_batch = next(batch_iter, None)
if first_batch is not None:
yield first_batch
yield from batch_iter
return
except Exception as storage_err:
if "readsessions" not in str(storage_err) and "PERMISSION_DENIED" not in str(storage_err):
raise
logger.warning(
"BQ Storage API unavailable (missing readsessions permission), "
"falling back to REST API (streaming)"
)
# Fallback: REST API streaming (re-execute query for fresh RowIterator)
row_iter = self.client.query(sql, job_config=job_config).result()
yield from row_iter.to_arrow_iterable(create_bqstorage_client=False)
def read_table_streaming(
self,
table_id: str,
columns: Optional[List[str]] = None,
row_filter: Optional[str] = None,
):
"""
Read table as streaming RecordBatches (constant memory).
Args:
table_id: Full table ID (e.g., "project.dataset.table")
columns: Optional list of columns to select
row_filter: Optional SQL WHERE clause (without WHERE keyword)
Yields:
pyarrow.RecordBatch objects
"""
select_cols = ", ".join(f"`{c}`" for c in columns) if columns else "*"
sql = f"SELECT {select_cols} FROM `{table_id}`"
if row_filter:
sql += f" WHERE {row_filter}"
logger.info(
f"Streaming BQ table: {table_id} "
f"(filter: {row_filter or 'none'}, "
f"storage_api={'yes' if self.bqstorage_client else 'no'})"
)
yield from self.query_to_arrow_batches(sql)
def read_table(
self,
table_id: str,
columns: Optional[List[str]] = None,
row_filter: Optional[str] = None,
) -> pa.Table:
"""
Read full table (or filtered subset) as PyArrow Table.
Args:
table_id: Full table ID (e.g., "project.dataset.table")
columns: Optional list of columns to select
row_filter: Optional SQL WHERE clause (without WHERE keyword)
Returns:
PyArrow Table with table data
"""
# Build SELECT clause
select_cols = ", ".join(f"`{c}`" for c in columns) if columns else "*"
sql = f"SELECT {select_cols} FROM `{table_id}`"
if row_filter:
sql += f" WHERE {row_filter}"
logger.info(f"Reading BQ table: {table_id} (filter: {row_filter or 'none'})")
return self.query_to_arrow(sql)
def read_table_incremental(
self,
table_id: str,
incremental_column: str,
since_value: str,
columns: Optional[List[str]] = None,
) -> pa.Table:
"""
Read rows where incremental_column > since_value.
Uses parameterized query to prevent SQL injection.
Args:
table_id: Full table ID
incremental_column: Column name for incremental filter
since_value: ISO timestamp string - fetch rows after this value
columns: Optional list of columns to select
Returns:
PyArrow Table with incremental data
"""
select_cols = ", ".join(f"`{c}`" for c in columns) if columns else "*"
sql = (
f"SELECT {select_cols} FROM `{table_id}` "
f"WHERE `{incremental_column}` > @since_value"
)
params = [
bigquery.ScalarQueryParameter("since_value", "TIMESTAMP", since_value),
]
logger.info(
f"Incremental read: {table_id} WHERE {incremental_column} > {since_value}"
)
return self.query_to_arrow(sql, params=params)
def read_table_partitioned(
self,
table_id: str,
partition_column: str,
start: str,
end: Optional[str] = None,
columns: Optional[List[str]] = None,
column_type: str = "TIMESTAMP",
) -> pa.Table:
"""
Read data within a partition range.
Args:
table_id: Full table ID
partition_column: Partition column name
start: Start date/timestamp (inclusive)
end: End date/timestamp (exclusive). If None, reads to present.
columns: Optional list of columns to select
column_type: BQ SQL type for the partition column ("DATE", "TIMESTAMP", "DATETIME")
Returns:
PyArrow Table with partition range data
"""
select_cols = ", ".join(f"`{c}`" for c in columns) if columns else "*"
sql = (
f"SELECT {select_cols} FROM `{table_id}` "
f"WHERE `{partition_column}` >= @start_value"
)
params = [
bigquery.ScalarQueryParameter("start_value", column_type, start),
]
if end:
sql += f" AND `{partition_column}` < @end_value"
params.append(
bigquery.ScalarQueryParameter("end_value", column_type, end),
)
logger.info(
f"Partitioned read: {table_id} [{start} .. {end or 'now'})"
)
return self.query_to_arrow(sql, params=params)
def read_table_partitioned_streaming(
self,
table_id: str,
partition_column: str,
start: str,
end: Optional[str] = None,
columns: Optional[List[str]] = None,
column_type: str = "TIMESTAMP",
):
"""
Read data within a partition range as streaming RecordBatches (constant memory).
Unlike read_table_partitioned(), this does NOT load entire result into memory.
Each RecordBatch is a small chunk that can be written to disk immediately.
Args:
table_id: Full table ID
partition_column: Partition column name
start: Start date/timestamp (inclusive)
end: End date/timestamp (exclusive). If None, reads to present.
columns: Optional list of columns to select
column_type: BQ SQL type for the partition column ("DATE", "TIMESTAMP", "DATETIME")
Yields:
pyarrow.RecordBatch objects
"""
select_cols = ", ".join(f"`{c}`" for c in columns) if columns else "*"
sql = (
f"SELECT {select_cols} FROM `{table_id}` "
f"WHERE `{partition_column}` >= @start_value"
)
params = [
bigquery.ScalarQueryParameter("start_value", column_type, start),
]
if end:
sql += f" AND `{partition_column}` < @end_value"
params.append(
bigquery.ScalarQueryParameter("end_value", column_type, end),
)
logger.info(
f"Partitioned streaming read: {table_id} [{start} .. {end or 'now'})"
)
yield from self.query_to_arrow_batches(sql, params=params)
def discover_all_tables(self, dataset_id: Optional[str] = None) -> List[Dict[str, Any]]:
"""
List all tables in the project (or specific dataset).
Args:
dataset_id: Optional dataset ID to limit scope
Returns:
Normalized list of table dicts with id, name, columns, row_count, etc.
"""
logger.info(f"Discovering BQ tables (dataset={dataset_id or 'all'})...")
result = []
if dataset_id:
datasets = [self.client.get_dataset(dataset_id)]
else:
datasets = list(self.client.list_datasets())
for dataset in datasets:
ds_ref = dataset.reference if hasattr(dataset, "reference") else dataset.dataset_id
ds_id = str(ds_ref)
try:
tables = list(self.client.list_tables(ds_ref))
except Exception as e:
logger.warning(f"Could not list tables in dataset {ds_id}: {e}")
continue
for table_item in tables:
full_id = f"{table_item.project}.{table_item.dataset_id}.{table_item.table_id}"
try:
table_detail = self.client.get_table(full_id)
columns = [f.name for f in table_detail.schema]
result.append({
"id": full_id,
"name": table_item.table_id,
"bucket_id": table_item.dataset_id,
"bucket_name": table_item.dataset_id,
"columns": columns,
"row_count": table_detail.num_rows or 0,
"size_bytes": table_detail.num_bytes or 0,
"primary_key": [],
"last_change": (
table_detail.modified.isoformat()
if table_detail.modified else None
),
"last_import": None,
})
except Exception as e:
logger.warning(f"Could not get details for {full_id}: {e}")
logger.info(f"Discovered {len(result)} BQ tables")
return result
def test_connection(self) -> bool:
"""
Test connection to BigQuery.
Returns:
True if connection works, False otherwise
"""
try:
query_job = self.client.query("SELECT 1")
list(query_job.result())
logger.info(f"BigQuery connection OK (project: {self.project_id})")
return True
except Exception as e:
logger.error(f"BigQuery connection test failed: {e}")
return False
def create_client() -> BigQueryClient:
"""
Factory function to create BigQuery client.
Returns:
BigQueryClient instance
"""
return BigQueryClient()

View file

@ -1,820 +0,0 @@
"""
Keboola data source adapter.
Implements the DataSource interface for Keboola Storage API.
Downloads tables via the Storage API, converts CSV exports to Parquet files
with full type metadata from Keboola column metadata.
"""
import logging
import tempfile
from pathlib import Path
from typing import Dict, List, Optional, Any
from datetime import datetime, timedelta
import pyarrow as pa
from tqdm import tqdm
from src.config import get_config, TableConfig
from src.data_sync import DataSource, SyncState, _get_uncompressed_size
from src.parquet_manager import (
create_parquet_manager,
_convert_column,
convert_date_columns_to_date32,
apply_schema_to_table,
)
from .client import create_client as create_keboola_client
logger = logging.getLogger(__name__)
class KeboolaDataSource(DataSource):
"""
Data source: Direct download from Keboola Storage API.
Downloads data directly from a Keboola project, converts CSV exports
to typed Parquet files using column metadata for schema enforcement.
"""
def __init__(self):
"""Initialize Keboola source with full env var validation."""
self.config = get_config()
# Validate all required Keboola env vars before proceeding
missing = []
if not self.config.keboola_token:
missing.append("KEBOOLA_STORAGE_TOKEN")
if not self.config.keboola_stack_url:
missing.append("KEBOOLA_STACK_URL")
if not self.config.keboola_project_id:
missing.append("KEBOOLA_PROJECT_ID")
if missing:
raise ValueError(
f"Missing required environment variables for Keboola connector: "
f"{', '.join(missing)}. See config/.env.template"
)
self.keboola_client = create_keboola_client()
self.parquet_manager = create_parquet_manager()
def get_column_metadata(self, table_id: str) -> Optional[Dict[str, Any]]:
"""Return Keboola metadata with provider cascade applied.
Delegates type resolution to the client's _resolve_keboola_type(),
and extracts descriptions via provider priority cascade.
Returns:
{"columns": {"col_name": {"source_type": "...", "description": "..."}}}
or None if metadata is unavailable.
"""
raw = self.keboola_client.get_table_metadata(table_id)
column_metadata = raw.get("column_metadata", {})
if not column_metadata:
return None
PROVIDER_PRIORITY = [
"user",
"ai-metadata-enrichment",
"keboola.snowflake-transformation",
]
result = {}
for col_name, col_meta_list in column_metadata.items():
# Delegate type resolution to client
source_type = self.keboola_client._resolve_keboola_type(col_meta_list)
# Extract description via provider cascade
description = None
if isinstance(col_meta_list, list):
description_by_provider = {}
for entry in col_meta_list:
provider = entry.get("provider", "")
key = entry.get("key", "")
value = entry.get("value", "")
if key == "KBC.description":
description_by_provider[provider] = value
for p in PROVIDER_PRIORITY:
if p in description_by_provider:
description = description_by_provider[p]
break
result[col_name] = {"source_type": source_type}
if description:
result[col_name]["description"] = description
return {"columns": result}
def discover_tables(self) -> List[Dict[str, Any]]:
"""Discover all available tables from Keboola Storage."""
return self.keboola_client.discover_all_tables()
def get_source_name(self) -> str:
"""Display name of this data source."""
return "Keboola Storage API"
def _cleanup_staging(self):
"""
Remove all files from staging directory.
Called before chunked initial load and after failures to free up disk space.
"""
staging_dir = self.config.get_staging_path()
for f in staging_dir.glob("*"):
if f.is_file():
try:
f.unlink()
logger.debug(f"Cleaned up staging file: {f.name}")
except Exception as e:
logger.warning(f"Failed to clean up {f.name}: {e}")
def sync_table(
self,
table_config: TableConfig,
sync_state: SyncState,
) -> Dict[str, Any]:
"""
Synchronize table from Keboola.
According to sync_strategy calls _full_refresh or _incremental_sync.
Args:
table_config: Table configuration
sync_state: Sync state manager
Returns:
Dictionary with result:
- success: bool
- rows: int
- strategy: str
- error: str (if failed)
"""
logger.info(f"Syncing table: {table_config.name} ({table_config.sync_strategy})")
# Refresh metadata cache for this table to get latest types from Keboola
if table_config.id in self.keboola_client.metadata_cache:
del self.keboola_client.metadata_cache[table_config.id]
logger.debug(f"Cleared metadata cache for {table_config.id}")
try:
if table_config.sync_strategy == "full_refresh":
result = self._full_refresh(table_config)
elif table_config.sync_strategy == "partitioned":
result = self._partitioned_sync(table_config)
else: # incremental
result = self._incremental_sync(table_config, sync_state)
# Update sync state
sync_state.update_sync(
table_id=table_config.id,
table_name=table_config.name,
strategy=table_config.sync_strategy,
rows=result["rows"],
file_size_bytes=result["file_size_bytes"],
columns=result.get("columns", 0),
uncompressed_bytes=result.get("uncompressed_bytes", 0),
)
return {
"success": True,
"rows": result["rows"],
"strategy": table_config.sync_strategy,
"file_size_mb": result["file_size_bytes"] / 1024 / 1024,
}
except Exception as e:
logger.error(f"Error syncing table {table_config.name}: {e}")
return {
"success": False,
"error": str(e),
"strategy": table_config.sync_strategy,
}
def _full_refresh(self, table_config: TableConfig) -> Dict[str, Any]:
"""
Full refresh sync strategy.
Downloads entire table and replaces existing Parquet file.
"""
logger.info(f"Full refresh: {table_config.name}")
parquet_path = self.config.get_parquet_path(table_config)
staging_dir = self.config.get_staging_path()
with tempfile.NamedTemporaryFile(
mode="w", suffix=".csv", delete=False, dir=staging_dir
) as tmp_file:
tmp_csv_path = Path(tmp_file.name)
try:
# 1. Export from Keboola to CSV
filters_desc = ""
if table_config.where_filters:
filters_desc = f" (filters: {len(table_config.where_filters)})"
logger.info(f" -> Exporting from Keboola...{filters_desc}")
export_info = self.keboola_client.export_table(
table_id=table_config.id,
output_path=tmp_csv_path,
where_filters=table_config.where_filters if table_config.where_filters else None,
)
# 2. Get dtypes for proper conversion
dtypes = self.keboola_client.get_pandas_dtypes(table_config.id)
date_columns = self.keboola_client.get_date_columns(table_config.id)
pyarrow_schema = self.keboola_client.get_pyarrow_schema(table_config.id)
# 3. Convert CSV -> Parquet
logger.info(" -> Converting to Parquet...")
parquet_info = self.parquet_manager.csv_to_parquet(
csv_path=tmp_csv_path,
parquet_path=parquet_path,
dtypes=dtypes,
table_id=table_config.id,
date_columns=date_columns,
pyarrow_schema=pyarrow_schema,
)
return {
"rows": parquet_info["rows"],
"file_size_bytes": parquet_info["parquet_size_bytes"],
"columns": parquet_info.get("columns", 0),
"uncompressed_bytes": _get_uncompressed_size(parquet_path),
}
finally:
if tmp_csv_path.exists():
tmp_csv_path.unlink()
def _incremental_sync(
self,
table_config: TableConfig,
sync_state: SyncState,
) -> Dict[str, Any]:
"""
Incremental sync strategy.
Downloads only changed rows using changedSince API parameter.
If partition_by is configured, outputs are partitioned.
Otherwise, merges into a single Parquet file.
"""
if table_config.partition_by:
return self._incremental_partitioned_sync(table_config, sync_state)
return self._incremental_single_file_sync(table_config, sync_state)
def _incremental_single_file_sync(
self,
table_config: TableConfig,
sync_state: SyncState,
) -> Dict[str, Any]:
"""
Incremental sync to a single Parquet file (no partitioning).
"""
logger.info(f"Incremental sync (single file): {table_config.name}")
parquet_path = self.config.get_parquet_path(table_config)
staging_dir = self.config.get_staging_path()
# Determine timestamp for changedSince
last_sync = sync_state.get_last_sync(table_config.id)
if last_sync:
last_sync_dt = datetime.fromisoformat(last_sync)
window_days = table_config.incremental_window_days or 7
changed_since_dt = last_sync_dt - timedelta(days=window_days)
changed_since = changed_since_dt.isoformat()
logger.info(
f" -> ChangedSince: {changed_since} (window: {window_days} days)"
)
else:
if table_config.max_history_days:
changed_since_dt = datetime.now() - timedelta(days=table_config.max_history_days)
changed_since = changed_since_dt.isoformat()
logger.info(
f" -> First sync, limited to last {table_config.max_history_days} days "
f"(changedSince: {changed_since})"
)
else:
logger.info(" -> First sync, downloading all data...")
changed_since = None
with tempfile.NamedTemporaryFile(
mode="w", suffix=".csv", delete=False, dir=staging_dir
) as tmp_file:
tmp_csv_path = Path(tmp_file.name)
try:
# 1. Export changed data from Keboola
logger.info(" -> Exporting changes from Keboola...")
export_info = self.keboola_client.export_table(
table_id=table_config.id,
output_path=tmp_csv_path,
changed_since=changed_since,
)
if export_info["exported_rows"] == 0:
logger.info(" -> No changes since last synchronization")
if parquet_path.exists():
existing_info = self.parquet_manager.get_parquet_info(parquet_path)
return {
"rows": existing_info["rows"],
"file_size_bytes": existing_info["file_size_bytes"],
"columns": existing_info.get("columns", 0),
"uncompressed_bytes": _get_uncompressed_size(parquet_path),
}
else:
return {"rows": 0, "file_size_bytes": 0, "columns": 0, "uncompressed_bytes": 0}
# 2. Get dtypes and date columns
dtypes = self.keboola_client.get_pandas_dtypes(table_config.id)
date_columns = self.keboola_client.get_date_columns(table_config.id)
pyarrow_schema = self.keboola_client.get_pyarrow_schema(table_config.id)
# 3. If Parquet exists, merge; otherwise create new
if parquet_path.exists():
logger.info(
f" -> Merging {export_info['exported_rows']} changes into Parquet..."
)
with tempfile.NamedTemporaryFile(
mode="w", suffix=".parquet", delete=False, dir=staging_dir
) as tmp_parquet_file:
tmp_parquet_path = Path(tmp_parquet_file.name)
try:
merge_info = self.parquet_manager.merge_parquet(
existing_parquet=parquet_path,
new_csv=tmp_csv_path,
output_parquet=tmp_parquet_path,
primary_key=table_config.get_primary_key_columns(),
dtypes=dtypes,
date_columns=date_columns,
pyarrow_schema=pyarrow_schema,
)
tmp_parquet_path.replace(parquet_path)
return {
"rows": merge_info["total_rows"],
"file_size_bytes": parquet_path.stat().st_size,
"columns": merge_info.get("total_columns", 0),
"uncompressed_bytes": _get_uncompressed_size(parquet_path),
}
finally:
if tmp_parquet_path.exists():
tmp_parquet_path.unlink()
else:
logger.info(" -> Creating new Parquet...")
parquet_info = self.parquet_manager.csv_to_parquet(
csv_path=tmp_csv_path,
parquet_path=parquet_path,
dtypes=dtypes,
table_id=table_config.id,
date_columns=date_columns,
pyarrow_schema=pyarrow_schema,
)
return {
"rows": parquet_info["rows"],
"file_size_bytes": parquet_info["parquet_size_bytes"],
"columns": parquet_info.get("columns", 0),
"uncompressed_bytes": _get_uncompressed_size(parquet_path),
}
finally:
if tmp_csv_path.exists():
tmp_csv_path.unlink()
def _incremental_partitioned_sync(
self,
table_config: TableConfig,
sync_state: SyncState,
) -> Dict[str, Any]:
"""
Incremental sync with partitioned output.
Downloads only changed rows using changedSince API parameter,
then partitions by partition_by column and merges into existing
partition files. Same logic as _partitioned_sync but uses
changedSince instead of whereFilters.
For initial load of large tables (max_history_days > chunk_days),
uses chunked download to avoid filling up disk space. Each chunk
has 1-day overlap with the next to ensure no data is lost at boundaries.
"""
import pandas as pd
logger.info(
f"Incremental sync (partitioned): {table_config.name} "
f"(by {table_config.partition_by}, {table_config.partition_granularity})"
)
partition_dir = self.config.get_parquet_path(table_config)
staging_dir = self.config.get_staging_path()
last_sync = sync_state.get_last_sync(table_config.id)
# For initial load (no last_sync), always use chunked approach
if not last_sync:
return self._chunked_initial_load(table_config, partition_dir, staging_dir)
# Regular incremental sync
last_sync_dt = datetime.fromisoformat(last_sync)
window_days = table_config.incremental_window_days or 7
changed_since_dt = last_sync_dt - timedelta(days=window_days)
changed_since = changed_since_dt.isoformat()
logger.info(
f" -> ChangedSince: {changed_since} (window: {window_days} days)"
)
with tempfile.NamedTemporaryFile(
mode="w", suffix=".csv", delete=False, dir=staging_dir
) as tmp_file:
tmp_csv_path = Path(tmp_file.name)
try:
logger.info(" -> Exporting changes from Keboola...")
export_info = self.keboola_client.export_table(
table_id=table_config.id,
output_path=tmp_csv_path,
changed_since=changed_since,
)
if export_info["exported_rows"] == 0:
logger.info(" -> No changes since last synchronization")
return self._get_partition_totals(partition_dir)
dtypes = self.keboola_client.get_pandas_dtypes(table_config.id)
date_columns = self.keboola_client.get_date_columns(table_config.id)
pyarrow_schema = self.keboola_client.get_pyarrow_schema(table_config.id)
logger.info(f" -> Processing {export_info['exported_rows']} changed rows...")
partitions_updated = self._process_csv_to_partitions(
tmp_csv_path, table_config, partition_dir,
dtypes=dtypes, date_columns=date_columns, pyarrow_schema=pyarrow_schema,
)
self._deduplicate_partitions(
table_config, partitions_updated,
date_columns=date_columns, pyarrow_schema=pyarrow_schema,
)
logger.info(f" -> Incremental sync complete, {len(partitions_updated)} partitions updated")
return self._get_partition_totals(partition_dir)
finally:
if tmp_csv_path.exists():
tmp_csv_path.unlink()
def _chunked_initial_load(
self,
table_config: TableConfig,
partition_dir: Path,
staging_dir: Path,
) -> Dict[str, Any]:
"""
Chunked initial load for large tables.
Downloads data in time-window chunks to avoid filling up disk space.
Each chunk has 1-day overlap with the next to ensure no data is lost
at boundaries. Deduplication removes any duplicates from overlaps.
"""
chunk_days = table_config.initial_load_chunk_days
max_history_days = table_config.max_history_days
overlap_days = 1
max_chunks_safety = 120
now = datetime.now()
if max_history_days:
num_chunks = (max_history_days + chunk_days - 1) // chunk_days
logger.info(
f" -> CHUNKED INITIAL LOAD: {max_history_days} days in {num_chunks} chunks "
f"of {chunk_days} days each (with {overlap_days}-day overlap)"
)
else:
num_chunks = None
logger.info(
f" -> CHUNKED INITIAL LOAD: iterating backwards in {chunk_days}-day chunks "
f"until no more data (with {overlap_days}-day overlap)"
)
self._cleanup_staging()
dtypes = self.keboola_client.get_pandas_dtypes(table_config.id)
date_columns = self.keboola_client.get_date_columns(table_config.id)
pyarrow_schema = self.keboola_client.get_pyarrow_schema(table_config.id)
all_partitions_updated = set()
chunk_idx = 0
consecutive_empty_chunks = 0
while True:
if chunk_idx >= max_chunks_safety:
logger.warning(f" -> Reached safety limit of {max_chunks_safety} chunks, stopping")
break
if num_chunks is not None and chunk_idx >= num_chunks:
break
chunk_end_offset = chunk_idx * chunk_days
chunk_start_offset = chunk_end_offset + chunk_days + overlap_days
chunk_end = now - timedelta(days=chunk_end_offset) if chunk_idx > 0 else None
chunk_start = now - timedelta(days=chunk_start_offset)
if max_history_days and chunk_start_offset > max_history_days:
chunk_start = now - timedelta(days=max_history_days)
chunk_label = f"{chunk_idx + 1}" if num_chunks is None else f"{chunk_idx + 1}/{num_chunks}"
logger.info(
f" -> Chunk {chunk_label}: "
f"{chunk_start.strftime('%Y-%m-%d')} to "
f"{chunk_end.strftime('%Y-%m-%d') if chunk_end else 'now'}"
)
with tempfile.NamedTemporaryFile(
mode="w", suffix=".csv", delete=False, dir=staging_dir
) as tmp_file:
tmp_csv_path = Path(tmp_file.name)
try:
export_info = self.keboola_client.export_table(
table_id=table_config.id,
output_path=tmp_csv_path,
changed_since=chunk_start.isoformat(),
changed_until=chunk_end.isoformat() if chunk_end else None,
)
if export_info["exported_rows"] == 0:
logger.info(" -> No data in this chunk")
consecutive_empty_chunks += 1
if num_chunks is None and consecutive_empty_chunks >= 2:
logger.info(
f" -> Found {consecutive_empty_chunks} consecutive empty chunks, "
f"assuming end of history"
)
break
chunk_idx += 1
continue
consecutive_empty_chunks = 0
logger.info(f" -> Exported {export_info['exported_rows']} rows")
partitions_updated = self._process_csv_to_partitions(
tmp_csv_path, table_config, partition_dir,
dtypes=dtypes, date_columns=date_columns, pyarrow_schema=pyarrow_schema,
)
all_partitions_updated.update(partitions_updated)
logger.info(f" -> Processed into {len(partitions_updated)} partitions")
finally:
if tmp_csv_path.exists():
tmp_csv_path.unlink()
chunk_idx += 1
if all_partitions_updated:
logger.info(
f" -> Final deduplication of {len(all_partitions_updated)} partitions "
f"(removing duplicates from {overlap_days}-day overlaps)..."
)
self._deduplicate_partitions(
table_config, all_partitions_updated,
date_columns=date_columns, pyarrow_schema=pyarrow_schema,
)
logger.info(
f" -> Chunked initial load complete: {len(all_partitions_updated)} partitions"
)
return self._get_partition_totals(partition_dir)
def _process_csv_to_partitions(
self,
csv_path: Path,
table_config: TableConfig,
partition_dir: Path,
dtypes: Optional[Dict[str, str]] = None,
date_columns: Optional[List[str]] = None,
pyarrow_schema: Optional[pa.Schema] = None,
) -> set:
"""
Process CSV file and write to partition files.
Returns:
Set of partition keys that were updated
"""
import pandas as pd
import pyarrow.parquet as pq
partition_col = table_config.partition_by
granularity = table_config.partition_granularity or "month"
partitions_updated = set()
chunk_size = 500000 # 500k rows per pandas chunk
chunk_num = 0
for chunk_df in pd.read_csv(csv_path, chunksize=chunk_size, dtype=str):
chunk_num += 1
logger.debug(f" -> Processing pandas chunk {chunk_num} ({len(chunk_df)} rows)...")
if partition_col not in chunk_df.columns:
raise ValueError(f"Partition column '{partition_col}' not found in data")
# Apply dtypes using _convert_column (except datetime columns)
if dtypes:
for col, dtype in dtypes.items():
if col in chunk_df.columns and "datetime" not in dtype:
try:
chunk_df[col] = _convert_column(chunk_df[col], dtype, col_name=col)
except Exception as e:
logger.warning(f"Failed to apply dtype {dtype} to column {col}: {e}")
# Convert partition column to datetime
if not pd.api.types.is_datetime64_any_dtype(chunk_df[partition_col]):
chunk_df[partition_col] = pd.to_datetime(
chunk_df[partition_col], format="ISO8601", utc=True
)
# Create partition key based on granularity
if granularity == "month":
chunk_df["_partition_key"] = chunk_df[partition_col].dt.strftime("%Y_%m")
elif granularity == "day":
chunk_df["_partition_key"] = chunk_df[partition_col].dt.strftime("%Y_%m_%d")
elif granularity == "year":
chunk_df["_partition_key"] = chunk_df[partition_col].dt.strftime("%Y")
# Group by partition and append to partition files
for partition_key, partition_df in chunk_df.groupby("_partition_key"):
partition_df = partition_df.drop(columns=["_partition_key"])
partition_path = self.config.get_partition_path(table_config, partition_key)
partitions_updated.add(partition_key)
if partition_path.exists():
existing_df = pd.read_parquet(partition_path)
merged_df = pd.concat([existing_df, partition_df], ignore_index=True)
table = pa.Table.from_pandas(merged_df, preserve_index=False)
if date_columns:
table = convert_date_columns_to_date32(table, date_columns)
if pyarrow_schema:
table = apply_schema_to_table(table, pyarrow_schema)
pq.write_table(table, partition_path, compression="snappy")
else:
table = pa.Table.from_pandas(partition_df, preserve_index=False)
if date_columns:
table = convert_date_columns_to_date32(table, date_columns)
if pyarrow_schema:
table = apply_schema_to_table(table, pyarrow_schema)
pq.write_table(table, partition_path, compression="snappy")
return partitions_updated
def _deduplicate_partitions(
self,
table_config: TableConfig,
partitions_to_dedup: set,
date_columns: Optional[List[str]] = None,
pyarrow_schema: Optional[pa.Schema] = None,
):
"""
Deduplicate partition files based on primary key.
"""
import pandas as pd
import pyarrow.parquet as pq
primary_key_cols = table_config.get_primary_key_columns()
logger.info(f" -> Deduplicating {len(partitions_to_dedup)} partitions...")
for partition_key in sorted(partitions_to_dedup):
partition_path = self.config.get_partition_path(table_config, partition_key)
if not partition_path.exists():
continue
df = pd.read_parquet(partition_path)
rows_before = len(df)
df = df.drop_duplicates(subset=primary_key_cols, keep="last")
rows_after = len(df)
table = pa.Table.from_pandas(df, preserve_index=False)
if date_columns:
table = convert_date_columns_to_date32(table, date_columns)
if pyarrow_schema:
table = apply_schema_to_table(table, pyarrow_schema)
pq.write_table(table, partition_path, compression="snappy")
if rows_before != rows_after:
logger.debug(
f" -> Partition {partition_key}: {rows_before} -> {rows_after} rows "
f"(removed {rows_before - rows_after} duplicates)"
)
def _get_partition_totals(self, partition_dir: Path) -> Dict[str, Any]:
"""
Calculate totals from all partition files in a directory.
"""
import pyarrow.parquet as pq
total_rows = 0
total_size = 0
total_uncompressed = 0
total_columns = 0
if not partition_dir.exists():
return {"rows": 0, "file_size_bytes": 0, "columns": 0, "uncompressed_bytes": 0}
all_partitions = list(partition_dir.glob("*.parquet"))
for part_path in all_partitions:
try:
pf = pq.ParquetFile(part_path)
meta = pf.metadata
total_rows += meta.num_rows
total_size += part_path.stat().st_size
if total_columns == 0:
total_columns = len(pf.schema_arrow)
for rg_idx in range(meta.num_row_groups):
rg = meta.row_group(rg_idx)
for col_idx in range(rg.num_columns):
total_uncompressed += rg.column(col_idx).total_uncompressed_size
except Exception as e:
logger.warning(f" -> Skipping corrupt partition {part_path.name}: {e}")
return {
"rows": total_rows,
"file_size_bytes": total_size,
"partitions": len(all_partitions),
"columns": total_columns,
"uncompressed_bytes": total_uncompressed,
}
def _partitioned_sync(self, table_config: TableConfig) -> Dict[str, Any]:
"""
Partitioned sync strategy.
Downloads data and splits into monthly (or other granularity) partitions.
Each partition is stored as separate Parquet file and merged independently.
"""
logger.info(
f"Partitioned sync: {table_config.name} "
f"(by {table_config.partition_by}, {table_config.partition_granularity})"
)
partition_dir = self.config.get_parquet_path(table_config)
staging_dir = self.config.get_staging_path()
with tempfile.NamedTemporaryFile(
mode="w", suffix=".csv", delete=False, dir=staging_dir
) as tmp_file:
tmp_csv_path = Path(tmp_file.name)
try:
filters_desc = ""
if table_config.where_filters:
filters_desc = f" (filters: {len(table_config.where_filters)})"
logger.info(f" -> Exporting from Keboola...{filters_desc}")
export_info = self.keboola_client.export_table(
table_id=table_config.id,
output_path=tmp_csv_path,
where_filters=table_config.where_filters if table_config.where_filters else None,
)
if export_info["exported_rows"] == 0:
logger.info(" -> No data exported")
return self._get_partition_totals(partition_dir)
dtypes = self.keboola_client.get_pandas_dtypes(table_config.id)
date_columns = self.keboola_client.get_date_columns(table_config.id)
pyarrow_schema = self.keboola_client.get_pyarrow_schema(table_config.id)
logger.info(f" -> Processing CSV in chunks ({export_info['exported_rows']} rows)...")
partitions_seen = self._process_csv_to_partitions(
tmp_csv_path, table_config, partition_dir,
dtypes=dtypes, date_columns=date_columns, pyarrow_schema=pyarrow_schema,
)
self._deduplicate_partitions(
table_config, partitions_seen,
date_columns=date_columns, pyarrow_schema=pyarrow_schema,
)
totals = self._get_partition_totals(partition_dir)
logger.info(
f" -> Partitioned sync complete: {totals.get('partitions', 0)} partitions on disk, "
f"{totals['rows']} total rows"
)
return totals
finally:
if tmp_csv_path.exists():
tmp_csv_path.unlink()

View file

@ -23,7 +23,16 @@ import requests
from kbcstorage.client import Client
from kbcstorage.tables import Tables
from src.config import get_config, TableConfig, WhereFilter
from dataclasses import dataclass, field
from typing import Optional as _Opt
@dataclass
class WhereFilter:
"""Keboola where filter for export operations."""
column: str
operator: str = "eq"
values: list = field(default_factory=list)
logger = logging.getLogger(__name__)
@ -86,10 +95,8 @@ class KeboolaClient:
token: Storage API token. If None, loads from configuration.
url: Stack URL. If None, loads from configuration.
"""
config = get_config()
self.token = token or config.keboola_token
self.url = url or config.keboola_stack_url
self.token = token or os.environ.get("KEBOOLA_STORAGE_TOKEN", "")
self.url = url or os.environ.get("KEBOOLA_STACK_URL", "")
if not self.token:
raise ValueError(
@ -824,7 +831,7 @@ def create_client() -> KeboolaClient:
"""
Factory function to create Keboola client.
Uses configuration from get_config().
Uses environment variables for token and URL.
Returns:
KeboolaClient instance
@ -848,21 +855,21 @@ if __name__ == "__main__":
print(" ❌ Connection failed!")
exit(1)
# Test metadata
# Test metadata with discovered tables
print("\n2⃣ Testing metadata...")
config = get_config()
if config.tables:
test_table = config.tables[0]
print(f" Testing table: {test_table.id}")
tables = client.discover_all_tables()
if tables:
test_table_id = tables[0].get("id", tables[0].get("name", ""))
print(f" Testing table: {test_table_id}")
metadata = client.get_table_metadata(test_table.id)
metadata = client.get_table_metadata(test_table_id)
print(f" ✅ Metadata loaded:")
print(f" Columns: {len(metadata.get('columns', []))}")
print(f" Rows: {metadata.get('row_count', 0):,}")
print(f" Size: {metadata.get('data_size_bytes', 0) / 1024 / 1024:.2f} MB")
# Test dtypes
dtypes = client.get_pandas_dtypes(test_table.id)
dtypes = client.get_pandas_dtypes(test_table_id)
print(f" Pandas dtypes:")
for col, dtype in list(dtypes.items())[:5]:
print(f" {col}: {dtype}")

View file

@ -1,194 +0,0 @@
"""Tests for Keboola adapter and DataSource ABC / factory in src.data_sync.
Covers:
- DataSource ABC default method behaviour
- create_data_source factory: keboola import error, unknown source, dynamic lookup
- KeboolaDataSource env var validation
"""
from unittest.mock import patch, MagicMock
import pytest
from src.data_sync import DataSource, SyncState, create_data_source
# ---------------------------------------------------------------------------
# Helpers
# ---------------------------------------------------------------------------
class _MinimalSource(DataSource):
"""Concrete DataSource that only implements the required abstract method."""
def sync_table(self, table_config, sync_state):
return {"success": True, "rows": 0}
# ---------------------------------------------------------------------------
# 1. DataSource ABC default methods
# ---------------------------------------------------------------------------
class TestDataSourceABCDefaultMethods:
"""Verify that the optional methods on the DataSource ABC return sensible defaults."""
def test_get_column_metadata_returns_none(self):
source = _MinimalSource()
assert source.get_column_metadata("any.table.id") is None
def test_get_source_name_returns_unknown(self):
source = _MinimalSource()
assert source.get_source_name() == "Unknown"
# ---------------------------------------------------------------------------
# 2. Factory: keboola without kbcstorage
# ---------------------------------------------------------------------------
class TestFactoryKeboolaWithoutKbcstorage:
"""create_data_source('keboola') must raise ImportError when kbcstorage is missing."""
def test_raises_import_error(self):
# Patch the import inside create_data_source so that importing
# connectors.keboola.adapter triggers a ModuleNotFoundError
# mentioning kbcstorage (simulates the package not being installed).
original_import = __builtins__.__import__ if hasattr(__builtins__, "__import__") else __import__
def _fake_import(name, *args, **kwargs):
if name == "connectors.keboola.adapter":
raise ModuleNotFoundError("No module named 'kbcstorage'")
return original_import(name, *args, **kwargs)
with patch("builtins.__import__", side_effect=_fake_import):
with pytest.raises(ImportError, match="kbcstorage"):
create_data_source("keboola")
# ---------------------------------------------------------------------------
# 3. Factory: unknown source type
# ---------------------------------------------------------------------------
class TestFactoryUnknownSource:
"""create_data_source with a non-existent source type must raise ValueError."""
def test_raises_value_error(self):
with pytest.raises(ValueError, match="Unknown data source.*nonexistent"):
create_data_source("nonexistent")
def test_error_message_contains_guidance(self):
with pytest.raises(ValueError, match="connectors/nonexistent/adapter.py"):
create_data_source("nonexistent")
# ---------------------------------------------------------------------------
# 4. Factory: dynamic connector lookup
# ---------------------------------------------------------------------------
class TestFactoryDynamicConnectorLookup:
"""create_data_source attempts dynamic import for unknown connector types."""
def test_jira_lookup_falls_through_to_value_error(self):
"""'jira' has no adapter.py exporting a DataSource, so the factory
should try importing connectors.jira.adapter, fail, and finally
raise ValueError."""
with pytest.raises(ValueError, match="Unknown data source.*jira"):
create_data_source("jira")
def test_dynamic_import_is_attempted(self):
"""Verify that importlib.import_module is called with the expected
module path when the source type is not hard-coded."""
with patch("src.data_sync.importlib.import_module", side_effect=ModuleNotFoundError) as mock_imp:
with pytest.raises(ValueError):
create_data_source("custom_source")
mock_imp.assert_called_once_with("connectors.custom_source.adapter")
def test_dynamic_import_with_factory_function(self):
"""If the dynamically loaded module exposes create_data_source(),
the factory should call it and return its result."""
fake_source = _MinimalSource()
fake_module = MagicMock()
fake_module.create_data_source = MagicMock(return_value=fake_source)
with patch("src.data_sync.importlib.import_module", return_value=fake_module):
result = create_data_source("my_connector")
assert result is fake_source
fake_module.create_data_source.assert_called_once()
def test_dynamic_import_with_datasource_subclass(self):
"""If the dynamically loaded module has no factory but exposes a
DataSource subclass, the factory should instantiate it."""
import types
fake_module = types.ModuleType("connectors.my_connector.adapter")
fake_module.MyDataSource = _MinimalSource
with patch("src.data_sync.importlib.import_module", return_value=fake_module):
result = create_data_source("my_connector")
assert isinstance(result, _MinimalSource)
# ---------------------------------------------------------------------------
# 5. KeboolaDataSource validates env vars
# ---------------------------------------------------------------------------
class TestKeboolaAdapterValidatesEnvVars:
"""KeboolaDataSource.__init__ must raise ValueError when required
Keboola env vars are missing."""
def _make_mock_config(self, token="", stack_url="", project_id=""):
"""Build a mock config with the given Keboola credential values."""
cfg = MagicMock()
cfg.keboola_token = token
cfg.keboola_stack_url = stack_url
cfg.keboola_project_id = project_id
return cfg
def test_all_missing(self):
mock_cfg = self._make_mock_config()
with patch("connectors.keboola.adapter.get_config", return_value=mock_cfg):
with pytest.raises(ValueError, match="KEBOOLA_STORAGE_TOKEN"):
from connectors.keboola.adapter import KeboolaDataSource
KeboolaDataSource()
def test_token_missing(self):
mock_cfg = self._make_mock_config(
stack_url="https://connection.keboola.com",
project_id="12345",
)
with patch("connectors.keboola.adapter.get_config", return_value=mock_cfg):
with pytest.raises(ValueError, match="KEBOOLA_STORAGE_TOKEN"):
from connectors.keboola.adapter import KeboolaDataSource
KeboolaDataSource()
def test_stack_url_missing(self):
mock_cfg = self._make_mock_config(
token="my-token",
project_id="12345",
)
with patch("connectors.keboola.adapter.get_config", return_value=mock_cfg):
with pytest.raises(ValueError, match="KEBOOLA_STACK_URL"):
from connectors.keboola.adapter import KeboolaDataSource
KeboolaDataSource()
def test_project_id_missing(self):
mock_cfg = self._make_mock_config(
token="my-token",
stack_url="https://connection.keboola.com",
)
with patch("connectors.keboola.adapter.get_config", return_value=mock_cfg):
with pytest.raises(ValueError, match="KEBOOLA_PROJECT_ID"):
from connectors.keboola.adapter import KeboolaDataSource
KeboolaDataSource()
def test_error_lists_all_missing_vars(self):
"""When multiple env vars are missing, all should appear in the error message."""
mock_cfg = self._make_mock_config()
with patch("connectors.keboola.adapter.get_config", return_value=mock_cfg):
with pytest.raises(ValueError) as exc_info:
from connectors.keboola.adapter import KeboolaDataSource
KeboolaDataSource()
msg = str(exc_info.value)
assert "KEBOOLA_STORAGE_TOKEN" in msg
assert "KEBOOLA_STACK_URL" in msg
assert "KEBOOLA_PROJECT_ID" in msg

View file

@ -15,10 +15,23 @@ from datetime import datetime, timedelta
from pathlib import Path
from typing import Dict, List, Optional, Any
from src.config import TableConfig
from dataclasses import dataclass as _dataclass
from .client import OpenMetadataClient
@_dataclass
class TableConfig:
"""Minimal table config used by the enricher.
Attributes expected by CatalogEnricher: id, name, and optional catalog_fqn.
Can be constructed from a dict, e.g. ``TableConfig(**row)`` where *row*
comes from ``TableRegistryRepository.get()``.
"""
id: str
name: str
catalog_fqn: str = ""
logger = logging.getLogger(__name__)

View file

@ -905,35 +905,35 @@ class SampleDataGenerator:
# ── Parquet conversion ─────────────────────────────────────
def _convert_to_parquet(self, parquet_dir: Path) -> None:
"""Convert generated CSVs to Parquet using project's ParquetManager."""
# Ensure project root is importable (script may run from any cwd)
project_root = Path(__file__).resolve().parent.parent
if str(project_root) not in sys.path:
sys.path.insert(0, str(project_root))
from src.parquet_manager import create_parquet_manager
"""Convert generated CSVs to Parquet using DuckDB."""
import duckdb
manager = create_parquet_manager()
parquet_dir.mkdir(parents=True, exist_ok=True)
logger.info(f" Converting to Parquet -> {parquet_dir}/")
conn = duckdb.connect()
for csv_path in sorted(self.output_dir.glob("*.csv")):
table_name = csv_path.stem
schema = TABLE_SCHEMAS.get(table_name, {})
parquet_path = parquet_dir / f"{table_name}.parquet"
result = manager.csv_to_parquet(
csv_path=csv_path,
parquet_path=parquet_path,
dtypes=schema.get("dtypes"),
parse_dates=schema.get("parse_dates"),
date_columns=schema.get("date_columns"),
table_id=f"sample.{table_name}",
conn.execute(
f"COPY (SELECT * FROM read_csv_auto('{csv_path}')) "
f"TO '{parquet_path}' (FORMAT PARQUET, COMPRESSION ZSTD)"
)
# Report stats
row_count = conn.execute(
f"SELECT count(*) FROM '{parquet_path}'"
).fetchone()[0]
parquet_size = parquet_path.stat().st_size
csv_size = csv_path.stat().st_size
ratio = csv_size / parquet_size if parquet_size > 0 else 0
logger.info(
f" {table_name}: {result['rows']:,} rows, "
f"{result['parquet_size_bytes'] / 1024:.0f} KB "
f"({result['compression_ratio']:.1f}x compression)"
f" {table_name}: {row_count:,} rows, "
f"{parquet_size / 1024:.0f} KB "
f"({ratio:.1f}x compression)"
)
conn.close()
# ── Orchestration ──────────────────────────────────────────

View file

@ -34,7 +34,8 @@ from connectors.openmetadata.transformer import (
sanitize_filename,
table_to_yaml_dict,
)
from src.config import Config
from src.db import get_system_db
from src.repositories.table_registry import TableRegistryRepository
# ---------------------------------------------------------------------------
# Logging
@ -279,14 +280,14 @@ def _write_metrics_index(
def export_tables(
client: OpenMetadataClient,
config: Config,
tables: list[dict],
docs_dir: Path,
catalog_url: str,
) -> int:
"""
Export table metadata from OpenMetadata to YAML files.
For each table defined in data_description.md:
For each table in the registry:
1. Derives the OpenMetadata FQN
2. Fetches table metadata (columns, owners, tags, description)
3. Transforms to YAML dict
@ -294,7 +295,7 @@ def export_tables(
Args:
client: Initialized OpenMetadata API client
config: Application config with table definitions
tables: List of table dicts from TableRegistryRepository.list_all()
docs_dir: Base docs directory
catalog_url: Catalog URL for header comments
@ -307,10 +308,12 @@ def export_tables(
written_files: set[Path] = set()
count = 0
for table_config in config.tables:
for tbl in tables:
table_id = tbl.get("id", "")
table_name = tbl.get("name", "")
try:
# Derive FQN: explicit override or auto-derive
fqn = table_config.catalog_fqn or f"bigquery.{table_config.id}"
# Use explicit catalog_fqn if set, otherwise derive from table id
fqn = tbl.get("catalog_fqn") or f"bigquery.{table_id}"
logger.debug(f"Fetching table metadata: {fqn}")
raw_table = client.get_table(fqn)
@ -318,7 +321,7 @@ def export_tables(
yaml_dict = table_to_yaml_dict(raw_table)
# Write table YAML
file_path = tables_dir / f"{table_config.name}.yml"
file_path = tables_dir / f"{table_name}.yml"
header = _yaml_header(catalog_url, fqn, entity_type="table")
yaml_content = yaml.dump(
yaml_dict,
@ -330,10 +333,10 @@ def export_tables(
written_files.add(file_path)
count += 1
logger.info(f"Exported table: {table_config.name} ({len(yaml_dict.get('columns', []))} columns)")
logger.info(f"Exported table: {table_name} ({len(yaml_dict.get('columns', []))} columns)")
except Exception as e:
logger.warning(f"Failed to export table {table_config.name}: {e}")
logger.warning(f"Failed to export table {table_name}: {e}")
continue
# Cleanup stale auto-generated table files
@ -443,10 +446,12 @@ def main() -> None:
# Export tables
try:
config = Config()
tables_count = export_tables(client, config, docs_dir, catalog_url)
conn = get_system_db()
repo = TableRegistryRepository(conn)
registered_tables = repo.list_all()
tables_count = export_tables(client, registered_tables, docs_dir, catalog_url)
except Exception as e:
logger.warning(f"Table export skipped (config error): {e}")
logger.warning(f"Table export skipped (registry error): {e}")
tables_count = 0
# Write sync state

View file

@ -1,653 +0,0 @@
"""
Configuration Management
This module handles:
1. Loading environment variables from .env file
2. Parsing data_description.md (YAML blocks with table definitions)
3. Validating configuration
4. Providing structured configuration data for other modules
SINGLE SOURCE OF TRUTH is data_description.md - it defines:
- List of tables to synchronize
- Sync strategies (full_refresh vs incremental)
- Primary keys and foreign keys
- Incremental columns and windows
"""
import os
import re
import logging
from pathlib import Path
from typing import Dict, List, Optional, Any
from dataclasses import dataclass, field
from datetime import datetime, timedelta
import yaml
from dotenv import load_dotenv
# Logging setup
logging.basicConfig(
level=logging.INFO,
format='%(asctime)s - %(name)s - %(levelname)s - %(message)s'
)
logger = logging.getLogger(__name__)
@dataclass
class ForeignKey:
"""
Representation of foreign key relationship between tables.
Attributes:
column: Column name in this table (e.g., "company_id")
references: Reference table and column (e.g., "company.id")
description: Relationship description
"""
column: str
references: str
description: Optional[str] = None
@dataclass
class WhereFilter:
"""
Filter for exporting subset of table data.
Used with Keboola Storage API whereFilters parameter.
Attributes:
column: Column name to filter on
operator: Comparison operator (eq, ne, gt, ge, lt, le)
values: List of values to compare against
"""
column: str
operator: str # eq, ne, gt, ge, lt, le
values: List[str] = field(default_factory=list)
@dataclass
class TableConfig:
"""
Configuration for a single table.
Attributes:
id: Full table ID in Keboola (e.g., "in.c-sfdc.company")
name: Short table name (e.g., "company")
description: Table description
primary_key: Primary key column name (optional for remote-only tables)
sync_strategy: "full_refresh", "incremental", "partitioned", or "none" (for remote-only tables)
incremental_window_days: Number of days to backtrack for incremental sync
partition_by: Column name to partition by (for incremental/partitioned with partitions)
partition_granularity: Partition granularity: "month", "day", or "year"
foreign_keys: List of foreign key relationships
where_filters: List of filters to apply when exporting (for downloading subset of data)
folder: Override folder name (instead of bucket-level folder_mapping)
max_history_days: Max days of history for initial incremental load (None = download all)
dataset: Dataset group name for on-demand tables (e.g., "kbc_telemetry_expert")
initial_load_chunk_days: Chunk size in days for chunked initial load (default: 30)
sync_schedule: Schedule for automatic sync: "every 15m", "every 1h", "daily 05:00" (UTC)
profile_after_sync: Run profiler after sync (default True; disable for frequently synced tables)
"""
id: str
name: str
description: str
primary_key: str = "" # Optional for remote-only tables
sync_strategy: str = "none" # "full_refresh", "incremental", "partitioned", or "none" (remote-only)
incremental_window_days: Optional[int] = None
partition_by: Optional[str] = None
partition_granularity: Optional[str] = None # "month", "day", "year"
foreign_keys: List[ForeignKey] = field(default_factory=list)
where_filters: List[WhereFilter] = field(default_factory=list)
folder: Optional[str] = None
max_history_days: Optional[int] = None
dataset: Optional[str] = None
initial_load_chunk_days: int = 30
incremental_column: Optional[str] = None # Column for timestamp-based incremental sync (BigQuery)
columns: Optional[List[str]] = None # Subset of columns to sync (None = all)
row_filter: Optional[str] = None # SQL WHERE clause for filtering (e.g., "event_date >= '2024-01-01'")
query_mode: str = "local" # "local" (Parquet) | "remote" (BQ direct) | "hybrid" (sync subset, query BQ)
bq_entity_type: str = "view" # "view" (Python BQ client) | "table" (DuckDB BQ extension)
partition_column_type: str = "TIMESTAMP" # BQ SQL type for partition column: "DATE", "TIMESTAMP", "DATETIME"
catalog_fqn: Optional[str] = None # Explicit OpenMetadata FQN override (auto-derived if not set)
sync_schedule: Optional[str] = None # Schedule: "every 15m", "every 1h", "daily 05:00" (UTC)
profile_after_sync: bool = True # Run profiler after sync (disable for frequently synced tables)
def __post_init__(self):
"""Validate configuration after initialization."""
# Validate query_mode
valid_query_modes = ("local", "remote", "hybrid")
if self.query_mode not in valid_query_modes:
raise ValueError(
f"Invalid query_mode '{self.query_mode}' for table {self.id}. "
f"Allowed values: {', '.join(valid_query_modes)}"
)
# Validate bq_entity_type
valid_entity_types = ("view", "table")
if self.bq_entity_type not in valid_entity_types:
raise ValueError(
f"Invalid bq_entity_type '{self.bq_entity_type}' for table {self.id}. "
f"Allowed values: {', '.join(valid_entity_types)}"
)
# Validate sync_strategy
if self.sync_strategy not in ["full_refresh", "incremental", "partitioned", "none"]:
raise ValueError(
f"Invalid sync_strategy '{self.sync_strategy}' for table {self.id}. "
f"Allowed values: 'full_refresh', 'incremental', 'partitioned', 'none'"
)
# For incremental strategy:
# - changedSince is calculated from last sync timestamp (Keboola internal)
# - partition_by is optional - if set, output will be partitioned
if self.sync_strategy == "incremental":
if not self.incremental_window_days:
# Default 7 days if not specified
self.incremental_window_days = 7
logger.warning(
f"Table {self.id}: incremental_window_days not set, "
f"using default 7 days"
)
# If partition_by is set, validate partition_granularity
if self.partition_by:
if not self.partition_granularity:
self.partition_granularity = "month"
logger.info(
f"Table {self.id}: partition_granularity not set, "
f"using default 'month'"
)
if self.partition_granularity not in ["month", "day", "year"]:
raise ValueError(
f"Invalid partition_granularity '{self.partition_granularity}' for table {self.id}. "
f"Allowed values: 'month', 'day', 'year'"
)
# Validate partition_column_type
valid_column_types = ("DATE", "TIMESTAMP", "DATETIME")
if self.partition_column_type not in valid_column_types:
raise ValueError(
f"Invalid partition_column_type '{self.partition_column_type}' for table {self.id}. "
f"Allowed values: {', '.join(valid_column_types)}"
)
# Validate sync_schedule format
if self.sync_schedule:
import re as _re
valid_schedule = (
_re.match(r"^every \d+[mh]$", self.sync_schedule)
or _re.match(r"^daily \d{2}:\d{2}(,\d{2}:\d{2})*$", self.sync_schedule)
)
if not valid_schedule:
raise ValueError(
f"Invalid sync_schedule '{self.sync_schedule}' for table {self.id}. "
f"Allowed formats: 'every 15m', 'every 1h', 'daily 05:00', "
f"'daily 07:00,13:00,18:00'"
)
# For partitioned, partition_by must be defined
if self.sync_strategy == "partitioned":
if not self.partition_by:
raise ValueError(
f"Table {self.id} has sync_strategy='partitioned', "
f"but partition_by is missing"
)
if not self.partition_granularity:
self.partition_granularity = "month"
logger.info(
f"Table {self.id}: partition_granularity not set, "
f"using default 'month'"
)
if self.partition_granularity not in ["month", "day", "year"]:
raise ValueError(
f"Invalid partition_granularity '{self.partition_granularity}' for table {self.id}. "
f"Allowed values: 'month', 'day', 'year'"
)
def get_primary_key_columns(self) -> List[str]:
"""
Get primary key as list of column names.
Supports both single and composite primary keys.
Composite PKs are defined as comma-separated string: "col1, col2"
Returns:
List of column names forming the primary key
"""
# Split by comma and strip whitespace
return [col.strip() for col in self.primary_key.split(",")]
def is_partitioned(self) -> bool:
"""Check if table output should be partitioned.
Returns True for:
- partitioned strategy (always partitioned)
- incremental strategy with partition_by set
"""
if self.sync_strategy == "partitioned":
return True
if self.sync_strategy == "incremental" and self.partition_by:
return True
return False
class Config:
"""
Main configuration class.
Loads environment variables and parses data_description.md.
Provides access to all configuration parameters.
"""
def __init__(self, env_file: Optional[str] = None):
"""
Initialize configuration.
Args:
env_file: Path to .env file. If None, looks for .env in project root.
"""
# Find project root (folder containing data_description.md)
self.project_root = self._find_project_root()
# Load environment variables
if env_file is None:
env_file = self.project_root / ".env"
if env_file.exists():
load_dotenv(env_file)
logger.info(f"Loaded from .env: {env_file}")
else:
logger.warning(
f".env file not found: {env_file}. "
f"Use config/.env.template as reference."
)
# Read by connectors/keboola/ if enabled
self.keboola_token = os.getenv("KEBOOLA_STORAGE_TOKEN")
self.keboola_stack_url = os.getenv("KEBOOLA_STACK_URL")
self.keboola_project_id = os.getenv("KEBOOLA_PROJECT_ID")
self.data_dir = Path(os.getenv("DATA_DIR", "./data"))
self.docs_output_dir = Path(os.getenv("DOCS_OUTPUT_DIR", "./docs"))
self.data_source = os.getenv("DATA_SOURCE", "local")
self.log_level = os.getenv("LOG_LEVEL", "INFO")
# Set log level
logging.getLogger().setLevel(self.log_level)
# Validate required environment variables
self._validate_env_vars()
# Parse data_description.md
self.tables, self.folder_mapping = self._parse_data_description()
logger.info(f"Configuration loaded: {len(self.tables)} tables")
def _find_project_root(self) -> Path:
"""
Find project root (folder containing docs/data_description.md).
Searches from current folder upwards.
Returns:
Path to project root
Raises:
FileNotFoundError: If docs/data_description.md is not found
"""
current = Path.cwd()
# Try current folder first
if (current / "docs" / "data_description.md").exists():
return current
# Try parent folders (up to 5 levels)
for _ in range(5):
current = current.parent
if (current / "docs" / "data_description.md").exists():
return current
raise FileNotFoundError(
"docs/data_description.md not found. "
"Make sure you're running from project root."
)
def _resolve_placeholder(self, value: str) -> str:
"""
Resolve placeholders in filter values.
Supported placeholders:
- {{last_week}}: 7 days ago
- {{last_month}}: 30 days ago
- {{last_2_months}}: 60 days ago
- {{last_3_months}}: 90 days ago
- {{last_6_months}}: 180 days ago
- {{last_year}}: 365 days ago
- {{last_2_years}}: 730 days ago
- {{today}}: Today's date
Args:
value: String that may contain placeholder
Returns:
Resolved string with actual date values
"""
if not isinstance(value, str):
return value
today = datetime.now()
placeholders = {
"{{last_week}}": (today - timedelta(days=7)).strftime("%Y-%m-%d"),
"{{last_month}}": (today - timedelta(days=30)).strftime("%Y-%m-%d"),
"{{last_2_months}}": (today - timedelta(days=60)).strftime("%Y-%m-%d"),
"{{last_3_months}}": (today - timedelta(days=90)).strftime("%Y-%m-%d"),
"{{last_6_months}}": (today - timedelta(days=180)).strftime("%Y-%m-%d"),
"{{last_year}}": (today - timedelta(days=365)).strftime("%Y-%m-%d"),
"{{last_2_years}}": (today - timedelta(days=730)).strftime("%Y-%m-%d"),
"{{today}}": today.strftime("%Y-%m-%d"),
}
result = value
for placeholder, replacement in placeholders.items():
if placeholder in result:
result = result.replace(placeholder, replacement)
logger.debug(f"Resolved placeholder: {placeholder} -> {replacement}")
return result
def _validate_env_vars(self):
"""
Validate that required environment variables are set based on data source type.
Raises:
ValueError: If any required variable is missing
"""
# Keboola env vars are validated by connectors/keboola/adapter.py at init time.
# No source-specific validation needed here.
pass
def _parse_data_description(self) -> tuple[List[TableConfig], Dict[str, str]]:
"""
Parse docs/data_description.md and extract table definitions.
Looks for YAML blocks in markdown file and parses them.
Returns:
Tuple of (List of TableConfig objects, folder_mapping dict)
Raises:
FileNotFoundError: If docs/data_description.md doesn't exist
yaml.YAMLError: If YAML is invalid
"""
# Check CONFIG_DIR first, then project root
config_dir = Path(os.environ.get("CONFIG_DIR", ""))
if config_dir and (config_dir / "data_description.md").exists():
data_desc_path = config_dir / "data_description.md"
else:
data_desc_path = self.project_root / "docs" / "data_description.md"
if not data_desc_path.exists():
raise FileNotFoundError(
f"docs/data_description.md not found: {data_desc_path}"
)
# Collect all markdown files to parse: main + dataset files
md_files = [data_desc_path]
datasets_dir = self.project_root / "docs" / "datasets"
if datasets_dir.exists():
for md_file in sorted(datasets_dir.glob("*.md")):
md_files.append(md_file)
logger.info(f"Found dataset file: {md_file.name}")
# Find YAML blocks (between ```yaml and ```) from all files
yaml_pattern = r'```yaml\n(.*?)```'
yaml_matches = []
for md_file in md_files:
content = md_file.read_text()
yaml_matches.extend(re.findall(yaml_pattern, content, re.DOTALL))
if not yaml_matches:
raise ValueError(
"data_description.md contains no YAML blocks. "
"Make sure tables are defined in ```yaml blocks."
)
# Parse all YAML blocks and merge them
all_tables = []
folder_mapping = {}
for yaml_block in yaml_matches:
try:
data = yaml.safe_load(yaml_block)
if data:
if "tables" in data:
all_tables.extend(data["tables"])
if "folder_mapping" in data:
folder_mapping.update(data["folder_mapping"])
except yaml.YAMLError as e:
logger.error(f"Error parsing YAML: {e}")
raise
if not all_tables:
raise ValueError(
"data_description.md contains no tables. "
"Make sure YAML block contains 'tables:' key."
)
# Convert to TableConfig objects
table_configs = []
for table_data in all_tables:
# Parse foreign keys
fk_list = []
if "foreign_keys" in table_data:
for fk_data in table_data["foreign_keys"]:
fk = ForeignKey(
column=fk_data["column"],
references=fk_data["references"],
description=fk_data.get("description")
)
fk_list.append(fk)
# Parse where filters with placeholder resolution
wf_list = []
if "where_filters" in table_data:
for wf_data in table_data["where_filters"]:
# Resolve placeholders in values
resolved_values = [
self._resolve_placeholder(v) for v in wf_data.get("values", [])
]
wf = WhereFilter(
column=wf_data["column"],
operator=wf_data["operator"],
values=resolved_values
)
wf_list.append(wf)
# Create TableConfig
config = TableConfig(
id=table_data["id"],
name=table_data["name"],
description=table_data["description"],
primary_key=table_data.get("primary_key", ""),
sync_strategy=table_data.get("sync_strategy", "none"),
incremental_window_days=table_data.get("incremental_window_days"),
partition_by=table_data.get("partition_by"),
partition_granularity=table_data.get("partition_granularity"),
foreign_keys=fk_list,
where_filters=wf_list,
folder=table_data.get("folder"),
max_history_days=table_data.get("max_history_days"),
dataset=table_data.get("dataset"),
initial_load_chunk_days=table_data.get("initial_load_chunk_days", 30),
incremental_column=table_data.get("incremental_column"),
columns=table_data.get("columns"),
row_filter=table_data.get("row_filter"),
query_mode=table_data.get("query_mode", "local"),
bq_entity_type=table_data.get("bq_entity_type", "view"),
partition_column_type=table_data.get("partition_column_type", "TIMESTAMP"),
catalog_fqn=table_data.get("catalog_fqn"),
sync_schedule=table_data.get("sync_schedule"),
profile_after_sync=table_data.get("profile_after_sync", True),
)
table_configs.append(config)
return table_configs, folder_mapping
def get_table_config(self, table_id: str) -> Optional[TableConfig]:
"""
Get configuration for specific table by ID.
Args:
table_id: Full table ID (e.g., "in.c-sfdc.company")
Returns:
TableConfig or None if table not in configuration
"""
for table in self.tables:
if table.id == table_id:
return table
return None
def get_parquet_path(self, table_config: TableConfig) -> Path:
"""
Get path to Parquet file for given table.
Format: data/parquet/{folder_name}/{table_name}.parquet
For partitioned tables: data/parquet/{folder_name}/{table_name}/ (directory)
Folder name is determined by folder_mapping in data_description.md.
Falls back to bucket name if no mapping exists.
Args:
table_config: Table configuration
Returns:
Path to Parquet file (or directory for partitioned tables)
"""
# Extract bucket name from table ID (e.g., "in.c-crm" from "in.c-crm.company")
bucket_name = ".".join(table_config.id.split(".")[:-1])
# Use folder mapping if available, otherwise fall back to bucket name
folder_name = self.folder_mapping.get(bucket_name, bucket_name)
# Table-level folder override (e.g., folder: kbc_telemetry_expert)
if table_config.folder:
folder_name = table_config.folder
parquet_dir = self.data_dir / "parquet" / folder_name
parquet_dir.mkdir(parents=True, exist_ok=True)
if table_config.is_partitioned():
# For partitioned tables, return directory path
partition_dir = parquet_dir / table_config.name
partition_dir.mkdir(parents=True, exist_ok=True)
return partition_dir
else:
return parquet_dir / f"{table_config.name}.parquet"
def get_partition_path(self, table_config: TableConfig, partition_key: str) -> Path:
"""
Get path to specific partition file.
Args:
table_config: Table configuration (must be partitioned)
partition_key: Partition key (e.g., "2026_01" for monthly)
Returns:
Path to partition Parquet file
"""
if not table_config.is_partitioned():
raise ValueError(f"Table {table_config.id} is not partitioned")
partition_dir = self.get_parquet_path(table_config)
return partition_dir / f"{partition_key}.parquet"
def get_metadata_path(self) -> Path:
"""
Get path to metadata folder.
Returns:
Path to metadata folder
"""
metadata_dir = self.data_dir / "metadata"
metadata_dir.mkdir(parents=True, exist_ok=True)
return metadata_dir
def get_staging_path(self) -> Path:
"""
Get path to staging folder for temporary files.
Uses /tmp/data_analyst_staging for faster I/O and to avoid filling /data disk.
Directory is created by deploy.sh on server startup.
Returns:
Path to staging folder
"""
staging_dir = Path("/tmp/data_analyst_staging")
staging_dir.mkdir(parents=True, exist_ok=True)
return staging_dir
def get_duckdb_path(self) -> Path:
"""
Get path to DuckDB database.
Returns:
Path to DuckDB file
"""
duckdb_dir = self.data_dir / "duckdb"
duckdb_dir.mkdir(parents=True, exist_ok=True)
return duckdb_dir / "analytics.duckdb"
# Singleton instance for easy access from entire application
_config_instance: Optional[Config] = None
def get_config() -> Config:
"""
Get singleton configuration instance.
On first call initializes configuration, then returns existing instance.
Returns:
Config instance
"""
global _config_instance
if _config_instance is None:
_config_instance = Config()
return _config_instance
# For testing - allows resetting config
def reset_config():
"""Reset singleton config instance. For testing only."""
global _config_instance
_config_instance = None
if __name__ == "__main__":
# Test configuration
print("🔧 Testing configuration...")
try:
config = get_config()
print(f"\n✅ Configuration loaded successfully!")
print(f" Project ID: {config.keboola_project_id}")
print(f" Stack URL: {config.keboola_stack_url}")
print(f" Data dir: {config.data_dir}")
print(f" Number of tables: {len(config.tables)}")
print(f"\n📊 Tables:")
for table in config.tables:
print(f" - {table.name} ({table.id})")
print(f" Strategy: {table.sync_strategy}")
if table.sync_strategy == "incremental":
print(f" Incremental window: {table.incremental_window_days} days")
if table.partition_by:
print(f" Partitioned by: {table.partition_by} ({table.partition_granularity})")
print(f" Parquet: {config.get_parquet_path(table)}")
except Exception as e:
print(f"\n❌ Error: {e}")
import traceback
traceback.print_exc()

View file

@ -1,734 +0,0 @@
"""
Data Synchronization Manager
Orchestrates data synchronization from configured sources to local Parquet files.
Main functions:
1. Tracking sync state (when was last synchronization)
2. DataSource ABC for pluggable connectors
3. Sync single table or all tables at once
4. Progress tracking and error handling
5. Schema generation from synced Parquet files
Sync State:
- Stored in data/metadata/sync_state.json
- Contains timestamp of last synchronization for each table
- Used for incremental sync (changedSince parameter)
"""
import importlib
import json
import logging
from abc import ABC, abstractmethod
from pathlib import Path
from typing import Dict, List, Optional, Any
from datetime import datetime
from tqdm import tqdm
from .config import get_config, TableConfig
from config.loader import load_instance_config
from connectors.openmetadata.enricher import CatalogEnricher
logger = logging.getLogger(__name__)
class SyncState:
"""
Synchronization state management.
Stores and loads information about last synchronization of each table.
"""
def __init__(self, state_file: Path):
"""
Args:
state_file: Path to JSON file with sync state
"""
self.state_file = state_file
self.state: Dict[str, Any] = self._load_state()
def _load_state(self) -> Dict[str, Any]:
"""
Load sync state from disk.
Returns:
Dictionary with sync state
"""
if self.state_file.exists():
try:
with open(self.state_file, "r") as f:
return json.load(f)
except Exception as e:
logger.error(f"Error loading sync state: {e}")
return {}
return {}
def _save_state(self):
"""
Save sync state to disk.
Creates data/metadata/ directory if needed.
"""
try:
self.state_file.parent.mkdir(parents=True, exist_ok=True)
with open(self.state_file, "w") as f:
json.dump(self.state, f, indent=2, default=str)
except Exception as e:
logger.error(f"Error saving sync state: {e}")
def get_last_sync(self, table_id: str) -> Optional[str]:
"""
Get timestamp of last synchronization for given table.
Args:
table_id: Table identifier
Returns:
ISO timestamp string, or None if not synced yet
"""
table_state = self.state.get(table_id, {})
return table_state.get("last_sync")
def get_table_state(self, table_id: str) -> Dict[str, Any]:
"""
Get complete sync state for a table.
Args:
table_id: Table identifier
Returns:
Dictionary with table sync state
"""
return self.state.get(table_id, {})
def update_sync(
self,
table_id: str,
table_name: str,
strategy: str,
rows: int,
file_size_bytes: int,
columns: int = 0,
uncompressed_bytes: int = 0,
):
"""
Update synchronization state for a table.
Args:
table_id: Table identifier
table_name: Human-readable table name
strategy: Sync strategy used
rows: Number of rows synced
file_size_bytes: Size of Parquet file in bytes
columns: Number of columns
uncompressed_bytes: Uncompressed data size
"""
self.state[table_id] = {
"table_name": table_name,
"last_sync": datetime.now().isoformat(),
"strategy": strategy,
"rows": rows,
"columns": columns,
"file_size_mb": round(file_size_bytes / 1024 / 1024, 2),
"uncompressed_mb": round(uncompressed_bytes / 1024 / 1024, 2),
}
self._save_state()
class DataSource(ABC):
"""
Abstract class for data source.
Connectors implement this to integrate different data backends.
See connectors/keboola/ for a reference implementation.
"""
@abstractmethod
def sync_table(
self,
table_config: TableConfig,
sync_state: SyncState,
) -> Dict[str, Any]:
"""
Synchronize single table.
Args:
table_config: Table configuration
sync_state: Sync state manager
Returns:
Dictionary with sync result
"""
pass
def discover_tables(self) -> List[Dict[str, Any]]:
"""List all available tables in the data source.
Returns list of dicts with at minimum:
id, name, bucket_id, columns, row_count, size_bytes,
primary_key, last_change
Default: empty list (source doesn't support discovery).
"""
return []
def get_column_metadata(self, table_id: str) -> Optional[Dict[str, Any]]:
"""Return processed column metadata for schema generation.
Returns:
{"columns": {"col_name": {"source_type": "...", "description": "..."}}}
or None if the source doesn't support metadata.
"""
return None
def get_source_name(self) -> str:
"""Display name of this data source for schema comments."""
return "Unknown"
def _get_uncompressed_size(parquet_path: Path) -> int:
"""Read total uncompressed size from Parquet file metadata."""
try:
import pyarrow.parquet as pq
meta = pq.ParquetFile(parquet_path).metadata
total = 0
for rg_idx in range(meta.num_row_groups):
rg = meta.row_group(rg_idx)
for col_idx in range(rg.num_columns):
total += rg.column(col_idx).total_uncompressed_size
return total
except Exception:
return 0
class DataSyncManager:
"""
Main data synchronization orchestrator.
Manages sync of all tables and tracks results.
"""
def __init__(self):
"""Initialize sync manager."""
self.config = get_config()
self.sync_state = SyncState(
self.config.get_metadata_path() / "sync_state.json"
)
self.data_source = create_data_source()
# Initialize OpenMetadata catalog enricher
try:
instance_config = load_instance_config()
self.catalog_enricher = CatalogEnricher(instance_config)
except Exception as e:
logger.warning(f"Failed to initialize catalog enricher: {e}")
self.catalog_enricher = CatalogEnricher({}) # Disabled enricher
def _generate_schema_yaml(self):
"""
Generate schema.yml file with actual table schemas from Parquet files.
This file is auto-generated and contains:
- Table names and descriptions
- Column names, types (from Parquet), and descriptions (from source metadata)
- Primary keys
Output: DOCS_OUTPUT_DIR/schema.yml (default: ./docs/schema.yml)
"""
import yaml
import pyarrow.parquet as pq
source_name = self.data_source.get_source_name()
logger.info("Generating schema.yml from synced tables...")
schema_data = {
"_metadata": {
"_schema_version": 2,
"generated_at": datetime.now().isoformat(),
"note": "AUTO-GENERATED - DO NOT EDIT. This file contains actual table schemas from synced Parquet files.",
"source": source_name,
"generator": "src/data_sync.py::DataSyncManager._generate_schema_yaml()",
},
"tables": {},
}
# Process each table in configuration
for table_config in self.config.tables:
try:
parquet_path = self.config.get_parquet_path(table_config)
# Skip if Parquet doesn't exist (table not synced yet)
if table_config.partition_by:
if not parquet_path.exists() or not list(parquet_path.glob("*.parquet")):
logger.debug(f" Skipping {table_config.name} (not synced yet)")
continue
first_partition = next(parquet_path.glob("*.parquet"))
pf = pq.ParquetFile(first_partition)
else:
if not parquet_path.exists():
logger.debug(f" Skipping {table_config.name} (not synced yet)")
continue
pf = pq.ParquetFile(parquet_path)
arrow_schema = pf.schema_arrow
# Get column metadata from data source (if supported)
col_metadata = self.data_source.get_column_metadata(table_config.id)
# Enrich with catalog metadata (OpenMetadata)
catalog_data = self.catalog_enricher.enrich_table(table_config)
# Extract column information
columns = []
for field_item in arrow_schema:
col_name = field_item.name
col_name_lower = col_name.lower()
pyarrow_type = str(field_item.type)
column_info = {
"name": col_name,
"type": pyarrow_type,
}
# Priority for description: catalog > BQ API > (nothing)
description = None
if catalog_data and col_name_lower in catalog_data.columns:
description = catalog_data.columns[col_name_lower].description
elif col_metadata and "columns" in col_metadata:
col_meta = col_metadata["columns"].get(col_name, {})
description = col_meta.get("description")
if description:
column_info["description"] = description
# Add source type from connector metadata
if col_metadata and "columns" in col_metadata:
col_meta = col_metadata["columns"].get(col_name, {})
if "source_type" in col_meta:
column_info["source_type"] = col_meta["source_type"]
columns.append(column_info)
primary_key = table_config.get_primary_key_columns()
# Priority for table description: catalog > data_description.md
table_description = table_config.description
if catalog_data:
table_description = catalog_data.description or table_description
table_info = {
"table_id": table_config.id,
"description": table_description,
"primary_key": primary_key,
"sync_strategy": table_config.sync_strategy,
"columns": columns,
}
if table_config.partition_by:
table_info["partitioned_by"] = table_config.partition_by
schema_data["tables"][table_config.name] = table_info
logger.debug(f" {table_config.name}: {len(columns)} columns")
except Exception as e:
logger.warning(f" Error processing {table_config.name}: {e}")
# Split tables into core (no dataset) and per-dataset groups
core_tables = {}
dataset_tables = {} # {dataset_name: {table_name: table_info}}
for table_name, table_info in schema_data["tables"].items():
table_config = next(
(t for t in self.config.tables if t.name == table_name), None
)
if table_config and table_config.dataset:
ds = table_config.dataset
if ds not in dataset_tables:
dataset_tables[ds] = {}
dataset_tables[ds][table_name] = table_info
else:
core_tables[table_name] = table_info
generated_at = schema_data["_metadata"]["generated_at"]
def _write_schema_file(filepath, tables, note=""):
"""Write a schema YAML file with header comments."""
filepath.parent.mkdir(parents=True, exist_ok=True)
data = {
"_metadata": {
"_schema_version": 2,
"generated_at": generated_at,
"note": "AUTO-GENERATED - DO NOT EDIT.",
"source": source_name,
"generator": "src/data_sync.py::DataSyncManager._generate_schema_yaml()",
},
"tables": tables,
}
with open(filepath, "w") as f:
f.write("# AUTO-GENERATED - DO NOT EDIT\n")
f.write("# This file is automatically generated during data sync\n")
f.write(f"# Generated: {generated_at}\n")
if note:
f.write(f"# {note}\n")
f.write("#\n")
f.write("# Contains actual table schemas from synced Parquet files:\n")
f.write("# - Column names and PyArrow types (from Parquet)\n")
f.write(f"# - Source types and descriptions (from {source_name})\n")
f.write("# - Primary keys and sync strategies\n")
f.write("#\n")
f.write("# For architectural documentation and relationships, see data_description.md\n")
f.write("\n")
yaml.dump(data, f, default_flow_style=False, sort_keys=False, allow_unicode=True)
# Write core schema.yml
schema_file = self.config.docs_output_dir / "schema.yml"
_write_schema_file(schema_file, core_tables)
logger.info(f"Core schema YAML: {len(core_tables)} tables -> {schema_file}")
# Write per-dataset schema files
for ds_name, ds_tables in dataset_tables.items():
ds_schema_file = self.config.docs_output_dir / "datasets" / ds_name / "schema.yml"
_write_schema_file(ds_schema_file, ds_tables, note=f"Dataset: {ds_name}")
logger.info(f"Dataset schema YAML: {len(ds_tables)} tables -> {ds_schema_file}")
total = len(core_tables) + sum(len(t) for t in dataset_tables.values())
logger.info(f"Schema generation complete: {total} tables total")
return schema_file
def sync_table(self, table_id: str) -> Dict[str, Any]:
"""
Synchronize single table by ID.
Args:
table_id: Table ID to synchronize
Returns:
Dictionary with sync result
"""
table_config = self.config.get_table_config(table_id)
if not table_config:
raise ValueError(f"Table {table_id} not found in configuration")
return self.data_source.sync_table(table_config, self.sync_state)
def sync_all(self, tables: Optional[List[str]] = None) -> Dict[str, Dict[str, Any]]:
"""
Synchronize all tables (or subset according to list).
Args:
tables: List of table IDs to synchronize. If None, syncs all.
Returns:
Dictionary {table_id: result} with sync results
"""
if tables:
table_configs = [
self.config.get_table_config(tid) for tid in tables
]
table_configs = [tc for tc in table_configs if tc is not None]
else:
table_configs = self.config.tables
# Filter out remote-only tables (no local sync needed)
remote_skipped = [
tc for tc in table_configs if tc.query_mode == "remote"
]
table_configs = [
tc for tc in table_configs if tc.query_mode != "remote"
]
if remote_skipped:
logger.info(
f"Skipping {len(remote_skipped)} remote-only tables "
f"(query via BigQuery): "
f"{', '.join(tc.name for tc in remote_skipped)}"
)
logger.info(f"Synchronizing {len(table_configs)} tables...")
results = {}
with tqdm(table_configs, desc="Syncing tables") as pbar:
for table_config in pbar:
pbar.set_description(f"Sync: {table_config.name}")
result = self.data_source.sync_table(table_config, self.sync_state)
results[table_config.id] = result
if result["success"]:
pbar.write(
f" {table_config.name}: {result['rows']:,} rows, "
f"{result['file_size_mb']:.2f} MB"
)
else:
pbar.write(f" {table_config.name}: {result['error']}")
success_count = sum(1 for r in results.values() if r["success"])
logger.info(
f"Synchronization completed: {success_count}/{len(results)} tables successful"
)
# Generate schema.yml from synced tables
if success_count > 0:
try:
self._generate_schema_yaml()
except Exception as e:
logger.warning(f"Failed to generate schema.yml: {e}")
# Auto-profile changed tables
if success_count > 0:
self._auto_profile(results)
return results
def _auto_profile(
self,
results: Dict[str, Dict[str, Any]],
skip_tables: Optional[List[str]] = None,
):
"""Run profiler on successfully synced tables.
Args:
results: Sync results dict {table_id: result}
skip_tables: Table IDs to skip profiling for
"""
skip_set = set(skip_tables or [])
try:
from src.profiler import profile_changed_tables
changed = [
self.config.get_table_config(tid).name
for tid, r in results.items()
if r.get("success")
and self.config.get_table_config(tid)
and tid not in skip_set
]
if changed:
result = profile_changed_tables(changed)
logger.info(
f"Auto-profiling: {result['success']} profiled, "
f"{result['errors']} errors, {result['skipped']} skipped"
)
else:
logger.info("No tables to profile (all skipped or none succeeded)")
except Exception as e:
logger.warning(f"Auto-profiling failed (non-fatal): {e}")
def sync_scheduled(self) -> Dict[str, Dict[str, Any]]:
"""Synchronize only tables whose sync_schedule says they are due.
Evaluates each table's sync_schedule against its last_sync timestamp.
Only syncs tables that are due. Respects profile_after_sync flag.
Returns:
Dictionary {table_id: result} with sync results (only for synced tables)
"""
from src.scheduler import is_table_due
scheduled_tables = [
tc for tc in self.config.tables
if tc.sync_schedule and tc.query_mode != "remote"
]
if not scheduled_tables:
logger.info("No tables with sync_schedule configured")
return {}
# Evaluate which tables are due
due_tables = []
for tc in scheduled_tables:
last_sync = self.sync_state.get_last_sync(tc.id)
if is_table_due(tc.sync_schedule, last_sync):
due_tables.append(tc)
logger.info(f"Table {tc.name} is DUE (schedule: {tc.sync_schedule})")
else:
logger.debug(f"Table {tc.name} is not due (schedule: {tc.sync_schedule})")
if not due_tables:
logger.info(
f"Checked {len(scheduled_tables)} scheduled tables, none are due"
)
return {}
logger.info(
f"Syncing {len(due_tables)}/{len(scheduled_tables)} due tables: "
f"{', '.join(tc.name for tc in due_tables)}"
)
# Sync due tables
results = {}
for table_config in due_tables:
try:
result = self.data_source.sync_table(table_config, self.sync_state)
results[table_config.id] = result
if result["success"]:
logger.info(
f" {table_config.name}: {result['rows']:,} rows, "
f"{result['file_size_mb']:.2f} MB"
)
else:
logger.error(f" {table_config.name}: {result['error']}")
except Exception as e:
logger.error(f" {table_config.name}: sync failed: {e}")
results[table_config.id] = {"success": False, "error": str(e)}
success_count = sum(1 for r in results.values() if r["success"])
logger.info(f"Scheduled sync: {success_count}/{len(results)} tables successful")
# Generate schema.yml
if success_count > 0:
try:
self._generate_schema_yaml()
except Exception as e:
logger.warning(f"Failed to generate schema.yml: {e}")
# Profile only tables with profile_after_sync=True
skip_profiler = [
tc.id for tc in due_tables if not tc.profile_after_sync
]
if skip_profiler:
logger.info(
f"Skipping profiler for: "
f"{', '.join(self.config.get_table_config(tid).name for tid in skip_profiler)}"
)
profiled_any = False
if success_count > 0:
tables_to_profile = [
tid for tid, r in results.items()
if r.get("success") and tid not in set(skip_profiler)
]
if tables_to_profile:
self._auto_profile(results, skip_tables=skip_profiler)
profiled_any = True
# Restart webapp if profiler ran (new profiles.json needs reload)
if profiled_any:
self._restart_webapp()
return results
def _restart_webapp(self):
"""Restart webapp service to pick up new profiles.json."""
import subprocess
try:
subprocess.run(
["sudo", "systemctl", "restart", "webapp"],
check=True,
capture_output=True,
timeout=30,
)
logger.info("Webapp restarted successfully")
except subprocess.CalledProcessError as e:
logger.warning(f"Failed to restart webapp: {e.stderr.decode() if e.stderr else e}")
except FileNotFoundError:
logger.debug("systemctl not found (not running on server)")
def create_sync_manager() -> DataSyncManager:
"""
Factory function to create DataSyncManager.
Returns:
DataSyncManager instance
"""
return DataSyncManager()
def create_data_source(source_type: str = None) -> DataSource:
"""Create a data source based on configuration.
Args:
source_type: Override source type. If None, uses DATA_SOURCE env var.
Returns:
DataSource instance
Raises:
ValueError: If source type is unknown
ImportError: If connector dependencies are missing
"""
if source_type is None:
source_type = get_config().data_source
if source_type in ("local", "keboola"):
try:
from connectors.keboola.adapter import KeboolaDataSource
except ModuleNotFoundError as e:
if "kbcstorage" in str(e):
raise ImportError(
"Keboola connector requires 'kbcstorage' package. "
"Install with: pip install kbcstorage"
) from e
raise # Re-raise real import errors
return KeboolaDataSource()
# Try dynamic connector import for other types
try:
mod = importlib.import_module(f"connectors.{source_type}.adapter")
factory = getattr(mod, "create_data_source", None)
if factory:
return factory()
# Fallback: look for a class named *DataSource
for attr_name in dir(mod):
attr = getattr(mod, attr_name)
if isinstance(attr, type) and issubclass(attr, DataSource) and attr is not DataSource:
return attr()
except ModuleNotFoundError:
pass
raise ValueError(
f"Unknown data source: '{source_type}'. "
f"Available connectors: keboola, bigquery. "
f"Create connectors/{source_type}/adapter.py to add a new one."
)
if __name__ == "__main__":
# CLI interface for sync
import sys
scheduled_mode = "--scheduled" in sys.argv
table_args = [a for a in sys.argv[1:] if a != "--scheduled"]
try:
manager = create_sync_manager()
if scheduled_mode:
print("Data Sync (scheduled mode)")
results = manager.sync_scheduled()
if not results:
print("No tables due for sync")
sys.exit(0)
elif table_args:
print("Data Sync")
print(f"\nSynchronizing selected tables: {', '.join(table_args)}")
results = manager.sync_all(tables=table_args)
else:
print("Data Sync")
print("\nSynchronizing all tables...")
results = manager.sync_all()
success_count = sum(1 for r in results.values() if r["success"])
total_count = len(results)
if success_count == total_count:
print(f"\nAll {total_count} tables synchronized successfully!")
sys.exit(0)
else:
print(
f"\n{success_count}/{total_count} tables synchronized. "
f"Check logs for details."
)
sys.exit(1)
except Exception as e:
print(f"\nError: {e}")
import traceback
traceback.print_exc()
sys.exit(1)

View file

@ -1,755 +0,0 @@
"""
Parquet File Manager
Parquet file management:
1. CSV -> Parquet conversion with data type application
2. Compression (snappy) for space saving
3. Metadata embedding (table_id, export_date)
4. Information about existing Parquet files
5. Merge/upsert operations for incremental sync
Parquet format advantages:
- Columnar storage -> faster analytical queries
- Compression -> smaller size than CSV
- Schema enforcement -> type safety
- Metadata support -> self-documenting
"""
import logging
from pathlib import Path
from typing import Dict, Optional, List, Any
from datetime import datetime
import pandas as pd
import pyarrow as pa
import pyarrow.parquet as pq
logger = logging.getLogger(__name__)
def convert_date_columns_to_date32(table: pa.Table, date_columns: List[str]) -> pa.Table:
"""
Convert timestamp/string columns to DATE32 type.
Extracted from csv_to_parquet() for reuse in partitioned sync.
Invalid date values (like '0000-00-00') are converted to NULL, type stays DATE32.
Args:
table: PyArrow table to convert
date_columns: List of column names to convert to DATE32
Returns:
PyArrow table with converted date columns
"""
if not date_columns:
return table
schema_fields = []
for i, field in enumerate(table.schema):
if field.name in date_columns:
schema_fields.append(pa.field(field.name, pa.date32()))
else:
schema_fields.append(field)
# Cast columns to DATE32
columns = []
for i, field in enumerate(table.schema):
if field.name in date_columns:
col = table.column(i)
# Skip if all nulls - nothing to convert, just cast type
if col.null_count == len(col):
columns.append(pa.nulls(len(col), type=pa.date32()))
continue
# If column is string type, use pandas for robust parsing with errors='coerce'
# This converts invalid dates to NaT (NULL) while keeping the DATE type
if pa.types.is_string(col.type) or pa.types.is_large_string(col.type):
# Convert to pandas, parse dates with coerce (invalid -> NaT)
series = col.to_pandas()
parsed = pd.to_datetime(series, errors='coerce', format='mixed')
# Count invalid values for logging
invalid_count = parsed.isna().sum() - series.isna().sum()
if invalid_count > 0:
# Find examples of invalid values
invalid_mask = series.notna() & parsed.isna()
examples = series[invalid_mask].head(3).tolist()
logger.warning(
f"Column '{field.name}': {invalid_count} invalid date values converted to NULL. "
f"Examples: {examples}"
)
# Convert to date only (remove time component) and then to PyArrow
date_series = parsed.dt.date
date_array = pa.array(date_series, type=pa.date32())
columns.append(date_array)
else:
# Column is already timestamp/date type, just cast
try:
columns.append(col.cast(pa.date32()))
except Exception as e:
logger.warning(
f"Column '{field.name}': Failed to cast to date32, keeping original type. Error: {e}"
)
columns.append(col)
schema_fields[i] = field
else:
columns.append(table.column(i))
# Rebuild table with new schema
new_schema = pa.schema(schema_fields, metadata=table.schema.metadata)
return pa.Table.from_arrays(columns, schema=new_schema)
def apply_schema_to_table(table: pa.Table, target_schema: pa.Schema) -> pa.Table:
"""
Apply target schema to PyArrow table, handling type mismatches gracefully.
This function:
- Casts null-type columns to proper types (prevents DuckDB schema mismatch)
- Attempts safe casts for other type mismatches
- Keeps original type on cast failure (logs warning, no data loss)
- Preserves columns not in target schema with their inferred type
Args:
table: PyArrow table to cast
target_schema: Target PyArrow schema
Returns:
PyArrow table with applied schema
"""
if not target_schema:
return table
# Build mapping of column names to target types
target_types = {field.name: field.type for field in target_schema}
columns = []
schema_fields = []
for i, field in enumerate(table.schema):
col = table.column(i)
col_name = field.name
# Column not in target schema - keep as-is
if col_name not in target_types:
columns.append(col)
schema_fields.append(field)
continue
target_type = target_types[col_name]
# Case 1: Column has null type -> create typed null array
if pa.types.is_null(col.type):
columns.append(pa.nulls(len(col), type=target_type))
schema_fields.append(pa.field(col_name, target_type))
logger.debug(f"Column '{col_name}': converted null type to {target_type}")
continue
# Case 2: Column type matches target -> keep as-is
if col.type == target_type:
columns.append(col)
schema_fields.append(pa.field(col_name, target_type))
continue
# Case 3: Type mismatch -> try safe cast
try:
casted = col.cast(target_type, safe=True)
columns.append(casted)
schema_fields.append(pa.field(col_name, target_type))
logger.debug(f"Column '{col_name}': cast from {col.type} to {target_type}")
except Exception as e:
# Cast failed -> keep original type, log warning
logger.warning(
f"Column '{col_name}': cannot cast from {col.type} to {target_type}, "
f"keeping original type. Error: {e}"
)
columns.append(col)
schema_fields.append(field)
# Rebuild table with new schema
new_schema = pa.schema(schema_fields, metadata=table.schema.metadata)
return pa.Table.from_arrays(columns, schema=new_schema)
def _convert_column(series: pd.Series, dtype: str, col_name: str = "") -> pd.Series:
"""
Convert pandas Series to dtype, handling empty strings.
Empty strings become NA/NaN for non-string types.
Logs warning if invalid (non-empty) values are found.
Args:
series: Input pandas Series
dtype: Target dtype (e.g., "Int64", "float64", "boolean")
col_name: Column name for logging
Returns:
Converted Series
"""
# Replace empty strings with NA for non-string types
if dtype != "object":
series = series.replace('', pd.NA)
# Numeric types - use errors='coerce' but log invalid values
if dtype in ("Int64", "float64", "Float64"):
# Count non-null values before conversion
non_null_before = series.notna().sum()
converted = pd.to_numeric(series, errors='coerce')
# Count how many became NA after conversion (excluding already NA)
non_null_after = converted.notna().sum()
invalid_count = non_null_before - non_null_after
if invalid_count > 0:
# Find examples of invalid values
invalid_mask = series.notna() & converted.isna()
examples = series[invalid_mask].head(3).tolist()
logger.warning(
f"Column '{col_name}': {invalid_count} invalid values converted to NULL. "
f"Examples: {examples}"
)
return converted.astype(dtype)
# Boolean type - map string representations
if dtype == "boolean":
# If pandas already parsed as bool, just convert to nullable boolean
if series.dtype == bool or series.dtype == 'object' and series.dropna().apply(lambda x: isinstance(x, bool)).all():
return series.astype(dtype)
bool_map = {
'true': True, 'false': False,
'True': True, 'False': False,
'TRUE': True, 'FALSE': False,
'1': True, '0': False,
'yes': True, 'no': False,
'Yes': True, 'No': False,
'YES': True, 'NO': False,
}
# Log unknown values (non-empty strings that aren't in bool_map)
known_values = set(bool_map.keys())
non_na_values = series.dropna()
unknown = non_na_values[~non_na_values.isin(known_values)]
if len(unknown) > 0:
examples = unknown.head(3).tolist()
logger.warning(
f"Column '{col_name}': {len(unknown)} unknown boolean values converted to NULL. "
f"Examples: {examples}"
)
return series.map(bool_map).astype(dtype)
# Default: direct conversion
return series.astype(dtype)
class ParquetManager:
"""
Parquet file manager.
Provides methods for:
- CSV -> Parquet conversion
- Getting information about Parquet files
- Merge/upsert for incremental sync
"""
def __init__(self):
"""Initialize Parquet manager."""
# Compression codec - snappy is fast and has good compression ratio
self.compression = "snappy"
def csv_to_parquet(
self,
csv_path: Path,
parquet_path: Path,
dtypes: Optional[Dict[str, str]] = None,
table_id: Optional[str] = None,
parse_dates: Optional[List[str]] = None,
date_columns: Optional[List[str]] = None,
pyarrow_schema: Optional[pa.Schema] = None
) -> Dict[str, Any]:
"""
Convert CSV file to Parquet format.
Args:
csv_path: Path to source CSV file
parquet_path: Path where to save Parquet file
dtypes: Dictionary with data types for columns (pandas dtypes)
table_id: Table ID for metadata
parse_dates: List of columns with dates/timestamps to parse
date_columns: List of DATE-only columns (without time) to convert to PyArrow DATE32
pyarrow_schema: Optional PyArrow schema to enforce (prevents null-type columns)
Returns:
Dictionary with conversion information:
- rows: Number of rows
- columns: Number of columns
- file_size_bytes: Parquet file size
- compression_ratio: Compression ratio (CSV size / Parquet size)
Raises:
Exception: If conversion fails
"""
logger.info(f"Converting CSV -> Parquet: {csv_path.name}")
try:
# Load CSV into pandas DataFrame
# IMPORTANT: Use dtype=str to prevent pandas from guessing types
# We apply our own types from Keboola metadata using _convert_column
read_kwargs = {"dtype": str}
# Get actual column names from CSV header first
with open(csv_path, 'r') as f:
header_line = f.readline().strip()
csv_columns = set(col.strip('"') for col in header_line.split(','))
# Parse datetime columns - only those that exist in CSV
if parse_dates:
valid_parse_dates = [col for col in parse_dates if col in csv_columns]
if valid_parse_dates:
read_kwargs["parse_dates"] = valid_parse_dates
elif dtypes:
# Auto-detect datetime columns from dtypes (only existing columns)
datetime_cols = [
col for col, dtype in dtypes.items()
if "datetime" in dtype and col in csv_columns
]
if datetime_cols:
read_kwargs["parse_dates"] = datetime_cols
df = pd.read_csv(csv_path, **read_kwargs)
logger.debug(f"CSV loaded: {len(df)} rows, {len(df.columns)} columns")
# Apply dtypes using _convert_column to handle empty strings
if dtypes:
for col, dtype in dtypes.items():
if col in df.columns and "datetime" not in dtype:
try:
df[col] = _convert_column(df[col], dtype, col_name=col)
except Exception as e:
logger.warning(
f"Failed to apply dtype {dtype} to column {col}: {e}"
)
# Add metadata as custom schema metadata
metadata = {
"created_at": datetime.now().isoformat(),
}
if table_id:
metadata["table_id"] = table_id
# Convert to PyArrow Table
# PyArrow preserves pandas dtypes and adds metadata
table = pa.Table.from_pandas(df)
# Convert DATE columns from timestamp/string to DATE32
if date_columns:
table = convert_date_columns_to_date32(table, date_columns)
# Apply explicit schema (prevents null-type columns)
if pyarrow_schema:
table = apply_schema_to_table(table, pyarrow_schema)
# Add custom metadata
existing_metadata = table.schema.metadata or {}
new_metadata = {
**existing_metadata,
**{k.encode(): v.encode() for k, v in metadata.items()}
}
table = table.replace_schema_metadata(new_metadata)
# Ensure output folder exists
parquet_path.parent.mkdir(parents=True, exist_ok=True)
# Write to Parquet
pq.write_table(
table,
parquet_path,
compression=self.compression
)
# Get file sizes for compression ratio
csv_size = csv_path.stat().st_size
parquet_size = parquet_path.stat().st_size
compression_ratio = csv_size / parquet_size if parquet_size > 0 else 0
result = {
"rows": len(df),
"columns": len(df.columns),
"csv_size_bytes": csv_size,
"parquet_size_bytes": parquet_size,
"compression_ratio": compression_ratio,
"parquet_path": str(parquet_path)
}
logger.info(
f"Parquet created: {len(df)} rows, "
f"{parquet_size / 1024 / 1024:.2f} MB, "
f"compression {compression_ratio:.2f}x"
)
return result
except Exception as e:
logger.error(f"Error converting CSV -> Parquet: {e}")
raise
def get_parquet_info(self, parquet_path: Path) -> Optional[Dict[str, Any]]:
"""
Get information about existing Parquet file.
Args:
parquet_path: Path to Parquet file
Returns:
Dictionary with information:
- rows: Number of rows
- columns: Number of columns
- file_size_bytes: File size
- modified_at: Last modification timestamp
- schema: PyArrow schema
- metadata: Custom metadata
Or None if file doesn't exist.
"""
if not parquet_path.exists():
return None
try:
# Load Parquet metadata (without loading data)
parquet_file = pq.ParquetFile(parquet_path)
# Basic info
file_size = parquet_path.stat().st_size
modified_at = datetime.fromtimestamp(parquet_path.stat().st_mtime)
# Schema and metadata
schema = parquet_file.schema_arrow
custom_metadata = {}
if schema.metadata:
custom_metadata = {
k.decode(): v.decode()
for k, v in schema.metadata.items()
if k.decode() not in ["pandas"] # Filter pandas internal metadata
}
info = {
"rows": parquet_file.metadata.num_rows,
"columns": len(schema),
"file_size_bytes": file_size,
"modified_at": modified_at.isoformat(),
"schema": schema,
"metadata": custom_metadata,
"parquet_path": str(parquet_path)
}
return info
except Exception as e:
logger.error(f"Error reading Parquet info: {e}")
return None
def merge_parquet(
self,
existing_parquet: Path,
new_csv: Path,
output_parquet: Path,
primary_key: List[str],
dtypes: Optional[Dict[str, str]] = None,
parse_dates: Optional[List[str]] = None,
date_columns: Optional[List[str]] = None,
pyarrow_schema: Optional[pa.Schema] = None
) -> Dict[str, Any]:
"""
Merge new data from CSV into existing Parquet file.
Performs upsert operation:
- Rows with existing primary_key are updated
- Rows with new primary_key are added
Args:
existing_parquet: Path to existing Parquet file
new_csv: Path to CSV with new data
output_parquet: Path where to save resulting Parquet
primary_key: List of column names forming the primary key (supports composite PK)
dtypes: Dictionary with data types
parse_dates: List of datetime columns
date_columns: List of DATE-only columns to convert to PyArrow DATE32
pyarrow_schema: Optional PyArrow schema to enforce (prevents null-type columns)
Returns:
Dictionary with information:
- total_rows: Total number of rows after merge
- added_rows: Number of newly added rows
- updated_rows: Number of updated rows
- unchanged_rows: Number of unchanged rows
Raises:
Exception: If merge fails
"""
pk_str = ", ".join(primary_key)
logger.info(
f"Merging Parquet: {existing_parquet.name} + {new_csv.name} -> {output_parquet.name} (PK: {pk_str})"
)
try:
# Load existing Parquet
existing_df = pd.read_parquet(existing_parquet)
original_count = len(existing_df)
logger.debug(f"Existing data: {original_count} rows")
# Load new data from CSV
# IMPORTANT: Use dtype=str to prevent pandas from guessing types
# We apply our own types from Keboola metadata using _convert_column
read_kwargs = {"dtype": str}
if parse_dates:
read_kwargs["parse_dates"] = parse_dates
elif dtypes:
datetime_cols = [
col for col, dtype in dtypes.items()
if "datetime" in dtype
]
if datetime_cols:
read_kwargs["parse_dates"] = datetime_cols
new_df = pd.read_csv(new_csv, **read_kwargs)
# Apply dtypes using _convert_column to handle empty strings
if dtypes:
for col, dtype in dtypes.items():
if col in new_df.columns and "datetime" not in dtype:
try:
new_df[col] = _convert_column(new_df[col], dtype, col_name=col)
except Exception as e:
logger.warning(
f"Failed to apply dtype {dtype} to column {col}: {e}"
)
new_count = len(new_df)
logger.debug(f"New data: {new_count} rows")
# Check that all primary_key columns exist in both dataframes
for pk_col in primary_key:
if pk_col not in existing_df.columns:
raise ValueError(
f"Primary key column '{pk_col}' not found in existing data"
)
if pk_col not in new_df.columns:
raise ValueError(
f"Primary key column '{pk_col}' not found in new data"
)
# Perform upsert: concat and then drop_duplicates with keep='last'
# Keep='last' means that new data (which is second) will overwrite old
merged_df = pd.concat([existing_df, new_df], ignore_index=True)
merged_df = merged_df.drop_duplicates(subset=primary_key, keep='last')
# Calculate statistics
final_count = len(merged_df)
added_rows = final_count - original_count
# Updated rows = rows that were in both datasets
updated_rows = new_count - added_rows if added_rows < new_count else 0
unchanged_rows = original_count - updated_rows
logger.info(
f"Merge completed: {final_count} total rows "
f"(+{added_rows} new, ~{updated_rows} updates)"
)
# Save as new Parquet
# Prepare metadata
metadata = {
"created_at": datetime.now().isoformat(),
"merged_from": new_csv.name
}
table = pa.Table.from_pandas(merged_df)
# Convert DATE columns from timestamp to DATE32
if date_columns:
table = convert_date_columns_to_date32(table, date_columns)
# Apply explicit schema (prevents null-type columns)
if pyarrow_schema:
table = apply_schema_to_table(table, pyarrow_schema)
new_metadata = {
k.encode(): v.encode() for k, v in metadata.items()
}
table = table.replace_schema_metadata(new_metadata)
# Ensure output folder exists
output_parquet.parent.mkdir(parents=True, exist_ok=True)
# Write Parquet
pq.write_table(table, output_parquet, compression=self.compression)
result = {
"total_rows": final_count,
"total_columns": len(merged_df.columns),
"added_rows": added_rows,
"updated_rows": updated_rows,
"unchanged_rows": unchanged_rows,
"parquet_path": str(output_parquet)
}
return result
except Exception as e:
logger.error(f"Error merging Parquet: {e}")
raise
def validate_parquet(self, parquet_path: Path) -> bool:
"""
Validate that Parquet file is readable and not corrupted.
Args:
parquet_path: Path to Parquet file
Returns:
True if file is valid, False otherwise
"""
try:
# Try to load schema (fast operation)
parquet_file = pq.ParquetFile(parquet_path)
_ = parquet_file.schema_arrow
# Try to load first row (data validation)
# Note: pyarrow 23.0 doesn't have nrows parameter, load full file then limit
df = pd.read_parquet(parquet_path)
_ = df.head(1)
return True
except Exception as e:
logger.error(f"Parquet validation failed: {e}")
return False
def create_parquet_manager() -> ParquetManager:
"""
Factory function to create ParquetManager instance.
Returns:
ParquetManager instance
"""
return ParquetManager()
if __name__ == "__main__":
# Test Parquet manager
print("📦 Testing Parquet manager...")
import tempfile
try:
manager = create_parquet_manager()
# Create test CSV
with tempfile.TemporaryDirectory() as tmpdir:
tmpdir = Path(tmpdir)
# Test data
test_csv = tmpdir / "test.csv"
test_csv.write_text(
"id,name,value,created_at\n"
"1,Alice,100,2026-01-01\n"
"2,Bob,200,2026-01-02\n"
"3,Charlie,300,2026-01-03\n"
)
test_parquet = tmpdir / "test.parquet"
# Test 1: CSV -> Parquet
print("\n1⃣ Testing CSV -> Parquet conversion...")
result = manager.csv_to_parquet(
csv_path=test_csv,
parquet_path=test_parquet,
dtypes={"id": "Int64", "name": "object", "value": "Int64"},
parse_dates=["created_at"],
table_id="test.table"
)
print(f" ✅ Conversion OK:")
print(f" Rows: {result['rows']}")
print(f" Compression: {result['compression_ratio']:.2f}x")
# Test 2: Parquet info
print("\n2⃣ Testing Parquet info...")
info = manager.get_parquet_info(test_parquet)
if info:
print(f" ✅ Info loaded:")
print(f" Rows: {info['rows']}")
print(f" Metadata: {info['metadata']}")
# Test 3: Validation
print("\n3⃣ Testing validation...")
if manager.validate_parquet(test_parquet):
print(" ✅ Parquet is valid!")
# Test 4: Merge
print("\n4⃣ Testing merge...")
# Create update CSV
update_csv = tmpdir / "update.csv"
update_csv.write_text(
"id,name,value,created_at\n"
"2,Bob Updated,250,2026-01-04\n" # Update
"4,David,400,2026-01-05\n" # New
)
merged_parquet = tmpdir / "merged.parquet"
merge_result = manager.merge_parquet(
existing_parquet=test_parquet,
new_csv=update_csv,
output_parquet=merged_parquet,
primary_key=["id"], # Now uses list
dtypes={"id": "Int64", "name": "object", "value": "Int64"},
parse_dates=["created_at"]
)
print(f" ✅ Merge OK:")
print(f" Total rows: {merge_result['total_rows']}")
print(f" Added: {merge_result['added_rows']}")
print(f" Updated: {merge_result['updated_rows']}")
# Test 5: Empty string handling
print("\n5⃣ Testing empty string handling...")
empty_csv = tmpdir / "empty_strings.csv"
empty_csv.write_text(
"id,is_active,revenue,note\n"
'1,true,100.5,text\n'
'2,,200.0,\n' # Empty boolean and string
'3,false,,note\n' # Empty float
'4,TRUE,N/A,\n' # Invalid float value
)
empty_parquet = tmpdir / "empty_strings.parquet"
result = manager.csv_to_parquet(
csv_path=empty_csv,
parquet_path=empty_parquet,
dtypes={
"id": "Int64",
"is_active": "boolean",
"revenue": "float64",
"note": "object"
},
table_id="test.empty_strings"
)
print(f" ✅ Conversion with empty strings OK:")
print(f" Rows: {result['rows']}")
# Verify the data
df = pd.read_parquet(empty_parquet)
print(f" Dtypes: {dict(df.dtypes)}")
print(f" is_active nulls: {df['is_active'].isna().sum()}")
print(f" revenue nulls: {df['revenue'].isna().sum()}")
print("\n✅ All tests passed!")
except Exception as e:
print(f"\n❌ Error: {e}")
import traceback
traceback.print_exc()

View file

@ -1,636 +0,0 @@
"""
Remote Query - Execute DuckDB queries spanning local Parquet + remote BigQuery tables.
Provides a server-side CLI for the AI agent to run SQL queries that JOIN local
(Parquet-backed) tables with on-demand BigQuery results. Designed for tables too
large to sync locally (e.g., daily_deal_traffic: ~3M rows/day).
Two-phase query protocol:
1. BQ sub-queries (--register-bq "alias=SQL") run on BigQuery, results registered
as DuckDB views via PyArrow (reuses register_bq_table from duckdb_manager).
2. DuckDB SQL (--sql) runs against local Parquet views + registered BQ results.
Usage:
python -m src.remote_query \\
--sql "SELECT ... FROM order_economics o JOIN traffic t ON ..." \\
--register-bq "traffic=SELECT ... FROM \\`project.dataset.table\\` WHERE ..." \\
--format table
Safety features:
- COUNT(*) pre-check before fetching BQ data
- Memory estimation (refuses queries > 2 GB estimated)
- Configurable row limits (per BQ sub-query and final result)
- Query timeout support
"""
import argparse
import csv
import io
import json
import logging
import os
import sys
import time
from pathlib import Path
from typing import Optional
import duckdb
from config.loader import get_instance_value
from scripts.duckdb_manager import (
create_local_views,
register_bq_table,
_create_bq_client,
)
logger = logging.getLogger(__name__)
class RemoteQueryError(Exception):
"""Error during remote query execution."""
# ---------------------------------------------------------------------------
# Configuration
# ---------------------------------------------------------------------------
def _load_remote_query_config() -> dict:
"""Load remote_query settings from instance.yaml with defaults.
Uses raw YAML loading instead of load_instance_config() to avoid
requiring webapp secrets (WEBAPP_SECRET_KEY etc.) that analysts
don't have access to.
"""
import yaml as _yaml
from pathlib import Path as _Path
instance_config: dict = {}
config_dir = _Path(os.environ.get("CONFIG_DIR", "./config"))
yaml_path = config_dir / "instance.yaml"
if yaml_path.exists():
try:
with open(yaml_path) as f:
instance_config = _yaml.safe_load(f) or {}
except Exception as e:
logger.warning("Could not load instance.yaml: %s. Using defaults.", e)
return {
"timeout_seconds": get_instance_value(
instance_config, "remote_query", "timeout_seconds", default=300,
),
"max_result_rows": get_instance_value(
instance_config, "remote_query", "max_result_rows", default=100_000,
),
"max_bq_registration_rows": get_instance_value(
instance_config, "remote_query", "max_bq_registration_rows", default=500_000,
),
"default_format": get_instance_value(
instance_config, "remote_query", "default_format", default="table",
),
"output_dir": get_instance_value(
instance_config, "remote_query", "output_dir", default="/tmp/remote_query",
),
}
# ---------------------------------------------------------------------------
# BQ registration with safety checks
# ---------------------------------------------------------------------------
def _validate_bq_result_size(
bq_client, sql: str, alias: str, max_rows: int,
) -> int:
"""Execute COUNT(*) on the BQ sub-query before fetching all rows.
Args:
bq_client: BigQuery client instance
sql: The BQ SQL query to count
alias: Alias name (for error messages)
max_rows: Maximum allowed rows
Returns:
Row count
Raises:
RemoteQueryError: If count exceeds max_rows
"""
count_sql = f"SELECT COUNT(*) AS cnt FROM ({sql})"
_log_progress(f" Counting rows for '{alias}'...")
job = bq_client.query(count_sql)
result = job.result()
row_count = next(iter(result))[0]
if row_count > max_rows:
raise RemoteQueryError(
f"BQ sub-query '{alias}' would return {row_count:,} rows "
f"(limit: {max_rows:,}). Add more WHERE filters or GROUP BY "
f"to reduce the result set."
)
return row_count
def _estimate_memory_mb(row_count: int, column_count: int) -> float:
"""Estimate memory usage in MB for a PyArrow table.
Uses ~50 bytes per cell as a rough average across data types.
"""
return (row_count * column_count * 50) / (1024 * 1024)
def _register_bq_views(
conn: duckdb.DuckDBPyConnection,
registrations: list[tuple[str, str]],
max_bq_rows: int,
timeout_seconds: int,
quiet: bool = False,
) -> dict[str, int]:
"""Register BQ query results as DuckDB views with safety checks.
Args:
conn: DuckDB connection
registrations: List of (alias, bq_sql) tuples
max_bq_rows: Max rows per sub-query
timeout_seconds: BQ job timeout
quiet: Suppress progress messages
Returns:
Dict of {alias: row_count}
"""
if not registrations:
return {}
bq_project = os.environ.get("BIGQUERY_PROJECT")
if not bq_project:
raise RemoteQueryError(
"BIGQUERY_PROJECT environment variable not set. "
"Required for BigQuery sub-queries."
)
bq_client = _create_bq_client(bq_project)
results = {}
for alias, bq_sql in registrations:
start_time = time.time()
# Phase 1: COUNT(*) safety check
row_count = _validate_bq_result_size(bq_client, bq_sql, alias, max_bq_rows)
_log_progress(f" '{alias}': {row_count:,} rows (within limit)")
# Phase 2: Memory estimation
# Estimate column count from a LIMIT 0 query (cheap)
sample_job = bq_client.query(f"SELECT * FROM ({bq_sql}) LIMIT 0")
schema = sample_job.result().schema
col_count = len(schema)
estimated_mb = _estimate_memory_mb(row_count, col_count)
if estimated_mb > 2048: # 2 GB = 25% of 8 GB server RAM
raise RemoteQueryError(
f"BQ sub-query '{alias}' estimated memory: {estimated_mb:.0f} MB "
f"({row_count:,} rows x {col_count} cols). "
f"Limit is 2048 MB. Add more aggregation or filters."
)
# Phase 3: Execute and register
_log_progress(f" Fetching '{alias}' ({row_count:,} rows, ~{estimated_mb:.0f} MB)...")
actual_rows = register_bq_table(
conn=conn,
table_id=f"bq_registration.{alias}",
view_name=alias,
sql=bq_sql,
bq_project=bq_project,
)
elapsed = time.time() - start_time
_log_progress(f" '{alias}' registered: {actual_rows:,} rows in {elapsed:.1f}s")
results[alias] = actual_rows
return results
# ---------------------------------------------------------------------------
# Local view setup
# ---------------------------------------------------------------------------
def _setup_local_views(
conn: duckdb.DuckDBPyConnection,
data_dir: str,
quiet: bool = False,
) -> list[str]:
"""Create DuckDB views for all local/hybrid tables from Parquet.
Args:
conn: DuckDB connection
data_dir: Path to data directory (e.g., "/data/src_data")
quiet: Suppress progress messages
Returns:
List of created view names
"""
created, skipped = create_local_views(
conn=conn,
data_dir=data_dir,
verbose=not quiet,
)
return created
# ---------------------------------------------------------------------------
# Output formatting
# ---------------------------------------------------------------------------
def _print_table(columns: list[str], rows: list[tuple]) -> None:
"""Print an aligned ASCII table to stdout."""
if not rows:
print("(empty result)")
return
# Calculate column widths
str_rows = [[str(v) if v is not None else "NULL" for v in row] for row in rows]
widths = [len(col) for col in columns]
for row in str_rows:
for i, val in enumerate(row):
widths[i] = max(widths[i], len(val))
# Header
header = " | ".join(col.ljust(widths[i]) for i, col in enumerate(columns))
separator = "-+-".join("-" * widths[i] for i in range(len(columns)))
print(header)
print(separator)
# Rows
for row in str_rows:
line = " | ".join(val.ljust(widths[i]) for i, val in enumerate(row))
print(line)
print(f"\n({len(rows)} rows)")
def _format_output(
conn: duckdb.DuckDBPyConnection,
sql: str,
fmt: str,
output_path: Optional[str],
max_rows: int,
) -> None:
"""Execute final SQL and output results in the requested format.
Args:
conn: DuckDB connection with all views registered
sql: The final DuckDB SQL query
fmt: Output format (table, csv, json, parquet)
output_path: File path for file-based outputs
max_rows: Maximum rows to return
"""
# Add LIMIT to prevent runaway results
limited_sql = f"SELECT * FROM ({sql}) AS _rq LIMIT {max_rows + 1}"
result = conn.execute(limited_sql)
columns = [desc[0] for desc in result.description]
rows = result.fetchall()
# Check if result exceeded limit
if len(rows) > max_rows:
rows = rows[:max_rows]
_log_progress(
f" WARNING: Result truncated to {max_rows:,} rows. "
f"Add more filters or increase --max-rows."
)
if fmt == "table":
_print_table(columns, rows)
elif fmt == "csv":
if output_path:
with open(output_path, "w", newline="") as f:
writer = csv.writer(f)
writer.writerow(columns)
writer.writerows(rows)
_log_progress(f" CSV written: {output_path} ({len(rows)} rows)")
else:
writer = csv.writer(sys.stdout)
writer.writerow(columns)
writer.writerows(rows)
elif fmt == "json":
data = [dict(zip(columns, row)) for row in rows]
json_str = json.dumps(data, default=str, indent=2)
if output_path:
with open(output_path, "w") as f:
f.write(json_str)
_log_progress(f" JSON written: {output_path} ({len(rows)} rows)")
else:
print(json_str)
elif fmt == "parquet":
import pyarrow as pa
import pyarrow.parquet as pq
# Re-execute without limit wrapper for clean Arrow export
arrow_result = conn.execute(
f"SELECT * FROM ({sql}) AS _rq LIMIT {max_rows}"
).arrow().read_all()
if not output_path:
output_path = str(Path(_load_remote_query_config()["output_dir"]) / "result.parquet")
Path(output_path).parent.mkdir(parents=True, exist_ok=True)
pq.write_table(arrow_result, output_path)
_log_progress(
f" Parquet written: {output_path} "
f"({arrow_result.num_rows} rows, {arrow_result.num_columns} cols)"
)
else:
raise RemoteQueryError(f"Unknown format: {fmt}")
# ---------------------------------------------------------------------------
# Progress logging (stderr so stdout stays clean for data)
# ---------------------------------------------------------------------------
_quiet_mode = False
def _log_progress(msg: str) -> None:
"""Print progress message to stderr (keeps stdout clean for data output)."""
if not _quiet_mode:
print(msg, file=sys.stderr)
# ---------------------------------------------------------------------------
# Main execution
# ---------------------------------------------------------------------------
def execute_remote_query(
sql: str,
bq_registrations: list[tuple[str, str]],
fmt: str = "table",
output: Optional[str] = None,
max_rows: Optional[int] = None,
max_bq_rows: Optional[int] = None,
timeout: Optional[int] = None,
data_dir: str = "/data/src_data",
quiet: bool = False,
) -> None:
"""Main execution function for remote queries.
Args:
sql: DuckDB SQL query to execute
bq_registrations: List of (alias, bq_sql) tuples
fmt: Output format (table, csv, json, parquet)
output: Output file path (for parquet/csv/json)
max_rows: Max rows in final result
max_bq_rows: Max rows per BQ sub-query
timeout: Query timeout in seconds
data_dir: Path to data directory
quiet: Suppress progress messages
"""
global _quiet_mode
_quiet_mode = quiet
config = _load_remote_query_config()
max_rows = max_rows or config["max_result_rows"]
max_bq_rows = max_bq_rows or config["max_bq_registration_rows"]
timeout = timeout or config["timeout_seconds"]
fmt = fmt or config["default_format"]
start_time = time.time()
# Create in-memory DuckDB connection
conn = duckdb.connect(":memory:")
try:
# Step 1: Register local Parquet views
_log_progress("Setting up local views...")
local_views = _setup_local_views(conn, data_dir, quiet=quiet)
_log_progress(f" {len(local_views)} local views ready")
# Step 2: Register BQ sub-query results
if bq_registrations:
_log_progress(f"Registering {len(bq_registrations)} BQ sub-queries...")
bq_results = _register_bq_views(
conn, bq_registrations, max_bq_rows, timeout, quiet=quiet,
)
for alias, count in bq_results.items():
_log_progress(f" {alias}: {count:,} rows")
# Step 3: Execute the final DuckDB query
_log_progress("Executing query...")
_format_output(conn, sql, fmt, output, max_rows)
elapsed = time.time() - start_time
_log_progress(f"Done in {elapsed:.1f}s")
finally:
conn.close()
# ---------------------------------------------------------------------------
# CLI argument parsing
# ---------------------------------------------------------------------------
def _parse_register_bq(value: str) -> tuple[str, str]:
"""Parse --register-bq argument in 'alias=SQL' format.
Args:
value: String in format "alias=SELECT ..."
Returns:
Tuple of (alias, sql)
Raises:
argparse.ArgumentTypeError: If format is invalid
"""
eq_pos = value.find("=")
if eq_pos <= 0:
raise argparse.ArgumentTypeError(
f"Invalid --register-bq format: '{value}'. "
f"Expected: 'alias=SELECT ...' (e.g., 'traffic=SELECT report_date, ...')"
)
alias = value[:eq_pos].strip()
sql = value[eq_pos + 1:].strip()
if not sql:
raise argparse.ArgumentTypeError(
f"Empty SQL in --register-bq for alias '{alias}'"
)
return alias, sql
def build_parser() -> argparse.ArgumentParser:
"""Build the argument parser for remote_query CLI."""
parser = argparse.ArgumentParser(
description="Execute DuckDB queries spanning local Parquet + remote BigQuery tables",
formatter_class=argparse.RawDescriptionHelpFormatter,
epilog="""
Examples:
# Local-only query (no BigQuery):
python -m src.remote_query --sql "SELECT COUNT(*) FROM order_economics"
# Register BQ result and query it:
python -m src.remote_query \\
--register-bq "traffic=SELECT report_date, SUM(visitors) FROM \\`proj.ds.table\\` GROUP BY 1" \\
--sql "SELECT * FROM traffic ORDER BY report_date"
# JOIN local + remote:
python -m src.remote_query \\
--register-bq "traffic=SELECT ... GROUP BY ..." \\
--sql "SELECT o.*, t.visitors FROM order_economics o JOIN traffic t ON ..." \\
--format parquet --output /tmp/result.parquet
""",
)
parser.add_argument(
"--sql",
required=False, # not required when --stdin is used
default=None,
help="DuckDB SQL query (executed after all views are registered)",
)
parser.add_argument(
"--register-bq",
action="append",
type=_parse_register_bq,
default=[],
metavar="ALIAS=SQL",
dest="bq_registrations",
help='Register BQ query result as DuckDB view. Format: "alias=BQ_SQL". Repeatable.',
)
parser.add_argument(
"--format",
choices=["table", "csv", "json", "parquet"],
default=None,
dest="fmt",
help="Output format (default: from config or 'table')",
)
parser.add_argument(
"--output",
default=None,
help="Output file path for parquet/csv/json (default: auto for parquet)",
)
parser.add_argument(
"--max-rows",
type=int,
default=None,
help="Max rows in final result (default: from config)",
)
parser.add_argument(
"--max-bq-rows",
type=int,
default=None,
help="Max rows per BQ sub-query (default: from config)",
)
parser.add_argument(
"--timeout",
type=int,
default=None,
help="Query timeout in seconds (default: from config)",
)
parser.add_argument(
"--data-dir",
default="/data/src_data",
help="Parquet data directory (default: /data/src_data)",
)
parser.add_argument(
"--quiet",
action="store_true",
help="Suppress progress messages (stderr)",
)
parser.add_argument(
"--stdin",
action="store_true",
help="Read query spec from stdin as JSON. Avoids shell escaping issues.",
)
return parser
def _parse_stdin_query() -> dict:
"""Parse query specification from stdin JSON.
Expected format:
{
"sql": "SELECT ... FROM ...",
"register_bq": {"alias": "BQ SQL", ...},
"format": "table",
"output": "/path/to/file",
"max_rows": 100000,
"max_bq_rows": 500000
}
Returns:
Dict with parsed query spec
"""
raw = sys.stdin.read().strip()
if not raw:
raise RemoteQueryError("Empty stdin. Provide JSON query spec.")
try:
spec = json.loads(raw)
except json.JSONDecodeError as e:
raise RemoteQueryError(f"Invalid JSON on stdin: {e}")
if "sql" not in spec:
raise RemoteQueryError("JSON must contain 'sql' field.")
return spec
def main(argv: Optional[list[str]] = None) -> None:
"""CLI entry point."""
parser = build_parser()
args = parser.parse_args(argv)
# Setup logging
logging.basicConfig(
level=logging.WARNING if args.quiet else logging.INFO,
format="%(asctime)s - %(name)s - %(levelname)s - %(message)s",
stream=sys.stderr,
)
try:
# --stdin mode: read query spec from JSON on stdin (no shell escaping needed)
if args.stdin:
spec = _parse_stdin_query()
bq_regs = [
(alias, sql) for alias, sql in spec.get("register_bq", {}).items()
]
execute_remote_query(
sql=spec["sql"],
bq_registrations=bq_regs,
fmt=spec.get("format", args.fmt),
output=spec.get("output", args.output),
max_rows=spec.get("max_rows", args.max_rows),
max_bq_rows=spec.get("max_bq_rows", args.max_bq_rows),
timeout=args.timeout,
data_dir=args.data_dir,
quiet=args.quiet,
)
return
# Validate --sql is provided when not using --stdin
if not args.sql:
parser.error("--sql is required (or use --stdin for JSON input)")
execute_remote_query(
sql=args.sql,
bq_registrations=args.bq_registrations,
fmt=args.fmt,
output=args.output,
max_rows=args.max_rows,
max_bq_rows=args.max_bq_rows,
timeout=args.timeout,
data_dir=args.data_dir,
quiet=args.quiet,
)
except RemoteQueryError as e:
print(f"ERROR: {e}", file=sys.stderr)
sys.exit(1)
except KeyboardInterrupt:
print("\nInterrupted.", file=sys.stderr)
sys.exit(130)
except Exception as e:
print(f"UNEXPECTED ERROR: {e}", file=sys.stderr)
logger.exception("Unexpected error in remote_query")
sys.exit(2)
if __name__ == "__main__":
main()

View file

@ -1,464 +0,0 @@
"""
Table Registry - Central source of truth for registered tables.
Manages table registrations in a JSON file. Generates data_description.md
as a read-only output for downstream consumers (config.py, profiler.py, webapp).
Supports:
- CRUD operations on registered tables
- Folder mapping (bucket -> folder name)
- Atomic persistence (tempfile + os.replace)
- Optimistic locking (version field)
- Audit logging
- One-time migration from existing data_description.md
- Generation of data_description.md with checksum header
"""
import hashlib
import json
import logging
import os
import re
import tempfile
from datetime import datetime, timezone
from pathlib import Path
from typing import Any, Optional
import yaml
logger = logging.getLogger(__name__)
# Default registry location
_DEFAULT_REGISTRY_DIR = Path(
os.environ.get("REGISTRY_DIR", "/data/src_data/metadata")
)
_REGISTRY_FILENAME = "table_registry.json"
def _now_iso() -> str:
"""Return current UTC time as ISO string."""
return datetime.now(timezone.utc).isoformat()
def _atomic_write_json(path: Path, data: dict) -> None:
"""Write JSON atomically using tempfile + os.replace."""
path.parent.mkdir(parents=True, exist_ok=True)
fd, tmp_path = tempfile.mkstemp(
dir=str(path.parent), suffix=".tmp"
)
try:
with os.fdopen(fd, "w") as f:
json.dump(data, f, indent=2, default=str)
os.chmod(tmp_path, 0o660)
os.replace(tmp_path, str(path))
except Exception:
try:
os.unlink(tmp_path)
except OSError:
pass
raise
def _audit_log(registry_path: Path, action: str, details: dict) -> None:
"""Append entry to registry audit log."""
audit_path = registry_path.parent / "registry_audit.log"
try:
entry = {
"timestamp": _now_iso(),
"action": action,
**details,
}
with open(audit_path, "a") as f:
f.write(json.dumps(entry, default=str) + "\n")
except Exception as e:
logger.warning(f"Could not write audit log: {e}")
class TableRegistry:
"""Manages table registrations. Source of truth for what gets synced."""
def __init__(self, registry_path: Path):
self.registry_path = registry_path
self._data = self._load()
@classmethod
def default(cls) -> "TableRegistry":
"""Create registry at the default location."""
return cls(_DEFAULT_REGISTRY_DIR / _REGISTRY_FILENAME)
# ── Persistence ──────────────────────────────────────────────────
def _load(self) -> dict:
"""Load registry from disk. Returns empty structure if not found."""
if self.registry_path.exists():
try:
with open(self.registry_path) as f:
data = json.load(f)
logger.info(
f"Registry loaded: {len(data.get('tables', []))} tables"
)
return data
except Exception as e:
logger.error(f"Error loading registry: {e}")
return self._empty_registry()
def _save(self) -> None:
"""Save registry to disk atomically."""
self._data["_metadata"]["updated_at"] = _now_iso()
self._data["_metadata"]["version"] = self.version + 1
_atomic_write_json(self.registry_path, self._data)
logger.debug("Registry saved (version %d)", self.version)
@staticmethod
def _empty_registry() -> dict:
now = _now_iso()
return {
"_metadata": {
"version": 0,
"created_at": now,
"updated_at": now,
},
"folder_mapping": {},
"tables": [],
}
# ── Properties ───────────────────────────────────────────────────
@property
def version(self) -> int:
return self._data.get("_metadata", {}).get("version", 0)
# ── Core CRUD ────────────────────────────────────────────────────
def list_tables(self) -> list[dict]:
"""Return all registered tables."""
return list(self._data.get("tables", []))
def get_table(self, table_id: str) -> Optional[dict]:
"""Get a single table by ID."""
for t in self._data.get("tables", []):
if t["id"] == table_id:
return dict(t)
return None
def is_registered(self, table_id: str) -> bool:
return any(t["id"] == table_id for t in self._data.get("tables", []))
def register_table(
self,
table_def: dict,
registered_by: str,
expected_version: Optional[int] = None,
) -> None:
"""Register a new table.
Args:
table_def: Table definition dict (must contain id, name, sync_strategy, primary_key).
registered_by: Email of the admin who registered the table.
expected_version: If provided, reject if registry version doesn't match (optimistic lock).
Raises:
ValueError: If table already registered or validation fails.
ConflictError: If expected_version doesn't match.
"""
if expected_version is not None and expected_version != self.version:
raise ConflictError(
f"Version conflict: expected {expected_version}, current {self.version}"
)
table_id = table_def.get("id", "")
if not table_id:
raise ValueError("Table definition must include 'id'")
if self.is_registered(table_id):
raise ValueError(f"Table '{table_id}' is already registered")
# Validate required fields
for field in ("name", "sync_strategy", "primary_key"):
if not table_def.get(field):
raise ValueError(f"Table definition must include '{field}'")
# Validate sync_strategy
valid_strategies = ("full_refresh", "incremental", "partitioned")
if table_def["sync_strategy"] not in valid_strategies:
raise ValueError(
f"Invalid sync_strategy '{table_def['sync_strategy']}'. "
f"Allowed: {', '.join(valid_strategies)}"
)
# Build full record
record = {
"id": table_id,
"name": table_def["name"],
"description": table_def.get("description", ""),
"primary_key": table_def["primary_key"],
"sync_strategy": table_def["sync_strategy"],
"incremental_window_days": table_def.get("incremental_window_days"),
"partition_by": table_def.get("partition_by"),
"partition_granularity": table_def.get("partition_granularity"),
"foreign_keys": table_def.get("foreign_keys", []),
"where_filters": table_def.get("where_filters", []),
"folder": table_def.get("folder"),
"dataset": table_def.get("dataset"),
"initial_load_chunk_days": table_def.get("initial_load_chunk_days", 30),
"registered_at": _now_iso(),
"registered_by": registered_by,
"source_metadata": table_def.get("source_metadata", {}),
}
self._data["tables"].append(record)
self._save()
_audit_log(self.registry_path, "register", {
"table_id": table_id,
"by": registered_by,
})
def unregister_table(
self,
table_id: str,
unregistered_by: str = "",
expected_version: Optional[int] = None,
) -> None:
"""Remove a table from the registry.
Raises:
ValueError: If table not found.
ConflictError: If expected_version doesn't match.
"""
if expected_version is not None and expected_version != self.version:
raise ConflictError(
f"Version conflict: expected {expected_version}, current {self.version}"
)
tables = self._data.get("tables", [])
new_tables = [t for t in tables if t["id"] != table_id]
if len(new_tables) == len(tables):
raise ValueError(f"Table '{table_id}' is not registered")
self._data["tables"] = new_tables
self._save()
_audit_log(self.registry_path, "unregister", {
"table_id": table_id,
"by": unregistered_by,
})
def update_table(
self,
table_id: str,
updates: dict,
updated_by: str = "",
expected_version: Optional[int] = None,
) -> None:
"""Update table configuration.
Raises:
ValueError: If table not found.
ConflictError: If expected_version doesn't match.
"""
if expected_version is not None and expected_version != self.version:
raise ConflictError(
f"Version conflict: expected {expected_version}, current {self.version}"
)
# Fields that can be updated
allowed_fields = {
"description", "primary_key", "sync_strategy",
"incremental_window_days", "partition_by", "partition_granularity",
"foreign_keys", "where_filters", "folder", "dataset",
"initial_load_chunk_days",
}
for t in self._data.get("tables", []):
if t["id"] == table_id:
for key, value in updates.items():
if key in allowed_fields:
t[key] = value
self._save()
_audit_log(self.registry_path, "update", {
"table_id": table_id,
"fields": list(updates.keys()),
"by": updated_by,
})
return
raise ValueError(f"Table '{table_id}' is not registered")
# ── Folder mapping ───────────────────────────────────────────────
def get_folder_mapping(self) -> dict[str, str]:
return dict(self._data.get("folder_mapping", {}))
def set_folder_mapping(self, bucket_id: str, folder: str) -> None:
self._data.setdefault("folder_mapping", {})[bucket_id] = folder
self._save()
# ── Generation ───────────────────────────────────────────────────
def generate_data_description_md(self, output_path: Path) -> None:
"""Regenerate data_description.md from registry.
The generated file is read-only and includes a checksum header.
Existing readers (config.py, profiler.py) consume this without changes.
"""
tables = self.list_tables()
folder_mapping = self.get_folder_mapping()
# Build YAML structure matching existing data_description.md format
yaml_data: dict[str, Any] = {}
if folder_mapping:
yaml_data["folder_mapping"] = folder_mapping
yaml_tables = []
for t in tables:
entry: dict[str, Any] = {
"id": t["id"],
"name": t["name"],
"description": t.get("description", ""),
"primary_key": t["primary_key"],
"sync_strategy": t["sync_strategy"],
}
# Optional fields -- only include if set
if t.get("incremental_window_days"):
entry["incremental_window_days"] = t["incremental_window_days"]
if t.get("partition_by"):
entry["partition_by"] = t["partition_by"]
if t.get("partition_granularity"):
entry["partition_granularity"] = t["partition_granularity"]
if t.get("max_history_days"):
entry["max_history_days"] = t["max_history_days"]
if t.get("initial_load_chunk_days") and t["initial_load_chunk_days"] != 30:
entry["initial_load_chunk_days"] = t["initial_load_chunk_days"]
if t.get("foreign_keys"):
entry["foreign_keys"] = t["foreign_keys"]
if t.get("where_filters"):
entry["where_filters"] = t["where_filters"]
if t.get("folder"):
entry["folder"] = t["folder"]
if t.get("dataset"):
entry["dataset"] = t["dataset"]
yaml_tables.append(entry)
yaml_data["tables"] = yaml_tables
yaml_str = yaml.dump(
yaml_data, default_flow_style=False, sort_keys=False, allow_unicode=True
)
# Compute checksum
checksum = hashlib.sha256(yaml_str.encode()).hexdigest()[:16]
# Build markdown
lines = [
f"<!-- AUTO-GENERATED from table_registry.json -- do not edit manually -->",
f"<!-- Use the admin UI at /admin/tables to manage table registrations -->",
f"<!-- checksum: sha256:{checksum} -->",
"",
"# Data Description",
"",
f"Generated at {_now_iso()} from table registry "
f"(version {self.version}, {len(yaml_tables)} tables).",
"",
"```yaml",
yaml_str.rstrip(),
"```",
"",
]
content = "\n".join(lines)
output_path.parent.mkdir(parents=True, exist_ok=True)
output_path.write_text(content)
logger.info(
f"Generated data_description.md: {len(yaml_tables)} tables "
f"(checksum: {checksum})"
)
# ── Migration ────────────────────────────────────────────────────
@classmethod
def import_from_data_description(
cls,
md_path: Path,
registry_path: Path,
registered_by: str = "migration",
) -> "TableRegistry":
"""One-time migration: parse existing data_description.md into registry.
Creates a new registry JSON from the existing markdown YAML blocks.
"""
if not md_path.exists():
raise FileNotFoundError(f"data_description.md not found: {md_path}")
content = md_path.read_text()
# Extract YAML blocks
yaml_matches = re.findall(r"```yaml\n(.*?)```", content, re.DOTALL)
if not yaml_matches:
raise ValueError("No YAML blocks found in data_description.md")
all_tables: list[dict] = []
folder_mapping: dict[str, str] = {}
for yaml_block in yaml_matches:
data = yaml.safe_load(yaml_block)
if data:
if "tables" in data:
all_tables.extend(data["tables"])
if "folder_mapping" in data:
folder_mapping.update(data["folder_mapping"])
if not all_tables:
raise ValueError("No tables found in YAML blocks")
# Build registry
registry = cls(registry_path)
registry._data = cls._empty_registry()
registry._data["folder_mapping"] = folder_mapping
registry._data["_metadata"]["migrated_from"] = str(md_path)
now = _now_iso()
for table_data in all_tables:
record = {
"id": table_data.get("id", ""),
"name": table_data.get("name", ""),
"description": table_data.get("description", ""),
"primary_key": table_data.get("primary_key", ""),
"sync_strategy": table_data.get("sync_strategy", "full_refresh"),
"incremental_window_days": table_data.get("incremental_window_days"),
"partition_by": table_data.get("partition_by"),
"partition_granularity": table_data.get("partition_granularity"),
"foreign_keys": table_data.get("foreign_keys", []),
"where_filters": table_data.get("where_filters", []),
"folder": table_data.get("folder"),
"dataset": table_data.get("dataset"),
"initial_load_chunk_days": table_data.get("initial_load_chunk_days", 30),
"max_history_days": table_data.get("max_history_days"),
"registered_at": now,
"registered_by": registered_by,
"source_metadata": {},
}
registry._data["tables"].append(record)
registry._save()
_audit_log(registry_path, "migrate", {
"source": str(md_path),
"tables_imported": len(all_tables),
"by": registered_by,
})
logger.info(
f"Migrated {len(all_tables)} tables from {md_path} to registry"
)
return registry
class ConflictError(Exception):
"""Raised when optimistic locking version doesn't match."""
pass

File diff suppressed because it is too large Load diff

File diff suppressed because it is too large Load diff

View file

@ -93,17 +93,15 @@ def mock_client():
@pytest.fixture
def mock_config():
"""Return a mock Config with a single table."""
cfg = MagicMock()
cfg.tables = [
FakeTableConfig(
name="order_economics",
id="prj.dataset.order_economics",
catalog_fqn="bigquery.prj.dataset.order_economics",
)
def mock_tables():
"""Return a list of table dicts matching TableRegistryRepository.list_all() format."""
return [
{
"id": "prj.dataset.order_economics",
"name": "order_economics",
"catalog_fqn": "bigquery.prj.dataset.order_economics",
}
]
return cfg
CATALOG_URL = "https://catalog.example.com"
@ -351,11 +349,11 @@ class TestExportMetrics:
class TestExportTables:
def test_export_tables_writes_files(
self, tmp_path: Path, mock_client, mock_config
self, tmp_path: Path, mock_client, mock_tables
):
"""Creates table YAML with columns."""
docs = tmp_path / "docs"
count = export_tables(mock_client, mock_config, docs, CATALOG_URL)
count = export_tables(mock_client, mock_tables, docs, CATALOG_URL)
assert count == 1
@ -374,21 +372,12 @@ class TestExportTables:
assert parsed["columns"][0]["name"] == "order_id"
def test_export_tables_handles_api_error(
self, tmp_path: Path, mock_client, mock_config
self, tmp_path: Path, mock_client
):
"""Continues on per-table errors, exports remaining tables."""
# Two tables: first will fail, second succeeds
mock_config.tables = [
FakeTableConfig(
name="broken_table",
id="prj.dataset.broken",
catalog_fqn="bigquery.prj.dataset.broken",
),
FakeTableConfig(
name="good_table",
id="prj.dataset.good",
catalog_fqn="bigquery.prj.dataset.good",
),
tables = [
{"id": "prj.dataset.broken", "name": "broken_table", "catalog_fqn": "bigquery.prj.dataset.broken"},
{"id": "prj.dataset.good", "name": "good_table", "catalog_fqn": "bigquery.prj.dataset.good"},
]
def side_effect(fqn):
@ -399,7 +388,7 @@ class TestExportTables:
mock_client.get_table.side_effect = side_effect
docs = tmp_path / "docs"
count = export_tables(mock_client, mock_config, docs, CATALOG_URL)
count = export_tables(mock_client, tables, docs, CATALOG_URL)
# Only the good table should succeed
assert count == 1
@ -407,11 +396,11 @@ class TestExportTables:
assert (docs / "tables" / "good_table.yml").exists()
def test_export_tables_uses_catalog_fqn(
self, tmp_path: Path, mock_client, mock_config
self, tmp_path: Path, mock_client, mock_tables
):
"""Uses explicit catalog_fqn when set on table config."""
docs = tmp_path / "docs"
export_tables(mock_client, mock_config, docs, CATALOG_URL)
export_tables(mock_client, mock_tables, docs, CATALOG_URL)
mock_client.get_table.assert_called_once_with(
"bigquery.prj.dataset.order_economics"
@ -421,17 +410,12 @@ class TestExportTables:
self, tmp_path: Path, mock_client
):
"""When catalog_fqn is None, derives FQN as 'bigquery.{id}'."""
cfg = MagicMock()
cfg.tables = [
FakeTableConfig(
name="my_table",
id="project.dataset.my_table",
catalog_fqn=None,
)
tables = [
{"id": "project.dataset.my_table", "name": "my_table"},
]
docs = tmp_path / "docs"
export_tables(mock_client, cfg, docs, CATALOG_URL)
export_tables(mock_client, tables, docs, CATALOG_URL)
mock_client.get_table.assert_called_once_with(
"bigquery.project.dataset.my_table"

View file

@ -1,69 +0,0 @@
"""Tests for TableConfig.bq_entity_type field validation."""
import pytest
from src.config import TableConfig
# ---------------------------------------------------------------------------
# Helpers
# ---------------------------------------------------------------------------
def _make_table(**overrides) -> TableConfig:
"""Create a TableConfig with sensible defaults, applying overrides."""
defaults = dict(
id="test.dataset.table",
name="test_table",
description="Test",
primary_key="id",
sync_strategy="full_refresh",
)
defaults.update(overrides)
return TableConfig(**defaults)
# ---------------------------------------------------------------------------
# Tests
# ---------------------------------------------------------------------------
class TestBqEntityTypeDefault:
def test_default_is_view(self):
table = _make_table()
assert table.bq_entity_type == "view"
class TestBqEntityTypeValidValues:
@pytest.mark.parametrize("entity_type", ["view", "table"])
def test_valid_bq_entity_type(self, entity_type):
table = _make_table(bq_entity_type=entity_type)
assert table.bq_entity_type == entity_type
class TestBqEntityTypeInvalid:
@pytest.mark.parametrize("bad_type", ["VIEW", "physical", "", "tables", "materialized"])
def test_invalid_bq_entity_type_raises(self, bad_type):
with pytest.raises(ValueError, match="Invalid bq_entity_type"):
_make_table(bq_entity_type=bad_type)
class TestBqEntityTypeFromKwarg:
def test_kwarg_sets_bq_entity_type(self):
"""Simulate what _parse_data_description does: pass bq_entity_type as kwarg."""
table = TableConfig(
id="proj.dataset.orders",
name="orders",
description="Order data",
primary_key="order_id",
sync_strategy="full_refresh",
bq_entity_type="table",
)
assert table.bq_entity_type == "table"
def test_kwarg_default_when_omitted(self):
"""When YAML has no bq_entity_type, _parse_data_description passes 'view'."""
table = TableConfig(
id="proj.dataset.orders",
name="orders",
description="Order data",
primary_key="order_id",
sync_strategy="full_refresh",
)
assert table.bq_entity_type == "view"

View file

@ -1,69 +0,0 @@
"""Tests for TableConfig.query_mode field validation."""
import pytest
from src.config import TableConfig
# ---------------------------------------------------------------------------
# Helpers
# ---------------------------------------------------------------------------
def _make_table(**overrides) -> TableConfig:
"""Create a TableConfig with sensible defaults, applying overrides."""
defaults = dict(
id="test.dataset.table",
name="test_table",
description="Test",
primary_key="id",
sync_strategy="full_refresh",
)
defaults.update(overrides)
return TableConfig(**defaults)
# ---------------------------------------------------------------------------
# Tests
# ---------------------------------------------------------------------------
class TestQueryModeDefault:
def test_default_is_local(self):
table = _make_table()
assert table.query_mode == "local"
class TestQueryModeValidValues:
@pytest.mark.parametrize("mode", ["local", "remote", "hybrid"])
def test_valid_query_mode(self, mode):
table = _make_table(query_mode=mode)
assert table.query_mode == mode
class TestQueryModeInvalid:
@pytest.mark.parametrize("bad_mode", ["invalid", "Local", "REMOTE", "", "sql"])
def test_invalid_query_mode_raises(self, bad_mode):
with pytest.raises(ValueError, match="Invalid query_mode"):
_make_table(query_mode=bad_mode)
class TestQueryModeFromKwarg:
def test_kwarg_sets_query_mode(self):
"""Simulate what _parse_data_description does: pass query_mode as kwarg."""
table = TableConfig(
id="proj.dataset.orders",
name="orders",
description="Order data",
primary_key="order_id",
sync_strategy="full_refresh",
query_mode="remote",
)
assert table.query_mode == "remote"
def test_kwarg_default_when_omitted(self):
"""When YAML has no query_mode, _parse_data_description passes 'local'."""
table = TableConfig(
id="proj.dataset.orders",
name="orders",
description="Order data",
primary_key="order_id",
sync_strategy="full_refresh",
)
assert table.query_mode == "local"

View file

@ -1,103 +0,0 @@
"""Tests for TableConfig.sync_schedule field validation."""
import pytest
from src.config import TableConfig
# ---------------------------------------------------------------------------
# Helpers
# ---------------------------------------------------------------------------
def _make_table(**overrides) -> TableConfig:
"""Create a TableConfig with sensible defaults, applying overrides."""
defaults = dict(
id="test.dataset.table",
name="test_table",
description="Test",
primary_key="id",
sync_strategy="full_refresh",
)
defaults.update(overrides)
return TableConfig(**defaults)
# ---------------------------------------------------------------------------
# Tests
# ---------------------------------------------------------------------------
class TestSyncScheduleDefault:
def test_default_is_none(self):
table = _make_table()
assert table.sync_schedule is None
class TestSyncScheduleValidValues:
@pytest.mark.parametrize(
"schedule",
[
"every 15m",
"every 1h",
"daily 05:00",
"daily 07:00,13:00,18:00",
"daily 00:00,12:00",
],
ids=[
"every-15m",
"every-1h",
"daily-single",
"daily-three-times",
"daily-two-times",
],
)
def test_valid_schedule_accepted(self, schedule: str):
table = _make_table(sync_schedule=schedule)
assert table.sync_schedule == schedule
class TestSyncScheduleEdgeCases:
def test_every_zero_minutes(self):
"""every 0m matches the regex -- validation is syntactic, not semantic."""
table = _make_table(sync_schedule="every 0m")
assert table.sync_schedule == "every 0m"
def test_daily_2359(self):
table = _make_table(sync_schedule="daily 23:59")
assert table.sync_schedule == "daily 23:59"
class TestSyncScheduleInvalid:
@pytest.mark.parametrize(
"bad_schedule",
[
"daily 07:00,13:00,18:00,", # trailing comma
"daily 7:00", # single-digit hour
"daily", # missing time
"hourly", # unsupported keyword
"weekly", # unsupported keyword
],
ids=[
"trailing-comma",
"single-digit-hour",
"daily-no-time",
"hourly-keyword",
"weekly-keyword",
],
)
def test_invalid_schedule_raises(self, bad_schedule: str):
with pytest.raises(ValueError, match="Invalid sync_schedule"):
_make_table(sync_schedule=bad_schedule)
def test_empty_string_treated_as_none(self):
"""Empty string is falsy, so validation is skipped (same as None)."""
table = _make_table(sync_schedule="")
assert table.sync_schedule == ""
def test_daily_25_accepted_by_regex(self):
"""25:00 passes regex validation (two digits). Document this behavior."""
table = _make_table(sync_schedule="daily 25:00")
assert table.sync_schedule == "daily 25:00"
class TestSyncScheduleNoneExplicit:
def test_explicit_none_accepted(self):
table = _make_table(sync_schedule=None)
assert table.sync_schedule is None

View file

@ -1,228 +0,0 @@
"""Tests for remote table skipping in DataSyncManager.sync_all().
Tables with query_mode == "remote" should be skipped during sync (no local
Parquet file is needed -- queries go directly to BigQuery). Tables with
query_mode "local" or "hybrid" must still be synced normally.
"""
from unittest.mock import MagicMock, patch
import pytest
from src.config import TableConfig
from src.data_sync import DataSyncManager
# ---------------------------------------------------------------------------
# Helpers
# ---------------------------------------------------------------------------
def _make_table_config(
table_id: str,
name: str,
query_mode: str = "local",
) -> TableConfig:
"""Create a minimal TableConfig for testing."""
return TableConfig(
id=table_id,
name=name,
description=f"Test table {name}",
primary_key="id",
sync_strategy="full_refresh",
query_mode=query_mode,
)
def _successful_sync_result() -> dict:
"""Return a fake successful sync result dict."""
return {
"success": True,
"rows": 100,
"file_size_mb": 0.5,
}
# ---------------------------------------------------------------------------
# Fixtures
# ---------------------------------------------------------------------------
@pytest.fixture
def table_local():
return _make_table_config("t.local", "local_table", query_mode="local")
@pytest.fixture
def table_remote():
return _make_table_config("t.remote", "remote_table", query_mode="remote")
@pytest.fixture
def table_hybrid():
return _make_table_config("t.hybrid", "hybrid_table", query_mode="hybrid")
@pytest.fixture
def all_tables(table_local, table_remote, table_hybrid):
return [table_local, table_remote, table_hybrid]
@pytest.fixture
def mock_config(all_tables):
"""Return a mock Config whose .tables list contains all three query modes."""
cfg = MagicMock()
cfg.tables = all_tables
cfg.get_metadata_path.return_value = MagicMock() # Path-like
def _get_table_config(tid):
return next((t for t in all_tables if t.id == tid), None)
cfg.get_table_config.side_effect = _get_table_config
return cfg
@pytest.fixture
def mock_data_source():
"""Return a mock DataSource that always succeeds."""
ds = MagicMock()
ds.sync_table.return_value = _successful_sync_result()
return ds
@pytest.fixture
def sync_manager(mock_config, mock_data_source):
"""Create a DataSyncManager with mocked dependencies."""
with (
patch("src.data_sync.get_config", return_value=mock_config),
patch("src.data_sync.create_data_source", return_value=mock_data_source),
patch("src.data_sync.SyncState"),
):
manager = DataSyncManager()
# Patch out schema generation and auto-profiling (not under test)
manager._generate_schema_yaml = MagicMock()
yield manager
# ---------------------------------------------------------------------------
# Tests
# ---------------------------------------------------------------------------
class TestSyncAllRemoteSkipping:
"""Verify that sync_all filters out remote tables."""
def test_remote_table_not_synced(self, sync_manager, mock_data_source, table_remote):
"""Remote table must NOT be passed to data_source.sync_table."""
sync_manager.sync_all()
synced_ids = [
call.args[0].id for call in mock_data_source.sync_table.call_args_list
]
assert table_remote.id not in synced_ids
def test_local_table_is_synced(self, sync_manager, mock_data_source, table_local):
"""Local table must be synced normally."""
sync_manager.sync_all()
synced_ids = [
call.args[0].id for call in mock_data_source.sync_table.call_args_list
]
assert table_local.id in synced_ids
def test_hybrid_table_is_synced(self, sync_manager, mock_data_source, table_hybrid):
"""Hybrid table must be synced (needs local parquet for profiling)."""
sync_manager.sync_all()
synced_ids = [
call.args[0].id for call in mock_data_source.sync_table.call_args_list
]
assert table_hybrid.id in synced_ids
def test_sync_call_count(self, sync_manager, mock_data_source):
"""Only local + hybrid tables should result in sync_table calls."""
sync_manager.sync_all()
# 3 tables total, 1 remote -> 2 sync calls
assert mock_data_source.sync_table.call_count == 2
def test_results_exclude_remote(self, sync_manager, table_remote):
"""The results dict must not contain an entry for the remote table."""
results = sync_manager.sync_all()
assert table_remote.id not in results
def test_results_include_local_and_hybrid(
self, sync_manager, table_local, table_hybrid
):
"""Results dict must contain entries for local and hybrid tables."""
results = sync_manager.sync_all()
assert table_local.id in results
assert table_hybrid.id in results
class TestSyncAllAllRemote:
"""Edge case: every table is remote."""
def test_no_sync_calls_when_all_remote(self, mock_config, mock_data_source):
remote_only = [
_make_table_config("t.r1", "remote1", query_mode="remote"),
_make_table_config("t.r2", "remote2", query_mode="remote"),
]
mock_config.tables = remote_only
with (
patch("src.data_sync.get_config", return_value=mock_config),
patch("src.data_sync.create_data_source", return_value=mock_data_source),
patch("src.data_sync.SyncState"),
):
manager = DataSyncManager()
manager._generate_schema_yaml = MagicMock()
results = manager.sync_all()
assert mock_data_source.sync_table.call_count == 0
assert results == {}
class TestSyncAllNoRemote:
"""Edge case: no remote tables at all -- everything syncs."""
def test_all_tables_synced(self, mock_config, mock_data_source):
local_only = [
_make_table_config("t.l1", "local1", query_mode="local"),
_make_table_config("t.l2", "local2", query_mode="local"),
]
mock_config.tables = local_only
with (
patch("src.data_sync.get_config", return_value=mock_config),
patch("src.data_sync.create_data_source", return_value=mock_data_source),
patch("src.data_sync.SyncState"),
):
manager = DataSyncManager()
manager._generate_schema_yaml = MagicMock()
results = manager.sync_all()
assert mock_data_source.sync_table.call_count == 2
assert "t.l1" in results
assert "t.l2" in results
class TestSyncAllWithTableFilter:
"""When sync_all receives an explicit table list, remote filtering still applies."""
def test_explicit_remote_table_still_skipped(
self, sync_manager, mock_data_source, table_remote
):
"""Even if explicitly listed, a remote table should be skipped."""
sync_manager.sync_all(tables=[table_remote.id])
assert mock_data_source.sync_table.call_count == 0
def test_explicit_local_table_synced(
self, sync_manager, mock_data_source, table_local
):
"""An explicitly listed local table should be synced."""
sync_manager.sync_all(tables=[table_local.id])
assert mock_data_source.sync_table.call_count == 1
synced_id = mock_data_source.sync_table.call_args_list[0].args[0].id
assert synced_id == table_local.id

View file

@ -7,11 +7,11 @@ from datetime import datetime, timedelta
from unittest.mock import Mock, patch, MagicMock
from dataclasses import dataclass
from src.config import TableConfig
from connectors.openmetadata.enricher import (
CatalogEnricher,
CatalogTableData,
CatalogColumnData,
TableConfig,
)
@ -21,9 +21,6 @@ def sample_table_config():
return TableConfig(
id="prj-grp-dataview-prod-1ff9.marketing.roi_datamart_v2",
name="roi_datamart_v2",
description="ROI metrics",
primary_key="id",
sync_strategy="full_refresh",
)
@ -111,9 +108,6 @@ def test_enrich_table_disabled():
table_config = TableConfig(
id="test.table",
name="test",
description="Test",
primary_key="id",
sync_strategy="full_refresh",
)
result = enricher.enrich_table(table_config)
@ -145,9 +139,6 @@ def test_enrich_table_cache_hit():
table_config = TableConfig(
id="prj-grp-dataview-prod-1ff9.marketing.test",
name="test",
description="Test",
primary_key="id",
sync_strategy="full_refresh",
)
result = enricher.enrich_table(table_config)
@ -199,9 +190,6 @@ def test_derive_fqn_auto():
table_config = TableConfig(
id="prj-grp-dataview-prod-1ff9.marketing.roi_datamart_v2",
name="roi_datamart_v2",
description="Test",
primary_key="id",
sync_strategy="full_refresh",
)
fqn = enricher._derive_fqn(table_config)
@ -223,9 +211,6 @@ def test_derive_fqn_explicit_override():
table_config = TableConfig(
id="prj-grp-dataview-prod-1ff9.marketing.roi_datamart_v2",
name="roi_datamart_v2",
description="Test",
primary_key="id",
sync_strategy="full_refresh",
)
table_config.catalog_fqn = "bigquery.custom.fqn.override"
@ -390,9 +375,6 @@ def test_enrich_table_http_error_graceful():
table_config = TableConfig(
id="test.table",
name="test",
description="Test",
primary_key="id",
sync_strategy="full_refresh",
)
# Should return None, not raise

View file

@ -1,779 +0,0 @@
"""Tests for remote_query module - hybrid local Parquet + remote BigQuery queries.
Tests cover:
- CLI argument parsing (_parse_register_bq, build_parser)
- Local view setup (_setup_local_views via create_local_views)
- BQ registration with safety checks (_validate_bq_result_size, _estimate_memory_mb, _register_bq_views)
- Output formatting (_print_table, _format_output)
- End-to-end local-only queries (no BQ mocking needed)
"""
import argparse
import csv
import json
import os
from io import StringIO
from pathlib import Path
from unittest.mock import MagicMock, patch
import duckdb
import pyarrow as pa
import pyarrow.parquet as pq
import pytest
from src.remote_query import (
RemoteQueryError,
_estimate_memory_mb,
_format_output,
_parse_register_bq,
_print_table,
_register_bq_views,
_validate_bq_result_size,
build_parser,
execute_remote_query,
)
# ---------------------------------------------------------------------------
# Fixtures
# ---------------------------------------------------------------------------
@pytest.fixture
def tmp_local_project(tmp_path):
"""Create a minimal project with docs/data_description.md and parquet files.
Layout:
tmp_path/
docs/data_description.md (YAML with local + remote + hybrid tables)
server/parquet/crm_data/orders.parquet
server/parquet/crm_data/products.parquet
Returns (project_root, data_dir) where data_dir = tmp_path / "server".
"""
docs_dir = tmp_path / "docs"
docs_dir.mkdir()
data_description = """\
# Data Description
```yaml
folder_mapping:
in.c-crm: crm_data
tables:
- id: "in.c-crm.orders"
name: "orders"
description: "Order data"
primary_key: "order_id"
sync_strategy: "full_refresh"
- id: "in.c-crm.products"
name: "products"
description: "Product catalog"
primary_key: "product_id"
sync_strategy: "full_refresh"
- id: "prj-grp-dataview-prod-1ff9.supply.traffic"
name: "traffic"
description: "Remote BQ traffic table"
primary_key: "id"
query_mode: "remote"
- id: "in.c-crm.inventory"
name: "inventory"
description: "Hybrid inventory"
primary_key: "id"
sync_strategy: "full_refresh"
query_mode: "hybrid"
```
"""
(docs_dir / "data_description.md").write_text(data_description)
# Create parquet files for local tables
crm_dir = tmp_path / "server" / "parquet" / "crm_data"
crm_dir.mkdir(parents=True)
orders_table = pa.table({
"order_id": [1, 2, 3, 4, 5],
"amount": [10.0, 20.0, 30.0, 40.0, 50.0],
"product_id": [101, 102, 101, 103, 102],
})
pq.write_table(orders_table, crm_dir / "orders.parquet")
products_table = pa.table({
"product_id": [101, 102, 103],
"name": ["Widget", "Gadget", "Doohickey"],
})
pq.write_table(products_table, crm_dir / "products.parquet")
# Create parquet for hybrid table
inventory_table = pa.table({
"id": [1, 2],
"stock": [100, 200],
})
pq.write_table(inventory_table, crm_dir / "inventory.parquet")
data_dir = str(tmp_path / "server")
return tmp_path, data_dir
@pytest.fixture
def duckdb_conn():
"""Create an in-memory DuckDB connection, closed after test."""
conn = duckdb.connect(":memory:")
yield conn
conn.close()
class _DuckDBConnectionProxy:
"""Proxy around DuckDBPyConnection that silently ignores unsupported SET commands.
DuckDB versions may not support 'statement_timeout'. This proxy catches
CatalogException for SET commands so end-to-end tests work across versions.
The real connection's execute method is read-only, so we wrap it.
"""
def __init__(self, conn):
object.__setattr__(self, "_conn", conn)
def execute(self, sql, *args, **kwargs):
if isinstance(sql, str) and sql.strip().upper().startswith("SET "):
try:
return self._conn.execute(sql, *args, **kwargs)
except duckdb.CatalogException:
return None
return self._conn.execute(sql, *args, **kwargs)
def __getattr__(self, name):
return getattr(self._conn, name)
def _patched_duckdb_connect(*args, **kwargs):
"""Create a DuckDB connection wrapped in the proxy."""
conn = duckdb.connect(*args, **kwargs)
return _DuckDBConnectionProxy(conn)
# ---------------------------------------------------------------------------
# Tests: CLI argument parsing
# ---------------------------------------------------------------------------
class TestCLIArgParsing:
"""Test _parse_register_bq() and build_parser()."""
def test_sql_defaults_to_none(self):
"""Parser allows --sql to be omitted (for --stdin mode)."""
parser = build_parser()
args = parser.parse_args(["--stdin"])
assert args.sql is None
assert args.stdin is True
def test_register_bq_parsing(self):
"""'alias=SELECT ...' parses into (alias, sql) tuple."""
result = _parse_register_bq("traffic=SELECT report_date FROM `proj.ds.table`")
assert result == ("traffic", "SELECT report_date FROM `proj.ds.table`")
def test_register_bq_invalid_format(self):
"""Missing '=' should raise ArgumentTypeError."""
with pytest.raises(argparse.ArgumentTypeError, match="Invalid --register-bq format"):
_parse_register_bq("no_equals_sign_here")
def test_register_bq_empty_sql(self):
"""Alias with empty SQL after '=' should raise."""
with pytest.raises(argparse.ArgumentTypeError, match="Empty SQL"):
_parse_register_bq("alias=")
def test_register_bq_empty_alias(self):
"""'=SELECT ...' (empty alias) should raise."""
with pytest.raises(argparse.ArgumentTypeError, match="Invalid --register-bq format"):
_parse_register_bq("=SELECT 1")
def test_multiple_register_bq(self):
"""Multiple --register-bq args should be collected into a list."""
parser = build_parser()
args = parser.parse_args([
"--sql", "SELECT 1",
"--register-bq", "t1=SELECT a FROM x",
"--register-bq", "t2=SELECT b FROM y",
])
assert len(args.bq_registrations) == 2
assert args.bq_registrations[0] == ("t1", "SELECT a FROM x")
assert args.bq_registrations[1] == ("t2", "SELECT b FROM y")
def test_default_format_is_none(self):
"""Default --format should be None (uses config default at runtime)."""
parser = build_parser()
args = parser.parse_args(["--sql", "SELECT 1"])
assert args.fmt is None
def test_explicit_format(self):
"""Explicit --format should be respected."""
parser = build_parser()
args = parser.parse_args(["--sql", "SELECT 1", "--format", "csv"])
assert args.fmt == "csv"
def test_invalid_format_rejected(self):
"""Invalid --format value should cause parser error."""
parser = build_parser()
with pytest.raises(SystemExit):
parser.parse_args(["--sql", "SELECT 1", "--format", "xml"])
def test_no_register_bq_yields_empty_list(self):
"""When no --register-bq is provided, bq_registrations defaults to []."""
parser = build_parser()
args = parser.parse_args(["--sql", "SELECT 1"])
assert args.bq_registrations == []
def test_register_bq_sql_with_equals(self):
"""SQL containing '=' should be parsed correctly (split only on first '=')."""
result = _parse_register_bq("view=SELECT * FROM t WHERE col = 5")
assert result[0] == "view"
assert result[1] == "SELECT * FROM t WHERE col = 5"
def test_quiet_flag(self):
"""--quiet should set quiet=True."""
parser = build_parser()
args = parser.parse_args(["--sql", "SELECT 1", "--quiet"])
assert args.quiet is True
def test_max_rows_parsing(self):
"""--max-rows should be parsed as integer."""
parser = build_parser()
args = parser.parse_args(["--sql", "SELECT 1", "--max-rows", "500"])
assert args.max_rows == 500
# ---------------------------------------------------------------------------
# Tests: Local view setup
# ---------------------------------------------------------------------------
class TestLocalViewSetup:
"""Test _setup_local_views via create_local_views with tmp_path fixture."""
def test_creates_views_from_parquet(self, tmp_local_project, duckdb_conn):
"""Local tables should be available as DuckDB views after setup."""
project_root, data_dir = tmp_local_project
with patch("scripts.duckdb_manager.find_project_root", return_value=project_root):
from src.remote_query import _setup_local_views
created = _setup_local_views(duckdb_conn, data_dir, quiet=True)
assert "orders" in created
assert "products" in created
# Verify data is queryable
count = duckdb_conn.execute("SELECT COUNT(*) FROM orders").fetchone()[0]
assert count == 5
def test_skips_remote_tables(self, tmp_local_project, duckdb_conn):
"""Remote tables (query_mode='remote') should NOT create local views."""
project_root, data_dir = tmp_local_project
with patch("scripts.duckdb_manager.find_project_root", return_value=project_root):
from src.remote_query import _setup_local_views
created = _setup_local_views(duckdb_conn, data_dir, quiet=True)
assert "traffic" not in created
# Verify the remote table is not queryable
tables = [row[0] for row in duckdb_conn.execute("SHOW TABLES").fetchall()]
assert "traffic" not in tables
def test_includes_hybrid_tables(self, tmp_local_project, duckdb_conn):
"""Hybrid tables (query_mode='hybrid') should create local views."""
project_root, data_dir = tmp_local_project
with patch("scripts.duckdb_manager.find_project_root", return_value=project_root):
from src.remote_query import _setup_local_views
created = _setup_local_views(duckdb_conn, data_dir, quiet=True)
assert "inventory" in created
count = duckdb_conn.execute("SELECT COUNT(*) FROM inventory").fetchone()[0]
assert count == 2
# ---------------------------------------------------------------------------
# Tests: BQ registration with safety checks
# ---------------------------------------------------------------------------
class TestBQRegistration:
"""Test BQ result validation and registration (mocked BigQuery)."""
@staticmethod
def _make_mock_bq_client(count_result: int = 100, schema_fields: int = 5):
"""Create a mock BQ client that returns controlled count and schema.
Args:
count_result: Row count returned by COUNT(*) query
schema_fields: Number of fields in the schema
"""
mock_client = MagicMock()
# COUNT(*) query result
count_row = MagicMock()
count_row.__getitem__ = MagicMock(return_value=count_result)
count_iter = iter([count_row])
# Schema query result (LIMIT 0)
mock_schema_fields = [MagicMock() for _ in range(schema_fields)]
mock_schema = MagicMock()
mock_schema.__len__ = MagicMock(return_value=schema_fields)
# Use side_effect to return different results for different queries
def query_side_effect(sql):
job = MagicMock()
if sql.startswith("SELECT COUNT(*)"):
result = MagicMock()
result.__iter__ = MagicMock(return_value=iter([count_row]))
job.result.return_value = result
elif "LIMIT 0" in sql:
result = MagicMock()
result.schema = mock_schema_fields
job.result.return_value = result
return job
mock_client.query.side_effect = query_side_effect
return mock_client
def test_count_check_blocks_large_result(self):
"""BQ sub-query exceeding max_rows should raise RemoteQueryError."""
mock_client = self._make_mock_bq_client(count_result=1_000_000)
with pytest.raises(RemoteQueryError, match="would return 1,000,000 rows"):
_validate_bq_result_size(
bq_client=mock_client,
sql="SELECT * FROM big_table",
alias="big_table",
max_rows=500_000,
)
def test_validates_small_result_passes(self):
"""BQ sub-query within limits should return the row count."""
mock_client = self._make_mock_bq_client(count_result=1000)
row_count = _validate_bq_result_size(
bq_client=mock_client,
sql="SELECT * FROM small_table",
alias="small_table",
max_rows=500_000,
)
assert row_count == 1000
def test_memory_estimate_blocks_huge_result(self):
"""_register_bq_views should refuse when estimated memory exceeds 2 GB."""
# Create a mock that passes count check but fails memory check
# 500K rows x 100 cols x 50 bytes/cell = ~2384 MB > 2048 MB limit
mock_client = self._make_mock_bq_client(count_result=500_000, schema_fields=100)
conn = duckdb.connect(":memory:")
try:
with patch("src.remote_query._create_bq_client", return_value=mock_client), \
patch.dict(os.environ, {"BIGQUERY_PROJECT": "test-proj"}):
with pytest.raises(RemoteQueryError, match="estimated memory"):
_register_bq_views(
conn=conn,
registrations=[("huge", "SELECT * FROM huge_table")],
max_bq_rows=1_000_000,
timeout_seconds=60,
quiet=True,
)
finally:
conn.close()
def test_registers_small_result(self):
"""BQ sub-query within all limits should register successfully."""
# Small result: 100 rows x 5 cols = ~0.02 MB
mock_client = self._make_mock_bq_client(count_result=100, schema_fields=5)
# Mock register_bq_table to return the row count
conn = duckdb.connect(":memory:")
try:
with patch("src.remote_query._create_bq_client", return_value=mock_client), \
patch("src.remote_query.register_bq_table", return_value=100) as mock_reg, \
patch.dict(os.environ, {"BIGQUERY_PROJECT": "test-proj"}):
results = _register_bq_views(
conn=conn,
registrations=[("small_view", "SELECT * FROM small_table")],
max_bq_rows=500_000,
timeout_seconds=60,
quiet=True,
)
assert results == {"small_view": 100}
mock_reg.assert_called_once()
finally:
conn.close()
def test_missing_bigquery_project_raises(self):
"""Missing BIGQUERY_PROJECT env var should raise RemoteQueryError."""
conn = duckdb.connect(":memory:")
try:
with patch.dict(os.environ, {}, clear=True):
with pytest.raises(RemoteQueryError, match="BIGQUERY_PROJECT"):
_register_bq_views(
conn=conn,
registrations=[("v", "SELECT 1")],
max_bq_rows=100,
timeout_seconds=60,
)
finally:
conn.close()
def test_empty_registrations_returns_empty(self):
"""Empty registration list should return empty dict without BQ calls."""
conn = duckdb.connect(":memory:")
try:
result = _register_bq_views(
conn=conn,
registrations=[],
max_bq_rows=100,
timeout_seconds=60,
)
assert result == {}
finally:
conn.close()
# ---------------------------------------------------------------------------
# Tests: Memory estimation
# ---------------------------------------------------------------------------
class TestMemoryEstimation:
"""Test _estimate_memory_mb calculation."""
def test_small_table(self):
"""100 rows x 10 cols = 50_000 bytes ~ 0.048 MB."""
result = _estimate_memory_mb(100, 10)
assert abs(result - 50_000 / (1024 * 1024)) < 0.001
def test_large_table(self):
"""1M rows x 50 cols x 50 bytes = ~2384 MB."""
result = _estimate_memory_mb(1_000_000, 50)
expected = (1_000_000 * 50 * 50) / (1024 * 1024)
assert abs(result - expected) < 0.01
def test_zero_rows(self):
"""Zero rows should return 0 MB."""
assert _estimate_memory_mb(0, 50) == 0.0
def test_zero_columns(self):
"""Zero columns should return 0 MB."""
assert _estimate_memory_mb(1000, 0) == 0.0
# ---------------------------------------------------------------------------
# Tests: Output formatting
# ---------------------------------------------------------------------------
class TestOutputFormatting:
"""Test _print_table and _format_output for various formats."""
def test_table_format_aligned(self, capsys):
"""Table format should produce aligned columns with header separator."""
columns = ["id", "name", "value"]
rows = [(1, "alice", 100), (2, "bob", 200)]
_print_table(columns, rows)
output = capsys.readouterr().out
lines = output.strip().split("\n")
# Header line
assert "id" in lines[0]
assert "name" in lines[0]
assert "value" in lines[0]
# Separator line
assert "-+-" in lines[1]
# Data rows
assert "alice" in lines[2]
assert "bob" in lines[3]
# Row count footer
assert "(2 rows)" in output
def test_table_format_empty_result(self, capsys):
"""Empty result should print '(empty result)'."""
_print_table(["col1"], [])
output = capsys.readouterr().out
assert "(empty result)" in output
def test_table_format_null_values(self, capsys):
"""None values should be rendered as 'NULL'."""
_print_table(["col"], [(None,)])
output = capsys.readouterr().out
assert "NULL" in output
def test_csv_format(self, tmp_path, duckdb_conn):
"""CSV output should contain header + data rows."""
duckdb_conn.execute("CREATE TABLE test AS SELECT 1 AS id, 'hello' AS msg")
output_path = str(tmp_path / "result.csv")
_format_output(
conn=duckdb_conn,
sql="SELECT * FROM test",
fmt="csv",
output_path=output_path,
max_rows=1000,
)
with open(output_path) as f:
reader = csv.reader(f)
header = next(reader)
rows = list(reader)
assert header == ["id", "msg"]
assert len(rows) == 1
assert rows[0][1] == "hello"
def test_csv_format_to_stdout(self, capsys, duckdb_conn):
"""CSV with no output path should write to stdout."""
duckdb_conn.execute("CREATE TABLE test AS SELECT 42 AS val")
_format_output(
conn=duckdb_conn,
sql="SELECT * FROM test",
fmt="csv",
output_path=None,
max_rows=1000,
)
output = capsys.readouterr().out
assert "val" in output
assert "42" in output
def test_json_format(self, tmp_path, duckdb_conn):
"""JSON output should contain a list of dicts."""
duckdb_conn.execute(
"CREATE TABLE test AS SELECT 1 AS id, 'world' AS msg"
)
output_path = str(tmp_path / "result.json")
_format_output(
conn=duckdb_conn,
sql="SELECT * FROM test",
fmt="json",
output_path=output_path,
max_rows=1000,
)
with open(output_path) as f:
data = json.load(f)
assert len(data) == 1
assert data[0]["id"] == 1
assert data[0]["msg"] == "world"
def test_json_format_to_stdout(self, capsys, duckdb_conn):
"""JSON with no output path should print to stdout."""
duckdb_conn.execute("CREATE TABLE test AS SELECT 99 AS num")
_format_output(
conn=duckdb_conn,
sql="SELECT * FROM test",
fmt="json",
output_path=None,
max_rows=1000,
)
output = capsys.readouterr().out
data = json.loads(output)
assert data[0]["num"] == 99
def test_parquet_write(self, tmp_path, duckdb_conn):
"""Parquet output should create a readable parquet file."""
duckdb_conn.execute(
"CREATE TABLE test AS SELECT 1 AS id, 2.5 AS val"
)
output_path = str(tmp_path / "output" / "result.parquet")
with patch("src.remote_query._load_remote_query_config", return_value={
"output_dir": str(tmp_path / "default_output"),
"timeout_seconds": 300,
"max_result_rows": 100_000,
"max_bq_registration_rows": 500_000,
"default_format": "table",
}):
_format_output(
conn=duckdb_conn,
sql="SELECT * FROM test",
fmt="parquet",
output_path=output_path,
max_rows=1000,
)
assert Path(output_path).exists()
# Read it back and verify
result = pq.read_table(output_path)
assert result.num_rows == 1
assert result.num_columns == 2
assert result.column_names == ["id", "val"]
def test_parquet_default_path(self, tmp_path, duckdb_conn):
"""Parquet with no output path should use config default dir."""
duckdb_conn.execute("CREATE TABLE test AS SELECT 1 AS x")
default_dir = str(tmp_path / "default_output")
with patch("src.remote_query._load_remote_query_config", return_value={
"output_dir": default_dir,
"timeout_seconds": 300,
"max_result_rows": 100_000,
"max_bq_registration_rows": 500_000,
"default_format": "table",
}):
_format_output(
conn=duckdb_conn,
sql="SELECT * FROM test",
fmt="parquet",
output_path=None,
max_rows=1000,
)
expected_path = Path(default_dir) / "result.parquet"
assert expected_path.exists()
def test_unknown_format_raises(self, duckdb_conn):
"""Unknown format should raise RemoteQueryError."""
duckdb_conn.execute("CREATE TABLE test AS SELECT 1 AS id")
with pytest.raises(RemoteQueryError, match="Unknown format"):
_format_output(
conn=duckdb_conn,
sql="SELECT * FROM test",
fmt="xml",
output_path=None,
max_rows=1000,
)
# ---------------------------------------------------------------------------
# Tests: End-to-end (local-only, no BQ mocking needed)
# ---------------------------------------------------------------------------
class TestEndToEnd:
"""End-to-end tests with local Parquet data only (no BigQuery dependency).
Uses _patched_duckdb_connect to handle DuckDB version differences
(statement_timeout may not be supported in all versions).
"""
_CONFIG = {
"timeout_seconds": 300,
"max_result_rows": 100_000,
"max_bq_registration_rows": 500_000,
"default_format": "table",
"output_dir": "/tmp/remote_query_test",
}
def _run(self, tmp_local_project, **kwargs):
"""Helper to run execute_remote_query with standard patches."""
project_root, data_dir = tmp_local_project
config = dict(self._CONFIG)
config.update(kwargs.pop("config_overrides", {}))
with patch("scripts.duckdb_manager.find_project_root", return_value=project_root), \
patch("src.remote_query._load_remote_query_config", return_value=config), \
patch("src.remote_query.duckdb") as mock_duckdb_mod:
mock_duckdb_mod.connect = _patched_duckdb_connect
kwargs.setdefault("data_dir", data_dir)
kwargs.setdefault("bq_registrations", [])
kwargs.setdefault("quiet", True)
execute_remote_query(**kwargs)
def test_local_only_query(self, tmp_local_project, capsys):
"""Execute a query against local Parquet views and verify table output."""
self._run(
tmp_local_project,
sql="SELECT COUNT(*) AS cnt FROM orders",
fmt="table",
)
output = capsys.readouterr().out
assert "cnt" in output
assert "5" in output
def test_local_join_query(self, tmp_local_project, capsys):
"""JOIN between two local tables should work."""
self._run(
tmp_local_project,
sql=(
"SELECT p.name, SUM(o.amount) AS total "
"FROM orders o JOIN products p ON o.product_id = p.product_id "
"GROUP BY p.name ORDER BY total DESC"
),
fmt="json",
)
output = capsys.readouterr().out
data = json.loads(output)
assert len(data) == 3
# Widget: orders 1,3 -> 10+30=40
widget = next(r for r in data if r["name"] == "Widget")
assert widget["total"] == 40.0
def test_result_row_limit(self, tmp_local_project, capsys):
"""Result exceeding max_rows should be truncated."""
self._run(
tmp_local_project,
sql="SELECT * FROM orders ORDER BY order_id",
fmt="table",
max_rows=2,
quiet=False,
config_overrides={"max_result_rows": 2},
)
out = capsys.readouterr().out
# Table output should show exactly 2 data rows
assert "(2 rows)" in out
def test_csv_output_to_file(self, tmp_local_project, tmp_path):
"""End-to-end CSV output written to a file."""
output_path = str(tmp_path / "result.csv")
self._run(
tmp_local_project,
sql="SELECT order_id, amount FROM orders ORDER BY order_id",
fmt="csv",
output=output_path,
)
with open(output_path) as f:
reader = csv.DictReader(f)
rows = list(reader)
assert len(rows) == 5
assert rows[0]["order_id"] == "1"
assert rows[0]["amount"] == "10.0"
def test_hybrid_table_queryable(self, tmp_local_project, capsys):
"""Hybrid table should be accessible in local queries."""
self._run(
tmp_local_project,
sql="SELECT SUM(stock) AS total_stock FROM inventory",
fmt="json",
)
output = capsys.readouterr().out
data = json.loads(output)
assert data[0]["total_stock"] == 300
def test_quiet_mode_suppresses_stderr(self, tmp_local_project, capsys):
"""With quiet=True, no progress messages should appear on stderr."""
self._run(
tmp_local_project,
sql="SELECT COUNT(*) AS cnt FROM orders",
fmt="table",
quiet=True,
)
err = capsys.readouterr().err
# In quiet mode, _log_progress should not emit anything
assert "Setting up" not in err
assert "local views" not in err

View file

@ -1,363 +0,0 @@
"""Tests for the Table Registry module."""
import json
from pathlib import Path
import pytest
import yaml
from src.table_registry import ConflictError, TableRegistry
# ---------------------------------------------------------------------------
# Fixtures
# ---------------------------------------------------------------------------
@pytest.fixture
def registry_path(tmp_path):
"""Return a temp path for the registry JSON."""
return tmp_path / "table_registry.json"
@pytest.fixture
def registry(registry_path):
"""Create an empty registry."""
return TableRegistry(registry_path)
@pytest.fixture
def sample_table():
"""Minimal valid table definition."""
return {
"id": "in.c-crm.company",
"name": "company",
"description": "Customer master data",
"primary_key": "id",
"sync_strategy": "full_refresh",
}
@pytest.fixture
def sample_table_incremental():
"""Incremental table definition."""
return {
"id": "in.c-crm.events",
"name": "events",
"description": "User events",
"primary_key": "event_id",
"sync_strategy": "incremental",
"incremental_window_days": 14,
"partition_by": "created_at",
"partition_granularity": "month",
}
# ---------------------------------------------------------------------------
# Basic CRUD
# ---------------------------------------------------------------------------
class TestRegistryCRUD:
def test_empty_registry(self, registry):
assert registry.list_tables() == []
assert registry.version == 0
def test_register_table(self, registry, sample_table):
registry.register_table(sample_table, registered_by="admin@test.com")
tables = registry.list_tables()
assert len(tables) == 1
assert tables[0]["id"] == "in.c-crm.company"
assert tables[0]["registered_by"] == "admin@test.com"
assert registry.version == 1
def test_register_duplicate_raises(self, registry, sample_table):
registry.register_table(sample_table, registered_by="admin@test.com")
with pytest.raises(ValueError, match="already registered"):
registry.register_table(sample_table, registered_by="admin@test.com")
def test_get_table(self, registry, sample_table):
registry.register_table(sample_table, registered_by="admin@test.com")
t = registry.get_table("in.c-crm.company")
assert t is not None
assert t["name"] == "company"
def test_get_table_not_found(self, registry):
assert registry.get_table("nonexistent") is None
def test_is_registered(self, registry, sample_table):
assert not registry.is_registered("in.c-crm.company")
registry.register_table(sample_table, registered_by="admin@test.com")
assert registry.is_registered("in.c-crm.company")
def test_unregister_table(self, registry, sample_table):
registry.register_table(sample_table, registered_by="admin@test.com")
registry.unregister_table("in.c-crm.company", unregistered_by="admin@test.com")
assert not registry.is_registered("in.c-crm.company")
assert registry.list_tables() == []
def test_unregister_nonexistent_raises(self, registry):
with pytest.raises(ValueError, match="not registered"):
registry.unregister_table("nonexistent")
def test_update_table(self, registry, sample_table):
registry.register_table(sample_table, registered_by="admin@test.com")
registry.update_table(
"in.c-crm.company",
{"description": "Updated description", "sync_strategy": "incremental"},
updated_by="admin@test.com",
)
t = registry.get_table("in.c-crm.company")
assert t["description"] == "Updated description"
assert t["sync_strategy"] == "incremental"
def test_update_nonexistent_raises(self, registry):
with pytest.raises(ValueError, match="not registered"):
registry.update_table("nonexistent", {"description": "x"})
# ---------------------------------------------------------------------------
# Validation
# ---------------------------------------------------------------------------
class TestValidation:
def test_missing_id_raises(self, registry):
with pytest.raises(ValueError, match="must include 'id'"):
registry.register_table(
{"name": "x", "sync_strategy": "full_refresh", "primary_key": "id"},
registered_by="admin@test.com",
)
def test_missing_name_raises(self, registry):
with pytest.raises(ValueError, match="must include 'name'"):
registry.register_table(
{"id": "x.y.z", "sync_strategy": "full_refresh", "primary_key": "id"},
registered_by="admin@test.com",
)
def test_invalid_sync_strategy_raises(self, registry):
with pytest.raises(ValueError, match="Invalid sync_strategy"):
registry.register_table(
{
"id": "x.y.z",
"name": "z",
"sync_strategy": "magic",
"primary_key": "id",
},
registered_by="admin@test.com",
)
# ---------------------------------------------------------------------------
# Optimistic locking
# ---------------------------------------------------------------------------
class TestOptimisticLocking:
def test_register_with_wrong_version_raises(self, registry, sample_table):
with pytest.raises(ConflictError, match="Version conflict"):
registry.register_table(
sample_table, registered_by="admin@test.com", expected_version=99
)
def test_register_with_correct_version(self, registry, sample_table):
registry.register_table(
sample_table, registered_by="admin@test.com", expected_version=0
)
assert registry.version == 1
def test_unregister_with_wrong_version_raises(self, registry, sample_table):
registry.register_table(sample_table, registered_by="admin@test.com")
with pytest.raises(ConflictError):
registry.unregister_table(
"in.c-crm.company", expected_version=0
)
def test_update_with_wrong_version_raises(self, registry, sample_table):
registry.register_table(sample_table, registered_by="admin@test.com")
with pytest.raises(ConflictError):
registry.update_table(
"in.c-crm.company",
{"description": "x"},
expected_version=0,
)
# ---------------------------------------------------------------------------
# Persistence
# ---------------------------------------------------------------------------
class TestPersistence:
def test_save_and_reload(self, registry_path, sample_table):
reg1 = TableRegistry(registry_path)
reg1.register_table(sample_table, registered_by="admin@test.com")
# Reload from disk
reg2 = TableRegistry(registry_path)
assert len(reg2.list_tables()) == 1
assert reg2.get_table("in.c-crm.company")["name"] == "company"
assert reg2.version == 1
def test_json_format(self, registry_path, sample_table):
reg = TableRegistry(registry_path)
reg.register_table(sample_table, registered_by="admin@test.com")
with open(registry_path) as f:
data = json.load(f)
assert "_metadata" in data
assert "tables" in data
assert data["_metadata"]["version"] == 1
assert len(data["tables"]) == 1
# ---------------------------------------------------------------------------
# Folder mapping
# ---------------------------------------------------------------------------
class TestFolderMapping:
def test_set_and_get(self, registry):
registry.set_folder_mapping("in.c-crm", "crm")
assert registry.get_folder_mapping() == {"in.c-crm": "crm"}
def test_persists(self, registry_path):
reg1 = TableRegistry(registry_path)
reg1.set_folder_mapping("in.c-crm", "crm")
reg2 = TableRegistry(registry_path)
assert reg2.get_folder_mapping() == {"in.c-crm": "crm"}
# ---------------------------------------------------------------------------
# Generation
# ---------------------------------------------------------------------------
class TestGeneration:
def test_generate_data_description_md(self, registry, sample_table, tmp_path):
registry.register_table(sample_table, registered_by="admin@test.com")
registry.set_folder_mapping("in.c-crm", "crm")
output = tmp_path / "data_description.md"
registry.generate_data_description_md(output)
content = output.read_text()
# Check header
assert "AUTO-GENERATED" in content
assert "checksum: sha256:" in content
# Check YAML block is parseable
yaml_match = __import__("re").search(r"```yaml\n(.*?)```", content, __import__("re").DOTALL)
assert yaml_match
yaml_data = yaml.safe_load(yaml_match.group(1))
assert len(yaml_data["tables"]) == 1
assert yaml_data["tables"][0]["id"] == "in.c-crm.company"
assert yaml_data["folder_mapping"] == {"in.c-crm": "crm"}
def test_generate_includes_incremental_fields(
self, registry, sample_table_incremental, tmp_path
):
registry.register_table(sample_table_incremental, registered_by="admin@test.com")
output = tmp_path / "data_description.md"
registry.generate_data_description_md(output)
content = output.read_text()
yaml_match = __import__("re").search(r"```yaml\n(.*?)```", content, __import__("re").DOTALL)
yaml_data = yaml.safe_load(yaml_match.group(1))
table = yaml_data["tables"][0]
assert table["partition_by"] == "created_at"
assert table["partition_granularity"] == "month"
assert table["incremental_window_days"] == 14
# ---------------------------------------------------------------------------
# Migration
# ---------------------------------------------------------------------------
class TestMigration:
def test_import_from_data_description(self, tmp_path):
# Create a fake data_description.md
md_content = """# Data Description
```yaml
folder_mapping:
in.c-crm: crm
tables:
- id: in.c-crm.company
name: company
description: Companies
primary_key: id
sync_strategy: full_refresh
- id: in.c-crm.contact
name: contact
description: Contacts
primary_key: id
sync_strategy: incremental
incremental_window_days: 7
```
"""
md_path = tmp_path / "data_description.md"
md_path.write_text(md_content)
registry_path = tmp_path / "table_registry.json"
registry = TableRegistry.import_from_data_description(md_path, registry_path)
assert len(registry.list_tables()) == 2
assert registry.is_registered("in.c-crm.company")
assert registry.is_registered("in.c-crm.contact")
assert registry.get_folder_mapping() == {"in.c-crm": "crm"}
# Check migrated_from marker
with open(registry_path) as f:
data = json.load(f)
assert "migrated_from" in data["_metadata"]
def test_import_no_yaml_raises(self, tmp_path):
md_path = tmp_path / "data_description.md"
md_path.write_text("# Empty file\nNo YAML here.")
with pytest.raises(ValueError, match="No YAML blocks"):
TableRegistry.import_from_data_description(
md_path, tmp_path / "registry.json"
)
def test_import_file_not_found_raises(self, tmp_path):
with pytest.raises(FileNotFoundError):
TableRegistry.import_from_data_description(
tmp_path / "nonexistent.md", tmp_path / "registry.json"
)
# ---------------------------------------------------------------------------
# Audit log
# ---------------------------------------------------------------------------
class TestAuditLog:
def test_register_writes_audit(self, registry, sample_table):
registry.register_table(sample_table, registered_by="admin@test.com")
audit_path = registry.registry_path.parent / "registry_audit.log"
assert audit_path.exists()
lines = audit_path.read_text().strip().split("\n")
assert len(lines) >= 1
entry = json.loads(lines[-1])
assert entry["action"] == "register"
assert entry["table_id"] == "in.c-crm.company"
def test_unregister_writes_audit(self, registry, sample_table):
registry.register_table(sample_table, registered_by="admin@test.com")
registry.unregister_table("in.c-crm.company", unregistered_by="admin@test.com")
audit_path = registry.registry_path.parent / "registry_audit.log"
lines = audit_path.read_text().strip().split("\n")
last_entry = json.loads(lines[-1])
assert last_entry["action"] == "unregister"

View file

@ -611,19 +611,11 @@ def _load_catalog_data() -> list:
# Enrich with catalog metadata (OpenMetadata)
if _catalog_enricher:
try:
# Create config for enrichment with all available fields
from src.config import TableConfig
table_config = TableConfig(
# Create lightweight config for enrichment (enricher uses .id, .name, .catalog_fqn)
from types import SimpleNamespace
table_config = SimpleNamespace(
id=table_id,
name=table.get("name", ""),
description=table.get("description", ""),
primary_key=table.get("primary_key", "id"),
sync_strategy=table.get("sync_strategy", "full_refresh"),
incremental_window_days=table.get("incremental_window_days"),
partition_by=table.get("partition_by"),
partition_granularity=table.get("partition_granularity"),
max_history_days=table.get("max_history_days"),
partition_column_type=table.get("partition_column_type", "TIMESTAMP"),
catalog_fqn=table.get("catalog_fqn"),
)
catalog_data = _catalog_enricher.enrich_table(table_config)
@ -1097,7 +1089,7 @@ def register_routes(app: Flask) -> None:
if _catalog_enricher and _catalog_enricher.enabled:
try:
# Find table config from data_description.md
from src.config import TableConfig
from types import SimpleNamespace
from config.loader import load_instance_config
# Load data_description.md to find table config by name
@ -1116,17 +1108,10 @@ def register_routes(app: Flask) -> None:
# Find table by name
for table_def in yaml_data["tables"]:
if table_def.get("name") == table_name:
table_config = TableConfig(
# Lightweight config (enricher uses .id, .name, .catalog_fqn)
table_config = SimpleNamespace(
id=table_def.get("id", ""),
name=table_def.get("name", ""),
description=table_def.get("description", ""),
primary_key=table_def.get("primary_key", "id"),
sync_strategy=table_def.get("sync_strategy", "full_refresh"),
incremental_window_days=table_def.get("incremental_window_days"),
partition_by=table_def.get("partition_by"),
partition_granularity=table_def.get("partition_granularity"),
max_history_days=table_def.get("max_history_days"),
partition_column_type=table_def.get("partition_column_type", "TIMESTAMP"),
catalog_fqn=table_def.get("catalog_fqn"),
)
catalog_data = _catalog_enricher.enrich_table(table_config)
@ -1862,17 +1847,25 @@ def register_routes(app: Flask) -> None:
def admin_discover_tables():
"""Discover all available tables from the data source."""
try:
from src.data_sync import create_data_source
from app.instance_config import get_data_source_type, get_value
ds = create_data_source()
raw_tables = ds.discover_tables()
source_type = get_data_source_type()
raw_tables = []
if source_type == "keboola":
from connectors.keboola.client import KeboolaClient
url = get_value("keboola", "url", default="")
token = os.environ.get(get_value("keboola", "token_env", default="KEBOOLA_STORAGE_TOKEN"), "")
client = KeboolaClient(token=token, url=url)
raw_tables = client.discover_all_tables()
# Check which tables are already registered
registered_ids = set()
try:
from src.table_registry import TableRegistry
registry = TableRegistry.default()
registered_ids = {t["id"] for t in registry.list_tables()}
from src.db import get_system_db
from src.repositories.table_registry import TableRegistryRepository
conn = get_system_db()
repo = TableRegistryRepository(conn)
registered_ids = {t["id"] for t in repo.list_all()}
except Exception:
pass
@ -1905,14 +1898,16 @@ def register_routes(app: Flask) -> None:
def admin_registry_list():
"""Return the full table registry."""
try:
from src.table_registry import TableRegistry
from src.db import get_system_db
from src.repositories.table_registry import TableRegistryRepository
registry = TableRegistry.default()
conn = get_system_db()
repo = TableRegistryRepository(conn)
return jsonify({
"ok": True,
"version": registry.version,
"folder_mapping": registry.get_folder_mapping(),
"tables": registry.list_tables(),
"version": 0,
"folder_mapping": {},
"tables": repo.list_all(),
})
except Exception as e:
logger.error(f"Registry list failed: {e}")
@ -1923,7 +1918,8 @@ def register_routes(app: Flask) -> None:
@admin_required
def admin_register_table():
"""Register a new table from discovery results."""
from src.table_registry import ConflictError, TableRegistry
from src.db import get_system_db
from src.repositories.table_registry import TableRegistryRepository
user = session.get("user", {})
email = user.get("email", "")
@ -1933,21 +1929,26 @@ def register_routes(app: Flask) -> None:
return jsonify({"error": "Missing table 'id'"}), 400
try:
registry = TableRegistry.default()
registry.register_table(
table_def=data,
conn = get_system_db()
repo = TableRegistryRepository(conn)
repo.register(
id=data["id"],
name=data.get("name", ""),
folder=data.get("folder"),
sync_strategy=data.get("sync_strategy"),
primary_key=data.get("primary_key"),
description=data.get("description"),
registered_by=email,
expected_version=data.get("version"),
source_type=data.get("source_type"),
bucket=data.get("bucket"),
source_table=data.get("source_table"),
query_mode=data.get("query_mode", "local"),
sync_schedule=data.get("sync_schedule"),
profile_after_sync=data.get("profile_after_sync", True),
)
# Regenerate data_description.md
docs_path = Path(os.path.dirname(__file__)) / ".." / "docs" / "data_description.md"
registry.generate_data_description_md(docs_path.resolve())
return jsonify({"ok": True})
return jsonify({"ok": True, "version": registry.version})
except ConflictError as e:
return jsonify({"error": str(e)}), 409
except ValueError as e:
return jsonify({"error": str(e)}), 400
except Exception as e:
@ -1959,30 +1960,42 @@ def register_routes(app: Flask) -> None:
@admin_required
def admin_update_table(table_id):
"""Update configuration of a registered table."""
from src.table_registry import ConflictError, TableRegistry
from src.db import get_system_db
from src.repositories.table_registry import TableRegistryRepository
user = session.get("user", {})
email = user.get("email", "")
data = request.get_json(silent=True) or {}
data.pop("version", None) # Not used by DuckDB repo
try:
registry = TableRegistry.default()
registry.update_table(
table_id=table_id,
updates=data,
updated_by=email,
expected_version=data.pop("version", None),
conn = get_system_db()
repo = TableRegistryRepository(conn)
# Get existing record and merge updates
existing = repo.get(table_id)
if not existing:
return jsonify({"error": f"Table '{table_id}' not found"}), 404
repo.register(
id=table_id,
name=data.get("name", existing.get("name", "")),
folder=data.get("folder", existing.get("folder")),
sync_strategy=data.get("sync_strategy", existing.get("sync_strategy")),
primary_key=data.get("primary_key", existing.get("primary_key")),
description=data.get("description", existing.get("description")),
registered_by=email,
source_type=data.get("source_type", existing.get("source_type")),
bucket=data.get("bucket", existing.get("bucket")),
source_table=data.get("source_table", existing.get("source_table")),
query_mode=data.get("query_mode", existing.get("query_mode", "local")),
sync_schedule=data.get("sync_schedule", existing.get("sync_schedule")),
profile_after_sync=data.get("profile_after_sync", existing.get("profile_after_sync", True)),
)
# Regenerate data_description.md
docs_path = Path(os.path.dirname(__file__)) / ".." / "docs" / "data_description.md"
registry.generate_data_description_md(docs_path.resolve())
return jsonify({"ok": True})
return jsonify({"ok": True, "version": registry.version})
except ConflictError as e:
return jsonify({"error": str(e)}), 409
except ValueError as e:
return jsonify({"error": str(e)}), 400
except Exception as e:
@ -1994,25 +2007,18 @@ def register_routes(app: Flask) -> None:
@admin_required
def admin_unregister_table(table_id):
"""Unregister a table and clean up subscriptions."""
from src.table_registry import ConflictError, TableRegistry
user = session.get("user", {})
email = user.get("email", "")
data = request.get_json(silent=True) or {}
from src.db import get_system_db
from src.repositories.table_registry import TableRegistryRepository
try:
registry = TableRegistry.default()
conn = get_system_db()
repo = TableRegistryRepository(conn)
# Get table name before deletion (for subscription cleanup)
table_info = registry.get_table(table_id)
table_info = repo.get(table_id)
table_name = table_info["name"] if table_info else None
registry.unregister_table(
table_id=table_id,
unregistered_by=email,
expected_version=data.get("version"),
)
repo.unregister(table_id)
# Clean up per-user subscriptions for removed table
if table_name:
@ -2021,14 +2027,8 @@ def register_routes(app: Flask) -> None:
except Exception as ce:
logger.warning(f"Subscription cleanup for {table_name} failed: {ce}")
# Regenerate data_description.md
docs_path = Path(os.path.dirname(__file__)) / ".." / "docs" / "data_description.md"
registry.generate_data_description_md(docs_path.resolve())
return jsonify({"ok": True})
return jsonify({"ok": True, "version": registry.version})
except ConflictError as e:
return jsonify({"error": str(e)}), 409
except ValueError as e:
return jsonify({"error": str(e)}), 400
except Exception as e:

View file

@ -194,9 +194,12 @@ def _write_rsync_filter(username: str, dataset_settings: dict, table_mode: str,
# Load folder_mapping from table registry (or instance config as fallback)
folder_mapping = {}
try:
from src.table_registry import TableRegistry
registry = TableRegistry.default()
folder_mapping = registry.get_folder_mapping()
from src.db import get_system_db
from src.repositories.table_registry import TableRegistryRepository
conn = get_system_db()
repo = TableRegistryRepository(conn)
tables = repo.list_all()
folder_mapping = {t["bucket"]: t["folder"] for t in tables if t.get("bucket") and t.get("folder")}
except Exception:
try:
from config.loader import load_instance_config, get_instance_value