Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
34 changes: 31 additions & 3 deletions api/experimentation/views.py
Original file line number Diff line number Diff line change
@@ -1,14 +1,16 @@
import logging
from typing import Any

from django.db import IntegrityError
from django.db.models import Q, QuerySet
from rest_framework import mixins, status
from rest_framework import mixins, serializers, status
from rest_framework.decorators import action
from rest_framework.permissions import IsAuthenticated
from rest_framework.request import Request
from rest_framework.response import Response
from rest_framework.serializers import BaseSerializer

from app.pagination import CustomPagination
from environments.views import NestedEnvironmentViewSet
from experimentation.models import (
Experiment,
Expand Down Expand Up @@ -101,7 +103,7 @@ class ExperimentViewSet(
mixins.DestroyModelMixin,
):
serializer_class = ExperimentSerializer
pagination_class = None
pagination_class = CustomPagination
permission_classes = [IsAuthenticated, ExperimentPermission]
model_class = Experiment
lookup_field = "id"
Expand All @@ -125,6 +127,10 @@ def get_queryset(self) -> "QuerySet[Experiment]":
)
status_filter = self.request.query_params.get("status")
if status_filter:
if status_filter not in ExperimentStatus.values:
raise serializers.ValidationError(
{"status": f"Invalid status '{status_filter}'."}
)
qs = qs.filter(status=status_filter)

q = self.request.query_params.get("q")
Expand Down Expand Up @@ -152,7 +158,13 @@ def create(self, request: Request, *args: object, **kwargs: object) -> Response:
status=status.HTTP_409_CONFLICT,
)

self.perform_create(serializer)
try:
self.perform_create(serializer)
except IntegrityError:
return Response(
{"detail": "An active experiment already exists for this feature."},
status=status.HTTP_409_CONFLICT,
)
return Response(serializer.data, status=status.HTTP_201_CREATED)

def perform_create(self, serializer: BaseSerializer[Experiment]) -> None:
Expand All @@ -162,12 +174,28 @@ def perform_create(self, serializer: BaseSerializer[Experiment]) -> None:
)

def perform_update(self, serializer: BaseSerializer[Experiment]) -> None:
changed_fields = {
field
for field, value in serializer.validated_data.items()
if getattr(serializer.instance, field, None) != value
}
if not changed_fields:
return
experiment: Experiment = serializer.save()
create_experiment_audit_log(
experiment, self._get_user(self.request), action="updated"
)

def perform_destroy(self, instance: Experiment) -> None:
if instance.status == ExperimentStatus.RUNNING:
raise serializers.ValidationError(
{
"detail": (
"Cannot delete a running experiment. "
"Pause or complete it first."
)
}
)
create_experiment_audit_log(
instance, self._get_user(self.request), action="deleted"
)
Expand Down
119 changes: 107 additions & 12 deletions api/tests/unit/experimentation/test_experiment_views.py
Original file line number Diff line number Diff line change
Expand Up @@ -3,7 +3,9 @@
from typing import TYPE_CHECKING

import pytest
from django.db import IntegrityError
from django.urls import reverse
from pytest_mock import MockerFixture
from rest_framework import status
from rest_framework.test import APIClient

