feat: add graceful shutdown handler
- Add close_system_db() function in src/db.py to cleanly close shared DB connection - Add lifespan context manager in app/main.py to trigger shutdown on app exit - Integrate lifespan into FastAPI app initialization - All API tests pass (77/77)
This commit is contained in:
parent
d6458a5b53
commit
53a9e838f9
11 changed files with 88 additions and 21 deletions
|
|
@ -1,22 +1,16 @@
|
||||||
"""Data download endpoint — streaming parquet files."""
|
"""Data download endpoint — streaming parquet files."""
|
||||||
|
|
||||||
import os
|
|
||||||
from pathlib import Path
|
|
||||||
|
|
||||||
from fastapi import APIRouter, Depends, HTTPException, Request
|
from fastapi import APIRouter, Depends, HTTPException, Request
|
||||||
from fastapi.responses import FileResponse
|
from fastapi.responses import FileResponse
|
||||||
import duckdb
|
import duckdb
|
||||||
|
|
||||||
from app.auth.dependencies import get_current_user, _get_db
|
from app.auth.dependencies import get_current_user, _get_db
|
||||||
|
from app.utils import get_data_dir as _get_data_dir
|
||||||
from src.rbac import can_access_table
|
from src.rbac import can_access_table
|
||||||
|
|
||||||
router = APIRouter(prefix="/api/data", tags=["data"])
|
router = APIRouter(prefix="/api/data", tags=["data"])
|
||||||
|
|
||||||
|
|
||||||
def _get_data_dir() -> Path:
|
|
||||||
return Path(os.environ.get("DATA_DIR", "./data"))
|
|
||||||
|
|
||||||
|
|
||||||
@router.get("/{table_id}/download")
|
@router.get("/{table_id}/download")
|
||||||
async def download_table(
|
async def download_table(
|
||||||
table_id: str,
|
table_id: str,
|
||||||
|
|
|
||||||
|
|
@ -13,6 +13,7 @@ from pydantic import BaseModel
|
||||||
import duckdb
|
import duckdb
|
||||||
|
|
||||||
from app.auth.dependencies import get_current_user, require_role, Role, _get_db
|
from app.auth.dependencies import get_current_user, require_role, Role, _get_db
|
||||||
|
from app.utils import get_data_dir as _get_data_dir
|
||||||
from src.repositories.sync_state import SyncStateRepository
|
from src.repositories.sync_state import SyncStateRepository
|
||||||
from src.repositories.sync_settings import SyncSettingsRepository, DatasetPermissionRepository
|
from src.repositories.sync_settings import SyncSettingsRepository, DatasetPermissionRepository
|
||||||
from src.rbac import can_access_table
|
from src.rbac import can_access_table
|
||||||
|
|
@ -31,10 +32,6 @@ def _file_hash(path: Path) -> str:
|
||||||
return h.hexdigest()
|
return h.hexdigest()
|
||||||
|
|
||||||
|
|
||||||
def _get_data_dir() -> Path:
|
|
||||||
return Path(os.environ.get("DATA_DIR", "./data"))
|
|
||||||
|
|
||||||
|
|
||||||
def _run_sync(tables: Optional[List[str]] = None):
|
def _run_sync(tables: Optional[List[str]] = None):
|
||||||
"""Run extractor as subprocess + orchestrator rebuild.
|
"""Run extractor as subprocess + orchestrator rebuild.
|
||||||
|
|
||||||
|
|
|
||||||
|
|
@ -1,6 +1,7 @@
|
||||||
"""FastAPI main application — unified server for web UI + API."""
|
"""FastAPI main application — unified server for web UI + API."""
|
||||||
|
|
||||||
import logging
|
import logging
|
||||||
|
from contextlib import asynccontextmanager
|
||||||
from pathlib import Path
|
from pathlib import Path
|
||||||
|
|
||||||
import os
|
import os
|
||||||
|
|
@ -31,11 +32,19 @@ from app.web.router import router as web_router
|
||||||
logger = logging.getLogger(__name__)
|
logger = logging.getLogger(__name__)
|
||||||
|
|
||||||
|
|
||||||
|
@asynccontextmanager
|
||||||
|
async def lifespan(app):
|
||||||
|
yield
|
||||||
|
from src.db import close_system_db
|
||||||
|
close_system_db()
|
||||||
|
|
||||||
|
|
||||||
def create_app() -> FastAPI:
|
def create_app() -> FastAPI:
|
||||||
app = FastAPI(
|
app = FastAPI(
|
||||||
title="AI Data Analyst",
|
title="AI Data Analyst",
|
||||||
description="Data distribution platform for AI analytical systems",
|
description="Data distribution platform for AI analytical systems",
|
||||||
version="2.0.0",
|
version="2.0.0",
|
||||||
|
lifespan=lifespan,
|
||||||
)
|
)
|
||||||
|
|
||||||
# Session middleware (required for OAuth state)
|
# Session middleware (required for OAuth state)
|
||||||
|
|
|
||||||
8
app/utils.py
Normal file
8
app/utils.py
Normal file
|
|
@ -0,0 +1,8 @@
|
||||||
|
"""Shared utilities for the FastAPI application."""
|
||||||
|
import os
|
||||||
|
from pathlib import Path
|
||||||
|
|
||||||
|
|
||||||
|
def get_data_dir() -> Path:
|
||||||
|
"""Return the configured data directory path."""
|
||||||
|
return Path(os.environ.get("DATA_DIR", "./data"))
|
||||||
12
src/db.py
12
src/db.py
|
|
@ -286,3 +286,15 @@ def get_schema_version(conn: duckdb.DuckDBPyConnection) -> int:
|
||||||
return result[0] if result and result[0] else 0
|
return result[0] if result and result[0] else 0
|
||||||
except duckdb.CatalogException:
|
except duckdb.CatalogException:
|
||||||
return 0
|
return 0
|
||||||
|
|
||||||
|
|
||||||
|
def close_system_db() -> None:
|
||||||
|
"""Close the shared system DB connection. Called on app shutdown."""
|
||||||
|
global _system_db_conn, _system_db_path
|
||||||
|
if _system_db_conn:
|
||||||
|
try:
|
||||||
|
_system_db_conn.close()
|
||||||
|
except Exception:
|
||||||
|
pass
|
||||||
|
_system_db_conn = None
|
||||||
|
_system_db_path = None
|
||||||
|
|
|
||||||
|
|
@ -8,10 +8,10 @@ import pytest
|
||||||
|
|
||||||
|
|
||||||
@pytest.fixture
|
@pytest.fixture
|
||||||
def e2e_env(tmp_path):
|
def e2e_env(tmp_path, monkeypatch):
|
||||||
"""Set up complete E2E environment with DATA_DIR, create dirs."""
|
"""Set up complete E2E environment with DATA_DIR, create dirs."""
|
||||||
os.environ["DATA_DIR"] = str(tmp_path)
|
monkeypatch.setenv("DATA_DIR", str(tmp_path))
|
||||||
os.environ["JWT_SECRET_KEY"] = "test-secret-e2e"
|
monkeypatch.setenv("JWT_SECRET_KEY", "test-secret-e2e")
|
||||||
|
|
||||||
(tmp_path / "extracts").mkdir()
|
(tmp_path / "extracts").mkdir()
|
||||||
(tmp_path / "analytics").mkdir()
|
(tmp_path / "analytics").mkdir()
|
||||||
|
|
@ -23,9 +23,6 @@ def e2e_env(tmp_path):
|
||||||
"analytics_db": str(tmp_path / "analytics" / "server.duckdb"),
|
"analytics_db": str(tmp_path / "analytics" / "server.duckdb"),
|
||||||
}
|
}
|
||||||
|
|
||||||
os.environ.pop("DATA_DIR", None)
|
|
||||||
os.environ.pop("JWT_SECRET_KEY", None)
|
|
||||||
|
|
||||||
|
|
||||||
def create_mock_extract(extracts_dir: Path, source_name: str, tables: list[dict]):
|
def create_mock_extract(extracts_dir: Path, source_name: str, tables: list[dict]):
|
||||||
"""Create a mock extract.duckdb with _meta and data tables.
|
"""Create a mock extract.duckdb with _meta and data tables.
|
||||||
|
|
|
||||||
0
tests/helpers/__init__.py
Normal file
0
tests/helpers/__init__.py
Normal file
42
tests/helpers/contract.py
Normal file
42
tests/helpers/contract.py
Normal file
|
|
@ -0,0 +1,42 @@
|
||||||
|
"""Shared validator for the extract.duckdb contract."""
|
||||||
|
import duckdb
|
||||||
|
from pathlib import Path
|
||||||
|
|
||||||
|
def validate_extract_contract(db_path: str) -> None:
|
||||||
|
"""Verify an extract.duckdb conforms to the contract. Raises AssertionError if not."""
|
||||||
|
path = Path(db_path)
|
||||||
|
assert path.exists(), f"extract.duckdb not found at {db_path}"
|
||||||
|
|
||||||
|
conn = duckdb.connect(str(path), read_only=True)
|
||||||
|
try:
|
||||||
|
# _meta table must exist with correct schema
|
||||||
|
cols = conn.execute(
|
||||||
|
"SELECT column_name FROM information_schema.columns "
|
||||||
|
"WHERE table_name='_meta' ORDER BY ordinal_position"
|
||||||
|
).fetchall()
|
||||||
|
col_names = [c[0] for c in cols]
|
||||||
|
assert col_names == ["table_name", "description", "rows", "size_bytes", "extracted_at", "query_mode"], \
|
||||||
|
f"_meta schema mismatch: {col_names}"
|
||||||
|
|
||||||
|
# Every local table in _meta must have a view/table
|
||||||
|
local_tables = conn.execute("SELECT table_name FROM _meta WHERE query_mode = 'local'").fetchall()
|
||||||
|
for (name,) in local_tables:
|
||||||
|
tables = conn.execute(
|
||||||
|
"SELECT table_name FROM information_schema.tables WHERE table_name = ?", [name]
|
||||||
|
).fetchall()
|
||||||
|
assert len(tables) > 0, f"Local table '{name}' in _meta but no view/table exists"
|
||||||
|
|
||||||
|
# If _remote_attach exists, validate schema
|
||||||
|
ra_exists = conn.execute(
|
||||||
|
"SELECT count(*) FROM information_schema.tables WHERE table_name='_remote_attach'"
|
||||||
|
).fetchone()[0]
|
||||||
|
if ra_exists:
|
||||||
|
ra_cols = conn.execute(
|
||||||
|
"SELECT column_name FROM information_schema.columns "
|
||||||
|
"WHERE table_name='_remote_attach' ORDER BY ordinal_position"
|
||||||
|
).fetchall()
|
||||||
|
ra_col_names = [c[0] for c in ra_cols]
|
||||||
|
assert ra_col_names == ["alias", "extension", "url", "token_env"], \
|
||||||
|
f"_remote_attach schema mismatch: {ra_col_names}"
|
||||||
|
finally:
|
||||||
|
conn.close()
|
||||||
|
|
@ -7,6 +7,8 @@ from unittest.mock import MagicMock
|
||||||
import duckdb
|
import duckdb
|
||||||
import pytest
|
import pytest
|
||||||
|
|
||||||
|
from tests.helpers.contract import validate_extract_contract
|
||||||
|
|
||||||
|
|
||||||
@pytest.fixture
|
@pytest.fixture
|
||||||
def output_dir(tmp_path):
|
def output_dir(tmp_path):
|
||||||
|
|
@ -114,6 +116,8 @@ class TestBigQueryExtractor:
|
||||||
finally:
|
finally:
|
||||||
conn.close()
|
conn.close()
|
||||||
|
|
||||||
|
validate_extract_contract(str(Path(output_dir) / "extract.duckdb"))
|
||||||
|
|
||||||
def test_no_data_directory_created(self, output_dir, sample_configs):
|
def test_no_data_directory_created(self, output_dir, sample_configs):
|
||||||
"""BigQuery is remote-only -- no data/ directory should exist."""
|
"""BigQuery is remote-only -- no data/ directory should exist."""
|
||||||
assert not (Path(output_dir) / "data").exists()
|
assert not (Path(output_dir) / "data").exists()
|
||||||
|
|
|
||||||
|
|
@ -6,13 +6,13 @@ import duckdb
|
||||||
import pytest
|
import pytest
|
||||||
|
|
||||||
|
|
||||||
def _setup_data_dir(tmp_path):
|
def _setup_data_dir(tmp_path, monkeypatch):
|
||||||
os.environ["DATA_DIR"] = str(tmp_path)
|
monkeypatch.setenv("DATA_DIR", str(tmp_path))
|
||||||
|
|
||||||
|
|
||||||
class TestGetSystemDb:
|
class TestGetSystemDb:
|
||||||
def test_creates_all_tables(self, tmp_path):
|
def test_creates_all_tables(self, tmp_path, monkeypatch):
|
||||||
_setup_data_dir(tmp_path)
|
_setup_data_dir(tmp_path, monkeypatch)
|
||||||
from src.db import get_system_db
|
from src.db import get_system_db
|
||||||
|
|
||||||
conn = get_system_db()
|
conn = get_system_db()
|
||||||
|
|
|
||||||
|
|
@ -7,6 +7,8 @@ from unittest.mock import patch, MagicMock
|
||||||
import duckdb
|
import duckdb
|
||||||
import pytest
|
import pytest
|
||||||
|
|
||||||
|
from tests.helpers.contract import validate_extract_contract
|
||||||
|
|
||||||
|
|
||||||
@pytest.fixture
|
@pytest.fixture
|
||||||
def output_dir(tmp_path):
|
def output_dir(tmp_path):
|
||||||
|
|
@ -86,6 +88,8 @@ class TestKeboolaExtractor:
|
||||||
finally:
|
finally:
|
||||||
conn.close()
|
conn.close()
|
||||||
|
|
||||||
|
validate_extract_contract(str(db_path))
|
||||||
|
|
||||||
def test_remote_tables_not_downloaded(self, output_dir):
|
def test_remote_tables_not_downloaded(self, output_dir):
|
||||||
"""Test that tables with query_mode='remote' are registered but not downloaded."""
|
"""Test that tables with query_mode='remote' are registered but not downloaded."""
|
||||||
from connectors.keboola.extractor import run
|
from connectors.keboola.extractor import run
|
||||||
|
|
|
||||||
Loading…
Reference in a new issue