mirror of
https://github.com/lorsanstand/Aether.git
synced 2026-06-19 12:05:16 +03:00
Add websckets connection
This commit is contained in:
@@ -46,13 +46,15 @@ async def login(response: Response, credentials: OAuth2PasswordRequestForm = Dep
|
||||
'access_token',
|
||||
token.access_token,
|
||||
max_age=settings.ACCESS_TOKEN_EXPIRE_MINUTES * 60,
|
||||
httponly=True
|
||||
httponly=True,
|
||||
samesite='lax'
|
||||
)
|
||||
response.set_cookie(
|
||||
'refresh_token',
|
||||
str(token.refresh_token),
|
||||
max_age=settings.REFRESH_TOKEN_EXPIRE_DAYS * 30 * 24 * 60,
|
||||
httponly=True
|
||||
httponly=True,
|
||||
samesite='lax'
|
||||
)
|
||||
return token
|
||||
|
||||
|
||||
@@ -0,0 +1,100 @@
|
||||
import uuid
|
||||
from typing import Optional, List
|
||||
|
||||
from sqlalchemy.ext.asyncio import AsyncSession
|
||||
from sqlalchemy import func, select, and_
|
||||
|
||||
from app.core.dao import BaseDAO
|
||||
from app.chats.models import ChatModel, MessageModel, ParticipantModel
|
||||
from app.chats.schemas import ChatCreateDB, MessageCreateDB, ParticipantCreateDB
|
||||
from app.chats.schemas import ChatUpdateDB, MessageUpdateDB, ParticipantUpdateDB
|
||||
from app.users.models import UserModel
|
||||
|
||||
|
||||
class ChatDAO(BaseDAO[ChatModel, ChatCreateDB, ChatUpdateDB]):
|
||||
model = ChatModel
|
||||
|
||||
@classmethod
|
||||
async def get_chat_id(cls, session: AsyncSession, user_a_id: int, user_b_id: int) -> Optional[uuid.UUID]:
|
||||
stmt = (
|
||||
select(ParticipantModel.chat_id)
|
||||
.join(ChatModel)
|
||||
.where(ChatModel.is_group == False)
|
||||
.where(ParticipantModel.user_id.in_([user_a_id, user_b_id]))
|
||||
.group_by(ParticipantModel.chat_id)
|
||||
.having(func.count(ParticipantModel.chat_id) == 2)
|
||||
)
|
||||
result = await session.execute(stmt)
|
||||
return result.scalar_one_or_none()
|
||||
|
||||
|
||||
@classmethod
|
||||
async def get_chats(cls, session: AsyncSession, user_id: int, offset: int = 0, limit: int = 10):
|
||||
|
||||
stmt = (
|
||||
select(
|
||||
ChatModel.id.label("chat_id"),
|
||||
ChatModel.last_message,
|
||||
UserModel.id.label("user_id"),
|
||||
UserModel.avatar_url,
|
||||
UserModel.display_name
|
||||
)
|
||||
.join(ParticipantModel, ParticipantModel.chat_id==ChatModel.id)
|
||||
.join(UserModel, UserModel.id==ParticipantModel.user_id)
|
||||
.where(
|
||||
ChatModel.id.in_(
|
||||
select(ParticipantModel.chat_id).where(ParticipantModel.user_id==user_id)
|
||||
)
|
||||
)
|
||||
.where(UserModel.id!=user_id)
|
||||
.where(ChatModel.is_group==False)
|
||||
.order_by(ChatModel.updated_at.desc())
|
||||
.limit(limit)
|
||||
.offset(offset)
|
||||
)
|
||||
result = await session.execute(stmt)
|
||||
return result.mappings().all()
|
||||
|
||||
|
||||
@classmethod
|
||||
async def get_chat_with_participant(cls, session: AsyncSession, chat_id: uuid.UUID, user_id: int):
|
||||
stmt = (
|
||||
select(ChatModel, ParticipantModel.id.label("participant_id"))
|
||||
.outerjoin(
|
||||
ParticipantModel,
|
||||
and_(
|
||||
ParticipantModel.chat_id==ChatModel.id,
|
||||
ParticipantModel.user_id==user_id
|
||||
)
|
||||
)
|
||||
.where(ChatModel.id==chat_id)
|
||||
)
|
||||
result = await session.execute(stmt)
|
||||
return result.first()
|
||||
|
||||
|
||||
class MessageDAO(BaseDAO[MessageModel, MessageCreateDB, MessageUpdateDB]):
|
||||
model = MessageModel
|
||||
|
||||
@classmethod
|
||||
async def find_all_asc(
|
||||
cls,
|
||||
session: AsyncSession,
|
||||
offset: Optional[int],
|
||||
limit: Optional[int],
|
||||
*filter,
|
||||
**filter_by
|
||||
) -> List[MessageModel]:
|
||||
stmt = select(MessageModel).filter(*filter).filter_by(**filter_by).order_by(MessageModel.created_at.asc())
|
||||
|
||||
if offset is not None:
|
||||
stmt = stmt.offset(offset)
|
||||
if limit is not None:
|
||||
stmt = stmt.limit(limit)
|
||||
|
||||
result = await session.execute(stmt)
|
||||
return result.scalars().all()
|
||||
|
||||
|
||||
class ParticipantDAO(BaseDAO[ParticipantModel, ParticipantCreateDB, ParticipantUpdateDB]):
|
||||
model = ParticipantModel
|
||||
@@ -1,5 +1,7 @@
|
||||
import uuid
|
||||
|
||||
from sqlalchemy.orm import Mapped, mapped_column
|
||||
from sqlalchemy import ForeignKey
|
||||
from sqlalchemy import ForeignKey, UUID, UniqueConstraint
|
||||
|
||||
from app.core.database import Base
|
||||
|
||||
@@ -7,7 +9,28 @@ from app.core.database import Base
|
||||
class MessageModel(Base):
|
||||
__tablename__ = "message"
|
||||
|
||||
id: Mapped[int] = mapped_column(primary_key=True, index=True)
|
||||
sender_id: Mapped[int] = mapped_column(ForeignKey("user.id", ondelete="CASCADE"), index=True)
|
||||
recipient_id: Mapped[int] = mapped_column(ForeignKey("user.id", ondelete="CASCADE"), index=True)
|
||||
content: Mapped[str] = mapped_column()
|
||||
id: Mapped[uuid.UUID] = mapped_column(UUID, primary_key=True, index=True, default=uuid.uuid4)
|
||||
sender_id: Mapped[int] = mapped_column(ForeignKey("user.id", ondelete="SET NULL"), index=True)
|
||||
chat_id: Mapped[uuid.UUID] = mapped_column(UUID, ForeignKey("chat.id", ondelete="CASCADE"), index=True)
|
||||
content: Mapped[str] = mapped_column()
|
||||
is_read: Mapped[bool] = mapped_column(default=False)
|
||||
|
||||
|
||||
class ChatModel(Base):
|
||||
__tablename__ = "chat"
|
||||
|
||||
id: Mapped[uuid.UUID] = mapped_column(UUID, primary_key=True, index=True, default=uuid.uuid4)
|
||||
is_group: Mapped[bool] = mapped_column(default=False)
|
||||
last_message: Mapped[str] = mapped_column(nullable=True)
|
||||
|
||||
|
||||
class ParticipantModel(Base):
|
||||
__tablename__ = "Participant"
|
||||
|
||||
id: Mapped[uuid.UUID] = mapped_column(UUID, primary_key=True, index=True, default=uuid.uuid4)
|
||||
chat_id: Mapped[uuid.UUID] = mapped_column(UUID, ForeignKey("chat.id", ondelete="CASCADE"), index=True)
|
||||
user_id: Mapped[int] = mapped_column(ForeignKey("user.id", ondelete="CASCADE"), index=True)
|
||||
|
||||
__table_args__ = (
|
||||
UniqueConstraint("chat_id", "user_id", name="uq_chat_user"),
|
||||
)
|
||||
@@ -0,0 +1,44 @@
|
||||
import uuid
|
||||
from typing import List
|
||||
|
||||
from fastapi import APIRouter, Depends, WebSocket, WebSocketDisconnect
|
||||
|
||||
from app.chats.service import ChatService
|
||||
from app.auth.dependencies import get_current_verified_user
|
||||
from app.users.models import UserModel
|
||||
from app.chats.schemas import Chat, MessageCreate, Message
|
||||
|
||||
router = APIRouter(prefix="/chats", tags=["chats"])
|
||||
|
||||
@router.get("/")
|
||||
async def get_chats(
|
||||
offset: int = 0,
|
||||
limit: int = 10,
|
||||
user: UserModel = Depends(get_current_verified_user)
|
||||
) -> List[Chat]:
|
||||
return await ChatService.get_chats(user, offset, limit)
|
||||
|
||||
@router.get("/{chat_id}")
|
||||
async def get_chat(
|
||||
chat_id: uuid.UUID,
|
||||
offset: int = 0,
|
||||
limit: int = 10,
|
||||
user: UserModel = Depends(get_current_verified_user)
|
||||
) -> List[Message]:
|
||||
return await ChatService.get_chat(chat_id, user, offset, limit)
|
||||
|
||||
@router.post("/message")
|
||||
async def send_message(message: MessageCreate, user: UserModel = Depends(get_current_verified_user)) -> Message:
|
||||
return await ChatService.send_message(user, message)
|
||||
|
||||
|
||||
@router.websocket("/ws")
|
||||
async def websocket_endpoint(ws: WebSocket, user: UserModel = Depends(get_current_verified_user)):
|
||||
await ws.accept()
|
||||
await ChatService.save_websocket(user, ws)
|
||||
|
||||
try:
|
||||
while True:
|
||||
await ws.receive_text()
|
||||
except WebSocketDisconnect:
|
||||
await ChatService.delete_websocket(user)
|
||||
@@ -0,0 +1,67 @@
|
||||
from datetime import datetime
|
||||
from typing import Optional
|
||||
import uuid
|
||||
|
||||
from pydantic import BaseModel
|
||||
|
||||
|
||||
class MessageCreate(BaseModel):
|
||||
recipient_id: Optional[int] = None
|
||||
chat_id: Optional[uuid.UUID] = None
|
||||
content: str
|
||||
|
||||
|
||||
class MessageUpdate(BaseModel):
|
||||
id: uuid.UUID
|
||||
content: str
|
||||
|
||||
|
||||
class MessageCreateDB(BaseModel):
|
||||
sender_id: Optional[int]
|
||||
chat_id: Optional[uuid.UUID]
|
||||
content: Optional[str]
|
||||
is_read: Optional[bool] = False
|
||||
|
||||
|
||||
class MessageUpdateDB(BaseModel):
|
||||
content: Optional[str]
|
||||
|
||||
|
||||
class Message(BaseModel):
|
||||
id: uuid.UUID
|
||||
sender_id: int
|
||||
chat_id: uuid.UUID
|
||||
content: str
|
||||
created_at: datetime
|
||||
updated_at: datetime
|
||||
|
||||
|
||||
class ChatBase(BaseModel):
|
||||
is_group: Optional[bool] = False
|
||||
last_message: Optional[str] = None
|
||||
|
||||
|
||||
class ChatCreateDB(ChatBase):
|
||||
pass
|
||||
|
||||
|
||||
class ChatUpdateDB(ChatBase):
|
||||
pass
|
||||
|
||||
|
||||
class Chat(BaseModel):
|
||||
chat_id: uuid.UUID
|
||||
user_id: int
|
||||
last_message: Optional[str]
|
||||
avatar_url: Optional[str]
|
||||
display_name: str
|
||||
|
||||
|
||||
class ParticipantCreateDB(BaseModel):
|
||||
chat_id: Optional[uuid.UUID]
|
||||
user_id: Optional[int]
|
||||
|
||||
|
||||
class ParticipantUpdateDB(BaseModel):
|
||||
chat_id: Optional[uuid.UUID]
|
||||
user_id: Optional[int]
|
||||
@@ -0,0 +1,183 @@
|
||||
import json
|
||||
import uuid
|
||||
from typing import List, Dict
|
||||
import logging
|
||||
|
||||
from fastapi import HTTPException, status, WebSocket
|
||||
from sqlalchemy import and_
|
||||
|
||||
from app.core.database import async_session_maker
|
||||
from app.chats.dao import ChatDAO, MessageDAO, ParticipantDAO
|
||||
from app.chats.models import ChatModel, MessageModel, ParticipantModel
|
||||
from app.chats.schemas import Chat, MessageCreate, MessageCreateDB, ChatCreateDB, ParticipantCreateDB, Message
|
||||
from app.users.models import UserModel
|
||||
from app.core.redis import get_redis
|
||||
|
||||
log = logging.getLogger(__name__)
|
||||
|
||||
|
||||
class ChatService:
|
||||
active_connections: Dict[str, WebSocket] = {}
|
||||
|
||||
@classmethod
|
||||
async def get_chats(cls, user: UserModel, offset: int, limit: int) -> List[Chat]:
|
||||
log.debug("Getting chats", extra={"user_id": user.id, "offset": offset, "limit": limit})
|
||||
async with async_session_maker() as session:
|
||||
chats = await ChatDAO.get_chats(session, user.id, offset, limit)
|
||||
log.debug("Retrieved chats", extra={"user_id": user.id, "count": len(chats)})
|
||||
return chats
|
||||
|
||||
|
||||
@classmethod
|
||||
async def send_message(cls, sender: UserModel, message: MessageCreate) -> Message:
|
||||
log.info("Sending message", extra={"sender_id": sender.id, "chat_id": message.chat_id, "recipient_id": message.recipient_id})
|
||||
async with async_session_maker() as session:
|
||||
target_chat_id = message.chat_id
|
||||
|
||||
if target_chat_id is None:
|
||||
if message.recipient_id is None:
|
||||
log.warning("Message send failed: missing chat_id and recipient_id", extra={"sender_id": sender.id})
|
||||
raise HTTPException(status.HTTP_400_BAD_REQUEST, "Need chat_id or user_id")
|
||||
|
||||
target_chat_id = await ChatDAO.get_chat_id(session, sender.id, message.recipient_id)
|
||||
|
||||
if target_chat_id is None:
|
||||
log.info("Creating new chat", extra={"sender_id": sender.id, "recipient_id": message.recipient_id})
|
||||
target_chat_db = await ChatDAO.add(
|
||||
session,
|
||||
obj_in=ChatCreateDB(
|
||||
is_group=False,
|
||||
last_message=message.content
|
||||
)
|
||||
)
|
||||
target_chat_id: uuid.UUID = target_chat_db.id
|
||||
|
||||
await ParticipantDAO.add(
|
||||
session,
|
||||
obj_in=ParticipantCreateDB(
|
||||
user_id=sender.id,
|
||||
chat_id=target_chat_id
|
||||
)
|
||||
)
|
||||
await ParticipantDAO.add(
|
||||
session,
|
||||
obj_in=ParticipantCreateDB(
|
||||
user_id=message.recipient_id,
|
||||
chat_id=target_chat_id
|
||||
)
|
||||
)
|
||||
log.info("Created new chat", extra={"chat_id": target_chat_id, "sender_id": sender.id, "recipient_id": message.recipient_id})
|
||||
|
||||
|
||||
members = await ParticipantDAO.find_all(
|
||||
session,
|
||||
None,
|
||||
None,
|
||||
ParticipantModel.chat_id==target_chat_id
|
||||
)
|
||||
|
||||
members_ids = [member.user_id for member in members]
|
||||
|
||||
if not sender.id in members_ids :
|
||||
log.warning("Access denied to chat", extra={"user_id": sender.id, "chat_id": message.chat_id})
|
||||
raise HTTPException(status.HTTP_403_FORBIDDEN, detail="Access denied")
|
||||
|
||||
|
||||
message_db = await MessageDAO.add(
|
||||
session,
|
||||
obj_in=MessageCreateDB(
|
||||
sender_id=sender.id,
|
||||
chat_id=target_chat_id,
|
||||
content=message.content
|
||||
)
|
||||
)
|
||||
|
||||
await cls._send_ws_message(members_ids, Message(
|
||||
id=message_db.id,
|
||||
sender_id=message_db.sender_id,
|
||||
chat_id=message_db.chat_id,
|
||||
content=message_db.content,
|
||||
created_at=message_db.created_at,
|
||||
updated_at=message_db.updated_at
|
||||
))
|
||||
|
||||
await ChatDAO.update(
|
||||
session,
|
||||
ChatModel.id==target_chat_id,
|
||||
obj_in={"last_message": message.content}
|
||||
)
|
||||
await session.commit()
|
||||
log.info("Message sent", extra={"message_id": message_db.id, "sender_id": sender.id, "chat_id": target_chat_id})
|
||||
return message_db
|
||||
|
||||
|
||||
@classmethod
|
||||
async def get_chat(cls, chat_id: uuid.UUID, user: UserModel, offset: int = 0, limit: int = 0) -> List[Message]:
|
||||
log.debug("Getting chat messages", extra={"user_id": user.id, "chat_id": chat_id, "offset": offset, "limit": limit})
|
||||
async with async_session_maker() as session:
|
||||
chat_exist = await ChatDAO.get_chat_with_participant(session, chat_id, user.id)
|
||||
|
||||
if chat_exist is None:
|
||||
log.warning("Chat not found", extra={"user_id": user.id, "chat_id": chat_id})
|
||||
raise HTTPException(status.HTTP_404_NOT_FOUND, "Chat not found")
|
||||
if chat_exist.participant_id is None:
|
||||
log.warning("Access denied to chat", extra={"user_id": user.id, "chat_id": chat_id})
|
||||
raise HTTPException(status.HTTP_403_FORBIDDEN, detail="Access denied")
|
||||
|
||||
messages = await MessageDAO.find_all_asc(
|
||||
session,
|
||||
offset,
|
||||
limit,
|
||||
MessageModel.chat_id==chat_id
|
||||
)
|
||||
|
||||
log.debug("Retrieved chat messages", extra={"user_id": user.id, "chat_id": chat_id, "count": len(messages)})
|
||||
return messages
|
||||
|
||||
|
||||
@classmethod
|
||||
async def save_websocket(cls, user: UserModel, ws: WebSocket):
|
||||
cls.active_connections[str(user.id)] = ws
|
||||
log.info("WebSocket connection saved", extra={"user_id": user.id, "active_connections": len(cls.active_connections) + 1})
|
||||
|
||||
|
||||
|
||||
@classmethod
|
||||
async def delete_websocket(cls, user: UserModel):
|
||||
cls.active_connections.pop(str(user.id))
|
||||
log.info("WebSocket connection deleted", extra={"user_id": user.id, "active_connections": len(cls.active_connections) - 1})
|
||||
|
||||
|
||||
@classmethod
|
||||
async def message_listener(cls):
|
||||
redis_client = await get_redis()
|
||||
|
||||
pubsub = redis_client.pubsub()
|
||||
|
||||
await pubsub.subscribe("messenger_updates")
|
||||
|
||||
async for message in pubsub.listen():
|
||||
log.debug(f"Received message from Redis: {message}")
|
||||
if message["type"] == "message":
|
||||
payload = json.loads(message["data"])
|
||||
user_id = payload["user_id"]
|
||||
|
||||
if user_id in cls.active_connections:
|
||||
ws = cls.active_connections[user_id]
|
||||
await ws.send_json(payload["message"])
|
||||
log.info(f"Message sent to user {user_id} via WebSocket")
|
||||
else:
|
||||
log.debug(f"User {user_id} not connected")
|
||||
|
||||
|
||||
|
||||
@classmethod
|
||||
async def _send_ws_message(cls, user_ids: List[int], message: Message):
|
||||
redis_client = await get_redis()
|
||||
for user_id in user_ids:
|
||||
payload = {
|
||||
"user_id": str(user_id),
|
||||
"message": message.model_dump(mode='json')
|
||||
}
|
||||
await redis_client.publish("messenger_updates", json.dumps(payload))
|
||||
log.debug(f"Published message for user_id: {user_id}")
|
||||
@@ -3,6 +3,7 @@ from contextlib import asynccontextmanager
|
||||
import uvicorn
|
||||
import logging
|
||||
import uuid
|
||||
import asyncio
|
||||
|
||||
from fastapi import FastAPI, APIRouter, Request, Response
|
||||
from fastapi.middleware.cors import CORSMiddleware
|
||||
@@ -10,6 +11,8 @@ from fastapi.middleware.cors import CORSMiddleware
|
||||
from app.core.redis import close_redis, init_redis
|
||||
from app.users.router import router as user_router
|
||||
from app.auth.router import router as auth_router
|
||||
from app.chats.router import router as chat_router
|
||||
from app.chats.service import ChatService
|
||||
from app.core.log_config import set_logging
|
||||
from app.core.config import settings
|
||||
|
||||
@@ -21,14 +24,19 @@ log = logging.getLogger(__name__)
|
||||
async def lifespan(app: FastAPI):
|
||||
await init_redis()
|
||||
log.info("Redis connected")
|
||||
task_send_message = asyncio.create_task(ChatService.message_listener())
|
||||
log.info("Message sender started")
|
||||
yield
|
||||
await close_redis()
|
||||
log.info("Redis disconnected")
|
||||
task_send_message.cancel()
|
||||
log.info("Message sender stopped")
|
||||
|
||||
|
||||
api_router = APIRouter(prefix="/api/v1")
|
||||
api_router.include_router(user_router)
|
||||
api_router.include_router(auth_router)
|
||||
api_router.include_router(chat_router)
|
||||
|
||||
@api_router.get("/health")
|
||||
async def test_health():
|
||||
|
||||
@@ -9,6 +9,8 @@ from alembic import context
|
||||
|
||||
from app.core.database import Base
|
||||
from app.core.config import settings
|
||||
from app.users.models import UserModel
|
||||
from app.chats.models import MessageModel, ChatModel, ParticipantModel
|
||||
|
||||
# this is the Alembic Config object, which provides
|
||||
# access to the values within the .ini file in use.
|
||||
|
||||
@@ -0,0 +1,101 @@
|
||||
"""Create tables
|
||||
|
||||
Revision ID: 0d3f7039ba77
|
||||
Revises: 7ad624ae1699
|
||||
Create Date: 2026-01-12 15:51:43.453822
|
||||
|
||||
"""
|
||||
from typing import Sequence, Union
|
||||
|
||||
from alembic import op
|
||||
import sqlalchemy as sa
|
||||
|
||||
|
||||
# revision identifiers, used by Alembic.
|
||||
revision: str = '0d3f7039ba77'
|
||||
down_revision: Union[str, Sequence[str], None] = '7ad624ae1699'
|
||||
branch_labels: Union[str, Sequence[str], None] = None
|
||||
depends_on: Union[str, Sequence[str], None] = None
|
||||
|
||||
|
||||
def upgrade() -> None:
|
||||
"""Upgrade schema."""
|
||||
# ### commands auto generated by Alembic - please adjust! ###
|
||||
op.create_table('chat',
|
||||
sa.Column('id', sa.UUID(), nullable=False),
|
||||
sa.Column('is_group', sa.Boolean(), nullable=False),
|
||||
sa.Column('last_message', sa.String(), nullable=True),
|
||||
sa.Column('created_at', sa.DateTime(), server_default=sa.text('now()'), nullable=False),
|
||||
sa.Column('updated_at', sa.DateTime(), server_default=sa.text('now()'), nullable=False),
|
||||
sa.PrimaryKeyConstraint('id', name=op.f('chat_pkey'))
|
||||
)
|
||||
op.create_index(op.f('chat_id_idx'), 'chat', ['id'], unique=False)
|
||||
op.create_table('user',
|
||||
sa.Column('id', sa.Integer(), nullable=False),
|
||||
sa.Column('display_name', sa.String(), nullable=False),
|
||||
sa.Column('username', sa.String(), nullable=False),
|
||||
sa.Column('email', sa.String(), nullable=False),
|
||||
sa.Column('birth_day', sa.DATE(), nullable=True),
|
||||
sa.Column('description', sa.String(), nullable=True),
|
||||
sa.Column('avatar_url', sa.String(), nullable=True),
|
||||
sa.Column('is_active', sa.Boolean(), nullable=False),
|
||||
sa.Column('is_verified', sa.Boolean(), nullable=False),
|
||||
sa.Column('is_superuser', sa.Boolean(), nullable=False),
|
||||
sa.Column('hashed_password', sa.String(), nullable=False),
|
||||
sa.Column('created_at', sa.DateTime(), server_default=sa.text('now()'), nullable=False),
|
||||
sa.Column('updated_at', sa.DateTime(), server_default=sa.text('now()'), nullable=False),
|
||||
sa.PrimaryKeyConstraint('id', name=op.f('user_pkey'))
|
||||
)
|
||||
op.create_index(op.f('user_email_idx'), 'user', ['email'], unique=True)
|
||||
op.create_index(op.f('user_id_idx'), 'user', ['id'], unique=False)
|
||||
op.create_index(op.f('user_username_idx'), 'user', ['username'], unique=True)
|
||||
op.create_table('Participant',
|
||||
sa.Column('id', sa.UUID(), nullable=False),
|
||||
sa.Column('chat_id', sa.UUID(), nullable=False),
|
||||
sa.Column('user_id', sa.Integer(), nullable=False),
|
||||
sa.Column('created_at', sa.DateTime(), server_default=sa.text('now()'), nullable=False),
|
||||
sa.Column('updated_at', sa.DateTime(), server_default=sa.text('now()'), nullable=False),
|
||||
sa.ForeignKeyConstraint(['chat_id'], ['chat.id'], name=op.f('Participant_chat_id_fkey'), ondelete='CASCADE'),
|
||||
sa.ForeignKeyConstraint(['user_id'], ['user.id'], name=op.f('Participant_user_id_fkey'), ondelete='CASCADE'),
|
||||
sa.PrimaryKeyConstraint('id', name=op.f('Participant_pkey')),
|
||||
sa.UniqueConstraint('chat_id', 'user_id', name='uq_chat_user')
|
||||
)
|
||||
op.create_index(op.f('Participant_chat_id_idx'), 'Participant', ['chat_id'], unique=False)
|
||||
op.create_index(op.f('Participant_id_idx'), 'Participant', ['id'], unique=False)
|
||||
op.create_index(op.f('Participant_user_id_idx'), 'Participant', ['user_id'], unique=False)
|
||||
op.create_table('message',
|
||||
sa.Column('id', sa.UUID(), nullable=False),
|
||||
sa.Column('sender_id', sa.Integer(), nullable=False),
|
||||
sa.Column('chat_id', sa.UUID(), nullable=False),
|
||||
sa.Column('content', sa.String(), nullable=False),
|
||||
sa.Column('is_read', sa.Boolean(), nullable=False),
|
||||
sa.Column('created_at', sa.DateTime(), server_default=sa.text('now()'), nullable=False),
|
||||
sa.Column('updated_at', sa.DateTime(), server_default=sa.text('now()'), nullable=False),
|
||||
sa.ForeignKeyConstraint(['chat_id'], ['chat.id'], name=op.f('message_chat_id_fkey'), ondelete='CASCADE'),
|
||||
sa.ForeignKeyConstraint(['sender_id'], ['user.id'], name=op.f('message_sender_id_fkey'), ondelete='SET NULL'),
|
||||
sa.PrimaryKeyConstraint('id', name=op.f('message_pkey'))
|
||||
)
|
||||
op.create_index(op.f('message_chat_id_idx'), 'message', ['chat_id'], unique=False)
|
||||
op.create_index(op.f('message_id_idx'), 'message', ['id'], unique=False)
|
||||
op.create_index(op.f('message_sender_id_idx'), 'message', ['sender_id'], unique=False)
|
||||
# ### end Alembic commands ###
|
||||
|
||||
|
||||
def downgrade() -> None:
|
||||
"""Downgrade schema."""
|
||||
# ### commands auto generated by Alembic - please adjust! ###
|
||||
op.drop_index(op.f('message_sender_id_idx'), table_name='message')
|
||||
op.drop_index(op.f('message_id_idx'), table_name='message')
|
||||
op.drop_index(op.f('message_chat_id_idx'), table_name='message')
|
||||
op.drop_table('message')
|
||||
op.drop_index(op.f('Participant_user_id_idx'), table_name='Participant')
|
||||
op.drop_index(op.f('Participant_id_idx'), table_name='Participant')
|
||||
op.drop_index(op.f('Participant_chat_id_idx'), table_name='Participant')
|
||||
op.drop_table('Participant')
|
||||
op.drop_index(op.f('user_username_idx'), table_name='user')
|
||||
op.drop_index(op.f('user_id_idx'), table_name='user')
|
||||
op.drop_index(op.f('user_email_idx'), table_name='user')
|
||||
op.drop_table('user')
|
||||
op.drop_index(op.f('chat_id_idx'), table_name='chat')
|
||||
op.drop_table('chat')
|
||||
# ### end Alembic commands ###
|
||||
@@ -0,0 +1,32 @@
|
||||
"""EDIT: chat tables
|
||||
|
||||
Revision ID: 7a5ccb6859fe
|
||||
Revises: fd15ec3ae3fb
|
||||
Create Date: 2026-01-12 14:51:11.514074
|
||||
|
||||
"""
|
||||
from typing import Sequence, Union
|
||||
|
||||
from alembic import op
|
||||
import sqlalchemy as sa
|
||||
|
||||
|
||||
# revision identifiers, used by Alembic.
|
||||
revision: str = '7a5ccb6859fe'
|
||||
down_revision: Union[str, Sequence[str], None] = 'fd15ec3ae3fb'
|
||||
branch_labels: Union[str, Sequence[str], None] = None
|
||||
depends_on: Union[str, Sequence[str], None] = None
|
||||
|
||||
|
||||
def upgrade() -> None:
|
||||
"""Upgrade schema."""
|
||||
# ### commands auto generated by Alembic - please adjust! ###
|
||||
pass
|
||||
# ### end Alembic commands ###
|
||||
|
||||
|
||||
def downgrade() -> None:
|
||||
"""Downgrade schema."""
|
||||
# ### commands auto generated by Alembic - please adjust! ###
|
||||
pass
|
||||
# ### end Alembic commands ###
|
||||
@@ -0,0 +1,32 @@
|
||||
"""EDIT: chat tables
|
||||
|
||||
Revision ID: 7ad624ae1699
|
||||
Revises: 7a5ccb6859fe
|
||||
Create Date: 2026-01-12 14:54:04.459361
|
||||
|
||||
"""
|
||||
from typing import Sequence, Union
|
||||
|
||||
from alembic import op
|
||||
import sqlalchemy as sa
|
||||
|
||||
|
||||
# revision identifiers, used by Alembic.
|
||||
revision: str = '7ad624ae1699'
|
||||
down_revision: Union[str, Sequence[str], None] = '7a5ccb6859fe'
|
||||
branch_labels: Union[str, Sequence[str], None] = None
|
||||
depends_on: Union[str, Sequence[str], None] = None
|
||||
|
||||
|
||||
def upgrade() -> None:
|
||||
"""Upgrade schema."""
|
||||
# ### commands auto generated by Alembic - please adjust! ###
|
||||
pass
|
||||
# ### end Alembic commands ###
|
||||
|
||||
|
||||
def downgrade() -> None:
|
||||
"""Downgrade schema."""
|
||||
# ### commands auto generated by Alembic - please adjust! ###
|
||||
pass
|
||||
# ### end Alembic commands ###
|
||||
@@ -0,0 +1,53 @@
|
||||
"""ADD: chat tables
|
||||
|
||||
Revision ID: fd15ec3ae3fb
|
||||
Revises: 4d00c9b0516e
|
||||
Create Date: 2026-01-11 21:54:51.418126
|
||||
|
||||
"""
|
||||
from typing import Sequence, Union
|
||||
|
||||
from alembic import op
|
||||
import sqlalchemy as sa
|
||||
from sqlalchemy.dialects import postgresql
|
||||
|
||||
# revision identifiers, used by Alembic.
|
||||
revision: str = 'fd15ec3ae3fb'
|
||||
down_revision: Union[str, Sequence[str], None] = '4d00c9b0516e'
|
||||
branch_labels: Union[str, Sequence[str], None] = None
|
||||
depends_on: Union[str, Sequence[str], None] = None
|
||||
|
||||
|
||||
def upgrade() -> None:
|
||||
"""Upgrade schema."""
|
||||
# ### commands auto generated by Alembic - please adjust! ###
|
||||
op.drop_index(op.f('user_email_idx'), table_name='user')
|
||||
op.drop_index(op.f('user_id_idx'), table_name='user')
|
||||
op.drop_index(op.f('user_username_idx'), table_name='user')
|
||||
op.drop_table('user')
|
||||
# ### end Alembic commands ###
|
||||
|
||||
|
||||
def downgrade() -> None:
|
||||
"""Downgrade schema."""
|
||||
# ### commands auto generated by Alembic - please adjust! ###
|
||||
op.create_table('user',
|
||||
sa.Column('id', sa.INTEGER(), autoincrement=True, nullable=False),
|
||||
sa.Column('display_name', sa.VARCHAR(), autoincrement=False, nullable=False),
|
||||
sa.Column('username', sa.VARCHAR(), autoincrement=False, nullable=False),
|
||||
sa.Column('email', sa.VARCHAR(), autoincrement=False, nullable=False),
|
||||
sa.Column('birth_day', sa.DATE(), autoincrement=False, nullable=True),
|
||||
sa.Column('description', sa.VARCHAR(), autoincrement=False, nullable=True),
|
||||
sa.Column('avatar_url', sa.VARCHAR(), autoincrement=False, nullable=True),
|
||||
sa.Column('is_active', sa.BOOLEAN(), autoincrement=False, nullable=False),
|
||||
sa.Column('is_verified', sa.BOOLEAN(), autoincrement=False, nullable=False),
|
||||
sa.Column('is_superuser', sa.BOOLEAN(), autoincrement=False, nullable=False),
|
||||
sa.Column('created_at', postgresql.TIMESTAMP(), server_default=sa.text('now()'), autoincrement=False, nullable=False),
|
||||
sa.Column('updated_at', postgresql.TIMESTAMP(), server_default=sa.text('now()'), autoincrement=False, nullable=False),
|
||||
sa.Column('hashed_password', sa.VARCHAR(), autoincrement=False, nullable=False),
|
||||
sa.PrimaryKeyConstraint('id', name=op.f('user_pkey'))
|
||||
)
|
||||
op.create_index(op.f('user_username_idx'), 'user', ['username'], unique=True)
|
||||
op.create_index(op.f('user_id_idx'), 'user', ['id'], unique=False)
|
||||
op.create_index(op.f('user_email_idx'), 'user', ['email'], unique=True)
|
||||
# ### end Alembic commands ###
|
||||
@@ -24,6 +24,10 @@ async def get_users(offset: int, limit: int, user: UserModel = Depends(get_curre
|
||||
log.info("Getting users list", extra={"offset": offset, "limit": limit})
|
||||
return await UserService.get_users_list(offset=offset, limit=limit)
|
||||
|
||||
@router.get("/{user_id}")
|
||||
async def get_user(user_id: int, user: UserModel = Depends(get_current_verified_user)):
|
||||
return await UserService.get_user(user_id)
|
||||
|
||||
@router.put("/me")
|
||||
async def update_current_user(update_user: UserUpdate, user: UserModel = Depends(get_current_verified_user)) -> User:
|
||||
return await UserService.update_user(user.id, update_user)
|
||||
|
||||
@@ -40,6 +40,7 @@ class User(UserBase):
|
||||
is_verified: bool
|
||||
is_superuser: bool
|
||||
|
||||
|
||||
class UserCreateDB(UserBase):
|
||||
email: Optional[str] = None
|
||||
hashed_password: Optional[str] = None
|
||||
@@ -47,6 +48,7 @@ class UserCreateDB(UserBase):
|
||||
is_verified: Optional[bool] = None
|
||||
is_superuser: Optional[bool] = None
|
||||
|
||||
|
||||
class UserUpdateDB(UserBase):
|
||||
email: Optional[str] = None
|
||||
hashed_password: Optional[str] = None
|
||||
@@ -56,6 +58,7 @@ class UserUpdateDB(UserBase):
|
||||
is_verified: Optional[bool] = None
|
||||
is_superuser: Optional[bool] = None
|
||||
|
||||
|
||||
class ChangePassword(BaseModel):
|
||||
old_password: str
|
||||
new_password: str
|
||||
|
||||
@@ -1,6 +1,6 @@
|
||||
from typing import Dict, Optional
|
||||
|
||||
from fastapi import HTTPException, Request, status
|
||||
from fastapi import HTTPException, Request, status, WebSocket
|
||||
from fastapi.openapi.models import OAuthFlows as OAuthFlowsModel
|
||||
from fastapi.security import OAuth2
|
||||
from fastapi.security.utils import get_authorization_scheme_param
|
||||
@@ -20,8 +20,19 @@ class OAuth2PasswordBearerWithCookie(OAuth2):
|
||||
password={"tokenUrl": tokenUrl, "scopes": scopes})
|
||||
super().__init__(flows=flows, scheme_name=scheme_name, auto_error=auto_error)
|
||||
|
||||
async def __call__(self, request: Request) -> Optional[str]:
|
||||
authorization: str = request.cookies.get("access_token")
|
||||
|
||||
async def __call__(
|
||||
self,
|
||||
request: Request = None,
|
||||
websocket: WebSocket = None
|
||||
) -> Optional[str]:
|
||||
connection = request or websocket
|
||||
|
||||
if connection is None:
|
||||
raise HTTPException(status_code=status.HTTP_400_BAD_REQUEST, detail="No connection found")
|
||||
|
||||
authorization: str = connection.cookies.get("access_token")
|
||||
print(authorization)
|
||||
|
||||
scheme, param = get_authorization_scheme_param(authorization)
|
||||
if not authorization or scheme.lower() != "bearer":
|
||||
|
||||
Reference in New Issue
Block a user