fix: address code review — path injection, multi-table search, metrics import API, error handling
- Validate view names with _SAFE_IDENTIFIER regex and check path traversal in _initialize_duckdb() - find_by_table() and get_table_map() now also search the tables[] array field - Add POST /api/admin/metrics/import endpoint for YAML file upload - Replace generic except in _connect_to_instance() with specific HTTPStatusError/TimeoutException handlers - Generate .claude/settings.json in _generate_claude_md() bootstrap - Update test_find_by_table and test_get_table_map to cover tables[] array lookups - Add test_import_metrics_yaml in TestMetricsAPI
This commit is contained in:
parent
847b48f3af
commit
126d151413
5 changed files with 111 additions and 6 deletions
|
|
@ -1,9 +1,12 @@
|
|||
"""Metrics API endpoints — CRUD for metric definitions stored in DuckDB."""
|
||||
|
||||
import os
|
||||
import tempfile
|
||||
from typing import List, Optional
|
||||
|
||||
import duckdb
|
||||
from fastapi import APIRouter, Depends, HTTPException
|
||||
import yaml
|
||||
from fastapi import APIRouter, Depends, File, HTTPException, UploadFile
|
||||
from pydantic import BaseModel
|
||||
|
||||
from app.auth.dependencies import get_current_user, require_admin, _get_db
|
||||
|
|
@ -106,3 +109,24 @@ async def delete_metric(
|
|||
if not deleted:
|
||||
raise HTTPException(status_code=404, detail=f"Metric '{metric_id}' not found")
|
||||
return {"status": "deleted", "id": metric_id}
|
||||
|
||||
|
||||
@router.post("/api/admin/metrics/import", status_code=200)
|
||||
async def import_metrics(
|
||||
file: UploadFile = File(...),
|
||||
user: dict = Depends(require_admin),
|
||||
conn: duckdb.DuckDBPyConnection = Depends(_get_db),
|
||||
):
|
||||
"""Import metrics from uploaded YAML file."""
|
||||
content = await file.read()
|
||||
|
||||
with tempfile.NamedTemporaryFile(suffix=".yml", delete=False, mode="wb") as tmp:
|
||||
tmp.write(content)
|
||||
tmp_path = tmp.name
|
||||
|
||||
try:
|
||||
repo = MetricRepository(conn)
|
||||
count = repo.import_from_yaml(tmp_path)
|
||||
return {"status": "imported", "count": count}
|
||||
finally:
|
||||
os.unlink(tmp_path)
|
||||
|
|
|
|||
|
|
@ -1,6 +1,7 @@
|
|||
"""Analyst bootstrap commands — da analyst setup, da analyst status."""
|
||||
|
||||
import json
|
||||
import re
|
||||
from datetime import datetime, timezone
|
||||
from pathlib import Path
|
||||
from typing import Optional
|
||||
|
|
@ -8,6 +9,8 @@ from urllib.parse import urlparse
|
|||
|
||||
import typer
|
||||
|
||||
_SAFE_IDENTIFIER = re.compile(r"^[a-zA-Z_][a-zA-Z0-9_]{0,63}$")
|
||||
|
||||
analyst_app = typer.Typer(help="Analyst workspace bootstrap and status")
|
||||
|
||||
# ---------------------------------------------------------------------------
|
||||
|
|
@ -59,6 +62,17 @@ def _connect_to_instance(server_url: str) -> str:
|
|||
)
|
||||
resp.raise_for_status()
|
||||
data = resp.json()
|
||||
except httpx.HTTPStatusError as e:
|
||||
if e.response.status_code == 401:
|
||||
typer.echo("Authentication failed: invalid credentials", err=True)
|
||||
elif e.response.status_code == 403:
|
||||
typer.echo("Authentication failed: account disabled or forbidden", err=True)
|
||||
else:
|
||||
typer.echo(f"Authentication failed: HTTP {e.response.status_code}", err=True)
|
||||
raise typer.Exit(1)
|
||||
except httpx.TimeoutException:
|
||||
typer.echo(f"Authentication failed: connection timeout to {server_url}", err=True)
|
||||
raise typer.Exit(1)
|
||||
except Exception as e:
|
||||
typer.echo(f"Authentication failed: {e}", err=True)
|
||||
raise typer.Exit(1)
|
||||
|
|
@ -199,9 +213,21 @@ def _initialize_duckdb(workspace: Path) -> int:
|
|||
conn = duckdb.connect(str(db_path))
|
||||
total_rows = 0
|
||||
|
||||
parquet_dir_resolved = parquet_dir.resolve()
|
||||
for pq_file in parquet_dir.glob("*.parquet"):
|
||||
view_name = pq_file.stem
|
||||
abs_path = str(pq_file.resolve())
|
||||
# Validate path is within the expected parquet directory (no path traversal)
|
||||
try:
|
||||
pq_resolved = pq_file.resolve()
|
||||
pq_resolved.relative_to(parquet_dir_resolved)
|
||||
except ValueError:
|
||||
typer.echo(f" Warning: Skipping {pq_file.name}: path traversal detected", err=True)
|
||||
continue
|
||||
# Validate view name is a safe SQL identifier
|
||||
if not _SAFE_IDENTIFIER.match(view_name):
|
||||
typer.echo(f" Warning: Skipping {pq_file.name}: unsafe view name", err=True)
|
||||
continue
|
||||
abs_path = str(pq_resolved)
|
||||
try:
|
||||
conn.execute(f'DROP VIEW IF EXISTS "{view_name}"')
|
||||
conn.execute(
|
||||
|
|
@ -285,6 +311,11 @@ def _generate_claude_md(
|
|||
encoding="utf-8",
|
||||
)
|
||||
|
||||
settings_path = workspace / ".claude" / "settings.json"
|
||||
if not settings_path.exists():
|
||||
settings = {"model": "sonnet", "permissions": {"allow": ["Read", "Bash", "Grep", "Glob"]}}
|
||||
settings_path.write_text(json.dumps(settings, indent=2))
|
||||
|
||||
|
||||
# ---------------------------------------------------------------------------
|
||||
# Helper: data freshness check (for returning-session detection)
|
||||
|
|
|
|||
|
|
@ -156,8 +156,8 @@ class MetricRepository:
|
|||
|
||||
def find_by_table(self, table_name: str) -> List[Dict[str, Any]]:
|
||||
rows = self.conn.execute(
|
||||
"SELECT * FROM metric_definitions WHERE table_name = ? ORDER BY name",
|
||||
[table_name],
|
||||
"SELECT * FROM metric_definitions WHERE table_name = ? OR list_contains(tables, ?) ORDER BY name",
|
||||
[table_name, table_name],
|
||||
).fetchall()
|
||||
return self._rows_to_dicts(rows)
|
||||
|
||||
|
|
@ -176,6 +176,12 @@ class MetricRepository:
|
|||
result: Dict[str, List[str]] = {}
|
||||
for table_name, metric_name in rows:
|
||||
result.setdefault(table_name, []).append(metric_name)
|
||||
# Also include metrics that reference tables via the 'tables' array
|
||||
results2 = self.conn.execute(
|
||||
"SELECT unnest(tables) AS tbl, name FROM metric_definitions WHERE tables IS NOT NULL"
|
||||
).fetchall()
|
||||
for tbl, metric_name in results2:
|
||||
result.setdefault(tbl, []).append(metric_name)
|
||||
return result
|
||||
|
||||
def import_from_yaml(self, path: Union[str, Path]) -> int:
|
||||
|
|
|
|||
|
|
@ -330,6 +330,17 @@ class TestMetricsAPI:
|
|||
assert data["count"] == 1
|
||||
assert data["metrics"][0]["category"] == "finance"
|
||||
|
||||
def test_import_metrics_yaml(self, seeded_client):
|
||||
client, admin_token, _ = seeded_client
|
||||
yaml_content = b"- name: test_metric\n display_name: Test\n category: test\n sql: SELECT 1\n"
|
||||
resp = client.post(
|
||||
"/api/admin/metrics/import",
|
||||
files={"file": ("test.yml", yaml_content, "application/x-yaml")},
|
||||
headers={"Authorization": f"Bearer {admin_token}"},
|
||||
)
|
||||
assert resp.status_code == 200
|
||||
assert resp.json()["count"] == 1
|
||||
|
||||
|
||||
class TestMetadataAPI:
|
||||
def test_get_metadata_empty(self, seeded_client):
|
||||
|
|
|
|||
|
|
@ -230,6 +230,15 @@ class TestMetricRepositorySearch:
|
|||
table_name="events",
|
||||
sql="SELECT COUNT(DISTINCT user_id) FROM events",
|
||||
)
|
||||
# 1 metric using tables[] array (no table_name)
|
||||
repo.create(
|
||||
id="combined/multi",
|
||||
name="multi",
|
||||
display_name="Multi Table Metric",
|
||||
category="combined",
|
||||
tables=["a", "b"],
|
||||
sql="SELECT 1",
|
||||
)
|
||||
sub_metrics = repo.find_by_table("subscriptions")
|
||||
assert len(sub_metrics) == 2
|
||||
ids = {m["id"] for m in sub_metrics}
|
||||
|
|
@ -240,6 +249,16 @@ class TestMetricRepositorySearch:
|
|||
assert len(event_metrics) == 1
|
||||
assert event_metrics[0]["id"] == "engagement/dau"
|
||||
|
||||
# Metric referencing 'a' via tables[] array should be found
|
||||
a_metrics = repo.find_by_table("a")
|
||||
assert len(a_metrics) == 1
|
||||
assert a_metrics[0]["id"] == "combined/multi"
|
||||
|
||||
# Metric referencing 'b' via tables[] array should also be found
|
||||
b_metrics = repo.find_by_table("b")
|
||||
assert len(b_metrics) == 1
|
||||
assert b_metrics[0]["id"] == "combined/multi"
|
||||
|
||||
def test_find_by_synonym(self, db_conn):
|
||||
from src.repositories.metrics import MetricRepository
|
||||
repo = MetricRepository(db_conn)
|
||||
|
|
@ -283,12 +302,26 @@ class TestMetricRepositorySearch:
|
|||
table_name="events",
|
||||
sql="SELECT COUNT(DISTINCT user_id) FROM events",
|
||||
)
|
||||
# Metric using tables[] array
|
||||
repo.create(
|
||||
id="combined/multi",
|
||||
name="multi",
|
||||
display_name="Multi Table Metric",
|
||||
category="combined",
|
||||
tables=["subscriptions", "events"],
|
||||
sql="SELECT 1",
|
||||
)
|
||||
table_map = repo.get_table_map()
|
||||
assert isinstance(table_map, dict)
|
||||
assert "subscriptions" in table_map
|
||||
assert "events" in table_map
|
||||
assert set(table_map["subscriptions"]) == {"mrr", "arr"}
|
||||
assert table_map["events"] == ["dau"]
|
||||
# 'subscriptions' should include mrr, arr (table_name) plus multi (tables[])
|
||||
assert "mrr" in table_map["subscriptions"]
|
||||
assert "arr" in table_map["subscriptions"]
|
||||
assert "multi" in table_map["subscriptions"]
|
||||
# 'events' should include dau (table_name) plus multi (tables[])
|
||||
assert "dau" in table_map["events"]
|
||||
assert "multi" in table_map["events"]
|
||||
|
||||
def test_get_table_map_excludes_null_table(self, db_conn):
|
||||
from src.repositories.metrics import MetricRepository
|
||||
|
|
|
|||
Loading…
Reference in a new issue