fix: address code review — path injection, multi-table search, metrics import API, error handling
- Validate view names with _SAFE_IDENTIFIER regex and check path traversal in _initialize_duckdb() - find_by_table() and get_table_map() now also search the tables[] array field - Add POST /api/admin/metrics/import endpoint for YAML file upload - Replace generic except in _connect_to_instance() with specific HTTPStatusError/TimeoutException handlers - Generate .claude/settings.json in _generate_claude_md() bootstrap - Update test_find_by_table and test_get_table_map to cover tables[] array lookups - Add test_import_metrics_yaml in TestMetricsAPI
This commit is contained in:
parent
847b48f3af
commit
126d151413
5 changed files with 111 additions and 6 deletions
|
|
@ -1,9 +1,12 @@
|
||||||
"""Metrics API endpoints — CRUD for metric definitions stored in DuckDB."""
|
"""Metrics API endpoints — CRUD for metric definitions stored in DuckDB."""
|
||||||
|
|
||||||
|
import os
|
||||||
|
import tempfile
|
||||||
from typing import List, Optional
|
from typing import List, Optional
|
||||||
|
|
||||||
import duckdb
|
import duckdb
|
||||||
from fastapi import APIRouter, Depends, HTTPException
|
import yaml
|
||||||
|
from fastapi import APIRouter, Depends, File, HTTPException, UploadFile
|
||||||
from pydantic import BaseModel
|
from pydantic import BaseModel
|
||||||
|
|
||||||
from app.auth.dependencies import get_current_user, require_admin, _get_db
|
from app.auth.dependencies import get_current_user, require_admin, _get_db
|
||||||
|
|
@ -106,3 +109,24 @@ async def delete_metric(
|
||||||
if not deleted:
|
if not deleted:
|
||||||
raise HTTPException(status_code=404, detail=f"Metric '{metric_id}' not found")
|
raise HTTPException(status_code=404, detail=f"Metric '{metric_id}' not found")
|
||||||
return {"status": "deleted", "id": metric_id}
|
return {"status": "deleted", "id": metric_id}
|
||||||
|
|
||||||
|
|
||||||
|
@router.post("/api/admin/metrics/import", status_code=200)
|
||||||
|
async def import_metrics(
|
||||||
|
file: UploadFile = File(...),
|
||||||
|
user: dict = Depends(require_admin),
|
||||||
|
conn: duckdb.DuckDBPyConnection = Depends(_get_db),
|
||||||
|
):
|
||||||
|
"""Import metrics from uploaded YAML file."""
|
||||||
|
content = await file.read()
|
||||||
|
|
||||||
|
with tempfile.NamedTemporaryFile(suffix=".yml", delete=False, mode="wb") as tmp:
|
||||||
|
tmp.write(content)
|
||||||
|
tmp_path = tmp.name
|
||||||
|
|
||||||
|
try:
|
||||||
|
repo = MetricRepository(conn)
|
||||||
|
count = repo.import_from_yaml(tmp_path)
|
||||||
|
return {"status": "imported", "count": count}
|
||||||
|
finally:
|
||||||
|
os.unlink(tmp_path)
|
||||||
|
|
|
||||||
|
|
@ -1,6 +1,7 @@
|
||||||
"""Analyst bootstrap commands — da analyst setup, da analyst status."""
|
"""Analyst bootstrap commands — da analyst setup, da analyst status."""
|
||||||
|
|
||||||
import json
|
import json
|
||||||
|
import re
|
||||||
from datetime import datetime, timezone
|
from datetime import datetime, timezone
|
||||||
from pathlib import Path
|
from pathlib import Path
|
||||||
from typing import Optional
|
from typing import Optional
|
||||||
|
|
@ -8,6 +9,8 @@ from urllib.parse import urlparse
|
||||||
|
|
||||||
import typer
|
import typer
|
||||||
|
|
||||||
|
_SAFE_IDENTIFIER = re.compile(r"^[a-zA-Z_][a-zA-Z0-9_]{0,63}$")
|
||||||
|
|
||||||
analyst_app = typer.Typer(help="Analyst workspace bootstrap and status")
|
analyst_app = typer.Typer(help="Analyst workspace bootstrap and status")
|
||||||
|
|
||||||
# ---------------------------------------------------------------------------
|
# ---------------------------------------------------------------------------
|
||||||
|
|
@ -59,6 +62,17 @@ def _connect_to_instance(server_url: str) -> str:
|
||||||
)
|
)
|
||||||
resp.raise_for_status()
|
resp.raise_for_status()
|
||||||
data = resp.json()
|
data = resp.json()
|
||||||
|
except httpx.HTTPStatusError as e:
|
||||||
|
if e.response.status_code == 401:
|
||||||
|
typer.echo("Authentication failed: invalid credentials", err=True)
|
||||||
|
elif e.response.status_code == 403:
|
||||||
|
typer.echo("Authentication failed: account disabled or forbidden", err=True)
|
||||||
|
else:
|
||||||
|
typer.echo(f"Authentication failed: HTTP {e.response.status_code}", err=True)
|
||||||
|
raise typer.Exit(1)
|
||||||
|
except httpx.TimeoutException:
|
||||||
|
typer.echo(f"Authentication failed: connection timeout to {server_url}", err=True)
|
||||||
|
raise typer.Exit(1)
|
||||||
except Exception as e:
|
except Exception as e:
|
||||||
typer.echo(f"Authentication failed: {e}", err=True)
|
typer.echo(f"Authentication failed: {e}", err=True)
|
||||||
raise typer.Exit(1)
|
raise typer.Exit(1)
|
||||||
|
|
@ -199,9 +213,21 @@ def _initialize_duckdb(workspace: Path) -> int:
|
||||||
conn = duckdb.connect(str(db_path))
|
conn = duckdb.connect(str(db_path))
|
||||||
total_rows = 0
|
total_rows = 0
|
||||||
|
|
||||||
|
parquet_dir_resolved = parquet_dir.resolve()
|
||||||
for pq_file in parquet_dir.glob("*.parquet"):
|
for pq_file in parquet_dir.glob("*.parquet"):
|
||||||
view_name = pq_file.stem
|
view_name = pq_file.stem
|
||||||
abs_path = str(pq_file.resolve())
|
# Validate path is within the expected parquet directory (no path traversal)
|
||||||
|
try:
|
||||||
|
pq_resolved = pq_file.resolve()
|
||||||
|
pq_resolved.relative_to(parquet_dir_resolved)
|
||||||
|
except ValueError:
|
||||||
|
typer.echo(f" Warning: Skipping {pq_file.name}: path traversal detected", err=True)
|
||||||
|
continue
|
||||||
|
# Validate view name is a safe SQL identifier
|
||||||
|
if not _SAFE_IDENTIFIER.match(view_name):
|
||||||
|
typer.echo(f" Warning: Skipping {pq_file.name}: unsafe view name", err=True)
|
||||||
|
continue
|
||||||
|
abs_path = str(pq_resolved)
|
||||||
try:
|
try:
|
||||||
conn.execute(f'DROP VIEW IF EXISTS "{view_name}"')
|
conn.execute(f'DROP VIEW IF EXISTS "{view_name}"')
|
||||||
conn.execute(
|
conn.execute(
|
||||||
|
|
@ -285,6 +311,11 @@ def _generate_claude_md(
|
||||||
encoding="utf-8",
|
encoding="utf-8",
|
||||||
)
|
)
|
||||||
|
|
||||||
|
settings_path = workspace / ".claude" / "settings.json"
|
||||||
|
if not settings_path.exists():
|
||||||
|
settings = {"model": "sonnet", "permissions": {"allow": ["Read", "Bash", "Grep", "Glob"]}}
|
||||||
|
settings_path.write_text(json.dumps(settings, indent=2))
|
||||||
|
|
||||||
|
|
||||||
# ---------------------------------------------------------------------------
|
# ---------------------------------------------------------------------------
|
||||||
# Helper: data freshness check (for returning-session detection)
|
# Helper: data freshness check (for returning-session detection)
|
||||||
|
|
|
||||||
|
|
@ -156,8 +156,8 @@ class MetricRepository:
|
||||||
|
|
||||||
def find_by_table(self, table_name: str) -> List[Dict[str, Any]]:
|
def find_by_table(self, table_name: str) -> List[Dict[str, Any]]:
|
||||||
rows = self.conn.execute(
|
rows = self.conn.execute(
|
||||||
"SELECT * FROM metric_definitions WHERE table_name = ? ORDER BY name",
|
"SELECT * FROM metric_definitions WHERE table_name = ? OR list_contains(tables, ?) ORDER BY name",
|
||||||
[table_name],
|
[table_name, table_name],
|
||||||
).fetchall()
|
).fetchall()
|
||||||
return self._rows_to_dicts(rows)
|
return self._rows_to_dicts(rows)
|
||||||
|
|
||||||
|
|
@ -176,6 +176,12 @@ class MetricRepository:
|
||||||
result: Dict[str, List[str]] = {}
|
result: Dict[str, List[str]] = {}
|
||||||
for table_name, metric_name in rows:
|
for table_name, metric_name in rows:
|
||||||
result.setdefault(table_name, []).append(metric_name)
|
result.setdefault(table_name, []).append(metric_name)
|
||||||
|
# Also include metrics that reference tables via the 'tables' array
|
||||||
|
results2 = self.conn.execute(
|
||||||
|
"SELECT unnest(tables) AS tbl, name FROM metric_definitions WHERE tables IS NOT NULL"
|
||||||
|
).fetchall()
|
||||||
|
for tbl, metric_name in results2:
|
||||||
|
result.setdefault(tbl, []).append(metric_name)
|
||||||
return result
|
return result
|
||||||
|
|
||||||
def import_from_yaml(self, path: Union[str, Path]) -> int:
|
def import_from_yaml(self, path: Union[str, Path]) -> int:
|
||||||
|
|
|
||||||
|
|
@ -330,6 +330,17 @@ class TestMetricsAPI:
|
||||||
assert data["count"] == 1
|
assert data["count"] == 1
|
||||||
assert data["metrics"][0]["category"] == "finance"
|
assert data["metrics"][0]["category"] == "finance"
|
||||||
|
|
||||||
|
def test_import_metrics_yaml(self, seeded_client):
|
||||||
|
client, admin_token, _ = seeded_client
|
||||||
|
yaml_content = b"- name: test_metric\n display_name: Test\n category: test\n sql: SELECT 1\n"
|
||||||
|
resp = client.post(
|
||||||
|
"/api/admin/metrics/import",
|
||||||
|
files={"file": ("test.yml", yaml_content, "application/x-yaml")},
|
||||||
|
headers={"Authorization": f"Bearer {admin_token}"},
|
||||||
|
)
|
||||||
|
assert resp.status_code == 200
|
||||||
|
assert resp.json()["count"] == 1
|
||||||
|
|
||||||
|
|
||||||
class TestMetadataAPI:
|
class TestMetadataAPI:
|
||||||
def test_get_metadata_empty(self, seeded_client):
|
def test_get_metadata_empty(self, seeded_client):
|
||||||
|
|
|
||||||
|
|
@ -230,6 +230,15 @@ class TestMetricRepositorySearch:
|
||||||
table_name="events",
|
table_name="events",
|
||||||
sql="SELECT COUNT(DISTINCT user_id) FROM events",
|
sql="SELECT COUNT(DISTINCT user_id) FROM events",
|
||||||
)
|
)
|
||||||
|
# 1 metric using tables[] array (no table_name)
|
||||||
|
repo.create(
|
||||||
|
id="combined/multi",
|
||||||
|
name="multi",
|
||||||
|
display_name="Multi Table Metric",
|
||||||
|
category="combined",
|
||||||
|
tables=["a", "b"],
|
||||||
|
sql="SELECT 1",
|
||||||
|
)
|
||||||
sub_metrics = repo.find_by_table("subscriptions")
|
sub_metrics = repo.find_by_table("subscriptions")
|
||||||
assert len(sub_metrics) == 2
|
assert len(sub_metrics) == 2
|
||||||
ids = {m["id"] for m in sub_metrics}
|
ids = {m["id"] for m in sub_metrics}
|
||||||
|
|
@ -240,6 +249,16 @@ class TestMetricRepositorySearch:
|
||||||
assert len(event_metrics) == 1
|
assert len(event_metrics) == 1
|
||||||
assert event_metrics[0]["id"] == "engagement/dau"
|
assert event_metrics[0]["id"] == "engagement/dau"
|
||||||
|
|
||||||
|
# Metric referencing 'a' via tables[] array should be found
|
||||||
|
a_metrics = repo.find_by_table("a")
|
||||||
|
assert len(a_metrics) == 1
|
||||||
|
assert a_metrics[0]["id"] == "combined/multi"
|
||||||
|
|
||||||
|
# Metric referencing 'b' via tables[] array should also be found
|
||||||
|
b_metrics = repo.find_by_table("b")
|
||||||
|
assert len(b_metrics) == 1
|
||||||
|
assert b_metrics[0]["id"] == "combined/multi"
|
||||||
|
|
||||||
def test_find_by_synonym(self, db_conn):
|
def test_find_by_synonym(self, db_conn):
|
||||||
from src.repositories.metrics import MetricRepository
|
from src.repositories.metrics import MetricRepository
|
||||||
repo = MetricRepository(db_conn)
|
repo = MetricRepository(db_conn)
|
||||||
|
|
@ -283,12 +302,26 @@ class TestMetricRepositorySearch:
|
||||||
table_name="events",
|
table_name="events",
|
||||||
sql="SELECT COUNT(DISTINCT user_id) FROM events",
|
sql="SELECT COUNT(DISTINCT user_id) FROM events",
|
||||||
)
|
)
|
||||||
|
# Metric using tables[] array
|
||||||
|
repo.create(
|
||||||
|
id="combined/multi",
|
||||||
|
name="multi",
|
||||||
|
display_name="Multi Table Metric",
|
||||||
|
category="combined",
|
||||||
|
tables=["subscriptions", "events"],
|
||||||
|
sql="SELECT 1",
|
||||||
|
)
|
||||||
table_map = repo.get_table_map()
|
table_map = repo.get_table_map()
|
||||||
assert isinstance(table_map, dict)
|
assert isinstance(table_map, dict)
|
||||||
assert "subscriptions" in table_map
|
assert "subscriptions" in table_map
|
||||||
assert "events" in table_map
|
assert "events" in table_map
|
||||||
assert set(table_map["subscriptions"]) == {"mrr", "arr"}
|
# 'subscriptions' should include mrr, arr (table_name) plus multi (tables[])
|
||||||
assert table_map["events"] == ["dau"]
|
assert "mrr" in table_map["subscriptions"]
|
||||||
|
assert "arr" in table_map["subscriptions"]
|
||||||
|
assert "multi" in table_map["subscriptions"]
|
||||||
|
# 'events' should include dau (table_name) plus multi (tables[])
|
||||||
|
assert "dau" in table_map["events"]
|
||||||
|
assert "multi" in table_map["events"]
|
||||||
|
|
||||||
def test_get_table_map_excludes_null_table(self, db_conn):
|
def test_get_table_map_excludes_null_table(self, db_conn):
|
||||||
from src.repositories.metrics import MetricRepository
|
from src.repositories.metrics import MetricRepository
|
||||||
|
|
|
||||||
Loading…
Reference in a new issue