import asyncio import json from collections import defaultdict from typing import Dict, Set, Optional, Any from fastapi import WebSocket from db.redis import redis_client class ConnectionManager: def __init__(self): # Local connections on this server instance self.active_connections: Dict[str, Set[WebSocket]] = defaultdict(set) self._pubsub: Any = None self._listen_task: Optional[asyncio.Task] = None async def _setup_pubsub(self): if self._pubsub is None: self._pubsub = redis_client.pubsub() await self._pubsub.psubscribe("ws:group:*") self._listen_task = asyncio.create_task(self._redis_listener()) async def _redis_listener(self): if self._pubsub is None: return try: async for message in self._pubsub.listen(): if message["type"] == "pmessage": channel = message["channel"] # Extract group_id from "ws:group:{group_id}" group_id = channel.replace("ws:group:", "") data = json.loads(message["data"]) # Forward to local websockets for this group await self._local_broadcast(group_id, data) except Exception: # Re-initialize on error self._pubsub = None self._listen_task = None async def _local_broadcast(self, group_id: str, message: dict): if group_id in self.active_connections: for ws in list(self.active_connections[group_id]): try: await ws.send_json(message) except Exception: self.active_connections[group_id].discard(ws) async def connect(self, group_id: str, websocket: WebSocket): await websocket.accept() await self._setup_pubsub() self.active_connections[group_id].add(websocket) def disconnect(self, group_id: str, websocket: WebSocket): if group_id in self.active_connections: self.active_connections[group_id].discard(websocket) if not self.active_connections[group_id]: del self.active_connections[group_id] async def broadcast(self, group_id: str, message: dict): """ Public message to Redis. ALL server instances will receive it and forward it to their local connections for this group. """ await redis_client.publish(f"ws:group:{group_id}", json.dumps(message)) manager = ConnectionManager()