agnes-the-ai-analyst/connectors/bigquery/client.py
Petr 85c87ec375 Pass explicit bqstorage_client to to_arrow_iterable() for Storage API
Without explicit bqstorage_client parameter, to_arrow_iterable() silently
falls back to REST API pagination (~5K rows/sec). With explicit client,
it uses parallel gRPC streams via BQ Storage API (~300K rows/sec).

No temp table materialization - BQ already writes query results to an
internal temp table automatically. We just tell the reader to use the
fast gRPC path instead of slow HTTP pagination.
2026-03-12 10:51:44 +01:00

644 lines
22 KiB
Python

"""
Google BigQuery API Client
Low-level wrapper for Google BigQuery with these functions:
1. Authentication using Application Default Credentials (ADC)
2. Query tables to PyArrow (no CSV intermediate step)
3. Get table metadata (schema, columns, data types)
4. Cache metadata for faster repeated use
5. Incremental reads (timestamp-based and partition-based)
Uses google-cloud-bigquery with native PyArrow support.
"""
import json
import logging
import os
from pathlib import Path
from typing import Dict, List, Optional, Any
from datetime import datetime, timedelta
import pyarrow as pa
from google.cloud import bigquery
try:
from google.cloud import bigquery_storage_v1
_HAS_BQ_STORAGE = True
except ImportError:
_HAS_BQ_STORAGE = False
from src.config import get_config
logger = logging.getLogger(__name__)
# Mapping BigQuery types to PyArrow types
BIGQUERY_TO_PYARROW_TYPES = {
"STRING": pa.string(),
"BYTES": pa.binary(),
"INTEGER": pa.int64(),
"INT64": pa.int64(),
"FLOAT": pa.float64(),
"FLOAT64": pa.float64(),
"NUMERIC": pa.float64(),
"BIGNUMERIC": pa.float64(),
"BOOLEAN": pa.bool_(),
"BOOL": pa.bool_(),
"TIMESTAMP": pa.timestamp("us", tz="UTC"),
"DATE": pa.date32(),
"TIME": pa.string(),
"DATETIME": pa.timestamp("us"),
"GEOGRAPHY": pa.string(),
"JSON": pa.string(),
"STRUCT": pa.string(),
"RECORD": pa.string(),
"ARRAY": pa.string(),
}
class BigQueryClient:
"""
Wrapper for Google BigQuery API.
Provides high-level methods for working with BigQuery tables:
- Query tables to PyArrow Tables (no CSV step)
- Get metadata (schema, columns)
- Incremental and partitioned reads
"""
def __init__(
self,
project_id: Optional[str] = None,
location: Optional[str] = None,
):
"""
Initialize BigQuery client.
Args:
project_id: GCP project ID for job execution/billing.
If None, reads from BIGQUERY_PROJECT env var.
location: BigQuery location for job execution (e.g., "us-central1").
If None, reads from BIGQUERY_LOCATION env var.
Raises:
ValueError: If project_id is not provided and BIGQUERY_PROJECT is not set.
"""
self.project_id = project_id or os.environ.get("BIGQUERY_PROJECT")
if not self.project_id:
raise ValueError(
"BigQuery project ID not set. "
"Set BIGQUERY_PROJECT environment variable."
)
self.location = location or os.environ.get("BIGQUERY_LOCATION")
# Initialize BigQuery client with ADC
# project_id is used for job execution and billing.
# Data can live in a different project -- table IDs in queries
# use fully-qualified format (project.dataset.table).
client_kwargs = {"project": self.project_id}
if self.location:
client_kwargs["location"] = self.location
self.client = bigquery.Client(**client_kwargs)
# BQ Storage API client for fast parallel reads (gRPC streams).
# Without explicit bqstorage_client, to_arrow_iterable() silently
# falls back to slow REST API pagination (~5K rows/sec vs ~300K rows/sec).
if _HAS_BQ_STORAGE:
try:
self.bqstorage_client = bigquery_storage_v1.BigQueryReadClient()
logger.info("BQ Storage API client initialized (fast parallel gRPC reads)")
except Exception as e:
self.bqstorage_client = None
logger.warning(f"BQ Storage API client failed to initialize: {e}")
else:
self.bqstorage_client = None
logger.info("BQ Storage API not available (install google-cloud-bigquery-storage)")
# Metadata cache
config = get_config()
self.metadata_cache: Dict[str, Dict[str, Any]] = {}
self.metadata_cache_path = config.get_metadata_path() / "bq_table_metadata.json"
# Load cache from disk if exists
self._load_metadata_cache()
logger.info(
f"BigQuery client initialized: project={self.project_id}, "
f"location={self.location or 'auto'}"
)
def _load_metadata_cache(self):
"""Load metadata cache from disk."""
if self.metadata_cache_path.exists():
try:
with open(self.metadata_cache_path, "r") as f:
self.metadata_cache = json.load(f)
logger.info(
f"BQ metadata cache loaded: {len(self.metadata_cache)} tables"
)
except Exception as e:
logger.warning(f"Error loading BQ metadata cache: {e}")
self.metadata_cache = {}
def _save_metadata_cache(self):
"""Save metadata cache to disk."""
try:
self.metadata_cache_path.parent.mkdir(parents=True, exist_ok=True)
with open(self.metadata_cache_path, "w") as f:
json.dump(self.metadata_cache, f, indent=2)
logger.debug("BQ metadata cache saved")
except Exception as e:
logger.warning(f"Error saving BQ metadata cache: {e}")
def get_table_metadata(
self,
table_id: str,
use_cache: bool = True,
cache_ttl_hours: int = 24,
) -> Dict[str, Any]:
"""
Get table metadata from BigQuery.
Args:
table_id: Full table ID (e.g., "project.dataset.table")
use_cache: Use cache if available
cache_ttl_hours: Cache TTL in hours (default 24h)
Returns:
Dictionary with metadata including columns, types, descriptions, row count.
"""
# Check cache
if use_cache and table_id in self.metadata_cache:
cached = self.metadata_cache[table_id]
cached_time = datetime.fromisoformat(cached.get("_cached_at", "2000-01-01"))
cache_age = datetime.now() - cached_time
if cache_age < timedelta(hours=cache_ttl_hours):
logger.debug(f"Using BQ metadata cache for {table_id}")
return cached
logger.info(f"Fetching metadata for BQ table: {table_id}")
try:
table_ref = self.client.get_table(table_id)
# Build column metadata
columns = []
column_types = {}
column_descriptions = {}
for field in table_ref.schema:
columns.append(field.name)
column_types[field.name] = field.field_type
if field.description:
column_descriptions[field.name] = field.description
metadata = {
"table_id": table_id,
"name": table_ref.table_id,
"dataset": table_ref.dataset_id,
"project": table_ref.project,
"columns": columns,
"column_types": column_types,
"column_descriptions": column_descriptions,
"row_count": table_ref.num_rows,
"size_bytes": table_ref.num_bytes,
"created": table_ref.created.isoformat() if table_ref.created else None,
"modified": table_ref.modified.isoformat() if table_ref.modified else None,
"partitioning": None,
"_cached_at": datetime.now().isoformat(),
}
# Capture partitioning info
if table_ref.time_partitioning:
metadata["partitioning"] = {
"type": table_ref.time_partitioning.type_,
"field": table_ref.time_partitioning.field,
"expiration_ms": table_ref.time_partitioning.expiration_ms,
}
# Save to cache
self.metadata_cache[table_id] = metadata
self._save_metadata_cache()
return metadata
except Exception as e:
logger.error(f"Error getting metadata for {table_id}: {e}")
raise
def get_pyarrow_schema(self, table_id: str) -> Optional[pa.Schema]:
"""
Build PyArrow schema from BigQuery table schema.
Args:
table_id: Full table ID
Returns:
PyArrow schema or None if metadata unavailable
"""
metadata = self.get_table_metadata(table_id)
column_types = metadata.get("column_types", {})
if not column_types:
logger.warning(f"No column types for {table_id}, schema will not be applied")
return None
fields = []
for col_name in metadata.get("columns", []):
bq_type = column_types.get(col_name, "STRING")
pa_type = BIGQUERY_TO_PYARROW_TYPES.get(bq_type, pa.string())
fields.append(pa.field(col_name, pa_type))
return pa.schema(fields)
def get_date_columns(self, table_id: str) -> List[str]:
"""
Get list of DATE-only columns for a table.
Args:
table_id: Full table ID
Returns:
List of column names that have DATE type in BigQuery
"""
metadata = self.get_table_metadata(table_id)
column_types = metadata.get("column_types", {})
return [
col_name for col_name, bq_type in column_types.items()
if bq_type == "DATE"
]
def query_to_arrow(
self,
sql: str,
params: Optional[List[bigquery.ScalarQueryParameter]] = None,
) -> pa.Table:
"""
Execute SQL query and return results as PyArrow Table.
Args:
sql: SQL query string (use @param_name for parameterized values)
params: List of BigQuery query parameters
Returns:
PyArrow Table with query results
"""
job_config = bigquery.QueryJobConfig()
if params:
job_config.query_parameters = params
logger.debug(f"Executing BQ query: {sql[:200]}...")
query_job = self.client.query(sql, job_config=job_config)
# Use BQ Storage API for fast reads (parallel gRPC) if available.
# Fall back to REST API if SA lacks bigquery.readsessions.create permission.
try:
if self.bqstorage_client:
arrow_table = query_job.to_arrow(bqstorage_client=self.bqstorage_client)
else:
arrow_table = query_job.to_arrow()
except Exception as storage_err:
if "readsessions" in str(storage_err) or "PERMISSION_DENIED" in str(storage_err):
logger.warning(
"BQ Storage API unavailable (missing readsessions permission), "
"falling back to REST API"
)
arrow_table = query_job.to_arrow(create_bqstorage_client=False)
else:
raise
logger.debug(f"Query returned {arrow_table.num_rows} rows, {arrow_table.num_columns} columns")
return arrow_table
def query_to_arrow_batches(
self,
sql: str,
params: Optional[List[bigquery.ScalarQueryParameter]] = None,
):
"""
Execute SQL query and yield results as streaming RecordBatches.
Unlike query_to_arrow(), this does NOT load entire result into memory.
Each RecordBatch is a small chunk (typically a few MB) that can be
written to disk immediately.
Args:
sql: SQL query string (use @param_name for parameterized values)
params: List of BigQuery query parameters
Yields:
pyarrow.RecordBatch objects
"""
job_config = bigquery.QueryJobConfig()
if params:
job_config.query_parameters = params
logger.debug(f"Executing BQ query (streaming): {sql[:200]}...")
query_job = self.client.query(sql, job_config=job_config)
# result() returns RowIterator which has to_arrow_iterable()
# (QueryJob itself only has to_arrow(), not to_arrow_iterable())
row_iter = query_job.result()
# IMPORTANT: to_arrow_iterable() requires explicit bqstorage_client
# to use BQ Storage API (parallel gRPC streams, ~300K rows/sec).
# Without it, silently falls back to REST pagination (~5K rows/sec).
# This is critical when querying VIEWS (DataView): BQ materializes
# the view into a temp table, and Storage API reads from that temp table.
try:
storage_kwargs = {}
if self.bqstorage_client:
storage_kwargs["bqstorage_client"] = self.bqstorage_client
batch_iter = row_iter.to_arrow_iterable(**storage_kwargs)
# Probe first batch to detect Storage API permission errors early
first_batch = next(batch_iter, None)
if first_batch is not None:
yield first_batch
yield from batch_iter
return
except Exception as storage_err:
if "readsessions" not in str(storage_err) and "PERMISSION_DENIED" not in str(storage_err):
raise
logger.warning(
"BQ Storage API unavailable (missing readsessions permission), "
"falling back to REST API (streaming)"
)
# Fallback: REST API streaming (re-execute query for fresh RowIterator)
row_iter = self.client.query(sql, job_config=job_config).result()
yield from row_iter.to_arrow_iterable(create_bqstorage_client=False)
def read_table_streaming(
self,
table_id: str,
columns: Optional[List[str]] = None,
row_filter: Optional[str] = None,
):
"""
Read table as streaming RecordBatches (constant memory).
Args:
table_id: Full table ID (e.g., "project.dataset.table")
columns: Optional list of columns to select
row_filter: Optional SQL WHERE clause (without WHERE keyword)
Yields:
pyarrow.RecordBatch objects
"""
select_cols = ", ".join(f"`{c}`" for c in columns) if columns else "*"
sql = f"SELECT {select_cols} FROM `{table_id}`"
if row_filter:
sql += f" WHERE {row_filter}"
logger.info(
f"Streaming BQ table: {table_id} "
f"(filter: {row_filter or 'none'}, "
f"storage_api={'yes' if self.bqstorage_client else 'no'})"
)
yield from self.query_to_arrow_batches(sql)
def read_table(
self,
table_id: str,
columns: Optional[List[str]] = None,
row_filter: Optional[str] = None,
) -> pa.Table:
"""
Read full table (or filtered subset) as PyArrow Table.
Args:
table_id: Full table ID (e.g., "project.dataset.table")
columns: Optional list of columns to select
row_filter: Optional SQL WHERE clause (without WHERE keyword)
Returns:
PyArrow Table with table data
"""
# Build SELECT clause
select_cols = ", ".join(f"`{c}`" for c in columns) if columns else "*"
sql = f"SELECT {select_cols} FROM `{table_id}`"
if row_filter:
sql += f" WHERE {row_filter}"
logger.info(f"Reading BQ table: {table_id} (filter: {row_filter or 'none'})")
return self.query_to_arrow(sql)
def read_table_incremental(
self,
table_id: str,
incremental_column: str,
since_value: str,
columns: Optional[List[str]] = None,
) -> pa.Table:
"""
Read rows where incremental_column > since_value.
Uses parameterized query to prevent SQL injection.
Args:
table_id: Full table ID
incremental_column: Column name for incremental filter
since_value: ISO timestamp string - fetch rows after this value
columns: Optional list of columns to select
Returns:
PyArrow Table with incremental data
"""
select_cols = ", ".join(f"`{c}`" for c in columns) if columns else "*"
sql = (
f"SELECT {select_cols} FROM `{table_id}` "
f"WHERE `{incremental_column}` > @since_value"
)
params = [
bigquery.ScalarQueryParameter("since_value", "TIMESTAMP", since_value),
]
logger.info(
f"Incremental read: {table_id} WHERE {incremental_column} > {since_value}"
)
return self.query_to_arrow(sql, params=params)
def read_table_partitioned(
self,
table_id: str,
partition_column: str,
start: str,
end: Optional[str] = None,
columns: Optional[List[str]] = None,
column_type: str = "TIMESTAMP",
) -> pa.Table:
"""
Read data within a partition range.
Args:
table_id: Full table ID
partition_column: Partition column name
start: Start date/timestamp (inclusive)
end: End date/timestamp (exclusive). If None, reads to present.
columns: Optional list of columns to select
column_type: BQ SQL type for the partition column ("DATE", "TIMESTAMP", "DATETIME")
Returns:
PyArrow Table with partition range data
"""
select_cols = ", ".join(f"`{c}`" for c in columns) if columns else "*"
sql = (
f"SELECT {select_cols} FROM `{table_id}` "
f"WHERE `{partition_column}` >= @start_value"
)
params = [
bigquery.ScalarQueryParameter("start_value", column_type, start),
]
if end:
sql += f" AND `{partition_column}` < @end_value"
params.append(
bigquery.ScalarQueryParameter("end_value", column_type, end),
)
logger.info(
f"Partitioned read: {table_id} [{start} .. {end or 'now'})"
)
return self.query_to_arrow(sql, params=params)
def read_table_partitioned_streaming(
self,
table_id: str,
partition_column: str,
start: str,
end: Optional[str] = None,
columns: Optional[List[str]] = None,
column_type: str = "TIMESTAMP",
):
"""
Read data within a partition range as streaming RecordBatches (constant memory).
Unlike read_table_partitioned(), this does NOT load entire result into memory.
Each RecordBatch is a small chunk that can be written to disk immediately.
Args:
table_id: Full table ID
partition_column: Partition column name
start: Start date/timestamp (inclusive)
end: End date/timestamp (exclusive). If None, reads to present.
columns: Optional list of columns to select
column_type: BQ SQL type for the partition column ("DATE", "TIMESTAMP", "DATETIME")
Yields:
pyarrow.RecordBatch objects
"""
select_cols = ", ".join(f"`{c}`" for c in columns) if columns else "*"
sql = (
f"SELECT {select_cols} FROM `{table_id}` "
f"WHERE `{partition_column}` >= @start_value"
)
params = [
bigquery.ScalarQueryParameter("start_value", column_type, start),
]
if end:
sql += f" AND `{partition_column}` < @end_value"
params.append(
bigquery.ScalarQueryParameter("end_value", column_type, end),
)
logger.info(
f"Partitioned streaming read: {table_id} [{start} .. {end or 'now'})"
)
yield from self.query_to_arrow_batches(sql, params=params)
def discover_all_tables(self, dataset_id: Optional[str] = None) -> List[Dict[str, Any]]:
"""
List all tables in the project (or specific dataset).
Args:
dataset_id: Optional dataset ID to limit scope
Returns:
Normalized list of table dicts with id, name, columns, row_count, etc.
"""
logger.info(f"Discovering BQ tables (dataset={dataset_id or 'all'})...")
result = []
if dataset_id:
datasets = [self.client.get_dataset(dataset_id)]
else:
datasets = list(self.client.list_datasets())
for dataset in datasets:
ds_ref = dataset.reference if hasattr(dataset, "reference") else dataset.dataset_id
ds_id = str(ds_ref)
try:
tables = list(self.client.list_tables(ds_ref))
except Exception as e:
logger.warning(f"Could not list tables in dataset {ds_id}: {e}")
continue
for table_item in tables:
full_id = f"{table_item.project}.{table_item.dataset_id}.{table_item.table_id}"
try:
table_detail = self.client.get_table(full_id)
columns = [f.name for f in table_detail.schema]
result.append({
"id": full_id,
"name": table_item.table_id,
"bucket_id": table_item.dataset_id,
"bucket_name": table_item.dataset_id,
"columns": columns,
"row_count": table_detail.num_rows or 0,
"size_bytes": table_detail.num_bytes or 0,
"primary_key": [],
"last_change": (
table_detail.modified.isoformat()
if table_detail.modified else None
),
"last_import": None,
})
except Exception as e:
logger.warning(f"Could not get details for {full_id}: {e}")
logger.info(f"Discovered {len(result)} BQ tables")
return result
def test_connection(self) -> bool:
"""
Test connection to BigQuery.
Returns:
True if connection works, False otherwise
"""
try:
query_job = self.client.query("SELECT 1")
list(query_job.result())
logger.info(f"BigQuery connection OK (project: {self.project_id})")
return True
except Exception as e:
logger.error(f"BigQuery connection test failed: {e}")
return False
def create_client() -> BigQueryClient:
"""
Factory function to create BigQuery client.
Returns:
BigQueryClient instance
"""
return BigQueryClient()