169 lines
5.5 KiB
Python
169 lines
5.5 KiB
Python
"""Генератор расписаний с использованием GigaChat."""
|
||
import json
|
||
from typing import List, Optional
|
||
|
||
from models.gigachat_types import GigaChatMessage
|
||
from models.schedule import Schedule, Task
|
||
from prompts.schedule_prompts import SCHEDULE_GENERATION_PROMPT
|
||
|
||
from agents.gigachat_client import GigaChatClient
|
||
|
||
|
||
class ScheduleGenerator:
|
||
"""Генератор расписаний для детей с РАС."""
|
||
|
||
def __init__(self, gigachat: GigaChatClient):
|
||
self.gigachat = gigachat
|
||
|
||
async def generate(
|
||
self,
|
||
child_age: int,
|
||
preferences: List[str],
|
||
date: str,
|
||
existing_tasks: Optional[List[str]] = None,
|
||
model: str = "GigaChat-2-Pro",
|
||
) -> Schedule:
|
||
"""
|
||
Сгенерировать расписание.
|
||
|
||
Args:
|
||
child_age: Возраст ребенка
|
||
preferences: Предпочтения ребенка
|
||
date: Дата расписания
|
||
existing_tasks: Существующие задания для учета
|
||
model: Модель GigaChat
|
||
|
||
Returns:
|
||
Объект расписания
|
||
"""
|
||
preferences_str = ", ".join(preferences) if preferences else "не указаны"
|
||
|
||
prompt = SCHEDULE_GENERATION_PROMPT.format(
|
||
age=child_age,
|
||
preferences=preferences_str,
|
||
date=date,
|
||
)
|
||
|
||
if existing_tasks:
|
||
prompt += f"\n\nУчти существующие задания: {', '.join(existing_tasks)}"
|
||
|
||
# Используем более высокую температуру для разнообразия
|
||
response_text = await self.gigachat.chat(
|
||
message=prompt,
|
||
model=model,
|
||
temperature=0.8,
|
||
max_tokens=3000,
|
||
)
|
||
|
||
# Парсим JSON из ответа
|
||
schedule_data = self._parse_json_response(response_text)
|
||
|
||
# Создаем объект Schedule
|
||
tasks = [
|
||
Task(
|
||
title=task_data["title"],
|
||
description=task_data.get("description"),
|
||
duration_minutes=task_data["duration_minutes"],
|
||
category=task_data.get("category", "обучение"),
|
||
)
|
||
for task_data in schedule_data.get("tasks", [])
|
||
]
|
||
|
||
return Schedule(
|
||
title=schedule_data.get("title", f"Расписание на {date}"),
|
||
date=date,
|
||
tasks=tasks,
|
||
)
|
||
|
||
async def update(
|
||
self,
|
||
existing_schedule: Schedule,
|
||
user_request: str,
|
||
model: str = "GigaChat-2-Pro",
|
||
) -> Schedule:
|
||
"""
|
||
Обновить существующее расписание.
|
||
|
||
Args:
|
||
existing_schedule: Текущее расписание
|
||
user_request: Запрос на изменение
|
||
model: Модель GigaChat
|
||
|
||
Returns:
|
||
Обновленное расписание
|
||
"""
|
||
from prompts.schedule_prompts import SCHEDULE_UPDATE_PROMPT
|
||
|
||
schedule_json = existing_schedule.model_dump_json()
|
||
|
||
prompt = SCHEDULE_UPDATE_PROMPT.format(
|
||
existing_schedule=schedule_json,
|
||
user_request=user_request,
|
||
)
|
||
|
||
response_text = await self.gigachat.chat(
|
||
message=prompt,
|
||
model=model,
|
||
temperature=0.7,
|
||
max_tokens=3000,
|
||
)
|
||
|
||
schedule_data = self._parse_json_response(response_text)
|
||
|
||
tasks = [
|
||
Task(
|
||
title=task_data["title"],
|
||
description=task_data.get("description"),
|
||
duration_minutes=task_data["duration_minutes"],
|
||
category=task_data.get("category", "обучение"),
|
||
)
|
||
for task_data in schedule_data.get("tasks", [])
|
||
]
|
||
|
||
return Schedule(
|
||
id=existing_schedule.id,
|
||
title=schedule_data.get("title", existing_schedule.title),
|
||
date=existing_schedule.date,
|
||
tasks=tasks,
|
||
user_id=existing_schedule.user_id,
|
||
)
|
||
|
||
def _parse_json_response(self, response_text: str) -> dict:
|
||
"""
|
||
Извлечь JSON из ответа модели.
|
||
|
||
Args:
|
||
response_text: Текст ответа
|
||
|
||
Returns:
|
||
Распарсенный JSON
|
||
"""
|
||
# Пытаемся найти JSON в ответе
|
||
response_text = response_text.strip()
|
||
|
||
# Удаляем markdown код блоки если есть
|
||
if response_text.startswith("```json"):
|
||
response_text = response_text[7:]
|
||
if response_text.startswith("```"):
|
||
response_text = response_text[3:]
|
||
if response_text.endswith("```"):
|
||
response_text = response_text[:-3]
|
||
|
||
response_text = response_text.strip()
|
||
|
||
try:
|
||
return json.loads(response_text)
|
||
except json.JSONDecodeError:
|
||
# Если не удалось распарсить, пытаемся найти JSON объект в тексте
|
||
start_idx = response_text.find("{")
|
||
end_idx = response_text.rfind("}") + 1
|
||
|
||
if start_idx >= 0 and end_idx > start_idx:
|
||
try:
|
||
return json.loads(response_text[start_idx:end_idx])
|
||
except json.JSONDecodeError:
|
||
pass
|
||
|
||
raise ValueError(f"Не удалось распарсить JSON из ответа: {response_text[:200]}")
|
||
|