diff --git a/connectors/bigquery/adapter.py b/connectors/bigquery/adapter.py index 62ecd30..c7992d1 100644 --- a/connectors/bigquery/adapter.py +++ b/connectors/bigquery/adapter.py @@ -132,42 +132,66 @@ class BigQueryDataSource(DataSource): def _full_refresh(self, table_config: TableConfig) -> Dict[str, Any]: """ - Full refresh: read entire table and replace Parquet file. + Full refresh: stream table from BQ and write to Parquet in batches. + + Uses streaming (constant memory) instead of loading entire table into RAM. + Each RecordBatch from BQ is written directly to disk via ParquetWriter. """ - logger.info(f"Full refresh: {table_config.name}") + logger.info(f"Full refresh (streaming): {table_config.name}") parquet_path = self.config.get_parquet_path(table_config) date_columns = self.bq_client.get_date_columns(table_config.id) pyarrow_schema = self.bq_client.get_pyarrow_schema(table_config.id) - # Read full table from BigQuery -> PyArrow - arrow_table = self.bq_client.read_table( + parquet_path.parent.mkdir(parents=True, exist_ok=True) + + # Stream BQ results directly to Parquet file (constant memory) + writer = None + total_rows = 0 + num_columns = 0 + + for batch in self.bq_client.read_table_streaming( table_config.id, columns=table_config.columns, row_filter=table_config.row_filter, - ) + ): + if batch.num_rows == 0: + continue - # Apply schema enforcement - if date_columns: - arrow_table = convert_date_columns_to_date32(arrow_table, date_columns) - if pyarrow_schema: - arrow_table = apply_schema_to_table(arrow_table, pyarrow_schema) + # Convert batch to table for schema enforcement + chunk = pa.Table.from_batches([batch]) + if date_columns: + chunk = convert_date_columns_to_date32(chunk, date_columns) + if pyarrow_schema: + chunk = apply_schema_to_table(chunk, pyarrow_schema) - # Write to Parquet - parquet_path.parent.mkdir(parents=True, exist_ok=True) - pq.write_table(arrow_table, parquet_path, compression="snappy") + if writer is None: + writer = pq.ParquetWriter( + parquet_path, chunk.schema, compression="snappy", + ) + num_columns = chunk.num_columns - file_size = parquet_path.stat().st_size + writer.write_table(chunk) + total_rows += chunk.num_rows + + # Log progress every ~1M rows + if total_rows % 1_000_000 < chunk.num_rows: + logger.info(f" -> {total_rows:,} rows written...") + + if writer: + writer.close() + + file_size = parquet_path.stat().st_size if parquet_path.exists() else 0 logger.info( - f"Full refresh complete: {arrow_table.num_rows} rows, " + f"Full refresh complete: {total_rows:,} rows, " f"{file_size / 1024 / 1024:.2f} MB" ) return { - "rows": arrow_table.num_rows, - "columns": arrow_table.num_columns, + "rows": total_rows, + "columns": num_columns, "file_size_bytes": file_size, - "uncompressed_bytes": _get_uncompressed_size(parquet_path), + "uncompressed_bytes": _get_uncompressed_size(parquet_path) if total_rows > 0 else 0, } def _incremental_sync( diff --git a/connectors/bigquery/client.py b/connectors/bigquery/client.py index cd07f7e..c07aaad 100644 --- a/connectors/bigquery/client.py +++ b/connectors/bigquery/client.py @@ -292,6 +292,80 @@ class BigQueryClient: logger.debug(f"Query returned {arrow_table.num_rows} rows, {arrow_table.num_columns} columns") return arrow_table + def query_to_arrow_batches( + self, + sql: str, + params: Optional[List[bigquery.ScalarQueryParameter]] = None, + ): + """ + Execute SQL query and yield results as streaming RecordBatches. + + Unlike query_to_arrow(), this does NOT load entire result into memory. + Each RecordBatch is a small chunk (typically a few MB) that can be + written to disk immediately. + + Args: + sql: SQL query string (use @param_name for parameterized values) + params: List of BigQuery query parameters + + Yields: + pyarrow.RecordBatch objects + """ + job_config = bigquery.QueryJobConfig() + if params: + job_config.query_parameters = params + + logger.debug(f"Executing BQ query (streaming): {sql[:200]}...") + + query_job = self.client.query(sql, job_config=job_config) + + # Try BQ Storage API first (faster, parallel gRPC streams). + # Fall back to REST API if SA lacks bigquery.readsessions.create permission. + try: + batch_iter = query_job.to_arrow_iterable() + # Probe first batch to detect Storage API permission errors early + first_batch = next(batch_iter, None) + if first_batch is not None: + yield first_batch + yield from batch_iter + return + except Exception as storage_err: + if "readsessions" not in str(storage_err) and "PERMISSION_DENIED" not in str(storage_err): + raise + logger.warning( + "BQ Storage API unavailable (missing readsessions permission), " + "falling back to REST API (streaming)" + ) + + # Fallback: REST API streaming (same query_job, just different reader) + yield from query_job.to_arrow_iterable(create_bqstorage_client=False) + + def read_table_streaming( + self, + table_id: str, + columns: Optional[List[str]] = None, + row_filter: Optional[str] = None, + ): + """ + Read table as streaming RecordBatches (constant memory). + + Args: + table_id: Full table ID (e.g., "project.dataset.table") + columns: Optional list of columns to select + row_filter: Optional SQL WHERE clause (without WHERE keyword) + + Yields: + pyarrow.RecordBatch objects + """ + select_cols = ", ".join(f"`{c}`" for c in columns) if columns else "*" + + sql = f"SELECT {select_cols} FROM `{table_id}`" + if row_filter: + sql += f" WHERE {row_filter}" + + logger.info(f"Streaming BQ table: {table_id} (filter: {row_filter or 'none'})") + yield from self.query_to_arrow_batches(sql) + def read_table( self, table_id: str, diff --git a/tests/test_bigquery_adapter.py b/tests/test_bigquery_adapter.py index d85e11d..dafe0fa 100644 --- a/tests/test_bigquery_adapter.py +++ b/tests/test_bigquery_adapter.py @@ -104,6 +104,11 @@ def _sample_arrow_table(ids: list[int], names: list[str]) -> pa.Table: return pa.table({"id": ids, "name": names}) +def _as_batches(arrow_table: pa.Table) -> list: + """Convert Arrow table to list of RecordBatches (mimics streaming from BQ).""" + return arrow_table.to_batches() + + def _create_adapter(mock_config, mock_bq_client): """Instantiate BigQueryDataSource with mocked dependencies. @@ -124,13 +129,13 @@ def _create_adapter(mock_config, mock_bq_client): class TestFullRefresh: def test_writes_valid_parquet(self, mock_config, mock_bq_client, tmp_parquet_dir, sync_state): - """full_refresh should write a valid, readable Parquet file.""" + """full_refresh should stream batches and write a valid Parquet file.""" table_config = _make_table_config(sync_strategy="full_refresh") parquet_path = tmp_parquet_dir / "orders.parquet" mock_config.get_parquet_path.return_value = parquet_path arrow_data = _sample_arrow_table([1, 2, 3], ["Alice", "Bob", "Charlie"]) - mock_bq_client.read_table.return_value = arrow_data + mock_bq_client.read_table_streaming.return_value = _as_batches(arrow_data) adapter = _create_adapter(mock_config, mock_bq_client) result = adapter.sync_table(table_config, sync_state) @@ -145,35 +150,37 @@ class TestFullRefresh: assert read_back.column_names == ["id", "name"] def test_applies_date_columns(self, mock_config, mock_bq_client, tmp_parquet_dir, sync_state): - """full_refresh should call convert_date_columns_to_date32 when date columns exist.""" + """full_refresh should call convert_date_columns_to_date32 per batch.""" table_config = _make_table_config() parquet_path = tmp_parquet_dir / "orders.parquet" mock_config.get_parquet_path.return_value = parquet_path arrow_data = _sample_arrow_table([1], ["Alice"]) - mock_bq_client.read_table.return_value = arrow_data + mock_bq_client.read_table_streaming.return_value = _as_batches(arrow_data) mock_bq_client.get_date_columns.return_value = ["created_at"] with patch("connectors.bigquery.adapter.convert_date_columns_to_date32", return_value=arrow_data) as mock_conv: adapter = _create_adapter(mock_config, mock_bq_client) adapter.sync_table(table_config, sync_state) - mock_conv.assert_called_once_with(arrow_data, ["created_at"]) + mock_conv.assert_called_once() + assert mock_conv.call_args[0][1] == ["created_at"] def test_applies_pyarrow_schema(self, mock_config, mock_bq_client, tmp_parquet_dir, sync_state): - """full_refresh should call apply_schema_to_table when schema is available.""" + """full_refresh should call apply_schema_to_table per batch.""" table_config = _make_table_config() parquet_path = tmp_parquet_dir / "orders.parquet" mock_config.get_parquet_path.return_value = parquet_path arrow_data = _sample_arrow_table([1], ["Alice"]) - mock_bq_client.read_table.return_value = arrow_data + mock_bq_client.read_table_streaming.return_value = _as_batches(arrow_data) schema = pa.schema([pa.field("id", pa.int64()), pa.field("name", pa.string())]) mock_bq_client.get_pyarrow_schema.return_value = schema with patch("connectors.bigquery.adapter.apply_schema_to_table", return_value=arrow_data) as mock_apply: adapter = _create_adapter(mock_config, mock_bq_client) adapter.sync_table(table_config, sync_state) - mock_apply.assert_called_once_with(arrow_data, schema) + mock_apply.assert_called_once() + assert mock_apply.call_args[0][1] == schema # --------------------------------------------------------------------------- @@ -396,7 +403,7 @@ class TestErrorHandling: """When BigQuery query raises, sync_table returns {success: False, error: ...}.""" table_config = _make_table_config() mock_config.get_parquet_path.return_value = tmp_parquet_dir / "orders.parquet" - mock_bq_client.read_table.side_effect = RuntimeError("BigQuery API timeout") + mock_bq_client.read_table_streaming.side_effect = RuntimeError("BigQuery API timeout") adapter = _create_adapter(mock_config, mock_bq_client) result = adapter.sync_table(table_config, sync_state) @@ -538,7 +545,7 @@ class TestSyncTableDispatch: """sync_strategy='full_refresh' should call _full_refresh.""" table_config = _make_table_config(sync_strategy="full_refresh") mock_config.get_parquet_path.return_value = tmp_parquet_dir / "orders.parquet" - mock_bq_client.read_table.return_value = _sample_arrow_table([1], ["A"]) + mock_bq_client.read_table_streaming.return_value = _as_batches(_sample_arrow_table([1], ["A"])) adapter = _create_adapter(mock_config, mock_bq_client) @@ -610,7 +617,7 @@ class TestSyncTableDispatch: incremental_window_days=7, ) mock_config.get_parquet_path.return_value = tmp_parquet_dir / "orders.parquet" - mock_bq_client.read_table.return_value = _sample_arrow_table([1], ["A"]) + mock_bq_client.read_table_streaming.return_value = _as_batches(_sample_arrow_table([1], ["A"])) adapter = _create_adapter(mock_config, mock_bq_client) @@ -708,7 +715,7 @@ class TestMetadataCacheClearing: table_config = _make_table_config() parquet_path = tmp_parquet_dir / "orders.parquet" mock_config.get_parquet_path.return_value = parquet_path - mock_bq_client.read_table.return_value = _sample_arrow_table([1], ["A"]) + mock_bq_client.read_table_streaming.return_value = _as_batches(_sample_arrow_table([1], ["A"])) # Pre-populate cache mock_bq_client.metadata_cache[table_config.id] = {"some": "cached_data"} @@ -728,7 +735,7 @@ class TestSyncStateUpdate: table_config = _make_table_config() parquet_path = tmp_parquet_dir / "orders.parquet" mock_config.get_parquet_path.return_value = parquet_path - mock_bq_client.read_table.return_value = _sample_arrow_table([1, 2], ["A", "B"]) + mock_bq_client.read_table_streaming.return_value = _as_batches(_sample_arrow_table([1, 2], ["A", "B"])) adapter = _create_adapter(mock_config, mock_bq_client) adapter.sync_table(table_config, sync_state) @@ -745,7 +752,7 @@ class TestSyncStateUpdate: """On sync failure, the sync state should NOT be updated.""" table_config = _make_table_config() mock_config.get_parquet_path.return_value = tmp_parquet_dir / "orders.parquet" - mock_bq_client.read_table.side_effect = RuntimeError("boom") + mock_bq_client.read_table_streaming.side_effect = RuntimeError("boom") adapter = _create_adapter(mock_config, mock_bq_client) adapter.sync_table(table_config, sync_state)