68 lines
2.5 KiB
Python
68 lines
2.5 KiB
Python
import asyncio
|
|
import json
|
|
from collections import defaultdict
|
|
from typing import Dict, Set, Optional, Any
|
|
|
|
from fastapi import WebSocket
|
|
from db.redis import redis_client
|
|
|
|
|
|
class ConnectionManager:
|
|
def __init__(self):
|
|
# Local connections on this server instance
|
|
self.active_connections: Dict[str, Set[WebSocket]] = defaultdict(set)
|
|
self._pubsub: Any = None
|
|
self._listen_task: Optional[asyncio.Task] = None
|
|
|
|
async def _setup_pubsub(self):
|
|
if self._pubsub is None:
|
|
self._pubsub = redis_client.pubsub()
|
|
await self._pubsub.psubscribe("ws:group:*")
|
|
self._listen_task = asyncio.create_task(self._redis_listener())
|
|
|
|
async def _redis_listener(self):
|
|
if self._pubsub is None:
|
|
return
|
|
try:
|
|
async for message in self._pubsub.listen():
|
|
if message["type"] == "pmessage":
|
|
channel = message["channel"]
|
|
# Extract group_id from "ws:group:{group_id}"
|
|
group_id = channel.replace("ws:group:", "")
|
|
data = json.loads(message["data"])
|
|
|
|
# Forward to local websockets for this group
|
|
await self._local_broadcast(group_id, data)
|
|
except Exception:
|
|
# Re-initialize on error
|
|
self._pubsub = None
|
|
self._listen_task = None
|
|
|
|
async def _local_broadcast(self, group_id: str, message: dict):
|
|
if group_id in self.active_connections:
|
|
for ws in list(self.active_connections[group_id]):
|
|
try:
|
|
await ws.send_json(message)
|
|
except Exception:
|
|
self.active_connections[group_id].discard(ws)
|
|
|
|
async def connect(self, group_id: str, websocket: WebSocket):
|
|
await websocket.accept()
|
|
await self._setup_pubsub()
|
|
self.active_connections[group_id].add(websocket)
|
|
|
|
def disconnect(self, group_id: str, websocket: WebSocket):
|
|
if group_id in self.active_connections:
|
|
self.active_connections[group_id].discard(websocket)
|
|
if not self.active_connections[group_id]:
|
|
del self.active_connections[group_id]
|
|
|
|
async def broadcast(self, group_id: str, message: dict):
|
|
"""
|
|
Public message to Redis. ALL server instances will receive it
|
|
and forward it to their local connections for this group.
|
|
"""
|
|
await redis_client.publish(f"ws:group:{group_id}", json.dumps(message))
|
|
|
|
|
|
manager = ConnectionManager() |