feat: add import_from_yaml and export_to_yaml to MetricRepository
Adds YAML-based bulk import/export to MetricRepository, supporting list-wrapped and plain-dict YAML formats, table→table_name field mapping, and sql_by_* → sql_variants collection (and reverse on export). All 24 tests pass. Co-Authored-By: Claude Sonnet 4.6 <noreply@anthropic.com>
This commit is contained in:
parent
88d536ca29
commit
a65de8574e
2 changed files with 284 additions and 1 deletions
|
|
@ -2,9 +2,11 @@
|
||||||
|
|
||||||
import json
|
import json
|
||||||
from datetime import datetime, timezone
|
from datetime import datetime, timezone
|
||||||
from typing import Any, Optional, List, Dict
|
from pathlib import Path
|
||||||
|
from typing import Any, Optional, List, Dict, Union
|
||||||
|
|
||||||
import duckdb
|
import duckdb
|
||||||
|
import yaml
|
||||||
|
|
||||||
|
|
||||||
def json_dumps(obj) -> Optional[str]:
|
def json_dumps(obj) -> Optional[str]:
|
||||||
|
|
@ -175,3 +177,169 @@ class MetricRepository:
|
||||||
for table_name, metric_name in rows:
|
for table_name, metric_name in rows:
|
||||||
result.setdefault(table_name, []).append(metric_name)
|
result.setdefault(table_name, []).append(metric_name)
|
||||||
return result
|
return result
|
||||||
|
|
||||||
|
def import_from_yaml(self, path: Union[str, Path]) -> int:
|
||||||
|
"""Import metrics from a YAML file or directory of YAML files.
|
||||||
|
|
||||||
|
Args:
|
||||||
|
path: Path to a single .yml file or a directory containing */*.yml files.
|
||||||
|
|
||||||
|
Returns:
|
||||||
|
Number of metrics imported.
|
||||||
|
"""
|
||||||
|
path = Path(path)
|
||||||
|
files: List[Path] = []
|
||||||
|
|
||||||
|
if path.is_file():
|
||||||
|
files = [path]
|
||||||
|
elif path.is_dir():
|
||||||
|
files = sorted(path.glob("*/*.yml"))
|
||||||
|
|
||||||
|
count = 0
|
||||||
|
for file_path in files:
|
||||||
|
# Infer category from parent directory name
|
||||||
|
category_from_dir = file_path.parent.name
|
||||||
|
|
||||||
|
with open(file_path, "r") as f:
|
||||||
|
raw = yaml.safe_load(f)
|
||||||
|
|
||||||
|
# Support both list-wrapped [{...}] and plain {...} formats
|
||||||
|
if isinstance(raw, list):
|
||||||
|
metrics_data = raw
|
||||||
|
elif isinstance(raw, dict):
|
||||||
|
metrics_data = [raw]
|
||||||
|
else:
|
||||||
|
continue
|
||||||
|
|
||||||
|
for data in metrics_data:
|
||||||
|
if not isinstance(data, dict):
|
||||||
|
continue
|
||||||
|
|
||||||
|
name = data.get("name")
|
||||||
|
if not name:
|
||||||
|
continue
|
||||||
|
|
||||||
|
category = data.get("category") or category_from_dir
|
||||||
|
|
||||||
|
# Build id as "category/name"
|
||||||
|
metric_id = f"{category}/{name}"
|
||||||
|
|
||||||
|
# Map YAML 'table' -> DB 'table_name'
|
||||||
|
table_name = data.get("table") or data.get("table_name")
|
||||||
|
|
||||||
|
# Collect sql_by_* keys into sql_variants dict
|
||||||
|
# e.g. sql_by_channel → {"by_channel": "..."}
|
||||||
|
sql_variants: Dict[str, str] = {}
|
||||||
|
for key, value in data.items():
|
||||||
|
if key.startswith("sql_by_"):
|
||||||
|
variant_key = key[len("sql_"):] # strip 'sql_' prefix → 'by_channel'
|
||||||
|
sql_variants[variant_key] = value
|
||||||
|
|
||||||
|
self.create(
|
||||||
|
id=metric_id,
|
||||||
|
name=name,
|
||||||
|
display_name=data.get("display_name", name),
|
||||||
|
category=category,
|
||||||
|
sql=data.get("sql", ""),
|
||||||
|
description=data.get("description"),
|
||||||
|
type=data.get("type", "sum"),
|
||||||
|
unit=data.get("unit"),
|
||||||
|
grain=data.get("grain", "monthly"),
|
||||||
|
table_name=table_name,
|
||||||
|
tables=data.get("tables"),
|
||||||
|
expression=data.get("expression"),
|
||||||
|
time_column=data.get("time_column"),
|
||||||
|
dimensions=data.get("dimensions"),
|
||||||
|
filters=data.get("filters"),
|
||||||
|
synonyms=data.get("synonyms"),
|
||||||
|
notes=data.get("notes"),
|
||||||
|
sql_variants=sql_variants if sql_variants else None,
|
||||||
|
validation=data.get("validation"),
|
||||||
|
source="yaml_import",
|
||||||
|
)
|
||||||
|
count += 1
|
||||||
|
|
||||||
|
return count
|
||||||
|
|
||||||
|
def export_to_yaml(self, output_dir: Union[str, Path]) -> int:
|
||||||
|
"""Export all metrics to YAML files under output_dir/{category}/{name}.yml.
|
||||||
|
|
||||||
|
Args:
|
||||||
|
output_dir: Root directory for the exported YAML files.
|
||||||
|
|
||||||
|
Returns:
|
||||||
|
Number of metrics exported.
|
||||||
|
"""
|
||||||
|
output_dir = Path(output_dir)
|
||||||
|
metrics = self.list()
|
||||||
|
count = 0
|
||||||
|
|
||||||
|
for metric in metrics:
|
||||||
|
category = metric.get("category") or "uncategorized"
|
||||||
|
name = metric.get("name") or metric["id"].split("/")[-1]
|
||||||
|
|
||||||
|
category_dir = output_dir / category
|
||||||
|
category_dir.mkdir(parents=True, exist_ok=True)
|
||||||
|
|
||||||
|
# Build the YAML dict — map table_name back to table
|
||||||
|
data: Dict[str, Any] = {"name": name}
|
||||||
|
if metric.get("display_name"):
|
||||||
|
data["display_name"] = metric["display_name"]
|
||||||
|
data["category"] = category
|
||||||
|
if metric.get("type"):
|
||||||
|
data["type"] = metric["type"]
|
||||||
|
if metric.get("unit"):
|
||||||
|
data["unit"] = metric["unit"]
|
||||||
|
if metric.get("grain"):
|
||||||
|
data["grain"] = metric["grain"]
|
||||||
|
if metric.get("time_column"):
|
||||||
|
data["time_column"] = metric["time_column"]
|
||||||
|
# Use 'table' (not 'table_name') in YAML output
|
||||||
|
if metric.get("table_name"):
|
||||||
|
data["table"] = metric["table_name"]
|
||||||
|
if metric.get("expression"):
|
||||||
|
data["expression"] = metric["expression"]
|
||||||
|
if metric.get("description"):
|
||||||
|
data["description"] = metric["description"]
|
||||||
|
if metric.get("dimensions"):
|
||||||
|
data["dimensions"] = metric["dimensions"]
|
||||||
|
if metric.get("filters"):
|
||||||
|
data["filters"] = metric["filters"]
|
||||||
|
if metric.get("synonyms"):
|
||||||
|
data["synonyms"] = metric["synonyms"]
|
||||||
|
if metric.get("notes"):
|
||||||
|
data["notes"] = metric["notes"]
|
||||||
|
if metric.get("sql"):
|
||||||
|
data["sql"] = metric["sql"]
|
||||||
|
|
||||||
|
# Expand sql_variants back to sql_by_* keys
|
||||||
|
sql_variants = metric.get("sql_variants")
|
||||||
|
if sql_variants:
|
||||||
|
if isinstance(sql_variants, str):
|
||||||
|
try:
|
||||||
|
sql_variants = json.loads(sql_variants)
|
||||||
|
except (json.JSONDecodeError, ValueError):
|
||||||
|
sql_variants = {}
|
||||||
|
if isinstance(sql_variants, dict):
|
||||||
|
for variant_key, variant_sql in sql_variants.items():
|
||||||
|
# variant_key is e.g. 'by_channel' → YAML key 'sql_by_channel'
|
||||||
|
data[f"sql_{variant_key}"] = variant_sql
|
||||||
|
|
||||||
|
# Handle validation JSON
|
||||||
|
validation = metric.get("validation")
|
||||||
|
if validation:
|
||||||
|
if isinstance(validation, str):
|
||||||
|
try:
|
||||||
|
validation = json.loads(validation)
|
||||||
|
except (json.JSONDecodeError, ValueError):
|
||||||
|
validation = None
|
||||||
|
if validation:
|
||||||
|
data["validation"] = validation
|
||||||
|
|
||||||
|
out_file = category_dir / f"{name}.yml"
|
||||||
|
with open(out_file, "w") as f:
|
||||||
|
yaml.dump(data, f, default_flow_style=False, allow_unicode=True, sort_keys=False)
|
||||||
|
|
||||||
|
count += 1
|
||||||
|
|
||||||
|
return count
|
||||||
|
|
|
||||||
|
|
@ -304,3 +304,118 @@ class TestMetricRepositorySearch:
|
||||||
table_map = repo.get_table_map()
|
table_map = repo.get_table_map()
|
||||||
assert "None" not in table_map
|
assert "None" not in table_map
|
||||||
assert None not in table_map
|
assert None not in table_map
|
||||||
|
|
||||||
|
|
||||||
|
@pytest.fixture
|
||||||
|
def metrics_dir(tmp_path):
|
||||||
|
revenue_dir = tmp_path / "metrics" / "revenue"
|
||||||
|
revenue_dir.mkdir(parents=True)
|
||||||
|
ops_dir = tmp_path / "metrics" / "operations"
|
||||||
|
ops_dir.mkdir(parents=True)
|
||||||
|
|
||||||
|
# total_revenue.yml — list-wrapped format with table key and sql_by_channel variant
|
||||||
|
(revenue_dir / "total_revenue.yml").write_text(
|
||||||
|
"- name: total_revenue\n"
|
||||||
|
" display_name: Total Revenue\n"
|
||||||
|
" category: revenue\n"
|
||||||
|
" type: sum\n"
|
||||||
|
" unit: USD\n"
|
||||||
|
" grain: monthly\n"
|
||||||
|
" table: orders\n"
|
||||||
|
" sql: |\n"
|
||||||
|
" SELECT DATE_TRUNC('month', order_date) AS month, SUM(total_amount) AS revenue FROM orders GROUP BY 1\n"
|
||||||
|
" sql_by_channel: |\n"
|
||||||
|
" SELECT channel, SUM(total_amount) AS revenue FROM orders GROUP BY 1\n"
|
||||||
|
)
|
||||||
|
|
||||||
|
# resolution_time.yml — plain dict format (no list wrapper)
|
||||||
|
(ops_dir / "resolution_time.yml").write_text(
|
||||||
|
"name: resolution_time\n"
|
||||||
|
"display_name: Resolution Time\n"
|
||||||
|
"type: avg\n"
|
||||||
|
"unit: hours\n"
|
||||||
|
"grain: weekly\n"
|
||||||
|
"table: tickets\n"
|
||||||
|
"sql: |\n"
|
||||||
|
" SELECT AVG(resolution_hours) FROM tickets\n"
|
||||||
|
)
|
||||||
|
|
||||||
|
return tmp_path / "metrics"
|
||||||
|
|
||||||
|
|
||||||
|
class TestMetricRepositoryImport:
|
||||||
|
def test_import_from_directory(self, db_conn, metrics_dir):
|
||||||
|
from src.repositories.metrics import MetricRepository
|
||||||
|
repo = MetricRepository(db_conn)
|
||||||
|
count = repo.import_from_yaml(metrics_dir)
|
||||||
|
assert count == 2
|
||||||
|
all_metrics = repo.list()
|
||||||
|
assert len(all_metrics) == 2
|
||||||
|
ids = {m["id"] for m in all_metrics}
|
||||||
|
assert "revenue/total_revenue" in ids
|
||||||
|
assert "operations/resolution_time" in ids
|
||||||
|
|
||||||
|
def test_import_maps_table_to_table_name(self, db_conn, metrics_dir):
|
||||||
|
from src.repositories.metrics import MetricRepository
|
||||||
|
repo = MetricRepository(db_conn)
|
||||||
|
repo.import_from_yaml(metrics_dir)
|
||||||
|
metric = repo.get("revenue/total_revenue")
|
||||||
|
assert metric is not None
|
||||||
|
assert metric["table_name"] == "orders"
|
||||||
|
|
||||||
|
def test_import_collects_sql_variants(self, db_conn, metrics_dir):
|
||||||
|
from src.repositories.metrics import MetricRepository
|
||||||
|
import json
|
||||||
|
repo = MetricRepository(db_conn)
|
||||||
|
repo.import_from_yaml(metrics_dir)
|
||||||
|
metric = repo.get("revenue/total_revenue")
|
||||||
|
assert metric is not None
|
||||||
|
sql_variants = metric["sql_variants"]
|
||||||
|
# DuckDB may return as a string — parse if so
|
||||||
|
if isinstance(sql_variants, str):
|
||||||
|
sql_variants = json.loads(sql_variants)
|
||||||
|
assert isinstance(sql_variants, dict)
|
||||||
|
assert "by_channel" in sql_variants
|
||||||
|
assert "channel" in sql_variants["by_channel"]
|
||||||
|
|
||||||
|
def test_import_single_file(self, db_conn, metrics_dir):
|
||||||
|
from src.repositories.metrics import MetricRepository
|
||||||
|
repo = MetricRepository(db_conn)
|
||||||
|
single_file = metrics_dir / "revenue" / "total_revenue.yml"
|
||||||
|
count = repo.import_from_yaml(single_file)
|
||||||
|
assert count == 1
|
||||||
|
metric = repo.get("revenue/total_revenue")
|
||||||
|
assert metric is not None
|
||||||
|
|
||||||
|
def test_import_idempotent(self, db_conn, metrics_dir):
|
||||||
|
from src.repositories.metrics import MetricRepository
|
||||||
|
repo = MetricRepository(db_conn)
|
||||||
|
repo.import_from_yaml(metrics_dir)
|
||||||
|
repo.import_from_yaml(metrics_dir)
|
||||||
|
all_metrics = repo.list()
|
||||||
|
assert len(all_metrics) == 2
|
||||||
|
|
||||||
|
|
||||||
|
class TestMetricRepositoryExport:
|
||||||
|
def test_export_to_yaml(self, db_conn, metrics_dir, tmp_path):
|
||||||
|
from src.repositories.metrics import MetricRepository
|
||||||
|
import yaml
|
||||||
|
repo = MetricRepository(db_conn)
|
||||||
|
repo.import_from_yaml(metrics_dir)
|
||||||
|
output_dir = tmp_path / "exported"
|
||||||
|
count = repo.export_to_yaml(output_dir)
|
||||||
|
assert count == 2
|
||||||
|
# Check expected files exist
|
||||||
|
revenue_file = output_dir / "revenue" / "total_revenue.yml"
|
||||||
|
ops_file = output_dir / "operations" / "resolution_time.yml"
|
||||||
|
assert revenue_file.exists()
|
||||||
|
assert ops_file.exists()
|
||||||
|
# Verify content uses 'table' not 'table_name'
|
||||||
|
with open(revenue_file) as f:
|
||||||
|
data = yaml.safe_load(f)
|
||||||
|
assert "table" in data
|
||||||
|
assert "table_name" not in data
|
||||||
|
assert data["table"] == "orders"
|
||||||
|
# Verify sql_variants are expanded back to sql_by_* keys
|
||||||
|
assert "sql_by_channel" in data
|
||||||
|
assert "sql_variants" not in data
|
||||||
|
|
|
||||||
Loading…
Reference in a new issue