From e5058ae01cca1ceb783c38d2eb29c83f07a86a8c Mon Sep 17 00:00:00 2001 From: Kim Gustyr Date: Mon, 27 Nov 2023 11:51:34 +0000 Subject: [PATCH] fix: Revert to Core API segment evaluation (#3036) --- api/conftest.py | 3 +- .../tests/test_dynamodb_identity_wrapper.py | 3 +- .../identities/tests/test_models.py | 18 +- .../identities/tests/test_views.py | 18 +- .../feature_segments/tests/test_models.py | 3 +- api/segments/models.py | 264 +++++++++++++++--- api/segments/serializers.py | 3 +- api/segments/tests/test_models.py | 46 +++ ...test_environments_views_sdk_environment.py | 3 +- .../test_unit_environments_views.py | 3 +- .../test_unit_import_export_export.py | 4 +- .../unit/projects/test_unit_projects_admin.py | 3 +- api/tests/unit/segments/test_conditions.py | 157 +++++++++++ .../unit/segments/test_unit_segments_views.py | 3 +- api/util/mappers/engine.py | 65 ++--- 15 files changed, 480 insertions(+), 116 deletions(-) create mode 100644 api/segments/tests/test_models.py create mode 100644 api/tests/unit/segments/test_conditions.py diff --git a/api/conftest.py b/api/conftest.py index 71ea7d54d92d..146788a42ce0 100644 --- a/api/conftest.py +++ b/api/conftest.py @@ -3,7 +3,6 @@ import pytest from django.contrib.contenttypes.models import ContentType from django.core.cache import cache -from flag_engine.segments.constants import EQUAL from rest_framework.authtoken.models import Token from rest_framework.test import APIClient @@ -47,7 +46,7 @@ ) from projects.permissions import VIEW_PROJECT from projects.tags.models import Tag -from segments.models import Condition, Segment, SegmentRule +from segments.models import EQUAL, Condition, Segment, SegmentRule from task_processor.task_run_method import TaskRunMethod from users.models import FFAdminUser, UserPermissionGroup diff --git a/api/environments/dynamodb/tests/test_dynamodb_identity_wrapper.py b/api/environments/dynamodb/tests/test_dynamodb_identity_wrapper.py index f1ca7411506b..262b83f1d7c7 100644 --- a/api/environments/dynamodb/tests/test_dynamodb_identity_wrapper.py +++ b/api/environments/dynamodb/tests/test_dynamodb_identity_wrapper.py @@ -5,13 +5,12 @@ from core.constants import INTEGER from django.core.exceptions import ObjectDoesNotExist from flag_engine.identities.builders import build_identity_model -from flag_engine.segments.constants import IN from rest_framework.exceptions import NotFound from environments.dynamodb import DynamoIdentityWrapper from environments.identities.models import Identity from environments.identities.traits.models import Trait -from segments.models import Condition, Segment, SegmentRule +from segments.models import IN, Condition, Segment, SegmentRule from util.mappers import ( map_environment_to_environment_document, map_identity_to_identity_document, diff --git a/api/environments/identities/tests/test_models.py b/api/environments/identities/tests/test_models.py index 9e0691b22f12..3ec3bad9f5a6 100644 --- a/api/environments/identities/tests/test_models.py +++ b/api/environments/identities/tests/test_models.py @@ -1,13 +1,6 @@ import pytest from core.constants import FLOAT from django.utils import timezone -from flag_engine.segments.constants import ( - EQUAL, - GREATER_THAN, - GREATER_THAN_INCLUSIVE, - LESS_THAN_INCLUSIVE, - NOT_EQUAL, -) from rest_framework.test import APITestCase from environments.identities.models import Identity @@ -22,7 +15,16 @@ from features.value_types import BOOLEAN, INTEGER, STRING from organisations.models import Organisation from projects.models import Project -from segments.models import Condition, Segment, SegmentRule +from segments.models import ( + EQUAL, + GREATER_THAN, + GREATER_THAN_INCLUSIVE, + LESS_THAN_INCLUSIVE, + NOT_EQUAL, + Condition, + Segment, + SegmentRule, +) from .helpers import ( create_trait_for_identity, diff --git a/api/environments/identities/tests/test_views.py b/api/environments/identities/tests/test_views.py index 146de7fe4e93..6086c707b002 100644 --- a/api/environments/identities/tests/test_views.py +++ b/api/environments/identities/tests/test_views.py @@ -4,10 +4,9 @@ from unittest.case import TestCase import pytest -from core.constants import FLAGSMITH_UPDATED_AT_HEADER, STRING +from core.constants import FLAGSMITH_UPDATED_AT_HEADER from django.test import override_settings from django.urls import reverse -from flag_engine.segments.constants import PERCENTAGE_SPLIT from rest_framework import status from rest_framework.test import APIClient, APITestCase @@ -21,6 +20,7 @@ from integrations.amplitude.models import AmplitudeConfiguration from organisations.models import Organisation, OrganisationRole from projects.models import Project +from segments import models from segments.models import Condition, Segment, SegmentRule from util.tests import Helper @@ -370,7 +370,7 @@ def test_identities_endpoint_returns_traits(self, mock_amplitude_wrapper): trait = Trait.objects.create( identity=self.identity, trait_key="trait_key", - value_type=STRING, + value_type="STRING", string_value="trait_value", ) @@ -422,7 +422,7 @@ def test_identities_endpoint_returns_value_for_segment_if_identity_in_segment( Trait.objects.create( identity=self.identity, trait_key=trait_key, - value_type=STRING, + value_type="STRING", string_value=trait_value, ) segment = Segment.objects.create(name="Test Segment", project=self.project) @@ -476,7 +476,7 @@ def test_identities_endpoint_returns_value_for_segment_if_identity_in_segment_an Trait.objects.create( identity=self.identity, trait_key=trait_key, - value_type=STRING, + value_type="STRING", string_value=trait_value, ) segment = Segment.objects.create(name="Test Segment", project=self.project) @@ -528,7 +528,7 @@ def test_identities_endpoint_returns_value_for_segment_if_rule_type_percentage_s [segment.id, self.identity.id] ) Condition.objects.create( - operator=PERCENTAGE_SPLIT, + operator=models.PERCENTAGE_SPLIT, value=(identity_percentage_value + (1 - identity_percentage_value) / 2) * 100.0, rule=segment_rule, @@ -575,7 +575,7 @@ def test_identities_endpoint_returns_default_value_if_rule_type_percentage_split [segment.id, self.identity.id] ) Condition.objects.create( - operator=PERCENTAGE_SPLIT, + operator=models.PERCENTAGE_SPLIT, value=identity_percentage_value / 2, rule=segment_rule, ) @@ -628,13 +628,13 @@ def test_post_identify_deletes_a_trait_if_trait_value_is_none(self): trait_1 = Trait.objects.create( identity=self.identity, trait_key="trait_key_1", - value_type=STRING, + value_type="STRING", string_value="trait_value", ) trait_2 = Trait.objects.create( identity=self.identity, trait_key="trait_key_2", - value_type=STRING, + value_type="STRING", string_value="trait_value", ) diff --git a/api/features/feature_segments/tests/test_models.py b/api/features/feature_segments/tests/test_models.py index 34222c6f3f3b..2ff952d1a663 100644 --- a/api/features/feature_segments/tests/test_models.py +++ b/api/features/feature_segments/tests/test_models.py @@ -1,7 +1,6 @@ import pytest from core.constants import STRING from django.test import TestCase -from flag_engine.segments.constants import EQUAL from environments.identities.models import Identity from environments.identities.traits.models import Trait @@ -9,7 +8,7 @@ from features.models import Feature, FeatureSegment from organisations.models import Organisation from projects.models import Project -from segments.models import Condition, Segment, SegmentRule +from segments.models import EQUAL, Condition, Segment, SegmentRule @pytest.mark.django_db diff --git a/api/segments/models.py b/api/segments/models.py index 3bdc7d8399a1..a49e73b94ebc 100644 --- a/api/segments/models.py +++ b/api/segments/models.py @@ -2,6 +2,8 @@ import typing from copy import deepcopy +import semver +from core.constants import BOOLEAN, FLOAT, INTEGER from core.models import ( AbstractBaseExportableModel, SoftDeleteExportableModel, @@ -9,18 +11,15 @@ ) from django.core.exceptions import ValidationError from django.db import models -from flag_engine.segments import constants -from flag_engine.segments.evaluator import evaluate_identity_in_segment +from flag_engine.utils.semver import is_semver, remove_semver_suffix from audit.constants import SEGMENT_CREATED_MESSAGE, SEGMENT_UPDATED_MESSAGE from audit.related_object_type import RelatedObjectType +from environments.identities.helpers import ( + get_hashed_percentage_for_object_ids, +) from features.models import Feature from projects.models import Project -from util.mappers.engine import ( - map_identity_to_engine, - map_segment_to_engine, - map_traits_to_trait_models, -) if typing.TYPE_CHECKING: from environments.identities.models import Identity @@ -29,6 +28,30 @@ logger = logging.getLogger(__name__) +try: + import re2 as re + + logger.info("Using re2 library for regex.") +except ImportError: + logger.warning("Unable to import re2. Falling back to re.") + import re + +# Condition Types +EQUAL = "EQUAL" +GREATER_THAN = "GREATER_THAN" +LESS_THAN = "LESS_THAN" +LESS_THAN_INCLUSIVE = "LESS_THAN_INCLUSIVE" +CONTAINS = "CONTAINS" +GREATER_THAN_INCLUSIVE = "GREATER_THAN_INCLUSIVE" +NOT_CONTAINS = "NOT_CONTAINS" +NOT_EQUAL = "NOT_EQUAL" +REGEX = "REGEX" +PERCENTAGE_SPLIT = "PERCENTAGE_SPLIT" +MODULO = "MODULO" +IS_SET = "IS_SET" +IS_NOT_SET = "IS_NOT_SET" +IN = "IN" + class Segment( SoftDeleteExportableModel, @@ -86,18 +109,9 @@ def id_exists_in_rules_data(rules_data: typing.List[dict]) -> bool: def does_identity_match( self, identity: "Identity", traits: typing.List["Trait"] = None ) -> bool: - segment_model = map_segment_to_engine(self) - identity_model = map_identity_to_engine( - identity, - with_overrides=False, - with_traits=not traits, - ) - trait_models = map_traits_to_trait_models(traits) if traits else None - - return evaluate_identity_in_segment( - identity=identity_model, - segment=segment_model, - override_traits=trait_models, + rules = self.rules.all() + return rules.count() > 0 and all( + rule.does_identity_match(identity, traits) for rule in rules ) def get_create_log_message(self, history_instance) -> typing.Optional[str]: @@ -141,6 +155,34 @@ def __str__(self): str(self.segment) if self.segment else str(self.rule), ) + def does_identity_match( + self, identity: "Identity", traits: typing.List["Trait"] = None + ) -> bool: + matches_conditions = False + conditions = self.conditions.all() + + if conditions.count() == 0: + matches_conditions = True + elif self.type == self.ALL_RULE: + matches_conditions = all( + condition.does_identity_match(identity, traits) + for condition in conditions + ) + elif self.type == self.ANY_RULE: + matches_conditions = any( + condition.does_identity_match(identity, traits) + for condition in conditions + ) + elif self.type == self.NONE_RULE: + matches_conditions = not any( + condition.does_identity_match(identity, traits) + for condition in conditions + ) + + return matches_conditions and all( + rule.does_identity_match(identity, traits) for rule in self.rules.all() + ) + def get_segment(self): """ rules can be a child of a parent rule instead of a segment, this method iterates back up the tree to find the @@ -161,20 +203,20 @@ class Condition( related_object_type = RelatedObjectType.SEGMENT CONDITION_TYPES = ( - (constants.EQUAL, "Exactly Matches"), - (constants.GREATER_THAN, "Greater than"), - (constants.LESS_THAN, "Less than"), - (constants.CONTAINS, "Contains"), - (constants.GREATER_THAN_INCLUSIVE, "Greater than or equal to"), - (constants.LESS_THAN_INCLUSIVE, "Less than or equal to"), - (constants.NOT_CONTAINS, "Does not contain"), - (constants.NOT_EQUAL, "Does not match"), - (constants.REGEX, "Matches regex"), - (constants.PERCENTAGE_SPLIT, "Percentage split"), - (constants.MODULO, "Modulo Operation"), - (constants.IS_SET, "Is set"), - (constants.IS_NOT_SET, "Is not set"), - (constants.IN, "In"), + (EQUAL, "Exactly Matches"), + (GREATER_THAN, "Greater than"), + (LESS_THAN, "Less than"), + (CONTAINS, "Contains"), + (GREATER_THAN_INCLUSIVE, "Greater than or equal to"), + (LESS_THAN_INCLUSIVE, "Less than or equal to"), + (NOT_CONTAINS, "Does not contain"), + (NOT_EQUAL, "Does not match"), + (REGEX, "Matches regex"), + (PERCENTAGE_SPLIT, "Percentage split"), + (MODULO, "Modulo Operation"), + (IS_SET, "Is set"), + (IS_NOT_SET, "Is not set"), + (IN, "In"), ) operator = models.CharField(choices=CONDITION_TYPES, max_length=500) @@ -199,6 +241,162 @@ def __str__(self): self.value, ) + def does_identity_match( # noqa: C901 + self, identity: "Identity", traits: typing.List["Trait"] = None + ) -> bool: + if self.operator == PERCENTAGE_SPLIT: + return self._check_percentage_split_operator(identity) + + # we allow passing in traits to handle when they aren't + # persisted for certain organisations + traits = identity.identity_traits.all() if traits is None else traits + matching_trait = next( + filter(lambda t: t.trait_key == self.property, traits), None + ) + if matching_trait is None: + return self.operator == IS_NOT_SET + + if self.operator in (IS_SET, IS_NOT_SET): + return self.operator == IS_SET + elif self.operator == MODULO: + if matching_trait.value_type in [INTEGER, FLOAT]: + return self._check_modulo_operator(matching_trait.trait_value) + elif self.operator == IN: + return str(matching_trait.trait_value) in self.value.split(",") + elif matching_trait.value_type == INTEGER: + return self.check_integer_value(matching_trait.integer_value) + elif matching_trait.value_type == FLOAT: + return self.check_float_value(matching_trait.float_value) + elif matching_trait.value_type == BOOLEAN: + return self.check_boolean_value(matching_trait.boolean_value) + elif is_semver(self.value): + return self.check_semver_value(matching_trait.string_value) + + return self.check_string_value(matching_trait.string_value) + + def _check_percentage_split_operator(self, identity): + try: + float_value = float(self.value) / 100.0 + except ValueError: + return False + + segment = self.rule.get_segment() + return ( + get_hashed_percentage_for_object_ids( + object_ids=[segment.id, identity.get_hash_key()] + ) + <= float_value + ) + + def _check_modulo_operator(self, value: typing.Union[int, float]) -> bool: + try: + divisor, remainder = self.value.split("|") + divisor = float(divisor) + remainder = float(remainder) + except ValueError: + return False + + return value % divisor == remainder + + def check_integer_value(self, value: int) -> bool: + try: + int_value = int(str(self.value)) + except ValueError: + return False + + if self.operator == EQUAL: + return value == int_value + elif self.operator == GREATER_THAN: + return value > int_value + elif self.operator == GREATER_THAN_INCLUSIVE: + return value >= int_value + elif self.operator == LESS_THAN: + return value < int_value + elif self.operator == LESS_THAN_INCLUSIVE: + return value <= int_value + elif self.operator == NOT_EQUAL: + return value != int_value + + return False + + def check_float_value(self, value: float) -> bool: + try: + float_value = float(str(self.value)) + except ValueError: + return False + + if self.operator == EQUAL: + return value == float_value + elif self.operator == GREATER_THAN: + return value > float_value + elif self.operator == GREATER_THAN_INCLUSIVE: + return value >= float_value + elif self.operator == LESS_THAN: + return value < float_value + elif self.operator == LESS_THAN_INCLUSIVE: + return value <= float_value + elif self.operator == NOT_EQUAL: + return value != float_value + + return False + + def check_boolean_value(self, value: bool) -> bool: + if self.value in ("False", "false", "0"): + bool_value = False + elif self.value in ("True", "true", "1"): + bool_value = True + else: + return False + + if self.operator == EQUAL: + return value == bool_value + elif self.operator == NOT_EQUAL: + return value != bool_value + + return False + + def check_semver_value(self, value: str) -> bool: + try: + condition_version_info = semver.VersionInfo.parse( + remove_semver_suffix(self.value) + ) + except ValueError: + return False + + if self.operator == EQUAL: + return value == condition_version_info + elif self.operator == GREATER_THAN: + return value > condition_version_info + elif self.operator == GREATER_THAN_INCLUSIVE: + return value >= condition_version_info + elif self.operator == LESS_THAN: + return value < condition_version_info + elif self.operator == LESS_THAN_INCLUSIVE: + return value <= condition_version_info + elif self.operator == NOT_EQUAL: + return value != condition_version_info + + return False + + def check_string_value(self, value: str) -> bool: + try: + str_value = str(self.value) + except ValueError: + return False + + if self.operator == EQUAL: + return value == str_value + elif self.operator == NOT_EQUAL: + return value != str_value + elif self.operator == CONTAINS: + return str_value in value + elif self.operator == NOT_CONTAINS: + return str_value not in value + elif self.operator == REGEX: + return re.compile(str(self.value)).match(value) is not None + + return False + def get_update_log_message(self, history_instance) -> typing.Optional[str]: return f"Condition updated on segment '{self._get_segment().name}'." diff --git a/api/segments/serializers.py b/api/segments/serializers.py index 28966e65fd7c..29a99c292e25 100644 --- a/api/segments/serializers.py +++ b/api/segments/serializers.py @@ -1,13 +1,12 @@ import typing -from flag_engine.segments.constants import PERCENTAGE_SPLIT from rest_framework import serializers from rest_framework.exceptions import ValidationError from rest_framework.serializers import ListSerializer from rest_framework_recursive.fields import RecursiveField from projects.models import Project -from segments.models import Condition, Segment, SegmentRule +from segments.models import PERCENTAGE_SPLIT, Condition, Segment, SegmentRule class ConditionSerializer(serializers.ModelSerializer): diff --git a/api/segments/tests/test_models.py b/api/segments/tests/test_models.py new file mode 100644 index 000000000000..7c49ad1e263e --- /dev/null +++ b/api/segments/tests/test_models.py @@ -0,0 +1,46 @@ +from unittest import TestCase + +import pytest + +from environments.identities.models import Identity +from environments.models import Environment +from organisations.models import Organisation +from projects.models import Project +from segments.models import PERCENTAGE_SPLIT, Condition, Segment, SegmentRule + + +@pytest.mark.django_db +class SegmentRuleTest(TestCase): + def setUp(self) -> None: + self.organisation = Organisation.objects.create(name="Test Org") + self.project = Project.objects.create( + name="Test Project", organisation=self.organisation + ) + self.environment = Environment.objects.create( + name="Test Environment", project=self.project + ) + self.identity = Identity.objects.create( + environment=self.environment, identifier="test_identity" + ) + self.segment = Segment.objects.create(project=self.project, name="test_segment") + + def test_get_segment_returns_parent_segment_for_nested_rule(self): + # Given + parent_rule = SegmentRule.objects.create( + segment=self.segment, type=SegmentRule.ALL_RULE + ) + child_rule = SegmentRule.objects.create( + rule=parent_rule, type=SegmentRule.ALL_RULE + ) + grandchild_rule = SegmentRule.objects.create( + rule=child_rule, type=SegmentRule.ALL_RULE + ) + Condition.objects.create( + operator=PERCENTAGE_SPLIT, value=0.1, rule=grandchild_rule + ) + + # When + segment = grandchild_rule.get_segment() + + # Then + assert segment == self.segment diff --git a/api/tests/unit/environments/test_environments_views_sdk_environment.py b/api/tests/unit/environments/test_environments_views_sdk_environment.py index 5857fa66b895..9dc3983b1466 100644 --- a/api/tests/unit/environments/test_environments_views_sdk_environment.py +++ b/api/tests/unit/environments/test_environments_views_sdk_environment.py @@ -1,12 +1,11 @@ from core.constants import FLAGSMITH_UPDATED_AT_HEADER from django.urls import reverse -from flag_engine.segments.constants import EQUAL from rest_framework import status from rest_framework.test import APIClient from environments.models import Environment, EnvironmentAPIKey from features.models import Feature -from segments.models import Condition, Segment, SegmentRule +from segments.models import EQUAL, Condition, Segment, SegmentRule def test_get_environment_document( diff --git a/api/tests/unit/environments/test_unit_environments_views.py b/api/tests/unit/environments/test_unit_environments_views.py index 2c903e1c463b..d6614c2f0c85 100644 --- a/api/tests/unit/environments/test_unit_environments_views.py +++ b/api/tests/unit/environments/test_unit_environments_views.py @@ -6,7 +6,6 @@ from core.constants import STRING from django.contrib.contenttypes.models import ContentType from django.urls import reverse -from flag_engine.segments.constants import EQUAL from pytest_lazyfixture import lazy_fixture from pytest_mock import MockerFixture from rest_framework import status @@ -27,7 +26,7 @@ UserProjectPermission, ) from projects.permissions import CREATE_ENVIRONMENT, VIEW_PROJECT -from segments.models import Condition, SegmentRule +from segments.models import EQUAL, Condition, SegmentRule from users.models import FFAdminUser from util.tests import Helper diff --git a/api/tests/unit/import_export/test_unit_import_export_export.py b/api/tests/unit/import_export/test_unit_import_export_export.py index 19b48a50f6db..509734db1ef5 100644 --- a/api/tests/unit/import_export/test_unit_import_export_export.py +++ b/api/tests/unit/import_export/test_unit_import_export_export.py @@ -7,7 +7,7 @@ from django.contrib.contenttypes.models import ContentType from django.core.management import call_command from django.core.serializers.json import DjangoJSONEncoder -from flag_engine.segments.constants import ALL_RULE, EQUAL +from flag_engine.segments.constants import ALL_RULE from moto import mock_s3 from environments.models import Environment, EnvironmentAPIKey, Webhook @@ -42,7 +42,7 @@ from organisations.models import Organisation, OrganisationWebhook from projects.models import Project from projects.tags.models import Tag -from segments.models import Condition, Segment, SegmentRule +from segments.models import EQUAL, Condition, Segment, SegmentRule def test_export_organisation(db): diff --git a/api/tests/unit/projects/test_unit_projects_admin.py b/api/tests/unit/projects/test_unit_projects_admin.py index e1f735f576e1..0be0f847cad3 100644 --- a/api/tests/unit/projects/test_unit_projects_admin.py +++ b/api/tests/unit/projects/test_unit_projects_admin.py @@ -3,13 +3,12 @@ import pytest from django.contrib.admin import AdminSite -from flag_engine.segments.constants import EQUAL from environments.models import Environment from features.models import Feature, FeatureSegment, FeatureState from projects.admin import ProjectAdmin from projects.models import Project -from segments.models import Condition, Segment, SegmentRule +from segments.models import EQUAL, Condition, Segment, SegmentRule if typing.TYPE_CHECKING: from django.contrib.auth.models import AbstractUser diff --git a/api/tests/unit/segments/test_conditions.py b/api/tests/unit/segments/test_conditions.py new file mode 100644 index 000000000000..48191203996d --- /dev/null +++ b/api/tests/unit/segments/test_conditions.py @@ -0,0 +1,157 @@ +import pytest +from core.constants import INTEGER, STRING + +from environments.identities.traits.models import Trait +from segments.models import ( + EQUAL, + GREATER_THAN, + GREATER_THAN_INCLUSIVE, + IN, + IS_NOT_SET, + IS_SET, + LESS_THAN, + LESS_THAN_INCLUSIVE, + MODULO, + NOT_EQUAL, + Condition, +) + + +@pytest.mark.parametrize( + "operator, trait_value, condition_value, result", + [ + (EQUAL, "1.0.0", "1.0.0:semver", True), + (EQUAL, "1.0.0", "1.0.1:semver", False), + (NOT_EQUAL, "1.0.0", "1.0.0:semver", False), + (NOT_EQUAL, "1.0.0", "1.0.1:semver", True), + (GREATER_THAN, "1.0.1", "1.0.0:semver", True), + (GREATER_THAN, "1.0.0", "1.0.0-beta:semver", True), + (GREATER_THAN, "1.0.1", "1.2.0:semver", False), + (GREATER_THAN, "1.0.1", "1.0.1:semver", False), + (GREATER_THAN, "1.2.4", "1.2.3-pre.2+build.4:semver", True), + (LESS_THAN, "1.0.0", "1.0.1:semver", True), + (LESS_THAN, "1.0.0", "1.0.0:semver", False), + (LESS_THAN, "1.0.1", "1.0.0:semver", False), + (LESS_THAN, "1.0.0-rc.2", "1.0.0-rc.3:semver", True), + (GREATER_THAN_INCLUSIVE, "1.0.1", "1.0.0:semver", True), + (GREATER_THAN_INCLUSIVE, "1.0.1", "1.2.0:semver", False), + (GREATER_THAN_INCLUSIVE, "1.0.1", "1.0.1:semver", True), + (LESS_THAN_INCLUSIVE, "1.0.0", "1.0.1:semver", True), + (LESS_THAN_INCLUSIVE, "1.0.0", "1.0.0:semver", True), + (LESS_THAN_INCLUSIVE, "1.0.1", "1.0.0:semver", False), + ], +) +def test_does_identity_match_for_semver_values( + identity, operator, trait_value, condition_value, result +): + # Given + condition = Condition(operator=operator, property="version", value=condition_value) + traits = [ + Trait( + trait_key="version", + string_value=trait_value, + identity=identity, + ) + ] + # Then + assert condition.does_identity_match(identity, traits) is result + + +@pytest.mark.parametrize( + "trait_value, condition_value, result", + [ + (1, "2|0", False), + (2, "2|0", True), + (3, "2|0", False), + (34.2, "4|3", False), + (35.0, "4|3", True), + ("dummy", "3|0", False), + ("1.0.0", "3|0", False), + (False, "1|3", False), + ], +) +def test_does_identity_match_for_modulo_operator( + identity, trait_value, condition_value, result +): + condition = Condition(operator=MODULO, property="user_id", value=condition_value) + + trait_value_data = Trait.generate_trait_value_data(trait_value) + traits = [Trait(trait_key="user_id", identity=identity, **trait_value_data)] + + assert condition.does_identity_match(identity, traits) is result + + +def test_does_identity_match_is_set_true(identity): + # Given + trait_key = "some_property" + condition = Condition(operator=IS_SET, property=trait_key) + traits = [Trait(trait_key=trait_key, identity=identity)] + + # Then + assert condition.does_identity_match(identity, traits) is True + + +def test_does_identity_match_is_set_false(identity): + # Given + trait_key = "some_property" + condition = Condition(operator=IS_SET, property=trait_key) + traits = [] + + # Then + assert condition.does_identity_match(identity, traits) is False + + +def test_does_identity_match_is_not_set_true(identity): + # Given + trait_key = "some_property" + condition = Condition(operator=IS_NOT_SET, property=trait_key) + traits = [Trait(trait_key=trait_key, identity=identity)] + + # Then + assert condition.does_identity_match(identity, traits) is False + + +def test_does_identity_match_is_not_set_false(identity): + # Given + trait_key = "some_property" + condition = Condition(operator=IS_NOT_SET, property=trait_key) + traits = [] + + # Then + assert condition.does_identity_match(identity, traits) is True + + +@pytest.mark.parametrize( + "condition_value, trait_value_type, trait_string_value, trait_integer_value, expected_result", + ( + ("", STRING, "foo", None, False), + ("foo,bar", STRING, "foo", None, True), + ("foo", STRING, "foo", None, True), + ("1,2,3,4", INTEGER, None, 1, True), + ("", INTEGER, None, 1, False), + ("1", INTEGER, None, 1, True), + ), +) +def test_does_identity_match_in( + identity, + condition_value, + trait_value_type, + trait_string_value, + trait_integer_value, + expected_result, +): + # Given + trait_key = "some_property" + condition = Condition(operator=IN, property=trait_key, value=condition_value) + traits = [ + Trait( + trait_key=trait_key, + identity=identity, + value_type=trait_value_type, + string_value=trait_string_value, + integer_value=trait_integer_value, + ) + ] + + # Then + assert condition.does_identity_match(identity, traits) is expected_result diff --git a/api/tests/unit/segments/test_unit_segments_views.py b/api/tests/unit/segments/test_unit_segments_views.py index e0126b386a1e..ebdbe7c545c3 100644 --- a/api/tests/unit/segments/test_unit_segments_views.py +++ b/api/tests/unit/segments/test_unit_segments_views.py @@ -4,7 +4,6 @@ import pytest from django.contrib.auth import get_user_model from django.urls import reverse -from flag_engine.segments.constants import EQUAL from pytest_lazyfixture import lazy_fixture from rest_framework import status @@ -12,7 +11,7 @@ from audit.related_object_type import RelatedObjectType from environments.models import Environment from features.models import Feature -from segments.models import Condition, Segment, SegmentRule +from segments.models import EQUAL, Condition, Segment, SegmentRule from util.mappers import map_identity_to_identity_document User = get_user_model() diff --git a/api/util/mappers/engine.py b/api/util/mappers/engine.py index fb691a07f806..eb545a425dc5 100644 --- a/api/util/mappers/engine.py +++ b/api/util/mappers/engine.py @@ -48,29 +48,6 @@ ) -def map_traits_to_trait_models(traits: Iterable["Trait"]) -> list[TraitModel]: - return [ - TraitModel(trait_key=trait.trait_key, trait_value=trait.trait_value) - for trait in traits - ] - - -def map_segment_to_engine( - segment: "Segment", -) -> SegmentModel: - segment_rules = segment.rules.all() - - # No reading from ORM past this point! - - return SegmentModel( - id=segment.pk, - name=segment.name, - rules=[ - map_segment_rule_to_engine(segment_rule) for segment_rule in segment_rules - ], - ) - - def map_segment_rule_to_engine( segment_rule: "SegmentRule", ) -> SegmentRuleModel: @@ -190,7 +167,7 @@ def map_environment_to_engine( int, Iterable["SegmentRule"], ] = {segment.pk: segment.rules.all() for segment in project_segments} - project_segment_feature_states_by_segment_id = _get_segment_feature_states( + project_segment_feature_states_by_segment_id = _get_project_segment_feature_states( project_segments, environment.pk, ) @@ -353,30 +330,19 @@ def map_environment_api_key_to_engine( ) -def map_identity_to_engine( - identity: "Identity", - *, - with_overrides: bool = True, - with_traits: bool = True, -) -> IdentityModel: +def map_identity_to_engine(identity: "Identity") -> IdentityModel: environment_api_key = identity.environment.api_key # Read relationships - grab all the data needed from the ORM here. - if with_overrides: - identity_feature_states: List["FeatureState"] = _get_prioritised_feature_states( - identity.identity_features.all(), - ) - multivariate_feature_state_values_by_feature_state_id = { - feature_state.pk: feature_state.multivariate_feature_state_values.all() - for feature_state in identity_feature_states - } - else: - identity_feature_states = [] - multivariate_feature_state_values_by_feature_state_id = {} - - identity_traits: Iterable["Trait"] = ( - identity.identity_traits.all() if with_traits else [] + identity_feature_states: List["FeatureState"] = _get_prioritised_feature_states( + identity.identity_features.all(), ) + multivariate_feature_state_values_by_feature_state_id = { + feature_state.pk: feature_state.multivariate_feature_state_values.all() + for feature_state in identity_feature_states + } + + identity_traits: Iterable["Trait"] = identity.identity_traits.all() # Prepare relationships. identity_feature_state_models = [ @@ -386,7 +352,10 @@ def map_identity_to_engine( ) for feature_state in identity_feature_states ] - identity_trait_models = map_traits_to_trait_models(identity_traits) + identity_trait_models = [ + TraitModel(trait_key=trait.trait_key, trait_value=trait.trait_value) + for trait in identity_traits + ] return IdentityModel( # Attributes: @@ -419,12 +388,12 @@ def _get_prioritised_feature_states( return list(prioritised_feature_state_by_feature_id.values()) -def _get_segment_feature_states( - segments: Iterable["Segment"], +def _get_project_segment_feature_states( + project_segments: Iterable["Segment"], environment_id: int, ) -> Dict[int, List["FeatureState"]]: feature_states_by_segment_id = {} - for segment in segments: + for segment in project_segments: segment_feature_states = feature_states_by_segment_id.setdefault(segment.pk, []) for feature_segment in segment.feature_segments.all(): if feature_segment.environment_id != environment_id: