MVP0
This commit is contained in:
9
backend_django/adsassistant_backend/api/admin.py
Normal file
9
backend_django/adsassistant_backend/api/admin.py
Normal file
@@ -0,0 +1,9 @@
|
||||
from django.contrib import admin
|
||||
from .models import Brief, TextVariant, Test, Segment, Assignment, ResultEntry, MetricsSnapshot
|
||||
admin.site.register(Brief)
|
||||
admin.site.register(TextVariant)
|
||||
admin.site.register(Test)
|
||||
admin.site.register(Segment)
|
||||
admin.site.register(Assignment)
|
||||
admin.site.register(ResultEntry)
|
||||
admin.site.register(MetricsSnapshot)
|
||||
98
backend_django/adsassistant_backend/api/models.py
Normal file
98
backend_django/adsassistant_backend/api/models.py
Normal file
@@ -0,0 +1,98 @@
|
||||
from django.db import models
|
||||
from django.contrib.auth.models import User
|
||||
|
||||
class Brief(models.Model):
|
||||
owner = models.ForeignKey(User, on_delete=models.CASCADE, related_name="briefs")
|
||||
product = models.TextField()
|
||||
audience = models.TextField()
|
||||
usp = models.TextField(blank=True, null=True)
|
||||
benefits = models.JSONField(default=list, blank=True)
|
||||
constraints = models.TextField(blank=True, null=True)
|
||||
tone = models.CharField(max_length=120, blank=True, null=True)
|
||||
|
||||
formats = models.JSONField(default=list) # client chooses formats
|
||||
variants_per_format = models.PositiveIntegerField(default=3)
|
||||
|
||||
created_at = models.DateTimeField(auto_now_add=True)
|
||||
|
||||
class TextVariant(models.Model):
|
||||
brief = models.ForeignKey(Brief, on_delete=models.CASCADE, related_name="variants")
|
||||
format = models.CharField(max_length=50)
|
||||
payload = models.JSONField(default=dict)
|
||||
placement_tips = models.TextField(blank=True, default="")
|
||||
expected_effect = models.TextField(blank=True, default="")
|
||||
created_at = models.DateTimeField(auto_now_add=True)
|
||||
|
||||
class Test(models.Model):
|
||||
owner = models.ForeignKey(User, on_delete=models.CASCADE, related_name="tests")
|
||||
brief = models.ForeignKey(Brief, on_delete=models.CASCADE, related_name="tests")
|
||||
name = models.CharField(max_length=200, default="Тест")
|
||||
channel = models.CharField(max_length=120, blank=True, default="")
|
||||
duration_days = models.PositiveIntegerField(default=3)
|
||||
sample_size = models.PositiveIntegerField(default=0)
|
||||
objective = models.CharField(max_length=32, default="leads") # leads|conversions|clicks
|
||||
status = models.CharField(max_length=20, default="draft")
|
||||
created_at = models.DateTimeField(auto_now_add=True)
|
||||
|
||||
class Segment(models.Model):
|
||||
test = models.ForeignKey(Test, on_delete=models.CASCADE, related_name="segments")
|
||||
name = models.CharField(max_length=200)
|
||||
description = models.TextField(blank=True, default="")
|
||||
|
||||
class Assignment(models.Model):
|
||||
test = models.ForeignKey(Test, on_delete=models.CASCADE, related_name="assignments")
|
||||
segment = models.ForeignKey(Segment, on_delete=models.CASCADE, related_name="assignments")
|
||||
variant = models.ForeignKey(TextVariant, on_delete=models.CASCADE, related_name="assignments")
|
||||
|
||||
class ResultEntry(models.Model):
|
||||
test = models.ForeignKey(Test, on_delete=models.CASCADE, related_name="results")
|
||||
segment = models.ForeignKey(Segment, on_delete=models.CASCADE, related_name="results")
|
||||
variant = models.ForeignKey(TextVariant, on_delete=models.CASCADE, related_name="results")
|
||||
date = models.DateField()
|
||||
impressions = models.PositiveIntegerField(default=0)
|
||||
clicks = models.PositiveIntegerField(default=0)
|
||||
conversions = models.PositiveIntegerField(default=0)
|
||||
leads = models.PositiveIntegerField(default=0)
|
||||
spend = models.FloatField(default=0.0)
|
||||
created_at = models.DateTimeField(auto_now_add=True)
|
||||
|
||||
class MetricsSnapshot(models.Model):
|
||||
test = models.OneToOneField(Test, on_delete=models.CASCADE, related_name="snapshot")
|
||||
ranking = models.JSONField(default=list)
|
||||
recommendations = models.JSONField(default=list)
|
||||
created_at = models.DateTimeField(auto_now_add=True)
|
||||
|
||||
|
||||
class OptimizationPolicy(models.Model):
|
||||
"""User-defined optimization rules for Agent #2.
|
||||
|
||||
Supports two modes:
|
||||
- thresholds: structured KPI + thresholds
|
||||
- text: free-form query that describes ranking/thresholds
|
||||
"""
|
||||
MODE_CHOICES = [
|
||||
("thresholds", "thresholds"),
|
||||
("text", "text"),
|
||||
]
|
||||
KPI_CHOICES = [
|
||||
("cpa", "cpa"),
|
||||
("cpl", "cpl"),
|
||||
("cpc", "cpc"),
|
||||
("ctr", "ctr"),
|
||||
("cr", "cr"),
|
||||
]
|
||||
test = models.OneToOneField("Test", on_delete=models.CASCADE, related_name="policy")
|
||||
mode = models.CharField(max_length=16, choices=MODE_CHOICES, default="thresholds")
|
||||
|
||||
# Structured mode
|
||||
primary_kpi = models.CharField(max_length=8, choices=KPI_CHOICES, default="cpl")
|
||||
direction = models.CharField(max_length=8, default="min") # min or max
|
||||
good_threshold = models.FloatField(null=True, blank=True)
|
||||
ok_threshold = models.FloatField(null=True, blank=True)
|
||||
min_impressions = models.IntegerField(default=0)
|
||||
min_clicks = models.IntegerField(default=0)
|
||||
|
||||
# Text mode
|
||||
query_text = models.TextField(blank=True, null=True)
|
||||
|
||||
updated_at = models.DateTimeField(auto_now=True)
|
||||
63
backend_django/adsassistant_backend/api/serializers.py
Normal file
63
backend_django/adsassistant_backend/api/serializers.py
Normal file
@@ -0,0 +1,63 @@
|
||||
from rest_framework import serializers
|
||||
from django.contrib.auth.models import User
|
||||
from .models import Brief, TextVariant, Test, Segment, Assignment, ResultEntry, MetricsSnapshot, OptimizationPolicy
|
||||
|
||||
class RegisterSerializer(serializers.ModelSerializer):
|
||||
password = serializers.CharField(write_only=True, min_length=6)
|
||||
class Meta:
|
||||
model = User
|
||||
fields = ("username","password","email")
|
||||
def create(self, validated):
|
||||
user = User(username=validated["username"], email=validated.get("email",""))
|
||||
user.set_password(validated["password"])
|
||||
user.save()
|
||||
return user
|
||||
|
||||
class BriefSerializer(serializers.ModelSerializer):
|
||||
class Meta:
|
||||
model = Brief
|
||||
fields = "__all__"
|
||||
read_only_fields = ("id","owner","created_at")
|
||||
|
||||
class TextVariantSerializer(serializers.ModelSerializer):
|
||||
class Meta:
|
||||
model = TextVariant
|
||||
fields = "__all__"
|
||||
read_only_fields = ("id","created_at")
|
||||
|
||||
class TestSerializer(serializers.ModelSerializer):
|
||||
class Meta:
|
||||
model = Test
|
||||
fields = "__all__"
|
||||
read_only_fields = ("id","owner","created_at")
|
||||
|
||||
class SegmentSerializer(serializers.ModelSerializer):
|
||||
class Meta:
|
||||
model = Segment
|
||||
fields = "__all__"
|
||||
read_only_fields = ("id",)
|
||||
|
||||
class AssignmentSerializer(serializers.ModelSerializer):
|
||||
class Meta:
|
||||
model = Assignment
|
||||
fields = "__all__"
|
||||
read_only_fields = ("id",)
|
||||
|
||||
class ResultEntrySerializer(serializers.ModelSerializer):
|
||||
class Meta:
|
||||
model = ResultEntry
|
||||
fields = "__all__"
|
||||
read_only_fields = ("id","created_at")
|
||||
|
||||
class MetricsSnapshotSerializer(serializers.ModelSerializer):
|
||||
class Meta:
|
||||
model = MetricsSnapshot
|
||||
fields = "__all__"
|
||||
read_only_fields = ("id","created_at")
|
||||
|
||||
|
||||
class OptimizationPolicySerializer(serializers.ModelSerializer):
|
||||
class Meta:
|
||||
model = OptimizationPolicy
|
||||
fields = "__all__"
|
||||
read_only_fields = ("id","test","updated_at")
|
||||
77
backend_django/adsassistant_backend/api/services.py
Normal file
77
backend_django/adsassistant_backend/api/services.py
Normal file
@@ -0,0 +1,77 @@
|
||||
import requests
|
||||
|
||||
|
||||
def _raise_with_detail(r: requests.Response):
|
||||
"""Raise an HTTPError but include upstream error body for easier debugging."""
|
||||
try:
|
||||
detail = r.json()
|
||||
except Exception:
|
||||
detail = r.text
|
||||
http_error_msg = f"{r.status_code} {r.url}: {detail}"
|
||||
raise requests.HTTPError(http_error_msg, response=r)
|
||||
from django.conf import settings
|
||||
from django.db.models import Sum
|
||||
from .models import Test, ResultEntry, TextVariant
|
||||
|
||||
def agents_generate_texts(payload: dict) -> dict:
|
||||
url = f"{settings.AGENTS_SERVICE_URL}/api/v1/texts/generate"
|
||||
r = requests.post(url, json=payload, timeout=180)
|
||||
|
||||
if r.status_code >= 400:
|
||||
_raise_with_detail(r)
|
||||
|
||||
return r.json()
|
||||
|
||||
def agents_analyze(rows: list[dict], objective: str, policy: dict | None = None) -> dict:
|
||||
url = f"{settings.AGENTS_SERVICE_URL}/api/v1/tests/analyze"
|
||||
payload = {"rows": rows, "objective": objective}
|
||||
if policy:
|
||||
payload["policy"] = policy
|
||||
r = requests.post(url, json=payload, timeout=90)
|
||||
|
||||
if r.status_code >= 400:
|
||||
_raise_with_detail(r)
|
||||
|
||||
return r.json()
|
||||
|
||||
def aggregate_test_rows(test: Test) -> list[dict]:
|
||||
qs = (ResultEntry.objects.filter(test=test).values("variant_id").annotate(
|
||||
impressions=Sum("impressions"), clicks=Sum("clicks"), conversions=Sum("conversions"), leads=Sum("leads"), spend=Sum("spend")
|
||||
))
|
||||
rows=[]
|
||||
for row in qs:
|
||||
v = TextVariant.objects.get(id=row["variant_id"])
|
||||
rows.append({
|
||||
"variant_id": v.id, "format": v.format,
|
||||
"impressions": int(row["impressions"] or 0),
|
||||
"clicks": int(row["clicks"] or 0),
|
||||
"conversions": int(row["conversions"] or 0),
|
||||
"leads": int(row["leads"] or 0),
|
||||
"spend": float(row["spend"] or 0.0),
|
||||
})
|
||||
return rows
|
||||
|
||||
|
||||
def aggregate_test_rows_by_segment(test: Test) -> list[dict]:
|
||||
"""Aggregate stats per (segment, variant)."""
|
||||
qs = (ResultEntry.objects.filter(test=test)
|
||||
.values("segment_id", "variant_id")
|
||||
.annotate(impressions=Sum("impressions"), clicks=Sum("clicks"),
|
||||
conversions=Sum("conversions"), leads=Sum("leads"), spend=Sum("spend")))
|
||||
rows = []
|
||||
# preload formats
|
||||
fmt_map = {tv.id: tv.format for tv in TextVariant.objects.filter(brief=test.brief)}
|
||||
seg_map = {s.id: s.name for s in test.segments.all()}
|
||||
for r in qs:
|
||||
rows.append({
|
||||
"variant_id": r["variant_id"],
|
||||
"format": fmt_map.get(r["variant_id"]),
|
||||
"segment_id": r["segment_id"],
|
||||
"segment_name": seg_map.get(r["segment_id"]),
|
||||
"impressions": int(r["impressions"] or 0),
|
||||
"clicks": int(r["clicks"] or 0),
|
||||
"conversions": int(r["conversions"] or 0),
|
||||
"leads": int(r["leads"] or 0),
|
||||
"spend": float(r["spend"] or 0.0),
|
||||
})
|
||||
return rows
|
||||
23
backend_django/adsassistant_backend/api/urls.py
Normal file
23
backend_django/adsassistant_backend/api/urls.py
Normal file
@@ -0,0 +1,23 @@
|
||||
from django.urls import path, include
|
||||
from rest_framework.routers import DefaultRouter
|
||||
from rest_framework_simplejwt.views import TokenRefreshView
|
||||
|
||||
from .views import (
|
||||
register,
|
||||
PublicTokenObtainPairView,
|
||||
BriefViewSet,
|
||||
TextVariantViewSet,
|
||||
TestViewSet,
|
||||
)
|
||||
|
||||
router = DefaultRouter()
|
||||
router.register(r"briefs", BriefViewSet, basename="brief")
|
||||
router.register(r"variants", TextVariantViewSet, basename="variant")
|
||||
router.register(r"tests", TestViewSet, basename="test")
|
||||
|
||||
urlpatterns = [
|
||||
path("auth/register/", register),
|
||||
path("auth/token/", PublicTokenObtainPairView.as_view(), name="token_obtain_pair"),
|
||||
path("auth/token/refresh/", TokenRefreshView.as_view(), name="token_refresh"),
|
||||
path("", include(router.urls)),
|
||||
]
|
||||
274
backend_django/adsassistant_backend/api/views.py
Normal file
274
backend_django/adsassistant_backend/api/views.py
Normal file
@@ -0,0 +1,274 @@
|
||||
from rest_framework import viewsets, permissions, status
|
||||
from rest_framework.decorators import action, api_view, permission_classes
|
||||
from rest_framework.response import Response
|
||||
|
||||
from rest_framework_simplejwt.views import TokenObtainPairView
|
||||
|
||||
from drf_spectacular.utils import extend_schema, inline_serializer
|
||||
from rest_framework import serializers
|
||||
|
||||
from .models import Brief, TextVariant, Test, Segment, Assignment, ResultEntry, MetricsSnapshot, OptimizationPolicy
|
||||
from .serializers import (
|
||||
RegisterSerializer,
|
||||
BriefSerializer,
|
||||
TextVariantSerializer,
|
||||
TestSerializer,
|
||||
SegmentSerializer,
|
||||
AssignmentSerializer,
|
||||
ResultEntrySerializer,
|
||||
MetricsSnapshotSerializer,
|
||||
OptimizationPolicySerializer,
|
||||
)
|
||||
from .services import agents_generate_texts, agents_analyze, aggregate_test_rows, aggregate_test_rows_by_segment
|
||||
|
||||
|
||||
# --- Public Auth Endpoints ---
|
||||
|
||||
@extend_schema(
|
||||
auth=[],
|
||||
request=RegisterSerializer,
|
||||
responses={
|
||||
200: inline_serializer(
|
||||
name="RegisterResponse",
|
||||
fields={
|
||||
"id": serializers.IntegerField(),
|
||||
"username": serializers.CharField(),
|
||||
},
|
||||
)
|
||||
},
|
||||
)
|
||||
@api_view(["POST"])
|
||||
@permission_classes([permissions.AllowAny])
|
||||
def register(request):
|
||||
ser = RegisterSerializer(data=request.data)
|
||||
ser.is_valid(raise_exception=True)
|
||||
user = ser.save()
|
||||
return Response({"id": user.id, "username": user.username})
|
||||
|
||||
|
||||
class PublicTokenObtainPairView(TokenObtainPairView):
|
||||
"""Same JWT token endpoint, but marked as public for Swagger (auth=[])."""
|
||||
|
||||
@extend_schema(auth=[])
|
||||
def post(self, request, *args, **kwargs):
|
||||
return super().post(request, *args, **kwargs)
|
||||
|
||||
|
||||
# --- Protected API (JWT required by DEFAULT_PERMISSION_CLASSES) ---
|
||||
|
||||
class BriefViewSet(viewsets.ModelViewSet):
|
||||
serializer_class = BriefSerializer
|
||||
permission_classes = [permissions.IsAuthenticated]
|
||||
queryset = Brief.objects.all()
|
||||
|
||||
def get_queryset(self):
|
||||
return Brief.objects.filter(owner=self.request.user).order_by("-id")
|
||||
|
||||
def perform_create(self, serializer):
|
||||
serializer.save(owner=self.request.user)
|
||||
|
||||
@action(detail=True, methods=["GET"])
|
||||
def segments(self, request, pk=None):
|
||||
test = self.get_object()
|
||||
qs = Segment.objects.filter(test=test).order_by("id")
|
||||
return Response(SegmentSerializer(qs, many=True).data)
|
||||
|
||||
@action(detail=True, methods=["GET"])
|
||||
def assignments(self, request, pk=None):
|
||||
test = self.get_object()
|
||||
qs = Assignment.objects.filter(test=test).order_by("id")
|
||||
return Response(AssignmentSerializer(qs, many=True).data)
|
||||
|
||||
@action(detail=True, methods=["GET"])
|
||||
def results(self, request, pk=None):
|
||||
test = self.get_object()
|
||||
qs = ResultEntry.objects.filter(test=test).order_by("-date", "-id")
|
||||
return Response(ResultEntrySerializer(qs, many=True).data)
|
||||
|
||||
@action(detail=True, methods=["POST"])
|
||||
def generate(self, request, pk=None):
|
||||
brief = self.get_object()
|
||||
# Build payload for agents service and ensure types are correct for validation
|
||||
formats = brief.formats or []
|
||||
if isinstance(formats, str):
|
||||
# try to parse comma-separated / json-like strings
|
||||
formats = [x.strip() for x in formats.split(",") if x.strip()]
|
||||
if not isinstance(formats, list):
|
||||
formats = []
|
||||
|
||||
benefits = brief.benefits or []
|
||||
if isinstance(benefits, str):
|
||||
benefits = [benefits]
|
||||
if not isinstance(benefits, list):
|
||||
benefits = []
|
||||
|
||||
if not formats:
|
||||
# Without formats the agents service will return 422. Give a clear error to user.
|
||||
return Response({"detail": "Выберите хотя бы один формат в брифе (например: Поисковое объявление)."}, status=status.HTTP_400_BAD_REQUEST)
|
||||
|
||||
payload = {
|
||||
"product": brief.product,
|
||||
"audience": brief.audience,
|
||||
"usp": brief.usp,
|
||||
"benefits": benefits,
|
||||
"constraints": brief.constraints,
|
||||
"tone": brief.tone,
|
||||
"formats": formats,
|
||||
"variants_per_format": max(1, min(int(brief.variants_per_format or 3), 10)),
|
||||
}
|
||||
res = agents_generate_texts(payload)
|
||||
|
||||
created = []
|
||||
for block in res.get("variants", []):
|
||||
fmt = block.get("format")
|
||||
for item in block.get("items", []):
|
||||
v = TextVariant.objects.create(
|
||||
brief=brief,
|
||||
format=fmt,
|
||||
payload=item.get("payload") or {},
|
||||
placement_tips=item.get("placement_tips", ""),
|
||||
expected_effect=item.get("expected_effect", ""),
|
||||
)
|
||||
created.append(v.id)
|
||||
return Response({"created_variant_ids": created})
|
||||
|
||||
|
||||
class TextVariantViewSet(viewsets.ReadOnlyModelViewSet):
|
||||
serializer_class = TextVariantSerializer
|
||||
permission_classes = [permissions.IsAuthenticated]
|
||||
queryset = TextVariant.objects.all()
|
||||
|
||||
def get_queryset(self):
|
||||
qs = TextVariant.objects.filter(brief__owner=self.request.user).order_by("-id")
|
||||
brief_id = self.request.query_params.get("brief")
|
||||
if brief_id:
|
||||
qs = qs.filter(brief_id=brief_id)
|
||||
return qs
|
||||
|
||||
|
||||
class TestViewSet(viewsets.ModelViewSet):
|
||||
serializer_class = TestSerializer
|
||||
permission_classes = [permissions.IsAuthenticated]
|
||||
queryset = Test.objects.all()
|
||||
|
||||
def get_queryset(self):
|
||||
return Test.objects.filter(owner=self.request.user).order_by("-id")
|
||||
|
||||
def perform_create(self, serializer):
|
||||
serializer.save(owner=self.request.user)
|
||||
|
||||
|
||||
@action(detail=True, methods=["GET","POST"])
|
||||
def policy(self, request, pk=None):
|
||||
"""Get or update optimization rules for this test."""
|
||||
test = self.get_object()
|
||||
obj, _ = OptimizationPolicy.objects.get_or_create(test=test)
|
||||
if request.method == "GET":
|
||||
return Response(OptimizationPolicySerializer(obj).data)
|
||||
|
||||
data = request.data or {}
|
||||
# Normalize direction by KPI
|
||||
kpi = (data.get("primary_kpi") or obj.primary_kpi or "cpl").lower()
|
||||
direction = data.get("direction")
|
||||
if not direction:
|
||||
direction = "max" if kpi in ("ctr","cr") else "min"
|
||||
for k in ("mode","primary_kpi","direction","good_threshold","ok_threshold","min_impressions","min_clicks","query_text"):
|
||||
if k in data:
|
||||
setattr(obj, k, data.get(k))
|
||||
obj.direction = direction
|
||||
obj.save()
|
||||
return Response(OptimizationPolicySerializer(obj).data)
|
||||
|
||||
@action(detail=True, methods=["GET"])
|
||||
def segments(self, request, pk=None):
|
||||
test = self.get_object()
|
||||
qs = Segment.objects.filter(test=test).order_by("id")
|
||||
return Response(SegmentSerializer(qs, many=True).data)
|
||||
|
||||
@action(detail=True, methods=["GET"])
|
||||
def assignments(self, request, pk=None):
|
||||
test = self.get_object()
|
||||
qs = Assignment.objects.filter(test=test).order_by("id")
|
||||
return Response(AssignmentSerializer(qs, many=True).data)
|
||||
|
||||
@action(detail=True, methods=["GET"])
|
||||
def results(self, request, pk=None):
|
||||
test = self.get_object()
|
||||
qs = ResultEntry.objects.filter(test=test).order_by("-date", "-id")
|
||||
return Response(ResultEntrySerializer(qs, many=True).data)
|
||||
|
||||
@action(detail=True, methods=["POST"])
|
||||
def add_segments(self, request, pk=None):
|
||||
test = self.get_object()
|
||||
segments = request.data.get("segments") or []
|
||||
ids = []
|
||||
for s in segments:
|
||||
seg = Segment.objects.create(
|
||||
test=test,
|
||||
name=s.get("name", "Segment"),
|
||||
description=s.get("description", ""),
|
||||
)
|
||||
ids.append(seg.id)
|
||||
return Response({"created_segment_ids": ids})
|
||||
|
||||
@action(detail=True, methods=["POST"])
|
||||
def assign(self, request, pk=None):
|
||||
test = self.get_object()
|
||||
assignments = request.data.get("assignments") or []
|
||||
ids = []
|
||||
for a in assignments:
|
||||
seg = Segment.objects.get(id=a["segment_id"], test=test)
|
||||
var = TextVariant.objects.get(id=a["variant_id"], brief=test.brief)
|
||||
obj = Assignment.objects.create(test=test, segment=seg, variant=var)
|
||||
ids.append(obj.id)
|
||||
return Response({"created_assignment_ids": ids})
|
||||
|
||||
@action(detail=True, methods=["POST"])
|
||||
def add_results(self, request, pk=None):
|
||||
test = self.get_object()
|
||||
rows = request.data.get("results") or []
|
||||
ids = []
|
||||
for r in rows:
|
||||
seg = Segment.objects.get(id=r["segment_id"], test=test)
|
||||
var = TextVariant.objects.get(id=r["variant_id"], brief=test.brief)
|
||||
obj = ResultEntry.objects.create(
|
||||
test=test,
|
||||
segment=seg,
|
||||
variant=var,
|
||||
date=r["date"],
|
||||
impressions=r.get("impressions", 0),
|
||||
clicks=r.get("clicks", 0),
|
||||
conversions=r.get("conversions", 0),
|
||||
leads=r.get("leads", 0),
|
||||
spend=r.get("spend", 0.0),
|
||||
)
|
||||
ids.append(obj.id)
|
||||
return Response({"created_result_ids": ids})
|
||||
|
||||
@action(detail=True, methods=["POST"])
|
||||
def analyze(self, request, pk=None):
|
||||
test = self.get_object()
|
||||
policy = None
|
||||
if hasattr(test, "policy"):
|
||||
policy = OptimizationPolicySerializer(test.policy).data
|
||||
# allow overriding policy from request
|
||||
if isinstance(request.data, dict) and request.data.get("policy"):
|
||||
policy = request.data.get("policy")
|
||||
rows = aggregate_test_rows_by_segment(test)
|
||||
res = agents_analyze(rows, test.objective, policy=policy)
|
||||
|
||||
snap, _ = MetricsSnapshot.objects.update_or_create(
|
||||
test=test,
|
||||
defaults={
|
||||
"ranking": res.get("ranking", []),
|
||||
"recommendations": res.get("recommendations", []),
|
||||
},
|
||||
)
|
||||
return Response(MetricsSnapshotSerializer(snap).data)
|
||||
|
||||
@action(detail=True, methods=["GET"])
|
||||
def report(self, request, pk=None):
|
||||
test = self.get_object()
|
||||
if hasattr(test, "snapshot"):
|
||||
return Response(MetricsSnapshotSerializer(test.snapshot).data)
|
||||
return Response({"detail": "No analysis yet"}, status=status.HTTP_404_NOT_FOUND)
|
||||
Reference in New Issue
Block a user