Neda/Back/domains/realtime/ws_manager.py

70 lines
2.7 KiB
Python

import asyncio
import json
from collections import defaultdict
from typing import Dict, Set, Optional, Any
from fastapi import WebSocket
from db.redis import redis_client
class ConnectionManager:
def __init__(self):
# Local connections on this server instance
self.active_connections: Dict[str, Set[WebSocket]] = defaultdict(set)
self._pubsub: Any = None
self._listen_task: Optional[asyncio.Task] = None
async def _setup_pubsub(self):
if self._pubsub is None:
self._pubsub = redis_client.pubsub()
# Subscribe to all websocket broadcasting channels
# We use Any for _pubsub to satisfy type checkers that don't know redis-py return types
await self._pubsub.psubscribe("ws:group:*")
self._listen_task = asyncio.create_task(self._redis_listener())
async def _redis_listener(self):
if self._pubsub is None:
return
try:
async for message in self._pubsub.listen():
if message["type"] == "pmessage":
channel = message["channel"]
# Extract group_id from "ws:group:{group_id}"
group_id = channel.replace("ws:group:", "")
data = json.loads(message["data"])
# Forward to local websockets for this group
await self._local_broadcast(group_id, data)
except Exception:
# Re-initialize on error
self._pubsub = None
self._listen_task = None
async def _local_broadcast(self, group_id: str, message: dict):
if group_id in self.active_connections:
for ws in list(self.active_connections[group_id]):
try:
await ws.send_json(message)
except Exception:
self.active_connections[group_id].discard(ws)
async def connect(self, group_id: str, websocket: WebSocket):
await websocket.accept()
await self._setup_pubsub()
self.active_connections[group_id].add(websocket)
def disconnect(self, group_id: str, websocket: WebSocket):
if group_id in self.active_connections:
self.active_connections[group_id].discard(websocket)
if not self.active_connections[group_id]:
del self.active_connections[group_id]
async def broadcast(self, group_id: str, message: dict):
"""
Public message to Redis. ALL server instances will receive it
and forward it to their local connections for this group.
"""
await redis_client.publish(f"ws:group:{group_id}", json.dumps(message))
manager = ConnectionManager()