Files
2025-12-13 14:39:50 +03:00

110 lines
3.9 KiB
Python
Raw Permalink Blame History

This file contains ambiguous Unicode characters
This file contains Unicode characters that might be confused with other characters. If you think that this is intentional, you can safely ignore this warning. Use the Escape button to reveal them.
from fastapi import APIRouter, WebSocket, WebSocketDisconnect, Depends
from typing import Dict, Set, Optional
import json
from app.services.chat_service import chat_service
from app.db.session import AsyncSessionLocal
from app.api.deps import get_current_user
from app.core.security import decode_token
router = APIRouter()
# Хранилище активных соединений
active_connections: Dict[str, Set[WebSocket]] = {}
class ConnectionManager:
def __init__(self):
self.active_connections: Dict[str, Set[WebSocket]] = {}
async def connect(self, websocket: WebSocket, user_id: str):
await websocket.accept()
if user_id not in self.active_connections:
self.active_connections[user_id] = set()
self.active_connections[user_id].add(websocket)
def disconnect(self, websocket: WebSocket, user_id: str):
if user_id in self.active_connections:
self.active_connections[user_id].discard(websocket)
if not self.active_connections[user_id]:
del self.active_connections[user_id]
async def send_personal_message(self, message: str, websocket: WebSocket):
await websocket.send_text(message)
async def broadcast_to_user(self, user_id: str, message: str):
if user_id in self.active_connections:
for connection in self.active_connections[user_id]:
await connection.send_text(message)
manager = ConnectionManager()
async def get_user_from_token(token: str) -> Optional[str]:
"""Получить user_id из токена"""
payload = decode_token(token)
if payload:
return payload.get("sub")
return None
@router.websocket("/ws/chat")
async def websocket_chat(websocket: WebSocket):
"""WebSocket endpoint для чата с ИИ"""
# Получаем токен из query параметров
token = websocket.query_params.get("token")
if not token:
await websocket.close(code=1008, reason="Token required")
return
# Проверка токена
user_id = await get_user_from_token(token)
if not user_id:
await websocket.close(code=1008, reason="Unauthorized")
return
await manager.connect(websocket, user_id)
try:
async with AsyncSessionLocal() as db:
while True:
data = await websocket.receive_text()
message_data = json.loads(data)
# Создаем запрос для chat_service
from app.schemas.ai import ChatRequest
request = ChatRequest(
message=message_data.get("message", ""),
conversation_id=message_data.get("conversation_id")
)
# Получаем ответ от ИИ
try:
response = await chat_service.chat(
db=db,
user_id=user_id,
request=request
)
# Отправляем ответ клиенту
await manager.send_personal_message(
json.dumps({
"type": "message",
"response": response.response,
"conversation_id": response.conversation_id,
"tokens_used": response.tokens_used
}),
websocket
)
except Exception as e:
await manager.send_personal_message(
json.dumps({
"type": "error",
"message": str(e)
}),
websocket
)
except WebSocketDisconnect:
manager.disconnect(websocket, user_id)