fix: address code review — path injection, multi-table search, metrics import API, error handling

- Validate view names with _SAFE_IDENTIFIER regex and check path traversal in _initialize_duckdb()
- find_by_table() and get_table_map() now also search the tables[] array field
- Add POST /api/admin/metrics/import endpoint for YAML file upload
- Replace generic except in _connect_to_instance() with specific HTTPStatusError/TimeoutException handlers
- Generate .claude/settings.json in _generate_claude_md() bootstrap
- Update test_find_by_table and test_get_table_map to cover tables[] array lookups
- Add test_import_metrics_yaml in TestMetricsAPI
This commit is contained in:
ZdenekSrotyr 2026-04-10 19:56:00 +02:00
parent 847b48f3af
commit 126d151413
5 changed files with 111 additions and 6 deletions

View file

@ -1,9 +1,12 @@
"""Metrics API endpoints — CRUD for metric definitions stored in DuckDB."""
import os
import tempfile
from typing import List, Optional
import duckdb
from fastapi import APIRouter, Depends, HTTPException
import yaml
from fastapi import APIRouter, Depends, File, HTTPException, UploadFile
from pydantic import BaseModel
from app.auth.dependencies import get_current_user, require_admin, _get_db
@ -106,3 +109,24 @@ async def delete_metric(
if not deleted:
raise HTTPException(status_code=404, detail=f"Metric '{metric_id}' not found")
return {"status": "deleted", "id": metric_id}
@router.post("/api/admin/metrics/import", status_code=200)
async def import_metrics(
file: UploadFile = File(...),
user: dict = Depends(require_admin),
conn: duckdb.DuckDBPyConnection = Depends(_get_db),
):
"""Import metrics from uploaded YAML file."""
content = await file.read()
with tempfile.NamedTemporaryFile(suffix=".yml", delete=False, mode="wb") as tmp:
tmp.write(content)
tmp_path = tmp.name
try:
repo = MetricRepository(conn)
count = repo.import_from_yaml(tmp_path)
return {"status": "imported", "count": count}
finally:
os.unlink(tmp_path)

View file

@ -1,6 +1,7 @@
"""Analyst bootstrap commands — da analyst setup, da analyst status."""
import json
import re
from datetime import datetime, timezone
from pathlib import Path
from typing import Optional
@ -8,6 +9,8 @@ from urllib.parse import urlparse
import typer
_SAFE_IDENTIFIER = re.compile(r"^[a-zA-Z_][a-zA-Z0-9_]{0,63}$")
analyst_app = typer.Typer(help="Analyst workspace bootstrap and status")
# ---------------------------------------------------------------------------
@ -59,6 +62,17 @@ def _connect_to_instance(server_url: str) -> str:
)
resp.raise_for_status()
data = resp.json()
except httpx.HTTPStatusError as e:
if e.response.status_code == 401:
typer.echo("Authentication failed: invalid credentials", err=True)
elif e.response.status_code == 403:
typer.echo("Authentication failed: account disabled or forbidden", err=True)
else:
typer.echo(f"Authentication failed: HTTP {e.response.status_code}", err=True)
raise typer.Exit(1)
except httpx.TimeoutException:
typer.echo(f"Authentication failed: connection timeout to {server_url}", err=True)
raise typer.Exit(1)
except Exception as e:
typer.echo(f"Authentication failed: {e}", err=True)
raise typer.Exit(1)
@ -199,9 +213,21 @@ def _initialize_duckdb(workspace: Path) -> int:
conn = duckdb.connect(str(db_path))
total_rows = 0
parquet_dir_resolved = parquet_dir.resolve()
for pq_file in parquet_dir.glob("*.parquet"):
view_name = pq_file.stem
abs_path = str(pq_file.resolve())
# Validate path is within the expected parquet directory (no path traversal)
try:
pq_resolved = pq_file.resolve()
pq_resolved.relative_to(parquet_dir_resolved)
except ValueError:
typer.echo(f" Warning: Skipping {pq_file.name}: path traversal detected", err=True)
continue
# Validate view name is a safe SQL identifier
if not _SAFE_IDENTIFIER.match(view_name):
typer.echo(f" Warning: Skipping {pq_file.name}: unsafe view name", err=True)
continue
abs_path = str(pq_resolved)
try:
conn.execute(f'DROP VIEW IF EXISTS "{view_name}"')
conn.execute(
@ -285,6 +311,11 @@ def _generate_claude_md(
encoding="utf-8",
)
settings_path = workspace / ".claude" / "settings.json"
if not settings_path.exists():
settings = {"model": "sonnet", "permissions": {"allow": ["Read", "Bash", "Grep", "Glob"]}}
settings_path.write_text(json.dumps(settings, indent=2))
# ---------------------------------------------------------------------------
# Helper: data freshness check (for returning-session detection)

View file

@ -156,8 +156,8 @@ class MetricRepository:
def find_by_table(self, table_name: str) -> List[Dict[str, Any]]:
rows = self.conn.execute(
"SELECT * FROM metric_definitions WHERE table_name = ? ORDER BY name",
[table_name],
"SELECT * FROM metric_definitions WHERE table_name = ? OR list_contains(tables, ?) ORDER BY name",
[table_name, table_name],
).fetchall()
return self._rows_to_dicts(rows)
@ -176,6 +176,12 @@ class MetricRepository:
result: Dict[str, List[str]] = {}
for table_name, metric_name in rows:
result.setdefault(table_name, []).append(metric_name)
# Also include metrics that reference tables via the 'tables' array
results2 = self.conn.execute(
"SELECT unnest(tables) AS tbl, name FROM metric_definitions WHERE tables IS NOT NULL"
).fetchall()
for tbl, metric_name in results2:
result.setdefault(tbl, []).append(metric_name)
return result
def import_from_yaml(self, path: Union[str, Path]) -> int:

View file

@ -330,6 +330,17 @@ class TestMetricsAPI:
assert data["count"] == 1
assert data["metrics"][0]["category"] == "finance"
def test_import_metrics_yaml(self, seeded_client):
client, admin_token, _ = seeded_client
yaml_content = b"- name: test_metric\n display_name: Test\n category: test\n sql: SELECT 1\n"
resp = client.post(
"/api/admin/metrics/import",
files={"file": ("test.yml", yaml_content, "application/x-yaml")},
headers={"Authorization": f"Bearer {admin_token}"},
)
assert resp.status_code == 200
assert resp.json()["count"] == 1
class TestMetadataAPI:
def test_get_metadata_empty(self, seeded_client):

View file

@ -230,6 +230,15 @@ class TestMetricRepositorySearch:
table_name="events",
sql="SELECT COUNT(DISTINCT user_id) FROM events",
)
# 1 metric using tables[] array (no table_name)
repo.create(
id="combined/multi",
name="multi",
display_name="Multi Table Metric",
category="combined",
tables=["a", "b"],
sql="SELECT 1",
)
sub_metrics = repo.find_by_table("subscriptions")
assert len(sub_metrics) == 2
ids = {m["id"] for m in sub_metrics}
@ -240,6 +249,16 @@ class TestMetricRepositorySearch:
assert len(event_metrics) == 1
assert event_metrics[0]["id"] == "engagement/dau"
# Metric referencing 'a' via tables[] array should be found
a_metrics = repo.find_by_table("a")
assert len(a_metrics) == 1
assert a_metrics[0]["id"] == "combined/multi"
# Metric referencing 'b' via tables[] array should also be found
b_metrics = repo.find_by_table("b")
assert len(b_metrics) == 1
assert b_metrics[0]["id"] == "combined/multi"
def test_find_by_synonym(self, db_conn):
from src.repositories.metrics import MetricRepository
repo = MetricRepository(db_conn)
@ -283,12 +302,26 @@ class TestMetricRepositorySearch:
table_name="events",
sql="SELECT COUNT(DISTINCT user_id) FROM events",
)
# Metric using tables[] array
repo.create(
id="combined/multi",
name="multi",
display_name="Multi Table Metric",
category="combined",
tables=["subscriptions", "events"],
sql="SELECT 1",
)
table_map = repo.get_table_map()
assert isinstance(table_map, dict)
assert "subscriptions" in table_map
assert "events" in table_map
assert set(table_map["subscriptions"]) == {"mrr", "arr"}
assert table_map["events"] == ["dau"]
# 'subscriptions' should include mrr, arr (table_name) plus multi (tables[])
assert "mrr" in table_map["subscriptions"]
assert "arr" in table_map["subscriptions"]
assert "multi" in table_map["subscriptions"]
# 'events' should include dau (table_name) plus multi (tables[])
assert "dau" in table_map["events"]
assert "multi" in table_map["events"]
def test_get_table_map_excludes_null_table(self, db_conn):
from src.repositories.metrics import MetricRepository