Files
New-planet-ai-agent/agents/schedule_generator.py

169 lines
5.5 KiB
Python
Raw Blame History

This file contains ambiguous Unicode characters
This file contains Unicode characters that might be confused with other characters. If you think that this is intentional, you can safely ignore this warning. Use the Escape button to reveal them.
"""Генератор расписаний с использованием GigaChat."""
import json
from typing import List, Optional
from models.gigachat_types import GigaChatMessage
from models.schedule import Schedule, Task
from prompts.schedule_prompts import SCHEDULE_GENERATION_PROMPT
from agents.gigachat_client import GigaChatClient
class ScheduleGenerator:
"""Генератор расписаний для детей с РАС."""
def __init__(self, gigachat: GigaChatClient):
self.gigachat = gigachat
async def generate(
self,
child_age: int,
preferences: List[str],
date: str,
existing_tasks: Optional[List[str]] = None,
model: str = "GigaChat-2-Pro",
) -> Schedule:
"""
Сгенерировать расписание.
Args:
child_age: Возраст ребенка
preferences: Предпочтения ребенка
date: Дата расписания
existing_tasks: Существующие задания для учета
model: Модель GigaChat
Returns:
Объект расписания
"""
preferences_str = ", ".join(preferences) if preferences else "не указаны"
prompt = SCHEDULE_GENERATION_PROMPT.format(
age=child_age,
preferences=preferences_str,
date=date,
)
if existing_tasks:
prompt += f"\n\nУчти существующие задания: {', '.join(existing_tasks)}"
# Используем более высокую температуру для разнообразия
response_text = await self.gigachat.chat(
message=prompt,
model=model,
temperature=0.8,
max_tokens=3000,
)
# Парсим JSON из ответа
schedule_data = self._parse_json_response(response_text)
# Создаем объект Schedule
tasks = [
Task(
title=task_data["title"],
description=task_data.get("description"),
duration_minutes=task_data["duration_minutes"],
category=task_data.get("category", "обучение"),
)
for task_data in schedule_data.get("tasks", [])
]
return Schedule(
title=schedule_data.get("title", f"Расписание на {date}"),
date=date,
tasks=tasks,
)
async def update(
self,
existing_schedule: Schedule,
user_request: str,
model: str = "GigaChat-2-Pro",
) -> Schedule:
"""
Обновить существующее расписание.
Args:
existing_schedule: Текущее расписание
user_request: Запрос на изменение
model: Модель GigaChat
Returns:
Обновленное расписание
"""
from prompts.schedule_prompts import SCHEDULE_UPDATE_PROMPT
schedule_json = existing_schedule.model_dump_json()
prompt = SCHEDULE_UPDATE_PROMPT.format(
existing_schedule=schedule_json,
user_request=user_request,
)
response_text = await self.gigachat.chat(
message=prompt,
model=model,
temperature=0.7,
max_tokens=3000,
)
schedule_data = self._parse_json_response(response_text)
tasks = [
Task(
title=task_data["title"],
description=task_data.get("description"),
duration_minutes=task_data["duration_minutes"],
category=task_data.get("category", "обучение"),
)
for task_data in schedule_data.get("tasks", [])
]
return Schedule(
id=existing_schedule.id,
title=schedule_data.get("title", existing_schedule.title),
date=existing_schedule.date,
tasks=tasks,
user_id=existing_schedule.user_id,
)
def _parse_json_response(self, response_text: str) -> dict:
"""
Извлечь JSON из ответа модели.
Args:
response_text: Текст ответа
Returns:
Распарсенный JSON
"""
# Пытаемся найти JSON в ответе
response_text = response_text.strip()
# Удаляем markdown код блоки если есть
if response_text.startswith("```json"):
response_text = response_text[7:]
if response_text.startswith("```"):
response_text = response_text[3:]
if response_text.endswith("```"):
response_text = response_text[:-3]
response_text = response_text.strip()
try:
return json.loads(response_text)
except json.JSONDecodeError:
# Если не удалось распарсить, пытаемся найти JSON объект в тексте
start_idx = response_text.find("{")
end_idx = response_text.rfind("}") + 1
if start_idx >= 0 and end_idx > start_idx:
try:
return json.loads(response_text[start_idx:end_idx])
except json.JSONDecodeError:
pass
raise ValueError(f"Не удалось распарсить JSON из ответа: {response_text[:200]}")