refact: improve security and improve some code in domains

This commit is contained in:
roai_linux 2026-03-23 21:15:07 +03:30
parent 7f37d7fb60
commit 8ff4e90d13
18 changed files with 206 additions and 132 deletions

View File

@ -26,7 +26,7 @@ async def get_current_user(
if payload is None: if payload is None:
raise HTTPException( raise HTTPException(
status_code=status.HTTP_401_UNAUTHORIZED, status_code=status.HTTP_401_UNAUTHORIZED,
detail="Invalid authentication token", detail="توکن نامعتبر است",
) )
user_id = payload.get("sub") user_id = payload.get("sub")
@ -34,7 +34,7 @@ async def get_current_user(
if not user_id: if not user_id:
raise HTTPException( raise HTTPException(
status_code=status.HTTP_401_UNAUTHORIZED, status_code=status.HTTP_401_UNAUTHORIZED,
detail="Invalid token payload", detail="توکن نامعتبر است",
) )
user = await get_user_by_id(db, user_id) user = await get_user_by_id(db, user_id)
@ -42,13 +42,13 @@ async def get_current_user(
if not user: if not user:
raise HTTPException( raise HTTPException(
status_code=status.HTTP_401_UNAUTHORIZED, status_code=status.HTTP_401_UNAUTHORIZED,
detail="User not found", detail="کاربری یافت نشد",
) )
if not user.is_active: if not user.is_active:
raise HTTPException( raise HTTPException(
status_code=status.HTTP_403_FORBIDDEN, status_code=status.HTTP_403_FORBIDDEN,
detail="User is inactive", detail="کاربر غیرفعال است",
) )
# Check token version for remote logout # Check token version for remote logout
@ -56,7 +56,7 @@ async def get_current_user(
if token_version != user.token_version: if token_version != user.token_version:
raise HTTPException( raise HTTPException(
status_code=status.HTTP_401_UNAUTHORIZED, status_code=status.HTTP_401_UNAUTHORIZED,
detail="Token has been invalidated", detail="توکن نامعتبر است",
) )
return user return user
@ -72,7 +72,7 @@ async def get_current_admin(
if not user.is_admin: if not user.is_admin:
raise HTTPException( raise HTTPException(
status_code=status.HTTP_403_FORBIDDEN, status_code=status.HTTP_403_FORBIDDEN,
detail="Admin privileges required", detail="شما دسترسی لازم را ندارید",
) )
return user return user

50
Back/core/rate_limit.py Normal file
View File

@ -0,0 +1,50 @@
# core/rate_limit.py
from fastapi import Request, HTTPException, status
from db.redis import redis_client
class RateLimiter:
def __init__(self, requests: int, window_seconds: int, scope: str = "global"):
"""
:param requests: number of requests allowed
:param window_seconds: window time in seconds
:param scope: type of limit ("global" for all site, "endpoint" for a specific path)
"""
self.requests = requests
self.window_seconds = window_seconds
self.scope = scope
async def __call__(self, request: Request):
# getting client ip
client_ip = request.client.host if request.client else "127.0.0.1"
# when project is in docker and behind nginx, the real ip is in the headers
real_ip = request.headers.get("x-real-ip", request.headers.get("x-forwarded-for", client_ip))
# if there are multiple ips, take the first ip (the real user ip)
real_ip = real_ip.split(",")[0].strip()
# creating redis key based on scope
if self.scope == "global":
# key for global limit (e.g., rate_limit:global:192.168.1.5)
key = f"rate_limit:global:{real_ip}"
else:
# key for endpoint limit (e.g., rate_limit:endpoint:192.168.1.5:/admin/login)
path = request.scope["path"]
key = f"rate_limit:endpoint:{real_ip}:{path}"
# adding 1 to the number of requests for this ip in redis
current_count = await redis_client.incr(key)
# if this is the first request in this time window, set the expiration time (TTL)
if current_count == 1:
await redis_client.expire(key, self.window_seconds)
# if the number of requests exceeds the limit, access is blocked
if current_count > self.requests:
# penalty: if someone spams, the time they are blocked is extended from zero again
await redis_client.expire(key, self.window_seconds)
raise HTTPException(
status_code=status.HTTP_429_TOO_MANY_REQUESTS,
detail="Too many requests. Please try again later."
)

View File

@ -68,9 +68,9 @@ async def logout_user(
if not user: if not user:
raise HTTPException( raise HTTPException(
status_code=status.HTTP_404_NOT_FOUND, status_code=status.HTTP_404_NOT_FOUND,
detail="User not found" detail="نام‌ کاربری یافت نشد"
) )
return {"message": "User logged out successfully"} return {"message": "خروج کاربر با موفقیت انجام شد"}
@router.post("/users/{user_id}/reset-secret", @router.post("/users/{user_id}/reset-secret",
@ -89,7 +89,7 @@ async def reset_secret(
if not new_secret: if not new_secret:
raise HTTPException( raise HTTPException(
status_code=status.HTTP_404_NOT_FOUND, status_code=status.HTTP_404_NOT_FOUND,
detail="User not found" detail="نام کاربری یافت نشد"
) )
return {"secret": new_secret} return {"secret": new_secret}

View File

@ -1,24 +1,50 @@
from typing import Required
import uuid import uuid
from pydantic import BaseModel, ConfigDict from pydantic import BaseModel, ConfigDict, Field, field_validator, StrictBool
class AdminCreateUser(BaseModel): class AdminCreateUser(BaseModel):
model_config = ConfigDict(extra="forbid") model_config = ConfigDict(extra="forbid")
username: str username: str = Field(..., max_length=20, description='username of the user')
phone_number: str | None = None phone_number: str | None = Field(None, max_length=11, description='phone number of user')
# validate phone number
@field_validator('phone_number')
def validate_phone_number(cls, v: str | None) -> str | None:
if v is None:
return None
if not v.isdigit():
raise ValueError('شماره تلفن باید عدد باشد')
if len(v) != 11:
raise ValueError('شماره تلفن باید 11 رقم باشد')
if not v.startswith('09'):
raise ValueError('شماره تلفن باید با 09 شروع شود')
return v
class AdminUserResponse(BaseModel): class AdminUserResponse(BaseModel):
id: uuid.UUID id: uuid.UUID = Field(...)
username: str username: str = Field(..., max_length=20, description='username of the user')
phone_number: str | None phone_number: str | None = Field(..., description='phone number of user')
is_admin: bool is_admin: bool = Field(..., description='is admin')
is_active: bool is_active: bool = Field(..., description='is active')
@field_validator('phone_number')
def validate_phone_number(cls, v: str | None) -> str | None:
if v is None:
return None
if not v.isdigit():
raise ValueError('شماره تلفن باید عدد باشد')
if len(v) != 11:
raise ValueError('شماره تلفن باید 11 رقم باشد')
if not v.startswith('09'):
raise ValueError('شماره تلفن باید با 09 شروع شود')
return v
class Config: class Config:
from_attributes = True from_attributes = True
class AdminCreateUserResult(BaseModel): class AdminCreateUserResult(BaseModel):
user: AdminUserResponse user: AdminUserResponse = Field(...)
secret: str secret: str = Field(...)
class AdminResetSecretResult(BaseModel): class AdminResetSecretResult(BaseModel):
secret: str secret: str = Field(...)

View File

@ -41,7 +41,7 @@ async def _create_user_with_role(
existing = await get_user_by_username(db, username) existing = await get_user_by_username(db, username)
if existing: if existing:
raise ValueError("Username already exists") raise ValueError("نام کاربری تکراری است")
secret = generate_user_secret() secret = generate_user_secret()

View File

@ -1,8 +1,10 @@
import uuid
from fastapi import APIRouter, Depends, HTTPException, status from fastapi import APIRouter, Depends, HTTPException, status
from sqlalchemy.ext.asyncio import AsyncSession from sqlalchemy.ext.asyncio import AsyncSession
from db.session import get_db from db.session import get_db
from core.jwt import decode_token, create_access_token from core.jwt import decode_token, create_access_token
from core.rate_limit import RateLimiter
from domains.users.repo import get_user_by_id from domains.users.repo import get_user_by_id
from domains.auth.schemas import ( from domains.auth.schemas import (
@ -12,15 +14,14 @@ from domains.auth.schemas import (
) )
from domains.auth.service import login_user from domains.auth.service import login_user
import uuid
router = APIRouter( router = APIRouter(
prefix="/auth", prefix="/auth",
tags=["auth"] tags=["auth"]
) )
login_limiter = RateLimiter(requests=5, window_seconds=60, scope="endpoint")
@router.post("/login", response_model=TokenResponse) @router.post("/login", response_model=TokenResponse, dependencies=[Depends(login_limiter)])
async def login( async def login(
payload: LoginRequest, payload: LoginRequest,
db: AsyncSession = Depends(get_db) db: AsyncSession = Depends(get_db)
@ -35,7 +36,7 @@ async def login(
if not token: if not token:
raise HTTPException( raise HTTPException(
status_code=status.HTTP_401_UNAUTHORIZED, status_code=status.HTTP_401_UNAUTHORIZED,
detail="Invalid username or secret" detail="نام کاربری یا رمز عبور اشتباه است"
) )
return token return token
@ -49,7 +50,7 @@ async def refresh(
if not payload_data or payload_data.get("type") != "refresh": if not payload_data or payload_data.get("type") != "refresh":
raise HTTPException( raise HTTPException(
status_code=status.HTTP_401_UNAUTHORIZED, status_code=status.HTTP_401_UNAUTHORIZED,
detail="Invalid refresh token", detail="رفرش توکن نامعتبر است",
) )
user_id = payload_data.get("sub") user_id = payload_data.get("sub")
@ -58,7 +59,7 @@ async def refresh(
if not user_id or token_version is None: if not user_id or token_version is None:
raise HTTPException( raise HTTPException(
status_code=status.HTTP_401_UNAUTHORIZED, status_code=status.HTTP_401_UNAUTHORIZED,
detail="Invalid refresh token payload", detail="پیلود رفرش توکن نامعتبر است",
) )
try: try:
@ -66,7 +67,7 @@ async def refresh(
except ValueError: except ValueError:
raise HTTPException( raise HTTPException(
status_code=status.HTTP_401_UNAUTHORIZED, status_code=status.HTTP_401_UNAUTHORIZED,
detail="Invalid user ID in token", detail="شناسه کاربر نامعتبر است",
) )
user = await get_user_by_id(db, user_uuid) user = await get_user_by_id(db, user_uuid)
@ -74,7 +75,7 @@ async def refresh(
if not user or not user.is_active or user.token_version != token_version: if not user or not user.is_active or user.token_version != token_version:
raise HTTPException( raise HTTPException(
status_code=status.HTTP_401_UNAUTHORIZED, status_code=status.HTTP_401_UNAUTHORIZED,
detail="Invalid refresh token", detail="رفرش توکن نامعتبر است",
) )
access_token = create_access_token( access_token = create_access_token(

View File

@ -1,25 +1,25 @@
import uuid import uuid
from pydantic import BaseModel from pydantic import BaseModel, Field
class LoginRequest(BaseModel): class LoginRequest(BaseModel):
username: str username: str = Field(..., description='username of the user')
secret: str secret: str = Field(..., description='secret of the user')
class TokenResponse(BaseModel): class TokenResponse(BaseModel):
access_token: str access_token: str = Field(..., description='access token')
refresh_token: str refresh_token: str = Field(..., description='refresh token')
token_type: str = "bearer" token_type: str = Field("bearer", description='token type')
class RefreshTokenRequest(BaseModel): class RefreshTokenRequest(BaseModel):
refresh_token: str refresh_token: str = Field(..., description='refresh token')
class AuthUser(BaseModel): class AuthUser(BaseModel):
id: uuid.UUID id: uuid.UUID = Field(..., description='user id')
username: str username: str = Field(..., description='username of the user')
is_admin: bool is_admin: bool = Field(..., description='is admin')
class Config: class Config:
from_attributes = True from_attributes = True

View File

@ -100,7 +100,7 @@ async def invite_member(
user.id, user.id,
payload.username payload.username
) )
return {"message": "Invitation sent", "notification_id": notification.id} return {"message": "دعوت ارسال شد", "notification_id": notification.id}
except ValueError as e: except ValueError as e:
raise HTTPException( raise HTTPException(
status_code=status.HTTP_400_BAD_REQUEST, status_code=status.HTTP_400_BAD_REQUEST,
@ -125,7 +125,7 @@ async def remove_member(
user_id, user_id,
user user
) )
return {"message": "Member removed successfully"} return {"message": "عضو با موفقیت حذف شد"}
except ValueError as e: except ValueError as e:
raise HTTPException( raise HTTPException(
status_code=status.HTTP_403_FORBIDDEN, status_code=status.HTTP_403_FORBIDDEN,

View File

@ -1,28 +1,28 @@
import uuid import uuid
from pydantic import BaseModel from pydantic import BaseModel, Field
from domains.groups.models import GroupType, GroupMemberRole from domains.groups.models import GroupType, GroupMemberRole
class GroupCreate(BaseModel): class GroupCreate(BaseModel):
name: str name: str = Field(..., max_length=50, description='name of the group')
class GroupResponse(BaseModel): class GroupResponse(BaseModel):
id: uuid.UUID id: uuid.UUID = Field(...)
name: str name: str = Field(..., max_length=50, description='name of the group')
type: GroupType type: GroupType = Field(..., description='type of the group')
is_active: bool is_active: bool = Field(..., description='is active')
class Config: class Config:
from_attributes = True from_attributes = True
class AddMemberRequest(BaseModel): class AddMemberRequest(BaseModel):
username: str # Req 12 says user enters username username: str = Field(..., max_length=20, description='username of the user')
class GroupMemberResponse(BaseModel): class GroupMemberResponse(BaseModel):
user_id: uuid.UUID user_id: uuid.UUID = Field(...)
username: str username: str = Field(..., max_length=20, description='username of the user')
role: GroupMemberRole role: GroupMemberRole = Field(..., description='role of the user in the group')
is_online: bool = False is_online: bool = Field(False, description='is online')
class Config: class Config:
from_attributes = True from_attributes = True

View File

@ -15,7 +15,8 @@ from domains.groups.repo import (
get_all_groups as repo_get_all_groups get_all_groups as repo_get_all_groups
) )
from domains.realtime.presence_service import list_online_users from domains.realtime.presence_service import list_online_users
from domains.users.repo import get_user_by_username
from domains.notifications.service import send_join_request
async def create_new_group( async def create_new_group(
db: AsyncSession, db: AsyncSession,
@ -49,29 +50,28 @@ async def invite_member_to_group(
sender_id: uuid.UUID, sender_id: uuid.UUID,
target_username: str target_username: str
): ):
from domains.users.repo import get_user_by_username
from domains.notifications.service import send_join_request
group_id_uuid = group_id if isinstance(group_id, uuid.UUID) else uuid.UUID(group_id) group_id_uuid = group_id if isinstance(group_id, uuid.UUID) else uuid.UUID(group_id)
# 1. Check if group exists # 1. Check if group exists
group = await get_group_by_id(db, group_id_uuid) group = await get_group_by_id(db, group_id_uuid)
if not group: if not group:
raise ValueError("Group not found") raise ValueError("گروهی یافت نشد")
sender = await get_user_by_id(db, sender_id) sender = await get_user_by_id(db, sender_id)
if not sender: if not sender:
raise ValueError("Sender not found") raise ValueError("فرستنده یافت نشد")
if not sender.is_admin: if not sender.is_admin:
membership = await get_group_member(db, group_id_uuid, sender_id) membership = await get_group_member(db, group_id_uuid, sender_id)
if not membership: if not membership:
raise ValueError("Not a group member") raise ValueError("شما عضو این گروه نیستید")
# 2. Check if target user exists # 2. Check if target user exists
target_user = await get_user_by_username(db, target_username) target_user = await get_user_by_username(db, target_username)
if not target_user: if not target_user:
raise ValueError("User not found") raise ValueError("کاربری یافت نشد")
# 3. Send notification (Req 12) # 3. Send notification (Req 12)
return await send_join_request( return await send_join_request(
@ -79,8 +79,8 @@ async def invite_member_to_group(
sender_id=sender_id, sender_id=sender_id,
receiver_id=target_user.id, receiver_id=target_user.id,
group_id=group.id, group_id=group.id,
title="Group Invitation", title="دعوت به گروه",
description=f"You have been invited to join group {group.name}" description=f"شما به گروه {group.name} دعوت شده‌اید"
) )
@ -95,7 +95,7 @@ async def add_member_to_group(
existing = await get_group_member(db, group_id_uuid, user_id_uuid) existing = await get_group_member(db, group_id_uuid, user_id_uuid)
if existing: if existing:
raise ValueError("User already in group") raise ValueError("کاربر از قبل عضو گروه است")
membership = GroupMember( membership = GroupMember(
group_id=group_id_uuid, group_id=group_id_uuid,
@ -146,6 +146,6 @@ async def remove_member_from_group(
if not requesting_user.is_admin: if not requesting_user.is_admin:
membership = await get_group_member(db, group_id_uuid, requesting_user.id) membership = await get_group_member(db, group_id_uuid, requesting_user.id)
if not membership or membership.role != GroupMemberRole.MANAGER: if not membership or membership.role != GroupMemberRole.MANAGER:
raise ValueError("Permission denied") raise ValueError("دسترسی لازم را ندارید")
await delete_group_member(db, group_id_uuid, target_user_id_uuid) await delete_group_member(db, group_id_uuid, target_user_id_uuid)

View File

@ -1,25 +1,26 @@
import uuid import uuid
from pydantic import BaseModel from pydantic import BaseModel, Field
from domains.notifications.models import NotificationType from domains.notifications.models import NotificationType
class NotificationBase(BaseModel): class NotificationBase(BaseModel):
title: str title: str = Field(..., description='title of the notification')
description: str | None = None description: str | None = Field(None, description='description of the notification')
type: NotificationType type: NotificationType = Field(..., description='type of the notification')
group_id: uuid.UUID | None = None group_id: uuid.UUID | None = Field(None, description='group id of the notification')
class NotificationCreate(NotificationBase): class NotificationCreate(NotificationBase):
receiver_id: uuid.UUID receiver_id: uuid.UUID = Field(..., description='receiver id of the notification')
sender_id: uuid.UUID | None = None sender_id: uuid.UUID | None = Field(None, description='sender id of the notification')
class NotificationResponse(NotificationBase): class NotificationResponse(NotificationBase):
id: uuid.UUID id: uuid.UUID = Field(..., description='notification id')
is_accepted: bool | None is_accepted: bool | None = Field(..., description='is accepted')
receiver_id: uuid.UUID receiver_id: uuid.UUID = Field(..., description='receiver id of the notification')
sender_id: uuid.UUID | None sender_id: uuid.UUID | None = Field(..., description='sender id of the notification')
class Config: class Config:
from_attributes = True from_attributes = True
class NotificationAction(BaseModel): class NotificationAction(BaseModel):
is_accepted: bool is_accepted: bool = Field(..., description='is accepted')

View File

@ -62,10 +62,10 @@ async def respond_to_notification(
notification_id_uuid = notification_id if isinstance(notification_id, uuid.UUID) else uuid.UUID(notification_id) notification_id_uuid = notification_id if isinstance(notification_id, uuid.UUID) else uuid.UUID(notification_id)
notification = await get_notification_by_id(db, notification_id_uuid) notification = await get_notification_by_id(db, notification_id_uuid)
if not notification: if not notification:
raise ValueError("Notification not found") raise ValueError("نوتیفیکیشنی یافت نشد")
if str(notification.receiver_id) != str(user_id): if str(notification.receiver_id) != str(user_id):
raise ValueError("Permission denied") raise ValueError("دسترسی لازم را ندارید")
notification.is_accepted = is_accepted notification.is_accepted = is_accepted
await update_notification(db, notification) await update_notification(db, notification)

View File

@ -13,17 +13,10 @@ from domains.groups.repo import get_group_by_id
async def request_speak( async def request_speak(
group_id: str | uuid.UUID, group_id: str | uuid.UUID,
user_id: str | uuid.UUID, user_id: str | uuid.UUID,
group_type: str
) -> bool: ) -> bool:
group_id_str = str(group_id) group_id_str = str(group_id)
user_id_str = str(user_id) user_id_str = str(user_id)
if group_type == "private":
await grant_publish_permission(group_id_str, user_id_str, True)
return True
# group chat → push-to-talk
granted = await acquire_speaker(group_id_str, user_id_str) granted = await acquire_speaker(group_id_str, user_id_str)
if not granted: if not granted:
return False return False
@ -35,46 +28,28 @@ async def request_speak(
async def stop_speaking( async def stop_speaking(
group_id: str | uuid.UUID, group_id: str | uuid.UUID,
user_id: str | uuid.UUID, user_id: str | uuid.UUID,
group_type: str
): ):
group_id_str = str(group_id) group_id_str = str(group_id)
user_id_str = str(user_id) user_id_str = str(user_id)
if group_type == "private":
await grant_publish_permission(group_id_str, user_id_str, False)
return True
released = await release_speaker(group_id_str, user_id_str) released = await release_speaker(group_id_str, user_id_str)
if released: if released:
await grant_publish_permission(group_id_str, user_id_str, False) await grant_publish_permission(group_id_str, user_id_str, False)
return released return released
async def current_speaker( async def current_speaker(group_id: str | uuid.UUID):
db: AsyncSession, group_id_str = str(group_id)
group_id: str | uuid.UUID
):
group_id_uuid = group_id if isinstance(group_id, uuid.UUID) else uuid.UUID(group_id)
group_id_str = str(group_id_uuid)
group = await get_group_by_id(db, group_id_uuid)
if not group:
return None
if str(group.type) == "private":
return None
return await get_active_speaker(group_id_str) return await get_active_speaker(group_id_str)
async def grant_publish_permission(room_name: str, identity: str, can_publish: bool): async def grant_publish_permission(room_name: str, identity: str, can_publish: bool):
lk_api = get_livekit_api() # همان متدی که در client.py نوشتی lk_api = get_livekit_api()
await lk_api.room.update_participant( await lk_api.room.update_participant(
api.UpdateParticipantRequest( api.UpdateParticipantRequest(
room=room_name, room=room_name,
identity=identity, identity=identity,
permission=api.ParticipantPermission( permission=api.ParticipantPermission(
can_publish=can_publish, can_publish=can_publish,
can_subscribe=True # همیشه بتواند بشنود can_subscribe=True
) )
) )
) )

View File

@ -1,4 +1,6 @@
import uuid import uuid
import time
import json
from fastapi import APIRouter, WebSocket, WebSocketDisconnect, Request, Header, status, HTTPException from fastapi import APIRouter, WebSocket, WebSocketDisconnect, Request, Header, status, HTTPException
from livekit import api from livekit import api
from core.config import settings from core.config import settings
@ -79,10 +81,6 @@ async def group_ws(websocket: WebSocket, group_id: str):
await websocket.close(code=status.WS_1008_POLICY_VIOLATION) await websocket.close(code=status.WS_1008_POLICY_VIOLATION)
return return
from domains.groups.repo import get_group_by_id
group = await get_group_by_id(db, group_id_uuid)
group_type = str(group.type) if group else "public"
user_id = str(user.id) user_id = str(user.id)
# connect websocket # connect websocket
await manager.connect(group_id, websocket) await manager.connect(group_id, websocket)
@ -111,13 +109,28 @@ async def group_ws(websocket: WebSocket, group_id: str):
} }
) )
last_action_time = 0.0
try: try:
while True: while True:
data = await websocket.receive_json() text_data = await websocket.receive_text()
if len(text_data) > 20000:
await websocket.close(code=status.WS_1009_MESSAGE_TOO_BIG)
return
data = json.loads(text_data)
event = data.get("type") event = data.get("type")
# anti spam
current_time = time.time()
if current_time - last_action_time < 0.5:
continue
last_action_time = current_time
# user wants to speak # user wants to speak
if event == "request_speak": if event == "request_speak":
success = await request_speak(group_id, user_id, group_type) success = await request_speak(group_id, user_id)
if success: if success:
# Broadcast globally that someone is speaking # Broadcast globally that someone is speaking
await manager.broadcast( await manager.broadcast(
@ -133,8 +146,7 @@ async def group_ws(websocket: WebSocket, group_id: str):
}) })
else: else:
# someone else is speaking # someone else is speaking
async with AsyncSessionLocal() as temp_db: speaker = await current_speaker(group_id)
speaker = await current_speaker(temp_db, group_id)
await websocket.send_json( await websocket.send_json(
{ {
"type": "speaker_busy", "type": "speaker_busy",
@ -144,7 +156,7 @@ async def group_ws(websocket: WebSocket, group_id: str):
# user stops speaking # user stops speaking
elif event == "stop_speak": elif event == "stop_speak":
released = await stop_speaking(group_id, user_id, group_type) released = await stop_speaking(group_id, user_id)
if released: if released:
await manager.broadcast( await manager.broadcast(
group_id, group_id,
@ -160,7 +172,7 @@ async def group_ws(websocket: WebSocket, group_id: str):
except WebSocketDisconnect: except WebSocketDisconnect:
manager.disconnect(group_id, websocket) manager.disconnect(group_id, websocket)
await user_leave_group(group_id, user_id) await user_leave_group(group_id, user_id)
await stop_speaking(group_id, user_id, group_type) await stop_speaking(group_id, user_id)
await manager.broadcast( await manager.broadcast(
group_id, group_id,
{ {

View File

@ -17,8 +17,6 @@ class ConnectionManager:
async def _setup_pubsub(self): async def _setup_pubsub(self):
if self._pubsub is None: if self._pubsub is None:
self._pubsub = redis_client.pubsub() self._pubsub = redis_client.pubsub()
# Subscribe to all websocket broadcasting channels
# We use Any for _pubsub to satisfy type checkers that don't know redis-py return types
await self._pubsub.psubscribe("ws:group:*") await self._pubsub.psubscribe("ws:group:*")
self._listen_task = asyncio.create_task(self._redis_listener()) self._listen_task = asyncio.create_task(self._redis_listener())

View File

@ -1,4 +1,3 @@
from enum import Enum
from sqlalchemy import String, Boolean, Integer from sqlalchemy import String, Boolean, Integer
from sqlalchemy.orm import Mapped, mapped_column from sqlalchemy.orm import Mapped, mapped_column

View File

@ -3,9 +3,9 @@ from pydantic import BaseModel, Field, field_validator
import re import re
class UserCreate(BaseModel): class UserCreate(BaseModel):
username: str username: str = Field(..., max_length=20, description='username of the user')
phone_number: str | None = Field(None, description="11 digit phone number") phone_number: str | None = Field(None, description="11 digit phone number")
is_admin: bool = False is_admin: bool = Field(False)
@field_validator("phone_number") @field_validator("phone_number")
@classmethod @classmethod
@ -13,19 +13,19 @@ class UserCreate(BaseModel):
if v is None: if v is None:
return v return v
if not re.match(r"^09\d{9}$", v): if not re.match(r"^09\d{9}$", v):
raise ValueError("Phone number must start with 09 and be exactly 11 digits") raise ValueError("شماره تلفن باید ۱۱ رقم باشد و با ۰۹ شروع شود")
return v return v
class UserResponse(BaseModel): class UserResponse(BaseModel):
id: uuid.UUID id: uuid.UUID = Field(...)
username: str username: str = Field(..., description='username of the user')
phone_number: str | None phone_number: str | None = Field(..., description='phone number of the user')
is_admin: bool is_admin: bool = Field(..., description='is admin')
is_active: bool is_active: bool = Field(..., description='is active')
class Config: class Config:
from_attributes = True from_attributes = True
class UserCreateResult(BaseModel): class UserCreateResult(BaseModel):
user: UserResponse user: UserResponse = Field(..., description='user created')
secret: str secret: str = Field(..., description='secret of the user')

View File

@ -1,5 +1,5 @@
from contextlib import asynccontextmanager from contextlib import asynccontextmanager
from fastapi import FastAPI from fastapi import FastAPI, Depends
from fastapi.middleware.cors import CORSMiddleware from fastapi.middleware.cors import CORSMiddleware
from fastapi_swagger import patch_fastapi from fastapi_swagger import patch_fastapi
@ -11,6 +11,7 @@ from domains.realtime.ws import router as realtime_router
from domains.notifications.api import router as notifications_router from domains.notifications.api import router as notifications_router
from integrations.livekit.client import close_livekit_api from integrations.livekit.client import close_livekit_api
from db.redis import redis_client from db.redis import redis_client
from core.rate_limit import RateLimiter
@asynccontextmanager @asynccontextmanager
@ -33,6 +34,7 @@ async def lifespan(app: FastAPI):
await close_livekit_api() await close_livekit_api()
await redis_client.close() await redis_client.close()
global_limiter = RateLimiter(requests=30, window_seconds=60, scope="global")
app = FastAPI( app = FastAPI(
title="NEDA API", title="NEDA API",
@ -40,7 +42,8 @@ app = FastAPI(
version="1.0.0", version="1.0.0",
lifespan=lifespan, lifespan=lifespan,
docs_url=None, docs_url=None,
swagger_ui_oauth2_redirect_url=None swagger_ui_oauth2_redirect_url=None,
dependencies=[Depends(global_limiter)]
) )
patch_fastapi(app,docs_url="/swagger") patch_fastapi(app,docs_url="/swagger")
@ -55,7 +58,16 @@ app.add_middleware(
allow_methods=["*"], allow_methods=["*"],
allow_headers=["*"], allow_headers=["*"],
) )
# app.add_middleware(
# CORSMiddleware,
# allow_origins=[
# "https://app.neda.com",
# "http://localhost:3000" # فقط برای تست برنامه‌نویس فرانت‌اند
# ],
# allow_credentials=True,
# allow_methods=["GET", "POST", "PUT", "DELETE"], # محدود کردن متدها
# allow_headers=["Authorization", "Content-Type"], # محدود کردن هدرها
# )
# ------------------------- # -------------------------
# Routers # Routers