From a65de8574ec32ea1f7ba56b5e37a5ebbb497b48c Mon Sep 17 00:00:00 2001 From: ZdenekSrotyr Date: Fri, 10 Apr 2026 19:25:11 +0200 Subject: [PATCH] feat: add import_from_yaml and export_to_yaml to MetricRepository MIME-Version: 1.0 Content-Type: text/plain; charset=UTF-8 Content-Transfer-Encoding: 8bit Adds YAML-based bulk import/export to MetricRepository, supporting list-wrapped and plain-dict YAML formats, table→table_name field mapping, and sql_by_* → sql_variants collection (and reverse on export). All 24 tests pass. Co-Authored-By: Claude Sonnet 4.6 --- src/repositories/metrics.py | 170 +++++++++++++++++++++++++++++++++++- tests/test_metrics.py | 115 ++++++++++++++++++++++++ 2 files changed, 284 insertions(+), 1 deletion(-) diff --git a/src/repositories/metrics.py b/src/repositories/metrics.py index 495a7a1..6335208 100644 --- a/src/repositories/metrics.py +++ b/src/repositories/metrics.py @@ -2,9 +2,11 @@ import json from datetime import datetime, timezone -from typing import Any, Optional, List, Dict +from pathlib import Path +from typing import Any, Optional, List, Dict, Union import duckdb +import yaml def json_dumps(obj) -> Optional[str]: @@ -175,3 +177,169 @@ class MetricRepository: for table_name, metric_name in rows: result.setdefault(table_name, []).append(metric_name) return result + + def import_from_yaml(self, path: Union[str, Path]) -> int: + """Import metrics from a YAML file or directory of YAML files. + + Args: + path: Path to a single .yml file or a directory containing */*.yml files. + + Returns: + Number of metrics imported. + """ + path = Path(path) + files: List[Path] = [] + + if path.is_file(): + files = [path] + elif path.is_dir(): + files = sorted(path.glob("*/*.yml")) + + count = 0 + for file_path in files: + # Infer category from parent directory name + category_from_dir = file_path.parent.name + + with open(file_path, "r") as f: + raw = yaml.safe_load(f) + + # Support both list-wrapped [{...}] and plain {...} formats + if isinstance(raw, list): + metrics_data = raw + elif isinstance(raw, dict): + metrics_data = [raw] + else: + continue + + for data in metrics_data: + if not isinstance(data, dict): + continue + + name = data.get("name") + if not name: + continue + + category = data.get("category") or category_from_dir + + # Build id as "category/name" + metric_id = f"{category}/{name}" + + # Map YAML 'table' -> DB 'table_name' + table_name = data.get("table") or data.get("table_name") + + # Collect sql_by_* keys into sql_variants dict + # e.g. sql_by_channel → {"by_channel": "..."} + sql_variants: Dict[str, str] = {} + for key, value in data.items(): + if key.startswith("sql_by_"): + variant_key = key[len("sql_"):] # strip 'sql_' prefix → 'by_channel' + sql_variants[variant_key] = value + + self.create( + id=metric_id, + name=name, + display_name=data.get("display_name", name), + category=category, + sql=data.get("sql", ""), + description=data.get("description"), + type=data.get("type", "sum"), + unit=data.get("unit"), + grain=data.get("grain", "monthly"), + table_name=table_name, + tables=data.get("tables"), + expression=data.get("expression"), + time_column=data.get("time_column"), + dimensions=data.get("dimensions"), + filters=data.get("filters"), + synonyms=data.get("synonyms"), + notes=data.get("notes"), + sql_variants=sql_variants if sql_variants else None, + validation=data.get("validation"), + source="yaml_import", + ) + count += 1 + + return count + + def export_to_yaml(self, output_dir: Union[str, Path]) -> int: + """Export all metrics to YAML files under output_dir/{category}/{name}.yml. + + Args: + output_dir: Root directory for the exported YAML files. + + Returns: + Number of metrics exported. + """ + output_dir = Path(output_dir) + metrics = self.list() + count = 0 + + for metric in metrics: + category = metric.get("category") or "uncategorized" + name = metric.get("name") or metric["id"].split("/")[-1] + + category_dir = output_dir / category + category_dir.mkdir(parents=True, exist_ok=True) + + # Build the YAML dict — map table_name back to table + data: Dict[str, Any] = {"name": name} + if metric.get("display_name"): + data["display_name"] = metric["display_name"] + data["category"] = category + if metric.get("type"): + data["type"] = metric["type"] + if metric.get("unit"): + data["unit"] = metric["unit"] + if metric.get("grain"): + data["grain"] = metric["grain"] + if metric.get("time_column"): + data["time_column"] = metric["time_column"] + # Use 'table' (not 'table_name') in YAML output + if metric.get("table_name"): + data["table"] = metric["table_name"] + if metric.get("expression"): + data["expression"] = metric["expression"] + if metric.get("description"): + data["description"] = metric["description"] + if metric.get("dimensions"): + data["dimensions"] = metric["dimensions"] + if metric.get("filters"): + data["filters"] = metric["filters"] + if metric.get("synonyms"): + data["synonyms"] = metric["synonyms"] + if metric.get("notes"): + data["notes"] = metric["notes"] + if metric.get("sql"): + data["sql"] = metric["sql"] + + # Expand sql_variants back to sql_by_* keys + sql_variants = metric.get("sql_variants") + if sql_variants: + if isinstance(sql_variants, str): + try: + sql_variants = json.loads(sql_variants) + except (json.JSONDecodeError, ValueError): + sql_variants = {} + if isinstance(sql_variants, dict): + for variant_key, variant_sql in sql_variants.items(): + # variant_key is e.g. 'by_channel' → YAML key 'sql_by_channel' + data[f"sql_{variant_key}"] = variant_sql + + # Handle validation JSON + validation = metric.get("validation") + if validation: + if isinstance(validation, str): + try: + validation = json.loads(validation) + except (json.JSONDecodeError, ValueError): + validation = None + if validation: + data["validation"] = validation + + out_file = category_dir / f"{name}.yml" + with open(out_file, "w") as f: + yaml.dump(data, f, default_flow_style=False, allow_unicode=True, sort_keys=False) + + count += 1 + + return count diff --git a/tests/test_metrics.py b/tests/test_metrics.py index 310f096..6ced1ac 100644 --- a/tests/test_metrics.py +++ b/tests/test_metrics.py @@ -304,3 +304,118 @@ class TestMetricRepositorySearch: table_map = repo.get_table_map() assert "None" not in table_map assert None not in table_map + + +@pytest.fixture +def metrics_dir(tmp_path): + revenue_dir = tmp_path / "metrics" / "revenue" + revenue_dir.mkdir(parents=True) + ops_dir = tmp_path / "metrics" / "operations" + ops_dir.mkdir(parents=True) + + # total_revenue.yml — list-wrapped format with table key and sql_by_channel variant + (revenue_dir / "total_revenue.yml").write_text( + "- name: total_revenue\n" + " display_name: Total Revenue\n" + " category: revenue\n" + " type: sum\n" + " unit: USD\n" + " grain: monthly\n" + " table: orders\n" + " sql: |\n" + " SELECT DATE_TRUNC('month', order_date) AS month, SUM(total_amount) AS revenue FROM orders GROUP BY 1\n" + " sql_by_channel: |\n" + " SELECT channel, SUM(total_amount) AS revenue FROM orders GROUP BY 1\n" + ) + + # resolution_time.yml — plain dict format (no list wrapper) + (ops_dir / "resolution_time.yml").write_text( + "name: resolution_time\n" + "display_name: Resolution Time\n" + "type: avg\n" + "unit: hours\n" + "grain: weekly\n" + "table: tickets\n" + "sql: |\n" + " SELECT AVG(resolution_hours) FROM tickets\n" + ) + + return tmp_path / "metrics" + + +class TestMetricRepositoryImport: + def test_import_from_directory(self, db_conn, metrics_dir): + from src.repositories.metrics import MetricRepository + repo = MetricRepository(db_conn) + count = repo.import_from_yaml(metrics_dir) + assert count == 2 + all_metrics = repo.list() + assert len(all_metrics) == 2 + ids = {m["id"] for m in all_metrics} + assert "revenue/total_revenue" in ids + assert "operations/resolution_time" in ids + + def test_import_maps_table_to_table_name(self, db_conn, metrics_dir): + from src.repositories.metrics import MetricRepository + repo = MetricRepository(db_conn) + repo.import_from_yaml(metrics_dir) + metric = repo.get("revenue/total_revenue") + assert metric is not None + assert metric["table_name"] == "orders" + + def test_import_collects_sql_variants(self, db_conn, metrics_dir): + from src.repositories.metrics import MetricRepository + import json + repo = MetricRepository(db_conn) + repo.import_from_yaml(metrics_dir) + metric = repo.get("revenue/total_revenue") + assert metric is not None + sql_variants = metric["sql_variants"] + # DuckDB may return as a string — parse if so + if isinstance(sql_variants, str): + sql_variants = json.loads(sql_variants) + assert isinstance(sql_variants, dict) + assert "by_channel" in sql_variants + assert "channel" in sql_variants["by_channel"] + + def test_import_single_file(self, db_conn, metrics_dir): + from src.repositories.metrics import MetricRepository + repo = MetricRepository(db_conn) + single_file = metrics_dir / "revenue" / "total_revenue.yml" + count = repo.import_from_yaml(single_file) + assert count == 1 + metric = repo.get("revenue/total_revenue") + assert metric is not None + + def test_import_idempotent(self, db_conn, metrics_dir): + from src.repositories.metrics import MetricRepository + repo = MetricRepository(db_conn) + repo.import_from_yaml(metrics_dir) + repo.import_from_yaml(metrics_dir) + all_metrics = repo.list() + assert len(all_metrics) == 2 + + +class TestMetricRepositoryExport: + def test_export_to_yaml(self, db_conn, metrics_dir, tmp_path): + from src.repositories.metrics import MetricRepository + import yaml + repo = MetricRepository(db_conn) + repo.import_from_yaml(metrics_dir) + output_dir = tmp_path / "exported" + count = repo.export_to_yaml(output_dir) + assert count == 2 + # Check expected files exist + revenue_file = output_dir / "revenue" / "total_revenue.yml" + ops_file = output_dir / "operations" / "resolution_time.yml" + assert revenue_file.exists() + assert ops_file.exists() + # Verify content uses 'table' not 'table_name' + with open(revenue_file) as f: + data = yaml.safe_load(f) + assert "table" in data + assert "table_name" not in data + assert data["table"] == "orders" + # Verify sql_variants are expanded back to sql_by_* keys + assert "sql_by_channel" in data + assert "sql_variants" not in data