Neda/Back/core/websocket.py

28 lines
962 B
Python
Executable File

from fastapi import WebSocket, WebSocketException, status
from core.jwt import decode_token
from db.session import AsyncSessionLocal
from domains.users.repo import get_user_by_id
async def get_ws_current_user(websocket: WebSocket):
token = websocket.query_params.get("token")
if not token:
raise WebSocketException(code=status.WS_1008_POLICY_VIOLATION)
payload = decode_token(token)
if payload is None:
raise WebSocketException(code=status.WS_1008_POLICY_VIOLATION)
user_id = payload.get("sub")
token_version = payload.get("token_version")
async with AsyncSessionLocal() as db:
if user_id is None:
raise WebSocketException(code=status.WS_1008_POLICY_VIOLATION)
user = await get_user_by_id(db, user_id)
if not user or not user.is_active or user.token_version != token_version:
raise WebSocketException(code=status.WS_1008_POLICY_VIOLATION)
return user