refact: improve security and improve some code in domains
This commit is contained in:
parent
7f37d7fb60
commit
8ff4e90d13
|
|
@ -26,7 +26,7 @@ async def get_current_user(
|
|||
if payload is None:
|
||||
raise HTTPException(
|
||||
status_code=status.HTTP_401_UNAUTHORIZED,
|
||||
detail="Invalid authentication token",
|
||||
detail="توکن نامعتبر است",
|
||||
)
|
||||
|
||||
user_id = payload.get("sub")
|
||||
|
|
@ -34,7 +34,7 @@ async def get_current_user(
|
|||
if not user_id:
|
||||
raise HTTPException(
|
||||
status_code=status.HTTP_401_UNAUTHORIZED,
|
||||
detail="Invalid token payload",
|
||||
detail="توکن نامعتبر است",
|
||||
)
|
||||
|
||||
user = await get_user_by_id(db, user_id)
|
||||
|
|
@ -42,13 +42,13 @@ async def get_current_user(
|
|||
if not user:
|
||||
raise HTTPException(
|
||||
status_code=status.HTTP_401_UNAUTHORIZED,
|
||||
detail="User not found",
|
||||
detail="کاربری یافت نشد",
|
||||
)
|
||||
|
||||
if not user.is_active:
|
||||
raise HTTPException(
|
||||
status_code=status.HTTP_403_FORBIDDEN,
|
||||
detail="User is inactive",
|
||||
detail="کاربر غیرفعال است",
|
||||
)
|
||||
|
||||
# Check token version for remote logout
|
||||
|
|
@ -56,7 +56,7 @@ async def get_current_user(
|
|||
if token_version != user.token_version:
|
||||
raise HTTPException(
|
||||
status_code=status.HTTP_401_UNAUTHORIZED,
|
||||
detail="Token has been invalidated",
|
||||
detail="توکن نامعتبر است",
|
||||
)
|
||||
|
||||
return user
|
||||
|
|
@ -72,7 +72,7 @@ async def get_current_admin(
|
|||
if not user.is_admin:
|
||||
raise HTTPException(
|
||||
status_code=status.HTTP_403_FORBIDDEN,
|
||||
detail="Admin privileges required",
|
||||
detail="شما دسترسی لازم را ندارید",
|
||||
)
|
||||
|
||||
return user
|
||||
50
Back/core/rate_limit.py
Normal file
50
Back/core/rate_limit.py
Normal file
|
|
@ -0,0 +1,50 @@
|
|||
# core/rate_limit.py
|
||||
|
||||
from fastapi import Request, HTTPException, status
|
||||
from db.redis import redis_client
|
||||
|
||||
class RateLimiter:
|
||||
def __init__(self, requests: int, window_seconds: int, scope: str = "global"):
|
||||
"""
|
||||
:param requests: number of requests allowed
|
||||
:param window_seconds: window time in seconds
|
||||
:param scope: type of limit ("global" for all site, "endpoint" for a specific path)
|
||||
"""
|
||||
self.requests = requests
|
||||
self.window_seconds = window_seconds
|
||||
self.scope = scope
|
||||
|
||||
async def __call__(self, request: Request):
|
||||
# getting client ip
|
||||
client_ip = request.client.host if request.client else "127.0.0.1"
|
||||
|
||||
# when project is in docker and behind nginx, the real ip is in the headers
|
||||
real_ip = request.headers.get("x-real-ip", request.headers.get("x-forwarded-for", client_ip))
|
||||
# if there are multiple ips, take the first ip (the real user ip)
|
||||
real_ip = real_ip.split(",")[0].strip()
|
||||
|
||||
# creating redis key based on scope
|
||||
if self.scope == "global":
|
||||
# key for global limit (e.g., rate_limit:global:192.168.1.5)
|
||||
key = f"rate_limit:global:{real_ip}"
|
||||
else:
|
||||
# key for endpoint limit (e.g., rate_limit:endpoint:192.168.1.5:/admin/login)
|
||||
path = request.scope["path"]
|
||||
key = f"rate_limit:endpoint:{real_ip}:{path}"
|
||||
|
||||
# adding 1 to the number of requests for this ip in redis
|
||||
current_count = await redis_client.incr(key)
|
||||
|
||||
# if this is the first request in this time window, set the expiration time (TTL)
|
||||
if current_count == 1:
|
||||
await redis_client.expire(key, self.window_seconds)
|
||||
|
||||
# if the number of requests exceeds the limit, access is blocked
|
||||
if current_count > self.requests:
|
||||
# penalty: if someone spams, the time they are blocked is extended from zero again
|
||||
await redis_client.expire(key, self.window_seconds)
|
||||
|
||||
raise HTTPException(
|
||||
status_code=status.HTTP_429_TOO_MANY_REQUESTS,
|
||||
detail="Too many requests. Please try again later."
|
||||
)
|
||||
|
|
@ -68,9 +68,9 @@ async def logout_user(
|
|||
if not user:
|
||||
raise HTTPException(
|
||||
status_code=status.HTTP_404_NOT_FOUND,
|
||||
detail="User not found"
|
||||
detail="نام کاربری یافت نشد"
|
||||
)
|
||||
return {"message": "User logged out successfully"}
|
||||
return {"message": "خروج کاربر با موفقیت انجام شد"}
|
||||
|
||||
|
||||
@router.post("/users/{user_id}/reset-secret",
|
||||
|
|
@ -89,7 +89,7 @@ async def reset_secret(
|
|||
if not new_secret:
|
||||
raise HTTPException(
|
||||
status_code=status.HTTP_404_NOT_FOUND,
|
||||
detail="User not found"
|
||||
detail="نام کاربری یافت نشد"
|
||||
)
|
||||
|
||||
return {"secret": new_secret}
|
||||
|
|
|
|||
|
|
@ -1,24 +1,50 @@
|
|||
from typing import Required
|
||||
import uuid
|
||||
from pydantic import BaseModel, ConfigDict
|
||||
from pydantic import BaseModel, ConfigDict, Field, field_validator, StrictBool
|
||||
|
||||
class AdminCreateUser(BaseModel):
|
||||
model_config = ConfigDict(extra="forbid")
|
||||
username: str
|
||||
phone_number: str | None = None
|
||||
username: str = Field(..., max_length=20, description='username of the user')
|
||||
phone_number: str | None = Field(None, max_length=11, description='phone number of user')
|
||||
|
||||
# validate phone number
|
||||
@field_validator('phone_number')
|
||||
def validate_phone_number(cls, v: str | None) -> str | None:
|
||||
if v is None:
|
||||
return None
|
||||
if not v.isdigit():
|
||||
raise ValueError('شماره تلفن باید عدد باشد')
|
||||
if len(v) != 11:
|
||||
raise ValueError('شماره تلفن باید 11 رقم باشد')
|
||||
if not v.startswith('09'):
|
||||
raise ValueError('شماره تلفن باید با 09 شروع شود')
|
||||
return v
|
||||
|
||||
class AdminUserResponse(BaseModel):
|
||||
id: uuid.UUID
|
||||
username: str
|
||||
phone_number: str | None
|
||||
is_admin: bool
|
||||
is_active: bool
|
||||
id: uuid.UUID = Field(...)
|
||||
username: str = Field(..., max_length=20, description='username of the user')
|
||||
phone_number: str | None = Field(..., description='phone number of user')
|
||||
is_admin: bool = Field(..., description='is admin')
|
||||
is_active: bool = Field(..., description='is active')
|
||||
|
||||
@field_validator('phone_number')
|
||||
def validate_phone_number(cls, v: str | None) -> str | None:
|
||||
if v is None:
|
||||
return None
|
||||
if not v.isdigit():
|
||||
raise ValueError('شماره تلفن باید عدد باشد')
|
||||
if len(v) != 11:
|
||||
raise ValueError('شماره تلفن باید 11 رقم باشد')
|
||||
if not v.startswith('09'):
|
||||
raise ValueError('شماره تلفن باید با 09 شروع شود')
|
||||
return v
|
||||
|
||||
class Config:
|
||||
from_attributes = True
|
||||
|
||||
class AdminCreateUserResult(BaseModel):
|
||||
user: AdminUserResponse
|
||||
secret: str
|
||||
user: AdminUserResponse = Field(...)
|
||||
secret: str = Field(...)
|
||||
|
||||
class AdminResetSecretResult(BaseModel):
|
||||
secret: str
|
||||
secret: str = Field(...)
|
||||
|
|
|
|||
|
|
@ -41,7 +41,7 @@ async def _create_user_with_role(
|
|||
existing = await get_user_by_username(db, username)
|
||||
|
||||
if existing:
|
||||
raise ValueError("Username already exists")
|
||||
raise ValueError("نام کاربری تکراری است")
|
||||
|
||||
secret = generate_user_secret()
|
||||
|
||||
|
|
|
|||
|
|
@ -1,8 +1,10 @@
|
|||
import uuid
|
||||
from fastapi import APIRouter, Depends, HTTPException, status
|
||||
from sqlalchemy.ext.asyncio import AsyncSession
|
||||
|
||||
from db.session import get_db
|
||||
from core.jwt import decode_token, create_access_token
|
||||
from core.rate_limit import RateLimiter
|
||||
from domains.users.repo import get_user_by_id
|
||||
|
||||
from domains.auth.schemas import (
|
||||
|
|
@ -12,15 +14,14 @@ from domains.auth.schemas import (
|
|||
)
|
||||
|
||||
from domains.auth.service import login_user
|
||||
import uuid
|
||||
|
||||
router = APIRouter(
|
||||
prefix="/auth",
|
||||
tags=["auth"]
|
||||
)
|
||||
|
||||
|
||||
@router.post("/login", response_model=TokenResponse)
|
||||
login_limiter = RateLimiter(requests=5, window_seconds=60, scope="endpoint")
|
||||
@router.post("/login", response_model=TokenResponse, dependencies=[Depends(login_limiter)])
|
||||
async def login(
|
||||
payload: LoginRequest,
|
||||
db: AsyncSession = Depends(get_db)
|
||||
|
|
@ -35,7 +36,7 @@ async def login(
|
|||
if not token:
|
||||
raise HTTPException(
|
||||
status_code=status.HTTP_401_UNAUTHORIZED,
|
||||
detail="Invalid username or secret"
|
||||
detail="نام کاربری یا رمز عبور اشتباه است"
|
||||
)
|
||||
|
||||
return token
|
||||
|
|
@ -49,7 +50,7 @@ async def refresh(
|
|||
if not payload_data or payload_data.get("type") != "refresh":
|
||||
raise HTTPException(
|
||||
status_code=status.HTTP_401_UNAUTHORIZED,
|
||||
detail="Invalid refresh token",
|
||||
detail="رفرش توکن نامعتبر است",
|
||||
)
|
||||
|
||||
user_id = payload_data.get("sub")
|
||||
|
|
@ -58,7 +59,7 @@ async def refresh(
|
|||
if not user_id or token_version is None:
|
||||
raise HTTPException(
|
||||
status_code=status.HTTP_401_UNAUTHORIZED,
|
||||
detail="Invalid refresh token payload",
|
||||
detail="پیلود رفرش توکن نامعتبر است",
|
||||
)
|
||||
|
||||
try:
|
||||
|
|
@ -66,7 +67,7 @@ async def refresh(
|
|||
except ValueError:
|
||||
raise HTTPException(
|
||||
status_code=status.HTTP_401_UNAUTHORIZED,
|
||||
detail="Invalid user ID in token",
|
||||
detail="شناسه کاربر نامعتبر است",
|
||||
)
|
||||
|
||||
user = await get_user_by_id(db, user_uuid)
|
||||
|
|
@ -74,7 +75,7 @@ async def refresh(
|
|||
if not user or not user.is_active or user.token_version != token_version:
|
||||
raise HTTPException(
|
||||
status_code=status.HTTP_401_UNAUTHORIZED,
|
||||
detail="Invalid refresh token",
|
||||
detail="رفرش توکن نامعتبر است",
|
||||
)
|
||||
|
||||
access_token = create_access_token(
|
||||
|
|
|
|||
|
|
@ -1,25 +1,25 @@
|
|||
import uuid
|
||||
from pydantic import BaseModel
|
||||
from pydantic import BaseModel, Field
|
||||
|
||||
|
||||
class LoginRequest(BaseModel):
|
||||
username: str
|
||||
secret: str
|
||||
username: str = Field(..., description='username of the user')
|
||||
secret: str = Field(..., description='secret of the user')
|
||||
|
||||
|
||||
class TokenResponse(BaseModel):
|
||||
access_token: str
|
||||
refresh_token: str
|
||||
token_type: str = "bearer"
|
||||
access_token: str = Field(..., description='access token')
|
||||
refresh_token: str = Field(..., description='refresh token')
|
||||
token_type: str = Field("bearer", description='token type')
|
||||
|
||||
class RefreshTokenRequest(BaseModel):
|
||||
refresh_token: str
|
||||
refresh_token: str = Field(..., description='refresh token')
|
||||
|
||||
|
||||
class AuthUser(BaseModel):
|
||||
id: uuid.UUID
|
||||
username: str
|
||||
is_admin: bool
|
||||
id: uuid.UUID = Field(..., description='user id')
|
||||
username: str = Field(..., description='username of the user')
|
||||
is_admin: bool = Field(..., description='is admin')
|
||||
|
||||
class Config:
|
||||
from_attributes = True
|
||||
|
|
@ -100,7 +100,7 @@ async def invite_member(
|
|||
user.id,
|
||||
payload.username
|
||||
)
|
||||
return {"message": "Invitation sent", "notification_id": notification.id}
|
||||
return {"message": "دعوت ارسال شد", "notification_id": notification.id}
|
||||
except ValueError as e:
|
||||
raise HTTPException(
|
||||
status_code=status.HTTP_400_BAD_REQUEST,
|
||||
|
|
@ -125,7 +125,7 @@ async def remove_member(
|
|||
user_id,
|
||||
user
|
||||
)
|
||||
return {"message": "Member removed successfully"}
|
||||
return {"message": "عضو با موفقیت حذف شد"}
|
||||
except ValueError as e:
|
||||
raise HTTPException(
|
||||
status_code=status.HTTP_403_FORBIDDEN,
|
||||
|
|
|
|||
|
|
@ -1,28 +1,28 @@
|
|||
import uuid
|
||||
from pydantic import BaseModel
|
||||
from pydantic import BaseModel, Field
|
||||
|
||||
from domains.groups.models import GroupType, GroupMemberRole
|
||||
|
||||
class GroupCreate(BaseModel):
|
||||
name: str
|
||||
name: str = Field(..., max_length=50, description='name of the group')
|
||||
|
||||
class GroupResponse(BaseModel):
|
||||
id: uuid.UUID
|
||||
name: str
|
||||
type: GroupType
|
||||
is_active: bool
|
||||
id: uuid.UUID = Field(...)
|
||||
name: str = Field(..., max_length=50, description='name of the group')
|
||||
type: GroupType = Field(..., description='type of the group')
|
||||
is_active: bool = Field(..., description='is active')
|
||||
|
||||
class Config:
|
||||
from_attributes = True
|
||||
|
||||
class AddMemberRequest(BaseModel):
|
||||
username: str # Req 12 says user enters username
|
||||
username: str = Field(..., max_length=20, description='username of the user')
|
||||
|
||||
class GroupMemberResponse(BaseModel):
|
||||
user_id: uuid.UUID
|
||||
username: str
|
||||
role: GroupMemberRole
|
||||
is_online: bool = False
|
||||
user_id: uuid.UUID = Field(...)
|
||||
username: str = Field(..., max_length=20, description='username of the user')
|
||||
role: GroupMemberRole = Field(..., description='role of the user in the group')
|
||||
is_online: bool = Field(False, description='is online')
|
||||
|
||||
class Config:
|
||||
from_attributes = True
|
||||
|
|
@ -15,7 +15,8 @@ from domains.groups.repo import (
|
|||
get_all_groups as repo_get_all_groups
|
||||
)
|
||||
from domains.realtime.presence_service import list_online_users
|
||||
|
||||
from domains.users.repo import get_user_by_username
|
||||
from domains.notifications.service import send_join_request
|
||||
|
||||
async def create_new_group(
|
||||
db: AsyncSession,
|
||||
|
|
@ -49,29 +50,28 @@ async def invite_member_to_group(
|
|||
sender_id: uuid.UUID,
|
||||
target_username: str
|
||||
):
|
||||
from domains.users.repo import get_user_by_username
|
||||
from domains.notifications.service import send_join_request
|
||||
|
||||
|
||||
group_id_uuid = group_id if isinstance(group_id, uuid.UUID) else uuid.UUID(group_id)
|
||||
|
||||
# 1. Check if group exists
|
||||
group = await get_group_by_id(db, group_id_uuid)
|
||||
if not group:
|
||||
raise ValueError("Group not found")
|
||||
raise ValueError("گروهی یافت نشد")
|
||||
|
||||
sender = await get_user_by_id(db, sender_id)
|
||||
if not sender:
|
||||
raise ValueError("Sender not found")
|
||||
raise ValueError("فرستنده یافت نشد")
|
||||
|
||||
if not sender.is_admin:
|
||||
membership = await get_group_member(db, group_id_uuid, sender_id)
|
||||
if not membership:
|
||||
raise ValueError("Not a group member")
|
||||
raise ValueError("شما عضو این گروه نیستید")
|
||||
|
||||
# 2. Check if target user exists
|
||||
target_user = await get_user_by_username(db, target_username)
|
||||
if not target_user:
|
||||
raise ValueError("User not found")
|
||||
raise ValueError("کاربری یافت نشد")
|
||||
|
||||
# 3. Send notification (Req 12)
|
||||
return await send_join_request(
|
||||
|
|
@ -79,8 +79,8 @@ async def invite_member_to_group(
|
|||
sender_id=sender_id,
|
||||
receiver_id=target_user.id,
|
||||
group_id=group.id,
|
||||
title="Group Invitation",
|
||||
description=f"You have been invited to join group {group.name}"
|
||||
title="دعوت به گروه",
|
||||
description=f"شما به گروه {group.name} دعوت شدهاید"
|
||||
)
|
||||
|
||||
|
||||
|
|
@ -95,7 +95,7 @@ async def add_member_to_group(
|
|||
|
||||
existing = await get_group_member(db, group_id_uuid, user_id_uuid)
|
||||
if existing:
|
||||
raise ValueError("User already in group")
|
||||
raise ValueError("کاربر از قبل عضو گروه است")
|
||||
|
||||
membership = GroupMember(
|
||||
group_id=group_id_uuid,
|
||||
|
|
@ -146,6 +146,6 @@ async def remove_member_from_group(
|
|||
if not requesting_user.is_admin:
|
||||
membership = await get_group_member(db, group_id_uuid, requesting_user.id)
|
||||
if not membership or membership.role != GroupMemberRole.MANAGER:
|
||||
raise ValueError("Permission denied")
|
||||
raise ValueError("دسترسی لازم را ندارید")
|
||||
|
||||
await delete_group_member(db, group_id_uuid, target_user_id_uuid)
|
||||
|
|
|
|||
|
|
@ -1,25 +1,26 @@
|
|||
import uuid
|
||||
from pydantic import BaseModel
|
||||
from pydantic import BaseModel, Field
|
||||
|
||||
from domains.notifications.models import NotificationType
|
||||
|
||||
class NotificationBase(BaseModel):
|
||||
title: str
|
||||
description: str | None = None
|
||||
type: NotificationType
|
||||
group_id: uuid.UUID | None = None
|
||||
title: str = Field(..., description='title of the notification')
|
||||
description: str | None = Field(None, description='description of the notification')
|
||||
type: NotificationType = Field(..., description='type of the notification')
|
||||
group_id: uuid.UUID | None = Field(None, description='group id of the notification')
|
||||
|
||||
class NotificationCreate(NotificationBase):
|
||||
receiver_id: uuid.UUID
|
||||
sender_id: uuid.UUID | None = None
|
||||
receiver_id: uuid.UUID = Field(..., description='receiver id of the notification')
|
||||
sender_id: uuid.UUID | None = Field(None, description='sender id of the notification')
|
||||
|
||||
class NotificationResponse(NotificationBase):
|
||||
id: uuid.UUID
|
||||
is_accepted: bool | None
|
||||
receiver_id: uuid.UUID
|
||||
sender_id: uuid.UUID | None
|
||||
id: uuid.UUID = Field(..., description='notification id')
|
||||
is_accepted: bool | None = Field(..., description='is accepted')
|
||||
receiver_id: uuid.UUID = Field(..., description='receiver id of the notification')
|
||||
sender_id: uuid.UUID | None = Field(..., description='sender id of the notification')
|
||||
|
||||
class Config:
|
||||
from_attributes = True
|
||||
|
||||
class NotificationAction(BaseModel):
|
||||
is_accepted: bool
|
||||
is_accepted: bool = Field(..., description='is accepted')
|
||||
|
|
|
|||
|
|
@ -62,10 +62,10 @@ async def respond_to_notification(
|
|||
notification_id_uuid = notification_id if isinstance(notification_id, uuid.UUID) else uuid.UUID(notification_id)
|
||||
notification = await get_notification_by_id(db, notification_id_uuid)
|
||||
if not notification:
|
||||
raise ValueError("Notification not found")
|
||||
raise ValueError("نوتیفیکیشنی یافت نشد")
|
||||
|
||||
if str(notification.receiver_id) != str(user_id):
|
||||
raise ValueError("Permission denied")
|
||||
raise ValueError("دسترسی لازم را ندارید")
|
||||
|
||||
notification.is_accepted = is_accepted
|
||||
await update_notification(db, notification)
|
||||
|
|
|
|||
|
|
@ -13,17 +13,10 @@ from domains.groups.repo import get_group_by_id
|
|||
async def request_speak(
|
||||
group_id: str | uuid.UUID,
|
||||
user_id: str | uuid.UUID,
|
||||
group_type: str
|
||||
) -> bool:
|
||||
group_id_str = str(group_id)
|
||||
user_id_str = str(user_id)
|
||||
|
||||
|
||||
if group_type == "private":
|
||||
await grant_publish_permission(group_id_str, user_id_str, True)
|
||||
return True
|
||||
|
||||
# group chat → push-to-talk
|
||||
granted = await acquire_speaker(group_id_str, user_id_str)
|
||||
if not granted:
|
||||
return False
|
||||
|
|
@ -35,46 +28,28 @@ async def request_speak(
|
|||
async def stop_speaking(
|
||||
group_id: str | uuid.UUID,
|
||||
user_id: str | uuid.UUID,
|
||||
group_type: str
|
||||
):
|
||||
group_id_str = str(group_id)
|
||||
user_id_str = str(user_id)
|
||||
|
||||
if group_type == "private":
|
||||
await grant_publish_permission(group_id_str, user_id_str, False)
|
||||
return True
|
||||
|
||||
released = await release_speaker(group_id_str, user_id_str)
|
||||
if released:
|
||||
await grant_publish_permission(group_id_str, user_id_str, False)
|
||||
return released
|
||||
|
||||
async def current_speaker(
|
||||
db: AsyncSession,
|
||||
group_id: str | uuid.UUID
|
||||
):
|
||||
group_id_uuid = group_id if isinstance(group_id, uuid.UUID) else uuid.UUID(group_id)
|
||||
group_id_str = str(group_id_uuid)
|
||||
|
||||
group = await get_group_by_id(db, group_id_uuid)
|
||||
|
||||
if not group:
|
||||
return None
|
||||
|
||||
if str(group.type) == "private":
|
||||
return None
|
||||
|
||||
async def current_speaker(group_id: str | uuid.UUID):
|
||||
group_id_str = str(group_id)
|
||||
return await get_active_speaker(group_id_str)
|
||||
|
||||
async def grant_publish_permission(room_name: str, identity: str, can_publish: bool):
|
||||
lk_api = get_livekit_api() # همان متدی که در client.py نوشتی
|
||||
lk_api = get_livekit_api()
|
||||
await lk_api.room.update_participant(
|
||||
api.UpdateParticipantRequest(
|
||||
room=room_name,
|
||||
identity=identity,
|
||||
permission=api.ParticipantPermission(
|
||||
can_publish=can_publish,
|
||||
can_subscribe=True # همیشه بتواند بشنود
|
||||
can_subscribe=True
|
||||
)
|
||||
)
|
||||
)
|
||||
|
|
@ -1,4 +1,6 @@
|
|||
import uuid
|
||||
import time
|
||||
import json
|
||||
from fastapi import APIRouter, WebSocket, WebSocketDisconnect, Request, Header, status, HTTPException
|
||||
from livekit import api
|
||||
from core.config import settings
|
||||
|
|
@ -79,10 +81,6 @@ async def group_ws(websocket: WebSocket, group_id: str):
|
|||
await websocket.close(code=status.WS_1008_POLICY_VIOLATION)
|
||||
return
|
||||
|
||||
from domains.groups.repo import get_group_by_id
|
||||
group = await get_group_by_id(db, group_id_uuid)
|
||||
group_type = str(group.type) if group else "public"
|
||||
|
||||
user_id = str(user.id)
|
||||
# connect websocket
|
||||
await manager.connect(group_id, websocket)
|
||||
|
|
@ -111,13 +109,28 @@ async def group_ws(websocket: WebSocket, group_id: str):
|
|||
}
|
||||
)
|
||||
|
||||
last_action_time = 0.0
|
||||
|
||||
try:
|
||||
while True:
|
||||
data = await websocket.receive_json()
|
||||
text_data = await websocket.receive_text()
|
||||
|
||||
if len(text_data) > 20000:
|
||||
await websocket.close(code=status.WS_1009_MESSAGE_TOO_BIG)
|
||||
return
|
||||
|
||||
data = json.loads(text_data)
|
||||
event = data.get("type")
|
||||
|
||||
# anti spam
|
||||
current_time = time.time()
|
||||
if current_time - last_action_time < 0.5:
|
||||
continue
|
||||
last_action_time = current_time
|
||||
|
||||
# user wants to speak
|
||||
if event == "request_speak":
|
||||
success = await request_speak(group_id, user_id, group_type)
|
||||
success = await request_speak(group_id, user_id)
|
||||
if success:
|
||||
# Broadcast globally that someone is speaking
|
||||
await manager.broadcast(
|
||||
|
|
@ -133,8 +146,7 @@ async def group_ws(websocket: WebSocket, group_id: str):
|
|||
})
|
||||
else:
|
||||
# someone else is speaking
|
||||
async with AsyncSessionLocal() as temp_db:
|
||||
speaker = await current_speaker(temp_db, group_id)
|
||||
speaker = await current_speaker(group_id)
|
||||
await websocket.send_json(
|
||||
{
|
||||
"type": "speaker_busy",
|
||||
|
|
@ -144,7 +156,7 @@ async def group_ws(websocket: WebSocket, group_id: str):
|
|||
|
||||
# user stops speaking
|
||||
elif event == "stop_speak":
|
||||
released = await stop_speaking(group_id, user_id, group_type)
|
||||
released = await stop_speaking(group_id, user_id)
|
||||
if released:
|
||||
await manager.broadcast(
|
||||
group_id,
|
||||
|
|
@ -160,7 +172,7 @@ async def group_ws(websocket: WebSocket, group_id: str):
|
|||
except WebSocketDisconnect:
|
||||
manager.disconnect(group_id, websocket)
|
||||
await user_leave_group(group_id, user_id)
|
||||
await stop_speaking(group_id, user_id, group_type)
|
||||
await stop_speaking(group_id, user_id)
|
||||
await manager.broadcast(
|
||||
group_id,
|
||||
{
|
||||
|
|
|
|||
|
|
@ -17,8 +17,6 @@ class ConnectionManager:
|
|||
async def _setup_pubsub(self):
|
||||
if self._pubsub is None:
|
||||
self._pubsub = redis_client.pubsub()
|
||||
# Subscribe to all websocket broadcasting channels
|
||||
# We use Any for _pubsub to satisfy type checkers that don't know redis-py return types
|
||||
await self._pubsub.psubscribe("ws:group:*")
|
||||
self._listen_task = asyncio.create_task(self._redis_listener())
|
||||
|
||||
|
|
|
|||
|
|
@ -1,4 +1,3 @@
|
|||
from enum import Enum
|
||||
from sqlalchemy import String, Boolean, Integer
|
||||
from sqlalchemy.orm import Mapped, mapped_column
|
||||
|
||||
|
|
|
|||
|
|
@ -3,9 +3,9 @@ from pydantic import BaseModel, Field, field_validator
|
|||
import re
|
||||
|
||||
class UserCreate(BaseModel):
|
||||
username: str
|
||||
username: str = Field(..., max_length=20, description='username of the user')
|
||||
phone_number: str | None = Field(None, description="11 digit phone number")
|
||||
is_admin: bool = False
|
||||
is_admin: bool = Field(False)
|
||||
|
||||
@field_validator("phone_number")
|
||||
@classmethod
|
||||
|
|
@ -13,19 +13,19 @@ class UserCreate(BaseModel):
|
|||
if v is None:
|
||||
return v
|
||||
if not re.match(r"^09\d{9}$", v):
|
||||
raise ValueError("Phone number must start with 09 and be exactly 11 digits")
|
||||
raise ValueError("شماره تلفن باید ۱۱ رقم باشد و با ۰۹ شروع شود")
|
||||
return v
|
||||
|
||||
class UserResponse(BaseModel):
|
||||
id: uuid.UUID
|
||||
username: str
|
||||
phone_number: str | None
|
||||
is_admin: bool
|
||||
is_active: bool
|
||||
id: uuid.UUID = Field(...)
|
||||
username: str = Field(..., description='username of the user')
|
||||
phone_number: str | None = Field(..., description='phone number of the user')
|
||||
is_admin: bool = Field(..., description='is admin')
|
||||
is_active: bool = Field(..., description='is active')
|
||||
|
||||
class Config:
|
||||
from_attributes = True
|
||||
|
||||
class UserCreateResult(BaseModel):
|
||||
user: UserResponse
|
||||
secret: str
|
||||
user: UserResponse = Field(..., description='user created')
|
||||
secret: str = Field(..., description='secret of the user')
|
||||
|
|
|
|||
18
Back/main.py
18
Back/main.py
|
|
@ -1,5 +1,5 @@
|
|||
from contextlib import asynccontextmanager
|
||||
from fastapi import FastAPI
|
||||
from fastapi import FastAPI, Depends
|
||||
from fastapi.middleware.cors import CORSMiddleware
|
||||
from fastapi_swagger import patch_fastapi
|
||||
|
||||
|
|
@ -11,6 +11,7 @@ from domains.realtime.ws import router as realtime_router
|
|||
from domains.notifications.api import router as notifications_router
|
||||
from integrations.livekit.client import close_livekit_api
|
||||
from db.redis import redis_client
|
||||
from core.rate_limit import RateLimiter
|
||||
|
||||
|
||||
@asynccontextmanager
|
||||
|
|
@ -33,6 +34,7 @@ async def lifespan(app: FastAPI):
|
|||
await close_livekit_api()
|
||||
await redis_client.close()
|
||||
|
||||
global_limiter = RateLimiter(requests=30, window_seconds=60, scope="global")
|
||||
|
||||
app = FastAPI(
|
||||
title="NEDA API",
|
||||
|
|
@ -40,7 +42,8 @@ app = FastAPI(
|
|||
version="1.0.0",
|
||||
lifespan=lifespan,
|
||||
docs_url=None,
|
||||
swagger_ui_oauth2_redirect_url=None
|
||||
swagger_ui_oauth2_redirect_url=None,
|
||||
dependencies=[Depends(global_limiter)]
|
||||
)
|
||||
patch_fastapi(app,docs_url="/swagger")
|
||||
|
||||
|
|
@ -55,7 +58,16 @@ app.add_middleware(
|
|||
allow_methods=["*"],
|
||||
allow_headers=["*"],
|
||||
)
|
||||
|
||||
# app.add_middleware(
|
||||
# CORSMiddleware,
|
||||
# allow_origins=[
|
||||
# "https://app.neda.com",
|
||||
# "http://localhost:3000" # فقط برای تست برنامهنویس فرانتاند
|
||||
# ],
|
||||
# allow_credentials=True,
|
||||
# allow_methods=["GET", "POST", "PUT", "DELETE"], # محدود کردن متدها
|
||||
# allow_headers=["Authorization", "Content-Type"], # محدود کردن هدرها
|
||||
# )
|
||||
|
||||
# -------------------------
|
||||
# Routers
|
||||
|
|
|
|||
Loading…
Reference in New Issue
Block a user