Extract 4 self-contained services into services/ module: - server/telegram_bot/ -> services/telegram_bot/ - server/ws_gateway/ -> services/ws_gateway/ - server/corporate_memory/ -> services/corporate_memory/ - server/session_collector.py -> services/session_collector/ Each service now has its own systemd/ directory with .service and .timer files. deploy.sh updated to auto-discover service units from services/*/systemd/*. server/ now contains only deployment infrastructure (deploy.sh, setup scripts, bin/ management tools, sudoers, nginx config). All imports updated: webapp/app.py, server/bin/ scripts, systemd ExecStart paths.
249 lines
7.9 KiB
Python
249 lines
7.9 KiB
Python
"""Main WebSocket Gateway with TCP WebSocket server and Unix socket HTTP dispatch."""
|
|
|
|
import asyncio
|
|
import json
|
|
import logging
|
|
import os
|
|
from collections import defaultdict
|
|
|
|
import aiohttp
|
|
from aiohttp import web, WSMsgType
|
|
|
|
from .auth import validate_token
|
|
from .config import (
|
|
HEARTBEAT_INTERVAL_SECONDS,
|
|
HEARTBEAT_TIMEOUT_MISSED,
|
|
MAX_CONNECTIONS_PER_USER,
|
|
WS_DISPATCH_SOCKET,
|
|
WS_GATEWAY_HOST,
|
|
WS_GATEWAY_PORT,
|
|
)
|
|
|
|
logger = logging.getLogger(__name__)
|
|
|
|
# Global connection registry: username -> list of WebSocket responses
|
|
connections: dict[str, list[web.WebSocketResponse]] = defaultdict(list)
|
|
|
|
|
|
def _total_connections() -> int:
|
|
"""Return total number of active WebSocket connections."""
|
|
return sum(len(ws_list) for ws_list in connections.values())
|
|
|
|
|
|
def _remove_connection(username: str, ws: web.WebSocketResponse) -> None:
|
|
"""Remove a WebSocket connection from the registry."""
|
|
if username in connections:
|
|
try:
|
|
connections[username].remove(ws)
|
|
except ValueError:
|
|
pass
|
|
if not connections[username]:
|
|
del connections[username]
|
|
|
|
|
|
async def _heartbeat_loop(
|
|
username: str, ws: web.WebSocketResponse
|
|
) -> None:
|
|
"""Send periodic pings and disconnect on missed pongs.
|
|
|
|
NOTE: This task may be cancelled and restarted when the client sends a pong.
|
|
Connection cleanup is handled by ws_handler's finally block, NOT here.
|
|
"""
|
|
missed = 0
|
|
try:
|
|
while not ws.closed:
|
|
await asyncio.sleep(HEARTBEAT_INTERVAL_SECONDS)
|
|
if ws.closed:
|
|
break
|
|
try:
|
|
await ws.send_json({"type": "ping"})
|
|
missed += 1
|
|
except ConnectionResetError:
|
|
break
|
|
|
|
if missed >= HEARTBEAT_TIMEOUT_MISSED:
|
|
logger.warning(
|
|
"User %s missed %d heartbeats, disconnecting",
|
|
username,
|
|
missed,
|
|
)
|
|
await ws.close()
|
|
break
|
|
except asyncio.CancelledError:
|
|
pass
|
|
|
|
|
|
async def ws_handler(request: web.Request) -> web.WebSocketResponse:
|
|
"""Handle incoming WebSocket connections on the TCP server."""
|
|
ws = web.WebSocketResponse()
|
|
await ws.prepare(request)
|
|
|
|
username: str | None = None
|
|
heartbeat_task: asyncio.Task | None = None
|
|
|
|
try:
|
|
# Wait for auth message
|
|
try:
|
|
msg = await asyncio.wait_for(ws.receive(), timeout=10.0)
|
|
except asyncio.TimeoutError:
|
|
await ws.send_json({"type": "auth_error", "message": "Auth timeout"})
|
|
await ws.close()
|
|
return ws
|
|
|
|
if msg.type != WSMsgType.TEXT:
|
|
await ws.close()
|
|
return ws
|
|
|
|
try:
|
|
data = json.loads(msg.data)
|
|
except json.JSONDecodeError:
|
|
await ws.send_json({"type": "auth_error", "message": "Invalid JSON"})
|
|
await ws.close()
|
|
return ws
|
|
|
|
if data.get("type") != "auth" or "token" not in data:
|
|
await ws.send_json(
|
|
{"type": "auth_error", "message": "Expected auth message with token"}
|
|
)
|
|
await ws.close()
|
|
return ws
|
|
|
|
payload = validate_token(data["token"])
|
|
if payload is None:
|
|
await ws.send_json({"type": "auth_error", "message": "Invalid token"})
|
|
await ws.close()
|
|
return ws
|
|
|
|
username = payload["sub"]
|
|
|
|
# Enforce per-user connection limit
|
|
if len(connections[username]) >= MAX_CONNECTIONS_PER_USER:
|
|
await ws.send_json(
|
|
{"type": "auth_error", "message": "Too many connections"}
|
|
)
|
|
await ws.close()
|
|
return ws
|
|
|
|
connections[username].append(ws)
|
|
await ws.send_json({"type": "auth_ok", "username": username})
|
|
logger.info("User %s connected (total: %d)", username, _total_connections())
|
|
|
|
# Start heartbeat
|
|
heartbeat_task = asyncio.create_task(_heartbeat_loop(username, ws))
|
|
|
|
# Read loop - handle pong responses and ignore other messages
|
|
async for msg in ws:
|
|
logger.debug("User %s msg type=%s", username, msg.type)
|
|
if msg.type == WSMsgType.TEXT:
|
|
try:
|
|
data = json.loads(msg.data)
|
|
except json.JSONDecodeError:
|
|
continue
|
|
if data.get("type") == "pong":
|
|
# Reset missed counter by cancelling and restarting heartbeat
|
|
heartbeat_task.cancel()
|
|
heartbeat_task = asyncio.create_task(
|
|
_heartbeat_loop(username, ws)
|
|
)
|
|
elif msg.type == WSMsgType.CLOSE:
|
|
logger.info("User %s sent CLOSE frame", username)
|
|
break
|
|
elif msg.type == WSMsgType.ERROR:
|
|
logger.warning("User %s WS error: %s", username, ws.exception())
|
|
break
|
|
|
|
except Exception:
|
|
logger.exception("Error in WebSocket handler")
|
|
finally:
|
|
if heartbeat_task is not None:
|
|
heartbeat_task.cancel()
|
|
if username is not None:
|
|
_remove_connection(username, ws)
|
|
logger.info(
|
|
"User %s disconnected (total: %d)", username, _total_connections()
|
|
)
|
|
|
|
return ws
|
|
|
|
|
|
# --- HTTP dispatch handlers (Unix socket) ---
|
|
|
|
|
|
async def dispatch_handler(request: web.Request) -> web.Response:
|
|
"""Handle POST /dispatch to send notifications to connected users."""
|
|
try:
|
|
body = await request.json()
|
|
except json.JSONDecodeError:
|
|
return web.json_response({"error": "Invalid JSON"}, status=400)
|
|
|
|
user = body.get("user")
|
|
notification = body.get("notification")
|
|
|
|
if not user or not notification:
|
|
return web.json_response(
|
|
{"error": "Missing 'user' or 'notification'"}, status=400
|
|
)
|
|
|
|
user_connections = connections.get(user, [])
|
|
sent_count = 0
|
|
message = {"type": "notification", **notification}
|
|
|
|
for ws in list(user_connections):
|
|
if ws.closed:
|
|
continue
|
|
try:
|
|
await ws.send_json(message)
|
|
sent_count += 1
|
|
except Exception:
|
|
logger.warning("Failed to send notification to user %s", user)
|
|
|
|
return web.json_response({"sent": sent_count})
|
|
|
|
|
|
async def health_handler(request: web.Request) -> web.Response:
|
|
"""Handle GET /health to report gateway status."""
|
|
total = _total_connections()
|
|
users = {user: len(ws_list) for user, ws_list in connections.items()}
|
|
return web.json_response(
|
|
{"status": "ok", "connections": total, "users": users}
|
|
)
|
|
|
|
|
|
async def main() -> None:
|
|
"""Run both the WebSocket TCP server and the HTTP Unix socket dispatch server."""
|
|
logging.basicConfig(
|
|
level=logging.INFO,
|
|
format="%(asctime)s [%(levelname)s] %(name)s: %(message)s",
|
|
)
|
|
|
|
# WebSocket TCP server
|
|
ws_app = web.Application()
|
|
ws_app.router.add_get("/", ws_handler)
|
|
|
|
ws_runner = web.AppRunner(ws_app)
|
|
await ws_runner.setup()
|
|
ws_site = web.TCPSite(ws_runner, WS_GATEWAY_HOST, WS_GATEWAY_PORT)
|
|
await ws_site.start()
|
|
logger.info(
|
|
"WebSocket server listening on %s:%d", WS_GATEWAY_HOST, WS_GATEWAY_PORT
|
|
)
|
|
|
|
# HTTP dispatch Unix socket server
|
|
dispatch_app = web.Application()
|
|
dispatch_app.router.add_post("/dispatch", dispatch_handler)
|
|
dispatch_app.router.add_get("/health", health_handler)
|
|
|
|
dispatch_runner = web.AppRunner(dispatch_app)
|
|
await dispatch_runner.setup()
|
|
dispatch_site = web.UnixSite(dispatch_runner, WS_DISPATCH_SOCKET)
|
|
await dispatch_site.start()
|
|
# Allow group members (data-ops: www-data, deploy) to connect to the socket
|
|
os.chmod(WS_DISPATCH_SOCKET, 0o770)
|
|
logger.info("Dispatch server listening on %s", WS_DISPATCH_SOCKET)
|
|
|
|
# Run forever
|
|
try:
|
|
await asyncio.Event().wait()
|
|
finally:
|
|
await ws_runner.cleanup()
|
|
await dispatch_runner.cleanup()
|