From 7a6a2c8f929bc079526a852494e3cfb87f796fb3 Mon Sep 17 00:00:00 2001 From: Matthew Elwell Date: Thu, 17 Aug 2023 10:55:46 +0100 Subject: [PATCH] feat: re-add totals and limits (#2631) --- api/audit/serializers.py | 4 +- api/environments/serializers.py | 14 +++- api/environments/views.py | 14 +++- api/organisations/views.py | 4 +- api/projects/serializers.py | 36 +++++++- api/projects/tests/test_serializers.py | 44 ++++++++-- api/projects/views.py | 9 +- .../unit/projects/test_unit_projects_views.py | 82 +++++++++++++++++++ 8 files changed, 186 insertions(+), 21 deletions(-) create mode 100644 api/tests/unit/projects/test_unit_projects_views.py diff --git a/api/audit/serializers.py b/api/audit/serializers.py index e4de2d68bb48..a936a137fe1e 100644 --- a/api/audit/serializers.py +++ b/api/audit/serializers.py @@ -2,14 +2,14 @@ from audit.models import AuditLog from environments.serializers import EnvironmentSerializerLight -from projects.serializers import ProjectSerializer +from projects.serializers import ProjectListSerializer from users.serializers import UserListSerializer class AuditLogSerializer(serializers.ModelSerializer): author = UserListSerializer() environment = EnvironmentSerializerLight() - project = ProjectSerializer() + project = ProjectListSerializer() class Meta: model = AuditLog diff --git a/api/environments/serializers.py b/api/environments/serializers.py index 86d8e72d894e..19609da9be4d 100644 --- a/api/environments/serializers.py +++ b/api/environments/serializers.py @@ -10,7 +10,7 @@ ReadOnlyIfNotValidPlanMixin, ) from projects.models import Project -from projects.serializers import ProjectSerializer +from projects.serializers import ProjectListSerializer from util.drf_writable_nested.serializers import ( DeleteBeforeUpdateWritableNestedModelSerializer, ) @@ -18,7 +18,7 @@ class EnvironmentSerializerFull(serializers.ModelSerializer): feature_states = FeatureStateSerializerFull(many=True) - project = ProjectSerializer() + project = ProjectListSerializer() class Meta: model = Environment @@ -86,6 +86,16 @@ def get_project(self, validated_data: dict = None) -> Project: ) +class EnvironmentRetrieveSerializerWithMetadata(EnvironmentSerializerWithMetadata): + total_segment_overrides = serializers.IntegerField() + + class Meta(EnvironmentSerializerWithMetadata.Meta): + fields = EnvironmentSerializerWithMetadata.Meta.fields + ( + "total_segment_overrides", + ) + read_only_fields = ("total_segment_overrides",) + + class CreateUpdateEnvironmentSerializer( ReadOnlyIfNotValidPlanMixin, EnvironmentSerializerWithMetadata ): diff --git a/api/environments/views.py b/api/environments/views.py index 34129decf3ad..e23209558273 100644 --- a/api/environments/views.py +++ b/api/environments/views.py @@ -3,6 +3,7 @@ import logging +from django.db.models import Count from django.utils.decorators import method_decorator from drf_yasg import openapi from drf_yasg.utils import swagger_auto_schema @@ -41,6 +42,7 @@ CloneEnvironmentSerializer, CreateUpdateEnvironmentSerializer, EnvironmentAPIKeySerializer, + EnvironmentRetrieveSerializerWithMetadata, EnvironmentSerializerWithMetadata, WebhookSerializer, ) @@ -73,6 +75,8 @@ def get_serializer_class(self): return DeleteAllTraitKeysSerializer if self.action == "clone": return CloneEnvironmentSerializer + if self.action == "retrieve": + return EnvironmentRetrieveSerializerWithMetadata elif self.action in ("create", "update", "partial_update"): return CreateUpdateEnvironmentSerializer return EnvironmentSerializerWithMetadata @@ -98,12 +102,20 @@ def get_queryset(self): return ( self.request.master_api_key.organisation.projects.environments.all() ) + return self.request.user.get_permitted_environments( "VIEW_ENVIRONMENT", project=project ) # Permission class handles validation of permissions for other actions - return Environment.objects.all() + queryset = Environment.objects.all() + + if self.action == "retrieve": + queryset = queryset.annotate( + total_segment_overrides=Count("feature_segments") + ) + + return queryset def perform_create(self, serializer): environment = serializer.save() diff --git a/api/organisations/views.py b/api/organisations/views.py index e15c5a53c58d..998a16373857 100644 --- a/api/organisations/views.py +++ b/api/organisations/views.py @@ -49,7 +49,7 @@ PermissionModelSerializer, UserObjectPermissionsSerializer, ) -from projects.serializers import ProjectSerializer +from projects.serializers import ProjectListSerializer from users.serializers import UserIdSerializer from webhooks.mixins import TriggerSampleWebhookMixin from webhooks.webhooks import WebhookType @@ -118,7 +118,7 @@ def create(self, request, **kwargs): def projects(self, request, pk): organisation = self.get_object() projects = organisation.projects.all() - return Response(ProjectSerializer(projects, many=True).data) + return Response(ProjectListSerializer(projects, many=True).data) @action(detail=True, methods=["POST"]) def invite(self, request, pk): diff --git a/api/projects/serializers.py b/api/projects/serializers.py index 10914a5fa1ff..58582d22eb6f 100644 --- a/api/projects/serializers.py +++ b/api/projects/serializers.py @@ -12,7 +12,7 @@ from users.serializers import UserListSerializer, UserPermissionGroupSerializer -class ProjectSerializer(serializers.ModelSerializer): +class ProjectListSerializer(serializers.ModelSerializer): migration_status = serializers.SerializerMethodField( help_text="Edge migration status of the project; can be one of: " + ", ".join([k.value for k in ProjectIdentityMigrationStatus]) @@ -39,10 +39,8 @@ class Meta: def get_migration_status(self, obj: Project) -> str: if not settings.PROJECT_METADATA_TABLE_NAME_DYNAMO: migration_status = ProjectIdentityMigrationStatus.NOT_APPLICABLE.value - elif obj.is_edge_project_by_default: migration_status = ProjectIdentityMigrationStatus.MIGRATION_COMPLETED.value - else: migration_status = IdentityMigrator(obj.id).migration_status.value @@ -58,6 +56,38 @@ def get_use_edge_identities(self, obj: Project) -> bool: ) +class ProjectRetrieveSerializer(ProjectListSerializer): + total_features = serializers.SerializerMethodField() + total_segments = serializers.SerializerMethodField() + + class Meta(ProjectListSerializer.Meta): + fields = ProjectListSerializer.Meta.fields + ( + "max_segments_allowed", + "max_features_allowed", + "max_segment_overrides_allowed", + "total_features", + "total_segments", + ) + + read_only_fields = ( + "max_segments_allowed", + "max_features_allowed", + "max_segment_overrides_allowed", + "total_features", + "total_segments", + ) + + def get_total_features(self, instance: Project) -> int: + # added here to prevent need for annotate(Count("features", distinct=True)) + # which causes performance issues. + return instance.features.count() + + def get_total_segments(self, instance: Project) -> int: + # added here to prevent need for annotate(Count("segments", distinct=True)) + # which causes performance issues. + return instance.segments.count() + + class CreateUpdateUserProjectPermissionSerializer( CreateUpdateUserPermissionSerializerABC ): diff --git a/api/projects/tests/test_serializers.py b/api/projects/tests/test_serializers.py index 13bb9d00b025..079b87106175 100644 --- a/api/projects/tests/test_serializers.py +++ b/api/projects/tests/test_serializers.py @@ -4,10 +4,13 @@ from django.utils import timezone from environments.dynamodb.types import ProjectIdentityMigrationStatus -from projects.serializers import ProjectSerializer +from projects.serializers import ( + ProjectListSerializer, + ProjectRetrieveSerializer, +) -def test_ProjectSerializer_get_migration_status_returns_migration_not_applicable_if_not_configured( +def test_ProjectListSerializer_get_migration_status_returns_migration_not_applicable_if_not_configured( mocker, project, settings ): # Given @@ -16,7 +19,7 @@ def test_ProjectSerializer_get_migration_status_returns_migration_not_applicable "projects.serializers.IdentityMigrator", autospec=True ) - serializer = ProjectSerializer() + serializer = ProjectListSerializer() # When migration_status = serializer.get_migration_status(project) @@ -26,7 +29,7 @@ def test_ProjectSerializer_get_migration_status_returns_migration_not_applicable mocked_identity_migrator.assert_not_called() -def test_ProjectSerializer_get_migration_status_returns_migration_completed_for_new_projects( +def test_ProjectListSerializer_get_migration_status_returns_migration_completed_for_new_projects( mocker, project, settings ): # Given @@ -36,7 +39,7 @@ def test_ProjectSerializer_get_migration_status_returns_migration_completed_for_ "projects.serializers.IdentityMigrator", autospec=True ) - serializer = ProjectSerializer() + serializer = ProjectListSerializer() # When migration_status = serializer.get_migration_status(project) @@ -46,7 +49,7 @@ def test_ProjectSerializer_get_migration_status_returns_migration_completed_for_ mocked_identity_migrator.assert_not_called() -def test_ProjectSerializer_get_migration_status_calls_migrator_with_correct_arguments_for_old_projects( +def test_ProjectListSerializer_get_migration_status_calls_migrator_with_correct_arguments_for_old_projects( mocker, project, settings ): # Given @@ -57,7 +60,7 @@ def test_ProjectSerializer_get_migration_status_calls_migrator_with_correct_argu settings.EDGE_RELEASE_DATETIME = timezone.now() - serializer = ProjectSerializer() + serializer = ProjectListSerializer() # When migration_status = serializer.get_migration_status(project) @@ -78,9 +81,32 @@ def test_ProjectSerializer_get_migration_status_calls_migrator_with_correct_argu (ProjectIdentityMigrationStatus.NOT_APPLICABLE.value, False), ], ) -def test_ProjectSerializer_get_use_edge_identities(project, migration_status, expected): +def test_ProjectListSerializer_get_use_edge_identities( + project, migration_status, expected +): # Given - serializer = ProjectSerializer(context={"migration_status": migration_status}) + serializer = ProjectListSerializer(context={"migration_status": migration_status}) + + # When/Then + assert expected is serializer.get_use_edge_identities(project) + + +@pytest.mark.parametrize( + "migration_status, expected", + [ + (ProjectIdentityMigrationStatus.MIGRATION_COMPLETED.value, True), + (ProjectIdentityMigrationStatus.MIGRATION_IN_PROGRESS.value, False), + (ProjectIdentityMigrationStatus.MIGRATION_NOT_STARTED.value, False), + (ProjectIdentityMigrationStatus.NOT_APPLICABLE.value, False), + ], +) +def test_ProjectRetrieveSerializer_get_use_edge_identities( + project, migration_status, expected +): + # Given + serializer = ProjectRetrieveSerializer( + context={"migration_status": migration_status} + ) # When/Then assert expected is serializer.get_use_edge_identities(project) diff --git a/api/projects/views.py b/api/projects/views.py index 4a47707406a8..0de363563c67 100644 --- a/api/projects/views.py +++ b/api/projects/views.py @@ -44,7 +44,8 @@ CreateUpdateUserProjectPermissionSerializer, ListUserPermissionGroupProjectPermissionSerializer, ListUserProjectPermissionSerializer, - ProjectSerializer, + ProjectListSerializer, + ProjectRetrieveSerializer, ) @@ -70,7 +71,11 @@ ), ) class ProjectViewSet(viewsets.ModelViewSet): - serializer_class = ProjectSerializer + def get_serializer_class(self): + if self.action == "retrieve": + return ProjectRetrieveSerializer + return ProjectListSerializer + permission_classes = [ProjectPermissions | MasterAPIKeyProjectPermissions] pagination_class = None diff --git a/api/tests/unit/projects/test_unit_projects_views.py b/api/tests/unit/projects/test_unit_projects_views.py new file mode 100644 index 000000000000..f2b7100e39bb --- /dev/null +++ b/api/tests/unit/projects/test_unit_projects_views.py @@ -0,0 +1,82 @@ +import pytest +from django.urls import reverse +from pytest_lazyfixture import lazy_fixture +from rest_framework import status + +from features.models import Feature +from projects.models import Project +from segments.models import Segment + + +@pytest.mark.parametrize( + "client", (lazy_fixture("admin_client"), lazy_fixture("master_api_key_client")) +) +def test_get_project_list_data(client, organisation): + # Given + list_url = reverse("api-v1:projects:project-list") + + project_name = "Test project" + hide_disabled_flags = False + enable_dynamo_db = False + prevent_flag_defaults = True + enable_realtime_updates = False + only_allow_lower_case_feature_names = True + + Project.objects.create( + name=project_name, + organisation=organisation, + hide_disabled_flags=hide_disabled_flags, + enable_dynamo_db=enable_dynamo_db, + prevent_flag_defaults=prevent_flag_defaults, + enable_realtime_updates=enable_realtime_updates, + only_allow_lower_case_feature_names=only_allow_lower_case_feature_names, + ) + + # When + response = client.get(list_url) + + # Then + assert response.status_code == status.HTTP_200_OK + assert response.json()[0]["name"] == project_name + assert response.json()[0]["hide_disabled_flags"] is hide_disabled_flags + assert response.json()[0]["enable_dynamo_db"] is enable_dynamo_db + assert response.json()[0]["prevent_flag_defaults"] is prevent_flag_defaults + assert response.json()[0]["enable_realtime_updates"] is enable_realtime_updates + assert ( + response.json()[0]["only_allow_lower_case_feature_names"] + is only_allow_lower_case_feature_names + ) + assert "max_segments_allowed" not in response.json()[0].keys() + assert "max_features_allowed" not in response.json()[0].keys() + assert "max_segment_overrides_allowed" not in response.json()[0].keys() + assert "total_features" not in response.json()[0].keys() + assert "total_segments" not in response.json()[0].keys() + + +@pytest.mark.parametrize( + "client", (lazy_fixture("admin_client"), lazy_fixture("master_api_key_client")) +) +def test_get_project_data_by_id(client, organisation, project): + # Given + url = reverse("api-v1:projects:project-detail", args=[project.id]) + + num_features = 5 + num_segments = 7 + + for i in range(num_features): + Feature.objects.create(name=f"feature_{i}", project=project) + + for i in range(num_segments): + Segment.objects.create(name=f"feature_{i}", project=project) + + # When + response = client.get(url) + + # Then + assert response.status_code == status.HTTP_200_OK + assert response.json()["name"] == project.name + assert response.json()["max_segments_allowed"] == 100 + assert response.json()["max_features_allowed"] == 400 + assert response.json()["max_segment_overrides_allowed"] == 100 + assert response.json()["total_features"] == num_features + assert response.json()["total_segments"] == num_segments