Skip to content

Commit

Permalink
fix: Add stale_flags_limit_days to Project serializer (#3607)
Browse files Browse the repository at this point in the history
  • Loading branch information
matthewelwell authored Mar 14, 2024
1 parent a03b681 commit 99e0148
Show file tree
Hide file tree
Showing 6 changed files with 111 additions and 13 deletions.
9 changes: 9 additions & 0 deletions api/conftest.py
Original file line number Diff line number Diff line change
Expand Up @@ -176,6 +176,15 @@ def system_tag(project: Project) -> Tag:
)


@pytest.fixture()
def enterprise_subscription(organisation: Organisation) -> Subscription:
Subscription.objects.filter(organisation=organisation).update(
plan="enterprise", subscription_id="subscription-id"
)
organisation.refresh_from_db()
return organisation.subscription


@pytest.fixture()
def project(organisation):
return Project.objects.create(name="Test Project", organisation=organisation)
Expand Down
15 changes: 13 additions & 2 deletions api/organisations/subscriptions/serializers/mixins.py
Original file line number Diff line number Diff line change
@@ -1,3 +1,4 @@
import re
import typing

from organisations.models import Subscription
Expand All @@ -24,6 +25,13 @@ def get_subscription(self):

invalid_plans: typing.Iterable[str] = None
field_names: typing.Iterable[str] = None
invalid_plans_regex: str = None

def __init__(self, *args, **kwargs):
super().__init__(*args, **kwargs)
self.invalid_plans_regex_matcher = (
re.compile(self.invalid_plans_regex) if self.invalid_plans_regex else None
)

def get_fields(self, *args, **kwargs):
fields = super().get_fields(*args, **kwargs)
Expand All @@ -33,12 +41,15 @@ def get_fields(self, *args, **kwargs):

subscription = self.get_subscription()
field_names = self.field_names or []
invalid_plans = self.invalid_plans or []

for field_name in field_names:
if field_name in fields and (
not (subscription and subscription.plan)
or subscription.plan in invalid_plans
or (self.invalid_plans and subscription.plan in self.invalid_plans)
or (
self.invalid_plans_regex
and re.match(self.invalid_plans_regex_matcher, subscription.plan)
)
):
fields[field_name].read_only = True

Expand Down
31 changes: 31 additions & 0 deletions api/projects/serializers.py
Original file line number Diff line number Diff line change
@@ -1,8 +1,14 @@
import typing

from django.conf import settings
from rest_framework import serializers

from environments.dynamodb.migrator import IdentityMigrator
from environments.dynamodb.types import ProjectIdentityMigrationStatus
from organisations.models import Subscription
from organisations.subscriptions.serializers.mixins import (
ReadOnlyIfNotValidPlanMixin,
)
from permissions.serializers import CreateUpdateUserPermissionSerializerABC
from projects.models import (
Project,
Expand Down Expand Up @@ -35,6 +41,7 @@ class Meta:
"only_allow_lower_case_feature_names",
"feature_name_regex",
"show_edge_identity_overrides_for_feature",
"stale_flags_limit_days",
)

def get_migration_status(self, obj: Project) -> str:
Expand All @@ -57,6 +64,30 @@ def get_use_edge_identities(self, obj: Project) -> bool:
)


class ProjectUpdateOrCreateSerializer(
ReadOnlyIfNotValidPlanMixin, ProjectListSerializer
):
invalid_plans_regex = r"^(free|startup.*|scale-up.*)$"
field_names = ("stale_flags_limit_days",)

def get_subscription(self) -> typing.Optional[Subscription]:
view = self.context["view"]

if view.action == "create":
# handle `organisation` not being part of the data
# When request comes from yasg2 (as part of schema generation)
organisation_id = view.request.data.get("organisation")
if not organisation_id:
return None

# Organisation should only have a single subscription
return Subscription.objects.filter(organisation_id=organisation_id).first()
elif view.action in ("update", "partial_update"):
return getattr(self.instance.organisation, "subscription", None)

return None


class ProjectRetrieveSerializer(ProjectListSerializer):
total_features = serializers.SerializerMethodField()
total_segments = serializers.SerializerMethodField()
Expand Down
6 changes: 6 additions & 0 deletions api/projects/views.py
Original file line number Diff line number Diff line change
Expand Up @@ -45,6 +45,7 @@
ListUserProjectPermissionSerializer,
ProjectListSerializer,
ProjectRetrieveSerializer,
ProjectUpdateOrCreateSerializer,
)


Expand Down Expand Up @@ -75,10 +76,15 @@ class ProjectViewSet(viewsets.ModelViewSet):
def get_serializer_class(self):
if self.action == "retrieve":
return ProjectRetrieveSerializer
elif self.action in ("create", "update", "partial_update"):
return ProjectUpdateOrCreateSerializer
return ProjectListSerializer

pagination_class = None

def get_serializer_context(self):
return super().get_serializer_context()

