Stream BQ results to Parquet instead of loading into memory

Replace to_arrow() (loads entire result into RAM) with
to_arrow_iterable() (streams RecordBatches). Each batch is written
directly to disk via ParquetWriter - constant memory regardless
of table size. Prevents OOM on 8GB server for multi-million row tables.
This commit is contained in:
Petr 2026-03-11 20:13:03 +01:00
parent a191ede28c
commit ee70da86c3
3 changed files with 137 additions and 32 deletions

View file

@ -132,42 +132,66 @@ class BigQueryDataSource(DataSource):
def _full_refresh(self, table_config: TableConfig) -> Dict[str, Any]:
"""
Full refresh: read entire table and replace Parquet file.
Full refresh: stream table from BQ and write to Parquet in batches.
Uses streaming (constant memory) instead of loading entire table into RAM.
Each RecordBatch from BQ is written directly to disk via ParquetWriter.
"""
logger.info(f"Full refresh: {table_config.name}")
logger.info(f"Full refresh (streaming): {table_config.name}")
parquet_path = self.config.get_parquet_path(table_config)
date_columns = self.bq_client.get_date_columns(table_config.id)
pyarrow_schema = self.bq_client.get_pyarrow_schema(table_config.id)
# Read full table from BigQuery -> PyArrow
arrow_table = self.bq_client.read_table(
parquet_path.parent.mkdir(parents=True, exist_ok=True)
# Stream BQ results directly to Parquet file (constant memory)
writer = None
total_rows = 0
num_columns = 0
for batch in self.bq_client.read_table_streaming(
table_config.id,
columns=table_config.columns,
row_filter=table_config.row_filter,
)
):
if batch.num_rows == 0:
continue
# Apply schema enforcement
if date_columns:
arrow_table = convert_date_columns_to_date32(arrow_table, date_columns)
if pyarrow_schema:
arrow_table = apply_schema_to_table(arrow_table, pyarrow_schema)
# Convert batch to table for schema enforcement
chunk = pa.Table.from_batches([batch])
if date_columns:
chunk = convert_date_columns_to_date32(chunk, date_columns)
if pyarrow_schema:
chunk = apply_schema_to_table(chunk, pyarrow_schema)
# Write to Parquet
parquet_path.parent.mkdir(parents=True, exist_ok=True)
pq.write_table(arrow_table, parquet_path, compression="snappy")
if writer is None:
writer = pq.ParquetWriter(
parquet_path, chunk.schema, compression="snappy",
)
num_columns = chunk.num_columns
file_size = parquet_path.stat().st_size
writer.write_table(chunk)
total_rows += chunk.num_rows
# Log progress every ~1M rows
if total_rows % 1_000_000 < chunk.num_rows:
logger.info(f" -> {total_rows:,} rows written...")
if writer:
writer.close()
file_size = parquet_path.stat().st_size if parquet_path.exists() else 0
logger.info(
f"Full refresh complete: {arrow_table.num_rows} rows, "
f"Full refresh complete: {total_rows:,} rows, "
f"{file_size / 1024 / 1024:.2f} MB"
)
return {
"rows": arrow_table.num_rows,
"columns": arrow_table.num_columns,
"rows": total_rows,
"columns": num_columns,
"file_size_bytes": file_size,
"uncompressed_bytes": _get_uncompressed_size(parquet_path),
"uncompressed_bytes": _get_uncompressed_size(parquet_path) if total_rows > 0 else 0,
}
def _incremental_sync(

View file

@ -292,6 +292,80 @@ class BigQueryClient:
logger.debug(f"Query returned {arrow_table.num_rows} rows, {arrow_table.num_columns} columns")
return arrow_table
def query_to_arrow_batches(
self,
sql: str,
params: Optional[List[bigquery.ScalarQueryParameter]] = None,
):
"""
Execute SQL query and yield results as streaming RecordBatches.
Unlike query_to_arrow(), this does NOT load entire result into memory.
Each RecordBatch is a small chunk (typically a few MB) that can be
written to disk immediately.
Args:
sql: SQL query string (use @param_name for parameterized values)
params: List of BigQuery query parameters
Yields:
pyarrow.RecordBatch objects
"""
job_config = bigquery.QueryJobConfig()
if params:
job_config.query_parameters = params
logger.debug(f"Executing BQ query (streaming): {sql[:200]}...")
query_job = self.client.query(sql, job_config=job_config)
# Try BQ Storage API first (faster, parallel gRPC streams).
# Fall back to REST API if SA lacks bigquery.readsessions.create permission.
try:
batch_iter = query_job.to_arrow_iterable()
# Probe first batch to detect Storage API permission errors early
first_batch = next(batch_iter, None)
if first_batch is not None:
yield first_batch
yield from batch_iter
return
except Exception as storage_err:
if "readsessions" not in str(storage_err) and "PERMISSION_DENIED" not in str(storage_err):
raise
logger.warning(
"BQ Storage API unavailable (missing readsessions permission), "
"falling back to REST API (streaming)"
)
# Fallback: REST API streaming (same query_job, just different reader)
yield from query_job.to_arrow_iterable(create_bqstorage_client=False)
def read_table_streaming(
self,
table_id: str,
columns: Optional[List[str]] = None,
row_filter: Optional[str] = None,
):
"""
Read table as streaming RecordBatches (constant memory).
Args:
table_id: Full table ID (e.g., "project.dataset.table")
columns: Optional list of columns to select
row_filter: Optional SQL WHERE clause (without WHERE keyword)
Yields:
pyarrow.RecordBatch objects
"""
select_cols = ", ".join(f"`{c}`" for c in columns) if columns else "*"
sql = f"SELECT {select_cols} FROM `{table_id}`"
if row_filter:
sql += f" WHERE {row_filter}"
logger.info(f"Streaming BQ table: {table_id} (filter: {row_filter or 'none'})")
yield from self.query_to_arrow_batches(sql)
def read_table(
self,
table_id: str,

View file

@ -104,6 +104,11 @@ def _sample_arrow_table(ids: list[int], names: list[str]) -> pa.Table:
return pa.table({"id": ids, "name": names})
def _as_batches(arrow_table: pa.Table) -> list:
"""Convert Arrow table to list of RecordBatches (mimics streaming from BQ)."""
return arrow_table.to_batches()
def _create_adapter(mock_config, mock_bq_client):
"""Instantiate BigQueryDataSource with mocked dependencies.
@ -124,13 +129,13 @@ def _create_adapter(mock_config, mock_bq_client):
class TestFullRefresh:
def test_writes_valid_parquet(self, mock_config, mock_bq_client, tmp_parquet_dir, sync_state):
"""full_refresh should write a valid, readable Parquet file."""
"""full_refresh should stream batches and write a valid Parquet file."""
table_config = _make_table_config(sync_strategy="full_refresh")
parquet_path = tmp_parquet_dir / "orders.parquet"
mock_config.get_parquet_path.return_value = parquet_path
arrow_data = _sample_arrow_table([1, 2, 3], ["Alice", "Bob", "Charlie"])
mock_bq_client.read_table.return_value = arrow_data
mock_bq_client.read_table_streaming.return_value = _as_batches(arrow_data)
adapter = _create_adapter(mock_config, mock_bq_client)
result = adapter.sync_table(table_config, sync_state)
@ -145,35 +150,37 @@ class TestFullRefresh:
assert read_back.column_names == ["id", "name"]
def test_applies_date_columns(self, mock_config, mock_bq_client, tmp_parquet_dir, sync_state):
"""full_refresh should call convert_date_columns_to_date32 when date columns exist."""
"""full_refresh should call convert_date_columns_to_date32 per batch."""
table_config = _make_table_config()
parquet_path = tmp_parquet_dir / "orders.parquet"
mock_config.get_parquet_path.return_value = parquet_path
arrow_data = _sample_arrow_table([1], ["Alice"])
mock_bq_client.read_table.return_value = arrow_data
mock_bq_client.read_table_streaming.return_value = _as_batches(arrow_data)
mock_bq_client.get_date_columns.return_value = ["created_at"]
with patch("connectors.bigquery.adapter.convert_date_columns_to_date32", return_value=arrow_data) as mock_conv:
adapter = _create_adapter(mock_config, mock_bq_client)
adapter.sync_table(table_config, sync_state)
mock_conv.assert_called_once_with(arrow_data, ["created_at"])
mock_conv.assert_called_once()
assert mock_conv.call_args[0][1] == ["created_at"]
def test_applies_pyarrow_schema(self, mock_config, mock_bq_client, tmp_parquet_dir, sync_state):
"""full_refresh should call apply_schema_to_table when schema is available."""
"""full_refresh should call apply_schema_to_table per batch."""
table_config = _make_table_config()
parquet_path = tmp_parquet_dir / "orders.parquet"
mock_config.get_parquet_path.return_value = parquet_path
arrow_data = _sample_arrow_table([1], ["Alice"])
mock_bq_client.read_table.return_value = arrow_data
mock_bq_client.read_table_streaming.return_value = _as_batches(arrow_data)
schema = pa.schema([pa.field("id", pa.int64()), pa.field("name", pa.string())])
mock_bq_client.get_pyarrow_schema.return_value = schema
with patch("connectors.bigquery.adapter.apply_schema_to_table", return_value=arrow_data) as mock_apply:
adapter = _create_adapter(mock_config, mock_bq_client)
adapter.sync_table(table_config, sync_state)
mock_apply.assert_called_once_with(arrow_data, schema)
mock_apply.assert_called_once()
assert mock_apply.call_args[0][1] == schema
# ---------------------------------------------------------------------------
@ -396,7 +403,7 @@ class TestErrorHandling:
"""When BigQuery query raises, sync_table returns {success: False, error: ...}."""
table_config = _make_table_config()
mock_config.get_parquet_path.return_value = tmp_parquet_dir / "orders.parquet"
mock_bq_client.read_table.side_effect = RuntimeError("BigQuery API timeout")
mock_bq_client.read_table_streaming.side_effect = RuntimeError("BigQuery API timeout")
adapter = _create_adapter(mock_config, mock_bq_client)
result = adapter.sync_table(table_config, sync_state)
@ -538,7 +545,7 @@ class TestSyncTableDispatch:
"""sync_strategy='full_refresh' should call _full_refresh."""
table_config = _make_table_config(sync_strategy="full_refresh")
mock_config.get_parquet_path.return_value = tmp_parquet_dir / "orders.parquet"
mock_bq_client.read_table.return_value = _sample_arrow_table([1], ["A"])
mock_bq_client.read_table_streaming.return_value = _as_batches(_sample_arrow_table([1], ["A"]))
adapter = _create_adapter(mock_config, mock_bq_client)
@ -610,7 +617,7 @@ class TestSyncTableDispatch:
incremental_window_days=7,
)
mock_config.get_parquet_path.return_value = tmp_parquet_dir / "orders.parquet"
mock_bq_client.read_table.return_value = _sample_arrow_table([1], ["A"])
mock_bq_client.read_table_streaming.return_value = _as_batches(_sample_arrow_table([1], ["A"]))
adapter = _create_adapter(mock_config, mock_bq_client)
@ -708,7 +715,7 @@ class TestMetadataCacheClearing:
table_config = _make_table_config()
parquet_path = tmp_parquet_dir / "orders.parquet"
mock_config.get_parquet_path.return_value = parquet_path
mock_bq_client.read_table.return_value = _sample_arrow_table([1], ["A"])
mock_bq_client.read_table_streaming.return_value = _as_batches(_sample_arrow_table([1], ["A"]))
# Pre-populate cache
mock_bq_client.metadata_cache[table_config.id] = {"some": "cached_data"}
@ -728,7 +735,7 @@ class TestSyncStateUpdate:
table_config = _make_table_config()
parquet_path = tmp_parquet_dir / "orders.parquet"
mock_config.get_parquet_path.return_value = parquet_path
mock_bq_client.read_table.return_value = _sample_arrow_table([1, 2], ["A", "B"])
mock_bq_client.read_table_streaming.return_value = _as_batches(_sample_arrow_table([1, 2], ["A", "B"]))
adapter = _create_adapter(mock_config, mock_bq_client)
adapter.sync_table(table_config, sync_state)
@ -745,7 +752,7 @@ class TestSyncStateUpdate:
"""On sync failure, the sync state should NOT be updated."""
table_config = _make_table_config()
mock_config.get_parquet_path.return_value = tmp_parquet_dir / "orders.parquet"
mock_bq_client.read_table.side_effect = RuntimeError("boom")
mock_bq_client.read_table_streaming.side_effect = RuntimeError("boom")
adapter = _create_adapter(mock_config, mock_bq_client)
adapter.sync_table(table_config, sync_state)