refact: improve security and improve some code in domains
This commit is contained in:
parent
7f37d7fb60
commit
8ff4e90d13
|
|
@ -26,7 +26,7 @@ async def get_current_user(
|
||||||
if payload is None:
|
if payload is None:
|
||||||
raise HTTPException(
|
raise HTTPException(
|
||||||
status_code=status.HTTP_401_UNAUTHORIZED,
|
status_code=status.HTTP_401_UNAUTHORIZED,
|
||||||
detail="Invalid authentication token",
|
detail="توکن نامعتبر است",
|
||||||
)
|
)
|
||||||
|
|
||||||
user_id = payload.get("sub")
|
user_id = payload.get("sub")
|
||||||
|
|
@ -34,7 +34,7 @@ async def get_current_user(
|
||||||
if not user_id:
|
if not user_id:
|
||||||
raise HTTPException(
|
raise HTTPException(
|
||||||
status_code=status.HTTP_401_UNAUTHORIZED,
|
status_code=status.HTTP_401_UNAUTHORIZED,
|
||||||
detail="Invalid token payload",
|
detail="توکن نامعتبر است",
|
||||||
)
|
)
|
||||||
|
|
||||||
user = await get_user_by_id(db, user_id)
|
user = await get_user_by_id(db, user_id)
|
||||||
|
|
@ -42,13 +42,13 @@ async def get_current_user(
|
||||||
if not user:
|
if not user:
|
||||||
raise HTTPException(
|
raise HTTPException(
|
||||||
status_code=status.HTTP_401_UNAUTHORIZED,
|
status_code=status.HTTP_401_UNAUTHORIZED,
|
||||||
detail="User not found",
|
detail="کاربری یافت نشد",
|
||||||
)
|
)
|
||||||
|
|
||||||
if not user.is_active:
|
if not user.is_active:
|
||||||
raise HTTPException(
|
raise HTTPException(
|
||||||
status_code=status.HTTP_403_FORBIDDEN,
|
status_code=status.HTTP_403_FORBIDDEN,
|
||||||
detail="User is inactive",
|
detail="کاربر غیرفعال است",
|
||||||
)
|
)
|
||||||
|
|
||||||
# Check token version for remote logout
|
# Check token version for remote logout
|
||||||
|
|
@ -56,7 +56,7 @@ async def get_current_user(
|
||||||
if token_version != user.token_version:
|
if token_version != user.token_version:
|
||||||
raise HTTPException(
|
raise HTTPException(
|
||||||
status_code=status.HTTP_401_UNAUTHORIZED,
|
status_code=status.HTTP_401_UNAUTHORIZED,
|
||||||
detail="Token has been invalidated",
|
detail="توکن نامعتبر است",
|
||||||
)
|
)
|
||||||
|
|
||||||
return user
|
return user
|
||||||
|
|
@ -72,7 +72,7 @@ async def get_current_admin(
|
||||||
if not user.is_admin:
|
if not user.is_admin:
|
||||||
raise HTTPException(
|
raise HTTPException(
|
||||||
status_code=status.HTTP_403_FORBIDDEN,
|
status_code=status.HTTP_403_FORBIDDEN,
|
||||||
detail="Admin privileges required",
|
detail="شما دسترسی لازم را ندارید",
|
||||||
)
|
)
|
||||||
|
|
||||||
return user
|
return user
|
||||||
50
Back/core/rate_limit.py
Normal file
50
Back/core/rate_limit.py
Normal file
|
|
@ -0,0 +1,50 @@
|
||||||
|
# core/rate_limit.py
|
||||||
|
|
||||||
|
from fastapi import Request, HTTPException, status
|
||||||
|
from db.redis import redis_client
|
||||||
|
|
||||||
|
class RateLimiter:
|
||||||
|
def __init__(self, requests: int, window_seconds: int, scope: str = "global"):
|
||||||
|
"""
|
||||||
|
:param requests: number of requests allowed
|
||||||
|
:param window_seconds: window time in seconds
|
||||||
|
:param scope: type of limit ("global" for all site, "endpoint" for a specific path)
|
||||||
|
"""
|
||||||
|
self.requests = requests
|
||||||
|
self.window_seconds = window_seconds
|
||||||
|
self.scope = scope
|
||||||
|
|
||||||
|
async def __call__(self, request: Request):
|
||||||
|
# getting client ip
|
||||||
|
client_ip = request.client.host if request.client else "127.0.0.1"
|
||||||
|
|
||||||
|
# when project is in docker and behind nginx, the real ip is in the headers
|
||||||
|
real_ip = request.headers.get("x-real-ip", request.headers.get("x-forwarded-for", client_ip))
|
||||||
|
# if there are multiple ips, take the first ip (the real user ip)
|
||||||
|
real_ip = real_ip.split(",")[0].strip()
|
||||||
|
|
||||||
|
# creating redis key based on scope
|
||||||
|
if self.scope == "global":
|
||||||
|
# key for global limit (e.g., rate_limit:global:192.168.1.5)
|
||||||
|
key = f"rate_limit:global:{real_ip}"
|
||||||
|
else:
|
||||||
|
# key for endpoint limit (e.g., rate_limit:endpoint:192.168.1.5:/admin/login)
|
||||||
|
path = request.scope["path"]
|
||||||
|
key = f"rate_limit:endpoint:{real_ip}:{path}"
|
||||||
|
|
||||||
|
# adding 1 to the number of requests for this ip in redis
|
||||||
|
current_count = await redis_client.incr(key)
|
||||||
|
|
||||||
|
# if this is the first request in this time window, set the expiration time (TTL)
|
||||||
|
if current_count == 1:
|
||||||
|
await redis_client.expire(key, self.window_seconds)
|
||||||
|
|
||||||
|
# if the number of requests exceeds the limit, access is blocked
|
||||||
|
if current_count > self.requests:
|
||||||
|
# penalty: if someone spams, the time they are blocked is extended from zero again
|
||||||
|
await redis_client.expire(key, self.window_seconds)
|
||||||
|
|
||||||
|
raise HTTPException(
|
||||||
|
status_code=status.HTTP_429_TOO_MANY_REQUESTS,
|
||||||
|
detail="Too many requests. Please try again later."
|
||||||
|
)
|
||||||
|
|
@ -68,9 +68,9 @@ async def logout_user(
|
||||||
if not user:
|
if not user:
|
||||||
raise HTTPException(
|
raise HTTPException(
|
||||||
status_code=status.HTTP_404_NOT_FOUND,
|
status_code=status.HTTP_404_NOT_FOUND,
|
||||||
detail="User not found"
|
detail="نام کاربری یافت نشد"
|
||||||
)
|
)
|
||||||
return {"message": "User logged out successfully"}
|
return {"message": "خروج کاربر با موفقیت انجام شد"}
|
||||||
|
|
||||||
|
|
||||||
@router.post("/users/{user_id}/reset-secret",
|
@router.post("/users/{user_id}/reset-secret",
|
||||||
|
|
@ -89,7 +89,7 @@ async def reset_secret(
|
||||||
if not new_secret:
|
if not new_secret:
|
||||||
raise HTTPException(
|
raise HTTPException(
|
||||||
status_code=status.HTTP_404_NOT_FOUND,
|
status_code=status.HTTP_404_NOT_FOUND,
|
||||||
detail="User not found"
|
detail="نام کاربری یافت نشد"
|
||||||
)
|
)
|
||||||
|
|
||||||
return {"secret": new_secret}
|
return {"secret": new_secret}
|
||||||
|
|
|
||||||
|
|
@ -1,24 +1,50 @@
|
||||||
|
from typing import Required
|
||||||
import uuid
|
import uuid
|
||||||
from pydantic import BaseModel, ConfigDict
|
from pydantic import BaseModel, ConfigDict, Field, field_validator, StrictBool
|
||||||
|
|
||||||
class AdminCreateUser(BaseModel):
|
class AdminCreateUser(BaseModel):
|
||||||
model_config = ConfigDict(extra="forbid")
|
model_config = ConfigDict(extra="forbid")
|
||||||
username: str
|
username: str = Field(..., max_length=20, description='username of the user')
|
||||||
phone_number: str | None = None
|
phone_number: str | None = Field(None, max_length=11, description='phone number of user')
|
||||||
|
|
||||||
|
# validate phone number
|
||||||
|
@field_validator('phone_number')
|
||||||
|
def validate_phone_number(cls, v: str | None) -> str | None:
|
||||||
|
if v is None:
|
||||||
|
return None
|
||||||
|
if not v.isdigit():
|
||||||
|
raise ValueError('شماره تلفن باید عدد باشد')
|
||||||
|
if len(v) != 11:
|
||||||
|
raise ValueError('شماره تلفن باید 11 رقم باشد')
|
||||||
|
if not v.startswith('09'):
|
||||||
|
raise ValueError('شماره تلفن باید با 09 شروع شود')
|
||||||
|
return v
|
||||||
|
|
||||||
class AdminUserResponse(BaseModel):
|
class AdminUserResponse(BaseModel):
|
||||||
id: uuid.UUID
|
id: uuid.UUID = Field(...)
|
||||||
username: str
|
username: str = Field(..., max_length=20, description='username of the user')
|
||||||
phone_number: str | None
|
phone_number: str | None = Field(..., description='phone number of user')
|
||||||
is_admin: bool
|
is_admin: bool = Field(..., description='is admin')
|
||||||
is_active: bool
|
is_active: bool = Field(..., description='is active')
|
||||||
|
|
||||||
|
@field_validator('phone_number')
|
||||||
|
def validate_phone_number(cls, v: str | None) -> str | None:
|
||||||
|
if v is None:
|
||||||
|
return None
|
||||||
|
if not v.isdigit():
|
||||||
|
raise ValueError('شماره تلفن باید عدد باشد')
|
||||||
|
if len(v) != 11:
|
||||||
|
raise ValueError('شماره تلفن باید 11 رقم باشد')
|
||||||
|
if not v.startswith('09'):
|
||||||
|
raise ValueError('شماره تلفن باید با 09 شروع شود')
|
||||||
|
return v
|
||||||
|
|
||||||
class Config:
|
class Config:
|
||||||
from_attributes = True
|
from_attributes = True
|
||||||
|
|
||||||
class AdminCreateUserResult(BaseModel):
|
class AdminCreateUserResult(BaseModel):
|
||||||
user: AdminUserResponse
|
user: AdminUserResponse = Field(...)
|
||||||
secret: str
|
secret: str = Field(...)
|
||||||
|
|
||||||
class AdminResetSecretResult(BaseModel):
|
class AdminResetSecretResult(BaseModel):
|
||||||
secret: str
|
secret: str = Field(...)
|
||||||
|
|
|
||||||
|
|
@ -41,7 +41,7 @@ async def _create_user_with_role(
|
||||||
existing = await get_user_by_username(db, username)
|
existing = await get_user_by_username(db, username)
|
||||||
|
|
||||||
if existing:
|
if existing:
|
||||||
raise ValueError("Username already exists")
|
raise ValueError("نام کاربری تکراری است")
|
||||||
|
|
||||||
secret = generate_user_secret()
|
secret = generate_user_secret()
|
||||||
|
|
||||||
|
|
|
||||||
|
|
@ -1,8 +1,10 @@
|
||||||
|
import uuid
|
||||||
from fastapi import APIRouter, Depends, HTTPException, status
|
from fastapi import APIRouter, Depends, HTTPException, status
|
||||||
from sqlalchemy.ext.asyncio import AsyncSession
|
from sqlalchemy.ext.asyncio import AsyncSession
|
||||||
|
|
||||||
from db.session import get_db
|
from db.session import get_db
|
||||||
from core.jwt import decode_token, create_access_token
|
from core.jwt import decode_token, create_access_token
|
||||||
|
from core.rate_limit import RateLimiter
|
||||||
from domains.users.repo import get_user_by_id
|
from domains.users.repo import get_user_by_id
|
||||||
|
|
||||||
from domains.auth.schemas import (
|
from domains.auth.schemas import (
|
||||||
|
|
@ -12,15 +14,14 @@ from domains.auth.schemas import (
|
||||||
)
|
)
|
||||||
|
|
||||||
from domains.auth.service import login_user
|
from domains.auth.service import login_user
|
||||||
import uuid
|
|
||||||
|
|
||||||
router = APIRouter(
|
router = APIRouter(
|
||||||
prefix="/auth",
|
prefix="/auth",
|
||||||
tags=["auth"]
|
tags=["auth"]
|
||||||
)
|
)
|
||||||
|
|
||||||
|
login_limiter = RateLimiter(requests=5, window_seconds=60, scope="endpoint")
|
||||||
@router.post("/login", response_model=TokenResponse)
|
@router.post("/login", response_model=TokenResponse, dependencies=[Depends(login_limiter)])
|
||||||
async def login(
|
async def login(
|
||||||
payload: LoginRequest,
|
payload: LoginRequest,
|
||||||
db: AsyncSession = Depends(get_db)
|
db: AsyncSession = Depends(get_db)
|
||||||
|
|
@ -35,7 +36,7 @@ async def login(
|
||||||
if not token:
|
if not token:
|
||||||
raise HTTPException(
|
raise HTTPException(
|
||||||
status_code=status.HTTP_401_UNAUTHORIZED,
|
status_code=status.HTTP_401_UNAUTHORIZED,
|
||||||
detail="Invalid username or secret"
|
detail="نام کاربری یا رمز عبور اشتباه است"
|
||||||
)
|
)
|
||||||
|
|
||||||
return token
|
return token
|
||||||
|
|
@ -49,7 +50,7 @@ async def refresh(
|
||||||
if not payload_data or payload_data.get("type") != "refresh":
|
if not payload_data or payload_data.get("type") != "refresh":
|
||||||
raise HTTPException(
|
raise HTTPException(
|
||||||
status_code=status.HTTP_401_UNAUTHORIZED,
|
status_code=status.HTTP_401_UNAUTHORIZED,
|
||||||
detail="Invalid refresh token",
|
detail="رفرش توکن نامعتبر است",
|
||||||
)
|
)
|
||||||
|
|
||||||
user_id = payload_data.get("sub")
|
user_id = payload_data.get("sub")
|
||||||
|
|
@ -58,7 +59,7 @@ async def refresh(
|
||||||
if not user_id or token_version is None:
|
if not user_id or token_version is None:
|
||||||
raise HTTPException(
|
raise HTTPException(
|
||||||
status_code=status.HTTP_401_UNAUTHORIZED,
|
status_code=status.HTTP_401_UNAUTHORIZED,
|
||||||
detail="Invalid refresh token payload",
|
detail="پیلود رفرش توکن نامعتبر است",
|
||||||
)
|
)
|
||||||
|
|
||||||
try:
|
try:
|
||||||
|
|
@ -66,7 +67,7 @@ async def refresh(
|
||||||
except ValueError:
|
except ValueError:
|
||||||
raise HTTPException(
|
raise HTTPException(
|
||||||
status_code=status.HTTP_401_UNAUTHORIZED,
|
status_code=status.HTTP_401_UNAUTHORIZED,
|
||||||
detail="Invalid user ID in token",
|
detail="شناسه کاربر نامعتبر است",
|
||||||
)
|
)
|
||||||
|
|
||||||
user = await get_user_by_id(db, user_uuid)
|
user = await get_user_by_id(db, user_uuid)
|
||||||
|
|
@ -74,7 +75,7 @@ async def refresh(
|
||||||
if not user or not user.is_active or user.token_version != token_version:
|
if not user or not user.is_active or user.token_version != token_version:
|
||||||
raise HTTPException(
|
raise HTTPException(
|
||||||
status_code=status.HTTP_401_UNAUTHORIZED,
|
status_code=status.HTTP_401_UNAUTHORIZED,
|
||||||
detail="Invalid refresh token",
|
detail="رفرش توکن نامعتبر است",
|
||||||
)
|
)
|
||||||
|
|
||||||
access_token = create_access_token(
|
access_token = create_access_token(
|
||||||
|
|
|
||||||
|
|
@ -1,25 +1,25 @@
|
||||||
import uuid
|
import uuid
|
||||||
from pydantic import BaseModel
|
from pydantic import BaseModel, Field
|
||||||
|
|
||||||
|
|
||||||
class LoginRequest(BaseModel):
|
class LoginRequest(BaseModel):
|
||||||
username: str
|
username: str = Field(..., description='username of the user')
|
||||||
secret: str
|
secret: str = Field(..., description='secret of the user')
|
||||||
|
|
||||||
|
|
||||||
class TokenResponse(BaseModel):
|
class TokenResponse(BaseModel):
|
||||||
access_token: str
|
access_token: str = Field(..., description='access token')
|
||||||
refresh_token: str
|
refresh_token: str = Field(..., description='refresh token')
|
||||||
token_type: str = "bearer"
|
token_type: str = Field("bearer", description='token type')
|
||||||
|
|
||||||
class RefreshTokenRequest(BaseModel):
|
class RefreshTokenRequest(BaseModel):
|
||||||
refresh_token: str
|
refresh_token: str = Field(..., description='refresh token')
|
||||||
|
|
||||||
|
|
||||||
class AuthUser(BaseModel):
|
class AuthUser(BaseModel):
|
||||||
id: uuid.UUID
|
id: uuid.UUID = Field(..., description='user id')
|
||||||
username: str
|
username: str = Field(..., description='username of the user')
|
||||||
is_admin: bool
|
is_admin: bool = Field(..., description='is admin')
|
||||||
|
|
||||||
class Config:
|
class Config:
|
||||||
from_attributes = True
|
from_attributes = True
|
||||||
|
|
@ -100,7 +100,7 @@ async def invite_member(
|
||||||
user.id,
|
user.id,
|
||||||
payload.username
|
payload.username
|
||||||
)
|
)
|
||||||
return {"message": "Invitation sent", "notification_id": notification.id}
|
return {"message": "دعوت ارسال شد", "notification_id": notification.id}
|
||||||
except ValueError as e:
|
except ValueError as e:
|
||||||
raise HTTPException(
|
raise HTTPException(
|
||||||
status_code=status.HTTP_400_BAD_REQUEST,
|
status_code=status.HTTP_400_BAD_REQUEST,
|
||||||
|
|
@ -125,7 +125,7 @@ async def remove_member(
|
||||||
user_id,
|
user_id,
|
||||||
user
|
user
|
||||||
)
|
)
|
||||||
return {"message": "Member removed successfully"}
|
return {"message": "عضو با موفقیت حذف شد"}
|
||||||
except ValueError as e:
|
except ValueError as e:
|
||||||
raise HTTPException(
|
raise HTTPException(
|
||||||
status_code=status.HTTP_403_FORBIDDEN,
|
status_code=status.HTTP_403_FORBIDDEN,
|
||||||
|
|
|
||||||
|
|
@ -1,28 +1,28 @@
|
||||||
import uuid
|
import uuid
|
||||||
from pydantic import BaseModel
|
from pydantic import BaseModel, Field
|
||||||
|
|
||||||
from domains.groups.models import GroupType, GroupMemberRole
|
from domains.groups.models import GroupType, GroupMemberRole
|
||||||
|
|
||||||
class GroupCreate(BaseModel):
|
class GroupCreate(BaseModel):
|
||||||
name: str
|
name: str = Field(..., max_length=50, description='name of the group')
|
||||||
|
|
||||||
class GroupResponse(BaseModel):
|
class GroupResponse(BaseModel):
|
||||||
id: uuid.UUID
|
id: uuid.UUID = Field(...)
|
||||||
name: str
|
name: str = Field(..., max_length=50, description='name of the group')
|
||||||
type: GroupType
|
type: GroupType = Field(..., description='type of the group')
|
||||||
is_active: bool
|
is_active: bool = Field(..., description='is active')
|
||||||
|
|
||||||
class Config:
|
class Config:
|
||||||
from_attributes = True
|
from_attributes = True
|
||||||
|
|
||||||
class AddMemberRequest(BaseModel):
|
class AddMemberRequest(BaseModel):
|
||||||
username: str # Req 12 says user enters username
|
username: str = Field(..., max_length=20, description='username of the user')
|
||||||
|
|
||||||
class GroupMemberResponse(BaseModel):
|
class GroupMemberResponse(BaseModel):
|
||||||
user_id: uuid.UUID
|
user_id: uuid.UUID = Field(...)
|
||||||
username: str
|
username: str = Field(..., max_length=20, description='username of the user')
|
||||||
role: GroupMemberRole
|
role: GroupMemberRole = Field(..., description='role of the user in the group')
|
||||||
is_online: bool = False
|
is_online: bool = Field(False, description='is online')
|
||||||
|
|
||||||
class Config:
|
class Config:
|
||||||
from_attributes = True
|
from_attributes = True
|
||||||
|
|
@ -15,7 +15,8 @@ from domains.groups.repo import (
|
||||||
get_all_groups as repo_get_all_groups
|
get_all_groups as repo_get_all_groups
|
||||||
)
|
)
|
||||||
from domains.realtime.presence_service import list_online_users
|
from domains.realtime.presence_service import list_online_users
|
||||||
|
from domains.users.repo import get_user_by_username
|
||||||
|
from domains.notifications.service import send_join_request
|
||||||
|
|
||||||
async def create_new_group(
|
async def create_new_group(
|
||||||
db: AsyncSession,
|
db: AsyncSession,
|
||||||
|
|
@ -49,29 +50,28 @@ async def invite_member_to_group(
|
||||||
sender_id: uuid.UUID,
|
sender_id: uuid.UUID,
|
||||||
target_username: str
|
target_username: str
|
||||||
):
|
):
|
||||||
from domains.users.repo import get_user_by_username
|
|
||||||
from domains.notifications.service import send_join_request
|
|
||||||
|
|
||||||
group_id_uuid = group_id if isinstance(group_id, uuid.UUID) else uuid.UUID(group_id)
|
group_id_uuid = group_id if isinstance(group_id, uuid.UUID) else uuid.UUID(group_id)
|
||||||
|
|
||||||
# 1. Check if group exists
|
# 1. Check if group exists
|
||||||
group = await get_group_by_id(db, group_id_uuid)
|
group = await get_group_by_id(db, group_id_uuid)
|
||||||
if not group:
|
if not group:
|
||||||
raise ValueError("Group not found")
|
raise ValueError("گروهی یافت نشد")
|
||||||
|
|
||||||
sender = await get_user_by_id(db, sender_id)
|
sender = await get_user_by_id(db, sender_id)
|
||||||
if not sender:
|
if not sender:
|
||||||
raise ValueError("Sender not found")
|
raise ValueError("فرستنده یافت نشد")
|
||||||
|
|
||||||
if not sender.is_admin:
|
if not sender.is_admin:
|
||||||
membership = await get_group_member(db, group_id_uuid, sender_id)
|
membership = await get_group_member(db, group_id_uuid, sender_id)
|
||||||
if not membership:
|
if not membership:
|
||||||
raise ValueError("Not a group member")
|
raise ValueError("شما عضو این گروه نیستید")
|
||||||
|
|
||||||
# 2. Check if target user exists
|
# 2. Check if target user exists
|
||||||
target_user = await get_user_by_username(db, target_username)
|
target_user = await get_user_by_username(db, target_username)
|
||||||
if not target_user:
|
if not target_user:
|
||||||
raise ValueError("User not found")
|
raise ValueError("کاربری یافت نشد")
|
||||||
|
|
||||||
# 3. Send notification (Req 12)
|
# 3. Send notification (Req 12)
|
||||||
return await send_join_request(
|
return await send_join_request(
|
||||||
|
|
@ -79,8 +79,8 @@ async def invite_member_to_group(
|
||||||
sender_id=sender_id,
|
sender_id=sender_id,
|
||||||
receiver_id=target_user.id,
|
receiver_id=target_user.id,
|
||||||
group_id=group.id,
|
group_id=group.id,
|
||||||
title="Group Invitation",
|
title="دعوت به گروه",
|
||||||
description=f"You have been invited to join group {group.name}"
|
description=f"شما به گروه {group.name} دعوت شدهاید"
|
||||||
)
|
)
|
||||||
|
|
||||||
|
|
||||||
|
|
@ -95,7 +95,7 @@ async def add_member_to_group(
|
||||||
|
|
||||||
existing = await get_group_member(db, group_id_uuid, user_id_uuid)
|
existing = await get_group_member(db, group_id_uuid, user_id_uuid)
|
||||||
if existing:
|
if existing:
|
||||||
raise ValueError("User already in group")
|
raise ValueError("کاربر از قبل عضو گروه است")
|
||||||
|
|
||||||
membership = GroupMember(
|
membership = GroupMember(
|
||||||
group_id=group_id_uuid,
|
group_id=group_id_uuid,
|
||||||
|
|
@ -146,6 +146,6 @@ async def remove_member_from_group(
|
||||||
if not requesting_user.is_admin:
|
if not requesting_user.is_admin:
|
||||||
membership = await get_group_member(db, group_id_uuid, requesting_user.id)
|
membership = await get_group_member(db, group_id_uuid, requesting_user.id)
|
||||||
if not membership or membership.role != GroupMemberRole.MANAGER:
|
if not membership or membership.role != GroupMemberRole.MANAGER:
|
||||||
raise ValueError("Permission denied")
|
raise ValueError("دسترسی لازم را ندارید")
|
||||||
|
|
||||||
await delete_group_member(db, group_id_uuid, target_user_id_uuid)
|
await delete_group_member(db, group_id_uuid, target_user_id_uuid)
|
||||||
|
|
|
||||||
|
|
@ -1,25 +1,26 @@
|
||||||
import uuid
|
import uuid
|
||||||
from pydantic import BaseModel
|
from pydantic import BaseModel, Field
|
||||||
|
|
||||||
from domains.notifications.models import NotificationType
|
from domains.notifications.models import NotificationType
|
||||||
|
|
||||||
class NotificationBase(BaseModel):
|
class NotificationBase(BaseModel):
|
||||||
title: str
|
title: str = Field(..., description='title of the notification')
|
||||||
description: str | None = None
|
description: str | None = Field(None, description='description of the notification')
|
||||||
type: NotificationType
|
type: NotificationType = Field(..., description='type of the notification')
|
||||||
group_id: uuid.UUID | None = None
|
group_id: uuid.UUID | None = Field(None, description='group id of the notification')
|
||||||
|
|
||||||
class NotificationCreate(NotificationBase):
|
class NotificationCreate(NotificationBase):
|
||||||
receiver_id: uuid.UUID
|
receiver_id: uuid.UUID = Field(..., description='receiver id of the notification')
|
||||||
sender_id: uuid.UUID | None = None
|
sender_id: uuid.UUID | None = Field(None, description='sender id of the notification')
|
||||||
|
|
||||||
class NotificationResponse(NotificationBase):
|
class NotificationResponse(NotificationBase):
|
||||||
id: uuid.UUID
|
id: uuid.UUID = Field(..., description='notification id')
|
||||||
is_accepted: bool | None
|
is_accepted: bool | None = Field(..., description='is accepted')
|
||||||
receiver_id: uuid.UUID
|
receiver_id: uuid.UUID = Field(..., description='receiver id of the notification')
|
||||||
sender_id: uuid.UUID | None
|
sender_id: uuid.UUID | None = Field(..., description='sender id of the notification')
|
||||||
|
|
||||||
class Config:
|
class Config:
|
||||||
from_attributes = True
|
from_attributes = True
|
||||||
|
|
||||||
class NotificationAction(BaseModel):
|
class NotificationAction(BaseModel):
|
||||||
is_accepted: bool
|
is_accepted: bool = Field(..., description='is accepted')
|
||||||
|
|
|
||||||
|
|
@ -62,10 +62,10 @@ async def respond_to_notification(
|
||||||
notification_id_uuid = notification_id if isinstance(notification_id, uuid.UUID) else uuid.UUID(notification_id)
|
notification_id_uuid = notification_id if isinstance(notification_id, uuid.UUID) else uuid.UUID(notification_id)
|
||||||
notification = await get_notification_by_id(db, notification_id_uuid)
|
notification = await get_notification_by_id(db, notification_id_uuid)
|
||||||
if not notification:
|
if not notification:
|
||||||
raise ValueError("Notification not found")
|
raise ValueError("نوتیفیکیشنی یافت نشد")
|
||||||
|
|
||||||
if str(notification.receiver_id) != str(user_id):
|
if str(notification.receiver_id) != str(user_id):
|
||||||
raise ValueError("Permission denied")
|
raise ValueError("دسترسی لازم را ندارید")
|
||||||
|
|
||||||
notification.is_accepted = is_accepted
|
notification.is_accepted = is_accepted
|
||||||
await update_notification(db, notification)
|
await update_notification(db, notification)
|
||||||
|
|
|
||||||
|
|
@ -13,17 +13,10 @@ from domains.groups.repo import get_group_by_id
|
||||||
async def request_speak(
|
async def request_speak(
|
||||||
group_id: str | uuid.UUID,
|
group_id: str | uuid.UUID,
|
||||||
user_id: str | uuid.UUID,
|
user_id: str | uuid.UUID,
|
||||||
group_type: str
|
|
||||||
) -> bool:
|
) -> bool:
|
||||||
group_id_str = str(group_id)
|
group_id_str = str(group_id)
|
||||||
user_id_str = str(user_id)
|
user_id_str = str(user_id)
|
||||||
|
|
||||||
|
|
||||||
if group_type == "private":
|
|
||||||
await grant_publish_permission(group_id_str, user_id_str, True)
|
|
||||||
return True
|
|
||||||
|
|
||||||
# group chat → push-to-talk
|
|
||||||
granted = await acquire_speaker(group_id_str, user_id_str)
|
granted = await acquire_speaker(group_id_str, user_id_str)
|
||||||
if not granted:
|
if not granted:
|
||||||
return False
|
return False
|
||||||
|
|
@ -35,46 +28,28 @@ async def request_speak(
|
||||||
async def stop_speaking(
|
async def stop_speaking(
|
||||||
group_id: str | uuid.UUID,
|
group_id: str | uuid.UUID,
|
||||||
user_id: str | uuid.UUID,
|
user_id: str | uuid.UUID,
|
||||||
group_type: str
|
|
||||||
):
|
):
|
||||||
group_id_str = str(group_id)
|
group_id_str = str(group_id)
|
||||||
user_id_str = str(user_id)
|
user_id_str = str(user_id)
|
||||||
|
|
||||||
if group_type == "private":
|
|
||||||
await grant_publish_permission(group_id_str, user_id_str, False)
|
|
||||||
return True
|
|
||||||
|
|
||||||
released = await release_speaker(group_id_str, user_id_str)
|
released = await release_speaker(group_id_str, user_id_str)
|
||||||
if released:
|
if released:
|
||||||
await grant_publish_permission(group_id_str, user_id_str, False)
|
await grant_publish_permission(group_id_str, user_id_str, False)
|
||||||
return released
|
return released
|
||||||
|
|
||||||
async def current_speaker(
|
async def current_speaker(group_id: str | uuid.UUID):
|
||||||
db: AsyncSession,
|
group_id_str = str(group_id)
|
||||||
group_id: str | uuid.UUID
|
|
||||||
):
|
|
||||||
group_id_uuid = group_id if isinstance(group_id, uuid.UUID) else uuid.UUID(group_id)
|
|
||||||
group_id_str = str(group_id_uuid)
|
|
||||||
|
|
||||||
group = await get_group_by_id(db, group_id_uuid)
|
|
||||||
|
|
||||||
if not group:
|
|
||||||
return None
|
|
||||||
|
|
||||||
if str(group.type) == "private":
|
|
||||||
return None
|
|
||||||
|
|
||||||
return await get_active_speaker(group_id_str)
|
return await get_active_speaker(group_id_str)
|
||||||
|
|
||||||
async def grant_publish_permission(room_name: str, identity: str, can_publish: bool):
|
async def grant_publish_permission(room_name: str, identity: str, can_publish: bool):
|
||||||
lk_api = get_livekit_api() # همان متدی که در client.py نوشتی
|
lk_api = get_livekit_api()
|
||||||
await lk_api.room.update_participant(
|
await lk_api.room.update_participant(
|
||||||
api.UpdateParticipantRequest(
|
api.UpdateParticipantRequest(
|
||||||
room=room_name,
|
room=room_name,
|
||||||
identity=identity,
|
identity=identity,
|
||||||
permission=api.ParticipantPermission(
|
permission=api.ParticipantPermission(
|
||||||
can_publish=can_publish,
|
can_publish=can_publish,
|
||||||
can_subscribe=True # همیشه بتواند بشنود
|
can_subscribe=True
|
||||||
)
|
)
|
||||||
)
|
)
|
||||||
)
|
)
|
||||||
|
|
@ -1,4 +1,6 @@
|
||||||
import uuid
|
import uuid
|
||||||
|
import time
|
||||||
|
import json
|
||||||
from fastapi import APIRouter, WebSocket, WebSocketDisconnect, Request, Header, status, HTTPException
|
from fastapi import APIRouter, WebSocket, WebSocketDisconnect, Request, Header, status, HTTPException
|
||||||
from livekit import api
|
from livekit import api
|
||||||
from core.config import settings
|
from core.config import settings
|
||||||
|
|
@ -79,10 +81,6 @@ async def group_ws(websocket: WebSocket, group_id: str):
|
||||||
await websocket.close(code=status.WS_1008_POLICY_VIOLATION)
|
await websocket.close(code=status.WS_1008_POLICY_VIOLATION)
|
||||||
return
|
return
|
||||||
|
|
||||||
from domains.groups.repo import get_group_by_id
|
|
||||||
group = await get_group_by_id(db, group_id_uuid)
|
|
||||||
group_type = str(group.type) if group else "public"
|
|
||||||
|
|
||||||
user_id = str(user.id)
|
user_id = str(user.id)
|
||||||
# connect websocket
|
# connect websocket
|
||||||
await manager.connect(group_id, websocket)
|
await manager.connect(group_id, websocket)
|
||||||
|
|
@ -111,13 +109,28 @@ async def group_ws(websocket: WebSocket, group_id: str):
|
||||||
}
|
}
|
||||||
)
|
)
|
||||||
|
|
||||||
|
last_action_time = 0.0
|
||||||
|
|
||||||
try:
|
try:
|
||||||
while True:
|
while True:
|
||||||
data = await websocket.receive_json()
|
text_data = await websocket.receive_text()
|
||||||
|
|
||||||
|
if len(text_data) > 20000:
|
||||||
|
await websocket.close(code=status.WS_1009_MESSAGE_TOO_BIG)
|
||||||
|
return
|
||||||
|
|
||||||
|
data = json.loads(text_data)
|
||||||
event = data.get("type")
|
event = data.get("type")
|
||||||
|
|
||||||
|
# anti spam
|
||||||
|
current_time = time.time()
|
||||||
|
if current_time - last_action_time < 0.5:
|
||||||
|
continue
|
||||||
|
last_action_time = current_time
|
||||||
|
|
||||||
# user wants to speak
|
# user wants to speak
|
||||||
if event == "request_speak":
|
if event == "request_speak":
|
||||||
success = await request_speak(group_id, user_id, group_type)
|
success = await request_speak(group_id, user_id)
|
||||||
if success:
|
if success:
|
||||||
# Broadcast globally that someone is speaking
|
# Broadcast globally that someone is speaking
|
||||||
await manager.broadcast(
|
await manager.broadcast(
|
||||||
|
|
@ -133,8 +146,7 @@ async def group_ws(websocket: WebSocket, group_id: str):
|
||||||
})
|
})
|
||||||
else:
|
else:
|
||||||
# someone else is speaking
|
# someone else is speaking
|
||||||
async with AsyncSessionLocal() as temp_db:
|
speaker = await current_speaker(group_id)
|
||||||
speaker = await current_speaker(temp_db, group_id)
|
|
||||||
await websocket.send_json(
|
await websocket.send_json(
|
||||||
{
|
{
|
||||||
"type": "speaker_busy",
|
"type": "speaker_busy",
|
||||||
|
|
@ -144,7 +156,7 @@ async def group_ws(websocket: WebSocket, group_id: str):
|
||||||
|
|
||||||
# user stops speaking
|
# user stops speaking
|
||||||
elif event == "stop_speak":
|
elif event == "stop_speak":
|
||||||
released = await stop_speaking(group_id, user_id, group_type)
|
released = await stop_speaking(group_id, user_id)
|
||||||
if released:
|
if released:
|
||||||
await manager.broadcast(
|
await manager.broadcast(
|
||||||
group_id,
|
group_id,
|
||||||
|
|
@ -160,7 +172,7 @@ async def group_ws(websocket: WebSocket, group_id: str):
|
||||||
except WebSocketDisconnect:
|
except WebSocketDisconnect:
|
||||||
manager.disconnect(group_id, websocket)
|
manager.disconnect(group_id, websocket)
|
||||||
await user_leave_group(group_id, user_id)
|
await user_leave_group(group_id, user_id)
|
||||||
await stop_speaking(group_id, user_id, group_type)
|
await stop_speaking(group_id, user_id)
|
||||||
await manager.broadcast(
|
await manager.broadcast(
|
||||||
group_id,
|
group_id,
|
||||||
{
|
{
|
||||||
|
|
|
||||||
|
|
@ -17,8 +17,6 @@ class ConnectionManager:
|
||||||
async def _setup_pubsub(self):
|
async def _setup_pubsub(self):
|
||||||
if self._pubsub is None:
|
if self._pubsub is None:
|
||||||
self._pubsub = redis_client.pubsub()
|
self._pubsub = redis_client.pubsub()
|
||||||
# Subscribe to all websocket broadcasting channels
|
|
||||||
# We use Any for _pubsub to satisfy type checkers that don't know redis-py return types
|
|
||||||
await self._pubsub.psubscribe("ws:group:*")
|
await self._pubsub.psubscribe("ws:group:*")
|
||||||
self._listen_task = asyncio.create_task(self._redis_listener())
|
self._listen_task = asyncio.create_task(self._redis_listener())
|
||||||
|
|
||||||
|
|
|
||||||
|
|
@ -1,4 +1,3 @@
|
||||||
from enum import Enum
|
|
||||||
from sqlalchemy import String, Boolean, Integer
|
from sqlalchemy import String, Boolean, Integer
|
||||||
from sqlalchemy.orm import Mapped, mapped_column
|
from sqlalchemy.orm import Mapped, mapped_column
|
||||||
|
|
||||||
|
|
|
||||||
|
|
@ -3,9 +3,9 @@ from pydantic import BaseModel, Field, field_validator
|
||||||
import re
|
import re
|
||||||
|
|
||||||
class UserCreate(BaseModel):
|
class UserCreate(BaseModel):
|
||||||
username: str
|
username: str = Field(..., max_length=20, description='username of the user')
|
||||||
phone_number: str | None = Field(None, description="11 digit phone number")
|
phone_number: str | None = Field(None, description="11 digit phone number")
|
||||||
is_admin: bool = False
|
is_admin: bool = Field(False)
|
||||||
|
|
||||||
@field_validator("phone_number")
|
@field_validator("phone_number")
|
||||||
@classmethod
|
@classmethod
|
||||||
|
|
@ -13,19 +13,19 @@ class UserCreate(BaseModel):
|
||||||
if v is None:
|
if v is None:
|
||||||
return v
|
return v
|
||||||
if not re.match(r"^09\d{9}$", v):
|
if not re.match(r"^09\d{9}$", v):
|
||||||
raise ValueError("Phone number must start with 09 and be exactly 11 digits")
|
raise ValueError("شماره تلفن باید ۱۱ رقم باشد و با ۰۹ شروع شود")
|
||||||
return v
|
return v
|
||||||
|
|
||||||
class UserResponse(BaseModel):
|
class UserResponse(BaseModel):
|
||||||
id: uuid.UUID
|
id: uuid.UUID = Field(...)
|
||||||
username: str
|
username: str = Field(..., description='username of the user')
|
||||||
phone_number: str | None
|
phone_number: str | None = Field(..., description='phone number of the user')
|
||||||
is_admin: bool
|
is_admin: bool = Field(..., description='is admin')
|
||||||
is_active: bool
|
is_active: bool = Field(..., description='is active')
|
||||||
|
|
||||||
class Config:
|
class Config:
|
||||||
from_attributes = True
|
from_attributes = True
|
||||||
|
|
||||||
class UserCreateResult(BaseModel):
|
class UserCreateResult(BaseModel):
|
||||||
user: UserResponse
|
user: UserResponse = Field(..., description='user created')
|
||||||
secret: str
|
secret: str = Field(..., description='secret of the user')
|
||||||
|
|
|
||||||
18
Back/main.py
18
Back/main.py
|
|
@ -1,5 +1,5 @@
|
||||||
from contextlib import asynccontextmanager
|
from contextlib import asynccontextmanager
|
||||||
from fastapi import FastAPI
|
from fastapi import FastAPI, Depends
|
||||||
from fastapi.middleware.cors import CORSMiddleware
|
from fastapi.middleware.cors import CORSMiddleware
|
||||||
from fastapi_swagger import patch_fastapi
|
from fastapi_swagger import patch_fastapi
|
||||||
|
|
||||||
|
|
@ -11,6 +11,7 @@ from domains.realtime.ws import router as realtime_router
|
||||||
from domains.notifications.api import router as notifications_router
|
from domains.notifications.api import router as notifications_router
|
||||||
from integrations.livekit.client import close_livekit_api
|
from integrations.livekit.client import close_livekit_api
|
||||||
from db.redis import redis_client
|
from db.redis import redis_client
|
||||||
|
from core.rate_limit import RateLimiter
|
||||||
|
|
||||||
|
|
||||||
@asynccontextmanager
|
@asynccontextmanager
|
||||||
|
|
@ -33,6 +34,7 @@ async def lifespan(app: FastAPI):
|
||||||
await close_livekit_api()
|
await close_livekit_api()
|
||||||
await redis_client.close()
|
await redis_client.close()
|
||||||
|
|
||||||
|
global_limiter = RateLimiter(requests=30, window_seconds=60, scope="global")
|
||||||
|
|
||||||
app = FastAPI(
|
app = FastAPI(
|
||||||
title="NEDA API",
|
title="NEDA API",
|
||||||
|
|
@ -40,7 +42,8 @@ app = FastAPI(
|
||||||
version="1.0.0",
|
version="1.0.0",
|
||||||
lifespan=lifespan,
|
lifespan=lifespan,
|
||||||
docs_url=None,
|
docs_url=None,
|
||||||
swagger_ui_oauth2_redirect_url=None
|
swagger_ui_oauth2_redirect_url=None,
|
||||||
|
dependencies=[Depends(global_limiter)]
|
||||||
)
|
)
|
||||||
patch_fastapi(app,docs_url="/swagger")
|
patch_fastapi(app,docs_url="/swagger")
|
||||||
|
|
||||||
|
|
@ -55,7 +58,16 @@ app.add_middleware(
|
||||||
allow_methods=["*"],
|
allow_methods=["*"],
|
||||||
allow_headers=["*"],
|
allow_headers=["*"],
|
||||||
)
|
)
|
||||||
|
# app.add_middleware(
|
||||||
|
# CORSMiddleware,
|
||||||
|
# allow_origins=[
|
||||||
|
# "https://app.neda.com",
|
||||||
|
# "http://localhost:3000" # فقط برای تست برنامهنویس فرانتاند
|
||||||
|
# ],
|
||||||
|
# allow_credentials=True,
|
||||||
|
# allow_methods=["GET", "POST", "PUT", "DELETE"], # محدود کردن متدها
|
||||||
|
# allow_headers=["Authorization", "Content-Type"], # محدود کردن هدرها
|
||||||
|
# )
|
||||||
|
|
||||||
# -------------------------
|
# -------------------------
|
||||||
# Routers
|
# Routers
|
||||||
|
|
|
||||||
Loading…
Reference in New Issue
Block a user