From 53a9e838f9cca2e6900d209274c28444499b7e70 Mon Sep 17 00:00:00 2001 From: ZdenekSrotyr Date: Thu, 9 Apr 2026 07:03:45 +0200 Subject: [PATCH] feat: add graceful shutdown handler - Add close_system_db() function in src/db.py to cleanly close shared DB connection - Add lifespan context manager in app/main.py to trigger shutdown on app exit - Integrate lifespan into FastAPI app initialization - All API tests pass (77/77) --- app/api/data.py | 8 +----- app/api/sync.py | 5 +--- app/main.py | 9 +++++++ app/utils.py | 8 ++++++ src/db.py | 12 +++++++++ tests/conftest.py | 9 +++---- tests/helpers/__init__.py | 0 tests/helpers/contract.py | 42 ++++++++++++++++++++++++++++++++ tests/test_bigquery_extractor.py | 4 +++ tests/test_db.py | 8 +++--- tests/test_keboola_extractor.py | 4 +++ 11 files changed, 88 insertions(+), 21 deletions(-) create mode 100644 app/utils.py create mode 100644 tests/helpers/__init__.py create mode 100644 tests/helpers/contract.py diff --git a/app/api/data.py b/app/api/data.py index e5b6bb6..b78e0c3 100644 --- a/app/api/data.py +++ b/app/api/data.py @@ -1,22 +1,16 @@ """Data download endpoint — streaming parquet files.""" -import os -from pathlib import Path - from fastapi import APIRouter, Depends, HTTPException, Request from fastapi.responses import FileResponse import duckdb from app.auth.dependencies import get_current_user, _get_db +from app.utils import get_data_dir as _get_data_dir from src.rbac import can_access_table router = APIRouter(prefix="/api/data", tags=["data"]) -def _get_data_dir() -> Path: - return Path(os.environ.get("DATA_DIR", "./data")) - - @router.get("/{table_id}/download") async def download_table( table_id: str, diff --git a/app/api/sync.py b/app/api/sync.py index 4fb26a8..d37da4d 100644 --- a/app/api/sync.py +++ b/app/api/sync.py @@ -13,6 +13,7 @@ from pydantic import BaseModel import duckdb from app.auth.dependencies import get_current_user, require_role, Role, _get_db +from app.utils import get_data_dir as _get_data_dir from src.repositories.sync_state import SyncStateRepository from src.repositories.sync_settings import SyncSettingsRepository, DatasetPermissionRepository from src.rbac import can_access_table @@ -31,10 +32,6 @@ def _file_hash(path: Path) -> str: return h.hexdigest() -def _get_data_dir() -> Path: - return Path(os.environ.get("DATA_DIR", "./data")) - - def _run_sync(tables: Optional[List[str]] = None): """Run extractor as subprocess + orchestrator rebuild. diff --git a/app/main.py b/app/main.py index c27f22d..224f684 100644 --- a/app/main.py +++ b/app/main.py @@ -1,6 +1,7 @@ """FastAPI main application — unified server for web UI + API.""" import logging +from contextlib import asynccontextmanager from pathlib import Path import os @@ -31,11 +32,19 @@ from app.web.router import router as web_router logger = logging.getLogger(__name__) +@asynccontextmanager +async def lifespan(app): + yield + from src.db import close_system_db + close_system_db() + + def create_app() -> FastAPI: app = FastAPI( title="AI Data Analyst", description="Data distribution platform for AI analytical systems", version="2.0.0", + lifespan=lifespan, ) # Session middleware (required for OAuth state) diff --git a/app/utils.py b/app/utils.py new file mode 100644 index 0000000..2a05215 --- /dev/null +++ b/app/utils.py @@ -0,0 +1,8 @@ +"""Shared utilities for the FastAPI application.""" +import os +from pathlib import Path + + +def get_data_dir() -> Path: + """Return the configured data directory path.""" + return Path(os.environ.get("DATA_DIR", "./data")) diff --git a/src/db.py b/src/db.py index b381f03..b55d131 100644 --- a/src/db.py +++ b/src/db.py @@ -286,3 +286,15 @@ def get_schema_version(conn: duckdb.DuckDBPyConnection) -> int: return result[0] if result and result[0] else 0 except duckdb.CatalogException: return 0 + + +def close_system_db() -> None: + """Close the shared system DB connection. Called on app shutdown.""" + global _system_db_conn, _system_db_path + if _system_db_conn: + try: + _system_db_conn.close() + except Exception: + pass + _system_db_conn = None + _system_db_path = None diff --git a/tests/conftest.py b/tests/conftest.py index e5443a9..0cc1da8 100644 --- a/tests/conftest.py +++ b/tests/conftest.py @@ -8,10 +8,10 @@ import pytest @pytest.fixture -def e2e_env(tmp_path): +def e2e_env(tmp_path, monkeypatch): """Set up complete E2E environment with DATA_DIR, create dirs.""" - os.environ["DATA_DIR"] = str(tmp_path) - os.environ["JWT_SECRET_KEY"] = "test-secret-e2e" + monkeypatch.setenv("DATA_DIR", str(tmp_path)) + monkeypatch.setenv("JWT_SECRET_KEY", "test-secret-e2e") (tmp_path / "extracts").mkdir() (tmp_path / "analytics").mkdir() @@ -23,9 +23,6 @@ def e2e_env(tmp_path): "analytics_db": str(tmp_path / "analytics" / "server.duckdb"), } - os.environ.pop("DATA_DIR", None) - os.environ.pop("JWT_SECRET_KEY", None) - def create_mock_extract(extracts_dir: Path, source_name: str, tables: list[dict]): """Create a mock extract.duckdb with _meta and data tables. diff --git a/tests/helpers/__init__.py b/tests/helpers/__init__.py new file mode 100644 index 0000000..e69de29 diff --git a/tests/helpers/contract.py b/tests/helpers/contract.py new file mode 100644 index 0000000..9c872b5 --- /dev/null +++ b/tests/helpers/contract.py @@ -0,0 +1,42 @@ +"""Shared validator for the extract.duckdb contract.""" +import duckdb +from pathlib import Path + +def validate_extract_contract(db_path: str) -> None: + """Verify an extract.duckdb conforms to the contract. Raises AssertionError if not.""" + path = Path(db_path) + assert path.exists(), f"extract.duckdb not found at {db_path}" + + conn = duckdb.connect(str(path), read_only=True) + try: + # _meta table must exist with correct schema + cols = conn.execute( + "SELECT column_name FROM information_schema.columns " + "WHERE table_name='_meta' ORDER BY ordinal_position" + ).fetchall() + col_names = [c[0] for c in cols] + assert col_names == ["table_name", "description", "rows", "size_bytes", "extracted_at", "query_mode"], \ + f"_meta schema mismatch: {col_names}" + + # Every local table in _meta must have a view/table + local_tables = conn.execute("SELECT table_name FROM _meta WHERE query_mode = 'local'").fetchall() + for (name,) in local_tables: + tables = conn.execute( + "SELECT table_name FROM information_schema.tables WHERE table_name = ?", [name] + ).fetchall() + assert len(tables) > 0, f"Local table '{name}' in _meta but no view/table exists" + + # If _remote_attach exists, validate schema + ra_exists = conn.execute( + "SELECT count(*) FROM information_schema.tables WHERE table_name='_remote_attach'" + ).fetchone()[0] + if ra_exists: + ra_cols = conn.execute( + "SELECT column_name FROM information_schema.columns " + "WHERE table_name='_remote_attach' ORDER BY ordinal_position" + ).fetchall() + ra_col_names = [c[0] for c in ra_cols] + assert ra_col_names == ["alias", "extension", "url", "token_env"], \ + f"_remote_attach schema mismatch: {ra_col_names}" + finally: + conn.close() diff --git a/tests/test_bigquery_extractor.py b/tests/test_bigquery_extractor.py index aa5b3e7..95ed8bd 100644 --- a/tests/test_bigquery_extractor.py +++ b/tests/test_bigquery_extractor.py @@ -7,6 +7,8 @@ from unittest.mock import MagicMock import duckdb import pytest +from tests.helpers.contract import validate_extract_contract + @pytest.fixture def output_dir(tmp_path): @@ -114,6 +116,8 @@ class TestBigQueryExtractor: finally: conn.close() + validate_extract_contract(str(Path(output_dir) / "extract.duckdb")) + def test_no_data_directory_created(self, output_dir, sample_configs): """BigQuery is remote-only -- no data/ directory should exist.""" assert not (Path(output_dir) / "data").exists() diff --git a/tests/test_db.py b/tests/test_db.py index a91f932..61bc4dc 100644 --- a/tests/test_db.py +++ b/tests/test_db.py @@ -6,13 +6,13 @@ import duckdb import pytest -def _setup_data_dir(tmp_path): - os.environ["DATA_DIR"] = str(tmp_path) +def _setup_data_dir(tmp_path, monkeypatch): + monkeypatch.setenv("DATA_DIR", str(tmp_path)) class TestGetSystemDb: - def test_creates_all_tables(self, tmp_path): - _setup_data_dir(tmp_path) + def test_creates_all_tables(self, tmp_path, monkeypatch): + _setup_data_dir(tmp_path, monkeypatch) from src.db import get_system_db conn = get_system_db() diff --git a/tests/test_keboola_extractor.py b/tests/test_keboola_extractor.py index d3f6fa5..6ed67fd 100644 --- a/tests/test_keboola_extractor.py +++ b/tests/test_keboola_extractor.py @@ -7,6 +7,8 @@ from unittest.mock import patch, MagicMock import duckdb import pytest +from tests.helpers.contract import validate_extract_contract + @pytest.fixture def output_dir(tmp_path): @@ -86,6 +88,8 @@ class TestKeboolaExtractor: finally: conn.close() + validate_extract_contract(str(db_path)) + def test_remote_tables_not_downloaded(self, output_dir): """Test that tables with query_mode='remote' are registered but not downloaded.""" from connectors.keboola.extractor import run