Merge branch 'worktree-agent-af11156d' into feature/v2-fastapi-duckdb-docker-cli
This commit is contained in:
commit
b6ace1e09a
6 changed files with 773 additions and 0 deletions
245
tests/test_corporate_memory_collector.py
Normal file
245
tests/test_corporate_memory_collector.py
Normal file
|
|
@ -0,0 +1,245 @@
|
|||
"""Tests for Corporate Memory knowledge collector."""
|
||||
|
||||
import hashlib
|
||||
import json
|
||||
from datetime import datetime, timezone
|
||||
from pathlib import Path
|
||||
from unittest.mock import MagicMock, patch
|
||||
|
||||
import pytest
|
||||
|
||||
|
||||
# ---------------------------------------------------------------------------
|
||||
# Minimal mock LLM extractor
|
||||
# ---------------------------------------------------------------------------
|
||||
|
||||
class MockLLMProvider:
|
||||
"""A minimal mock for connectors.llm.StructuredExtractor."""
|
||||
|
||||
def __init__(self, response: dict):
|
||||
self._response = response
|
||||
|
||||
def extract_json(self, prompt: str, max_tokens: int, json_schema: dict, schema_name: str) -> dict:
|
||||
return self._response
|
||||
|
||||
|
||||
# ---------------------------------------------------------------------------
|
||||
# Helpers
|
||||
# ---------------------------------------------------------------------------
|
||||
|
||||
def _write_json(path: Path, data: dict):
|
||||
path.parent.mkdir(parents=True, exist_ok=True)
|
||||
with open(path, "w") as f:
|
||||
json.dump(data, f)
|
||||
|
||||
|
||||
# ---------------------------------------------------------------------------
|
||||
# Tests for _generate_id
|
||||
# ---------------------------------------------------------------------------
|
||||
|
||||
class TestGenerateId:
|
||||
def test_returns_km_prefix(self):
|
||||
from services.corporate_memory.collector import _generate_id
|
||||
item_id = _generate_id("hello world")
|
||||
assert item_id.startswith("km_")
|
||||
|
||||
def test_deterministic(self):
|
||||
from services.corporate_memory.collector import _generate_id
|
||||
assert _generate_id("same") == _generate_id("same")
|
||||
|
||||
def test_different_content_different_id(self):
|
||||
from services.corporate_memory.collector import _generate_id
|
||||
assert _generate_id("aaa") != _generate_id("bbb")
|
||||
|
||||
|
||||
# ---------------------------------------------------------------------------
|
||||
# Tests for _process_catalog_response (hash change / governance preservation)
|
||||
# ---------------------------------------------------------------------------
|
||||
|
||||
class TestProcessCatalogResponse:
|
||||
def test_new_item_gets_generated_id(self):
|
||||
from services.corporate_memory.collector import _process_catalog_response
|
||||
response_items = [
|
||||
{
|
||||
"existing_id": None,
|
||||
"title": "Tip One",
|
||||
"content": "Always check the logs first.",
|
||||
"category": "debugging",
|
||||
"tags": ["logs"],
|
||||
"source_users": ["alice"],
|
||||
}
|
||||
]
|
||||
result = _process_catalog_response(response_items, existing={"items": {}})
|
||||
assert len(result) == 1
|
||||
item_id, item = next(iter(result.items()))
|
||||
assert item_id.startswith("km_")
|
||||
assert item["title"] == "Tip One"
|
||||
assert item["status"] == "approved" # default initial_status
|
||||
|
||||
def test_existing_id_preserved(self):
|
||||
from services.corporate_memory.collector import _process_catalog_response
|
||||
existing = {
|
||||
"items": {
|
||||
"km_abc123": {
|
||||
"id": "km_abc123",
|
||||
"title": "Old Title",
|
||||
"content": "Old content",
|
||||
"category": "debugging",
|
||||
"tags": [],
|
||||
"source_users": ["bob"],
|
||||
"extracted_at": "2026-01-01T00:00:00+00:00",
|
||||
"status": "approved",
|
||||
"approved_by": "admin",
|
||||
"approved_at": "2026-01-02T00:00:00+00:00",
|
||||
"mandatory_reason": None,
|
||||
"audience": "all",
|
||||
"review_by": None,
|
||||
"edited_by": None,
|
||||
"edited_at": None,
|
||||
}
|
||||
}
|
||||
}
|
||||
response_items = [
|
||||
{
|
||||
"existing_id": "km_abc123",
|
||||
"title": "Updated Title",
|
||||
"content": "New content",
|
||||
"category": "debugging",
|
||||
"tags": ["updated"],
|
||||
"source_users": ["bob"],
|
||||
}
|
||||
]
|
||||
result = _process_catalog_response(response_items, existing=existing)
|
||||
assert "km_abc123" in result
|
||||
item = result["km_abc123"]
|
||||
assert item["title"] == "Updated Title"
|
||||
|
||||
def test_governance_fields_preserved(self):
|
||||
from services.corporate_memory.collector import GOVERNANCE_FIELDS, _process_catalog_response
|
||||
existing = {
|
||||
"items": {
|
||||
"km_abc123": {
|
||||
"id": "km_abc123",
|
||||
"title": "T",
|
||||
"content": "C",
|
||||
"category": "workflow",
|
||||
"tags": [],
|
||||
"source_users": ["carol"],
|
||||
"extracted_at": "2026-01-01T00:00:00+00:00",
|
||||
"status": "approved",
|
||||
"approved_by": "manager",
|
||||
"approved_at": "2026-03-01T00:00:00+00:00",
|
||||
"mandatory_reason": "Policy",
|
||||
"audience": "team",
|
||||
"review_by": "2026-12-31",
|
||||
"edited_by": "carol",
|
||||
"edited_at": "2026-02-01T00:00:00+00:00",
|
||||
}
|
||||
}
|
||||
}
|
||||
response_items = [
|
||||
{
|
||||
"existing_id": "km_abc123",
|
||||
"title": "T",
|
||||
"content": "C updated",
|
||||
"category": "workflow",
|
||||
"tags": [],
|
||||
"source_users": ["carol"],
|
||||
}
|
||||
]
|
||||
result = _process_catalog_response(response_items, existing=existing)
|
||||
item = result["km_abc123"]
|
||||
assert item["approved_by"] == "manager"
|
||||
assert item["mandatory_reason"] == "Policy"
|
||||
assert item["audience"] == "team"
|
||||
|
||||
def test_new_item_with_pending_initial_status(self):
|
||||
from services.corporate_memory.collector import _process_catalog_response
|
||||
response_items = [
|
||||
{
|
||||
"existing_id": None,
|
||||
"title": "Another tip",
|
||||
"content": "Some content",
|
||||
"category": "workflow",
|
||||
"tags": [],
|
||||
"source_users": ["dave"],
|
||||
}
|
||||
]
|
||||
result = _process_catalog_response(
|
||||
response_items, existing={"items": {}}, initial_status="pending"
|
||||
)
|
||||
item = next(iter(result.values()))
|
||||
assert item["status"] == "pending"
|
||||
|
||||
|
||||
# ---------------------------------------------------------------------------
|
||||
# Tests for check_sensitivity
|
||||
# ---------------------------------------------------------------------------
|
||||
|
||||
class TestCheckSensitivity:
|
||||
def test_safe_item_returns_true(self):
|
||||
from services.corporate_memory.collector import check_sensitivity
|
||||
extractor = MockLLMProvider({"safe": True})
|
||||
item = {"id": "km_x", "title": "T", "content": "C", "tags": []}
|
||||
assert check_sensitivity(extractor, item) is True
|
||||
|
||||
def test_unsafe_item_returns_false(self):
|
||||
from services.corporate_memory.collector import check_sensitivity
|
||||
extractor = MockLLMProvider({"safe": False, "reason": "Contains PII"})
|
||||
item = {"id": "km_y", "title": "T", "content": "C", "tags": []}
|
||||
assert check_sensitivity(extractor, item) is False
|
||||
|
||||
def test_llm_error_returns_false(self):
|
||||
"""When the LLM raises an LLMError, the item is treated as unsafe."""
|
||||
from connectors.llm.exceptions import LLMError
|
||||
from services.corporate_memory.collector import check_sensitivity
|
||||
|
||||
class ErrorExtractor:
|
||||
def extract_json(self, *args, **kwargs):
|
||||
raise LLMError("Network error")
|
||||
|
||||
item = {"id": "km_z", "title": "T", "content": "C", "tags": []}
|
||||
assert check_sensitivity(ErrorExtractor(), item) is False
|
||||
|
||||
|
||||
# ---------------------------------------------------------------------------
|
||||
# Integration-style: collect_all with mocked I/O
|
||||
# ---------------------------------------------------------------------------
|
||||
|
||||
class TestCollectAllSkipsWhenNoChanges:
|
||||
def test_skips_when_no_user_files(self, tmp_path):
|
||||
"""collect_all returns skipped=True when no CLAUDE.local.md files exist."""
|
||||
from services.corporate_memory import collector
|
||||
|
||||
with (
|
||||
patch.object(collector, "HOME_BASE", tmp_path / "home"),
|
||||
patch.object(collector, "KNOWLEDGE_FILE", tmp_path / "knowledge.json"),
|
||||
patch.object(collector, "USER_HASHES_FILE", tmp_path / "user_hashes.json"),
|
||||
):
|
||||
(tmp_path / "home").mkdir()
|
||||
stats = collector.collect_all(dry_run=True)
|
||||
assert stats["skipped"] is True
|
||||
|
||||
def test_skips_when_hashes_unchanged(self, tmp_path):
|
||||
"""collect_all skips when hashes match stored values."""
|
||||
from services.corporate_memory import collector
|
||||
|
||||
home = tmp_path / "home"
|
||||
home.mkdir()
|
||||
user_dir = home / "alice"
|
||||
user_dir.mkdir()
|
||||
claude_file = user_dir / "CLAUDE.local.md"
|
||||
claude_file.write_text("# My tips\n- Always document code")
|
||||
|
||||
content = claude_file.read_text(encoding="utf-8")
|
||||
md5 = hashlib.md5(content.encode()).hexdigest()
|
||||
user_hashes_file = tmp_path / "user_hashes.json"
|
||||
_write_json(user_hashes_file, {"hashes": {"alice": md5}})
|
||||
|
||||
with (
|
||||
patch.object(collector, "HOME_BASE", home),
|
||||
patch.object(collector, "KNOWLEDGE_FILE", tmp_path / "knowledge.json"),
|
||||
patch.object(collector, "USER_HASHES_FILE", user_hashes_file),
|
||||
):
|
||||
stats = collector.collect_all(dry_run=True)
|
||||
assert stats["skipped"] is True
|
||||
134
tests/test_scheduler_full.py
Normal file
134
tests/test_scheduler_full.py
Normal file
|
|
@ -0,0 +1,134 @@
|
|||
"""Tests for schedule parsing and due-check logic in src/scheduler.py."""
|
||||
|
||||
from datetime import datetime, timezone
|
||||
|
||||
import pytest
|
||||
|
||||
from src.scheduler import is_table_due, parse_interval_minutes
|
||||
|
||||
|
||||
# ---------------------------------------------------------------------------
|
||||
# parse_interval_minutes
|
||||
# ---------------------------------------------------------------------------
|
||||
|
||||
class TestParseIntervalMinutes:
|
||||
def test_every_15m(self):
|
||||
assert parse_interval_minutes("every 15m") == 15
|
||||
|
||||
def test_every_1h(self):
|
||||
assert parse_interval_minutes("every 1h") == 60
|
||||
|
||||
def test_every_2h(self):
|
||||
assert parse_interval_minutes("every 2h") == 120
|
||||
|
||||
def test_every_30m(self):
|
||||
assert parse_interval_minutes("every 30m") == 30
|
||||
|
||||
def test_daily_returns_none(self):
|
||||
assert parse_interval_minutes("daily 05:00") is None
|
||||
|
||||
def test_invalid_string_returns_none(self):
|
||||
assert parse_interval_minutes("gibberish") is None
|
||||
|
||||
def test_empty_string_returns_none(self):
|
||||
assert parse_interval_minutes("") is None
|
||||
|
||||
def test_every_0m(self):
|
||||
# Edge case: zero minutes is still a valid parse
|
||||
assert parse_interval_minutes("every 0m") == 0
|
||||
|
||||
|
||||
# ---------------------------------------------------------------------------
|
||||
# is_table_due — interval schedules
|
||||
# ---------------------------------------------------------------------------
|
||||
|
||||
def _utc(year, month, day, hour=0, minute=0, second=0):
|
||||
return datetime(year, month, day, hour, minute, second, tzinfo=timezone.utc)
|
||||
|
||||
|
||||
class TestIsTableDueNeverSynced:
|
||||
def test_never_synced_is_always_due(self):
|
||||
assert is_table_due("every 1h", last_sync_iso=None) is True
|
||||
|
||||
def test_empty_string_last_sync_is_due(self):
|
||||
assert is_table_due("every 1h", last_sync_iso="") is True
|
||||
|
||||
|
||||
class TestIsTableDueInterval:
|
||||
NOW = _utc(2026, 4, 12, 10, 0, 0)
|
||||
|
||||
def test_interval_not_elapsed(self):
|
||||
# Synced 10 minutes ago, interval is 30m
|
||||
last = _utc(2026, 4, 12, 9, 50, 0).isoformat()
|
||||
assert is_table_due("every 30m", last, now=self.NOW) is False
|
||||
|
||||
def test_interval_elapsed(self):
|
||||
# Synced 31 minutes ago, interval is 30m
|
||||
last = _utc(2026, 4, 12, 9, 29, 0).isoformat()
|
||||
assert is_table_due("every 30m", last, now=self.NOW) is True
|
||||
|
||||
def test_exact_boundary(self):
|
||||
# Synced exactly 30 minutes ago — boundary is inclusive (>=)
|
||||
last = _utc(2026, 4, 12, 9, 30, 0).isoformat()
|
||||
assert is_table_due("every 30m", last, now=self.NOW) is True
|
||||
|
||||
def test_interval_1h_not_elapsed(self):
|
||||
last = _utc(2026, 4, 12, 9, 30, 0).isoformat()
|
||||
assert is_table_due("every 1h", last, now=self.NOW) is False
|
||||
|
||||
def test_interval_1h_elapsed(self):
|
||||
last = _utc(2026, 4, 12, 8, 59, 0).isoformat()
|
||||
assert is_table_due("every 1h", last, now=self.NOW) is True
|
||||
|
||||
|
||||
class TestIsTableDueDaily:
|
||||
def test_before_target_time_not_due(self):
|
||||
# now is 04:00 UTC, target is 05:00 — not yet reached
|
||||
now = _utc(2026, 4, 12, 4, 0, 0)
|
||||
last = _utc(2026, 4, 11, 5, 0, 0).isoformat() # yesterday
|
||||
assert is_table_due("daily 05:00", last, now=now) is False
|
||||
|
||||
def test_after_target_time_due(self):
|
||||
# now is 06:00 UTC, target is 05:00 — past target
|
||||
now = _utc(2026, 4, 12, 6, 0, 0)
|
||||
last = _utc(2026, 4, 11, 5, 0, 0).isoformat() # last sync was yesterday
|
||||
assert is_table_due("daily 05:00", last, now=now) is True
|
||||
|
||||
def test_already_synced_today(self):
|
||||
# Now is 10:00 UTC, synced at 05:30 today — not due again
|
||||
now = _utc(2026, 4, 12, 10, 0, 0)
|
||||
last = _utc(2026, 4, 12, 5, 30, 0).isoformat()
|
||||
assert is_table_due("daily 05:00", last, now=now) is False
|
||||
|
||||
def test_daily_multiple_times_second_time_due(self):
|
||||
# Schedule: daily 07:00,13:00,18:00
|
||||
# Now is 14:00, already synced at 07:30 — second target (13:00) is due
|
||||
now = _utc(2026, 4, 12, 14, 0, 0)
|
||||
last = _utc(2026, 4, 12, 7, 30, 0).isoformat()
|
||||
assert is_table_due("daily 07:00,13:00,18:00", last, now=now) is True
|
||||
|
||||
def test_daily_multiple_times_not_due_after_all(self):
|
||||
# Now is 19:00, synced at 18:30 — all targets passed
|
||||
now = _utc(2026, 4, 12, 19, 0, 0)
|
||||
last = _utc(2026, 4, 12, 18, 30, 0).isoformat()
|
||||
assert is_table_due("daily 07:00,13:00,18:00", last, now=now) is False
|
||||
|
||||
|
||||
class TestIsTableDueEdgeCases:
|
||||
def test_unknown_format_returns_false(self):
|
||||
now = _utc(2026, 4, 12, 10, 0, 0)
|
||||
assert is_table_due("weekly monday", "2026-04-11T09:00:00", now=now) is False
|
||||
|
||||
def test_invalid_timestamp_treated_as_due(self):
|
||||
assert is_table_due("every 1h", "not-a-timestamp") is True
|
||||
|
||||
def test_naive_last_sync_timestamp(self):
|
||||
# ISO timestamp without timezone info should still work
|
||||
now = _utc(2026, 4, 12, 10, 0, 0)
|
||||
last = "2026-04-12T08:00:00" # no tz info
|
||||
assert is_table_due("every 1h", last, now=now) is True
|
||||
|
||||
def test_none_now_uses_current_time(self):
|
||||
# Simply smoke-test that it doesn't crash with now=None
|
||||
result = is_table_due("every 1h", last_sync_iso=None, now=None)
|
||||
assert result is True # never synced
|
||||
98
tests/test_session_collector.py
Normal file
98
tests/test_session_collector.py
Normal file
|
|
@ -0,0 +1,98 @@
|
|||
"""Tests for session_collector.collector."""
|
||||
|
||||
from pathlib import Path
|
||||
|
||||
import pytest
|
||||
|
||||
from services.session_collector.collector import copy_session_file, find_session_files
|
||||
|
||||
|
||||
class TestCopySessionFile:
|
||||
def test_skips_if_target_exists(self, tmp_path):
|
||||
"""Returns False and does not overwrite if target already exists."""
|
||||
source = tmp_path / "session.jsonl"
|
||||
source.write_text('{"event": "start"}')
|
||||
target = tmp_path / "dest" / "session.jsonl"
|
||||
target.parent.mkdir(parents=True)
|
||||
target.write_text("existing content")
|
||||
|
||||
result = copy_session_file(source, target)
|
||||
assert result is False
|
||||
# Target content should not be overwritten
|
||||
assert target.read_text() == "existing content"
|
||||
|
||||
def test_copies_new_file(self, tmp_path):
|
||||
"""Returns True and creates the target when it does not exist."""
|
||||
source = tmp_path / "session.jsonl"
|
||||
source.write_text('{"event": "start"}')
|
||||
target = tmp_path / "dest" / "session.jsonl"
|
||||
|
||||
result = copy_session_file(source, target)
|
||||
assert result is True
|
||||
assert target.exists()
|
||||
assert target.read_text() == '{"event": "start"}'
|
||||
|
||||
def test_dry_run_returns_true_without_copying(self, tmp_path):
|
||||
"""In dry_run mode, returns True but does not create the file."""
|
||||
source = tmp_path / "session.jsonl"
|
||||
source.write_text('{"event": "start"}')
|
||||
target = tmp_path / "dest" / "session.jsonl"
|
||||
|
||||
result = copy_session_file(source, target, dry_run=True)
|
||||
assert result is True
|
||||
assert not target.exists()
|
||||
|
||||
def test_creates_parent_directory(self, tmp_path):
|
||||
"""Parent directories are created automatically."""
|
||||
source = tmp_path / "session.jsonl"
|
||||
source.write_text("data")
|
||||
target = tmp_path / "a" / "b" / "c" / "session.jsonl"
|
||||
|
||||
copy_session_file(source, target)
|
||||
assert target.exists()
|
||||
|
||||
def test_dry_run_skips_existing_target(self, tmp_path):
|
||||
"""dry_run still returns False if target already exists."""
|
||||
source = tmp_path / "session.jsonl"
|
||||
source.write_text("data")
|
||||
target = tmp_path / "session.jsonl"
|
||||
target.write_text("old")
|
||||
|
||||
result = copy_session_file(source, target, dry_run=True)
|
||||
assert result is False
|
||||
|
||||
|
||||
class TestFindSessionFiles:
|
||||
def test_finds_jsonl_files(self, tmp_path):
|
||||
"""find_session_files yields .jsonl files from user/sessions/."""
|
||||
user_home = tmp_path / "alice"
|
||||
sessions_dir = user_home / "user" / "sessions"
|
||||
sessions_dir.mkdir(parents=True)
|
||||
f1 = sessions_dir / "session1.jsonl"
|
||||
f2 = sessions_dir / "session2.jsonl"
|
||||
f1.write_text("{}")
|
||||
f2.write_text("{}")
|
||||
|
||||
found = list(find_session_files(user_home))
|
||||
assert len(found) == 2
|
||||
assert all(f.suffix == ".jsonl" for f in found)
|
||||
|
||||
def test_ignores_non_jsonl_files(self, tmp_path):
|
||||
"""Non-.jsonl files are not returned."""
|
||||
user_home = tmp_path / "bob"
|
||||
sessions_dir = user_home / "user" / "sessions"
|
||||
sessions_dir.mkdir(parents=True)
|
||||
(sessions_dir / "notes.txt").write_text("ignore me")
|
||||
(sessions_dir / "session.jsonl").write_text("{}")
|
||||
|
||||
found = list(find_session_files(user_home))
|
||||
assert len(found) == 1
|
||||
assert found[0].name == "session.jsonl"
|
||||
|
||||
def test_returns_empty_when_no_sessions_dir(self, tmp_path):
|
||||
"""Returns empty iterator when user/sessions/ doesn't exist."""
|
||||
user_home = tmp_path / "carol"
|
||||
user_home.mkdir()
|
||||
|
||||
found = list(find_session_files(user_home))
|
||||
assert found == []
|
||||
101
tests/test_telegram_bot.py
Normal file
101
tests/test_telegram_bot.py
Normal file
|
|
@ -0,0 +1,101 @@
|
|||
"""Tests for Telegram bot message handlers."""
|
||||
|
||||
import asyncio
|
||||
import os
|
||||
import sys
|
||||
from unittest.mock import AsyncMock, patch
|
||||
|
||||
import pytest
|
||||
|
||||
|
||||
def _make_message(text: str, chat_id: int = 10) -> dict:
|
||||
return {"chat": {"id": chat_id}, "text": text}
|
||||
|
||||
|
||||
def _run(coro):
|
||||
"""Run a coroutine synchronously."""
|
||||
return asyncio.get_event_loop().run_until_complete(coro)
|
||||
|
||||
|
||||
@pytest.fixture(autouse=True, scope="module")
|
||||
def patch_bot_log(tmp_path_factory):
|
||||
"""Patch BOT_LOG_FILE before the bot module is imported so the FileHandler succeeds."""
|
||||
log_dir = tmp_path_factory.mktemp("notify_bot")
|
||||
log_file = str(log_dir / "bot.log")
|
||||
|
||||
import services.telegram_bot.config as cfg
|
||||
original = cfg.BOT_LOG_FILE
|
||||
cfg.BOT_LOG_FILE = log_file
|
||||
|
||||
# Remove cached bot module so it re-imports with patched config
|
||||
sys.modules.pop("services.telegram_bot.bot", None)
|
||||
|
||||
yield
|
||||
|
||||
cfg.BOT_LOG_FILE = original
|
||||
sys.modules.pop("services.telegram_bot.bot", None)
|
||||
|
||||
|
||||
class TestHandleMessage:
|
||||
def test_start_unlinked_user_generates_verification_code(self):
|
||||
"""'/start' for an unlinked user generates and sends a verification code."""
|
||||
with (
|
||||
patch("services.telegram_bot.bot.get_username_by_chat_id", return_value=None),
|
||||
patch("services.telegram_bot.bot.create_verification_code", return_value="123456") as mock_code,
|
||||
patch("services.telegram_bot.bot.send_message", new_callable=AsyncMock) as mock_send,
|
||||
):
|
||||
from services.telegram_bot.bot import handle_message
|
||||
_run(handle_message(_make_message("/start", chat_id=10)))
|
||||
mock_code.assert_called_once_with(10)
|
||||
mock_send.assert_called_once()
|
||||
sent_text = mock_send.call_args[0][1]
|
||||
assert "123456" in sent_text
|
||||
|
||||
def test_start_already_linked_user_no_code(self):
|
||||
"""'/start' for an already-linked user does NOT generate a new code."""
|
||||
with (
|
||||
patch("services.telegram_bot.bot.get_username_by_chat_id", return_value="alice"),
|
||||
patch("services.telegram_bot.bot.create_verification_code") as mock_code,
|
||||
patch("services.telegram_bot.bot.send_message", new_callable=AsyncMock),
|
||||
):
|
||||
from services.telegram_bot.bot import handle_message
|
||||
_run(handle_message(_make_message("/start", chat_id=10)))
|
||||
mock_code.assert_not_called()
|
||||
|
||||
def test_help_returns_help_text(self):
|
||||
"""'/help' sends a message containing help information."""
|
||||
with patch("services.telegram_bot.bot.send_message", new_callable=AsyncMock) as mock_send:
|
||||
from services.telegram_bot.bot import handle_message
|
||||
_run(handle_message(_make_message("/help", chat_id=20)))
|
||||
mock_send.assert_called_once()
|
||||
sent_text = mock_send.call_args[0][1]
|
||||
assert "/start" in sent_text
|
||||
assert "/help" in sent_text
|
||||
|
||||
def test_unknown_command_sends_unknown_response(self):
|
||||
"""An unknown command sends an 'Unknown command' reply."""
|
||||
with patch("services.telegram_bot.bot.send_message", new_callable=AsyncMock) as mock_send:
|
||||
from services.telegram_bot.bot import handle_message
|
||||
_run(handle_message(_make_message("/foobar", chat_id=30)))
|
||||
mock_send.assert_called_once()
|
||||
sent_text = mock_send.call_args[0][1]
|
||||
assert "Unknown" in sent_text or "unknown" in sent_text
|
||||
|
||||
def test_message_with_no_chat_id_is_ignored(self):
|
||||
"""A message without a chat id does nothing."""
|
||||
with patch("services.telegram_bot.bot.send_message", new_callable=AsyncMock) as mock_send:
|
||||
from services.telegram_bot.bot import handle_message
|
||||
_run(handle_message({"text": "/help"}))
|
||||
mock_send.assert_not_called()
|
||||
|
||||
def test_whoami_linked_user_sends_username(self):
|
||||
"""'/whoami' for a linked user sends the username."""
|
||||
with (
|
||||
patch("services.telegram_bot.bot.get_username_by_chat_id", return_value="dave"),
|
||||
patch("services.telegram_bot.bot.send_message", new_callable=AsyncMock) as mock_send,
|
||||
):
|
||||
from services.telegram_bot.bot import handle_message
|
||||
_run(handle_message(_make_message("/whoami", chat_id=40)))
|
||||
mock_send.assert_called_once()
|
||||
sent_text = mock_send.call_args[0][1]
|
||||
assert "dave" in sent_text
|
||||
115
tests/test_telegram_storage.py
Normal file
115
tests/test_telegram_storage.py
Normal file
|
|
@ -0,0 +1,115 @@
|
|||
"""Tests for Telegram bot storage (user linking and verification codes)."""
|
||||
|
||||
import json
|
||||
import os
|
||||
import time
|
||||
from pathlib import Path
|
||||
from unittest.mock import patch
|
||||
|
||||
import pytest
|
||||
|
||||
|
||||
@pytest.fixture()
|
||||
def storage_paths(tmp_path, monkeypatch):
|
||||
"""Redirect storage file paths to tmp_path."""
|
||||
users_file = str(tmp_path / "telegram_users.json")
|
||||
codes_file = str(tmp_path / "pending_codes.json")
|
||||
|
||||
import services.telegram_bot.config as cfg
|
||||
monkeypatch.setattr(cfg, "TELEGRAM_USERS_FILE", users_file)
|
||||
monkeypatch.setattr(cfg, "PENDING_CODES_FILE", codes_file)
|
||||
# Also patch in the storage module namespace
|
||||
import services.telegram_bot.storage as storage_mod
|
||||
monkeypatch.setattr(storage_mod, "config", cfg)
|
||||
|
||||
return {"users": users_file, "codes": codes_file}
|
||||
|
||||
|
||||
class TestUserLinking:
|
||||
def test_link_user_and_get_chat_id(self, storage_paths):
|
||||
from services.telegram_bot.storage import get_chat_id, link_user
|
||||
link_user("alice", 100)
|
||||
assert get_chat_id("alice") == 100
|
||||
|
||||
def test_get_chat_id_unknown_user_returns_none(self, storage_paths):
|
||||
from services.telegram_bot.storage import get_chat_id
|
||||
assert get_chat_id("nobody") is None
|
||||
|
||||
def test_unlink_user_returns_true_when_linked(self, storage_paths):
|
||||
from services.telegram_bot.storage import link_user, unlink_user
|
||||
link_user("bob", 200)
|
||||
result = unlink_user("bob")
|
||||
assert result is True
|
||||
|
||||
def test_unlink_user_removes_entry(self, storage_paths):
|
||||
from services.telegram_bot.storage import get_chat_id, link_user, unlink_user
|
||||
link_user("carol", 300)
|
||||
unlink_user("carol")
|
||||
assert get_chat_id("carol") is None
|
||||
|
||||
def test_unlink_user_returns_false_when_not_linked(self, storage_paths):
|
||||
from services.telegram_bot.storage import unlink_user
|
||||
result = unlink_user("ghost")
|
||||
assert result is False
|
||||
|
||||
def test_link_multiple_users(self, storage_paths):
|
||||
from services.telegram_bot.storage import get_chat_id, link_user
|
||||
link_user("user1", 111)
|
||||
link_user("user2", 222)
|
||||
assert get_chat_id("user1") == 111
|
||||
assert get_chat_id("user2") == 222
|
||||
|
||||
|
||||
class TestVerificationCodes:
|
||||
def test_create_verification_code_returns_string(self, storage_paths):
|
||||
from services.telegram_bot.storage import create_verification_code
|
||||
code = create_verification_code(chat_id=42)
|
||||
assert isinstance(code, str)
|
||||
assert len(code) > 0
|
||||
|
||||
def test_verify_code_returns_chat_id(self, storage_paths):
|
||||
from services.telegram_bot.storage import create_verification_code, verify_code
|
||||
code = create_verification_code(chat_id=55)
|
||||
result = verify_code(code)
|
||||
assert result == 55
|
||||
|
||||
def test_code_consumed_after_first_verify(self, storage_paths):
|
||||
from services.telegram_bot.storage import create_verification_code, verify_code
|
||||
code = create_verification_code(chat_id=77)
|
||||
verify_code(code)
|
||||
# Second call must return None (code consumed)
|
||||
result = verify_code(code)
|
||||
assert result is None
|
||||
|
||||
def test_verify_invalid_code_returns_none(self, storage_paths):
|
||||
from services.telegram_bot.storage import verify_code
|
||||
result = verify_code("000000")
|
||||
assert result is None
|
||||
|
||||
def test_create_code_replaces_existing_for_same_chat_id(self, storage_paths):
|
||||
from services.telegram_bot.storage import create_verification_code, verify_code
|
||||
old_code = create_verification_code(chat_id=88)
|
||||
new_code = create_verification_code(chat_id=88)
|
||||
# Old code should be gone
|
||||
assert verify_code(old_code) is None
|
||||
# New code should work
|
||||
assert verify_code(new_code) == 88
|
||||
|
||||
def test_expired_code_not_valid(self, storage_paths):
|
||||
"""Manually write an expired code and verify it returns None."""
|
||||
import services.telegram_bot.config as cfg
|
||||
from services.telegram_bot.storage import verify_code
|
||||
|
||||
# Write a code that expired long ago
|
||||
expired_data = {
|
||||
"123456": {
|
||||
"chat_id": 99,
|
||||
"created_at": time.time() - cfg.CODE_TTL_SECONDS - 1,
|
||||
}
|
||||
}
|
||||
Path(cfg.PENDING_CODES_FILE).parent.mkdir(parents=True, exist_ok=True)
|
||||
with open(cfg.PENDING_CODES_FILE, "w") as f:
|
||||
json.dump(expired_data, f)
|
||||
|
||||
result = verify_code("123456")
|
||||
assert result is None
|
||||
80
tests/test_ws_gateway.py
Normal file
80
tests/test_ws_gateway.py
Normal file
|
|
@ -0,0 +1,80 @@
|
|||
"""Tests for WebSocket Gateway JWT authentication."""
|
||||
|
||||
import time
|
||||
|
||||
import jwt
|
||||
import pytest
|
||||
|
||||
SECRET = "test-secret-ws-gateway"
|
||||
|
||||
|
||||
@pytest.fixture(autouse=True)
|
||||
def patch_gateway_secret(monkeypatch):
|
||||
"""Patch the DESKTOP_JWT_SECRET so the module can be imported."""
|
||||
monkeypatch.setenv("DESKTOP_JWT_SECRET", SECRET)
|
||||
# Force reload of config module so the env var is picked up
|
||||
import importlib
|
||||
import services.ws_gateway.config as cfg
|
||||
importlib.reload(cfg)
|
||||
import services.ws_gateway.auth as auth_mod
|
||||
importlib.reload(auth_mod)
|
||||
|
||||
|
||||
def _make_token(payload: dict, secret: str = SECRET, algorithm: str = "HS256") -> str:
|
||||
return jwt.encode(payload, secret, algorithm=algorithm)
|
||||
|
||||
|
||||
def _import_validate():
|
||||
"""Return the validate_token function (after env is patched)."""
|
||||
from services.ws_gateway.auth import validate_token
|
||||
return validate_token
|
||||
|
||||
|
||||
class TestValidateToken:
|
||||
def test_valid_token_returns_payload(self):
|
||||
"""A token with 'sub' and a future 'exp' returns the decoded payload."""
|
||||
validate_token = _import_validate()
|
||||
payload = {"sub": "alice", "exp": int(time.time()) + 3600}
|
||||
token = _make_token(payload)
|
||||
result = validate_token(token)
|
||||
assert result is not None
|
||||
assert result["sub"] == "alice"
|
||||
|
||||
def test_expired_token_returns_none(self):
|
||||
"""An expired token returns None."""
|
||||
validate_token = _import_validate()
|
||||
payload = {"sub": "bob", "exp": int(time.time()) - 10}
|
||||
token = _make_token(payload)
|
||||
result = validate_token(token)
|
||||
assert result is None
|
||||
|
||||
def test_invalid_signature_returns_none(self):
|
||||
"""A token signed with a different secret returns None."""
|
||||
validate_token = _import_validate()
|
||||
payload = {"sub": "charlie", "exp": int(time.time()) + 3600}
|
||||
token = _make_token(payload, secret="wrong-secret")
|
||||
result = validate_token(token)
|
||||
assert result is None
|
||||
|
||||
def test_token_missing_sub_returns_none(self):
|
||||
"""A token that has no 'sub' claim returns None."""
|
||||
validate_token = _import_validate()
|
||||
payload = {"exp": int(time.time()) + 3600, "role": "admin"}
|
||||
token = _make_token(payload)
|
||||
result = validate_token(token)
|
||||
assert result is None
|
||||
|
||||
def test_garbage_string_returns_none(self):
|
||||
"""A completely invalid token string returns None."""
|
||||
validate_token = _import_validate()
|
||||
result = validate_token("not.a.token")
|
||||
assert result is None
|
||||
|
||||
def test_valid_token_includes_all_claims(self):
|
||||
"""All custom claims are present in the returned payload."""
|
||||
validate_token = _import_validate()
|
||||
payload = {"sub": "dave", "exp": int(time.time()) + 3600, "role": "analyst"}
|
||||
token = _make_token(payload)
|
||||
result = validate_token(token)
|
||||
assert result is not None
|
||||
assert result["role"] == "analyst"
|
||||
Loading…
Reference in a new issue