From 88d536ca29dea59a11ae2041089154500754b884 Mon Sep 17 00:00:00 2001 From: ZdenekSrotyr Date: Fri, 10 Apr 2026 19:21:25 +0200 Subject: [PATCH] feat: add MetricRepository with full CRUD and search for metric_definitions MIME-Version: 1.0 Content-Type: text/plain; charset=UTF-8 Content-Transfer-Encoding: 8bit Implements MetricRepository following the table_registry pattern — raw SQL, dict returns, ON CONFLICT upsert, and json.dumps for sql_variants/validation. Includes 18 tests covering create, read, list, update, delete, find_by_table, find_by_synonym, and get_table_map. --- src/repositories/metrics.py | 177 +++++++++++++++++++++ tests/test_metrics.py | 306 ++++++++++++++++++++++++++++++++++++ 2 files changed, 483 insertions(+) create mode 100644 src/repositories/metrics.py create mode 100644 tests/test_metrics.py diff --git a/src/repositories/metrics.py b/src/repositories/metrics.py new file mode 100644 index 0000000..495a7a1 --- /dev/null +++ b/src/repositories/metrics.py @@ -0,0 +1,177 @@ +"""Repository for metric definitions.""" + +import json +from datetime import datetime, timezone +from typing import Any, Optional, List, Dict + +import duckdb + + +def json_dumps(obj) -> Optional[str]: + """Serialize obj to JSON string, or None if obj is None.""" + if obj is None: + return None + return json.dumps(obj) + + +class MetricRepository: + def __init__(self, conn: duckdb.DuckDBPyConnection): + self.conn = conn + + def _row_to_dict(self, row) -> Optional[Dict[str, Any]]: + if not row: + return None + columns = [desc[0] for desc in self.conn.description] + return dict(zip(columns, row)) + + def _rows_to_dicts(self, rows) -> List[Dict[str, Any]]: + if not rows: + return [] + columns = [desc[0] for desc in self.conn.description] + return [dict(zip(columns, row)) for row in rows] + + def create( + self, + id: str, + name: str, + display_name: str, + category: str, + sql: str, + description: Optional[str] = None, + type: str = "sum", + unit: Optional[str] = None, + grain: str = "monthly", + table_name: Optional[str] = None, + tables: Optional[List[str]] = None, + expression: Optional[str] = None, + time_column: Optional[str] = None, + dimensions: Optional[List[str]] = None, + filters: Optional[List[str]] = None, + synonyms: Optional[List[str]] = None, + notes: Optional[List[str]] = None, + sql_variants: Optional[Dict[str, Any]] = None, + validation: Optional[Dict[str, Any]] = None, + source: str = "manual", + **kwargs, + ) -> Dict[str, Any]: + now = datetime.now(timezone.utc) + self.conn.execute( + """INSERT INTO metric_definitions ( + id, name, display_name, category, description, type, unit, grain, + table_name, tables, expression, time_column, dimensions, filters, + synonyms, notes, sql, sql_variants, validation, source, + created_at, updated_at + ) VALUES (?, ?, ?, ?, ?, ?, ?, ?, ?, ?, ?, ?, ?, ?, ?, ?, ?, ?, ?, ?, ?, ?) + ON CONFLICT (id) DO UPDATE SET + name = excluded.name, + display_name = excluded.display_name, + category = excluded.category, + description = excluded.description, + type = excluded.type, + unit = excluded.unit, + grain = excluded.grain, + table_name = excluded.table_name, + tables = excluded.tables, + expression = excluded.expression, + time_column = excluded.time_column, + dimensions = excluded.dimensions, + filters = excluded.filters, + synonyms = excluded.synonyms, + notes = excluded.notes, + sql = excluded.sql, + sql_variants = excluded.sql_variants, + validation = excluded.validation, + source = excluded.source, + updated_at = excluded.updated_at""", + [ + id, name, display_name, category, description, type, unit, grain, + table_name, tables, expression, time_column, dimensions, filters, + synonyms, notes, sql, + json_dumps(sql_variants), json_dumps(validation), source, + now, now, + ], + ) + return self.get(id) + + def get(self, metric_id: str) -> Optional[Dict[str, Any]]: + result = self.conn.execute( + "SELECT * FROM metric_definitions WHERE id = ?", [metric_id] + ).fetchone() + return self._row_to_dict(result) + + def list(self, category: Optional[str] = None) -> List[Dict[str, Any]]: + if category is not None: + rows = self.conn.execute( + "SELECT * FROM metric_definitions WHERE category = ? ORDER BY name", + [category], + ).fetchall() + else: + rows = self.conn.execute( + "SELECT * FROM metric_definitions ORDER BY name" + ).fetchall() + return self._rows_to_dicts(rows) + + def update(self, metric_id: str, **kwargs) -> Optional[Dict[str, Any]]: + # Check existence first + existing = self.get(metric_id) + if existing is None: + return None + + allowed = { + "name", "display_name", "category", "description", "type", "unit", + "grain", "table_name", "tables", "expression", "time_column", + "dimensions", "filters", "synonyms", "notes", "sql", + "sql_variants", "validation", "source", + } + # JSON fields that need serialization + json_fields = {"sql_variants", "validation"} + + updates = {} + for k, v in kwargs.items(): + if k in allowed: + if k in json_fields: + updates[k] = json_dumps(v) + else: + updates[k] = v + + if not updates: + return existing + + updates["updated_at"] = datetime.now(timezone.utc) + set_clause = ", ".join(f"{k} = ?" for k in updates) + values = list(updates.values()) + [metric_id] + self.conn.execute( + f"UPDATE metric_definitions SET {set_clause} WHERE id = ?", values + ) + return self.get(metric_id) + + def delete(self, metric_id: str) -> bool: + existing = self.get(metric_id) + if existing is None: + return False + self.conn.execute("DELETE FROM metric_definitions WHERE id = ?", [metric_id]) + return True + + def find_by_table(self, table_name: str) -> List[Dict[str, Any]]: + rows = self.conn.execute( + "SELECT * FROM metric_definitions WHERE table_name = ? ORDER BY name", + [table_name], + ).fetchall() + return self._rows_to_dicts(rows) + + def find_by_synonym(self, term: str) -> List[Dict[str, Any]]: + rows = self.conn.execute( + "SELECT * FROM metric_definitions WHERE list_contains(synonyms, ?) ORDER BY name", + [term], + ).fetchall() + return self._rows_to_dicts(rows) + + def get_table_map(self) -> Dict[str, List[str]]: + """Return {table_name: [metric_name, ...]} for profiler use.""" + rows = self.conn.execute( + "SELECT table_name, name FROM metric_definitions WHERE table_name IS NOT NULL ORDER BY table_name, name" + ).fetchall() + result: Dict[str, List[str]] = {} + for table_name, metric_name in rows: + result.setdefault(table_name, []).append(metric_name) + return result diff --git a/tests/test_metrics.py b/tests/test_metrics.py new file mode 100644 index 0000000..310f096 --- /dev/null +++ b/tests/test_metrics.py @@ -0,0 +1,306 @@ +"""Tests for MetricRepository (metric_definitions table).""" + +import pytest + + +@pytest.fixture +def db_conn(tmp_path, monkeypatch): + monkeypatch.setenv("DATA_DIR", str(tmp_path)) + from src.db import get_system_db + conn = get_system_db() + yield conn + conn.close() + + +SAMPLE_METRIC = { + "id": "revenue/mrr", + "name": "mrr", + "display_name": "Monthly Recurring Revenue", + "category": "revenue", + "description": "Total MRR from all subscriptions", + "type": "sum", + "unit": "USD", + "grain": "monthly", + "table_name": "subscriptions", + "expression": "SUM(mrr_amount)", + "time_column": "billing_date", + "dimensions": ["plan_type", "region"], + "synonyms": ["monthly_revenue", "recurring_revenue"], + "notes": ["Excludes one-time fees"], + "sql": "SELECT DATE_TRUNC('month', billing_date) AS month, SUM(mrr_amount) AS mrr FROM subscriptions GROUP BY 1", +} + + +class TestMetricRepositoryCreate: + def test_create_metric(self, db_conn): + from src.repositories.metrics import MetricRepository + repo = MetricRepository(db_conn) + result = repo.create(**SAMPLE_METRIC) + assert result is not None + assert result["id"] == "revenue/mrr" + assert result["name"] == "mrr" + assert result["display_name"] == "Monthly Recurring Revenue" + assert result["category"] == "revenue" + assert result["description"] == "Total MRR from all subscriptions" + assert result["type"] == "sum" + assert result["unit"] == "USD" + assert result["grain"] == "monthly" + assert result["table_name"] == "subscriptions" + assert result["expression"] == "SUM(mrr_amount)" + assert result["time_column"] == "billing_date" + assert result["dimensions"] == ["plan_type", "region"] + assert result["synonyms"] == ["monthly_revenue", "recurring_revenue"] + assert result["notes"] == ["Excludes one-time fees"] + assert "SELECT" in result["sql"] + assert result["source"] == "manual" + + def test_create_duplicate_upserts(self, db_conn): + from src.repositories.metrics import MetricRepository + repo = MetricRepository(db_conn) + repo.create(**SAMPLE_METRIC) + # Create again with different display_name + updated = {**SAMPLE_METRIC, "display_name": "MRR (Updated)"} + repo.create(**updated) + # Should only have one record + all_metrics = repo.list() + assert len(all_metrics) == 1 + assert all_metrics[0]["display_name"] == "MRR (Updated)" + + def test_create_with_defaults(self, db_conn): + from src.repositories.metrics import MetricRepository + repo = MetricRepository(db_conn) + result = repo.create( + id="test/metric", + name="test_metric", + display_name="Test Metric", + category="test", + sql="SELECT 1", + ) + assert result["type"] == "sum" + assert result["grain"] == "monthly" + assert result["source"] == "manual" + + def test_create_with_json_fields(self, db_conn): + from src.repositories.metrics import MetricRepository + repo = MetricRepository(db_conn) + result = repo.create( + **SAMPLE_METRIC, + sql_variants={"weekly": "SELECT DATE_TRUNC('week', billing_date), SUM(mrr) FROM subscriptions GROUP BY 1"}, + validation={"min": 0, "max": 1000000}, + ) + assert result is not None + assert result["id"] == "revenue/mrr" + + +class TestMetricRepositoryRead: + def test_get_existing(self, db_conn): + from src.repositories.metrics import MetricRepository + repo = MetricRepository(db_conn) + repo.create(**SAMPLE_METRIC) + metric = repo.get("revenue/mrr") + assert metric is not None + assert metric["name"] == "mrr" + assert metric["category"] == "revenue" + + def test_get_missing(self, db_conn): + from src.repositories.metrics import MetricRepository + repo = MetricRepository(db_conn) + result = repo.get("nonexistent/metric") + assert result is None + + def test_list_all(self, db_conn): + from src.repositories.metrics import MetricRepository + repo = MetricRepository(db_conn) + repo.create(**SAMPLE_METRIC) + repo.create( + id="engagement/dau", + name="dau", + display_name="Daily Active Users", + category="engagement", + sql="SELECT COUNT(DISTINCT user_id) FROM events WHERE DATE(created_at) = CURRENT_DATE", + ) + all_metrics = repo.list() + assert len(all_metrics) == 2 + ids = {m["id"] for m in all_metrics} + assert "revenue/mrr" in ids + assert "engagement/dau" in ids + + def test_list_by_category(self, db_conn): + from src.repositories.metrics import MetricRepository + repo = MetricRepository(db_conn) + repo.create(**SAMPLE_METRIC) + repo.create( + id="engagement/dau", + name="dau", + display_name="Daily Active Users", + category="engagement", + sql="SELECT COUNT(DISTINCT user_id) FROM events", + ) + revenue_metrics = repo.list(category="revenue") + assert len(revenue_metrics) == 1 + assert revenue_metrics[0]["id"] == "revenue/mrr" + + engagement_metrics = repo.list(category="engagement") + assert len(engagement_metrics) == 1 + assert engagement_metrics[0]["id"] == "engagement/dau" + + +class TestMetricRepositoryUpdate: + def test_update_fields(self, db_conn): + from src.repositories.metrics import MetricRepository + repo = MetricRepository(db_conn) + repo.create(**SAMPLE_METRIC) + updated = repo.update("revenue/mrr", display_name="MRR (New)", unit="EUR") + assert updated is not None + assert updated["display_name"] == "MRR (New)" + assert updated["unit"] == "EUR" + # Unchanged fields should persist + assert updated["name"] == "mrr" + assert updated["category"] == "revenue" + assert updated["description"] == "Total MRR from all subscriptions" + + def test_update_missing_returns_none(self, db_conn): + from src.repositories.metrics import MetricRepository + repo = MetricRepository(db_conn) + result = repo.update("nonexistent/metric", display_name="Doesn't matter") + assert result is None + + def test_update_persists_to_db(self, db_conn): + from src.repositories.metrics import MetricRepository + repo = MetricRepository(db_conn) + repo.create(**SAMPLE_METRIC) + repo.update("revenue/mrr", unit="GBP") + # Re-fetch from DB to verify persistence + metric = repo.get("revenue/mrr") + assert metric["unit"] == "GBP" + + +class TestMetricRepositoryDelete: + def test_delete_existing(self, db_conn): + from src.repositories.metrics import MetricRepository + repo = MetricRepository(db_conn) + repo.create(**SAMPLE_METRIC) + result = repo.delete("revenue/mrr") + assert result is True + assert repo.get("revenue/mrr") is None + + def test_delete_missing(self, db_conn): + from src.repositories.metrics import MetricRepository + repo = MetricRepository(db_conn) + result = repo.delete("nonexistent/metric") + assert result is False + + def test_delete_only_target(self, db_conn): + from src.repositories.metrics import MetricRepository + repo = MetricRepository(db_conn) + repo.create(**SAMPLE_METRIC) + repo.create( + id="engagement/dau", + name="dau", + display_name="Daily Active Users", + category="engagement", + sql="SELECT 1", + ) + repo.delete("revenue/mrr") + all_metrics = repo.list() + assert len(all_metrics) == 1 + assert all_metrics[0]["id"] == "engagement/dau" + + +class TestMetricRepositorySearch: + def test_find_by_table(self, db_conn): + from src.repositories.metrics import MetricRepository + repo = MetricRepository(db_conn) + # 2 metrics with table_name='subscriptions' + repo.create(**SAMPLE_METRIC) + repo.create( + id="revenue/arr", + name="arr", + display_name="Annual Recurring Revenue", + category="revenue", + table_name="subscriptions", + sql="SELECT SUM(mrr_amount) * 12 AS arr FROM subscriptions", + ) + # 1 metric with different table + repo.create( + id="engagement/dau", + name="dau", + display_name="Daily Active Users", + category="engagement", + table_name="events", + sql="SELECT COUNT(DISTINCT user_id) FROM events", + ) + sub_metrics = repo.find_by_table("subscriptions") + assert len(sub_metrics) == 2 + ids = {m["id"] for m in sub_metrics} + assert "revenue/mrr" in ids + assert "revenue/arr" in ids + + event_metrics = repo.find_by_table("events") + assert len(event_metrics) == 1 + assert event_metrics[0]["id"] == "engagement/dau" + + def test_find_by_synonym(self, db_conn): + from src.repositories.metrics import MetricRepository + repo = MetricRepository(db_conn) + repo.create(**SAMPLE_METRIC) # has synonyms: ["monthly_revenue", "recurring_revenue"] + repo.create( + id="engagement/dau", + name="dau", + display_name="Daily Active Users", + category="engagement", + synonyms=["active_users", "daily_users"], + sql="SELECT COUNT(DISTINCT user_id) FROM events", + ) + results = repo.find_by_synonym("monthly_revenue") + assert len(results) == 1 + assert results[0]["id"] == "revenue/mrr" + + results2 = repo.find_by_synonym("active_users") + assert len(results2) == 1 + assert results2[0]["id"] == "engagement/dau" + + results3 = repo.find_by_synonym("nonexistent_synonym") + assert len(results3) == 0 + + def test_get_table_map(self, db_conn): + from src.repositories.metrics import MetricRepository + repo = MetricRepository(db_conn) + repo.create(**SAMPLE_METRIC) # table_name='subscriptions' + repo.create( + id="revenue/arr", + name="arr", + display_name="Annual Recurring Revenue", + category="revenue", + table_name="subscriptions", + sql="SELECT SUM(mrr_amount) * 12 FROM subscriptions", + ) + repo.create( + id="engagement/dau", + name="dau", + display_name="Daily Active Users", + category="engagement", + table_name="events", + sql="SELECT COUNT(DISTINCT user_id) FROM events", + ) + table_map = repo.get_table_map() + assert isinstance(table_map, dict) + assert "subscriptions" in table_map + assert "events" in table_map + assert set(table_map["subscriptions"]) == {"mrr", "arr"} + assert table_map["events"] == ["dau"] + + def test_get_table_map_excludes_null_table(self, db_conn): + from src.repositories.metrics import MetricRepository + repo = MetricRepository(db_conn) + # Metric without table_name + repo.create( + id="test/no_table", + name="no_table", + display_name="No Table Metric", + category="test", + sql="SELECT 1", + ) + table_map = repo.get_table_map() + assert "None" not in table_map + assert None not in table_map