From 126d15141365311ec4ecb45c260afff87937ecad Mon Sep 17 00:00:00 2001 From: ZdenekSrotyr Date: Fri, 10 Apr 2026 19:56:00 +0200 Subject: [PATCH] =?UTF-8?q?fix:=20address=20code=20review=20=E2=80=94=20pa?= =?UTF-8?q?th=20injection,=20multi-table=20search,=20metrics=20import=20AP?= =?UTF-8?q?I,=20error=20handling?= MIME-Version: 1.0 Content-Type: text/plain; charset=UTF-8 Content-Transfer-Encoding: 8bit - Validate view names with _SAFE_IDENTIFIER regex and check path traversal in _initialize_duckdb() - find_by_table() and get_table_map() now also search the tables[] array field - Add POST /api/admin/metrics/import endpoint for YAML file upload - Replace generic except in _connect_to_instance() with specific HTTPStatusError/TimeoutException handlers - Generate .claude/settings.json in _generate_claude_md() bootstrap - Update test_find_by_table and test_get_table_map to cover tables[] array lookups - Add test_import_metrics_yaml in TestMetricsAPI --- app/api/metrics.py | 26 +++++++++++++++++++++++++- cli/commands/analyst.py | 33 ++++++++++++++++++++++++++++++++- src/repositories/metrics.py | 10 ++++++++-- tests/test_api.py | 11 +++++++++++ tests/test_metrics.py | 37 +++++++++++++++++++++++++++++++++++-- 5 files changed, 111 insertions(+), 6 deletions(-) diff --git a/app/api/metrics.py b/app/api/metrics.py index 2d7b915..a608559 100644 --- a/app/api/metrics.py +++ b/app/api/metrics.py @@ -1,9 +1,12 @@ """Metrics API endpoints — CRUD for metric definitions stored in DuckDB.""" +import os +import tempfile from typing import List, Optional import duckdb -from fastapi import APIRouter, Depends, HTTPException +import yaml +from fastapi import APIRouter, Depends, File, HTTPException, UploadFile from pydantic import BaseModel from app.auth.dependencies import get_current_user, require_admin, _get_db @@ -106,3 +109,24 @@ async def delete_metric( if not deleted: raise HTTPException(status_code=404, detail=f"Metric '{metric_id}' not found") return {"status": "deleted", "id": metric_id} + + +@router.post("/api/admin/metrics/import", status_code=200) +async def import_metrics( + file: UploadFile = File(...), + user: dict = Depends(require_admin), + conn: duckdb.DuckDBPyConnection = Depends(_get_db), +): + """Import metrics from uploaded YAML file.""" + content = await file.read() + + with tempfile.NamedTemporaryFile(suffix=".yml", delete=False, mode="wb") as tmp: + tmp.write(content) + tmp_path = tmp.name + + try: + repo = MetricRepository(conn) + count = repo.import_from_yaml(tmp_path) + return {"status": "imported", "count": count} + finally: + os.unlink(tmp_path) diff --git a/cli/commands/analyst.py b/cli/commands/analyst.py index b2ffd6c..f62ce28 100644 --- a/cli/commands/analyst.py +++ b/cli/commands/analyst.py @@ -1,6 +1,7 @@ """Analyst bootstrap commands — da analyst setup, da analyst status.""" import json +import re from datetime import datetime, timezone from pathlib import Path from typing import Optional @@ -8,6 +9,8 @@ from urllib.parse import urlparse import typer +_SAFE_IDENTIFIER = re.compile(r"^[a-zA-Z_][a-zA-Z0-9_]{0,63}$") + analyst_app = typer.Typer(help="Analyst workspace bootstrap and status") # --------------------------------------------------------------------------- @@ -59,6 +62,17 @@ def _connect_to_instance(server_url: str) -> str: ) resp.raise_for_status() data = resp.json() + except httpx.HTTPStatusError as e: + if e.response.status_code == 401: + typer.echo("Authentication failed: invalid credentials", err=True) + elif e.response.status_code == 403: + typer.echo("Authentication failed: account disabled or forbidden", err=True) + else: + typer.echo(f"Authentication failed: HTTP {e.response.status_code}", err=True) + raise typer.Exit(1) + except httpx.TimeoutException: + typer.echo(f"Authentication failed: connection timeout to {server_url}", err=True) + raise typer.Exit(1) except Exception as e: typer.echo(f"Authentication failed: {e}", err=True) raise typer.Exit(1) @@ -199,9 +213,21 @@ def _initialize_duckdb(workspace: Path) -> int: conn = duckdb.connect(str(db_path)) total_rows = 0 + parquet_dir_resolved = parquet_dir.resolve() for pq_file in parquet_dir.glob("*.parquet"): view_name = pq_file.stem - abs_path = str(pq_file.resolve()) + # Validate path is within the expected parquet directory (no path traversal) + try: + pq_resolved = pq_file.resolve() + pq_resolved.relative_to(parquet_dir_resolved) + except ValueError: + typer.echo(f" Warning: Skipping {pq_file.name}: path traversal detected", err=True) + continue + # Validate view name is a safe SQL identifier + if not _SAFE_IDENTIFIER.match(view_name): + typer.echo(f" Warning: Skipping {pq_file.name}: unsafe view name", err=True) + continue + abs_path = str(pq_resolved) try: conn.execute(f'DROP VIEW IF EXISTS "{view_name}"') conn.execute( @@ -285,6 +311,11 @@ def _generate_claude_md( encoding="utf-8", ) + settings_path = workspace / ".claude" / "settings.json" + if not settings_path.exists(): + settings = {"model": "sonnet", "permissions": {"allow": ["Read", "Bash", "Grep", "Glob"]}} + settings_path.write_text(json.dumps(settings, indent=2)) + # --------------------------------------------------------------------------- # Helper: data freshness check (for returning-session detection) diff --git a/src/repositories/metrics.py b/src/repositories/metrics.py index 6335208..d37e447 100644 --- a/src/repositories/metrics.py +++ b/src/repositories/metrics.py @@ -156,8 +156,8 @@ class MetricRepository: def find_by_table(self, table_name: str) -> List[Dict[str, Any]]: rows = self.conn.execute( - "SELECT * FROM metric_definitions WHERE table_name = ? ORDER BY name", - [table_name], + "SELECT * FROM metric_definitions WHERE table_name = ? OR list_contains(tables, ?) ORDER BY name", + [table_name, table_name], ).fetchall() return self._rows_to_dicts(rows) @@ -176,6 +176,12 @@ class MetricRepository: result: Dict[str, List[str]] = {} for table_name, metric_name in rows: result.setdefault(table_name, []).append(metric_name) + # Also include metrics that reference tables via the 'tables' array + results2 = self.conn.execute( + "SELECT unnest(tables) AS tbl, name FROM metric_definitions WHERE tables IS NOT NULL" + ).fetchall() + for tbl, metric_name in results2: + result.setdefault(tbl, []).append(metric_name) return result def import_from_yaml(self, path: Union[str, Path]) -> int: diff --git a/tests/test_api.py b/tests/test_api.py index 0e752c2..df016c6 100644 --- a/tests/test_api.py +++ b/tests/test_api.py @@ -330,6 +330,17 @@ class TestMetricsAPI: assert data["count"] == 1 assert data["metrics"][0]["category"] == "finance" + def test_import_metrics_yaml(self, seeded_client): + client, admin_token, _ = seeded_client + yaml_content = b"- name: test_metric\n display_name: Test\n category: test\n sql: SELECT 1\n" + resp = client.post( + "/api/admin/metrics/import", + files={"file": ("test.yml", yaml_content, "application/x-yaml")}, + headers={"Authorization": f"Bearer {admin_token}"}, + ) + assert resp.status_code == 200 + assert resp.json()["count"] == 1 + class TestMetadataAPI: def test_get_metadata_empty(self, seeded_client): diff --git a/tests/test_metrics.py b/tests/test_metrics.py index fc8a064..2165202 100644 --- a/tests/test_metrics.py +++ b/tests/test_metrics.py @@ -230,6 +230,15 @@ class TestMetricRepositorySearch: table_name="events", sql="SELECT COUNT(DISTINCT user_id) FROM events", ) + # 1 metric using tables[] array (no table_name) + repo.create( + id="combined/multi", + name="multi", + display_name="Multi Table Metric", + category="combined", + tables=["a", "b"], + sql="SELECT 1", + ) sub_metrics = repo.find_by_table("subscriptions") assert len(sub_metrics) == 2 ids = {m["id"] for m in sub_metrics} @@ -240,6 +249,16 @@ class TestMetricRepositorySearch: assert len(event_metrics) == 1 assert event_metrics[0]["id"] == "engagement/dau" + # Metric referencing 'a' via tables[] array should be found + a_metrics = repo.find_by_table("a") + assert len(a_metrics) == 1 + assert a_metrics[0]["id"] == "combined/multi" + + # Metric referencing 'b' via tables[] array should also be found + b_metrics = repo.find_by_table("b") + assert len(b_metrics) == 1 + assert b_metrics[0]["id"] == "combined/multi" + def test_find_by_synonym(self, db_conn): from src.repositories.metrics import MetricRepository repo = MetricRepository(db_conn) @@ -283,12 +302,26 @@ class TestMetricRepositorySearch: table_name="events", sql="SELECT COUNT(DISTINCT user_id) FROM events", ) + # Metric using tables[] array + repo.create( + id="combined/multi", + name="multi", + display_name="Multi Table Metric", + category="combined", + tables=["subscriptions", "events"], + sql="SELECT 1", + ) table_map = repo.get_table_map() assert isinstance(table_map, dict) assert "subscriptions" in table_map assert "events" in table_map - assert set(table_map["subscriptions"]) == {"mrr", "arr"} - assert table_map["events"] == ["dau"] + # 'subscriptions' should include mrr, arr (table_name) plus multi (tables[]) + assert "mrr" in table_map["subscriptions"] + assert "arr" in table_map["subscriptions"] + assert "multi" in table_map["subscriptions"] + # 'events' should include dau (table_name) plus multi (tables[]) + assert "dau" in table_map["events"] + assert "multi" in table_map["events"] def test_get_table_map_excludes_null_table(self, db_conn): from src.repositories.metrics import MetricRepository