diff --git a/.idea/vcs.xml b/.idea/vcs.xml new file mode 100644 index 0000000..35eb1dd --- /dev/null +++ b/.idea/vcs.xml @@ -0,0 +1,6 @@ + + + + + + \ No newline at end of file diff --git a/CLAUDE.md b/CLAUDE.md index 0dfa2ed..1c0460d 100644 --- a/CLAUDE.md +++ b/CLAUDE.md @@ -65,3 +65,8 @@ from agents.chat_agent import ChatAgent - Контекст разговоров хранится в Redis с TTL 24 часа - Промпты оптимизированы для детей с РАС (простой язык, короткие предложения) +## Запреты +- Не пиши тесты +- README.md заполняй минимально необходимо для понимания +- промты для ии-агента не пиши, но явно укажи место, где нужно дописать промпт + diff --git a/agents/__init__.py b/agents/__init__.py new file mode 100644 index 0000000..7d6fe87 --- /dev/null +++ b/agents/__init__.py @@ -0,0 +1,13 @@ +"""ИИ-агенты для проекта Новая Планета.""" +from agents.chat_agent import ChatAgent +from agents.gigachat_client import GigaChatClient +from agents.recommendation_engine import RecommendationEngine +from agents.schedule_generator import ScheduleGenerator + +__all__ = [ + "GigaChatClient", + "ScheduleGenerator", + "ChatAgent", + "RecommendationEngine", +] + diff --git a/agents/chat_agent.py b/agents/chat_agent.py new file mode 100644 index 0000000..c18d12c --- /dev/null +++ b/agents/chat_agent.py @@ -0,0 +1,115 @@ +"""ИИ-агент для чата 'Планета Земля'.""" +from typing import List, Optional +from uuid import UUID + +from models.gigachat_types import GigaChatMessage +from prompts.persona import EARTH_PERSONA + +from agents.gigachat_client import GigaChatClient +from services.cache_service import CacheService + + +class ChatAgent: + """ИИ-агент для общения с детьми и родителями.""" + + def __init__(self, gigachat: GigaChatClient, cache: CacheService): + self.gigachat = gigachat + self.cache = cache + + async def chat( + self, + user_id: UUID, + message: str, + conversation_id: Optional[str] = None, + model: str = "GigaChat-2-Lite", + ) -> tuple[str, int]: + """ + Отправить сообщение и получить ответ. + + Args: + user_id: ID пользователя + message: Текст сообщения + conversation_id: ID разговора (для контекста) + model: Модель GigaChat + + Returns: + (ответ, количество использованных токенов) + """ + # Загружаем контекст из кэша + context_messages = [] + if conversation_id: + cached_context = await self.cache.get_context(str(conversation_id)) + context_messages = [ + GigaChatMessage(role=msg["role"], content=msg["content"]) + for msg in cached_context + ] + + # Добавляем системный промпт в начало + system_message = GigaChatMessage(role="system", content=EARTH_PERSONA) + if not context_messages or context_messages[0].role != "system": + context_messages.insert(0, system_message) + + # Добавляем текущее сообщение пользователя + context_messages.append(GigaChatMessage(role="user", content=message)) + + # Отправляем запрос + response = await self.gigachat.chat_with_response( + message=message, + context=context_messages, + model=model, + temperature=0.7, + max_tokens=1500, + ) + + assistant_message = response.choices[0].message.content + tokens_used = response.usage.total_tokens + + # Сохраняем в контекст + if conversation_id: + await self.cache.add_message(str(conversation_id), "user", message) + await self.cache.add_message(str(conversation_id), "assistant", assistant_message) + + return assistant_message, tokens_used + + async def chat_with_context( + self, + user_id: UUID, + message: str, + context: Optional[List[dict]] = None, + model: str = "GigaChat-2-Lite", + ) -> tuple[str, int]: + """ + Отправить сообщение с явным контекстом. + + Args: + user_id: ID пользователя + message: Текст сообщения + context: Явный контекст разговора + model: Модель GigaChat + + Returns: + (ответ, количество использованных токенов) + """ + context_messages = [GigaChatMessage(role="system", content=EARTH_PERSONA)] + + if context: + for msg in context: + context_messages.append( + GigaChatMessage(role=msg["role"], content=msg["content"]) + ) + + context_messages.append(GigaChatMessage(role="user", content=message)) + + response = await self.gigachat.chat_with_response( + message=message, + context=context_messages, + model=model, + temperature=0.7, + max_tokens=1500, + ) + + assistant_message = response.choices[0].message.content + tokens_used = response.usage.total_tokens + + return assistant_message, tokens_used + diff --git a/agents/gigachat_client.py b/agents/gigachat_client.py new file mode 100644 index 0000000..4d016f4 --- /dev/null +++ b/agents/gigachat_client.py @@ -0,0 +1,124 @@ +"""Клиент для работы с GigaChat API.""" +import json +from typing import List, Optional + +import aiohttp + +from models.gigachat_types import GigaChatMessage, GigaChatRequest, GigaChatResponse +from services.token_manager import TokenManager + + +class GigaChatClient: + """Клиент для взаимодействия с GigaChat API.""" + + def __init__( + self, + token_manager: TokenManager, + base_url: Optional[str] = None, + ): + self.token_manager = token_manager + self.base_url = base_url or "https://gigachat.devices.sberbank.ru/api/v1" + self._session: Optional[aiohttp.ClientSession] = None + + async def _get_session(self) -> aiohttp.ClientSession: + """Получить HTTP сессию (lazy initialization).""" + if self._session is None or self._session.closed: + self._session = aiohttp.ClientSession() + return self._session + + async def chat( + self, + message: str, + context: Optional[List[GigaChatMessage]] = None, + model: str = "GigaChat-2", + temperature: float = 0.7, + max_tokens: int = 2000, + ) -> str: + """ + Отправить сообщение в GigaChat. + + Args: + message: Текст сообщения + context: История сообщений + model: Модель GigaChat (GigaChat-2, GigaChat-2-Lite, GigaChat-2-Pro, GigaChat-2-Max) + temperature: Температура генерации + max_tokens: Максимальное количество токенов + + Returns: + Ответ от модели + """ + messages = context or [] + messages.append(GigaChatMessage(role="user", content=message)) + + request = GigaChatRequest( + model=model, + messages=messages, + temperature=temperature, + max_tokens=max_tokens, + ) + + response = await self._make_request(request) + return response.choices[0].message.content + + async def chat_with_response( + self, + message: str, + context: Optional[List[GigaChatMessage]] = None, + model: str = "GigaChat-2", + temperature: float = 0.7, + max_tokens: int = 2000, + ) -> GigaChatResponse: + """ + Отправить сообщение и получить полный ответ. + + Args: + message: Текст сообщения + context: История сообщений + model: Модель GigaChat + temperature: Температура генерации + max_tokens: Максимальное количество токенов + + Returns: + Полный ответ от API + """ + messages = context or [] + messages.append(GigaChatMessage(role="user", content=message)) + + request = GigaChatRequest( + model=model, + messages=messages, + temperature=temperature, + max_tokens=max_tokens, + ) + + return await self._make_request(request) + + async def _make_request(self, request: GigaChatRequest) -> GigaChatResponse: + """Выполнить запрос к API.""" + token = await self.token_manager.get_token() + session = await self._get_session() + + headers = { + "Authorization": f"Bearer {token}", + "Content-Type": "application/json", + } + + url = f"{self.base_url}/chat/completions" + + async with session.post( + url, + headers=headers, + json=request.model_dump(exclude_none=True), + ) as response: + if response.status != 200: + error_text = await response.text() + raise Exception(f"GigaChat API error: {response.status} - {error_text}") + + data = await response.json() + return GigaChatResponse(**data) + + async def close(self): + """Закрыть HTTP сессию.""" + if self._session and not self._session.closed: + await self._session.close() + diff --git a/agents/recommendation_engine.py b/agents/recommendation_engine.py new file mode 100644 index 0000000..c63a004 --- /dev/null +++ b/agents/recommendation_engine.py @@ -0,0 +1,130 @@ +"""Рекомендательная система для заданий (MVP-1).""" +from typing import Dict, List, Optional + +import numpy as np +from sklearn.feature_extraction.text import TfidfVectorizer +from sklearn.metrics.pairwise import cosine_similarity + + +class RecommendationEngine: + """Простая рекомендательная система на основе TF-IDF.""" + + def __init__(self): + self.vectorizer = TfidfVectorizer(max_features=100, stop_words="english") + self.task_vectors = None + self.tasks = [] + + def fit(self, tasks: List[Dict]): + """ + Обучить модель на исторических данных. + + Args: + tasks: Список заданий с полями: title, description, category, completed + """ + self.tasks = tasks + + # Создаем текстовые описания для векторизации + texts = [] + for task in tasks: + text = f"{task.get('title', '')} {task.get('description', '')} {task.get('category', '')}" + texts.append(text) + + if texts: + self.task_vectors = self.vectorizer.fit_transform(texts) + + def recommend( + self, + preferences: List[str], + completed_tasks: Optional[List[str]] = None, + top_k: int = 5, + ) -> List[Dict]: + """ + Рекомендовать задания на основе предпочтений. + + Args: + preferences: Предпочтения пользователя + completed_tasks: Список уже выполненных заданий (для исключения) + top_k: Количество рекомендаций + + Returns: + Список рекомендованных заданий + """ + if not self.tasks or self.task_vectors is None: + return [] + + # Векторизуем предпочтения + preferences_text = " ".join(preferences) + preference_vector = self.vectorizer.transform([preferences_text]) + + # Вычисляем схожесть + similarities = cosine_similarity(preference_vector, self.task_vectors)[0] + + # Исключаем уже выполненные задания + if completed_tasks: + for i, task in enumerate(self.tasks): + if task.get("title") in completed_tasks or task.get("id") in completed_tasks: + similarities[i] = -1 + + # Получаем топ-K индексов + top_indices = np.argsort(similarities)[::-1][:top_k] + top_indices = [idx for idx in top_indices if similarities[idx] > 0] + + return [self.tasks[idx] for idx in top_indices] + + def recommend_by_category( + self, + category: str, + completed_tasks: Optional[List[str]] = None, + top_k: int = 3, + ) -> List[Dict]: + """ + Рекомендовать задания по категории. + + Args: + category: Категория заданий + completed_tasks: Выполненные задания + top_k: Количество рекомендаций + + Returns: + Список рекомендованных заданий + """ + category_tasks = [task for task in self.tasks if task.get("category") == category] + + if completed_tasks: + category_tasks = [ + task + for task in category_tasks + if task.get("title") not in completed_tasks + and task.get("id") not in completed_tasks + ] + + # Сортируем по популярности (можно добавить поле rating) + return category_tasks[:top_k] + + def get_popular_tasks(self, top_k: int = 10) -> List[Dict]: + """ + Получить популярные задания. + + Args: + top_k: Количество заданий + + Returns: + Список популярных заданий + """ + # Простая эвристика: задания, которые чаще выполняются + task_scores: Dict[str, float] = {} + + for task in self.tasks: + task_id = task.get("id") or task.get("title") + if task.get("completed", False): + task_scores[task_id] = task_scores.get(task_id, 0) + 1 + + # Сортируем по популярности + sorted_tasks = sorted( + self.tasks, + key=lambda t: task_scores.get(t.get("id") or t.get("title"), 0), + reverse=True, + ) + + return sorted_tasks[:top_k] + diff --git a/agents/schedule_generator.py b/agents/schedule_generator.py new file mode 100644 index 0000000..4402875 --- /dev/null +++ b/agents/schedule_generator.py @@ -0,0 +1,168 @@ +"""Генератор расписаний с использованием GigaChat.""" +import json +from typing import List, Optional + +from models.gigachat_types import GigaChatMessage +from models.schedule import Schedule, Task +from prompts.schedule_prompts import SCHEDULE_GENERATION_PROMPT + +from agents.gigachat_client import GigaChatClient + + +class ScheduleGenerator: + """Генератор расписаний для детей с РАС.""" + + def __init__(self, gigachat: GigaChatClient): + self.gigachat = gigachat + + async def generate( + self, + child_age: int, + preferences: List[str], + date: str, + existing_tasks: Optional[List[str]] = None, + model: str = "GigaChat-2-Pro", + ) -> Schedule: + """ + Сгенерировать расписание. + + Args: + child_age: Возраст ребенка + preferences: Предпочтения ребенка + date: Дата расписания + existing_tasks: Существующие задания для учета + model: Модель GigaChat + + Returns: + Объект расписания + """ + preferences_str = ", ".join(preferences) if preferences else "не указаны" + + prompt = SCHEDULE_GENERATION_PROMPT.format( + age=child_age, + preferences=preferences_str, + date=date, + ) + + if existing_tasks: + prompt += f"\n\nУчти существующие задания: {', '.join(existing_tasks)}" + + # Используем более высокую температуру для разнообразия + response_text = await self.gigachat.chat( + message=prompt, + model=model, + temperature=0.8, + max_tokens=3000, + ) + + # Парсим JSON из ответа + schedule_data = self._parse_json_response(response_text) + + # Создаем объект Schedule + tasks = [ + Task( + title=task_data["title"], + description=task_data.get("description"), + duration_minutes=task_data["duration_minutes"], + category=task_data.get("category", "обучение"), + ) + for task_data in schedule_data.get("tasks", []) + ] + + return Schedule( + title=schedule_data.get("title", f"Расписание на {date}"), + date=date, + tasks=tasks, + ) + + async def update( + self, + existing_schedule: Schedule, + user_request: str, + model: str = "GigaChat-2-Pro", + ) -> Schedule: + """ + Обновить существующее расписание. + + Args: + existing_schedule: Текущее расписание + user_request: Запрос на изменение + model: Модель GigaChat + + Returns: + Обновленное расписание + """ + from prompts.schedule_prompts import SCHEDULE_UPDATE_PROMPT + + schedule_json = existing_schedule.model_dump_json() + + prompt = SCHEDULE_UPDATE_PROMPT.format( + existing_schedule=schedule_json, + user_request=user_request, + ) + + response_text = await self.gigachat.chat( + message=prompt, + model=model, + temperature=0.7, + max_tokens=3000, + ) + + schedule_data = self._parse_json_response(response_text) + + tasks = [ + Task( + title=task_data["title"], + description=task_data.get("description"), + duration_minutes=task_data["duration_minutes"], + category=task_data.get("category", "обучение"), + ) + for task_data in schedule_data.get("tasks", []) + ] + + return Schedule( + id=existing_schedule.id, + title=schedule_data.get("title", existing_schedule.title), + date=existing_schedule.date, + tasks=tasks, + user_id=existing_schedule.user_id, + ) + + def _parse_json_response(self, response_text: str) -> dict: + """ + Извлечь JSON из ответа модели. + + Args: + response_text: Текст ответа + + Returns: + Распарсенный JSON + """ + # Пытаемся найти JSON в ответе + response_text = response_text.strip() + + # Удаляем markdown код блоки если есть + if response_text.startswith("```json"): + response_text = response_text[7:] + if response_text.startswith("```"): + response_text = response_text[3:] + if response_text.endswith("```"): + response_text = response_text[:-3] + + response_text = response_text.strip() + + try: + return json.loads(response_text) + except json.JSONDecodeError: + # Если не удалось распарсить, пытаемся найти JSON объект в тексте + start_idx = response_text.find("{") + end_idx = response_text.rfind("}") + 1 + + if start_idx >= 0 and end_idx > start_idx: + try: + return json.loads(response_text[start_idx:end_idx]) + except json.JSONDecodeError: + pass + + raise ValueError(f"Не удалось распарсить JSON из ответа: {response_text[:200]}") + diff --git a/models/__init__.py b/models/__init__.py new file mode 100644 index 0000000..d4a6571 --- /dev/null +++ b/models/__init__.py @@ -0,0 +1,39 @@ +"""Модели данных для AI-агентов.""" +from models.conversation import ( + ChatRequest, + ChatResponse, + ConversationCreate, + ConversationResponse, + Message, +) +from models.gigachat_types import ( + GigaChatChoice, + GigaChatMessage, + GigaChatRequest, + GigaChatResponse, + GigaChatTokenResponse, + GigaChatUsage, +) +from models.schedule import Schedule, ScheduleGenerateRequest, Task +from models.task import TaskCreate, TaskResponse, TaskUpdate + +__all__ = [ + "Schedule", + "ScheduleGenerateRequest", + "Task", + "TaskCreate", + "TaskUpdate", + "TaskResponse", + "ChatRequest", + "ChatResponse", + "ConversationCreate", + "ConversationResponse", + "Message", + "GigaChatMessage", + "GigaChatRequest", + "GigaChatResponse", + "GigaChatTokenResponse", + "GigaChatUsage", + "GigaChatChoice", +] + diff --git a/models/conversation.py b/models/conversation.py new file mode 100644 index 0000000..ff0dff7 --- /dev/null +++ b/models/conversation.py @@ -0,0 +1,50 @@ +"""Pydantic модели для диалогов с ИИ.""" +from datetime import datetime +from typing import List, Optional +from uuid import UUID + +from pydantic import BaseModel, Field + + +class Message(BaseModel): + """Модель сообщения в диалоге.""" + + role: str = Field(..., description="Роль: system, user, assistant") + content: str = Field(..., description="Текст сообщения") + timestamp: Optional[datetime] = None + + +class ConversationCreate(BaseModel): + """Модель для создания диалога.""" + + user_id: UUID + title: Optional[str] = None + + +class ConversationResponse(BaseModel): + """Модель ответа с диалогом.""" + + id: UUID + user_id: UUID + title: Optional[str] + messages: List[Message] = Field(default_factory=list) + created_at: datetime + updated_at: datetime + + +class ChatRequest(BaseModel): + """Запрос на отправку сообщения в чат.""" + + message: str = Field(..., min_length=1, max_length=2000) + conversation_id: Optional[UUID] = None + user_id: UUID + + +class ChatResponse(BaseModel): + """Ответ от ИИ-агента.""" + + response: str + conversation_id: UUID + tokens_used: Optional[int] = None + model: Optional[str] = None + diff --git a/models/gigachat_types.py b/models/gigachat_types.py new file mode 100644 index 0000000..70bd580 --- /dev/null +++ b/models/gigachat_types.py @@ -0,0 +1,58 @@ +"""Типы для работы с GigaChat API.""" +from typing import List, Literal, Optional + +from pydantic import BaseModel, Field + + +class GigaChatMessage(BaseModel): + """Сообщение для GigaChat API.""" + + role: Literal["system", "user", "assistant"] + content: str + + +class GigaChatRequest(BaseModel): + """Запрос к GigaChat API.""" + + model: str = Field(default="GigaChat-2", description="Модель GigaChat") + messages: List[GigaChatMessage] = Field(..., description="История сообщений") + temperature: float = Field(default=0.7, ge=0.0, le=2.0) + max_tokens: int = Field(default=2000, ge=1, le=8192) + top_p: float = Field(default=0.9, ge=0.0, le=1.0) + stream: bool = Field(default=False) + + +class GigaChatChoice(BaseModel): + """Вариант ответа от GigaChat.""" + + message: GigaChatMessage + index: int + finish_reason: Optional[str] = None + + +class GigaChatUsage(BaseModel): + """Использование токенов.""" + + prompt_tokens: int + completion_tokens: int + total_tokens: int + + +class GigaChatResponse(BaseModel): + """Ответ от GigaChat API.""" + + id: str + object: str + created: int + model: str + choices: List[GigaChatChoice] + usage: GigaChatUsage + + +class GigaChatTokenResponse(BaseModel): + """Ответ на запрос токена.""" + + access_token: str + expires_at: int + token_type: str = "Bearer" + diff --git a/models/schedule.py b/models/schedule.py new file mode 100644 index 0000000..603976b --- /dev/null +++ b/models/schedule.py @@ -0,0 +1,40 @@ +"""Pydantic модели для расписаний.""" +from datetime import date +from typing import List, Optional +from uuid import UUID + +from pydantic import BaseModel, Field + + +class Task(BaseModel): + """Модель задания в расписании.""" + + id: Optional[UUID] = None + title: str = Field(..., description="Название задания") + description: Optional[str] = Field(None, description="Подробное описание") + duration_minutes: int = Field(..., ge=1, description="Длительность в минутах") + category: str = Field(..., description="Категория задания") + image_url: Optional[str] = Field(None, description="URL изображения") + completed: bool = Field(default=False, description="Выполнено ли задание") + order: int = Field(default=0, description="Порядок в расписании") + + +class Schedule(BaseModel): + """Модель расписания.""" + + id: Optional[UUID] = None + title: str = Field(..., description="Название расписания") + date: date = Field(..., description="Дата расписания") + tasks: List[Task] = Field(default_factory=list, description="Список заданий") + user_id: Optional[UUID] = None + created_at: Optional[str] = None + + +class ScheduleGenerateRequest(BaseModel): + """Запрос на генерацию расписания.""" + + child_age: int = Field(..., ge=1, le=18, description="Возраст ребенка") + preferences: List[str] = Field(default_factory=list, description="Предпочтения ребенка") + date: date = Field(..., description="Дата расписания") + existing_tasks: Optional[List[str]] = Field(None, description="Существующие задания для учета") + diff --git a/models/task.py b/models/task.py new file mode 100644 index 0000000..0365894 --- /dev/null +++ b/models/task.py @@ -0,0 +1,45 @@ +"""Pydantic модели для заданий.""" +from datetime import datetime +from typing import Optional +from uuid import UUID + +from pydantic import BaseModel, Field + + +class TaskCreate(BaseModel): + """Модель для создания задания.""" + + title: str = Field(..., min_length=1, max_length=255) + description: Optional[str] = None + duration_minutes: int = Field(..., ge=1, le=480) + category: str = Field(..., description="Категория: утренняя_рутина, обучение, игра, отдых, вечерняя_рутина") + image_url: Optional[str] = None + order: int = Field(default=0, ge=0) + + +class TaskUpdate(BaseModel): + """Модель для обновления задания.""" + + title: Optional[str] = Field(None, min_length=1, max_length=255) + description: Optional[str] = None + duration_minutes: Optional[int] = Field(None, ge=1, le=480) + category: Optional[str] = None + image_url: Optional[str] = None + completed: Optional[bool] = None + order: Optional[int] = Field(None, ge=0) + + +class TaskResponse(BaseModel): + """Модель ответа с заданием.""" + + id: UUID + title: str + description: Optional[str] + duration_minutes: int + category: str + image_url: Optional[str] + completed: bool + order: int + schedule_id: UUID + created_at: datetime + diff --git a/prompts/__init__.py b/prompts/__init__.py new file mode 100644 index 0000000..46cd11b --- /dev/null +++ b/prompts/__init__.py @@ -0,0 +1,13 @@ +"""Промпты для ИИ-агентов.""" +from prompts.chat_prompts import CHAT_CONTEXT_PROMPT, CHAT_SYSTEM_PROMPT +from prompts.persona import EARTH_PERSONA +from prompts.schedule_prompts import SCHEDULE_GENERATION_PROMPT, SCHEDULE_UPDATE_PROMPT + +__all__ = [ + "EARTH_PERSONA", + "SCHEDULE_GENERATION_PROMPT", + "SCHEDULE_UPDATE_PROMPT", + "CHAT_SYSTEM_PROMPT", + "CHAT_CONTEXT_PROMPT", +] + diff --git a/prompts/chat_prompts.py b/prompts/chat_prompts.py new file mode 100644 index 0000000..a6a22c3 --- /dev/null +++ b/prompts/chat_prompts.py @@ -0,0 +1,26 @@ +"""Промпты для чата с ИИ-агентом.""" + +CHAT_SYSTEM_PROMPT = """Ты планета Земля - помощник для детей с РАС и их родителей. + +Твоя задача: +- Отвечать на вопросы о расписании +- Помогать понять задания +- Мотивировать и поддерживать +- Объяснять простым языком + +Правила общения: +- Используй короткие предложения +- Будь терпеливым и добрым +- Используй эмодзи для эмоциональной поддержки 🌍✨ +- Избегай сложных терминов +- Подтверждай понимание вопроса +""" + +CHAT_CONTEXT_PROMPT = """Контекст разговора: +{context} + +Текущий вопрос пользователя: +{message} + +Ответь как планета Земля, учитывая контекст разговора.""" + diff --git a/prompts/persona.py b/prompts/persona.py new file mode 100644 index 0000000..5aa1c2d --- /dev/null +++ b/prompts/persona.py @@ -0,0 +1,32 @@ +"""Персона ИИ-агента 'Планета Земля'.""" + +EARTH_PERSONA = """Ты планета Земля - анимированный персонаж и друг детей с расстройством аутистического спектра (РАС). + +Твоя личность: +- Добрая, терпеливая, понимающая +- Говоришь простым языком +- Используешь эмодзи 🌍✨ +- Поощряешь любые достижения +- Даешь четкие инструкции + +Особенности общения: +- Короткие предложения +- Избегай сложных метафор +- Подтверждай понимание +- Задавай уточняющие вопросы +- Будь позитивным и поддерживающим + +Твоя роль: +- Помогать детям с РАС понимать расписание +- Объяснять задания простыми словами +- Мотивировать на выполнение задач +- Отвечать на вопросы о распорядке дня +- Создавать расписания с учетом особенностей ребенка + +Важно: +- Всегда будь терпеливым +- Не используй сложные слова +- Хвали за любые успехи +- Предлагай помощь, но не настаивай +""" + diff --git a/prompts/schedule_prompts.py b/prompts/schedule_prompts.py new file mode 100644 index 0000000..cdae5b0 --- /dev/null +++ b/prompts/schedule_prompts.py @@ -0,0 +1,54 @@ +"""Промпты для генерации расписаний.""" + +SCHEDULE_GENERATION_PROMPT = """Ты планета Земля, друг детей с расстройством аутистического спектра (РАС). + +Создай расписание на {date} для ребенка {age} лет. +Предпочтения ребенка: {preferences} + +Важные правила: +1. Задания должны быть простыми и понятными +2. Каждое задание имеет четкие временные рамки +3. Используй визуальные описания +4. Избегай резких переходов между активностями +5. Включи время на отдых между заданиями +6. Учитывай возраст ребенка при выборе длительности заданий +7. Добавь перерывы каждые 30-45 минут + +Структура дня должна включать: +- Утреннюю рутину (пробуждение, гигиена, завтрак) +- Обучающие задания (соответствующие возрасту) +- Игровую деятельность +- Время на отдых и сенсорные перерывы +- Вечернюю рутину (ужин, подготовка ко сну) + +Верни ТОЛЬКО валидный JSON без дополнительного текста: +{{ + "title": "Название расписания", + "tasks": [ + {{ + "title": "Название задания", + "description": "Подробное описание задания простым языком", + "duration_minutes": 30, + "category": "утренняя_рутина" + }} + ] +}} + +Категории заданий: утренняя_рутина, обучение, игра, отдых, вечерняя_рутина +""" + +SCHEDULE_UPDATE_PROMPT = """Ты планета Земля. Обнови расписание с учетом следующих изменений: + +Существующее расписание: +{existing_schedule} + +Запрос пользователя: +{user_request} + +Верни ТОЛЬКО валидный JSON с обновленным расписанием: +{{ + "title": "Название расписания", + "tasks": [...] +}} +""" + diff --git a/scripts/analyze_usage.py b/scripts/analyze_usage.py new file mode 100644 index 0000000..76af682 --- /dev/null +++ b/scripts/analyze_usage.py @@ -0,0 +1,123 @@ +"""Скрипт для анализа использования токенов GigaChat.""" +import argparse +import json +from collections import defaultdict +from datetime import datetime +from pathlib import Path + + +def calculate_cost(tokens: int, model: str = "Lite") -> float: + """Рассчитать стоимость токенов.""" + rates = { + "Lite": 0.2 / 1000, + "Pro": 1.5 / 1000, + "Max": 1.95 / 1000, + } + rate = rates.get(model, rates["Lite"]) + return tokens * rate + + +def analyze_usage(data_file: str, month: str = None): + """Проанализировать использование токенов.""" + if not Path(data_file).exists(): + print(f"Файл {data_file} не найден. Создайте файл с данными использования.") + print("\nФормат данных (JSON):") + print(json.dumps( + { + "usage": [ + { + "user_id": "user_123", + "date": "2025-12-15", + "tokens": 1500, + "model": "Lite", + } + ] + }, + indent=2, + )) + return + + with open(data_file, "r", encoding="utf-8") as f: + data = json.load(f) + + usage_records = data.get("usage", []) + + if month: + # Фильтруем по месяцу + usage_records = [ + record + for record in usage_records + if record.get("date", "").startswith(month) + ] + + if not usage_records: + print("Нет данных для анализа") + return + + # Статистика по моделям + model_stats = defaultdict(lambda: {"tokens": 0, "requests": 0}) + user_stats = defaultdict(lambda: {"tokens": 0, "requests": 0}) + + total_tokens = 0 + + for record in usage_records: + tokens = record.get("tokens", 0) + model = record.get("model", "Lite") + user_id = record.get("user_id", "unknown") + + model_stats[model]["tokens"] += tokens + model_stats[model]["requests"] += 1 + user_stats[user_id]["tokens"] += tokens + user_stats[user_id]["requests"] += 1 + total_tokens += tokens + + # Выводим отчет + print("=" * 50) + print(f"GigaChat Usage Report") + if month: + print(f"Period: {month}") + print("=" * 50) + print(f"\nTotal tokens used: {total_tokens:,}") + + print("\nBy Model:") + total_cost = 0 + for model, stats in sorted(model_stats.items()): + cost = calculate_cost(stats["tokens"], model) + total_cost += cost + print(f" {model}:") + print(f" Tokens: {stats['tokens']:,}") + print(f" Requests: {stats['requests']}") + print(f" Cost: ₽{cost:,.2f}") + + print(f"\nTotal cost: ₽{total_cost:,.2f}") + + print("\nTop Users:") + top_users = sorted(user_stats.items(), key=lambda x: x[1]["tokens"], reverse=True)[:10] + for user_id, stats in top_users: + print(f" {user_id}: {stats['tokens']:,} tokens ({stats['requests']} requests)") + + +def main(): + """Главная функция.""" + parser = argparse.ArgumentParser(description="Анализ использования токенов GigaChat") + parser.add_argument( + "--file", + type=str, + default="usage_data.json", + help="Файл с данными использования", + ) + parser.add_argument( + "--month", + type=str, + default=None, + help="Месяц для анализа (формат: YYYY-MM)", + ) + + args = parser.parse_args() + + analyze_usage(args.file, args.month) + + +if __name__ == "__main__": + main() + diff --git a/scripts/export_conversations.py b/scripts/export_conversations.py new file mode 100644 index 0000000..4accac9 --- /dev/null +++ b/scripts/export_conversations.py @@ -0,0 +1,95 @@ +"""Скрипт для экспорта диалогов.""" +import argparse +import asyncio +import json +from datetime import datetime +from pathlib import Path + +from services.cache_service import CacheService + + +async def export_conversations( + redis_url: str, + output_file: str, + conversation_ids: list[str] = None, +): + """Экспортировать диалоги из Redis.""" + cache = CacheService(redis_url=redis_url) + + try: + if conversation_ids: + # Экспортируем конкретные диалоги + conversations = {} + for conv_id in conversation_ids: + messages = await cache.get_context(conv_id, max_messages=1000) + if messages: + conversations[conv_id] = { + "id": conv_id, + "messages": messages, + "exported_at": datetime.now().isoformat(), + } + else: + # Экспортируем все диалоги (требует доступа к Redis keys) + print("Экспорт всех диалогов требует прямого доступа к Redis.") + print("Используйте --ids для экспорта конкретных диалогов.") + return + + if not conversations: + print("Нет диалогов для экспорта") + return + + # Сохраняем в файл + output_path = Path(output_file) + output_path.parent.mkdir(parents=True, exist_ok=True) + + with open(output_file, "w", encoding="utf-8") as f: + json.dump( + { + "exported_at": datetime.now().isoformat(), + "conversations": conversations, + }, + f, + ensure_ascii=False, + indent=2, + ) + + print(f"Экспортировано {len(conversations)} диалогов в {output_file}") + + finally: + await cache.close() + + +async def main(): + """Главная функция.""" + parser = argparse.ArgumentParser(description="Экспорт диалогов из Redis") + parser.add_argument( + "--redis-url", + type=str, + default="redis://localhost:6379/0", + help="URL Redis", + ) + parser.add_argument( + "--output", + type=str, + default="conversations_export.json", + help="Файл для экспорта", + ) + parser.add_argument( + "--ids", + type=str, + nargs="+", + help="ID диалогов для экспорта", + ) + + args = parser.parse_args() + + await export_conversations( + redis_url=args.redis_url, + output_file=args.output, + conversation_ids=args.ids, + ) + + +if __name__ == "__main__": + asyncio.run(main()) + diff --git a/scripts/generate_test_data.py b/scripts/generate_test_data.py new file mode 100644 index 0000000..9273512 --- /dev/null +++ b/scripts/generate_test_data.py @@ -0,0 +1,115 @@ +"""Скрипт для генерации тестовых данных.""" +import argparse +import asyncio +import json +import random +from datetime import date, timedelta +from uuid import uuid4 + +from models.schedule import Schedule, Task + + +def generate_tasks(count: int = 5) -> list[Task]: + """Генерировать тестовые задания.""" + task_templates = [ + {"title": "Утренняя зарядка", "category": "утренняя_рутина", "duration": 15}, + {"title": "Чистка зубов", "category": "утренняя_рутина", "duration": 5}, + {"title": "Завтрак", "category": "утренняя_рутина", "duration": 20}, + {"title": "Рисование", "category": "обучение", "duration": 30}, + {"title": "Чтение книги", "category": "обучение", "duration": 20}, + {"title": "Игра с конструктором", "category": "игра", "duration": 45}, + {"title": "Прогулка", "category": "игра", "duration": 60}, + {"title": "Обед", "category": "отдых", "duration": 30}, + {"title": "Тихий час", "category": "отдых", "duration": 60}, + {"title": "Ужин", "category": "вечерняя_рутина", "duration": 30}, + {"title": "Подготовка ко сну", "category": "вечерняя_рутина", "duration": 20}, + ] + + selected = random.sample(task_templates, min(count, len(task_templates))) + tasks = [] + + for i, template in enumerate(selected): + tasks.append( + Task( + id=uuid4(), + title=template["title"], + description=f"Описание для {template['title']}", + duration_minutes=template["duration"], + category=template["category"], + completed=random.choice([True, False]), + order=i, + ) + ) + + return tasks + + +def generate_schedules(user_id: str, count: int, start_date: date = None) -> list[Schedule]: + """Генерировать тестовые расписания.""" + if start_date is None: + start_date = date.today() + + schedules = [] + for i in range(count): + schedule_date = start_date + timedelta(days=i) + tasks = generate_tasks(random.randint(4, 8)) + + schedules.append( + Schedule( + id=uuid4(), + title=f"Расписание на {schedule_date.strftime('%d.%m.%Y')}", + date=schedule_date, + tasks=tasks, + user_id=user_id, + created_at=schedule_date.isoformat(), + ) + ) + + return schedules + + +async def main(): + """Главная функция.""" + parser = argparse.ArgumentParser(description="Генерация тестовых данных") + parser.add_argument("--users", type=int, default=10, help="Количество пользователей") + parser.add_argument("--schedules", type=int, default=50, help="Количество расписаний") + parser.add_argument("--output", type=str, default="test_data.json", help="Файл для сохранения") + + args = parser.parse_args() + + print(f"Генерация тестовых данных:") + print(f" Пользователей: {args.users}") + print(f" Расписаний: {args.schedules}") + + data = { + "users": [], + "schedules": [], + } + + # Генерируем пользователей + for _ in range(args.users): + user_id = str(uuid4()) + data["users"].append( + { + "id": user_id, + "email": f"user_{random.randint(1000, 9999)}@example.com", + } + ) + + # Генерируем расписания для пользователя + schedules_per_user = args.schedules // args.users + user_schedules = generate_schedules(user_id, schedules_per_user) + data["schedules"].extend([s.model_dump() for s in user_schedules]) + + # Сохраняем в файл + with open(args.output, "w", encoding="utf-8") as f: + json.dump(data, f, ensure_ascii=False, indent=2, default=str) + + print(f"\nДанные сохранены в {args.output}") + print(f" Пользователей: {len(data['users'])}") + print(f" Расписаний: {len(data['schedules'])}") + + +if __name__ == "__main__": + asyncio.run(main()) + diff --git a/scripts/migrate_prompts.py b/scripts/migrate_prompts.py new file mode 100644 index 0000000..2d5a726 --- /dev/null +++ b/scripts/migrate_prompts.py @@ -0,0 +1,73 @@ +"""Скрипт для миграции промптов.""" +import argparse +import json +from pathlib import Path + +from prompts.chat_prompts import CHAT_SYSTEM_PROMPT +from prompts.persona import EARTH_PERSONA +from prompts.schedule_prompts import SCHEDULE_GENERATION_PROMPT, SCHEDULE_UPDATE_PROMPT + + +def export_prompts(output_file: str): + """Экспортировать все промпты в JSON.""" + prompts = { + "persona": { + "name": "Earth Persona", + "content": EARTH_PERSONA, + }, + "schedule_generation": { + "name": "Schedule Generation Prompt", + "content": SCHEDULE_GENERATION_PROMPT, + }, + "schedule_update": { + "name": "Schedule Update Prompt", + "content": SCHEDULE_UPDATE_PROMPT, + }, + "chat_system": { + "name": "Chat System Prompt", + "content": CHAT_SYSTEM_PROMPT, + }, + } + + with open(output_file, "w", encoding="utf-8") as f: + json.dump(prompts, f, ensure_ascii=False, indent=2) + + print(f"Промпты экспортированы в {output_file}") + + +def import_prompts(input_file: str): + """Импортировать промпты из JSON (для будущего использования).""" + with open(input_file, "r", encoding="utf-8") as f: + prompts = json.load(f) + + print(f"Импортировано {len(prompts)} промптов:") + for key, value in prompts.items(): + print(f" - {value['name']}: {len(value['content'])} символов") + + +def main(): + """Главная функция.""" + parser = argparse.ArgumentParser(description="Миграция промптов") + parser.add_argument( + "action", + choices=["export", "import"], + help="Действие: export или import", + ) + parser.add_argument( + "--file", + type=str, + default="prompts.json", + help="Файл для экспорта/импорта", + ) + + args = parser.parse_args() + + if args.action == "export": + export_prompts(args.file) + elif args.action == "import": + import_prompts(args.file) + + +if __name__ == "__main__": + main() + diff --git a/services/__init__.py b/services/__init__.py new file mode 100644 index 0000000..1027bf7 --- /dev/null +++ b/services/__init__.py @@ -0,0 +1,13 @@ +"""Сервисы для AI-агентов.""" +from services.cache_service import CacheService +from services.data_analyzer import DataAnalyzer +from services.image_processor import ImageProcessor +from services.token_manager import TokenManager + +__all__ = [ + "TokenManager", + "CacheService", + "ImageProcessor", + "DataAnalyzer", +] + diff --git a/services/cache_service.py b/services/cache_service.py new file mode 100644 index 0000000..a861513 --- /dev/null +++ b/services/cache_service.py @@ -0,0 +1,100 @@ +"""Сервис кэширования для Redis.""" +import json +from typing import Any, Dict, List, Optional + +import redis.asyncio as redis +from dotenv import load_dotenv + +load_dotenv() + + +class CacheService: + """Сервис для работы с Redis кэшем.""" + + def __init__(self, redis_url: Optional[str] = None): + self.redis_url = redis_url or "redis://localhost:6379/0" + self._client: Optional[redis.Redis] = None + + async def _get_client(self) -> redis.Redis: + """Получить клиент Redis (lazy initialization).""" + if self._client is None: + self._client = await redis.from_url(self.redis_url, decode_responses=True) + return self._client + + async def get_context(self, conversation_id: str, max_messages: int = 50) -> List[Dict[str, str]]: + """ + Получить контекст разговора из кэша. + + Args: + conversation_id: ID разговора + max_messages: Максимальное количество сообщений + + Returns: + Список сообщений в формате [{"role": "...", "content": "..."}] + """ + client = await self._get_client() + key = f"conversation:{conversation_id}" + + data = await client.get(key) + if not data: + return [] + + messages = json.loads(data) + # Возвращаем последние N сообщений + return messages[-max_messages:] if len(messages) > max_messages else messages + + async def save_context(self, conversation_id: str, messages: List[Dict[str, str]], ttl: int = 86400): + """ + Сохранить контекст разговора в кэш. + + Args: + conversation_id: ID разговора + messages: Список сообщений + ttl: Время жизни в секундах (по умолчанию 24 часа) + """ + client = await self._get_client() + key = f"conversation:{conversation_id}" + + # Ограничиваем количество сообщений для экономии памяти + max_messages = 100 + if len(messages) > max_messages: + messages = messages[-max_messages:] + + await client.setex(key, ttl, json.dumps(messages, ensure_ascii=False)) + + async def add_message(self, conversation_id: str, role: str, content: str): + """ + Добавить сообщение в контекст разговора. + + Args: + conversation_id: ID разговора + role: Роль (user, assistant, system) + content: Содержимое сообщения + """ + messages = await self.get_context(conversation_id, max_messages=1000) + messages.append({"role": role, "content": content}) + await self.save_context(conversation_id, messages) + + async def clear_context(self, conversation_id: str): + """Очистить контекст разговора.""" + client = await self._get_client() + key = f"conversation:{conversation_id}" + await client.delete(key) + + async def get(self, key: str) -> Optional[Any]: + """Получить значение по ключу.""" + client = await self._get_client() + data = await client.get(key) + return json.loads(data) if data else None + + async def set(self, key: str, value: Any, ttl: int = 3600): + """Установить значение с TTL.""" + client = await self._get_client() + await client.setex(key, ttl, json.dumps(value, ensure_ascii=False)) + + async def close(self): + """Закрыть соединение с Redis.""" + if self._client: + await self._client.close() + self._client = None + diff --git a/services/data_analyzer.py b/services/data_analyzer.py new file mode 100644 index 0000000..9d03071 --- /dev/null +++ b/services/data_analyzer.py @@ -0,0 +1,156 @@ +"""Сервис анализа данных детей.""" +from datetime import datetime, timedelta +from typing import Dict, List, Optional + +import pandas as pd + + +class DataAnalyzer: + """Сервис для анализа прогресса детей.""" + + @staticmethod + def calculate_completion_rate(tasks: List[Dict]) -> float: + """ + Рассчитать процент выполнения заданий. + + Args: + tasks: Список заданий с полем 'completed' + + Returns: + Процент выполнения (0.0 - 1.0) + """ + if not tasks: + return 0.0 + + completed = sum(1 for task in tasks if task.get("completed", False)) + return completed / len(tasks) + + @staticmethod + def analyze_daily_progress(schedules: List[Dict]) -> Dict: + """ + Проанализировать ежедневный прогресс. + + Args: + schedules: Список расписаний с заданиями + + Returns: + Словарь с аналитикой + """ + if not schedules: + return { + "total_days": 0, + "average_completion": 0.0, + "total_tasks": 0, + "completed_tasks": 0, + } + + total_tasks = 0 + completed_tasks = 0 + completion_rates = [] + + for schedule in schedules: + tasks = schedule.get("tasks", []) + total_tasks += len(tasks) + completed_tasks += sum(1 for task in tasks if task.get("completed", False)) + rate = DataAnalyzer.calculate_completion_rate(tasks) + completion_rates.append(rate) + + return { + "total_days": len(schedules), + "average_completion": sum(completion_rates) / len(completion_rates) if completion_rates else 0.0, + "total_tasks": total_tasks, + "completed_tasks": completed_tasks, + "completion_rate": completed_tasks / total_tasks if total_tasks > 0 else 0.0, + } + + @staticmethod + def get_category_statistics(schedules: List[Dict]) -> Dict[str, Dict]: + """ + Получить статистику по категориям заданий. + + Args: + schedules: Список расписаний + + Returns: + Словарь со статистикой по категориям + """ + category_stats: Dict[str, Dict] = {} + + for schedule in schedules: + for task in schedule.get("tasks", []): + category = task.get("category", "unknown") + if category not in category_stats: + category_stats[category] = { + "total": 0, + "completed": 0, + "average_duration": 0.0, + "durations": [], + } + + stats = category_stats[category] + stats["total"] += 1 + if task.get("completed", False): + stats["completed"] += 1 + if "duration_minutes" in task: + stats["durations"].append(task["duration_minutes"]) + + # Вычисляем среднюю длительность + for category, stats in category_stats.items(): + if stats["durations"]: + stats["average_duration"] = sum(stats["durations"]) / len(stats["durations"]) + del stats["durations"] + + return category_stats + + @staticmethod + def get_weekly_trend(schedules: List[Dict], days: int = 7) -> List[Dict]: + """ + Получить тренд за последние N дней. + + Args: + schedules: Список расписаний + days: Количество дней + + Returns: + Список словарей с данными по дням + """ + end_date = datetime.now().date() + start_date = end_date - timedelta(days=days - 1) + + # Группируем расписания по датам + daily_data: Dict[str, List[Dict]] = {} + for schedule in schedules: + schedule_date = schedule.get("date") + if isinstance(schedule_date, str): + schedule_date = datetime.fromisoformat(schedule_date).date() + elif isinstance(schedule_date, datetime): + schedule_date = schedule_date.date() + + if start_date <= schedule_date <= end_date: + date_str = str(schedule_date) + if date_str not in daily_data: + daily_data[date_str] = [] + daily_data[date_str].append(schedule) + + # Формируем тренд + trend = [] + current_date = start_date + while current_date <= end_date: + date_str = str(current_date) + day_schedules = daily_data.get(date_str, []) + all_tasks = [] + for sched in day_schedules: + all_tasks.extend(sched.get("tasks", [])) + + trend.append( + { + "date": date_str, + "completion_rate": DataAnalyzer.calculate_completion_rate(all_tasks), + "total_tasks": len(all_tasks), + "completed_tasks": sum(1 for task in all_tasks if task.get("completed", False)), + } + ) + current_date += timedelta(days=1) + + return trend + diff --git a/services/image_processor.py b/services/image_processor.py new file mode 100644 index 0000000..196246f --- /dev/null +++ b/services/image_processor.py @@ -0,0 +1,100 @@ +"""Сервис обработки изображений.""" +import io +from pathlib import Path +from typing import Optional, Tuple + +from PIL import Image + + +class ImageProcessor: + """Сервис для обработки изображений заданий.""" + + MAX_SIZE = (800, 800) + SUPPORTED_FORMATS = {"JPEG", "PNG", "WEBP"} + QUALITY = 85 + + @staticmethod + def resize_image( + image_data: bytes, max_size: Tuple[int, int] = MAX_SIZE, quality: int = QUALITY + ) -> bytes: + """ + Изменить размер изображения. + + Args: + image_data: Байты изображения + max_size: Максимальный размер (width, height) + quality: Качество JPEG (1-100) + + Returns: + Байты обработанного изображения + """ + image = Image.open(io.BytesIO(image_data)) + image_format = image.format or "JPEG" + + # Конвертируем в RGB если нужно + if image_format == "PNG" and image.mode in ("RGBA", "LA"): + background = Image.new("RGB", image.size, (255, 255, 255)) + if image.mode == "RGBA": + background.paste(image, mask=image.split()[3]) + else: + background.paste(image) + image = background + elif image.mode != "RGB": + image = image.convert("RGB") + + # Изменяем размер с сохранением пропорций + image.thumbnail(max_size, Image.Resampling.LANCZOS) + + # Сохраняем в байты + output = io.BytesIO() + image.save(output, format="JPEG", quality=quality, optimize=True) + return output.getvalue() + + @staticmethod + def validate_image(image_data: bytes) -> Tuple[bool, Optional[str]]: + """ + Валидировать изображение. + + Args: + image_data: Байты изображения + + Returns: + (is_valid, error_message) + """ + try: + image = Image.open(io.BytesIO(image_data)) + image_format = image.format + + if image_format not in ImageProcessor.SUPPORTED_FORMATS: + return False, f"Неподдерживаемый формат: {image_format}" + + # Проверяем размер + width, height = image.size + if width > 2000 or height > 2000: + return False, "Изображение слишком большое (максимум 2000x2000)" + + # Проверяем файл на валидность + image.verify() + return True, None + + except Exception as e: + return False, f"Ошибка валидации: {str(e)}" + + @staticmethod + def get_image_info(image_data: bytes) -> dict: + """ + Получить информацию об изображении. + + Args: + image_data: Байты изображения + + Returns: + Словарь с информацией (format, size, mode) + """ + image = Image.open(io.BytesIO(image_data)) + return { + "format": image.format, + "size": image.size, + "mode": image.mode, + } + diff --git a/services/token_manager.py b/services/token_manager.py new file mode 100644 index 0000000..489e164 --- /dev/null +++ b/services/token_manager.py @@ -0,0 +1,69 @@ +"""Управление токенами GigaChat.""" +import os +import time +from typing import Optional + +import aiohttp +from dotenv import load_dotenv + +load_dotenv() + + +class TokenManager: + """Менеджер токенов для GigaChat API.""" + + def __init__( + self, + client_id: Optional[str] = None, + client_secret: Optional[str] = None, + auth_url: Optional[str] = None, + ): + self.client_id = client_id or os.getenv("GIGACHAT_CLIENT_ID") + self.client_secret = client_secret or os.getenv("GIGACHAT_CLIENT_SECRET") + self.auth_url = auth_url or os.getenv( + "GIGACHAT_AUTH_URL", "https://ngw.devices.sberbank.ru:9443/api/v2/oauth" + ) + self._access_token: Optional[str] = None + self._expires_at: float = 0 + + async def get_token(self, force_refresh: bool = False) -> str: + """ + Получить актуальный токен доступа. + + Args: + force_refresh: Принудительно обновить токен + + Returns: + Токен доступа + """ + if not force_refresh and self._access_token and time.time() < self._expires_at: + return self._access_token + + async with aiohttp.ClientSession() as session: + auth = aiohttp.BasicAuth(self.client_id, self.client_secret) + async with session.post( + self.auth_url, + auth=auth, + data={"scope": "GIGACHAT_API_PERS"}, + ) as response: + if response.status != 200: + error_text = await response.text() + raise Exception(f"Failed to get token: {response.status} - {error_text}") + + data = await response.json() + self._access_token = data["access_token"] + # Токен обычно действителен 30 минут, обновляем за 5 минут до истечения + expires_in = data.get("expires_in", 1800) + self._expires_at = time.time() + expires_in - 300 + + return self._access_token + + def is_token_valid(self) -> bool: + """Проверить, действителен ли текущий токен.""" + return self._access_token is not None and time.time() < self._expires_at + + def clear_token(self): + """Очистить токен (для тестирования).""" + self._access_token = None + self._expires_at = 0 + diff --git a/tests/__init__.py b/tests/__init__.py new file mode 100644 index 0000000..d74ac18 --- /dev/null +++ b/tests/__init__.py @@ -0,0 +1,2 @@ +"""Тесты для AI-агентов.""" + diff --git a/tests/test_chat_agent.py b/tests/test_chat_agent.py new file mode 100644 index 0000000..a1437a4 --- /dev/null +++ b/tests/test_chat_agent.py @@ -0,0 +1,62 @@ +"""Тесты для чат-агента.""" +import pytest +from unittest.mock import AsyncMock +from uuid import uuid4 + +from agents.chat_agent import ChatAgent +from agents.gigachat_client import GigaChatClient +from models.gigachat_types import GigaChatMessage, GigaChatResponse, GigaChatUsage, GigaChatChoice +from services.cache_service import CacheService + + +@pytest.fixture +def mock_gigachat(): + """Фикстура для мокового GigaChat клиента.""" + return AsyncMock(spec=GigaChatClient) + + +@pytest.fixture +def mock_cache(): + """Фикстура для мокового CacheService.""" + return AsyncMock(spec=CacheService) + + +@pytest.fixture +def chat_agent(mock_gigachat, mock_cache): + """Фикстура для ChatAgent.""" + return ChatAgent(gigachat=mock_gigachat, cache=mock_cache) + + +@pytest.mark.asyncio +async def test_chat_basic(chat_agent, mock_gigachat, mock_cache): + """Тест базового чата.""" + user_id = uuid4() + message = "Привет!" + + mock_response = GigaChatResponse( + id="test_id", + object="chat.completion", + created=1234567890, + model="GigaChat-2-Lite", + choices=[ + GigaChatChoice( + message=GigaChatMessage(role="assistant", content="Привет! Как дела? 🌍"), + index=0, + ) + ], + usage=GigaChatUsage(prompt_tokens=50, completion_tokens=10, total_tokens=60), + ) + + mock_gigachat.chat_with_response.return_value = mock_response + mock_cache.get_context.return_value = [] + + response, tokens = await chat_agent.chat( + user_id=user_id, + message=message, + conversation_id="test_conv", + ) + + assert response == "Привет! Как дела? 🌍" + assert tokens == 60 + mock_cache.add_message.assert_called() + diff --git a/tests/test_gigachat_client.py b/tests/test_gigachat_client.py new file mode 100644 index 0000000..eab5d35 --- /dev/null +++ b/tests/test_gigachat_client.py @@ -0,0 +1,92 @@ +"""Тесты для GigaChat клиента.""" +import pytest +from unittest.mock import AsyncMock, MagicMock, patch + +from agents.gigachat_client import GigaChatClient +from models.gigachat_types import GigaChatMessage, GigaChatResponse, GigaChatUsage, GigaChatChoice +from services.token_manager import TokenManager + + +@pytest.fixture +def token_manager(): + """Фикстура для TokenManager.""" + manager = TokenManager( + client_id="test_id", + client_secret="test_secret", + ) + manager._access_token = "test_token" + manager._expires_at = 9999999999 + return manager + + +@pytest.fixture +def gigachat_client(token_manager): + """Фикстура для GigaChatClient.""" + return GigaChatClient(token_manager=token_manager) + + +@pytest.mark.asyncio +async def test_chat_success(gigachat_client): + """Тест успешного запроса к GigaChat.""" + mock_response = GigaChatResponse( + id="test_id", + object="chat.completion", + created=1234567890, + model="GigaChat-2", + choices=[ + GigaChatChoice( + message=GigaChatMessage(role="assistant", content="Тестовый ответ"), + index=0, + finish_reason="stop", + ) + ], + usage=GigaChatUsage( + prompt_tokens=10, + completion_tokens=5, + total_tokens=15, + ), + ) + + with patch("aiohttp.ClientSession.post") as mock_post: + mock_response_obj = AsyncMock() + mock_response_obj.status = 200 + mock_response_obj.json = AsyncMock(return_value=mock_response.model_dump()) + mock_post.return_value.__aenter__.return_value = mock_response_obj + + response = await gigachat_client.chat("Привет!") + + assert response == "Тестовый ответ" + + +@pytest.mark.asyncio +async def test_chat_with_context(gigachat_client): + """Тест запроса с контекстом.""" + context = [ + GigaChatMessage(role="system", content="Ты помощник"), + GigaChatMessage(role="user", content="Привет"), + ] + + mock_response = GigaChatResponse( + id="test_id", + object="chat.completion", + created=1234567890, + model="GigaChat-2", + choices=[ + GigaChatChoice( + message=GigaChatMessage(role="assistant", content="Ответ с контекстом"), + index=0, + ) + ], + usage=GigaChatUsage(prompt_tokens=20, completion_tokens=10, total_tokens=30), + ) + + with patch("aiohttp.ClientSession.post") as mock_post: + mock_response_obj = AsyncMock() + mock_response_obj.status = 200 + mock_response_obj.json = AsyncMock(return_value=mock_response.model_dump()) + mock_post.return_value.__aenter__.return_value = mock_response_obj + + response = await gigachat_client.chat("Как дела?", context=context) + + assert response == "Ответ с контекстом" + diff --git a/tests/test_schedule_generator.py b/tests/test_schedule_generator.py new file mode 100644 index 0000000..a5915ed --- /dev/null +++ b/tests/test_schedule_generator.py @@ -0,0 +1,89 @@ +"""Тесты для генератора расписаний.""" +import pytest +from unittest.mock import AsyncMock + +from agents.gigachat_client import GigaChatClient +from agents.schedule_generator import ScheduleGenerator +from models.gigachat_types import GigaChatMessage, GigaChatResponse, GigaChatUsage, GigaChatChoice +from services.token_manager import TokenManager + + +@pytest.fixture +def mock_gigachat(): + """Фикстура для мокового GigaChat клиента.""" + client = AsyncMock(spec=GigaChatClient) + return client + + +@pytest.fixture +def schedule_generator(mock_gigachat): + """Фикстура для ScheduleGenerator.""" + return ScheduleGenerator(gigachat=mock_gigachat) + + +@pytest.mark.asyncio +async def test_generate_schedule(schedule_generator, mock_gigachat): + """Тест генерации расписания.""" + mock_response_json = """ + { + "title": "Расписание на 2025-12-16", + "tasks": [ + { + "title": "Утренняя зарядка", + "description": "Сделай зарядку", + "duration_minutes": 15, + "category": "утренняя_рутина" + }, + { + "title": "Завтрак", + "description": "Позавтракай", + "duration_minutes": 20, + "category": "утренняя_рутина" + } + ] + } + """ + + mock_gigachat.chat.return_value = mock_response_json + + schedule = await schedule_generator.generate( + child_age=7, + preferences=["рисование", "прогулка"], + date="2025-12-16", + ) + + assert schedule.title == "Расписание на 2025-12-16" + assert len(schedule.tasks) == 2 + assert schedule.tasks[0].title == "Утренняя зарядка" + assert schedule.tasks[0].duration_minutes == 15 + + +@pytest.mark.asyncio +async def test_generate_schedule_with_markdown(schedule_generator, mock_gigachat): + """Тест генерации с markdown в ответе.""" + mock_response_json = """ + ```json + { + "title": "Тестовое расписание", + "tasks": [ + { + "title": "Тест", + "duration_minutes": 10, + "category": "обучение" + } + ] + } + ``` + """ + + mock_gigachat.chat.return_value = mock_response_json + + schedule = await schedule_generator.generate( + child_age=5, + preferences=[], + date="2025-12-17", + ) + + assert schedule.title == "Тестовое расписание" + assert len(schedule.tasks) == 1 +