from fastapi import WebSocket, WebSocketException, status from core.jwt import decode_token from db.session import AsyncSessionLocal from domains.users.repo import get_user_by_id async def get_ws_current_user(websocket: WebSocket): token = websocket.query_params.get("token") if not token: raise WebSocketException(code=status.WS_1008_POLICY_VIOLATION) payload = decode_token(token) if payload is None: raise WebSocketException(code=status.WS_1008_POLICY_VIOLATION) user_id = payload.get("sub") async with AsyncSessionLocal() as db: user = await get_user_by_id(db, user_id) if not user: raise WebSocketException(code=status.WS_1008_POLICY_VIOLATION) return user