feat: add graceful shutdown handler
- Add close_system_db() function in src/db.py to cleanly close shared DB connection - Add lifespan context manager in app/main.py to trigger shutdown on app exit - Integrate lifespan into FastAPI app initialization - All API tests pass (77/77)
This commit is contained in:
parent
d6458a5b53
commit
53a9e838f9
11 changed files with 88 additions and 21 deletions
|
|
@ -1,22 +1,16 @@
|
|||
"""Data download endpoint — streaming parquet files."""
|
||||
|
||||
import os
|
||||
from pathlib import Path
|
||||
|
||||
from fastapi import APIRouter, Depends, HTTPException, Request
|
||||
from fastapi.responses import FileResponse
|
||||
import duckdb
|
||||
|
||||
from app.auth.dependencies import get_current_user, _get_db
|
||||
from app.utils import get_data_dir as _get_data_dir
|
||||
from src.rbac import can_access_table
|
||||
|
||||
router = APIRouter(prefix="/api/data", tags=["data"])
|
||||
|
||||
|
||||
def _get_data_dir() -> Path:
|
||||
return Path(os.environ.get("DATA_DIR", "./data"))
|
||||
|
||||
|
||||
@router.get("/{table_id}/download")
|
||||
async def download_table(
|
||||
table_id: str,
|
||||
|
|
|
|||
|
|
@ -13,6 +13,7 @@ from pydantic import BaseModel
|
|||
import duckdb
|
||||
|
||||
from app.auth.dependencies import get_current_user, require_role, Role, _get_db
|
||||
from app.utils import get_data_dir as _get_data_dir
|
||||
from src.repositories.sync_state import SyncStateRepository
|
||||
from src.repositories.sync_settings import SyncSettingsRepository, DatasetPermissionRepository
|
||||
from src.rbac import can_access_table
|
||||
|
|
@ -31,10 +32,6 @@ def _file_hash(path: Path) -> str:
|
|||
return h.hexdigest()
|
||||
|
||||
|
||||
def _get_data_dir() -> Path:
|
||||
return Path(os.environ.get("DATA_DIR", "./data"))
|
||||
|
||||
|
||||
def _run_sync(tables: Optional[List[str]] = None):
|
||||
"""Run extractor as subprocess + orchestrator rebuild.
|
||||
|
||||
|
|
|
|||
|
|
@ -1,6 +1,7 @@
|
|||
"""FastAPI main application — unified server for web UI + API."""
|
||||
|
||||
import logging
|
||||
from contextlib import asynccontextmanager
|
||||
from pathlib import Path
|
||||
|
||||
import os
|
||||
|
|
@ -31,11 +32,19 @@ from app.web.router import router as web_router
|
|||
logger = logging.getLogger(__name__)
|
||||
|
||||
|
||||
@asynccontextmanager
|
||||
async def lifespan(app):
|
||||
yield
|
||||
from src.db import close_system_db
|
||||
close_system_db()
|
||||
|
||||
|
||||
def create_app() -> FastAPI:
|
||||
app = FastAPI(
|
||||
title="AI Data Analyst",
|
||||
description="Data distribution platform for AI analytical systems",
|
||||
version="2.0.0",
|
||||
lifespan=lifespan,
|
||||
)
|
||||
|
||||
# Session middleware (required for OAuth state)
|
||||
|
|
|
|||
8
app/utils.py
Normal file
8
app/utils.py
Normal file
|
|
@ -0,0 +1,8 @@
|
|||
"""Shared utilities for the FastAPI application."""
|
||||
import os
|
||||
from pathlib import Path
|
||||
|
||||
|
||||
def get_data_dir() -> Path:
|
||||
"""Return the configured data directory path."""
|
||||
return Path(os.environ.get("DATA_DIR", "./data"))
|
||||
12
src/db.py
12
src/db.py
|
|
@ -286,3 +286,15 @@ def get_schema_version(conn: duckdb.DuckDBPyConnection) -> int:
|
|||
return result[0] if result and result[0] else 0
|
||||
except duckdb.CatalogException:
|
||||
return 0
|
||||
|
||||
|
||||
def close_system_db() -> None:
|
||||
"""Close the shared system DB connection. Called on app shutdown."""
|
||||
global _system_db_conn, _system_db_path
|
||||
if _system_db_conn:
|
||||
try:
|
||||
_system_db_conn.close()
|
||||
except Exception:
|
||||
pass
|
||||
_system_db_conn = None
|
||||
_system_db_path = None
|
||||
|
|
|
|||
|
|
@ -8,10 +8,10 @@ import pytest
|
|||
|
||||
|
||||
@pytest.fixture
|
||||
def e2e_env(tmp_path):
|
||||
def e2e_env(tmp_path, monkeypatch):
|
||||
"""Set up complete E2E environment with DATA_DIR, create dirs."""
|
||||
os.environ["DATA_DIR"] = str(tmp_path)
|
||||
os.environ["JWT_SECRET_KEY"] = "test-secret-e2e"
|
||||
monkeypatch.setenv("DATA_DIR", str(tmp_path))
|
||||
monkeypatch.setenv("JWT_SECRET_KEY", "test-secret-e2e")
|
||||
|
||||
(tmp_path / "extracts").mkdir()
|
||||
(tmp_path / "analytics").mkdir()
|
||||
|
|
@ -23,9 +23,6 @@ def e2e_env(tmp_path):
|
|||
"analytics_db": str(tmp_path / "analytics" / "server.duckdb"),
|
||||
}
|
||||
|
||||
os.environ.pop("DATA_DIR", None)
|
||||
os.environ.pop("JWT_SECRET_KEY", None)
|
||||
|
||||
|
||||
def create_mock_extract(extracts_dir: Path, source_name: str, tables: list[dict]):
|
||||
"""Create a mock extract.duckdb with _meta and data tables.
|
||||
|
|
|
|||
0
tests/helpers/__init__.py
Normal file
0
tests/helpers/__init__.py
Normal file
42
tests/helpers/contract.py
Normal file
42
tests/helpers/contract.py
Normal file
|
|
@ -0,0 +1,42 @@
|
|||
"""Shared validator for the extract.duckdb contract."""
|
||||
import duckdb
|
||||
from pathlib import Path
|
||||
|
||||
def validate_extract_contract(db_path: str) -> None:
|
||||
"""Verify an extract.duckdb conforms to the contract. Raises AssertionError if not."""
|
||||
path = Path(db_path)
|
||||
assert path.exists(), f"extract.duckdb not found at {db_path}"
|
||||
|
||||
conn = duckdb.connect(str(path), read_only=True)
|
||||
try:
|
||||
# _meta table must exist with correct schema
|
||||
cols = conn.execute(
|
||||
"SELECT column_name FROM information_schema.columns "
|
||||
"WHERE table_name='_meta' ORDER BY ordinal_position"
|
||||
).fetchall()
|
||||
col_names = [c[0] for c in cols]
|
||||
assert col_names == ["table_name", "description", "rows", "size_bytes", "extracted_at", "query_mode"], \
|
||||
f"_meta schema mismatch: {col_names}"
|
||||
|
||||
# Every local table in _meta must have a view/table
|
||||
local_tables = conn.execute("SELECT table_name FROM _meta WHERE query_mode = 'local'").fetchall()
|
||||
for (name,) in local_tables:
|
||||
tables = conn.execute(
|
||||
"SELECT table_name FROM information_schema.tables WHERE table_name = ?", [name]
|
||||
).fetchall()
|
||||
assert len(tables) > 0, f"Local table '{name}' in _meta but no view/table exists"
|
||||
|
||||
# If _remote_attach exists, validate schema
|
||||
ra_exists = conn.execute(
|
||||
"SELECT count(*) FROM information_schema.tables WHERE table_name='_remote_attach'"
|
||||
).fetchone()[0]
|
||||
if ra_exists:
|
||||
ra_cols = conn.execute(
|
||||
"SELECT column_name FROM information_schema.columns "
|
||||
"WHERE table_name='_remote_attach' ORDER BY ordinal_position"
|
||||
).fetchall()
|
||||
ra_col_names = [c[0] for c in ra_cols]
|
||||
assert ra_col_names == ["alias", "extension", "url", "token_env"], \
|
||||
f"_remote_attach schema mismatch: {ra_col_names}"
|
||||
finally:
|
||||
conn.close()
|
||||
|
|
@ -7,6 +7,8 @@ from unittest.mock import MagicMock
|
|||
import duckdb
|
||||
import pytest
|
||||
|
||||
from tests.helpers.contract import validate_extract_contract
|
||||
|
||||
|
||||
@pytest.fixture
|
||||
def output_dir(tmp_path):
|
||||
|
|
@ -114,6 +116,8 @@ class TestBigQueryExtractor:
|
|||
finally:
|
||||
conn.close()
|
||||
|
||||
validate_extract_contract(str(Path(output_dir) / "extract.duckdb"))
|
||||
|
||||
def test_no_data_directory_created(self, output_dir, sample_configs):
|
||||
"""BigQuery is remote-only -- no data/ directory should exist."""
|
||||
assert not (Path(output_dir) / "data").exists()
|
||||
|
|
|
|||
|
|
@ -6,13 +6,13 @@ import duckdb
|
|||
import pytest
|
||||
|
||||
|
||||
def _setup_data_dir(tmp_path):
|
||||
os.environ["DATA_DIR"] = str(tmp_path)
|
||||
def _setup_data_dir(tmp_path, monkeypatch):
|
||||
monkeypatch.setenv("DATA_DIR", str(tmp_path))
|
||||
|
||||
|
||||
class TestGetSystemDb:
|
||||
def test_creates_all_tables(self, tmp_path):
|
||||
_setup_data_dir(tmp_path)
|
||||
def test_creates_all_tables(self, tmp_path, monkeypatch):
|
||||
_setup_data_dir(tmp_path, monkeypatch)
|
||||
from src.db import get_system_db
|
||||
|
||||
conn = get_system_db()
|
||||
|
|
|
|||
|
|
@ -7,6 +7,8 @@ from unittest.mock import patch, MagicMock
|
|||
import duckdb
|
||||
import pytest
|
||||
|
||||
from tests.helpers.contract import validate_extract_contract
|
||||
|
||||
|
||||
@pytest.fixture
|
||||
def output_dir(tmp_path):
|
||||
|
|
@ -86,6 +88,8 @@ class TestKeboolaExtractor:
|
|||
finally:
|
||||
conn.close()
|
||||
|
||||
validate_extract_contract(str(db_path))
|
||||
|
||||
def test_remote_tables_not_downloaded(self, output_dir):
|
||||
"""Test that tables with query_mode='remote' are registered but not downloaded."""
|
||||
from connectors.keboola.extractor import run
|
||||
|
|
|
|||
Loading…
Reference in a new issue