diff --git a/api/conftest.py b/api/conftest.py index 625f1094194f..e7537ea549f1 100644 --- a/api/conftest.py +++ b/api/conftest.py @@ -659,23 +659,23 @@ def task_processor_synchronously(settings): @pytest.fixture() -def a_metadata_field(organisation): +def a_metadata_field(organisation: Organisation) -> MetadataField: return MetadataField.objects.create(name="a", type="int", organisation=organisation) @pytest.fixture() -def b_metadata_field(organisation): +def b_metadata_field(organisation: Organisation) -> MetadataField: return MetadataField.objects.create(name="b", type="str", organisation=organisation) @pytest.fixture() def required_a_environment_metadata_field( - organisation, - a_metadata_field, - environment, - project, - project_content_type, -): + organisation: Organisation, + a_metadata_field: MetadataField, + environment: Environment, + project: Project, + project_content_type: ContentType, +) -> MetadataModelField: environment_type = ContentType.objects.get_for_model(environment) model_field = MetadataModelField.objects.create( field=a_metadata_field, @@ -689,7 +689,119 @@ def required_a_environment_metadata_field( @pytest.fixture() -def optional_b_environment_metadata_field(organisation, b_metadata_field, environment): +def required_a_feature_metadata_field( + organisation: Organisation, + a_metadata_field: MetadataField, + feature_content_type: ContentType, + project: Project, + project_content_type: ContentType, +) -> MetadataModelField: + model_field = MetadataModelField.objects.create( + field=a_metadata_field, + content_type=feature_content_type, + ) + + MetadataModelFieldRequirement.objects.create( + content_type=project_content_type, object_id=project.id, model_field=model_field + ) + + return model_field + + +@pytest.fixture() +def required_a_feature_metadata_field_using_organisation_content_type( + organisation: Organisation, + a_metadata_field: MetadataField, + feature_content_type: ContentType, + project: Project, + organisation_content_type: ContentType, +) -> MetadataModelField: + model_field = MetadataModelField.objects.create( + field=a_metadata_field, + content_type=feature_content_type, + ) + + MetadataModelFieldRequirement.objects.create( + content_type=organisation_content_type, + object_id=organisation.id, + model_field=model_field, + ) + + return model_field + + +@pytest.fixture() +def required_a_segment_metadata_field( + organisation: Organisation, + a_metadata_field: MetadataField, + segment_content_type: ContentType, + project: Project, + project_content_type: ContentType, +) -> MetadataModelField: + model_field = MetadataModelField.objects.create( + field=a_metadata_field, + content_type=segment_content_type, + ) + + MetadataModelFieldRequirement.objects.create( + content_type=project_content_type, object_id=project.id, model_field=model_field + ) + + return model_field + + +@pytest.fixture() +def required_a_segment_metadata_field_using_organisation_content_type( + organisation: Organisation, + a_metadata_field: MetadataField, + segment_content_type: ContentType, + project: Project, + organisation_content_type: ContentType, +) -> MetadataModelField: + model_field = MetadataModelField.objects.create( + field=a_metadata_field, + content_type=segment_content_type, + ) + + MetadataModelFieldRequirement.objects.create( + content_type=organisation_content_type, + object_id=organisation.id, + model_field=model_field, + ) + + return model_field + + +@pytest.fixture() +def optional_b_feature_metadata_field( + organisation: Organisation, b_metadata_field: MetadataField, feature: Feature +) -> MetadataModelField: + feature_type = ContentType.objects.get_for_model(feature) + + return MetadataModelField.objects.create( + field=b_metadata_field, + content_type=feature_type, + ) + + +@pytest.fixture() +def optional_b_segment_metadata_field( + organisation: Organisation, b_metadata_field: MetadataField, segment: Segment +) -> MetadataModelField: + segment_type = ContentType.objects.get_for_model(segment) + + return MetadataModelField.objects.create( + field=b_metadata_field, + content_type=segment_type, + ) + + +@pytest.fixture() +def optional_b_environment_metadata_field( + organisation: Organisation, + b_metadata_field: MetadataField, + environment: Environment, +) -> MetadataModelField: environment_type = ContentType.objects.get_for_model(environment) return MetadataModelField.objects.create( @@ -699,7 +811,10 @@ def optional_b_environment_metadata_field(organisation, b_metadata_field, enviro @pytest.fixture() -def environment_metadata_a(environment, required_a_environment_metadata_field): +def environment_metadata_a( + environment: Environment, + required_a_environment_metadata_field: MetadataModelField, +) -> Metadata: environment_type = ContentType.objects.get_for_model(environment) return Metadata.objects.create( object_id=environment.id, @@ -710,7 +825,10 @@ def environment_metadata_a(environment, required_a_environment_metadata_field): @pytest.fixture() -def environment_metadata_b(environment, optional_b_environment_metadata_field): +def environment_metadata_b( + environment: Environment, + optional_b_environment_metadata_field: MetadataModelField, +) -> Metadata: environment_type = ContentType.objects.get_for_model(environment) return Metadata.objects.create( object_id=environment.id, @@ -721,15 +839,30 @@ def environment_metadata_b(environment, optional_b_environment_metadata_field): @pytest.fixture() -def environment_content_type(): +def environment_content_type() -> ContentType: return ContentType.objects.get_for_model(Environment) @pytest.fixture() -def project_content_type(): +def feature_content_type() -> ContentType: + return ContentType.objects.get_for_model(Feature) + + +@pytest.fixture() +def segment_content_type() -> ContentType: + return ContentType.objects.get_for_model(Segment) + + +@pytest.fixture() +def project_content_type() -> ContentType: return ContentType.objects.get_for_model(Project) +@pytest.fixture() +def organisation_content_type() -> ContentType: + return ContentType.objects.get_for_model(Organisation) + + @pytest.fixture def manage_user_group_permission(db): return OrganisationPermissionModel.objects.get(key=MANAGE_USER_GROUPS) diff --git a/api/features/models.py b/api/features/models.py index b00737bba703..59014edac2c4 100644 --- a/api/features/models.py +++ b/api/features/models.py @@ -12,6 +12,7 @@ SoftDeleteExportableModel, abstract_base_auditable_model_factory, ) +from django.contrib.contenttypes.fields import GenericRelation from django.core.exceptions import ( NON_FIELD_ERRORS, ObjectDoesNotExist, @@ -75,6 +76,7 @@ ) from features.versioning.models import EnvironmentFeatureVersion from integrations.github.models import GithubConfiguration +from metadata.models import Metadata from projects.models import Project from projects.tags.models import Tag @@ -129,6 +131,8 @@ class Feature( objects = FeatureManager() + metadata = GenericRelation(Metadata) + class Meta: # Note: uniqueness index is added in explicit SQL in the migrations (See 0005, 0050) # TODO: after upgrade to Django 4.0 use UniqueConstraint() diff --git a/api/features/serializers.py b/api/features/serializers.py index b5f5748988e0..242710c65de9 100644 --- a/api/features/serializers.py +++ b/api/features/serializers.py @@ -12,6 +12,7 @@ from environments.sdk.serializers_mixins import ( HideSensitiveFieldsSerializerMixin, ) +from metadata.serializers import MetadataSerializer, SerializerWithMetadata from projects.models import Project from users.serializers import ( UserIdsSerializer, @@ -296,17 +297,44 @@ def get_last_modified_in_current_environment( return getattr(instance, "last_modified_in_current_environment", None) -class ListFeatureSerializer(CreateFeatureSerializer): +class FeatureSerializerWithMetadata(SerializerWithMetadata, CreateFeatureSerializer): + metadata = MetadataSerializer(required=False, many=True) + + class Meta(CreateFeatureSerializer.Meta): + fields = CreateFeatureSerializer.Meta.fields + ("metadata",) + + def get_project(self, validated_data: dict = None) -> Project: + project = self.context.get("project") + if project: + return project + else: + raise serializers.ValidationError( + "Unable to retrieve project for metadata validation." + ) + + +class UpdateFeatureSerializerWithMetadata(FeatureSerializerWithMetadata): + """prevent users from changing certain values after creation""" + + class Meta(FeatureSerializerWithMetadata.Meta): + read_only_fields = FeatureSerializerWithMetadata.Meta.read_only_fields + ( + "default_enabled", + "initial_value", + "name", + ) + + +class ListFeatureSerializer(FeatureSerializerWithMetadata): # This exists purely to reduce the conflicts for the EE repository # which has some extra behaviour here to support Oracle DB. pass -class UpdateFeatureSerializer(CreateFeatureSerializer): +class UpdateFeatureSerializer(ListFeatureSerializer): """prevent users from changing certain values after creation""" - class Meta(CreateFeatureSerializer.Meta): - read_only_fields = CreateFeatureSerializer.Meta.read_only_fields + ( + class Meta(ListFeatureSerializer.Meta): + read_only_fields = ListFeatureSerializer.Meta.read_only_fields + ( "default_enabled", "initial_value", "name", diff --git a/api/features/views.py b/api/features/views.py index 323bef8dd175..6ac0312b8ebe 100644 --- a/api/features/views.py +++ b/api/features/views.py @@ -107,8 +107,8 @@ class FeatureViewSet(viewsets.ModelViewSet): def get_serializer_class(self): return { "list": ListFeatureSerializer, - "retrieve": CreateFeatureSerializer, - "create": CreateFeatureSerializer, + "retrieve": ListFeatureSerializer, + "create": ListFeatureSerializer, "update": UpdateFeatureSerializer, "partial_update": UpdateFeatureSerializer, }.get(self.action, ProjectFeatureSerializer) @@ -131,7 +131,9 @@ def get_queryset(self): ), ), ) - .prefetch_related("multivariate_options", "owners", "tags", "group_owners") + .prefetch_related( + "multivariate_options", "owners", "tags", "group_owners", "metadata" + ) ) query_serializer = FeatureQuerySerializer(data=self.request.query_params) diff --git a/api/metadata/models.py b/api/metadata/models.py index ccdd376780c8..26b228a5e283 100644 --- a/api/metadata/models.py +++ b/api/metadata/models.py @@ -6,22 +6,16 @@ from django.db import models from organisations.models import Organisation -from projects.models import Project from .fields import GenericObjectID FIELD_VALUE_MAX_LENGTH = 2000 -METADATA_SUPPORTED_MODELS = ["environment"] - # A map of model name to a function that takes the object id and returns the organisation_id SUPPORTED_REQUIREMENTS_MAPPING = { - "environment": { - "organisation": lambda org_id: org_id, - "project": lambda project_id: Project.objects.get( - id=project_id - ).organisation_id, - } + "environment": ["organisation", "project"], + "feature": ["organisation", "project"], + "segment": ["organisation", "project"], } diff --git a/api/metadata/serializers.py b/api/metadata/serializers.py index 3d5f9a1e6f50..a6fdddf89582 100644 --- a/api/metadata/serializers.py +++ b/api/metadata/serializers.py @@ -9,7 +9,6 @@ ) from .models import ( - SUPPORTED_REQUIREMENTS_MAPPING, Metadata, MetadataField, MetadataModelField, @@ -55,21 +54,13 @@ class Meta: def validate(self, data): data = super().validate(data) for requirement in data.get("is_required_for", []): - try: - get_org_id_func = SUPPORTED_REQUIREMENTS_MAPPING[ - data["content_type"].model - ][requirement["content_type"].model] - except KeyError: - raise serializers.ValidationError( - "Invalid requirement for model {}".format( - data["content_type"].model - ) - ) - - if ( - get_org_id_func(requirement["object_id"]) - != data["field"].organisation_id - ): + org_id = ( + requirement["content_type"] + .model_class() + .objects.get(id=requirement["object_id"]) + .organisation_id + ) + if org_id != data["field"].organisation_id: raise serializers.ValidationError( "The requirement organisation does not match the field organisation" ) diff --git a/api/metadata/views.py b/api/metadata/views.py index 098e266318dc..49680ee866ec 100644 --- a/api/metadata/views.py +++ b/api/metadata/views.py @@ -1,13 +1,14 @@ +from itertools import chain + from django.contrib.contenttypes.models import ContentType from django.utils.decorators import method_decorator from drf_yasg.utils import swagger_auto_schema -from rest_framework import viewsets +from rest_framework import status, viewsets from rest_framework.decorators import action from rest_framework.exceptions import ValidationError from rest_framework.response import Response from .models import ( - METADATA_SUPPORTED_MODELS, SUPPORTED_REQUIREMENTS_MAPPING, MetadataField, MetadataModelField, @@ -81,7 +82,13 @@ def get_queryset(self): url_path="supported-content-types", ) def supported_content_types(self, request, organisation_pk=None): - qs = ContentType.objects.filter(model__in=METADATA_SUPPORTED_MODELS) + need_content_type_of = list( + chain.from_iterable( + (key, *value) for key, value in SUPPORTED_REQUIREMENTS_MAPPING.items() + ) + ) + + qs = ContentType.objects.filter(model__in=need_content_type_of) serializer = ContentTypeSerializer(qs, many=True) return Response(serializer.data) @@ -100,11 +107,16 @@ def supported_required_for_models(self, request, organisation_pk=None): serializer = SupportedRequiredForModelQuerySerializer(data=request.query_params) serializer.is_valid(raise_exception=True) - qs = ContentType.objects.filter( - model__in=SUPPORTED_REQUIREMENTS_MAPPING.get( - serializer.data["model_name"], {} - ).keys() + supported_models = SUPPORTED_REQUIREMENTS_MAPPING.get( + serializer.data["model_name"], [] ) + if not supported_models: + return Response( + {"message": "No supported models found for the given model name."}, + status=status.HTTP_404_NOT_FOUND, + ) + + qs = ContentType.objects.filter(model__in=supported_models) serializer = ContentTypeSerializer(qs, many=True) return Response(serializer.data) diff --git a/api/segments/models.py b/api/segments/models.py index 78ddd68ae31d..dea6f980a7ea 100644 --- a/api/segments/models.py +++ b/api/segments/models.py @@ -7,6 +7,7 @@ abstract_base_auditable_model_factory, ) from django.conf import settings +from django.contrib.contenttypes.fields import GenericRelation from django.core.exceptions import ValidationError from django.db import models from flag_engine.segments import constants @@ -18,6 +19,7 @@ ) from audit.related_object_type import RelatedObjectType from features.models import Feature +from metadata.models import Metadata from projects.models import Project logger = logging.getLogger(__name__) @@ -43,6 +45,8 @@ class Segment( Feature, on_delete=models.CASCADE, related_name="segments", null=True ) + metadata = GenericRelation(Metadata) + class Meta: ordering = ("id",) # explicit ordering to prevent pagination warnings diff --git a/api/segments/serializers.py b/api/segments/serializers.py index f5364c408449..1909ddebafd7 100644 --- a/api/segments/serializers.py +++ b/api/segments/serializers.py @@ -1,12 +1,15 @@ import typing from django.conf import settings +from django.contrib.contenttypes.models import ContentType from flag_engine.segments.constants import PERCENTAGE_SPLIT from rest_framework import serializers from rest_framework.exceptions import ValidationError from rest_framework.serializers import ListSerializer from rest_framework_recursive.fields import RecursiveField +from metadata.models import Metadata +from metadata.serializers import MetadataSerializer, SerializerWithMetadata from projects.models import Project from segments.models import Condition, Segment, SegmentRule @@ -40,25 +43,32 @@ class Meta: fields = ("id", "type", "rules", "conditions", "delete") -class SegmentSerializer(serializers.ModelSerializer): +class SegmentSerializer(serializers.ModelSerializer, SerializerWithMetadata): rules = RuleSerializer(many=True) + metadata = MetadataSerializer(required=False, many=True) class Meta: model = Segment fields = "__all__" def validate(self, attrs): + attrs = super().validate(attrs) + self.validate_required_metadata(attrs) if not attrs.get("rules"): raise ValidationError( {"rules": "Segment cannot be created without any rules."} ) return attrs + def get_project(self, validated_data: dict = None) -> Project: + return validated_data.get("project") + def create(self, validated_data): project = validated_data["project"] self.validate_project_segment_limit(project) rules_data = validated_data.pop("rules", []) + metadata_data = validated_data.pop("metadata", []) self.validate_segment_rules_conditions_limit(rules_data) # create segment with nested rules and conditions @@ -66,13 +76,16 @@ def create(self, validated_data): self._update_or_create_segment_rules( rules_data, segment=segment, is_create=True ) + self._update_or_create_metadata(metadata_data, segment=segment) return segment def update(self, instance, validated_data): # use the initial data since we need the ids included to determine which to update & which to create rules_data = self.initial_data.pop("rules", []) + metadata_data = validated_data.pop("metadata", []) self.validate_segment_rules_conditions_limit(rules_data) self._update_segment_rules(rules_data, segment=instance) + self._update_or_create_metadata(metadata_data, segment=instance) # remove rules from validated data to prevent error trying to create segment with nested rules del validated_data["rules"] return super().update(instance, validated_data) @@ -156,6 +169,28 @@ def _update_or_create_segment_rules( child_rules, rule=child_rule, is_create=is_create ) + def _update_or_create_metadata( + self, metadata_data: typing.Dict, segment: typing.Optional[Segment] = None + ) -> None: + if len(metadata_data) == 0: + Metadata.objects.filter(object_id=segment.id).delete() + return + if metadata_data is not None: + for metadata_item in metadata_data: + metadata_model_field = metadata_item.pop("model_field", None) + if metadata_item.get("delete"): + Metadata.objects.filter(model_field=metadata_model_field).delete() + continue + + Metadata.objects.update_or_create( + model_field=metadata_model_field, + defaults={ + **metadata_item, + "content_type": ContentType.objects.get_for_model(Segment), + "object_id": segment.id, + }, + ) + @staticmethod def _update_or_create_segment_rule( rule_data: dict, segment: Segment = None, rule: SegmentRule = None diff --git a/api/tests/unit/features/test_unit_features_views.py b/api/tests/unit/features/test_unit_features_views.py index e63ac8cb6e9d..c579d1ab0af9 100644 --- a/api/tests/unit/features/test_unit_features_views.py +++ b/api/tests/unit/features/test_unit_features_views.py @@ -13,6 +13,7 @@ from freezegun import freeze_time from pytest_django import DjangoAssertNumQueries from pytest_django.fixtures import SettingsWrapper +from pytest_lazyfixture import lazy_fixture from pytest_mock import MockerFixture from rest_framework import status from rest_framework.test import APIClient @@ -36,6 +37,7 @@ from features.multivariate.models import MultivariateFeatureOption from features.value_types import BOOLEAN, INTEGER, STRING from features.versioning.models import EnvironmentFeatureVersion +from metadata.models import MetadataModelField from organisations.models import Organisation, OrganisationRole from projects.models import Project, UserProjectPermission from projects.permissions import CREATE_FEATURE, VIEW_PROJECT @@ -2294,6 +2296,138 @@ def test_cannot_update_feature_of_a_feature_state( ) +@pytest.mark.parametrize( + "client", + [lazy_fixture("admin_master_api_key_client"), lazy_fixture("admin_client")], +) +def test_create_feature_without_required_metadata_returns_400( + project: Project, + client: APIClient, + required_a_feature_metadata_field: MetadataModelField, +) -> None: + # Given + url = reverse("api-v1:projects:project-features-list", args=[project.id]) + description = "This is the description" + data = { + "name": "Test feature", + "description": description, + } + + # When + response = client.post(url, data=json.dumps(data), content_type="application/json") + + # Then + assert response.status_code == status.HTTP_400_BAD_REQUEST + + +@pytest.mark.parametrize( + "client", + [lazy_fixture("admin_master_api_key_client"), lazy_fixture("admin_client")], +) +def test_create_feature_with_optional_metadata_returns_201( + project: Project, + client: APIClient, + optional_b_feature_metadata_field: MetadataModelField, +) -> None: + # Given + url = reverse("api-v1:projects:project-features-list", args=[project.id]) + description = "This is the description" + field_value = 10 + data = { + "name": "Test feature", + "description": description, + "metadata": [ + { + "model_field": optional_b_feature_metadata_field.id, + "field_value": field_value, + }, + ], + } + + # When + response = client.post(url, data=json.dumps(data), content_type="application/json") + + # Then + assert response.status_code == status.HTTP_201_CREATED + assert ( + response.json()["metadata"][0]["model_field"] + == optional_b_feature_metadata_field.id + ) + assert response.json()["metadata"][0]["field_value"] == str(field_value) + + +@pytest.mark.parametrize( + "client", + [lazy_fixture("admin_master_api_key_client"), lazy_fixture("admin_client")], +) +def test_create_feature_with_required_metadata_returns_201( + project: Project, + client: APIClient, + required_a_feature_metadata_field: MetadataModelField, +) -> None: + # Given + url = reverse("api-v1:projects:project-features-list", args=[project.id]) + description = "This is the description" + field_value = 10 + data = { + "name": "Test feature", + "description": description, + "metadata": [ + { + "model_field": required_a_feature_metadata_field.id, + "field_value": field_value, + }, + ], + } + + # When + response = client.post(url, data=json.dumps(data), content_type="application/json") + + # Then + assert response.status_code == status.HTTP_201_CREATED + assert ( + response.json()["metadata"][0]["model_field"] + == required_a_feature_metadata_field.id + ) + assert response.json()["metadata"][0]["field_value"] == str(field_value) + + +@pytest.mark.parametrize( + "client", + [lazy_fixture("admin_master_api_key_client"), lazy_fixture("admin_client")], +) +def test_create_feature_with_required_metadata_using_organisation_content_typereturns_201( + project: Project, + client: APIClient, + required_a_feature_metadata_field_using_organisation_content_type: MetadataModelField, +) -> None: + # Given + url = reverse("api-v1:projects:project-features-list", args=[project.id]) + description = "This is the description" + field_value = 10 + data = { + "name": "Test feature", + "description": description, + "metadata": [ + { + "model_field": required_a_feature_metadata_field_using_organisation_content_type.id, + "field_value": field_value, + }, + ], + } + + # When + response = client.post(url, data=json.dumps(data), content_type="application/json") + + # Then + assert response.status_code == status.HTTP_201_CREATED + assert ( + response.json()["metadata"][0]["model_field"] + == required_a_feature_metadata_field_using_organisation_content_type.id + ) + assert response.json()["metadata"][0]["field_value"] == str(field_value) + + def test_create_segment_override__using_simple_feature_state_viewset__allows_manage_segment_overrides( staff_client: APIClient, with_environment_permissions: WithEnvironmentPermissionsCallable, @@ -2466,7 +2600,7 @@ def test_list_features_n_plus_1( v1_feature_state.clone(env=environment, version=i, live_from=timezone.now()) # When - with django_assert_num_queries(16): + with django_assert_num_queries(17): response = staff_client.get(url) # Then @@ -2674,7 +2808,7 @@ def test_list_features_with_feature_state( url = f"{base_url}?environment={environment.id}" # When - with django_assert_num_queries(16): + with django_assert_num_queries(17): response = staff_client.get(url) # Then @@ -2968,7 +3102,7 @@ def test_feature_list_last_modified_values( Feature.objects.create(name=f"feature_{i}", project=project) # When - with django_assert_num_queries(18): # TODO: reduce this number of queries! + with django_assert_num_queries(19): # TODO: reduce this number of queries! response = staff_client.get(url) # Then diff --git a/api/tests/unit/metadata/test_views.py b/api/tests/unit/metadata/test_views.py index 18680ce0d07d..2dc750d0d3f5 100644 --- a/api/tests/unit/metadata/test_views.py +++ b/api/tests/unit/metadata/test_views.py @@ -1,11 +1,19 @@ import json +from itertools import chain from django.contrib.contenttypes.models import ContentType from django.urls import reverse from rest_framework import status +from rest_framework.test import APIClient -from metadata.models import MetadataModelField, MetadataModelFieldRequirement -from metadata.views import METADATA_SUPPORTED_MODELS +from metadata.models import ( + MetadataField, + MetadataModelField, + MetadataModelFieldRequirement, +) +from metadata.views import SUPPORTED_REQUIREMENTS_MAPPING +from organisations.models import Organisation +from projects.models import Project def test_can_create_metadata_field(admin_client, organisation): @@ -267,15 +275,14 @@ def test_can_not_update_model_metadata_field_from_other_organisation( assert response.status_code == status.HTTP_404_NOT_FOUND -def test_create_model_metadata_field( - admin_client, - a_metadata_field, - organisation, - environment, - project_content_type, - environment_content_type, - project, -): +def test_create_model_metadata_field_for_environments( + admin_client: APIClient, + a_metadata_field: MetadataField, + organisation: Organisation, + project_content_type: ContentType, + environment_content_type: ContentType, + project: Project, +) -> None: # Given url = reverse( "api-v1:organisations:metadata-model-fields-list", args=[organisation.id] @@ -301,6 +308,72 @@ def test_create_model_metadata_field( } +def test_create_model_metadata_field_for_features( + admin_client: APIClient, + a_metadata_field: MetadataField, + organisation: Organisation, + project_content_type: ContentType, + feature_content_type: ContentType, + project: Project, +) -> None: + # Given + url = reverse( + "api-v1:organisations:metadata-model-fields-list", args=[organisation.id] + ) + data = { + "field": a_metadata_field.id, + "is_required_for": [ + {"content_type": project_content_type.id, "object_id": project.id} + ], + "content_type": feature_content_type.id, + } + + # When + response = admin_client.post( + url, data=json.dumps(data), content_type="application/json" + ) + # Then + assert response.status_code == status.HTTP_201_CREATED + assert response.json()["field"] == a_metadata_field.id + assert response.json()["is_required_for"][0] == { + "content_type": project_content_type.id, + "object_id": project.id, + } + + +def test_create_model_metadata_field_for_segments( + admin_client: APIClient, + a_metadata_field: MetadataField, + organisation: Organisation, + project_content_type: ContentType, + segment_content_type: ContentType, + project: Project, +) -> None: + # Given + url = reverse( + "api-v1:organisations:metadata-model-fields-list", args=[organisation.id] + ) + data = { + "field": a_metadata_field.id, + "is_required_for": [ + {"content_type": project_content_type.id, "object_id": project.id} + ], + "content_type": segment_content_type.id, + } + + # When + response = admin_client.post( + url, data=json.dumps(data), content_type="application/json" + ) + # Then + assert response.status_code == status.HTTP_201_CREATED + assert response.json()["field"] == a_metadata_field.id + assert response.json()["is_required_for"][0] == { + "content_type": project_content_type.id, + "object_id": project.id, + } + + def test_can_not_create_model_metadata_field_using_field_from_other_organisation( admin_client, environment_metadata_field_different_org, organisation, project ): @@ -322,22 +395,31 @@ def test_can_not_create_model_metadata_field_using_field_from_other_organisation assert response.status_code == status.HTTP_403_FORBIDDEN -def test_get_supported_content_type(admin_client, organisation): +def test_get_supported_content_type( + admin_client: APIClient, organisation: Organisation +): # Given url = reverse( "api-v1:organisations:metadata-model-fields-supported-content-types", args=[organisation.id], ) + + supported_models = list( + chain.from_iterable( + (key, *value) for key, value in SUPPORTED_REQUIREMENTS_MAPPING.items() + ) + ) + # When response = admin_client.get(url) # Then assert response.status_code == status.HTTP_200_OK - assert len(response.json()) == len(METADATA_SUPPORTED_MODELS) - assert set(content_type["model"] for content_type in response.json()) == set( - METADATA_SUPPORTED_MODELS - ) + response_models = set(content_type["model"] for content_type in response.json()) + + for model in response_models: + assert model in supported_models def test_get_supported_required_for_models(admin_client, organisation): diff --git a/api/tests/unit/segments/test_unit_segments_views.py b/api/tests/unit/segments/test_unit_segments_views.py index d990bc4dc4f2..74fa7ae22127 100644 --- a/api/tests/unit/segments/test_unit_segments_views.py +++ b/api/tests/unit/segments/test_unit_segments_views.py @@ -15,6 +15,7 @@ from audit.related_object_type import RelatedObjectType from environments.models import Environment from features.models import Feature +from metadata.models import MetadataModelField from projects.models import Project from projects.permissions import MANAGE_SEGMENTS, VIEW_PROJECT from segments.models import Condition, Segment, SegmentRule, WhitelistedSegment @@ -336,8 +337,8 @@ def test_get_segment_by_uuid(client, project, segment): @pytest.mark.parametrize( "client, num_queries", [ - (lazy_fixture("admin_master_api_key_client"), 11), - (lazy_fixture("admin_client"), 10), + (lazy_fixture("admin_master_api_key_client"), 16), + (lazy_fixture("admin_client"), 15), ], ) def test_list_segments(django_assert_num_queries, project, client, num_queries): @@ -608,6 +609,108 @@ def test_update_segment_delete_existing_rule(project, client, segment, segment_r assert segment_rule.conditions.count() == 0 +@pytest.mark.parametrize( + "client", + [lazy_fixture("admin_master_api_key_client"), lazy_fixture("admin_client")], +) +def test_create_segment_with_required_metadata_returns_201( + project: Project, + client: APIClient, + required_a_segment_metadata_field: MetadataModelField, +) -> None: + # Given + url = reverse("api-v1:projects:project-segments-list", args=[project.id]) + description = "This is the description" + field_value = 10 + data = { + "name": "Test Segment", + "description": description, + "project": project.id, + "rules": [{"type": "ALL", "rules": [], "conditions": []}], + "metadata": [ + { + "model_field": required_a_segment_metadata_field.id, + "field_value": field_value, + }, + ], + } + + # When + response = client.post(url, data=json.dumps(data), content_type="application/json") + + # Then + assert response.status_code == status.HTTP_201_CREATED + assert ( + response.json()["metadata"][0]["model_field"] + == required_a_segment_metadata_field.id + ) + assert response.json()["metadata"][0]["field_value"] == str(field_value) + + +@pytest.mark.parametrize( + "client", + [lazy_fixture("admin_master_api_key_client"), lazy_fixture("admin_client")], +) +def test_create_segment_with_required_metadata_using_organisation_content_type_returns_201( + project: Project, + client: APIClient, + required_a_segment_metadata_field_using_organisation_content_type: MetadataModelField, +) -> None: + # Given + url = reverse("api-v1:projects:project-segments-list", args=[project.id]) + description = "This is the description" + field_value = 10 + data = { + "name": "Test Segment", + "description": description, + "project": project.id, + "rules": [{"type": "ALL", "rules": [], "conditions": []}], + "metadata": [ + { + "model_field": required_a_segment_metadata_field_using_organisation_content_type.id, + "field_value": field_value, + }, + ], + } + + # When + response = client.post(url, data=json.dumps(data), content_type="application/json") + + # Then + assert response.status_code == status.HTTP_201_CREATED + assert ( + response.json()["metadata"][0]["model_field"] + == required_a_segment_metadata_field_using_organisation_content_type.id + ) + assert response.json()["metadata"][0]["field_value"] == str(field_value) + + +@pytest.mark.parametrize( + "client", + [lazy_fixture("admin_master_api_key_client"), lazy_fixture("admin_client")], +) +def test_create_segment_without_required_metadata_returns_400( + project: Project, + client: APIClient, + required_a_segment_metadata_field: MetadataModelField, +) -> None: + # Given + url = reverse("api-v1:projects:project-segments-list", args=[project.id]) + description = "This is the description" + data = { + "name": "Test Segment", + "description": description, + "project": project.id, + "rules": [{"type": "ALL", "rules": [], "conditions": []}], + } + + # When + response = client.post(url, data=json.dumps(data), content_type="application/json") + + # Then + assert response.status_code == status.HTTP_400_BAD_REQUEST + + def test_update_segment_obeys_max_conditions( project: Project, admin_client: APIClient, @@ -683,6 +786,44 @@ def test_update_segment_obeys_max_conditions( assert nested_rule.conditions.count() == 1 +@pytest.mark.parametrize( + "client", + [lazy_fixture("admin_master_api_key_client"), lazy_fixture("admin_client")], +) +def test_create_segment_with_optional_metadata_returns_201( + project: Project, + client: APIClient, + optional_b_segment_metadata_field: MetadataModelField, +) -> None: + # Given + url = reverse("api-v1:projects:project-segments-list", args=[project.id]) + description = "This is the description" + field_value = 10 + data = { + "name": "Test Segment", + "description": description, + "project": project.id, + "rules": [{"type": "ALL", "rules": [], "conditions": []}], + "metadata": [ + { + "model_field": optional_b_segment_metadata_field.id, + "field_value": field_value, + }, + ], + } + + # When + response = client.post(url, data=json.dumps(data), content_type="application/json") + + # Then + assert response.status_code == status.HTTP_201_CREATED + assert ( + response.json()["metadata"][0]["model_field"] + == optional_b_segment_metadata_field.id + ) + assert response.json()["metadata"][0]["field_value"] == str(field_value) + + def test_update_segment_evades_max_conditions_when_whitelisted( project: Project, admin_client: APIClient, @@ -809,7 +950,6 @@ def test_create_segment_obeys_max_conditions( assert response.json() == { "segment": "The segment has 11 conditions, which exceeds the maximum condition count of 10." } - assert Segment.objects.count() == 0