Stream BQ results to Parquet instead of loading into memory
Replace to_arrow() (loads entire result into RAM) with to_arrow_iterable() (streams RecordBatches). Each batch is written directly to disk via ParquetWriter - constant memory regardless of table size. Prevents OOM on 8GB server for multi-million row tables.
This commit is contained in:
parent
a191ede28c
commit
ee70da86c3
3 changed files with 137 additions and 32 deletions
|
|
@ -132,42 +132,66 @@ class BigQueryDataSource(DataSource):
|
|||
|
||||
def _full_refresh(self, table_config: TableConfig) -> Dict[str, Any]:
|
||||
"""
|
||||
Full refresh: read entire table and replace Parquet file.
|
||||
Full refresh: stream table from BQ and write to Parquet in batches.
|
||||
|
||||
Uses streaming (constant memory) instead of loading entire table into RAM.
|
||||
Each RecordBatch from BQ is written directly to disk via ParquetWriter.
|
||||
"""
|
||||
logger.info(f"Full refresh: {table_config.name}")
|
||||
logger.info(f"Full refresh (streaming): {table_config.name}")
|
||||
|
||||
parquet_path = self.config.get_parquet_path(table_config)
|
||||
date_columns = self.bq_client.get_date_columns(table_config.id)
|
||||
pyarrow_schema = self.bq_client.get_pyarrow_schema(table_config.id)
|
||||
|
||||
# Read full table from BigQuery -> PyArrow
|
||||
arrow_table = self.bq_client.read_table(
|
||||
parquet_path.parent.mkdir(parents=True, exist_ok=True)
|
||||
|
||||
# Stream BQ results directly to Parquet file (constant memory)
|
||||
writer = None
|
||||
total_rows = 0
|
||||
num_columns = 0
|
||||
|
||||
for batch in self.bq_client.read_table_streaming(
|
||||
table_config.id,
|
||||
columns=table_config.columns,
|
||||
row_filter=table_config.row_filter,
|
||||
)
|
||||
):
|
||||
if batch.num_rows == 0:
|
||||
continue
|
||||
|
||||
# Apply schema enforcement
|
||||
# Convert batch to table for schema enforcement
|
||||
chunk = pa.Table.from_batches([batch])
|
||||
if date_columns:
|
||||
arrow_table = convert_date_columns_to_date32(arrow_table, date_columns)
|
||||
chunk = convert_date_columns_to_date32(chunk, date_columns)
|
||||
if pyarrow_schema:
|
||||
arrow_table = apply_schema_to_table(arrow_table, pyarrow_schema)
|
||||
chunk = apply_schema_to_table(chunk, pyarrow_schema)
|
||||
|
||||
# Write to Parquet
|
||||
parquet_path.parent.mkdir(parents=True, exist_ok=True)
|
||||
pq.write_table(arrow_table, parquet_path, compression="snappy")
|
||||
if writer is None:
|
||||
writer = pq.ParquetWriter(
|
||||
parquet_path, chunk.schema, compression="snappy",
|
||||
)
|
||||
num_columns = chunk.num_columns
|
||||
|
||||
file_size = parquet_path.stat().st_size
|
||||
writer.write_table(chunk)
|
||||
total_rows += chunk.num_rows
|
||||
|
||||
# Log progress every ~1M rows
|
||||
if total_rows % 1_000_000 < chunk.num_rows:
|
||||
logger.info(f" -> {total_rows:,} rows written...")
|
||||
|
||||
if writer:
|
||||
writer.close()
|
||||
|
||||
file_size = parquet_path.stat().st_size if parquet_path.exists() else 0
|
||||
logger.info(
|
||||
f"Full refresh complete: {arrow_table.num_rows} rows, "
|
||||
f"Full refresh complete: {total_rows:,} rows, "
|
||||
f"{file_size / 1024 / 1024:.2f} MB"
|
||||
)
|
||||
|
||||
return {
|
||||
"rows": arrow_table.num_rows,
|
||||
"columns": arrow_table.num_columns,
|
||||
"rows": total_rows,
|
||||
"columns": num_columns,
|
||||
"file_size_bytes": file_size,
|
||||
"uncompressed_bytes": _get_uncompressed_size(parquet_path),
|
||||
"uncompressed_bytes": _get_uncompressed_size(parquet_path) if total_rows > 0 else 0,
|
||||
}
|
||||
|
||||
def _incremental_sync(
|
||||
|
|
|
|||
|
|
@ -292,6 +292,80 @@ class BigQueryClient:
|
|||
logger.debug(f"Query returned {arrow_table.num_rows} rows, {arrow_table.num_columns} columns")
|
||||
return arrow_table
|
||||
|
||||
def query_to_arrow_batches(
|
||||
self,
|
||||
sql: str,
|
||||
params: Optional[List[bigquery.ScalarQueryParameter]] = None,
|
||||
):
|
||||
"""
|
||||
Execute SQL query and yield results as streaming RecordBatches.
|
||||
|
||||
Unlike query_to_arrow(), this does NOT load entire result into memory.
|
||||
Each RecordBatch is a small chunk (typically a few MB) that can be
|
||||
written to disk immediately.
|
||||
|
||||
Args:
|
||||
sql: SQL query string (use @param_name for parameterized values)
|
||||
params: List of BigQuery query parameters
|
||||
|
||||
Yields:
|
||||
pyarrow.RecordBatch objects
|
||||
"""
|
||||
job_config = bigquery.QueryJobConfig()
|
||||
if params:
|
||||
job_config.query_parameters = params
|
||||
|
||||
logger.debug(f"Executing BQ query (streaming): {sql[:200]}...")
|
||||
|
||||
query_job = self.client.query(sql, job_config=job_config)
|
||||
|
||||
# Try BQ Storage API first (faster, parallel gRPC streams).
|
||||
# Fall back to REST API if SA lacks bigquery.readsessions.create permission.
|
||||
try:
|
||||
batch_iter = query_job.to_arrow_iterable()
|
||||
# Probe first batch to detect Storage API permission errors early
|
||||
first_batch = next(batch_iter, None)
|
||||
if first_batch is not None:
|
||||
yield first_batch
|
||||
yield from batch_iter
|
||||
return
|
||||
except Exception as storage_err:
|
||||
if "readsessions" not in str(storage_err) and "PERMISSION_DENIED" not in str(storage_err):
|
||||
raise
|
||||
logger.warning(
|
||||
"BQ Storage API unavailable (missing readsessions permission), "
|
||||
"falling back to REST API (streaming)"
|
||||
)
|
||||
|
||||
# Fallback: REST API streaming (same query_job, just different reader)
|
||||
yield from query_job.to_arrow_iterable(create_bqstorage_client=False)
|
||||
|
||||
def read_table_streaming(
|
||||
self,
|
||||
table_id: str,
|
||||
columns: Optional[List[str]] = None,
|
||||
row_filter: Optional[str] = None,
|
||||
):
|
||||
"""
|
||||
Read table as streaming RecordBatches (constant memory).
|
||||
|
||||
Args:
|
||||
table_id: Full table ID (e.g., "project.dataset.table")
|
||||
columns: Optional list of columns to select
|
||||
row_filter: Optional SQL WHERE clause (without WHERE keyword)
|
||||
|
||||
Yields:
|
||||
pyarrow.RecordBatch objects
|
||||
"""
|
||||
select_cols = ", ".join(f"`{c}`" for c in columns) if columns else "*"
|
||||
|
||||
sql = f"SELECT {select_cols} FROM `{table_id}`"
|
||||
if row_filter:
|
||||
sql += f" WHERE {row_filter}"
|
||||
|
||||
logger.info(f"Streaming BQ table: {table_id} (filter: {row_filter or 'none'})")
|
||||
yield from self.query_to_arrow_batches(sql)
|
||||
|
||||
def read_table(
|
||||
self,
|
||||
table_id: str,
|
||||
|
|
|
|||
|
|
@ -104,6 +104,11 @@ def _sample_arrow_table(ids: list[int], names: list[str]) -> pa.Table:
|
|||
return pa.table({"id": ids, "name": names})
|
||||
|
||||
|
||||
def _as_batches(arrow_table: pa.Table) -> list:
|
||||
"""Convert Arrow table to list of RecordBatches (mimics streaming from BQ)."""
|
||||
return arrow_table.to_batches()
|
||||
|
||||
|
||||
def _create_adapter(mock_config, mock_bq_client):
|
||||
"""Instantiate BigQueryDataSource with mocked dependencies.
|
||||
|
||||
|
|
@ -124,13 +129,13 @@ def _create_adapter(mock_config, mock_bq_client):
|
|||
class TestFullRefresh:
|
||||
|
||||
def test_writes_valid_parquet(self, mock_config, mock_bq_client, tmp_parquet_dir, sync_state):
|
||||
"""full_refresh should write a valid, readable Parquet file."""
|
||||
"""full_refresh should stream batches and write a valid Parquet file."""
|
||||
table_config = _make_table_config(sync_strategy="full_refresh")
|
||||
parquet_path = tmp_parquet_dir / "orders.parquet"
|
||||
mock_config.get_parquet_path.return_value = parquet_path
|
||||
|
||||
arrow_data = _sample_arrow_table([1, 2, 3], ["Alice", "Bob", "Charlie"])
|
||||
mock_bq_client.read_table.return_value = arrow_data
|
||||
mock_bq_client.read_table_streaming.return_value = _as_batches(arrow_data)
|
||||
|
||||
adapter = _create_adapter(mock_config, mock_bq_client)
|
||||
result = adapter.sync_table(table_config, sync_state)
|
||||
|
|
@ -145,35 +150,37 @@ class TestFullRefresh:
|
|||
assert read_back.column_names == ["id", "name"]
|
||||
|
||||
def test_applies_date_columns(self, mock_config, mock_bq_client, tmp_parquet_dir, sync_state):
|
||||
"""full_refresh should call convert_date_columns_to_date32 when date columns exist."""
|
||||
"""full_refresh should call convert_date_columns_to_date32 per batch."""
|
||||
table_config = _make_table_config()
|
||||
parquet_path = tmp_parquet_dir / "orders.parquet"
|
||||
mock_config.get_parquet_path.return_value = parquet_path
|
||||
|
||||
arrow_data = _sample_arrow_table([1], ["Alice"])
|
||||
mock_bq_client.read_table.return_value = arrow_data
|
||||
mock_bq_client.read_table_streaming.return_value = _as_batches(arrow_data)
|
||||
mock_bq_client.get_date_columns.return_value = ["created_at"]
|
||||
|
||||
with patch("connectors.bigquery.adapter.convert_date_columns_to_date32", return_value=arrow_data) as mock_conv:
|
||||
adapter = _create_adapter(mock_config, mock_bq_client)
|
||||
adapter.sync_table(table_config, sync_state)
|
||||
mock_conv.assert_called_once_with(arrow_data, ["created_at"])
|
||||
mock_conv.assert_called_once()
|
||||
assert mock_conv.call_args[0][1] == ["created_at"]
|
||||
|
||||
def test_applies_pyarrow_schema(self, mock_config, mock_bq_client, tmp_parquet_dir, sync_state):
|
||||
"""full_refresh should call apply_schema_to_table when schema is available."""
|
||||
"""full_refresh should call apply_schema_to_table per batch."""
|
||||
table_config = _make_table_config()
|
||||
parquet_path = tmp_parquet_dir / "orders.parquet"
|
||||
mock_config.get_parquet_path.return_value = parquet_path
|
||||
|
||||
arrow_data = _sample_arrow_table([1], ["Alice"])
|
||||
mock_bq_client.read_table.return_value = arrow_data
|
||||
mock_bq_client.read_table_streaming.return_value = _as_batches(arrow_data)
|
||||
schema = pa.schema([pa.field("id", pa.int64()), pa.field("name", pa.string())])
|
||||
mock_bq_client.get_pyarrow_schema.return_value = schema
|
||||
|
||||
with patch("connectors.bigquery.adapter.apply_schema_to_table", return_value=arrow_data) as mock_apply:
|
||||
adapter = _create_adapter(mock_config, mock_bq_client)
|
||||
adapter.sync_table(table_config, sync_state)
|
||||
mock_apply.assert_called_once_with(arrow_data, schema)
|
||||
mock_apply.assert_called_once()
|
||||
assert mock_apply.call_args[0][1] == schema
|
||||
|
||||
|
||||
# ---------------------------------------------------------------------------
|
||||
|
|
@ -396,7 +403,7 @@ class TestErrorHandling:
|
|||
"""When BigQuery query raises, sync_table returns {success: False, error: ...}."""
|
||||
table_config = _make_table_config()
|
||||
mock_config.get_parquet_path.return_value = tmp_parquet_dir / "orders.parquet"
|
||||
mock_bq_client.read_table.side_effect = RuntimeError("BigQuery API timeout")
|
||||
mock_bq_client.read_table_streaming.side_effect = RuntimeError("BigQuery API timeout")
|
||||
|
||||
adapter = _create_adapter(mock_config, mock_bq_client)
|
||||
result = adapter.sync_table(table_config, sync_state)
|
||||
|
|
@ -538,7 +545,7 @@ class TestSyncTableDispatch:
|
|||
"""sync_strategy='full_refresh' should call _full_refresh."""
|
||||
table_config = _make_table_config(sync_strategy="full_refresh")
|
||||
mock_config.get_parquet_path.return_value = tmp_parquet_dir / "orders.parquet"
|
||||
mock_bq_client.read_table.return_value = _sample_arrow_table([1], ["A"])
|
||||
mock_bq_client.read_table_streaming.return_value = _as_batches(_sample_arrow_table([1], ["A"]))
|
||||
|
||||
adapter = _create_adapter(mock_config, mock_bq_client)
|
||||
|
||||
|
|
@ -610,7 +617,7 @@ class TestSyncTableDispatch:
|
|||
incremental_window_days=7,
|
||||
)
|
||||
mock_config.get_parquet_path.return_value = tmp_parquet_dir / "orders.parquet"
|
||||
mock_bq_client.read_table.return_value = _sample_arrow_table([1], ["A"])
|
||||
mock_bq_client.read_table_streaming.return_value = _as_batches(_sample_arrow_table([1], ["A"]))
|
||||
|
||||
adapter = _create_adapter(mock_config, mock_bq_client)
|
||||
|
||||
|
|
@ -708,7 +715,7 @@ class TestMetadataCacheClearing:
|
|||
table_config = _make_table_config()
|
||||
parquet_path = tmp_parquet_dir / "orders.parquet"
|
||||
mock_config.get_parquet_path.return_value = parquet_path
|
||||
mock_bq_client.read_table.return_value = _sample_arrow_table([1], ["A"])
|
||||
mock_bq_client.read_table_streaming.return_value = _as_batches(_sample_arrow_table([1], ["A"]))
|
||||
|
||||
# Pre-populate cache
|
||||
mock_bq_client.metadata_cache[table_config.id] = {"some": "cached_data"}
|
||||
|
|
@ -728,7 +735,7 @@ class TestSyncStateUpdate:
|
|||
table_config = _make_table_config()
|
||||
parquet_path = tmp_parquet_dir / "orders.parquet"
|
||||
mock_config.get_parquet_path.return_value = parquet_path
|
||||
mock_bq_client.read_table.return_value = _sample_arrow_table([1, 2], ["A", "B"])
|
||||
mock_bq_client.read_table_streaming.return_value = _as_batches(_sample_arrow_table([1, 2], ["A", "B"]))
|
||||
|
||||
adapter = _create_adapter(mock_config, mock_bq_client)
|
||||
adapter.sync_table(table_config, sync_state)
|
||||
|
|
@ -745,7 +752,7 @@ class TestSyncStateUpdate:
|
|||
"""On sync failure, the sync state should NOT be updated."""
|
||||
table_config = _make_table_config()
|
||||
mock_config.get_parquet_path.return_value = tmp_parquet_dir / "orders.parquet"
|
||||
mock_bq_client.read_table.side_effect = RuntimeError("boom")
|
||||
mock_bq_client.read_table_streaming.side_effect = RuntimeError("boom")
|
||||
|
||||
adapter = _create_adapter(mock_config, mock_bq_client)
|
||||
adapter.sync_table(table_config, sync_state)
|
||||
|
|
|
|||
Loading…
Reference in a new issue