From ff962a413ef0c909da17960b1cab2f206d7d2ae1 Mon Sep 17 00:00:00 2001 From: Matthew Elwell Date: Fri, 16 Feb 2024 10:29:04 +0000 Subject: [PATCH] Add query param to exclude / include feature specific --- api/conftest.py | 7 +++ api/segments/serializers.py | 12 +++++ api/segments/views.py | 35 +++++--------- .../unit/segments/test_unit_segments_views.py | 47 +++++++++++++++++++ 4 files changed, 79 insertions(+), 22 deletions(-) diff --git a/api/conftest.py b/api/conftest.py index fe9ed05b92ac..df1fd16de462 100644 --- a/api/conftest.py +++ b/api/conftest.py @@ -184,6 +184,13 @@ def segment_rule(segment): return SegmentRule.objects.create(segment=segment, type=SegmentRule.ALL_RULE) +@pytest.fixture() +def feature_specific_segment(feature: Feature) -> Segment: + return Segment.objects.create( + feature=feature, name="feature specific segment", project=feature.project + ) + + @pytest.fixture() def environment(project): return Environment.objects.create(name="Test Environment", project=project) diff --git a/api/segments/serializers.py b/api/segments/serializers.py index 4fc9597dbdaa..f5364c408449 100644 --- a/api/segments/serializers.py +++ b/api/segments/serializers.py @@ -188,3 +188,15 @@ class SegmentSerializerBasic(serializers.ModelSerializer): class Meta: model = Segment fields = ("id", "name", "description") + + +class SegmentListQuerySerializer(serializers.Serializer): + q = serializers.CharField( + required=False, + help_text="Search term to find segment with given term in their name", + ) + identity = serializers.CharField( + required=False, + help_text="Optionally provide the id of an identity to get only the segments they match", + ) + include_feature_specific = serializers.BooleanField(required=False, default=True) diff --git a/api/segments/views.py b/api/segments/views.py index ec864edf20f6..63673b02447a 100644 --- a/api/segments/views.py +++ b/api/segments/views.py @@ -1,7 +1,6 @@ import logging from django.utils.decorators import method_decorator -from drf_yasg import openapi from drf_yasg.utils import swagger_auto_schema from rest_framework import viewsets from rest_framework.decorators import action, api_view @@ -17,31 +16,14 @@ from .models import Segment from .permissions import SegmentPermissions -from .serializers import SegmentSerializer +from .serializers import SegmentListQuerySerializer, SegmentSerializer logger = logging.getLogger() @method_decorator( name="list", - decorator=swagger_auto_schema( - manual_parameters=[ - openapi.Parameter( - "identity", - openapi.IN_QUERY, - "Optionally provide the id of an identity to get only the segments they match", - required=False, - type=openapi.TYPE_INTEGER, - ), - openapi.Parameter( - "q", - openapi.IN_QUERY, - "Search term to find segment with given term in their name", - required=False, - type=openapi.TYPE_STRING, - ), - ] - ), + decorator=swagger_auto_schema(query_serializer=SegmentListQuerySerializer()), ) class SegmentViewSet(viewsets.ModelViewSet): serializer_class = SegmentSerializer @@ -70,7 +52,10 @@ def get_queryset(self): "rules__rules__rules", ) - identity_pk = self.request.query_params.get("identity") + query_serializer = SegmentListQuerySerializer(data=self.request.query_params) + query_serializer.is_valid(raise_exception=True) + + identity_pk = query_serializer.validated_data.get("identity") if identity_pk: if identity_pk.isdigit(): identity = Identity.objects.get(pk=identity_pk) @@ -79,10 +64,16 @@ def get_queryset(self): segment_ids = EdgeIdentity.dynamo_wrapper.get_segment_ids(identity_pk) queryset = queryset.filter(id__in=segment_ids) - search_term = self.request.query_params.get("q") + search_term = query_serializer.validated_data.get("q") if search_term: queryset = queryset.filter(name__icontains=search_term) + include_feature_specific = query_serializer.validated_data[ + "include_feature_specific" + ] + if include_feature_specific is False: + queryset = queryset.filter(feature__isnull=True) + return queryset @action( diff --git a/api/tests/unit/segments/test_unit_segments_views.py b/api/tests/unit/segments/test_unit_segments_views.py index 54df65ad9cc8..155838f1a80f 100644 --- a/api/tests/unit/segments/test_unit_segments_views.py +++ b/api/tests/unit/segments/test_unit_segments_views.py @@ -15,7 +15,9 @@ from environments.models import Environment from features.models import Feature from projects.models import Project +from projects.permissions import MANAGE_SEGMENTS, VIEW_PROJECT from segments.models import Condition, Segment, SegmentRule, WhitelistedSegment +from tests.types import WithProjectPermissionsCallable from util.mappers import map_identity_to_identity_document User = get_user_model() @@ -782,3 +784,48 @@ def test_create_segment_obeys_max_conditions( } assert Segment.objects.count() == 0 + + +def test_include_feature_specific_query_filter__true( + staff_client: APIClient, + with_project_permissions: WithProjectPermissionsCallable, + project: Project, + segment: Segment, + feature_specific_segment: Segment, +) -> None: + # Given + with_project_permissions([MANAGE_SEGMENTS, VIEW_PROJECT]) + url = "%s?include_feature_specific=1" % ( + reverse("api-v1:projects:project-segments-list", args=[project.id]), + ) + + # When + response = staff_client.get(url) + + # Then + assert response.json()["count"] == 2 + assert [res["id"] for res in response.json()["results"]] == [ + segment.id, + feature_specific_segment.id, + ] + + +def test_include_feature_specific_query_filter__false( + staff_client: APIClient, + with_project_permissions: WithProjectPermissionsCallable, + project: Project, + segment: Segment, + feature_specific_segment: Segment, +) -> None: + # Given + with_project_permissions([MANAGE_SEGMENTS, VIEW_PROJECT]) + url = "%s?include_feature_specific=0" % ( + reverse("api-v1:projects:project-segments-list", args=[project.id]), + ) + + # When + response = staff_client.get(url) + + # Then + assert response.json()["count"] == 1 + assert [res["id"] for res in response.json()["results"]] == [segment.id]