from rest_framework import viewsets, permissions, status from rest_framework.decorators import action, api_view, permission_classes from rest_framework.response import Response from rest_framework_simplejwt.views import TokenObtainPairView from drf_spectacular.utils import extend_schema, inline_serializer from rest_framework import serializers from .models import Brief, TextVariant, Test, Segment, Assignment, ResultEntry, MetricsSnapshot, OptimizationPolicy from .serializers import ( RegisterSerializer, BriefSerializer, TextVariantSerializer, TestSerializer, SegmentSerializer, AssignmentSerializer, ResultEntrySerializer, MetricsSnapshotSerializer, OptimizationPolicySerializer, ) from .services import agents_generate_texts, agents_analyze, aggregate_test_rows, aggregate_test_rows_by_segment # --- Public Auth Endpoints --- @extend_schema( auth=[], request=RegisterSerializer, responses={ 200: inline_serializer( name="RegisterResponse", fields={ "id": serializers.IntegerField(), "username": serializers.CharField(), }, ) }, ) @api_view(["POST"]) @permission_classes([permissions.AllowAny]) def register(request): ser = RegisterSerializer(data=request.data) ser.is_valid(raise_exception=True) user = ser.save() return Response({"id": user.id, "username": user.username}) class PublicTokenObtainPairView(TokenObtainPairView): """Same JWT token endpoint, but marked as public for Swagger (auth=[]).""" @extend_schema(auth=[]) def post(self, request, *args, **kwargs): return super().post(request, *args, **kwargs) # --- Protected API (JWT required by DEFAULT_PERMISSION_CLASSES) --- class BriefViewSet(viewsets.ModelViewSet): serializer_class = BriefSerializer permission_classes = [permissions.IsAuthenticated] queryset = Brief.objects.all() def get_queryset(self): return Brief.objects.filter(owner=self.request.user).order_by("-id") def perform_create(self, serializer): serializer.save(owner=self.request.user) @action(detail=True, methods=["GET"]) def segments(self, request, pk=None): test = self.get_object() qs = Segment.objects.filter(test=test).order_by("id") return Response(SegmentSerializer(qs, many=True).data) @action(detail=True, methods=["GET"]) def assignments(self, request, pk=None): test = self.get_object() qs = Assignment.objects.filter(test=test).order_by("id") return Response(AssignmentSerializer(qs, many=True).data) @action(detail=True, methods=["GET"]) def results(self, request, pk=None): test = self.get_object() qs = ResultEntry.objects.filter(test=test).order_by("-date", "-id") return Response(ResultEntrySerializer(qs, many=True).data) @action(detail=True, methods=["POST"]) def generate(self, request, pk=None): brief = self.get_object() # Build payload for agents service and ensure types are correct for validation formats = brief.formats or [] if isinstance(formats, str): # try to parse comma-separated / json-like strings formats = [x.strip() for x in formats.split(",") if x.strip()] if not isinstance(formats, list): formats = [] benefits = brief.benefits or [] if isinstance(benefits, str): benefits = [benefits] if not isinstance(benefits, list): benefits = [] if not formats: # Without formats the agents service will return 422. Give a clear error to user. return Response({"detail": "Выберите хотя бы один формат в брифе (например: Поисковое объявление)."}, status=status.HTTP_400_BAD_REQUEST) payload = { "product": brief.product, "audience": brief.audience, "usp": brief.usp, "benefits": benefits, "constraints": brief.constraints, "tone": brief.tone, "formats": formats, "variants_per_format": max(1, min(int(brief.variants_per_format or 3), 10)), } res = agents_generate_texts(payload) created = [] for block in res.get("variants", []): fmt = block.get("format") for item in block.get("items", []): v = TextVariant.objects.create( brief=brief, format=fmt, payload=item.get("payload") or {}, placement_tips=item.get("placement_tips", ""), expected_effect=item.get("expected_effect", ""), ) created.append(v.id) return Response({"created_variant_ids": created}) class TextVariantViewSet(viewsets.ReadOnlyModelViewSet): serializer_class = TextVariantSerializer permission_classes = [permissions.IsAuthenticated] queryset = TextVariant.objects.all() def get_queryset(self): qs = TextVariant.objects.filter(brief__owner=self.request.user).order_by("-id") brief_id = self.request.query_params.get("brief") if brief_id: qs = qs.filter(brief_id=brief_id) return qs class TestViewSet(viewsets.ModelViewSet): serializer_class = TestSerializer permission_classes = [permissions.IsAuthenticated] queryset = Test.objects.all() def get_queryset(self): return Test.objects.filter(owner=self.request.user).order_by("-id") def perform_create(self, serializer): serializer.save(owner=self.request.user) @action(detail=True, methods=["GET","POST"]) def policy(self, request, pk=None): """Get or update optimization rules for this test.""" test = self.get_object() obj, _ = OptimizationPolicy.objects.get_or_create(test=test) if request.method == "GET": return Response(OptimizationPolicySerializer(obj).data) data = request.data or {} # Normalize direction by KPI kpi = (data.get("primary_kpi") or obj.primary_kpi or "cpl").lower() direction = data.get("direction") if not direction: direction = "max" if kpi in ("ctr","cr") else "min" for k in ("mode","primary_kpi","direction","good_threshold","ok_threshold","min_impressions","min_clicks","query_text"): if k in data: setattr(obj, k, data.get(k)) obj.direction = direction obj.save() return Response(OptimizationPolicySerializer(obj).data) @action(detail=True, methods=["GET"]) def segments(self, request, pk=None): test = self.get_object() qs = Segment.objects.filter(test=test).order_by("id") return Response(SegmentSerializer(qs, many=True).data) @action(detail=True, methods=["GET"]) def assignments(self, request, pk=None): test = self.get_object() qs = Assignment.objects.filter(test=test).order_by("id") return Response(AssignmentSerializer(qs, many=True).data) @action(detail=True, methods=["GET"]) def results(self, request, pk=None): test = self.get_object() qs = ResultEntry.objects.filter(test=test).order_by("-date", "-id") return Response(ResultEntrySerializer(qs, many=True).data) @action(detail=True, methods=["POST"]) def add_segments(self, request, pk=None): test = self.get_object() segments = request.data.get("segments") or [] ids = [] for s in segments: seg = Segment.objects.create( test=test, name=s.get("name", "Segment"), description=s.get("description", ""), ) ids.append(seg.id) return Response({"created_segment_ids": ids}) @action(detail=True, methods=["POST"]) def assign(self, request, pk=None): test = self.get_object() assignments = request.data.get("assignments") or [] ids = [] for a in assignments: seg = Segment.objects.get(id=a["segment_id"], test=test) var = TextVariant.objects.get(id=a["variant_id"], brief=test.brief) obj = Assignment.objects.create(test=test, segment=seg, variant=var) ids.append(obj.id) return Response({"created_assignment_ids": ids}) @action(detail=True, methods=["POST"]) def add_results(self, request, pk=None): test = self.get_object() rows = request.data.get("results") or [] ids = [] for r in rows: seg = Segment.objects.get(id=r["segment_id"], test=test) var = TextVariant.objects.get(id=r["variant_id"], brief=test.brief) obj = ResultEntry.objects.create( test=test, segment=seg, variant=var, date=r["date"], impressions=r.get("impressions", 0), clicks=r.get("clicks", 0), conversions=r.get("conversions", 0), leads=r.get("leads", 0), spend=r.get("spend", 0.0), ) ids.append(obj.id) return Response({"created_result_ids": ids}) @action(detail=True, methods=["POST"]) def analyze(self, request, pk=None): test = self.get_object() policy = None if hasattr(test, "policy"): policy = OptimizationPolicySerializer(test.policy).data # allow overriding policy from request if isinstance(request.data, dict) and request.data.get("policy"): policy = request.data.get("policy") rows = aggregate_test_rows_by_segment(test) res = agents_analyze(rows, test.objective, policy=policy) snap, _ = MetricsSnapshot.objects.update_or_create( test=test, defaults={ "ranking": res.get("ranking", []), "recommendations": res.get("recommendations", []), }, ) return Response(MetricsSnapshotSerializer(snap).data) @action(detail=True, methods=["GET"]) def report(self, request, pk=None): test = self.get_object() if hasattr(test, "snapshot"): return Response(MetricsSnapshotSerializer(test.snapshot).data) return Response({"detail": "No analysis yet"}, status=status.HTTP_404_NOT_FOUND)