MVP0
This commit is contained in:
274
backend_django/adsassistant_backend/api/views.py
Normal file
274
backend_django/adsassistant_backend/api/views.py
Normal file
@@ -0,0 +1,274 @@
|
||||
from rest_framework import viewsets, permissions, status
|
||||
from rest_framework.decorators import action, api_view, permission_classes
|
||||
from rest_framework.response import Response
|
||||
|
||||
from rest_framework_simplejwt.views import TokenObtainPairView
|
||||
|
||||
from drf_spectacular.utils import extend_schema, inline_serializer
|
||||
from rest_framework import serializers
|
||||
|
||||
from .models import Brief, TextVariant, Test, Segment, Assignment, ResultEntry, MetricsSnapshot, OptimizationPolicy
|
||||
from .serializers import (
|
||||
RegisterSerializer,
|
||||
BriefSerializer,
|
||||
TextVariantSerializer,
|
||||
TestSerializer,
|
||||
SegmentSerializer,
|
||||
AssignmentSerializer,
|
||||
ResultEntrySerializer,
|
||||
MetricsSnapshotSerializer,
|
||||
OptimizationPolicySerializer,
|
||||
)
|
||||
from .services import agents_generate_texts, agents_analyze, aggregate_test_rows, aggregate_test_rows_by_segment
|
||||
|
||||
|
||||
# --- Public Auth Endpoints ---
|
||||
|
||||
@extend_schema(
|
||||
auth=[],
|
||||
request=RegisterSerializer,
|
||||
responses={
|
||||
200: inline_serializer(
|
||||
name="RegisterResponse",
|
||||
fields={
|
||||
"id": serializers.IntegerField(),
|
||||
"username": serializers.CharField(),
|
||||
},
|
||||
)
|
||||
},
|
||||
)
|
||||
@api_view(["POST"])
|
||||
@permission_classes([permissions.AllowAny])
|
||||
def register(request):
|
||||
ser = RegisterSerializer(data=request.data)
|
||||
ser.is_valid(raise_exception=True)
|
||||
user = ser.save()
|
||||
return Response({"id": user.id, "username": user.username})
|
||||
|
||||
|
||||
class PublicTokenObtainPairView(TokenObtainPairView):
|
||||
"""Same JWT token endpoint, but marked as public for Swagger (auth=[])."""
|
||||
|
||||
@extend_schema(auth=[])
|
||||
def post(self, request, *args, **kwargs):
|
||||
return super().post(request, *args, **kwargs)
|
||||
|
||||
|
||||
# --- Protected API (JWT required by DEFAULT_PERMISSION_CLASSES) ---
|
||||
|
||||
class BriefViewSet(viewsets.ModelViewSet):
|
||||
serializer_class = BriefSerializer
|
||||
permission_classes = [permissions.IsAuthenticated]
|
||||
queryset = Brief.objects.all()
|
||||
|
||||
def get_queryset(self):
|
||||
return Brief.objects.filter(owner=self.request.user).order_by("-id")
|
||||
|
||||
def perform_create(self, serializer):
|
||||
serializer.save(owner=self.request.user)
|
||||
|
||||
@action(detail=True, methods=["GET"])
|
||||
def segments(self, request, pk=None):
|
||||
test = self.get_object()
|
||||
qs = Segment.objects.filter(test=test).order_by("id")
|
||||
return Response(SegmentSerializer(qs, many=True).data)
|
||||
|
||||
@action(detail=True, methods=["GET"])
|
||||
def assignments(self, request, pk=None):
|
||||
test = self.get_object()
|
||||
qs = Assignment.objects.filter(test=test).order_by("id")
|
||||
return Response(AssignmentSerializer(qs, many=True).data)
|
||||
|
||||
@action(detail=True, methods=["GET"])
|
||||
def results(self, request, pk=None):
|
||||
test = self.get_object()
|
||||
qs = ResultEntry.objects.filter(test=test).order_by("-date", "-id")
|
||||
return Response(ResultEntrySerializer(qs, many=True).data)
|
||||
|
||||
@action(detail=True, methods=["POST"])
|
||||
def generate(self, request, pk=None):
|
||||
brief = self.get_object()
|
||||
# Build payload for agents service and ensure types are correct for validation
|
||||
formats = brief.formats or []
|
||||
if isinstance(formats, str):
|
||||
# try to parse comma-separated / json-like strings
|
||||
formats = [x.strip() for x in formats.split(",") if x.strip()]
|
||||
if not isinstance(formats, list):
|
||||
formats = []
|
||||
|
||||
benefits = brief.benefits or []
|
||||
if isinstance(benefits, str):
|
||||
benefits = [benefits]
|
||||
if not isinstance(benefits, list):
|
||||
benefits = []
|
||||
|
||||
if not formats:
|
||||
# Without formats the agents service will return 422. Give a clear error to user.
|
||||
return Response({"detail": "Выберите хотя бы один формат в брифе (например: Поисковое объявление)."}, status=status.HTTP_400_BAD_REQUEST)
|
||||
|
||||
payload = {
|
||||
"product": brief.product,
|
||||
"audience": brief.audience,
|
||||
"usp": brief.usp,
|
||||
"benefits": benefits,
|
||||
"constraints": brief.constraints,
|
||||
"tone": brief.tone,
|
||||
"formats": formats,
|
||||
"variants_per_format": max(1, min(int(brief.variants_per_format or 3), 10)),
|
||||
}
|
||||
res = agents_generate_texts(payload)
|
||||
|
||||
created = []
|
||||
for block in res.get("variants", []):
|
||||
fmt = block.get("format")
|
||||
for item in block.get("items", []):
|
||||
v = TextVariant.objects.create(
|
||||
brief=brief,
|
||||
format=fmt,
|
||||
payload=item.get("payload") or {},
|
||||
placement_tips=item.get("placement_tips", ""),
|
||||
expected_effect=item.get("expected_effect", ""),
|
||||
)
|
||||
created.append(v.id)
|
||||
return Response({"created_variant_ids": created})
|
||||
|
||||
|
||||
class TextVariantViewSet(viewsets.ReadOnlyModelViewSet):
|
||||
serializer_class = TextVariantSerializer
|
||||
permission_classes = [permissions.IsAuthenticated]
|
||||
queryset = TextVariant.objects.all()
|
||||
|
||||
def get_queryset(self):
|
||||
qs = TextVariant.objects.filter(brief__owner=self.request.user).order_by("-id")
|
||||
brief_id = self.request.query_params.get("brief")
|
||||
if brief_id:
|
||||
qs = qs.filter(brief_id=brief_id)
|
||||
return qs
|
||||
|
||||
|
||||
class TestViewSet(viewsets.ModelViewSet):
|
||||
serializer_class = TestSerializer
|
||||
permission_classes = [permissions.IsAuthenticated]
|
||||
queryset = Test.objects.all()
|
||||
|
||||
def get_queryset(self):
|
||||
return Test.objects.filter(owner=self.request.user).order_by("-id")
|
||||
|
||||
def perform_create(self, serializer):
|
||||
serializer.save(owner=self.request.user)
|
||||
|
||||
|
||||
@action(detail=True, methods=["GET","POST"])
|
||||
def policy(self, request, pk=None):
|
||||
"""Get or update optimization rules for this test."""
|
||||
test = self.get_object()
|
||||
obj, _ = OptimizationPolicy.objects.get_or_create(test=test)
|
||||
if request.method == "GET":
|
||||
return Response(OptimizationPolicySerializer(obj).data)
|
||||
|
||||
data = request.data or {}
|
||||
# Normalize direction by KPI
|
||||
kpi = (data.get("primary_kpi") or obj.primary_kpi or "cpl").lower()
|
||||
direction = data.get("direction")
|
||||
if not direction:
|
||||
direction = "max" if kpi in ("ctr","cr") else "min"
|
||||
for k in ("mode","primary_kpi","direction","good_threshold","ok_threshold","min_impressions","min_clicks","query_text"):
|
||||
if k in data:
|
||||
setattr(obj, k, data.get(k))
|
||||
obj.direction = direction
|
||||
obj.save()
|
||||
return Response(OptimizationPolicySerializer(obj).data)
|
||||
|
||||
@action(detail=True, methods=["GET"])
|
||||
def segments(self, request, pk=None):
|
||||
test = self.get_object()
|
||||
qs = Segment.objects.filter(test=test).order_by("id")
|
||||
return Response(SegmentSerializer(qs, many=True).data)
|
||||
|
||||
@action(detail=True, methods=["GET"])
|
||||
def assignments(self, request, pk=None):
|
||||
test = self.get_object()
|
||||
qs = Assignment.objects.filter(test=test).order_by("id")
|
||||
return Response(AssignmentSerializer(qs, many=True).data)
|
||||
|
||||
@action(detail=True, methods=["GET"])
|
||||
def results(self, request, pk=None):
|
||||
test = self.get_object()
|
||||
qs = ResultEntry.objects.filter(test=test).order_by("-date", "-id")
|
||||
return Response(ResultEntrySerializer(qs, many=True).data)
|
||||
|
||||
@action(detail=True, methods=["POST"])
|
||||
def add_segments(self, request, pk=None):
|
||||
test = self.get_object()
|
||||
segments = request.data.get("segments") or []
|
||||
ids = []
|
||||
for s in segments:
|
||||
seg = Segment.objects.create(
|
||||
test=test,
|
||||
name=s.get("name", "Segment"),
|
||||
description=s.get("description", ""),
|
||||
)
|
||||
ids.append(seg.id)
|
||||
return Response({"created_segment_ids": ids})
|
||||
|
||||
@action(detail=True, methods=["POST"])
|
||||
def assign(self, request, pk=None):
|
||||
test = self.get_object()
|
||||
assignments = request.data.get("assignments") or []
|
||||
ids = []
|
||||
for a in assignments:
|
||||
seg = Segment.objects.get(id=a["segment_id"], test=test)
|
||||
var = TextVariant.objects.get(id=a["variant_id"], brief=test.brief)
|
||||
obj = Assignment.objects.create(test=test, segment=seg, variant=var)
|
||||
ids.append(obj.id)
|
||||
return Response({"created_assignment_ids": ids})
|
||||
|
||||
@action(detail=True, methods=["POST"])
|
||||
def add_results(self, request, pk=None):
|
||||
test = self.get_object()
|
||||
rows = request.data.get("results") or []
|
||||
ids = []
|
||||
for r in rows:
|
||||
seg = Segment.objects.get(id=r["segment_id"], test=test)
|
||||
var = TextVariant.objects.get(id=r["variant_id"], brief=test.brief)
|
||||
obj = ResultEntry.objects.create(
|
||||
test=test,
|
||||
segment=seg,
|
||||
variant=var,
|
||||
date=r["date"],
|
||||
impressions=r.get("impressions", 0),
|
||||
clicks=r.get("clicks", 0),
|
||||
conversions=r.get("conversions", 0),
|
||||
leads=r.get("leads", 0),
|
||||
spend=r.get("spend", 0.0),
|
||||
)
|
||||
ids.append(obj.id)
|
||||
return Response({"created_result_ids": ids})
|
||||
|
||||
@action(detail=True, methods=["POST"])
|
||||
def analyze(self, request, pk=None):
|
||||
test = self.get_object()
|
||||
policy = None
|
||||
if hasattr(test, "policy"):
|
||||
policy = OptimizationPolicySerializer(test.policy).data
|
||||
# allow overriding policy from request
|
||||
if isinstance(request.data, dict) and request.data.get("policy"):
|
||||
policy = request.data.get("policy")
|
||||
rows = aggregate_test_rows_by_segment(test)
|
||||
res = agents_analyze(rows, test.objective, policy=policy)
|
||||
|
||||
snap, _ = MetricsSnapshot.objects.update_or_create(
|
||||
test=test,
|
||||
defaults={
|
||||
"ranking": res.get("ranking", []),
|
||||
"recommendations": res.get("recommendations", []),
|
||||
},
|
||||
)
|
||||
return Response(MetricsSnapshotSerializer(snap).data)
|
||||
|
||||
@action(detail=True, methods=["GET"])
|
||||
def report(self, request, pk=None):
|
||||
test = self.get_object()
|
||||
if hasattr(test, "snapshot"):
|
||||
return Response(MetricsSnapshotSerializer(test.snapshot).data)
|
||||
return Response({"detail": "No analysis yet"}, status=status.HTTP_404_NOT_FOUND)
|
||||
Reference in New Issue
Block a user