From bf90f06774d0d2a39f952ab7c9c878263a23a836 Mon Sep 17 00:00:00 2001 From: ZdenekSrotyr Date: Fri, 10 Apr 2026 19:41:53 +0200 Subject: [PATCH] feat: add ColumnMetadataRepository with CRUD and proposal import --- src/repositories/column_metadata.py | 107 +++++++++++++++++++++ tests/test_column_metadata.py | 144 ++++++++++++++++++++++++++++ 2 files changed, 251 insertions(+) create mode 100644 src/repositories/column_metadata.py create mode 100644 tests/test_column_metadata.py diff --git a/src/repositories/column_metadata.py b/src/repositories/column_metadata.py new file mode 100644 index 0000000..26e78dc --- /dev/null +++ b/src/repositories/column_metadata.py @@ -0,0 +1,107 @@ +"""Repository for column metadata.""" + +import json +from datetime import datetime, timezone +from typing import Any, Optional, List, Dict + +import duckdb + + +class ColumnMetadataRepository: + def __init__(self, conn: duckdb.DuckDBPyConnection): + self.conn = conn + + def save( + self, + table_id: str, + column_name: str, + basetype: Optional[str] = None, + description: Optional[str] = None, + confidence: str = "manual", + source: str = "manual", + ) -> dict: + """Insert or update column metadata. Returns the saved record.""" + now = datetime.now(timezone.utc) + self.conn.execute( + """INSERT INTO column_metadata (table_id, column_name, basetype, description, confidence, source, updated_at) + VALUES (?, ?, ?, ?, ?, ?, ?) + ON CONFLICT (table_id, column_name) DO UPDATE SET + basetype = excluded.basetype, + description = excluded.description, + confidence = excluded.confidence, + source = excluded.source, + updated_at = excluded.updated_at""", + [table_id, column_name, basetype, description, confidence, source, now], + ) + return self.get(table_id, column_name) + + def get(self, table_id: str, column_name: str) -> Optional[Dict[str, Any]]: + """Select by composite PK. Returns None if not found.""" + result = self.conn.execute( + "SELECT * FROM column_metadata WHERE table_id = ? AND column_name = ?", + [table_id, column_name], + ).fetchone() + if not result: + return None + columns = [desc[0] for desc in self.conn.description] + return dict(zip(columns, result)) + + def list_for_table(self, table_id: str) -> List[Dict[str, Any]]: + """Select all columns for a table, ordered by column_name.""" + results = self.conn.execute( + "SELECT * FROM column_metadata WHERE table_id = ? ORDER BY column_name", + [table_id], + ).fetchall() + if not results: + return [] + columns = [desc[0] for desc in self.conn.description] + return [dict(zip(columns, row)) for row in results] + + def delete(self, table_id: str, column_name: str) -> bool: + """Delete column metadata. Returns True if a row was deleted.""" + before = self.conn.execute( + "SELECT COUNT(*) FROM column_metadata WHERE table_id = ? AND column_name = ?", + [table_id, column_name], + ).fetchone()[0] + if before == 0: + return False + self.conn.execute( + "DELETE FROM column_metadata WHERE table_id = ? AND column_name = ?", + [table_id, column_name], + ) + return True + + def import_proposal(self, proposal_path: str) -> int: + """Import a proposal JSON file. + + Format: + { + "tables": { + "orders": { + "columns": { + "id": {"basetype": "STRING", "description": "...", "confidence": "high"} + } + } + } + } + + Sets source="ai_enrichment". Returns count of columns imported. + """ + with open(proposal_path, "r", encoding="utf-8") as f: + proposal = json.load(f) + + count = 0 + tables = proposal.get("tables", {}) + for table_id, table_data in tables.items(): + columns = table_data.get("columns", {}) + for column_name, col_data in columns.items(): + self.save( + table_id=table_id, + column_name=column_name, + basetype=col_data.get("basetype"), + description=col_data.get("description"), + confidence=col_data.get("confidence", "high"), + source="ai_enrichment", + ) + count += 1 + return count diff --git a/tests/test_column_metadata.py b/tests/test_column_metadata.py new file mode 100644 index 0000000..a183b9c --- /dev/null +++ b/tests/test_column_metadata.py @@ -0,0 +1,144 @@ +"""Tests for ColumnMetadataRepository.""" + +import json +import pytest + + +@pytest.fixture +def db_conn(tmp_path, monkeypatch): + monkeypatch.setenv("DATA_DIR", str(tmp_path)) + from src.db import get_system_db + conn = get_system_db() + yield conn + conn.close() + + +@pytest.fixture +def repo(db_conn): + from src.repositories.column_metadata import ColumnMetadataRepository + return ColumnMetadataRepository(db_conn) + + +class TestColumnMetadataCreate: + def test_save_single_column(self, repo): + result = repo.save("orders", "id", basetype="STRING", description="Order ID") + assert result["table_id"] == "orders" + assert result["column_name"] == "id" + assert result["basetype"] == "STRING" + assert result["description"] == "Order ID" + assert result["confidence"] == "manual" + assert result["source"] == "manual" + + def test_upsert_overwrites(self, repo): + repo.save("orders", "id", basetype="STRING", description="Old") + result = repo.save("orders", "id", basetype="INTEGER", description="New", confidence="high") + assert result["basetype"] == "INTEGER" + assert result["description"] == "New" + assert result["confidence"] == "high" + # Should still be only one row + rows = repo.list_for_table("orders") + assert len(rows) == 1 + + +class TestColumnMetadataRead: + def test_list_for_table_filters_by_table(self, repo): + repo.save("orders", "id", basetype="STRING") + repo.save("orders", "total", basetype="NUMERIC") + repo.save("orders", "status", basetype="STRING") + repo.save("customers", "email", basetype="STRING") + + orders_cols = repo.list_for_table("orders") + assert len(orders_cols) == 3 + assert all(c["table_id"] == "orders" for c in orders_cols) + + customer_cols = repo.list_for_table("customers") + assert len(customer_cols) == 1 + assert customer_cols[0]["column_name"] == "email" + + def test_list_for_table_ordered_by_column_name(self, repo): + repo.save("orders", "total", basetype="NUMERIC") + repo.save("orders", "id", basetype="STRING") + repo.save("orders", "status", basetype="STRING") + + cols = repo.list_for_table("orders") + names = [c["column_name"] for c in cols] + assert names == sorted(names) + + def test_get_missing_returns_none(self, repo): + result = repo.get("orders", "nonexistent") + assert result is None + + +class TestColumnMetadataDelete: + def test_delete_column(self, repo): + repo.save("orders", "id", basetype="STRING") + deleted = repo.delete("orders", "id") + assert deleted is True + assert repo.get("orders", "id") is None + + def test_delete_missing_returns_false(self, repo): + result = repo.delete("orders", "does_not_exist") + assert result is False + + +class TestColumnMetadataProposal: + def test_import_proposal_count(self, repo, tmp_path): + proposal = { + "tables": { + "orders": { + "columns": { + "id": {"basetype": "STRING", "description": "Order ID", "confidence": "high"}, + "total": {"basetype": "NUMERIC", "description": "Total amount"}, + } + }, + "customers": { + "columns": { + "email": {"basetype": "STRING", "description": "Customer email", "confidence": "medium"}, + } + }, + } + } + path = tmp_path / "proposal.json" + path.write_text(json.dumps(proposal)) + + count = repo.import_proposal(str(path)) + assert count == 3 + + def test_import_proposal_data(self, repo, tmp_path): + proposal = { + "tables": { + "orders": { + "columns": { + "id": {"basetype": "STRING", "description": "Order ID", "confidence": "high"}, + } + } + } + } + path = tmp_path / "proposal.json" + path.write_text(json.dumps(proposal)) + + repo.import_proposal(str(path)) + + result = repo.get("orders", "id") + assert result is not None + assert result["basetype"] == "STRING" + assert result["description"] == "Order ID" + assert result["confidence"] == "high" + + def test_import_sets_source_ai_enrichment(self, repo, tmp_path): + proposal = { + "tables": { + "orders": { + "columns": { + "id": {"basetype": "STRING"}, + } + } + } + } + path = tmp_path / "proposal.json" + path.write_text(json.dumps(proposal)) + + repo.import_proposal(str(path)) + + result = repo.get("orders", "id") + assert result["source"] == "ai_enrichment"