Add websckets connection

This commit is contained in:
2026-01-20 17:06:06 +03:00
parent 8167c77a27
commit a690116399
19 changed files with 748 additions and 83 deletions
+100
View File
@@ -0,0 +1,100 @@
import uuid
from typing import Optional, List
from sqlalchemy.ext.asyncio import AsyncSession
from sqlalchemy import func, select, and_
from app.core.dao import BaseDAO
from app.chats.models import ChatModel, MessageModel, ParticipantModel
from app.chats.schemas import ChatCreateDB, MessageCreateDB, ParticipantCreateDB
from app.chats.schemas import ChatUpdateDB, MessageUpdateDB, ParticipantUpdateDB
from app.users.models import UserModel
class ChatDAO(BaseDAO[ChatModel, ChatCreateDB, ChatUpdateDB]):
model = ChatModel
@classmethod
async def get_chat_id(cls, session: AsyncSession, user_a_id: int, user_b_id: int) -> Optional[uuid.UUID]:
stmt = (
select(ParticipantModel.chat_id)
.join(ChatModel)
.where(ChatModel.is_group == False)
.where(ParticipantModel.user_id.in_([user_a_id, user_b_id]))
.group_by(ParticipantModel.chat_id)
.having(func.count(ParticipantModel.chat_id) == 2)
)
result = await session.execute(stmt)
return result.scalar_one_or_none()
@classmethod
async def get_chats(cls, session: AsyncSession, user_id: int, offset: int = 0, limit: int = 10):
stmt = (
select(
ChatModel.id.label("chat_id"),
ChatModel.last_message,
UserModel.id.label("user_id"),
UserModel.avatar_url,
UserModel.display_name
)
.join(ParticipantModel, ParticipantModel.chat_id==ChatModel.id)
.join(UserModel, UserModel.id==ParticipantModel.user_id)
.where(
ChatModel.id.in_(
select(ParticipantModel.chat_id).where(ParticipantModel.user_id==user_id)
)
)
.where(UserModel.id!=user_id)
.where(ChatModel.is_group==False)
.order_by(ChatModel.updated_at.desc())
.limit(limit)
.offset(offset)
)
result = await session.execute(stmt)
return result.mappings().all()
@classmethod
async def get_chat_with_participant(cls, session: AsyncSession, chat_id: uuid.UUID, user_id: int):
stmt = (
select(ChatModel, ParticipantModel.id.label("participant_id"))
.outerjoin(
ParticipantModel,
and_(
ParticipantModel.chat_id==ChatModel.id,
ParticipantModel.user_id==user_id
)
)
.where(ChatModel.id==chat_id)
)
result = await session.execute(stmt)
return result.first()
class MessageDAO(BaseDAO[MessageModel, MessageCreateDB, MessageUpdateDB]):
model = MessageModel
@classmethod
async def find_all_asc(
cls,
session: AsyncSession,
offset: Optional[int],
limit: Optional[int],
*filter,
**filter_by
) -> List[MessageModel]:
stmt = select(MessageModel).filter(*filter).filter_by(**filter_by).order_by(MessageModel.created_at.asc())
if offset is not None:
stmt = stmt.offset(offset)
if limit is not None:
stmt = stmt.limit(limit)
result = await session.execute(stmt)
return result.scalars().all()
class ParticipantDAO(BaseDAO[ParticipantModel, ParticipantCreateDB, ParticipantUpdateDB]):
model = ParticipantModel
+28 -5
View File
@@ -1,5 +1,7 @@
import uuid
from sqlalchemy.orm import Mapped, mapped_column
from sqlalchemy import ForeignKey
from sqlalchemy import ForeignKey, UUID, UniqueConstraint
from app.core.database import Base
@@ -7,7 +9,28 @@ from app.core.database import Base
class MessageModel(Base):
__tablename__ = "message"
id: Mapped[int] = mapped_column(primary_key=True, index=True)
sender_id: Mapped[int] = mapped_column(ForeignKey("user.id", ondelete="CASCADE"), index=True)
recipient_id: Mapped[int] = mapped_column(ForeignKey("user.id", ondelete="CASCADE"), index=True)
content: Mapped[str] = mapped_column()
id: Mapped[uuid.UUID] = mapped_column(UUID, primary_key=True, index=True, default=uuid.uuid4)
sender_id: Mapped[int] = mapped_column(ForeignKey("user.id", ondelete="SET NULL"), index=True)
chat_id: Mapped[uuid.UUID] = mapped_column(UUID, ForeignKey("chat.id", ondelete="CASCADE"), index=True)
content: Mapped[str] = mapped_column()
is_read: Mapped[bool] = mapped_column(default=False)
class ChatModel(Base):
__tablename__ = "chat"
id: Mapped[uuid.UUID] = mapped_column(UUID, primary_key=True, index=True, default=uuid.uuid4)
is_group: Mapped[bool] = mapped_column(default=False)
last_message: Mapped[str] = mapped_column(nullable=True)
class ParticipantModel(Base):
__tablename__ = "Participant"
id: Mapped[uuid.UUID] = mapped_column(UUID, primary_key=True, index=True, default=uuid.uuid4)
chat_id: Mapped[uuid.UUID] = mapped_column(UUID, ForeignKey("chat.id", ondelete="CASCADE"), index=True)
user_id: Mapped[int] = mapped_column(ForeignKey("user.id", ondelete="CASCADE"), index=True)
__table_args__ = (
UniqueConstraint("chat_id", "user_id", name="uq_chat_user"),
)
+44
View File
@@ -0,0 +1,44 @@
import uuid
from typing import List
from fastapi import APIRouter, Depends, WebSocket, WebSocketDisconnect
from app.chats.service import ChatService
from app.auth.dependencies import get_current_verified_user
from app.users.models import UserModel
from app.chats.schemas import Chat, MessageCreate, Message
router = APIRouter(prefix="/chats", tags=["chats"])
@router.get("/")
async def get_chats(
offset: int = 0,
limit: int = 10,
user: UserModel = Depends(get_current_verified_user)
) -> List[Chat]:
return await ChatService.get_chats(user, offset, limit)
@router.get("/{chat_id}")
async def get_chat(
chat_id: uuid.UUID,
offset: int = 0,
limit: int = 10,
user: UserModel = Depends(get_current_verified_user)
) -> List[Message]:
return await ChatService.get_chat(chat_id, user, offset, limit)
@router.post("/message")
async def send_message(message: MessageCreate, user: UserModel = Depends(get_current_verified_user)) -> Message:
return await ChatService.send_message(user, message)
@router.websocket("/ws")
async def websocket_endpoint(ws: WebSocket, user: UserModel = Depends(get_current_verified_user)):
await ws.accept()
await ChatService.save_websocket(user, ws)
try:
while True:
await ws.receive_text()
except WebSocketDisconnect:
await ChatService.delete_websocket(user)
+67
View File
@@ -0,0 +1,67 @@
from datetime import datetime
from typing import Optional
import uuid
from pydantic import BaseModel
class MessageCreate(BaseModel):
recipient_id: Optional[int] = None
chat_id: Optional[uuid.UUID] = None
content: str
class MessageUpdate(BaseModel):
id: uuid.UUID
content: str
class MessageCreateDB(BaseModel):
sender_id: Optional[int]
chat_id: Optional[uuid.UUID]
content: Optional[str]
is_read: Optional[bool] = False
class MessageUpdateDB(BaseModel):
content: Optional[str]
class Message(BaseModel):
id: uuid.UUID
sender_id: int
chat_id: uuid.UUID
content: str
created_at: datetime
updated_at: datetime
class ChatBase(BaseModel):
is_group: Optional[bool] = False
last_message: Optional[str] = None
class ChatCreateDB(ChatBase):
pass
class ChatUpdateDB(ChatBase):
pass
class Chat(BaseModel):
chat_id: uuid.UUID
user_id: int
last_message: Optional[str]
avatar_url: Optional[str]
display_name: str
class ParticipantCreateDB(BaseModel):
chat_id: Optional[uuid.UUID]
user_id: Optional[int]
class ParticipantUpdateDB(BaseModel):
chat_id: Optional[uuid.UUID]
user_id: Optional[int]
+183
View File
@@ -0,0 +1,183 @@
import json
import uuid
from typing import List, Dict
import logging
from fastapi import HTTPException, status, WebSocket
from sqlalchemy import and_
from app.core.database import async_session_maker
from app.chats.dao import ChatDAO, MessageDAO, ParticipantDAO
from app.chats.models import ChatModel, MessageModel, ParticipantModel
from app.chats.schemas import Chat, MessageCreate, MessageCreateDB, ChatCreateDB, ParticipantCreateDB, Message
from app.users.models import UserModel
from app.core.redis import get_redis
log = logging.getLogger(__name__)
class ChatService:
active_connections: Dict[str, WebSocket] = {}
@classmethod
async def get_chats(cls, user: UserModel, offset: int, limit: int) -> List[Chat]:
log.debug("Getting chats", extra={"user_id": user.id, "offset": offset, "limit": limit})
async with async_session_maker() as session:
chats = await ChatDAO.get_chats(session, user.id, offset, limit)
log.debug("Retrieved chats", extra={"user_id": user.id, "count": len(chats)})
return chats
@classmethod
async def send_message(cls, sender: UserModel, message: MessageCreate) -> Message:
log.info("Sending message", extra={"sender_id": sender.id, "chat_id": message.chat_id, "recipient_id": message.recipient_id})
async with async_session_maker() as session:
target_chat_id = message.chat_id
if target_chat_id is None:
if message.recipient_id is None:
log.warning("Message send failed: missing chat_id and recipient_id", extra={"sender_id": sender.id})
raise HTTPException(status.HTTP_400_BAD_REQUEST, "Need chat_id or user_id")
target_chat_id = await ChatDAO.get_chat_id(session, sender.id, message.recipient_id)
if target_chat_id is None:
log.info("Creating new chat", extra={"sender_id": sender.id, "recipient_id": message.recipient_id})
target_chat_db = await ChatDAO.add(
session,
obj_in=ChatCreateDB(
is_group=False,
last_message=message.content
)
)
target_chat_id: uuid.UUID = target_chat_db.id
await ParticipantDAO.add(
session,
obj_in=ParticipantCreateDB(
user_id=sender.id,
chat_id=target_chat_id
)
)
await ParticipantDAO.add(
session,
obj_in=ParticipantCreateDB(
user_id=message.recipient_id,
chat_id=target_chat_id
)
)
log.info("Created new chat", extra={"chat_id": target_chat_id, "sender_id": sender.id, "recipient_id": message.recipient_id})
members = await ParticipantDAO.find_all(
session,
None,
None,
ParticipantModel.chat_id==target_chat_id
)
members_ids = [member.user_id for member in members]
if not sender.id in members_ids :
log.warning("Access denied to chat", extra={"user_id": sender.id, "chat_id": message.chat_id})
raise HTTPException(status.HTTP_403_FORBIDDEN, detail="Access denied")
message_db = await MessageDAO.add(
session,
obj_in=MessageCreateDB(
sender_id=sender.id,
chat_id=target_chat_id,
content=message.content
)
)
await cls._send_ws_message(members_ids, Message(
id=message_db.id,
sender_id=message_db.sender_id,
chat_id=message_db.chat_id,
content=message_db.content,
created_at=message_db.created_at,
updated_at=message_db.updated_at
))
await ChatDAO.update(
session,
ChatModel.id==target_chat_id,
obj_in={"last_message": message.content}
)
await session.commit()
log.info("Message sent", extra={"message_id": message_db.id, "sender_id": sender.id, "chat_id": target_chat_id})
return message_db
@classmethod
async def get_chat(cls, chat_id: uuid.UUID, user: UserModel, offset: int = 0, limit: int = 0) -> List[Message]:
log.debug("Getting chat messages", extra={"user_id": user.id, "chat_id": chat_id, "offset": offset, "limit": limit})
async with async_session_maker() as session:
chat_exist = await ChatDAO.get_chat_with_participant(session, chat_id, user.id)
if chat_exist is None:
log.warning("Chat not found", extra={"user_id": user.id, "chat_id": chat_id})
raise HTTPException(status.HTTP_404_NOT_FOUND, "Chat not found")
if chat_exist.participant_id is None:
log.warning("Access denied to chat", extra={"user_id": user.id, "chat_id": chat_id})
raise HTTPException(status.HTTP_403_FORBIDDEN, detail="Access denied")
messages = await MessageDAO.find_all_asc(
session,
offset,
limit,
MessageModel.chat_id==chat_id
)
log.debug("Retrieved chat messages", extra={"user_id": user.id, "chat_id": chat_id, "count": len(messages)})
return messages
@classmethod
async def save_websocket(cls, user: UserModel, ws: WebSocket):
cls.active_connections[str(user.id)] = ws
log.info("WebSocket connection saved", extra={"user_id": user.id, "active_connections": len(cls.active_connections) + 1})
@classmethod
async def delete_websocket(cls, user: UserModel):
cls.active_connections.pop(str(user.id))
log.info("WebSocket connection deleted", extra={"user_id": user.id, "active_connections": len(cls.active_connections) - 1})
@classmethod
async def message_listener(cls):
redis_client = await get_redis()
pubsub = redis_client.pubsub()
await pubsub.subscribe("messenger_updates")
async for message in pubsub.listen():
log.debug(f"Received message from Redis: {message}")
if message["type"] == "message":
payload = json.loads(message["data"])
user_id = payload["user_id"]
if user_id in cls.active_connections:
ws = cls.active_connections[user_id]
await ws.send_json(payload["message"])
log.info(f"Message sent to user {user_id} via WebSocket")
else:
log.debug(f"User {user_id} not connected")
@classmethod
async def _send_ws_message(cls, user_ids: List[int], message: Message):
redis_client = await get_redis()
for user_id in user_ids:
payload = {
"user_id": str(user_id),
"message": message.model_dump(mode='json')
}
await redis_client.publish("messenger_updates", json.dumps(payload))
log.debug(f"Published message for user_id: {user_id}")