275 lines
10 KiB
Python
275 lines
10 KiB
Python
from rest_framework import viewsets, permissions, status
|
|
from rest_framework.decorators import action, api_view, permission_classes
|
|
from rest_framework.response import Response
|
|
|
|
from rest_framework_simplejwt.views import TokenObtainPairView
|
|
|
|
from drf_spectacular.utils import extend_schema, inline_serializer
|
|
from rest_framework import serializers
|
|
|
|
from .models import Brief, TextVariant, Test, Segment, Assignment, ResultEntry, MetricsSnapshot, OptimizationPolicy
|
|
from .serializers import (
|
|
RegisterSerializer,
|
|
BriefSerializer,
|
|
TextVariantSerializer,
|
|
TestSerializer,
|
|
SegmentSerializer,
|
|
AssignmentSerializer,
|
|
ResultEntrySerializer,
|
|
MetricsSnapshotSerializer,
|
|
OptimizationPolicySerializer,
|
|
)
|
|
from .services import agents_generate_texts, agents_analyze, aggregate_test_rows, aggregate_test_rows_by_segment
|
|
|
|
|
|
# --- Public Auth Endpoints ---
|
|
|
|
@extend_schema(
|
|
auth=[],
|
|
request=RegisterSerializer,
|
|
responses={
|
|
200: inline_serializer(
|
|
name="RegisterResponse",
|
|
fields={
|
|
"id": serializers.IntegerField(),
|
|
"username": serializers.CharField(),
|
|
},
|
|
)
|
|
},
|
|
)
|
|
@api_view(["POST"])
|
|
@permission_classes([permissions.AllowAny])
|
|
def register(request):
|
|
ser = RegisterSerializer(data=request.data)
|
|
ser.is_valid(raise_exception=True)
|
|
user = ser.save()
|
|
return Response({"id": user.id, "username": user.username})
|
|
|
|
|
|
class PublicTokenObtainPairView(TokenObtainPairView):
|
|
"""Same JWT token endpoint, but marked as public for Swagger (auth=[])."""
|
|
|
|
@extend_schema(auth=[])
|
|
def post(self, request, *args, **kwargs):
|
|
return super().post(request, *args, **kwargs)
|
|
|
|
|
|
# --- Protected API (JWT required by DEFAULT_PERMISSION_CLASSES) ---
|
|
|
|
class BriefViewSet(viewsets.ModelViewSet):
|
|
serializer_class = BriefSerializer
|
|
permission_classes = [permissions.IsAuthenticated]
|
|
queryset = Brief.objects.all()
|
|
|
|
def get_queryset(self):
|
|
return Brief.objects.filter(owner=self.request.user).order_by("-id")
|
|
|
|
def perform_create(self, serializer):
|
|
serializer.save(owner=self.request.user)
|
|
|
|
@action(detail=True, methods=["GET"])
|
|
def segments(self, request, pk=None):
|
|
test = self.get_object()
|
|
qs = Segment.objects.filter(test=test).order_by("id")
|
|
return Response(SegmentSerializer(qs, many=True).data)
|
|
|
|
@action(detail=True, methods=["GET"])
|
|
def assignments(self, request, pk=None):
|
|
test = self.get_object()
|
|
qs = Assignment.objects.filter(test=test).order_by("id")
|
|
return Response(AssignmentSerializer(qs, many=True).data)
|
|
|
|
@action(detail=True, methods=["GET"])
|
|
def results(self, request, pk=None):
|
|
test = self.get_object()
|
|
qs = ResultEntry.objects.filter(test=test).order_by("-date", "-id")
|
|
return Response(ResultEntrySerializer(qs, many=True).data)
|
|
|
|
@action(detail=True, methods=["POST"])
|
|
def generate(self, request, pk=None):
|
|
brief = self.get_object()
|
|
# Build payload for agents service and ensure types are correct for validation
|
|
formats = brief.formats or []
|
|
if isinstance(formats, str):
|
|
# try to parse comma-separated / json-like strings
|
|
formats = [x.strip() for x in formats.split(",") if x.strip()]
|
|
if not isinstance(formats, list):
|
|
formats = []
|
|
|
|
benefits = brief.benefits or []
|
|
if isinstance(benefits, str):
|
|
benefits = [benefits]
|
|
if not isinstance(benefits, list):
|
|
benefits = []
|
|
|
|
if not formats:
|
|
# Without formats the agents service will return 422. Give a clear error to user.
|
|
return Response({"detail": "Выберите хотя бы один формат в брифе (например: Поисковое объявление)."}, status=status.HTTP_400_BAD_REQUEST)
|
|
|
|
payload = {
|
|
"product": brief.product,
|
|
"audience": brief.audience,
|
|
"usp": brief.usp,
|
|
"benefits": benefits,
|
|
"constraints": brief.constraints,
|
|
"tone": brief.tone,
|
|
"formats": formats,
|
|
"variants_per_format": max(1, min(int(brief.variants_per_format or 3), 10)),
|
|
}
|
|
res = agents_generate_texts(payload)
|
|
|
|
created = []
|
|
for block in res.get("variants", []):
|
|
fmt = block.get("format")
|
|
for item in block.get("items", []):
|
|
v = TextVariant.objects.create(
|
|
brief=brief,
|
|
format=fmt,
|
|
payload=item.get("payload") or {},
|
|
placement_tips=item.get("placement_tips", ""),
|
|
expected_effect=item.get("expected_effect", ""),
|
|
)
|
|
created.append(v.id)
|
|
return Response({"created_variant_ids": created})
|
|
|
|
|
|
class TextVariantViewSet(viewsets.ReadOnlyModelViewSet):
|
|
serializer_class = TextVariantSerializer
|
|
permission_classes = [permissions.IsAuthenticated]
|
|
queryset = TextVariant.objects.all()
|
|
|
|
def get_queryset(self):
|
|
qs = TextVariant.objects.filter(brief__owner=self.request.user).order_by("-id")
|
|
brief_id = self.request.query_params.get("brief")
|
|
if brief_id:
|
|
qs = qs.filter(brief_id=brief_id)
|
|
return qs
|
|
|
|
|
|
class TestViewSet(viewsets.ModelViewSet):
|
|
serializer_class = TestSerializer
|
|
permission_classes = [permissions.IsAuthenticated]
|
|
queryset = Test.objects.all()
|
|
|
|
def get_queryset(self):
|
|
return Test.objects.filter(owner=self.request.user).order_by("-id")
|
|
|
|
def perform_create(self, serializer):
|
|
serializer.save(owner=self.request.user)
|
|
|
|
|
|
@action(detail=True, methods=["GET","POST"])
|
|
def policy(self, request, pk=None):
|
|
"""Get or update optimization rules for this test."""
|
|
test = self.get_object()
|
|
obj, _ = OptimizationPolicy.objects.get_or_create(test=test)
|
|
if request.method == "GET":
|
|
return Response(OptimizationPolicySerializer(obj).data)
|
|
|
|
data = request.data or {}
|
|
# Normalize direction by KPI
|
|
kpi = (data.get("primary_kpi") or obj.primary_kpi or "cpl").lower()
|
|
direction = data.get("direction")
|
|
if not direction:
|
|
direction = "max" if kpi in ("ctr","cr") else "min"
|
|
for k in ("mode","primary_kpi","direction","good_threshold","ok_threshold","min_impressions","min_clicks","query_text"):
|
|
if k in data:
|
|
setattr(obj, k, data.get(k))
|
|
obj.direction = direction
|
|
obj.save()
|
|
return Response(OptimizationPolicySerializer(obj).data)
|
|
|
|
@action(detail=True, methods=["GET"])
|
|
def segments(self, request, pk=None):
|
|
test = self.get_object()
|
|
qs = Segment.objects.filter(test=test).order_by("id")
|
|
return Response(SegmentSerializer(qs, many=True).data)
|
|
|
|
@action(detail=True, methods=["GET"])
|
|
def assignments(self, request, pk=None):
|
|
test = self.get_object()
|
|
qs = Assignment.objects.filter(test=test).order_by("id")
|
|
return Response(AssignmentSerializer(qs, many=True).data)
|
|
|
|
@action(detail=True, methods=["GET"])
|
|
def results(self, request, pk=None):
|
|
test = self.get_object()
|
|
qs = ResultEntry.objects.filter(test=test).order_by("-date", "-id")
|
|
return Response(ResultEntrySerializer(qs, many=True).data)
|
|
|
|
@action(detail=True, methods=["POST"])
|
|
def add_segments(self, request, pk=None):
|
|
test = self.get_object()
|
|
segments = request.data.get("segments") or []
|
|
ids = []
|
|
for s in segments:
|
|
seg = Segment.objects.create(
|
|
test=test,
|
|
name=s.get("name", "Segment"),
|
|
description=s.get("description", ""),
|
|
)
|
|
ids.append(seg.id)
|
|
return Response({"created_segment_ids": ids})
|
|
|
|
@action(detail=True, methods=["POST"])
|
|
def assign(self, request, pk=None):
|
|
test = self.get_object()
|
|
assignments = request.data.get("assignments") or []
|
|
ids = []
|
|
for a in assignments:
|
|
seg = Segment.objects.get(id=a["segment_id"], test=test)
|
|
var = TextVariant.objects.get(id=a["variant_id"], brief=test.brief)
|
|
obj = Assignment.objects.create(test=test, segment=seg, variant=var)
|
|
ids.append(obj.id)
|
|
return Response({"created_assignment_ids": ids})
|
|
|
|
@action(detail=True, methods=["POST"])
|
|
def add_results(self, request, pk=None):
|
|
test = self.get_object()
|
|
rows = request.data.get("results") or []
|
|
ids = []
|
|
for r in rows:
|
|
seg = Segment.objects.get(id=r["segment_id"], test=test)
|
|
var = TextVariant.objects.get(id=r["variant_id"], brief=test.brief)
|
|
obj = ResultEntry.objects.create(
|
|
test=test,
|
|
segment=seg,
|
|
variant=var,
|
|
date=r["date"],
|
|
impressions=r.get("impressions", 0),
|
|
clicks=r.get("clicks", 0),
|
|
conversions=r.get("conversions", 0),
|
|
leads=r.get("leads", 0),
|
|
spend=r.get("spend", 0.0),
|
|
)
|
|
ids.append(obj.id)
|
|
return Response({"created_result_ids": ids})
|
|
|
|
@action(detail=True, methods=["POST"])
|
|
def analyze(self, request, pk=None):
|
|
test = self.get_object()
|
|
policy = None
|
|
if hasattr(test, "policy"):
|
|
policy = OptimizationPolicySerializer(test.policy).data
|
|
# allow overriding policy from request
|
|
if isinstance(request.data, dict) and request.data.get("policy"):
|
|
policy = request.data.get("policy")
|
|
rows = aggregate_test_rows_by_segment(test)
|
|
res = agents_analyze(rows, test.objective, policy=policy)
|
|
|
|
snap, _ = MetricsSnapshot.objects.update_or_create(
|
|
test=test,
|
|
defaults={
|
|
"ranking": res.get("ranking", []),
|
|
"recommendations": res.get("recommendations", []),
|
|
},
|
|
)
|
|
return Response(MetricsSnapshotSerializer(snap).data)
|
|
|
|
@action(detail=True, methods=["GET"])
|
|
def report(self, request, pk=None):
|
|
test = self.get_object()
|
|
if hasattr(test, "snapshot"):
|
|
return Response(MetricsSnapshotSerializer(test.snapshot).data)
|
|
return Response({"detail": "No analysis yet"}, status=status.HTTP_404_NOT_FOUND)
|