feat: add ColumnMetadataRepository with CRUD and proposal import
This commit is contained in:
parent
c825ead209
commit
bf90f06774
2 changed files with 251 additions and 0 deletions
107
src/repositories/column_metadata.py
Normal file
107
src/repositories/column_metadata.py
Normal file
|
|
@ -0,0 +1,107 @@
|
|||
"""Repository for column metadata."""
|
||||
|
||||
import json
|
||||
from datetime import datetime, timezone
|
||||
from typing import Any, Optional, List, Dict
|
||||
|
||||
import duckdb
|
||||
|
||||
|
||||
class ColumnMetadataRepository:
|
||||
def __init__(self, conn: duckdb.DuckDBPyConnection):
|
||||
self.conn = conn
|
||||
|
||||
def save(
|
||||
self,
|
||||
table_id: str,
|
||||
column_name: str,
|
||||
basetype: Optional[str] = None,
|
||||
description: Optional[str] = None,
|
||||
confidence: str = "manual",
|
||||
source: str = "manual",
|
||||
) -> dict:
|
||||
"""Insert or update column metadata. Returns the saved record."""
|
||||
now = datetime.now(timezone.utc)
|
||||
self.conn.execute(
|
||||
"""INSERT INTO column_metadata (table_id, column_name, basetype, description, confidence, source, updated_at)
|
||||
VALUES (?, ?, ?, ?, ?, ?, ?)
|
||||
ON CONFLICT (table_id, column_name) DO UPDATE SET
|
||||
basetype = excluded.basetype,
|
||||
description = excluded.description,
|
||||
confidence = excluded.confidence,
|
||||
source = excluded.source,
|
||||
updated_at = excluded.updated_at""",
|
||||
[table_id, column_name, basetype, description, confidence, source, now],
|
||||
)
|
||||
return self.get(table_id, column_name)
|
||||
|
||||
def get(self, table_id: str, column_name: str) -> Optional[Dict[str, Any]]:
|
||||
"""Select by composite PK. Returns None if not found."""
|
||||
result = self.conn.execute(
|
||||
"SELECT * FROM column_metadata WHERE table_id = ? AND column_name = ?",
|
||||
[table_id, column_name],
|
||||
).fetchone()
|
||||
if not result:
|
||||
return None
|
||||
columns = [desc[0] for desc in self.conn.description]
|
||||
return dict(zip(columns, result))
|
||||
|
||||
def list_for_table(self, table_id: str) -> List[Dict[str, Any]]:
|
||||
"""Select all columns for a table, ordered by column_name."""
|
||||
results = self.conn.execute(
|
||||
"SELECT * FROM column_metadata WHERE table_id = ? ORDER BY column_name",
|
||||
[table_id],
|
||||
).fetchall()
|
||||
if not results:
|
||||
return []
|
||||
columns = [desc[0] for desc in self.conn.description]
|
||||
return [dict(zip(columns, row)) for row in results]
|
||||
|
||||
def delete(self, table_id: str, column_name: str) -> bool:
|
||||
"""Delete column metadata. Returns True if a row was deleted."""
|
||||
before = self.conn.execute(
|
||||
"SELECT COUNT(*) FROM column_metadata WHERE table_id = ? AND column_name = ?",
|
||||
[table_id, column_name],
|
||||
).fetchone()[0]
|
||||
if before == 0:
|
||||
return False
|
||||
self.conn.execute(
|
||||
"DELETE FROM column_metadata WHERE table_id = ? AND column_name = ?",
|
||||
[table_id, column_name],
|
||||
)
|
||||
return True
|
||||
|
||||
def import_proposal(self, proposal_path: str) -> int:
|
||||
"""Import a proposal JSON file.
|
||||
|
||||
Format:
|
||||
{
|
||||
"tables": {
|
||||
"orders": {
|
||||
"columns": {
|
||||
"id": {"basetype": "STRING", "description": "...", "confidence": "high"}
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
Sets source="ai_enrichment". Returns count of columns imported.
|
||||
"""
|
||||
with open(proposal_path, "r", encoding="utf-8") as f:
|
||||
proposal = json.load(f)
|
||||
|
||||
count = 0
|
||||
tables = proposal.get("tables", {})
|
||||
for table_id, table_data in tables.items():
|
||||
columns = table_data.get("columns", {})
|
||||
for column_name, col_data in columns.items():
|
||||
self.save(
|
||||
table_id=table_id,
|
||||
column_name=column_name,
|
||||
basetype=col_data.get("basetype"),
|
||||
description=col_data.get("description"),
|
||||
confidence=col_data.get("confidence", "high"),
|
||||
source="ai_enrichment",
|
||||
)
|
||||
count += 1
|
||||
return count
|
||||
144
tests/test_column_metadata.py
Normal file
144
tests/test_column_metadata.py
Normal file
|
|
@ -0,0 +1,144 @@
|
|||
"""Tests for ColumnMetadataRepository."""
|
||||
|
||||
import json
|
||||
import pytest
|
||||
|
||||
|
||||
@pytest.fixture
|
||||
def db_conn(tmp_path, monkeypatch):
|
||||
monkeypatch.setenv("DATA_DIR", str(tmp_path))
|
||||
from src.db import get_system_db
|
||||
conn = get_system_db()
|
||||
yield conn
|
||||
conn.close()
|
||||
|
||||
|
||||
@pytest.fixture
|
||||
def repo(db_conn):
|
||||
from src.repositories.column_metadata import ColumnMetadataRepository
|
||||
return ColumnMetadataRepository(db_conn)
|
||||
|
||||
|
||||
class TestColumnMetadataCreate:
|
||||
def test_save_single_column(self, repo):
|
||||
result = repo.save("orders", "id", basetype="STRING", description="Order ID")
|
||||
assert result["table_id"] == "orders"
|
||||
assert result["column_name"] == "id"
|
||||
assert result["basetype"] == "STRING"
|
||||
assert result["description"] == "Order ID"
|
||||
assert result["confidence"] == "manual"
|
||||
assert result["source"] == "manual"
|
||||
|
||||
def test_upsert_overwrites(self, repo):
|
||||
repo.save("orders", "id", basetype="STRING", description="Old")
|
||||
result = repo.save("orders", "id", basetype="INTEGER", description="New", confidence="high")
|
||||
assert result["basetype"] == "INTEGER"
|
||||
assert result["description"] == "New"
|
||||
assert result["confidence"] == "high"
|
||||
# Should still be only one row
|
||||
rows = repo.list_for_table("orders")
|
||||
assert len(rows) == 1
|
||||
|
||||
|
||||
class TestColumnMetadataRead:
|
||||
def test_list_for_table_filters_by_table(self, repo):
|
||||
repo.save("orders", "id", basetype="STRING")
|
||||
repo.save("orders", "total", basetype="NUMERIC")
|
||||
repo.save("orders", "status", basetype="STRING")
|
||||
repo.save("customers", "email", basetype="STRING")
|
||||
|
||||
orders_cols = repo.list_for_table("orders")
|
||||
assert len(orders_cols) == 3
|
||||
assert all(c["table_id"] == "orders" for c in orders_cols)
|
||||
|
||||
customer_cols = repo.list_for_table("customers")
|
||||
assert len(customer_cols) == 1
|
||||
assert customer_cols[0]["column_name"] == "email"
|
||||
|
||||
def test_list_for_table_ordered_by_column_name(self, repo):
|
||||
repo.save("orders", "total", basetype="NUMERIC")
|
||||
repo.save("orders", "id", basetype="STRING")
|
||||
repo.save("orders", "status", basetype="STRING")
|
||||
|
||||
cols = repo.list_for_table("orders")
|
||||
names = [c["column_name"] for c in cols]
|
||||
assert names == sorted(names)
|
||||
|
||||
def test_get_missing_returns_none(self, repo):
|
||||
result = repo.get("orders", "nonexistent")
|
||||
assert result is None
|
||||
|
||||
|
||||
class TestColumnMetadataDelete:
|
||||
def test_delete_column(self, repo):
|
||||
repo.save("orders", "id", basetype="STRING")
|
||||
deleted = repo.delete("orders", "id")
|
||||
assert deleted is True
|
||||
assert repo.get("orders", "id") is None
|
||||
|
||||
def test_delete_missing_returns_false(self, repo):
|
||||
result = repo.delete("orders", "does_not_exist")
|
||||
assert result is False
|
||||
|
||||
|
||||
class TestColumnMetadataProposal:
|
||||
def test_import_proposal_count(self, repo, tmp_path):
|
||||
proposal = {
|
||||
"tables": {
|
||||
"orders": {
|
||||
"columns": {
|
||||
"id": {"basetype": "STRING", "description": "Order ID", "confidence": "high"},
|
||||
"total": {"basetype": "NUMERIC", "description": "Total amount"},
|
||||
}
|
||||
},
|
||||
"customers": {
|
||||
"columns": {
|
||||
"email": {"basetype": "STRING", "description": "Customer email", "confidence": "medium"},
|
||||
}
|
||||
},
|
||||
}
|
||||
}
|
||||
path = tmp_path / "proposal.json"
|
||||
path.write_text(json.dumps(proposal))
|
||||
|
||||
count = repo.import_proposal(str(path))
|
||||
assert count == 3
|
||||
|
||||
def test_import_proposal_data(self, repo, tmp_path):
|
||||
proposal = {
|
||||
"tables": {
|
||||
"orders": {
|
||||
"columns": {
|
||||
"id": {"basetype": "STRING", "description": "Order ID", "confidence": "high"},
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
path = tmp_path / "proposal.json"
|
||||
path.write_text(json.dumps(proposal))
|
||||
|
||||
repo.import_proposal(str(path))
|
||||
|
||||
result = repo.get("orders", "id")
|
||||
assert result is not None
|
||||
assert result["basetype"] == "STRING"
|
||||
assert result["description"] == "Order ID"
|
||||
assert result["confidence"] == "high"
|
||||
|
||||
def test_import_sets_source_ai_enrichment(self, repo, tmp_path):
|
||||
proposal = {
|
||||
"tables": {
|
||||
"orders": {
|
||||
"columns": {
|
||||
"id": {"basetype": "STRING"},
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
path = tmp_path / "proposal.json"
|
||||
path.write_text(json.dumps(proposal))
|
||||
|
||||
repo.import_proposal(str(path))
|
||||
|
||||
result = repo.get("orders", "id")
|
||||
assert result["source"] == "ai_enrichment"
|
||||
Loading…
Reference in a new issue