def get_queryset(self):
if getattr(self, "swagger_fake_view", False):
return Project.objects.none()
Expand Down
Original file line number Diff line number Diff line change
@@ -1,5 +1,6 @@
from unittest.mock import MagicMock

import pytest
from rest_framework import serializers

from organisations.models import Subscription
Expand All @@ -8,20 +9,29 @@
)


def test_read_only_if_not_valid_plan_mixin_sets_read_only_if_plan_not_valid():
@pytest.mark.parametrize(
"plan_id, invalid_plans_, invalid_plans_regex_",
(
("invalid-plan-id", ("invalid-plan-id",), ""),
("invalid-plan-id", tuple(), "invalid-.*"),
),
)
def test_read_only_if_not_valid_plan_mixin_sets_read_only_if_plan_not_valid(
plan_id: str, invalid_plans_: list[str], invalid_plans_regex_: str
) -> None:
# Given
invalid_plan_id = "invalid-plan-id"

mock_view = MagicMock()

class MySerializer(ReadOnlyIfNotValidPlanMixin, serializers.Serializer):
invalid_plans = (invalid_plan_id,)
field_names = ("foo",)

invalid_plans = invalid_plans_
invalid_plans_regex = invalid_plans_regex_

foo = serializers.CharField()

def get_subscription(self) -> Subscription:
return MagicMock(plan=invalid_plan_id)
return MagicMock(plan=plan_id)

serializer = MySerializer(data={"foo": "bar"}, context={"view": mock_view})

Expand All @@ -37,12 +47,14 @@ def test_read_only_if_not_valid_plan_mixin_does_not_set_read_only_if_plan_valid(
# Given
valid_plan_id = "plan-id"
invalid_plan_id = "invalid-plan-id"
invalid_plans_regex_ = r"^another-invalid-plan-id-.*$"

mock_view = MagicMock()

class MySerializer(ReadOnlyIfNotValidPlanMixin, serializers.Serializer):
invalid_plans = (invalid_plan_id,)
field_names = ("foo",)
invalid_plans_regex = invalid_plans_regex_

foo = serializers.CharField()

Expand Down
41 changes: 35 additions & 6 deletions api/tests/unit/projects/test_unit_projects_views.py
Original file line number Diff line number Diff line change
Expand Up @@ -14,7 +14,7 @@
from environments.dynamodb.types import ProjectIdentityMigrationStatus
from environments.identities.models import Identity
from features.models import Feature, FeatureSegment
from organisations.models import Organisation, OrganisationRole
from organisations.models import Organisation, OrganisationRole, Subscription
from organisations.permissions.models import (
OrganisationPermissionModel,
UserOrganisationPermission,
Expand All @@ -39,11 +39,24 @@
yesterday = now - timedelta(days=1)


def test_should_create_a_project(settings, admin_user, admin_client, organisation):
def test_should_create_a_project(
settings: SettingsWrapper,
admin_user: FFAdminUser,
admin_client: APIClient,
organisation: Organisation,
enterprise_subscription: Subscription,
) -> None:
# Given
project_name = "project1"
settings.PROJECT_METADATA_TABLE_NAME_DYNAMO = None
data = {"name": project_name, "organisation": organisation.id}

project_name = "project1"
stale_flags_limit_days = 15

data = {
"name": project_name,
"organisation": organisation.id,
"stale_flags_limit_days": stale_flags_limit_days,
}
url = reverse("api-v1:projects:project-list")

# When
Expand All @@ -52,6 +65,10 @@ def test_should_create_a_project(settings, admin_user, admin_client, organisatio
# Then
assert response.status_code == status.HTTP_201_CREATED
assert Project.objects.filter(name=project_name).count() == 1

project = Project.objects.get(name=project_name)
assert project.stale_flags_limit_days == stale_flags_limit_days

assert (
response.json()["migration_status"]
== ProjectIdentityMigrationStatus.NOT_APPLICABLE.value
Expand Down Expand Up @@ -95,10 +112,21 @@ def test_should_create_a_project_with_admin_master_api_key_client(
"client",
[(lazy_fixture("admin_master_api_key_client")), (lazy_fixture("admin_client"))],
)
def test_can_update_project(client, project, organisation):
def test_can_update_project(
client: APIClient,
project: Project,
organisation: Organisation,
enterprise_subscription: Subscription,
) -> None:
# Given
new_name = "New project name"
data = {"name": new_name, "organisation": organisation.id}
new_stale_flags_limit_days = 15

data = {
"name": new_name,
"organisation": organisation.id,
"stale_flags_limit_days": new_stale_flags_limit_days,
}
url = reverse("api-v1:projects:project-detail", args=[project.id])

# When
Expand All @@ -107,6 +135,7 @@ def test_can_update_project(client, project, organisation):
# Then
assert response.status_code == status.HTTP_200_OK
assert response.json()["name"] == new_name
assert response.json()["stale_flags_limit_days"] == new_stale_flags_limit_days


@pytest.mark.parametrize(
Expand Down

0 comments on commit 99e0148

Please sign in to comment.