Expand Down Expand Up @@ -265,8 +267,9 @@ def test_get_list__with_experiments__returns_all(

# Then
assert response.status_code == status.HTTP_200_OK
assert len(response.json()) == 1
assert response.json()[0]["id"] == experiment.id
results = response.json()["results"]
assert len(results) == 1
assert results[0]["id"] == experiment.id


def test_get_list__with_experiments__returns_nested_feature(
Expand All @@ -284,9 +287,9 @@ def test_get_list__with_experiments__returns_nested_feature(

# Then
assert response.status_code == status.HTTP_200_OK
data = response.json()
assert len(data) == 1
feature_data = data[0]["feature"]
results = response.json()["results"]
assert len(results) == 1
feature_data = results[0]["feature"]
assert isinstance(feature_data, dict)
assert feature_data["id"] == multivariate_feature.id
assert feature_data["name"] == multivariate_feature.name
Expand Down Expand Up @@ -329,7 +332,7 @@ def test_get_list__empty__returns_200(

# Then
assert response.status_code == status.HTTP_200_OK
assert response.json() == []
assert response.json()["results"] == []


@pytest.mark.parametrize(
Expand Down Expand Up @@ -357,7 +360,7 @@ def test_get_list__filter_by_status__returns_filtered(

# Then
assert response.status_code == status.HTTP_200_OK
assert len(response.json()) == expected_count
assert len(response.json()["results"]) == expected_count


def test_get_list__search_by_experiment_name__returns_matching(
Expand All @@ -374,8 +377,9 @@ def test_get_list__search_by_experiment_name__returns_matching(

# Then
assert response.status_code == status.HTTP_200_OK
assert len(response.json()) == 1
assert response.json()[0]["id"] == experiment.id
results = response.json()["results"]
assert len(results) == 1
assert results[0]["id"] == experiment.id


def test_get_list__search_by_feature_name__returns_matching(
Expand All @@ -395,8 +399,9 @@ def test_get_list__search_by_feature_name__returns_matching(

# Then
assert response.status_code == status.HTTP_200_OK
assert len(response.json()) == 1
assert response.json()[0]["id"] == experiment.id
results = response.json()["results"]
assert len(results) == 1
assert results[0]["id"] == experiment.id


def test_get_list__search_no_match__returns_empty(
Expand All @@ -413,7 +418,7 @@ def test_get_list__search_no_match__returns_empty(

# Then
assert response.status_code == status.HTTP_200_OK
assert len(response.json()) == 0
assert len(response.json()["results"]) == 0


def test_get_detail__exists__returns_200(
Expand Down Expand Up @@ -670,3 +675,93 @@ def test_delete__valid_delete__creates_audit_log(
).last()
assert audit is not None
assert "deleted" in audit.log


def test_get_list__invalid_status__returns_400(
admin_client_new: APIClient,
environment: Environment,
enable_features: EnableFeaturesFixture,
) -> None:
# Given
enable_features(EXPERIMENT_FLAG)

# When
response = admin_client_new.get(_list_url(environment), {"status": "garbage"})

# Then
assert response.status_code == status.HTTP_400_BAD_REQUEST


def test_delete__running_experiment__returns_400(
admin_client_new: APIClient,
environment: Environment,
experiment: Experiment,
enable_features: EnableFeaturesFixture,
) -> None:
# Given
enable_features(EXPERIMENT_FLAG)
experiment.status = ExperimentStatus.RUNNING
experiment.save()

# When
response = admin_client_new.delete(_detail_url(environment, experiment))

# Then
assert response.status_code == status.HTTP_400_BAD_REQUEST
assert Experiment.objects.filter(id=experiment.id).exists()


def test_patch__no_change__skips_audit_log(
admin_client_new: APIClient,
environment: Environment,
experiment: Experiment,
enable_features: EnableFeaturesFixture,
) -> None:
# Given
enable_features(EXPERIMENT_FLAG)
audit_count_before = AuditLog.objects.filter(
related_object_type=RelatedObjectType.EXPERIMENT.name
).count()

# When
response = admin_client_new.patch(
_detail_url(environment, experiment),
data={"name": experiment.name},
format="json",
)

# Then
assert response.status_code == status.HTTP_200_OK
audit_count_after = AuditLog.objects.filter(
related_object_type=RelatedObjectType.EXPERIMENT.name
).count()
assert audit_count_after == audit_count_before


def test_post__concurrent_create_race__returns_409(
admin_client_new: APIClient,
environment: Environment,
multivariate_feature: Feature,
enable_features: EnableFeaturesFixture,
mocker: MockerFixture,
) -> None:
# Given
enable_features(EXPERIMENT_FLAG)
mocker.patch(
"experimentation.views.ExperimentViewSet.perform_create",
side_effect=IntegrityError(),
)

# When
response = admin_client_new.post(
_list_url(environment),
data={
"feature": multivariate_feature.id,
"name": "Race",
"hypothesis": "Should 409",
},
format="json",
)

# Then
assert response.status_code == status.HTTP_409_CONFLICT
Loading