agnes-the-ai-analyst/src/repositories/metrics.py
ZdenekSrotyr 126d151413 fix: address code review — path injection, multi-table search, metrics import API, error handling
- Validate view names with _SAFE_IDENTIFIER regex and check path traversal in _initialize_duckdb()
- find_by_table() and get_table_map() now also search the tables[] array field
- Add POST /api/admin/metrics/import endpoint for YAML file upload
- Replace generic except in _connect_to_instance() with specific HTTPStatusError/TimeoutException handlers
- Generate .claude/settings.json in _generate_claude_md() bootstrap
- Update test_find_by_table and test_get_table_map to cover tables[] array lookups
- Add test_import_metrics_yaml in TestMetricsAPI
2026-04-10 19:56:00 +02:00

351 lines
13 KiB
Python

"""Repository for metric definitions."""
import json
from datetime import datetime, timezone
from pathlib import Path
from typing import Any, Optional, List, Dict, Union
import duckdb
import yaml
def json_dumps(obj) -> Optional[str]:
"""Serialize obj to JSON string, or None if obj is None."""
if obj is None:
return None
return json.dumps(obj)
class MetricRepository:
def __init__(self, conn: duckdb.DuckDBPyConnection):
self.conn = conn
def _row_to_dict(self, row) -> Optional[Dict[str, Any]]:
if not row:
return None
columns = [desc[0] for desc in self.conn.description]
return dict(zip(columns, row))
def _rows_to_dicts(self, rows) -> List[Dict[str, Any]]:
if not rows:
return []
columns = [desc[0] for desc in self.conn.description]
return [dict(zip(columns, row)) for row in rows]
def create(
self,
id: str,
name: str,
display_name: str,
category: str,
sql: str,
description: Optional[str] = None,
type: str = "sum",
unit: Optional[str] = None,
grain: str = "monthly",
table_name: Optional[str] = None,
tables: Optional[List[str]] = None,
expression: Optional[str] = None,
time_column: Optional[str] = None,
dimensions: Optional[List[str]] = None,
filters: Optional[List[str]] = None,
synonyms: Optional[List[str]] = None,
notes: Optional[List[str]] = None,
sql_variants: Optional[Dict[str, Any]] = None,
validation: Optional[Dict[str, Any]] = None,
source: str = "manual",
**kwargs,
) -> Dict[str, Any]:
now = datetime.now(timezone.utc)
self.conn.execute(
"""INSERT INTO metric_definitions (
id, name, display_name, category, description, type, unit, grain,
table_name, tables, expression, time_column, dimensions, filters,
synonyms, notes, sql, sql_variants, validation, source,
created_at, updated_at
) VALUES (?, ?, ?, ?, ?, ?, ?, ?, ?, ?, ?, ?, ?, ?, ?, ?, ?, ?, ?, ?, ?, ?)
ON CONFLICT (id) DO UPDATE SET
name = excluded.name,
display_name = excluded.display_name,
category = excluded.category,
description = excluded.description,
type = excluded.type,
unit = excluded.unit,
grain = excluded.grain,
table_name = excluded.table_name,
tables = excluded.tables,
expression = excluded.expression,
time_column = excluded.time_column,
dimensions = excluded.dimensions,
filters = excluded.filters,
synonyms = excluded.synonyms,
notes = excluded.notes,
sql = excluded.sql,
sql_variants = excluded.sql_variants,
validation = excluded.validation,
source = excluded.source,
updated_at = excluded.updated_at""",
[
id, name, display_name, category, description, type, unit, grain,
table_name, tables, expression, time_column, dimensions, filters,
synonyms, notes, sql,
json_dumps(sql_variants), json_dumps(validation), source,
now, now,
],
)
return self.get(id)
def get(self, metric_id: str) -> Optional[Dict[str, Any]]:
result = self.conn.execute(
"SELECT * FROM metric_definitions WHERE id = ?", [metric_id]
).fetchone()
return self._row_to_dict(result)
def list(self, category: Optional[str] = None) -> List[Dict[str, Any]]:
if category is not None:
rows = self.conn.execute(
"SELECT * FROM metric_definitions WHERE category = ? ORDER BY name",
[category],
).fetchall()
else:
rows = self.conn.execute(
"SELECT * FROM metric_definitions ORDER BY name"
).fetchall()
return self._rows_to_dicts(rows)
def update(self, metric_id: str, **kwargs) -> Optional[Dict[str, Any]]:
# Check existence first
existing = self.get(metric_id)
if existing is None:
return None
allowed = {
"name", "display_name", "category", "description", "type", "unit",
"grain", "table_name", "tables", "expression", "time_column",
"dimensions", "filters", "synonyms", "notes", "sql",
"sql_variants", "validation", "source",
}
# JSON fields that need serialization
json_fields = {"sql_variants", "validation"}
updates = {}
for k, v in kwargs.items():
if k in allowed:
if k in json_fields:
updates[k] = json_dumps(v)
else:
updates[k] = v
if not updates:
return existing
updates["updated_at"] = datetime.now(timezone.utc)
set_clause = ", ".join(f"{k} = ?" for k in updates)
values = list(updates.values()) + [metric_id]
self.conn.execute(
f"UPDATE metric_definitions SET {set_clause} WHERE id = ?", values
)
return self.get(metric_id)
def delete(self, metric_id: str) -> bool:
existing = self.get(metric_id)
if existing is None:
return False
self.conn.execute("DELETE FROM metric_definitions WHERE id = ?", [metric_id])
return True
def find_by_table(self, table_name: str) -> List[Dict[str, Any]]:
rows = self.conn.execute(
"SELECT * FROM metric_definitions WHERE table_name = ? OR list_contains(tables, ?) ORDER BY name",
[table_name, table_name],
).fetchall()
return self._rows_to_dicts(rows)
def find_by_synonym(self, term: str) -> List[Dict[str, Any]]:
rows = self.conn.execute(
"SELECT * FROM metric_definitions WHERE list_contains(synonyms, ?) ORDER BY name",
[term],
).fetchall()
return self._rows_to_dicts(rows)
def get_table_map(self) -> Dict[str, List[str]]:
"""Return {table_name: [metric_name, ...]} for profiler use."""
rows = self.conn.execute(
"SELECT table_name, name FROM metric_definitions WHERE table_name IS NOT NULL ORDER BY table_name, name"
).fetchall()
result: Dict[str, List[str]] = {}
for table_name, metric_name in rows:
result.setdefault(table_name, []).append(metric_name)
# Also include metrics that reference tables via the 'tables' array
results2 = self.conn.execute(
"SELECT unnest(tables) AS tbl, name FROM metric_definitions WHERE tables IS NOT NULL"
).fetchall()
for tbl, metric_name in results2:
result.setdefault(tbl, []).append(metric_name)
return result
def import_from_yaml(self, path: Union[str, Path]) -> int:
"""Import metrics from a YAML file or directory of YAML files.
Args:
path: Path to a single .yml file or a directory containing */*.yml files.
Returns:
Number of metrics imported.
"""
path = Path(path)
files: List[Path] = []
if path.is_file():
files = [path]
elif path.is_dir():
files = sorted(path.glob("*/*.yml"))
count = 0
for file_path in files:
# Infer category from parent directory name
category_from_dir = file_path.parent.name
with open(file_path, "r") as f:
raw = yaml.safe_load(f)
# Support both list-wrapped [{...}] and plain {...} formats
if isinstance(raw, list):
metrics_data = raw
elif isinstance(raw, dict):
metrics_data = [raw]
else:
continue
for data in metrics_data:
if not isinstance(data, dict):
continue
name = data.get("name")
if not name:
continue
category = data.get("category") or category_from_dir
# Build id as "category/name"
metric_id = f"{category}/{name}"
# Map YAML 'table' -> DB 'table_name'
table_name = data.get("table") or data.get("table_name")
# Collect sql_by_* keys into sql_variants dict
# e.g. sql_by_channel → {"by_channel": "..."}
sql_variants: Dict[str, str] = {}
for key, value in data.items():
if key.startswith("sql_by_"):
variant_key = key[len("sql_"):] # strip 'sql_' prefix → 'by_channel'
sql_variants[variant_key] = value
self.create(
id=metric_id,
name=name,
display_name=data.get("display_name", name),
category=category,
sql=data.get("sql", ""),
description=data.get("description"),
type=data.get("type", "sum"),
unit=data.get("unit"),
grain=data.get("grain", "monthly"),
table_name=table_name,
tables=data.get("tables"),
expression=data.get("expression"),
time_column=data.get("time_column"),
dimensions=data.get("dimensions"),
filters=data.get("filters"),
synonyms=data.get("synonyms"),
notes=data.get("notes"),
sql_variants=sql_variants if sql_variants else None,
validation=data.get("validation"),
source="yaml_import",
)
count += 1
return count
def export_to_yaml(self, output_dir: Union[str, Path]) -> int:
"""Export all metrics to YAML files under output_dir/{category}/{name}.yml.
Args:
output_dir: Root directory for the exported YAML files.
Returns:
Number of metrics exported.
"""
output_dir = Path(output_dir)
metrics = self.list()
count = 0
for metric in metrics:
category = metric.get("category") or "uncategorized"
name = metric.get("name") or metric["id"].split("/")[-1]
category_dir = output_dir / category
category_dir.mkdir(parents=True, exist_ok=True)
# Build the YAML dict — map table_name back to table
data: Dict[str, Any] = {"name": name}
if metric.get("display_name"):
data["display_name"] = metric["display_name"]
data["category"] = category
if metric.get("type"):
data["type"] = metric["type"]
if metric.get("unit"):
data["unit"] = metric["unit"]
if metric.get("grain"):
data["grain"] = metric["grain"]
if metric.get("time_column"):
data["time_column"] = metric["time_column"]
# Use 'table' (not 'table_name') in YAML output
if metric.get("table_name"):
data["table"] = metric["table_name"]
if metric.get("expression"):
data["expression"] = metric["expression"]
if metric.get("description"):
data["description"] = metric["description"]
if metric.get("dimensions"):
data["dimensions"] = metric["dimensions"]
if metric.get("filters"):
data["filters"] = metric["filters"]
if metric.get("synonyms"):
data["synonyms"] = metric["synonyms"]
if metric.get("notes"):
data["notes"] = metric["notes"]
if metric.get("sql"):
data["sql"] = metric["sql"]
# Expand sql_variants back to sql_by_* keys
sql_variants = metric.get("sql_variants")
if sql_variants:
if isinstance(sql_variants, str):
try:
sql_variants = json.loads(sql_variants)
except (json.JSONDecodeError, ValueError):
sql_variants = {}
if isinstance(sql_variants, dict):
for variant_key, variant_sql in sql_variants.items():
# variant_key is e.g. 'by_channel' → YAML key 'sql_by_channel'
data[f"sql_{variant_key}"] = variant_sql
# Handle validation JSON
validation = metric.get("validation")
if validation:
if isinstance(validation, str):
try:
validation = json.loads(validation)
except (json.JSONDecodeError, ValueError):
validation = None
if validation:
data["validation"] = validation
out_file = category_dir / f"{name}.yml"
with open(out_file, "w") as f:
yaml.dump(data, f, default_flow_style=False, allow_unicode=True, sort_keys=False)
count += 1
return count