diff --git a/tests/test_corporate_memory_collector.py b/tests/test_corporate_memory_collector.py new file mode 100644 index 0000000..827abf3 --- /dev/null +++ b/tests/test_corporate_memory_collector.py @@ -0,0 +1,245 @@ +"""Tests for Corporate Memory knowledge collector.""" + +import hashlib +import json +from datetime import datetime, timezone +from pathlib import Path +from unittest.mock import MagicMock, patch + +import pytest + + +# --------------------------------------------------------------------------- +# Minimal mock LLM extractor +# --------------------------------------------------------------------------- + +class MockLLMProvider: + """A minimal mock for connectors.llm.StructuredExtractor.""" + + def __init__(self, response: dict): + self._response = response + + def extract_json(self, prompt: str, max_tokens: int, json_schema: dict, schema_name: str) -> dict: + return self._response + + +# --------------------------------------------------------------------------- +# Helpers +# --------------------------------------------------------------------------- + +def _write_json(path: Path, data: dict): + path.parent.mkdir(parents=True, exist_ok=True) + with open(path, "w") as f: + json.dump(data, f) + + +# --------------------------------------------------------------------------- +# Tests for _generate_id +# --------------------------------------------------------------------------- + +class TestGenerateId: + def test_returns_km_prefix(self): + from services.corporate_memory.collector import _generate_id + item_id = _generate_id("hello world") + assert item_id.startswith("km_") + + def test_deterministic(self): + from services.corporate_memory.collector import _generate_id + assert _generate_id("same") == _generate_id("same") + + def test_different_content_different_id(self): + from services.corporate_memory.collector import _generate_id + assert _generate_id("aaa") != _generate_id("bbb") + + +# --------------------------------------------------------------------------- +# Tests for _process_catalog_response (hash change / governance preservation) +# --------------------------------------------------------------------------- + +class TestProcessCatalogResponse: + def test_new_item_gets_generated_id(self): + from services.corporate_memory.collector import _process_catalog_response + response_items = [ + { + "existing_id": None, + "title": "Tip One", + "content": "Always check the logs first.", + "category": "debugging", + "tags": ["logs"], + "source_users": ["alice"], + } + ] + result = _process_catalog_response(response_items, existing={"items": {}}) + assert len(result) == 1 + item_id, item = next(iter(result.items())) + assert item_id.startswith("km_") + assert item["title"] == "Tip One" + assert item["status"] == "approved" # default initial_status + + def test_existing_id_preserved(self): + from services.corporate_memory.collector import _process_catalog_response + existing = { + "items": { + "km_abc123": { + "id": "km_abc123", + "title": "Old Title", + "content": "Old content", + "category": "debugging", + "tags": [], + "source_users": ["bob"], + "extracted_at": "2026-01-01T00:00:00+00:00", + "status": "approved", + "approved_by": "admin", + "approved_at": "2026-01-02T00:00:00+00:00", + "mandatory_reason": None, + "audience": "all", + "review_by": None, + "edited_by": None, + "edited_at": None, + } + } + } + response_items = [ + { + "existing_id": "km_abc123", + "title": "Updated Title", + "content": "New content", + "category": "debugging", + "tags": ["updated"], + "source_users": ["bob"], + } + ] + result = _process_catalog_response(response_items, existing=existing) + assert "km_abc123" in result + item = result["km_abc123"] + assert item["title"] == "Updated Title" + + def test_governance_fields_preserved(self): + from services.corporate_memory.collector import GOVERNANCE_FIELDS, _process_catalog_response + existing = { + "items": { + "km_abc123": { + "id": "km_abc123", + "title": "T", + "content": "C", + "category": "workflow", + "tags": [], + "source_users": ["carol"], + "extracted_at": "2026-01-01T00:00:00+00:00", + "status": "approved", + "approved_by": "manager", + "approved_at": "2026-03-01T00:00:00+00:00", + "mandatory_reason": "Policy", + "audience": "team", + "review_by": "2026-12-31", + "edited_by": "carol", + "edited_at": "2026-02-01T00:00:00+00:00", + } + } + } + response_items = [ + { + "existing_id": "km_abc123", + "title": "T", + "content": "C updated", + "category": "workflow", + "tags": [], + "source_users": ["carol"], + } + ] + result = _process_catalog_response(response_items, existing=existing) + item = result["km_abc123"] + assert item["approved_by"] == "manager" + assert item["mandatory_reason"] == "Policy" + assert item["audience"] == "team" + + def test_new_item_with_pending_initial_status(self): + from services.corporate_memory.collector import _process_catalog_response + response_items = [ + { + "existing_id": None, + "title": "Another tip", + "content": "Some content", + "category": "workflow", + "tags": [], + "source_users": ["dave"], + } + ] + result = _process_catalog_response( + response_items, existing={"items": {}}, initial_status="pending" + ) + item = next(iter(result.values())) + assert item["status"] == "pending" + + +# --------------------------------------------------------------------------- +# Tests for check_sensitivity +# --------------------------------------------------------------------------- + +class TestCheckSensitivity: + def test_safe_item_returns_true(self): + from services.corporate_memory.collector import check_sensitivity + extractor = MockLLMProvider({"safe": True}) + item = {"id": "km_x", "title": "T", "content": "C", "tags": []} + assert check_sensitivity(extractor, item) is True + + def test_unsafe_item_returns_false(self): + from services.corporate_memory.collector import check_sensitivity + extractor = MockLLMProvider({"safe": False, "reason": "Contains PII"}) + item = {"id": "km_y", "title": "T", "content": "C", "tags": []} + assert check_sensitivity(extractor, item) is False + + def test_llm_error_returns_false(self): + """When the LLM raises an LLMError, the item is treated as unsafe.""" + from connectors.llm.exceptions import LLMError + from services.corporate_memory.collector import check_sensitivity + + class ErrorExtractor: + def extract_json(self, *args, **kwargs): + raise LLMError("Network error") + + item = {"id": "km_z", "title": "T", "content": "C", "tags": []} + assert check_sensitivity(ErrorExtractor(), item) is False + + +# --------------------------------------------------------------------------- +# Integration-style: collect_all with mocked I/O +# --------------------------------------------------------------------------- + +class TestCollectAllSkipsWhenNoChanges: + def test_skips_when_no_user_files(self, tmp_path): + """collect_all returns skipped=True when no CLAUDE.local.md files exist.""" + from services.corporate_memory import collector + + with ( + patch.object(collector, "HOME_BASE", tmp_path / "home"), + patch.object(collector, "KNOWLEDGE_FILE", tmp_path / "knowledge.json"), + patch.object(collector, "USER_HASHES_FILE", tmp_path / "user_hashes.json"), + ): + (tmp_path / "home").mkdir() + stats = collector.collect_all(dry_run=True) + assert stats["skipped"] is True + + def test_skips_when_hashes_unchanged(self, tmp_path): + """collect_all skips when hashes match stored values.""" + from services.corporate_memory import collector + + home = tmp_path / "home" + home.mkdir() + user_dir = home / "alice" + user_dir.mkdir() + claude_file = user_dir / "CLAUDE.local.md" + claude_file.write_text("# My tips\n- Always document code") + + content = claude_file.read_text(encoding="utf-8") + md5 = hashlib.md5(content.encode()).hexdigest() + user_hashes_file = tmp_path / "user_hashes.json" + _write_json(user_hashes_file, {"hashes": {"alice": md5}}) + + with ( + patch.object(collector, "HOME_BASE", home), + patch.object(collector, "KNOWLEDGE_FILE", tmp_path / "knowledge.json"), + patch.object(collector, "USER_HASHES_FILE", user_hashes_file), + ): + stats = collector.collect_all(dry_run=True) + assert stats["skipped"] is True diff --git a/tests/test_scheduler_full.py b/tests/test_scheduler_full.py new file mode 100644 index 0000000..332a53c --- /dev/null +++ b/tests/test_scheduler_full.py @@ -0,0 +1,134 @@ +"""Tests for schedule parsing and due-check logic in src/scheduler.py.""" + +from datetime import datetime, timezone + +import pytest + +from src.scheduler import is_table_due, parse_interval_minutes + + +# --------------------------------------------------------------------------- +# parse_interval_minutes +# --------------------------------------------------------------------------- + +class TestParseIntervalMinutes: + def test_every_15m(self): + assert parse_interval_minutes("every 15m") == 15 + + def test_every_1h(self): + assert parse_interval_minutes("every 1h") == 60 + + def test_every_2h(self): + assert parse_interval_minutes("every 2h") == 120 + + def test_every_30m(self): + assert parse_interval_minutes("every 30m") == 30 + + def test_daily_returns_none(self): + assert parse_interval_minutes("daily 05:00") is None + + def test_invalid_string_returns_none(self): + assert parse_interval_minutes("gibberish") is None + + def test_empty_string_returns_none(self): + assert parse_interval_minutes("") is None + + def test_every_0m(self): + # Edge case: zero minutes is still a valid parse + assert parse_interval_minutes("every 0m") == 0 + + +# --------------------------------------------------------------------------- +# is_table_due — interval schedules +# --------------------------------------------------------------------------- + +def _utc(year, month, day, hour=0, minute=0, second=0): + return datetime(year, month, day, hour, minute, second, tzinfo=timezone.utc) + + +class TestIsTableDueNeverSynced: + def test_never_synced_is_always_due(self): + assert is_table_due("every 1h", last_sync_iso=None) is True + + def test_empty_string_last_sync_is_due(self): + assert is_table_due("every 1h", last_sync_iso="") is True + + +class TestIsTableDueInterval: + NOW = _utc(2026, 4, 12, 10, 0, 0) + + def test_interval_not_elapsed(self): + # Synced 10 minutes ago, interval is 30m + last = _utc(2026, 4, 12, 9, 50, 0).isoformat() + assert is_table_due("every 30m", last, now=self.NOW) is False + + def test_interval_elapsed(self): + # Synced 31 minutes ago, interval is 30m + last = _utc(2026, 4, 12, 9, 29, 0).isoformat() + assert is_table_due("every 30m", last, now=self.NOW) is True + + def test_exact_boundary(self): + # Synced exactly 30 minutes ago — boundary is inclusive (>=) + last = _utc(2026, 4, 12, 9, 30, 0).isoformat() + assert is_table_due("every 30m", last, now=self.NOW) is True + + def test_interval_1h_not_elapsed(self): + last = _utc(2026, 4, 12, 9, 30, 0).isoformat() + assert is_table_due("every 1h", last, now=self.NOW) is False + + def test_interval_1h_elapsed(self): + last = _utc(2026, 4, 12, 8, 59, 0).isoformat() + assert is_table_due("every 1h", last, now=self.NOW) is True + + +class TestIsTableDueDaily: + def test_before_target_time_not_due(self): + # now is 04:00 UTC, target is 05:00 — not yet reached + now = _utc(2026, 4, 12, 4, 0, 0) + last = _utc(2026, 4, 11, 5, 0, 0).isoformat() # yesterday + assert is_table_due("daily 05:00", last, now=now) is False + + def test_after_target_time_due(self): + # now is 06:00 UTC, target is 05:00 — past target + now = _utc(2026, 4, 12, 6, 0, 0) + last = _utc(2026, 4, 11, 5, 0, 0).isoformat() # last sync was yesterday + assert is_table_due("daily 05:00", last, now=now) is True + + def test_already_synced_today(self): + # Now is 10:00 UTC, synced at 05:30 today — not due again + now = _utc(2026, 4, 12, 10, 0, 0) + last = _utc(2026, 4, 12, 5, 30, 0).isoformat() + assert is_table_due("daily 05:00", last, now=now) is False + + def test_daily_multiple_times_second_time_due(self): + # Schedule: daily 07:00,13:00,18:00 + # Now is 14:00, already synced at 07:30 — second target (13:00) is due + now = _utc(2026, 4, 12, 14, 0, 0) + last = _utc(2026, 4, 12, 7, 30, 0).isoformat() + assert is_table_due("daily 07:00,13:00,18:00", last, now=now) is True + + def test_daily_multiple_times_not_due_after_all(self): + # Now is 19:00, synced at 18:30 — all targets passed + now = _utc(2026, 4, 12, 19, 0, 0) + last = _utc(2026, 4, 12, 18, 30, 0).isoformat() + assert is_table_due("daily 07:00,13:00,18:00", last, now=now) is False + + +class TestIsTableDueEdgeCases: + def test_unknown_format_returns_false(self): + now = _utc(2026, 4, 12, 10, 0, 0) + assert is_table_due("weekly monday", "2026-04-11T09:00:00", now=now) is False + + def test_invalid_timestamp_treated_as_due(self): + assert is_table_due("every 1h", "not-a-timestamp") is True + + def test_naive_last_sync_timestamp(self): + # ISO timestamp without timezone info should still work + now = _utc(2026, 4, 12, 10, 0, 0) + last = "2026-04-12T08:00:00" # no tz info + assert is_table_due("every 1h", last, now=now) is True + + def test_none_now_uses_current_time(self): + # Simply smoke-test that it doesn't crash with now=None + result = is_table_due("every 1h", last_sync_iso=None, now=None) + assert result is True # never synced diff --git a/tests/test_session_collector.py b/tests/test_session_collector.py new file mode 100644 index 0000000..4246af9 --- /dev/null +++ b/tests/test_session_collector.py @@ -0,0 +1,98 @@ +"""Tests for session_collector.collector.""" + +from pathlib import Path + +import pytest + +from services.session_collector.collector import copy_session_file, find_session_files + + +class TestCopySessionFile: + def test_skips_if_target_exists(self, tmp_path): + """Returns False and does not overwrite if target already exists.""" + source = tmp_path / "session.jsonl" + source.write_text('{"event": "start"}') + target = tmp_path / "dest" / "session.jsonl" + target.parent.mkdir(parents=True) + target.write_text("existing content") + + result = copy_session_file(source, target) + assert result is False + # Target content should not be overwritten + assert target.read_text() == "existing content" + + def test_copies_new_file(self, tmp_path): + """Returns True and creates the target when it does not exist.""" + source = tmp_path / "session.jsonl" + source.write_text('{"event": "start"}') + target = tmp_path / "dest" / "session.jsonl" + + result = copy_session_file(source, target) + assert result is True + assert target.exists() + assert target.read_text() == '{"event": "start"}' + + def test_dry_run_returns_true_without_copying(self, tmp_path): + """In dry_run mode, returns True but does not create the file.""" + source = tmp_path / "session.jsonl" + source.write_text('{"event": "start"}') + target = tmp_path / "dest" / "session.jsonl" + + result = copy_session_file(source, target, dry_run=True) + assert result is True + assert not target.exists() + + def test_creates_parent_directory(self, tmp_path): + """Parent directories are created automatically.""" + source = tmp_path / "session.jsonl" + source.write_text("data") + target = tmp_path / "a" / "b" / "c" / "session.jsonl" + + copy_session_file(source, target) + assert target.exists() + + def test_dry_run_skips_existing_target(self, tmp_path): + """dry_run still returns False if target already exists.""" + source = tmp_path / "session.jsonl" + source.write_text("data") + target = tmp_path / "session.jsonl" + target.write_text("old") + + result = copy_session_file(source, target, dry_run=True) + assert result is False + + +class TestFindSessionFiles: + def test_finds_jsonl_files(self, tmp_path): + """find_session_files yields .jsonl files from user/sessions/.""" + user_home = tmp_path / "alice" + sessions_dir = user_home / "user" / "sessions" + sessions_dir.mkdir(parents=True) + f1 = sessions_dir / "session1.jsonl" + f2 = sessions_dir / "session2.jsonl" + f1.write_text("{}") + f2.write_text("{}") + + found = list(find_session_files(user_home)) + assert len(found) == 2 + assert all(f.suffix == ".jsonl" for f in found) + + def test_ignores_non_jsonl_files(self, tmp_path): + """Non-.jsonl files are not returned.""" + user_home = tmp_path / "bob" + sessions_dir = user_home / "user" / "sessions" + sessions_dir.mkdir(parents=True) + (sessions_dir / "notes.txt").write_text("ignore me") + (sessions_dir / "session.jsonl").write_text("{}") + + found = list(find_session_files(user_home)) + assert len(found) == 1 + assert found[0].name == "session.jsonl" + + def test_returns_empty_when_no_sessions_dir(self, tmp_path): + """Returns empty iterator when user/sessions/ doesn't exist.""" + user_home = tmp_path / "carol" + user_home.mkdir() + + found = list(find_session_files(user_home)) + assert found == [] diff --git a/tests/test_telegram_bot.py b/tests/test_telegram_bot.py new file mode 100644 index 0000000..3074231 --- /dev/null +++ b/tests/test_telegram_bot.py @@ -0,0 +1,101 @@ +"""Tests for Telegram bot message handlers.""" + +import asyncio +import os +import sys +from unittest.mock import AsyncMock, patch + +import pytest + + +def _make_message(text: str, chat_id: int = 10) -> dict: + return {"chat": {"id": chat_id}, "text": text} + + +def _run(coro): + """Run a coroutine synchronously.""" + return asyncio.get_event_loop().run_until_complete(coro) + + +@pytest.fixture(autouse=True, scope="module") +def patch_bot_log(tmp_path_factory): + """Patch BOT_LOG_FILE before the bot module is imported so the FileHandler succeeds.""" + log_dir = tmp_path_factory.mktemp("notify_bot") + log_file = str(log_dir / "bot.log") + + import services.telegram_bot.config as cfg + original = cfg.BOT_LOG_FILE + cfg.BOT_LOG_FILE = log_file + + # Remove cached bot module so it re-imports with patched config + sys.modules.pop("services.telegram_bot.bot", None) + + yield + + cfg.BOT_LOG_FILE = original + sys.modules.pop("services.telegram_bot.bot", None) + + +class TestHandleMessage: + def test_start_unlinked_user_generates_verification_code(self): + """'/start' for an unlinked user generates and sends a verification code.""" + with ( + patch("services.telegram_bot.bot.get_username_by_chat_id", return_value=None), + patch("services.telegram_bot.bot.create_verification_code", return_value="123456") as mock_code, + patch("services.telegram_bot.bot.send_message", new_callable=AsyncMock) as mock_send, + ): + from services.telegram_bot.bot import handle_message + _run(handle_message(_make_message("/start", chat_id=10))) + mock_code.assert_called_once_with(10) + mock_send.assert_called_once() + sent_text = mock_send.call_args[0][1] + assert "123456" in sent_text + + def test_start_already_linked_user_no_code(self): + """'/start' for an already-linked user does NOT generate a new code.""" + with ( + patch("services.telegram_bot.bot.get_username_by_chat_id", return_value="alice"), + patch("services.telegram_bot.bot.create_verification_code") as mock_code, + patch("services.telegram_bot.bot.send_message", new_callable=AsyncMock), + ): + from services.telegram_bot.bot import handle_message + _run(handle_message(_make_message("/start", chat_id=10))) + mock_code.assert_not_called() + + def test_help_returns_help_text(self): + """'/help' sends a message containing help information.""" + with patch("services.telegram_bot.bot.send_message", new_callable=AsyncMock) as mock_send: + from services.telegram_bot.bot import handle_message + _run(handle_message(_make_message("/help", chat_id=20))) + mock_send.assert_called_once() + sent_text = mock_send.call_args[0][1] + assert "/start" in sent_text + assert "/help" in sent_text + + def test_unknown_command_sends_unknown_response(self): + """An unknown command sends an 'Unknown command' reply.""" + with patch("services.telegram_bot.bot.send_message", new_callable=AsyncMock) as mock_send: + from services.telegram_bot.bot import handle_message + _run(handle_message(_make_message("/foobar", chat_id=30))) + mock_send.assert_called_once() + sent_text = mock_send.call_args[0][1] + assert "Unknown" in sent_text or "unknown" in sent_text + + def test_message_with_no_chat_id_is_ignored(self): + """A message without a chat id does nothing.""" + with patch("services.telegram_bot.bot.send_message", new_callable=AsyncMock) as mock_send: + from services.telegram_bot.bot import handle_message + _run(handle_message({"text": "/help"})) + mock_send.assert_not_called() + + def test_whoami_linked_user_sends_username(self): + """'/whoami' for a linked user sends the username.""" + with ( + patch("services.telegram_bot.bot.get_username_by_chat_id", return_value="dave"), + patch("services.telegram_bot.bot.send_message", new_callable=AsyncMock) as mock_send, + ): + from services.telegram_bot.bot import handle_message + _run(handle_message(_make_message("/whoami", chat_id=40))) + mock_send.assert_called_once() + sent_text = mock_send.call_args[0][1] + assert "dave" in sent_text diff --git a/tests/test_telegram_storage.py b/tests/test_telegram_storage.py new file mode 100644 index 0000000..accf308 --- /dev/null +++ b/tests/test_telegram_storage.py @@ -0,0 +1,115 @@ +"""Tests for Telegram bot storage (user linking and verification codes).""" + +import json +import os +import time +from pathlib import Path +from unittest.mock import patch + +import pytest + + +@pytest.fixture() +def storage_paths(tmp_path, monkeypatch): + """Redirect storage file paths to tmp_path.""" + users_file = str(tmp_path / "telegram_users.json") + codes_file = str(tmp_path / "pending_codes.json") + + import services.telegram_bot.config as cfg + monkeypatch.setattr(cfg, "TELEGRAM_USERS_FILE", users_file) + monkeypatch.setattr(cfg, "PENDING_CODES_FILE", codes_file) + # Also patch in the storage module namespace + import services.telegram_bot.storage as storage_mod + monkeypatch.setattr(storage_mod, "config", cfg) + + return {"users": users_file, "codes": codes_file} + + +class TestUserLinking: + def test_link_user_and_get_chat_id(self, storage_paths): + from services.telegram_bot.storage import get_chat_id, link_user + link_user("alice", 100) + assert get_chat_id("alice") == 100 + + def test_get_chat_id_unknown_user_returns_none(self, storage_paths): + from services.telegram_bot.storage import get_chat_id + assert get_chat_id("nobody") is None + + def test_unlink_user_returns_true_when_linked(self, storage_paths): + from services.telegram_bot.storage import link_user, unlink_user + link_user("bob", 200) + result = unlink_user("bob") + assert result is True + + def test_unlink_user_removes_entry(self, storage_paths): + from services.telegram_bot.storage import get_chat_id, link_user, unlink_user + link_user("carol", 300) + unlink_user("carol") + assert get_chat_id("carol") is None + + def test_unlink_user_returns_false_when_not_linked(self, storage_paths): + from services.telegram_bot.storage import unlink_user + result = unlink_user("ghost") + assert result is False + + def test_link_multiple_users(self, storage_paths): + from services.telegram_bot.storage import get_chat_id, link_user + link_user("user1", 111) + link_user("user2", 222) + assert get_chat_id("user1") == 111 + assert get_chat_id("user2") == 222 + + +class TestVerificationCodes: + def test_create_verification_code_returns_string(self, storage_paths): + from services.telegram_bot.storage import create_verification_code + code = create_verification_code(chat_id=42) + assert isinstance(code, str) + assert len(code) > 0 + + def test_verify_code_returns_chat_id(self, storage_paths): + from services.telegram_bot.storage import create_verification_code, verify_code + code = create_verification_code(chat_id=55) + result = verify_code(code) + assert result == 55 + + def test_code_consumed_after_first_verify(self, storage_paths): + from services.telegram_bot.storage import create_verification_code, verify_code + code = create_verification_code(chat_id=77) + verify_code(code) + # Second call must return None (code consumed) + result = verify_code(code) + assert result is None + + def test_verify_invalid_code_returns_none(self, storage_paths): + from services.telegram_bot.storage import verify_code + result = verify_code("000000") + assert result is None + + def test_create_code_replaces_existing_for_same_chat_id(self, storage_paths): + from services.telegram_bot.storage import create_verification_code, verify_code + old_code = create_verification_code(chat_id=88) + new_code = create_verification_code(chat_id=88) + # Old code should be gone + assert verify_code(old_code) is None + # New code should work + assert verify_code(new_code) == 88 + + def test_expired_code_not_valid(self, storage_paths): + """Manually write an expired code and verify it returns None.""" + import services.telegram_bot.config as cfg + from services.telegram_bot.storage import verify_code + + # Write a code that expired long ago + expired_data = { + "123456": { + "chat_id": 99, + "created_at": time.time() - cfg.CODE_TTL_SECONDS - 1, + } + } + Path(cfg.PENDING_CODES_FILE).parent.mkdir(parents=True, exist_ok=True) + with open(cfg.PENDING_CODES_FILE, "w") as f: + json.dump(expired_data, f) + + result = verify_code("123456") + assert result is None diff --git a/tests/test_ws_gateway.py b/tests/test_ws_gateway.py new file mode 100644 index 0000000..9dbb9c5 --- /dev/null +++ b/tests/test_ws_gateway.py @@ -0,0 +1,80 @@ +"""Tests for WebSocket Gateway JWT authentication.""" + +import time + +import jwt +import pytest + +SECRET = "test-secret-ws-gateway" + + +@pytest.fixture(autouse=True) +def patch_gateway_secret(monkeypatch): + """Patch the DESKTOP_JWT_SECRET so the module can be imported.""" + monkeypatch.setenv("DESKTOP_JWT_SECRET", SECRET) + # Force reload of config module so the env var is picked up + import importlib + import services.ws_gateway.config as cfg + importlib.reload(cfg) + import services.ws_gateway.auth as auth_mod + importlib.reload(auth_mod) + + +def _make_token(payload: dict, secret: str = SECRET, algorithm: str = "HS256") -> str: + return jwt.encode(payload, secret, algorithm=algorithm) + + +def _import_validate(): + """Return the validate_token function (after env is patched).""" + from services.ws_gateway.auth import validate_token + return validate_token + + +class TestValidateToken: + def test_valid_token_returns_payload(self): + """A token with 'sub' and a future 'exp' returns the decoded payload.""" + validate_token = _import_validate() + payload = {"sub": "alice", "exp": int(time.time()) + 3600} + token = _make_token(payload) + result = validate_token(token) + assert result is not None + assert result["sub"] == "alice" + + def test_expired_token_returns_none(self): + """An expired token returns None.""" + validate_token = _import_validate() + payload = {"sub": "bob", "exp": int(time.time()) - 10} + token = _make_token(payload) + result = validate_token(token) + assert result is None + + def test_invalid_signature_returns_none(self): + """A token signed with a different secret returns None.""" + validate_token = _import_validate() + payload = {"sub": "charlie", "exp": int(time.time()) + 3600} + token = _make_token(payload, secret="wrong-secret") + result = validate_token(token) + assert result is None + + def test_token_missing_sub_returns_none(self): + """A token that has no 'sub' claim returns None.""" + validate_token = _import_validate() + payload = {"exp": int(time.time()) + 3600, "role": "admin"} + token = _make_token(payload) + result = validate_token(token) + assert result is None + + def test_garbage_string_returns_none(self): + """A completely invalid token string returns None.""" + validate_token = _import_validate() + result = validate_token("not.a.token") + assert result is None + + def test_valid_token_includes_all_claims(self): + """All custom claims are present in the returned payload.""" + validate_token = _import_validate() + payload = {"sub": "dave", "exp": int(time.time()) + 3600, "role": "analyst"} + token = _make_token(payload) + result = validate_token(token) + assert result is not None + assert result["role"] == "analyst"