test: add Block C services tests (68 tests across 6 files)
Cover ws_gateway JWT auth, telegram storage user linking and verification codes, telegram bot handlers, scheduler pure functions, corporate memory collector hash detection and governance, and session file collection.
This commit is contained in:
parent
c24205a1bf
commit
5a651ca59c
6 changed files with 773 additions and 0 deletions
245
tests/test_corporate_memory_collector.py
Normal file
245
tests/test_corporate_memory_collector.py
Normal file
|
|
@ -0,0 +1,245 @@
|
|||
"""Tests for Corporate Memory knowledge collector."""
|
||||
|
||||
import hashlib
|
||||
import json
|
||||
from datetime import datetime, timezone
|
||||
from pathlib import Path
|
||||
from unittest.mock import MagicMock, patch
|
||||
|
||||
import pytest
|
||||
|
||||
|
||||
# ---------------------------------------------------------------------------
|
||||
# Minimal mock LLM extractor
|
||||
# ---------------------------------------------------------------------------
|
||||
|
||||
class MockLLMProvider:
|
||||
"""A minimal mock for connectors.llm.StructuredExtractor."""
|
||||
|
||||
def __init__(self, response: dict):
|
||||
self._response = response
|
||||
|
||||
def extract_json(self, prompt: str, max_tokens: int, json_schema: dict, schema_name: str) -> dict:
|
||||
return self._response
|
||||
|
||||
|
||||
# ---------------------------------------------------------------------------
|
||||
# Helpers
|
||||
# ---------------------------------------------------------------------------
|
||||
|
||||
def _write_json(path: Path, data: dict):
|
||||
path.parent.mkdir(parents=True, exist_ok=True)
|
||||
with open(path, "w") as f:
|
||||
json.dump(data, f)
|
||||
|
||||
|
||||
# ---------------------------------------------------------------------------
|
||||
# Tests for _generate_id
|
||||
# ---------------------------------------------------------------------------
|
||||
|
||||
class TestGenerateId:
|
||||
def test_returns_km_prefix(self):
|
||||
from services.corporate_memory.collector import _generate_id
|
||||
item_id = _generate_id("hello world")
|
||||
assert item_id.startswith("km_")
|
||||
|
||||
def test_deterministic(self):
|
||||
from services.corporate_memory.collector import _generate_id
|
||||
assert _generate_id("same") == _generate_id("same")
|
||||
|
||||
def test_different_content_different_id(self):
|
||||
from services.corporate_memory.collector import _generate_id
|
||||
assert _generate_id("aaa") != _generate_id("bbb")
|
||||
|
||||
|
||||
# ---------------------------------------------------------------------------
|
||||
# Tests for _process_catalog_response (hash change / governance preservation)
|
||||
# ---------------------------------------------------------------------------
|
||||
|
||||
class TestProcessCatalogResponse:
|
||||
def test_new_item_gets_generated_id(self):
|
||||
from services.corporate_memory.collector import _process_catalog_response
|
||||
response_items = [
|
||||
{
|
||||
"existing_id": None,
|
||||
"title": "Tip One",
|
||||
"content": "Always check the logs first.",
|
||||
"category": "debugging",
|
||||
"tags": ["logs"],
|
||||
"source_users": ["alice"],
|
||||
}
|
||||
]
|
||||
result = _process_catalog_response(response_items, existing={"items": {}})
|
||||
assert len(result) == 1
|
||||
item_id, item = next(iter(result.items()))
|
||||
assert item_id.startswith("km_")
|
||||
assert item["title"] == "Tip One"
|
||||
assert item["status"] == "approved" # default initial_status
|
||||
|
||||
def test_existing_id_preserved(self):
|
||||
from services.corporate_memory.collector import _process_catalog_response
|
||||
existing = {
|
||||
"items": {
|
||||
"km_abc123": {
|
||||
"id": "km_abc123",
|
||||
"title": "Old Title",
|
||||
"content": "Old content",
|
||||
"category": "debugging",
|
||||
"tags": [],
|
||||
"source_users": ["bob"],
|
||||
"extracted_at": "2026-01-01T00:00:00+00:00",
|
||||
"status": "approved",
|
||||
"approved_by": "admin",
|
||||
"approved_at": "2026-01-02T00:00:00+00:00",
|
||||
"mandatory_reason": None,
|
||||
"audience": "all",
|
||||
"review_by": None,
|
||||
"edited_by": None,
|
||||
"edited_at": None,
|
||||
}
|
||||
}
|
||||
}
|
||||
response_items = [
|
||||
{
|
||||
"existing_id": "km_abc123",
|
||||
"title": "Updated Title",
|
||||
"content": "New content",
|
||||
"category": "debugging",
|
||||
"tags": ["updated"],
|
||||
"source_users": ["bob"],
|
||||
}
|
||||
]
|
||||
result = _process_catalog_response(response_items, existing=existing)
|
||||
assert "km_abc123" in result
|
||||
item = result["km_abc123"]
|
||||
assert item["title"] == "Updated Title"
|
||||
|
||||
def test_governance_fields_preserved(self):
|
||||
from services.corporate_memory.collector import GOVERNANCE_FIELDS, _process_catalog_response
|
||||
existing = {
|
||||
"items": {
|
||||
"km_abc123": {
|
||||
"id": "km_abc123",
|
||||
"title": "T",
|
||||
"content": "C",
|
||||
"category": "workflow",
|
||||
"tags": [],
|
||||
"source_users": ["carol"],
|
||||
"extracted_at": "2026-01-01T00:00:00+00:00",
|
||||
"status": "approved",
|
||||
"approved_by": "manager",
|
||||
"approved_at": "2026-03-01T00:00:00+00:00",
|
||||
"mandatory_reason": "Policy",
|
||||
"audience": "team",
|
||||
"review_by": "2026-12-31",
|
||||
"edited_by": "carol",
|
||||
"edited_at": "2026-02-01T00:00:00+00:00",
|
||||
}
|
||||
}
|
||||
}
|
||||
response_items = [
|
||||
{
|
||||
"existing_id": "km_abc123",
|
||||
"title": "T",
|
||||
"content": "C updated",
|
||||
"category": "workflow",
|
||||
"tags": [],
|
||||
"source_users": ["carol"],
|
||||
}
|
||||
]
|
||||
result = _process_catalog_response(response_items, existing=existing)
|
||||
item = result["km_abc123"]
|
||||
assert item["approved_by"] == "manager"
|
||||
assert item["mandatory_reason"] == "Policy"
|
||||
assert item["audience"] == "team"
|
||||
|
||||
def test_new_item_with_pending_initial_status(self):
|
||||
from services.corporate_memory.collector import _process_catalog_response
|
||||
response_items = [
|
||||
{
|
||||
"existing_id": None,
|
||||
"title": "Another tip",
|
||||
"content": "Some content",
|
||||
"category": "workflow",
|
||||
"tags": [],
|
||||
"source_users": ["dave"],
|
||||
}
|
||||
]
|
||||
result = _process_catalog_response(
|
||||
response_items, existing={"items": {}}, initial_status="pending"
|
||||
)
|
||||
item = next(iter(result.values()))
|
||||
assert item["status"] == "pending"
|
||||
|
||||
|
||||
# ---------------------------------------------------------------------------
|
||||
# Tests for check_sensitivity
|
||||
# ---------------------------------------------------------------------------
|
||||
|
||||
class TestCheckSensitivity:
|
||||
def test_safe_item_returns_true(self):
|
||||
from services.corporate_memory.collector import check_sensitivity
|
||||
extractor = MockLLMProvider({"safe": True})
|
||||
item = {"id": "km_x", "title": "T", "content": "C", "tags": []}
|
||||
assert check_sensitivity(extractor, item) is True
|
||||
|
||||
def test_unsafe_item_returns_false(self):
|
||||
from services.corporate_memory.collector import check_sensitivity
|
||||
extractor = MockLLMProvider({"safe": False, "reason": "Contains PII"})
|
||||
item = {"id": "km_y", "title": "T", "content": "C", "tags": []}
|
||||
assert check_sensitivity(extractor, item) is False
|
||||
|
||||
def test_llm_error_returns_false(self):
|
||||
"""When the LLM raises an LLMError, the item is treated as unsafe."""
|
||||
from connectors.llm.exceptions import LLMError
|
||||
from services.corporate_memory.collector import check_sensitivity
|
||||
|
||||
class ErrorExtractor:
|
||||
def extract_json(self, *args, **kwargs):
|
||||
raise LLMError("Network error")
|
||||
|
||||
item = {"id": "km_z", "title": "T", "content": "C", "tags": []}
|
||||
assert check_sensitivity(ErrorExtractor(), item) is False
|
||||
|
||||
|
||||
# ---------------------------------------------------------------------------
|
||||
# Integration-style: collect_all with mocked I/O
|
||||
# ---------------------------------------------------------------------------
|
||||
|
||||
class TestCollectAllSkipsWhenNoChanges:
|
||||
def test_skips_when_no_user_files(self, tmp_path):
|
||||
"""collect_all returns skipped=True when no CLAUDE.local.md files exist."""
|
||||
from services.corporate_memory import collector
|
||||
|
||||
with (
|
||||
patch.object(collector, "HOME_BASE", tmp_path / "home"),
|
||||
patch.object(collector, "KNOWLEDGE_FILE", tmp_path / "knowledge.json"),
|
||||
patch.object(collector, "USER_HASHES_FILE", tmp_path / "user_hashes.json"),
|
||||
):
|
||||
(tmp_path / "home").mkdir()
|
||||
stats = collector.collect_all(dry_run=True)
|
||||
assert stats["skipped"] is True
|
||||
|
||||
def test_skips_when_hashes_unchanged(self, tmp_path):
|
||||
"""collect_all skips when hashes match stored values."""
|
||||
from services.corporate_memory import collector
|
||||
|
||||
home = tmp_path / "home"
|
||||
home.mkdir()
|
||||
user_dir = home / "alice"
|
||||
user_dir.mkdir()
|
||||
claude_file = user_dir / "CLAUDE.local.md"
|
||||
claude_file.write_text("# My tips\n- Always document code")
|
||||
|
||||
content = claude_file.read_text(encoding="utf-8")
|
||||
md5 = hashlib.md5(content.encode()).hexdigest()
|
||||
user_hashes_file = tmp_path / "user_hashes.json"
|
||||
_write_json(user_hashes_file, {"hashes": {"alice": md5}})
|
||||
|
||||
with (
|
||||
patch.object(collector, "HOME_BASE", home),
|
||||
patch.object(collector, "KNOWLEDGE_FILE", tmp_path / "knowledge.json"),
|
||||
patch.object(collector, "USER_HASHES_FILE", user_hashes_file),
|
||||
):
|
||||
stats = collector.collect_all(dry_run=True)
|
||||
assert stats["skipped"] is True
|
||||
134
tests/test_scheduler_full.py
Normal file
134
tests/test_scheduler_full.py
Normal file
|
|
@ -0,0 +1,134 @@
|
|||
"""Tests for schedule parsing and due-check logic in src/scheduler.py."""
|
||||
|
||||
from datetime import datetime, timezone
|
||||
|
||||
import pytest
|
||||
|
||||
from src.scheduler import is_table_due, parse_interval_minutes
|
||||
|
||||
|
||||
# ---------------------------------------------------------------------------
|
||||
# parse_interval_minutes
|
||||
# ---------------------------------------------------------------------------
|
||||
|
||||
class TestParseIntervalMinutes:
|
||||
def test_every_15m(self):
|
||||
assert parse_interval_minutes("every 15m") == 15
|
||||
|
||||
def test_every_1h(self):
|
||||
assert parse_interval_minutes("every 1h") == 60
|
||||
|
||||
def test_every_2h(self):
|
||||
assert parse_interval_minutes("every 2h") == 120
|
||||
|
||||
def test_every_30m(self):
|
||||
assert parse_interval_minutes("every 30m") == 30
|
||||
|
||||
def test_daily_returns_none(self):
|
||||
assert parse_interval_minutes("daily 05:00") is None
|
||||
|
||||
def test_invalid_string_returns_none(self):
|
||||
assert parse_interval_minutes("gibberish") is None
|
||||
|
||||
def test_empty_string_returns_none(self):
|
||||
assert parse_interval_minutes("") is None
|
||||
|
||||
def test_every_0m(self):
|
||||
# Edge case: zero minutes is still a valid parse
|
||||
assert parse_interval_minutes("every 0m") == 0
|
||||
|
||||
|
||||
# ---------------------------------------------------------------------------
|
||||
# is_table_due — interval schedules
|
||||
# ---------------------------------------------------------------------------
|
||||
|
||||
def _utc(year, month, day, hour=0, minute=0, second=0):
|
||||
return datetime(year, month, day, hour, minute, second, tzinfo=timezone.utc)
|
||||
|
||||
|
||||
class TestIsTableDueNeverSynced:
|
||||
def test_never_synced_is_always_due(self):
|
||||
assert is_table_due("every 1h", last_sync_iso=None) is True
|
||||
|
||||
def test_empty_string_last_sync_is_due(self):
|
||||
assert is_table_due("every 1h", last_sync_iso="") is True
|
||||
|
||||
|
||||
class TestIsTableDueInterval:
|
||||
NOW = _utc(2026, 4, 12, 10, 0, 0)
|
||||
|
||||
def test_interval_not_elapsed(self):
|
||||
# Synced 10 minutes ago, interval is 30m
|
||||
last = _utc(2026, 4, 12, 9, 50, 0).isoformat()
|
||||
assert is_table_due("every 30m", last, now=self.NOW) is False
|
||||
|
||||
def test_interval_elapsed(self):
|
||||
# Synced 31 minutes ago, interval is 30m
|
||||
last = _utc(2026, 4, 12, 9, 29, 0).isoformat()
|
||||
assert is_table_due("every 30m", last, now=self.NOW) is True
|
||||
|
||||
def test_exact_boundary(self):
|
||||
# Synced exactly 30 minutes ago — boundary is inclusive (>=)
|
||||
last = _utc(2026, 4, 12, 9, 30, 0).isoformat()
|
||||
assert is_table_due("every 30m", last, now=self.NOW) is True
|
||||
|
||||
def test_interval_1h_not_elapsed(self):
|
||||
last = _utc(2026, 4, 12, 9, 30, 0).isoformat()
|
||||
assert is_table_due("every 1h", last, now=self.NOW) is False
|
||||
|
||||
def test_interval_1h_elapsed(self):
|
||||
last = _utc(2026, 4, 12, 8, 59, 0).isoformat()
|
||||
assert is_table_due("every 1h", last, now=self.NOW) is True
|
||||
|
||||
|
||||
class TestIsTableDueDaily:
|
||||
def test_before_target_time_not_due(self):
|
||||
# now is 04:00 UTC, target is 05:00 — not yet reached
|
||||
now = _utc(2026, 4, 12, 4, 0, 0)
|
||||
last = _utc(2026, 4, 11, 5, 0, 0).isoformat() # yesterday
|
||||
assert is_table_due("daily 05:00", last, now=now) is False
|
||||
|
||||
def test_after_target_time_due(self):
|
||||
# now is 06:00 UTC, target is 05:00 — past target
|
||||
now = _utc(2026, 4, 12, 6, 0, 0)
|
||||
last = _utc(2026, 4, 11, 5, 0, 0).isoformat() # last sync was yesterday
|
||||
assert is_table_due("daily 05:00", last, now=now) is True
|
||||
|
||||
def test_already_synced_today(self):
|
||||
# Now is 10:00 UTC, synced at 05:30 today — not due again
|
||||
now = _utc(2026, 4, 12, 10, 0, 0)
|
||||
last = _utc(2026, 4, 12, 5, 30, 0).isoformat()
|
||||
assert is_table_due("daily 05:00", last, now=now) is False
|
||||
|
||||
def test_daily_multiple_times_second_time_due(self):
|
||||
# Schedule: daily 07:00,13:00,18:00
|
||||
# Now is 14:00, already synced at 07:30 — second target (13:00) is due
|
||||
now = _utc(2026, 4, 12, 14, 0, 0)
|
||||
last = _utc(2026, 4, 12, 7, 30, 0).isoformat()
|
||||
assert is_table_due("daily 07:00,13:00,18:00", last, now=now) is True
|
||||
|
||||
def test_daily_multiple_times_not_due_after_all(self):
|
||||
# Now is 19:00, synced at 18:30 — all targets passed
|
||||
now = _utc(2026, 4, 12, 19, 0, 0)
|
||||
last = _utc(2026, 4, 12, 18, 30, 0).isoformat()
|
||||
assert is_table_due("daily 07:00,13:00,18:00", last, now=now) is False
|
||||
|
||||
|
||||
class TestIsTableDueEdgeCases:
|
||||
def test_unknown_format_returns_false(self):
|
||||
now = _utc(2026, 4, 12, 10, 0, 0)
|
||||
assert is_table_due("weekly monday", "2026-04-11T09:00:00", now=now) is False
|
||||
|
||||
def test_invalid_timestamp_treated_as_due(self):
|
||||
assert is_table_due("every 1h", "not-a-timestamp") is True
|
||||
|
||||
def test_naive_last_sync_timestamp(self):
|
||||
# ISO timestamp without timezone info should still work
|
||||
now = _utc(2026, 4, 12, 10, 0, 0)
|
||||
last = "2026-04-12T08:00:00" # no tz info
|
||||
assert is_table_due("every 1h", last, now=now) is True
|
||||
|
||||
def test_none_now_uses_current_time(self):
|
||||
# Simply smoke-test that it doesn't crash with now=None
|
||||
result = is_table_due("every 1h", last_sync_iso=None, now=None)
|
||||
assert result is True # never synced
|
||||
98
tests/test_session_collector.py
Normal file
98
tests/test_session_collector.py
Normal file
|
|
@ -0,0 +1,98 @@
|
|||
"""Tests for session_collector.collector."""
|
||||
|
||||
from pathlib import Path
|
||||
|
||||
import pytest
|
||||
|
||||
from services.session_collector.collector import copy_session_file, find_session_files
|
||||
|
||||
|
||||
class TestCopySessionFile:
|
||||
def test_skips_if_target_exists(self, tmp_path):
|
||||
"""Returns False and does not overwrite if target already exists."""
|
||||
source = tmp_path / "session.jsonl"
|
||||
source.write_text('{"event": "start"}')
|
||||
target = tmp_path / "dest" / "session.jsonl"
|
||||
target.parent.mkdir(parents=True)
|
||||
target.write_text("existing content")
|
||||
|
||||
result = copy_session_file(source, target)
|
||||
assert result is False
|
||||
# Target content should not be overwritten
|
||||
assert target.read_text() == "existing content"
|
||||
|
||||
def test_copies_new_file(self, tmp_path):
|
||||
"""Returns True and creates the target when it does not exist."""
|
||||
source = tmp_path / "session.jsonl"
|
||||
source.write_text('{"event": "start"}')
|
||||
target = tmp_path / "dest" / "session.jsonl"
|
||||
|
||||
result = copy_session_file(source, target)
|
||||
assert result is True
|
||||
assert target.exists()
|
||||
assert target.read_text() == '{"event": "start"}'
|
||||
|
||||
def test_dry_run_returns_true_without_copying(self, tmp_path):
|
||||
"""In dry_run mode, returns True but does not create the file."""
|
||||
source = tmp_path / "session.jsonl"
|
||||
source.write_text('{"event": "start"}')
|
||||
target = tmp_path / "dest" / "session.jsonl"
|
||||
|
||||
result = copy_session_file(source, target, dry_run=True)
|
||||
assert result is True
|
||||
assert not target.exists()
|
||||
|
||||
def test_creates_parent_directory(self, tmp_path):
|
||||
"""Parent directories are created automatically."""
|
||||
source = tmp_path / "session.jsonl"
|
||||
source.write_text("data")
|
||||
target = tmp_path / "a" / "b" / "c" / "session.jsonl"
|
||||
|
||||
copy_session_file(source, target)
|
||||
assert target.exists()
|
||||
|
||||
def test_dry_run_skips_existing_target(self, tmp_path):
|
||||
"""dry_run still returns False if target already exists."""
|
||||
source = tmp_path / "session.jsonl"
|
||||
source.write_text("data")
|
||||
target = tmp_path / "session.jsonl"
|
||||
target.write_text("old")
|
||||
|
||||
result = copy_session_file(source, target, dry_run=True)
|
||||
assert result is False
|
||||
|
||||
|
||||
class TestFindSessionFiles:
|
||||
def test_finds_jsonl_files(self, tmp_path):
|
||||
"""find_session_files yields .jsonl files from user/sessions/."""
|
||||
user_home = tmp_path / "alice"
|
||||
sessions_dir = user_home / "user" / "sessions"
|
||||
sessions_dir.mkdir(parents=True)
|
||||
f1 = sessions_dir / "session1.jsonl"
|
||||
f2 = sessions_dir / "session2.jsonl"
|
||||
f1.write_text("{}")
|
||||
f2.write_text("{}")
|
||||
|
||||
found = list(find_session_files(user_home))
|
||||
assert len(found) == 2
|
||||
assert all(f.suffix == ".jsonl" for f in found)
|
||||
|
||||
def test_ignores_non_jsonl_files(self, tmp_path):
|
||||
"""Non-.jsonl files are not returned."""
|
||||
user_home = tmp_path / "bob"
|
||||
sessions_dir = user_home / "user" / "sessions"
|
||||
sessions_dir.mkdir(parents=True)
|
||||
(sessions_dir / "notes.txt").write_text("ignore me")
|
||||
(sessions_dir / "session.jsonl").write_text("{}")
|
||||
|
||||
found = list(find_session_files(user_home))
|
||||
assert len(found) == 1
|
||||
assert found[0].name == "session.jsonl"
|
||||
|
||||
def test_returns_empty_when_no_sessions_dir(self, tmp_path):
|
||||
"""Returns empty iterator when user/sessions/ doesn't exist."""
|
||||
user_home = tmp_path / "carol"
|
||||
user_home.mkdir()
|
||||
|
||||
found = list(find_session_files(user_home))
|
||||
assert found == []
|
||||
101
tests/test_telegram_bot.py
Normal file
101
tests/test_telegram_bot.py
Normal file
|
|
@ -0,0 +1,101 @@
|
|||
"""Tests for Telegram bot message handlers."""
|
||||
|
||||
import asyncio
|
||||
import os
|
||||
import sys
|
||||
from unittest.mock import AsyncMock, patch
|
||||
|
||||
import pytest
|
||||
|
||||
|
||||
def _make_message(text: str, chat_id: int = 10) -> dict:
|
||||
return {"chat": {"id": chat_id}, "text": text}
|
||||
|
||||
|
||||
def _run(coro):
|
||||
"""Run a coroutine synchronously."""
|
||||
return asyncio.get_event_loop().run_until_complete(coro)
|
||||
|
||||
|
||||
@pytest.fixture(autouse=True, scope="module")
|
||||
def patch_bot_log(tmp_path_factory):
|
||||
"""Patch BOT_LOG_FILE before the bot module is imported so the FileHandler succeeds."""
|
||||
log_dir = tmp_path_factory.mktemp("notify_bot")
|
||||
log_file = str(log_dir / "bot.log")
|
||||
|
||||
import services.telegram_bot.config as cfg
|
||||
original = cfg.BOT_LOG_FILE
|
||||
cfg.BOT_LOG_FILE = log_file
|
||||
|
||||
# Remove cached bot module so it re-imports with patched config
|
||||
sys.modules.pop("services.telegram_bot.bot", None)
|
||||
|
||||
yield
|
||||
|
||||
cfg.BOT_LOG_FILE = original
|
||||
sys.modules.pop("services.telegram_bot.bot", None)
|
||||
|
||||
|
||||
class TestHandleMessage:
|
||||
def test_start_unlinked_user_generates_verification_code(self):
|
||||
"""'/start' for an unlinked user generates and sends a verification code."""
|
||||
with (
|
||||
patch("services.telegram_bot.bot.get_username_by_chat_id", return_value=None),
|
||||
patch("services.telegram_bot.bot.create_verification_code", return_value="123456") as mock_code,
|
||||
patch("services.telegram_bot.bot.send_message", new_callable=AsyncMock) as mock_send,
|
||||
):
|
||||
from services.telegram_bot.bot import handle_message
|
||||
_run(handle_message(_make_message("/start", chat_id=10)))
|
||||
mock_code.assert_called_once_with(10)
|
||||
mock_send.assert_called_once()
|
||||
sent_text = mock_send.call_args[0][1]
|
||||
assert "123456" in sent_text
|
||||
|
||||
def test_start_already_linked_user_no_code(self):
|
||||
"""'/start' for an already-linked user does NOT generate a new code."""
|
||||
with (
|
||||
patch("services.telegram_bot.bot.get_username_by_chat_id", return_value="alice"),
|
||||
patch("services.telegram_bot.bot.create_verification_code") as mock_code,
|
||||
patch("services.telegram_bot.bot.send_message", new_callable=AsyncMock),
|
||||
):
|
||||
from services.telegram_bot.bot import handle_message
|
||||
_run(handle_message(_make_message("/start", chat_id=10)))
|
||||
mock_code.assert_not_called()
|
||||
|
||||
def test_help_returns_help_text(self):
|
||||
"""'/help' sends a message containing help information."""
|
||||
with patch("services.telegram_bot.bot.send_message", new_callable=AsyncMock) as mock_send:
|
||||
from services.telegram_bot.bot import handle_message
|
||||
_run(handle_message(_make_message("/help", chat_id=20)))
|
||||
mock_send.assert_called_once()
|
||||
sent_text = mock_send.call_args[0][1]
|
||||
assert "/start" in sent_text
|
||||
assert "/help" in sent_text
|
||||
|
||||
def test_unknown_command_sends_unknown_response(self):
|
||||
"""An unknown command sends an 'Unknown command' reply."""
|
||||
with patch("services.telegram_bot.bot.send_message", new_callable=AsyncMock) as mock_send:
|
||||
from services.telegram_bot.bot import handle_message
|
||||
_run(handle_message(_make_message("/foobar", chat_id=30)))
|
||||
mock_send.assert_called_once()
|
||||
sent_text = mock_send.call_args[0][1]
|
||||
assert "Unknown" in sent_text or "unknown" in sent_text
|
||||
|
||||
def test_message_with_no_chat_id_is_ignored(self):
|
||||
"""A message without a chat id does nothing."""
|
||||
with patch("services.telegram_bot.bot.send_message", new_callable=AsyncMock) as mock_send:
|
||||
from services.telegram_bot.bot import handle_message
|
||||
_run(handle_message({"text": "/help"}))
|
||||
mock_send.assert_not_called()
|
||||
|
||||
def test_whoami_linked_user_sends_username(self):
|
||||
"""'/whoami' for a linked user sends the username."""
|
||||
with (
|
||||
patch("services.telegram_bot.bot.get_username_by_chat_id", return_value="dave"),
|
||||
patch("services.telegram_bot.bot.send_message", new_callable=AsyncMock) as mock_send,
|
||||
):
|
||||
from services.telegram_bot.bot import handle_message
|
||||
_run(handle_message(_make_message("/whoami", chat_id=40)))
|
||||
mock_send.assert_called_once()
|
||||
sent_text = mock_send.call_args[0][1]
|
||||
assert "dave" in sent_text
|
||||
115
tests/test_telegram_storage.py
Normal file
115
tests/test_telegram_storage.py
Normal file
|
|
@ -0,0 +1,115 @@
|
|||
"""Tests for Telegram bot storage (user linking and verification codes)."""
|
||||
|
||||
import json
|
||||
import os
|
||||
import time
|
||||
from pathlib import Path
|
||||
from unittest.mock import patch
|
||||
|
||||
import pytest
|
||||
|
||||
|
||||
@pytest.fixture()
|
||||
def storage_paths(tmp_path, monkeypatch):
|
||||
"""Redirect storage file paths to tmp_path."""
|
||||
users_file = str(tmp_path / "telegram_users.json")
|
||||
codes_file = str(tmp_path / "pending_codes.json")
|
||||
|
||||
import services.telegram_bot.config as cfg
|
||||
monkeypatch.setattr(cfg, "TELEGRAM_USERS_FILE", users_file)
|
||||
monkeypatch.setattr(cfg, "PENDING_CODES_FILE", codes_file)
|
||||
# Also patch in the storage module namespace
|
||||
import services.telegram_bot.storage as storage_mod
|
||||
monkeypatch.setattr(storage_mod, "config", cfg)
|
||||
|
||||
return {"users": users_file, "codes": codes_file}
|
||||
|
||||
|
||||
class TestUserLinking:
|
||||
def test_link_user_and_get_chat_id(self, storage_paths):
|
||||
from services.telegram_bot.storage import get_chat_id, link_user
|
||||
link_user("alice", 100)
|
||||
assert get_chat_id("alice") == 100
|
||||
|
||||
def test_get_chat_id_unknown_user_returns_none(self, storage_paths):
|
||||
from services.telegram_bot.storage import get_chat_id
|
||||
assert get_chat_id("nobody") is None
|
||||
|
||||
def test_unlink_user_returns_true_when_linked(self, storage_paths):
|
||||
from services.telegram_bot.storage import link_user, unlink_user
|
||||
link_user("bob", 200)
|
||||
result = unlink_user("bob")
|
||||
assert result is True
|
||||
|
||||
def test_unlink_user_removes_entry(self, storage_paths):
|
||||
from services.telegram_bot.storage import get_chat_id, link_user, unlink_user
|
||||
link_user("carol", 300)
|
||||
unlink_user("carol")
|
||||
assert get_chat_id("carol") is None
|
||||
|
||||
def test_unlink_user_returns_false_when_not_linked(self, storage_paths):
|
||||
from services.telegram_bot.storage import unlink_user
|
||||
result = unlink_user("ghost")
|
||||
assert result is False
|
||||
|
||||
def test_link_multiple_users(self, storage_paths):
|
||||
from services.telegram_bot.storage import get_chat_id, link_user
|
||||
link_user("user1", 111)
|
||||
link_user("user2", 222)
|
||||
assert get_chat_id("user1") == 111
|
||||
assert get_chat_id("user2") == 222
|
||||
|
||||
|
||||
class TestVerificationCodes:
|
||||
def test_create_verification_code_returns_string(self, storage_paths):
|
||||
from services.telegram_bot.storage import create_verification_code
|
||||
code = create_verification_code(chat_id=42)
|
||||
assert isinstance(code, str)
|
||||
assert len(code) > 0
|
||||
|
||||
def test_verify_code_returns_chat_id(self, storage_paths):
|
||||
from services.telegram_bot.storage import create_verification_code, verify_code
|
||||
code = create_verification_code(chat_id=55)
|
||||
result = verify_code(code)
|
||||
assert result == 55
|
||||
|
||||
def test_code_consumed_after_first_verify(self, storage_paths):
|
||||
from services.telegram_bot.storage import create_verification_code, verify_code
|
||||
code = create_verification_code(chat_id=77)
|
||||
verify_code(code)
|
||||
# Second call must return None (code consumed)
|
||||
result = verify_code(code)
|
||||
assert result is None
|
||||
|
||||
def test_verify_invalid_code_returns_none(self, storage_paths):
|
||||
from services.telegram_bot.storage import verify_code
|
||||
result = verify_code("000000")
|
||||
assert result is None
|
||||
|
||||
def test_create_code_replaces_existing_for_same_chat_id(self, storage_paths):
|
||||
from services.telegram_bot.storage import create_verification_code, verify_code
|
||||
old_code = create_verification_code(chat_id=88)
|
||||
new_code = create_verification_code(chat_id=88)
|
||||
# Old code should be gone
|
||||
assert verify_code(old_code) is None
|
||||
# New code should work
|
||||
assert verify_code(new_code) == 88
|
||||
|
||||
def test_expired_code_not_valid(self, storage_paths):
|
||||
"""Manually write an expired code and verify it returns None."""
|
||||
import services.telegram_bot.config as cfg
|
||||
from services.telegram_bot.storage import verify_code
|
||||
|
||||
# Write a code that expired long ago
|
||||
expired_data = {
|
||||
"123456": {
|
||||
"chat_id": 99,
|
||||
"created_at": time.time() - cfg.CODE_TTL_SECONDS - 1,
|
||||
}
|
||||
}
|
||||
Path(cfg.PENDING_CODES_FILE).parent.mkdir(parents=True, exist_ok=True)
|
||||
with open(cfg.PENDING_CODES_FILE, "w") as f:
|
||||
json.dump(expired_data, f)
|
||||
|
||||
result = verify_code("123456")
|
||||
assert result is None
|
||||
80
tests/test_ws_gateway.py
Normal file
80
tests/test_ws_gateway.py
Normal file
|
|
@ -0,0 +1,80 @@
|
|||
"""Tests for WebSocket Gateway JWT authentication."""
|
||||
|
||||
import time
|
||||
|
||||
import jwt
|
||||
import pytest
|
||||
|
||||
SECRET = "test-secret-ws-gateway"
|
||||
|
||||
|
||||
@pytest.fixture(autouse=True)
|
||||
def patch_gateway_secret(monkeypatch):
|
||||
"""Patch the DESKTOP_JWT_SECRET so the module can be imported."""
|
||||
monkeypatch.setenv("DESKTOP_JWT_SECRET", SECRET)
|
||||
# Force reload of config module so the env var is picked up
|
||||
import importlib
|
||||
import services.ws_gateway.config as cfg
|
||||
importlib.reload(cfg)
|
||||
import services.ws_gateway.auth as auth_mod
|
||||
importlib.reload(auth_mod)
|
||||
|
||||
|
||||
def _make_token(payload: dict, secret: str = SECRET, algorithm: str = "HS256") -> str:
|
||||
return jwt.encode(payload, secret, algorithm=algorithm)
|
||||
|
||||
|
||||
def _import_validate():
|
||||
"""Return the validate_token function (after env is patched)."""
|
||||
from services.ws_gateway.auth import validate_token
|
||||
return validate_token
|
||||
|
||||
|
||||
class TestValidateToken:
|
||||
def test_valid_token_returns_payload(self):
|
||||
"""A token with 'sub' and a future 'exp' returns the decoded payload."""
|
||||
validate_token = _import_validate()
|
||||
payload = {"sub": "alice", "exp": int(time.time()) + 3600}
|
||||
token = _make_token(payload)
|
||||
result = validate_token(token)
|
||||
assert result is not None
|
||||
assert result["sub"] == "alice"
|
||||
|
||||
def test_expired_token_returns_none(self):
|
||||
"""An expired token returns None."""
|
||||
validate_token = _import_validate()
|
||||
payload = {"sub": "bob", "exp": int(time.time()) - 10}
|
||||
token = _make_token(payload)
|
||||
result = validate_token(token)
|
||||
assert result is None
|
||||
|
||||
def test_invalid_signature_returns_none(self):
|
||||
"""A token signed with a different secret returns None."""
|
||||
validate_token = _import_validate()
|
||||
payload = {"sub": "charlie", "exp": int(time.time()) + 3600}
|
||||
token = _make_token(payload, secret="wrong-secret")
|
||||
result = validate_token(token)
|
||||
assert result is None
|
||||
|
||||
def test_token_missing_sub_returns_none(self):
|
||||
"""A token that has no 'sub' claim returns None."""
|
||||
validate_token = _import_validate()
|
||||
payload = {"exp": int(time.time()) + 3600, "role": "admin"}
|
||||
token = _make_token(payload)
|
||||
result = validate_token(token)
|
||||
assert result is None
|
||||
|
||||
def test_garbage_string_returns_none(self):
|
||||
"""A completely invalid token string returns None."""
|
||||
validate_token = _import_validate()
|
||||
result = validate_token("not.a.token")
|
||||
assert result is None
|
||||
|
||||
def test_valid_token_includes_all_claims(self):
|
||||
"""All custom claims are present in the returned payload."""
|
||||
validate_token = _import_validate()
|
||||
payload = {"sub": "dave", "exp": int(time.time()) + 3600, "role": "analyst"}
|
||||
token = _make_token(payload)
|
||||
result = validate_token(token)
|
||||
assert result is not None
|
||||
assert result["role"] == "analyst"
|
||||
Loading…
Reference in a new issue