From 79b0b66f2e720de8e946dc9e3eab2498ba5fc673 Mon Sep 17 00:00:00 2001 From: ZdenekSrotyr Date: Fri, 27 Mar 2026 15:06:55 +0100 Subject: [PATCH] feat: add DuckDB state layer with all repository classes - src/db.py: schema with 14 tables matching design spec - 7 repository classes: SyncState, Users, Knowledge, Audit, Telegram, PendingCode, Script, TableRegistry, Profiles - 37 tests covering all CRUD operations --- src/db.py | 315 +++++++++++++++-------------- src/repositories/__init__.py | 4 + src/repositories/audit.py | 54 +++++ src/repositories/knowledge.py | 103 ++++++++++ src/repositories/notifications.py | 105 ++++++++++ src/repositories/profiles.py | 44 ++++ src/repositories/sync_state.py | 82 ++++++++ src/repositories/table_registry.py | 47 +++++ src/repositories/users.py | 61 ++++++ tests/test_db.py | 57 ++---- tests/test_repositories.py | 315 +++++++++++++++++++++++++++++ 11 files changed, 992 insertions(+), 195 deletions(-) create mode 100644 src/repositories/__init__.py create mode 100644 src/repositories/audit.py create mode 100644 src/repositories/knowledge.py create mode 100644 src/repositories/notifications.py create mode 100644 src/repositories/profiles.py create mode 100644 src/repositories/sync_state.py create mode 100644 src/repositories/table_registry.py create mode 100644 src/repositories/users.py create mode 100644 tests/test_repositories.py diff --git a/src/db.py b/src/db.py index 66052a2..de71c6c 100644 --- a/src/db.py +++ b/src/db.py @@ -1,8 +1,9 @@ -"""DuckDB connection management and schema initialization. +"""DuckDB connection management and schema versioning. -Provides connections to the system state database and analytics database, -with automatic directory creation and schema bootstrapping. +Provides get_system_db() for the system state database +and get_analytics_db() for the analytics database with parquet views. """ + import os from pathlib import Path @@ -10,181 +11,181 @@ import duckdb SCHEMA_VERSION = 1 -_SCHEMA_SQL = """ +_SYSTEM_SCHEMA = """ CREATE TABLE IF NOT EXISTS schema_version ( - version INTEGER NOT NULL, - applied_at TIMESTAMP DEFAULT current_timestamp -); - -CREATE TABLE IF NOT EXISTS audit_log ( - id VARCHAR PRIMARY KEY, - timestamp TIMESTAMP DEFAULT current_timestamp, - actor VARCHAR, - action VARCHAR NOT NULL, - entity_type VARCHAR, - entity_id VARCHAR, - details JSON -); - -CREATE TABLE IF NOT EXISTS dataset_permissions ( - id VARCHAR PRIMARY KEY, - user_email VARCHAR NOT NULL, - dataset VARCHAR NOT NULL, - permission VARCHAR NOT NULL DEFAULT 'read', - granted_by VARCHAR, - granted_at TIMESTAMP DEFAULT current_timestamp -); - -CREATE TABLE IF NOT EXISTS knowledge_items ( - id VARCHAR PRIMARY KEY, - title VARCHAR NOT NULL, - content VARCHAR, - category VARCHAR, - author VARCHAR, - status VARCHAR DEFAULT 'active', - metadata JSON, - created_at TIMESTAMP DEFAULT current_timestamp, - updated_at TIMESTAMP DEFAULT current_timestamp -); - -CREATE TABLE IF NOT EXISTS knowledge_votes ( - id VARCHAR PRIMARY KEY, - item_id VARCHAR NOT NULL, - user_email VARCHAR NOT NULL, - vote INTEGER NOT NULL, - created_at TIMESTAMP DEFAULT current_timestamp -); - -CREATE TABLE IF NOT EXISTS pending_codes ( - code VARCHAR PRIMARY KEY, - user_email VARCHAR NOT NULL, - purpose VARCHAR, - created_at TIMESTAMP DEFAULT current_timestamp, - expires_at TIMESTAMP -); - -CREATE TABLE IF NOT EXISTS script_registry ( - id VARCHAR PRIMARY KEY, - name VARCHAR NOT NULL, - path VARCHAR NOT NULL, - description VARCHAR, - author VARCHAR, - metadata JSON, - created_at TIMESTAMP DEFAULT current_timestamp, - updated_at TIMESTAMP DEFAULT current_timestamp -); - -CREATE TABLE IF NOT EXISTS sync_history ( - id VARCHAR PRIMARY KEY, - table_name VARCHAR NOT NULL, - status VARCHAR NOT NULL, - rows_synced INTEGER, - started_at TIMESTAMP DEFAULT current_timestamp, - finished_at TIMESTAMP, - error VARCHAR, - metadata JSON -); - -CREATE TABLE IF NOT EXISTS sync_state ( - table_name VARCHAR PRIMARY KEY, - last_sync TIMESTAMP, - status VARCHAR DEFAULT 'pending', - row_count INTEGER, - file_hash VARCHAR, - metadata JSON -); - -CREATE TABLE IF NOT EXISTS table_profiles ( - table_name VARCHAR PRIMARY KEY, - profile JSON, - profiled_at TIMESTAMP DEFAULT current_timestamp -); - -CREATE TABLE IF NOT EXISTS table_registry ( - table_name VARCHAR PRIMARY KEY, - bucket VARCHAR, - source VARCHAR, - sync_strategy VARCHAR DEFAULT 'full', - primary_key VARCHAR, - description VARCHAR, - metadata JSON, - registered_at TIMESTAMP DEFAULT current_timestamp, - updated_at TIMESTAMP DEFAULT current_timestamp -); - -CREATE TABLE IF NOT EXISTS telegram_links ( - chat_id VARCHAR PRIMARY KEY, - user_email VARCHAR NOT NULL, - linked_at TIMESTAMP DEFAULT current_timestamp, - active BOOLEAN DEFAULT true -); - -CREATE TABLE IF NOT EXISTS user_sync_settings ( - user_email VARCHAR PRIMARY KEY, - settings JSON, - updated_at TIMESTAMP DEFAULT current_timestamp + version INTEGER NOT NULL, + applied_at TIMESTAMP DEFAULT current_timestamp ); CREATE TABLE IF NOT EXISTS users ( - email VARCHAR PRIMARY KEY, - name VARCHAR, - picture VARCHAR, - role VARCHAR DEFAULT 'analyst', - is_active BOOLEAN DEFAULT true, - metadata JSON, - created_at TIMESTAMP DEFAULT current_timestamp, - last_login TIMESTAMP + id VARCHAR PRIMARY KEY, + email VARCHAR UNIQUE NOT NULL, + name VARCHAR, + role VARCHAR DEFAULT 'analyst', + password_hash VARCHAR, + setup_token VARCHAR, + setup_token_created TIMESTAMP, + reset_token VARCHAR, + reset_token_created TIMESTAMP, + created_at TIMESTAMP DEFAULT current_timestamp, + updated_at TIMESTAMP +); + +CREATE TABLE IF NOT EXISTS sync_state ( + table_id VARCHAR PRIMARY KEY, + last_sync TIMESTAMP, + rows BIGINT, + file_size_bytes BIGINT, + uncompressed_size_bytes BIGINT, + columns INTEGER, + hash VARCHAR, + status VARCHAR DEFAULT 'ok', + error TEXT +); + +CREATE TABLE IF NOT EXISTS sync_history ( + id VARCHAR PRIMARY KEY, + table_id VARCHAR NOT NULL, + synced_at TIMESTAMP NOT NULL, + rows BIGINT, + duration_ms INTEGER, + status VARCHAR, + error TEXT +); + +CREATE TABLE IF NOT EXISTS user_sync_settings ( + user_id VARCHAR NOT NULL, + dataset VARCHAR NOT NULL, + enabled BOOLEAN DEFAULT false, + table_mode VARCHAR DEFAULT 'all', + tables JSON, + updated_at TIMESTAMP, + PRIMARY KEY (user_id, dataset) +); + +CREATE TABLE IF NOT EXISTS knowledge_items ( + id VARCHAR PRIMARY KEY, + title VARCHAR NOT NULL, + content TEXT, + category VARCHAR, + tags JSON, + status VARCHAR DEFAULT 'pending', + contributors JSON, + source_user VARCHAR, + audience VARCHAR, + created_at TIMESTAMP DEFAULT current_timestamp, + updated_at TIMESTAMP +); + +CREATE TABLE IF NOT EXISTS knowledge_votes ( + item_id VARCHAR NOT NULL, + user_id VARCHAR NOT NULL, + vote INTEGER, + voted_at TIMESTAMP DEFAULT current_timestamp, + PRIMARY KEY (item_id, user_id) +); + +CREATE TABLE IF NOT EXISTS audit_log ( + id VARCHAR PRIMARY KEY, + timestamp TIMESTAMP NOT NULL DEFAULT current_timestamp, + user_id VARCHAR, + action VARCHAR NOT NULL, + resource VARCHAR, + params JSON, + result VARCHAR, + duration_ms INTEGER +); + +CREATE TABLE IF NOT EXISTS telegram_links ( + user_id VARCHAR PRIMARY KEY, + chat_id BIGINT NOT NULL, + linked_at TIMESTAMP DEFAULT current_timestamp +); + +CREATE TABLE IF NOT EXISTS pending_codes ( + code VARCHAR PRIMARY KEY, + chat_id BIGINT NOT NULL, + created_at TIMESTAMP DEFAULT current_timestamp +); + +CREATE TABLE IF NOT EXISTS script_registry ( + id VARCHAR PRIMARY KEY, + name VARCHAR NOT NULL, + owner VARCHAR, + schedule VARCHAR, + source TEXT NOT NULL, + deployed_at TIMESTAMP DEFAULT current_timestamp, + last_run TIMESTAMP, + last_status VARCHAR +); + +CREATE TABLE IF NOT EXISTS table_registry ( + id VARCHAR PRIMARY KEY, + name VARCHAR NOT NULL, + folder VARCHAR, + sync_strategy VARCHAR, + primary_key VARCHAR, + description TEXT, + registered_by VARCHAR, + registered_at TIMESTAMP DEFAULT current_timestamp +); + +CREATE TABLE IF NOT EXISTS table_profiles ( + table_id VARCHAR PRIMARY KEY, + profile JSON NOT NULL, + profiled_at TIMESTAMP DEFAULT current_timestamp +); + +CREATE TABLE IF NOT EXISTS dataset_permissions ( + user_id VARCHAR NOT NULL, + dataset VARCHAR NOT NULL, + access VARCHAR DEFAULT 'read', + PRIMARY KEY (user_id, dataset) ); """ def _get_data_dir() -> Path: - """Return the DATA_DIR path, defaulting to ./data.""" - return Path(os.environ.get("DATA_DIR", "data")) + return Path(os.environ.get("DATA_DIR", "./data")) def get_system_db() -> duckdb.DuckDBPyConnection: - """Open (or create) the system state database and ensure schema exists. - - Returns a DuckDB connection to {DATA_DIR}/state/system.duckdb. - Creates directories and all schema tables on first call. - """ - db_dir = _get_data_dir() / "state" - db_dir.mkdir(parents=True, exist_ok=True) - db_path = db_dir / "system.duckdb" - + """Get a connection to the system state database. Creates schema if needed.""" + db_path = _get_data_dir() / "state" / "system.duckdb" + db_path.parent.mkdir(parents=True, exist_ok=True) conn = duckdb.connect(str(db_path)) - conn.execute(_SCHEMA_SQL) - - # Seed schema_version if empty - row = conn.execute("SELECT COUNT(*) FROM schema_version").fetchone() - if row[0] == 0: - conn.execute( - "INSERT INTO schema_version (version) VALUES (?)", [SCHEMA_VERSION] - ) - + _ensure_schema(conn) return conn def get_analytics_db() -> duckdb.DuckDBPyConnection: - """Open (or create) the analytics database. - - Returns a DuckDB connection to {DATA_DIR}/analytics/server.duckdb. - Creates directories if needed. - """ - db_dir = _get_data_dir() / "analytics" - db_dir.mkdir(parents=True, exist_ok=True) - db_path = db_dir / "server.duckdb" - + """Get a connection to the analytics database (parquet views).""" + db_path = _get_data_dir() / "analytics" / "server.duckdb" + db_path.parent.mkdir(parents=True, exist_ok=True) return duckdb.connect(str(db_path)) +def _ensure_schema(conn: duckdb.DuckDBPyConnection) -> None: + """Create tables if they don't exist. Apply migrations if schema version changed.""" + current = get_schema_version(conn) + if current < SCHEMA_VERSION: + conn.execute(_SYSTEM_SCHEMA) + if current == 0: + conn.execute( + "INSERT INTO schema_version (version) VALUES (?)", + [SCHEMA_VERSION], + ) + else: + conn.execute( + "UPDATE schema_version SET version = ?, applied_at = current_timestamp", + [SCHEMA_VERSION], + ) + + def get_schema_version(conn: duckdb.DuckDBPyConnection) -> int: - """Return the current schema version, or 0 if no schema_version table.""" + """Get current schema version. Returns 0 if no schema exists.""" try: - row = conn.execute( - "SELECT version FROM schema_version ORDER BY applied_at DESC LIMIT 1" - ).fetchone() - return row[0] if row else 0 + result = conn.execute("SELECT MAX(version) FROM schema_version").fetchone() + return result[0] if result and result[0] else 0 except duckdb.CatalogException: return 0 diff --git a/src/repositories/__init__.py b/src/repositories/__init__.py new file mode 100644 index 0000000..c5068e0 --- /dev/null +++ b/src/repositories/__init__.py @@ -0,0 +1,4 @@ +"""Repository layer for DuckDB state management.""" +from src.db import get_system_db, get_analytics_db + +__all__ = ["get_system_db", "get_analytics_db"] diff --git a/src/repositories/audit.py b/src/repositories/audit.py new file mode 100644 index 0000000..4473887 --- /dev/null +++ b/src/repositories/audit.py @@ -0,0 +1,54 @@ +"""Repository for audit logging.""" + +import json +import uuid +from datetime import datetime, timezone +from typing import Any, Optional, List, Dict + +import duckdb + + +class AuditRepository: + def __init__(self, conn: duckdb.DuckDBPyConnection): + self.conn = conn + + def log( + self, + user_id: Optional[str] = None, + action: str = "", + resource: Optional[str] = None, + params: Optional[dict] = None, + result: Optional[str] = None, + duration_ms: Optional[int] = None, + ) -> str: + entry_id = str(uuid.uuid4()) + now = datetime.now(timezone.utc) + self.conn.execute( + """INSERT INTO audit_log (id, timestamp, user_id, action, resource, params, result, duration_ms) + VALUES (?, ?, ?, ?, ?, ?, ?, ?)""", + [entry_id, now, user_id, action, resource, + json.dumps(params) if params else None, result, duration_ms], + ) + return entry_id + + def query( + self, + user_id: Optional[str] = None, + action: Optional[str] = None, + limit: int = 50, + ) -> List[Dict[str, Any]]: + sql = "SELECT * FROM audit_log WHERE 1=1" + params: List[Any] = [] + if user_id: + sql += " AND user_id = ?" + params.append(user_id) + if action: + sql += " AND action = ?" + params.append(action) + sql += " ORDER BY timestamp DESC LIMIT ?" + params.append(limit) + results = self.conn.execute(sql, params).fetchall() + if not results: + return [] + columns = [desc[0] for desc in self.conn.description] + return [dict(zip(columns, row)) for row in results] diff --git a/src/repositories/knowledge.py b/src/repositories/knowledge.py new file mode 100644 index 0000000..1553562 --- /dev/null +++ b/src/repositories/knowledge.py @@ -0,0 +1,103 @@ +"""Repository for corporate memory knowledge items and votes.""" + +import json +from datetime import datetime, timezone +from typing import Any, Optional, List, Dict + +import duckdb + + +class KnowledgeRepository: + def __init__(self, conn: duckdb.DuckDBPyConnection): + self.conn = conn + + def _row_to_dict(self, row) -> Optional[Dict[str, Any]]: + if not row: + return None + columns = [desc[0] for desc in self.conn.description] + return dict(zip(columns, row)) + + def _rows_to_dicts(self, rows) -> List[Dict[str, Any]]: + if not rows: + return [] + columns = [desc[0] for desc in self.conn.description] + return [dict(zip(columns, row)) for row in rows] + + def get_by_id(self, item_id: str) -> Optional[Dict[str, Any]]: + result = self.conn.execute("SELECT * FROM knowledge_items WHERE id = ?", [item_id]).fetchone() + return self._row_to_dict(result) + + def create( + self, + id: str, + title: str, + content: str, + category: str, + source_user: Optional[str] = None, + tags: Optional[List[str]] = None, + status: str = "pending", + ) -> None: + now = datetime.now(timezone.utc) + self.conn.execute( + """INSERT INTO knowledge_items (id, title, content, category, source_user, + tags, status, created_at, updated_at) + VALUES (?, ?, ?, ?, ?, ?, ?, ?, ?)""", + [id, title, content, category, source_user, + json.dumps(tags) if tags else None, status, now, now], + ) + + def update_status(self, item_id: str, status: str) -> None: + now = datetime.now(timezone.utc) + self.conn.execute( + "UPDATE knowledge_items SET status = ?, updated_at = ? WHERE id = ?", + [status, now, item_id], + ) + + def list_items( + self, + statuses: Optional[List[str]] = None, + category: Optional[str] = None, + limit: int = 100, + offset: int = 0, + ) -> List[Dict[str, Any]]: + query = "SELECT * FROM knowledge_items WHERE 1=1" + params: List[Any] = [] + if statuses: + placeholders = ", ".join("?" for _ in statuses) + query += f" AND status IN ({placeholders})" + params.extend(statuses) + if category: + query += " AND category = ?" + params.append(category) + query += " ORDER BY updated_at DESC LIMIT ? OFFSET ?" + params.extend([limit, offset]) + return self._rows_to_dicts(self.conn.execute(query, params).fetchall()) + + def search(self, query: str) -> List[Dict[str, Any]]: + pattern = f"%{query}%" + results = self.conn.execute( + """SELECT * FROM knowledge_items + WHERE title ILIKE ? OR content ILIKE ? + ORDER BY updated_at DESC""", + [pattern, pattern], + ).fetchall() + return self._rows_to_dicts(results) + + def vote(self, item_id: str, user_id: str, vote: int) -> None: + now = datetime.now(timezone.utc) + self.conn.execute( + """INSERT INTO knowledge_votes (item_id, user_id, vote, voted_at) + VALUES (?, ?, ?, ?) + ON CONFLICT (item_id, user_id) DO UPDATE SET vote = excluded.vote, voted_at = excluded.voted_at""", + [item_id, user_id, vote, now], + ) + + def get_votes(self, item_id: str) -> Dict[str, int]: + result = self.conn.execute( + """SELECT + COALESCE(SUM(CASE WHEN vote > 0 THEN 1 ELSE 0 END), 0) as upvotes, + COALESCE(SUM(CASE WHEN vote < 0 THEN 1 ELSE 0 END), 0) as downvotes + FROM knowledge_votes WHERE item_id = ?""", + [item_id], + ).fetchone() + return {"upvotes": result[0], "downvotes": result[1]} diff --git a/src/repositories/notifications.py b/src/repositories/notifications.py new file mode 100644 index 0000000..bf9a6f6 --- /dev/null +++ b/src/repositories/notifications.py @@ -0,0 +1,105 @@ +"""Repositories for Telegram links, pending codes, and script registry.""" + +from datetime import datetime, timezone +from typing import Any, Optional, List, Dict + +import duckdb + + +class TelegramRepository: + def __init__(self, conn: duckdb.DuckDBPyConnection): + self.conn = conn + + def link_user(self, user_id: str, chat_id: int) -> None: + now = datetime.now(timezone.utc) + self.conn.execute( + """INSERT INTO telegram_links (user_id, chat_id, linked_at) + VALUES (?, ?, ?) + ON CONFLICT (user_id) DO UPDATE SET chat_id = excluded.chat_id, linked_at = excluded.linked_at""", + [user_id, chat_id, now], + ) + + def unlink_user(self, user_id: str) -> None: + self.conn.execute("DELETE FROM telegram_links WHERE user_id = ?", [user_id]) + + def get_link(self, user_id: str) -> Optional[Dict[str, Any]]: + result = self.conn.execute( + "SELECT * FROM telegram_links WHERE user_id = ?", [user_id] + ).fetchone() + if not result: + return None + columns = [desc[0] for desc in self.conn.description] + return dict(zip(columns, result)) + + def get_all_links(self) -> List[Dict[str, Any]]: + results = self.conn.execute("SELECT * FROM telegram_links").fetchall() + if not results: + return [] + columns = [desc[0] for desc in self.conn.description] + return [dict(zip(columns, row)) for row in results] + + +class PendingCodeRepository: + def __init__(self, conn: duckdb.DuckDBPyConnection): + self.conn = conn + + def create_code(self, code: str, chat_id: int) -> None: + now = datetime.now(timezone.utc) + self.conn.execute( + "INSERT INTO pending_codes (code, chat_id, created_at) VALUES (?, ?, ?)", + [code, chat_id, now], + ) + + def verify_code(self, code: str) -> Optional[Dict[str, Any]]: + result = self.conn.execute( + "SELECT * FROM pending_codes WHERE code = ?", [code] + ).fetchone() + if not result: + return None + columns = [desc[0] for desc in self.conn.description] + row = dict(zip(columns, result)) + self.conn.execute("DELETE FROM pending_codes WHERE code = ?", [code]) + return row + + +class ScriptRepository: + def __init__(self, conn: duckdb.DuckDBPyConnection): + self.conn = conn + + def deploy( + self, id: str, name: str, owner: Optional[str] = None, + schedule: Optional[str] = None, source: str = "", + ) -> None: + now = datetime.now(timezone.utc) + self.conn.execute( + """INSERT INTO script_registry (id, name, owner, schedule, source, deployed_at) + VALUES (?, ?, ?, ?, ?, ?) + ON CONFLICT (id) DO UPDATE SET + name = excluded.name, schedule = excluded.schedule, + source = excluded.source, deployed_at = excluded.deployed_at""", + [id, name, owner, schedule, source, now], + ) + + def undeploy(self, script_id: str) -> None: + self.conn.execute("DELETE FROM script_registry WHERE id = ?", [script_id]) + + def get(self, script_id: str) -> Optional[Dict[str, Any]]: + result = self.conn.execute( + "SELECT * FROM script_registry WHERE id = ?", [script_id] + ).fetchone() + if not result: + return None + columns = [desc[0] for desc in self.conn.description] + return dict(zip(columns, result)) + + def list_all(self, owner: Optional[str] = None) -> List[Dict[str, Any]]: + if owner: + results = self.conn.execute( + "SELECT * FROM script_registry WHERE owner = ? ORDER BY name", [owner] + ).fetchall() + else: + results = self.conn.execute("SELECT * FROM script_registry ORDER BY name").fetchall() + if not results: + return [] + columns = [desc[0] for desc in self.conn.description] + return [dict(zip(columns, row)) for row in results] diff --git a/src/repositories/profiles.py b/src/repositories/profiles.py new file mode 100644 index 0000000..08970dc --- /dev/null +++ b/src/repositories/profiles.py @@ -0,0 +1,44 @@ +"""Repository for table profiles.""" + +import json +from datetime import datetime, timezone +from typing import Any, Optional, Dict + +import duckdb + + +class ProfileRepository: + def __init__(self, conn: duckdb.DuckDBPyConnection): + self.conn = conn + + def save(self, table_id: str, profile: dict) -> None: + now = datetime.now(timezone.utc) + self.conn.execute( + """INSERT INTO table_profiles (table_id, profile, profiled_at) + VALUES (?, ?, ?) + ON CONFLICT (table_id) DO UPDATE SET + profile = excluded.profile, profiled_at = excluded.profiled_at""", + [table_id, json.dumps(profile), now], + ) + + def get(self, table_id: str) -> Optional[Dict[str, Any]]: + result = self.conn.execute( + "SELECT profile, profiled_at FROM table_profiles WHERE table_id = ?", + [table_id], + ).fetchone() + if not result: + return None + profile = json.loads(result[0]) if isinstance(result[0], str) else result[0] + profile["profiled_at"] = result[1] + return profile + + def get_all(self) -> Dict[str, dict]: + results = self.conn.execute( + "SELECT table_id, profile, profiled_at FROM table_profiles ORDER BY table_id" + ).fetchall() + out = {} + for row in results: + profile = json.loads(row[1]) if isinstance(row[1], str) else row[1] + profile["profiled_at"] = row[2] + out[row[0]] = profile + return out diff --git a/src/repositories/sync_state.py b/src/repositories/sync_state.py new file mode 100644 index 0000000..5744dd0 --- /dev/null +++ b/src/repositories/sync_state.py @@ -0,0 +1,82 @@ +"""Repository for sync state and history.""" + +import uuid +from datetime import datetime, timezone +from typing import Any, Optional, List, Dict + +import duckdb + + +class SyncStateRepository: + def __init__(self, conn: duckdb.DuckDBPyConnection): + self.conn = conn + + def _row_to_dict(self, row) -> Optional[Dict[str, Any]]: + if not row: + return None + columns = [desc[0] for desc in self.conn.description] + return dict(zip(columns, row)) + + def _rows_to_dicts(self, rows) -> List[Dict[str, Any]]: + if not rows: + return [] + columns = [desc[0] for desc in self.conn.description] + return [dict(zip(columns, row)) for row in rows] + + def get_table_state(self, table_id: str) -> Optional[Dict[str, Any]]: + result = self.conn.execute( + "SELECT * FROM sync_state WHERE table_id = ?", [table_id] + ).fetchone() + return self._row_to_dict(result) + + def get_last_sync(self, table_id: str) -> Optional[datetime]: + result = self.conn.execute( + "SELECT last_sync FROM sync_state WHERE table_id = ?", [table_id] + ).fetchone() + return result[0] if result else None + + def get_all_states(self) -> List[Dict[str, Any]]: + results = self.conn.execute("SELECT * FROM sync_state ORDER BY table_id").fetchall() + return self._rows_to_dicts(results) + + def update_sync( + self, + table_id: str, + rows: int, + file_size_bytes: int, + hash: str, + uncompressed_size_bytes: int = 0, + columns: int = 0, + status: str = "ok", + error: Optional[str] = None, + duration_ms: Optional[int] = None, + ) -> None: + now = datetime.now(timezone.utc) + self.conn.execute( + """INSERT INTO sync_state (table_id, last_sync, rows, file_size_bytes, + uncompressed_size_bytes, columns, hash, status, error) + VALUES (?, ?, ?, ?, ?, ?, ?, ?, ?) + ON CONFLICT (table_id) DO UPDATE SET + last_sync = excluded.last_sync, + rows = excluded.rows, + file_size_bytes = excluded.file_size_bytes, + uncompressed_size_bytes = excluded.uncompressed_size_bytes, + columns = excluded.columns, + hash = excluded.hash, + status = excluded.status, + error = excluded.error""", + [table_id, now, rows, file_size_bytes, uncompressed_size_bytes, + columns, hash, status, error], + ) + self.conn.execute( + """INSERT INTO sync_history (id, table_id, synced_at, rows, duration_ms, status, error) + VALUES (?, ?, ?, ?, ?, ?, ?)""", + [str(uuid.uuid4()), table_id, now, rows, duration_ms, status, error], + ) + + def get_sync_history(self, table_id: str, limit: int = 10) -> List[Dict[str, Any]]: + results = self.conn.execute( + "SELECT * FROM sync_history WHERE table_id = ? ORDER BY synced_at DESC LIMIT ?", + [table_id, limit], + ).fetchall() + return self._rows_to_dicts(results) diff --git a/src/repositories/table_registry.py b/src/repositories/table_registry.py new file mode 100644 index 0000000..35dda1c --- /dev/null +++ b/src/repositories/table_registry.py @@ -0,0 +1,47 @@ +"""Repository for table registry.""" + +from datetime import datetime, timezone +from typing import Any, Optional, List, Dict + +import duckdb + + +class TableRegistryRepository: + def __init__(self, conn: duckdb.DuckDBPyConnection): + self.conn = conn + + def register( + self, id: str, name: str, folder: Optional[str] = None, + sync_strategy: Optional[str] = None, primary_key: Optional[str] = None, + description: Optional[str] = None, registered_by: Optional[str] = None, + ) -> None: + now = datetime.now(timezone.utc) + self.conn.execute( + """INSERT INTO table_registry (id, name, folder, sync_strategy, + primary_key, description, registered_by, registered_at) + VALUES (?, ?, ?, ?, ?, ?, ?, ?) + ON CONFLICT (id) DO UPDATE SET + name = excluded.name, folder = excluded.folder, + sync_strategy = excluded.sync_strategy, primary_key = excluded.primary_key, + description = excluded.description, registered_at = excluded.registered_at""", + [id, name, folder, sync_strategy, primary_key, description, registered_by, now], + ) + + def unregister(self, table_id: str) -> None: + self.conn.execute("DELETE FROM table_registry WHERE id = ?", [table_id]) + + def get(self, table_id: str) -> Optional[Dict[str, Any]]: + result = self.conn.execute( + "SELECT * FROM table_registry WHERE id = ?", [table_id] + ).fetchone() + if not result: + return None + columns = [desc[0] for desc in self.conn.description] + return dict(zip(columns, result)) + + def list_all(self) -> List[Dict[str, Any]]: + results = self.conn.execute("SELECT * FROM table_registry ORDER BY name").fetchall() + if not results: + return [] + columns = [desc[0] for desc in self.conn.description] + return [dict(zip(columns, row)) for row in results] diff --git a/src/repositories/users.py b/src/repositories/users.py new file mode 100644 index 0000000..1990f06 --- /dev/null +++ b/src/repositories/users.py @@ -0,0 +1,61 @@ +"""Repository for user management.""" + +from datetime import datetime, timezone +from typing import Any, Optional, List, Dict + +import duckdb + + +class UserRepository: + def __init__(self, conn: duckdb.DuckDBPyConnection): + self.conn = conn + + def _row_to_dict(self, row) -> Optional[Dict[str, Any]]: + if not row: + return None + columns = [desc[0] for desc in self.conn.description] + return dict(zip(columns, row)) + + def get_by_id(self, user_id: str) -> Optional[Dict[str, Any]]: + result = self.conn.execute("SELECT * FROM users WHERE id = ?", [user_id]).fetchone() + return self._row_to_dict(result) + + def get_by_email(self, email: str) -> Optional[Dict[str, Any]]: + result = self.conn.execute("SELECT * FROM users WHERE email = ?", [email]).fetchone() + return self._row_to_dict(result) + + def list_all(self) -> List[Dict[str, Any]]: + results = self.conn.execute("SELECT * FROM users ORDER BY email").fetchall() + if not results: + return [] + columns = [desc[0] for desc in self.conn.description] + return [dict(zip(columns, row)) for row in results] + + def create( + self, + id: str, + email: str, + name: str, + role: str = "analyst", + password_hash: Optional[str] = None, + ) -> None: + now = datetime.now(timezone.utc) + self.conn.execute( + """INSERT INTO users (id, email, name, role, password_hash, created_at, updated_at) + VALUES (?, ?, ?, ?, ?, ?, ?)""", + [id, email, name, role, password_hash, now, now], + ) + + def update(self, id: str, **kwargs) -> None: + allowed = {"email", "name", "role", "password_hash", "setup_token", + "setup_token_created", "reset_token", "reset_token_created"} + updates = {k: v for k, v in kwargs.items() if k in allowed} + if not updates: + return + updates["updated_at"] = datetime.now(timezone.utc) + set_clause = ", ".join(f"{k} = ?" for k in updates) + values = list(updates.values()) + [id] + self.conn.execute(f"UPDATE users SET {set_clause} WHERE id = ?", values) + + def delete(self, user_id: str) -> None: + self.conn.execute("DELETE FROM users WHERE id = ?", [user_id]) diff --git a/tests/test_db.py b/tests/test_db.py index b571bc4..998c34b 100644 --- a/tests/test_db.py +++ b/tests/test_db.py @@ -7,69 +7,53 @@ import pytest def _setup_data_dir(tmp_path): - """Set DATA_DIR env var to a temporary directory.""" os.environ["DATA_DIR"] = str(tmp_path) class TestGetSystemDb: - """Tests for get_system_db().""" - - def test_get_system_db_creates_tables(self, tmp_path): + def test_creates_all_tables(self, tmp_path): _setup_data_dir(tmp_path) from src.db import get_system_db conn = get_system_db() try: - tables = [ + tables = { row[0] for row in conn.execute( - "SELECT table_name FROM information_schema.tables " - "WHERE table_schema = 'main' ORDER BY table_name" + "SELECT table_name FROM information_schema.tables WHERE table_schema = 'main'" ).fetchall() - ] - expected = sorted([ - "schema_version", - "users", - "sync_state", - "sync_history", - "user_sync_settings", - "knowledge_items", - "knowledge_votes", - "audit_log", - "telegram_links", - "pending_codes", - "script_registry", - "table_registry", - "table_profiles", + } + expected = { + "schema_version", "users", "sync_state", "sync_history", + "user_sync_settings", "knowledge_items", "knowledge_votes", + "audit_log", "telegram_links", "pending_codes", + "script_registry", "table_registry", "table_profiles", "dataset_permissions", - ]) - assert tables == expected + } + assert expected.issubset(tables), f"Missing: {expected - tables}" finally: conn.close() - def test_get_system_db_idempotent(self, tmp_path): + def test_idempotent(self, tmp_path): _setup_data_dir(tmp_path) from src.db import get_system_db conn = get_system_db() conn.execute( - "INSERT INTO users (email, name) VALUES ('test@example.com', 'Test')" + "INSERT INTO users (id, email, name, role) VALUES ('u1', 'test@test.com', 'Test', 'analyst')" ) conn.close() conn2 = get_system_db() try: - rows = conn2.execute("SELECT email FROM users").fetchall() - assert len(rows) == 1 - assert rows[0][0] == "test@example.com" + result = conn2.execute("SELECT email FROM users WHERE id='u1'").fetchone() + assert result[0] == "test@test.com" finally: conn2.close() class TestGetSchemaVersion: - """Tests for get_schema_version().""" - - def test_get_schema_version(self, tmp_path): + def test_returns_version(self, tmp_path): _setup_data_dir(tmp_path) from src.db import get_schema_version, get_system_db @@ -79,12 +63,11 @@ class TestGetSchemaVersion: finally: conn.close() - def test_get_schema_version_no_table(self, tmp_path): + def test_returns_zero_for_empty_db(self, tmp_path): _setup_data_dir(tmp_path) from src.db import get_schema_version - db_path = tmp_path / "empty.duckdb" - conn = duckdb.connect(str(db_path)) + conn = duckdb.connect(str(tmp_path / "empty.duckdb")) try: assert get_schema_version(conn) == 0 finally: @@ -92,9 +75,7 @@ class TestGetSchemaVersion: class TestGetAnalyticsDb: - """Tests for get_analytics_db().""" - - def test_get_analytics_db(self, tmp_path): + def test_creates_db(self, tmp_path): _setup_data_dir(tmp_path) from src.db import get_analytics_db diff --git a/tests/test_repositories.py b/tests/test_repositories.py new file mode 100644 index 0000000..3417ec8 --- /dev/null +++ b/tests/test_repositories.py @@ -0,0 +1,315 @@ +"""Tests for all DuckDB repository classes.""" + +import os +import pytest + + +@pytest.fixture +def db_conn(tmp_path): + os.environ["DATA_DIR"] = str(tmp_path) + from src.db import get_system_db + conn = get_system_db() + yield conn + conn.close() + + +# ---- SyncState ---- + +class TestSyncStateRepository: + def test_update_and_get(self, db_conn): + from src.repositories.sync_state import SyncStateRepository + repo = SyncStateRepository(db_conn) + repo.update_sync(table_id="orders", rows=1000, file_size_bytes=5000, hash="abc123") + state = repo.get_table_state("orders") + assert state is not None + assert state["rows"] == 1000 + assert state["hash"] == "abc123" + assert state["status"] == "ok" + + def test_get_nonexistent(self, db_conn): + from src.repositories.sync_state import SyncStateRepository + repo = SyncStateRepository(db_conn) + assert repo.get_table_state("nonexistent") is None + + def test_get_last_sync(self, db_conn): + from src.repositories.sync_state import SyncStateRepository + repo = SyncStateRepository(db_conn) + repo.update_sync(table_id="orders", rows=100, file_size_bytes=500, hash="h1") + last = repo.get_last_sync("orders") + assert last is not None + + def test_get_all_states(self, db_conn): + from src.repositories.sync_state import SyncStateRepository + repo = SyncStateRepository(db_conn) + repo.update_sync(table_id="orders", rows=100, file_size_bytes=500, hash="h1") + repo.update_sync(table_id="customers", rows=50, file_size_bytes=200, hash="h2") + all_states = repo.get_all_states() + assert len(all_states) == 2 + + def test_history_recorded(self, db_conn): + from src.repositories.sync_state import SyncStateRepository + repo = SyncStateRepository(db_conn) + repo.update_sync(table_id="orders", rows=100, file_size_bytes=500, hash="h1") + repo.update_sync(table_id="orders", rows=200, file_size_bytes=800, hash="h2") + history = repo.get_sync_history("orders", limit=10) + assert len(history) == 2 + assert history[0]["rows"] == 200 # newest first + + def test_update_with_error(self, db_conn): + from src.repositories.sync_state import SyncStateRepository + repo = SyncStateRepository(db_conn) + repo.update_sync( + table_id="orders", rows=0, file_size_bytes=0, hash="", + status="error", error="Connection timeout", + ) + state = repo.get_table_state("orders") + assert state["status"] == "error" + assert state["error"] == "Connection timeout" + + +# ---- Users ---- + +class TestUserRepository: + def test_create_and_get(self, db_conn): + from src.repositories.users import UserRepository + repo = UserRepository(db_conn) + repo.create(id="u1", email="test@acme.com", name="Test User", role="analyst") + user = repo.get_by_id("u1") + assert user is not None + assert user["email"] == "test@acme.com" + assert user["role"] == "analyst" + + def test_get_by_email(self, db_conn): + from src.repositories.users import UserRepository + repo = UserRepository(db_conn) + repo.create(id="u1", email="test@acme.com", name="Test User") + user = repo.get_by_email("test@acme.com") + assert user is not None + assert user["id"] == "u1" + + def test_get_nonexistent(self, db_conn): + from src.repositories.users import UserRepository + repo = UserRepository(db_conn) + assert repo.get_by_id("nope") is None + assert repo.get_by_email("nope@nope.com") is None + + def test_list_all(self, db_conn): + from src.repositories.users import UserRepository + repo = UserRepository(db_conn) + repo.create(id="u1", email="a@acme.com", name="A") + repo.create(id="u2", email="b@acme.com", name="B") + assert len(repo.list_all()) == 2 + + def test_update_role(self, db_conn): + from src.repositories.users import UserRepository + repo = UserRepository(db_conn) + repo.create(id="u1", email="test@acme.com", name="Test") + repo.update(id="u1", role="admin") + user = repo.get_by_id("u1") + assert user["role"] == "admin" + + def test_delete(self, db_conn): + from src.repositories.users import UserRepository + repo = UserRepository(db_conn) + repo.create(id="u1", email="test@acme.com", name="Test") + repo.delete("u1") + assert repo.get_by_id("u1") is None + + def test_set_password_hash(self, db_conn): + from src.repositories.users import UserRepository + repo = UserRepository(db_conn) + repo.create(id="u1", email="test@acme.com", name="Test") + repo.update(id="u1", password_hash="$argon2id$hashed") + user = repo.get_by_id("u1") + assert user["password_hash"] == "$argon2id$hashed" + + +# ---- Knowledge ---- + +class TestKnowledgeRepository: + def test_create_and_get(self, db_conn): + from src.repositories.knowledge import KnowledgeRepository + repo = KnowledgeRepository(db_conn) + repo.create(id="k1", title="MRR Definition", content="Monthly recurring...", + category="metrics", source_user="petr@acme.com") + item = repo.get_by_id("k1") + assert item is not None + assert item["title"] == "MRR Definition" + assert item["status"] == "pending" + + def test_list_by_status(self, db_conn): + from src.repositories.knowledge import KnowledgeRepository + repo = KnowledgeRepository(db_conn) + repo.create(id="k1", title="A", content="a", category="c") + repo.create(id="k2", title="B", content="b", category="c") + repo.update_status("k1", "approved") + approved = repo.list_items(statuses=["approved"]) + assert len(approved) == 1 + assert approved[0]["id"] == "k1" + + def test_vote(self, db_conn): + from src.repositories.knowledge import KnowledgeRepository + repo = KnowledgeRepository(db_conn) + repo.create(id="k1", title="A", content="a", category="c") + repo.vote("k1", "user1", 1) + repo.vote("k1", "user2", -1) + votes = repo.get_votes("k1") + assert votes["upvotes"] == 1 + assert votes["downvotes"] == 1 + + def test_vote_replace(self, db_conn): + from src.repositories.knowledge import KnowledgeRepository + repo = KnowledgeRepository(db_conn) + repo.create(id="k1", title="A", content="a", category="c") + repo.vote("k1", "user1", 1) + repo.vote("k1", "user1", -1) # change vote + votes = repo.get_votes("k1") + assert votes["upvotes"] == 0 + assert votes["downvotes"] == 1 + + def test_search(self, db_conn): + from src.repositories.knowledge import KnowledgeRepository + repo = KnowledgeRepository(db_conn) + repo.create(id="k1", title="Revenue metrics", content="MRR definition", category="metrics") + repo.create(id="k2", title="Support SLA", content="Response times", category="support") + results = repo.search("revenue") + assert len(results) == 1 + assert results[0]["id"] == "k1" + + +# ---- Audit ---- + +class TestAuditRepository: + def test_log_and_query(self, db_conn): + from src.repositories.audit import AuditRepository + repo = AuditRepository(db_conn) + repo.log(user_id="u1", action="sync_trigger", resource="orders", + params={"force": True}, result="ok", duration_ms=1200) + entries = repo.query(limit=10) + assert len(entries) == 1 + assert entries[0]["action"] == "sync_trigger" + assert entries[0]["duration_ms"] == 1200 + + def test_query_by_action(self, db_conn): + from src.repositories.audit import AuditRepository + repo = AuditRepository(db_conn) + repo.log(user_id="u1", action="sync_trigger", resource="orders") + repo.log(user_id="u1", action="login", resource=None) + entries = repo.query(action="sync_trigger") + assert len(entries) == 1 + + def test_query_by_user(self, db_conn): + from src.repositories.audit import AuditRepository + repo = AuditRepository(db_conn) + repo.log(user_id="u1", action="sync_trigger", resource="orders") + repo.log(user_id="u2", action="sync_trigger", resource="customers") + entries = repo.query(user_id="u1") + assert len(entries) == 1 + + +# ---- Telegram ---- + +class TestTelegramRepository: + def test_link_and_get(self, db_conn): + from src.repositories.notifications import TelegramRepository + repo = TelegramRepository(db_conn) + repo.link_user("u1", chat_id=12345) + link = repo.get_link("u1") + assert link is not None + assert link["chat_id"] == 12345 + + def test_unlink(self, db_conn): + from src.repositories.notifications import TelegramRepository + repo = TelegramRepository(db_conn) + repo.link_user("u1", chat_id=12345) + repo.unlink_user("u1") + assert repo.get_link("u1") is None + + +# ---- PendingCode ---- + +class TestPendingCodeRepository: + def test_create_and_verify(self, db_conn): + from src.repositories.notifications import PendingCodeRepository + repo = PendingCodeRepository(db_conn) + repo.create_code("ABC123", chat_id=12345) + code = repo.verify_code("ABC123") + assert code is not None + assert code["chat_id"] == 12345 + # Code consumed + assert repo.verify_code("ABC123") is None + + +# ---- Script ---- + +class TestScriptRepository: + def test_deploy_and_get(self, db_conn): + from src.repositories.notifications import ScriptRepository + repo = ScriptRepository(db_conn) + repo.deploy("s1", name="sales_alert", owner="u1", + schedule="0 8 * * MON", source="print('hello')") + script = repo.get("s1") + assert script is not None + assert script["schedule"] == "0 8 * * MON" + + def test_list_all(self, db_conn): + from src.repositories.notifications import ScriptRepository + repo = ScriptRepository(db_conn) + repo.deploy("s1", name="alert1", owner="u1", source="pass") + repo.deploy("s2", name="alert2", owner="u1", source="pass") + assert len(repo.list_all()) == 2 + + def test_undeploy(self, db_conn): + from src.repositories.notifications import ScriptRepository + repo = ScriptRepository(db_conn) + repo.deploy("s1", name="test", owner="u1", source="pass") + repo.undeploy("s1") + assert repo.get("s1") is None + + +# ---- TableRegistry ---- + +class TestTableRegistryRepository: + def test_register_and_get(self, db_conn): + from src.repositories.table_registry import TableRegistryRepository + repo = TableRegistryRepository(db_conn) + repo.register(id="orders", name="Orders", folder="sales", + sync_strategy="incremental", registered_by="admin") + table = repo.get("orders") + assert table is not None + assert table["folder"] == "sales" + + def test_list_all(self, db_conn): + from src.repositories.table_registry import TableRegistryRepository + repo = TableRegistryRepository(db_conn) + repo.register(id="t1", name="A", folder="f1") + repo.register(id="t2", name="B", folder="f2") + assert len(repo.list_all()) == 2 + + def test_unregister(self, db_conn): + from src.repositories.table_registry import TableRegistryRepository + repo = TableRegistryRepository(db_conn) + repo.register(id="t1", name="A", folder="f1") + repo.unregister("t1") + assert repo.get("t1") is None + + +# ---- Profiles ---- + +class TestProfileRepository: + def test_save_and_get(self, db_conn): + from src.repositories.profiles import ProfileRepository + repo = ProfileRepository(db_conn) + profile_data = {"columns": [{"name": "id", "type": "int"}], "row_count": 1000} + repo.save("orders", profile_data) + profile = repo.get("orders") + assert profile is not None + assert profile["row_count"] == 1000 + + def test_get_all(self, db_conn): + from src.repositories.profiles import ProfileRepository + repo = ProfileRepository(db_conn) + repo.save("t1", {"row_count": 100}) + repo.save("t2", {"row_count": 200}) + all_profiles = repo.get_all() + assert len(all_profiles) == 2