feat: add MetricRepository with full CRUD and search for metric_definitions

Implements MetricRepository following the table_registry pattern — raw SQL,
dict returns, ON CONFLICT upsert, and json.dumps for sql_variants/validation.
Includes 18 tests covering create, read, list, update, delete, find_by_table,
find_by_synonym, and get_table_map.
This commit is contained in:
ZdenekSrotyr 2026-04-10 19:21:25 +02:00
parent cc1445f7ed
commit 88d536ca29
2 changed files with 483 additions and 0 deletions

177
src/repositories/metrics.py Normal file
View file

@ -0,0 +1,177 @@
"""Repository for metric definitions."""
import json
from datetime import datetime, timezone
from typing import Any, Optional, List, Dict
import duckdb
def json_dumps(obj) -> Optional[str]:
"""Serialize obj to JSON string, or None if obj is None."""
if obj is None:
return None
return json.dumps(obj)
class MetricRepository:
def __init__(self, conn: duckdb.DuckDBPyConnection):
self.conn = conn
def _row_to_dict(self, row) -> Optional[Dict[str, Any]]:
if not row:
return None
columns = [desc[0] for desc in self.conn.description]
return dict(zip(columns, row))
def _rows_to_dicts(self, rows) -> List[Dict[str, Any]]:
if not rows:
return []
columns = [desc[0] for desc in self.conn.description]
return [dict(zip(columns, row)) for row in rows]
def create(
self,
id: str,
name: str,
display_name: str,
category: str,
sql: str,
description: Optional[str] = None,
type: str = "sum",
unit: Optional[str] = None,
grain: str = "monthly",
table_name: Optional[str] = None,
tables: Optional[List[str]] = None,
expression: Optional[str] = None,
time_column: Optional[str] = None,
dimensions: Optional[List[str]] = None,
filters: Optional[List[str]] = None,
synonyms: Optional[List[str]] = None,
notes: Optional[List[str]] = None,
sql_variants: Optional[Dict[str, Any]] = None,
validation: Optional[Dict[str, Any]] = None,
source: str = "manual",
**kwargs,
) -> Dict[str, Any]:
now = datetime.now(timezone.utc)
self.conn.execute(
"""INSERT INTO metric_definitions (
id, name, display_name, category, description, type, unit, grain,
table_name, tables, expression, time_column, dimensions, filters,
synonyms, notes, sql, sql_variants, validation, source,
created_at, updated_at
) VALUES (?, ?, ?, ?, ?, ?, ?, ?, ?, ?, ?, ?, ?, ?, ?, ?, ?, ?, ?, ?, ?, ?)
ON CONFLICT (id) DO UPDATE SET
name = excluded.name,
display_name = excluded.display_name,
category = excluded.category,
description = excluded.description,
type = excluded.type,
unit = excluded.unit,
grain = excluded.grain,
table_name = excluded.table_name,
tables = excluded.tables,
expression = excluded.expression,
time_column = excluded.time_column,
dimensions = excluded.dimensions,
filters = excluded.filters,
synonyms = excluded.synonyms,
notes = excluded.notes,
sql = excluded.sql,
sql_variants = excluded.sql_variants,
validation = excluded.validation,
source = excluded.source,
updated_at = excluded.updated_at""",
[
id, name, display_name, category, description, type, unit, grain,
table_name, tables, expression, time_column, dimensions, filters,
synonyms, notes, sql,
json_dumps(sql_variants), json_dumps(validation), source,
now, now,
],
)
return self.get(id)
def get(self, metric_id: str) -> Optional[Dict[str, Any]]:
result = self.conn.execute(
"SELECT * FROM metric_definitions WHERE id = ?", [metric_id]
).fetchone()
return self._row_to_dict(result)
def list(self, category: Optional[str] = None) -> List[Dict[str, Any]]:
if category is not None:
rows = self.conn.execute(
"SELECT * FROM metric_definitions WHERE category = ? ORDER BY name",
[category],
).fetchall()
else:
rows = self.conn.execute(
"SELECT * FROM metric_definitions ORDER BY name"
).fetchall()
return self._rows_to_dicts(rows)
def update(self, metric_id: str, **kwargs) -> Optional[Dict[str, Any]]:
# Check existence first
existing = self.get(metric_id)
if existing is None:
return None
allowed = {
"name", "display_name", "category", "description", "type", "unit",
"grain", "table_name", "tables", "expression", "time_column",
"dimensions", "filters", "synonyms", "notes", "sql",
"sql_variants", "validation", "source",
}
# JSON fields that need serialization
json_fields = {"sql_variants", "validation"}
updates = {}
for k, v in kwargs.items():
if k in allowed:
if k in json_fields:
updates[k] = json_dumps(v)
else:
updates[k] = v
if not updates:
return existing
updates["updated_at"] = datetime.now(timezone.utc)
set_clause = ", ".join(f"{k} = ?" for k in updates)
values = list(updates.values()) + [metric_id]
self.conn.execute(
f"UPDATE metric_definitions SET {set_clause} WHERE id = ?", values
)
return self.get(metric_id)
def delete(self, metric_id: str) -> bool:
existing = self.get(metric_id)
if existing is None:
return False
self.conn.execute("DELETE FROM metric_definitions WHERE id = ?", [metric_id])
return True
def find_by_table(self, table_name: str) -> List[Dict[str, Any]]:
rows = self.conn.execute(
"SELECT * FROM metric_definitions WHERE table_name = ? ORDER BY name",
[table_name],
).fetchall()
return self._rows_to_dicts(rows)
def find_by_synonym(self, term: str) -> List[Dict[str, Any]]:
rows = self.conn.execute(
"SELECT * FROM metric_definitions WHERE list_contains(synonyms, ?) ORDER BY name",
[term],
).fetchall()
return self._rows_to_dicts(rows)
def get_table_map(self) -> Dict[str, List[str]]:
"""Return {table_name: [metric_name, ...]} for profiler use."""
rows = self.conn.execute(
"SELECT table_name, name FROM metric_definitions WHERE table_name IS NOT NULL ORDER BY table_name, name"
).fetchall()
result: Dict[str, List[str]] = {}
for table_name, metric_name in rows:
result.setdefault(table_name, []).append(metric_name)
return result

