Files
adsassistant/backend_django/adsassistant_backend/api/views.py
2026-03-05 06:55:42 +03:00

275 lines
10 KiB
Python

from rest_framework import viewsets, permissions, status
from rest_framework.decorators import action, api_view, permission_classes
from rest_framework.response import Response
from rest_framework_simplejwt.views import TokenObtainPairView
from drf_spectacular.utils import extend_schema, inline_serializer
from rest_framework import serializers
from .models import Brief, TextVariant, Test, Segment, Assignment, ResultEntry, MetricsSnapshot, OptimizationPolicy
from .serializers import (
RegisterSerializer,
BriefSerializer,
TextVariantSerializer,
TestSerializer,
SegmentSerializer,
AssignmentSerializer,
ResultEntrySerializer,
MetricsSnapshotSerializer,
OptimizationPolicySerializer,
)
from .services import agents_generate_texts, agents_analyze, aggregate_test_rows, aggregate_test_rows_by_segment
# --- Public Auth Endpoints ---
@extend_schema(
auth=[],
request=RegisterSerializer,
responses={
200: inline_serializer(
name="RegisterResponse",
fields={
"id": serializers.IntegerField(),
"username": serializers.CharField(),
},
)
},
)
@api_view(["POST"])
@permission_classes([permissions.AllowAny])
def register(request):
ser = RegisterSerializer(data=request.data)
ser.is_valid(raise_exception=True)
user = ser.save()
return Response({"id": user.id, "username": user.username})
class PublicTokenObtainPairView(TokenObtainPairView):
"""Same JWT token endpoint, but marked as public for Swagger (auth=[])."""
@extend_schema(auth=[])
def post(self, request, *args, **kwargs):
return super().post(request, *args, **kwargs)
# --- Protected API (JWT required by DEFAULT_PERMISSION_CLASSES) ---
class BriefViewSet(viewsets.ModelViewSet):
serializer_class = BriefSerializer
permission_classes = [permissions.IsAuthenticated]
queryset = Brief.objects.all()
def get_queryset(self):
return Brief.objects.filter(owner=self.request.user).order_by("-id")
def perform_create(self, serializer):
serializer.save(owner=self.request.user)
@action(detail=True, methods=["GET"])
def segments(self, request, pk=None):
test = self.get_object()
qs = Segment.objects.filter(test=test).order_by("id")
return Response(SegmentSerializer(qs, many=True).data)
@action(detail=True, methods=["GET"])
def assignments(self, request, pk=None):
test = self.get_object()
qs = Assignment.objects.filter(test=test).order_by("id")
return Response(AssignmentSerializer(qs, many=True).data)
@action(detail=True, methods=["GET"])
def results(self, request, pk=None):
test = self.get_object()
qs = ResultEntry.objects.filter(test=test).order_by("-date", "-id")
return Response(ResultEntrySerializer(qs, many=True).data)
@action(detail=True, methods=["POST"])
def generate(self, request, pk=None):
brief = self.get_object()
# Build payload for agents service and ensure types are correct for validation
formats = brief.formats or []
if isinstance(formats, str):
# try to parse comma-separated / json-like strings
formats = [x.strip() for x in formats.split(",") if x.strip()]
if not isinstance(formats, list):
formats = []
benefits = brief.benefits or []
if isinstance(benefits, str):
benefits = [benefits]
if not isinstance(benefits, list):
benefits = []
if not formats:
# Without formats the agents service will return 422. Give a clear error to user.
return Response({"detail": "Выберите хотя бы один формат в брифе (например: Поисковое объявление)."}, status=status.HTTP_400_BAD_REQUEST)
payload = {
"product": brief.product,
"audience": brief.audience,
"usp": brief.usp,
"benefits": benefits,
"constraints": brief.constraints,
"tone": brief.tone,
"formats": formats,
"variants_per_format": max(1, min(int(brief.variants_per_format or 3), 10)),
}
res = agents_generate_texts(payload)
created = []
for block in res.get("variants", []):
fmt = block.get("format")
for item in block.get("items", []):
v = TextVariant.objects.create(
brief=brief,
format=fmt,
payload=item.get("payload") or {},
placement_tips=item.get("placement_tips", ""),
expected_effect=item.get("expected_effect", ""),
)
created.append(v.id)
return Response({"created_variant_ids": created})
class TextVariantViewSet(viewsets.ReadOnlyModelViewSet):
serializer_class = TextVariantSerializer
permission_classes = [permissions.IsAuthenticated]
queryset = TextVariant.objects.all()
def get_queryset(self):
qs = TextVariant.objects.filter(brief__owner=self.request.user).order_by("-id")
brief_id = self.request.query_params.get("brief")
if brief_id:
qs = qs.filter(brief_id=brief_id)
return qs
class TestViewSet(viewsets.ModelViewSet):
serializer_class = TestSerializer
permission_classes = [permissions.IsAuthenticated]
queryset = Test.objects.all()
def get_queryset(self):
return Test.objects.filter(owner=self.request.user).order_by("-id")
def perform_create(self, serializer):
serializer.save(owner=self.request.user)
@action(detail=True, methods=["GET","POST"])
def policy(self, request, pk=None):
"""Get or update optimization rules for this test."""
test = self.get_object()
obj, _ = OptimizationPolicy.objects.get_or_create(test=test)
if request.method == "GET":
return Response(OptimizationPolicySerializer(obj).data)
data = request.data or {}
# Normalize direction by KPI
kpi = (data.get("primary_kpi") or obj.primary_kpi or "cpl").lower()
direction = data.get("direction")
if not direction:
direction = "max" if kpi in ("ctr","cr") else "min"
for k in ("mode","primary_kpi","direction","good_threshold","ok_threshold","min_impressions","min_clicks","query_text"):
if k in data:
setattr(obj, k, data.get(k))
obj.direction = direction
obj.save()
return Response(OptimizationPolicySerializer(obj).data)
@action(detail=True, methods=["GET"])
def segments(self, request, pk=None):
test = self.get_object()
qs = Segment.objects.filter(test=test).order_by("id")
return Response(SegmentSerializer(qs, many=True).data)
@action(detail=True, methods=["GET"])
def assignments(self, request, pk=None):
test = self.get_object()
qs = Assignment.objects.filter(test=test).order_by("id")
return Response(AssignmentSerializer(qs, many=True).data)
@action(detail=True, methods=["GET"])
def results(self, request, pk=None):
test = self.get_object()
qs = ResultEntry.objects.filter(test=test).order_by("-date", "-id")
return Response(ResultEntrySerializer(qs, many=True).data)
@action(detail=True, methods=["POST"])
def add_segments(self, request, pk=None):
test = self.get_object()
segments = request.data.get("segments") or []
ids = []
for s in segments:
seg = Segment.objects.create(
test=test,
name=s.get("name", "Segment"),
description=s.get("description", ""),
)
ids.append(seg.id)
return Response({"created_segment_ids": ids})
@action(detail=True, methods=["POST"])
def assign(self, request, pk=None):
test = self.get_object()
assignments = request.data.get("assignments") or []
ids = []
for a in assignments:
seg = Segment.objects.get(id=a["segment_id"], test=test)
var = TextVariant.objects.get(id=a["variant_id"], brief=test.brief)
obj = Assignment.objects.create(test=test, segment=seg, variant=var)
ids.append(obj.id)
return Response({"created_assignment_ids": ids})
@action(detail=True, methods=["POST"])
def add_results(self, request, pk=None):
test = self.get_object()
rows = request.data.get("results") or []
ids = []
for r in rows:
seg = Segment.objects.get(id=r["segment_id"], test=test)
var = TextVariant.objects.get(id=r["variant_id"], brief=test.brief)
obj = ResultEntry.objects.create(
test=test,
segment=seg,
variant=var,
date=r["date"],
impressions=r.get("impressions", 0),
clicks=r.get("clicks", 0),
conversions=r.get("conversions", 0),
leads=r.get("leads", 0),
spend=r.get("spend", 0.0),
)
ids.append(obj.id)
return Response({"created_result_ids": ids})
@action(detail=True, methods=["POST"])
def analyze(self, request, pk=None):
test = self.get_object()
policy = None
if hasattr(test, "policy"):
policy = OptimizationPolicySerializer(test.policy).data
# allow overriding policy from request
if isinstance(request.data, dict) and request.data.get("policy"):
policy = request.data.get("policy")
rows = aggregate_test_rows_by_segment(test)
res = agents_analyze(rows, test.objective, policy=policy)
snap, _ = MetricsSnapshot.objects.update_or_create(
test=test,
defaults={
"ranking": res.get("ranking", []),
"recommendations": res.get("recommendations", []),
},
)
return Response(MetricsSnapshotSerializer(snap).data)
@action(detail=True, methods=["GET"])
def report(self, request, pk=None):
test = self.get_object()
if hasattr(test, "snapshot"):
return Response(MetricsSnapshotSerializer(test.snapshot).data)
return Response({"detail": "No analysis yet"}, status=status.HTTP_404_NOT_FOUND)