306
tests/test_metrics.py Normal file
View file

@ -0,0 +1,306 @@
"""Tests for MetricRepository (metric_definitions table)."""
import pytest
@pytest.fixture
def db_conn(tmp_path, monkeypatch):
monkeypatch.setenv("DATA_DIR", str(tmp_path))
from src.db import get_system_db
conn = get_system_db()
yield conn
conn.close()
SAMPLE_METRIC = {
"id": "revenue/mrr",
"name": "mrr",
"display_name": "Monthly Recurring Revenue",
"category": "revenue",
"description": "Total MRR from all subscriptions",
"type": "sum",
"unit": "USD",
"grain": "monthly",
"table_name": "subscriptions",
"expression": "SUM(mrr_amount)",
"time_column": "billing_date",
"dimensions": ["plan_type", "region"],
"synonyms": ["monthly_revenue", "recurring_revenue"],
"notes": ["Excludes one-time fees"],
"sql": "SELECT DATE_TRUNC('month', billing_date) AS month, SUM(mrr_amount) AS mrr FROM subscriptions GROUP BY 1",
}
class TestMetricRepositoryCreate:
def test_create_metric(self, db_conn):
from src.repositories.metrics import MetricRepository
repo = MetricRepository(db_conn)
result = repo.create(**SAMPLE_METRIC)
assert result is not None
assert result["id"] == "revenue/mrr"
assert result["name"] == "mrr"
assert result["display_name"] == "Monthly Recurring Revenue"
assert result["category"] == "revenue"
assert result["description"] == "Total MRR from all subscriptions"
assert result["type"] == "sum"
assert result["unit"] == "USD"
assert result["grain"] == "monthly"
assert result["table_name"] == "subscriptions"
assert result["expression"] == "SUM(mrr_amount)"
assert result["time_column"] == "billing_date"
assert result["dimensions"] == ["plan_type", "region"]
assert result["synonyms"] == ["monthly_revenue", "recurring_revenue"]
assert result["notes"] == ["Excludes one-time fees"]
assert "SELECT" in result["sql"]
assert result["source"] == "manual"
def test_create_duplicate_upserts(self, db_conn):
from src.repositories.metrics import MetricRepository
repo = MetricRepository(db_conn)
repo.create(**SAMPLE_METRIC)
# Create again with different display_name
updated = {**SAMPLE_METRIC, "display_name": "MRR (Updated)"}
repo.create(**updated)
# Should only have one record
all_metrics = repo.list()
assert len(all_metrics) == 1
assert all_metrics[0]["display_name"] == "MRR (Updated)"
def test_create_with_defaults(self, db_conn):
from src.repositories.metrics import MetricRepository
repo = MetricRepository(db_conn)
result = repo.create(
id="test/metric",
name="test_metric",
display_name="Test Metric",
category="test",
sql="SELECT 1",
)
assert result["type"] == "sum"
assert result["grain"] == "monthly"
assert result["source"] == "manual"
def test_create_with_json_fields(self, db_conn):
from src.repositories.metrics import MetricRepository
repo = MetricRepository(db_conn)
result = repo.create(
**SAMPLE_METRIC,
sql_variants={"weekly": "SELECT DATE_TRUNC('week', billing_date), SUM(mrr) FROM subscriptions GROUP BY 1"},
validation={"min": 0, "max": 1000000},
)
assert result is not None
assert result["id"] == "revenue/mrr"
class TestMetricRepositoryRead:
def test_get_existing(self, db_conn):
from src.repositories.metrics import MetricRepository
repo = MetricRepository(db_conn)
repo.create(**SAMPLE_METRIC)
metric = repo.get("revenue/mrr")
assert metric is not None
assert metric["name"] == "mrr"
assert metric["category"] == "revenue"
def test_get_missing(self, db_conn):
from src.repositories.metrics import MetricRepository
repo = MetricRepository(db_conn)
result = repo.get("nonexistent/metric")
assert result is None
def test_list_all(self, db_conn):
from src.repositories.metrics import MetricRepository
repo = MetricRepository(db_conn)
repo.create(**SAMPLE_METRIC)
repo.create(
id="engagement/dau",
name="dau",
display_name="Daily Active Users",
category="engagement",
sql="SELECT COUNT(DISTINCT user_id) FROM events WHERE DATE(created_at) = CURRENT_DATE",
)
all_metrics = repo.list()
assert len(all_metrics) == 2
ids = {m["id"] for m in all_metrics}
assert "revenue/mrr" in ids
assert "engagement/dau" in ids
def test_list_by_category(self, db_conn):
from src.repositories.metrics import MetricRepository
repo = MetricRepository(db_conn)
repo.create(**SAMPLE_METRIC)
repo.create(
id="engagement/dau",
name="dau",
display_name="Daily Active Users",
category="engagement",
sql="SELECT COUNT(DISTINCT user_id) FROM events",
)
revenue_metrics = repo.list(category="revenue")
assert len(revenue_metrics) == 1
assert revenue_metrics[0]["id"] == "revenue/mrr"
engagement_metrics = repo.list(category="engagement")
assert len(engagement_metrics) == 1
assert engagement_metrics[0]["id"] == "engagement/dau"
class TestMetricRepositoryUpdate:
def test_update_fields(self, db_conn):
from src.repositories.metrics import MetricRepository
repo = MetricRepository(db_conn)
repo.create(**SAMPLE_METRIC)
updated = repo.update("revenue/mrr", display_name="MRR (New)", unit="EUR")
assert updated is not None
assert updated["display_name"] == "MRR (New)"
assert updated["unit"] == "EUR"
# Unchanged fields should persist
assert updated["name"] == "mrr"
assert updated["category"] == "revenue"
assert updated["description"] == "Total MRR from all subscriptions"
def test_update_missing_returns_none(self, db_conn):
from src.repositories.metrics import MetricRepository
repo = MetricRepository(db_conn)
result = repo.update("nonexistent/metric", display_name="Doesn't matter")
assert result is None
def test_update_persists_to_db(self, db_conn):
from src.repositories.metrics import MetricRepository
repo = MetricRepository(db_conn)
repo.create(**SAMPLE_METRIC)
repo.update("revenue/mrr", unit="GBP")
# Re-fetch from DB to verify persistence
metric = repo.get("revenue/mrr")
assert metric["unit"] == "GBP"
class TestMetricRepositoryDelete:
def test_delete_existing(self, db_conn):
from src.repositories.metrics import MetricRepository
repo = MetricRepository(db_conn)
repo.create(**SAMPLE_METRIC)
result = repo.delete("revenue/mrr")
assert result is True
assert repo.get("revenue/mrr") is None
def test_delete_missing(self, db_conn):
from src.repositories.metrics import MetricRepository
repo = MetricRepository(db_conn)
result = repo.delete("nonexistent/metric")
assert result is False
def test_delete_only_target(self, db_conn):
from src.repositories.metrics import MetricRepository
repo = MetricRepository(db_conn)
repo.create(**SAMPLE_METRIC)
repo.create(
id="engagement/dau",
name="dau",
display_name="Daily Active Users",
category="engagement",
sql="SELECT 1",
)
repo.delete("revenue/mrr")
all_metrics = repo.list()
assert len(all_metrics) == 1
assert all_metrics[0]["id"] == "engagement/dau"
class TestMetricRepositorySearch:
def test_find_by_table(self, db_conn):
from src.repositories.metrics import MetricRepository
repo = MetricRepository(db_conn)
# 2 metrics with table_name='subscriptions'
repo.create(**SAMPLE_METRIC)
repo.create(
id="revenue/arr",
name="arr",
display_name="Annual Recurring Revenue",
category="revenue",
table_name="subscriptions",
sql="SELECT SUM(mrr_amount) * 12 AS arr FROM subscriptions",
)
# 1 metric with different table
repo.create(
id="engagement/dau",
name="dau",
display_name="Daily Active Users",
category="engagement",
table_name="events",
sql="SELECT COUNT(DISTINCT user_id) FROM events",
)
sub_metrics = repo.find_by_table("subscriptions")
assert len(sub_metrics) == 2
ids = {m["id"] for m in sub_metrics}
assert "revenue/mrr" in ids
assert "revenue/arr" in ids
event_metrics = repo.find_by_table("events")
assert len(event_metrics) == 1
assert event_metrics[0]["id"] == "engagement/dau"
def test_find_by_synonym(self, db_conn):
from src.repositories.metrics import MetricRepository
repo = MetricRepository(db_conn)
repo.create(**SAMPLE_METRIC) # has synonyms: ["monthly_revenue", "recurring_revenue"]
repo.create(
id="engagement/dau",
name="dau",
display_name="Daily Active Users",
category="engagement",
synonyms=["active_users", "daily_users"],
sql="SELECT COUNT(DISTINCT user_id) FROM events",
)
results = repo.find_by_synonym("monthly_revenue")
assert len(results) == 1
assert results[0]["id"] == "revenue/mrr"
results2 = repo.find_by_synonym("active_users")
assert len(results2) == 1
assert results2[0]["id"] == "engagement/dau"
results3 = repo.find_by_synonym("nonexistent_synonym")
assert len(results3) == 0
def test_get_table_map(self, db_conn):
from src.repositories.metrics import MetricRepository
repo = MetricRepository(db_conn)
repo.create(**SAMPLE_METRIC) # table_name='subscriptions'
repo.create(
id="revenue/arr",
name="arr",
display_name="Annual Recurring Revenue",
category="revenue",
table_name="subscriptions",
sql="SELECT SUM(mrr_amount) * 12 FROM subscriptions",
)
repo.create(
id="engagement/dau",
name="dau",
display_name="Daily Active Users",
category="engagement",
table_name="events",
sql="SELECT COUNT(DISTINCT user_id) FROM events",
)
table_map = repo.get_table_map()
assert isinstance(table_map, dict)
assert "subscriptions" in table_map
assert "events" in table_map
assert set(table_map["subscriptions"]) == {"mrr", "arr"}
assert table_map["events"] == ["dau"]
def test_get_table_map_excludes_null_table(self, db_conn):
from src.repositories.metrics import MetricRepository
repo = MetricRepository(db_conn)
# Metric without table_name
repo.create(
id="test/no_table",
name="no_table",
display_name="No Table Metric",
category="test",
sql="SELECT 1",
)
table_map = repo.get_table_map()
assert "None" not in table_map
assert None not in table_map